├── .github └── workflows │ └── test.yml ├── .golangci.yaml ├── LICENSE ├── README.md ├── example └── http_breaker.go ├── go.mod ├── go.sum ├── gobreaker.go ├── gobreaker_test.go └── v2 ├── distributed_gobreaker.go ├── distributed_gobreaker_test.go ├── example └── http_breaker.go ├── go.mod ├── go.sum ├── gobreaker.go ├── gobreaker_test.go └── redis_store.go /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | name: Test 3 | jobs: 4 | test: 5 | strategy: 6 | matrix: 7 | go-version: [1.22.x, 1.23.x] 8 | os: [ubuntu-latest] 9 | work-dir: ["./", "./v2"] 10 | runs-on: ${{matrix.os}} 11 | defaults: 12 | run: 13 | working-directory: ${{matrix.work-dir}} 14 | steps: 15 | - name: Setup Go 16 | uses: actions/setup-go@v5 17 | with: 18 | go-version: ${{matrix.go-version}} 19 | - name: Checkout 20 | uses: actions/checkout@v4 21 | - name: Lint 22 | uses: golangci/golangci-lint-action@v6 23 | with: 24 | working-directory: ${{matrix.work-dir}} 25 | - name: go test 26 | run: go test -v ./... 27 | - name: Run example 28 | run: cd example && go build -o http_breaker && ./http_breaker 29 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | linters: 2 | enable: 3 | - gofmt 4 | - goimports 5 | - gosec 6 | - misspell 7 | issues: 8 | exclude-rules: 9 | - path: _test\.go 10 | linters: 11 | - gosec 12 | - govet 13 | - path: example/ 14 | linters: 15 | - gosec 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright 2015 Sony Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | gobreaker 2 | ========= 3 | 4 | [![GoDoc](https://godoc.org/github.com/sony/gobreaker/v2?status.svg)](https://godoc.org/github.com/sony/gobreaker/v2) 5 | 6 | [gobreaker][repo-url] implements the [Circuit Breaker pattern](https://msdn.microsoft.com/en-us/library/dn589784.aspx) in Go. 7 | 8 | Installation 9 | ------------ 10 | 11 | ``` 12 | go get github.com/sony/gobreaker/v2 13 | ``` 14 | 15 | Usage 16 | ----- 17 | 18 | The struct `CircuitBreaker` is a state machine to prevent sending requests that are likely to fail. 19 | The function `NewCircuitBreaker` creates a new `CircuitBreaker`. 20 | The type parameter `T` specifies the return type of requests. 21 | 22 | ```go 23 | func NewCircuitBreaker[T any](st Settings) *CircuitBreaker[T] 24 | ``` 25 | 26 | You can configure `CircuitBreaker` by the struct `Settings`: 27 | 28 | ```go 29 | type Settings struct { 30 | Name string 31 | MaxRequests uint32 32 | Interval time.Duration 33 | Timeout time.Duration 34 | ReadyToTrip func(counts Counts) bool 35 | OnStateChange func(name string, from State, to State) 36 | IsSuccessful func(err error) bool 37 | } 38 | ``` 39 | 40 | - `Name` is the name of the `CircuitBreaker`. 41 | 42 | - `MaxRequests` is the maximum number of requests allowed to pass through 43 | when the `CircuitBreaker` is half-open. 44 | If `MaxRequests` is 0, `CircuitBreaker` allows only 1 request. 45 | 46 | - `Interval` is the cyclic period of the closed state 47 | for `CircuitBreaker` to clear the internal `Counts`, described later in this section. 48 | If `Interval` is 0, `CircuitBreaker` doesn't clear the internal `Counts` during the closed state. 49 | 50 | - `Timeout` is the period of the open state, 51 | after which the state of `CircuitBreaker` becomes half-open. 52 | If `Timeout` is 0, the timeout value of `CircuitBreaker` is set to 60 seconds. 53 | 54 | - `ReadyToTrip` is called with a copy of `Counts` whenever a request fails in the closed state. 55 | If `ReadyToTrip` returns true, `CircuitBreaker` will be placed into the open state. 56 | If `ReadyToTrip` is `nil`, default `ReadyToTrip` is used. 57 | Default `ReadyToTrip` returns true when the number of consecutive failures is more than 5. 58 | 59 | - `OnStateChange` is called whenever the state of `CircuitBreaker` changes. 60 | 61 | - `IsSuccessful` is called with the error returned from a request. 62 | If `IsSuccessful` returns true, the error is counted as a success. 63 | Otherwise the error is counted as a failure. 64 | If `IsSuccessful` is nil, default `IsSuccessful` is used, which returns false for all non-nil errors. 65 | 66 | The struct `Counts` holds the numbers of requests and their successes/failures: 67 | 68 | ```go 69 | type Counts struct { 70 | Requests uint32 71 | TotalSuccesses uint32 72 | TotalFailures uint32 73 | ConsecutiveSuccesses uint32 74 | ConsecutiveFailures uint32 75 | } 76 | ``` 77 | 78 | `CircuitBreaker` clears the internal `Counts` either 79 | on the change of the state or at the closed-state intervals. 80 | `Counts` ignores the results of the requests sent before clearing. 81 | 82 | `CircuitBreaker` can wrap any function to send a request: 83 | 84 | ```go 85 | func (cb *CircuitBreaker[T]) Execute(req func() (T, error)) (T, error) 86 | ``` 87 | 88 | The method `Execute` runs the given request if `CircuitBreaker` accepts it. 89 | `Execute` returns an error instantly if `CircuitBreaker` rejects the request. 90 | Otherwise, `Execute` returns the result of the request. 91 | If a panic occurs in the request, `CircuitBreaker` handles it as an error 92 | and causes the same panic again. 93 | 94 | Example 95 | ------- 96 | 97 | ```go 98 | var cb *gobreaker.CircuitBreaker[[]byte] 99 | 100 | func Get(url string) ([]byte, error) { 101 | body, err := cb.Execute(func() ([]byte, error) { 102 | resp, err := http.Get(url) 103 | if err != nil { 104 | return nil, err 105 | } 106 | 107 | defer resp.Body.Close() 108 | body, err := io.ReadAll(resp.Body) 109 | if err != nil { 110 | return nil, err 111 | } 112 | 113 | return body, nil 114 | }) 115 | if err != nil { 116 | return nil, err 117 | } 118 | 119 | return body, nil 120 | } 121 | ``` 122 | 123 | See [example](https://github.com/sony/gobreaker/blob/master/v2/example) for details. 124 | 125 | License 126 | ------- 127 | 128 | The MIT License (MIT) 129 | 130 | See [LICENSE](https://github.com/sony/gobreaker/blob/master/LICENSE) for details. 131 | 132 | 133 | [repo-url]: https://github.com/sony/gobreaker 134 | -------------------------------------------------------------------------------- /example/http_breaker.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "log" 7 | "net/http" 8 | 9 | "github.com/sony/gobreaker" 10 | ) 11 | 12 | var cb *gobreaker.CircuitBreaker 13 | 14 | func init() { 15 | var st gobreaker.Settings 16 | st.Name = "HTTP GET" 17 | st.ReadyToTrip = func(counts gobreaker.Counts) bool { 18 | failureRatio := float64(counts.TotalFailures) / float64(counts.Requests) 19 | return counts.Requests >= 3 && failureRatio >= 0.6 20 | } 21 | 22 | cb = gobreaker.NewCircuitBreaker(st) 23 | } 24 | 25 | // Get wraps http.Get in CircuitBreaker. 26 | func Get(url string) ([]byte, error) { 27 | body, err := cb.Execute(func() (interface{}, error) { 28 | resp, err := http.Get(url) 29 | if err != nil { 30 | return nil, err 31 | } 32 | 33 | defer resp.Body.Close() 34 | body, err := ioutil.ReadAll(resp.Body) 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | return body, nil 40 | }) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | return body.([]byte), nil 46 | } 47 | 48 | func main() { 49 | body, err := Get("http://www.google.com/robots.txt") 50 | if err != nil { 51 | log.Fatal(err) 52 | } 53 | 54 | fmt.Println(string(body)) 55 | } 56 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/sony/gobreaker 2 | 3 | go 1.12 4 | 5 | require github.com/stretchr/testify v1.3.0 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 6 | github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= 7 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 8 | -------------------------------------------------------------------------------- /gobreaker.go: -------------------------------------------------------------------------------- 1 | // Package gobreaker implements the Circuit Breaker pattern. 2 | // See https://msdn.microsoft.com/en-us/library/dn589784.aspx. 3 | package gobreaker 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | // State is a type that represents a state of CircuitBreaker. 13 | type State int 14 | 15 | // These constants are states of CircuitBreaker. 16 | const ( 17 | StateClosed State = iota 18 | StateHalfOpen 19 | StateOpen 20 | ) 21 | 22 | var ( 23 | // ErrTooManyRequests is returned when the CB state is half open and the requests count is over the cb maxRequests 24 | ErrTooManyRequests = errors.New("too many requests") 25 | // ErrOpenState is returned when the CB state is open 26 | ErrOpenState = errors.New("circuit breaker is open") 27 | ) 28 | 29 | // String implements stringer interface. 30 | func (s State) String() string { 31 | switch s { 32 | case StateClosed: 33 | return "closed" 34 | case StateHalfOpen: 35 | return "half-open" 36 | case StateOpen: 37 | return "open" 38 | default: 39 | return fmt.Sprintf("unknown state: %d", s) 40 | } 41 | } 42 | 43 | // Counts holds the numbers of requests and their successes/failures. 44 | // CircuitBreaker clears the internal Counts either 45 | // on the change of the state or at the closed-state intervals. 46 | // Counts ignores the results of the requests sent before clearing. 47 | type Counts struct { 48 | Requests uint32 49 | TotalSuccesses uint32 50 | TotalFailures uint32 51 | ConsecutiveSuccesses uint32 52 | ConsecutiveFailures uint32 53 | } 54 | 55 | func (c *Counts) onRequest() { 56 | c.Requests++ 57 | } 58 | 59 | func (c *Counts) onSuccess() { 60 | c.TotalSuccesses++ 61 | c.ConsecutiveSuccesses++ 62 | c.ConsecutiveFailures = 0 63 | } 64 | 65 | func (c *Counts) onFailure() { 66 | c.TotalFailures++ 67 | c.ConsecutiveFailures++ 68 | c.ConsecutiveSuccesses = 0 69 | } 70 | 71 | func (c *Counts) clear() { 72 | c.Requests = 0 73 | c.TotalSuccesses = 0 74 | c.TotalFailures = 0 75 | c.ConsecutiveSuccesses = 0 76 | c.ConsecutiveFailures = 0 77 | } 78 | 79 | // Settings configures CircuitBreaker: 80 | // 81 | // Name is the name of the CircuitBreaker. 82 | // 83 | // MaxRequests is the maximum number of requests allowed to pass through 84 | // when the CircuitBreaker is half-open. 85 | // If MaxRequests is 0, the CircuitBreaker allows only 1 request. 86 | // 87 | // Interval is the cyclic period of the closed state 88 | // for the CircuitBreaker to clear the internal Counts. 89 | // If Interval is less than or equal to 0, the CircuitBreaker doesn't clear internal Counts during the closed state. 90 | // 91 | // Timeout is the period of the open state, 92 | // after which the state of the CircuitBreaker becomes half-open. 93 | // If Timeout is less than or equal to 0, the timeout value of the CircuitBreaker is set to 60 seconds. 94 | // 95 | // ReadyToTrip is called with a copy of Counts whenever a request fails in the closed state. 96 | // If ReadyToTrip returns true, the CircuitBreaker will be placed into the open state. 97 | // If ReadyToTrip is nil, default ReadyToTrip is used. 98 | // Default ReadyToTrip returns true when the number of consecutive failures is more than 5. 99 | // 100 | // OnStateChange is called whenever the state of the CircuitBreaker changes. 101 | // 102 | // IsSuccessful is called with the error returned from a request. 103 | // If IsSuccessful returns true, the error is counted as a success. 104 | // Otherwise the error is counted as a failure. 105 | // If IsSuccessful is nil, default IsSuccessful is used, which returns false for all non-nil errors. 106 | type Settings struct { 107 | Name string 108 | MaxRequests uint32 109 | Interval time.Duration 110 | Timeout time.Duration 111 | ReadyToTrip func(counts Counts) bool 112 | OnStateChange func(name string, from State, to State) 113 | IsSuccessful func(err error) bool 114 | } 115 | 116 | // CircuitBreaker is a state machine to prevent sending requests that are likely to fail. 117 | type CircuitBreaker struct { 118 | name string 119 | maxRequests uint32 120 | interval time.Duration 121 | timeout time.Duration 122 | readyToTrip func(counts Counts) bool 123 | isSuccessful func(err error) bool 124 | onStateChange func(name string, from State, to State) 125 | 126 | mutex sync.Mutex 127 | state State 128 | generation uint64 129 | counts Counts 130 | expiry time.Time 131 | } 132 | 133 | // TwoStepCircuitBreaker is like CircuitBreaker but instead of surrounding a function 134 | // with the breaker functionality, it only checks whether a request can proceed and 135 | // expects the caller to report the outcome in a separate step using a callback. 136 | type TwoStepCircuitBreaker struct { 137 | cb *CircuitBreaker 138 | } 139 | 140 | // NewCircuitBreaker returns a new CircuitBreaker configured with the given Settings. 141 | func NewCircuitBreaker(st Settings) *CircuitBreaker { 142 | cb := new(CircuitBreaker) 143 | 144 | cb.name = st.Name 145 | cb.onStateChange = st.OnStateChange 146 | 147 | if st.MaxRequests == 0 { 148 | cb.maxRequests = 1 149 | } else { 150 | cb.maxRequests = st.MaxRequests 151 | } 152 | 153 | if st.Interval <= 0 { 154 | cb.interval = defaultInterval 155 | } else { 156 | cb.interval = st.Interval 157 | } 158 | 159 | if st.Timeout <= 0 { 160 | cb.timeout = defaultTimeout 161 | } else { 162 | cb.timeout = st.Timeout 163 | } 164 | 165 | if st.ReadyToTrip == nil { 166 | cb.readyToTrip = defaultReadyToTrip 167 | } else { 168 | cb.readyToTrip = st.ReadyToTrip 169 | } 170 | 171 | if st.IsSuccessful == nil { 172 | cb.isSuccessful = defaultIsSuccessful 173 | } else { 174 | cb.isSuccessful = st.IsSuccessful 175 | } 176 | 177 | cb.toNewGeneration(time.Now()) 178 | 179 | return cb 180 | } 181 | 182 | // NewTwoStepCircuitBreaker returns a new TwoStepCircuitBreaker configured with the given Settings. 183 | func NewTwoStepCircuitBreaker(st Settings) *TwoStepCircuitBreaker { 184 | return &TwoStepCircuitBreaker{ 185 | cb: NewCircuitBreaker(st), 186 | } 187 | } 188 | 189 | const defaultInterval = time.Duration(0) * time.Second 190 | const defaultTimeout = time.Duration(60) * time.Second 191 | 192 | func defaultReadyToTrip(counts Counts) bool { 193 | return counts.ConsecutiveFailures > 5 194 | } 195 | 196 | func defaultIsSuccessful(err error) bool { 197 | return err == nil 198 | } 199 | 200 | // Name returns the name of the CircuitBreaker. 201 | func (cb *CircuitBreaker) Name() string { 202 | return cb.name 203 | } 204 | 205 | // State returns the current state of the CircuitBreaker. 206 | func (cb *CircuitBreaker) State() State { 207 | cb.mutex.Lock() 208 | defer cb.mutex.Unlock() 209 | 210 | now := time.Now() 211 | state, _ := cb.currentState(now) 212 | return state 213 | } 214 | 215 | // Counts returns internal counters 216 | func (cb *CircuitBreaker) Counts() Counts { 217 | cb.mutex.Lock() 218 | defer cb.mutex.Unlock() 219 | 220 | return cb.counts 221 | } 222 | 223 | // Execute runs the given request if the CircuitBreaker accepts it. 224 | // Execute returns an error instantly if the CircuitBreaker rejects the request. 225 | // Otherwise, Execute returns the result of the request. 226 | // If a panic occurs in the request, the CircuitBreaker handles it as an error 227 | // and causes the same panic again. 228 | func (cb *CircuitBreaker) Execute(req func() (interface{}, error)) (interface{}, error) { 229 | generation, err := cb.beforeRequest() 230 | if err != nil { 231 | return nil, err 232 | } 233 | 234 | defer func() { 235 | e := recover() 236 | if e != nil { 237 | cb.afterRequest(generation, false) 238 | panic(e) 239 | } 240 | }() 241 | 242 | result, err := req() 243 | cb.afterRequest(generation, cb.isSuccessful(err)) 244 | return result, err 245 | } 246 | 247 | // Name returns the name of the TwoStepCircuitBreaker. 248 | func (tscb *TwoStepCircuitBreaker) Name() string { 249 | return tscb.cb.Name() 250 | } 251 | 252 | // State returns the current state of the TwoStepCircuitBreaker. 253 | func (tscb *TwoStepCircuitBreaker) State() State { 254 | return tscb.cb.State() 255 | } 256 | 257 | // Counts returns internal counters 258 | func (tscb *TwoStepCircuitBreaker) Counts() Counts { 259 | return tscb.cb.Counts() 260 | } 261 | 262 | // Allow checks if a new request can proceed. It returns a callback that should be used to 263 | // register the success or failure in a separate step. If the circuit breaker doesn't allow 264 | // requests, it returns an error. 265 | func (tscb *TwoStepCircuitBreaker) Allow() (done func(success bool), err error) { 266 | generation, err := tscb.cb.beforeRequest() 267 | if err != nil { 268 | return nil, err 269 | } 270 | 271 | return func(success bool) { 272 | tscb.cb.afterRequest(generation, success) 273 | }, nil 274 | } 275 | 276 | func (cb *CircuitBreaker) beforeRequest() (uint64, error) { 277 | cb.mutex.Lock() 278 | defer cb.mutex.Unlock() 279 | 280 | now := time.Now() 281 | state, generation := cb.currentState(now) 282 | 283 | if state == StateOpen { 284 | return generation, ErrOpenState 285 | } else if state == StateHalfOpen && cb.counts.Requests >= cb.maxRequests { 286 | return generation, ErrTooManyRequests 287 | } 288 | 289 | cb.counts.onRequest() 290 | return generation, nil 291 | } 292 | 293 | func (cb *CircuitBreaker) afterRequest(before uint64, success bool) { 294 | cb.mutex.Lock() 295 | defer cb.mutex.Unlock() 296 | 297 | now := time.Now() 298 | state, generation := cb.currentState(now) 299 | if generation != before { 300 | return 301 | } 302 | 303 | if success { 304 | cb.onSuccess(state, now) 305 | } else { 306 | cb.onFailure(state, now) 307 | } 308 | } 309 | 310 | func (cb *CircuitBreaker) onSuccess(state State, now time.Time) { 311 | switch state { 312 | case StateClosed: 313 | cb.counts.onSuccess() 314 | case StateHalfOpen: 315 | cb.counts.onSuccess() 316 | if cb.counts.ConsecutiveSuccesses >= cb.maxRequests { 317 | cb.setState(StateClosed, now) 318 | } 319 | } 320 | } 321 | 322 | func (cb *CircuitBreaker) onFailure(state State, now time.Time) { 323 | switch state { 324 | case StateClosed: 325 | cb.counts.onFailure() 326 | if cb.readyToTrip(cb.counts) { 327 | cb.setState(StateOpen, now) 328 | } 329 | case StateHalfOpen: 330 | cb.setState(StateOpen, now) 331 | } 332 | } 333 | 334 | func (cb *CircuitBreaker) currentState(now time.Time) (State, uint64) { 335 | switch cb.state { 336 | case StateClosed: 337 | if !cb.expiry.IsZero() && cb.expiry.Before(now) { 338 | cb.toNewGeneration(now) 339 | } 340 | case StateOpen: 341 | if cb.expiry.Before(now) { 342 | cb.setState(StateHalfOpen, now) 343 | } 344 | } 345 | return cb.state, cb.generation 346 | } 347 | 348 | func (cb *CircuitBreaker) setState(state State, now time.Time) { 349 | if cb.state == state { 350 | return 351 | } 352 | 353 | prev := cb.state 354 | cb.state = state 355 | 356 | cb.toNewGeneration(now) 357 | 358 | if cb.onStateChange != nil { 359 | cb.onStateChange(cb.name, prev, state) 360 | } 361 | } 362 | 363 | func (cb *CircuitBreaker) toNewGeneration(now time.Time) { 364 | cb.generation++ 365 | cb.counts.clear() 366 | 367 | var zero time.Time 368 | switch cb.state { 369 | case StateClosed: 370 | if cb.interval == 0 { 371 | cb.expiry = zero 372 | } else { 373 | cb.expiry = now.Add(cb.interval) 374 | } 375 | case StateOpen: 376 | cb.expiry = now.Add(cb.timeout) 377 | default: // StateHalfOpen 378 | cb.expiry = zero 379 | } 380 | } 381 | -------------------------------------------------------------------------------- /gobreaker_test.go: -------------------------------------------------------------------------------- 1 | package gobreaker 2 | 3 | import ( 4 | "errors" 5 | "runtime" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | var defaultCB *CircuitBreaker 13 | var customCB *CircuitBreaker 14 | 15 | type StateChange struct { 16 | name string 17 | from State 18 | to State 19 | } 20 | 21 | var stateChange StateChange 22 | 23 | func pseudoSleep(cb *CircuitBreaker, period time.Duration) { 24 | if !cb.expiry.IsZero() { 25 | cb.expiry = cb.expiry.Add(-period) 26 | } 27 | } 28 | 29 | func succeed(cb *CircuitBreaker) error { 30 | _, err := cb.Execute(func() (interface{}, error) { return nil, nil }) 31 | return err 32 | } 33 | 34 | func succeedLater(cb *CircuitBreaker, delay time.Duration) <-chan error { 35 | ch := make(chan error) 36 | go func() { 37 | _, err := cb.Execute(func() (interface{}, error) { 38 | time.Sleep(delay) 39 | return nil, nil 40 | }) 41 | ch <- err 42 | }() 43 | return ch 44 | } 45 | 46 | func succeed2Step(cb *TwoStepCircuitBreaker) error { 47 | done, err := cb.Allow() 48 | if err != nil { 49 | return err 50 | } 51 | 52 | done(true) 53 | return nil 54 | } 55 | 56 | func fail(cb *CircuitBreaker) error { 57 | msg := "fail" 58 | _, err := cb.Execute(func() (interface{}, error) { return nil, errors.New(msg) }) 59 | if err.Error() == msg { 60 | return nil 61 | } 62 | return err 63 | } 64 | 65 | func fail2Step(cb *TwoStepCircuitBreaker) error { 66 | done, err := cb.Allow() 67 | if err != nil { 68 | return err 69 | } 70 | 71 | done(false) 72 | return nil 73 | } 74 | 75 | func causePanic(cb *CircuitBreaker) error { 76 | _, err := cb.Execute(func() (interface{}, error) { panic("oops"); return nil, nil }) 77 | return err 78 | } 79 | 80 | func newCustom() *CircuitBreaker { 81 | var customSt Settings 82 | customSt.Name = "cb" 83 | customSt.MaxRequests = 3 84 | customSt.Interval = time.Duration(30) * time.Second 85 | customSt.Timeout = time.Duration(90) * time.Second 86 | customSt.ReadyToTrip = func(counts Counts) bool { 87 | numReqs := counts.Requests 88 | failureRatio := float64(counts.TotalFailures) / float64(numReqs) 89 | 90 | counts.clear() // no effect on customCB.counts 91 | 92 | return numReqs >= 3 && failureRatio >= 0.6 93 | } 94 | customSt.OnStateChange = func(name string, from State, to State) { 95 | stateChange = StateChange{name, from, to} 96 | } 97 | 98 | return NewCircuitBreaker(customSt) 99 | } 100 | 101 | func newNegativeDurationCB() *CircuitBreaker { 102 | var negativeSt Settings 103 | negativeSt.Name = "ncb" 104 | negativeSt.Interval = time.Duration(-30) * time.Second 105 | negativeSt.Timeout = time.Duration(-90) * time.Second 106 | 107 | return NewCircuitBreaker(negativeSt) 108 | } 109 | 110 | func init() { 111 | defaultCB = NewCircuitBreaker(Settings{}) 112 | customCB = newCustom() 113 | } 114 | 115 | func TestStateConstants(t *testing.T) { 116 | assert.Equal(t, State(0), StateClosed) 117 | assert.Equal(t, State(1), StateHalfOpen) 118 | assert.Equal(t, State(2), StateOpen) 119 | 120 | assert.Equal(t, StateClosed.String(), "closed") 121 | assert.Equal(t, StateHalfOpen.String(), "half-open") 122 | assert.Equal(t, StateOpen.String(), "open") 123 | assert.Equal(t, State(100).String(), "unknown state: 100") 124 | } 125 | 126 | func TestNewCircuitBreaker(t *testing.T) { 127 | defaultCB := NewCircuitBreaker(Settings{}) 128 | assert.Equal(t, "", defaultCB.name) 129 | assert.Equal(t, uint32(1), defaultCB.maxRequests) 130 | assert.Equal(t, time.Duration(0), defaultCB.interval) 131 | assert.Equal(t, time.Duration(60)*time.Second, defaultCB.timeout) 132 | assert.NotNil(t, defaultCB.readyToTrip) 133 | assert.Nil(t, defaultCB.onStateChange) 134 | assert.Equal(t, StateClosed, defaultCB.state) 135 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) 136 | assert.True(t, defaultCB.expiry.IsZero()) 137 | 138 | customCB := newCustom() 139 | assert.Equal(t, "cb", customCB.name) 140 | assert.Equal(t, uint32(3), customCB.maxRequests) 141 | assert.Equal(t, time.Duration(30)*time.Second, customCB.interval) 142 | assert.Equal(t, time.Duration(90)*time.Second, customCB.timeout) 143 | assert.NotNil(t, customCB.readyToTrip) 144 | assert.NotNil(t, customCB.onStateChange) 145 | assert.Equal(t, StateClosed, customCB.state) 146 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) 147 | assert.False(t, customCB.expiry.IsZero()) 148 | 149 | negativeDurationCB := newNegativeDurationCB() 150 | assert.Equal(t, "ncb", negativeDurationCB.name) 151 | assert.Equal(t, uint32(1), negativeDurationCB.maxRequests) 152 | assert.Equal(t, time.Duration(0)*time.Second, negativeDurationCB.interval) 153 | assert.Equal(t, time.Duration(60)*time.Second, negativeDurationCB.timeout) 154 | assert.NotNil(t, negativeDurationCB.readyToTrip) 155 | assert.Nil(t, negativeDurationCB.onStateChange) 156 | assert.Equal(t, StateClosed, negativeDurationCB.state) 157 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, negativeDurationCB.counts) 158 | assert.True(t, negativeDurationCB.expiry.IsZero()) 159 | } 160 | 161 | func TestDefaultCircuitBreaker(t *testing.T) { 162 | assert.Equal(t, "", defaultCB.Name()) 163 | 164 | for i := 0; i < 5; i++ { 165 | assert.Nil(t, fail(defaultCB)) 166 | } 167 | assert.Equal(t, StateClosed, defaultCB.State()) 168 | assert.Equal(t, Counts{5, 0, 5, 0, 5}, defaultCB.counts) 169 | 170 | assert.Nil(t, succeed(defaultCB)) 171 | assert.Equal(t, StateClosed, defaultCB.State()) 172 | assert.Equal(t, Counts{6, 1, 5, 1, 0}, defaultCB.counts) 173 | 174 | assert.Nil(t, fail(defaultCB)) 175 | assert.Equal(t, StateClosed, defaultCB.State()) 176 | assert.Equal(t, Counts{7, 1, 6, 0, 1}, defaultCB.counts) 177 | 178 | // StateClosed to StateOpen 179 | for i := 0; i < 5; i++ { 180 | assert.Nil(t, fail(defaultCB)) // 6 consecutive failures 181 | } 182 | assert.Equal(t, StateOpen, defaultCB.State()) 183 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) 184 | assert.False(t, defaultCB.expiry.IsZero()) 185 | 186 | assert.Error(t, succeed(defaultCB)) 187 | assert.Error(t, fail(defaultCB)) 188 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) 189 | 190 | pseudoSleep(defaultCB, time.Duration(59)*time.Second) 191 | assert.Equal(t, StateOpen, defaultCB.State()) 192 | 193 | // StateOpen to StateHalfOpen 194 | pseudoSleep(defaultCB, time.Duration(1)*time.Second) // over Timeout 195 | assert.Equal(t, StateHalfOpen, defaultCB.State()) 196 | assert.True(t, defaultCB.expiry.IsZero()) 197 | 198 | // StateHalfOpen to StateOpen 199 | assert.Nil(t, fail(defaultCB)) 200 | assert.Equal(t, StateOpen, defaultCB.State()) 201 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) 202 | assert.False(t, defaultCB.expiry.IsZero()) 203 | 204 | // StateOpen to StateHalfOpen 205 | pseudoSleep(defaultCB, time.Duration(60)*time.Second) 206 | assert.Equal(t, StateHalfOpen, defaultCB.State()) 207 | assert.True(t, defaultCB.expiry.IsZero()) 208 | 209 | // StateHalfOpen to StateClosed 210 | assert.Nil(t, succeed(defaultCB)) 211 | assert.Equal(t, StateClosed, defaultCB.State()) 212 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) 213 | assert.True(t, defaultCB.expiry.IsZero()) 214 | } 215 | 216 | func TestCustomCircuitBreaker(t *testing.T) { 217 | assert.Equal(t, "cb", customCB.Name()) 218 | 219 | for i := 0; i < 5; i++ { 220 | assert.Nil(t, succeed(customCB)) 221 | assert.Nil(t, fail(customCB)) 222 | } 223 | assert.Equal(t, StateClosed, customCB.State()) 224 | assert.Equal(t, Counts{10, 5, 5, 0, 1}, customCB.counts) 225 | 226 | pseudoSleep(customCB, time.Duration(29)*time.Second) 227 | assert.Nil(t, succeed(customCB)) 228 | assert.Equal(t, StateClosed, customCB.State()) 229 | assert.Equal(t, Counts{11, 6, 5, 1, 0}, customCB.counts) 230 | 231 | pseudoSleep(customCB, time.Duration(1)*time.Second) // over Interval 232 | assert.Nil(t, fail(customCB)) 233 | assert.Equal(t, StateClosed, customCB.State()) 234 | assert.Equal(t, Counts{1, 0, 1, 0, 1}, customCB.counts) 235 | 236 | // StateClosed to StateOpen 237 | assert.Nil(t, succeed(customCB)) 238 | assert.Nil(t, fail(customCB)) // failure ratio: 2/3 >= 0.6 239 | assert.Equal(t, StateOpen, customCB.State()) 240 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) 241 | assert.False(t, customCB.expiry.IsZero()) 242 | assert.Equal(t, StateChange{"cb", StateClosed, StateOpen}, stateChange) 243 | 244 | // StateOpen to StateHalfOpen 245 | pseudoSleep(customCB, time.Duration(90)*time.Second) 246 | assert.Equal(t, StateHalfOpen, customCB.State()) 247 | assert.True(t, defaultCB.expiry.IsZero()) 248 | assert.Equal(t, StateChange{"cb", StateOpen, StateHalfOpen}, stateChange) 249 | 250 | assert.Nil(t, succeed(customCB)) 251 | assert.Nil(t, succeed(customCB)) 252 | assert.Equal(t, StateHalfOpen, customCB.State()) 253 | assert.Equal(t, Counts{2, 2, 0, 2, 0}, customCB.counts) 254 | 255 | // StateHalfOpen to StateClosed 256 | ch := succeedLater(customCB, time.Duration(100)*time.Millisecond) // 3 consecutive successes 257 | time.Sleep(time.Duration(50) * time.Millisecond) 258 | assert.Equal(t, Counts{3, 2, 0, 2, 0}, customCB.counts) 259 | assert.Error(t, succeed(customCB)) // over MaxRequests 260 | assert.Nil(t, <-ch) 261 | assert.Equal(t, StateClosed, customCB.State()) 262 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) 263 | assert.False(t, customCB.expiry.IsZero()) 264 | assert.Equal(t, StateChange{"cb", StateHalfOpen, StateClosed}, stateChange) 265 | } 266 | 267 | func TestTwoStepCircuitBreaker(t *testing.T) { 268 | tscb := NewTwoStepCircuitBreaker(Settings{Name: "tscb"}) 269 | assert.Equal(t, "tscb", tscb.Name()) 270 | 271 | for i := 0; i < 5; i++ { 272 | assert.Nil(t, fail2Step(tscb)) 273 | } 274 | 275 | assert.Equal(t, StateClosed, tscb.State()) 276 | assert.Equal(t, Counts{5, 0, 5, 0, 5}, tscb.cb.counts) 277 | 278 | assert.Nil(t, succeed2Step(tscb)) 279 | assert.Equal(t, StateClosed, tscb.State()) 280 | assert.Equal(t, Counts{6, 1, 5, 1, 0}, tscb.cb.counts) 281 | 282 | assert.Nil(t, fail2Step(tscb)) 283 | assert.Equal(t, StateClosed, tscb.State()) 284 | assert.Equal(t, Counts{7, 1, 6, 0, 1}, tscb.cb.counts) 285 | 286 | // StateClosed to StateOpen 287 | for i := 0; i < 5; i++ { 288 | assert.Nil(t, fail2Step(tscb)) // 6 consecutive failures 289 | } 290 | assert.Equal(t, StateOpen, tscb.State()) 291 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) 292 | assert.False(t, tscb.cb.expiry.IsZero()) 293 | 294 | assert.Error(t, succeed2Step(tscb)) 295 | assert.Error(t, fail2Step(tscb)) 296 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) 297 | 298 | pseudoSleep(tscb.cb, time.Duration(59)*time.Second) 299 | assert.Equal(t, StateOpen, tscb.State()) 300 | 301 | // StateOpen to StateHalfOpen 302 | pseudoSleep(tscb.cb, time.Duration(1)*time.Second) // over Timeout 303 | assert.Equal(t, StateHalfOpen, tscb.State()) 304 | assert.True(t, tscb.cb.expiry.IsZero()) 305 | 306 | // StateHalfOpen to StateOpen 307 | assert.Nil(t, fail2Step(tscb)) 308 | assert.Equal(t, StateOpen, tscb.State()) 309 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) 310 | assert.False(t, tscb.cb.expiry.IsZero()) 311 | 312 | // StateOpen to StateHalfOpen 313 | pseudoSleep(tscb.cb, time.Duration(60)*time.Second) 314 | assert.Equal(t, StateHalfOpen, tscb.State()) 315 | assert.True(t, tscb.cb.expiry.IsZero()) 316 | 317 | // StateHalfOpen to StateClosed 318 | assert.Nil(t, succeed2Step(tscb)) 319 | assert.Equal(t, StateClosed, tscb.State()) 320 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) 321 | assert.True(t, tscb.cb.expiry.IsZero()) 322 | } 323 | 324 | func TestPanicInRequest(t *testing.T) { 325 | assert.Panics(t, func() { _ = causePanic(defaultCB) }) 326 | assert.Equal(t, Counts{1, 0, 1, 0, 1}, defaultCB.counts) 327 | } 328 | 329 | func TestGeneration(t *testing.T) { 330 | pseudoSleep(customCB, time.Duration(29)*time.Second) 331 | assert.Nil(t, succeed(customCB)) 332 | ch := succeedLater(customCB, time.Duration(1500)*time.Millisecond) 333 | time.Sleep(time.Duration(500) * time.Millisecond) 334 | assert.Equal(t, Counts{2, 1, 0, 1, 0}, customCB.counts) 335 | 336 | time.Sleep(time.Duration(500) * time.Millisecond) // over Interval 337 | assert.Equal(t, StateClosed, customCB.State()) 338 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) 339 | 340 | // the request from the previous generation has no effect on customCB.counts 341 | assert.Nil(t, <-ch) 342 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) 343 | } 344 | 345 | func TestCustomIsSuccessful(t *testing.T) { 346 | isSuccessful := func(error) bool { 347 | return true 348 | } 349 | cb := NewCircuitBreaker(Settings{IsSuccessful: isSuccessful}) 350 | 351 | for i := 0; i < 5; i++ { 352 | assert.Nil(t, fail(cb)) 353 | } 354 | assert.Equal(t, StateClosed, cb.State()) 355 | assert.Equal(t, Counts{5, 5, 0, 5, 0}, cb.counts) 356 | 357 | cb.counts.clear() 358 | 359 | cb.isSuccessful = func(err error) bool { 360 | return err == nil 361 | } 362 | for i := 0; i < 6; i++ { 363 | assert.Nil(t, fail(cb)) 364 | } 365 | assert.Equal(t, StateOpen, cb.State()) 366 | 367 | } 368 | 369 | func TestCircuitBreakerInParallel(t *testing.T) { 370 | runtime.GOMAXPROCS(runtime.NumCPU()) 371 | 372 | ch := make(chan error) 373 | 374 | const numReqs = 10000 375 | routine := func() { 376 | for i := 0; i < numReqs; i++ { 377 | ch <- succeed(customCB) 378 | } 379 | } 380 | 381 | const numRoutines = 10 382 | for i := 0; i < numRoutines; i++ { 383 | go routine() 384 | } 385 | 386 | total := uint32(numReqs * numRoutines) 387 | for i := uint32(0); i < total; i++ { 388 | err := <-ch 389 | assert.Nil(t, err) 390 | } 391 | assert.Equal(t, Counts{total, total, 0, total, 0}, customCB.counts) 392 | } 393 | -------------------------------------------------------------------------------- /v2/distributed_gobreaker.go: -------------------------------------------------------------------------------- 1 | package gobreaker 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "time" 7 | ) 8 | 9 | var ( 10 | // ErrNoSharedStore is returned when there is no shared store. 11 | ErrNoSharedStore = errors.New("no shared store") 12 | // ErrNoSharedState is returned when there is no shared state. 13 | ErrNoSharedState = errors.New("no shared state") 14 | ) 15 | 16 | // SharedState represents the shared state of DistributedCircuitBreaker. 17 | type SharedState struct { 18 | State State `json:"state"` 19 | Generation uint64 `json:"generation"` 20 | Counts Counts `json:"counts"` 21 | Expiry time.Time `json:"expiry"` 22 | } 23 | 24 | // SharedDataStore stores the shared state of DistributedCircuitBreaker. 25 | type SharedDataStore interface { 26 | Lock(name string) error 27 | Unlock(name string) error 28 | GetData(name string) ([]byte, error) 29 | SetData(name string, data []byte) error 30 | } 31 | 32 | // DistributedCircuitBreaker extends CircuitBreaker with SharedDataStore. 33 | type DistributedCircuitBreaker[T any] struct { 34 | *CircuitBreaker[T] 35 | store SharedDataStore 36 | } 37 | 38 | // NewDistributedCircuitBreaker returns a new DistributedCircuitBreaker. 39 | func NewDistributedCircuitBreaker[T any](store SharedDataStore, settings Settings) (dcb *DistributedCircuitBreaker[T], err error) { 40 | if store == nil { 41 | return nil, ErrNoSharedStore 42 | } 43 | 44 | dcb = &DistributedCircuitBreaker[T]{ 45 | CircuitBreaker: NewCircuitBreaker[T](settings), 46 | store: store, 47 | } 48 | 49 | err = dcb.lock() 50 | if err != nil { 51 | return nil, err 52 | } 53 | defer func() { 54 | e := dcb.unlock() 55 | if err == nil { 56 | err = e 57 | } 58 | }() 59 | 60 | _, err = dcb.getSharedState() 61 | if err == ErrNoSharedState { 62 | err = dcb.setSharedState(dcb.extract()) 63 | } 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | return dcb, nil 69 | } 70 | 71 | const ( 72 | mutexTimeout = 5 * time.Second 73 | mutexWaitTime = 500 * time.Millisecond 74 | ) 75 | 76 | func (dcb *DistributedCircuitBreaker[T]) mutexKey() string { 77 | return "gobreaker:mutex:" + dcb.name 78 | } 79 | 80 | func (dcb *DistributedCircuitBreaker[T]) lock() error { 81 | if dcb.store == nil { 82 | return ErrNoSharedStore 83 | } 84 | 85 | var err error 86 | expiry := time.Now().Add(mutexTimeout) 87 | for time.Now().Before(expiry) { 88 | err = dcb.store.Lock(dcb.mutexKey()) 89 | if err == nil { 90 | return nil 91 | } 92 | 93 | time.Sleep(mutexWaitTime) 94 | } 95 | return err 96 | } 97 | 98 | func (dcb *DistributedCircuitBreaker[T]) unlock() error { 99 | if dcb.store == nil { 100 | return ErrNoSharedStore 101 | } 102 | 103 | return dcb.store.Unlock(dcb.mutexKey()) 104 | } 105 | 106 | func (dcb *DistributedCircuitBreaker[T]) sharedStateKey() string { 107 | return "gobreaker:state:" + dcb.name 108 | } 109 | 110 | func (dcb *DistributedCircuitBreaker[T]) getSharedState() (SharedState, error) { 111 | var state SharedState 112 | if dcb.store == nil { 113 | return state, ErrNoSharedStore 114 | } 115 | 116 | data, err := dcb.store.GetData(dcb.sharedStateKey()) 117 | if len(data) == 0 { 118 | return state, ErrNoSharedState 119 | } else if err != nil { 120 | return state, err 121 | } 122 | 123 | err = json.Unmarshal(data, &state) 124 | return state, err 125 | } 126 | 127 | func (dcb *DistributedCircuitBreaker[T]) setSharedState(state SharedState) error { 128 | if dcb.store == nil { 129 | return ErrNoSharedStore 130 | } 131 | 132 | data, err := json.Marshal(state) 133 | if err != nil { 134 | return err 135 | } 136 | 137 | return dcb.store.SetData(dcb.sharedStateKey(), data) 138 | } 139 | 140 | func (dcb *DistributedCircuitBreaker[T]) inject(shared SharedState) { 141 | dcb.mutex.Lock() 142 | defer dcb.mutex.Unlock() 143 | 144 | dcb.state = shared.State 145 | dcb.generation = shared.Generation 146 | dcb.counts = shared.Counts 147 | dcb.expiry = shared.Expiry 148 | } 149 | 150 | func (dcb *DistributedCircuitBreaker[T]) extract() SharedState { 151 | dcb.mutex.Lock() 152 | defer dcb.mutex.Unlock() 153 | 154 | return SharedState{ 155 | State: dcb.state, 156 | Generation: dcb.generation, 157 | Counts: dcb.counts, 158 | Expiry: dcb.expiry, 159 | } 160 | } 161 | 162 | // State returns the State of DistributedCircuitBreaker. 163 | func (dcb *DistributedCircuitBreaker[T]) State() (state State, err error) { 164 | shared, err := dcb.getSharedState() 165 | if err != nil { 166 | return shared.State, err 167 | } 168 | 169 | err = dcb.lock() 170 | if err != nil { 171 | return state, err 172 | } 173 | defer func() { 174 | e := dcb.unlock() 175 | if err == nil { 176 | err = e 177 | } 178 | }() 179 | 180 | dcb.inject(shared) 181 | state = dcb.CircuitBreaker.State() 182 | shared = dcb.extract() 183 | 184 | err = dcb.setSharedState(shared) 185 | return state, err 186 | } 187 | 188 | // Execute runs the given request if the DistributedCircuitBreaker accepts it. 189 | func (dcb *DistributedCircuitBreaker[T]) Execute(req func() (T, error)) (t T, err error) { 190 | shared, err := dcb.getSharedState() 191 | if err != nil { 192 | return t, err 193 | } 194 | 195 | err = dcb.lock() 196 | if err != nil { 197 | return t, err 198 | } 199 | defer func() { 200 | e := dcb.unlock() 201 | if err == nil { 202 | err = e 203 | } 204 | }() 205 | 206 | dcb.inject(shared) 207 | t, err = dcb.CircuitBreaker.Execute(req) 208 | shared = dcb.extract() 209 | 210 | e := dcb.setSharedState(shared) 211 | if e != nil { 212 | return t, e 213 | } 214 | 215 | return t, err 216 | } 217 | -------------------------------------------------------------------------------- /v2/distributed_gobreaker_test.go: -------------------------------------------------------------------------------- 1 | package gobreaker 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "time" 7 | 8 | "github.com/alicebob/miniredis/v2" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | var redisServer *miniredis.Miniredis 13 | 14 | func setUpDCB() *DistributedCircuitBreaker[any] { 15 | var err error 16 | redisServer, err := miniredis.Run() 17 | if err != nil { 18 | panic(err) 19 | } 20 | 21 | store := NewRedisStore(redisServer.Addr()) 22 | 23 | dcb, err := NewDistributedCircuitBreaker[any](store, Settings{ 24 | Name: "TestBreaker", 25 | MaxRequests: 3, 26 | Interval: time.Second, 27 | Timeout: time.Second * 2, 28 | ReadyToTrip: func(counts Counts) bool { 29 | return counts.ConsecutiveFailures > 5 30 | }, 31 | }) 32 | if err != nil { 33 | panic(err) 34 | } 35 | 36 | return dcb 37 | } 38 | 39 | func tearDownDCB(dcb *DistributedCircuitBreaker[any]) { 40 | if dcb != nil { 41 | store := dcb.store.(*RedisStore) 42 | store.Close() 43 | } 44 | 45 | if redisServer != nil { 46 | redisServer.Close() 47 | redisServer = nil 48 | } 49 | } 50 | 51 | func dcbPseudoSleep(dcb *DistributedCircuitBreaker[any], period time.Duration) { 52 | state, err := dcb.getSharedState() 53 | if err != nil { 54 | panic(err) 55 | } 56 | 57 | state.Expiry = state.Expiry.Add(-period) 58 | // Reset counts if the interval has passed 59 | if time.Now().After(state.Expiry) { 60 | state.Counts = Counts{} 61 | } 62 | 63 | err = dcb.setSharedState(state) 64 | if err != nil { 65 | panic(err) 66 | } 67 | } 68 | 69 | func successRequest(dcb *DistributedCircuitBreaker[any]) error { 70 | _, err := dcb.Execute(func() (interface{}, error) { return nil, nil }) 71 | return err 72 | } 73 | 74 | func failRequest(dcb *DistributedCircuitBreaker[any]) error { 75 | _, err := dcb.Execute(func() (interface{}, error) { return nil, errors.New("fail") }) 76 | if err != nil && err.Error() == "fail" { 77 | return nil 78 | } 79 | return err 80 | } 81 | 82 | func assertState(t *testing.T, dcb *DistributedCircuitBreaker[any], expected State) { 83 | state, err := dcb.State() 84 | assert.Equal(t, expected, state) 85 | assert.NoError(t, err) 86 | } 87 | 88 | func TestDistributedCircuitBreakerInitialization(t *testing.T) { 89 | dcb := setUpDCB() 90 | defer tearDownDCB(dcb) 91 | 92 | assert.Equal(t, "TestBreaker", dcb.Name()) 93 | assert.Equal(t, uint32(3), dcb.maxRequests) 94 | assert.Equal(t, time.Second, dcb.interval) 95 | assert.Equal(t, time.Second*2, dcb.timeout) 96 | assert.NotNil(t, dcb.readyToTrip) 97 | 98 | assertState(t, dcb, StateClosed) 99 | } 100 | 101 | func TestDistributedCircuitBreakerStateTransitions(t *testing.T) { 102 | dcb := setUpDCB() 103 | defer tearDownDCB(dcb) 104 | 105 | // Check if initial state is closed 106 | assertState(t, dcb, StateClosed) 107 | 108 | // StateClosed to StateOpen 109 | for i := 0; i < 6; i++ { 110 | assert.NoError(t, failRequest(dcb)) 111 | } 112 | assertState(t, dcb, StateOpen) 113 | 114 | // Ensure requests fail when the circuit is open 115 | err := failRequest(dcb) 116 | assert.Equal(t, ErrOpenState, err) 117 | 118 | // Wait for timeout so that the state will move to half-open 119 | dcbPseudoSleep(dcb, dcb.timeout) 120 | assertState(t, dcb, StateHalfOpen) 121 | 122 | // StateHalfOpen to StateClosed 123 | for i := 0; i < int(dcb.maxRequests); i++ { 124 | assert.NoError(t, successRequest(dcb)) 125 | } 126 | assertState(t, dcb, StateClosed) 127 | 128 | // StateClosed to StateOpen (again) 129 | for i := 0; i < 6; i++ { 130 | assert.NoError(t, failRequest(dcb)) 131 | } 132 | assertState(t, dcb, StateOpen) 133 | } 134 | 135 | func TestDistributedCircuitBreakerExecution(t *testing.T) { 136 | dcb := setUpDCB() 137 | defer tearDownDCB(dcb) 138 | 139 | // Test successful execution 140 | result, err := dcb.Execute(func() (interface{}, error) { 141 | return "success", nil 142 | }) 143 | assert.NoError(t, err) 144 | assert.Equal(t, "success", result) 145 | 146 | // Test failed execution 147 | _, err = dcb.Execute(func() (interface{}, error) { 148 | return nil, errors.New("test error") 149 | }) 150 | assert.Error(t, err) 151 | assert.Equal(t, "test error", err.Error()) 152 | } 153 | 154 | func TestDistributedCircuitBreakerCounts(t *testing.T) { 155 | dcb := setUpDCB() 156 | defer tearDownDCB(dcb) 157 | 158 | for i := 0; i < 5; i++ { 159 | assert.Nil(t, successRequest(dcb)) 160 | } 161 | 162 | state, err := dcb.getSharedState() 163 | assert.Equal(t, Counts{5, 5, 0, 5, 0}, state.Counts) 164 | assert.NoError(t, err) 165 | 166 | assert.Nil(t, failRequest(dcb)) 167 | state, err = dcb.getSharedState() 168 | assert.Equal(t, Counts{6, 5, 1, 0, 1}, state.Counts) 169 | assert.NoError(t, err) 170 | } 171 | 172 | var customDCB *DistributedCircuitBreaker[any] 173 | 174 | func TestCustomDistributedCircuitBreaker(t *testing.T) { 175 | mr, err := miniredis.Run() 176 | if err != nil { 177 | panic(err) 178 | } 179 | defer mr.Close() 180 | 181 | store := NewRedisStore(mr.Addr()) 182 | 183 | customDCB, err = NewDistributedCircuitBreaker[any](store, Settings{ 184 | Name: "CustomBreaker", 185 | MaxRequests: 3, 186 | Interval: time.Second * 30, 187 | Timeout: time.Second * 90, 188 | ReadyToTrip: func(counts Counts) bool { 189 | numReqs := counts.Requests 190 | failureRatio := float64(counts.TotalFailures) / float64(numReqs) 191 | return numReqs >= 3 && failureRatio >= 0.6 192 | }, 193 | }) 194 | assert.NoError(t, err) 195 | 196 | t.Run("Initialization", func(t *testing.T) { 197 | assert.Equal(t, "CustomBreaker", customDCB.Name()) 198 | assertState(t, customDCB, StateClosed) 199 | }) 200 | 201 | t.Run("Counts and State Transitions", func(t *testing.T) { 202 | // Perform 5 successful and 5 failed requests 203 | for i := 0; i < 5; i++ { 204 | assert.NoError(t, successRequest(customDCB)) 205 | assert.NoError(t, failRequest(customDCB)) 206 | } 207 | 208 | state, err := customDCB.getSharedState() 209 | assert.NoError(t, err) 210 | assert.Equal(t, StateClosed, state.State) 211 | assert.Equal(t, Counts{10, 5, 5, 0, 1}, state.Counts) 212 | 213 | // Perform one more successful request 214 | assert.NoError(t, successRequest(customDCB)) 215 | state, err = customDCB.getSharedState() 216 | assert.NoError(t, err) 217 | assert.Equal(t, Counts{11, 6, 5, 1, 0}, state.Counts) 218 | 219 | // Simulate time passing to reset counts 220 | dcbPseudoSleep(customDCB, time.Second*30) 221 | 222 | // Perform requests to trigger StateOpen 223 | assert.NoError(t, successRequest(customDCB)) 224 | assert.NoError(t, failRequest(customDCB)) 225 | assert.NoError(t, failRequest(customDCB)) 226 | 227 | // Check if the circuit breaker is now open 228 | assertState(t, customDCB, StateOpen) 229 | 230 | state, err = customDCB.getSharedState() 231 | assert.NoError(t, err) 232 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, state.Counts) 233 | }) 234 | 235 | t.Run("Timeout and Half-Open State", func(t *testing.T) { 236 | // Simulate timeout to transition to half-open state 237 | dcbPseudoSleep(customDCB, time.Second*90) 238 | assertState(t, customDCB, StateHalfOpen) 239 | 240 | // Successful requests in half-open state should close the circuit 241 | for i := 0; i < 3; i++ { 242 | assert.NoError(t, successRequest(customDCB)) 243 | } 244 | assertState(t, customDCB, StateClosed) 245 | }) 246 | } 247 | 248 | func TestCustomDistributedCircuitBreakerStateTransitions(t *testing.T) { 249 | // Setup 250 | var stateChange StateChange 251 | customSt := Settings{ 252 | Name: "cb", 253 | MaxRequests: 3, 254 | Interval: 5 * time.Second, 255 | Timeout: 5 * time.Second, 256 | ReadyToTrip: func(counts Counts) bool { 257 | return counts.ConsecutiveFailures >= 2 258 | }, 259 | OnStateChange: func(name string, from State, to State) { 260 | stateChange = StateChange{name, from, to} 261 | }, 262 | } 263 | 264 | mr, err := miniredis.Run() 265 | if err != nil { 266 | t.Fatalf("Failed to start miniredis: %v", err) 267 | } 268 | defer mr.Close() 269 | 270 | store := NewRedisStore(mr.Addr()) 271 | 272 | dcb, err := NewDistributedCircuitBreaker[any](store, customSt) 273 | assert.NoError(t, err) 274 | 275 | // Test case 276 | t.Run("Circuit Breaker State Transitions", func(t *testing.T) { 277 | // Initial state should be Closed 278 | assertState(t, dcb, StateClosed) 279 | 280 | // Cause two consecutive failures to trip the circuit 281 | for i := 0; i < 2; i++ { 282 | err := failRequest(dcb) 283 | assert.NoError(t, err, "Fail request should not return an error") 284 | } 285 | 286 | // Circuit should now be Open 287 | assertState(t, dcb, StateOpen) 288 | assert.Equal(t, StateChange{"cb", StateClosed, StateOpen}, stateChange) 289 | 290 | // Requests should fail immediately when circuit is Open 291 | err := successRequest(dcb) 292 | assert.Error(t, err) 293 | assert.Equal(t, ErrOpenState, err) 294 | 295 | // Simulate timeout to transition to Half-Open 296 | dcbPseudoSleep(dcb, 6*time.Second) 297 | assertState(t, dcb, StateHalfOpen) 298 | assert.Equal(t, StateChange{"cb", StateOpen, StateHalfOpen}, stateChange) 299 | 300 | // Successful requests in Half-Open state should close the circuit 301 | for i := 0; i < int(dcb.maxRequests); i++ { 302 | err := successRequest(dcb) 303 | assert.NoError(t, err) 304 | } 305 | 306 | // Circuit should now be Closed 307 | assertState(t, dcb, StateClosed) 308 | assert.Equal(t, StateChange{"cb", StateHalfOpen, StateClosed}, stateChange) 309 | }) 310 | } 311 | -------------------------------------------------------------------------------- /v2/example/http_breaker.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | "net/http" 8 | 9 | "github.com/sony/gobreaker/v2" 10 | ) 11 | 12 | var cb *gobreaker.CircuitBreaker[[]byte] 13 | 14 | func init() { 15 | var st gobreaker.Settings 16 | st.Name = "HTTP GET" 17 | st.ReadyToTrip = func(counts gobreaker.Counts) bool { 18 | failureRatio := float64(counts.TotalFailures) / float64(counts.Requests) 19 | return counts.Requests >= 3 && failureRatio >= 0.6 20 | } 21 | 22 | cb = gobreaker.NewCircuitBreaker[[]byte](st) 23 | } 24 | 25 | // Get wraps http.Get in CircuitBreaker. 26 | func Get(url string) ([]byte, error) { 27 | body, err := cb.Execute(func() ([]byte, error) { 28 | resp, err := http.Get(url) 29 | if err != nil { 30 | return nil, err 31 | } 32 | 33 | defer resp.Body.Close() 34 | body, err := io.ReadAll(resp.Body) 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | return body, nil 40 | }) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | return body, nil 46 | } 47 | 48 | func main() { 49 | body, err := Get("http://www.google.com/robots.txt") 50 | if err != nil { 51 | log.Fatal(err) 52 | } 53 | 54 | fmt.Println(string(body)) 55 | } 56 | -------------------------------------------------------------------------------- /v2/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/sony/gobreaker/v2 2 | 3 | go 1.22.0 4 | 5 | toolchain go1.22.10 6 | 7 | require ( 8 | github.com/alicebob/miniredis/v2 v2.33.0 9 | github.com/go-redsync/redsync/v4 v4.13.0 10 | github.com/redis/go-redis/v9 v9.7.3 11 | github.com/stretchr/testify v1.8.4 12 | ) 13 | 14 | require ( 15 | github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect 16 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 17 | github.com/davecgh/go-spew v1.1.1 // indirect 18 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 19 | github.com/hashicorp/errwrap v1.1.0 // indirect 20 | github.com/hashicorp/go-multierror v1.1.1 // indirect 21 | github.com/pmezard/go-difflib v1.0.0 // indirect 22 | github.com/yuin/gopher-lua v1.1.1 // indirect 23 | gopkg.in/yaml.v3 v3.0.1 // indirect 24 | ) 25 | -------------------------------------------------------------------------------- /v2/go.sum: -------------------------------------------------------------------------------- 1 | github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= 2 | github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= 3 | github.com/alicebob/miniredis/v2 v2.33.0 h1:uvTF0EDeu9RLnUEG27Db5I68ESoIxTiXbNUiji6lZrA= 4 | github.com/alicebob/miniredis/v2 v2.33.0/go.mod h1:MhP4a3EU7aENRi9aO+tHfTBZicLqQevyi/DJpoj6mi0= 5 | github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= 6 | github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= 7 | github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= 8 | github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= 9 | github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= 10 | github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 11 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 12 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 13 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= 14 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= 15 | github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= 16 | github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= 17 | github.com/go-redis/redis/v7 v7.4.1 h1:PASvf36gyUpr2zdOUS/9Zqc80GbM+9BDyiJSJDDOrTI= 18 | github.com/go-redis/redis/v7 v7.4.1/go.mod h1:JDNMw23GTyLNC4GZu9njt15ctBQVn7xjRfnwdHj/Dcg= 19 | github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= 20 | github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= 21 | github.com/go-redsync/redsync/v4 v4.13.0 h1:49X6GJfnbLGaIpBBREM/zA4uIMDXKAh1NDkvQ1EkZKA= 22 | github.com/go-redsync/redsync/v4 v4.13.0/go.mod h1:HMW4Q224GZQz6x1Xc7040Yfgacukdzu7ifTDAKiyErQ= 23 | github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= 24 | github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= 25 | github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 26 | github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= 27 | github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 28 | github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= 29 | github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= 30 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 31 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 32 | github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= 33 | github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= 34 | github.com/redis/rueidis v1.0.19 h1:s65oWtotzlIFN8eMPhyYwxlwLR1lUdhza2KtWprKYSo= 35 | github.com/redis/rueidis v1.0.19/go.mod h1:8B+r5wdnjwK3lTFml5VtxjzGOQAC+5UmujoD12pDrEo= 36 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 37 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 38 | github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203 h1:QVqDTf3h2WHt08YuiTGPZLls0Wq99X9bWd0Q5ZSBesM= 39 | github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203/go.mod h1:oqN97ltKNihBbwlX8dLpwxCl3+HnXKV/R0e+sRLd9C8= 40 | github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= 41 | github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= 42 | golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= 43 | golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= 44 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 45 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 46 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 47 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 48 | -------------------------------------------------------------------------------- /v2/gobreaker.go: -------------------------------------------------------------------------------- 1 | // Package gobreaker implements the Circuit Breaker pattern. 2 | // See https://msdn.microsoft.com/en-us/library/dn589784.aspx. 3 | package gobreaker 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | // State is a type that represents a state of CircuitBreaker. 13 | type State int 14 | 15 | // These constants are states of CircuitBreaker. 16 | const ( 17 | StateClosed State = iota 18 | StateHalfOpen 19 | StateOpen 20 | ) 21 | 22 | var ( 23 | // ErrTooManyRequests is returned when the CB state is half open and the requests count is over the cb maxRequests 24 | ErrTooManyRequests = errors.New("too many requests") 25 | // ErrOpenState is returned when the CB state is open 26 | ErrOpenState = errors.New("circuit breaker is open") 27 | ) 28 | 29 | // String implements stringer interface. 30 | func (s State) String() string { 31 | switch s { 32 | case StateClosed: 33 | return "closed" 34 | case StateHalfOpen: 35 | return "half-open" 36 | case StateOpen: 37 | return "open" 38 | default: 39 | return fmt.Sprintf("unknown state: %d", s) 40 | } 41 | } 42 | 43 | // Counts holds the numbers of requests and their successes/failures. 44 | // CircuitBreaker clears the internal Counts either 45 | // on the change of the state or at the closed-state intervals. 46 | // Counts ignores the results of the requests sent before clearing. 47 | type Counts struct { 48 | Requests uint32 49 | TotalSuccesses uint32 50 | TotalFailures uint32 51 | ConsecutiveSuccesses uint32 52 | ConsecutiveFailures uint32 53 | } 54 | 55 | func (c *Counts) onRequest() { 56 | c.Requests++ 57 | } 58 | 59 | func (c *Counts) onSuccess() { 60 | c.TotalSuccesses++ 61 | c.ConsecutiveSuccesses++ 62 | c.ConsecutiveFailures = 0 63 | } 64 | 65 | func (c *Counts) onFailure() { 66 | c.TotalFailures++ 67 | c.ConsecutiveFailures++ 68 | c.ConsecutiveSuccesses = 0 69 | } 70 | 71 | func (c *Counts) clear() { 72 | c.Requests = 0 73 | c.TotalSuccesses = 0 74 | c.TotalFailures = 0 75 | c.ConsecutiveSuccesses = 0 76 | c.ConsecutiveFailures = 0 77 | } 78 | 79 | // Settings configures CircuitBreaker: 80 | // 81 | // Name is the name of the CircuitBreaker. 82 | // 83 | // MaxRequests is the maximum number of requests allowed to pass through 84 | // when the CircuitBreaker is half-open. 85 | // If MaxRequests is 0, the CircuitBreaker allows only 1 request. 86 | // 87 | // Interval is the cyclic period of the closed state 88 | // for the CircuitBreaker to clear the internal Counts. 89 | // If Interval is less than or equal to 0, the CircuitBreaker doesn't clear internal Counts during the closed state. 90 | // 91 | // Timeout is the period of the open state, 92 | // after which the state of the CircuitBreaker becomes half-open. 93 | // If Timeout is less than or equal to 0, the timeout value of the CircuitBreaker is set to 60 seconds. 94 | // 95 | // ReadyToTrip is called with a copy of Counts whenever a request fails in the closed state. 96 | // If ReadyToTrip returns true, the CircuitBreaker will be placed into the open state. 97 | // If ReadyToTrip is nil, default ReadyToTrip is used. 98 | // Default ReadyToTrip returns true when the number of consecutive failures is more than 5. 99 | // 100 | // OnStateChange is called whenever the state of the CircuitBreaker changes. 101 | // 102 | // IsSuccessful is called with the error returned from a request. 103 | // If IsSuccessful returns true, the error is counted as a success. 104 | // Otherwise the error is counted as a failure. 105 | // If IsSuccessful is nil, default IsSuccessful is used, which returns false for all non-nil errors. 106 | type Settings struct { 107 | Name string 108 | MaxRequests uint32 109 | Interval time.Duration 110 | Timeout time.Duration 111 | ReadyToTrip func(counts Counts) bool 112 | OnStateChange func(name string, from State, to State) 113 | IsSuccessful func(err error) bool 114 | } 115 | 116 | // CircuitBreaker is a state machine to prevent sending requests that are likely to fail. 117 | type CircuitBreaker[T any] struct { 118 | name string 119 | maxRequests uint32 120 | interval time.Duration 121 | timeout time.Duration 122 | readyToTrip func(counts Counts) bool 123 | isSuccessful func(err error) bool 124 | onStateChange func(name string, from State, to State) 125 | 126 | mutex sync.Mutex 127 | state State 128 | generation uint64 129 | counts Counts 130 | expiry time.Time 131 | } 132 | 133 | // TwoStepCircuitBreaker is like CircuitBreaker but instead of surrounding a function 134 | // with the breaker functionality, it only checks whether a request can proceed and 135 | // expects the caller to report the outcome in a separate step using a callback. 136 | type TwoStepCircuitBreaker[T any] struct { 137 | cb *CircuitBreaker[T] 138 | } 139 | 140 | // NewCircuitBreaker returns a new CircuitBreaker configured with the given Settings. 141 | func NewCircuitBreaker[T any](st Settings) *CircuitBreaker[T] { 142 | cb := new(CircuitBreaker[T]) 143 | 144 | cb.name = st.Name 145 | cb.onStateChange = st.OnStateChange 146 | 147 | if st.MaxRequests == 0 { 148 | cb.maxRequests = 1 149 | } else { 150 | cb.maxRequests = st.MaxRequests 151 | } 152 | 153 | if st.Interval <= 0 { 154 | cb.interval = defaultInterval 155 | } else { 156 | cb.interval = st.Interval 157 | } 158 | 159 | if st.Timeout <= 0 { 160 | cb.timeout = defaultTimeout 161 | } else { 162 | cb.timeout = st.Timeout 163 | } 164 | 165 | if st.ReadyToTrip == nil { 166 | cb.readyToTrip = defaultReadyToTrip 167 | } else { 168 | cb.readyToTrip = st.ReadyToTrip 169 | } 170 | 171 | if st.IsSuccessful == nil { 172 | cb.isSuccessful = defaultIsSuccessful 173 | } else { 174 | cb.isSuccessful = st.IsSuccessful 175 | } 176 | 177 | cb.toNewGeneration(time.Now()) 178 | 179 | return cb 180 | } 181 | 182 | // NewTwoStepCircuitBreaker returns a new TwoStepCircuitBreaker configured with the given Settings. 183 | func NewTwoStepCircuitBreaker[T any](st Settings) *TwoStepCircuitBreaker[T] { 184 | return &TwoStepCircuitBreaker[T]{ 185 | cb: NewCircuitBreaker[T](st), 186 | } 187 | } 188 | 189 | const defaultInterval = time.Duration(0) * time.Second 190 | const defaultTimeout = time.Duration(60) * time.Second 191 | 192 | func defaultReadyToTrip(counts Counts) bool { 193 | return counts.ConsecutiveFailures > 5 194 | } 195 | 196 | func defaultIsSuccessful(err error) bool { 197 | return err == nil 198 | } 199 | 200 | // Name returns the name of the CircuitBreaker. 201 | func (cb *CircuitBreaker[T]) Name() string { 202 | return cb.name 203 | } 204 | 205 | // State returns the current state of the CircuitBreaker. 206 | func (cb *CircuitBreaker[T]) State() State { 207 | cb.mutex.Lock() 208 | defer cb.mutex.Unlock() 209 | 210 | now := time.Now() 211 | state, _ := cb.currentState(now) 212 | return state 213 | } 214 | 215 | // Counts returns internal counters 216 | func (cb *CircuitBreaker[T]) Counts() Counts { 217 | cb.mutex.Lock() 218 | defer cb.mutex.Unlock() 219 | 220 | return cb.counts 221 | } 222 | 223 | // Execute runs the given request if the CircuitBreaker accepts it. 224 | // Execute returns an error instantly if the CircuitBreaker rejects the request. 225 | // Otherwise, Execute returns the result of the request. 226 | // If a panic occurs in the request, the CircuitBreaker handles it as an error 227 | // and causes the same panic again. 228 | func (cb *CircuitBreaker[T]) Execute(req func() (T, error)) (T, error) { 229 | generation, err := cb.beforeRequest() 230 | if err != nil { 231 | var defaultValue T 232 | return defaultValue, err 233 | } 234 | 235 | defer func() { 236 | e := recover() 237 | if e != nil { 238 | cb.afterRequest(generation, false) 239 | panic(e) 240 | } 241 | }() 242 | 243 | result, err := req() 244 | cb.afterRequest(generation, cb.isSuccessful(err)) 245 | return result, err 246 | } 247 | 248 | // Name returns the name of the TwoStepCircuitBreaker. 249 | func (tscb *TwoStepCircuitBreaker[T]) Name() string { 250 | return tscb.cb.Name() 251 | } 252 | 253 | // State returns the current state of the TwoStepCircuitBreaker. 254 | func (tscb *TwoStepCircuitBreaker[T]) State() State { 255 | return tscb.cb.State() 256 | } 257 | 258 | // Counts returns internal counters 259 | func (tscb *TwoStepCircuitBreaker[T]) Counts() Counts { 260 | return tscb.cb.Counts() 261 | } 262 | 263 | // Allow checks if a new request can proceed. It returns a callback that should be used to 264 | // register the success or failure in a separate step. If the circuit breaker doesn't allow 265 | // requests, it returns an error. 266 | func (tscb *TwoStepCircuitBreaker[T]) Allow() (done func(success bool), err error) { 267 | generation, err := tscb.cb.beforeRequest() 268 | if err != nil { 269 | return nil, err 270 | } 271 | 272 | return func(success bool) { 273 | tscb.cb.afterRequest(generation, success) 274 | }, nil 275 | } 276 | 277 | func (cb *CircuitBreaker[T]) beforeRequest() (uint64, error) { 278 | cb.mutex.Lock() 279 | defer cb.mutex.Unlock() 280 | 281 | now := time.Now() 282 | state, generation := cb.currentState(now) 283 | 284 | if state == StateOpen { 285 | return generation, ErrOpenState 286 | } else if state == StateHalfOpen && cb.counts.Requests >= cb.maxRequests { 287 | return generation, ErrTooManyRequests 288 | } 289 | 290 | cb.counts.onRequest() 291 | return generation, nil 292 | } 293 | 294 | func (cb *CircuitBreaker[T]) afterRequest(before uint64, success bool) { 295 | cb.mutex.Lock() 296 | defer cb.mutex.Unlock() 297 | 298 | now := time.Now() 299 | state, generation := cb.currentState(now) 300 | if generation != before { 301 | return 302 | } 303 | 304 | if success { 305 | cb.onSuccess(state, now) 306 | } else { 307 | cb.onFailure(state, now) 308 | } 309 | } 310 | 311 | func (cb *CircuitBreaker[T]) onSuccess(state State, now time.Time) { 312 | switch state { 313 | case StateClosed: 314 | cb.counts.onSuccess() 315 | case StateHalfOpen: 316 | cb.counts.onSuccess() 317 | if cb.counts.ConsecutiveSuccesses >= cb.maxRequests { 318 | cb.setState(StateClosed, now) 319 | } 320 | } 321 | } 322 | 323 | func (cb *CircuitBreaker[T]) onFailure(state State, now time.Time) { 324 | switch state { 325 | case StateClosed: 326 | cb.counts.onFailure() 327 | if cb.readyToTrip(cb.counts) { 328 | cb.setState(StateOpen, now) 329 | } 330 | case StateHalfOpen: 331 | cb.setState(StateOpen, now) 332 | } 333 | } 334 | 335 | func (cb *CircuitBreaker[T]) currentState(now time.Time) (State, uint64) { 336 | switch cb.state { 337 | case StateClosed: 338 | if !cb.expiry.IsZero() && cb.expiry.Before(now) { 339 | cb.toNewGeneration(now) 340 | } 341 | case StateOpen: 342 | if cb.expiry.Before(now) { 343 | cb.setState(StateHalfOpen, now) 344 | } 345 | } 346 | return cb.state, cb.generation 347 | } 348 | 349 | func (cb *CircuitBreaker[T]) setState(state State, now time.Time) { 350 | if cb.state == state { 351 | return 352 | } 353 | 354 | prev := cb.state 355 | cb.state = state 356 | 357 | cb.toNewGeneration(now) 358 | 359 | if cb.onStateChange != nil { 360 | cb.onStateChange(cb.name, prev, state) 361 | } 362 | } 363 | 364 | func (cb *CircuitBreaker[T]) toNewGeneration(now time.Time) { 365 | cb.generation++ 366 | cb.counts.clear() 367 | 368 | var zero time.Time 369 | switch cb.state { 370 | case StateClosed: 371 | if cb.interval == 0 { 372 | cb.expiry = zero 373 | } else { 374 | cb.expiry = now.Add(cb.interval) 375 | } 376 | case StateOpen: 377 | cb.expiry = now.Add(cb.timeout) 378 | default: // StateHalfOpen 379 | cb.expiry = zero 380 | } 381 | } 382 | -------------------------------------------------------------------------------- /v2/gobreaker_test.go: -------------------------------------------------------------------------------- 1 | package gobreaker 2 | 3 | import ( 4 | "errors" 5 | "runtime" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | var defaultCB *CircuitBreaker[bool] 13 | var customCB *CircuitBreaker[bool] 14 | 15 | type StateChange struct { 16 | name string 17 | from State 18 | to State 19 | } 20 | 21 | var stateChange StateChange 22 | 23 | func pseudoSleep(cb *CircuitBreaker[bool], period time.Duration) { 24 | if !cb.expiry.IsZero() { 25 | cb.expiry = cb.expiry.Add(-period) 26 | } 27 | } 28 | 29 | func succeed(cb *CircuitBreaker[bool]) error { 30 | _, err := cb.Execute(func() (bool, error) { return true, nil }) 31 | return err 32 | } 33 | 34 | func succeedLater(cb *CircuitBreaker[bool], delay time.Duration) <-chan error { 35 | ch := make(chan error) 36 | go func() { 37 | _, err := cb.Execute(func() (bool, error) { 38 | time.Sleep(delay) 39 | return true, nil 40 | }) 41 | ch <- err 42 | }() 43 | return ch 44 | } 45 | 46 | func succeed2Step(cb *TwoStepCircuitBreaker[bool]) error { 47 | done, err := cb.Allow() 48 | if err != nil { 49 | return err 50 | } 51 | 52 | done(true) 53 | return nil 54 | } 55 | 56 | func fail(cb *CircuitBreaker[bool]) error { 57 | msg := "fail" 58 | _, err := cb.Execute(func() (bool, error) { return false, errors.New(msg) }) 59 | if err.Error() == msg { 60 | return nil 61 | } 62 | return err 63 | } 64 | 65 | func fail2Step(cb *TwoStepCircuitBreaker[bool]) error { 66 | done, err := cb.Allow() 67 | if err != nil { 68 | return err 69 | } 70 | 71 | done(false) 72 | return nil 73 | } 74 | 75 | func causePanic(cb *CircuitBreaker[bool]) error { 76 | _, err := cb.Execute(func() (bool, error) { panic("oops"); return false, nil }) 77 | return err 78 | } 79 | 80 | func newCustom() *CircuitBreaker[bool] { 81 | var customSt Settings 82 | customSt.Name = "cb" 83 | customSt.MaxRequests = 3 84 | customSt.Interval = time.Duration(30) * time.Second 85 | customSt.Timeout = time.Duration(90) * time.Second 86 | customSt.ReadyToTrip = func(counts Counts) bool { 87 | numReqs := counts.Requests 88 | failureRatio := float64(counts.TotalFailures) / float64(numReqs) 89 | 90 | counts.clear() // no effect on customCB.counts 91 | 92 | return numReqs >= 3 && failureRatio >= 0.6 93 | } 94 | customSt.OnStateChange = func(name string, from State, to State) { 95 | stateChange = StateChange{name, from, to} 96 | } 97 | 98 | return NewCircuitBreaker[bool](customSt) 99 | } 100 | 101 | func newNegativeDurationCB() *CircuitBreaker[bool] { 102 | var negativeSt Settings 103 | negativeSt.Name = "ncb" 104 | negativeSt.Interval = time.Duration(-30) * time.Second 105 | negativeSt.Timeout = time.Duration(-90) * time.Second 106 | 107 | return NewCircuitBreaker[bool](negativeSt) 108 | } 109 | 110 | func init() { 111 | defaultCB = NewCircuitBreaker[bool](Settings{}) 112 | customCB = newCustom() 113 | } 114 | 115 | func TestStateConstants(t *testing.T) { 116 | assert.Equal(t, State(0), StateClosed) 117 | assert.Equal(t, State(1), StateHalfOpen) 118 | assert.Equal(t, State(2), StateOpen) 119 | 120 | assert.Equal(t, StateClosed.String(), "closed") 121 | assert.Equal(t, StateHalfOpen.String(), "half-open") 122 | assert.Equal(t, StateOpen.String(), "open") 123 | assert.Equal(t, State(100).String(), "unknown state: 100") 124 | } 125 | 126 | func TestNewCircuitBreaker(t *testing.T) { 127 | defaultCB := NewCircuitBreaker[bool](Settings{}) 128 | assert.Equal(t, "", defaultCB.name) 129 | assert.Equal(t, uint32(1), defaultCB.maxRequests) 130 | assert.Equal(t, time.Duration(0), defaultCB.interval) 131 | assert.Equal(t, time.Duration(60)*time.Second, defaultCB.timeout) 132 | assert.NotNil(t, defaultCB.readyToTrip) 133 | assert.Nil(t, defaultCB.onStateChange) 134 | assert.Equal(t, StateClosed, defaultCB.state) 135 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) 136 | assert.True(t, defaultCB.expiry.IsZero()) 137 | 138 | customCB := newCustom() 139 | assert.Equal(t, "cb", customCB.name) 140 | assert.Equal(t, uint32(3), customCB.maxRequests) 141 | assert.Equal(t, time.Duration(30)*time.Second, customCB.interval) 142 | assert.Equal(t, time.Duration(90)*time.Second, customCB.timeout) 143 | assert.NotNil(t, customCB.readyToTrip) 144 | assert.NotNil(t, customCB.onStateChange) 145 | assert.Equal(t, StateClosed, customCB.state) 146 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) 147 | assert.False(t, customCB.expiry.IsZero()) 148 | 149 | negativeDurationCB := newNegativeDurationCB() 150 | assert.Equal(t, "ncb", negativeDurationCB.name) 151 | assert.Equal(t, uint32(1), negativeDurationCB.maxRequests) 152 | assert.Equal(t, time.Duration(0)*time.Second, negativeDurationCB.interval) 153 | assert.Equal(t, time.Duration(60)*time.Second, negativeDurationCB.timeout) 154 | assert.NotNil(t, negativeDurationCB.readyToTrip) 155 | assert.Nil(t, negativeDurationCB.onStateChange) 156 | assert.Equal(t, StateClosed, negativeDurationCB.state) 157 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, negativeDurationCB.counts) 158 | assert.True(t, negativeDurationCB.expiry.IsZero()) 159 | } 160 | 161 | func TestDefaultCircuitBreaker(t *testing.T) { 162 | assert.Equal(t, "", defaultCB.Name()) 163 | 164 | for i := 0; i < 5; i++ { 165 | assert.Nil(t, fail(defaultCB)) 166 | } 167 | assert.Equal(t, StateClosed, defaultCB.State()) 168 | assert.Equal(t, Counts{5, 0, 5, 0, 5}, defaultCB.counts) 169 | 170 | assert.Nil(t, succeed(defaultCB)) 171 | assert.Equal(t, StateClosed, defaultCB.State()) 172 | assert.Equal(t, Counts{6, 1, 5, 1, 0}, defaultCB.counts) 173 | 174 | assert.Nil(t, fail(defaultCB)) 175 | assert.Equal(t, StateClosed, defaultCB.State()) 176 | assert.Equal(t, Counts{7, 1, 6, 0, 1}, defaultCB.counts) 177 | 178 | // StateClosed to StateOpen 179 | for i := 0; i < 5; i++ { 180 | assert.Nil(t, fail(defaultCB)) // 6 consecutive failures 181 | } 182 | assert.Equal(t, StateOpen, defaultCB.State()) 183 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) 184 | assert.False(t, defaultCB.expiry.IsZero()) 185 | 186 | assert.Error(t, succeed(defaultCB)) 187 | assert.Error(t, fail(defaultCB)) 188 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) 189 | 190 | pseudoSleep(defaultCB, time.Duration(59)*time.Second) 191 | assert.Equal(t, StateOpen, defaultCB.State()) 192 | 193 | // StateOpen to StateHalfOpen 194 | pseudoSleep(defaultCB, time.Duration(1)*time.Second) // over Timeout 195 | assert.Equal(t, StateHalfOpen, defaultCB.State()) 196 | assert.True(t, defaultCB.expiry.IsZero()) 197 | 198 | // StateHalfOpen to StateOpen 199 | assert.Nil(t, fail(defaultCB)) 200 | assert.Equal(t, StateOpen, defaultCB.State()) 201 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) 202 | assert.False(t, defaultCB.expiry.IsZero()) 203 | 204 | // StateOpen to StateHalfOpen 205 | pseudoSleep(defaultCB, time.Duration(60)*time.Second) 206 | assert.Equal(t, StateHalfOpen, defaultCB.State()) 207 | assert.True(t, defaultCB.expiry.IsZero()) 208 | 209 | // StateHalfOpen to StateClosed 210 | assert.Nil(t, succeed(defaultCB)) 211 | assert.Equal(t, StateClosed, defaultCB.State()) 212 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) 213 | assert.True(t, defaultCB.expiry.IsZero()) 214 | } 215 | 216 | func TestCustomCircuitBreaker(t *testing.T) { 217 | assert.Equal(t, "cb", customCB.Name()) 218 | 219 | for i := 0; i < 5; i++ { 220 | assert.Nil(t, succeed(customCB)) 221 | assert.Nil(t, fail(customCB)) 222 | } 223 | assert.Equal(t, StateClosed, customCB.State()) 224 | assert.Equal(t, Counts{10, 5, 5, 0, 1}, customCB.counts) 225 | 226 | pseudoSleep(customCB, time.Duration(29)*time.Second) 227 | assert.Nil(t, succeed(customCB)) 228 | assert.Equal(t, StateClosed, customCB.State()) 229 | assert.Equal(t, Counts{11, 6, 5, 1, 0}, customCB.counts) 230 | 231 | pseudoSleep(customCB, time.Duration(1)*time.Second) // over Interval 232 | assert.Nil(t, fail(customCB)) 233 | assert.Equal(t, StateClosed, customCB.State()) 234 | assert.Equal(t, Counts{1, 0, 1, 0, 1}, customCB.counts) 235 | 236 | // StateClosed to StateOpen 237 | assert.Nil(t, succeed(customCB)) 238 | assert.Nil(t, fail(customCB)) // failure ratio: 2/3 >= 0.6 239 | assert.Equal(t, StateOpen, customCB.State()) 240 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) 241 | assert.False(t, customCB.expiry.IsZero()) 242 | assert.Equal(t, StateChange{"cb", StateClosed, StateOpen}, stateChange) 243 | 244 | // StateOpen to StateHalfOpen 245 | pseudoSleep(customCB, time.Duration(90)*time.Second) 246 | assert.Equal(t, StateHalfOpen, customCB.State()) 247 | assert.True(t, defaultCB.expiry.IsZero()) 248 | assert.Equal(t, StateChange{"cb", StateOpen, StateHalfOpen}, stateChange) 249 | 250 | assert.Nil(t, succeed(customCB)) 251 | assert.Nil(t, succeed(customCB)) 252 | assert.Equal(t, StateHalfOpen, customCB.State()) 253 | assert.Equal(t, Counts{2, 2, 0, 2, 0}, customCB.counts) 254 | 255 | // StateHalfOpen to StateClosed 256 | ch := succeedLater(customCB, time.Duration(100)*time.Millisecond) // 3 consecutive successes 257 | time.Sleep(time.Duration(50) * time.Millisecond) 258 | assert.Equal(t, Counts{3, 2, 0, 2, 0}, customCB.counts) 259 | assert.Error(t, succeed(customCB)) // over MaxRequests 260 | assert.Nil(t, <-ch) 261 | assert.Equal(t, StateClosed, customCB.State()) 262 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) 263 | assert.False(t, customCB.expiry.IsZero()) 264 | assert.Equal(t, StateChange{"cb", StateHalfOpen, StateClosed}, stateChange) 265 | } 266 | 267 | func TestTwoStepCircuitBreaker(t *testing.T) { 268 | tscb := NewTwoStepCircuitBreaker[bool](Settings{Name: "tscb"}) 269 | assert.Equal(t, "tscb", tscb.Name()) 270 | 271 | for i := 0; i < 5; i++ { 272 | assert.Nil(t, fail2Step(tscb)) 273 | } 274 | 275 | assert.Equal(t, StateClosed, tscb.State()) 276 | assert.Equal(t, Counts{5, 0, 5, 0, 5}, tscb.cb.counts) 277 | 278 | assert.Nil(t, succeed2Step(tscb)) 279 | assert.Equal(t, StateClosed, tscb.State()) 280 | assert.Equal(t, Counts{6, 1, 5, 1, 0}, tscb.cb.counts) 281 | 282 | assert.Nil(t, fail2Step(tscb)) 283 | assert.Equal(t, StateClosed, tscb.State()) 284 | assert.Equal(t, Counts{7, 1, 6, 0, 1}, tscb.cb.counts) 285 | 286 | // StateClosed to StateOpen 287 | for i := 0; i < 5; i++ { 288 | assert.Nil(t, fail2Step(tscb)) // 6 consecutive failures 289 | } 290 | assert.Equal(t, StateOpen, tscb.State()) 291 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) 292 | assert.False(t, tscb.cb.expiry.IsZero()) 293 | 294 | assert.Error(t, succeed2Step(tscb)) 295 | assert.Error(t, fail2Step(tscb)) 296 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) 297 | 298 | pseudoSleep(tscb.cb, time.Duration(59)*time.Second) 299 | assert.Equal(t, StateOpen, tscb.State()) 300 | 301 | // StateOpen to StateHalfOpen 302 | pseudoSleep(tscb.cb, time.Duration(1)*time.Second) // over Timeout 303 | assert.Equal(t, StateHalfOpen, tscb.State()) 304 | assert.True(t, tscb.cb.expiry.IsZero()) 305 | 306 | // StateHalfOpen to StateOpen 307 | assert.Nil(t, fail2Step(tscb)) 308 | assert.Equal(t, StateOpen, tscb.State()) 309 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) 310 | assert.False(t, tscb.cb.expiry.IsZero()) 311 | 312 | // StateOpen to StateHalfOpen 313 | pseudoSleep(tscb.cb, time.Duration(60)*time.Second) 314 | assert.Equal(t, StateHalfOpen, tscb.State()) 315 | assert.True(t, tscb.cb.expiry.IsZero()) 316 | 317 | // StateHalfOpen to StateClosed 318 | assert.Nil(t, succeed2Step(tscb)) 319 | assert.Equal(t, StateClosed, tscb.State()) 320 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) 321 | assert.True(t, tscb.cb.expiry.IsZero()) 322 | } 323 | 324 | func TestPanicInRequest(t *testing.T) { 325 | assert.Panics(t, func() { _ = causePanic(defaultCB) }) 326 | assert.Equal(t, Counts{1, 0, 1, 0, 1}, defaultCB.counts) 327 | } 328 | 329 | func TestGeneration(t *testing.T) { 330 | pseudoSleep(customCB, time.Duration(29)*time.Second) 331 | assert.Nil(t, succeed(customCB)) 332 | ch := succeedLater(customCB, time.Duration(1500)*time.Millisecond) 333 | time.Sleep(time.Duration(500) * time.Millisecond) 334 | assert.Equal(t, Counts{2, 1, 0, 1, 0}, customCB.counts) 335 | 336 | time.Sleep(time.Duration(500) * time.Millisecond) // over Interval 337 | assert.Equal(t, StateClosed, customCB.State()) 338 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) 339 | 340 | // the request from the previous generation has no effect on customCB.counts 341 | assert.Nil(t, <-ch) 342 | assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) 343 | } 344 | 345 | func TestCustomIsSuccessful(t *testing.T) { 346 | isSuccessful := func(error) bool { 347 | return true 348 | } 349 | cb := NewCircuitBreaker[bool](Settings{IsSuccessful: isSuccessful}) 350 | 351 | for i := 0; i < 5; i++ { 352 | assert.Nil(t, fail(cb)) 353 | } 354 | assert.Equal(t, StateClosed, cb.State()) 355 | assert.Equal(t, Counts{5, 5, 0, 5, 0}, cb.counts) 356 | 357 | cb.counts.clear() 358 | 359 | cb.isSuccessful = func(err error) bool { 360 | return err == nil 361 | } 362 | for i := 0; i < 6; i++ { 363 | assert.Nil(t, fail(cb)) 364 | } 365 | assert.Equal(t, StateOpen, cb.State()) 366 | 367 | } 368 | 369 | func TestCircuitBreakerInParallel(t *testing.T) { 370 | runtime.GOMAXPROCS(runtime.NumCPU()) 371 | 372 | ch := make(chan error) 373 | 374 | const numReqs = 10000 375 | routine := func() { 376 | for i := 0; i < numReqs; i++ { 377 | ch <- succeed(customCB) 378 | } 379 | } 380 | 381 | const numRoutines = 10 382 | for i := 0; i < numRoutines; i++ { 383 | go routine() 384 | } 385 | 386 | total := uint32(numReqs * numRoutines) 387 | for i := uint32(0); i < total; i++ { 388 | err := <-ch 389 | assert.Nil(t, err) 390 | } 391 | assert.Equal(t, Counts{total, total, 0, total, 0}, customCB.counts) 392 | } 393 | -------------------------------------------------------------------------------- /v2/redis_store.go: -------------------------------------------------------------------------------- 1 | package gobreaker 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | 7 | "github.com/go-redsync/redsync/v4" 8 | "github.com/go-redsync/redsync/v4/redis/goredis/v9" 9 | "github.com/redis/go-redis/v9" 10 | ) 11 | 12 | type RedisStore struct { 13 | ctx context.Context 14 | client *redis.Client 15 | rs *redsync.Redsync 16 | mutex map[string]*redsync.Mutex 17 | } 18 | 19 | func NewRedisStore(addr string) *RedisStore { 20 | client := redis.NewClient(&redis.Options{ 21 | Addr: addr, 22 | }) 23 | return &RedisStore{ 24 | ctx: context.Background(), 25 | client: client, 26 | rs: redsync.New(goredis.NewPool(client)), 27 | mutex: map[string]*redsync.Mutex{}, 28 | } 29 | } 30 | 31 | func (rs *RedisStore) Lock(name string) error { 32 | mutex, ok := rs.mutex[name] 33 | if ok { 34 | return mutex.Lock() 35 | } 36 | 37 | mutex = rs.rs.NewMutex(name, redsync.WithExpiry(mutexTimeout)) 38 | rs.mutex[name] = mutex 39 | return mutex.Lock() 40 | } 41 | 42 | func (rs *RedisStore) Unlock(name string) error { 43 | mutex, ok := rs.mutex[name] 44 | if ok { 45 | var err error 46 | ok, err = mutex.Unlock() 47 | if ok && err == nil { 48 | return nil 49 | } 50 | } 51 | return errors.New("unlock failed") 52 | } 53 | 54 | func (rs *RedisStore) GetData(name string) ([]byte, error) { 55 | return rs.client.Get(rs.ctx, name).Bytes() 56 | } 57 | 58 | func (rs *RedisStore) SetData(name string, data []byte) error { 59 | return rs.client.Set(rs.ctx, name, data, 0).Err() 60 | } 61 | 62 | func (rs *RedisStore) Close() { 63 | rs.client.Close() 64 | } 65 | --------------------------------------------------------------------------------