├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── cmd └── tlsrouter │ ├── README.md │ ├── config.go │ ├── config_test.go │ ├── e2e_test.go │ ├── main.go │ ├── sni.go │ └── sni_test.go ├── go.mod ├── go.sum ├── http.go ├── listener.go ├── listener_test.go ├── scripts └── prune_old_versions.go ├── sni.go ├── systemd └── tlsrouter.service ├── tcpproxy.go └── tcpproxy_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - "1.16.x" 4 | - "1.17.x" 5 | - tip 6 | os: 7 | - linux 8 | script: 9 | - go build ./... 10 | - go test ./... 11 | - go vet ./... 12 | 13 | jobs: 14 | include: 15 | - stage: deploy 16 | go: "1.16" 17 | install: 18 | - gem install fpm 19 | script: 20 | - go build ./cmd/tlsrouter 21 | - fpm -s dir -t deb -n tlsrouter -v $(date '+%Y%m%d%H%M%S') 22 | --license Apache2 23 | --vendor "David Anderson " 24 | --maintainer "David Anderson " 25 | --description "TLS SNI router" 26 | --url "https://github.com/inetaf/tcpproxy/tree/master/cmd/tlsrouter" 27 | ./tlsrouter=/usr/bin/tlsrouter 28 | ./systemd/tlsrouter.service=/lib/systemd/system/tlsrouter.service 29 | deploy: 30 | - provider: packagecloud 31 | repository: tlsrouter 32 | username: danderson 33 | dist: debian/stretch 34 | skip_cleanup: true 35 | on: 36 | branch: master 37 | token: 38 | secure: gNU3o70EU4oYeIS6pr0K5oLMGqqxrcf41EOv6c/YoHPVdV6Cx4j9NW0/ISgu6a1/Xf2NgWKT5BWwLpAuhmGdALuOz1Ah//YBWd9N8mGHGaC6RpOPDU8/9NkQdBEmjEH9sgX4PNOh1KQ7d7O0OH0g8RqJlJa0MkUYbTtN6KJ29oiUXxKmZM4D/iWB8VonKOnrtx1NwQL8jL8imZyEV/1fknhDwumz2iKeU1le4Neq9zkxwICMLUonmgphlrp+SDb1EOoHxT6cn51bqBQtQUplfC4dN4OQU/CPqE9E1N1noibvN29YA93qfcrjD3I95KT9wzq+3B6he33+kb0Gz+Cj5ypGy4P85l7TuX4CtQg0U3NAlJCk32IfsdjK+o47pdmADij9IIb9yKt+g99FMERkJJY5EInqEsxHlW/vNF5OqQCmpiHstZL4R2XaHEsWh6j77npnjjC1Aea8xZTWr8PTsbSzVkbG7bTmFpZoPH8eEmr4GNuw5gnbi6D1AJDjcA+UdY9s5qZNpzuWOqfhOFxL+zUW+8sHBvcoFw3R+pwHECs2LCL1c0xAC1LtNUnmW/gnwHavtvKkzErjR1P8Xl7obCbeChJjp+b/BcFYlNACldZcuzBAPyPwIdlWVyUonL4bm63upfMEEShiAIDDJ21y7fjsQK7CfPA7g25bpyo+hV8= 39 | - provider: script 40 | on: 41 | branch: master 42 | script: go run scripts/prune_old_versions.go -user=danderson -repo=tlsrouter -distro=debian -version=stretch -package=tlsrouter -arch=amd64 -limit=2 43 | env: 44 | # Packagecloud API key, for prune_old_versions.go 45 | - secure: "SRcNwt+45QyPS1w9aGxMg9905Y6d9w4mBM29G6iTTnUB5nD7cAk4m+tf834knGSobVXlWcRnTDW8zrHdQ9yX22dPqCpH5qE+qzTmIvxRHrVJRMmPeYvligJ/9jYfHgQbvuRT8cUpIcpCQAla6rw8nXfKTOE3h8XqMP2hdc3DTVOu2HCfKCNco1tJ7is+AIAnFV2Wpsbb3ZsdKFvHvi2RKUfFaX61J1GNt2/XJIlZs8jC6Y1IAC+ftjql9UsAE/WjZ9fL0Ww1b9/LBIIGHXWI3HpVv9WvlhhIxIlJgOVjmU2lbSuj2w/EBDJ9cd1Qe+wJkT3yKzE1NRsNScVjGg+Ku5igJu/XXuaHkIX01+15BqgPduBYRL0atiNQDhqgBiSyVhXZBX9vsgsp0bgpKaBSF++CV18Q9dara8aljqqS33M3imO3I8JmXU10944QA9Wvu7pCYuIzXxhINcDXRvqxBqz5LnFJGwnGqngTrOCSVS2xn7Y+sjmhe1n5cPCEISlozfa9mPYPvMPp8zg3TbATOOM8CVfcpaNscLqa/+SExN3zMwSanjNKrBgoaQcBzGW5mIgSPxhXkWikBgapiEN7+2Y032Lhqdb9dYjH+EuwcnofspDjjMabWxnuJaln+E3/9vZi2ooQrBEtvymUTy4VMSnqwIX5bU7nPdIuQycdWhk=" 46 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contributions are welcome by pull request. 2 | 3 | You need to sign the Google Contributor License Agreement before your 4 | contributions can be accepted. You can find the individual and organization 5 | level CLAs here: 6 | 7 | Individual: https://cla.developers.google.com/about/google-individual 8 | Organization: https://cla.developers.google.com/about/google-corporate 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tcpproxy 2 | 3 | For library usage, see https://pkg.go.dev/github.com/inetaf/tcpproxy/ 4 | 5 | For CLI usage, see https://github.com/inetaf/tcpproxy/blob/master/cmd/tlsrouter/README.md 6 | -------------------------------------------------------------------------------- /cmd/tlsrouter/README.md: -------------------------------------------------------------------------------- 1 | # TLS SNI router 2 | 3 | [![license](https://img.shields.io/github/license/google/tlsrouter.svg?maxAge=2592000)](https://github.com/inetaf/tcpproxy/blob/master/LICENSE) [![Travis](https://img.shields.io/travis/google/tlsrouter.svg?maxAge=2592000)](https://travis-ci.org/google/tlsrouter) [![api](https://img.shields.io/badge/api-unstable-red.svg)](https://godoc.org/go.universe.tf/tlsrouter) 4 | 5 | TLSRouter is a TLS proxy that routes connections to backends based on 6 | the TLS SNI (Server Name Indication) of the TLS handshake. It carries 7 | no encryption keys and cannot decode the traffic that it proxies. 8 | 9 | ## Installation 10 | 11 | Install TLSRouter via `go get`: 12 | 13 | ```shell 14 | go get go.universe.tf/tcpproxy/cmd/tlsrouter 15 | ``` 16 | 17 | ## Usage 18 | 19 | TLSRouter requires a configuration file that tells it what backend to 20 | use for a given hostname. The config file looks like: 21 | 22 | ``` 23 | # Basic hostname -> backend mapping 24 | go.universe.tf localhost:1234 25 | 26 | # DNS wildcards are understood as well. 27 | *.go.universe.tf 1.2.3.4:8080 28 | 29 | # DNS wildcards can go anywhere in name. 30 | google.* 10.20.30.40:443 31 | 32 | # RE2 regexes are also available 33 | /(alpha|beta|gamma)\.mon(itoring)?\.dave\.tf/ 100.200.100.200:443 34 | 35 | # If your backend supports HAProxy's PROXY protocol, you can enable 36 | # it to receive the real client ip:port. 37 | 38 | fancy.backend 2.3.4.5:443 PROXY 39 | ``` 40 | 41 | TLSRouter takes one mandatory commandline argument, the configuration file to use: 42 | 43 | ```shell 44 | tlsrouter -conf tlsrouter.conf 45 | ``` 46 | 47 | Optional flags are: 48 | 49 | * `-listen `: set the listen address (default `:443`) 50 | * `-hello-timeout `: how long to wait for the start of the 51 | TLS handshake (default `3s`) 52 | -------------------------------------------------------------------------------- /cmd/tlsrouter/config.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package main 16 | 17 | import ( 18 | "bufio" 19 | "bytes" 20 | "errors" 21 | "fmt" 22 | "io" 23 | "os" 24 | "regexp" 25 | "strings" 26 | "sync" 27 | ) 28 | 29 | // A Route maps a match on a domain name to a backend. 30 | type Route struct { 31 | match *regexp.Regexp 32 | backend string 33 | proxyInfo bool 34 | } 35 | 36 | // Config stores the TLS routing configuration. 37 | type Config struct { 38 | mu sync.Mutex 39 | routes []Route 40 | } 41 | 42 | func dnsRegex(s string) (*regexp.Regexp, error) { 43 | if len(s) >= 2 && s[0] == '/' && s[len(s)-1] == '/' { 44 | return regexp.Compile(s[1 : len(s)-1]) 45 | } 46 | 47 | var b []string 48 | for _, f := range strings.Split(s, ".") { 49 | switch f { 50 | case "*": 51 | b = append(b, `[^.]+`) 52 | case "": 53 | return nil, fmt.Errorf("DNS name %q has empty label", s) 54 | default: 55 | b = append(b, regexp.QuoteMeta(f)) 56 | } 57 | } 58 | return regexp.Compile(fmt.Sprintf("^%s$", strings.Join(b, `\.`))) 59 | } 60 | 61 | // Match returns the backend for hostname, and whether to use the PROXY protocol. 62 | func (c *Config) Match(hostname string) (string, bool) { 63 | c.mu.Lock() 64 | defer c.mu.Unlock() 65 | 66 | for _, r := range c.routes { 67 | if r.match.MatchString(hostname) { 68 | return r.backend, r.proxyInfo 69 | } 70 | } 71 | return "", false 72 | } 73 | 74 | // Read replaces the current Config with one read from r. 75 | func (c *Config) Read(r io.Reader) error { 76 | var routes []Route 77 | var backends []string 78 | 79 | s := bufio.NewScanner(r) 80 | for s.Scan() { 81 | if strings.HasPrefix(strings.TrimSpace(s.Text()), "#") { 82 | // Comment, ignore. 83 | continue 84 | } 85 | 86 | fs := strings.Fields(s.Text()) 87 | switch len(fs) { 88 | case 0: 89 | continue 90 | case 1: 91 | return fmt.Errorf("invalid %q on a line by itself", s.Text()) 92 | case 2: 93 | re, err := dnsRegex(fs[0]) 94 | if err != nil { 95 | return err 96 | } 97 | routes = append(routes, Route{re, fs[1], false}) 98 | backends = append(backends, fs[1]) 99 | case 3: 100 | re, err := dnsRegex(fs[0]) 101 | if err != nil { 102 | return err 103 | } 104 | if fs[2] != "PROXY" { 105 | return errors.New("third item on a line can only be PROXY") 106 | } 107 | routes = append(routes, Route{re, fs[1], true}) 108 | backends = append(backends, fs[1]) 109 | default: 110 | // TODO: multiple backends? 111 | return fmt.Errorf("too many fields on line: %q", s.Text()) 112 | } 113 | } 114 | if err := s.Err(); err != nil { 115 | return err 116 | } 117 | 118 | c.mu.Lock() 119 | defer c.mu.Unlock() 120 | c.routes = routes 121 | return nil 122 | } 123 | 124 | // ReadFile replaces the current Config with one read from path. 125 | func (c *Config) ReadFile(path string) error { 126 | f, err := os.Open(path) 127 | if err != nil { 128 | return err 129 | } 130 | return c.Read(f) 131 | } 132 | 133 | // ReadString replaces the current Config with one read from cfg. 134 | func (c *Config) ReadString(cfg string) error { 135 | b := bytes.NewBufferString(cfg) 136 | return c.Read(b) 137 | } 138 | -------------------------------------------------------------------------------- /cmd/tlsrouter/config_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestConfig(t *testing.T) { 9 | type result struct { 10 | backend string 11 | proxy bool 12 | } 13 | 14 | cases := []struct { 15 | Config string 16 | Tests map[string]result 17 | }{ 18 | { 19 | Config: ` 20 | # Comment 21 | go.universe.tf 1.2.3.4 22 | *.universe.tf 2.3.4.5 23 | # Comment 24 | google.* 3.4.5.6 25 | /gooo+gle\.com/ 4.5.6.7 26 | foobar.net 6.7.8.9 PROXY 27 | `, 28 | Tests: map[string]result{ 29 | "go.universe.tf": result{"1.2.3.4", false}, 30 | "foo.universe.tf": result{"2.3.4.5", false}, 31 | "bar.universe.tf": result{"2.3.4.5", false}, 32 | "google.com": result{"3.4.5.6", false}, 33 | "google.fr": result{"3.4.5.6", false}, 34 | "goooooooooogle.com": result{"4.5.6.7", false}, 35 | "foobar.net": result{"6.7.8.9", true}, 36 | 37 | "blah.com": result{"", false}, 38 | "google.com.br": result{"", false}, 39 | "foo.bar.universe.tf": result{"", false}, 40 | "goooooglexcom": result{"", false}, 41 | }, 42 | }, 43 | } 44 | 45 | for _, test := range cases { 46 | var cfg Config 47 | if err := cfg.Read(bytes.NewBufferString(test.Config)); err != nil { 48 | t.Fatalf("Failed to read config (%s):\n%q", err, test.Config) 49 | } 50 | 51 | for hostname, expected := range test.Tests { 52 | backend, proxy := cfg.Match(hostname) 53 | if expected.backend != backend { 54 | t.Errorf("cfg.Match(%q) is %q, want %q", hostname, backend, expected.backend) 55 | } 56 | if expected.proxy != proxy { 57 | t.Errorf("cfg.Match(%q).proxy is %v, want %v", hostname, proxy, expected.proxy) 58 | } 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /cmd/tlsrouter/e2e_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "crypto/rsa" 7 | "crypto/tls" 8 | "crypto/x509" 9 | "crypto/x509/pkix" 10 | "encoding/pem" 11 | "fmt" 12 | "io/ioutil" 13 | "math/big" 14 | "net" 15 | "strings" 16 | "sync/atomic" 17 | "testing" 18 | "time" 19 | 20 | proxyproto "github.com/armon/go-proxyproto" 21 | ) 22 | 23 | func TestRouting(t *testing.T) { 24 | // Backend servers 25 | s1, err := serveTLS(t, "server1", false, "test.com") 26 | if err != nil { 27 | t.Fatalf("serve TLS server1: %s", err) 28 | } 29 | defer s1.Close() 30 | 31 | s2, err := serveTLS(t, "server2", false, "foo.net") 32 | if err != nil { 33 | t.Fatalf("serve TLS server2: %s", err) 34 | } 35 | defer s2.Close() 36 | 37 | s4, err := serveTLS(t, "server4", true, "proxy.design") 38 | if err != nil { 39 | t.Fatalf("server TLS server4: %s", err) 40 | } 41 | defer s4.Close() 42 | 43 | // One proxy 44 | var p Proxy 45 | l, err := net.Listen("tcp", "localhost:0") 46 | if err != nil { 47 | t.Fatalf("create listener: %s", err) 48 | } 49 | defer l.Close() 50 | go p.Serve(l) 51 | 52 | if err := p.Config.ReadString(fmt.Sprintf(` 53 | test.com %s 54 | foo.net %s 55 | proxy.design %s PROXY 56 | `, s1.Addr(), s2.Addr(), s4.Addr())); err != nil { 57 | t.Fatalf("configure proxy: %s", err) 58 | } 59 | 60 | for _, test := range []struct { 61 | N, V string 62 | P *x509.CertPool 63 | OK bool 64 | Transparent bool 65 | }{ 66 | {"test.com", "server1", s1.Pool, true, false}, 67 | {"foo.net", "server2", s2.Pool, true, false}, 68 | {"bar.org", "", s1.Pool, false, false}, 69 | {"proxy.design", "server4", s4.Pool, true, true}, 70 | } { 71 | res, transparent, err := getTLS(l.Addr().String(), test.N, test.P) 72 | switch { 73 | case test.OK && err != nil: 74 | t.Fatalf("get %q failed: %s", test.N, err) 75 | case !test.OK && err == nil: 76 | t.Fatalf("get %q should have failed, but returned %q", test.N, res) 77 | case test.OK && res != test.V: 78 | t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V) 79 | case test.OK && transparent != test.Transparent: 80 | t.Fatalf("connection transparency for %q was %v, want %v", test.N, transparent, test.Transparent) 81 | } 82 | } 83 | } 84 | 85 | // getTLS attempts to set up a TLS session using the given proxy 86 | // address, domain, and cert pool. It returns the value served by the 87 | // server, as well as a bool indicating whether the server knew the 88 | // true client address, indicating that the PROXY protocol was in use. 89 | func getTLS(addr string, domain string, pool *x509.CertPool) (string, bool, error) { 90 | cfg := tls.Config{ 91 | RootCAs: pool, 92 | ServerName: domain, 93 | } 94 | conn, err := tls.Dial("tcp", addr, &cfg) 95 | if err != nil { 96 | return "", false, fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err) 97 | } 98 | defer conn.Close() 99 | bs, err := ioutil.ReadAll(conn) 100 | if err != nil { 101 | return "", false, fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err) 102 | } 103 | fs := strings.Split(string(bs), " ") 104 | if len(fs) != 2 { 105 | return "", false, fmt.Errorf("read TLS from %q (domain %q): incoherent response %q", addr, domain, string(bs)) 106 | } 107 | transparent := fs[1] == conn.LocalAddr().String() 108 | return fs[0], transparent, nil 109 | } 110 | 111 | type tlsServer struct { 112 | Domains []string 113 | Value string 114 | Pool *x509.CertPool 115 | Test *testing.T 116 | NumHits uint32 117 | l net.Listener 118 | } 119 | 120 | func (s *tlsServer) Serve() { 121 | for { 122 | c, err := s.l.Accept() 123 | if err != nil { 124 | s.Test.Logf("accept failed on %q: %s", s.Domains, err) 125 | return 126 | } 127 | atomic.AddUint32(&s.NumHits, 1) 128 | fmt.Fprintf(c, "%s %s", s.Value, c.RemoteAddr()) 129 | c.Close() 130 | } 131 | } 132 | 133 | func (s *tlsServer) Addr() string { 134 | return s.l.Addr().String() 135 | } 136 | 137 | func (s *tlsServer) Close() error { 138 | return s.l.Close() 139 | } 140 | 141 | func serveTLS(t *testing.T, value string, understandProxy bool, domains ...string) (*tlsServer, error) { 142 | cert, pool, err := selfSignedCert(domains) 143 | if err != nil { 144 | return nil, err 145 | } 146 | 147 | cfg := &tls.Config{ 148 | Certificates: []tls.Certificate{cert}, 149 | } 150 | cfg.BuildNameToCertificate() 151 | 152 | var l net.Listener 153 | 154 | l, err = net.Listen("tcp", "localhost:0") 155 | if err != nil { 156 | return nil, err 157 | } 158 | 159 | if understandProxy { 160 | l = &proxyproto.Listener{Listener: l} 161 | } 162 | 163 | l = tls.NewListener(l, cfg) 164 | 165 | ret := &tlsServer{ 166 | Domains: domains, 167 | Value: value, 168 | Pool: pool, 169 | Test: t, 170 | l: l, 171 | } 172 | go ret.Serve() 173 | return ret, nil 174 | } 175 | 176 | func selfSignedCert(domains []string) (tls.Certificate, *x509.CertPool, error) { 177 | pkey, err := rsa.GenerateKey(rand.Reader, 2048) 178 | if err != nil { 179 | return tls.Certificate{}, nil, err 180 | } 181 | template := &x509.Certificate{ 182 | SerialNumber: big.NewInt(1), 183 | Subject: pkix.Name{ 184 | Organization: []string{"Test Co"}, 185 | CommonName: domains[0], 186 | }, 187 | NotBefore: time.Now().Add(-5 * time.Minute), 188 | NotAfter: time.Now().Add(60 * time.Minute), 189 | IsCA: true, 190 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, 191 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 192 | BasicConstraintsValid: true, 193 | DNSNames: domains[:], 194 | } 195 | 196 | derBytes, err := x509.CreateCertificate(rand.Reader, template, template, pkey.Public(), pkey) 197 | if err != nil { 198 | return tls.Certificate{}, nil, err 199 | } 200 | 201 | var cert, key bytes.Buffer 202 | pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) 203 | pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pkey)}) 204 | 205 | tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes()) 206 | if err != nil { 207 | return tls.Certificate{}, nil, err 208 | } 209 | 210 | pool := x509.NewCertPool() 211 | if !pool.AppendCertsFromPEM(cert.Bytes()) { 212 | return tls.Certificate{}, nil, fmt.Errorf("failed to add cert %q to pool", domains) 213 | } 214 | 215 | return tlscert, pool, nil 216 | } 217 | -------------------------------------------------------------------------------- /cmd/tlsrouter/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package main 16 | 17 | import ( 18 | "bytes" 19 | "flag" 20 | "fmt" 21 | "io" 22 | "log" 23 | "net" 24 | "sync" 25 | "time" 26 | ) 27 | 28 | var ( 29 | cfgFile = flag.String("conf", "", "configuration file") 30 | listen = flag.String("listen", ":443", "listening port") 31 | helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello") 32 | ) 33 | 34 | func main() { 35 | flag.Parse() 36 | 37 | p := &Proxy{} 38 | if err := p.Config.ReadFile(*cfgFile); err != nil { 39 | log.Fatalf("Failed to read config %q: %s", *cfgFile, err) 40 | } 41 | 42 | log.Fatalf("%s", p.ListenAndServe(*listen)) 43 | } 44 | 45 | // Proxy routes connections to backends based on a Config. 46 | type Proxy struct { 47 | Config Config 48 | l net.Listener 49 | } 50 | 51 | // Serve accepts connections from l and routes them according to TLS SNI. 52 | func (p *Proxy) Serve(l net.Listener) error { 53 | for { 54 | c, err := l.Accept() 55 | if err != nil { 56 | return fmt.Errorf("accept new conn: %s", err) 57 | } 58 | 59 | conn := &Conn{ 60 | TCPConn: c.(*net.TCPConn), 61 | config: &p.Config, 62 | } 63 | go conn.proxy() 64 | } 65 | } 66 | 67 | // ListenAndServe creates a listener on addr calls Serve on it. 68 | func (p *Proxy) ListenAndServe(addr string) error { 69 | l, err := net.Listen("tcp", addr) 70 | if err != nil { 71 | return fmt.Errorf("create listener: %s", err) 72 | } 73 | return p.Serve(l) 74 | } 75 | 76 | // A Conn handles the TLS proxying of one user connection. 77 | type Conn struct { 78 | *net.TCPConn 79 | config *Config 80 | 81 | tlsMinor int 82 | hostname string 83 | backend string 84 | backendConn *net.TCPConn 85 | } 86 | 87 | func (c *Conn) logf(msg string, args ...interface{}) { 88 | msg = fmt.Sprintf(msg, args...) 89 | log.Printf("%s <> %s: %s", c.RemoteAddr(), c.LocalAddr(), msg) 90 | } 91 | 92 | func (c *Conn) abort(alert byte, msg string, args ...interface{}) { 93 | c.logf(msg, args...) 94 | alertMsg := []byte{21, 3, byte(c.tlsMinor), 0, 2, 2, alert} 95 | 96 | if err := c.SetWriteDeadline(time.Now().Add(*helloTimeout)); err != nil { 97 | c.logf("error while setting write deadline during abort: %s", err) 98 | // Do NOT send the alert if we can't set a write deadline, 99 | // that could result in leaking a connection for an extended 100 | // period. 101 | return 102 | } 103 | 104 | if _, err := c.Write(alertMsg); err != nil { 105 | c.logf("error while sending alert: %s", err) 106 | } 107 | } 108 | 109 | func (c *Conn) internalError(msg string, args ...interface{}) { c.abort(80, msg, args...) } 110 | func (c *Conn) sniFailed(msg string, args ...interface{}) { c.abort(112, msg, args...) } 111 | 112 | func (c *Conn) proxy() { 113 | defer c.Close() 114 | 115 | if err := c.SetReadDeadline(time.Now().Add(*helloTimeout)); err != nil { 116 | c.internalError("Setting read deadline for ClientHello: %s", err) 117 | return 118 | } 119 | 120 | var ( 121 | err error 122 | handshakeBuf bytes.Buffer 123 | ) 124 | c.hostname, c.tlsMinor, err = extractSNI(io.TeeReader(c, &handshakeBuf)) 125 | if err != nil { 126 | c.internalError("Extracting SNI: %s", err) 127 | return 128 | } 129 | 130 | c.logf("extracted SNI %s", c.hostname) 131 | 132 | if err = c.SetReadDeadline(time.Time{}); err != nil { 133 | c.internalError("Clearing read deadline for ClientHello: %s", err) 134 | return 135 | } 136 | 137 | addProxyHeader := false 138 | c.backend, addProxyHeader = c.config.Match(c.hostname) 139 | if c.backend == "" { 140 | c.sniFailed("no backend found for %q", c.hostname) 141 | return 142 | } 143 | 144 | c.logf("routing %q to %q", c.hostname, c.backend) 145 | backend, err := net.DialTimeout("tcp", c.backend, 10*time.Second) 146 | if err != nil { 147 | c.internalError("failed to dial backend %q for %q: %s", c.backend, c.hostname, err) 148 | return 149 | } 150 | defer backend.Close() 151 | 152 | c.backendConn = backend.(*net.TCPConn) 153 | 154 | // If the backend supports the HAProxy PROXY protocol, give it the 155 | // real source information about the connection. 156 | if addProxyHeader { 157 | remote := c.TCPConn.RemoteAddr().(*net.TCPAddr) 158 | local := c.TCPConn.LocalAddr().(*net.TCPAddr) 159 | family := "TCP6" 160 | if remote.IP.To4() != nil { 161 | family = "TCP4" 162 | } 163 | if _, err := fmt.Fprintf(c.backendConn, "PROXY %s %s %s %d %d\r\n", family, remote.IP, local.IP, remote.Port, local.Port); err != nil { 164 | c.internalError("failed to send PROXY header to %q: %s", c.backend, err) 165 | return 166 | } 167 | } 168 | 169 | // Replay the piece of the handshake we had to read to do the 170 | // routing, then blindly proxy any other bytes. 171 | if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil { 172 | c.internalError("failed to replay handshake to %q: %s", c.backend, err) 173 | return 174 | } 175 | 176 | var wg sync.WaitGroup 177 | wg.Add(2) 178 | go proxy(&wg, c.TCPConn, c.backendConn) 179 | go proxy(&wg, c.backendConn, c.TCPConn) 180 | wg.Wait() 181 | } 182 | 183 | func proxy(wg *sync.WaitGroup, a, b net.Conn) { 184 | defer wg.Done() 185 | atcp, btcp := a.(*net.TCPConn), b.(*net.TCPConn) 186 | if _, err := io.Copy(atcp, btcp); err != nil { 187 | log.Printf("%s<>%s -> %s<>%s: %s", atcp.RemoteAddr(), atcp.LocalAddr(), btcp.LocalAddr(), btcp.RemoteAddr(), err) 188 | } 189 | btcp.CloseWrite() 190 | atcp.CloseRead() 191 | } 192 | -------------------------------------------------------------------------------- /cmd/tlsrouter/sni.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package main 16 | 17 | import ( 18 | "encoding/binary" 19 | "errors" 20 | "fmt" 21 | "io" 22 | ) 23 | 24 | func extractSNI(r io.Reader) (string, int, error) { 25 | handshake, tlsver, err := handshakeRecord(r) 26 | if err != nil { 27 | return "", 0, fmt.Errorf("reading TLS record: %s", err) 28 | } 29 | 30 | sni, err := parseHello(handshake) 31 | if err != nil { 32 | return "", 0, fmt.Errorf("reading ClientHello: %s", err) 33 | } 34 | if len(sni) == 0 { 35 | // ClientHello did not present an SNI extension. Valid packet, 36 | // no hostname. 37 | return "", tlsver, nil 38 | } 39 | 40 | hostname, err := parseSNI(sni) 41 | if err != nil { 42 | return "", 0, fmt.Errorf("parsing SNI extension: %s", err) 43 | } 44 | return hostname, tlsver, nil 45 | } 46 | 47 | // Extract the indicated hostname, if any, from the given SNI 48 | // extension bytes. 49 | func parseSNI(b []byte) (string, error) { 50 | b, _, err := vector(b, 2) 51 | if err != nil { 52 | return "", err 53 | } 54 | 55 | var ret []byte 56 | for len(b) >= 3 { 57 | typ := b[0] 58 | ret, b, err = vector(b[1:], 2) 59 | if err != nil { 60 | return "", fmt.Errorf("truncated SNI extension") 61 | } 62 | 63 | if typ == sniHostnameID { 64 | return string(ret), nil 65 | } 66 | } 67 | 68 | if len(b) != 0 { 69 | return "", fmt.Errorf("trailing garbage at end of SNI extension") 70 | } 71 | 72 | // No DNS-based SNI present. 73 | return "", nil 74 | } 75 | 76 | const sniExtensionID = 0 77 | const sniHostnameID = 0 78 | 79 | // Parse a TLS handshake record as a ClientHello message and extract 80 | // the SNI extension bytes, if any. 81 | func parseHello(b []byte) ([]byte, error) { 82 | if len(b) == 0 { 83 | return nil, errors.New("zero length handshake record") 84 | } 85 | if b[0] != 1 { 86 | return nil, fmt.Errorf("non-ClientHello handshake record type %d", b[0]) 87 | } 88 | 89 | // We're expecting a stricter TLS parser to run after we've 90 | // proxied, so we ignore any trailing bytes that might be present 91 | // (e.g. another handshake message). 92 | b, _, err := vector(b[1:], 3) 93 | if err != nil { 94 | return nil, fmt.Errorf("reading ClientHello: %s", err) 95 | } 96 | 97 | // ClientHello must be at least 34 bytes to reach the first vector 98 | // length byte. The actual minimal size is larger than that, but 99 | // vector() will correctly handle truncated packets. 100 | if len(b) < 34 { 101 | return nil, errors.New("ClientHello packet too short") 102 | } 103 | 104 | if b[0] != 3 { 105 | return nil, fmt.Errorf("ClientHello has unsupported version %d.%d", b[0], b[1]) 106 | } 107 | switch b[1] { 108 | case 1, 2, 3: 109 | // TLS 1.0, TLS 1.1, TLS 1.2 110 | default: 111 | return nil, fmt.Errorf("TLS record has unsupported version %d.%d", b[0], b[1]) 112 | } 113 | 114 | // Skip over version and random struct 115 | b = b[34:] 116 | 117 | // We don't technically care about SessionID, but we care that the 118 | // framing is well-formed all the way up to the SNI field, so that 119 | // we are sure that we're pulling the same SNI bytes as the 120 | // eventual TLS implementation. 121 | vec, b, err := vector(b, 1) 122 | if err != nil { 123 | return nil, fmt.Errorf("reading ClientHello SessionID: %s", err) 124 | } 125 | if len(vec) > 32 { 126 | return nil, fmt.Errorf("ClientHello SessionID too long (%db)", len(vec)) 127 | } 128 | 129 | // Likewise, we're just checking the bare minimum of framing. 130 | vec, b, err = vector(b, 2) 131 | if err != nil { 132 | return nil, fmt.Errorf("reading ClientHello CipherSuites: %s", err) 133 | } 134 | if len(vec) < 2 || len(vec)%2 != 0 { 135 | return nil, fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec)) 136 | } 137 | 138 | vec, b, err = vector(b, 1) 139 | if err != nil { 140 | return nil, fmt.Errorf("reading ClientHello CompressionMethods: %s", err) 141 | } 142 | if len(vec) < 1 { 143 | return nil, fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec)) 144 | } 145 | 146 | // Finally, we reach the extensions. 147 | if len(b) == 0 { 148 | // No extensions. This is not an error, it just means we have 149 | // no SNI payload. 150 | return nil, nil 151 | } 152 | b, vec, err = vector(b, 2) 153 | if err != nil { 154 | return nil, fmt.Errorf("reading ClientHello extensions: %s", err) 155 | } 156 | if len(vec) != 0 { 157 | return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(vec)) 158 | } 159 | 160 | for len(b) >= 4 { 161 | typ := binary.BigEndian.Uint16(b[:2]) 162 | vec, b, err = vector(b[2:], 2) 163 | if err != nil { 164 | return nil, fmt.Errorf("reading ClientHello extension %d: %s", typ, err) 165 | } 166 | if typ == sniExtensionID { 167 | // Found the SNI extension, return its payload. We don't 168 | // care about anything in the packet beyond this point. 169 | return vec, nil 170 | } 171 | } 172 | 173 | if len(b) != 0 { 174 | return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(b)) 175 | } 176 | 177 | // Successfully parsed all extensions, but there was no SNI. 178 | return nil, nil 179 | } 180 | 181 | const maxTLSRecordLength = 16384 182 | 183 | // Read one TLS record, which must be for the handshake protocol, from r. 184 | func handshakeRecord(r io.Reader) ([]byte, int, error) { 185 | var hdr struct { 186 | Type uint8 187 | Major, Minor uint8 188 | Length uint16 189 | } 190 | if err := binary.Read(r, binary.BigEndian, &hdr); err != nil { 191 | return nil, 0, fmt.Errorf("reading TLS record header: %s", err) 192 | } 193 | 194 | if hdr.Type != 22 { 195 | return nil, 0, fmt.Errorf("TLS record is not a handshake") 196 | } 197 | 198 | if hdr.Major != 3 { 199 | return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor) 200 | } 201 | switch hdr.Minor { 202 | case 1, 2, 3: 203 | // TLS 1.0, TLS 1.1, TLS 1.2 204 | default: 205 | return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor) 206 | } 207 | 208 | if hdr.Length > maxTLSRecordLength { 209 | return nil, 0, fmt.Errorf("TLS record length is greater than %d", maxTLSRecordLength) 210 | } 211 | 212 | ret := make([]byte, hdr.Length) 213 | if _, err := io.ReadFull(r, ret); err != nil { 214 | return nil, 0, err 215 | } 216 | 217 | return ret, int(hdr.Minor), nil 218 | } 219 | 220 | func vector(b []byte, lenBytes int) ([]byte, []byte, error) { 221 | if len(b) < lenBytes { 222 | return nil, nil, errors.New("not enough space in packet for vector") 223 | } 224 | var l int 225 | for _, b := range b[:lenBytes] { 226 | l = (l << 8) + int(b) 227 | } 228 | if len(b) < l+lenBytes { 229 | return nil, nil, errors.New("not enough space in packet for vector") 230 | } 231 | return b[lenBytes : l+lenBytes], b[l+lenBytes:], nil 232 | } 233 | -------------------------------------------------------------------------------- /cmd/tlsrouter/sni_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package main 16 | 17 | import ( 18 | "bytes" 19 | "testing" 20 | ) 21 | 22 | func slice(l int) []byte { 23 | ret := make([]byte, l) 24 | for i := 0; i < l; i++ { 25 | ret[i] = byte(i) 26 | } 27 | return ret 28 | } 29 | 30 | func vec(l, lenBytes int) []byte { 31 | b := slice(l) 32 | vecLen := len(b) 33 | ret := make([]byte, vecLen+l) 34 | for i := l - 1; i >= 0; i-- { 35 | ret[i] = byte(vecLen & 0xff) 36 | vecLen >>= 8 37 | } 38 | copy(ret[l:], b) 39 | return ret 40 | } 41 | 42 | func packet(bs ...[]byte) []byte { 43 | var ret []byte 44 | for _, b := range bs { 45 | ret = append(ret, b...) 46 | } 47 | return ret 48 | } 49 | 50 | func offset(b []byte, off int) []byte { 51 | return b[off:] 52 | } 53 | 54 | func TestVector(t *testing.T) { 55 | tests := []struct { 56 | in []byte 57 | inLen int 58 | out1, out2 []byte 59 | err bool 60 | }{ 61 | { 62 | // 1b length 63 | append([]byte{3}, slice(10)...), 1, 64 | slice(3), offset(slice(10), 3), false, 65 | }, 66 | { 67 | // 1b length, no trailer 68 | append([]byte{10}, slice(10)...), 1, 69 | slice(10), []byte{}, false, 70 | }, 71 | { 72 | // 1b length, no vector 73 | append([]byte{0}, slice(10)...), 1, 74 | []byte{}, slice(10), false, 75 | }, 76 | { 77 | // 1b length, no vector or trailer 78 | []byte{0}, 1, 79 | []byte{}, []byte{}, false, 80 | }, 81 | { 82 | // 2b length, LSB only 83 | append([]byte{0, 3}, slice(10)...), 2, 84 | slice(3), offset(slice(10), 3), false, 85 | }, 86 | { 87 | // 2b length, MSB only 88 | append([]byte{3, 0}, slice(1024)...), 2, 89 | slice(768), offset(slice(1024), 768), false, 90 | }, 91 | { 92 | // 2b length, both bytes 93 | append([]byte{3, 2}, slice(1024)...), 2, 94 | slice(770), offset(slice(1024), 770), false, 95 | }, 96 | { 97 | // 3b length 98 | append([]byte{1, 2, 3}, slice(100000)...), 3, 99 | slice(66051), offset(slice(100000), 66051), false, 100 | }, 101 | { 102 | // no bytes 103 | []byte{}, 1, 104 | nil, nil, true, 105 | }, 106 | { 107 | // no slice 108 | nil, 1, 109 | nil, nil, true, 110 | }, 111 | { 112 | // not enough bytes for length 113 | []byte{1}, 2, 114 | nil, nil, true, 115 | }, 116 | { 117 | // no bytes after length 118 | []byte{1}, 1, 119 | nil, nil, true, 120 | }, 121 | { 122 | // not enough bytes for vector 123 | []byte{4, 1, 2}, 1, 124 | nil, nil, true, 125 | }, 126 | } 127 | 128 | for _, test := range tests { 129 | actual1, actual2, err := vector(test.in, test.inLen) 130 | if !test.err && (err != nil) { 131 | t.Errorf("unexpected error %q", err) 132 | } 133 | if test.err && (err == nil) { 134 | t.Errorf("unexpected success") 135 | } 136 | if err != nil { 137 | continue 138 | } 139 | if !bytes.Equal(actual1, test.out1) { 140 | t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual1, test.out1) 141 | } 142 | if !bytes.Equal(actual2, test.out2) { 143 | t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual2, test.out2) 144 | } 145 | } 146 | } 147 | 148 | func TestHandshakeRecord(t *testing.T) { 149 | tests := []struct { 150 | in []byte 151 | out []byte 152 | tlsver int 153 | }{ 154 | { 155 | // TLS 1.0, 1b packet 156 | []byte{22, 3, 1, 0, 1, 3}, 157 | []byte{3}, 158 | 1, 159 | }, 160 | { 161 | // TLS 1.1, 1b packet 162 | []byte{22, 3, 2, 0, 1, 3}, 163 | []byte{3}, 164 | 2, 165 | }, 166 | { 167 | // TLS 1.2, 1b packet 168 | []byte{22, 3, 3, 0, 1, 3}, 169 | []byte{3}, 170 | 3, 171 | }, 172 | { 173 | // TLS 1.2, no payload bytes 174 | []byte{22, 3, 3, 0, 0}, 175 | []byte{}, 176 | 3, 177 | }, 178 | { 179 | // TLS 1.2, >255b payload w/ trailing stuff 180 | append([]byte{22, 3, 3, 3, 2}, slice(1024)...), 181 | slice(770), 182 | 3, 183 | }, 184 | { 185 | // TLS 1.2, 2^14 payload 186 | append([]byte{22, 3, 3, 64, 0}, slice(maxTLSRecordLength)...), 187 | slice(maxTLSRecordLength), 188 | 3, 189 | }, 190 | { 191 | // TLS 1.2, >2^14 payload 192 | append([]byte{22, 3, 3, 64, 1}, slice(maxTLSRecordLength+1)...), 193 | nil, 194 | 0, 195 | }, 196 | { 197 | // TLS 1.2, truncated payload 198 | []byte{22, 3, 3, 0, 4, 1, 2}, 199 | nil, 200 | 0, 201 | }, 202 | { 203 | // truncated header 204 | []byte{22}, 205 | nil, 206 | 0, 207 | }, 208 | { 209 | // wrong record type 210 | []byte{42, 3, 3, 0, 1, 3}, 211 | nil, 212 | 0, 213 | }, 214 | { 215 | // wrong TLS major version 216 | []byte{22, 2, 3, 0, 1, 3}, 217 | nil, 218 | 0, 219 | }, 220 | { 221 | // wrong TLS minor version 222 | []byte{22, 3, 42, 0, 1, 3}, 223 | nil, 224 | 0, 225 | }, 226 | { 227 | // Obsolete SSL 3.0 228 | []byte{22, 3, 0, 0, 1, 3}, 229 | nil, 230 | 0, 231 | }, 232 | } 233 | 234 | for _, test := range tests { 235 | r := bytes.NewBuffer(test.in) 236 | actual, tlsver, err := handshakeRecord(r) 237 | if test.out == nil && err == nil { 238 | t.Errorf("unexpected success") 239 | continue 240 | } 241 | if !bytes.Equal(test.out, actual) { 242 | t.Errorf("wrong bytes for TLS record. Got %#v, want %#v", actual, test.out) 243 | } 244 | if tlsver != test.tlsver { 245 | t.Errorf("wrong TLS version returned. Got %d, want %d", tlsver, test.tlsver) 246 | } 247 | } 248 | } 249 | 250 | func TestParseHello(t *testing.T) { 251 | tests := []struct { 252 | in []byte 253 | out []byte 254 | err bool 255 | }{ 256 | { 257 | // Wrong record type 258 | packet([]byte{42, 0, 0, 1, 1}), 259 | nil, 260 | true, 261 | }, 262 | { 263 | // Truncated payload 264 | packet([]byte{1, 0, 0, 1}), 265 | nil, 266 | true, 267 | }, 268 | { 269 | // Payload too small 270 | packet([]byte{1, 0, 0, 1, 1}), 271 | nil, 272 | true, 273 | }, 274 | { 275 | // Unknown major version 276 | packet([]byte{1, 0, 0, 34, 1, 0}, slice(32)), 277 | nil, 278 | true, 279 | }, 280 | { 281 | // Unknown minor version 282 | packet([]byte{1, 0, 0, 34, 3, 42}, slice(32)), 283 | nil, 284 | true, 285 | }, 286 | { 287 | // Missing required variadic fields 288 | packet([]byte{1, 0, 0, 34, 3, 1}, slice(32)), 289 | nil, 290 | true, 291 | }, 292 | { 293 | // All zero variadic fields (no ciphersuites, no compression) 294 | packet([]byte{1, 0, 0, 38, 3, 1}, slice(32), []byte{0, 0, 0, 0}), 295 | nil, 296 | true, 297 | }, 298 | { 299 | // All zero variadic fields (no ciphersuites, no compression, nonzero session ID) 300 | packet([]byte{1, 0, 0, 70, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 0, 0}), 301 | nil, 302 | true, 303 | }, 304 | { 305 | // Session + ciphersuites, no compression 306 | packet([]byte{1, 0, 0, 72, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 0}), 307 | nil, 308 | true, 309 | }, 310 | { 311 | // First valid packet. TLS 1.0, no extensions present. 312 | packet([]byte{1, 0, 0, 73, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), 313 | nil, 314 | false, 315 | }, 316 | { 317 | // TLS 1.1, no extensions present. 318 | packet([]byte{1, 0, 0, 73, 3, 2}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), 319 | nil, 320 | false, 321 | }, 322 | { 323 | // TLS 1.2, no extensions present. 324 | packet([]byte{1, 0, 0, 73, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), 325 | nil, 326 | false, 327 | }, 328 | { 329 | // TLS 1.2, garbage extensions 330 | packet([]byte{1, 0, 0, 115, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, slice(42)), 331 | nil, 332 | true, 333 | }, 334 | { 335 | // empty extensions vector 336 | packet([]byte{1, 0, 0, 75, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 0}), 337 | nil, 338 | false, 339 | }, 340 | { 341 | // non-SNI extensions 342 | packet([]byte{1, 0, 0, 85, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 10, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2}), 343 | nil, 344 | false, 345 | }, 346 | { 347 | // SNI present 348 | packet([]byte{1, 0, 0, 90, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 15, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 1, 182}), 349 | []byte{182}, 350 | false, 351 | }, 352 | { 353 | // Longer SNI 354 | packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 4}, slice(4)), 355 | slice(4), 356 | false, 357 | }, 358 | { 359 | // Embedded SNI 360 | packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 0, 0, 0, 4}, slice(4), []byte{100, 100, 0, 2, 1, 2}), 361 | slice(4), 362 | false, 363 | }, 364 | } 365 | 366 | for _, test := range tests { 367 | actual, err := parseHello(test.in) 368 | if test.err { 369 | if err == nil { 370 | t.Errorf("unexpected success") 371 | } 372 | continue 373 | } 374 | if err != nil { 375 | t.Errorf("unexpected error %q", err) 376 | continue 377 | } 378 | if !bytes.Equal(test.out, actual) { 379 | t.Errorf("wrong bytes for SNI data. Got %#v, want %#v", actual, test.out) 380 | } 381 | } 382 | } 383 | 384 | func TestParseSNI(t *testing.T) { 385 | tests := []struct { 386 | in []byte 387 | out string 388 | err bool 389 | }{ 390 | { 391 | // Empty packet 392 | []byte{}, 393 | "", 394 | true, 395 | }, 396 | { 397 | // Truncated packet 398 | []byte{0, 2, 1}, 399 | "", 400 | true, 401 | }, 402 | { 403 | // Truncated packet within SNI vector 404 | []byte{0, 2, 1, 2}, 405 | "", 406 | true, 407 | }, 408 | { 409 | // Wrong SNI kind 410 | []byte{0, 3, 1, 0, 0}, 411 | "", 412 | false, 413 | }, 414 | { 415 | // Right SNI kind, no hostname 416 | []byte{0, 3, 0, 0, 0}, 417 | "", 418 | false, 419 | }, 420 | { 421 | // SNI hostname 422 | packet([]byte{0, 6, 0, 0, 3}, []byte("lol")), 423 | "lol", 424 | false, 425 | }, 426 | { 427 | // Multiple SNI kinds 428 | packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("lol"), []byte{42, 0, 1, 2}), 429 | "lol", 430 | false, 431 | }, 432 | { 433 | // Multiple SNI hostnames (illegal, but we just return the first) 434 | packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("bar"), []byte{0, 0, 3}, []byte("lol")), 435 | "bar", 436 | false, 437 | }, 438 | } 439 | 440 | for _, test := range tests { 441 | actual, err := parseSNI(test.in) 442 | if test.err { 443 | if err == nil { 444 | t.Errorf("unexpected success") 445 | } 446 | continue 447 | } 448 | if err != nil { 449 | t.Errorf("unexpected error %q", err) 450 | continue 451 | } 452 | if test.out != actual { 453 | t.Errorf("wrong SNI hostname. Got %q, want %q", actual, test.out) 454 | } 455 | } 456 | } 457 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/inetaf/tcpproxy 2 | 3 | go 1.16 4 | 5 | require github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a h1:AP/vsCIvJZ129pdm9Ek7bH7yutN3hByqsMoNrWAxRQc= 2 | github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a/go.mod h1:QmP9hvJ91BbJmGVGSbutW19IC0Q9phDCLGaomwTJbgU= 3 | -------------------------------------------------------------------------------- /http.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package tcpproxy 16 | 17 | import ( 18 | "bufio" 19 | "bytes" 20 | "context" 21 | "net/http" 22 | ) 23 | 24 | // AddHTTPHostRoute appends a route to the ipPort listener that 25 | // routes to dest if the incoming HTTP/1.x Host header name is 26 | // httpHost. If it doesn't match, rule processing continues for any 27 | // additional routes on ipPort. 28 | // 29 | // The ipPort is any valid net.Listen TCP address. 30 | func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) { 31 | p.AddHTTPHostMatchRoute(ipPort, equals(httpHost), dest) 32 | } 33 | 34 | // AddHTTPHostMatchRoute appends a route to the ipPort listener that 35 | // routes to dest if the incoming HTTP/1.x Host header name is 36 | // accepted by matcher. If it doesn't match, rule processing continues 37 | // for any additional routes on ipPort. 38 | // 39 | // The ipPort is any valid net.Listen TCP address. 40 | func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) { 41 | p.addRoute(ipPort, httpHostMatch{match, dest}) 42 | } 43 | 44 | type httpHostMatch struct { 45 | matcher Matcher 46 | target Target 47 | } 48 | 49 | func (m httpHostMatch) match(br *bufio.Reader) (Target, string) { 50 | hh := httpHostHeader(br) 51 | if m.matcher(context.TODO(), hh) { 52 | return m.target, hh 53 | } 54 | return nil, "" 55 | } 56 | 57 | // httpHostHeader returns the HTTP Host header from br without 58 | // consuming any of its bytes. It returns "" if it can't find one. 59 | func httpHostHeader(br *bufio.Reader) string { 60 | const maxPeek = 4 << 10 61 | peekSize := 0 62 | for { 63 | peekSize++ 64 | if peekSize > maxPeek { 65 | b, _ := br.Peek(br.Buffered()) 66 | return httpHostHeaderFromBytes(b) 67 | } 68 | b, err := br.Peek(peekSize) 69 | if n := br.Buffered(); n > peekSize { 70 | b, _ = br.Peek(n) 71 | peekSize = n 72 | } 73 | if len(b) > 0 { 74 | if b[0] < 'A' || b[0] > 'Z' { 75 | // Doesn't look like an HTTP verb 76 | // (GET, POST, etc). 77 | return "" 78 | } 79 | if bytes.Index(b, crlfcrlf) != -1 || bytes.Index(b, lflf) != -1 { 80 | req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(b))) 81 | if err != nil { 82 | return "" 83 | } 84 | if len(req.Header["Host"]) > 1 { 85 | // TODO(bradfitz): what does 86 | // ReadRequest do if there are 87 | // multiple Host headers? 88 | return "" 89 | } 90 | return req.Host 91 | } 92 | } 93 | if err != nil { 94 | return httpHostHeaderFromBytes(b) 95 | } 96 | } 97 | } 98 | 99 | var ( 100 | lfHostColon = []byte("\nHost:") 101 | lfhostColon = []byte("\nhost:") 102 | crlf = []byte("\r\n") 103 | lf = []byte("\n") 104 | crlfcrlf = []byte("\r\n\r\n") 105 | lflf = []byte("\n\n") 106 | ) 107 | 108 | func httpHostHeaderFromBytes(b []byte) string { 109 | if i := bytes.Index(b, lfHostColon); i != -1 { 110 | return string(bytes.TrimSpace(untilEOL(b[i+len(lfHostColon):]))) 111 | } 112 | if i := bytes.Index(b, lfhostColon); i != -1 { 113 | return string(bytes.TrimSpace(untilEOL(b[i+len(lfhostColon):]))) 114 | } 115 | return "" 116 | } 117 | 118 | // untilEOL returns v, truncated before the first '\n' byte, if any. 119 | // The returned slice may include a '\r' at the end. 120 | func untilEOL(v []byte) []byte { 121 | if i := bytes.IndexByte(v, '\n'); i != -1 { 122 | return v[:i] 123 | } 124 | return v 125 | } 126 | -------------------------------------------------------------------------------- /listener.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package tcpproxy 16 | 17 | import ( 18 | "io" 19 | "net" 20 | "sync" 21 | ) 22 | 23 | // TargetListener implements both net.Listener and Target. 24 | // Matched Targets become accepted connections. 25 | type TargetListener struct { 26 | Address string // Address is the string reported by TargetListener.Addr().String(). 27 | 28 | mu sync.Mutex 29 | cond *sync.Cond 30 | closed bool 31 | nextConn net.Conn 32 | } 33 | 34 | var ( 35 | _ net.Listener = (*TargetListener)(nil) 36 | _ Target = (*TargetListener)(nil) 37 | ) 38 | 39 | func (tl *TargetListener) lock() { 40 | tl.mu.Lock() 41 | if tl.cond == nil { 42 | tl.cond = sync.NewCond(&tl.mu) 43 | } 44 | } 45 | 46 | type tcpAddr string 47 | 48 | func (a tcpAddr) Network() string { return "tcp" } 49 | func (a tcpAddr) String() string { return string(a) } 50 | 51 | // Addr returns the listener's Address field as a net.Addr. 52 | func (tl *TargetListener) Addr() net.Addr { return tcpAddr(tl.Address) } 53 | 54 | // Close stops listening for new connections. All new connections 55 | // routed to this listener will be closed. Already accepted 56 | // connections are not closed. 57 | func (tl *TargetListener) Close() error { 58 | tl.lock() 59 | if tl.closed { 60 | tl.mu.Unlock() 61 | return nil 62 | } 63 | tl.closed = true 64 | tl.mu.Unlock() 65 | tl.cond.Broadcast() 66 | return nil 67 | } 68 | 69 | // HandleConn implements the Target interface. It blocks until tl is 70 | // closed or another goroutine has called Accept and received c. 71 | func (tl *TargetListener) HandleConn(c net.Conn) { 72 | tl.lock() 73 | defer tl.mu.Unlock() 74 | for tl.nextConn != nil && !tl.closed { 75 | tl.cond.Wait() 76 | } 77 | if tl.closed { 78 | c.Close() 79 | return 80 | } 81 | tl.nextConn = c 82 | tl.cond.Broadcast() // Signal might be sufficient; verify. 83 | for tl.nextConn == c && !tl.closed { 84 | tl.cond.Wait() 85 | } 86 | if tl.closed { 87 | c.Close() 88 | return 89 | } 90 | } 91 | 92 | // Accept implements the Accept method in the net.Listener interface. 93 | func (tl *TargetListener) Accept() (net.Conn, error) { 94 | tl.lock() 95 | for tl.nextConn == nil && !tl.closed { 96 | tl.cond.Wait() 97 | } 98 | if tl.closed { 99 | tl.mu.Unlock() 100 | return nil, io.EOF 101 | } 102 | c := tl.nextConn 103 | tl.nextConn = nil 104 | tl.mu.Unlock() 105 | tl.cond.Broadcast() // Signal might be sufficient; verify. 106 | 107 | return c, nil 108 | } 109 | -------------------------------------------------------------------------------- /listener_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package tcpproxy 16 | 17 | import ( 18 | "io" 19 | "testing" 20 | ) 21 | 22 | func TestListenerAccept(t *testing.T) { 23 | tl := new(TargetListener) 24 | ch := make(chan interface{}, 1) 25 | go func() { 26 | for { 27 | conn, err := tl.Accept() 28 | if err != nil { 29 | ch <- err 30 | return 31 | } 32 | ch <- conn 33 | } 34 | }() 35 | 36 | for i := 0; i < 3; i++ { 37 | conn := new(Conn) 38 | tl.HandleConn(conn) 39 | got := <-ch 40 | if got != conn { 41 | t.Errorf("Accept conn = %v; want %v", got, conn) 42 | } 43 | } 44 | tl.Close() 45 | got := <-ch 46 | if got != io.EOF { 47 | t.Errorf("Accept error post-Close = %v; want io.EOF", got) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /scripts/prune_old_versions.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package main 16 | 17 | import ( 18 | "encoding/json" 19 | "flag" 20 | "fmt" 21 | "io/ioutil" 22 | "net/http" 23 | "os" 24 | "sort" 25 | "strings" 26 | "time" 27 | ) 28 | 29 | var ( 30 | user = flag.String("user", "", "username") 31 | repo = flag.String("repo", "", "repository name") 32 | pkgType = flag.String("pkg-type", "deb", "Package type, e.g. 'deb'") 33 | distro = flag.String("distro", "", "distro name, e.g. 'debian'") 34 | distroVersion = flag.String("version", "", "distro version, e.g. 'stretch'") 35 | pkg = flag.String("package", "", "package name") 36 | arch = flag.String("arch", "", "package architecture") 37 | limit = flag.Int("limit", 2, "package versions to keep") 38 | ) 39 | 40 | func fatalf(msg string, args ...interface{}) { 41 | fmt.Printf(msg+"\n", args...) 42 | os.Exit(1) 43 | } 44 | 45 | func main() { 46 | flag.Parse() 47 | if *user == "" { 48 | fatalf("missing -user") 49 | } 50 | if *repo == "" { 51 | fatalf("missing -repo") 52 | } 53 | if *pkgType == "" { 54 | fatalf("missing -pkg-type") 55 | } 56 | if *distro == "" { 57 | fatalf("missing -distro") 58 | } 59 | if *distroVersion == "" { 60 | fatalf("missing -version") 61 | } 62 | if *pkg == "" { 63 | fatalf("missing -package") 64 | } 65 | if *arch == "" { 66 | fatalf("missing -arch") 67 | } 68 | if *limit < 1 { 69 | fatalf("limit must be >= 1") 70 | } 71 | 72 | files, err := packageVersions(*user, *repo, *pkgType, *distro, *distroVersion, *pkg, *arch) 73 | if err != nil { 74 | fmt.Println(err) 75 | os.Exit(1) 76 | } 77 | if len(files) <= *limit { 78 | fmt.Println("Below limit, no packages deleted") 79 | return 80 | } 81 | delete := files[:len(files)-*limit] 82 | keep := files[len(files)-*limit:] 83 | if err = deletePackages(delete); err != nil { 84 | fmt.Println(err) 85 | os.Exit(1) 86 | } 87 | 88 | fmt.Printf("Deleted:\n\n%s\n\nKept:\n\n%s\n", strings.Join(delete, "\n"), strings.Join(keep, "\n")) 89 | } 90 | 91 | type packageMeta struct { 92 | Created time.Time `json:"created_at"` 93 | Filename string `json:"filename"` 94 | } 95 | 96 | type metaSort []packageMeta 97 | 98 | func (m metaSort) Len() int { return len(m) } 99 | func (m metaSort) Less(i, j int) bool { return m[i].Created.Before(m[j].Created) } 100 | func (m metaSort) Swap(i, j int) { m[i], m[j] = m[j], m[i] } 101 | 102 | func packageVersions(user, repo, typ, distro, version, pkgname, arch string) ([]string, error) { 103 | url := fmt.Sprintf("https://%s:@packagecloud.io/api/v1/repos/%s/%s/package/%s/%s/%s/%s/%s/versions.json", os.Getenv("PACKAGECLOUD_API_KEY"), user, repo, typ, distro, version, pkgname, arch) 104 | resp, err := http.Get(url) 105 | if err != nil { 106 | return nil, fmt.Errorf("get versions.json: %s", err) 107 | } 108 | defer resp.Body.Close() 109 | if resp.StatusCode != 200 { 110 | msg, err := ioutil.ReadAll(resp.Body) 111 | if err != nil { 112 | return nil, fmt.Errorf("get error message of versions.json get: %s", err) 113 | } 114 | return nil, fmt.Errorf("get versions.json: %s (%q)", resp.Status, string(msg)) 115 | } 116 | 117 | var files []packageMeta 118 | if err := json.NewDecoder(resp.Body).Decode(&files); err != nil { 119 | return nil, fmt.Errorf("decode versions.json: %s", err) 120 | } 121 | 122 | // Newest first 123 | sort.Sort(metaSort(files)) 124 | 125 | var ret []string 126 | for _, meta := range files { 127 | ret = append(ret, fmt.Sprintf("/api/v1/repos/%s/%s/%s/%s/%s", user, repo, distro, version, meta.Filename)) 128 | } 129 | 130 | return ret, nil 131 | } 132 | 133 | func deletePackages(urls []string) error { 134 | for _, url := range urls { 135 | fullURL := fmt.Sprintf("https://%s:@packagecloud.io%s", os.Getenv("PACKAGECLOUD_API_KEY"), url) 136 | req, err := http.NewRequest("DELETE", fullURL, nil) 137 | if err != nil { 138 | return fmt.Errorf("build delete request for %s: %s", url, err) 139 | } 140 | resp, err := http.DefaultClient.Do(req) 141 | if err != nil { 142 | return fmt.Errorf("delete %s: %s", url, err) 143 | } 144 | defer resp.Body.Close() 145 | if resp.StatusCode != 200 { 146 | return fmt.Errorf("delete %s: %s", url, resp.Status) 147 | } 148 | } 149 | return nil 150 | } 151 | -------------------------------------------------------------------------------- /sni.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package tcpproxy 16 | 17 | import ( 18 | "bufio" 19 | "bytes" 20 | "context" 21 | "crypto/tls" 22 | "io" 23 | "net" 24 | ) 25 | 26 | // AddSNIRoute appends a route to the ipPort listener that routes to 27 | // dest if the incoming TLS SNI server name is sni. If it doesn't 28 | // match, rule processing continues for any additional routes on 29 | // ipPort. 30 | // 31 | // The ipPort is any valid net.Listen TCP address. 32 | func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) { 33 | p.AddSNIMatchRoute(ipPort, equals(sni), dest) 34 | } 35 | 36 | // AddSNIMatchRoute appends a route to the ipPort listener that routes 37 | // to dest if the incoming TLS SNI server name is accepted by 38 | // matcher. If it doesn't match, rule processing continues for any 39 | // additional routes on ipPort. 40 | // 41 | // The ipPort is any valid net.Listen TCP address. 42 | func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) { 43 | p.addRoute(ipPort, sniMatch{matcher: matcher, target: dest}) 44 | } 45 | 46 | // SNITargetFunc is the func callback used by Proxy.AddSNIRouteFunc. 47 | type SNITargetFunc func(ctx context.Context, sniName string) (t Target, ok bool) 48 | 49 | // AddSNIRouteFunc adds a route to ipPort that matches an SNI request and calls 50 | // fn to map its nap to a target. 51 | func (p *Proxy) AddSNIRouteFunc(ipPort string, fn SNITargetFunc) { 52 | p.addRoute(ipPort, sniMatch{targetFunc: fn}) 53 | } 54 | 55 | type sniMatch struct { 56 | matcher Matcher 57 | target Target 58 | 59 | // Alternatively, if targetFunc is non-nil, it's used instead: 60 | targetFunc SNITargetFunc 61 | } 62 | 63 | func (m sniMatch) match(br *bufio.Reader) (Target, string) { 64 | sni := clientHelloServerName(br) 65 | if sni == "" { 66 | return nil, "" 67 | } 68 | if m.targetFunc != nil { 69 | if t, ok := m.targetFunc(context.TODO(), sni); ok { 70 | return t, sni 71 | } 72 | return nil, "" 73 | } 74 | if m.matcher(context.TODO(), sni) { 75 | return m.target, sni 76 | } 77 | return nil, "" 78 | } 79 | 80 | // clientHelloServerName returns the SNI server name inside the TLS ClientHello, 81 | // without consuming any bytes from br. 82 | // On any error, the empty string is returned. 83 | func clientHelloServerName(br *bufio.Reader) (sni string) { 84 | const recordHeaderLen = 5 85 | hdr, err := br.Peek(recordHeaderLen) 86 | if err != nil { 87 | return "" 88 | } 89 | const recordTypeHandshake = 0x16 90 | if hdr[0] != recordTypeHandshake { 91 | return "" // Not TLS. 92 | } 93 | recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3] 94 | helloBytes, err := br.Peek(recordHeaderLen + recLen) 95 | if err != nil { 96 | return "" 97 | } 98 | tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{ 99 | GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) { 100 | sni = hello.ServerName 101 | return nil, nil 102 | }, 103 | }).Handshake() 104 | return 105 | } 106 | 107 | // sniSniffConn is a net.Conn that reads from r, fails on Writes, 108 | // and crashes otherwise. 109 | type sniSniffConn struct { 110 | r io.Reader 111 | net.Conn // nil; crash on any unexpected use 112 | } 113 | 114 | func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } 115 | func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF } 116 | -------------------------------------------------------------------------------- /systemd/tlsrouter.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | Description=TLS SNI proxy 3 | Documentation=https://github.com/google/tlsrouter 4 | 5 | [Service] 6 | WorkingDirectory=/tmp 7 | ExecStart=/usr/bin/tlsrouter -conf /etc/tlsrouter.conf 8 | Restart=always 9 | User=nobody 10 | Group=nogroup 11 | CapabilityBoundingSet=CAP_NET_BIND_SERVICE 12 | AmbientCapabilities=CAP_NET_BIND_SERVICE 13 | PrivateTmp=true 14 | PrivateDevices=true 15 | ProtectSystem=strict 16 | ProtectHome=true 17 | ProtectKernelTunables=true 18 | ProtectControlGroups=true 19 | ProtectKernelModules=true 20 | NoNewPrivileges=true 21 | SystemCallFilter=~@clock @cpu-emulation @debug @keyring @module @mount @obsolete @privileged @raw-io 22 | RestrictAddressFamilies=AF_INET AF_INET6 AF_UNIX 23 | 24 | [Install] 25 | WantedBy=multi-user.target 26 | -------------------------------------------------------------------------------- /tcpproxy.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Package tcpproxy lets users build TCP proxies, optionally making 16 | // routing decisions based on HTTP/1 Host headers and the SNI hostname 17 | // in TLS connections. 18 | // 19 | // Typical usage: 20 | // 21 | // var p tcpproxy.Proxy 22 | // p.AddHTTPHostRoute(":80", "foo.com", tcpproxy.To("10.0.0.1:8081")) 23 | // p.AddHTTPHostRoute(":80", "bar.com", tcpproxy.To("10.0.0.2:8082")) 24 | // p.AddRoute(":80", tcpproxy.To("10.0.0.1:8081")) // fallback 25 | // p.AddSNIRoute(":443", "foo.com", tcpproxy.To("10.0.0.1:4431")) 26 | // p.AddSNIRoute(":443", "bar.com", tcpproxy.To("10.0.0.2:4432")) 27 | // p.AddRoute(":443", tcpproxy.To("10.0.0.1:4431")) // fallback 28 | // log.Fatal(p.Run()) 29 | // 30 | // Calling Run (or Start) on a proxy also starts all the necessary 31 | // listeners. 32 | // 33 | // For each accepted connection, the rules for that ipPort are 34 | // matched, in order. If one matches (currently HTTP Host, SNI, or 35 | // always), then the connection is handed to the target. 36 | // 37 | // The two predefined Target implementations are: 38 | // 39 | // 1) DialProxy, proxying to another address (use the To func to return a 40 | // DialProxy value), 41 | // 42 | // 2) TargetListener, making the matched connection available via a 43 | // net.Listener.Accept call. 44 | // 45 | // But Target is an interface, so you can also write your own. 46 | // 47 | // Note that tcpproxy does not do any TLS encryption or decryption. It 48 | // only (via DialProxy) copies bytes around. The SNI hostname in the TLS 49 | // header is unencrypted, for better or worse. 50 | // 51 | // This package makes no API stability promises. If you depend on it, 52 | // vendor it. 53 | package tcpproxy 54 | 55 | import ( 56 | "bufio" 57 | "context" 58 | "errors" 59 | "fmt" 60 | "io" 61 | "log" 62 | "net" 63 | "time" 64 | ) 65 | 66 | // Proxy is a proxy. Its zero value is a valid proxy that does 67 | // nothing. Call methods to add routes before calling Start or Run. 68 | // 69 | // The order that routes are added in matters; each is matched in the order 70 | // registered. 71 | type Proxy struct { 72 | configs map[string]*config // ip:port => config 73 | 74 | lns []net.Listener 75 | donec chan struct{} // closed before err 76 | err error // any error from listening 77 | 78 | // ListenFunc optionally specifies an alternate listen 79 | // function. If nil, net.Dial is used. 80 | // The provided net is always "tcp". 81 | ListenFunc func(net, laddr string) (net.Listener, error) 82 | } 83 | 84 | // Matcher reports whether hostname matches the Matcher's criteria. 85 | type Matcher func(ctx context.Context, hostname string) bool 86 | 87 | // equals is a trivial Matcher that implements string equality. 88 | func equals(want string) Matcher { 89 | return func(_ context.Context, got string) bool { 90 | return want == got 91 | } 92 | } 93 | 94 | // config contains the proxying state for one listener. 95 | type config struct { 96 | routes []route 97 | } 98 | 99 | // A route matches a connection to a target. 100 | type route interface { 101 | // match examines the initial bytes of a connection, looking for a 102 | // match. If a match is found, match returns a non-nil Target to 103 | // which the stream should be proxied. match returns nil if the 104 | // connection doesn't match. 105 | // 106 | // match must not consume bytes from the given bufio.Reader, it 107 | // can only Peek. 108 | // 109 | // If an sni or host header was parsed successfully, that will be 110 | // returned as the second parameter. 111 | match(*bufio.Reader) (Target, string) 112 | } 113 | 114 | func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) { 115 | if p.ListenFunc != nil { 116 | return p.ListenFunc 117 | } 118 | return net.Listen 119 | } 120 | 121 | func (p *Proxy) configFor(ipPort string) *config { 122 | if p.configs == nil { 123 | p.configs = make(map[string]*config) 124 | } 125 | if p.configs[ipPort] == nil { 126 | p.configs[ipPort] = &config{} 127 | } 128 | return p.configs[ipPort] 129 | } 130 | 131 | func (p *Proxy) addRoute(ipPort string, r route) { 132 | cfg := p.configFor(ipPort) 133 | cfg.routes = append(cfg.routes, r) 134 | } 135 | 136 | // AddRoute appends an always-matching route to the ipPort listener, 137 | // directing any connection to dest. 138 | // 139 | // This is generally used as either the only rule (for simple TCP 140 | // proxies), or as the final fallback rule for an ipPort. 141 | // 142 | // The ipPort is any valid net.Listen TCP address. 143 | func (p *Proxy) AddRoute(ipPort string, dest Target) { 144 | p.addRoute(ipPort, fixedTarget{dest}) 145 | } 146 | 147 | type fixedTarget struct { 148 | t Target 149 | } 150 | 151 | func (m fixedTarget) match(*bufio.Reader) (Target, string) { return m.t, "" } 152 | 153 | // Run is calls Start, and then Wait. 154 | // 155 | // It blocks until there's an error. The return value is always 156 | // non-nil. 157 | func (p *Proxy) Run() error { 158 | if err := p.Start(); err != nil { 159 | return err 160 | } 161 | return p.Wait() 162 | } 163 | 164 | // Wait waits for the Proxy to finish running. Currently this can only 165 | // happen if a Listener is closed, or Close is called on the proxy. 166 | // 167 | // It is only valid to call Wait after a successful call to Start. 168 | func (p *Proxy) Wait() error { 169 | <-p.donec 170 | return p.err 171 | } 172 | 173 | // Close closes all the proxy's self-opened listeners. 174 | func (p *Proxy) Close() error { 175 | for _, c := range p.lns { 176 | c.Close() 177 | } 178 | return nil 179 | } 180 | 181 | // Start creates a TCP listener for each unique ipPort from the 182 | // previously created routes and starts the proxy. It returns any 183 | // error from starting listeners. 184 | // 185 | // If it returns a non-nil error, any successfully opened listeners 186 | // are closed. 187 | func (p *Proxy) Start() error { 188 | if p.donec != nil { 189 | return errors.New("already started") 190 | } 191 | p.donec = make(chan struct{}) 192 | errc := make(chan error, len(p.configs)) 193 | p.lns = make([]net.Listener, 0, len(p.configs)) 194 | for ipPort, config := range p.configs { 195 | ln, err := p.netListen()("tcp", ipPort) 196 | if err != nil { 197 | p.Close() 198 | return err 199 | } 200 | p.lns = append(p.lns, ln) 201 | go p.serveListener(errc, ln, config.routes) 202 | } 203 | go p.awaitFirstError(errc) 204 | return nil 205 | } 206 | 207 | func (p *Proxy) awaitFirstError(errc <-chan error) { 208 | p.err = <-errc 209 | close(p.donec) 210 | } 211 | 212 | func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) { 213 | for { 214 | c, err := ln.Accept() 215 | if err != nil { 216 | ret <- err 217 | return 218 | } 219 | go p.serveConn(c, routes) 220 | } 221 | } 222 | 223 | // serveConn runs in its own goroutine and matches c against routes. 224 | // It returns whether it matched purely for testing. 225 | func (p *Proxy) serveConn(c net.Conn, routes []route) bool { 226 | br := bufio.NewReader(c) 227 | for _, route := range routes { 228 | if target, hostName := route.match(br); target != nil { 229 | if n := br.Buffered(); n > 0 { 230 | peeked, _ := br.Peek(br.Buffered()) 231 | c = &Conn{ 232 | HostName: hostName, 233 | Peeked: peeked, 234 | Conn: c, 235 | } 236 | } 237 | target.HandleConn(c) 238 | return true 239 | } 240 | } 241 | // TODO: hook for this? 242 | log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String()) 243 | c.Close() 244 | return false 245 | } 246 | 247 | // Conn is an incoming connection that has had some bytes read from it 248 | // to determine how to route the connection. The Read method stitches 249 | // the peeked bytes and unread bytes back together. 250 | type Conn struct { 251 | // HostName is the hostname field that was sent to the request router. 252 | // In the case of TLS, this is the SNI header, in the case of HTTPHost 253 | // route, it will be the host header. In the case of a fixed 254 | // route, i.e. those created with AddRoute(), this will always be 255 | // empty. This can be useful in the case where further routing decisions 256 | // need to be made in the Target impementation. 257 | HostName string 258 | 259 | // Peeked are the bytes that have been read from Conn for the 260 | // purposes of route matching, but have not yet been consumed 261 | // by Read calls. It set to nil by Read when fully consumed. 262 | Peeked []byte 263 | 264 | // Conn is the underlying connection. 265 | // It can be type asserted against *net.TCPConn or other types 266 | // as needed. It should not be read from directly unless 267 | // Peeked is nil. 268 | net.Conn 269 | } 270 | 271 | func (c *Conn) Read(p []byte) (n int, err error) { 272 | if len(c.Peeked) > 0 { 273 | n = copy(p, c.Peeked) 274 | c.Peeked = c.Peeked[n:] 275 | if len(c.Peeked) == 0 { 276 | c.Peeked = nil 277 | } 278 | return n, nil 279 | } 280 | return c.Conn.Read(p) 281 | } 282 | 283 | // Target is what an incoming matched connection is sent to. 284 | type Target interface { 285 | // HandleConn is called when an incoming connection is 286 | // matched. After the call to HandleConn, the tcpproxy 287 | // package never touches the conn again. Implementations are 288 | // responsible for closing the connection when needed. 289 | // 290 | // The concrete type of conn will be of type *Conn if any 291 | // bytes have been consumed for the purposes of route 292 | // matching. 293 | HandleConn(net.Conn) 294 | } 295 | 296 | // To is shorthand way of writing &tcpproxy.DialProxy{Addr: addr}. 297 | func To(addr string) *DialProxy { 298 | return &DialProxy{Addr: addr} 299 | } 300 | 301 | // DialProxy implements Target by dialing a new connection to Addr 302 | // and then proxying data back and forth. 303 | // 304 | // The To func is a shorthand way of creating a DialProxy. 305 | type DialProxy struct { 306 | // Addr is the TCP address to proxy to. 307 | Addr string 308 | 309 | // KeepAlivePeriod sets the period between TCP keep alives. 310 | // If zero, a default is used. To disable, use a negative number. 311 | // The keep-alive is used for both the client connection and 312 | KeepAlivePeriod time.Duration 313 | 314 | // DialTimeout optionally specifies a dial timeout. 315 | // If zero, a default is used. 316 | // If negative, the timeout is disabled. 317 | DialTimeout time.Duration 318 | 319 | // DialContext optionally specifies an alternate dial function 320 | // for TCP targets. If nil, the standard 321 | // net.Dialer.DialContext method is used. 322 | DialContext func(ctx context.Context, network, address string) (net.Conn, error) 323 | 324 | // OnDialError optionally specifies an alternate way to handle errors dialing Addr. 325 | // If nil, the error is logged and src is closed. 326 | // If non-nil, src is not closed automatically. 327 | OnDialError func(src net.Conn, dstDialErr error) 328 | 329 | // ProxyProtocolVersion optionally specifies the version of 330 | // HAProxy's PROXY protocol to use. The PROXY protocol provides 331 | // connection metadata to the DialProxy target, via a header 332 | // inserted ahead of the client's traffic. The DialProxy target 333 | // must explicitly support and expect the PROXY header; there is 334 | // no graceful downgrade. 335 | // If zero, no PROXY header is sent. Currently, version 1 is supported. 336 | ProxyProtocolVersion int 337 | } 338 | 339 | // UnderlyingConn returns c.Conn if c of type *Conn, 340 | // otherwise it returns c. 341 | func UnderlyingConn(c net.Conn) net.Conn { 342 | if wrap, ok := c.(*Conn); ok { 343 | return wrap.Conn 344 | } 345 | return c 346 | } 347 | 348 | func tcpConn(c net.Conn) (t *net.TCPConn, ok bool) { 349 | if c, ok := UnderlyingConn(c).(*net.TCPConn); ok { 350 | return c, ok 351 | } 352 | if c, ok := c.(*net.TCPConn); ok { 353 | return c, ok 354 | } 355 | return nil, false 356 | } 357 | 358 | type closeReader interface{ CloseRead() error } 359 | type closeWriter interface{ CloseWrite() error } 360 | 361 | func closeRead(c net.Conn) { 362 | // prefer the interfaces, for compatibility with e.g. gvisor/netstack. 363 | if c, ok := UnderlyingConn(c).(closeReader); ok { 364 | c.CloseRead() 365 | } 366 | } 367 | 368 | func closeWrite(c net.Conn) { 369 | // prefer the interfaces, for compatibility with e.g. gvisor/netstack. 370 | if c, ok := UnderlyingConn(c).(closeWriter); ok { 371 | c.CloseWrite() 372 | } 373 | } 374 | 375 | // HandleConn implements the Target interface. 376 | func (dp *DialProxy) HandleConn(src net.Conn) { 377 | ctx := context.Background() 378 | var cancel context.CancelFunc 379 | if dp.DialTimeout >= 0 { 380 | ctx, cancel = context.WithTimeout(ctx, dp.dialTimeout()) 381 | } 382 | dst, err := dp.dialContext()(ctx, "tcp", dp.Addr) 383 | if cancel != nil { 384 | cancel() 385 | } 386 | if err != nil { 387 | dp.onDialError()(src, err) 388 | return 389 | } 390 | defer dst.Close() 391 | 392 | if err = dp.sendProxyHeader(dst, src); err != nil { 393 | dp.onDialError()(src, err) 394 | return 395 | } 396 | defer src.Close() 397 | 398 | if ka := dp.keepAlivePeriod(); ka > 0 { 399 | for _, c := range []net.Conn{src, dst} { 400 | if c, ok := tcpConn(c); ok { 401 | c.SetKeepAlive(true) 402 | c.SetKeepAlivePeriod(ka) 403 | } 404 | } 405 | } 406 | 407 | errc := make(chan error, 2) 408 | go proxyCopy(errc, src, dst) 409 | go proxyCopy(errc, dst, src) 410 | <-errc 411 | <-errc 412 | } 413 | 414 | func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { 415 | switch dp.ProxyProtocolVersion { 416 | case 0: 417 | return nil 418 | case 1: 419 | var srcAddr, dstAddr *net.TCPAddr 420 | if a, ok := src.RemoteAddr().(*net.TCPAddr); ok { 421 | srcAddr = a 422 | } 423 | if a, ok := src.LocalAddr().(*net.TCPAddr); ok { 424 | dstAddr = a 425 | } 426 | 427 | if srcAddr == nil || dstAddr == nil { 428 | _, err := io.WriteString(w, "PROXY UNKNOWN\r\n") 429 | return err 430 | } 431 | 432 | family := "TCP4" 433 | if srcAddr.IP.To4() == nil { 434 | family = "TCP6" 435 | } 436 | _, err := fmt.Fprintf(w, "PROXY %s %s %s %d %d\r\n", family, srcAddr.IP, dstAddr.IP, srcAddr.Port, dstAddr.Port) 437 | return err 438 | default: 439 | return fmt.Errorf("PROXY protocol version %d not supported", dp.ProxyProtocolVersion) 440 | } 441 | } 442 | 443 | // proxyCopy is the function that copies bytes around. 444 | // It's a named function instead of a func literal so users get 445 | // named goroutines in debug goroutine stack dumps. 446 | func proxyCopy(errc chan<- error, dst, src net.Conn) { 447 | defer closeRead(src) 448 | defer closeWrite(dst) 449 | 450 | // Before we unwrap src and/or dst, copy any buffered data. 451 | if wc, ok := src.(*Conn); ok && len(wc.Peeked) > 0 { 452 | if _, err := dst.Write(wc.Peeked); err != nil { 453 | errc <- err 454 | return 455 | } 456 | wc.Peeked = nil 457 | } 458 | 459 | // Unwrap the src and dst from *Conn to *net.TCPConn so Go 460 | // 1.11's splice optimization kicks in. 461 | src = UnderlyingConn(src) 462 | dst = UnderlyingConn(dst) 463 | 464 | _, err := io.Copy(dst, src) 465 | errc <- err 466 | } 467 | 468 | func (dp *DialProxy) keepAlivePeriod() time.Duration { 469 | if dp.KeepAlivePeriod != 0 { 470 | return dp.KeepAlivePeriod 471 | } 472 | return time.Minute 473 | } 474 | 475 | func (dp *DialProxy) dialTimeout() time.Duration { 476 | if dp.DialTimeout > 0 { 477 | return dp.DialTimeout 478 | } 479 | return 10 * time.Second 480 | } 481 | 482 | var defaultDialer = new(net.Dialer) 483 | 484 | func (dp *DialProxy) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) { 485 | if dp.DialContext != nil { 486 | return dp.DialContext 487 | } 488 | return defaultDialer.DialContext 489 | } 490 | 491 | func (dp *DialProxy) onDialError() func(src net.Conn, dstDialErr error) { 492 | if dp.OnDialError != nil { 493 | return dp.OnDialError 494 | } 495 | return func(src net.Conn, dstDialErr error) { 496 | var remoteAddr string 497 | if ra := src.RemoteAddr(); ra != nil { 498 | remoteAddr = ra.String() 499 | } else { 500 | remoteAddr = fmt.Sprintf("[%T with nil RemoteAddr]", src) 501 | } 502 | log.Printf("tcpproxy: for incoming conn %v, error dialing %q: %v", remoteAddr, dp.Addr, dstDialErr) 503 | src.Close() 504 | } 505 | } 506 | -------------------------------------------------------------------------------- /tcpproxy_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package tcpproxy 16 | 17 | import ( 18 | "bufio" 19 | "bytes" 20 | "context" 21 | "crypto/rand" 22 | "crypto/rsa" 23 | "crypto/tls" 24 | "crypto/x509" 25 | "crypto/x509/pkix" 26 | "encoding/pem" 27 | "errors" 28 | "fmt" 29 | "io" 30 | "io/ioutil" 31 | "math/big" 32 | "net" 33 | "strings" 34 | "testing" 35 | "time" 36 | ) 37 | 38 | type noopTarget struct{} 39 | 40 | func (noopTarget) HandleConn(net.Conn) {} 41 | 42 | func TestMatchHTTPHost(t *testing.T) { 43 | tests := []struct { 44 | name string 45 | r io.Reader 46 | host string 47 | want bool 48 | }{ 49 | { 50 | name: "match", 51 | r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"), 52 | host: "foo.com", 53 | want: true, 54 | }, 55 | { 56 | name: "no-match", 57 | r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"), 58 | host: "bar.com", 59 | want: false, 60 | }, 61 | { 62 | name: "match-huge-request", 63 | r: io.MultiReader(strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n"), neverEnding('a')), 64 | host: "foo.com", 65 | want: true, 66 | }, 67 | } 68 | for i, tt := range tests { 69 | name := tt.name 70 | if name == "" { 71 | name = fmt.Sprintf("test_index_%d", i) 72 | } 73 | t.Run(name, func(t *testing.T) { 74 | br := bufio.NewReader(tt.r) 75 | r := httpHostMatch{equals(tt.host), noopTarget{}} 76 | m, name := r.match(br) 77 | got := m != nil 78 | if got != tt.want { 79 | t.Fatalf("match = %v; want %v", got, tt.want) 80 | } 81 | if tt.want && name != tt.host { 82 | t.Fatalf("host = %s; want %s", name, tt.host) 83 | } 84 | get := make([]byte, 3) 85 | if _, err := io.ReadFull(br, get); err != nil { 86 | t.Fatal(err) 87 | } 88 | if string(get) != "GET" { 89 | t.Fatalf("did bufio.Reader consume bytes? got %q; want GET", get) 90 | } 91 | }) 92 | } 93 | } 94 | 95 | type neverEnding byte 96 | 97 | func (b neverEnding) Read(p []byte) (n int, err error) { 98 | for i := range p { 99 | p[i] = byte(b) 100 | } 101 | return len(p), nil 102 | } 103 | 104 | type recordWritesConn struct { 105 | buf bytes.Buffer 106 | net.Conn 107 | } 108 | 109 | func (c *recordWritesConn) Write(p []byte) (int, error) { 110 | c.buf.Write(p) 111 | return len(p), nil 112 | } 113 | 114 | func (c *recordWritesConn) Read(p []byte) (int, error) { return 0, io.EOF } 115 | 116 | func clientHelloRecord(t *testing.T, hostName string) string { 117 | rec := new(recordWritesConn) 118 | cl := tls.Client(rec, &tls.Config{ServerName: hostName}) 119 | cl.Handshake() 120 | 121 | s := rec.buf.String() 122 | if !strings.Contains(s, hostName) { 123 | t.Fatalf("clientHello sent in test didn't contain %q", hostName) 124 | } 125 | return s 126 | } 127 | 128 | func TestSNI(t *testing.T) { 129 | const hostName = "foo.com" 130 | greeting := clientHelloRecord(t, hostName) 131 | got := clientHelloServerName(bufio.NewReader(strings.NewReader(greeting))) 132 | if got != hostName { 133 | t.Errorf("got SNI %q; want %q", got, hostName) 134 | } 135 | } 136 | 137 | func TestProxyStartNone(t *testing.T) { 138 | var p Proxy 139 | if err := p.Start(); err != nil { 140 | t.Fatal(err) 141 | } 142 | } 143 | 144 | func newLocalListener(t *testing.T) net.Listener { 145 | ln, err := net.Listen("tcp", "127.0.0.1:0") 146 | if err != nil { 147 | ln, err = net.Listen("tcp", "[::1]:0") 148 | if err != nil { 149 | t.Fatal(err) 150 | } 151 | } 152 | return ln 153 | } 154 | 155 | const testFrontAddr = "1.2.3.4:567" 156 | 157 | func testListenFunc(t *testing.T, ln net.Listener) func(network, laddr string) (net.Listener, error) { 158 | return func(network, laddr string) (net.Listener, error) { 159 | if network != "tcp" { 160 | t.Errorf("got Listen call with network %q, not tcp", network) 161 | return nil, errors.New("invalid network") 162 | } 163 | if laddr != testFrontAddr { 164 | t.Fatalf("got Listen call with laddr %q, want %q", laddr, testFrontAddr) 165 | panic("bogus address") 166 | } 167 | return ln, nil 168 | } 169 | } 170 | 171 | func testProxy(t *testing.T, front net.Listener) *Proxy { 172 | return &Proxy{ 173 | ListenFunc: testListenFunc(t, front), 174 | } 175 | } 176 | 177 | func TestBufferedClose(t *testing.T) { 178 | front := newLocalListener(t) 179 | defer front.Close() 180 | back := newLocalListener(t) 181 | defer back.Close() 182 | 183 | p := testProxy(t, front) 184 | p.AddRoute(testFrontAddr, To(back.Addr().String())) 185 | if err := p.Start(); err != nil { 186 | t.Fatal(err) 187 | } 188 | 189 | toFront, err := net.Dial("tcp", front.Addr().String()) 190 | if err != nil { 191 | t.Fatal(err) 192 | } 193 | defer toFront.Close() 194 | 195 | fromProxy, err := back.Accept() 196 | if err != nil { 197 | t.Fatal(err) 198 | } 199 | defer fromProxy.Close() 200 | const msg = "message" 201 | if _, err := io.WriteString(toFront, msg); err != nil { 202 | t.Fatal(err) 203 | } 204 | // actively close toFront, the write should still make to the back. 205 | toFront.Close() 206 | 207 | buf := make([]byte, len(msg)) 208 | if _, err := io.ReadFull(fromProxy, buf); err != nil { 209 | t.Fatal(err) 210 | } 211 | if string(buf) != msg { 212 | t.Fatalf("got %q; want %q", buf, msg) 213 | } 214 | } 215 | 216 | func TestProxyAlwaysMatch(t *testing.T) { 217 | front := newLocalListener(t) 218 | defer front.Close() 219 | back := newLocalListener(t) 220 | defer back.Close() 221 | 222 | p := testProxy(t, front) 223 | p.AddRoute(testFrontAddr, To(back.Addr().String())) 224 | if err := p.Start(); err != nil { 225 | t.Fatal(err) 226 | } 227 | 228 | toFront, err := net.Dial("tcp", front.Addr().String()) 229 | if err != nil { 230 | t.Fatal(err) 231 | } 232 | defer toFront.Close() 233 | 234 | fromProxy, err := back.Accept() 235 | if err != nil { 236 | t.Fatal(err) 237 | } 238 | defer fromProxy.Close() 239 | const msg = "message" 240 | io.WriteString(toFront, msg) 241 | 242 | buf := make([]byte, len(msg)) 243 | if _, err := io.ReadFull(fromProxy, buf); err != nil { 244 | t.Fatal(err) 245 | } 246 | if string(buf) != msg { 247 | t.Fatalf("got %q; want %q", buf, msg) 248 | } 249 | } 250 | 251 | func TestProxyHTTP(t *testing.T) { 252 | front := newLocalListener(t) 253 | defer front.Close() 254 | 255 | backFoo := newLocalListener(t) 256 | defer backFoo.Close() 257 | backBar := newLocalListener(t) 258 | defer backBar.Close() 259 | 260 | p := testProxy(t, front) 261 | p.AddHTTPHostRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String())) 262 | p.AddHTTPHostRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) 263 | if err := p.Start(); err != nil { 264 | t.Fatal(err) 265 | } 266 | 267 | toFront, err := net.Dial("tcp", front.Addr().String()) 268 | if err != nil { 269 | t.Fatal(err) 270 | } 271 | defer toFront.Close() 272 | 273 | const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n" 274 | io.WriteString(toFront, msg) 275 | 276 | fromProxy, err := backBar.Accept() 277 | if err != nil { 278 | t.Fatal(err) 279 | } 280 | 281 | buf := make([]byte, len(msg)) 282 | if _, err := io.ReadFull(fromProxy, buf); err != nil { 283 | t.Fatal(err) 284 | } 285 | if string(buf) != msg { 286 | t.Fatalf("got %q; want %q", buf, msg) 287 | } 288 | } 289 | 290 | func TestProxySNI(t *testing.T) { 291 | front := newLocalListener(t) 292 | defer front.Close() 293 | 294 | backFoo := newLocalListener(t) 295 | defer backFoo.Close() 296 | backBar := newLocalListener(t) 297 | defer backBar.Close() 298 | 299 | p := testProxy(t, front) 300 | p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String())) 301 | p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) 302 | if err := p.Start(); err != nil { 303 | t.Fatal(err) 304 | } 305 | 306 | toFront, err := net.Dial("tcp", front.Addr().String()) 307 | if err != nil { 308 | t.Fatal(err) 309 | } 310 | defer toFront.Close() 311 | 312 | msg := clientHelloRecord(t, "bar.com") 313 | io.WriteString(toFront, msg) 314 | 315 | fromProxy, err := backBar.Accept() 316 | if err != nil { 317 | t.Fatal(err) 318 | } 319 | 320 | buf := make([]byte, len(msg)) 321 | if _, err := io.ReadFull(fromProxy, buf); err != nil { 322 | t.Fatal(err) 323 | } 324 | if string(buf) != msg { 325 | t.Fatalf("got %q; want %q", buf, msg) 326 | } 327 | } 328 | 329 | func TestAddSNIRouteFunc(t *testing.T) { 330 | front := newLocalListener(t) 331 | defer front.Close() 332 | 333 | backFoo := newLocalListener(t) 334 | defer backFoo.Close() 335 | backBar := newLocalListener(t) 336 | defer backBar.Close() 337 | 338 | p := testProxy(t, front) 339 | p.AddSNIRouteFunc(testFrontAddr, func(ctx context.Context, sniName string) (_ Target, ok bool) { 340 | if sniName == "bar.com" { 341 | return To(backBar.Addr().String()), true 342 | } 343 | t.Fatalf("failed to match %q", sniName) 344 | return nil, false 345 | }) 346 | if err := p.Start(); err != nil { 347 | t.Fatal(err) 348 | } 349 | 350 | toFront, err := net.Dial("tcp", front.Addr().String()) 351 | if err != nil { 352 | t.Fatal(err) 353 | } 354 | defer toFront.Close() 355 | 356 | msg := clientHelloRecord(t, "bar.com") 357 | io.WriteString(toFront, msg) 358 | 359 | fromProxy, err := backBar.Accept() 360 | if err != nil { 361 | t.Fatal(err) 362 | } 363 | 364 | buf := make([]byte, len(msg)) 365 | if _, err := io.ReadFull(fromProxy, buf); err != nil { 366 | t.Fatal(err) 367 | } 368 | if string(buf) != msg { 369 | t.Fatalf("got %q; want %q", buf, msg) 370 | } 371 | } 372 | func TestProxyPROXYOut(t *testing.T) { 373 | front := newLocalListener(t) 374 | defer front.Close() 375 | back := newLocalListener(t) 376 | defer back.Close() 377 | 378 | p := testProxy(t, front) 379 | p.AddRoute(testFrontAddr, &DialProxy{ 380 | Addr: back.Addr().String(), 381 | ProxyProtocolVersion: 1, 382 | }) 383 | if err := p.Start(); err != nil { 384 | t.Fatal(err) 385 | } 386 | 387 | toFront, err := net.Dial("tcp", front.Addr().String()) 388 | if err != nil { 389 | t.Fatal(err) 390 | } 391 | 392 | io.WriteString(toFront, "foo") 393 | toFront.Close() 394 | 395 | fromProxy, err := back.Accept() 396 | if err != nil { 397 | t.Fatal(err) 398 | } 399 | 400 | bs, err := ioutil.ReadAll(fromProxy) 401 | if err != nil { 402 | t.Fatal(err) 403 | } 404 | 405 | want := fmt.Sprintf("PROXY TCP4 %s %s %d %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).Port) 406 | if string(bs) != want { 407 | t.Fatalf("got %q; want %q", bs, want) 408 | } 409 | } 410 | 411 | type tlsServer struct { 412 | Listener net.Listener 413 | Domain string 414 | Test *testing.T 415 | } 416 | 417 | func (t *tlsServer) Start() { 418 | cert := cert(t.Test, t.Domain) 419 | cfg := &tls.Config{ 420 | Certificates: []tls.Certificate{cert}, 421 | } 422 | cfg.BuildNameToCertificate() 423 | 424 | go func() { 425 | for { 426 | rawConn, err := t.Listener.Accept() 427 | if err != nil { 428 | return // assume Close() 429 | } 430 | 431 | conn := tls.Server(rawConn, cfg) 432 | if _, err = io.WriteString(conn, t.Domain); err != nil { 433 | t.Test.Errorf("writing to tlsconn: %s", err) 434 | } 435 | conn.Close() 436 | } 437 | }() 438 | } 439 | 440 | func (t *tlsServer) Close() { 441 | t.Listener.Close() 442 | } 443 | 444 | // cert creates a well-formed, but completely insecure self-signed 445 | // cert for domain. 446 | func cert(t *testing.T, domain string) tls.Certificate { 447 | private, err := rsa.GenerateKey(rand.Reader, 1024) 448 | if err != nil { 449 | t.Fatal(err) 450 | } 451 | template := &x509.Certificate{ 452 | SerialNumber: big.NewInt(1), 453 | Subject: pkix.Name{ 454 | Organization: []string{"Test Co"}, 455 | CommonName: domain, 456 | }, 457 | NotBefore: time.Time{}, 458 | NotAfter: time.Now().Add(60 * time.Minute), 459 | IsCA: true, 460 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, 461 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 462 | BasicConstraintsValid: true, 463 | } 464 | 465 | derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &private.PublicKey, private) 466 | if err != nil { 467 | t.Fatal(err) 468 | } 469 | 470 | var cert, key bytes.Buffer 471 | pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) 472 | pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(private)}) 473 | 474 | tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes()) 475 | if err != nil { 476 | t.Fatal(err) 477 | } 478 | 479 | return tlscert 480 | } 481 | 482 | // newTLSServer starts a TLS server that serves a self-signed cert for 483 | // domain. 484 | func newTLSServer(t *testing.T, domain string) net.Listener { 485 | cert := cert(t, domain) 486 | 487 | l := newLocalListener(t) 488 | go func() { 489 | for { 490 | rawConn, err := l.Accept() 491 | if err != nil { 492 | return // assume closed 493 | } 494 | 495 | cfg := &tls.Config{ 496 | Certificates: []tls.Certificate{cert}, 497 | } 498 | cfg.BuildNameToCertificate() 499 | conn := tls.Server(rawConn, cfg) 500 | if _, err = io.WriteString(conn, domain); err != nil { 501 | t.Errorf("writing to tlsconn: %s", err) 502 | } 503 | conn.Close() 504 | } 505 | }() 506 | 507 | return l 508 | } 509 | 510 | func readTLS(dest, domain string) (string, error) { 511 | conn, err := tls.Dial("tcp", dest, &tls.Config{ 512 | ServerName: domain, 513 | InsecureSkipVerify: true, 514 | }) 515 | if err != nil { 516 | return "", err 517 | } 518 | defer conn.Close() 519 | 520 | bs, err := ioutil.ReadAll(conn) 521 | if err != nil { 522 | return "", err 523 | } 524 | return string(bs), nil 525 | } 526 | --------------------------------------------------------------------------------