├── .gitignore ├── .travis.yml ├── LICENSE ├── changelog.md ├── conn ├── conn.go ├── conn_test.go └── policy.go ├── heartbeat ├── detector.go ├── hash_expiry_strategy.go ├── hash_expiry_strategy_test.go ├── heart.go ├── heartbeater.go ├── heartbeater_test.go ├── simple_detector.go ├── simple_detector_test.go ├── simple_heart.go ├── simple_heart_test.go ├── strategy.go └── strategy_test.go ├── pubsub ├── client.go ├── events.go ├── listener.go ├── pubsub_test.go └── state.go ├── pubsub2 ├── counters.go ├── emitter.go ├── emitter_test.go ├── event.go ├── event_test.go ├── fuzz_record_list │ └── fuzz.go ├── pumps.go ├── redis.go └── redis_test.go ├── queue ├── base_queue.go ├── base_queue_test.go ├── byte_queue.go ├── byte_queue_test.go ├── durable_queue.go ├── durable_queue_test.go ├── fifo_processor.go ├── fifo_processor_test.go ├── lifo_processor.go ├── lifo_processor_test.go ├── processor.go ├── processor_test.go ├── queue.go ├── scripts.go └── util.go ├── readme.md ├── test └── redis_suite.go └── worker ├── default_lifecycle.go ├── default_lifecycle_test.go ├── default_worker.go ├── default_worker_test.go ├── janitor.go ├── janitor_test.go ├── lifecycle.go ├── lifecycle_test.go ├── task.go ├── task_test.go ├── util.go ├── util_test.go └── worker.go /.gitignore: -------------------------------------------------------------------------------- 1 | #### joe made this: http://goel.io/joe 2 | 3 | #####=== Go ===##### 4 | 5 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 6 | *.o 7 | *.a 8 | *.so 9 | 10 | # Folders 11 | _obj 12 | _test 13 | 14 | # Architecture specific extensions/prefixes 15 | *.[568vq] 16 | [568vq].out 17 | 18 | *.cgo1.go 19 | *.cgo2.c 20 | _cgo_defun.c 21 | _cgo_gotypes.go 22 | _cgo_export.* 23 | 24 | _testmain.go 25 | 26 | *.exe 27 | *.test 28 | *.prof 29 | 30 | /.idea 31 | /fuzz_record_list-fuzz.zip 32 | /pubsub2/fuzz_record_list_output 33 | /*.txt 34 | /*.zip 35 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | services: 3 | - redis-server 4 | go: 5 | - "1.9" 6 | - "1.10" 7 | notifications: 8 | email: false 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Beam Interactive, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## pubsub 4 | 5 | ### 2.4.2 (02-19-2016) 6 | * Extract a `worker.Worker` interface, and create a `worker.DefaultWorker` 7 | implementation. 8 | 9 | ### 2.4.1 (02-17-2016) 10 | * Fix a critical issue in the worker package which prevented janitors from 11 | being able to clean up dead workers. 12 | 13 | ### 2.4 (02-14-2016) 14 | * Implement the `worker` package. 15 | 16 | ### 2.3.1 (02-13-2016) 17 | * **Breaking Change**: Allow the specification of a timeout parameter to 18 | several (previously) infinitely-blocking methods. 19 | 20 | ### 2.3 (01-30-2016) 21 | * Implement the `queue` package. 22 | 23 | ### 2.2 (01-26-2016) 24 | * Implement the `heartbeat` package. 25 | 26 | ### 2.1 (01-23-2016) 27 | 28 | * **Breaking Change**: `ConnectionParam` has moved from the `pubsub` package to 29 | the `conn` package. 30 | * **Breaking Change**: `pubsub.New` no longer takes a `ConnectionParam`, rather 31 | it takes a `*redis.Pool` and a `conn.ReconnectPolicy`. 32 | 33 | ### 2.0 (25-08-2015) rc 34 | 35 | * **Breaking Change**: New() now takes a ConnectionParam value rather than a pointer. 36 | * **Breaking Change**: GetState() now returns a uint8 rather than a user-defined type, for greater compatibility with [fsm](https://github.com/mixer/fsm). _mumble mumble generics_ 37 | * Fix potential data races on the internal subscription registry. 38 | * Fix potential data race resulting in subscription duplication during multiple reconnections. 39 | * Allow specification of connection timeout (deadlines). 40 | * Allow specification of reconnection policies. 41 | * Cause subscription, unsubscriptions, and teardowns to happen more quickly. 42 | * Improve events system for increased flexibility. 43 | * Significantly improve conciseness and speed. 44 | 45 | 46 | ### 1.1 (25-08-2015) 47 | 48 | * **Breaking Change**: New() now takes a *ConnectionParam struct as its first argument. 49 | * Add password authentication options (from @janeczku). 50 | * Fix failing tests in Go 1.3 51 | * Prevent paniking when tearing down a client which was not set up. 52 | 53 | 54 | ### 1.0 (07-04-2015) 55 | 56 | Initial 57 | -------------------------------------------------------------------------------- /conn/conn.go: -------------------------------------------------------------------------------- 1 | package conn 2 | 3 | import ( 4 | "crypto/tls" 5 | "net" 6 | "time" 7 | 8 | "github.com/gomodule/redigo/redis" 9 | "github.com/mna/redisc" 10 | ) 11 | 12 | // Used to denote the parameters of the redis connection. 13 | type ConnectionParam struct { 14 | // Host:port 15 | Address string 16 | // Optional password. Defaults to no authentication. 17 | Password string 18 | // Policy to use for reconnections (defaults to 19 | // LogReconnectPolicy with a base of 10 and factor of 1 ms) 20 | Policy ReconnectPolicy 21 | // Dial timeout for redis (defaults to no timeout) 22 | Timeout time.Duration 23 | // Whether or not to secure the connection with TLS 24 | UseTLS bool 25 | // Whether to use clustering (redisc) 26 | UseCluster bool 27 | } 28 | 29 | // RedUtilPool is used as a generic for redis.Pool and redisc.Cluster 30 | type RedUtilPool interface { 31 | Get() redis.Conn 32 | } 33 | 34 | // NewWithActiveLimit makes and returns a pointer to a new Connector instance. It sets some 35 | // defaults on the ConnectionParam object, such as the policy, which defaults to 36 | // a LogReconnectPolicy with a base of 10ms. A call to this function does not 37 | // produce a connection. 38 | func NewWithActiveLimit(param ConnectionParam, maxIdle int, maxActive int) (RedUtilPool, ReconnectPolicy) { 39 | if param.Policy == nil { 40 | param.Policy = &LogReconnectPolicy{Base: 10, Factor: time.Millisecond} 41 | } 42 | 43 | options := make([]redis.DialOption, 0) 44 | if param.UseTLS { 45 | options = append(options, redis.DialUseTLS(param.UseTLS)) 46 | } 47 | 48 | if param.Password != "" { 49 | options = append(options, redis.DialPassword(param.Password)) 50 | } 51 | 52 | if param.Timeout > 0 { 53 | options = append(options, []redis.DialOption{ 54 | redis.DialConnectTimeout(param.Timeout), 55 | redis.DialReadTimeout(param.Timeout), 56 | redis.DialWriteTimeout(param.Timeout), 57 | }...) 58 | } 59 | 60 | if param.UseCluster { 61 | host, _, err := net.SplitHostPort(param.Address) 62 | if err != nil { 63 | panic(err) 64 | } 65 | 66 | options = append(options, redis.DialTLSConfig(&tls.Config{ 67 | InsecureSkipVerify: false, 68 | ServerName: host, 69 | })) 70 | 71 | return &redisc.Cluster{ 72 | StartupNodes: []string{param.Address}, 73 | DialOptions: options, 74 | CreatePool: connectPool(maxIdle, maxActive), 75 | }, param.Policy 76 | } 77 | 78 | return &redis.Pool{Dial: connect(param.Address, options...), MaxIdle: maxIdle, MaxActive: maxActive}, param.Policy 79 | } 80 | 81 | // New makes and returns a pointer to a new Connector instance. It sets some 82 | // defaults on the ConnectionParam object, such as the policy, which defaults to 83 | // a LogReconnectPolicy with a base of 10ms. A call to this function does not 84 | // produce a connection. 85 | func New(param ConnectionParam, maxIdle int) (RedUtilPool, ReconnectPolicy) { 86 | return NewWithActiveLimit(param, maxIdle, 0) 87 | } 88 | 89 | // connect is a higher-order function that returns a function that dials, 90 | // connects, and authenticates a Redis connection. 91 | // 92 | // It attempts to dial a TCP connection to the address specified, timing out if 93 | // no connection was able to be established within the given time-frame. If no 94 | // timeout was given, it will wait indefinitely. 95 | // 96 | // If a password as specified in the ConnectionParam object, then an `AUTH` 97 | // command (see: http://redis.io/commands/auth) is issued with that password. 98 | // 99 | // If an error is incurred either dialing the TCP connection, or sending the 100 | // `AUTH` command, then it will be returned immediately, and the client can be 101 | // considered useless. 102 | func connect(address string, options ...redis.DialOption) func() (redis.Conn, error) { 103 | return func() (redis.Conn, error) { 104 | return redis.Dial("tcp", address, options...) 105 | } 106 | } 107 | 108 | func connectPool(maxIdle int, maxActive int) func(string, ...redis.DialOption) (*redis.Pool, error) { 109 | return func(address string, options ...redis.DialOption) (*redis.Pool, error) { 110 | return &redis.Pool{Dial: connect(address, options...), MaxIdle: maxIdle, MaxActive: maxActive}, nil 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /conn/conn_test.go: -------------------------------------------------------------------------------- 1 | package conn 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestValidConnections(t *testing.T) { 11 | pool, _ := New(ConnectionParam{ 12 | Address: "127.0.0.1:6379", 13 | }, 1) 14 | 15 | cnx := pool.Get() 16 | err := cnx.Err() 17 | 18 | assert.NotNil(t, cnx) 19 | assert.Nil(t, err) 20 | } 21 | 22 | func TestInvalidConnection(t *testing.T) { 23 | pool, _ := New(ConnectionParam{ 24 | Address: "127.0.0.2:6379", 25 | Timeout: time.Nanosecond, 26 | }, 1) 27 | 28 | err := pool.Get().Err() 29 | 30 | assert.Equal(t, "dial tcp 127.0.0.2:6379: i/o timeout", err.Error()) 31 | } 32 | -------------------------------------------------------------------------------- /conn/policy.go: -------------------------------------------------------------------------------- 1 | package conn 2 | 3 | import ( 4 | "math" 5 | "time" 6 | ) 7 | 8 | // The ReconnectPolicy defines the pattern of delay times used 9 | // after a connection is lost before waiting to reconnection. 10 | type ReconnectPolicy interface { 11 | // Gets the next reconnect time, usually incrementing some 12 | // counter so the next attempt is longer. 13 | Next() time.Duration 14 | // Resets the number of attempts. 15 | Reset() 16 | } 17 | 18 | // Reconnect policy which waits a set period of time on each connect. 19 | type StaticReconnectPolicy struct { 20 | // Delay each time. 21 | Delay time.Duration 22 | } 23 | 24 | var _ ReconnectPolicy = new(StaticReconnectPolicy) 25 | 26 | func (s *StaticReconnectPolicy) Next() time.Duration { 27 | return s.Delay 28 | } 29 | 30 | func (r *StaticReconnectPolicy) Reset() {} 31 | 32 | // Reconnect policy which increases the reconnect day in a logarithmic 33 | // fashion. The equation used is delay = log(tries) / log(base) * Factor 34 | type LogReconnectPolicy struct { 35 | // Base for the logarithim 36 | Base float64 37 | // Duration multiplier for the calculated value. 38 | Factor time.Duration 39 | tries float64 40 | } 41 | 42 | var _ ReconnectPolicy = new(LogReconnectPolicy) 43 | 44 | func (l *LogReconnectPolicy) Next() time.Duration { 45 | l.tries += 1 46 | return time.Duration(math.Log(float64(l.tries)) / math.Log(l.Base) * float64(time.Millisecond)) 47 | } 48 | 49 | func (l *LogReconnectPolicy) Reset() { 50 | l.tries = 0 51 | } 52 | -------------------------------------------------------------------------------- /heartbeat/detector.go: -------------------------------------------------------------------------------- 1 | package heartbeat 2 | 3 | // Detector is an interface to a type responsible for collecting "dead" items in 4 | // Redis. 5 | type Detector interface { 6 | // Detect returns all dead items according to the implementation 7 | // definition of "dead". If an error is encountered while processing, it 8 | // will be returned, and execution of this function will be halted. 9 | Detect() (expired []string, err error) 10 | 11 | // Removes a (supposedly dead) worker from the heartbeating data structure. 12 | Purge(id string) error 13 | } 14 | -------------------------------------------------------------------------------- /heartbeat/hash_expiry_strategy.go: -------------------------------------------------------------------------------- 1 | package heartbeat 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/garyburd/redigo/redis" 7 | ) 8 | 9 | const ( 10 | // DefaultTimeFormat is a time format string according to the time 11 | // package and is used to marshal and unmarshall time.Time instances 12 | // into ISO8601-compliant strings. 13 | // 14 | // See: https://www.ietf.org/rfc/rfc3339.txt for more details. 15 | DefaultTimeFormat string = "2006-01-02T15:04:05" 16 | ) 17 | 18 | // HashExpireyStrategy is an implementation of Strategy that stores items in a 19 | // hash. 20 | type HashExpireyStrategy struct { 21 | MaxAge time.Duration 22 | } 23 | 24 | var _ Strategy = &HashExpireyStrategy{} 25 | 26 | // Touch implements the `func Touch` defined in the Strategy interface. It 27 | // assumes a HASH type is used in Redis to map the IDs of various Hearts to the 28 | // last time that they were updated. 29 | // 30 | // It uses the Heart's `Location` and `ID` fields respectively to determine 31 | // where to both place, and name the hash as well as the items within it. 32 | // 33 | // Times are marshalled using the `const DefaultTimeFormat` which stores times 34 | // in the ISO8601 format. 35 | func (s HashExpireyStrategy) Touch(location, ID string, pool *redis.Pool) error { 36 | now := time.Now().UTC().Format(DefaultTimeFormat) 37 | 38 | cnx := pool.Get() 39 | defer cnx.Close() 40 | 41 | if _, err := cnx.Do("HSET", location, ID, now); err != nil { 42 | return err 43 | } 44 | 45 | return nil 46 | } 47 | 48 | // Purge implements the `func Purge` defined in the Strategy interface. It 49 | // assumes a HASH type is used in Redis to map the IDs of various Hearts, 50 | // and removes the record for the specified ID from the hash. 51 | func (s HashExpireyStrategy) Purge(location, ID string, pool *redis.Pool) error { 52 | cnx := pool.Get() 53 | defer cnx.Close() 54 | 55 | if _, err := cnx.Do("HDEL", location, ID); err != nil { 56 | return err 57 | } 58 | 59 | return nil 60 | } 61 | 62 | // Expired implements the `func Expired` defined on the Strategy interface. It 63 | // scans iteratively over the Heart's `location` field to look for items that 64 | // have expired. An item is marked as expired iff the last update time happened 65 | // before the instant of the maxAge subtracted from the current time. 66 | func (s HashExpireyStrategy) Expired(location string, 67 | pool *redis.Pool) (expired []string, err error) { 68 | 69 | now := time.Now().UTC() 70 | 71 | cnx := pool.Get() 72 | defer cnx.Close() 73 | 74 | reply, err := redis.StringMap(cnx.Do("HGETALL", location)) 75 | if err != nil { 76 | return 77 | } 78 | 79 | for id, tick := range reply { 80 | lastUpdate, err := time.Parse(DefaultTimeFormat, tick) 81 | 82 | if err != nil { 83 | continue 84 | } else if lastUpdate.Add(s.MaxAge).Before(now) { 85 | expired = append(expired, id) 86 | } 87 | } 88 | 89 | return 90 | } 91 | -------------------------------------------------------------------------------- /heartbeat/hash_expiry_strategy_test.go: -------------------------------------------------------------------------------- 1 | package heartbeat_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/garyburd/redigo/redis" 8 | "github.com/mixer/redutil/conn" 9 | "github.com/mixer/redutil/heartbeat" 10 | "github.com/mixer/redutil/test" 11 | "github.com/stretchr/testify/suite" 12 | ) 13 | 14 | type HashExpireyStrategySuite struct { 15 | *test.RedisSuite 16 | } 17 | 18 | func TestHashExpireyStrategySuite(t *testing.T) { 19 | pool, _ := conn.New(conn.ConnectionParam{ 20 | Address: "127.0.0.1:6379", 21 | }, 1) 22 | 23 | suite.Run(t, &HashExpireyStrategySuite{test.NewSuite(pool)}) 24 | } 25 | 26 | func (suite *HashExpireyStrategySuite) TestTouchAddsValues() { 27 | s := &heartbeat.HashExpireyStrategy{MaxAge: time.Second} 28 | 29 | s.Touch("foo", "bar", suite.Pool) 30 | 31 | suite.WithRedis(func(cnx redis.Conn) { 32 | exists, err := redis.Bool(cnx.Do("HEXISTS", "foo", "bar")) 33 | 34 | suite.Assert().Nil(err, "HEXISTS: expected no error but got one") 35 | suite.Assert().True(exists, "HEXISTS: expected to find foo:bar but didn't") 36 | }) 37 | } 38 | 39 | func (suite *HashExpireyStrategySuite) TestExpiredFindsDeadValues() { 40 | s := &heartbeat.HashExpireyStrategy{MaxAge: time.Second} 41 | 42 | suite.WithRedis(func(cnx redis.Conn) { 43 | tick := time.Now().UTC().Add(-10 * time.Second).Format(heartbeat.DefaultTimeFormat) 44 | cnx.Do("HSET", "foo", "bar", tick) 45 | 46 | expired, err := s.Expired("foo", suite.Pool) 47 | 48 | suite.Assert().Nil(err) 49 | suite.Assert().Len(expired, 1) 50 | suite.Assert().Equal(expired[0], "bar") 51 | }) 52 | } 53 | 54 | func (suite *HashExpireyStrategySuite) TestExpiredIgnoresUnreadableValues() { 55 | s := &heartbeat.HashExpireyStrategy{MaxAge: time.Second} 56 | 57 | suite.WithRedis(func(cnx redis.Conn) { 58 | tick := time.Now().UTC().Add(-10 * time.Second).Format(heartbeat.DefaultTimeFormat) 59 | cnx.Do("HSET", "foo", "bar", tick) 60 | cnx.Do("HSET", "foo", "baz", "not a real time") 61 | 62 | expired, err := s.Expired("foo", suite.Pool) 63 | 64 | suite.Assert().Nil(err) 65 | suite.Assert().Len(expired, 1) 66 | suite.Assert().Equal(expired[0], "bar") 67 | }) 68 | } 69 | -------------------------------------------------------------------------------- /heartbeat/heart.go: -------------------------------------------------------------------------------- 1 | package heartbeat 2 | 3 | // Heart is a interface representing a process which runs a particular task at a 4 | // given interval within its own goroutine. 5 | type Heart interface { 6 | // Close stops the update-task from running, and frees up any owned 7 | // resources. 8 | Close() 9 | 10 | // Errs returns a read-only `<-chan error` which contains all errors 11 | // encountered while runnning the update task. 12 | Errs() <-chan error 13 | } 14 | -------------------------------------------------------------------------------- /heartbeat/heartbeater.go: -------------------------------------------------------------------------------- 1 | package heartbeat 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/garyburd/redigo/redis" 7 | ) 8 | 9 | // Heartbeater serves as a factory-type, in essence, to create both Hearts and 10 | // Detectors. It maintains information about the ID used to heartbeat with, the 11 | // Location in which to store ticks, the interval at which to update that 12 | // location, the pool in which to maintain and recycle Redis connections to, and 13 | // the strategy to use to tick and recover items in Redis. 14 | type Heartbeater struct { 15 | // ID is a unique identifier used by the Heart to tick with. 16 | ID string 17 | 18 | // Location is the absolute path of the keyspace in Redis in which to 19 | // store all of the heartbeat updates. 20 | Location string 21 | 22 | // interval is the Interval at which to tell the Heart to tick itself. 23 | interval time.Duration 24 | 25 | // pool is the *redis.Pool that connections are derived from. 26 | pool *redis.Pool 27 | 28 | // Strategy is the strategy to use both to tick the items in Redis, but 29 | // also to determine which ones are still alive. 30 | Strategy Strategy 31 | } 32 | 33 | // New allocates and returns a pointer to a new instance of a Heartbeater. It 34 | // takes in the id, location, interval and pool with which to use to create 35 | // Hearts and Detectors. 36 | func New(id, location string, interval time.Duration, pool *redis.Pool) *Heartbeater { 37 | h := &Heartbeater{ 38 | ID: id, 39 | Location: location, 40 | interval: interval, 41 | pool: pool, 42 | } 43 | h.Strategy = HashExpireyStrategy{h.MaxAge()} 44 | 45 | return h 46 | } 47 | 48 | // Interval returns the interval at which the heart should tick itself. 49 | func (h *Heartbeater) Interval() time.Duration { 50 | return h.interval 51 | } 52 | 53 | // Heart creates and returns a new instance of the Heart type with the 54 | // parameters used by the Heartbeater for consistency. 55 | func (h *Heartbeater) Heart() Heart { 56 | // TODO: missing strategy field here 57 | return NewSimpleHeart(h.ID, h.Location, h.interval, h.pool, h.Strategy) 58 | } 59 | 60 | // Detectors creates and returns a new instance of the Detector type with the 61 | // parameters used by the Heartbeater for consistency. 62 | func (h *Heartbeater) Detector() Detector { 63 | return NewDetector(h.Location, h.pool, h.Strategy) 64 | } 65 | 66 | // MaxAge returns the maximum amount of time that can pass from the last tick of 67 | // an item before that item may be considered dead. By default, it is three 68 | // halves of the Interval() time, but is only one second if the interval time is 69 | // less than five seconds. 70 | func (h *Heartbeater) MaxAge() time.Duration { 71 | if h.Interval() < 5*time.Second { 72 | return h.Interval() + time.Second 73 | } 74 | 75 | return h.Interval() * 3 / 2 76 | } 77 | 78 | // SetStrategy changes the strategy used by all future Heart and Detector 79 | // instantiations. 80 | func (h *Heartbeater) SetStrategy(strategy Strategy) *Heartbeater { 81 | h.Strategy = strategy 82 | 83 | return h 84 | } 85 | -------------------------------------------------------------------------------- /heartbeat/heartbeater_test.go: -------------------------------------------------------------------------------- 1 | package heartbeat_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/garyburd/redigo/redis" 8 | "github.com/mixer/redutil/conn" 9 | "github.com/mixer/redutil/heartbeat" 10 | "github.com/mixer/redutil/test" 11 | "github.com/stretchr/testify/suite" 12 | ) 13 | 14 | type HeartbeaterSuite struct { 15 | *test.RedisSuite 16 | } 17 | 18 | type StrategyStub struct{} 19 | 20 | func (_ *StrategyStub) Touch(location, ID string, pool *redis.Pool) (err error) { 21 | return nil 22 | } 23 | 24 | func (_ *StrategyStub) Purge(location, ID string, pool *redis.Pool) (err error) { 25 | return nil 26 | } 27 | 28 | func (_ *StrategyStub) Expired(location string, pool *redis.Pool) (expired []string, err error) { 29 | return make([]string, 0), nil 30 | } 31 | 32 | func TestHeartbeaterSuite(t *testing.T) { 33 | pool, _ := conn.New(conn.ConnectionParam{ 34 | Address: "127.0.0.1:6379", 35 | }, 1) 36 | 37 | suite.Run(t, &HeartbeaterSuite{test.NewSuite(pool)}) 38 | } 39 | 40 | func (suite *HeartbeaterSuite) TestConstruction() { 41 | h := heartbeat.New("foo", "bar", time.Second, suite.Pool) 42 | 43 | suite.Assert().Equal(h.ID, "foo") 44 | suite.Assert().Equal(h.Location, "bar") 45 | suite.Assert().Equal(h.Interval(), time.Second) 46 | } 47 | 48 | func (suite *HeartbeaterSuite) TestMaxAgePadsSmallValues() { 49 | h := heartbeat.New("foo", "bar", 3*time.Second, suite.Pool) 50 | 51 | suite.Assert().Equal(h.MaxAge(), 4*time.Second) 52 | } 53 | 54 | func (suite *HeartbeaterSuite) TestMaxAgeScalesLargeValues() { 55 | h := heartbeat.New("foo", "bar", 10*time.Second, suite.Pool) 56 | 57 | suite.Assert().Equal(h.MaxAge(), 15*time.Second) 58 | } 59 | 60 | func (suite *HeartbeaterSuite) TestSetStrategy() { 61 | h := heartbeat.New("foo", "bar", 10*time.Second, suite.Pool) 62 | strategy := &StrategyStub{} 63 | 64 | h.SetStrategy(strategy) 65 | 66 | suite.Assert().Equal(h.Strategy, strategy) 67 | } 68 | 69 | func (suite *HeartbeaterSuite) TestHeartCreation() { 70 | h := heartbeat.New("foo", "bar", 10*time.Second, suite.Pool) 71 | 72 | heart, ok := h.Heart().(heartbeat.SimpleHeart) 73 | 74 | suite.Assert().True(ok, "heart: expected heart to be heartbeat.SimpleHeart") 75 | suite.Assert().Equal(heart.ID, "foo") 76 | suite.Assert().Equal(heart.Location, "bar") 77 | suite.Assert().Equal(heart.Interval, 10*time.Second) 78 | } 79 | 80 | func (suite *HeartbeaterSuite) TestDetectorCreation() { 81 | h := heartbeat.New("foo", "bar", 10*time.Second, suite.Pool) 82 | strategy := &StrategyStub{} 83 | h.SetStrategy(strategy) 84 | 85 | detector, ok := h.Detector().(heartbeat.SimpleDetector) 86 | 87 | suite.Assert().True(ok, "detector: expected detector to be heartbeat.SimpleDetector") 88 | suite.Assert().Equal(detector.Location(), "bar") 89 | suite.Assert().Equal(detector.Strategy(), strategy) 90 | } 91 | -------------------------------------------------------------------------------- /heartbeat/simple_detector.go: -------------------------------------------------------------------------------- 1 | package heartbeat 2 | 3 | import "github.com/garyburd/redigo/redis" 4 | 5 | // SimpleDetector is an implementation of the Detector interface which uses the 6 | // provided Strategy in order to determine which items may be considered dead. 7 | type SimpleDetector struct { 8 | location string 9 | pool *redis.Pool 10 | strategy Strategy 11 | } 12 | 13 | // NewDetector initializes and returns a new SimpleDetector instance with the 14 | // given parameters. 15 | func NewDetector(location string, pool *redis.Pool, strategy Strategy) Detector { 16 | return SimpleDetector{ 17 | location: location, 18 | pool: pool, 19 | strategy: strategy, 20 | } 21 | } 22 | 23 | // Location returns the keyspace in which the detector searches. 24 | func (d SimpleDetector) Location() string { return d.location } 25 | 26 | // Strategy returns the strategy that the detector uses to search in the 27 | // keyspace returned by Location(). 28 | func (d SimpleDetector) Strategy() Strategy { return d.strategy } 29 | 30 | // Detect implements the `func Detect` on the `type Detector interface`. This 31 | // implementation simply delegates into the provided Strategy. 32 | func (d SimpleDetector) Detect() (expired []string, err error) { 33 | return d.strategy.Expired(d.location, d.pool) 34 | } 35 | 36 | // Purge implements the `func Purge` on the `type Detector interface`. This 37 | // implementation simply delegates into the provided Strategy. 38 | func (d SimpleDetector) Purge(id string) (err error) { 39 | return d.strategy.Purge(d.location, id, d.pool) 40 | } 41 | -------------------------------------------------------------------------------- /heartbeat/simple_detector_test.go: -------------------------------------------------------------------------------- 1 | package heartbeat_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "time" 7 | 8 | "github.com/mixer/redutil/conn" 9 | "github.com/mixer/redutil/heartbeat" 10 | "github.com/mixer/redutil/test" 11 | "github.com/stretchr/testify/suite" 12 | ) 13 | 14 | type SimpleDetectorSuite struct { 15 | *test.RedisSuite 16 | } 17 | 18 | func TestSimpleDetectorSuite(t *testing.T) { 19 | pool, _ := conn.New(conn.ConnectionParam{ 20 | Address: "127.0.0.1:6379", 21 | }, 1) 22 | 23 | suite.Run(t, &SimpleDetectorSuite{test.NewSuite(pool)}) 24 | } 25 | 26 | func (suite *SimpleDetectorSuite) TestConstruction() { 27 | d := heartbeat.NewDetector("foo", suite.Pool, heartbeat.HashExpireyStrategy{time.Second}) 28 | 29 | suite.Assert().IsType(heartbeat.SimpleDetector{}, d) 30 | } 31 | 32 | func (suite *SimpleDetectorSuite) TestDetectDelegatesToStrategy() { 33 | strategy := &TestStrategy{} 34 | strategy.On("Expired", "foo", suite.Pool).Return([]string{}, nil) 35 | 36 | d := heartbeat.NewDetector("foo", suite.Pool, strategy) 37 | d.Detect() 38 | 39 | strategy.AssertCalled(suite.T(), "Expired", "foo", suite.Pool) 40 | } 41 | 42 | func (suite *SimpleDetectorSuite) TestDetectPropogatesValues() { 43 | strategy := &TestStrategy{} 44 | strategy.On("Expired", "foo", suite.Pool).Return([]string{"foo", "bar"}, errors.New("baz")) 45 | 46 | d := heartbeat.NewDetector("foo", suite.Pool, strategy) 47 | expired, err := d.Detect() 48 | 49 | suite.Assert().Equal(expired, []string{"foo", "bar"}) 50 | suite.Assert().Equal(err.Error(), "baz") 51 | } 52 | 53 | func (suite *SimpleDetectorSuite) TestDetectPurgesData() { 54 | strategy := &TestStrategy{} 55 | strategy.On("Purge", "foo", "id1", suite.Pool).Return(nil).Once() 56 | strategy.On("Purge", "foo", "id2", suite.Pool).Return(errors.New("baz")).Once() 57 | 58 | d := heartbeat.NewDetector("foo", suite.Pool, strategy) 59 | suite.Assert().Nil(d.Purge("id1")) 60 | suite.Assert().Equal(d.Purge("id2").Error(), "baz") 61 | } 62 | -------------------------------------------------------------------------------- /heartbeat/simple_heart.go: -------------------------------------------------------------------------------- 1 | package heartbeat 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/garyburd/redigo/redis" 7 | ) 8 | 9 | // SimpleHeart is a implementation of the `type Heart interface` which uses a 10 | // given ID, Location and Interval, to call the provided Strategy in order to 11 | // tick items in Redis. 12 | type SimpleHeart struct { 13 | ID string 14 | Location string 15 | Interval time.Duration 16 | strategy Strategy 17 | 18 | pool *redis.Pool 19 | closer chan struct{} 20 | errs chan error 21 | } 22 | 23 | // New allocates and returns a new instance of a SimpleHeart, initialized with 24 | // the given parameters. 25 | // 26 | // pool is the *redis.Pool of which the Heart will take connections from. 27 | // 28 | // location is the location in the Redis keyspace wherein the heartbeat will be 29 | // stored. Similarily, `id` is the ID of the given Heart that will be touched in 30 | // that location. 31 | // 32 | // strategy is the Strategy that will be used to Touch() the keyspace in Redis 33 | // at the given interval. 34 | // 35 | // Both the closer and errs channel are initialized to new, empty channels using 36 | // the builtin `make()`. 37 | // 38 | // It begins ticking immediately. 39 | func NewSimpleHeart(id, location string, interval time.Duration, pool *redis.Pool, 40 | strategy Strategy) SimpleHeart { 41 | 42 | sh := SimpleHeart{ 43 | ID: id, 44 | Location: location, 45 | Interval: interval, 46 | 47 | pool: pool, 48 | 49 | closer: make(chan struct{}), 50 | errs: make(chan error, 1), 51 | strategy: strategy, 52 | } 53 | 54 | go sh.heartbeat() 55 | sh.touch() 56 | 57 | return sh 58 | } 59 | 60 | // Close implements the `func Close` defined in the `type Heart interface`. It 61 | // stops the heartbeat process at the next tick. 62 | func (s SimpleHeart) Close() { 63 | s.closer <- struct{}{} 64 | } 65 | 66 | // Errs implements the `func Errs` defined in the `type Heart interface`. It 67 | // returns a read-only channel of errors encountered during the heartbeat 68 | // operation. 69 | func (s SimpleHeart) Errs() <-chan error { 70 | return s.errs 71 | } 72 | 73 | // touch calls .Touch() on the Heart's strategy and pushes any error that 74 | // occurred to the errs channel. 75 | func (s SimpleHeart) touch() { 76 | if err := s.strategy.Touch(s.Location, s.ID, s.pool); err != nil { 77 | s.errs <- err 78 | } 79 | } 80 | 81 | // heartbeat is a function responsible for ticking the updater. 82 | // 83 | // It uses a `select` statement to either gather an update from the time.Ticker 84 | // or a close operation. When the Updater is called, the `now` time is passed 85 | // from whatever was on the `ticker.C` channel (`<-chan time.Time`). 86 | // 87 | // If any error occurs during operation, it will be masqueraded up to the 88 | // `s.errs` channel, accessible by the `Errs()` function. 89 | // 90 | // It runs in its own goroutine. 91 | func (s *SimpleHeart) heartbeat() { 92 | ticker := time.NewTicker(s.Interval) 93 | defer ticker.Stop() 94 | 95 | for { 96 | select { 97 | case <-ticker.C: 98 | s.touch() 99 | case <-s.closer: 100 | return 101 | } 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /heartbeat/simple_heart_test.go: -------------------------------------------------------------------------------- 1 | package heartbeat_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "time" 7 | 8 | "github.com/mixer/redutil/conn" 9 | "github.com/mixer/redutil/heartbeat" 10 | "github.com/mixer/redutil/test" 11 | "github.com/stretchr/testify/suite" 12 | ) 13 | 14 | type SimpleHeartbeatSuite struct { 15 | *test.RedisSuite 16 | } 17 | 18 | func TestSimpleHeartbeatSuite(t *testing.T) { 19 | pool, _ := conn.New(conn.ConnectionParam{ 20 | Address: "127.0.0.1:6379", 21 | }, 1) 22 | 23 | suite.Run(t, &SimpleHeartbeatSuite{test.NewSuite(pool)}) 24 | } 25 | 26 | func (suite *SimpleHeartbeatSuite) TestConstruction() { 27 | h := heartbeat.NewSimpleHeart("bar", "foo", time.Second, suite.Pool, heartbeat.HashExpireyStrategy{time.Second}) 28 | defer h.Close() 29 | 30 | suite.Assert().IsType(heartbeat.SimpleHeart{}, h) 31 | suite.Assert().Equal(h.ID, "bar") 32 | suite.Assert().Equal(h.Location, "foo") 33 | suite.Assert().Equal(h.Interval, time.Second) 34 | } 35 | 36 | func (suite *SimpleHeartbeatSuite) TestStrategyIsCalledAtInitializationAndInterval() { 37 | strategy := &TestStrategy{} 38 | strategy.On("Touch", "foo", "bar", suite.Pool).Return(nil) 39 | 40 | h := heartbeat.NewSimpleHeart("bar", "foo", 50*time.Millisecond, suite.Pool, strategy) 41 | strategy.AssertNumberOfCalls(suite.T(), "Touch", 1) 42 | defer h.Close() 43 | 44 | time.Sleep(60 * time.Millisecond) 45 | 46 | strategy.AssertNumberOfCalls(suite.T(), "Touch", 2) 47 | } 48 | 49 | func (suite *SimpleHeartbeatSuite) TestStrategyPropogatesErrors() { 50 | strategy := &TestStrategy{} 51 | err := errors.New("some error") 52 | strategy.On("Touch", "foo", "bar", suite.Pool).Twice().Return(err) 53 | 54 | h := heartbeat.NewSimpleHeart("bar", "foo", 100*time.Millisecond, suite.Pool, strategy) 55 | defer h.Close() 56 | 57 | errs := h.Errs() 58 | suite.Assert().Len(errs, 1) 59 | suite.Assert().Equal(err, <-errs) 60 | 61 | time.Sleep(150 * time.Millisecond) 62 | 63 | suite.Assert().Equal(err, <-errs) 64 | suite.Assert().Len(errs, 0) 65 | } 66 | 67 | func (suite *SimpleHeartbeatSuite) TestCloseStopsCallingStrategy() { 68 | strategy := &TestStrategy{} 69 | strategy.On("Touch", "foo", "bar", suite.Pool).Return(nil) 70 | 71 | h := heartbeat.NewSimpleHeart("bar", "foo", 5*time.Millisecond, suite.Pool, strategy) 72 | h.Close() 73 | 74 | time.Sleep(10 * time.Millisecond) 75 | 76 | strategy.AssertNumberOfCalls(suite.T(), "Touch", 1) 77 | } 78 | -------------------------------------------------------------------------------- /heartbeat/strategy.go: -------------------------------------------------------------------------------- 1 | package heartbeat 2 | 3 | import "github.com/garyburd/redigo/redis" 4 | 5 | // Strategy is an interface to a type which is responsible for both ticking a 6 | // heart's "pulse," and detecting what items in a particular section of Redis 7 | // are to be considered dead. 8 | type Strategy interface { 9 | // Touch ticks the item at `location:ID` in Redis. 10 | Touch(location, ID string, pool *redis.Pool) (err error) 11 | 12 | // Removes an item at `location:ID` in Redis. 13 | Purge(location, ID string, pool *redis.Pool) (err error) 14 | 15 | // Expired returns an array of strings that represent the expired IDs in 16 | // a given keyspace as specified by the `location` parameter. 17 | Expired(location string, pool *redis.Pool) (expired []string, err error) 18 | } 19 | -------------------------------------------------------------------------------- /heartbeat/strategy_test.go: -------------------------------------------------------------------------------- 1 | package heartbeat_test 2 | 3 | import ( 4 | "github.com/garyburd/redigo/redis" 5 | "github.com/stretchr/testify/mock" 6 | ) 7 | 8 | type TestStrategy struct { 9 | mock.Mock 10 | } 11 | 12 | func (s *TestStrategy) Touch(location, ID string, pool *redis.Pool) (err error) { 13 | args := s.Called(location, ID, pool) 14 | return args.Error(0) 15 | } 16 | 17 | func (s *TestStrategy) Purge(location, ID string, pool *redis.Pool) (err error) { 18 | args := s.Called(location, ID, pool) 19 | return args.Error(0) 20 | } 21 | 22 | func (s *TestStrategy) Expired(location string, pool *redis.Pool) (expired []string, err error) { 23 | args := s.Called(location, pool) 24 | return args.Get(0).([]string), args.Error(1) 25 | } 26 | -------------------------------------------------------------------------------- /pubsub/client.go: -------------------------------------------------------------------------------- 1 | // The pubsub package provides a useful stable Redis pubsub connection. 2 | // After opening a connection, it allows you to subscribe to and receieve 3 | // events even in the case of network failures - you don't have to deal 4 | // with that in your code! 5 | // 6 | // Basic usage, prints any messages it gets in the "foobar" channel: 7 | // 8 | // client := pubsub.New(conn.New(conn.ConnectionParam{ 9 | // Address: "127.0.0.1:6379" 10 | // }, 1)) 11 | // defer client.TearDown() 12 | // go client.Connect() 13 | // 14 | // listener := client.Listener(Channel, "foobar") 15 | // for { 16 | // fmt.Println(<-listener.Channel) 17 | // } 18 | // 19 | // Events are emitted down the client's "Event" channel. If you wanted to 20 | // wait until the client was open (not necessary, but may be useful): 21 | // 22 | // client := NewPubsub("127.0.0.1:6379") 23 | // go client.Connect() 24 | // client.WaitFor(ConnectedEvent) 25 | // 26 | // You can also subscribe to patterns and unsubscribe, of course: 27 | // 28 | // listener := client.Listener(Pattern, "foo:*:bar") 29 | // doStuff() 30 | // listener.Unsubscribe() 31 | 32 | package pubsub 33 | 34 | import ( 35 | "net" 36 | "sync" 37 | "time" 38 | 39 | "github.com/garyburd/redigo/redis" 40 | "github.com/mixer/fsm" 41 | "github.com/mixer/redutil/conn" 42 | ) 43 | 44 | // Used to denote the type of listener - channel or pattern. 45 | type ListenerType uint8 46 | 47 | const ( 48 | Channel ListenerType = iota 49 | Pattern 50 | ) 51 | 52 | // Tasks we send to the main pubsub thread to subscribe/unsubscribe. 53 | type task struct { 54 | Action action 55 | Listener *Listener 56 | // If true, the *action* (subscribed/unsubscribed) will be undertaken 57 | // even if we think we might have already done it. 58 | Force bool 59 | } 60 | 61 | // actions are "things" we can do in tasks 62 | type action uint8 63 | 64 | // List of actions we can use in tasks. Internal use 65 | const ( 66 | subscribeAction action = iota 67 | unsubscribeAction 68 | disruptAction 69 | ) 70 | 71 | // The Client is responsible for maintaining a subscribed redis client, 72 | // reconnecting and resubscribing if it drops. 73 | type Client struct { 74 | eventEmitter 75 | // The current state that the client is in. 76 | state *fsm.Machine 77 | stateLock *sync.Mutex 78 | // Connection we're subbed to. 79 | pool *redis.Pool 80 | // The subscription client we're currently using. 81 | pubsub *redis.PubSubConn 82 | // A list of events that we're subscribed to. If the connection closes, 83 | // we'll reestablish it and resubscribe to events. 84 | subscribed map[string][]*Listener 85 | // Reconnection policy 86 | policy conn.ReconnectPolicy 87 | // Channel of sub/unsub tasks 88 | tasks chan task 89 | } 90 | 91 | // Creates and returns a new pubsub client client and subscribes to it. 92 | func New(pool *redis.Pool, reconnectPolicy conn.ReconnectPolicy) *Client { 93 | return &Client{ 94 | eventEmitter: newEventEmitter(), 95 | state: blueprint.Machine(), 96 | stateLock: new(sync.Mutex), 97 | pool: pool, 98 | subscribed: map[string][]*Listener{}, 99 | policy: reconnectPolicy, 100 | tasks: make(chan task, 128), 101 | } 102 | } 103 | 104 | // Convenience function to create a new listener for an event. 105 | func (c *Client) Listener(kind ListenerType, event string) *Listener { 106 | listener := &Listener{ 107 | Type: kind, 108 | Event: event, 109 | Messages: make(chan redis.Message), 110 | PMessages: make(chan redis.PMessage), 111 | Client: c, 112 | } 113 | c.Subscribe(listener) 114 | 115 | return listener 116 | } 117 | 118 | // GetState gets the current client state. 119 | func (c *Client) GetState() uint8 { 120 | c.stateLock.Lock() 121 | defer c.stateLock.Unlock() 122 | 123 | return c.state.State() 124 | } 125 | 126 | // Sets the client state in a thread-safe manner. 127 | func (c *Client) setState(s uint8) error { 128 | c.stateLock.Lock() 129 | err := c.state.Goto(s) 130 | c.stateLock.Unlock() 131 | 132 | if err != nil { 133 | return err 134 | } 135 | 136 | switch s { 137 | case ConnectedState: 138 | c.emit(ConnectedEvent, nil) 139 | case DisconnectedState: 140 | c.emit(DisconnectedEvent, nil) 141 | case ClosingState: 142 | c.emit(ClosingEvent, nil) 143 | case ClosedState: 144 | c.emit(ClosedEvent, nil) 145 | } 146 | 147 | return nil 148 | } 149 | 150 | // Tries to reconnect to pubsub, looping until we're able to do so 151 | // successfully. This must be called to activate the client. 152 | func (c *Client) Connect() { 153 | for c.GetState() == DisconnectedState { 154 | go c.resubscribe() 155 | c.doConnection() 156 | time.Sleep(c.policy.Next()) 157 | } 158 | 159 | c.setState(ClosedState) 160 | } 161 | 162 | func (c *Client) doConnection() { 163 | cnx := c.pool.Get() 164 | defer cnx.Close() 165 | 166 | err := cnx.Err() 167 | 168 | if err != nil { 169 | c.emit(ErrorEvent, err) 170 | return 171 | } 172 | 173 | c.pubsub = &redis.PubSubConn{Conn: cnx} 174 | c.policy.Reset() 175 | c.setState(ConnectedState) 176 | 177 | end := make(chan bool) 178 | go func() { 179 | for { 180 | select { 181 | case <-end: 182 | return 183 | case t := <-c.tasks: 184 | c.workOnTask(t) 185 | } 186 | } 187 | }() 188 | 189 | READ: 190 | for c.GetState() == ConnectedState { 191 | switch reply := c.pubsub.Receive().(type) { 192 | case redis.Message: 193 | go c.dispatchMessage(reply) 194 | case redis.PMessage: 195 | go c.dispatchPMessage(reply) 196 | case redis.Subscription: 197 | switch reply.Kind { 198 | case "subscribe", "psubscribe": 199 | c.emit(SubscribeEvent, reply) 200 | case "unsubscribe", "punsubscribe": 201 | c.emit(UnsubscribeEvent, reply) 202 | } 203 | case error: 204 | if err, ok := reply.(net.Error); ok && err.Timeout() { 205 | // don't emit error for time outs 206 | } else if c.GetState() != ConnectedState { 207 | // if we already closed, don't really care 208 | } else { 209 | c.emit(ErrorEvent, reply) 210 | } 211 | break READ 212 | } 213 | } 214 | 215 | end <- true 216 | c.pubsub.Close() 217 | c.setState(DisconnectedState) 218 | } 219 | 220 | // Takes a task, modifies redis and the internal subscribed registery. 221 | // This is done here (called in boot()) since Go's maps are not thread safe. 222 | func (c *Client) workOnTask(t task) { 223 | switch t.Action { 224 | case subscribeAction: 225 | event := t.Listener.Event 226 | // Check to see if it already exists in the subscribed map. 227 | // If not, then we need to create a new listener list and 228 | // start listening on the pubsub client. Otherwise just 229 | // append it. 230 | if _, exists := c.subscribed[event]; !exists || t.Force { 231 | c.subscribed[event] = []*Listener{t.Listener} 232 | switch t.Listener.Type { 233 | case Channel: 234 | c.pubsub.Subscribe(event) 235 | case Pattern: 236 | c.pubsub.PSubscribe(event) 237 | } 238 | } else { 239 | c.subscribed[event] = append(c.subscribed[event], t.Listener) 240 | } 241 | case unsubscribeAction: 242 | event := t.Listener.Event 243 | // Look for the listener in the subscribed list and, 244 | // once found, remove it. 245 | list := c.subscribed[event] 246 | for i, l := range list { 247 | if l == t.Listener { 248 | c.subscribed[event] = append(list[:i], list[i+1:]...) 249 | break 250 | } 251 | } 252 | 253 | // If the list is now empty, we can unsubscribe on Redis and 254 | // remove it from our registery 255 | if len(c.subscribed[event]) == 0 || t.Force { 256 | switch t.Listener.Type { 257 | case Channel: 258 | c.pubsub.Unsubscribe(event) 259 | case Pattern: 260 | c.pubsub.PUnsubscribe(event) 261 | } 262 | delete(c.subscribed, event) 263 | } 264 | case disruptAction: 265 | c.pubsub.Conn.Close() 266 | default: 267 | panic("unknown task") 268 | } 269 | } 270 | 271 | // Takes in a Redis message and sends it out to any "listening" channels. 272 | func (c *Client) dispatchMessage(message redis.Message) { 273 | if listeners, exists := c.subscribed[message.Channel]; exists { 274 | for _, l := range listeners { 275 | l.Messages <- message 276 | } 277 | } 278 | } 279 | 280 | // Takes in a Redis pmessage and sends it out to any "listening" channels. 281 | func (c *Client) dispatchPMessage(message redis.PMessage) { 282 | if listeners, exists := c.subscribed[message.Pattern]; exists { 283 | for _, l := range listeners { 284 | l.PMessages <- message 285 | } 286 | } 287 | } 288 | 289 | // Resubscribes to all Redis events. Good to do after a disconnection. 290 | func (c *Client) resubscribe() { 291 | // Swap so that if we reconnect before all tasks are done, 292 | // we don't duplicate things. 293 | subs := c.subscribed 294 | c.subscribed = map[string][]*Listener{} 295 | 296 | for _, events := range subs { 297 | // We only need to subscribe on one event, so that the Redis 298 | // connection gets registered. We don't care about the others. 299 | c.tasks <- task{Listener: events[0], Action: subscribeAction, Force: true} 300 | } 301 | } 302 | 303 | // Tears down the client - closes the connection and stops 304 | // listening for connections. 305 | func (c *Client) TearDown() { 306 | if c.GetState() != ConnectedState { 307 | return 308 | } 309 | c.setState(ClosingState) 310 | c.tasks <- task{Action: disruptAction} 311 | 312 | c.WaitFor(ClosedEvent) 313 | } 314 | 315 | // Subscribes to a Redis event. Strings are sent back down the listener 316 | // channel when they come in, and 317 | func (c *Client) Subscribe(listener *Listener) { 318 | listener.Active = true 319 | c.tasks <- task{Listener: listener, Action: subscribeAction} 320 | } 321 | 322 | // Unsubscribe removes the listener from the list of subscribers. If it's the last one 323 | // listening to that Redis event, we unsubscribe entirely. 324 | func (c *Client) Unsubscribe(listener *Listener) { 325 | listener.Active = false 326 | c.tasks <- task{Listener: listener, Action: unsubscribeAction} 327 | } 328 | -------------------------------------------------------------------------------- /pubsub/events.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | // The event is sent down the client's Events channel when something happens! 8 | type Event struct { 9 | Type EventType 10 | Packet interface{} 11 | } 12 | 13 | // A function which handles an incoming event. 14 | type EventHandler func(Event) 15 | 16 | // Events that are sent down the "Events" channel. 17 | type EventType uint8 18 | 19 | const ( 20 | ConnectedEvent EventType = iota 21 | DisconnectedEvent 22 | ClosingEvent 23 | ClosedEvent 24 | MessageEvent 25 | SubscribeEvent 26 | UnsubscribeEvent 27 | ErrorEvent 28 | AnyEvent 29 | ) 30 | 31 | // Simple implementation of a Node-like event emitter. 32 | type eventEmitter struct { 33 | lock *sync.Mutex 34 | listeners map[EventType][]EventHandler 35 | once map[EventType][]EventHandler 36 | } 37 | 38 | func newEventEmitter() eventEmitter { 39 | return eventEmitter{ 40 | lock: new(sync.Mutex), 41 | listeners: map[EventType][]EventHandler{}, 42 | once: map[EventType][]EventHandler{}, 43 | } 44 | } 45 | 46 | func (e *eventEmitter) addHandlerToMap(ev EventType, h EventHandler, m map[EventType][]EventHandler) { 47 | e.lock.Lock() 48 | defer e.lock.Unlock() 49 | 50 | if handlers, ok := m[ev]; ok { 51 | m[ev] = append(handlers, h) 52 | } else { 53 | m[ev] = []EventHandler{h} 54 | } 55 | } 56 | 57 | // Once adds a handler that's executed once when an event is emitted. 58 | func (e *eventEmitter) Once(ev EventType, h EventHandler) { 59 | e.addHandlerToMap(ev, h, e.once) 60 | } 61 | 62 | // On adds a handler that's executed when an event happens. 63 | func (e *eventEmitter) On(ev EventType, h EventHandler) { 64 | e.addHandlerToMap(ev, h, e.listeners) 65 | } 66 | 67 | // Creates a channel that gets written to when a new event comes in. 68 | func (e *eventEmitter) OnChannel(ev EventType) chan Event { 69 | ch := make(chan Event, 1) 70 | e.On(ev, func(e Event) { 71 | ch <- e 72 | }) 73 | 74 | return ch 75 | } 76 | 77 | // Triggers an event to be sent out to listeners. 78 | func (e *eventEmitter) emit(typ EventType, data interface{}) { 79 | ev := Event{Type: typ, Packet: data} 80 | 81 | e.lock.Lock() 82 | lists := [][]EventHandler{} 83 | if handlers, ok := e.listeners[ev.Type]; ok { 84 | lists = append(lists, handlers) 85 | } 86 | if handlers, ok := e.once[ev.Type]; ok { 87 | lists = append(lists, handlers) 88 | delete(e.once, ev.Type) 89 | } 90 | e.lock.Unlock() 91 | 92 | for _, list := range lists { 93 | for _, handler := range list { 94 | go handler(ev) 95 | } 96 | } 97 | } 98 | 99 | // Blocks until an event is received. Mainly for backwards-compatibility. 100 | func (e *eventEmitter) WaitFor(ev EventType) { 101 | done := make(chan bool) 102 | 103 | go func() { 104 | e.Once(ev, func(e Event) { 105 | done <- true 106 | }) 107 | }() 108 | 109 | <-done 110 | } 111 | -------------------------------------------------------------------------------- /pubsub/listener.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import "github.com/garyburd/redigo/redis" 4 | 5 | // The listener is used to keep track of events that the client is 6 | // listening to. 7 | type Listener struct { 8 | // The event slug we're listening for. 9 | Event string 10 | // and its type 11 | Type ListenerType 12 | // The channel we send events down for plain subscriptions. 13 | Messages chan redis.Message 14 | // The channel we send events down for pattern subscriptions. 15 | PMessages chan redis.PMessage 16 | // The client it's attached to. 17 | Client *Client 18 | // Whether its active. True by default, false after unsubscribed. 19 | Active bool 20 | } 21 | 22 | // Unsubscribes the listener. 23 | func (l *Listener) Unsubscribe() { 24 | l.Client.Unsubscribe(l) 25 | } 26 | 27 | // Resubscribes the listener. 28 | func (l *Listener) Resubscribe() { 29 | l.Client.Subscribe(l) 30 | } 31 | -------------------------------------------------------------------------------- /pubsub/pubsub_test.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/garyburd/redigo/redis" 9 | "github.com/mixer/redutil/conn" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func publish(event string, data string) { 14 | c, _ := redis.Dial("tcp", "127.0.0.1:6379") 15 | defer c.Close() 16 | c.Do("publish", event, data) 17 | } 18 | 19 | func disrupt(client *Client) { 20 | client.tasks <- task{Action: disruptAction} 21 | client.WaitFor(ConnectedEvent) 22 | } 23 | 24 | func create(t *testing.T) *Client { 25 | client := New(conn.New(conn.ConnectionParam{ 26 | Address: "127.0.0.1:6379", 27 | Timeout: time.Second, 28 | }, 1)) 29 | go client.Connect() 30 | client.WaitFor(ConnectedEvent) 31 | 32 | return client 33 | } 34 | 35 | func TestBasic(t *testing.T) { 36 | client := create(t) 37 | defer client.TearDown() 38 | 39 | listener := client.Listener(Channel, "foobar") 40 | client.WaitFor(SubscribeEvent) 41 | publish("foobar", "heyo!") 42 | assert.Equal(t, "heyo!", string((<-listener.Messages).Data)) 43 | } 44 | 45 | func TestReconnects(t *testing.T) { 46 | client := create(t) 47 | defer client.TearDown() 48 | 49 | listener := client.Listener(Channel, "foobar") 50 | client.WaitFor(SubscribeEvent) 51 | publish("foobar", "heyo!") 52 | assert.Equal(t, "heyo!", string((<-listener.Messages).Data)) 53 | 54 | go func() { 55 | time.Sleep(10 * time.Millisecond) 56 | disrupt(client) 57 | }() 58 | 59 | client.WaitFor(SubscribeEvent) 60 | publish("foobar", "we're back!") 61 | assert.Equal(t, "we're back!", string((<-listener.Messages).Data)) 62 | } 63 | 64 | func TestIncreasesReconnectTimeAndResets(t *testing.T) { 65 | t.Skip("skipping flakey test") 66 | 67 | client := create(t) 68 | defer client.TearDown() 69 | client.Listener(Channel, "foobar") 70 | 71 | for i, prev := 0, -1; i < 5; i++ { 72 | assert.True(t, time.Duration(prev) < client.policy.Next()) 73 | } 74 | 75 | disrupt(client) 76 | 77 | assert.Equal(t, 0, int(client.policy.Next())) 78 | } 79 | 80 | func TestUnsubscribe(t *testing.T) { 81 | client := create(t) 82 | defer client.TearDown() 83 | 84 | listener := client.Listener(Channel, "foobar") 85 | client.WaitFor(SubscribeEvent) 86 | publish("foobar", "heyo!") 87 | assert.Equal(t, "heyo!", string((<-listener.Messages).Data)) 88 | 89 | // Unsubscribe, then publish and listen for a second to make sure 90 | // the event doesn't come in. 91 | listener.Unsubscribe() 92 | client.WaitFor(UnsubscribeEvent) 93 | publish("foobar", "heyo!") 94 | 95 | select { 96 | case packet := <-client.OnChannel(AnyEvent): 97 | assert.Fail(t, fmt.Sprintf("Got 'some' packet after unsubscribe: %#v", packet)) 98 | case <-time.After(time.Millisecond * 100): 99 | } 100 | 101 | disrupt(client) 102 | 103 | // Make sure we don't resubscribe after a disruption. 104 | publish("foobar", "heyo!") 105 | select { 106 | case packet := <-client.OnChannel(AnyEvent): 107 | assert.Fail(t, fmt.Sprintf("Got 'some' packet after unsubscribe reconnect: %#v", packet)) 108 | case <-time.After(time.Millisecond * 100): 109 | } 110 | } 111 | 112 | func TestDoesNotReconnectAfterGracefulClose(t *testing.T) { 113 | client := create(t) 114 | defer client.TearDown() 115 | 116 | client.Listener(Channel, "foobar") 117 | 118 | reconnect := make(chan bool) 119 | go (func() { 120 | client.WaitFor(ConnectedEvent) 121 | reconnect <- true 122 | })() 123 | 124 | select { 125 | case <-reconnect: 126 | assert.Fail(t, "Client should not have reconnected.") 127 | case <-time.After(time.Millisecond * 100): 128 | } 129 | } 130 | 131 | func TestPatternConnects(t *testing.T) { 132 | client := create(t) 133 | defer client.TearDown() 134 | 135 | listener := client.Listener(Pattern, "foo:*:bar") 136 | client.WaitFor(SubscribeEvent) 137 | publish("foo:2:bar", "heyo!") 138 | assert.Equal(t, "heyo!", string((<-listener.PMessages).Data)) 139 | 140 | // Make sure we don't resubscribe after a disruption. 141 | listener.Unsubscribe() 142 | client.WaitFor(UnsubscribeEvent) 143 | publish("foo:2:bar", "oh no!") 144 | select { 145 | case packet := <-client.OnChannel(AnyEvent): 146 | assert.Fail(t, fmt.Sprintf("Got 'some' packet after unsubscribe: %#v", packet)) 147 | case <-time.After(time.Millisecond * 100): 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /pubsub/state.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "github.com/mixer/fsm" 5 | ) 6 | 7 | const ( 8 | // Not currently connected to a server. 9 | DisconnectedState uint8 = iota 10 | // Connected to a server, but not yet linked as a pubsub client. 11 | ConnectedState 12 | // We're in the process of closing the client. 13 | ClosingState 14 | // We were connected, but closed gracefully. 15 | ClosedState 16 | ) 17 | 18 | var blueprint *fsm.Blueprint 19 | 20 | func init() { 21 | bp := fsm.New() 22 | bp.Start(DisconnectedState) 23 | bp.From(DisconnectedState).To(ConnectedState) 24 | bp.From(DisconnectedState).To(ClosingState) 25 | bp.From(ConnectedState).To(DisconnectedState) 26 | bp.From(ConnectedState).To(ClosingState) 27 | bp.From(ClosingState).To(ClosedState) 28 | 29 | blueprint = bp 30 | } 31 | -------------------------------------------------------------------------------- /pubsub2/counters.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/prometheus/client_golang/prometheus" 7 | ) 8 | 9 | var ( 10 | PromSubscriptions = prometheus.NewGauge(prometheus.GaugeOpts{ 11 | Name: "redutil_pubsub_subscriptions", 12 | Help: "Number of subscriptions held by Redutil", 13 | }) 14 | PromReconnections = prometheus.NewGauge(prometheus.GaugeOpts{ 15 | Name: "redutil_pubsub_reconnections", 16 | Help: "Number of times pubsub has reconnected", 17 | }) 18 | PromSendLatency = prometheus.NewGauge(prometheus.GaugeOpts{ 19 | Name: "redutil_pubsub_send_latency", 20 | Help: "Amount of time it last took to fire a prometheus event", 21 | }) 22 | PromReconnectLatency = prometheus.NewGauge(prometheus.GaugeOpts{ 23 | Name: "redutil_pubsub_reconnect_latency", 24 | Help: "Amount of time it last took to reconnect", 25 | }) 26 | PromSubLatency = prometheus.NewGauge(prometheus.GaugeOpts{ 27 | Name: "redutil_pubsub_sub_latency", 28 | Help: "Amount of time it last took to subscribe or unsubscribe", 29 | }) 30 | ) 31 | 32 | func gaugeLatency(g prometheus.Gauge) (stop func()) { 33 | start := time.Now() 34 | 35 | return func() { 36 | g.Set(float64(time.Now().Sub(start)) / float64(time.Millisecond)) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /pubsub2/emitter.go: -------------------------------------------------------------------------------- 1 | // Package pubsub implements helpers to connect to and read events from 2 | // Redis in a reliable way. 3 | // 4 | // Subscriptions are handled on a Redis connection; multiple listeners can 5 | // be attached to single events, and subscription and unsubscription on 6 | // the connection are handled automatically. If the connection goes down, 7 | // a new connection will be established and events will be automatically 8 | // re-subscribed to. All methods are thread-safe. 9 | // 10 | // Sometimes you wish to subscribe to a simple, single event in Redis, but 11 | // more often than not you want to subscribe to an event on a resource. 12 | // Pubsub gives gives you an easy way to handle these kinds of subscriptions, 13 | // without the need for `fmt` on your part or any kind or reflection-based 14 | // formatting on ours. You create Event structures with typed fields 15 | // and can later inspect those fields when you receive an event from 16 | // Redis, in a reasonably fluid and very type-safe manner. 17 | // 18 | // Simple Event 19 | // 20 | // pub := pubsub.New(cnx) 21 | // pub.Subscribe(pubsub.NewEvent("foo"), function (e pubsub.Event, b []byte) { 22 | // // You're passed the subscribed event and any byte payload 23 | // fmt.Printf("foo happened with payload %#v\n", b) 24 | // }) 25 | // 26 | // Patterns and Inspections 27 | // 28 | // pub := pubsub.New(cnx) 29 | // event := pubsub.NewPattern(). 30 | // String("foo:"). 31 | // Start().As("id"). 32 | // String(":bar") 33 | // 34 | // pub.Subscribe(event, function (e pubsub.Event, b []byte) { 35 | // // You can alias fields and look them up by their names. Pattern 36 | // // events will have their "stars" filled in. 37 | // id, _ := e.Find("id").Int() 38 | // fmt.Printf("just got foo:%d:bar!\n", id) 39 | // }) 40 | package pubsub 41 | 42 | // The Listener contains a function which, when passed to an Emitter, is 43 | // invoked when an event occurs. It's invoked with the corresponding 44 | // subscribed Event and the event's payload. 45 | // 46 | // Note that if a listener is subscribed to multiple overlapping events such 47 | // as `foo:bar` and `foo:*`, the listener will be called multiple times. 48 | type Listener interface { 49 | // Handle is invoked when the event occurs, with 50 | // the byte payload it contains. 51 | Handle(e Event, b []byte) 52 | } 53 | 54 | type listenerFunc struct{ handle func(e Event, b []byte) } 55 | 56 | func (l listenerFunc) Handle(e Event, b []byte) { l.handle(e, b) } 57 | 58 | // ListenerFunc creates a Listener which invokes the provided function. 59 | func ListenerFunc(fn func(e Event, b []byte)) Listener { return &listenerFunc{fn} } 60 | 61 | // Emitter is the primary interface to interact with pubsub. 62 | type Emitter interface { 63 | // Subscribe registers that the provided lister wants to be notified 64 | // of the given Event. If the Listener is already subscribed to the 65 | // event, it will be added again and the listener will be invoked 66 | // multiple times when that event occurs. 67 | Subscribe(e EventBuilder, l Listener) 68 | 69 | // Unsubscribe unregisters a listener from an event. If the listener 70 | // is not subscribed to the event, this will be a noop. Note that this 71 | // only unsubscribes *one* listener from the event; if it's subscribed 72 | // multiple times, it will need to unsubscribe multiple times. 73 | Unsubscribe(e EventBuilder, l Listener) 74 | 75 | // Errs returns a channel of errors that occur asynchronously on 76 | // the Redis connection. 77 | Errs() <-chan error 78 | 79 | // Close frees resources associated with the Emitter. 80 | Close() 81 | } 82 | -------------------------------------------------------------------------------- /pubsub2/emitter_test.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestListenerFunc(t *testing.T) { 10 | called := false 11 | fn := func(e Event, b []byte) { 12 | assert.Equal(t, "foo", e.Channel()) 13 | assert.Equal(t, []byte{1, 2, 3}, b) 14 | called = true 15 | } 16 | 17 | listener := ListenerFunc(fn) 18 | listener.Handle(NewEvent().String("foo").ToEvent("foo", "foo"), []byte{1, 2, 3}) 19 | assert.True(t, called) 20 | } 21 | -------------------------------------------------------------------------------- /pubsub2/event.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | // EventType is used to distinguish between pattern and plain text events. 10 | type EventType int 11 | 12 | const ( 13 | // PlainEvent is the event type for events that are simple subscribed 14 | // to in Redis without pattern matching. 15 | PlainEvent EventType = iota 16 | // PatternEvent is the event type for events in Redis which will be 17 | // subscribed to using patterns. 18 | PatternEvent 19 | ) 20 | 21 | type patternType uint8 22 | 23 | const ( 24 | patternPlain patternType = iota 25 | patternStar 26 | patternAlts 27 | patternPlaceholder 28 | ) 29 | 30 | // SubCommand returns the command issued to subscribe to the event in Redis. 31 | func (e EventType) SubCommand() string { 32 | switch e { 33 | case PlainEvent: 34 | return "SUBSCRIBE" 35 | case PatternEvent: 36 | return "PSUBSCRIBE" 37 | default: 38 | panic(fmt.Sprintf("redutil/pubsub: unknown event type %d", e)) 39 | } 40 | } 41 | 42 | // UnsubCommand returns the command issued 43 | // o unsubscribe from the event in Redis. 44 | func (e EventType) UnsubCommand() string { 45 | switch e { 46 | case PlainEvent: 47 | return "UNSUBSCRIBE" 48 | case PatternEvent: 49 | return "PUNSUBSCRIBE" 50 | default: 51 | panic(fmt.Sprintf("redutil/pubsub: unknown event type %d", e)) 52 | } 53 | } 54 | 55 | // Field is a type which is are concatenated into events, listed to over Redis. 56 | type Field struct { 57 | valid bool 58 | alias string 59 | value string 60 | pattern patternType 61 | } 62 | 63 | // IsZero returns true if the field is empty. A call to Event.Find() or 64 | // Event.Get() with a non-existent alias or index will return such a struct. 65 | func (f Field) IsZero() bool { return !f.valid } 66 | 67 | // String returns the field value as a string. 68 | func (f Field) String() string { return f.value } 69 | 70 | // Bytes returns the field value as a byte slice. 71 | func (f Field) Bytes() []byte { return []byte(f.value) } 72 | 73 | // Int attempts to parse and return the field value as an integer. 74 | func (f Field) Int() (int, error) { 75 | x, err := strconv.ParseInt(f.value, 10, 32) 76 | return int(x), err 77 | } 78 | 79 | // Uint64 attempts to parse and return the field value as a uint64. 80 | func (f Field) Uint64() (uint64, error) { return strconv.ParseUint(f.value, 10, 64) } 81 | 82 | // Int64 attempts to parse and return the field value as a int64. 83 | func (f Field) Int64() (int64, error) { return strconv.ParseInt(f.value, 10, 64) } 84 | 85 | // An Event is passed to an Emitter to manage which 86 | // events a Listener is subscribed to. 87 | type EventBuilder struct { 88 | fields []Field 89 | kind EventType 90 | } 91 | 92 | // As sets the alias of the last field in the event list. You may then call 93 | // Event.Find(alias) to look up the value of the field in the event. 94 | func (e EventBuilder) As(alias string) EventBuilder { 95 | e.fields[len(e.fields)-1].alias = alias 96 | return e 97 | } 98 | 99 | // String creates a Field containing a string. 100 | func (e EventBuilder) String(str string) EventBuilder { 101 | e.fields = append(e.fields, Field{valid: true, value: str}) 102 | return e 103 | } 104 | 105 | // Int creates a Field containing an integer. 106 | func (e EventBuilder) Int(x int) EventBuilder { 107 | e.fields = append(e.fields, Field{valid: true, value: strconv.Itoa(x)}) 108 | return e 109 | } 110 | 111 | // Star creates a Field containing the Kleene star `*` for pattern subscription, 112 | // and chains it on to the EventBuilder. 113 | func (e EventBuilder) Star() EventBuilder { 114 | e.assertPattern() 115 | e.fields = append(e.fields, Field{valid: true, value: "*", pattern: patternStar}) 116 | return e 117 | } 118 | 119 | // Placeholder creates a field containing a `?` for a placeholder in Redis patterns, 120 | // and chains it on to the event. 121 | func (e EventBuilder) Placeholder() EventBuilder { 122 | e.assertPattern() 123 | e.fields = append(e.fields, Field{valid: true, value: "?", pattern: patternPlaceholder}) 124 | return e 125 | } 126 | 127 | // Alternatives creates a field with the alts wrapped in brackets, to match 128 | // one of them in a Redis pattern, and chains it on to the event. 129 | func (e EventBuilder) Alternatives(alts string) EventBuilder { 130 | e.assertPattern() 131 | e.fields = append(e.fields, Field{ 132 | valid: true, 133 | value: "[" + alts + "]", 134 | pattern: patternAlts, 135 | }) 136 | return e 137 | } 138 | 139 | // assertPattern panics if the event is not a Pattern type. 140 | func (e EventBuilder) assertPattern() { 141 | if e.kind != PatternEvent { 142 | panic("That operation is only valid on pattern events created with NewPattern()") 143 | } 144 | } 145 | 146 | // Name returns name of the event, formed by a concatenation of all the 147 | // event fields. 148 | func (e EventBuilder) Name() string { 149 | strs := make([]string, len(e.fields)) 150 | for i, field := range e.fields { 151 | strs[i] = field.value 152 | } 153 | 154 | return strings.Join(strs, "") 155 | } 156 | 157 | // ToEvent converts an EventBuilder into an immutable event which appears 158 | // to have been send down the provided channel and pattern. This is primarily 159 | // used internally but may also be useful for unit testing. 160 | func (e EventBuilder) ToEvent(channel, pattern string) Event { 161 | fields := make([]Field, len(e.fields)) 162 | copy(fields, e.fields) 163 | 164 | return Event{ 165 | fields: fields, 166 | kind: e.kind, 167 | channel: channel, 168 | pattern: pattern, 169 | } 170 | } 171 | 172 | func (e EventBuilder) concat(other EventBuilder) EventBuilder { 173 | e.fields = append(e.fields, other.fields...) 174 | return e 175 | } 176 | 177 | func (e EventBuilder) slice(start, end int) EventBuilder { 178 | e.fields = e.fields[start:end] 179 | return e 180 | } 181 | 182 | // applyFields attempts to convert the values from a string or byte slice into 183 | // a Field. It panics if a value is none of the above. 184 | func applyFields(event EventBuilder, values []interface{}) EventBuilder { 185 | for _, v := range values { 186 | switch t := v.(type) { 187 | case string: 188 | event = event.String(t) 189 | case []byte: 190 | event = event.String(string(t)) 191 | default: 192 | panic(fmt.Sprintf("Expected string or field when creating an event, got %T", v)) 193 | } 194 | } 195 | 196 | return event 197 | } 198 | 199 | // NewEvent creates and returns a new event based off the series of fields. 200 | // This translates to a Redis SUBSCRIBE call. 201 | func NewEvent(fields ...interface{}) EventBuilder { 202 | return applyFields(EventBuilder{kind: PlainEvent}, fields) 203 | } 204 | 205 | // NewPattern creates and returns a new event pattern off the series 206 | // of fields. This translates to a Redis PSUBSCRIBE call. 207 | func NewPattern(fields ...interface{}) EventBuilder { 208 | return applyFields(EventBuilder{kind: PatternEvent}, fields) 209 | } 210 | 211 | // An Event is passed down to a Listener's Handle function. 212 | type Event struct { 213 | channel, pattern string 214 | fields []Field 215 | kind EventType 216 | } 217 | 218 | // Type returns the type of the event (either a PlainEvent or a PatternEvent). 219 | func (e Event) Type() EventType { return e.kind } 220 | 221 | // Channel returns a the channel that the event was sent down. For plain 222 | // events, this is equivalent to event names. For pattern events, this is 223 | // the fulfilled name of the event. 224 | func (e Event) Channel() string { return e.channel } 225 | 226 | // Pattern returns the pattern which was originally used to subscribe to 227 | // this event. For plain events, this will simply be the event name. 228 | func (e Event) Pattern() string { return e.channel } 229 | 230 | // Len returns the number of fields contained in the event. 231 | func (e Event) Len() int { return len(e.fields) } 232 | 233 | // Get returns the value of a field at index `i` within the event. If the 234 | // field does not exist, an empty struct will be returned. 235 | func (e Event) Get(i int) Field { 236 | if len(e.fields) <= i { 237 | return Field{valid: false} 238 | } 239 | 240 | return e.fields[i] 241 | } 242 | 243 | // Find looks up a field value by its alias. This is most useful in pattern 244 | // subscriptions where might use Find to look up a parameterized property. 245 | // If the alias does not exist, an empty struct will be returned. 246 | func (e Event) Find(alias string) Field { 247 | for _, field := range e.fields { 248 | if field.alias == alias { 249 | return field 250 | } 251 | } 252 | 253 | return Field{valid: false} 254 | } 255 | 256 | // Matches the special characters of an event against a specific channel, 257 | // returning a new event that has the ambiguities (*, ?, [xy]) resolved. 258 | // It returns a zero event if it cannot match the string. 259 | // 260 | // This is partly based on Redis' own matching from: 261 | // https://github.com/antirez/redis/blob/unstable/src/util.c#L47 262 | func matchPatternAgainst(ev EventBuilder, channel string) (EventBuilder, bool) { 263 | if ev.kind != PatternEvent { 264 | panic("pattern matching is only valid against pattern events") 265 | } 266 | 267 | out := EventBuilder{ 268 | fields: make([]Field, 0, len(ev.fields)), 269 | kind: PatternEvent, 270 | } 271 | 272 | pos := 0 273 | chlen := len(channel) 274 | num := len(ev.fields) 275 | for i := 0; i < num; i++ { 276 | field := ev.fields[i] 277 | vallen := len(field.value) 278 | 279 | // Fail if the field is a pattern and the current character does not 280 | // contain one of the alternatives. 281 | if field.pattern == patternAlts && pos < chlen && 282 | strings.IndexByte(field.value[1:vallen-1], channel[pos]) == -1 { 283 | return EventBuilder{}, false 284 | } 285 | 286 | // Swap out placeholders or alts for the next character, if there is one. 287 | if field.pattern == patternPlaceholder || field.pattern == patternAlts { 288 | if pos == chlen { 289 | return EventBuilder{}, false 290 | } 291 | out = out.String(string(channel[pos])) 292 | pos++ 293 | continue 294 | } 295 | 296 | // Eat stars, yum 297 | if field.pattern == patternStar { 298 | // If this is the last component, eat the rest of the channel. 299 | if i == num-1 { 300 | out = out.String(channel[pos:]) 301 | return out, true 302 | } 303 | 304 | // Start eating up the pattern until we get to a point where 305 | // we can match the rest. Then add what we ate as a string -- 306 | // that's what the star contained -- and concat everything 307 | // else on the end. 308 | tail := ev.slice(i+1, num) 309 | for end := pos; end < chlen; end++ { 310 | if tail, ok := matchPatternAgainst(tail, channel[end:]); ok { 311 | out = out.String(channel[pos:end]) 312 | return out.slice(0, i+1).concat(tail), true 313 | } 314 | } 315 | 316 | // Can't find anything? This is invalid. 317 | return EventBuilder{}, false 318 | } 319 | 320 | // Otherwise it's a plain text match. Make sure it actually matches, 321 | // then add it on. 322 | if vallen+pos > chlen || channel[pos:pos+vallen] != field.value { 323 | return EventBuilder{}, false 324 | } 325 | 326 | out.fields = append(out.fields, field) 327 | pos += vallen 328 | } 329 | 330 | return out, true 331 | } 332 | -------------------------------------------------------------------------------- /pubsub2/event_test.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "os" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestEventTypes(t *testing.T) { 13 | assert.Equal(t, "SUBSCRIBE", PlainEvent.SubCommand()) 14 | assert.Equal(t, "UNSUBSCRIBE", PlainEvent.UnsubCommand()) 15 | assert.Equal(t, "PSUBSCRIBE", PatternEvent.SubCommand()) 16 | assert.Equal(t, "PUNSUBSCRIBE", PatternEvent.UnsubCommand()) 17 | } 18 | 19 | func TestEventBuildsString(t *testing.T) { 20 | e := NewEvent("foo") 21 | assert.Equal(t, PlainEvent, e.kind) 22 | assert.Equal(t, e.Name(), "foo") 23 | } 24 | 25 | func TestEventBuildsPattern(t *testing.T) { 26 | e := NewPattern("foo") 27 | assert.Equal(t, PatternEvent, e.kind) 28 | assert.Equal(t, e.Name(), "foo") 29 | } 30 | 31 | func TestEventBuildsMultipart(t *testing.T) { 32 | e := NewEvent("prefix:").String("foo:").Int(42) 33 | assert.Equal(t, "prefix:foo:42", e.Name()) 34 | 35 | b := e.ToEvent("prefix:foo:42", "prefix:foo:42") 36 | assert.Equal(t, 3, b.Len()) 37 | 38 | assert.Equal(t, "prefix:", b.Get(0).String()) 39 | id, _ := b.Get(2).Int() 40 | assert.Equal(t, "foo:", b.Get(1).String()) 41 | assert.Equal(t, 42, id) 42 | assert.Equal(t, "prefix:foo:42", b.Channel()) 43 | assert.Equal(t, "prefix:foo:42", b.Pattern()) 44 | } 45 | 46 | func TestEventReturnsZeroOnDNE(t *testing.T) { 47 | assert.True(t, NewEvent("foo").ToEvent("", "").Get(1).IsZero()) 48 | assert.False(t, NewEvent("foo").ToEvent("", "").Get(0).IsZero()) 49 | assert.True(t, NewEvent("foo").Int(1).As("bar").ToEvent("", "").Find("bleh").IsZero()) 50 | assert.False(t, NewEvent("foo").Int(1).As("bar").ToEvent("", "").Find("bar").IsZero()) 51 | } 52 | 53 | func TestEventMatchesPattern(t *testing.T) { 54 | tt := []struct { 55 | isMatch bool 56 | event EventBuilder 57 | channel string 58 | }{ 59 | {true, NewPattern("foo"), "foo"}, 60 | {false, NewPattern("foo"), "bar"}, 61 | {false, NewPattern("fooo"), "foo"}, 62 | {false, NewPattern("foo"), "fooo"}, 63 | 64 | {true, NewPattern("foo").Star(), "foo"}, 65 | {true, NewPattern("foo").Star(), "fooasdf"}, 66 | {true, NewPattern("foo").Star().String("bar"), "foo42bar"}, 67 | {false, NewPattern("foo").Star().String("nar"), "foo42bar"}, 68 | {true, NewPattern("foo").Star().String("bar").Star(), "foo42bar"}, 69 | {true, NewPattern("foo").Star().String("bar").Star(), "foo42bar42"}, 70 | {false, NewPattern("foo").Star().String("baz").Star(), "foo42bar42"}, 71 | 72 | {false, NewPattern("foo").Alternatives("123"), "foo6"}, 73 | {true, NewPattern("foo").Alternatives("123"), "foo2"}, 74 | } 75 | 76 | for _, test := range tt { 77 | actual, _ := matchPatternAgainst(test.event, test.channel) 78 | matches := test.channel == actual.Name() 79 | if test.isMatch { 80 | assert.True(t, matches, fmt.Sprintf("%s ∉ %s", test.channel, test.event.Name())) 81 | } else { 82 | assert.False(t, matches, fmt.Sprintf("%s ∈ %s", test.channel, test.event.Name())) 83 | } 84 | } 85 | } 86 | 87 | func TestFuzz(t *testing.T) { 88 | if os.Getenv("FUZZ") == "" { 89 | return 90 | } 91 | 92 | fields := []struct { 93 | Transform func(e EventBuilder) EventBuilder 94 | Matching string 95 | }{ 96 | {func(e EventBuilder) EventBuilder { return e.Star() }, "adsf"}, 97 | {func(e EventBuilder) EventBuilder { return e.String("foo") }, "foo"}, 98 | {func(e EventBuilder) EventBuilder { return e.String("bar") }, "bar"}, 99 | {func(e EventBuilder) EventBuilder { return e.Alternatives("123") }, "2"}, 100 | } 101 | // Transition matrix for fields, by index. Given an [x, y], field[y] has 102 | // a matrix[x][y] chance of transitioning into field[x] next 103 | transitions := [][]float64{ 104 | {0, 0.34, 0.34, 0.32}, 105 | {0.3, 0.2, 0.3, 0.2}, 106 | {0.3, 0.3, 0.2, 0.2}, 107 | {0.25, 0.25, 0.25, 0.25}, 108 | } 109 | 110 | fmt.Println("") 111 | 112 | for k := 0; true; k++ { 113 | event := NewPattern() 114 | matching := "" 115 | 116 | // 1. build the event 117 | for i := rand.Intn(len(fields)); len(event.fields) == 0 || rand.Float64() < 0.8; { 118 | event = fields[i].Transform(event) 119 | matching += fields[i].Matching 120 | 121 | x := rand.Float64() 122 | sum := float64(0) 123 | for idx, p := range transitions[i] { 124 | sum += p 125 | if x < sum { 126 | i = idx 127 | break 128 | } 129 | } 130 | } 131 | 132 | actual, _ := matchPatternAgainst(event, matching) 133 | if matching != actual.Name() { 134 | panic(fmt.Sprintf("%s ∉ %s", matching, event.Name())) 135 | } 136 | 137 | if k%100000 == 0 { 138 | fmt.Printf("\033[2K\r%d tests run -- %s ∈ %s", k, matching, event.Name()) 139 | } 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /pubsub2/fuzz_record_list/fuzz.go: -------------------------------------------------------------------------------- 1 | package fuzz 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/mixer/redutil/pubsub2" 7 | ) 8 | 9 | var ( 10 | listeners = []pubsub.Listener{} 11 | event = pubsub.NewEvent() 12 | ) 13 | 14 | func init() { 15 | for i := 0; i <= 0xFF; i++ { 16 | listeners = append(listeners, pubsub.ListenerFunc(func(_ pubsub.Event, _ []byte) {})) 17 | } 18 | } 19 | 20 | // Fuzz is the main function for go-fuzz: https://github.com/dvyukov/go-fuzz. 21 | // It adds and remove users according to the sequence of bytes and ensures 22 | // state is consistent at the end. 23 | func Fuzz(data []byte) int { 24 | list := pubsub.NewRecordList() 25 | isAdded := map[pubsub.Listener]bool{} 26 | for i := 0; i < len(data); i++ { 27 | listener := listeners[data[i]] 28 | if isAdded[listener] { 29 | list.Remove(event, listener) 30 | isAdded[listener] = false 31 | } else { 32 | list.Add(event, listener) 33 | isAdded[listener] = true 34 | } 35 | } 36 | 37 | result := list.ListenersFor(event) 38 | for i, l := range listeners { 39 | exists := false 40 | for _, l2 := range result { 41 | if l2 == l { 42 | exists = true 43 | break 44 | } 45 | } 46 | 47 | if exists != isAdded[l] { 48 | panic(fmt.Sprintf("expected listener %d exists=%+v, but it was not", i, isAdded[l])) 49 | } 50 | } 51 | 52 | return 0 53 | } 54 | -------------------------------------------------------------------------------- /pubsub2/pumps.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/garyburd/redigo/redis" 7 | ) 8 | 9 | // readPump tries to read out data from the connection, sending it down 10 | // the data channel, until it's closed. It's a conversion from 11 | // synchronous to channel-based code. 12 | type readPump struct { 13 | cnx redis.Conn 14 | 15 | data chan interface{} 16 | errs chan error 17 | closer chan struct{} 18 | } 19 | 20 | // newReadPump creates a new pump that operates on the single Redis connection. 21 | func newReadPump(cnx redis.Conn) *readPump { 22 | return &readPump{ 23 | cnx: cnx, 24 | data: make(chan interface{}), 25 | errs: make(chan error), 26 | closer: make(chan struct{}), 27 | } 28 | } 29 | 30 | // Work starts reading from the connection and blocks until it is closed. 31 | func (r *readPump) Work() { 32 | cnx := redis.PubSubConn{Conn: r.cnx} 33 | defer close(r.closer) 34 | 35 | for { 36 | msg := cnx.Receive() 37 | 38 | if err, isErr := msg.(error); isErr && shouldNotifyUser(err) { 39 | select { 40 | case r.errs <- err: 41 | case <-r.closer: 42 | return 43 | } 44 | } else if !isErr { 45 | select { 46 | case r.data <- msg: 47 | case <-r.closer: 48 | return 49 | } 50 | } 51 | } 52 | } 53 | 54 | // Errs returns a channel of errors from the connection 55 | func (r *readPump) Errs() <-chan error { return r.errs } 56 | 57 | // Data returns a channel of pubsub data from the connection 58 | func (r *readPump) Data() <-chan interface{} { return r.data } 59 | 60 | // Close tears down the connection 61 | func (r *readPump) Close() { r.closer <- struct{}{} } 62 | 63 | type command struct { 64 | command string 65 | channel string 66 | written chan<- struct{} 67 | } 68 | 69 | // writePump tries to write data sent over the channel to a connection. 70 | type writePump struct { 71 | cnx redis.Conn 72 | 73 | errs chan error 74 | data chan command 75 | closer chan struct{} 76 | } 77 | 78 | // newWritePump creates a new pump that operates on the single Redis connection. 79 | func newWritePump(cnx redis.Conn) *writePump { 80 | return &writePump{ 81 | cnx: cnx, 82 | data: make(chan command), 83 | errs: make(chan error), 84 | closer: make(chan struct{}), 85 | } 86 | } 87 | 88 | // Work starts writing to the connection and blocks until it is closed. 89 | func (r *writePump) Work() { 90 | defer close(r.closer) 91 | 92 | for { 93 | select { 94 | case data := <-r.data: 95 | r.cnx.Send(data.command, data.channel) 96 | if err := r.cnx.Flush(); err != nil { 97 | select { 98 | case r.errs <- err: 99 | case <-r.closer: 100 | return 101 | } 102 | } 103 | case <-r.closer: 104 | return 105 | } 106 | } 107 | } 108 | 109 | // Errs returns a channel of errors from the connection 110 | func (r *writePump) Errs() <-chan error { return r.errs } 111 | 112 | // Data returns a channel of pubsub data to be written to the connection 113 | func (r *writePump) Data() chan<- command { return r.data } 114 | 115 | // Close tears down the connection 116 | func (r *writePump) Close() { r.closer <- struct{}{} } 117 | 118 | // shouldNotifyUser return true if the user should be notified of the 119 | // given error; if it's not a temporary network error or a timeout. 120 | func shouldNotifyUser(err error) bool { 121 | if nerr, ok := err.(net.Error); ok && (nerr.Timeout() || nerr.Temporary()) { 122 | return false 123 | } 124 | 125 | return true 126 | } 127 | -------------------------------------------------------------------------------- /pubsub2/redis.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "sync" 5 | "sync/atomic" 6 | "unsafe" 7 | 8 | "github.com/garyburd/redigo/redis" 9 | ) 10 | 11 | // compactionThreshold defines how many inactiveItems / total length 12 | // have to be in the record before we'll do a full cleanup sweep. 13 | const compactionThreshold = 0.5 14 | 15 | type record struct { 16 | name string 17 | ev EventBuilder 18 | list unsafe.Pointer // to a []Listener 19 | inactiveItems int 20 | } 21 | 22 | // Emit invokes all attached listeners with the provided event. 23 | func (r *record) Emit(ev Event, b []byte) { 24 | // This looks absolutely _horrible_. It'd be better in C. Anyway, the idea 25 | // is that we have a list of unsafe pointers, which is stored in the list 26 | // which is an atomic pointer itself. We can swap the pointer held at each 27 | // address in the list, so _that_ address needs to be loaded atomically. 28 | // This means we need to actually load the data from the address at 29 | // offset X from the start of the array underlying the slice. 30 | // 31 | // The cost of the atomic loading makes 5-10% slower, but allows us to 32 | // atomically insert and remove listeners in many cases. 33 | 34 | list := r.getUnsafeList() 35 | for i := 0; i < len(list); i++ { 36 | addr := uintptr(unsafe.Pointer(&list[i])) 37 | value := atomic.LoadPointer(*(**unsafe.Pointer)(unsafe.Pointer(&addr))) 38 | if value != nil { 39 | (*(*Listener)(value)).Handle(ev, b) 40 | } 41 | } 42 | } 43 | 44 | func (r *record) setList(l []unsafe.Pointer) { atomic.StorePointer(&r.list, (unsafe.Pointer)(&l)) } 45 | 46 | func (r *record) storeAtIndex(index int, listener Listener) { 47 | // see Emit for details about this code. 48 | list := r.getUnsafeList() 49 | addr := uintptr(unsafe.Pointer(&list[index])) 50 | 51 | var storedPtr unsafe.Pointer 52 | if listener != nil { 53 | storedPtr = unsafe.Pointer(&listener) 54 | } 55 | 56 | atomic.StorePointer(*(**unsafe.Pointer)(unsafe.Pointer(&addr)), storedPtr) 57 | } 58 | 59 | func (r *record) getList() []Listener { 60 | original := r.getUnsafeList() 61 | output := make([]Listener, 0, len(original)-r.inactiveItems) 62 | for _, ptr := range original { 63 | if ptr != nil { 64 | output = append(output, *(*Listener)(ptr)) 65 | } 66 | } 67 | 68 | return output 69 | } 70 | 71 | func (r *record) getUnsafeList() []unsafe.Pointer { 72 | return *(*[]unsafe.Pointer)(atomic.LoadPointer(&r.list)) 73 | } 74 | 75 | // RecordList is used internally for recording which listeners are listening 76 | // to events. It's exposed for testing/fuzzing purposes, but you do not use 77 | // it directly as a consumer. 78 | type RecordList struct { 79 | list unsafe.Pointer // to a []*record 80 | } 81 | 82 | // Creates a new, empty, record list. 83 | func NewRecordList() *RecordList { 84 | init := []*record{} 85 | return &RecordList{unsafe.Pointer(&init)} 86 | } 87 | 88 | // Find looks up the index for the record corresponding to the provided 89 | // event name. It returns -1 if one was not found. Thread-safe. 90 | func (r *RecordList) Find(ev string) (index int, rec *record) { 91 | for i, rec := range r.getList() { 92 | if rec.name == ev { 93 | return i, rec 94 | } 95 | } 96 | 97 | return -1, nil 98 | } 99 | 100 | // Add inserts a new listener for an event. Returns the incremented 101 | // number of listeners. Not thread-safe with other write operations. 102 | func (r *RecordList) Add(ev EventBuilder, fn Listener) int { 103 | idx, rec := r.Find(ev.Name()) 104 | if idx == -1 { 105 | rec := &record{ev: ev, name: ev.Name()} 106 | rec.setList([]unsafe.Pointer{unsafe.Pointer(&fn)}) 107 | r.append(rec) 108 | return 1 109 | } 110 | 111 | oldList := rec.getUnsafeList() 112 | newCount := len(oldList) - rec.inactiveItems + 1 113 | if rec.inactiveItems == 0 { 114 | rec.inactiveItems = len(oldList) 115 | newList := make([]unsafe.Pointer, len(oldList)*2+1) 116 | // copy so that the new list items are at the end of the list, 117 | // this lets the slot search later run and find free space faster. 118 | copy(newList[len(oldList)+1:], oldList) 119 | newList[len(oldList)] = unsafe.Pointer(&fn) 120 | rec.setList(newList) 121 | return newCount 122 | } 123 | 124 | for i, ptr := range oldList { 125 | if ptr == nil { 126 | rec.storeAtIndex(i, fn) 127 | rec.inactiveItems-- 128 | return newCount 129 | } 130 | } 131 | 132 | panic("unreachable") 133 | } 134 | 135 | // Remove delete the listener from the event. Returns the event's remaining 136 | // listeners. Not thread-safe with other write operations. 137 | func (r *RecordList) Remove(ev EventBuilder, fn Listener) int { 138 | idx, rec := r.Find(ev.Name()) 139 | if idx == -1 { 140 | return 0 141 | } 142 | 143 | // 1. Find the index of the listener in the list 144 | oldList := rec.getUnsafeList() 145 | spliceIndex := -1 146 | for i, l := range oldList { 147 | if l != nil && (*(*Listener)(l)) == fn { 148 | spliceIndex = i 149 | break 150 | } 151 | } 152 | 153 | if spliceIndex == -1 { 154 | return len(oldList) 155 | } 156 | 157 | newCount := len(oldList) - rec.inactiveItems - 1 158 | // 2. If that's the only listener, just remove that record entirely. 159 | if newCount == 0 { 160 | r.remove(idx) 161 | return 0 162 | } 163 | 164 | // 3. Otherwise, wipe the pointer, or make a new list copied from the parts 165 | // of the old if we wanted to compact it. 166 | if float32(rec.inactiveItems+1)/float32(len(oldList)) < compactionThreshold { 167 | rec.storeAtIndex(spliceIndex, nil) 168 | rec.inactiveItems++ 169 | return newCount 170 | } 171 | 172 | newList := make([]unsafe.Pointer, 0, newCount) 173 | for i, ptr := range oldList { 174 | if ptr != nil && i != spliceIndex { 175 | newList = append(newList, ptr) 176 | } 177 | } 178 | rec.inactiveItems = 0 179 | rec.setList(newList) 180 | 181 | return newCount 182 | } 183 | 184 | // ListenersFor returns the list of listeners attached to the given event. 185 | func (r *RecordList) ListenersFor(ev EventBuilder) []Listener { 186 | idx, rec := r.Find(ev.Name()) 187 | if idx == -1 { 188 | return nil 189 | } 190 | 191 | return rec.getList() 192 | } 193 | 194 | func (r *RecordList) setList(l []*record) { atomic.StorePointer(&r.list, (unsafe.Pointer)(&l)) } 195 | 196 | func (r *RecordList) getList() []*record { return *(*[]*record)(atomic.LoadPointer(&r.list)) } 197 | 198 | func (r *RecordList) append(rec *record) { 199 | oldList := r.getList() 200 | newList := make([]*record, len(oldList)+1) 201 | copy(newList, oldList) 202 | newList[len(oldList)] = rec 203 | r.setList(newList) 204 | } 205 | 206 | func (r *RecordList) remove(index int) { 207 | oldList := r.getList() 208 | newList := make([]*record, len(oldList)-1) 209 | copy(newList, oldList[:index]) 210 | copy(newList[index:], oldList[index+1:]) 211 | r.setList(newList) 212 | } 213 | 214 | // Pubsub is an implementation of the Emitter interface using 215 | // Redis pupsub. 216 | type Pubsub struct { 217 | pool *redis.Pool 218 | errs chan error 219 | closer chan struct{} 220 | send chan command 221 | 222 | // Lists of listeners for subscribers and pattern subscribers 223 | subsMu sync.Mutex 224 | subs []*RecordList 225 | } 226 | 227 | // NewPubsub creates a new Emitter based on pubsub on the provided 228 | // Redis pool. 229 | func NewPubsub(pool *redis.Pool) *Pubsub { 230 | ps := &Pubsub{ 231 | pool: pool, 232 | errs: make(chan error), 233 | closer: make(chan struct{}), 234 | send: make(chan command), 235 | subs: []*RecordList{ 236 | PlainEvent: NewRecordList(), 237 | PatternEvent: NewRecordList(), 238 | }, 239 | } 240 | 241 | go ps.work() 242 | 243 | return ps 244 | } 245 | 246 | var _ Emitter = new(Pubsub) 247 | 248 | // Inner working loop for the emitter, runs until .Close() is called. 249 | func (p *Pubsub) work() { 250 | var ( 251 | cnx redis.Conn 252 | read *readPump 253 | write *writePump 254 | ) 255 | 256 | teardown := func() { 257 | read.Close() 258 | write.Close() 259 | cnx.Close() 260 | cnx = nil 261 | } 262 | 263 | defer teardown() 264 | 265 | for { 266 | if cnx == nil { 267 | cnx = p.pool.Get() 268 | read = newReadPump(cnx) 269 | write = newWritePump(cnx) 270 | go write.Work() 271 | p.resubscribe(write) 272 | 273 | go read.Work() 274 | } 275 | 276 | select { 277 | case <-p.closer: 278 | return 279 | case data := <-p.send: 280 | write.Data() <- data 281 | data.written <- struct{}{} 282 | case event := <-read.Data(): 283 | go p.handleEvent(event) 284 | case err := <-read.Errs(): 285 | teardown() 286 | p.errs <- err 287 | case err := <-write.Errs(): 288 | teardown() 289 | p.errs <- err 290 | } 291 | } 292 | } 293 | 294 | // resubscribe flushes the `send` queue and replaces it with commands 295 | // to resubscribe to all previously-subscribed-to channels. This will 296 | // NOT block until all subs are resubmitted, only until we get a lock. 297 | func (p *Pubsub) resubscribe(write *writePump) { 298 | timer := gaugeLatency(PromReconnectLatency) 299 | PromReconnections.Inc() 300 | 301 | p.subsMu.Lock() 302 | defer p.subsMu.Unlock() 303 | 304 | for kind, recs := range p.subs { 305 | if recs == nil { 306 | continue 307 | } 308 | 309 | for _, ev := range recs.getList() { 310 | write.Data() <- command{ 311 | command: EventType(kind).SubCommand(), 312 | channel: ev.name, 313 | } 314 | } 315 | } 316 | 317 | timer() 318 | } 319 | 320 | func (p *Pubsub) handleEvent(data interface{}) { 321 | timer := gaugeLatency(PromSendLatency) 322 | defer timer() 323 | 324 | switch t := data.(type) { 325 | case redis.Message: 326 | _, rec := p.subs[PlainEvent].Find(t.Channel) 327 | if rec == nil { 328 | return 329 | } 330 | 331 | rec.Emit(rec.ev.ToEvent(t.Channel, t.Channel), t.Data) 332 | 333 | case redis.PMessage: 334 | _, rec := p.subs[PatternEvent].Find(t.Pattern) 335 | if rec == nil { 336 | return 337 | } 338 | 339 | match, ok := matchPatternAgainst(rec.ev, t.Channel) 340 | if !ok { 341 | rec.Emit(rec.ev.ToEvent(t.Channel, t.Pattern), t.Data) 342 | } else { 343 | rec.Emit(match.ToEvent(t.Channel, t.Pattern), t.Data) 344 | } 345 | } 346 | } 347 | 348 | // Errs implements Emitter.Errs 349 | func (p *Pubsub) Errs() <-chan error { 350 | return p.errs 351 | } 352 | 353 | // Subscribe implements Emitter.Subscribe 354 | func (p *Pubsub) Subscribe(ev EventBuilder, l Listener) { 355 | timer := gaugeLatency(PromSubLatency) 356 | defer timer() 357 | 358 | p.subsMu.Lock() 359 | count := p.subs[ev.kind].Add(ev, l) 360 | p.subsMu.Unlock() 361 | 362 | if count == 1 { 363 | PromSubscriptions.Inc() 364 | written := make(chan struct{}, 1) 365 | p.send <- command{ 366 | command: ev.kind.SubCommand(), 367 | channel: ev.Name(), 368 | written: written, 369 | } 370 | 371 | <-written 372 | } 373 | } 374 | 375 | // Unsubscribe implements Emitter.Unsubscribe 376 | func (p *Pubsub) Unsubscribe(ev EventBuilder, l Listener) { 377 | timer := gaugeLatency(PromSubLatency) 378 | defer timer() 379 | 380 | p.subsMu.Lock() 381 | count := p.subs[ev.kind].Remove(ev, l) 382 | p.subsMu.Unlock() 383 | 384 | if count == 0 { 385 | PromSubscriptions.Dec() 386 | written := make(chan struct{}, 1) 387 | p.send <- command{ 388 | command: ev.kind.UnsubCommand(), 389 | channel: ev.Name(), 390 | written: written, 391 | } 392 | 393 | <-written 394 | } 395 | } 396 | 397 | // Close implements Emitter.Close 398 | func (p *Pubsub) Close() { 399 | p.closer <- struct{}{} 400 | } 401 | -------------------------------------------------------------------------------- /pubsub2/redis_test.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "math/rand" 5 | "sync" 6 | "testing" 7 | "time" 8 | 9 | "unsafe" 10 | 11 | "github.com/garyburd/redigo/redis" 12 | "github.com/mixer/redutil/conn" 13 | "github.com/mixer/redutil/test" 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/mock" 16 | "github.com/stretchr/testify/suite" 17 | ) 18 | 19 | type mockListener struct { 20 | mock.Mock 21 | called chan struct{} 22 | } 23 | 24 | func (m *mockListener) Handle(ev Event, b []byte) { 25 | m.Called(ev, b) 26 | m.called <- struct{}{} 27 | } 28 | 29 | func (m *mockListener) waitForCall() { 30 | select { 31 | case <-m.called: 32 | case <-time.After(time.Second): 33 | panic("expected to get a call to the listener") 34 | } 35 | } 36 | 37 | func newMockListener() *mockListener { 38 | return &mockListener{called: make(chan struct{}, 1)} 39 | } 40 | 41 | func newTestRecordList() (recs *RecordList, l1 Listener, l2 Listener) { 42 | recs = NewRecordList() 43 | l1 = newMockListener() 44 | l2 = newMockListener() 45 | 46 | recs.Add(NewEvent("foo"), l1) 47 | recs.Add(NewEvent("foo"), l2) 48 | 49 | return recs, l1, l2 50 | } 51 | 52 | func assertListenersEqual(t *testing.T, r *record, event EventBuilder, name string, listeners []Listener) { 53 | assert.Equal(t, r.ev, event) 54 | assert.Equal(t, r.name, name) 55 | assert.Equal(t, listeners, r.getList()) 56 | } 57 | 58 | func TestRecordsAddListeners(t *testing.T) { 59 | list := NewRecordList() 60 | ev := NewEvent("foo") 61 | l1 := newMockListener() 62 | l2 := newMockListener() 63 | assert.Len(t, list.getList(), 0) 64 | 65 | list.Add(ev, l1) 66 | 67 | list1 := list.getList() 68 | assert.Len(t, list1, 1) 69 | assertListenersEqual(t, list1[0], ev, "foo", []Listener{l1}) 70 | 71 | list.Add(ev, l2) 72 | 73 | list2 := list.getList() 74 | assert.Len(t, list2, 1) 75 | assertListenersEqual(t, list2[0], ev, "foo", []Listener{l2, l1}) 76 | } 77 | 78 | func TestRecordsRemoves(t *testing.T) { 79 | recs, l1, l2 := newTestRecordList() 80 | ev := NewEvent("foo") 81 | assertListenersEqual(t, recs.getList()[0], ev, "foo", []Listener{l2, l1}) 82 | recs.Remove(NewEvent("foo"), l2) 83 | assertListenersEqual(t, recs.getList()[0], ev, "foo", []Listener{l1}) 84 | recs.Remove(NewEvent("foo"), l1) 85 | assert.Len(t, recs.getList(), 0) 86 | } 87 | 88 | func TestRecordFindCopyGetsEmptyByDefault(t *testing.T) { 89 | i, recs := NewRecordList().Find("foo") 90 | assert.Equal(t, -1, i) 91 | assert.Nil(t, recs) 92 | } 93 | 94 | func TestRecordsGetCopies(t *testing.T) { 95 | recs, l1, l2 := newTestRecordList() 96 | ev := NewEvent("foo") 97 | _, out := recs.Find("foo") 98 | l3 := newMockListener() 99 | 100 | originalList := out.getList() 101 | assert.Equal(t, originalList, []Listener{l2, l1}) 102 | recs.Add(ev, l3) 103 | assert.Equal(t, originalList, []Listener{l2, l1}) 104 | 105 | updatedList := out.getList() 106 | assert.Equal(t, updatedList, []Listener{l3, l2, l1}) 107 | recs.Remove(ev, l1) 108 | recs.Remove(ev, l2) 109 | assert.Equal(t, updatedList, []Listener{l3, l2, l1}) 110 | assert.Equal(t, originalList, []Listener{l2, l1}) 111 | assert.Equal(t, out.getList(), []Listener{l3}) 112 | } 113 | 114 | func doRandomly(fns []func()) { 115 | for _, i := range rand.Perm(len(fns)) { 116 | fns[i]() 117 | } 118 | } 119 | 120 | func TestRaceGuarantees(t *testing.T) { 121 | // this test will fail with the race detector enabled if anything that's 122 | // not thread-safe happens. 123 | 124 | recs := NewRecordList() 125 | ev1 := NewEvent("foo") 126 | ev2 := NewEvent("bar") 127 | until := time.Now().Add(500 * time.Millisecond) 128 | 129 | var wg sync.WaitGroup 130 | wg.Add(2) 131 | 132 | go func() { 133 | defer wg.Done() 134 | for time.Now().Before(until) { 135 | listener1 := newMockListener() 136 | listener2 := newMockListener() 137 | doRandomly([]func(){ 138 | func() { recs.Add(ev1, listener1) }, 139 | func() { recs.Add(ev2, listener1) }, 140 | func() { recs.Add(ev1, listener2) }, 141 | func() { recs.Add(ev2, listener2) }, 142 | }) 143 | 144 | doRandomly([]func(){ 145 | func() { recs.Remove(ev1, listener1) }, 146 | func() { recs.Remove(ev2, listener1) }, 147 | func() { recs.Remove(ev1, listener2) }, 148 | func() { recs.Remove(ev2, listener2) }, 149 | }) 150 | } 151 | }() 152 | 153 | go func() { 154 | defer wg.Done() 155 | for time.Now().Before(until) { 156 | recs.Find("bar") 157 | recs.Find("foo") 158 | } 159 | }() 160 | 161 | wg.Wait() 162 | } 163 | 164 | type RedisPubsubSuite struct { 165 | *test.RedisSuite 166 | emitter *Pubsub 167 | } 168 | 169 | func TestRedisPubsubSuite(t *testing.T) { 170 | pool, _ := conn.New(conn.ConnectionParam{ 171 | Address: "127.0.0.1:6379", 172 | }, 1) 173 | 174 | suite.Run(t, &RedisPubsubSuite{RedisSuite: test.NewSuite(pool)}) 175 | } 176 | 177 | func (r *RedisPubsubSuite) SetupTest() { 178 | r.emitter = NewPubsub(r.Pool) 179 | } 180 | 181 | func (r *RedisPubsubSuite) TearDownTest() { 182 | r.emitter.Close() 183 | } 184 | 185 | // waitForSubscribers blocks until there are at least `num` subscribers 186 | // listening on the `channel` in Redis. 187 | func (r *RedisPubsubSuite) waitForSubscribers(channel string, num int) { 188 | cnx := r.Pool.Get() 189 | defer cnx.Close() 190 | 191 | for { 192 | list, err := redis.Values(cnx.Do("PUBSUB", "NUMSUB", channel)) 193 | if err != nil { 194 | panic(err) 195 | } 196 | 197 | var count int 198 | if _, err := redis.Scan(list, nil, &count); err != nil { 199 | panic(err) 200 | } 201 | 202 | if count >= num { 203 | return 204 | } 205 | 206 | time.Sleep(50 * time.Millisecond) 207 | } 208 | } 209 | 210 | func (r *RedisPubsubSuite) TestBasicReception() { 211 | cnx := r.Pool.Get() 212 | defer cnx.Close() 213 | 214 | l := newMockListener() 215 | defer l.AssertExpectations(r.T()) 216 | 217 | ev := NewEvent("foo") 218 | body := []byte("bar") 219 | r.emitter.Subscribe(ev, l) 220 | l.On("Handle", ev.ToEvent("foo", "foo"), body).Return() 221 | r.MustDo("PUBLISH", "foo", body) 222 | l.waitForCall() 223 | } 224 | 225 | func (r *RedisPubsubSuite) TestCallsMultipleListeners() { 226 | cnx := r.Pool.Get() 227 | defer cnx.Close() 228 | 229 | l1 := newMockListener() 230 | defer l1.AssertExpectations(r.T()) 231 | l2 := newMockListener() 232 | defer l2.AssertExpectations(r.T()) 233 | 234 | ev := NewEvent("foo") 235 | body1 := []byte("bar1") 236 | body2 := []byte("bar2") 237 | body3 := []byte("bar3") 238 | l1.On("Handle", ev.ToEvent("foo", "foo"), body1).Return() 239 | l2.On("Handle", ev.ToEvent("foo", "foo"), body1).Return() 240 | l2.On("Handle", ev.ToEvent("foo", "foo"), body2).Return() 241 | 242 | r.emitter.Subscribe(ev, l1) 243 | r.emitter.Subscribe(ev, l2) 244 | 245 | r.MustDo("PUBLISH", "foo", body1) 246 | l1.waitForCall() 247 | l2.waitForCall() 248 | 249 | r.emitter.Unsubscribe(ev, l1) 250 | 251 | r.MustDo("PUBLISH", "foo", body2) 252 | l2.waitForCall() 253 | 254 | r.emitter.Unsubscribe(ev, l2) 255 | r.MustDo("PUBLISH", "foo", body3) 256 | } 257 | 258 | func (r *RedisPubsubSuite) TestsCallsWithFancyPatterns() { 259 | cnx := r.Pool.Get() 260 | defer cnx.Close() 261 | 262 | l1 := newMockListener() 263 | defer l1.AssertExpectations(r.T()) 264 | 265 | ev := NewPattern("foo:").Int(42).String(":bar") 266 | body1 := []byte("bar1") 267 | l1.On("Handle", ev.ToEvent("foo:42:bar", "foo:*:bar"), body1).Return() 268 | r.emitter.Subscribe(NewPattern().String("foo:").Star().String(":bar"), l1) 269 | 270 | r.MustDo("PUBLISH", "foo:42:bar", body1) 271 | l1.waitForCall() 272 | } 273 | 274 | func (r *RedisPubsubSuite) TestResubscribesWhenDies() { 275 | cnx := r.Pool.Get() 276 | defer cnx.Close() 277 | 278 | body := []byte("bar") 279 | ev := NewEvent("foo") 280 | l := newMockListener() 281 | defer l.AssertExpectations(r.T()) 282 | 283 | r.emitter.Subscribe(ev, l) 284 | 285 | r.MustDo("CLIENT", "KILL", "SKIPME", "yes") 286 | select { 287 | case <-r.emitter.Errs(): 288 | case <-time.After(time.Second): 289 | r.Fail("timeout: expected to get an error after kill Redis conns") 290 | } 291 | 292 | r.waitForSubscribers("foo", 1) 293 | l.On("Handle", ev.ToEvent("foo", "foo"), body).Return() 294 | r.MustDo("PUBLISH", "foo", body) 295 | l.waitForCall() 296 | } 297 | 298 | func createBenchmarkList(count int, removeEvery int) (listeners []*Listener, recordInst *record, recordList *RecordList) { 299 | listeners = make([]*Listener, count) 300 | for i := 0; i < count; i++ { 301 | wrapped := ListenerFunc(func(_ Event, _ []byte) {}) 302 | listeners[i] = &wrapped 303 | } 304 | 305 | recordInst = &record{list: unsafe.Pointer(&listeners)} 306 | recordInner := []*record{recordInst} 307 | recordList = NewRecordList() 308 | recordList.list = unsafe.Pointer(&recordInner) 309 | 310 | for i := removeEvery; i < count; i += removeEvery { 311 | recordList.Remove(NewEvent(), *listeners[i]) 312 | } 313 | 314 | return 315 | } 316 | 317 | func runBenchmarkAddBenchmark(count int, b *testing.B) { 318 | listeners, recordInst, recordList := createBenchmarkList(count, 3) 319 | b.ResetTimer() 320 | 321 | ev := NewEvent() 322 | fn := ListenerFunc(func(_ Event, _ []byte) {}) 323 | for i := 0; i < b.N; i++ { 324 | recordInst.list = unsafe.Pointer(&listeners) 325 | recordList.Add(ev, fn) 326 | } 327 | } 328 | 329 | func runBenchmarkRemoveBenchmark(count int, b *testing.B) { 330 | listeners, recordInst, recordList := createBenchmarkList(count, 3) 331 | first := listeners[0] 332 | b.ResetTimer() 333 | 334 | ev := NewEvent() 335 | for i := 0; i < b.N; i++ { 336 | listeners[0] = first 337 | recordInst.list = unsafe.Pointer(&listeners) 338 | recordList.Remove(ev, *first) 339 | } 340 | } 341 | 342 | func runBenchmarkBroadcastBenchmark(count int, b *testing.B) { 343 | _, recordInst, _ := createBenchmarkList(count, 3) 344 | b.ResetTimer() 345 | 346 | ev := NewEvent().ToEvent("", "") 347 | for i := 0; i < b.N; i++ { 348 | recordInst.Emit(ev, nil) 349 | } 350 | } 351 | 352 | func BenchmarkBroadcast1K(b *testing.B) { runBenchmarkBroadcastBenchmark(1000, b) } 353 | func BenchmarkBroadcast10K(b *testing.B) { runBenchmarkBroadcastBenchmark(10000, b) } 354 | func BenchmarkBroadcast100K(b *testing.B) { runBenchmarkBroadcastBenchmark(100000, b) } 355 | 356 | func BenchmarkRecordAdd1K(b *testing.B) { runBenchmarkAddBenchmark(1000, b) } 357 | func BenchmarkRecordAdd10K(b *testing.B) { runBenchmarkAddBenchmark(10000, b) } 358 | func BenchmarkRecordAdd100K(b *testing.B) { runBenchmarkAddBenchmark(100000, b) } 359 | 360 | func BenchmarkRecordRemove1K(b *testing.B) { runBenchmarkRemoveBenchmark(1000, b) } 361 | func BenchmarkRecordRemove10K(b *testing.B) { runBenchmarkRemoveBenchmark(10000, b) } 362 | func BenchmarkRecordRemove100K(b *testing.B) { runBenchmarkRemoveBenchmark(100000, b) } 363 | -------------------------------------------------------------------------------- /queue/base_queue.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/garyburd/redigo/redis" 8 | ) 9 | 10 | // BaseQueue provides a basic implementation of the Queue interface. Its basic 11 | // methodology is to preform updates using a Processor interface which in and of 12 | // itself defines how updates can be handled. 13 | type BaseQueue struct { 14 | pool *redis.Pool 15 | source string 16 | 17 | pmu sync.RWMutex 18 | processor Processor 19 | } 20 | 21 | var _ Queue = new(BaseQueue) 22 | 23 | func NewBaseQueue(pool *redis.Pool, source string) *BaseQueue { 24 | return &BaseQueue{ 25 | pool: pool, 26 | source: source, 27 | } 28 | } 29 | 30 | // Source implements the Source method on the Queue interface. 31 | func (q *BaseQueue) Source() string { 32 | return q.source 33 | } 34 | 35 | // Push pushes the given payload (a byte slice) into the specified keyspace by 36 | // delegating into the `Processor`'s `func Push`. It obtains a connection to 37 | // Redis using the pool, which is passed into the Processor, and recycles that 38 | // connection after the function has returned. 39 | // 40 | // If an error occurs during Pushing, it will be returned, and it can be assumed 41 | // that the payload is not in Redis. 42 | func (q *BaseQueue) Push(payload []byte) (err error) { 43 | cnx := q.pool.Get() 44 | defer cnx.Close() 45 | 46 | return q.Processor().Push(cnx, q.Source(), payload) 47 | } 48 | 49 | // Source implements the Source method on the Queue interface. 50 | func (q *BaseQueue) Pull(timeout time.Duration) (payload []byte, err error) { 51 | cnx := q.pool.Get() 52 | defer cnx.Close() 53 | 54 | return q.Processor().Pull(cnx, q.Source(), timeout) 55 | } 56 | 57 | // Source implements the Source method on the Queue interface. It functions by 58 | // requesting a read-level lock from the guarding mutex and returning that value 59 | // once obtained. If no processor is set, the the default FIFO implementation is 60 | // returned. 61 | func (q *BaseQueue) Processor() Processor { 62 | q.pmu.RLock() 63 | defer q.pmu.RUnlock() 64 | 65 | if q.processor == nil { 66 | return FIFO 67 | } 68 | 69 | return q.processor 70 | } 71 | 72 | // SetProcessor implements the SetProcessor method on the Queue interface. It 73 | // functions by requesting write-level access from the guarding mutex and 74 | // preforms the update atomically. 75 | func (q *BaseQueue) SetProcessor(processor Processor) { 76 | q.pmu.Lock() 77 | defer q.pmu.Unlock() 78 | 79 | q.processor = processor 80 | } 81 | 82 | // Concat takes all elements from the source queue and adds them to this one. This 83 | // can be a long-running operation. If a persistent error is returned while 84 | // moving things, then it will be returned and the concat will stop, though 85 | // the concat operation can be safely resumed at any time. 86 | func (q *BaseQueue) Concat(src string) (moved int, err error) { 87 | cnx := q.pool.Get() 88 | defer cnx.Close() 89 | 90 | errCount := 0 91 | for { 92 | err = q.Processor().Concat(cnx, src, q.Source()) 93 | if err == nil { 94 | errCount = 0 95 | moved++ 96 | continue 97 | } 98 | 99 | // ErrNil is returned when there are no more items to concat 100 | if err == redis.ErrNil { 101 | return 102 | } 103 | 104 | // Command error are bad; something is wrong in db and we should 105 | // return the problem to the caller. 106 | if _, cmdErr := err.(redis.Error); cmdErr { 107 | return 108 | } 109 | 110 | // Otherwise this is probably some temporary network error. Close 111 | // the old connection and try getting a new one. 112 | errCount++ 113 | if errCount >= concatRetries { 114 | return 115 | } 116 | 117 | cnx.Close() 118 | cnx = q.pool.Get() 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /queue/base_queue_test.go: -------------------------------------------------------------------------------- 1 | package queue_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "time" 7 | 8 | "github.com/garyburd/redigo/redis" 9 | "github.com/mixer/redutil/conn" 10 | "github.com/mixer/redutil/queue" 11 | "github.com/mixer/redutil/test" 12 | "github.com/stretchr/testify/mock" 13 | "github.com/stretchr/testify/suite" 14 | ) 15 | 16 | type BaseQueueSuite struct { 17 | *test.RedisSuite 18 | } 19 | 20 | func TestBaseQueueSuite(t *testing.T) { 21 | pool, _ := conn.New(conn.ConnectionParam{ 22 | Address: "127.0.0.1:6379", 23 | }, 1) 24 | 25 | suite.Run(t, &BaseQueueSuite{test.NewSuite(pool)}) 26 | } 27 | 28 | func (suite *BaseQueueSuite) TestPushDelegatesToProcesor() { 29 | processor := &MockProcessor{} 30 | processor. 31 | On("Push", 32 | mock.Anything, "foo", []byte("payload")). 33 | Return(errors.New("error")) 34 | 35 | q := queue.NewBaseQueue(suite.Pool, "foo") 36 | q.SetProcessor(processor) 37 | 38 | err := q.Push([]byte("payload")) 39 | 40 | suite.Assert().Equal("error", err.Error()) 41 | processor.AssertNumberOfCalls(suite.T(), "Push", 1) 42 | } 43 | 44 | func (suite *ByteQueueSuite) TestPullDelegatesToProcessor() { 45 | processor := &MockProcessor{} 46 | processor.On("Pull", 47 | mock.Anything, "foo", time.Second). 48 | Return([]byte("bar"), errors.New("error")) 49 | 50 | q := queue.NewBaseQueue(suite.Pool, "foo") 51 | q.SetProcessor(processor) 52 | 53 | payload, err := q.Pull(time.Second) 54 | 55 | suite.Assert().Equal([]byte("bar"), payload) 56 | suite.Assert().Equal("error", err.Error()) 57 | } 58 | 59 | func (suite *ByteQueueSuite) TestConcatsDelegatesToProcessor() { 60 | processor := &MockProcessor{} 61 | processor.On("Concat", 62 | mock.Anything, "bar", "foo"). 63 | Return(nil).Once() 64 | processor.On("Concat", 65 | mock.Anything, "bar", "foo"). 66 | Return(redis.ErrNil).Once() 67 | 68 | q := queue.NewBaseQueue(suite.Pool, "foo") 69 | q.SetProcessor(processor) 70 | 71 | _, err := q.Concat("bar") 72 | 73 | suite.Assert().Equal(redis.ErrNil, err) 74 | processor.AssertExpectations(suite.T()) 75 | } 76 | 77 | func (suite *ByteQueueSuite) TestConcatAbortsOnCommandError() { 78 | err := redis.Error("oh no!") 79 | processor := &MockProcessor{} 80 | processor.On("Concat", 81 | mock.Anything, "bar", "foo"). 82 | Return(err).Once() 83 | 84 | q := queue.NewBaseQueue(suite.Pool, "foo") 85 | 86 | q.SetProcessor(processor) 87 | _, qerr := q.Concat("bar") 88 | suite.Assert().Equal(err, qerr) 89 | 90 | processor.AssertExpectations(suite.T()) 91 | } 92 | 93 | func (suite *ByteQueueSuite) TestConcatRetriesOnCnxError() { 94 | processor := &MockProcessor{} 95 | processor.On("Concat", 96 | mock.Anything, "bar", "foo"). 97 | Return(errors.New("some net error or something")).Once() 98 | processor.On("Concat", 99 | mock.Anything, "bar", "foo"). 100 | Return(redis.ErrNil).Once() 101 | 102 | q := queue.NewBaseQueue(suite.Pool, "foo") 103 | q.SetProcessor(processor) 104 | 105 | _, err := q.Concat("bar") 106 | 107 | suite.Assert().Equal(redis.ErrNil, err) 108 | processor.AssertExpectations(suite.T()) 109 | } 110 | -------------------------------------------------------------------------------- /queue/byte_queue.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import "github.com/garyburd/redigo/redis" 4 | 5 | // The number of errors we can get while concatting in a row before giving 6 | // up and just returning. 7 | const concatRetries int = 3 8 | 9 | // ByteQueue represents either a FILO or FIFO queue contained in a particular 10 | // Redis keyspace. It allows callers to push `[]byte` payloads, and receive them 11 | // back over the `In() <-chan []byte`. It is typically used in a distributed 12 | // setting, where the pusher may not always get the item back. 13 | type ByteQueue struct { 14 | BaseQueue 15 | } 16 | 17 | // NewByteQueue allocates and returns a pointer to a new instance of a 18 | // ByteQueue. It initializes itself using the given *redis.Pool, and the name, 19 | // which refers to the keyspace wherein these values will be stored. 20 | // 21 | // Internal channels are also initialized here. 22 | func NewByteQueue(pool *redis.Pool, name string) *ByteQueue { 23 | return &ByteQueue{BaseQueue{ 24 | source: name, 25 | pool: pool, 26 | }} 27 | } 28 | -------------------------------------------------------------------------------- /queue/byte_queue_test.go: -------------------------------------------------------------------------------- 1 | package queue_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/mixer/redutil/conn" 7 | "github.com/mixer/redutil/queue" 8 | "github.com/mixer/redutil/test" 9 | "github.com/stretchr/testify/suite" 10 | ) 11 | 12 | type ByteQueueSuite struct { 13 | *test.RedisSuite 14 | } 15 | 16 | func TestByteQueueSuite(t *testing.T) { 17 | pool, _ := conn.New(conn.ConnectionParam{ 18 | Address: "127.0.0.1:6379", 19 | }, 1) 20 | 21 | suite.Run(t, &ByteQueueSuite{test.NewSuite(pool)}) 22 | } 23 | 24 | func (suite *ByteQueueSuite) TestConstruction() { 25 | q := queue.NewByteQueue(suite.Pool, "foo") 26 | 27 | suite.Assert().IsType(&queue.ByteQueue{}, q) 28 | } 29 | -------------------------------------------------------------------------------- /queue/durable_queue.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/garyburd/redigo/redis" 8 | ) 9 | 10 | // DurableQueue is an implementation of the Queue interface which takes items 11 | // from a source queue and pushes them into the destination queue when Pull() is 12 | // called. 13 | type DurableQueue struct { 14 | // dmu is a sync.RWMutex that guards the destination string 15 | dmu sync.RWMutex 16 | // dest is the Redis keyspace where Pulled() items end up. 17 | dest string 18 | 19 | // DurableQueue extends a BaseQueue 20 | BaseQueue 21 | } 22 | 23 | // DurableQueue implements the Queue type. 24 | var _ Queue = new(DurableQueue) 25 | 26 | // NewDurableQueue initializes and returns a new pointer to an instance of a 27 | // DurableQueue. It is initialized with the given Redis pool, and the source and 28 | // destination queues. By default the FIFO tactic is used, but a call to 29 | // SetProcessor can change this in a safe fashion. 30 | // 31 | // DurableQueues own no goroutines, so this method does not spwawn any 32 | // goroutines or channels. 33 | func NewDurableQueue(pool *redis.Pool, source, dest string) *DurableQueue { 34 | return &DurableQueue{ 35 | dest: dest, 36 | BaseQueue: BaseQueue{source: source, pool: pool}, 37 | } 38 | } 39 | 40 | // Pull implements the Pull function on the Queue interface. Unlike common 41 | // implementations of the Queue type, it mutates the Redis keyspace twice, by 42 | // removing an item from one LIST and popping it onto another. It does so by 43 | // delegating into the processor, thus blocking until the processor returns. 44 | func (q *DurableQueue) Pull(timeout time.Duration) (payload []byte, err error) { 45 | cnx := q.pool.Get() 46 | defer cnx.Close() 47 | 48 | return q.Processor().PullTo(cnx, q.Source(), q.Dest(), timeout) 49 | } 50 | 51 | // Dest returns the destination keyspace in Redis where pulled items end up. It 52 | // first obtains a read-level lock on the member `dest` variable before 53 | // returning. 54 | func (q *DurableQueue) Dest() string { 55 | q.dmu.RLock() 56 | defer q.dmu.RUnlock() 57 | 58 | return q.dest 59 | } 60 | 61 | // SetDest updates the destination where items are "pulled" to in a safe, 62 | // blocking manner. It does this by first obtaining a write-level lock on the 63 | // internal member variable wherein the destination is stored, updating, and 64 | // then relinquishing the lock. 65 | // 66 | // It returns the new destination that was just set. 67 | func (q *DurableQueue) SetDest(dest string) string { 68 | q.dmu.Lock() 69 | defer q.dmu.Unlock() 70 | 71 | q.dest = dest 72 | 73 | return q.dest 74 | } 75 | -------------------------------------------------------------------------------- /queue/durable_queue_test.go: -------------------------------------------------------------------------------- 1 | package queue_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/mixer/redutil/conn" 9 | "github.com/mixer/redutil/queue" 10 | "github.com/mixer/redutil/test" 11 | "github.com/stretchr/testify/mock" 12 | "github.com/stretchr/testify/suite" 13 | ) 14 | 15 | type DurableQueueSuite struct { 16 | *test.RedisSuite 17 | } 18 | 19 | func TestDurableQueueSuite(t *testing.T) { 20 | pool, _ := conn.New(conn.ConnectionParam{ 21 | Address: "127.0.0.1:6379", 22 | }, 1) 23 | 24 | suite.Run(t, &DurableQueueSuite{test.NewSuite(pool)}) 25 | } 26 | 27 | func (suite *DurableQueueSuite) TestPullDelegatesToProcessor() { 28 | p := &MockProcessor{} 29 | p.On("PullTo", 30 | mock.Anything, "foo", "bar", time.Second). 31 | Return([]byte("baz"), fmt.Errorf("woot")).Once() 32 | 33 | q := queue.NewDurableQueue(suite.Pool, "foo", "bar") 34 | q.SetProcessor(p) 35 | data, err := q.Pull(time.Second) 36 | 37 | suite.Assert().Equal([]byte("baz"), data) 38 | suite.Assert().Equal("woot", err.Error()) 39 | } 40 | 41 | func (suite *DurableQueueSuite) TestDestReturnsTheDestination() { 42 | q := queue.NewDurableQueue(suite.Pool, "foo", "bar") 43 | 44 | dest := q.Dest() 45 | 46 | suite.Assert().Equal("bar", dest) 47 | } 48 | 49 | func (suite *DurableQueueSuite) TestSetDestUpdatesTheDestination() { 50 | q := queue.NewDurableQueue(suite.Pool, "foo", "old") 51 | 52 | old := q.Dest() 53 | new := q.SetDest("new") 54 | 55 | suite.Assert().Equal("old", old) 56 | suite.Assert().Equal("new", new) 57 | } 58 | -------------------------------------------------------------------------------- /queue/fifo_processor.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/garyburd/redigo/redis" 7 | ) 8 | 9 | type fifoProcessor struct{} 10 | 11 | // FIFO is a first in, first out implementation of the Processor interface. 12 | var FIFO Processor = &fifoProcessor{} 13 | 14 | // Push implements the `func Push` from `Processor`. It pushes to the left-side 15 | // of the Redis structure using RPUSH, and returns any errors encountered while 16 | // runnning that command. 17 | func (f *fifoProcessor) Push(cnx redis.Conn, src string, payload []byte) (err error) { 18 | _, err = cnx.Do("LPUSH", src, payload) 19 | return 20 | } 21 | 22 | // Pull implements the `func Pull` from `Processor`. It pulls from the right-side 23 | // of the Redis structure in a blocking-fashion, using BLPOP. 24 | // 25 | // If an redis.ErrNil is returned, it is silenced, and both fields are returend 26 | // as nil. If the err is not a redis.ErrNil, but is still non-nil itself, then 27 | // it will be returend, along with an empty []byte. 28 | // 29 | // If an item can sucessfully be removed from the keyspace, it is returned 30 | // without error. 31 | func (f *fifoProcessor) Pull(cnx redis.Conn, src string, 32 | timeout time.Duration) ([]byte, error) { 33 | 34 | slices, err := redis.ByteSlices(cnx.Do("BRPOP", src, block(timeout))) 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | return slices[1], nil 40 | } 41 | 42 | // PullTo implements the `func PullTo` from the `Processor` interface. It pulls 43 | // from the right-side of the Redis source (src) structure, and pushes to the 44 | // right side of the Redis destination (dest) structure. 45 | func (f *fifoProcessor) PullTo(cnx redis.Conn, src, dest string, 46 | timeout time.Duration) ([]byte, error) { 47 | 48 | bytes, err := redis.Bytes(cnx.Do("BRPOPLPUSH", src, dest, block(timeout))) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | return bytes, nil 54 | } 55 | 56 | // Concat removes the first element from the source list and adds it to the end 57 | // of the destination list. ErrNil is returns when the source is empty. 58 | func (f *fifoProcessor) Concat(cnx redis.Conn, src, dest string) (err error) { 59 | return rlConcat(cnx, src, dest) 60 | } 61 | -------------------------------------------------------------------------------- /queue/fifo_processor_test.go: -------------------------------------------------------------------------------- 1 | package queue_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/garyburd/redigo/redis" 8 | "github.com/mixer/redutil/conn" 9 | "github.com/mixer/redutil/queue" 10 | "github.com/mixer/redutil/test" 11 | "github.com/stretchr/testify/suite" 12 | ) 13 | 14 | type FIFOProcessorTest struct { 15 | *test.RedisSuite 16 | } 17 | 18 | func TestFIFOProcessorSuite(t *testing.T) { 19 | pool, _ := conn.New(conn.ConnectionParam{ 20 | Address: "127.0.0.1:6379", 21 | }, 1) 22 | 23 | suite.Run(t, &FIFOProcessorTest{test.NewSuite(pool)}) 24 | } 25 | 26 | func (suite *FIFOProcessorTest) assertOrder(cnx redis.Conn) { 27 | first, e1 := queue.FIFO.Pull(cnx, "keyspace", time.Second) 28 | second, e2 := queue.FIFO.Pull(cnx, "keyspace", time.Second) 29 | third, e3 := queue.FIFO.Pull(cnx, "keyspace", time.Second) 30 | 31 | suite.Assert().Equal([]byte("first"), first) 32 | suite.Assert().Equal([]byte("second"), second) 33 | suite.Assert().Equal([]byte("third"), third) 34 | 35 | suite.Assert().Nil(e1) 36 | suite.Assert().Nil(e2) 37 | suite.Assert().Nil(e3) 38 | } 39 | 40 | func (suite *FIFOProcessorTest) TestPullToOrder() { 41 | cnx := suite.Pool.Get() 42 | defer cnx.Close() 43 | 44 | queue.FIFO.Push(cnx, "keyspace", []byte("first")) 45 | queue.FIFO.Push(cnx, "keyspace", []byte("second")) 46 | queue.FIFO.Push(cnx, "keyspace2", []byte("third")) 47 | 48 | queue.FIFO.PullTo(cnx, "keyspace2", "keyspace", time.Second) 49 | 50 | suite.assertOrder(cnx) 51 | } 52 | 53 | func (suite *FIFOProcessorTest) TestProcessingOrder() { 54 | cnx := suite.Pool.Get() 55 | defer cnx.Close() 56 | 57 | queue.FIFO.Push(cnx, "keyspace", []byte("first")) 58 | queue.FIFO.Push(cnx, "keyspace", []byte("second")) 59 | queue.FIFO.Push(cnx, "keyspace", []byte("third")) 60 | 61 | suite.assertOrder(cnx) 62 | } 63 | 64 | func (suite *FIFOProcessorTest) TestConcats() { 65 | cnx := suite.Pool.Get() 66 | defer cnx.Close() 67 | 68 | queue.FIFO.Push(cnx, "keyspace", []byte("first")) 69 | queue.FIFO.Push(cnx, "keyspace2", []byte("second")) 70 | queue.FIFO.Push(cnx, "keyspace2", []byte("third")) 71 | 72 | suite.Assert().Nil(queue.FIFO.Concat(cnx, "keyspace2", "keyspace")) 73 | suite.Assert().Nil(queue.FIFO.Concat(cnx, "keyspace2", "keyspace")) 74 | suite.Assert().Equal(redis.ErrNil, queue.FIFO.Concat(cnx, "keyspace2", "keyspace")) 75 | 76 | suite.assertOrder(cnx) 77 | } 78 | 79 | // Unfortunately, this test takes a lot of time to run since redis does not 80 | // support floating point timeouts. 81 | func (suite *FIFOProcessorTest) TestPullRespectsTimeouts() { 82 | cnx := suite.Pool.Get() 83 | defer cnx.Close() 84 | 85 | b, err := queue.FIFO.Pull(cnx, "keyspace", 250*time.Millisecond) 86 | 87 | suite.Assert().Empty(b) 88 | suite.Assert().Equal(redis.ErrNil, err) 89 | } 90 | -------------------------------------------------------------------------------- /queue/lifo_processor.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/garyburd/redigo/redis" 7 | ) 8 | 9 | type lifoProcessor struct{} 10 | 11 | // FIFO is a last in, first out implementation of the Processor interface. 12 | var LIFO Processor = &lifoProcessor{} 13 | 14 | // Push implements the `func Push` from `Processor`. It pushes the right-side 15 | // of the Redis structure using RPUSH, and returns any errors encountered while 16 | // runnning that command. 17 | func (l *lifoProcessor) Push(cnx redis.Conn, src string, payload []byte) (err error) { 18 | _, err = cnx.Do("RPUSH", src, payload) 19 | return 20 | } 21 | 22 | // Pull implements the `func Pull` from `Processor`. It pulls from the 23 | // right-side of the Redis structure in a blocking-fashion, using BRPOP. 24 | // 25 | // If an redis.ErrNil is returned, it is silenced, and both fields are returend 26 | // as nil. If the err is not a redis.ErrNil, but is still non-nil itself, then 27 | // it will be returend, along with an empty []byte. 28 | // 29 | // If an item can successfully be removed from the keyspace, it is returned 30 | // without error. 31 | func (l *lifoProcessor) Pull(cnx redis.Conn, src string, 32 | timeout time.Duration) ([]byte, error) { 33 | 34 | slices, err := redis.ByteSlices(cnx.Do("BRPOP", src, block(timeout))) 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | return slices[1], nil 40 | } 41 | 42 | // PullTo implements the `func PullTo` from the `Processor` interface. It pulls 43 | // from the right-side of the Redis source (src) structure, and pushes to the 44 | // left side of the Redis destination (dest) structure. 45 | // 46 | // Warning: unlike Pull() and the PullTo() method on the FIFO process, this 47 | // is NOT blocking and will return redis.ErrNil if there is not anything on 48 | // the queue when the method is called. 49 | func (l *lifoProcessor) PullTo(cnx redis.Conn, src, dest string, 50 | _ time.Duration) ([]byte, error) { 51 | 52 | bytes, err := redis.Bytes(LPOPRPUSH.Do(cnx, src, dest)) 53 | if err != nil { 54 | return nil, err 55 | } 56 | 57 | return bytes, nil 58 | } 59 | 60 | // Concat removes the first element from the source list and adds it to the end 61 | // of the destination list. ErrNil is returns when the source is empty. 62 | func (l *lifoProcessor) Concat(cnx redis.Conn, src, dest string) (err error) { 63 | bytes, err := l.PullTo(cnx, src, dest, 0*time.Second) 64 | if err == nil && bytes == nil { 65 | err = redis.ErrNil 66 | } 67 | 68 | return 69 | } 70 | -------------------------------------------------------------------------------- /queue/lifo_processor_test.go: -------------------------------------------------------------------------------- 1 | package queue_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/garyburd/redigo/redis" 8 | "github.com/mixer/redutil/conn" 9 | "github.com/mixer/redutil/queue" 10 | "github.com/mixer/redutil/test" 11 | "github.com/stretchr/testify/suite" 12 | ) 13 | 14 | type LIFOProcessorTest struct { 15 | *test.RedisSuite 16 | } 17 | 18 | func TestLIFOProcessorSuite(t *testing.T) { 19 | pool, _ := conn.New(conn.ConnectionParam{ 20 | Address: "127.0.0.1:6379", 21 | }, 1) 22 | 23 | suite.Run(t, &LIFOProcessorTest{test.NewSuite(pool)}) 24 | } 25 | 26 | func (suite *LIFOProcessorTest) assertOrder(cnx redis.Conn) { 27 | first, e1 := queue.LIFO.Pull(cnx, "keyspace", time.Second) 28 | second, e2 := queue.LIFO.Pull(cnx, "keyspace", time.Second) 29 | third, e3 := queue.LIFO.Pull(cnx, "keyspace", time.Second) 30 | 31 | suite.Assert().Equal([]byte("third"), first) 32 | suite.Assert().Equal([]byte("second"), second) 33 | suite.Assert().Equal([]byte("first"), third) 34 | 35 | suite.Assert().Nil(e1) 36 | suite.Assert().Nil(e2) 37 | suite.Assert().Nil(e3) 38 | } 39 | 40 | func (suite *LIFOProcessorTest) TestPullToOrder() { 41 | cnx := suite.Pool.Get() 42 | defer cnx.Close() 43 | 44 | queue.FIFO.Push(cnx, "keyspace", []byte("third")) 45 | queue.FIFO.Push(cnx, "keyspace", []byte("second")) 46 | queue.FIFO.Push(cnx, "keyspace2", []byte("first")) 47 | 48 | queue.FIFO.PullTo(cnx, "keyspace2", "keyspace", time.Second) 49 | 50 | suite.assertOrder(cnx) 51 | } 52 | 53 | func (suite *LIFOProcessorTest) TestProcessingOrder() { 54 | cnx := suite.Pool.Get() 55 | defer cnx.Close() 56 | 57 | queue.LIFO.Push(cnx, "keyspace", []byte("first")) 58 | queue.LIFO.Push(cnx, "keyspace", []byte("second")) 59 | queue.LIFO.Push(cnx, "keyspace", []byte("third")) 60 | 61 | suite.assertOrder(cnx) 62 | } 63 | 64 | func (suite *LIFOProcessorTest) TestConcats() { 65 | cnx := suite.Pool.Get() 66 | defer cnx.Close() 67 | 68 | queue.LIFO.Push(cnx, "keyspace", []byte("first")) 69 | queue.LIFO.Push(cnx, "keyspace2", []byte("second")) 70 | queue.LIFO.Push(cnx, "keyspace2", []byte("third")) 71 | 72 | suite.Assert().Nil(queue.LIFO.Concat(cnx, "keyspace2", "keyspace")) 73 | suite.Assert().Nil(queue.LIFO.Concat(cnx, "keyspace2", "keyspace")) 74 | suite.Assert().Equal(redis.ErrNil, queue.LIFO.Concat(cnx, "keyspace2", "keyspace")) 75 | 76 | suite.assertOrder(cnx) 77 | } 78 | 79 | // Unfortunately, this test takes a lot of time to run since redis does not 80 | // support floating point timeouts. 81 | func (suite *LIFOProcessorTest) TestPullRespectsTimeouts() { 82 | cnx := suite.Pool.Get() 83 | defer cnx.Close() 84 | 85 | b, err := queue.LIFO.Pull(cnx, "keyspace", 250*time.Millisecond) 86 | 87 | suite.Assert().Empty(b) 88 | suite.Assert().Equal(redis.ErrNil, err) 89 | } 90 | -------------------------------------------------------------------------------- /queue/processor.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import ( 4 | "math" 5 | "time" 6 | 7 | "github.com/garyburd/redigo/redis" 8 | ) 9 | 10 | // Processor is an interface to a type encapsulating the interaction between a 11 | // queue.ByteQueue and a datastructure in Redis. 12 | type Processor interface { 13 | // Push pushes a given `payload` into the keyspace at `key` over the 14 | // given `redis.Conn`. This function should block until the item can 15 | // succesfully be confirmed to have been pushed. 16 | Push(conn redis.Conn, src string, payload []byte) (err error) 17 | 18 | // Pull pulls a given `payload` from the keyspace at `key` over the 19 | // given `redis.Conn`. This function should block until the given 20 | // timeout has elapsed, or an item is available. If the timeout has 21 | // passed, a redis.ErrNil will be returned. 22 | Pull(conn redis.Conn, src string, timeout time.Duration) (payload []byte, err error) 23 | 24 | // PullTo transfers a given payload from the source (src) keyspace to 25 | // the destination (dest) keyspace and returns the moved item in the 26 | // payload space. If an error was encountered, then it will be returned 27 | // immediately. Timeout semantisc are idential to those on Pull, unless 28 | // noted otherwise in implementation. 29 | PullTo(conn redis.Conn, src, dest string, timeout time.Duration) (payload []byte, err error) 30 | 31 | // Moves all elements from the src queue to the end of the destination 32 | // It should return a redis.ErrNil when the source queue is empty. 33 | Concat(conn redis.Conn, src, dest string) (err error) 34 | } 35 | 36 | func block(timeout time.Duration) float64 { 37 | return math.Ceil(timeout.Seconds()) 38 | } 39 | -------------------------------------------------------------------------------- /queue/processor_test.go: -------------------------------------------------------------------------------- 1 | package queue_test 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/garyburd/redigo/redis" 7 | "github.com/stretchr/testify/mock" 8 | ) 9 | 10 | type MockProcessor struct { 11 | mock.Mock 12 | } 13 | 14 | func (m *MockProcessor) Push(cnx redis.Conn, src string, payload []byte) error { 15 | args := m.Called(cnx, src, payload) 16 | return args.Error(0) 17 | } 18 | 19 | func (m *MockProcessor) Pull(cnx redis.Conn, src string, 20 | timeout time.Duration) ([]byte, error) { 21 | 22 | args := m.Called(cnx, src, timeout) 23 | return args.Get(0).([]byte), args.Error(1) 24 | } 25 | 26 | func (m *MockProcessor) PullTo(cnx redis.Conn, src, dest string, 27 | timeout time.Duration) ([]byte, error) { 28 | 29 | args := m.Called(cnx, src, dest, timeout) 30 | return args.Get(0).([]byte), args.Error(1) 31 | } 32 | 33 | func (m *MockProcessor) Concat(cnx redis.Conn, src, dest string) error { 34 | args := m.Called(cnx, src, dest) 35 | return args.Error(0) 36 | } 37 | -------------------------------------------------------------------------------- /queue/queue.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import "time" 4 | 5 | type Queue interface { 6 | // Source returns the keyspace in Redis from which this queue is 7 | // populated. 8 | Source() string 9 | 10 | // Push pushes the given payload (a byte slice) into the specified 11 | // keyspace by delegating into the `Processor`'s `func Push`. It obtains 12 | // a connection to Redis using the pool, which is passed into the 13 | // Processor, and recycles that connection after the function has 14 | // returned. 15 | // 16 | // If an error occurs during Pushing, it will be returned, and it can be 17 | // assumed that the payload is not in Redis. 18 | Push(payload []byte) (err error) 19 | 20 | // Pull returns the next available payload, blocking until data can be 21 | // returned. 22 | Pull(timeout time.Duration) (payload []byte, err error) 23 | 24 | // Takes all elements from the source queue and adds them to this one. This 25 | // can be a long-running operation. If a persistent error is returned while 26 | // moving things, then concat will stop, though the concat operation can 27 | // be safely resumed at any time. 28 | // 29 | // Returns the number of items successfully moved and any error that 30 | // occurred. 31 | Concat(src string) (moved int, err error) 32 | 33 | // Processor returns the processor that is being used to push and pull. 34 | // If no processor is specified, a first-in-first-out will be returned 35 | // by default. 36 | Processor() Processor 37 | 38 | // SetProcessor sets the current processor to the specified processor by 39 | // aquiring a write lock into the mutex guarding that field. The 40 | // processor will be switched over during the next iteration of a 41 | // Pull-cycle, or a call to Push. 42 | SetProcessor(processor Processor) 43 | } 44 | -------------------------------------------------------------------------------- /queue/scripts.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import "github.com/garyburd/redigo/redis" 4 | 5 | var LPOPRPUSH = redis.NewScript(2, ` 6 | local v = redis.call('lpop', KEYS[1]) 7 | if v == nil or v == false then 8 | return nil 9 | end 10 | 11 | redis.call('rpush', KEYS[2], v) 12 | return v 13 | `) 14 | -------------------------------------------------------------------------------- /queue/util.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import "github.com/garyburd/redigo/redis" 4 | 5 | // Concat implementation using RPOPLPUSH, compatible 6 | // with the behaviour of Processor.Queue. 7 | func rlConcat(cnx redis.Conn, src, dest string) error { 8 | data, err := cnx.Do("RPOPLPUSH", src, dest) 9 | if err != nil { 10 | return err 11 | } 12 | if data == nil { 13 | return redis.ErrNil 14 | } 15 | 16 | return nil 17 | } 18 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # redutil [![Build Status](https://travis-ci.org/mixer/redutil.svg?branch=master)](https://travis-ci.org/mixer/redutil) [![Coverage Status](https://coveralls.io/repos/MCProHosting/redutil/badge.svg?branch=master)](https://coveralls.io/r/MCProHosting/redutil?branch=master) [![godoc reference](https://godoc.org/github.com/mixer/redutil?status.png)](https://godoc.org/github.com/mixer/redutil) 2 | 3 | 4 | This package consists of several utilities to make Redis easier and more consistent in Go. 5 | 6 | ## pubsub 7 | 8 | Traditional Redis libraries allow you to subscribe to events, and maybe even pool connections. But there's often no mechanism for maintaining subscribed state in the event of a connection failure, and many packages aren't thread-safe. This package, `redutil/pubsub`, solves these issues. 9 | 10 | It is fully thread safe and unit tested. We're currently using it in production, though it has not yet been entirely battle-tested. Feel free to open issues on this repository. 11 | 12 | ```go 13 | package main 14 | 15 | import ( 16 | "time" 17 | "gopkg.in/mixer/redutil.v2/conn" 18 | "gopkg.in/mixer/redutil.v2/pubsub" 19 | ) 20 | 21 | func main() { 22 | // Create a new pubsub client. This will create and manage connections, 23 | // even if you disconnect. 24 | c := pubsub.New(conn.New(conn.ConnectionParam{ 25 | Address: "127.0.0.1:6379", 26 | // optional password 27 | Password: "secret", 28 | }, 1)) 29 | go client.Connect() 30 | defer c.TearDown() 31 | 32 | go listenChannel(c) 33 | go listenPattern(c) 34 | 35 | // Wait forever! 36 | select {} 37 | } 38 | 39 | 40 | // Simple example function that listens for all events broadcast 41 | // in the channel "chan". 42 | func listenChannel(c *pubsub.Client) { 43 | listener := c.Listen(pubsub.Channel, "chan") 44 | defer listener.Unsubscribe() 45 | for _, message := range listener.Messages { 46 | doStuff() 47 | } 48 | } 49 | 50 | // Example that listens for events that match the pattern 51 | // "foo:*:bar". Note that we listen to the `PMessages` channel, not `Messages`. 52 | func listenPattern(c *pubsub.Client) { 53 | listener := c.Listen(pubsub.Pattern, "foo:*:bar") 54 | defer listener.Unsubscribe() 55 | 56 | for _, message := range listener.PMessages { 57 | // You got mail! 58 | } 59 | } 60 | ``` 61 | 62 | ## License 63 | 64 | Copyright 2015-2016 by Beam LLC. Distributed under the MIT license. 65 | -------------------------------------------------------------------------------- /test/redis_suite.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "github.com/garyburd/redigo/redis" 5 | "github.com/stretchr/testify/suite" 6 | ) 7 | 8 | // RedisSuite is a type to be used during testing that wraps the testify's 9 | // `suite.Suite` type and gives a *redis.Pool for us to work with as well. 10 | type RedisSuite struct { 11 | // Pool is the pool that Redis connections should be pulled from during 12 | // test. 13 | Pool *redis.Pool 14 | 15 | suite.Suite 16 | } 17 | 18 | // NewSuite constructs a suite with the give pool. 19 | func NewSuite(pool *redis.Pool) *RedisSuite { 20 | return &RedisSuite{Pool: pool} 21 | } 22 | 23 | // SetupTest implements the SetupTest function and entirely clears Redis of 24 | // items before each test run to prevent order-related test issues. 25 | func (s *RedisSuite) SetupTest() { 26 | s.WithRedis(func(cnx redis.Conn) { 27 | cnx.Do("FLUSHALL") 28 | }) 29 | } 30 | 31 | // MustDo executes the command on a new Redis connection and panics if there's 32 | // an error executing it. 33 | func (r *RedisSuite) MustDo(cmd string, args ...interface{}) (reply interface{}) { 34 | cnx := r.Pool.Get() 35 | defer cnx.Close() 36 | 37 | reply, err := cnx.Do(cmd, args...) 38 | if err != nil { 39 | panic(err) 40 | } 41 | 42 | return reply 43 | } 44 | 45 | // WithRedis runs a function and passes it a valid redis.Conn instance. It does 46 | // so by obtaining the redis.Conn instance from the owned *redis.Pool and then 47 | // closing once the outer function has returned. 48 | func (s *RedisSuite) WithRedis(fn func(redis.Conn)) { 49 | cnx := s.Pool.Get() 50 | defer cnx.Close() 51 | 52 | fn(cnx) 53 | } 54 | -------------------------------------------------------------------------------- /worker/default_lifecycle.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "sync" 7 | "time" 8 | 9 | "github.com/garyburd/redigo/redis" 10 | "github.com/mixer/redutil/queue" 11 | ) 12 | 13 | // ErrNotFound is returned if you attempt to mark a task as complete or abandoned 14 | // that wasn't registered in the lifecycle. 15 | var ErrNotFound = errors.New("Attempted to complete a task that we aren't working on.") 16 | 17 | // deleteToken is the value deleted items are set to pending removal: 16 18 | // random bytes. This is used since removal of items in Redis is necessarily 19 | // a two-step operation; we can only delete items by value, not by index in 20 | // the queue, but we *can* set items by their index. 21 | // 22 | // Hopefully Redis, at some point, implements removal by index, allowing 23 | // us to do away with this ugliness. 24 | var deleteToken = []byte{0x16, 0xea, 0x58, 0x1f, 0xbd, 0x4a, 0x23, 0xc2, 25 | 0x66, 0x97, 0x8a, 0x35, 0xb7, 0xd0, 0x22, 0xef} 26 | 27 | // DefaultLifecycle provides a default implementation of the Lifecycle 28 | // interface. It moves tasks from a source, which provides available tasks, into 29 | // a specific worker queue, which is the list of items that this worker is 30 | // currently working on. 31 | // 32 | // Completed tasks leave the individualized worker queue, while abandoned tasks 33 | // move back to the task source. 34 | type DefaultLifecycle struct { 35 | // pool is a *redis.Pool used to maintain and use connections into 36 | // Redis. 37 | pool *redis.Pool 38 | 39 | // availableTasks is a Redis queue that contains an in-order list of 40 | // tasks that need to be worked on. Workers race into this list. 41 | availableTasks queue.Queue 42 | // workingTasks contains the list of tasks that this particular worker 43 | // is currently working on. See above semantics as to where these items 44 | // move to and from. 45 | workingTasks *queue.DurableQueue 46 | 47 | // rmu guards registry 48 | rmu sync.Mutex 49 | // registry is a local copy of the workingTasks queue so we can easily 50 | // delete items by index. 51 | registry []*Task 52 | 53 | // tasks are a channel of tasks that this worker has taken ownership of 54 | // and needs to work on. 55 | tasks chan *Task 56 | // errs is a channel of errors that gets written to when an error is 57 | // encountered by the recv() function. 58 | errs chan error 59 | 60 | // closer is a channel that receives an empty message when it is time to 61 | // close. 62 | closer chan struct{} 63 | 64 | // wg is a WaitGroup that keeps track of all actively owned tasks. When 65 | // a task is pulled, the state of this WaitGroup increases, and 66 | // conversely decreases when the task is either COMPLETED, or FAILED. 67 | wg sync.WaitGroup 68 | } 69 | 70 | // NewLifecycle allocates and returns a pointer to a new instance of 71 | // DefaultLifecycle. It uses the specified pool to make connections into Redis 72 | // and a queue of available tasks along with a second working tasks queue, 73 | // which stores the items the lifecycle is currently processing. 74 | func NewLifecycle(pool *redis.Pool) *DefaultLifecycle { 75 | return &DefaultLifecycle{pool: pool} 76 | } 77 | 78 | var _ Lifecycle = new(DefaultLifecycle) 79 | 80 | // Await implements Lifecycle's Await func, blocking until all tasks are either 81 | // completed or abandoned. If Await in conjunction with a AbandonAll, it will 82 | // wait until all tasks have been successfully abandoned before returning. 83 | func (l *DefaultLifecycle) Await() { 84 | l.wg.Wait() 85 | } 86 | 87 | func (l *DefaultLifecycle) SetQueues(availableTasks queue.Queue, 88 | workingTasks *queue.DurableQueue) { 89 | l.availableTasks = availableTasks 90 | l.workingTasks = workingTasks 91 | } 92 | 93 | // Listen returns a channel of tasks and error that are pulled from the 94 | // main processing queue. Once a *Task is able to be read from the <-chan *Task, 95 | // that Task is ready to be worked on and is in the appropriate locations in 96 | // Redis. StopListening() can be called to terminate. 97 | func (l *DefaultLifecycle) Listen() (<-chan *Task, <-chan error) { 98 | l.tasks = make(chan *Task) 99 | l.errs = make(chan error) 100 | l.closer = make(chan struct{}) 101 | 102 | go l.recv() 103 | 104 | return l.tasks, l.errs 105 | } 106 | 107 | // StopListening closes the queue "pull" task at the next opportunity. Tasks 108 | // can still be marked as completed or abandoned, but no new tasks will 109 | // be generated. 110 | func (l *DefaultLifecycle) StopListening() { 111 | l.closer <- struct{}{} 112 | } 113 | 114 | // Complete marks a task as having been completed, removing it from the 115 | // worker's queue. 116 | func (l *DefaultLifecycle) Complete(task *Task) error { 117 | return l.removeTask(task) 118 | } 119 | 120 | // Abandon marks a task as having failed, pushing it back onto the primary 121 | // task queue and removing it from our worker queue. 122 | func (l *DefaultLifecycle) Abandon(task *Task) error { 123 | if err := l.availableTasks.Push(task.Bytes()); err != nil { 124 | return err 125 | } 126 | 127 | return l.removeTask(task) 128 | } 129 | 130 | // Marks *all* tasks in the queue as having been abandoned. Called by the 131 | // worker in the Halt() method. 132 | func (l *DefaultLifecycle) AbandonAll() error { 133 | l.rmu.Lock() 134 | defer l.rmu.Unlock() 135 | 136 | moved, err := l.availableTasks.Concat(l.workingTasks.Dest()) 137 | remaining := len(l.registry) - moved 138 | 139 | l.registry = l.registry[0:remaining] 140 | l.wg.Add(-1 * moved) 141 | 142 | return err 143 | } 144 | 145 | // recv pulls items from the queue and publishes them into the <-chan *Task. If 146 | // a message is sent over the `closer` channel, then this process will be 147 | // stoped. 148 | func (l *DefaultLifecycle) recv() { 149 | for { 150 | select { 151 | case <-l.closer: 152 | close(l.errs) 153 | return 154 | default: 155 | payload, err := l.workingTasks.Pull(time.Second) 156 | if err == redis.ErrNil { 157 | continue 158 | } 159 | 160 | if err != nil { 161 | l.errs <- err 162 | continue 163 | } 164 | 165 | if payload == nil || bytes.Equal(deleteToken, payload) { 166 | continue 167 | } 168 | 169 | task := NewTask(l, payload) 170 | l.addTask(task) 171 | l.tasks <- task 172 | } 173 | } 174 | } 175 | 176 | // addTask inserts a newly created task into the internal tasks registry. 177 | func (l *DefaultLifecycle) addTask(t *Task) { 178 | l.rmu.Lock() 179 | defer l.rmu.Unlock() 180 | 181 | l.registry = append([]*Task{t}, l.registry...) 182 | l.wg.Add(1) 183 | } 184 | 185 | // Removes a task from the worker's task queue. 186 | func (l *DefaultLifecycle) removeTask(task *Task) (err error) { 187 | l.rmu.Lock() 188 | defer l.rmu.Unlock() 189 | cnx := l.pool.Get() 190 | defer cnx.Close() 191 | 192 | i := l.findTaskIndex(task) 193 | if i == -1 { 194 | return ErrNotFound 195 | } 196 | 197 | count := len(l.registry) 198 | l.registry = append(l.registry[:i], l.registry[i+1:]...) 199 | 200 | // We set the item relative to the end position of the list. Since the 201 | // queue is running BRPOPLPUSH, the index relative to the start of the 202 | // list (left side) might change in the meantime. 203 | _, err = cnx.Do("LSET", l.workingTasks.Dest(), i-count, deleteToken) 204 | if err != nil { 205 | return 206 | } 207 | 208 | // Ignore errors from trimming. If this fails, it's unfortunate, but the 209 | // task was still removed successfully. The next LTRIM will remove the 210 | // item or, if not, we'll just ignore it if we read it from the queue. 211 | cnx.Do("LREM", l.workingTasks.Dest(), 0, deleteToken) 212 | 213 | l.wg.Done() 214 | return nil 215 | } 216 | 217 | // Returns the index of the task in the tasks list. Returns -1 if the task 218 | // was not in the list. 219 | func (l *DefaultLifecycle) findTaskIndex(task *Task) int { 220 | for i, t := range l.registry { 221 | if t == task { 222 | return i 223 | } 224 | } 225 | 226 | return -1 227 | } 228 | -------------------------------------------------------------------------------- /worker/default_lifecycle_test.go: -------------------------------------------------------------------------------- 1 | package worker_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/garyburd/redigo/redis" 7 | "github.com/mixer/redutil/conn" 8 | "github.com/mixer/redutil/queue" 9 | "github.com/mixer/redutil/test" 10 | "github.com/mixer/redutil/worker" 11 | "github.com/stretchr/testify/suite" 12 | ) 13 | 14 | type DefaultLifecycleSuite struct { 15 | *test.RedisSuite 16 | } 17 | 18 | func TestDefaultLifecycleSuite(t *testing.T) { 19 | pool, _ := conn.New(conn.ConnectionParam{ 20 | Address: "127.0.0.1:6379", 21 | }, 1) 22 | 23 | suite.Run(t, &DefaultLifecycleSuite{test.NewSuite(pool)}) 24 | } 25 | 26 | func (d *DefaultLifecycleSuite) makeLifecycle(src, working string) worker.Lifecycle { 27 | l := worker.NewLifecycle(d.Pool) 28 | 29 | l.SetQueues( 30 | queue.NewByteQueue(d.Pool, src), 31 | queue.NewDurableQueue(d.Pool, src, working), 32 | ) 33 | 34 | return l 35 | } 36 | 37 | func (suite *DefaultLifecycleSuite) TestConstruction() { 38 | l := suite.makeLifecycle("queue", "worker_1") 39 | suite.Assert().IsType(&worker.DefaultLifecycle{}, l) 40 | } 41 | 42 | func (suite *DefaultLifecycleSuite) TestListenReturnsTasks() { 43 | l := suite.makeLifecycle("queue", "worker_1") 44 | queue := queue.NewByteQueue(suite.Pool, "queue") 45 | tasks, _ := l.Listen() 46 | defer func() { 47 | l.AbandonAll() 48 | }() 49 | 50 | queue.Push([]byte("hello, world!")) 51 | 52 | task := <-tasks 53 | l.StopListening() 54 | 55 | suite.Assert().Equal([]byte("hello, world!"), task.Bytes()) 56 | } 57 | 58 | func (suite *DefaultLifecycleSuite) TestCompletedTasksRemovedFromAllQueues() { 59 | l := suite.makeLifecycle("queue", "worker_1") 60 | defer l.StopListening() 61 | 62 | queue := queue.NewByteQueue(suite.Pool, "queue") 63 | queue.Push([]byte("some_task")) 64 | 65 | tasks, _ := l.Listen() 66 | 67 | l.Complete(<-tasks) 68 | 69 | suite.WithRedis(func(conn redis.Conn) { 70 | ql := suite.RedisLength("queue") 71 | wl := suite.RedisLength("worker_1") 72 | 73 | suite.Assert().Equal(0, ql, "redutil: main queue should be empty, but wasn't") 74 | suite.Assert().Equal(0, wl, "redutil: worker (worker_1) queue should be empty, but wasn't") 75 | }) 76 | } 77 | 78 | func (suite *DefaultLifecycleSuite) TestAbandonedTasksRemovedFromWorkerQueue() { 79 | cnx := suite.Pool.Get() 80 | defer cnx.Close() 81 | 82 | l := suite.makeLifecycle("queue", "worker_1") 83 | 84 | queue := queue.NewByteQueue(suite.Pool, "queue") 85 | queue.Push([]byte("some_task")) 86 | 87 | tasks, _ := l.Listen() 88 | task := <-tasks 89 | l.StopListening() 90 | 91 | suite.Assert().Equal(1, suite.RedisLength("worker_1"), 92 | "redutil: worker (worker_1) queue should have one item, but doesn't") 93 | suite.Assert().Equal(0, suite.RedisLength("queue"), 94 | "redutil: main queue should be empty, but wasn't") 95 | 96 | l.Abandon(task) 97 | 98 | suite.Assert().Equal(0, suite.RedisLength("worker_1"), 99 | "redutil: worker (worker_1) should be empty, but isn't") 100 | suite.Assert().Equal(1, suite.RedisLength("queue"), 101 | "redutil: main queue should have one item, but doesn't") 102 | } 103 | 104 | func (suite *DefaultLifecycleSuite) TestAbandonAllMovesAllTasksToMainQueue() { 105 | l := suite.makeLifecycle("queue", "worker_1") 106 | 107 | queue := queue.NewByteQueue(suite.Pool, "queue") 108 | queue.Push([]byte("task_1")) 109 | queue.Push([]byte("task_2")) 110 | queue.Push([]byte("task_3")) 111 | 112 | tasks, _ := l.Listen() 113 | for i := 0; i < 3; i++ { 114 | <-tasks 115 | } 116 | l.StopListening() 117 | 118 | suite.Assert().Equal(3, suite.RedisLength("worker_1")) 119 | suite.Assert().Equal(0, suite.RedisLength("queue")) 120 | 121 | l.AbandonAll() 122 | 123 | suite.Assert().Equal(0, suite.RedisLength("worker_1")) 124 | suite.Assert().Equal(3, suite.RedisLength("queue")) 125 | } 126 | 127 | func (suite *DefaultLifecycleSuite) RedisLength(keyspace string) int { 128 | cnx := suite.Pool.Get() 129 | defer cnx.Close() 130 | 131 | len, err := redis.Int(cnx.Do("LLEN", keyspace)) 132 | if err != nil { 133 | return -1 134 | } 135 | 136 | return len 137 | } 138 | -------------------------------------------------------------------------------- /worker/default_worker.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "time" 7 | 8 | "github.com/garyburd/redigo/redis" 9 | "github.com/mixer/redutil/heartbeat" 10 | "github.com/mixer/redutil/queue" 11 | ) 12 | 13 | // Internal state tracking used in the worker. 14 | type state uint8 15 | 16 | const ( 17 | idle state = iota // we've not yet started 18 | open // tasks are running or can be run on the worker 19 | halting // we're terminating all ongoing tasks 20 | closing // we're waiting for tasks to gracefully close 21 | closed // all tasks have terminated 22 | ) 23 | 24 | // A DefaultWorker is the bridge between Redutil Queues and the worker pattern. Items 25 | // can be moved around between different queues using a lifecycle (see 26 | // Lifecycle, DefaultLifecycle), and worked on by clients. "Dead" workers' items 27 | // are recovered by other, living ones, providing an in-order, reliable, 28 | // distributed implementation of the worker pattern. 29 | type DefaultWorker struct { 30 | pool *redis.Pool 31 | 32 | // availableTasks is a Redis queue that contains an in-order list of 33 | // tasks that need to be worked on. Workers race into this list. 34 | availableTasks queue.Queue 35 | // workingTasks contains the list of tasks that this particular worker 36 | // is currently working on. See above semantics as to where these items 37 | // move to and from. 38 | workingTasks *queue.DurableQueue 39 | 40 | lifecycle Lifecycle 41 | 42 | // The heartbeat components are used to maintain the state of the worker 43 | // and to detect dead workers to clean up. 44 | detector heartbeat.Detector 45 | heart heartbeat.Heart 46 | 47 | // The janitor is responsible for cleaning up dead workers. 48 | janitor Janitor 49 | janitorRunner *janitorRunner 50 | 51 | // smu wraps Worker#state in a loving, mutex-y embrace. 52 | smu sync.Mutex 53 | // The open or closed state of the worker. Locked by the cond. 54 | state state 55 | } 56 | 57 | const ( 58 | // Default interval passed into heartbeat.New 59 | defaultHeartInterval = 10 * time.Second 60 | // Default interval we use to check for dead workers. Note that the first 61 | // check will be anywhere in the range [0, monitor interval]; this is 62 | // randomized so that workers that start at the same time will not 63 | // contest the same locks. 64 | defaultMonitorInterval = 15 * time.Second 65 | ) 66 | 67 | // Returns the name of the working queue based on the worker's processing 68 | // source and worker ID. This is purposely NOT readily configurable; this 69 | // is not something you have to touch for 99% of redutil usage, and 70 | // incorrectly configuring this can result in Bad Things (dropped 71 | // jobs, duplicate jobs, etc). 72 | func getWorkingQueueName(src, id string) string { 73 | return fmt.Sprintf("%s:worker_%s", src, id) 74 | } 75 | 76 | // New creates and returns a pointer to a new instance of a DefaultWorker. It uses the 77 | // given redis.Pool, the main queue to pull from (`src`), and is given a 78 | // unique ID through the `id` paramter. 79 | func New(pool *redis.Pool, src, id string) *DefaultWorker { 80 | heartbeater := heartbeat.New( 81 | id, 82 | fmt.Sprintf("%s:%s", src, "ticks"), 83 | defaultHeartInterval, pool) 84 | 85 | return &DefaultWorker{ 86 | pool: pool, 87 | availableTasks: queue.NewByteQueue(pool, src), 88 | workingTasks: queue.NewDurableQueue(pool, src, getWorkingQueueName(src, id)), 89 | 90 | lifecycle: NewLifecycle(pool), 91 | detector: heartbeater.Detector(), 92 | heart: heartbeater.Heart(), 93 | janitor: nilJanitor{}, 94 | 95 | state: closed, 96 | } 97 | } 98 | 99 | // Sets the Lifecycle used for managing job states. Note: this is only safe 100 | // to call BEFORE calling Start() 101 | func (w *DefaultWorker) SetLifecycle(lf Lifecycle) { 102 | w.ensureUnstarted() 103 | w.lifecycle = lf 104 | } 105 | 106 | // Sets the Janitor interface used to dispose of old workers. This is optional; 107 | // if you do not need to hook in extra functionality, you don't need to 108 | // provide a janitor. 109 | func (w *DefaultWorker) SetJanitor(janitor Janitor) { 110 | w.ensureUnstarted() 111 | w.janitor = janitor 112 | } 113 | 114 | func (w *DefaultWorker) ensureUnstarted() { 115 | w.smu.Lock() 116 | defer w.smu.Unlock() 117 | 118 | if w.state == open { 119 | panic("Attempted to alter the worker while it was running.") 120 | } 121 | } 122 | 123 | // Start signals the worker to begin receiving tasks from the main queue. 124 | func (w *DefaultWorker) Start() (<-chan *Task, <-chan error) { 125 | w.smu.Lock() 126 | defer w.smu.Unlock() 127 | 128 | w.state = open 129 | w.lifecycle.SetQueues(w.availableTasks, w.workingTasks) 130 | w.janitorRunner = newJanitorRunner(w.pool, w.detector, w.janitor, w.availableTasks) 131 | 132 | errs1 := w.janitorRunner.Start() 133 | tasks, errs2 := w.lifecycle.Listen() 134 | 135 | return tasks, concatErrs(errs1, errs2) 136 | } 137 | 138 | // Close stops polling the queue immediately and waits for all tasks to complete 139 | // before stopping the heartbeat. 140 | func (w *DefaultWorker) Close() { 141 | w.startClosing(func() { 142 | w.state = closing 143 | }) 144 | } 145 | 146 | // Halt stops the heartbeat and queue polling goroutines immediately and cancels 147 | // all tasks, marking them as FAILED before returning. 148 | func (w *DefaultWorker) Halt() { 149 | w.startClosing(func() { 150 | w.state = halting 151 | w.lifecycle.AbandonAll() 152 | }) 153 | } 154 | 155 | // Starts closing the worker if it was not already closed. Invokes the passed 156 | // function to help in the teardown, and blocks until all tasks are done. 157 | func (w *DefaultWorker) startClosing(fn func()) { 158 | w.smu.Lock() 159 | defer w.smu.Unlock() 160 | if w.state != open { 161 | return 162 | } 163 | 164 | w.lifecycle.StopListening() 165 | w.heart.Close() 166 | w.janitorRunner.Close() 167 | 168 | fn() 169 | 170 | w.lifecycle.Await() 171 | 172 | w.state = closed 173 | } 174 | -------------------------------------------------------------------------------- /worker/default_worker_test.go: -------------------------------------------------------------------------------- 1 | package worker_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/mixer/redutil/conn" 8 | "github.com/mixer/redutil/queue" 9 | "github.com/mixer/redutil/test" 10 | "github.com/mixer/redutil/worker" 11 | "github.com/stretchr/testify/mock" 12 | "github.com/stretchr/testify/suite" 13 | ) 14 | 15 | type WorkerSuite struct { 16 | *test.RedisSuite 17 | } 18 | 19 | func TestWorkerSuite(t *testing.T) { 20 | pool, _ := conn.New(conn.ConnectionParam{ 21 | Address: "127.0.0.1:6379", 22 | }, 1) 23 | 24 | suite.Run(t, &WorkerSuite{test.NewSuite(pool)}) 25 | } 26 | 27 | func (suite *WorkerSuite) TestConstruction() { 28 | w := worker.New(suite.Pool, "queue", "worker_1") 29 | defer w.Close() 30 | 31 | suite.Assert().IsType(&worker.DefaultWorker{}, w) 32 | } 33 | 34 | func (suite *WorkerSuite) TestStartPropogatesProcessor() { 35 | c1, c2 := make(chan *worker.Task), make(chan error) 36 | defer func() { 37 | close(c1) 38 | close(c2) 39 | }() 40 | 41 | l := &MockLifecycle{} 42 | task := worker.NewTask(l, []byte("payload")) 43 | 44 | go func() { 45 | c1 <- task 46 | c2 <- errors.New("error") 47 | }() 48 | 49 | l.On("SetQueues", mock.Anything, mock.Anything).Return() 50 | l.On("Listen").Return(c1, c2).Once() 51 | 52 | w := worker.New(suite.Pool, "queue", "worker_1") 53 | w.SetLifecycle(l) 54 | 55 | t, e := w.Start() 56 | 57 | suite.Assert().Equal(task, <-t) 58 | suite.Assert().Equal("error", (<-e).Error()) 59 | } 60 | 61 | func (suite *WorkerSuite) TestCloseWaitsForCompletion() { 62 | c1, c2 := make(chan *worker.Task), make(chan error) 63 | defer func() { 64 | close(c1) 65 | close(c2) 66 | }() 67 | 68 | l := &MockLifecycle{} 69 | l.On("Await").Return() 70 | l.On("Listen").Return(c1, c2) 71 | l.On("StopListening").Return() 72 | l.On("SetQueues", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { 73 | src := args.Get(0).(queue.Queue) 74 | suite.Assert().Equal("queue", src.Source()) 75 | dest := args.Get(1).(*queue.DurableQueue) 76 | suite.Assert().Equal("queue", dest.Source()) 77 | suite.Assert().Equal("queue:worker_worker_1", dest.Dest()) 78 | }) 79 | 80 | w := worker.New(suite.Pool, "queue", "worker_1") 81 | w.SetLifecycle(l) 82 | 83 | w.Start() 84 | w.Close() 85 | 86 | l.AssertExpectations(suite.T()) 87 | } 88 | 89 | func (suite *WorkerSuite) TestHaltDoesNotWaitForCompletion() { 90 | c1, c2 := make(chan *worker.Task), make(chan error) 91 | defer func() { 92 | close(c1) 93 | close(c2) 94 | }() 95 | 96 | l := &MockLifecycle{} 97 | l.On("Await").Return() 98 | l.On("AbandonAll").Return(nil).Once() 99 | l.On("Listen").Return(c1, c2) 100 | l.On("StopListening").Return() 101 | l.On("SetQueues", mock.Anything, mock.Anything).Return() 102 | 103 | w := worker.New(suite.Pool, "queue", "worker_1") 104 | w.SetLifecycle(l) 105 | 106 | w.Start() 107 | w.Halt() 108 | 109 | l.AssertCalled(suite.T(), "Await") 110 | l.AssertCalled(suite.T(), "AbandonAll") 111 | } 112 | -------------------------------------------------------------------------------- /worker/janitor.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "math/rand" 5 | "sync" 6 | "time" 7 | 8 | "github.com/benbjohnson/clock" 9 | "github.com/garyburd/redigo/redis" 10 | "github.com/hjr265/redsync.go/redsync" 11 | "github.com/mixer/redutil/heartbeat" 12 | "github.com/mixer/redutil/queue" 13 | ) 14 | 15 | // The Janitor is used to assist in the tear down of dead workers. It can 16 | // be provided to the worker to hook additional functionality that will 17 | // occur when the worker dies. 18 | type Janitor interface { 19 | // Called when a worker dies after we have acquired a lock and before 20 | // we start moving the worker's queue back to the main processing 21 | // queue. Note that if we does before the worker's queue is moved 22 | // over, this function *can be called multiple times on the 23 | // same worker* 24 | // 25 | // If an error is returned from the function, the queue concatenation 26 | // will be aborted and we'll release the lock. 27 | OnPreConcat(cnx redis.Conn, worker string) error 28 | 29 | // Called when a worker dies after we have acquired a lock and finished 30 | // moving the worker's queue back to the main processing queue. Note 31 | // that, in the result of a panic or power failure, this function 32 | // may never be called, and errors resulting from this function 33 | // will not roll-back the concatenation. 34 | OnPostConcat(cnx redis.Conn, worker string) error 35 | } 36 | 37 | // Base janitor used unless the user provides a replacement. 38 | type nilJanitor struct{} 39 | 40 | var _ Janitor = nilJanitor{} 41 | 42 | func (n nilJanitor) OnPreConcat(cnx redis.Conn, worker string) error { return nil } 43 | func (n nilJanitor) OnPostConcat(cnx redis.Conn, worker string) error { return nil } 44 | 45 | // The janitor is responsible for cleaning up dead workers. 46 | type janitorRunner struct { 47 | // pool is a *redis.Pool used to maintain and use connections into 48 | // Redis. 49 | pool *redis.Pool 50 | 51 | availableTasks queue.Queue 52 | 53 | // Associated heartbeater detector 54 | detector heartbeat.Detector 55 | janitor Janitor 56 | 57 | // Duration between dead checks. The first check will come at a time 58 | // between 0 and time.Duration, so that workers started at the same time 59 | // don't try to contest the same locks. 60 | interval time.Duration 61 | 62 | // interval checker used to prune dead workers 63 | clock clock.Clock 64 | 65 | errs chan error 66 | closer chan struct{} 67 | } 68 | 69 | func newJanitorRunner(pool *redis.Pool, detector heartbeat.Detector, janitor Janitor, 70 | availableTasks queue.Queue) *janitorRunner { 71 | 72 | return &janitorRunner{ 73 | pool: pool, 74 | availableTasks: availableTasks, 75 | detector: detector, 76 | janitor: janitor, 77 | clock: clock.New(), 78 | interval: defaultMonitorInterval, 79 | errs: make(chan error), 80 | closer: make(chan struct{}), 81 | } 82 | } 83 | 84 | func (j *janitorRunner) watchDead() { 85 | defer close(j.errs) 86 | 87 | // Sleep for a random interval so that janitors started at the same time 88 | // don't try to contest the same locks. 89 | select { 90 | case <-j.closer: 91 | return 92 | case <-j.clock.After(time.Duration(float64(j.interval) * rand.Float64())): 93 | } 94 | 95 | ticker := j.clock.Ticker(j.interval) 96 | defer ticker.Stop() 97 | 98 | for { 99 | select { 100 | case <-j.closer: 101 | return 102 | case <-ticker.C: 103 | j.runCleaning() 104 | } 105 | } 106 | } 107 | 108 | // Detects expired records and starts tasks to move any of their abandoned 109 | // tasks back to the main queue. 110 | func (j *janitorRunner) runCleaning() { 111 | dead, err := j.detector.Detect() 112 | if err != nil { 113 | j.errs <- err 114 | return 115 | } 116 | 117 | var wg sync.WaitGroup 118 | wg.Add(len(dead)) 119 | 120 | for _, worker := range dead { 121 | go func(worker string) { 122 | defer wg.Done() 123 | 124 | err := j.handleDeath(worker) 125 | if err != nil && err != redsync.ErrFailed { 126 | j.errs <- err 127 | } 128 | }(worker) 129 | } 130 | 131 | wg.Wait() 132 | } 133 | 134 | // Creates a mutex and attempts to acquire a redlock to dispose of the worker. 135 | func (j *janitorRunner) getLock(worker string) (*redsync.Mutex, error) { 136 | mu, err := redsync.NewMutexWithPool("redutil:lock:"+worker, []*redis.Pool{j.pool}) 137 | if err != nil { 138 | return nil, err 139 | } 140 | 141 | return mu, mu.Lock() 142 | } 143 | 144 | // Processes a dead worker, moving its queue back to the main queue and 145 | // calling the disposer function if we get a lock on it. 146 | func (j *janitorRunner) handleDeath(worker string) error { 147 | mu, err := j.getLock(worker) 148 | if err != nil { 149 | return err 150 | } 151 | defer mu.Unlock() 152 | 153 | cnx := j.pool.Get() 154 | defer cnx.Close() 155 | 156 | if err := j.janitor.OnPreConcat(cnx, worker); err != nil { 157 | return err 158 | } 159 | 160 | _, err = j.availableTasks.Concat( 161 | getWorkingQueueName(j.availableTasks.Source(), worker)) 162 | if err != nil && err != redis.ErrNil { 163 | return err 164 | } 165 | j.detector.Purge(worker) 166 | 167 | return j.janitor.OnPostConcat(cnx, worker) 168 | } 169 | 170 | func (j *janitorRunner) Close() { 171 | j.closer <- struct{}{} 172 | } 173 | 174 | func (j *janitorRunner) Start() <-chan error { 175 | j.errs = make(chan error) 176 | j.closer = make(chan struct{}) 177 | 178 | go j.watchDead() 179 | 180 | return j.errs 181 | } 182 | -------------------------------------------------------------------------------- /worker/janitor_test.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "time" 7 | 8 | "github.com/garyburd/redigo/redis" 9 | "github.com/mixer/redutil/conn" 10 | "github.com/mixer/redutil/heartbeat" 11 | "github.com/mixer/redutil/queue" 12 | "github.com/mixer/redutil/test" 13 | "github.com/stretchr/testify/mock" 14 | "github.com/stretchr/testify/suite" 15 | ) 16 | 17 | type JanitorSuite struct { 18 | *test.RedisSuite 19 | } 20 | 21 | func TestJanitorSuite(t *testing.T) { 22 | pool, _ := conn.New(conn.ConnectionParam{ 23 | Address: "127.0.0.1:6379", 24 | }, 1) 25 | 26 | suite.Run(t, &JanitorSuite{test.NewSuite(pool)}) 27 | } 28 | 29 | type mockJanitor struct { 30 | mock.Mock 31 | } 32 | 33 | func (m *mockJanitor) OnPreConcat(cnx redis.Conn, worker string) error { 34 | return m.Called(cnx, worker).Error(0) 35 | } 36 | 37 | func (m *mockJanitor) OnPostConcat(cnx redis.Conn, worker string) error { 38 | return m.Called(cnx, worker).Error(0) 39 | } 40 | 41 | func (suite *JanitorSuite) generate() (*janitorRunner, *mockJanitor) { 42 | cnx := suite.Pool.Get() 43 | defer cnx.Close() 44 | 45 | run := [][]interface{}{ 46 | {"HSET", "beats", "old", time.Now().UTC().Add(-time.Hour).Format(heartbeat.DefaultTimeFormat)}, 47 | {"HSET", "beats", "new", time.Now().UTC().Add(time.Hour).Format(heartbeat.DefaultTimeFormat)}, 48 | {"DEL", "available"}, 49 | {"DEL", "available:worker_old"}, 50 | {"RPUSH", "available:worker_old", []byte{1, 2, 3}}, 51 | } 52 | 53 | for _, r := range run { 54 | if _, err := cnx.Do(r[0].(string), r[1:]...); err != nil { 55 | panic(err) 56 | } 57 | } 58 | 59 | available := queue.NewByteQueue(suite.Pool, "available") 60 | janitor := new(mockJanitor) 61 | 62 | runner := newJanitorRunner( 63 | suite.Pool, 64 | heartbeat.NewDetector("beats", suite.Pool, heartbeat.HashExpireyStrategy{}), 65 | janitor, 66 | available, 67 | ) 68 | 69 | return runner, janitor 70 | } 71 | 72 | func (suite *JanitorSuite) TestJanitorSweepsTheDustSuccessfully() { 73 | runner, m := suite.generate() 74 | m.On("OnPreConcat", mock.Anything, "old").Return(nil).Once() 75 | m.On("OnPostConcat", mock.Anything, "old").Return(nil).Once() 76 | 77 | go func() { 78 | for err := range runner.errs { 79 | panic(err) 80 | } 81 | }() 82 | 83 | runner.runCleaning() 84 | 85 | suite.WithRedis(func(cnx redis.Conn) { 86 | // moved the queues successfully: 87 | size, err := redis.Int(cnx.Do("LLEN", "available:worker_old")) 88 | suite.Assert().Nil(err) 89 | suite.Assert().Zero(size) 90 | 91 | data, err := redis.Bytes(cnx.Do("RPOP", "available")) 92 | suite.Assert().Nil(err) 93 | suite.Assert().Equal([]byte{1, 2, 3}, data) 94 | 95 | data, err = redis.Bytes(cnx.Do("HGET", "beats", "old")) 96 | suite.Assert().Nil(data) 97 | 98 | _, err = redis.Bytes(cnx.Do("HGET", "beats", "new")) 99 | suite.Assert().Nil(err) 100 | }) 101 | 102 | m.AssertExpectations(suite.T()) 103 | } 104 | 105 | func (suite *JanitorSuite) TestAbortsIfPreConcatFails() { 106 | expectedErr := errors.New("oh no!") 107 | runner, m := suite.generate() 108 | m.On("OnPreConcat", mock.Anything, "old").Return(expectedErr).Once() 109 | 110 | go runner.runCleaning() 111 | suite.Assert().Equal(expectedErr, <-runner.errs) 112 | 113 | suite.WithRedis(func(cnx redis.Conn) { 114 | size, err := redis.Int(cnx.Do("LLEN", "available:worker_old")) 115 | suite.Assert().Nil(err) 116 | suite.Assert().Equal(1, size) 117 | 118 | data, err := redis.Bytes(cnx.Do("HGET", "beats", "old")) 119 | suite.Assert().NotNil(data) 120 | }) 121 | 122 | m.AssertExpectations(suite.T()) 123 | } 124 | -------------------------------------------------------------------------------- /worker/lifecycle.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "github.com/mixer/redutil/queue" 5 | ) 6 | 7 | type Lifecycle interface { 8 | // Sets the queues for the lifecycle's tasks. The lifecycle should 9 | // pull jobs from the `available` queue, then add them to the 10 | // `working` queue until they're complete. When they're complete, 11 | // the jobs can be deleted from that queue. 12 | SetQueues(available queue.Queue, working *queue.DurableQueue) 13 | 14 | // Marks a task as being completed. This is called by the Task.Complete 15 | // method; you should not use this directly. 16 | Complete(task *Task) (err error) 17 | 18 | // Abandon marks a task as having failed, pushing it back onto the 19 | // primary task queue and removing it from our worker queue. This is 20 | // called by the Task.Abandon method; you should not use this directly. 21 | Abandon(task *Task) (err error) 22 | 23 | // Marks *all* tasks in the queue as having been abandoned. Called by 24 | // the worker in the Halt() method. 25 | AbandonAll() (err error) 26 | 27 | // Starts pulling from the processing queue, returning a channel of 28 | // tasks and errors. Can be halted with StopListening() 29 | Listen() (<-chan *Task, <-chan error) 30 | 31 | // Stops an ongoing listening loop. 32 | StopListening() 33 | 34 | // Await blocks until all tasks currently being worked on by the Worker 35 | // are completed. If there are no tasks being worked on, this method 36 | // will return instantly. 37 | Await() 38 | } 39 | -------------------------------------------------------------------------------- /worker/lifecycle_test.go: -------------------------------------------------------------------------------- 1 | package worker_test 2 | 3 | import ( 4 | "github.com/mixer/redutil/queue" 5 | "github.com/mixer/redutil/worker" 6 | "github.com/stretchr/testify/mock" 7 | ) 8 | 9 | type MockLifecycle struct { 10 | mock.Mock 11 | } 12 | 13 | var _ worker.Lifecycle = new(MockLifecycle) 14 | 15 | func (l *MockLifecycle) SetQueues(availableTasks queue.Queue, 16 | workingTasks *queue.DurableQueue) { 17 | l.Called(availableTasks, workingTasks) 18 | } 19 | 20 | func (l *MockLifecycle) Complete(task *worker.Task) error { 21 | args := l.Called(task) 22 | return args.Error(0) 23 | } 24 | 25 | func (l *MockLifecycle) Abandon(task *worker.Task) error { 26 | args := l.Called(task) 27 | return args.Error(0) 28 | } 29 | 30 | func (l *MockLifecycle) AbandonAll() error { 31 | args := l.Called() 32 | return args.Error(0) 33 | } 34 | 35 | func (l *MockLifecycle) Listen() (<-chan *worker.Task, <-chan error) { 36 | args := l.Called() 37 | return args.Get(0).(chan *worker.Task), args.Get(1).(chan error) 38 | } 39 | 40 | func (l *MockLifecycle) StopListening() { l.Called() } 41 | 42 | func (l *MockLifecycle) Await() { l.Called() } 43 | -------------------------------------------------------------------------------- /worker/task.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "encoding/hex" 5 | "errors" 6 | "sync" 7 | ) 8 | 9 | // ErrAlreadyResolved is returned from Task.Fail or Task.Succeed if the Task 10 | // was already marked as having failed or resolved. 11 | var ErrAlreadyResolved = errors.New("Attempted to resolve an already-resolved task.") 12 | 13 | // A Task encapsulates a item coming off the main queue that needs to be 14 | // "worked" on. Upon receiving a Task, callers may either Complete() or Fail() 15 | // the task, which will be delegated into the lifecycle appropriately. 16 | // 17 | // Should execution be halted, the Closer() channel will get an empty "message" 18 | // that can be read off, which means that execution has indeed been halted. 19 | type Task struct { 20 | lifecycle Lifecycle 21 | // payload is the data that this Task is holding 22 | payload []byte 23 | // Whether the task has already been marked as succeeded or failed. 24 | resolved bool 25 | resolvedMu sync.Mutex 26 | } 27 | 28 | // NewTask initializes and returns a pointer to a new Task instance. It is 29 | // constructed with the given payload, progress, and closer channels 30 | // respectively, all of which should be open when either Succeed(), Fail() or 31 | // Closer() is called. 32 | func NewTask(lifecycle Lifecycle, payload []byte) *Task { 33 | return &Task{lifecycle: lifecycle, payload: payload} 34 | } 35 | 36 | // Bytes returns the bytes that this Task is holding, and is the "data" to be 37 | // worked on. 38 | func (t *Task) Bytes() []byte { return t.payload } 39 | 40 | // SetBytes sets the bytes that this task is holding. 41 | func (t *Task) SetBytes(bytes []byte) { t.payload = bytes } 42 | 43 | // Succeed signals the lifecycle that work on this task has been completed, 44 | // and removes the task from the worker queue. 45 | func (t *Task) Succeed() error { 46 | return t.guardResolution(func() error { 47 | return t.lifecycle.Complete(t) 48 | }) 49 | } 50 | 51 | // Fail signals the lifecycle that work on this task has failed, causing 52 | // it to return the task to the main processing queue to be retried. 53 | func (t *Task) Fail() error { 54 | return t.guardResolution(func() error { 55 | return t.lifecycle.Abandon(t) 56 | }) 57 | } 58 | 59 | // IsResolved returns true if the task has already been marked as having 60 | // succeeded or failed. 61 | func (t *Task) IsResolved() bool { 62 | t.resolvedMu.Lock() 63 | defer t.resolvedMu.Unlock() 64 | 65 | return t.resolved 66 | } 67 | 68 | // HexDump returns a byte dump of the task, in the same format as `hexdump -C`. 69 | // This is useful for debugging/logging purposes. 70 | func (t *Task) HexDump() string { 71 | return hex.Dump(t.Bytes()) 72 | } 73 | 74 | // String returns the strinigified contents of the task payload. 75 | func (t *Task) String() string { 76 | return string(t.Bytes()) 77 | } 78 | 79 | // guardResolution runs the inner fn only if the task is not already resolved, 80 | // returning ErrAlreadyResolved if that's not the case. If the inner function 81 | // returns no error, the task will subsequently be marked as resolved. 82 | func (t *Task) guardResolution(fn func() error) error { 83 | t.resolvedMu.Lock() 84 | defer t.resolvedMu.Unlock() 85 | if t.resolved { 86 | return ErrAlreadyResolved 87 | } 88 | 89 | err := fn() 90 | if err == nil { 91 | t.resolved = true 92 | } 93 | 94 | return err 95 | } 96 | -------------------------------------------------------------------------------- /worker/task_test.go: -------------------------------------------------------------------------------- 1 | package worker_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/mixer/redutil/conn" 7 | "github.com/mixer/redutil/test" 8 | "github.com/mixer/redutil/worker" 9 | "github.com/stretchr/testify/mock" 10 | "github.com/stretchr/testify/suite" 11 | ) 12 | 13 | type TaskSuite struct { 14 | *test.RedisSuite 15 | } 16 | 17 | func TestTaskSuite(t *testing.T) { 18 | pool, _ := conn.New(conn.ConnectionParam{ 19 | Address: "127.0.0.1:6379", 20 | }, 1) 21 | 22 | suite.Run(t, &TaskSuite{test.NewSuite(pool)}) 23 | } 24 | 25 | func (suite *TaskSuite) TestConstruction() { 26 | task := worker.NewTask(&MockLifecycle{}, []byte{}) 27 | 28 | suite.Assert().IsType(&worker.Task{}, task) 29 | } 30 | 31 | func (suite *TaskSuite) TestBytesReturnsPayload() { 32 | task := worker.NewTask(&MockLifecycle{}, []byte("hello world")) 33 | 34 | payload := task.Bytes() 35 | 36 | suite.Assert().Equal([]byte("hello world"), payload) 37 | } 38 | 39 | func (suite *TaskSuite) TestDumpsBody() { 40 | task := worker.NewTask(&MockLifecycle{}, []byte("hello world")) 41 | expected := "00000000 68 65 6c 6c 6f 20 77 6f 72 6c 64 |hello world|\n" 42 | suite.Assert().Equal(expected, task.HexDump()) 43 | } 44 | 45 | func (suite *TaskSuite) TestStringReturns() { 46 | task := worker.NewTask(&MockLifecycle{}, []byte("hello world")) 47 | suite.Assert().Equal("hello world", task.String()) 48 | } 49 | 50 | func (suite *TaskSuite) TestSucceedingDelegatesToLifecycle() { 51 | lifecycle := &MockLifecycle{} 52 | lifecycle.On("Complete", mock.Anything).Return(nil) 53 | 54 | task := worker.NewTask(lifecycle, []byte{}) 55 | 56 | task.Succeed() 57 | 58 | lifecycle.AssertCalled(suite.T(), "Complete", task) 59 | suite.Assert().True(task.IsResolved()) 60 | } 61 | 62 | func (suite *TaskSuite) TestFailingDelegatesToLifecycle() { 63 | lifecycle := &MockLifecycle{} 64 | lifecycle.On("Abandon", mock.Anything).Return(nil) 65 | 66 | task := worker.NewTask(lifecycle, []byte{}) 67 | 68 | task.Fail() 69 | 70 | lifecycle.AssertCalled(suite.T(), "Abandon", task) 71 | suite.Assert().True(task.IsResolved()) 72 | } 73 | 74 | func (suite *TaskSuite) TestSucceedingMultipleTimes() { 75 | lifecycle := &MockLifecycle{} 76 | lifecycle.On("Complete", mock.Anything).Return(nil).Once() 77 | 78 | task := worker.NewTask(lifecycle, []byte{}) 79 | 80 | task.Succeed() 81 | task.Succeed() 82 | 83 | lifecycle.AssertNumberOfCalls(suite.T(), "Complete", 1) 84 | suite.Assert().True(task.IsResolved()) 85 | } 86 | 87 | func (suite *TaskSuite) TestFailingMultipleTimes() { 88 | lifecycle := &MockLifecycle{} 89 | lifecycle.On("Abandon", mock.Anything).Return(nil).Once() 90 | 91 | task := worker.NewTask(lifecycle, []byte{}) 92 | 93 | task.Fail() 94 | task.Fail() 95 | 96 | lifecycle.AssertNumberOfCalls(suite.T(), "Abandon", 1) 97 | suite.Assert().True(task.IsResolved()) 98 | } 99 | -------------------------------------------------------------------------------- /worker/util.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import "reflect" 4 | 5 | // Concatenates the output from several error channels into a single one. 6 | // Stops and closes the resulting channel when all its inputs are closed. 7 | func concatErrs(errs ...<-chan error) <-chan error { 8 | cases := make([]reflect.SelectCase, len(errs)) 9 | for i, ch := range errs { 10 | cases[i] = reflect.SelectCase{ 11 | Dir: reflect.SelectRecv, 12 | Chan: reflect.ValueOf(ch), 13 | } 14 | } 15 | 16 | out := make(chan error) 17 | go func() { 18 | for len(cases) > 0 { 19 | chosen, value, ok := reflect.Select(cases) 20 | if !ok { 21 | cases = append(cases[:chosen], cases[chosen+1:]...) 22 | } else { 23 | out <- value.Interface().(error) 24 | } 25 | } 26 | 27 | close(out) 28 | }() 29 | 30 | return out 31 | } 32 | -------------------------------------------------------------------------------- /worker/util_test.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestConcatsErrors(t *testing.T) { 12 | ch1 := make(chan error) 13 | go func() { 14 | ch1 <- errors.New("a1") 15 | time.Sleep(10 * time.Millisecond) 16 | ch1 <- errors.New("a2") 17 | time.Sleep(10 * time.Millisecond) 18 | ch1 <- errors.New("a3") 19 | close(ch1) 20 | }() 21 | 22 | ch2 := make(chan error) 23 | go func() { 24 | time.Sleep(5 * time.Millisecond) 25 | ch1 <- errors.New("b1") 26 | close(ch2) 27 | }() 28 | 29 | tt := []struct { 30 | err string 31 | closed bool 32 | }{ 33 | {"a1", false}, 34 | {"b1", false}, 35 | {"a2", false}, 36 | {"a3", false}, 37 | {"", true}, 38 | } 39 | out := concatErrs(ch1, ch2) 40 | 41 | for _, test := range tt { 42 | err, ok := <-out 43 | if test.closed { 44 | assert.False(t, ok) 45 | } else { 46 | assert.True(t, ok) 47 | assert.Equal(t, test.err, err.Error()) 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /worker/worker.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | type Worker interface { 4 | // Start begins the process of pulling []byte from the Queue, returning 5 | // them as `*Task`s on a channel of Tasks (`<-chan *Task`). 6 | // 7 | // If any errors are encountered along the way, they are sent across the 8 | // (unbuffered) `<-chan error`. 9 | Start() (<-chan *Task, <-chan error) 10 | 11 | // Close closes the pulling goroutine and waits for all tasks to finish 12 | // before returning. 13 | Close() 14 | 15 | // Halt closes the pulling goroutine, but does not wait for all tasks to 16 | // finish before returning. 17 | Halt() 18 | } 19 | --------------------------------------------------------------------------------