├── shuttle-cli ├── shuttle-cli └── main.go ├── .travis.yml ├── GLOCKFILE ├── .gitignore ├── utils.go ├── LICENSE ├── testdata ├── vhost1.pem ├── vhost2.pem ├── vhost1.key └── vhost2.key ├── config.go ├── Makefile ├── main.go ├── log └── log.go ├── README.md ├── client ├── client.go └── config.go ├── balancer.go ├── proxy_bench_test.go ├── admin.go ├── server_test.go ├── backend.go ├── reverseproxy.go ├── http.go ├── registry.go ├── shuttle_test.go ├── service.go └── admin_test.go /shuttle-cli/shuttle-cli: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litl/shuttle/HEAD/shuttle-cli/shuttle-cli -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.5.3 4 | install: 5 | - go get github.com/robfig/glock 6 | - make deps 7 | script: 8 | - make all fmt test 9 | -------------------------------------------------------------------------------- /GLOCKFILE: -------------------------------------------------------------------------------- 1 | github.com/fatih/color 95b468b5f34882796c597b718955603a584a9bd4 2 | github.com/gorilla/context a08edd30ad9e104612741163dc087a613829a23c 3 | github.com/gorilla/mux 270c42505a11c779b5a5aaecfa5ec717adac996e 4 | gopkg.in/check.v1 871360013c92e1c715c2de6d06b54899468a8a2d 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | 24 | # dist builds 25 | dist 26 | *.tar.gz 27 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | "strings" 9 | ) 10 | 11 | // marshal whatever we've got with out default indentation 12 | // swallowing errors. 13 | func marshal(i interface{}) []byte { 14 | jsonBytes, err := json.MarshalIndent(i, "", " ") 15 | if err != nil { 16 | log.Println("error encoding json:", err) 17 | } 18 | return append(jsonBytes, '\n') 19 | } 20 | 21 | // random 64bit ID 22 | func genId() string { 23 | b := make([]byte, 8) 24 | rand.Read(b) 25 | return fmt.Sprintf("%x", b) 26 | } 27 | 28 | // remove empty strings from a []string 29 | func filterEmpty(a []string) []string { 30 | removed := 0 31 | for i := 0; i < len(a); i++ { 32 | if removed > 0 { 33 | a[i-removed] = a[i] 34 | } 35 | if len(strings.TrimSpace(a[i])) == 0 { 36 | removed++ 37 | } 38 | 39 | } 40 | return a[:len(a)-removed] 41 | } 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 litl, LLC. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /testdata/vhost1.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDHjCCAgagAwIBAgIQbbl4Hl/rDdACZzHMe1rLbjANBgkqhkiG9w0BAQsFADAS 3 | MRAwDgYDVQQKEwdBY21lIENvMB4XDTE2MDEyNTIxNDUyNFoXDTI2MDEyMjIxNDUy 4 | NFowEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC 5 | AQoCggEBAMiQVT4/fzMIBb+HFG5nuK/78Iega2GO8D/aL7JNDeW1czNP7lMHRrO8 6 | Z+jiJ5llqojHUYIp2mPkYg07IxI+4ur+YKS9BKyr5msyK0ibFWrC8k9t2dTYlJLb 7 | VKW/ys/nd1F5MMHSUgBpCM6ilPtXlZGmUR8wwDsg3Ekvmm98pL2bCp8qnKZHE7IQ 8 | 4xwfsRwrVNImWsDGQXtM30dmvYfhveZp2H205hAwLEEWcOFLLvKVfCvExYOT2UJi 9 | tHY2CzgSIBYnY4/Y4e8diSzxgbVEkjc2QHoRfd3PSESPVEO1xb3pU2H4RcS+bedY 10 | M3wXNPectpQi8Y56+PHJkUKMDqd/VLsCAwEAAaNwMG4wDgYDVR0PAQH/BAQDAgKk 11 | MBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQFMAMBAf8wNgYDVR0RBC8w 12 | LYILdmhvc3QxLnRlc3SCD2FsdC52aG9zdDEudGVzdIINKi52aG9zdDEudGVzdDAN 13 | BgkqhkiG9w0BAQsFAAOCAQEApgLHr8VY9qheQWxeOn17fO0DgQRoY2s6qdXqhBJe 14 | QwBC/RHrB62zzkO3Q37AU31BhYCIf3YzOKvFSKeWhWIjsSG/hZt8w/jvbCohzHX8 15 | 0MsK9LaLG8AypYD6Ztqr2UT2ekwG5B/zSUTmYysLQaqdChXJypQ4Oerdmy78x9ep 16 | Tlwc+DE+QANYrj4lvN8h7/1lo27PEMGyeH0xLRdXZihOqFK4ktj5NmTOgX7I+hAT 17 | IhbW5MFVFpGHchcsIUeCNVzm3NP/DfbdUK0HodYKdXj70g9Uq6/CznhosqKhApg1 18 | 3Cxr/EriFzYU6gif/PtEmSeRNrQzYt9EsKll0DHAbW6O2Q== 19 | -----END CERTIFICATE----- 20 | -------------------------------------------------------------------------------- /testdata/vhost2.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDHzCCAgegAwIBAgIRALemi25z2dFdIKse33bVL6kwDQYJKoZIhvcNAQELBQAw 3 | EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNjAxMjUyMTQ1NTZaFw0yNjAxMjIyMTQ1 4 | NTZaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw 5 | ggEKAoIBAQDkdL5jBt0qpFRoviKcqxQsPdKPSlos3Vcin1IlIJa9uO3Dl0vTsHQL 6 | od8RtzUgADUmkQiv1rm4rWKCLXXeStpubLf3/BJb+J3kVxPHzrCbdpQ9uinkjunT 7 | PizVX0pVI8Jh4+8qVdI7kdKn8MVXqe71dJzviMTXd4YuUD5ikXzga0LNmOThM/vU 8 | zk/fV/aRdp6qXzoRsJO9Kg7+I9nl2kNIHJi4Hyrkl1+JgiHFy6eL4/Nf3u3/o/zx 9 | pfF/SwO8GbdzX+UkTd+w5u566+zrTauhoFCo/bJ4pumyPSKo96HG0gzM64ih65DV 10 | J522SvF26DOADkbvYDNzUVLU67HVfXfdAgMBAAGjcDBuMA4GA1UdDwEB/wQEAwIC 11 | pDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MDYGA1UdEQQv 12 | MC2CC3Zob3N0Mi50ZXN0gg9hbHQudmhvc3QyLnRlc3SCDSoudmhvc3QyLnRlc3Qw 13 | DQYJKoZIhvcNAQELBQADggEBABqWrJdNgpsikMx98KMEsOB+7kGLnxMJtnR6YvLV 14 | 1MmzsSthkqOr8yaSq+m5pdf4/a7n770zuyp8aXBO8JqCD0cN4Qyf5+PUZoWwX+KC 15 | sPKKxdHL4bgGha0ujlS7JwKRLXzHh5g6TaQ/lKSnliWPllQOpmqjfEt3S53oDZ+I 16 | qEd0HR50S3jTJsUJJ+az+rvDV65d0DYQRMvMUv3KzB5n1hWOgkf18OYVqdTk2Z5w 17 | W/7sp11B7eMXPnLRFM+rcRSXpiJRB2Ojc36k//x8ZQJUztaSlWKoIluAN8ToI9GY 18 | JmoHI6+tMhhntSKFn49F++vR2mXsGukTg0QhCHfTGigDquk= 19 | -----END CERTIFICATE----- 20 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "io/ioutil" 7 | "sync" 8 | 9 | "github.com/litl/shuttle/client" 10 | "github.com/litl/shuttle/log" 11 | ) 12 | 13 | func loadConfig() { 14 | for _, cfgPath := range []string{stateConfig, defaultConfig} { 15 | if cfgPath == "" { 16 | continue 17 | } 18 | 19 | cfgData, err := ioutil.ReadFile(cfgPath) 20 | if err != nil { 21 | log.Warnln("Error reading config:", err) 22 | continue 23 | } 24 | 25 | var cfg client.Config 26 | err = json.Unmarshal(cfgData, &cfg) 27 | if err != nil { 28 | log.Warnln("Config error:", err) 29 | continue 30 | } 31 | log.Debug("Loaded config from:", cfgPath) 32 | 33 | if err := Registry.UpdateConfig(cfg); err != nil { 34 | log.Printf("Unable to load config: error: %s", err) 35 | } 36 | } 37 | } 38 | 39 | // protects the state config file 40 | var configMutex sync.Mutex 41 | 42 | func writeStateConfig() { 43 | configMutex.Lock() 44 | defer configMutex.Unlock() 45 | 46 | if stateConfig == "" { 47 | log.Debug("No state file. Not saving changes") 48 | return 49 | } 50 | 51 | cfg := marshal(Registry.Config()) 52 | if len(cfg) == 0 { 53 | return 54 | } 55 | 56 | lastCfg, _ := ioutil.ReadFile(stateConfig) 57 | if bytes.Equal(cfg, lastCfg) { 58 | log.Println("No change in config") 59 | return 60 | } 61 | 62 | // We should probably write a temp file and mv for atomic update. 63 | err := ioutil.WriteFile(stateConfig, cfg, 0644) 64 | if err != nil { 65 | log.Println("Error saving config state:", err) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /testdata/vhost1.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpQIBAAKCAQEAyJBVPj9/MwgFv4cUbme4r/vwh6BrYY7wP9ovsk0N5bVzM0/u 3 | UwdGs7xn6OInmWWqiMdRginaY+RiDTsjEj7i6v5gpL0ErKvmazIrSJsVasLyT23Z 4 | 1NiUkttUpb/Kz+d3UXkwwdJSAGkIzqKU+1eVkaZRHzDAOyDcSS+ab3ykvZsKnyqc 5 | pkcTshDjHB+xHCtU0iZawMZBe0zfR2a9h+G95mnYfbTmEDAsQRZw4Usu8pV8K8TF 6 | g5PZQmK0djYLOBIgFidjj9jh7x2JLPGBtUSSNzZAehF93c9IRI9UQ7XFvelTYfhF 7 | xL5t51gzfBc095y2lCLxjnr48cmRQowOp39UuwIDAQABAoIBAQDFZXjgeTJCEaVG 8 | qjYrq54UZwyHEBZfwIUo8x96h2gkK4Akgoj34vNtNwO2K8/5pBxB3pqUV4kAQ+lV 9 | SFzuOkKwMoj/2qFdKRrxakE7hpd/qjs+fcmlOTyRhZk8QRXlpdTDtVmNiej3SmlG 10 | prGm5r7oyR6SajLofyEQTu/axnyFtvJa3TDDDJfEeccRPROfRxSSUCkXgp9TNw6F 11 | kouiri84zKVuuQXPyUYCviiFp/Tpt4Lh5OXA4wvcHspPUjdPK266USOmgikSZzTY 12 | +u+0SJ8nGMAeYDiC6F50viJPlpidEBTRe0oNUdXljGZYOPNCBQG61nqi8ACRsyo+ 13 | N6SGLGnBAoGBANqCaoK2lKOE00+P2G/QeGEgbaehKBHwNcixg7tO3nLygotd21k/ 14 | Dv6trlfAyfpikdtcNHFqGBkKJ5Df8srJOM83IHMgSjshgoixHf36GpREZa/K4ZZU 15 | 7lmGIxLh5VfOYEwiqATpuk8CI3+z+Hqy73MOgcur+00YaGbKE6xti4gLAoGBAOr5 16 | sPMcpJ+PjyV4p9qakMvjjTwpdDMftwLJSY0xF3qVTVUAS6UeeoKYy7Iku86w3+4G 17 | BoXYoP/sXhJAINK+a/lVtnNewwCsBxnaGoXahE1HhMllHmXwTcsTPkKMzy6yc5ze 18 | E2WEfLzb2Jni554yUt6qa4ng9tWmg/0uKFgzpWQRAoGAXmBJxJ87X8z0v75vSwwN 19 | klXBRs+SUP0hHceeD/6mkZswyyUEom1b+p/lVz2Lfzunp8kRVZLvSZFbOXWglfmH 20 | MeireU4PAa8dhBCL1bB6XmOUT/MesCGKuNv4tiUfO2eFrByj2UtiDtHrpzKCNeyn 21 | A1jWsrNbXRcXsJ3DFYxS4bUCgYEAwxaD/5SsaX27j5TZZ/okdeN7g5O3Uirmu317 22 | f6pen/wNtKEGLRVdCcjqdgFhnH3lra17BO2S3mjUwbpUhiRraRvs22S16nzpeGFI 23 | 3BFM/wx+BufZkTEupYhYjNBzw4WNz5Ph7stM9VBiSYHGY+XMP+qmVlddGI2j0DTe 24 | cjyO+MECgYEAq2Or8wX0zqCHYinObsOGXwaOCleWb7Mjore3VMAtzajVKdSYE8DH 25 | yUleOXIlbb77Q/N4QezoSQLBvRQWov48DRGzwlW9AWH0MRBTeagzu0uL7+pKK4fi 26 | ATgCLtiMwsKjkOxiuhn+ikd2dDw/5LySYVjIFeVOlHalQ5HfSvcM0X4= 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /testdata/vhost2.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpAIBAAKCAQEA5HS+YwbdKqRUaL4inKsULD3Sj0paLN1XIp9SJSCWvbjtw5dL 3 | 07B0C6HfEbc1IAA1JpEIr9a5uK1igi113krabmy39/wSW/id5FcTx86wm3aUPbop 4 | 5I7p0z4s1V9KVSPCYePvKlXSO5HSp/DFV6nu9XSc74jE13eGLlA+YpF84GtCzZjk 5 | 4TP71M5P31f2kXaeql86EbCTvSoO/iPZ5dpDSByYuB8q5JdfiYIhxcuni+PzX97t 6 | /6P88aXxf0sDvBm3c1/lJE3fsObueuvs602roaBQqP2yeKbpsj0iqPehxtIMzOuI 7 | oeuQ1SedtkrxdugzgA5G72Azc1FS1Oux1X133QIDAQABAoIBAFDARy9/lJtm/IMN 8 | efSAsB+3Nn75nAgxsIQHZqTC8SVcgYZaKy5HN62I6O09IeUOzbq1Fyn4Lyts9d3n 9 | rbsGIFFZ0mkwS1kA9uZoNRCyKVC6SEnNTNOCBHprhrNg/Eg93I53X+lJ7oap05kT 10 | DN4grdtK/dHZOSKkF+S07mgu3sH/3L5tofJhfV4dv2INLnM35uYIVazITZhS+V3O 11 | KXue7DPiHyj+vwRBHlIObmgj4gu5NOychhsxhGRpDSxfqY+WeIyuIl4rqL3HiGdV 12 | hlqHGBLU5DihRvvIPIGlGot+ZrXKi5bX/62X0bifU5n6YK+K9W0N2P8Bxe8n+J6g 13 | 6Z1WbM0CgYEA59lzL4CypozEOMIbTPYB/KtKE3F/ZJdunj5kW2gx1SAaHVq5X6wg 14 | rtMQImd3jJ55Y2uES+XkY6HKGG6nXPpmMCiVkXrnUqUq4yBni2ljrPbICREp4fQZ 15 | G3UgbCVXMm4Mju13T7U5/gGsZ4V566HkbhbwKkCBvsVAZ94/c95T7aMCgYEA/EDO 16 | FGqwhjs4A2xbY+c36p1BVi11k1oyYFn+hAH9s/UgaG7Tiu+c6O2PZkgUUNyEKSL2 17 | faKIYX5NSBSH9Tpm0iaoFCiODHuf3KROdGeE2RVNDn8+ft8G/UjodgNVncrqf3AV 18 | wysBb8+oNaJ//QZzUmPasDqJOAgwddY4gvz8XH8CgYEAqf1KMfMHiYu+NutCvwvE 19 | azBfsJ/Pyr4o8cdHJ6nel6fg3dLuBZKbp/LCaqc4BRcQY2+qYUeeS9qM5ZsEBOzm 20 | zbqD51WYk4TcTAkvQg00ctXB1rwJ3ExvuC0JZ6F9LFF5zbWYfA2hBnbNpF0+BiD9 21 | 7iXNUv1W47uWPFG8bkT9fkcCgYBnkcQLJfLsagwJe8faMOkIbyCQXYHUyke8v7Z8 22 | RMUByjdQKZC5jsAB8ufZuuZ8fM3WhgBmfQE55j2cxrE7worM5gpEnJIWFfwA/4Um 23 | zgoBh3ln5l6mgLPB8tle+ueALfwx7rdAtruUSNJrkxixrqBSx6TWjnIgi1w6RYZW 24 | YcDLyQKBgQCGIE95UREAb2AMlpyqJO2Qk/LWOQz0RWvGuogdyvwnd+wNaVYqA6V4 25 | I22vrusTF/3wbYqVBJqAo5/PMW/oZfEY3jKCsjrp+0zyqEjLg9v5nFYM6eelf0mS 26 | y2HkwymByX31QCwBQw+1uJXS+ZQ1n66MH8H4TyU/7L4niD6iXn6sNQ== 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .SILENT : 2 | .PHONY : shuttle fmt test dist-clean 3 | 4 | TAG :=`git describe --tags` 5 | 6 | VERSION = `git describe --tags 2>/dev/null || git rev-parse --short HEAD 2>/dev/null` 7 | LDFLAGS = -X main.buildVersion=$(VERSION) 8 | 9 | all: shuttle 10 | 11 | deps: 12 | glock sync github.com/litl/shuttle 13 | 14 | shuttle: 15 | echo "Building shuttle" 16 | go install -x -ldflags "$(LDFLAGS)" github.com/litl/shuttle 17 | 18 | fmt: 19 | go fmt github.com/litl/shuttle/... 20 | 21 | test: 22 | go test -v github.com/litl/shuttle 23 | 24 | dist-clean: 25 | rm -rf dist 26 | rm -f shuttle-*.tar.gz 27 | 28 | dist-init: 29 | mkdir -p dist/$$GOOS/$$GOARCH 30 | 31 | dist-build: dist-init 32 | echo "Compiling $$GOOS/$$GOARCH" 33 | go build -a -ldflags "$(LDFLAGS)" -o dist/$$GOOS/$$GOARCH/shuttle github.com/litl/shuttle 34 | 35 | dist-linux-amd64: 36 | export GOOS="linux"; \ 37 | export GOARCH="amd64"; \ 38 | $(MAKE) dist-build 39 | 40 | dist-linux-386: 41 | export GOOS="linux"; \ 42 | export GOARCH="386"; \ 43 | $(MAKE) dist-build 44 | 45 | dist-darwin-amd64: 46 | export GOOS="darwin"; \ 47 | export GOARCH="amd64"; \ 48 | $(MAKE) dist-build 49 | 50 | dist: dist-clean dist-init dist-linux-amd64 dist-linux-386 dist-darwin-amd64 51 | 52 | release-tarball: 53 | echo "Building $$GOOS-$$GOARCH-$(TAG).tar.gz" 54 | GZIP=-9 tar -cvzf shuttle-$$GOOS-$$GOARCH-$(TAG).tar.gz -C dist/$$GOOS/$$GOARCH shuttle >/dev/null 2>&1 55 | 56 | release-linux-amd64: 57 | export GOOS="linux"; \ 58 | export GOARCH="amd64"; \ 59 | $(MAKE) release-tarball 60 | 61 | release-linux-386: 62 | export GOOS="linux"; \ 63 | export GOARCH="386"; \ 64 | $(MAKE) release-tarball 65 | 66 | release-darwin-amd64: 67 | export GOOS="darwin"; \ 68 | export GOARCH="amd64"; \ 69 | $(MAKE) release-tarball 70 | 71 | release: deps dist release-linux-amd64 release-linux-386 release-darwin-amd64 72 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "sync" 6 | 7 | "github.com/litl/shuttle/log" 8 | ) 9 | 10 | var ( 11 | // Location of the default config. 12 | // This will not be overwritten by shuttle. 13 | defaultConfig string 14 | 15 | // Location of the live config which is updated on every state change. 16 | // The default config is loaded if this file does not exist. 17 | stateConfig string 18 | 19 | // Listen addressed for the http servers. 20 | httpAddr string 21 | httpsAddr string 22 | 23 | // Listen address for the http server. 24 | adminListenAddr string 25 | 26 | // Debug logging 27 | debug bool 28 | 29 | // Redirect to HTTPS endpoint 30 | httpsRedirect bool 31 | 32 | // version flags 33 | version bool 34 | buildVersion string 35 | 36 | // SSL Certificate directory 37 | certDir string 38 | ) 39 | 40 | func init() { 41 | flag.StringVar(&httpAddr, "http", "", "http server address") 42 | flag.StringVar(&httpsAddr, "https", "", "https server address") 43 | flag.StringVar(&adminListenAddr, "admin", "127.0.0.1:9090", "admin http server address") 44 | flag.StringVar(&defaultConfig, "config", "", "default config file") 45 | flag.StringVar(&stateConfig, "state", "", "updated config which reflects the internal state") 46 | flag.StringVar(&certDir, "certs", "./", "directory containing SSL Certficates and Keys") 47 | flag.BoolVar(&debug, "debug", false, "verbose logging") 48 | flag.BoolVar(&version, "v", false, "display version") 49 | 50 | flag.BoolVar(&httpsRedirect, "https-redirect", false, "redirect all http vhost requests to https") 51 | flag.BoolVar(&httpsRedirect, "sslOnly", false, "require https (deprecated)") 52 | 53 | flag.Parse() 54 | } 55 | 56 | func main() { 57 | if debug { 58 | log.DefaultLogger.Level = log.DEBUG 59 | } 60 | 61 | if version { 62 | println(buildVersion) 63 | return 64 | } 65 | 66 | log.Printf("Starting shuttle %s", buildVersion) 67 | loadConfig() 68 | 69 | var wg sync.WaitGroup 70 | wg.Add(1) 71 | go startAdminHTTPServer(&wg) 72 | 73 | if httpAddr != "" { 74 | wg.Add(1) 75 | go startHTTPServer(&wg) 76 | } 77 | 78 | if httpsAddr != "" { 79 | wg.Add(1) 80 | go startHTTPSServer(&wg) 81 | } 82 | wg.Wait() 83 | } 84 | -------------------------------------------------------------------------------- /log/log.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "io" 5 | golog "log" 6 | "os" 7 | 8 | "github.com/fatih/color" 9 | ) 10 | 11 | const ( 12 | ERROR = iota 13 | INFO 14 | WARN 15 | DEBUG 16 | ) 17 | 18 | type Logger struct { 19 | golog.Logger 20 | Level int 21 | Prefix string 22 | } 23 | 24 | var ( 25 | red = color.New(color.FgRed).SprintFunc() 26 | redln = color.New(color.FgRed).SprintlnFunc() 27 | redf = color.New(color.FgRed).SprintfFunc() 28 | yellow = color.New(color.FgYellow).SprintFunc() 29 | yellowln = color.New(color.FgYellow).SprintlnFunc() 30 | yellowf = color.New(color.FgYellow).SprintfFunc() 31 | ) 32 | 33 | func New(out io.Writer, prefix string, level int) *Logger { 34 | l := &Logger{ 35 | Level: level, 36 | Prefix: prefix, 37 | } 38 | l.Logger = *(golog.New(out, prefix, golog.LstdFlags)) 39 | return l 40 | } 41 | 42 | var DefaultLogger = New(os.Stderr, "", INFO) 43 | 44 | func (l *Logger) Debug(v ...interface{}) { 45 | if l.Level < DEBUG { 46 | return 47 | } 48 | l.Println(v...) 49 | } 50 | 51 | func (l *Logger) Debugf(fmt string, v ...interface{}) { 52 | if l.Level < DEBUG { 53 | return 54 | } 55 | l.Printf(fmt, v...) 56 | } 57 | 58 | func (l *Logger) Write(p []byte) (n int, err error) { 59 | if l.Level < DEBUG { 60 | return 61 | } 62 | l.Print(string(p)) 63 | return len(p), nil 64 | } 65 | 66 | func Debug(v ...interface{}) { DefaultLogger.Debug(v...) } 67 | func Debugf(format string, v ...interface{}) { DefaultLogger.Debugf(format, v...) } 68 | func Fatal(v ...interface{}) { 69 | DefaultLogger.Fatal(red(v...)) 70 | } 71 | func Fatalf(format string, v ...interface{}) { 72 | DefaultLogger.Fatal(redf(format, v...)) 73 | } 74 | func Fatalln(v ...interface{}) { 75 | DefaultLogger.Fatal(redln(v...)) 76 | } 77 | func Panic(v ...interface{}) { 78 | DefaultLogger.Panic(red(v...)) 79 | } 80 | func Panicf(format string, v ...interface{}) { 81 | DefaultLogger.Panic(redf(format, v...)) 82 | } 83 | func Panicln(v ...interface{}) { 84 | DefaultLogger.Panic(redln(v...)) 85 | } 86 | 87 | func Error(v ...interface{}) { 88 | DefaultLogger.Print(red(v...)) 89 | } 90 | func Errorf(format string, v ...interface{}) { 91 | DefaultLogger.Print(redf(format, v...)) 92 | } 93 | func Errorln(v ...interface{}) { 94 | DefaultLogger.Print(redln(v...)) 95 | } 96 | 97 | func Warn(v ...interface{}) { 98 | DefaultLogger.Print(yellow(v...)) 99 | } 100 | func Warnf(format string, v ...interface{}) { 101 | DefaultLogger.Print(yellowf(format, v...)) 102 | } 103 | func Warnln(v ...interface{}) { 104 | DefaultLogger.Print(yellowln(v...)) 105 | } 106 | 107 | func Print(v ...interface{}) { DefaultLogger.Print(v...) } 108 | func Printf(format string, v ...interface{}) { DefaultLogger.Printf(format, v...) } 109 | func Println(v ...interface{}) { DefaultLogger.Println(v...) } 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | shuttle - Dynamic HTTP(S)/TCP/UDP Service Proxy 2 | ======= 3 | 4 | [![latest v0.1.0](https://img.shields.io/badge/latest-v0.1.0-green.svg?style=flat)](https://github.com/litl/shuttle/releases/tag/v0.1.0) 5 | [![Build Status](https://travis-ci.org/litl/shuttle.svg?branch=master)](https://travis-ci.org/litl/shuttle) 6 | [![License MIT](https://img.shields.io/badge/license-MIT-blue.svg?style=flat)](https://github.com/litl/shuttle/blob/master/LICENSE) 7 | [![GoDoc](https://godoc.org/github.com/litl/shuttle/client?status.png)](https://godoc.org/github.com/litl/shuttle/client) 8 | 9 | 10 | Shuttle is a proxy and load balancer, which can be updated live via an HTTP 11 | interface. It can Proxy TCP, UDP, and HTTP(S) via virtual hosts. 12 | 13 | ## Features 14 | - TCP/UDP/HTTP/HTTPS (SNI) Proxying 15 | - Round robin/Least Connection/Weighted Load Balancing 16 | - Backend Health Checks 17 | - HTTP API for dynamic updating and querying 18 | - Stats API 19 | - HTTP(S) Virtual Host Routing 20 | - Configuration HTTP Error Pages 21 | - Optional proxy config state saving 22 | - Optional file config 23 | 24 | ## Install 25 | 26 | ``` 27 | $ wget https://github.com/litl/shuttle/releases/download/v0.1.0/shuttle-linux-amd64-v0.1.0.tar.gz 28 | $ tar xvzv shuttle-linux-amd64-v0.1.0.tar.gz 29 | ``` 30 | 31 | ## Usage 32 | 33 | Shuttle can be started with a default configuration, as well as its last 34 | configuration state. The -state configuration is updated on changes to the 35 | internal config. If the state config file doesn't exist, the default is loaded. 36 | The default config is never written to by shuttle. 37 | 38 | Shuttle can serve multiple HTTPS hosts via SNI. Certs are loaded by providing 39 | a directory containing pairs of certificates and keys with the naming 40 | convention, `vhost.name.pem` `vhost.name.key`. 41 | 42 | 43 | Basic TCP proxy: 44 | 45 | $ ./shuttle -admin 127.0.0.1:9090 -config default_config.json -state state_config.json 46 | 47 | 48 | Proxy with a virtualhost HTTP proxy on port 8080: 49 | 50 | $ ./shuttle -admin 127.0.0.1:9090 -http :8080 -config default_config.json -state state_config.json 51 | 52 | 53 | The current config can be queried via the `/_config` endpoint. This returns a 54 | json list of Services and their Backends, which can be saved directly as a 55 | config file. The configuration itself is defined by `Config` in 56 | github.com/litl/shuttle/client. The running config cam be updated by issuing a 57 | PUT or POST with a valid json config to `/_config`. 58 | 59 | A GET request to `/` or `/_stats` returns the live stats from all Services. 60 | Individual services can be queried by their name, `/service_name`, returning 61 | just the json stats for that service. Backend stats can be queried directly as 62 | well via the path `service_name/backend_name`. 63 | 64 | Issuing a PUT with a json config to the service's endpoint will create, or 65 | replace that service. Any changes to the running service require shutting down 66 | the listener, and starting a new service, which will create a very small period 67 | where connection may be rejected. 68 | 69 | Issuing a PUT with a json config to the backend's endpoint will create or 70 | replace that backend. Existing connections relying on the old config will 71 | continue to run until the connection is closed. 72 | 73 | 74 | ## TODO 75 | 76 | - Documentation! 77 | - Configure individual hosts to require HTTPS 78 | - Connection limits (per service and/or per backend) 79 | - Rate limits 80 | - Mark backend down after non-check connection failures (still requires checks to bring it back up) 81 | - Health check via http, or tcp call/resp pattern 82 | - Protocol bridging? e.g. `TCP<->unix`, `UDP->TCP`?! 83 | - Better logging 84 | - Remove all dependency on galaxy (galaxy/log?) 85 | 86 | ## License 87 | 88 | MIT 89 | -------------------------------------------------------------------------------- /client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io/ioutil" 9 | "net/http" 10 | "time" 11 | ) 12 | 13 | // Client is an http client for communicating with the shuttle server api 14 | type Client struct { 15 | httpClient *http.Client 16 | addr string 17 | } 18 | 19 | // An http client for communicating with the shuttle server. 20 | func NewClient(addr string) *Client { 21 | return &Client{ 22 | httpClient: &http.Client{Timeout: 2 * time.Second}, 23 | addr: addr, 24 | } 25 | } 26 | 27 | // GetConfig retrieves the configuration for a running shuttle server. 28 | func (c *Client) GetConfig() (*Config, error) { 29 | 30 | req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/_config", c.addr), nil) 31 | if err != nil { 32 | return nil, err 33 | } 34 | 35 | resp, err := c.httpClient.Do(req) 36 | if err != nil { 37 | return nil, err 38 | } 39 | 40 | defer resp.Body.Close() 41 | body, err := ioutil.ReadAll(resp.Body) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | config := &Config{} 47 | err = json.Unmarshal(body, config) 48 | if err != nil { 49 | return nil, err 50 | } 51 | 52 | return config, nil 53 | } 54 | 55 | // UpdateConfig updates the running config on a shuttle server. This will 56 | // update globals settings and add services, but currently doesn't remove any 57 | // running service or backends. 58 | func (c *Client) UpdateConfig(config *Config) error { 59 | 60 | js, err := json.Marshal(config) 61 | if err != nil { 62 | return err 63 | } 64 | 65 | resp, err := c.httpClient.Post(fmt.Sprintf("http://%s/_config", c.addr), "application/json", 66 | bytes.NewBuffer(js)) 67 | if err != nil { 68 | return err 69 | } 70 | defer resp.Body.Close() 71 | 72 | if resp.StatusCode != http.StatusOK { 73 | return fmt.Errorf("failed to update shuttle config: %s", resp.Status) 74 | } 75 | return nil 76 | } 77 | 78 | // UpdateService adds or updates a service on a running shuttle server. 79 | func (c *Client) UpdateService(service *ServiceConfig) error { 80 | 81 | js, err := json.Marshal(service) 82 | if err != nil { 83 | return err 84 | } 85 | 86 | resp, err := c.httpClient.Post(fmt.Sprintf("http://%s/%s", c.addr, service.Name), "application/json", 87 | bytes.NewBuffer(js)) 88 | if err != nil { 89 | return err 90 | } 91 | defer resp.Body.Close() 92 | 93 | if resp.StatusCode != http.StatusOK { 94 | return fmt.Errorf("failed to update shuttle service '%s': %s", service.Name, resp.Status) 95 | } 96 | return nil 97 | } 98 | 99 | // RemoveService removes a service and its backends from a running shuttle server. 100 | func (c *Client) RemoveService(service string) error { 101 | req, err := http.NewRequest("DELETE", fmt.Sprintf("http://%s/%s", c.addr, service), nil) 102 | if err != nil { 103 | return err 104 | } 105 | 106 | resp, err := c.httpClient.Do(req) 107 | if err != nil { 108 | return err 109 | } 110 | defer resp.Body.Close() 111 | 112 | if resp.StatusCode != http.StatusOK { 113 | return errors.New(fmt.Sprintf("failed to remove shuttle service '%s': %s", service, resp.Status)) 114 | } 115 | return nil 116 | } 117 | 118 | // UpdateBackend adds or updates a single backend on a running shuttle server. 119 | func (c *Client) UpdateBackend(service string, backend *BackendConfig) error { 120 | 121 | js, err := json.Marshal(backend) 122 | if err != nil { 123 | return err 124 | } 125 | 126 | resp, err := c.httpClient.Post(fmt.Sprintf("http://%s/%s/%s", c.addr, service, backend.Name), "application/json", 127 | bytes.NewBuffer(js)) 128 | if err != nil { 129 | return err 130 | } 131 | defer resp.Body.Close() 132 | 133 | if resp.StatusCode != http.StatusOK { 134 | return fmt.Errorf("failed to update shuttle backend '%s/%s': %s", service, backend.Name, resp.Status) 135 | } 136 | return nil 137 | } 138 | 139 | // RemoveBackend removes a backend from its service on a running shuttle server. 140 | func (c *Client) RemoveBackend(service, backend string) error { 141 | req, err := http.NewRequest("DELETE", fmt.Sprintf("http://%s/%s/%s", c.addr, service, backend), nil) 142 | if err != nil { 143 | return err 144 | } 145 | 146 | resp, err := c.httpClient.Do(req) 147 | if err != nil { 148 | return err 149 | } 150 | defer resp.Body.Close() 151 | 152 | if resp.StatusCode != http.StatusOK { 153 | return errors.New(fmt.Sprintf("failed to remove shuttle backend '%s/%s': %s", service, backend, resp.Status)) 154 | } 155 | return nil 156 | } 157 | -------------------------------------------------------------------------------- /balancer.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "sort" 5 | "sync/atomic" 6 | ) 7 | 8 | // Balancing functions return a slice of all known available backends, in 9 | // priority order. This way the service can cycle through backends if the 10 | // initial connections fails. 11 | 12 | // RR is always weighted. 13 | // we don't reduce the weight, we just distribute exactly "Weight" calls in 14 | // a row 15 | func (s *Service) roundRobin() []*Backend { 16 | s.Lock() 17 | defer s.Unlock() 18 | 19 | count := len(s.Backends) 20 | switch count { 21 | case 0: 22 | return nil 23 | case 1: 24 | // fast track for the single backend case 25 | return s.Backends[0:1] 26 | } 27 | 28 | // we may be out of range if we lost a backend since last connections 29 | if s.lastBackend >= count { 30 | s.lastBackend = 0 31 | s.lastCount = 0 32 | } 33 | 34 | // if our backend was over-weight, but we can't find another, use this 35 | var reuse *Backend 36 | 37 | var balanced []*Backend 38 | // Find the next Up backend to call 39 | for i := 0; i < count; i++ { 40 | backend := s.Backends[s.lastBackend] 41 | 42 | if backend.Up() { 43 | if s.lastCount >= int(backend.Weight) { 44 | // used too many times, but save it just in case 45 | reuse = backend 46 | s.lastBackend = (s.lastBackend + 1) % count 47 | s.lastCount = 0 48 | continue 49 | } 50 | 51 | s.lastCount++ 52 | balanced = append(balanced, backend) 53 | 54 | break 55 | } 56 | 57 | s.lastBackend = (s.lastBackend + 1) % count 58 | } 59 | 60 | if len(balanced) == 0 { 61 | if reuse != nil { 62 | balanced = append(balanced, reuse) 63 | } else { 64 | return nil 65 | } 66 | } 67 | 68 | // Now add the rest of the available backends in order, in case the first 69 | // connect fails 70 | lastBackend := s.lastBackend 71 | for i := 0; i < count-1; i++ { 72 | lastBackend = (lastBackend + 1) % count 73 | backend := s.Backends[lastBackend] 74 | if backend.Up() { 75 | balanced = append(balanced, backend) 76 | } 77 | } 78 | 79 | return balanced 80 | } 81 | 82 | // LC returns the backend with the least number of active connections 83 | func (s *Service) leastConn() []*Backend { 84 | s.Lock() 85 | defer s.Unlock() 86 | 87 | count := len(s.Backends) 88 | switch count { 89 | case 0: 90 | return nil 91 | case 1: 92 | // fast track for the single backend case 93 | return s.Backends[0:1] 94 | } 95 | 96 | // return the backends in the order of least connections 97 | var balanced []*Backend 98 | 99 | // Accumulate all backends that are currently Up 100 | for _, b := range s.Backends { 101 | if b.Up() { 102 | balanced = append(balanced, b) 103 | } 104 | } 105 | 106 | if len(balanced) == 0 { 107 | return nil 108 | } 109 | 110 | sort.Sort(ByActive(balanced)) 111 | 112 | return balanced 113 | } 114 | 115 | // Simple, but still weighted, RR for UDP where we don't don't have active 116 | // connections or connection failures. 117 | func (s *Service) udpRoundRobin() *Backend { 118 | s.Lock() 119 | defer s.Unlock() 120 | 121 | count := len(s.Backends) 122 | switch count { 123 | case 0: 124 | return nil 125 | case 1: 126 | // fast track for the single backend case 127 | return s.Backends[0] 128 | } 129 | 130 | // we may be out of range if we lost a backend since last connections 131 | if s.lastBackend >= count { 132 | s.lastBackend = 0 133 | s.lastCount = 0 134 | } 135 | 136 | // if our backend was over-weight, but we can't find another, use this 137 | var backend, reuse *Backend 138 | 139 | // Find the next Up backend to call 140 | for i := 0; i < count; i++ { 141 | backend = s.Backends[s.lastBackend] 142 | 143 | if backend.Up() { 144 | if s.lastCount >= int(backend.Weight) { 145 | // used too many times, but save it just in case 146 | reuse = backend 147 | s.lastBackend = (s.lastBackend + 1) % count 148 | s.lastCount = 0 149 | continue 150 | } 151 | 152 | s.lastCount++ 153 | break 154 | } 155 | 156 | s.lastBackend = (s.lastBackend + 1) % count 157 | } 158 | 159 | if backend != nil { 160 | return backend 161 | } 162 | 163 | if reuse != nil { 164 | return reuse 165 | } 166 | 167 | return nil 168 | } 169 | 170 | type ByActive []*Backend 171 | 172 | func (s ByActive) Len() int { return len(s) } 173 | func (s ByActive) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 174 | func (s ByActive) Less(i, j int) bool { 175 | iActive := atomic.LoadInt64(&(s[i].Active)) 176 | jActive := atomic.LoadInt64(&(s[j].Active)) 177 | return iActive < jActive 178 | } 179 | -------------------------------------------------------------------------------- /proxy_bench_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io/ioutil" 5 | "net/http" 6 | "net/http/httptest" 7 | "runtime" 8 | "testing" 9 | 10 | "github.com/litl/shuttle/client" 11 | ) 12 | 13 | var ( 14 | benchServer *httptest.Server 15 | benchBackends []*testHTTPServer 16 | benchRouter *HostRouter 17 | ) 18 | 19 | func setupBench(b *testing.B) { 20 | Registry = ServiceRegistry{ 21 | svcs: make(map[string]*Service), 22 | vhosts: make(map[string]*VirtualHost), 23 | } 24 | 25 | benchServer = httptest.NewServer(nil) 26 | 27 | httpServer := &http.Server{ 28 | Addr: httpAddr, 29 | } 30 | 31 | benchRouter = NewHostRouter(httpServer) 32 | ready := make(chan bool) 33 | go benchRouter.Start(ready) 34 | <-ready 35 | 36 | for i := 0; i < 4; i++ { 37 | server, err := NewHTTPTestServer("127.0.0.1:0", b) 38 | if err != nil { 39 | b.Fatal(err) 40 | } 41 | 42 | benchBackends = append(benchBackends, server) 43 | } 44 | } 45 | 46 | func tearDownBench(b *testing.B) { 47 | for _, s := range benchBackends { 48 | s.Close() 49 | } 50 | benchBackends = nil 51 | 52 | for _, svc := range Registry.svcs { 53 | Registry.RemoveService(svc.Name) 54 | } 55 | 56 | benchServer.Close() 57 | benchRouter.Stop() 58 | } 59 | 60 | // Make HTTP calls over the TCP proxy for comparison to ReverseProxy 61 | func BenchmarkTCPProxy(b *testing.B) { 62 | setupBench(b) 63 | defer tearDownBench(b) 64 | 65 | svcCfg := client.ServiceConfig{ 66 | Name: "VHostTest", 67 | Addr: "127.0.0.1:9000", 68 | VirtualHosts: []string{"test-vhost"}, 69 | } 70 | 71 | for _, srv := range benchBackends { 72 | cfg := client.BackendConfig{ 73 | Addr: srv.addr, 74 | Name: srv.addr, 75 | } 76 | svcCfg.Backends = append(svcCfg.Backends, cfg) 77 | } 78 | 79 | err := Registry.AddService(svcCfg) 80 | if err != nil { 81 | b.Fatal(err) 82 | } 83 | 84 | req, err := http.NewRequest("GET", "http://127.0.0.1:9000/addr", nil) 85 | if err != nil { 86 | b.Fatal(err) 87 | } 88 | 89 | req.Host = "test-vhost" 90 | 91 | http.DefaultTransport.(*http.Transport).DisableKeepAlives = true 92 | runtime.GC() 93 | b.ResetTimer() 94 | for i := 0; i < b.N; i++ { 95 | resp, err := http.DefaultClient.Do(req) 96 | if err != nil { 97 | b.Fatal("Error during GET:", err) 98 | } 99 | body, err := ioutil.ReadAll(resp.Body) 100 | resp.Body.Close() 101 | if err != nil { 102 | b.Fatal("Error during Read:", err) 103 | } 104 | if len(body) < 7 { 105 | b.Fatalf("Error in Response: %s", body) 106 | } 107 | } 108 | 109 | runtime.GC() 110 | } 111 | 112 | func BenchmarkHTTPProxy(b *testing.B) { 113 | setupBench(b) 114 | defer tearDownBench(b) 115 | 116 | svcCfg := client.ServiceConfig{ 117 | Name: "VHostTest", 118 | Addr: "127.0.0.1:9000", 119 | VirtualHosts: []string{"test-vhost"}, 120 | } 121 | 122 | for _, srv := range benchBackends { 123 | cfg := client.BackendConfig{ 124 | Addr: srv.addr, 125 | Name: srv.addr, 126 | } 127 | svcCfg.Backends = append(svcCfg.Backends, cfg) 128 | } 129 | 130 | err := Registry.AddService(svcCfg) 131 | if err != nil { 132 | b.Fatal(err) 133 | } 134 | 135 | req, err := http.NewRequest("GET", "http://"+httpAddr+"/addr", nil) 136 | if err != nil { 137 | b.Fatal(err) 138 | } 139 | 140 | req.Host = "test-vhost" 141 | http.DefaultTransport.(*http.Transport).DisableKeepAlives = true 142 | 143 | b.ResetTimer() 144 | for i := 0; i < b.N; i++ { 145 | resp, err := http.DefaultClient.Do(req) 146 | if err != nil { 147 | b.Fatal("Error during GET:", err) 148 | } 149 | body, err := ioutil.ReadAll(resp.Body) 150 | if err != nil { 151 | b.Fatal("Error during Read:", err) 152 | } 153 | if len(body) < 7 { 154 | b.Fatalf("Error in Response: %s", body) 155 | } 156 | } 157 | 158 | } 159 | 160 | func BenchmarkHTTPProxyKeepalive(b *testing.B) { 161 | setupBench(b) 162 | defer tearDownBench(b) 163 | 164 | svcCfg := client.ServiceConfig{ 165 | Name: "VHostTest", 166 | Addr: "127.0.0.1:9000", 167 | VirtualHosts: []string{"test-vhost"}, 168 | } 169 | 170 | for _, srv := range benchBackends { 171 | cfg := client.BackendConfig{ 172 | Addr: srv.addr, 173 | Name: srv.addr, 174 | } 175 | svcCfg.Backends = append(svcCfg.Backends, cfg) 176 | } 177 | 178 | err := Registry.AddService(svcCfg) 179 | if err != nil { 180 | b.Fatal(err) 181 | } 182 | 183 | req, err := http.NewRequest("GET", "http://"+httpAddr+"/addr", nil) 184 | if err != nil { 185 | b.Fatal(err) 186 | } 187 | 188 | req.Host = "test-vhost" 189 | http.DefaultTransport.(*http.Transport).DisableKeepAlives = false 190 | 191 | b.ResetTimer() 192 | for i := 0; i < b.N; i++ { 193 | resp, err := http.DefaultClient.Do(req) 194 | if err != nil { 195 | b.Fatal("Error during GET:", err) 196 | } 197 | body, err := ioutil.ReadAll(resp.Body) 198 | if err != nil { 199 | b.Fatal("Error during Read:", err) 200 | } 201 | if len(body) < 7 { 202 | b.Fatalf("Error in Response: %s", body) 203 | } 204 | } 205 | 206 | } 207 | -------------------------------------------------------------------------------- /admin.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "io/ioutil" 6 | "net" 7 | "net/http" 8 | "os" 9 | "strings" 10 | "sync" 11 | 12 | "github.com/litl/shuttle/client" 13 | "github.com/litl/shuttle/log" 14 | 15 | "github.com/gorilla/mux" 16 | ) 17 | 18 | func getConfig(w http.ResponseWriter, r *http.Request) { 19 | w.Write(marshal(Registry.Config())) 20 | } 21 | 22 | func getStats(w http.ResponseWriter, r *http.Request) { 23 | if len(Registry.Config().Services) == 0 { 24 | w.WriteHeader(503) 25 | } 26 | w.Write(marshal(Registry.Stats())) 27 | } 28 | 29 | func getServiceStats(w http.ResponseWriter, r *http.Request) { 30 | vars := mux.Vars(r) 31 | 32 | serviceStats, err := Registry.ServiceStats(vars["service"]) 33 | if err != nil { 34 | http.Error(w, err.Error(), http.StatusNotFound) 35 | return 36 | } 37 | 38 | w.Write(marshal(serviceStats)) 39 | } 40 | 41 | func getServiceConfig(w http.ResponseWriter, r *http.Request) { 42 | vars := mux.Vars(r) 43 | 44 | serviceStats, err := Registry.ServiceConfig(vars["service"]) 45 | if err != nil { 46 | http.Error(w, err.Error(), http.StatusNotFound) 47 | return 48 | } 49 | 50 | w.Write(marshal(serviceStats)) 51 | } 52 | 53 | // Update the global config 54 | func postConfig(w http.ResponseWriter, r *http.Request) { 55 | cfg := client.Config{} 56 | 57 | body, err := ioutil.ReadAll(r.Body) 58 | if err != nil { 59 | log.Errorln(err) 60 | http.Error(w, err.Error(), http.StatusInternalServerError) 61 | return 62 | } 63 | defer r.Body.Close() 64 | 65 | err = json.Unmarshal(body, &cfg) 66 | if err != nil { 67 | log.Errorln(err) 68 | http.Error(w, err.Error(), http.StatusInternalServerError) 69 | return 70 | } 71 | 72 | if err := Registry.UpdateConfig(cfg); err != nil { 73 | log.Errorln(err) 74 | // TODO: differentiate between ServerError and BadRequest 75 | http.Error(w, err.Error(), http.StatusInternalServerError) 76 | return 77 | } 78 | } 79 | 80 | // Update a service and/or backends. 81 | func postService(w http.ResponseWriter, r *http.Request) { 82 | vars := mux.Vars(r) 83 | 84 | body, err := ioutil.ReadAll(r.Body) 85 | if err != nil { 86 | log.Errorln(err) 87 | http.Error(w, err.Error(), http.StatusInternalServerError) 88 | return 89 | } 90 | defer r.Body.Close() 91 | 92 | svcCfg := client.ServiceConfig{Name: vars["service"]} 93 | err = json.Unmarshal(body, &svcCfg) 94 | if err != nil { 95 | log.Errorln(err) 96 | http.Error(w, err.Error(), http.StatusInternalServerError) 97 | return 98 | } 99 | 100 | // don't let someone update the wrong service 101 | if svcCfg.Name != vars["service"] { 102 | errMsg := "Mismatched service name in API call" 103 | log.Error(errMsg) 104 | http.Error(w, errMsg, http.StatusBadRequest) 105 | return 106 | } 107 | 108 | cfg := client.Config{ 109 | Services: []client.ServiceConfig{svcCfg}, 110 | } 111 | 112 | err = Registry.UpdateConfig(cfg) 113 | //FIXME: this doesn't return an error for an empty or broken service 114 | if err != nil { 115 | log.Error(err) 116 | http.Error(w, err.Error(), http.StatusBadRequest) 117 | return 118 | } 119 | 120 | w.Write(marshal(Registry.Config())) 121 | } 122 | 123 | func deleteService(w http.ResponseWriter, r *http.Request) { 124 | vars := mux.Vars(r) 125 | 126 | err := Registry.RemoveService(vars["service"]) 127 | if err != nil { 128 | http.Error(w, err.Error(), http.StatusNotFound) 129 | return 130 | } 131 | go writeStateConfig() 132 | w.Write(marshal(Registry.Config())) 133 | } 134 | 135 | func getBackendStats(w http.ResponseWriter, r *http.Request) { 136 | vars := mux.Vars(r) 137 | serviceName := vars["service"] 138 | backendName := vars["backend"] 139 | 140 | backend, err := Registry.BackendStats(serviceName, backendName) 141 | if err != nil { 142 | http.Error(w, err.Error(), http.StatusNotFound) 143 | return 144 | } 145 | 146 | w.Write(marshal(backend)) 147 | } 148 | 149 | func getBackend(w http.ResponseWriter, r *http.Request) { 150 | vars := mux.Vars(r) 151 | serviceName := vars["service"] 152 | backendName := vars["backend"] 153 | 154 | backend, err := Registry.BackendStats(serviceName, backendName) 155 | if err != nil { 156 | http.Error(w, err.Error(), http.StatusNotFound) 157 | return 158 | } 159 | 160 | w.Write(marshal(backend)) 161 | } 162 | 163 | func postBackend(w http.ResponseWriter, r *http.Request) { 164 | vars := mux.Vars(r) 165 | 166 | body, err := ioutil.ReadAll(r.Body) 167 | if err != nil { 168 | log.Errorln(err) 169 | http.Error(w, err.Error(), http.StatusInternalServerError) 170 | return 171 | } 172 | defer r.Body.Close() 173 | 174 | backendName := vars["backend"] 175 | serviceName := vars["service"] 176 | 177 | backendCfg := client.BackendConfig{Name: backendName} 178 | err = json.Unmarshal(body, &backendCfg) 179 | if err != nil { 180 | log.Errorln(err) 181 | http.Error(w, err.Error(), http.StatusInternalServerError) 182 | return 183 | } 184 | 185 | if err := Registry.AddBackend(serviceName, backendCfg); err != nil { 186 | http.Error(w, err.Error(), http.StatusBadRequest) 187 | return 188 | } 189 | 190 | go writeStateConfig() 191 | w.Write(marshal(Registry.Config())) 192 | } 193 | 194 | func deleteBackend(w http.ResponseWriter, r *http.Request) { 195 | vars := mux.Vars(r) 196 | 197 | serviceName := vars["service"] 198 | backendName := vars["backend"] 199 | 200 | if err := Registry.RemoveBackend(serviceName, backendName); err != nil { 201 | http.Error(w, err.Error(), http.StatusBadRequest) 202 | return 203 | } 204 | 205 | go writeStateConfig() 206 | w.Write(marshal(Registry.Config())) 207 | } 208 | 209 | func addHandlers() { 210 | r := mux.NewRouter() 211 | r.HandleFunc("/", getStats).Methods("GET") 212 | r.HandleFunc("/", postConfig).Methods("PUT", "POST") 213 | r.HandleFunc("/_config", getConfig).Methods("GET") 214 | r.HandleFunc("/_config", postConfig).Methods("PUT", "POST") 215 | r.HandleFunc("/_stats", getStats).Methods("GET") 216 | r.HandleFunc("/{service}", getServiceStats).Methods("GET") 217 | r.HandleFunc("/{service}/_config", getServiceConfig).Methods("GET") 218 | r.HandleFunc("/{service}/_stats", getServiceStats).Methods("GET") 219 | r.HandleFunc("/{service}", postService).Methods("PUT", "POST") 220 | r.HandleFunc("/{service}", deleteService).Methods("DELETE") 221 | r.HandleFunc("/{service}/{backend}", getBackend).Methods("GET") 222 | r.HandleFunc("/{service}/{backend}", postBackend).Methods("PUT", "POST") 223 | r.HandleFunc("/{service}/{backend}", deleteBackend).Methods("DELETE") 224 | http.Handle("/", r) 225 | } 226 | 227 | func startAdminHTTPServer(wg *sync.WaitGroup) { 228 | defer wg.Done() 229 | addHandlers() 230 | log.Println("Admin server listening on", adminListenAddr) 231 | 232 | netw := "tcp" 233 | 234 | if strings.HasPrefix(adminListenAddr, "/") { 235 | netw = "unix" 236 | 237 | // remove our old socket if we left it lying around 238 | if stats, err := os.Stat(adminListenAddr); err == nil { 239 | if stats.Mode()&os.ModeSocket != 0 { 240 | os.Remove(adminListenAddr) 241 | } 242 | } 243 | 244 | defer os.Remove(adminListenAddr) 245 | } 246 | 247 | listener, err := net.Listen(netw, adminListenAddr) 248 | if err != nil { 249 | log.Fatalln(err) 250 | } 251 | 252 | http.Serve(listener, nil) 253 | } 254 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "net" 10 | "net/http" 11 | "net/http/httptest" 12 | "strconv" 13 | "strings" 14 | "sync" 15 | "time" 16 | 17 | . "gopkg.in/check.v1" 18 | ) 19 | 20 | type testServer struct { 21 | addr string 22 | sig string 23 | listener net.Listener 24 | wg *sync.WaitGroup 25 | } 26 | 27 | // Start a tcp server which responds with it's addr after every read. 28 | func NewTestServer(addr string, c Tester) (*testServer, error) { 29 | s := &testServer{} 30 | s.wg = new(sync.WaitGroup) 31 | 32 | var err error 33 | 34 | // try really hard to bind this so we don't fail tests 35 | for i := 0; i < 3; i++ { 36 | s.listener, err = net.Listen("tcp", addr) 37 | if err == nil { 38 | break 39 | } 40 | c.Log("Listen error:", err) 41 | c.Log("Trying again in 1s...") 42 | time.Sleep(time.Second) 43 | } 44 | 45 | if err != nil { 46 | return nil, err 47 | } 48 | 49 | s.addr = s.listener.Addr().String() 50 | c.Log("listening on ", s.addr) 51 | 52 | s.wg.Add(1) 53 | go func() { 54 | defer s.wg.Done() 55 | for { 56 | conn, err := s.listener.Accept() 57 | if err != nil { 58 | return 59 | } 60 | 61 | conn.SetDeadline(time.Now().Add(5 * time.Second)) 62 | s.wg.Add(1) 63 | go func() { 64 | defer s.wg.Done() 65 | defer conn.Close() 66 | buff := make([]byte, 1024) 67 | for { 68 | if _, err := conn.Read(buff); err != nil { 69 | if err != io.EOF { 70 | c.Logf("test server '%s' error: %s", s.addr, err) 71 | } 72 | return 73 | } 74 | if _, err := io.WriteString(conn, s.addr); err != nil { 75 | if err != io.EOF { 76 | c.Logf("test server '%s' error: %s", s.addr, err) 77 | } 78 | return 79 | } 80 | } 81 | }() 82 | } 83 | }() 84 | return s, nil 85 | } 86 | 87 | func (s *testServer) Stop() { 88 | s.listener.Close() 89 | // We may be imediately creating another identical server. 90 | // Wait until all goroutines return to ensure we can bind again. 91 | s.wg.Wait() 92 | } 93 | 94 | type udpTestServer struct { 95 | sync.Mutex 96 | addr string 97 | conn *net.UDPConn 98 | count int 99 | packets [][]byte 100 | wg *sync.WaitGroup 101 | } 102 | 103 | // Start a tcp server which responds with it's addr after every read. 104 | func NewUDPTestServer(addr string, c Tester) (*udpTestServer, error) { 105 | s := &udpTestServer{} 106 | s.wg = new(sync.WaitGroup) 107 | 108 | lAddr, err := net.ResolveUDPAddr("udp", addr) 109 | if err != nil { 110 | c.Fatal(err) 111 | } 112 | 113 | // try really hard to bind this so we don't fail tests 114 | for i := 0; i < 3; i++ { 115 | s.conn, err = net.ListenUDP("udp", lAddr) 116 | if err == nil { 117 | break 118 | } 119 | c.Log("Listen error:", err) 120 | c.Log("Trying again in 1s...") 121 | time.Sleep(time.Second) 122 | } 123 | 124 | if err != nil { 125 | return nil, err 126 | } 127 | 128 | s.addr = addr 129 | c.Log("listening on UDP:", s.addr) 130 | 131 | s.wg.Add(1) 132 | go func() { 133 | defer s.wg.Done() 134 | // receive packets into a single buffer so we don't waste time make'ing them 135 | buff := make([]byte, 1048576) 136 | pos := 0 137 | for { 138 | n, _, err := s.conn.ReadFromUDP(buff[pos:]) 139 | if err != nil { 140 | return 141 | } 142 | s.count++ 143 | 144 | // lock the packet slice so we can safely inspect it from tests 145 | s.Lock() 146 | s.packets = append(s.packets, buff[pos:pos+n]) 147 | s.Unlock() 148 | pos += n 149 | } 150 | }() 151 | return s, nil 152 | } 153 | 154 | func (s *udpTestServer) Stop() { 155 | s.conn.Close() 156 | // We may be imediately creating another identical server. 157 | // Wait until all goroutines return to ensure we can bind again. 158 | s.wg.Wait() 159 | } 160 | 161 | // Backend server for testing HTTP proxies 162 | type testHTTPServer struct { 163 | *httptest.Server 164 | addr string 165 | name string 166 | } 167 | 168 | // make the handler a method of the server so we can get the server's address 169 | func (s *testHTTPServer) addrHandler(w http.ResponseWriter, r *http.Request) { 170 | io.WriteString(w, s.addr) 171 | } 172 | 173 | func (s *testHTTPServer) errorHandler(w http.ResponseWriter, r *http.Request) { 174 | code, _ := strconv.Atoi(r.FormValue("code")) 175 | if code > 0 { 176 | w.WriteHeader(code) 177 | io.WriteString(w, s.addr) 178 | return 179 | } 180 | 181 | // set a nonsense header to chech ErrorPage caching 182 | w.Header().Set("Last-Modified", s.addr) 183 | w.WriteHeader(400) 184 | io.WriteString(w, s.addr) 185 | } 186 | 187 | type fataler interface { 188 | Fatal(...interface{}) 189 | } 190 | 191 | // Start a tcp server which responds with it's addr after every read. 192 | func NewHTTPTestServer(addr string, c fataler) (*testHTTPServer, error) { 193 | s := &testHTTPServer{ 194 | Server: httptest.NewUnstartedServer(nil), 195 | } 196 | 197 | s.addr = s.Listener.Addr().String() 198 | if parts := strings.Split(s.addr, ":"); len(parts) == 2 { 199 | s.name = fmt.Sprintf("http-%s.server.test", parts[1]) 200 | } else { 201 | c.Fatal("error naming http server") 202 | } 203 | 204 | mux := http.NewServeMux() 205 | mux.HandleFunc("/addr", s.addrHandler) 206 | mux.HandleFunc("/error", s.errorHandler) 207 | 208 | s.Config.Handler = mux 209 | s.Start() 210 | 211 | return s, nil 212 | } 213 | 214 | // Dialer that always resolves to 127.0.0.1 215 | func localDial(netw, addr string) (net.Conn, error) { 216 | _, port, err := net.SplitHostPort(addr) 217 | if err != nil { 218 | return nil, err 219 | } 220 | 221 | return net.Dial("tcp", "127.0.0.1:"+port) 222 | } 223 | 224 | // Connect to http server, and check response for value 225 | func checkHTTP(url, host, expected string, status int, c Tester) { 226 | req, err := http.NewRequest("GET", url, nil) 227 | if err != nil { 228 | c.Fatal(err) 229 | } 230 | 231 | req.Host = host 232 | req.Header.Set("X-Request-Id", "foo") 233 | 234 | // Load our test certs as our RootCAs, so we can verify that we connect 235 | // with the correct Cert in an HTTPSRouter 236 | certs := x509.NewCertPool() 237 | pemData, err := ioutil.ReadFile("testdata/vhost1.pem") 238 | if err != nil { 239 | c.Fatal(err) 240 | } 241 | certs.AppendCertsFromPEM(pemData) 242 | pemData, err = ioutil.ReadFile("testdata/vhost2.pem") 243 | if err != nil { 244 | c.Fatal(err) 245 | } 246 | certs.AppendCertsFromPEM(pemData) 247 | 248 | client := &http.Client{ 249 | Transport: &http.Transport{ 250 | Dial: localDial, 251 | TLSClientConfig: &tls.Config{ 252 | RootCAs: certs, 253 | }, 254 | }, 255 | } 256 | 257 | c.Log("GET ", req.Host, req.URL.Path) 258 | 259 | resp, err := client.Do(req) 260 | if err != nil { 261 | c.Fatal(err) 262 | } 263 | defer resp.Body.Close() 264 | 265 | body, err := ioutil.ReadAll(resp.Body) 266 | if err != nil { 267 | c.Fatal(err) 268 | } 269 | 270 | reqID := resp.Header.Get("X-Request-Id") 271 | c.Assert(reqID[len(reqID)-4:], Equals, ".foo") 272 | 273 | c.Assert(resp.StatusCode, Equals, status) 274 | 275 | if resp.StatusCode == http.StatusOK { 276 | // check for our backend header, without possibly getting a cached error page 277 | c.Assert(resp.Header.Get("X-Backend"), Equals, expected) 278 | } 279 | c.Assert(string(body), Equals, expected) 280 | } 281 | -------------------------------------------------------------------------------- /shuttle-cli/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "os" 9 | "strconv" 10 | "strings" 11 | 12 | shuttle "github.com/litl/shuttle/client" 13 | ) 14 | 15 | var ( 16 | shuttleAddr string 17 | configData string 18 | configFile string 19 | 20 | buildVersion = "0.1.0" 21 | 22 | client *shuttle.Client 23 | 24 | cfg = &shuttle.Config{} 25 | configFS = flag.NewFlagSet("config", flag.ExitOnError) 26 | 27 | serviceCfg = &shuttle.ServiceConfig{} 28 | serviceFS = flag.NewFlagSet("service", flag.ExitOnError) 29 | vhosts = stringSlice{} 30 | errorPages = stringSlice{} 31 | 32 | backendCfg = &shuttle.BackendConfig{} 33 | backendFS = flag.NewFlagSet("backend", flag.ExitOnError) 34 | ) 35 | 36 | func init() { 37 | configFS.StringVar(&cfg.Balance, "balance", "", "balance algorithm, {RR|LC}") 38 | configFS.IntVar(&cfg.CheckInterval, "check-interval", 0, "interval between health checks in milliseconds") 39 | configFS.IntVar(&cfg.Fall, "fall", 0, "number of failed healthchecks before a backend is marked down") 40 | configFS.IntVar(&cfg.Rise, "rise", 0, "number of successful health checks before a down service is marked up") 41 | configFS.IntVar(&cfg.ClientTimeout, "client-timeout", 0, "innactivity timeout for client connections") 42 | configFS.IntVar(&cfg.ServerTimeout, "server-timeout", 0, "innactivity timeout for server connections") 43 | configFS.IntVar(&cfg.DialTimeout, "dial-timeout", 0, "timeout for dialing new connections connections") 44 | configFS.BoolVar(&cfg.HTTPSRedirect, "https-redirect", false, "rediect all http requests to https") 45 | 46 | serviceFS.StringVar(&serviceCfg.Addr, "address", "", "service listening address") 47 | serviceFS.StringVar(&serviceCfg.Network, "network", "", "service network type") 48 | serviceFS.StringVar(&serviceCfg.Balance, "balance", "", "balancing algorithm, {RR|LC}") 49 | serviceFS.IntVar(&serviceCfg.CheckInterval, "check-interval", 0, "interval between health checks in milliseconds") 50 | serviceFS.IntVar(&serviceCfg.Fall, "fall", 0, "number of failed healthchecks before a backend is marked down") 51 | serviceFS.IntVar(&serviceCfg.Rise, "rise", 0, "number of successful health checks before a down service is marked up") 52 | serviceFS.IntVar(&serviceCfg.ClientTimeout, "client-timeout", 0, "innactivity timeout for client connections") 53 | serviceFS.IntVar(&serviceCfg.ServerTimeout, "server-timeout", 0, "innactivity timeout for server connections") 54 | serviceFS.IntVar(&serviceCfg.DialTimeout, "dial-timeout", 0, "timeout for dialing new connections connections") 55 | serviceFS.BoolVar(&serviceCfg.HTTPSRedirect, "https-redirect", false, "rediect all http requests to https") 56 | serviceFS.Var(&vhosts, "vhost", "virtual host name. may be set multiple times") 57 | serviceFS.Var(&errorPages, "error-page", "location for http error code formatted as 'http://example.com/|500,503'. may be set multiple times") 58 | 59 | backendFS.StringVar(&backendCfg.Addr, "address", "", "service listening address") 60 | backendFS.StringVar(&backendCfg.Network, "network", "", "backend network type") 61 | backendFS.StringVar(&backendCfg.CheckAddr, "check-address", "", "health check address") 62 | backendFS.IntVar(&backendCfg.Weight, "weight", 0, "balance weight") 63 | } 64 | 65 | func usage() { 66 | flag.PrintDefaults() 67 | fmt.Println(`shuttle-cli {config|update|remove} [options] 68 | 69 | config [options] 70 | set or print global config 71 | example: update the default client-timeout to 10 seconds 72 | $ shuttle-cli config -client-timeout 10s 73 | options:`) 74 | configFS.PrintDefaults() 75 | 76 | fmt.Println(` 77 | update service [options] 78 | add or update a service 79 | example: update the server-timeout on "servicename" to 2 seconds 80 | $ shuttle-cli update servicename -server-timeout 2s 81 | options:`) 82 | serviceFS.PrintDefaults() 83 | 84 | fmt.Println(` 85 | udpate service/backend [options] 86 | add or update a backend 87 | example: update the round-robin weight on "service/backend" to 3 88 | $ shuttle-cli update service/backend -weight 3 89 | options:`) 90 | backendFS.PrintDefaults() 91 | 92 | fmt.Println(` 93 | remove: remove services or backends 94 | remove service 95 | remove service/backend`) 96 | 97 | os.Exit(1) 98 | } 99 | 100 | func main() { 101 | log.SetPrefix("") 102 | log.SetFlags(0) 103 | 104 | flag.StringVar(&shuttleAddr, "addr", "127.0.0.1:9090", "shuttle admin address") 105 | flag.Usage = usage 106 | 107 | flag.Parse() 108 | 109 | if flag.NArg() < 1 { 110 | usage() 111 | 112 | } 113 | 114 | client = shuttle.NewClient(shuttleAddr) 115 | 116 | switch flag.Args()[0] { 117 | case "version": 118 | fmt.Println(buildVersion) 119 | return 120 | case "config": 121 | config(flag.Args()[1:]) 122 | case "update", "add": 123 | update(flag.Args()[1:]) 124 | case "remove": 125 | remove(flag.Args()[1:]) 126 | default: 127 | usage() 128 | } 129 | 130 | } 131 | 132 | func config(args []string) { 133 | if len(args) == 0 { 134 | cfg, err := client.GetConfig() 135 | if err != nil { 136 | log.Fatal(err) 137 | } 138 | 139 | js, err := json.MarshalIndent(cfg, " ", "") 140 | if err != nil { 141 | log.Fatal(err) 142 | } 143 | 144 | fmt.Println(string(js)) 145 | return 146 | } 147 | 148 | configFS.Parse(args) 149 | 150 | err := client.UpdateConfig(cfg) 151 | if err != nil { 152 | log.Fatal(err) 153 | } 154 | } 155 | 156 | // slice for multiple string flags 157 | type stringSlice []string 158 | 159 | func (s *stringSlice) Set(arg string) error { 160 | *s = append(*s, arg) 161 | return nil 162 | } 163 | 164 | func (s stringSlice) String() string { 165 | return strings.Join(s, ",") 166 | } 167 | 168 | func update(args []string) { 169 | if len(args) < 1 { 170 | usage() 171 | } 172 | 173 | target := strings.SplitN(args[0], "/", 2) 174 | if len(target) == 1 { 175 | updateService(target[0], args[1:]) 176 | return 177 | } 178 | 179 | updateBackend(target[0], target[1], args[1:]) 180 | } 181 | 182 | func updateService(service string, args []string) { 183 | serviceFS.Parse(args) 184 | 185 | if len(vhosts) > 0 { 186 | serviceCfg.VirtualHosts = vhosts 187 | } 188 | 189 | if len(errorPages) > 1 { 190 | serviceCfg.ErrorPages = parseErrorPages(errorPages) 191 | } 192 | 193 | serviceCfg.Name = service 194 | 195 | err := client.UpdateService(serviceCfg) 196 | if err != nil { 197 | log.Fatal(err) 198 | } 199 | } 200 | 201 | func parseErrorPages(pages []string) map[string][]int { 202 | ep := make(map[string][]int) 203 | for _, p := range pages { 204 | parts := strings.SplitN(p, "|", 1) 205 | if len(parts) != 2 { 206 | log.Fatalf("invalid error-page %s", p) 207 | } 208 | 209 | codes := []int{} 210 | for _, code := range strings.Split(parts[1], ",") { 211 | i, err := strconv.Atoi(code) 212 | if err != nil { 213 | log.Fatalf("invalided error-page %s, %s", p, err.Error()) 214 | } 215 | codes = append(codes, i) 216 | } 217 | 218 | ep[p] = codes 219 | } 220 | return ep 221 | } 222 | 223 | func updateBackend(service, backend string, args []string) { 224 | backendFS.Parse(args) 225 | 226 | backendCfg.Name = backend 227 | err := client.UpdateBackend(service, backendCfg) 228 | if err != nil { 229 | log.Fatal(err) 230 | } 231 | } 232 | 233 | func remove(args []string) { 234 | if len(args) < 1 { 235 | usage() 236 | } 237 | 238 | target := strings.SplitN(args[0], "/", 1) 239 | if len(target) == 1 { 240 | err := client.RemoveService(target[0]) 241 | if err != nil { 242 | log.Fatal(err) 243 | } 244 | return 245 | } 246 | 247 | err := client.RemoveBackend(target[0], target[1]) 248 | if err != nil { 249 | log.Fatal(err) 250 | } 251 | 252 | } 253 | -------------------------------------------------------------------------------- /backend.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "sync" 7 | "sync/atomic" 8 | "time" 9 | 10 | "github.com/litl/shuttle/client" 11 | "github.com/litl/shuttle/log" 12 | ) 13 | 14 | type Backend struct { 15 | sync.Mutex 16 | Name string 17 | Addr string 18 | CheckAddr string 19 | up bool 20 | Weight int 21 | Sent int64 22 | Rcvd int64 23 | Errors int64 24 | Conns int64 25 | Active int64 26 | HTTPActive int64 27 | Network string 28 | 29 | // these are loaded from the service, so a backend doesn't need to access 30 | // the service struct at all. 31 | dialTimeout time.Duration 32 | rwTimeout time.Duration 33 | checkInterval time.Duration 34 | rise int 35 | riseCount int 36 | checkOK int 37 | fall int 38 | fallCount int 39 | checkFail int 40 | 41 | startCheck sync.Once 42 | // stop the health-check loop 43 | stopCheck chan interface{} 44 | 45 | // so we only need to ResolveUDPAddr once 46 | udpAddr *net.UDPAddr 47 | } 48 | 49 | // The json stats we return for the backend 50 | type BackendStat struct { 51 | Name string `json:"name"` 52 | Addr string `json:"address"` 53 | CheckAddr string `json:"check_address"` 54 | Up bool `json:"up"` 55 | Weight int `json:"weight"` 56 | Sent int64 `json:"sent"` 57 | Rcvd int64 `json:"received"` 58 | Errors int64 `json:"errors"` 59 | Conns int64 `json:"connections"` 60 | Active int64 `json:"active"` 61 | HTTPActive int64 `json:"http_active"` 62 | CheckOK int `json:"check_success"` 63 | CheckFail int `json:"check_fail"` 64 | } 65 | 66 | func NewBackend(cfg client.BackendConfig) *Backend { 67 | b := &Backend{ 68 | Name: cfg.Name, 69 | Addr: cfg.Addr, 70 | CheckAddr: cfg.CheckAddr, 71 | Weight: cfg.Weight, 72 | Network: cfg.Network, 73 | stopCheck: make(chan interface{}), 74 | } 75 | 76 | // don't want a weight of 0 77 | if b.Weight == 0 { 78 | b.Weight = 1 79 | } 80 | 81 | if b.Network == "" { 82 | b.Network = "tcp" 83 | } 84 | 85 | switch b.Network { 86 | case "udp", "udp4", "udp6": 87 | var err error 88 | b.udpAddr, err = net.ResolveUDPAddr(b.Network, b.Addr) 89 | if err != nil { 90 | log.Errorf("ERROR: %s", err.Error()) 91 | b.up = false 92 | } 93 | } 94 | 95 | return b 96 | } 97 | 98 | // Copy the backend state into a BackendStat struct. 99 | func (b *Backend) Stats() BackendStat { 100 | b.Lock() 101 | defer b.Unlock() 102 | 103 | stats := BackendStat{ 104 | Name: b.Name, 105 | Addr: b.Addr, 106 | CheckAddr: b.CheckAddr, 107 | Up: b.up, 108 | Weight: b.Weight, 109 | Sent: atomic.LoadInt64(&b.Sent), 110 | Rcvd: atomic.LoadInt64(&b.Rcvd), 111 | Errors: atomic.LoadInt64(&b.Errors), 112 | Conns: atomic.LoadInt64(&b.Conns), 113 | Active: atomic.LoadInt64(&b.Active), 114 | HTTPActive: atomic.LoadInt64(&b.HTTPActive), 115 | CheckOK: b.checkOK, 116 | CheckFail: b.checkFail, 117 | } 118 | 119 | return stats 120 | } 121 | 122 | func (b *Backend) Up() bool { 123 | b.Lock() 124 | up := b.up 125 | b.Unlock() 126 | return up 127 | } 128 | 129 | // Return the struct for marshaling into a json config 130 | func (b *Backend) Config() client.BackendConfig { 131 | b.Lock() 132 | defer b.Unlock() 133 | 134 | cfg := client.BackendConfig{ 135 | Name: b.Name, 136 | Addr: b.Addr, 137 | CheckAddr: b.CheckAddr, 138 | Weight: b.Weight, 139 | } 140 | 141 | return cfg 142 | } 143 | 144 | // Backends and Servers Stringify themselves directly into their config format. 145 | func (b *Backend) String() string { 146 | return string(marshal(b.Config())) 147 | } 148 | 149 | func (b *Backend) Start() { 150 | go b.startCheck.Do(b.healthCheck) 151 | } 152 | 153 | func (b *Backend) Stop() { 154 | close(b.stopCheck) 155 | } 156 | 157 | func (b *Backend) check() { 158 | if b.CheckAddr == "" { 159 | return 160 | } 161 | 162 | up := true 163 | if c, e := net.DialTimeout("tcp", b.CheckAddr, b.dialTimeout); e == nil { 164 | c.(*net.TCPConn).SetLinger(0) 165 | c.Close() 166 | } else { 167 | log.Debug("Check error:", e) 168 | up = false 169 | } 170 | 171 | b.Lock() 172 | defer b.Unlock() 173 | if up { 174 | log.Debugf("Check OK for %s/%s", b.Name, b.CheckAddr) 175 | b.fallCount = 0 176 | b.riseCount++ 177 | b.checkOK++ 178 | if b.riseCount >= b.rise { 179 | if !b.up { 180 | log.Debugf("Marking backend %s Up", b.Name) 181 | } 182 | b.up = true 183 | } 184 | } else { 185 | log.Debugf("Check failed for %s/%s", b.Name, b.CheckAddr) 186 | b.riseCount = 0 187 | b.fallCount++ 188 | b.checkFail++ 189 | if b.fallCount >= b.fall { 190 | if b.up { 191 | log.Debugf("Marking backend %s Down", b.Name) 192 | } 193 | b.up = false 194 | } 195 | } 196 | } 197 | 198 | // Periodically check the status of this backend 199 | func (b *Backend) healthCheck() { 200 | t := time.NewTicker(b.checkInterval) 201 | for { 202 | select { 203 | case <-b.stopCheck: 204 | log.Debug("Stopping backend", b.Name) 205 | t.Stop() 206 | return 207 | case <-t.C: 208 | b.check() 209 | } 210 | } 211 | } 212 | 213 | // use to identify embedded TCPConns 214 | type closeReader interface { 215 | CloseRead() error 216 | } 217 | 218 | func (b *Backend) Proxy(srvConn, cliConn net.Conn) { 219 | log.Debugf("Initiating proxy: %s/%s-%s/%s", 220 | cliConn.RemoteAddr(), 221 | cliConn.LocalAddr(), 222 | srvConn.LocalAddr(), 223 | srvConn.RemoteAddr(), 224 | ) 225 | 226 | // Backend is a pointer receiver so we can get the address of the fields, 227 | // but all updates will be done atomically. 228 | 229 | bConn := &shuttleConn{ 230 | TCPConn: srvConn.(*net.TCPConn), 231 | rwTimeout: b.rwTimeout, 232 | read: &b.Rcvd, 233 | written: &b.Sent, 234 | } 235 | // TODO: No way to force shutdown. Do we need it, or should we always just 236 | // let a connection run out? 237 | 238 | atomic.AddInt64(&b.Conns, 1) 239 | atomic.AddInt64(&b.Active, 1) 240 | defer atomic.AddInt64(&b.Active, -1) 241 | 242 | // channels to wait on close event 243 | backendClosed := make(chan bool, 1) 244 | clientClosed := make(chan bool, 1) 245 | 246 | go broker(bConn, cliConn, clientClosed, &b.Sent, &b.Errors) 247 | go broker(cliConn, bConn, backendClosed, &b.Rcvd, &b.Errors) 248 | 249 | // wait for one half of the proxy to exit, then trigger a shutdown of the 250 | // other half by calling CloseRead(). This will break the read loop in the 251 | // broker and fully close the connection. 252 | var waitFor chan bool 253 | select { 254 | case <-clientClosed: 255 | log.Debugf("Client %s/%s closed connection", cliConn.RemoteAddr(), cliConn.LocalAddr()) 256 | // the client closed first, so any more packets here are invalid, and 257 | // we can SetLinger(0) to recycle the port faster. 258 | bConn.TCPConn.SetLinger(0) 259 | bConn.CloseRead() 260 | waitFor = backendClosed 261 | case <-backendClosed: 262 | log.Debugf("Server %s/%s closed connection", srvConn.RemoteAddr(), srvConn.LocalAddr()) 263 | cliConn.(closeReader).CloseRead() 264 | waitFor = clientClosed 265 | } 266 | // wait for the other connection to close 267 | <-waitFor 268 | } 269 | 270 | // This does the actual data transfer. 271 | // The broker only closes the Read side. 272 | func broker(dst, src net.Conn, srcClosed chan bool, written, errors *int64) { 273 | _, err := io.Copy(dst, src) 274 | if err != nil { 275 | atomic.AddInt64(errors, 1) 276 | log.Printf("Copy error: %s", err) 277 | } 278 | if err := src.Close(); err != nil { 279 | atomic.AddInt64(errors, 1) 280 | log.Printf("Close error: %s", err) 281 | } 282 | srcClosed <- true 283 | } 284 | 285 | // A net.Conn that sets a deadline for every read or write operation. 286 | // This will allow the server to close connections that are broken at the 287 | // network level. 288 | type shuttleConn struct { 289 | *net.TCPConn 290 | rwTimeout time.Duration 291 | 292 | // count bytes read and written through this connection 293 | written *int64 294 | read *int64 295 | 296 | // decrement when closed 297 | connected *int64 298 | } 299 | 300 | func (c *shuttleConn) Read(b []byte) (int, error) { 301 | if c.rwTimeout > 0 { 302 | err := c.TCPConn.SetReadDeadline(time.Now().Add(c.rwTimeout)) 303 | if err != nil { 304 | return 0, err 305 | } 306 | } 307 | n, err := c.TCPConn.Read(b) 308 | atomic.AddInt64(c.read, int64(n)) 309 | return n, err 310 | } 311 | 312 | func (c *shuttleConn) Write(b []byte) (int, error) { 313 | if c.rwTimeout > 0 { 314 | err := c.TCPConn.SetWriteDeadline(time.Now().Add(c.rwTimeout)) 315 | if err != nil { 316 | return 0, err 317 | } 318 | } 319 | 320 | n, err := c.TCPConn.Write(b) 321 | atomic.AddInt64(c.written, int64(n)) 322 | return n, err 323 | } 324 | 325 | func (c *shuttleConn) Close() error { 326 | if c.connected != nil { 327 | atomic.AddInt64(c.connected, -1) 328 | } 329 | return c.TCPConn.Close() 330 | } 331 | 332 | // Empty function to override the ReadFrom in *net.TCPConn 333 | // io.Copy will attempt to use ReadFrom when it can, but there's no bennefit 334 | // for a TCPConn->TCPConn, and it prevents us from collecting Read/Write stats. 335 | func (c *shuttleConn) ReadFrom() {} 336 | -------------------------------------------------------------------------------- /reverseproxy.go: -------------------------------------------------------------------------------- 1 | // Copyright 2011 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | package main 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "io" 10 | "io/ioutil" 11 | "net" 12 | "net/http" 13 | "strings" 14 | "sync" 15 | "time" 16 | 17 | "github.com/litl/shuttle/log" 18 | ) 19 | 20 | // onExitFlushLoop is a callback set by tests to detect the state of the 21 | // flushLoop() goroutine. 22 | var onExitFlushLoop func() 23 | 24 | type ProxyCallback func(*ProxyRequest) bool 25 | 26 | // A Dialer can return an error wrapped in DialError to notify the ReverseProxy 27 | // that an error occured during the initial TCP connection, and it's safe to 28 | // try again. 29 | type DialError struct { 30 | error 31 | } 32 | 33 | // ReverseProxy is an HTTP Handler that takes an incoming request and 34 | // sends it to another server, proxying the response back to the 35 | // client. 36 | type ReverseProxy struct { 37 | // we need to protect our ErrorPage cache 38 | sync.Mutex 39 | 40 | // Director must be a function which modifies 41 | // the request into a new request to be sent 42 | // using Transport. Its response is then copied 43 | // back to the original client unmodified. 44 | Director func(*http.Request) 45 | 46 | // The transport used to perform proxy requests. 47 | // If nil, http.DefaultTransport is used. 48 | Transport http.RoundTripper 49 | 50 | // FlushInterval specifies the flush interval 51 | // to flush to the client while copying the 52 | // response body. 53 | // If zero, no periodic flushing is done. 54 | FlushInterval time.Duration 55 | 56 | // These are called in order on before any request is made to the backend server. 57 | // Each Callback must return true to continue processing. 58 | OnRequest []ProxyCallback 59 | 60 | // These are called in order after the response is obtained from the remote 61 | // server. The http.Response will be valid even on error. Callbacks may 62 | // write directly to the client, or modify the response which will be 63 | // written to the client if all callbacks complete with True. If any 64 | // callback returns false to stop the chain, the response is discarded. 65 | OnResponse []ProxyCallback 66 | } 67 | 68 | // Create a new ReverseProxy 69 | // This will still need to have a Director and Transport assigned. 70 | func NewReverseProxy(t *http.Transport) *ReverseProxy { 71 | p := &ReverseProxy{ 72 | Transport: t, 73 | FlushInterval: 1109 * time.Millisecond, 74 | } 75 | return p 76 | } 77 | 78 | func singleJoiningSlash(a, b string) string { 79 | aslash := strings.HasSuffix(a, "/") 80 | bslash := strings.HasPrefix(b, "/") 81 | switch { 82 | case aslash && bslash: 83 | return a + b[1:] 84 | case !aslash && !bslash: 85 | return a + "/" + b 86 | } 87 | return a + b 88 | } 89 | 90 | func copyHeader(dst, src http.Header) { 91 | for k, vv := range src { 92 | for _, v := range vv { 93 | dst.Add(k, v) 94 | } 95 | } 96 | } 97 | 98 | // Hop-by-hop headers. These are removed when sent to the backend. 99 | // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html 100 | var hopHeaders = []string{ 101 | "Connection", 102 | "Keep-Alive", 103 | "Proxy-Authenticate", 104 | "Proxy-Authorization", 105 | "Te", // canonicalized version of "TE" 106 | "Trailers", 107 | "Transfer-Encoding", 108 | "Upgrade", 109 | } 110 | 111 | // This probably shouldn't be called ServeHTTP anymore 112 | func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, addrs []string) { 113 | 114 | pr := &ProxyRequest{ 115 | ResponseWriter: rw, 116 | Request: req, 117 | Backends: addrs, 118 | } 119 | 120 | for _, f := range p.OnRequest { 121 | cont := f(pr) 122 | if !cont { 123 | return 124 | } 125 | } 126 | 127 | pr.StartTime = time.Now() 128 | res, err := p.doRequest(pr) 129 | 130 | pr.Response = res 131 | pr.ProxyError = err 132 | pr.FinishTime = time.Now() 133 | 134 | if err != nil { 135 | log.Printf("http: proxy error: %v", err) 136 | 137 | // We want to ensure that we have a non-nil response even on error for 138 | // the OnResponse callbacks. If the Callback chain completes, this will 139 | // be written to the client. 140 | res = &http.Response{ 141 | Header: make(map[string][]string), 142 | StatusCode: http.StatusBadGateway, 143 | Status: http.StatusText(http.StatusBadGateway), 144 | // this ensures Body isn't nil 145 | Body: ioutil.NopCloser(bytes.NewReader(nil)), 146 | } 147 | pr.Response = res 148 | } 149 | 150 | for _, h := range hopHeaders { 151 | res.Header.Del(h) 152 | } 153 | 154 | copyHeader(rw.Header(), res.Header) 155 | 156 | for _, f := range p.OnResponse { 157 | cont := f(pr) 158 | if !cont { 159 | return 160 | } 161 | } 162 | 163 | // calls all completed with true, write the Response back to the client. 164 | defer res.Body.Close() 165 | rw.WriteHeader(res.StatusCode) 166 | _, err = p.copyResponse(rw, res.Body) 167 | if err != nil { 168 | log.Warnf("id=%s transfer error: %s", req.Header.Get("X-Request-Id"), err) 169 | } 170 | } 171 | 172 | func (p *ReverseProxy) doRequest(pr *ProxyRequest) (*http.Response, error) { 173 | transport := p.Transport 174 | if transport == nil { 175 | transport = http.DefaultTransport 176 | } 177 | 178 | outreq := new(http.Request) 179 | *outreq = *pr.Request // includes shallow copies of maps, but okay 180 | 181 | p.Director(outreq) 182 | outreq.Proto = "HTTP/1.1" 183 | outreq.ProtoMajor = 1 184 | outreq.ProtoMinor = 1 185 | outreq.Close = false 186 | 187 | // Remove hop-by-hop headers to the backend. Especially 188 | // important is "Connection" because we want a persistent 189 | // connection, regardless of what the client sent to us. This 190 | // is modifying the same underlying map from req (shallow 191 | // copied above) so we only copy it if necessary. 192 | copiedHeaders := false 193 | for _, h := range hopHeaders { 194 | if outreq.Header.Get(h) != "" { 195 | if !copiedHeaders { 196 | outreq.Header = make(http.Header) 197 | copyHeader(outreq.Header, pr.Request.Header) 198 | copiedHeaders = true 199 | } 200 | 201 | outreq.Header.Del(h) 202 | } 203 | } 204 | 205 | if clientIP, _, err := net.SplitHostPort(pr.Request.RemoteAddr); err == nil { 206 | // If we aren't the first proxy retain prior 207 | // X-Forwarded-For information as a comma+space 208 | // separated list and fold multiple headers into one. 209 | if prior, ok := outreq.Header["X-Forwarded-For"]; ok { 210 | clientIP = strings.Join(prior, ", ") + ", " + clientIP 211 | } 212 | outreq.Header.Set("X-Forwarded-For", clientIP) 213 | } 214 | 215 | var err error 216 | var resp *http.Response 217 | 218 | for _, addr := range pr.Backends { 219 | outreq.URL.Host = addr 220 | resp, err = transport.RoundTrip(outreq) 221 | 222 | if err == nil { 223 | pr.ResponseWriter.Header().Set("X-Backend", addr) 224 | return resp, nil 225 | } 226 | 227 | if _, ok := err.(DialError); ok { 228 | // only Dial failed, so we can try again 229 | continue 230 | } 231 | 232 | // not a DialError, so make this terminal. 233 | return nil, err 234 | } 235 | 236 | // In this case, our last backend returned a DialError 237 | if err != nil { 238 | return nil, err 239 | } 240 | 241 | // probably shouldn't get here 242 | return nil, fmt.Errorf("no http backends available") 243 | } 244 | 245 | func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) (int64, error) { 246 | if p.FlushInterval != 0 { 247 | if wf, ok := dst.(writeFlusher); ok { 248 | mlw := &maxLatencyWriter{ 249 | dst: wf, 250 | latency: p.FlushInterval, 251 | done: make(chan bool), 252 | } 253 | go mlw.flushLoop() 254 | defer mlw.stop() 255 | dst = mlw 256 | } 257 | } 258 | 259 | return io.Copy(dst, src) 260 | } 261 | 262 | type writeFlusher interface { 263 | io.Writer 264 | http.Flusher 265 | } 266 | 267 | type maxLatencyWriter struct { 268 | dst writeFlusher 269 | latency time.Duration 270 | 271 | lk sync.Mutex // protects Write + Flush 272 | done chan bool 273 | } 274 | 275 | func (m *maxLatencyWriter) Write(p []byte) (int, error) { 276 | m.lk.Lock() 277 | defer m.lk.Unlock() 278 | return m.dst.Write(p) 279 | } 280 | 281 | func (m *maxLatencyWriter) flushLoop() { 282 | t := time.NewTicker(m.latency) 283 | defer t.Stop() 284 | for { 285 | select { 286 | case <-m.done: 287 | if onExitFlushLoop != nil { 288 | onExitFlushLoop() 289 | } 290 | return 291 | case <-t.C: 292 | m.lk.Lock() 293 | m.dst.Flush() 294 | m.lk.Unlock() 295 | } 296 | } 297 | } 298 | 299 | func (m *maxLatencyWriter) stop() { m.done <- true } 300 | 301 | // Proxy Request stores a client request, backend response, error, and any 302 | // stats needed to complete a round trip. 303 | type ProxyRequest struct { 304 | // The incoming request from the client 305 | Request *http.Request 306 | 307 | // The Client's ResponseWriter 308 | ResponseWriter http.ResponseWriter 309 | 310 | // The response, if any, from the backend server 311 | Response *http.Response 312 | 313 | // The error, if any, from the http request to the backend server 314 | ProxyError error 315 | 316 | // backend hosts we can use 317 | Backends []string 318 | 319 | // Duration of the backend request 320 | StartTime time.Time 321 | FinishTime time.Time 322 | } 323 | -------------------------------------------------------------------------------- /client/config.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "reflect" 6 | "sort" 7 | ) 8 | 9 | const ( 10 | // Balancing schemes 11 | RoundRobin = "RR" 12 | LeastConn = "LC" 13 | 14 | // Default timeout in milliseconds for clients and server connections 15 | DefaultTimeout = 2000 16 | 17 | // Default interval in milliseconds between health checks 18 | DefaultCheckInterval = 5000 19 | 20 | // Default network connections are TCP 21 | DefaultNet = "tcp" 22 | 23 | // All RoundRobin backends are weighted, with a default of 1 24 | DefaultWeight = 1 25 | 26 | // RoundRobin is the default balancing scheme 27 | DefaultBalance = RoundRobin 28 | 29 | // Default for Fall and Rise is 2 30 | DefaultFall = 2 31 | DefaultRise = 2 32 | ) 33 | 34 | var ( 35 | // Status400s is a set of response codes to set an Error page for all 4xx responses. 36 | Status400s = []int{400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418} 37 | // Status500s is a set of response codes to set an Error page for all 5xx responses. 38 | Status500s = []int{500, 501, 502, 503, 504, 505} 39 | ) 40 | 41 | // Config is the global configuration for all Services. 42 | // Defaults set here can be overridden by individual services. 43 | type Config struct { 44 | // Balance method 45 | // Valid values are "RR" for RoundRobin, the default, and "LC" for 46 | // LeastConnected. 47 | Balance string `json:"balance,omitempty"` 48 | 49 | // CheckInterval is in time in milliseconds between service health checks. 50 | CheckInterval int `json:"check_interval"` 51 | 52 | // Fall is the number of failed health checks before a service is marked 53 | // down. 54 | Fall int `json:"fall"` 55 | 56 | // Rise is the number of successful health checks before a down service is 57 | // marked up. 58 | Rise int `json:"rise"` 59 | 60 | // ClientTimeout is the maximum inactivity time, in milliseconds, for a 61 | // connection to the client before it is closed. 62 | ClientTimeout int `json:"client_timeout"` 63 | 64 | // ServerTimeout is the maximum inactivity time, in milliseconds, for a 65 | // connection to the backend before it is closed. 66 | ServerTimeout int `json:"server_timeout"` 67 | 68 | // DialTimeout is the timeout in milliseconds for connections to the 69 | // backend service, including name resolution. 70 | DialTimeout int `json:"connect_timeout"` 71 | 72 | // HTTPSRedirect when set to true, redirects non-https request to https on 73 | // all services. The request may either have Scheme set to 'https', or 74 | // have an "X-Forwarded-Proto: https" header. 75 | HTTPSRedirect bool `json:"https-redirect"` 76 | 77 | // Services is a slice of ServiceConfig for each service. A service 78 | // corresponds to one listening connection, and a number of backends to 79 | // proxy. 80 | Services []ServiceConfig `json:"services"` 81 | } 82 | 83 | // Marshal returns an entire config as a json []byte. 84 | func (c *Config) Marshal() []byte { 85 | sort.Sort(serviceSlice(c.Services)) 86 | js, _ := json.Marshal(c) 87 | return js 88 | } 89 | 90 | // The string representation of a config is in json. 91 | func (c *Config) String() string { 92 | return string(c.Marshal()) 93 | } 94 | 95 | // BackendConfig defines the parameters unique for individual backends. 96 | type BackendConfig struct { 97 | // Name must be unique for this service. 98 | // Used for reference and for the HTTP API. 99 | Name string `json:"name"` 100 | 101 | // Addr must in the form ip:port 102 | Addr string `json:"address"` 103 | 104 | // Network must be "tcp" or "udp". 105 | // Default is "tcp" 106 | Network string `json:"network,omitempty"` 107 | 108 | // CheckAddr must be in the form ip:port. 109 | // A TCP connect is performed against this address to determine server 110 | // availability. If this is empty, no checks will be performed. 111 | CheckAddr string `json:"check_address"` 112 | 113 | // Weight is always used for RoundRobin balancing. Default is 1 114 | Weight int `json:"weight"` 115 | } 116 | 117 | // return a copy of the BackendConfig with default values set 118 | func (b BackendConfig) SetDefaults() BackendConfig { 119 | if b.Weight == 0 { 120 | b.Weight = DefaultWeight 121 | } 122 | if b.Network == "" { 123 | b.Network = DefaultNet 124 | } 125 | return b 126 | } 127 | 128 | func (b BackendConfig) Equal(other BackendConfig) bool { 129 | b = b.SetDefaults() 130 | other = other.SetDefaults() 131 | return b == other 132 | } 133 | 134 | func (b *BackendConfig) Marshal() []byte { 135 | js, _ := json.Marshal(b) 136 | return js 137 | } 138 | 139 | func (b *BackendConfig) String() string { 140 | return string(b.Marshal()) 141 | } 142 | 143 | // keep things sorted for easy viewing and comparison 144 | type backendSlice []BackendConfig 145 | 146 | func (p backendSlice) Len() int { return len(p) } 147 | func (p backendSlice) Less(i, j int) bool { return p[i].Name < p[j].Name } 148 | func (p backendSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } 149 | 150 | type serviceSlice []ServiceConfig 151 | 152 | func (p serviceSlice) Len() int { return len(p) } 153 | func (p serviceSlice) Less(i, j int) bool { return p[i].Name < p[j].Name } 154 | func (p serviceSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } 155 | 156 | // Subset of service fields needed for configuration. 157 | type ServiceConfig struct { 158 | // Name is the unique name of the service. This is used only for reference 159 | // and in the HTTP API. 160 | Name string `json:"name"` 161 | 162 | // Addr is the listening address for this service. Must be in the form 163 | // "ip:addr" 164 | Addr string `json:"address"` 165 | 166 | // Network must be "tcp" or "udp". 167 | // Default is "tcp" 168 | Network string `json:"network,omitempty"` 169 | 170 | // Balance method 171 | // Valid values are "RR" for RoundRobin, the default, and "LC" for 172 | // LeastConnected. 173 | Balance string `json:"balance,omitempty"` 174 | 175 | // CheckInterval is in time in milliseconds between service health checks. 176 | CheckInterval int `json:"check_interval"` 177 | 178 | // Fall is the number of failed health checks before a service is marked. 179 | Fall int `json:"fall"` 180 | 181 | // Rise is the number of successful health checks before a down service is 182 | // marked up. 183 | Rise int `json:"rise"` 184 | 185 | // ClientTimeout is the maximum inactivity time, in milliseconds, for a 186 | // connection to the client before it is closed. 187 | ClientTimeout int `json:"client_timeout"` 188 | 189 | // ServerTimeout is the maximum inactivity time, in milliseconds, for a 190 | // connection to the backend before it is closed. 191 | ServerTimeout int `json:"server_timeout"` 192 | 193 | // DialTimeout is the timeout in milliseconds for connections to the 194 | // backend service, including name resolution. 195 | DialTimeout int `json:"connect_timeout"` 196 | 197 | // HTTPSRedirect when set to true, redirects non-https request to https. The 198 | // request may either have Scheme set to 'https', or have an 199 | // "X-Forwarded-Proto: https" header. 200 | HTTPSRedirect bool `json:"https-redirect"` 201 | 202 | // Virtualhosts is a set of virtual hostnames for which this service should 203 | // handle HTTP requests. 204 | VirtualHosts []string `json:"virtual_hosts,omitempty"` 205 | 206 | // ErrorPages are responses to be returned for HTTP error codes. Each page 207 | // is defined by a URL mapped and is mapped to a list of error codes that 208 | // should return the content at the URL. Error pages are retrieved ahead of 209 | // time if possible, and cached. 210 | ErrorPages map[string][]int `json:"error_pages,omitempty"` 211 | 212 | // Backends is a list of all servers handling connections for this service. 213 | Backends []BackendConfig `json:"backends,omitempty"` 214 | 215 | // Maintenance mode is a flag to return 503 status codes to clients 216 | // without visiting backends. 217 | MaintenanceMode bool `json:"maintenance_mode"` 218 | } 219 | 220 | // Return a copy of ServiceConfig with any unset fields to their default 221 | // values 222 | func (s ServiceConfig) SetDefaults() ServiceConfig { 223 | if s.Balance == "" { 224 | s.Balance = DefaultBalance 225 | } 226 | if s.CheckInterval == 0 { 227 | s.CheckInterval = DefaultCheckInterval 228 | } 229 | if s.Rise == 0 { 230 | s.Rise = DefaultRise 231 | } 232 | if s.Fall == 0 { 233 | s.Fall = DefaultFall 234 | } 235 | if s.Network == "" { 236 | s.Network = DefaultNet 237 | } 238 | return s 239 | } 240 | 241 | // Compare a service's settings, ignoring individual backends. 242 | func (s ServiceConfig) Equal(other ServiceConfig) bool { 243 | // just remove the backends and compare the rest 244 | s.Backends = nil 245 | other.Backends = nil 246 | 247 | s = s.SetDefaults() 248 | other = other.SetDefaults() 249 | 250 | sort.Strings(s.VirtualHosts) 251 | sort.Strings(s.VirtualHosts) 252 | 253 | // FIXME: ignoring VirtualHosts and ErrorPages equality 254 | return reflect.DeepEqual(s, other) 255 | } 256 | 257 | // Check for equality including backends 258 | func (s ServiceConfig) DeepEqual(other ServiceConfig) bool { 259 | if len(s.Backends) != len(other.Backends) { 260 | return false 261 | } 262 | 263 | if !s.Equal(other) { 264 | return false 265 | } 266 | 267 | if len(s.Backends) != len(other.Backends) { 268 | return false 269 | } 270 | 271 | sort.Sort(backendSlice(s.Backends)) 272 | sort.Sort(backendSlice(other.Backends)) 273 | 274 | for i := range s.Backends { 275 | if !s.Backends[i].Equal(other.Backends[i]) { 276 | return false 277 | } 278 | } 279 | return true 280 | } 281 | 282 | func (b *ServiceConfig) Marshal() []byte { 283 | sort.Sort(backendSlice(b.Backends)) 284 | js, _ := json.Marshal(b) 285 | return js 286 | } 287 | 288 | func (b *ServiceConfig) String() string { 289 | return string(b.Marshal()) 290 | } 291 | 292 | // Create a new config by merging the values from the current config 293 | // with those set in the new config 294 | func (s ServiceConfig) Merge(cfg ServiceConfig) ServiceConfig { 295 | new := s 296 | 297 | // let's try not to change the name 298 | new.Name = cfg.Name 299 | 300 | if cfg.Addr != "" { 301 | new.Addr = cfg.Addr 302 | } 303 | if cfg.Network != "" { 304 | new.Network = cfg.Network 305 | } 306 | if cfg.Balance != "" { 307 | new.Balance = cfg.Balance 308 | } 309 | if cfg.CheckInterval != 0 { 310 | new.CheckInterval = cfg.CheckInterval 311 | } 312 | if cfg.Fall != 0 { 313 | new.Fall = cfg.Fall 314 | } 315 | if cfg.Rise != 0 { 316 | new.Rise = cfg.Rise 317 | } 318 | if cfg.ClientTimeout != 0 { 319 | new.ClientTimeout = cfg.ClientTimeout 320 | } 321 | if cfg.ServerTimeout != 0 { 322 | new.ServerTimeout = cfg.ServerTimeout 323 | } 324 | if cfg.DialTimeout != 0 { 325 | new.DialTimeout = cfg.DialTimeout 326 | } 327 | 328 | if cfg.VirtualHosts != nil { 329 | new.VirtualHosts = cfg.VirtualHosts 330 | } 331 | 332 | if cfg.ErrorPages != nil { 333 | new.ErrorPages = cfg.ErrorPages 334 | } 335 | 336 | if cfg.Backends != nil { 337 | new.Backends = cfg.Backends 338 | } 339 | 340 | new.HTTPSRedirect = cfg.HTTPSRedirect 341 | new.MaintenanceMode = cfg.MaintenanceMode 342 | 343 | return new 344 | } 345 | -------------------------------------------------------------------------------- /http.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "io/ioutil" 7 | "net" 8 | "net/http" 9 | "path/filepath" 10 | "strings" 11 | "sync" 12 | "time" 13 | 14 | "github.com/litl/shuttle/log" 15 | ) 16 | 17 | var ( 18 | httpRouter *HostRouter 19 | ) 20 | 21 | // This works along with the ServiceRegistry, and the individual Services to 22 | // route http requests based on the Host header. The Resgistry hold the mapping 23 | // of VHost names to individual services, and each service has it's own 24 | // ReeverseProxy to fulfill the request. 25 | // HostRouter contains the ReverseProxy http Listener, and has an http.Handler 26 | // to service the requets. 27 | type HostRouter struct { 28 | sync.Mutex 29 | // the http frontend 30 | server *http.Server 31 | 32 | // HTTP/HTTPS 33 | Scheme string 34 | 35 | // track our listener so we can kill the server 36 | listener net.Listener 37 | } 38 | 39 | func NewHostRouter(httpServer *http.Server) *HostRouter { 40 | r := &HostRouter{ 41 | Scheme: "http", 42 | } 43 | httpServer.Handler = r 44 | r.server = httpServer 45 | return r 46 | } 47 | 48 | func (r *HostRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { 49 | reqId := req.Header.Get("X-Request-Id") 50 | if reqId == "" { 51 | reqId = genId() 52 | } else { 53 | reqId = genId() + "." + reqId 54 | } 55 | req.Header.Set("X-Request-Id", reqId) 56 | w.Header().Add("X-Request-Id", reqId) 57 | 58 | var err error 59 | host := req.Host 60 | 61 | if strings.Contains(host, ":") { 62 | host, _, err = net.SplitHostPort(req.Host) 63 | if err != nil { 64 | log.Warnf("%s", err) 65 | } 66 | } 67 | 68 | svc := Registry.GetVHostService(host) 69 | 70 | if svc != nil && svc.httpProxy != nil { 71 | // The vhost has a service registered, give it to the proxy 72 | svc.ServeHTTP(w, req) 73 | return 74 | } 75 | 76 | r.noHostHandler(w, req) 77 | } 78 | 79 | func (r *HostRouter) noHostHandler(w http.ResponseWriter, req *http.Request) { 80 | w.WriteHeader(http.StatusNotFound) 81 | fmt.Fprintln(w, "Not found") 82 | } 83 | 84 | // TODO: collect more stats? 85 | 86 | // Start the HTTP Router frontend. 87 | // Takes a channel to notify when the listener is started 88 | // to safely synchronize tests. 89 | func (r *HostRouter) Start(ready chan bool) { 90 | //FIXME: poor locking strategy 91 | r.Lock() 92 | var err error 93 | r.listener, err = newTimeoutListener("tcp", r.server.Addr, 300*time.Second) 94 | if err != nil { 95 | log.Errorf("%s", err) 96 | r.Unlock() 97 | return 98 | } 99 | 100 | listener := r.listener 101 | if r.Scheme == "https" { 102 | listener = tls.NewListener(listener, r.server.TLSConfig) 103 | } 104 | 105 | r.Unlock() 106 | 107 | log.Printf("%s server listening at %s", strings.ToUpper(r.Scheme), r.server.Addr) 108 | if ready != nil { 109 | close(ready) 110 | } 111 | 112 | // This will log a closed connection error every time we Stop 113 | // but that's mostly a testing issue. 114 | log.Errorf("%s", r.server.Serve(listener)) 115 | } 116 | 117 | func (r *HostRouter) Stop() { 118 | r.listener.Close() 119 | } 120 | 121 | func startHTTPServer(wg *sync.WaitGroup) { 122 | defer wg.Done() 123 | 124 | //TODO: configure these timeouts somewhere 125 | httpServer := &http.Server{ 126 | Addr: httpAddr, 127 | ReadTimeout: 10 * time.Minute, 128 | WriteTimeout: 10 * time.Minute, 129 | MaxHeaderBytes: 1 << 20, 130 | } 131 | 132 | httpRouter = NewHostRouter(httpServer) 133 | 134 | httpRouter.Start(nil) 135 | } 136 | 137 | // find certs in and is the named directory, and match them up by their base 138 | // name using '.pem' and '.key' as extensions. 139 | func loadCerts(certDir string) (*tls.Config, error) { 140 | abs, err := filepath.Abs(certDir) 141 | if err != nil { 142 | return nil, err 143 | } 144 | 145 | dir, err := ioutil.ReadDir(abs) 146 | if err != nil { 147 | return nil, err 148 | } 149 | 150 | // [cert, key] pairs 151 | pairs := make(map[string][2]string) 152 | 153 | for _, f := range dir { 154 | name := f.Name() 155 | if strings.HasSuffix(name, ".pem") { 156 | p := pairs[name[:len(name)-4]] 157 | p[0] = filepath.Join(abs, name) 158 | pairs[name[:len(name)-4]] = p 159 | } 160 | if strings.HasSuffix(name, ".key") { 161 | p := pairs[name[:len(name)-4]] 162 | p[1] = filepath.Join(abs, name) 163 | pairs[name[:len(name)-4]] = p 164 | } 165 | } 166 | 167 | tlsCfg := &tls.Config{ 168 | NextProtos: []string{"http/1.1"}, 169 | } 170 | 171 | for key, pair := range pairs { 172 | if pair[0] == "" { 173 | log.Errorf("missing cert for key: %s", pair[1]) 174 | continue 175 | } 176 | if pair[1] == "" { 177 | log.Errorf("missing key for cert: %s", pair[0]) 178 | continue 179 | } 180 | 181 | cert, err := tls.LoadX509KeyPair(pair[0], pair[1]) 182 | if err != nil { 183 | log.Error(err) 184 | continue 185 | } 186 | tlsCfg.Certificates = append(tlsCfg.Certificates, cert) 187 | log.Debugf("loaded X509KeyPair for %s", key) 188 | } 189 | 190 | if len(tlsCfg.Certificates) == 0 { 191 | return nil, fmt.Errorf("no tls certificates loaded") 192 | } 193 | 194 | tlsCfg.BuildNameToCertificate() 195 | 196 | return tlsCfg, nil 197 | } 198 | 199 | func startHTTPSServer(wg *sync.WaitGroup) { 200 | defer wg.Done() 201 | 202 | tlsCfg, err := loadCerts(certDir) 203 | if err != nil { 204 | log.Error(err) 205 | return 206 | } 207 | 208 | //TODO: configure these timeouts somewhere 209 | httpsServer := &http.Server{ 210 | Addr: httpsAddr, 211 | ReadTimeout: 10 * time.Minute, 212 | WriteTimeout: 10 * time.Minute, 213 | MaxHeaderBytes: 1 << 20, 214 | TLSConfig: tlsCfg, 215 | } 216 | 217 | httpRouter = NewHostRouter(httpsServer) 218 | httpRouter.Scheme = "https" 219 | 220 | httpRouter.Start(nil) 221 | } 222 | 223 | type ErrorPage struct { 224 | // The Mutex protects access to the body slice, and headers 225 | // Everything else should be static once the ErrorPage is created. 226 | sync.Mutex 227 | 228 | Location string 229 | StatusCodes []int 230 | 231 | // body contains the cached error page 232 | body []byte 233 | // important headers 234 | header http.Header 235 | } 236 | 237 | func (e *ErrorPage) Body() []byte { 238 | e.Lock() 239 | defer e.Unlock() 240 | return e.body 241 | } 242 | 243 | func (e *ErrorPage) SetBody(b []byte) { 244 | e.Lock() 245 | defer e.Unlock() 246 | e.body = b 247 | } 248 | 249 | func (e *ErrorPage) Header() http.Header { 250 | e.Lock() 251 | defer e.Unlock() 252 | return e.header 253 | } 254 | 255 | func (e *ErrorPage) SetHeader(h http.Header) { 256 | e.Lock() 257 | defer e.Unlock() 258 | e.header = h 259 | } 260 | 261 | // List of headers we want to cache for ErrorPages 262 | var ErrorHeaders = []string{ 263 | "Content-Type", 264 | "Content-Encoding", 265 | "Cache-Control", 266 | "Last-Modified", 267 | "Retry-After", 268 | "Set-Cookie", 269 | } 270 | 271 | // ErrorResponse provides a ReverProxy callback to process a response and 272 | // insert custom error pages for a virtual host. 273 | type ErrorResponse struct { 274 | sync.Mutex 275 | 276 | // map them by status for responses 277 | pages map[int]*ErrorPage 278 | 279 | // keep this handy to refresh the pages 280 | client *http.Client 281 | } 282 | 283 | func NewErrorResponse(pages map[string][]int) *ErrorResponse { 284 | errors := &ErrorResponse{ 285 | pages: make(map[int]*ErrorPage), 286 | } 287 | 288 | // aggressively timeout connections 289 | errors.client = &http.Client{ 290 | Transport: &http.Transport{ 291 | Dial: (&net.Dialer{ 292 | Timeout: 2 * time.Second, 293 | }).Dial, 294 | TLSHandshakeTimeout: 2 * time.Second, 295 | }, 296 | Timeout: 5 * time.Second, 297 | } 298 | 299 | if pages != nil { 300 | errors.Update(pages) 301 | } 302 | return errors 303 | } 304 | 305 | // Get the ErrorPage, returning nil if the page was incomplete. 306 | // We permanently cache error pages and headers once we've seen them. 307 | func (e *ErrorResponse) Get(code int) *ErrorPage { 308 | e.Lock() 309 | page, ok := e.pages[code] 310 | e.Unlock() 311 | 312 | if !ok { 313 | // this is a code we don't handle 314 | return nil 315 | } 316 | 317 | body := page.Body() 318 | if body != nil { 319 | return page 320 | } 321 | 322 | // we haven't successfully fetched this error 323 | e.fetch(page) 324 | return page 325 | } 326 | 327 | func (e *ErrorResponse) fetch(page *ErrorPage) { 328 | log.Debugf("Fetching error page from %s", page.Location) 329 | resp, err := e.client.Get(page.Location) 330 | if err != nil { 331 | log.Warnf("Could not fetch %s: %s", page.Location, err.Error()) 332 | return 333 | } 334 | defer resp.Body.Close() 335 | 336 | // If the StatusCode matches any of our registered codes, it's OK 337 | for _, code := range page.StatusCodes { 338 | if resp.StatusCode == code { 339 | resp.StatusCode = http.StatusOK 340 | break 341 | } 342 | } 343 | 344 | if resp.StatusCode != http.StatusOK { 345 | log.Warnf("Server returned %d when fetching %s", resp.StatusCode, page.Location) 346 | return 347 | } 348 | 349 | header := make(map[string][]string) 350 | for _, key := range ErrorHeaders { 351 | if hdr, ok := resp.Header[key]; ok { 352 | header[key] = hdr 353 | } 354 | } 355 | // set the headers along with the body below 356 | 357 | body, err := ioutil.ReadAll(resp.Body) 358 | if err != nil { 359 | log.Warnf("Error reading response from %s: %s", page.Location, err.Error()) 360 | return 361 | } 362 | 363 | if len(body) > 0 { 364 | page.SetHeader(header) 365 | page.SetBody(body) 366 | return 367 | } 368 | log.Warnf("Empty response from %s", page.Location) 369 | } 370 | 371 | // This replaces all existing ErrorPages 372 | func (e *ErrorResponse) Update(pages map[string][]int) { 373 | e.Lock() 374 | defer e.Unlock() 375 | 376 | e.pages = make(map[int]*ErrorPage) 377 | 378 | for loc, codes := range pages { 379 | page := &ErrorPage{ 380 | StatusCodes: codes, 381 | Location: loc, 382 | } 383 | 384 | for _, code := range codes { 385 | e.pages[code] = page 386 | } 387 | go e.fetch(page) 388 | } 389 | } 390 | 391 | func (e *ErrorResponse) CheckResponse(pr *ProxyRequest) bool { 392 | 393 | errPage := e.Get(pr.Response.StatusCode) 394 | if errPage != nil { 395 | // load the cached headers 396 | header := pr.ResponseWriter.Header() 397 | for key, val := range errPage.Header() { 398 | header[key] = val 399 | } 400 | 401 | pr.ResponseWriter.WriteHeader(pr.Response.StatusCode) 402 | pr.ResponseWriter.Write(errPage.Body()) 403 | return false 404 | } 405 | 406 | return true 407 | } 408 | 409 | func logRequest(req *http.Request, statusCode int, backend string, proxyError error, duration time.Duration) { 410 | id := req.Header.Get("X-Request-Id") 411 | method := req.Method 412 | url := req.Host + req.RequestURI 413 | agent := req.UserAgent() 414 | 415 | clientIP := req.Header.Get("X-Forwarded-For") 416 | if clientIP == "" { 417 | clientIP = req.RemoteAddr 418 | } 419 | 420 | errStr := fmt.Sprintf("%v", proxyError) 421 | fmtStr := "id=%s method=%s client-ip=%s url=%s backend=%s status=%d duration=%s agent=%s, err=%s" 422 | log.Printf(fmtStr, id, method, clientIP, url, backend, statusCode, duration, agent, errStr) 423 | } 424 | 425 | func logProxyRequest(pr *ProxyRequest) bool { 426 | // TODO: we may to be able to switch this off 427 | if pr == nil || pr.Request == nil { 428 | return true 429 | } 430 | 431 | duration := pr.FinishTime.Sub(pr.StartTime) 432 | 433 | var backend string 434 | if pr.Response != nil && pr.Response.Request != nil && pr.Response.Request.URL != nil { 435 | backend = pr.Response.Request.URL.Host 436 | } 437 | 438 | logRequest(pr.Request, pr.Response.StatusCode, backend, pr.ProxyError, duration) 439 | return true 440 | } 441 | -------------------------------------------------------------------------------- /registry.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "sort" 7 | "strings" 8 | "sync" 9 | 10 | "github.com/litl/shuttle/client" 11 | "github.com/litl/shuttle/log" 12 | ) 13 | 14 | var ( 15 | ErrNoService = fmt.Errorf("service does not exist") 16 | ErrNoBackend = fmt.Errorf("backend does not exist") 17 | ErrDuplicateService = fmt.Errorf("service already exists") 18 | ErrDuplicateBackend = fmt.Errorf("backend already exists") 19 | ) 20 | 21 | type multiError struct { 22 | errors []error 23 | } 24 | 25 | func (e *multiError) Add(err error) { 26 | e.errors = append(e.errors, err) 27 | } 28 | 29 | func (e multiError) Len() int { 30 | return len(e.errors) 31 | } 32 | 33 | func (e multiError) Error() string { 34 | msgs := make([]string, len(e.errors)) 35 | for i, err := range e.errors { 36 | msgs[i] = err.Error() 37 | } 38 | return strings.Join(msgs, ", ") 39 | } 40 | 41 | func (e multiError) String() string { 42 | return e.Error() 43 | } 44 | 45 | type VirtualHost struct { 46 | sync.Mutex 47 | Name string 48 | // All services registered under this vhost name. 49 | services []*Service 50 | // The last one we returned so we can RoundRobin them. 51 | last int 52 | } 53 | 54 | func (v *VirtualHost) Len() int { 55 | v.Lock() 56 | defer v.Unlock() 57 | return len(v.services) 58 | } 59 | 60 | // Insert a service 61 | // do nothing if the service already is registered 62 | func (v *VirtualHost) Add(svc *Service) { 63 | v.Lock() 64 | defer v.Unlock() 65 | for _, s := range v.services { 66 | if s.Name == svc.Name { 67 | log.Debugf("Service %s already registered in VirtualHost %s", svc.Name, v.Name) 68 | return 69 | } 70 | } 71 | 72 | // TODO: is this the best place to log these? 73 | svcCfg := svc.Config() 74 | for _, backend := range svcCfg.Backends { 75 | log.Printf("Adding backend http://%s to VirtualHost %s", backend.Addr, v.Name) 76 | } 77 | v.services = append(v.services, svc) 78 | } 79 | 80 | func (v *VirtualHost) Remove(svc *Service) { 81 | v.Lock() 82 | defer v.Unlock() 83 | 84 | found := -1 85 | for i, s := range v.services { 86 | if s.Name == svc.Name { 87 | found = i 88 | break 89 | } 90 | } 91 | 92 | if found < 0 { 93 | log.Debugf("Service %s not found under VirtualHost %s", svc.Name, v.Name) 94 | return 95 | } 96 | 97 | // safe way to get the backends info for logging 98 | svcCfg := svc.Config() 99 | 100 | // Now removing this Service 101 | for _, backend := range svcCfg.Backends { 102 | log.Printf("Removing backend http://%s from VirtualHost %s", backend.Addr, v.Name) 103 | } 104 | 105 | v.services = append(v.services[:found], v.services[found+1:]...) 106 | } 107 | 108 | // Return a *Service for this VirtualHost 109 | func (v *VirtualHost) Service() *Service { 110 | v.Lock() 111 | defer v.Unlock() 112 | 113 | if len(v.services) == 0 { 114 | log.Warnf("No Services registered for VirtualHost %s", v.Name) 115 | return nil 116 | } 117 | 118 | // start cycling through the services in case one has no backends available 119 | for i := 1; i <= len(v.services); i++ { 120 | idx := (v.last + i) % len(v.services) 121 | if v.services[idx].Available() > 0 { 122 | v.last = idx 123 | return v.services[idx] 124 | } 125 | } 126 | 127 | // even if all backends are down, return a service so that the request can 128 | // be processed normally (we may have a custom 502 error page for this) 129 | return v.services[v.last] 130 | } 131 | 132 | //TODO: notify or prevent vhost name conflicts between services. 133 | // ServiceRegistry is a global container for all configured services. 134 | type ServiceRegistry struct { 135 | sync.Mutex 136 | svcs map[string]*Service 137 | // Multiple services may respond from a single vhost 138 | vhosts map[string]*VirtualHost 139 | 140 | // Global config to apply to new services. 141 | cfg client.Config 142 | } 143 | 144 | // Update the global config state, including services and backends. 145 | // This does not remove any Services, but will add or update any provided in 146 | // the config. 147 | func (s *ServiceRegistry) UpdateConfig(cfg client.Config) error { 148 | 149 | // Set globals 150 | // TODO: we might need to unset something 151 | // TODO: this should remove services and backends to match the submitted config 152 | 153 | if cfg.Balance != "" { 154 | s.cfg.Balance = cfg.Balance 155 | } 156 | if cfg.CheckInterval != 0 { 157 | s.cfg.CheckInterval = cfg.CheckInterval 158 | } 159 | if cfg.Fall != 0 { 160 | s.cfg.Fall = cfg.Fall 161 | } 162 | if cfg.Rise != 0 { 163 | s.cfg.Rise = cfg.Rise 164 | } 165 | if cfg.ClientTimeout != 0 { 166 | s.cfg.ClientTimeout = cfg.ClientTimeout 167 | } 168 | if cfg.ServerTimeout != 0 { 169 | s.cfg.ServerTimeout = cfg.ServerTimeout 170 | } 171 | if cfg.DialTimeout != 0 { 172 | s.cfg.DialTimeout = cfg.DialTimeout 173 | } 174 | 175 | // apply the https rediect flag 176 | if httpsRedirect { 177 | s.cfg.HTTPSRedirect = true 178 | } 179 | 180 | invalidPorts := []string{ 181 | // FIXME: lookup bound addresses some other way. We may have multiple 182 | // http listeners, as well as all listening Services. 183 | // listenAddr[strings.Index(listenAddr, ":")+1:], 184 | adminListenAddr[strings.Index(adminListenAddr, ":")+1:], 185 | } 186 | 187 | errors := &multiError{} 188 | 189 | for _, svc := range cfg.Services { 190 | for _, port := range invalidPorts { 191 | if strings.HasSuffix(svc.Addr, port) { 192 | // TODO: report conflicts between service listeners 193 | errors.Add(fmt.Errorf("Port conflict: %s port %s already bound by shuttle", svc.Name, port)) 194 | continue 195 | } 196 | } 197 | 198 | // Add a new service, or update an existing one. 199 | if Registry.GetService(svc.Name) == nil { 200 | if err := Registry.AddService(svc); err != nil { 201 | log.Errorln("Unable to add service %s: %s", svc.Name, err.Error()) 202 | errors.Add(err) 203 | continue 204 | } 205 | } else if err := Registry.UpdateService(svc); err != nil { 206 | log.Errorln("Unable to update service %s: %s", svc.Name, err.Error()) 207 | errors.Add(err) 208 | continue 209 | } 210 | } 211 | 212 | go writeStateConfig() 213 | 214 | if errors.Len() == 0 { 215 | return nil 216 | } 217 | return errors 218 | } 219 | 220 | // Return a service by name. 221 | func (s *ServiceRegistry) GetService(name string) *Service { 222 | s.Lock() 223 | defer s.Unlock() 224 | return s.svcs[name] 225 | } 226 | 227 | // Return a service that handles a particular vhost by name. 228 | func (s *ServiceRegistry) GetVHostService(name string) *Service { 229 | s.Lock() 230 | defer s.Unlock() 231 | 232 | if vhost := s.vhosts[name]; vhost != nil { 233 | return vhost.Service() 234 | } 235 | return nil 236 | } 237 | 238 | func (s *ServiceRegistry) VHostsLen() int { 239 | s.Lock() 240 | defer s.Unlock() 241 | return len(s.vhosts) 242 | } 243 | 244 | // Add a new service to the Registry. 245 | // Do not replace an existing service. 246 | func (s *ServiceRegistry) AddService(svcCfg client.ServiceConfig) error { 247 | s.Lock() 248 | defer s.Unlock() 249 | 250 | log.Debug("Adding service:", svcCfg.Name) 251 | if _, ok := s.svcs[svcCfg.Name]; ok { 252 | log.Debug("Service already exists:", svcCfg.Name) 253 | return ErrDuplicateService 254 | } 255 | 256 | s.setServiceDefaults(&svcCfg) 257 | svcCfg = svcCfg.SetDefaults() 258 | 259 | service := NewService(svcCfg) 260 | err := service.start() 261 | if err != nil { 262 | return err 263 | } 264 | 265 | s.svcs[service.Name] = service 266 | 267 | svcCfg.VirtualHosts = filterEmpty(svcCfg.VirtualHosts) 268 | for _, name := range svcCfg.VirtualHosts { 269 | vhost := s.vhosts[name] 270 | if vhost == nil { 271 | vhost = &VirtualHost{Name: name} 272 | s.vhosts[name] = vhost 273 | } 274 | vhost.Add(service) 275 | } 276 | 277 | return nil 278 | } 279 | 280 | // Replace the service's configuration, or update its list of backends. 281 | // Replacing a configuration will shutdown the existing service, and start a 282 | // new one, which will cause the listening socket to be temporarily 283 | // unavailable. 284 | func (s *ServiceRegistry) UpdateService(newCfg client.ServiceConfig) error { 285 | s.Lock() 286 | defer s.Unlock() 287 | 288 | log.Debug("Updating Service:", newCfg.Name) 289 | service, ok := s.svcs[newCfg.Name] 290 | if !ok { 291 | log.Debug("Service not found:", newCfg.Name) 292 | return ErrNoService 293 | } 294 | 295 | currentCfg := service.Config() 296 | newCfg = currentCfg.Merge(newCfg) 297 | 298 | if err := service.UpdateConfig(newCfg); err != nil { 299 | return err 300 | } 301 | 302 | // Lots of looping here (including fetching the Config, but the cardinality 303 | // of Backends shouldn't be very large, and the default RoundRobin balancing 304 | // is much simpler with a slice. 305 | 306 | // we're going to update just the backends for this config 307 | // get a map of what's already running 308 | currentBackends := make(map[string]client.BackendConfig) 309 | for _, backendCfg := range currentCfg.Backends { 310 | currentBackends[backendCfg.Name] = backendCfg 311 | } 312 | 313 | // Keep existing backends when they have equivalent config. 314 | // Update changed backends, and add new ones. 315 | for _, newBackend := range newCfg.Backends { 316 | current, ok := currentBackends[newBackend.Name] 317 | if ok && current.Equal(newBackend) { 318 | log.Debugf("Backend %s/%s unchanged", service.Name, current.Name) 319 | // no change for this one 320 | delete(currentBackends, current.Name) 321 | continue 322 | } 323 | 324 | // we need to remove and re-add this backend 325 | log.Debugf("Updating Backend %s/%s", service.Name, newBackend.Name) 326 | service.remove(newBackend.Name) 327 | service.add(NewBackend(newBackend)) 328 | 329 | delete(currentBackends, newBackend.Name) 330 | } 331 | 332 | // remove any left over backends 333 | for name := range currentBackends { 334 | log.Debugf("Removing Backend %s/%s", service.Name, name) 335 | service.remove(name) 336 | } 337 | 338 | if currentCfg.Equal(newCfg) { 339 | log.Debugf("Service Unchanged %s", service.Name) 340 | return nil 341 | } 342 | 343 | // replace error pages if there's any change 344 | if !reflect.DeepEqual(service.errPagesCfg, newCfg.ErrorPages) { 345 | log.Debugf("Updating ErrorPages") 346 | service.errPagesCfg = newCfg.ErrorPages 347 | service.errorPages.Update(newCfg.ErrorPages) 348 | } 349 | 350 | s.updateVHosts(service, filterEmpty(newCfg.VirtualHosts)) 351 | 352 | return nil 353 | } 354 | 355 | // update the VirtualHost entries for this service 356 | // only to be called from UpdateService. 357 | func (s *ServiceRegistry) updateVHosts(service *Service, newHosts []string) { 358 | // We could just clear the vhosts and the new list since we're doing 359 | // this all while the registry is locked, but because we want sane log 360 | // messages about adding remove endpoints, we have to diff the slices 361 | // anyway. 362 | 363 | oldHosts := service.VirtualHosts 364 | sort.Strings(oldHosts) 365 | sort.Strings(newHosts) 366 | 367 | // find the relative compliments of each set of hostnames 368 | var remove, add []string 369 | i, j := 0, 0 370 | for i < len(oldHosts) && j < len(newHosts) { 371 | if oldHosts[i] != newHosts[j] { 372 | if oldHosts[i] < newHosts[j] { 373 | // oldHosts[i] can't be in newHosts 374 | remove = append(remove, oldHosts[i]) 375 | i++ 376 | continue 377 | } else { 378 | // newHosts[j] can't be in oldHosts 379 | add = append(add, newHosts[j]) 380 | j++ 381 | continue 382 | } 383 | } 384 | i++ 385 | j++ 386 | } 387 | if i < len(oldHosts) { 388 | // there's more! 389 | remove = append(remove, oldHosts[i:]...) 390 | } 391 | if j < len(newHosts) { 392 | add = append(add, newHosts[j:]...) 393 | } 394 | 395 | // remove existing vhost entries for this service, and add new ones 396 | for _, name := range remove { 397 | vhost := s.vhosts[name] 398 | if vhost != nil { 399 | vhost.Remove(service) 400 | } 401 | if vhost.Len() == 0 { 402 | log.Println("Removing empty VirtualHost", name) 403 | delete(s.vhosts, name) 404 | } 405 | } 406 | 407 | for _, name := range add { 408 | vhost := s.vhosts[name] 409 | if vhost == nil { 410 | vhost = &VirtualHost{Name: name} 411 | s.vhosts[name] = vhost 412 | } 413 | vhost.Add(service) 414 | } 415 | 416 | // and replace the list 417 | service.VirtualHosts = newHosts 418 | } 419 | 420 | func (s *ServiceRegistry) RemoveService(name string) error { 421 | s.Lock() 422 | defer s.Unlock() 423 | 424 | svc, ok := s.svcs[name] 425 | if ok { 426 | log.Debugf("Removing Service %s", svc.Name) 427 | delete(s.svcs, name) 428 | svc.stop() 429 | 430 | for host, vhost := range s.vhosts { 431 | vhost.Remove(svc) 432 | 433 | removeVhost := true 434 | for _, service := range s.svcs { 435 | for _, h := range service.VirtualHosts { 436 | if host == h { 437 | // FIXME: is this still correct? NOT TESTED! 438 | // vhost exists in another service, so leave it 439 | removeVhost = false 440 | break 441 | } 442 | } 443 | } 444 | if removeVhost { 445 | log.Debugf("Removing VirtualHost %s", host) 446 | delete(s.vhosts, host) 447 | 448 | } 449 | } 450 | 451 | return nil 452 | } 453 | return ErrNoService 454 | } 455 | 456 | func (s *ServiceRegistry) ServiceStats(serviceName string) (ServiceStat, error) { 457 | s.Lock() 458 | defer s.Unlock() 459 | 460 | service, ok := s.svcs[serviceName] 461 | if !ok { 462 | return ServiceStat{}, ErrNoService 463 | } 464 | return service.Stats(), nil 465 | } 466 | 467 | func (s *ServiceRegistry) ServiceConfig(serviceName string) (client.ServiceConfig, error) { 468 | s.Lock() 469 | defer s.Unlock() 470 | 471 | service, ok := s.svcs[serviceName] 472 | if !ok { 473 | return client.ServiceConfig{}, ErrNoService 474 | } 475 | return service.Config(), nil 476 | } 477 | 478 | func (s *ServiceRegistry) BackendStats(serviceName, backendName string) (BackendStat, error) { 479 | s.Lock() 480 | defer s.Unlock() 481 | 482 | service, ok := s.svcs[serviceName] 483 | if !ok { 484 | return BackendStat{}, ErrNoService 485 | } 486 | 487 | for _, backend := range service.Backends { 488 | if backendName == backend.Name { 489 | return backend.Stats(), nil 490 | } 491 | } 492 | return BackendStat{}, ErrNoBackend 493 | } 494 | 495 | // Add or update a Backend on an existing Service. 496 | func (s *ServiceRegistry) AddBackend(svcName string, backendCfg client.BackendConfig) error { 497 | s.Lock() 498 | defer s.Unlock() 499 | 500 | service, ok := s.svcs[svcName] 501 | if !ok { 502 | return ErrNoService 503 | } 504 | 505 | log.Debugf("Adding Backend %s/%s", service.Name, backendCfg.Name) 506 | service.add(NewBackend(backendCfg)) 507 | return nil 508 | } 509 | 510 | // Remove a Backend from an existing Service. 511 | func (s *ServiceRegistry) RemoveBackend(svcName, backendName string) error { 512 | s.Lock() 513 | defer s.Unlock() 514 | 515 | log.Debugf("Removing Backend %s/%s", svcName, backendName) 516 | service, ok := s.svcs[svcName] 517 | if !ok { 518 | return ErrNoService 519 | } 520 | 521 | if !service.remove(backendName) { 522 | return ErrNoBackend 523 | } 524 | return nil 525 | } 526 | 527 | func (s *ServiceRegistry) Stats() []ServiceStat { 528 | s.Lock() 529 | defer s.Unlock() 530 | 531 | stats := []ServiceStat{} 532 | for _, service := range s.svcs { 533 | stats = append(stats, service.Stats()) 534 | } 535 | 536 | return stats 537 | } 538 | 539 | func (s *ServiceRegistry) Config() client.Config { 540 | s.Lock() 541 | defer s.Unlock() 542 | 543 | // make sure the old ServiceConfigs are purged when we copy the slice 544 | s.cfg.Services = nil 545 | 546 | cfg := s.cfg 547 | for _, service := range s.svcs { 548 | cfg.Services = append(cfg.Services, service.Config()) 549 | } 550 | 551 | return cfg 552 | } 553 | 554 | func (s *ServiceRegistry) String() string { 555 | return string(marshal(s.Config())) 556 | } 557 | 558 | // set any missing global configuration on a new ServiceConfig. 559 | // ServiceRegistry *must* be locked. 560 | func (s *ServiceRegistry) setServiceDefaults(svc *client.ServiceConfig) { 561 | if svc.Balance == "" && s.cfg.Balance != "" { 562 | svc.Balance = s.cfg.Balance 563 | } 564 | if svc.CheckInterval == 0 && s.cfg.CheckInterval != 0 { 565 | svc.CheckInterval = s.cfg.CheckInterval 566 | } 567 | if svc.Fall == 0 && s.cfg.Fall != 0 { 568 | svc.Fall = s.cfg.Fall 569 | } 570 | if svc.Rise == 0 && s.cfg.Rise != 0 { 571 | svc.Rise = s.cfg.Rise 572 | } 573 | if svc.ClientTimeout == 0 && s.cfg.ClientTimeout != 0 { 574 | svc.ClientTimeout = s.cfg.ClientTimeout 575 | } 576 | if svc.ServerTimeout == 0 && s.cfg.ServerTimeout != 0 { 577 | svc.ServerTimeout = s.cfg.ServerTimeout 578 | } 579 | if svc.DialTimeout == 0 && s.cfg.DialTimeout != 0 { 580 | svc.DialTimeout = s.cfg.DialTimeout 581 | } 582 | if s.cfg.HTTPSRedirect { 583 | svc.HTTPSRedirect = true 584 | } 585 | } 586 | -------------------------------------------------------------------------------- /shuttle_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "io/ioutil" 7 | "net" 8 | "os" 9 | "sync" 10 | "testing" 11 | "time" 12 | 13 | "github.com/litl/shuttle/client" 14 | "github.com/litl/shuttle/log" 15 | . "gopkg.in/check.v1" 16 | ) 17 | 18 | func init() { 19 | debug = os.Getenv("SHUTTLE_DEBUG") == "1" 20 | 21 | if debug { 22 | log.DefaultLogger.Level = log.DEBUG 23 | } else { 24 | log.DefaultLogger = log.New(ioutil.Discard, "", 0) 25 | } 26 | } 27 | 28 | // something that can wrap a gocheck.C testing.T or testing.B 29 | // Just add more methods as we need them. 30 | type Tester interface { 31 | Fatal(args ...interface{}) 32 | Fatalf(format string, args ...interface{}) 33 | Log(args ...interface{}) 34 | Logf(format string, args ...interface{}) 35 | Assert(interface{}, Checker, ...interface{}) 36 | } 37 | 38 | func Test(t *testing.T) { TestingT(t) } 39 | 40 | type BasicSuite struct { 41 | servers []*testServer 42 | service *Service 43 | } 44 | 45 | var _ = Suite(&BasicSuite{}) 46 | 47 | // Make Setup and TearDown more generic, so we can bypass the gocheck Suite if 48 | // needed. 49 | func mySetup(s *BasicSuite, t Tester) { 50 | // start 4 possible backend servers 51 | for i := 0; i < 4; i++ { 52 | server, err := NewTestServer("127.0.0.1:0", t) 53 | if err != nil { 54 | t.Fatal(err) 55 | } 56 | s.servers = append(s.servers, server) 57 | } 58 | 59 | svcCfg := client.ServiceConfig{ 60 | Name: "testService", 61 | Addr: "127.0.0.1:2000", 62 | ClientTimeout: 1000, 63 | ServerTimeout: 1000, 64 | } 65 | 66 | if err := Registry.AddService(svcCfg); err != nil { 67 | t.Fatal(err) 68 | } 69 | 70 | s.service = Registry.GetService(svcCfg.Name) 71 | } 72 | 73 | // shutdown our backend servers 74 | func myTearDown(s *BasicSuite, t Tester) { 75 | for _, s := range s.servers { 76 | s.Stop() 77 | } 78 | 79 | // get rid of the servers refs too! 80 | s.servers = nil 81 | 82 | // clear global defaults in Registry 83 | Registry.cfg.Balance = "" 84 | Registry.cfg.CheckInterval = 0 85 | Registry.cfg.Fall = 0 86 | Registry.cfg.Rise = 0 87 | Registry.cfg.ClientTimeout = 0 88 | Registry.cfg.ServerTimeout = 0 89 | Registry.cfg.DialTimeout = 0 90 | 91 | err := Registry.RemoveService(s.service.Name) 92 | if err != nil { 93 | t.Fatalf("could not remove service '%s': %s", s.service.Name, err) 94 | } 95 | } 96 | 97 | func (s *BasicSuite) SetUpTest(c *C) { 98 | mySetup(s, c) 99 | } 100 | 101 | func (s *BasicSuite) TearDownTest(c *C) { 102 | myTearDown(s, c) 103 | } 104 | 105 | // Add a default backend for the next server we have running 106 | func (s *BasicSuite) AddBackend(c Tester) { 107 | // get the backends via Config to use the Service's locking. 108 | svcCfg := s.service.Config() 109 | next := len(svcCfg.Backends) 110 | if next >= len(s.servers) { 111 | c.Fatal("no more servers") 112 | } 113 | 114 | name := fmt.Sprintf("backend_%d", next) 115 | cfg := client.BackendConfig{ 116 | Name: name, 117 | Addr: s.servers[next].addr, 118 | CheckAddr: s.servers[next].addr, 119 | } 120 | 121 | s.service.add(NewBackend(cfg)) 122 | } 123 | 124 | // Connect to address, and check response after write. 125 | func checkResp(addr, expected string, c Tester) { 126 | conn, err := net.Dial("tcp", addr) 127 | if err != nil { 128 | c.Fatal(err) 129 | } 130 | defer conn.Close() 131 | 132 | if _, err := io.WriteString(conn, "testing\n"); err != nil { 133 | c.Fatal(err) 134 | } 135 | 136 | buff := make([]byte, 1024) 137 | n, err := conn.Read(buff) 138 | if err != nil { 139 | c.Fatal(err) 140 | } 141 | 142 | resp := string(buff[:n]) 143 | if resp == "" { 144 | c.Fatal("No response") 145 | } 146 | 147 | if expected != "" && resp != expected { 148 | c.Fatal("Expected ", expected, ", got ", resp) 149 | } 150 | } 151 | 152 | func (s *BasicSuite) TestSingleBackend(c *C) { 153 | s.AddBackend(c) 154 | 155 | checkResp(s.service.Addr, s.servers[0].addr, c) 156 | } 157 | 158 | func (s *BasicSuite) TestRoundRobin(c *C) { 159 | s.AddBackend(c) 160 | s.AddBackend(c) 161 | 162 | checkResp(s.service.Addr, s.servers[0].addr, c) 163 | checkResp(s.service.Addr, s.servers[1].addr, c) 164 | checkResp(s.service.Addr, s.servers[0].addr, c) 165 | checkResp(s.service.Addr, s.servers[1].addr, c) 166 | } 167 | 168 | func (s *BasicSuite) TestWeightedRoundRobin(c *C) { 169 | s.AddBackend(c) 170 | s.AddBackend(c) 171 | s.AddBackend(c) 172 | 173 | s.service.Backends[0].Weight = 1 174 | s.service.Backends[1].Weight = 2 175 | s.service.Backends[2].Weight = 3 176 | 177 | // we already checked that we connect to the correct backends, 178 | // so skip the tcp connection this time. 179 | 180 | // one from the first server 181 | c.Assert(s.service.next()[0].Name, Equals, "backend_0") 182 | // A weight of 2 should return twice 183 | c.Assert(s.service.next()[0].Name, Equals, "backend_1") 184 | c.Assert(s.service.next()[0].Name, Equals, "backend_1") 185 | // And a weight of 3 should return thrice 186 | c.Assert(s.service.next()[0].Name, Equals, "backend_2") 187 | c.Assert(s.service.next()[0].Name, Equals, "backend_2") 188 | c.Assert(s.service.next()[0].Name, Equals, "backend_2") 189 | // and once around or good measure 190 | c.Assert(s.service.next()[0].Name, Equals, "backend_0") 191 | } 192 | 193 | func (s *BasicSuite) TestLeastConn(c *C) { 194 | // replace out default service with one using LeastConn balancing 195 | Registry.RemoveService("testService") 196 | svcCfg := client.ServiceConfig{ 197 | Name: "testService", 198 | Addr: "127.0.0.1:2223", 199 | Balance: "LC", 200 | } 201 | 202 | if err := Registry.AddService(svcCfg); err != nil { 203 | c.Fatal(err) 204 | } 205 | s.service = Registry.GetService("testService") 206 | 207 | s.AddBackend(c) 208 | s.AddBackend(c) 209 | 210 | // tie up 4 connections to the backends 211 | buff := make([]byte, 64) 212 | for i := 0; i < 4; i++ { 213 | conn, e := net.Dial("tcp", s.service.Addr) 214 | if e != nil { 215 | c.Fatal(e) 216 | } 217 | // we need to make a call on this proxy to ensure the backend 218 | // connection is complete. 219 | if _, err := io.WriteString(conn, "connect\n"); err != nil { 220 | c.Fatal(err) 221 | } 222 | 223 | n, err := conn.Read(buff) 224 | if err != nil || n == 0 { 225 | c.Fatal("no response from backend") 226 | } 227 | 228 | defer conn.Close() 229 | } 230 | 231 | s.AddBackend(c) 232 | 233 | checkResp(s.service.Addr, s.servers[2].addr, c) 234 | checkResp(s.service.Addr, s.servers[2].addr, c) 235 | } 236 | 237 | // Test health check by taking down a server from a configured backend 238 | func (s *BasicSuite) TestFailedCheck(c *C) { 239 | s.service.CheckInterval = 500 240 | s.service.Fall = 1 241 | s.AddBackend(c) 242 | 243 | stats := s.service.Stats() 244 | c.Assert(stats.Backends[0].Up, Equals, true) 245 | 246 | // Stop the server, and see if the backend shows Down after our check 247 | // interval. 248 | s.servers[0].Stop() 249 | time.Sleep(800 * time.Millisecond) 250 | 251 | stats = s.service.Stats() 252 | c.Assert(stats.Backends[0].Up, Equals, false) 253 | c.Assert(stats.Backends[0].CheckFail, Equals, 1) 254 | 255 | // now try and connect to the service 256 | conn, err := net.Dial("tcp", s.service.Addr) 257 | if err != nil { 258 | // we should still get an initial connection 259 | c.Fatal(err) 260 | } 261 | 262 | b := make([]byte, 1024) 263 | n, err := conn.Read(b) 264 | if n != 0 || err != io.EOF { 265 | c.Fatal("connection should have been closed") 266 | } 267 | 268 | // now bring that server back up 269 | server, err := NewTestServer(s.servers[0].addr, c) 270 | if err != nil { 271 | c.Fatal(err) 272 | } 273 | s.servers[0] = server 274 | 275 | time.Sleep(800 * time.Millisecond) 276 | stats = s.service.Stats() 277 | c.Assert(stats.Backends[0].Up, Equals, true) 278 | } 279 | 280 | // Make sure the connection is re-dispatched when Dialing a backend fails 281 | func (s *BasicSuite) TestConnectAny(c *C) { 282 | s.service.CheckInterval = 2000 283 | s.service.Fall = 2 284 | s.AddBackend(c) 285 | s.AddBackend(c) 286 | 287 | // kill the first server 288 | s.servers[0].Stop() 289 | 290 | stats := s.service.Stats() 291 | c.Assert(stats.Backends[0].Up, Equals, true) 292 | 293 | // Backend 0 still shows up, but we should get connected to backend 1 294 | checkResp(s.service.Addr, s.servers[1].addr, c) 295 | } 296 | 297 | // Update a backend in place 298 | func (s *BasicSuite) TestUpdateBackend(c *C) { 299 | s.service.CheckInterval = 500 300 | s.service.Fall = 1 301 | s.AddBackend(c) 302 | 303 | cfg := s.service.Config() 304 | backendCfg := cfg.Backends[0] 305 | 306 | c.Assert(backendCfg.CheckAddr, Equals, backendCfg.Addr) 307 | 308 | backendCfg.CheckAddr = "" 309 | s.service.add(NewBackend(backendCfg)) 310 | 311 | // see if the config reflects the new backend 312 | cfg = s.service.Config() 313 | c.Assert(len(cfg.Backends), Equals, 1) 314 | c.Assert(cfg.Backends[0].CheckAddr, Equals, "") 315 | 316 | // Stopping the server should not take down the backend 317 | // since there is no longer a Check address. 318 | s.servers[0].Stop() 319 | time.Sleep(800 * time.Millisecond) 320 | 321 | stats := s.service.Stats() 322 | c.Assert(stats.Backends[0].Up, Equals, true) 323 | // should have been no check failures 324 | c.Assert(stats.Backends[0].CheckFail, Equals, 0) 325 | } 326 | 327 | // Test removal of a single Backend from a service with multiple. 328 | func (s *BasicSuite) TestRemoveBackend(c *C) { 329 | s.AddBackend(c) 330 | s.AddBackend(c) 331 | 332 | stats, err := Registry.ServiceStats("testService") 333 | if err != nil { 334 | c.Fatal(err) 335 | } 336 | 337 | c.Assert(len(stats.Backends), Equals, 2) 338 | 339 | backend1 := stats.Backends[0].Name 340 | 341 | err = Registry.RemoveBackend("testService", backend1) 342 | if err != nil { 343 | c.Fatal(err) 344 | } 345 | 346 | stats, err = Registry.ServiceStats("testService") 347 | if err != nil { 348 | c.Fatal(err) 349 | } 350 | 351 | c.Assert(len(stats.Backends), Equals, 1) 352 | 353 | _, err = Registry.BackendStats("testService", backend1) 354 | c.Assert(err, Equals, ErrNoBackend) 355 | } 356 | 357 | func (s *BasicSuite) TestInvalidUpdateService(c *C) { 358 | svcCfg := client.ServiceConfig{ 359 | Name: "Update", 360 | Addr: "127.0.0.1:9324", 361 | } 362 | 363 | if err := Registry.AddService(svcCfg); err != nil { 364 | c.Fatal(err) 365 | } 366 | 367 | svc := Registry.GetService("Update") 368 | if svc == nil { 369 | c.Fatal(ErrNoService) 370 | } 371 | 372 | svcCfg.Addr = "127.0.0.1:9425" 373 | 374 | // Make sure we can't add the same service again 375 | if err := Registry.AddService(svcCfg); err == nil { 376 | c.Fatal(err) 377 | } 378 | 379 | // the update should fail, because it would require a new listener 380 | if err := Registry.UpdateService(svcCfg); err == nil { 381 | c.Fatal(err) 382 | } 383 | 384 | // change the addres back, and try to update ClientTimeout 385 | svcCfg.Addr = "127.0.0.1:9324" 386 | svcCfg.ClientTimeout = 1234 387 | 388 | // the update should fail, because it would require a new listener 389 | if err := Registry.UpdateService(svcCfg); err == nil { 390 | c.Fatal(err) 391 | } 392 | 393 | if err := Registry.RemoveService("Update"); err != nil { 394 | c.Fatal(err) 395 | } 396 | } 397 | 398 | // check valid service updates 399 | func (s *BasicSuite) TestUpdateService(c *C) { 400 | svcCfg := client.ServiceConfig{ 401 | Name: "Update2", 402 | Addr: "127.0.0.1:9324", 403 | } 404 | 405 | if err := Registry.AddService(svcCfg); err != nil { 406 | c.Fatal(err) 407 | } 408 | 409 | svc := Registry.GetService("Update2") 410 | if svc == nil { 411 | c.Fatal(ErrNoService) 412 | } 413 | 414 | svcCfg.ServerTimeout = 1234 415 | svcCfg.HTTPSRedirect = true 416 | svcCfg.Fall = 5 417 | svcCfg.Rise = 6 418 | svcCfg.Balance = "LC" 419 | 420 | // Now update the service for real 421 | if err := Registry.UpdateService(svcCfg); err != nil { 422 | c.Fatal(err) 423 | } 424 | 425 | svc = Registry.GetService("Update2") 426 | if svc == nil { 427 | c.Fatal(ErrNoService) 428 | } 429 | c.Assert(svc.ServerTimeout, Equals, 1234*time.Millisecond) 430 | c.Assert(svc.HTTPSRedirect, Equals, true) 431 | c.Assert(svc.Fall, Equals, 5) 432 | c.Assert(svc.Rise, Equals, 6) 433 | c.Assert(svc.Balance, Equals, "LC") 434 | 435 | if err := Registry.RemoveService("Update2"); err != nil { 436 | c.Fatal(err) 437 | } 438 | } 439 | 440 | // Add backends and run response tests in parallel 441 | func (s *BasicSuite) TestParallel(c *C) { 442 | var wg sync.WaitGroup 443 | 444 | client := func(i int) { 445 | s.AddBackend(c) 446 | // do a bunch of new connections in unison 447 | for i := 0; i < 100; i++ { 448 | checkResp(s.service.Addr, "", c) 449 | } 450 | 451 | conn, err := net.Dial("tcp", s.service.Addr) 452 | if err != nil { 453 | // we should still get an initial connection 454 | c.Fatal(err) 455 | } 456 | defer conn.Close() 457 | 458 | // now do some more continuous ping-pongs with the server 459 | buff := make([]byte, 1024) 460 | 461 | for i := 0; i < 1000; i++ { 462 | n, err := io.WriteString(conn, "Testing testing\n") 463 | if err != nil || n == 0 { 464 | c.Fatal("couldn't write:", err) 465 | } 466 | 467 | n, err = conn.Read(buff) 468 | if err != nil || n == 0 { 469 | c.Fatal("no response:", err) 470 | } 471 | } 472 | wg.Done() 473 | } 474 | 475 | for i := 0; i < 4; i++ { 476 | wg.Add(1) 477 | go client(i) 478 | } 479 | 480 | wg.Wait() 481 | } 482 | 483 | type UDPSuite struct { 484 | servers []*udpTestServer 485 | service *Service 486 | } 487 | 488 | var _ = Suite(&UDPSuite{}) 489 | 490 | func (s *UDPSuite) SetUpTest(c *C) { 491 | svcCfg := client.ServiceConfig{ 492 | Name: "testService", 493 | Addr: "127.0.0.1:11110", 494 | Network: "udp", 495 | } 496 | 497 | if err := Registry.AddService(svcCfg); err != nil { 498 | c.Fatal(err) 499 | } 500 | 501 | s.service = Registry.GetService(svcCfg.Name) 502 | } 503 | 504 | func (s *UDPSuite) TearDownTest(c *C) { 505 | for _, s := range s.servers { 506 | s.Stop() 507 | } 508 | 509 | // get rid of the servers refs too! 510 | s.servers = nil 511 | 512 | // clear global defaults in Registry 513 | Registry.cfg.Balance = "" 514 | Registry.cfg.CheckInterval = 0 515 | Registry.cfg.Fall = 0 516 | Registry.cfg.Rise = 0 517 | Registry.cfg.ClientTimeout = 0 518 | Registry.cfg.ServerTimeout = 0 519 | Registry.cfg.DialTimeout = 0 520 | 521 | err := Registry.RemoveService(s.service.Name) 522 | if err != nil { 523 | c.Fatalf("could not remove service '%s': %s", s.service.Name, err) 524 | } 525 | } 526 | 527 | // Add a UDP service, make sure it works, and remove it 528 | func (s *UDPSuite) TestAddRemove(c *C) { 529 | bckCfg := client.BackendConfig{ 530 | Name: "UDPServer", 531 | Addr: "127.0.0.1:11111", 532 | Network: "udp", 533 | } 534 | 535 | s.service.add(NewBackend(bckCfg)) 536 | 537 | lAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") 538 | rAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:11110") 539 | conn, err := net.ListenUDP("udp", lAddr) 540 | if err != nil { 541 | c.Fatal(err) 542 | } 543 | 544 | n, err := conn.WriteToUDP([]byte("TEST"), rAddr) 545 | if err != nil { 546 | c.Fatal(err) 547 | } 548 | 549 | // try to make sure packets were delivered and read 550 | time.Sleep(100 * time.Millisecond) 551 | 552 | stats := s.service.Stats() 553 | c.Assert(stats.Rcvd, Equals, int64(n)) 554 | 555 | ok := s.service.remove("UDPServer") 556 | c.Assert(ok, Equals, true) 557 | 558 | stats = s.service.Stats() 559 | c.Assert(len(stats.Backends), Equals, 0) 560 | 561 | } 562 | 563 | // Make sure UDP Services work, and check our WeightedRoundRobin since we're 564 | // already testing it. 565 | func (s *UDPSuite) TestWeightedRoundRobin(c *C) { 566 | servers := make([]*udpTestServer, 3) 567 | 568 | var err error 569 | for i, _ := range servers { 570 | servers[i], err = NewUDPTestServer(fmt.Sprintf("127.0.0.1:1111%d", i+1), c) 571 | if err != nil { 572 | c.Fatal(err) 573 | } 574 | bckCfg := client.BackendConfig{ 575 | Name: fmt.Sprintf("UDPServer%d", i+1), 576 | Addr: servers[i].addr, 577 | Weight: i + 1, 578 | Network: "udp", 579 | } 580 | s.service.add(NewBackend(bckCfg)) 581 | } 582 | 583 | defer func() { 584 | for _, s := range servers { 585 | s.Stop() 586 | } 587 | }() 588 | 589 | lAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") 590 | rAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:11110") 591 | 592 | conn, err := net.ListenUDP("udp", lAddr) 593 | if err != nil { 594 | c.Fatal(err) 595 | } 596 | 597 | for i := 0; i < 12; i++ { 598 | msg := fmt.Sprintf("TEST_%d", i) 599 | _, err := conn.WriteToUDP([]byte(msg), rAddr) 600 | if err != nil { 601 | c.Fatal(err) 602 | } 603 | } 604 | 605 | // The order that packets are delivered to the 3 servers 606 | time.Sleep(100 * time.Millisecond) 607 | rcvOrder := []int{ 608 | 0, 6, // servers[0] 609 | 1, 2, 7, 8, // servers[1] 610 | 3, 4, 5, 9, 10, 11, // servers[2] 611 | } 612 | 613 | packetNum := 0 614 | for _, srv := range servers { 615 | srv.Lock() 616 | for _, p := range srv.packets { 617 | c.Assert(string(p), Equals, fmt.Sprintf("TEST_%d", rcvOrder[packetNum])) 618 | packetNum++ 619 | } 620 | srv.Unlock() 621 | } 622 | } 623 | 624 | // Throw a lot of packets at the proxy then count what went through 625 | // This doesn't pass or fail, just logs how much made it to the backend. 626 | func (s *UDPSuite) TestSpew(c *C) { 627 | server, err := NewUDPTestServer("127.0.0.1:11111", c) 628 | if err != nil { 629 | c.Fatal(err) 630 | } 631 | defer server.Stop() 632 | 633 | bckCfg := client.BackendConfig{ 634 | Name: "UDPServer", 635 | Addr: server.addr, 636 | Network: "udp", 637 | } 638 | s.service.add(NewBackend(bckCfg)) 639 | 640 | lAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") 641 | rAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:11110") 642 | 643 | conn, err := net.ListenUDP("udp", lAddr) 644 | if err != nil { 645 | c.Fatal(err) 646 | } 647 | 648 | msg := []byte("10 BYTES") 649 | toSend := 10000 650 | for i := 0; i < toSend; i++ { 651 | n, err := conn.WriteToUDP(msg, rAddr) 652 | if err != nil || n != len(msg) { 653 | c.Fatal(fmt.Sprintf("%d %s", n, err)) 654 | } 655 | } 656 | 657 | // make sure everything the service received made it to the backend. 658 | time.Sleep(100 * time.Millisecond) 659 | stats := s.service.Stats() 660 | c.Logf("Sent %d packets", toSend) 661 | c.Logf("Proxied %d packets", stats.Rcvd/10) 662 | c.Logf("Received %d packets", server.count) 663 | } 664 | -------------------------------------------------------------------------------- /service.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | "sync" 8 | "sync/atomic" 9 | "time" 10 | 11 | "github.com/litl/shuttle/client" 12 | "github.com/litl/shuttle/log" 13 | ) 14 | 15 | var ( 16 | Registry = ServiceRegistry{ 17 | svcs: make(map[string]*Service), 18 | vhosts: make(map[string]*VirtualHost), 19 | } 20 | 21 | ErrInvalidServiceUpdate = fmt.Errorf("configuration requires a new service") 22 | ) 23 | 24 | type Service struct { 25 | sync.Mutex 26 | Name string 27 | Addr string 28 | HTTPSRedirect bool 29 | VirtualHosts []string 30 | Backends []*Backend 31 | Balance string 32 | CheckInterval int 33 | Fall int 34 | Rise int 35 | ClientTimeout time.Duration 36 | ServerTimeout time.Duration 37 | DialTimeout time.Duration 38 | Sent int64 39 | Rcvd int64 40 | Errors int64 41 | HTTPConns int64 42 | HTTPErrors int64 43 | HTTPActive int64 44 | Network string 45 | MaintenanceMode bool 46 | 47 | // Next returns the backends in priority order. 48 | next func() []*Backend 49 | 50 | // the last backend we used and the number of times we used it 51 | lastBackend int 52 | lastCount int 53 | 54 | // Each Service owns it's own netowrk listener 55 | tcpListener net.Listener 56 | udpListener *net.UDPConn 57 | 58 | // reverse proxy for vhost routing 59 | httpProxy *ReverseProxy 60 | 61 | // Custom Pages to backend error responses 62 | errorPages *ErrorResponse 63 | 64 | // the original map of errors as loaded in by a config 65 | errPagesCfg map[string][]int 66 | 67 | // net.Dialer so we don't need to allocate one every time 68 | dialer *net.Dialer 69 | } 70 | 71 | // Stats returned about a service 72 | type ServiceStat struct { 73 | Name string `json:"name"` 74 | Addr string `json:"address"` 75 | VirtualHosts []string `json:"virtual_hosts"` 76 | Backends []BackendStat `json:"backends"` 77 | Balance string `json:"balance"` 78 | CheckInterval int `json:"check_interval"` 79 | Fall int `json:"fall"` 80 | Rise int `json:"rise"` 81 | ClientTimeout int `json:"client_timeout"` 82 | ServerTimeout int `json:"server_timeout"` 83 | DialTimeout int `json:"connect_timeout"` 84 | Sent int64 `json:"sent"` 85 | Rcvd int64 `json:"received"` 86 | Errors int64 `json:"errors"` 87 | Conns int64 `json:"connections"` 88 | Active int64 `json:"active"` 89 | HTTPActive int64 `json:"http_active"` 90 | HTTPConns int64 `json:"http_connections"` 91 | HTTPErrors int64 `json:"http_errors"` 92 | } 93 | 94 | // Create a Service from a config struct 95 | func NewService(cfg client.ServiceConfig) *Service { 96 | s := &Service{ 97 | Name: cfg.Name, 98 | Addr: cfg.Addr, 99 | Balance: cfg.Balance, 100 | CheckInterval: cfg.CheckInterval, 101 | Fall: cfg.Fall, 102 | Rise: cfg.Rise, 103 | HTTPSRedirect: cfg.HTTPSRedirect, 104 | VirtualHosts: cfg.VirtualHosts, 105 | ClientTimeout: time.Duration(cfg.ClientTimeout) * time.Millisecond, 106 | ServerTimeout: time.Duration(cfg.ServerTimeout) * time.Millisecond, 107 | DialTimeout: time.Duration(cfg.DialTimeout) * time.Millisecond, 108 | errorPages: NewErrorResponse(cfg.ErrorPages), 109 | errPagesCfg: cfg.ErrorPages, 110 | Network: cfg.Network, 111 | MaintenanceMode: cfg.MaintenanceMode, 112 | } 113 | 114 | // TODO: insert this into the backends too 115 | s.dialer = &net.Dialer{ 116 | Timeout: s.DialTimeout, 117 | KeepAlive: 30 * time.Second, 118 | } 119 | 120 | // create our reverse proxy, using our load-balancing Dial method 121 | proxyTransport := &http.Transport{ 122 | Dial: s.Dial, 123 | MaxIdleConnsPerHost: 10, 124 | } 125 | s.httpProxy = NewReverseProxy(proxyTransport) 126 | s.httpProxy.FlushInterval = time.Second 127 | s.httpProxy.Director = func(req *http.Request) { 128 | req.URL.Scheme = "http" 129 | } 130 | 131 | s.httpProxy.OnResponse = []ProxyCallback{logProxyRequest, s.errStats, s.errorPages.CheckResponse} 132 | 133 | if s.CheckInterval == 0 { 134 | s.CheckInterval = client.DefaultCheckInterval 135 | } 136 | if s.Rise == 0 { 137 | s.Rise = client.DefaultRise 138 | } 139 | if s.Fall == 0 { 140 | s.Fall = client.DefaultFall 141 | } 142 | 143 | if s.Network == "" { 144 | s.Network = client.DefaultNet 145 | } 146 | 147 | for _, b := range cfg.Backends { 148 | s.add(NewBackend(b)) 149 | } 150 | 151 | switch cfg.Balance { 152 | case client.RoundRobin: 153 | s.next = s.roundRobin 154 | case client.LeastConn: 155 | s.next = s.leastConn 156 | default: 157 | if cfg.Balance != "" { 158 | log.Warnf("invalid balancing algorithm '%s'", cfg.Balance) 159 | } 160 | s.next = s.roundRobin 161 | } 162 | 163 | return s 164 | } 165 | 166 | // Update the running configuration. 167 | func (s *Service) UpdateConfig(cfg client.ServiceConfig) error { 168 | s.Lock() 169 | defer s.Unlock() 170 | 171 | if s.ClientTimeout != time.Duration(cfg.ClientTimeout)*time.Millisecond { 172 | return ErrInvalidServiceUpdate 173 | } 174 | 175 | if s.Addr != "" && s.Addr != cfg.Addr { 176 | return ErrInvalidServiceUpdate 177 | } 178 | 179 | s.CheckInterval = cfg.CheckInterval 180 | s.Fall = cfg.Fall 181 | s.Rise = cfg.Rise 182 | s.ServerTimeout = time.Duration(cfg.ServerTimeout) * time.Millisecond 183 | s.DialTimeout = time.Duration(cfg.DialTimeout) * time.Millisecond 184 | s.HTTPSRedirect = cfg.HTTPSRedirect 185 | s.MaintenanceMode = cfg.MaintenanceMode 186 | 187 | if s.Balance != cfg.Balance { 188 | s.Balance = cfg.Balance 189 | switch s.Balance { 190 | case client.RoundRobin: 191 | s.next = s.roundRobin 192 | case client.LeastConn: 193 | s.next = s.leastConn 194 | default: 195 | if cfg.Balance != "" { 196 | log.Warnf("invalid balancing algorithm '%s'", cfg.Balance) 197 | } 198 | s.next = s.roundRobin 199 | } 200 | } 201 | 202 | return nil 203 | } 204 | 205 | func (s *Service) Stats() ServiceStat { 206 | s.Lock() 207 | defer s.Unlock() 208 | 209 | stats := ServiceStat{ 210 | Name: s.Name, 211 | Addr: s.Addr, 212 | VirtualHosts: s.VirtualHosts, 213 | Balance: s.Balance, 214 | CheckInterval: s.CheckInterval, 215 | Fall: s.Fall, 216 | Rise: s.Rise, 217 | ClientTimeout: int(s.ClientTimeout / time.Millisecond), 218 | ServerTimeout: int(s.ServerTimeout / time.Millisecond), 219 | DialTimeout: int(s.DialTimeout / time.Millisecond), 220 | HTTPConns: s.HTTPConns, 221 | HTTPErrors: s.HTTPErrors, 222 | HTTPActive: atomic.LoadInt64(&s.HTTPActive), 223 | Rcvd: atomic.LoadInt64(&s.Rcvd), 224 | Sent: atomic.LoadInt64(&s.Sent), 225 | } 226 | 227 | for _, b := range s.Backends { 228 | stats.Backends = append(stats.Backends, b.Stats()) 229 | stats.Sent += b.Sent 230 | stats.Rcvd += b.Rcvd 231 | stats.Errors += b.Errors 232 | stats.Conns += b.Conns 233 | stats.Active += b.Active 234 | } 235 | 236 | return stats 237 | } 238 | 239 | func (s *Service) Config() client.ServiceConfig { 240 | s.Lock() 241 | defer s.Unlock() 242 | return s.config() 243 | } 244 | 245 | func (s *Service) config() client.ServiceConfig { 246 | 247 | config := client.ServiceConfig{ 248 | Name: s.Name, 249 | Addr: s.Addr, 250 | VirtualHosts: s.VirtualHosts, 251 | HTTPSRedirect: s.HTTPSRedirect, 252 | Balance: s.Balance, 253 | CheckInterval: s.CheckInterval, 254 | Fall: s.Fall, 255 | Rise: s.Rise, 256 | ClientTimeout: int(s.ClientTimeout / time.Millisecond), 257 | ServerTimeout: int(s.ServerTimeout / time.Millisecond), 258 | DialTimeout: int(s.DialTimeout / time.Millisecond), 259 | ErrorPages: s.errPagesCfg, 260 | Network: s.Network, 261 | MaintenanceMode: s.MaintenanceMode, 262 | } 263 | for _, b := range s.Backends { 264 | config.Backends = append(config.Backends, b.Config()) 265 | } 266 | 267 | return config 268 | } 269 | 270 | func (s *Service) String() string { 271 | return string(marshal(s.Config())) 272 | } 273 | 274 | func (s *Service) get(name string) *Backend { 275 | s.Lock() 276 | defer s.Unlock() 277 | 278 | for _, b := range s.Backends { 279 | if b.Name == name { 280 | return b 281 | } 282 | } 283 | return nil 284 | } 285 | 286 | // Add or replace a Backend in this service 287 | func (s *Service) add(backend *Backend) { 288 | s.Lock() 289 | defer s.Unlock() 290 | 291 | log.Printf("Adding %s backend %s{%s} for %s at %s", backend.Network, backend.Name, backend.Addr, s.Name, s.Addr) 292 | backend.up = true 293 | backend.rwTimeout = s.ServerTimeout 294 | backend.dialTimeout = s.DialTimeout 295 | backend.checkInterval = time.Duration(s.CheckInterval) * time.Millisecond 296 | 297 | // We may add some allowed protocol bridging in the future, but for now just fail 298 | if s.Network[:3] != backend.Network[:3] { 299 | log.Errorf("ERROR: backend %s cannot use network '%s'", backend.Name, backend.Network) 300 | } 301 | 302 | // replace an existing backend if we have it. 303 | for i, b := range s.Backends { 304 | if b.Name == backend.Name { 305 | b.Stop() 306 | s.Backends[i] = backend 307 | backend.Start() 308 | return 309 | } 310 | } 311 | 312 | s.Backends = append(s.Backends, backend) 313 | 314 | backend.Start() 315 | } 316 | 317 | // Remove a Backend by name 318 | func (s *Service) remove(name string) bool { 319 | s.Lock() 320 | defer s.Unlock() 321 | 322 | for i, b := range s.Backends { 323 | if b.Name == name { 324 | log.Printf("Removing %s backend %s{%s} for %s at %s", b.Network, b.Name, b.Addr, s.Name, s.Addr) 325 | last := len(s.Backends) - 1 326 | deleted := b 327 | s.Backends[i], s.Backends[last] = s.Backends[last], nil 328 | s.Backends = s.Backends[:last] 329 | deleted.Stop() 330 | return true 331 | } 332 | } 333 | return false 334 | } 335 | 336 | // Fill out and verify service 337 | func (s *Service) start() (err error) { 338 | s.Lock() 339 | defer s.Unlock() 340 | 341 | if s.Backends == nil { 342 | s.Backends = make([]*Backend, 0) 343 | } 344 | 345 | switch s.Network { 346 | case "tcp", "tcp4", "tcp6": 347 | log.Printf("Starting TCP listener for %s on %s", s.Name, s.Addr) 348 | 349 | s.tcpListener, err = newTimeoutListener(s.Network, s.Addr, s.ClientTimeout) 350 | if err != nil { 351 | return err 352 | } 353 | 354 | go s.runTCP() 355 | case "udp", "udp4", "udp6": 356 | log.Printf("Starting UDP listener for %s on %s", s.Name, s.Addr) 357 | 358 | laddr, err := net.ResolveUDPAddr(s.Network, s.Addr) 359 | if err != nil { 360 | return err 361 | } 362 | s.udpListener, err = net.ListenUDP(s.Network, laddr) 363 | if err != nil { 364 | return err 365 | } 366 | 367 | go s.runUDP() 368 | default: 369 | return fmt.Errorf("Error: unknown network '%s'", s.Network) 370 | } 371 | 372 | return nil 373 | } 374 | 375 | // Start the Service's Accept loop 376 | func (s *Service) runTCP() { 377 | for { 378 | conn, err := s.tcpListener.Accept() 379 | if err != nil { 380 | if err, ok := err.(net.Error); ok && err.Temporary() { 381 | log.Warnln("WARN:", err) 382 | continue 383 | } 384 | // we must be getting shut down 385 | return 386 | } 387 | 388 | go s.connectTCP(conn) 389 | } 390 | } 391 | 392 | func (s *Service) runUDP() { 393 | buff := make([]byte, 65536) 394 | conn := s.udpListener 395 | 396 | // for UDP, we can proxy the data right here. 397 | for { 398 | n, _, err := conn.ReadFromUDP(buff) 399 | if err != nil { 400 | // we can't cleanly signal the Read to stop, so we have to 401 | // string-match this error. 402 | if err.Error() == "use of closed network connection" { 403 | // normal shutdown 404 | return 405 | } else if err, ok := err.(net.Error); ok && err.Temporary() { 406 | log.Warnf("WARN: %s", err.Error()) 407 | } else { 408 | // unexpected error, log it before exiting 409 | log.Errorf("ERROR: %s", err.Error()) 410 | atomic.AddInt64(&s.Errors, 1) 411 | return 412 | } 413 | } 414 | 415 | if n == 0 { 416 | continue 417 | } 418 | 419 | atomic.AddInt64(&s.Rcvd, int64(n)) 420 | 421 | backend := s.udpRoundRobin() 422 | if backend == nil { 423 | // this could produce a lot of message 424 | // TODO: log some %, or max rate of messages 425 | continue 426 | } 427 | 428 | n, err = conn.WriteTo(buff[:n], backend.udpAddr) 429 | if err != nil { 430 | if err, ok := err.(net.Error); ok && err.Temporary() { 431 | log.Warnf("WARN: %s", err.Error()) 432 | continue 433 | } 434 | 435 | log.Errorf("ERROR: %s", err.Error()) 436 | atomic.AddInt64(&s.Errors, 1) 437 | } else { 438 | atomic.AddInt64(&s.Sent, int64(n)) 439 | } 440 | } 441 | } 442 | 443 | // Return the addresses of the current backends in the order they would be balanced 444 | func (s *Service) NextAddrs() []string { 445 | backends := s.next() 446 | 447 | addrs := make([]string, len(backends)) 448 | for i, b := range backends { 449 | addrs[i] = b.Addr 450 | } 451 | return addrs 452 | } 453 | 454 | // Available returns the number of backends marked as Up 455 | func (s *Service) Available() int { 456 | s.Lock() 457 | defer s.Unlock() 458 | 459 | if s.MaintenanceMode { 460 | return 0 461 | } 462 | 463 | available := 0 464 | for _, b := range s.Backends { 465 | if b.Up() { 466 | available++ 467 | } 468 | } 469 | return available 470 | } 471 | 472 | // Dial a backend by address. 473 | // This way we can wrap the connection to provide our timeout settings, as well 474 | // as hook it into the backend stats. 475 | // We return an error if we don't have a backend which matches. 476 | // If Dial returns an error, we wrap it in DialError, so that a ReverseProxy 477 | // can determine if it's safe to call RoundTrip again on a new host. 478 | func (s *Service) Dial(nw, addr string) (net.Conn, error) { 479 | s.Lock() 480 | 481 | var backend *Backend 482 | for _, b := range s.Backends { 483 | if b.Addr == addr { 484 | backend = b 485 | break 486 | } 487 | } 488 | s.Unlock() 489 | 490 | if backend == nil { 491 | return nil, DialError{fmt.Errorf("no backend matching %s", addr)} 492 | } 493 | 494 | srvConn, err := s.dialer.Dial(nw, backend.Addr) 495 | if err != nil { 496 | log.Errorf("ERROR: connecting to backend %s/%s: %s", s.Name, backend.Name, err) 497 | atomic.AddInt64(&backend.Errors, 1) 498 | return nil, DialError{err} 499 | } 500 | 501 | conn := &shuttleConn{ 502 | TCPConn: srvConn.(*net.TCPConn), 503 | rwTimeout: s.ServerTimeout, 504 | written: &backend.Sent, 505 | read: &backend.Rcvd, 506 | connected: &backend.HTTPActive, 507 | } 508 | 509 | atomic.AddInt64(&backend.Conns, 1) 510 | 511 | // NOTE: this relies on conn.Close being called, which *should* happen in 512 | // all cases, but may be at fault in the active count becomes skewed in 513 | // some error case. 514 | atomic.AddInt64(&backend.HTTPActive, 1) 515 | return conn, nil 516 | } 517 | 518 | func (s *Service) connectTCP(cliConn net.Conn) { 519 | backends := s.next() 520 | 521 | // Try the first backend given, but if that fails, cycle through them all 522 | // to make a best effort to connect the client. 523 | for _, b := range backends { 524 | srvConn, err := s.dialer.Dial(b.Network, b.Addr) 525 | if err != nil { 526 | log.Errorf("ERROR: connecting to backend %s/%s: %s", s.Name, b.Name, err) 527 | atomic.AddInt64(&b.Errors, 1) 528 | continue 529 | } 530 | 531 | b.Proxy(srvConn, cliConn) 532 | return 533 | } 534 | 535 | log.Errorf("ERROR: no backend for %s", s.Name) 536 | cliConn.Close() 537 | } 538 | 539 | // Stop the Service's Accept loop by closing the Listener, 540 | // and stop all backends for this service. 541 | func (s *Service) stop() { 542 | s.Lock() 543 | defer s.Unlock() 544 | 545 | log.Printf("Stopping Listener for %s on %s:%s", s.Name, s.Network, s.Addr) 546 | for _, backend := range s.Backends { 547 | backend.Stop() 548 | } 549 | 550 | switch s.Network { 551 | case "tcp", "tcp4", "tcp6": 552 | // the service may have been bad, and the listener failed 553 | if s.tcpListener == nil { 554 | return 555 | } 556 | 557 | err := s.tcpListener.Close() 558 | if err != nil { 559 | log.Println(err) 560 | } 561 | 562 | case "udp", "udp4", "udp6": 563 | if s.udpListener == nil { 564 | return 565 | } 566 | err := s.udpListener.Close() 567 | if err != nil { 568 | log.Println(err) 569 | } 570 | } 571 | 572 | } 573 | 574 | // Provide a ServeHTTP method for out ReverseProxy 575 | func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { 576 | atomic.AddInt64(&s.HTTPConns, 1) 577 | atomic.AddInt64(&s.HTTPActive, 1) 578 | defer atomic.AddInt64(&s.HTTPActive, -1) 579 | 580 | if s.HTTPSRedirect { 581 | if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") != "https" { 582 | //TODO: verify RequestURI 583 | redirLoc := "https://" + r.Host + r.RequestURI 584 | http.Redirect(w, r, redirLoc, http.StatusMovedPermanently) 585 | return 586 | } 587 | } 588 | 589 | if s.MaintenanceMode { 590 | // TODO: Should we increment HTTPErrors here as well? 591 | logRequest(r, http.StatusServiceUnavailable, "", nil, 0) 592 | errPage := s.errorPages.Get(http.StatusServiceUnavailable) 593 | if errPage != nil { 594 | headers := w.Header() 595 | for key, val := range errPage.Header() { 596 | headers[key] = val 597 | } 598 | } 599 | w.WriteHeader(http.StatusServiceUnavailable) 600 | if errPage != nil { 601 | w.Write(errPage.Body()) 602 | } 603 | return 604 | } 605 | 606 | s.httpProxy.ServeHTTP(w, r, s.NextAddrs()) 607 | } 608 | 609 | func (s *Service) errStats(pr *ProxyRequest) bool { 610 | if pr.ProxyError != nil { 611 | atomic.AddInt64(&s.HTTPErrors, 1) 612 | } 613 | return true 614 | } 615 | 616 | // A net.Listener that provides a read/write timeout 617 | type timeoutListener struct { 618 | *net.TCPListener 619 | rwTimeout time.Duration 620 | 621 | // these aren't reported yet, but our new counting connections need to 622 | // update something 623 | read int64 624 | written int64 625 | } 626 | 627 | func newTimeoutListener(netw, addr string, timeout time.Duration) (net.Listener, error) { 628 | lAddr, err := net.ResolveTCPAddr(netw, addr) 629 | if err != nil { 630 | return nil, err 631 | } 632 | 633 | l, err := net.ListenTCP(netw, lAddr) 634 | if err != nil { 635 | return nil, err 636 | } 637 | 638 | tl := &timeoutListener{ 639 | TCPListener: l, 640 | rwTimeout: timeout, 641 | } 642 | return tl, nil 643 | } 644 | 645 | func (l *timeoutListener) Accept() (net.Conn, error) { 646 | conn, err := l.TCPListener.AcceptTCP() 647 | if err != nil { 648 | return nil, err 649 | } 650 | 651 | conn.SetKeepAlive(true) 652 | conn.SetKeepAlivePeriod(3 * time.Minute) 653 | 654 | sc := &shuttleConn{ 655 | TCPConn: conn, 656 | rwTimeout: l.rwTimeout, 657 | read: &l.read, 658 | written: &l.written, 659 | } 660 | return sc, nil 661 | } 662 | -------------------------------------------------------------------------------- /admin_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io/ioutil" 8 | "net" 9 | "net/http" 10 | "net/http/httptest" 11 | "net/url" 12 | "sync" 13 | 14 | "github.com/litl/shuttle/client" 15 | . "gopkg.in/check.v1" 16 | ) 17 | 18 | type HTTPSuite struct { 19 | servers []*testServer 20 | backendServers []*testHTTPServer 21 | httpSvr *httptest.Server 22 | httpAddr string 23 | httpPort string 24 | httpsAddr string 25 | httpsPort string 26 | } 27 | 28 | var _ = Suite(&HTTPSuite{}) 29 | 30 | func (s *HTTPSuite) SetUpSuite(c *C) { 31 | Registry = ServiceRegistry{ 32 | svcs: make(map[string]*Service), 33 | vhosts: make(map[string]*VirtualHost), 34 | } 35 | 36 | addHandlers() 37 | s.httpSvr = httptest.NewServer(nil) 38 | 39 | httpServer := &http.Server{ 40 | Addr: "127.0.0.1:0", 41 | } 42 | 43 | httpRouter = NewHostRouter(httpServer) 44 | httpReady := make(chan bool) 45 | go httpRouter.Start(httpReady) 46 | <-httpReady 47 | 48 | // now build an HTTPS server 49 | tlsCfg, err := loadCerts("./testdata") 50 | if err != nil { 51 | c.Fatal(err) 52 | return 53 | } 54 | 55 | httpsServer := &http.Server{ 56 | Addr: "127.0.0.1:0", 57 | TLSConfig: tlsCfg, 58 | } 59 | 60 | httpsRouter := NewHostRouter(httpsServer) 61 | httpsRouter.Scheme = "https" 62 | 63 | httpsReady := make(chan bool) 64 | go httpsRouter.Start(httpsReady) 65 | <-httpsReady 66 | 67 | // assign the test router's addr to the glolbal 68 | s.httpAddr = httpRouter.listener.Addr().String() 69 | s.httpPort = fmt.Sprintf("%d", httpRouter.listener.Addr().(*net.TCPAddr).Port) 70 | s.httpsAddr = httpsRouter.listener.Addr().String() 71 | s.httpsPort = fmt.Sprintf("%d", httpsRouter.listener.Addr().(*net.TCPAddr).Port) 72 | } 73 | 74 | func (s *HTTPSuite) TearDownSuite(c *C) { 75 | s.httpSvr.Close() 76 | httpRouter.Stop() 77 | } 78 | 79 | func (s *HTTPSuite) SetUpTest(c *C) { 80 | // start 4 possible backend servers 81 | for i := 0; i < 4; i++ { 82 | server, err := NewTestServer("127.0.0.1:0", c) 83 | if err != nil { 84 | c.Fatal(err) 85 | } 86 | s.servers = append(s.servers, server) 87 | } 88 | 89 | for i := 0; i < 4; i++ { 90 | server, err := NewHTTPTestServer("127.0.0.1:0", c) 91 | if err != nil { 92 | c.Fatal(err) 93 | } 94 | 95 | s.backendServers = append(s.backendServers, server) 96 | } 97 | } 98 | 99 | // shutdown our backend servers 100 | func (s *HTTPSuite) TearDownTest(c *C) { 101 | for _, s := range s.servers { 102 | s.Stop() 103 | } 104 | 105 | s.servers = s.servers[:0] 106 | 107 | // clear global defaults in Registry 108 | Registry.cfg.Balance = "" 109 | Registry.cfg.CheckInterval = 0 110 | Registry.cfg.Fall = 0 111 | Registry.cfg.Rise = 0 112 | Registry.cfg.ClientTimeout = 0 113 | Registry.cfg.ServerTimeout = 0 114 | Registry.cfg.DialTimeout = 0 115 | 116 | for _, s := range s.backendServers { 117 | s.Close() 118 | } 119 | 120 | s.backendServers = s.backendServers[:0] 121 | 122 | for _, svc := range Registry.svcs { 123 | Registry.RemoveService(svc.Name) 124 | } 125 | 126 | } 127 | 128 | // These don't yet *really* test anything other than code coverage 129 | func (s *HTTPSuite) TestAddService(c *C) { 130 | svcDef := bytes.NewReader([]byte(`{"address": "127.0.0.1:9000"}`)) 131 | req, _ := http.NewRequest("PUT", s.httpSvr.URL+"/testService", svcDef) 132 | resp, err := http.DefaultClient.Do(req) 133 | if err != nil { 134 | c.Fatal(err) 135 | } 136 | defer resp.Body.Close() 137 | 138 | body, _ := ioutil.ReadAll(resp.Body) 139 | c.Assert(Registry.String(), DeepEquals, string(body)) 140 | } 141 | 142 | func (s *HTTPSuite) TestAddBackend(c *C) { 143 | svcDef := bytes.NewReader([]byte(`{"address": "127.0.0.1:9000"}`)) 144 | req, _ := http.NewRequest("PUT", s.httpSvr.URL+"/testService", svcDef) 145 | _, err := http.DefaultClient.Do(req) 146 | if err != nil { 147 | c.Fatal(err) 148 | } 149 | 150 | backendDef := bytes.NewReader([]byte(`{"address": "127.0.0.1:9001"}`)) 151 | req, _ = http.NewRequest("PUT", s.httpSvr.URL+"/testService/testBackend", backendDef) 152 | resp, err := http.DefaultClient.Do(req) 153 | if err != nil { 154 | c.Fatal(err) 155 | } 156 | defer resp.Body.Close() 157 | 158 | body, _ := ioutil.ReadAll(resp.Body) 159 | c.Assert(Registry.String(), DeepEquals, string(body)) 160 | } 161 | 162 | func (s *HTTPSuite) TestReAddBackend(c *C) { 163 | svcDef := bytes.NewReader([]byte(`{"address": "127.0.0.1:9000"}`)) 164 | req, _ := http.NewRequest("PUT", s.httpSvr.URL+"/testService", svcDef) 165 | _, err := http.DefaultClient.Do(req) 166 | if err != nil { 167 | c.Fatal(err) 168 | } 169 | 170 | backendDef := bytes.NewReader([]byte(`{"address": "127.0.0.1:9001"}`)) 171 | req, _ = http.NewRequest("PUT", s.httpSvr.URL+"/testService/testBackend", backendDef) 172 | firstResp, err := http.DefaultClient.Do(req) 173 | if err != nil { 174 | c.Fatal(err) 175 | } 176 | defer firstResp.Body.Close() 177 | 178 | firstBody, _ := ioutil.ReadAll(firstResp.Body) 179 | 180 | backendDef = bytes.NewReader([]byte(`{"address": "127.0.0.1:9001"}`)) 181 | req, _ = http.NewRequest("PUT", s.httpSvr.URL+"/testService/testBackend", backendDef) 182 | secResp, err := http.DefaultClient.Do(req) 183 | if err != nil { 184 | c.Fatal(err) 185 | } 186 | defer secResp.Body.Close() 187 | 188 | secBody, _ := ioutil.ReadAll(secResp.Body) 189 | 190 | c.Assert(string(secBody), DeepEquals, string(firstBody)) 191 | } 192 | 193 | func (s *HTTPSuite) TestSimulAdd(c *C) { 194 | start := make(chan struct{}) 195 | testWG := new(sync.WaitGroup) 196 | 197 | svcCfg := client.ServiceConfig{ 198 | Name: "TestService", 199 | Addr: "127.0.0.1:9000", 200 | VirtualHosts: []string{"test-vhost"}, 201 | Backends: []client.BackendConfig{ 202 | client.BackendConfig{ 203 | Name: "vhost1", 204 | Addr: "127.0.0.1:9001", 205 | }, 206 | client.BackendConfig{ 207 | Name: "vhost2", 208 | Addr: "127.0.0.1:9002", 209 | }, 210 | }, 211 | } 212 | 213 | for i := 0; i < 8; i++ { 214 | testWG.Add(1) 215 | go func() { 216 | defer testWG.Done() 217 | //wait to start all at once 218 | <-start 219 | svcDef := bytes.NewReader(svcCfg.Marshal()) 220 | req, _ := http.NewRequest("PUT", s.httpSvr.URL+"/TestService", svcDef) 221 | resp, err := http.DefaultClient.Do(req) 222 | if err != nil { 223 | c.Fatal(err) 224 | } 225 | 226 | body, _ := ioutil.ReadAll(resp.Body) 227 | 228 | respCfg := client.Config{} 229 | err = json.Unmarshal(body, &respCfg) 230 | 231 | // We're only checking to ensure we have 1 service with the proper number of backends 232 | c.Assert(len(respCfg.Services), Equals, 1) 233 | c.Assert(len(respCfg.Services[0].Backends), Equals, 2) 234 | c.Assert(len(respCfg.Services[0].VirtualHosts), Equals, 1) 235 | }() 236 | } 237 | 238 | close(start) 239 | testWG.Wait() 240 | } 241 | 242 | func (s *HTTPSuite) TestRouter(c *C) { 243 | svcCfg := client.ServiceConfig{ 244 | Name: "VHostTest", 245 | Addr: "127.0.0.1:9000", 246 | VirtualHosts: []string{"test-vhost"}, 247 | } 248 | 249 | for _, srv := range s.backendServers { 250 | cfg := client.BackendConfig{ 251 | Addr: srv.addr, 252 | Name: srv.addr, 253 | } 254 | svcCfg.Backends = append(svcCfg.Backends, cfg) 255 | } 256 | 257 | err := Registry.AddService(svcCfg) 258 | if err != nil { 259 | c.Fatal(err) 260 | } 261 | 262 | for _, srv := range s.backendServers { 263 | checkHTTP("http://"+s.httpAddr+"/addr", "test-vhost", srv.addr, 200, c) 264 | } 265 | } 266 | 267 | func (s *HTTPSuite) TestAddRemoveVHosts(c *C) { 268 | svcCfg := client.ServiceConfig{ 269 | Name: "VHostTest", 270 | Addr: "127.0.0.1:9000", 271 | VirtualHosts: []string{"test-vhost"}, 272 | } 273 | 274 | for _, srv := range s.backendServers { 275 | cfg := client.BackendConfig{ 276 | Addr: srv.addr, 277 | Name: srv.addr, 278 | } 279 | svcCfg.Backends = append(svcCfg.Backends, cfg) 280 | } 281 | 282 | err := Registry.AddService(svcCfg) 283 | if err != nil { 284 | c.Fatal(err) 285 | } 286 | 287 | // now update the service with another vhost 288 | svcCfg.VirtualHosts = append(svcCfg.VirtualHosts, "test-vhost-2") 289 | err = Registry.UpdateService(svcCfg) 290 | if err != nil { 291 | c.Fatal(err) 292 | } 293 | 294 | if Registry.VHostsLen() != 2 { 295 | c.Fatal("missing new vhost") 296 | } 297 | 298 | // remove the first vhost 299 | svcCfg.VirtualHosts = []string{"test-vhost-2"} 300 | err = Registry.UpdateService(svcCfg) 301 | if err != nil { 302 | c.Fatal(err) 303 | } 304 | 305 | if Registry.VHostsLen() != 1 { 306 | c.Fatal("extra vhost:", Registry.VHostsLen()) 307 | } 308 | 309 | // check responses from this new vhost 310 | for _, srv := range s.backendServers { 311 | checkHTTP("http://"+s.httpAddr+"/addr", "test-vhost-2", srv.addr, 200, c) 312 | } 313 | } 314 | 315 | // Add multiple services under the same VirtualHost 316 | // Each proxy request should round-robin through the two of them 317 | func (s *HTTPSuite) TestMultiServiceVHost(c *C) { 318 | svcCfgOne := client.ServiceConfig{ 319 | Name: "VHostTest", 320 | Addr: "127.0.0.1:9000", 321 | VirtualHosts: []string{"test-vhost"}, 322 | } 323 | 324 | svcCfgTwo := client.ServiceConfig{ 325 | Name: "VHostTest2", 326 | Addr: "127.0.0.1:9001", 327 | VirtualHosts: []string{"test-vhost-2"}, 328 | } 329 | 330 | var backends []client.BackendConfig 331 | for _, srv := range s.backendServers { 332 | cfg := client.BackendConfig{ 333 | Addr: srv.addr, 334 | Name: srv.addr, 335 | } 336 | backends = append(backends, cfg) 337 | } 338 | 339 | svcCfgOne.Backends = backends 340 | svcCfgTwo.Backends = backends 341 | 342 | err := Registry.AddService(svcCfgOne) 343 | if err != nil { 344 | c.Fatal(err) 345 | } 346 | 347 | err = Registry.AddService(svcCfgTwo) 348 | if err != nil { 349 | c.Fatal(err) 350 | } 351 | 352 | for _, srv := range s.backendServers { 353 | checkHTTP("http://"+s.httpAddr+"/addr", "test-vhost", srv.addr, 200, c) 354 | checkHTTP("http://"+s.httpAddr+"/addr", "test-vhost-2", srv.addr, 200, c) 355 | } 356 | } 357 | 358 | func (s *HTTPSuite) TestAddRemoveBackends(c *C) { 359 | svcCfg := client.ServiceConfig{ 360 | Name: "VHostTest", 361 | Addr: "127.0.0.1:9000", 362 | } 363 | 364 | err := Registry.AddService(svcCfg) 365 | if err != nil { 366 | c.Fatal(err) 367 | } 368 | 369 | for _, srv := range s.backendServers { 370 | cfg := client.BackendConfig{ 371 | Addr: srv.addr, 372 | Name: srv.addr, 373 | } 374 | svcCfg.Backends = append(svcCfg.Backends, cfg) 375 | } 376 | 377 | err = Registry.UpdateService(svcCfg) 378 | if err != nil { 379 | c.Fatal(err) 380 | } 381 | 382 | cfg := Registry.Config() 383 | if !svcCfg.DeepEqual(cfg.Services[0]) { 384 | c.Errorf("we should have 1 service, we have %d", len(cfg.Services)) 385 | c.Errorf("we should have 4 backends, we have %d", len(cfg.Services[0].Backends)) 386 | } 387 | 388 | svcCfg.Backends = svcCfg.Backends[:3] 389 | err = Registry.UpdateService(svcCfg) 390 | if err != nil { 391 | c.Fatal(err) 392 | } 393 | 394 | cfg = Registry.Config() 395 | if !svcCfg.DeepEqual(cfg.Services[0]) { 396 | c.Errorf("we should have 1 service, we have %d", len(cfg.Services)) 397 | c.Errorf("we should have 3 backends, we have %d", len(cfg.Services[0].Backends)) 398 | } 399 | 400 | } 401 | 402 | func (s *HTTPSuite) TestHTTPAddRemoveBackends(c *C) { 403 | svcCfg := client.ServiceConfig{ 404 | Name: "VHostTest", 405 | Addr: "127.0.0.1:9000", 406 | } 407 | 408 | err := Registry.AddService(svcCfg) 409 | if err != nil { 410 | c.Fatal(err) 411 | } 412 | 413 | for _, srv := range s.backendServers { 414 | cfg := client.BackendConfig{ 415 | Addr: srv.addr, 416 | Name: srv.addr, 417 | } 418 | svcCfg.Backends = append(svcCfg.Backends, cfg) 419 | } 420 | 421 | req, _ := http.NewRequest("PUT", s.httpSvr.URL+"/VHostTest", bytes.NewReader(svcCfg.Marshal())) 422 | _, err = http.DefaultClient.Do(req) 423 | if err != nil { 424 | c.Fatal(err) 425 | } 426 | 427 | cfg := Registry.Config() 428 | if !svcCfg.DeepEqual(cfg.Services[0]) { 429 | c.Errorf("we should have 1 service, we have %d", len(cfg.Services)) 430 | c.Errorf("we should have 4 backends, we have %d", len(cfg.Services[0].Backends)) 431 | } 432 | 433 | // remove a backend from the config and submit it again 434 | svcCfg.Backends = svcCfg.Backends[:3] 435 | err = Registry.UpdateService(svcCfg) 436 | if err != nil { 437 | c.Fatal(err) 438 | } 439 | 440 | req, _ = http.NewRequest("PUT", s.httpSvr.URL+"/VHostTest", bytes.NewReader(svcCfg.Marshal())) 441 | _, err = http.DefaultClient.Do(req) 442 | if err != nil { 443 | c.Fatal(err) 444 | } 445 | 446 | // now check the config via what's returned from the http server 447 | resp, err := http.Get(s.httpSvr.URL + "/_config") 448 | if err != nil { 449 | c.Fatal(err) 450 | } 451 | defer resp.Body.Close() 452 | 453 | cfg = client.Config{} 454 | body, _ := ioutil.ReadAll(resp.Body) 455 | err = json.Unmarshal(body, &cfg) 456 | if err != nil { 457 | c.Fatal(err) 458 | } 459 | 460 | if !svcCfg.DeepEqual(cfg.Services[0]) { 461 | c.Errorf("we should have 1 service, we have %d", len(cfg.Services)) 462 | c.Errorf("we should have 3 backends, we have %d", len(cfg.Services[0].Backends)) 463 | } 464 | } 465 | 466 | func (s *HTTPSuite) TestErrorPage(c *C) { 467 | svcCfg := client.ServiceConfig{ 468 | Name: "VHostTest", 469 | Addr: "127.0.0.1:9000", 470 | VirtualHosts: []string{"test-vhost"}, 471 | } 472 | 473 | okServer := s.backendServers[0] 474 | errServer := s.backendServers[1] 475 | 476 | // Add one backend to service requests 477 | cfg := client.BackendConfig{ 478 | Addr: okServer.addr, 479 | Name: okServer.addr, 480 | } 481 | svcCfg.Backends = append(svcCfg.Backends, cfg) 482 | 483 | // use another backend to provide the error page 484 | svcCfg.ErrorPages = map[string][]int{ 485 | "http://" + errServer.addr + "/error": []int{400, 503}, 486 | } 487 | 488 | err := Registry.AddService(svcCfg) 489 | if err != nil { 490 | c.Fatal(err) 491 | } 492 | 493 | // check that the normal response comes from srv1 494 | checkHTTP("http://"+s.httpAddr+"/addr", "test-vhost", okServer.addr, 200, c) 495 | // verify that an unregistered error doesn't give the cached page 496 | checkHTTP("http://"+s.httpAddr+"/error?code=504", "test-vhost", okServer.addr, 504, c) 497 | // now see if the registered error comes from srv2 498 | checkHTTP("http://"+s.httpAddr+"/error?code=503", "test-vhost", errServer.addr, 503, c) 499 | 500 | // now check that we got the header cached in the error page as well 501 | req, err := http.NewRequest("GET", "http://"+s.httpAddr+"/error?code=503", nil) 502 | if err != nil { 503 | c.Fatal(err) 504 | } 505 | 506 | req.Host = "test-vhost" 507 | resp, err := http.DefaultClient.Do(req) 508 | if err != nil { 509 | c.Fatal(err) 510 | } 511 | 512 | c.Assert(resp.StatusCode, Equals, 503) 513 | c.Assert(resp.Header.Get("Last-Modified"), Equals, errServer.addr) 514 | } 515 | 516 | func (s *HTTPSuite) TestUpdateServiceDefaults(c *C) { 517 | svcCfg := client.ServiceConfig{ 518 | Name: "TestService", 519 | Addr: "127.0.0.1:9000", 520 | Backends: []client.BackendConfig{ 521 | client.BackendConfig{ 522 | Name: "Backend1", 523 | Addr: "127.0.0.1:9001", 524 | }, 525 | }, 526 | } 527 | 528 | svcDef := bytes.NewBuffer(svcCfg.Marshal()) 529 | req, _ := http.NewRequest("PUT", s.httpSvr.URL+"/TestService", svcDef) 530 | resp, err := http.DefaultClient.Do(req) 531 | if err != nil { 532 | c.Fatal(err) 533 | } 534 | resp.Body.Close() 535 | 536 | // Now update the Service in-place 537 | svcCfg.ServerTimeout = 1234 538 | svcDef.Reset() 539 | svcDef.Write(svcCfg.Marshal()) 540 | 541 | req, _ = http.NewRequest("PUT", s.httpSvr.URL+"/TestService", svcDef) 542 | resp, err = http.DefaultClient.Do(req) 543 | if err != nil { 544 | c.Fatal(err) 545 | } 546 | 547 | body, _ := ioutil.ReadAll(resp.Body) 548 | resp.Body.Close() 549 | 550 | config := client.Config{} 551 | err = json.Unmarshal(body, &config) 552 | if err != nil { 553 | c.Fatal(err) 554 | } 555 | 556 | // make sure we don't see a second value 557 | found := false 558 | 559 | for _, svc := range config.Services { 560 | if svc.Name == "TestService" { 561 | if svc.ServerTimeout != svcCfg.ServerTimeout { 562 | c.Fatal("Service not updated") 563 | } else if found { 564 | c.Fatal("Multiple Service Definitions") 565 | } 566 | found = true 567 | } 568 | } 569 | } 570 | 571 | // Set some global defaults, and check that a new service inherits them all 572 | func (s *HTTPSuite) TestGlobalDefaults(c *C) { 573 | globalCfg := client.Config{ 574 | Balance: "LC", 575 | CheckInterval: 101, 576 | Fall: 7, 577 | Rise: 8, 578 | ClientTimeout: 102, 579 | ServerTimeout: 103, 580 | DialTimeout: 104, 581 | } 582 | 583 | globalDef := bytes.NewBuffer(globalCfg.Marshal()) 584 | req, _ := http.NewRequest("PUT", s.httpSvr.URL+"/", globalDef) 585 | resp, err := http.DefaultClient.Do(req) 586 | if err != nil { 587 | c.Fatal(err) 588 | } 589 | resp.Body.Close() 590 | 591 | svcCfg := client.ServiceConfig{ 592 | Name: "TestService", 593 | Addr: "127.0.0.1:9000", 594 | } 595 | 596 | svcDef := bytes.NewBuffer(svcCfg.Marshal()) 597 | req, _ = http.NewRequest("PUT", s.httpSvr.URL+"/TestService", svcDef) 598 | resp, err = http.DefaultClient.Do(req) 599 | if err != nil { 600 | c.Fatal(err) 601 | } 602 | resp.Body.Close() 603 | 604 | config := Registry.Config() 605 | 606 | c.Assert(len(config.Services), Equals, 1) 607 | service := config.Services[0] 608 | 609 | c.Assert(globalCfg.Balance, Equals, service.Balance) 610 | c.Assert(globalCfg.CheckInterval, Equals, service.CheckInterval) 611 | c.Assert(globalCfg.Fall, Equals, service.Fall) 612 | c.Assert(globalCfg.Rise, Equals, service.Rise) 613 | c.Assert(globalCfg.ClientTimeout, Equals, service.ClientTimeout) 614 | c.Assert(globalCfg.ServerTimeout, Equals, service.ServerTimeout) 615 | c.Assert(globalCfg.DialTimeout, Equals, service.DialTimeout) 616 | } 617 | 618 | // Test that we can route to Vhosts based on SNI 619 | func (s *HTTPSuite) TestHTTPSRouter(c *C) { 620 | srv1 := s.backendServers[0] 621 | srv2 := s.backendServers[1] 622 | 623 | svcCfgOne := client.ServiceConfig{ 624 | Name: "VHostTest1", 625 | Addr: "127.0.0.1:9000", 626 | VirtualHosts: []string{"vhost1.test", "alt.vhost1.test", "star.vhost1.test"}, 627 | Backends: []client.BackendConfig{ 628 | {Addr: srv1.addr}, 629 | }, 630 | } 631 | 632 | svcCfgTwo := client.ServiceConfig{ 633 | Name: "VHostTest2", 634 | Addr: "127.0.0.1:9001", 635 | VirtualHosts: []string{"vhost2.test", "alt.vhost2.test", "star.vhost2.test"}, 636 | Backends: []client.BackendConfig{ 637 | {Addr: srv2.addr}, 638 | }, 639 | } 640 | 641 | err := Registry.AddService(svcCfgOne) 642 | if err != nil { 643 | c.Fatal(err) 644 | } 645 | 646 | err = Registry.AddService(svcCfgTwo) 647 | if err != nil { 648 | c.Fatal(err) 649 | } 650 | 651 | // Our router has 2 certs, each with name.test and alt.name.test as DNS names. 652 | // checkHTTP has a fake dialer that resolves everything to 127.0.0.1. 653 | checkHTTP("https://vhost1.test:"+s.httpsPort+"/addr", "vhost1.test", srv1.addr, 200, c) 654 | checkHTTP("https://alt.vhost1.test:"+s.httpsPort+"/addr", "alt.vhost1.test", srv1.addr, 200, c) 655 | checkHTTP("https://star.vhost1.test:"+s.httpsPort+"/addr", "star.vhost1.test", srv1.addr, 200, c) 656 | 657 | checkHTTP("https://vhost2.test:"+s.httpsPort+"/addr", "vhost2.test", srv2.addr, 200, c) 658 | checkHTTP("https://alt.vhost2.test:"+s.httpsPort+"/addr", "alt.vhost2.test", srv2.addr, 200, c) 659 | checkHTTP("https://star.vhost2.test:"+s.httpsPort+"/addr", "star.vhost2.test", srv2.addr, 200, c) 660 | } 661 | 662 | // Verify that Settting HTTPSRedirect on a service works as expected for https 663 | // and for X-Forwarded-Proto:https. 664 | func (s *HTTPSuite) TestHTTPSRedirect(c *C) { 665 | srv1 := s.backendServers[0] 666 | 667 | svcCfgOne := client.ServiceConfig{ 668 | Name: "VHostTest1", 669 | Addr: "127.0.0.1:9000", 670 | HTTPSRedirect: true, 671 | VirtualHosts: []string{"vhost1.test", "alt.vhost1.test", "star.vhost1.test"}, 672 | Backends: []client.BackendConfig{ 673 | {Addr: srv1.addr}, 674 | }, 675 | } 676 | 677 | err := Registry.AddService(svcCfgOne) 678 | if err != nil { 679 | c.Fatal(err) 680 | } 681 | 682 | client := &http.Client{ 683 | Transport: &http.Transport{ 684 | Dial: localDial, 685 | }, 686 | CheckRedirect: func(req *http.Request, via []*http.Request) error { 687 | return fmt.Errorf("redirected") 688 | }, 689 | } 690 | 691 | // this should redirect to https 692 | reqHTTP, _ := http.NewRequest("HEAD", "http://localhost:"+s.httpPort+"/addr", nil) 693 | reqHTTP.Host = "vhost1.test" 694 | 695 | resp, err := client.Do(reqHTTP) 696 | if err != nil { 697 | if err, ok := err.(*url.Error); ok { 698 | if err.Err.Error() != "redirected" { 699 | c.Fatal(err) 700 | } 701 | } else { 702 | c.Fatal(err) 703 | } 704 | } 705 | c.Assert(resp.StatusCode, Equals, http.StatusMovedPermanently) 706 | 707 | // this should be OK 708 | reqHTTP.Header = map[string][]string{ 709 | "X-Forwarded-Proto": {"https"}, 710 | } 711 | resp, err = client.Do(reqHTTP) 712 | if err != nil { 713 | c.Fatal(err) 714 | } 715 | c.Assert(resp.StatusCode, Equals, http.StatusOK) 716 | } 717 | 718 | func (s *HTTPSuite) TestMaintenanceMode(c *C) { 719 | mainServer := s.backendServers[0] 720 | errServer := s.backendServers[1] 721 | 722 | svcCfg := client.ServiceConfig{ 723 | Name: "VHostTest1", 724 | Addr: "127.0.0.1:9000", 725 | VirtualHosts: []string{"vhost1.test"}, 726 | Backends: []client.BackendConfig{ 727 | {Addr: mainServer.addr}, 728 | }, 729 | MaintenanceMode: true, 730 | } 731 | 732 | if err := Registry.AddService(svcCfg); err != nil { 733 | c.Fatal(err) 734 | } 735 | 736 | // No error page is registered, so we should just get a 503 error with no body 737 | checkHTTP("https://vhost1.test:"+s.httpsPort+"/addr", "vhost1.test", "", 503, c) 738 | 739 | // Use another backend to provide the error page 740 | svcCfg.ErrorPages = map[string][]int{ 741 | "http://" + errServer.addr + "/error?code=503": []int{503}, 742 | } 743 | 744 | if err := Registry.UpdateService(svcCfg); err != nil { 745 | c.Fatal(err) 746 | } 747 | 748 | // Get a 503 error with the cached body 749 | checkHTTP("https://vhost1.test:"+s.httpsPort+"/addr", "vhost1.test", errServer.addr, 503, c) 750 | 751 | // Turn maintenance mode off 752 | svcCfg.MaintenanceMode = false 753 | 754 | if err := Registry.UpdateService(svcCfg); err != nil { 755 | c.Fatal(err) 756 | } 757 | 758 | checkHTTP("https://vhost1.test:"+s.httpsPort+"/addr", "vhost1.test", mainServer.addr, 200, c) 759 | 760 | // Turn it back on 761 | svcCfg.MaintenanceMode = true 762 | 763 | if err := Registry.UpdateService(svcCfg); err != nil { 764 | c.Fatal(err) 765 | } 766 | 767 | checkHTTP("https://vhost1.test:"+s.httpsPort+"/addr", "vhost1.test", errServer.addr, 503, c) 768 | } 769 | --------------------------------------------------------------------------------