├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── cbreaker ├── cbreaker.go ├── cbreaker_test.go ├── effect.go ├── fallback.go ├── predicates.go ├── predicates_test.go ├── ratio.go └── ratio_test.go ├── connlimit ├── connlimit.go └── connlimit_test.go ├── forward ├── fwd.go ├── fwd_test.go ├── fwd_websocket_test.go ├── headers.go ├── responseflusher.go ├── rewrite.go └── rewrite_test.go ├── memmetrics ├── anomaly.go ├── anomaly_test.go ├── counter.go ├── counter_test.go ├── histogram.go ├── histogram_test.go ├── ratio.go ├── ratio_test.go ├── roundtrip.go └── roundtrip_test.go ├── ratelimit ├── bucket.go ├── bucket_test.go ├── bucketset.go ├── bucketset_test.go ├── tokenlimiter.go └── tokenlimiter_test.go ├── roundrobin ├── rebalancer.go ├── rebalancer_test.go ├── rr.go ├── rr_test.go ├── stickysessions.go └── stickysessions_test.go ├── stream ├── retry_test.go ├── stream.go ├── stream_test.go └── threshold.go ├── testutils └── utils.go ├── trace ├── trace.go └── trace_test.go └── utils ├── auth.go ├── auth_test.go ├── handler.go ├── handler_test.go ├── logging.go ├── netutils.go ├── netutils_test.go └── source.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | flymake_* 27 | 28 | vendor/ 29 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: clean 2 | go test -v ./... -cover 3 | 4 | clean: 5 | find . -name flymake_* -delete 6 | 7 | test-package: clean 8 | go test -v ./$(p) 9 | 10 | test-grep-package: clean 11 | go test -v ./$(p) -check.f=$(e) 12 | 13 | cover-package: clean 14 | go test -v ./$(p) -coverprofile=/tmp/coverage.out 15 | go tool cover -html=/tmp/coverage.out 16 | 17 | sloccount: 18 | find . -path ./Godeps -prune -o -name "*.go" -print0 | xargs -0 wc -l 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Oxy 2 | ===== 3 | 4 | Oxy is a Go library with HTTP handlers that enhance HTTP standard library: 5 | 6 | * [Stream](http://godoc.org/github.com/vulcand/oxy/stream) retries and buffers requests and responses 7 | * [Forward](http://godoc.org/github.com/vulcand/oxy/forward) forwards requests to remote location and rewrites headers 8 | * [Roundrobin](http://godoc.org/github.com/vulcand/oxy/roundrobin) is a round-robin load balancer 9 | * [Circuit Breaker](http://godoc.org/github.com/vulcand/oxy/cbreaker) Hystrix-style circuit breaker 10 | * [Connlimit](http://godoc.org/github.com/vulcand/oxy/connlimit) Simultaneous connections limiter 11 | * [Ratelimit](http://godoc.org/github.com/vulcand/oxy/ratelimit) Rate limiter (based on tokenbucket algo) 12 | * [Trace](http://godoc.org/github.com/vulcand/oxy/trace) Structured request and response logger 13 | 14 | It is designed to be fully compatible with http standard library, easy to customize and reuse. 15 | 16 | Status 17 | ------ 18 | 19 | * Initial design is completed 20 | * Covered by tests 21 | * Used as a reverse proxy engine in [Vulcand](https://github.com/vulcand/vulcand) 22 | 23 | Quickstart 24 | ----------- 25 | 26 | Every handler is ``http.Handler``, so writing and plugging in a middleware is easy. Let us write a simple reverse proxy as an example: 27 | 28 | Simple reverse proxy 29 | ==================== 30 | 31 | ```go 32 | 33 | import ( 34 | "net/http" 35 | "github.com/vulcand/oxy/forward" 36 | "github.com/vulcand/oxy/testutils" 37 | ) 38 | 39 | // Forwards incoming requests to whatever location URL points to, adds proper forwarding headers 40 | fwd, _ := forward.New() 41 | 42 | redirect := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 43 | // let us forward this request to another server 44 | req.URL = testutils.ParseURI("http://localhost:63450") 45 | fwd.ServeHTTP(w, req) 46 | }) 47 | 48 | // that's it! our reverse proxy is ready! 49 | s := &http.Server{ 50 | Addr: ":8080", 51 | Handler: redirect, 52 | } 53 | s.ListenAndServe() 54 | ``` 55 | 56 | As a next step, let us add a round robin load-balancer: 57 | 58 | 59 | ```go 60 | 61 | import ( 62 | "net/http" 63 | "github.com/vulcand/oxy/forward" 64 | "github.com/vulcand/oxy/roundrobin" 65 | ) 66 | 67 | // Forwards incoming requests to whatever location URL points to, adds proper forwarding headers 68 | fwd, _ := forward.New() 69 | lb, _ := roundrobin.New(fwd) 70 | 71 | lb.UpsertServer(url1) 72 | lb.UpsertServer(url2) 73 | 74 | s := &http.Server{ 75 | Addr: ":8080", 76 | Handler: lb, 77 | } 78 | s.ListenAndServe() 79 | ``` 80 | 81 | What if we want to handle retries and replay the request in case of errors? `stream` handler will help: 82 | 83 | 84 | ```go 85 | 86 | import ( 87 | "net/http" 88 | "github.com/vulcand/oxy/forward" 89 | "github.com/vulcand/oxy/roundrobin" 90 | ) 91 | 92 | // Forwards incoming requests to whatever location URL points to, adds proper forwarding headers 93 | 94 | fwd, _ := forward.New() 95 | lb, _ := roundrobin.New(fwd) 96 | 97 | // stream will read the request body and will replay the request again in case if forward returned status 98 | // corresponding to nework error (e.g. Gateway Timeout) 99 | stream, _ := stream.New(lb, stream.Retry(`IsNetworkError() && Attempts() < 2`)) 100 | 101 | lb.UpsertServer(url1) 102 | lb.UpsertServer(url2) 103 | 104 | // that's it! our reverse proxy is ready! 105 | s := &http.Server{ 106 | Addr: ":8080", 107 | Handler: stream, 108 | } 109 | s.ListenAndServe() 110 | ``` 111 | -------------------------------------------------------------------------------- /cbreaker/cbreaker.go: -------------------------------------------------------------------------------- 1 | // package cbreaker implements circuit breaker similar to https://github.com/Netflix/Hystrix/wiki/How-it-Works 2 | // 3 | // Vulcan circuit breaker watches the error condtion to match 4 | // after which it activates the fallback scenario, e.g. returns the response code 5 | // or redirects the request to another location 6 | 7 | // Circuit breakers start in the Standby state first, observing responses and watching location metrics. 8 | // 9 | // Once the Circuit breaker condition is met, it enters the "Tripped" state, where it activates fallback scenario 10 | // for all requests during the FallbackDuration time period and reset the stats for the location. 11 | // 12 | // After FallbackDuration time period passes, Circuit breaker enters "Recovering" state, during that state it will 13 | // start passing some traffic back to the endpoints, increasing the amount of passed requests using linear function: 14 | // 15 | // allowedRequestsRatio = 0.5 * (Now() - StartRecovery())/RecoveryDuration 16 | // 17 | // Two scenarios are possible in the "Recovering" state: 18 | // 1. Condition matches again, this will reset the state to "Tripped" and reset the timer. 19 | // 2. Condition does not match, circuit breaker enters "Standby" state 20 | // 21 | // It is possible to define actions (e.g. webhooks) of transitions between states: 22 | // 23 | // * OnTripped action is called on transition (Standby -> Tripped) 24 | // * OnStandby action is called on transition (Recovering -> Standby) 25 | // 26 | package cbreaker 27 | 28 | import ( 29 | "fmt" 30 | "net/http" 31 | "sync" 32 | "time" 33 | 34 | "github.com/mailgun/timetools" 35 | "github.com/vulcand/oxy/memmetrics" 36 | "github.com/vulcand/oxy/utils" 37 | ) 38 | 39 | // CircuitBreaker is http.Handler that implements circuit breaker pattern 40 | type CircuitBreaker struct { 41 | m *sync.RWMutex 42 | metrics *memmetrics.RTMetrics 43 | 44 | condition hpredicate 45 | 46 | fallbackDuration time.Duration 47 | recoveryDuration time.Duration 48 | 49 | onTripped SideEffect 50 | onStandby SideEffect 51 | 52 | state cbState 53 | until time.Time 54 | 55 | rc *ratioController 56 | 57 | checkPeriod time.Duration 58 | lastCheck time.Time 59 | 60 | fallback http.Handler 61 | next http.Handler 62 | 63 | log utils.Logger 64 | clock timetools.TimeProvider 65 | } 66 | 67 | // New creates a new CircuitBreaker middleware 68 | func New(next http.Handler, expression string, options ...CircuitBreakerOption) (*CircuitBreaker, error) { 69 | cb := &CircuitBreaker{ 70 | m: &sync.RWMutex{}, 71 | next: next, 72 | // Default values. Might be overwritten by options below. 73 | clock: &timetools.RealTime{}, 74 | checkPeriod: defaultCheckPeriod, 75 | fallbackDuration: defaultFallbackDuration, 76 | recoveryDuration: defaultRecoveryDuration, 77 | fallback: defaultFallback, 78 | log: utils.NullLogger, 79 | } 80 | 81 | for _, s := range options { 82 | if err := s(cb); err != nil { 83 | return nil, err 84 | } 85 | } 86 | 87 | condition, err := parseExpression(expression) 88 | if err != nil { 89 | return nil, err 90 | } 91 | cb.condition = condition 92 | 93 | mt, err := memmetrics.NewRTMetrics() 94 | if err != nil { 95 | return nil, err 96 | } 97 | cb.metrics = mt 98 | 99 | return cb, nil 100 | } 101 | 102 | func (c *CircuitBreaker) ServeHTTP(w http.ResponseWriter, req *http.Request) { 103 | if c.activateFallback(w, req) { 104 | c.fallback.ServeHTTP(w, req) 105 | return 106 | } 107 | c.serve(w, req) 108 | } 109 | 110 | func (c *CircuitBreaker) Wrap(next http.Handler) { 111 | c.next = next 112 | } 113 | 114 | // updateState updates internal state and returns true if fallback should be used and false otherwise 115 | func (c *CircuitBreaker) activateFallback(w http.ResponseWriter, req *http.Request) bool { 116 | // Quick check with read locks optimized for normal operation use-case 117 | if c.isStandby() { 118 | return false 119 | } 120 | // Circuit breaker is in tripped or recovering state 121 | c.m.Lock() 122 | defer c.m.Unlock() 123 | 124 | c.log.Infof("%v is in error state", c) 125 | 126 | switch c.state { 127 | case stateStandby: 128 | // someone else has set it to standby just now 129 | return false 130 | case stateTripped: 131 | if c.clock.UtcNow().Before(c.until) { 132 | return true 133 | } 134 | // We have been in active state enough, enter recovering state 135 | c.setRecovering() 136 | fallthrough 137 | case stateRecovering: 138 | // We have been in recovering state enough, enter standby and allow request 139 | if c.clock.UtcNow().After(c.until) { 140 | c.setState(stateStandby, c.clock.UtcNow()) 141 | return false 142 | } 143 | // ratio controller allows this request 144 | if c.rc.allowRequest() { 145 | return false 146 | } 147 | return true 148 | } 149 | return false 150 | } 151 | 152 | func (c *CircuitBreaker) serve(w http.ResponseWriter, req *http.Request) { 153 | start := c.clock.UtcNow() 154 | p := &utils.ProxyWriter{W: w} 155 | 156 | c.next.ServeHTTP(p, req) 157 | 158 | latency := c.clock.UtcNow().Sub(start) 159 | c.metrics.Record(p.Code, latency) 160 | 161 | // Note that this call is less expensive than it looks -- checkCondition only performs the real check 162 | // periodically. Because of that we can afford to call it here on every single response. 163 | c.checkAndSet() 164 | } 165 | 166 | func (c *CircuitBreaker) isStandby() bool { 167 | c.m.RLock() 168 | defer c.m.RUnlock() 169 | return c.state == stateStandby 170 | } 171 | 172 | // String returns log-friendly representation of the circuit breaker state 173 | func (c *CircuitBreaker) String() string { 174 | switch c.state { 175 | case stateTripped, stateRecovering: 176 | return fmt.Sprintf("CircuitBreaker(state=%v, until=%v)", c.state, c.until) 177 | default: 178 | return fmt.Sprintf("CircuitBreaker(state=%v)", c.state) 179 | } 180 | } 181 | 182 | // exec executes side effect 183 | func (c *CircuitBreaker) exec(s SideEffect) { 184 | if s == nil { 185 | return 186 | } 187 | go func() { 188 | if err := s.Exec(); err != nil { 189 | c.log.Errorf("%v side effect failure: %v", c, err) 190 | } 191 | }() 192 | } 193 | 194 | func (c *CircuitBreaker) setState(new cbState, until time.Time) { 195 | c.log.Infof("%v setting state to %v, until %v", c, new, until) 196 | c.state = new 197 | c.until = until 198 | switch new { 199 | case stateTripped: 200 | c.exec(c.onTripped) 201 | case stateStandby: 202 | c.exec(c.onStandby) 203 | } 204 | } 205 | 206 | func (c *CircuitBreaker) timeToCheck() bool { 207 | c.m.RLock() 208 | defer c.m.RUnlock() 209 | return c.clock.UtcNow().After(c.lastCheck) 210 | } 211 | 212 | // Checks if tripping condition matches and sets circuit breaker to the tripped state 213 | func (c *CircuitBreaker) checkAndSet() { 214 | if !c.timeToCheck() { 215 | return 216 | } 217 | 218 | c.m.Lock() 219 | defer c.m.Unlock() 220 | 221 | // Other goroutine could have updated the lastCheck variable before we grabbed mutex 222 | if !c.clock.UtcNow().After(c.lastCheck) { 223 | return 224 | } 225 | c.lastCheck = c.clock.UtcNow().Add(c.checkPeriod) 226 | 227 | if c.state == stateTripped { 228 | c.log.Infof("%v skip set tripped", c) 229 | return 230 | } 231 | 232 | if !c.condition(c) { 233 | return 234 | } 235 | 236 | c.setState(stateTripped, c.clock.UtcNow().Add(c.fallbackDuration)) 237 | c.metrics.Reset() 238 | } 239 | 240 | func (c *CircuitBreaker) setRecovering() { 241 | c.setState(stateRecovering, c.clock.UtcNow().Add(c.recoveryDuration)) 242 | c.rc = newRatioController(c.clock, c.recoveryDuration) 243 | } 244 | 245 | // CircuitBreakerOption represents an option you can pass to New. 246 | // See the documentation for the individual options below. 247 | type CircuitBreakerOption func(*CircuitBreaker) error 248 | 249 | // Clock allows you to fake che CircuitBreaker's view of the current time. 250 | // Intended for unit tests. 251 | func Clock(clock timetools.TimeProvider) CircuitBreakerOption { 252 | return func(c *CircuitBreaker) error { 253 | c.clock = clock 254 | return nil 255 | } 256 | } 257 | 258 | // FallbackDuration is how long the CircuitBreaker will remain in the Tripped 259 | // state before trying to recover. 260 | func FallbackDuration(d time.Duration) CircuitBreakerOption { 261 | return func(c *CircuitBreaker) error { 262 | c.fallbackDuration = d 263 | return nil 264 | } 265 | } 266 | 267 | // RecoveryDuration is how long the CircuitBreaker will take to ramp up 268 | // requests during the Recovering state. 269 | func RecoveryDuration(d time.Duration) CircuitBreakerOption { 270 | return func(c *CircuitBreaker) error { 271 | c.recoveryDuration = d 272 | return nil 273 | } 274 | } 275 | 276 | // CheckPeriod is how long the CircuitBreaker will wait between successive 277 | // checks of the breaker condition. 278 | func CheckPeriod(d time.Duration) CircuitBreakerOption { 279 | return func(c *CircuitBreaker) error { 280 | c.checkPeriod = d 281 | return nil 282 | } 283 | } 284 | 285 | // OnTripped sets a SideEffect to run when entering the Tripped state. 286 | // Only one SideEffect can be set for this hook. 287 | func OnTripped(s SideEffect) CircuitBreakerOption { 288 | return func(c *CircuitBreaker) error { 289 | c.onTripped = s 290 | return nil 291 | } 292 | } 293 | 294 | // OnTripped sets a SideEffect to run when entering the Standby state. 295 | // Only one SideEffect can be set for this hook. 296 | func OnStandby(s SideEffect) CircuitBreakerOption { 297 | return func(c *CircuitBreaker) error { 298 | c.onStandby = s 299 | return nil 300 | } 301 | } 302 | 303 | // Fallback defines the http.Handler that the CircuitBreaker should route 304 | // requests to when it prevents a request from taking its normal path. 305 | func Fallback(h http.Handler) CircuitBreakerOption { 306 | return func(c *CircuitBreaker) error { 307 | c.fallback = h 308 | return nil 309 | } 310 | } 311 | 312 | // Logger adds logging for the CircuitBreaker. 313 | func Logger(l utils.Logger) CircuitBreakerOption { 314 | return func(c *CircuitBreaker) error { 315 | c.log = l 316 | return nil 317 | } 318 | } 319 | 320 | // cbState is the state of the circuit breaker 321 | type cbState int 322 | 323 | func (s cbState) String() string { 324 | switch s { 325 | case stateStandby: 326 | return "standby" 327 | case stateTripped: 328 | return "tripped" 329 | case stateRecovering: 330 | return "recovering" 331 | } 332 | return "undefined" 333 | } 334 | 335 | const ( 336 | // CircuitBreaker is passing all requests and watching stats 337 | stateStandby = iota 338 | // CircuitBreaker activates fallback scenario for all requests 339 | stateTripped 340 | // CircuitBreaker passes some requests to go through, rejecting others 341 | stateRecovering 342 | ) 343 | 344 | const ( 345 | defaultFallbackDuration = 10 * time.Second 346 | defaultRecoveryDuration = 10 * time.Second 347 | defaultCheckPeriod = 100 * time.Millisecond 348 | ) 349 | 350 | var defaultFallback = &fallback{} 351 | 352 | type fallback struct { 353 | } 354 | 355 | func (f *fallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { 356 | w.WriteHeader(http.StatusServiceUnavailable) 357 | w.Write([]byte(http.StatusText(http.StatusServiceUnavailable))) 358 | } 359 | -------------------------------------------------------------------------------- /cbreaker/cbreaker_test.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "net/http" 7 | "net/http/httptest" 8 | "net/url" 9 | "testing" 10 | "time" 11 | 12 | "github.com/mailgun/timetools" 13 | "github.com/vulcand/oxy/memmetrics" 14 | "github.com/vulcand/oxy/testutils" 15 | 16 | . "gopkg.in/check.v1" 17 | ) 18 | 19 | func TestCircuitBreaker(t *testing.T) { TestingT(t) } 20 | 21 | type CBSuite struct { 22 | clock *timetools.FreezedTime 23 | } 24 | 25 | var _ = Suite(&CBSuite{ 26 | clock: &timetools.FreezedTime{ 27 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 28 | }, 29 | }) 30 | 31 | const triggerNetRatio = `NetworkErrorRatio() > 0.5` 32 | 33 | var fallbackResponse http.Handler 34 | var fallbackRedirect http.Handler 35 | 36 | func (s CBSuite) SetUpSuite(c *C) { 37 | f, err := NewResponseFallback(Response{StatusCode: 400, Body: []byte("Come back later")}) 38 | c.Assert(err, IsNil) 39 | fallbackResponse = f 40 | 41 | rdr, err := NewRedirectFallback(Redirect{URL: "http://localhost:5000"}) 42 | c.Assert(err, IsNil) 43 | fallbackRedirect = rdr 44 | } 45 | 46 | func (s *CBSuite) advanceTime(d time.Duration) { 47 | s.clock.CurrentTime = s.clock.CurrentTime.Add(d) 48 | } 49 | 50 | func (s *CBSuite) TestStandbyCycle(c *C) { 51 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 52 | w.Write([]byte("hello")) 53 | }) 54 | 55 | cb, err := New(handler, triggerNetRatio) 56 | c.Assert(err, IsNil) 57 | 58 | srv := httptest.NewServer(cb) 59 | defer srv.Close() 60 | 61 | re, body, err := testutils.Get(srv.URL) 62 | c.Assert(err, IsNil) 63 | c.Assert(re.StatusCode, Equals, http.StatusOK) 64 | c.Assert(string(body), Equals, "hello") 65 | } 66 | 67 | func (s *CBSuite) TestFullCycle(c *C) { 68 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 69 | w.Write([]byte("hello")) 70 | }) 71 | 72 | cb, err := New(handler, triggerNetRatio, Clock(s.clock)) 73 | c.Assert(err, IsNil) 74 | 75 | srv := httptest.NewServer(cb) 76 | defer srv.Close() 77 | 78 | re, _, err := testutils.Get(srv.URL) 79 | c.Assert(err, IsNil) 80 | c.Assert(re.StatusCode, Equals, http.StatusOK) 81 | 82 | cb.metrics = statsNetErrors(0.6) 83 | s.advanceTime(defaultCheckPeriod + time.Millisecond) 84 | re, _, err = testutils.Get(srv.URL) 85 | c.Assert(err, IsNil) 86 | c.Assert(cb.state, Equals, cbState(stateTripped)) 87 | 88 | // Some time has passed, but we are still in trpped state. 89 | s.advanceTime(9 * time.Second) 90 | re, _, err = testutils.Get(srv.URL) 91 | c.Assert(err, IsNil) 92 | c.Assert(re.StatusCode, Equals, http.StatusServiceUnavailable) 93 | c.Assert(cb.state, Equals, cbState(stateTripped)) 94 | 95 | // We should be in recovering state by now 96 | s.advanceTime(time.Second*1 + time.Millisecond) 97 | re, _, err = testutils.Get(srv.URL) 98 | c.Assert(err, IsNil) 99 | c.Assert(re.StatusCode, Equals, http.StatusServiceUnavailable) 100 | c.Assert(cb.state, Equals, cbState(stateRecovering)) 101 | 102 | // 5 seconds after we should be allowing some requests to pass 103 | s.advanceTime(5 * time.Second) 104 | allowed := 0 105 | for i := 0; i < 100; i++ { 106 | re, _, err = testutils.Get(srv.URL) 107 | if re.StatusCode == http.StatusOK && err == nil { 108 | allowed++ 109 | } 110 | } 111 | c.Assert(allowed, Not(Equals), 0) 112 | 113 | // After some time, all is good and we should be in stand by mode again 114 | s.advanceTime(5*time.Second + time.Millisecond) 115 | re, _, err = testutils.Get(srv.URL) 116 | c.Assert(cb.state, Equals, cbState(stateStandby)) 117 | c.Assert(err, IsNil) 118 | c.Assert(re.StatusCode, Equals, http.StatusOK) 119 | } 120 | 121 | func (s *CBSuite) TestRedirect(c *C) { 122 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 123 | w.Write([]byte("hello")) 124 | }) 125 | 126 | cb, err := New(handler, triggerNetRatio, Clock(s.clock), Fallback(fallbackRedirect)) 127 | c.Assert(err, IsNil) 128 | 129 | srv := httptest.NewServer(cb) 130 | defer srv.Close() 131 | 132 | cb.metrics = statsNetErrors(0.6) 133 | re, _, err := testutils.Get(srv.URL) 134 | c.Assert(err, IsNil) 135 | 136 | client := &http.Client{ 137 | CheckRedirect: func(req *http.Request, via []*http.Request) error { 138 | return fmt.Errorf("no redirects") 139 | }, 140 | } 141 | 142 | re, err = client.Get(srv.URL) 143 | c.Assert(err, NotNil) 144 | c.Assert(re.StatusCode, Equals, http.StatusFound) 145 | c.Assert(re.Header.Get("Location"), Equals, "http://localhost:5000") 146 | } 147 | 148 | func (s *CBSuite) TestTriggerDuringRecovery(c *C) { 149 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 150 | w.Write([]byte("hello")) 151 | }) 152 | 153 | cb, err := New(handler, triggerNetRatio, Clock(s.clock), CheckPeriod(time.Microsecond)) 154 | c.Assert(err, IsNil) 155 | 156 | srv := httptest.NewServer(cb) 157 | defer srv.Close() 158 | 159 | cb.metrics = statsNetErrors(0.6) 160 | re, _, err := testutils.Get(srv.URL) 161 | c.Assert(err, IsNil) 162 | c.Assert(cb.state, Equals, cbState(stateTripped)) 163 | 164 | // We should be in recovering state by now 165 | s.advanceTime(10*time.Second + time.Millisecond) 166 | re, _, err = testutils.Get(srv.URL) 167 | c.Assert(err, IsNil) 168 | c.Assert(re.StatusCode, Equals, http.StatusServiceUnavailable) 169 | c.Assert(cb.state, Equals, cbState(stateRecovering)) 170 | 171 | // We have matched error condition during recovery state and are going back to tripped state 172 | s.advanceTime(5 * time.Second) 173 | cb.metrics = statsNetErrors(0.6) 174 | allowed := 0 175 | for i := 0; i < 100; i++ { 176 | re, _, err = testutils.Get(srv.URL) 177 | if re.StatusCode == http.StatusOK && err == nil { 178 | allowed++ 179 | } 180 | } 181 | c.Assert(allowed, Not(Equals), 0) 182 | c.Assert(cb.state, Equals, cbState(stateTripped)) 183 | } 184 | 185 | func (s *CBSuite) TestSideEffects(c *C) { 186 | srv1Chan := make(chan *http.Request, 1) 187 | var srv1Body []byte 188 | srv1 := testutils.NewHandler(func(w http.ResponseWriter, r *http.Request) { 189 | b, err := ioutil.ReadAll(r.Body) 190 | c.Assert(err, IsNil) 191 | srv1Body = b 192 | w.Write([]byte("srv1")) 193 | srv1Chan <- r 194 | }) 195 | defer srv1.Close() 196 | 197 | srv2Chan := make(chan *http.Request, 1) 198 | srv2 := testutils.NewHandler(func(w http.ResponseWriter, r *http.Request) { 199 | w.Write([]byte("srv2")) 200 | r.ParseForm() 201 | srv2Chan <- r 202 | }) 203 | defer srv2.Close() 204 | 205 | onTripped, err := NewWebhookSideEffect( 206 | Webhook{ 207 | URL: fmt.Sprintf("%s/post.json", srv1.URL), 208 | Method: "POST", 209 | Headers: map[string][]string{"Content-Type": []string{"application/json"}}, 210 | Body: []byte(`{"Key": ["val1", "val2"]}`), 211 | }) 212 | c.Assert(err, IsNil) 213 | 214 | onStandby, err := NewWebhookSideEffect( 215 | Webhook{ 216 | URL: fmt.Sprintf("%s/post", srv2.URL), 217 | Method: "POST", 218 | Form: map[string][]string{"key": []string{"val1", "val2"}}, 219 | }) 220 | c.Assert(err, IsNil) 221 | 222 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 223 | w.Write([]byte("hello")) 224 | }) 225 | 226 | cb, err := New(handler, triggerNetRatio, Clock(s.clock), CheckPeriod(time.Microsecond), OnTripped(onTripped), OnStandby(onStandby)) 227 | c.Assert(err, IsNil) 228 | 229 | srv := httptest.NewServer(cb) 230 | defer srv.Close() 231 | 232 | cb.metrics = statsNetErrors(0.6) 233 | 234 | _, _, err = testutils.Get(srv.URL) 235 | c.Assert(err, IsNil) 236 | c.Assert(cb.state, Equals, cbState(stateTripped)) 237 | 238 | select { 239 | case req := <-srv1Chan: 240 | c.Assert(req.Method, Equals, "POST") 241 | c.Assert(req.URL.Path, Equals, "/post.json") 242 | c.Assert(string(srv1Body), Equals, `{"Key": ["val1", "val2"]}`) 243 | c.Assert(req.Header.Get("Content-Type"), Equals, "application/json") 244 | case <-time.After(time.Second): 245 | c.Error("timeout waiting for side effect to kick off") 246 | } 247 | 248 | // Transition to recovering state 249 | s.advanceTime(10*time.Second + time.Millisecond) 250 | cb.metrics = statsOK() 251 | testutils.Get(srv.URL) 252 | c.Assert(cb.state, Equals, cbState(stateRecovering)) 253 | 254 | // Going back to standby 255 | s.advanceTime(10*time.Second + time.Millisecond) 256 | testutils.Get(srv.URL) 257 | c.Assert(cb.state, Equals, cbState(stateStandby)) 258 | 259 | select { 260 | case req := <-srv2Chan: 261 | c.Assert(req.Method, Equals, "POST") 262 | c.Assert(req.URL.Path, Equals, "/post") 263 | c.Assert(req.Form, DeepEquals, url.Values{"key": []string{"val1", "val2"}}) 264 | case <-time.After(time.Second): 265 | c.Error("timeout waiting for side effect to kick off") 266 | } 267 | } 268 | 269 | func statsOK() *memmetrics.RTMetrics { 270 | m, err := memmetrics.NewRTMetrics() 271 | if err != nil { 272 | panic(err) 273 | } 274 | return m 275 | } 276 | 277 | func statsNetErrors(threshold float64) *memmetrics.RTMetrics { 278 | m, err := memmetrics.NewRTMetrics() 279 | if err != nil { 280 | panic(err) 281 | } 282 | for i := 0; i < 100; i++ { 283 | if i < int(threshold*100) { 284 | m.Record(http.StatusGatewayTimeout, 0) 285 | } else { 286 | m.Record(http.StatusOK, 0) 287 | } 288 | } 289 | return m 290 | } 291 | 292 | func statsLatencyAtQuantile(quantile float64, value time.Duration) *memmetrics.RTMetrics { 293 | m, err := memmetrics.NewRTMetrics() 294 | if err != nil { 295 | panic(err) 296 | } 297 | m.Record(http.StatusOK, value) 298 | return m 299 | } 300 | 301 | func statsResponseCodes(codes ...statusCode) *memmetrics.RTMetrics { 302 | m, err := memmetrics.NewRTMetrics() 303 | if err != nil { 304 | panic(err) 305 | } 306 | for _, c := range codes { 307 | for i := int64(0); i < c.Count; i++ { 308 | m.Record(c.Code, 0) 309 | } 310 | } 311 | return m 312 | } 313 | 314 | type statusCode struct { 315 | Code int 316 | Count int64 317 | } 318 | -------------------------------------------------------------------------------- /cbreaker/effect.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "io/ioutil" 8 | "net/http" 9 | "net/url" 10 | "strings" 11 | 12 | "github.com/vulcand/oxy/utils" 13 | ) 14 | 15 | type SideEffect interface { 16 | Exec() error 17 | } 18 | 19 | type Webhook struct { 20 | URL string 21 | Method string 22 | Headers http.Header 23 | Form url.Values 24 | Body []byte 25 | } 26 | 27 | type WebhookSideEffect struct { 28 | w Webhook 29 | } 30 | 31 | func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) { 32 | if w.Method == "" { 33 | return nil, fmt.Errorf("Supply method") 34 | } 35 | _, err := url.Parse(w.URL) 36 | if err != nil { 37 | return nil, err 38 | } 39 | 40 | return &WebhookSideEffect{w: w}, nil 41 | } 42 | 43 | func (w *WebhookSideEffect) getBody() io.Reader { 44 | if len(w.w.Form) != 0 { 45 | return strings.NewReader(w.w.Form.Encode()) 46 | } 47 | if len(w.w.Body) != 0 { 48 | return bytes.NewBuffer(w.w.Body) 49 | } 50 | return nil 51 | } 52 | 53 | func (w *WebhookSideEffect) Exec() error { 54 | r, err := http.NewRequest(w.w.Method, w.w.URL, w.getBody()) 55 | if err != nil { 56 | return err 57 | } 58 | if len(w.w.Headers) != 0 { 59 | utils.CopyHeaders(r.Header, w.w.Headers) 60 | } 61 | if len(w.w.Form) != 0 { 62 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded") 63 | } 64 | re, err := http.DefaultClient.Do(r) 65 | if err != nil { 66 | return err 67 | } 68 | if re.Body != nil { 69 | defer re.Body.Close() 70 | } 71 | _, err = ioutil.ReadAll(re.Body) 72 | if err != nil { 73 | return err 74 | } 75 | return nil 76 | } 77 | -------------------------------------------------------------------------------- /cbreaker/fallback.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/url" 7 | "strconv" 8 | ) 9 | 10 | type Response struct { 11 | StatusCode int 12 | ContentType string 13 | Body []byte 14 | } 15 | 16 | type ResponseFallback struct { 17 | r Response 18 | } 19 | 20 | func NewResponseFallback(r Response) (*ResponseFallback, error) { 21 | if r.StatusCode == 0 { 22 | return nil, fmt.Errorf("response code should not be 0") 23 | } 24 | return &ResponseFallback{r: r}, nil 25 | } 26 | 27 | func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { 28 | if f.r.ContentType != "" { 29 | w.Header().Set("Content-Type", f.r.ContentType) 30 | } 31 | w.Header().Set("Content-Length", strconv.Itoa(len(f.r.Body))) 32 | w.WriteHeader(f.r.StatusCode) 33 | w.Write(f.r.Body) 34 | } 35 | 36 | type Redirect struct { 37 | URL string 38 | } 39 | 40 | type RedirectFallback struct { 41 | u *url.URL 42 | } 43 | 44 | func NewRedirectFallback(r Redirect) (*RedirectFallback, error) { 45 | u, err := url.ParseRequestURI(r.URL) 46 | if err != nil { 47 | return nil, err 48 | } 49 | return &RedirectFallback{u: u}, nil 50 | } 51 | 52 | func (f *RedirectFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { 53 | w.Header().Set("Location", f.u.String()) 54 | w.WriteHeader(http.StatusFound) 55 | w.Write([]byte(http.StatusText(http.StatusFound))) 56 | } 57 | -------------------------------------------------------------------------------- /cbreaker/predicates.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/vulcand/predicate" 8 | ) 9 | 10 | type hpredicate func(*CircuitBreaker) bool 11 | 12 | // parseExpression parses expression in the go language into predicates. 13 | func parseExpression(in string) (hpredicate, error) { 14 | p, err := predicate.NewParser(predicate.Def{ 15 | Operators: predicate.Operators{ 16 | AND: and, 17 | OR: or, 18 | EQ: eq, 19 | NEQ: neq, 20 | LT: lt, 21 | LE: le, 22 | GT: gt, 23 | GE: ge, 24 | }, 25 | Functions: map[string]interface{}{ 26 | "LatencyAtQuantileMS": latencyAtQuantile, 27 | "NetworkErrorRatio": networkErrorRatio, 28 | "ResponseCodeRatio": responseCodeRatio, 29 | }, 30 | }) 31 | if err != nil { 32 | return nil, err 33 | } 34 | out, err := p.Parse(in) 35 | if err != nil { 36 | return nil, err 37 | } 38 | pr, ok := out.(hpredicate) 39 | if !ok { 40 | return nil, fmt.Errorf("expected predicate, got %T", out) 41 | } 42 | return pr, nil 43 | } 44 | 45 | type toInt func(c *CircuitBreaker) int 46 | type toFloat64 func(c *CircuitBreaker) float64 47 | 48 | func latencyAtQuantile(quantile float64) toInt { 49 | return func(c *CircuitBreaker) int { 50 | h, err := c.metrics.LatencyHistogram() 51 | if err != nil { 52 | c.log.Errorf("Failed to get latency histogram, for %v error: %v", c, err) 53 | return 0 54 | } 55 | return int(h.LatencyAtQuantile(quantile) / time.Millisecond) 56 | } 57 | } 58 | 59 | func networkErrorRatio() toFloat64 { 60 | return func(c *CircuitBreaker) float64 { 61 | return c.metrics.NetworkErrorRatio() 62 | } 63 | } 64 | 65 | func responseCodeRatio(startA, endA, startB, endB int) toFloat64 { 66 | return func(c *CircuitBreaker) float64 { 67 | return c.metrics.ResponseCodeRatio(startA, endA, startB, endB) 68 | } 69 | } 70 | 71 | // or returns predicate by joining the passed predicates with logical 'or' 72 | func or(fns ...hpredicate) hpredicate { 73 | return func(c *CircuitBreaker) bool { 74 | for _, fn := range fns { 75 | if fn(c) { 76 | return true 77 | } 78 | } 79 | return false 80 | } 81 | } 82 | 83 | // and returns predicate by joining the passed predicates with logical 'and' 84 | func and(fns ...hpredicate) hpredicate { 85 | return func(c *CircuitBreaker) bool { 86 | for _, fn := range fns { 87 | if !fn(c) { 88 | return false 89 | } 90 | } 91 | return true 92 | } 93 | } 94 | 95 | // not creates negation of the passed predicate 96 | func not(p hpredicate) hpredicate { 97 | return func(c *CircuitBreaker) bool { 98 | return !p(c) 99 | } 100 | } 101 | 102 | // eq returns predicate that tests for equality of the value of the mapper and the constant 103 | func eq(m interface{}, value interface{}) (hpredicate, error) { 104 | switch mapper := m.(type) { 105 | case toInt: 106 | return intEQ(mapper, value) 107 | case toFloat64: 108 | return float64EQ(mapper, value) 109 | } 110 | return nil, fmt.Errorf("eq: unsupported argument: %T", m) 111 | } 112 | 113 | // neq returns predicate that tests for inequality of the value of the mapper and the constant 114 | func neq(m interface{}, value interface{}) (hpredicate, error) { 115 | p, err := eq(m, value) 116 | if err != nil { 117 | return nil, err 118 | } 119 | return not(p), nil 120 | } 121 | 122 | // lt returns predicate that tests that value of the mapper function is less than the constant 123 | func lt(m interface{}, value interface{}) (hpredicate, error) { 124 | switch mapper := m.(type) { 125 | case toInt: 126 | return intLT(mapper, value) 127 | case toFloat64: 128 | return float64LT(mapper, value) 129 | } 130 | return nil, fmt.Errorf("lt: unsupported argument: %T", m) 131 | } 132 | 133 | // le returns predicate that tests that value of the mapper function is less or equal than the constant 134 | func le(m interface{}, value interface{}) (hpredicate, error) { 135 | l, err := lt(m, value) 136 | if err != nil { 137 | return nil, err 138 | } 139 | e, err := eq(m, value) 140 | if err != nil { 141 | return nil, err 142 | } 143 | return func(c *CircuitBreaker) bool { 144 | return l(c) || e(c) 145 | }, nil 146 | } 147 | 148 | // gt returns predicate that tests that value of the mapper function is greater than the constant 149 | func gt(m interface{}, value interface{}) (hpredicate, error) { 150 | switch mapper := m.(type) { 151 | case toInt: 152 | return intGT(mapper, value) 153 | case toFloat64: 154 | return float64GT(mapper, value) 155 | } 156 | return nil, fmt.Errorf("gt: unsupported argument: %T", m) 157 | } 158 | 159 | // ge returns predicate that tests that value of the mapper function is less or equal than the constant 160 | func ge(m interface{}, value interface{}) (hpredicate, error) { 161 | g, err := gt(m, value) 162 | if err != nil { 163 | return nil, err 164 | } 165 | e, err := eq(m, value) 166 | if err != nil { 167 | return nil, err 168 | } 169 | return func(c *CircuitBreaker) bool { 170 | return g(c) || e(c) 171 | }, nil 172 | } 173 | 174 | func intEQ(m toInt, val interface{}) (hpredicate, error) { 175 | value, ok := val.(int) 176 | if !ok { 177 | return nil, fmt.Errorf("expected int, got %T", val) 178 | } 179 | return func(c *CircuitBreaker) bool { 180 | return m(c) == value 181 | }, nil 182 | } 183 | 184 | func float64EQ(m toFloat64, val interface{}) (hpredicate, error) { 185 | value, ok := val.(float64) 186 | if !ok { 187 | return nil, fmt.Errorf("expected float64, got %T", val) 188 | } 189 | return func(c *CircuitBreaker) bool { 190 | return m(c) == value 191 | }, nil 192 | } 193 | 194 | func intLT(m toInt, val interface{}) (hpredicate, error) { 195 | value, ok := val.(int) 196 | if !ok { 197 | return nil, fmt.Errorf("expected int, got %T", val) 198 | } 199 | return func(c *CircuitBreaker) bool { 200 | return m(c) < value 201 | }, nil 202 | } 203 | 204 | func intGT(m toInt, val interface{}) (hpredicate, error) { 205 | value, ok := val.(int) 206 | if !ok { 207 | return nil, fmt.Errorf("expected int, got %T", val) 208 | } 209 | return func(c *CircuitBreaker) bool { 210 | return m(c) > value 211 | }, nil 212 | } 213 | 214 | func float64LT(m toFloat64, val interface{}) (hpredicate, error) { 215 | value, ok := val.(float64) 216 | if !ok { 217 | return nil, fmt.Errorf("expected int, got %T", val) 218 | } 219 | return func(c *CircuitBreaker) bool { 220 | return m(c) < value 221 | }, nil 222 | } 223 | 224 | func float64GT(m toFloat64, val interface{}) (hpredicate, error) { 225 | value, ok := val.(float64) 226 | if !ok { 227 | return nil, fmt.Errorf("expected int, got %T", val) 228 | } 229 | return func(c *CircuitBreaker) bool { 230 | return m(c) > value 231 | }, nil 232 | } 233 | -------------------------------------------------------------------------------- /cbreaker/predicates_test.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "github.com/vulcand/oxy/memmetrics" 5 | "time" 6 | 7 | . "gopkg.in/check.v1" 8 | ) 9 | 10 | type PredicatesSuite struct { 11 | } 12 | 13 | var _ = Suite(&PredicatesSuite{}) 14 | 15 | func (s *PredicatesSuite) TestTripped(c *C) { 16 | predicates := []struct { 17 | Expression string 18 | M *memmetrics.RTMetrics 19 | V bool 20 | }{ 21 | { 22 | Expression: "NetworkErrorRatio() > 0.5", 23 | M: statsNetErrors(0.6), 24 | V: true, 25 | }, 26 | { 27 | Expression: "NetworkErrorRatio() < 0.5", 28 | M: statsNetErrors(0.6), 29 | V: false, 30 | }, 31 | { 32 | Expression: "LatencyAtQuantileMS(50.0) > 50", 33 | M: statsLatencyAtQuantile(50, time.Millisecond*51), 34 | V: true, 35 | }, 36 | { 37 | Expression: "LatencyAtQuantileMS(50.0) < 50", 38 | M: statsLatencyAtQuantile(50, time.Millisecond*51), 39 | V: false, 40 | }, 41 | { 42 | Expression: "ResponseCodeRatio(500, 600, 0, 600) > 0.5", 43 | M: statsResponseCodes(statusCode{Code: 200, Count: 5}, statusCode{Code: 500, Count: 6}), 44 | V: true, 45 | }, 46 | { 47 | Expression: "ResponseCodeRatio(500, 600, 0, 600) > 0.5", 48 | M: statsResponseCodes(statusCode{Code: 200, Count: 5}, statusCode{Code: 500, Count: 4}), 49 | V: false, 50 | }, 51 | } 52 | for _, t := range predicates { 53 | p, err := parseExpression(t.Expression) 54 | c.Assert(err, IsNil) 55 | c.Assert(p, NotNil) 56 | 57 | c.Assert(p(&CircuitBreaker{metrics: t.M}), Equals, t.V) 58 | } 59 | } 60 | 61 | func (s *PredicatesSuite) TestErrors(c *C) { 62 | predicates := []struct { 63 | Expression string 64 | M *memmetrics.RTMetrics 65 | }{ 66 | { 67 | Expression: "LatencyAtQuantileMS(40.0) > 50", // quantile not defined 68 | M: statsNetErrors(0.6), 69 | }, 70 | } 71 | for _, t := range predicates { 72 | p, err := parseExpression(t.Expression) 73 | c.Assert(err, IsNil) 74 | c.Assert(p, NotNil) 75 | 76 | c.Assert(p(&CircuitBreaker{metrics: t.M}), Equals, false) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /cbreaker/ratio.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/mailgun/timetools" 8 | ) 9 | 10 | // ratioController allows passing portions traffic back to the endpoints, 11 | // increasing the amount of passed requests using linear function: 12 | // 13 | // allowedRequestsRatio = 0.5 * (Now() - Start())/Duration 14 | // 15 | type ratioController struct { 16 | duration time.Duration 17 | start time.Time 18 | tm timetools.TimeProvider 19 | allowed int 20 | denied int 21 | } 22 | 23 | func newRatioController(tm timetools.TimeProvider, rampUp time.Duration) *ratioController { 24 | return &ratioController{ 25 | duration: rampUp, 26 | tm: tm, 27 | start: tm.UtcNow(), 28 | } 29 | } 30 | 31 | func (r *ratioController) String() string { 32 | return fmt.Sprintf("RatioController(target=%f, current=%f, allowed=%d, denied=%d)", r.targetRatio(), r.computeRatio(r.allowed, r.denied), r.allowed, r.denied) 33 | } 34 | 35 | func (r *ratioController) allowRequest() bool { 36 | t := r.targetRatio() 37 | // This condition answers the question - would we satisfy the target ratio if we allow this request? 38 | e := r.computeRatio(r.allowed+1, r.denied) 39 | if e < t { 40 | r.allowed++ 41 | return true 42 | } 43 | r.denied++ 44 | return false 45 | } 46 | 47 | func (r *ratioController) computeRatio(allowed, denied int) float64 { 48 | if denied+allowed == 0 { 49 | return 0 50 | } 51 | return float64(allowed) / float64(denied+allowed) 52 | } 53 | 54 | func (r *ratioController) targetRatio() float64 { 55 | // Here's why it's 0.5: 56 | // We are watching the following ratio 57 | // ratio = a / (a + d) 58 | // We can notice, that once we get to 0.5 59 | // 0.5 = a / (a + d) 60 | // we can evaluate that a = d 61 | // that means equilibrium, where we would allow all the requests 62 | // after this point to achieve ratio of 1 (that can never be reached unless d is 0) 63 | // so we stop from there 64 | multiplier := 0.5 / float64(r.duration) 65 | return multiplier * float64(r.tm.UtcNow().Sub(r.start)) 66 | } 67 | -------------------------------------------------------------------------------- /cbreaker/ratio_test.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "math" 5 | "time" 6 | 7 | "github.com/mailgun/timetools" 8 | . "gopkg.in/check.v1" 9 | ) 10 | 11 | type RatioSuite struct { 12 | tm *timetools.FreezedTime 13 | } 14 | 15 | var _ = Suite(&RatioSuite{ 16 | tm: &timetools.FreezedTime{ 17 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 18 | }, 19 | }) 20 | 21 | func (s *RatioSuite) advanceTime(d time.Duration) { 22 | s.tm.CurrentTime = s.tm.CurrentTime.Add(d) 23 | } 24 | 25 | func (s *RatioSuite) TestRampUp(c *C) { 26 | duration := 10 * time.Second 27 | rc := newRatioController(s.tm, duration) 28 | 29 | allowed, denied := 0, 0 30 | for i := 0; i < int(duration/time.Millisecond); i++ { 31 | ratio := s.sendRequest(&allowed, &denied, rc) 32 | expected := rc.targetRatio() 33 | diff := math.Abs(expected - ratio) 34 | c.Assert(round(diff, 0.5, 1), Equals, float64(0)) 35 | s.advanceTime(time.Millisecond) 36 | } 37 | } 38 | 39 | func (s *RatioSuite) sendRequest(allowed, denied *int, rc *ratioController) float64 { 40 | if rc.allowRequest() { 41 | *allowed++ 42 | } else { 43 | *denied++ 44 | } 45 | if *allowed+*denied == 0 { 46 | return 0 47 | } 48 | return float64(*allowed) / float64(*allowed+*denied) 49 | } 50 | 51 | func round(val float64, roundOn float64, places int) float64 { 52 | pow := math.Pow(10, float64(places)) 53 | digit := pow * val 54 | _, div := math.Modf(digit) 55 | var round float64 56 | if div >= roundOn { 57 | round = math.Ceil(digit) 58 | } else { 59 | round = math.Floor(digit) 60 | } 61 | return round / pow 62 | } 63 | -------------------------------------------------------------------------------- /connlimit/connlimit.go: -------------------------------------------------------------------------------- 1 | // package connlimit provides control over simultaneous connections coming from the same source 2 | package connlimit 3 | 4 | import ( 5 | "fmt" 6 | "net/http" 7 | "sync" 8 | 9 | "github.com/vulcand/oxy/utils" 10 | ) 11 | 12 | // Limiter tracks concurrent connection per token 13 | // and is capable of rejecting connections if they are failed 14 | type ConnLimiter struct { 15 | mutex *sync.Mutex 16 | extract utils.SourceExtractor 17 | connections map[string]int64 18 | maxConnections int64 19 | totalConnections int64 20 | next http.Handler 21 | 22 | errHandler utils.ErrorHandler 23 | log utils.Logger 24 | } 25 | 26 | func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) { 27 | if extract == nil { 28 | return nil, fmt.Errorf("Extract function can not be nil") 29 | } 30 | cl := &ConnLimiter{ 31 | mutex: &sync.Mutex{}, 32 | extract: extract, 33 | maxConnections: maxConnections, 34 | connections: make(map[string]int64), 35 | next: next, 36 | } 37 | 38 | for _, o := range options { 39 | if err := o(cl); err != nil { 40 | return nil, err 41 | } 42 | } 43 | if cl.log == nil { 44 | cl.log = utils.NullLogger 45 | } 46 | if cl.errHandler == nil { 47 | cl.errHandler = defaultErrHandler 48 | } 49 | return cl, nil 50 | } 51 | 52 | func (cl *ConnLimiter) Wrap(h http.Handler) { 53 | cl.next = h 54 | } 55 | 56 | func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { 57 | token, amount, err := cl.extract.Extract(r) 58 | if err != nil { 59 | cl.log.Errorf("failed to extract source of the connection: %v", err) 60 | cl.errHandler.ServeHTTP(w, r, err) 61 | return 62 | } 63 | if err := cl.acquire(token, amount); err != nil { 64 | cl.log.Infof("limiting request source %s: %v", token, err) 65 | cl.errHandler.ServeHTTP(w, r, err) 66 | return 67 | } 68 | 69 | defer cl.release(token, amount) 70 | 71 | cl.next.ServeHTTP(w, r) 72 | } 73 | 74 | func (cl *ConnLimiter) acquire(token string, amount int64) error { 75 | cl.mutex.Lock() 76 | defer cl.mutex.Unlock() 77 | 78 | connections := cl.connections[token] 79 | if connections >= cl.maxConnections { 80 | return &MaxConnError{max: cl.maxConnections} 81 | } 82 | 83 | cl.connections[token] += amount 84 | cl.totalConnections += int64(amount) 85 | return nil 86 | } 87 | 88 | func (cl *ConnLimiter) release(token string, amount int64) { 89 | cl.mutex.Lock() 90 | defer cl.mutex.Unlock() 91 | 92 | cl.connections[token] -= amount 93 | cl.totalConnections -= int64(amount) 94 | 95 | // Otherwise it would grow forever 96 | if cl.connections[token] == 0 { 97 | delete(cl.connections, token) 98 | } 99 | } 100 | 101 | type MaxConnError struct { 102 | max int64 103 | } 104 | 105 | func (m *MaxConnError) Error() string { 106 | return fmt.Sprintf("max connections reached: %d", m.max) 107 | } 108 | 109 | type ConnErrHandler struct { 110 | } 111 | 112 | func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { 113 | if _, ok := err.(*MaxConnError); ok { 114 | w.WriteHeader(429) 115 | w.Write([]byte(err.Error())) 116 | return 117 | } 118 | utils.DefaultHandler.ServeHTTP(w, req, err) 119 | } 120 | 121 | type ConnLimitOption func(l *ConnLimiter) error 122 | 123 | // Logger sets the logger that will be used by this middleware. 124 | func Logger(l utils.Logger) ConnLimitOption { 125 | return func(cl *ConnLimiter) error { 126 | cl.log = l 127 | return nil 128 | } 129 | } 130 | 131 | // ErrorHandler sets error handler of the server 132 | func ErrorHandler(h utils.ErrorHandler) ConnLimitOption { 133 | return func(cl *ConnLimiter) error { 134 | cl.errHandler = h 135 | return nil 136 | } 137 | } 138 | 139 | var defaultErrHandler = &ConnErrHandler{} 140 | -------------------------------------------------------------------------------- /connlimit/connlimit_test.go: -------------------------------------------------------------------------------- 1 | package connlimit 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/vulcand/oxy/testutils" 11 | "github.com/vulcand/oxy/utils" 12 | 13 | . "gopkg.in/check.v1" 14 | ) 15 | 16 | func TestConn(t *testing.T) { TestingT(t) } 17 | 18 | type ConnLimiterSuite struct { 19 | } 20 | 21 | var _ = Suite(&ConnLimiterSuite{}) 22 | 23 | func (s *ConnLimiterSuite) SetUpSuite(c *C) { 24 | } 25 | 26 | // We've hit the limit and were able to proceed once the request has completed 27 | func (s *ConnLimiterSuite) TestHitLimitAndRelease(c *C) { 28 | wait := make(chan bool) 29 | proceed := make(chan bool) 30 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 31 | if req.Header.Get("wait") != "" { 32 | proceed <- true 33 | <-wait 34 | } 35 | w.Write([]byte("hello")) 36 | }) 37 | 38 | l, err := New(handler, headerLimit, 1) 39 | c.Assert(err, Equals, nil) 40 | 41 | srv := httptest.NewServer(l) 42 | defer srv.Close() 43 | 44 | go func() { 45 | re, _, err := testutils.Get(srv.URL, testutils.Header("Limit", "a"), testutils.Header("wait", "yes")) 46 | c.Assert(err, IsNil) 47 | c.Assert(re.StatusCode, Equals, http.StatusOK) 48 | }() 49 | 50 | <-proceed 51 | 52 | re, _, err := testutils.Get(srv.URL, testutils.Header("Limit", "a")) 53 | c.Assert(err, IsNil) 54 | c.Assert(re.StatusCode, Equals, 429) 55 | 56 | // request from another source succeeds 57 | re, _, err = testutils.Get(srv.URL, testutils.Header("Limit", "b")) 58 | c.Assert(err, IsNil) 59 | c.Assert(re.StatusCode, Equals, http.StatusOK) 60 | 61 | // Once the first request finished, next one succeeds 62 | close(wait) 63 | 64 | re, _, err = testutils.Get(srv.URL, testutils.Header("Limit", "a")) 65 | c.Assert(err, IsNil) 66 | c.Assert(re.StatusCode, Equals, http.StatusOK) 67 | } 68 | 69 | // We've hit the limit and were able to proceed once the request has completed 70 | func (s *ConnLimiterSuite) TestCustomHandlers(c *C) { 71 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 72 | w.Write([]byte("hello")) 73 | }) 74 | 75 | errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { 76 | w.WriteHeader(http.StatusTeapot) 77 | w.Write([]byte(http.StatusText(http.StatusTeapot))) 78 | }) 79 | 80 | buf := &bytes.Buffer{} 81 | log := utils.NewFileLogger(buf, utils.INFO) 82 | 83 | l, err := New(handler, headerLimit, 0, ErrorHandler(errHandler), Logger(log)) 84 | c.Assert(err, Equals, nil) 85 | 86 | srv := httptest.NewServer(l) 87 | defer srv.Close() 88 | 89 | re, _, err := testutils.Get(srv.URL, testutils.Header("Limit", "a")) 90 | c.Assert(err, IsNil) 91 | c.Assert(re.StatusCode, Equals, http.StatusTeapot) 92 | 93 | c.Assert(len(buf.String()), Not(Equals), 0) 94 | } 95 | 96 | // We've hit the limit and were able to proceed once the request has completed 97 | func (s *ConnLimiterSuite) TestFaultyExtract(c *C) { 98 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 99 | w.Write([]byte("hello")) 100 | }) 101 | 102 | l, err := New(handler, faultyExtract, 1) 103 | c.Assert(err, Equals, nil) 104 | 105 | srv := httptest.NewServer(l) 106 | defer srv.Close() 107 | 108 | re, _, err := testutils.Get(srv.URL) 109 | c.Assert(err, IsNil) 110 | c.Assert(re.StatusCode, Equals, http.StatusInternalServerError) 111 | } 112 | 113 | func headerLimiter(req *http.Request) (string, int64, error) { 114 | return req.Header.Get("Limit"), 1, nil 115 | } 116 | 117 | func faultyExtractor(req *http.Request) (string, int64, error) { 118 | return "", -1, fmt.Errorf("oops") 119 | } 120 | 121 | var headerLimit = utils.ExtractorFunc(headerLimiter) 122 | var faultyExtract = utils.ExtractorFunc(faultyExtractor) 123 | -------------------------------------------------------------------------------- /forward/headers.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | const ( 4 | XForwardedProto = "X-Forwarded-Proto" 5 | XForwardedFor = "X-Forwarded-For" 6 | XForwardedHost = "X-Forwarded-Host" 7 | XForwardedPort = "X-Forwarded-Port" 8 | XForwardedServer = "X-Forwarded-Server" 9 | XRealIp = "X-Real-Ip" 10 | Connection = "Connection" 11 | KeepAlive = "Keep-Alive" 12 | ProxyAuthenticate = "Proxy-Authenticate" 13 | ProxyAuthorization = "Proxy-Authorization" 14 | Te = "Te" // canonicalized version of "TE" 15 | Trailers = "Trailers" 16 | TransferEncoding = "Transfer-Encoding" 17 | Upgrade = "Upgrade" 18 | ContentLength = "Content-Length" 19 | ContentType = "Content-Type" 20 | SecWebsocketKey = "Sec-Websocket-Key" 21 | SecWebsocketVersion = "Sec-Websocket-Version" 22 | SecWebsocketExtensions = "Sec-Websocket-Extensions" 23 | SecWebsocketAccept = "Sec-Websocket-Accept" 24 | ) 25 | 26 | // Hop-by-hop headers. These are removed when sent to the backend. 27 | // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html 28 | // Copied from reverseproxy.go, too bad 29 | var HopHeaders = []string{ 30 | Connection, 31 | KeepAlive, 32 | ProxyAuthenticate, 33 | ProxyAuthorization, 34 | Te, // canonicalized version of "TE" 35 | Trailers, 36 | TransferEncoding, 37 | Upgrade, 38 | } 39 | 40 | var WebsocketDialHeaders = []string{ 41 | Upgrade, 42 | Connection, 43 | SecWebsocketKey, 44 | SecWebsocketVersion, 45 | SecWebsocketExtensions, 46 | SecWebsocketAccept, 47 | } 48 | 49 | var WebsocketUpgradeHeaders = []string{ 50 | Upgrade, 51 | Connection, 52 | SecWebsocketAccept, 53 | } 54 | 55 | var XHeaders = []string{ 56 | XForwardedProto, 57 | XForwardedFor, 58 | XForwardedHost, 59 | XForwardedPort, 60 | XForwardedServer, 61 | XRealIp, 62 | } 63 | -------------------------------------------------------------------------------- /forward/responseflusher.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | ) 9 | 10 | var ( 11 | _ http.Hijacker = &responseFlusher{} 12 | _ http.Flusher = &responseFlusher{} 13 | _ http.CloseNotifier = &responseFlusher{} 14 | ) 15 | 16 | type responseFlusher struct { 17 | http.ResponseWriter 18 | flush bool 19 | } 20 | 21 | func newResponseFlusher(rw http.ResponseWriter, flush bool) *responseFlusher { 22 | return &responseFlusher{ 23 | ResponseWriter: rw, 24 | flush: flush, 25 | } 26 | } 27 | 28 | func (wf *responseFlusher) Write(p []byte) (int, error) { 29 | written, err := wf.ResponseWriter.Write(p) 30 | if wf.flush { 31 | wf.Flush() 32 | } 33 | return written, err 34 | } 35 | 36 | func (wf *responseFlusher) Hijack() (net.Conn, *bufio.ReadWriter, error) { 37 | hijacker, ok := wf.ResponseWriter.(http.Hijacker) 38 | if !ok { 39 | return nil, nil, fmt.Errorf("the ResponseWriter doesn't support the Hijacker interface") 40 | } 41 | return hijacker.Hijack() 42 | } 43 | 44 | func (wf *responseFlusher) CloseNotify() <-chan bool { 45 | return wf.ResponseWriter.(http.CloseNotifier).CloseNotify() 46 | } 47 | 48 | func (wf *responseFlusher) Flush() { 49 | flusher, ok := wf.ResponseWriter.(http.Flusher) 50 | if ok { 51 | flusher.Flush() 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /forward/rewrite.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | import ( 4 | "net" 5 | "net/http" 6 | "strings" 7 | 8 | "github.com/vulcand/oxy/utils" 9 | ) 10 | 11 | // Rewriter is responsible for removing hop-by-hop headers and setting forwarding headers 12 | type HeaderRewriter struct { 13 | TrustForwardHeader bool 14 | Hostname string 15 | } 16 | 17 | func (rw *HeaderRewriter) Rewrite(req *http.Request) { 18 | if !rw.TrustForwardHeader { 19 | utils.RemoveHeaders(req.Header, XHeaders...) 20 | } 21 | 22 | if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { 23 | if prior, ok := req.Header[XForwardedFor]; ok { 24 | req.Header.Set(XForwardedFor, strings.Join(prior, ", ")+", "+clientIP) 25 | } else { 26 | req.Header.Set(XForwardedFor, clientIP) 27 | } 28 | 29 | if req.Header.Get(XRealIp) == "" { 30 | req.Header.Set(XRealIp, clientIP) 31 | } 32 | } 33 | 34 | xfProto := req.Header.Get(XForwardedProto) 35 | if xfProto == "" { 36 | if req.TLS != nil { 37 | req.Header.Set(XForwardedProto, "https") 38 | } else { 39 | req.Header.Set(XForwardedProto, "http") 40 | } 41 | } 42 | 43 | if xfp := req.Header.Get(XForwardedPort); xfp == "" { 44 | req.Header.Set(XForwardedPort, forwardedPort(req)) 45 | } 46 | 47 | if xfHost := req.Header.Get(XForwardedHost); xfHost == "" && req.Host != "" { 48 | req.Header.Set(XForwardedHost, req.Host) 49 | } 50 | 51 | if rw.Hostname != "" { 52 | req.Header.Set(XForwardedServer, rw.Hostname) 53 | } 54 | 55 | // Remove hop-by-hop headers to the backend. Especially important is "Connection" because we want a persistent 56 | // connection, regardless of what the client sent to us. 57 | utils.RemoveHeaders(req.Header, HopHeaders...) 58 | } 59 | 60 | func forwardedPort(req *http.Request) string { 61 | if req == nil { 62 | return "" 63 | } 64 | 65 | if _, port, err := net.SplitHostPort(req.Host); err == nil && port != "" { 66 | return port 67 | } 68 | 69 | if req.TLS != nil { 70 | return "443" 71 | } 72 | 73 | return "80" 74 | } 75 | -------------------------------------------------------------------------------- /forward/rewrite_test.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "testing" 7 | ) 8 | 9 | func TestRewriter(t *testing.T) { 10 | testCases := []struct { 11 | desc string 12 | url string 13 | remoteAddr string 14 | host string 15 | hostName string 16 | trustForwardHeader bool 17 | reqHeaders map[string]string 18 | expectedHeaders map[string]string 19 | }{ 20 | { 21 | desc: "don't trust X Headers", 22 | url: "http://foo.bar", 23 | remoteAddr: "fii.bir:800", 24 | hostName: "fuu.bur", 25 | trustForwardHeader: false, 26 | reqHeaders: dumbHeaders(XHeaders), 27 | expectedHeaders: map[string]string{ 28 | XForwardedProto: "http", 29 | XForwardedFor: "fii.bir", 30 | XForwardedHost: "foo.bar", 31 | XForwardedPort: "80", 32 | XForwardedServer: "fuu.bur", 33 | XRealIp: "fii.bir", 34 | }, 35 | }, 36 | { 37 | desc: "trust X Headers", 38 | url: "http://foo.bar", 39 | remoteAddr: "fii.bir:800", 40 | trustForwardHeader: true, 41 | reqHeaders: dumbHeaders(XHeaders), 42 | expectedHeaders: map[string]string{ 43 | XForwardedProto: "fake", 44 | XForwardedFor: "fake, fii.bir", 45 | XForwardedHost: "fake", 46 | XForwardedPort: "fake", 47 | XForwardedServer: "fake", 48 | XRealIp: "fake", 49 | }, 50 | }, 51 | { 52 | desc: "no X Headers", 53 | url: "http://foo.bar", 54 | remoteAddr: "fii.bir:800", 55 | hostName: "fuu.bur", 56 | trustForwardHeader: true, 57 | reqHeaders: make(map[string]string), 58 | expectedHeaders: map[string]string{ 59 | XForwardedProto: "http", 60 | XForwardedFor: "fii.bir", 61 | XForwardedHost: "foo.bar", 62 | XForwardedPort: "80", 63 | XForwardedServer: "fuu.bur", 64 | XRealIp: "fii.bir", 65 | }, 66 | }, 67 | { 68 | desc: "request host", 69 | url: "http://127.0.0.1:8000/", 70 | remoteAddr: "fii.bir:800", 71 | host: "fyy.byr", 72 | hostName: "fuu.bur", 73 | trustForwardHeader: false, 74 | expectedHeaders: map[string]string{ 75 | XForwardedProto: "http", 76 | XForwardedFor: "fii.bir", 77 | XForwardedHost: "fyy.byr", 78 | XForwardedPort: "80", 79 | XForwardedServer: "fuu.bur", 80 | XRealIp: "fii.bir", 81 | }, 82 | }, 83 | } 84 | 85 | for _, test := range testCases { 86 | test := test 87 | t.Run(test.desc, func(t *testing.T) { 88 | hr := HeaderRewriter{ 89 | TrustForwardHeader: test.trustForwardHeader, 90 | Hostname: test.hostName, 91 | } 92 | 93 | req, err := http.NewRequest(http.MethodGet, test.url, nil) 94 | if err != nil { 95 | t.Fatal(err) 96 | } 97 | if test.host != "" { 98 | req.Host = test.host 99 | } 100 | if test.remoteAddr != "" { 101 | req.RemoteAddr = test.remoteAddr 102 | } 103 | 104 | for key, value := range test.reqHeaders { 105 | req.Header.Add(key, value) 106 | } 107 | 108 | hr.Rewrite(req) 109 | 110 | for key, expectedValue := range test.expectedHeaders { 111 | currentValue := req.Header.Get(key) 112 | if currentValue != expectedValue { 113 | t.Errorf("key: %s, currentValue: %s, expectedValue: %s", key, currentValue, expectedValue) 114 | } 115 | } 116 | if t.Failed() { 117 | for key, currentValue := range req.Header { 118 | fmt.Println(key, currentValue) 119 | } 120 | } 121 | }) 122 | } 123 | 124 | } 125 | 126 | func TestRewriterCleanHopHeaders(t *testing.T) { 127 | hr := HeaderRewriter{} 128 | 129 | req, err := http.NewRequest(http.MethodGet, "http://foo.bar", nil) 130 | 131 | for key, value := range dumbHeaders(HopHeaders) { 132 | req.Header.Add(key, value) 133 | } 134 | 135 | if err != nil { 136 | t.Fatal(err) 137 | } 138 | 139 | hr.Rewrite(req) 140 | 141 | for _, hop := range HopHeaders { 142 | if req.Header.Get(hop) != "" { 143 | t.Errorf("error %s", hop) 144 | } 145 | } 146 | } 147 | 148 | func dumbHeaders(selectedHeaders []string) map[string]string { 149 | headers := make(map[string]string) 150 | for _, header := range selectedHeaders { 151 | headers[header] = "fake" 152 | } 153 | return headers 154 | } 155 | -------------------------------------------------------------------------------- /memmetrics/anomaly.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "math" 5 | "sort" 6 | "time" 7 | ) 8 | 9 | // SplitRatios provides simple anomaly detection for requests latencies. 10 | // it splits values into good or bad category based on the threshold and the median value. 11 | // If all values are not far from the median, it will return all values in 'good' set. 12 | // Precision is the smallest value to consider, e.g. if set to millisecond, microseconds will be ignored. 13 | func SplitLatencies(values []time.Duration, precision time.Duration) (good map[time.Duration]bool, bad map[time.Duration]bool) { 14 | // Find the max latency M and then map each latency L to the ratio L/M and then call SplitFloat64 15 | v2r := map[float64]time.Duration{} 16 | ratios := make([]float64, len(values)) 17 | m := maxTime(values) 18 | for i, v := range values { 19 | ratio := float64(v/precision+1) / float64(m/precision+1) // +1 is to avoid division by 0 20 | v2r[ratio] = v 21 | ratios[i] = ratio 22 | } 23 | good, bad = make(map[time.Duration]bool), make(map[time.Duration]bool) 24 | // Note that multiplier makes this function way less sensitive than ratios detector, this is to avoid noise. 25 | vgood, vbad := SplitFloat64(2, 0, ratios) 26 | for r, _ := range vgood { 27 | good[v2r[r]] = true 28 | } 29 | for r, _ := range vbad { 30 | bad[v2r[r]] = true 31 | } 32 | return good, bad 33 | } 34 | 35 | // SplitRatios provides simple anomaly detection for ratio values, that are all in the range [0, 1] 36 | // it splits values into good or bad category based on the threshold and the median value. 37 | // If all values are not far from the median, it will return all values in 'good' set. 38 | func SplitRatios(values []float64) (good map[float64]bool, bad map[float64]bool) { 39 | return SplitFloat64(1.5, 0, values) 40 | } 41 | 42 | // SplitFloat64 provides simple anomaly detection for skewed data sets with no particular distribution. 43 | // In essense it applies the formula if(v > median(values) + threshold * medianAbsoluteDeviation) -> anomaly 44 | // There's a corner case where there are just 2 values, so by definition there's no value that exceeds the threshold. 45 | // This case is solved by introducing additional value that we know is good, e.g. 0. That helps to improve the detection results 46 | // on such data sets. 47 | func SplitFloat64(threshold, sentinel float64, values []float64) (good map[float64]bool, bad map[float64]bool) { 48 | good, bad = make(map[float64]bool), make(map[float64]bool) 49 | var newValues []float64 50 | if len(values)%2 == 0 { 51 | newValues = make([]float64, len(values)+1) 52 | copy(newValues, values) 53 | // Add a sentinel endpoint so we can distinguish outliers better 54 | newValues[len(newValues)-1] = sentinel 55 | } else { 56 | newValues = values 57 | } 58 | 59 | m := median(newValues) 60 | mAbs := medianAbsoluteDeviation(newValues) 61 | for _, v := range values { 62 | if v > (m+mAbs)*threshold { 63 | bad[v] = true 64 | } else { 65 | good[v] = true 66 | } 67 | } 68 | return good, bad 69 | } 70 | 71 | func median(values []float64) float64 { 72 | vals := make([]float64, len(values)) 73 | copy(vals, values) 74 | sort.Float64s(vals) 75 | l := len(vals) 76 | if l%2 != 0 { 77 | return vals[l/2] 78 | } 79 | return (vals[l/2-1] + vals[l/2]) / 2.0 80 | } 81 | 82 | func medianAbsoluteDeviation(values []float64) float64 { 83 | m := median(values) 84 | distances := make([]float64, len(values)) 85 | for i, v := range values { 86 | distances[i] = math.Abs(v - m) 87 | } 88 | return median(distances) 89 | } 90 | 91 | func maxTime(vals []time.Duration) time.Duration { 92 | val := vals[0] 93 | for _, v := range vals { 94 | if v > val { 95 | val = v 96 | } 97 | } 98 | return val 99 | } 100 | -------------------------------------------------------------------------------- /memmetrics/anomaly_test.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "time" 5 | 6 | . "gopkg.in/check.v1" 7 | ) 8 | 9 | type AnomalySuite struct { 10 | } 11 | 12 | var _ = Suite(&AnomalySuite{}) 13 | 14 | func (s *AnomalySuite) TestMedian(c *C) { 15 | c.Assert(median([]float64{0.1, 0.2}), Equals, (float64(0.1)+float64(0.2))/2.0) 16 | c.Assert(median([]float64{0.3, 0.2, 0.5}), Equals, 0.3) 17 | } 18 | 19 | func (s *AnomalySuite) TestSplitRatios(c *C) { 20 | vals := []struct { 21 | values []float64 22 | good []float64 23 | bad []float64 24 | }{ 25 | { 26 | values: []float64{0, 0}, 27 | good: []float64{0}, 28 | bad: []float64{}, 29 | }, 30 | 31 | { 32 | values: []float64{0, 1}, 33 | good: []float64{0}, 34 | bad: []float64{1}, 35 | }, 36 | { 37 | values: []float64{0.1, 0.1}, 38 | good: []float64{0.1}, 39 | bad: []float64{}, 40 | }, 41 | 42 | { 43 | values: []float64{0.15, 0.1}, 44 | good: []float64{0.15, 0.1}, 45 | bad: []float64{}, 46 | }, 47 | { 48 | values: []float64{0.01, 0.01}, 49 | good: []float64{0.01}, 50 | bad: []float64{}, 51 | }, 52 | { 53 | values: []float64{0.012, 0.01, 1}, 54 | good: []float64{0.012, 0.01}, 55 | bad: []float64{1}, 56 | }, 57 | { 58 | values: []float64{0, 0, 1, 1}, 59 | good: []float64{0}, 60 | bad: []float64{1}, 61 | }, 62 | { 63 | values: []float64{0, 0.1, 0.1, 0}, 64 | good: []float64{0}, 65 | bad: []float64{0.1}, 66 | }, 67 | { 68 | values: []float64{0, 0.01, 0.1, 0}, 69 | good: []float64{0}, 70 | bad: []float64{0.01, 0.1}, 71 | }, 72 | { 73 | values: []float64{0, 0.01, 0.02, 1}, 74 | good: []float64{0, 0.01, 0.02}, 75 | bad: []float64{1}, 76 | }, 77 | { 78 | values: []float64{0, 0, 0, 0, 0, 0.01, 0.02, 1}, 79 | good: []float64{0}, 80 | bad: []float64{0.01, 0.02, 1}, 81 | }, 82 | } 83 | for _, v := range vals { 84 | good, bad := SplitRatios(v.values) 85 | vgood, vbad := make(map[float64]bool, len(v.good)), make(map[float64]bool, len(v.bad)) 86 | for _, v := range v.good { 87 | vgood[v] = true 88 | } 89 | for _, v := range v.bad { 90 | vbad[v] = true 91 | } 92 | 93 | c.Assert(good, DeepEquals, vgood) 94 | c.Assert(bad, DeepEquals, vbad) 95 | } 96 | } 97 | 98 | func (s *AnomalySuite) TestSplitLatencies(c *C) { 99 | vals := []struct { 100 | values []int 101 | good []int 102 | bad []int 103 | }{ 104 | { 105 | values: []int{0, 0}, 106 | good: []int{0}, 107 | bad: []int{}, 108 | }, 109 | { 110 | values: []int{1, 2}, 111 | good: []int{1, 2}, 112 | bad: []int{}, 113 | }, 114 | { 115 | values: []int{1, 2, 4}, 116 | good: []int{1, 2, 4}, 117 | bad: []int{}, 118 | }, 119 | { 120 | values: []int{8, 8, 18}, 121 | good: []int{8}, 122 | bad: []int{18}, 123 | }, 124 | { 125 | values: []int{32, 28, 11, 26, 19, 51, 25, 39, 28, 26, 8, 97}, 126 | good: []int{32, 28, 11, 26, 19, 51, 25, 39, 28, 26, 8}, 127 | bad: []int{97}, 128 | }, 129 | { 130 | values: []int{1, 2, 4, 40}, 131 | good: []int{1, 2, 4}, 132 | bad: []int{40}, 133 | }, 134 | { 135 | values: []int{40, 60, 1000}, 136 | good: []int{40, 60}, 137 | bad: []int{1000}, 138 | }, 139 | } 140 | for _, v := range vals { 141 | vvalues := make([]time.Duration, len(v.values)) 142 | for i, d := range v.values { 143 | vvalues[i] = time.Millisecond * time.Duration(d) 144 | } 145 | good, bad := SplitLatencies(vvalues, time.Millisecond) 146 | 147 | vgood, vbad := make(map[time.Duration]bool, len(v.good)), make(map[time.Duration]bool, len(v.bad)) 148 | for _, v := range v.good { 149 | vgood[time.Duration(v)*time.Millisecond] = true 150 | } 151 | for _, v := range v.bad { 152 | vbad[time.Duration(v)*time.Millisecond] = true 153 | } 154 | 155 | c.Assert(good, DeepEquals, vgood) 156 | c.Assert(bad, DeepEquals, vbad) 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /memmetrics/counter.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/mailgun/timetools" 8 | ) 9 | 10 | type rcOptSetter func(*RollingCounter) error 11 | 12 | func CounterClock(c timetools.TimeProvider) rcOptSetter { 13 | return func(r *RollingCounter) error { 14 | r.clock = c 15 | return nil 16 | } 17 | } 18 | 19 | // Calculates in memory failure rate of an endpoint using rolling window of a predefined size 20 | type RollingCounter struct { 21 | clock timetools.TimeProvider 22 | resolution time.Duration 23 | values []int 24 | countedBuckets int // how many samples in different buckets have we collected so far 25 | lastBucket int // last recorded bucket 26 | lastUpdated time.Time 27 | } 28 | 29 | // NewCounter creates a counter with fixed amount of buckets that are rotated every resolution period. 30 | // E.g. 10 buckets with 1 second means that every new second the bucket is refreshed, so it maintains 10 second rolling window. 31 | // By default creates a bucket with 10 buckets and 1 second resolution 32 | func NewCounter(buckets int, resolution time.Duration, options ...rcOptSetter) (*RollingCounter, error) { 33 | if buckets <= 0 { 34 | return nil, fmt.Errorf("Buckets should be >= 0") 35 | } 36 | if resolution < time.Second { 37 | return nil, fmt.Errorf("Resolution should be larger than a second") 38 | } 39 | 40 | rc := &RollingCounter{ 41 | lastBucket: -1, 42 | resolution: resolution, 43 | 44 | values: make([]int, buckets), 45 | } 46 | 47 | for _, o := range options { 48 | if err := o(rc); err != nil { 49 | return nil, err 50 | } 51 | } 52 | 53 | if rc.clock == nil { 54 | rc.clock = &timetools.RealTime{} 55 | } 56 | 57 | return rc, nil 58 | } 59 | 60 | func (c *RollingCounter) Append(o *RollingCounter) error { 61 | c.Inc(int(o.Count())) 62 | return nil 63 | } 64 | 65 | func (c *RollingCounter) Clone() *RollingCounter { 66 | c.cleanup() 67 | other := &RollingCounter{ 68 | resolution: c.resolution, 69 | values: make([]int, len(c.values)), 70 | clock: c.clock, 71 | lastBucket: c.lastBucket, 72 | lastUpdated: c.lastUpdated, 73 | } 74 | for i, v := range c.values { 75 | other.values[i] = v 76 | } 77 | return other 78 | } 79 | 80 | func (c *RollingCounter) Reset() { 81 | c.lastBucket = -1 82 | c.countedBuckets = 0 83 | c.lastUpdated = time.Time{} 84 | for i := range c.values { 85 | c.values[i] = 0 86 | } 87 | } 88 | 89 | func (c *RollingCounter) CountedBuckets() int { 90 | return c.countedBuckets 91 | } 92 | 93 | func (c *RollingCounter) Count() int64 { 94 | c.cleanup() 95 | return c.sum() 96 | } 97 | 98 | func (c *RollingCounter) Resolution() time.Duration { 99 | return c.resolution 100 | } 101 | 102 | func (c *RollingCounter) Buckets() int { 103 | return len(c.values) 104 | } 105 | 106 | func (c *RollingCounter) WindowSize() time.Duration { 107 | return time.Duration(len(c.values)) * c.resolution 108 | } 109 | 110 | func (c *RollingCounter) Inc(v int) { 111 | c.cleanup() 112 | c.incBucketValue(v) 113 | } 114 | 115 | func (c *RollingCounter) incBucketValue(v int) { 116 | now := c.clock.UtcNow() 117 | bucket := c.getBucket(now) 118 | c.values[bucket] += v 119 | c.lastUpdated = now 120 | // Update usage stats if we haven't collected enough data 121 | if c.countedBuckets < len(c.values) { 122 | // Only update if we have advanced to the next bucket and not incremented the value 123 | // in the current bucket. 124 | if c.lastBucket != bucket { 125 | c.lastBucket = bucket 126 | c.countedBuckets++ 127 | } 128 | } 129 | } 130 | 131 | // Returns the number in the moving window bucket that this slot occupies 132 | func (c *RollingCounter) getBucket(t time.Time) int { 133 | return int(t.Truncate(c.resolution).Unix() % int64(len(c.values))) 134 | } 135 | 136 | // Reset buckets that were not updated 137 | func (c *RollingCounter) cleanup() { 138 | now := c.clock.UtcNow() 139 | for i := 0; i < len(c.values); i++ { 140 | now = now.Add(time.Duration(-1*i) * c.resolution) 141 | if now.Truncate(c.resolution).After(c.lastUpdated.Truncate(c.resolution)) { 142 | c.values[c.getBucket(now)] = 0 143 | } else { 144 | break 145 | } 146 | } 147 | } 148 | 149 | func (c *RollingCounter) sum() int64 { 150 | out := int64(0) 151 | for _, v := range c.values { 152 | out += int64(v) 153 | } 154 | return out 155 | } 156 | -------------------------------------------------------------------------------- /memmetrics/counter_test.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/mailgun/timetools" 7 | . "gopkg.in/check.v1" 8 | ) 9 | 10 | type CounterSuite struct { 11 | clock *timetools.FreezedTime 12 | } 13 | 14 | var _ = Suite(&CounterSuite{}) 15 | 16 | func (s *CounterSuite) SetUpSuite(c *C) { 17 | s.clock = &timetools.FreezedTime{ 18 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 19 | } 20 | } 21 | 22 | func (s *CounterSuite) TestCloneExpired(c *C) { 23 | cnt, err := NewCounter(3, time.Second, CounterClock(s.clock)) 24 | c.Assert(err, IsNil) 25 | cnt.Inc(1) 26 | s.clock.Sleep(time.Second) 27 | cnt.Inc(1) 28 | s.clock.Sleep(time.Second) 29 | cnt.Inc(1) 30 | s.clock.Sleep(time.Second) 31 | 32 | out := cnt.Clone() 33 | c.Assert(out.Count(), Equals, int64(2)) 34 | } 35 | -------------------------------------------------------------------------------- /memmetrics/histogram.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/codahale/hdrhistogram" 8 | "github.com/mailgun/timetools" 9 | ) 10 | 11 | // HDRHistogram is a tiny wrapper around github.com/codahale/hdrhistogram that provides convenience functions for measuring http latencies 12 | type HDRHistogram struct { 13 | // lowest trackable value 14 | low int64 15 | // highest trackable value 16 | high int64 17 | // significant figures 18 | sigfigs int 19 | 20 | h *hdrhistogram.Histogram 21 | } 22 | 23 | func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error) { 24 | defer func() { 25 | if msg := recover(); msg != nil { 26 | err = fmt.Errorf("%s", msg) 27 | } 28 | }() 29 | return &HDRHistogram{ 30 | low: low, 31 | high: high, 32 | sigfigs: sigfigs, 33 | h: hdrhistogram.New(low, high, sigfigs), 34 | }, nil 35 | } 36 | 37 | // Returns latency at quantile with microsecond precision 38 | func (h *HDRHistogram) LatencyAtQuantile(q float64) time.Duration { 39 | return time.Duration(h.ValueAtQuantile(q)) * time.Microsecond 40 | } 41 | 42 | // Records latencies with microsecond precision 43 | func (h *HDRHistogram) RecordLatencies(d time.Duration, n int64) error { 44 | return h.RecordValues(int64(d/time.Microsecond), n) 45 | } 46 | 47 | func (h *HDRHistogram) Reset() { 48 | h.h.Reset() 49 | } 50 | 51 | func (h *HDRHistogram) ValueAtQuantile(q float64) int64 { 52 | return h.h.ValueAtQuantile(q) 53 | } 54 | 55 | func (h *HDRHistogram) RecordValues(v, n int64) error { 56 | return h.h.RecordValues(v, n) 57 | } 58 | 59 | func (h *HDRHistogram) Merge(other *HDRHistogram) error { 60 | if other == nil { 61 | return fmt.Errorf("other is nil") 62 | } 63 | h.h.Merge(other.h) 64 | return nil 65 | } 66 | 67 | type rhOptSetter func(r *RollingHDRHistogram) error 68 | 69 | func RollingClock(clock timetools.TimeProvider) rhOptSetter { 70 | return func(r *RollingHDRHistogram) error { 71 | r.clock = clock 72 | return nil 73 | } 74 | } 75 | 76 | // RollingHistogram holds multiple histograms and rotates every period. 77 | // It provides resulting histogram as a result of a call of 'Merged' function. 78 | type RollingHDRHistogram struct { 79 | idx int 80 | lastRoll time.Time 81 | period time.Duration 82 | bucketCount int 83 | low int64 84 | high int64 85 | sigfigs int 86 | buckets []*HDRHistogram 87 | clock timetools.TimeProvider 88 | } 89 | 90 | func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration, bucketCount int, options ...rhOptSetter) (*RollingHDRHistogram, error) { 91 | rh := &RollingHDRHistogram{ 92 | bucketCount: bucketCount, 93 | period: period, 94 | low: low, 95 | high: high, 96 | sigfigs: sigfigs, 97 | } 98 | 99 | for _, o := range options { 100 | if err := o(rh); err != nil { 101 | return nil, err 102 | } 103 | } 104 | 105 | if rh.clock == nil { 106 | rh.clock = &timetools.RealTime{} 107 | } 108 | 109 | buckets := make([]*HDRHistogram, rh.bucketCount) 110 | for i := range buckets { 111 | h, err := NewHDRHistogram(low, high, sigfigs) 112 | if err != nil { 113 | return nil, err 114 | } 115 | buckets[i] = h 116 | } 117 | rh.buckets = buckets 118 | return rh, nil 119 | } 120 | 121 | func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error { 122 | if r.bucketCount != o.bucketCount || r.period != o.period || r.low != o.low || r.high != o.high || r.sigfigs != o.sigfigs { 123 | return fmt.Errorf("can't merge") 124 | } 125 | 126 | for i := range r.buckets { 127 | if err := r.buckets[i].Merge(o.buckets[i]); err != nil { 128 | return err 129 | } 130 | } 131 | return nil 132 | } 133 | 134 | func (r *RollingHDRHistogram) Reset() { 135 | r.idx = 0 136 | r.lastRoll = r.clock.UtcNow() 137 | for _, b := range r.buckets { 138 | b.Reset() 139 | } 140 | } 141 | 142 | func (r *RollingHDRHistogram) rotate() { 143 | r.idx = (r.idx + 1) % len(r.buckets) 144 | r.buckets[r.idx].Reset() 145 | } 146 | 147 | func (r *RollingHDRHistogram) Merged() (*HDRHistogram, error) { 148 | m, err := NewHDRHistogram(r.low, r.high, r.sigfigs) 149 | if err != nil { 150 | return m, err 151 | } 152 | for _, h := range r.buckets { 153 | if m.Merge(h); err != nil { 154 | return nil, err 155 | } 156 | } 157 | return m, nil 158 | } 159 | 160 | func (r *RollingHDRHistogram) getHist() *HDRHistogram { 161 | if r.clock.UtcNow().Sub(r.lastRoll) >= r.period { 162 | r.rotate() 163 | r.lastRoll = r.clock.UtcNow() 164 | } 165 | return r.buckets[r.idx] 166 | } 167 | 168 | func (r *RollingHDRHistogram) RecordLatencies(v time.Duration, n int64) error { 169 | return r.getHist().RecordLatencies(v, n) 170 | } 171 | 172 | func (r *RollingHDRHistogram) RecordValues(v, n int64) error { 173 | return r.getHist().RecordValues(v, n) 174 | } 175 | -------------------------------------------------------------------------------- /memmetrics/histogram_test.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/mailgun/timetools" 7 | . "gopkg.in/check.v1" 8 | ) 9 | 10 | type HistogramSuite struct { 11 | tm *timetools.FreezedTime 12 | } 13 | 14 | var _ = Suite(&HistogramSuite{}) 15 | 16 | func (s *HistogramSuite) SetUpSuite(c *C) { 17 | s.tm = &timetools.FreezedTime{ 18 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 19 | } 20 | } 21 | 22 | func (s *HistogramSuite) TestMerge(c *C) { 23 | a, err := NewHDRHistogram(1, 3600000, 2) 24 | c.Assert(err, IsNil) 25 | 26 | a.RecordValues(1, 2) 27 | 28 | b, err := NewHDRHistogram(1, 3600000, 2) 29 | c.Assert(err, IsNil) 30 | 31 | b.RecordValues(2, 1) 32 | 33 | c.Assert(a.Merge(b), IsNil) 34 | 35 | c.Assert(a.ValueAtQuantile(50), Equals, int64(1)) 36 | c.Assert(a.ValueAtQuantile(100), Equals, int64(2)) 37 | } 38 | 39 | func (s *HistogramSuite) TestInvalidParams(c *C) { 40 | _, err := NewHDRHistogram(1, 3600000, 0) 41 | c.Assert(err, NotNil) 42 | } 43 | 44 | func (s *HistogramSuite) TestMergeNil(c *C) { 45 | a, err := NewHDRHistogram(1, 3600000, 1) 46 | c.Assert(err, IsNil) 47 | 48 | c.Assert(a.Merge(nil), NotNil) 49 | } 50 | 51 | func (s *HistogramSuite) TestRotation(c *C) { 52 | h, err := NewRollingHDRHistogram( 53 | 1, // min value 54 | 3600000, // max value 55 | 3, // significant figurwes 56 | time.Second, // 1 second is a rolling period 57 | 2, // 2 histograms in a window 58 | RollingClock(s.tm)) 59 | 60 | c.Assert(err, IsNil) 61 | c.Assert(h, NotNil) 62 | 63 | h.RecordValues(5, 1) 64 | 65 | m, err := h.Merged() 66 | c.Assert(err, IsNil) 67 | c.Assert(m.ValueAtQuantile(100), Equals, int64(5)) 68 | 69 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 70 | h.RecordValues(2, 1) 71 | h.RecordValues(1, 1) 72 | 73 | m, err = h.Merged() 74 | c.Assert(err, IsNil) 75 | c.Assert(m.ValueAtQuantile(100), Equals, int64(5)) 76 | 77 | // rotate, this means that the old value would evaporate 78 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 79 | h.RecordValues(1, 1) 80 | m, err = h.Merged() 81 | c.Assert(err, IsNil) 82 | c.Assert(m.ValueAtQuantile(100), Equals, int64(2)) 83 | } 84 | 85 | func (s *HistogramSuite) TestReset(c *C) { 86 | h, err := NewRollingHDRHistogram( 87 | 1, // min value 88 | 3600000, // max value 89 | 3, // significant figurwes 90 | time.Second, // 1 second is a rolling period 91 | 2, // 2 histograms in a window 92 | RollingClock(s.tm)) 93 | 94 | c.Assert(err, IsNil) 95 | c.Assert(h, NotNil) 96 | 97 | h.RecordValues(5, 1) 98 | 99 | m, err := h.Merged() 100 | c.Assert(err, IsNil) 101 | c.Assert(m.ValueAtQuantile(100), Equals, int64(5)) 102 | 103 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 104 | h.RecordValues(2, 1) 105 | h.RecordValues(1, 1) 106 | 107 | m, err = h.Merged() 108 | c.Assert(err, IsNil) 109 | c.Assert(m.ValueAtQuantile(100), Equals, int64(5)) 110 | 111 | h.Reset() 112 | 113 | h.RecordValues(5, 1) 114 | 115 | m, err = h.Merged() 116 | c.Assert(err, IsNil) 117 | c.Assert(m.ValueAtQuantile(100), Equals, int64(5)) 118 | 119 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 120 | h.RecordValues(2, 1) 121 | h.RecordValues(1, 1) 122 | 123 | m, err = h.Merged() 124 | c.Assert(err, IsNil) 125 | c.Assert(m.ValueAtQuantile(100), Equals, int64(5)) 126 | 127 | } 128 | -------------------------------------------------------------------------------- /memmetrics/ratio.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/mailgun/timetools" 7 | ) 8 | 9 | type ratioOptSetter func(r *RatioCounter) error 10 | 11 | func RatioClock(clock timetools.TimeProvider) ratioOptSetter { 12 | return func(r *RatioCounter) error { 13 | r.clock = clock 14 | return nil 15 | } 16 | } 17 | 18 | // RatioCounter calculates a ratio of a/a+b over a rolling window of predefined buckets 19 | type RatioCounter struct { 20 | clock timetools.TimeProvider 21 | a *RollingCounter 22 | b *RollingCounter 23 | } 24 | 25 | func NewRatioCounter(buckets int, resolution time.Duration, options ...ratioOptSetter) (*RatioCounter, error) { 26 | rc := &RatioCounter{} 27 | 28 | for _, o := range options { 29 | if err := o(rc); err != nil { 30 | return nil, err 31 | } 32 | } 33 | 34 | if rc.clock == nil { 35 | rc.clock = &timetools.RealTime{} 36 | } 37 | 38 | a, err := NewCounter(buckets, resolution, CounterClock(rc.clock)) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | b, err := NewCounter(buckets, resolution, CounterClock(rc.clock)) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | rc.a = a 49 | rc.b = b 50 | return rc, nil 51 | } 52 | 53 | func (r *RatioCounter) Reset() { 54 | r.a.Reset() 55 | r.b.Reset() 56 | } 57 | 58 | func (r *RatioCounter) IsReady() bool { 59 | return r.a.countedBuckets+r.b.countedBuckets >= len(r.a.values) 60 | } 61 | 62 | func (r *RatioCounter) CountA() int64 { 63 | return r.a.Count() 64 | } 65 | 66 | func (r *RatioCounter) CountB() int64 { 67 | return r.b.Count() 68 | } 69 | 70 | func (r *RatioCounter) Resolution() time.Duration { 71 | return r.a.Resolution() 72 | } 73 | 74 | func (r *RatioCounter) Buckets() int { 75 | return r.a.Buckets() 76 | } 77 | 78 | func (r *RatioCounter) WindowSize() time.Duration { 79 | return r.a.WindowSize() 80 | } 81 | 82 | func (r *RatioCounter) ProcessedCount() int64 { 83 | return r.CountA() + r.CountB() 84 | } 85 | 86 | func (r *RatioCounter) Ratio() float64 { 87 | a := r.a.Count() 88 | b := r.b.Count() 89 | // No data, return ok 90 | if a+b == 0 { 91 | return 0 92 | } 93 | return float64(a) / float64(a+b) 94 | } 95 | 96 | func (r *RatioCounter) IncA(v int) { 97 | r.a.Inc(v) 98 | } 99 | 100 | func (r *RatioCounter) IncB(v int) { 101 | r.b.Inc(v) 102 | } 103 | 104 | type TestMeter struct { 105 | Rate float64 106 | NotReady bool 107 | WindowSize time.Duration 108 | } 109 | 110 | func (tm *TestMeter) GetWindowSize() time.Duration { 111 | return tm.WindowSize 112 | } 113 | 114 | func (tm *TestMeter) IsReady() bool { 115 | return !tm.NotReady 116 | } 117 | 118 | func (tm *TestMeter) GetRate() float64 { 119 | return tm.Rate 120 | } 121 | -------------------------------------------------------------------------------- /memmetrics/ratio_test.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/mailgun/timetools" 8 | . "gopkg.in/check.v1" 9 | ) 10 | 11 | func TestFailrate(t *testing.T) { TestingT(t) } 12 | 13 | type FailRateSuite struct { 14 | tm *timetools.FreezedTime 15 | } 16 | 17 | var _ = Suite(&FailRateSuite{}) 18 | 19 | func (s *FailRateSuite) SetUpSuite(c *C) { 20 | s.tm = &timetools.FreezedTime{ 21 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 22 | } 23 | } 24 | 25 | func (s *FailRateSuite) TestInvalidParams(c *C) { 26 | // Bad buckets count 27 | _, err := NewRatioCounter(0, time.Second, RatioClock(s.tm)) 28 | c.Assert(err, Not(IsNil)) 29 | 30 | // Too precise resolution 31 | _, err = NewRatioCounter(10, time.Millisecond, RatioClock(s.tm)) 32 | c.Assert(err, Not(IsNil)) 33 | } 34 | 35 | func (s *FailRateSuite) TestNotReady(c *C) { 36 | // No data 37 | fr, err := NewRatioCounter(10, time.Second, RatioClock(s.tm)) 38 | c.Assert(err, IsNil) 39 | c.Assert(fr.IsReady(), Equals, false) 40 | c.Assert(fr.Ratio(), Equals, 0.0) 41 | 42 | // Not enough data 43 | fr, err = NewRatioCounter(10, time.Second, RatioClock(s.tm)) 44 | c.Assert(err, IsNil) 45 | fr.CountA() 46 | c.Assert(fr.IsReady(), Equals, false) 47 | } 48 | 49 | func (s *FailRateSuite) TestNoB(c *C) { 50 | fr, err := NewRatioCounter(1, time.Second, RatioClock(s.tm)) 51 | c.Assert(err, IsNil) 52 | fr.IncA(1) 53 | c.Assert(fr.IsReady(), Equals, true) 54 | c.Assert(fr.Ratio(), Equals, 1.0) 55 | } 56 | 57 | func (s *FailRateSuite) TestNoA(c *C) { 58 | fr, err := NewRatioCounter(1, time.Second, RatioClock(s.tm)) 59 | c.Assert(err, IsNil) 60 | fr.IncB(1) 61 | c.Assert(fr.IsReady(), Equals, true) 62 | c.Assert(fr.Ratio(), Equals, 0.0) 63 | } 64 | 65 | // Make sure that data is properly calculated over several buckets 66 | func (s *FailRateSuite) TestMultipleBuckets(c *C) { 67 | fr, err := NewRatioCounter(3, time.Second, RatioClock(s.tm)) 68 | c.Assert(err, IsNil) 69 | 70 | fr.IncB(1) 71 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 72 | fr.IncA(1) 73 | 74 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 75 | fr.IncA(1) 76 | 77 | c.Assert(fr.IsReady(), Equals, true) 78 | c.Assert(fr.Ratio(), Equals, float64(2)/float64(3)) 79 | } 80 | 81 | // Make sure that data is properly calculated over several buckets 82 | // When we overwrite old data when the window is rolling 83 | func (s *FailRateSuite) TestOverwriteBuckets(c *C) { 84 | fr, err := NewRatioCounter(3, time.Second, RatioClock(s.tm)) 85 | c.Assert(err, IsNil) 86 | 87 | fr.IncB(1) 88 | 89 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 90 | fr.IncA(1) 91 | 92 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 93 | fr.IncA(1) 94 | 95 | // This time we should overwrite the old data points 96 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 97 | fr.IncA(1) 98 | fr.IncB(2) 99 | 100 | c.Assert(fr.IsReady(), Equals, true) 101 | c.Assert(fr.Ratio(), Equals, float64(3)/float64(5)) 102 | } 103 | 104 | // Make sure we cleanup the data after periods of inactivity 105 | // So it does not mess up the stats 106 | func (s *FailRateSuite) TestInactiveBuckets(c *C) { 107 | 108 | fr, err := NewRatioCounter(3, time.Second, RatioClock(s.tm)) 109 | c.Assert(err, IsNil) 110 | 111 | fr.IncB(1) 112 | 113 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 114 | fr.IncA(1) 115 | 116 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 117 | fr.IncA(1) 118 | 119 | // This time we should overwrite the old data points with new data 120 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 121 | fr.IncA(1) 122 | fr.IncB(2) 123 | 124 | // Jump to the last bucket and change the data 125 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second * 2) 126 | fr.IncB(1) 127 | 128 | c.Assert(fr.IsReady(), Equals, true) 129 | c.Assert(fr.Ratio(), Equals, float64(1)/float64(4)) 130 | } 131 | 132 | func (s *FailRateSuite) TestLongPeriodsOfInactivity(c *C) { 133 | fr, err := NewRatioCounter(2, time.Second, RatioClock(s.tm)) 134 | c.Assert(err, IsNil) 135 | 136 | fr.IncB(1) 137 | 138 | s.tm.CurrentTime = s.tm.CurrentTime.Add(time.Second) 139 | fr.IncA(1) 140 | 141 | c.Assert(fr.IsReady(), Equals, true) 142 | c.Assert(fr.Ratio(), Equals, 0.5) 143 | 144 | // This time we should overwrite all data points 145 | s.tm.CurrentTime = s.tm.CurrentTime.Add(100 * time.Second) 146 | fr.IncA(1) 147 | c.Assert(fr.Ratio(), Equals, 1.0) 148 | } 149 | 150 | func (s *FailRateSuite) TestReset(c *C) { 151 | fr, err := NewRatioCounter(1, time.Second, RatioClock(s.tm)) 152 | c.Assert(err, IsNil) 153 | 154 | fr.IncB(1) 155 | fr.IncA(1) 156 | 157 | c.Assert(fr.IsReady(), Equals, true) 158 | c.Assert(fr.Ratio(), Equals, 0.5) 159 | 160 | // Reset the counter 161 | fr.Reset() 162 | c.Assert(fr.IsReady(), Equals, false) 163 | 164 | // Now add some stats 165 | fr.IncA(2) 166 | 167 | // We are game again! 168 | c.Assert(fr.IsReady(), Equals, true) 169 | c.Assert(fr.Ratio(), Equals, 1.0) 170 | } 171 | -------------------------------------------------------------------------------- /memmetrics/roundtrip.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "sync" 7 | "time" 8 | 9 | "github.com/mailgun/timetools" 10 | ) 11 | 12 | // RTMetrics provides aggregated performance metrics for HTTP requests processing 13 | // such as round trip latency, response codes counters network error and total requests. 14 | // all counters are collected as rolling window counters with defined precision, histograms 15 | // are a rolling window histograms with defined precision as well. 16 | // See RTOptions for more detail on parameters. 17 | type RTMetrics struct { 18 | total *RollingCounter 19 | netErrors *RollingCounter 20 | statusCodes map[int]*RollingCounter 21 | statusCodesLock sync.RWMutex 22 | histogram *RollingHDRHistogram 23 | 24 | newCounter NewCounterFn 25 | newHist NewRollingHistogramFn 26 | clock timetools.TimeProvider 27 | } 28 | 29 | type rrOptSetter func(r *RTMetrics) error 30 | 31 | type NewRTMetricsFn func() (*RTMetrics, error) 32 | type NewCounterFn func() (*RollingCounter, error) 33 | type NewRollingHistogramFn func() (*RollingHDRHistogram, error) 34 | 35 | func RTCounter(new NewCounterFn) rrOptSetter { 36 | return func(r *RTMetrics) error { 37 | r.newCounter = new 38 | return nil 39 | } 40 | } 41 | 42 | func RTHistogram(new NewRollingHistogramFn) rrOptSetter { 43 | return func(r *RTMetrics) error { 44 | r.newHist = new 45 | return nil 46 | } 47 | } 48 | 49 | func RTClock(clock timetools.TimeProvider) rrOptSetter { 50 | return func(r *RTMetrics) error { 51 | r.clock = clock 52 | return nil 53 | } 54 | } 55 | 56 | // NewRTMetrics returns new instance of metrics collector. 57 | func NewRTMetrics(settings ...rrOptSetter) (*RTMetrics, error) { 58 | m := &RTMetrics{ 59 | statusCodes: make(map[int]*RollingCounter), 60 | statusCodesLock: sync.RWMutex{}, 61 | } 62 | for _, s := range settings { 63 | if err := s(m); err != nil { 64 | return nil, err 65 | } 66 | } 67 | 68 | if m.clock == nil { 69 | m.clock = &timetools.RealTime{} 70 | } 71 | 72 | if m.newCounter == nil { 73 | m.newCounter = func() (*RollingCounter, error) { 74 | return NewCounter(counterBuckets, counterResolution, CounterClock(m.clock)) 75 | } 76 | } 77 | 78 | if m.newHist == nil { 79 | m.newHist = func() (*RollingHDRHistogram, error) { 80 | return NewRollingHDRHistogram(histMin, histMax, histSignificantFigures, histPeriod, histBuckets, RollingClock(m.clock)) 81 | } 82 | } 83 | 84 | h, err := m.newHist() 85 | if err != nil { 86 | return nil, err 87 | } 88 | 89 | netErrors, err := m.newCounter() 90 | if err != nil { 91 | return nil, err 92 | } 93 | 94 | total, err := m.newCounter() 95 | if err != nil { 96 | return nil, err 97 | } 98 | 99 | m.histogram = h 100 | m.netErrors = netErrors 101 | m.total = total 102 | return m, nil 103 | } 104 | 105 | func (m *RTMetrics) CounterWindowSize() time.Duration { 106 | return m.total.WindowSize() 107 | } 108 | 109 | // GetNetworkErrorRatio calculates the amont of network errors such as time outs and dropped connection 110 | // that occured in the given time window compared to the total requests count. 111 | func (m *RTMetrics) NetworkErrorRatio() float64 { 112 | if m.total.Count() == 0 { 113 | return 0 114 | } 115 | return float64(m.netErrors.Count()) / float64(m.total.Count()) 116 | } 117 | 118 | // GetResponseCodeRatio calculates ratio of count(startA to endA) / count(startB to endB) 119 | func (m *RTMetrics) ResponseCodeRatio(startA, endA, startB, endB int) float64 { 120 | a := int64(0) 121 | b := int64(0) 122 | m.statusCodesLock.RLock() 123 | defer m.statusCodesLock.RUnlock() 124 | for code, v := range m.statusCodes { 125 | if code < endA && code >= startA { 126 | a += v.Count() 127 | } 128 | if code < endB && code >= startB { 129 | b += v.Count() 130 | } 131 | } 132 | if b != 0 { 133 | return float64(a) / float64(b) 134 | } 135 | return 0 136 | } 137 | 138 | func (m *RTMetrics) Append(other *RTMetrics) error { 139 | if m == other { 140 | return errors.New("RTMetrics cannot append to self") 141 | } 142 | 143 | if err := m.total.Append(other.total); err != nil { 144 | return err 145 | } 146 | 147 | if err := m.netErrors.Append(other.netErrors); err != nil { 148 | return err 149 | } 150 | 151 | m.statusCodesLock.Lock() 152 | defer m.statusCodesLock.Unlock() 153 | other.statusCodesLock.RLock() 154 | defer other.statusCodesLock.RUnlock() 155 | for code, c := range other.statusCodes { 156 | o, ok := m.statusCodes[code] 157 | if ok { 158 | if err := o.Append(c); err != nil { 159 | return err 160 | } 161 | } else { 162 | m.statusCodes[code] = c.Clone() 163 | } 164 | } 165 | 166 | return m.histogram.Append(other.histogram) 167 | } 168 | 169 | func (m *RTMetrics) Record(code int, duration time.Duration) { 170 | m.total.Inc(1) 171 | if code == http.StatusGatewayTimeout || code == http.StatusBadGateway { 172 | m.netErrors.Inc(1) 173 | } 174 | m.recordStatusCode(code) 175 | m.recordLatency(duration) 176 | } 177 | 178 | // GetTotalCount returns total count of processed requests collected. 179 | func (m *RTMetrics) TotalCount() int64 { 180 | return m.total.Count() 181 | } 182 | 183 | // GetNetworkErrorCount returns total count of processed requests observed 184 | func (m *RTMetrics) NetworkErrorCount() int64 { 185 | return m.netErrors.Count() 186 | } 187 | 188 | // GetStatusCodesCounts returns map with counts of the response codes 189 | func (m *RTMetrics) StatusCodesCounts() map[int]int64 { 190 | sc := make(map[int]int64) 191 | m.statusCodesLock.RLock() 192 | defer m.statusCodesLock.RUnlock() 193 | for k, v := range m.statusCodes { 194 | if v.Count() != 0 { 195 | sc[k] = v.Count() 196 | } 197 | } 198 | return sc 199 | } 200 | 201 | // GetLatencyHistogram computes and returns resulting histogram with latencies observed. 202 | func (m *RTMetrics) LatencyHistogram() (*HDRHistogram, error) { 203 | return m.histogram.Merged() 204 | } 205 | 206 | func (m *RTMetrics) Reset() { 207 | m.histogram.Reset() 208 | m.total.Reset() 209 | m.netErrors.Reset() 210 | m.statusCodesLock.Lock() 211 | defer m.statusCodesLock.Unlock() 212 | m.statusCodes = make(map[int]*RollingCounter) 213 | } 214 | 215 | func (m *RTMetrics) recordNetError() error { 216 | m.netErrors.Inc(1) 217 | return nil 218 | } 219 | 220 | func (m *RTMetrics) recordLatency(d time.Duration) error { 221 | return m.histogram.RecordLatencies(d, 1) 222 | } 223 | 224 | func (m *RTMetrics) recordStatusCode(statusCode int) error { 225 | m.statusCodesLock.RLock() 226 | if c, ok := m.statusCodes[statusCode]; ok { 227 | c.Inc(1) 228 | m.statusCodesLock.RUnlock() 229 | return nil 230 | } 231 | m.statusCodesLock.RUnlock() 232 | 233 | m.statusCodesLock.Lock() 234 | defer m.statusCodesLock.Unlock() 235 | 236 | // Check if another goroutine has written our counter already 237 | if c, ok := m.statusCodes[statusCode]; ok { 238 | c.Inc(1) 239 | return nil 240 | } 241 | 242 | c, err := m.newCounter() 243 | if err != nil { 244 | return err 245 | } 246 | c.Inc(1) 247 | m.statusCodes[statusCode] = c 248 | return nil 249 | } 250 | 251 | const ( 252 | counterBuckets = 10 253 | counterResolution = time.Second 254 | histMin = 1 255 | histMax = 3600000000 // 1 hour in microseconds 256 | histSignificantFigures = 2 // signigicant figures (1% precision) 257 | histBuckets = 6 // number of sub-histograms in a rolling histogram 258 | histPeriod = 10 * time.Second // roll time 259 | ) 260 | -------------------------------------------------------------------------------- /memmetrics/roundtrip_test.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "runtime" 5 | "sync" 6 | "time" 7 | 8 | "github.com/mailgun/timetools" 9 | . "gopkg.in/check.v1" 10 | ) 11 | 12 | type RRSuite struct { 13 | tm *timetools.FreezedTime 14 | } 15 | 16 | var _ = Suite(&RRSuite{}) 17 | 18 | func (s *RRSuite) SetUpSuite(c *C) { 19 | s.tm = &timetools.FreezedTime{ 20 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 21 | } 22 | } 23 | 24 | func (s *RRSuite) TestDefaults(c *C) { 25 | rr, err := NewRTMetrics(RTClock(s.tm)) 26 | c.Assert(err, IsNil) 27 | c.Assert(rr, NotNil) 28 | 29 | rr.Record(200, time.Second) 30 | rr.Record(502, 2*time.Second) 31 | rr.Record(200, time.Second) 32 | rr.Record(200, time.Second) 33 | 34 | c.Assert(rr.NetworkErrorCount(), Equals, int64(1)) 35 | c.Assert(rr.TotalCount(), Equals, int64(4)) 36 | c.Assert(rr.StatusCodesCounts(), DeepEquals, map[int]int64{502: 1, 200: 3}) 37 | c.Assert(rr.NetworkErrorRatio(), Equals, float64(1)/float64(4)) 38 | c.Assert(rr.ResponseCodeRatio(500, 503, 200, 300), Equals, 1.0/3.0) 39 | 40 | h, err := rr.LatencyHistogram() 41 | c.Assert(err, IsNil) 42 | c.Assert(int(h.LatencyAtQuantile(100)/time.Second), Equals, 2) 43 | 44 | rr.Reset() 45 | c.Assert(rr.NetworkErrorCount(), Equals, int64(0)) 46 | c.Assert(rr.TotalCount(), Equals, int64(0)) 47 | c.Assert(rr.StatusCodesCounts(), DeepEquals, map[int]int64{}) 48 | c.Assert(rr.NetworkErrorRatio(), Equals, float64(0)) 49 | c.Assert(rr.ResponseCodeRatio(500, 503, 200, 300), Equals, float64(0)) 50 | 51 | h, err = rr.LatencyHistogram() 52 | c.Assert(err, IsNil) 53 | c.Assert(h.LatencyAtQuantile(100), Equals, time.Duration(0)) 54 | 55 | } 56 | 57 | func (s *RRSuite) TestAppend(c *C) { 58 | rr, err := NewRTMetrics(RTClock(s.tm)) 59 | c.Assert(err, IsNil) 60 | c.Assert(rr, NotNil) 61 | 62 | rr.Record(200, time.Second) 63 | rr.Record(502, 2*time.Second) 64 | rr.Record(200, time.Second) 65 | rr.Record(200, time.Second) 66 | 67 | rr2, err := NewRTMetrics(RTClock(s.tm)) 68 | c.Assert(err, IsNil) 69 | c.Assert(rr2, NotNil) 70 | 71 | rr2.Record(200, 3*time.Second) 72 | rr2.Record(501, 3*time.Second) 73 | rr2.Record(200, 3*time.Second) 74 | rr2.Record(200, 3*time.Second) 75 | 76 | c.Assert(rr2.Append(rr), IsNil) 77 | c.Assert(rr2.StatusCodesCounts(), DeepEquals, map[int]int64{501: 1, 502: 1, 200: 6}) 78 | c.Assert(rr2.NetworkErrorCount(), Equals, int64(1)) 79 | 80 | h, err := rr2.LatencyHistogram() 81 | c.Assert(err, IsNil) 82 | c.Assert(int(h.LatencyAtQuantile(100)/time.Second), Equals, 3) 83 | } 84 | 85 | func (s *RRSuite) TestConcurrentRecords(c *C) { 86 | // This test asserts a race condition which requires parallelism 87 | runtime.GOMAXPROCS(100) 88 | 89 | rr, _ := NewRTMetrics(RTClock(s.tm)) 90 | 91 | for code := 0; code < 100; code++ { 92 | l := sync.RWMutex{} 93 | l.Lock() 94 | for numRecords := 0; numRecords < 10; numRecords++ { 95 | go func() { 96 | l.RLock() 97 | rr.recordStatusCode(code) 98 | }() 99 | } 100 | l.Unlock() 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /ratelimit/bucket.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/mailgun/timetools" 8 | ) 9 | 10 | const UndefinedDelay = -1 11 | 12 | // rate defines token bucket parameters. 13 | type rate struct { 14 | period time.Duration 15 | average int64 16 | burst int64 17 | } 18 | 19 | func (r *rate) String() string { 20 | return fmt.Sprintf("rate(%v/%v, burst=%v)", r.average, r.period, r.burst) 21 | } 22 | 23 | // Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket) 24 | type tokenBucket struct { 25 | // The time period controlled by the bucket in nanoseconds. 26 | period time.Duration 27 | // The number of nanoseconds that takes to add one more token to the total 28 | // number of available tokens. It effectively caches the value that could 29 | // have been otherwise deduced from refillRate. 30 | timePerToken time.Duration 31 | // The maximum number of tokens that can be accumulate in the bucket. 32 | burst int64 33 | // The number of tokens available for consumption at the moment. It can 34 | // nether be larger then capacity. 35 | availableTokens int64 36 | // Interface that gives current time (so tests can override) 37 | clock timetools.TimeProvider 38 | // Tells when tokensAvailable was updated the last time. 39 | lastRefresh time.Time 40 | // The number of tokens consumed the last time. 41 | lastConsumed int64 42 | } 43 | 44 | // newTokenBucket crates a `tokenBucket` instance for the specified `Rate`. 45 | func newTokenBucket(rate *rate, clock timetools.TimeProvider) *tokenBucket { 46 | return &tokenBucket{ 47 | period: rate.period, 48 | timePerToken: time.Duration(int64(rate.period) / rate.average), 49 | burst: rate.burst, 50 | clock: clock, 51 | lastRefresh: clock.UtcNow(), 52 | availableTokens: rate.burst, 53 | } 54 | } 55 | 56 | // consume makes an attempt to consume the specified number of tokens from the 57 | // bucket. If there are enough tokens available then `0, nil` is returned; if 58 | // tokens to consume is larger than the burst size, then an error is returned 59 | // and the delay is not defined; otherwise returned a none zero delay that tells 60 | // how much time the caller needs to wait until the desired number of tokens 61 | // will become available for consumption. 62 | func (tb *tokenBucket) consume(tokens int64) (time.Duration, error) { 63 | tb.updateAvailableTokens() 64 | tb.lastConsumed = 0 65 | if tokens > tb.burst { 66 | return UndefinedDelay, fmt.Errorf("Requested tokens larger than max tokens") 67 | } 68 | if tb.availableTokens < tokens { 69 | return tb.timeTillAvailable(tokens), nil 70 | } 71 | tb.availableTokens -= tokens 72 | tb.lastConsumed = tokens 73 | return 0, nil 74 | } 75 | 76 | // rollback reverts effect of the most recent consumption. If the most recent 77 | // `consume` resulted in an error or a burst overflow, and therefore did not 78 | // modify the number of available tokens, then `rollback` won't do that either. 79 | // It is safe to call this method multiple times, for the second and all 80 | // following calls have no effect. 81 | func (tb *tokenBucket) rollback() { 82 | tb.availableTokens += tb.lastConsumed 83 | tb.lastConsumed = 0 84 | } 85 | 86 | // Update modifies `average` and `burst` fields of the token bucket according 87 | // to the provided `Rate` 88 | func (tb *tokenBucket) update(rate *rate) error { 89 | if rate.period != tb.period { 90 | return fmt.Errorf("Period mismatch: %v != %v", tb.period, rate.period) 91 | } 92 | tb.timePerToken = time.Duration(int64(tb.period) / rate.average) 93 | tb.burst = rate.burst 94 | if tb.availableTokens > rate.burst { 95 | tb.availableTokens = rate.burst 96 | } 97 | return nil 98 | } 99 | 100 | // timeTillAvailable returns the number of nanoseconds that we need to 101 | // wait until the specified number of tokens becomes available for consumption. 102 | func (tb *tokenBucket) timeTillAvailable(tokens int64) time.Duration { 103 | missingTokens := tokens - tb.availableTokens 104 | return time.Duration(missingTokens) * tb.timePerToken 105 | } 106 | 107 | // updateAvailableTokens updates the number of tokens available for consumption. 108 | // It is calculated based on the refill rate, the time passed since last refresh, 109 | // and is limited by the bucket capacity. 110 | func (tb *tokenBucket) updateAvailableTokens() { 111 | now := tb.clock.UtcNow() 112 | timePassed := now.Sub(tb.lastRefresh) 113 | 114 | tokens := tb.availableTokens + int64(timePassed/tb.timePerToken) 115 | // If we haven't added any tokens that means that not enough time has passed, 116 | // in this case do not adjust last refill checkpoint, otherwise it will be 117 | // always moving in time in case of frequent requests that exceed the rate 118 | if tokens != tb.availableTokens { 119 | tb.lastRefresh = now 120 | tb.availableTokens = tokens 121 | } 122 | if tb.availableTokens > tb.burst { 123 | tb.availableTokens = tb.burst 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /ratelimit/bucket_test.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/mailgun/timetools" 8 | . "gopkg.in/check.v1" 9 | ) 10 | 11 | func TestTokenBucket(t *testing.T) { TestingT(t) } 12 | 13 | type BucketSuite struct { 14 | clock *timetools.FreezedTime 15 | } 16 | 17 | var _ = Suite(&BucketSuite{}) 18 | 19 | func (s *BucketSuite) SetUpSuite(c *C) { 20 | s.clock = &timetools.FreezedTime{ 21 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 22 | } 23 | } 24 | 25 | func (s *BucketSuite) TestConsumeSingleToken(c *C) { 26 | tb := newTokenBucket(&rate{time.Second, 1, 1}, s.clock) 27 | 28 | // First request passes 29 | delay, err := tb.consume(1) 30 | c.Assert(err, IsNil) 31 | c.Assert(delay, Equals, time.Duration(0)) 32 | 33 | // Next request does not pass the same second 34 | delay, err = tb.consume(1) 35 | c.Assert(err, IsNil) 36 | c.Assert(delay, Equals, time.Second) 37 | 38 | // Second later, the request passes 39 | s.clock.Sleep(time.Second) 40 | delay, err = tb.consume(1) 41 | c.Assert(err, IsNil) 42 | c.Assert(delay, Equals, time.Duration(0)) 43 | 44 | // Five seconds later, still only one request is allowed 45 | // because maxBurst is 1 46 | s.clock.Sleep(5 * time.Second) 47 | delay, err = tb.consume(1) 48 | c.Assert(err, IsNil) 49 | c.Assert(delay, Equals, time.Duration(0)) 50 | 51 | // The next one is forbidden 52 | delay, err = tb.consume(1) 53 | c.Assert(err, IsNil) 54 | c.Assert(delay, Equals, time.Second) 55 | } 56 | 57 | func (s *BucketSuite) TestFastConsumption(c *C) { 58 | tb := newTokenBucket(&rate{time.Second, 1, 1}, s.clock) 59 | 60 | // First request passes 61 | delay, err := tb.consume(1) 62 | c.Assert(err, IsNil) 63 | c.Assert(delay, Equals, time.Duration(0)) 64 | 65 | // Try 200 ms later 66 | s.clock.Sleep(time.Millisecond * 200) 67 | delay, err = tb.consume(1) 68 | c.Assert(err, IsNil) 69 | c.Assert(delay, Equals, time.Second) 70 | 71 | // Try 700 ms later 72 | s.clock.Sleep(time.Millisecond * 700) 73 | delay, err = tb.consume(1) 74 | c.Assert(err, IsNil) 75 | c.Assert(delay, Equals, time.Second) 76 | 77 | // Try 100 ms later, success! 78 | s.clock.Sleep(time.Millisecond * 100) 79 | delay, err = tb.consume(1) 80 | c.Assert(err, IsNil) 81 | c.Assert(delay, Equals, time.Duration(0)) 82 | } 83 | 84 | func (s *BucketSuite) TestConsumeMultipleTokens(c *C) { 85 | tb := newTokenBucket(&rate{time.Second, 3, 5}, s.clock) 86 | 87 | delay, err := tb.consume(3) 88 | c.Assert(err, IsNil) 89 | c.Assert(delay, Equals, time.Duration(0)) 90 | 91 | delay, err = tb.consume(2) 92 | c.Assert(err, IsNil) 93 | c.Assert(delay, Equals, time.Duration(0)) 94 | 95 | delay, err = tb.consume(1) 96 | c.Assert(err, IsNil) 97 | c.Assert(delay, Not(Equals), time.Duration(0)) 98 | } 99 | 100 | func (s *BucketSuite) TestDelayIsCorrect(c *C) { 101 | tb := newTokenBucket(&rate{time.Second, 3, 5}, s.clock) 102 | 103 | // Exhaust initial capacity 104 | delay, err := tb.consume(5) 105 | c.Assert(err, IsNil) 106 | c.Assert(delay, Equals, time.Duration(0)) 107 | 108 | delay, err = tb.consume(3) 109 | c.Assert(err, IsNil) 110 | c.Assert(delay, Not(Equals), time.Duration(0)) 111 | 112 | // Now wait provided delay and make sure we can consume now 113 | s.clock.Sleep(delay) 114 | delay, err = tb.consume(3) 115 | c.Assert(err, IsNil) 116 | c.Assert(delay, Equals, time.Duration(0)) 117 | } 118 | 119 | // Make sure requests that exceed burst size are not allowed 120 | func (s *BucketSuite) TestExceedsBurst(c *C) { 121 | tb := newTokenBucket(&rate{time.Second, 1, 10}, s.clock) 122 | 123 | _, err := tb.consume(11) 124 | c.Assert(err, NotNil) 125 | } 126 | 127 | func (s *BucketSuite) TestConsumeBurst(c *C) { 128 | tb := newTokenBucket(&rate{time.Second, 2, 5}, s.clock) 129 | 130 | // In two seconds we would have 5 tokens 131 | s.clock.Sleep(2 * time.Second) 132 | 133 | // Lets consume 5 at once 134 | delay, err := tb.consume(5) 135 | c.Assert(delay, Equals, time.Duration(0)) 136 | c.Assert(err, IsNil) 137 | } 138 | 139 | func (s *BucketSuite) TestConsumeEstimate(c *C) { 140 | tb := newTokenBucket(&rate{time.Second, 2, 4}, s.clock) 141 | 142 | // Consume all burst at once 143 | delay, err := tb.consume(4) 144 | c.Assert(err, IsNil) 145 | c.Assert(delay, Equals, time.Duration(0)) 146 | 147 | // Now try to consume it and face delay 148 | delay, err = tb.consume(4) 149 | c.Assert(err, IsNil) 150 | c.Assert(delay, Equals, time.Duration(2)*time.Second) 151 | } 152 | 153 | // If a rate with different period is passed to the `update` method, then an 154 | // error is returned but the state of the bucket remains valid and unchanged. 155 | func (s *BucketSuite) TestUpdateInvalidPeriod(c *C) { 156 | // Given 157 | tb := newTokenBucket(&rate{time.Second, 10, 20}, s.clock) 158 | tb.consume(15) // 5 tokens available 159 | // When 160 | err := tb.update(&rate{time.Second + 1, 30, 40}) // still 5 tokens available 161 | // Then 162 | c.Assert(err, NotNil) 163 | 164 | // ...check that rate did not change 165 | s.clock.Sleep(500 * time.Millisecond) 166 | delay, err := tb.consume(11) 167 | c.Assert(err, IsNil) 168 | c.Assert(delay, Equals, 100*time.Millisecond) 169 | delay, err = tb.consume(10) 170 | c.Assert(err, IsNil) 171 | c.Assert(delay, Equals, time.Duration(0)) // 0 available 172 | 173 | // ...check that burst did not change 174 | s.clock.Sleep(40 * time.Second) 175 | delay, err = tb.consume(21) 176 | c.Assert(err, NotNil) 177 | delay, err = tb.consume(20) 178 | c.Assert(err, IsNil) 179 | c.Assert(delay, Equals, time.Duration(0)) // 0 available 180 | } 181 | 182 | // If the capacity of the bucket is increased by the update then it takes some 183 | // time to fill the bucket with tokens up to the new capacity. 184 | func (s *BucketSuite) TestUpdateBurstIncreased(c *C) { 185 | // Given 186 | tb := newTokenBucket(&rate{time.Second, 10, 20}, s.clock) 187 | tb.consume(15) // 5 tokens available 188 | // When 189 | err := tb.update(&rate{time.Second, 10, 50}) // still 5 tokens available 190 | // Then 191 | c.Assert(err, IsNil) 192 | delay, err := tb.consume(50) 193 | c.Assert(err, IsNil) 194 | c.Assert(delay, Equals, time.Duration(time.Second/10*45)) 195 | } 196 | 197 | // If the capacity of the bucket is increased by the update then it takes some 198 | // time to fill the bucket with tokens up to the new capacity. 199 | func (s *BucketSuite) TestUpdateBurstDecreased(c *C) { 200 | // Given 201 | tb := newTokenBucket(&rate{time.Second, 10, 50}, s.clock) 202 | tb.consume(15) // 35 tokens available 203 | // When 204 | err := tb.update(&rate{time.Second, 10, 20}) // the number of available tokens reduced to 20. 205 | // Then 206 | c.Assert(err, IsNil) 207 | delay, err := tb.consume(21) 208 | c.Assert(err, NotNil) 209 | c.Assert(delay, Equals, time.Duration(-1)) 210 | } 211 | 212 | // If rate is updated then it affects the bucket refill speed. 213 | func (s *BucketSuite) TestUpdateRateChanged(c *C) { 214 | // Given 215 | tb := newTokenBucket(&rate{time.Second, 10, 20}, s.clock) 216 | tb.consume(15) // 5 tokens available 217 | // When 218 | err := tb.update(&rate{time.Second, 20, 20}) // still 5 tokens available 219 | // Then 220 | delay, err := tb.consume(20) 221 | c.Assert(err, IsNil) 222 | c.Assert(delay, Equals, time.Duration(time.Second/20*15)) 223 | } 224 | 225 | // Only the most recent consumption is reverted by `Rollback`. 226 | func (s *BucketSuite) TestRollback(c *C) { 227 | // Given 228 | tb := newTokenBucket(&rate{time.Second, 10, 20}, s.clock) 229 | tb.consume(8) // 12 tokens available 230 | tb.consume(7) // 5 tokens available 231 | // When 232 | tb.rollback() // 12 tokens available 233 | // Then 234 | delay, err := tb.consume(12) 235 | c.Assert(err, IsNil) 236 | c.Assert(delay, Equals, time.Duration(0)) 237 | delay, err = tb.consume(1) 238 | c.Assert(err, IsNil) 239 | c.Assert(delay, Equals, 100*time.Millisecond) 240 | } 241 | 242 | // It is safe to call `Rollback` several times. The second and all subsequent 243 | // calls just do nothing. 244 | func (s *BucketSuite) TestRollbackSeveralTimes(c *C) { 245 | // Given 246 | tb := newTokenBucket(&rate{time.Second, 10, 20}, s.clock) 247 | tb.consume(8) // 12 tokens available 248 | tb.rollback() // 20 tokens available 249 | // When 250 | tb.rollback() // still 20 tokens available 251 | tb.rollback() // still 20 tokens available 252 | tb.rollback() // still 20 tokens available 253 | // Then: all 20 tokens can be consumed 254 | delay, err := tb.consume(20) 255 | c.Assert(err, IsNil) 256 | c.Assert(delay, Equals, time.Duration(0)) 257 | delay, err = tb.consume(1) 258 | c.Assert(err, IsNil) 259 | c.Assert(delay, Equals, 100*time.Millisecond) 260 | } 261 | 262 | // If previous consumption returned a delay due to an attempt to consume more 263 | // tokens then there are available, then `Rollback` has no effect. 264 | func (s *BucketSuite) TestRollbackAfterAvailableExceeded(c *C) { 265 | // Given 266 | tb := newTokenBucket(&rate{time.Second, 10, 20}, s.clock) 267 | tb.consume(8) // 12 tokens available 268 | delay, err := tb.consume(15) // still 12 tokens available 269 | c.Assert(err, IsNil) 270 | c.Assert(delay, Equals, 300*time.Millisecond) 271 | // When 272 | tb.rollback() // Previous operation consumed 0 tokens, so rollback has no effect. 273 | // Then 274 | delay, err = tb.consume(12) 275 | c.Assert(err, IsNil) 276 | c.Assert(delay, Equals, time.Duration(0)) 277 | delay, err = tb.consume(1) 278 | c.Assert(err, IsNil) 279 | c.Assert(delay, Equals, 100*time.Millisecond) 280 | } 281 | 282 | // If previous consumption returned a error due to an attempt to consume more 283 | // tokens then the bucket's burst size, then `Rollback` has no effect. 284 | func (s *BucketSuite) TestRollbackAfterError(c *C) { 285 | // Given 286 | tb := newTokenBucket(&rate{time.Second, 10, 20}, s.clock) 287 | tb.consume(8) // 12 tokens available 288 | delay, err := tb.consume(21) // still 12 tokens available 289 | c.Assert(err, NotNil) 290 | c.Assert(delay, Equals, time.Duration(-1)) 291 | // When 292 | tb.rollback() // Previous operation consumed 0 tokens, so rollback has no effect. 293 | // Then 294 | delay, err = tb.consume(12) 295 | c.Assert(err, IsNil) 296 | c.Assert(delay, Equals, time.Duration(0)) 297 | delay, err = tb.consume(1) 298 | c.Assert(err, IsNil) 299 | c.Assert(delay, Equals, 100*time.Millisecond) 300 | } 301 | -------------------------------------------------------------------------------- /ratelimit/bucketset.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "time" 7 | 8 | "github.com/mailgun/timetools" 9 | "sort" 10 | ) 11 | 12 | // TokenBucketSet represents a set of TokenBucket covering different time periods. 13 | type TokenBucketSet struct { 14 | buckets map[time.Duration]*tokenBucket 15 | maxPeriod time.Duration 16 | clock timetools.TimeProvider 17 | } 18 | 19 | // newTokenBucketSet creates a `TokenBucketSet` from the specified `rates`. 20 | func NewTokenBucketSet(rates *RateSet, clock timetools.TimeProvider) *TokenBucketSet { 21 | tbs := new(TokenBucketSet) 22 | tbs.clock = clock 23 | // In the majority of cases we will have only one bucket. 24 | tbs.buckets = make(map[time.Duration]*tokenBucket, len(rates.m)) 25 | for _, rate := range rates.m { 26 | newBucket := newTokenBucket(rate, clock) 27 | tbs.buckets[rate.period] = newBucket 28 | tbs.maxPeriod = maxDuration(tbs.maxPeriod, rate.period) 29 | } 30 | return tbs 31 | } 32 | 33 | // Update brings the buckets in the set in accordance with the provided `rates`. 34 | func (tbs *TokenBucketSet) Update(rates *RateSet) { 35 | // Update existing buckets and delete those that have no corresponding spec. 36 | for _, bucket := range tbs.buckets { 37 | if rate, ok := rates.m[bucket.period]; ok { 38 | bucket.update(rate) 39 | } else { 40 | delete(tbs.buckets, bucket.period) 41 | } 42 | } 43 | // Add missing buckets. 44 | for _, rate := range rates.m { 45 | if _, ok := tbs.buckets[rate.period]; !ok { 46 | newBucket := newTokenBucket(rate, tbs.clock) 47 | tbs.buckets[rate.period] = newBucket 48 | } 49 | } 50 | // Identify the maximum period in the set 51 | tbs.maxPeriod = 0 52 | for _, bucket := range tbs.buckets { 53 | tbs.maxPeriod = maxDuration(tbs.maxPeriod, bucket.period) 54 | } 55 | } 56 | 57 | func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) { 58 | var maxDelay time.Duration = UndefinedDelay 59 | var firstErr error = nil 60 | for _, tokenBucket := range tbs.buckets { 61 | // We keep calling `Consume` even after a error is returned for one of 62 | // buckets because that allows us to simplify the rollback procedure, 63 | // that is to just call `Rollback` for all buckets. 64 | delay, err := tokenBucket.consume(tokens) 65 | if firstErr == nil { 66 | if err != nil { 67 | firstErr = err 68 | } else { 69 | maxDelay = maxDuration(maxDelay, delay) 70 | } 71 | } 72 | } 73 | // If we could not make ALL buckets consume tokens for whatever reason, 74 | // then rollback consumption for all of them. 75 | if firstErr != nil || maxDelay > 0 { 76 | for _, tokenBucket := range tbs.buckets { 77 | tokenBucket.rollback() 78 | } 79 | } 80 | return maxDelay, firstErr 81 | } 82 | 83 | func (tbs *TokenBucketSet) GetMaxPeriod() time.Duration { 84 | return tbs.maxPeriod 85 | } 86 | 87 | // debugState returns string that reflects the current state of all buckets in 88 | // this set. It is intended to be used for debugging and testing only. 89 | func (tbs *TokenBucketSet) debugState() string { 90 | periods := sort.IntSlice(make([]int, 0, len(tbs.buckets))) 91 | for period := range tbs.buckets { 92 | periods = append(periods, int(period)) 93 | } 94 | sort.Sort(periods) 95 | bucketRepr := make([]string, 0, len(tbs.buckets)) 96 | for _, period := range periods { 97 | bucket := tbs.buckets[time.Duration(period)] 98 | bucketRepr = append(bucketRepr, fmt.Sprintf("{%v: %v}", bucket.period, bucket.availableTokens)) 99 | } 100 | return strings.Join(bucketRepr, ", ") 101 | } 102 | 103 | func maxDuration(x time.Duration, y time.Duration) time.Duration { 104 | if x > y { 105 | return x 106 | } 107 | return y 108 | } 109 | -------------------------------------------------------------------------------- /ratelimit/bucketset_test.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/mailgun/timetools" 7 | . "gopkg.in/check.v1" 8 | ) 9 | 10 | type BucketSetSuite struct { 11 | clock *timetools.FreezedTime 12 | } 13 | 14 | var _ = Suite(&BucketSetSuite{}) 15 | 16 | func (s *BucketSetSuite) SetUpSuite(c *C) { 17 | s.clock = &timetools.FreezedTime{ 18 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 19 | } 20 | } 21 | 22 | // A value returned by `MaxPeriod` corresponds to the longest bucket time period. 23 | func (s *BucketSetSuite) TestLongestPeriod(c *C) { 24 | // Given 25 | rates := NewRateSet() 26 | rates.Add(1*time.Second, 10, 20) 27 | rates.Add(7*time.Second, 10, 20) 28 | rates.Add(5*time.Second, 11, 21) 29 | // When 30 | tbs := NewTokenBucketSet(rates, s.clock) 31 | // Then 32 | c.Assert(tbs.maxPeriod, Equals, 7*time.Second) 33 | } 34 | 35 | // Successful token consumption updates state of all buckets in the set. 36 | func (s *BucketSetSuite) TestConsume(c *C) { 37 | // Given 38 | rates := NewRateSet() 39 | rates.Add(1*time.Second, 10, 20) 40 | rates.Add(10*time.Second, 20, 50) 41 | tbs := NewTokenBucketSet(rates, s.clock) 42 | // When 43 | delay, err := tbs.Consume(15) 44 | // Then 45 | c.Assert(delay, Equals, time.Duration(0)) 46 | c.Assert(err, IsNil) 47 | c.Assert(tbs.debugState(), Equals, "{1s: 5}, {10s: 35}") 48 | } 49 | 50 | // As time goes by all set buckets are refilled with appropriate rates. 51 | func (s *BucketSetSuite) TestConsumeRefill(c *C) { 52 | // Given 53 | rates := NewRateSet() 54 | rates.Add(10*time.Second, 10, 20) 55 | rates.Add(100*time.Second, 20, 50) 56 | tbs := NewTokenBucketSet(rates, s.clock) 57 | tbs.Consume(15) 58 | c.Assert(tbs.debugState(), Equals, "{10s: 5}, {1m40s: 35}") 59 | // When 60 | s.clock.Sleep(10 * time.Second) 61 | delay, err := tbs.Consume(0) // Consumes nothing but forces an internal state update. 62 | // Then 63 | c.Assert(delay, Equals, time.Duration(0)) 64 | c.Assert(err, IsNil) 65 | c.Assert(tbs.debugState(), Equals, "{10s: 15}, {1m40s: 37}") 66 | } 67 | 68 | // If the first bucket in the set has no enough tokens to allow desired 69 | // consumption then an appropriate delay is returned. 70 | func (s *BucketSetSuite) TestConsumeLimitedBy1st(c *C) { 71 | // Given 72 | rates := NewRateSet() 73 | rates.Add(10*time.Second, 10, 10) 74 | rates.Add(100*time.Second, 20, 20) 75 | tbs := NewTokenBucketSet(rates, s.clock) 76 | tbs.Consume(5) 77 | c.Assert(tbs.debugState(), Equals, "{10s: 5}, {1m40s: 15}") 78 | // When 79 | delay, err := tbs.Consume(10) 80 | // Then 81 | c.Assert(delay, Equals, 5*time.Second) 82 | c.Assert(err, IsNil) 83 | c.Assert(tbs.debugState(), Equals, "{10s: 5}, {1m40s: 15}") 84 | } 85 | 86 | // If the second bucket in the set has no enough tokens to allow desired 87 | // consumption then an appropriate delay is returned. 88 | func (s *BucketSetSuite) TestConsumeLimitedBy2st(c *C) { 89 | // Given 90 | rates := NewRateSet() 91 | rates.Add(10*time.Second, 10, 10) 92 | rates.Add(100*time.Second, 20, 20) 93 | tbs := NewTokenBucketSet(rates, s.clock) 94 | tbs.Consume(10) 95 | s.clock.Sleep(10 * time.Second) 96 | tbs.Consume(10) 97 | s.clock.Sleep(5 * time.Second) 98 | tbs.Consume(0) 99 | c.Assert(tbs.debugState(), Equals, "{10s: 5}, {1m40s: 3}") 100 | // When 101 | delay, err := tbs.Consume(10) 102 | // Then 103 | c.Assert(delay, Equals, 7*(5*time.Second)) 104 | c.Assert(err, IsNil) 105 | c.Assert(tbs.debugState(), Equals, "{10s: 5}, {1m40s: 3}") 106 | } 107 | 108 | // An attempt to consume more tokens then the smallest bucket capacity results 109 | // in error. 110 | func (s *BucketSetSuite) TestConsumeMoreThenBurst(c *C) { 111 | // Given 112 | rates := NewRateSet() 113 | rates.Add(1*time.Second, 10, 20) 114 | rates.Add(10*time.Second, 50, 100) 115 | tbs := NewTokenBucketSet(rates, s.clock) 116 | tbs.Consume(5) 117 | c.Assert(tbs.debugState(), Equals, "{1s: 15}, {10s: 95}") 118 | // When 119 | _, err := tbs.Consume(21) 120 | //Then 121 | c.Assert(tbs.debugState(), Equals, "{1s: 15}, {10s: 95}") 122 | c.Assert(err, NotNil) 123 | } 124 | 125 | // Update operation can add buckets. 126 | func (s *BucketSetSuite) TestUpdateMore(c *C) { 127 | // Given 128 | rates := NewRateSet() 129 | rates.Add(1*time.Second, 10, 20) 130 | rates.Add(10*time.Second, 20, 50) 131 | rates.Add(20*time.Second, 45, 90) 132 | tbs := NewTokenBucketSet(rates, s.clock) 133 | tbs.Consume(5) 134 | c.Assert(tbs.debugState(), Equals, "{1s: 15}, {10s: 45}, {20s: 85}") 135 | rates = NewRateSet() 136 | rates.Add(10*time.Second, 30, 40) 137 | rates.Add(11*time.Second, 30, 40) 138 | rates.Add(12*time.Second, 30, 40) 139 | rates.Add(13*time.Second, 30, 40) 140 | // When 141 | tbs.Update(rates) 142 | // Then 143 | c.Assert(tbs.debugState(), Equals, "{10s: 40}, {11s: 40}, {12s: 40}, {13s: 40}") 144 | c.Assert(tbs.maxPeriod, Equals, 13*time.Second) 145 | } 146 | 147 | // Update operation can remove buckets. 148 | func (s *BucketSetSuite) TestUpdateLess(c *C) { 149 | // Given 150 | rates := NewRateSet() 151 | rates.Add(1*time.Second, 10, 20) 152 | rates.Add(10*time.Second, 20, 50) 153 | rates.Add(20*time.Second, 45, 90) 154 | rates.Add(30*time.Second, 50, 100) 155 | tbs := NewTokenBucketSet(rates, s.clock) 156 | tbs.Consume(5) 157 | c.Assert(tbs.debugState(), Equals, "{1s: 15}, {10s: 45}, {20s: 85}, {30s: 95}") 158 | rates = NewRateSet() 159 | rates.Add(10*time.Second, 25, 20) 160 | rates.Add(20*time.Second, 30, 21) 161 | // When 162 | tbs.Update(rates) 163 | // Then 164 | c.Assert(tbs.debugState(), Equals, "{10s: 20}, {20s: 21}") 165 | c.Assert(tbs.maxPeriod, Equals, 20*time.Second) 166 | } 167 | 168 | // Update operation can remove buckets. 169 | func (s *BucketSetSuite) TestUpdateAllDifferent(c *C) { 170 | // Given 171 | rates := NewRateSet() 172 | rates.Add(10*time.Second, 20, 50) 173 | rates.Add(30*time.Second, 50, 100) 174 | tbs := NewTokenBucketSet(rates, s.clock) 175 | tbs.Consume(5) 176 | c.Assert(tbs.debugState(), Equals, "{10s: 45}, {30s: 95}") 177 | rates = NewRateSet() 178 | rates.Add(1*time.Second, 10, 40) 179 | rates.Add(60*time.Second, 100, 150) 180 | // When 181 | tbs.Update(rates) 182 | // Then 183 | c.Assert(tbs.debugState(), Equals, "{1s: 40}, {1m0s: 150}") 184 | c.Assert(tbs.maxPeriod, Equals, 60*time.Second) 185 | } 186 | -------------------------------------------------------------------------------- /ratelimit/tokenlimiter.go: -------------------------------------------------------------------------------- 1 | // Tokenbucket based request rate limiter 2 | package ratelimit 3 | 4 | import ( 5 | "fmt" 6 | "net/http" 7 | "sync" 8 | "time" 9 | 10 | "github.com/mailgun/timetools" 11 | "github.com/mailgun/ttlmap" 12 | "github.com/vulcand/oxy/utils" 13 | ) 14 | 15 | const DefaultCapacity = 65536 16 | 17 | // RateSet maintains a set of rates. It can contain only one rate per period at a time. 18 | type RateSet struct { 19 | m map[time.Duration]*rate 20 | } 21 | 22 | // NewRateSet crates an empty `RateSet` instance. 23 | func NewRateSet() *RateSet { 24 | rs := new(RateSet) 25 | rs.m = make(map[time.Duration]*rate) 26 | return rs 27 | } 28 | 29 | // Add adds a rate to the set. If there is a rate with the same period in the 30 | // set then the new rate overrides the old one. 31 | func (rs *RateSet) Add(period time.Duration, average int64, burst int64) error { 32 | if period <= 0 { 33 | return fmt.Errorf("Invalid period: %v", period) 34 | } 35 | if average <= 0 { 36 | return fmt.Errorf("Invalid average: %v", average) 37 | } 38 | if burst <= 0 { 39 | return fmt.Errorf("Invalid burst: %v", burst) 40 | } 41 | rs.m[period] = &rate{period, average, burst} 42 | return nil 43 | } 44 | 45 | func (rs *RateSet) String() string { 46 | return fmt.Sprint(rs.m) 47 | } 48 | 49 | type RateExtractor interface { 50 | Extract(r *http.Request) (*RateSet, error) 51 | } 52 | 53 | type RateExtractorFunc func(r *http.Request) (*RateSet, error) 54 | 55 | func (e RateExtractorFunc) Extract(r *http.Request) (*RateSet, error) { 56 | return e(r) 57 | } 58 | 59 | // TokenLimiter implements rate limiting middleware. 60 | type TokenLimiter struct { 61 | defaultRates *RateSet 62 | extract utils.SourceExtractor 63 | extractRates RateExtractor 64 | clock timetools.TimeProvider 65 | mutex sync.Mutex 66 | bucketSets *ttlmap.TtlMap 67 | errHandler utils.ErrorHandler 68 | log utils.Logger 69 | capacity int 70 | next http.Handler 71 | } 72 | 73 | // New constructs a `TokenLimiter` middleware instance. 74 | func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet, opts ...TokenLimiterOption) (*TokenLimiter, error) { 75 | if defaultRates == nil || len(defaultRates.m) == 0 { 76 | return nil, fmt.Errorf("Provide default rates") 77 | } 78 | if extract == nil { 79 | return nil, fmt.Errorf("Provide extract function") 80 | } 81 | tl := &TokenLimiter{ 82 | next: next, 83 | defaultRates: defaultRates, 84 | extract: extract, 85 | } 86 | 87 | for _, o := range opts { 88 | if err := o(tl); err != nil { 89 | return nil, err 90 | } 91 | } 92 | setDefaults(tl) 93 | bucketSets, err := ttlmap.NewMapWithProvider(tl.capacity, tl.clock) 94 | if err != nil { 95 | return nil, err 96 | } 97 | tl.bucketSets = bucketSets 98 | return tl, nil 99 | } 100 | 101 | func (tl *TokenLimiter) Wrap(next http.Handler) { 102 | tl.next = next 103 | } 104 | 105 | func (tl *TokenLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) { 106 | source, amount, err := tl.extract.Extract(req) 107 | if err != nil { 108 | tl.errHandler.ServeHTTP(w, req, err) 109 | return 110 | } 111 | 112 | if err := tl.consumeRates(req, source, amount); err != nil { 113 | tl.log.Infof("limiting request %v %v, limit: %v", req.Method, req.URL, err) 114 | tl.errHandler.ServeHTTP(w, req, err) 115 | return 116 | } 117 | 118 | tl.next.ServeHTTP(w, req) 119 | } 120 | 121 | func (tl *TokenLimiter) consumeRates(req *http.Request, source string, amount int64) error { 122 | tl.mutex.Lock() 123 | defer tl.mutex.Unlock() 124 | 125 | effectiveRates := tl.resolveRates(req) 126 | bucketSetI, exists := tl.bucketSets.Get(source) 127 | var bucketSet *TokenBucketSet 128 | 129 | if exists { 130 | bucketSet = bucketSetI.(*TokenBucketSet) 131 | bucketSet.Update(effectiveRates) 132 | } else { 133 | bucketSet = NewTokenBucketSet(effectiveRates, tl.clock) 134 | // We set ttl as 10 times rate period. E.g. if rate is 100 requests/second per client ip 135 | // the counters for this ip will expire after 10 seconds of inactivity 136 | tl.bucketSets.Set(source, bucketSet, int(bucketSet.maxPeriod/time.Second)*10+1) 137 | } 138 | delay, err := bucketSet.Consume(amount) 139 | if err != nil { 140 | return err 141 | } 142 | if delay > 0 { 143 | return &MaxRateError{delay: delay} 144 | } 145 | return nil 146 | } 147 | 148 | // effectiveRates retrieves rates to be applied to the request. 149 | func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet { 150 | // If configuration mapper is not specified for this instance, then return 151 | // the default bucket specs. 152 | if tl.extractRates == nil { 153 | return tl.defaultRates 154 | } 155 | 156 | rates, err := tl.extractRates.Extract(req) 157 | if err != nil { 158 | tl.log.Errorf("Failed to retrieve rates: %v", err) 159 | return tl.defaultRates 160 | } 161 | 162 | // If the returned rate set is empty then used the default one. 163 | if len(rates.m) == 0 { 164 | return tl.defaultRates 165 | } 166 | 167 | return rates 168 | } 169 | 170 | type MaxRateError struct { 171 | delay time.Duration 172 | } 173 | 174 | func (m *MaxRateError) Error() string { 175 | return fmt.Sprintf("max rate reached: retry-in %v", m.delay) 176 | } 177 | 178 | type RateErrHandler struct { 179 | } 180 | 181 | func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { 182 | if rerr, ok := err.(*MaxRateError); ok { 183 | w.Header().Set("X-Retry-In", rerr.delay.String()) 184 | w.WriteHeader(429) 185 | w.Write([]byte(err.Error())) 186 | return 187 | } 188 | utils.DefaultHandler.ServeHTTP(w, req, err) 189 | } 190 | 191 | type TokenLimiterOption func(l *TokenLimiter) error 192 | 193 | // Logger sets the logger that will be used by this middleware. 194 | func Logger(l utils.Logger) TokenLimiterOption { 195 | return func(cl *TokenLimiter) error { 196 | cl.log = l 197 | return nil 198 | } 199 | } 200 | 201 | // ErrorHandler sets error handler of the server 202 | func ErrorHandler(h utils.ErrorHandler) TokenLimiterOption { 203 | return func(cl *TokenLimiter) error { 204 | cl.errHandler = h 205 | return nil 206 | } 207 | } 208 | 209 | func ExtractRates(e RateExtractor) TokenLimiterOption { 210 | return func(cl *TokenLimiter) error { 211 | cl.extractRates = e 212 | return nil 213 | } 214 | } 215 | 216 | func Clock(clock timetools.TimeProvider) TokenLimiterOption { 217 | return func(cl *TokenLimiter) error { 218 | cl.clock = clock 219 | return nil 220 | } 221 | } 222 | 223 | func Capacity(cap int) TokenLimiterOption { 224 | return func(cl *TokenLimiter) error { 225 | if cap <= 0 { 226 | return fmt.Errorf("bad capacity: %v", cap) 227 | } 228 | cl.capacity = cap 229 | return nil 230 | } 231 | } 232 | 233 | var defaultErrHandler = &RateErrHandler{} 234 | 235 | func setDefaults(tl *TokenLimiter) { 236 | if tl.log == nil { 237 | tl.log = utils.NullLogger 238 | } 239 | if tl.capacity <= 0 { 240 | tl.capacity = DefaultCapacity 241 | } 242 | if tl.clock == nil { 243 | tl.clock = &timetools.RealTime{} 244 | } 245 | if tl.errHandler == nil { 246 | tl.errHandler = defaultErrHandler 247 | } 248 | } 249 | -------------------------------------------------------------------------------- /ratelimit/tokenlimiter_test.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "time" 9 | 10 | "github.com/mailgun/timetools" 11 | "github.com/vulcand/oxy/testutils" 12 | "github.com/vulcand/oxy/utils" 13 | 14 | . "gopkg.in/check.v1" 15 | ) 16 | 17 | type LimiterSuite struct { 18 | clock *timetools.FreezedTime 19 | } 20 | 21 | var _ = Suite(&LimiterSuite{}) 22 | 23 | func (s *LimiterSuite) SetUpSuite(c *C) { 24 | s.clock = &timetools.FreezedTime{ 25 | CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), 26 | } 27 | } 28 | 29 | func (s *LimiterSuite) TestRateSetAdd(c *C) { 30 | rs := NewRateSet() 31 | 32 | // Invalid period 33 | err := rs.Add(0, 1, 1) 34 | c.Assert(err, NotNil) 35 | 36 | // Invalid Average 37 | err = rs.Add(time.Second, 0, 1) 38 | c.Assert(err, NotNil) 39 | 40 | // Invalid Burst 41 | err = rs.Add(time.Second, 1, 0) 42 | c.Assert(err, NotNil) 43 | 44 | err = rs.Add(time.Second, 1, 1) 45 | c.Assert(err, IsNil) 46 | c.Assert("map[1s:rate(1/1s, burst=1)]", Equals, fmt.Sprint(rs)) 47 | } 48 | 49 | // We've hit the limit and were able to proceed on the next time run 50 | func (s *LimiterSuite) TestHitLimit(c *C) { 51 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 52 | w.Write([]byte("hello")) 53 | }) 54 | 55 | rates := NewRateSet() 56 | rates.Add(time.Second, 1, 1) 57 | 58 | l, err := New(handler, headerLimit, rates, Clock(s.clock)) 59 | c.Assert(err, IsNil) 60 | 61 | srv := httptest.NewServer(l) 62 | defer srv.Close() 63 | 64 | re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) 65 | c.Assert(err, IsNil) 66 | c.Assert(re.StatusCode, Equals, http.StatusOK) 67 | 68 | // Next request from the same source hits rate limit 69 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 70 | c.Assert(err, IsNil) 71 | c.Assert(re.StatusCode, Equals, 429) 72 | 73 | // Second later, the request from this ip will succeed 74 | s.clock.Sleep(time.Second) 75 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 76 | c.Assert(err, IsNil) 77 | c.Assert(re.StatusCode, Equals, http.StatusOK) 78 | } 79 | 80 | // We've failed to extract client ip 81 | func (s *LimiterSuite) TestFailure(c *C) { 82 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 83 | w.Write([]byte("hello")) 84 | }) 85 | 86 | rates := NewRateSet() 87 | rates.Add(time.Second, 1, 1) 88 | 89 | l, err := New(handler, faultyExtract, rates, Clock(s.clock)) 90 | c.Assert(err, IsNil) 91 | 92 | srv := httptest.NewServer(l) 93 | defer srv.Close() 94 | 95 | re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) 96 | c.Assert(err, IsNil) 97 | c.Assert(re.StatusCode, Equals, http.StatusInternalServerError) 98 | } 99 | 100 | // Make sure rates from different ips are controlled separatedly 101 | func (s *LimiterSuite) TestIsolation(c *C) { 102 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 103 | w.Write([]byte("hello")) 104 | }) 105 | 106 | rates := NewRateSet() 107 | rates.Add(time.Second, 1, 1) 108 | 109 | l, err := New(handler, headerLimit, rates, Clock(s.clock)) 110 | c.Assert(err, IsNil) 111 | 112 | srv := httptest.NewServer(l) 113 | defer srv.Close() 114 | 115 | re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) 116 | c.Assert(err, IsNil) 117 | c.Assert(re.StatusCode, Equals, http.StatusOK) 118 | 119 | // Next request from the same source hits rate limit 120 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 121 | c.Assert(err, IsNil) 122 | c.Assert(re.StatusCode, Equals, 429) 123 | 124 | // The request from other source can proceed 125 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "b")) 126 | c.Assert(err, IsNil) 127 | c.Assert(re.StatusCode, Equals, http.StatusOK) 128 | } 129 | 130 | // Make sure that expiration works (Expiration is triggered after significant amount of time passes) 131 | func (s *LimiterSuite) TestExpiration(c *C) { 132 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 133 | w.Write([]byte("hello")) 134 | }) 135 | 136 | rates := NewRateSet() 137 | rates.Add(time.Second, 1, 1) 138 | 139 | l, err := New(handler, headerLimit, rates, Clock(s.clock)) 140 | c.Assert(err, IsNil) 141 | 142 | srv := httptest.NewServer(l) 143 | defer srv.Close() 144 | 145 | re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) 146 | c.Assert(err, IsNil) 147 | c.Assert(re.StatusCode, Equals, http.StatusOK) 148 | 149 | // Next request from the same source hits rate limit 150 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 151 | c.Assert(err, IsNil) 152 | c.Assert(re.StatusCode, Equals, 429) 153 | 154 | // 24 hours later, the request from this ip will succeed 155 | s.clock.Sleep(24 * time.Hour) 156 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 157 | c.Assert(err, IsNil) 158 | c.Assert(re.StatusCode, Equals, http.StatusOK) 159 | } 160 | 161 | // If rate limiting configuration is valid, then it is applied. 162 | func (s *LimiterSuite) TestExtractRates(c *C) { 163 | // Given 164 | extractRates := func(*http.Request) (*RateSet, error) { 165 | rates := NewRateSet() 166 | rates.Add(time.Second, 2, 2) 167 | rates.Add(60*time.Second, 10, 10) 168 | return rates, nil 169 | } 170 | rates := NewRateSet() 171 | rates.Add(time.Second, 1, 1) 172 | 173 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 174 | w.Write([]byte("hello")) 175 | }) 176 | 177 | tl, err := New(handler, headerLimit, rates, Clock(s.clock), ExtractRates(RateExtractorFunc(extractRates))) 178 | c.Assert(err, IsNil) 179 | 180 | srv := httptest.NewServer(tl) 181 | defer srv.Close() 182 | 183 | // When/Then: The configured rate is applied, which 2 req/second 184 | re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) 185 | c.Assert(err, IsNil) 186 | c.Assert(re.StatusCode, Equals, http.StatusOK) 187 | 188 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 189 | c.Assert(err, IsNil) 190 | c.Assert(re.StatusCode, Equals, http.StatusOK) 191 | 192 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 193 | c.Assert(err, IsNil) 194 | c.Assert(re.StatusCode, Equals, 429) 195 | 196 | s.clock.Sleep(time.Second) 197 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 198 | c.Assert(err, IsNil) 199 | c.Assert(re.StatusCode, Equals, http.StatusOK) 200 | } 201 | 202 | // If configMapper returns error, then the default rate is applied. 203 | func (s *LimiterSuite) TestBadRateExtractor(c *C) { 204 | // Given 205 | extractor := func(*http.Request) (*RateSet, error) { 206 | return nil, fmt.Errorf("Boom!") 207 | } 208 | rates := NewRateSet() 209 | rates.Add(time.Second, 1, 1) 210 | 211 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 212 | w.Write([]byte("hello")) 213 | }) 214 | 215 | l, err := New(handler, headerLimit, rates, Clock(s.clock), ExtractRates(RateExtractorFunc(extractor))) 216 | c.Assert(err, IsNil) 217 | 218 | srv := httptest.NewServer(l) 219 | defer srv.Close() 220 | 221 | // When/Then: The default rate is applied, which 1 req/second 222 | re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) 223 | c.Assert(err, IsNil) 224 | c.Assert(re.StatusCode, Equals, http.StatusOK) 225 | 226 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 227 | c.Assert(err, IsNil) 228 | c.Assert(re.StatusCode, Equals, 429) 229 | 230 | s.clock.Sleep(time.Second) 231 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 232 | c.Assert(err, IsNil) 233 | c.Assert(re.StatusCode, Equals, http.StatusOK) 234 | } 235 | 236 | // If configMapper returns empty rates, then the default rate is applied. 237 | func (s *LimiterSuite) TestExtractorEmpty(c *C) { 238 | // Given 239 | extractor := func(*http.Request) (*RateSet, error) { 240 | return NewRateSet(), nil 241 | } 242 | rates := NewRateSet() 243 | rates.Add(time.Second, 1, 1) 244 | 245 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 246 | w.Write([]byte("hello")) 247 | }) 248 | 249 | l, err := New(handler, headerLimit, rates, Clock(s.clock), ExtractRates(RateExtractorFunc(extractor))) 250 | c.Assert(err, IsNil) 251 | 252 | srv := httptest.NewServer(l) 253 | defer srv.Close() 254 | 255 | // When/Then: The default rate is applied, which 1 req/second 256 | re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) 257 | c.Assert(err, IsNil) 258 | c.Assert(re.StatusCode, Equals, http.StatusOK) 259 | 260 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 261 | c.Assert(err, IsNil) 262 | c.Assert(re.StatusCode, Equals, 429) 263 | 264 | s.clock.Sleep(time.Second) 265 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 266 | c.Assert(err, IsNil) 267 | c.Assert(re.StatusCode, Equals, http.StatusOK) 268 | } 269 | 270 | func (s *LimiterSuite) TestInvalidParams(c *C) { 271 | // Rates are missing 272 | rs := NewRateSet() 273 | rs.Add(time.Second, 1, 1) 274 | 275 | // Empty 276 | _, err := New(nil, nil, rs) 277 | c.Assert(err, NotNil) 278 | 279 | // Rates are empty 280 | _, err = New(nil, nil, NewRateSet()) 281 | c.Assert(err, NotNil) 282 | 283 | // Bad capacity 284 | _, err = New(nil, headerLimit, rs, Capacity(-1)) 285 | c.Assert(err, NotNil) 286 | } 287 | 288 | // We've hit the limit and were able to proceed on the next time run 289 | func (s *LimiterSuite) TestOptions(c *C) { 290 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 291 | w.Write([]byte("hello")) 292 | }) 293 | 294 | rates := NewRateSet() 295 | rates.Add(time.Second, 1, 1) 296 | 297 | errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { 298 | w.WriteHeader(http.StatusTeapot) 299 | w.Write([]byte(http.StatusText(http.StatusTeapot))) 300 | }) 301 | 302 | buf := &bytes.Buffer{} 303 | log := utils.NewFileLogger(buf, utils.INFO) 304 | 305 | l, err := New(handler, headerLimit, rates, ErrorHandler(errHandler), Logger(log), Clock(s.clock)) 306 | c.Assert(err, IsNil) 307 | 308 | srv := httptest.NewServer(l) 309 | defer srv.Close() 310 | 311 | re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) 312 | c.Assert(err, IsNil) 313 | c.Assert(re.StatusCode, Equals, http.StatusOK) 314 | 315 | re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) 316 | c.Assert(err, IsNil) 317 | c.Assert(re.StatusCode, Equals, http.StatusTeapot) 318 | 319 | c.Assert(len(buf.String()), Not(Equals), 0) 320 | } 321 | 322 | func headerLimiter(req *http.Request) (string, int64, error) { 323 | return req.Header.Get("Source"), 1, nil 324 | } 325 | 326 | func faultyExtractor(req *http.Request) (string, int64, error) { 327 | return "", -1, fmt.Errorf("oops") 328 | } 329 | 330 | var headerLimit = utils.ExtractorFunc(headerLimiter) 331 | var faultyExtract = utils.ExtractorFunc(faultyExtractor) 332 | -------------------------------------------------------------------------------- /roundrobin/rr.go: -------------------------------------------------------------------------------- 1 | // package roundrobin implements dynamic weighted round robin load balancer http handler 2 | package roundrobin 3 | 4 | import ( 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | "sync" 9 | 10 | "github.com/vulcand/oxy/utils" 11 | ) 12 | 13 | // Weight is an optional functional argument that sets weight of the server 14 | func Weight(w int) ServerOption { 15 | return func(s *server) error { 16 | if w < 0 { 17 | return fmt.Errorf("Weight should be >= 0") 18 | } 19 | s.weight = w 20 | return nil 21 | } 22 | } 23 | 24 | // ErrorHandler is a functional argument that sets error handler of the server 25 | func ErrorHandler(h utils.ErrorHandler) LBOption { 26 | return func(s *RoundRobin) error { 27 | s.errHandler = h 28 | return nil 29 | } 30 | } 31 | 32 | func EnableStickySession(ss *StickySession) LBOption { 33 | return func(s *RoundRobin) error { 34 | s.ss = ss 35 | return nil 36 | } 37 | } 38 | 39 | type RoundRobin struct { 40 | mutex *sync.Mutex 41 | next http.Handler 42 | errHandler utils.ErrorHandler 43 | // Current index (starts from -1) 44 | index int 45 | servers []*server 46 | currentWeight int 47 | ss *StickySession 48 | } 49 | 50 | func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) { 51 | rr := &RoundRobin{ 52 | next: next, 53 | index: -1, 54 | mutex: &sync.Mutex{}, 55 | servers: []*server{}, 56 | ss: nil, 57 | } 58 | for _, o := range opts { 59 | if err := o(rr); err != nil { 60 | return nil, err 61 | } 62 | } 63 | if rr.errHandler == nil { 64 | rr.errHandler = utils.DefaultHandler 65 | } 66 | return rr, nil 67 | } 68 | 69 | func (r *RoundRobin) Next() http.Handler { 70 | return r.next 71 | } 72 | 73 | func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { 74 | // make shallow copy of request before chaning anything to avoid side effects 75 | newReq := *req 76 | stuck := false 77 | if r.ss != nil { 78 | cookie_url, present, err := r.ss.GetBackend(&newReq, r.Servers()) 79 | 80 | if err != nil { 81 | r.errHandler.ServeHTTP(w, req, err) 82 | return 83 | } 84 | 85 | if present { 86 | newReq.URL = cookie_url 87 | stuck = true 88 | } 89 | } 90 | 91 | if !stuck { 92 | url, err := r.NextServer() 93 | if err != nil { 94 | r.errHandler.ServeHTTP(w, req, err) 95 | return 96 | } 97 | 98 | if r.ss != nil { 99 | r.ss.StickBackend(url, &w) 100 | } 101 | newReq.URL = url 102 | } 103 | r.next.ServeHTTP(w, &newReq) 104 | } 105 | 106 | func (r *RoundRobin) NextServer() (*url.URL, error) { 107 | srv, err := r.nextServer() 108 | if err != nil { 109 | return nil, err 110 | } 111 | return utils.CopyURL(srv.url), nil 112 | } 113 | 114 | func (r *RoundRobin) nextServer() (*server, error) { 115 | r.mutex.Lock() 116 | defer r.mutex.Unlock() 117 | 118 | if len(r.servers) == 0 { 119 | return nil, fmt.Errorf("no servers in the pool") 120 | } 121 | 122 | // The algo below may look messy, but is actually very simple 123 | // it calculates the GCD and subtracts it on every iteration, what interleaves servers 124 | // and allows us not to build an iterator every time we readjust weights 125 | 126 | // GCD across all enabled servers 127 | gcd := r.weightGcd() 128 | // Maximum weight across all enabled servers 129 | max := r.maxWeight() 130 | 131 | for { 132 | r.index = (r.index + 1) % len(r.servers) 133 | if r.index == 0 { 134 | r.currentWeight = r.currentWeight - gcd 135 | if r.currentWeight <= 0 { 136 | r.currentWeight = max 137 | if r.currentWeight == 0 { 138 | return nil, fmt.Errorf("all servers have 0 weight") 139 | } 140 | } 141 | } 142 | srv := r.servers[r.index] 143 | if srv.weight >= r.currentWeight { 144 | return srv, nil 145 | } 146 | } 147 | // We did full circle and found no available servers 148 | return nil, fmt.Errorf("no available servers") 149 | } 150 | 151 | func (r *RoundRobin) RemoveServer(u *url.URL) error { 152 | r.mutex.Lock() 153 | defer r.mutex.Unlock() 154 | 155 | e, index := r.findServerByURL(u) 156 | if e == nil { 157 | return fmt.Errorf("server not found") 158 | } 159 | r.servers = append(r.servers[:index], r.servers[index+1:]...) 160 | r.resetState() 161 | return nil 162 | } 163 | 164 | func (rr *RoundRobin) Servers() []*url.URL { 165 | rr.mutex.Lock() 166 | defer rr.mutex.Unlock() 167 | 168 | out := make([]*url.URL, len(rr.servers)) 169 | for i, srv := range rr.servers { 170 | out[i] = srv.url 171 | } 172 | return out 173 | } 174 | 175 | func (rr *RoundRobin) ServerWeight(u *url.URL) (int, bool) { 176 | rr.mutex.Lock() 177 | defer rr.mutex.Unlock() 178 | 179 | if s, _ := rr.findServerByURL(u); s != nil { 180 | return s.weight, true 181 | } 182 | return -1, false 183 | } 184 | 185 | // In case if server is already present in the load balancer, returns error 186 | func (rr *RoundRobin) UpsertServer(u *url.URL, options ...ServerOption) error { 187 | rr.mutex.Lock() 188 | defer rr.mutex.Unlock() 189 | 190 | if u == nil { 191 | return fmt.Errorf("server URL can't be nil") 192 | } 193 | 194 | if s, _ := rr.findServerByURL(u); s != nil { 195 | for _, o := range options { 196 | if err := o(s); err != nil { 197 | return err 198 | } 199 | } 200 | rr.resetState() 201 | return nil 202 | } 203 | 204 | srv := &server{url: utils.CopyURL(u)} 205 | for _, o := range options { 206 | if err := o(srv); err != nil { 207 | return err 208 | } 209 | } 210 | 211 | if srv.weight == 0 { 212 | srv.weight = defaultWeight 213 | } 214 | 215 | rr.servers = append(rr.servers, srv) 216 | rr.resetState() 217 | return nil 218 | } 219 | 220 | func (r *RoundRobin) resetIterator() { 221 | r.index = -1 222 | r.currentWeight = 0 223 | } 224 | 225 | func (r *RoundRobin) resetState() { 226 | r.resetIterator() 227 | } 228 | 229 | func (r *RoundRobin) findServerByURL(u *url.URL) (*server, int) { 230 | if len(r.servers) == 0 { 231 | return nil, -1 232 | } 233 | for i, s := range r.servers { 234 | if sameURL(u, s.url) { 235 | return s, i 236 | } 237 | } 238 | return nil, -1 239 | } 240 | 241 | func (rr *RoundRobin) maxWeight() int { 242 | max := -1 243 | for _, s := range rr.servers { 244 | if s.weight > max { 245 | max = s.weight 246 | } 247 | } 248 | return max 249 | } 250 | 251 | func (rr *RoundRobin) weightGcd() int { 252 | divisor := -1 253 | for _, s := range rr.servers { 254 | if divisor == -1 { 255 | divisor = s.weight 256 | } else { 257 | divisor = gcd(divisor, s.weight) 258 | } 259 | } 260 | return divisor 261 | } 262 | 263 | func gcd(a, b int) int { 264 | for b != 0 { 265 | a, b = b, a%b 266 | } 267 | return a 268 | } 269 | 270 | // ServerOption provides various options for server, e.g. weight 271 | type ServerOption func(*server) error 272 | 273 | // LBOption provides options for load balancer 274 | type LBOption func(*RoundRobin) error 275 | 276 | // Set additional parameters for the server can be supplied when adding server 277 | type server struct { 278 | url *url.URL 279 | // Relative weight for the enpoint to other enpoints in the load balancer 280 | weight int 281 | } 282 | 283 | const defaultWeight = 1 284 | 285 | func sameURL(a, b *url.URL) bool { 286 | return a.Path == b.Path && a.Host == b.Host && a.Scheme == b.Scheme 287 | } 288 | 289 | type balancerHandler interface { 290 | Servers() []*url.URL 291 | ServeHTTP(w http.ResponseWriter, req *http.Request) 292 | ServerWeight(u *url.URL) (int, bool) 293 | RemoveServer(u *url.URL) error 294 | UpsertServer(u *url.URL, options ...ServerOption) error 295 | NextServer() (*url.URL, error) 296 | Next() http.Handler 297 | } 298 | -------------------------------------------------------------------------------- /roundrobin/rr_test.go: -------------------------------------------------------------------------------- 1 | package roundrobin 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/vulcand/oxy/forward" 9 | "github.com/vulcand/oxy/testutils" 10 | "github.com/vulcand/oxy/utils" 11 | 12 | . "gopkg.in/check.v1" 13 | ) 14 | 15 | func TestRR(t *testing.T) { TestingT(t) } 16 | 17 | type RRSuite struct{} 18 | 19 | var _ = Suite(&RRSuite{}) 20 | 21 | func (s *RRSuite) TestNoServers(c *C) { 22 | fwd, err := forward.New() 23 | c.Assert(err, IsNil) 24 | 25 | lb, err := New(fwd) 26 | c.Assert(err, IsNil) 27 | 28 | proxy := httptest.NewServer(lb) 29 | defer proxy.Close() 30 | 31 | re, _, err := testutils.Get(proxy.URL) 32 | c.Assert(err, IsNil) 33 | c.Assert(re.StatusCode, Equals, http.StatusInternalServerError) 34 | } 35 | 36 | func (s *RRSuite) TestRemoveBadServer(c *C) { 37 | lb, err := New(nil) 38 | c.Assert(err, IsNil) 39 | 40 | c.Assert(lb.RemoveServer(testutils.ParseURI("http://google.com")), NotNil) 41 | } 42 | 43 | func (s *RRSuite) TestCustomErrHandler(c *C) { 44 | errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { 45 | w.WriteHeader(http.StatusTeapot) 46 | w.Write([]byte(http.StatusText(http.StatusTeapot))) 47 | }) 48 | 49 | fwd, err := forward.New() 50 | c.Assert(err, IsNil) 51 | 52 | lb, err := New(fwd, ErrorHandler(errHandler)) 53 | c.Assert(err, IsNil) 54 | 55 | proxy := httptest.NewServer(lb) 56 | defer proxy.Close() 57 | 58 | re, _, err := testutils.Get(proxy.URL) 59 | c.Assert(err, IsNil) 60 | c.Assert(re.StatusCode, Equals, http.StatusTeapot) 61 | } 62 | 63 | func (s *RRSuite) TestOneServer(c *C) { 64 | a := testutils.NewResponder("a") 65 | defer a.Close() 66 | 67 | fwd, err := forward.New() 68 | c.Assert(err, IsNil) 69 | 70 | lb, err := New(fwd) 71 | c.Assert(err, IsNil) 72 | 73 | lb.UpsertServer(testutils.ParseURI(a.URL)) 74 | 75 | proxy := httptest.NewServer(lb) 76 | defer proxy.Close() 77 | 78 | c.Assert(seq(c, proxy.URL, 3), DeepEquals, []string{"a", "a", "a"}) 79 | } 80 | 81 | func (s *RRSuite) TestSimple(c *C) { 82 | a := testutils.NewResponder("a") 83 | defer a.Close() 84 | 85 | b := testutils.NewResponder("b") 86 | defer b.Close() 87 | 88 | fwd, err := forward.New() 89 | c.Assert(err, IsNil) 90 | 91 | lb, err := New(fwd) 92 | c.Assert(err, IsNil) 93 | 94 | lb.UpsertServer(testutils.ParseURI(a.URL)) 95 | lb.UpsertServer(testutils.ParseURI(b.URL)) 96 | 97 | proxy := httptest.NewServer(lb) 98 | defer proxy.Close() 99 | 100 | c.Assert(seq(c, proxy.URL, 3), DeepEquals, []string{"a", "b", "a"}) 101 | } 102 | 103 | func (s *RRSuite) TestRemoveServer(c *C) { 104 | a := testutils.NewResponder("a") 105 | defer a.Close() 106 | 107 | b := testutils.NewResponder("b") 108 | defer b.Close() 109 | 110 | fwd, err := forward.New() 111 | c.Assert(err, IsNil) 112 | 113 | lb, err := New(fwd) 114 | c.Assert(err, IsNil) 115 | 116 | lb.UpsertServer(testutils.ParseURI(a.URL)) 117 | lb.UpsertServer(testutils.ParseURI(b.URL)) 118 | 119 | proxy := httptest.NewServer(lb) 120 | defer proxy.Close() 121 | 122 | c.Assert(seq(c, proxy.URL, 3), DeepEquals, []string{"a", "b", "a"}) 123 | 124 | lb.RemoveServer(testutils.ParseURI(a.URL)) 125 | 126 | c.Assert(seq(c, proxy.URL, 3), DeepEquals, []string{"b", "b", "b"}) 127 | } 128 | 129 | func (s *RRSuite) TestUpsertSame(c *C) { 130 | a := testutils.NewResponder("a") 131 | defer a.Close() 132 | 133 | fwd, err := forward.New() 134 | c.Assert(err, IsNil) 135 | 136 | lb, err := New(fwd) 137 | c.Assert(err, IsNil) 138 | 139 | c.Assert(lb.UpsertServer(testutils.ParseURI(a.URL)), IsNil) 140 | c.Assert(lb.UpsertServer(testutils.ParseURI(a.URL)), IsNil) 141 | 142 | proxy := httptest.NewServer(lb) 143 | defer proxy.Close() 144 | 145 | c.Assert(seq(c, proxy.URL, 3), DeepEquals, []string{"a", "a", "a"}) 146 | } 147 | 148 | func (s *RRSuite) TestUpsertWeight(c *C) { 149 | a := testutils.NewResponder("a") 150 | defer a.Close() 151 | 152 | b := testutils.NewResponder("b") 153 | defer b.Close() 154 | 155 | fwd, err := forward.New() 156 | c.Assert(err, IsNil) 157 | 158 | lb, err := New(fwd) 159 | c.Assert(err, IsNil) 160 | 161 | c.Assert(lb.UpsertServer(testutils.ParseURI(a.URL)), IsNil) 162 | c.Assert(lb.UpsertServer(testutils.ParseURI(b.URL)), IsNil) 163 | 164 | proxy := httptest.NewServer(lb) 165 | defer proxy.Close() 166 | 167 | c.Assert(seq(c, proxy.URL, 3), DeepEquals, []string{"a", "b", "a"}) 168 | 169 | c.Assert(lb.UpsertServer(testutils.ParseURI(b.URL), Weight(3)), IsNil) 170 | 171 | c.Assert(seq(c, proxy.URL, 4), DeepEquals, []string{"b", "b", "a", "b"}) 172 | } 173 | 174 | func (s *RRSuite) TestWeighted(c *C) { 175 | a := testutils.NewResponder("a") 176 | defer a.Close() 177 | 178 | b := testutils.NewResponder("b") 179 | defer b.Close() 180 | 181 | fwd, err := forward.New() 182 | c.Assert(err, IsNil) 183 | 184 | lb, err := New(fwd) 185 | c.Assert(err, IsNil) 186 | 187 | lb.UpsertServer(testutils.ParseURI(a.URL), Weight(3)) 188 | lb.UpsertServer(testutils.ParseURI(b.URL), Weight(2)) 189 | 190 | proxy := httptest.NewServer(lb) 191 | defer proxy.Close() 192 | 193 | c.Assert(seq(c, proxy.URL, 6), DeepEquals, []string{"a", "a", "b", "a", "b", "a"}) 194 | 195 | w, ok := lb.ServerWeight(testutils.ParseURI(a.URL)) 196 | c.Assert(w, Equals, 3) 197 | c.Assert(ok, Equals, true) 198 | 199 | w, ok = lb.ServerWeight(testutils.ParseURI(b.URL)) 200 | c.Assert(w, Equals, 2) 201 | c.Assert(ok, Equals, true) 202 | 203 | w, ok = lb.ServerWeight(testutils.ParseURI("http://caramba:4000")) 204 | c.Assert(w, Equals, -1) 205 | c.Assert(ok, Equals, false) 206 | } 207 | 208 | func seq(c *C, url string, repeat int) []string { 209 | out := []string{} 210 | for i := 0; i < repeat; i++ { 211 | _, body, err := testutils.Get(url) 212 | c.Assert(err, IsNil) 213 | out = append(out, string(body)) 214 | } 215 | return out 216 | } 217 | -------------------------------------------------------------------------------- /roundrobin/stickysessions.go: -------------------------------------------------------------------------------- 1 | // package stickysession is a mixin for load balancers that implements layer 7 (http cookie) session affinity 2 | package roundrobin 3 | 4 | import ( 5 | "net/http" 6 | "net/url" 7 | ) 8 | 9 | type StickySession struct { 10 | cookiename string 11 | } 12 | 13 | func NewStickySession(c string) *StickySession { 14 | return &StickySession{c} 15 | } 16 | 17 | // GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers. 18 | func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url.URL, bool, error) { 19 | cookie, err := req.Cookie(s.cookiename) 20 | switch err { 21 | case nil: 22 | case http.ErrNoCookie: 23 | return nil, false, nil 24 | default: 25 | return nil, false, err 26 | } 27 | 28 | s_url, err := url.Parse(cookie.Value) 29 | if err != nil { 30 | return nil, false, err 31 | } 32 | 33 | if s.isBackendAlive(s_url, servers) { 34 | return s_url, true, nil 35 | } else { 36 | return nil, false, nil 37 | } 38 | } 39 | 40 | func (s *StickySession) StickBackend(backend *url.URL, w *http.ResponseWriter) { 41 | c := &http.Cookie{Name: s.cookiename, Value: backend.String(), Path: "/"} 42 | http.SetCookie(*w, c) 43 | return 44 | } 45 | 46 | func (s *StickySession) isBackendAlive(needle *url.URL, haystack []*url.URL) bool { 47 | if len(haystack) == 0 { 48 | return false 49 | } 50 | 51 | for _, s := range haystack { 52 | if sameURL(needle, s) { 53 | return true 54 | } 55 | } 56 | return false 57 | } 58 | -------------------------------------------------------------------------------- /roundrobin/stickysessions_test.go: -------------------------------------------------------------------------------- 1 | package roundrobin 2 | 3 | import ( 4 | "io/ioutil" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/vulcand/oxy/forward" 10 | "github.com/vulcand/oxy/testutils" 11 | 12 | . "gopkg.in/check.v1" 13 | ) 14 | 15 | func TestSS(t *testing.T) { TestingT(t) } 16 | 17 | type SSSuite struct{} 18 | 19 | var _ = Suite(&SSSuite{}) 20 | 21 | func (s *SSSuite) TestBasic(c *C) { 22 | a := testutils.NewResponder("a") 23 | b := testutils.NewResponder("b") 24 | 25 | defer a.Close() 26 | defer b.Close() 27 | 28 | fwd, err := forward.New() 29 | c.Assert(err, IsNil) 30 | 31 | sticky := NewStickySession("test") 32 | c.Assert(sticky, NotNil) 33 | 34 | lb, err := New(fwd, EnableStickySession(sticky)) 35 | c.Assert(err, IsNil) 36 | 37 | lb.UpsertServer(testutils.ParseURI(a.URL)) 38 | lb.UpsertServer(testutils.ParseURI(b.URL)) 39 | 40 | proxy := httptest.NewServer(lb) 41 | defer proxy.Close() 42 | 43 | http_cli := &http.Client{} 44 | 45 | for i := 0; i < 10; i++ { 46 | req, err := http.NewRequest("GET", proxy.URL, nil) 47 | c.Assert(err, IsNil) 48 | req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) 49 | 50 | resp, err := http_cli.Do(req) 51 | c.Assert(err, IsNil) 52 | 53 | defer resp.Body.Close() 54 | body, err := ioutil.ReadAll(resp.Body) 55 | 56 | c.Assert(err, IsNil) 57 | c.Assert(string(body), Equals, "a") 58 | } 59 | } 60 | 61 | func (s *SSSuite) TestStickCookie(c *C) { 62 | a := testutils.NewResponder("a") 63 | b := testutils.NewResponder("b") 64 | 65 | defer a.Close() 66 | defer b.Close() 67 | 68 | fwd, err := forward.New() 69 | c.Assert(err, IsNil) 70 | 71 | sticky := NewStickySession("test") 72 | c.Assert(sticky, NotNil) 73 | 74 | lb, err := New(fwd, EnableStickySession(sticky)) 75 | c.Assert(err, IsNil) 76 | 77 | lb.UpsertServer(testutils.ParseURI(a.URL)) 78 | lb.UpsertServer(testutils.ParseURI(b.URL)) 79 | 80 | proxy := httptest.NewServer(lb) 81 | defer proxy.Close() 82 | 83 | resp, err := http.Get(proxy.URL) 84 | c.Assert(err, IsNil) 85 | 86 | c_out := resp.Cookies()[0] 87 | c.Assert(c_out.Name, Equals, "test") 88 | c.Assert(c_out.Value, Equals, a.URL) 89 | } 90 | 91 | func (s *SSSuite) TestRemoveRespondingServer(c *C) { 92 | a := testutils.NewResponder("a") 93 | b := testutils.NewResponder("b") 94 | 95 | defer a.Close() 96 | defer b.Close() 97 | 98 | fwd, err := forward.New() 99 | c.Assert(err, IsNil) 100 | 101 | sticky := NewStickySession("test") 102 | c.Assert(sticky, NotNil) 103 | 104 | lb, err := New(fwd, EnableStickySession(sticky)) 105 | c.Assert(err, IsNil) 106 | 107 | lb.UpsertServer(testutils.ParseURI(a.URL)) 108 | lb.UpsertServer(testutils.ParseURI(b.URL)) 109 | 110 | proxy := httptest.NewServer(lb) 111 | defer proxy.Close() 112 | 113 | http_cli := &http.Client{} 114 | 115 | for i := 0; i < 10; i++ { 116 | req, err := http.NewRequest("GET", proxy.URL, nil) 117 | c.Assert(err, IsNil) 118 | req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) 119 | 120 | resp, err := http_cli.Do(req) 121 | c.Assert(err, IsNil) 122 | 123 | defer resp.Body.Close() 124 | body, err := ioutil.ReadAll(resp.Body) 125 | 126 | c.Assert(err, IsNil) 127 | c.Assert(string(body), Equals, "a") 128 | } 129 | 130 | lb.RemoveServer(testutils.ParseURI(a.URL)) 131 | 132 | // Now, use the organic cookie response in our next requests. 133 | req, err := http.NewRequest("GET", proxy.URL, nil) 134 | req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) 135 | resp, err := http_cli.Do(req) 136 | c.Assert(err, IsNil) 137 | 138 | c.Assert(resp.Cookies()[0].Name, Equals, "test") 139 | c.Assert(resp.Cookies()[0].Value, Equals, b.URL) 140 | 141 | for i := 0; i < 10; i++ { 142 | req, err := http.NewRequest("GET", proxy.URL, nil) 143 | c.Assert(err, IsNil) 144 | 145 | resp, err := http_cli.Do(req) 146 | c.Assert(err, IsNil) 147 | 148 | defer resp.Body.Close() 149 | body, err := ioutil.ReadAll(resp.Body) 150 | 151 | c.Assert(err, IsNil) 152 | c.Assert(string(body), Equals, "b") 153 | } 154 | } 155 | 156 | func (s *SSSuite) TestRemoveAllServers(c *C) { 157 | a := testutils.NewResponder("a") 158 | b := testutils.NewResponder("b") 159 | 160 | defer a.Close() 161 | defer b.Close() 162 | 163 | fwd, err := forward.New() 164 | c.Assert(err, IsNil) 165 | 166 | sticky := NewStickySession("test") 167 | c.Assert(sticky, NotNil) 168 | 169 | lb, err := New(fwd, EnableStickySession(sticky)) 170 | c.Assert(err, IsNil) 171 | 172 | lb.UpsertServer(testutils.ParseURI(a.URL)) 173 | lb.UpsertServer(testutils.ParseURI(b.URL)) 174 | 175 | proxy := httptest.NewServer(lb) 176 | defer proxy.Close() 177 | 178 | http_cli := &http.Client{} 179 | 180 | for i := 0; i < 10; i++ { 181 | req, err := http.NewRequest("GET", proxy.URL, nil) 182 | c.Assert(err, IsNil) 183 | req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) 184 | 185 | resp, err := http_cli.Do(req) 186 | c.Assert(err, IsNil) 187 | 188 | defer resp.Body.Close() 189 | body, err := ioutil.ReadAll(resp.Body) 190 | 191 | c.Assert(err, IsNil) 192 | c.Assert(string(body), Equals, "a") 193 | } 194 | 195 | lb.RemoveServer(testutils.ParseURI(a.URL)) 196 | lb.RemoveServer(testutils.ParseURI(b.URL)) 197 | 198 | // Now, use the organic cookie response in our next requests. 199 | req, err := http.NewRequest("GET", proxy.URL, nil) 200 | req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) 201 | resp, err := http_cli.Do(req) 202 | c.Assert(err, IsNil) 203 | c.Assert(resp.StatusCode, Equals, http.StatusInternalServerError) 204 | } 205 | 206 | func (s *SSSuite) TestBadCookieVal(c *C) { 207 | a := testutils.NewResponder("a") 208 | 209 | defer a.Close() 210 | 211 | fwd, err := forward.New() 212 | c.Assert(err, IsNil) 213 | 214 | sticky := NewStickySession("test") 215 | c.Assert(sticky, NotNil) 216 | 217 | lb, err := New(fwd, EnableStickySession(sticky)) 218 | c.Assert(err, IsNil) 219 | 220 | lb.UpsertServer(testutils.ParseURI(a.URL)) 221 | 222 | proxy := httptest.NewServer(lb) 223 | defer proxy.Close() 224 | 225 | http_cli := &http.Client{} 226 | 227 | req, err := http.NewRequest("GET", proxy.URL, nil) 228 | c.Assert(err, IsNil) 229 | req.AddCookie(&http.Cookie{Name: "test", Value: "This is a patently invalid url! You can't parse it! :-)"}) 230 | 231 | resp, err := http_cli.Do(req) 232 | c.Assert(err, IsNil) 233 | 234 | body, err := ioutil.ReadAll(resp.Body) 235 | c.Assert(string(body), Equals, "a") 236 | 237 | // Now, cycle off the good server to cause an error 238 | lb.RemoveServer(testutils.ParseURI(a.URL)) 239 | 240 | http_cli = &http.Client{} 241 | 242 | req, err = http.NewRequest("GET", proxy.URL, nil) 243 | c.Assert(err, IsNil) 244 | req.AddCookie(&http.Cookie{Name: "test", Value: "This is a patently invalid url! You can't parse it! :-)"}) 245 | 246 | resp, err = http_cli.Do(req) 247 | c.Assert(err, IsNil) 248 | 249 | body, err = ioutil.ReadAll(resp.Body) 250 | c.Assert(resp.StatusCode, Equals, http.StatusInternalServerError) 251 | } 252 | -------------------------------------------------------------------------------- /stream/retry_test.go: -------------------------------------------------------------------------------- 1 | package stream 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "os" 7 | 8 | "github.com/vulcand/oxy/forward" 9 | "github.com/vulcand/oxy/roundrobin" 10 | "github.com/vulcand/oxy/testutils" 11 | "github.com/vulcand/oxy/utils" 12 | 13 | . "gopkg.in/check.v1" 14 | ) 15 | 16 | type RTSuite struct{} 17 | 18 | var _ = Suite(&RTSuite{}) 19 | 20 | func (s *RTSuite) TestSuccess(c *C) { 21 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 22 | w.Write([]byte("hello")) 23 | }) 24 | defer srv.Close() 25 | 26 | lb, rt := new(c, `IsNetworkError() && Attempts() <= 2`) 27 | 28 | proxy := httptest.NewServer(rt) 29 | defer proxy.Close() 30 | 31 | lb.UpsertServer(testutils.ParseURI(srv.URL)) 32 | 33 | re, body, err := testutils.Get(proxy.URL) 34 | c.Assert(err, IsNil) 35 | c.Assert(re.StatusCode, Equals, http.StatusOK) 36 | c.Assert(string(body), Equals, "hello") 37 | } 38 | 39 | func (s *RTSuite) TestRetryOnError(c *C) { 40 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 41 | w.Write([]byte("hello")) 42 | }) 43 | defer srv.Close() 44 | 45 | lb, rt := new(c, `IsNetworkError() && Attempts() <= 2`) 46 | 47 | proxy := httptest.NewServer(rt) 48 | defer proxy.Close() 49 | 50 | lb.UpsertServer(testutils.ParseURI("http://localhost:64321")) 51 | lb.UpsertServer(testutils.ParseURI(srv.URL)) 52 | 53 | re, body, err := testutils.Get(proxy.URL, testutils.Body("some request parameters")) 54 | c.Assert(err, IsNil) 55 | c.Assert(re.StatusCode, Equals, http.StatusOK) 56 | c.Assert(string(body), Equals, "hello") 57 | } 58 | 59 | func (s *RTSuite) TestRetryExceedAttempts(c *C) { 60 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 61 | w.Write([]byte("hello")) 62 | }) 63 | defer srv.Close() 64 | 65 | lb, rt := new(c, `IsNetworkError() && Attempts() <= 2`) 66 | 67 | proxy := httptest.NewServer(rt) 68 | defer proxy.Close() 69 | 70 | lb.UpsertServer(testutils.ParseURI("http://localhost:64321")) 71 | lb.UpsertServer(testutils.ParseURI("http://localhost:64322")) 72 | lb.UpsertServer(testutils.ParseURI("http://localhost:64323")) 73 | lb.UpsertServer(testutils.ParseURI(srv.URL)) 74 | 75 | re, _, err := testutils.Get(proxy.URL) 76 | c.Assert(err, IsNil) 77 | c.Assert(re.StatusCode, Equals, http.StatusBadGateway) 78 | } 79 | 80 | func new(c *C, p string) (*roundrobin.RoundRobin, *Streamer) { 81 | logger := utils.NewFileLogger(os.Stdout, utils.INFO) 82 | // forwarder will proxy the request to whatever destination 83 | fwd, err := forward.New(forward.Logger(logger)) 84 | c.Assert(err, IsNil) 85 | 86 | // load balancer will round robin request 87 | lb, err := roundrobin.New(fwd) 88 | c.Assert(err, IsNil) 89 | 90 | // stream handler will forward requests to redirect, make sure it uses files 91 | st, err := New(lb, Logger(logger), Retry(p), MemRequestBodyBytes(1)) 92 | c.Assert(err, IsNil) 93 | 94 | return lb, st 95 | } 96 | -------------------------------------------------------------------------------- /stream/stream.go: -------------------------------------------------------------------------------- 1 | /* 2 | package stream provides http.Handler middleware that solves several problems when dealing with http requests: 3 | 4 | Reads the entire request and response into buffer, optionally buffering it to disk for large requests. 5 | Checks the limits for the requests and responses, rejecting in case if the limit was exceeded. 6 | Changes request content-transfer-encoding from chunked and provides total size to the handlers. 7 | 8 | Examples of a streaming middleware: 9 | 10 | // sample HTTP handler 11 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 12 | w.Write([]byte("hello")) 13 | }) 14 | 15 | // Stream will read the body in buffer before passing the request to the handler 16 | // calculate total size of the request and transform it from chunked encoding 17 | // before passing to the server 18 | stream.New(handler) 19 | 20 | // This version will buffer up to 2MB in memory and will serialize any extra 21 | // to a temporary file, if the request size exceeds 10MB it will reject the request 22 | stream.New(handler, 23 | stream.MemRequestBodyBytes(2 * 1024 * 1024), 24 | stream.MaxRequestBodyBytes(10 * 1024 * 1024)) 25 | 26 | // Will do the same as above, but with responses 27 | stream.New(handler, 28 | stream.MemResponseBodyBytes(2 * 1024 * 1024), 29 | stream.MaxResponseBodyBytes(10 * 1024 * 1024)) 30 | 31 | // Stream will replay the request if the handler returns error at least 3 times 32 | // before returning the response 33 | stream.New(handler, stream.Retry(`IsNetworkError() && Attempts() <= 2`)) 34 | 35 | */ 36 | package stream 37 | 38 | import ( 39 | "fmt" 40 | "io" 41 | "io/ioutil" 42 | "net/http" 43 | 44 | "github.com/mailgun/multibuf" 45 | "github.com/vulcand/oxy/utils" 46 | ) 47 | 48 | const ( 49 | // Store up to 1MB in RAM 50 | DefaultMemBodyBytes = 1048576 51 | // No limit by default 52 | DefaultMaxBodyBytes = -1 53 | // Maximum retry attempts 54 | DefaultMaxRetryAttempts = 10 55 | ) 56 | 57 | var errHandler utils.ErrorHandler = &SizeErrHandler{} 58 | 59 | // Streamer is responsible for streaming requests and responses 60 | // It buffers large reqeuests and responses to disk, 61 | type Streamer struct { 62 | maxRequestBodyBytes int64 63 | memRequestBodyBytes int64 64 | 65 | maxResponseBodyBytes int64 66 | memResponseBodyBytes int64 67 | 68 | retryPredicate hpredicate 69 | 70 | next http.Handler 71 | errHandler utils.ErrorHandler 72 | log utils.Logger 73 | } 74 | 75 | // New returns a new streamer middleware. New() function supports optional functional arguments 76 | func New(next http.Handler, setters ...optSetter) (*Streamer, error) { 77 | strm := &Streamer{ 78 | next: next, 79 | 80 | maxRequestBodyBytes: DefaultMaxBodyBytes, 81 | memRequestBodyBytes: DefaultMemBodyBytes, 82 | 83 | maxResponseBodyBytes: DefaultMaxBodyBytes, 84 | memResponseBodyBytes: DefaultMemBodyBytes, 85 | } 86 | for _, s := range setters { 87 | if err := s(strm); err != nil { 88 | return nil, err 89 | } 90 | } 91 | if strm.errHandler == nil { 92 | strm.errHandler = errHandler 93 | } 94 | 95 | if strm.log == nil { 96 | strm.log = utils.NullLogger 97 | } 98 | 99 | return strm, nil 100 | } 101 | 102 | type optSetter func(s *Streamer) error 103 | 104 | // Retry provides a predicate that allows stream middleware to replay the request 105 | // if it matches certain condition, e.g. returns special error code. Available functions are: 106 | // 107 | // Attempts() - limits the amount of retry attempts 108 | // ResponseCode() - returns http response code 109 | // IsNetworkError() - tests if response code is related to networking error 110 | // 111 | // Example of the predicate: 112 | // 113 | // `Attempts() <= 2 && ResponseCode() == 502` 114 | // 115 | func Retry(predicate string) optSetter { 116 | return func(s *Streamer) error { 117 | p, err := parseExpression(predicate) 118 | if err != nil { 119 | return err 120 | } 121 | s.retryPredicate = p 122 | return nil 123 | } 124 | } 125 | 126 | // Logger sets the logger that will be used by this middleware. 127 | func Logger(l utils.Logger) optSetter { 128 | return func(s *Streamer) error { 129 | s.log = l 130 | return nil 131 | } 132 | } 133 | 134 | // ErrorHandler sets error handler of the server 135 | func ErrorHandler(h utils.ErrorHandler) optSetter { 136 | return func(s *Streamer) error { 137 | s.errHandler = h 138 | return nil 139 | } 140 | } 141 | 142 | // MaxRequestBodyBytes sets the maximum request body size in bytes 143 | func MaxRequestBodyBytes(m int64) optSetter { 144 | return func(s *Streamer) error { 145 | if m < 0 { 146 | return fmt.Errorf("max bytes should be >= 0 got %d", m) 147 | } 148 | s.maxRequestBodyBytes = m 149 | return nil 150 | } 151 | } 152 | 153 | // MaxRequestBody bytes sets the maximum request body to be stored in memory 154 | // stream middleware will serialize the excess to disk. 155 | func MemRequestBodyBytes(m int64) optSetter { 156 | return func(s *Streamer) error { 157 | if m < 0 { 158 | return fmt.Errorf("mem bytes should be >= 0 got %d", m) 159 | } 160 | s.memRequestBodyBytes = m 161 | return nil 162 | } 163 | } 164 | 165 | // MaxResponseBodyBytes sets the maximum request body size in bytes 166 | func MaxResponseBodyBytes(m int64) optSetter { 167 | return func(s *Streamer) error { 168 | if m < 0 { 169 | return fmt.Errorf("max bytes should be >= 0 got %d", m) 170 | } 171 | s.maxResponseBodyBytes = m 172 | return nil 173 | } 174 | } 175 | 176 | // MemResponseBodyBytes sets the maximum request body to be stored in memory 177 | // stream middleware will serialize the excess to disk. 178 | func MemResponseBodyBytes(m int64) optSetter { 179 | return func(s *Streamer) error { 180 | if m < 0 { 181 | return fmt.Errorf("mem bytes should be >= 0 got %d", m) 182 | } 183 | s.memResponseBodyBytes = m 184 | return nil 185 | } 186 | } 187 | 188 | // Wrap sets the next handler to be called by stream handler. 189 | func (s *Streamer) Wrap(next http.Handler) error { 190 | s.next = next 191 | return nil 192 | } 193 | 194 | func (s *Streamer) ServeHTTP(w http.ResponseWriter, req *http.Request) { 195 | if err := s.checkLimit(req); err != nil { 196 | s.log.Infof("request body over limit: %v", err) 197 | s.errHandler.ServeHTTP(w, req, err) 198 | return 199 | } 200 | 201 | // Read the body while keeping limits in mind. This reader controls the maximum bytes 202 | // to read into memory and disk. This reader returns an error if the total request size exceeds the 203 | // prefefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1 204 | // and the reader would be unbounded bufio in the http.Server 205 | body, err := multibuf.New(req.Body, multibuf.MaxBytes(s.maxRequestBodyBytes), multibuf.MemBytes(s.memRequestBodyBytes)) 206 | if err != nil || body == nil { 207 | s.errHandler.ServeHTTP(w, req, err) 208 | return 209 | } 210 | 211 | // Set request body to buffered reader that can replay the read and execute Seek 212 | // Note that we don't change the original request body as it's handled by the http server 213 | // and we don'w want to mess with standard library 214 | defer body.Close() 215 | 216 | // We need to set ContentLength based on known request size. The incoming request may have been 217 | // set without content length or using chunked TransferEncoding 218 | totalSize, err := body.Size() 219 | if err != nil { 220 | s.log.Errorf("failed to get size, err %v", err) 221 | s.errHandler.ServeHTTP(w, req, err) 222 | return 223 | } 224 | 225 | outreq := s.copyRequest(req, body, totalSize) 226 | 227 | attempt := 1 228 | for { 229 | // We create a special writer that will limit the response size, buffer it to disk if necessary 230 | writer, err := multibuf.NewWriterOnce(multibuf.MaxBytes(s.maxResponseBodyBytes), multibuf.MemBytes(s.memResponseBodyBytes)) 231 | if err != nil { 232 | s.errHandler.ServeHTTP(w, req, err) 233 | return 234 | } 235 | 236 | // We are mimicking http.ResponseWriter to replace writer with our special writer 237 | b := &bufferWriter{ 238 | header: make(http.Header), 239 | buffer: writer, 240 | } 241 | defer b.Close() 242 | 243 | s.next.ServeHTTP(b, outreq) 244 | 245 | var reader multibuf.MultiReader 246 | if b.expectBody(outreq) { 247 | rdr, err := writer.Reader() 248 | if err != nil { 249 | s.log.Errorf("failed to read response, err %v", err) 250 | s.errHandler.ServeHTTP(w, req, err) 251 | return 252 | } 253 | defer rdr.Close() 254 | reader = rdr 255 | } 256 | 257 | if (s.retryPredicate == nil || attempt > DefaultMaxRetryAttempts) || 258 | !s.retryPredicate(&context{r: req, attempt: attempt, responseCode: b.code, log: s.log}) { 259 | utils.CopyHeaders(w.Header(), b.Header()) 260 | w.WriteHeader(b.code) 261 | if reader != nil { 262 | io.Copy(w, reader) 263 | } 264 | return 265 | } 266 | 267 | attempt += 1 268 | if _, err := body.Seek(0, 0); err != nil { 269 | s.log.Errorf("Failed to rewind: error: %v", err) 270 | s.errHandler.ServeHTTP(w, req, err) 271 | return 272 | } 273 | outreq = s.copyRequest(req, body, totalSize) 274 | s.log.Infof("retry Request(%v %v) attempt %v", req.Method, req.URL, attempt) 275 | } 276 | } 277 | 278 | func (s *Streamer) copyRequest(req *http.Request, body io.ReadCloser, bodySize int64) *http.Request { 279 | o := *req 280 | o.URL = utils.CopyURL(req.URL) 281 | o.Header = make(http.Header) 282 | utils.CopyHeaders(o.Header, req.Header) 283 | o.ContentLength = bodySize 284 | // remove TransferEncoding that could have been previously set because we have transformed the request from chunked encoding 285 | o.TransferEncoding = []string{} 286 | // http.Transport will close the request body on any error, we are controlling the close process ourselves, so we override the closer here 287 | o.Body = ioutil.NopCloser(body) 288 | return &o 289 | } 290 | 291 | func (s *Streamer) checkLimit(req *http.Request) error { 292 | if s.maxRequestBodyBytes <= 0 { 293 | return nil 294 | } 295 | if req.ContentLength > s.maxRequestBodyBytes { 296 | return &multibuf.MaxSizeReachedError{MaxSize: s.maxRequestBodyBytes} 297 | } 298 | return nil 299 | } 300 | 301 | type bufferWriter struct { 302 | header http.Header 303 | code int 304 | buffer multibuf.WriterOnce 305 | } 306 | 307 | // RFC2616 #4.4 308 | func (b *bufferWriter) expectBody(r *http.Request) bool { 309 | if r.Method == "HEAD" { 310 | return false 311 | } 312 | if (b.code >= 100 && b.code < 200) || b.code == 204 || b.code == 304 { 313 | return false 314 | } 315 | if b.header.Get("Content-Length") == "" && b.header.Get("Transfer-Encoding") == "" { 316 | return false 317 | } 318 | if b.header.Get("Content-Length") == "0" { 319 | return false 320 | } 321 | return true 322 | } 323 | 324 | func (b *bufferWriter) Close() error { 325 | return b.buffer.Close() 326 | } 327 | 328 | func (b *bufferWriter) Header() http.Header { 329 | return b.header 330 | } 331 | 332 | func (b *bufferWriter) Write(buf []byte) (int, error) { 333 | return b.buffer.Write(buf) 334 | } 335 | 336 | // WriteHeader sets rw.Code. 337 | func (b *bufferWriter) WriteHeader(code int) { 338 | b.code = code 339 | } 340 | 341 | type SizeErrHandler struct { 342 | } 343 | 344 | func (e *SizeErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { 345 | if _, ok := err.(*multibuf.MaxSizeReachedError); ok { 346 | w.WriteHeader(http.StatusRequestEntityTooLarge) 347 | w.Write([]byte(http.StatusText(http.StatusRequestEntityTooLarge))) 348 | return 349 | } 350 | utils.DefaultHandler.ServeHTTP(w, req, err) 351 | } 352 | -------------------------------------------------------------------------------- /stream/stream_test.go: -------------------------------------------------------------------------------- 1 | package stream 2 | 3 | import ( 4 | "bufio" 5 | "crypto/tls" 6 | "fmt" 7 | "io/ioutil" 8 | "net" 9 | "net/http" 10 | "net/http/httptest" 11 | "os" 12 | "testing" 13 | 14 | "github.com/vulcand/oxy/forward" 15 | "github.com/vulcand/oxy/testutils" 16 | "github.com/vulcand/oxy/utils" 17 | 18 | . "gopkg.in/check.v1" 19 | ) 20 | 21 | func TestStream(t *testing.T) { TestingT(t) } 22 | 23 | type STSuite struct{} 24 | 25 | var _ = Suite(&STSuite{}) 26 | 27 | func (s *STSuite) TestSimple(c *C) { 28 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 29 | w.Write([]byte("hello")) 30 | }) 31 | defer srv.Close() 32 | 33 | // forwarder will proxy the request to whatever destination 34 | fwd, err := forward.New() 35 | c.Assert(err, IsNil) 36 | 37 | // this is our redirect to server 38 | rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 39 | req.URL = testutils.ParseURI(srv.URL) 40 | fwd.ServeHTTP(w, req) 41 | }) 42 | 43 | // stream handler will forward requests to redirect 44 | st, err := New(rdr, Logger(utils.NewFileLogger(os.Stdout, utils.INFO))) 45 | c.Assert(err, IsNil) 46 | 47 | proxy := httptest.NewServer(st) 48 | defer proxy.Close() 49 | 50 | re, body, err := testutils.Get(proxy.URL) 51 | c.Assert(err, IsNil) 52 | c.Assert(re.StatusCode, Equals, http.StatusOK) 53 | c.Assert(string(body), Equals, "hello") 54 | } 55 | 56 | func (s *STSuite) TestChunkedEncodingSuccess(c *C) { 57 | var reqBody string 58 | var contentLength int64 59 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 60 | body, err := ioutil.ReadAll(req.Body) 61 | c.Assert(err, IsNil) 62 | reqBody = string(body) 63 | contentLength = req.ContentLength 64 | w.Write([]byte("hello")) 65 | }) 66 | defer srv.Close() 67 | 68 | // forwarder will proxy the request to whatever destination 69 | fwd, err := forward.New() 70 | c.Assert(err, IsNil) 71 | 72 | // this is our redirect to server 73 | rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 74 | req.URL = testutils.ParseURI(srv.URL) 75 | fwd.ServeHTTP(w, req) 76 | }) 77 | 78 | // stream handler will forward requests to redirect 79 | st, err := New(rdr, Logger(utils.NewFileLogger(os.Stdout, utils.INFO))) 80 | c.Assert(err, IsNil) 81 | 82 | proxy := httptest.NewServer(st) 83 | defer proxy.Close() 84 | 85 | conn, err := net.Dial("tcp", testutils.ParseURI(proxy.URL).Host) 86 | c.Assert(err, IsNil) 87 | fmt.Fprintf(conn, "POST / HTTP/1.0\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n5\r\ntest1\r\n5\r\ntest2\r\n0\r\n\r\n") 88 | status, err := bufio.NewReader(conn).ReadString('\n') 89 | 90 | c.Assert(reqBody, Equals, "testtest1test2") 91 | c.Assert(status, Equals, "HTTP/1.0 200 OK\r\n") 92 | c.Assert(contentLength, Equals, int64(len(reqBody))) 93 | } 94 | 95 | func (s *STSuite) TestChunkedEncodingLimitReached(c *C) { 96 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 97 | w.Write([]byte("hello")) 98 | }) 99 | defer srv.Close() 100 | 101 | // forwarder will proxy the request to whatever destination 102 | fwd, err := forward.New() 103 | c.Assert(err, IsNil) 104 | 105 | // this is our redirect to server 106 | rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 107 | req.URL = testutils.ParseURI(srv.URL) 108 | fwd.ServeHTTP(w, req) 109 | }) 110 | 111 | // stream handler will forward requests to redirect 112 | st, err := New(rdr, Logger(utils.NewFileLogger(os.Stdout, utils.INFO)), MemRequestBodyBytes(4), MaxRequestBodyBytes(8)) 113 | c.Assert(err, IsNil) 114 | 115 | proxy := httptest.NewServer(st) 116 | defer proxy.Close() 117 | 118 | conn, err := net.Dial("tcp", testutils.ParseURI(proxy.URL).Host) 119 | c.Assert(err, IsNil) 120 | fmt.Fprintf(conn, "POST / HTTP/1.0\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n5\r\ntest1\r\n5\r\ntest2\r\n0\r\n\r\n") 121 | status, err := bufio.NewReader(conn).ReadString('\n') 122 | 123 | c.Assert(status, Equals, "HTTP/1.0 413 Request Entity Too Large\r\n") 124 | } 125 | 126 | func (s *STSuite) TestRequestLimitReached(c *C) { 127 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 128 | w.Write([]byte("hello")) 129 | }) 130 | defer srv.Close() 131 | 132 | // forwarder will proxy the request to whatever destination 133 | fwd, err := forward.New() 134 | c.Assert(err, IsNil) 135 | 136 | // this is our redirect to server 137 | rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 138 | req.URL = testutils.ParseURI(srv.URL) 139 | fwd.ServeHTTP(w, req) 140 | }) 141 | 142 | // stream handler will forward requests to redirect 143 | st, err := New(rdr, Logger(utils.NewFileLogger(os.Stdout, utils.INFO)), MaxRequestBodyBytes(4)) 144 | c.Assert(err, IsNil) 145 | 146 | proxy := httptest.NewServer(st) 147 | defer proxy.Close() 148 | 149 | re, _, err := testutils.Get(proxy.URL, testutils.Body("this request is too long")) 150 | c.Assert(err, IsNil) 151 | c.Assert(re.StatusCode, Equals, http.StatusRequestEntityTooLarge) 152 | } 153 | 154 | func (s *STSuite) TestResponseLimitReached(c *C) { 155 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 156 | w.Write([]byte("hello, this response is too large")) 157 | }) 158 | defer srv.Close() 159 | 160 | // forwarder will proxy the request to whatever destination 161 | fwd, err := forward.New() 162 | c.Assert(err, IsNil) 163 | 164 | // this is our redirect to server 165 | rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 166 | req.URL = testutils.ParseURI(srv.URL) 167 | fwd.ServeHTTP(w, req) 168 | }) 169 | 170 | // stream handler will forward requests to redirect 171 | st, err := New(rdr, Logger(utils.NewFileLogger(os.Stdout, utils.INFO)), MaxResponseBodyBytes(4)) 172 | c.Assert(err, IsNil) 173 | 174 | proxy := httptest.NewServer(st) 175 | defer proxy.Close() 176 | 177 | re, _, err := testutils.Get(proxy.URL) 178 | c.Assert(err, IsNil) 179 | c.Assert(re.StatusCode, Equals, http.StatusInternalServerError) 180 | } 181 | 182 | func (s *STSuite) TestFileStreamingResponse(c *C) { 183 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 184 | w.Write([]byte("hello, this response is too large to fit in memory")) 185 | }) 186 | defer srv.Close() 187 | 188 | // forwarder will proxy the request to whatever destination 189 | fwd, err := forward.New() 190 | c.Assert(err, IsNil) 191 | 192 | // this is our redirect to server 193 | rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 194 | req.URL = testutils.ParseURI(srv.URL) 195 | fwd.ServeHTTP(w, req) 196 | }) 197 | 198 | // stream handler will forward requests to redirect 199 | st, err := New(rdr, Logger(utils.NewFileLogger(os.Stdout, utils.INFO)), MemResponseBodyBytes(4)) 200 | c.Assert(err, IsNil) 201 | 202 | proxy := httptest.NewServer(st) 203 | defer proxy.Close() 204 | 205 | re, body, err := testutils.Get(proxy.URL) 206 | c.Assert(err, IsNil) 207 | c.Assert(re.StatusCode, Equals, http.StatusOK) 208 | c.Assert(string(body), Equals, "hello, this response is too large to fit in memory") 209 | } 210 | 211 | func (s *STSuite) TestCustomErrorHandler(c *C) { 212 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 213 | w.Write([]byte("hello, this response is too large")) 214 | }) 215 | defer srv.Close() 216 | 217 | // forwarder will proxy the request to whatever destination 218 | fwd, err := forward.New() 219 | c.Assert(err, IsNil) 220 | 221 | // this is our redirect to server 222 | rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 223 | req.URL = testutils.ParseURI(srv.URL) 224 | fwd.ServeHTTP(w, req) 225 | }) 226 | 227 | // stream handler will forward requests to redirect 228 | errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { 229 | w.WriteHeader(http.StatusTeapot) 230 | w.Write([]byte(http.StatusText(http.StatusTeapot))) 231 | }) 232 | st, err := New(rdr, Logger(utils.NewFileLogger(os.Stdout, utils.INFO)), MaxResponseBodyBytes(4), ErrorHandler(errHandler)) 233 | c.Assert(err, IsNil) 234 | 235 | proxy := httptest.NewServer(st) 236 | defer proxy.Close() 237 | 238 | re, _, err := testutils.Get(proxy.URL) 239 | c.Assert(err, IsNil) 240 | c.Assert(re.StatusCode, Equals, http.StatusTeapot) 241 | } 242 | 243 | func (s *STSuite) TestNotModified(c *C) { 244 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 245 | w.WriteHeader(http.StatusNotModified) 246 | }) 247 | defer srv.Close() 248 | 249 | // forwarder will proxy the request to whatever destination 250 | fwd, err := forward.New() 251 | c.Assert(err, IsNil) 252 | 253 | // this is our redirect to server 254 | rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 255 | req.URL = testutils.ParseURI(srv.URL) 256 | fwd.ServeHTTP(w, req) 257 | }) 258 | 259 | // stream handler will forward requests to redirect 260 | st, err := New(rdr, Logger(utils.NewFileLogger(os.Stdout, utils.INFO))) 261 | c.Assert(err, IsNil) 262 | 263 | proxy := httptest.NewServer(st) 264 | defer proxy.Close() 265 | 266 | re, _, err := testutils.Get(proxy.URL) 267 | c.Assert(err, IsNil) 268 | c.Assert(re.StatusCode, Equals, http.StatusNotModified) 269 | } 270 | 271 | func (s *STSuite) TestNoBody(c *C) { 272 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 273 | w.WriteHeader(http.StatusOK) 274 | }) 275 | defer srv.Close() 276 | 277 | // forwarder will proxy the request to whatever destination 278 | fwd, err := forward.New() 279 | c.Assert(err, IsNil) 280 | 281 | // this is our redirect to server 282 | rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 283 | req.URL = testutils.ParseURI(srv.URL) 284 | fwd.ServeHTTP(w, req) 285 | }) 286 | 287 | // stream handler will forward requests to redirect 288 | st, err := New(rdr, Logger(utils.NewFileLogger(os.Stdout, utils.INFO))) 289 | c.Assert(err, IsNil) 290 | 291 | proxy := httptest.NewServer(st) 292 | defer proxy.Close() 293 | 294 | re, _, err := testutils.Get(proxy.URL) 295 | c.Assert(err, IsNil) 296 | c.Assert(re.StatusCode, Equals, http.StatusOK) 297 | } 298 | 299 | // Make sure that stream handler preserves TLS settings 300 | func (s *STSuite) TestPreservesTLS(c *C) { 301 | srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { 302 | w.WriteHeader(http.StatusOK) 303 | w.Write([]byte("ok")) 304 | }) 305 | defer srv.Close() 306 | 307 | // forwarder will proxy the request to whatever destination 308 | fwd, err := forward.New() 309 | c.Assert(err, IsNil) 310 | 311 | var t *tls.ConnectionState 312 | // this is our redirect to server 313 | rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 314 | t = req.TLS 315 | req.URL = testutils.ParseURI(srv.URL) 316 | fwd.ServeHTTP(w, req) 317 | }) 318 | 319 | // stream handler will forward requests to redirect 320 | st, err := New(rdr, Logger(utils.NewFileLogger(os.Stdout, utils.INFO))) 321 | c.Assert(err, IsNil) 322 | 323 | proxy := httptest.NewUnstartedServer(st) 324 | proxy.StartTLS() 325 | defer proxy.Close() 326 | 327 | re, _, err := testutils.Get(proxy.URL) 328 | c.Assert(err, IsNil) 329 | c.Assert(re.StatusCode, Equals, http.StatusOK) 330 | 331 | c.Assert(t, NotNil) 332 | } 333 | -------------------------------------------------------------------------------- /stream/threshold.go: -------------------------------------------------------------------------------- 1 | package stream 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/vulcand/oxy/utils" 8 | "github.com/vulcand/predicate" 9 | ) 10 | 11 | func IsValidExpression(expr string) bool { 12 | _, err := parseExpression(expr) 13 | return err == nil 14 | } 15 | 16 | type context struct { 17 | r *http.Request 18 | attempt int 19 | responseCode int 20 | log utils.Logger 21 | } 22 | 23 | type hpredicate func(*context) bool 24 | 25 | // Parses expression in the go language into Failover predicates 26 | func parseExpression(in string) (hpredicate, error) { 27 | p, err := predicate.NewParser(predicate.Def{ 28 | Operators: predicate.Operators{ 29 | AND: and, 30 | OR: or, 31 | EQ: eq, 32 | NEQ: neq, 33 | LT: lt, 34 | GT: gt, 35 | LE: le, 36 | GE: ge, 37 | }, 38 | Functions: map[string]interface{}{ 39 | "RequestMethod": requestMethod, 40 | "IsNetworkError": isNetworkError, 41 | "Attempts": attempts, 42 | "ResponseCode": responseCode, 43 | }, 44 | }) 45 | if err != nil { 46 | return nil, err 47 | } 48 | out, err := p.Parse(in) 49 | if err != nil { 50 | return nil, err 51 | } 52 | pr, ok := out.(hpredicate) 53 | if !ok { 54 | return nil, fmt.Errorf("expected predicate, got %T", out) 55 | } 56 | return pr, nil 57 | } 58 | 59 | type toString func(c *context) string 60 | type toInt func(c *context) int 61 | 62 | // RequestMethod returns mapper of the request to its method e.g. POST 63 | func requestMethod() toString { 64 | return func(c *context) string { 65 | return c.r.Method 66 | } 67 | } 68 | 69 | // Attempts returns mapper of the request to the number of proxy attempts 70 | func attempts() toInt { 71 | return func(c *context) int { 72 | return c.attempt 73 | } 74 | } 75 | 76 | // ResponseCode returns mapper of the request to the last response code, returns 0 if there was no response code. 77 | func responseCode() toInt { 78 | return func(c *context) int { 79 | return c.responseCode 80 | } 81 | } 82 | 83 | // IsNetworkError returns a predicate that returns true if last attempt ended with network error. 84 | func isNetworkError() hpredicate { 85 | return func(c *context) bool { 86 | return c.responseCode == http.StatusBadGateway || c.responseCode == http.StatusGatewayTimeout 87 | } 88 | } 89 | 90 | // and returns predicate by joining the passed predicates with logical 'and' 91 | func and(fns ...hpredicate) hpredicate { 92 | return func(c *context) bool { 93 | for _, fn := range fns { 94 | if !fn(c) { 95 | return false 96 | } 97 | } 98 | return true 99 | } 100 | } 101 | 102 | // or returns predicate by joining the passed predicates with logical 'or' 103 | func or(fns ...hpredicate) hpredicate { 104 | return func(c *context) bool { 105 | for _, fn := range fns { 106 | if fn(c) { 107 | return true 108 | } 109 | } 110 | return false 111 | } 112 | } 113 | 114 | // not creates negation of the passed predicate 115 | func not(p hpredicate) hpredicate { 116 | return func(c *context) bool { 117 | return !p(c) 118 | } 119 | } 120 | 121 | // eq returns predicate that tests for equality of the value of the mapper and the constant 122 | func eq(m interface{}, value interface{}) (hpredicate, error) { 123 | switch mapper := m.(type) { 124 | case toString: 125 | return stringEQ(mapper, value) 126 | case toInt: 127 | return intEQ(mapper, value) 128 | } 129 | return nil, fmt.Errorf("unsupported argument: %T", m) 130 | } 131 | 132 | // neq returns predicate that tests for inequality of the value of the mapper and the constant 133 | func neq(m interface{}, value interface{}) (hpredicate, error) { 134 | p, err := eq(m, value) 135 | if err != nil { 136 | return nil, err 137 | } 138 | return not(p), nil 139 | } 140 | 141 | // lt returns predicate that tests that value of the mapper function is less than the constant 142 | func lt(m interface{}, value interface{}) (hpredicate, error) { 143 | switch mapper := m.(type) { 144 | case toInt: 145 | return intLT(mapper, value) 146 | } 147 | return nil, fmt.Errorf("unsupported argument: %T", m) 148 | } 149 | 150 | // le returns predicate that tests that value of the mapper function is less or equal than the constant 151 | func le(m interface{}, value interface{}) (hpredicate, error) { 152 | l, err := lt(m, value) 153 | if err != nil { 154 | return nil, err 155 | } 156 | e, err := eq(m, value) 157 | if err != nil { 158 | return nil, err 159 | } 160 | return func(c *context) bool { 161 | return l(c) || e(c) 162 | }, nil 163 | } 164 | 165 | // gt returns predicate that tests that value of the mapper function is greater than the constant 166 | func gt(m interface{}, value interface{}) (hpredicate, error) { 167 | switch mapper := m.(type) { 168 | case toInt: 169 | return intGT(mapper, value) 170 | } 171 | return nil, fmt.Errorf("unsupported argument: %T", m) 172 | } 173 | 174 | // ge returns predicate that tests that value of the mapper function is less or equal than the constant 175 | func ge(m interface{}, value interface{}) (hpredicate, error) { 176 | g, err := gt(m, value) 177 | if err != nil { 178 | return nil, err 179 | } 180 | e, err := eq(m, value) 181 | if err != nil { 182 | return nil, err 183 | } 184 | return func(c *context) bool { 185 | return g(c) || e(c) 186 | }, nil 187 | } 188 | 189 | func stringEQ(m toString, val interface{}) (hpredicate, error) { 190 | value, ok := val.(string) 191 | if !ok { 192 | return nil, fmt.Errorf("expected string, got %T", val) 193 | } 194 | return func(c *context) bool { 195 | return m(c) == value 196 | }, nil 197 | } 198 | 199 | func intEQ(m toInt, val interface{}) (hpredicate, error) { 200 | value, ok := val.(int) 201 | if !ok { 202 | return nil, fmt.Errorf("expected int, got %T", val) 203 | } 204 | return func(c *context) bool { 205 | return m(c) == value 206 | }, nil 207 | } 208 | 209 | func intLT(m toInt, val interface{}) (hpredicate, error) { 210 | value, ok := val.(int) 211 | if !ok { 212 | return nil, fmt.Errorf("expected int, got %T", val) 213 | } 214 | return func(c *context) bool { 215 | return m(c) < value 216 | }, nil 217 | } 218 | 219 | func intGT(m toInt, val interface{}) (hpredicate, error) { 220 | value, ok := val.(int) 221 | if !ok { 222 | return nil, fmt.Errorf("expected int, got %T", val) 223 | } 224 | return func(c *context) bool { 225 | return m(c) > value 226 | }, nil 227 | } 228 | -------------------------------------------------------------------------------- /testutils/utils.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | "net/http/httptest" 9 | "net/url" 10 | "strings" 11 | 12 | "github.com/vulcand/oxy/utils" 13 | ) 14 | 15 | func NewHandler(handler http.HandlerFunc) *httptest.Server { 16 | return httptest.NewServer(http.HandlerFunc(handler)) 17 | } 18 | 19 | func NewResponder(response string) *httptest.Server { 20 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 21 | w.Write([]byte(response)) 22 | })) 23 | } 24 | 25 | // ParseURI is the version of url.ParseRequestURI that panics if incorrect, helpful to shorten the tests 26 | func ParseURI(uri string) *url.URL { 27 | out, err := url.ParseRequestURI(uri) 28 | if err != nil { 29 | panic(err) 30 | } 31 | return out 32 | } 33 | 34 | type ReqOpts struct { 35 | Host string 36 | Method string 37 | Body string 38 | Headers http.Header 39 | Auth *utils.BasicAuth 40 | } 41 | 42 | type ReqOption func(o *ReqOpts) error 43 | 44 | func Method(m string) ReqOption { 45 | return func(o *ReqOpts) error { 46 | o.Method = m 47 | return nil 48 | } 49 | } 50 | 51 | func Host(h string) ReqOption { 52 | return func(o *ReqOpts) error { 53 | o.Host = h 54 | return nil 55 | } 56 | } 57 | 58 | func Body(b string) ReqOption { 59 | return func(o *ReqOpts) error { 60 | o.Body = b 61 | return nil 62 | } 63 | } 64 | 65 | func Header(name, val string) ReqOption { 66 | return func(o *ReqOpts) error { 67 | if o.Headers == nil { 68 | o.Headers = make(http.Header) 69 | } 70 | o.Headers.Add(name, val) 71 | return nil 72 | } 73 | } 74 | 75 | func Headers(h http.Header) ReqOption { 76 | return func(o *ReqOpts) error { 77 | if o.Headers == nil { 78 | o.Headers = make(http.Header) 79 | } 80 | utils.CopyHeaders(o.Headers, h) 81 | return nil 82 | } 83 | } 84 | 85 | func BasicAuth(username, password string) ReqOption { 86 | return func(o *ReqOpts) error { 87 | o.Auth = &utils.BasicAuth{ 88 | Username: username, 89 | Password: password, 90 | } 91 | return nil 92 | } 93 | } 94 | 95 | func MakeRequest(url string, opts ...ReqOption) (*http.Response, []byte, error) { 96 | o := &ReqOpts{} 97 | for _, s := range opts { 98 | if err := s(o); err != nil { 99 | return nil, nil, err 100 | } 101 | } 102 | 103 | if o.Method == "" { 104 | o.Method = "GET" 105 | } 106 | request, _ := http.NewRequest(o.Method, url, strings.NewReader(o.Body)) 107 | if o.Headers != nil { 108 | utils.CopyHeaders(request.Header, o.Headers) 109 | } 110 | 111 | if o.Auth != nil { 112 | request.Header.Set("Authorization", o.Auth.String()) 113 | } 114 | 115 | if len(o.Host) != 0 { 116 | request.Host = o.Host 117 | } 118 | 119 | var tr *http.Transport 120 | if strings.HasPrefix(url, "https") { 121 | tr = &http.Transport{ 122 | DisableKeepAlives: true, 123 | TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, 124 | } 125 | } else { 126 | tr = &http.Transport{ 127 | DisableKeepAlives: true, 128 | } 129 | } 130 | 131 | client := &http.Client{ 132 | Transport: tr, 133 | CheckRedirect: func(req *http.Request, via []*http.Request) error { 134 | return fmt.Errorf("No redirects") 135 | }, 136 | } 137 | response, err := client.Do(request) 138 | if err == nil { 139 | bodyBytes, err := ioutil.ReadAll(response.Body) 140 | return response, bodyBytes, err 141 | } 142 | return response, nil, err 143 | } 144 | 145 | func Get(url string, opts ...ReqOption) (*http.Response, []byte, error) { 146 | opts = append(opts, Method("GET")) 147 | return MakeRequest(url, opts...) 148 | } 149 | -------------------------------------------------------------------------------- /trace/trace.go: -------------------------------------------------------------------------------- 1 | // Package trace implement structured logging of requests 2 | package trace 3 | 4 | import ( 5 | "crypto/tls" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "strconv" 11 | "time" 12 | 13 | "github.com/vulcand/oxy/utils" 14 | ) 15 | 16 | // Option is a functional option setter for Tracer 17 | type Option func(*Tracer) error 18 | 19 | // ErrorHandler is a functional argument that sets error handler of the server 20 | func ErrorHandler(h utils.ErrorHandler) Option { 21 | return func(t *Tracer) error { 22 | t.errHandler = h 23 | return nil 24 | } 25 | } 26 | 27 | // RequestHeaders adds request headers to capture 28 | func RequestHeaders(headers ...string) Option { 29 | return func(t *Tracer) error { 30 | t.reqHeaders = append(t.reqHeaders, headers...) 31 | return nil 32 | } 33 | } 34 | 35 | // ResponseHeaders adds response headers to capture 36 | func ResponseHeaders(headers ...string) Option { 37 | return func(t *Tracer) error { 38 | t.respHeaders = append(t.respHeaders, headers...) 39 | return nil 40 | } 41 | } 42 | 43 | // Logger sets optional logger for trace used to report errors 44 | func Logger(l utils.Logger) Option { 45 | return func(t *Tracer) error { 46 | t.log = l 47 | return nil 48 | } 49 | } 50 | 51 | // Tracer records request and response emitting JSON structured data to the output 52 | type Tracer struct { 53 | errHandler utils.ErrorHandler 54 | next http.Handler 55 | reqHeaders []string 56 | respHeaders []string 57 | writer io.Writer 58 | log utils.Logger 59 | } 60 | 61 | // New creates a new Tracer middleware that emits all the request/response information in structured format 62 | // to writer and passes the request to the next handler. It can optionally capture request and response headers, 63 | // see RequestHeaders and ResponseHeaders options for details. 64 | func New(next http.Handler, writer io.Writer, opts ...Option) (*Tracer, error) { 65 | t := &Tracer{ 66 | writer: writer, 67 | next: next, 68 | } 69 | for _, o := range opts { 70 | if err := o(t); err != nil { 71 | return nil, err 72 | } 73 | } 74 | if t.errHandler == nil { 75 | t.errHandler = utils.DefaultHandler 76 | } 77 | if t.log == nil { 78 | t.log = utils.NullLogger 79 | } 80 | return t, nil 81 | } 82 | 83 | func (t *Tracer) ServeHTTP(w http.ResponseWriter, req *http.Request) { 84 | start := time.Now() 85 | pw := &utils.ProxyWriter{W: w} 86 | t.next.ServeHTTP(pw, req) 87 | 88 | l := t.newRecord(req, pw, time.Since(start)) 89 | if err := json.NewEncoder(t.writer).Encode(l); err != nil { 90 | t.log.Errorf("Failed to marshal request: %v", err) 91 | } 92 | } 93 | 94 | func (t *Tracer) newRecord(req *http.Request, pw *utils.ProxyWriter, diff time.Duration) *Record { 95 | return &Record{ 96 | Request: Request{ 97 | Method: req.Method, 98 | URL: req.URL.String(), 99 | TLS: newTLS(req), 100 | BodyBytes: bodyBytes(req.Header), 101 | Headers: captureHeaders(req.Header, t.reqHeaders), 102 | }, 103 | Response: Response{ 104 | Code: pw.StatusCode(), 105 | BodyBytes: bodyBytes(pw.Header()), 106 | Roundtrip: float64(diff) / float64(time.Millisecond), 107 | Headers: captureHeaders(pw.Header(), t.respHeaders), 108 | }, 109 | } 110 | } 111 | 112 | func newTLS(req *http.Request) *TLS { 113 | if req.TLS == nil { 114 | return nil 115 | } 116 | return &TLS{ 117 | Version: versionToString(req.TLS.Version), 118 | Resume: req.TLS.DidResume, 119 | CipherSuite: csToString(req.TLS.CipherSuite), 120 | Server: req.TLS.ServerName, 121 | } 122 | } 123 | 124 | func captureHeaders(in http.Header, headers []string) http.Header { 125 | if len(headers) == 0 || in == nil { 126 | return nil 127 | } 128 | out := make(http.Header, len(headers)) 129 | for _, h := range headers { 130 | vals, ok := in[h] 131 | if !ok || len(out[h]) != 0 { 132 | continue 133 | } 134 | for i := range vals { 135 | out.Add(h, vals[i]) 136 | } 137 | } 138 | return out 139 | } 140 | 141 | // Record represents a structured request and response record 142 | type Record struct { 143 | Request Request `json:"request"` 144 | Response Response `json:"response"` 145 | } 146 | 147 | // Req contains information about an HTTP request 148 | type Request struct { 149 | Method string `json:"method"` // Method - request method 150 | BodyBytes int64 `json:"body_bytes"` // BodyBytes - size of request body in bytes 151 | URL string `json:"url"` // URL - Request URL 152 | Headers http.Header `json:"headers,omitempty"` // Headers - optional request headers, will be recorded if configured 153 | TLS *TLS `json:"tls,omitempty"` // TLS - optional TLS record, will be recorded if it's a TLS connection 154 | } 155 | 156 | // Resp contains information about HTTP response 157 | type Response struct { 158 | Code int `json:"code"` // Code - response status code 159 | Roundtrip float64 `json:"roundtrip"` // Roundtrip - round trip time in milliseconds 160 | Headers http.Header `json:"headers,omitempty"` // Headers - optional headers, will be recorded if configured 161 | BodyBytes int64 `json:"body_bytes"` // BodyBytes - size of response body in bytes 162 | } 163 | 164 | // TLS contains information about this TLS connection 165 | type TLS struct { 166 | Version string `json:"version"` // Version - TLS version 167 | Resume bool `json:"resume"` // Resume tells if the session has been re-used (session tickets) 168 | CipherSuite string `json:"cipher_suite"` // CipherSuite contains cipher suite used for this connection 169 | Server string `json:"server"` // Server contains server name used in SNI 170 | } 171 | 172 | func versionToString(v uint16) string { 173 | switch v { 174 | case tls.VersionSSL30: 175 | return "SSL30" 176 | case tls.VersionTLS10: 177 | return "TLS10" 178 | case tls.VersionTLS11: 179 | return "TLS11" 180 | case tls.VersionTLS12: 181 | return "TLS12" 182 | } 183 | return fmt.Sprintf("unknown: %x", v) 184 | } 185 | 186 | func csToString(cs uint16) string { 187 | switch cs { 188 | case tls.TLS_RSA_WITH_RC4_128_SHA: 189 | return "TLS_RSA_WITH_RC4_128_SHA" 190 | case tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: 191 | return "TLS_RSA_WITH_3DES_EDE_CBC_SHA" 192 | case tls.TLS_RSA_WITH_AES_128_CBC_SHA: 193 | return "TLS_RSA_WITH_AES_128_CBC_SHA" 194 | case tls.TLS_RSA_WITH_AES_256_CBC_SHA: 195 | return "TLS_RSA_WITH_AES_256_CBC_SHA" 196 | case tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: 197 | return "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA" 198 | case tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: 199 | return "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA" 200 | case tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: 201 | return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA" 202 | case tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: 203 | return "TLS_ECDHE_RSA_WITH_RC4_128_SHA" 204 | case tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: 205 | return "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA" 206 | case tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: 207 | return "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA" 208 | case tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: 209 | return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA" 210 | case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: 211 | return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" 212 | case tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: 213 | return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" 214 | } 215 | return fmt.Sprintf("unknown: %x", cs) 216 | } 217 | 218 | func bodyBytes(h http.Header) int64 { 219 | len := h.Get("Content-Length") 220 | if len == "" { 221 | return 0 222 | } 223 | bytes, err := strconv.ParseInt(len, 10, 0) 224 | if err == nil { 225 | return bytes 226 | } 227 | return 0 228 | } 229 | -------------------------------------------------------------------------------- /trace/trace_test.go: -------------------------------------------------------------------------------- 1 | package trace 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "crypto/tls" 7 | "encoding/json" 8 | "fmt" 9 | "net/http" 10 | "net/http/httptest" 11 | "net/url" 12 | "testing" 13 | 14 | "github.com/vulcand/oxy/testutils" 15 | "github.com/vulcand/oxy/utils" 16 | 17 | . "gopkg.in/check.v1" 18 | ) 19 | 20 | func TestTrace(t *testing.T) { TestingT(t) } 21 | 22 | type TraceSuite struct{} 23 | 24 | var _ = Suite(&TraceSuite{}) 25 | 26 | func (s *TraceSuite) TestTraceSimple(c *C) { 27 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 28 | w.Header().Set("Content-Length", "5") 29 | w.Write([]byte("hello")) 30 | }) 31 | buf := &bytes.Buffer{} 32 | l := utils.NewFileLogger(buf, utils.INFO) 33 | 34 | trace := &bytes.Buffer{} 35 | 36 | t, err := New(handler, trace, Logger(l)) 37 | c.Assert(err, IsNil) 38 | 39 | srv := httptest.NewServer(t) 40 | defer srv.Close() 41 | 42 | re, _, err := testutils.MakeRequest(srv.URL+"/hello", testutils.Method("POST"), testutils.Body("123456")) 43 | c.Assert(err, IsNil) 44 | c.Assert(re.StatusCode, Equals, http.StatusOK) 45 | 46 | var r *Record 47 | c.Assert(json.Unmarshal(trace.Bytes(), &r), IsNil) 48 | 49 | c.Assert(r.Request.Method, Equals, "POST") 50 | c.Assert(r.Request.URL, Equals, "/hello") 51 | c.Assert(r.Response.Code, Equals, http.StatusOK) 52 | c.Assert(r.Request.BodyBytes, Equals, int64(6)) 53 | c.Assert(r.Response.Roundtrip, Not(Equals), float64(0)) 54 | c.Assert(r.Response.BodyBytes, Equals, int64(5)) 55 | } 56 | 57 | func (s *TraceSuite) TestTraceCaptureHeaders(c *C) { 58 | respHeaders := http.Header{ 59 | "X-Re-1": []string{"6", "7"}, 60 | "X-Re-2": []string{"2", "3"}, 61 | } 62 | 63 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 64 | utils.CopyHeaders(w.Header(), respHeaders) 65 | w.Write([]byte("hello")) 66 | }) 67 | buf := &bytes.Buffer{} 68 | l := utils.NewFileLogger(buf, utils.INFO) 69 | 70 | trace := &bytes.Buffer{} 71 | 72 | t, err := New(handler, trace, Logger(l), RequestHeaders("X-Req-B", "X-Req-A"), ResponseHeaders("X-Re-1", "X-Re-2")) 73 | c.Assert(err, IsNil) 74 | 75 | srv := httptest.NewServer(t) 76 | defer srv.Close() 77 | 78 | reqHeaders := http.Header{"X-Req-A": []string{"1", "2"}, "X-Req-B": []string{"3", "4"}} 79 | re, _, err := testutils.Get(srv.URL+"/hello", testutils.Headers(reqHeaders)) 80 | c.Assert(err, IsNil) 81 | c.Assert(re.StatusCode, Equals, http.StatusOK) 82 | 83 | var r *Record 84 | c.Assert(json.Unmarshal(trace.Bytes(), &r), IsNil) 85 | 86 | c.Assert(r.Request.Headers, DeepEquals, reqHeaders) 87 | c.Assert(r.Response.Headers, DeepEquals, respHeaders) 88 | } 89 | 90 | func (s *TraceSuite) TestTraceTLS(c *C) { 91 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 92 | w.Write([]byte("hello")) 93 | }) 94 | buf := &bytes.Buffer{} 95 | l := utils.NewFileLogger(buf, utils.INFO) 96 | 97 | trace := &bytes.Buffer{} 98 | 99 | t, err := New(handler, trace, Logger(l)) 100 | c.Assert(err, IsNil) 101 | 102 | srv := httptest.NewUnstartedServer(t) 103 | srv.StartTLS() 104 | defer srv.Close() 105 | 106 | config := &tls.Config{ 107 | InsecureSkipVerify: true, 108 | } 109 | 110 | u, err := url.Parse(srv.URL) 111 | c.Assert(err, IsNil) 112 | 113 | conn, err := tls.Dial("tcp", u.Host, config) 114 | c.Assert(err, IsNil) 115 | 116 | fmt.Fprintf(conn, "GET / HTTP/1.0\r\n\r\n") 117 | status, err := bufio.NewReader(conn).ReadString('\n') 118 | c.Assert(status, Equals, "HTTP/1.0 200 OK\r\n") 119 | state := conn.ConnectionState() 120 | conn.Close() 121 | 122 | var r *Record 123 | c.Assert(json.Unmarshal(trace.Bytes(), &r), IsNil) 124 | c.Assert(r.Request.TLS.Version, Equals, versionToString(state.Version)) 125 | } 126 | -------------------------------------------------------------------------------- /utils/auth.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "encoding/base64" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | type BasicAuth struct { 10 | Username string 11 | Password string 12 | } 13 | 14 | func (ba *BasicAuth) String() string { 15 | encoded := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", ba.Username, ba.Password))) 16 | return fmt.Sprintf("Basic %s", encoded) 17 | } 18 | 19 | func ParseAuthHeader(header string) (*BasicAuth, error) { 20 | values := strings.Fields(header) 21 | if len(values) != 2 { 22 | return nil, fmt.Errorf(fmt.Sprintf("Failed to parse header '%s'", header)) 23 | } 24 | 25 | auth_type := strings.ToLower(values[0]) 26 | if auth_type != "basic" { 27 | return nil, fmt.Errorf("Expected basic auth type, got '%s'", auth_type) 28 | } 29 | 30 | encoded_string := values[1] 31 | decoded_string, err := base64.StdEncoding.DecodeString(encoded_string) 32 | if err != nil { 33 | return nil, fmt.Errorf("Failed to parse header '%s', base64 failed: %s", header, err) 34 | } 35 | 36 | values = strings.SplitN(string(decoded_string), ":", 2) 37 | if len(values) != 2 { 38 | return nil, fmt.Errorf("Failed to parse header '%s', expected separator ':'", header) 39 | } 40 | return &BasicAuth{Username: values[0], Password: values[1]}, nil 41 | } 42 | -------------------------------------------------------------------------------- /utils/auth_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | . "gopkg.in/check.v1" 5 | ) 6 | 7 | type AuthSuite struct { 8 | } 9 | 10 | var _ = Suite(&AuthSuite{}) 11 | 12 | //Just to make sure we don't panic, return err and not 13 | //username and pass and cover the function 14 | func (s *AuthSuite) TestParseBadHeaders(c *C) { 15 | headers := []string{ 16 | //just empty string 17 | "", 18 | //missing auth type 19 | "justplainstring", 20 | //unknown auth type 21 | "Whut justplainstring", 22 | //invalid base64 23 | "Basic Shmasic", 24 | //random encoded string 25 | "Basic YW55IGNhcm5hbCBwbGVhcw==", 26 | } 27 | for _, h := range headers { 28 | _, err := ParseAuthHeader(h) 29 | c.Assert(err, NotNil) 30 | } 31 | } 32 | 33 | //Just to make sure we don't panic, return err and not 34 | //username and pass and cover the function 35 | func (s *AuthSuite) TestParseSuccess(c *C) { 36 | headers := []struct { 37 | Header string 38 | Expected BasicAuth 39 | }{ 40 | { 41 | "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", 42 | BasicAuth{Username: "Aladdin", Password: "open sesame"}, 43 | }, 44 | // Make sure that String() produces valid header 45 | { 46 | (&BasicAuth{Username: "Alice", Password: "Here's bob"}).String(), 47 | BasicAuth{Username: "Alice", Password: "Here's bob"}, 48 | }, 49 | //empty pass 50 | { 51 | "Basic QWxhZGRpbjo=", 52 | BasicAuth{Username: "Aladdin", Password: ""}, 53 | }, 54 | } 55 | for _, h := range headers { 56 | request, err := ParseAuthHeader(h.Header) 57 | c.Assert(err, IsNil) 58 | c.Assert(request.Username, Equals, h.Expected.Username) 59 | c.Assert(request.Password, Equals, h.Expected.Password) 60 | 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /utils/handler.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "net/http" 7 | ) 8 | 9 | type ErrorHandler interface { 10 | ServeHTTP(w http.ResponseWriter, req *http.Request, err error) 11 | } 12 | 13 | var DefaultHandler ErrorHandler = &StdHandler{} 14 | 15 | type StdHandler struct { 16 | } 17 | 18 | func (e *StdHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { 19 | statusCode := http.StatusInternalServerError 20 | if e, ok := err.(net.Error); ok { 21 | if e.Timeout() { 22 | statusCode = http.StatusGatewayTimeout 23 | } else { 24 | statusCode = http.StatusBadGateway 25 | } 26 | } else if err == io.EOF { 27 | statusCode = http.StatusBadGateway 28 | } 29 | w.WriteHeader(statusCode) 30 | w.Write([]byte(http.StatusText(statusCode))) 31 | } 32 | 33 | type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, error) 34 | 35 | // ServeHTTP calls f(w, r). 36 | func (f ErrorHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request, err error) { 37 | f(w, r, err) 38 | } 39 | -------------------------------------------------------------------------------- /utils/handler_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "strings" 8 | 9 | . "gopkg.in/check.v1" 10 | ) 11 | 12 | type UtilsSuite struct{} 13 | 14 | var _ = Suite(&UtilsSuite{}) 15 | 16 | func (s *UtilsSuite) TestDefaultHandlerErrors(c *C) { 17 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 18 | h := w.(http.Hijacker) 19 | conn, _, _ := h.Hijack() 20 | conn.Close() 21 | })) 22 | defer srv.Close() 23 | 24 | request, err := http.NewRequest("GET", srv.URL, strings.NewReader("")) 25 | c.Assert(err, IsNil) 26 | 27 | _, err = http.DefaultTransport.RoundTrip(request) 28 | 29 | w := NewBufferWriter(NopWriteCloser(&bytes.Buffer{})) 30 | 31 | DefaultHandler.ServeHTTP(w, nil, err) 32 | 33 | c.Assert(w.Code, Equals, http.StatusBadGateway) 34 | } 35 | -------------------------------------------------------------------------------- /utils/logging.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "io" 5 | "log" 6 | ) 7 | 8 | var NullLogger Logger = &NOPLogger{} 9 | 10 | // Logger defines a simple logging interface 11 | type Logger interface { 12 | Infof(format string, args ...interface{}) 13 | Warningf(format string, args ...interface{}) 14 | Errorf(format string, args ...interface{}) 15 | } 16 | 17 | type FileLogger struct { 18 | info *log.Logger 19 | warn *log.Logger 20 | error *log.Logger 21 | } 22 | 23 | func NewFileLogger(w io.Writer, lvl LogLevel) *FileLogger { 24 | l := &FileLogger{} 25 | flag := log.Ldate | log.Ltime | log.Lmicroseconds 26 | if lvl <= INFO { 27 | l.info = log.New(w, "INFO: ", flag) 28 | } 29 | if lvl <= WARN { 30 | l.warn = log.New(w, "WARN: ", flag) 31 | } 32 | if lvl <= ERROR { 33 | l.error = log.New(w, "ERR: ", flag) 34 | } 35 | return l 36 | } 37 | 38 | func (f *FileLogger) Infof(format string, args ...interface{}) { 39 | if f.info == nil { 40 | return 41 | } 42 | f.info.Printf(format, args...) 43 | } 44 | 45 | func (f *FileLogger) Warningf(format string, args ...interface{}) { 46 | if f.warn == nil { 47 | return 48 | } 49 | f.warn.Printf(format, args...) 50 | } 51 | 52 | func (f *FileLogger) Errorf(format string, args ...interface{}) { 53 | if f.error == nil { 54 | return 55 | } 56 | f.error.Printf(format, args...) 57 | } 58 | 59 | type NOPLogger struct { 60 | } 61 | 62 | func (*NOPLogger) Infof(format string, args ...interface{}) { 63 | 64 | } 65 | func (*NOPLogger) Warningf(format string, args ...interface{}) { 66 | } 67 | 68 | func (*NOPLogger) Errorf(format string, args ...interface{}) { 69 | } 70 | 71 | func (*NOPLogger) Info(string) { 72 | 73 | } 74 | func (*NOPLogger) Warning(string) { 75 | } 76 | 77 | func (*NOPLogger) Error(string) { 78 | } 79 | 80 | type LogLevel int 81 | 82 | const ( 83 | INFO = iota 84 | WARN 85 | ERROR 86 | ) 87 | -------------------------------------------------------------------------------- /utils/netutils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "mime" 7 | "net" 8 | "net/http" 9 | "net/url" 10 | ) 11 | 12 | // ProxyWriter helps to capture response headers and status code 13 | // from the ServeHTTP. It can be safely passed to ServeHTTP handler, 14 | // wrapping the real response writer. 15 | type ProxyWriter struct { 16 | W http.ResponseWriter 17 | Code int 18 | } 19 | 20 | func (p *ProxyWriter) StatusCode() int { 21 | if p.Code == 0 { 22 | // per contract standard lib will set this to http.StatusOK if not set 23 | // by user, here we avoid the confusion by mirroring this logic 24 | return http.StatusOK 25 | } 26 | return p.Code 27 | } 28 | 29 | func (p *ProxyWriter) Header() http.Header { 30 | return p.W.Header() 31 | } 32 | 33 | func (p *ProxyWriter) Write(buf []byte) (int, error) { 34 | return p.W.Write(buf) 35 | } 36 | 37 | func (p *ProxyWriter) WriteHeader(code int) { 38 | p.Code = code 39 | p.W.WriteHeader(code) 40 | } 41 | 42 | func (p *ProxyWriter) Flush() { 43 | if f, ok := p.W.(http.Flusher); ok { 44 | f.Flush() 45 | } 46 | } 47 | 48 | func (p *ProxyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 49 | return p.W.(http.Hijacker).Hijack() 50 | } 51 | 52 | func NewBufferWriter(w io.WriteCloser) *BufferWriter { 53 | return &BufferWriter{ 54 | W: w, 55 | H: make(http.Header), 56 | } 57 | } 58 | 59 | type BufferWriter struct { 60 | H http.Header 61 | Code int 62 | W io.WriteCloser 63 | } 64 | 65 | func (b *BufferWriter) Close() error { 66 | return b.W.Close() 67 | } 68 | 69 | func (b *BufferWriter) Header() http.Header { 70 | return b.H 71 | } 72 | 73 | func (b *BufferWriter) Write(buf []byte) (int, error) { 74 | return b.W.Write(buf) 75 | } 76 | 77 | // WriteHeader sets rw.Code. 78 | func (b *BufferWriter) WriteHeader(code int) { 79 | b.Code = code 80 | } 81 | 82 | func (b *BufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 83 | return b.W.(http.Hijacker).Hijack() 84 | } 85 | 86 | type nopWriteCloser struct { 87 | io.Writer 88 | } 89 | 90 | func (*nopWriteCloser) Close() error { return nil } 91 | 92 | // NopCloser returns a WriteCloser with a no-op Close method wrapping 93 | // the provided Writer w. 94 | func NopWriteCloser(w io.Writer) io.WriteCloser { 95 | return &nopWriteCloser{w} 96 | } 97 | 98 | // CopyURL provides update safe copy by avoiding shallow copying User field 99 | func CopyURL(i *url.URL) *url.URL { 100 | out := *i 101 | if i.User != nil { 102 | out.User = &(*i.User) 103 | } 104 | return &out 105 | } 106 | 107 | // CopyHeaders copies http headers from source to destination, it 108 | // does not overide, but adds multiple headers 109 | func CopyHeaders(dst, src http.Header) { 110 | for k, vv := range src { 111 | for _, v := range vv { 112 | dst.Add(k, v) 113 | } 114 | } 115 | } 116 | 117 | // HasHeaders determines whether any of the header names is present in the http headers 118 | func HasHeaders(names []string, headers http.Header) bool { 119 | for _, h := range names { 120 | if headers.Get(h) != "" { 121 | return true 122 | } 123 | } 124 | return false 125 | } 126 | 127 | // RemoveHeaders removes the header with the given names from the headers map 128 | func RemoveHeaders(headers http.Header, names ...string) { 129 | for _, h := range names { 130 | headers.Del(h) 131 | } 132 | } 133 | 134 | // Parse the MIME media type value of a header. 135 | func GetHeaderMediaType(headers http.Header, name string) (string, error) { 136 | mediatype, _, err := mime.ParseMediaType(headers.Get(name)) 137 | return mediatype, err 138 | } 139 | -------------------------------------------------------------------------------- /utils/netutils_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "testing" 7 | 8 | . "gopkg.in/check.v1" 9 | ) 10 | 11 | func TestUtils(t *testing.T) { TestingT(t) } 12 | 13 | type NetUtilsSuite struct{} 14 | 15 | var _ = Suite(&NetUtilsSuite{}) 16 | 17 | // Make sure copy does it right, so the copied url 18 | // is safe to alter without modifying the other 19 | func (s *NetUtilsSuite) TestCopyUrl(c *C) { 20 | urlA := &url.URL{ 21 | Scheme: "http", 22 | Host: "localhost:5000", 23 | Path: "/upstream", 24 | Opaque: "opaque", 25 | RawQuery: "a=1&b=2", 26 | Fragment: "#hello", 27 | User: &url.Userinfo{}, 28 | } 29 | urlB := CopyURL(urlA) 30 | c.Assert(urlB, DeepEquals, urlA) 31 | urlB.Scheme = "https" 32 | c.Assert(urlB, Not(DeepEquals), urlA) 33 | } 34 | 35 | // Make sure copy headers is not shallow and copies all headers 36 | func (s *NetUtilsSuite) TestCopyHeaders(c *C) { 37 | source, destination := make(http.Header), make(http.Header) 38 | source.Add("a", "b") 39 | source.Add("c", "d") 40 | 41 | CopyHeaders(destination, source) 42 | 43 | c.Assert(destination.Get("a"), Equals, "b") 44 | c.Assert(destination.Get("c"), Equals, "d") 45 | 46 | // make sure that altering source does not affect the destination 47 | source.Del("a") 48 | c.Assert(source.Get("a"), Equals, "") 49 | c.Assert(destination.Get("a"), Equals, "b") 50 | } 51 | 52 | func (s *NetUtilsSuite) TestHasHeaders(c *C) { 53 | source := make(http.Header) 54 | source.Add("a", "b") 55 | source.Add("c", "d") 56 | c.Assert(HasHeaders([]string{"a", "f"}, source), Equals, true) 57 | c.Assert(HasHeaders([]string{"i", "j"}, source), Equals, false) 58 | } 59 | 60 | func (s *NetUtilsSuite) TestRemoveHeaders(c *C) { 61 | source := make(http.Header) 62 | source.Add("a", "b") 63 | source.Add("a", "m") 64 | source.Add("c", "d") 65 | RemoveHeaders(source, "a") 66 | c.Assert(source.Get("a"), Equals, "") 67 | c.Assert(source.Get("c"), Equals, "d") 68 | } 69 | 70 | func (s *NetUtilsSuite) TestGetHeaderMediaType(c *C) { 71 | source := make(http.Header) 72 | source.Add("Content-Type", "text/event-stream") 73 | 74 | mediatype, err := GetHeaderMediaType(source, "Content-Type") 75 | c.Assert(err, IsNil) 76 | c.Assert(mediatype, Equals, "text/event-stream") 77 | } 78 | 79 | func (s *NetUtilsSuite) TestGetHeaderMediaTypeCharSet(c *C) { 80 | source := make(http.Header) 81 | source.Add("Content-Type", "text/event-stream; charset=utf-8") 82 | 83 | mediatype, err := GetHeaderMediaType(source, "Content-Type") 84 | c.Assert(err, IsNil) 85 | c.Assert(mediatype, Equals, "text/event-stream") 86 | } 87 | 88 | func (s *NetUtilsSuite) TestGetHeaderMediaTypeMixedCase(c *C) { 89 | source := make(http.Header) 90 | source.Add("Content-Type", "text/Event-Stream") 91 | 92 | mediatype, err := GetHeaderMediaType(source, "Content-Type") 93 | c.Assert(err, IsNil) 94 | c.Assert(mediatype, Equals, "text/event-stream") 95 | } 96 | -------------------------------------------------------------------------------- /utils/source.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | // ExtractSource extracts the source from the request, e.g. that may be client ip, or particular header that 10 | // identifies the source. amount stands for amount of connections the source consumes, usually 1 for connection limiters 11 | // error should be returned when source can not be identified 12 | type SourceExtractor interface { 13 | Extract(req *http.Request) (token string, amount int64, err error) 14 | } 15 | 16 | type ExtractorFunc func(req *http.Request) (token string, amount int64, err error) 17 | 18 | func (f ExtractorFunc) Extract(req *http.Request) (string, int64, error) { 19 | return f(req) 20 | } 21 | 22 | type ExtractSource func(req *http.Request) 23 | 24 | func NewExtractor(variable string) (SourceExtractor, error) { 25 | if variable == "client.ip" { 26 | return ExtractorFunc(extractClientIP), nil 27 | } 28 | if variable == "request.host" { 29 | return ExtractorFunc(extractHost), nil 30 | } 31 | if strings.HasPrefix(variable, "request.header.") { 32 | header := strings.TrimPrefix(variable, "request.header.") 33 | if len(header) == 0 { 34 | return nil, fmt.Errorf("Wrong header: %s", header) 35 | } 36 | return makeHeaderExtractor(header), nil 37 | } 38 | return nil, fmt.Errorf("Unsupported limiting variable: '%s'", variable) 39 | } 40 | 41 | func extractClientIP(req *http.Request) (string, int64, error) { 42 | vals := strings.SplitN(req.RemoteAddr, ":", 2) 43 | if len(vals[0]) == 0 { 44 | return "", 0, fmt.Errorf("Failed to parse client IP: %v", req.RemoteAddr) 45 | } 46 | return vals[0], 1, nil 47 | } 48 | 49 | func extractHost(req *http.Request) (string, int64, error) { 50 | return req.Host, 1, nil 51 | } 52 | 53 | func makeHeaderExtractor(header string) SourceExtractor { 54 | return ExtractorFunc(func(req *http.Request) (string, int64, error) { 55 | return req.Header.Get(header), 1, nil 56 | }) 57 | } 58 | --------------------------------------------------------------------------------