├── go.mod ├── .travis.yml ├── .gitignore ├── errors.go ├── balancer.go ├── client.go ├── const.go ├── CONTRIBUTING.md ├── LICENSE ├── doc.go ├── README.md ├── connection_test.go ├── transport_test.go ├── roundrobin ├── balancer.go └── balancer_test.go ├── connection.go └── transport.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/olivere/balancers 2 | 3 | go 1.12 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: go 3 | env: 4 | - GO111MODULE=on 5 | - GO111MODULE=off 6 | go: 7 | - "1.11.x" 8 | - "1.12.x" 9 | install: 10 | - go get ./... 11 | script: 12 | - go test -v -race -run=. -bench=. ./... 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2014-2015 Oliver Eilhard. All rights reserved. 2 | // Use of this source code is governed by the MIT license. 3 | // See LICENSE file for details. 4 | package balancers 5 | 6 | import ( 7 | "errors" 8 | ) 9 | 10 | // ErrNoConn must be returned when a Balancer does not find a (non-broken) connection. 11 | var ErrNoConn = errors.New("no connection") 12 | -------------------------------------------------------------------------------- /balancer.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2014-2015 Oliver Eilhard. All rights reserved. 2 | // Use of this source code is governed by the MIT license. 3 | // See LICENSE file for details. 4 | package balancers 5 | 6 | // Balancer holds a list of connections to hosts. 7 | type Balancer interface { 8 | // Get returns a connection that can be used for the next request. 9 | Get() (Connection, error) 10 | 11 | // Connections is the list of available connections. 12 | Connections() []Connection 13 | } 14 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2014-2015 Oliver Eilhard. All rights reserved. 2 | // Use of this source code is governed by the MIT license. 3 | // See LICENSE file for details. 4 | package balancers 5 | 6 | import ( 7 | "net/http" 8 | ) 9 | 10 | // NewClient returns a http Client that applies a certain scheduling algorithm 11 | // (like round-robin) to load balance between several HTTP servers. 12 | func NewClient(b Balancer) *http.Client { 13 | return &http.Client{ 14 | Transport: &Transport{balancer: b}, 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /const.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2014-2015 Oliver Eilhard. All rights reserved. 2 | // Use of this source code is governed by the MIT license. 3 | // See LICENSE file for details. 4 | package balancers 5 | 6 | import ( 7 | "runtime" 8 | "time" 9 | ) 10 | 11 | const ( 12 | // Version is the current version of this package. 13 | Version = "1.0.0" 14 | ) 15 | 16 | var ( 17 | // UserAgent is sent with all heartbeat requests. 18 | UserAgent = "balancers/" + Version + " (" + runtime.GOOS + "-" + runtime.GOARCH + ")" 19 | 20 | // DefaultHeartbeatDuration is the default time between heartbeat messages. 21 | DefaultHeartbeatDuration = 30 * time.Second 22 | ) 23 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | Balancers is an open-source project and we are looking forward to each 4 | contribution. 5 | 6 | ## Your Pull Request 7 | 8 | To make it easy to review and understand your changes, please keep the 9 | following things in mind before submitting your pull request: 10 | 11 | * Work on the latest possible state of `olivere/balancers`. 12 | * Create a branch dedicated to your change. 13 | * If possible, write a test case which confirms your change. 14 | * Make sure your changes and your tests work. 15 | * Test your changes before creating a pull request (`go test ./...`). 16 | * Don't mix several features or bug fixes in one pull request. 17 | * Create a meaningful commit message. 18 | * Explain your change, e.g. provide a link to the issue you are fixing. 19 | * Format your source with `go fmt`. 20 | 21 | ## Additional Resources 22 | 23 | * [GitHub documentation](http://help.github.com/) 24 | * [GitHub pull request documentation](http://help.github.com/send-pull-requests/) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright © 2014-2015 Oliver Eilhard 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the “Software”), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included 12 | in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 20 | IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2014-2015 Oliver Eilhard. All rights reserved. 2 | // Use of this source code is governed by the MIT license. 3 | // See LICENSE file for details. 4 | 5 | /* 6 | Package balancers provides implementations of HTTP load-balancers. 7 | 8 | It has two key interfaces: A Balancer is the implementation of a load-balancer 9 | that chooses from a set of Connections. 10 | 11 | You can e.g. use the balancer from the roundrobin package to rewrite 12 | HTTP requests and use URLs from a given set of HTTP connections. 13 | 14 | Suppose you have a cluster of two servers (on two different URLs) and you 15 | want to load-balance between the two in a round-robin fashion, you can use 16 | code like this: 17 | 18 | balancer, err := roundrobin.NewBalancerFromURL("https://server1.com", "https://server2.com") 19 | ... 20 | // Get a HTTP client for the roundrobin balancer. 21 | client := balancer.Client() 22 | ... 23 | client.Get("http://example.com/path1?foo=bar") // will rewrite URL to https://server1.com/path1?foo=bar 24 | client.Get("http://example.com/path1?foo=bar") // will rewrite URL to https://server2.com/path1?foo=bar 25 | client.Get("http://example.com/path1?foo=bar") // will rewrite URL to https://server1.com/path1?foo=bar 26 | client.Get("/path1?foo=bar") // will rewrite URL to https://server2.com/path1?foo=bar 27 | */ 28 | package balancers 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Balancers 2 | 3 | Balancers provides implementations of HTTP load-balancers. 4 | 5 | [![Build Status](https://travis-ci.org/olivere/balancers.svg?branch=master)](https://travis-ci.org/olivere/balancers) 6 | [![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/olivere/balancers) 7 | [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/olivere/balancers/master/LICENSE) 8 | 9 | ## What does it do? 10 | 11 | Balancers gives you a `http.Client` from [net/http](http://golang.org/pkg/net/http) 12 | that rewrites your requests' scheme, host, and userinfo according to the 13 | rules of a balancer. A balancer is simply an algorithm to pick the host for 14 | the next request a `http.Client`. 15 | 16 | ## How does it work? 17 | 18 | Suppose you have a cluster of two servers (on two different URLs) and you 19 | want to load balance between them. A very simple implementation can be done 20 | with the [round-robin scheduling algorithm](http://en.wikipedia.org/wiki/Round-robin_scheduling). 21 | Round-robin iterates through the list of available hosts and restarts 22 | at the first when the end is reached. Here's some code that illustrates that: 23 | 24 | ```go 25 | // Get a balancer that performs round-robin scheduling between two servers. 26 | balancer, err := roundrobin.NewBalancerFromURL("https://server1.com", "https://server2.com") 27 | 28 | // Get a HTTP client based on that balancer. 29 | client := balancers.NewClient(balancer) 30 | 31 | // Now request some data. The scheme, host, and user info will be rewritten 32 | // by the balancer; you'll never get data from http://example.com, only data 33 | // from http://server1.com or http://server2.com. 34 | client.Get("http://example.com/path1?foo=bar") // rewritten to https://server1.com/path1?foo=bar 35 | client.Get("http://example.com/path1?foo=bar") // rewritten to https://server2.com/path1?foo=bar 36 | client.Get("http://example.com/path1?foo=bar") // rewritten to https://server1.com/path1?foo=bar 37 | client.Get("/path1?foo=bar") // rewritten to https://server2.com/path1?foo=bar 38 | ``` 39 | 40 | ## Status 41 | 42 | The current state of Balancers is a proof-of-concept. 43 | It didn't touch production systems yet. 44 | 45 | ## Credits 46 | 47 | Thanks a lot for the great folks working on [Go](http://www.golang.org/). 48 | 49 | ## LICENSE 50 | 51 | MIT-LICENSE. See [LICENSE](http://olivere.mit-license.org/) 52 | or the LICENSE file provided in the repository for details. 53 | -------------------------------------------------------------------------------- /connection_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2014-2015 Oliver Eilhard. All rights reserved. 2 | // Use of this source code is governed by the MIT license. 3 | // See LICENSE file for details. 4 | package balancers 5 | 6 | import ( 7 | "net/http" 8 | "net/http/httptest" 9 | "net/url" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | func TestHttpConnection(t *testing.T) { 15 | var visited bool 16 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 17 | visited = true 18 | })) 19 | defer server.Close() 20 | 21 | url, _ := url.Parse(server.URL) 22 | conn := NewHttpConnection(url) 23 | if conn == nil { 24 | t.Fatal("expected connection") 25 | } 26 | if conn.URL() != url { 27 | t.Errorf("expected URL %v; got: %v", url, conn.URL()) 28 | } 29 | broken := conn.IsBroken() 30 | if broken { 31 | t.Error("expected connection to not be broken") 32 | } 33 | if !visited { 34 | t.Error("expected server to be pinged") 35 | } 36 | } 37 | 38 | func TestHttpConnectionReturningInternalServerErrorIsBroken(t *testing.T) { 39 | var visited bool 40 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 41 | visited = true 42 | w.WriteHeader(http.StatusInternalServerError) 43 | })) 44 | defer server.Close() 45 | 46 | url, _ := url.Parse(server.URL) 47 | conn := NewHttpConnection(url) 48 | if conn == nil { 49 | t.Fatal("expected connection") 50 | } 51 | if conn.URL() != url { 52 | t.Errorf("expected URL %v; got: %v", url, conn.URL()) 53 | } 54 | broken := conn.IsBroken() 55 | if !broken { 56 | t.Error("expected connection to be broken") 57 | } 58 | if !visited { 59 | t.Error("expected server to be pinged") 60 | } 61 | } 62 | 63 | func TestHttpConnectionToNonexistentServer(t *testing.T) { 64 | url, _ := url.Parse("http://localhost:12345") 65 | conn := NewHttpConnection(url) 66 | if conn == nil { 67 | t.Fatal("expected connection") 68 | } 69 | if conn.URL() != url { 70 | t.Errorf("expected URL %v; got: %v", url, conn.URL()) 71 | } 72 | broken := conn.IsBroken() 73 | if !broken { 74 | t.Error("expected connection to be broken") 75 | } 76 | } 77 | 78 | func TestHttpConnectionHeartbeat(t *testing.T) { 79 | var count int 80 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 81 | count += 1 82 | })) 83 | defer server.Close() 84 | 85 | url, _ := url.Parse(server.URL) 86 | conn := NewHttpConnection(url).HeartbeatDuration(2 * time.Second) 87 | if conn == nil { 88 | t.Fatal("expected connection") 89 | } 90 | if conn.URL() != url { 91 | t.Errorf("expected URL %v; got: %v", url, conn.URL()) 92 | } 93 | time.Sleep(3 * time.Second) 94 | err := conn.Close() 95 | if err != nil { 96 | t.Fatal(err) 97 | } 98 | if count != 2 { // 1 on NewConnection + 1 for a heartbeat 99 | t.Errorf("expected %d heartbeats; got: %d", 2, count) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /transport_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2014-2015 Oliver Eilhard. All rights reserved. 2 | // Use of this source code is governed by the MIT license. 3 | // See LICENSE file for details. 4 | package balancers 5 | 6 | import ( 7 | "net/http" 8 | "net/http/httptest" 9 | "net/url" 10 | "testing" 11 | ) 12 | 13 | func TestCloseRequest(t *testing.T) { 14 | orig, _ := http.NewRequest("GET", "https://user:pwd@localhost:12345/path?query=1#hash", nil) 15 | dup := cloneRequest(orig) 16 | if dup.URL.User != orig.URL.User { 17 | t.Errorf("expected userinfo %v; got: %v", orig.URL.User, dup.URL.User) 18 | } 19 | if dup.URL.Scheme != "https" { 20 | t.Errorf("expected scheme %q; got: %q", "https", dup.URL.Scheme) 21 | } 22 | if dup.URL.Host != "localhost:12345" { 23 | t.Errorf("expected host %q; got: %q", "localhost:12345", dup.URL.Host) 24 | } 25 | if dup.URL.Path != "/path" { 26 | t.Errorf("expected path %q; got: %q", "/path", dup.URL.Path) 27 | } 28 | if dup.URL.RawQuery != "query=1" { 29 | t.Errorf("expected raw query %q; got: %q", "query=1", dup.URL.RawQuery) 30 | } 31 | if dup.URL.Fragment != "hash" { 32 | t.Errorf("expected fragment %q; got: %q", "hash", dup.URL.Fragment) 33 | } 34 | } 35 | 36 | func TestModifyRequest(t *testing.T) { 37 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 38 | w.WriteHeader(http.StatusOK) 39 | })) 40 | defer server.Close() 41 | 42 | tests := []struct { 43 | Req string 44 | ConnURL string 45 | Expected error 46 | }{ 47 | { 48 | "http://localhost:12345/path?query=1#hash", 49 | server.URL, 50 | nil, 51 | }, 52 | } 53 | 54 | for _, test := range tests { 55 | orig, err := url.Parse(test.Req) 56 | if err != nil { 57 | t.Fatal(err) 58 | } 59 | 60 | req, err := http.NewRequest("GET", test.Req, nil) 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | url, err := url.Parse(test.ConnURL) 66 | if err != nil { 67 | t.Fatal(err) 68 | } 69 | conn := NewHttpConnection(url) 70 | 71 | err = modifyRequest(req, conn) 72 | if err != test.Expected { 73 | t.Errorf("expected err = %v; got: %v", test.Expected, err) 74 | } else { 75 | if url.Scheme != "" && req.URL.Scheme != url.Scheme { 76 | t.Errorf("expected scheme %q; got: %q", url.Scheme, req.URL.Scheme) 77 | } 78 | if url.Host != "" && req.URL.Host != url.Host { 79 | t.Errorf("expected host %q; got: %q", url.Scheme, req.URL.Scheme) 80 | } 81 | if url.User != nil && req.URL.User != url.User { 82 | t.Errorf("expected userinfo %v; got: %v", url.Scheme, req.URL.Scheme) 83 | } 84 | if req.URL.Path != orig.Path { 85 | t.Errorf("expected path %q; got: %q", orig.Path, req.URL.Path) 86 | } 87 | if req.URL.RawQuery != orig.RawQuery { 88 | t.Errorf("expected raw query %q; got: %q", orig.RawQuery, req.URL.RawQuery) 89 | } 90 | if req.URL.Fragment != orig.Fragment { 91 | t.Errorf("expected fragment %q; got: %q", orig.Fragment, req.URL.Fragment) 92 | } 93 | } 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /roundrobin/balancer.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2014-2015 Oliver Eilhard. All rights reserved. 2 | // Use of this source code is governed by the MIT license. 3 | // See LICENSE file for details. 4 | package roundrobin 5 | 6 | import ( 7 | "net/url" 8 | "sync" 9 | 10 | "github.com/olivere/balancers" 11 | ) 12 | 13 | // Balancer implements a round-robin balancer. 14 | type Balancer struct { 15 | sync.Mutex // guards the following variables 16 | conns []balancers.Connection 17 | idx int // index into conns 18 | } 19 | 20 | // NewBalancer creates a new round-robin balancer. It can be initializes by 21 | // a variable number of connections. To use plain URLs instead of 22 | // connections, use NewBalancerFromURL. 23 | func NewBalancer(conns ...balancers.Connection) (balancers.Balancer, error) { 24 | b := &Balancer{ 25 | conns: make([]balancers.Connection, 0), 26 | } 27 | if len(conns) > 0 { 28 | b.conns = append(b.conns, conns...) 29 | } 30 | return b, nil 31 | } 32 | 33 | // NewBalancerFromURL creates a new round-robin balancer for the 34 | // given list of URLs. It returns an error if any of the URLs is invalid. 35 | func NewBalancerFromURL(urls ...string) (*Balancer, error) { 36 | b := &Balancer{ 37 | conns: make([]balancers.Connection, 0), 38 | } 39 | for _, rawurl := range urls { 40 | if u, err := url.Parse(rawurl); err != nil { 41 | return nil, err 42 | } else { 43 | b.conns = append(b.conns, balancers.NewHttpConnection(u)) 44 | } 45 | } 46 | return b, nil 47 | } 48 | 49 | // Get returns a connection from the balancer that can be used for the next request. 50 | // ErrNoConn is returns when no connection is available. 51 | func (b *Balancer) Get() (balancers.Connection, error) { 52 | b.Lock() 53 | defer b.Unlock() 54 | 55 | if len(b.conns) == 0 { 56 | return nil, balancers.ErrNoConn 57 | } 58 | 59 | var conn balancers.Connection 60 | for i := 0; i < len(b.conns); i++ { 61 | candidate := b.conns[b.idx] 62 | b.idx = (b.idx + 1) % len(b.conns) 63 | if !candidate.IsBroken() { 64 | conn = candidate 65 | break 66 | } 67 | } 68 | 69 | if conn == nil { 70 | return nil, balancers.ErrNoConn 71 | } 72 | return conn, nil 73 | } 74 | 75 | // Connections returns a list of all connections. 76 | func (b *Balancer) Connections() []balancers.Connection { 77 | b.Lock() 78 | defer b.Unlock() 79 | conns := make([]balancers.Connection, len(b.conns)) 80 | for i, c := range b.conns { 81 | if oc, ok := c.(*balancers.HttpConnection); ok { 82 | // Make a clone 83 | cr := &simpleConn{ 84 | url: oc.URL(), 85 | broken: oc.IsBroken(), 86 | } 87 | conns[i] = cr 88 | } 89 | } 90 | return conns 91 | } 92 | 93 | var ( 94 | // Ensure that simpleConn make implements balancers.Connection. 95 | _ balancers.Connection = (*simpleConn)(nil) 96 | ) 97 | 98 | type simpleConn struct { 99 | url *url.URL 100 | broken bool 101 | } 102 | 103 | func (c *simpleConn) URL() *url.URL { return c.url } 104 | func (c *simpleConn) IsBroken() bool { return c.broken } 105 | -------------------------------------------------------------------------------- /connection.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2014-2015 Oliver Eilhard. All rights reserved. 2 | // Use of this source code is governed by the MIT license. 3 | // See LICENSE file for details. 4 | package balancers 5 | 6 | import ( 7 | "net/http" 8 | "net/url" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | // Connection is a single connection to a host. It is defined by a URL. 14 | // It also maintains state in the form that a connection can be broken. 15 | // TODO(oe) Not sure if this abstraction is necessary. 16 | type Connection interface { 17 | // URL to the host. 18 | URL() *url.URL 19 | // IsBroken must return true if the connection to URL is currently not available. 20 | IsBroken() bool 21 | } 22 | 23 | // HttpConnection is a HTTP connection to a host. 24 | // It implements the Connection interface and can be used by balancer 25 | // implementations. 26 | type HttpConnection struct { 27 | sync.Mutex 28 | url *url.URL 29 | broken bool 30 | heartbeatDuration time.Duration 31 | heartbeatStop chan bool 32 | } 33 | 34 | // NewHttpConnection creates a new HTTP connection to the given URL. 35 | func NewHttpConnection(url *url.URL) *HttpConnection { 36 | c := &HttpConnection{ 37 | url: url, 38 | heartbeatDuration: DefaultHeartbeatDuration, 39 | heartbeatStop: make(chan bool), 40 | } 41 | c.checkBroken() 42 | go c.heartbeat() 43 | return c 44 | } 45 | 46 | // Close this connection. 47 | func (c *HttpConnection) Close() error { 48 | c.Lock() 49 | defer c.Unlock() 50 | c.heartbeatStop <- true // wait for heartbeat ticker to stop 51 | c.broken = false 52 | return nil 53 | } 54 | 55 | // HeartbeatDuration sets the duration in which the connection is checked. 56 | func (c *HttpConnection) HeartbeatDuration(d time.Duration) *HttpConnection { 57 | c.Lock() 58 | defer c.Unlock() 59 | c.heartbeatStop <- true // wait for heartbeat ticker to stop 60 | c.broken = false 61 | c.heartbeatDuration = d 62 | go c.heartbeat() 63 | return c 64 | } 65 | 66 | // heartbeat periodically checks if the connection is broken. 67 | func (c *HttpConnection) heartbeat() { 68 | ticker := time.NewTicker(c.heartbeatDuration) 69 | for { 70 | select { 71 | case <-ticker.C: 72 | c.checkBroken() 73 | case <-c.heartbeatStop: 74 | return 75 | } 76 | } 77 | } 78 | 79 | // checkBroken checks if the HTTP connection is alive. 80 | func (c *HttpConnection) checkBroken() { 81 | c.Lock() 82 | defer c.Unlock() 83 | 84 | // TODO(oe) Can we use HEAD? 85 | req, err := http.NewRequest("GET", c.url.String(), nil) 86 | if err != nil { 87 | c.broken = true 88 | return 89 | } 90 | // Add UA to heartbeat requests. 91 | req.Header.Add("User-Agent", UserAgent) 92 | 93 | // Use a standard HTTP client with a timeout of 5 seconds. 94 | cl := &http.Client{Timeout: 5 * time.Second} 95 | res, err := cl.Do(req) 96 | if err == nil { 97 | defer res.Body.Close() 98 | if res.StatusCode == http.StatusOK { 99 | c.broken = false 100 | } else { 101 | c.broken = true 102 | } 103 | } else { 104 | c.broken = true 105 | } 106 | } 107 | 108 | // URL returns the URL of the HTTP connection. 109 | func (c *HttpConnection) URL() *url.URL { 110 | return c.url 111 | } 112 | 113 | // IsBroken returns true if the HTTP connection is currently broken. 114 | func (c *HttpConnection) IsBroken() bool { 115 | return c.broken 116 | } 117 | -------------------------------------------------------------------------------- /transport.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2014-2015 Oliver Eilhard. All rights reserved. 2 | // Use of this source code is governed by the MIT license. 3 | // See LICENSE file for details. 4 | 5 | // Most of the code here is taken from the Google OAuth2 client library 6 | // at https://github.com/golang/oauth2, 7 | // especially https://github.com/golang/oauth2/blob/master/transport.go. 8 | package balancers 9 | 10 | import ( 11 | "io" 12 | "net/http" 13 | "sync" 14 | ) 15 | 16 | // Transport implements a http Transport for a HTTP load balancer. 17 | type Transport struct { 18 | Base http.RoundTripper 19 | 20 | balancer Balancer 21 | 22 | mu sync.Mutex 23 | modReq map[*http.Request]*http.Request 24 | } 25 | 26 | // RoundTrip is the core of the balancers package. It accepts a request, 27 | // replaces host, scheme, and port with the URl provided by the balancer, 28 | // executes it and returns the response to the caller. 29 | func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) { 30 | conn, err := t.balancer.Get() 31 | if err != nil { 32 | return nil, err 33 | } 34 | 35 | rc := cloneRequest(r) 36 | if err := modifyRequest(rc, conn); err != nil { 37 | return nil, err 38 | } 39 | t.setModReq(r, rc) 40 | 41 | res, err := t.base().RoundTrip(rc) 42 | if err != nil { 43 | t.setModReq(r, nil) 44 | return nil, err 45 | } 46 | res.Body = &onEOFReader{ 47 | rc: res.Body, 48 | fn: func() { t.setModReq(r, nil) }, 49 | } 50 | return res, nil 51 | } 52 | 53 | // CancelRequest cancels the given request (if canceling is available). 54 | func (t *Transport) CancelRequest(r *http.Request) { 55 | type canceler interface { 56 | CancelRequest(*http.Request) 57 | } 58 | if cr, ok := t.base().(canceler); ok { 59 | t.mu.Lock() 60 | modReq := t.modReq[r] 61 | delete(t.modReq, r) 62 | t.mu.Unlock() 63 | cr.CancelRequest(modReq) 64 | } 65 | } 66 | 67 | func (t *Transport) base() http.RoundTripper { 68 | if t.Base != nil { 69 | return t.Base 70 | } 71 | return http.DefaultTransport 72 | } 73 | 74 | // modifyRequest exchanges the HTTP request scheme, host, and userinfo 75 | // by the URL the connection returns. 76 | func modifyRequest(r *http.Request, conn Connection) error { 77 | url := conn.URL() 78 | if url.Scheme != "" { 79 | r.URL.Scheme = url.Scheme 80 | } 81 | if url.Host != "" { 82 | r.URL.Host = url.Host 83 | } 84 | if url.User != nil { 85 | r.URL.User = url.User 86 | } 87 | return nil 88 | } 89 | 90 | // cloneRequest makes a duplicate of the request. 91 | func cloneRequest(r *http.Request) *http.Request { 92 | rc := new(http.Request) 93 | *rc = *r 94 | rc.Header = make(http.Header, len(r.Header)) 95 | for k, s := range r.Header { 96 | rc.Header[k] = append([]string(nil), s...) 97 | } 98 | return rc 99 | } 100 | 101 | func (t *Transport) setModReq(orig, mod *http.Request) { 102 | t.mu.Lock() 103 | defer t.mu.Unlock() 104 | 105 | if t.modReq == nil { 106 | t.modReq = make(map[*http.Request]*http.Request) 107 | } 108 | if mod == nil { 109 | delete(t.modReq, orig) 110 | } else { 111 | t.modReq[orig] = mod 112 | } 113 | } 114 | 115 | // onEOFReader is a reader that executes a function when io.EOF is read 116 | // or the reader is closed. 117 | type onEOFReader struct { 118 | rc io.ReadCloser 119 | fn func() 120 | } 121 | 122 | func (r *onEOFReader) Read(p []byte) (n int, err error) { 123 | n, err = r.rc.Read(p) 124 | if err == io.EOF { 125 | r.runFunc() 126 | } 127 | return 128 | } 129 | 130 | func (r *onEOFReader) Close() error { 131 | err := r.rc.Close() 132 | r.runFunc() 133 | return err 134 | } 135 | 136 | func (r *onEOFReader) runFunc() { 137 | if fn := r.fn; fn != nil { 138 | fn() 139 | r.fn = nil 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /roundrobin/balancer_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2014-2015 Oliver Eilhard. All rights reserved. 2 | // Use of this source code is governed by the MIT license. 3 | // See LICENSE file for details. 4 | package roundrobin 5 | 6 | import ( 7 | "net/http" 8 | "net/http/httptest" 9 | "net/url" 10 | "sync/atomic" 11 | "testing" 12 | 13 | "github.com/olivere/balancers" 14 | ) 15 | 16 | func TestNewBalancer(t *testing.T) { 17 | url1, _ := url.Parse("http://127.0.0.1:12345") 18 | url2, _ := url.Parse("http://127.0.0.1:23456") 19 | 20 | balancer, err := NewBalancer( 21 | balancers.NewHttpConnection(url1), 22 | balancers.NewHttpConnection(url2)) 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | conns := balancer.Connections() 27 | if len(conns) != 2 { 28 | t.Errorf("expected %d connections; got: %v", 2, len(conns)) 29 | } 30 | url := conns[0].URL() 31 | if url.String() != "http://127.0.0.1:12345" { 32 | t.Errorf("expected %q; got: %q", "http://127.0.0.1:12345", url.String()) 33 | } 34 | } 35 | 36 | func TestBalancerErrNoConnWithoutConnections(t *testing.T) { 37 | balancer, err := NewBalancer() 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | conns := balancer.Connections() 42 | if len(conns) != 0 { 43 | t.Errorf("expected %d connections; got: %v", 0, len(conns)) 44 | } 45 | _, err = balancer.Get() 46 | if err != balancers.ErrNoConn { 47 | t.Fatalf("expected %v; got: %v", balancers.ErrNoConn, err) 48 | } 49 | } 50 | 51 | func TestBalancer(t *testing.T) { 52 | var visited []int 53 | 54 | server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 55 | // Only count non-heartbeat requests 56 | if r.Header.Get("User-Agent") != balancers.UserAgent { 57 | visited = append(visited, 1) 58 | } 59 | })) 60 | defer server1.Close() 61 | 62 | server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 63 | // Only count non-heartbeat requests 64 | if r.Header.Get("User-Agent") != balancers.UserAgent { 65 | visited = append(visited, 2) 66 | } 67 | })) 68 | defer server2.Close() 69 | 70 | balancer, err := NewBalancerFromURL(server1.URL, server2.URL) 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | 75 | client := balancers.NewClient(balancer) 76 | client.Get(server1.URL) 77 | client.Get(server1.URL) 78 | client.Get(server1.URL) 79 | 80 | if len(visited) != 3 { 81 | t.Fatalf("expected %d URLs to be visited; got: %d", 3, len(visited)) 82 | } 83 | if visited[0] != 1 { 84 | t.Errorf("expected 1st URL to be %q", server1.URL) 85 | } 86 | if visited[1] != 2 { 87 | t.Errorf("expected 2nd URL to be %q", server2.URL) 88 | } 89 | if visited[2] != 1 { 90 | t.Errorf("expected 3rd URL to be %q", server1.URL) 91 | } 92 | } 93 | 94 | func TestBalancerWithBrokenConnections(t *testing.T) { 95 | var visited []int 96 | 97 | server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 98 | // Only count non-heartbeat requests 99 | if r.Header.Get("User-Agent") != balancers.UserAgent { 100 | visited = append(visited, 1) 101 | } 102 | })) 103 | defer server1.Close() 104 | 105 | server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 106 | // Only count non-heartbeat requests 107 | if r.Header.Get("User-Agent") != balancers.UserAgent { 108 | visited = append(visited, 2) 109 | } 110 | })) 111 | defer server2.Close() 112 | 113 | balancer, err := NewBalancerFromURL(server1.URL, "http://localhost:12345", server2.URL, "http://localhost:12346") 114 | if err != nil { 115 | t.Fatal(err) 116 | } 117 | 118 | client := balancers.NewClient(balancer) 119 | client.Get(server1.URL) 120 | client.Get(server1.URL) 121 | client.Get(server1.URL) 122 | client.Get(server1.URL) 123 | client.Get(server1.URL) 124 | 125 | if len(visited) != 5 { // 5 requests 126 | t.Fatalf("expected %d URLs to be visited; got: %d", 5, len(visited)) 127 | } 128 | if visited[0] != 1 { 129 | t.Errorf("expected 1st URL to be %q", server1.URL) 130 | } 131 | if visited[1] != 2 { 132 | t.Errorf("expected 2nd URL to be %q", server2.URL) 133 | } 134 | if visited[2] != 1 { 135 | t.Errorf("expected 3rd URL to be %q", server1.URL) 136 | } 137 | if visited[3] != 2 { 138 | t.Errorf("expected 4th URL to be %q", server2.URL) 139 | } 140 | if visited[4] != 1 { 141 | t.Errorf("expected 5th URL to be %q", server1.URL) 142 | } 143 | } 144 | 145 | func TestBalancerRewritesSchemeAndURLButNotPathOrQuery(t *testing.T) { 146 | var visited []string 147 | 148 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 149 | // Only count non-heartbeat requests 150 | if r.Header.Get("User-Agent") != balancers.UserAgent { 151 | visited = append(visited, r.URL.String()) 152 | } 153 | })) 154 | defer server.Close() 155 | 156 | balancer, err := NewBalancerFromURL(server.URL) 157 | if err != nil { 158 | t.Fatal(err) 159 | } 160 | 161 | client := balancers.NewClient(balancer) 162 | client.Get(server.URL + "/path?foo=bar&n=1") 163 | client.Get(server.URL + "/path?n=2") 164 | client.Get(server.URL + "/no/3") 165 | 166 | if len(visited) != 3 { 167 | t.Fatalf("expected %d URLs to be visited; got: %d", 3, len(visited)) 168 | } 169 | if visited[0] != "/path?foo=bar&n=1" { 170 | t.Errorf("expected 1st URL to be %q; got: %q", "/path?foo=bar&n=1", visited[0]) 171 | } 172 | if visited[1] != "/path?n=2" { 173 | t.Errorf("expected 2nd URL to be %q; got: %q", "/path?n=2", visited[1]) 174 | } 175 | if visited[2] != "/no/3" { 176 | t.Errorf("expected 3rd URL to be %q; got: %q", "/no/3", visited[2]) 177 | } 178 | } 179 | 180 | func BenchmarkBalancer(b *testing.B) { 181 | var ( 182 | visited1 int64 183 | visited2 int64 184 | ) 185 | 186 | server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 187 | // Only count non-heartbeat requests 188 | if r.Header.Get("User-Agent") != balancers.UserAgent { 189 | atomic.AddInt64(&visited1, 1) 190 | } 191 | })) 192 | defer server1.Close() 193 | 194 | server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 195 | // Only count non-heartbeat requests 196 | if r.Header.Get("User-Agent") != balancers.UserAgent { 197 | atomic.AddInt64(&visited2, 1) 198 | } 199 | })) 200 | defer server2.Close() 201 | 202 | balancer, err := NewBalancerFromURL(server1.URL, server2.URL) 203 | if err != nil { 204 | b.Fatal(err) 205 | } 206 | 207 | client := balancers.NewClient(balancer) 208 | 209 | b.ReportAllocs() 210 | 211 | for i := 0; i < b.N; i++ { 212 | res, err := client.Get(server1.URL) 213 | if err != nil { 214 | b.Fatal(err) 215 | } 216 | res.Body.Close() 217 | } 218 | 219 | if want, have := int64(b.N), visited1+visited2; want != have { 220 | b.Fatalf("expected %d visits; got: %d", want, have) 221 | } 222 | } 223 | --------------------------------------------------------------------------------