├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── cli ├── cli.go ├── cli_test.go └── opts.go ├── client └── client.go ├── functional_test.go ├── main ├── wormhole │ └── wormhole.go └── wormholed │ └── wormholed.go ├── mysql ├── Dockerfile └── Makefile ├── pkg ├── netaddr │ ├── LICENSE │ └── ip.go └── proxy │ ├── LICENSE │ ├── config │ ├── api.go │ ├── api_test.go │ ├── config.go │ ├── config_test.go │ ├── doc.go │ ├── etcd.go │ └── file.go │ ├── doc.go │ ├── loadbalancer.go │ ├── proxier.go │ ├── proxier_test.go │ ├── roundrobin.go │ ├── roundrobin_test.go │ └── udp_server.go ├── pong ├── .gitignore ├── Dockerfile ├── Makefile └── pong.go ├── server ├── api.go ├── echo.go ├── opts.go ├── segment.go ├── segment_test.go ├── server.go └── tunnel.go ├── utils ├── utils.go └── utils_test.go └── wordpress ├── Dockerfile └── Makefile /.gitignore: -------------------------------------------------------------------------------- 1 | wormhole 2 | wormholed 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | Copyright 2014 Vishvananda Ishaya. 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | http://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. 192 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SERVER_NAME = wormholed 2 | CLI_NAME = wormhole 3 | 4 | SHARED = \ 5 | client \ 6 | utils 7 | 8 | CLI = \ 9 | cli \ 10 | main/$(CLI_NAME) 11 | 12 | SERVER = \ 13 | pkg/netaddr \ 14 | pkg/proxy \ 15 | server \ 16 | main/$(SERVER_NAME) 17 | 18 | COMBINED := $(SHARED) $(CLI) $(SERVER) 19 | 20 | SHARED_DEPS = \ 21 | github.com/raff/tls-ext \ 22 | github.com/raff/tls-psk 23 | 24 | # kubernetes/pkg/api is needed for pkg/proxy 25 | # the other dependencies besides netns and netlink are for kubernetes 26 | SERVER_DEPS = \ 27 | github.com/GoogleCloudPlatform/kubernetes/pkg/api \ 28 | github.com/fsouza/go-dockerclient \ 29 | github.com/golang/glog \ 30 | code.google.com/p/go.net/context \ 31 | gopkg.in/v1/yaml \ 32 | github.com/vishvananda/netns \ 33 | github.com/vishvananda/netlink 34 | 35 | CLI_DEPS = 36 | 37 | TEST_DEPS = 38 | 39 | COMBINED_DEPS := $(SHARED_DEPS) $(CLI_DEPS) $(SERVER_DEPS) $(TEST_DEPS) 40 | 41 | uniq = $(if $1,$(firstword $1) $(call uniq,$(filter-out $(firstword $1),$1))) 42 | gofiles = $(foreach d,$(1),$(wildcard $(d)/*.go)) 43 | testdirs = $(call uniq,$(foreach d,$(1),$(dir $(wildcard $(d)/*_test.go)))) 44 | goroot = $(addprefix ../../../,$(1)) 45 | unroot = $(subst ../../../,,$(1)) 46 | 47 | all: $(SERVER_NAME) $(CLI_NAME) 48 | 49 | $(call goroot,$(COMBINED_DEPS)): 50 | go get $(call unroot,$@) 51 | 52 | $(SERVER_NAME): $(call goroot,$(SHARED_DEPS)) $(call goroot,$(SERVER_DEPS)) $(call gofiles,$(SERVER)) $(call gofiles,$(SHARED)) 53 | go build github.com/vishvananda/wormhole/main/wormholed 54 | 55 | $(CLI_NAME): $(call goroot,$(SHARED_DEPS)) $(call goroot,$(CLI_DEPS)) $(call gofiles,$(CLI)) $(call gofiles,$(SHARED)) 56 | go build github.com/vishvananda/wormhole/main/wormhole 57 | 58 | .PHONY: $(call testdirs,$(COMBINED)) 59 | $(call testdirs,$(COMBINED)): $(call goroot,$(TEST_DEPS)) 60 | sudo -E go test -v github.com/vishvananda/wormhole/$@ 61 | 62 | fmt: 63 | for dir in . $(COMBINED); do go fmt github.com/vishvananda/wormhole/$$dir; done 64 | 65 | test: test-unit test-functional 66 | 67 | test-unit: $(call testdirs,$(COMBINED)) 68 | 69 | .PHONY: pong 70 | pong: 71 | $(MAKE) -C pong all 72 | 73 | test-functional: $(SERVER_NAME) $(CLI_NAME) pong 74 | sudo -E go test -v functional_test.go 75 | 76 | .PHONY: clean 77 | clean: 78 | -rm wormhole wormholed 79 | 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wormhole # 2 | 3 | ### A smart proxy that connects docker containers ### 4 | 5 | Wormhole is a namespace-aware socket-activated tunneling proxy. It allows you 6 | to securely connect ports together on different physical machines inside 7 | docker containers. This allows for interesting things like connecting services 8 | running on localhost inside the container namespace or creating on-demand 9 | services that start when you connect to them. 10 | 11 | ## But Why? ## 12 | 13 | Containers give us the opportunity to move a whole bunch of complicated 14 | distributed systems problems (orchestration, service discovery, configuration 15 | management, and security) into the communication layer. They have the 16 | potential to finally give us [truly standard components](https://medium.com/@vishvananda/standard-components-not-standard-containers-c30567f23da6) 17 | 18 | This isn't intended to be a production solution for container relationships. 19 | Consider it an exploration of the above value. If you can create an 20 | application container that talks to a single port to get a secure connection 21 | to the database, many things get simpler. You don't have to configure the 22 | application with the proper address. You don't have to set a secure password 23 | for the database server. You can move the database without breaking the 24 | application. You can add a second database and start load balancing. 25 | 26 | Most importantly, standardizing the communication layer means that containers 27 | are trivially sharable. Everyone who needs a mysql database container can 28 | use the same one. No configuration of the container is necessary, you can just 29 | drop it in and start using it. 30 | 31 | Yes it is computationally more expensive to proxy connections, but consider: 32 | 33 | 1. It is possible to accomplish many of the same things with sdn instead 34 | of proxying 35 | 2. This proxy could replace the proxies that many services already use 36 | for load balancing or ssl termination. 37 | 38 | Some people may feel that is inappropriate to proxy localhost connections this 39 | way, that localhost traffic should always be local. The above principles can 40 | be accomplished by using another well-known ip address. The one advantage of 41 | the localhost approach is almost every application is configured to listen on 42 | localhost out of the box so it makes container builds very easy. 43 | 44 | ## Examples ## 45 | 46 | Wormhole connects services together. The examples use the a wordpress 47 | container and a mysql container as the canonical example of two things that 48 | need to be connected. 49 | 50 | ### Legend for diagrams ### 51 | ![ex-legend](https://cloud.githubusercontent.com/assets/142222/4346902/25bd6e18-411f-11e4-8f0c-b2a4cfa2208f.png) 52 | 53 | ### Proxy to the mysql in a local container ### 54 | ![ex-01](https://cloud.githubusercontent.com/assets/142222/4346904/2a7fb85c-411f-11e4-9637-0e7bbd5fe506.png) 55 | 56 | mysql=`docker run -d wormhole/mysql` 57 | ./wormhole create url :3306 docker-ns tail docker-ns $mysql 58 | mysql -u root -h 127.0.0.1 59 | 60 | This requires a local install of mysql-client (ubuntu: apt-get install mysql-client). 61 | 62 | ### Connect a local wp container to a local mysql container ### 63 | ![ex-02](https://cloud.githubusercontent.com/assets/142222/4346903/2a750024-411f-11e4-9aa4-818bfe05b0e1.png) 64 | 65 | app=`docker run -d wormhole/wordpress` 66 | mysql=`docker run -d wormhole/mysql` 67 | ./wormhole create url :3306 docker-ns $app tail docker-ns $mysql 68 | 69 | ### Create a local port that does the above on connection ### 70 | ![ex-03](https://cloud.githubusercontent.com/assets/142222/4346905/2a8f1446-411f-11e4-97e3-41060c6d2432.png) 71 | 72 | ./wormhole create url :80 trigger docker-run wormhole/wordpress \ 73 | child url :3306 trigger docker-run wormhole/mysql 74 | 75 | ### Create a local port to talk to a remote mysql ### 76 | ![ex-04](https://cloud.githubusercontent.com/assets/142222/4346908/2a96b5f2-411f-11e4-9e36-1921a8a3cbda.png) 77 | 78 | mysql=`docker -H myserver run -d wormhole/mysql` 79 | ./wormhole create url :3306 remote myserver tail url :3306 docker-ns $mysql 80 | 81 | The remote server must be running wormhole with the same key.secret 82 | 83 | ### Do the above over an ipsec tunnel ### 84 | ![ex-05](https://cloud.githubusercontent.com/assets/142222/4346907/2a96b868-411f-11e4-8e86-7bb2f9ff8e15.png) 85 | 86 | mysql=`docker -H myserver run -d wormhole/mysql` 87 | ./wormhole create url :3306 tunnel myserver trigger url :3306 docker-ns $mysql 88 | 89 | ### Create a local port that runs a remote mysql on connection ### 90 | ![ex-06](https://cloud.githubusercontent.com/assets/142222/4346909/2a96d406-411f-11e4-9461-3308404704ba.png) 91 | 92 | ./wormhole create url :3306 trigger tunnel myserver trigger url :3306 docker-run wormhole/mysql 93 | 94 | If the image has not been downloaded on 'myserver' then the initial 95 | connection will timeout. 96 | 97 | ### Create a local port that runs wp followed by the above ### 98 | ![ex-07](https://cloud.githubusercontent.com/assets/142222/4346906/2a949a4c-411f-11e4-9784-44ba18ca7a1d.png) 99 | 100 | ./wormhole create url :80 trigger docker-run wormhole/wordpress \ 101 | child url :3306 tunnel myserver trigger url :3306 docker-run wormhole/mysql 102 | 103 | ### Forget all this proxy stuff and make an ipsec tunnel ### 104 | ![ex-08](https://cloud.githubusercontent.com/assets/142222/4346910/2a973aa4-411f-11e4-8ff3-b4a7c6e4efce.png) 105 | 106 | 107 | ./wormhole tunnel-create myserver 108 | 109 | This command outputs a local and remote ip for the tunnel. Tunnels are 110 | not deleted when wormholed is closed. To delete the tunnel: 111 | 112 | ./wormhole tunnel-delete myserver 113 | 114 | ## Getting Started ## 115 | 116 | To get started you will need to: 117 | 118 | a) Create a secret key 119 | 120 | sudo mkdir -p /etc/wormhole 121 | cat /dev/urandom | tr -dc '0-9a-zA-Z' | head -c 32 | sudo tee /etc/wormhole/key.secret 122 | sudo chmod 600 /etc/wormhole/key.secret 123 | 124 | b) Run the daemon as root 125 | 126 | sudo ./wormholed 127 | 128 | The wormhole cli communicates with the daemon over port 9999. To verify it 129 | is working: 130 | 131 | ./wormhole ping 132 | 133 | ## Local Build and Test ## 134 | 135 | Getting the source code: 136 | 137 | go get github.com/vishvananda/wormhole 138 | 139 | Building the binaries (will install go dependencies via go get): 140 | 141 | make 142 | 143 | Testing dependencies: 144 | 145 | docker is required for functional tests 146 | 147 | Unit Tests (functional tests use sudo): 148 | 149 | make test-unit 150 | 151 | Functional tests (requires root): 152 | 153 | make test-functional # or sudo -E go test -v functional_test.go 154 | 155 | ## Alternative Tools ## 156 | 157 | Most of what wormhole does can be accomplished by hacking together various 158 | tools like socat and iproute2. 159 | 160 | ## Future Work ## 161 | 162 | Wormhole could be extended to support unix socket proxying. It would also be 163 | interesting to allow proxies from one type of socket to another a la socat. 164 | 165 | Wormhole discovers existing tunnels when it starts, but it doesn't attempt 166 | to cleanup if it finds a partial tunnel. This option could be added. 167 | 168 | Namespace support should be upstreamed to kubernetes/proxy so we don't have 169 | to maintain a fork. 170 | 171 | Commands for list and tunnel-list should be added. 172 | 173 | Wormhole could grow support for load balancing. 174 | 175 | Traffic analysis and reporting could be added to the proxy layer. 176 | 177 | ## Disclaimer ## 178 | 179 | Wormhole is alpha quality code and, while efforts have been made to keep the 180 | daemon secure, it must run as root and therefore offers a tempting attack 181 | surface. It is not recommended to run this in production until it has been 182 | more thouroughly vetted. 183 | -------------------------------------------------------------------------------- /cli/cli.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "fmt" 7 | "github.com/vishvananda/wormhole/client" 8 | "github.com/vishvananda/wormhole/utils" 9 | "log" 10 | "os" 11 | "time" 12 | ) 13 | 14 | func ping(args []string, c *client.Client) { 15 | host := "" 16 | if len(args) > 1 { 17 | log.Fatalf("Unknown args for ping: %v", args[1:]) 18 | } 19 | if len(args) == 1 { 20 | var err error 21 | host, err = utils.ValidateAddr(args[0]) 22 | if err != nil { 23 | log.Fatalf("%v", err) 24 | } 25 | } 26 | startTime := time.Now() 27 | 28 | log.Printf("Connection took %v", time.Since(startTime)) 29 | value := make([]byte, 16) 30 | rand.Read(value) 31 | echoTime := time.Now() 32 | result, err := c.Echo(value, host) 33 | if err != nil { 34 | log.Fatalf("client.Echo failed: %v", err) 35 | } 36 | if !bytes.Equal(value, result) { 37 | log.Fatalf("Incorrect response from echo") 38 | } 39 | log.Printf("Reply took %v: %v", time.Since(echoTime), result) 40 | // milliseconds 41 | fmt.Printf("%f\n", float64(time.Since(startTime))/1000000) 42 | } 43 | 44 | func tunnelCreate(args []string, c *client.Client) { 45 | host := "" 46 | udp := false 47 | filtered := make([]string, 0) 48 | for _, arg := range args { 49 | if arg == "--udp" { 50 | udp = true 51 | } else { 52 | filtered = append(filtered, arg) 53 | } 54 | } 55 | args = filtered 56 | if len(args) > 1 { 57 | log.Fatalf("Too many args for tunnel-create: %v", args[1:]) 58 | } else if len(args) == 0 { 59 | log.Fatalf("Argument host is required for tunnel-create") 60 | } 61 | host = args[0] 62 | var err error 63 | host, err = utils.ValidateAddr(args[0]) 64 | if err != nil { 65 | log.Fatalf("%v", err) 66 | } 67 | 68 | src, dst, err := c.CreateTunnel(host, udp) 69 | if err != nil { 70 | log.Fatalf("client.CreateTunnel failed: %v", err) 71 | } 72 | fmt.Printf("%v %v\n", src, dst) 73 | } 74 | 75 | func tunnelDelete(args []string, c *client.Client) { 76 | host := "" 77 | if len(args) > 1 { 78 | log.Fatalf("Unknown args for tunnel-delete: %v", args[1:]) 79 | } 80 | if len(args) == 1 { 81 | var err error 82 | host, err = utils.ValidateAddr(args[0]) 83 | if err != nil { 84 | log.Fatalf("%v", err) 85 | } 86 | } else { 87 | log.Fatalf("Argument host is required for tunnel-delete") 88 | } 89 | 90 | err := c.DeleteTunnel(host) 91 | if err != nil { 92 | log.Fatalf("client.DeleteTunnel failed: %v", err) 93 | } 94 | } 95 | 96 | func segmentCreate(args []string, c *client.Client) { 97 | id, init, trig, err := parseSegment(args) 98 | if err != nil { 99 | log.Fatalf("Could not parse create: %v", err) 100 | } 101 | 102 | url, err := c.CreateSegment(id, init, trig) 103 | if err != nil { 104 | log.Fatalf("client.CreateSegment failed: %v", err) 105 | } 106 | fmt.Printf("%v %v\n", id, url) 107 | } 108 | 109 | func segmentDelete(args []string, c *client.Client) { 110 | id := "" 111 | if len(args) > 1 { 112 | log.Fatalf("Unknown args for delete: %v", args[1:]) 113 | } 114 | if len(args) == 1 { 115 | id = args[0] 116 | } else { 117 | log.Fatalf("Argument id is required for delete") 118 | } 119 | 120 | err := c.DeleteSegment(id) 121 | if err != nil { 122 | log.Fatalf("client.DeleteSegment failed: %v", err) 123 | } 124 | } 125 | 126 | func parseSegment(args []string) (string, []client.SegmentCommand, []client.SegmentCommand, error) { 127 | id := utils.Uuid() 128 | s := client.SegmentCommand{} 129 | chain, tail, trigger := false, false, false 130 | command := "" 131 | cur := &s 132 | for len(args) > 0 { 133 | command, args = args[0], args[1:] 134 | var action *client.SegmentCommand 135 | switch command { 136 | case "id": 137 | id = parseName(&args) 138 | continue 139 | case "url": 140 | action = parseUrl(tail, &args) 141 | case "docker-ns": 142 | action = parseDockerNs(tail, &args) 143 | case "docker-run": 144 | action = parseDockerRun(tail, &args) 145 | case "child": 146 | action = parseChild() 147 | chain = true 148 | case "chain": 149 | action = parseChain() 150 | chain = true 151 | case "remote": 152 | action = parseRemote(&args) 153 | chain = true 154 | case "tunnel": 155 | action = parseTunnel(&args) 156 | chain = true 157 | case "udptunnel": 158 | action = parseUdptunnel(&args) 159 | chain = true 160 | case "tail": 161 | tail = true 162 | continue 163 | case "trigger": 164 | trigger = true 165 | tail = true 166 | continue 167 | default: 168 | log.Fatalf("Action %s not recognized", command) 169 | } 170 | if trigger { 171 | cur.AddTrig(action) 172 | if chain { 173 | cur = &cur.ChildTrig[len(cur.ChildTrig)-1] 174 | chain = false 175 | trigger = false 176 | tail = false 177 | } 178 | } else { 179 | cur.AddInit(action) 180 | if chain { 181 | cur = &cur.ChildInit[len(cur.ChildInit)-1] 182 | chain = false 183 | trigger = false 184 | tail = false 185 | } 186 | } 187 | } 188 | return id, s.ChildInit, s.ChildTrig, nil 189 | } 190 | 191 | func parseName(args *[]string) string { 192 | if len(*args) == 0 { 193 | createFail("Argument ID is required for id") 194 | } 195 | var id string 196 | id, *args = (*args)[0], (*args)[1:] 197 | return id 198 | } 199 | 200 | func parseUrl(tail bool, args *[]string) *client.SegmentCommand { 201 | if len(*args) == 0 { 202 | createFail("Argument URL is required for url") 203 | } 204 | 205 | url := (*args)[0] 206 | proto, _, _, _, err := utils.ParseUrl(url) 207 | if err != nil { 208 | createFail(fmt.Sprintf("Unable to parse URL: %v", url)) 209 | } 210 | if proto != "" && proto != "tcp" && proto != "udp" { 211 | createFail("Only tcp and udp protocols are currently supported.") 212 | } 213 | *args = (*args)[1:] 214 | return &client.SegmentCommand{Type: client.URL, Tail: tail, Arg: url} 215 | } 216 | 217 | func parseDockerNs(tail bool, args *[]string) *client.SegmentCommand { 218 | if len(*args) == 0 { 219 | createFail("Argument ID is required for docker-ns") 220 | } 221 | var id string 222 | id, *args = (*args)[0], (*args)[1:] 223 | return &client.SegmentCommand{Type: client.DOCKER_NS, Tail: tail, Arg: id} 224 | } 225 | 226 | func parseDockerRun(tail bool, args *[]string) *client.SegmentCommand { 227 | if len(*args) == 0 { 228 | createFail("Argument ARGS is required for docker-run") 229 | } 230 | var run string 231 | run, *args = (*args)[0], (*args)[1:] 232 | return &client.SegmentCommand{Type: client.DOCKER_RUN, Tail: tail, Arg: run} 233 | } 234 | 235 | func parseChild() *client.SegmentCommand { 236 | return &client.SegmentCommand{Type: client.CHILD} 237 | } 238 | 239 | func parseChain() *client.SegmentCommand { 240 | return &client.SegmentCommand{Type: client.CHAIN} 241 | } 242 | 243 | func parseRemote(args *[]string) *client.SegmentCommand { 244 | if len(*args) == 0 { 245 | createFail("Argument HOST is required for remote") 246 | } 247 | host, err := utils.ValidateAddr((*args)[0]) 248 | if err != nil { 249 | createFail(fmt.Sprintf("Unable to parse HOST: %v", host)) 250 | } 251 | *args = (*args)[1:] 252 | return &client.SegmentCommand{Type: client.REMOTE, Arg: host} 253 | } 254 | 255 | func parseTunnel(args *[]string) *client.SegmentCommand { 256 | if len(*args) == 0 { 257 | createFail("Argument HOST is required for tunnel") 258 | } 259 | host, err := utils.ValidateAddr((*args)[0]) 260 | if err != nil { 261 | createFail(fmt.Sprintf("Unable to parse HOST: %v", host)) 262 | } 263 | *args = (*args)[1:] 264 | return &client.SegmentCommand{Type: client.TUNNEL, Arg: host} 265 | } 266 | 267 | func parseUdptunnel(args *[]string) *client.SegmentCommand { 268 | if len(*args) == 0 { 269 | createFail("Argument HOST is required for udptunnel") 270 | } 271 | host, err := utils.ValidateAddr((*args)[0]) 272 | if err != nil { 273 | createFail(fmt.Sprintf("Unable to parse HOST: %v", host)) 274 | } 275 | *args = (*args)[1:] 276 | return &client.SegmentCommand{Type: client.UDPTUNNEL, Arg: host} 277 | } 278 | 279 | func createFail(msg string) { 280 | fmt.Println(msg) 281 | usage("create") 282 | } 283 | 284 | func usage(command string) { 285 | u := "" 286 | if command == "" { 287 | u = `Usage: %s [ OPTIONS ] [ help ] COMMAND { SUBCOMMAND ... } 288 | where COMMAND := { ping | create | delete | tunnel-create | tunnel-delete } 289 | OPTIONS := { -K[eyfile] | -H[ost] }` 290 | } else { 291 | switch command { 292 | case "ping": 293 | u = `Usage: %s ping HOST 294 | Pings wormholed on HOST and prints the latency in ms.` 295 | case "create": 296 | u = `Usage: %s create { SUBCOMMAND ... } 297 | where SUBCOMMAND := { url | name | docker-ns | docker-run | child | 298 | child | chain | remote | tunnel | udptunnel | 299 | tail | trigger } 300 | 301 | Creates a proxy wormhole. The wormhole has a head and a tail. The head 302 | represents where the proxy listens, and the tail represents where the 303 | proxy connects. Both the head and the tail have the following values: 304 | 305 | protocol: the protocol of the connection (currently udp or tcp) 306 | namespace: the network namespace of the connection 307 | host: hostname or ip address of the connection 308 | port: port of the connection 309 | 310 | Prints the id and the listen url of the wormhole. 311 | 312 | SUBCOMMANDS 313 | =========== 314 | 315 | url URL 316 | set the head data to values specified in URL 317 | URL is in the form {protocol://}{namespace@}{host}{:port} 318 | 319 | id ID 320 | sets the id of the wormhole to ID 321 | 322 | docker-ns ID 323 | set the namespace using docker ID 324 | 325 | docker-run ARGS 326 | docker-run using ARGS and set the namespace to the container's namespace 327 | 328 | child 329 | create a child wormhole using the current proxy values as a base 330 | everything following this command applies to child wormhole 331 | 332 | chain 333 | create a child wormhole using the current proxy values as a base 334 | set the current wormhole's values to the child wormhole 335 | everything following this command applies to child wormhole 336 | 337 | remote HOST 338 | create a child wormhole on HOST 339 | set the current wormhole's tail values to the child wormhole 340 | 341 | tunnel HOST 342 | create an ipsec tunnel to HOST 343 | create a child wormhole on HOST 344 | set the current wormhole's tail values to the child wormhole 345 | 346 | udptunnel HOST 347 | create an ipsec tunnel to HOST using espinudp encapsulation 348 | create a child wormhole on HOST 349 | set the current wormhole's tail values to the child wormhole 350 | 351 | tail 352 | all following commands modify the tail instead of the head 353 | 354 | trigger 355 | all following commands modify the tail instead of the head 356 | all following commands are executed when something connects to the head 357 | 358 | ` 359 | case "delete": 360 | u = `Usage: %s delete ID 361 | Deletes the proxy wormhole ID.` 362 | case "tunnel-create": 363 | u = `Usage: %s tunnel-create [--udp] HOST 364 | Creates an ipsec tunnel to HOST and prints out the source and destination 365 | tunnel ip addresses. If --udp is specified the tunnel will use espinudp 366 | encapsulation. Wormholed must be running on HOST with the same key as the 367 | local wormholed.` 368 | case "tunnel-delete": 369 | u = `Usage: %s tunnel-delete HOST 370 | Deletes an ipsec tunnel to HOST.` 371 | default: 372 | log.Printf("Unknown command: %v", command) 373 | } 374 | } 375 | fmt.Printf(u, os.Args[0]) 376 | fmt.Println() 377 | } 378 | 379 | func Main() { 380 | args := parseFlags() 381 | if len(args) == 0 { 382 | usage("") 383 | return 384 | } 385 | 386 | command, args := args[0], args[1:] 387 | if command == "help" { 388 | if len(args) != 0 { 389 | usage(args[0]) 390 | } else { 391 | usage("") 392 | } 393 | return 394 | } 395 | 396 | c, err := client.NewClient(opts.host, opts.config) 397 | if err != nil { 398 | log.Fatalf("Failed to create client: %v", err) 399 | } 400 | defer c.Close() 401 | switch command { 402 | case "ping": 403 | ping(args, c) 404 | case "create": 405 | segmentCreate(args, c) 406 | case "delete": 407 | segmentDelete(args, c) 408 | case "tunnel-create": 409 | tunnelCreate(args, c) 410 | case "tunnel-delete": 411 | tunnelDelete(args, c) 412 | default: 413 | log.Printf("Unknown command: %v", command) 414 | usage("") 415 | } 416 | } 417 | -------------------------------------------------------------------------------- /cli/cli_test.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "github.com/vishvananda/wormhole/client" 5 | "testing" 6 | ) 7 | 8 | func validateBasicParse(t *testing.T, args []string, commandType int) { 9 | _, init, _, err := parseSegment(args) 10 | if err != nil { 11 | t.Fatal(err) 12 | } 13 | if len(init) != 1 { 14 | t.Fatalf("Parse Failed to create action: %v", args) 15 | } 16 | if init[0].Type != commandType { 17 | t.Fatalf("Types don't match, %v: %s != %s", args, client.CommandName[init[0].Type], client.CommandName[commandType]) 18 | } 19 | } 20 | func TestSegmentParseBasic(t *testing.T) { 21 | validateBasicParse(t, []string{"url", ":40"}, client.URL) 22 | validateBasicParse(t, []string{"docker-ns", "foo"}, client.DOCKER_NS) 23 | validateBasicParse(t, []string{"docker-run", "foo"}, client.DOCKER_RUN) 24 | validateBasicParse(t, []string{"child"}, client.CHILD) 25 | validateBasicParse(t, []string{"chain"}, client.CHAIN) 26 | validateBasicParse(t, []string{"remote", "foo"}, client.REMOTE) 27 | validateBasicParse(t, []string{"tunnel", "foo"}, client.TUNNEL) 28 | validateBasicParse(t, []string{"udptunnel", "foo"}, client.UDPTUNNEL) 29 | } 30 | 31 | func TestParseId(t *testing.T) { 32 | id, _, _, err := parseSegment([]string{"id", "foo"}) 33 | if err != nil { 34 | t.Fatal(err) 35 | } 36 | if id != "foo" { 37 | t.Fatalf("Id Parse Failed") 38 | } 39 | } 40 | 41 | func TestSegmentParseComplex(t *testing.T) { 42 | args := []string{"id", "foo", "url", ":40", "docker-run", "bar"} 43 | id, init, _, err := parseSegment(args) 44 | if err != nil { 45 | t.Fatal(err) 46 | } 47 | if id != "foo" { 48 | t.Fatalf("Name Parse Failed") 49 | } 50 | if len(init) != 2 { 51 | t.Fatalf("Wrong number of actions, %v: %v", args, len(init)) 52 | } 53 | if init[0].Type != client.URL { 54 | t.Fatalf("Types don't match, %v: %s != %s", args, client.CommandName[init[0].Type], client.CommandName[client.URL]) 55 | } 56 | if init[1].Type != client.DOCKER_RUN { 57 | t.Fatalf("Types don't match, %v: %s != %s", args, client.CommandName[init[1].Type], client.CommandName[client.DOCKER_RUN]) 58 | } 59 | } 60 | 61 | func TestSegmentParseRemote(t *testing.T) { 62 | args := []string{"id", "foo", "url", ":40", "remote", "bar", "docker-run", "baz"} 63 | id, init, _, err := parseSegment(args) 64 | if err != nil { 65 | t.Fatal(err) 66 | } 67 | if id != "foo" { 68 | t.Fatalf("Name Parse Failed") 69 | } 70 | if len(init) != 2 { 71 | t.Fatalf("Wrong number of actions, %v: %v", args, len(init)) 72 | } 73 | if init[0].Type != client.URL { 74 | t.Fatalf("Types don't match, %v: %s != %s", args, client.CommandName[init[0].Type], client.CommandName[client.URL]) 75 | } 76 | if init[1].Type != client.REMOTE { 77 | t.Fatalf("Types don't match, %v: %s != %s", args, client.CommandName[init[1].Type], client.CommandName[client.REMOTE]) 78 | } 79 | if len(init[1].ChildInit) != 1 { 80 | t.Fatalf("Wrong number of child actions, %v: %v", args, len(init[1].ChildInit)) 81 | } 82 | if init[1].ChildInit[0].Type != client.DOCKER_RUN { 83 | t.Fatalf("Types don't match, %v: %s != %s", args, client.CommandName[init[1].ChildInit[0].Type], client.CommandName[client.DOCKER_RUN]) 84 | } 85 | } 86 | 87 | func TestSegmentParseTrigger(t *testing.T) { 88 | args := []string{"id", "foo", "url", ":40", "trigger", "docker-run", "baz"} 89 | id, init, trig, err := parseSegment(args) 90 | if err != nil { 91 | t.Fatal(err) 92 | } 93 | if id != "foo" { 94 | t.Fatalf("Name Parse Failed") 95 | } 96 | if len(init) != 1 { 97 | t.Fatalf("Wrong number of actions, %v: %v", args, len(init)) 98 | } 99 | if init[0].Type != client.URL { 100 | t.Fatalf("Types don't match, %v: %s != %s", args, client.CommandName[init[0].Type], client.CommandName[client.URL]) 101 | } 102 | if len(trig) != 1 { 103 | t.Fatalf("Wrong number of actions, %v: %v", args, len(trig)) 104 | } 105 | if trig[0].Type != client.DOCKER_RUN { 106 | t.Fatalf("Types don't match, %v: %s != %s", args, client.CommandName[trig[0].Type], client.CommandName[client.DOCKER_RUN]) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /cli/opts.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "flag" 5 | "io/ioutil" 6 | "log" 7 | "os" 8 | 9 | "github.com/raff/tls-ext" 10 | "github.com/raff/tls-psk" 11 | "github.com/vishvananda/wormhole/utils" 12 | ) 13 | 14 | type options struct { 15 | host string 16 | config *tls.Config 17 | } 18 | 19 | var opts *options 20 | 21 | func parseFlags() []string { 22 | keyfile := flag.String("K", "/etc/wormhole/key.secret", "Keyfile for psk auth (if not found defaults to insecure key)") 23 | host := flag.String("H", "127.0.0.1", "server tcp://host:port or unix://path/to/socket") 24 | 25 | flag.Parse() 26 | validHost, err := utils.ValidateAddr(*host) 27 | if err != nil { 28 | log.Fatalf("%v", err) 29 | } 30 | 31 | key := "wormhole" 32 | b, err := ioutil.ReadFile(*keyfile) 33 | if err != nil { 34 | log.Printf("Failed to open keyfile %s: %v", *keyfile, err) 35 | log.Printf("** WARNING: USING INSECURE PRE-SHARED-KEY **") 36 | } else { 37 | key = string(b) 38 | } 39 | var config = &tls.Config{ 40 | CipherSuites: []uint16{psk.TLS_PSK_WITH_AES_128_CBC_SHA}, 41 | Certificates: []tls.Certificate{tls.Certificate{}}, 42 | Extra: psk.PSKConfig{ 43 | GetKey: func(id string) ([]byte, error) { 44 | return []byte(key), nil 45 | }, 46 | GetIdentity: func() string { 47 | name, err := os.Hostname() 48 | if err != nil { 49 | log.Printf("Failed to determine hostname: %v", err) 50 | return "wormhole" 51 | } 52 | return name 53 | }, 54 | }, 55 | } 56 | 57 | opts = &options{ 58 | host: validHost, 59 | config: config, 60 | } 61 | 62 | return flag.Args() 63 | } 64 | -------------------------------------------------------------------------------- /client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "bytes" 5 | "github.com/raff/tls-ext" 6 | "github.com/vishvananda/wormhole/utils" 7 | "net" 8 | "net/rpc" 9 | ) 10 | 11 | const ( 12 | NONE = iota 13 | URL = iota 14 | DOCKER_NS = iota 15 | DOCKER_RUN = iota 16 | CHILD = iota 17 | CHAIN = iota 18 | REMOTE = iota 19 | TUNNEL = iota 20 | UDPTUNNEL = iota 21 | ) 22 | 23 | var CommandName = []string{ 24 | NONE: "none", 25 | DOCKER_NS: "docker-ns", 26 | DOCKER_RUN: "docker-run", 27 | CHILD: "child", 28 | CHAIN: "chain", 29 | REMOTE: "remote", 30 | TUNNEL: "tunnel", 31 | URL: "url", 32 | } 33 | 34 | type SegmentCommand struct { 35 | Type int 36 | Tail bool 37 | Arg string 38 | ChildInit []SegmentCommand 39 | ChildTrig []SegmentCommand 40 | } 41 | 42 | type Tunnel struct { 43 | Reqid int 44 | AuthKey []byte 45 | EncKey []byte 46 | Src net.IP 47 | Dst net.IP 48 | SrcPort int 49 | DstPort int 50 | } 51 | 52 | func (t Tunnel) Equal(o *Tunnel) bool { 53 | return t.Reqid == o.Reqid && bytes.Equal(t.AuthKey, o.AuthKey) && bytes.Equal(t.EncKey, o.EncKey) && t.Src.Equal(o.Src) && t.Dst.Equal(o.Dst) && t.SrcPort == o.SrcPort && t.DstPort == o.DstPort 54 | } 55 | 56 | func (s *SegmentCommand) AddInit(c *SegmentCommand) { 57 | s.ChildInit = append(s.ChildInit, *c) 58 | } 59 | 60 | func (s *SegmentCommand) AddTrig(c *SegmentCommand) { 61 | s.ChildTrig = append(s.ChildTrig, *c) 62 | } 63 | 64 | type Client struct { 65 | RpcClient *rpc.Client 66 | } 67 | 68 | func NewClient(host string, config *tls.Config) (*Client, error) { 69 | proto, address := utils.ParseAddr(host) 70 | conn, err := tls.Dial(proto, address, config) 71 | if err != nil { 72 | return nil, err 73 | } 74 | return &Client{rpc.NewClient(conn)}, nil 75 | } 76 | 77 | func (c *Client) Close() error { 78 | return c.RpcClient.Close() 79 | } 80 | 81 | type EchoArgs struct { 82 | Value []byte 83 | Host string 84 | } 85 | 86 | type EchoReply struct { 87 | Value []byte 88 | } 89 | 90 | func (c *Client) Echo(value []byte, host string) ([]byte, error) { 91 | reply := EchoReply{} 92 | args := EchoArgs{value, host} 93 | err := c.RpcClient.Call("Api.Echo", args, &reply) 94 | return reply.Value, err 95 | } 96 | 97 | type CreateTunnelArgs struct { 98 | Host string 99 | Udp bool 100 | } 101 | 102 | type CreateTunnelReply struct { 103 | Src net.IP 104 | Dst net.IP 105 | } 106 | 107 | func (c *Client) CreateTunnel(host string, udp bool) (net.IP, net.IP, error) { 108 | reply := CreateTunnelReply{} 109 | args := CreateTunnelArgs{host, udp} 110 | err := c.RpcClient.Call("Api.CreateTunnel", args, &reply) 111 | return reply.Src, reply.Dst, err 112 | } 113 | 114 | type DeleteTunnelArgs struct { 115 | Host string 116 | } 117 | 118 | type DeleteTunnelReply struct { 119 | } 120 | 121 | func (c *Client) DeleteTunnel(host string) error { 122 | reply := DeleteTunnelReply{} 123 | args := DeleteTunnelArgs{host} 124 | err := c.RpcClient.Call("Api.DeleteTunnel", args, &reply) 125 | return err 126 | } 127 | 128 | type CreateSegmentArgs struct { 129 | Id string 130 | Init []SegmentCommand 131 | Trig []SegmentCommand 132 | } 133 | 134 | type CreateSegmentReply struct { 135 | Url string 136 | } 137 | 138 | func (c *Client) CreateSegment(id string, init []SegmentCommand, trig []SegmentCommand) (string, error) { 139 | reply := CreateSegmentReply{} 140 | args := CreateSegmentArgs{id, init, trig} 141 | err := c.RpcClient.Call("Api.CreateSegment", args, &reply) 142 | return reply.Url, err 143 | } 144 | 145 | type DeleteSegmentArgs struct { 146 | Id string 147 | } 148 | 149 | type DeleteSegmentReply struct { 150 | } 151 | 152 | func (c *Client) DeleteSegment(id string) error { 153 | reply := DeleteSegmentReply{} 154 | args := DeleteSegmentArgs{id} 155 | err := c.RpcClient.Call("Api.DeleteSegment", args, &reply) 156 | return err 157 | } 158 | 159 | type GetSrcIPArgs struct { 160 | Dst net.IP 161 | } 162 | 163 | type GetSrcIPReply struct { 164 | Src net.IP 165 | } 166 | 167 | func (c *Client) GetSrcIP(dst net.IP) (net.IP, error) { 168 | reply := GetSrcIPReply{} 169 | args := GetSrcIPArgs{dst} 170 | err := c.RpcClient.Call("Api.GetSrcIP", args, &reply) 171 | return reply.Src, err 172 | } 173 | 174 | type BuildTunnelArgs struct { 175 | Dst net.IP 176 | Tunnel *Tunnel 177 | } 178 | 179 | type BuildTunnelReply struct { 180 | Src net.IP 181 | Tunnel *Tunnel 182 | } 183 | 184 | func (c *Client) BuildTunnel(dst net.IP, tunnel *Tunnel) (net.IP, *Tunnel, error) { 185 | reply := BuildTunnelReply{} 186 | args := BuildTunnelArgs{dst, tunnel} 187 | err := c.RpcClient.Call("Api.BuildTunnel", args, &reply) 188 | return reply.Src, reply.Tunnel, err 189 | } 190 | 191 | type DestroyTunnelArgs struct { 192 | Dst net.IP 193 | } 194 | 195 | // DestroyTunnel has no reply value 196 | type DestroyTunnelReply struct { 197 | Src net.IP 198 | Error error 199 | } 200 | 201 | func (c *Client) DestroyTunnel(dst net.IP) (net.IP, error) { 202 | reply := DestroyTunnelReply{} 203 | args := DestroyTunnelArgs{dst} 204 | err := c.RpcClient.Call("Api.DestroyTunnel", args, &reply) 205 | return reply.Src, err 206 | } 207 | -------------------------------------------------------------------------------- /functional_test.go: -------------------------------------------------------------------------------- 1 | package functional 2 | 3 | import ( 4 | "bytes" 5 | "net" 6 | "os/exec" 7 | "runtime" 8 | "strconv" 9 | "strings" 10 | "sync" 11 | "syscall" 12 | "testing" 13 | "time" 14 | 15 | "github.com/vishvananda/netlink" 16 | "github.com/vishvananda/netns" 17 | ) 18 | 19 | type Server struct { 20 | Cmd *exec.Cmd 21 | CommandLine string 22 | Outb *bytes.Buffer 23 | Errb *bytes.Buffer 24 | } 25 | 26 | type Context struct { 27 | t *testing.T 28 | Failed bool 29 | Servers []Server 30 | Ns *netns.NsHandle 31 | } 32 | 33 | func addrAdd(t *testing.T, l netlink.Link, a string) { 34 | addr, err := netlink.ParseAddr(a) 35 | if err != nil { 36 | t.Fatal(err) 37 | } 38 | 39 | err = netlink.AddrAdd(l, addr) 40 | if err != nil && err != syscall.EEXIST { 41 | t.Fatal(err) 42 | } 43 | } 44 | 45 | func ensureNetwork(t *testing.T) *netns.NsHandle { 46 | addrs, err := netlink.AddrList(nil, netlink.FAMILY_V4) 47 | if err != nil { 48 | t.Fatal("Failed to list addresses", err) 49 | } 50 | for _, a := range addrs { 51 | if a.Label == "lo-wormhole" { 52 | // NOTE(vish): We are already namespaced so just continue. This 53 | // means we can leak data between tests if wormhole 54 | // doesn't cleanup afeter itself, but it makes the 55 | // tests run 5x faster. 56 | return nil 57 | } 58 | } 59 | ns, err := netns.New() 60 | if err != nil { 61 | t.Fatal("Failed to create newns", ns) 62 | } 63 | link, err := netlink.LinkByName("lo") 64 | if err != nil { 65 | t.Fatal(err) 66 | } 67 | 68 | err = netlink.LinkSetUp(link) 69 | if err != nil { 70 | t.Fatal(err) 71 | } 72 | 73 | addrAdd(t, link, "127.0.0.1/32 lo-wormhole") 74 | addrAdd(t, link, "127.0.0.2/32") 75 | 76 | _, dst, _ := net.ParseCIDR("0.0.0.0/0") 77 | err = netlink.RouteAdd(&netlink.Route{LinkIndex: link.Attrs().Index, Dst: dst}) 78 | if err != nil { 79 | t.Fatal(err) 80 | } 81 | return &ns 82 | } 83 | func getContext(t *testing.T) *Context { 84 | runtime.LockOSThread() 85 | ns := ensureNetwork(t) 86 | c := &Context{t: t, Ns: ns} 87 | return c 88 | } 89 | 90 | func conditionalClose(ns *netns.NsHandle) { 91 | if ns != nil && ns.IsOpen() { 92 | ns.Close() 93 | } 94 | } 95 | 96 | func (c *Context) cleanup() { 97 | defer runtime.UnlockOSThread() 98 | defer conditionalClose(c.Ns) 99 | var wg sync.WaitGroup 100 | for _, s := range c.Servers { 101 | err := s.Cmd.Process.Signal(syscall.SIGTERM) 102 | if err != nil { 103 | c.Logf("Failed to terminate %s: %v", s.CommandLine, err) 104 | } 105 | done := make(chan error, 1) 106 | go func(s Server) { 107 | done <- s.Cmd.Wait() 108 | }(s) 109 | wg.Add(1) 110 | go c.waitExit(s.Cmd, s.CommandLine, done, &wg) 111 | } 112 | wg.Wait() 113 | if c.Failed { 114 | for _, s := range c.Servers { 115 | c.Logf("DUMPING PROCESS '%s'", s.CommandLine) 116 | c.Logf("STDOUT") 117 | c.Logf(string(s.Outb.Bytes())) 118 | c.Logf("STDERR") 119 | c.Logf(string(s.Errb.Bytes())) 120 | } 121 | } 122 | } 123 | 124 | func (c *Context) waitExit(cmd *exec.Cmd, commandLine string, done chan error, wg *sync.WaitGroup) { 125 | if wg != nil { 126 | defer wg.Done() 127 | } 128 | select { 129 | case <-time.After(1 * time.Second): 130 | c.Logf("Killing process %s because it did not exit in time", commandLine) 131 | err := cmd.Process.Kill() 132 | if err != nil { 133 | c.Logf("Failed to kill %s: %v", commandLine, err) 134 | } 135 | <-done // allow goroutine to exit 136 | case err := <-done: 137 | if exiterr, ok := err.(*exec.ExitError); ok { 138 | if status, ok := exiterr.Sys().(syscall.WaitStatus); ok { 139 | if status.Signaled() { 140 | signal := status.Signal() 141 | if signal != syscall.SIGTERM { 142 | c.Failed = true 143 | c.Logf("Process %s exited with abnormal signal %d", commandLine, status.Signal()) 144 | } 145 | } else { 146 | exitStatus := status.ExitStatus() 147 | if exitStatus != 0 { 148 | c.Failed = true 149 | c.Logf("Process %s exited with abnormal status %d", commandLine, exitStatus) 150 | } 151 | } 152 | } 153 | } else if err != nil { 154 | c.Failed = true 155 | c.Logf("Process %s exited with error %v", commandLine, err) 156 | } 157 | } 158 | } 159 | 160 | func (c *Context) Fatalf(format string, arg ...interface{}) { 161 | c.Failed = true 162 | c.t.Fatalf(format, arg...) 163 | } 164 | 165 | func (c *Context) Logf(format string, arg ...interface{}) { 166 | c.t.Logf(format, arg...) 167 | } 168 | 169 | func (c *Context) start(name string, arg ...string) { 170 | s := Server{} 171 | s.Cmd = exec.Command(name, arg...) 172 | s.CommandLine = strings.Join(append([]string{name}, arg...), " ") 173 | 174 | s.Outb = new(bytes.Buffer) 175 | s.Errb = new(bytes.Buffer) 176 | s.Cmd.Stdout = s.Outb 177 | s.Cmd.Stderr = s.Errb 178 | 179 | err := s.Cmd.Start() 180 | if err != nil { 181 | c.Fatalf("Error starting %s: %v", s.CommandLine, err) 182 | } 183 | c.Servers = append(c.Servers, s) 184 | } 185 | 186 | func (c *Context) execute(name string, arg ...string) (string, string) { 187 | cmd := exec.Command(name, arg...) 188 | commandLine := strings.Join(append([]string{name}, arg...), " ") 189 | 190 | var err error 191 | var outb, errb bytes.Buffer 192 | cmd.Stdout = &outb 193 | cmd.Stderr = &errb 194 | 195 | err = cmd.Start() 196 | if err != nil { 197 | c.Fatalf("Error starting %s: %v", commandLine, err) 198 | } 199 | done := make(chan error, 1) 200 | go func() { 201 | done <- cmd.Wait() 202 | }() 203 | c.waitExit(cmd, commandLine, done, nil) 204 | return string(outb.Bytes()), string(errb.Bytes()) 205 | } 206 | 207 | func (c *Context) listening(address string) bool { 208 | network := "tcp" 209 | parts := strings.SplitN(address, "://", 2) 210 | if len(parts) == 2 { 211 | network = parts[0] 212 | address = parts[1] 213 | } 214 | if network != "unix" && !strings.Contains(address, ":") { 215 | address += ":9999" 216 | } 217 | conn, err := net.Dial(network, address) 218 | if err != nil { 219 | return false 220 | } 221 | conn.Close() 222 | return true 223 | } 224 | 225 | func (c *Context) wait(address string) { 226 | for i := 0; i < 100; i++ { 227 | if c.listening(address) { 228 | return 229 | } 230 | time.Sleep(10 * time.Millisecond) 231 | } 232 | c.Fatalf("Nothing started listening on '%s'", address) 233 | } 234 | 235 | func (c *Context) sendTimeout(msg string, address string) string { 236 | conn, err := net.Dial("tcp", address) 237 | if err != nil { 238 | c.Fatalf("Dial to %s failed: %v", address, err) 239 | } 240 | defer conn.Close() 241 | 242 | _, err = conn.Write([]byte(msg)) 243 | if err != nil { 244 | c.Fatalf("Write to server at %s failed: %v", address, err) 245 | } 246 | 247 | reply := make([]byte, 1024) 248 | 249 | n, err := conn.Read(reply) 250 | if err != nil { 251 | c.Fatalf("Read from server at %s failed: %v", address, err) 252 | } 253 | return string(reply[:n]) 254 | } 255 | 256 | func (c *Context) validatePolicy(policy netlink.XfrmPolicy, src string, dst string) { 257 | if policy.Tmpls[0].Dst.String() != dst || policy.Tmpls[0].Src.String() != src { 258 | c.Fatalf("Policy src and dst don't match: %v != %v, %v", policy, dst, src) 259 | } 260 | } 261 | 262 | func (c *Context) validateState(state netlink.XfrmState, src string, dst string, udp bool) { 263 | if state.Dst.String() != dst || state.Src.String() != src { 264 | c.Fatalf("State src and dst don't match: %v != %v, %v", state, dst, src) 265 | } 266 | if udp != (state.Encap != nil) { 267 | c.Fatalf("Udp and encap don't match: %v != %v", udp, state.Encap != nil) 268 | } 269 | } 270 | 271 | func (c *Context) validateTunnel(udp bool) { 272 | policies, err := netlink.XfrmPolicyList(netlink.FAMILY_ALL) 273 | if err != nil { 274 | c.Fatalf("Failed to get policies: %v", err) 275 | } 276 | if len(policies) != 4 { 277 | c.Fatalf("Wrong number of policies found: %v", policies) 278 | } 279 | c.validatePolicy(policies[0], "127.0.0.2", "127.0.0.1") 280 | c.validatePolicy(policies[1], "127.0.0.1", "127.0.0.2") 281 | c.validatePolicy(policies[2], "127.0.0.1", "127.0.0.2") 282 | c.validatePolicy(policies[3], "127.0.0.2", "127.0.0.1") 283 | states, err := netlink.XfrmStateList(netlink.FAMILY_ALL) 284 | if err != nil { 285 | c.Fatalf("Failed to get states: %v", err) 286 | } 287 | if len(states) != 2 { 288 | c.Fatalf("Wrong number of states found: %v", states) 289 | } 290 | c.validateState(states[0], "127.0.0.1", "127.0.0.2", udp) 291 | c.validateState(states[1], "127.0.0.2", "127.0.0.1", udp) 292 | } 293 | 294 | func (c *Context) validateNoTunnel() { 295 | policies, err := netlink.XfrmPolicyList(netlink.FAMILY_ALL) 296 | if err != nil { 297 | c.Fatalf("Failed to get policies: %v", err) 298 | } 299 | if len(policies) != 0 { 300 | c.Fatalf("Policies not removed") 301 | } 302 | states, err := netlink.XfrmStateList(netlink.FAMILY_ALL) 303 | if err != nil { 304 | c.Fatalf("Failed to get states: %v", err) 305 | } 306 | if len(states) != 0 { 307 | c.Fatalf("States not removed") 308 | } 309 | } 310 | 311 | const ( 312 | SERVER = "./wormholed" 313 | CLIENT = "./wormhole" 314 | PONG = "./pong/pong" 315 | ) 316 | 317 | func TestServerStartTerminate(t *testing.T) { 318 | c := getContext(t) 319 | defer c.cleanup() 320 | c.start(SERVER) 321 | c.wait("") 322 | } 323 | 324 | func TestServerStartTerminateOtherPort(t *testing.T) { 325 | c := getContext(t) 326 | defer c.cleanup() 327 | host := ":6666" 328 | c.start(SERVER, "-H", host) 329 | c.wait(host) 330 | } 331 | 332 | func TestPing(t *testing.T) { 333 | c := getContext(t) 334 | defer c.cleanup() 335 | c.start(SERVER) 336 | c.wait("") 337 | stdout, _ := c.execute(CLIENT, "ping") 338 | pingms, err := strconv.ParseFloat(strings.TrimSpace(stdout), 64) 339 | if err != nil { 340 | c.Fatalf("Failed to convert result of ping to float: %v", stdout) 341 | } 342 | if pingms > 10.0 { 343 | c.Fatalf("Ping took too long: %v", pingms) 344 | } 345 | } 346 | 347 | func TestPingUnix(t *testing.T) { 348 | c := getContext(t) 349 | defer c.cleanup() 350 | host := "unix://./socket" 351 | c.start(SERVER, "-H", host) 352 | c.wait(host) 353 | stdout, _ := c.execute(CLIENT, "-H", host, "ping") 354 | pingms, err := strconv.ParseFloat(strings.TrimSpace(stdout), 64) 355 | if err != nil { 356 | c.Fatalf("Failed to convert result of ping to float: %v", stdout) 357 | } 358 | if pingms > 10.0 { 359 | c.Fatalf("Ping took too long: %v", pingms) 360 | } 361 | } 362 | 363 | func TestPingRemote(t *testing.T) { 364 | c := getContext(t) 365 | defer c.cleanup() 366 | c.start(SERVER) 367 | c.wait("") 368 | host := ":6666" 369 | c.start(SERVER, "-H", host) 370 | c.wait(host) 371 | stdout, _ := c.execute(CLIENT, "ping", host) 372 | pingms, err := strconv.ParseFloat(strings.TrimSpace(stdout), 64) 373 | if err != nil { 374 | c.Fatalf("Failed to convert result of ping to float: %v", stdout) 375 | } 376 | if pingms > 10.0 { 377 | c.Fatalf("Ping took too long: %v", pingms) 378 | } 379 | } 380 | 381 | func TestTunnel(t *testing.T) { 382 | c := getContext(t) 383 | defer c.cleanup() 384 | c.start(SERVER, "-I", "127.0.0.1") 385 | c.wait("") 386 | host := ":6666" 387 | c.start(SERVER, "-H", host, "-I", "127.0.0.2") 388 | c.wait(host) 389 | c.execute(CLIENT, "tunnel-create", ":6666") 390 | c.validateTunnel(false) 391 | c.execute(CLIENT, "tunnel-delete", ":6666") 392 | c.validateNoTunnel() 393 | } 394 | 395 | func TestTunnelDoubleCreate(t *testing.T) { 396 | c := getContext(t) 397 | defer c.cleanup() 398 | c.start(SERVER, "-I", "127.0.0.1") 399 | c.wait("") 400 | host := ":6666" 401 | c.start(SERVER, "-H", host, "-I", "127.0.0.2") 402 | c.wait(host) 403 | stdout1, _ := c.execute(CLIENT, "tunnel-create", ":6666") 404 | stdout2, _ := c.execute(CLIENT, "tunnel-create", ":6666") 405 | if stdout1 != stdout2 { 406 | c.Fatalf("Second tunnel create retuned new result: %s != %s", stdout1, stdout2) 407 | } 408 | c.validateTunnel(false) 409 | c.execute(CLIENT, "tunnel-delete", ":6666") 410 | c.validateNoTunnel() 411 | } 412 | 413 | func TestTunnelUdp(t *testing.T) { 414 | c := getContext(t) 415 | defer c.cleanup() 416 | c.start(SERVER, "-I", "127.0.0.1") 417 | c.wait("") 418 | host := ":6666" 419 | c.start(SERVER, "-H", host, "-I", "127.0.0.2", "-P", "4501") 420 | c.wait(host) 421 | c.execute(CLIENT, "tunnel-create", "--udp", ":6666") 422 | c.validateTunnel(true) 423 | c.execute(CLIENT, "tunnel-delete", ":6666") 424 | c.validateNoTunnel() 425 | } 426 | 427 | func TestCreateDelete(t *testing.T) { 428 | c := getContext(t) 429 | defer c.cleanup() 430 | c.start(SERVER) 431 | c.wait("") 432 | stdout, _ := c.execute(CLIENT, "create", "url", ":9000", "tail", "url", ":9001") 433 | parts := strings.Fields(strings.TrimSpace(stdout)) 434 | if len(parts) != 2 { 435 | c.Fatalf("Bad data returned from create: %s", stdout) 436 | } 437 | id := parts[0] 438 | if !c.listening(":9000") { 439 | c.Fatalf("Segment is not listening") 440 | } 441 | c.execute(CLIENT, "delete", id) 442 | if c.listening(":9000") { 443 | c.Fatalf("Segment is still listening") 444 | } 445 | } 446 | 447 | func TestDockerRun(t *testing.T) { 448 | c := getContext(t) 449 | defer c.cleanup() 450 | c.start(SERVER) 451 | c.wait("") 452 | c.execute(CLIENT, "create", "url", ":9000", "tail", "url", ":9001", "docker-run", "wormhole/pong") 453 | msg := "ping" 454 | result := c.sendTimeout(msg, ":9000") 455 | if result != msg { 456 | c.Fatalf("Incorrect response from ping: %s != %s", result, msg) 457 | } 458 | } 459 | 460 | func TestChain(t *testing.T) { 461 | c := getContext(t) 462 | defer c.cleanup() 463 | c.start(SERVER) 464 | c.wait("") 465 | c.execute(CLIENT, "create", "url", ":9000", "chain", "url", ":9001", "tail", "url", ":9002") 466 | c.start(PONG, ":9002") 467 | msg := "ping" 468 | result := c.sendTimeout(msg, ":9000") 469 | if result != msg { 470 | c.Fatalf("Incorrect response from ping: %s != %s", result, msg) 471 | } 472 | } 473 | 474 | func TestRemote(t *testing.T) { 475 | c := getContext(t) 476 | defer c.cleanup() 477 | c.start(SERVER, "-I", "127.0.0.1") 478 | c.wait("") 479 | host := ":6666" 480 | c.start(SERVER, "-H", host, "-I", "127.0.0.2") 481 | c.wait(host) 482 | c.execute(CLIENT, "create", "url", ":9000", "remote", ":6666", "url", ":9001", "tail", "url", ":9002") 483 | c.start(PONG, ":9002") 484 | msg := "ping" 485 | result := c.sendTimeout(msg, ":9000") 486 | if result != msg { 487 | c.Fatalf("Incorrect response from ping: %s != %s", result, msg) 488 | } 489 | } 490 | 491 | func TestRemoteTunnel(t *testing.T) { 492 | c := getContext(t) 493 | defer c.cleanup() 494 | c.start(SERVER, "-I", "127.0.0.1") 495 | c.wait("") 496 | host := ":6666" 497 | c.start(SERVER, "-H", host, "-I", "127.0.0.2") 498 | c.wait(host) 499 | c.execute(CLIENT, "create", "url", ":9000", "tunnel", ":6666", "url", ":9001", "tail", "url", ":9002") 500 | c.start(PONG, ":9002") 501 | msg := "ping" 502 | result := c.sendTimeout(msg, ":9000") 503 | if result != msg { 504 | c.Fatalf("Incorrect response from ping: %s != %s", result, msg) 505 | } 506 | c.validateTunnel(false) 507 | c.execute(CLIENT, "tunnel-delete", ":6666") 508 | c.validateNoTunnel() 509 | } 510 | 511 | func TestRemoteUdptunnel(t *testing.T) { 512 | c := getContext(t) 513 | defer c.cleanup() 514 | c.start(SERVER, "-I", "127.0.0.1") 515 | c.wait("") 516 | host := ":6666" 517 | c.start(SERVER, "-H", host, "-I", "127.0.0.2", "-P", "4501") 518 | c.wait(host) 519 | c.execute(CLIENT, "create", "url", ":9000", "udptunnel", ":6666", "url", ":9001", "tail", "url", ":9002") 520 | c.start(PONG, ":9002") 521 | msg := "ping" 522 | result := c.sendTimeout(msg, ":9000") 523 | if result != msg { 524 | c.Fatalf("Incorrect response from ping: %s != %s", result, msg) 525 | } 526 | c.validateTunnel(true) 527 | c.execute(CLIENT, "tunnel-delete", ":6666") 528 | c.validateNoTunnel() 529 | } 530 | -------------------------------------------------------------------------------- /main/wormhole/wormhole.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/vishvananda/wormhole/cli" 5 | ) 6 | 7 | func main() { 8 | cli.Main() 9 | } 10 | -------------------------------------------------------------------------------- /main/wormholed/wormholed.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/vishvananda/wormhole/server" 5 | ) 6 | 7 | func main() { 8 | server.Main() 9 | } 10 | -------------------------------------------------------------------------------- /mysql/Dockerfile: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # Dockerfile for mysql 3 | # Based on tutum/mysql 4 | ############################################################ 5 | 6 | FROM tutum/mysql 7 | 8 | MAINTAINER Vishvananda Ishaya 9 | 10 | ENV MYSQL_PASS simple 11 | -------------------------------------------------------------------------------- /mysql/Makefile: -------------------------------------------------------------------------------- 1 | all: docker-mysql 2 | 3 | .PHONY: docker-mysql 4 | docker-mysql: 5 | docker build -t wormhole/mysql . 6 | -------------------------------------------------------------------------------- /pkg/netaddr/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013, Michal Derkacz 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright 10 | notice, this list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | 3. The name of the author may not be used to endorse or promote products 13 | derived from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 16 | IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 17 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 18 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 19 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 20 | NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 24 | THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /pkg/netaddr/ip.go: -------------------------------------------------------------------------------- 1 | package netaddr 2 | 3 | import ( 4 | "math" 5 | "net" 6 | ) 7 | 8 | func isZeros(p net.IP) bool { 9 | for _, b := range p { 10 | if b != 0 { 11 | return false 12 | } 13 | } 14 | return true 15 | } 16 | 17 | // IsIPv4 returns true if ip is IPv4 address. 18 | func IsIPv4(ip net.IP) bool { 19 | return len(ip) == net.IPv4len || 20 | isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff 21 | } 22 | 23 | func ipToI32(ip net.IP) uint32 { 24 | ip = ip.To4() 25 | return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) 26 | } 27 | 28 | func i32ToIP(a uint32) net.IP { 29 | return net.IPv4(byte(a>>24), byte(a>>16), byte(a>>8), byte(a)) 30 | } 31 | 32 | func ipToU64(ip net.IP) uint64 { 33 | return uint64(ip[0])<<56 | uint64(ip[1])<<48 | uint64(ip[2])<<40 | 34 | uint64(ip[3])<<32 | uint64(ip[4])<<24 | uint64(ip[5])<<16 | 35 | uint64(ip[6])<<8 | uint64(ip[7]) 36 | } 37 | 38 | func u64ToIP(ip net.IP, a uint64) { 39 | ip[0] = byte(a >> 56) 40 | ip[1] = byte(a >> 48) 41 | ip[2] = byte(a >> 40) 42 | ip[3] = byte(a >> 32) 43 | ip[4] = byte(a >> 24) 44 | ip[5] = byte(a >> 16) 45 | ip[6] = byte(a >> 8) 46 | ip[7] = byte(a) 47 | } 48 | 49 | // IPAdd adds offset to ip 50 | func IPAdd(ip net.IP, offset uint64) net.IP { 51 | if IsIPv4(ip) { 52 | a := uint64(ipToI32(ip[len(ip)-4:])) 53 | return i32ToIP(uint32(a + offset)) 54 | } 55 | a := ipToU64(ip[:net.IPv6len/2]) 56 | b := ipToU64(ip[net.IPv6len/2:]) 57 | o := uint64(offset) 58 | if math.MaxUint64-b < o { 59 | a++ 60 | } 61 | b += o 62 | if offset < 0 { 63 | a += math.MaxUint64 64 | } 65 | ip = make(net.IP, net.IPv6len) 66 | u64ToIP(ip[:net.IPv6len/2], a) 67 | u64ToIP(ip[net.IPv6len/2:], b) 68 | return ip 69 | } 70 | 71 | // IPMod calculates ip % d 72 | func IPMod(ip net.IP, d uint64) uint64 { 73 | if IsIPv4(ip) { 74 | return uint64(ipToI32(ip[len(ip)-4:])) % d 75 | } 76 | b := uint64(d) 77 | hi := ipToU64(ip[:net.IPv6len/2]) 78 | lo := ipToU64(ip[net.IPv6len/2:]) 79 | return uint64(((hi%b)*((0-b)%b) + lo%b) % b) 80 | } 81 | -------------------------------------------------------------------------------- /pkg/proxy/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /pkg/proxy/config/api.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package config 18 | 19 | import ( 20 | "time" 21 | 22 | "github.com/GoogleCloudPlatform/kubernetes/pkg/api" 23 | "github.com/GoogleCloudPlatform/kubernetes/pkg/labels" 24 | "github.com/GoogleCloudPlatform/kubernetes/pkg/util" 25 | "github.com/GoogleCloudPlatform/kubernetes/pkg/util/wait" 26 | "github.com/GoogleCloudPlatform/kubernetes/pkg/watch" 27 | "github.com/golang/glog" 28 | ) 29 | 30 | // Watcher is the interface needed to receive changes to services and endpoints. 31 | type Watcher interface { 32 | ListServices(label labels.Selector) (*api.ServiceList, error) 33 | ListEndpoints(label labels.Selector) (*api.EndpointsList, error) 34 | WatchServices(label, field labels.Selector, resourceVersion uint64) (watch.Interface, error) 35 | WatchEndpoints(label, field labels.Selector, resourceVersion uint64) (watch.Interface, error) 36 | } 37 | 38 | // SourceAPI implements a configuration source for services and endpoints that 39 | // uses the client watch API to efficiently detect changes. 40 | type SourceAPI struct { 41 | client Watcher 42 | services chan<- ServiceUpdate 43 | endpoints chan<- EndpointsUpdate 44 | 45 | waitDuration time.Duration 46 | reconnectDuration time.Duration 47 | } 48 | 49 | // NewSourceAPI creates a config source that watches for changes to the services and endpoints. 50 | func NewSourceAPI(client Watcher, period time.Duration, services chan<- ServiceUpdate, endpoints chan<- EndpointsUpdate) *SourceAPI { 51 | config := &SourceAPI{ 52 | client: client, 53 | services: services, 54 | endpoints: endpoints, 55 | 56 | waitDuration: period, 57 | // prevent hot loops if the server starts to misbehave 58 | reconnectDuration: time.Second * 1, 59 | } 60 | serviceVersion := uint64(0) 61 | go util.Forever(func() { 62 | config.runServices(&serviceVersion) 63 | time.Sleep(wait.Jitter(config.reconnectDuration, 0.0)) 64 | }, period) 65 | endpointVersion := uint64(0) 66 | go util.Forever(func() { 67 | config.runEndpoints(&endpointVersion) 68 | time.Sleep(wait.Jitter(config.reconnectDuration, 0.0)) 69 | }, period) 70 | return config 71 | } 72 | 73 | // runServices loops forever looking for changes to services. 74 | func (s *SourceAPI) runServices(resourceVersion *uint64) { 75 | if *resourceVersion == 0 { 76 | services, err := s.client.ListServices(labels.Everything()) 77 | if err != nil { 78 | glog.Errorf("Unable to load services: %v", err) 79 | time.Sleep(wait.Jitter(s.waitDuration, 0.0)) 80 | return 81 | } 82 | *resourceVersion = services.ResourceVersion 83 | s.services <- ServiceUpdate{Op: SET, Services: services.Items} 84 | } 85 | 86 | watcher, err := s.client.WatchServices(labels.Everything(), labels.Everything(), *resourceVersion) 87 | if err != nil { 88 | glog.Errorf("Unable to watch for services changes: %v", err) 89 | time.Sleep(wait.Jitter(s.waitDuration, 0.0)) 90 | return 91 | } 92 | defer watcher.Stop() 93 | 94 | ch := watcher.ResultChan() 95 | handleServicesWatch(resourceVersion, ch, s.services) 96 | } 97 | 98 | // handleServicesWatch loops over an event channel and delivers config changes to an update channel. 99 | func handleServicesWatch(resourceVersion *uint64, ch <-chan watch.Event, updates chan<- ServiceUpdate) { 100 | for { 101 | select { 102 | case event, ok := <-ch: 103 | if !ok { 104 | glog.V(2).Infof("WatchServices channel closed") 105 | return 106 | } 107 | 108 | service := event.Object.(*api.Service) 109 | *resourceVersion = service.ResourceVersion + 1 110 | 111 | switch event.Type { 112 | case watch.Added, watch.Modified: 113 | updates <- ServiceUpdate{Op: ADD, Services: []api.Service{*service}} 114 | 115 | case watch.Deleted: 116 | updates <- ServiceUpdate{Op: REMOVE, Services: []api.Service{*service}} 117 | } 118 | } 119 | } 120 | } 121 | 122 | // runEndpoints loops forever looking for changes to endpoints. 123 | func (s *SourceAPI) runEndpoints(resourceVersion *uint64) { 124 | if *resourceVersion == 0 { 125 | endpoints, err := s.client.ListEndpoints(labels.Everything()) 126 | if err != nil { 127 | glog.Errorf("Unable to load endpoints: %v", err) 128 | time.Sleep(wait.Jitter(s.waitDuration, 0.0)) 129 | return 130 | } 131 | *resourceVersion = endpoints.ResourceVersion 132 | s.endpoints <- EndpointsUpdate{Op: SET, Endpoints: endpoints.Items} 133 | } 134 | 135 | watcher, err := s.client.WatchEndpoints(labels.Everything(), labels.Everything(), *resourceVersion) 136 | if err != nil { 137 | glog.Errorf("Unable to watch for endpoints changes: %v", err) 138 | time.Sleep(wait.Jitter(s.waitDuration, 0.0)) 139 | return 140 | } 141 | defer watcher.Stop() 142 | 143 | ch := watcher.ResultChan() 144 | handleEndpointsWatch(resourceVersion, ch, s.endpoints) 145 | } 146 | 147 | // handleEndpointsWatch loops over an event channel and delivers config changes to an update channel. 148 | func handleEndpointsWatch(resourceVersion *uint64, ch <-chan watch.Event, updates chan<- EndpointsUpdate) { 149 | for { 150 | select { 151 | case event, ok := <-ch: 152 | if !ok { 153 | glog.V(2).Infof("WatchEndpoints channel closed") 154 | return 155 | } 156 | 157 | endpoints := event.Object.(*api.Endpoints) 158 | *resourceVersion = endpoints.ResourceVersion + 1 159 | 160 | switch event.Type { 161 | case watch.Added, watch.Modified: 162 | updates <- EndpointsUpdate{Op: ADD, Endpoints: []api.Endpoints{*endpoints}} 163 | 164 | case watch.Deleted: 165 | updates <- EndpointsUpdate{Op: REMOVE, Endpoints: []api.Endpoints{*endpoints}} 166 | } 167 | } 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /pkg/proxy/config/api_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package config 18 | 19 | import ( 20 | "errors" 21 | "reflect" 22 | "testing" 23 | 24 | "github.com/GoogleCloudPlatform/kubernetes/pkg/api" 25 | "github.com/GoogleCloudPlatform/kubernetes/pkg/client" 26 | "github.com/GoogleCloudPlatform/kubernetes/pkg/watch" 27 | ) 28 | 29 | func TestServices(t *testing.T) { 30 | service := api.Service{JSONBase: api.JSONBase{ID: "bar", ResourceVersion: uint64(2)}} 31 | 32 | fakeWatch := watch.NewFake() 33 | fakeClient := &client.Fake{Watch: fakeWatch} 34 | services := make(chan ServiceUpdate) 35 | source := SourceAPI{client: fakeClient, services: services} 36 | resourceVersion := uint64(1) 37 | go func() { 38 | // called twice 39 | source.runServices(&resourceVersion) 40 | source.runServices(&resourceVersion) 41 | }() 42 | 43 | // test adding a service to the watch 44 | fakeWatch.Add(&service) 45 | if !reflect.DeepEqual(fakeClient.Actions, []client.FakeAction{{"watch-services", uint64(1)}}) { 46 | t.Errorf("expected call to watch-services, got %#v", fakeClient) 47 | } 48 | 49 | actual := <-services 50 | expected := ServiceUpdate{Op: ADD, Services: []api.Service{service}} 51 | if !reflect.DeepEqual(expected, actual) { 52 | t.Errorf("expected %#v, got %#v", expected, actual) 53 | } 54 | 55 | // verify that a delete results in a config change 56 | fakeWatch.Delete(&service) 57 | actual = <-services 58 | expected = ServiceUpdate{Op: REMOVE, Services: []api.Service{service}} 59 | if !reflect.DeepEqual(expected, actual) { 60 | t.Errorf("expected %#v, got %#v", expected, actual) 61 | } 62 | 63 | // verify that closing the channel results in a new call to WatchServices with a higher resource version 64 | newFakeWatch := watch.NewFake() 65 | fakeClient.Watch = newFakeWatch 66 | fakeWatch.Stop() 67 | 68 | newFakeWatch.Add(&service) 69 | if !reflect.DeepEqual(fakeClient.Actions, []client.FakeAction{{"watch-services", uint64(1)}, {"watch-services", uint64(3)}}) { 70 | t.Errorf("expected call to watch-endpoints, got %#v", fakeClient) 71 | } 72 | } 73 | 74 | func TestServicesFromZero(t *testing.T) { 75 | service := api.Service{JSONBase: api.JSONBase{ID: "bar", ResourceVersion: uint64(2)}} 76 | 77 | fakeWatch := watch.NewFake() 78 | fakeWatch.Stop() 79 | fakeClient := &client.Fake{Watch: fakeWatch} 80 | fakeClient.ServiceList = api.ServiceList{ 81 | JSONBase: api.JSONBase{ResourceVersion: 2}, 82 | Items: []api.Service{ 83 | service, 84 | }, 85 | } 86 | services := make(chan ServiceUpdate) 87 | source := SourceAPI{client: fakeClient, services: services} 88 | resourceVersion := uint64(0) 89 | ch := make(chan struct{}) 90 | go func() { 91 | source.runServices(&resourceVersion) 92 | close(ch) 93 | }() 94 | 95 | // should get services SET 96 | actual := <-services 97 | expected := ServiceUpdate{Op: SET, Services: []api.Service{service}} 98 | if !reflect.DeepEqual(expected, actual) { 99 | t.Errorf("expected %#v, got %#v", expected, actual) 100 | } 101 | 102 | // should have listed, then watched 103 | <-ch 104 | if resourceVersion != 2 { 105 | t.Errorf("unexpected resource version, got %#v", resourceVersion) 106 | } 107 | if !reflect.DeepEqual(fakeClient.Actions, []client.FakeAction{{"list-services", nil}, {"watch-services", uint64(2)}}) { 108 | t.Errorf("unexpected actions, got %#v", fakeClient) 109 | } 110 | } 111 | 112 | func TestServicesError(t *testing.T) { 113 | fakeClient := &client.Fake{Err: errors.New("test")} 114 | services := make(chan ServiceUpdate) 115 | source := SourceAPI{client: fakeClient, services: services} 116 | resourceVersion := uint64(1) 117 | ch := make(chan struct{}) 118 | go func() { 119 | source.runServices(&resourceVersion) 120 | close(ch) 121 | }() 122 | 123 | // should have listed only 124 | <-ch 125 | if resourceVersion != 1 { 126 | t.Errorf("unexpected resource version, got %#v", resourceVersion) 127 | } 128 | if !reflect.DeepEqual(fakeClient.Actions, []client.FakeAction{{"watch-services", uint64(1)}}) { 129 | t.Errorf("unexpected actions, got %#v", fakeClient) 130 | } 131 | } 132 | 133 | func TestServicesFromZeroError(t *testing.T) { 134 | fakeClient := &client.Fake{Err: errors.New("test")} 135 | services := make(chan ServiceUpdate) 136 | source := SourceAPI{client: fakeClient, services: services} 137 | resourceVersion := uint64(0) 138 | ch := make(chan struct{}) 139 | go func() { 140 | source.runServices(&resourceVersion) 141 | close(ch) 142 | }() 143 | 144 | // should have listed only 145 | <-ch 146 | if resourceVersion != 0 { 147 | t.Errorf("unexpected resource version, got %#v", resourceVersion) 148 | } 149 | if !reflect.DeepEqual(fakeClient.Actions, []client.FakeAction{{"list-services", nil}}) { 150 | t.Errorf("unexpected actions, got %#v", fakeClient) 151 | } 152 | } 153 | 154 | func TestEndpoints(t *testing.T) { 155 | endpoint := api.Endpoints{JSONBase: api.JSONBase{ID: "bar", ResourceVersion: uint64(2)}, Endpoints: []string{"127.0.0.1:9000"}} 156 | 157 | fakeWatch := watch.NewFake() 158 | fakeClient := &client.Fake{Watch: fakeWatch} 159 | endpoints := make(chan EndpointsUpdate) 160 | source := SourceAPI{client: fakeClient, endpoints: endpoints} 161 | resourceVersion := uint64(1) 162 | go func() { 163 | // called twice 164 | source.runEndpoints(&resourceVersion) 165 | source.runEndpoints(&resourceVersion) 166 | }() 167 | 168 | // test adding an endpoint to the watch 169 | fakeWatch.Add(&endpoint) 170 | if !reflect.DeepEqual(fakeClient.Actions, []client.FakeAction{{"watch-endpoints", uint64(1)}}) { 171 | t.Errorf("expected call to watch-endpoints, got %#v", fakeClient) 172 | } 173 | 174 | actual := <-endpoints 175 | expected := EndpointsUpdate{Op: ADD, Endpoints: []api.Endpoints{endpoint}} 176 | if !reflect.DeepEqual(expected, actual) { 177 | t.Errorf("expected %#v, got %#v", expected, actual) 178 | } 179 | 180 | // verify that a delete results in a config change 181 | fakeWatch.Delete(&endpoint) 182 | actual = <-endpoints 183 | expected = EndpointsUpdate{Op: REMOVE, Endpoints: []api.Endpoints{endpoint}} 184 | if !reflect.DeepEqual(expected, actual) { 185 | t.Errorf("expected %#v, got %#v", expected, actual) 186 | } 187 | 188 | // verify that closing the channel results in a new call to WatchEndpoints with a higher resource version 189 | newFakeWatch := watch.NewFake() 190 | fakeClient.Watch = newFakeWatch 191 | fakeWatch.Stop() 192 | 193 | newFakeWatch.Add(&endpoint) 194 | if !reflect.DeepEqual(fakeClient.Actions, []client.FakeAction{{"watch-endpoints", uint64(1)}, {"watch-endpoints", uint64(3)}}) { 195 | t.Errorf("expected call to watch-endpoints, got %#v", fakeClient) 196 | } 197 | } 198 | 199 | func TestEndpointsFromZero(t *testing.T) { 200 | endpoint := api.Endpoints{JSONBase: api.JSONBase{ID: "bar", ResourceVersion: uint64(2)}, Endpoints: []string{"127.0.0.1:9000"}} 201 | 202 | fakeWatch := watch.NewFake() 203 | fakeWatch.Stop() 204 | fakeClient := &client.Fake{Watch: fakeWatch} 205 | fakeClient.EndpointsList = api.EndpointsList{ 206 | JSONBase: api.JSONBase{ResourceVersion: 2}, 207 | Items: []api.Endpoints{ 208 | endpoint, 209 | }, 210 | } 211 | endpoints := make(chan EndpointsUpdate) 212 | source := SourceAPI{client: fakeClient, endpoints: endpoints} 213 | resourceVersion := uint64(0) 214 | ch := make(chan struct{}) 215 | go func() { 216 | source.runEndpoints(&resourceVersion) 217 | close(ch) 218 | }() 219 | 220 | // should get endpoints SET 221 | actual := <-endpoints 222 | expected := EndpointsUpdate{Op: SET, Endpoints: []api.Endpoints{endpoint}} 223 | if !reflect.DeepEqual(expected, actual) { 224 | t.Errorf("expected %#v, got %#v", expected, actual) 225 | } 226 | 227 | // should have listed, then watched 228 | <-ch 229 | if resourceVersion != 2 { 230 | t.Errorf("unexpected resource version, got %#v", resourceVersion) 231 | } 232 | if !reflect.DeepEqual(fakeClient.Actions, []client.FakeAction{{"list-endpoints", nil}, {"watch-endpoints", uint64(2)}}) { 233 | t.Errorf("unexpected actions, got %#v", fakeClient) 234 | } 235 | } 236 | 237 | func TestEndpointsError(t *testing.T) { 238 | fakeClient := &client.Fake{Err: errors.New("test")} 239 | endpoints := make(chan EndpointsUpdate) 240 | source := SourceAPI{client: fakeClient, endpoints: endpoints} 241 | resourceVersion := uint64(1) 242 | ch := make(chan struct{}) 243 | go func() { 244 | source.runEndpoints(&resourceVersion) 245 | close(ch) 246 | }() 247 | 248 | // should have listed only 249 | <-ch 250 | if resourceVersion != 1 { 251 | t.Errorf("unexpected resource version, got %#v", resourceVersion) 252 | } 253 | if !reflect.DeepEqual(fakeClient.Actions, []client.FakeAction{{"watch-endpoints", uint64(1)}}) { 254 | t.Errorf("unexpected actions, got %#v", fakeClient) 255 | } 256 | } 257 | 258 | func TestEndpointsFromZeroError(t *testing.T) { 259 | fakeClient := &client.Fake{Err: errors.New("test")} 260 | endpoints := make(chan EndpointsUpdate) 261 | source := SourceAPI{client: fakeClient, endpoints: endpoints} 262 | resourceVersion := uint64(0) 263 | ch := make(chan struct{}) 264 | go func() { 265 | source.runEndpoints(&resourceVersion) 266 | close(ch) 267 | }() 268 | 269 | // should have listed only 270 | <-ch 271 | if resourceVersion != 0 { 272 | t.Errorf("unexpected resource version, got %#v", resourceVersion) 273 | } 274 | if !reflect.DeepEqual(fakeClient.Actions, []client.FakeAction{{"list-endpoints", nil}}) { 275 | t.Errorf("unexpected actions, got %#v", fakeClient) 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /pkg/proxy/config/config.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package config 18 | 19 | import ( 20 | "sync" 21 | 22 | "github.com/GoogleCloudPlatform/kubernetes/pkg/api" 23 | "github.com/GoogleCloudPlatform/kubernetes/pkg/util/config" 24 | "github.com/golang/glog" 25 | ) 26 | 27 | // Operation is a type of operation of services or endpoints. 28 | type Operation int 29 | 30 | // These are the available operation types. 31 | const ( 32 | SET Operation = iota 33 | ADD 34 | REMOVE 35 | ) 36 | 37 | // ServiceUpdate describes an operation of services, sent on the channel. 38 | // You can add or remove single services by sending an array of size one and Op == ADD|REMOVE. 39 | // For setting the state of the system to a given state for this source configuration, set Services as desired and Op to SET, 40 | // which will reset the system state to that specified in this operation for this source channel. 41 | // To remove all services, set Services to empty array and Op to SET 42 | type ServiceUpdate struct { 43 | Services []api.Service 44 | Op Operation 45 | } 46 | 47 | // EndpointsUpdate describes an operation of endpoints, sent on the channel. 48 | // You can add or remove single endpoints by sending an array of size one and Op == ADD|REMOVE. 49 | // For setting the state of the system to a given state for this source configuration, set Endpoints as desired and Op to SET, 50 | // which will reset the system state to that specified in this operation for this source channel. 51 | // To remove all endpoints, set Endpoints to empty array and Op to SET 52 | type EndpointsUpdate struct { 53 | Endpoints []api.Endpoints 54 | Op Operation 55 | } 56 | 57 | // ServiceConfigHandler is an abstract interface of objects which receive update notifications for the set of services. 58 | type ServiceConfigHandler interface { 59 | // OnUpdate gets called when a configuration has been changed by one of the sources. 60 | // This is the union of all the configuration sources. 61 | OnUpdate(services []api.Service) 62 | } 63 | 64 | // EndpointsConfigHandler is an abstract interface of objects which receive update notifications for the set of endpoints. 65 | type EndpointsConfigHandler interface { 66 | // OnUpdate gets called when endpoints configuration is changed for a given 67 | // service on any of the configuration sources. An example is when a new 68 | // service comes up, or when containers come up or down for an existing service. 69 | OnUpdate(endpoints []api.Endpoints) 70 | } 71 | 72 | // EndpointsConfig tracks a set of endpoints configurations. 73 | // It accepts "set", "add" and "remove" operations of endpoints via channels, and invokes registered handlers on change. 74 | type EndpointsConfig struct { 75 | mux *config.Mux 76 | watcher *config.Watcher 77 | store *endpointsStore 78 | } 79 | 80 | // NewEndpointsConfig creates a new EndpointsConfig. 81 | // It immediately runs the created EndpointsConfig. 82 | func NewEndpointsConfig() *EndpointsConfig { 83 | updates := make(chan struct{}) 84 | store := &endpointsStore{updates: updates, endpoints: make(map[string]map[string]api.Endpoints)} 85 | mux := config.NewMux(store) 86 | watcher := config.NewWatcher() 87 | go watchForUpdates(watcher, store, updates) 88 | return &EndpointsConfig{mux, watcher, store} 89 | } 90 | 91 | func (c *EndpointsConfig) RegisterHandler(handler EndpointsConfigHandler) { 92 | c.watcher.Add(config.ListenerFunc(func(instance interface{}) { 93 | handler.OnUpdate(instance.([]api.Endpoints)) 94 | })) 95 | } 96 | 97 | func (c *EndpointsConfig) Channel(source string) chan EndpointsUpdate { 98 | ch := c.mux.Channel(source) 99 | endpointsCh := make(chan EndpointsUpdate) 100 | go func() { 101 | for update := range endpointsCh { 102 | ch <- update 103 | } 104 | close(ch) 105 | }() 106 | return endpointsCh 107 | } 108 | 109 | func (c *EndpointsConfig) Config() map[string]map[string]api.Endpoints { 110 | return c.store.MergedState().(map[string]map[string]api.Endpoints) 111 | } 112 | 113 | type endpointsStore struct { 114 | endpointLock sync.RWMutex 115 | endpoints map[string]map[string]api.Endpoints 116 | updates chan<- struct{} 117 | } 118 | 119 | func (s *endpointsStore) Merge(source string, change interface{}) error { 120 | s.endpointLock.Lock() 121 | endpoints := s.endpoints[source] 122 | if endpoints == nil { 123 | endpoints = make(map[string]api.Endpoints) 124 | } 125 | update := change.(EndpointsUpdate) 126 | switch update.Op { 127 | case ADD: 128 | glog.Infof("Adding new endpoint from source %s : %v", source, update.Endpoints) 129 | for _, value := range update.Endpoints { 130 | endpoints[value.ID] = value 131 | } 132 | case REMOVE: 133 | glog.Infof("Removing an endpoint %v", update) 134 | for _, value := range update.Endpoints { 135 | delete(endpoints, value.ID) 136 | } 137 | case SET: 138 | glog.Infof("Setting endpoints %v", update) 139 | // Clear the old map entries by just creating a new map 140 | endpoints = make(map[string]api.Endpoints) 141 | for _, value := range update.Endpoints { 142 | endpoints[value.ID] = value 143 | } 144 | default: 145 | glog.Infof("Received invalid update type: %v", update) 146 | } 147 | s.endpoints[source] = endpoints 148 | s.endpointLock.Unlock() 149 | if s.updates != nil { 150 | s.updates <- struct{}{} 151 | } 152 | return nil 153 | } 154 | 155 | func (s *endpointsStore) MergedState() interface{} { 156 | s.endpointLock.RLock() 157 | defer s.endpointLock.RUnlock() 158 | endpoints := make([]api.Endpoints, 0) 159 | for _, sourceEndpoints := range s.endpoints { 160 | for _, value := range sourceEndpoints { 161 | endpoints = append(endpoints, value) 162 | } 163 | } 164 | return endpoints 165 | } 166 | 167 | // ServiceConfig tracks a set of service configurations. 168 | // It accepts "set", "add" and "remove" operations of services via channels, and invokes registered handlers on change. 169 | type ServiceConfig struct { 170 | mux *config.Mux 171 | watcher *config.Watcher 172 | store *serviceStore 173 | } 174 | 175 | // NewServiceConfig creates a new ServiceConfig. 176 | // It immediately runs the created ServiceConfig. 177 | func NewServiceConfig() *ServiceConfig { 178 | updates := make(chan struct{}) 179 | store := &serviceStore{updates: updates, services: make(map[string]map[string]api.Service)} 180 | mux := config.NewMux(store) 181 | watcher := config.NewWatcher() 182 | go watchForUpdates(watcher, store, updates) 183 | return &ServiceConfig{mux, watcher, store} 184 | } 185 | 186 | func (c *ServiceConfig) RegisterHandler(handler ServiceConfigHandler) { 187 | c.watcher.Add(config.ListenerFunc(func(instance interface{}) { 188 | handler.OnUpdate(instance.([]api.Service)) 189 | })) 190 | } 191 | 192 | func (c *ServiceConfig) Channel(source string) chan ServiceUpdate { 193 | ch := c.mux.Channel(source) 194 | serviceCh := make(chan ServiceUpdate) 195 | go func() { 196 | for update := range serviceCh { 197 | ch <- update 198 | } 199 | close(ch) 200 | }() 201 | return serviceCh 202 | } 203 | 204 | func (c *ServiceConfig) Config() map[string]map[string]api.Service { 205 | return c.store.MergedState().(map[string]map[string]api.Service) 206 | } 207 | 208 | type serviceStore struct { 209 | serviceLock sync.RWMutex 210 | services map[string]map[string]api.Service 211 | updates chan<- struct{} 212 | } 213 | 214 | func (s *serviceStore) Merge(source string, change interface{}) error { 215 | s.serviceLock.Lock() 216 | services := s.services[source] 217 | if services == nil { 218 | services = make(map[string]api.Service) 219 | } 220 | update := change.(ServiceUpdate) 221 | switch update.Op { 222 | case ADD: 223 | glog.Infof("Adding new service from source %s : %v", source, update.Services) 224 | for _, value := range update.Services { 225 | services[value.ID] = value 226 | } 227 | case REMOVE: 228 | glog.Infof("Removing a service %v", update) 229 | for _, value := range update.Services { 230 | delete(services, value.ID) 231 | } 232 | case SET: 233 | glog.Infof("Setting services %v", update) 234 | // Clear the old map entries by just creating a new map 235 | services = make(map[string]api.Service) 236 | for _, value := range update.Services { 237 | services[value.ID] = value 238 | } 239 | default: 240 | glog.Infof("Received invalid update type: %v", update) 241 | } 242 | s.services[source] = services 243 | s.serviceLock.Unlock() 244 | if s.updates != nil { 245 | s.updates <- struct{}{} 246 | } 247 | return nil 248 | } 249 | 250 | func (s *serviceStore) MergedState() interface{} { 251 | s.serviceLock.RLock() 252 | defer s.serviceLock.RUnlock() 253 | services := make([]api.Service, 0) 254 | for _, sourceServices := range s.services { 255 | for _, value := range sourceServices { 256 | services = append(services, value) 257 | } 258 | } 259 | return services 260 | } 261 | 262 | // watchForUpdates invokes watcher.Notify() with the latest version of an object 263 | // when changes occur. 264 | func watchForUpdates(watcher *config.Watcher, accessor config.Accessor, updates <-chan struct{}) { 265 | for _ = range updates { 266 | watcher.Notify(accessor.MergedState()) 267 | } 268 | } 269 | -------------------------------------------------------------------------------- /pkg/proxy/config/config_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package config_test 18 | 19 | import ( 20 | "reflect" 21 | "sort" 22 | "sync" 23 | "testing" 24 | 25 | "github.com/GoogleCloudPlatform/kubernetes/pkg/api" 26 | . "github.com/GoogleCloudPlatform/kubernetes/pkg/proxy/config" 27 | ) 28 | 29 | const TomcatPort int = 8080 30 | const TomcatName = "tomcat" 31 | 32 | var TomcatEndpoints = map[string]string{"c0": "1.1.1.1:18080", "c1": "2.2.2.2:18081"} 33 | 34 | const MysqlPort int = 3306 35 | const MysqlName = "mysql" 36 | 37 | var MysqlEndpoints = map[string]string{"c0": "1.1.1.1:13306", "c3": "2.2.2.2:13306"} 38 | 39 | type sortedServices []api.Service 40 | 41 | func (s sortedServices) Len() int { 42 | return len(s) 43 | } 44 | func (s sortedServices) Swap(i, j int) { 45 | s[i], s[j] = s[j], s[i] 46 | } 47 | func (s sortedServices) Less(i, j int) bool { 48 | return s[i].JSONBase.ID < s[j].JSONBase.ID 49 | } 50 | 51 | type ServiceHandlerMock struct { 52 | services []api.Service 53 | updated sync.WaitGroup 54 | } 55 | 56 | func NewServiceHandlerMock() *ServiceHandlerMock { 57 | return &ServiceHandlerMock{services: make([]api.Service, 0)} 58 | } 59 | 60 | func (h *ServiceHandlerMock) OnUpdate(services []api.Service) { 61 | sort.Sort(sortedServices(services)) 62 | h.services = services 63 | h.updated.Done() 64 | } 65 | 66 | func (h *ServiceHandlerMock) ValidateServices(t *testing.T, expectedServices []api.Service) { 67 | h.updated.Wait() 68 | if !reflect.DeepEqual(h.services, expectedServices) { 69 | t.Errorf("Expected %#v, Got %#v", expectedServices, h.services) 70 | } 71 | } 72 | 73 | func (h *ServiceHandlerMock) Wait(waits int) { 74 | h.updated.Add(waits) 75 | } 76 | 77 | type sortedEndpoints []api.Endpoints 78 | 79 | func (s sortedEndpoints) Len() int { 80 | return len(s) 81 | } 82 | func (s sortedEndpoints) Swap(i, j int) { 83 | s[i], s[j] = s[j], s[i] 84 | } 85 | func (s sortedEndpoints) Less(i, j int) bool { 86 | return s[i].ID < s[j].ID 87 | } 88 | 89 | type EndpointsHandlerMock struct { 90 | endpoints []api.Endpoints 91 | updated sync.WaitGroup 92 | } 93 | 94 | func NewEndpointsHandlerMock() *EndpointsHandlerMock { 95 | return &EndpointsHandlerMock{endpoints: make([]api.Endpoints, 0)} 96 | } 97 | 98 | func (h *EndpointsHandlerMock) OnUpdate(endpoints []api.Endpoints) { 99 | sort.Sort(sortedEndpoints(endpoints)) 100 | h.endpoints = endpoints 101 | h.updated.Done() 102 | } 103 | 104 | func (h *EndpointsHandlerMock) ValidateEndpoints(t *testing.T, expectedEndpoints []api.Endpoints) { 105 | h.updated.Wait() 106 | if !reflect.DeepEqual(h.endpoints, expectedEndpoints) { 107 | t.Errorf("Expected %#v, Got %#v", expectedEndpoints, h.endpoints) 108 | } 109 | } 110 | 111 | func (h *EndpointsHandlerMock) Wait(waits int) { 112 | h.updated.Add(waits) 113 | } 114 | 115 | func CreateServiceUpdate(op Operation, services ...api.Service) ServiceUpdate { 116 | ret := ServiceUpdate{Op: op} 117 | ret.Services = make([]api.Service, len(services)) 118 | for i, value := range services { 119 | ret.Services[i] = value 120 | } 121 | return ret 122 | } 123 | 124 | func CreateEndpointsUpdate(op Operation, endpoints ...api.Endpoints) EndpointsUpdate { 125 | ret := EndpointsUpdate{Op: op} 126 | ret.Endpoints = make([]api.Endpoints, len(endpoints)) 127 | for i, value := range endpoints { 128 | ret.Endpoints[i] = value 129 | } 130 | return ret 131 | } 132 | 133 | func TestNewServiceAddedAndNotified(t *testing.T) { 134 | config := NewServiceConfig() 135 | channel := config.Channel("one") 136 | handler := NewServiceHandlerMock() 137 | handler.Wait(1) 138 | config.RegisterHandler(handler) 139 | serviceUpdate := CreateServiceUpdate(ADD, api.Service{JSONBase: api.JSONBase{ID: "foo"}, Port: 10}) 140 | channel <- serviceUpdate 141 | handler.ValidateServices(t, serviceUpdate.Services) 142 | 143 | } 144 | 145 | func TestServiceAddedRemovedSetAndNotified(t *testing.T) { 146 | config := NewServiceConfig() 147 | channel := config.Channel("one") 148 | handler := NewServiceHandlerMock() 149 | config.RegisterHandler(handler) 150 | serviceUpdate := CreateServiceUpdate(ADD, api.Service{JSONBase: api.JSONBase{ID: "foo"}, Port: 10}) 151 | handler.Wait(1) 152 | channel <- serviceUpdate 153 | handler.ValidateServices(t, serviceUpdate.Services) 154 | 155 | serviceUpdate2 := CreateServiceUpdate(ADD, api.Service{JSONBase: api.JSONBase{ID: "bar"}, Port: 20}) 156 | handler.Wait(1) 157 | channel <- serviceUpdate2 158 | services := []api.Service{serviceUpdate2.Services[0], serviceUpdate.Services[0]} 159 | handler.ValidateServices(t, services) 160 | 161 | serviceUpdate3 := CreateServiceUpdate(REMOVE, api.Service{JSONBase: api.JSONBase{ID: "foo"}}) 162 | handler.Wait(1) 163 | channel <- serviceUpdate3 164 | services = []api.Service{serviceUpdate2.Services[0]} 165 | handler.ValidateServices(t, services) 166 | 167 | serviceUpdate4 := CreateServiceUpdate(SET, api.Service{JSONBase: api.JSONBase{ID: "foobar"}, Port: 99}) 168 | handler.Wait(1) 169 | channel <- serviceUpdate4 170 | services = []api.Service{serviceUpdate4.Services[0]} 171 | handler.ValidateServices(t, services) 172 | } 173 | 174 | func TestNewMultipleSourcesServicesAddedAndNotified(t *testing.T) { 175 | config := NewServiceConfig() 176 | channelOne := config.Channel("one") 177 | channelTwo := config.Channel("two") 178 | if channelOne == channelTwo { 179 | t.Error("Same channel handed back for one and two") 180 | } 181 | handler := NewServiceHandlerMock() 182 | config.RegisterHandler(handler) 183 | serviceUpdate1 := CreateServiceUpdate(ADD, api.Service{JSONBase: api.JSONBase{ID: "foo"}, Port: 10}) 184 | serviceUpdate2 := CreateServiceUpdate(ADD, api.Service{JSONBase: api.JSONBase{ID: "bar"}, Port: 20}) 185 | handler.Wait(2) 186 | channelOne <- serviceUpdate1 187 | channelTwo <- serviceUpdate2 188 | services := []api.Service{serviceUpdate2.Services[0], serviceUpdate1.Services[0]} 189 | handler.ValidateServices(t, services) 190 | } 191 | 192 | func TestNewMultipleSourcesServicesMultipleHandlersAddedAndNotified(t *testing.T) { 193 | config := NewServiceConfig() 194 | channelOne := config.Channel("one") 195 | channelTwo := config.Channel("two") 196 | handler := NewServiceHandlerMock() 197 | handler2 := NewServiceHandlerMock() 198 | config.RegisterHandler(handler) 199 | config.RegisterHandler(handler2) 200 | serviceUpdate1 := CreateServiceUpdate(ADD, api.Service{JSONBase: api.JSONBase{ID: "foo"}, Port: 10}) 201 | serviceUpdate2 := CreateServiceUpdate(ADD, api.Service{JSONBase: api.JSONBase{ID: "bar"}, Port: 20}) 202 | handler.Wait(2) 203 | handler2.Wait(2) 204 | channelOne <- serviceUpdate1 205 | channelTwo <- serviceUpdate2 206 | services := []api.Service{serviceUpdate2.Services[0], serviceUpdate1.Services[0]} 207 | handler.ValidateServices(t, services) 208 | handler2.ValidateServices(t, services) 209 | } 210 | 211 | func TestNewMultipleSourcesEndpointsMultipleHandlersAddedAndNotified(t *testing.T) { 212 | config := NewEndpointsConfig() 213 | channelOne := config.Channel("one") 214 | channelTwo := config.Channel("two") 215 | handler := NewEndpointsHandlerMock() 216 | handler2 := NewEndpointsHandlerMock() 217 | config.RegisterHandler(handler) 218 | config.RegisterHandler(handler2) 219 | endpointsUpdate1 := CreateEndpointsUpdate(ADD, api.Endpoints{ 220 | JSONBase: api.JSONBase{ID: "foo"}, 221 | Endpoints: []string{"endpoint1", "endpoint2"}, 222 | }) 223 | endpointsUpdate2 := CreateEndpointsUpdate(ADD, api.Endpoints{ 224 | JSONBase: api.JSONBase{ID: "bar"}, 225 | Endpoints: []string{"endpoint3", "endpoint4"}, 226 | }) 227 | handler.Wait(2) 228 | handler2.Wait(2) 229 | channelOne <- endpointsUpdate1 230 | channelTwo <- endpointsUpdate2 231 | 232 | endpoints := []api.Endpoints{endpointsUpdate2.Endpoints[0], endpointsUpdate1.Endpoints[0]} 233 | handler.ValidateEndpoints(t, endpoints) 234 | handler2.ValidateEndpoints(t, endpoints) 235 | } 236 | 237 | func TestNewMultipleSourcesEndpointsMultipleHandlersAddRemoveSetAndNotified(t *testing.T) { 238 | config := NewEndpointsConfig() 239 | channelOne := config.Channel("one") 240 | channelTwo := config.Channel("two") 241 | handler := NewEndpointsHandlerMock() 242 | handler2 := NewEndpointsHandlerMock() 243 | config.RegisterHandler(handler) 244 | config.RegisterHandler(handler2) 245 | endpointsUpdate1 := CreateEndpointsUpdate(ADD, api.Endpoints{ 246 | JSONBase: api.JSONBase{ID: "foo"}, 247 | Endpoints: []string{"endpoint1", "endpoint2"}, 248 | }) 249 | endpointsUpdate2 := CreateEndpointsUpdate(ADD, api.Endpoints{ 250 | JSONBase: api.JSONBase{ID: "bar"}, 251 | Endpoints: []string{"endpoint3", "endpoint4"}, 252 | }) 253 | handler.Wait(2) 254 | handler2.Wait(2) 255 | channelOne <- endpointsUpdate1 256 | channelTwo <- endpointsUpdate2 257 | 258 | endpoints := []api.Endpoints{endpointsUpdate2.Endpoints[0], endpointsUpdate1.Endpoints[0]} 259 | handler.ValidateEndpoints(t, endpoints) 260 | handler2.ValidateEndpoints(t, endpoints) 261 | 262 | // Add one more 263 | endpointsUpdate3 := CreateEndpointsUpdate(ADD, api.Endpoints{ 264 | JSONBase: api.JSONBase{ID: "foobar"}, 265 | Endpoints: []string{"endpoint5", "endpoint6"}, 266 | }) 267 | handler.Wait(1) 268 | handler2.Wait(1) 269 | channelTwo <- endpointsUpdate3 270 | endpoints = []api.Endpoints{endpointsUpdate2.Endpoints[0], endpointsUpdate1.Endpoints[0], endpointsUpdate3.Endpoints[0]} 271 | handler.ValidateEndpoints(t, endpoints) 272 | handler2.ValidateEndpoints(t, endpoints) 273 | 274 | // Update the "foo" service with new endpoints 275 | endpointsUpdate1 = CreateEndpointsUpdate(ADD, api.Endpoints{ 276 | JSONBase: api.JSONBase{ID: "foo"}, 277 | Endpoints: []string{"endpoint77"}, 278 | }) 279 | handler.Wait(1) 280 | handler2.Wait(1) 281 | channelOne <- endpointsUpdate1 282 | endpoints = []api.Endpoints{endpointsUpdate2.Endpoints[0], endpointsUpdate1.Endpoints[0], endpointsUpdate3.Endpoints[0]} 283 | handler.ValidateEndpoints(t, endpoints) 284 | handler2.ValidateEndpoints(t, endpoints) 285 | 286 | // Remove "bar" service 287 | endpointsUpdate2 = CreateEndpointsUpdate(REMOVE, api.Endpoints{JSONBase: api.JSONBase{ID: "bar"}}) 288 | handler.Wait(1) 289 | handler2.Wait(1) 290 | channelTwo <- endpointsUpdate2 291 | 292 | endpoints = []api.Endpoints{endpointsUpdate1.Endpoints[0], endpointsUpdate3.Endpoints[0]} 293 | handler.ValidateEndpoints(t, endpoints) 294 | handler2.ValidateEndpoints(t, endpoints) 295 | } 296 | -------------------------------------------------------------------------------- /pkg/proxy/config/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | // Package config provides decoupling between various configuration sources (etcd, files,...) and 18 | // the pieces that actually care about them (loadbalancer, proxy). Config takes 1 or more 19 | // configuration sources and allows for incremental (add/remove) and full replace (set) 20 | // changes from each of the sources, then creates a union of the configuration and provides 21 | // a unified view for both service handlers as well as endpoint handlers. There is no attempt 22 | // to resolve conflicts of any sort. Basic idea is that each configuration source gets a channel 23 | // from the Config service and pushes updates to it via that channel. Config then keeps track of 24 | // incremental & replace changes and distributes them to listeners as appropriate. 25 | package config 26 | -------------------------------------------------------------------------------- /pkg/proxy/config/etcd.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | // Watches etcd and gets the full configuration on preset intervals. 18 | // It expects the list of exposed services to live under: 19 | // registry/services 20 | // which in etcd is exposed like so: 21 | // http:///v2/keys/registry/services 22 | // 23 | // The port that proxy needs to listen in for each service is a value in: 24 | // registry/services/ 25 | // 26 | // The endpoints for each of the services found is a json string 27 | // representing that service at: 28 | // /registry/services//endpoint 29 | // and the format is: 30 | // '[ { "machine": , "name": }, 31 | // { "machine": , "name": } 32 | // ]', 33 | 34 | package config 35 | 36 | import ( 37 | "fmt" 38 | "strings" 39 | "time" 40 | 41 | "github.com/GoogleCloudPlatform/kubernetes/pkg/api" 42 | "github.com/GoogleCloudPlatform/kubernetes/pkg/runtime" 43 | "github.com/GoogleCloudPlatform/kubernetes/pkg/tools" 44 | "github.com/GoogleCloudPlatform/kubernetes/pkg/util" 45 | "github.com/coreos/go-etcd/etcd" 46 | "github.com/golang/glog" 47 | ) 48 | 49 | // registryRoot is the key prefix for service configs in etcd. 50 | const registryRoot = "registry/services" 51 | 52 | // ConfigSourceEtcd communicates with a etcd via the client, and sends the change notification of services and endpoints to the specified channels. 53 | type ConfigSourceEtcd struct { 54 | client *etcd.Client 55 | serviceChannel chan ServiceUpdate 56 | endpointsChannel chan EndpointsUpdate 57 | interval time.Duration 58 | } 59 | 60 | // NewConfigSourceEtcd creates a new ConfigSourceEtcd and immediately runs the created ConfigSourceEtcd in a goroutine. 61 | func NewConfigSourceEtcd(client *etcd.Client, serviceChannel chan ServiceUpdate, endpointsChannel chan EndpointsUpdate) ConfigSourceEtcd { 62 | config := ConfigSourceEtcd{ 63 | client: client, 64 | serviceChannel: serviceChannel, 65 | endpointsChannel: endpointsChannel, 66 | interval: 2 * time.Second, 67 | } 68 | go config.Run() 69 | return config 70 | } 71 | 72 | // Run begins watching for new services and their endpoints on etcd. 73 | func (s ConfigSourceEtcd) Run() { 74 | // Initially, just wait for the etcd to come up before doing anything more complicated. 75 | var services []api.Service 76 | var endpoints []api.Endpoints 77 | var err error 78 | for { 79 | services, endpoints, err = s.GetServices() 80 | if err == nil { 81 | break 82 | } 83 | glog.V(1).Infof("Failed to get any services: %v", err) 84 | time.Sleep(s.interval) 85 | } 86 | 87 | if len(services) > 0 { 88 | serviceUpdate := ServiceUpdate{Op: SET, Services: services} 89 | s.serviceChannel <- serviceUpdate 90 | } 91 | if len(endpoints) > 0 { 92 | endpointsUpdate := EndpointsUpdate{Op: SET, Endpoints: endpoints} 93 | s.endpointsChannel <- endpointsUpdate 94 | } 95 | 96 | // Ok, so we got something back from etcd. Let's set up a watch for new services, and 97 | // their endpoints 98 | go util.Forever(s.WatchForChanges, 1*time.Second) 99 | 100 | for { 101 | services, endpoints, err = s.GetServices() 102 | if err != nil { 103 | glog.Errorf("ConfigSourceEtcd: Failed to get services: %v", err) 104 | } else { 105 | if len(services) > 0 { 106 | serviceUpdate := ServiceUpdate{Op: SET, Services: services} 107 | s.serviceChannel <- serviceUpdate 108 | } 109 | if len(endpoints) > 0 { 110 | endpointsUpdate := EndpointsUpdate{Op: SET, Endpoints: endpoints} 111 | s.endpointsChannel <- endpointsUpdate 112 | } 113 | } 114 | time.Sleep(30 * time.Second) 115 | } 116 | } 117 | 118 | // GetServices finds the list of services and their endpoints from etcd. 119 | // This operation is akin to a set a known good at regular intervals. 120 | func (s ConfigSourceEtcd) GetServices() ([]api.Service, []api.Endpoints, error) { 121 | response, err := s.client.Get(registryRoot+"/specs", true, false) 122 | if err != nil { 123 | if tools.IsEtcdNotFound(err) { 124 | glog.V(1).Infof("Failed to get the key %s: %v", registryRoot, err) 125 | } else { 126 | glog.Errorf("Failed to contact etcd for key %s: %v", registryRoot, err) 127 | } 128 | return []api.Service{}, []api.Endpoints{}, err 129 | } 130 | if response.Node.Dir == true { 131 | retServices := make([]api.Service, len(response.Node.Nodes)) 132 | retEndpoints := make([]api.Endpoints, len(response.Node.Nodes)) 133 | // Ok, so we have directories, this list should be the list 134 | // of services. Find the local port to listen on and remote endpoints 135 | // and create a Service entry for it. 136 | for i, node := range response.Node.Nodes { 137 | var svc api.Service 138 | err = runtime.DefaultCodec.DecodeInto([]byte(node.Value), &svc) 139 | if err != nil { 140 | glog.Errorf("Failed to load Service: %s (%#v)", node.Value, err) 141 | continue 142 | } 143 | retServices[i] = svc 144 | endpoints, err := s.GetEndpoints(svc.ID) 145 | if err != nil { 146 | if tools.IsEtcdNotFound(err) { 147 | glog.V(1).Infof("Unable to get endpoints for %s : %v", svc.ID, err) 148 | } 149 | glog.Errorf("Couldn't get endpoints for %s : %v skipping", svc.ID, err) 150 | endpoints = api.Endpoints{} 151 | } else { 152 | glog.Infof("Got service: %s on localport %d mapping to: %s", svc.ID, svc.Port, endpoints) 153 | } 154 | retEndpoints[i] = endpoints 155 | } 156 | return retServices, retEndpoints, err 157 | } 158 | return nil, nil, fmt.Errorf("did not get the root of the registry %s", registryRoot) 159 | } 160 | 161 | // GetEndpoints finds the list of endpoints of the service from etcd. 162 | func (s ConfigSourceEtcd) GetEndpoints(service string) (api.Endpoints, error) { 163 | key := fmt.Sprintf(registryRoot + "/endpoints/" + service) 164 | response, err := s.client.Get(key, true, false) 165 | if err != nil { 166 | glog.Errorf("Failed to get the key: %s %v", key, err) 167 | return api.Endpoints{}, err 168 | } 169 | // Parse all the endpoint specifications in this value. 170 | var e api.Endpoints 171 | err = runtime.DefaultCodec.DecodeInto([]byte(response.Node.Value), &e) 172 | return e, err 173 | } 174 | 175 | // etcdResponseToService takes an etcd response and pulls it apart to find service. 176 | func etcdResponseToService(response *etcd.Response) (*api.Service, error) { 177 | if response.Node == nil { 178 | return nil, fmt.Errorf("invalid response from etcd: %#v", response) 179 | } 180 | var svc api.Service 181 | err := runtime.DefaultCodec.DecodeInto([]byte(response.Node.Value), &svc) 182 | if err != nil { 183 | return nil, err 184 | } 185 | return &svc, err 186 | } 187 | 188 | func (s ConfigSourceEtcd) WatchForChanges() { 189 | glog.Info("Setting up a watch for new services") 190 | watchChannel := make(chan *etcd.Response) 191 | go s.client.Watch("/registry/services/", 0, true, watchChannel, nil) 192 | for { 193 | watchResponse, ok := <-watchChannel 194 | if !ok { 195 | break 196 | } 197 | s.ProcessChange(watchResponse) 198 | } 199 | } 200 | 201 | func (s ConfigSourceEtcd) ProcessChange(response *etcd.Response) { 202 | glog.Infof("Processing a change in service configuration... %s", *response) 203 | 204 | // If it's a new service being added (signified by a localport being added) 205 | // then process it as such 206 | if strings.Contains(response.Node.Key, "/endpoints/") { 207 | s.ProcessEndpointResponse(response) 208 | } else if response.Action == "set" { 209 | service, err := etcdResponseToService(response) 210 | if err != nil { 211 | glog.Errorf("Failed to parse %s Port: %s", response, err) 212 | return 213 | } 214 | 215 | glog.Infof("New service added/updated: %#v", service) 216 | serviceUpdate := ServiceUpdate{Op: ADD, Services: []api.Service{*service}} 217 | s.serviceChannel <- serviceUpdate 218 | return 219 | } 220 | if response.Action == "delete" { 221 | parts := strings.Split(response.Node.Key[1:], "/") 222 | if len(parts) == 4 { 223 | glog.Infof("Deleting service: %s", parts[3]) 224 | serviceUpdate := ServiceUpdate{Op: REMOVE, Services: []api.Service{{JSONBase: api.JSONBase{ID: parts[3]}}}} 225 | s.serviceChannel <- serviceUpdate 226 | return 227 | } 228 | glog.Infof("Unknown service delete: %#v", parts) 229 | } 230 | } 231 | 232 | func (s ConfigSourceEtcd) ProcessEndpointResponse(response *etcd.Response) { 233 | glog.Infof("Processing a change in endpoint configuration... %s", *response) 234 | var endpoints api.Endpoints 235 | err := runtime.DefaultCodec.DecodeInto([]byte(response.Node.Value), &endpoints) 236 | if err != nil { 237 | glog.Errorf("Failed to parse service out of etcd key: %v : %+v", response.Node.Value, err) 238 | return 239 | } 240 | endpointsUpdate := EndpointsUpdate{Op: ADD, Endpoints: []api.Endpoints{endpoints}} 241 | s.endpointsChannel <- endpointsUpdate 242 | } 243 | -------------------------------------------------------------------------------- /pkg/proxy/config/file.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | // Reads the configuration from the file. Example file for two services [nodejs & mysql] 18 | //{"Services": [ 19 | // { 20 | // "Name":"nodejs", 21 | // "Port":10000, 22 | // "Endpoints":["10.240.180.168:8000", "10.240.254.199:8000", "10.240.62.150:8000"] 23 | // }, 24 | // { 25 | // "Name":"mysql", 26 | // "Port":10001, 27 | // "Endpoints":["10.240.180.168:9000", "10.240.254.199:9000", "10.240.62.150:9000"] 28 | // } 29 | //] 30 | //} 31 | 32 | package config 33 | 34 | import ( 35 | "bytes" 36 | "encoding/json" 37 | "fmt" 38 | "io/ioutil" 39 | "reflect" 40 | "time" 41 | 42 | "github.com/GoogleCloudPlatform/kubernetes/pkg/api" 43 | "github.com/golang/glog" 44 | ) 45 | 46 | // serviceConfig is a deserialized form of the config file format which ConfigSourceFile accepts. 47 | type serviceConfig struct { 48 | Services []struct { 49 | Name string `json: "name"` 50 | Port int `json: "port"` 51 | Endpoints []string `json: "endpoints"` 52 | } `json: "service"` 53 | } 54 | 55 | // ConfigSourceFile periodically reads service configurations in JSON from a file, and sends the services and endpoints defined in the file to the specified channels. 56 | type ConfigSourceFile struct { 57 | serviceChannel chan ServiceUpdate 58 | endpointsChannel chan EndpointsUpdate 59 | filename string 60 | } 61 | 62 | // NewConfigSourceFile creates a new ConfigSourceFile and let it immediately runs the created ConfigSourceFile in a goroutine. 63 | func NewConfigSourceFile(filename string, serviceChannel chan ServiceUpdate, endpointsChannel chan EndpointsUpdate) ConfigSourceFile { 64 | config := ConfigSourceFile{ 65 | filename: filename, 66 | serviceChannel: serviceChannel, 67 | endpointsChannel: endpointsChannel, 68 | } 69 | go config.Run() 70 | return config 71 | } 72 | 73 | // Run begins watching the config file. 74 | func (s ConfigSourceFile) Run() { 75 | glog.Infof("Watching file %s", s.filename) 76 | var lastData []byte 77 | var lastServices []api.Service 78 | var lastEndpoints []api.Endpoints 79 | 80 | sleep := 5 * time.Second 81 | // Used to avoid spamming the error log file, makes error logging edge triggered. 82 | hadSuccess := true 83 | for { 84 | data, err := ioutil.ReadFile(s.filename) 85 | if err != nil { 86 | msg := fmt.Sprintf("Couldn't read file: %s : %v", s.filename, err) 87 | if hadSuccess { 88 | glog.Error(msg) 89 | } else { 90 | glog.V(1).Info(msg) 91 | } 92 | hadSuccess = false 93 | time.Sleep(sleep) 94 | continue 95 | } 96 | hadSuccess = true 97 | 98 | if bytes.Equal(lastData, data) { 99 | time.Sleep(sleep) 100 | continue 101 | } 102 | lastData = data 103 | 104 | config := &serviceConfig{} 105 | if err = json.Unmarshal(data, config); err != nil { 106 | glog.Errorf("Couldn't unmarshal configuration from file : %s %v", data, err) 107 | continue 108 | } 109 | // Ok, we have a valid configuration, send to channel for 110 | // rejiggering. 111 | newServices := make([]api.Service, len(config.Services)) 112 | newEndpoints := make([]api.Endpoints, len(config.Services)) 113 | for i, service := range config.Services { 114 | newServices[i] = api.Service{JSONBase: api.JSONBase{ID: service.Name}, Port: service.Port} 115 | newEndpoints[i] = api.Endpoints{JSONBase: api.JSONBase{ID: service.Name}, Endpoints: service.Endpoints} 116 | } 117 | if !reflect.DeepEqual(lastServices, newServices) { 118 | serviceUpdate := ServiceUpdate{Op: SET, Services: newServices} 119 | s.serviceChannel <- serviceUpdate 120 | lastServices = newServices 121 | } 122 | if !reflect.DeepEqual(lastEndpoints, newEndpoints) { 123 | endpointsUpdate := EndpointsUpdate{Op: SET, Endpoints: newEndpoints} 124 | s.endpointsChannel <- endpointsUpdate 125 | lastEndpoints = newEndpoints 126 | } 127 | 128 | time.Sleep(sleep) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /pkg/proxy/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | // Package proxy implements the layer-3 network proxy. 18 | package proxy 19 | -------------------------------------------------------------------------------- /pkg/proxy/loadbalancer.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package proxy 18 | 19 | import ( 20 | "github.com/vishvananda/netns" 21 | "net" 22 | ) 23 | 24 | // LoadBalancer is an interface for distributing incoming requests to service endpoints. 25 | type LoadBalancer interface { 26 | // NextEndpoint returns the namespace and endpoint to handle a request for the given 27 | // service and source address. 28 | NextEndpoint(service string, srcAddr net.Addr) (netns.NsHandle, string, error) 29 | } 30 | -------------------------------------------------------------------------------- /pkg/proxy/proxier.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package proxy 18 | 19 | import ( 20 | "fmt" 21 | "io" 22 | "net" 23 | "runtime" 24 | "strconv" 25 | "strings" 26 | "sync" 27 | "time" 28 | 29 | "github.com/GoogleCloudPlatform/kubernetes/pkg/api" 30 | "github.com/GoogleCloudPlatform/kubernetes/pkg/util" 31 | "github.com/golang/glog" 32 | "github.com/vishvananda/netns" 33 | ) 34 | 35 | type serviceInfo struct { 36 | name string 37 | port int 38 | protocol string 39 | socket proxySocket 40 | timeout time.Duration 41 | mu sync.Mutex // protects active 42 | active bool 43 | } 44 | 45 | func (si *serviceInfo) isActive() bool { 46 | si.mu.Lock() 47 | defer si.mu.Unlock() 48 | return si.active 49 | } 50 | 51 | func (si *serviceInfo) setActive(val bool) bool { 52 | si.mu.Lock() 53 | defer si.mu.Unlock() 54 | tmp := si.active 55 | si.active = val 56 | return tmp 57 | } 58 | 59 | // How long we wait for a connection to a backend. 60 | const endpointDialTimeout = 5 * time.Second 61 | 62 | // How long we retry when opening a listening socket. 63 | const listenTimeout = 5 * time.Second 64 | 65 | // Abstraction over TCP/UDP sockets which are proxied. 66 | type proxySocket interface { 67 | // Addr gets the net.Addr for a proxySocket. 68 | Addr() net.Addr 69 | // Close stops the proxySocket from accepting incoming connections. Each implementation should comment 70 | // on the impact of calling Close while sessions are active. 71 | Close() error 72 | // ProxyLoop proxies incoming connections for the specified service to the service endpoints. 73 | ProxyLoop(service string, proxier *Proxier) 74 | } 75 | 76 | // tcpProxySocket implements proxySocket. Close() is implemented by net.Listener. When Close() is called, 77 | // no new connections are allowed but existing connections are left untouched. 78 | type tcpProxySocket struct { 79 | net.Listener 80 | } 81 | 82 | func retryDial(network, address string, timeout time.Duration) (net.Conn, error) { 83 | endTime := time.Now().Add(timeout) 84 | for { 85 | remaining := endTime.Sub(time.Now()) 86 | outConn, err := net.DialTimeout(network, address, remaining) 87 | if err != nil { 88 | if endTime.After(time.Now()) { 89 | glog.Infof("Dial retrying on error for %s: %v", remaining, err) 90 | time.Sleep(100 * time.Millisecond) 91 | continue 92 | } else { 93 | return nil, fmt.Errorf("Dial retry timed out") 94 | } 95 | } 96 | return outConn, err 97 | } 98 | } 99 | 100 | func (tcp *tcpProxySocket) ProxyLoop(service string, proxier *Proxier) { 101 | info, found := proxier.getServiceInfo(service) 102 | if !found { 103 | glog.Errorf("Failed to find service: %s", service) 104 | return 105 | } 106 | for { 107 | if !info.isActive() { 108 | break 109 | } 110 | 111 | // Block until a connection is made. 112 | inConn, err := tcp.Accept() 113 | if err != nil { 114 | glog.Errorf("Accept failed: %v", err) 115 | continue 116 | } 117 | glog.Infof("Accepted TCP connection from %v to %v", inConn.RemoteAddr(), inConn.LocalAddr()) 118 | ns, endpoint, err := proxier.loadBalancer.NextEndpoint(service, inConn.RemoteAddr()) 119 | if err != nil { 120 | glog.Errorf("Couldn't find an endpoint for %s %v", service, err) 121 | inConn.Close() 122 | continue 123 | } 124 | glog.Infof("Mapped service %s to endpoint %s", service, endpoint) 125 | // TODO: This could spin up a new goroutine to make the outbound connection, 126 | // and keep accepting inbound traffic. 127 | if ns.IsOpen() { 128 | glog.Infof("Using namespace %v for endpoint %s", ns, endpoint) 129 | runtime.LockOSThread() 130 | defer runtime.UnlockOSThread() 131 | origns, err := netns.Get() 132 | if err != nil { 133 | glog.Errorf("Failed to get original ns: %v", err) 134 | continue 135 | } 136 | err = netns.Set(ns) 137 | if err != nil { 138 | glog.Errorf("Failed to set ns: %v", err) 139 | continue 140 | } 141 | defer netns.Set(origns) 142 | } 143 | outConn, err := retryDial("tcp", endpoint, endpointDialTimeout) 144 | if err != nil { 145 | // TODO: Try another endpoint? 146 | glog.Errorf("Dial failed: %v", err) 147 | inConn.Close() 148 | continue 149 | } 150 | // Spin up an async copy loop. 151 | proxyTCP(inConn.(*net.TCPConn), outConn.(*net.TCPConn)) 152 | } 153 | } 154 | 155 | // proxyTCP proxies data bi-directionally between in and out. 156 | func proxyTCP(in, out *net.TCPConn) { 157 | glog.Infof("Creating proxy between %v <-> %v <-> %v <-> %v", 158 | in.RemoteAddr(), in.LocalAddr(), out.LocalAddr(), out.RemoteAddr()) 159 | go copyBytes(in, out) 160 | go copyBytes(out, in) 161 | } 162 | 163 | // udpProxySocket implements proxySocket. Close() is implemented by net.UDPConn. When Close() is called, 164 | // no new connections are allowed and existing connections are broken. 165 | // TODO: We could lame-duck this ourselves, if it becomes important. 166 | type udpProxySocket struct { 167 | *net.UDPConn 168 | } 169 | 170 | func (udp *udpProxySocket) Addr() net.Addr { 171 | return udp.LocalAddr() 172 | } 173 | 174 | // Holds all the known UDP clients that have not timed out. 175 | type clientCache struct { 176 | mu sync.Mutex 177 | clients map[string]net.Conn // addr string -> connection 178 | } 179 | 180 | func newClientCache() *clientCache { 181 | return &clientCache{clients: map[string]net.Conn{}} 182 | } 183 | 184 | func (udp *udpProxySocket) ProxyLoop(service string, proxier *Proxier) { 185 | info, found := proxier.getServiceInfo(service) 186 | if !found { 187 | glog.Errorf("Failed to find service: %s", service) 188 | return 189 | } 190 | activeClients := newClientCache() 191 | var buffer [4096]byte // 4KiB should be enough for most whole-packets 192 | for { 193 | if !info.isActive() { 194 | break 195 | } 196 | 197 | // Block until data arrives. 198 | // TODO: Accumulate a histogram of n or something, to fine tune the buffer size. 199 | n, cliAddr, err := udp.ReadFrom(buffer[0:]) 200 | if err != nil { 201 | if e, ok := err.(net.Error); ok { 202 | if e.Temporary() { 203 | glog.Infof("ReadFrom had a temporary failure: %v", err) 204 | continue 205 | } 206 | } 207 | glog.Errorf("ReadFrom failed, exiting ProxyLoop: %v", err) 208 | break 209 | } 210 | // If this is a client we know already, reuse the connection and goroutine. 211 | svrConn, err := udp.getBackendConn(activeClients, cliAddr, proxier, service, info.timeout) 212 | if err != nil { 213 | continue 214 | } 215 | // TODO: It would be nice to let the goroutine handle this write, but we don't 216 | // really want to copy the buffer. We could do a pool of buffers or something. 217 | _, err = svrConn.Write(buffer[0:n]) 218 | if err != nil { 219 | if !logTimeout(err) { 220 | glog.Errorf("Write failed: %v", err) 221 | // TODO: Maybe tear down the goroutine for this client/server pair? 222 | } 223 | continue 224 | } 225 | svrConn.SetDeadline(time.Now().Add(info.timeout)) 226 | if err != nil { 227 | glog.Errorf("SetDeadline failed: %v", err) 228 | continue 229 | } 230 | } 231 | } 232 | 233 | func (udp *udpProxySocket) getBackendConn(activeClients *clientCache, cliAddr net.Addr, proxier *Proxier, service string, timeout time.Duration) (net.Conn, error) { 234 | activeClients.mu.Lock() 235 | defer activeClients.mu.Unlock() 236 | 237 | svrConn, found := activeClients.clients[cliAddr.String()] 238 | if !found { 239 | // TODO: This could spin up a new goroutine to make the outbound connection, 240 | // and keep accepting inbound traffic. 241 | glog.Infof("New UDP connection from %s", cliAddr) 242 | ns, endpoint, err := proxier.loadBalancer.NextEndpoint(service, cliAddr) 243 | if err != nil { 244 | glog.Errorf("Couldn't find an endpoint for %s %v", service, err) 245 | return nil, err 246 | } 247 | glog.Infof("Mapped service %s to endpoint %s", service, endpoint) 248 | if ns.IsOpen() { 249 | glog.Infof("Using namespace %v for endpoint %s", ns, endpoint) 250 | runtime.LockOSThread() 251 | netns.Set(ns) 252 | defer runtime.UnlockOSThread() 253 | } 254 | svrConn, err = retryDial("udp", endpoint, endpointDialTimeout) 255 | if err != nil { 256 | // TODO: Try another endpoint? 257 | glog.Errorf("Dial failed: %v", err) 258 | return nil, err 259 | } 260 | activeClients.clients[cliAddr.String()] = svrConn 261 | go func(cliAddr net.Addr, svrConn net.Conn, activeClients *clientCache, timeout time.Duration) { 262 | defer util.HandleCrash() 263 | udp.proxyClient(cliAddr, svrConn, activeClients, timeout) 264 | }(cliAddr, svrConn, activeClients, timeout) 265 | } 266 | return svrConn, nil 267 | } 268 | 269 | // This function is expected to be called as a goroutine. 270 | func (udp *udpProxySocket) proxyClient(cliAddr net.Addr, svrConn net.Conn, activeClients *clientCache, timeout time.Duration) { 271 | defer svrConn.Close() 272 | var buffer [4096]byte 273 | for { 274 | n, err := svrConn.Read(buffer[0:]) 275 | if err != nil { 276 | if !logTimeout(err) { 277 | glog.Errorf("Read failed: %v", err) 278 | } 279 | break 280 | } 281 | svrConn.SetDeadline(time.Now().Add(timeout)) 282 | if err != nil { 283 | glog.Errorf("SetDeadline failed: %v", err) 284 | break 285 | } 286 | n, err = udp.WriteTo(buffer[0:n], cliAddr) 287 | if err != nil { 288 | if !logTimeout(err) { 289 | glog.Errorf("WriteTo failed: %v", err) 290 | } 291 | break 292 | } 293 | } 294 | activeClients.mu.Lock() 295 | delete(activeClients.clients, cliAddr.String()) 296 | activeClients.mu.Unlock() 297 | } 298 | 299 | func logTimeout(err error) bool { 300 | if e, ok := err.(net.Error); ok { 301 | if e.Timeout() { 302 | glog.Infof("connection to endpoint closed due to inactivity") 303 | return true 304 | } 305 | } 306 | return false 307 | } 308 | 309 | func newProxySocket(protocol string, host string, port int) (proxySocket, error) { 310 | endTime := time.Now().Add(listenTimeout) 311 | for { 312 | remaining := endTime.Sub(time.Now()) 313 | sock, err := innerProxySocket(protocol, host, port) 314 | if err != nil { 315 | // TODO(vish): don't retry if the socket is in use 316 | if endTime.After(time.Now()) { 317 | glog.Infof("ProxySocket retrying on error for %v: %v", remaining, err) 318 | time.Sleep(100 * time.Millisecond) 319 | continue 320 | } else { 321 | return nil, fmt.Errorf("ProxySocket retry timed out") 322 | } 323 | } 324 | return sock, err 325 | } 326 | 327 | } 328 | 329 | func innerProxySocket(protocol string, host string, port int) (proxySocket, error) { 330 | switch strings.ToUpper(protocol) { 331 | case "TCP": 332 | listener, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) 333 | if err != nil { 334 | return nil, err 335 | } 336 | return &tcpProxySocket{listener}, nil 337 | case "UDP": 338 | addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(host, strconv.Itoa(port))) 339 | if err != nil { 340 | return nil, err 341 | } 342 | conn, err := net.ListenUDP("udp", addr) 343 | if err != nil { 344 | return nil, err 345 | } 346 | return &udpProxySocket{conn}, nil 347 | } 348 | return nil, fmt.Errorf("Unknown protocol %q", protocol) 349 | } 350 | 351 | // Proxier is a simple proxy for TCP connections between a localhost:lport 352 | // and services that provide the actual implementations. 353 | type Proxier struct { 354 | loadBalancer LoadBalancer 355 | mu sync.Mutex // protects serviceMap 356 | serviceMap map[string]*serviceInfo 357 | address string 358 | // NOTE(vish): this ns probably should be part of the Service struct 359 | ns netns.NsHandle 360 | } 361 | 362 | // NOTE(vish): this ns probably should be part of the Service struct 363 | func (proxier *Proxier) SetNs(ns netns.NsHandle) { 364 | proxier.ns = ns 365 | } 366 | 367 | // NewProxier returns a new Proxier given a LoadBalancer and an 368 | // address on which to listen 369 | func NewProxier(loadBalancer LoadBalancer, address string) *Proxier { 370 | return &Proxier{ 371 | loadBalancer: loadBalancer, 372 | serviceMap: make(map[string]*serviceInfo), 373 | address: address, 374 | // NOTE(vish): this ns probably should be part of the Service struct 375 | ns: netns.None(), 376 | } 377 | } 378 | 379 | func copyBytes(in, out *net.TCPConn) { 380 | glog.Infof("Copying from %v <-> %v <-> %v <-> %v", 381 | in.RemoteAddr(), in.LocalAddr(), out.LocalAddr(), out.RemoteAddr()) 382 | if _, err := io.Copy(in, out); err != nil { 383 | glog.Errorf("I/O error: %v", err) 384 | } 385 | in.CloseRead() 386 | out.CloseWrite() 387 | } 388 | 389 | // StopProxy stops the proxy for the named service. 390 | func (proxier *Proxier) StopProxy(service string) error { 391 | // TODO: delete from map here? 392 | info, found := proxier.getServiceInfo(service) 393 | if !found { 394 | return fmt.Errorf("unknown service: %s", service) 395 | } 396 | return proxier.stopProxyInternal(info) 397 | } 398 | 399 | func (proxier *Proxier) stopProxyInternal(info *serviceInfo) error { 400 | if !info.setActive(false) { 401 | return nil 402 | } 403 | glog.Infof("Removing service: %s", info.name) 404 | return info.socket.Close() 405 | } 406 | 407 | func (proxier *Proxier) getServiceInfo(service string) (*serviceInfo, bool) { 408 | proxier.mu.Lock() 409 | defer proxier.mu.Unlock() 410 | info, ok := proxier.serviceMap[service] 411 | return info, ok 412 | } 413 | 414 | func (proxier *Proxier) setServiceInfo(service string, info *serviceInfo) { 415 | proxier.mu.Lock() 416 | defer proxier.mu.Unlock() 417 | info.name = service 418 | proxier.serviceMap[service] = info 419 | } 420 | 421 | // used to globally lock around unused ports. Only used in testing. 422 | var unusedPortLock sync.Mutex 423 | 424 | // addServiceOnUnusedPort starts listening for a new service, returning the 425 | // port it's using. For testing on a system with unknown ports used. The timeout only applies to UDP 426 | // connections, for now. 427 | func (proxier *Proxier) addServiceOnUnusedPort(service, protocol string, timeout time.Duration) (string, error) { 428 | unusedPortLock.Lock() 429 | defer unusedPortLock.Unlock() 430 | sock, err := newProxySocket(protocol, proxier.address, 0) 431 | if err != nil { 432 | return "", err 433 | } 434 | _, port, err := net.SplitHostPort(sock.Addr().String()) 435 | if err != nil { 436 | return "", err 437 | } 438 | portNum, err := strconv.Atoi(port) 439 | if err != nil { 440 | return "", err 441 | } 442 | proxier.setServiceInfo(service, &serviceInfo{ 443 | port: portNum, 444 | protocol: protocol, 445 | active: true, 446 | socket: sock, 447 | timeout: timeout, 448 | }) 449 | proxier.startAccepting(service, sock) 450 | return port, nil 451 | } 452 | 453 | func (proxier *Proxier) startAccepting(service string, sock proxySocket) { 454 | glog.Infof("Listening for %s on %s:%s", service, sock.Addr().Network(), sock.Addr().String()) 455 | go func(service string, proxier *Proxier) { 456 | defer util.HandleCrash() 457 | sock.ProxyLoop(service, proxier) 458 | }(service, proxier) 459 | } 460 | 461 | // How long we leave idle UDP connections open. 462 | const udpIdleTimeout = 1 * time.Minute 463 | 464 | func (proxier *Proxier) AddService(service, protocol string, port int) (int, error) { 465 | glog.Infof("Adding proxy %s on %s:%d", service, proxier.address, port) 466 | if proxier.ns.IsOpen() { 467 | glog.Infof("Using namespace %v for proxy %s", proxier.ns, service) 468 | runtime.LockOSThread() 469 | defer runtime.UnlockOSThread() 470 | origns, err := netns.Get() 471 | if err != nil { 472 | return 0, err 473 | } 474 | err = netns.Set(proxier.ns) 475 | if err != nil { 476 | return 0, err 477 | } 478 | defer netns.Set(origns) 479 | } 480 | sock, err := newProxySocket(protocol, proxier.address, port) 481 | if err != nil { 482 | return 0, err 483 | } 484 | _, portStr, err := net.SplitHostPort(sock.Addr().String()) 485 | if err != nil { 486 | return 0, err 487 | } 488 | portNum, err := strconv.Atoi(portStr) 489 | if err != nil { 490 | return 0, err 491 | } 492 | proxier.setServiceInfo(service, &serviceInfo{ 493 | port: portNum, 494 | protocol: protocol, 495 | active: true, 496 | socket: sock, 497 | timeout: udpIdleTimeout, 498 | }) 499 | proxier.startAccepting(service, sock) 500 | return portNum, err 501 | } 502 | 503 | // OnUpdate manages the active set of service proxies. 504 | // Active service proxies are reinitialized if found in the update set or 505 | // shutdown if missing from the update set. 506 | func (proxier *Proxier) OnUpdate(services []api.Service) { 507 | glog.Infof("Received update notice: %+v", services) 508 | activeServices := util.StringSet{} 509 | for _, service := range services { 510 | activeServices.Insert(service.ID) 511 | info, exists := proxier.getServiceInfo(service.ID) 512 | // TODO: check health of the socket? What if ProxyLoop exited? 513 | if exists && info.isActive() && info.port == service.Port { 514 | continue 515 | } 516 | if exists && info.port != service.Port { 517 | err := proxier.stopProxyInternal(info) 518 | if err != nil { 519 | glog.Errorf("error stopping %s: %v", info.name, err) 520 | } 521 | } 522 | glog.Infof("Adding a new service %s on %s port %d", service.ID, service.Protocol, service.Port) 523 | sock, err := newProxySocket(service.Protocol, proxier.address, service.Port) 524 | if err != nil { 525 | glog.Errorf("Failed to get a socket for %s: %+v", service.ID, err) 526 | continue 527 | } 528 | proxier.setServiceInfo(service.ID, &serviceInfo{ 529 | port: service.Port, 530 | protocol: service.Protocol, 531 | active: true, 532 | socket: sock, 533 | timeout: udpIdleTimeout, 534 | }) 535 | proxier.startAccepting(service.ID, sock) 536 | } 537 | proxier.mu.Lock() 538 | defer proxier.mu.Unlock() 539 | for name, info := range proxier.serviceMap { 540 | if !activeServices.Has(name) { 541 | err := proxier.stopProxyInternal(info) 542 | if err != nil { 543 | glog.Errorf("error stopping %s: %v", info.name, err) 544 | } 545 | } 546 | } 547 | } 548 | -------------------------------------------------------------------------------- /pkg/proxy/proxier_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package proxy 18 | 19 | import ( 20 | "fmt" 21 | "io/ioutil" 22 | "net" 23 | "net/http" 24 | "net/http/httptest" 25 | "net/url" 26 | "strconv" 27 | "testing" 28 | "time" 29 | 30 | "github.com/GoogleCloudPlatform/kubernetes/pkg/api" 31 | ) 32 | 33 | func waitForClosedPortTCP(p *Proxier, proxyPort string) error { 34 | for i := 0; i < 50; i++ { 35 | conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", proxyPort)) 36 | if err != nil { 37 | return nil 38 | } 39 | conn.Close() 40 | time.Sleep(1 * time.Millisecond) 41 | } 42 | return fmt.Errorf("port %s still open", proxyPort) 43 | } 44 | 45 | func waitForClosedPortUDP(p *Proxier, proxyPort string) error { 46 | for i := 0; i < 50; i++ { 47 | conn, err := net.Dial("udp", net.JoinHostPort("127.0.0.1", proxyPort)) 48 | if err != nil { 49 | return nil 50 | } 51 | conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) 52 | // To detect a closed UDP port write, then read. 53 | _, err = conn.Write([]byte("x")) 54 | if err != nil { 55 | if e, ok := err.(net.Error); ok && !e.Timeout() { 56 | return nil 57 | } 58 | } 59 | var buf [4]byte 60 | _, err = conn.Read(buf[0:]) 61 | if err != nil { 62 | if e, ok := err.(net.Error); ok && !e.Timeout() { 63 | return nil 64 | } 65 | } 66 | conn.Close() 67 | time.Sleep(1 * time.Millisecond) 68 | } 69 | return fmt.Errorf("port %s still open", proxyPort) 70 | } 71 | 72 | var tcpServerPort string 73 | var udpServerPort string 74 | 75 | func init() { 76 | // TCP setup. 77 | tcp := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 78 | w.WriteHeader(http.StatusOK) 79 | w.Write([]byte(r.URL.Path[1:])) 80 | })) 81 | u, err := url.Parse(tcp.URL) 82 | if err != nil { 83 | panic(fmt.Sprintf("failed to parse: %v", err)) 84 | } 85 | _, tcpServerPort, err = net.SplitHostPort(u.Host) 86 | if err != nil { 87 | panic(fmt.Sprintf("failed to parse: %v", err)) 88 | } 89 | 90 | // UDP setup. 91 | udp, err := newUDPEchoServer() 92 | if err != nil { 93 | panic(fmt.Sprintf("failed to make a UDP server: %v", err)) 94 | } 95 | _, udpServerPort, err = net.SplitHostPort(udp.LocalAddr().String()) 96 | if err != nil { 97 | panic(fmt.Sprintf("failed to parse: %v", err)) 98 | } 99 | go udp.Loop() 100 | } 101 | 102 | func testEchoTCP(t *testing.T, address, port string) { 103 | path := "aaaaa" 104 | res, err := http.Get("http://" + address + ":" + port + "/" + path) 105 | if err != nil { 106 | t.Fatalf("error connecting to server: %v", err) 107 | } 108 | defer res.Body.Close() 109 | data, err := ioutil.ReadAll(res.Body) 110 | if err != nil { 111 | t.Errorf("error reading data: %v %v", err, string(data)) 112 | } 113 | if string(data) != path { 114 | t.Errorf("expected: %s, got %s", path, string(data)) 115 | } 116 | } 117 | 118 | func testEchoUDP(t *testing.T, address, port string) { 119 | data := "abc123" 120 | 121 | conn, err := net.Dial("udp", net.JoinHostPort(address, port)) 122 | if err != nil { 123 | t.Fatalf("error connecting to server: %v", err) 124 | } 125 | if _, err := conn.Write([]byte(data)); err != nil { 126 | t.Fatalf("error sending to server: %v", err) 127 | } 128 | var resp [1024]byte 129 | n, err := conn.Read(resp[0:]) 130 | if err != nil { 131 | t.Errorf("error receiving data: %v", err) 132 | } 133 | if string(resp[0:n]) != data { 134 | t.Errorf("expected: %s, got %s", data, string(resp[0:n])) 135 | } 136 | } 137 | 138 | func TestTCPProxy(t *testing.T) { 139 | lb := NewLoadBalancerRR() 140 | lb.OnUpdate([]api.Endpoints{ 141 | { 142 | JSONBase: api.JSONBase{ID: "echo"}, 143 | Endpoints: []string{net.JoinHostPort("127.0.0.1", tcpServerPort)}, 144 | }, 145 | }) 146 | 147 | p := NewProxier(lb, "127.0.0.1") 148 | 149 | proxyPort, err := p.addServiceOnUnusedPort("echo", "TCP", 0) 150 | if err != nil { 151 | t.Fatalf("error adding new service: %#v", err) 152 | } 153 | testEchoTCP(t, "127.0.0.1", proxyPort) 154 | } 155 | 156 | func TestUDPProxy(t *testing.T) { 157 | lb := NewLoadBalancerRR() 158 | lb.OnUpdate([]api.Endpoints{ 159 | { 160 | JSONBase: api.JSONBase{ID: "echo"}, 161 | Endpoints: []string{net.JoinHostPort("127.0.0.1", udpServerPort)}, 162 | }, 163 | }) 164 | 165 | p := NewProxier(lb, "127.0.0.1") 166 | 167 | proxyPort, err := p.addServiceOnUnusedPort("echo", "UDP", time.Second) 168 | if err != nil { 169 | t.Fatalf("error adding new service: %#v", err) 170 | } 171 | testEchoUDP(t, "127.0.0.1", proxyPort) 172 | } 173 | 174 | func TestTCPProxyStop(t *testing.T) { 175 | lb := NewLoadBalancerRR() 176 | lb.OnUpdate([]api.Endpoints{ 177 | { 178 | JSONBase: api.JSONBase{ID: "echo"}, 179 | Endpoints: []string{net.JoinHostPort("127.0.0.1", tcpServerPort)}, 180 | }, 181 | }) 182 | 183 | p := NewProxier(lb, "127.0.0.1") 184 | 185 | proxyPort, err := p.addServiceOnUnusedPort("echo", "TCP", 0) 186 | if err != nil { 187 | t.Fatalf("error adding new service: %#v", err) 188 | } 189 | conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", proxyPort)) 190 | if err != nil { 191 | t.Fatalf("error connecting to proxy: %v", err) 192 | } 193 | conn.Close() 194 | 195 | p.StopProxy("echo") 196 | // Wait for the port to really close. 197 | if err := waitForClosedPortTCP(p, proxyPort); err != nil { 198 | t.Fatalf(err.Error()) 199 | } 200 | } 201 | 202 | func TestUDPProxyStop(t *testing.T) { 203 | lb := NewLoadBalancerRR() 204 | lb.OnUpdate([]api.Endpoints{ 205 | { 206 | JSONBase: api.JSONBase{ID: "echo"}, 207 | Endpoints: []string{net.JoinHostPort("127.0.0.1", udpServerPort)}, 208 | }, 209 | }) 210 | 211 | p := NewProxier(lb, "127.0.0.1") 212 | 213 | proxyPort, err := p.addServiceOnUnusedPort("echo", "UDP", time.Second) 214 | if err != nil { 215 | t.Fatalf("error adding new service: %#v", err) 216 | } 217 | conn, err := net.Dial("udp", net.JoinHostPort("127.0.0.1", proxyPort)) 218 | if err != nil { 219 | t.Fatalf("error connecting to proxy: %v", err) 220 | } 221 | conn.Close() 222 | 223 | p.StopProxy("echo") 224 | // Wait for the port to really close. 225 | if err := waitForClosedPortUDP(p, proxyPort); err != nil { 226 | t.Fatalf(err.Error()) 227 | } 228 | } 229 | 230 | func TestTCPProxyUpdateDelete(t *testing.T) { 231 | lb := NewLoadBalancerRR() 232 | lb.OnUpdate([]api.Endpoints{ 233 | { 234 | JSONBase: api.JSONBase{ID: "echo"}, 235 | Endpoints: []string{net.JoinHostPort("127.0.0.1", tcpServerPort)}, 236 | }, 237 | }) 238 | 239 | p := NewProxier(lb, "127.0.0.1") 240 | 241 | proxyPort, err := p.addServiceOnUnusedPort("echo", "TCP", 0) 242 | if err != nil { 243 | t.Fatalf("error adding new service: %#v", err) 244 | } 245 | conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", proxyPort)) 246 | if err != nil { 247 | t.Fatalf("error connecting to proxy: %v", err) 248 | } 249 | conn.Close() 250 | 251 | p.OnUpdate([]api.Service{}) 252 | if err := waitForClosedPortTCP(p, proxyPort); err != nil { 253 | t.Fatalf(err.Error()) 254 | } 255 | } 256 | 257 | func TestUDPProxyUpdateDelete(t *testing.T) { 258 | lb := NewLoadBalancerRR() 259 | lb.OnUpdate([]api.Endpoints{ 260 | { 261 | JSONBase: api.JSONBase{ID: "echo"}, 262 | Endpoints: []string{net.JoinHostPort("127.0.0.1", udpServerPort)}, 263 | }, 264 | }) 265 | 266 | p := NewProxier(lb, "127.0.0.1") 267 | 268 | proxyPort, err := p.addServiceOnUnusedPort("echo", "UDP", time.Second) 269 | if err != nil { 270 | t.Fatalf("error adding new service: %#v", err) 271 | } 272 | conn, err := net.Dial("udp", net.JoinHostPort("127.0.0.1", proxyPort)) 273 | if err != nil { 274 | t.Fatalf("error connecting to proxy: %v", err) 275 | } 276 | conn.Close() 277 | 278 | p.OnUpdate([]api.Service{}) 279 | if err := waitForClosedPortUDP(p, proxyPort); err != nil { 280 | t.Fatalf(err.Error()) 281 | } 282 | } 283 | 284 | func TestTCPProxyUpdateDeleteUpdate(t *testing.T) { 285 | lb := NewLoadBalancerRR() 286 | lb.OnUpdate([]api.Endpoints{ 287 | { 288 | JSONBase: api.JSONBase{ID: "echo"}, 289 | Endpoints: []string{net.JoinHostPort("127.0.0.1", tcpServerPort)}, 290 | }, 291 | }) 292 | 293 | p := NewProxier(lb, "127.0.0.1") 294 | 295 | proxyPort, err := p.addServiceOnUnusedPort("echo", "TCP", 0) 296 | if err != nil { 297 | t.Fatalf("error adding new service: %#v", err) 298 | } 299 | conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", proxyPort)) 300 | if err != nil { 301 | t.Fatalf("error connecting to proxy: %v", err) 302 | } 303 | conn.Close() 304 | 305 | p.OnUpdate([]api.Service{}) 306 | if err := waitForClosedPortTCP(p, proxyPort); err != nil { 307 | t.Fatalf(err.Error()) 308 | } 309 | proxyPortNum, _ := strconv.Atoi(proxyPort) 310 | p.OnUpdate([]api.Service{ 311 | {JSONBase: api.JSONBase{ID: "echo"}, Port: proxyPortNum, Protocol: "TCP"}, 312 | }) 313 | testEchoTCP(t, "127.0.0.1", proxyPort) 314 | } 315 | 316 | func TestUDPProxyUpdateDeleteUpdate(t *testing.T) { 317 | lb := NewLoadBalancerRR() 318 | lb.OnUpdate([]api.Endpoints{ 319 | { 320 | JSONBase: api.JSONBase{ID: "echo"}, 321 | Endpoints: []string{net.JoinHostPort("127.0.0.1", udpServerPort)}, 322 | }, 323 | }) 324 | 325 | p := NewProxier(lb, "127.0.0.1") 326 | 327 | proxyPort, err := p.addServiceOnUnusedPort("echo", "UDP", time.Second) 328 | if err != nil { 329 | t.Fatalf("error adding new service: %#v", err) 330 | } 331 | conn, err := net.Dial("udp", net.JoinHostPort("127.0.0.1", proxyPort)) 332 | if err != nil { 333 | t.Fatalf("error connecting to proxy: %v", err) 334 | } 335 | conn.Close() 336 | 337 | p.OnUpdate([]api.Service{}) 338 | if err := waitForClosedPortUDP(p, proxyPort); err != nil { 339 | t.Fatalf(err.Error()) 340 | } 341 | proxyPortNum, _ := strconv.Atoi(proxyPort) 342 | p.OnUpdate([]api.Service{ 343 | {JSONBase: api.JSONBase{ID: "echo"}, Port: proxyPortNum, Protocol: "UDP"}, 344 | }) 345 | testEchoUDP(t, "127.0.0.1", proxyPort) 346 | } 347 | 348 | func TestTCPProxyUpdatePort(t *testing.T) { 349 | lb := NewLoadBalancerRR() 350 | lb.OnUpdate([]api.Endpoints{ 351 | { 352 | JSONBase: api.JSONBase{ID: "echo"}, 353 | Endpoints: []string{net.JoinHostPort("127.0.0.1", tcpServerPort)}, 354 | }, 355 | }) 356 | 357 | p := NewProxier(lb, "127.0.0.1") 358 | 359 | proxyPort, err := p.addServiceOnUnusedPort("echo", "TCP", 0) 360 | if err != nil { 361 | t.Fatalf("error adding new service: %#v", err) 362 | } 363 | 364 | // add a new dummy listener in order to get a port that is free 365 | l, _ := net.Listen("tcp", ":0") 366 | _, newPort, _ := net.SplitHostPort(l.Addr().String()) 367 | newPortNum, _ := strconv.Atoi(newPort) 368 | l.Close() 369 | 370 | // Wait for the socket to actually get free. 371 | if err := waitForClosedPortTCP(p, newPort); err != nil { 372 | t.Fatalf(err.Error()) 373 | } 374 | if proxyPort == newPort { 375 | t.Errorf("expected difference, got %s %s", newPort, proxyPort) 376 | } 377 | p.OnUpdate([]api.Service{ 378 | {JSONBase: api.JSONBase{ID: "echo"}, Port: newPortNum, Protocol: "TCP"}, 379 | }) 380 | if err := waitForClosedPortTCP(p, proxyPort); err != nil { 381 | t.Fatalf(err.Error()) 382 | } 383 | testEchoTCP(t, "127.0.0.1", newPort) 384 | 385 | // Ensure the old port is released and re-usable. 386 | l, err = net.Listen("tcp", net.JoinHostPort("", proxyPort)) 387 | if err != nil { 388 | t.Fatalf("can't claim released port: %s", err) 389 | } 390 | l.Close() 391 | } 392 | 393 | func TestUDPProxyUpdatePort(t *testing.T) { 394 | lb := NewLoadBalancerRR() 395 | lb.OnUpdate([]api.Endpoints{ 396 | { 397 | JSONBase: api.JSONBase{ID: "echo"}, 398 | Endpoints: []string{net.JoinHostPort("127.0.0.1", udpServerPort)}, 399 | }, 400 | }) 401 | 402 | p := NewProxier(lb, "127.0.0.1") 403 | 404 | proxyPort, err := p.addServiceOnUnusedPort("echo", "UDP", time.Second) 405 | if err != nil { 406 | t.Fatalf("error adding new service: %#v", err) 407 | } 408 | 409 | // add a new dummy listener in order to get a port that is free 410 | pc, _ := net.ListenPacket("udp", ":0") 411 | _, newPort, _ := net.SplitHostPort(pc.LocalAddr().String()) 412 | newPortNum, _ := strconv.Atoi(newPort) 413 | pc.Close() 414 | 415 | // Wait for the socket to actually get free. 416 | if err := waitForClosedPortUDP(p, newPort); err != nil { 417 | t.Fatalf(err.Error()) 418 | } 419 | if proxyPort == newPort { 420 | t.Errorf("expected difference, got %s %s", newPort, proxyPort) 421 | } 422 | p.OnUpdate([]api.Service{ 423 | {JSONBase: api.JSONBase{ID: "echo"}, Port: newPortNum, Protocol: "UDP"}, 424 | }) 425 | if err := waitForClosedPortUDP(p, proxyPort); err != nil { 426 | t.Fatalf(err.Error()) 427 | } 428 | testEchoUDP(t, "127.0.0.1", newPort) 429 | 430 | // Ensure the old port is released and re-usable. 431 | pc, err = net.ListenPacket("udp", net.JoinHostPort("", proxyPort)) 432 | if err != nil { 433 | t.Fatalf("can't claim released port: %s", err) 434 | } 435 | pc.Close() 436 | } 437 | 438 | // TODO: Test UDP timeouts. 439 | -------------------------------------------------------------------------------- /pkg/proxy/roundrobin.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package proxy 18 | 19 | import ( 20 | "errors" 21 | "net" 22 | "reflect" 23 | "strconv" 24 | "sync" 25 | 26 | "github.com/GoogleCloudPlatform/kubernetes/pkg/api" 27 | "github.com/golang/glog" 28 | "github.com/vishvananda/netns" 29 | ) 30 | 31 | var ( 32 | ErrMissingServiceEntry = errors.New("missing service entry") 33 | ErrMissingEndpoints = errors.New("missing endpoints") 34 | ) 35 | 36 | // LoadBalancerRR is a round-robin load balancer. 37 | type LoadBalancerRR struct { 38 | lock sync.RWMutex 39 | endpointsMap map[string][]string 40 | rrIndex map[string]int 41 | } 42 | 43 | // NewLoadBalancerRR returns a new LoadBalancerRR. 44 | func NewLoadBalancerRR() *LoadBalancerRR { 45 | return &LoadBalancerRR{ 46 | endpointsMap: make(map[string][]string), 47 | rrIndex: make(map[string]int), 48 | } 49 | } 50 | 51 | // NextEndpoint returns a service endpoint. 52 | // The service endpoint is chosen using the round-robin algorithm. 53 | func (lb *LoadBalancerRR) NextEndpoint(service string, srcAddr net.Addr) (netns.NsHandle, string, error) { 54 | ns := netns.None() 55 | lb.lock.RLock() 56 | endpoints, exists := lb.endpointsMap[service] 57 | index := lb.rrIndex[service] 58 | lb.lock.RUnlock() 59 | if !exists { 60 | return ns, "", ErrMissingServiceEntry 61 | } 62 | if len(endpoints) == 0 { 63 | return ns, "", ErrMissingEndpoints 64 | } 65 | endpoint := endpoints[index] 66 | lb.lock.Lock() 67 | lb.rrIndex[service] = (index + 1) % len(endpoints) 68 | lb.lock.Unlock() 69 | return ns, endpoint, nil 70 | } 71 | 72 | func isValidEndpoint(spec string) bool { 73 | _, port, err := net.SplitHostPort(spec) 74 | if err != nil { 75 | return false 76 | } 77 | value, err := strconv.Atoi(port) 78 | if err != nil { 79 | return false 80 | } 81 | return value > 0 82 | } 83 | 84 | func filterValidEndpoints(endpoints []string) []string { 85 | var result []string 86 | for _, spec := range endpoints { 87 | if isValidEndpoint(spec) { 88 | result = append(result, spec) 89 | } 90 | } 91 | return result 92 | } 93 | 94 | // OnUpdate manages the registered service endpoints. 95 | // Registered endpoints are updated if found in the update set or 96 | // unregistered if missing from the update set. 97 | func (lb *LoadBalancerRR) OnUpdate(endpoints []api.Endpoints) { 98 | registeredEndpoints := make(map[string]bool) 99 | lb.lock.Lock() 100 | defer lb.lock.Unlock() 101 | // Update endpoints for services. 102 | for _, endpoint := range endpoints { 103 | existingEndpoints, exists := lb.endpointsMap[endpoint.ID] 104 | validEndpoints := filterValidEndpoints(endpoint.Endpoints) 105 | if !exists || !reflect.DeepEqual(existingEndpoints, validEndpoints) { 106 | glog.Infof("LoadBalancerRR: Setting endpoints for %s to %+v", endpoint.ID, endpoint.Endpoints) 107 | lb.endpointsMap[endpoint.ID] = validEndpoints 108 | // Reset the round-robin index. 109 | lb.rrIndex[endpoint.ID] = 0 110 | } 111 | registeredEndpoints[endpoint.ID] = true 112 | } 113 | // Remove endpoints missing from the update. 114 | for k, v := range lb.endpointsMap { 115 | if _, exists := registeredEndpoints[k]; !exists { 116 | glog.Infof("LoadBalancerRR: Removing endpoints for %s -> %+v", k, v) 117 | delete(lb.endpointsMap, k) 118 | } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /pkg/proxy/roundrobin_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package proxy 18 | 19 | import ( 20 | "testing" 21 | 22 | "github.com/GoogleCloudPlatform/kubernetes/pkg/api" 23 | ) 24 | 25 | func TestValidateWorks(t *testing.T) { 26 | if isValidEndpoint("") { 27 | t.Errorf("Didn't fail for empty string") 28 | } 29 | if isValidEndpoint("foobar") { 30 | t.Errorf("Didn't fail with no port") 31 | } 32 | if isValidEndpoint("foobar:-1") { 33 | t.Errorf("Didn't fail with a negative port") 34 | } 35 | if !isValidEndpoint("foobar:8080") { 36 | t.Errorf("Failed a valid config.") 37 | } 38 | } 39 | 40 | func TestFilterWorks(t *testing.T) { 41 | endpoints := []string{"foobar:1", "foobar:2", "foobar:-1", "foobar:3", "foobar:-2"} 42 | filtered := filterValidEndpoints(endpoints) 43 | 44 | if len(filtered) != 3 { 45 | t.Errorf("Failed to filter to the correct size") 46 | } 47 | if filtered[0] != "foobar:1" { 48 | t.Errorf("Index zero is not foobar:1") 49 | } 50 | if filtered[1] != "foobar:2" { 51 | t.Errorf("Index one is not foobar:2") 52 | } 53 | if filtered[2] != "foobar:3" { 54 | t.Errorf("Index two is not foobar:3") 55 | } 56 | } 57 | 58 | func TestLoadBalanceFailsWithNoEndpoints(t *testing.T) { 59 | loadBalancer := NewLoadBalancerRR() 60 | var endpoints []api.Endpoints 61 | loadBalancer.OnUpdate(endpoints) 62 | _, endpoint, err := loadBalancer.NextEndpoint("foo", nil) 63 | if err == nil { 64 | t.Errorf("Didn't fail with non-existent service") 65 | } 66 | if len(endpoint) != 0 { 67 | t.Errorf("Got an endpoint") 68 | } 69 | } 70 | 71 | func expectEndpoint(t *testing.T, loadBalancer *LoadBalancerRR, service string, expected string) { 72 | _, endpoint, err := loadBalancer.NextEndpoint(service, nil) 73 | if err != nil { 74 | t.Errorf("Didn't find a service for %s, expected %s, failed with: %v", service, expected, err) 75 | } 76 | if endpoint != expected { 77 | t.Errorf("Didn't get expected endpoint for service %s, expected %s, got: %s", service, expected, endpoint) 78 | } 79 | } 80 | 81 | func TestLoadBalanceWorksWithSingleEndpoint(t *testing.T) { 82 | loadBalancer := NewLoadBalancerRR() 83 | _, endpoint, err := loadBalancer.NextEndpoint("foo", nil) 84 | if err == nil || len(endpoint) != 0 { 85 | t.Errorf("Didn't fail with non-existent service") 86 | } 87 | endpoints := make([]api.Endpoints, 1) 88 | endpoints[0] = api.Endpoints{ 89 | JSONBase: api.JSONBase{ID: "foo"}, 90 | Endpoints: []string{"endpoint1:40"}, 91 | } 92 | loadBalancer.OnUpdate(endpoints) 93 | expectEndpoint(t, loadBalancer, "foo", "endpoint1:40") 94 | expectEndpoint(t, loadBalancer, "foo", "endpoint1:40") 95 | expectEndpoint(t, loadBalancer, "foo", "endpoint1:40") 96 | expectEndpoint(t, loadBalancer, "foo", "endpoint1:40") 97 | } 98 | 99 | func TestLoadBalanceWorksWithMultipleEndpoints(t *testing.T) { 100 | loadBalancer := NewLoadBalancerRR() 101 | _, endpoint, err := loadBalancer.NextEndpoint("foo", nil) 102 | if err == nil || len(endpoint) != 0 { 103 | t.Errorf("Didn't fail with non-existent service") 104 | } 105 | endpoints := make([]api.Endpoints, 1) 106 | endpoints[0] = api.Endpoints{ 107 | JSONBase: api.JSONBase{ID: "foo"}, 108 | Endpoints: []string{"endpoint:1", "endpoint:2", "endpoint:3"}, 109 | } 110 | loadBalancer.OnUpdate(endpoints) 111 | expectEndpoint(t, loadBalancer, "foo", "endpoint:1") 112 | expectEndpoint(t, loadBalancer, "foo", "endpoint:2") 113 | expectEndpoint(t, loadBalancer, "foo", "endpoint:3") 114 | expectEndpoint(t, loadBalancer, "foo", "endpoint:1") 115 | } 116 | 117 | func TestLoadBalanceWorksWithMultipleEndpointsAndUpdates(t *testing.T) { 118 | loadBalancer := NewLoadBalancerRR() 119 | _, endpoint, err := loadBalancer.NextEndpoint("foo", nil) 120 | if err == nil || len(endpoint) != 0 { 121 | t.Errorf("Didn't fail with non-existent service") 122 | } 123 | endpoints := make([]api.Endpoints, 1) 124 | endpoints[0] = api.Endpoints{ 125 | JSONBase: api.JSONBase{ID: "foo"}, 126 | Endpoints: []string{"endpoint:1", "endpoint:2", "endpoint:3"}, 127 | } 128 | loadBalancer.OnUpdate(endpoints) 129 | expectEndpoint(t, loadBalancer, "foo", "endpoint:1") 130 | expectEndpoint(t, loadBalancer, "foo", "endpoint:2") 131 | expectEndpoint(t, loadBalancer, "foo", "endpoint:3") 132 | expectEndpoint(t, loadBalancer, "foo", "endpoint:1") 133 | expectEndpoint(t, loadBalancer, "foo", "endpoint:2") 134 | // Then update the configuration with one fewer endpoints, make sure 135 | // we start in the beginning again 136 | endpoints[0] = api.Endpoints{JSONBase: api.JSONBase{ID: "foo"}, 137 | Endpoints: []string{"endpoint:8", "endpoint:9"}, 138 | } 139 | loadBalancer.OnUpdate(endpoints) 140 | expectEndpoint(t, loadBalancer, "foo", "endpoint:8") 141 | expectEndpoint(t, loadBalancer, "foo", "endpoint:9") 142 | expectEndpoint(t, loadBalancer, "foo", "endpoint:8") 143 | expectEndpoint(t, loadBalancer, "foo", "endpoint:9") 144 | // Clear endpoints 145 | endpoints[0] = api.Endpoints{JSONBase: api.JSONBase{ID: "foo"}, Endpoints: []string{}} 146 | loadBalancer.OnUpdate(endpoints) 147 | 148 | _, endpoint, err = loadBalancer.NextEndpoint("foo", nil) 149 | if err == nil || len(endpoint) != 0 { 150 | t.Errorf("Didn't fail with non-existent service") 151 | } 152 | } 153 | 154 | func TestLoadBalanceWorksWithServiceRemoval(t *testing.T) { 155 | loadBalancer := NewLoadBalancerRR() 156 | _, endpoint, err := loadBalancer.NextEndpoint("foo", nil) 157 | if err == nil || len(endpoint) != 0 { 158 | t.Errorf("Didn't fail with non-existent service") 159 | } 160 | endpoints := make([]api.Endpoints, 2) 161 | endpoints[0] = api.Endpoints{ 162 | JSONBase: api.JSONBase{ID: "foo"}, 163 | Endpoints: []string{"endpoint:1", "endpoint:2", "endpoint:3"}, 164 | } 165 | endpoints[1] = api.Endpoints{ 166 | JSONBase: api.JSONBase{ID: "bar"}, 167 | Endpoints: []string{"endpoint:4", "endpoint:5"}, 168 | } 169 | loadBalancer.OnUpdate(endpoints) 170 | expectEndpoint(t, loadBalancer, "foo", "endpoint:1") 171 | expectEndpoint(t, loadBalancer, "foo", "endpoint:2") 172 | expectEndpoint(t, loadBalancer, "foo", "endpoint:3") 173 | expectEndpoint(t, loadBalancer, "foo", "endpoint:1") 174 | expectEndpoint(t, loadBalancer, "foo", "endpoint:2") 175 | 176 | expectEndpoint(t, loadBalancer, "bar", "endpoint:4") 177 | expectEndpoint(t, loadBalancer, "bar", "endpoint:5") 178 | expectEndpoint(t, loadBalancer, "bar", "endpoint:4") 179 | expectEndpoint(t, loadBalancer, "bar", "endpoint:5") 180 | expectEndpoint(t, loadBalancer, "bar", "endpoint:4") 181 | 182 | // Then update the configuration by removing foo 183 | loadBalancer.OnUpdate(endpoints[1:]) 184 | _, endpoint, err = loadBalancer.NextEndpoint("foo", nil) 185 | if err == nil || len(endpoint) != 0 { 186 | t.Errorf("Didn't fail with non-existent service") 187 | } 188 | 189 | // but bar is still there, and we continue RR from where we left off. 190 | expectEndpoint(t, loadBalancer, "bar", "endpoint:5") 191 | expectEndpoint(t, loadBalancer, "bar", "endpoint:4") 192 | expectEndpoint(t, loadBalancer, "bar", "endpoint:5") 193 | expectEndpoint(t, loadBalancer, "bar", "endpoint:4") 194 | } 195 | -------------------------------------------------------------------------------- /pkg/proxy/udp_server.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2014 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package proxy 18 | 19 | import ( 20 | "fmt" 21 | "net" 22 | ) 23 | 24 | // udpEchoServer is a simple echo server in UDP, intended for testing the proxy. 25 | type udpEchoServer struct { 26 | net.PacketConn 27 | } 28 | 29 | func (r *udpEchoServer) Loop() { 30 | var buffer [4096]byte 31 | for { 32 | n, cliAddr, err := r.ReadFrom(buffer[0:]) 33 | if err != nil { 34 | fmt.Printf("ReadFrom failed: %#v\n", err) 35 | continue 36 | } 37 | r.WriteTo(buffer[0:n], cliAddr) 38 | } 39 | } 40 | 41 | func newUDPEchoServer() (*udpEchoServer, error) { 42 | packetconn, err := net.ListenPacket("udp", ":0") 43 | if err != nil { 44 | return nil, err 45 | } 46 | return &udpEchoServer{packetconn}, nil 47 | } 48 | 49 | /* 50 | func main() { 51 | r,_ := newUDPEchoServer() 52 | r.Loop() 53 | } 54 | */ 55 | -------------------------------------------------------------------------------- /pong/.gitignore: -------------------------------------------------------------------------------- 1 | pong 2 | -------------------------------------------------------------------------------- /pong/Dockerfile: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # Dockerfile with simple nc pong script 3 | # Based on Ubuntu Image 4 | ############################################################ 5 | 6 | FROM ubuntu 7 | 8 | MAINTAINER Vishvananda Ishaya 9 | 10 | ADD pong /usr/bin/pong 11 | 12 | ENTRYPOINT pong 13 | -------------------------------------------------------------------------------- /pong/Makefile: -------------------------------------------------------------------------------- 1 | all: docker-pong 2 | 3 | pong: pong.go 4 | go build pong.go 5 | 6 | .PHONY: docker-pong 7 | docker-pong: pong 8 | docker build -t wormhole/pong . 9 | -------------------------------------------------------------------------------- /pong/pong.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "io" 6 | "os" 7 | "os/signal" 8 | "net" 9 | "syscall" 10 | ) 11 | 12 | func main() { 13 | host := ":9001" 14 | if len(os.Args) > 1 { 15 | host = os.Args[1] 16 | } 17 | listener, err := net.Listen("tcp", host) 18 | if err != nil { 19 | log.Fatalf("Listen Error: %v", err) 20 | } 21 | defer listener.Close() 22 | 23 | csig := make(chan os.Signal, 1) 24 | signal.Notify(csig, os.Interrupt, syscall.SIGTERM, syscall.SIGKILL) 25 | go func() { 26 | <-csig 27 | listener.Close() 28 | os.Exit(0) 29 | }() 30 | 31 | for { 32 | conn, err := listener.Accept() 33 | if err != nil { 34 | log.Fatalf("Accept Error: %v", err) 35 | return 36 | } 37 | go func() { 38 | defer conn.Close() 39 | buf := make([]byte, 1024) 40 | n, err := conn.Read(buf) 41 | if err == io.EOF { 42 | return 43 | } 44 | if err != nil { 45 | log.Fatalf("Read Error: %v", err) 46 | } 47 | if n != 0 { 48 | _, err := conn.Write(buf[:n]) 49 | if err != nil { 50 | log.Fatalf("Write Error: %v", err) 51 | } 52 | os.Exit(0) 53 | } 54 | }() 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /server/api.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bufio" 5 | "encoding/gob" 6 | "github.com/raff/tls-ext" 7 | "github.com/vishvananda/wormhole/client" 8 | "github.com/vishvananda/wormhole/utils" 9 | "io" 10 | "log" 11 | "net" 12 | "net/rpc" 13 | ) 14 | 15 | type Api int 16 | 17 | func (t *Api) Echo(args *client.EchoArgs, reply *client.EchoReply) (err error) { 18 | reply.Value, err = echo(args.Host, args.Value) 19 | return err 20 | } 21 | 22 | func (t *Api) CreateTunnel(args *client.CreateTunnelArgs, reply *client.CreateTunnelReply) (err error) { 23 | reply.Src, reply.Dst, err = createTunnel(args.Host, args.Udp) 24 | return err 25 | } 26 | 27 | func (t *Api) DeleteTunnel(args *client.DeleteTunnelArgs, reply *client.DeleteTunnelReply) (err error) { 28 | return deleteTunnel(args.Host) 29 | return err 30 | } 31 | 32 | func (t *Api) CreateSegment(args *client.CreateSegmentArgs, reply *client.CreateSegmentReply) (err error) { 33 | reply.Url, err = createSegment(args.Id, args.Init, args.Trig) 34 | return err 35 | } 36 | 37 | func (t *Api) DeleteSegment(args *client.DeleteSegmentArgs, reply *client.DeleteSegmentReply) (err error) { 38 | err = deleteSegment(args.Id) 39 | return err 40 | } 41 | 42 | func (t *Api) GetSrcIP(args *client.GetSrcIPArgs, reply *client.GetSrcIPReply) (err error) { 43 | reply.Src, err = getSrcIP(args.Dst) 44 | return err 45 | } 46 | 47 | func (t *Api) BuildTunnel(args *client.BuildTunnelArgs, reply *client.BuildTunnelReply) (err error) { 48 | reply.Src, reply.Tunnel, err = buildTunnel(args.Dst, args.Tunnel) 49 | return err 50 | } 51 | 52 | func (t *Api) DestroyTunnel(args *client.DestroyTunnelArgs, reply *client.DestroyTunnelReply) (err error) { 53 | reply.Src, err = destroyTunnel(args.Dst) 54 | return err 55 | } 56 | 57 | type gobServerCodec struct { 58 | rwc io.ReadWriteCloser 59 | dec *gob.Decoder 60 | enc *gob.Encoder 61 | encBuf *bufio.Writer 62 | } 63 | 64 | func (c *gobServerCodec) ReadRequestHeader(r *rpc.Request) error { 65 | return c.dec.Decode(r) 66 | } 67 | 68 | func (c *gobServerCodec) ReadRequestBody(body interface{}) error { 69 | return c.dec.Decode(body) 70 | } 71 | 72 | func (c *gobServerCodec) WriteResponse(r *rpc.Response, body interface{}) (err error) { 73 | if err = c.enc.Encode(r); err != nil { 74 | return 75 | } 76 | if err = c.enc.Encode(body); err != nil { 77 | return 78 | } 79 | return c.encBuf.Flush() 80 | } 81 | 82 | func (c *gobServerCodec) Close() error { 83 | return c.rwc.Close() 84 | } 85 | 86 | func handle(conn net.Conn) { 87 | buf := bufio.NewWriter(conn) 88 | srv := &gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(buf), buf} 89 | rpc.ServeCodec(srv) 90 | } 91 | 92 | var listener net.Listener 93 | 94 | func serveAPI() { 95 | rpc.Register(new(Api)) 96 | proto, address := utils.ParseAddr(opts.hosts[0]) 97 | var err error 98 | listener, err = tls.Listen(proto, address, opts.config) 99 | if err != nil { 100 | log.Fatalf("Listen: %v", err) 101 | } 102 | defer listener.Close() 103 | for { 104 | conn, err := listener.Accept() 105 | if err != nil { 106 | log.Fatalf("Accept: %v", err) 107 | return 108 | } 109 | go handle(conn) 110 | } 111 | } 112 | 113 | func shutdownAPI() { 114 | listener.Close() 115 | } 116 | -------------------------------------------------------------------------------- /server/echo.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | 7 | "github.com/golang/glog" 8 | "github.com/vishvananda/wormhole/client" 9 | "github.com/vishvananda/wormhole/utils" 10 | ) 11 | 12 | func echo(host string, value []byte) ([]byte, error) { 13 | glog.Infof("Echo called with: %v %v", host, value) 14 | if host == "" { 15 | return value, nil 16 | } else { 17 | host, err := utils.ValidateAddr(host) 18 | if err != nil { 19 | return nil, err 20 | } 21 | c, err := client.NewClient(host, opts.config) 22 | if err != nil { 23 | return nil, err 24 | } 25 | response, err := c.Echo(value, "") 26 | if err != nil { 27 | return nil, err 28 | } 29 | if !bytes.Equal(value, response) { 30 | return response, fmt.Errorf("Incorrect response from echo") 31 | } 32 | glog.Infof("Echo response is: %v", response) 33 | return response, nil 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /server/opts.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "flag" 5 | "io/ioutil" 6 | "log" 7 | "net" 8 | "os" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/raff/tls-ext" 13 | "github.com/raff/tls-psk" 14 | "github.com/vishvananda/wormhole/utils" 15 | ) 16 | 17 | type options struct { 18 | hosts []string 19 | src net.IP 20 | external net.IP 21 | cidr *net.IPNet 22 | config *tls.Config 23 | udpStartPort int 24 | udpEndPort int 25 | } 26 | 27 | var opts *options 28 | 29 | func parseFlags() { 30 | keyfile := flag.String("K", "/etc/wormhole/key.secret", "Keyfile for psk auth (if not found defaults to insecure key)") 31 | src := flag.String("I", "", "Internal Ip for tunnel (defaults to src of default route)") 32 | external := flag.String("E", "", "External Ip for tunnel (defaults to src of default route)") 33 | cidr := flag.String("C", "100.65.0.0/14", "Cidr for overlay ips (must be the same on all hosts)") 34 | ports := flag.String("P", "4500-4599", "Inclusive port range for udp tunnels") 35 | hosts := utils.NewListOpts(utils.ValidateAddr) 36 | flag.Var(&hosts, "H", "Multiple tcp://host:port or unix://path/to/socket to bind") 37 | 38 | flag.Parse() 39 | if hosts.Len() == 0 { 40 | hosts.Set("") 41 | } 42 | 43 | var srcIP net.IP 44 | if *src == "" { 45 | var err error 46 | srcIP, err = getSource(nil) 47 | if err != nil { 48 | log.Fatalf("Failed to find default route ip. Please specify -I") 49 | } 50 | } else { 51 | log.Printf("Got a source ip of %v", *src) 52 | srcIP = net.ParseIP(*src) 53 | if srcIP == nil { 54 | log.Fatalf("Invalid source IP for tunnels: %v", src) 55 | } 56 | } 57 | var externalIP net.IP 58 | if *external == "" { 59 | externalIP = srcIP 60 | } else { 61 | log.Printf("Got an external ip of %v", *external) 62 | externalIP = net.ParseIP(*external) 63 | if externalIP == nil { 64 | log.Fatalf("Invalid external IP for tunnels: %v", external) 65 | } 66 | } 67 | _, cidrNet, err := net.ParseCIDR(*cidr) 68 | if err != nil { 69 | log.Fatalf("Failed to parse -C: %v", err) 70 | } 71 | portParts := strings.Split(*ports, "-") 72 | startPort, err := strconv.Atoi(portParts[0]) 73 | if err != nil { 74 | log.Fatalf("Port range %s is not valid: %v", ports, err) 75 | } 76 | endPort := startPort 77 | if len(portParts) > 1 { 78 | endPort, err = strconv.Atoi(portParts[1]) 79 | if err != nil { 80 | log.Fatalf("Port range %s is not valid: %v", ports, err) 81 | } 82 | } 83 | 84 | key := "wormhole" 85 | b, err := ioutil.ReadFile(*keyfile) 86 | if err != nil { 87 | log.Printf("Failed to open keyfile %s: %v", *keyfile, err) 88 | log.Printf("** WARNING: USING INSECURE PRE-SHARED-KEY **") 89 | } else { 90 | key = string(b) 91 | } 92 | 93 | var config = &tls.Config{ 94 | CipherSuites: []uint16{psk.TLS_PSK_WITH_AES_128_CBC_SHA}, 95 | Certificates: []tls.Certificate{tls.Certificate{}}, 96 | Extra: psk.PSKConfig{ 97 | GetKey: func(id string) ([]byte, error) { 98 | return []byte(key), nil 99 | }, 100 | GetIdentity: func() string { 101 | name, err := os.Hostname() 102 | if err != nil { 103 | log.Printf("Failed to determine hostname: %v", err) 104 | return "wormhole" 105 | } 106 | return name 107 | }, 108 | }, 109 | } 110 | 111 | opts = &options{ 112 | hosts: hosts.GetAll(), 113 | src: srcIP, 114 | external: externalIP, 115 | cidr: cidrNet, 116 | config: config, 117 | udpStartPort: startPort, 118 | udpEndPort: endPort, 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /server/segment.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "os/exec" 7 | "strconv" 8 | "strings" 9 | "sync" 10 | 11 | "github.com/golang/glog" 12 | "github.com/vishvananda/netns" 13 | "github.com/vishvananda/wormhole/client" 14 | "github.com/vishvananda/wormhole/pkg/proxy" 15 | "github.com/vishvananda/wormhole/utils" 16 | ) 17 | 18 | var segmentsMutex sync.Mutex 19 | var segments map[string]*Segment 20 | 21 | func initSegments() { 22 | segments = make(map[string]*Segment) 23 | } 24 | 25 | func cleanupSegments() { 26 | for id, s := range segments { 27 | glog.Infof("Cleaning segment %s", id) 28 | s.Cleanup() 29 | glog.Infof("Finished cleaning segment %s", id) 30 | } 31 | segmentsMutex.Lock() 32 | defer segmentsMutex.Unlock() 33 | for id, _ := range segments { 34 | glog.Infof("Deleting segment %s", id) 35 | delete(segments, id) 36 | glog.Infof("Finished deleting segment %s", id) 37 | } 38 | } 39 | 40 | func addSegment(key string, segment *Segment) { 41 | segmentsMutex.Lock() 42 | defer segmentsMutex.Unlock() 43 | segments[key] = segment 44 | } 45 | 46 | func getSegment(key string) *Segment { 47 | return segments[key] 48 | } 49 | 50 | func removeSegment(key string) { 51 | segmentsMutex.Lock() 52 | defer segmentsMutex.Unlock() 53 | delete(segments, key) 54 | } 55 | 56 | type ConnectionInfo struct { 57 | Proto string 58 | Ns netns.NsHandle 59 | Hostname string 60 | Port int 61 | } 62 | 63 | type Segment struct { 64 | Head ConnectionInfo 65 | Tail ConnectionInfo 66 | Init []client.SegmentCommand 67 | Trig []client.SegmentCommand 68 | ChildHost string 69 | ChildId string 70 | Proxy *proxy.Proxier 71 | DockerIds []string 72 | } 73 | 74 | func (s Segment) String() string { 75 | var initstring, trigstring string 76 | for _, a := range s.Init { 77 | initstring += fmt.Sprintf("%s: %v ", client.CommandName[a.Type], a) 78 | } 79 | for _, a := range s.Trig { 80 | trigstring += fmt.Sprintf("%s: %v ", client.CommandName[a.Type], a) 81 | } 82 | return fmt.Sprintf("{%v %v [%s] [%s]}", s.Head, s.Tail, strings.TrimSpace(initstring), strings.TrimSpace(trigstring)) 83 | } 84 | 85 | func (s *Segment) Cleanup() { 86 | if s.Proxy != nil { 87 | s.Proxy.StopProxy("segment") 88 | s.Proxy = nil 89 | } 90 | if s.ChildId != "" { 91 | if s.ChildHost == "" { 92 | deleteSegment(s.ChildId) 93 | } else { 94 | c, err := client.NewClient(s.ChildHost, opts.config) 95 | if err != nil { 96 | glog.Errorf("Failed to connect to child host at %s: %v", s.ChildHost, err) 97 | } else { 98 | c.DeleteSegment(s.ChildId) 99 | } 100 | } 101 | } 102 | if len(s.DockerIds) != 0 { 103 | args := []string{"rm", "-f"} 104 | args = append(args, s.DockerIds...) 105 | out, err := exec.Command("docker", args...).CombinedOutput() 106 | if err != nil { 107 | glog.Errorf("Error deleting docker container %v: %s", err, out) 108 | } 109 | } 110 | if s.Head.Ns.IsOpen() { 111 | s.Head.Ns.Close() 112 | } 113 | if s.Tail.Ns.IsOpen() { 114 | s.Tail.Ns.Close() 115 | } 116 | } 117 | 118 | func NewSegment() *Segment { 119 | return &Segment{Head: ConnectionInfo{Ns: netns.None()}, Tail: ConnectionInfo{Ns: netns.None()}} 120 | } 121 | 122 | func createSegment(id string, init []client.SegmentCommand, trig []client.SegmentCommand) (string, error) { 123 | cinfo, err := createSegmentLocal(id, init, trig, nil) 124 | if err != nil { 125 | return "", err 126 | } 127 | return fmt.Sprintf("%s://%s:%d", cinfo.Proto, cinfo.Hostname, cinfo.Port), nil 128 | } 129 | 130 | func createSegmentLocal(id string, init []client.SegmentCommand, trig []client.SegmentCommand, cinfo *ConnectionInfo) (*ConnectionInfo, error) { 131 | exists := getSegment(id) 132 | if exists != nil { 133 | return nil, fmt.Errorf("Segment %s already exists", id) 134 | } 135 | glog.Infof("Creating segment %s", id) 136 | s := NewSegment() 137 | if cinfo != nil { 138 | s.Head = *cinfo 139 | } 140 | s.Init = init 141 | s.Trig = trig 142 | err := s.Initialize() 143 | if err != nil { 144 | return nil, err 145 | } 146 | s.Proxy = proxy.NewProxier(s, s.Head.Hostname) 147 | s.Proxy.SetNs(s.Head.Ns) 148 | s.Head.Port, err = s.Proxy.AddService("segment", s.Head.Proto, s.Head.Port) 149 | if err != nil { 150 | return nil, err 151 | } 152 | addSegment(id, s) 153 | glog.Infof("Finished creating segment %s", id) 154 | return &s.Head, nil 155 | } 156 | 157 | func deleteSegment(id string) error { 158 | glog.Infof("Deleting segment %s", id) 159 | s := getSegment(id) 160 | if s != nil { 161 | s.Cleanup() 162 | } 163 | removeSegment(id) 164 | glog.Infof("Finished deleting segment %s", id) 165 | return nil 166 | } 167 | 168 | func executeCommands(commands *[]client.SegmentCommand, seg *Segment) error { 169 | // Range is not used to skip a copy 170 | for i := 0; i < len(*commands); i++ { 171 | var err error 172 | switch (*commands)[i].Type { 173 | case client.NONE: 174 | case client.DOCKER_NS: 175 | err = executeDockerNs(&(*commands)[i], seg) 176 | case client.DOCKER_RUN: 177 | err = executeDockerRun(&(*commands)[i], seg) 178 | case client.CHILD: 179 | err = executeChild(&(*commands)[i], seg, false) 180 | case client.CHAIN: 181 | err = executeChild(&(*commands)[i], seg, true) 182 | case client.REMOTE: 183 | err = executeRemote(&(*commands)[i], seg) 184 | case client.TUNNEL: 185 | err = executeTunnel(&(*commands)[i], seg, false) 186 | case client.UDPTUNNEL: 187 | err = executeTunnel(&(*commands)[i], seg, true) 188 | case client.URL: 189 | err = executeUrl(&(*commands)[i], seg) 190 | default: 191 | err = fmt.Errorf("Command type %d recognized", (*commands)[i].Type) 192 | } 193 | if err != nil { 194 | return err 195 | } 196 | } 197 | *commands = make([]client.SegmentCommand, 0) 198 | return nil 199 | } 200 | 201 | func (s *Segment) Initialize() error { 202 | err := executeCommands(&s.Init, s) 203 | if err != nil { 204 | return err 205 | } 206 | // Head proto defaults to tcp if not set 207 | if s.Head.Proto == "" { 208 | s.Head.Proto = "tcp" 209 | } 210 | if s.Head.Hostname == "" { 211 | s.Head.Hostname = "127.0.0.1" 212 | } 213 | return nil 214 | } 215 | 216 | func hostEqual(proto string, h1 string, h2 string) bool { 217 | // TODO: check all local addresses if h1 is 0.0.0.0 218 | if h1 == h2 { 219 | return true 220 | } 221 | if proto[:3] == "udp" { 222 | a1, err := net.ResolveUDPAddr(proto, h1) 223 | if err != nil { 224 | return false 225 | } 226 | a2, err := net.ResolveUDPAddr(proto, h2) 227 | if err != nil { 228 | return false 229 | } 230 | return a1.IP.Equal(a2.IP) && a1.Zone == a2.Zone && a1.Port == a2.Port 231 | } else if proto[:3] == "tcp" { 232 | a1, err := net.ResolveTCPAddr(proto, h1) 233 | if err != nil { 234 | return false 235 | } 236 | a2, err := net.ResolveTCPAddr(proto, h2) 237 | if err != nil { 238 | return false 239 | } 240 | return a1.IP.Equal(a2.IP) && a1.Zone == a2.Zone && a1.Port == a2.Port 241 | } 242 | return false 243 | } 244 | 245 | func (s *Segment) Trigger() error { 246 | err := executeCommands(&s.Trig, s) 247 | if err != nil { 248 | return err 249 | } 250 | // Tail proto defaults to Head proto if not set 251 | if s.Tail.Proto == "" { 252 | s.Tail.Proto = s.Head.Proto 253 | } 254 | if s.Tail.Hostname == "" { 255 | s.Tail.Hostname = "127.0.0.1" 256 | } 257 | // Tail port defaults to Head port if not set 258 | if s.Tail.Port == 0 { 259 | s.Tail.Port = s.Head.Port 260 | } 261 | host1 := net.JoinHostPort(s.Head.Hostname, strconv.Itoa(s.Head.Port)) 262 | host2 := net.JoinHostPort(s.Tail.Hostname, strconv.Itoa(s.Tail.Port)) 263 | 264 | if hostEqual(s.Head.Proto, host1, host2) && s.Head.Ns.Equal(s.Tail.Ns) { 265 | return fmt.Errorf("Cannot proxy to self") 266 | } 267 | return nil 268 | } 269 | 270 | // NextEndpoint is an implementation of the loadbalancer interface for proxy. 271 | func (s *Segment) NextEndpoint(service string, srcAddr net.Addr) (netns.NsHandle, string, error) { 272 | err := s.Trigger() 273 | if err != nil { 274 | return netns.None(), "", err 275 | } 276 | host := net.JoinHostPort(s.Tail.Hostname, strconv.Itoa(s.Tail.Port)) 277 | return s.Tail.Ns, host, nil 278 | } 279 | 280 | func executeUrl(command *client.SegmentCommand, seg *Segment) error { 281 | ci := &seg.Head 282 | if command.Tail { 283 | ci = &seg.Tail 284 | } 285 | proto, ns, hostname, port, err := utils.ParseUrl(command.Arg) 286 | if err != nil { 287 | return err 288 | } 289 | if proto != "" { 290 | ci.Proto = proto 291 | } 292 | if ns != "" { 293 | var err error 294 | ci.Ns, err = netns.GetFromName(ns) 295 | if err != nil { 296 | return err 297 | } 298 | } 299 | if hostname != "" { 300 | ci.Hostname = hostname 301 | } 302 | if port != 0 { 303 | ci.Port = port 304 | } 305 | return nil 306 | } 307 | 308 | func executeDockerNs(command *client.SegmentCommand, seg *Segment) error { 309 | ci := &seg.Head 310 | if command.Tail { 311 | ci = &seg.Tail 312 | } 313 | var err error 314 | ci.Ns, err = netns.GetFromDocker(command.Arg) 315 | return err 316 | } 317 | 318 | func executeDockerRun(command *client.SegmentCommand, seg *Segment) error { 319 | ci := &seg.Head 320 | if command.Tail { 321 | ci = &seg.Tail 322 | } 323 | args := strings.Fields(command.Arg) 324 | // TODO: use the api here instead of shelling out 325 | args = append([]string{"run", "-d"}, args...) 326 | out, err := exec.Command("docker", args...).Output() 327 | if err != nil { 328 | return err 329 | } 330 | id := strings.TrimSpace(string(out)) 331 | seg.DockerIds = append(seg.DockerIds, id) 332 | 333 | ci.Ns, err = netns.GetFromDocker(id) 334 | return err 335 | } 336 | 337 | func executeChild(command *client.SegmentCommand, seg *Segment, chain bool) error { 338 | id := utils.Uuid() 339 | cinfo, err := createSegmentLocal(id, command.ChildInit, command.ChildTrig, &seg.Tail) 340 | if err != nil { 341 | return err 342 | } 343 | if chain { 344 | seg.Tail = *cinfo 345 | } 346 | seg.ChildId = id 347 | return nil 348 | } 349 | 350 | func executeRemote(command *client.SegmentCommand, seg *Segment) error { 351 | c, err := client.NewClient(command.Arg, opts.config) 352 | if err != nil { 353 | return err 354 | } 355 | dst, err := c.GetSrcIP(opts.src) 356 | if err != nil { 357 | return err 358 | } 359 | urlCommand := client.SegmentCommand{Type: client.URL, Arg: dst.String()} 360 | command.ChildInit = append(command.ChildInit, urlCommand) 361 | id := utils.Uuid() 362 | url, err := c.CreateSegment(id, command.ChildInit, command.ChildTrig) 363 | if err != nil { 364 | return err 365 | } 366 | seg.Tail.Proto, _, seg.Tail.Hostname, seg.Tail.Port, err = utils.ParseUrl(url) 367 | if err != nil { 368 | return err 369 | } 370 | seg.ChildHost = command.Arg 371 | seg.ChildId = id 372 | return nil 373 | } 374 | 375 | func executeTunnel(command *client.SegmentCommand, seg *Segment, udp bool) error { 376 | _, dst, err := createTunnel(command.Arg, udp) 377 | if err != nil { 378 | return err 379 | } 380 | urlCommand := client.SegmentCommand{Type: client.URL, Arg: dst.String()} 381 | command.ChildInit = append(command.ChildInit, urlCommand) 382 | c, err := client.NewClient(command.Arg, opts.config) 383 | if err != nil { 384 | return err 385 | } 386 | id := utils.Uuid() 387 | url, err := c.CreateSegment(id, command.ChildInit, command.ChildTrig) 388 | if err != nil { 389 | return err 390 | } 391 | seg.Tail.Proto, _, seg.Tail.Hostname, seg.Tail.Port, err = utils.ParseUrl(url) 392 | if err != nil { 393 | return err 394 | } 395 | seg.ChildHost = command.Arg 396 | seg.ChildId = id 397 | return nil 398 | } 399 | -------------------------------------------------------------------------------- /server/segment_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/vishvananda/wormhole/client" 5 | "testing" 6 | ) 7 | 8 | func TestInitializeModifyCurrent(t *testing.T) { 9 | seg := Segment{} 10 | seg.Init = append(seg.Init, client.SegmentCommand{Type: client.URL, Arg: ":1"}) 11 | seg.Initialize() 12 | if seg.Head.Port != 1 { 13 | t.Fatal("Command did not modify value") 14 | } 15 | } 16 | 17 | func TestInitializeModifyTail(t *testing.T) { 18 | seg := Segment{} 19 | seg.Init = append(seg.Init, client.SegmentCommand{Type: client.URL, Arg: ":1", Tail: true}) 20 | seg.Initialize() 21 | if seg.Tail.Port != 1 { 22 | t.Fatal("Command did not modify value") 23 | } 24 | } 25 | 26 | func TestInitializeDouble(t *testing.T) { 27 | seg := Segment{} 28 | seg.Init = append(seg.Init, client.SegmentCommand{Type: client.URL, Arg: ":1"}) 29 | seg.Init = append(seg.Init, client.SegmentCommand{Type: client.URL, Arg: ":2"}) 30 | seg.Initialize() 31 | if seg.Head.Port != 2 { 32 | t.Fatal("Second command did not modify value") 33 | } 34 | } 35 | 36 | func TestInitializeEmpties(t *testing.T) { 37 | seg := Segment{} 38 | seg.Init = append(seg.Init, client.SegmentCommand{Type: client.URL, Arg: ":1"}) 39 | seg.Initialize() 40 | if len(seg.Init) != 0 { 41 | t.Fatal("Initialize commands still in queue") 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "github.com/vishvananda/netlink" 6 | "net" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | ) 11 | 12 | func getSource(dest net.IP) (net.IP, error) { 13 | var source net.IP 14 | routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) 15 | if err != nil { 16 | return nil, fmt.Errorf("Failed to get routes") 17 | } 18 | var link netlink.Link 19 | for _, route := range routes { 20 | if route.Dst == nil { 21 | link = &netlink.Dummy{netlink.LinkAttrs{Index: route.LinkIndex}} 22 | source = route.Src 23 | 24 | } else if route.Dst.Contains(dest) { 25 | link = &netlink.Dummy{netlink.LinkAttrs{Index: route.LinkIndex}} 26 | source = route.Src 27 | break 28 | } 29 | } 30 | if link == nil { 31 | return nil, fmt.Errorf("Failed to find route to target: %s", dest) 32 | } 33 | if source == nil { 34 | // no source in route to target so use the first ip from interface 35 | addrs, err := netlink.AddrList(link, netlink.FAMILY_ALL) 36 | if err != nil || len(addrs) == 0 { 37 | return nil, fmt.Errorf("Failed to find source ip for interface: %s", link) 38 | } 39 | source = addrs[0].IP 40 | } 41 | return source, nil 42 | } 43 | 44 | func Main() { 45 | parseFlags() 46 | 47 | csig := make(chan os.Signal, 1) 48 | signal.Notify(csig, os.Interrupt, syscall.SIGTERM, syscall.SIGKILL) 49 | go func() { 50 | <-csig 51 | cleanupSegments() 52 | cleanupTunnels() 53 | shutdownAPI() 54 | os.Exit(0) 55 | }() 56 | 57 | initTunnels() 58 | defer cleanupTunnels() 59 | 60 | initSegments() 61 | defer cleanupSegments() 62 | 63 | serveAPI() 64 | } 65 | -------------------------------------------------------------------------------- /server/tunnel.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "crypto/rand" 5 | "fmt" 6 | "math/big" 7 | "net" 8 | "sync" 9 | "syscall" 10 | 11 | "github.com/golang/glog" 12 | "github.com/vishvananda/netlink" 13 | "github.com/vishvananda/wormhole/client" 14 | "github.com/vishvananda/wormhole/pkg/netaddr" 15 | ) 16 | 17 | var tunnelsMutex sync.Mutex 18 | var tunnels map[string]*client.Tunnel 19 | var listeners map[string]int 20 | 21 | var usedIPsMutex sync.Mutex 22 | var usedIPs map[string]bool 23 | 24 | var unusedPortsMutex sync.Mutex 25 | var unusedPorts []int 26 | 27 | type IPInUse error 28 | type NoPortsAvailable error 29 | 30 | func initTunnels() { 31 | tunnels = make(map[string]*client.Tunnel) 32 | listeners = make(map[string]int) 33 | usedIPs = make(map[string]bool) 34 | for p := opts.udpStartPort; p <= opts.udpEndPort; p++ { 35 | unusedPorts = append(unusedPorts, p) 36 | } 37 | discoverTunnels() 38 | } 39 | 40 | func cleanupTunnels() { 41 | // Currently we leave tunnels in place 42 | } 43 | 44 | func addTunnel(key string, tunnel *client.Tunnel, listener int) { 45 | tunnelsMutex.Lock() 46 | defer tunnelsMutex.Unlock() 47 | tunnels[key] = tunnel 48 | listeners[key] = listener 49 | } 50 | 51 | func getTunnel(key string) *client.Tunnel { 52 | return tunnels[key] 53 | } 54 | 55 | func getListener(key string) int { 56 | return listeners[key] 57 | } 58 | 59 | func removeTunnel(key string) { 60 | tunnelsMutex.Lock() 61 | defer tunnelsMutex.Unlock() 62 | delete(tunnels, key) 63 | delete(listeners, key) 64 | } 65 | 66 | func reserveIP(ip net.IP) error { 67 | usedIPsMutex.Lock() 68 | defer usedIPsMutex.Unlock() 69 | ipStr := ip.String() 70 | exists := usedIPs[ipStr] 71 | if exists { 72 | return IPInUse(fmt.Errorf("IP %s is in use", ip)) 73 | } 74 | usedIPs[ipStr] = true 75 | return nil 76 | } 77 | 78 | func unreserveIP(ip net.IP) { 79 | usedIPsMutex.Lock() 80 | defer usedIPsMutex.Unlock() 81 | delete(usedIPs, ip.String()) 82 | } 83 | 84 | func allocatePort() (int, error) { 85 | unusedPortsMutex.Lock() 86 | defer unusedPortsMutex.Unlock() 87 | 88 | if len(unusedPorts) == 0 { 89 | return 0, NoPortsAvailable(fmt.Errorf("No ports available")) 90 | } 91 | var port int 92 | port, unusedPorts = unusedPorts[0], unusedPorts[1:] 93 | return port, nil 94 | } 95 | 96 | func releasePort(port int) { 97 | unusedPortsMutex.Lock() 98 | defer unusedPortsMutex.Unlock() 99 | unusedPorts = append(unusedPorts, port) 100 | } 101 | 102 | func discoverTunnels() { 103 | glog.Infof("Discovering existing tunnels") 104 | lo, err := netlink.LinkByName("lo") 105 | if err != nil { 106 | glog.Errorf("Failed to get loopback device: %v", err) 107 | return 108 | } 109 | addrs, err := netlink.AddrList(lo, netlink.FAMILY_ALL) 110 | if err != nil { 111 | glog.Errorf("Failed to get addrs: %v", err) 112 | return 113 | } 114 | routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) 115 | if err != nil { 116 | glog.Errorf("Failed to get routes: %v", err) 117 | return 118 | } 119 | policies, err := netlink.XfrmPolicyList(netlink.FAMILY_ALL) 120 | if err != nil { 121 | glog.Errorf("Failed to get xfrm policies: %v", err) 122 | return 123 | } 124 | states, err := netlink.XfrmStateList(netlink.FAMILY_ALL) 125 | if err != nil { 126 | glog.Errorf("Failed to get xfrm states: %v", err) 127 | return 128 | } 129 | for _, addr := range addrs { 130 | if opts.cidr.Contains(addr.IP) { 131 | tunnel := client.Tunnel{} 132 | tunnel.Src = addr.IP 133 | err := reserveIP(tunnel.Src) 134 | if err != nil { 135 | glog.Warningf("Duplicate tunnel ip detected: %v", tunnel.Src) 136 | } 137 | tunnel.Dst = nil 138 | glog.Infof("Potential tunnel found from %s", tunnel.Src) 139 | for _, route := range routes { 140 | if route.Src == nil || !route.Src.Equal(tunnel.Src) { 141 | continue 142 | } 143 | tunnel.Dst = route.Dst.IP 144 | break 145 | } 146 | if tunnel.Dst == nil { 147 | glog.Warningf("could not find dst for tunnel src %s", tunnel.Src) 148 | continue 149 | } 150 | err = reserveIP(tunnel.Dst) 151 | if err != nil { 152 | glog.Warningf("Duplicate tunnel ip detected: %v", tunnel.Dst) 153 | } 154 | var dst net.IP 155 | for _, policy := range policies { 156 | if !policy.Dst.IP.Equal(tunnel.Dst) { 157 | continue 158 | } 159 | if len(policy.Tmpls) == 0 { 160 | glog.Warningf("Tunnel policy has no associated template") 161 | continue 162 | } 163 | dst = policy.Tmpls[0].Dst 164 | break 165 | } 166 | if dst == nil { 167 | glog.Warningf("could not find ip for tunnel between %s and %s", tunnel.Src, tunnel.Dst) 168 | continue 169 | } 170 | for _, state := range states { 171 | if !state.Dst.Equal(dst) { 172 | continue 173 | } 174 | tunnel.Reqid = state.Reqid 175 | if state.Auth == nil { 176 | glog.Warningf("Tunnel state has no associated authentication entry") 177 | continue 178 | } 179 | tunnel.AuthKey = state.Auth.Key 180 | if state.Crypt == nil { 181 | glog.Warningf("Tunnel state has no associated encryption entry") 182 | continue 183 | } 184 | tunnel.EncKey = state.Crypt.Key 185 | if state.Encap != nil { 186 | tunnel.SrcPort = state.Encap.SrcPort 187 | tunnel.SrcPort = state.Encap.DstPort 188 | } 189 | glog.Infof("Discovered tunnel between %v and %v over %v", tunnel.Src, tunnel.Dst, dst) 190 | var socket int 191 | if tunnel.SrcPort != 0 { 192 | socket, err = createEncapListener(tunnel.Src, tunnel.SrcPort) 193 | if err != nil { 194 | glog.Warningf("Failed to create udp listener: %v", err) 195 | } 196 | } 197 | addTunnel(dst.String(), &tunnel, socket) 198 | break 199 | } 200 | } 201 | } 202 | glog.Infof("Finished discovering existing tunnels") 203 | } 204 | 205 | func getLinkIndex(ip net.IP) (int, error) { 206 | links, err := netlink.LinkList() 207 | if err != nil { 208 | return -1, fmt.Errorf("Failed to get links") 209 | } 210 | for _, link := range links { 211 | addrs, err := netlink.AddrList(link, netlink.FAMILY_ALL) 212 | if err != nil { 213 | return -1, fmt.Errorf("Failed to get addrs") 214 | } 215 | for _, addr := range addrs { 216 | if addr.IP.Equal(ip) { 217 | return link.Attrs().Index, nil 218 | } 219 | } 220 | } 221 | return -1, fmt.Errorf("Could not find address") 222 | } 223 | 224 | func getSrcIP(dst net.IP) (net.IP, error) { 225 | if dst == nil { 226 | return opts.external, nil 227 | } 228 | tunnel := getTunnel(dst.String()) 229 | if tunnel == nil { 230 | return opts.src, nil 231 | } else { 232 | return tunnel.Src, nil 233 | } 234 | } 235 | 236 | func randomIPPair(cidr *net.IPNet) (first net.IP, second net.IP, err error) { 237 | ones, total := cidr.Mask.Size() 238 | max := int64((1 << uint64(total-ones-1))) 239 | value, err := rand.Int(rand.Reader, big.NewInt(max)) 240 | if err != nil { 241 | return 242 | } 243 | first = netaddr.IPAdd(cidr.IP, value.Uint64()*2+1) 244 | second = netaddr.IPAdd(first, 1) 245 | return 246 | } 247 | 248 | func randomKey() []byte { 249 | value := make([]byte, 32) 250 | rand.Read(value) 251 | return value 252 | } 253 | 254 | func getUnusedPort() int { 255 | // TODO: use a port range 256 | return 4500 257 | } 258 | 259 | func createTunnel(host string, udp bool) (net.IP, net.IP, error) { 260 | c, err := client.NewClient(host, opts.config) 261 | if err != nil { 262 | return nil, nil, err 263 | } 264 | defer c.Close() 265 | 266 | dst, err := c.GetSrcIP(nil) 267 | 268 | tunnel := &client.Tunnel{} 269 | 270 | exists := getTunnel(dst.String()) 271 | if exists != nil { 272 | glog.Infof("Tunnel already exists: %v, %v", exists.Src, exists.Dst) 273 | // tunnel dst and src are reversed from remote 274 | tunnel.Reqid = exists.Reqid 275 | tunnel.Src = exists.Dst 276 | tunnel.Dst = exists.Src 277 | tunnel.AuthKey = exists.AuthKey 278 | tunnel.EncKey = exists.EncKey 279 | tunnel.SrcPort = exists.DstPort 280 | tunnel.SrcPort = exists.SrcPort 281 | } else { 282 | tunnel = &client.Tunnel{} 283 | if udp { 284 | var err error 285 | tunnel.DstPort, err = allocatePort() 286 | if err != nil { 287 | glog.Errorf("No ports available: %v", dst) 288 | return nil, nil, err 289 | } 290 | glog.Infof("Using %d for encap port", tunnel.DstPort) 291 | } 292 | 293 | tunnel.AuthKey = randomKey() 294 | tunnel.EncKey = randomKey() 295 | // random number between 1 and 2^32 296 | bigreq, err := rand.Int(rand.Reader, big.NewInt(int64(^uint32(0)))) 297 | if err != nil { 298 | glog.Errorf("Failed to generate reqid: %v", err) 299 | return nil, nil, err 300 | } 301 | tunnel.Reqid = int(bigreq.Int64()) + 1 302 | } 303 | 304 | // While tail not created 305 | for { 306 | if tunnel.Src == nil { 307 | // Select random pair of addresses from cidr 308 | for { 309 | tunnel.Dst, tunnel.Src, err = randomIPPair(opts.cidr) 310 | if err != nil { 311 | return nil, nil, err 312 | } 313 | err = reserveIP(tunnel.Dst) 314 | if err != nil { 315 | glog.Infof("IP in use: %v", tunnel.Dst) 316 | continue 317 | } 318 | err = reserveIP(tunnel.Src) 319 | if err != nil { 320 | unreserveIP(tunnel.Dst) 321 | glog.Infof("IP in use: %v", tunnel.Src) 322 | continue 323 | } 324 | break 325 | } 326 | } 327 | // create tail of tunnel 328 | var out *client.Tunnel 329 | dst, out, err = c.BuildTunnel(opts.external, tunnel) 330 | if err != nil { 331 | _, ok := err.(IPInUse) 332 | if ok { 333 | unreserveIP(tunnel.Dst) 334 | unreserveIP(tunnel.Src) 335 | tunnel.Src = nil 336 | if exists != nil { 337 | glog.Warningf("Destroying local tunnel due to remote ip conflict") 338 | destroyTunnel(dst) 339 | exists = nil 340 | } 341 | continue 342 | } 343 | glog.Errorf("Remote BuildTunnel failed: %v", err) 344 | // cleanup partial tunnel 345 | c.DestroyTunnel(opts.external) 346 | return nil, nil, err 347 | } 348 | if exists != nil && !out.Equal(tunnel) { 349 | glog.Warningf("Destroying remote mismatched tunnel") 350 | c.DestroyTunnel(opts.external) 351 | continue 352 | } 353 | tunnel = out 354 | break 355 | } 356 | 357 | // tunnel dst and src are reversed from remote 358 | tunnel.Src, tunnel.Dst = tunnel.Dst, tunnel.Src 359 | tunnel.SrcPort, tunnel.DstPort = tunnel.DstPort, tunnel.SrcPort 360 | if exists == nil { 361 | _, tunnel, err = buildTunnelLocal(dst, tunnel) 362 | if err != nil { 363 | glog.Errorf("Local buildTunnel failed: %v", err) 364 | c.DestroyTunnel(opts.external) 365 | destroyTunnel(dst) 366 | return nil, nil, err 367 | } 368 | } 369 | return tunnel.Src, tunnel.Dst, nil 370 | } 371 | 372 | func deleteTunnel(host string) error { 373 | c, err := client.NewClient(host, opts.config) 374 | if err != nil { 375 | return err 376 | } 377 | defer c.Close() 378 | 379 | dst, _ := c.DestroyTunnel(opts.external) 380 | if dst != nil { 381 | destroyTunnel(dst) 382 | } 383 | return nil 384 | } 385 | 386 | func createEncapListener(ip net.IP, port int) (int, error) { 387 | const ( 388 | UDP_ENCAP = 100 389 | UDP_ENCAP_ESPINUDP = 2 390 | ) 391 | s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, 0) 392 | if err != nil { 393 | return 0, err 394 | } 395 | err = syscall.SetsockoptInt(s, syscall.IPPROTO_UDP, UDP_ENCAP, UDP_ENCAP_ESPINUDP) 396 | if err != nil { 397 | return 0, err 398 | } 399 | var family int 400 | if len(ip) <= net.IPv4len { 401 | family = syscall.AF_INET 402 | } else if ip.To4() != nil { 403 | family = syscall.AF_INET 404 | } else { 405 | family = syscall.AF_INET6 406 | } 407 | var bindaddr syscall.Sockaddr 408 | switch family { 409 | case syscall.AF_INET: 410 | if len(ip) == 0 { 411 | ip = net.IPv4zero 412 | } 413 | sa := new(syscall.SockaddrInet4) 414 | for i := 0; i < net.IPv4len; i++ { 415 | sa.Addr[i] = ip[i] 416 | } 417 | sa.Port = port 418 | bindaddr = sa 419 | case syscall.AF_INET6: 420 | sa := new(syscall.SockaddrInet6) 421 | for i := 0; i < net.IPv6len; i++ { 422 | sa.Addr[i] = ip[i] 423 | } 424 | sa.Port = port 425 | // TODO: optionally allow zone for ipv6 426 | // sa.ZoneId = uint32(zoneToInt(zone)) 427 | bindaddr = sa 428 | } 429 | err = syscall.Bind(s, bindaddr) 430 | if err != nil { 431 | return 0, err 432 | } 433 | return s, nil 434 | } 435 | 436 | func deleteEncapListener(socket int) { 437 | err := syscall.Close(socket) 438 | if err != nil { 439 | glog.Warningf("Failed to delete tunnel udp listener: %v", err) 440 | } 441 | } 442 | 443 | func buildTunnel(dst net.IP, tunnel *client.Tunnel) (net.IP, *client.Tunnel, error) { 444 | exists := getTunnel(dst.String()) 445 | if exists != nil { 446 | glog.Infof("Tunnel already exists: %v, %v", exists.Src, exists.Dst) 447 | return opts.external, exists, nil 448 | } 449 | var err error 450 | if tunnel.DstPort != 0 { 451 | tunnel.SrcPort, err = allocatePort() 452 | if err != nil { 453 | glog.Errorf("No ports available: %v", tunnel.Dst) 454 | return nil, nil, err 455 | } 456 | glog.Infof("Using %d for encap port", tunnel.SrcPort) 457 | } 458 | err = reserveIP(tunnel.Dst) 459 | if err != nil { 460 | glog.Infof("IP in use: %v", tunnel.Dst) 461 | return nil, nil, err 462 | } 463 | err = reserveIP(tunnel.Src) 464 | if err != nil { 465 | unreserveIP(tunnel.Dst) 466 | glog.Infof("IP in use: %v", tunnel.Src) 467 | return nil, nil, err 468 | } 469 | return buildTunnelLocal(dst, tunnel) 470 | } 471 | 472 | func buildTunnelLocal(dst net.IP, tunnel *client.Tunnel) (net.IP, *client.Tunnel, error) { 473 | var socket int 474 | if tunnel.SrcPort != 0 { 475 | var err error 476 | socket, err = createEncapListener(tunnel.Src, tunnel.SrcPort) 477 | if err != nil { 478 | glog.Errorf("Failed to create udp listener: %v", err) 479 | return nil, nil, err 480 | } 481 | } 482 | addTunnel(dst.String(), tunnel, socket) 483 | 484 | src := opts.src 485 | 486 | srcNet := netlink.NewIPNet(tunnel.Src) 487 | dstNet := netlink.NewIPNet(tunnel.Dst) 488 | 489 | glog.Infof("Building tunnel: %v, %v", tunnel.Src, tunnel.Dst) 490 | // add IP address to loopback device 491 | lo, err := netlink.LinkByName("lo") 492 | if err != nil { 493 | glog.Errorf("Failed to get loopback device: %v", err) 494 | return nil, nil, err 495 | } 496 | err = netlink.AddrAdd(lo, &netlink.Addr{IPNet: srcNet}) 497 | if err != nil { 498 | glog.Errorf("Failed to add %v to loopback: %v", tunnel.Src, err) 499 | return nil, nil, err 500 | } 501 | 502 | index, err := getLinkIndex(src) 503 | if err != nil { 504 | glog.Errorf("Failed to get link for address: %v", err) 505 | return nil, nil, err 506 | } 507 | // add source route to tunnel ips device 508 | route := &netlink.Route{ 509 | Scope: netlink.SCOPE_LINK, 510 | Src: tunnel.Src, 511 | Dst: dstNet, 512 | LinkIndex: index, 513 | } 514 | err = netlink.RouteAdd(route) 515 | if err != nil { 516 | glog.Errorf("Failed to add route %v: %v", route, err) 517 | return nil, nil, err 518 | } 519 | 520 | for _, policy := range getPolicies(tunnel.Reqid, src, dst, srcNet, dstNet) { 521 | glog.Infof("building Policy: %v", policy) 522 | // create xfrm policy rules 523 | err = netlink.XfrmPolicyAdd(&policy) 524 | if err != nil { 525 | if err == syscall.EEXIST { 526 | glog.Infof("Skipped adding policy %v because it already exists", policy) 527 | } else { 528 | glog.Errorf("Failed to add policy %v: %v", policy, err) 529 | return nil, nil, err 530 | } 531 | } 532 | } 533 | for _, state := range getStates(tunnel.Reqid, src, dst, tunnel.SrcPort, tunnel.DstPort, tunnel.AuthKey, tunnel.EncKey) { 534 | glog.Infof("building State: %v", state) 535 | // crate xfrm state rules 536 | err = netlink.XfrmStateAdd(&state) 537 | if err != nil { 538 | if err == syscall.EEXIST { 539 | glog.Infof("Skipped adding state %v because it already exists", state) 540 | } else { 541 | glog.Errorf("Failed to add state %v: %v", state, err) 542 | return nil, nil, err 543 | } 544 | } 545 | } 546 | glog.Infof("Finished building tunnel: %v, %v", tunnel.Src, tunnel.Dst) 547 | return opts.external, tunnel, nil 548 | } 549 | 550 | func destroyTunnel(dst net.IP) (net.IP, error) { 551 | // Determine the src and dst ips for the tunnel 552 | key := dst.String() 553 | tunnel := getTunnel(dst.String()) 554 | if tunnel == nil { 555 | s := fmt.Sprintf("Failed to find tunnel to dst %s", dst) 556 | glog.Errorf(s) 557 | return nil, fmt.Errorf(s) 558 | } 559 | 560 | src := opts.src 561 | 562 | srcNet := netlink.NewIPNet(tunnel.Src) 563 | dstNet := netlink.NewIPNet(tunnel.Dst) 564 | 565 | glog.Infof("Destroying Tunnel: %v, %v", tunnel.Src, tunnel.Dst) 566 | 567 | for _, state := range getStates(tunnel.Reqid, src, dst, 0, 0, nil, nil) { 568 | // crate xfrm state rules 569 | err := netlink.XfrmStateDel(&state) 570 | if err != nil { 571 | glog.Errorf("Failed to delete state %v: %v", state, err) 572 | } 573 | } 574 | 575 | for _, policy := range getPolicies(tunnel.Reqid, src, dst, srcNet, dstNet) { 576 | // create xfrm policy rules 577 | err := netlink.XfrmPolicyDel(&policy) 578 | if err != nil { 579 | glog.Errorf("Failed to delete policy %v: %v", policy, err) 580 | } 581 | } 582 | 583 | index, err := getLinkIndex(src) 584 | if err != nil { 585 | glog.Errorf("Failed to get link for address: %v", err) 586 | } else { 587 | 588 | // del source route to tunnel ips device 589 | route := &netlink.Route{ 590 | Scope: netlink.SCOPE_LINK, 591 | Src: tunnel.Src, 592 | Dst: dstNet, 593 | LinkIndex: index, 594 | } 595 | err = netlink.RouteDel(route) 596 | if err != nil { 597 | glog.Errorf("Failed to delete route %v: %v", route, err) 598 | } 599 | } 600 | 601 | // del IP address to loopback device 602 | lo, err := netlink.LinkByName("lo") 603 | if err != nil { 604 | glog.Errorf("Failed to get loopback device: %v", err) 605 | } else { 606 | err = netlink.AddrDel(lo, &netlink.Addr{IPNet: srcNet}) 607 | if err != nil { 608 | glog.Errorf("Failed to delete %v from loopback: %v", tunnel.Src, err) 609 | } 610 | } 611 | if tunnel.SrcPort != 0 { 612 | deleteEncapListener(getListener(key)) 613 | releasePort(tunnel.SrcPort) 614 | } 615 | unreserveIP(tunnel.Src) 616 | unreserveIP(tunnel.Dst) 617 | removeTunnel(key) 618 | glog.Infof("Finished destroying tunnel: %v, %v", tunnel.Src, tunnel.Dst) 619 | return opts.external, nil 620 | } 621 | 622 | func getPolicies(reqid int, src net.IP, dst net.IP, srcNet *net.IPNet, dstNet *net.IPNet) []netlink.XfrmPolicy { 623 | policies := make([]netlink.XfrmPolicy, 0) 624 | out := netlink.XfrmPolicy{ 625 | Src: srcNet, 626 | Dst: dstNet, 627 | Dir: netlink.XFRM_DIR_OUT, 628 | } 629 | otmpl := netlink.XfrmPolicyTmpl{ 630 | Src: src, 631 | Dst: dst, 632 | Proto: netlink.XFRM_PROTO_ESP, 633 | Mode: netlink.XFRM_MODE_TUNNEL, 634 | Reqid: reqid, 635 | } 636 | out.Tmpls = append(out.Tmpls, otmpl) 637 | policies = append(policies, out) 638 | in := netlink.XfrmPolicy{ 639 | Src: dstNet, 640 | Dst: srcNet, 641 | Dir: netlink.XFRM_DIR_IN, 642 | } 643 | itmpl := netlink.XfrmPolicyTmpl{ 644 | Src: dst, 645 | Dst: src, 646 | Proto: netlink.XFRM_PROTO_ESP, 647 | Mode: netlink.XFRM_MODE_TUNNEL, 648 | Reqid: reqid, 649 | } 650 | in.Tmpls = append(in.Tmpls, itmpl) 651 | policies = append(policies, in) 652 | return policies 653 | } 654 | 655 | func getStates(reqid int, src net.IP, dst net.IP, srcPort int, dstPort int, authKey []byte, encKey []byte) []netlink.XfrmState { 656 | states := make([]netlink.XfrmState, 0) 657 | out := netlink.XfrmState{ 658 | Src: src, 659 | Dst: dst, 660 | Proto: netlink.XFRM_PROTO_ESP, 661 | Mode: netlink.XFRM_MODE_TUNNEL, 662 | Spi: reqid, 663 | Reqid: reqid, 664 | ReplayWindow: 32, 665 | Auth: &netlink.XfrmStateAlgo{ 666 | Name: "hmac(sha256)", 667 | Key: authKey, 668 | }, 669 | Crypt: &netlink.XfrmStateAlgo{ 670 | Name: "cbc(aes)", 671 | Key: encKey, 672 | }, 673 | } 674 | if srcPort != 0 && dstPort != 0 { 675 | out.Encap = &netlink.XfrmStateEncap{ 676 | Type: netlink.XFRM_ENCAP_ESPINUDP, 677 | SrcPort: srcPort, 678 | DstPort: dstPort, 679 | } 680 | } 681 | states = append(states, out) 682 | in := netlink.XfrmState{ 683 | Src: dst, 684 | Dst: src, 685 | Proto: netlink.XFRM_PROTO_ESP, 686 | Mode: netlink.XFRM_MODE_TUNNEL, 687 | Spi: reqid, 688 | Reqid: reqid, 689 | ReplayWindow: 32, 690 | Auth: &netlink.XfrmStateAlgo{ 691 | Name: "hmac(sha256)", 692 | Key: authKey, 693 | }, 694 | Crypt: &netlink.XfrmStateAlgo{ 695 | Name: "cbc(aes)", 696 | Key: encKey, 697 | }, 698 | } 699 | if srcPort != 0 && dstPort != 0 { 700 | in.Encap = &netlink.XfrmStateEncap{ 701 | Type: netlink.XFRM_ENCAP_ESPINUDP, 702 | SrcPort: dstPort, 703 | DstPort: srcPort, 704 | } 705 | } 706 | states = append(states, in) 707 | return states 708 | } 709 | -------------------------------------------------------------------------------- /utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "crypto/rand" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | const ( 11 | DEFAULT_PORT = 9999 12 | DEFAULT_HOST = "" 13 | DEFAULT_UNIX = "/var/run/wormhole" 14 | ) 15 | 16 | func Uuid() string { 17 | b := make([]byte, 16) 18 | rand.Read(b) 19 | b[6] = (b[6] & 0x0f) | 0x40 20 | b[8] = (b[8] & 0x3f) | 0x80 21 | return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) 22 | } 23 | 24 | type ValidatorFctType func(val string) (string, error) 25 | 26 | // ListOpts type 27 | type ListOpts struct { 28 | values []string 29 | validator ValidatorFctType 30 | } 31 | 32 | func NewListOpts(validator ValidatorFctType) ListOpts { 33 | return ListOpts{ 34 | validator: validator, 35 | } 36 | } 37 | 38 | func (opts *ListOpts) String() string { 39 | return fmt.Sprintf("%v", []string(opts.values)) 40 | } 41 | 42 | // Set validates if needed the input value and add it to the 43 | // internal slice. 44 | func (opts *ListOpts) Set(value string) error { 45 | if opts.validator != nil { 46 | v, err := opts.validator(value) 47 | if err != nil { 48 | return err 49 | } 50 | value = v 51 | } 52 | opts.values = append(opts.values, value) 53 | return nil 54 | } 55 | 56 | // Delete remove the given element from the slice. 57 | func (opts *ListOpts) Delete(key string) { 58 | for i, k := range opts.values { 59 | if k == key { 60 | opts.values = append(opts.values[:i], opts.values[i+1:]...) 61 | return 62 | } 63 | } 64 | } 65 | 66 | // GetAll returns the values' slice. 67 | func (opts *ListOpts) GetAll() []string { 68 | return opts.values 69 | } 70 | 71 | // Len returns the amount of element in the slice. 72 | func (opts *ListOpts) Len() int { 73 | return len(opts.values) 74 | } 75 | 76 | func ValidateAddr(addr string) (string, error) { 77 | var ( 78 | proto string 79 | host string 80 | port int 81 | ) 82 | 83 | switch { 84 | case strings.HasPrefix(addr, "unix://"): 85 | proto = "unix" 86 | addr = strings.TrimPrefix(addr, "unix://") 87 | if addr == "" { 88 | addr = DEFAULT_UNIX 89 | } 90 | case strings.HasPrefix(addr, "tcp://"): 91 | proto = "tcp" 92 | addr = strings.TrimPrefix(addr, "tcp://") 93 | default: 94 | if strings.Contains(addr, "://") { 95 | return "", fmt.Errorf("Invalid bind address protocol: %s", addr) 96 | } 97 | proto = "tcp" 98 | } 99 | 100 | if proto != "unix" && strings.Contains(addr, ":") { 101 | hostParts := strings.Split(addr, ":") 102 | if len(hostParts) != 2 { 103 | return "", fmt.Errorf("Invalid bind address format: %s", addr) 104 | } 105 | if hostParts[0] != "" { 106 | host = hostParts[0] 107 | } else { 108 | host = DEFAULT_HOST 109 | } 110 | 111 | if p, err := strconv.Atoi(hostParts[1]); err == nil && p != 0 { 112 | port = p 113 | } else { 114 | port = DEFAULT_PORT 115 | } 116 | 117 | } else { 118 | host = addr 119 | port = DEFAULT_PORT 120 | } 121 | if proto == "unix" { 122 | return fmt.Sprintf("%s://%s", proto, host), nil 123 | } 124 | return fmt.Sprintf("%s://%s:%d", proto, host, port), nil 125 | } 126 | 127 | func ParseAddr(host string) (string, string) { 128 | res := strings.SplitN(host, "://", 2) 129 | return res[0], res[1] 130 | } 131 | 132 | func ParseUrl(url string) (proto string, ns string, hostname string, port int, err error) { 133 | url = strings.TrimSpace(url) 134 | if len(url) == 0 { 135 | return 136 | } 137 | switch { 138 | case strings.HasPrefix(url, "unix://"): 139 | url = strings.TrimPrefix(url, "unix://") 140 | proto = "unix" 141 | case strings.HasPrefix(url, "tcp://"): 142 | url = strings.TrimPrefix(url, "tcp://") 143 | proto = "tcp" 144 | case strings.HasPrefix(url, "udp://"): 145 | url = strings.TrimPrefix(url, "udp://") 146 | proto = "udp" 147 | default: 148 | if strings.Contains(url, "://") { 149 | err = fmt.Errorf("Invalid segment protocol: %s", url) 150 | return 151 | } 152 | } 153 | if strings.Contains(url, "@") { 154 | if proto == "unix" { 155 | err = fmt.Errorf("Namespace not supported in unix protocol") 156 | } 157 | nsParts := strings.Split(url, "@") 158 | if len(nsParts) != 2 { 159 | err = fmt.Errorf("Only one namespace is allowed") 160 | } 161 | ns = nsParts[0] 162 | url = nsParts[1] 163 | } 164 | n := len(url) - 1 165 | if n > 0 && url[0] == '[' && url[n] == ']' { 166 | url = url[1:n] 167 | } else { 168 | i := strings.LastIndex(url, ":") 169 | if i != -1 { 170 | if proto == "unix" { 171 | err = fmt.Errorf("Port not supported in unix protocol") 172 | } 173 | strPort := url[i+1:] 174 | url = url[0:i] 175 | if len(strPort) != 0 { 176 | if p, err := strconv.Atoi(strPort); err == nil && p != 0 { 177 | port = p 178 | } else { 179 | err = fmt.Errorf("Invalid value for port: %v", strPort) 180 | } 181 | } 182 | n := len(url) - 1 183 | if n > 0 && url[0] == '[' && url[n] == ']' { 184 | url = url[1:n] 185 | } else if strings.Contains(url, ":") { 186 | err = fmt.Errorf("Only one port is allowed") 187 | } 188 | } 189 | } 190 | if proto != "unix" && (strings.Contains(url, "[") || strings.Contains(url, "]")) { 191 | err = fmt.Errorf("Invalid characters in hostname") 192 | } 193 | hostname = url 194 | return 195 | } 196 | -------------------------------------------------------------------------------- /utils/utils_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func validate(t *testing.T, url string, proto string, ns string, hostname string, port int) { 8 | pr, n, h, po, err := ParseUrl(url) 9 | if err != nil { 10 | t.Fatal("Error parsing %v: %v", url, err) 11 | } 12 | if proto != pr { 13 | t.Fatalf("Protocol doesn't match for %v: (Expected: '%v' != Actual '%v')", url, proto, pr) 14 | } 15 | if ns != n { 16 | t.Fatalf("Namespace doesn't match for %v: (Expected: '%v' != Actual '%v')", url, ns, n) 17 | } 18 | if hostname != h { 19 | t.Fatalf("Hostname doesn't match for %v: (Expected: '%v' != Actual '%v')", url, hostname, h) 20 | } 21 | if port != po { 22 | t.Fatalf("Port doesn't match for %v: (Expected: '%v' != Actual '%v')", url, port, po) 23 | } 24 | } 25 | 26 | func errors(t *testing.T, url string) { 27 | pr, n, h, po, err := ParseUrl(url) 28 | if err == nil { 29 | t.Fatalf("No error for %v: (Actual '%v', '%v', '%v', '%v')", url, pr, n, h, po) 30 | } 31 | } 32 | 33 | func TestParseUrl(t *testing.T) { 34 | validate(t, "", "", "", "", 0) 35 | validate(t, ":40", "", "", "", 40) 36 | validate(t, "foo", "", "", "foo", 0) 37 | validate(t, "foo:", "", "", "foo", 0) 38 | validate(t, "foo:40", "", "", "foo", 40) 39 | validate(t, "ns@", "", "ns", "", 0) 40 | validate(t, "ns@foo", "", "ns", "foo", 0) 41 | validate(t, "ns@:40", "", "ns", "", 40) 42 | validate(t, "ns@foo:40", "", "ns", "foo", 40) 43 | validate(t, "tcp://", "tcp", "", "", 0) 44 | validate(t, "udp://", "udp", "", "", 0) 45 | validate(t, "unix://", "unix", "", "", 0) 46 | validate(t, "tcp://foo", "tcp", "", "foo", 0) 47 | validate(t, "tcp://:40", "tcp", "", "", 40) 48 | validate(t, "tcp://foo:40", "tcp", "", "foo", 40) 49 | validate(t, "tcp://ns@", "tcp", "ns", "", 0) 50 | validate(t, "tcp://ns@foo", "tcp", "ns", "foo", 0) 51 | validate(t, "tcp://ns@:40", "tcp", "ns", "", 40) 52 | validate(t, "tcp://ns@foo:40", "tcp", "ns", "foo", 40) 53 | validate(t, "[::1]:40", "", "", "::1", 40) 54 | } 55 | 56 | func TestParseUrlErrors(t *testing.T) { 57 | errors(t, "multiple@namespace@foo") 58 | errors(t, "invalid://host") 59 | errors(t, "multiple:ports:foo") 60 | errors(t, "unix://with@namepace") 61 | errors(t, "unix://with:port") 62 | errors(t, "::1:40") 63 | errors(t, "[bad]bracketing") 64 | } 65 | -------------------------------------------------------------------------------- /wordpress/Dockerfile: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # Dockerfile for wordpress 3 | # Based on tutum/wordpress-stackable 4 | ############################################################ 5 | 6 | FROM tutum/wordpress-stackable 7 | 8 | MAINTAINER Vishvananda Ishaya 9 | 10 | ENV DB_HOST 127.0.0.1 11 | ENV DB_PORT 3306 12 | ENV DB_USER admin 13 | ENV DB_PASS simple 14 | -------------------------------------------------------------------------------- /wordpress/Makefile: -------------------------------------------------------------------------------- 1 | all: docker-wordpress 2 | 3 | .PHONY: docker-wordpress 4 | docker-wordpress: 5 | docker build -t wormhole/wordpress . 6 | --------------------------------------------------------------------------------