├── .gitignore ├── go.mod ├── resp ├── resp_test.go ├── util.go ├── resp.go └── resp2 │ └── bench_test.go ├── radix_test.go ├── bench ├── go.mod ├── go.sum └── bench_test.go ├── cluster_scanner_test.go ├── timer.go ├── go.sum ├── CONTRIBUTING.md ├── LICENSE.txt ├── cluster_crc16_test.go ├── .github └── workflows │ └── push.yml ├── trace ├── cluster.go └── pool.go ├── internal └── bytesutil │ ├── bytesutil_test.go │ ├── bench_test.go │ └── bytesutil.go ├── cluster_scanner.go ├── cluster_crc16.go ├── pubsub_persistent_test.go ├── stub_test.go ├── pubsub_stub_test.go ├── scanner.go ├── conn_test.go ├── pubsub_stub.go ├── scanner_test.go ├── stub.go ├── pipeliner.go ├── tls_test.go ├── pipeliner_test.go ├── cluster_topo.go ├── CHANGELOG.md ├── radix.go ├── cluster_topo_test.go ├── pubsub_persistent.go ├── README.md ├── pool_test.go ├── pubsub_test.go ├── cluster_test.go ├── conn.go ├── stream.go └── sentinel_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | test-tmp 2 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mediocregopher/radix/v3 2 | 3 | require ( 4 | github.com/davecgh/go-spew v1.1.1 // indirect 5 | github.com/pmezard/go-difflib v1.0.0 // indirect 6 | github.com/stretchr/testify v1.2.2 7 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 8 | ) 9 | 10 | go 1.13 11 | -------------------------------------------------------------------------------- /resp/resp_test.go: -------------------------------------------------------------------------------- 1 | package resp 2 | 3 | import ( 4 | . "testing" 5 | 6 | "errors" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestErrDiscarded(t *T) { 12 | err := errors.New("foo") 13 | assert.False(t, errors.As(err, new(ErrDiscarded))) 14 | assert.True(t, errors.As(ErrDiscarded{Err: err}, new(ErrDiscarded))) 15 | assert.True(t, errors.Is(ErrDiscarded{Err: err}, err)) 16 | } 17 | -------------------------------------------------------------------------------- /radix_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/hex" 6 | ) 7 | 8 | func randStr() string { 9 | b := make([]byte, 16) 10 | if _, err := rand.Read(b); err != nil { 11 | panic(err) 12 | } 13 | return hex.EncodeToString(b) 14 | } 15 | 16 | func dial(opts ...DialOpt) Conn { 17 | c, err := Dial("tcp", "127.0.0.1:6379", opts...) 18 | if err != nil { 19 | panic(err) 20 | } 21 | return c 22 | } 23 | -------------------------------------------------------------------------------- /bench/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mediocregopher/radix/bench 2 | 3 | require ( 4 | github.com/garyburd/redigo v1.6.0 // indirect 5 | github.com/gomodule/redigo v2.0.0+incompatible 6 | github.com/joomcode/errorx v0.8.0 // indirect 7 | github.com/joomcode/redispipe v0.9.0 8 | github.com/mediocregopher/radix.v2 v0.0.0-20181115013041-b67df6e626f9 // indirect 9 | github.com/mediocregopher/radix/v3 v3.2.0 10 | ) 11 | 12 | replace github.com/mediocregopher/radix/v3 => ../. 13 | -------------------------------------------------------------------------------- /cluster_scanner_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | . "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestClusterScanner(t *T) { 11 | c, _ := newTestCluster() 12 | defer c.Close() 13 | exp := map[string]bool{} 14 | for _, k := range clusterSlotKeys { 15 | exp[k] = true 16 | require.Nil(t, c.Do(Cmd(nil, "SET", k, "1"))) 17 | } 18 | 19 | scanner := c.NewScanner(ScanAllKeys) 20 | var k string 21 | got := map[string]bool{} 22 | for scanner.Next(&k) { 23 | got[k] = true 24 | } 25 | 26 | assert.Equal(t, exp, got) 27 | } 28 | -------------------------------------------------------------------------------- /timer.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | // global pool of *time.Timer's. 9 | var timerPool sync.Pool 10 | 11 | // get returns a timer that completes after the given duration. 12 | func getTimer(d time.Duration) *time.Timer { 13 | if t, _ := timerPool.Get().(*time.Timer); t != nil { 14 | t.Reset(d) 15 | return t 16 | } 17 | 18 | return time.NewTimer(d) 19 | } 20 | 21 | // putTimer pools the given timer. putTimer stops the timer and handles any left over data in the channel. 22 | func putTimer(t *time.Timer) { 23 | if !t.Stop() { 24 | select { 25 | case <-t.C: 26 | default: 27 | } 28 | } 29 | 30 | timerPool.Put(t) 31 | } 32 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 6 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 7 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 h1:/atklqdjdhuosWIl6AIbOeHJjicWYPqR9bpxqxYG2pA= 8 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 9 | -------------------------------------------------------------------------------- /resp/util.go: -------------------------------------------------------------------------------- 1 | package resp 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | // LenReader adds an additional method to io.Reader, returning how many bytes 8 | // are left till be read until an io.EOF is reached. 9 | type LenReader interface { 10 | io.Reader 11 | Len() int64 12 | } 13 | 14 | type lenReader struct { 15 | r io.Reader 16 | l int64 17 | } 18 | 19 | // NewLenReader wraps an existing io.Reader whose length is known so that it 20 | // implements LenReader. 21 | func NewLenReader(r io.Reader, l int64) LenReader { 22 | return &lenReader{r: r, l: l} 23 | } 24 | 25 | func (lr *lenReader) Read(b []byte) (int, error) { 26 | n, err := lr.r.Read(b) 27 | lr.l -= int64(n) 28 | return n, err 29 | } 30 | 31 | func (lr *lenReader) Len() int64 { 32 | return lr.l 33 | } 34 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # The rulez 2 | 3 | There's a couple. They're not even really rules, more just telling you what you 4 | can expect. 5 | 6 | * Issues are ALWAYS welcome, whether or not you think it's a dumb question or 7 | it's been asked before. I make a very real attempt to respond to all issues in 8 | 24 hours. You can email me directly if I don't make this deadline. 9 | 10 | * Please always preface a pull request by making an issue. It can save you some 11 | time if it turns out that something you consider an issue is actually intended 12 | behavior, and saves me the difficult task of telling you that I'm going to let 13 | the work you put in go to waste. 14 | 15 | * The API never breaks. All PRs which aren't backwards compatible will not be 16 | accepted. Similarly, if I do commit something which isn't backwards compatible 17 | with an older behavior please submit an issue ASAP so I can fix it. 18 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR 17 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 19 | IN THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /cluster_crc16_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "fmt" 5 | . "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestClusterSlot(t *T) { 11 | // basic test 12 | assert.Equal(t, uint16(0x31c3), ClusterSlot([]byte("123456789"))) 13 | 14 | // this is more to test that the hash tag checking works than anything 15 | k := []byte(randStr()) 16 | crcSlot := ClusterSlot(k) 17 | 18 | // ClusterSlot without handling curly braces 19 | rawClusterSlot := func(s string) uint16 { 20 | return CRC16([]byte(s)) % numSlots 21 | } 22 | 23 | clusterSlotf := func(s string, args ...interface{}) uint16 { 24 | return ClusterSlot([]byte(fmt.Sprintf(s, args...))) 25 | } 26 | 27 | kk := randStr() 28 | assert.Equal(t, crcSlot, clusterSlotf("{%s}%s", k, kk)) 29 | assert.Equal(t, crcSlot, clusterSlotf("{%s}}%s", k, kk)) 30 | assert.Equal(t, crcSlot, clusterSlotf("%s{%s}", kk, k)) 31 | assert.Equal(t, crcSlot, clusterSlotf("%s{%s}}%s", kk, k, kk)) 32 | assert.Equal(t, rawClusterSlot(string(k)+"{"), clusterSlotf("%s{", k)) 33 | // if the braces are empty it should match the whole string 34 | assert.Equal(t, rawClusterSlot("foo{}{bar}"), ClusterSlot([]byte(`foo{}{bar}`))) 35 | } 36 | -------------------------------------------------------------------------------- /.github/workflows/push.yml: -------------------------------------------------------------------------------- 1 | on: push 2 | jobs: 3 | lint: 4 | runs-on: ubuntu-latest 5 | steps: 6 | - uses: actions/checkout@v2 7 | - name: golangci-lint 8 | uses: golangci/golangci-lint-action@v2 9 | with: 10 | args: >- 11 | --timeout 10m 12 | --max-same-issues 0 13 | --max-issues-per-linter 0 14 | -E durationcheck 15 | -E errorlint 16 | -E exportloopref 17 | -E forbidigo 18 | -E gochecknoinits 19 | -E godot 20 | -E godox 21 | -E goimports 22 | -E misspell 23 | -E revive 24 | -E unconvert 25 | -E unparam 26 | 27 | test: 28 | runs-on: ubuntu-latest 29 | strategy: 30 | matrix: 31 | go: [ '1.17', '1.16' ] # should always be latest two go version 32 | services: 33 | redis: 34 | image: redis 35 | # Set health checks to wait until redis has started 36 | options: >- 37 | --health-cmd "redis-cli ping" 38 | --health-interval 10s 39 | --health-timeout 5s 40 | --health-retries 5 41 | ports: 42 | - 6379:6379 43 | steps: 44 | - uses: actions/setup-go@v2 45 | with: 46 | go-version: ${{ matrix.go }} 47 | - uses: actions/checkout@v2 48 | - run: go test -race ./... 49 | -------------------------------------------------------------------------------- /resp/resp.go: -------------------------------------------------------------------------------- 1 | // Package resp is an umbrella package which covers both the old RESP protocol 2 | // (resp2) and the new one (resp3), allowing clients to choose which one they 3 | // care to use 4 | package resp 5 | 6 | import ( 7 | "bufio" 8 | "io" 9 | ) 10 | 11 | // Marshaler is the interface implemented by types that can marshal themselves 12 | // into valid RESP. 13 | type Marshaler interface { 14 | MarshalRESP(io.Writer) error 15 | } 16 | 17 | // Unmarshaler is the interface implemented by types that can unmarshal a RESP 18 | // description of themselves. UnmarshalRESP should _always_ fully consume a RESP 19 | // message off the reader, unless there is an error returned from the reader 20 | // itself. 21 | // 22 | // Note that, unlike Marshaler, Unmarshaler _must_ take in a *bufio.Reader. 23 | type Unmarshaler interface { 24 | UnmarshalRESP(*bufio.Reader) error 25 | } 26 | 27 | // ErrDiscarded is used to wrap an error encountered while unmarshaling a 28 | // message. If an error was encountered during unmarshaling but the rest of the 29 | // message was successfully discarded off of the wire, then the error can be 30 | // wrapped in this type. 31 | type ErrDiscarded struct { 32 | Err error 33 | } 34 | 35 | func (ed ErrDiscarded) Error() string { 36 | return ed.Err.Error() 37 | } 38 | 39 | // Unwrap implements the errors.Wrapper interface. 40 | func (ed ErrDiscarded) Unwrap() error { 41 | return ed.Err 42 | } 43 | -------------------------------------------------------------------------------- /trace/cluster.go: -------------------------------------------------------------------------------- 1 | // Package trace contains all the types provided for tracing within the radix 2 | // package. With tracing a user is able to pull out fine-grained runtime events 3 | // as they happen, which is useful for gathering metrics, logging, performance 4 | // analysis, etc... 5 | // 6 | // BIG LOUD DISCLAIMER DO NOT IGNORE THIS: while the main radix package is 7 | // stable and will always remain backwards compatible, trace is still under 8 | // active development and may undergo changes to its types and other features. 9 | // The methods in the main radix package which invoke trace types are guaranteed 10 | // to remain stable. 11 | package trace 12 | 13 | //////////////////////////////////////////////////////////////////////////////// 14 | 15 | type ClusterTrace struct { 16 | // StateChange is called when the cluster becomes down or becomes available again. 17 | StateChange func(ClusterStateChange) 18 | // TopoChanged is called when the cluster's topology changed. 19 | TopoChanged func(ClusterTopoChanged) 20 | // Redirected is called when radix.Do responded 'MOVED' or 'ASKED'. 21 | Redirected func(ClusterRedirected) 22 | } 23 | 24 | type ClusterStateChange struct { 25 | IsDown bool 26 | } 27 | 28 | type ClusterNodeInfo struct { 29 | Addr string 30 | Slots [][2]uint16 31 | IsPrimary bool 32 | } 33 | 34 | type ClusterTopoChanged struct { 35 | Added []ClusterNodeInfo 36 | Removed []ClusterNodeInfo 37 | Changed []ClusterNodeInfo 38 | } 39 | 40 | type ClusterRedirected struct { 41 | Addr string 42 | Key string 43 | Moved, Ask bool 44 | RedirectCount int 45 | 46 | // If true, then the MOVED/ASK error which was received will not be honored, 47 | // and the call to Do will be returning the MOVED/ASK error. 48 | Final bool 49 | } 50 | -------------------------------------------------------------------------------- /bench/go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/garyburd/redigo v1.6.0 h1:0VruCpn7yAIIu7pWVClQC8wxCJEcG3nyzpMSHKi1PQc= 4 | github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= 5 | github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= 6 | github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= 7 | github.com/joomcode/errorx v0.8.0 h1:GhAqPtcYuo1O7TOIbtzEIDzPGQ3SrKJ3tdjXNmUtDNo= 8 | github.com/joomcode/errorx v0.8.0/go.mod h1:kgco15ekB6cs+4Xjzo7SPeXzx38PbJzBwbnu9qfVNHQ= 9 | github.com/joomcode/redispipe v0.9.0 h1:NukwwIvxhg6r2lVxa1RJhEZXYPZZF/OX9WZJk+2cK1Q= 10 | github.com/joomcode/redispipe v0.9.0/go.mod h1:4S/gpBCZ62pB/3+XLNWDH7jQnB0vxmpddAMBva2adpM= 11 | github.com/mediocregopher/mediocre-go-lib v0.0.0-20181029021733-cb65787f37ed h1:3dQJqqDouawQgl3gBE1PNHKFkJYGEuFb1DbSlaxdosE= 12 | github.com/mediocregopher/mediocre-go-lib v0.0.0-20181029021733-cb65787f37ed/go.mod h1:dSsfyI2zABAdhcbvkXqgxOxrCsbYeHCPgrZkku60dSg= 13 | github.com/mediocregopher/radix.v2 v0.0.0-20181115013041-b67df6e626f9 h1:ViNuGS149jgnttqhc6XQNPwdupEMBXqCx9wtlW7P3sA= 14 | github.com/mediocregopher/radix.v2 v0.0.0-20181115013041-b67df6e626f9/go.mod h1:fLRUbhbSd5Px2yKUaGYYPltlyxi1guJz1vCmo1RQL50= 15 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 16 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 17 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 18 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 19 | golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= 20 | golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 21 | -------------------------------------------------------------------------------- /internal/bytesutil/bytesutil_test.go: -------------------------------------------------------------------------------- 1 | package bytesutil 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | crand "crypto/rand" 7 | "io" 8 | "math/rand" 9 | . "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestReadNAppend(t *T) { 17 | buf := []byte("hello") 18 | buf, err := ReadNAppend(bytes.NewReader([]byte(" world!")), buf, len(" world")) 19 | require.Nil(t, err) 20 | assert.Len(t, buf, len("hello world")) 21 | assert.Equal(t, buf, []byte("hello world")) 22 | } 23 | 24 | type discarder struct { 25 | didDiscard bool 26 | *bufio.Reader 27 | } 28 | 29 | func (d *discarder) Discard(n int) (int, error) { 30 | d.didDiscard = true 31 | return d.Reader.Discard(n) 32 | } 33 | 34 | func TestReadNDiscard(t *T) { 35 | rand.Seed(time.Now().UnixNano()) 36 | type testT struct { 37 | n int 38 | discarder bool 39 | } 40 | 41 | assert := func(test testT) { 42 | buf := bytes.NewBuffer(make([]byte, 0, test.n)) 43 | if _, err := io.CopyN(buf, crand.Reader, int64(test.n)); err != nil { 44 | t.Fatal(err) 45 | } 46 | 47 | var r io.Reader = buf 48 | var d *discarder 49 | if test.discarder { 50 | d = &discarder{Reader: bufio.NewReader(r)} 51 | r = d 52 | } 53 | 54 | if err := ReadNDiscard(r, test.n); err != nil { 55 | t.Fatalf("error calling readNDiscard: %s (%#v)", err, test) 56 | 57 | } else if test.discarder && test.n > 0 && !d.didDiscard { 58 | t.Fatalf("Discard not called on discarder (%#v)", test) 59 | 60 | } else if test.discarder && test.n == 0 && d.didDiscard { 61 | t.Fatalf("Unnecessary Discard call (%#v)", test) 62 | 63 | } else if buf.Len() > 0 { 64 | t.Fatalf("%d bytes not discarded (%#v)", buf.Len(), test) 65 | } 66 | } 67 | 68 | // randomly generate test cases 69 | for i := 0; i < 1000; i++ { 70 | test := testT{ 71 | n: rand.Intn(16384), 72 | discarder: rand.Intn(2) == 0, 73 | } 74 | assert(test) 75 | } 76 | 77 | // edge cases 78 | assert(testT{n: 0, discarder: true}) 79 | } 80 | -------------------------------------------------------------------------------- /resp/resp2/bench_test.go: -------------------------------------------------------------------------------- 1 | package resp2 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func BenchmarkIntUnmarshalRESP(b *testing.B) { 11 | tests := []struct { 12 | In string 13 | }{ 14 | {"-1"}, 15 | {"-123"}, 16 | {"1"}, 17 | {"123"}, 18 | {"+1"}, 19 | {"+123"}, 20 | } 21 | 22 | for _, test := range tests { 23 | input := ":" + test.In + "\r\n" 24 | 25 | b.Run(fmt.Sprint(test.In), func(b *testing.B) { 26 | var sr strings.Reader 27 | br := bufio.NewReader(&sr) 28 | 29 | for i := 0; i < b.N; i++ { 30 | sr.Reset(input) 31 | br.Reset(&sr) 32 | 33 | var i Int 34 | if err := i.UnmarshalRESP(br); err != nil { 35 | b.Fatalf("failed to unmarshal %q: %s", input, err) 36 | } 37 | } 38 | }) 39 | } 40 | } 41 | 42 | func BenchmarkAnyUnmarshalRESP(b *testing.B) { 43 | b.Run("Map", func(b *testing.B) { 44 | b.ReportAllocs() 45 | 46 | const input = "*8\r\n" + 47 | "$3\r\nFoo\r\n" + "$1\r\n1\r\n" + 48 | "$3\r\nBAZ\r\n" + "$2\r\n22\r\n" + 49 | "$3\r\nBoz\r\n" + "$4\r\n4444\r\n" + 50 | "$3\r\nBiz\r\n" + "$8\r\n88888888\r\n" 51 | 52 | var sr strings.Reader 53 | br := bufio.NewReader(&sr) 54 | 55 | for i := 0; i < b.N; i++ { 56 | sr.Reset(input) 57 | br.Reset(&sr) 58 | 59 | var m map[string]string 60 | if err := (Any{I: &m}).UnmarshalRESP(br); err != nil { 61 | b.Fatalf("failed to unmarshal %q: %s", input, err) 62 | } 63 | } 64 | }) 65 | 66 | b.Run("Struct", func(b *testing.B) { 67 | b.ReportAllocs() 68 | 69 | const input = "*8\r\n" + 70 | "$3\r\nFoo\r\n" + ":1\r\n" + 71 | "$3\r\nBAZ\r\n" + "$1\r\n3\r\n" + 72 | "$3\r\nBoz\r\n" + ":5\r\n" + 73 | "$3\r\nBiz\r\n" + "$2\r\n10\r\n" 74 | 75 | var sr strings.Reader 76 | br := bufio.NewReader(&sr) 77 | 78 | for i := 0; i < b.N; i++ { 79 | sr.Reset(input) 80 | br.Reset(&sr) 81 | 82 | var s testStructA 83 | if err := (Any{I: &s}).UnmarshalRESP(br); err != nil { 84 | b.Fatalf("failed to unmarshal %q: %s", input, err) 85 | } 86 | } 87 | }) 88 | } 89 | -------------------------------------------------------------------------------- /cluster_scanner.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | type clusterScanner struct { 8 | cluster *Cluster 9 | opts ScanOpts 10 | 11 | addrs []string 12 | currScanner Scanner 13 | lastErr error 14 | } 15 | 16 | // NewScanner will return a Scanner which will scan over every node in the 17 | // cluster. This will panic if the ScanOpt's Command isn't "SCAN". For scanning 18 | // operations other than "SCAN" (e.g. "HSCAN", "ZSCAN") use the normal 19 | // NewScanner function. 20 | // 21 | // If the cluster topology changes during a scan the Scanner may or may not 22 | // error out due to it, depending on the nature of the change. 23 | func (c *Cluster) NewScanner(o ScanOpts) Scanner { 24 | if strings.ToUpper(o.Command) != "SCAN" { 25 | panic("Cluster.NewScanner can only perform SCAN operations") 26 | } 27 | 28 | var addrs []string 29 | for _, node := range c.Topo().Primaries() { 30 | addrs = append(addrs, node.Addr) 31 | } 32 | 33 | cs := &clusterScanner{ 34 | cluster: c, 35 | opts: o, 36 | addrs: addrs, 37 | } 38 | cs.nextScanner() 39 | 40 | return cs 41 | } 42 | 43 | func (cs *clusterScanner) closeCurr() { 44 | if cs.currScanner != nil { 45 | if err := cs.currScanner.Close(); err != nil && cs.lastErr == nil { 46 | cs.lastErr = err 47 | } 48 | cs.currScanner = nil 49 | } 50 | } 51 | 52 | func (cs *clusterScanner) scannerForAddr(addr string) bool { 53 | client, _ := cs.cluster.rpool(addr) 54 | if client != nil { 55 | cs.closeCurr() 56 | cs.currScanner = NewScanner(client, cs.opts) 57 | return true 58 | } 59 | return false 60 | } 61 | 62 | func (cs *clusterScanner) nextScanner() { 63 | for { 64 | if len(cs.addrs) == 0 { 65 | cs.closeCurr() 66 | return 67 | } 68 | addr := cs.addrs[0] 69 | cs.addrs = cs.addrs[1:] 70 | if cs.scannerForAddr(addr) { 71 | return 72 | } 73 | } 74 | } 75 | 76 | func (cs *clusterScanner) Next(res *string) bool { 77 | for { 78 | if cs.currScanner == nil { 79 | return false 80 | } else if out := cs.currScanner.Next(res); out { 81 | return true 82 | } 83 | cs.nextScanner() 84 | } 85 | } 86 | 87 | func (cs *clusterScanner) Close() error { 88 | cs.closeCurr() 89 | return cs.lastErr 90 | } 91 | -------------------------------------------------------------------------------- /internal/bytesutil/bench_test.go: -------------------------------------------------------------------------------- 1 | package bytesutil 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | var bfloat float64 11 | 12 | func BenchmarkReadFloat(b *testing.B) { 13 | tests := []struct { 14 | In string 15 | N int 16 | }{ 17 | {"1", 1}, 18 | {"1.23", 4}, 19 | {"-1", 2}, 20 | {"-1.23", 5}, 21 | {"+1", 2}, 22 | {"+1.23", 5}, 23 | } 24 | 25 | for _, test := range tests { 26 | input, n := test.In, test.N 27 | 28 | b.Run(fmt.Sprint(test.In), func(b *testing.B) { 29 | var r strings.Reader 30 | 31 | for i := 0; i < b.N; i++ { 32 | r.Reset(input) 33 | bfloat, _ = ReadFloat(&r, 64, n) 34 | } 35 | }) 36 | } 37 | } 38 | 39 | var bint int64 40 | 41 | func BenchmarkReadInt(b *testing.B) { 42 | tests := []struct { 43 | In string 44 | N int 45 | }{ 46 | {"1", 1}, 47 | {"123", 3}, 48 | {"-1", 2}, 49 | {"-123", 4}, 50 | {"+1", 2}, 51 | {"+123", 4}, 52 | } 53 | 54 | for _, test := range tests { 55 | input, n := test.In, test.N 56 | 57 | b.Run(fmt.Sprint(test.In), func(b *testing.B) { 58 | var r strings.Reader 59 | 60 | for i := 0; i < b.N; i++ { 61 | r.Reset(input) 62 | bint, _ = ReadInt(&r, n) 63 | } 64 | }) 65 | } 66 | } 67 | 68 | var buint uint64 69 | 70 | func BenchmarkReadUint(b *testing.B) { 71 | tests := []struct { 72 | In string 73 | N int 74 | }{ 75 | {"1", 1}, 76 | {"123", 123}, 77 | } 78 | 79 | for _, test := range tests { 80 | input, n := test.In, test.N 81 | 82 | b.Run(fmt.Sprint(test.In), func(b *testing.B) { 83 | var r strings.Reader 84 | 85 | for i := 0; i < b.N; i++ { 86 | r.Reset(input) 87 | buint, _ = ReadUint(&r, n) 88 | } 89 | }) 90 | } 91 | } 92 | 93 | type nothingReader struct{} 94 | 95 | func (nothingReader) Read(p []byte) (n int, err error) { 96 | return len(p), nil 97 | } 98 | 99 | func BenchmarkReadNAppend(b *testing.B) { 100 | for _, n := range []int{0, 64, 512, 4096} { 101 | b.Run("N="+strconv.Itoa(n), func(b *testing.B) { 102 | var r nothingReader 103 | buf := *GetBytes() 104 | 105 | b.ResetTimer() 106 | 107 | for i := 0; i < b.N; i++ { 108 | if _, err := ReadNAppend(&r, buf, n); err != nil { 109 | b.Fatal(err) 110 | } 111 | } 112 | }) 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /cluster_crc16.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "bytes" 5 | ) 6 | 7 | var tab = [256]uint16{ 8 | 0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7, 9 | 0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef, 10 | 0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6, 11 | 0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de, 12 | 0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485, 13 | 0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d, 14 | 0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4, 15 | 0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc, 16 | 0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823, 17 | 0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b, 18 | 0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12, 19 | 0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a, 20 | 0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41, 21 | 0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49, 22 | 0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70, 23 | 0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78, 24 | 0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f, 25 | 0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067, 26 | 0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e, 27 | 0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256, 28 | 0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d, 29 | 0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405, 30 | 0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c, 31 | 0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634, 32 | 0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab, 33 | 0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3, 34 | 0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a, 35 | 0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92, 36 | 0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9, 37 | 0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1, 38 | 0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8, 39 | 0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0, 40 | } 41 | 42 | const numSlots = 16384 43 | 44 | // CRC16 returns checksum for a given set of bytes based on the crc algorithm 45 | // defined for hashing redis keys in a cluster setup. 46 | func CRC16(buf []byte) uint16 { 47 | crc := uint16(0) 48 | for _, b := range buf { 49 | index := byte(crc>>8) ^ b 50 | crc = (crc << 8) ^ tab[index] 51 | } 52 | return crc 53 | } 54 | 55 | // ClusterSlot returns the slot number the key belongs to in any redis cluster, 56 | // taking into account key hash tags. 57 | func ClusterSlot(key []byte) uint16 { 58 | if start := bytes.Index(key, []byte("{")); start >= 0 { 59 | if end := bytes.Index(key[start+1:], []byte("}")); end > 0 { 60 | key = key[start+1 : start+1+end] 61 | } 62 | } 63 | return CRC16(key) % numSlots 64 | } 65 | -------------------------------------------------------------------------------- /pubsub_persistent_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | . "testing" 5 | "time" 6 | 7 | "errors" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func closablePersistentPubSub() (PubSubConn, func()) { 13 | closeCh := make(chan chan bool) 14 | p := PersistentPubSub("", "", func(_, _ string) (Conn, error) { 15 | c := dial() 16 | go func() { 17 | closeRetCh := <-closeCh 18 | c.Close() 19 | closeRetCh <- true 20 | }() 21 | return c, nil 22 | }) 23 | 24 | return p, func() { 25 | closeRetCh := make(chan bool) 26 | closeCh <- closeRetCh 27 | <-closeRetCh 28 | } 29 | } 30 | 31 | func TestPersistentPubSub(t *T) { 32 | p, closeFn := closablePersistentPubSub() 33 | pubCh := make(chan int) 34 | go func() { 35 | for i := 0; i < 1000; i++ { 36 | pubCh <- i 37 | if i%100 == 0 { 38 | time.Sleep(100 * time.Millisecond) 39 | closeFn() 40 | assert.Nil(t, p.Ping()) 41 | } 42 | } 43 | close(pubCh) 44 | }() 45 | 46 | testSubscribe(t, p, pubCh) 47 | } 48 | 49 | func TestPersistentPubSubAbortAfter(t *T) { 50 | var errNope = errors.New("nope") 51 | var attempts int 52 | connFn := func(_, _ string) (Conn, error) { 53 | attempts++ 54 | if attempts%3 != 0 { 55 | return nil, errNope 56 | } 57 | return dial(), nil 58 | } 59 | 60 | _, err := PersistentPubSubWithOpts("", "", 61 | PersistentPubSubConnFunc(connFn), 62 | PersistentPubSubAbortAfter(2)) 63 | assert.Equal(t, errNope, err) 64 | 65 | attempts = 0 66 | p, err := PersistentPubSubWithOpts("", "", 67 | PersistentPubSubConnFunc(connFn), 68 | PersistentPubSubAbortAfter(3)) 69 | assert.NoError(t, err) 70 | assert.NoError(t, p.Ping()) 71 | p.Close() 72 | } 73 | 74 | // https://github.com/mediocregopher/radix/issues/184 75 | func TestPersistentPubSubClose(t *T) { 76 | channel := "TestPersistentPubSubClose:" + randStr() 77 | 78 | stopCh := make(chan struct{}) 79 | defer close(stopCh) 80 | go func() { 81 | pubConn := dial() 82 | for { 83 | err := pubConn.Do(Cmd(nil, "PUBLISH", channel, randStr())) 84 | assert.NoError(t, err) 85 | time.Sleep(10 * time.Millisecond) 86 | 87 | select { 88 | case <-stopCh: 89 | return 90 | default: 91 | } 92 | } 93 | }() 94 | 95 | for i := 0; i < 1000; i++ { 96 | p := PersistentPubSub("", "", func(_, _ string) (Conn, error) { 97 | return dial(), nil 98 | }) 99 | msgCh := make(chan PubSubMessage) 100 | err := p.Subscribe(msgCh, channel) 101 | assert.NoError(t, err) 102 | // drain msgCh till it closes 103 | go func() { 104 | for range msgCh { 105 | } 106 | }() 107 | p.Close() 108 | close(msgCh) 109 | } 110 | } 111 | 112 | func TestPersistentPubSubUseAfterCloseDeadlock(t *T) { 113 | channel := "TestPersistentPubSubUseAfterCloseDeadlock:" + randStr() 114 | 115 | p := PersistentPubSub("", "", func(_, _ string) (Conn, error) { 116 | return dial(), nil 117 | }) 118 | msgCh := make(chan PubSubMessage) 119 | err := p.Subscribe(msgCh, channel) 120 | assert.NoError(t, err) 121 | p.Close() 122 | 123 | errch := make(chan error) 124 | go func() { 125 | errch <- p.PUnsubscribe(msgCh, channel) 126 | }() 127 | 128 | select { 129 | case <-time.After(time.Second): 130 | assert.Fail(t, "PUnsubscribe call timeout") 131 | case err := <-errch: 132 | assert.Error(t, err) 133 | } 134 | 135 | } 136 | -------------------------------------------------------------------------------- /stub_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | "strconv" 8 | "sync" 9 | . "testing" 10 | "time" 11 | 12 | "github.com/mediocregopher/radix/v3/resp/resp2" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | // Watching the watchmen 18 | 19 | func testStub() Conn { 20 | m := map[string]string{} 21 | return Stub("tcp", "127.0.0.1:6379", func(args []string) interface{} { 22 | switch args[0] { 23 | case "GET": 24 | return m[args[1]] 25 | case "SET": 26 | m[args[1]] = args[2] 27 | return nil 28 | case "ECHO": 29 | return args[1] 30 | default: 31 | return fmt.Errorf("testStub doesn't support command %q", args[0]) 32 | } 33 | }) 34 | } 35 | 36 | func TestStub(t *T) { 37 | stub := testStub() 38 | 39 | { // Basic test 40 | var foo string 41 | require.Nil(t, stub.Do(Cmd(nil, "SET", "foo", "a"))) 42 | require.Nil(t, stub.Do(Cmd(&foo, "GET", "foo"))) 43 | assert.Equal(t, "a", foo) 44 | } 45 | 46 | { // Basic test with an int, to ensure marshalling/unmarshalling all works 47 | var foo int 48 | require.Nil(t, stub.Do(FlatCmd(nil, "SET", "foo", 1))) 49 | require.Nil(t, stub.Do(Cmd(&foo, "GET", "foo"))) 50 | assert.Equal(t, 1, foo) 51 | } 52 | } 53 | 54 | func TestStubPipeline(t *T) { 55 | stub := testStub() 56 | var out string 57 | err := stub.Do(Pipeline( 58 | Cmd(nil, "SET", "foo", "bar"), 59 | Cmd(&out, "GET", "foo"), 60 | )) 61 | 62 | require.Nil(t, err) 63 | assert.Equal(t, "bar", out) 64 | } 65 | 66 | func TestStubLockingTimeout(t *T) { 67 | stub := testStub() 68 | wg := new(sync.WaitGroup) 69 | c := 1000 70 | 71 | wg.Add(1) 72 | go func() { 73 | defer wg.Done() 74 | for i := 0; i < c; i++ { 75 | require.Nil(t, stub.Encode(Cmd(nil, "ECHO", strconv.Itoa(i)))) 76 | } 77 | }() 78 | 79 | wg.Add(1) 80 | go func() { 81 | defer wg.Done() 82 | for i := 0; i < c; i++ { 83 | var j int 84 | require.Nil(t, stub.Decode(resp2.Any{I: &j})) 85 | assert.Equal(t, i, j) 86 | } 87 | }() 88 | 89 | wg.Wait() 90 | 91 | // test out timeout. do a write-then-read to ensure nothing bad happens 92 | // when there's actually data to read 93 | now := time.Now() 94 | conn := stub.NetConn() 95 | err := conn.SetDeadline(now.Add(2 * time.Second)) 96 | assert.NoError(t, err) 97 | require.Nil(t, stub.Encode(Cmd(nil, "ECHO", "1"))) 98 | require.Nil(t, stub.Decode(resp2.Any{})) 99 | 100 | // now there's no data to read, should return after 2-ish seconds with a 101 | // timeout error 102 | err = stub.Decode(resp2.Any{}) 103 | 104 | var nerr *net.OpError 105 | assert.True(t, errors.As(err, &nerr)) 106 | assert.True(t, nerr.Timeout()) 107 | } 108 | 109 | func ExampleStub() { 110 | m := map[string]string{} 111 | stub := Stub("tcp", "127.0.0.1:6379", func(args []string) interface{} { 112 | switch args[0] { 113 | case "GET": 114 | return m[args[1]] 115 | case "SET": 116 | m[args[1]] = args[2] 117 | return nil 118 | default: 119 | return fmt.Errorf("this stub doesn't support command %q", args[0]) 120 | } 121 | }) 122 | 123 | if err := stub.Do(Cmd(nil, "SET", "foo", "1")); err != nil { 124 | // handle error 125 | } 126 | 127 | var foo int 128 | if err := stub.Do(Cmd(&foo, "GET", "foo")); err != nil { 129 | // handle error 130 | } 131 | 132 | fmt.Printf("foo: %d\n", foo) 133 | } 134 | -------------------------------------------------------------------------------- /pubsub_stub_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "log" 5 | . "testing" 6 | "time" 7 | 8 | "github.com/mediocregopher/radix/v3/resp/resp2" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestPubSubStub(t *T) { 14 | conn, stubCh := PubSubStub("tcp", "127.0.0.1:6379", func(in []string) interface{} { 15 | return in 16 | }) 17 | message := func(channel, val string) { 18 | stubCh <- PubSubMessage{Type: "message", Channel: channel, Message: []byte(val)} 19 | <-conn.(*pubSubStub).mDoneCh 20 | } 21 | pmessage := func(pattern, channel, val string) { 22 | stubCh <- PubSubMessage{Type: "pmessage", Pattern: pattern, Channel: channel, Message: []byte(val)} 23 | <-conn.(*pubSubStub).mDoneCh 24 | } 25 | 26 | assertEncode := func(in ...string) { 27 | require.Nil(t, conn.Encode(resp2.Any{I: in})) 28 | } 29 | assertDecode := func(exp ...string) { 30 | var into []string 31 | require.Nil(t, conn.Decode(resp2.Any{I: &into})) 32 | assert.Equal(t, exp, into) 33 | } 34 | 35 | assertEncode("foo") 36 | assertDecode("foo") 37 | 38 | // shouldn't do anything 39 | message("foo", "a") 40 | 41 | assertEncode("SUBSCRIBE", "foo", "bar") 42 | assertDecode("subscribe", "foo", "1") 43 | assertDecode("subscribe", "bar", "2") 44 | 45 | // should error because we're in pubsub mode 46 | assertEncode("wat") 47 | assert.Equal(t, errPubSubMode.Error(), conn.Decode(resp2.Any{}).Error()) 48 | 49 | assertEncode("PING") 50 | assertDecode("pong", "") 51 | 52 | message("foo", "b") 53 | message("bar", "c") 54 | message("baz", "c") 55 | assertDecode("message", "foo", "b") 56 | assertDecode("message", "bar", "c") 57 | 58 | assertEncode("PSUBSCRIBE", "b*z") 59 | assertDecode("psubscribe", "b*z", "3") 60 | assertEncode("PSUBSCRIBE", "b[au]z") 61 | assertDecode("psubscribe", "b[au]z", "4") 62 | pmessage("b*z", "buz", "d") 63 | pmessage("b[au]z", "buz", "d") 64 | pmessage("b*z", "biz", "e") 65 | assertDecode("pmessage", "b*z", "buz", "d") 66 | assertDecode("pmessage", "b[au]z", "buz", "d") 67 | assertDecode("pmessage", "b*z", "biz", "e") 68 | 69 | assertEncode("UNSUBSCRIBE", "foo") 70 | assertDecode("unsubscribe", "foo", "3") 71 | message("foo", "f") 72 | message("bar", "g") 73 | assertDecode("message", "bar", "g") 74 | 75 | assertEncode("UNSUBSCRIBE", "bar") 76 | assertDecode("unsubscribe", "bar", "2") 77 | assertEncode("PUNSUBSCRIBE", "b*z") 78 | assertDecode("punsubscribe", "b*z", "1") 79 | assertEncode("PUNSUBSCRIBE", "b[au]z") 80 | assertDecode("punsubscribe", "b[au]z", "0") 81 | 82 | // No longer in pubsub mode, normal requests should work again 83 | assertEncode("wat") 84 | assertDecode("wat") 85 | } 86 | 87 | func ExamplePubSubStub() { 88 | // Make a pubsub stub conn which will return nil for everything except 89 | // pubsub commands (which will be handled automatically) 90 | stub, stubCh := PubSubStub("tcp", "127.0.0.1:6379", func([]string) interface{} { 91 | return nil 92 | }) 93 | 94 | // These writes shouldn't do anything, initially, since we haven't 95 | // subscribed to anything 96 | go func() { 97 | for { 98 | stubCh <- PubSubMessage{ 99 | Channel: "foo", 100 | Message: []byte("bar"), 101 | } 102 | time.Sleep(1 * time.Second) 103 | } 104 | }() 105 | 106 | // Use PubSub to wrap the stub like we would for a normal redis connection 107 | pstub := PubSub(stub) 108 | 109 | // Subscribe msgCh to "foo" 110 | msgCh := make(chan PubSubMessage) 111 | if err := pstub.Subscribe(msgCh, "foo"); err != nil { 112 | log.Fatal(err) 113 | } 114 | 115 | // now msgCh is subscribed the publishes being made by the go-routine above 116 | // will start being written to it 117 | for m := range msgCh { 118 | log.Printf("read m: %#v", m) 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /scanner.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "bufio" 5 | "strconv" 6 | "strings" 7 | 8 | "errors" 9 | 10 | "github.com/mediocregopher/radix/v3/resp/resp2" 11 | ) 12 | 13 | // Scanner is used to iterate through the results of a SCAN call (or HSCAN, 14 | // SSCAN, etc...) 15 | // 16 | // Once created, repeatedly call Next() on it to fill the passed in string 17 | // pointer with the next result. Next will return false if there's no more 18 | // results to retrieve or if an error occurred, at which point Close should be 19 | // called to retrieve any error. 20 | type Scanner interface { 21 | Next(*string) bool 22 | Close() error 23 | } 24 | 25 | // ScanOpts are various parameters which can be passed into ScanWithOpts. Some 26 | // fields are required depending on which type of scan is being done. 27 | type ScanOpts struct { 28 | // The scan command to do, e.g. "SCAN", "HSCAN", etc... 29 | Command string 30 | 31 | // The key to perform the scan on. Only necessary when Command isn't "SCAN" 32 | Key string 33 | 34 | // An optional pattern to filter returned keys by 35 | Pattern string 36 | 37 | // An optional count hint to send to redis to indicate number of keys to 38 | // return per call. This does not affect the actual results of the scan 39 | // command, but it may be useful for optimizing certain datasets 40 | Count int 41 | 42 | // An optional type name to filter for values of the given type. 43 | // The type names are the same as returned by the "TYPE" command. 44 | // This if only available in Redis 6 or newer and only works with "SCAN". 45 | // If used with an older version of Redis or with a Command other than 46 | // "SCAN", scanning will fail. 47 | Type string 48 | } 49 | 50 | func (o ScanOpts) cmd(rcv interface{}, cursor string) CmdAction { 51 | cmdStr := strings.ToUpper(o.Command) 52 | args := make([]string, 0, 8) 53 | if cmdStr != "SCAN" { 54 | args = append(args, o.Key) 55 | } 56 | 57 | args = append(args, cursor) 58 | if o.Pattern != "" { 59 | args = append(args, "MATCH", o.Pattern) 60 | } 61 | if o.Count > 0 { 62 | args = append(args, "COUNT", strconv.Itoa(o.Count)) 63 | } 64 | if o.Type != "" { 65 | args = append(args, "TYPE", o.Type) 66 | } 67 | 68 | return Cmd(rcv, cmdStr, args...) 69 | } 70 | 71 | // ScanAllKeys is a shortcut ScanOpts which can be used to scan all keys. 72 | var ScanAllKeys = ScanOpts{ 73 | Command: "SCAN", 74 | } 75 | 76 | type scanner struct { 77 | Client 78 | ScanOpts 79 | res scanResult 80 | resIdx int 81 | err error 82 | } 83 | 84 | // NewScanner creates a new Scanner instance which will iterate over the redis 85 | // instance's Client using the ScanOpts. 86 | // 87 | // NOTE if Client is a *Cluster this will not work correctly, use the NewScanner 88 | // method on Cluster instead. 89 | func NewScanner(c Client, o ScanOpts) Scanner { 90 | return &scanner{ 91 | Client: c, 92 | ScanOpts: o, 93 | res: scanResult{ 94 | cur: "0", 95 | }, 96 | } 97 | } 98 | 99 | func (s *scanner) Next(res *string) bool { 100 | for { 101 | if s.err != nil { 102 | return false 103 | } 104 | 105 | for s.resIdx < len(s.res.keys) { 106 | *res = s.res.keys[s.resIdx] 107 | s.resIdx++ 108 | if *res != "" { 109 | return true 110 | } 111 | } 112 | 113 | if s.res.cur == "0" && s.res.keys != nil { 114 | return false 115 | } 116 | 117 | s.err = s.Client.Do(s.cmd(&s.res, s.res.cur)) 118 | s.resIdx = 0 119 | } 120 | } 121 | 122 | func (s *scanner) Close() error { 123 | return s.err 124 | } 125 | 126 | type scanResult struct { 127 | cur string 128 | keys []string 129 | } 130 | 131 | func (s *scanResult) UnmarshalRESP(br *bufio.Reader) error { 132 | var ah resp2.ArrayHeader 133 | if err := ah.UnmarshalRESP(br); err != nil { 134 | return err 135 | } else if ah.N != 2 { 136 | return errors.New("not enough parts returned") 137 | } 138 | 139 | var c resp2.BulkString 140 | if err := c.UnmarshalRESP(br); err != nil { 141 | return err 142 | } 143 | 144 | s.cur = c.S 145 | s.keys = s.keys[:0] 146 | 147 | return (resp2.Any{I: &s.keys}).UnmarshalRESP(br) 148 | } 149 | -------------------------------------------------------------------------------- /conn_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "regexp" 5 | "strings" 6 | . "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestCloseBehavior(t *T) { 14 | c := dial() 15 | 16 | // sanity check 17 | var out string 18 | require.Nil(t, c.Do(Cmd(&out, "ECHO", "foo"))) 19 | assert.Equal(t, "foo", out) 20 | 21 | c.Close() 22 | require.NotNil(t, c.Do(Cmd(&out, "ECHO", "foo"))) 23 | require.NotNil(t, c.NetConn().SetDeadline(time.Now())) 24 | } 25 | 26 | func TestDialURI(t *T) { 27 | c, err := Dial("tcp", "redis://127.0.0.1:6379") 28 | if err != nil { 29 | t.Fatal(err) 30 | } else if err := c.Do(Cmd(nil, "PING")); err != nil { 31 | t.Fatal(err) 32 | } 33 | } 34 | 35 | func TestDialAuth(t *T) { 36 | type testCase struct { 37 | url, dialOptUser, dialOptPass string 38 | } 39 | 40 | runTests := func(t *T, tests []testCase, allowedErrs []string) { 41 | for _, test := range tests { 42 | var opts []DialOpt 43 | if test.dialOptUser != "" { 44 | opts = append(opts, DialAuthUser(test.dialOptUser, test.dialOptPass)) 45 | } else if test.dialOptPass != "" { 46 | opts = append(opts, DialAuthPass(test.dialOptPass)) 47 | } 48 | _, err := Dial("tcp", test.url, opts...) 49 | 50 | // It's difficult to test _which_ password is being sent, but it's easy 51 | // enough to tell that one was sent because redis returns an error if one 52 | // isn't set in the config 53 | assert.Errorf(t, err, "expected authentication error, got nil") 54 | assert.Containsf(t, allowedErrs, err.Error(), "one of %v expected, got %v (test:%#v)", allowedErrs, err, test) 55 | } 56 | } 57 | 58 | t.Run("Password only", func(t *T) { 59 | runTests(t, []testCase{ 60 | {url: "redis://:myPass@127.0.0.1:6379"}, 61 | {url: "redis://127.0.0.1:6379?password=myPass"}, 62 | {url: "127.0.0.1:6379", dialOptPass: "myPass"}, 63 | }, []string{ 64 | "ERR Client sent AUTH, but no password is set", 65 | // Redis 6 only 66 | "ERR AUTH called without any password configured for the default user. Are you sure your configuration is correct?", 67 | }) 68 | }) 69 | 70 | t.Run("Username and password", func(t *T) { 71 | conn := dial() 72 | defer conn.Close() 73 | 74 | requireRedisVersion(t, conn, 6, 0, 0) 75 | 76 | runTests(t, []testCase{ 77 | {url: "redis://user:myPass@127.0.0.1:6379"}, 78 | {url: "redis://127.0.0.1:6379?username=mediocregopher"}, 79 | {url: "127.0.0.1:6379", dialOptUser: "mediocregopher"}, 80 | {url: "redis://127.0.0.1:6379?username=mediocregopher&password=myPass"}, 81 | {url: "127.0.0.1:6379", dialOptUser: "mediocregopher", dialOptPass: "myPass"}, 82 | }, []string{ 83 | "WRONGPASS invalid username-password pair", 84 | "WRONGPASS invalid username-password pair or user is disabled.", 85 | }) 86 | }) 87 | } 88 | 89 | func TestDialSelect(t *T) { 90 | 91 | // unfortunately this is the best way to discover the currently selected 92 | // database, and it's janky af 93 | assertDB := func(c Conn) bool { 94 | name := randStr() 95 | if err := c.Do(Cmd(nil, "CLIENT", "SETNAME", name)); err != nil { 96 | t.Fatal(err) 97 | } 98 | 99 | var list string 100 | if err := c.Do(Cmd(&list, "CLIENT", "LIST")); err != nil { 101 | t.Fatal(err) 102 | } 103 | 104 | line := regexp.MustCompile(".*name=" + name + ".*").FindString(list) 105 | if line == "" { 106 | t.Fatalf("line messed up:%q (list:%q name:%q)", line, list, name) 107 | } 108 | 109 | return strings.Index(line, " db=9 ") > 0 110 | } 111 | 112 | tests := []struct { 113 | url string 114 | dialOptSelect int 115 | }{ 116 | {url: "redis://127.0.0.1:6379/9"}, 117 | {url: "redis://127.0.0.1:6379?db=9"}, 118 | {url: "redis://127.0.0.1:6379", dialOptSelect: 9}, 119 | // DialOpt should overwrite URI 120 | {url: "redis://127.0.0.1:6379/8", dialOptSelect: 9}, 121 | } 122 | 123 | for _, test := range tests { 124 | var opts []DialOpt 125 | if test.dialOptSelect > 0 { 126 | opts = append(opts, DialSelectDB(test.dialOptSelect)) 127 | } 128 | c, err := Dial("tcp", test.url, opts...) 129 | if err != nil { 130 | t.Fatalf("got err connecting:%v (test:%#v)", err, test) 131 | } 132 | 133 | if !assertDB(c) { 134 | t.Fatalf("db not set to 9 (test:%#v)", test) 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /pubsub_stub.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "strings" 7 | "sync" 8 | 9 | "errors" 10 | 11 | "github.com/mediocregopher/radix/v3/resp" 12 | "github.com/mediocregopher/radix/v3/resp/resp2" 13 | ) 14 | 15 | var errPubSubMode = resp2.Error{ 16 | E: errors.New("ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context"), 17 | } 18 | 19 | type multiMarshal []resp.Marshaler 20 | 21 | func (mm multiMarshal) MarshalRESP(w io.Writer) error { 22 | for _, m := range mm { 23 | if err := m.MarshalRESP(w); err != nil { 24 | return err 25 | } 26 | } 27 | return nil 28 | } 29 | 30 | type pubSubStub struct { 31 | Conn 32 | fn func([]string) interface{} 33 | inCh <-chan PubSubMessage 34 | 35 | closeOnce sync.Once 36 | closeCh chan struct{} 37 | closeErr error 38 | 39 | l sync.Mutex 40 | pubsubMode bool 41 | subbed, psubbed map[string]bool 42 | 43 | // this is only used for tests 44 | mDoneCh chan struct{} 45 | } 46 | 47 | // PubSubStub returns a (fake) Conn, much like Stub does, which pretends it is a 48 | // Conn to a real redis instance, but is instead using the given callback to 49 | // service requests. It is primarily useful for writing tests. 50 | // 51 | // PubSubStub differes from Stub in that Encode calls for (P)SUBSCRIBE, 52 | // (P)UNSUBSCRIBE, MESSAGE, and PING will be intercepted and handled as per 53 | // redis' expected pubsub functionality. A PubSubMessage may be written to the 54 | // returned channel at any time, and if the PubSubStub has had (P)SUBSCRIBE 55 | // called matching that PubSubMessage it will be written to the PubSubStub's 56 | // internal buffer as expected. 57 | // 58 | // This is intended to be used so that it can mock services which can perform 59 | // both normal redis commands and pubsub (e.g. a real redis instance, redis 60 | // sentinel). Once created this stub can be passed into PubSub and treated like 61 | // a real connection. 62 | func PubSubStub(remoteNetwork, remoteAddr string, fn func([]string) interface{}) (Conn, chan<- PubSubMessage) { 63 | ch := make(chan PubSubMessage) 64 | s := &pubSubStub{ 65 | fn: fn, 66 | inCh: ch, 67 | closeCh: make(chan struct{}), 68 | subbed: map[string]bool{}, 69 | psubbed: map[string]bool{}, 70 | mDoneCh: make(chan struct{}, 1), 71 | } 72 | s.Conn = Stub(remoteNetwork, remoteAddr, s.innerFn) 73 | go s.spin() 74 | return s, ch 75 | } 76 | 77 | func (s *pubSubStub) innerFn(ss []string) interface{} { 78 | s.l.Lock() 79 | defer s.l.Unlock() 80 | 81 | writeRes := func(mm multiMarshal, cmd, subj string) multiMarshal { 82 | c := len(s.subbed) + len(s.psubbed) 83 | s.pubsubMode = c > 0 84 | return append(mm, resp2.Any{I: []interface{}{cmd, subj, c}}) 85 | } 86 | 87 | switch strings.ToUpper(ss[0]) { 88 | case "PING": 89 | if !s.pubsubMode { 90 | return s.fn(ss) 91 | } 92 | return []string{"pong", ""} 93 | case "SUBSCRIBE": 94 | var mm multiMarshal 95 | for _, channel := range ss[1:] { 96 | s.subbed[channel] = true 97 | mm = writeRes(mm, "subscribe", channel) 98 | } 99 | return mm 100 | case "UNSUBSCRIBE": 101 | var mm multiMarshal 102 | for _, channel := range ss[1:] { 103 | delete(s.subbed, channel) 104 | mm = writeRes(mm, "unsubscribe", channel) 105 | } 106 | return mm 107 | case "PSUBSCRIBE": 108 | var mm multiMarshal 109 | for _, pattern := range ss[1:] { 110 | s.psubbed[pattern] = true 111 | mm = writeRes(mm, "psubscribe", pattern) 112 | } 113 | return mm 114 | case "PUNSUBSCRIBE": 115 | var mm multiMarshal 116 | for _, pattern := range ss[1:] { 117 | delete(s.psubbed, pattern) 118 | mm = writeRes(mm, "punsubscribe", pattern) 119 | } 120 | return mm 121 | case "MESSAGE": 122 | m := PubSubMessage{ 123 | Type: "message", 124 | Channel: ss[1], 125 | Message: []byte(ss[2]), 126 | } 127 | 128 | var mm multiMarshal 129 | if s.subbed[m.Channel] { 130 | mm = append(mm, m) 131 | } 132 | return mm 133 | case "PMESSAGE": 134 | m := PubSubMessage{ 135 | Type: "pmessage", 136 | Pattern: ss[1], 137 | Channel: ss[2], 138 | Message: []byte(ss[3]), 139 | } 140 | 141 | var mm multiMarshal 142 | if s.psubbed[m.Pattern] { 143 | mm = append(mm, m) 144 | } 145 | return mm 146 | default: 147 | if s.pubsubMode { 148 | return errPubSubMode 149 | } 150 | return s.fn(ss) 151 | } 152 | } 153 | 154 | func (s *pubSubStub) Close() error { 155 | s.closeOnce.Do(func() { 156 | close(s.closeCh) 157 | s.closeErr = s.Conn.Close() 158 | }) 159 | return s.closeErr 160 | } 161 | 162 | func (s *pubSubStub) spin() { 163 | for { 164 | select { 165 | case m, ok := <-s.inCh: 166 | if !ok { 167 | panic("PubSubStub message channel was closed") 168 | } 169 | if m.Type == "" { 170 | if m.Pattern == "" { 171 | m.Type = "message" 172 | } else { 173 | m.Type = "pmessage" 174 | } 175 | } 176 | if err := s.Conn.Encode(m); err != nil { 177 | panic(fmt.Sprintf("error encoding message in PubSubStub: %s", err)) 178 | } 179 | select { 180 | case s.mDoneCh <- struct{}{}: 181 | default: 182 | } 183 | case <-s.closeCh: 184 | return 185 | } 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /scanner_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "log" 5 | "regexp" 6 | "strconv" 7 | . "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | var redisVersionPat = regexp.MustCompile(`(?m)^redis_version:(\d+)\.(\d+)\.(\d+).*$`) 14 | 15 | func requireRedisVersion(tb TB, c Client, major, minor, patch int) { 16 | tb.Helper() 17 | 18 | var info string 19 | require.NoError(tb, c.Do(Cmd(&info, "INFO", "server"))) 20 | 21 | m := redisVersionPat.FindStringSubmatch(info) 22 | if m == nil { 23 | panic("failed to get redis server version") 24 | } 25 | 26 | gotMajor, _ := strconv.Atoi(m[1]) 27 | gotMinor, _ := strconv.Atoi(m[2]) 28 | gotPatch, _ := strconv.Atoi(m[3]) 29 | 30 | if gotMajor < major || 31 | (gotMajor == major && gotMinor < minor) || 32 | (gotMajor == major && gotMinor == minor && gotPatch < patch) { 33 | tb.Skipf("not supported with current redis version %d.%d.%d, need at least %d.%d.%d", 34 | gotMajor, 35 | gotMinor, 36 | gotPatch, 37 | major, 38 | minor, 39 | patch) 40 | } 41 | } 42 | 43 | func TestScanner(t *T) { 44 | c := dial() 45 | 46 | // Make a random dataset 47 | prefix := randStr() 48 | fullMap := map[string]bool{} 49 | for i := 0; i < 100; i++ { 50 | key := prefix + ":" + strconv.Itoa(i) 51 | fullMap[key] = true 52 | require.Nil(t, c.Do(Cmd(nil, "SET", key, "1"))) 53 | } 54 | 55 | // make sure we get all results when scanning with an existing prefix 56 | sc := NewScanner(c, ScanOpts{Command: "SCAN", Pattern: prefix + ":*"}) 57 | var key string 58 | for sc.Next(&key) { 59 | delete(fullMap, key) 60 | } 61 | require.Nil(t, sc.Close()) 62 | assert.Empty(t, fullMap) 63 | 64 | // make sure we don't get any results when scanning with a non-existing 65 | // prefix 66 | sc = NewScanner(c, ScanOpts{Command: "SCAN", Pattern: prefix + "DNE:*"}) 67 | assert.False(t, sc.Next(nil)) 68 | require.Nil(t, sc.Close()) 69 | } 70 | 71 | // Similar to TestScanner, but scans over a set instead of the whole key space. 72 | func TestScannerSet(t *T) { 73 | c := dial() 74 | 75 | key := randStr() 76 | fullMap := map[string]bool{} 77 | for i := 0; i < 100; i++ { 78 | elem := strconv.Itoa(i) 79 | fullMap[elem] = true 80 | require.Nil(t, c.Do(Cmd(nil, "SADD", key, elem))) 81 | } 82 | 83 | // make sure we get all results when scanning an existing set 84 | sc := NewScanner(c, ScanOpts{Command: "SSCAN", Key: key}) 85 | var val string 86 | for sc.Next(&val) { 87 | delete(fullMap, val) 88 | } 89 | require.Nil(t, sc.Close()) 90 | assert.Empty(t, fullMap) 91 | 92 | // make sure we don't get any results when scanning a non-existent set 93 | sc = NewScanner(c, ScanOpts{Command: "SSCAN", Key: key + "DNE"}) 94 | assert.False(t, sc.Next(nil)) 95 | require.Nil(t, sc.Close()) 96 | } 97 | 98 | func TestScannerType(t *T) { 99 | c := dial() 100 | requireRedisVersion(t, c, 6, 0, 0) 101 | 102 | for i := 0; i < 100; i++ { 103 | require.NoError(t, c.Do(Cmd(nil, "SET", randStr(), "string"))) 104 | require.NoError(t, c.Do(Cmd(nil, "LPUSH", randStr(), "list"))) 105 | require.NoError(t, c.Do(Cmd(nil, "HMSET", randStr(), "hash", "hash"))) 106 | require.NoError(t, c.Do(Cmd(nil, "SADD", randStr(), "set"))) 107 | require.NoError(t, c.Do(Cmd(nil, "ZADD", randStr(), "1000", "zset"))) 108 | } 109 | 110 | scanType := func(type_ string) { 111 | sc := NewScanner(c, ScanOpts{Command: "SCAN", Type: type_}) 112 | 113 | var key string 114 | for sc.Next(&key) { 115 | var got string 116 | require.NoError(t, c.Do(Cmd(&got, "TYPE", key))) 117 | assert.Equalf(t, type_, got, "key %s has wrong type %q, expected %q", got, type_) 118 | } 119 | require.NoError(t, sc.Close()) 120 | } 121 | 122 | scanType("string") 123 | scanType("list") 124 | scanType("hash") 125 | scanType("set") 126 | scanType("zset") 127 | } 128 | 129 | func BenchmarkScanner(b *B) { 130 | c := dial() 131 | 132 | const total = 10 * 1000 133 | 134 | // Make a random dataset 135 | prefix := randStr() 136 | for i := 0; i < total; i++ { 137 | key := prefix + ":" + strconv.Itoa(i) 138 | require.Nil(b, c.Do(Cmd(nil, "SET", key, "1"))) 139 | } 140 | 141 | b.ResetTimer() 142 | 143 | for i := 0; i < b.N; i++ { 144 | // make sure we get all results when scanning with an existing prefix 145 | sc := NewScanner(c, ScanOpts{Command: "SCAN", Pattern: prefix + ":*"}) 146 | var key string 147 | var got int 148 | for sc.Next(&key) { 149 | got++ 150 | } 151 | if got != total { 152 | require.Failf(b, "mismatched between inserted and scanned keys", "expected %d keys, got %d", total, got) 153 | } 154 | } 155 | } 156 | 157 | func ExampleNewScanner_scan() { 158 | client, err := DefaultClientFunc("tcp", "126.0.0.1:6379") 159 | if err != nil { 160 | log.Fatal(err) 161 | } 162 | 163 | s := NewScanner(client, ScanAllKeys) 164 | var key string 165 | for s.Next(&key) { 166 | log.Printf("key: %q", key) 167 | } 168 | if err := s.Close(); err != nil { 169 | log.Fatal(err) 170 | } 171 | } 172 | 173 | func ExampleNewScanner_hscan() { 174 | client, err := DefaultClientFunc("tcp", "126.0.0.1:6379") 175 | if err != nil { 176 | log.Fatal(err) 177 | } 178 | 179 | s := NewScanner(client, ScanOpts{Command: "HSCAN", Key: "somekey"}) 180 | var key string 181 | for s.Next(&key) { 182 | log.Printf("key: %q", key) 183 | } 184 | if err := s.Close(); err != nil { 185 | log.Fatal(err) 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /trace/pool.go: -------------------------------------------------------------------------------- 1 | package trace 2 | 3 | import "time" 4 | 5 | // PoolTrace is passed into radix.NewPool via radix.PoolWithTrace, and contains 6 | // callbacks which will be triggered for specific events during the Pool's 7 | // runtime. 8 | // 9 | // All callbacks are called synchronously. 10 | type PoolTrace struct { 11 | // ConnCreated is called when the Pool creates a new connection. The 12 | // provided Err indicates whether the connection successfully completed. 13 | ConnCreated func(PoolConnCreated) 14 | 15 | // ConnClosed is called before closing the connection. 16 | ConnClosed func(PoolConnClosed) 17 | 18 | // DoCompleted is called after command execution. Must consider race condition 19 | // for manipulating variables in DoCompleted callback since DoComplete 20 | // function can be called in many go-routines. 21 | DoCompleted func(PoolDoCompleted) 22 | 23 | // InitCompleted is called after pool fills its connections 24 | InitCompleted func(PoolInitCompleted) 25 | } 26 | 27 | // PoolCommon contains information which is passed into all Pool-related 28 | // callbacks. 29 | type PoolCommon struct { 30 | // Network and Addr indicate the network/address the Pool was created with 31 | // (useful for differentiating different redis instances in a Cluster). 32 | Network, Addr string 33 | 34 | // PoolSize and BufferSize indicate the Pool size and buffer size that the 35 | // Pool was initialized with. 36 | PoolSize, BufferSize int 37 | } 38 | 39 | // PoolConnCreatedReason enumerates all the different reasons a connection might 40 | // be created and trigger a ConnCreated trace. 41 | type PoolConnCreatedReason string 42 | 43 | // All possible values of PoolConnCreatedReason. 44 | const ( 45 | // PoolConnCreatedReasonInitialization indicates a connection was being 46 | // created during initialization of the Pool (i.e. within NewPool). 47 | PoolConnCreatedReasonInitialization PoolConnCreatedReason = "initialization" 48 | 49 | // PoolConnCreatedReasonRefill indicates a connection was being created 50 | // during a refill event (see radix.PoolRefillInterval). 51 | PoolConnCreatedReasonRefill PoolConnCreatedReason = "refill" 52 | 53 | // PoolConnCreatedReasonPoolEmpty indicates a connection was being created 54 | // because the Pool was empty and an Action requires one. See the 55 | // radix.PoolOnEmpty options. 56 | PoolConnCreatedReasonPoolEmpty PoolConnCreatedReason = "pool empty" 57 | ) 58 | 59 | // PoolConnCreated is passed into the PoolTrace.ConnCreated callback whenever 60 | // the Pool creates a new connection. 61 | type PoolConnCreated struct { 62 | PoolCommon 63 | 64 | // The reason the connection was created. 65 | Reason PoolConnCreatedReason 66 | 67 | // How long it took to create the connection. 68 | ConnectTime time.Duration 69 | 70 | // If connection creation failed, this is the error it failed with. 71 | Err error 72 | } 73 | 74 | // PoolConnClosedReason enumerates all the different reasons a connection might 75 | // be closed and trigger a ConnClosed trace. 76 | type PoolConnClosedReason string 77 | 78 | // All possible values of PoolConnClosedReason. 79 | const ( 80 | // PoolConnClosedReasonPoolClosed indicates a connection was closed because 81 | // the Close method was called on Pool. 82 | PoolConnClosedReasonPoolClosed PoolConnClosedReason = "pool closed" 83 | 84 | // PoolConnClosedReasonBufferDrain indicates a connection was closed due to 85 | // a buffer drain event. See radix.PoolOnFullBuffer. 86 | PoolConnClosedReasonBufferDrain PoolConnClosedReason = "buffer drained" 87 | 88 | // PoolConnClosedReasonPoolFull indicates a connection was closed due to 89 | // the Pool already being full. See The radix.PoolOnFullClose options. 90 | PoolConnClosedReasonPoolFull PoolConnClosedReason = "pool full" 91 | 92 | // PoolConnClosedReasonConnExpired indicates a connection was closed because 93 | // the connection was expired. See The radix.PoolMaxLifetime options. 94 | PoolConnClosedReasonConnExpired PoolConnClosedReason = "conn expired" 95 | ) 96 | 97 | // PoolConnClosed is passed into the PoolTrace.ConnClosed callback whenever the 98 | // Pool closes a connection. 99 | type PoolConnClosed struct { 100 | PoolCommon 101 | 102 | // AvailCount indicates the total number of connections the Pool is holding 103 | // on to which are available for usage at the moment the trace occurs. 104 | AvailCount int 105 | 106 | // The reason the connection was closed. 107 | Reason PoolConnClosedReason 108 | } 109 | 110 | // PoolDoCompleted is passed into the PoolTrace.DoCompleted callback whenever Pool finished to run 111 | // Do function. 112 | type PoolDoCompleted struct { 113 | PoolCommon 114 | 115 | // AvailCount indicates the total number of connections the Pool is holding 116 | // on to which are available for usage at the moment the trace occurs. 117 | AvailCount int 118 | 119 | // How long it took to send command. 120 | ElapsedTime time.Duration 121 | 122 | // This is the error returned from redis. 123 | Err error 124 | } 125 | 126 | // PoolInitCompleted is passed into the PoolTrace.InitCompleted callback whenever Pool initialized. 127 | // This must be called once. 128 | type PoolInitCompleted struct { 129 | PoolCommon 130 | 131 | // AvailCount indicates the total number of connections the Pool is holding 132 | // on to which are available for usage at the moment the trace occurs. 133 | AvailCount int 134 | 135 | // How long it took to fill all connections. 136 | ElapsedTime time.Duration 137 | } 138 | -------------------------------------------------------------------------------- /stub.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "net" 7 | "sync" 8 | "time" 9 | 10 | "errors" 11 | 12 | "github.com/mediocregopher/radix/v3/resp" 13 | "github.com/mediocregopher/radix/v3/resp/resp2" 14 | ) 15 | 16 | type bufferAddr struct { 17 | network, addr string 18 | } 19 | 20 | func (sa bufferAddr) Network() string { 21 | return sa.network 22 | } 23 | 24 | func (sa bufferAddr) String() string { 25 | return sa.addr 26 | } 27 | 28 | type buffer struct { 29 | net.Conn // always nil 30 | remoteAddr bufferAddr 31 | 32 | bufL *sync.Cond 33 | buf *bytes.Buffer 34 | bufbr *bufio.Reader 35 | closed bool 36 | readDeadline time.Time 37 | } 38 | 39 | func newBuffer(remoteNetwork, remoteAddr string) *buffer { 40 | buf := new(bytes.Buffer) 41 | return &buffer{ 42 | remoteAddr: bufferAddr{network: remoteNetwork, addr: remoteAddr}, 43 | bufL: sync.NewCond(new(sync.Mutex)), 44 | buf: buf, 45 | bufbr: bufio.NewReader(buf), 46 | } 47 | } 48 | 49 | func (b *buffer) Encode(m resp.Marshaler) error { 50 | b.bufL.L.Lock() 51 | var err error 52 | if b.closed { 53 | err = b.err("write", errClosed) 54 | } else { 55 | err = m.MarshalRESP(b.buf) 56 | } 57 | b.bufL.L.Unlock() 58 | if err != nil { 59 | return err 60 | } 61 | 62 | b.bufL.Broadcast() 63 | return nil 64 | } 65 | 66 | func (b *buffer) Decode(u resp.Unmarshaler) error { 67 | b.bufL.L.Lock() 68 | defer b.bufL.L.Unlock() 69 | 70 | var timeoutCh chan struct{} 71 | if b.readDeadline.IsZero() { 72 | // no readDeadline, timeoutCh will never be written to 73 | } else if now := time.Now(); b.readDeadline.Before(now) { 74 | return b.err("read", new(timeoutError)) 75 | } else { 76 | timeoutCh = make(chan struct{}, 2) 77 | sleep := b.readDeadline.Sub(now) 78 | go func() { 79 | time.Sleep(sleep) 80 | timeoutCh <- struct{}{} 81 | b.bufL.Broadcast() 82 | }() 83 | } 84 | 85 | for b.buf.Len() == 0 && b.bufbr.Buffered() == 0 { 86 | if b.closed { 87 | return b.err("read", errClosed) 88 | } 89 | 90 | select { 91 | case <-timeoutCh: 92 | return b.err("read", new(timeoutError)) 93 | default: 94 | } 95 | 96 | // we have to periodically wakeup to double-check the timeoutCh, if 97 | // there is one 98 | if timeoutCh != nil { 99 | go func() { 100 | time.Sleep(1 * time.Second) 101 | b.bufL.Broadcast() 102 | }() 103 | } 104 | 105 | b.bufL.Wait() 106 | } 107 | 108 | return u.UnmarshalRESP(b.bufbr) 109 | } 110 | 111 | func (b *buffer) Close() error { 112 | b.bufL.L.Lock() 113 | defer b.bufL.L.Unlock() 114 | if b.closed { 115 | return b.err("close", errClosed) 116 | } 117 | b.closed = true 118 | b.bufL.Broadcast() 119 | return nil 120 | } 121 | 122 | func (b *buffer) RemoteAddr() net.Addr { 123 | return b.remoteAddr 124 | } 125 | 126 | func (b *buffer) SetDeadline(t time.Time) error { 127 | return b.SetReadDeadline(t) 128 | } 129 | 130 | func (b *buffer) SetReadDeadline(t time.Time) error { 131 | b.bufL.L.Lock() 132 | defer b.bufL.L.Unlock() 133 | if b.closed { 134 | return b.err("set", errClosed) 135 | } 136 | b.readDeadline = t 137 | return nil 138 | } 139 | 140 | func (b *buffer) err(op string, err error) error { 141 | return &net.OpError{ 142 | Op: op, 143 | Net: "tcp", 144 | Source: nil, 145 | Addr: b.remoteAddr, 146 | Err: err, 147 | } 148 | } 149 | 150 | var errClosed = errors.New("use of closed network connection") 151 | 152 | type timeoutError struct{} 153 | 154 | func (e *timeoutError) Error() string { return "i/o timeout" } 155 | func (e *timeoutError) Timeout() bool { return true } 156 | func (e *timeoutError) Temporary() bool { return true } 157 | 158 | //////////////////////////////////////////////////////////////////////////////// 159 | 160 | type stub struct { 161 | *buffer 162 | fn func([]string) interface{} 163 | } 164 | 165 | // Stub returns a (fake) Conn which pretends it is a Conn to a real redis 166 | // instance, but is instead using the given callback to service requests. It is 167 | // primarily useful for writing tests. 168 | // 169 | // When Encode is called the given value is marshalled into bytes then 170 | // unmarshalled into a []string, which is passed to the callback. The return 171 | // from the callback is then marshalled and buffered interanlly, and will be 172 | // unmarshalled in the next call to Decode. 173 | // 174 | // remoteNetwork and remoteAddr can be empty, but if given will be used as the 175 | // return from the RemoteAddr method. 176 | // 177 | // If the internal buffer is empty then Decode will block until Encode is called 178 | // in a separate go-routine. The SetDeadline and SetReadDeadline methods can be 179 | // used as usual to limit how long Decode blocks. All other inherited net.Conn 180 | // methods will panic. 181 | func Stub(remoteNetwork, remoteAddr string, fn func([]string) interface{}) Conn { 182 | return &stub{ 183 | buffer: newBuffer(remoteNetwork, remoteAddr), 184 | fn: fn, 185 | } 186 | } 187 | 188 | func (s *stub) Do(a Action) error { 189 | return a.Run(s) 190 | } 191 | 192 | func (s *stub) Encode(m resp.Marshaler) error { 193 | // first marshal into a RawMessage 194 | buf := new(bytes.Buffer) 195 | if err := m.MarshalRESP(buf); err != nil { 196 | return err 197 | } 198 | br := bufio.NewReader(buf) 199 | 200 | var rm resp2.RawMessage 201 | for { 202 | if buf.Len() == 0 && br.Buffered() == 0 { 203 | break 204 | } else if err := rm.UnmarshalRESP(br); err != nil { 205 | return err 206 | } 207 | // unmarshal that into a string slice 208 | var ss []string 209 | if err := rm.UnmarshalInto(resp2.Any{I: &ss}); err != nil { 210 | return err 211 | } 212 | 213 | // get return from callback. Results implementing resp.Marshaler are 214 | // assumed to be wanting to be written in all cases, otherwise if the 215 | // result is an error it is assumed to want to be returned directly. 216 | ret := s.fn(ss) 217 | if m, ok := ret.(resp.Marshaler); ok { 218 | return s.buffer.Encode(m) 219 | } else if err, _ := ret.(error); err != nil { 220 | return err 221 | } else if err = s.buffer.Encode(resp2.Any{I: ret}); err != nil { 222 | return err 223 | } 224 | } 225 | 226 | return nil 227 | } 228 | 229 | func (s *stub) NetConn() net.Conn { 230 | return s.buffer 231 | } 232 | -------------------------------------------------------------------------------- /pipeliner.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "strings" 7 | "sync" 8 | "time" 9 | 10 | "github.com/mediocregopher/radix/v3/resp" 11 | ) 12 | 13 | var blockingCmds = map[string]bool{ 14 | "WAIT": true, 15 | 16 | // taken from https://github.com/joomcode/redispipe#limitations 17 | "BLPOP": true, 18 | "BRPOP": true, 19 | "BRPOPLPUSH": true, 20 | 21 | "BZPOPMIN": true, 22 | "BZPOPMAX": true, 23 | 24 | "XREAD": true, 25 | "XREADGROUP": true, 26 | 27 | "SAVE": true, 28 | } 29 | 30 | type pipeliner struct { 31 | c Client 32 | 33 | limit int 34 | window time.Duration 35 | 36 | // reqsBufCh contains buffers for collecting commands and acts as a semaphore 37 | // to limit the number of concurrent flushes. 38 | reqsBufCh chan []CmdAction 39 | 40 | reqCh chan *pipelinerCmd 41 | reqWG sync.WaitGroup 42 | 43 | l sync.RWMutex 44 | closed bool 45 | } 46 | 47 | var _ Client = (*pipeliner)(nil) 48 | 49 | func newPipeliner(c Client, concurrency, limit int, window time.Duration) *pipeliner { 50 | if concurrency < 1 { 51 | concurrency = 1 52 | } 53 | 54 | p := &pipeliner{ 55 | c: c, 56 | 57 | limit: limit, 58 | window: window, 59 | 60 | reqsBufCh: make(chan []CmdAction, concurrency), 61 | 62 | reqCh: make(chan *pipelinerCmd, 32), // https://xkcd.com/221/ 63 | } 64 | 65 | p.reqWG.Add(1) 66 | go func() { 67 | defer p.reqWG.Done() 68 | p.reqLoop() 69 | }() 70 | 71 | for i := 0; i < cap(p.reqsBufCh); i++ { 72 | if p.limit > 0 { 73 | p.reqsBufCh <- make([]CmdAction, 0, limit) 74 | } else { 75 | p.reqsBufCh <- nil 76 | } 77 | } 78 | 79 | return p 80 | } 81 | 82 | // CanDo checks if the given Action can be executed / passed to p.Do. 83 | // 84 | // If CanDo returns false, the Action must not be given to Do. 85 | func (p *pipeliner) CanDo(a Action) bool { 86 | // there is currently no way to get the command for CmdAction implementations 87 | // from outside the radix package so we can not multiplex those commands. User 88 | // defined pipelines are not pipelined to let the user better control them. 89 | if cmdA, ok := a.(*cmdAction); ok { 90 | return !blockingCmds[strings.ToUpper(cmdA.cmd)] 91 | } 92 | return false 93 | } 94 | 95 | // Do executes the given Action as part of the pipeline. 96 | // 97 | // If a is not a CmdAction, Do panics. 98 | func (p *pipeliner) Do(a Action) error { 99 | req := getPipelinerCmd(a.(CmdAction)) // get this outside the lock to avoid 100 | 101 | p.l.RLock() 102 | if p.closed { 103 | p.l.RUnlock() 104 | return errClientClosed 105 | } 106 | p.reqCh <- req 107 | p.l.RUnlock() 108 | 109 | err := <-req.resCh 110 | poolPipelinerCmd(req) 111 | return err 112 | } 113 | 114 | // Close closes the pipeliner and makes sure that all background goroutines 115 | // are stopped before returning. 116 | // 117 | // Close does *not* close the underlying Client. 118 | func (p *pipeliner) Close() error { 119 | p.l.Lock() 120 | defer p.l.Unlock() 121 | 122 | if p.closed { 123 | return nil 124 | } 125 | 126 | close(p.reqCh) 127 | p.reqWG.Wait() 128 | 129 | for i := 0; i < cap(p.reqsBufCh); i++ { 130 | <-p.reqsBufCh 131 | } 132 | 133 | p.c, p.closed = nil, true 134 | return nil 135 | } 136 | 137 | func (p *pipeliner) reqLoop() { 138 | t := getTimer(time.Hour) 139 | defer putTimer(t) 140 | 141 | t.Stop() 142 | 143 | reqs := <-p.reqsBufCh 144 | defer func() { 145 | p.reqsBufCh <- reqs 146 | }() 147 | 148 | for { 149 | select { 150 | case req, ok := <-p.reqCh: 151 | if !ok { 152 | reqs = p.flush(reqs) 153 | return 154 | } 155 | 156 | reqs = append(reqs, req) 157 | 158 | if p.limit > 0 && len(reqs) == p.limit { 159 | // if we reached the pipeline limit, execute now to avoid unnecessary waiting 160 | t.Stop() 161 | 162 | reqs = p.flush(reqs) 163 | } else if len(reqs) == 1 { 164 | t.Reset(p.window) 165 | } 166 | case <-t.C: 167 | reqs = p.flush(reqs) 168 | } 169 | } 170 | } 171 | 172 | func (p *pipeliner) flush(reqs []CmdAction) []CmdAction { 173 | if len(reqs) == 0 { 174 | return reqs 175 | } 176 | 177 | go func() { 178 | defer func() { 179 | p.reqsBufCh <- reqs[:0] 180 | }() 181 | 182 | pp := &pipelinerPipeline{pipeline: pipeline(reqs)} 183 | defer pp.flush() 184 | 185 | if err := p.c.Do(pp); err != nil { 186 | pp.doErr = err 187 | } 188 | }() 189 | 190 | return <-p.reqsBufCh 191 | } 192 | 193 | type pipelinerCmd struct { 194 | CmdAction 195 | 196 | resCh chan error 197 | 198 | unmarshalCalled bool 199 | unmarshalErr error 200 | } 201 | 202 | var ( 203 | _ resp.Unmarshaler = (*pipelinerCmd)(nil) 204 | ) 205 | 206 | func (p *pipelinerCmd) sendRes(err error) { 207 | p.resCh <- err 208 | } 209 | 210 | func (p *pipelinerCmd) UnmarshalRESP(br *bufio.Reader) error { 211 | p.unmarshalErr = p.CmdAction.UnmarshalRESP(br) 212 | p.unmarshalCalled = true // important: we set this after unmarshalErr in case the call to UnmarshalRESP panics 213 | return p.unmarshalErr 214 | } 215 | 216 | var pipelinerCmdPool sync.Pool 217 | 218 | func getPipelinerCmd(cmd CmdAction) *pipelinerCmd { 219 | req, _ := pipelinerCmdPool.Get().(*pipelinerCmd) 220 | if req != nil { 221 | *req = pipelinerCmd{ 222 | CmdAction: cmd, 223 | resCh: req.resCh, 224 | } 225 | return req 226 | } 227 | return &pipelinerCmd{ 228 | CmdAction: cmd, 229 | // using a buffer of 1 is faster than no buffer in most cases 230 | resCh: make(chan error, 1), 231 | } 232 | } 233 | 234 | func poolPipelinerCmd(req *pipelinerCmd) { 235 | req.CmdAction = nil 236 | pipelinerCmdPool.Put(req) 237 | } 238 | 239 | type pipelinerPipeline struct { 240 | pipeline 241 | doErr error 242 | } 243 | 244 | func (p *pipelinerPipeline) flush() { 245 | for _, req := range p.pipeline { 246 | var err error 247 | 248 | cmd := req.(*pipelinerCmd) 249 | if cmd.unmarshalCalled { 250 | err = cmd.unmarshalErr 251 | } else { 252 | err = p.doErr 253 | } 254 | cmd.sendRes(err) 255 | } 256 | } 257 | 258 | func (p *pipelinerPipeline) Run(c Conn) (err error) { 259 | defer func() { 260 | if v := recover(); v != nil { 261 | err = fmt.Errorf("%s", v) 262 | } 263 | }() 264 | if err := c.Encode(p); err != nil { 265 | return err 266 | } 267 | errConn := ioErrConn{Conn: c} 268 | for _, req := range p.pipeline { 269 | if _ = errConn.Decode(req); errConn.lastIOErr != nil { 270 | return errConn.lastIOErr 271 | } 272 | } 273 | return nil 274 | } 275 | -------------------------------------------------------------------------------- /tls_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "crypto/tls" 5 | "net" 6 | "strings" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestDialUseTLS(t *testing.T) { 16 | // In order to test a TLS connection we need to start a TLS terminating proxy 17 | 18 | // Both the key and the certificate were generated by running the following command: 19 | // go run $GOROOT/src/crypto/tls/generate_cert.go --host localhost 20 | 21 | // This function is used to avoid static code analysis from identifying the private key 22 | testingKey := func(s string) string { return strings.Replace(s, "TESTING KEY", "PRIVATE KEY", 2) } 23 | 24 | var rsaCertPEM = `-----BEGIN CERTIFICATE----- 25 | MIIC+TCCAeGgAwIBAgIQJ0gZjEJuKoZtra6oAYs54zANBgkqhkiG9w0BAQsFADAS 26 | MRAwDgYDVQQKEwdBY21lIENvMB4XDTE5MDkxMjE5MzAyN1oXDTIwMDkxMTE5MzAy 27 | N1owEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC 28 | AQoCggEBAOSDBT4IYzLU1lAbLMU+JmkiZilfkJ+/iUEoSz2jTVyyntY6+r2x2Sfc 29 | HVVKTo08qGsNdgSr09GPBHytWBWXbgH1h4ipnRt8iBtDPFzmqMlK/SVn2fFwlkl4 30 | XqDJSQzeuK2LrbjaiI7TFNJ7mwGDgOIsqi/8am2Te/sQGmZomkR6Pysr92jbZLEm 31 | zxEvv7vQjknNVbRsotincEVtkhT3vAstl1YZPOflsP6J0XtmOXst9WhE96U2Lsh5 32 | cJK1Xi8y1q2u4yScljhrnsURHQKoF0WXyT+5vo1NcZsECscsjiVqqXUdwX91J753 33 | UAM/r75Zc8lMyfU7QPQIafebT4rk1hMCAwEAAaNLMEkwDgYDVR0PAQH/BAQDAgWg 34 | MBMGA1UdJQQMMAoGCCsGAQUFBwMBMAwGA1UdEwEB/wQCMAAwFAYDVR0RBA0wC4IJ 35 | bG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQBZf1tHwem38Cp4tTUyhBV849fz 36 | GVs1FlDLX11PRF9TaAyQf4QKpWDXQV9baQF90krwBDTMjb8f5pVfI1uaEFu3zQZ7 37 | DFNnw628wzGOKPr0fivXaycN3Gt0Qs9UvM5uiI+cNU4tKofd1dkVrnPzJXYaTbAn 38 | lJf4OgVAHa6RtNpZXicARXb+jqKiMOWZH8A3Tj1jQIXgv+orW3ha1R2y2HzZEbnj 39 | NyklAu0YelMXI5nbkptdXBsWVMU/2z/d00AEQRlQoDRXamE0FCURL+J1odzifk80 40 | PdMm11Wq+2LeY0h/4SGwP+cmpNMOV5bMvHBohmGxMZMVISyvSuw7JMMcydR4 41 | -----END CERTIFICATE----- 42 | 43 | ` 44 | 45 | var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY----- 46 | MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDkgwU+CGMy1NZQ 47 | GyzFPiZpImYpX5Cfv4lBKEs9o01csp7WOvq9sdkn3B1VSk6NPKhrDXYEq9PRjwR8 48 | rVgVl24B9YeIqZ0bfIgbQzxc5qjJSv0lZ9nxcJZJeF6gyUkM3riti6242oiO0xTS 49 | e5sBg4DiLKov/Gptk3v7EBpmaJpEej8rK/do22SxJs8RL7+70I5JzVW0bKLYp3BF 50 | bZIU97wLLZdWGTzn5bD+idF7Zjl7LfVoRPelNi7IeXCStV4vMtatruMknJY4a57F 51 | ER0CqBdFl8k/ub6NTXGbBArHLI4laql1HcF/dSe+d1ADP6++WXPJTMn1O0D0CGn3 52 | m0+K5NYTAgMBAAECggEABdx6ePHcIYSmDp3z0wdaEt5IAo2p9v8BtUMkUutqY5NN 53 | Ua9nmRADwur5caObSjIhG8XXnh0OLNTfR5dmp/8fWjuDA3VeS0MxdomN9dAQykD7 54 | J0d3pqK9qBrHSpZ/Ii5gTEtF5HTuhcNSSGfVPP+zgZmlr99ol3DuAC2Uj8XlFxaD 55 | NNhgLsB5v3vPGSqiW7joKRaSa2OGqbXdOz6dTWkS0PWYXGSfIMI4Js4EmU7sqi6w 56 | aVG25XWjiQ57VNH2ZoE0rY3yVbWH5CJ2CD58jIO/LpWfWCXvwGMRHNl6GN0cSv/h 57 | g+BBsTz2VzvoN6/ZvdxccQ63KBDpb0/Ovd2Ri/kjYQKBgQDkmSsUBrOjHWCSZqpg 58 | BIdcFGYBjTFCm4JXrBBZsEHjxBqKXKcCY5DkwW580Yug/Vmn8j2bJgJ3mbYccPqB 59 | WkjrnnjS+lhe9ciNNdA/YqmN0ONlyEvEtb1fOOilZDq+SPrVnaQ8xT/+UCfWnbQI 60 | NNySy4rlMERAcv2G0z9EmqK6AwKBgQD/5zKH1mekV/3i+IjL8KizMwH6ddJwpnMT 61 | 3Iwx11PjPM7tmrlrzRK3rQOqDHVTo9kV37L1rKfC9JXjFhxydj/IOLD6cKUiJJUI 62 | UJVEHb21HNzommnhHyq5AGYtv6hiWPvbMEtws7Y3QmnAbP9UbjZSiTs8hZoHUrWY 63 | uMub1hy+sQKBgDiaxNP8pNarG5Kk4WNNO8dNNcUElUINB8V10cajom0nzfqc3q30 64 | wZgjXZyCtrRyh5TSovab/thms3VvdFg7ZvsRDpIPc3pwGez9ekd3wsxfAS/e3QQk 65 | jHPbv5/UpccggxwKIPT7UtFCP9sgyceOb1/aDtaZkQz0bFrKTExMjibJAoGAE7ID 66 | nZjO2UM8cx+Vx7x5/3DJkjFHRQxKhxjOYXelKTQg6QCjjLx32FMkmQ3kac+OgbR5 67 | 3ZawQrz4XEXzYovfVNWoKV5KF1qhbcZl9pwjYbEa/3wC8iSn8R0qwBKkLw2SNMh+ 68 | xenO+GnQIdNBw4nH/Io7WOkfdbjT6TEv2oqcI8ECgYBIppEhekL3lzN5qNqUqaQS 69 | 64gtm/esLUQLkXFmrv/KZ/QMhOtGNb2Hipc1KOomMTm5zJf2gRMZice97EoBpIiq 70 | /syezw2OV/TjSCLzFrikz8W/lHkpbzwk71s1f0FKMIK863lB4fqj5bCXMXGyiXUt 71 | Baas4jyR6hQ0qRSe4PmQrA== 72 | -----END TESTING KEY----- 73 | `) 74 | pem := []byte(rsaCertPEM + rsaKeyPEM) 75 | cert, err := tls.X509KeyPair(pem, pem) 76 | require.NoError(t, err) 77 | 78 | // The following TLS proxy is based on https://gist.github.com/cs8425/a742349a55596f1b251a#file-tls2tcp_server-go 79 | listener, err := tls.Listen("tcp", ":63790", &tls.Config{ 80 | Certificates: []tls.Certificate{cert}, 81 | }) 82 | require.NoError(t, err) 83 | // Used to prevent a race during shutdown failing the test 84 | m := sync.Mutex{} 85 | shuttingDown := false 86 | defer func() { 87 | m.Lock() 88 | shuttingDown = true 89 | m.Unlock() 90 | listener.Close() 91 | }() 92 | 93 | // Dials 127.0.0.1:6379 and proxies traffic 94 | proxyConnection := func(lConn net.Conn) { 95 | defer lConn.Close() 96 | 97 | rConn, err := net.DialTCP("tcp", nil, &net.TCPAddr{ 98 | IP: net.IPv4(127, 0, 0, 1), 99 | Port: 6379, 100 | }) 101 | require.NoError(t, err) 102 | defer rConn.Close() 103 | 104 | chanFromConn := func(conn net.Conn) chan []byte { 105 | c := make(chan []byte) 106 | 107 | go func() { 108 | b := make([]byte, 1024) 109 | 110 | for { 111 | n, err := conn.Read(b) 112 | if n > 0 { 113 | res := make([]byte, n) 114 | // Copy the buffer so it doesn't get changed while read by the recipient. 115 | copy(res, b[:n]) 116 | c <- res 117 | } 118 | if err != nil { 119 | c <- nil 120 | break 121 | } 122 | } 123 | }() 124 | 125 | return c 126 | } 127 | 128 | lChan := chanFromConn(lConn) 129 | rChan := chanFromConn(rConn) 130 | 131 | for { 132 | select { 133 | case b1 := <-lChan: 134 | if b1 == nil { 135 | return 136 | } 137 | _, err = rConn.Write(b1) 138 | require.NoError(t, err) 139 | case b2 := <-rChan: 140 | if b2 == nil { 141 | return 142 | } 143 | _, err = lConn.Write(b2) 144 | require.NoError(t, err) 145 | } 146 | } 147 | 148 | } 149 | 150 | // Accept new connections 151 | go func() { 152 | for { 153 | lConn, err := listener.Accept() 154 | if err != nil { 155 | // Accept unblocks and returns an error after Shutdown is called on listener 156 | m.Lock() 157 | defer m.Unlock() 158 | if shuttingDown { 159 | // Exit 160 | break 161 | } else { 162 | require.NoError(t, err) 163 | } 164 | } 165 | go proxyConnection(lConn) 166 | } 167 | }() 168 | 169 | // Connect to the proxy, passing in an insecure flag as we are self-signed 170 | c, err := Dial("tcp", "127.0.0.1:63790", DialUseTLS(&tls.Config{ 171 | InsecureSkipVerify: true, 172 | })) 173 | if err != nil { 174 | t.Fatal(err) 175 | } else if err := c.Do(Cmd(nil, "PING")); err != nil { 176 | t.Fatal(err) 177 | } 178 | 179 | // Confirm that the connection fails if verifying certificate 180 | _, err = Dial("tcp", "127.0.0.1:63790", DialUseTLS(nil), DialConnectTimeout(60*time.Minute)) 181 | assert.Error(t, err) 182 | } 183 | -------------------------------------------------------------------------------- /internal/bytesutil/bytesutil.go: -------------------------------------------------------------------------------- 1 | // Package bytesutil provides utility functions for working with bytes and byte streams that are useful when 2 | // working with the RESP protocol. 3 | package bytesutil 4 | 5 | import ( 6 | "bufio" 7 | "fmt" 8 | "io" 9 | "strconv" 10 | "sync" 11 | 12 | "errors" 13 | 14 | "github.com/mediocregopher/radix/v3/resp" 15 | ) 16 | 17 | // AnyIntToInt64 converts a value of any of Go's integer types (signed and unsigned) into a signed int64. 18 | // 19 | // If m is not of one of Go's built in integer types the call will panic. 20 | func AnyIntToInt64(m interface{}) int64 { 21 | switch mt := m.(type) { 22 | case int: 23 | return int64(mt) 24 | case int8: 25 | return int64(mt) 26 | case int16: 27 | return int64(mt) 28 | case int32: 29 | return int64(mt) 30 | case int64: 31 | return mt 32 | case uint: 33 | return int64(mt) 34 | case uint8: 35 | return int64(mt) 36 | case uint16: 37 | return int64(mt) 38 | case uint32: 39 | return int64(mt) 40 | case uint64: 41 | return int64(mt) 42 | } 43 | panic(fmt.Sprintf("anyIntToInt64 got bad arg: %#v", m)) 44 | } 45 | 46 | var bytePool = sync.Pool{ 47 | New: func() interface{} { 48 | b := make([]byte, 0, 64) 49 | return &b 50 | }, 51 | } 52 | 53 | // GetBytes returns a non-nil pointer to a byte slice from a pool of byte slices. 54 | // 55 | // The returned byte slice should be put back into the pool using PutBytes after usage. 56 | func GetBytes() *[]byte { 57 | return bytePool.Get().(*[]byte) 58 | } 59 | 60 | // PutBytes puts the given byte slice pointer into a pool that can be accessed via GetBytes. 61 | // 62 | // After calling PutBytes the given pointer and byte slice must not be accessed anymore. 63 | func PutBytes(b *[]byte) { 64 | *b = (*b)[:0] 65 | bytePool.Put(b) 66 | } 67 | 68 | // ParseInt is a specialized version of strconv.ParseInt that parses a base-10 encoded signed integer from a []byte. 69 | // 70 | // This can be used to avoid allocating a string, since strconv.ParseInt only takes a string. 71 | func ParseInt(b []byte) (int64, error) { 72 | if len(b) == 0 { 73 | return 0, errors.New("empty slice given to parseInt") 74 | } 75 | 76 | var neg bool 77 | if b[0] == '-' || b[0] == '+' { 78 | neg = b[0] == '-' 79 | b = b[1:] 80 | } 81 | 82 | n, err := ParseUint(b) 83 | if err != nil { 84 | return 0, err 85 | } 86 | 87 | if neg { 88 | return -int64(n), nil 89 | } 90 | 91 | return int64(n), nil 92 | } 93 | 94 | // ParseUint is a specialized version of strconv.ParseUint that parses a base-10 encoded integer from a []byte. 95 | // 96 | // This can be used to avoid allocating a string, since strconv.ParseUint only takes a string. 97 | func ParseUint(b []byte) (uint64, error) { 98 | if len(b) == 0 { 99 | return 0, errors.New("empty slice given to parseUint") 100 | } 101 | 102 | var n uint64 103 | 104 | for i, c := range b { 105 | if c < '0' || c > '9' { 106 | return 0, fmt.Errorf("invalid character %c at position %d in parseUint", c, i) 107 | } 108 | 109 | n *= 10 110 | n += uint64(c - '0') 111 | } 112 | 113 | return n, nil 114 | } 115 | 116 | // Expand expands the given byte slice to exactly n bytes. 117 | // 118 | // If cap(b) < n, a new slice will be allocated and filled with the bytes from b. 119 | func Expand(b []byte, n int) []byte { 120 | if cap(b) < n { 121 | nb := make([]byte, n) 122 | copy(nb, b) 123 | return nb 124 | } 125 | return b[:n] 126 | } 127 | 128 | // BufferedBytesDelim reads a line from br and checks that the line ends with \r\n, returning the line without \r\n. 129 | func BufferedBytesDelim(br *bufio.Reader) ([]byte, error) { 130 | b, err := br.ReadSlice('\n') 131 | if err != nil { 132 | return nil, err 133 | } else if len(b) < 2 || b[len(b)-2] != '\r' { 134 | return nil, fmt.Errorf("malformed resp %q", b) 135 | } 136 | return b[:len(b)-2], err 137 | } 138 | 139 | // BufferedIntDelim reads the current line from br as an integer. 140 | func BufferedIntDelim(br *bufio.Reader) (int64, error) { 141 | b, err := BufferedBytesDelim(br) 142 | if err != nil { 143 | return 0, err 144 | } 145 | return ParseInt(b) 146 | } 147 | 148 | // ReadNAppend appends exactly n bytes from r into b. 149 | func ReadNAppend(r io.Reader, b []byte, n int) ([]byte, error) { 150 | if n == 0 { 151 | return b, nil 152 | } 153 | m := len(b) 154 | b = Expand(b, len(b)+n) 155 | _, err := io.ReadFull(r, b[m:]) 156 | return b, err 157 | } 158 | 159 | // ReadNDiscard discards exactly n bytes from r. 160 | func ReadNDiscard(r io.Reader, n int) error { 161 | type discarder interface { 162 | Discard(int) (int, error) 163 | } 164 | 165 | if n == 0 { 166 | return nil 167 | } 168 | 169 | switch v := r.(type) { 170 | case discarder: 171 | _, err := v.Discard(n) 172 | return err 173 | case io.Seeker: 174 | _, err := v.Seek(int64(n), io.SeekCurrent) 175 | return err 176 | } 177 | 178 | scratch := GetBytes() 179 | defer PutBytes(scratch) 180 | *scratch = (*scratch)[:cap(*scratch)] 181 | if len(*scratch) < n { 182 | *scratch = make([]byte, 8192) 183 | } 184 | 185 | for { 186 | buf := *scratch 187 | if len(buf) > n { 188 | buf = buf[:n] 189 | } 190 | nr, err := r.Read(buf) 191 | n -= nr 192 | if n == 0 || err != nil { 193 | return err 194 | } 195 | } 196 | } 197 | 198 | // ReadInt reads the next n bytes from r as a signed 64 bit integer. 199 | func ReadInt(r io.Reader, n int) (int64, error) { 200 | scratch := GetBytes() 201 | defer PutBytes(scratch) 202 | 203 | var err error 204 | if *scratch, err = ReadNAppend(r, *scratch, n); err != nil { 205 | return 0, err 206 | } 207 | i, err := ParseInt(*scratch) 208 | if err != nil { 209 | return 0, resp.ErrDiscarded{Err: err} 210 | } 211 | return i, nil 212 | } 213 | 214 | // ReadUint reads the next n bytes from r as an unsigned 64 bit integer. 215 | func ReadUint(r io.Reader, n int) (uint64, error) { 216 | scratch := GetBytes() 217 | defer PutBytes(scratch) 218 | 219 | var err error 220 | if *scratch, err = ReadNAppend(r, *scratch, n); err != nil { 221 | return 0, err 222 | } 223 | ui, err := ParseUint(*scratch) 224 | if err != nil { 225 | return 0, resp.ErrDiscarded{Err: err} 226 | } 227 | return ui, nil 228 | } 229 | 230 | // ReadFloat reads the next n bytes from r as a 64 bit floating point number with the given precision. 231 | func ReadFloat(r io.Reader, precision, n int) (float64, error) { 232 | scratch := GetBytes() 233 | defer PutBytes(scratch) 234 | 235 | var err error 236 | if *scratch, err = ReadNAppend(r, *scratch, n); err != nil { 237 | return 0, err 238 | } 239 | f, err := strconv.ParseFloat(string(*scratch), precision) 240 | if err != nil { 241 | return 0, resp.ErrDiscarded{Err: err} 242 | } 243 | return f, nil 244 | } 245 | -------------------------------------------------------------------------------- /pipeliner_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "io" 7 | "net" 8 | . "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | type panicingCmdAction struct { 16 | panicOnMarshal bool 17 | } 18 | 19 | func (p panicingCmdAction) Keys() []string { 20 | return nil 21 | } 22 | 23 | func (p panicingCmdAction) Run(c Conn) error { 24 | return c.Do(p) 25 | } 26 | 27 | func (p panicingCmdAction) MarshalRESP(io.Writer) error { 28 | if p.panicOnMarshal { 29 | panic("MarshalRESP called") 30 | } 31 | return nil 32 | } 33 | 34 | func (p panicingCmdAction) UnmarshalRESP(*bufio.Reader) error { 35 | panic("UnmarshalRESP called") 36 | } 37 | 38 | func TestPipeliner(t *T) { 39 | dialOpts := []DialOpt{DialReadTimeout(time.Second)} 40 | 41 | testMarshalPanic := func(t *T, p *pipeliner) { 42 | key := randStr() 43 | 44 | setCmd := getPipelinerCmd(Cmd(nil, "SET", key, key)) 45 | 46 | var firstGetResult string 47 | firstGetCmd := getPipelinerCmd(Cmd(&firstGetResult, "GET", key)) 48 | 49 | panicingCmd := getPipelinerCmd(&panicingCmdAction{panicOnMarshal: true}) 50 | 51 | var secondGetResult string 52 | secondGetCmd := getPipelinerCmd(Cmd(&secondGetResult, "GET", key)) 53 | 54 | p.flush([]CmdAction{setCmd, firstGetCmd, panicingCmd, secondGetCmd}) 55 | 56 | require.NotNil(t, <-setCmd.resCh) 57 | require.NotNil(t, <-firstGetCmd.resCh) 58 | require.NotNil(t, <-panicingCmd.resCh) 59 | require.NotNil(t, <-secondGetCmd.resCh) 60 | } 61 | 62 | testUnmarshalPanic := func(t *T, p *pipeliner) { 63 | key := randStr() 64 | 65 | setCmd := getPipelinerCmd(Cmd(nil, "SET", key, key)) 66 | 67 | var firstGetResult string 68 | firstGetCmd := getPipelinerCmd(Cmd(&firstGetResult, "GET", key)) 69 | 70 | panicingCmd := getPipelinerCmd(&panicingCmdAction{}) 71 | 72 | var secondGetResult string 73 | secondGetCmd := getPipelinerCmd(Cmd(&secondGetResult, "GET", key)) 74 | 75 | p.flush([]CmdAction{setCmd, firstGetCmd, panicingCmd, secondGetCmd}) 76 | 77 | require.Nil(t, <-setCmd.resCh) 78 | 79 | require.Nil(t, <-firstGetCmd.resCh) 80 | require.Equal(t, key, firstGetResult) 81 | 82 | require.NotNil(t, <-panicingCmd.resCh) 83 | 84 | require.NotNil(t, <-secondGetCmd.resCh) 85 | } 86 | 87 | testRecoverableError := func(t *T, p *pipeliner) { 88 | key := randStr() 89 | 90 | setCmd := getPipelinerCmd(Cmd(nil, "SET", key, key)) 91 | 92 | var firstGetResult string 93 | firstGetCmd := getPipelinerCmd(Cmd(&firstGetResult, "GET", key)) 94 | 95 | invalidCmd := getPipelinerCmd(Cmd(nil, "RADIXISAWESOME")) 96 | 97 | var secondGetResult string 98 | secondGetCmd := getPipelinerCmd(Cmd(&secondGetResult, "GET", key)) 99 | 100 | p.flush([]CmdAction{setCmd, firstGetCmd, invalidCmd, secondGetCmd}) 101 | 102 | require.Nil(t, <-setCmd.resCh) 103 | 104 | require.Nil(t, <-firstGetCmd.resCh) 105 | require.Equal(t, key, firstGetResult) 106 | 107 | require.NotNil(t, <-invalidCmd.resCh) 108 | 109 | require.Nil(t, <-secondGetCmd.resCh) 110 | require.Equal(t, key, secondGetResult) 111 | } 112 | 113 | testTimeout := func(t *T, p *pipeliner) { 114 | key := randStr() 115 | 116 | delCmd := getPipelinerCmd(Cmd(nil, "DEL", key)) 117 | pushCmd := getPipelinerCmd(Cmd(nil, "LPUSH", key, "3", "2", "1")) 118 | p.flush([]CmdAction{delCmd, pushCmd}) 119 | require.Nil(t, <-delCmd.resCh) 120 | require.Nil(t, <-pushCmd.resCh) 121 | 122 | var firstPopResult string 123 | firstPopCmd := getPipelinerCmd(Cmd(&firstPopResult, "LPOP", key)) 124 | 125 | var pauseResult string 126 | pauseCmd := getPipelinerCmd(Cmd(&pauseResult, "CLIENT", "PAUSE", "1100")) 127 | 128 | var secondPopResult string 129 | secondPopCmd := getPipelinerCmd(Cmd(&secondPopResult, "LPOP", key)) 130 | 131 | var thirdPopResult string 132 | thirdPopCmd := getPipelinerCmd(Cmd(&thirdPopResult, "LPOP", key)) 133 | 134 | p.flush([]CmdAction{firstPopCmd, pauseCmd, secondPopCmd, thirdPopCmd}) 135 | 136 | require.Nil(t, <-firstPopCmd.resCh) 137 | require.Equal(t, "1", firstPopResult) 138 | 139 | require.Nil(t, <-pauseCmd.resCh) 140 | require.Equal(t, "OK", pauseResult) 141 | 142 | secondPopErr := <-secondPopCmd.resCh 143 | require.IsType(t, (*net.OpError)(nil), secondPopErr) 144 | 145 | var secondPopNetErr net.Error 146 | assert.True(t, errors.As(secondPopErr, &secondPopNetErr)) 147 | 148 | require.True(t, secondPopNetErr.Timeout()) 149 | assert.Empty(t, secondPopResult) 150 | 151 | thirdPopErr := <-thirdPopCmd.resCh 152 | 153 | var thirdPopNetErr *net.OpError 154 | assert.True(t, errors.As(thirdPopErr, &thirdPopNetErr)) 155 | 156 | require.True(t, thirdPopNetErr.Timeout()) 157 | assert.Empty(t, thirdPopResult) 158 | } 159 | 160 | t.Run("Conn", func(t *T) { 161 | t.Run("MarshalPanic", func(t *T) { 162 | conn := dial(dialOpts...) 163 | defer conn.Close() 164 | 165 | p := newPipeliner(conn, 0, 0, 0) 166 | defer p.Close() 167 | 168 | testMarshalPanic(t, p) 169 | }) 170 | 171 | t.Run("UnmarshalPanic", func(t *T) { 172 | conn := dial(dialOpts...) 173 | defer conn.Close() 174 | 175 | p := newPipeliner(conn, 0, 0, 0) 176 | defer p.Close() 177 | 178 | testUnmarshalPanic(t, p) 179 | }) 180 | 181 | t.Run("RecoverableError", func(t *T) { 182 | conn := dial(dialOpts...) 183 | defer conn.Close() 184 | 185 | p := newPipeliner(conn, 0, 0, 0) 186 | defer p.Close() 187 | 188 | testRecoverableError(t, p) 189 | }) 190 | 191 | t.Run("Timeout", func(t *T) { 192 | conn := dial(dialOpts...) 193 | defer conn.Close() 194 | 195 | p := newPipeliner(conn, 0, 0, 0) 196 | defer p.Close() 197 | 198 | testTimeout(t, p) 199 | }) 200 | }) 201 | 202 | // Pool may have potentially different semantics because it uses ioErrConn 203 | // directly, so we test it separately. 204 | t.Run("Pool", func(t *T) { 205 | poolOpts := []PoolOpt{ 206 | PoolConnFunc(func(string, string) (Conn, error) { 207 | return dial(dialOpts...), nil 208 | }), 209 | PoolPipelineConcurrency(1), 210 | PoolPipelineWindow(time.Hour, 0), 211 | } 212 | 213 | t.Run("MarshalPanic", func(t *T) { 214 | pool := testPool(1, poolOpts...) 215 | defer pool.Close() 216 | 217 | testMarshalPanic(t, pool.pipeliner) 218 | }) 219 | 220 | t.Run("UnmarshalPanic", func(t *T) { 221 | pool := testPool(1, poolOpts...) 222 | defer pool.Close() 223 | 224 | testUnmarshalPanic(t, pool.pipeliner) 225 | }) 226 | 227 | t.Run("RecoverableError", func(t *T) { 228 | pool := testPool(1, poolOpts...) 229 | defer pool.Close() 230 | 231 | testRecoverableError(t, pool.pipeliner) 232 | }) 233 | 234 | t.Run("Timeout", func(t *T) { 235 | pool := testPool(1, poolOpts...) 236 | defer pool.Close() 237 | 238 | testTimeout(t, pool.pipeliner) 239 | }) 240 | }) 241 | } 242 | -------------------------------------------------------------------------------- /cluster_topo.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "net" 8 | "sort" 9 | "strconv" 10 | 11 | "github.com/mediocregopher/radix/v3/resp" 12 | "github.com/mediocregopher/radix/v3/resp/resp2" 13 | ) 14 | 15 | // ClusterNode describes a single node in the cluster at a moment in time. 16 | type ClusterNode struct { 17 | // older versions of redis might not actually send back the id, so it may be 18 | // blank 19 | Addr, ID string 20 | // start is inclusive, end is exclusive 21 | Slots [][2]uint16 22 | // address and id this node is the secondary of, if it's a secondary 23 | SecondaryOfAddr, SecondaryOfID string 24 | } 25 | 26 | // ClusterTopo describes the cluster topology at a given moment. It will be 27 | // sorted first by slot number of each node and then by secondary status, so 28 | // primaries will come before secondaries. 29 | type ClusterTopo []ClusterNode 30 | 31 | // MarshalRESP implements the resp.Marshaler interface, and will marshal the 32 | // ClusterTopo in the same format as the return from CLUSTER SLOTS. 33 | func (tt ClusterTopo) MarshalRESP(w io.Writer) error { 34 | m := map[[2]uint16]topoSlotSet{} 35 | for _, t := range tt { 36 | for _, slots := range t.Slots { 37 | tss := m[slots] 38 | tss.slots = slots 39 | tss.nodes = append(tss.nodes, t) 40 | m[slots] = tss 41 | } 42 | } 43 | 44 | // we sort the topoSlotSets by their slot number so that the order is 45 | // deterministic, mostly so tests pass consistently, I'm not sure if actual 46 | // redis has any contract on the order 47 | allTSS := make([]topoSlotSet, 0, len(m)) 48 | for _, tss := range m { 49 | allTSS = append(allTSS, tss) 50 | } 51 | sort.Slice(allTSS, func(i, j int) bool { 52 | return allTSS[i].slots[0] < allTSS[j].slots[0] 53 | }) 54 | 55 | if err := (resp2.ArrayHeader{N: len(allTSS)}).MarshalRESP(w); err != nil { 56 | return err 57 | } 58 | for _, tss := range allTSS { 59 | if err := tss.MarshalRESP(w); err != nil { 60 | return err 61 | } 62 | } 63 | return nil 64 | } 65 | 66 | // UnmarshalRESP implements the resp.Unmarshaler interface, but only supports 67 | // unmarshaling the return from CLUSTER SLOTS. The unmarshaled nodes will be 68 | // sorted before they are returned. 69 | func (tt *ClusterTopo) UnmarshalRESP(br *bufio.Reader) error { 70 | var arrHead resp2.ArrayHeader 71 | if err := arrHead.UnmarshalRESP(br); err != nil { 72 | return err 73 | } 74 | slotSets := make([]topoSlotSet, arrHead.N) 75 | for i := range slotSets { 76 | if err := (&(slotSets[i])).UnmarshalRESP(br); err != nil { 77 | return err 78 | } 79 | } 80 | 81 | nodeAddrM := map[string]ClusterNode{} 82 | for _, tss := range slotSets { 83 | for _, n := range tss.nodes { 84 | if existingN, ok := nodeAddrM[n.Addr]; ok { 85 | existingN.Slots = append(existingN.Slots, n.Slots...) 86 | nodeAddrM[n.Addr] = existingN 87 | } else { 88 | nodeAddrM[n.Addr] = n 89 | } 90 | } 91 | } 92 | 93 | for _, n := range nodeAddrM { 94 | *tt = append(*tt, n) 95 | } 96 | tt.sort() 97 | return nil 98 | } 99 | 100 | func (tt ClusterTopo) sort() { 101 | // first go through each node and make sure the individual slot sets are 102 | // sorted 103 | for _, node := range tt { 104 | sort.Slice(node.Slots, func(i, j int) bool { 105 | return node.Slots[i][0] < node.Slots[j][0] 106 | }) 107 | } 108 | 109 | sort.Slice(tt, func(i, j int) bool { 110 | if tt[i].Slots[0] != tt[j].Slots[0] { 111 | return tt[i].Slots[0][0] < tt[j].Slots[0][0] 112 | } 113 | // we want secondaries to come after, which actually means they should 114 | // be sorted as greater 115 | return tt[i].SecondaryOfAddr == "" 116 | }) 117 | 118 | } 119 | 120 | // Map returns the topology as a mapping of node address to its ClusterNode. 121 | func (tt ClusterTopo) Map() map[string]ClusterNode { 122 | m := make(map[string]ClusterNode, len(tt)) 123 | for _, t := range tt { 124 | m[t.Addr] = t 125 | } 126 | return m 127 | } 128 | 129 | // Primaries returns a ClusterTopo instance containing only the primary nodes 130 | // from the ClusterTopo being called on. 131 | func (tt ClusterTopo) Primaries() ClusterTopo { 132 | mtt := make(ClusterTopo, 0, len(tt)) 133 | for _, node := range tt { 134 | if node.SecondaryOfAddr == "" { 135 | mtt = append(mtt, node) 136 | } 137 | } 138 | return mtt 139 | } 140 | 141 | // we only use this type during unmarshalling, the topo Unmarshal method will 142 | // convert these into ClusterNodes. 143 | type topoSlotSet struct { 144 | slots [2]uint16 145 | nodes []ClusterNode 146 | } 147 | 148 | func (tss topoSlotSet) MarshalRESP(w io.Writer) error { 149 | var err error 150 | marshal := func(m resp.Marshaler) { 151 | if err == nil { 152 | err = m.MarshalRESP(w) 153 | } 154 | } 155 | 156 | marshal(resp2.ArrayHeader{N: 2 + len(tss.nodes)}) 157 | marshal(resp2.Any{I: tss.slots[0]}) 158 | marshal(resp2.Any{I: tss.slots[1] - 1}) 159 | 160 | for _, n := range tss.nodes { 161 | 162 | host, portStr, _ := net.SplitHostPort(n.Addr) 163 | 164 | port, err := strconv.Atoi(portStr) 165 | if err != nil { 166 | return err 167 | } 168 | 169 | node := []interface{}{host, port} 170 | if n.ID != "" { 171 | node = append(node, n.ID) 172 | } 173 | marshal(resp2.Any{I: node}) 174 | } 175 | 176 | return err 177 | } 178 | 179 | func (tss *topoSlotSet) UnmarshalRESP(br *bufio.Reader) error { 180 | var arrHead resp2.ArrayHeader 181 | if err := arrHead.UnmarshalRESP(br); err != nil { 182 | return err 183 | } 184 | 185 | // first two array elements are the slot numbers. We increment the second to 186 | // preserve inclusive start/exclusive end, which redis doesn't 187 | for i := range tss.slots { 188 | if err := (resp2.Any{I: &tss.slots[i]}).UnmarshalRESP(br); err != nil { 189 | return err 190 | } 191 | } 192 | tss.slots[1]++ 193 | arrHead.N -= len(tss.slots) 194 | 195 | var primaryNode ClusterNode 196 | for i := 0; i < arrHead.N; i++ { 197 | 198 | var nodeArrHead resp2.ArrayHeader 199 | if err := nodeArrHead.UnmarshalRESP(br); err != nil { 200 | return err 201 | } else if nodeArrHead.N < 2 { 202 | return fmt.Errorf("expected at least 2 array elements, got %d", nodeArrHead.N) 203 | } 204 | 205 | var ip resp2.BulkString 206 | if err := ip.UnmarshalRESP(br); err != nil { 207 | return err 208 | } 209 | 210 | var port resp2.Int 211 | if err := port.UnmarshalRESP(br); err != nil { 212 | return err 213 | } 214 | 215 | nodeArrHead.N -= 2 216 | 217 | var id resp2.BulkString 218 | if nodeArrHead.N > 0 { 219 | if err := id.UnmarshalRESP(br); err != nil { 220 | return err 221 | } 222 | nodeArrHead.N-- 223 | } 224 | 225 | // discard anything after 226 | for i := 0; i < nodeArrHead.N; i++ { 227 | if err := (resp2.Any{}).UnmarshalRESP(br); err != nil { 228 | return err 229 | } 230 | } 231 | 232 | node := ClusterNode{ 233 | Addr: net.JoinHostPort(ip.S, strconv.FormatInt(port.I, 10)), 234 | ID: id.S, 235 | Slots: [][2]uint16{tss.slots}, 236 | } 237 | 238 | if i == 0 { 239 | primaryNode = node 240 | } else { 241 | node.SecondaryOfAddr = primaryNode.Addr 242 | node.SecondaryOfID = primaryNode.ID 243 | } 244 | 245 | tss.nodes = append(tss.nodes, node) 246 | } 247 | 248 | return nil 249 | } 250 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | Changelog from v3.0.1 and up. Prior changes don't have a changelog. 2 | 3 | # v3.8.1 4 | 5 | * Fixed `NewCluster` not returning an error if it can't connect to any of the 6 | redis instances given. (#319) 7 | 8 | * Fix deadlock in `Cluster` when using `DoSecondary`. (#317) 9 | 10 | * Fix parsing for `CLUSTER SLOTS` command, which changed slightly with redis 11 | 7.0. (#322) 12 | 13 | # v3.8.0 14 | 15 | **New** 16 | 17 | * Add `PoolMaxLifetime` option for `Pool`. (PR #294) 18 | 19 | **Fixes And Improvements** 20 | 21 | * Switched to using `errors` package, rather than `golang.org/x/xerrors`. (PR 22 | #300) 23 | 24 | * Switch to using Github Actions from travis. (PR #300) 25 | 26 | * Fixed IPv6 addresses breaking `Cluster`. (Issue #288) 27 | 28 | # v3.7.1 29 | 30 | * Release the RLock in `Sentinel`'s `Do`. (PR #272) 31 | 32 | # v3.7.0 33 | 34 | **New** 35 | 36 | * Add `FallbackToUndelivered` option to `StreamReaderOpts`. (PR #244) 37 | 38 | * Add `ClusterOnInitAllowUnavailable`. (PR #247) 39 | 40 | **Fixes and Improvements** 41 | 42 | * Fix reading a RESP error into a `*interface{}` panicking. (PR #240) 43 | 44 | # v3.6.0 45 | 46 | **New** 47 | 48 | * Add `Tuple` type, which makes unmarshaling `EXEC` and `EVAL` results easier. 49 | 50 | * Add `PersistentPubSubErrCh`, so that asynchronous errors within 51 | `PersistentPubSub` can be exposed to the user. 52 | 53 | * Add `FlatCmd` method to `EvalScript`. 54 | 55 | * Add `StreamEntries` unmarshaler to make unmarshaling `XREAD` and `XREADGROUP` 56 | results easier. 57 | 58 | **Fixes and Improvements** 59 | 60 | * Fix wrapped errors not being handled correctly by `Cluster`. (PR #229) 61 | 62 | * Fix `PersistentPubSub` deadlocking when a method was called after `Close`. 63 | (PR #230) 64 | 65 | * Fix `StreamReader` not correctly handling the case of reading from multiple 66 | streams when one is empty. (PR #224) 67 | 68 | # v3.5.2 69 | 70 | * Improve docs for `WithConn` and `PubSubConn`. 71 | 72 | * Fix `PubSubConn`'s `Subscribe` and `PSubscribe` methods potentially mutating 73 | the passed in array of strings. (Issue #217) 74 | 75 | * Fix `StreamEntry` not properly handling unmarshaling an entry with a nil 76 | fields array. (PR #218) 77 | 78 | # v3.5.1 79 | 80 | * Add `EmptyArray` field to `MaybeNil`. (PR #211) 81 | 82 | * Fix `Cluster` not properly re-initializing itself when the cluster goes 83 | completely down. (PR #209) 84 | 85 | # v3.5.0 86 | 87 | Huge thank you to @nussjustin for all the work he's been doing on this project, 88 | this release is almost entirely his doing. 89 | 90 | **New** 91 | 92 | * Add support for `TYPE` option to `Scanner`. (PR #187) 93 | 94 | * Add `Sentinel.DoSecondary` method. (PR #197) 95 | 96 | * Add `DialAuthUser`, to support username+password authentication. (PR #195) 97 | 98 | * Add `Cluster.DoSecondary` method. (PR #198) 99 | 100 | **Fixes and Improvements** 101 | 102 | * Fix pipeline behavior when a decode error is encountered. (PR #180) 103 | 104 | * Fix `Reason` in `PoolConnClosed` in the case of the Pool being full. (PR #186) 105 | 106 | * Refactor `PersistentPubSub` to be cleaner, fixing a panic in the process. 107 | (PR #185, Issue #184) 108 | 109 | * Fix marshaling of nil pointers in structs. (PR #192) 110 | 111 | * Wrap errors which get returned from pipeline decoding. (PR #191) 112 | 113 | * Simplify and improve pipeline error handling. (PR #190) 114 | 115 | * Dodge a `[]byte` allocation when in `StreamReader.Next`. (PR #196) 116 | 117 | * Remove excess lock in Pool. (PR #202) 118 | 119 | 120 | # v3.4.2 121 | 122 | * Fix alignment for atomic values in structs (PR #171) 123 | 124 | * Fix closing of sentinel instances while updating state (PR #173) 125 | 126 | # v3.4.1 127 | 128 | * Update xerrors package (PR #165) 129 | 130 | * Have cluster Pools be closed outside of lock, to reduce contention during 131 | failover events (PR #168) 132 | 133 | # v3.4.0 134 | 135 | * Add `PersistentPubSubWithOpts` function, deprecating the old 136 | `PersistentPubSub` function. (PR #156) 137 | 138 | * Make decode errors a bit more helpful. (PR #157) 139 | 140 | * Refactor Pool to rely on its inner lock less, simplifying the code quite a bit 141 | and hopefully speeding up certain actions. (PR #160) 142 | 143 | * Various documentation updates. (PR #138, Issue #162) 144 | 145 | # v3.3.2 146 | 147 | * Have `resp2.Error` match with a `resp.ErrDiscarded` when using `errors.As`. 148 | Fixes EVAL, among probably other problems. (PR #152) 149 | 150 | # v3.3.1 151 | 152 | * Use `xerrors` internally. (PR #113) 153 | 154 | * Handle unmarshal errors better. Previously an unmarshaling error could leave 155 | the connection in an inconsistent state, because the full message wouldn't get 156 | completely read off the wire. After a lot of work, this has been fixed. (PR 157 | #127, #139, #145) 158 | 159 | * Handle CLUSTERDOWN errors better. Upon seeing a CLUSTERDOWN, all commands will 160 | be delayed by a small amount of time. The delay will be stopped as soon as the 161 | first non-CLUSTERDOWN result is seen from the Cluster. The idea is that, if a 162 | failover happens, commands which are incoming will be paused long enough for 163 | the cluster to regain it sanity, thus minimizing the number of failed commands 164 | during the failover. (PR #137) 165 | 166 | * Fix cluster redirect tracing. (PR #142) 167 | 168 | # v3.3.0 169 | 170 | **New** 171 | 172 | * Add `trace` package with tracing callbacks for `Pool` and `Cluster`. 173 | (`Sentinel` coming soon!) (PR #100, PR #108, PR #111) 174 | 175 | * Add `SentinelAddrs` method to `Sentinel` (PR #118) 176 | 177 | * Add `DialUseTLS` option. (PR #104) 178 | 179 | **Fixes and Improvements** 180 | 181 | * Fix `NewSentinel` not handling URL AUTH parameters correctly (PR #120) 182 | 183 | * Change `DefaultClientFunc`'s pool size from 20 to 4, on account of pipelining 184 | being enabled by default. (Issue #107) 185 | 186 | * Reuse `reflect.Value` instances when unmarshaling into certain map types. (PR 187 | #96). 188 | 189 | * Fix a panic in `FlatCmd`. (PR #97) 190 | 191 | * Reuse field name `string` when unmarshaling into a struct. (PR #95) 192 | 193 | * Reduce PubSub allocations significantly. (PR #92 + Issue #91) 194 | 195 | * Reduce allocations in `Conn`. (PR #84) 196 | 197 | # v3.2.3 198 | 199 | * Optimize Scanner implementation. 200 | 201 | * Fix bug with using types which implement resp.LenReader, encoding.TextMarshaler, and encoding.BinaryMarshaler. The encoder wasn't properly taking into account the interfaces when counting the number of elements in the message. 202 | 203 | # v3.2.2 204 | 205 | * Give Pool an ErrCh so that errors which happen internally may be reported to 206 | the user, if they care. 207 | 208 | * Fix `PubSubConn`'s deadlock problems during Unsubscribe commands. 209 | 210 | * Small speed optimizations in network protocol code. 211 | 212 | # v3.2.1 213 | 214 | * Move benchmarks to a submodule in order to clean up `go.mod` a bit. 215 | 216 | # v3.2.0 217 | 218 | * Add `StreamReader` type to make working with redis' new [Stream][stream] 219 | functionality easier. 220 | 221 | * Make `Sentinel` properly respond to `Client` method calls. Previously it 222 | always created a new `Client` instance when a secondary was requested, now it 223 | keeps track of instances internally. 224 | 225 | * Make default `Dial` call have a timeout for connect/read/write. At the same 226 | time, normalize default timeout values across the project. 227 | 228 | * Implicitly pipeline commands in the default Pool implementation whenever 229 | possible. This gives a throughput increase of nearly 5x for a normal parallel 230 | workload. 231 | 232 | [stream]: https://redis.io/topics/streams-intro 233 | 234 | # v3.1.0 235 | 236 | * Add support for marshaling/unmarshaling structs. 237 | 238 | # v3.0.1 239 | 240 | * Make `Stub` support `Pipeline` properly. 241 | -------------------------------------------------------------------------------- /bench/bench_test.go: -------------------------------------------------------------------------------- 1 | package bench 2 | 3 | import ( 4 | "context" 5 | "runtime" 6 | "strings" 7 | . "testing" 8 | "time" 9 | 10 | "errors" 11 | 12 | redigo "github.com/gomodule/redigo/redis" 13 | redispipe "github.com/joomcode/redispipe/redis" 14 | redispipeconn "github.com/joomcode/redispipe/redisconn" 15 | "github.com/mediocregopher/radix/v3" 16 | ) 17 | 18 | func newRedigo() redigo.Conn { 19 | c, err := redigo.Dial("tcp", "127.0.0.1:6379") 20 | if err != nil { 21 | panic(err) 22 | } 23 | return c 24 | } 25 | 26 | func newRedisPipe(writePause time.Duration) redispipe.Sync { 27 | pipe, err := redispipeconn.Connect(context.Background(), "127.0.0.1:6379", redispipeconn.Opts{ 28 | Logger: redispipeconn.NoopLogger{}, 29 | WritePause: writePause, 30 | }) 31 | if err != nil { 32 | panic(err) 33 | } 34 | return redispipe.Sync{S: pipe} 35 | } 36 | 37 | func radixGetSet(client radix.Client, key, val string) error { 38 | if err := client.Do(radix.Cmd(nil, "SET", key, val)); err != nil { 39 | return err 40 | } 41 | var out string 42 | if err := client.Do(radix.Cmd(&out, "GET", key)); err != nil { 43 | return err 44 | } else if out != val { 45 | return errors.New("got wrong value") 46 | } 47 | return nil 48 | } 49 | 50 | func BenchmarkSerialGetSet(b *B) { 51 | b.Run("radix", func(b *B) { 52 | rad, err := radix.Dial("tcp", "127.0.0.1:6379") 53 | if err != nil { 54 | b.Fatal(err) 55 | } 56 | defer rad.Close() 57 | // avoid overhead of converting from radix.Conn to radix.Client on each loop iteration 58 | client := radix.Client(rad) 59 | b.ResetTimer() 60 | for i := 0; i < b.N; i++ { 61 | if err := radixGetSet(client, "foo", "bar"); err != nil { 62 | b.Fatal(err) 63 | } 64 | } 65 | }) 66 | 67 | b.Run("redigo", func(b *B) { 68 | red := newRedigo() 69 | b.ResetTimer() 70 | for i := 0; i < b.N; i++ { 71 | if _, err := red.Do("SET", "foo", "bar"); err != nil { 72 | b.Fatal(err) 73 | } 74 | if _, err := redigo.String(red.Do("GET", "foo")); err != nil { 75 | b.Fatal(err) 76 | } 77 | } 78 | }) 79 | 80 | b.Run("redispipe", func(b *B) { 81 | sync := newRedisPipe(150 * time.Microsecond) 82 | defer sync.S.Close() 83 | b.ResetTimer() 84 | for i := 0; i < b.N; i++ { 85 | if res := sync.Do("SET", "foo", "bar"); redispipe.AsError(res) != nil { 86 | b.Fatal(res) 87 | } else if res := sync.Do("GET", "foo"); redispipe.AsError(res) != nil { 88 | b.Fatal(res) 89 | } 90 | } 91 | }) 92 | 93 | b.Run("redispipe_pause0", func(b *B) { 94 | sync := newRedisPipe(-1) 95 | defer sync.S.Close() 96 | b.ResetTimer() 97 | for i := 0; i < b.N; i++ { 98 | if res := sync.Do("SET", "foo", "bar"); redispipe.AsError(res) != nil { 99 | b.Fatal(res) 100 | } 101 | if res := sync.Do("GET", "foo"); redispipe.AsError(res) != nil { 102 | b.Fatal(res) 103 | } 104 | } 105 | }) 106 | } 107 | 108 | func BenchmarkSerialGetSetLargeArgs(b *B) { 109 | key := strings.Repeat("foo", 24) 110 | val := strings.Repeat("bar", 4096) 111 | 112 | b.Run("radix", func(b *B) { 113 | rad, err := radix.Dial("tcp", "127.0.0.1:6379") 114 | if err != nil { 115 | b.Fatal(err) 116 | } 117 | defer rad.Close() 118 | // avoid overhead of converting from radix.Conn to radix.Client on each loop iteration 119 | client := radix.Client(rad) 120 | b.ResetTimer() 121 | for i := 0; i < b.N; i++ { 122 | if err := radixGetSet(client, key, val); err != nil { 123 | b.Fatal(err) 124 | } 125 | } 126 | }) 127 | 128 | b.Run("redigo", func(b *B) { 129 | red := newRedigo() 130 | b.ResetTimer() 131 | for i := 0; i < b.N; i++ { 132 | if _, err := red.Do("SET", key, val); err != nil { 133 | b.Fatal(err) 134 | } 135 | if _, err := redigo.String(red.Do("GET", key)); err != nil { 136 | b.Fatal(err) 137 | } 138 | } 139 | }) 140 | 141 | b.Run("redispipe", func(b *B) { 142 | sync := newRedisPipe(150 * time.Microsecond) 143 | defer sync.S.Close() 144 | b.ResetTimer() 145 | for i := 0; i < b.N; i++ { 146 | if res := sync.Do("SET", key, val); redispipe.AsError(res) != nil { 147 | b.Fatal(res) 148 | } 149 | if res := sync.Do("GET", key); redispipe.AsError(res) != nil { 150 | b.Fatal(res) 151 | } 152 | } 153 | }) 154 | 155 | b.Run("redispipe_pause0", func(b *B) { 156 | sync := newRedisPipe(-1) 157 | defer sync.S.Close() 158 | b.ResetTimer() 159 | for i := 0; i < b.N; i++ { 160 | if res := sync.Do("SET", key, val); redispipe.AsError(res) != nil { 161 | b.Fatal(res) 162 | } 163 | if res := sync.Do("GET", key); redispipe.AsError(res) != nil { 164 | b.Fatal(res) 165 | } 166 | } 167 | }) 168 | } 169 | 170 | func BenchmarkParallelGetSet(b *B) { 171 | // parallel defines a multiplicand used for determining the number of goroutines 172 | // for running benchmarks. this value will be multiplied by GOMAXPROCS inside RunParallel. 173 | // since these benchmarks are mostly I/O bound and applications tend to have more 174 | // active goroutines accessing Redis than cores, especially with higher core numbers, 175 | // we set this to GOMAXPROCS so that we get GOMAXPROCS^2 connections. 176 | parallel := runtime.GOMAXPROCS(0) 177 | 178 | // multiply parallel with GOMAXPROCS to get the actual number of goroutines and thus 179 | // connections needed for the benchmarks. 180 | poolSize := parallel * runtime.GOMAXPROCS(0) 181 | 182 | do := func(b *B, fn func() error) { 183 | b.ResetTimer() 184 | b.SetParallelism(parallel) 185 | b.RunParallel(func(pb *PB) { 186 | for pb.Next() { 187 | if err := fn(); err != nil { 188 | b.Fatal(err) 189 | } 190 | } 191 | }) 192 | } 193 | 194 | b.Run("radix", func(b *B) { 195 | mkRadixBench := func(opts ...radix.PoolOpt) func(b *B) { 196 | return func(b *B) { 197 | pool, err := radix.NewPool("tcp", "127.0.0.1:6379", poolSize, opts...) 198 | if err != nil { 199 | b.Fatal(err) 200 | } 201 | defer pool.Close() 202 | 203 | // wait for the pool to fill up 204 | for { 205 | time.Sleep(50 * time.Millisecond) 206 | if pool.NumAvailConns() >= poolSize { 207 | break 208 | } 209 | } 210 | 211 | // avoid overhead of boxing the pool on each loop iteration 212 | client := radix.Client(pool) 213 | b.ResetTimer() 214 | do(b, func() error { 215 | return radixGetSet(client, "foo", "bar") 216 | }) 217 | } 218 | } 219 | 220 | b.Run("no pipeline", mkRadixBench(radix.PoolPipelineWindow(0, 0))) 221 | b.Run("one pipeline", mkRadixBench(radix.PoolPipelineConcurrency(1))) 222 | b.Run("default", mkRadixBench()) 223 | }) 224 | 225 | b.Run("redigo", func(b *B) { 226 | red := &redigo.Pool{MaxIdle: poolSize, Dial: func() (redigo.Conn, error) { 227 | return newRedigo(), nil 228 | }} 229 | defer red.Close() 230 | 231 | { // make sure the pool is full 232 | var conns []redigo.Conn 233 | for red.MaxIdle > red.ActiveCount() { 234 | conns = append(conns, red.Get()) 235 | } 236 | for _, conn := range conns { 237 | _ = conn.Close() 238 | } 239 | } 240 | 241 | do(b, func() error { 242 | conn := red.Get() 243 | if _, err := conn.Do("SET", "foo", "bar"); err != nil { 244 | conn.Close() 245 | return err 246 | } 247 | if out, err := redigo.String(conn.Do("GET", "foo")); err != nil { 248 | conn.Close() 249 | return err 250 | } else if out != "bar" { 251 | conn.Close() 252 | return errors.New("got wrong value") 253 | } 254 | return conn.Close() 255 | }) 256 | }) 257 | 258 | b.Run("redispipe", func(b *B) { 259 | sync := newRedisPipe(150 * time.Microsecond) 260 | defer sync.S.Close() 261 | do(b, func() error { 262 | if res := sync.Do("SET", "foo", "bar"); redispipe.AsError(res) != nil { 263 | return redispipe.AsError(res) 264 | } else if res := sync.Do("GET", "foo"); redispipe.AsError(res) != nil { 265 | return redispipe.AsError(res) 266 | } 267 | return nil 268 | }) 269 | }) 270 | } 271 | -------------------------------------------------------------------------------- /radix.go: -------------------------------------------------------------------------------- 1 | // Package radix implements all functionality needed to work with redis and all 2 | // things related to it, including redis cluster, pubsub, sentinel, scanning, 3 | // lua scripting, and more. 4 | // 5 | // Creating a client 6 | // 7 | // For a single node redis instance use NewPool to create a connection pool. The 8 | // connection pool is thread-safe and will automatically create, reuse, and 9 | // recreate connections as needed: 10 | // 11 | // pool, err := radix.NewPool("tcp", "127.0.0.1:6379", 10) 12 | // if err != nil { 13 | // // handle error 14 | // } 15 | // 16 | // If you're using sentinel or cluster you should use NewSentinel or NewCluster 17 | // (respectively) to create your client instead. 18 | // 19 | // Commands 20 | // 21 | // Any redis command can be performed by passing a Cmd into a Client's Do 22 | // method. Each Cmd should only be used once. The return from the Cmd can be 23 | // captured into any appopriate go primitive type, or a slice, map, or struct, 24 | // if the command returns an array. 25 | // 26 | // err := client.Do(radix.Cmd(nil, "SET", "foo", "someval")) 27 | // 28 | // var fooVal string 29 | // err := client.Do(radix.Cmd(&fooVal, "GET", "foo")) 30 | // 31 | // var fooValB []byte 32 | // err := client.Do(radix.Cmd(&fooValB, "GET", "foo")) 33 | // 34 | // var barI int 35 | // err := client.Do(radix.Cmd(&barI, "INCR", "bar")) 36 | // 37 | // var bazEls []string 38 | // err := client.Do(radix.Cmd(&bazEls, "LRANGE", "baz", "0", "-1")) 39 | // 40 | // var buzMap map[string]string 41 | // err := client.Do(radix.Cmd(&buzMap, "HGETALL", "buz")) 42 | // 43 | // FlatCmd can also be used if you wish to use non-string arguments like 44 | // integers, slices, maps, or structs, and have them automatically be flattened 45 | // into a single string slice. 46 | // 47 | // Struct Scanning 48 | // 49 | // Cmd and FlatCmd can unmarshal results into a struct. The results must be a 50 | // key/value array, such as that returned by HGETALL. Exported field names will 51 | // be used as keys, unless the fields have the "redis" tag: 52 | // 53 | // type MyType struct { 54 | // Foo string // Will be populated with the value for key "Foo" 55 | // Bar string `redis:"BAR"` // Will be populated with the value for key "BAR" 56 | // Baz string `redis:"-"` // Will not be populated 57 | // } 58 | // 59 | // Embedded structs will inline that struct's fields into the parent's: 60 | // 61 | // type MyOtherType struct { 62 | // // adds fields "Foo" and "BAR" (from above example) to MyOtherType 63 | // MyType 64 | // Biz int 65 | // } 66 | // 67 | // The same rules for field naming apply when a struct is passed into FlatCmd as 68 | // an argument. 69 | // 70 | // Actions 71 | // 72 | // Cmd and FlatCmd both implement the Action interface. Other Actions include 73 | // Pipeline, WithConn, and EvalScript.Cmd. Any of these may be passed into any 74 | // Client's Do method. 75 | // 76 | // var fooVal string 77 | // p := radix.Pipeline( 78 | // radix.FlatCmd(nil, "SET", "foo", 1), 79 | // radix.Cmd(&fooVal, "GET", "foo"), 80 | // ) 81 | // if err := client.Do(p); err != nil { 82 | // panic(err) 83 | // } 84 | // fmt.Printf("fooVal: %q\n", fooVal) 85 | // 86 | // Transactions 87 | // 88 | // There are two ways to perform transactions in redis. The first is with the 89 | // MULTI/EXEC commands, which can be done using the WithConn Action (see its 90 | // example). The second is using EVAL with lua scripting, which can be done 91 | // using the EvalScript Action (again, see its example). 92 | // 93 | // EVAL with lua scripting is recommended in almost all cases. It only requires 94 | // a single round-trip, it's infinitely more flexible than MULTI/EXEC, it's 95 | // simpler to code, and for complex transactions, which would otherwise need a 96 | // WATCH statement with MULTI/EXEC, it's significantly faster. 97 | // 98 | // AUTH and other settings via ConnFunc and ClientFunc 99 | // 100 | // All the client creation functions (e.g. NewPool) take in either a ConnFunc or 101 | // a ClientFunc via their options. These can be used in order to set up timeouts 102 | // on connections, perform authentication commands, or even implement custom 103 | // pools. 104 | // 105 | // // this is a ConnFunc which will set up a connection which is authenticated 106 | // // and has a 1 minute timeout on all operations 107 | // customConnFunc := func(network, addr string) (radix.Conn, error) { 108 | // return radix.Dial(network, addr, 109 | // radix.DialTimeout(1 * time.Minute), 110 | // radix.DialAuthPass("mySuperSecretPassword"), 111 | // ) 112 | // } 113 | // 114 | // // this pool will use our ConnFunc for all connections it creates 115 | // pool, err := radix.NewPool("tcp", redisAddr, 10, PoolConnFunc(customConnFunc)) 116 | // 117 | // // this cluster will use the ClientFunc to create a pool to each node in the 118 | // // cluster. The pools also use our customConnFunc, but have more connections 119 | // poolFunc := func(network, addr string) (radix.Client, error) { 120 | // return radix.NewPool(network, addr, 100, PoolConnFunc(customConnFunc)) 121 | // } 122 | // cluster, err := radix.NewCluster([]string{redisAddr1, redisAddr2}, ClusterPoolFunc(poolFunc)) 123 | // 124 | // Custom implementations 125 | // 126 | // All interfaces in this package were designed such that they could have custom 127 | // implementations. There is no dependency within radix that demands any 128 | // interface be implemented by a particular underlying type, so feel free to 129 | // create your own Pools or Conns or Actions or whatever makes your life easier. 130 | // 131 | // Errors 132 | // 133 | // Errors returned from redis can be explicitly checked for using the the 134 | // resp2.Error type. Note that the errors.As function, introduced in go 1.13, 135 | // should be used. 136 | // 137 | // var redisErr resp2.Error 138 | // err := client.Do(radix.Cmd(nil, "AUTH", "wrong password")) 139 | // if errors.As(err, &redisErr) { 140 | // log.Printf("redis error returned: %s", redisErr.E) 141 | // } 142 | // 143 | // Use the golang.org/x/xerrors package if you're using an older version of go. 144 | // 145 | // Implicit pipelining 146 | // 147 | // Implicit pipelining is an optimization implemented and enabled in the default 148 | // Pool implementation (and therefore also used by Cluster and Sentinel) which 149 | // involves delaying concurrent Cmds and FlatCmds a small amount of time and 150 | // sending them to redis in a single batch, similar to manually using a Pipeline. 151 | // By doing this radix significantly reduces the I/O and CPU overhead for 152 | // concurrent requests. 153 | // 154 | // Note that only commands which do not block are eligible for implicit pipelining. 155 | // 156 | // See the documentation on Pool for more information about the current 157 | // implementation of implicit pipelining and for how to configure or disable 158 | // the feature. 159 | // 160 | // For a performance comparisons between Clients with and without implicit 161 | // pipelining see the benchmark results in the README.md. 162 | // 163 | package radix 164 | 165 | import ( 166 | "errors" 167 | ) 168 | 169 | var errClientClosed = errors.New("client is closed") 170 | 171 | // Client describes an entity which can carry out Actions, e.g. a connection 172 | // pool for a single redis instance or the cluster client. 173 | // 174 | // Implementations of Client are expected to be thread-safe, except in cases 175 | // like Conn where they specify otherwise. 176 | type Client interface { 177 | // Do performs an Action, returning any error. 178 | Do(Action) error 179 | 180 | // Once Close() is called all future method calls on the Client will return 181 | // an error 182 | Close() error 183 | } 184 | 185 | // ClientFunc is a function which can be used to create a Client for a single 186 | // redis instance on the given network/address. 187 | type ClientFunc func(network, addr string) (Client, error) 188 | 189 | // DefaultClientFunc is a ClientFunc which will return a Client for a redis 190 | // instance using sane defaults. 191 | var DefaultClientFunc = func(network, addr string) (Client, error) { 192 | return NewPool(network, addr, 4) 193 | } 194 | -------------------------------------------------------------------------------- /cluster_topo_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | . "testing" 7 | 8 | "github.com/mediocregopher/radix/v3/resp" 9 | "github.com/mediocregopher/radix/v3/resp/resp2" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func respArr(ii ...interface{}) resp.Marshaler { 15 | var ar resp2.Array 16 | for _, i := range ii { 17 | if m, ok := i.(resp.Marshaler); ok { 18 | ar.A = append(ar.A, m) 19 | } else { 20 | ar.A = append(ar.A, resp2.Any{I: i}) 21 | } 22 | } 23 | return ar 24 | } 25 | 26 | var testTopoResp = func() resp.Marshaler { 27 | return respArr( 28 | respArr(13653, 16383, 29 | respArr("10.128.0.30", 6379, "f7e95c8730634159bc79f9edac566f7b1c964cdd"), 30 | respArr("10.128.0.27", 6379), // older redis instances don't return id 31 | ), 32 | respArr(5461, 8190, 33 | respArr("10.128.0.36", 6379, "a3c69665bb05c8d5770407cad5b35af29e740586"), 34 | 35 | // redis >7.0 returns an extra map as a fourth arg 36 | respArr("10.128.0.24", 6379, "bef29809fbfe964d3b7c3ad02d3d9a40e55de317", 37 | respArr("hostname", "whatever")), 38 | ), 39 | respArr(10923, 13652, 40 | respArr("10.128.0.20", 6379, "e0abc57f65496368e73a9b52b55efd00668adab7"), 41 | respArr("10.128.0.35", 6379, "3e231d265d6ec0c5aa11614eb86704b65f7f909e"), 42 | ), 43 | respArr(8191, 10922, 44 | respArr("10.128.0.29", 6379, "43f1b46d2772fd7bb78b144ddfc3fe77a9f21748"), 45 | respArr("10.128.0.26", 6379, "25339aee29100492d73cbfb1518e318ce1f2fd57"), 46 | ), 47 | respArr(2730, 5460, 48 | respArr("10.128.0.25", 6379, "78e4bb43f68cdc929815a65b4db0697fdda2a9fa"), 49 | respArr("10.128.0.28", 6379, "5a57538cd8ae102daee1dd7f34070e133ff92173"), 50 | ), 51 | respArr(0, 2729, 52 | respArr("10.128.0.34", 6379, "062d8ca98db4deb6b2a3fc776a774dbb710c1a24"), 53 | respArr("10.128.0.3", 6379, "7be2403f92c00d4907da742ffa4c84b935228350"), 54 | ), 55 | ) 56 | }() 57 | 58 | var testTopo = func() ClusterTopo { 59 | buf := new(bytes.Buffer) 60 | if err := testTopoResp.MarshalRESP(buf); err != nil { 61 | panic(err) 62 | } 63 | var tt ClusterTopo 64 | if err := tt.UnmarshalRESP(bufio.NewReader(buf)); err != nil { 65 | panic(err) 66 | } 67 | return tt 68 | }() 69 | 70 | func TestClusterTopo(t *T) { 71 | testTopoExp := ClusterTopo{ 72 | ClusterNode{ 73 | Slots: [][2]uint16{{0, 2730}}, 74 | Addr: "10.128.0.34:6379", ID: "062d8ca98db4deb6b2a3fc776a774dbb710c1a24", 75 | }, 76 | ClusterNode{ 77 | Slots: [][2]uint16{{0, 2730}}, 78 | Addr: "10.128.0.3:6379", ID: "7be2403f92c00d4907da742ffa4c84b935228350", 79 | SecondaryOfAddr: "10.128.0.34:6379", 80 | SecondaryOfID: "062d8ca98db4deb6b2a3fc776a774dbb710c1a24", 81 | }, 82 | 83 | ClusterNode{ 84 | Slots: [][2]uint16{{2730, 5461}}, 85 | Addr: "10.128.0.25:6379", ID: "78e4bb43f68cdc929815a65b4db0697fdda2a9fa", 86 | }, 87 | ClusterNode{ 88 | Slots: [][2]uint16{{2730, 5461}}, 89 | Addr: "10.128.0.28:6379", ID: "5a57538cd8ae102daee1dd7f34070e133ff92173", 90 | SecondaryOfAddr: "10.128.0.25:6379", 91 | SecondaryOfID: "78e4bb43f68cdc929815a65b4db0697fdda2a9fa", 92 | }, 93 | 94 | ClusterNode{ 95 | Slots: [][2]uint16{{5461, 8191}}, 96 | Addr: "10.128.0.36:6379", ID: "a3c69665bb05c8d5770407cad5b35af29e740586", 97 | }, 98 | ClusterNode{ 99 | Slots: [][2]uint16{{5461, 8191}}, 100 | Addr: "10.128.0.24:6379", ID: "bef29809fbfe964d3b7c3ad02d3d9a40e55de317", 101 | SecondaryOfAddr: "10.128.0.36:6379", 102 | SecondaryOfID: "a3c69665bb05c8d5770407cad5b35af29e740586", 103 | }, 104 | 105 | ClusterNode{ 106 | Slots: [][2]uint16{{8191, 10923}}, 107 | Addr: "10.128.0.29:6379", ID: "43f1b46d2772fd7bb78b144ddfc3fe77a9f21748", 108 | }, 109 | ClusterNode{ 110 | Slots: [][2]uint16{{8191, 10923}}, 111 | Addr: "10.128.0.26:6379", ID: "25339aee29100492d73cbfb1518e318ce1f2fd57", 112 | SecondaryOfAddr: "10.128.0.29:6379", 113 | SecondaryOfID: "43f1b46d2772fd7bb78b144ddfc3fe77a9f21748", 114 | }, 115 | 116 | ClusterNode{ 117 | Slots: [][2]uint16{{10923, 13653}}, 118 | Addr: "10.128.0.20:6379", ID: "e0abc57f65496368e73a9b52b55efd00668adab7", 119 | }, 120 | ClusterNode{ 121 | Slots: [][2]uint16{{10923, 13653}}, 122 | Addr: "10.128.0.35:6379", ID: "3e231d265d6ec0c5aa11614eb86704b65f7f909e", 123 | SecondaryOfAddr: "10.128.0.20:6379", 124 | SecondaryOfID: "e0abc57f65496368e73a9b52b55efd00668adab7", 125 | }, 126 | 127 | ClusterNode{ 128 | Slots: [][2]uint16{{13653, 16384}}, 129 | Addr: "10.128.0.30:6379", ID: "f7e95c8730634159bc79f9edac566f7b1c964cdd", 130 | }, 131 | ClusterNode{ 132 | Slots: [][2]uint16{{13653, 16384}}, 133 | Addr: "10.128.0.27:6379", ID: "", 134 | SecondaryOfAddr: "10.128.0.30:6379", 135 | SecondaryOfID: "f7e95c8730634159bc79f9edac566f7b1c964cdd", 136 | }, 137 | } 138 | 139 | // make sure, to start with, the testTopo matches what we expect it to 140 | assert.Equal(t, testTopoExp, testTopo) 141 | 142 | // Make sure both Marshal/UnmarshalRESP on it are working correctly (the 143 | // calls in testTopoResp aren't actually on Topo's methods) 144 | buf := new(bytes.Buffer) 145 | require.Nil(t, testTopo.MarshalRESP(buf)) 146 | var testTopo2 ClusterTopo 147 | require.Nil(t, testTopo2.UnmarshalRESP(bufio.NewReader(buf))) 148 | assert.Equal(t, testTopoExp, testTopo2) 149 | } 150 | 151 | // Test parsing a topology where a node in the cluster has two different sets of 152 | // slots, as well as a secondary. 153 | func TestClusterTopoSplitSlots(t *T) { 154 | clusterSlotsResp := respArr( 155 | respArr(0, 0, 156 | respArr("127.0.0.1", 7001, "90900dd4ef2182825bc853c448737b2ba9975a50"), 157 | respArr("127.0.0.1", 7011, "073a013f8886b6cf4c1b018612102601534912e9"), 158 | ), 159 | respArr(1, 8191, 160 | respArr("127.0.0.1", 7000, "3ff1ddc420cfceeb4c42dc4b1f8f85c3acf984fe"), 161 | ), 162 | respArr(8192, 16383, 163 | respArr("127.0.0.1", 7001, "90900dd4ef2182825bc853c448737b2ba9975a50"), 164 | respArr("127.0.0.1", 7011, "073a013f8886b6cf4c1b018612102601534912e9"), 165 | ), 166 | ) 167 | expTopo := ClusterTopo{ 168 | ClusterNode{ 169 | Slots: [][2]uint16{{0, 1}, {8192, 16384}}, 170 | Addr: "127.0.0.1:7001", ID: "90900dd4ef2182825bc853c448737b2ba9975a50", 171 | }, 172 | ClusterNode{ 173 | Slots: [][2]uint16{{0, 1}, {8192, 16384}}, 174 | Addr: "127.0.0.1:7011", ID: "073a013f8886b6cf4c1b018612102601534912e9", 175 | SecondaryOfAddr: "127.0.0.1:7001", 176 | SecondaryOfID: "90900dd4ef2182825bc853c448737b2ba9975a50", 177 | }, 178 | ClusterNode{ 179 | Slots: [][2]uint16{{1, 8192}}, 180 | Addr: "127.0.0.1:7000", ID: "3ff1ddc420cfceeb4c42dc4b1f8f85c3acf984fe", 181 | }, 182 | } 183 | 184 | // unmarshal the resp into a Topo and make sure it matches expTopo 185 | { 186 | buf := new(bytes.Buffer) 187 | require.Nil(t, clusterSlotsResp.MarshalRESP(buf)) 188 | var topo ClusterTopo 189 | require.Nil(t, topo.UnmarshalRESP(bufio.NewReader(buf))) 190 | assert.Equal(t, expTopo, topo) 191 | } 192 | 193 | // marshal Topo, then re-unmarshal, and make sure it still matches 194 | { 195 | buf := new(bytes.Buffer) 196 | require.Nil(t, expTopo.MarshalRESP(buf)) 197 | var topo ClusterTopo 198 | require.Nil(t, topo.UnmarshalRESP(bufio.NewReader(buf))) 199 | assert.Equal(t, expTopo, topo) 200 | } 201 | 202 | } 203 | 204 | func TestIPV6ClusterTopo(t *T) { 205 | clusterSlotsResp := respArr( 206 | respArr(0, 0, 207 | respArr("8ffd:50::d4eb", 7001, "90900dd4ef2182825bc853c448737b2ba9975a50"), 208 | ), 209 | ) 210 | expTopo := ClusterTopo{ 211 | ClusterNode{ 212 | Slots: [][2]uint16{{0, 1}}, 213 | Addr: "[8ffd:50::d4eb]:7001", ID: "90900dd4ef2182825bc853c448737b2ba9975a50", 214 | }, 215 | } 216 | 217 | // unmarshal the resp into a Topo and make sure it matches expTopo 218 | { 219 | buf := new(bytes.Buffer) 220 | require.Nil(t, clusterSlotsResp.MarshalRESP(buf)) 221 | var topo ClusterTopo 222 | require.Nil(t, topo.UnmarshalRESP(bufio.NewReader(buf))) 223 | assert.Equal(t, expTopo, topo) 224 | } 225 | 226 | // marshal Topo, then re-unmarshal, and make sure it still matches 227 | { 228 | buf := new(bytes.Buffer) 229 | require.Nil(t, expTopo.MarshalRESP(buf)) 230 | var topo ClusterTopo 231 | require.Nil(t, topo.UnmarshalRESP(bufio.NewReader(buf))) 232 | assert.Equal(t, expTopo, topo) 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /pubsub_persistent.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | type persistentPubSubOpts struct { 10 | connFn ConnFunc 11 | abortAfter int 12 | errCh chan<- error 13 | } 14 | 15 | // PersistentPubSubOpt is an optional parameter which can be passed into 16 | // PersistentPubSub in order to affect its behavior. 17 | type PersistentPubSubOpt func(*persistentPubSubOpts) 18 | 19 | // PersistentPubSubConnFunc causes PersistentPubSub to use the given ConnFunc 20 | // when connecting to its destination. 21 | func PersistentPubSubConnFunc(connFn ConnFunc) PersistentPubSubOpt { 22 | return func(opts *persistentPubSubOpts) { 23 | opts.connFn = connFn 24 | } 25 | } 26 | 27 | // PersistentPubSubAbortAfter changes PersistentPubSub's reconnect behavior. 28 | // Usually PersistentPubSub will try to reconnect forever upon a disconnect, 29 | // blocking any methods which have been called until reconnect is successful. 30 | // 31 | // When PersistentPubSubAbortAfter is used, it will give up after that many 32 | // attempts and return the error to the method which has been blocked the 33 | // longest. Another method will need to be called in order for PersistentPubSub 34 | // to resume trying to reconnect. 35 | func PersistentPubSubAbortAfter(attempts int) PersistentPubSubOpt { 36 | return func(opts *persistentPubSubOpts) { 37 | opts.abortAfter = attempts 38 | } 39 | } 40 | 41 | // PersistentPubSubErrCh takes a channel which asynchronous errors 42 | // encountered by the PersistentPubSub can be read off of. If the channel blocks 43 | // the error will be dropped. The channel will be closed when PersistentPubSub 44 | // is closed. 45 | func PersistentPubSubErrCh(errCh chan<- error) PersistentPubSubOpt { 46 | return func(opts *persistentPubSubOpts) { 47 | opts.errCh = errCh 48 | } 49 | } 50 | 51 | type pubSubCmd struct { 52 | // msgCh can be set along with one of subscribe/unsubscribe/etc... 53 | msgCh chan<- PubSubMessage 54 | subscribe, unsubscribe, psubscribe, punsubscribe []string 55 | 56 | // ... or one of ping or close can be set 57 | ping, close bool 58 | 59 | // resCh is always set 60 | resCh chan error 61 | } 62 | 63 | type persistentPubSub struct { 64 | dial func() (Conn, error) 65 | opts persistentPubSubOpts 66 | 67 | subs, psubs chanSet 68 | 69 | curr PubSubConn 70 | currErrCh chan error 71 | 72 | cmdCh chan pubSubCmd 73 | 74 | closeErr error 75 | closeCh chan struct{} 76 | closeOnce sync.Once 77 | } 78 | 79 | // PersistentPubSubWithOpts is like PubSub, but instead of taking in an existing 80 | // Conn to wrap it will create one on the fly. If the connection is ever 81 | // terminated then a new one will be created and will be reset to the previous 82 | // connection's state. 83 | // 84 | // This is effectively a way to have a permanent PubSubConn established which 85 | // supports subscribing/unsubscribing but without the hassle of implementing 86 | // reconnect/re-subscribe logic. 87 | // 88 | // With default options, neither this function nor any of the methods on the 89 | // returned PubSubConn will ever return an error, they will instead block until 90 | // a connection can be successfully reinstated. 91 | // 92 | // PersistentPubSubWithOpts takes in a number of options which can overwrite its 93 | // default behavior. The default options PersistentPubSubWithOpts uses are: 94 | // 95 | // PersistentPubSubConnFunc(DefaultConnFunc) 96 | // 97 | func PersistentPubSubWithOpts( 98 | network, addr string, options ...PersistentPubSubOpt, 99 | ) ( 100 | PubSubConn, error, 101 | ) { 102 | opts := persistentPubSubOpts{ 103 | connFn: DefaultConnFunc, 104 | } 105 | for _, opt := range options { 106 | opt(&opts) 107 | } 108 | 109 | p := &persistentPubSub{ 110 | dial: func() (Conn, error) { return opts.connFn(network, addr) }, 111 | opts: opts, 112 | subs: chanSet{}, 113 | psubs: chanSet{}, 114 | cmdCh: make(chan pubSubCmd), 115 | closeCh: make(chan struct{}), 116 | } 117 | if err := p.refresh(); err != nil { 118 | return nil, err 119 | } 120 | go p.spin() 121 | return p, nil 122 | } 123 | 124 | // PersistentPubSub is deprecated in favor of PersistentPubSubWithOpts instead. 125 | func PersistentPubSub(network, addr string, connFn ConnFunc) PubSubConn { 126 | var opts []PersistentPubSubOpt 127 | if connFn != nil { 128 | opts = append(opts, PersistentPubSubConnFunc(connFn)) 129 | } 130 | // since PersistentPubSubAbortAfter isn't used, this will never return an 131 | // error, panic if it does 132 | p, err := PersistentPubSubWithOpts(network, addr, opts...) 133 | if err != nil { 134 | panic(fmt.Sprintf("PersistentPubSubWithOpts impossibly returned an error: %v", err)) 135 | } 136 | return p 137 | } 138 | 139 | // refresh only returns an error if the connection could not be made. 140 | func (p *persistentPubSub) refresh() error { 141 | if p.curr != nil { 142 | p.curr.Close() 143 | <-p.currErrCh 144 | p.curr = nil 145 | p.currErrCh = nil 146 | } 147 | 148 | attempt := func() (PubSubConn, chan error, error) { 149 | c, err := p.dial() 150 | if err != nil { 151 | return nil, nil, err 152 | } 153 | errCh := make(chan error, 1) 154 | pc := newPubSub(c, errCh) 155 | 156 | for msgCh, channels := range p.subs.inverse() { 157 | if err := pc.Subscribe(msgCh, channels...); err != nil { 158 | pc.Close() 159 | return nil, nil, err 160 | } 161 | } 162 | 163 | for msgCh, patterns := range p.psubs.inverse() { 164 | if err := pc.PSubscribe(msgCh, patterns...); err != nil { 165 | pc.Close() 166 | return nil, nil, err 167 | } 168 | } 169 | return pc, errCh, nil 170 | } 171 | 172 | var attempts int 173 | for { 174 | var err error 175 | if p.curr, p.currErrCh, err = attempt(); err == nil { 176 | return nil 177 | } 178 | attempts++ 179 | if p.opts.abortAfter > 0 && attempts >= p.opts.abortAfter { 180 | return err 181 | } 182 | time.Sleep(200 * time.Millisecond) 183 | } 184 | } 185 | 186 | func (p *persistentPubSub) execCmd(cmd pubSubCmd) error { 187 | if p.curr == nil { 188 | if err := p.refresh(); err != nil { 189 | return err 190 | } 191 | } 192 | 193 | // For all subscribe/unsubscribe/etc... commands the modifications to 194 | // p.subs/p.psubs are made first, so that if the actual call to curr fails 195 | // then refresh will still instate the new desired subscription. 196 | var err error 197 | switch { 198 | case len(cmd.subscribe) > 0: 199 | for _, channel := range cmd.subscribe { 200 | p.subs.add(channel, cmd.msgCh) 201 | } 202 | err = p.curr.Subscribe(cmd.msgCh, cmd.subscribe...) 203 | 204 | case len(cmd.unsubscribe) > 0: 205 | for _, channel := range cmd.unsubscribe { 206 | p.subs.del(channel, cmd.msgCh) 207 | } 208 | err = p.curr.Unsubscribe(cmd.msgCh, cmd.unsubscribe...) 209 | 210 | case len(cmd.psubscribe) > 0: 211 | for _, channel := range cmd.psubscribe { 212 | p.psubs.add(channel, cmd.msgCh) 213 | } 214 | err = p.curr.PSubscribe(cmd.msgCh, cmd.psubscribe...) 215 | 216 | case len(cmd.punsubscribe) > 0: 217 | for _, channel := range cmd.punsubscribe { 218 | p.psubs.del(channel, cmd.msgCh) 219 | } 220 | err = p.curr.PUnsubscribe(cmd.msgCh, cmd.punsubscribe...) 221 | 222 | case cmd.ping: 223 | err = p.curr.Ping() 224 | 225 | case cmd.close: 226 | if p.curr != nil { 227 | err = p.curr.Close() 228 | <-p.currErrCh 229 | } 230 | 231 | default: 232 | // don't do anything I guess 233 | } 234 | 235 | if err != nil { 236 | return p.refresh() 237 | } 238 | return nil 239 | } 240 | 241 | func (p *persistentPubSub) err(err error) { 242 | select { 243 | case p.opts.errCh <- err: 244 | default: 245 | } 246 | } 247 | 248 | func (p *persistentPubSub) spin() { 249 | for { 250 | select { 251 | case err := <-p.currErrCh: 252 | p.err(err) 253 | if err := p.refresh(); err != nil { 254 | p.err(err) 255 | } 256 | case cmd := <-p.cmdCh: 257 | cmd.resCh <- p.execCmd(cmd) 258 | if cmd.close { 259 | return 260 | } 261 | } 262 | } 263 | } 264 | 265 | func (p *persistentPubSub) cmd(cmd pubSubCmd) error { 266 | cmd.resCh = make(chan error, 1) 267 | select { 268 | case p.cmdCh <- cmd: 269 | return <-cmd.resCh 270 | case <-p.closeCh: 271 | return fmt.Errorf("closed") 272 | } 273 | } 274 | 275 | func (p *persistentPubSub) Subscribe(msgCh chan<- PubSubMessage, channels ...string) error { 276 | return p.cmd(pubSubCmd{ 277 | msgCh: msgCh, 278 | subscribe: channels, 279 | }) 280 | } 281 | 282 | func (p *persistentPubSub) Unsubscribe(msgCh chan<- PubSubMessage, channels ...string) error { 283 | return p.cmd(pubSubCmd{ 284 | msgCh: msgCh, 285 | unsubscribe: channels, 286 | }) 287 | } 288 | 289 | func (p *persistentPubSub) PSubscribe(msgCh chan<- PubSubMessage, channels ...string) error { 290 | return p.cmd(pubSubCmd{ 291 | msgCh: msgCh, 292 | psubscribe: channels, 293 | }) 294 | } 295 | 296 | func (p *persistentPubSub) PUnsubscribe(msgCh chan<- PubSubMessage, channels ...string) error { 297 | return p.cmd(pubSubCmd{ 298 | msgCh: msgCh, 299 | punsubscribe: channels, 300 | }) 301 | } 302 | 303 | func (p *persistentPubSub) Ping() error { 304 | return p.cmd(pubSubCmd{ping: true}) 305 | } 306 | 307 | func (p *persistentPubSub) Close() error { 308 | p.closeOnce.Do(func() { 309 | p.closeErr = p.cmd(pubSubCmd{close: true}) 310 | close(p.closeCh) 311 | if p.opts.errCh != nil { 312 | close(p.opts.errCh) 313 | } 314 | }) 315 | return p.closeErr 316 | } 317 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Radix 2 | 3 | Radix is a full-featured [Redis][redis] client for Go. See the reference links 4 | below for documentation and general usage examples. 5 | 6 | **[v3 Documentation](https://pkg.go.dev/github.com/mediocregopher/radix/v3#section-documentation)** 7 | 8 | **[v4 Documentation](https://pkg.go.dev/github.com/mediocregopher/radix/v4#section-documentation)** 9 | 10 | Please open an issue, or start a discussion in the chat, before opening a pull request! 11 | 12 | ## Features 13 | 14 | * Standard print-like API which supports **all current and future redis commands**. 15 | 16 | * Connection pool which uses **connection sharing** to minimize system calls. 17 | 18 | * Full support for [Sentinel][sentinel] and [Cluster][cluster]. 19 | 20 | * Helpers for [EVAL][eval], [SCAN][scan], [Streams][stream], and [Pipelining][pipelining]. 21 | 22 | * Support for [pubsub][pubsub], as well as persistent pubsub wherein if a 23 | connection is lost a new one transparently replaces it. 24 | 25 | * API design allows for custom implementations of nearly anything. 26 | 27 | ## Versions 28 | 29 | There are two major versions of radix being supported: 30 | 31 | * v3 is the more mature version, but lacks the polished API of v4. v3 is only accepting bug fixes at this point. 32 | 33 | * v4 has feature parity with v3 and more! The biggest selling points are: 34 | 35 | * More polished API. 36 | * Full [RESP3][resp3] support. 37 | * Support for [context.Context][context] on all blocking operations. 38 | * Connection sharing (called "implicit pipelining" in v3) now works with Pipeline and EvalScript. 39 | 40 | View the [CHANGELOG][v4changelog] for more details. 41 | 42 | [v4changelog]: https://github.com/mediocregopher/radix/blob/v4/CHANGELOG.md 43 | 44 | ## Installation and Usage 45 | 46 | Radix always aims to support the most recent two versions of go, and is likely 47 | to support others prior to those two. 48 | 49 | [Module][module]-aware mode: 50 | 51 | go get github.com/mediocregopher/radix/v3 52 | // import github.com/mediocregopher/radix/v3 53 | 54 | go get github.com/mediocregopher/radix/v4 55 | // import github.com/mediocregopher/radix/v4 56 | 57 | ## Testing 58 | 59 | # requires a redis server running on 127.0.0.1:6379 60 | go test github.com/mediocregopher/radix/v3 61 | go test github.com/mediocregopher/radix/v4 62 | 63 | ## Benchmarks 64 | 65 | Benchmarks were run in as close to a "real" environment as possible. Two GCE 66 | instances were booted up, one hosting the redis server with 2vCPUs, the other 67 | running the benchmarks (found in the `bench` directory) with 16vCPUs. 68 | 69 | The benchmarks test a variety of situations against many different redis 70 | drivers, and the results are very large. You can view them [here][bench 71 | results]. Below are some highlights (I've tried to be fair here): 72 | 73 | For a typical workload, which is lots of concurrent commands with relatively 74 | small amounts of data, radix outperforms all tested drivers except 75 | [redispipe][redispipe]: 76 | 77 | ``` 78 | BenchmarkDrivers/parallel/no_pipeline/small_kv/radixv4-64 17815254 2917 ns/op 199 B/op 6 allocs/op 79 | BenchmarkDrivers/parallel/no_pipeline/small_kv/radixv3-64 16688293 3120 ns/op 109 B/op 4 allocs/op 80 | BenchmarkDrivers/parallel/no_pipeline/small_kv/redigo-64 3504063 15092 ns/op 168 B/op 9 allocs/op 81 | BenchmarkDrivers/parallel/no_pipeline/small_kv/redispipe_pause150us-64 31668576 1680 ns/op 217 B/op 11 allocs/op 82 | BenchmarkDrivers/parallel/no_pipeline/small_kv/redispipe_pause0-64 31149280 1685 ns/op 218 B/op 11 allocs/op 83 | BenchmarkDrivers/parallel/no_pipeline/small_kv/go-redis-64 3768988 14409 ns/op 411 B/op 13 allocs/op 84 | ``` 85 | 86 | The story is similar for pipelining commands concurrently (radixv3 doesn't do as 87 | well here, because it doesn't support connection sharing for pipeline commands): 88 | 89 | ``` 90 | BenchmarkDrivers/parallel/pipeline/small_kv/radixv4-64 24720337 2245 ns/op 508 B/op 13 allocs/op 91 | BenchmarkDrivers/parallel/pipeline/small_kv/radixv3-64 6921868 7757 ns/op 165 B/op 7 allocs/op 92 | BenchmarkDrivers/parallel/pipeline/small_kv/redigo-64 6738849 8080 ns/op 170 B/op 9 allocs/op 93 | BenchmarkDrivers/parallel/pipeline/small_kv/redispipe_pause150us-64 44479539 1148 ns/op 316 B/op 12 allocs/op 94 | BenchmarkDrivers/parallel/pipeline/small_kv/redispipe_pause0-64 45290868 1126 ns/op 315 B/op 12 allocs/op 95 | BenchmarkDrivers/parallel/pipeline/small_kv/go-redis-64 6740984 7903 ns/op 475 B/op 15 allocs/op 96 | ``` 97 | 98 | For larger amounts of data being transferred the differences become less 99 | noticeable, but both radix versions come out on top: 100 | 101 | ``` 102 | BenchmarkDrivers/parallel/no_pipeline/large_kv/radixv4-64 2395707 22766 ns/op 12553 B/op 4 allocs/op 103 | BenchmarkDrivers/parallel/no_pipeline/large_kv/radixv3-64 3150398 17087 ns/op 12745 B/op 4 allocs/op 104 | BenchmarkDrivers/parallel/no_pipeline/large_kv/redigo-64 1593054 34038 ns/op 24742 B/op 9 allocs/op 105 | BenchmarkDrivers/parallel/no_pipeline/large_kv/redispipe_pause150us-64 2105118 25085 ns/op 16962 B/op 11 allocs/op 106 | BenchmarkDrivers/parallel/no_pipeline/large_kv/redispipe_pause0-64 2354427 24280 ns/op 17295 B/op 11 allocs/op 107 | BenchmarkDrivers/parallel/no_pipeline/large_kv/go-redis-64 1519354 35745 ns/op 14033 B/op 14 allocs/op 108 | ``` 109 | 110 | All results above show the high-concurrency results (`-cpu 64`). Concurrencies 111 | of 16 and 32 are also included in the results, but didn't show anything 112 | different. 113 | 114 | For serial workloads, which involve a single connection performing commands 115 | one after the other, radix is either as fast or within a couple % of the other 116 | drivers tested. This use-case is much less common, and so when tradeoffs have 117 | been made between parallel and serial performance radix has general leaned 118 | towards parallel. 119 | 120 | Serial non-pipelined: 121 | 122 | ``` 123 | BenchmarkDrivers/serial/no_pipeline/small_kv/radixv4-16 346915 161493 ns/op 67 B/op 4 allocs/op 124 | BenchmarkDrivers/serial/no_pipeline/small_kv/radixv3-16 428313 138011 ns/op 67 B/op 4 allocs/op 125 | BenchmarkDrivers/serial/no_pipeline/small_kv/redigo-16 416103 134438 ns/op 134 B/op 8 allocs/op 126 | BenchmarkDrivers/serial/no_pipeline/small_kv/redispipe_pause150us-16 86734 635637 ns/op 217 B/op 11 allocs/op 127 | BenchmarkDrivers/serial/no_pipeline/small_kv/redispipe_pause0-16 340320 158732 ns/op 216 B/op 11 allocs/op 128 | BenchmarkDrivers/serial/no_pipeline/small_kv/go-redis-16 429703 138854 ns/op 408 B/op 13 allocs/op 129 | ``` 130 | 131 | Serial pipelined: 132 | 133 | ``` 134 | BenchmarkDrivers/serial/pipeline/small_kv/radixv4-16 624417 82336 ns/op 83 B/op 5 allocs/op 135 | BenchmarkDrivers/serial/pipeline/small_kv/radixv3-16 784947 68540 ns/op 163 B/op 7 allocs/op 136 | BenchmarkDrivers/serial/pipeline/small_kv/redigo-16 770983 69976 ns/op 134 B/op 8 allocs/op 137 | BenchmarkDrivers/serial/pipeline/small_kv/redispipe_pause150us-16 175623 320512 ns/op 312 B/op 12 allocs/op 138 | BenchmarkDrivers/serial/pipeline/small_kv/redispipe_pause0-16 642673 82225 ns/op 312 B/op 12 allocs/op 139 | BenchmarkDrivers/serial/pipeline/small_kv/go-redis-16 787364 72240 ns/op 472 B/op 15 allocs/op 140 | ``` 141 | 142 | Serial large values: 143 | 144 | ``` 145 | BenchmarkDrivers/serial/no_pipeline/large_kv/radixv4-16 253586 217600 ns/op 12521 B/op 4 allocs/op 146 | BenchmarkDrivers/serial/no_pipeline/large_kv/radixv3-16 317356 179608 ns/op 12717 B/op 4 allocs/op 147 | BenchmarkDrivers/serial/no_pipeline/large_kv/redigo-16 244226 231179 ns/op 24704 B/op 8 allocs/op 148 | BenchmarkDrivers/serial/no_pipeline/large_kv/redispipe_pause150us-16 80174 674066 ns/op 13780 B/op 11 allocs/op 149 | BenchmarkDrivers/serial/no_pipeline/large_kv/redispipe_pause0-16 251810 209890 ns/op 13778 B/op 11 allocs/op 150 | BenchmarkDrivers/serial/no_pipeline/large_kv/go-redis-16 236379 225677 ns/op 13976 B/op 14 allocs/op 151 | ``` 152 | 153 | [bench results]: https://github.com/mediocregopher/radix/blob/v4/bench/bench_results.txt 154 | 155 | ## Copyright and licensing 156 | 157 | Unless otherwise noted, the source files are distributed under the *MIT License* 158 | found in the LICENSE.txt file. 159 | 160 | [redis]: http://redis.io 161 | [eval]: https://redis.io/commands/eval 162 | [scan]: https://redis.io/commands/scan 163 | [stream]: https://redis.io/topics/streams-intro 164 | [pipelining]: https://redis.io/topics/pipelining 165 | [pubsub]: https://redis.io/topics/pubsub 166 | [sentinel]: http://redis.io/topics/sentinel 167 | [cluster]: http://redis.io/topics/cluster-spec 168 | [module]: https://github.com/golang/go/wiki/Modules 169 | [redispipe]: https://github.com/joomcode/redispipe 170 | [context]: https://pkg.go.dev/context 171 | [resp3]: https://github.com/antirez/RESP3/blob/master/spec.md 172 | -------------------------------------------------------------------------------- /pool_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "io" 5 | "sync" 6 | "sync/atomic" 7 | . "testing" 8 | "time" 9 | 10 | "errors" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | 15 | "github.com/mediocregopher/radix/v3/resp" 16 | "github.com/mediocregopher/radix/v3/resp/resp2" 17 | "github.com/mediocregopher/radix/v3/trace" 18 | ) 19 | 20 | func testPool(size int, opts ...PoolOpt) *Pool { 21 | pool, err := NewPool("tcp", "localhost:6379", size, opts...) 22 | if err != nil { 23 | panic(err) 24 | } 25 | <-pool.initDone 26 | return pool 27 | } 28 | 29 | func TestPool(t *T) { 30 | 31 | testEcho := func(c Conn) error { 32 | exp := randStr() 33 | var out string 34 | if err := c.Do(Cmd(&out, "ECHO", exp)); err != nil { 35 | return err 36 | } 37 | assert.Equal(t, exp, out) 38 | return nil 39 | } 40 | 41 | do := func(opts ...PoolOpt) { 42 | opts = append(opts, PoolOnFullClose()) 43 | size := 10 44 | pool := testPool(size, opts...) 45 | var wg sync.WaitGroup 46 | for i := 0; i < size*4; i++ { 47 | wg.Add(1) 48 | go func() { 49 | for i := 0; i < 100; i++ { 50 | assert.NoError(t, pool.Do(WithConn("", testEcho))) 51 | } 52 | wg.Done() 53 | }() 54 | } 55 | wg.Wait() 56 | assert.Equal(t, size, pool.NumAvailConns()) 57 | pool.Close() 58 | assert.Equal(t, 0, pool.NumAvailConns()) 59 | } 60 | 61 | t.Run("onEmptyWait", func(t *T) { do(PoolOnEmptyWait()) }) 62 | t.Run("onEmptyCreate", func(t *T) { do(PoolOnEmptyCreateAfter(0)) }) 63 | t.Run("onEmptyCreateAfter", func(t *T) { do(PoolOnEmptyCreateAfter(1 * time.Second)) }) 64 | // This one is expected to error, since this test empties the pool by design 65 | //t.Run("onEmptyErr", func(t *T) { do(PoolOnEmptyErrAfter(0)) }) 66 | t.Run("onEmptyErrAfter", func(t *T) { do(PoolOnEmptyErrAfter(1 * time.Second)) }) 67 | t.Run("maxLifetime", func(t *T) { do(PoolMaxLifetime(1 * time.Second)) }) 68 | 69 | t.Run("withTrace", func(t *T) { 70 | var connCreatedCount int 71 | var connClosedCount int 72 | var doCompletedCount uint32 73 | var initializedAvailCount int 74 | pt := trace.PoolTrace{ 75 | ConnCreated: func(done trace.PoolConnCreated) { 76 | connCreatedCount++ 77 | }, 78 | ConnClosed: func(closed trace.PoolConnClosed) { 79 | connClosedCount++ 80 | }, 81 | DoCompleted: func(completed trace.PoolDoCompleted) { 82 | atomic.AddUint32(&doCompletedCount, 1) 83 | }, 84 | InitCompleted: func(completed trace.PoolInitCompleted) { 85 | initializedAvailCount = completed.AvailCount 86 | }, 87 | } 88 | do(PoolWithTrace(pt)) 89 | if initializedAvailCount != 10 { 90 | t.Fail() 91 | } 92 | if connCreatedCount != connClosedCount { 93 | t.Fail() 94 | } 95 | if doCompletedCount == 0 { 96 | t.Fail() 97 | } 98 | }) 99 | } 100 | 101 | // Test all the different OnEmpty behaviors. 102 | func TestPoolGet(t *T) { 103 | getBlock := func(p *Pool) (time.Duration, error) { 104 | start := time.Now() 105 | _, err := p.get() 106 | return time.Since(start), err 107 | } 108 | 109 | // this one is a bit weird, cause it would block infinitely if we let it 110 | t.Run("onEmptyWait", func(t *T) { 111 | pool := testPool(1, PoolOnEmptyWait()) 112 | conn, err := pool.get() 113 | assert.NoError(t, err) 114 | 115 | go func() { 116 | time.Sleep(2 * time.Second) 117 | pool.put(conn) 118 | }() 119 | took, err := getBlock(pool) 120 | assert.NoError(t, err) 121 | assert.True(t, took-2*time.Second < 20*time.Millisecond) 122 | }) 123 | 124 | // the rest are pretty straightforward 125 | gen := func(mkOpt func(time.Duration) PoolOpt, d time.Duration, expErr error) func(*T) { 126 | return func(t *T) { 127 | pool := testPool(0, PoolOnFullClose(), mkOpt(d)) 128 | took, err := getBlock(pool) 129 | assert.Equal(t, expErr, err) 130 | assert.True(t, took-d < 20*time.Millisecond) 131 | } 132 | } 133 | 134 | t.Run("onEmptyCreate", gen(PoolOnEmptyCreateAfter, 0, nil)) 135 | t.Run("onEmptyCreateAfter", gen(PoolOnEmptyCreateAfter, 1*time.Second, nil)) 136 | t.Run("onEmptyErr", gen(PoolOnEmptyErrAfter, 0, ErrPoolEmpty)) 137 | t.Run("onEmptyErrAfter", gen(PoolOnEmptyErrAfter, 1*time.Second, ErrPoolEmpty)) 138 | } 139 | 140 | func TestPoolOnFull(t *T) { 141 | t.Run("onFullClose", func(t *T) { 142 | var reason trace.PoolConnClosedReason 143 | pool := testPool(1, 144 | PoolOnFullClose(), 145 | PoolWithTrace(trace.PoolTrace{ConnClosed: func(c trace.PoolConnClosed) { 146 | reason = c.Reason 147 | }}), 148 | ) 149 | defer pool.Close() 150 | assert.Equal(t, 1, len(pool.pool)) 151 | 152 | spc, err := pool.newConn("TEST") 153 | assert.NoError(t, err) 154 | pool.put(spc) 155 | assert.Equal(t, 1, len(pool.pool)) 156 | assert.Equal(t, trace.PoolConnClosedReasonPoolFull, reason) 157 | }) 158 | 159 | t.Run("onFullBuffer", func(t *T) { 160 | pool := testPool(1, PoolOnFullBuffer(1, 1*time.Second)) 161 | defer pool.Close() 162 | assert.Equal(t, 1, len(pool.pool)) 163 | 164 | // putting a conn should overflow 165 | spc, err := pool.newConn("TEST") 166 | assert.NoError(t, err) 167 | pool.put(spc) 168 | assert.Equal(t, 2, len(pool.pool)) 169 | 170 | // another shouldn't, overflow is full 171 | spc, err = pool.newConn("TEST") 172 | assert.NoError(t, err) 173 | pool.put(spc) 174 | assert.Equal(t, 2, len(pool.pool)) 175 | 176 | // retrieve from the pool, drain shouldn't do anything because the 177 | // overflow is empty now 178 | <-pool.pool 179 | assert.Equal(t, 1, len(pool.pool)) 180 | time.Sleep(2 * time.Second) 181 | assert.Equal(t, 1, len(pool.pool)) 182 | 183 | // if both are full then drain should remove the overflow one 184 | spc, err = pool.newConn("TEST") 185 | assert.NoError(t, err) 186 | pool.put(spc) 187 | assert.Equal(t, 2, len(pool.pool)) 188 | time.Sleep(2 * time.Second) 189 | assert.Equal(t, 1, len(pool.pool)) 190 | }) 191 | } 192 | 193 | func TestPoolPut(t *T) { 194 | size := 10 195 | pool := testPool(size) 196 | 197 | assertPoolConns := func(exp int) { 198 | assert.Equal(t, exp, pool.NumAvailConns()) 199 | } 200 | assertPoolConns(10) 201 | 202 | // Make sure that put does not accept a connection which has had a critical 203 | // network error 204 | err := pool.Do(WithConn("", func(conn Conn) error { 205 | assertPoolConns(9) 206 | conn.(*ioErrConn).lastIOErr = io.EOF 207 | return nil 208 | })) 209 | assert.NoError(t, err) 210 | assertPoolConns(9) 211 | 212 | // Make sure that a put _does_ accept a connection which had a 213 | // marshal/unmarshal error 214 | err = pool.Do(WithConn("", func(conn Conn) error { 215 | assert.NotNil(t, conn.Do(FlatCmd(nil, "ECHO", "", func() {}))) 216 | assert.Nil(t, conn.(*ioErrConn).lastIOErr) 217 | return nil 218 | })) 219 | assert.NoError(t, err) 220 | assertPoolConns(9) 221 | 222 | // Make sure that a put _does_ accept a connection which had an app level 223 | // resp error 224 | err = pool.Do(WithConn("", func(conn Conn) error { 225 | assert.NotNil(t, Cmd(nil, "CMDDNE")) 226 | assert.Nil(t, conn.(*ioErrConn).lastIOErr) 227 | return nil 228 | })) 229 | assert.NoError(t, err) 230 | assertPoolConns(9) 231 | 232 | // Make sure that closing the pool closes outstanding connections as well 233 | closeCh := make(chan bool) 234 | go func() { 235 | <-closeCh 236 | assert.Nil(t, pool.Close()) 237 | closeCh <- true 238 | }() 239 | err = pool.Do(WithConn("", func(conn Conn) error { 240 | closeCh <- true 241 | <-closeCh 242 | return nil 243 | })) 244 | assert.NoError(t, err) 245 | assertPoolConns(0) 246 | } 247 | 248 | // TestPoolDoDoesNotBlock checks that with a positive onEmptyWait Pool.Do() 249 | // does not block longer than the timeout period given by user. 250 | func TestPoolDoDoesNotBlock(t *T) { 251 | size := 10 252 | requestTimeout := 200 * time.Millisecond 253 | redialInterval := 100 * time.Millisecond 254 | 255 | connFunc := PoolConnFunc(func(string, string) (Conn, error) { 256 | return dial(DialTimeout(requestTimeout)), nil 257 | }) 258 | pool := testPool(size, 259 | PoolOnEmptyCreateAfter(redialInterval), 260 | PoolPipelineWindow(0, 0), 261 | connFunc, 262 | ) 263 | 264 | assertPoolConns := func(exp int) { 265 | assert.Equal(t, exp, pool.NumAvailConns()) 266 | } 267 | assertPoolConns(size) 268 | 269 | var wg sync.WaitGroup 270 | var timeExceeded uint32 271 | 272 | // here we try to imitate external requests which come one at a time 273 | // and exceed the number of connections in pool 274 | for i := 0; i < 5*size; i++ { 275 | wg.Add(1) 276 | go func(i int) { 277 | time.Sleep(time.Duration(i*10) * time.Millisecond) 278 | 279 | timeStart := time.Now() 280 | err := pool.Do(WithConn("", func(conn Conn) error { 281 | time.Sleep(requestTimeout) 282 | conn.(*ioErrConn).lastIOErr = errors.New("i/o timeout") 283 | return nil 284 | })) 285 | assert.NoError(t, err) 286 | 287 | if time.Since(timeStart)-requestTimeout-redialInterval > 20*time.Millisecond { 288 | atomic.AddUint32(&timeExceeded, 1) 289 | } 290 | wg.Done() 291 | }(i) 292 | } 293 | 294 | wg.Wait() 295 | assert.True(t, timeExceeded == 0) 296 | } 297 | 298 | func TestPoolClose(t *T) { 299 | pool := testPool(1) 300 | assert.NoError(t, pool.Do(Cmd(nil, "PING"))) 301 | assert.NoError(t, pool.Close()) 302 | assert.Error(t, errClientClosed, pool.Do(Cmd(nil, "PING"))) 303 | } 304 | 305 | func TestIoErrConn(t *T) { 306 | t.Run("NotReusableAfterError", func(t *T) { 307 | dummyError := errors.New("i am error") 308 | 309 | ioc := newIOErrConn(Stub("tcp", "127.0.0.1:6379", nil)) 310 | ioc.lastIOErr = dummyError 311 | 312 | require.Equal(t, dummyError, ioc.Encode(&resp2.Any{})) 313 | require.Equal(t, dummyError, ioc.Decode(&resp2.Any{})) 314 | require.Nil(t, ioc.Close()) 315 | }) 316 | 317 | t.Run("ReusableAfterRESPError", func(t *T) { 318 | ioc := newIOErrConn(dial()) 319 | defer ioc.Close() 320 | 321 | err1 := ioc.Do(Cmd(nil, "EVAL", "Z", "0")) 322 | require.True(t, errors.As(err1, new(resp.ErrDiscarded))) 323 | require.True(t, errors.As(err1, new(resp2.Error))) 324 | 325 | err2 := ioc.Do(Cmd(nil, "GET", randStr())) 326 | require.Nil(t, err2) 327 | }) 328 | } 329 | -------------------------------------------------------------------------------- /pubsub_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "log" 5 | "math/rand" 6 | "strconv" 7 | "sync" 8 | . "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func publish(t *T, c Conn, ch, msg string) { 16 | require.Nil(t, c.Do(Cmd(nil, "PUBLISH", ch, msg))) 17 | } 18 | 19 | func assertMsgRead(_ *T, msgCh <-chan PubSubMessage) PubSubMessage { 20 | select { 21 | case m := <-msgCh: 22 | return m 23 | case <-time.After(5 * time.Second): 24 | panic("timedout reading") 25 | } 26 | } 27 | 28 | func assertMsgNoRead(t *T, msgCh <-chan PubSubMessage) { 29 | select { 30 | case msg, ok := <-msgCh: 31 | if !ok { 32 | assert.Fail(t, "msgCh closed") 33 | } else { 34 | assert.Fail(t, "unexpected PubSubMessage off msgCh", "msg:%#v", msg) 35 | } 36 | default: 37 | } 38 | } 39 | 40 | func testSubscribe(t *T, c PubSubConn, pubCh chan int) { 41 | pubC := dial() 42 | msgCh := make(chan PubSubMessage, 1) 43 | 44 | ch1, ch2, msgStr := randStr(), randStr(), randStr() 45 | require.Nil(t, c.Subscribe(msgCh, ch1, ch2)) 46 | 47 | pubChs := make([]chan int, 3) 48 | { 49 | for i := range pubChs { 50 | pubChs[i] = make(chan int) 51 | } 52 | go func() { 53 | for i := range pubCh { 54 | for _, innerPubCh := range pubChs { 55 | innerPubCh <- i 56 | } 57 | } 58 | for _, innerPubCh := range pubChs { 59 | close(innerPubCh) 60 | } 61 | }() 62 | } 63 | 64 | wg := new(sync.WaitGroup) 65 | wg.Add(1) 66 | go func() { 67 | for i := range pubChs[0] { 68 | publish(t, pubC, ch1, msgStr+"_"+strconv.Itoa(i)) 69 | } 70 | wg.Done() 71 | }() 72 | 73 | wg.Add(1) 74 | go func() { 75 | for i := range pubChs[1] { 76 | msg := assertMsgRead(t, msgCh) 77 | assert.Equal(t, PubSubMessage{ 78 | Type: "message", 79 | Channel: ch1, 80 | Message: []byte(msgStr + "_" + strconv.Itoa(i)), 81 | }, msg) 82 | } 83 | wg.Done() 84 | }() 85 | 86 | wg.Add(1) 87 | go func() { 88 | for range pubChs[2] { 89 | require.Nil(t, c.Ping()) 90 | } 91 | wg.Done() 92 | }() 93 | wg.Wait() 94 | 95 | require.Nil(t, c.Unsubscribe(msgCh, ch1)) 96 | publish(t, pubC, ch1, msgStr) 97 | publish(t, pubC, ch2, msgStr) 98 | msg := assertMsgRead(t, msgCh) 99 | assert.Equal(t, PubSubMessage{ 100 | Type: "message", 101 | Channel: ch2, 102 | Message: []byte(msgStr), 103 | }, msg) 104 | 105 | } 106 | 107 | func TestPubSubSubscribe(t *T) { 108 | pubCh := make(chan int) 109 | go func() { 110 | for i := 0; i < 1000; i++ { 111 | pubCh <- i 112 | } 113 | close(pubCh) 114 | }() 115 | c := PubSub(dial()) 116 | testSubscribe(t, c, pubCh) 117 | 118 | c.Close() 119 | assert.NotNil(t, c.Ping()) 120 | assert.NotNil(t, c.Ping()) 121 | assert.NotNil(t, c.Ping()) 122 | } 123 | 124 | func TestPubSubPSubscribe(t *T) { 125 | pubC := dial() 126 | c := PubSub(dial()) 127 | msgCh := make(chan PubSubMessage, 1) 128 | 129 | p1, p2, msgStr := randStr()+"_*", randStr()+"_*", randStr() 130 | ch1, ch2 := p1+"_"+randStr(), p2+"_"+randStr() 131 | p3, p4 := randStr()+"_?", randStr()+"_[ae]" 132 | ch3, ch4 := p3[:len(p3)-len("?")]+"a", p4[:len(p4)-len("[ae]")]+"a" 133 | require.Nil(t, c.PSubscribe(msgCh, p1, p2, p3, p4)) 134 | 135 | count := 1000 136 | wg := new(sync.WaitGroup) 137 | 138 | wg.Add(1) 139 | go func() { 140 | for i := 0; i < count; i++ { 141 | publish(t, pubC, ch1, msgStr) 142 | } 143 | wg.Done() 144 | }() 145 | 146 | wg.Add(1) 147 | go func() { 148 | for i := 0; i < count; i++ { 149 | msg := assertMsgRead(t, msgCh) 150 | assert.Equal(t, PubSubMessage{ 151 | Type: "pmessage", 152 | Pattern: p1, 153 | Channel: ch1, 154 | Message: []byte(msgStr), 155 | }, msg) 156 | } 157 | wg.Done() 158 | }() 159 | 160 | wg.Add(1) 161 | go func() { 162 | for i := 0; i < count; i++ { 163 | require.Nil(t, c.Ping()) 164 | } 165 | wg.Done() 166 | }() 167 | 168 | wg.Wait() 169 | 170 | require.Nil(t, c.PUnsubscribe(msgCh, p1)) 171 | publish(t, pubC, ch1, msgStr) 172 | publish(t, pubC, ch2, msgStr) 173 | msg := assertMsgRead(t, msgCh) 174 | assert.Equal(t, PubSubMessage{ 175 | Type: "pmessage", 176 | Pattern: p2, 177 | Channel: ch2, 178 | Message: []byte(msgStr), 179 | }, msg) 180 | 181 | publish(t, pubC, ch3, msgStr) 182 | msg = assertMsgRead(t, msgCh) 183 | assert.Equal(t, PubSubMessage{ 184 | Type: "pmessage", 185 | Pattern: p3, 186 | Channel: ch3, 187 | Message: []byte(msgStr), 188 | }, msg) 189 | 190 | publish(t, pubC, ch4, msgStr) 191 | msg = assertMsgRead(t, msgCh) 192 | assert.Equal(t, PubSubMessage{ 193 | Type: "pmessage", 194 | Pattern: p4, 195 | Channel: ch4, 196 | Message: []byte(msgStr), 197 | }, msg) 198 | 199 | c.Close() 200 | assert.NotNil(t, c.Ping()) 201 | assert.NotNil(t, c.Ping()) 202 | assert.NotNil(t, c.Ping()) 203 | publish(t, pubC, ch2, msgStr) 204 | time.Sleep(250 * time.Millisecond) 205 | assertMsgNoRead(t, msgCh) 206 | } 207 | 208 | func TestPubSubMixedSubscribe(t *T) { 209 | pubC := dial() 210 | defer pubC.Close() 211 | 212 | c := PubSub(dial()) 213 | defer c.Close() 214 | 215 | msgCh := make(chan PubSubMessage, 2) 216 | 217 | const msgStr = "bar" 218 | 219 | require.Nil(t, c.Subscribe(msgCh, "foo")) 220 | require.Nil(t, c.PSubscribe(msgCh, "f[aeiou]o")) 221 | 222 | publish(t, pubC, "foo", msgStr) 223 | 224 | msg1, msg2 := assertMsgRead(t, msgCh), assertMsgRead(t, msgCh) 225 | 226 | // If we received the pmessage first we must swap msg1 and msg1. 227 | if msg1.Type == "pmessage" { 228 | msg1, msg2 = msg2, msg1 229 | } 230 | 231 | assert.Equal(t, PubSubMessage{ 232 | Type: "message", 233 | Channel: "foo", 234 | Message: []byte(msgStr), 235 | }, msg1) 236 | 237 | assert.Equal(t, PubSubMessage{ 238 | Type: "pmessage", 239 | Channel: "foo", 240 | Pattern: "f[aeiou]o", 241 | Message: []byte(msgStr), 242 | }, msg2) 243 | } 244 | 245 | // Ensure that PubSubConn properly handles the case where the Conn it's reading 246 | // from returns a timeout error. 247 | func TestPubSubTimeout(t *T) { 248 | c, pubC := PubSub(dial(DialReadTimeout(1*time.Second))), dial() 249 | c.(*pubSubConn).testEventCh = make(chan string, 1) 250 | 251 | ch, msgCh := randStr(), make(chan PubSubMessage, 1) 252 | require.Nil(t, c.Subscribe(msgCh, ch)) 253 | 254 | msgStr := randStr() 255 | go func() { 256 | time.Sleep(2 * time.Second) 257 | assert.Nil(t, pubC.Do(Cmd(nil, "PUBLISH", ch, msgStr))) 258 | }() 259 | 260 | assert.Equal(t, "timeout", <-c.(*pubSubConn).testEventCh) 261 | msg := assertMsgRead(t, msgCh) 262 | assert.Equal(t, msgStr, string(msg.Message)) 263 | } 264 | 265 | // This attempts to catch weird race conditions which might occur due to 266 | // subscribing/unsubscribing quickly on an active channel. 267 | func TestPubSubChaotic(t *T) { 268 | c, pubC := PubSub(dial()), dial() 269 | ch, msgStr := randStr(), randStr() 270 | 271 | stopCh := make(chan struct{}) 272 | defer close(stopCh) 273 | go func() { 274 | for { 275 | select { 276 | case <-stopCh: 277 | return 278 | default: 279 | publish(t, pubC, ch, msgStr) 280 | time.Sleep(10 * time.Millisecond) 281 | } 282 | } 283 | }() 284 | 285 | msgCh := make(chan PubSubMessage, 100) 286 | require.Nil(t, c.Subscribe(msgCh, ch)) 287 | 288 | stopAfter := time.After(10 * time.Second) 289 | toggleTimer := time.Tick(250 * time.Millisecond) 290 | subbed := true 291 | for { 292 | waitFor := time.NewTimer(100 * time.Millisecond) 293 | select { 294 | case <-stopAfter: 295 | return 296 | case <-waitFor.C: 297 | if subbed { 298 | t.Fatal("waited too long to receive message") 299 | } 300 | case msg := <-msgCh: 301 | waitFor.Stop() 302 | assert.Equal(t, msgStr, string(msg.Message)) 303 | case <-toggleTimer: 304 | waitFor.Stop() 305 | if subbed { 306 | require.Nil(t, c.Unsubscribe(msgCh, ch)) 307 | } else { 308 | require.Nil(t, c.Subscribe(msgCh, ch)) 309 | } 310 | subbed = !subbed 311 | } 312 | } 313 | } 314 | 315 | func BenchmarkPubSub(b *B) { 316 | c, pubC := PubSub(dial()), dial() 317 | defer c.Close() 318 | defer pubC.Close() 319 | 320 | msg := randStr() 321 | msgCh := make(chan PubSubMessage, 1) 322 | require.Nil(b, c.Subscribe(msgCh, "benchmark")) 323 | 324 | b.ResetTimer() 325 | 326 | for i := 0; i < b.N; i++ { 327 | if err := pubC.Do(Cmd(nil, "PUBLISH", "benchmark", msg)); err != nil { 328 | b.Fatal(err) 329 | } 330 | <-msgCh 331 | } 332 | } 333 | 334 | func ExamplePubSub() { 335 | // Create a normal redis connection 336 | conn, err := Dial("tcp", "127.0.0.1:6379") 337 | if err != nil { 338 | panic(err) 339 | } 340 | 341 | // Pass that connection into PubSub, conn should never get used after this 342 | ps := PubSub(conn) 343 | defer ps.Close() // this will close Conn as well 344 | 345 | // Subscribe to a channel called "myChannel". All publishes to "myChannel" 346 | // will get sent to msgCh after this 347 | msgCh := make(chan PubSubMessage) 348 | if err := ps.Subscribe(msgCh, "myChannel"); err != nil { 349 | panic(err) 350 | } 351 | 352 | // It's optional, but generally advisable, to periodically Ping the 353 | // connection to ensure it's still alive. This should be done in a separate 354 | // go-routine from that which is reading from msgCh. 355 | errCh := make(chan error, 1) 356 | go func() { 357 | ticker := time.NewTicker(5 * time.Second) 358 | defer ticker.Stop() 359 | for range ticker.C { 360 | if err := ps.Ping(); err != nil { 361 | errCh <- err 362 | return 363 | } 364 | } 365 | }() 366 | 367 | for { 368 | select { 369 | case msg := <-msgCh: 370 | log.Printf("publish to channel %q received: %q", msg.Channel, msg.Message) 371 | case err := <-errCh: 372 | panic(err) 373 | } 374 | } 375 | } 376 | 377 | func ExamplePersistentPubSub_cluster() { 378 | // Example of how to use PersistentPubSub with a Cluster instance. 379 | 380 | // Initialize the cluster in any way you see fit 381 | cluster, err := NewCluster([]string{"127.0.0.1:6379"}) 382 | if err != nil { 383 | // handle error 384 | } 385 | 386 | // Have PersistentPubSub pick a random cluster node everytime it wants to 387 | // make a new connection. If the node fails PersistentPubSub will 388 | // automatically pick a new node to connect to. 389 | ps := PersistentPubSub("", "", func(string, string) (Conn, error) { 390 | topo := cluster.Topo() 391 | node := topo[rand.Intn(len(topo))] 392 | return Dial("tcp", node.Addr) 393 | }) 394 | 395 | // Use the PubSubConn as normal. 396 | msgCh := make(chan PubSubMessage) 397 | 398 | if err = ps.Subscribe(msgCh, "myChannel"); err != nil { 399 | // handle error 400 | } 401 | 402 | for msg := range msgCh { 403 | log.Printf("publish to channel %q received: %q", msg.Channel, msg.Message) 404 | } 405 | } 406 | -------------------------------------------------------------------------------- /cluster_test.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | . "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/mediocregopher/radix/v3/trace" 11 | ) 12 | 13 | // clusterSlotKeys contains a random key for every slot. Unfortunately I haven't 14 | // come up with a better way to do this than brute force. It takes less than a 15 | // second on my laptop, so whatevs. 16 | var clusterSlotKeys = func() [numSlots]string { 17 | var a [numSlots]string 18 | var found int 19 | for found < len(a) { 20 | // we get a set of random characters and try increasingly larger subsets 21 | // of that set until one is in a slot which hasn't been set yet. This is 22 | // optimal because it minimizes the number of reads from random needed 23 | // to fill a slot, and the keys being filled are of minimal size. 24 | k := []byte(randStr()) 25 | for i := 1; i <= len(k); i++ { 26 | ksmall := k[:i] 27 | if a[ClusterSlot(ksmall)] == "" { 28 | a[ClusterSlot(ksmall)] = string(ksmall) 29 | found++ 30 | break 31 | } 32 | } 33 | } 34 | return a 35 | }() 36 | 37 | func newTestCluster(opts ...ClusterOpt) (*Cluster, *clusterStub) { 38 | scl := newStubCluster(testTopo) 39 | return scl.newCluster(opts...), scl 40 | } 41 | 42 | // sanity check that Cluster is a client. 43 | func TestClusterClient(t *T) { 44 | c, _ := newTestCluster() 45 | defer c.Close() 46 | assert.Implements(t, new(Client), c) 47 | } 48 | 49 | func makeFailedFlagMap(addrs []string) map[string]bool { 50 | failedAddrsFlag := make(map[string]bool) 51 | for _, addr := range addrs { 52 | failedAddrsFlag[addr] = true 53 | } 54 | return failedAddrsFlag 55 | } 56 | 57 | func TestClusterInitSync(t *T) { 58 | scl := newStubCluster(testTopo) 59 | serverAddrs := scl.addrs() 60 | 61 | //part of the addresses are unavailable during the initialization 62 | //and recover after that, call Sync to test whether it can work 63 | { 64 | c, err := scl.newInitSyncErrorCluster(serverAddrs, 65 | makeFailedFlagMap(serverAddrs[0:len(serverAddrs)/2]), 66 | ClusterOnInitAllowUnavailable(true)) 67 | require.Nil(t, err) 68 | defer c.Close() 69 | require.Nil(t, c.Sync()) 70 | } 71 | 72 | //part of the addresses are unavailable during the initialization 73 | //and recover after that, try Set and Get cmd to test whether it can work 74 | { 75 | c, err := scl.newInitSyncErrorCluster(serverAddrs, 76 | makeFailedFlagMap(serverAddrs[len(serverAddrs)/2:]), 77 | ClusterOnInitAllowUnavailable(true)) 78 | require.Nil(t, err) 79 | defer c.Close() 80 | // find the address's slot 81 | var targetStub *clusterNodeStub 82 | for i := len(serverAddrs) / 2; i < len(serverAddrs); i++ { 83 | targetStub = scl.stubs[serverAddrs[i]] 84 | if slotRanges := targetStub.slotRanges(); len(slotRanges) != 0 { 85 | break 86 | } 87 | } 88 | require.NotNil(t, targetStub) 89 | client, _ := c.rpool(targetStub.addr) 90 | require.Nil(t, client) 91 | 92 | slotRanges := targetStub.slotRanges()[0] 93 | targetSlotNum := (slotRanges[1] + slotRanges[0]) / 2 94 | t.Logf("the target addr for set and get is %s, slotnum= %d", targetStub.addr, targetSlotNum) 95 | k, v := clusterSlotKeys[targetSlotNum], randStr() 96 | t.Logf("call set, key=%s, v=%s", k, v) 97 | require.Nil(t, c.Do(Cmd(nil, "SET", k, v))) 98 | var vgot string 99 | require.Nil(t, c.Do(Cmd(&vgot, "GET", k))) 100 | assert.Equal(t, v, vgot) 101 | } 102 | //all addresses are unavailable and the call of NewCluster will get an error 103 | { 104 | _, err := scl.newInitSyncErrorCluster(serverAddrs, 105 | makeFailedFlagMap(serverAddrs), 106 | ClusterOnInitAllowUnavailable(true)) 107 | assert.NotNil(t, err) 108 | } 109 | } 110 | 111 | func TestClusterSync(t *T) { 112 | c, scl := newTestCluster() 113 | defer c.Close() 114 | assertClusterState := func() { 115 | require.Nil(t, c.Sync()) 116 | c.l.RLock() 117 | defer c.l.RUnlock() 118 | assert.Equal(t, c.topo, scl.topo()) 119 | assert.Len(t, c.pools, len(c.topo)) 120 | for _, node := range c.topo { 121 | assert.Contains(t, c.pools, node.Addr) 122 | } 123 | } 124 | assertClusterState() 125 | 126 | // cluster is unstable af 127 | for i := 0; i < 10; i++ { 128 | // find a usabel src/dst 129 | var srcStub, dstStub *clusterNodeStub 130 | for { 131 | srcStub = scl.randStub() 132 | dstStub = scl.randStub() 133 | if srcStub.addr == dstStub.addr { 134 | continue 135 | } else if slotRanges := srcStub.slotRanges(); len(slotRanges) == 0 { 136 | continue 137 | } 138 | break 139 | } 140 | 141 | // move src's first slot range to dst 142 | slotRange := srcStub.slotRanges()[0] 143 | t.Logf("moving %d:%d from %s to %s", slotRange[0], slotRange[1], srcStub.addr, dstStub.addr) 144 | scl.migrateSlotRange(dstStub.addr, slotRange[0], slotRange[1]) 145 | assertClusterState() 146 | } 147 | } 148 | 149 | func TestClusterGet(t *T) { 150 | c, _ := newTestCluster() 151 | defer c.Close() 152 | for s := uint16(0); s < numSlots; s++ { 153 | require.Nil(t, c.Do(Cmd(nil, "GET", clusterSlotKeys[s]))) 154 | } 155 | } 156 | 157 | func TestClusterDo(t *T) { 158 | var lastRedirect trace.ClusterRedirected 159 | c, scl := newTestCluster(ClusterWithTrace(trace.ClusterTrace{ 160 | Redirected: func(r trace.ClusterRedirected) { lastRedirect = r }, 161 | })) 162 | defer c.Close() 163 | stub0 := scl.stubForSlot(0) 164 | stub16k := scl.stubForSlot(16000) 165 | 166 | // sanity check before we start, these shouldn't have the same address 167 | require.NotEqual(t, stub0.addr, stub16k.addr) 168 | 169 | // basic Cmd 170 | k, v := clusterSlotKeys[0], randStr() 171 | require.Nil(t, c.Do(Cmd(nil, "SET", k, v))) 172 | { 173 | var vgot string 174 | require.Nil(t, c.Do(Cmd(&vgot, "GET", k))) 175 | assert.Equal(t, v, vgot) 176 | assert.Equal(t, trace.ClusterRedirected{}, lastRedirect) 177 | } 178 | 179 | // use doInner to hit the wrong node originally, Do should get a MOVED error 180 | // and end up at the correct node 181 | { 182 | var vgot string 183 | cmd := Cmd(&vgot, "GET", k) 184 | require.Nil(t, c.doInner(cmd, stub16k.addr, k, false, doAttempts)) 185 | assert.Equal(t, v, vgot) 186 | assert.Equal(t, trace.ClusterRedirected{ 187 | Addr: stub16k.addr, 188 | Key: k, 189 | Moved: true, 190 | RedirectCount: 1, 191 | }, lastRedirect) 192 | } 193 | 194 | // start a migration and migrate the key, which should trigger an ASK when 195 | // we hit stub0 for the key 196 | { 197 | scl.migrateInit(stub16k.addr, 0) 198 | scl.migrateKey(k) 199 | var vgot string 200 | require.Nil(t, c.Do(Cmd(&vgot, "GET", k))) 201 | assert.Equal(t, v, vgot) 202 | assert.Equal(t, trace.ClusterRedirected{ 203 | Addr: stub0.addr, 204 | Key: k, 205 | Ask: true, 206 | RedirectCount: 1, 207 | }, lastRedirect) 208 | } 209 | 210 | // Finish the migration, there should not be anymore redirects 211 | { 212 | scl.migrateAllKeys(0) 213 | scl.migrateDone(0) 214 | lastRedirect = trace.ClusterRedirected{} 215 | var vgot string 216 | require.Nil(t, c.Sync()) 217 | require.Nil(t, c.Do(Cmd(&vgot, "GET", k))) 218 | assert.Equal(t, v, vgot) 219 | assert.Equal(t, trace.ClusterRedirected{}, lastRedirect) 220 | } 221 | } 222 | 223 | func TestClusterDoWhenDown(t *T) { 224 | var stub *clusterNodeStub 225 | 226 | var isDown bool 227 | 228 | c, scl := newTestCluster( 229 | ClusterOnDownDelayActionsBy(50*time.Millisecond), 230 | ClusterWithTrace(trace.ClusterTrace{ 231 | StateChange: func(d trace.ClusterStateChange) { 232 | isDown = d.IsDown 233 | 234 | if d.IsDown { 235 | time.AfterFunc(75*time.Millisecond, func() { 236 | stub.addSlot(0) 237 | }) 238 | } 239 | }, 240 | }), 241 | ) 242 | defer c.Close() 243 | 244 | stub = scl.stubForSlot(0) 245 | stub.removeSlot(0) 246 | 247 | k := clusterSlotKeys[0] 248 | 249 | err := c.Do(Cmd(nil, "GET", k)) 250 | assert.EqualError(t, err, "CLUSTERDOWN Hash slot not served") 251 | assert.True(t, isDown) 252 | 253 | err = c.Do(Cmd(nil, "GET", k)) 254 | assert.Nil(t, err) 255 | assert.False(t, isDown) 256 | } 257 | 258 | func BenchmarkClusterDo(b *B) { 259 | c, _ := newTestCluster() 260 | defer c.Close() 261 | 262 | k, v := clusterSlotKeys[0], randStr() 263 | require.Nil(b, c.Do(Cmd(nil, "SET", k, v))) 264 | 265 | b.ResetTimer() 266 | 267 | for i := 0; i < b.N; i++ { 268 | require.Nil(b, c.Do(Cmd(nil, "GET", k))) 269 | } 270 | } 271 | 272 | func TestClusterEval(t *T) { 273 | c, scl := newTestCluster() 274 | defer c.Close() 275 | key := clusterSlotKeys[0] 276 | dst := scl.stubForSlot(10000) 277 | scl.migrateInit(dst.addr, 0) 278 | // now, when interacting with key, the stub should return an ASK error 279 | 280 | eval := NewEvalScript(1, `return nil`) 281 | var rcv string 282 | err := c.Do(eval.Cmd(&rcv, key, "foo")) 283 | 284 | assert.Nil(t, err) 285 | assert.Equal(t, "EVAL: success!", rcv) 286 | } 287 | 288 | func TestClusterEvalRcvInterface(t *T) { 289 | c, scl := newTestCluster() 290 | defer c.Close() 291 | key := clusterSlotKeys[0] 292 | dst := scl.stubForSlot(10000) 293 | scl.migrateInit(dst.addr, 0) 294 | // now, when interacting with key, the stub should return an ASK error 295 | 296 | eval := NewEvalScript(1, `return nil`) 297 | var rcv interface{} 298 | err := c.Do(eval.Cmd(&rcv, key, "foo")) 299 | 300 | assert.Nil(t, err) 301 | assert.Equal(t, []byte("EVAL: success!"), rcv) 302 | } 303 | 304 | func TestClusterDoSecondary(t *T) { 305 | var redirects int 306 | c, _ := newTestCluster( 307 | ClusterWithTrace(trace.ClusterTrace{ 308 | Redirected: func(trace.ClusterRedirected) { 309 | redirects++ 310 | }, 311 | }), 312 | ) 313 | defer c.Close() 314 | 315 | key := clusterSlotKeys[0] 316 | value := randStr() 317 | 318 | require.NoError(t, c.Do(Cmd(nil, "SET", key, value))) 319 | 320 | var res1 string 321 | require.NoError(t, c.Do(Cmd(&res1, "GET", key))) 322 | require.Equal(t, value, res1) 323 | 324 | require.Zero(t, redirects) 325 | 326 | var res2 string 327 | assert.NoError(t, c.DoSecondary(Cmd(&res2, "GET", key))) 328 | assert.Equal(t, value, res2) 329 | assert.Equal(t, 1, redirects) 330 | 331 | var secAddr string 332 | for secAddr = range c.secondaries[c.addrForKey(key)] { 333 | break 334 | } 335 | sec, err := c.Client(secAddr) 336 | require.NoError(t, err) 337 | 338 | assert.NoError(t, sec.Do(Cmd(nil, "READONLY"))) 339 | assert.Equal(t, 1, redirects) 340 | 341 | var res3 string 342 | assert.NoError(t, c.DoSecondary(Cmd(&res3, "GET", key))) 343 | assert.Equal(t, value, res3) 344 | assert.Equal(t, 1, redirects) 345 | 346 | assert.NoError(t, sec.Do(Cmd(nil, "READWRITE"))) 347 | assert.Equal(t, 1, redirects) 348 | 349 | var res4 string 350 | assert.NoError(t, c.DoSecondary(Cmd(&res4, "GET", key))) 351 | assert.Equal(t, value, res4) 352 | assert.Equal(t, 2, redirects) 353 | } 354 | 355 | var clusterAddrs []string 356 | 357 | func ExampleClusterPoolFunc_defaultClusterConnFunc() { 358 | 359 | cluster, err := NewCluster(clusterAddrs, ClusterPoolFunc(func(network, addr string) (Client, error) { 360 | return NewPool(network, addr, 4, PoolConnFunc(DefaultClusterConnFunc)) 361 | })) 362 | 363 | if err != nil { 364 | // handle err 365 | } 366 | 367 | _ = cluster // use the cluster 368 | } 369 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "bufio" 5 | "crypto/tls" 6 | "net" 7 | "net/url" 8 | "strconv" 9 | "strings" 10 | "time" 11 | 12 | "github.com/mediocregopher/radix/v3/resp" 13 | ) 14 | 15 | // Conn is a Client wrapping a single network connection which synchronously 16 | // reads/writes data using the redis resp protocol. 17 | // 18 | // A Conn can be used directly as a Client, but in general you probably want to 19 | // use a *Pool instead. 20 | type Conn interface { 21 | // The Do method of a Conn is _not_ expected to be thread-safe with the 22 | // other methods of Conn, and merely calls the Action's Run method with 23 | // itself as the argument. 24 | Client 25 | 26 | // Encode and Decode may be called at the same time by two different 27 | // go-routines, but each should only be called once at a time (i.e. two 28 | // routines shouldn't call Encode at the same time, same with Decode). 29 | // 30 | // Encode and Decode should _not_ be called at the same time as Do. 31 | // 32 | // If either Encode or Decode encounter a net.Error the Conn will be 33 | // automatically closed. 34 | // 35 | // Encode is expected to encode an entire resp message, not a partial one. 36 | // In other words, when sending commands to redis, Encode should only be 37 | // called once per command. Similarly, Decode is expected to decode an 38 | // entire resp response. 39 | Encode(resp.Marshaler) error 40 | Decode(resp.Unmarshaler) error 41 | 42 | // Returns the underlying network connection, as-is. Read, Write, and Close 43 | // should not be called on the returned Conn. 44 | NetConn() net.Conn 45 | } 46 | 47 | // ConnFunc is a function which returns an initialized, ready-to-be-used Conn. 48 | // Functions like NewPool or NewCluster take in a ConnFunc in order to allow for 49 | // things like calls to AUTH on each new connection, setting timeouts, custom 50 | // Conn implementations, etc... See the package docs for more details. 51 | type ConnFunc func(network, addr string) (Conn, error) 52 | 53 | // DefaultConnFunc is a ConnFunc which will return a Conn for a redis instance 54 | // using sane defaults. 55 | var DefaultConnFunc = func(network, addr string) (Conn, error) { 56 | return Dial(network, addr) 57 | } 58 | 59 | func wrapDefaultConnFunc(addr string) ConnFunc { 60 | _, opts := parseRedisURL(addr) 61 | return func(network, addr string) (Conn, error) { 62 | return Dial(network, addr, opts...) 63 | } 64 | } 65 | 66 | type connWrap struct { 67 | net.Conn 68 | brw *bufio.ReadWriter 69 | } 70 | 71 | // NewConn takes an existing net.Conn and wraps it to support the Conn interface 72 | // of this package. The Read and Write methods on the original net.Conn should 73 | // not be used after calling this method. 74 | func NewConn(conn net.Conn) Conn { 75 | return &connWrap{ 76 | Conn: conn, 77 | brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), 78 | } 79 | } 80 | 81 | func (cw *connWrap) Do(a Action) error { 82 | return a.Run(cw) 83 | } 84 | 85 | func (cw *connWrap) Encode(m resp.Marshaler) error { 86 | if err := m.MarshalRESP(cw.brw); err != nil { 87 | return err 88 | } 89 | return cw.brw.Flush() 90 | } 91 | 92 | func (cw *connWrap) Decode(u resp.Unmarshaler) error { 93 | return u.UnmarshalRESP(cw.brw.Reader) 94 | } 95 | 96 | func (cw *connWrap) NetConn() net.Conn { 97 | return cw.Conn 98 | } 99 | 100 | type dialOpts struct { 101 | connectTimeout, readTimeout, writeTimeout time.Duration 102 | authUser, authPass string 103 | selectDB string 104 | useTLSConfig bool 105 | tlsConfig *tls.Config 106 | } 107 | 108 | // DialOpt is an optional behavior which can be applied to the Dial function to 109 | // effect its behavior, or the behavior of the Conn it creates. 110 | type DialOpt func(*dialOpts) 111 | 112 | // DialConnectTimeout determines the timeout value to pass into net.DialTimeout 113 | // when creating the connection. If not set then net.Dial is called instead. 114 | func DialConnectTimeout(d time.Duration) DialOpt { 115 | return func(do *dialOpts) { 116 | do.connectTimeout = d 117 | } 118 | } 119 | 120 | // DialReadTimeout determines the deadline to set when reading from a dialed 121 | // connection. If not set then SetReadDeadline is never called. 122 | func DialReadTimeout(d time.Duration) DialOpt { 123 | return func(do *dialOpts) { 124 | do.readTimeout = d 125 | } 126 | } 127 | 128 | // DialWriteTimeout determines the deadline to set when writing to a dialed 129 | // connection. If not set then SetWriteDeadline is never called. 130 | func DialWriteTimeout(d time.Duration) DialOpt { 131 | return func(do *dialOpts) { 132 | do.writeTimeout = d 133 | } 134 | } 135 | 136 | // DialTimeout is the equivalent to using DialConnectTimeout, DialReadTimeout, 137 | // and DialWriteTimeout all with the same value. 138 | func DialTimeout(d time.Duration) DialOpt { 139 | return func(do *dialOpts) { 140 | DialConnectTimeout(d)(do) 141 | DialReadTimeout(d)(do) 142 | DialWriteTimeout(d)(do) 143 | } 144 | } 145 | 146 | const defaultAuthUser = "default" 147 | 148 | // DialAuthPass will cause Dial to perform an AUTH command once the connection 149 | // is created, using the given pass. 150 | // 151 | // If this is set and a redis URI is passed to Dial which also has a password 152 | // set, this takes precedence. 153 | // 154 | // Using DialAuthPass is equivalent to calling DialAuthUser with user "default" 155 | // and is kept for compatibility with older package versions. 156 | func DialAuthPass(pass string) DialOpt { 157 | return DialAuthUser(defaultAuthUser, pass) 158 | } 159 | 160 | // DialAuthUser will cause Dial to perform an AUTH command once the connection 161 | // is created, using the given user and pass. 162 | // 163 | // If this is set and a redis URI is passed to Dial which also has a username 164 | // and password set, this takes precedence. 165 | func DialAuthUser(user, pass string) DialOpt { 166 | return func(do *dialOpts) { 167 | do.authUser = user 168 | do.authPass = pass 169 | } 170 | } 171 | 172 | // DialSelectDB will cause Dial to perform a SELECT command once the connection 173 | // is created, using the given database index. 174 | // 175 | // If this is set and a redis URI is passed to Dial which also has a database 176 | // index set, this takes precedence. 177 | func DialSelectDB(db int) DialOpt { 178 | return func(do *dialOpts) { 179 | do.selectDB = strconv.Itoa(db) 180 | } 181 | } 182 | 183 | // DialUseTLS will cause Dial to perform a TLS handshake using the provided 184 | // config. If config is nil the config is interpreted as equivalent to the zero 185 | // configuration. See https://golang.org/pkg/crypto/tls/#Config 186 | func DialUseTLS(config *tls.Config) DialOpt { 187 | return func(do *dialOpts) { 188 | do.tlsConfig = config 189 | do.useTLSConfig = true 190 | } 191 | } 192 | 193 | type timeoutConn struct { 194 | net.Conn 195 | readTimeout, writeTimeout time.Duration 196 | } 197 | 198 | func (tc *timeoutConn) Read(b []byte) (int, error) { 199 | if tc.readTimeout > 0 { 200 | err := tc.Conn.SetReadDeadline(time.Now().Add(tc.readTimeout)) 201 | if err != nil { 202 | return 0, err 203 | } 204 | } 205 | return tc.Conn.Read(b) 206 | } 207 | 208 | func (tc *timeoutConn) Write(b []byte) (int, error) { 209 | if tc.writeTimeout > 0 { 210 | err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout)) 211 | if err != nil { 212 | return 0, err 213 | } 214 | } 215 | return tc.Conn.Write(b) 216 | } 217 | 218 | var defaultDialOpts = []DialOpt{ 219 | DialTimeout(10 * time.Second), 220 | } 221 | 222 | func parseRedisURL(urlStr string) (string, []DialOpt) { 223 | // do a quick check before we bust out url.Parse, in case that is very 224 | // unperformant 225 | if !strings.HasPrefix(urlStr, "redis://") { 226 | return urlStr, nil 227 | } 228 | 229 | u, err := url.Parse(urlStr) 230 | if err != nil { 231 | return urlStr, nil 232 | } 233 | 234 | q := u.Query() 235 | 236 | username := defaultAuthUser 237 | if n := u.User.Username(); n != "" { 238 | username = n 239 | } else if n := q.Get("username"); n != "" { 240 | username = n 241 | } 242 | 243 | password := q.Get("password") 244 | if p, ok := u.User.Password(); ok { 245 | password = p 246 | } 247 | 248 | opts := []DialOpt{ 249 | DialAuthUser(username, password), 250 | } 251 | 252 | dbStr := q.Get("db") 253 | if u.Path != "" && u.Path != "/" { 254 | dbStr = u.Path[1:] 255 | } 256 | 257 | if dbStr, err := strconv.Atoi(dbStr); err == nil { 258 | opts = append(opts, DialSelectDB(dbStr)) 259 | } 260 | 261 | return u.Host, opts 262 | } 263 | 264 | // Dial is a ConnFunc which creates a Conn using net.Dial and NewConn. It takes 265 | // in a number of options which can overwrite its default behavior as well. 266 | // 267 | // In place of a host:port address, Dial also accepts a URI, as per: 268 | // https://www.iana.org/assignments/uri-schemes/prov/redis 269 | // If the URI has an AUTH password or db specified Dial will attempt to perform 270 | // the AUTH and/or SELECT as well. 271 | // 272 | // If either DialAuthPass or DialSelectDB is used it overwrites the associated 273 | // value passed in by the URI. 274 | // 275 | // The default options Dial uses are: 276 | // 277 | // DialTimeout(10 * time.Second) 278 | // 279 | func Dial(network, addr string, opts ...DialOpt) (Conn, error) { 280 | var do dialOpts 281 | for _, opt := range defaultDialOpts { 282 | opt(&do) 283 | } 284 | addr, addrOpts := parseRedisURL(addr) 285 | for _, opt := range addrOpts { 286 | opt(&do) 287 | } 288 | for _, opt := range opts { 289 | opt(&do) 290 | } 291 | 292 | var netConn net.Conn 293 | var err error 294 | dialer := net.Dialer{} 295 | if do.connectTimeout > 0 { 296 | dialer.Timeout = do.connectTimeout 297 | } 298 | if do.useTLSConfig { 299 | netConn, err = tls.DialWithDialer(&dialer, network, addr, do.tlsConfig) 300 | } else { 301 | netConn, err = dialer.Dial(network, addr) 302 | } 303 | 304 | if err != nil { 305 | return nil, err 306 | } 307 | 308 | // If the netConn is a net.TCPConn (or some wrapper for it) and so can have 309 | // keepalive enabled, do so with a sane (though slightly aggressive) 310 | // default. 311 | { 312 | type keepaliveConn interface { 313 | SetKeepAlive(bool) error 314 | SetKeepAlivePeriod(time.Duration) error 315 | } 316 | 317 | if kaConn, ok := netConn.(keepaliveConn); ok { 318 | if err = kaConn.SetKeepAlive(true); err != nil { 319 | netConn.Close() 320 | return nil, err 321 | } else if err = kaConn.SetKeepAlivePeriod(10 * time.Second); err != nil { 322 | netConn.Close() 323 | return nil, err 324 | } 325 | } 326 | } 327 | 328 | conn := NewConn(&timeoutConn{ 329 | readTimeout: do.readTimeout, 330 | writeTimeout: do.writeTimeout, 331 | Conn: netConn, 332 | }) 333 | 334 | if do.authUser != "" && do.authUser != defaultAuthUser { 335 | if err := conn.Do(Cmd(nil, "AUTH", do.authUser, do.authPass)); err != nil { 336 | conn.Close() 337 | return nil, err 338 | } 339 | } else if do.authPass != "" { 340 | if err := conn.Do(Cmd(nil, "AUTH", do.authPass)); err != nil { 341 | conn.Close() 342 | return nil, err 343 | } 344 | } 345 | 346 | if do.selectDB != "" { 347 | if err := conn.Do(Cmd(nil, "SELECT", do.selectDB)); err != nil { 348 | conn.Close() 349 | return nil, err 350 | } 351 | } 352 | 353 | return conn, nil 354 | } 355 | -------------------------------------------------------------------------------- /stream.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "io" 8 | "math" 9 | "strconv" 10 | "time" 11 | 12 | "errors" 13 | 14 | "github.com/mediocregopher/radix/v3/internal/bytesutil" 15 | "github.com/mediocregopher/radix/v3/resp" 16 | "github.com/mediocregopher/radix/v3/resp/resp2" 17 | ) 18 | 19 | // StreamEntryID represents an ID used in a Redis stream with the format