├── .travis.yml ├── README.md ├── common ├── limits.go ├── limits_test.go ├── priority_queue.go └── priority_queue_test.go ├── config.yaml ├── config_test.yaml ├── golang-load-balancer ├── main.go └── main_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.8.x 5 | - 1.9.x 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A simple Golang HTTP load balancer 2 | 3 | [![Build Status](https://travis-ci.org/intrip/golang-load-balancer.svg?branch=http)](https://travis-ci.org/intrip/golang-load-balancer) 4 | 5 | A very simple HTTP load balancer written in Golang. 6 | 7 | ## Usage 8 | 9 | ``` 10 | $ go build 11 | $ ./golang-load-balancer 12 | ``` 13 | 14 | connect to http://localhost:8080 (default settings) 15 | 16 | ## Configuration 17 | 18 | Edit `config.yaml` to customize your settings. 19 | -------------------------------------------------------------------------------- /common/limits.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type limitHandler struct { 8 | connc chan struct{} 9 | handler http.Handler 10 | } 11 | 12 | func (h *limitHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 13 | select { 14 | case <-h.connc: 15 | h.handler.ServeHTTP(w, req) 16 | h.connc <- struct{}{} 17 | default: 18 | http.Error(w, "503 too busy", 503) 19 | } 20 | } 21 | 22 | func NewLimitHandler(maxConns int, handler http.Handler) http.Handler { 23 | h := &limitHandler{ 24 | connc: make(chan struct{}, maxConns), 25 | handler: handler, 26 | } 27 | for i := 0; i < maxConns; i++ { 28 | h.connc <- struct{}{} 29 | } 30 | return h 31 | } 32 | -------------------------------------------------------------------------------- /common/limits_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "sync" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | type FakeProxy struct{} 12 | 13 | func (h *FakeProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 14 | time.Sleep(100 * time.Millisecond) 15 | fmt.Fprintf(w, "hello") 16 | } 17 | 18 | func TestMaxConnections(t *testing.T) { 19 | maxConnections := 2 20 | serverUrl := "localhost:8088" 21 | 22 | go func() { 23 | s := &http.Server{ 24 | Addr: serverUrl, 25 | Handler: NewLimitHandler(maxConnections, &FakeProxy{}), 26 | } 27 | s.ListenAndServe() 28 | }() 29 | time.Sleep(time.Duration(100) * time.Millisecond) 30 | 31 | var wg sync.WaitGroup 32 | reachedMax := false 33 | wg.Add(maxConnections + 1) 34 | for i := 0; i <= maxConnections; i++ { 35 | go func() { 36 | defer wg.Done() 37 | res, _ := http.Get(fmt.Sprintf("http://%s/", serverUrl)) 38 | if res.StatusCode == 503 { 39 | reachedMax = true 40 | } 41 | }() 42 | } 43 | wg.Wait() 44 | 45 | if !reachedMax { 46 | t.Errorf("Expected to reach maxConnections but did not") 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /common/priority_queue.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | type PriorityQueuer interface { 4 | Next() Backend 5 | } 6 | 7 | type Backend struct { 8 | Url string 9 | ActiveConnections int 10 | } 11 | 12 | type RoundRobin struct { 13 | CurrentIndex int 14 | Backends []Backend 15 | } 16 | 17 | func Next(b *RoundRobin) Backend { 18 | res := b.Backends[b.CurrentIndex] 19 | b.CurrentIndex = (b.CurrentIndex + 1) % len(b.Backends) 20 | 21 | return res 22 | } 23 | -------------------------------------------------------------------------------- /common/priority_queue_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestNext(t *testing.T) { 8 | backendA := Backend{Url: "http://localhost:8081", ActiveConnections: 0} 9 | backendB := Backend{Url: "http://localhost:8082", ActiveConnections: 0} 10 | backends := RoundRobin{0, []Backend{backendA, backendB}} 11 | 12 | firstBackend := Next(&backends) 13 | secondBackend := Next(&backends) 14 | thirdBackend := Next(&backends) 15 | 16 | if firstBackend != backendA { 17 | t.Errorf("Wrong order of Next(), expected %v, got: %v", backendA, firstBackend) 18 | } 19 | 20 | if secondBackend != backendB { 21 | t.Errorf("Wrong order of Next(), expected %v, got: %v", backendB, firstBackend) 22 | } 23 | 24 | if thirdBackend != backendA { 25 | t.Errorf("Wrong order of Next(), expected %v, got: %v", backendA, firstBackend) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # The load balancer server configuration 2 | server: 3 | bind: '0.0.0.0' 4 | port: '8080' 5 | maxconnections: '100' 6 | # in seconds 7 | readTimeout: '30' 8 | # in seconds 9 | writeTimeout: '30' 10 | # The list of server you want to balance on separated by comma 11 | balancers: 'http://0.0.0.0:3003,http://0.0.0.0:3004' 12 | -------------------------------------------------------------------------------- /config_test.yaml: -------------------------------------------------------------------------------- 1 | # The load balancer server configuration 2 | server: 3 | bind: 'localhost' 4 | port: '8080' 5 | maxconnections: '100' 6 | readTimeout: '30' 7 | writeTimeout: '30' 8 | # The list of server you want to balance on 9 | balancers: 'http://localhost:8081' 10 | -------------------------------------------------------------------------------- /golang-load-balancer: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intrip/golang-load-balancer/8d9495966af19053103ca60b18b8f17852836223/golang-load-balancer -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "github.com/intrip/golang-load-balancer/common" 7 | "github.com/spf13/viper" 8 | "log" 9 | "net/http" 10 | "net/http/httputil" 11 | "net/url" 12 | "strconv" 13 | "strings" 14 | "time" 15 | ) 16 | 17 | var ( 18 | bind, balance string 19 | port, maxConnections int 20 | readTimeout = 10 21 | writeTimeout = 10 22 | backends []common.Backend 23 | testEnv bool 24 | ) 25 | 26 | func init() { 27 | loadConfig("config") 28 | 29 | if flag.Lookup("test.v") == nil { 30 | testEnv = false 31 | } else { 32 | testEnv = true 33 | } 34 | } 35 | 36 | // loads config from ./config.yaml 37 | func loadConfig(config string) { 38 | viper.SetConfigType("yaml") 39 | viper.SetConfigName(config) 40 | viper.AddConfigPath(".") 41 | err := viper.ReadInConfig() 42 | if err != nil { 43 | panic(fmt.Errorf("Error in config file: %s \n", err)) 44 | } 45 | 46 | server := viper.GetStringMapString("server") 47 | // read port 48 | if v, ok := server["port"]; ok { 49 | port, err = strconv.Atoi(v) 50 | if err != nil { 51 | panic(fmt.Errorf("Server port is not valid: %s \n", err)) 52 | } 53 | } else { 54 | panic(fmt.Errorf("Server port is required")) 55 | } 56 | // listen 57 | if v, ok := server["bind"]; ok { 58 | bind = v 59 | } else { 60 | panic(fmt.Errorf("Server bind is required")) 61 | } 62 | // maxConnections 63 | if v, ok := server["maxconnections"]; ok { 64 | maxConnections, err = strconv.Atoi(v) 65 | if err != nil { 66 | panic(fmt.Errorf("Server maxConnections is not valid: %s \n", err)) 67 | } 68 | } else { 69 | panic(fmt.Errorf("Server maxConnections is required")) 70 | } 71 | 72 | // timeout 73 | if v, ok := server["readtimeout"]; ok { 74 | readTimeout, err = strconv.Atoi(v) 75 | if err != nil { 76 | panic(fmt.Errorf("server readtimeout is not valid: %s \n", err)) 77 | } 78 | } 79 | if v, ok := server["writetimeout"]; ok { 80 | writeTimeout, err = strconv.Atoi(v) 81 | if err != nil { 82 | panic(fmt.Errorf("server writetimeout is not valid: %s \n", err)) 83 | } 84 | } 85 | balance = viper.GetString("balancers") 86 | backends = parseBalance(balance) 87 | } 88 | 89 | func main() { 90 | s := &http.Server{ 91 | Addr: serverUrl(), 92 | Handler: common.NewLimitHandler(maxConnections, &Proxy{&common.RoundRobin{0, backends}}), 93 | ReadTimeout: time.Duration(readTimeout) * time.Second, 94 | WriteTimeout: time.Duration(writeTimeout) * time.Second, 95 | MaxHeaderBytes: 1 << 20, 96 | } 97 | 98 | s.ListenAndServe() 99 | } 100 | 101 | func serverUrl() string { 102 | return fmt.Sprintf("%s:%d", bind, port) 103 | } 104 | 105 | type Proxy struct{ backendStruct *common.RoundRobin } 106 | 107 | func (h *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 108 | next := common.Next(h.backendStruct) 109 | doBalance(w, r, &next) 110 | } 111 | 112 | func doBalance(w http.ResponseWriter, r *http.Request, backend *common.Backend) { 113 | u, err := url.Parse(backend.Url) 114 | if err != nil { 115 | log.Panic("Error parsing backend Url: ", err) 116 | } 117 | 118 | if !testEnv { 119 | log.Printf("Request from: %s forwarded to: %s path: %s", r.RemoteAddr, backend.Url, r.RequestURI) 120 | } 121 | 122 | proxy := httputil.NewSingleHostReverseProxy(u) 123 | proxy.ServeHTTP(w, r) 124 | } 125 | 126 | func parseBalance(balancers string) (backends []common.Backend) { 127 | urls := strings.Split(balancers, ",") 128 | backends = make([]common.Backend, len(urls)) 129 | 130 | for index, backend := range urls { 131 | backends[index] = common.Backend{Url: backend, ActiveConnections: 0} 132 | } 133 | 134 | return 135 | } 136 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "github.com/intrip/golang-load-balancer/common" 6 | "io/ioutil" 7 | "log" 8 | "net" 9 | "net/http" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | func TestParseBalance(t *testing.T) { 15 | balance := "http://localhost:3000/home,http://localhost:3001/info" 16 | expectedBackends := []common.Backend{common.Backend{"http://localhost:3000/home", 0}, common.Backend{"http://localhost:3001/info", 0}} 17 | 18 | backends := parseBalance(balance) 19 | 20 | if len(backends) != len(expectedBackends) { 21 | t.Errorf("Result size differ, expected: %d, got: %d", len(expectedBackends), len(backends)) 22 | return 23 | } 24 | 25 | for index, backend := range backends { 26 | if backend != expectedBackends[index] { 27 | t.Errorf("Backend %d differ, expected: %q got %q", index, expectedBackends[index], backend) 28 | } 29 | } 30 | } 31 | 32 | func TestLoadConfig(t *testing.T) { 33 | loadConfig("config_test") 34 | expectedPort := 8080 35 | expectedBind := "localhost" 36 | expectedMaxConnections := 100 37 | expectedReadTimeout := 30 38 | expectedWriteTimeout := 30 39 | expectedBalance := "http://localhost:8081" 40 | 41 | if expectedPort != port { 42 | t.Errorf("Port differ, expected %d got %d", expectedPort, port) 43 | } 44 | if expectedBind != bind { 45 | t.Errorf("Bind differ, expected %d got %d", expectedBind, bind) 46 | } 47 | if expectedMaxConnections != maxConnections { 48 | t.Errorf("MaxConnections differ, expected %d got %d", expectedMaxConnections, maxConnections) 49 | } 50 | if expectedReadTimeout != readTimeout { 51 | t.Errorf("readTimeout differ, expected %d got %d", expectedReadTimeout, readTimeout) 52 | } 53 | if expectedWriteTimeout != writeTimeout { 54 | t.Errorf("WriteTimeout differ, expected %d got %d", expectedWriteTimeout, writeTimeout) 55 | } 56 | if expectedBalance != balance { 57 | t.Errorf("Balance differ, expected %d got %d", expectedBalance, balance) 58 | } 59 | } 60 | 61 | func TestDoBalance(t *testing.T) { 62 | msg := "Hello world!" 63 | 64 | // listen backend 65 | beListen := "localhost:8081" 66 | beRemoteAddr := "" 67 | beServeMux := http.NewServeMux() 68 | beServeMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 69 | if r.Header["X-Forwarded-For"][0] != beRemoteAddr { 70 | t.Errorf("Expected X-Forwarded-For: %s, got: %s", r.Header["X-Forwarded-For"][0], beRemoteAddr) 71 | } 72 | 73 | // send msg to the caller 74 | fmt.Fprintf(w, msg) 75 | }) 76 | beServer := &http.Server{ 77 | Addr: beListen, 78 | Handler: beServeMux, 79 | } 80 | 81 | // listen balancer 82 | serveMux := http.NewServeMux() 83 | serveMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 84 | beRemoteAddr, _, _ = net.SplitHostPort(r.RemoteAddr) 85 | doBalance(w, r, &common.Backend{Url: fmt.Sprintf("http://%s", beListen), ActiveConnections: 0}) 86 | }) 87 | server := &http.Server{ 88 | Addr: serverUrl(), 89 | Handler: serveMux, 90 | } 91 | 92 | // backend 93 | go func() { 94 | beServer.ListenAndServe() 95 | }() 96 | // balancer 97 | go func() { 98 | server.ListenAndServe() 99 | }() 100 | 101 | time.Sleep(time.Duration(100) * time.Millisecond) 102 | res, err := http.Get(fmt.Sprintf("http://%s/", serverUrl())) 103 | if err != nil { 104 | log.Panic("[test] Error connecting to balancer: ", err) 105 | } 106 | bodyBytes, _ := ioutil.ReadAll(res.Body) 107 | 108 | if string(bodyBytes) != msg { 109 | t.Errorf("Expected to read %s, got: %s", msg, bodyBytes) 110 | } 111 | 112 | defer beServer.Close() 113 | defer server.Close() 114 | } 115 | 116 | func TestBackendUnavailable(t *testing.T) { 117 | // backend 118 | serveMux := http.NewServeMux() 119 | serveMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 120 | doBalance(w, r, &common.Backend{Url: "http://localhost:9999"}) 121 | }) 122 | server := &http.Server{ 123 | Addr: serverUrl(), 124 | Handler: serveMux, 125 | } 126 | // balancer 127 | go func() { 128 | server.ListenAndServe() 129 | }() 130 | 131 | time.Sleep(time.Duration(100) * time.Millisecond) 132 | res, _ := http.Get(fmt.Sprintf("http://%s/", serverUrl())) 133 | 134 | if res.StatusCode != 502 { 135 | t.Errorf("Expected status code 502, got: %d", res.StatusCode) 136 | } 137 | 138 | defer server.Close() 139 | } 140 | --------------------------------------------------------------------------------