├── .drone.yml ├── .gitignore ├── .travis.yml ├── LICENSE ├── Makefile ├── README.md ├── base_test.go ├── circuitbreaker ├── cbreaker.go ├── cbreaker_test.go ├── effect.go ├── fallback.go ├── predicates.go ├── predicates_test.go ├── ratio.go └── ratio_test.go ├── endpoint └── endpoint.go ├── errors └── error.go ├── headers └── headers.go ├── limit ├── connlimit │ ├── connlimiter.go │ └── connlimiter_test.go ├── limiter.go ├── limiter_test.go └── tokenbucket │ ├── bucket.go │ ├── bucket_test.go │ ├── bucketset.go │ ├── bucketset_test.go │ ├── tokenlimiter.go │ └── tokenlimiter_test.go ├── loadbalance ├── balance.go ├── loadbalance_test.go └── roundrobin │ ├── fsm.go │ ├── fsm_test.go │ ├── recovery.go │ ├── roundrobin.go │ ├── roundrobin_test.go │ └── wendpoint.go ├── location ├── httploc │ ├── httploc.go │ ├── httploc_test.go │ └── rewrite.go └── location.go ├── metrics ├── anomaly.go ├── anomaly_test.go ├── counter.go ├── failrate.go ├── failrate_test.go ├── histogram.go ├── histogram_test.go ├── roundtrip.go └── rr_test.go ├── middleware ├── chain.go ├── chain_test.go └── middleware.go ├── netutils ├── buffer.go ├── buffer_test.go ├── netutils.go ├── netutils_test.go └── response.go ├── proxy.go ├── proxy_test.go ├── request ├── request.go └── request_test.go ├── route ├── exproute │ ├── exproute.go │ └── exproute_test.go ├── hostroute │ ├── host.go │ └── host_test.go ├── pathroute │ ├── route.go │ └── route_test.go └── router.go ├── template ├── template.go └── template_test.go ├── testutils ├── requests.go └── rndstring.go └── threshold ├── parse.go ├── parse_test.go └── threshold.go /.drone.yml: -------------------------------------------------------------------------------- 1 | image: mailgun/gobase2 2 | env: 3 | - GOROOT=/opt/go 4 | - PATH=$PATH:/opt/go/bin 5 | - GOPATH=/var/cache/drone 6 | script: 7 | - echo "gopath is: $GOPATH" 8 | - echo $PATH 9 | - go get -v -u code.google.com/p/go.tools/cover 10 | - go get -v -u github.com/axw/gocov 11 | - go install github.com/axw/gocov/gocov 12 | - go get -v -u github.com/golang/glog 13 | - go get -v -u github.com/mailgun/glogutils 14 | - go get -v -u github.com/axw/gocov 15 | - go get -v -u launchpad.net/gocheck 16 | - go get -v -u github.com/mailgun/gocql 17 | - go get -v -u github.com/robertkrimen/otto 18 | - go get -v -u github.com/coreos/go-etcd/etcd 19 | - go get -v -u github.com/mailgun/minheap 20 | - go test -v ./... 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.test 2 | flymake_* 3 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.3 4 | - tip 5 | 6 | install: 7 | - export PATH=$HOME/gopath/bin:$PATH 8 | - make deps 9 | 10 | script: 11 | - go test -v ./... 12 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: clean 2 | go test -v ./... -cover 3 | 4 | deps: 5 | go list -f '{{join .Deps "\n"}} \ 6 | {{join .TestImports "\n"}}' ./... | xargs go list -e -f '{{if not .Standard}}{{.ImportPath}}{{end}}' | grep -v `go list` | xargs go get -u -v 7 | 8 | clean: 9 | find . -name flymake_* -delete 10 | 11 | test-package: clean 12 | go test -v ./$(p) 13 | 14 | bench-package: clean 15 | go test ./$(p) -check.bmem -check.b -test.bench=. 16 | 17 | cover-package: clean 18 | go test -v ./$(p) -coverprofile=/tmp/coverage.out 19 | go tool cover -html=/tmp/coverage.out 20 | 21 | sloccount: 22 | find . -name "*.go" -print0 | xargs -0 wc -l 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | Status 3 | ------ 4 | 5 | This library has been refactored into: https://github.com/mailgun/oxy and is being deprecated 6 | 7 | 8 | [OXY](https://github.com/mailgun/oxy) is compatible with HTTP standard library, provides the same features as Vulcan and is simpler to use. Please consider using it instead. 9 | 10 | Vulcan library will stay there for a while in case if you are using it, but I would suggest consider migrating. Vulcand project is currently migrating to oxy library. 11 | 12 | -------------------------------------------------------------------------------- /base_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Declares gocheck's test suites 3 | */ 4 | package vulcan 5 | 6 | import ( 7 | . "gopkg.in/check.v1" 8 | "testing" 9 | ) 10 | 11 | func Test(t *testing.T) { TestingT(t) } 12 | 13 | //This is a simple suite to use if tests dont' need anything 14 | //special 15 | type MainSuite struct { 16 | } 17 | 18 | func (s *MainSuite) SetUpTest(c *C) { 19 | } 20 | 21 | var _ = Suite(&MainSuite{}) 22 | -------------------------------------------------------------------------------- /circuitbreaker/effect.go: -------------------------------------------------------------------------------- 1 | package circuitbreaker 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "io/ioutil" 8 | "net/http" 9 | "net/url" 10 | "strings" 11 | 12 | "github.com/mailgun/log" 13 | "github.com/mailgun/vulcan/netutils" 14 | ) 15 | 16 | type SideEffect interface { 17 | Exec() error 18 | } 19 | 20 | type Webhook struct { 21 | URL string 22 | Method string 23 | Headers http.Header 24 | Form url.Values 25 | Body []byte 26 | } 27 | 28 | type WebhookSideEffect struct { 29 | w Webhook 30 | } 31 | 32 | func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) { 33 | if w.Method == "" { 34 | return nil, fmt.Errorf("Supply method") 35 | } 36 | _, err := netutils.ParseUrl(w.URL) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | return &WebhookSideEffect{w: w}, nil 42 | } 43 | 44 | func (w *WebhookSideEffect) getBody() io.Reader { 45 | if len(w.w.Form) != 0 { 46 | return strings.NewReader(w.w.Form.Encode()) 47 | } 48 | if len(w.w.Body) != 0 { 49 | return bytes.NewBuffer(w.w.Body) 50 | } 51 | return nil 52 | } 53 | 54 | func (w *WebhookSideEffect) Exec() error { 55 | r, err := http.NewRequest(w.w.Method, w.w.URL, w.getBody()) 56 | if err != nil { 57 | return err 58 | } 59 | if len(w.w.Headers) != 0 { 60 | netutils.CopyHeaders(r.Header, w.w.Headers) 61 | } 62 | if len(w.w.Form) != 0 { 63 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded") 64 | } 65 | re, err := http.DefaultClient.Do(r) 66 | if err != nil { 67 | return err 68 | } 69 | if re.Body != nil { 70 | defer re.Body.Close() 71 | } 72 | body, err := ioutil.ReadAll(re.Body) 73 | if err != nil { 74 | return err 75 | } 76 | log.Infof("%v got response: (%s): %s", w, re.Status, string(body)) 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /circuitbreaker/fallback.go: -------------------------------------------------------------------------------- 1 | package circuitbreaker 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/url" 7 | 8 | "github.com/mailgun/vulcan/errors" 9 | "github.com/mailgun/vulcan/netutils" 10 | "github.com/mailgun/vulcan/request" 11 | ) 12 | 13 | type Response struct { 14 | StatusCode int 15 | ContentType string 16 | Body []byte 17 | } 18 | 19 | func (re *Response) getHTTPResponse(r request.Request) *http.Response { 20 | return netutils.NewHttpResponse(r.GetHttpRequest(), re.StatusCode, re.Body, re.ContentType) 21 | } 22 | 23 | type ResponseFallback struct { 24 | r Response 25 | } 26 | 27 | func NewResponseFallback(r Response) (*ResponseFallback, error) { 28 | if r.StatusCode == 0 { 29 | return nil, fmt.Errorf("response code should not be 0") 30 | } 31 | return &ResponseFallback{r: r}, nil 32 | } 33 | 34 | func (f *ResponseFallback) ProcessRequest(r request.Request) (*http.Response, error) { 35 | return f.r.getHTTPResponse(r), nil 36 | } 37 | 38 | func (f *ResponseFallback) ProcessResponse(r request.Request, a request.Attempt) { 39 | } 40 | 41 | type Redirect struct { 42 | URL string 43 | } 44 | 45 | type RedirectFallback struct { 46 | u *url.URL 47 | } 48 | 49 | func NewRedirectFallback(r Redirect) (*RedirectFallback, error) { 50 | u, err := netutils.ParseUrl(r.URL) 51 | if err != nil { 52 | return nil, err 53 | } 54 | return &RedirectFallback{u: u}, nil 55 | } 56 | 57 | func (f *RedirectFallback) ProcessRequest(r request.Request) (*http.Response, error) { 58 | return nil, &errors.RedirectError{URL: netutils.CopyUrl(f.u)} 59 | } 60 | 61 | func (f *RedirectFallback) ProcessResponse(r request.Request, a request.Attempt) { 62 | } 63 | -------------------------------------------------------------------------------- /circuitbreaker/predicates.go: -------------------------------------------------------------------------------- 1 | package circuitbreaker 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/mailgun/log" 8 | "github.com/mailgun/predicate" 9 | "github.com/mailgun/vulcan/metrics" 10 | "github.com/mailgun/vulcan/request" 11 | "github.com/mailgun/vulcan/threshold" 12 | ) 13 | 14 | // MustParseExpresison calls ParseExpression and panics if expression is incorrect, for use in tests 15 | func MustParseExpression(in string) threshold.Predicate { 16 | e, err := ParseExpression(in) 17 | if err != nil { 18 | panic(err) 19 | } 20 | return e 21 | } 22 | 23 | // ParseExpression parses expression in the go language into predicates. 24 | func ParseExpression(in string) (threshold.Predicate, error) { 25 | p, err := predicate.NewParser(predicate.Def{ 26 | Operators: predicate.Operators{ 27 | AND: threshold.AND, 28 | OR: threshold.OR, 29 | EQ: threshold.EQ, 30 | NEQ: threshold.NEQ, 31 | LT: threshold.LT, 32 | LE: threshold.LE, 33 | GT: threshold.GT, 34 | GE: threshold.GE, 35 | }, 36 | Functions: map[string]interface{}{ 37 | "LatencyAtQuantileMS": latencyAtQuantile, 38 | "NetworkErrorRatio": networkErrorRatio, 39 | "ResponseCodeRatio": responseCodeRatio, 40 | }, 41 | }) 42 | if err != nil { 43 | return nil, err 44 | } 45 | out, err := p.Parse(in) 46 | if err != nil { 47 | return nil, err 48 | } 49 | pr, ok := out.(threshold.Predicate) 50 | if !ok { 51 | return nil, fmt.Errorf("expected predicate, got %T", out) 52 | } 53 | return pr, nil 54 | } 55 | 56 | func latencyAtQuantile(quantile float64) threshold.RequestToInt { 57 | return func(r request.Request) int { 58 | m := getMetrics(r) 59 | if m == nil { 60 | return 0 61 | } 62 | h, err := m.GetLatencyHistogram() 63 | if err != nil { 64 | log.Errorf("Failed to get latency histogram, for %v error: %v", r, err) 65 | return 0 66 | } 67 | return int(h.LatencyAtQuantile(quantile) / time.Millisecond) 68 | } 69 | } 70 | 71 | func networkErrorRatio() threshold.RequestToFloat64 { 72 | return func(r request.Request) float64 { 73 | m := getMetrics(r) 74 | if m == nil { 75 | return 0 76 | } 77 | return m.GetNetworkErrorRatio() 78 | } 79 | } 80 | 81 | func responseCodeRatio(startA, endA, startB, endB int) threshold.RequestToFloat64 { 82 | return func(r request.Request) float64 { 83 | m := getMetrics(r) 84 | if m == nil { 85 | return 0 86 | } 87 | return m.GetResponseCodeRatio(startA, endA, startB, endB) 88 | } 89 | } 90 | 91 | func getMetrics(r request.Request) *metrics.RoundTripMetrics { 92 | m, ok := r.GetUserData(cbreakerMetrics) 93 | if !ok { 94 | return nil 95 | } 96 | return m.(*metrics.RoundTripMetrics) 97 | } 98 | -------------------------------------------------------------------------------- /circuitbreaker/predicates_test.go: -------------------------------------------------------------------------------- 1 | package circuitbreaker 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/mailgun/vulcan/request" 7 | . "gopkg.in/check.v1" 8 | ) 9 | 10 | type PredicatesSuite struct { 11 | } 12 | 13 | var _ = Suite(&PredicatesSuite{}) 14 | 15 | func (s *PredicatesSuite) TestTriggered(c *C) { 16 | predicates := []struct { 17 | Expression string 18 | Request request.Request 19 | V bool 20 | }{ 21 | { 22 | Expression: "NetworkErrorRatio() > 0.5", 23 | Request: makeRequest(O{stats: statsNetErrors(0.6)}), 24 | V: true, 25 | }, 26 | { 27 | Expression: "NetworkErrorRatio() < 0.5", 28 | Request: makeRequest(O{stats: statsNetErrors(0.6)}), 29 | V: false, 30 | }, 31 | { 32 | Expression: "LatencyAtQuantileMS(50.0) > 50", 33 | Request: makeRequest(O{stats: statsLatencyAtQuantile(50, time.Millisecond*51)}), 34 | V: true, 35 | }, 36 | { 37 | Expression: "LatencyAtQuantileMS(50.0) < 50", 38 | Request: makeRequest(O{stats: statsLatencyAtQuantile(50, time.Millisecond*51)}), 39 | V: false, 40 | }, 41 | { 42 | Expression: "ResponseCodeRatio(500, 600, 0, 600) > 0.5", 43 | Request: makeRequest(O{stats: statsResponseCodes(statusCode{Code: 200, Count: 5}, statusCode{Code: 500, Count: 6})}), 44 | V: true, 45 | }, 46 | { 47 | Expression: "ResponseCodeRatio(500, 600, 0, 600) > 0.5", 48 | Request: makeRequest(O{stats: statsResponseCodes(statusCode{Code: 200, Count: 5}, statusCode{Code: 500, Count: 4})}), 49 | V: false, 50 | }, 51 | } 52 | for _, t := range predicates { 53 | p, err := ParseExpression(t.Expression) 54 | c.Assert(err, IsNil) 55 | c.Assert(p, NotNil) 56 | 57 | c.Assert(p(t.Request), Equals, t.V) 58 | } 59 | } 60 | 61 | func (s *PredicatesSuite) TestErrors(c *C) { 62 | predicates := []struct { 63 | Expression string 64 | Request request.Request 65 | }{ 66 | { 67 | Expression: "LatencyAtQuantileMS(40.0) > 50", // quantile not defined 68 | Request: makeRequest(O{stats: statsNetErrors(0.6)}), 69 | }, 70 | { 71 | Expression: "LatencyAtQuantileMS(40.0) > 50", // stats are not defined 72 | Request: makeRequest(O{stats: nil}), 73 | }, 74 | { 75 | Expression: "NetworkErrorRatio() > 20.0", // stats are not defined 76 | Request: makeRequest(O{stats: nil}), 77 | }, 78 | { 79 | Expression: "NetworkErrorRatio() > 20.0", // no last attempt 80 | Request: makeRequest(O{noAttempts: true}), 81 | }, 82 | } 83 | for _, t := range predicates { 84 | p, err := ParseExpression(t.Expression) 85 | c.Assert(err, IsNil) 86 | c.Assert(p, NotNil) 87 | 88 | c.Assert(p(t.Request), Equals, false) 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /circuitbreaker/ratio.go: -------------------------------------------------------------------------------- 1 | package circuitbreaker 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/mailgun/log" 8 | "github.com/mailgun/timetools" 9 | ) 10 | 11 | // ratioController allows passing portions traffic back to the endpoints, 12 | // increasing the amount of passed requests using linear function: 13 | // 14 | // allowedRequestsRatio = 0.5 * (Now() - Start())/Duration 15 | // 16 | type ratioController struct { 17 | duration time.Duration 18 | start time.Time 19 | tm timetools.TimeProvider 20 | allowed int 21 | denied int 22 | } 23 | 24 | func newRatioController(tm timetools.TimeProvider, rampUp time.Duration) *ratioController { 25 | return &ratioController{ 26 | duration: rampUp, 27 | tm: tm, 28 | start: tm.UtcNow(), 29 | } 30 | } 31 | 32 | func (r *ratioController) String() string { 33 | return fmt.Sprintf("RatioController(target=%f, current=%f, allowed=%d, denied=%d)", r.targetRatio(), r.computeRatio(r.allowed, r.denied), r.allowed, r.denied) 34 | } 35 | 36 | func (r *ratioController) allowRequest() bool { 37 | log.Infof("%v", r) 38 | t := r.targetRatio() 39 | // This condition answers the question - would we satisfy the target ratio if we allow this request? 40 | e := r.computeRatio(r.allowed+1, r.denied) 41 | if e < t { 42 | r.allowed++ 43 | log.Infof("%v allowed", r) 44 | return true 45 | } 46 | r.denied++ 47 | log.Infof("%v denied", r) 48 | return false 49 | } 50 | 51 | func (r *ratioController) computeRatio(allowed, denied int) float64 { 52 | if denied+allowed == 0 { 53 | return 0 54 | } 55 | return float64(allowed) / float64(denied+allowed) 56 | } 57 | 58 | func (r *ratioController) targetRatio() float64 { 59 | // Here's why it's 0.5: 60 | // We are watching the following ratio 61 | // ratio = a / (a + d) 62 | // We can notice, that once we get to 0.5 63 | // 0.5 = a / (a + d) 64 | // we can evaluate that a = d 65 | // that means equilibrium, where we would allow all the requests 66 | // after this point to achieve ratio of 1 (that can never be reached unless d is 0) 67 | // so we stop from there 68 | multiplier := 0.5 / float64(r.duration) 69 | return multiplier * float64(r.tm.UtcNow().Sub(r.start)) 70 | } 71 | -------------------------------------------------------------------------------- /circuitbreaker/ratio_test.go: -------------------------------------------------------------------------------- 1 | package circuitbreaker 2 | 3 | import ( 4 | "math" 5 | "time" 6 | 7 | "github.com/mailgun/timetools" 8 | . "gopkg.in/check.v1" 9 | ) 10 | 11 | type RatioSuite struct { 12 | tm *timetools.FreezedTime 13 | } 14 | 15 | var _ = Suite(&RatioSuite{ 16 | tm: &timetools.FreezedTime{ 17 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 18 | }, 19 | }) 20 | 21 | func (s *RatioSuite) advanceTime(d time.Duration) { 22 | s.tm.CurrentTime = s.tm.CurrentTime.Add(d) 23 | } 24 | 25 | func (s *RatioSuite) TestRampUp(c *C) { 26 | duration := 10 * time.Second 27 | rc := newRatioController(s.tm, duration) 28 | 29 | allowed, denied := 0, 0 30 | for i := 0; i < int(duration/time.Millisecond); i++ { 31 | ratio := s.sendRequest(&allowed, &denied, rc) 32 | expected := rc.targetRatio() 33 | diff := math.Abs(expected - ratio) 34 | c.Assert(round(diff, 0.5, 1), Equals, float64(0)) 35 | s.advanceTime(time.Millisecond) 36 | } 37 | } 38 | 39 | func (s *RatioSuite) sendRequest(allowed, denied *int, rc *ratioController) float64 { 40 | if rc.allowRequest() { 41 | *allowed++ 42 | } else { 43 | *denied++ 44 | } 45 | if *allowed+*denied == 0 { 46 | return 0 47 | } 48 | return float64(*allowed) / float64(*allowed+*denied) 49 | } 50 | 51 | func round(val float64, roundOn float64, places int) float64 { 52 | pow := math.Pow(10, float64(places)) 53 | digit := pow * val 54 | _, div := math.Modf(digit) 55 | var round float64 56 | if div >= roundOn { 57 | round = math.Ceil(digit) 58 | } else { 59 | round = math.Floor(digit) 60 | } 61 | return round / pow 62 | } 63 | -------------------------------------------------------------------------------- /endpoint/endpoint.go: -------------------------------------------------------------------------------- 1 | /*Endpoints - final destination of the http request 2 | */ 3 | package endpoint 4 | 5 | import ( 6 | "fmt" 7 | "github.com/mailgun/vulcan/netutils" 8 | "net/url" 9 | ) 10 | 11 | type Endpoint interface { 12 | GetId() string 13 | GetUrl() *url.URL 14 | String() string 15 | } 16 | 17 | type HttpEndpoint struct { 18 | url *url.URL 19 | id string 20 | } 21 | 22 | func ParseUrl(in string) (*HttpEndpoint, error) { 23 | url, err := netutils.ParseUrl(in) 24 | if err != nil { 25 | return nil, err 26 | } 27 | return &HttpEndpoint{url: url, id: fmt.Sprintf("%s://%s", url.Scheme, url.Host)}, nil 28 | } 29 | 30 | func MustParseUrl(in string) *HttpEndpoint { 31 | u, err := ParseUrl(in) 32 | if err != nil { 33 | panic(err) 34 | } 35 | return u 36 | } 37 | 38 | func NewHttpEndpoint(in *url.URL) (*HttpEndpoint, error) { 39 | if in == nil { 40 | return nil, fmt.Errorf("Provide url") 41 | } 42 | return &HttpEndpoint{ 43 | url: netutils.CopyUrl(in), 44 | id: fmt.Sprintf("%s://%s", in.Scheme, in.Host)}, nil 45 | } 46 | 47 | func (e *HttpEndpoint) String() string { 48 | return e.url.String() 49 | } 50 | 51 | func (e *HttpEndpoint) GetId() string { 52 | return e.id 53 | } 54 | 55 | func (e *HttpEndpoint) GetUrl() *url.URL { 56 | return e.url 57 | } 58 | -------------------------------------------------------------------------------- /errors/error.go: -------------------------------------------------------------------------------- 1 | // Utility functions for producing errorneous http responses 2 | package errors 3 | 4 | import ( 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "net/url" 9 | 10 | "github.com/mailgun/log" 11 | ) 12 | 13 | const ( 14 | StatusTooManyRequests = 429 15 | ) 16 | 17 | type ProxyError interface { 18 | GetStatusCode() int 19 | Error() string 20 | Headers() http.Header 21 | } 22 | 23 | type Formatter interface { 24 | Format(ProxyError) (statusCode int, body []byte, contentType string) 25 | } 26 | 27 | type JsonFormatter struct { 28 | } 29 | 30 | func (f *JsonFormatter) Format(err ProxyError) (int, []byte, string) { 31 | encodedError, e := json.Marshal(map[string]interface{}{ 32 | "error": string(err.Error()), 33 | }) 34 | if e != nil { 35 | log.Errorf("Failed to serialize: %s", e) 36 | encodedError = []byte("{}") 37 | } 38 | return err.GetStatusCode(), encodedError, "application/json" 39 | } 40 | 41 | type HttpError struct { 42 | StatusCode int 43 | Body string 44 | } 45 | 46 | func FromStatus(statusCode int) *HttpError { 47 | return &HttpError{statusCode, http.StatusText(statusCode)} 48 | } 49 | 50 | func (r *HttpError) Headers() http.Header { 51 | return nil 52 | } 53 | 54 | func (r *HttpError) Error() string { 55 | return r.Body 56 | } 57 | 58 | func (r *HttpError) GetStatusCode() int { 59 | return r.StatusCode 60 | } 61 | 62 | type RedirectError struct { 63 | URL *url.URL 64 | } 65 | 66 | func (r *RedirectError) Error() string { 67 | return fmt.Sprintf("Redirect(url=%v)", r.URL) 68 | } 69 | 70 | func (r *RedirectError) GetStatusCode() int { 71 | return http.StatusFound 72 | } 73 | 74 | func (r *RedirectError) Headers() http.Header { 75 | h := make(http.Header) 76 | h.Set("Location", r.URL.String()) 77 | return h 78 | } 79 | -------------------------------------------------------------------------------- /headers/headers.go: -------------------------------------------------------------------------------- 1 | // Constants with common HTTP headers 2 | package headers 3 | 4 | const ( 5 | XForwardedProto = "X-Forwarded-Proto" 6 | XForwardedFor = "X-Forwarded-For" 7 | XForwardedHost = "X-Forwarded-Host" 8 | XForwardedServer = "X-Forwarded-Server" 9 | Connection = "Connection" 10 | KeepAlive = "Keep-Alive" 11 | ProxyAuthenticate = "Proxy-Authenticate" 12 | ProxyAuthorization = "Proxy-Authorization" 13 | Te = "Te" // canonicalized version of "TE" 14 | Trailers = "Trailers" 15 | TransferEncoding = "Transfer-Encoding" 16 | Upgrade = "Upgrade" 17 | ContentLength = "Content-Length" 18 | ) 19 | 20 | // Hop-by-hop headers. These are removed when sent to the backend. 21 | // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html 22 | // Copied from reverseproxy.go, too bad 23 | var HopHeaders = []string{ 24 | Connection, 25 | KeepAlive, 26 | ProxyAuthenticate, 27 | ProxyAuthorization, 28 | Te, // canonicalized version of "TE" 29 | Trailers, 30 | TransferEncoding, 31 | Upgrade, 32 | } 33 | -------------------------------------------------------------------------------- /limit/connlimit/connlimiter.go: -------------------------------------------------------------------------------- 1 | // Simultaneous connection limiter 2 | package connlimit 3 | 4 | import ( 5 | "fmt" 6 | "github.com/mailgun/vulcan/errors" 7 | "github.com/mailgun/vulcan/limit" 8 | "github.com/mailgun/vulcan/netutils" 9 | "github.com/mailgun/vulcan/request" 10 | "net/http" 11 | "sync" 12 | ) 13 | 14 | // This limiter tracks concurrent connection per token 15 | // and is capable of rejecting connections if they are failed 16 | type ConnectionLimiter struct { 17 | mutex *sync.Mutex 18 | mapper limit.MapperFn 19 | connections map[string]int64 20 | maxConnections int64 21 | totalConnections int64 22 | } 23 | 24 | func NewClientIpLimiter(maxConnections int64) (*ConnectionLimiter, error) { 25 | return NewConnectionLimiter(limit.MapClientIp, maxConnections) 26 | } 27 | 28 | func NewConnectionLimiter(mapper limit.MapperFn, maxConnections int64) (*ConnectionLimiter, error) { 29 | if mapper == nil { 30 | return nil, fmt.Errorf("Mapper function can not be nil") 31 | } 32 | if maxConnections <= 0 { 33 | return nil, fmt.Errorf("Max connections should be >= 0") 34 | } 35 | return &ConnectionLimiter{ 36 | mutex: &sync.Mutex{}, 37 | mapper: mapper, 38 | maxConnections: maxConnections, 39 | connections: make(map[string]int64), 40 | }, nil 41 | } 42 | 43 | func (cl *ConnectionLimiter) ProcessRequest(r request.Request) (*http.Response, error) { 44 | cl.mutex.Lock() 45 | defer cl.mutex.Unlock() 46 | 47 | token, amount, err := cl.mapper(r) 48 | if err != nil { 49 | return nil, err 50 | } 51 | 52 | connections := cl.connections[token] 53 | if connections >= cl.maxConnections { 54 | return netutils.NewTextResponse( 55 | r.GetHttpRequest(), 56 | errors.StatusTooManyRequests, 57 | fmt.Sprintf("Connection limit reached. Max is: %d, yours: %d", cl.maxConnections, connections)), nil 58 | } 59 | 60 | cl.connections[token] += amount 61 | cl.totalConnections += int64(amount) 62 | return nil, nil 63 | } 64 | 65 | func (cl *ConnectionLimiter) ProcessResponse(r request.Request, a request.Attempt) { 66 | cl.mutex.Lock() 67 | defer cl.mutex.Unlock() 68 | 69 | token, amount, err := cl.mapper(r) 70 | if err != nil { 71 | return 72 | } 73 | cl.connections[token] -= amount 74 | cl.totalConnections -= int64(amount) 75 | 76 | // Otherwise it would grow forever 77 | if cl.connections[token] == 0 { 78 | delete(cl.connections, token) 79 | } 80 | } 81 | 82 | func (cl *ConnectionLimiter) GetConnectionCount() int64 { 83 | cl.mutex.Lock() 84 | defer cl.mutex.Unlock() 85 | return cl.totalConnections 86 | } 87 | 88 | func (cl *ConnectionLimiter) GetMaxConnections() int64 { 89 | return cl.maxConnections 90 | } 91 | 92 | func (cl *ConnectionLimiter) SetMaxConnections(max int64) { 93 | cl.maxConnections = max 94 | } 95 | -------------------------------------------------------------------------------- /limit/connlimit/connlimiter_test.go: -------------------------------------------------------------------------------- 1 | package connlimit 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/mailgun/vulcan/request" 8 | . "gopkg.in/check.v1" 9 | ) 10 | 11 | func TestConn(t *testing.T) { TestingT(t) } 12 | 13 | type ConnLimiterSuite struct { 14 | } 15 | 16 | var _ = Suite(&ConnLimiterSuite{}) 17 | 18 | func (s *ConnLimiterSuite) SetUpSuite(c *C) { 19 | } 20 | 21 | // We've hit the limit and were able to proceed once the request has completed 22 | func (s *ConnLimiterSuite) TestHitLimitAndRelease(c *C) { 23 | l, err := NewClientIpLimiter(1) 24 | c.Assert(err, Equals, nil) 25 | 26 | r := makeRequest("1.2.3.4") 27 | 28 | re, err := l.ProcessRequest(r) 29 | c.Assert(re, IsNil) 30 | c.Assert(err, IsNil) 31 | 32 | // Next request from the same ip hits rate limit, because the active connections > 1 33 | re, err = l.ProcessRequest(r) 34 | c.Assert(re, NotNil) 35 | c.Assert(err, IsNil) 36 | 37 | // Once the first request finished, next one succeeds 38 | l.ProcessResponse(r, nil) 39 | 40 | re, err = l.ProcessRequest(r) 41 | c.Assert(err, IsNil) 42 | c.Assert(re, IsNil) 43 | } 44 | 45 | // Make sure connections are counted independently for different ips 46 | func (s *ConnLimiterSuite) TestDifferentIps(c *C) { 47 | l, err := NewClientIpLimiter(1) 48 | c.Assert(err, Equals, nil) 49 | 50 | r := makeRequest("1.2.3.4") 51 | r2 := makeRequest("1.2.3.5") 52 | 53 | re, err := l.ProcessRequest(r) 54 | c.Assert(re, IsNil) 55 | c.Assert(err, IsNil) 56 | 57 | re, err = l.ProcessRequest(r) 58 | c.Assert(re, NotNil) 59 | c.Assert(err, IsNil) 60 | 61 | re, err = l.ProcessRequest(r2) 62 | c.Assert(re, IsNil) 63 | c.Assert(err, IsNil) 64 | } 65 | 66 | // Make sure connections are counted independently for different ips 67 | func (s *ConnLimiterSuite) TestConnectionCount(c *C) { 68 | l, err := NewClientIpLimiter(1) 69 | c.Assert(err, Equals, nil) 70 | 71 | r := makeRequest("1.2.3.4") 72 | r2 := makeRequest("1.2.3.5") 73 | 74 | re, err := l.ProcessRequest(r) 75 | c.Assert(re, IsNil) 76 | c.Assert(err, IsNil) 77 | c.Assert(l.GetConnectionCount(), Equals, int64(1)) 78 | 79 | re, err = l.ProcessRequest(r) 80 | c.Assert(re, NotNil) 81 | c.Assert(err, IsNil) 82 | c.Assert(l.GetConnectionCount(), Equals, int64(1)) 83 | 84 | re, err = l.ProcessRequest(r2) 85 | c.Assert(re, IsNil) 86 | c.Assert(err, IsNil) 87 | c.Assert(l.GetConnectionCount(), Equals, int64(2)) 88 | 89 | l.ProcessResponse(r, nil) 90 | c.Assert(l.GetConnectionCount(), Equals, int64(1)) 91 | 92 | l.ProcessResponse(r2, nil) 93 | c.Assert(l.GetConnectionCount(), Equals, int64(0)) 94 | } 95 | 96 | // We've failed to extract client ip, everything crashes, bam! 97 | func (s *ConnLimiterSuite) TestFailure(c *C) { 98 | l, err := NewClientIpLimiter(1) 99 | c.Assert(err, IsNil) 100 | re, err := l.ProcessRequest(makeRequest("")) 101 | c.Assert(err, NotNil) 102 | c.Assert(re, IsNil) 103 | } 104 | 105 | func (s *ConnLimiterSuite) TestWrongParams(c *C) { 106 | _, err := NewConnectionLimiter(nil, 1) 107 | c.Assert(err, NotNil) 108 | 109 | _, err = NewClientIpLimiter(0) 110 | c.Assert(err, NotNil) 111 | 112 | _, err = NewClientIpLimiter(-1) 113 | c.Assert(err, NotNil) 114 | } 115 | 116 | func makeRequest(ip string) request.Request { 117 | return &request.BaseRequest{ 118 | HttpRequest: &http.Request{ 119 | RemoteAddr: ip, 120 | }, 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /limit/limiter.go: -------------------------------------------------------------------------------- 1 | // Interfaces for request limiting 2 | package limit 3 | 4 | import ( 5 | "fmt" 6 | "github.com/mailgun/vulcan/middleware" 7 | "github.com/mailgun/vulcan/request" 8 | "strings" 9 | ) 10 | 11 | // Limiter is an interface for request limiters (e.g. rate/connection) limiters 12 | type Limiter interface { 13 | // In case if limiter wants to reject request, it should return http response 14 | // will be proxied to the client. 15 | // In case if limiter returns an error, it will be treated as a request error and will 16 | // potentially activate failure recovery and failover algorithms. 17 | // In case if lmimiter wants to delay request, it should return duration > 0 18 | // Otherwise limiter should return (0, nil) to allow request to proceed 19 | middleware.Middleware 20 | } 21 | 22 | // MapperFn takes the request and returns token that corresponds to the request and the amount of tokens this request is going to consume, e.g. 23 | // * Client ip rate limiter - token is a client ip, amount is 1 request 24 | // * Client ip bandwidth limiter - token is a client ip, amount is number of bytes to consume 25 | // In case of error returns non nil error, in this case rate limiter will reject the request. 26 | type MapperFn func(r request.Request) (token string, amount int64, err error) 27 | 28 | // TokenMapperFn maps the request to limiting token 29 | type TokenMapperFn func(r request.Request) (token string, err error) 30 | 31 | // AmountMapperFn maps the request to the amount of tokens to consume 32 | type AmountMapperFn func(r request.Request) (amount int64, err error) 33 | 34 | // MapClientIp creates a mapper that allows rate limiting of requests per client ip 35 | func MapClientIp(req request.Request) (string, int64, error) { 36 | t, err := RequestToClientIp(req) 37 | return t, 1, err 38 | } 39 | 40 | func MapRequestHost(req request.Request) (string, int64, error) { 41 | t, err := RequestToHost(req) 42 | return t, 1, err 43 | } 44 | 45 | func MakeMapRequestHeader(header string) MapperFn { 46 | return MakeMapper(MakeRequestToHeader(header), RequestToCount) 47 | } 48 | 49 | func VariableToMapper(variable string) (MapperFn, error) { 50 | tokenMapper, err := MakeTokenMapperFromVariable(variable) 51 | if err != nil { 52 | return nil, err 53 | } 54 | return MakeMapper(tokenMapper, RequestToCount), nil 55 | } 56 | 57 | // Make mapper constructs the mapper function out of two functions - token mapper and amount mapper 58 | func MakeMapper(t TokenMapperFn, a AmountMapperFn) MapperFn { 59 | return func(r request.Request) (string, int64, error) { 60 | token, err := t(r) 61 | if err != nil { 62 | return "", -1, err 63 | } 64 | amount, err := a(r) 65 | if err != nil { 66 | return "", -1, err 67 | } 68 | return token, amount, nil 69 | } 70 | } 71 | 72 | // RequestToClientIp is a TokenMapper that maps the request to the client IP. 73 | func RequestToClientIp(req request.Request) (string, error) { 74 | vals := strings.SplitN(req.GetHttpRequest().RemoteAddr, ":", 2) 75 | if len(vals[0]) == 0 { 76 | return "", fmt.Errorf("Failed to parse client IP") 77 | } 78 | return vals[0], nil 79 | } 80 | 81 | // RequestToHost maps request to the host value 82 | func RequestToHost(req request.Request) (string, error) { 83 | return req.GetHttpRequest().Host, nil 84 | } 85 | 86 | // RequestToCount maps request to the amount of requests (essentially one) 87 | func RequestToCount(req request.Request) (int64, error) { 88 | return 1, nil 89 | } 90 | 91 | // Maps request to it's size in bytes 92 | func RequestToBytes(req request.Request) (int64, error) { 93 | return req.GetBody().TotalSize() 94 | } 95 | 96 | // MakeTokenMapperByHeader creates a TokenMapper that maps the incoming request to the header value. 97 | func MakeRequestToHeader(header string) TokenMapperFn { 98 | return func(req request.Request) (string, error) { 99 | return req.GetHttpRequest().Header.Get(header), nil 100 | } 101 | } 102 | 103 | // Converts varaiable string to a mapper function used in limiters 104 | func MakeTokenMapperFromVariable(variable string) (TokenMapperFn, error) { 105 | if variable == "client.ip" { 106 | return RequestToClientIp, nil 107 | } 108 | if variable == "request.host" { 109 | return RequestToHost, nil 110 | } 111 | if strings.HasPrefix(variable, "request.header.") { 112 | header := strings.TrimPrefix(variable, "request.header.") 113 | if len(header) == 0 { 114 | return nil, fmt.Errorf("Wrong header: %s", header) 115 | } 116 | return MakeRequestToHeader(header), nil 117 | } 118 | return nil, fmt.Errorf("Unsupported limiting variable: '%s'", variable) 119 | } 120 | -------------------------------------------------------------------------------- /limit/limiter_test.go: -------------------------------------------------------------------------------- 1 | package limit 2 | 3 | import ( 4 | . "gopkg.in/check.v1" 5 | "testing" 6 | ) 7 | 8 | func TestLimit(t *testing.T) { TestingT(t) } 9 | 10 | type LimitSuite struct { 11 | } 12 | 13 | var _ = Suite(&LimitSuite{}) 14 | 15 | func (s *LimitSuite) TestVariableToMapper(c *C) { 16 | m, err := VariableToMapper("client.ip") 17 | c.Assert(err, IsNil) 18 | c.Assert(m, NotNil) 19 | 20 | m, err = VariableToMapper("request.host") 21 | c.Assert(err, IsNil) 22 | c.Assert(m, NotNil) 23 | 24 | m, err = VariableToMapper("request.header.X-Header-Name") 25 | c.Assert(err, IsNil) 26 | c.Assert(m, NotNil) 27 | 28 | m, err = VariableToMapper("rsom") 29 | c.Assert(err, NotNil) 30 | c.Assert(m, IsNil) 31 | } 32 | -------------------------------------------------------------------------------- /limit/tokenbucket/bucket.go: -------------------------------------------------------------------------------- 1 | package tokenbucket 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/mailgun/timetools" 8 | ) 9 | 10 | const UndefinedDelay = -1 11 | 12 | // rate defines token bucket parameters. 13 | type rate struct { 14 | period time.Duration 15 | average int64 16 | burst int64 17 | } 18 | 19 | func (r *rate) String() string { 20 | return fmt.Sprintf("rate(%v/%v, burst=%v)", r.average, r.period, r.burst) 21 | } 22 | 23 | // Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket) 24 | type tokenBucket struct { 25 | // The time period controlled by the bucket in nanoseconds. 26 | period time.Duration 27 | // The number of nanoseconds that takes to add one more token to the total 28 | // number of available tokens. It effectively caches the value that could 29 | // have been otherwise deduced from refillRate. 30 | timePerToken time.Duration 31 | // The maximum number of tokens that can be accumulate in the bucket. 32 | burst int64 33 | // The number of tokens available for consumption at the moment. It can 34 | // nether be larger then capacity. 35 | availableTokens int64 36 | // Interface that gives current time (so tests can override) 37 | clock timetools.TimeProvider 38 | // Tells when tokensAvailable was updated the last time. 39 | lastRefresh time.Time 40 | // The number of tokens consumed the last time. 41 | lastConsumed int64 42 | } 43 | 44 | // newTokenBucket crates a `tokenBucket` instance for the specified `Rate`. 45 | func newTokenBucket(rate *rate, clock timetools.TimeProvider) *tokenBucket { 46 | return &tokenBucket{ 47 | period: rate.period, 48 | timePerToken: time.Duration(int64(rate.period) / rate.average), 49 | burst: rate.burst, 50 | clock: clock, 51 | lastRefresh: clock.UtcNow(), 52 | availableTokens: rate.burst, 53 | } 54 | } 55 | 56 | // consume makes an attempt to consume the specified number of tokens from the 57 | // bucket. If there are enough tokens available then `0, nil` is returned; if 58 | // tokens to consume is larger than the burst size, then an error is returned 59 | // and the delay is not defined; otherwise returned a none zero delay that tells 60 | // how much time the caller needs to wait until the desired number of tokens 61 | // will become available for consumption. 62 | func (tb *tokenBucket) consume(tokens int64) (time.Duration, error) { 63 | tb.updateAvailableTokens() 64 | tb.lastConsumed = 0 65 | if tokens > tb.burst { 66 | return UndefinedDelay, fmt.Errorf("Requested tokens larger than max tokens") 67 | } 68 | if tb.availableTokens < tokens { 69 | return tb.timeTillAvailable(tokens), nil 70 | } 71 | tb.availableTokens -= tokens 72 | tb.lastConsumed = tokens 73 | return 0, nil 74 | } 75 | 76 | // rollback reverts effect of the most recent consumption. If the most recent 77 | // `consume` resulted in an error or a burst overflow, and therefore did not 78 | // modify the number of available tokens, then `rollback` won't do that either. 79 | // It is safe to call this method multiple times, for the second and all 80 | // following calls have no effect. 81 | func (tb *tokenBucket) rollback() { 82 | tb.availableTokens += tb.lastConsumed 83 | tb.lastConsumed = 0 84 | } 85 | 86 | // Update modifies `average` and `burst` fields of the token bucket according 87 | // to the provided `Rate` 88 | func (tb *tokenBucket) update(rate *rate) error { 89 | if rate.period != tb.period { 90 | return fmt.Errorf("Period mismatch: %v != %v", tb.period, rate.period) 91 | } 92 | tb.timePerToken = time.Duration(int64(tb.period) / rate.average) 93 | tb.burst = rate.burst 94 | if tb.availableTokens > rate.burst { 95 | tb.availableTokens = rate.burst 96 | } 97 | return nil 98 | } 99 | 100 | // timeTillAvailable returns the number of nanoseconds that we need to 101 | // wait until the specified number of tokens becomes available for consumption. 102 | func (tb *tokenBucket) timeTillAvailable(tokens int64) time.Duration { 103 | missingTokens := tokens - tb.availableTokens 104 | return time.Duration(missingTokens) * tb.timePerToken 105 | } 106 | 107 | // updateAvailableTokens updates the number of tokens available for consumption. 108 | // It is calculated based on the refill rate, the time passed since last refresh, 109 | // and is limited by the bucket capacity. 110 | func (tb *tokenBucket) updateAvailableTokens() { 111 | now := tb.clock.UtcNow() 112 | timePassed := now.Sub(tb.lastRefresh) 113 | 114 | tokens := tb.availableTokens + int64(timePassed/tb.timePerToken) 115 | // If we haven't added any tokens that means that not enough time has passed, 116 | // in this case do not adjust last refill checkpoint, otherwise it will be 117 | // always moving in time in case of frequent requests that exceed the rate 118 | if tokens != tb.availableTokens { 119 | tb.lastRefresh = now 120 | tb.availableTokens = tokens 121 | } 122 | if tb.availableTokens > tb.burst { 123 | tb.availableTokens = tb.burst 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /limit/tokenbucket/bucketset.go: -------------------------------------------------------------------------------- 1 | package tokenbucket 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "time" 7 | 8 | "github.com/mailgun/timetools" 9 | "sort") 10 | 11 | // TokenBucketSet represents a set of TokenBucket covering different time periods. 12 | type tokenBucketSet struct { 13 | buckets map[time.Duration]*tokenBucket 14 | maxPeriod time.Duration 15 | clock timetools.TimeProvider 16 | } 17 | 18 | // newTokenBucketSet creates a `TokenBucketSet` from the specified `rates`. 19 | func newTokenBucketSet(rates *RateSet, clock timetools.TimeProvider) *tokenBucketSet { 20 | tbs := new(tokenBucketSet) 21 | tbs.clock = clock 22 | // In the majority of cases we will have only one bucket. 23 | tbs.buckets = make(map[time.Duration]*tokenBucket, len(rates.m)) 24 | for _, rate := range rates.m { 25 | newBucket := newTokenBucket(rate, clock) 26 | tbs.buckets[rate.period] = newBucket 27 | tbs.maxPeriod = maxDuration(tbs.maxPeriod, rate.period) 28 | } 29 | return tbs 30 | } 31 | 32 | // Update brings the buckets in the set in accordance with the provided `rates`. 33 | func (tbs *tokenBucketSet) update(rates *RateSet) { 34 | // Update existing buckets and delete those that have no corresponding spec. 35 | for _, bucket := range tbs.buckets { 36 | if rate, ok := rates.m[bucket.period]; ok { 37 | bucket.update(rate) 38 | } else { 39 | delete(tbs.buckets, bucket.period) 40 | } 41 | } 42 | // Add missing buckets. 43 | for _, rate := range rates.m { 44 | if _, ok := tbs.buckets[rate.period]; !ok { 45 | newBucket := newTokenBucket(rate, tbs.clock) 46 | tbs.buckets[rate.period] = newBucket 47 | } 48 | } 49 | // Identify the maximum period in the set 50 | tbs.maxPeriod = 0 51 | for _, bucket := range tbs.buckets { 52 | tbs.maxPeriod = maxDuration(tbs.maxPeriod, bucket.period) 53 | } 54 | } 55 | 56 | func (tbs *tokenBucketSet) consume(tokens int64) (time.Duration, error) { 57 | var maxDelay time.Duration = UndefinedDelay 58 | var firstErr error = nil 59 | for _, tokenBucket := range tbs.buckets { 60 | // We keep calling `Consume` even after a error is returned for one of 61 | // buckets because that allows us to simplify the rollback procedure, 62 | // that is to just call `Rollback` for all buckets. 63 | delay, err := tokenBucket.consume(tokens) 64 | if firstErr == nil { 65 | if err != nil { 66 | firstErr = err 67 | } else { 68 | maxDelay = maxDuration(maxDelay, delay) 69 | } 70 | } 71 | } 72 | // If we could not make ALL buckets consume tokens for whatever reason, 73 | // then rollback consumption for all of them. 74 | if firstErr != nil || maxDelay > 0 { 75 | for _, tokenBucket := range tbs.buckets { 76 | tokenBucket.rollback() 77 | } 78 | } 79 | return maxDelay, firstErr 80 | } 81 | 82 | // debugState returns string that reflects the current state of all buckets in 83 | // this set. It is intended to be used for debugging and testing only. 84 | func (tbs *tokenBucketSet) debugState() string { 85 | periods := sort.IntSlice(make([]int, 0, len(tbs.buckets))) 86 | for period := range tbs.buckets { 87 | periods = append(periods, int(period)) 88 | } 89 | sort.Sort(periods) 90 | bucketRepr := make([]string, 0, len(tbs.buckets)) 91 | for _, period := range periods { 92 | bucket := tbs.buckets[time.Duration(period)] 93 | bucketRepr = append(bucketRepr, fmt.Sprintf("{%v: %v}", bucket.period, bucket.availableTokens)) 94 | } 95 | return strings.Join(bucketRepr, ", ") 96 | } 97 | 98 | func maxDuration(x time.Duration, y time.Duration) time.Duration { 99 | if x > y { 100 | return x 101 | } 102 | return y 103 | } 104 | -------------------------------------------------------------------------------- /limit/tokenbucket/bucketset_test.go: -------------------------------------------------------------------------------- 1 | package tokenbucket 2 | 3 | import ( 4 | // "fmt" 5 | "time" 6 | 7 | "github.com/mailgun/timetools" 8 | . "gopkg.in/check.v1" 9 | ) 10 | 11 | type BucketSetSuite struct { 12 | clock *timetools.FreezedTime 13 | } 14 | 15 | var _ = Suite(&BucketSetSuite{}) 16 | 17 | func (s *BucketSetSuite) SetUpSuite(c *C) { 18 | s.clock = &timetools.FreezedTime{ 19 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 20 | } 21 | } 22 | 23 | // A value returned by `MaxPeriod` corresponds to the longest bucket time period. 24 | func (s *BucketSetSuite) TestLongestPeriod(c *C) { 25 | // Given 26 | rates := NewRateSet() 27 | rates.Add(1*time.Second, 10, 20) 28 | rates.Add(7*time.Second, 10, 20) 29 | rates.Add(5*time.Second, 11, 21) 30 | // When 31 | tbs := newTokenBucketSet(rates, s.clock) 32 | // Then 33 | c.Assert(tbs.maxPeriod, Equals, 7*time.Second) 34 | } 35 | 36 | // Successful token consumption updates state of all buckets in the set. 37 | func (s *BucketSetSuite) TestConsume(c *C) { 38 | // Given 39 | rates := NewRateSet() 40 | rates.Add(1*time.Second, 10, 20) 41 | rates.Add(10*time.Second, 20, 50) 42 | tbs := newTokenBucketSet(rates, s.clock) 43 | // When 44 | delay, err := tbs.consume(15) 45 | // Then 46 | c.Assert(delay, Equals, time.Duration(0)) 47 | c.Assert(err, IsNil) 48 | c.Assert(tbs.debugState(), Equals, "{1s: 5}, {10s: 35}") 49 | } 50 | 51 | // As time goes by all set buckets are refilled with appropriate rates. 52 | func (s *BucketSetSuite) TestConsumeRefill(c *C) { 53 | // Given 54 | rates := NewRateSet() 55 | rates.Add(10*time.Second, 10, 20) 56 | rates.Add(100*time.Second, 20, 50) 57 | tbs := newTokenBucketSet(rates, s.clock) 58 | tbs.consume(15) 59 | c.Assert(tbs.debugState(), Equals, "{10s: 5}, {1m40s: 35}") 60 | // When 61 | s.clock.Sleep(10 * time.Second) 62 | delay, err := tbs.consume(0) // Consumes nothing but forces an internal state update. 63 | // Then 64 | c.Assert(delay, Equals, time.Duration(0)) 65 | c.Assert(err, IsNil) 66 | c.Assert(tbs.debugState(), Equals, "{10s: 15}, {1m40s: 37}") 67 | } 68 | 69 | // If the first bucket in the set has no enough tokens to allow desired 70 | // consumption then an appropriate delay is returned. 71 | func (s *BucketSetSuite) TestConsumeLimitedBy1st(c *C) { 72 | // Given 73 | rates := NewRateSet() 74 | rates.Add(10*time.Second, 10, 10) 75 | rates.Add(100*time.Second, 20, 20) 76 | tbs := newTokenBucketSet(rates, s.clock) 77 | tbs.consume(5) 78 | c.Assert(tbs.debugState(), Equals, "{10s: 5}, {1m40s: 15}") 79 | // When 80 | delay, err := tbs.consume(10) 81 | // Then 82 | c.Assert(delay, Equals, 5*time.Second) 83 | c.Assert(err, IsNil) 84 | c.Assert(tbs.debugState(), Equals, "{10s: 5}, {1m40s: 15}") 85 | } 86 | 87 | // If the second bucket in the set has no enough tokens to allow desired 88 | // consumption then an appropriate delay is returned. 89 | func (s *BucketSetSuite) TestConsumeLimitedBy2st(c *C) { 90 | // Given 91 | rates := NewRateSet() 92 | rates.Add(10*time.Second, 10, 10) 93 | rates.Add(100*time.Second, 20, 20) 94 | tbs := newTokenBucketSet(rates, s.clock) 95 | tbs.consume(10) 96 | s.clock.Sleep(10 * time.Second) 97 | tbs.consume(10) 98 | s.clock.Sleep(5 * time.Second) 99 | tbs.consume(0) 100 | c.Assert(tbs.debugState(), Equals, "{10s: 5}, {1m40s: 3}") 101 | // When 102 | delay, err := tbs.consume(10) 103 | // Then 104 | c.Assert(delay, Equals, 7*(5*time.Second)) 105 | c.Assert(err, IsNil) 106 | c.Assert(tbs.debugState(), Equals, "{10s: 5}, {1m40s: 3}") 107 | } 108 | 109 | // An attempt to consume more tokens then the smallest bucket capacity results 110 | // in error. 111 | func (s *BucketSetSuite) TestConsumeMoreThenBurst(c *C) { 112 | // Given 113 | rates := NewRateSet() 114 | rates.Add(1*time.Second, 10, 20) 115 | rates.Add(10*time.Second, 50, 100) 116 | tbs := newTokenBucketSet(rates, s.clock) 117 | tbs.consume(5) 118 | c.Assert(tbs.debugState(), Equals, "{1s: 15}, {10s: 95}") 119 | // When 120 | _, err := tbs.consume(21) 121 | //Then 122 | c.Assert(tbs.debugState(), Equals, "{1s: 15}, {10s: 95}") 123 | c.Assert(err, NotNil) 124 | } 125 | 126 | // Update operation can add buckets. 127 | func (s *BucketSetSuite) TestUpdateMore(c *C) { 128 | // Given 129 | rates := NewRateSet() 130 | rates.Add(1*time.Second, 10, 20) 131 | rates.Add(10*time.Second, 20, 50) 132 | rates.Add(20*time.Second, 45, 90) 133 | tbs := newTokenBucketSet(rates, s.clock) 134 | tbs.consume(5) 135 | c.Assert(tbs.debugState(), Equals, "{1s: 15}, {10s: 45}, {20s: 85}") 136 | rates = NewRateSet() 137 | rates.Add(10*time.Second, 30, 40) 138 | rates.Add(11*time.Second, 30, 40) 139 | rates.Add(12*time.Second, 30, 40) 140 | rates.Add(13*time.Second, 30, 40) 141 | // When 142 | tbs.update(rates) 143 | // Then 144 | c.Assert(tbs.debugState(), Equals, "{10s: 40}, {11s: 40}, {12s: 40}, {13s: 40}") 145 | c.Assert(tbs.maxPeriod, Equals, 13*time.Second) 146 | } 147 | 148 | // Update operation can remove buckets. 149 | func (s *BucketSetSuite) TestUpdateLess(c *C) { 150 | // Given 151 | rates := NewRateSet() 152 | rates.Add(1*time.Second, 10, 20) 153 | rates.Add(10*time.Second, 20, 50) 154 | rates.Add(20*time.Second, 45, 90) 155 | rates.Add(30*time.Second, 50, 100) 156 | tbs := newTokenBucketSet(rates, s.clock) 157 | tbs.consume(5) 158 | c.Assert(tbs.debugState(), Equals, "{1s: 15}, {10s: 45}, {20s: 85}, {30s: 95}") 159 | rates = NewRateSet() 160 | rates.Add(10*time.Second, 25, 20) 161 | rates.Add(20*time.Second, 30, 21) 162 | // When 163 | tbs.update(rates) 164 | // Then 165 | c.Assert(tbs.debugState(), Equals, "{10s: 20}, {20s: 21}") 166 | c.Assert(tbs.maxPeriod, Equals, 20*time.Second) 167 | } 168 | 169 | // Update operation can remove buckets. 170 | func (s *BucketSetSuite) TestUpdateAllDifferent(c *C) { 171 | // Given 172 | rates := NewRateSet() 173 | rates.Add(10*time.Second, 20, 50) 174 | rates.Add(30*time.Second, 50, 100) 175 | tbs := newTokenBucketSet(rates, s.clock) 176 | tbs.consume(5) 177 | c.Assert(tbs.debugState(), Equals, "{10s: 45}, {30s: 95}") 178 | rates = NewRateSet() 179 | rates.Add(1*time.Second, 10, 40) 180 | rates.Add(60*time.Second, 100, 150) 181 | // When 182 | tbs.update(rates) 183 | // Then 184 | c.Assert(tbs.debugState(), Equals, "{1s: 40}, {1m0s: 150}") 185 | c.Assert(tbs.maxPeriod, Equals, 60*time.Second) 186 | } 187 | -------------------------------------------------------------------------------- /limit/tokenbucket/tokenlimiter.go: -------------------------------------------------------------------------------- 1 | // Tokenbucket based request rate limiter 2 | package tokenbucket 3 | 4 | import ( 5 | "fmt" 6 | "net/http" 7 | "sync" 8 | "time" 9 | 10 | "github.com/mailgun/log" 11 | "github.com/mailgun/timetools" 12 | "github.com/mailgun/ttlmap" 13 | "github.com/mailgun/vulcan/errors" 14 | "github.com/mailgun/vulcan/limit" 15 | "github.com/mailgun/vulcan/netutils" 16 | "github.com/mailgun/vulcan/request" 17 | ) 18 | 19 | const DefaultCapacity = 65536 20 | 21 | // RateSet maintains a set of rates. It can contain only one rate per period at a time. 22 | type RateSet struct { 23 | m map[time.Duration]*rate 24 | } 25 | 26 | // NewRateSet crates an empty `RateSet` instance. 27 | func NewRateSet() *RateSet { 28 | rs := new(RateSet) 29 | rs.m = make(map[time.Duration]*rate) 30 | return rs 31 | } 32 | 33 | // Add adds a rate to the set. If there is a rate with the same period in the 34 | // set then the new rate overrides the old one. 35 | func (rs *RateSet) Add(period time.Duration, average int64, burst int64) error { 36 | if period <= 0 { 37 | return fmt.Errorf("Invalid period: %v", period) 38 | } 39 | if average <= 0 { 40 | return fmt.Errorf("Invalid average: %v", average) 41 | } 42 | if burst <= 0 { 43 | return fmt.Errorf("Invalid burst: %v", burst) 44 | } 45 | rs.m[period] = &rate{period, average, burst} 46 | return nil 47 | } 48 | 49 | func (rs* RateSet) String() string { 50 | return fmt.Sprint(rs.m) 51 | } 52 | 53 | // ConfigMapperFn is a mapper function that is used by the `TokenLimiter` 54 | // middleware to retrieve `RateSet` from HTTP requests. 55 | type ConfigMapperFn func(r request.Request) (*RateSet, error) 56 | 57 | // TokenLimiter implements rate limiting middleware. 58 | type TokenLimiter struct { 59 | defaultRates *RateSet 60 | mapper limit.MapperFn 61 | configMapper ConfigMapperFn 62 | clock timetools.TimeProvider 63 | mutex sync.Mutex 64 | bucketSets *ttlmap.TtlMap 65 | } 66 | 67 | // NewLimiter constructs a `TokenLimiter` middleware instance. 68 | func NewLimiter(defaultRates *RateSet, capacity int, mapper limit.MapperFn, configMapper ConfigMapperFn, clock timetools.TimeProvider) (*TokenLimiter, error) { 69 | if defaultRates == nil || len(defaultRates.m) == 0 { 70 | return nil, fmt.Errorf("Provide default rates") 71 | } 72 | if mapper == nil { 73 | return nil, fmt.Errorf("Provide mapper function") 74 | } 75 | 76 | // Set default values for optional fields. 77 | if capacity <= 0 { 78 | capacity = DefaultCapacity 79 | } 80 | if clock == nil { 81 | clock = &timetools.RealTime{} 82 | } 83 | 84 | bucketSets, err := ttlmap.NewMapWithProvider(DefaultCapacity, clock) 85 | if err != nil { 86 | return nil, err 87 | } 88 | 89 | return &TokenLimiter{ 90 | defaultRates: defaultRates, 91 | mapper: mapper, 92 | configMapper: configMapper, 93 | clock: clock, 94 | bucketSets: bucketSets, 95 | }, nil 96 | } 97 | 98 | // DefaultRates returns the default rate set of the limiter. The only reason to 99 | // Provide this method is to facilitate testing. 100 | func (tl *TokenLimiter) DefaultRates() *RateSet { 101 | defaultRates := NewRateSet() 102 | for _, r := range tl.defaultRates.m { 103 | defaultRates.Add(r.period, r.average, r.burst) 104 | } 105 | return defaultRates 106 | } 107 | 108 | func (tl *TokenLimiter) ProcessRequest(r request.Request) (*http.Response, error) { 109 | tl.mutex.Lock() 110 | defer tl.mutex.Unlock() 111 | 112 | token, amount, err := tl.mapper(r) 113 | if err != nil { 114 | return nil, err 115 | } 116 | 117 | effectiveRates := tl.effectiveRates(r) 118 | bucketSetI, exists := tl.bucketSets.Get(token) 119 | var bucketSet *tokenBucketSet 120 | 121 | if exists { 122 | bucketSet = bucketSetI.(*tokenBucketSet) 123 | bucketSet.update(effectiveRates) 124 | } else { 125 | bucketSet = newTokenBucketSet(effectiveRates, tl.clock) 126 | // We set ttl as 10 times rate period. E.g. if rate is 100 requests/second per client ip 127 | // the counters for this ip will expire after 10 seconds of inactivity 128 | tl.bucketSets.Set(token, bucketSet, int(bucketSet.maxPeriod/time.Second)*10+1) 129 | } 130 | 131 | delay, err := bucketSet.consume(amount) 132 | if err != nil { 133 | return nil, err 134 | } 135 | if delay > 0 { 136 | return netutils.NewTextResponse(r.GetHttpRequest(), errors.StatusTooManyRequests, "Too many requests"), nil 137 | } 138 | return nil, nil 139 | } 140 | 141 | func (tl *TokenLimiter) ProcessResponse(r request.Request, a request.Attempt) { 142 | } 143 | 144 | // effectiveRates retrieves rates to be applied to the request. 145 | func (tl *TokenLimiter) effectiveRates(r request.Request) *RateSet { 146 | // If configuration mapper is not specified for this instance, then return 147 | // the default bucket specs. 148 | if tl.configMapper == nil { 149 | return tl.defaultRates 150 | } 151 | 152 | rates, err := tl.configMapper(r) 153 | if err != nil { 154 | log.Errorf("Failed to retrieve rates: %v", err) 155 | return tl.defaultRates 156 | } 157 | 158 | // If the returned rate set is empty then used the default one. 159 | if len(rates.m) == 0 { 160 | return tl.defaultRates 161 | } 162 | 163 | return rates 164 | } 165 | -------------------------------------------------------------------------------- /limit/tokenbucket/tokenlimiter_test.go: -------------------------------------------------------------------------------- 1 | package tokenbucket 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/mailgun/timetools" 9 | "github.com/mailgun/vulcan/limit" 10 | "github.com/mailgun/vulcan/request" 11 | . "gopkg.in/check.v1" 12 | ) 13 | 14 | type LimiterSuite struct { 15 | clock *timetools.FreezedTime 16 | } 17 | 18 | var _ = Suite(&LimiterSuite{}) 19 | 20 | func (s *LimiterSuite) SetUpSuite(c *C) { 21 | s.clock = &timetools.FreezedTime{ 22 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 23 | } 24 | } 25 | 26 | func (s *LimiterSuite) TestRateSetAdd(c *C) { 27 | rs := NewRateSet() 28 | 29 | // Invalid period 30 | err := rs.Add(0, 1, 1) 31 | c.Assert(err, NotNil) 32 | 33 | // Invalid Average 34 | err = rs.Add(time.Second, 0, 1) 35 | c.Assert(err, NotNil) 36 | 37 | // Invalid Burst 38 | err = rs.Add(time.Second, 1, 0) 39 | c.Assert(err, NotNil) 40 | 41 | err = rs.Add(time.Second, 1, 1) 42 | c.Assert(err, IsNil) 43 | c.Assert("map[1s:rate(1/1s, burst=1)]", Equals, fmt.Sprint(rs)) 44 | } 45 | 46 | // We've hit the limit and were able to proceed on the next time run 47 | func (s *LimiterSuite) TestHitLimit(c *C) { 48 | rates := NewRateSet() 49 | rates.Add(time.Second, 1, 1) 50 | tl, err := NewLimiter(rates, 0, limit.MapClientIp, nil, s.clock) 51 | c.Assert(err, IsNil) 52 | 53 | re, err := tl.ProcessRequest(makeRequest("1.2.3.4")) 54 | c.Assert(re, IsNil) 55 | c.Assert(err, IsNil) 56 | 57 | // Next request from the same ip hits rate limit 58 | re, err = tl.ProcessRequest(makeRequest("1.2.3.4")) 59 | c.Assert(re, NotNil) 60 | c.Assert(err, IsNil) 61 | 62 | // Second later, the request from this ip will succeed 63 | s.clock.Sleep(time.Second) 64 | re, err = tl.ProcessRequest(makeRequest("1.2.3.4")) 65 | c.Assert(re, IsNil) 66 | c.Assert(err, IsNil) 67 | } 68 | 69 | // We've failed to extract client ip 70 | func (s *LimiterSuite) TestFailure(c *C) { 71 | rates := NewRateSet() 72 | rates.Add(time.Second, 1, 1) 73 | tl, err := NewLimiter(rates, 0, limit.MapClientIp, nil, s.clock) 74 | c.Assert(err, IsNil) 75 | 76 | _, err = tl.ProcessRequest(makeRequest("")) 77 | c.Assert(err, NotNil) 78 | } 79 | 80 | func (s *LimiterSuite) TestInvalidParams(c *C) { 81 | // Rates are missing 82 | _, err := NewLimiter(nil, 0, limit.MapClientIp, nil, s.clock) 83 | c.Assert(err, NotNil) 84 | 85 | // Rates are empty 86 | _, err = NewLimiter(NewRateSet(), 0, limit.MapClientIp, nil, s.clock) 87 | c.Assert(err, NotNil) 88 | 89 | // Mapper is not provided 90 | rates := NewRateSet() 91 | rates.Add(time.Second, 1, 1) 92 | _, err = NewLimiter(rates, 0, nil, nil, s.clock) 93 | c.Assert(err, NotNil) 94 | 95 | // Mapper is not provided 96 | tl, err := NewLimiter(rates, 0, limit.MapClientIp, nil, s.clock) 97 | c.Assert(tl, NotNil) 98 | c.Assert(err, IsNil) 99 | } 100 | 101 | // Make sure rates from different ips are controlled separatedly 102 | func (s *LimiterSuite) TestIsolation(c *C) { 103 | rates := NewRateSet() 104 | rates.Add(time.Second, 1, 1) 105 | tl, err := NewLimiter(rates, 0, limit.MapClientIp, nil, s.clock) 106 | 107 | re, err := tl.ProcessRequest(makeRequest("1.2.3.4")) 108 | c.Assert(err, IsNil) 109 | c.Assert(re, IsNil) 110 | 111 | // Next request from the same ip hits rate limit 112 | re, err = tl.ProcessRequest(makeRequest("1.2.3.4")) 113 | c.Assert(re, NotNil) 114 | c.Assert(err, IsNil) 115 | 116 | // The request from other ip can proceed 117 | re, err = tl.ProcessRequest(makeRequest("1.2.3.5")) 118 | c.Assert(err, IsNil) 119 | c.Assert(err, IsNil) 120 | } 121 | 122 | // Make sure that expiration works (Expiration is triggered after significant amount of time passes) 123 | func (s *LimiterSuite) TestExpiration(c *C) { 124 | rates := NewRateSet() 125 | rates.Add(time.Second, 1, 1) 126 | tl, err := NewLimiter(rates, 0, limit.MapClientIp, nil, s.clock) 127 | 128 | re, err := tl.ProcessRequest(makeRequest("1.2.3.4")) 129 | c.Assert(re, IsNil) 130 | c.Assert(err, IsNil) 131 | 132 | // Next request from the same ip hits rate limit 133 | re, err = tl.ProcessRequest(makeRequest("1.2.3.4")) 134 | c.Assert(re, NotNil) 135 | c.Assert(err, IsNil) 136 | 137 | // 24 hours later, the request from this ip will succeed 138 | s.clock.Sleep(24 * time.Hour) 139 | re, err = tl.ProcessRequest(makeRequest("1.2.3.4")) 140 | c.Assert(err, IsNil) 141 | c.Assert(re, IsNil) 142 | } 143 | 144 | // If configMapper returns error, then the default rate is applied. 145 | func (s *LimiterSuite) TestBadConfigMapper(c *C) { 146 | // Given 147 | configMapper := func(r request.Request) (*RateSet, error) { 148 | return nil, fmt.Errorf("Boom!") 149 | } 150 | rates := NewRateSet() 151 | rates.Add(time.Second, 1, 1) 152 | tl, _ := NewLimiter(rates, 0, limit.MapClientIp, configMapper, s.clock) 153 | req := makeRequest("1.2.3.4") 154 | // When/Then: The default rate is applied, which 1 req/second 155 | response, err := tl.ProcessRequest(req) // Processed 156 | c.Assert(response, IsNil) 157 | c.Assert(err, IsNil) 158 | response, err = tl.ProcessRequest(req) // Rejected 159 | c.Assert(response, NotNil) 160 | c.Assert(err, IsNil) 161 | 162 | s.clock.Sleep(time.Second) 163 | response, err = tl.ProcessRequest(req) // Processed 164 | c.Assert(response, IsNil) 165 | c.Assert(err, IsNil) 166 | } 167 | 168 | // If configMapper returns empty rates, then the default rate is applied. 169 | func (s *LimiterSuite) TestEmptyConfig(c *C) { 170 | // Given 171 | configMapper := func(r request.Request) (*RateSet, error) { 172 | return NewRateSet(), nil 173 | } 174 | rates := NewRateSet() 175 | rates.Add(time.Second, 1, 1) 176 | tl, _ := NewLimiter(rates, 0, limit.MapClientIp, configMapper, s.clock) 177 | req := makeRequest("1.2.3.4") 178 | // When/Then: The default rate is applied, which 1 req/second 179 | response, err := tl.ProcessRequest(req) // Processed 180 | c.Assert(response, IsNil) 181 | c.Assert(err, IsNil) 182 | response, err = tl.ProcessRequest(req) // Rejected 183 | c.Assert(response, NotNil) 184 | c.Assert(err, IsNil) 185 | 186 | s.clock.Sleep(time.Second) 187 | response, err = tl.ProcessRequest(req) // Processed 188 | c.Assert(response, IsNil) 189 | c.Assert(err, IsNil) 190 | } 191 | 192 | // If rate limiting configuration is valid, then it is applied. 193 | func (s *LimiterSuite) TestConfigApplied(c *C) { 194 | // Given 195 | configMapper := func(request.Request) (*RateSet, error) { 196 | rates := NewRateSet() 197 | rates.Add(time.Second, 2, 2) 198 | rates.Add(60*time.Second, 10, 10) 199 | return rates, nil 200 | } 201 | rates := NewRateSet() 202 | rates.Add(time.Second, 1, 1) 203 | tl, _ := NewLimiter(rates, 0, limit.MapClientIp, configMapper, s.clock) 204 | req := makeRequest("1.2.3.4") 205 | // When/Then: The configured rate is applied, which 2 req/second 206 | response, err := tl.ProcessRequest(req) // Processed 207 | c.Assert(response, IsNil) 208 | c.Assert(err, IsNil) 209 | response, err = tl.ProcessRequest(req) // Processed 210 | c.Assert(response, IsNil) 211 | c.Assert(err, IsNil) 212 | response, err = tl.ProcessRequest(req) // Rejected 213 | c.Assert(response, NotNil) 214 | c.Assert(err, IsNil) 215 | 216 | s.clock.Sleep(time.Second) 217 | response, err = tl.ProcessRequest(req) // Processed 218 | c.Assert(response, IsNil) 219 | c.Assert(err, IsNil) 220 | } 221 | 222 | func makeRequest(ip string) request.Request { 223 | return &request.BaseRequest{ 224 | HttpRequest: &http.Request{ 225 | RemoteAddr: ip, 226 | }, 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /loadbalance/balance.go: -------------------------------------------------------------------------------- 1 | // Load balancers control how requests are distributed among multiple endpoints. 2 | package loadbalance 3 | 4 | import ( 5 | . "github.com/mailgun/vulcan/endpoint" 6 | . "github.com/mailgun/vulcan/middleware" 7 | . "github.com/mailgun/vulcan/request" 8 | ) 9 | 10 | type LoadBalancer interface { 11 | // This function will be called each time locaiton would need to choose the next endpoint for the request 12 | NextEndpoint(req Request) (Endpoint, error) 13 | // Load balancer can intercept the request 14 | Middleware 15 | // Load balancer may observe the request stats to get some runtime metrics 16 | Observer 17 | } 18 | -------------------------------------------------------------------------------- /loadbalance/loadbalance_test.go: -------------------------------------------------------------------------------- 1 | package loadbalance 2 | -------------------------------------------------------------------------------- /loadbalance/roundrobin/fsm.go: -------------------------------------------------------------------------------- 1 | package roundrobin 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/mailgun/timetools" 8 | "github.com/mailgun/vulcan/metrics" 9 | ) 10 | 11 | // This handler increases weights on endpoints that perform better than others 12 | // it also rolls back to original weights if the endpoints have changed. 13 | type FSMHandler struct { 14 | // As usual, control time in tests 15 | timeProvider timetools.TimeProvider 16 | // Time that freezes state machine to accumulate stats after updating the weights 17 | backoffDuration time.Duration 18 | // Timer is set to give probing some time to take place 19 | timer time.Time 20 | // Endpoints for this round 21 | endpoints []*WeightedEndpoint 22 | // Precalculated original weights 23 | originalWeights []SuggestedWeight 24 | // Last returned weights 25 | lastWeights []SuggestedWeight 26 | } 27 | 28 | const ( 29 | // This is the maximum weight that handler will set for the endpoint 30 | FSMMaxWeight = 4096 31 | // Multiplier for the endpoint weight 32 | FSMGrowFactor = 16 33 | ) 34 | 35 | func NewFSMHandler() (*FSMHandler, error) { 36 | return NewFSMHandlerWithOptions(&timetools.RealTime{}) 37 | } 38 | 39 | func NewFSMHandlerWithOptions(timeProvider timetools.TimeProvider) (*FSMHandler, error) { 40 | if timeProvider == nil { 41 | return nil, fmt.Errorf("time provider can not be nil") 42 | } 43 | return &FSMHandler{ 44 | timeProvider: timeProvider, 45 | }, nil 46 | } 47 | 48 | func (fsm *FSMHandler) Init(endpoints []*WeightedEndpoint) { 49 | fsm.originalWeights = makeOriginalWeights(endpoints) 50 | fsm.lastWeights = fsm.originalWeights 51 | fsm.endpoints = endpoints 52 | if len(endpoints) > 0 { 53 | fsm.backoffDuration = endpoints[0].meter.GetWindowSize() / 2 54 | } 55 | fsm.timer = fsm.timeProvider.UtcNow().Add(-1 * time.Second) 56 | } 57 | 58 | // Called on every load balancer NextEndpoint call, returns the suggested weights 59 | // on every call, can adjust weights if needed. 60 | func (fsm *FSMHandler) AdjustWeights() ([]SuggestedWeight, error) { 61 | // In this case adjusting weights would have no effect, so do nothing 62 | if len(fsm.endpoints) < 2 { 63 | return fsm.originalWeights, nil 64 | } 65 | // Metrics are not ready 66 | if !metricsReady(fsm.endpoints) { 67 | return fsm.originalWeights, nil 68 | } 69 | if !fsm.timerExpired() { 70 | return fsm.lastWeights, nil 71 | } 72 | // Select endpoints with highest error rates and lower their weight 73 | good, bad := splitEndpoints(fsm.endpoints) 74 | // No endpoints that are different by their quality, so converge weights 75 | if len(bad) == 0 || len(good) == 0 { 76 | weights, changed := fsm.convergeWeights() 77 | if changed { 78 | fsm.lastWeights = weights 79 | fsm.setTimer() 80 | } 81 | return fsm.lastWeights, nil 82 | } 83 | fsm.lastWeights = fsm.adjustWeights(good, bad) 84 | fsm.setTimer() 85 | return fsm.lastWeights, nil 86 | } 87 | 88 | func (fsm *FSMHandler) convergeWeights() ([]SuggestedWeight, bool) { 89 | weights := make([]SuggestedWeight, len(fsm.endpoints)) 90 | // If we have previoulsy changed endpoints try to restore weights to the original state 91 | changed := false 92 | for i, e := range fsm.endpoints { 93 | weights[i] = &EndpointWeight{e, decrease(e.GetOriginalWeight(), e.GetEffectiveWeight())} 94 | if e.GetEffectiveWeight() != e.GetOriginalWeight() { 95 | changed = true 96 | } 97 | } 98 | return normalizeWeights(weights), changed 99 | } 100 | 101 | func (fsm *FSMHandler) adjustWeights(good map[string]bool, bad map[string]bool) []SuggestedWeight { 102 | // Increase weight on good endpoints 103 | weights := make([]SuggestedWeight, len(fsm.endpoints)) 104 | for i, e := range fsm.endpoints { 105 | if good[e.GetId()] && increase(e.GetEffectiveWeight()) <= FSMMaxWeight { 106 | weights[i] = &EndpointWeight{e, increase(e.GetEffectiveWeight())} 107 | } else { 108 | weights[i] = &EndpointWeight{e, e.GetEffectiveWeight()} 109 | } 110 | } 111 | return normalizeWeights(weights) 112 | } 113 | 114 | func weightsGcd(weights []SuggestedWeight) int { 115 | divisor := -1 116 | for _, w := range weights { 117 | if divisor == -1 { 118 | divisor = w.GetWeight() 119 | } else { 120 | divisor = gcd(divisor, w.GetWeight()) 121 | } 122 | } 123 | return divisor 124 | } 125 | 126 | func normalizeWeights(weights []SuggestedWeight) []SuggestedWeight { 127 | gcd := weightsGcd(weights) 128 | if gcd <= 1 { 129 | return weights 130 | } 131 | for _, w := range weights { 132 | w.SetWeight(w.GetWeight() / gcd) 133 | } 134 | return weights 135 | } 136 | 137 | func (fsm *FSMHandler) setTimer() { 138 | fsm.timer = fsm.timeProvider.UtcNow().Add(fsm.backoffDuration) 139 | } 140 | 141 | func (fsm *FSMHandler) timerExpired() bool { 142 | return fsm.timer.Before(fsm.timeProvider.UtcNow()) 143 | } 144 | 145 | func metricsReady(endpoints []*WeightedEndpoint) bool { 146 | for _, e := range endpoints { 147 | if !e.meter.IsReady() { 148 | return false 149 | } 150 | } 151 | return true 152 | } 153 | 154 | func increase(weight int) int { 155 | return weight * FSMGrowFactor 156 | } 157 | 158 | func decrease(target, current int) int { 159 | adjusted := current / FSMGrowFactor 160 | if adjusted < target { 161 | return target 162 | } else { 163 | return adjusted 164 | } 165 | } 166 | 167 | func makeOriginalWeights(endpoints []*WeightedEndpoint) []SuggestedWeight { 168 | weights := make([]SuggestedWeight, len(endpoints)) 169 | for i, e := range endpoints { 170 | weights[i] = &EndpointWeight{ 171 | Weight: e.GetOriginalWeight(), 172 | Endpoint: e, 173 | } 174 | } 175 | return weights 176 | } 177 | 178 | // splitEndpoints splits endpoints into two groups of endpoints with bad and good failure rate. 179 | // It does compare relative performances of the endpoints though, so if all endpoints have approximately the same error rate 180 | // this function returns the result as if all endpoints are equally good. 181 | func splitEndpoints(endpoints []*WeightedEndpoint) (map[string]bool, map[string]bool) { 182 | 183 | failRates := make([]float64, len(endpoints)) 184 | 185 | for i, e := range endpoints { 186 | failRates[i] = e.failRate() 187 | } 188 | 189 | g, b := metrics.SplitFloat64(1.5, 0, failRates) 190 | good, bad := make(map[string]bool, len(g)), make(map[string]bool, len(b)) 191 | 192 | for _, e := range endpoints { 193 | if g[e.failRate()] { 194 | good[e.GetId()] = true 195 | } else { 196 | bad[e.GetId()] = true 197 | } 198 | } 199 | 200 | return good, bad 201 | } 202 | -------------------------------------------------------------------------------- /loadbalance/roundrobin/fsm_test.go: -------------------------------------------------------------------------------- 1 | package roundrobin 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/mailgun/timetools" 8 | "github.com/mailgun/vulcan/endpoint" 9 | "github.com/mailgun/vulcan/metrics" 10 | . "gopkg.in/check.v1" 11 | ) 12 | 13 | type FSMSuite struct { 14 | tm *timetools.FreezedTime 15 | } 16 | 17 | var _ = Suite(&FSMSuite{}) 18 | 19 | func (s *FSMSuite) SetUpTest(c *C) { 20 | s.tm = &timetools.FreezedTime{ 21 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 22 | } 23 | } 24 | 25 | func (s *FSMSuite) newF(endpoints []*WeightedEndpoint) *FSMHandler { 26 | o, err := NewFSMHandlerWithOptions(s.tm) 27 | if err != nil { 28 | panic(err) 29 | } 30 | o.Init(endpoints) 31 | return o 32 | } 33 | 34 | func (s *FSMSuite) advanceTime(d time.Duration) { 35 | s.tm.CurrentTime = s.tm.CurrentTime.Add(d) 36 | } 37 | 38 | // Check our special greater function that neglects insigificant differences 39 | func (s *FSMSuite) TestFSMSplit(c *C) { 40 | vals := []struct { 41 | endpoints []*WeightedEndpoint 42 | good []int 43 | bad []int 44 | }{ 45 | { 46 | endpoints: newW(0, 0), 47 | good: []int{0, 1}, 48 | bad: []int{}, 49 | }, 50 | { 51 | endpoints: newW(0, 1), 52 | good: []int{0}, 53 | bad: []int{1}, 54 | }, 55 | { 56 | endpoints: newW(0.1, 0.1), 57 | good: []int{0, 1}, 58 | bad: []int{}, 59 | }, 60 | { 61 | endpoints: newW(0.15, 0.1), 62 | good: []int{0, 1}, 63 | bad: []int{}, 64 | }, 65 | { 66 | endpoints: newW(0.01, 0.01), 67 | good: []int{0, 1}, 68 | bad: []int{}, 69 | }, 70 | { 71 | endpoints: newW(0.012, 0.01, 1), 72 | good: []int{0, 1}, 73 | bad: []int{2}, 74 | }, 75 | { 76 | endpoints: newW(0, 0, 1, 1), 77 | good: []int{0, 1}, 78 | bad: []int{2, 3}, 79 | }, 80 | { 81 | endpoints: newW(0, 0.1, 0.1, 0), 82 | good: []int{0, 3}, 83 | bad: []int{1, 2}, 84 | }, 85 | { 86 | endpoints: newW(0, 0.01, 0.1, 0), 87 | good: []int{0, 3}, 88 | bad: []int{1, 2}, 89 | }, 90 | { 91 | endpoints: newW(0, 0.01, 0.02, 1), 92 | good: []int{0, 1, 2}, 93 | bad: []int{3}, 94 | }, 95 | { 96 | endpoints: newW(0, 0, 0, 0, 0, 0.01, 0.02, 1), 97 | good: []int{0, 1, 2, 3, 4}, 98 | bad: []int{5, 6, 7}, 99 | }, 100 | } 101 | for _, v := range vals { 102 | good, bad := splitEndpoints(v.endpoints) 103 | for _, id := range v.good { 104 | c.Assert(good[fmt.Sprintf("http://localhost:500%d", id)], Equals, true) 105 | } 106 | for _, id := range v.bad { 107 | c.Assert(bad[fmt.Sprintf("http://localhost:500%d", id)], Equals, true) 108 | } 109 | } 110 | } 111 | 112 | func (s *FSMSuite) TestInvalidParameters(c *C) { 113 | _, err := NewFSMHandlerWithOptions(nil) 114 | c.Assert(err, NotNil) 115 | } 116 | 117 | func (s *FSMSuite) TestNoEndpoints(c *C) { 118 | adjusted, err := s.newF(newW()).AdjustWeights() 119 | c.Assert(err, IsNil) 120 | c.Assert(len(adjusted), Equals, 0) 121 | } 122 | 123 | func (s *FSMSuite) TestOneEndpoint(c *C) { 124 | adjusted, err := s.newF(newW(1)).AdjustWeights() 125 | c.Assert(err, IsNil) 126 | c.Assert(getWeights(adjusted), DeepEquals, []int{1}) 127 | } 128 | 129 | func (s *FSMSuite) TestAllEndpointsAreGood(c *C) { 130 | adjusted, err := s.newF(newW(0, 0)).AdjustWeights() 131 | c.Assert(err, IsNil) 132 | c.Assert(getWeights(adjusted), DeepEquals, []int{1, 1}) 133 | } 134 | 135 | func (s *FSMSuite) TestAllEndpointsAreBad(c *C) { 136 | adjusted, err := s.newF(newW(0.13, 0.14, 0.14)).AdjustWeights() 137 | c.Assert(err, IsNil) 138 | c.Assert(getWeights(adjusted), DeepEquals, []int{1, 1, 1}) 139 | } 140 | 141 | func (s *FSMSuite) TestMetricsAreNotReady(c *C) { 142 | endpoints := []*WeightedEndpoint{ 143 | &WeightedEndpoint{ 144 | meter: &metrics.TestMeter{Rate: 0.5, NotReady: true}, 145 | endpoint: endpoint.MustParseUrl("http://localhost:5000"), 146 | weight: 1, 147 | effectiveWeight: 1, 148 | }, 149 | &WeightedEndpoint{ 150 | meter: &metrics.TestMeter{Rate: 0, NotReady: true}, 151 | endpoint: endpoint.MustParseUrl("http://localhost:5001"), 152 | weight: 1, 153 | effectiveWeight: 1, 154 | }, 155 | } 156 | adjusted, err := s.newF(endpoints).AdjustWeights() 157 | c.Assert(err, IsNil) 158 | c.Assert(getWeights(adjusted), DeepEquals, []int{1, 1}) 159 | } 160 | 161 | func (s *FSMSuite) TestWeightIncrease(c *C) { 162 | endpoints := newW(0.5, 0) 163 | f := s.newF(endpoints) 164 | 165 | adjusted, err := f.AdjustWeights() 166 | 167 | // It will adjust weights and set timer 168 | c.Assert(err, IsNil) 169 | c.Assert(len(adjusted), Equals, 2) 170 | c.Assert(getWeights(adjusted), DeepEquals, []int{1, FSMGrowFactor}) 171 | for _, a := range adjusted { 172 | a.GetEndpoint().setEffectiveWeight(a.GetWeight()) 173 | } 174 | 175 | // We will wait some time until we gather some stats 176 | adjusted, err = f.AdjustWeights() 177 | c.Assert(err, IsNil) 178 | c.Assert(getWeights(adjusted), DeepEquals, []int{1, FSMGrowFactor}) 179 | 180 | // As time passes, let's repeat this procedure to see if we hit the ceiling 181 | for i := 0; i < 6; i += 1 { 182 | adjusted, err := f.AdjustWeights() 183 | c.Assert(err, IsNil) 184 | for _, a := range adjusted { 185 | a.GetEndpoint().setEffectiveWeight(a.GetWeight()) 186 | } 187 | s.advanceTime(endpoints[0].meter.GetWindowSize()/2 + time.Second) 188 | } 189 | 190 | // Algo has not changed the weight of the bad endpoint 191 | c.Assert(endpoints[0].GetEffectiveWeight(), Equals, 1) 192 | // Algo has adjusted the weight of the good endpoint to the maximum number 193 | c.Assert(endpoints[1].GetEffectiveWeight(), Equals, FSMMaxWeight) 194 | } 195 | 196 | func (s *FSMSuite) TestRevert(c *C) { 197 | endpoints := newW(0.5, 0) 198 | f := s.newF(endpoints) 199 | 200 | bad := endpoints[0] 201 | adjusted, err := f.AdjustWeights() 202 | c.Assert(err, IsNil) 203 | c.Assert(getWeights(adjusted), DeepEquals, []int{1, FSMGrowFactor}) 204 | for _, a := range adjusted { 205 | a.GetEndpoint().setEffectiveWeight(a.GetWeight()) 206 | } 207 | 208 | // The situation have recovered, so FSM will try to bring back the bad endpoint into life by reverting the weights back 209 | s.advanceTime(endpoints[0].meter.GetWindowSize()/2 + time.Second) 210 | bad.GetMeter().(*metrics.TestMeter).Rate = 0 211 | f.AdjustWeights() 212 | 213 | adjusted, err = f.AdjustWeights() 214 | c.Assert(err, IsNil) 215 | c.Assert(getWeights(adjusted), DeepEquals, []int{1, 1}) 216 | } 217 | 218 | // Case when the increasing weights went wrong and the good endpoints started failing 219 | func (s *FSMSuite) TestProbingUnsuccessfull(c *C) { 220 | endpoints := newW(0.5, 0.5, 0, 0, 0) 221 | f := s.newF(endpoints) 222 | 223 | adjusted, err := f.AdjustWeights() 224 | 225 | // It will adjust weight and set timer 226 | c.Assert(err, IsNil) 227 | c.Assert(getWeights(adjusted), DeepEquals, []int{1, 1, FSMGrowFactor, FSMGrowFactor, FSMGrowFactor}) 228 | for _, a := range adjusted { 229 | a.GetEndpoint().setEffectiveWeight(a.GetWeight()) 230 | } 231 | // Times has passed and good endpoint appears to behave worse now, oh no! 232 | for _, e := range endpoints { 233 | e.GetMeter().(*metrics.TestMeter).Rate = 0.5 234 | } 235 | s.advanceTime(endpoints[0].meter.GetWindowSize()/2 + time.Second) 236 | 237 | // As long as all endpoints are equally bad now, we will revert weights back 238 | adjusted, err = f.AdjustWeights() 239 | c.Assert(err, IsNil) 240 | c.Assert(getWeights(adjusted), DeepEquals, []int{1, 1, 1, 1, 1}) 241 | } 242 | 243 | func (s *FSMSuite) TestNormalize(c *C) { 244 | weights := newWeights(1, 2, 3, 4) 245 | c.Assert(weights, DeepEquals, normalizeWeights(weights)) 246 | c.Assert(newWeights(1, 1, 1, 4), DeepEquals, normalizeWeights(newWeights(4, 4, 4, 16))) 247 | } 248 | 249 | func newW(failRates ...float64) []*WeightedEndpoint { 250 | out := make([]*WeightedEndpoint, len(failRates)) 251 | for i, r := range failRates { 252 | out[i] = &WeightedEndpoint{ 253 | meter: &metrics.TestMeter{Rate: r, WindowSize: time.Second * 10}, 254 | endpoint: endpoint.MustParseUrl(fmt.Sprintf("http://localhost:500%d", i)), 255 | weight: 1, 256 | effectiveWeight: 1, 257 | } 258 | } 259 | return out 260 | } 261 | 262 | func getWeights(weights []SuggestedWeight) []int { 263 | out := make([]int, len(weights)) 264 | for i, w := range weights { 265 | out[i] = w.GetWeight() 266 | } 267 | return out 268 | } 269 | 270 | func newWeights(weights ...int) []SuggestedWeight { 271 | out := make([]SuggestedWeight, len(weights)) 272 | for i, w := range weights { 273 | out[i] = &EndpointWeight{ 274 | Weight: w, 275 | Endpoint: &WeightedEndpoint{endpoint: endpoint.MustParseUrl(fmt.Sprintf("http://localhost:500%d", i))}, 276 | } 277 | } 278 | return out 279 | } 280 | -------------------------------------------------------------------------------- /loadbalance/roundrobin/recovery.go: -------------------------------------------------------------------------------- 1 | package roundrobin 2 | 3 | type FailureHandler interface { 4 | // Returns error if something bad happened, returns suggested weights 5 | AdjustWeights() ([]SuggestedWeight, error) 6 | // Initializes handler with current set of endpoints. Will be called 7 | // each time endpoints are added or removed from the load balancer 8 | // to give failure handler a chance to set it's itenral state 9 | Init(endpoints []*WeightedEndpoint) 10 | } 11 | 12 | type SuggestedWeight interface { 13 | GetEndpoint() *WeightedEndpoint 14 | GetWeight() int 15 | SetWeight(int) 16 | } 17 | 18 | type EndpointWeight struct { 19 | Endpoint *WeightedEndpoint 20 | Weight int 21 | } 22 | 23 | func (ew *EndpointWeight) GetEndpoint() *WeightedEndpoint { 24 | return ew.Endpoint 25 | } 26 | 27 | func (ew *EndpointWeight) GetWeight() int { 28 | return ew.Weight 29 | } 30 | 31 | func (ew *EndpointWeight) SetWeight(w int) { 32 | ew.Weight = w 33 | } 34 | -------------------------------------------------------------------------------- /loadbalance/roundrobin/roundrobin_test.go: -------------------------------------------------------------------------------- 1 | package roundrobin 2 | 3 | import ( 4 | "fmt" 5 | "github.com/mailgun/timetools" 6 | . "github.com/mailgun/vulcan/endpoint" 7 | . "github.com/mailgun/vulcan/metrics" 8 | . "github.com/mailgun/vulcan/request" 9 | . "gopkg.in/check.v1" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | func Test(t *testing.T) { TestingT(t) } 15 | 16 | type RoundRobinSuite struct { 17 | tm *timetools.FreezedTime 18 | req Request 19 | } 20 | 21 | var _ = Suite(&RoundRobinSuite{}) 22 | 23 | func (s *RoundRobinSuite) SetUpSuite(c *C) { 24 | s.tm = &timetools.FreezedTime{ 25 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 26 | } 27 | s.req = &BaseRequest{} 28 | } 29 | 30 | func (s *RoundRobinSuite) newRR() *RoundRobin { 31 | handler, err := NewFSMHandlerWithOptions(s.tm) 32 | if err != nil { 33 | panic(err) 34 | } 35 | 36 | r, err := NewRoundRobinWithOptions(Options{TimeProvider: s.tm, FailureHandler: handler}) 37 | if err != nil { 38 | panic(err) 39 | } 40 | return r 41 | } 42 | 43 | func (s *RoundRobinSuite) TestNoEndpoints(c *C) { 44 | r := s.newRR() 45 | _, err := r.NextEndpoint(s.req) 46 | c.Assert(err, NotNil) 47 | } 48 | 49 | func (s *RoundRobinSuite) TestDefaultArgs(c *C) { 50 | r, err := NewRoundRobin() 51 | c.Assert(err, IsNil) 52 | 53 | a := MustParseUrl("http://localhost:5000") 54 | b := MustParseUrl("http://localhost:5001") 55 | 56 | r.AddEndpoint(a) 57 | r.AddEndpoint(b) 58 | 59 | u, err := r.NextEndpoint(s.req) 60 | c.Assert(err, IsNil) 61 | c.Assert(u, Equals, a) 62 | 63 | u, err = r.NextEndpoint(s.req) 64 | c.Assert(err, IsNil) 65 | c.Assert(u, Equals, b) 66 | 67 | u, err = r.NextEndpoint(s.req) 68 | c.Assert(err, IsNil) 69 | c.Assert(u, Equals, a) 70 | } 71 | 72 | // Subsequent calls to load balancer with 1 endpoint are ok 73 | func (s *RoundRobinSuite) TestSingleEndpoint(c *C) { 74 | r := s.newRR() 75 | 76 | u := MustParseUrl("http://localhost:5000") 77 | r.AddEndpoint(u) 78 | 79 | u2, err := r.NextEndpoint(s.req) 80 | c.Assert(err, IsNil) 81 | c.Assert(u2, Equals, u) 82 | 83 | u3, err := r.NextEndpoint(s.req) 84 | c.Assert(err, IsNil) 85 | c.Assert(u3, Equals, u) 86 | } 87 | 88 | // Make sure that load balancer round robins requests 89 | func (s *RoundRobinSuite) TestMultipleEndpoints(c *C) { 90 | r := s.newRR() 91 | 92 | uA := MustParseUrl("http://localhost:5000") 93 | uB := MustParseUrl("http://localhost:5001") 94 | r.AddEndpoint(uA) 95 | r.AddEndpoint(uB) 96 | 97 | u, err := r.NextEndpoint(s.req) 98 | c.Assert(err, IsNil) 99 | c.Assert(u, Equals, uA) 100 | 101 | u, err = r.NextEndpoint(s.req) 102 | c.Assert(err, IsNil) 103 | c.Assert(u, Equals, uB) 104 | 105 | u, err = r.NextEndpoint(s.req) 106 | c.Assert(err, IsNil) 107 | c.Assert(u, Equals, uA) 108 | } 109 | 110 | // Make sure that adding endpoints during load balancing works fine 111 | func (s *RoundRobinSuite) TestAddEndpoints(c *C) { 112 | r := s.newRR() 113 | 114 | uA := MustParseUrl("http://localhost:5000") 115 | uB := MustParseUrl("http://localhost:5001") 116 | r.AddEndpoint(uA) 117 | 118 | u, err := r.NextEndpoint(s.req) 119 | c.Assert(err, IsNil) 120 | c.Assert(u, Equals, uA) 121 | 122 | r.AddEndpoint(uB) 123 | 124 | // index was reset after altering endpoints 125 | u, err = r.NextEndpoint(s.req) 126 | c.Assert(err, IsNil) 127 | c.Assert(u, Equals, uA) 128 | 129 | u, err = r.NextEndpoint(s.req) 130 | c.Assert(err, IsNil) 131 | c.Assert(u, Equals, uB) 132 | } 133 | 134 | // Removing endpoints from the load balancer works fine as well 135 | func (s *RoundRobinSuite) TestRemoveEndpoint(c *C) { 136 | r := s.newRR() 137 | 138 | uA := MustParseUrl("http://localhost:5000") 139 | uB := MustParseUrl("http://localhost:5001") 140 | r.AddEndpoint(uA) 141 | r.AddEndpoint(uB) 142 | 143 | u, err := r.NextEndpoint(s.req) 144 | c.Assert(err, IsNil) 145 | c.Assert(u, Equals, uA) 146 | 147 | // Removing endpoint resets the counter 148 | r.RemoveEndpoint(uB) 149 | 150 | u, err = r.NextEndpoint(s.req) 151 | c.Assert(err, IsNil) 152 | c.Assert(u, Equals, uA) 153 | } 154 | 155 | func (s *RoundRobinSuite) TestAddSameEndpoint(c *C) { 156 | r := s.newRR() 157 | 158 | uA := MustParseUrl("http://localhost:5000") 159 | uB := MustParseUrl("http://localhost:5000") 160 | r.AddEndpoint(uA) 161 | c.Assert(r.AddEndpoint(uB), NotNil) 162 | } 163 | 164 | func (s *RoundRobinSuite) TestFindEndpoint(c *C) { 165 | r := s.newRR() 166 | 167 | uA := MustParseUrl("http://localhost:5000") 168 | uB := MustParseUrl("http://localhost:5001") 169 | r.AddEndpoint(uA) 170 | r.AddEndpoint(uB) 171 | 172 | c.Assert(r.FindEndpointById(""), IsNil) 173 | c.Assert(r.FindEndpointById(uA.GetId()).GetId(), Equals, uA.GetId()) 174 | c.Assert(r.FindEndpointByUrl(uA.GetUrl().String()).GetId(), Equals, uA.GetId()) 175 | c.Assert(r.FindEndpointByUrl(""), IsNil) 176 | c.Assert(r.FindEndpointByUrl("http://localhost wrong url 5000"), IsNil) 177 | } 178 | 179 | func (s *RoundRobinSuite) advanceTime(d time.Duration) { 180 | s.tm.CurrentTime = s.tm.CurrentTime.Add(d) 181 | } 182 | 183 | func (s *RoundRobinSuite) TestReactsOnFailures(c *C) { 184 | handler, err := NewFSMHandlerWithOptions(s.tm) 185 | c.Assert(err, IsNil) 186 | 187 | r, err := NewRoundRobinWithOptions( 188 | Options{ 189 | TimeProvider: s.tm, 190 | FailureHandler: handler, 191 | }) 192 | c.Assert(err, IsNil) 193 | 194 | a := MustParseUrl("http://localhost:5000") 195 | aM := &TestMeter{Rate: 0.5} 196 | 197 | b := MustParseUrl("http://localhost:5001") 198 | bM := &TestMeter{Rate: 0} 199 | 200 | r.AddEndpointWithOptions(a, EndpointOptions{Meter: aM}) 201 | r.AddEndpointWithOptions(b, EndpointOptions{Meter: bM}) 202 | 203 | countA, countB := 0, 0 204 | for i := 0; i < 100; i += 1 { 205 | e, err := r.NextEndpoint(s.req) 206 | if e.GetId() == a.GetId() { 207 | countA += 1 208 | } else { 209 | countB += 1 210 | } 211 | c.Assert(e, NotNil) 212 | c.Assert(err, IsNil) 213 | s.advanceTime(time.Duration(time.Second)) 214 | r.ObserveResponse(s.req, &BaseAttempt{Endpoint: e}) 215 | } 216 | c.Assert(countB > countA*2, Equals, true) 217 | } 218 | 219 | // Make sure that failover avoids to hit the same endpoint 220 | func (s *RoundRobinSuite) TestFailoverAvoidsSameEndpoint(c *C) { 221 | r := s.newRR() 222 | 223 | uA := MustParseUrl("http://localhost:5000") 224 | uB := MustParseUrl("http://localhost:5001") 225 | r.AddEndpoint(uA) 226 | r.AddEndpoint(uB) 227 | 228 | failedRequest := &BaseRequest{ 229 | Attempts: []Attempt{ 230 | &BaseAttempt{ 231 | Endpoint: uA, 232 | Error: fmt.Errorf("Something failed"), 233 | }, 234 | }, 235 | } 236 | 237 | u, err := r.NextEndpoint(failedRequest) 238 | c.Assert(err, IsNil) 239 | c.Assert(u, Equals, uB) 240 | } 241 | 242 | // Make sure that failover avoids to hit the same endpoints in case if there are multiple consequent failures 243 | func (s *RoundRobinSuite) TestFailoverAvoidsSameEndpointMultipleFailures(c *C) { 244 | r := s.newRR() 245 | 246 | uA := MustParseUrl("http://localhost:5000") 247 | uB := MustParseUrl("http://localhost:5001") 248 | uC := MustParseUrl("http://localhost:5002") 249 | r.AddEndpoint(uA) 250 | r.AddEndpoint(uB) 251 | r.AddEndpoint(uC) 252 | 253 | failedRequest := &BaseRequest{ 254 | Attempts: []Attempt{ 255 | &BaseAttempt{ 256 | Endpoint: uA, 257 | Error: fmt.Errorf("Something failed"), 258 | }, 259 | &BaseAttempt{ 260 | Endpoint: uB, 261 | Error: fmt.Errorf("Something failed"), 262 | }, 263 | }, 264 | } 265 | 266 | u, err := r.NextEndpoint(failedRequest) 267 | c.Assert(err, IsNil) 268 | c.Assert(u, Equals, uC) 269 | } 270 | 271 | // Removing endpoints from the load balancer works fine as well 272 | func (s *RoundRobinSuite) TestRemoveMultipleEndpoints(c *C) { 273 | r := s.newRR() 274 | 275 | uA := MustParseUrl("http://localhost:5000") 276 | uB := MustParseUrl("http://localhost:5001") 277 | uC := MustParseUrl("http://localhost:5002") 278 | r.AddEndpoint(uA) 279 | r.AddEndpoint(uB) 280 | r.AddEndpoint(uC) 281 | 282 | u, err := r.NextEndpoint(s.req) 283 | c.Assert(err, IsNil) 284 | u, err = r.NextEndpoint(s.req) 285 | c.Assert(err, IsNil) 286 | u, err = r.NextEndpoint(s.req) 287 | c.Assert(err, IsNil) 288 | c.Assert(u, Equals, uC) 289 | 290 | // There's only one endpoint left 291 | r.RemoveEndpoint(uA) 292 | r.RemoveEndpoint(uB) 293 | u, err = r.NextEndpoint(s.req) 294 | c.Assert(err, IsNil) 295 | c.Assert(u, Equals, uC) 296 | } 297 | -------------------------------------------------------------------------------- /loadbalance/roundrobin/wendpoint.go: -------------------------------------------------------------------------------- 1 | package roundrobin 2 | 3 | import ( 4 | "fmt" 5 | "github.com/mailgun/log" 6 | "github.com/mailgun/vulcan/endpoint" 7 | "github.com/mailgun/vulcan/metrics" 8 | "net/url" 9 | ) 10 | 11 | // WeightedEndpoint wraps the endpoint and adds support for weights and failure detection. 12 | type WeightedEndpoint struct { 13 | // meter accumulates endpoint stats and for failure detection 14 | meter metrics.FailRateMeter 15 | 16 | // endpoint is an original endpoint supplied by user 17 | endpoint endpoint.Endpoint 18 | 19 | // weight holds original weight supplied by user 20 | weight int 21 | 22 | // effectiveWeight is the weights assigned by the load balancer based on failure 23 | effectiveWeight int 24 | 25 | // rr is a reference to the parent load balancer 26 | rr *RoundRobin 27 | } 28 | 29 | func (we *WeightedEndpoint) String() string { 30 | return fmt.Sprintf("WeightedEndpoint(id=%s, url=%s, weight=%d, effectiveWeight=%d, failRate=%f)", 31 | we.GetId(), we.GetUrl(), we.weight, we.effectiveWeight, we.meter.GetRate()) 32 | } 33 | 34 | func (we *WeightedEndpoint) GetId() string { 35 | return we.endpoint.GetId() 36 | } 37 | 38 | func (we *WeightedEndpoint) GetUrl() *url.URL { 39 | return we.endpoint.GetUrl() 40 | } 41 | 42 | func (we *WeightedEndpoint) setEffectiveWeight(w int) { 43 | log.Infof("%s setting effective weight to: %d", we, w) 44 | we.effectiveWeight = w 45 | } 46 | 47 | func (we *WeightedEndpoint) GetOriginalEndpoint() endpoint.Endpoint { 48 | return we.endpoint 49 | } 50 | 51 | func (we *WeightedEndpoint) GetOriginalWeight() int { 52 | return we.weight 53 | } 54 | 55 | func (we *WeightedEndpoint) GetEffectiveWeight() int { 56 | return we.effectiveWeight 57 | } 58 | 59 | func (we *WeightedEndpoint) GetMeter() metrics.FailRateMeter { 60 | return we.meter 61 | } 62 | 63 | func (we *WeightedEndpoint) failRate() float64 { 64 | return we.meter.GetRate() 65 | } 66 | 67 | type WeightedEndpoints []*WeightedEndpoint 68 | 69 | func (we WeightedEndpoints) Len() int { 70 | return len(we) 71 | } 72 | 73 | func (we WeightedEndpoints) Swap(i, j int) { 74 | we[i], we[j] = we[j], we[i] 75 | } 76 | 77 | func (we WeightedEndpoints) Less(i, j int) bool { 78 | return we[i].meter.GetRate() < we[j].meter.GetRate() 79 | } 80 | -------------------------------------------------------------------------------- /location/httploc/rewrite.go: -------------------------------------------------------------------------------- 1 | package httploc 2 | 3 | import ( 4 | "net" 5 | "net/http" 6 | "strings" 7 | 8 | "github.com/mailgun/vulcan/headers" 9 | "github.com/mailgun/vulcan/netutils" 10 | "github.com/mailgun/vulcan/request" 11 | ) 12 | 13 | // Rewriter is responsible for removing hop-by-hop headers, fixing encodings and content-length 14 | type Rewriter struct { 15 | TrustForwardHeader bool 16 | Hostname string 17 | } 18 | 19 | func (rw *Rewriter) ProcessRequest(r request.Request) (*http.Response, error) { 20 | req := r.GetHttpRequest() 21 | 22 | if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { 23 | if rw.TrustForwardHeader { 24 | if prior, ok := req.Header[headers.XForwardedFor]; ok { 25 | clientIP = strings.Join(prior, ", ") + ", " + clientIP 26 | } 27 | } 28 | req.Header.Set(headers.XForwardedFor, clientIP) 29 | } 30 | 31 | if xfp := req.Header.Get(headers.XForwardedProto); xfp != "" && rw.TrustForwardHeader { 32 | req.Header.Set(headers.XForwardedProto, xfp) 33 | } else if req.TLS != nil { 34 | req.Header.Set(headers.XForwardedProto, "https") 35 | } else { 36 | req.Header.Set(headers.XForwardedProto, "http") 37 | } 38 | 39 | if req.Host != "" { 40 | req.Header.Set(headers.XForwardedHost, req.Host) 41 | } 42 | req.Header.Set(headers.XForwardedServer, rw.Hostname) 43 | 44 | // Remove hop-by-hop headers to the backend. Especially important is "Connection" because we want a persistent 45 | // connection, regardless of what the client sent to us. 46 | netutils.RemoveHeaders(headers.HopHeaders, req.Header) 47 | 48 | // We need to set ContentLength based on known request size. The incoming request may have been 49 | // set without content length or using chunked TransferEncoding 50 | totalSize, err := r.GetBody().TotalSize() 51 | if err != nil { 52 | return nil, err 53 | } 54 | req.ContentLength = totalSize 55 | // Remove TransferEncoding that could have been previously set 56 | req.TransferEncoding = []string{} 57 | 58 | return nil, nil 59 | } 60 | 61 | func (tl *Rewriter) ProcessResponse(r request.Request, a request.Attempt) { 62 | } 63 | -------------------------------------------------------------------------------- /location/location.go: -------------------------------------------------------------------------------- 1 | // Interfaces for location - round trip the http request to backends 2 | package location 3 | 4 | import ( 5 | "github.com/mailgun/vulcan/netutils" 6 | "github.com/mailgun/vulcan/request" 7 | "net/http" 8 | ) 9 | 10 | // Location accepts proxy request and round trips it to the backend 11 | type Location interface { 12 | // Unique identifier of this location 13 | GetId() string 14 | // Forward the request to a specific location and return the response 15 | RoundTrip(request.Request) (*http.Response, error) 16 | } 17 | 18 | // This location is used in tests 19 | type Loc struct { 20 | Id string 21 | Name string 22 | } 23 | 24 | func (*Loc) RoundTrip(request.Request) (*http.Response, error) { 25 | return nil, nil 26 | } 27 | 28 | func (l *Loc) GetId() string { 29 | return l.Id 30 | } 31 | 32 | // The simplest HTTP location implementation that adds no additional logic 33 | // on top of simple http round trip function call 34 | type ConstHttpLocation struct { 35 | Url string 36 | } 37 | 38 | func (l *ConstHttpLocation) RoundTrip(r request.Request) (*http.Response, error) { 39 | req := r.GetHttpRequest() 40 | req.URL = netutils.MustParseUrl(l.Url) 41 | return http.DefaultTransport.RoundTrip(req) 42 | } 43 | 44 | func (l *ConstHttpLocation) GetId() string { 45 | return l.Url 46 | } 47 | -------------------------------------------------------------------------------- /metrics/anomaly.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "math" 5 | "sort" 6 | "time" 7 | ) 8 | 9 | // SplitRatios provides simple anomaly detection for requests latencies. 10 | // it splits values into good or bad category based on the threshold and the median value. 11 | // If all values are not far from the median, it will return all values in 'good' set. 12 | // Precision is the smallest value to consider, e.g. if set to millisecond, microseconds will be ignored. 13 | func SplitLatencies(values []time.Duration, precision time.Duration) (good map[time.Duration]bool, bad map[time.Duration]bool) { 14 | // Find the max latency M and then map each latency L to the ratio L/M and then call SplitFloat64 15 | v2r := map[float64]time.Duration{} 16 | ratios := make([]float64, len(values)) 17 | m := maxTime(values) 18 | for i, v := range values { 19 | ratio := float64(v/precision+1) / float64(m/precision+1) // +1 is to avoid division by 0 20 | v2r[ratio] = v 21 | ratios[i] = ratio 22 | } 23 | good, bad = make(map[time.Duration]bool), make(map[time.Duration]bool) 24 | // Note that multiplier makes this function way less sensitive than ratios detector, this is to avoid noise. 25 | vgood, vbad := SplitFloat64(2, 0, ratios) 26 | for r, _ := range vgood { 27 | good[v2r[r]] = true 28 | } 29 | for r, _ := range vbad { 30 | bad[v2r[r]] = true 31 | } 32 | return good, bad 33 | } 34 | 35 | // SplitRatios provides simple anomaly detection for ratio values, that are all in the range [0, 1] 36 | // it splits values into good or bad category based on the threshold and the median value. 37 | // If all values are not far from the median, it will return all values in 'good' set. 38 | func SplitRatios(values []float64) (good map[float64]bool, bad map[float64]bool) { 39 | return SplitFloat64(1.5, 0, values) 40 | } 41 | 42 | // SplitFloat64 provides simple anomaly detection for skewed data sets with no particular distribution. 43 | // In essense it applies the formula if(v > median(values) + threshold * medianAbsoluteDeviation) -> anomaly 44 | // There's a corner case where there are just 2 values, so by definition there's no value that exceeds the threshold. 45 | // This case is solved by introducing additional value that we know is good, e.g. 0. That helps to improve the detection results 46 | // on such data sets. 47 | func SplitFloat64(threshold, sentinel float64, values []float64) (good map[float64]bool, bad map[float64]bool) { 48 | good, bad = make(map[float64]bool), make(map[float64]bool) 49 | var newValues []float64 50 | if len(values)%2 == 0 { 51 | newValues = make([]float64, len(values)+1) 52 | copy(newValues, values) 53 | // Add a sentinel endpoint so we can distinguish outliers better 54 | newValues[len(newValues)-1] = sentinel 55 | } else { 56 | newValues = values 57 | } 58 | 59 | m := median(newValues) 60 | mAbs := medianAbsoluteDeviation(newValues) 61 | for _, v := range values { 62 | if v > (m+mAbs)*threshold { 63 | bad[v] = true 64 | } else { 65 | good[v] = true 66 | } 67 | } 68 | return good, bad 69 | } 70 | 71 | func median(values []float64) float64 { 72 | vals := make([]float64, len(values)) 73 | copy(vals, values) 74 | sort.Float64s(vals) 75 | l := len(vals) 76 | if l%2 != 0 { 77 | return vals[l/2] 78 | } 79 | return (vals[l/2-1] + vals[l/2]) / 2.0 80 | } 81 | 82 | func medianAbsoluteDeviation(values []float64) float64 { 83 | m := median(values) 84 | distances := make([]float64, len(values)) 85 | for i, v := range values { 86 | distances[i] = math.Abs(v - m) 87 | } 88 | return median(distances) 89 | } 90 | 91 | func maxTime(vals []time.Duration) time.Duration { 92 | val := vals[0] 93 | for _, v := range vals { 94 | if v > val { 95 | val = v 96 | } 97 | } 98 | return val 99 | } 100 | -------------------------------------------------------------------------------- /metrics/anomaly_test.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "time" 5 | 6 | . "gopkg.in/check.v1" 7 | ) 8 | 9 | type AnomalySuite struct { 10 | } 11 | 12 | var _ = Suite(&AnomalySuite{}) 13 | 14 | func (s *AnomalySuite) TestMedian(c *C) { 15 | c.Assert(median([]float64{0.1, 0.2}), Equals, (float64(0.1)+float64(0.2))/2.0) 16 | c.Assert(median([]float64{0.3, 0.2, 0.5}), Equals, 0.3) 17 | } 18 | 19 | func (s *AnomalySuite) TestSplitRatios(c *C) { 20 | vals := []struct { 21 | values []float64 22 | good []float64 23 | bad []float64 24 | }{ 25 | { 26 | values: []float64{0, 0}, 27 | good: []float64{0}, 28 | bad: []float64{}, 29 | }, 30 | 31 | { 32 | values: []float64{0, 1}, 33 | good: []float64{0}, 34 | bad: []float64{1}, 35 | }, 36 | { 37 | values: []float64{0.1, 0.1}, 38 | good: []float64{0.1}, 39 | bad: []float64{}, 40 | }, 41 | 42 | { 43 | values: []float64{0.15, 0.1}, 44 | good: []float64{0.15, 0.1}, 45 | bad: []float64{}, 46 | }, 47 | { 48 | values: []float64{0.01, 0.01}, 49 | good: []float64{0.01}, 50 | bad: []float64{}, 51 | }, 52 | { 53 | values: []float64{0.012, 0.01, 1}, 54 | good: []float64{0.012, 0.01}, 55 | bad: []float64{1}, 56 | }, 57 | { 58 | values: []float64{0, 0, 1, 1}, 59 | good: []float64{0}, 60 | bad: []float64{1}, 61 | }, 62 | { 63 | values: []float64{0, 0.1, 0.1, 0}, 64 | good: []float64{0}, 65 | bad: []float64{0.1}, 66 | }, 67 | { 68 | values: []float64{0, 0.01, 0.1, 0}, 69 | good: []float64{0}, 70 | bad: []float64{0.01, 0.1}, 71 | }, 72 | { 73 | values: []float64{0, 0.01, 0.02, 1}, 74 | good: []float64{0, 0.01, 0.02}, 75 | bad: []float64{1}, 76 | }, 77 | { 78 | values: []float64{0, 0, 0, 0, 0, 0.01, 0.02, 1}, 79 | good: []float64{0}, 80 | bad: []float64{0.01, 0.02, 1}, 81 | }, 82 | } 83 | for _, v := range vals { 84 | good, bad := SplitRatios(v.values) 85 | vgood, vbad := make(map[float64]bool, len(v.good)), make(map[float64]bool, len(v.bad)) 86 | for _, v := range v.good { 87 | vgood[v] = true 88 | } 89 | for _, v := range v.bad { 90 | vbad[v] = true 91 | } 92 | 93 | c.Assert(good, DeepEquals, vgood) 94 | c.Assert(bad, DeepEquals, vbad) 95 | } 96 | } 97 | 98 | func (s *AnomalySuite) TestSplitLatencies(c *C) { 99 | vals := []struct { 100 | values []int 101 | good []int 102 | bad []int 103 | }{ 104 | { 105 | values: []int{0, 0}, 106 | good: []int{0}, 107 | bad: []int{}, 108 | }, 109 | { 110 | values: []int{1, 2}, 111 | good: []int{1, 2}, 112 | bad: []int{}, 113 | }, 114 | { 115 | values: []int{1, 2, 4}, 116 | good: []int{1, 2, 4}, 117 | bad: []int{}, 118 | }, 119 | { 120 | values: []int{8, 8, 18}, 121 | good: []int{8}, 122 | bad: []int{18}, 123 | }, 124 | { 125 | values: []int{32, 28, 11, 26, 19, 51, 25, 39, 28, 26, 8, 97}, 126 | good: []int{32, 28, 11, 26, 19, 51, 25, 39, 28, 26, 8}, 127 | bad: []int{97}, 128 | }, 129 | { 130 | values: []int{1, 2, 4, 40}, 131 | good: []int{1, 2, 4}, 132 | bad: []int{40}, 133 | }, 134 | { 135 | values: []int{40, 60, 1000}, 136 | good: []int{40, 60}, 137 | bad: []int{1000}, 138 | }, 139 | } 140 | for _, v := range vals { 141 | vvalues := make([]time.Duration, len(v.values)) 142 | for i, d := range v.values { 143 | vvalues[i] = time.Millisecond * time.Duration(d) 144 | } 145 | good, bad := SplitLatencies(vvalues, time.Millisecond) 146 | 147 | vgood, vbad := make(map[time.Duration]bool, len(v.good)), make(map[time.Duration]bool, len(v.bad)) 148 | for _, v := range v.good { 149 | vgood[time.Duration(v)*time.Millisecond] = true 150 | } 151 | for _, v := range v.bad { 152 | vbad[time.Duration(v)*time.Millisecond] = true 153 | } 154 | 155 | c.Assert(good, DeepEquals, vgood) 156 | c.Assert(bad, DeepEquals, vbad) 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /metrics/counter.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/mailgun/timetools" 8 | ) 9 | 10 | // NewRollingCounterFn is a constructor of rolling counters. 11 | type NewRollingCounterFn func() (*RollingCounter, error) 12 | 13 | // Calculates in memory failure rate of an endpoint using rolling window of a predefined size 14 | type RollingCounter struct { 15 | timeProvider timetools.TimeProvider 16 | resolution time.Duration 17 | values []int 18 | countedBuckets int // how many samples in different buckets have we collected so far 19 | lastBucket int // last recorded bucket 20 | lastUpdated time.Time 21 | } 22 | 23 | // NewRollingCounter creates a counter with fixed amount of buckets that are rotated every resolition period. 24 | // E.g. 10 buckets with 1 second means that every new second the bucket is refreshed, so it maintains 10 second rolling window. 25 | func NewRollingCounter(buckets int, resolution time.Duration, timeProvider timetools.TimeProvider) (*RollingCounter, error) { 26 | if buckets <= 0 { 27 | return nil, fmt.Errorf("Buckets should be >= 0") 28 | } 29 | if resolution < time.Second { 30 | return nil, fmt.Errorf("Resolution should be larger than a second") 31 | } 32 | 33 | return &RollingCounter{ 34 | resolution: resolution, 35 | timeProvider: timeProvider, 36 | values: make([]int, buckets), 37 | lastBucket: -1, 38 | }, nil 39 | } 40 | 41 | func (c *RollingCounter) Reset() { 42 | c.lastBucket = -1 43 | c.countedBuckets = 0 44 | c.lastUpdated = time.Time{} 45 | for i := range c.values { 46 | c.values[i] = 0 47 | } 48 | } 49 | 50 | func (c *RollingCounter) CountedBuckets() int { 51 | return c.countedBuckets 52 | } 53 | 54 | func (c *RollingCounter) Count() int64 { 55 | c.cleanup() 56 | return c.sum() 57 | } 58 | 59 | func (c *RollingCounter) Resolution() time.Duration { 60 | return c.resolution 61 | } 62 | 63 | func (c *RollingCounter) Buckets() int { 64 | return len(c.values) 65 | } 66 | 67 | func (c *RollingCounter) GetWindowSize() time.Duration { 68 | return time.Duration(len(c.values)) * c.resolution 69 | } 70 | 71 | func (c *RollingCounter) Inc() { 72 | c.cleanup() 73 | c.incBucketValue() 74 | } 75 | 76 | func (c *RollingCounter) incBucketValue() { 77 | now := c.timeProvider.UtcNow() 78 | bucket := c.getBucket(now) 79 | c.values[bucket]++ 80 | c.lastUpdated = now 81 | // Update usage stats if we haven't collected enough data 82 | if c.countedBuckets < len(c.values) { 83 | // Only update if we have advanced to the next bucket and not incremented the value 84 | // in the current bucket. 85 | if c.lastBucket != bucket { 86 | c.lastBucket = bucket 87 | c.countedBuckets++ 88 | } 89 | } 90 | } 91 | 92 | // Returns the number in the moving window bucket that this slot occupies 93 | func (c *RollingCounter) getBucket(t time.Time) int { 94 | return int(t.Truncate(c.resolution).Unix() % int64(len(c.values))) 95 | } 96 | 97 | // Reset buckets that were not updated 98 | func (c *RollingCounter) cleanup() { 99 | now := c.timeProvider.UtcNow() 100 | for i := 0; i < len(c.values); i++ { 101 | now = now.Add(time.Duration(-1*i) * c.resolution) 102 | if now.Truncate(c.resolution).After(c.lastUpdated.Truncate(c.resolution)) { 103 | c.values[c.getBucket(now)] = 0 104 | } else { 105 | break 106 | } 107 | } 108 | } 109 | 110 | func (c *RollingCounter) sum() int64 { 111 | out := int64(0) 112 | for _, v := range c.values { 113 | out += int64(v) 114 | } 115 | return out 116 | } 117 | -------------------------------------------------------------------------------- /metrics/failrate.go: -------------------------------------------------------------------------------- 1 | // In memory request performance metrics 2 | package metrics 3 | 4 | import ( 5 | "fmt" 6 | "time" 7 | 8 | "github.com/mailgun/timetools" 9 | "github.com/mailgun/vulcan/endpoint" 10 | "github.com/mailgun/vulcan/middleware" 11 | "github.com/mailgun/vulcan/request" 12 | ) 13 | 14 | type FailRateMeter interface { 15 | GetRate() float64 16 | IsReady() bool 17 | GetWindowSize() time.Duration 18 | middleware.Observer 19 | } 20 | 21 | // Predicate that helps to see if the attempt resulted in error 22 | type FailPredicate func(request.Attempt) bool 23 | 24 | func IsNetworkError(attempt request.Attempt) bool { 25 | return attempt != nil && attempt.GetError() != nil 26 | } 27 | 28 | // Calculates various performance metrics about the endpoint using counters of the predefined size 29 | type RollingMeter struct { 30 | endpoint endpoint.Endpoint 31 | isError FailPredicate 32 | 33 | errors *RollingCounter 34 | successes *RollingCounter 35 | } 36 | 37 | func NewRollingMeter(endpoint endpoint.Endpoint, buckets int, resolution time.Duration, timeProvider timetools.TimeProvider, isError FailPredicate) (*RollingMeter, error) { 38 | if endpoint == nil { 39 | return nil, fmt.Errorf("Select an endpoint") 40 | } 41 | if isError == nil { 42 | isError = IsNetworkError 43 | } 44 | 45 | e, err := NewRollingCounter(buckets, resolution, timeProvider) 46 | if err != nil { 47 | return nil, err 48 | } 49 | 50 | s, err := NewRollingCounter(buckets, resolution, timeProvider) 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | return &RollingMeter{ 56 | endpoint: endpoint, 57 | errors: e, 58 | successes: s, 59 | isError: isError, 60 | }, nil 61 | } 62 | 63 | func (r *RollingMeter) Reset() { 64 | r.errors.Reset() 65 | r.successes.Reset() 66 | } 67 | 68 | func (r *RollingMeter) IsReady() bool { 69 | return r.errors.countedBuckets+r.successes.countedBuckets >= len(r.errors.values) 70 | } 71 | 72 | func (r *RollingMeter) SuccessCount() int64 { 73 | return r.successes.Count() 74 | } 75 | 76 | func (r *RollingMeter) FailureCount() int64 { 77 | return r.errors.Count() 78 | } 79 | 80 | func (r *RollingMeter) Resolution() time.Duration { 81 | return r.errors.Resolution() 82 | } 83 | 84 | func (r *RollingMeter) Buckets() int { 85 | return r.errors.Buckets() 86 | } 87 | 88 | func (r *RollingMeter) GetWindowSize() time.Duration { 89 | return r.errors.GetWindowSize() 90 | } 91 | 92 | func (r *RollingMeter) ProcessedCount() int64 { 93 | return r.SuccessCount() + r.FailureCount() 94 | } 95 | 96 | func (r *RollingMeter) GetRate() float64 { 97 | success := r.SuccessCount() 98 | failure := r.FailureCount() 99 | // No data, return ok 100 | if success+failure == 0 { 101 | return 0 102 | } 103 | return float64(failure) / float64(success+failure) 104 | } 105 | 106 | func (r *RollingMeter) ObserveRequest(request.Request) { 107 | } 108 | 109 | func (r *RollingMeter) ObserveResponse(req request.Request, lastAttempt request.Attempt) { 110 | if lastAttempt == nil || lastAttempt.GetEndpoint() != r.endpoint { 111 | return 112 | } 113 | 114 | if r.isError(lastAttempt) { 115 | r.errors.Inc() 116 | } else { 117 | r.successes.Inc() 118 | } 119 | } 120 | 121 | type TestMeter struct { 122 | Rate float64 123 | NotReady bool 124 | WindowSize time.Duration 125 | } 126 | 127 | func (tm *TestMeter) GetWindowSize() time.Duration { 128 | return tm.WindowSize 129 | } 130 | 131 | func (tm *TestMeter) IsReady() bool { 132 | return !tm.NotReady 133 | } 134 | 135 | func (tm *TestMeter) GetRate() float64 { 136 | return tm.Rate 137 | } 138 | 139 | func (em *TestMeter) ObserveRequest(r request.Request) { 140 | } 141 | 142 | func (em *TestMeter) ObserveResponse(r request.Request, lastAttempt request.Attempt) { 143 | } 144 | -------------------------------------------------------------------------------- /metrics/failrate_test.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/mailgun/timetools" 9 | . "github.com/mailgun/vulcan/endpoint" 10 | . "github.com/mailgun/vulcan/request" 11 | 12 | . "gopkg.in/check.v1" 13 | ) 14 | 15 | func TestFailrate(t *testing.T) { TestingT(t) } 16 | 17 | type FailRateSuite struct { 18 | tm *timetools.FreezedTime 19 | } 20 | 21 | var _ = Suite(&FailRateSuite{}) 22 | 23 | func (s *FailRateSuite) SetUpSuite(c *C) { 24 | s.tm = &timetools.FreezedTime{ 25 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 26 | } 27 | } 28 | 29 | func (s *FailRateSuite) TestInvalidParams(c *C) { 30 | e := MustParseUrl("http://localhost:5000") 31 | 32 | // Invalid endpoint 33 | _, err := NewRollingMeter(nil, 10, time.Second, s.tm, nil) 34 | c.Assert(err, Not(IsNil)) 35 | 36 | // Bad buckets count 37 | _, err = NewRollingMeter(e, 0, time.Second, s.tm, nil) 38 | c.Assert(err, Not(IsNil)) 39 | 40 | // Too precise resolution 41 | _, err = NewRollingMeter(e, 10, time.Millisecond, s.tm, nil) 42 | c.Assert(err, Not(IsNil)) 43 | } 44 | 45 | func (s *FailRateSuite) TestNotReady(c *C) { 46 | e := MustParseUrl("http://localhost:5000") 47 | 48 | // No data 49 | fr, err := NewRollingMeter(e, 10, time.Second, s.tm, nil) 50 | c.Assert(err, IsNil) 51 | c.Assert(fr.IsReady(), Equals, false) 52 | c.Assert(fr.GetRate(), Equals, 0.0) 53 | 54 | // Not enough data 55 | fr, err = NewRollingMeter(e, 10, time.Second, s.tm, nil) 56 | c.Assert(err, IsNil) 57 | fr.ObserveResponse(makeFailRequest(e)) 58 | c.Assert(fr.IsReady(), Equals, false) 59 | } 60 | 61 | // Make sure we don't count the stats from the endpoints we don't care or requests with no attempts 62 | func (s *FailRateSuite) TestIgnoreOtherEndpoints(c *C) { 63 | e := MustParseUrl("http://localhost:5000") 64 | e2 := MustParseUrl("http://localhost:5001") 65 | 66 | fr, err := NewRollingMeter(e, 1, time.Second, s.tm, nil) 67 | c.Assert(err, IsNil) 68 | fr.ObserveResponse(makeFailRequest(e)) 69 | fr.ObserveResponse(makeFailRequest(e2)) 70 | 71 | c.Assert(fr.IsReady(), Equals, true) 72 | c.Assert(fr.GetRate(), Equals, 1.0) 73 | } 74 | 75 | func (s *FailRateSuite) TestIgnoreRequestsWithoutAttempts(c *C) { 76 | e := MustParseUrl("http://localhost:5000") 77 | 78 | fr, err := NewRollingMeter(e, 1, time.Second, s.tm, nil) 79 | c.Assert(err, IsNil) 80 | fr.ObserveResponse(makeFailRequest(e)) 81 | fr.ObserveResponse(&BaseRequest{}, nil) 82 | 83 | c.Assert(fr.IsReady(), Equals, true) 84 | c.Assert(fr.GetRate(), Equals, 1.0) 85 | } 86 | 87 | func (s *FailRateSuite) TestNoSuccesses(c *C) { 88 | e := MustParseUrl("http://localhost:5000") 89 | 90 | fr, err := NewRollingMeter(e, 1, time.Second, s.tm, nil) 91 | c.Assert(err, IsNil) 92 | fr.ObserveResponse(makeFailRequest(e)) 93 | 94 | c.Assert(fr.IsReady(), Equals, true) 95 | c.Assert(fr.GetRate(), Equals, 1.0) 96 | } 97 | 98 | func (s *FailRateSuite) TestNoFailures(c *C) { 99 | e := MustParseUrl("http://localhost:5000") 100 | 101 | fr, err := NewRollingMeter(e, 1, time.Second, s.tm, nil) 102 | c.Assert(err, IsNil) 103 | fr.ObserveResponse(makeOkRequest(e)) 104 | 105 | c.Assert(fr.IsReady(), Equals, true) 106 | c.Assert(fr.GetRate(), Equals, 0.0) 107 | } 108 | 109 | // Make sure that data is properly calculated over several buckets 110 | func (s *FailRateSuite) TestMultipleBuckets(c *C) { 111 | e := MustParseUrl("http://localhost:5000") 112 | 113 | fr, err := NewRollingMeter(e, 3, time.Second, s.tm, nil) 114 | c.Assert(err, IsNil) 115 | 116 | fr.ObserveResponse(makeOkRequest(e)) 117 | 118 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 119 | fr.ObserveResponse(makeFailRequest(e)) 120 | 121 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 122 | fr.ObserveResponse(makeFailRequest(e)) 123 | 124 | c.Assert(fr.IsReady(), Equals, true) 125 | c.Assert(fr.GetRate(), Equals, float64(2)/float64(3)) 126 | } 127 | 128 | // Make sure that data is properly calculated over several buckets 129 | // When we overwrite old data when the window is rolling 130 | func (s *FailRateSuite) TestOverwriteBuckets(c *C) { 131 | e := MustParseUrl("http://localhost:5000") 132 | 133 | fr, err := NewRollingMeter(e, 3, time.Second, s.tm, nil) 134 | c.Assert(err, IsNil) 135 | 136 | fr.ObserveResponse(makeOkRequest(e)) 137 | 138 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 139 | fr.ObserveResponse(makeFailRequest(e)) 140 | 141 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 142 | fr.ObserveResponse(makeFailRequest(e)) 143 | 144 | // This time we should overwrite the old data points 145 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 146 | fr.ObserveResponse(makeFailRequest(e)) 147 | fr.ObserveResponse(makeOkRequest(e)) 148 | fr.ObserveResponse(makeOkRequest(e)) 149 | 150 | c.Assert(fr.IsReady(), Equals, true) 151 | c.Assert(fr.GetRate(), Equals, float64(3)/float64(5)) 152 | } 153 | 154 | // Make sure we cleanup the data after periods of inactivity 155 | // So it does not mess up the stats 156 | func (s *FailRateSuite) TestInactiveBuckets(c *C) { 157 | e := MustParseUrl("http://localhost:5000") 158 | 159 | fr, err := NewRollingMeter(e, 3, time.Second, s.tm, nil) 160 | c.Assert(err, IsNil) 161 | 162 | fr.ObserveResponse(makeOkRequest(e)) 163 | 164 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 165 | fr.ObserveResponse(makeFailRequest(e)) 166 | 167 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 168 | fr.ObserveResponse(makeFailRequest(e)) 169 | 170 | // This time we should overwrite the old data points with new data 171 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 172 | fr.ObserveResponse(makeFailRequest(e)) 173 | fr.ObserveResponse(makeOkRequest(e)) 174 | fr.ObserveResponse(makeOkRequest(e)) 175 | 176 | // Jump to the last bucket and change the data 177 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second * 2) 178 | fr.ObserveResponse(makeOkRequest(e)) 179 | 180 | c.Assert(fr.IsReady(), Equals, true) 181 | c.Assert(fr.GetRate(), Equals, float64(1)/float64(4)) 182 | } 183 | 184 | func (s *FailRateSuite) TestLongPeriodsOfInactivity(c *C) { 185 | e := MustParseUrl("http://localhost:5000") 186 | 187 | fr, err := NewRollingMeter(e, 2, time.Second, s.tm, nil) 188 | c.Assert(err, IsNil) 189 | 190 | fr.ObserveResponse(makeOkRequest(e)) 191 | 192 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 193 | fr.ObserveResponse(makeFailRequest(e)) 194 | 195 | c.Assert(fr.IsReady(), Equals, true) 196 | c.Assert(fr.GetRate(), Equals, 0.5) 197 | 198 | // This time we should overwrite all data points 199 | s.tm.CurrentTime = s.tm.CurrentTime.Add(100 * time.Second) 200 | fr.ObserveResponse(makeFailRequest(e)) 201 | c.Assert(fr.GetRate(), Equals, 1.0) 202 | } 203 | 204 | func (s *FailRateSuite) TestReset(c *C) { 205 | e := MustParseUrl("http://localhost:5000") 206 | 207 | fr, err := NewRollingMeter(e, 1, time.Second, s.tm, nil) 208 | c.Assert(err, IsNil) 209 | 210 | fr.ObserveResponse(makeOkRequest(e)) 211 | fr.ObserveResponse(makeFailRequest(e)) 212 | 213 | c.Assert(fr.IsReady(), Equals, true) 214 | c.Assert(fr.GetRate(), Equals, 0.5) 215 | 216 | // Reset the counter 217 | fr.Reset() 218 | c.Assert(fr.IsReady(), Equals, false) 219 | 220 | // Now add some stats 221 | fr.ObserveResponse(makeFailRequest(e)) 222 | fr.ObserveResponse(makeFailRequest(e)) 223 | 224 | // We are game again! 225 | c.Assert(fr.IsReady(), Equals, true) 226 | c.Assert(fr.GetRate(), Equals, 1.0) 227 | } 228 | 229 | func makeRequest(endpoint Endpoint, err error) Request { 230 | return &BaseRequest{ 231 | Attempts: []Attempt{ 232 | &BaseAttempt{ 233 | Error: err, 234 | Endpoint: endpoint, 235 | }, 236 | }, 237 | } 238 | } 239 | 240 | func makeFailRequest(endpoint Endpoint) (Request, Attempt) { 241 | r := makeRequest(endpoint, fmt.Errorf("Oops")) 242 | return r, r.GetLastAttempt() 243 | } 244 | 245 | func makeOkRequest(endpoint Endpoint) (Request, Attempt) { 246 | r := makeRequest(endpoint, nil) 247 | return r, r.GetLastAttempt() 248 | } 249 | -------------------------------------------------------------------------------- /metrics/histogram.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/codahale/hdrhistogram" 8 | "github.com/mailgun/timetools" 9 | ) 10 | 11 | type Histogram interface { 12 | // Returns latency at quantile with microsecond precision 13 | LatencyAtQuantile(float64) time.Duration 14 | // Records latencies with microsecond precision 15 | RecordLatencies(d time.Duration, n int64) error 16 | 17 | ValueAtQuantile(q float64) int64 18 | RecordValues(v, n int64) error 19 | // Merge updates this histogram with values of another histogram 20 | Merge(Histogram) error 21 | // Resets state of the histogram 22 | Reset() 23 | } 24 | 25 | // RollingHistogram holds multiple histograms and rotates every period. 26 | // It provides resulting histogram as a result of a call of 'Merged' function. 27 | type RollingHistogram interface { 28 | RecordValues(v, n int64) error 29 | RecordLatencies(d time.Duration, n int64) error 30 | Merged() (Histogram, error) 31 | Reset() 32 | } 33 | 34 | // NewHistogramFn is a constructor that can be passed to NewRollingHistogram 35 | type NewHistogramFn func() (Histogram, error) 36 | 37 | // NewHDRHistogramFn creates a constructor of HDR histograms with predefined parameters. 38 | func NewHDRHistogramFn(low, high int64, sigfigs int) NewHistogramFn { 39 | return func() (Histogram, error) { 40 | return NewHDRHistogram(low, high, sigfigs) 41 | } 42 | } 43 | 44 | type rollingHistogram struct { 45 | maker NewHistogramFn 46 | idx int 47 | lastRoll time.Time 48 | period time.Duration 49 | buckets []Histogram 50 | timeProvider timetools.TimeProvider 51 | } 52 | 53 | func NewRollingHistogram(maker NewHistogramFn, bucketCount int, period time.Duration, timeProvider timetools.TimeProvider) (RollingHistogram, error) { 54 | buckets := make([]Histogram, bucketCount) 55 | for i := range buckets { 56 | h, err := maker() 57 | if err != nil { 58 | return nil, err 59 | } 60 | buckets[i] = h 61 | } 62 | 63 | return &rollingHistogram{ 64 | maker: maker, 65 | buckets: buckets, 66 | period: period, 67 | timeProvider: timeProvider, 68 | }, nil 69 | } 70 | 71 | func (r *rollingHistogram) Reset() { 72 | r.idx = 0 73 | r.lastRoll = r.timeProvider.UtcNow() 74 | for _, b := range r.buckets { 75 | b.Reset() 76 | } 77 | } 78 | 79 | func (r *rollingHistogram) rotate() { 80 | r.idx = (r.idx + 1) % len(r.buckets) 81 | r.buckets[r.idx].Reset() 82 | } 83 | 84 | func (r *rollingHistogram) Merged() (Histogram, error) { 85 | m, err := r.maker() 86 | if err != nil { 87 | return m, err 88 | } 89 | for _, h := range r.buckets { 90 | if m.Merge(h); err != nil { 91 | return nil, err 92 | } 93 | } 94 | return m, nil 95 | } 96 | 97 | func (r *rollingHistogram) getHist() Histogram { 98 | if r.timeProvider.UtcNow().Sub(r.lastRoll) >= r.period { 99 | r.rotate() 100 | r.lastRoll = r.timeProvider.UtcNow() 101 | } 102 | return r.buckets[r.idx] 103 | } 104 | 105 | func (r *rollingHistogram) RecordLatencies(v time.Duration, n int64) error { 106 | return r.getHist().RecordLatencies(v, n) 107 | } 108 | 109 | func (r *rollingHistogram) RecordValues(v, n int64) error { 110 | return r.getHist().RecordValues(v, n) 111 | } 112 | 113 | type HDRHistogram struct { 114 | // lowest trackable value 115 | low int64 116 | // highest trackable value 117 | high int64 118 | // significant figures 119 | sigfigs int 120 | 121 | h *hdrhistogram.Histogram 122 | } 123 | 124 | func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error) { 125 | defer func() { 126 | if msg := recover(); msg != nil { 127 | err = fmt.Errorf("%s", msg) 128 | } 129 | }() 130 | 131 | hdr := hdrhistogram.New(low, high, sigfigs) 132 | h = &HDRHistogram{ 133 | low: low, 134 | high: high, 135 | sigfigs: sigfigs, 136 | h: hdr, 137 | } 138 | return h, err 139 | } 140 | 141 | // Returns latency at quantile with microsecond precision 142 | func (h *HDRHistogram) LatencyAtQuantile(q float64) time.Duration { 143 | return time.Duration(h.ValueAtQuantile(q)) * time.Microsecond 144 | } 145 | 146 | // Records latencies with microsecond precision 147 | func (h *HDRHistogram) RecordLatencies(d time.Duration, n int64) error { 148 | return h.RecordValues(int64(d/time.Microsecond), n) 149 | } 150 | 151 | func (h *HDRHistogram) Reset() { 152 | h.h.Reset() 153 | } 154 | 155 | func (h *HDRHistogram) ValueAtQuantile(q float64) int64 { 156 | return h.h.ValueAtQuantile(q) 157 | } 158 | 159 | func (h *HDRHistogram) RecordValues(v, n int64) error { 160 | return h.h.RecordValues(v, n) 161 | } 162 | 163 | func (h *HDRHistogram) Merge(o Histogram) error { 164 | other, ok := o.(*HDRHistogram) 165 | if !ok { 166 | return fmt.Errorf("can merge only with other HDRHistogram, got %T", o) 167 | } 168 | 169 | h.h.Merge(other.h) 170 | return nil 171 | } 172 | -------------------------------------------------------------------------------- /metrics/histogram_test.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/mailgun/timetools" 7 | . "gopkg.in/check.v1" 8 | ) 9 | 10 | type HistogramSuite struct { 11 | tm *timetools.FreezedTime 12 | } 13 | 14 | var _ = Suite(&HistogramSuite{}) 15 | 16 | func (s *HistogramSuite) SetUpSuite(c *C) { 17 | s.tm = &timetools.FreezedTime{ 18 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 19 | } 20 | } 21 | 22 | func (s *HistogramSuite) TestMerge(c *C) { 23 | a, err := NewHDRHistogram(1, 3600000, 2) 24 | c.Assert(err, IsNil) 25 | 26 | a.RecordValues(1, 2) 27 | 28 | b, err := NewHDRHistogram(1, 3600000, 2) 29 | c.Assert(err, IsNil) 30 | 31 | b.RecordValues(2, 1) 32 | 33 | c.Assert(a.Merge(b), IsNil) 34 | 35 | c.Assert(a.ValueAtQuantile(50), Equals, int64(1)) 36 | c.Assert(a.ValueAtQuantile(100), Equals, int64(2)) 37 | } 38 | 39 | func (s *HistogramSuite) TestInvalidParams(c *C) { 40 | _, err := NewHDRHistogram(1, 3600000, 0) 41 | c.Assert(err, NotNil) 42 | } 43 | 44 | func (s *HistogramSuite) TestMergeNil(c *C) { 45 | a, err := NewHDRHistogram(1, 3600000, 1) 46 | c.Assert(err, IsNil) 47 | 48 | c.Assert(a.Merge(nil), NotNil) 49 | } 50 | 51 | func (s *HistogramSuite) TestRotation(c *C) { 52 | h, err := NewRollingHistogram( 53 | NewHDRHistogramFn(1, 3600000, 3), 54 | 2, // 2 histograms in a window 55 | time.Second, // 1 second is a rolling period 56 | s.tm) 57 | 58 | c.Assert(err, IsNil) 59 | c.Assert(h, NotNil) 60 | 61 | h.RecordValues(5, 1) 62 | 63 | m, err := h.Merged() 64 | c.Assert(err, IsNil) 65 | c.Assert(m.ValueAtQuantile(100), Equals, int64(5)) 66 | 67 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 68 | h.RecordValues(2, 1) 69 | h.RecordValues(1, 1) 70 | 71 | m, err = h.Merged() 72 | c.Assert(err, IsNil) 73 | c.Assert(m.ValueAtQuantile(100), Equals, int64(5)) 74 | 75 | // rotate, this means that the old value would evaporate 76 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 77 | h.RecordValues(1, 1) 78 | m, err = h.Merged() 79 | c.Assert(err, IsNil) 80 | c.Assert(m.ValueAtQuantile(100), Equals, int64(2)) 81 | } 82 | 83 | func (s *HistogramSuite) TestReset(c *C) { 84 | h, err := NewRollingHistogram( 85 | NewHDRHistogramFn(1, 3600000, 3), 86 | 2, // 2 histograms in a window 87 | time.Second, // 1 second is a rolling period 88 | s.tm) 89 | 90 | c.Assert(err, IsNil) 91 | c.Assert(h, NotNil) 92 | 93 | h.RecordValues(5, 1) 94 | 95 | m, err := h.Merged() 96 | c.Assert(err, IsNil) 97 | c.Assert(m.ValueAtQuantile(100), Equals, int64(5)) 98 | 99 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 100 | h.RecordValues(2, 1) 101 | h.RecordValues(1, 1) 102 | 103 | m, err = h.Merged() 104 | c.Assert(err, IsNil) 105 | c.Assert(m.ValueAtQuantile(100), Equals, int64(5)) 106 | 107 | h.Reset() 108 | 109 | h.RecordValues(5, 1) 110 | 111 | m, err = h.Merged() 112 | c.Assert(err, IsNil) 113 | c.Assert(m.ValueAtQuantile(100), Equals, int64(5)) 114 | 115 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 116 | h.RecordValues(2, 1) 117 | h.RecordValues(1, 1) 118 | 119 | m, err = h.Merged() 120 | c.Assert(err, IsNil) 121 | c.Assert(m.ValueAtQuantile(100), Equals, int64(5)) 122 | 123 | } 124 | -------------------------------------------------------------------------------- /metrics/roundtrip.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/mailgun/log" 7 | "github.com/mailgun/timetools" 8 | "github.com/mailgun/vulcan/request" 9 | ) 10 | 11 | // RoundTripMetrics provides aggregated performance metrics for HTTP requests processing 12 | // such as round trip latency, response codes counters network error and total requests. 13 | // all counters are collected as rolling window counters with defined precision, histograms 14 | // are a rolling window histograms with defined precision as well. 15 | // See RoundTripOptions for more detail on parameters. 16 | type RoundTripMetrics struct { 17 | o *RoundTripOptions 18 | total *RollingCounter 19 | netErrors *RollingCounter 20 | statusCodes map[int]*RollingCounter 21 | histogram RollingHistogram 22 | } 23 | 24 | type RoundTripOptions struct { 25 | // CounterBuckets - how many buckets to allocate for rolling counter. Defaults to 10 buckets. 26 | CounterBuckets int 27 | // CounterResolution specifies the resolution for a single bucket 28 | // (e.g. time.Second means that bucket will be counted for a second). 29 | // defaults to time.Second 30 | CounterResolution time.Duration 31 | // HistMin - minimum non 0 value for a histogram (default 1) 32 | HistMin int64 33 | // HistMax - maximum value that can be recorded for a histogram (default 3,600,000,000) 34 | HistMax int64 35 | // HistSignificantFigures - defines precision for a value. e.g. 3 - 0.1%X precision, default is 2 - 1% precision for X 36 | HistSignificantFigures int 37 | // HistBuckets - how many sub histogram to keep in a rolling histogram, default is 6 38 | HistBuckets int 39 | // HistPeriod - rotation period for a histogram, default is 10 seconds 40 | HistPeriod time.Duration 41 | // TimeProvider - to provide time provider in tests, default is RealTime 42 | TimeProvider timetools.TimeProvider 43 | } 44 | 45 | // NewRoundTripMetrics returns new instance of metrics collector. 46 | func NewRoundTripMetrics(o RoundTripOptions) (*RoundTripMetrics, error) { 47 | o = setDefaults(o) 48 | 49 | h, err := NewRollingHistogram( 50 | // this will create subhistograms 51 | NewHDRHistogramFn(o.HistMin, o.HistMax, o.HistSignificantFigures), 52 | // number of buckets in a rolling histogram 53 | o.HistBuckets, 54 | // rolling period for a histogram 55 | o.HistPeriod, 56 | o.TimeProvider) 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | m := &RoundTripMetrics{ 62 | statusCodes: make(map[int]*RollingCounter), 63 | histogram: h, 64 | o: &o, 65 | } 66 | 67 | netErrors, err := m.newCounter() 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | total, err := m.newCounter() 73 | if err != nil { 74 | return nil, err 75 | } 76 | 77 | m.netErrors = netErrors 78 | m.total = total 79 | return m, nil 80 | } 81 | 82 | // GetOptions returns settings used for this instance 83 | func (m *RoundTripMetrics) GetOptions() *RoundTripOptions { 84 | return m.o 85 | } 86 | 87 | // GetNetworkErrorRatio calculates the amont of network errors such as time outs and dropped connection 88 | // that occured in the given time window compared to the total requests count. 89 | func (m *RoundTripMetrics) GetNetworkErrorRatio() float64 { 90 | if m.total.Count() == 0 { 91 | return 0 92 | } 93 | return float64(m.netErrors.Count()) / float64(m.total.Count()) 94 | } 95 | 96 | // GetResponseCodeRatio calculates ratio of count(startA to endA) / count(startB to endB) 97 | func (m *RoundTripMetrics) GetResponseCodeRatio(startA, endA, startB, endB int) float64 { 98 | a := int64(0) 99 | b := int64(0) 100 | for code, v := range m.statusCodes { 101 | if code < endA && code >= startA { 102 | a += v.Count() 103 | } 104 | if code < endB && code >= startB { 105 | b += v.Count() 106 | } 107 | } 108 | if b != 0 { 109 | return float64(a) / float64(b) 110 | } 111 | return 0 112 | } 113 | 114 | // RecordMetrics updates internal metrics collection based on the data from passed request. 115 | func (m *RoundTripMetrics) RecordMetrics(a request.Attempt) { 116 | m.total.Inc() 117 | m.recordNetError(a) 118 | m.recordLatency(a) 119 | m.recordStatusCode(a) 120 | } 121 | 122 | // GetTotalCount returns total count of processed requests collected. 123 | func (m *RoundTripMetrics) GetTotalCount() int64 { 124 | return m.total.Count() 125 | } 126 | 127 | // GetNetworkErrorCount returns total count of processed requests observed 128 | func (m *RoundTripMetrics) GetNetworkErrorCount() int64 { 129 | return m.netErrors.Count() 130 | } 131 | 132 | // GetStatusCodesCounts returns map with counts of the response codes 133 | func (m *RoundTripMetrics) GetStatusCodesCounts() map[int]int64 { 134 | sc := make(map[int]int64) 135 | for k, v := range m.statusCodes { 136 | if v.Count() != 0 { 137 | sc[k] = v.Count() 138 | } 139 | } 140 | return sc 141 | } 142 | 143 | // GetLatencyHistogram computes and returns resulting histogram with latencies observed. 144 | func (m *RoundTripMetrics) GetLatencyHistogram() (Histogram, error) { 145 | return m.histogram.Merged() 146 | } 147 | 148 | func (m *RoundTripMetrics) Reset() { 149 | m.histogram.Reset() 150 | m.total.Reset() 151 | m.netErrors.Reset() 152 | m.statusCodes = make(map[int]*RollingCounter) 153 | } 154 | 155 | func (m *RoundTripMetrics) newCounter() (*RollingCounter, error) { 156 | return NewRollingCounter(m.o.CounterBuckets, m.o.CounterResolution, m.o.TimeProvider) 157 | } 158 | 159 | func (m *RoundTripMetrics) recordNetError(a request.Attempt) { 160 | if IsNetworkError(a) { 161 | m.netErrors.Inc() 162 | } 163 | } 164 | 165 | func (m *RoundTripMetrics) recordLatency(a request.Attempt) { 166 | if err := m.histogram.RecordLatencies(a.GetDuration(), 1); err != nil { 167 | log.Errorf("Failed to record latency: %v", err) 168 | } 169 | } 170 | 171 | func (m *RoundTripMetrics) recordStatusCode(a request.Attempt) { 172 | if a.GetResponse() == nil { 173 | return 174 | } 175 | statusCode := a.GetResponse().StatusCode 176 | if c, ok := m.statusCodes[statusCode]; ok { 177 | c.Inc() 178 | return 179 | } 180 | c, err := m.newCounter() 181 | if err != nil { 182 | log.Errorf("failed to create a counter: %v", err) 183 | return 184 | } 185 | c.Inc() 186 | m.statusCodes[statusCode] = c 187 | } 188 | 189 | const ( 190 | counterBuckets = 10 191 | counterResolution = time.Second 192 | histMin = 1 193 | histMax = 3600000000 // 1 hour in microseconds 194 | histSignificantFigures = 2 // signigicant figures (1% precision) 195 | histBuckets = 6 // number of sub-histograms in a rolling histogram 196 | histPeriod = 10 * time.Second // roll time 197 | ) 198 | 199 | func setDefaults(o RoundTripOptions) RoundTripOptions { 200 | if o.CounterBuckets == 0 { 201 | o.CounterBuckets = counterBuckets 202 | } 203 | if o.CounterResolution == 0 { 204 | o.CounterResolution = time.Second 205 | } 206 | if o.HistMin == 0 { 207 | o.HistMin = histMin 208 | } 209 | if o.HistMax == 0 { 210 | o.HistMax = histMax 211 | } 212 | if o.HistBuckets == 0 { 213 | o.HistBuckets = histBuckets 214 | } 215 | if o.HistSignificantFigures == 0 { 216 | o.HistSignificantFigures = histSignificantFigures 217 | } 218 | if o.HistPeriod == 0 { 219 | o.HistPeriod = histPeriod 220 | } 221 | if o.TimeProvider == nil { 222 | o.TimeProvider = &timetools.RealTime{} 223 | } 224 | return o 225 | } 226 | -------------------------------------------------------------------------------- /metrics/rr_test.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/mailgun/timetools" 9 | "github.com/mailgun/vulcan/request" 10 | . "gopkg.in/check.v1" 11 | ) 12 | 13 | type RRSuite struct { 14 | tm *timetools.FreezedTime 15 | } 16 | 17 | var _ = Suite(&RRSuite{}) 18 | 19 | func (s *RRSuite) SetUpSuite(c *C) { 20 | s.tm = &timetools.FreezedTime{ 21 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 22 | } 23 | } 24 | 25 | func (s *RRSuite) TestDefaults(c *C) { 26 | rr, err := NewRoundTripMetrics(RoundTripOptions{TimeProvider: s.tm}) 27 | c.Assert(err, IsNil) 28 | c.Assert(rr, NotNil) 29 | 30 | rr.RecordMetrics(makeAttempt(O{err: fmt.Errorf("o"), duration: time.Second})) 31 | rr.RecordMetrics(makeAttempt(O{statusCode: 500, duration: 2 * time.Second})) 32 | rr.RecordMetrics(makeAttempt(O{statusCode: 200, duration: time.Second})) 33 | rr.RecordMetrics(makeAttempt(O{statusCode: 200, duration: time.Second})) 34 | 35 | c.Assert(rr.GetNetworkErrorCount(), Equals, int64(1)) 36 | c.Assert(rr.GetTotalCount(), Equals, int64(4)) 37 | c.Assert(rr.GetStatusCodesCounts(), DeepEquals, map[int]int64{500: 1, 200: 2}) 38 | c.Assert(rr.GetNetworkErrorRatio(), Equals, float64(1)/float64(4)) 39 | c.Assert(rr.GetResponseCodeRatio(500, 501, 200, 300), Equals, 0.5) 40 | 41 | h, err := rr.GetLatencyHistogram() 42 | c.Assert(err, IsNil) 43 | c.Assert(int(h.LatencyAtQuantile(100)/time.Second), Equals, 2) 44 | 45 | rr.Reset() 46 | c.Assert(rr.GetNetworkErrorCount(), Equals, int64(0)) 47 | c.Assert(rr.GetTotalCount(), Equals, int64(0)) 48 | c.Assert(rr.GetStatusCodesCounts(), DeepEquals, map[int]int64{}) 49 | c.Assert(rr.GetNetworkErrorRatio(), Equals, float64(0)) 50 | c.Assert(rr.GetResponseCodeRatio(500, 501, 200, 300), Equals, float64(0)) 51 | 52 | h, err = rr.GetLatencyHistogram() 53 | c.Assert(err, IsNil) 54 | c.Assert(h.LatencyAtQuantile(100), Equals, time.Duration(0)) 55 | } 56 | 57 | func makeAttempt(o O) *request.BaseAttempt { 58 | a := &request.BaseAttempt{ 59 | Error: o.err, 60 | Duration: o.duration, 61 | } 62 | if o.statusCode != 0 { 63 | a.Response = &http.Response{StatusCode: o.statusCode} 64 | } 65 | return a 66 | } 67 | 68 | type O struct { 69 | statusCode int 70 | err error 71 | duration time.Duration 72 | } 73 | -------------------------------------------------------------------------------- /middleware/chain.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | . "github.com/mailgun/vulcan/request" 6 | "sort" 7 | "sync" 8 | ) 9 | 10 | // Middleware chain implements middleware interface and acts as a container 11 | // for multiple middlewares chained together in deterministic order. 12 | type MiddlewareChain struct { 13 | chain *chain 14 | } 15 | 16 | func NewMiddlewareChain() *MiddlewareChain { 17 | return &MiddlewareChain{ 18 | chain: newChain(), 19 | } 20 | } 21 | 22 | func (c *MiddlewareChain) Add(id string, priority int, m Middleware) error { 23 | return c.chain.append(id, priority, m) 24 | } 25 | 26 | func (c *MiddlewareChain) Upsert(id string, priority int, m Middleware) { 27 | c.chain.upsert(id, priority, m) 28 | } 29 | 30 | func (c *MiddlewareChain) Remove(id string) error { 31 | return c.chain.remove(id) 32 | } 33 | 34 | func (c *MiddlewareChain) Update(id string, priority int, m Middleware) error { 35 | return c.chain.update(id, priority, m) 36 | } 37 | 38 | func (c *MiddlewareChain) Get(id string) Middleware { 39 | m := c.chain.get(id) 40 | if m != nil { 41 | return m.(Middleware) 42 | } 43 | return nil 44 | } 45 | 46 | func (c *MiddlewareChain) GetIter() *MiddlewareIter { 47 | return &MiddlewareIter{ 48 | iter: c.chain.getIter(), 49 | } 50 | } 51 | 52 | type MiddlewareIter struct { 53 | iter *iter 54 | } 55 | 56 | func (m *MiddlewareIter) Next() Middleware { 57 | val := m.iter.next() 58 | if val == nil { 59 | return nil 60 | } 61 | return val.(Middleware) 62 | } 63 | 64 | func (m *MiddlewareIter) Prev() Middleware { 65 | val := m.iter.prev() 66 | if val == nil { 67 | return nil 68 | } 69 | return val.(Middleware) 70 | } 71 | 72 | type ObserverChain struct { 73 | chain *chain 74 | } 75 | 76 | func NewObserverChain() *ObserverChain { 77 | return &ObserverChain{ 78 | chain: newChain(), 79 | } 80 | } 81 | 82 | func (c *ObserverChain) Add(id string, o Observer) error { 83 | return c.chain.append(id, 0, o) 84 | } 85 | 86 | func (c *ObserverChain) Upsert(id string, o Observer) { 87 | c.chain.upsert(id, 0, o) 88 | } 89 | 90 | func (c *ObserverChain) Remove(id string) error { 91 | return c.chain.remove(id) 92 | } 93 | 94 | func (c *ObserverChain) Update(id string, o Observer) error { 95 | return c.chain.update(id, 0, o) 96 | } 97 | 98 | func (c *ObserverChain) Get(id string) Observer { 99 | o := c.chain.get(id) 100 | if o != nil { 101 | return o.(Observer) 102 | } 103 | return nil 104 | } 105 | 106 | func (c *ObserverChain) ObserveRequest(r Request) { 107 | it := c.chain.getIter() 108 | for v := it.next(); v != nil; v = it.next() { 109 | v.(Observer).ObserveRequest(r) 110 | } 111 | } 112 | 113 | func (c *ObserverChain) ObserveResponse(r Request, a Attempt) { 114 | it := c.chain.getReverseIter() 115 | for v := it.next(); v != nil; v = it.next() { 116 | v.(Observer).ObserveResponse(r, a) 117 | } 118 | } 119 | 120 | // Map with guaranteed iteration order, in place updates that do not change the order 121 | // and iterator that does not hold locks 122 | type chain struct { 123 | mutex *sync.RWMutex 124 | callbacks []*callback 125 | indexes map[string]int // Indexes for in place updates 126 | iter *iter //current version of iterator 127 | } 128 | 129 | type callback struct { 130 | id string 131 | priority int 132 | cb interface{} 133 | } 134 | 135 | type callbacks []*callback 136 | 137 | func (c callbacks) Len() int { 138 | return len(c) 139 | } 140 | 141 | func (c callbacks) Less(i, j int) bool { 142 | return c[i].priority < c[j].priority 143 | } 144 | 145 | func (c callbacks) Swap(i, j int) { 146 | c[i], c[j] = c[j], c[i] 147 | } 148 | 149 | func newChain() *chain { 150 | return &chain{ 151 | mutex: &sync.RWMutex{}, 152 | callbacks: callbacks{}, 153 | } 154 | } 155 | 156 | func (c *chain) append(id string, priority int, cb interface{}) error { 157 | c.mutex.Lock() 158 | defer c.mutex.Unlock() 159 | 160 | if p, _ := c.find(id); p != nil { 161 | return fmt.Errorf("Callback with id: %s already exists", id) 162 | } 163 | c.callbacks = append(c.callbacks, &callback{id, priority, cb}) 164 | sort.Stable((callbacks)(c.callbacks)) 165 | return nil 166 | } 167 | 168 | func (c *chain) find(id string) (*callback, int) { 169 | for i, c := range c.callbacks { 170 | if c.id == id { 171 | return c, i 172 | } 173 | } 174 | return nil, -1 175 | } 176 | 177 | func (c *chain) update(id string, priority int, cb interface{}) error { 178 | c.mutex.Lock() 179 | defer c.mutex.Unlock() 180 | 181 | p, _ := c.find(id) 182 | if p == nil { 183 | return fmt.Errorf("Callback with id: %s not found", id) 184 | } 185 | p.cb = cb 186 | p.priority = priority 187 | sort.Stable((callbacks)(c.callbacks)) 188 | return nil 189 | } 190 | 191 | func (c *chain) upsert(id string, priority int, cb interface{}) { 192 | c.mutex.Lock() 193 | defer c.mutex.Unlock() 194 | 195 | p, _ := c.find(id) 196 | if p == nil { 197 | c.callbacks = append(c.callbacks, &callback{id, priority, cb}) 198 | } else { 199 | p.cb = cb 200 | p.priority = priority 201 | } 202 | sort.Stable((callbacks)(c.callbacks)) 203 | } 204 | 205 | func (c *chain) get(id string) interface{} { 206 | c.mutex.Lock() 207 | defer c.mutex.Unlock() 208 | 209 | p, _ := c.find(id) 210 | if p == nil { 211 | return nil 212 | } else { 213 | return p.cb 214 | } 215 | } 216 | 217 | func (c *chain) remove(id string) error { 218 | c.mutex.Lock() 219 | defer c.mutex.Unlock() 220 | 221 | p, i := c.find(id) 222 | if p == nil { 223 | return fmt.Errorf("Callback with id: %s not found", id) 224 | } 225 | c.callbacks = append(c.callbacks[:i], c.callbacks[i+1:]...) 226 | sort.Stable((callbacks)(c.callbacks)) 227 | return nil 228 | } 229 | 230 | // Note that we hold read lock to get access to the current iterator 231 | func (c *chain) getIter() *iter { 232 | c.mutex.RLock() 233 | defer c.mutex.RUnlock() 234 | return newIter(c.callbacks) 235 | } 236 | 237 | func (c *chain) getReverseIter() *reverseIter { 238 | c.mutex.RLock() 239 | defer c.mutex.RUnlock() 240 | return &reverseIter{callbacks: c.callbacks} 241 | } 242 | 243 | func newIter(callbacks []*callback) *iter { 244 | return &iter{ 245 | index: -1, 246 | callbacks: callbacks, 247 | } 248 | } 249 | 250 | type iter struct { 251 | index int 252 | callbacks []*callback 253 | } 254 | 255 | func (it *iter) next() interface{} { 256 | if it.index >= len(it.callbacks) { 257 | return nil 258 | } 259 | it.index += 1 260 | if it.index >= len(it.callbacks) { 261 | return nil 262 | } 263 | return it.callbacks[it.index].cb 264 | } 265 | 266 | func (it *iter) prev() interface{} { 267 | if it.index < 0 { 268 | return nil 269 | } 270 | it.index -= 1 271 | if it.index < 0 { 272 | return nil 273 | } 274 | return it.callbacks[it.index].cb 275 | } 276 | 277 | type reverseIter struct { 278 | index int 279 | callbacks []*callback 280 | } 281 | 282 | func (it *reverseIter) next() interface{} { 283 | if it.index >= len(it.callbacks) { 284 | return nil 285 | } 286 | val := it.callbacks[len(it.callbacks)-it.index-1].cb 287 | it.index += 1 288 | return val 289 | } 290 | -------------------------------------------------------------------------------- /middleware/middleware.go: -------------------------------------------------------------------------------- 1 | // Middlewares can modify or intercept requests and responses 2 | package middleware 3 | 4 | import ( 5 | . "github.com/mailgun/vulcan/request" 6 | "net/http" 7 | ) 8 | 9 | // Middlewares are allowed to observe, modify and intercept http requests and responses 10 | type Middleware interface { 11 | // Called before the request is going to be proxied to the endpoint selected by the load balancer. 12 | // If it returns an error, request will be treated as erorrneous (e.g. failover will be initated). 13 | // If it returns a non nil response, proxy will return the response without proxying to the endpoint. 14 | // If it returns nil response and nil error request will be proxied to the upstream. 15 | // It's ok to modify request headers and body as a side effect of the funciton call. 16 | ProcessRequest(r Request) (*http.Response, error) 17 | 18 | // If request has been completed or intercepted by middleware and response has been received 19 | // attempt would contain non nil response or non nil error. 20 | ProcessResponse(r Request, a Attempt) 21 | } 22 | 23 | // Unlinke middlewares, observers are not able to intercept or change any requests 24 | // and will be called on every request to endpoint regardless of the middlewares side effects 25 | type Observer interface { 26 | // Will be called before every request to the endpoint 27 | ObserveRequest(r Request) 28 | 29 | // Will be called after every request to the endpoint 30 | ObserveResponse(r Request, a Attempt) 31 | } 32 | 33 | type ProcessRequestFn func(r Request) (*http.Response, error) 34 | type ProcessResponseFn func(r Request, a Attempt) 35 | 36 | // Wraps the functions to create a middleware compatible interface 37 | type MiddlewareWrapper struct { 38 | OnRequest ProcessRequestFn 39 | OnResponse ProcessResponseFn 40 | } 41 | 42 | func (cb *MiddlewareWrapper) ProcessRequest(r Request) (*http.Response, error) { 43 | if cb.OnRequest != nil { 44 | return cb.OnRequest(r) 45 | } 46 | return nil, nil 47 | } 48 | 49 | func (cb *MiddlewareWrapper) ProcessResponse(r Request, a Attempt) { 50 | if cb.OnResponse != nil { 51 | cb.OnResponse(r, a) 52 | } 53 | } 54 | 55 | type ObserveRequestFn func(r Request) 56 | type ObserveResponseFn func(r Request, a Attempt) 57 | 58 | // Wraps the functions to create a observer compatible interface 59 | type ObserverWrapper struct { 60 | OnRequest ObserveRequestFn 61 | OnResponse ObserveResponseFn 62 | } 63 | 64 | func (cb *ObserverWrapper) ObserveRequest(r Request) { 65 | if cb.OnRequest != nil { 66 | cb.OnRequest(r) 67 | } 68 | } 69 | 70 | func (cb *ObserverWrapper) ObserveResponse(r Request, a Attempt) { 71 | if cb.OnResponse != nil { 72 | cb.OnResponse(r, a) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /netutils/buffer.go: -------------------------------------------------------------------------------- 1 | package netutils 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "io/ioutil" 8 | "os" 9 | ) 10 | 11 | // MultiReader provides Read, Close, Seek and TotalSize methods. In addition to that it supports WriterTo interface 12 | // to provide efficient writing schemes, as functions like io.Copy use WriterTo when it's available. 13 | type MultiReader interface { 14 | io.Reader 15 | io.Seeker 16 | io.Closer 17 | io.WriterTo 18 | 19 | // TotalSize calculates and returns the total size of the reader and not the length remaining. 20 | TotalSize() (int64, error) 21 | } 22 | 23 | const ( 24 | DefaultMemBufferBytes = 1048576 25 | DefaultMaxSizeBytes = -1 26 | // Equivalent of bytes.MinRead used in ioutil.ReadAll 27 | DefaultBufferBytes = 512 28 | ) 29 | 30 | // Constraints: 31 | // - Implements io.Reader 32 | // - Implements Seek(0, 0) 33 | // - Designed for Write once, Read many times. 34 | type multiReaderSeek struct { 35 | length int64 36 | readers []io.ReadSeeker 37 | mr io.Reader 38 | cleanup CleanupFunc 39 | } 40 | 41 | type CleanupFunc func() error 42 | 43 | func NewMultiReaderSeeker(length int64, cleanup CleanupFunc, readers ...io.ReadSeeker) *multiReaderSeek { 44 | converted := make([]io.Reader, len(readers)) 45 | for i, r := range readers { 46 | // This conversion is safe as ReadSeeker includes Reader 47 | converted[i] = r.(io.Reader) 48 | } 49 | 50 | return &multiReaderSeek{ 51 | length: length, 52 | readers: readers, 53 | mr: io.MultiReader(converted...), 54 | cleanup: cleanup, 55 | } 56 | } 57 | 58 | func (mr *multiReaderSeek) Close() (err error) { 59 | if mr.cleanup != nil { 60 | return mr.cleanup() 61 | } 62 | return nil 63 | } 64 | 65 | func (mr *multiReaderSeek) WriteTo(w io.Writer) (int64, error) { 66 | b := make([]byte, DefaultBufferBytes) 67 | var total int64 68 | for { 69 | n, err := mr.mr.Read(b) 70 | // Recommended way is to always handle non 0 reads despite the errors 71 | if n > 0 { 72 | nw, errw := w.Write(b[:n]) 73 | total += int64(nw) 74 | // Write must return a non-nil error if it returns nw < n 75 | if nw != n || errw != nil { 76 | return total, errw 77 | } 78 | } 79 | if err != nil { 80 | if err == io.EOF { 81 | return total, nil 82 | } 83 | return total, err 84 | } 85 | } 86 | } 87 | 88 | func (mr *multiReaderSeek) Read(p []byte) (n int, err error) { 89 | return mr.mr.Read(p) 90 | } 91 | 92 | func (mr *multiReaderSeek) TotalSize() (int64, error) { 93 | return mr.length, nil 94 | } 95 | 96 | func (mr *multiReaderSeek) Seek(offset int64, whence int) (int64, error) { 97 | // TODO: implement other whence 98 | // TODO: implement real offsets 99 | 100 | if whence != 0 { 101 | return 0, fmt.Errorf("multiReaderSeek: unsupported whence") 102 | } 103 | 104 | if offset != 0 { 105 | return 0, fmt.Errorf("multiReaderSeek: unsupported offset") 106 | } 107 | 108 | for _, seeker := range mr.readers { 109 | seeker.Seek(0, 0) 110 | } 111 | 112 | ior := make([]io.Reader, len(mr.readers)) 113 | for i, arg := range mr.readers { 114 | ior[i] = arg.(io.Reader) 115 | } 116 | mr.mr = io.MultiReader(ior...) 117 | 118 | return 0, nil 119 | } 120 | 121 | type BodyBufferOptions struct { 122 | // MemBufferBytes sets up the size of the memory buffer for this request. 123 | // If the data size exceeds the limit, the remaining request part will be saved on the file system. 124 | MemBufferBytes int64 125 | // Max size bytes, ignored if set to value <= 0, if request exceeds the specified limit, the reader will fail. 126 | MaxSizeBytes int64 127 | } 128 | 129 | func NewBodyBuffer(input io.Reader) (MultiReader, error) { 130 | return NewBodyBufferWithOptions( 131 | input, BodyBufferOptions{ 132 | MemBufferBytes: DefaultMemBufferBytes, 133 | MaxSizeBytes: DefaultMaxSizeBytes, 134 | }) 135 | } 136 | 137 | func NewBodyBufferWithOptions(input io.Reader, o BodyBufferOptions) (MultiReader, error) { 138 | memReader := &io.LimitedReader{ 139 | R: input, // Read from this reader 140 | N: o.MemBufferBytes, // Maximum amount of data to read 141 | } 142 | readers := make([]io.ReadSeeker, 0, 2) 143 | 144 | buffer, err := ioutil.ReadAll(memReader) 145 | if err != nil { 146 | return nil, err 147 | } 148 | readers = append(readers, bytes.NewReader(buffer)) 149 | 150 | var file *os.File 151 | // This means that we have exceeded all the memory capacity and we will start buffering the body to disk. 152 | totalBytes := int64(len(buffer)) 153 | if memReader.N <= 0 { 154 | file, err = ioutil.TempFile("", "vulcan-bodies-") 155 | if err != nil { 156 | return nil, err 157 | } 158 | os.Remove(file.Name()) 159 | 160 | readSrc := input 161 | if o.MaxSizeBytes > 0 { 162 | readSrc = &MaxReader{R: input, Max: o.MaxSizeBytes - o.MemBufferBytes} 163 | } 164 | 165 | writtenBytes, err := io.Copy(file, readSrc) 166 | if err != nil { 167 | return nil, err 168 | } 169 | totalBytes += writtenBytes 170 | file.Seek(0, 0) 171 | readers = append(readers, file) 172 | } 173 | 174 | var cleanupFn CleanupFunc 175 | if file != nil { 176 | cleanupFn = func() error { 177 | file.Close() 178 | return nil 179 | } 180 | } 181 | return NewMultiReaderSeeker(totalBytes, cleanupFn, readers...), nil 182 | } 183 | 184 | // MaxReader does not allow to read more than Max bytes and returns error if this limit has been exceeded. 185 | type MaxReader struct { 186 | R io.Reader // underlying reader 187 | N int64 // bytes read 188 | Max int64 // max bytes to read 189 | } 190 | 191 | func (r *MaxReader) Read(p []byte) (int, error) { 192 | readBytes, err := r.R.Read(p) 193 | if err != nil && err != io.EOF { 194 | return readBytes, err 195 | } 196 | 197 | r.N += int64(readBytes) 198 | if r.N > r.Max { 199 | return readBytes, &MaxSizeReachedError{MaxSize: r.Max} 200 | } 201 | return readBytes, err 202 | } 203 | 204 | type MaxSizeReachedError struct { 205 | MaxSize int64 206 | } 207 | 208 | func (e *MaxSizeReachedError) Error() string { 209 | return fmt.Sprintf("Maximum size %d was reached", e) 210 | } 211 | -------------------------------------------------------------------------------- /netutils/buffer_test.go: -------------------------------------------------------------------------------- 1 | package netutils 2 | 3 | import ( 4 | "bytes" 5 | "crypto/md5" 6 | "encoding/hex" 7 | . "gopkg.in/check.v1" 8 | "io" 9 | "io/ioutil" 10 | "os" 11 | ) 12 | 13 | type BufferSuite struct{} 14 | 15 | var _ = Suite(&BufferSuite{}) 16 | 17 | func createReaderOfSize(size int64) (reader io.Reader, hash string) { 18 | f, err := os.Open("/dev/urandom") 19 | if err != nil { 20 | panic(err) 21 | } 22 | 23 | b := make([]byte, int(size)) 24 | 25 | _, err = io.ReadFull(f, b) 26 | 27 | if err != nil { 28 | panic(err) 29 | } 30 | 31 | h := md5.New() 32 | h.Write(b) 33 | return bytes.NewReader(b), hex.EncodeToString(h.Sum(nil)) 34 | } 35 | 36 | func hashOfReader(r io.Reader) string { 37 | h := md5.New() 38 | tr := io.TeeReader(r, h) 39 | _, _ = io.Copy(ioutil.Discard, tr) 40 | return hex.EncodeToString(h.Sum(nil)) 41 | } 42 | 43 | func (s *BufferSuite) TestSmallBuffer(c *C) { 44 | r, hash := createReaderOfSize(1) 45 | bb, err := NewBodyBuffer(r) 46 | c.Assert(err, IsNil) 47 | c.Assert(hashOfReader(bb), Equals, hash) 48 | bb.Close() 49 | } 50 | 51 | func (s *BufferSuite) TestBigBuffer(c *C) { 52 | r, hash := createReaderOfSize(13631488) 53 | bb, err := NewBodyBuffer(r) 54 | c.Assert(err, IsNil) 55 | c.Assert(hashOfReader(bb), Equals, hash) 56 | } 57 | 58 | func (s *BufferSuite) TestSeek(c *C) { 59 | tlen := int64(1057576) 60 | r, hash := createReaderOfSize(tlen) 61 | bb, err := NewBodyBuffer(r) 62 | 63 | c.Assert(err, IsNil) 64 | c.Assert(hashOfReader(bb), Equals, hash) 65 | l, err := bb.TotalSize() 66 | c.Assert(err, IsNil) 67 | c.Assert(l, Equals, tlen) 68 | 69 | bb.Seek(0, 0) 70 | c.Assert(hashOfReader(bb), Equals, hash) 71 | l, err = bb.TotalSize() 72 | c.Assert(err, IsNil) 73 | c.Assert(l, Equals, tlen) 74 | } 75 | 76 | func (s *BufferSuite) TestSeekFirst(c *C) { 77 | tlen := int64(1057576) 78 | r, hash := createReaderOfSize(tlen) 79 | bb, err := NewBodyBuffer(r) 80 | 81 | l, err := bb.TotalSize() 82 | c.Assert(err, IsNil) 83 | c.Assert(l, Equals, tlen) 84 | 85 | c.Assert(err, IsNil) 86 | c.Assert(hashOfReader(bb), Equals, hash) 87 | 88 | bb.Seek(0, 0) 89 | 90 | c.Assert(hashOfReader(bb), Equals, hash) 91 | l, err = bb.TotalSize() 92 | c.Assert(err, IsNil) 93 | c.Assert(l, Equals, tlen) 94 | } 95 | 96 | func (s *BufferSuite) TestLimitDoesNotExceed(c *C) { 97 | requestSize := int64(1057576) 98 | r, hash := createReaderOfSize(requestSize) 99 | bb, err := NewBodyBufferWithOptions(r, BodyBufferOptions{MemBufferBytes: 1024, MaxSizeBytes: requestSize + 1}) 100 | c.Assert(err, IsNil) 101 | c.Assert(hashOfReader(bb), Equals, hash) 102 | size, err := bb.TotalSize() 103 | c.Assert(err, IsNil) 104 | c.Assert(size, Equals, requestSize) 105 | bb.Close() 106 | } 107 | 108 | func (s *BufferSuite) TestLimitExceeds(c *C) { 109 | requestSize := int64(1057576) 110 | r, _ := createReaderOfSize(requestSize) 111 | bb, err := NewBodyBufferWithOptions(r, BodyBufferOptions{MemBufferBytes: 1024, MaxSizeBytes: requestSize - 1}) 112 | c.Assert(err, FitsTypeOf, &MaxSizeReachedError{}) 113 | c.Assert(bb, IsNil) 114 | } 115 | 116 | func (s *BufferSuite) TestWriteToBigBuffer(c *C) { 117 | l := int64(13631488) 118 | r, hash := createReaderOfSize(l) 119 | bb, err := NewBodyBuffer(r) 120 | c.Assert(err, IsNil) 121 | 122 | other := &bytes.Buffer{} 123 | wrote, err := bb.WriteTo(other) 124 | c.Assert(err, IsNil) 125 | c.Assert(wrote, Equals, l) 126 | c.Assert(hashOfReader(other), Equals, hash) 127 | } 128 | 129 | func (s *BufferSuite) TestWriteToSmallBuffer(c *C) { 130 | l := int64(1) 131 | r, hash := createReaderOfSize(l) 132 | bb, err := NewBodyBuffer(r) 133 | c.Assert(err, IsNil) 134 | 135 | other := &bytes.Buffer{} 136 | wrote, err := bb.WriteTo(other) 137 | c.Assert(err, IsNil) 138 | c.Assert(wrote, Equals, l) 139 | c.Assert(hashOfReader(other), Equals, hash) 140 | } 141 | -------------------------------------------------------------------------------- /netutils/netutils.go: -------------------------------------------------------------------------------- 1 | // Network related utilities 2 | package netutils 3 | 4 | import ( 5 | "encoding/base64" 6 | "fmt" 7 | "net/http" 8 | "net/url" 9 | "strings" 10 | ) 11 | 12 | // Provides update safe copy by avoiding 13 | // shallow copying certain fields (like user data) 14 | func CopyUrl(in *url.URL) *url.URL { 15 | out := new(url.URL) 16 | *out = *in 17 | if in.User != nil { 18 | *out.User = *in.User 19 | } 20 | return out 21 | } 22 | 23 | // RawPath returns escaped url path section 24 | func RawPath(in string) (string, error) { 25 | u, err := url.ParseRequestURI(in) 26 | if err != nil { 27 | return "", err 28 | } 29 | path := "" 30 | if u.Opaque != "" { 31 | path = u.Opaque 32 | } else if u.Host == "" { 33 | path = in 34 | } else { 35 | vals := strings.SplitN(in, u.Host, 2) 36 | if len(vals) != 2 { 37 | return "", fmt.Errorf("failed to parse url") 38 | } 39 | path = vals[1] 40 | } 41 | idx := strings.IndexRune(path, '?') 42 | if idx == -1 { 43 | return path, nil 44 | } 45 | return path[:idx], nil 46 | } 47 | 48 | // RawURL returns URL built out of the provided request's Request-URI, to avoid un-escaping. 49 | // Note: it assumes that scheme and host for the provided request's URL are defined. 50 | func RawURL(request *http.Request) string { 51 | return strings.Join([]string{request.URL.Scheme, "://", request.URL.Host, request.RequestURI}, "") 52 | } 53 | 54 | // Copies http headers from source to destination 55 | // does not overide, but adds multiple headers 56 | func CopyHeaders(dst, src http.Header) { 57 | for k, vv := range src { 58 | for _, v := range vv { 59 | dst.Add(k, v) 60 | } 61 | } 62 | } 63 | 64 | // Determines whether any of the header names is present 65 | // in the http headers 66 | func HasHeaders(names []string, headers http.Header) bool { 67 | for _, h := range names { 68 | if headers.Get(h) != "" { 69 | return true 70 | } 71 | } 72 | return false 73 | } 74 | 75 | // Removes the header with the given names from the headers map 76 | func RemoveHeaders(names []string, headers http.Header) { 77 | for _, h := range names { 78 | headers.Del(h) 79 | } 80 | } 81 | 82 | func MustParseUrl(inUrl string) *url.URL { 83 | u, err := ParseUrl(inUrl) 84 | if err != nil { 85 | panic(err) 86 | } 87 | return u 88 | } 89 | 90 | // Standard parse url is very generous, 91 | // parseUrl wrapper makes it more strict 92 | // and demands scheme and host to be set 93 | func ParseUrl(inUrl string) (*url.URL, error) { 94 | parsedUrl, err := url.Parse(inUrl) 95 | if err != nil { 96 | return nil, err 97 | } 98 | 99 | if parsedUrl.Host == "" || parsedUrl.Scheme == "" { 100 | return nil, fmt.Errorf("Empty Url is not allowed") 101 | } 102 | return parsedUrl, nil 103 | } 104 | 105 | type BasicAuth struct { 106 | Username string 107 | Password string 108 | } 109 | 110 | func (ba *BasicAuth) String() string { 111 | encoded := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", ba.Username, ba.Password))) 112 | return fmt.Sprintf("Basic %s", encoded) 113 | } 114 | 115 | func ParseAuthHeader(header string) (*BasicAuth, error) { 116 | 117 | values := strings.Fields(header) 118 | if len(values) != 2 { 119 | return nil, fmt.Errorf(fmt.Sprintf("Failed to parse header '%s'", header)) 120 | } 121 | 122 | auth_type := strings.ToLower(values[0]) 123 | if auth_type != "basic" { 124 | return nil, fmt.Errorf("Expected basic auth type, got '%s'", auth_type) 125 | } 126 | 127 | encoded_string := values[1] 128 | decoded_string, err := base64.StdEncoding.DecodeString(encoded_string) 129 | if err != nil { 130 | return nil, fmt.Errorf("Failed to parse header '%s', base64 failed: %s", header, err) 131 | } 132 | 133 | values = strings.SplitN(string(decoded_string), ":", 2) 134 | if len(values) != 2 { 135 | return nil, fmt.Errorf("Failed to parse header '%s', expected separator ':'", header) 136 | } 137 | return &BasicAuth{Username: values[0], Password: values[1]}, nil 138 | } 139 | -------------------------------------------------------------------------------- /netutils/netutils_test.go: -------------------------------------------------------------------------------- 1 | package netutils 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "testing" 7 | 8 | . "gopkg.in/check.v1" 9 | ) 10 | 11 | func TestUtils(t *testing.T) { TestingT(t) } 12 | 13 | type NetUtilsSuite struct{} 14 | 15 | var _ = Suite(&NetUtilsSuite{}) 16 | 17 | // Make sure parseUrl is strict enough not to accept total garbage 18 | func (s *NetUtilsSuite) TestParseBadUrl(c *C) { 19 | badUrls := []string{ 20 | "", 21 | " some random text ", 22 | "http---{}{\\bad bad url", 23 | } 24 | for _, badUrl := range badUrls { 25 | _, err := ParseUrl(badUrl) 26 | c.Assert(err, NotNil) 27 | } 28 | } 29 | 30 | // Make sure parseUrl is strict enough not to accept total garbage 31 | func (s *NetUtilsSuite) TestURLRawPath(c *C) { 32 | vals := []struct { 33 | URL string 34 | Expected string 35 | }{ 36 | {"http://google.com/", "/"}, 37 | {"http://google.com/a?q=b", "/a"}, 38 | {"http://google.com/%2Fvalue/hello", "/%2Fvalue/hello"}, 39 | {"/home", "/home"}, 40 | {"/home?a=b", "/home"}, 41 | {"/home%2F", "/home%2F"}, 42 | } 43 | for _, v := range vals { 44 | out, err := RawPath(v.URL) 45 | c.Assert(err, IsNil) 46 | c.Assert(out, Equals, v.Expected) 47 | } 48 | } 49 | 50 | func (s *NetUtilsSuite) TestRawURL(c *C) { 51 | request := &http.Request{URL: &url.URL{Scheme: "http", Host: "localhost:8080"}, RequestURI: "/foo/bar"} 52 | c.Assert("http://localhost:8080/foo/bar", Equals, RawURL(request)) 53 | } 54 | 55 | //Just to make sure we don't panic, return err and not 56 | //username and pass and cover the function 57 | func (s *NetUtilsSuite) TestParseBadHeaders(c *C) { 58 | headers := []string{ 59 | //just empty string 60 | "", 61 | //missing auth type 62 | "justplainstring", 63 | //unknown auth type 64 | "Whut justplainstring", 65 | //invalid base64 66 | "Basic Shmasic", 67 | //random encoded string 68 | "Basic YW55IGNhcm5hbCBwbGVhcw==", 69 | } 70 | for _, h := range headers { 71 | _, err := ParseAuthHeader(h) 72 | c.Assert(err, NotNil) 73 | } 74 | } 75 | 76 | //Just to make sure we don't panic, return err and not 77 | //username and pass and cover the function 78 | func (s *NetUtilsSuite) TestParseSuccess(c *C) { 79 | headers := []struct { 80 | Header string 81 | Expected BasicAuth 82 | }{ 83 | { 84 | "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", 85 | BasicAuth{Username: "Aladdin", Password: "open sesame"}, 86 | }, 87 | // Make sure that String() produces valid header 88 | { 89 | (&BasicAuth{Username: "Alice", Password: "Here's bob"}).String(), 90 | BasicAuth{Username: "Alice", Password: "Here's bob"}, 91 | }, 92 | //empty pass 93 | { 94 | "Basic QWxhZGRpbjo=", 95 | BasicAuth{Username: "Aladdin", Password: ""}, 96 | }, 97 | } 98 | for _, h := range headers { 99 | request, err := ParseAuthHeader(h.Header) 100 | c.Assert(err, IsNil) 101 | c.Assert(request.Username, Equals, h.Expected.Username) 102 | c.Assert(request.Password, Equals, h.Expected.Password) 103 | 104 | } 105 | } 106 | 107 | // Make sure copy does it right, so the copied url 108 | // is safe to alter without modifying the other 109 | func (s *NetUtilsSuite) TestCopyUrl(c *C) { 110 | urlA := &url.URL{ 111 | Scheme: "http", 112 | Host: "localhost:5000", 113 | Path: "/upstream", 114 | Opaque: "opaque", 115 | RawQuery: "a=1&b=2", 116 | Fragment: "#hello", 117 | User: &url.Userinfo{}, 118 | } 119 | urlB := CopyUrl(urlA) 120 | c.Assert(urlB, DeepEquals, urlB) 121 | urlB.Scheme = "https" 122 | c.Assert(urlB, Not(DeepEquals), urlA) 123 | } 124 | 125 | // Make sure copy headers is not shallow and copies all headers 126 | func (s *NetUtilsSuite) TestCopyHeaders(c *C) { 127 | source, destination := make(http.Header), make(http.Header) 128 | source.Add("a", "b") 129 | source.Add("c", "d") 130 | 131 | CopyHeaders(destination, source) 132 | 133 | c.Assert(destination.Get("a"), Equals, "b") 134 | c.Assert(destination.Get("c"), Equals, "d") 135 | 136 | // make sure that altering source does not affect the destination 137 | source.Del("a") 138 | c.Assert(source.Get("a"), Equals, "") 139 | c.Assert(destination.Get("a"), Equals, "b") 140 | } 141 | 142 | func (s *NetUtilsSuite) TestHasHeaders(c *C) { 143 | source := make(http.Header) 144 | source.Add("a", "b") 145 | source.Add("c", "d") 146 | c.Assert(HasHeaders([]string{"a", "f"}, source), Equals, true) 147 | c.Assert(HasHeaders([]string{"i", "j"}, source), Equals, false) 148 | } 149 | 150 | func (s *NetUtilsSuite) TestRemoveHeaders(c *C) { 151 | source := make(http.Header) 152 | source.Add("a", "b") 153 | source.Add("a", "m") 154 | source.Add("c", "d") 155 | RemoveHeaders([]string{"a"}, source) 156 | c.Assert(source.Get("a"), Equals, "") 157 | c.Assert(source.Get("c"), Equals, "d") 158 | } 159 | -------------------------------------------------------------------------------- /netutils/response.go: -------------------------------------------------------------------------------- 1 | package netutils 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io/ioutil" 8 | "net/http" 9 | ) 10 | 11 | func NewHttpResponse(request *http.Request, statusCode int, body []byte, contentType string) *http.Response { 12 | resp := &http.Response{ 13 | Status: fmt.Sprintf("%d %s", statusCode, http.StatusText(statusCode)), 14 | StatusCode: statusCode, 15 | Proto: "HTTP/1.0", 16 | ProtoMajor: 1, 17 | ProtoMinor: 0, 18 | Header: make(http.Header), 19 | } 20 | resp.Header.Add("Content-Type", contentType) 21 | resp.Body = ioutil.NopCloser(bytes.NewBuffer(body)) 22 | resp.ContentLength = int64(len(body)) 23 | resp.Request = request 24 | return resp 25 | } 26 | 27 | func NewTextResponse(request *http.Request, statusCode int, body string) *http.Response { 28 | return NewHttpResponse(request, statusCode, []byte(body), "text/plain") 29 | } 30 | 31 | func NewJsonResponse(request *http.Request, statusCode int, message interface{}) *http.Response { 32 | bytes, err := json.Marshal(message) 33 | if err != nil { 34 | bytes = []byte("{}") 35 | } 36 | return NewHttpResponse(request, statusCode, bytes, "application/json") 37 | } 38 | -------------------------------------------------------------------------------- /proxy.go: -------------------------------------------------------------------------------- 1 | // This package contains the reverse proxy that implements http.HandlerFunc 2 | package vulcan 3 | 4 | import ( 5 | "io" 6 | "net" 7 | "net/http" 8 | "sync/atomic" 9 | 10 | "github.com/mailgun/log" 11 | "github.com/mailgun/vulcan/errors" 12 | "github.com/mailgun/vulcan/netutils" 13 | "github.com/mailgun/vulcan/request" 14 | "github.com/mailgun/vulcan/route" 15 | ) 16 | 17 | type Proxy struct { 18 | // Router selects a location for each request 19 | router route.Router 20 | // Options like ErrorFormatter 21 | options Options 22 | // Counter that is used to provide unique identifiers for requests 23 | lastRequestId int64 24 | } 25 | 26 | type Options struct { 27 | // Takes a status code and formats it into proxy response 28 | ErrorFormatter errors.Formatter 29 | } 30 | 31 | // Accepts requests, round trips it to the endpoint, and writes back the response. 32 | func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 33 | err := p.proxyRequest(w, r) 34 | if err == nil { 35 | return 36 | } 37 | 38 | switch e := err.(type) { 39 | case *errors.RedirectError: 40 | // In case if it's redirect error, try the request one more time, but with different URL 41 | r.URL = e.URL 42 | r.Host = e.URL.Host 43 | r.RequestURI = e.URL.String() 44 | if err := p.proxyRequest(w, r); err != nil { 45 | p.replyError(err, w, r) 46 | } 47 | default: 48 | p.replyError(err, w, r) 49 | } 50 | } 51 | 52 | // Creates a proxy with a given router 53 | func NewProxy(router route.Router) (*Proxy, error) { 54 | return NewProxyWithOptions(router, Options{}) 55 | } 56 | 57 | // Creates reverse proxy that acts like http request handler 58 | func NewProxyWithOptions(router route.Router, o Options) (*Proxy, error) { 59 | o, err := validateOptions(o) 60 | if err != nil { 61 | return nil, err 62 | } 63 | 64 | p := &Proxy{ 65 | options: o, 66 | router: router, 67 | } 68 | return p, nil 69 | } 70 | 71 | func (p *Proxy) GetRouter() route.Router { 72 | return p.router 73 | } 74 | 75 | // Round trips the request to the selected location and writes back the response 76 | func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) error { 77 | 78 | // Create a unique request with sequential ids that will be passed to all interfaces. 79 | req := request.NewBaseRequest(r, atomic.AddInt64(&p.lastRequestId, 1), nil) 80 | location, err := p.router.Route(req) 81 | if err != nil { 82 | return err 83 | } 84 | 85 | // Router could not find a matching location, we can do nothing else. 86 | if location == nil { 87 | log.Errorf("%s failed to route", req) 88 | return errors.FromStatus(http.StatusBadGateway) 89 | } 90 | 91 | response, err := location.RoundTrip(req) 92 | if response != nil { 93 | netutils.CopyHeaders(w.Header(), response.Header) 94 | w.WriteHeader(response.StatusCode) 95 | io.Copy(w, response.Body) 96 | response.Body.Close() 97 | return nil 98 | } else { 99 | return err 100 | } 101 | } 102 | 103 | // replyError is a helper function that takes error and replies with HTTP compatible error to the client. 104 | func (p *Proxy) replyError(err error, w http.ResponseWriter, req *http.Request) { 105 | proxyError := convertError(err) 106 | statusCode, body, contentType := p.options.ErrorFormatter.Format(proxyError) 107 | w.Header().Set("Content-Type", contentType) 108 | if proxyError.Headers() != nil { 109 | netutils.CopyHeaders(w.Header(), proxyError.Headers()) 110 | } 111 | w.WriteHeader(statusCode) 112 | w.Write(body) 113 | } 114 | 115 | func validateOptions(o Options) (Options, error) { 116 | if o.ErrorFormatter == nil { 117 | o.ErrorFormatter = &errors.JsonFormatter{} 118 | } 119 | return o, nil 120 | } 121 | 122 | func convertError(err error) errors.ProxyError { 123 | switch e := err.(type) { 124 | case errors.ProxyError: 125 | return e 126 | case net.Error: 127 | if e.Timeout() { 128 | return errors.FromStatus(http.StatusRequestTimeout) 129 | } 130 | case *netutils.MaxSizeReachedError: 131 | return errors.FromStatus(http.StatusRequestEntityTooLarge) 132 | } 133 | return errors.FromStatus(http.StatusBadGateway) 134 | } 135 | -------------------------------------------------------------------------------- /proxy_test.go: -------------------------------------------------------------------------------- 1 | package vulcan 2 | 3 | import ( 4 | "github.com/mailgun/timetools" 5 | . "github.com/mailgun/vulcan/location" 6 | . "github.com/mailgun/vulcan/route" 7 | . "github.com/mailgun/vulcan/testutils" 8 | . "gopkg.in/check.v1" 9 | "net/http" 10 | "net/http/httptest" 11 | "time" 12 | ) 13 | 14 | type ProxySuite struct { 15 | authHeaders http.Header 16 | tm *timetools.FreezedTime 17 | } 18 | 19 | var _ = Suite(&ProxySuite{ 20 | tm: &timetools.FreezedTime{ 21 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 22 | }, 23 | }) 24 | 25 | // Success, make sure we've successfully proxied the response 26 | func (s *ProxySuite) TestSuccess(c *C) { 27 | server := NewTestServer(func(w http.ResponseWriter, r *http.Request) { 28 | w.Write([]byte("Hi, I'm endpoint")) 29 | }) 30 | defer server.Close() 31 | 32 | proxy, err := NewProxy(&ConstRouter{&ConstHttpLocation{server.URL}}) 33 | c.Assert(err, IsNil) 34 | proxyServer := httptest.NewServer(proxy) 35 | defer proxyServer.Close() 36 | 37 | response, bodyBytes, err := MakeRequest(proxyServer.URL, Opts{}) 38 | c.Assert(err, IsNil) 39 | c.Assert(response.StatusCode, Equals, http.StatusOK) 40 | c.Assert(string(bodyBytes), Equals, "Hi, I'm endpoint") 41 | } 42 | 43 | func (s *ProxySuite) TestFailure(c *C) { 44 | proxy, err := NewProxy(&ConstRouter{&ConstHttpLocation{"http://localhost:63999"}}) 45 | c.Assert(err, IsNil) 46 | proxyServer := httptest.NewServer(proxy) 47 | defer proxyServer.Close() 48 | 49 | response, _, err := MakeRequest(proxyServer.URL, Opts{}) 50 | c.Assert(err, IsNil) 51 | c.Assert(response.StatusCode, Equals, http.StatusBadGateway) 52 | } 53 | 54 | func (s *ProxySuite) TestReadTimeout(c *C) { 55 | c.Skip("This test is not stable") 56 | 57 | server := NewTestServer(func(w http.ResponseWriter, r *http.Request) { 58 | w.Write([]byte("Hi, I'm endpoint")) 59 | }) 60 | defer server.Close() 61 | 62 | proxy, err := NewProxy(&ConstRouter{&ConstHttpLocation{server.URL}}) 63 | c.Assert(err, IsNil) 64 | 65 | // Set a very short read timeout 66 | proxyServer := httptest.NewUnstartedServer(proxy) 67 | proxyServer.Config.ReadTimeout = time.Millisecond 68 | proxyServer.Start() 69 | defer proxyServer.Close() 70 | 71 | value := make([]byte, 65636) 72 | for i := 0; i < len(value); i += 1 { 73 | value[i] = byte(i % 255) 74 | } 75 | 76 | response, _, err := MakeRequest(proxyServer.URL, Opts{Body: string(value)}) 77 | c.Assert(err, IsNil) 78 | c.Assert(response.StatusCode, Equals, http.StatusRequestTimeout) 79 | } 80 | -------------------------------------------------------------------------------- /request/request.go: -------------------------------------------------------------------------------- 1 | // Wrapper around http.Request with additional features 2 | package request 3 | 4 | import ( 5 | "fmt" 6 | "net/http" 7 | "sync" 8 | "time" 9 | 10 | "github.com/mailgun/vulcan/endpoint" 11 | "github.com/mailgun/vulcan/netutils" 12 | ) 13 | 14 | // Request is a rapper around http request that provides more info about http.Request 15 | type Request interface { 16 | GetHttpRequest() *http.Request // Original http request 17 | SetHttpRequest(*http.Request) // Can be used to set http request 18 | GetId() int64 // Request id that is unique to this running process 19 | SetBody(netutils.MultiReader) // Sets request body 20 | GetBody() netutils.MultiReader // Request body fully read and stored in effective manner (buffered to disk for large requests) 21 | AddAttempt(Attempt) // Add last proxy attempt to the request 22 | GetAttempts() []Attempt // Returns last attempts to proxy request, may be nil if there are no attempts 23 | GetLastAttempt() Attempt // Convenience method returning the last attempt, may be nil if there are no attempts 24 | String() string // Debugging string representation of the request 25 | SetUserData(key string, baton interface{}) // Provide storage space for data that survives with the request 26 | GetUserData(key string) (interface{}, bool) // Fetch user data set from previously SetUserData call 27 | DeleteUserData(key string) // Clean up user data set from previously SetUserData call 28 | } 29 | 30 | type Attempt interface { 31 | GetError() error 32 | GetDuration() time.Duration 33 | GetResponse() *http.Response 34 | GetEndpoint() endpoint.Endpoint 35 | } 36 | 37 | type BaseAttempt struct { 38 | Error error 39 | Duration time.Duration 40 | Response *http.Response 41 | Endpoint endpoint.Endpoint 42 | } 43 | 44 | func (ba *BaseAttempt) GetResponse() *http.Response { 45 | return ba.Response 46 | } 47 | 48 | func (ba *BaseAttempt) GetError() error { 49 | return ba.Error 50 | } 51 | 52 | func (ba *BaseAttempt) GetDuration() time.Duration { 53 | return ba.Duration 54 | } 55 | 56 | func (ba *BaseAttempt) GetEndpoint() endpoint.Endpoint { 57 | return ba.Endpoint 58 | } 59 | 60 | type BaseRequest struct { 61 | HttpRequest *http.Request 62 | Id int64 63 | Body netutils.MultiReader 64 | Attempts []Attempt 65 | userDataMutex *sync.RWMutex 66 | userData map[string]interface{} 67 | } 68 | 69 | func NewBaseRequest(r *http.Request, id int64, body netutils.MultiReader) *BaseRequest { 70 | return &BaseRequest{ 71 | HttpRequest: r, 72 | Id: id, 73 | Body: body, 74 | userDataMutex: &sync.RWMutex{}, 75 | } 76 | 77 | } 78 | 79 | func (br *BaseRequest) String() string { 80 | return fmt.Sprintf("Request(id=%d, method=%s, url=%s, attempts=%d)", br.Id, br.HttpRequest.Method, br.HttpRequest.URL.String(), len(br.Attempts)) 81 | } 82 | 83 | func (br *BaseRequest) GetHttpRequest() *http.Request { 84 | return br.HttpRequest 85 | } 86 | 87 | func (br *BaseRequest) SetHttpRequest(r *http.Request) { 88 | br.HttpRequest = r 89 | } 90 | 91 | func (br *BaseRequest) GetId() int64 { 92 | return br.Id 93 | } 94 | 95 | func (br *BaseRequest) SetBody(b netutils.MultiReader) { 96 | br.Body = b 97 | } 98 | 99 | func (br *BaseRequest) GetBody() netutils.MultiReader { 100 | return br.Body 101 | } 102 | 103 | func (br *BaseRequest) AddAttempt(a Attempt) { 104 | br.Attempts = append(br.Attempts, a) 105 | } 106 | 107 | func (br *BaseRequest) GetAttempts() []Attempt { 108 | return br.Attempts 109 | } 110 | 111 | func (br *BaseRequest) GetLastAttempt() Attempt { 112 | if len(br.Attempts) == 0 { 113 | return nil 114 | } 115 | return br.Attempts[len(br.Attempts)-1] 116 | } 117 | func (br *BaseRequest) SetUserData(key string, baton interface{}) { 118 | br.userDataMutex.Lock() 119 | defer br.userDataMutex.Unlock() 120 | if br.userData == nil { 121 | br.userData = make(map[string]interface{}) 122 | } 123 | br.userData[key] = baton 124 | } 125 | func (br *BaseRequest) GetUserData(key string) (i interface{}, b bool) { 126 | br.userDataMutex.RLock() 127 | defer br.userDataMutex.RUnlock() 128 | if br.userData == nil { 129 | return i, false 130 | } 131 | i, b = br.userData[key] 132 | return i, b 133 | } 134 | func (br *BaseRequest) DeleteUserData(key string) { 135 | br.userDataMutex.Lock() 136 | defer br.userDataMutex.Unlock() 137 | if br.userData == nil { 138 | return 139 | } 140 | 141 | delete(br.userData, key) 142 | } 143 | -------------------------------------------------------------------------------- /request/request_test.go: -------------------------------------------------------------------------------- 1 | package request 2 | 3 | import ( 4 | . "gopkg.in/check.v1" 5 | "net/http" 6 | "testing" 7 | ) 8 | 9 | func TestRequest(t *testing.T) { TestingT(t) } 10 | 11 | type RequestSuite struct { 12 | } 13 | 14 | var _ = Suite(&RequestSuite{}) 15 | 16 | func (s *RequestSuite) SetUpSuite(c *C) { 17 | } 18 | 19 | func (s *RequestSuite) TestUserDataInt(c *C) { 20 | br := NewBaseRequest(&http.Request{}, 0, nil) 21 | br.SetUserData("caller1", 100) 22 | data, present := br.GetUserData("caller1") 23 | 24 | c.Assert(present, Equals, true) 25 | c.Assert(data.(int), Equals, 100) 26 | 27 | br.SetUserData("caller2", 200) 28 | data, present = br.GetUserData("caller1") 29 | c.Assert(present, Equals, true) 30 | c.Assert(data.(int), Equals, 100) 31 | 32 | data, present = br.GetUserData("caller2") 33 | c.Assert(present, Equals, true) 34 | c.Assert(data.(int), Equals, 200) 35 | 36 | br.DeleteUserData("caller2") 37 | _, present = br.GetUserData("caller2") 38 | c.Assert(present, Equals, false) 39 | } 40 | 41 | func (s *RequestSuite) TestUserDataNil(c *C) { 42 | br := NewBaseRequest(&http.Request{}, 0, nil) 43 | _, present := br.GetUserData("caller1") 44 | c.Assert(present, Equals, false) 45 | } 46 | -------------------------------------------------------------------------------- /route/exproute/exproute.go: -------------------------------------------------------------------------------- 1 | /* 2 | see http://godoc.org/github.com/mailgun/route for documentation on the language 3 | */ 4 | package exproute 5 | 6 | import ( 7 | "fmt" 8 | "regexp" 9 | "strings" 10 | 11 | "github.com/mailgun/route" 12 | "github.com/mailgun/vulcan/location" 13 | "github.com/mailgun/vulcan/request" 14 | ) 15 | 16 | type ExpRouter struct { 17 | r route.Router 18 | } 19 | 20 | func NewExpRouter() *ExpRouter { 21 | return &ExpRouter{ 22 | r: route.New(), 23 | } 24 | } 25 | 26 | func (e *ExpRouter) GetLocationByExpression(expr string) location.Location { 27 | v := e.r.GetRoute(convertPath(expr)) 28 | if v == nil { 29 | return nil 30 | } 31 | return v.(location.Location) 32 | } 33 | 34 | func (e *ExpRouter) AddLocation(expr string, l location.Location) error { 35 | return e.r.AddRoute(convertPath(expr), l) 36 | } 37 | 38 | func (e *ExpRouter) RemoveLocationByExpression(expr string) error { 39 | return e.r.RemoveRoute(convertPath(expr)) 40 | } 41 | 42 | func (e *ExpRouter) Route(req request.Request) (location.Location, error) { 43 | l, err := e.r.Route(req.GetHttpRequest()) 44 | if err != nil { 45 | return nil, err 46 | } 47 | if l == nil { 48 | return nil, nil 49 | } 50 | return l.(location.Location), nil 51 | } 52 | 53 | // convertPath changes strings to structured format /hello -> RegexpRoute("/hello") and leaves structured strings unchanged. 54 | func convertPath(in string) string { 55 | if !strings.Contains(in, "(") { 56 | return fmt.Sprintf(`PathRegexp(%#v)`, in) 57 | } 58 | fn, args, matched := extractFunction(in) 59 | if !matched { 60 | return in 61 | } 62 | pathMatcher := "" 63 | if fn == "TrieRoute" { 64 | pathMatcher = "Path" 65 | } else { 66 | pathMatcher = "PathRegexp" 67 | } 68 | if len(args) == 1 { 69 | return fmt.Sprintf(`%s("%s")`, pathMatcher, args[0]) 70 | } 71 | if len(args) == 2 { 72 | return fmt.Sprintf(`Method("%s") && %s("%s")`, args[0], pathMatcher, args[1]) 73 | } 74 | path := args[len(args)-1] 75 | methods := args[0 : len(args)-1] 76 | return fmt.Sprintf(`MethodRegexp("%s") && %s("%s")`, strings.Join(methods, "|"), pathMatcher, path) 77 | } 78 | 79 | func extractFunction(f string) (string, []string, bool) { 80 | match := regexp.MustCompile(`(TrieRoute|RegexpRoute)\(([^\(\)]+)\)`).FindStringSubmatch(f) 81 | if len(match) != 3 { 82 | return "", nil, false 83 | } 84 | fn := match[1] 85 | args := strings.Split(match[2], ",") 86 | arguments := make([]string, len(args)) 87 | for i, a := range args { 88 | arguments[i] = strings.Trim(a, " ,\"") 89 | } 90 | return fn, arguments, true 91 | } 92 | -------------------------------------------------------------------------------- /route/exproute/exproute_test.go: -------------------------------------------------------------------------------- 1 | package exproute 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | . "gopkg.in/check.v1" 8 | 9 | "github.com/mailgun/vulcan/location" 10 | "github.com/mailgun/vulcan/netutils" 11 | "github.com/mailgun/vulcan/request" 12 | ) 13 | 14 | func TestRoute(t *testing.T) { TestingT(t) } 15 | 16 | type RouteSuite struct { 17 | } 18 | 19 | var _ = Suite(&RouteSuite{}) 20 | 21 | func (s *RouteSuite) TestConvertPath(c *C) { 22 | tc := []struct { 23 | in string 24 | out string 25 | }{ 26 | {"/hello", `PathRegexp("/hello")`}, 27 | {`TrieRoute("/hello")`, `Path("/hello")`}, 28 | {`TrieRoute("POST", "/hello")`, `Method("POST") && Path("/hello")`}, 29 | {`TrieRoute("POST", "PUT", "/v2/path")`, `MethodRegexp("POST|PUT") && Path("/v2/path")`}, 30 | {`RegexpRoute("/hello")`, `PathRegexp("/hello")`}, 31 | {`RegexpRoute("POST", "/hello")`, `Method("POST") && PathRegexp("/hello")`}, 32 | {`RegexpRoute("POST", "PUT", "/v2/path")`, `MethodRegexp("POST|PUT") && PathRegexp("/v2/path")`}, 33 | {`Path("/hello")`, `Path("/hello")`}, 34 | } 35 | for i, t := range tc { 36 | comment := Commentf("tc%d", i) 37 | c.Assert(convertPath(t.in), Equals, t.out, comment) 38 | } 39 | } 40 | 41 | func (s *RouteSuite) TestEmptyOperationsSucceed(c *C) { 42 | r := NewExpRouter() 43 | 44 | c.Assert(r.GetLocationByExpression("bla"), IsNil) 45 | c.Assert(r.RemoveLocationByExpression("bla"), IsNil) 46 | 47 | l, err := r.Route(makeReq("http://google.com/blabla")) 48 | c.Assert(err, IsNil) 49 | c.Assert(l, IsNil) 50 | } 51 | 52 | func (s *RouteSuite) TestCRUD(c *C) { 53 | r := NewExpRouter() 54 | 55 | l1 := makeLoc("loc1") 56 | c.Assert(r.AddLocation(`TrieRoute("/r1")`, l1), IsNil) 57 | c.Assert(r.GetLocationByExpression(`TrieRoute("/r1")`), Equals, l1) 58 | c.Assert(r.RemoveLocationByExpression(`TrieRoute("/r1")`), IsNil) 59 | c.Assert(r.GetLocationByExpression(`TrieRoute("/r1")`), IsNil) 60 | } 61 | 62 | func (s *RouteSuite) TestAddTwiceFails(c *C) { 63 | r := NewExpRouter() 64 | 65 | l1 := makeLoc("loc1") 66 | c.Assert(r.AddLocation(`TrieRoute("/r1")`, l1), IsNil) 67 | c.Assert(r.AddLocation(`TrieRoute("/r1")`, l1), NotNil) 68 | 69 | // Make sure that error did not have side effects 70 | out, err := r.Route(makeReq("http://google.com/r1")) 71 | c.Assert(err, IsNil) 72 | c.Assert(out, Equals, l1) 73 | } 74 | 75 | func (s *RouteSuite) TestBadExpression(c *C) { 76 | r := NewExpRouter() 77 | 78 | l1 := makeLoc("loc1") 79 | c.Assert(r.AddLocation(`TrieRoute("/r1")`, l1), IsNil) 80 | c.Assert(r.AddLocation(`Path(blabla`, l1), NotNil) 81 | 82 | // Make sure that error did not have side effects 83 | out, err := r.Route(makeReq("http://google.com/r1")) 84 | c.Assert(err, IsNil) 85 | c.Assert(out, Equals, l1) 86 | } 87 | 88 | func (s *RouteSuite) TestTrieLegacyOperations(c *C) { 89 | r := NewExpRouter() 90 | 91 | l1 := makeLoc("loc1") 92 | c.Assert(r.AddLocation(`TrieRoute("/r1")`, l1), IsNil) 93 | 94 | l2 := makeLoc("loc2") 95 | c.Assert(r.AddLocation(`TrieRoute("/r2")`, l2), IsNil) 96 | 97 | out1, err := r.Route(makeReq("http://google.com/r1")) 98 | c.Assert(err, IsNil) 99 | c.Assert(out1, Equals, l1) 100 | 101 | out2, err := r.Route(makeReq("http://google.com/r2")) 102 | c.Assert(err, IsNil) 103 | c.Assert(out2, Equals, l2) 104 | } 105 | 106 | func (s *RouteSuite) TestTrieNewOperations(c *C) { 107 | r := NewExpRouter() 108 | 109 | l1 := makeLoc("loc1") 110 | c.Assert(r.AddLocation(`Path("/r1")`, l1), IsNil) 111 | 112 | l2 := makeLoc("loc2") 113 | c.Assert(r.AddLocation(`Path("/r2")`, l2), IsNil) 114 | 115 | out1, err := r.Route(makeReq("http://google.com/r1")) 116 | c.Assert(err, IsNil) 117 | c.Assert(out1, Equals, l1) 118 | 119 | out2, err := r.Route(makeReq("http://google.com/r2")) 120 | c.Assert(err, IsNil) 121 | c.Assert(out2, Equals, l2) 122 | } 123 | 124 | func (s *RouteSuite) TestTrieMiss(c *C) { 125 | r := NewExpRouter() 126 | 127 | c.Assert(r.AddLocation(`TrieRoute("/r1")`, makeLoc("loc1")), IsNil) 128 | 129 | out, err := r.Route(makeReq("http://google.com/r2")) 130 | c.Assert(err, IsNil) 131 | c.Assert(out, IsNil) 132 | } 133 | 134 | func (s *RouteSuite) TestRegexpOperations(c *C) { 135 | r := NewExpRouter() 136 | 137 | l1 := makeLoc("loc1") 138 | c.Assert(r.AddLocation(`PathRegexp("/r1")`, l1), IsNil) 139 | 140 | l2 := makeLoc("loc2") 141 | c.Assert(r.AddLocation(`PathRegexp("/r2")`, l2), IsNil) 142 | 143 | out, err := r.Route(makeReq("http://google.com/r1")) 144 | c.Assert(err, IsNil) 145 | c.Assert(out, Equals, l1) 146 | 147 | out, err = r.Route(makeReq("http://google.com/r2")) 148 | c.Assert(err, IsNil) 149 | c.Assert(out, Equals, l2) 150 | 151 | out, err = r.Route(makeReq("http://google.com/r3")) 152 | c.Assert(err, IsNil) 153 | c.Assert(out, IsNil) 154 | } 155 | 156 | func (s *RouteSuite) TestRegexpLegacyOperations(c *C) { 157 | r := NewExpRouter() 158 | 159 | l1 := makeLoc("loc1") 160 | c.Assert(r.AddLocation(`RegexpRoute("/r1")`, l1), IsNil) 161 | 162 | l2 := makeLoc("loc2") 163 | c.Assert(r.AddLocation(`RegexpRoute("/r2")`, l2), IsNil) 164 | 165 | out, err := r.Route(makeReq("http://google.com/r1")) 166 | c.Assert(err, IsNil) 167 | c.Assert(out, Equals, l1) 168 | 169 | out, err = r.Route(makeReq("http://google.com/r2")) 170 | c.Assert(err, IsNil) 171 | c.Assert(out, Equals, l2) 172 | 173 | out, err = r.Route(makeReq("http://google.com/r3")) 174 | c.Assert(err, IsNil) 175 | c.Assert(out, IsNil) 176 | } 177 | 178 | func (s *RouteSuite) TestMixedOperations(c *C) { 179 | r := NewExpRouter() 180 | 181 | l1 := makeLoc("loc1") 182 | c.Assert(r.AddLocation(`PathRegexp("/r1")`, l1), IsNil) 183 | 184 | l2 := makeLoc("loc2") 185 | c.Assert(r.AddLocation(`Path("/r2")`, l2), IsNil) 186 | 187 | out, err := r.Route(makeReq("http://google.com/r1")) 188 | c.Assert(err, IsNil) 189 | c.Assert(out, Equals, l1) 190 | 191 | out, err = r.Route(makeReq("http://google.com/r2")) 192 | c.Assert(err, IsNil) 193 | c.Assert(out, Equals, l2) 194 | 195 | out, err = r.Route(makeReq("http://google.com/r3")) 196 | c.Assert(err, IsNil) 197 | c.Assert(out, IsNil) 198 | } 199 | 200 | func (s *RouteSuite) TestMatchByMethodLegacy(c *C) { 201 | r := NewExpRouter() 202 | 203 | l1 := makeLoc("loc1") 204 | c.Assert(r.AddLocation(`TrieRoute("POST", "/r1")`, l1), IsNil) 205 | 206 | l2 := makeLoc("loc2") 207 | c.Assert(r.AddLocation(`TrieRoute("GET", "/r1")`, l2), IsNil) 208 | 209 | req := makeReq("http://google.com/r1") 210 | req.GetHttpRequest().Method = "POST" 211 | 212 | out, err := r.Route(req) 213 | c.Assert(err, IsNil) 214 | c.Assert(out, Equals, l1) 215 | 216 | req.GetHttpRequest().Method = "GET" 217 | out, err = r.Route(req) 218 | c.Assert(err, IsNil) 219 | c.Assert(out, Equals, l2) 220 | } 221 | 222 | func (s *RouteSuite) TestMatchByMethod(c *C) { 223 | r := NewExpRouter() 224 | 225 | l1 := makeLoc("loc1") 226 | c.Assert(r.AddLocation(`Method("POST") && Path("/r1")`, l1), IsNil) 227 | 228 | l2 := makeLoc("loc2") 229 | c.Assert(r.AddLocation(`Method("GET") && Path("/r1")`, l2), IsNil) 230 | 231 | req := makeReq("http://google.com/r1") 232 | req.GetHttpRequest().Method = "POST" 233 | 234 | out, err := r.Route(req) 235 | c.Assert(err, IsNil) 236 | c.Assert(out, Equals, l1) 237 | 238 | req.GetHttpRequest().Method = "GET" 239 | out, err = r.Route(req) 240 | c.Assert(err, IsNil) 241 | c.Assert(out, Equals, l2) 242 | } 243 | 244 | func (s *RouteSuite) TestTrieMatchLongestPath(c *C) { 245 | r := NewExpRouter() 246 | 247 | l1 := makeLoc("loc1") 248 | c.Assert(r.AddLocation(`Method("POST") && Path("/r")`, l1), IsNil) 249 | 250 | l2 := makeLoc("loc2") 251 | c.Assert(r.AddLocation(`Method("POST") && Path("/r/hello")`, l2), IsNil) 252 | 253 | req := makeReq("http://google.com/r/hello") 254 | req.GetHttpRequest().Method = "POST" 255 | 256 | out, err := r.Route(req) 257 | c.Assert(err, IsNil) 258 | c.Assert(out, Equals, l2) 259 | } 260 | 261 | func (s *RouteSuite) TestRegexpMatchLongestPath(c *C) { 262 | r := NewExpRouter() 263 | 264 | l1 := makeLoc("loc1") 265 | c.Assert(r.AddLocation(`PathRegexp("/r")`, l1), IsNil) 266 | 267 | l2 := makeLoc("loc2") 268 | c.Assert(r.AddLocation(`PathRegexp("/r/hello")`, l2), IsNil) 269 | 270 | req := makeReq("http://google.com/r/hello") 271 | 272 | out, err := r.Route(req) 273 | c.Assert(err, IsNil) 274 | c.Assert(out, Equals, l2) 275 | } 276 | 277 | func makeReq(url string) request.Request { 278 | u := netutils.MustParseUrl(url) 279 | return &request.BaseRequest{ 280 | HttpRequest: &http.Request{URL: u, RequestURI: url}, 281 | } 282 | } 283 | 284 | func makeLoc(url string) location.Location { 285 | return &location.ConstHttpLocation{Url: url} 286 | } 287 | -------------------------------------------------------------------------------- /route/hostroute/host.go: -------------------------------------------------------------------------------- 1 | // Route the request by hostname 2 | package hostroute 3 | 4 | import ( 5 | "fmt" 6 | . "github.com/mailgun/vulcan/location" 7 | . "github.com/mailgun/vulcan/request" 8 | . "github.com/mailgun/vulcan/route" 9 | "strings" 10 | "sync" 11 | ) 12 | 13 | // This router composer helps to match request by host header and uses inner 14 | // routes to do further matching 15 | type HostRouter struct { 16 | routers map[string]Router 17 | mutex *sync.Mutex 18 | } 19 | 20 | func NewHostRouter() *HostRouter { 21 | return &HostRouter{ 22 | mutex: &sync.Mutex{}, 23 | routers: make(map[string]Router), 24 | } 25 | } 26 | 27 | func (h *HostRouter) Route(req Request) (Location, error) { 28 | h.mutex.Lock() 29 | defer h.mutex.Unlock() 30 | 31 | hostname := strings.Split(strings.ToLower(req.GetHttpRequest().Host), ":")[0] 32 | matcher, exists := h.routers[hostname] 33 | if !exists { 34 | return nil, nil 35 | } 36 | return matcher.Route(req) 37 | } 38 | 39 | func (h *HostRouter) SetRouter(hostname string, router Router) error { 40 | h.mutex.Lock() 41 | defer h.mutex.Unlock() 42 | 43 | if router == nil { 44 | return fmt.Errorf("Router can not be nil") 45 | } 46 | 47 | h.routers[hostname] = router 48 | return nil 49 | } 50 | 51 | func (h *HostRouter) GetRouter(hostname string) Router { 52 | h.mutex.Lock() 53 | defer h.mutex.Unlock() 54 | 55 | router := h.routers[hostname] 56 | return router 57 | } 58 | 59 | func (h *HostRouter) RemoveRouter(hostname string) { 60 | h.mutex.Lock() 61 | defer h.mutex.Unlock() 62 | 63 | delete(h.routers, hostname) 64 | } 65 | -------------------------------------------------------------------------------- /route/hostroute/host_test.go: -------------------------------------------------------------------------------- 1 | package hostroute 2 | 3 | import ( 4 | . "github.com/mailgun/vulcan/location" 5 | . "github.com/mailgun/vulcan/netutils" 6 | . "github.com/mailgun/vulcan/request" 7 | . "github.com/mailgun/vulcan/route" 8 | . "gopkg.in/check.v1" 9 | "net/http" 10 | "testing" 11 | ) 12 | 13 | func TestPathRoute(t *testing.T) { TestingT(t) } 14 | 15 | type HostSuite struct { 16 | } 17 | 18 | var _ = Suite(&HostSuite{}) 19 | 20 | func (s *HostSuite) SetUpSuite(c *C) { 21 | } 22 | 23 | func (s *HostSuite) TestRouteEmpty(c *C) { 24 | m := NewHostRouter() 25 | 26 | out, err := m.Route(request("google.com", "http://google.com/")) 27 | c.Assert(err, IsNil) 28 | c.Assert(out, Equals, nil) 29 | } 30 | 31 | func (s *HostSuite) TestSetNil(c *C) { 32 | m := NewHostRouter() 33 | c.Assert(m.SetRouter("google.com", nil), Not(Equals), nil) 34 | } 35 | 36 | func (s *HostSuite) TestRouteMatching(c *C) { 37 | m := NewHostRouter() 38 | r := &ConstRouter{Location: &Loc{Name: "a"}} 39 | m.SetRouter("google.com", r) 40 | 41 | out, err := m.Route(request("google.com", "http://google.com/")) 42 | c.Assert(err, IsNil) 43 | c.Assert(out, Equals, r.Location) 44 | } 45 | 46 | func (s *HostSuite) TestRouteMatchingMultiple(c *C) { 47 | m := NewHostRouter() 48 | rA := &ConstRouter{Location: &Loc{Name: "a"}} 49 | rB := &ConstRouter{Location: &Loc{Name: "b"}} 50 | m.SetRouter("google.com", rA) 51 | m.SetRouter("yahoo.com", rB) 52 | 53 | out, err := m.Route(request("google.com", "http://google.com/")) 54 | c.Assert(err, IsNil) 55 | c.Assert(out, Equals, rA.Location) 56 | 57 | out, err = m.Route(request("yahoo.com", "http://yahoo.com/")) 58 | c.Assert(err, IsNil) 59 | c.Assert(out, Equals, rB.Location) 60 | } 61 | 62 | func (s *HostSuite) TestRemove(c *C) { 63 | m := NewHostRouter() 64 | rA := &ConstRouter{Location: &Loc{Name: "a"}} 65 | rB := &ConstRouter{Location: &Loc{Name: "b"}} 66 | m.SetRouter("google.com", rA) 67 | m.SetRouter("yahoo.com", rB) 68 | 69 | out, err := m.Route(request("google.com", "http://google.com/")) 70 | c.Assert(err, IsNil) 71 | c.Assert(out, Equals, rA.Location) 72 | 73 | out, err = m.Route(request("yahoo.com", "http://yahoo.com/")) 74 | c.Assert(err, IsNil) 75 | c.Assert(out, Equals, rB.Location) 76 | 77 | m.RemoveRouter("yahoo.com") 78 | 79 | out, err = m.Route(request("google.com", "http://google.com/")) 80 | c.Assert(err, IsNil) 81 | c.Assert(out, Equals, rA.Location) 82 | 83 | out, err = m.Route(request("yahoo.com", "http://yahoo.com/")) 84 | c.Assert(err, IsNil) 85 | c.Assert(out, Equals, nil) 86 | } 87 | 88 | func request(hostname, url string) Request { 89 | u := MustParseUrl(url) 90 | hr := &http.Request{URL: u, Header: make(http.Header), Host: hostname} 91 | return &BaseRequest{ 92 | HttpRequest: hr, 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /route/pathroute/route.go: -------------------------------------------------------------------------------- 1 | // Route the request by path 2 | package pathroute 3 | 4 | import ( 5 | "bytes" 6 | "fmt" 7 | . "github.com/mailgun/vulcan/location" 8 | . "github.com/mailgun/vulcan/request" 9 | "regexp" 10 | "sort" 11 | "sync" 12 | ) 13 | 14 | // Matches the location by path regular expression. 15 | // Out of two paths will select the one with the longer regular expression 16 | type PathRouter struct { 17 | locations []locPair 18 | expression *regexp.Regexp 19 | mutex *sync.Mutex 20 | } 21 | 22 | type locPair struct { 23 | pattern string 24 | location Location 25 | } 26 | 27 | type ByPattern []locPair 28 | 29 | func (a ByPattern) Len() int { return len(a) } 30 | func (a ByPattern) Swap(i, j int) { a[i], a[j] = a[j], a[i] } 31 | func (a ByPattern) Less(i, j int) bool { return len(a[i].pattern) > len(a[j].pattern) } 32 | 33 | func NewPathRouter() *PathRouter { 34 | return &PathRouter{ 35 | mutex: &sync.Mutex{}, 36 | } 37 | } 38 | 39 | func (m *PathRouter) Route(req Request) (Location, error) { 40 | m.mutex.Lock() 41 | defer m.mutex.Unlock() 42 | 43 | if m.expression == nil { 44 | return nil, nil 45 | } 46 | 47 | path := req.GetHttpRequest().URL.Path 48 | if len(path) == 0 { 49 | path = "/" 50 | } 51 | 52 | matches := m.expression.FindStringSubmatchIndex(path) 53 | if len(matches) < 2 { 54 | return nil, nil 55 | } 56 | for i := 2; i < len(matches); i += 2 { 57 | if matches[i] != -1 { 58 | if i/2-1 >= len(m.locations) { 59 | return nil, fmt.Errorf("Internal logic error: %d", i/2-1) 60 | } 61 | return m.locations[i/2-1].location, nil 62 | } 63 | } 64 | 65 | return nil, nil 66 | } 67 | 68 | func (m *PathRouter) AddLocation(pattern string, location Location) error { 69 | m.mutex.Lock() 70 | defer m.mutex.Unlock() 71 | 72 | _, err := regexp.Compile(pattern) 73 | if err != nil { 74 | return fmt.Errorf("Pattern '%s' does not compile into regular expression: %s", pattern, err) 75 | } 76 | 77 | for _, p := range m.locations { 78 | if p.pattern == pattern { 79 | return fmt.Errorf("Pattern: %s already exists", pattern) 80 | } 81 | } 82 | 83 | locations := append(m.locations, locPair{pattern, location}) 84 | 85 | sort.Sort(ByPattern(locations)) 86 | expression, err := buildMapping(locations) 87 | if err != nil { 88 | return err 89 | } 90 | 91 | m.locations = locations 92 | m.expression = expression 93 | 94 | return nil 95 | } 96 | 97 | func (m *PathRouter) GetLocationByPattern(pattern string) Location { 98 | m.mutex.Lock() 99 | defer m.mutex.Unlock() 100 | 101 | for _, p := range m.locations { 102 | if p.pattern == pattern { 103 | return p.location 104 | } 105 | } 106 | return nil 107 | } 108 | 109 | func (m *PathRouter) GetLocationById(id string) Location { 110 | m.mutex.Lock() 111 | defer m.mutex.Unlock() 112 | 113 | for _, p := range m.locations { 114 | if p.location.GetId() == id { 115 | return p.location 116 | } 117 | } 118 | return nil 119 | } 120 | 121 | func (m *PathRouter) RemoveLocation(location Location) error { 122 | m.mutex.Lock() 123 | defer m.mutex.Unlock() 124 | 125 | if location == nil { 126 | return fmt.Errorf("Pass location to remove") 127 | } 128 | 129 | for i, p := range m.locations { 130 | if p.location == location { 131 | // Note this is safe due to the way go does range iterations by snapshotting the ranged list 132 | m.locations = append(m.locations[:i], m.locations[i+1:]...) 133 | break 134 | } 135 | } 136 | if len(m.locations) != 0 { 137 | sort.Sort(ByPattern(m.locations)) 138 | } 139 | 140 | expression, err := buildMapping(m.locations) 141 | if err == nil { 142 | m.expression = expression 143 | } else { 144 | m.expression = nil 145 | } 146 | return err 147 | } 148 | 149 | func buildMapping(locations []locPair) (*regexp.Regexp, error) { 150 | if len(locations) == 0 { 151 | return nil, nil 152 | } 153 | out := &bytes.Buffer{} 154 | out.WriteString("^") 155 | for i, p := range locations { 156 | out.WriteString("(") 157 | out.WriteString(p.pattern) 158 | out.WriteString(")") 159 | if i != len(locations)-1 { 160 | out.WriteString("|") 161 | } 162 | } 163 | // Add optional trailing slash here 164 | out.WriteString("/?$") 165 | return regexp.Compile(out.String()) 166 | } 167 | -------------------------------------------------------------------------------- /route/pathroute/route_test.go: -------------------------------------------------------------------------------- 1 | package pathroute 2 | 3 | import ( 4 | "fmt" 5 | . "github.com/mailgun/vulcan/location" 6 | . "github.com/mailgun/vulcan/netutils" 7 | . "github.com/mailgun/vulcan/request" 8 | "github.com/mailgun/vulcan/testutils" 9 | . "gopkg.in/check.v1" 10 | "net/http" 11 | "testing" 12 | ) 13 | 14 | func TestPathRoute(t *testing.T) { TestingT(t) } 15 | 16 | type MatchSuite struct { 17 | } 18 | 19 | var _ = Suite(&MatchSuite{}) 20 | 21 | func (s *MatchSuite) SetUpSuite(c *C) { 22 | } 23 | 24 | func (s *MatchSuite) TestRouteEmpty(c *C) { 25 | m := NewPathRouter() 26 | out, err := m.Route(request("http://google.com/")) 27 | c.Assert(err, IsNil) 28 | c.Assert(out, Equals, nil) 29 | } 30 | 31 | func (s *MatchSuite) TestRemoveNonExistent(c *C) { 32 | m := NewPathRouter() 33 | c.Assert(m.RemoveLocation(m.GetLocationByPattern("ooo")), Not(Equals), nil) 34 | } 35 | 36 | func (s *MatchSuite) TestAddTwice(c *C) { 37 | m := NewPathRouter() 38 | loc := &Loc{Name: "a"} 39 | c.Assert(m.AddLocation("/a", loc), IsNil) 40 | c.Assert(m.AddLocation("/a", loc), Not(Equals), nil) 41 | } 42 | 43 | func (s *MatchSuite) TestSingleLocation(c *C) { 44 | m := NewPathRouter() 45 | loc := &Loc{Name: "a"} 46 | c.Assert(m.AddLocation("/", loc), IsNil) 47 | out, err := m.Route(request("http://google.com/")) 48 | c.Assert(err, IsNil) 49 | c.Assert(out, Equals, loc) 50 | } 51 | 52 | func (s *MatchSuite) TestEmptyPath(c *C) { 53 | m := NewPathRouter() 54 | loc := &Loc{Name: "a"} 55 | c.Assert(m.AddLocation("/", loc), IsNil) 56 | out, err := m.Route(request("http://google.com")) 57 | c.Assert(err, IsNil) 58 | c.Assert(out, Equals, loc) 59 | } 60 | 61 | func (s *MatchSuite) TestMatchNothing(c *C) { 62 | m := NewPathRouter() 63 | loc := &Loc{Name: "a"} 64 | c.Assert(m.AddLocation("/", loc), IsNil) 65 | out, err := m.Route(request("http://google.com/hello/there")) 66 | c.Assert(err, IsNil) 67 | c.Assert(out, Equals, nil) 68 | } 69 | 70 | // Make sure we'll match request regardless if it has trailing slash or not 71 | func (s *MatchSuite) TestTrailingSlashes(c *C) { 72 | m := NewPathRouter() 73 | loc := &Loc{Name: "a"} 74 | c.Assert(m.AddLocation("/a/b", loc), IsNil) 75 | 76 | out, err := m.Route(request("http://google.com/a/b")) 77 | c.Assert(err, IsNil) 78 | c.Assert(out, Equals, loc) 79 | 80 | out, err = m.Route(request("http://google.com/a/b/")) 81 | c.Assert(err, IsNil) 82 | c.Assert(out, Equals, loc) 83 | } 84 | 85 | // If users added trailing slashes the request will require them to match request 86 | func (s *MatchSuite) TestPatternTrailingSlashes(c *C) { 87 | m := NewPathRouter() 88 | loc := &Loc{Name: "a"} 89 | c.Assert(m.AddLocation("/a/b/", loc), IsNil) 90 | 91 | out, err := m.Route(request("http://google.com/a/b")) 92 | c.Assert(err, IsNil) 93 | c.Assert(out, Equals, nil) 94 | 95 | out, err = m.Route(request("http://google.com/a/b/")) 96 | c.Assert(err, IsNil) 97 | c.Assert(out, Equals, loc) 98 | } 99 | 100 | func (s *MatchSuite) TestMultipleLocations(c *C) { 101 | m := NewPathRouter() 102 | locA := &Loc{Name: "a"} 103 | locB := &Loc{Name: "b"} 104 | 105 | c.Assert(m.AddLocation("/a/there", locA), IsNil) 106 | c.Assert(m.AddLocation("/c", locB), IsNil) 107 | 108 | out, err := m.Route(request("http://google.com/a/there")) 109 | c.Assert(err, IsNil) 110 | c.Assert(out, Equals, locA) 111 | 112 | out, err = m.Route(request("http://google.com/c")) 113 | c.Assert(err, IsNil) 114 | c.Assert(out, Equals, locB) 115 | } 116 | 117 | func (s *MatchSuite) TestChooseLongest(c *C) { 118 | m := NewPathRouter() 119 | locA := &Loc{Name: "a"} 120 | locB := &Loc{Name: "b"} 121 | 122 | c.Assert(m.AddLocation("/a/there", locA), IsNil) 123 | c.Assert(m.AddLocation("/a", locB), IsNil) 124 | 125 | out, err := m.Route(request("http://google.com/a/there")) 126 | c.Assert(err, IsNil) 127 | c.Assert(out, Equals, locA) 128 | 129 | out, err = m.Route(request("http://google.com/a")) 130 | c.Assert(err, IsNil) 131 | c.Assert(out, Equals, locB) 132 | } 133 | 134 | func (s *MatchSuite) TestRemove(c *C) { 135 | m := NewPathRouter() 136 | locA := &Loc{Name: "a"} 137 | locB := &Loc{Name: "b"} 138 | 139 | c.Assert(m.AddLocation("/a", locA), IsNil) 140 | c.Assert(m.AddLocation("/b", locB), IsNil) 141 | 142 | out, err := m.Route(request("http://google.com/a")) 143 | c.Assert(err, IsNil) 144 | c.Assert(out, Equals, locA) 145 | 146 | out, err = m.Route(request("http://google.com/b")) 147 | c.Assert(err, IsNil) 148 | c.Assert(out, Equals, locB) 149 | 150 | // Remove the location and make sure the matcher is still valid 151 | c.Assert(m.RemoveLocation(m.GetLocationByPattern("/b")), IsNil) 152 | 153 | out, err = m.Route(request("http://google.com/a")) 154 | c.Assert(err, IsNil) 155 | c.Assert(out, Equals, locA) 156 | 157 | out, err = m.Route(request("http://google.com/b")) 158 | c.Assert(err, IsNil) 159 | c.Assert(out, Equals, nil) 160 | } 161 | 162 | func (s *MatchSuite) TestAddBad(c *C) { 163 | m := NewPathRouter() 164 | locA := &Loc{Name: "a"} 165 | locB := &Loc{Name: "b"} 166 | 167 | c.Assert(m.AddLocation("/a/there", locA), IsNil) 168 | 169 | out, err := m.Route(request("http://google.com/a/there")) 170 | c.Assert(err, IsNil) 171 | c.Assert(out, Equals, locA) 172 | 173 | c.Assert(m.AddLocation("--(", locB), Not(Equals), nil) 174 | 175 | out, err = m.Route(request("http://google.com/a/there")) 176 | c.Assert(err, IsNil) 177 | c.Assert(out, Equals, locA) 178 | } 179 | 180 | func (s *MatchSuite) BenchmarkMatching(c *C) { 181 | rndString := testutils.NewRndString() 182 | 183 | m := NewPathRouter() 184 | loc := &Loc{Name: "a"} 185 | 186 | for i := 0; i < 100; i++ { 187 | err := m.AddLocation(rndString.MakePath(20, 10), loc) 188 | c.Assert(err, IsNil) 189 | } 190 | 191 | req := request(fmt.Sprintf("http://google.com/%s", rndString.MakePath(20, 10))) 192 | for i := 0; i < c.N; i++ { 193 | m.Route(req) 194 | } 195 | } 196 | 197 | func request(url string) Request { 198 | u := MustParseUrl(url) 199 | return &BaseRequest{ 200 | HttpRequest: &http.Request{URL: u}, 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /route/router.go: -------------------------------------------------------------------------------- 1 | // Route the request to a location 2 | package route 3 | 4 | import ( 5 | . "github.com/mailgun/vulcan/location" 6 | . "github.com/mailgun/vulcan/request" 7 | ) 8 | 9 | // Router matches incoming request to a specific location 10 | type Router interface { 11 | // if error is not nil, the request wll be aborted and error will be proxied to client. 12 | // if location is nil and error is nil, that means that router did not find any matching location 13 | Route(req Request) (Location, error) 14 | } 15 | 16 | // Helper router that always the same location 17 | type ConstRouter struct { 18 | Location Location 19 | } 20 | 21 | func (m *ConstRouter) Route(req Request) (Location, error) { 22 | return m.Location, nil 23 | } 24 | -------------------------------------------------------------------------------- /template/template.go: -------------------------------------------------------------------------------- 1 | // Package template consolidates various templating utilities used throughout different 2 | // parts of vulcan. 3 | package template 4 | 5 | import ( 6 | "io" 7 | "io/ioutil" 8 | "net/http" 9 | "text/template" 10 | ) 11 | 12 | // data represents template data that is available to use in templates. 13 | type data struct { 14 | Request *http.Request 15 | } 16 | 17 | // Apply reads a template string from the provided reader, applies variables 18 | // from the provided request object to it and writes the result into 19 | // the provided writer. 20 | // 21 | // Template is standard Go's http://golang.org/pkg/text/template/. 22 | func Apply(in io.Reader, out io.Writer, request *http.Request) error { 23 | body, err := ioutil.ReadAll(in) 24 | if err != nil { 25 | return err 26 | } 27 | 28 | return ApplyString(string(body), out, request) 29 | } 30 | 31 | // ApplyString applies variables from the provided request object to the provided 32 | // template string and writes the result into the provided writer. 33 | // 34 | // Template is standard Go's http://golang.org/pkg/text/template/. 35 | func ApplyString(in string, out io.Writer, request *http.Request) error { 36 | t, err := template.New("t").Parse(in) 37 | if err != nil { 38 | return err 39 | } 40 | 41 | if err = t.Execute(out, data{request}); err != nil { 42 | return err 43 | } 44 | 45 | return nil 46 | } 47 | -------------------------------------------------------------------------------- /template/template_test.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "strings" 7 | "testing" 8 | 9 | . "gopkg.in/check.v1" 10 | ) 11 | 12 | func TestTemplate(t *testing.T) { TestingT(t) } 13 | 14 | type TemplateSuite struct{} 15 | 16 | var _ = Suite(&TemplateSuite{}) 17 | 18 | func (s *TemplateSuite) SetUpSuite(c *C) { 19 | } 20 | 21 | func (s *TemplateSuite) TestTemplateOkay(c *C) { 22 | request, _ := http.NewRequest("GET", "http://foo", nil) 23 | request.Header.Add("X-Header", "bar") 24 | 25 | out := &bytes.Buffer{} 26 | err := Apply(strings.NewReader(`foo {{.Request.Header.Get "X-Header"}}`), out, request) 27 | c.Assert(err, IsNil) 28 | c.Assert(out.String(), Equals, "foo bar") 29 | } 30 | 31 | func (s *TemplateSuite) TestBadTemplate(c *C) { 32 | request, _ := http.NewRequest("GET", "http://foo", nil) 33 | request.Header.Add("X-Header", "bar") 34 | 35 | out := &bytes.Buffer{} 36 | err := Apply(strings.NewReader(`foo {{.Request.Header.Get "X-Header"`), out, request) 37 | c.Assert(err, NotNil) 38 | c.Assert(out.String(), Equals, "") 39 | } 40 | 41 | func (s *TemplateSuite) TestNoVariables(c *C) { 42 | request, _ := http.NewRequest("GET", "http://foo", nil) 43 | request.Header.Add("X-Header", "bar") 44 | 45 | out := &bytes.Buffer{} 46 | err := Apply(strings.NewReader(`foo baz`), out, request) 47 | c.Assert(err, IsNil) 48 | c.Assert(out.String(), Equals, "foo baz") 49 | } 50 | 51 | func (s *TemplateSuite) TestNonexistentVariable(c *C) { 52 | request, _ := http.NewRequest("GET", "http://foo", nil) 53 | request.Header.Add("X-Header", "bar") 54 | 55 | out := &bytes.Buffer{} 56 | err := Apply(strings.NewReader(`foo {{.Request.Header.Get "Y-Header"}}`), out, request) 57 | c.Assert(err, IsNil) 58 | c.Assert(out.String(), Equals, "foo ") 59 | } 60 | -------------------------------------------------------------------------------- /testutils/requests.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | "net/http/httptest" 9 | "strings" 10 | 11 | "github.com/mailgun/vulcan/netutils" 12 | ) 13 | 14 | type Opts struct { 15 | Host string 16 | Method string 17 | Body string 18 | Headers http.Header 19 | } 20 | 21 | func MakeRequest(url string, opts Opts) (*http.Response, []byte, error) { 22 | method := "GET" 23 | if opts.Method != "" { 24 | method = opts.Method 25 | } 26 | request, _ := http.NewRequest(method, url, strings.NewReader(opts.Body)) 27 | if opts.Headers != nil { 28 | netutils.CopyHeaders(request.Header, opts.Headers) 29 | } 30 | 31 | if len(opts.Host) != 0 { 32 | request.Host = opts.Host 33 | } 34 | 35 | var tr *http.Transport 36 | if strings.HasPrefix(url, "https") { 37 | tr = &http.Transport{ 38 | DisableKeepAlives: true, 39 | TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, 40 | } 41 | } else { 42 | tr = &http.Transport{ 43 | DisableKeepAlives: true, 44 | } 45 | } 46 | 47 | client := &http.Client{ 48 | Transport: tr, 49 | CheckRedirect: func(req *http.Request, via []*http.Request) error { 50 | return fmt.Errorf("No redirects") 51 | }, 52 | } 53 | response, err := client.Do(request) 54 | if err == nil { 55 | bodyBytes, err := ioutil.ReadAll(response.Body) 56 | return response, bodyBytes, err 57 | } 58 | return response, nil, err 59 | } 60 | 61 | func GET(url string, o Opts) (*http.Response, []byte, error) { 62 | o.Method = "GET" 63 | return MakeRequest(url, o) 64 | } 65 | 66 | type WebHandler func(http.ResponseWriter, *http.Request) 67 | 68 | func NewTestServer(handler WebHandler) *httptest.Server { 69 | return httptest.NewServer(http.HandlerFunc(handler)) 70 | } 71 | 72 | func NewTestResponder(response string) *httptest.Server { 73 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 74 | w.Write([]byte(response)) 75 | })) 76 | } 77 | -------------------------------------------------------------------------------- /testutils/rndstring.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "math/rand" 8 | "time" 9 | ) 10 | 11 | type RndString struct { 12 | src rand.Source 13 | } 14 | 15 | func NewRndString() *RndString { 16 | return &RndString{rand.NewSource(time.Now().UTC().UnixNano())} 17 | } 18 | 19 | func (r *RndString) Read(p []byte) (n int, err error) { 20 | for i := range p { 21 | p[i] = byte(r.src.Int63()%26 + 97) 22 | } 23 | return len(p), nil 24 | } 25 | 26 | func (r *RndString) MakeString(n int) string { 27 | buffer := &bytes.Buffer{} 28 | io.CopyN(buffer, r, int64(n)) 29 | return buffer.String() 30 | } 31 | 32 | func (s *RndString) MakePath(varlen, minlen int) string { 33 | return fmt.Sprintf("/%s", s.MakeString(rand.Intn(varlen)+minlen)) 34 | } 35 | -------------------------------------------------------------------------------- /threshold/parse.go: -------------------------------------------------------------------------------- 1 | package threshold 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | 7 | "github.com/mailgun/predicate" 8 | ) 9 | 10 | // Parses expression in the go language into Failover predicates 11 | func ParseExpression(in string) (Predicate, error) { 12 | p, err := predicate.NewParser(predicate.Def{ 13 | Operators: predicate.Operators{ 14 | AND: AND, 15 | OR: OR, 16 | EQ: EQ, 17 | NEQ: NEQ, 18 | LT: LT, 19 | LE: LE, 20 | GT: GT, 21 | GE: GE, 22 | }, 23 | Functions: map[string]interface{}{ 24 | "RequestMethod": RequestMethod, 25 | "IsNetworkError": IsNetworkError, 26 | "Attempts": Attempts, 27 | "ResponseCode": ResponseCode, 28 | }, 29 | }) 30 | if err != nil { 31 | return nil, err 32 | } 33 | out, err := p.Parse(convertLegacy(in)) 34 | if err != nil { 35 | return nil, err 36 | } 37 | pr, ok := out.(Predicate) 38 | if !ok { 39 | return nil, fmt.Errorf("expected predicate, got %T", out) 40 | } 41 | return pr, nil 42 | } 43 | 44 | func convertLegacy(in string) string { 45 | patterns := []struct { 46 | Pattern string 47 | Replacement string 48 | }{ 49 | { 50 | Pattern: `IsNetworkError([^\(]|$)`, 51 | Replacement: "IsNetworkError()", 52 | }, 53 | { 54 | Pattern: `ResponseCodeEq\((\d+)\)`, 55 | Replacement: "ResponseCode() == $1", 56 | }, 57 | { 58 | Pattern: `RequestMethodEq\(("[^"]+")\)`, 59 | Replacement: `RequestMethod() == $1`, 60 | }, 61 | { 62 | Pattern: `AttemptsLe\((\d+)\)`, 63 | Replacement: "Attempts() <= $1", 64 | }, 65 | } 66 | for _, p := range patterns { 67 | re := regexp.MustCompile(p.Pattern) 68 | in = re.ReplaceAllString(in, p.Replacement) 69 | } 70 | return in 71 | } 72 | -------------------------------------------------------------------------------- /threshold/parse_test.go: -------------------------------------------------------------------------------- 1 | package threshold 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "testing" 7 | 8 | . "github.com/mailgun/vulcan/request" 9 | . "gopkg.in/check.v1" 10 | ) 11 | 12 | func Test(t *testing.T) { TestingT(t) } 13 | 14 | type ThresholdSuite struct { 15 | } 16 | 17 | var _ = Suite(&ThresholdSuite{}) 18 | 19 | func (s *ThresholdSuite) TestSuccessOnGets(c *C) { 20 | p, err := ParseExpression(`RequestMethod() == "GET"`) 21 | c.Assert(err, IsNil) 22 | 23 | c.Assert(p(&BaseRequest{HttpRequest: &http.Request{Method: "GET"}}), Equals, true) 24 | c.Assert(p(&BaseRequest{HttpRequest: &http.Request{Method: "POST"}}), Equals, false) 25 | } 26 | 27 | func (s *ThresholdSuite) TestSuccessOnGetsLegacy(c *C) { 28 | p, err := ParseExpression(`RequestMethodEq("GET")`) 29 | c.Assert(err, IsNil) 30 | 31 | c.Assert(p(&BaseRequest{HttpRequest: &http.Request{Method: "GET"}}), Equals, true) 32 | c.Assert(p(&BaseRequest{HttpRequest: &http.Request{Method: "POST"}}), Equals, false) 33 | } 34 | 35 | func (s *ThresholdSuite) TestSuccessOnGetsAndErrors(c *C) { 36 | p, err := ParseExpression(`(RequestMethod() == "GET") && IsNetworkError()`) 37 | c.Assert(err, IsNil) 38 | 39 | // There's no error 40 | c.Assert(p(&BaseRequest{HttpRequest: &http.Request{Method: "GET"}}), Equals, false) 41 | 42 | // This one allows error 43 | req := &BaseRequest{ 44 | HttpRequest: &http.Request{Method: "GET"}, 45 | Attempts: []Attempt{ 46 | &BaseAttempt{ 47 | Error: fmt.Errorf("Something failed"), 48 | }, 49 | }, 50 | } 51 | c.Assert(p(req), Equals, true) 52 | } 53 | 54 | func (s *ThresholdSuite) TestLegacyIsNetworkError(c *C) { 55 | p, err := ParseExpression(`ResponseCodeEq(503) || IsNetworkError`) 56 | c.Assert(err, IsNil) 57 | 58 | // There's no error 59 | c.Assert(p(&BaseRequest{}), Equals, false) 60 | 61 | // There's a network error 62 | req := &BaseRequest{ 63 | Attempts: []Attempt{ 64 | &BaseAttempt{ 65 | Error: fmt.Errorf("Something failed"), 66 | }, 67 | }, 68 | } 69 | c.Assert(p(req), Equals, true) 70 | 71 | // There's a 503 response code 72 | req = &BaseRequest{ 73 | Attempts: []Attempt{ 74 | &BaseAttempt{ 75 | Response: &http.Response{StatusCode: 503}, 76 | }, 77 | }, 78 | } 79 | c.Assert(p(req), Equals, true) 80 | 81 | // Different response code does not work 82 | req = &BaseRequest{ 83 | Attempts: []Attempt{ 84 | &BaseAttempt{ 85 | Response: &http.Response{StatusCode: 504}, 86 | }, 87 | }, 88 | } 89 | c.Assert(p(req), Equals, false) 90 | } 91 | 92 | func (s *ThresholdSuite) TestResponseCodeOrError(c *C) { 93 | p, err := ParseExpression(`ResponseCode() == 503 || IsNetworkError()`) 94 | c.Assert(err, IsNil) 95 | 96 | // There's no error 97 | c.Assert(p(&BaseRequest{}), Equals, false) 98 | 99 | // There's a network error 100 | req := &BaseRequest{ 101 | Attempts: []Attempt{ 102 | &BaseAttempt{ 103 | Error: fmt.Errorf("Something failed"), 104 | }, 105 | }, 106 | } 107 | c.Assert(p(req), Equals, true) 108 | 109 | // There's a 503 response code 110 | req = &BaseRequest{ 111 | Attempts: []Attempt{ 112 | &BaseAttempt{ 113 | Response: &http.Response{StatusCode: 503}, 114 | }, 115 | }, 116 | } 117 | c.Assert(p(req), Equals, true) 118 | 119 | // Different response code does not work 120 | req = &BaseRequest{ 121 | Attempts: []Attempt{ 122 | &BaseAttempt{ 123 | Response: &http.Response{StatusCode: 504}, 124 | }, 125 | }, 126 | } 127 | c.Assert(p(req), Equals, false) 128 | } 129 | 130 | func (s *ThresholdSuite) TestAttemptsLeLegacy(c *C) { 131 | p, err := ParseExpression(`AttemptsLe(1)`) 132 | c.Assert(err, IsNil) 133 | 134 | req := &BaseRequest{ 135 | Attempts: []Attempt{}, 136 | } 137 | c.Assert(p(req), Equals, true) 138 | 139 | req = &BaseRequest{ 140 | Attempts: []Attempt{ 141 | &BaseAttempt{ 142 | Response: &http.Response{StatusCode: 503}, 143 | }, 144 | &BaseAttempt{ 145 | Response: &http.Response{StatusCode: 503}, 146 | }, 147 | }, 148 | } 149 | c.Assert(p(req), Equals, false) 150 | } 151 | 152 | func (s *ThresholdSuite) TestAttemptsLT(c *C) { 153 | p, err := ParseExpression(`Attempts() < 1`) 154 | c.Assert(err, IsNil) 155 | 156 | req := &BaseRequest{ 157 | Attempts: []Attempt{}, 158 | } 159 | c.Assert(p(req), Equals, true) 160 | 161 | req = &BaseRequest{ 162 | Attempts: []Attempt{ 163 | &BaseAttempt{ 164 | Response: &http.Response{StatusCode: 503}, 165 | }, 166 | &BaseAttempt{ 167 | Response: &http.Response{StatusCode: 503}, 168 | }, 169 | }, 170 | } 171 | c.Assert(p(req), Equals, false) 172 | } 173 | 174 | func (s *ThresholdSuite) TestAttemptsGT(c *C) { 175 | p, err := ParseExpression(`Attempts() > 1`) 176 | c.Assert(err, IsNil) 177 | 178 | req := &BaseRequest{ 179 | Attempts: []Attempt{}, 180 | } 181 | c.Assert(p(req), Equals, false) 182 | 183 | req = &BaseRequest{ 184 | Attempts: []Attempt{ 185 | &BaseAttempt{ 186 | Response: &http.Response{StatusCode: 503}, 187 | }, 188 | &BaseAttempt{ 189 | Response: &http.Response{StatusCode: 503}, 190 | }, 191 | }, 192 | } 193 | c.Assert(p(req), Equals, true) 194 | } 195 | 196 | func (s *ThresholdSuite) TestAttemptsGE(c *C) { 197 | p, err := ParseExpression(`Attempts() >= 1`) 198 | c.Assert(err, IsNil) 199 | 200 | req := &BaseRequest{ 201 | Attempts: []Attempt{}, 202 | } 203 | c.Assert(p(req), Equals, false) 204 | 205 | req = &BaseRequest{ 206 | Attempts: []Attempt{ 207 | &BaseAttempt{ 208 | Response: &http.Response{StatusCode: 503}, 209 | }, 210 | }, 211 | } 212 | c.Assert(p(req), Equals, true) 213 | 214 | req = &BaseRequest{ 215 | Attempts: []Attempt{ 216 | &BaseAttempt{ 217 | Response: &http.Response{StatusCode: 503}, 218 | }, 219 | &BaseAttempt{ 220 | Response: &http.Response{StatusCode: 503}, 221 | }, 222 | }, 223 | } 224 | c.Assert(p(req), Equals, true) 225 | } 226 | 227 | func (s *ThresholdSuite) TestAttemptsNE(c *C) { 228 | p, err := ParseExpression(`Attempts() != 1`) 229 | c.Assert(err, IsNil) 230 | 231 | req := &BaseRequest{ 232 | Attempts: []Attempt{}, 233 | } 234 | c.Assert(p(req), Equals, true) 235 | 236 | req = &BaseRequest{ 237 | Attempts: []Attempt{ 238 | &BaseAttempt{ 239 | Response: &http.Response{StatusCode: 503}, 240 | }, 241 | }, 242 | } 243 | c.Assert(p(req), Equals, false) 244 | } 245 | 246 | func (s *ThresholdSuite) TestComplexExpression(c *C) { 247 | p, err := ParseExpression(`(ResponseCode() == 503 || IsNetworkError()) && Attempts() <= 1`) 248 | c.Assert(err, IsNil) 249 | 250 | // 503 error and one attempt 251 | req := &BaseRequest{ 252 | Attempts: []Attempt{ 253 | &BaseAttempt{ 254 | Response: &http.Response{StatusCode: 503}, 255 | }, 256 | }, 257 | } 258 | c.Assert(p(req), Equals, true) 259 | 260 | // 503 error and more than one attempt 261 | req = &BaseRequest{ 262 | Attempts: []Attempt{ 263 | &BaseAttempt{ 264 | Response: &http.Response{StatusCode: 503}, 265 | }, 266 | &BaseAttempt{ 267 | Response: &http.Response{StatusCode: 503}, 268 | }, 269 | }, 270 | } 271 | c.Assert(p(req), Equals, false) 272 | } 273 | 274 | func (s *ThresholdSuite) TestComplexLegacyExpression(c *C) { 275 | p, err := ParseExpression(`(IsNetworkError || ResponseCodeEq(503)) && AttemptsLe(2)`) 276 | c.Assert(err, IsNil) 277 | 278 | // 503 error and one attempt 279 | req := &BaseRequest{ 280 | Attempts: []Attempt{ 281 | &BaseAttempt{ 282 | Response: &http.Response{StatusCode: 503}, 283 | }, 284 | }, 285 | } 286 | c.Assert(p(req), Equals, true) 287 | 288 | // Network error and one attempt 289 | req = &BaseRequest{ 290 | Attempts: []Attempt{ 291 | &BaseAttempt{ 292 | Error: fmt.Errorf("Something failed"), 293 | }, 294 | }, 295 | } 296 | c.Assert(p(req), Equals, true) 297 | 298 | // 503 error and three attempts 299 | req = &BaseRequest{ 300 | Attempts: []Attempt{ 301 | &BaseAttempt{ 302 | Response: &http.Response{StatusCode: 503}, 303 | }, 304 | &BaseAttempt{ 305 | Response: &http.Response{StatusCode: 503}, 306 | }, 307 | &BaseAttempt{ 308 | Response: &http.Response{StatusCode: 503}, 309 | }, 310 | }, 311 | } 312 | c.Assert(p(req), Equals, false) 313 | } 314 | 315 | func (s *ThresholdSuite) TestInvalidCases(c *C) { 316 | cases := []string{ 317 | ")(", // invalid expression 318 | "1", // standalone literal 319 | "SomeFunc", // unsupported id 320 | "RequestMethod() == banana", // unsupported argument 321 | "RequestMethod() == RequestMethod()", // unsupported argument 322 | "RequestMethod() == 0.2", // unsupported argument 323 | "RequestMethod(200) == 200", // wrong number of arguments 324 | `RequestMethod() == "POST" && 1`, // standalone literal in expression 325 | `1 && RequestMethod() == "POST"`, // standalone literal in expression 326 | `Req(1)`, // unknown method call 327 | `RequestMethod(1)`, // bad parameter type 328 | } 329 | for _, tc := range cases { 330 | p, err := ParseExpression(tc) 331 | c.Assert(err, NotNil) 332 | c.Assert(p, IsNil) 333 | } 334 | } 335 | -------------------------------------------------------------------------------- /threshold/threshold.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package threshold contains predicates that can define various request thresholds 3 | 4 | Examples: 5 | 6 | * RequestMethod() == "GET" triggers action when request method equals "GET" 7 | * IsNetworkError() - triggers action on network errors 8 | * RequestMethod() == "GET" && Attempts <= 2 && (IsNetworkError() || ResponseCode() == 408) 9 | This predicate triggers for GET requests with maximum 2 attempts 10 | on network errors or when upstream returns special http response code 408 11 | */ 12 | package threshold 13 | 14 | import ( 15 | "fmt" 16 | 17 | "github.com/mailgun/vulcan/request" 18 | ) 19 | 20 | // Predicate that defines what request can fail over in case of error or http response 21 | type Predicate func(request.Request) bool 22 | 23 | // RequestToString defines mapper function that maps a request to some string (e.g extracts method name) 24 | type RequestToString func(req request.Request) string 25 | 26 | // RequestToInt defines mapper function that maps a request to some int (e.g extracts response code) 27 | type RequestToInt func(req request.Request) int 28 | 29 | // RequestToFloat64 defines mapper function that maps a request to some float64 (e.g extracts some ratio) 30 | type RequestToFloat64 func(req request.Request) float64 31 | 32 | // RequestMethod returns mapper of the request to its method e.g. POST 33 | func RequestMethod() RequestToString { 34 | return func(r request.Request) string { 35 | return r.GetHttpRequest().Method 36 | } 37 | } 38 | 39 | // Attempts returns mapper of the request to the number of proxy attempts 40 | func Attempts() RequestToInt { 41 | return func(r request.Request) int { 42 | return len(r.GetAttempts()) 43 | } 44 | } 45 | 46 | // ResponseCode returns mapper of the request to the last response code, returns 0 if there was no response code. 47 | func ResponseCode() RequestToInt { 48 | return func(r request.Request) int { 49 | attempts := len(r.GetAttempts()) 50 | if attempts == 0 { 51 | return 0 52 | } 53 | lastResponse := r.GetAttempts()[attempts-1].GetResponse() 54 | if lastResponse == nil { 55 | return 0 56 | } 57 | return lastResponse.StatusCode 58 | } 59 | } 60 | 61 | // IsNetworkError returns a predicate that returns true if last attempt ended with network error. 62 | func IsNetworkError() Predicate { 63 | return func(r request.Request) bool { 64 | attempts := len(r.GetAttempts()) 65 | return attempts != 0 && r.GetAttempts()[attempts-1].GetError() != nil 66 | } 67 | } 68 | 69 | // AND returns predicate by joining the passed predicates with logical 'and' 70 | func AND(fns ...Predicate) Predicate { 71 | return func(req request.Request) bool { 72 | for _, fn := range fns { 73 | if !fn(req) { 74 | return false 75 | } 76 | } 77 | return true 78 | } 79 | } 80 | 81 | // OR returns predicate by joining the passed predicates with logical 'or' 82 | func OR(fns ...Predicate) Predicate { 83 | return func(req request.Request) bool { 84 | for _, fn := range fns { 85 | if fn(req) { 86 | return true 87 | } 88 | } 89 | return false 90 | } 91 | } 92 | 93 | // NOT creates negation of the passed predicate 94 | func NOT(p Predicate) Predicate { 95 | return func(r request.Request) bool { 96 | return !p(r) 97 | } 98 | } 99 | 100 | // EQ returns predicate that tests for equality of the value of the mapper and the constant 101 | func EQ(m interface{}, value interface{}) (Predicate, error) { 102 | switch mapper := m.(type) { 103 | case RequestToString: 104 | return stringEQ(mapper, value) 105 | case RequestToInt: 106 | return intEQ(mapper, value) 107 | } 108 | return nil, fmt.Errorf("unsupported argument: %T", m) 109 | } 110 | 111 | // NEQ returns predicate that tests for inequality of the value of the mapper and the constant 112 | func NEQ(m interface{}, value interface{}) (Predicate, error) { 113 | p, err := EQ(m, value) 114 | if err != nil { 115 | return nil, err 116 | } 117 | return NOT(p), nil 118 | } 119 | 120 | // LT returns predicate that tests that value of the mapper function is less than the constant 121 | func LT(m interface{}, value interface{}) (Predicate, error) { 122 | switch mapper := m.(type) { 123 | case RequestToInt: 124 | return intLT(mapper, value) 125 | case RequestToFloat64: 126 | return float64LT(mapper, value) 127 | } 128 | return nil, fmt.Errorf("unsupported argument: %T", m) 129 | } 130 | 131 | // GT returns predicate that tests that value of the mapper function is greater than the constant 132 | func GT(m interface{}, value interface{}) (Predicate, error) { 133 | switch mapper := m.(type) { 134 | case RequestToInt: 135 | return intGT(mapper, value) 136 | case RequestToFloat64: 137 | return float64GT(mapper, value) 138 | } 139 | return nil, fmt.Errorf("unsupported argument: %T", m) 140 | } 141 | 142 | // LE returns predicate that tests that value of the mapper function is less than or equal to the constant 143 | func LE(m interface{}, value interface{}) (Predicate, error) { 144 | switch mapper := m.(type) { 145 | case RequestToInt: 146 | return intLE(mapper, value) 147 | case RequestToFloat64: 148 | return float64LE(mapper, value) 149 | } 150 | return nil, fmt.Errorf("unsupported argument: %T", m) 151 | } 152 | 153 | // GE returns predicate that tests that value of the mapper function is greater than or equal to the constant 154 | func GE(m interface{}, value interface{}) (Predicate, error) { 155 | switch mapper := m.(type) { 156 | case RequestToInt: 157 | return intGE(mapper, value) 158 | case RequestToFloat64: 159 | return float64GE(mapper, value) 160 | } 161 | return nil, fmt.Errorf("unsupported argument: %T", m) 162 | } 163 | 164 | func stringEQ(m RequestToString, val interface{}) (Predicate, error) { 165 | value, ok := val.(string) 166 | if !ok { 167 | return nil, fmt.Errorf("expected string, got %T", val) 168 | } 169 | return func(req request.Request) bool { 170 | return m(req) == value 171 | }, nil 172 | } 173 | 174 | func intEQ(m RequestToInt, val interface{}) (Predicate, error) { 175 | value, ok := val.(int) 176 | if !ok { 177 | return nil, fmt.Errorf("expected int, got %T", val) 178 | } 179 | return func(req request.Request) bool { 180 | return m(req) == value 181 | }, nil 182 | } 183 | 184 | func intLT(m RequestToInt, val interface{}) (Predicate, error) { 185 | value, ok := val.(int) 186 | if !ok { 187 | return nil, fmt.Errorf("expected int, got %T", val) 188 | } 189 | return func(req request.Request) bool { 190 | return m(req) < value 191 | }, nil 192 | } 193 | 194 | func intGT(m RequestToInt, val interface{}) (Predicate, error) { 195 | value, ok := val.(int) 196 | if !ok { 197 | return nil, fmt.Errorf("expected int, got %T", val) 198 | } 199 | return func(req request.Request) bool { 200 | return m(req) > value 201 | }, nil 202 | } 203 | 204 | func intLE(m RequestToInt, val interface{}) (Predicate, error) { 205 | value, ok := val.(int) 206 | if !ok { 207 | return nil, fmt.Errorf("expected int, got %T", val) 208 | } 209 | return func(req request.Request) bool { 210 | return m(req) <= value 211 | }, nil 212 | } 213 | 214 | func intGE(m RequestToInt, val interface{}) (Predicate, error) { 215 | value, ok := val.(int) 216 | if !ok { 217 | return nil, fmt.Errorf("expected int, got %T", val) 218 | } 219 | return func(req request.Request) bool { 220 | return m(req) >= value 221 | }, nil 222 | } 223 | 224 | func float64EQ(m RequestToFloat64, val interface{}) (Predicate, error) { 225 | value, ok := val.(float64) 226 | if !ok { 227 | return nil, fmt.Errorf("expected float64, got %T", val) 228 | } 229 | return func(req request.Request) bool { 230 | return m(req) == value 231 | }, nil 232 | } 233 | 234 | func float64LT(m RequestToFloat64, val interface{}) (Predicate, error) { 235 | value, ok := val.(float64) 236 | if !ok { 237 | return nil, fmt.Errorf("expected float64, got %T", val) 238 | } 239 | return func(req request.Request) bool { 240 | return m(req) < value 241 | }, nil 242 | } 243 | 244 | func float64GT(m RequestToFloat64, val interface{}) (Predicate, error) { 245 | value, ok := val.(float64) 246 | if !ok { 247 | return nil, fmt.Errorf("expected float64, got %T", val) 248 | } 249 | return func(req request.Request) bool { 250 | return m(req) > value 251 | }, nil 252 | } 253 | 254 | func float64LE(m RequestToFloat64, val interface{}) (Predicate, error) { 255 | value, ok := val.(float64) 256 | if !ok { 257 | return nil, fmt.Errorf("expected float64, got %T", val) 258 | } 259 | return func(req request.Request) bool { 260 | return m(req) <= value 261 | }, nil 262 | } 263 | 264 | func float64GE(m RequestToFloat64, val interface{}) (Predicate, error) { 265 | value, ok := val.(float64) 266 | if !ok { 267 | return nil, fmt.Errorf("expected float64, got %T", val) 268 | } 269 | return func(req request.Request) bool { 270 | return m(req) >= value 271 | }, nil 272 | } 273 | --------------------------------------------------------------------------------