├── .gitignore ├── cluster ├── crc16_test.go ├── loadtest │ ├── README.md │ └── loadtest.go ├── crc16.go ├── cluster_test.go └── cluster.go ├── radix.go ├── LICENSE.txt ├── util ├── lua.go ├── lua_test.go ├── util.go ├── scan.go └── scan_test.go ├── pool ├── pool_test.go ├── doc.go └── pool.go ├── redis ├── client_test.go ├── doc.go ├── resp_test.go ├── client.go └── resp.go ├── README.md ├── pubsub ├── sub_test.go └── sub.go └── sentinel └── sentinel.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Generic 2 | *~ 3 | [._]* 4 | 5 | # Go 6 | *.[ao] 7 | *.[568vq] 8 | [568vq].out 9 | main 10 | *.test 11 | 12 | # Cgo 13 | *.cgo* 14 | *.so 15 | -------------------------------------------------------------------------------- /cluster/crc16_test.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import ( 4 | . "testing" 5 | ) 6 | 7 | func TestCRC16(t *T) { 8 | if c := CRC16([]byte("123456789")); c != 0x31c3 { 9 | t.Fatalf("checksum came out to %x not %x", c, 0x31c3) 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /radix.go: -------------------------------------------------------------------------------- 1 | // Package radix is a simple redis driver. This top-level package is a wrapper 2 | // for its sub-packages and doesn't actually contain an code. You likely want to 3 | // look at the redis sub-package for a straightforward redis client 4 | package radix 5 | -------------------------------------------------------------------------------- /cluster/loadtest/README.md: -------------------------------------------------------------------------------- 1 | # loadtest 2 | 3 | This is a little script that can be used for load-testing and making sure that 4 | reconnect/redistirbution logic works alright. It will make a connection to the 5 | test cluster (ports 7000 and 7001) and continuously make GET and SET requests on 6 | random keys, sometimes going back and doing GET/SET on keys it looked at in the 7 | past. 8 | 9 | While this is running you can kill redis nodes, redistribute the nodes, etc... 10 | to see how the cluster reacts. All errors will be reported to the console. 11 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR 17 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 19 | IN THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /cluster/loadtest/loadtest.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/hex" 6 | "io" 7 | "log" 8 | "time" 9 | 10 | "github.com/pl/radix.v2/cluster" 11 | ) 12 | 13 | func randString() string { 14 | b := make([]byte, 16) 15 | if _, err := io.ReadFull(rand.Reader, b); err != nil { 16 | log.Fatal(err) 17 | } 18 | return hex.EncodeToString(b) 19 | } 20 | 21 | func main() { 22 | c, err := cluster.New("localhost:7000") 23 | if err != nil { 24 | log.Fatal(err) 25 | } 26 | 27 | oldKeys := make(chan string, 1000) 28 | 29 | doRand := time.Tick(100 * time.Millisecond) 30 | doOldRand := time.Tick(1 * time.Second) 31 | 32 | for { 33 | select { 34 | case <-doRand: 35 | key := randString() 36 | doGetSet(c, key) 37 | select { 38 | case oldKeys <- key: 39 | default: 40 | } 41 | 42 | case <-doOldRand: 43 | select { 44 | case key := <-oldKeys: 45 | doGetSet(c, key) 46 | default: 47 | } 48 | } 49 | } 50 | } 51 | 52 | func doGetSet(c *cluster.Cluster, key string) { 53 | if err := c.Cmd("GET", key).Err; err != nil { 54 | log.Printf("GET %s -> %s", key, err) 55 | } 56 | val := randString() 57 | if err := c.Cmd("SET", key, val).Err; err != nil { 58 | log.Printf("SET %s %s -> %s", key, val, err) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /util/lua.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "crypto/sha1" 5 | "encoding/hex" 6 | "strings" 7 | 8 | "github.com/pl/radix.v2/redis" 9 | ) 10 | 11 | // LuaEval calls EVAL on the given Cmder for the given script, passing the key 12 | // count and argument list in as well. See http://redis.io/commands/eval for 13 | // more on how EVAL works and for the meaning of the keys argument. 14 | // 15 | // LuaEval will automatically try to call EVALSHA first in order to preserve 16 | // bandwidth, and only falls back on EVAL if the script has never been used 17 | // before. 18 | // 19 | // This method works with any of the Cmder's implemented in radix.v2, including 20 | // Client, Pool, and Cluster. 21 | // 22 | // r := util.LuaEval(c, `return redis.call('GET', KEYS[1])`, 1, "foo") 23 | // 24 | func LuaEval(c Cmder, script string, keys int, args ...interface{}) *redis.Resp { 25 | mainKey, _ := redis.KeyFromArgs(args...) 26 | 27 | sumRaw := sha1.Sum([]byte(script)) 28 | sum := hex.EncodeToString(sumRaw[:]) 29 | 30 | var r *redis.Resp 31 | if err := withClientForKey(c, mainKey, func(cc Cmder) { 32 | r = c.Cmd("EVALSHA", sum, keys, args) 33 | if r.Err != nil && strings.HasPrefix(r.Err.Error(), "NOSCRIPT") { 34 | r = c.Cmd("EVAL", script, keys, args) 35 | } 36 | }); err != nil { 37 | return redis.NewResp(err) 38 | } 39 | 40 | return r 41 | } 42 | -------------------------------------------------------------------------------- /pool/pool_test.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "sync" 5 | . "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestPool(t *T) { 12 | size := 10 13 | pool, err := New("tcp", "localhost:6379", size) 14 | require.Nil(t, err) 15 | 16 | var wg sync.WaitGroup 17 | for i := 0; i < size*4; i++ { 18 | wg.Add(1) 19 | go func() { 20 | for i := 0; i < 100; i++ { 21 | conn, err := pool.Get() 22 | assert.Nil(t, err) 23 | 24 | assert.Nil(t, conn.Cmd("ECHO", "HI").Err) 25 | 26 | pool.Put(conn) 27 | } 28 | wg.Done() 29 | }() 30 | } 31 | wg.Wait() 32 | 33 | assert.Equal(t, size, len(pool.pool)) 34 | 35 | pool.Empty() 36 | assert.Equal(t, 0, len(pool.pool)) 37 | } 38 | 39 | func TestCmd(t *T) { 40 | size := 10 41 | pool, err := New("tcp", "localhost:6379", 10) 42 | require.Nil(t, err) 43 | 44 | var wg sync.WaitGroup 45 | for i := 0; i < size*4; i++ { 46 | wg.Add(1) 47 | go func() { 48 | for i := 0; i < 100; i++ { 49 | assert.Nil(t, pool.Cmd("ECHO", "HI").Err) 50 | } 51 | wg.Done() 52 | }() 53 | } 54 | wg.Wait() 55 | assert.Equal(t, size, len(pool.pool)) 56 | } 57 | 58 | func TestPut(t *T) { 59 | pool, err := New("tcp", "localhost:6379", 10) 60 | require.Nil(t, err) 61 | 62 | conn, err := pool.Get() 63 | require.Nil(t, err) 64 | assert.Equal(t, 9, len(pool.pool)) 65 | 66 | conn.Close() 67 | assert.NotNil(t, conn.Cmd("PING").Err) 68 | pool.Put(conn) 69 | 70 | // Make sure that Put does not accept a connection which has had a critical 71 | // network error 72 | assert.Equal(t, 9, len(pool.pool)) 73 | } 74 | -------------------------------------------------------------------------------- /util/lua_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "crypto/rand" 5 | "fmt" 6 | . "testing" 7 | 8 | "github.com/pl/radix.v2/cluster" 9 | "github.com/pl/radix.v2/redis" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | // randTestScript returns a script, and a key and value which can be input into 15 | // the script. The script will always have a different sha1 than the last. 16 | func randTestScript() (string, string, string) { 17 | r := [][]byte{ 18 | make([]byte, 10), // shard 19 | make([]byte, 10), 20 | make([]byte, 10), 21 | make([]byte, 10), 22 | make([]byte, 10), 23 | } 24 | for i := range r { 25 | if _, err := rand.Read(r[i]); err != nil { 26 | panic(err) 27 | } 28 | if i > 0 { 29 | r[i] = []byte(fmt.Sprintf(`{%x}%x`, string(r[0]), string(r[i]))) 30 | } 31 | } 32 | script := fmt.Sprintf( 33 | `return redis.call('MSET', KEYS[1], ARGV[1], '%s', '%s')`, 34 | r[1], r[2], 35 | ) 36 | return script, string(r[3]), string(r[4]) 37 | } 38 | 39 | func TestLuaEval(t *T) { 40 | c1, err := redis.Dial("tcp", "127.0.0.1:6379") 41 | require.Nil(t, err) 42 | c2, err := cluster.New("127.0.0.1:7000") 43 | require.Nil(t, err) 44 | 45 | cs := []Cmder{c1, c2} 46 | 47 | for _, c := range cs { 48 | script, key, val := randTestScript() 49 | s, err := LuaEval(c, script, 1, key, val).Str() 50 | require.Nil(t, err) 51 | assert.Equal(t, "OK", s) 52 | 53 | // The second time the command will be hashed 54 | script, key, val = randTestScript() 55 | s, err = LuaEval(c, script, 1, key, val).Str() 56 | require.Nil(t, err) 57 | assert.Equal(t, "OK", s) 58 | 59 | s, err = c.Cmd("GET", key).Str() 60 | require.Nil(t, err) 61 | assert.Equal(t, val, s) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /util/util.go: -------------------------------------------------------------------------------- 1 | // Package util is a collection of helper functions for interacting with various 2 | // parts of the radix.v2 package 3 | package util 4 | 5 | import ( 6 | "github.com/pl/radix.v2/cluster" 7 | "github.com/pl/radix.v2/pool" 8 | "github.com/pl/radix.v2/redis" 9 | ) 10 | 11 | // Cmder is an interface which can be used to interchangeably work with either 12 | // redis.Client (the basic, single connection redis client), pool.Pool, or 13 | // cluster.Cluster. All three implement a Cmd method (although, as is the case 14 | // with Cluster, sometimes with different limitations), and therefore all three 15 | // are Cmders 16 | type Cmder interface { 17 | Cmd(cmd string, args ...interface{}) *redis.Resp 18 | } 19 | 20 | // withClientForKey is useful for retrieving a single client which can handle 21 | // the given key and perform one or more requests on them, especially when the 22 | // passed in Cmder is actually a Cluster or Pool. 23 | // 24 | // The function given takes a Cmder and not a Client because the passed in Cmder 25 | // may not be one implemented in radix.v2, and in that case may not actually 26 | // have a way of mapping to a Client. In that case it is simply passed directly 27 | // through to fn. 28 | func withClientForKey(c Cmder, key string, fn func(c Cmder)) error { 29 | var singleC Cmder 30 | 31 | switch cc := c.(type) { 32 | case *cluster.Cluster: 33 | client, err := cc.GetForKey(key) 34 | if err != nil { 35 | return err 36 | } 37 | defer cc.Put(client) 38 | singleC = client 39 | 40 | case *pool.Pool: 41 | client, err := cc.Get() 42 | if err != nil { 43 | return err 44 | } 45 | defer cc.Put(client) 46 | singleC = client 47 | 48 | default: 49 | singleC = cc 50 | } 51 | 52 | fn(singleC) 53 | return nil 54 | } 55 | -------------------------------------------------------------------------------- /pool/doc.go: -------------------------------------------------------------------------------- 1 | // Package pool implements a connection pool for redis connections which is 2 | // thread-safe. 3 | // 4 | // Basic usage 5 | // 6 | // The basic use-case is to create a pool and then pass that pool amongst 7 | // multiple go-routines, each of which can use it safely. To retrieve a 8 | // connection you use Get, and to return the connection to the pool when you're 9 | // done with it you use Put. 10 | // 11 | // p, err := pool.New("tcp", "localhost:6379", 10) 12 | // if err != nil { 13 | // // handle error 14 | // } 15 | // 16 | // // In another go-routine 17 | // 18 | // conn, err := p.Get() 19 | // if err != nil { 20 | // // handle error 21 | // } 22 | // 23 | // if conn.Cmd("SOME", "CMD").Err != nil { 24 | // // handle error 25 | // } 26 | // 27 | // p.Put(conn) 28 | // 29 | // Shortcuts 30 | // 31 | // If you're doing multiple operations you may find it useful to defer the Put 32 | // right after retrieving a connection, so that you don't have to always 33 | // remember to do so 34 | // 35 | // conn, err := p.Get() 36 | // if err != nil { 37 | // // handle error 38 | // } 39 | // defer p.Put(conn) 40 | // 41 | // if conn.Cmd("SOME", "CMD").Err != nil { 42 | // // handle error 43 | // } 44 | // 45 | // if conn.Cmd("SOME", "OTHER", "CMD").Err != nil { 46 | // // handle error 47 | // } 48 | // 49 | // Additionally there is the Cmd method on Pool, which handles Get-ing and 50 | // Put-ing for you in the case of only wanting to execute a single command 51 | // 52 | // r := p.Cmd("SOME", "CMD") 53 | // if r.Err != nil { 54 | // // handle error 55 | // } 56 | // 57 | // Custom connections 58 | // 59 | // Sometimes it's necessary to run some code on each connection in a pool upon 60 | // its creation, for example in the case of AUTH. This can be done with 61 | // NewCustom, like so 62 | // 63 | // df := func(network, addr string) (*redis.Client, error) { 64 | // client, err := redis.Dial(network, addr) 65 | // if err != nil { 66 | // return nil, err 67 | // } 68 | // if err = client.Cmd("AUTH", "SUPERSECRET").Err; err != nil { 69 | // client.Close() 70 | // return nil, err 71 | // } 72 | // return client, nil 73 | // } 74 | // p, err := pool.NewCustom("tcp", "127.0.0.1:6379", 10, df) 75 | package pool 76 | -------------------------------------------------------------------------------- /redis/client_test.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | . "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func dial(t *T) *Client { 12 | client, err := DialTimeout("tcp", "127.0.0.1:6379", 10*time.Second) 13 | require.Nil(t, err) 14 | return client 15 | } 16 | 17 | func TestCmd(t *T) { 18 | c := dial(t) 19 | v, err := c.Cmd("echo", "Hello, World!").Str() 20 | require.Nil(t, err) 21 | assert.Equal(t, "Hello, World!", v) 22 | 23 | // Test that a bad command properly returns an AppErr 24 | r := c.Cmd("non-existant-cmd") 25 | assert.Equal(t, AppErr, r.typ) 26 | assert.NotNil(t, r.Err) 27 | 28 | // Test that application level errors propagate correctly 29 | require.Nil(t, c.Cmd("sadd", "foo", "bar").Err) 30 | _, err = c.Cmd("get", "foo").Str() 31 | assert.NotNil(t, "", err) 32 | } 33 | 34 | func TestPipeline(t *T) { 35 | c := dial(t) 36 | // Do this multiple times to make sure pipeline resetting happens correctly 37 | for i := 0; i < 10; i++ { 38 | c.PipeAppend("echo", "foo") 39 | c.PipeAppend("echo", "bar") 40 | c.PipeAppend("echo", "zot") 41 | 42 | v, err := c.PipeResp().Str() 43 | require.Nil(t, err) 44 | assert.Equal(t, "foo", v) 45 | 46 | v, err = c.PipeResp().Str() 47 | require.Nil(t, err) 48 | assert.Equal(t, "bar", v) 49 | 50 | v, err = c.PipeResp().Str() 51 | require.Nil(t, err) 52 | assert.Equal(t, "zot", v) 53 | 54 | r := c.PipeResp() 55 | assert.Equal(t, AppErr, r.typ) 56 | assert.Equal(t, ErrPipelineEmpty, r.Err) 57 | } 58 | } 59 | 60 | func TestLastCritical(t *T) { 61 | c := dial(t) 62 | 63 | // LastCritical shouldn't get set for application errors 64 | assert.NotNil(t, c.Cmd("WHAT").Err) 65 | assert.Nil(t, c.LastCritical) 66 | 67 | c.Close() 68 | r := c.Cmd("WHAT") 69 | assert.Equal(t, true, r.IsType(IOErr)) 70 | assert.NotNil(t, r.Err) 71 | assert.NotNil(t, c.LastCritical) 72 | } 73 | 74 | func TestKeyFromArg(t *T) { 75 | m := map[string]interface{}{ 76 | "foo0": "foo0", 77 | "foo1": []byte("foo1"), 78 | "1": 1, 79 | "1.1": 1.1, 80 | "foo2": []string{"foo2", "bar"}, 81 | "foo3": [][]string{{"foo3", "bar"}, {"baz", "buz"}}, 82 | } 83 | 84 | for out, in := range m { 85 | key, err := KeyFromArgs(in) 86 | assert.Nil(t, err) 87 | assert.Equal(t, out, key) 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /cluster/crc16.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | var tab = [256]uint16{ 4 | 0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7, 5 | 0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef, 6 | 0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6, 7 | 0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de, 8 | 0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485, 9 | 0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d, 10 | 0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4, 11 | 0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc, 12 | 0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823, 13 | 0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b, 14 | 0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12, 15 | 0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a, 16 | 0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41, 17 | 0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49, 18 | 0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70, 19 | 0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78, 20 | 0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f, 21 | 0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067, 22 | 0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e, 23 | 0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256, 24 | 0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d, 25 | 0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405, 26 | 0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c, 27 | 0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634, 28 | 0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab, 29 | 0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3, 30 | 0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a, 31 | 0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92, 32 | 0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9, 33 | 0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1, 34 | 0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8, 35 | 0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0, 36 | } 37 | 38 | // CRC16 returns checksum for a given set of bytes based on the crc algorithm 39 | // defined for hashing redis keys in a cluster setup 40 | func CRC16(buf []byte) uint16 { 41 | crc := uint16(0) 42 | for _, b := range buf { 43 | index := byte(crc>>8) ^ b 44 | crc = (crc << 8) ^ tab[index] 45 | } 46 | return crc 47 | } 48 | -------------------------------------------------------------------------------- /util/scan.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | 7 | "github.com/pl/radix.v2/cluster" 8 | ) 9 | 10 | func scanSingle(r Cmder, ch chan string, cmd, key, pattern string) error { 11 | defer close(ch) 12 | cmd = strings.ToUpper(cmd) 13 | 14 | var keys []string 15 | cursor := "0" 16 | for { 17 | args := make([]interface{}, 0, 4) 18 | if cmd != "SCAN" { 19 | args = append(args, key) 20 | } 21 | args = append(args, cursor, "MATCH", pattern) 22 | 23 | parts, err := r.Cmd(cmd, args...).Array() 24 | if err != nil { 25 | return err 26 | } 27 | 28 | if len(parts) < 2 { 29 | return errors.New("not enough parts returned") 30 | } 31 | 32 | if cursor, err = parts[0].Str(); err != nil { 33 | return err 34 | } 35 | 36 | if keys, err = parts[1].List(); err != nil { 37 | return err 38 | } 39 | 40 | for i := range keys { 41 | ch <- keys[i] 42 | } 43 | 44 | if cursor == "0" { 45 | return nil 46 | } 47 | } 48 | } 49 | 50 | // scanCluster is like Scan except it operates over a whole cluster. Unlike Scan 51 | // it only works with SCAN and as such only takes in a pattern string. 52 | func scanCluster(c *cluster.Cluster, ch chan string, pattern string) error { 53 | defer close(ch) 54 | clients, err := c.GetEvery() 55 | if err != nil { 56 | return err 57 | } 58 | for _, client := range clients { 59 | defer c.Put(client) 60 | } 61 | 62 | for _, client := range clients { 63 | cch := make(chan string) 64 | var err error 65 | go func() { 66 | err = scanSingle(client, cch, "SCAN", "", pattern) 67 | }() 68 | for key := range cch { 69 | ch <- key 70 | } 71 | if err != nil { 72 | return err 73 | } 74 | } 75 | 76 | return nil 77 | } 78 | 79 | // Scan is a helper function for performing any of the redis *SCAN functions. It 80 | // takes in a channel which keys returned by redis will be written to, and 81 | // returns an error should any occur. The input channel will always be closed 82 | // when Scan returns, and *must* be read until it is closed. 83 | // 84 | // The key argument is only needed if cmd isn't SCAN 85 | // 86 | // Example SCAN command 87 | // 88 | // ch := make(chan string) 89 | // var err error 90 | // go func() { 91 | // err = util.Scan(r, ch, "SCAN", "", "*") 92 | // }() 93 | // for key := range ch { 94 | // // do something with key 95 | // } 96 | // if err != nil { 97 | // // handle error 98 | // } 99 | // 100 | // Example HSCAN command 101 | // 102 | // ch := make(chan string) 103 | // var err error 104 | // go func() { 105 | // err = util.Scan(r, ch, "HSCAN", "somekey", "*") 106 | // }() 107 | // for key := range ch { 108 | // // do something with key 109 | // } 110 | // if err != nil { 111 | // // handle error 112 | // } 113 | // 114 | func Scan(r Cmder, ch chan string, cmd, key, pattern string) error { 115 | if rr, ok := r.(*cluster.Cluster); ok && strings.ToUpper(cmd) == "SCAN" { 116 | return scanCluster(rr, ch, pattern) 117 | } 118 | var cmdErr error 119 | err := withClientForKey(r, key, func(c Cmder) { 120 | cmdErr = scanSingle(r, ch, cmd, key, pattern) 121 | }) 122 | if err != nil { 123 | return err 124 | } 125 | return cmdErr 126 | } 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Radix 2 | 3 | [![GoDoc](https://godoc.org/github.com/mediocregopher/radix.v2?status.svg)](https://godoc.org/github.com/mediocregopher/radix.v2) 4 | 5 | Radix is a minimalistic [Redis][redis] client for Go. It is broken up into 6 | small, single-purpose packages for ease of use. 7 | 8 | * [redis](http://godoc.org/github.com/mediocregopher/radix.v2/redis) - A wrapper 9 | around a single redis connection. Supports normal commands/response as well as 10 | pipelining. 11 | 12 | * [pool](http://godoc.org/github.com/mediocregopher/radix.v2/pool) - a simple, 13 | automatically expanding/cleaning connection pool. 14 | 15 | * [pubsub](http://godoc.org/github.com/mediocregopher/radix.v2/pubsub) - a 16 | simple wrapper providing convenient access to Redis Pub/Sub functionality. 17 | 18 | * [sentinel](http://godoc.org/github.com/mediocregopher/radix.v2/sentinel) - a 19 | client for [redis sentinel][sentinel] which acts as a connection pool for a 20 | cluster of redis nodes. A sentinel client connects to a sentinel instance and 21 | any master redis instances that instance is monitoring. If a master becomes 22 | unavailable, the sentinel client will automatically start distributing 23 | connections from the slave chosen by the sentinel instance. 24 | 25 | * [cluster](http://godoc.org/github.com/mediocregopher/radix.v2/cluster) - a 26 | client for a [redis cluster][cluster] which automatically handles interacting 27 | with a redis cluster, transparently handling redirects and pooling. This 28 | client keeps a mapping of slots to nodes internally, and automatically keeps 29 | it up-to-date. 30 | 31 | * [util](http://godoc.org/github.com/mediocregopher/radix.v2/util) - a 32 | package containing a number of helper methods for doing common tasks with the 33 | radix package, such as SCANing either a single redis instance or every one in 34 | a cluster, or executing server-side lua 35 | 36 | ## Installation 37 | 38 | go get github.com/mediocregopher/radix.v2/... 39 | 40 | ## Testing 41 | 42 | go test github.com/mediocregopher/radix.v2/... 43 | 44 | The test action assumes you have the following running: 45 | 46 | * A redis server listening on port 6379 47 | 48 | * A redis cluster node listening on port 7000, handling slots 0 through 8191 49 | 50 | * A redis cluster node listening on port 7001, handling slots 8192 through 16383 51 | 52 | The slot number are particularly important as the tests for the cluster package 53 | do some trickery which depends on certain keys being assigned to certain nodes 54 | 55 | ## Why is this V2? 56 | 57 | V1 of radix was started by [fzzy](https://github.com/fzzy) and can be found 58 | [here](https://github.com/fzzy/radix). Some time in 2014 I took over the project 59 | and reached a point where I couldn't make improvements that I wanted to make due 60 | to past design decisions (mostly my own). So I've started V2, which has 61 | redesigned some core aspects of the api and hopefully made things easier to use 62 | and faster. 63 | 64 | Here are some of the major changes since V1: 65 | 66 | * Combining resp and redis packages 67 | 68 | * Reply is now Resp 69 | 70 | * Hash is now Map 71 | 72 | * Append is now PipeAppend, GetReply is now PipeResp 73 | 74 | * PipelineQueueEmptyError is now ErrPipelineEmpty 75 | 76 | * Significant changes to pool, making it easier to use 77 | 78 | * More functionality in cluster 79 | 80 | ## Copyright and licensing 81 | 82 | Unless otherwise noted, the source files are distributed under the *MIT License* 83 | found in the LICENSE.txt file. 84 | 85 | [redis]: http://redis.io 86 | [sentinel]: http://redis.io/topics/sentinel 87 | [cluster]: http://redis.io/topics/cluster-spec 88 | -------------------------------------------------------------------------------- /pool/pool.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "github.com/pl/radix.v2/redis" 5 | ) 6 | 7 | // Pool is a simple connection pool for redis Clients. It will create a small 8 | // pool of initial connections, and if more connections are needed they will be 9 | // created on demand. If a connection is Put back and the pool is full it will 10 | // be closed. 11 | type Pool struct { 12 | pool chan *redis.Client 13 | df DialFunc 14 | 15 | // The network/address that the pool is connecting to. These are going to be 16 | // whatever was passed into the New function. These should not be 17 | // changed after the pool is initialized 18 | Network, Addr string 19 | } 20 | 21 | // DialFunc is a function which can be passed into NewCustom 22 | type DialFunc func(network, addr string) (*redis.Client, error) 23 | 24 | // NewCustom is like New except you can specify a DialFunc which will be 25 | // used when creating new connections for the pool. The common use-case is to do 26 | // authentication for new connections. 27 | func NewCustom(network, addr string, size int, df DialFunc) (*Pool, error) { 28 | var client *redis.Client 29 | var err error 30 | pool := make([]*redis.Client, 0, size) 31 | for i := 0; i < size; i++ { 32 | client, err = df(network, addr) 33 | if err != nil { 34 | for _, client = range pool { 35 | client.Close() 36 | } 37 | pool = pool[0:] 38 | break 39 | } 40 | pool = append(pool, client) 41 | } 42 | p := Pool{ 43 | Network: network, 44 | Addr: addr, 45 | pool: make(chan *redis.Client, len(pool)), 46 | df: df, 47 | } 48 | for i := range pool { 49 | p.pool <- pool[i] 50 | } 51 | return &p, err 52 | } 53 | 54 | // New creates a new Pool whose connections are all created using 55 | // redis.Dial(network, addr). The size indicates the maximum number of idle 56 | // connections to have waiting to be used at any given moment. If an error is 57 | // encountered an empty (but still usable) pool is returned alongside that error 58 | func New(network, addr string, size int) (*Pool, error) { 59 | return NewCustom(network, addr, size, redis.Dial) 60 | } 61 | 62 | // Get retrieves an available redis client. If there are none available it will 63 | // create a new one on the fly 64 | func (p *Pool) Get() (*redis.Client, error) { 65 | select { 66 | case conn := <-p.pool: 67 | return conn, nil 68 | default: 69 | return p.df(p.Network, p.Addr) 70 | } 71 | } 72 | 73 | // Put returns a client back to the pool. If the pool is full the client is 74 | // closed instead. If the client is already closed (due to connection failure or 75 | // what-have-you) it will not be put back in the pool 76 | func (p *Pool) Put(conn *redis.Client) { 77 | if conn.LastCritical == nil { 78 | select { 79 | case p.pool <- conn: 80 | default: 81 | conn.Close() 82 | } 83 | } 84 | } 85 | 86 | // Cmd automatically gets one client from the pool, executes the given command 87 | // (returning its result), and puts the client back in the pool 88 | func (p *Pool) Cmd(cmd string, args ...interface{}) *redis.Resp { 89 | c, err := p.Get() 90 | if err != nil { 91 | return redis.NewResp(err) 92 | } 93 | defer p.Put(c) 94 | 95 | return c.Cmd(cmd, args...) 96 | } 97 | 98 | // Empty removes and calls Close() on all the connections currently in the pool. 99 | // Assuming there are no other connections waiting to be Put back this method 100 | // effectively closes and cleans up the pool. 101 | func (p *Pool) Empty() { 102 | var conn *redis.Client 103 | for { 104 | select { 105 | case conn = <-p.pool: 106 | conn.Close() 107 | default: 108 | return 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /util/scan_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "strconv" 5 | . "testing" 6 | 7 | "github.com/pl/radix.v2/cluster" 8 | "github.com/pl/radix.v2/redis" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestScan(t *T) { 14 | client, err := redis.Dial("tcp", "127.0.0.1:6379") 15 | require.Nil(t, err) 16 | 17 | prefix := "scanTestPrefix" 18 | 19 | fullMap := map[string]bool{} 20 | for i := 0; i < 100; i++ { 21 | key := prefix + ":" + strconv.Itoa(i) 22 | fullMap[key] = true 23 | require.Nil(t, client.Cmd("SET", key, "1").Err) 24 | } 25 | 26 | // make sure we get all results when scanning with an existing prefix 27 | ch := make(chan string) 28 | go func() { 29 | err = Scan(client, ch, "SCAN", "", prefix+":*") 30 | }() 31 | testMap := map[string]bool{} 32 | for key := range ch { 33 | testMap[key] = true 34 | } 35 | require.Nil(t, err) 36 | assert.Equal(t, fullMap, testMap) 37 | 38 | // make sure we don't get any results when scanning with a non-existing 39 | // prefix 40 | ch = make(chan string) 41 | go func() { 42 | err = Scan(client, ch, "SCAN", "", prefix+"DNE:*") 43 | }() 44 | testMap = map[string]bool{} 45 | for key := range ch { 46 | testMap[key] = true 47 | } 48 | require.Nil(t, err) 49 | assert.Equal(t, map[string]bool{}, testMap) 50 | } 51 | 52 | // Similar to TestScan, but scans over a set instead of the whole key space 53 | func TestSScan(t *T) { 54 | client, err := redis.Dial("tcp", "127.0.0.1:6379") 55 | require.Nil(t, err) 56 | 57 | key := "scanTestSet" 58 | 59 | fullMap := map[string]bool{} 60 | for i := 0; i < 100; i++ { 61 | elem := strconv.Itoa(i) 62 | fullMap[elem] = true 63 | require.Nil(t, client.Cmd("SADD", key, elem).Err) 64 | } 65 | 66 | // make sure we get all results when scanning with an existing prefix 67 | ch := make(chan string) 68 | go func() { 69 | err = Scan(client, ch, "SSCAN", key, "*") 70 | }() 71 | testMap := map[string]bool{} 72 | for elem := range ch { 73 | testMap[elem] = true 74 | } 75 | require.Nil(t, err) 76 | assert.Equal(t, fullMap, testMap) 77 | 78 | // make sure we don't get any results when scanning with a non-existing 79 | // prefix 80 | ch = make(chan string) 81 | go func() { 82 | err = Scan(client, ch, "SSCAN", key+"DNE", "*") 83 | }() 84 | testMap = map[string]bool{} 85 | for elem := range ch { 86 | testMap[elem] = true 87 | } 88 | require.Nil(t, err) 89 | assert.Equal(t, map[string]bool{}, testMap) 90 | } 91 | 92 | // Similar to TestScan, but scans over a whole cluster 93 | func TestClusterScan(t *T) { 94 | cluster, err := cluster.New("127.0.0.1:7000") 95 | require.Nil(t, err) 96 | 97 | prefix := "scanTestPrefix" 98 | 99 | fullMap := map[string]bool{} 100 | for i := 0; i < 100; i++ { 101 | key := prefix + ":" + strconv.Itoa(i) 102 | fullMap[key] = true 103 | require.Nil(t, cluster.Cmd("SET", key, "1").Err) 104 | } 105 | 106 | // make sure we get all results when scanning with an existing prefix 107 | ch := make(chan string) 108 | go func() { 109 | err = Scan(cluster, ch, "SCAN", "", prefix+":*") 110 | }() 111 | testMap := map[string]bool{} 112 | for key := range ch { 113 | testMap[key] = true 114 | } 115 | require.Nil(t, err) 116 | assert.Equal(t, fullMap, testMap) 117 | 118 | // make sure we don't get any results when scanning with a non-existing 119 | // prefix 120 | ch = make(chan string) 121 | go func() { 122 | err = Scan(cluster, ch, "SCAN", "", prefix+"DNE:*") 123 | }() 124 | testMap = map[string]bool{} 125 | for key := range ch { 126 | testMap[key] = true 127 | } 128 | require.Nil(t, err) 129 | assert.Equal(t, map[string]bool{}, testMap) 130 | } 131 | -------------------------------------------------------------------------------- /redis/doc.go: -------------------------------------------------------------------------------- 1 | // Package redis is a simple client for connecting and interacting with a single 2 | // redis instance. 3 | // 4 | // To import inside your package do: 5 | // 6 | // import "github.com/pl/radix.v2/redis" 7 | // 8 | // Connecting 9 | // 10 | // Use either Dial or DialTimeout: 11 | // 12 | // client, err := redis.Dial("tcp", "localhost:6379") 13 | // if err != nil { 14 | // // handle err 15 | // } 16 | // 17 | // Make sure to call Close on the client if you want to clean it up before the 18 | // end of the program. 19 | // 20 | // Cmd and Resp 21 | // 22 | // The Cmd method returns a Resp, which has methods for converting to various 23 | // types. Each of these methods returns an error which can either be a 24 | // connection error (e.g. timeout), an application error (e.g. key is wrong 25 | // type), or a conversion error (e.g. cannot convert to integer). You can also 26 | // directly check the error using the Err field: 27 | // 28 | // foo, err := client.Cmd("GET", "foo").Str() 29 | // if err != nil { 30 | // // handle err 31 | // } 32 | // 33 | // // Checking Err field directly 34 | // 35 | // err = client.Cmd("SET", "foo", "bar", "EX", 3600).Err 36 | // if err != nil { 37 | // // handle err 38 | // } 39 | // 40 | // Array Replies 41 | // 42 | // The elements to Array replies can be accessed as strings using List or 43 | // ListBytes, or you can use the Array method for more low level access: 44 | // 45 | // r := client.Cmd("MGET", "foo", "bar", "baz") 46 | // if r.Err != nil { 47 | // // handle error 48 | // } 49 | // 50 | // // This: 51 | // l, _ := r.List() 52 | // for _, elemStr := range l { 53 | // fmt.Println(elemStr) 54 | // } 55 | // 56 | // // is equivalent to this: 57 | // elems, err := r.Array() 58 | // for i := range elems { 59 | // elemStr, _ := elems[i].Str() 60 | // fmt.Println(elemStr) 61 | // } 62 | // 63 | // Pipelining 64 | // 65 | // Pipelining is when the client sends a bunch of commands to the server at 66 | // once, and only once all the commands have been sent does it start reading the 67 | // replies off the socket. This is supported using the PipeAppend and PipeResp 68 | // methods. PipeAppend will simply append the command to a buffer without 69 | // sending it, the first time PipeResp is called it will send all the commands 70 | // in the buffer and return the Resp for the first command that was sent. 71 | // Subsequent calls to PipeResp return Resps for subsequent commands: 72 | // 73 | // client.PipeAppend("GET", "foo") 74 | // client.PipeAppend("SET", "bar", "foo") 75 | // client.PipeAppend("DEL", "baz") 76 | // 77 | // // Read GET foo reply 78 | // foo, err := client.PipeResp().Str() 79 | // if err != nil { 80 | // // handle err 81 | // } 82 | // 83 | // // Read SET bar foo reply 84 | // if err := client.PipeResp().Err; err != nil { 85 | // // handle err 86 | // } 87 | // 88 | // // Read DEL baz reply 89 | // if err := client.PipeResp().Err; err != nil { 90 | // // handle err 91 | // } 92 | // 93 | // Flattening 94 | // 95 | // Radix will automatically flatten passed in maps and slices into the argument 96 | // list. For example, the following are all equivalent: 97 | // 98 | // client.Cmd("HMSET", "myhash", "key1", "val1", "key2", "val2") 99 | // client.Cmd("HMSET", "myhash", []string{"key1", "val1", "key2", "val2"}) 100 | // client.Cmd("HMSET", "myhash", map[string]string{ 101 | // "key1": "val1", 102 | // "key2": "val2", 103 | // }) 104 | // client.Cmd("HMSET", "myhash", [][]string{ 105 | // []string{"key1", "val1"}, 106 | // []string{"key2", "val2"}, 107 | // }) 108 | // 109 | // Radix is not picky about the types inside or outside the maps/slices, if they 110 | // don't match a subset of primitive types it will fall back to reflection to 111 | // figure out what they are and encode them. 112 | package redis 113 | -------------------------------------------------------------------------------- /pubsub/sub_test.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/pl/radix.v2/redis" 9 | ) 10 | 11 | func TestSubscribe(t *testing.T) { 12 | pub, err := redis.DialTimeout("tcp", "localhost:6379", time.Duration(10)*time.Second) 13 | if err != nil { 14 | t.Fatal(err) 15 | } 16 | 17 | client, err := redis.DialTimeout("tcp", "localhost:6379", time.Duration(10)*time.Second) 18 | if err != nil { 19 | t.Fatal(err) 20 | } 21 | sub := NewSubClient(client) 22 | 23 | channel := "subTestChannel" 24 | message := "Hello, World!" 25 | 26 | sr := sub.Subscribe(channel) 27 | if sr.Err != nil { 28 | t.Fatal(sr.Err) 29 | } 30 | 31 | if sr.Type != Subscribe { 32 | t.Fatal("Did not receive a subscribe reply") 33 | } 34 | 35 | if sr.SubCount != 1 { 36 | t.Fatal(fmt.Sprintf("Unexpected subscription count, Expected: 0, Found: %d", sr.SubCount)) 37 | } 38 | 39 | r := pub.Cmd("PUBLISH", channel, message) 40 | if r.Err != nil { 41 | t.Fatal(r.Err) 42 | } 43 | 44 | subChan := make(chan *SubResp) 45 | go func() { 46 | subChan <- sub.Receive() 47 | }() 48 | 49 | select { 50 | case sr = <-subChan: 51 | case <-time.After(time.Duration(10) * time.Second): 52 | t.Fatal("Took too long to Receive message") 53 | } 54 | 55 | if sr.Err != nil { 56 | t.Fatal(sr.Err) 57 | } 58 | 59 | if sr.Type != Message { 60 | t.Fatal("Did not receive a message reply") 61 | } 62 | 63 | if sr.Message != message { 64 | t.Fatal(fmt.Sprintf("Did not recieve expected message '%s', instead got: '%s'", message, sr.Message)) 65 | } 66 | 67 | sr = sub.Unsubscribe(channel) 68 | if sr.Err != nil { 69 | t.Fatal(sr.Err) 70 | } 71 | 72 | if sr.Type != Unsubscribe { 73 | t.Fatal("Did not receive a unsubscribe reply") 74 | } 75 | 76 | if sr.SubCount != 0 { 77 | t.Fatal(fmt.Sprintf("Unexpected subscription count, Expected: 0, Found: %d", sr.SubCount)) 78 | } 79 | } 80 | 81 | func TestPSubscribe(t *testing.T) { 82 | pub, err := redis.DialTimeout("tcp", "localhost:6379", time.Duration(10)*time.Second) 83 | if err != nil { 84 | t.Fatal(err) 85 | } 86 | 87 | client, err := redis.DialTimeout("tcp", "localhost:6379", time.Duration(10)*time.Second) 88 | if err != nil { 89 | t.Fatal(err) 90 | } 91 | sub := NewSubClient(client) 92 | 93 | pattern := "patternThen*" 94 | message := "Hello, World!" 95 | 96 | sr := sub.PSubscribe(pattern) 97 | if sr.Err != nil { 98 | t.Fatal(sr.Err) 99 | } 100 | 101 | if sr.Type != Subscribe { 102 | t.Fatal("Did not receive a subscribe reply") 103 | } 104 | 105 | if sr.SubCount != 1 { 106 | t.Fatal(fmt.Sprintf("Unexpected subscription count, Expected: 0, Found: %d", sr.SubCount)) 107 | } 108 | 109 | r := pub.Cmd("PUBLISH", "patternThenHello", message) 110 | if r.Err != nil { 111 | t.Fatal(r.Err) 112 | } 113 | 114 | subChan := make(chan *SubResp) 115 | go func() { 116 | subChan <- sub.Receive() 117 | }() 118 | 119 | select { 120 | case sr = <-subChan: 121 | case <-time.After(time.Duration(10) * time.Second): 122 | t.Fatal("Took too long to Receive message") 123 | } 124 | 125 | if sr.Err != nil { 126 | t.Fatal(sr.Err) 127 | } 128 | 129 | if sr.Type != Message { 130 | t.Fatal("Did not receive a message reply") 131 | } 132 | 133 | if sr.Pattern != pattern { 134 | t.Fatal(fmt.Sprintf("Did not recieve expected pattern '%s', instead got: '%s'", pattern, sr.Pattern)) 135 | } 136 | 137 | if sr.Message != message { 138 | t.Fatal(fmt.Sprintf("Did not recieve expected message '%s', instead got: '%s'", message, sr.Message)) 139 | } 140 | 141 | sr = sub.PUnsubscribe(pattern) 142 | if sr.Err != nil { 143 | t.Fatal(sr.Err) 144 | } 145 | 146 | if sr.Type != Unsubscribe { 147 | t.Fatal("Did not receive a unsubscribe reply") 148 | } 149 | 150 | if sr.SubCount != 0 { 151 | t.Fatal(fmt.Sprintf("Unexpected subscription count, Expected: 0, Found: %d", sr.SubCount)) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /cluster/cluster_test.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import ( 4 | "strings" 5 | . "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | 9 | "github.com/pl/radix.v2/pool" 10 | "github.com/pl/radix.v2/redis" 11 | ) 12 | 13 | // These tests assume there is a cluster running on ports 7000 and 7001, with 14 | // the first half of the slots assigned to 7000 and the second half assigned to 15 | // 7001. Calling `make up` inside of extra/cluster/testconfs will set this up 16 | // for you. This is automatically done if you are just calling `make test` in 17 | // the project root. 18 | // 19 | // It is also assumed that there is an unrelated redis instance on port 6379, 20 | // which will be connected to but not modified in any way 21 | 22 | func getCluster(t *T) *Cluster { 23 | cluster, err := New("127.0.0.1:7000") 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | // Pretend there is no throttle initially, so tests can get at least one 28 | // more reset call 29 | cluster.resetThrottle.Stop() 30 | cluster.resetThrottle = nil 31 | return cluster 32 | } 33 | 34 | func TestReset(t *T) { 35 | // Simply initializing a cluster proves Reset works to some degree, since 36 | // NewCluster calls Reset 37 | cluster := getCluster(t) 38 | old7000Pool := cluster.pools["127.0.0.1:7000"] 39 | old7001Pool := cluster.pools["127.0.0.1:7001"] 40 | 41 | // We make a bogus client and add it to the cluster to prove that it gets 42 | // removed, since it's not needed 43 | p, err := pool.New("tcp", "127.0.0.1:6379", 10) 44 | assert.Nil(t, err) 45 | cluster.pools["127.0.0.1:6379"] = p 46 | 47 | // We use resetInnerUsingPool so that we can specifically specify the pool 48 | // being used, so we don't accidentally use the 6379 one (which doesn't have 49 | // CLUSTER commands) 50 | respCh := make(chan bool) 51 | cluster.callCh <- func(c *Cluster) { 52 | err := cluster.resetInnerUsingPool(old7000Pool) 53 | assert.Nil(t, err) 54 | respCh <- true 55 | } 56 | <-respCh 57 | 58 | // Prove that the bogus client is closed 59 | _, ok := cluster.pools["127.0.0.1:6379"] 60 | assert.Equal(t, false, ok) 61 | 62 | // Prove that the remaining two addresses are still in clients, were not 63 | // reconnected, and still work 64 | assert.Equal(t, 2, len(cluster.pools)) 65 | assert.Equal(t, old7000Pool, cluster.pools["127.0.0.1:7000"]) 66 | assert.Equal(t, old7001Pool, cluster.pools["127.0.0.1:7001"]) 67 | assert.Nil(t, cluster.Cmd("GET", "foo").Err) 68 | assert.Nil(t, cluster.Cmd("GET", "bar").Err) 69 | } 70 | 71 | func TestCmd(t *T) { 72 | cluster := getCluster(t) 73 | assert.Nil(t, cluster.Cmd("SET", "foo", "bar").Err) 74 | assert.Nil(t, cluster.Cmd("SET", "bar", "foo").Err) 75 | 76 | s, err := cluster.Cmd("GET", "foo").Str() 77 | assert.Nil(t, err) 78 | assert.Equal(t, "bar", s) 79 | 80 | s, err = cluster.Cmd("GET", "bar").Str() 81 | assert.Nil(t, err) 82 | assert.Equal(t, "foo", s) 83 | } 84 | 85 | func TestCmdMiss(t *T) { 86 | cluster := getCluster(t) 87 | // foo and bar are on different nodes in our configuration. We set foo to 88 | // something, then try to retrieve it with a client pointed at a different 89 | // node. It should be redirected and returned correctly 90 | 91 | assert.Nil(t, cluster.Cmd("SET", "foo", "baz").Err) 92 | 93 | barClient, err := cluster.GetForKey("bar") 94 | assert.Nil(t, err) 95 | 96 | args := []interface{}{"foo"} 97 | r := cluster.clientCmd(barClient, "GET", args, false, nil, false) 98 | s, err := r.Str() 99 | assert.Nil(t, err) 100 | assert.Equal(t, "baz", s) 101 | } 102 | 103 | // This one is kind of a toughy. We have to set a certain slot (which isn't 104 | // being used anywhere else) to be migrating, and test that it does the right 105 | // thing. We'll use a key which isn't set so that we don't have to actually 106 | // migrate the key to get an ASK response 107 | func TestCmdAsk(t *T) { 108 | cluster := getCluster(t) 109 | key := "wat" 110 | slot := CRC16([]byte(key)) % numSlots 111 | 112 | assert.Nil(t, cluster.Cmd("DEL", key).Err) 113 | 114 | // the key "wat" originally belongs on 7000 115 | src, err := cluster.getConn("", "127.0.0.1:7000") 116 | assert.Nil(t, err) 117 | dst, err := cluster.getConn("", "127.0.0.1:7001") 118 | assert.Nil(t, err) 119 | 120 | // We need the node ids. Unfortunately, this is the best way to get them 121 | nodes, err := src.Cmd("CLUSTER", "NODES").Str() 122 | assert.Nil(t, err) 123 | lines := strings.Split(nodes, "\n") 124 | var srcID, dstID string 125 | for _, line := range lines { 126 | id := strings.Split(line, " ")[0] 127 | if id == "" { 128 | continue 129 | } 130 | if strings.Index(line, "myself,") > -1 { 131 | srcID = id 132 | } else { 133 | dstID = id 134 | } 135 | } 136 | 137 | // Start the "migration" 138 | assert.Nil(t, dst.Cmd("CLUSTER", "SETSLOT", slot, "IMPORTING", srcID).Err) 139 | assert.Nil(t, src.Cmd("CLUSTER", "SETSLOT", slot, "MIGRATING", dstID).Err) 140 | 141 | // Make sure we can still "get" the value 142 | assert.Equal(t, true, cluster.Cmd("GET", key).IsType(redis.Nil)) 143 | 144 | // Bail on the migration TODO this doesn't totally bail for some reason 145 | assert.Nil(t, dst.Cmd("CLUSTER", "SETSLOT", slot, "NODE", srcID).Err) 146 | assert.Nil(t, src.Cmd("CLUSTER", "SETSLOT", slot, "NODE", srcID).Err) 147 | } 148 | -------------------------------------------------------------------------------- /pubsub/sub.go: -------------------------------------------------------------------------------- 1 | // Package pubsub provides a wrapper around a normal redis client which makes 2 | // interacting with publish/subscribe commands much easier 3 | package pubsub 4 | 5 | import ( 6 | "container/list" 7 | "errors" 8 | "fmt" 9 | 10 | "github.com/pl/radix.v2/redis" 11 | ) 12 | 13 | // SubRespType describes the type of the response being returned from one of 14 | // the methods in this package 15 | type SubRespType uint8 16 | 17 | // The different kinds of SubRespTypes 18 | const ( 19 | Error SubRespType = iota 20 | Subscribe 21 | Unsubscribe 22 | Message 23 | ) 24 | 25 | // SubClient wraps a Redis client to provide convenience methods for Pub/Sub 26 | // functionality. 27 | type SubClient struct { 28 | Client *redis.Client 29 | messages *list.List 30 | } 31 | 32 | // SubResp wraps a Redis resp and provides convienient access to Pub/Sub info. 33 | type SubResp struct { 34 | *redis.Resp // Original Redis resp 35 | 36 | Type SubRespType 37 | Channel string // Channel resp is on (Message) 38 | Pattern string // Pattern which was matched for publishes captured by a PSubscribe 39 | SubCount int // Count of subs active after this action (Subscribe or Unsubscribe) 40 | Message string // Publish message (Message) 41 | Err error // SubResp error (Error) 42 | } 43 | 44 | // Timeout determines if this SubResp is an error type 45 | // due to a timeout reading from the network 46 | func (r *SubResp) Timeout() bool { 47 | return redis.IsTimeout(r.Resp) 48 | } 49 | 50 | // NewSubClient takes an existing, connected redis.Client and wraps it in a 51 | // SubClient, returning that. The passed in redis.Client should not be used as 52 | // long as the SubClient is also being used 53 | func NewSubClient(client *redis.Client) *SubClient { 54 | return &SubClient{client, &list.List{}} 55 | } 56 | 57 | // Subscribe makes a Redis "SUBSCRIBE" command on the provided channels 58 | func (c *SubClient) Subscribe(channels ...interface{}) error { 59 | return c.Client.SendCmd("SUBSCRIBE", channels...) 60 | } 61 | 62 | // PSubscribe makes a Redis "PSUBSCRIBE" command on the provided patterns 63 | func (c *SubClient) PSubscribe(patterns ...interface{}) error { 64 | return c.Client.SendCmd("PSUBSCRIBE", patterns...) 65 | } 66 | 67 | // Unsubscribe makes a Redis "UNSUBSCRIBE" command on the provided channels 68 | func (c *SubClient) Unsubscribe(channels ...interface{}) error { 69 | return c.Client.SendCmd("UNSUBSCRIBE", channels...) 70 | } 71 | 72 | // PUnsubscribe makes a Redis "PUNSUBSCRIBE" command on the provided patterns 73 | func (c *SubClient) PUnsubscribe(patterns ...interface{}) error { 74 | return c.Client.SendCmd("PUNSUBSCRIBE", patterns...) 75 | } 76 | 77 | // Receive returns the next publish resp on the Redis client. It is possible 78 | // Receive will timeout, and the *SubResp will be an Error. You can use the 79 | // Timeout() method on SubResp to easily determine if that is the case. If this 80 | // is the case you can call Receive again to continue listening for publishes 81 | func (c *SubClient) Receive() *SubResp { 82 | return c.receive(false) 83 | } 84 | 85 | func (c *SubClient) receive(skipBuffer bool) *SubResp { 86 | if c.messages.Len() > 0 && !skipBuffer { 87 | v := c.messages.Remove(c.messages.Front()) 88 | return v.(*SubResp) 89 | } 90 | r := c.Client.ReadResp() 91 | return c.parseResp(r) 92 | } 93 | 94 | func (c *SubClient) parseResp(resp *redis.Resp) *SubResp { 95 | sr := &SubResp{Resp: resp} 96 | var elems []*redis.Resp 97 | 98 | switch { 99 | case resp.IsType(redis.Array): 100 | elems, _ = resp.Array() 101 | if len(elems) < 3 { 102 | sr.Err = errors.New("resp is not formatted as a subscription resp") 103 | sr.Type = Error 104 | return sr 105 | } 106 | 107 | case resp.IsType(redis.Err): 108 | sr.Err = resp.Err 109 | sr.Type = Error 110 | return sr 111 | 112 | default: 113 | sr.Err = errors.New("resp is not formatted as a subscription resp") 114 | sr.Type = Error 115 | return sr 116 | } 117 | 118 | rtype, err := elems[0].Str() 119 | if err != nil { 120 | sr.Err = fmt.Errorf("resp type: %s", err) 121 | sr.Type = Error 122 | return sr 123 | } 124 | 125 | //first element 126 | switch rtype { 127 | case "subscribe", "psubscribe": 128 | sr.Type = Subscribe 129 | channel, channelErr := elems[1].Str() 130 | if channelErr != nil { 131 | sr.Err = fmt.Errorf("subscribe channel: %s", channelErr) 132 | sr.Type = Error 133 | return sr 134 | } else { 135 | sr.Channel = channel 136 | } 137 | count, countErr := elems[2].Int() 138 | if countErr != nil { 139 | sr.Err = fmt.Errorf("subscribe count: %s", err) 140 | sr.Type = Error 141 | return sr 142 | } else { 143 | sr.SubCount = int(count) 144 | } 145 | 146 | case "unsubscribe", "punsubscribe": 147 | sr.Type = Unsubscribe 148 | sr.Channel = elems[1].String() 149 | count, err := elems[2].Int() 150 | if err != nil { 151 | sr.Err = fmt.Errorf("unsubscribe count: %s", err) 152 | sr.Type = Error 153 | } else { 154 | sr.SubCount = int(count) 155 | } 156 | 157 | case "message", "pmessage": 158 | var chanI, msgI int 159 | 160 | if rtype == "message" { 161 | chanI, msgI = 1, 2 162 | } else { // "pmessage" 163 | chanI, msgI = 2, 3 164 | pattern, err := elems[1].Str() 165 | if err != nil { 166 | sr.Err = fmt.Errorf("message pattern: %s", err) 167 | sr.Type = Error 168 | return sr 169 | } 170 | sr.Pattern = pattern 171 | } 172 | 173 | sr.Type = Message 174 | channel, err := elems[chanI].Str() 175 | if err != nil { 176 | sr.Err = fmt.Errorf("message channel: %s", err) 177 | sr.Type = Error 178 | return sr 179 | } 180 | sr.Channel = channel 181 | msg, err := elems[msgI].Str() 182 | if err != nil { 183 | sr.Err = fmt.Errorf("message msg: %s", err) 184 | sr.Type = Error 185 | } else { 186 | sr.Message = msg 187 | } 188 | default: 189 | sr.Err = errors.New("suscription multiresp has invalid type: " + rtype) 190 | sr.Type = Error 191 | } 192 | return sr 193 | } 194 | -------------------------------------------------------------------------------- /redis/resp_test.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | . "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func pretendRead(s string) *Resp { 13 | buf := bytes.NewBufferString(s) 14 | return NewRespReader(buf).Read() 15 | } 16 | 17 | func TestRead(t *T) { 18 | 19 | r := pretendRead("") 20 | assert.NotNil(t, r.Err) 21 | 22 | // Simple string 23 | r = pretendRead("+ohey\r\n") 24 | assert.Equal(t, SimpleStr, r.typ) 25 | assert.Exactly(t, []byte("ohey"), r.val) 26 | s, err := r.Str() 27 | assert.Nil(t, err) 28 | assert.Equal(t, "ohey", s) 29 | 30 | // Empty simple string 31 | r = pretendRead("+\r\n") 32 | assert.Equal(t, SimpleStr, r.typ) 33 | assert.Exactly(t, []byte(""), r.val) 34 | s, err = r.Str() 35 | assert.Nil(t, err) 36 | assert.Equal(t, "", s) 37 | 38 | // Error 39 | r = pretendRead("-ohey\r\n") 40 | assert.Equal(t, AppErr, r.typ) 41 | assert.Exactly(t, errors.New("ohey"), r.val) 42 | assert.Equal(t, "ohey", r.Err.Error()) 43 | 44 | // Empty error 45 | r = pretendRead("-\r\n") 46 | assert.Equal(t, AppErr, r.typ) 47 | assert.Exactly(t, errors.New(""), r.val) 48 | assert.Equal(t, "", r.Err.Error()) 49 | 50 | // Int 51 | r = pretendRead(":1024\r\n") 52 | assert.Equal(t, Int, r.typ) 53 | assert.Exactly(t, int64(1024), r.val) 54 | i, err := r.Int() 55 | assert.Nil(t, err) 56 | assert.Equal(t, 1024, i) 57 | 58 | // Bulk string 59 | r = pretendRead("$3\r\nfoo\r\n") 60 | assert.Equal(t, BulkStr, r.typ) 61 | assert.Exactly(t, []byte("foo"), r.val) 62 | s, err = r.Str() 63 | assert.Nil(t, err) 64 | assert.Equal(t, "foo", s) 65 | 66 | // Empty bulk string 67 | r = pretendRead("$0\r\n\r\n") 68 | assert.Equal(t, BulkStr, r.typ) 69 | assert.Exactly(t, []byte(""), r.val) 70 | s, err = r.Str() 71 | assert.Nil(t, err) 72 | assert.Equal(t, "", s) 73 | 74 | // Nil bulk string 75 | r = pretendRead("$-1\r\n") 76 | assert.Equal(t, Nil, r.typ) 77 | 78 | // Array 79 | r = pretendRead("*2\r\n+foo\r\n+bar\r\n") 80 | assert.Equal(t, Array, r.typ) 81 | assert.Equal(t, 2, len(r.val.([]Resp))) 82 | assert.Equal(t, SimpleStr, r.val.([]Resp)[0].typ) 83 | assert.Exactly(t, []byte("foo"), r.val.([]Resp)[0].val) 84 | assert.Equal(t, SimpleStr, r.val.([]Resp)[1].typ) 85 | assert.Exactly(t, []byte("bar"), r.val.([]Resp)[1].val) 86 | l, err := r.List() 87 | assert.Nil(t, err) 88 | assert.Equal(t, []string{"foo", "bar"}, l) 89 | b, err := r.ListBytes() 90 | assert.Nil(t, err) 91 | assert.Equal(t, [][]byte{[]byte("foo"), []byte("bar")}, b) 92 | m, err := r.Map() 93 | assert.Nil(t, err) 94 | assert.Equal(t, map[string]string{"foo": "bar"}, m) 95 | 96 | // Empty Array 97 | r = pretendRead("*0\r\n") 98 | assert.Equal(t, Array, r.typ) 99 | assert.Equal(t, 0, len(r.val.([]Resp))) 100 | 101 | // Nil Array 102 | r = pretendRead("*-1\r\n") 103 | assert.Equal(t, Nil, r.typ) 104 | 105 | // Embedded Array 106 | r = pretendRead("*3\r\n+foo\r\n+bar\r\n*2\r\n+foo\r\n+bar\r\n") 107 | assert.Equal(t, Array, r.typ) 108 | assert.Equal(t, 3, len(r.val.([]Resp))) 109 | assert.Equal(t, SimpleStr, r.val.([]Resp)[0].typ) 110 | assert.Exactly(t, []byte("foo"), r.val.([]Resp)[0].val) 111 | assert.Equal(t, SimpleStr, r.val.([]Resp)[1].typ) 112 | assert.Exactly(t, []byte("bar"), r.val.([]Resp)[1].val) 113 | r = &r.val.([]Resp)[2] 114 | assert.Equal(t, 2, len(r.val.([]Resp))) 115 | assert.Equal(t, SimpleStr, r.val.([]Resp)[0].typ) 116 | assert.Exactly(t, []byte("foo"), r.val.([]Resp)[0].val) 117 | assert.Equal(t, SimpleStr, r.val.([]Resp)[1].typ) 118 | assert.Exactly(t, []byte("bar"), r.val.([]Resp)[1].val) 119 | 120 | // Test that two bulks in a row read correctly 121 | r = pretendRead("*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n") 122 | assert.Equal(t, Array, r.typ) 123 | assert.Equal(t, 2, len(r.val.([]Resp))) 124 | assert.Equal(t, BulkStr, r.val.([]Resp)[0].typ) 125 | assert.Exactly(t, []byte("foo"), r.val.([]Resp)[0].val) 126 | assert.Equal(t, BulkStr, r.val.([]Resp)[1].typ) 127 | assert.Exactly(t, []byte("bar"), r.val.([]Resp)[1].val) 128 | } 129 | 130 | type arbitraryTest struct { 131 | val interface{} 132 | expect []byte 133 | } 134 | 135 | var nilResp = pretendRead("$-1\r\n") 136 | 137 | var arbitraryTests = []arbitraryTest{ 138 | {[]byte("OHAI"), []byte("$4\r\nOHAI\r\n")}, 139 | {"OHAI", []byte("$4\r\nOHAI\r\n")}, 140 | {true, []byte("$1\r\n1\r\n")}, 141 | {false, []byte("$1\r\n0\r\n")}, 142 | {nil, []byte("$-1\r\n")}, 143 | {80, []byte(":80\r\n")}, 144 | {int64(-80), []byte(":-80\r\n")}, 145 | {uint64(80), []byte(":80\r\n")}, 146 | {float32(0.1234), []byte("$6\r\n0.1234\r\n")}, 147 | {float64(0.1234), []byte("$6\r\n0.1234\r\n")}, 148 | {errors.New("hi"), []byte("-hi\r\n")}, 149 | 150 | {nilResp, []byte("$-1\r\n")}, 151 | 152 | {[]int{1, 2, 3}, []byte("*3\r\n:1\r\n:2\r\n:3\r\n")}, 153 | {map[int]int{1: 2}, []byte("*2\r\n:1\r\n:2\r\n")}, 154 | 155 | {NewRespSimple("OK"), []byte("+OK\r\n")}, 156 | } 157 | 158 | var arbitraryAsStringTests = []arbitraryTest{ 159 | {[]byte("OHAI"), []byte("$4\r\nOHAI\r\n")}, 160 | {"OHAI", []byte("$4\r\nOHAI\r\n")}, 161 | {true, []byte("$1\r\n1\r\n")}, 162 | {false, []byte("$1\r\n0\r\n")}, 163 | {nil, []byte("$0\r\n\r\n")}, 164 | {80, []byte("$2\r\n80\r\n")}, 165 | {int64(-80), []byte("$3\r\n-80\r\n")}, 166 | {uint64(80), []byte("$2\r\n80\r\n")}, 167 | {float32(0.1234), []byte("$6\r\n0.1234\r\n")}, 168 | {float64(0.1234), []byte("$6\r\n0.1234\r\n")}, 169 | {errors.New("hi"), []byte("$2\r\nhi\r\n")}, 170 | 171 | {nilResp, []byte("$-1\r\n")}, 172 | 173 | {[]int{1, 2, 3}, []byte("*3\r\n$1\r\n1\r\n$1\r\n2\r\n$1\r\n3\r\n")}, 174 | {map[int]int{1: 2}, []byte("*2\r\n$1\r\n1\r\n$1\r\n2\r\n")}, 175 | 176 | {NewRespSimple("OK"), []byte("+OK\r\n")}, 177 | } 178 | 179 | var arbitraryAsFlattenedStringsTests = []arbitraryTest{ 180 | { 181 | []interface{}{"wat", map[string]interface{}{ 182 | "foo": 1, 183 | }}, 184 | []byte("*3\r\n$3\r\nwat\r\n$3\r\nfoo\r\n$1\r\n1\r\n"), 185 | }, 186 | } 187 | 188 | func TestWriteArbitrary(t *T) { 189 | var err error 190 | buf := bytes.NewBuffer([]byte{}) 191 | for _, test := range arbitraryTests { 192 | buf.Reset() 193 | _, err = NewResp(test.val).WriteTo(buf) 194 | assert.Nil(t, err) 195 | assert.Equal(t, test.expect, buf.Bytes()) 196 | } 197 | } 198 | 199 | func TestWriteArbitraryAsFlattenedStrings(t *T) { 200 | var err error 201 | buf := bytes.NewBuffer([]byte{}) 202 | for _, test := range arbitraryAsFlattenedStringsTests { 203 | buf.Reset() 204 | _, err = NewRespFlattenedStrings(test.val).WriteTo(buf) 205 | assert.Nil(t, err) 206 | assert.Equal(t, test.expect, buf.Bytes()) 207 | } 208 | } 209 | 210 | func TestFloat64(t *T) { 211 | r := NewResp(4) 212 | _, err := r.Float64() 213 | assert.NotNil(t, err) 214 | 215 | testErr := fmt.Errorf("test") 216 | r = NewResp(testErr) 217 | _, err = r.Float64() 218 | assert.Equal(t, testErr, err) 219 | 220 | r = NewResp("test") 221 | _, err = r.Float64() 222 | assert.NotNil(t, err) 223 | 224 | r = NewResp("5.0") 225 | f, err := r.Float64() 226 | assert.Nil(t, err) 227 | assert.Equal(t, float64(5.0), f) 228 | 229 | } 230 | -------------------------------------------------------------------------------- /redis/client.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "reflect" 9 | "time" 10 | ) 11 | 12 | const ( 13 | bufSize int = 4096 14 | ) 15 | 16 | // ErrPipelineEmpty is returned from PipeResp() to indicate that all commands 17 | // which were put into the pipeline have had their responses read 18 | var ErrPipelineEmpty = errors.New("pipeline queue empty") 19 | 20 | // Client describes a Redis client. 21 | type Client struct { 22 | conn net.Conn 23 | respReader *RespReader 24 | timeout time.Duration 25 | pending []request 26 | writeScratch []byte 27 | writeBuf *bytes.Buffer 28 | 29 | completed, completedHead []*Resp 30 | 31 | // The network/address of the redis instance this client is connected to. 32 | // These will be wahtever strings were passed into the Dial function when 33 | // creating this connection 34 | Network, Addr string 35 | 36 | // The most recent critical network error which occured when either reading 37 | // or writing. A critical network error is one in which the connection was 38 | // found to be no longer usable; in essence, any error except a timeout. 39 | // Close is automatically called on the client when it encounters a critical 40 | // network error 41 | LastCritical error 42 | } 43 | 44 | // request describes a client's request to the redis server 45 | type request struct { 46 | cmd string 47 | args []interface{} 48 | } 49 | 50 | // DialTimeout connects to the given Redis server with the given timeout, which 51 | // will be used as the read/write timeout when communicating with redis 52 | func DialTimeout(network, addr string, timeout time.Duration) (*Client, error) { 53 | // establish a connection 54 | conn, err := net.DialTimeout(network, addr, timeout) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | completed := make([]*Resp, 0, 10) 60 | return &Client{ 61 | conn: conn, 62 | respReader: NewRespReader(conn), 63 | timeout: timeout, 64 | writeScratch: make([]byte, 0, 128), 65 | writeBuf: bytes.NewBuffer(make([]byte, 0, 128)), 66 | completed: completed, 67 | completedHead: completed, 68 | Network: network, 69 | Addr: addr, 70 | }, nil 71 | } 72 | 73 | // Dial connects to the given Redis server. 74 | func Dial(network, addr string) (*Client, error) { 75 | return DialTimeout(network, addr, time.Duration(0)) 76 | } 77 | 78 | // Close closes the connection. 79 | func (c *Client) Close() error { 80 | return c.conn.Close() 81 | } 82 | 83 | // Sends the command, but does not wait for a response. 84 | func (c *Client) SendCmd(cmd string, args ...interface{}) error { 85 | err := c.writeRequest(request{cmd, args}) 86 | if err != nil { 87 | return err 88 | } 89 | return nil 90 | } 91 | 92 | // Sends the command and waits for a response. 93 | func (c *Client) Cmd(cmd string, args ...interface{}) *Resp { 94 | err := c.SendCmd(cmd, args...) 95 | if err != nil { 96 | return newRespIOErr(err) 97 | } 98 | return c.ReadResp() 99 | } 100 | 101 | // PipeAppend adds the given call to the pipeline queue. 102 | // Use PipeResp() to read the response. 103 | func (c *Client) PipeAppend(cmd string, args ...interface{}) { 104 | c.pending = append(c.pending, request{cmd, args}) 105 | } 106 | 107 | // PipeResp returns the reply for the next request in the pipeline queue. Err 108 | // with ErrPipelineEmpty is returned if the pipeline queue is empty. 109 | func (c *Client) PipeResp() *Resp { 110 | if len(c.completed) > 0 { 111 | r := c.completed[0] 112 | c.completed = c.completed[1:] 113 | return r 114 | } 115 | 116 | if len(c.pending) == 0 { 117 | return NewResp(ErrPipelineEmpty) 118 | } 119 | 120 | nreqs := len(c.pending) 121 | err := c.writeRequest(c.pending...) 122 | c.pending = nil 123 | if err != nil { 124 | return newRespIOErr(err) 125 | } 126 | c.completed = c.completedHead 127 | for i := 0; i < nreqs; i++ { 128 | r := c.ReadResp() 129 | c.completed = append(c.completed, r) 130 | } 131 | 132 | // At this point c.completed should have something in it 133 | return c.PipeResp() 134 | } 135 | 136 | // ReadResp will read a Resp off of the connection without sending anything 137 | // first (useful after you've sent a SUSBSCRIBE command). This will block until 138 | // a reply is received or the timeout is reached (returning the IOErr). You can 139 | // use IsTimeout to check if the Resp is due to a Timeout 140 | // 141 | // Note: this is a more low-level function, you really shouldn't have to 142 | // actually use it unless you're writing your own pub/sub code 143 | func (c *Client) ReadResp() *Resp { 144 | if c.timeout != 0 { 145 | c.conn.SetReadDeadline(time.Now().Add(c.timeout)) 146 | } 147 | r := c.respReader.Read() 148 | if r.IsType(IOErr) && !IsTimeout(r) { 149 | c.LastCritical = r.Err 150 | c.Close() 151 | } 152 | return r 153 | } 154 | 155 | func (c *Client) writeRequest(requests ...request) error { 156 | if c.timeout != 0 { 157 | c.conn.SetWriteDeadline(time.Now().Add(c.timeout)) 158 | } 159 | var err error 160 | outer: 161 | for i := range requests { 162 | c.writeBuf.Reset() 163 | elems := flattenedLength(requests[i].args...) + 1 164 | _, err = writeArrayHeader(c.writeBuf, c.writeScratch, int64(elems)) 165 | if err != nil { 166 | break 167 | } 168 | 169 | _, err = writeTo(c.writeBuf, c.writeScratch, requests[i].cmd, true, true) 170 | if err != nil { 171 | break 172 | } 173 | 174 | for _, arg := range requests[i].args { 175 | _, err = writeTo(c.writeBuf, c.writeScratch, arg, true, true) 176 | if err != nil { 177 | break outer 178 | } 179 | } 180 | 181 | if _, err = c.writeBuf.WriteTo(c.conn); err != nil { 182 | break 183 | } 184 | } 185 | if err != nil { 186 | c.LastCritical = err 187 | c.Close() 188 | return err 189 | } 190 | return nil 191 | } 192 | 193 | var errBadCmdNoKey = errors.New("bad command, no key") 194 | 195 | // KeyFromArgs is a helper function which other library packages which wrap this 196 | // one might find useful. It takes in a set of arguments which might be passed 197 | // into Cmd and returns the first key for the command. Since radix supports 198 | // complicated arguments (like slices, slices of slices, maps, etc...) this is 199 | // not always as straightforward as it might seem, so this helper function is 200 | // provided. 201 | // 202 | // An error is returned if no key can be determined 203 | func KeyFromArgs(args ...interface{}) (string, error) { 204 | if len(args) == 0 { 205 | return "", errBadCmdNoKey 206 | } 207 | arg := args[0] 208 | switch argv := arg.(type) { 209 | case string: 210 | return argv, nil 211 | case []byte: 212 | return string(argv), nil 213 | default: 214 | switch reflect.TypeOf(arg).Kind() { 215 | case reflect.Slice: 216 | argVal := reflect.ValueOf(arg) 217 | if argVal.Len() < 1 { 218 | return "", errBadCmdNoKey 219 | } 220 | first := argVal.Index(0).Interface() 221 | return KeyFromArgs(first) 222 | case reflect.Map: 223 | // Maps have no order, we can't possibly choose a key out of one 224 | return "", errBadCmdNoKey 225 | default: 226 | return fmt.Sprint(arg), nil 227 | } 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /sentinel/sentinel.go: -------------------------------------------------------------------------------- 1 | // Package sentinel provides a convenient interface with a redis sentinel which 2 | // will automatically handle pooling connections and automatic failover. 3 | // 4 | // Here's an example of creating a sentinel client and then using it to perform 5 | // some commands 6 | // 7 | // func example() error { 8 | // // If there exists sentinel masters "bucket0" and "bucket1", and we want 9 | // // out client to create pools for both: 10 | // client, err := sentinel.NewClient("tcp", "localhost:6379", 100, "bucket0", "bucket1") 11 | // if err != nil { 12 | // return err 13 | // } 14 | // 15 | // if err := exampleCmd(client); err != nil { 16 | // return err 17 | // } 18 | // 19 | // return nil 20 | // } 21 | // 22 | // func exampleCmd(client *sentinel.Client) error { 23 | // conn, err := client.GetMaster("bucket0") 24 | // if err != nil { 25 | // return redisErr 26 | // } 27 | // defer client.PutMaster("bucket0", conn) 28 | // 29 | // i, err := conn.Cmd("GET", "foo").Int() 30 | // if err != nil { 31 | // return err 32 | // } 33 | // 34 | // if err := conn.Cmd("SET", "foo", i+1); err != nil { 35 | // return err 36 | // } 37 | // 38 | // return nil 39 | // } 40 | // 41 | // This package only gaurantees that when GetMaster is called the returned 42 | // connection will be a connection to the master as of the moment that method is 43 | // called. It is still possible that there is a failover as that connection is 44 | // being used by the application. 45 | // 46 | // As a final note, a Client can be interacted with from multiple routines at 47 | // once safely, except for the Close method. To safely Close, ensure that only 48 | // one routine ever makes the call and that once the call is made no other 49 | // methods are ever called by any routines. 50 | package sentinel 51 | 52 | import ( 53 | "errors" 54 | "strings" 55 | 56 | "github.com/pl/radix.v2/pool" 57 | "github.com/pl/radix.v2/pubsub" 58 | "github.com/pl/radix.v2/redis" 59 | ) 60 | 61 | // ClientError is an error wrapper returned by operations in this package. It 62 | // implements the error interface and can therefore be passed around as a normal 63 | // error. 64 | type ClientError struct { 65 | err error 66 | 67 | // If this is true the error is due to a problem with the sentinel 68 | // connection, either it being closed or otherwise unavailable. If false the 69 | // error is due to some other circumstances. This is useful if you want to 70 | // implement some kind of reconnecting to sentinel on an error. 71 | SentinelErr bool 72 | } 73 | 74 | // Error implements the error protocol 75 | func (ce *ClientError) Error() string { 76 | return ce.err.Error() 77 | } 78 | 79 | type getReqRet struct { 80 | conn *redis.Client 81 | err *ClientError 82 | } 83 | 84 | type getReq struct { 85 | name string 86 | retCh chan *getReqRet 87 | } 88 | 89 | type putReq struct { 90 | name string 91 | conn *redis.Client 92 | } 93 | 94 | type switchMaster struct { 95 | name string 96 | addr string 97 | } 98 | 99 | // Client communicates with a sentinel instance and manages connection pools of 100 | // active masters 101 | type Client struct { 102 | poolSize int 103 | masterPools map[string]*pool.Pool 104 | subClient *pubsub.SubClient 105 | 106 | getCh chan *getReq 107 | putCh chan *putReq 108 | closeCh chan struct{} 109 | 110 | alwaysErr *ClientError 111 | alwaysErrCh chan *ClientError 112 | switchMasterCh chan *switchMaster 113 | } 114 | 115 | // NewClient creates a sentinel client. Connects to the given sentinel instance, 116 | // pulls the information for the masters of the given names, and creates an 117 | // intial pool of connections for each master. The client will automatically 118 | // replace the pool for any master should sentinel decide to fail the master 119 | // over. The returned error is a *ClientError. 120 | func NewClient( 121 | network, address string, poolSize int, names ...string, 122 | ) ( 123 | *Client, error, 124 | ) { 125 | 126 | // We use this to fetch initial details about masters before we upgrade it 127 | // to a pubsub client 128 | client, err := redis.Dial(network, address) 129 | if err != nil { 130 | return nil, &ClientError{err: err} 131 | } 132 | 133 | masterPools := map[string]*pool.Pool{} 134 | for _, name := range names { 135 | r := client.Cmd("SENTINEL", "MASTER", name) 136 | l, err := r.List() 137 | if err != nil { 138 | return nil, &ClientError{err: err, SentinelErr: true} 139 | } 140 | addr := l[3] + ":" + l[5] 141 | pool, err := pool.New("tcp", addr, poolSize) 142 | if err != nil { 143 | return nil, &ClientError{err: err} 144 | } 145 | masterPools[name] = pool 146 | } 147 | 148 | subClient := pubsub.NewSubClient(client) 149 | subSendErr := subClient.Subscribe("+switch-master") 150 | if subSendErr != nil { 151 | return nil, &ClientError{err: subSendErr, SentinelErr: true} 152 | } 153 | r := client.ReadResp() 154 | if r.Err != nil { 155 | return nil, &ClientError{err: r.Err, SentinelErr: true} 156 | } 157 | 158 | c := &Client{ 159 | poolSize: poolSize, 160 | masterPools: masterPools, 161 | subClient: subClient, 162 | getCh: make(chan *getReq), 163 | putCh: make(chan *putReq), 164 | closeCh: make(chan struct{}), 165 | alwaysErrCh: make(chan *ClientError), 166 | switchMasterCh: make(chan *switchMaster), 167 | } 168 | 169 | go c.subSpin() 170 | go c.spin() 171 | return c, nil 172 | } 173 | 174 | func (c *Client) subSpin() { 175 | for { 176 | r := c.subClient.Receive() 177 | if r.Timeout() { 178 | continue 179 | } 180 | if r.Err != nil { 181 | select { 182 | case c.alwaysErrCh <- &ClientError{err: r.Err, SentinelErr: true}: 183 | case <-c.closeCh: 184 | } 185 | return 186 | } 187 | sMsg := strings.Split(r.Message, " ") 188 | name := sMsg[0] 189 | newAddr := sMsg[3] + ":" + sMsg[4] 190 | select { 191 | case c.switchMasterCh <- &switchMaster{name, newAddr}: 192 | case <-c.closeCh: 193 | return 194 | } 195 | } 196 | } 197 | 198 | func (c *Client) spin() { 199 | for { 200 | select { 201 | case req := <-c.getCh: 202 | if c.alwaysErr != nil { 203 | req.retCh <- &getReqRet{nil, c.alwaysErr} 204 | continue 205 | } 206 | pool, ok := c.masterPools[req.name] 207 | if !ok { 208 | err := errors.New("unknown name: " + req.name) 209 | req.retCh <- &getReqRet{nil, &ClientError{err: err}} 210 | continue 211 | } 212 | conn, err := pool.Get() 213 | if err != nil { 214 | req.retCh <- &getReqRet{nil, &ClientError{err: err}} 215 | continue 216 | } 217 | req.retCh <- &getReqRet{conn, nil} 218 | 219 | case req := <-c.putCh: 220 | if pool, ok := c.masterPools[req.name]; ok { 221 | pool.Put(req.conn) 222 | } 223 | 224 | case err := <-c.alwaysErrCh: 225 | c.alwaysErr = err 226 | 227 | case sm := <-c.switchMasterCh: 228 | if p, ok := c.masterPools[sm.name]; ok { 229 | p.Empty() 230 | p, _ = pool.New("tcp", sm.addr, c.poolSize) 231 | c.masterPools[sm.name] = p 232 | } 233 | 234 | case <-c.closeCh: 235 | for name := range c.masterPools { 236 | c.masterPools[name].Empty() 237 | } 238 | c.subClient.Client.Close() 239 | close(c.getCh) 240 | close(c.putCh) 241 | return 242 | } 243 | } 244 | } 245 | 246 | // GetMaster retrieves a connection for the master of the given name. If 247 | // sentinel has become unreachable this will always return an error. Close 248 | // should be called in that case. The returned error is a *ClientError. 249 | func (c *Client) GetMaster(name string) (*redis.Client, error) { 250 | req := getReq{name, make(chan *getReqRet)} 251 | c.getCh <- &req 252 | ret := <-req.retCh 253 | if ret.err != nil { 254 | return nil, ret.err 255 | } 256 | return ret.conn, nil 257 | } 258 | 259 | // PutMaster return a connection for a master of a given name 260 | func (c *Client) PutMaster(name string, client *redis.Client) { 261 | c.putCh <- &putReq{name, client} 262 | } 263 | -------------------------------------------------------------------------------- /cluster/cluster.go: -------------------------------------------------------------------------------- 1 | // Package cluster implements an almost drop-in replacement for a normal Client 2 | // which accounts for a redis cluster setup. It will transparently redirect 3 | // requests to the correct nodes, as well as keep track of which slots are 4 | // mapped to which nodes and updating them accordingly so requests can remain as 5 | // fast as possible. 6 | // 7 | // This package will initially call `cluster slots` in order to retrieve an 8 | // initial idea of the topology of the cluster, but other than that will not 9 | // make any other extraneous calls. 10 | // 11 | // All methods on a Cluster are thread-safe, and connections are automatically 12 | // pooled 13 | package cluster 14 | 15 | import ( 16 | "errors" 17 | "fmt" 18 | "strconv" 19 | "strings" 20 | "time" 21 | 22 | "github.com/pl/radix.v2/pool" 23 | "github.com/pl/radix.v2/redis" 24 | ) 25 | 26 | const numSlots = 16384 27 | 28 | type mapping [numSlots]string 29 | 30 | func errorResp(err error) *redis.Resp { 31 | return redis.NewResp(err) 32 | } 33 | 34 | func errorRespf(format string, args ...interface{}) *redis.Resp { 35 | return errorResp(fmt.Errorf(format, args...)) 36 | } 37 | 38 | var ( 39 | // ErrBadCmdNoKey is an error reply returned when no key is given to the Cmd 40 | // method 41 | ErrBadCmdNoKey = errors.New("bad command, no key") 42 | 43 | errNoPools = errors.New("no pools to pull from") 44 | ) 45 | 46 | // Cluster wraps a Client and accounts for all redis cluster logic 47 | type Cluster struct { 48 | o Opts 49 | mapping 50 | pools map[string]*pool.Pool 51 | poolThrottles map[string]<-chan time.Time 52 | resetThrottle *time.Ticker 53 | callCh chan func(*Cluster) 54 | stopCh chan struct{} 55 | 56 | // This is written to whenever a slot miss (either a MOVED or ASK) is 57 | // encountered. This is mainly for informational purposes, it's not meant to 58 | // be actionable. If nothing is listening the message is dropped 59 | MissCh chan struct{} 60 | 61 | // This is written to whenever the cluster discovers there's been some kind 62 | // of re-ordering/addition/removal of cluster nodes. If nothing is listening 63 | // the message is dropped 64 | ChangeCh chan struct{} 65 | } 66 | 67 | // Opts are Options which can be passed in to NewWithOpts. If any are set to 68 | // their zero value the default value will be used instead 69 | type Opts struct { 70 | 71 | // Required. The address of a single node in the cluster 72 | Addr string 73 | 74 | // Read and write timeout which should be used on individual redis clients. 75 | // Default is to not set the timeout and let the connection use it's default 76 | Timeout time.Duration 77 | 78 | // The size of the connection pool to use for each host. Default is 10 79 | PoolSize int 80 | 81 | // The time which must elapse between subsequent calls to create a new 82 | // connection pool (on a per redis instance basis) in certain circumstances. 83 | // The default is 500 milliseconds 84 | PoolThrottle time.Duration 85 | 86 | // The time which must elapse between subsequent calls to Reset(). The 87 | // default is 500 milliseconds 88 | ResetThrottle time.Duration 89 | } 90 | 91 | // New will perform the following steps to initialize: 92 | // 93 | // - Connect to the node given in the argument 94 | // 95 | // - Use that node to call CLUSTER SLOTS. The return from this is used to build 96 | // a mapping of slot number -> connection. At the same time any new connections 97 | // which need to be made are created here. 98 | // 99 | // - *Cluster is returned 100 | // 101 | // At this point the Cluster has a complete view of the cluster's topology and 102 | // can immediately start performing commands with (theoretically) zero slot 103 | // misses 104 | func New(addr string) (*Cluster, error) { 105 | return NewWithOpts(Opts{ 106 | Addr: addr, 107 | }) 108 | } 109 | 110 | // NewWithOpts is the same as NewCluster, but with more fine-tuned 111 | // configuration options. See Opts for more available options 112 | func NewWithOpts(o Opts) (*Cluster, error) { 113 | if o.PoolSize == 0 { 114 | o.PoolSize = 10 115 | } 116 | if o.PoolThrottle == 0 { 117 | o.PoolThrottle = 500 * time.Millisecond 118 | } 119 | if o.ResetThrottle == 0 { 120 | o.ResetThrottle = 500 * time.Millisecond 121 | } 122 | 123 | c := Cluster{ 124 | o: o, 125 | mapping: mapping{}, 126 | pools: map[string]*pool.Pool{}, 127 | poolThrottles: map[string]<-chan time.Time{}, 128 | callCh: make(chan func(*Cluster)), 129 | stopCh: make(chan struct{}), 130 | MissCh: make(chan struct{}), 131 | ChangeCh: make(chan struct{}), 132 | } 133 | 134 | initialPool, err := c.newPool(o.Addr, true) 135 | if err != nil { 136 | return nil, err 137 | } 138 | c.pools[o.Addr] = initialPool 139 | 140 | go c.spin() 141 | if err := c.Reset(); err != nil { 142 | return nil, err 143 | } 144 | return &c, nil 145 | } 146 | 147 | func (c *Cluster) newPool(addr string, clearThrottle bool) (*pool.Pool, error) { 148 | if clearThrottle { 149 | delete(c.poolThrottles, addr) 150 | } else if throttle, ok := c.poolThrottles[addr]; ok { 151 | select { 152 | case <-throttle: 153 | delete(c.poolThrottles, addr) 154 | default: 155 | return nil, fmt.Errorf("newPool(%s) throttled", addr) 156 | } 157 | } 158 | 159 | df := func(network, addr string) (*redis.Client, error) { 160 | return redis.DialTimeout(network, addr, c.o.Timeout) 161 | } 162 | p, err := pool.NewCustom("tcp", addr, c.o.PoolSize, df) 163 | if err != nil { 164 | c.poolThrottles[addr] = time.After(c.o.PoolThrottle) 165 | return nil, err 166 | } 167 | return p, err 168 | } 169 | 170 | // Anything which requires creating/deleting pools must be done in here 171 | func (c *Cluster) spin() { 172 | for { 173 | select { 174 | case f := <-c.callCh: 175 | f(c) 176 | case <-c.stopCh: 177 | return 178 | } 179 | } 180 | } 181 | 182 | // Returns a connection for the given key or given address, depending on which 183 | // is set. If the given pool couldn't be used a connection from a random pool 184 | // will (attempt) to be returned 185 | func (c *Cluster) getConn(key, addr string) (*redis.Client, error) { 186 | type resp struct { 187 | conn *redis.Client 188 | err error 189 | } 190 | respCh := make(chan *resp) 191 | c.callCh <- func(c *Cluster) { 192 | if key != "" { 193 | addr = c.addrForKeyInner(key) 194 | } 195 | 196 | var err error 197 | p, ok := c.pools[addr] 198 | if !ok { 199 | p, err = c.newPool(addr, false) 200 | } 201 | 202 | var conn *redis.Client 203 | if err == nil { 204 | conn, err = p.Get() 205 | if err == nil { 206 | respCh <- &resp{conn, nil} 207 | return 208 | } 209 | } 210 | 211 | // If there's an error try one more time retrieving from a random pool 212 | // before bailing 213 | p = c.getRandomPoolInner() 214 | if p == nil { 215 | respCh <- &resp{err: errNoPools} 216 | return 217 | } 218 | conn, err = p.Get() 219 | if err != nil { 220 | respCh <- &resp{err: err} 221 | return 222 | } 223 | 224 | respCh <- &resp{conn, nil} 225 | } 226 | r := <-respCh 227 | return r.conn, r.err 228 | } 229 | 230 | // Put putss the connection back in its pool. To be used alongside any of the 231 | // Get* methods once use of the redis.Client is done 232 | func (c *Cluster) Put(conn *redis.Client) { 233 | c.callCh <- func(c *Cluster) { 234 | p := c.pools[conn.Addr] 235 | if p == nil { 236 | conn.Close() 237 | return 238 | } 239 | 240 | p.Put(conn) 241 | } 242 | } 243 | 244 | func (c *Cluster) getRandomPoolInner() *pool.Pool { 245 | for _, pool := range c.pools { 246 | return pool 247 | } 248 | return nil 249 | } 250 | 251 | // Reset will re-retrieve the cluster topology and set up/teardown connections 252 | // as necessary. It begins by calling CLUSTER SLOTS on a random known 253 | // connection. The return from that is used to re-create the topology, create 254 | // any missing clients, and close any clients which are no longer needed. 255 | // 256 | // This call is inherently throttled, so that multiple clients can call it at 257 | // the same time and it will only actually occur once (subsequent clients will 258 | // have nil returned immediately). 259 | func (c *Cluster) Reset() error { 260 | respCh := make(chan error) 261 | c.callCh <- func(c *Cluster) { 262 | respCh <- c.resetInner() 263 | } 264 | return <-respCh 265 | } 266 | 267 | func (c *Cluster) resetInner() error { 268 | // Throttle resetting so a bunch of routines can call Reset at once and the 269 | // server won't be spammed. We don't a throttle until the second Reset is 270 | // called, so the initial call inside New goes through correctly 271 | if c.resetThrottle != nil { 272 | select { 273 | case <-c.resetThrottle.C: 274 | default: 275 | return nil 276 | } 277 | } else { 278 | c.resetThrottle = time.NewTicker(c.o.ResetThrottle) 279 | } 280 | 281 | p := c.getRandomPoolInner() 282 | if p == nil { 283 | return fmt.Errorf("no available nodes to call CLUSTER SLOTS on") 284 | } 285 | 286 | return c.resetInnerUsingPool(p) 287 | } 288 | 289 | func (c *Cluster) resetInnerUsingPool(p *pool.Pool) error { 290 | 291 | // If we move the throttle check to be in here we'll have to fix the test in 292 | // TestReset, since it depends on being able to call Reset right after 293 | // initializing the cluster 294 | 295 | client, err := p.Get() 296 | if err != nil { 297 | return err 298 | } 299 | defer p.Put(client) 300 | 301 | pools := map[string]*pool.Pool{} 302 | 303 | elems, err := client.Cmd("CLUSTER", "SLOTS").Array() 304 | if err != nil { 305 | return err 306 | } else if len(elems) == 0 { 307 | return errors.New("empty CLUSTER SLOTS response") 308 | } 309 | 310 | var start, end, port int 311 | var ip, slotAddr string 312 | var slotPool *pool.Pool 313 | var ok, changed bool 314 | for _, slotGroup := range elems { 315 | slotElems, err := slotGroup.Array() 316 | if err != nil { 317 | return err 318 | } 319 | if start, err = slotElems[0].Int(); err != nil { 320 | return err 321 | } 322 | if end, err = slotElems[1].Int(); err != nil { 323 | return err 324 | } 325 | slotAddrElems, err := slotElems[2].Array() 326 | if err != nil { 327 | return err 328 | } 329 | if ip, err = slotAddrElems[0].Str(); err != nil { 330 | return err 331 | } 332 | if port, err = slotAddrElems[1].Int(); err != nil { 333 | return err 334 | } 335 | 336 | // cluster slots returns a blank ip for the node we're currently 337 | // connected to. I guess the node doesn't know its own ip? I guess that 338 | // makes sense 339 | if ip == "" { 340 | slotAddr = p.Addr 341 | } else { 342 | slotAddr = ip + ":" + strconv.Itoa(port) 343 | } 344 | for i := start; i <= end; i++ { 345 | c.mapping[i] = slotAddr 346 | } 347 | if slotPool, ok = c.pools[slotAddr]; ok { 348 | pools[slotAddr] = slotPool 349 | } else { 350 | slotPool, err = c.newPool(slotAddr, true) 351 | if err != nil { 352 | return err 353 | } 354 | changed = true 355 | pools[slotAddr] = slotPool 356 | } 357 | } 358 | 359 | for addr := range c.pools { 360 | if _, ok := pools[addr]; !ok { 361 | c.pools[addr].Empty() 362 | delete(c.poolThrottles, addr) 363 | changed = true 364 | } 365 | } 366 | c.pools = pools 367 | 368 | if changed { 369 | select { 370 | case c.ChangeCh <- struct{}{}: 371 | default: 372 | } 373 | } 374 | 375 | return nil 376 | } 377 | 378 | // Logic for doing a command: 379 | // * Get client for command's slot, try it 380 | // * If err == nil, return reply 381 | // * If err is a client error: 382 | // * If MOVED: 383 | // * If node not tried before, go to top with that node 384 | // * Otherwise if we haven't Reset, do that and go to top with random 385 | // node 386 | // * Otherwise error out 387 | // * If ASK (same as MOVED, but call ASKING beforehand and don't modify 388 | // slots) 389 | // * Otherwise return the error 390 | // * Otherwise it is a network error 391 | // * If we haven't reconnected to this node yet, do that and go to top 392 | // * If we haven't reset yet do that, pick a random node, and go to top 393 | // * Otherwise return network error (we don't reset, we have no nodes to do 394 | // it with) 395 | 396 | // Cmd performs the given command on the correct cluster node and gives back the 397 | // command's reply. The command *must* have a key parameter (i.e. len(args) >= 398 | // 1). If any MOVED or ASK errors are returned they will be transparently 399 | // handled by this method. 400 | func (c *Cluster) Cmd(cmd string, args ...interface{}) *redis.Resp { 401 | if len(args) < 1 { 402 | return errorResp(ErrBadCmdNoKey) 403 | } 404 | 405 | key, err := redis.KeyFromArgs(args) 406 | if err != nil { 407 | return errorResp(err) 408 | } 409 | 410 | client, err := c.getConn(key, "") 411 | if err != nil { 412 | return errorResp(err) 413 | } 414 | 415 | return c.clientCmd(client, cmd, args, false, nil, false) 416 | } 417 | 418 | func haveTried(tried map[string]bool, addr string) bool { 419 | if tried == nil { 420 | return false 421 | } 422 | return tried[addr] 423 | } 424 | 425 | func justTried(tried map[string]bool, addr string) map[string]bool { 426 | if tried == nil { 427 | tried = map[string]bool{} 428 | } 429 | tried[addr] = true 430 | return tried 431 | } 432 | 433 | func (c *Cluster) clientCmd( 434 | client *redis.Client, cmd string, args []interface{}, ask bool, 435 | tried map[string]bool, haveReset bool, 436 | ) *redis.Resp { 437 | var err error 438 | var r *redis.Resp 439 | defer c.Put(client) 440 | 441 | if ask { 442 | r = client.Cmd("ASKING") 443 | ask = false 444 | } 445 | 446 | // If we asked and got an error, we continue on with error handling as we 447 | // would normally do. If we didn't ask or the ask succeeded we do the 448 | // command normally, and see how that goes 449 | if r == nil || r.Err == nil { 450 | r = client.Cmd(cmd, args...) 451 | } 452 | 453 | if err = r.Err; err == nil { 454 | return r 455 | } 456 | 457 | // At this point we have some kind of error we have to deal with. The above 458 | // code is what will be run 99% of the time and is pretty streamlined, 459 | // everything after this point is allowed to be hairy and gross 460 | 461 | haveTriedBefore := haveTried(tried, client.Addr) 462 | tried = justTried(tried, client.Addr) 463 | 464 | // Deal with network error 465 | if r.IsType(redis.IOErr) { 466 | // If this is the first time trying this node, try it again 467 | if !haveTriedBefore { 468 | if client, try2err := c.getConn("", client.Addr); try2err == nil { 469 | return c.clientCmd(client, cmd, args, false, tried, haveReset) 470 | } 471 | } 472 | // Otherwise try calling Reset() and getting a random client 473 | if !haveReset { 474 | if resetErr := c.Reset(); resetErr != nil { 475 | return errorRespf("Could not get cluster info: %s", resetErr) 476 | } 477 | client, getErr := c.getConn("", "") 478 | if getErr != nil { 479 | return errorResp(getErr) 480 | } 481 | return c.clientCmd(client, cmd, args, false, tried, true) 482 | } 483 | // Otherwise give up and return the most recent error 484 | return r 485 | } 486 | 487 | // Here we deal with application errors that are either MOVED or ASK 488 | msg := err.Error() 489 | moved := strings.HasPrefix(msg, "MOVED ") 490 | ask = strings.HasPrefix(msg, "ASK ") 491 | if moved || ask { 492 | _, addr := redirectInfo(msg) 493 | c.callCh <- func(c *Cluster) { 494 | select { 495 | case c.MissCh <- struct{}{}: 496 | default: 497 | } 498 | } 499 | 500 | // If we've already called Reset and we're getting MOVED again than the 501 | // cluster is having problems, likely telling us to try a node which is 502 | // not reachable. Not much which can be done at this point 503 | if haveReset { 504 | return errorRespf("Cluster doesn't make sense, %s might be gone", addr) 505 | } 506 | if resetErr := c.Reset(); resetErr != nil { 507 | return errorRespf("Could not get cluster info: %s", resetErr) 508 | } 509 | haveReset = true 510 | 511 | // At this point addr is whatever redis told us it should be. However, 512 | // if we can't get a connection to it we'll never actually mark it as 513 | // tried, resulting in an infinite loop. Here we mark it as tried 514 | // regardless of if it actually was or not 515 | tried = justTried(tried, addr) 516 | 517 | client, getErr := c.getConn("", addr) 518 | if getErr != nil { 519 | return errorResp(getErr) 520 | } 521 | return c.clientCmd(client, cmd, args, ask, tried, haveReset) 522 | } 523 | 524 | // It's a normal application error (like WRONG KEY TYPE or whatever), return 525 | // that to the client 526 | return r 527 | } 528 | 529 | func redirectInfo(msg string) (int, string) { 530 | parts := strings.Split(msg, " ") 531 | slotStr := parts[1] 532 | slot, err := strconv.Atoi(slotStr) 533 | if err != nil { 534 | // if redis is returning bad integers, we have problems 535 | panic(err) 536 | } 537 | addr := parts[2] 538 | return slot, addr 539 | } 540 | 541 | func (c *Cluster) addrForKeyInner(key string) string { 542 | if start := strings.Index(key, "{"); start >= 0 { 543 | if end := strings.Index(key[start+2:], "}"); end >= 0 { 544 | key = key[start+1 : start+2+end] 545 | } 546 | } 547 | i := CRC16([]byte(key)) % numSlots 548 | return c.mapping[i] 549 | } 550 | 551 | // GetForKey returns the Client which *ought* to handle the given key, based 552 | // on Cluster's understanding of the cluster topology at the given moment. If 553 | // the slot isn't known or there is an error contacting the correct node, a 554 | // random client is returned. The client must be returned back to its pool using 555 | // Put when through 556 | func (c *Cluster) GetForKey(key string) (*redis.Client, error) { 557 | return c.getConn(key, "") 558 | } 559 | 560 | // GetEvery returns a single *redis.Client per master that the cluster currently 561 | // knows about. The map returned maps the address of the client to the client 562 | // itself. If there is an error retrieving any of the clients (for instance if a 563 | // new connection has to be made to get it) only that error is returned. Each 564 | // client must be returned back to its pools using Put when through 565 | func (c *Cluster) GetEvery() (map[string]*redis.Client, error) { 566 | type resp struct { 567 | m map[string]*redis.Client 568 | err error 569 | } 570 | respCh := make(chan resp) 571 | c.callCh <- func(c *Cluster) { 572 | m := map[string]*redis.Client{} 573 | for addr, p := range c.pools { 574 | client, err := p.Get() 575 | if err != nil { 576 | respCh <- resp{nil, err} 577 | return 578 | } 579 | m[addr] = client 580 | } 581 | respCh <- resp{m, nil} 582 | } 583 | 584 | r := <-respCh 585 | return r.m, r.err 586 | } 587 | 588 | // GetAddrForKey returns the address which would be used to handle the given key 589 | // in the cluster. 590 | func (c *Cluster) GetAddrForKey(key string) string { 591 | respCh := make(chan string) 592 | c.callCh <- func(c *Cluster) { 593 | respCh <- c.addrForKeyInner(key) 594 | } 595 | return <-respCh 596 | } 597 | 598 | // Close calls Close on all connected clients. Once this is called no other 599 | // methods should be called on this instance of Cluster 600 | func (c *Cluster) Close() { 601 | c.callCh <- func(c *Cluster) { 602 | for addr, p := range c.pools { 603 | p.Empty() 604 | delete(c.pools, addr) 605 | } 606 | if c.resetThrottle != nil { 607 | c.resetThrottle.Stop() 608 | } 609 | } 610 | close(c.stopCh) 611 | } 612 | -------------------------------------------------------------------------------- /redis/resp.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net" 9 | "reflect" 10 | "strconv" 11 | "strings" 12 | ) 13 | 14 | var ( 15 | delim = []byte{'\r', '\n'} 16 | delimEnd = delim[len(delim)-1] 17 | ) 18 | 19 | // RespType is a field on every Resp which indicates the type of the data it 20 | // contains 21 | type RespType int 22 | 23 | // Different RespTypes. You can check if a message is of one or more types using 24 | // the IsType method on Resp 25 | const ( 26 | SimpleStr RespType = 1 << iota 27 | BulkStr 28 | IOErr // An error which prevented reading/writing, e.g. connection close 29 | AppErr // An error returned by redis, e.g. WRONGTYPE 30 | Int 31 | Array 32 | Nil 33 | 34 | // Str combines both SimpleStr and BulkStr, which are considered strings to 35 | // the Str() method. This is what you want to give to IsType when 36 | // determining if a response is a string 37 | Str = SimpleStr | BulkStr 38 | 39 | // Err combines both IOErr and AppErr, which both indicate that the Err 40 | // field on their Resp is filled. To determine if a Resp is an error you'll 41 | // most often want to simply check if the Err field on it is nil 42 | Err = IOErr | AppErr 43 | ) 44 | 45 | var ( 46 | simpleStrPrefix = []byte{'+'} 47 | errPrefix = []byte{'-'} 48 | intPrefix = []byte{':'} 49 | bulkStrPrefix = []byte{'$'} 50 | arrayPrefix = []byte{'*'} 51 | nilFormatted = []byte("$-1\r\n") 52 | ) 53 | 54 | // Parse errors 55 | var ( 56 | errBadType = errors.New("wrong type") 57 | errParse = errors.New("parse error") 58 | errNotStr = errors.New("could not convert to string") 59 | errNotInt = errors.New("could not convert to int") 60 | errNotArray = errors.New("could not convert to array") 61 | ) 62 | 63 | // Resp represents a single response or message being sent to/from a redis 64 | // server. Each Resp has a type (see RespType and IsType) and a value. Values 65 | // can be retrieved using any of the casting methods on this type (e.g. Str) 66 | type Resp struct { 67 | typ RespType 68 | val interface{} 69 | 70 | // Err indicates that this Resp signals some kind of error, either on the 71 | // connection level or the application level. Use IsType if you need to 72 | // determine which, otherwise you can simply check if this is nil 73 | Err error 74 | } 75 | 76 | // NewResp takes the given value and interprets it into a resp encoded byte 77 | // stream 78 | func NewResp(v interface{}) *Resp { 79 | r := format(v, false) 80 | return &r 81 | } 82 | 83 | // NewRespSimple is like NewResp except it encodes its string as a resp 84 | // SimpleStr type, whereas NewResp will encode all strings as BulkStr 85 | func NewRespSimple(s string) *Resp { 86 | return &Resp{ 87 | typ: SimpleStr, 88 | val: []byte(s), 89 | } 90 | } 91 | 92 | // NewRespFlattenedStrings is like NewResp except it looks through the given 93 | // value and converts any types (except slices/maps) into strings, and flatten any 94 | // embedded slices/maps into a single slice. This is useful because commands to 95 | // a redis server must be given as an array of bulk strings. If the argument 96 | // isn't already in a slice/map it will be wrapped so that it is written as a 97 | // Array of size one 98 | func NewRespFlattenedStrings(v interface{}) *Resp { 99 | fv := flatten(v) 100 | r := format(fv, true) 101 | return &r 102 | } 103 | 104 | // newRespIOErr is a convenience method for making Resps to wrap io errors 105 | func newRespIOErr(err error) *Resp { 106 | r := NewResp(err) 107 | r.typ = IOErr 108 | return r 109 | } 110 | 111 | // RespReader is a wrapper around an io.Reader which will read Resp messages off 112 | // of the io.Reader 113 | type RespReader struct { 114 | r *bufio.Reader 115 | } 116 | 117 | // NewRespReader creates and returns a new RespReader which will read from the 118 | // given io.Reader. Once passed in the io.Reader shouldn't be read from by any 119 | // other processes 120 | func NewRespReader(r io.Reader) *RespReader { 121 | br, ok := r.(*bufio.Reader) 122 | if !ok { 123 | br = bufio.NewReader(r) 124 | } 125 | return &RespReader{br} 126 | } 127 | 128 | // ReadResp attempts to read a message object from the given io.Reader, parse 129 | // it, and return a Resp representing it 130 | func (rr *RespReader) Read() *Resp { 131 | res, err := bufioReadResp(rr.r) 132 | if err != nil { 133 | res = Resp{typ: IOErr, val: err, Err: err} 134 | } 135 | return &res 136 | } 137 | 138 | func bufioReadResp(r *bufio.Reader) (Resp, error) { 139 | b, err := r.Peek(1) 140 | if err != nil { 141 | return Resp{}, err 142 | } 143 | switch b[0] { 144 | case simpleStrPrefix[0]: 145 | return readSimpleStr(r) 146 | case errPrefix[0]: 147 | return readError(r) 148 | case intPrefix[0]: 149 | return readInt(r) 150 | case bulkStrPrefix[0]: 151 | return readBulkStr(r) 152 | case arrayPrefix[0]: 153 | return readArray(r) 154 | default: 155 | return Resp{}, errBadType 156 | } 157 | } 158 | 159 | func readSimpleStr(r *bufio.Reader) (Resp, error) { 160 | b, err := r.ReadBytes(delimEnd) 161 | if err != nil { 162 | return Resp{}, err 163 | } 164 | return Resp{typ: SimpleStr, val: b[1 : len(b)-2]}, nil 165 | } 166 | 167 | func readError(r *bufio.Reader) (Resp, error) { 168 | b, err := r.ReadBytes(delimEnd) 169 | if err != nil { 170 | return Resp{}, err 171 | } 172 | err = errors.New(string(b[1 : len(b)-2])) 173 | return Resp{typ: AppErr, val: err, Err: err}, nil 174 | } 175 | 176 | func readInt(r *bufio.Reader) (Resp, error) { 177 | b, err := r.ReadBytes(delimEnd) 178 | if err != nil { 179 | return Resp{}, err 180 | } 181 | i, err := strconv.ParseInt(string(b[1:len(b)-2]), 10, 64) 182 | if err != nil { 183 | return Resp{}, errParse 184 | } 185 | return Resp{typ: Int, val: i}, nil 186 | } 187 | 188 | func readBulkStr(r *bufio.Reader) (Resp, error) { 189 | b, err := r.ReadBytes(delimEnd) 190 | if err != nil { 191 | return Resp{}, err 192 | } 193 | size, err := strconv.ParseInt(string(b[1:len(b)-2]), 10, 64) 194 | if err != nil { 195 | return Resp{}, errParse 196 | } 197 | if size < 0 { 198 | return Resp{typ: Nil}, nil 199 | } 200 | total := make([]byte, size) 201 | b2 := total 202 | var n int 203 | for len(b2) > 0 { 204 | n, err = r.Read(b2) 205 | if err != nil { 206 | return Resp{}, err 207 | } 208 | b2 = b2[n:] 209 | } 210 | 211 | // There's a hanging \r\n there, gotta read past it 212 | trail := make([]byte, 2) 213 | for i := 0; i < 2; i++ { 214 | c, err := r.ReadByte() 215 | if err != nil { 216 | return Resp{}, err 217 | } 218 | trail[i] = c 219 | } 220 | 221 | return Resp{typ: BulkStr, val: total}, nil 222 | } 223 | 224 | func readArray(r *bufio.Reader) (Resp, error) { 225 | b, err := r.ReadBytes(delimEnd) 226 | if err != nil { 227 | return Resp{}, err 228 | } 229 | size, err := strconv.ParseInt(string(b[1:len(b)-2]), 10, 64) 230 | if err != nil { 231 | return Resp{}, errParse 232 | } 233 | if size < 0 { 234 | return Resp{typ: Nil}, nil 235 | } 236 | 237 | arr := make([]Resp, size) 238 | for i := range arr { 239 | m, err := bufioReadResp(r) 240 | if err != nil { 241 | return Resp{}, err 242 | } 243 | arr[i] = m 244 | } 245 | return Resp{typ: Array, val: arr}, nil 246 | } 247 | 248 | // IsType returns whether or or not the reply is of a given type 249 | // 250 | // isStr := r.IsType(redis.Str) 251 | // 252 | // Multiple types can be checked at the same time by or'ing the desired types 253 | // 254 | // isStrOrInt := r.IsType(redis.Str | redis.Int) 255 | // 256 | func (r *Resp) IsType(t RespType) bool { 257 | return r.typ&t > 0 258 | } 259 | 260 | // WriteTo writes the resp encoded form of the Resp to the given writer, 261 | // implementing the WriterTo interface 262 | func (r *Resp) WriteTo(w io.Writer) (int64, error) { 263 | 264 | // SimpleStr is a special case, writeTo always writes strings as BulkStrs, 265 | // so we just manually do SimpleStr here 266 | if r.typ == SimpleStr { 267 | s := r.val.([]byte) 268 | b := append(make([]byte, 0, len(s)+3), simpleStrPrefix...) 269 | b = append(b, s...) 270 | b = append(b, delim...) 271 | written, err := w.Write(b) 272 | return int64(written), err 273 | } 274 | 275 | return writeTo(w, nil, r.val, false, false) 276 | } 277 | 278 | // Bytes returns a byte slice representing the value of the Resp. Only valid for 279 | // a Resp of type Str. If r.Err != nil that will be returned. 280 | func (r *Resp) Bytes() ([]byte, error) { 281 | if r.Err != nil { 282 | return nil, r.Err 283 | } else if !r.IsType(Str) { 284 | return nil, errBadType 285 | } 286 | 287 | if b, ok := r.val.([]byte); ok { 288 | return b, nil 289 | } 290 | return nil, errNotStr 291 | } 292 | 293 | // Str is a wrapper around Bytes which returns the result as a string instead of 294 | // a byte slice 295 | func (r *Resp) Str() (string, error) { 296 | b, err := r.Bytes() 297 | if err != nil { 298 | return "", err 299 | } 300 | return string(b), nil 301 | } 302 | 303 | // Int returns an int representing the value of the Resp. Only valid for a 304 | // Resp of type Int. If r.Err != nil that will be returned 305 | func (r *Resp) Int() (int, error) { 306 | i, err := r.Int64() 307 | return int(i), err 308 | } 309 | 310 | // Int64 is like Int, but returns int64 instead of Int 311 | func (r *Resp) Int64() (int64, error) { 312 | if r.Err != nil { 313 | return 0, r.Err 314 | } 315 | if i, ok := r.val.(int64); ok { 316 | return i, nil 317 | } 318 | return 0, errNotInt 319 | } 320 | 321 | // Float64 returns a float64 representing the value of the Resp. Only valud for 322 | // a Resp of type Str which represents an actual float. If r.Err != nil that 323 | // will be returned 324 | func (r *Resp) Float64() (float64, error) { 325 | if r.Err != nil { 326 | return 0, r.Err 327 | } 328 | if b, ok := r.val.([]byte); ok { 329 | f, err := strconv.ParseFloat(string(b), 64) 330 | if err != nil { 331 | return 0, err 332 | } 333 | return f, nil 334 | } 335 | return 0, errNotStr 336 | } 337 | 338 | func (r *Resp) betterArray() ([]Resp, error) { 339 | if r.Err != nil { 340 | return nil, r.Err 341 | } 342 | if a, ok := r.val.([]Resp); ok { 343 | return a, nil 344 | } 345 | return nil, errNotArray 346 | } 347 | 348 | // Array returns the Resp slice encompassed by this Resp. Only valid for a Resp 349 | // of type Array. If r.Err != nil that will be returned 350 | func (r *Resp) Array() ([]*Resp, error) { 351 | a, err := r.betterArray() 352 | if err != nil { 353 | return nil, err 354 | } 355 | abad := make([]*Resp, len(a)) 356 | for i := range a { 357 | abad[i] = &a[i] 358 | } 359 | return abad, nil 360 | } 361 | 362 | // List is a wrapper around Array which returns the result as a list of strings, 363 | // calling Str() on each Resp which Array returns. Any errors encountered are 364 | // immediately returned. Any Nil replies are interpreted as empty strings 365 | func (r *Resp) List() ([]string, error) { 366 | m, err := r.betterArray() 367 | if err != nil { 368 | return nil, err 369 | } 370 | l := make([]string, len(m)) 371 | for i := range m { 372 | if m[i].IsType(Nil) { 373 | l[i] = "" 374 | continue 375 | } 376 | s, err := m[i].Str() 377 | if err != nil { 378 | return nil, err 379 | } 380 | l[i] = s 381 | } 382 | return l, nil 383 | } 384 | 385 | // ListBytes is a wrapper around Array which returns the result as a list of 386 | // byte slices, calling Bytes() on each Resp which Array returns. Any errors 387 | // encountered are immediately returned. Any Nil replies are interpreted as nil 388 | func (r *Resp) ListBytes() ([][]byte, error) { 389 | m, err := r.betterArray() 390 | if err != nil { 391 | return nil, err 392 | } 393 | l := make([][]byte, len(m)) 394 | for i := range m { 395 | if m[i].IsType(Nil) { 396 | l[i] = nil 397 | continue 398 | } 399 | b, err := m[i].Bytes() 400 | if err != nil { 401 | return nil, err 402 | } 403 | l[i] = b 404 | } 405 | return l, nil 406 | } 407 | 408 | // Map is a wrapper around Array which returns the result as a map of strings, 409 | // calling Str() on alternating key/values for the map. All value fields of type 410 | // Nil will be treated as empty strings, keys must all be of type Str 411 | func (r *Resp) Map() (map[string]string, error) { 412 | l, err := r.betterArray() 413 | if err != nil { 414 | return nil, err 415 | } 416 | if len(l)%2 != 0 { 417 | return nil, errors.New("reply has odd number of elements") 418 | } 419 | 420 | m := map[string]string{} 421 | for { 422 | if len(l) == 0 { 423 | return m, nil 424 | } 425 | k, v := l[0], l[1] 426 | l = l[2:] 427 | 428 | ks, err := k.Str() 429 | if err != nil { 430 | return nil, err 431 | } 432 | 433 | var vs string 434 | if v.IsType(Nil) { 435 | vs = "" 436 | } else if vs, err = v.Str(); err != nil { 437 | return nil, err 438 | } 439 | m[ks] = vs 440 | } 441 | } 442 | 443 | // String returns a string representation of the Resp. This method is for 444 | // debugging, use Str() for reading a Str reply 445 | func (r *Resp) String() string { 446 | var inner string 447 | switch r.typ { 448 | case AppErr: 449 | inner = fmt.Sprintf("AppErr %s", r.Err) 450 | case IOErr: 451 | inner = fmt.Sprintf("IOErr %s", r.Err) 452 | case BulkStr, SimpleStr: 453 | inner = fmt.Sprintf("Str %s", string(r.val.([]byte))) 454 | case Int: 455 | inner = fmt.Sprintf("Int %d", r.val.(int64)) 456 | case Nil: 457 | inner = fmt.Sprintf("Nil") 458 | case Array: 459 | kids := r.val.([]*Resp) 460 | kidsStr := make([]string, len(kids)) 461 | for i := range kids { 462 | kidsStr[i] = kids[i].String() 463 | } 464 | inner = strings.Join(kidsStr, " ") 465 | } 466 | return fmt.Sprintf("Resp(%s)", inner) 467 | } 468 | 469 | var typeOfBytes = reflect.TypeOf([]byte(nil)) 470 | 471 | func flattenedLength(mm ...interface{}) int { 472 | 473 | total := 0 474 | 475 | for _, m := range mm { 476 | switch m.(type) { 477 | case []byte, string, bool, nil, int, int8, int16, int32, int64, uint, 478 | uint8, uint16, uint32, uint64, float32, float64, error: 479 | total++ 480 | 481 | case Resp: 482 | total += flattenedLength(m.(Resp).val) 483 | case *Resp: 484 | total += flattenedLength(m.(*Resp).val) 485 | 486 | case []interface{}: 487 | total += flattenedLength(m.([]interface{})...) 488 | 489 | default: 490 | t := reflect.TypeOf(m) 491 | 492 | switch t.Kind() { 493 | case reflect.Slice: 494 | rm := reflect.ValueOf(m) 495 | l := rm.Len() 496 | for i := 0; i < l; i++ { 497 | total += flattenedLength(rm.Index(i).Interface()) 498 | } 499 | 500 | case reflect.Map: 501 | rm := reflect.ValueOf(m) 502 | keys := rm.MapKeys() 503 | for _, k := range keys { 504 | kv := k.Interface() 505 | vv := rm.MapIndex(k).Interface() 506 | total += flattenedLength(kv) 507 | total += flattenedLength(vv) 508 | } 509 | 510 | default: 511 | total++ 512 | } 513 | } 514 | } 515 | 516 | return total 517 | } 518 | 519 | func flatten(m interface{}) []interface{} { 520 | t := reflect.TypeOf(m) 521 | 522 | // If it's a byte-slice we don't want to flatten 523 | if t == typeOfBytes { 524 | return []interface{}{m} 525 | } 526 | 527 | switch t.Kind() { 528 | case reflect.Slice: 529 | rm := reflect.ValueOf(m) 530 | l := rm.Len() 531 | ret := make([]interface{}, 0, l) 532 | for i := 0; i < l; i++ { 533 | ret = append(ret, flatten(rm.Index(i).Interface())...) 534 | } 535 | return ret 536 | 537 | case reflect.Map: 538 | rm := reflect.ValueOf(m) 539 | l := rm.Len() * 2 540 | keys := rm.MapKeys() 541 | ret := make([]interface{}, 0, l) 542 | for _, k := range keys { 543 | kv := k.Interface() 544 | vv := rm.MapIndex(k).Interface() 545 | ret = append(ret, flatten(kv)...) 546 | ret = append(ret, flatten(vv)...) 547 | } 548 | return ret 549 | 550 | default: 551 | return []interface{}{m} 552 | } 553 | } 554 | 555 | func anyIntToInt64(m interface{}) int64 { 556 | switch mt := m.(type) { 557 | case int: 558 | return int64(mt) 559 | case int8: 560 | return int64(mt) 561 | case int16: 562 | return int64(mt) 563 | case int32: 564 | return int64(mt) 565 | case int64: 566 | return mt 567 | case uint: 568 | return int64(mt) 569 | case uint8: 570 | return int64(mt) 571 | case uint16: 572 | return int64(mt) 573 | case uint32: 574 | return int64(mt) 575 | case uint64: 576 | return int64(mt) 577 | } 578 | panic(fmt.Sprintf("anyIntToInt64 got bad arg: %#v", m)) 579 | } 580 | 581 | func writeBytesHelper( 582 | w io.Writer, b []byte, lastWritten int64, lastErr error, 583 | ) ( 584 | int64, error, 585 | ) { 586 | if lastErr != nil { 587 | return lastWritten, lastErr 588 | } 589 | i, err := w.Write(b) 590 | return int64(i) + lastWritten, err 591 | } 592 | 593 | func writeArrayHeader(w io.Writer, buf []byte, l int64) (int64, error) { 594 | buf = strconv.AppendInt(buf, l, 10) 595 | var err error 596 | var written int64 597 | written, err = writeBytesHelper(w, arrayPrefix, written, err) 598 | written, err = writeBytesHelper(w, buf, written, err) 599 | written, err = writeBytesHelper(w, delim, written, err) 600 | return written, err 601 | } 602 | 603 | // Given a preallocated byte buffer and a string, this will copy the string's 604 | // contents into buf starting at index 0, and returns two slices from buf: The 605 | // first is a slice of the string data, the second is a slice of the "rest" of 606 | // buf following the first slice 607 | func stringSlicer(buf []byte, s string) ([]byte, []byte) { 608 | sbuf := append(buf[:0], s...) 609 | return sbuf, sbuf[len(sbuf):] 610 | } 611 | 612 | // takes in something, m, and encodes it and writes it to w. buf is used as a 613 | // pre-alloated byte buffer for encoding integers (expected to have a length of 614 | // 0), so we don't have to re-allocate a new one every time we convert an 615 | // integer to a string. forceString means all types will be converted to 616 | // strings, noArrayHeader means don't write out the headers to any arrays, just 617 | // inline all the elements in the array 618 | func writeTo( 619 | w io.Writer, buf []byte, m interface{}, forceString, noArrayHeader bool, 620 | ) ( 621 | int64, error, 622 | ) { 623 | switch mt := m.(type) { 624 | case []byte: 625 | return writeStr(w, buf, mt) 626 | case string: 627 | sbuf, buf := stringSlicer(buf, mt) 628 | return writeStr(w, buf, sbuf) 629 | case bool: 630 | if mt { 631 | buf[0] = '1' 632 | } else { 633 | buf[0] = '0' 634 | } 635 | return writeStr(w, buf[1:], buf[:1]) 636 | case nil: 637 | if forceString { 638 | return writeStr(w, buf, nil) 639 | } 640 | return writeNil(w) 641 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: 642 | i := anyIntToInt64(mt) 643 | return writeInt(w, buf, i, forceString) 644 | case float32: 645 | return writeFloat(w, buf, float64(mt), 32) 646 | case float64: 647 | return writeFloat(w, buf, mt, 64) 648 | case error: 649 | return writeErr(w, buf, mt, forceString) 650 | 651 | // We duplicate the below code here a bit, since this is the common case and 652 | // it'd be better to not get the reflect package involved here 653 | case []interface{}: 654 | l := len(mt) 655 | var totalWritten int64 656 | 657 | if !noArrayHeader { 658 | written, err := writeArrayHeader(w, buf, int64(l)) 659 | totalWritten += written 660 | if err != nil { 661 | return totalWritten, err 662 | } 663 | } 664 | for i := 0; i < l; i++ { 665 | written, err := writeTo(w, buf, mt[i], forceString, noArrayHeader) 666 | totalWritten += written 667 | if err != nil { 668 | return totalWritten, err 669 | } 670 | } 671 | return totalWritten, nil 672 | 673 | case *Resp: 674 | return writeTo(w, buf, mt.val, forceString, noArrayHeader) 675 | 676 | case Resp: 677 | return writeTo(w, buf, mt.val, forceString, noArrayHeader) 678 | 679 | default: 680 | // Fallback to reflect-based. 681 | switch reflect.TypeOf(m).Kind() { 682 | case reflect.Slice: 683 | rm := reflect.ValueOf(mt) 684 | l := rm.Len() 685 | var totalWritten, written int64 686 | var err error 687 | 688 | if !noArrayHeader { 689 | written, err = writeArrayHeader(w, buf, int64(l)) 690 | totalWritten += written 691 | if err != nil { 692 | return totalWritten, err 693 | } 694 | } 695 | for i := 0; i < l; i++ { 696 | vv := rm.Index(i).Interface() 697 | written, err = writeTo(w, buf, vv, forceString, noArrayHeader) 698 | totalWritten += written 699 | if err != nil { 700 | return totalWritten, err 701 | } 702 | } 703 | return totalWritten, nil 704 | 705 | case reflect.Map: 706 | rm := reflect.ValueOf(mt) 707 | l := rm.Len() * 2 708 | var totalWritten, written int64 709 | var err error 710 | 711 | if !noArrayHeader { 712 | written, err = writeArrayHeader(w, buf, int64(l)) 713 | totalWritten += written 714 | if err != nil { 715 | return totalWritten, err 716 | } 717 | } 718 | keys := rm.MapKeys() 719 | for _, k := range keys { 720 | kv := k.Interface() 721 | written, err = writeTo(w, buf, kv, forceString, noArrayHeader) 722 | totalWritten += written 723 | if err != nil { 724 | return totalWritten, err 725 | } 726 | 727 | vv := rm.MapIndex(k).Interface() 728 | written, err = writeTo(w, buf, vv, forceString, noArrayHeader) 729 | if err != nil { 730 | return totalWritten, err 731 | } 732 | } 733 | return totalWritten, nil 734 | 735 | default: 736 | return writeStr(w, buf, []byte(fmt.Sprint(m))) 737 | } 738 | } 739 | } 740 | 741 | func writeStr(w io.Writer, buf, b []byte) (int64, error) { 742 | var err error 743 | var written int64 744 | buf = strconv.AppendInt(buf[:0], int64(len(b)), 10) 745 | 746 | written, err = writeBytesHelper(w, bulkStrPrefix, written, err) 747 | written, err = writeBytesHelper(w, buf, written, err) 748 | written, err = writeBytesHelper(w, delim, written, err) 749 | written, err = writeBytesHelper(w, b, written, err) 750 | written, err = writeBytesHelper(w, delim, written, err) 751 | return written, err 752 | } 753 | 754 | func writeErr( 755 | w io.Writer, buf []byte, ierr error, forceString bool, 756 | ) ( 757 | int64, error, 758 | ) { 759 | ierrStr := []byte(ierr.Error()) 760 | if forceString { 761 | return writeStr(w, buf, ierrStr) 762 | } 763 | var err error 764 | var written int64 765 | written, err = writeBytesHelper(w, errPrefix, written, err) 766 | written, err = writeBytesHelper(w, []byte(ierr.Error()), written, err) 767 | written, err = writeBytesHelper(w, delim, written, err) 768 | return written, err 769 | } 770 | 771 | func writeInt( 772 | w io.Writer, buf []byte, i int64, forceString bool, 773 | ) ( 774 | int64, error, 775 | ) { 776 | buf = strconv.AppendInt(buf[:0], i, 10) 777 | if forceString { 778 | return writeStr(w, buf[len(buf):], buf) 779 | } 780 | 781 | var err error 782 | var written int64 783 | written, err = writeBytesHelper(w, intPrefix, written, err) 784 | written, err = writeBytesHelper(w, buf, written, err) 785 | written, err = writeBytesHelper(w, delim, written, err) 786 | return written, err 787 | } 788 | 789 | func writeFloat(w io.Writer, buf []byte, f float64, bits int) (int64, error) { 790 | buf = strconv.AppendFloat(buf[:0], f, 'f', -1, bits) 791 | return writeStr(w, buf[len(buf):], buf) 792 | } 793 | 794 | func writeNil(w io.Writer) (int64, error) { 795 | written, err := w.Write(nilFormatted) 796 | return int64(written), err 797 | } 798 | 799 | // IsTimeout is a helper function for determining if an IOErr Resp was caused by 800 | // a network timeout 801 | func IsTimeout(r *Resp) bool { 802 | if r.IsType(IOErr) { 803 | t, ok := r.Err.(*net.OpError) 804 | return ok && t.Timeout() 805 | } 806 | return false 807 | } 808 | 809 | // format takes any data structure and attempts to turn it into a Resp or 810 | // multiple embedded Resps in the form of an Array. This is only used for 811 | // NewResp and NewRespFlattenedStrings 812 | func format(m interface{}, forceString bool) Resp { 813 | switch mt := m.(type) { 814 | case []byte: 815 | return Resp{typ: BulkStr, val: mt} 816 | case string: 817 | return Resp{typ: BulkStr, val: []byte(mt)} 818 | case bool: 819 | if mt { 820 | return Resp{typ: BulkStr, val: []byte{'1'}} 821 | } 822 | return Resp{typ: BulkStr, val: []byte{'0'}} 823 | case nil: 824 | if forceString { 825 | return Resp{typ: BulkStr, val: []byte{'0'}} 826 | } 827 | return Resp{typ: Nil} 828 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: 829 | i := anyIntToInt64(mt) 830 | if forceString { 831 | istr := strconv.FormatInt(i, 10) 832 | return Resp{typ: BulkStr, val: []byte(istr)} 833 | } 834 | return Resp{typ: Int, val: i} 835 | case float32: 836 | ft := strconv.FormatFloat(float64(mt), 'f', -1, 32) 837 | return Resp{typ: BulkStr, val: []byte(ft)} 838 | case float64: 839 | ft := strconv.FormatFloat(mt, 'f', -1, 64) 840 | return Resp{typ: BulkStr, val: []byte(ft)} 841 | case error: 842 | if forceString { 843 | return Resp{typ: BulkStr, val: []byte(mt.Error())} 844 | } 845 | return Resp{typ: AppErr, val: mt, Err: mt} 846 | 847 | // We duplicate the below code here a bit, since this is the common case and 848 | // it'd be better to not get the reflect package involved here 849 | case []interface{}: 850 | l := len(mt) 851 | rl := make([]Resp, 0, l) 852 | for i := 0; i < l; i++ { 853 | r := format(mt[i], forceString) 854 | rl = append(rl, r) 855 | } 856 | return Resp{typ: Array, val: rl} 857 | 858 | case *Resp: 859 | return *mt 860 | 861 | case Resp: 862 | return mt 863 | 864 | default: 865 | // Fallback to reflect-based. 866 | switch reflect.TypeOf(m).Kind() { 867 | case reflect.Slice: 868 | rm := reflect.ValueOf(mt) 869 | l := rm.Len() 870 | rl := make([]Resp, 0, l) 871 | for i := 0; i < l; i++ { 872 | vv := rm.Index(i).Interface() 873 | r := format(vv, forceString) 874 | rl = append(rl, r) 875 | } 876 | return Resp{typ: Array, val: rl} 877 | 878 | case reflect.Map: 879 | rm := reflect.ValueOf(mt) 880 | l := rm.Len() * 2 881 | rl := make([]Resp, 0, l) 882 | keys := rm.MapKeys() 883 | for _, k := range keys { 884 | kv := k.Interface() 885 | vv := rm.MapIndex(k).Interface() 886 | 887 | kr := format(kv, forceString) 888 | rl = append(rl, kr) 889 | 890 | vr := format(vv, forceString) 891 | rl = append(rl, vr) 892 | } 893 | return Resp{typ: Array, val: rl} 894 | 895 | default: 896 | return Resp{typ: BulkStr, val: []byte(fmt.Sprint(m))} 897 | } 898 | } 899 | } 900 | --------------------------------------------------------------------------------