├── .github └── workflows │ └── go.yml ├── LICENSE ├── README.md ├── example ├── clone.go ├── go.mod ├── mux │ ├── clone.go │ └── handler.go └── tls │ └── clone.go ├── go.mod ├── go.sum ├── logo.png ├── redcon.go ├── redcon_test.go ├── resp.go └── resp_test.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | build: 12 | name: Build 13 | runs-on: ubuntu-latest 14 | steps: 15 | 16 | - name: Set up Go 1.x 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: ^1.13 20 | 21 | - name: Check out code into the Go module directory 22 | uses: actions/checkout@v2 23 | 24 | - name: Get dependencies 25 | run: | 26 | go get -v -t -d ./... 27 | if [ -f Gopkg.toml ]; then 28 | curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh 29 | dep ensure 30 | fi 31 | 32 | - name: Build 33 | run: go build -v . 34 | 35 | - name: Test 36 | run: go test -v . 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Josh Baker 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
8 | 9 |Redis compatible server framework for Go
10 | 11 | Features 12 | -------- 13 | - Create a [Fast](#benchmarks) custom Redis compatible server in Go 14 | - Simple interface. One function `ListenAndServe` and two types `Conn` & `Command` 15 | - Support for pipelining and telnet commands 16 | - Works with Redis clients such as [redigo](https://github.com/garyburd/redigo), [redis-py](https://github.com/andymccurdy/redis-py), [node_redis](https://github.com/NodeRedis/node_redis), and [jedis](https://github.com/xetorthio/jedis) 17 | - [TLS Support](#tls-example) 18 | - Compatible pub/sub support 19 | - Multithreaded 20 | 21 | *This library is also available for [Rust](https://github.com/tidwall/redcon.rs) and [C](https://github.com/tidwall/redcon.c).* 22 | 23 | Installing 24 | ---------- 25 | 26 | ``` 27 | go get -u github.com/tidwall/redcon 28 | ``` 29 | 30 | Example 31 | ------- 32 | 33 | Here's a full example of a Redis clone that accepts: 34 | 35 | - SET key value 36 | - GET key 37 | - SETNX key value 38 | - DEL key 39 | - PING 40 | - QUIT 41 | - PUBLISH channel message 42 | - SUBSCRIBE channel 43 | 44 | You can run this example from a terminal: 45 | 46 | ```sh 47 | go run example/clone.go 48 | ``` 49 | 50 | ```go 51 | package main 52 | 53 | import ( 54 | "log" 55 | "strings" 56 | "sync" 57 | 58 | "github.com/tidwall/redcon" 59 | ) 60 | 61 | var addr = ":6380" 62 | 63 | func main() { 64 | var mu sync.RWMutex 65 | var items = make(map[string][]byte) 66 | var ps redcon.PubSub 67 | go log.Printf("started server at %s", addr) 68 | err := redcon.ListenAndServe(addr, 69 | func(conn redcon.Conn, cmd redcon.Command) { 70 | switch strings.ToLower(string(cmd.Args[0])) { 71 | default: 72 | conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") 73 | case "ping": 74 | conn.WriteString("PONG") 75 | case "quit": 76 | conn.WriteString("OK") 77 | conn.Close() 78 | case "set": 79 | if len(cmd.Args) != 3 { 80 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 81 | return 82 | } 83 | mu.Lock() 84 | items[string(cmd.Args[1])] = cmd.Args[2] 85 | mu.Unlock() 86 | conn.WriteString("OK") 87 | case "get": 88 | if len(cmd.Args) != 2 { 89 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 90 | return 91 | } 92 | mu.RLock() 93 | val, ok := items[string(cmd.Args[1])] 94 | mu.RUnlock() 95 | if !ok { 96 | conn.WriteNull() 97 | } else { 98 | conn.WriteBulk(val) 99 | } 100 | case "setnx": 101 | if len(cmd.Args) != 3 { 102 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 103 | return 104 | } 105 | mu.RLock() 106 | _, ok := items[string(cmd.Args[1])] 107 | mu.RUnlock() 108 | if ok { 109 | conn.WriteInt(0) 110 | return 111 | } 112 | mu.Lock() 113 | items[string(cmd.Args[1])] = cmd.Args[2] 114 | mu.Unlock() 115 | conn.WriteInt(1) 116 | case "del": 117 | if len(cmd.Args) != 2 { 118 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 119 | return 120 | } 121 | mu.Lock() 122 | _, ok := items[string(cmd.Args[1])] 123 | delete(items, string(cmd.Args[1])) 124 | mu.Unlock() 125 | if !ok { 126 | conn.WriteInt(0) 127 | } else { 128 | conn.WriteInt(1) 129 | } 130 | case "publish": 131 | if len(cmd.Args) != 3 { 132 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 133 | return 134 | } 135 | conn.WriteInt(ps.Publish(string(cmd.Args[1]), string(cmd.Args[2]))) 136 | case "subscribe", "psubscribe": 137 | if len(cmd.Args) < 2 { 138 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 139 | return 140 | } 141 | command := strings.ToLower(string(cmd.Args[0])) 142 | for i := 1; i < len(cmd.Args); i++ { 143 | if command == "psubscribe" { 144 | ps.Psubscribe(conn, string(cmd.Args[i])) 145 | } else { 146 | ps.Subscribe(conn, string(cmd.Args[i])) 147 | } 148 | } 149 | } 150 | }, 151 | func(conn redcon.Conn) bool { 152 | // Use this function to accept or deny the connection. 153 | // log.Printf("accept: %s", conn.RemoteAddr()) 154 | return true 155 | }, 156 | func(conn redcon.Conn, err error) { 157 | // This is called when the connection has been closed 158 | // log.Printf("closed: %s, err: %v", conn.RemoteAddr(), err) 159 | }, 160 | ) 161 | if err != nil { 162 | log.Fatal(err) 163 | } 164 | } 165 | ``` 166 | 167 | TLS Example 168 | ----------- 169 | 170 | Redcon has full TLS support through the `ListenAndServeTLS` function. 171 | 172 | The [same example](example/tls/clone.go) is also provided for serving Redcon over TLS. 173 | 174 | ```sh 175 | go run example/tls/clone.go 176 | ``` 177 | 178 | Benchmarks 179 | ---------- 180 | 181 | **Redis**: Single-threaded, no disk persistence. 182 | 183 | ``` 184 | $ redis-server --port 6379 --appendonly no 185 | ``` 186 | ``` 187 | redis-benchmark -p 6379 -t set,get -n 10000000 -q -P 512 -c 512 188 | SET: 941265.12 requests per second 189 | GET: 1189909.50 requests per second 190 | ``` 191 | 192 | **Redcon**: Single-threaded, no disk persistence. 193 | 194 | ``` 195 | $ GOMAXPROCS=1 go run example/clone.go 196 | ``` 197 | ``` 198 | redis-benchmark -p 6380 -t set,get -n 10000000 -q -P 512 -c 512 199 | SET: 2018570.88 requests per second 200 | GET: 2403846.25 requests per second 201 | ``` 202 | 203 | **Redcon**: Multi-threaded, no disk persistence. 204 | 205 | ``` 206 | $ GOMAXPROCS=0 go run example/clone.go 207 | ``` 208 | ``` 209 | $ redis-benchmark -p 6380 -t set,get -n 10000000 -q -P 512 -c 512 210 | SET: 1944390.38 requests per second 211 | GET: 3993610.25 requests per second 212 | ``` 213 | 214 | *Running on a MacBook Pro 15" 2.8 GHz Intel Core i7 using Go 1.7* 215 | 216 | Contact 217 | ------- 218 | Josh Baker [@tidwall](http://twitter.com/tidwall) 219 | 220 | License 221 | ------- 222 | Redcon source code is available under the MIT [License](/LICENSE). 223 | -------------------------------------------------------------------------------- /example/clone.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "strings" 6 | "sync" 7 | 8 | "github.com/tidwall/redcon" 9 | ) 10 | 11 | var addr = ":6380" 12 | 13 | func main() { 14 | var mu sync.RWMutex 15 | var items = make(map[string][]byte) 16 | var ps redcon.PubSub 17 | go log.Printf("started server at %s", addr) 18 | 19 | err := redcon.ListenAndServe(addr, 20 | func(conn redcon.Conn, cmd redcon.Command) { 21 | switch strings.ToLower(string(cmd.Args[0])) { 22 | default: 23 | conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") 24 | case "publish": 25 | // Publish to all pub/sub subscribers and return the number of 26 | // messages that were sent. 27 | if len(cmd.Args) != 3 { 28 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 29 | return 30 | } 31 | count := ps.Publish(string(cmd.Args[1]), string(cmd.Args[2])) 32 | conn.WriteInt(count) 33 | case "subscribe", "psubscribe": 34 | // Subscribe to a pub/sub channel. The `Psubscribe` and 35 | // `Subscribe` operations will detach the connection from the 36 | // event handler and manage all network I/O for this connection 37 | // in the background. 38 | if len(cmd.Args) < 2 { 39 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 40 | return 41 | } 42 | command := strings.ToLower(string(cmd.Args[0])) 43 | for i := 1; i < len(cmd.Args); i++ { 44 | if command == "psubscribe" { 45 | ps.Psubscribe(conn, string(cmd.Args[i])) 46 | } else { 47 | ps.Subscribe(conn, string(cmd.Args[i])) 48 | } 49 | } 50 | case "detach": 51 | hconn := conn.Detach() 52 | log.Printf("connection has been detached") 53 | go func() { 54 | defer hconn.Close() 55 | hconn.WriteString("OK") 56 | hconn.Flush() 57 | }() 58 | case "ping": 59 | conn.WriteString("PONG") 60 | case "quit": 61 | conn.WriteString("OK") 62 | conn.Close() 63 | case "set": 64 | if len(cmd.Args) != 3 { 65 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 66 | return 67 | } 68 | mu.Lock() 69 | items[string(cmd.Args[1])] = cmd.Args[2] 70 | mu.Unlock() 71 | conn.WriteString("OK") 72 | case "get": 73 | if len(cmd.Args) != 2 { 74 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 75 | return 76 | } 77 | mu.RLock() 78 | val, ok := items[string(cmd.Args[1])] 79 | mu.RUnlock() 80 | if !ok { 81 | conn.WriteNull() 82 | } else { 83 | conn.WriteBulk(val) 84 | } 85 | case "setnx": 86 | if len(cmd.Args) != 3 { 87 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 88 | return 89 | } 90 | mu.RLock() 91 | _, ok := items[string(cmd.Args[1])] 92 | mu.RUnlock() 93 | if ok { 94 | conn.WriteInt(0) 95 | return 96 | } 97 | mu.Lock() 98 | items[string(cmd.Args[1])] = cmd.Args[2] 99 | mu.Unlock() 100 | conn.WriteInt(1) 101 | case "del": 102 | if len(cmd.Args) != 2 { 103 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 104 | return 105 | } 106 | mu.Lock() 107 | _, ok := items[string(cmd.Args[1])] 108 | delete(items, string(cmd.Args[1])) 109 | mu.Unlock() 110 | if !ok { 111 | conn.WriteInt(0) 112 | } else { 113 | conn.WriteInt(1) 114 | } 115 | case "config": 116 | // This simple (blank) response is only here to allow for the 117 | // redis-benchmark command to work with this example. 118 | conn.WriteArray(2) 119 | conn.WriteBulk(cmd.Args[2]) 120 | conn.WriteBulkString("") 121 | } 122 | }, 123 | func(conn redcon.Conn) bool { 124 | // Use this function to accept or deny the connection. 125 | // log.Printf("accept: %s", conn.RemoteAddr()) 126 | return true 127 | }, 128 | func(conn redcon.Conn, err error) { 129 | // This is called when the connection has been closed 130 | // log.Printf("closed: %s, err: %v", conn.RemoteAddr(), err) 131 | }, 132 | ) 133 | if err != nil { 134 | log.Fatal(err) 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /example/go.mod: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidwall/redcon/9a9922f37eb1ea5420a7f0f6cb8acdf7a6e7142d/example/go.mod -------------------------------------------------------------------------------- /example/mux/clone.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/tidwall/redcon" 7 | ) 8 | 9 | var addr = ":6380" 10 | 11 | func main() { 12 | log.Printf("started server at %s", addr) 13 | 14 | handler := NewHandler() 15 | 16 | mux := redcon.NewServeMux() 17 | mux.HandleFunc("detach", handler.detach) 18 | mux.HandleFunc("ping", handler.ping) 19 | mux.HandleFunc("quit", handler.quit) 20 | mux.HandleFunc("set", handler.set) 21 | mux.HandleFunc("get", handler.get) 22 | mux.HandleFunc("setnx", handler.setnx) 23 | mux.HandleFunc("del", handler.delete) 24 | 25 | err := redcon.ListenAndServe(addr, 26 | mux.ServeRESP, 27 | func(conn redcon.Conn) bool { 28 | // use this function to accept or deny the connection. 29 | // log.Printf("accept: %s", conn.RemoteAddr()) 30 | return true 31 | }, 32 | func(conn redcon.Conn, err error) { 33 | // this is called when the connection has been closed 34 | // log.Printf("closed: %s, err: %v", conn.RemoteAddr(), err) 35 | }, 36 | ) 37 | if err != nil { 38 | log.Fatal(err) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /example/mux/handler.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "sync" 6 | 7 | "github.com/tidwall/redcon" 8 | ) 9 | 10 | type Handler struct { 11 | itemsMux sync.RWMutex 12 | items map[string][]byte 13 | } 14 | 15 | func NewHandler() *Handler { 16 | return &Handler{ 17 | items: make(map[string][]byte), 18 | } 19 | } 20 | 21 | func (h *Handler) detach(conn redcon.Conn, cmd redcon.Command) { 22 | detachedConn := conn.Detach() 23 | log.Printf("connection has been detached") 24 | go func(c redcon.DetachedConn) { 25 | defer c.Close() 26 | 27 | c.WriteString("OK") 28 | c.Flush() 29 | }(detachedConn) 30 | } 31 | 32 | func (h *Handler) ping(conn redcon.Conn, cmd redcon.Command) { 33 | conn.WriteString("PONG") 34 | } 35 | 36 | func (h *Handler) quit(conn redcon.Conn, cmd redcon.Command) { 37 | conn.WriteString("OK") 38 | conn.Close() 39 | } 40 | 41 | func (h *Handler) set(conn redcon.Conn, cmd redcon.Command) { 42 | if len(cmd.Args) != 3 { 43 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 44 | return 45 | } 46 | 47 | h.itemsMux.Lock() 48 | h.items[string(cmd.Args[1])] = cmd.Args[2] 49 | h.itemsMux.Unlock() 50 | 51 | conn.WriteString("OK") 52 | } 53 | 54 | func (h *Handler) get(conn redcon.Conn, cmd redcon.Command) { 55 | if len(cmd.Args) != 2 { 56 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 57 | return 58 | } 59 | 60 | h.itemsMux.RLock() 61 | val, ok := h.items[string(cmd.Args[1])] 62 | h.itemsMux.RUnlock() 63 | 64 | if !ok { 65 | conn.WriteNull() 66 | } else { 67 | conn.WriteBulk(val) 68 | } 69 | } 70 | 71 | func (h *Handler) setnx(conn redcon.Conn, cmd redcon.Command) { 72 | if len(cmd.Args) != 3 { 73 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 74 | return 75 | } 76 | 77 | h.itemsMux.RLock() 78 | _, ok := h.items[string(cmd.Args[1])] 79 | h.itemsMux.RUnlock() 80 | 81 | if ok { 82 | conn.WriteInt(0) 83 | return 84 | } 85 | 86 | h.itemsMux.Lock() 87 | h.items[string(cmd.Args[1])] = cmd.Args[2] 88 | h.itemsMux.Unlock() 89 | 90 | conn.WriteInt(1) 91 | } 92 | 93 | func (h *Handler) delete(conn redcon.Conn, cmd redcon.Command) { 94 | if len(cmd.Args) != 2 { 95 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 96 | return 97 | } 98 | 99 | h.itemsMux.Lock() 100 | _, ok := h.items[string(cmd.Args[1])] 101 | delete(h.items, string(cmd.Args[1])) 102 | h.itemsMux.Unlock() 103 | 104 | if !ok { 105 | conn.WriteInt(0) 106 | } else { 107 | conn.WriteInt(1) 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /example/tls/clone.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | "log" 6 | "strings" 7 | "sync" 8 | 9 | "github.com/tidwall/redcon" 10 | ) 11 | 12 | const serverKey = `-----BEGIN EC PARAMETERS----- 13 | BggqhkjOPQMBBw== 14 | -----END EC PARAMETERS----- 15 | -----BEGIN EC PRIVATE KEY----- 16 | MHcCAQEEIHg+g2unjA5BkDtXSN9ShN7kbPlbCcqcYdDu+QeV8XWuoAoGCCqGSM49 17 | AwEHoUQDQgAEcZpodWh3SEs5Hh3rrEiu1LZOYSaNIWO34MgRxvqwz1FMpLxNlx0G 18 | cSqrxhPubawptX5MSr02ft32kfOlYbaF5Q== 19 | -----END EC PRIVATE KEY----- 20 | ` 21 | 22 | const serverCert = `-----BEGIN CERTIFICATE----- 23 | MIIB+TCCAZ+gAwIBAgIJAL05LKXo6PrrMAoGCCqGSM49BAMCMFkxCzAJBgNVBAYT 24 | AkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRn 25 | aXRzIFB0eSBMdGQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0xNTEyMDgxNDAxMTNa 26 | Fw0yNTEyMDUxNDAxMTNaMFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 27 | YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMM 28 | CWxvY2FsaG9zdDBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABHGaaHVod0hLOR4d 29 | 66xIrtS2TmEmjSFjt+DIEcb6sM9RTKS8TZcdBnEqq8YT7m2sKbV+TEq9Nn7d9pHz 30 | pWG2heWjUDBOMB0GA1UdDgQWBBR0fqrecDJ44D/fiYJiOeBzfoqEijAfBgNVHSME 31 | GDAWgBR0fqrecDJ44D/fiYJiOeBzfoqEijAMBgNVHRMEBTADAQH/MAoGCCqGSM49 32 | BAMCA0gAMEUCIEKzVMF3JqjQjuM2rX7Rx8hancI5KJhwfeKu1xbyR7XaAiEA2UT7 33 | 1xOP035EcraRmWPe7tO0LpXgMxlh2VItpc2uc2w= 34 | -----END CERTIFICATE----- 35 | ` 36 | 37 | var addr = ":6380" 38 | 39 | func main() { 40 | cer, err := tls.X509KeyPair([]byte(serverCert), []byte(serverKey)) 41 | if err != nil { 42 | log.Fatal(err) 43 | } 44 | config := &tls.Config{Certificates: []tls.Certificate{cer}} 45 | 46 | var mu sync.RWMutex 47 | var items = make(map[string][]byte) 48 | 49 | go log.Printf("started server at %s", addr) 50 | err = redcon.ListenAndServeTLS(addr, 51 | func(conn redcon.Conn, cmd redcon.Command) { 52 | switch strings.ToLower(string(cmd.Args[0])) { 53 | default: 54 | conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") 55 | case "detach": 56 | hconn := conn.Detach() 57 | log.Printf("connection has been detached") 58 | go func() { 59 | defer hconn.Close() 60 | hconn.WriteString("OK") 61 | hconn.Flush() 62 | }() 63 | return 64 | case "ping": 65 | conn.WriteString("PONG") 66 | case "quit": 67 | conn.WriteString("OK") 68 | conn.Close() 69 | case "set": 70 | if len(cmd.Args) != 3 { 71 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 72 | return 73 | } 74 | mu.Lock() 75 | items[string(cmd.Args[1])] = cmd.Args[2] 76 | mu.Unlock() 77 | conn.WriteString("OK") 78 | case "get": 79 | if len(cmd.Args) != 2 { 80 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 81 | return 82 | } 83 | mu.RLock() 84 | val, ok := items[string(cmd.Args[1])] 85 | mu.RUnlock() 86 | if !ok { 87 | conn.WriteNull() 88 | } else { 89 | conn.WriteBulk(val) 90 | } 91 | case "setnx": 92 | if len(cmd.Args) != 3 { 93 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 94 | return 95 | } 96 | mu.RLock() 97 | _, ok := items[string(cmd.Args[1])] 98 | mu.RUnlock() 99 | if ok { 100 | conn.WriteInt(0) 101 | return 102 | } 103 | mu.Lock() 104 | items[string(cmd.Args[1])] = cmd.Args[2] 105 | mu.Unlock() 106 | conn.WriteInt(1) 107 | case "del": 108 | if len(cmd.Args) != 2 { 109 | conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") 110 | return 111 | } 112 | mu.Lock() 113 | _, ok := items[string(cmd.Args[1])] 114 | delete(items, string(cmd.Args[1])) 115 | mu.Unlock() 116 | if !ok { 117 | conn.WriteInt(0) 118 | } else { 119 | conn.WriteInt(1) 120 | } 121 | } 122 | }, 123 | func(conn redcon.Conn) bool { 124 | return true 125 | }, 126 | func(conn redcon.Conn, err error) { 127 | }, 128 | config, 129 | ) 130 | 131 | if err != nil { 132 | log.Fatal(err) 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/tidwall/redcon 2 | 3 | go 1.15 4 | 5 | require ( 6 | github.com/tidwall/btree v1.1.0 7 | github.com/tidwall/match v1.1.1 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/tidwall/btree v1.1.0 h1:5P+9WU8ui5uhmcg3SoPyTwoI0mVyZ1nps7YQzTZFkYM= 2 | github.com/tidwall/btree v1.1.0/go.mod h1:TzIRzen6yHbibdSfK6t8QimqbUnoxUSrZfeW7Uob0q4= 3 | github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= 4 | github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= 5 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidwall/redcon/9a9922f37eb1ea5420a7f0f6cb8acdf7a6e7142d/logo.png -------------------------------------------------------------------------------- /redcon.go: -------------------------------------------------------------------------------- 1 | // Package redcon implements a Redis compatible server framework 2 | package redcon 3 | 4 | import ( 5 | "bufio" 6 | "crypto/tls" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net" 11 | "strings" 12 | "sync" 13 | "time" 14 | 15 | "github.com/tidwall/btree" 16 | "github.com/tidwall/match" 17 | ) 18 | 19 | var ( 20 | errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"} 21 | errInvalidBulkLength = &errProtocol{"invalid bulk length"} 22 | errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"} 23 | errDetached = errors.New("detached") 24 | errIncompleteCommand = errors.New("incomplete command") 25 | errTooMuchData = errors.New("too much data") 26 | ) 27 | 28 | const maxBufferCap = 262144 29 | 30 | type errProtocol struct { 31 | msg string 32 | } 33 | 34 | func (err *errProtocol) Error() string { 35 | return "Protocol error: " + err.msg 36 | } 37 | 38 | // Conn represents a client connection 39 | type Conn interface { 40 | // RemoteAddr returns the remote address of the client connection. 41 | RemoteAddr() string 42 | // Close closes the connection. 43 | Close() error 44 | // WriteError writes an error to the client. 45 | WriteError(msg string) 46 | // WriteString writes a string to the client. 47 | WriteString(str string) 48 | // WriteBulk writes bulk bytes to the client. 49 | WriteBulk(bulk []byte) 50 | // WriteBulkString writes a bulk string to the client. 51 | WriteBulkString(bulk string) 52 | // WriteInt writes an integer to the client. 53 | WriteInt(num int) 54 | // WriteInt64 writes a 64-bit signed integer to the client. 55 | WriteInt64(num int64) 56 | // WriteUint64 writes a 64-bit unsigned integer to the client. 57 | WriteUint64(num uint64) 58 | // WriteArray writes an array header. You must then write additional 59 | // sub-responses to the client to complete the response. 60 | // For example to write two strings: 61 | // 62 | // c.WriteArray(2) 63 | // c.WriteBulkString("item 1") 64 | // c.WriteBulkString("item 2") 65 | WriteArray(count int) 66 | // WriteNull writes a null to the client 67 | WriteNull() 68 | // WriteRaw writes raw data to the client. 69 | WriteRaw(data []byte) 70 | // WriteAny writes any type to the client. 71 | // nil -> null 72 | // error -> error (adds "ERR " when first word is not uppercase) 73 | // string -> bulk-string 74 | // numbers -> bulk-string 75 | // []byte -> bulk-string 76 | // bool -> bulk-string ("0" or "1") 77 | // slice -> array 78 | // map -> array with key/value pairs 79 | // SimpleString -> string 80 | // SimpleInt -> integer 81 | // everything-else -> bulk-string representation using fmt.Sprint() 82 | WriteAny(any interface{}) 83 | // Context returns a user-defined context 84 | Context() interface{} 85 | // SetContext sets a user-defined context 86 | SetContext(v interface{}) 87 | // SetReadBuffer updates the buffer read size for the connection 88 | SetReadBuffer(bytes int) 89 | // Detach return a connection that is detached from the server. 90 | // Useful for operations like PubSub. 91 | // 92 | // dconn := conn.Detach() 93 | // go func(){ 94 | // defer dconn.Close() 95 | // cmd, err := dconn.ReadCommand() 96 | // if err != nil{ 97 | // fmt.Printf("read failed: %v\n", err) 98 | // return 99 | // } 100 | // fmt.Printf("received command: %v", cmd) 101 | // hconn.WriteString("OK") 102 | // if err := dconn.Flush(); err != nil{ 103 | // fmt.Printf("write failed: %v\n", err) 104 | // return 105 | // } 106 | // }() 107 | Detach() DetachedConn 108 | // ReadPipeline returns all commands in current pipeline, if any 109 | // The commands are removed from the pipeline. 110 | ReadPipeline() []Command 111 | // PeekPipeline returns all commands in current pipeline, if any. 112 | // The commands remain in the pipeline. 113 | PeekPipeline() []Command 114 | // NetConn returns the base net.Conn connection 115 | NetConn() net.Conn 116 | // WriteBulkFrom write bulk from io.Reader, size n 117 | WriteBulkFrom(n int64, rb io.Reader) 118 | } 119 | 120 | // NewServer returns a new Redcon server configured on "tcp" network net. 121 | func NewServer(addr string, 122 | handler func(conn Conn, cmd Command), 123 | accept func(conn Conn) bool, 124 | closed func(conn Conn, err error), 125 | ) *Server { 126 | return NewServerNetwork("tcp", addr, handler, accept, closed) 127 | } 128 | 129 | // NewServerTLS returns a new Redcon TLS server configured on "tcp" network net. 130 | func NewServerTLS(addr string, 131 | handler func(conn Conn, cmd Command), 132 | accept func(conn Conn) bool, 133 | closed func(conn Conn, err error), 134 | config *tls.Config, 135 | ) *TLSServer { 136 | return NewServerNetworkTLS("tcp", addr, handler, accept, closed, config) 137 | } 138 | 139 | // NewServerNetwork returns a new Redcon server. The network net must be 140 | // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" 141 | func NewServerNetwork( 142 | net, laddr string, 143 | handler func(conn Conn, cmd Command), 144 | accept func(conn Conn) bool, 145 | closed func(conn Conn, err error), 146 | ) *Server { 147 | if handler == nil { 148 | panic("handler is nil") 149 | } 150 | s := newServer() 151 | s.net = net 152 | s.laddr = laddr 153 | s.handler = handler 154 | s.accept = accept 155 | s.closed = closed 156 | return s 157 | } 158 | 159 | // NewServerNetworkTLS returns a new TLS Redcon server. The network net must be 160 | // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" 161 | func NewServerNetworkTLS( 162 | net, laddr string, 163 | handler func(conn Conn, cmd Command), 164 | accept func(conn Conn) bool, 165 | closed func(conn Conn, err error), 166 | config *tls.Config, 167 | ) *TLSServer { 168 | if handler == nil { 169 | panic("handler is nil") 170 | } 171 | s := Server{ 172 | net: net, 173 | laddr: laddr, 174 | handler: handler, 175 | accept: accept, 176 | closed: closed, 177 | conns: make(map[*conn]bool), 178 | } 179 | 180 | tls := &TLSServer{ 181 | config: config, 182 | Server: &s, 183 | } 184 | return tls 185 | } 186 | 187 | // Close stops listening on the TCP address. 188 | // Already Accepted connections will be closed. 189 | func (s *Server) Close() error { 190 | s.mu.Lock() 191 | defer s.mu.Unlock() 192 | if s.ln == nil { 193 | return errors.New("not serving") 194 | } 195 | s.done = true 196 | return s.ln.Close() 197 | } 198 | 199 | // ListenAndServe serves incoming connections. 200 | func (s *Server) ListenAndServe() error { 201 | return s.ListenServeAndSignal(nil) 202 | } 203 | 204 | // Addr returns server's listen address 205 | func (s *Server) Addr() net.Addr { 206 | return s.ln.Addr() 207 | } 208 | 209 | // Close stops listening on the TCP address. 210 | // Already Accepted connections will be closed. 211 | func (s *TLSServer) Close() error { 212 | s.mu.Lock() 213 | defer s.mu.Unlock() 214 | if s.ln == nil { 215 | return errors.New("not serving") 216 | } 217 | s.done = true 218 | return s.ln.Close() 219 | } 220 | 221 | // ListenAndServe serves incoming connections. 222 | func (s *TLSServer) ListenAndServe() error { 223 | return s.ListenServeAndSignal(nil) 224 | } 225 | 226 | func newServer() *Server { 227 | s := &Server{ 228 | conns: make(map[*conn]bool), 229 | } 230 | return s 231 | } 232 | 233 | // Serve creates a new server and serves with the given net.Listener. 234 | func Serve(ln net.Listener, 235 | handler func(conn Conn, cmd Command), 236 | accept func(conn Conn) bool, 237 | closed func(conn Conn, err error), 238 | ) error { 239 | s := newServer() 240 | s.mu.Lock() 241 | s.net = ln.Addr().Network() 242 | s.laddr = ln.Addr().String() 243 | s.ln = ln 244 | s.handler = handler 245 | s.accept = accept 246 | s.closed = closed 247 | s.mu.Unlock() 248 | return serve(s) 249 | } 250 | 251 | // ListenAndServe creates a new server and binds to addr configured on "tcp" network net. 252 | func ListenAndServe(addr string, 253 | handler func(conn Conn, cmd Command), 254 | accept func(conn Conn) bool, 255 | closed func(conn Conn, err error), 256 | ) error { 257 | return ListenAndServeNetwork("tcp", addr, handler, accept, closed) 258 | } 259 | 260 | // ListenAndServeTLS creates a new TLS server and binds to addr configured on "tcp" network net. 261 | func ListenAndServeTLS(addr string, 262 | handler func(conn Conn, cmd Command), 263 | accept func(conn Conn) bool, 264 | closed func(conn Conn, err error), 265 | config *tls.Config, 266 | ) error { 267 | return ListenAndServeNetworkTLS("tcp", addr, handler, accept, closed, config) 268 | } 269 | 270 | // ListenAndServeNetwork creates a new server and binds to addr. The network net must be 271 | // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" 272 | func ListenAndServeNetwork( 273 | net, laddr string, 274 | handler func(conn Conn, cmd Command), 275 | accept func(conn Conn) bool, 276 | closed func(conn Conn, err error), 277 | ) error { 278 | return NewServerNetwork(net, laddr, handler, accept, closed).ListenAndServe() 279 | } 280 | 281 | // ListenAndServeNetworkTLS creates a new TLS server and binds to addr. The network net must be 282 | // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" 283 | func ListenAndServeNetworkTLS( 284 | net, laddr string, 285 | handler func(conn Conn, cmd Command), 286 | accept func(conn Conn) bool, 287 | closed func(conn Conn, err error), 288 | config *tls.Config, 289 | ) error { 290 | return NewServerNetworkTLS(net, laddr, handler, accept, closed, config).ListenAndServe() 291 | } 292 | 293 | // ListenServeAndSignal serves incoming connections and passes nil or error 294 | // when listening. signal can be nil. 295 | func (s *Server) ListenServeAndSignal(signal chan error) error { 296 | ln, err := net.Listen(s.net, s.laddr) 297 | if err != nil { 298 | if signal != nil { 299 | signal <- err 300 | } 301 | return err 302 | } 303 | s.mu.Lock() 304 | s.ln = ln 305 | s.mu.Unlock() 306 | if signal != nil { 307 | signal <- nil 308 | } 309 | return serve(s) 310 | } 311 | 312 | // Serve serves incoming connections with the given net.Listener. 313 | func (s *Server) Serve(ln net.Listener) error { 314 | s.mu.Lock() 315 | s.ln = ln 316 | s.net = ln.Addr().Network() 317 | s.laddr = ln.Addr().String() 318 | s.mu.Unlock() 319 | return serve(s) 320 | } 321 | 322 | // ListenServeAndSignal serves incoming connections and passes nil or error 323 | // when listening. signal can be nil. 324 | func (s *TLSServer) ListenServeAndSignal(signal chan error) error { 325 | ln, err := tls.Listen(s.net, s.laddr, s.config) 326 | if err != nil { 327 | if signal != nil { 328 | signal <- err 329 | } 330 | return err 331 | } 332 | s.mu.Lock() 333 | s.ln = ln 334 | s.mu.Unlock() 335 | if signal != nil { 336 | signal <- nil 337 | } 338 | return serve(s.Server) 339 | } 340 | 341 | func serve(s *Server) error { 342 | defer func() { 343 | s.ln.Close() 344 | func() { 345 | s.mu.Lock() 346 | defer s.mu.Unlock() 347 | for c := range s.conns { 348 | c.conn.Close() 349 | } 350 | s.conns = nil 351 | }() 352 | }() 353 | for { 354 | lnconn, err := s.ln.Accept() 355 | if err != nil { 356 | s.mu.Lock() 357 | done := s.done 358 | s.mu.Unlock() 359 | if done { 360 | return nil 361 | } 362 | if errors.Is(err, net.ErrClosed) { 363 | // see https://github.com/tidwall/redcon/issues/46 364 | return nil 365 | } 366 | if s.AcceptError != nil { 367 | s.AcceptError(err) 368 | } 369 | continue 370 | } 371 | c := &conn{ 372 | conn: lnconn, 373 | addr: lnconn.RemoteAddr().String(), 374 | wr: NewWriter(lnconn), 375 | rd: NewReader(lnconn), 376 | } 377 | s.mu.Lock() 378 | c.idleClose = s.idleClose 379 | s.conns[c] = true 380 | s.mu.Unlock() 381 | if s.accept != nil && !s.accept(c) { 382 | s.mu.Lock() 383 | delete(s.conns, c) 384 | s.mu.Unlock() 385 | c.Close() 386 | continue 387 | } 388 | go handle(s, c) 389 | } 390 | } 391 | 392 | // handle manages the server connection. 393 | func handle(s *Server, c *conn) { 394 | var err error 395 | defer func() { 396 | if err != errDetached { 397 | // do not close the connection when a detach is detected. 398 | c.conn.Close() 399 | } 400 | func() { 401 | // remove the conn from the server 402 | s.mu.Lock() 403 | defer s.mu.Unlock() 404 | delete(s.conns, c) 405 | if s.closed != nil { 406 | if err == io.EOF { 407 | err = nil 408 | } 409 | s.closed(c, err) 410 | } 411 | }() 412 | }() 413 | 414 | err = func() error { 415 | // read commands and feed back to the client 416 | for { 417 | // read pipeline commands 418 | if c.idleClose != 0 { 419 | c.conn.SetReadDeadline(time.Now().Add(c.idleClose)) 420 | } 421 | cmds, err := c.rd.readCommands(nil) 422 | if err != nil { 423 | if err, ok := err.(*errProtocol); ok { 424 | // All protocol errors should attempt a response to 425 | // the client. Ignore write errors. 426 | c.wr.WriteError("ERR " + err.Error()) 427 | c.wr.Flush() 428 | } 429 | return err 430 | } 431 | c.cmds = cmds 432 | for len(c.cmds) > 0 { 433 | cmd := c.cmds[0] 434 | if len(c.cmds) == 1 { 435 | c.cmds = nil 436 | } else { 437 | c.cmds = c.cmds[1:] 438 | } 439 | s.handler(c, cmd) 440 | } 441 | if c.detached { 442 | // client has been detached 443 | return errDetached 444 | } 445 | if c.closed { 446 | return nil 447 | } 448 | if err := c.wr.Flush(); err != nil { 449 | return err 450 | } 451 | } 452 | }() 453 | } 454 | 455 | // conn represents a client connection 456 | type conn struct { 457 | conn net.Conn 458 | wr *Writer 459 | rd *Reader 460 | addr string 461 | ctx interface{} 462 | detached bool 463 | closed bool 464 | cmds []Command 465 | idleClose time.Duration 466 | } 467 | 468 | func (c *conn) Close() error { 469 | c.wr.Flush() 470 | c.closed = true 471 | return c.conn.Close() 472 | } 473 | func (c *conn) Context() interface{} { return c.ctx } 474 | func (c *conn) SetContext(v interface{}) { c.ctx = v } 475 | func (c *conn) SetReadBuffer(n int) {} 476 | func (c *conn) WriteString(str string) { c.wr.WriteString(str) } 477 | func (c *conn) WriteBulk(bulk []byte) { c.wr.WriteBulk(bulk) } 478 | func (c *conn) WriteBulkString(bulk string) { c.wr.WriteBulkString(bulk) } 479 | func (c *conn) WriteInt(num int) { c.wr.WriteInt(num) } 480 | func (c *conn) WriteInt64(num int64) { c.wr.WriteInt64(num) } 481 | func (c *conn) WriteUint64(num uint64) { c.wr.WriteUint64(num) } 482 | func (c *conn) WriteError(msg string) { c.wr.WriteError(msg) } 483 | func (c *conn) WriteArray(count int) { c.wr.WriteArray(count) } 484 | func (c *conn) WriteNull() { c.wr.WriteNull() } 485 | func (c *conn) WriteRaw(data []byte) { c.wr.WriteRaw(data) } 486 | func (c *conn) WriteAny(v interface{}) { c.wr.WriteAny(v) } 487 | func (c *conn) RemoteAddr() string { return c.addr } 488 | func (c *conn) ReadPipeline() []Command { 489 | cmds := c.cmds 490 | c.cmds = nil 491 | return cmds 492 | } 493 | func (c *conn) PeekPipeline() []Command { 494 | return c.cmds 495 | } 496 | func (c *conn) NetConn() net.Conn { 497 | return c.conn 498 | } 499 | func (c *conn) WriteBulkFrom(n int64, rb io.Reader) { 500 | c.wr.WriteBulkFrom(n, rb) 501 | } 502 | 503 | // BaseWriter returns the underlying connection writer, if any 504 | func BaseWriter(c Conn) *Writer { 505 | if c, ok := c.(*conn); ok { 506 | return c.wr 507 | } 508 | return nil 509 | } 510 | 511 | // DetachedConn represents a connection that is detached from the server 512 | type DetachedConn interface { 513 | // Conn is the original connection 514 | Conn 515 | // ReadCommand reads the next client command. 516 | ReadCommand() (Command, error) 517 | // Flush flushes any writes to the network. 518 | Flush() error 519 | } 520 | 521 | // Detach removes the current connection from the server loop and returns 522 | // a detached connection. This is useful for operations such as PubSub. 523 | // The detached connection must be closed by calling Close() when done. 524 | // All writes such as WriteString() will not be written to the client 525 | // until Flush() is called. 526 | func (c *conn) Detach() DetachedConn { 527 | c.detached = true 528 | cmds := c.cmds 529 | c.cmds = nil 530 | return &detachedConn{conn: c, cmds: cmds} 531 | } 532 | 533 | type detachedConn struct { 534 | *conn 535 | cmds []Command 536 | } 537 | 538 | // Flush writes and Write* calls to the client. 539 | func (dc *detachedConn) Flush() error { 540 | return dc.conn.wr.Flush() 541 | } 542 | 543 | // ReadCommand read the next command from the client. 544 | func (dc *detachedConn) ReadCommand() (Command, error) { 545 | if len(dc.cmds) > 0 { 546 | cmd := dc.cmds[0] 547 | if len(dc.cmds) == 1 { 548 | dc.cmds = nil 549 | } else { 550 | dc.cmds = dc.cmds[1:] 551 | } 552 | return cmd, nil 553 | } 554 | cmd, err := dc.rd.ReadCommand() 555 | if err != nil { 556 | return Command{}, err 557 | } 558 | return cmd, nil 559 | } 560 | 561 | // Command represent a command 562 | type Command struct { 563 | // Raw is a encoded RESP message. 564 | Raw []byte 565 | // Args is a series of arguments that make up the command. 566 | Args [][]byte 567 | } 568 | 569 | // Server defines a server for clients for managing client connections. 570 | type Server struct { 571 | mu sync.Mutex 572 | net string 573 | laddr string 574 | handler func(conn Conn, cmd Command) 575 | accept func(conn Conn) bool 576 | closed func(conn Conn, err error) 577 | conns map[*conn]bool 578 | ln net.Listener 579 | done bool 580 | idleClose time.Duration 581 | 582 | // AcceptError is an optional function used to handle Accept errors. 583 | AcceptError func(err error) 584 | } 585 | 586 | // TLSServer defines a server for clients for managing client connections. 587 | type TLSServer struct { 588 | *Server 589 | config *tls.Config 590 | } 591 | 592 | // Writer allows for writing RESP messages. 593 | type Writer struct { 594 | w io.Writer 595 | b []byte 596 | err error 597 | 598 | // buff use io buffer write to w(io.Writer) 599 | // for io.Copy r(io.Reader) to w(io.Writer) 600 | buff *bufio.Writer 601 | } 602 | 603 | // NewWriter creates a new RESP writer. 604 | func NewWriter(wr io.Writer) *Writer { 605 | return &Writer{ 606 | w: wr, 607 | buff: bufio.NewWriter(wr), 608 | } 609 | } 610 | 611 | func (w *Writer) WriteBulkFrom(n int64, r io.Reader) { 612 | if w != nil && w.err != nil { 613 | return 614 | } 615 | w.buff.Write(appendPrefix(w.b, '$', n)) 616 | io.Copy(w.buff, r) 617 | w.buff.Write([]byte{'\r', '\n'}) 618 | } 619 | 620 | // WriteNull writes a null to the client 621 | func (w *Writer) WriteNull() { 622 | if w.err != nil { 623 | return 624 | } 625 | w.b = AppendNull(w.b) 626 | } 627 | 628 | // WriteArray writes an array header. You must then write additional 629 | // sub-responses to the client to complete the response. 630 | // For example to write two strings: 631 | // 632 | // c.WriteArray(2) 633 | // c.WriteBulkString("item 1") 634 | // c.WriteBulkString("item 2") 635 | func (w *Writer) WriteArray(count int) { 636 | if w.err != nil { 637 | return 638 | } 639 | w.b = AppendArray(w.b, count) 640 | } 641 | 642 | // WriteBulk writes bulk bytes to the client. 643 | func (w *Writer) WriteBulk(bulk []byte) { 644 | if w.err != nil { 645 | return 646 | } 647 | w.b = AppendBulk(w.b, bulk) 648 | } 649 | 650 | // WriteBulkString writes a bulk string to the client. 651 | func (w *Writer) WriteBulkString(bulk string) { 652 | if w.err != nil { 653 | return 654 | } 655 | w.b = AppendBulkString(w.b, bulk) 656 | } 657 | 658 | // Buffer returns the unflushed buffer. This is a copy so changes 659 | // to the resulting []byte will not affect the writer. 660 | func (w *Writer) Buffer() []byte { 661 | if w.err != nil { 662 | return nil 663 | } 664 | return append([]byte(nil), w.b...) 665 | } 666 | 667 | // SetBuffer replaces the unflushed buffer with new bytes. 668 | func (w *Writer) SetBuffer(raw []byte) { 669 | if w.err != nil { 670 | return 671 | } 672 | w.b = w.b[:0] 673 | w.b = append(w.b, raw...) 674 | } 675 | 676 | // Flush writes all unflushed Write* calls to the underlying writer. 677 | func (w *Writer) Flush() error { 678 | if w.buff != nil { 679 | w.buff.Flush() 680 | } 681 | 682 | if w.err != nil { 683 | return w.err 684 | } 685 | _, w.err = w.w.Write(w.b) 686 | if cap(w.b) > maxBufferCap || w.err != nil { 687 | w.b = nil 688 | } else { 689 | w.b = w.b[:0] 690 | } 691 | return w.err 692 | } 693 | 694 | // WriteError writes an error to the client. 695 | func (w *Writer) WriteError(msg string) { 696 | if w.err != nil { 697 | return 698 | } 699 | w.b = AppendError(w.b, msg) 700 | } 701 | 702 | // WriteString writes a string to the client. 703 | func (w *Writer) WriteString(msg string) { 704 | if w.err != nil { 705 | return 706 | } 707 | w.b = AppendString(w.b, msg) 708 | } 709 | 710 | // WriteInt writes an integer to the client. 711 | func (w *Writer) WriteInt(num int) { 712 | if w.err != nil { 713 | return 714 | } 715 | w.WriteInt64(int64(num)) 716 | } 717 | 718 | // WriteInt64 writes a 64-bit signed integer to the client. 719 | func (w *Writer) WriteInt64(num int64) { 720 | if w.err != nil { 721 | return 722 | } 723 | w.b = AppendInt(w.b, num) 724 | } 725 | 726 | // WriteUint64 writes a 64-bit unsigned integer to the client. 727 | func (w *Writer) WriteUint64(num uint64) { 728 | if w.err != nil { 729 | return 730 | } 731 | w.b = AppendUint(w.b, num) 732 | } 733 | 734 | // WriteRaw writes raw data to the client. 735 | func (w *Writer) WriteRaw(data []byte) { 736 | if w.err != nil { 737 | return 738 | } 739 | w.b = append(w.b, data...) 740 | } 741 | 742 | // WriteAny writes any type to client. 743 | // 744 | // nil -> null 745 | // error -> error (adds "ERR " when first word is not uppercase) 746 | // string -> bulk-string 747 | // numbers -> bulk-string 748 | // []byte -> bulk-string 749 | // bool -> bulk-string ("0" or "1") 750 | // slice -> array 751 | // map -> array with key/value pairs 752 | // SimpleString -> string 753 | // SimpleInt -> integer 754 | // everything-else -> bulk-string representation using fmt.Sprint() 755 | func (w *Writer) WriteAny(v interface{}) { 756 | if w.err != nil { 757 | return 758 | } 759 | w.b = AppendAny(w.b, v) 760 | } 761 | 762 | // Reader represent a reader for RESP or telnet commands. 763 | type Reader struct { 764 | rd *bufio.Reader 765 | buf []byte 766 | start int 767 | end int 768 | cmds []Command 769 | } 770 | 771 | // NewReader returns a command reader which will read RESP or telnet commands. 772 | func NewReader(rd io.Reader) *Reader { 773 | return &Reader{ 774 | rd: bufio.NewReader(rd), 775 | buf: make([]byte, 4096), 776 | } 777 | } 778 | 779 | func parseInt(b []byte) (int, bool) { 780 | if len(b) == 1 && b[0] >= '0' && b[0] <= '9' { 781 | return int(b[0] - '0'), true 782 | } 783 | var n int 784 | var sign bool 785 | var i int 786 | if len(b) > 0 && b[0] == '-' { 787 | sign = true 788 | i++ 789 | } 790 | for ; i < len(b); i++ { 791 | if b[i] < '0' || b[i] > '9' { 792 | return 0, false 793 | } 794 | n = n*10 + int(b[i]-'0') 795 | } 796 | if sign { 797 | n *= -1 798 | } 799 | return n, true 800 | } 801 | 802 | func (rd *Reader) readCommands(leftover *int) ([]Command, error) { 803 | var cmds []Command 804 | b := rd.buf[rd.start:rd.end] 805 | if rd.end-rd.start == 0 && len(rd.buf) > 4096 { 806 | rd.buf = rd.buf[:4096] 807 | rd.start = 0 808 | rd.end = 0 809 | } 810 | if len(b) > 0 { 811 | // we have data, yay! 812 | // but is this enough data for a complete command? or multiple? 813 | next: 814 | switch b[0] { 815 | default: 816 | // just a plain text command 817 | for i := 0; i < len(b); i++ { 818 | if b[i] == '\n' { 819 | var line []byte 820 | if i > 0 && b[i-1] == '\r' { 821 | line = b[:i-1] 822 | } else { 823 | line = b[:i] 824 | } 825 | var cmd Command 826 | var quote bool 827 | var quotech byte 828 | var escape bool 829 | outer: 830 | for { 831 | nline := make([]byte, 0, len(line)) 832 | for i := 0; i < len(line); i++ { 833 | c := line[i] 834 | if !quote { 835 | if c == ' ' { 836 | if len(nline) > 0 { 837 | cmd.Args = append(cmd.Args, nline) 838 | } 839 | line = line[i+1:] 840 | continue outer 841 | } 842 | if c == '"' || c == '\'' { 843 | if i != 0 { 844 | return nil, errUnbalancedQuotes 845 | } 846 | quotech = c 847 | quote = true 848 | line = line[i+1:] 849 | continue outer 850 | } 851 | } else { 852 | if escape { 853 | escape = false 854 | switch c { 855 | case 'n': 856 | c = '\n' 857 | case 'r': 858 | c = '\r' 859 | case 't': 860 | c = '\t' 861 | } 862 | } else if c == quotech { 863 | quote = false 864 | quotech = 0 865 | cmd.Args = append(cmd.Args, nline) 866 | line = line[i+1:] 867 | if len(line) > 0 && line[0] != ' ' { 868 | return nil, errUnbalancedQuotes 869 | } 870 | continue outer 871 | } else if c == '\\' { 872 | escape = true 873 | continue 874 | } 875 | } 876 | nline = append(nline, c) 877 | } 878 | if quote { 879 | return nil, errUnbalancedQuotes 880 | } 881 | if len(line) > 0 { 882 | cmd.Args = append(cmd.Args, line) 883 | } 884 | break 885 | } 886 | if len(cmd.Args) > 0 { 887 | // convert this to resp command syntax 888 | var wr Writer 889 | wr.WriteArray(len(cmd.Args)) 890 | for i := range cmd.Args { 891 | wr.WriteBulk(cmd.Args[i]) 892 | cmd.Args[i] = append([]byte(nil), cmd.Args[i]...) 893 | } 894 | cmd.Raw = wr.b 895 | cmds = append(cmds, cmd) 896 | } 897 | b = b[i+1:] 898 | if len(b) > 0 { 899 | goto next 900 | } else { 901 | goto done 902 | } 903 | } 904 | } 905 | case '*': 906 | // resp formatted command 907 | marks := make([]int, 0, 16) 908 | outer2: 909 | for i := 1; i < len(b); i++ { 910 | if b[i] == '\n' { 911 | if b[i-1] != '\r' { 912 | return nil, errInvalidMultiBulkLength 913 | } 914 | count, ok := parseInt(b[1 : i-1]) 915 | if !ok || count <= 0 { 916 | return nil, errInvalidMultiBulkLength 917 | } 918 | marks = marks[:0] 919 | for j := 0; j < count; j++ { 920 | // read bulk length 921 | i++ 922 | if i < len(b) { 923 | if b[i] != '$' { 924 | return nil, &errProtocol{"expected '$', got '" + 925 | string(b[i]) + "'"} 926 | } 927 | si := i 928 | for ; i < len(b); i++ { 929 | if b[i] == '\n' { 930 | if b[i-1] != '\r' { 931 | return nil, errInvalidBulkLength 932 | } 933 | size, ok := parseInt(b[si+1 : i-1]) 934 | if !ok || size < 0 { 935 | return nil, errInvalidBulkLength 936 | } 937 | if i+size+2 >= len(b) { 938 | // not ready 939 | break outer2 940 | } 941 | if b[i+size+2] != '\n' || 942 | b[i+size+1] != '\r' { 943 | return nil, errInvalidBulkLength 944 | } 945 | i++ 946 | marks = append(marks, i, i+size) 947 | i += size + 1 948 | break 949 | } 950 | } 951 | } 952 | } 953 | if len(marks) == count*2 { 954 | var cmd Command 955 | if rd.rd != nil { 956 | // make a raw copy of the entire command when 957 | // there's a underlying reader. 958 | cmd.Raw = append([]byte(nil), b[:i+1]...) 959 | } else { 960 | // just assign the slice 961 | cmd.Raw = b[:i+1] 962 | } 963 | cmd.Args = make([][]byte, len(marks)/2) 964 | // slice up the raw command into the args based on 965 | // the recorded marks. 966 | for h := 0; h < len(marks); h += 2 { 967 | cmd.Args[h/2] = cmd.Raw[marks[h]:marks[h+1]] 968 | } 969 | cmds = append(cmds, cmd) 970 | b = b[i+1:] 971 | if len(b) > 0 { 972 | goto next 973 | } else { 974 | goto done 975 | } 976 | } 977 | } 978 | } 979 | } 980 | done: 981 | rd.start = rd.end - len(b) 982 | } 983 | if leftover != nil { 984 | *leftover = rd.end - rd.start 985 | } 986 | if len(cmds) > 0 { 987 | return cmds, nil 988 | } 989 | if rd.rd == nil { 990 | return nil, errIncompleteCommand 991 | } 992 | if rd.end == len(rd.buf) { 993 | // at the end of the buffer. 994 | if rd.start == rd.end { 995 | // rewind the to the beginning 996 | rd.start, rd.end = 0, 0 997 | } else { 998 | // must grow the buffer 999 | newbuf := make([]byte, len(rd.buf)*2) 1000 | copy(newbuf, rd.buf) 1001 | rd.buf = newbuf 1002 | } 1003 | } 1004 | n, err := rd.rd.Read(rd.buf[rd.end:]) 1005 | if err != nil { 1006 | return nil, err 1007 | } 1008 | rd.end += n 1009 | return rd.readCommands(leftover) 1010 | } 1011 | 1012 | // ReadCommands reads the next pipeline commands. 1013 | func (rd *Reader) ReadCommands() ([]Command, error) { 1014 | for { 1015 | if len(rd.cmds) > 0 { 1016 | cmds := rd.cmds 1017 | rd.cmds = nil 1018 | return cmds, nil 1019 | } 1020 | cmds, err := rd.readCommands(nil) 1021 | if err != nil { 1022 | return []Command{}, err 1023 | } 1024 | rd.cmds = cmds 1025 | } 1026 | } 1027 | 1028 | // ReadCommand reads the next command. 1029 | func (rd *Reader) ReadCommand() (Command, error) { 1030 | if len(rd.cmds) > 0 { 1031 | cmd := rd.cmds[0] 1032 | rd.cmds = rd.cmds[1:] 1033 | return cmd, nil 1034 | } 1035 | cmds, err := rd.readCommands(nil) 1036 | if err != nil { 1037 | return Command{}, err 1038 | } 1039 | rd.cmds = cmds 1040 | return rd.ReadCommand() 1041 | } 1042 | 1043 | // Parse parses a raw RESP message and returns a command. 1044 | func Parse(raw []byte) (Command, error) { 1045 | rd := Reader{buf: raw, end: len(raw)} 1046 | var leftover int 1047 | cmds, err := rd.readCommands(&leftover) 1048 | if err != nil { 1049 | return Command{}, err 1050 | } 1051 | if leftover > 0 { 1052 | return Command{}, errTooMuchData 1053 | } 1054 | return cmds[0], nil 1055 | 1056 | } 1057 | 1058 | // A Handler responds to an RESP request. 1059 | type Handler interface { 1060 | ServeRESP(conn Conn, cmd Command) 1061 | } 1062 | 1063 | // The HandlerFunc type is an adapter to allow the use of 1064 | // ordinary functions as RESP handlers. If f is a function 1065 | // with the appropriate signature, HandlerFunc(f) is a 1066 | // Handler that calls f. 1067 | type HandlerFunc func(conn Conn, cmd Command) 1068 | 1069 | // ServeRESP calls f(w, r) 1070 | func (f HandlerFunc) ServeRESP(conn Conn, cmd Command) { 1071 | f(conn, cmd) 1072 | } 1073 | 1074 | // ServeMux is an RESP command multiplexer. 1075 | type ServeMux struct { 1076 | handlers map[string]Handler 1077 | } 1078 | 1079 | // NewServeMux allocates and returns a new ServeMux. 1080 | func NewServeMux() *ServeMux { 1081 | return &ServeMux{ 1082 | handlers: make(map[string]Handler), 1083 | } 1084 | } 1085 | 1086 | // HandleFunc registers the handler function for the given command. 1087 | func (m *ServeMux) HandleFunc(command string, handler func(conn Conn, cmd Command)) { 1088 | if handler == nil { 1089 | panic("redcon: nil handler") 1090 | } 1091 | m.Handle(command, HandlerFunc(handler)) 1092 | } 1093 | 1094 | // Handle registers the handler for the given command. 1095 | // If a handler already exists for command, Handle panics. 1096 | func (m *ServeMux) Handle(command string, handler Handler) { 1097 | if command == "" { 1098 | panic("redcon: invalid command") 1099 | } 1100 | if handler == nil { 1101 | panic("redcon: nil handler") 1102 | } 1103 | if _, exist := m.handlers[command]; exist { 1104 | panic("redcon: multiple registrations for " + command) 1105 | } 1106 | 1107 | m.handlers[command] = handler 1108 | } 1109 | 1110 | // ServeRESP dispatches the command to the handler. 1111 | func (m *ServeMux) ServeRESP(conn Conn, cmd Command) { 1112 | command := strings.ToLower(string(cmd.Args[0])) 1113 | 1114 | if handler, ok := m.handlers[command]; ok { 1115 | handler.ServeRESP(conn, cmd) 1116 | } else { 1117 | conn.WriteError("ERR unknown command '" + command + "'") 1118 | } 1119 | } 1120 | 1121 | // PubSub is a Redis compatible pub/sub server 1122 | type PubSub struct { 1123 | mu sync.RWMutex 1124 | nextid uint64 1125 | initd bool 1126 | chans *btree.BTree 1127 | conns map[Conn]*pubSubConn 1128 | } 1129 | 1130 | // Subscribe a connection to PubSub 1131 | func (ps *PubSub) Subscribe(conn Conn, channel string) { 1132 | ps.subscribe(conn, false, channel) 1133 | } 1134 | 1135 | // Psubscribe a connection to PubSub 1136 | func (ps *PubSub) Psubscribe(conn Conn, channel string) { 1137 | ps.subscribe(conn, true, channel) 1138 | } 1139 | 1140 | // Publish a message to subscribers 1141 | func (ps *PubSub) Publish(channel, message string) int { 1142 | ps.mu.RLock() 1143 | defer ps.mu.RUnlock() 1144 | if !ps.initd { 1145 | return 0 1146 | } 1147 | var sent int 1148 | // write messages to all clients that are subscribed on the channel 1149 | pivot := &pubSubEntry{pattern: false, channel: channel} 1150 | ps.chans.Ascend(pivot, func(item interface{}) bool { 1151 | entry := item.(*pubSubEntry) 1152 | if entry.channel != pivot.channel || entry.pattern != pivot.pattern { 1153 | return false 1154 | } 1155 | entry.sconn.writeMessage(entry.pattern, "", channel, message) 1156 | sent++ 1157 | return true 1158 | }) 1159 | 1160 | // match on and write all psubscribe clients 1161 | pivot = &pubSubEntry{pattern: true} 1162 | ps.chans.Ascend(pivot, func(item interface{}) bool { 1163 | entry := item.(*pubSubEntry) 1164 | if match.Match(channel, entry.channel) { 1165 | entry.sconn.writeMessage(entry.pattern, entry.channel, channel, 1166 | message) 1167 | } 1168 | sent++ 1169 | return true 1170 | }) 1171 | 1172 | return sent 1173 | } 1174 | 1175 | type pubSubConn struct { 1176 | id uint64 1177 | mu sync.Mutex 1178 | conn Conn 1179 | dconn DetachedConn 1180 | entries map[*pubSubEntry]bool 1181 | } 1182 | 1183 | type pubSubEntry struct { 1184 | pattern bool 1185 | sconn *pubSubConn 1186 | channel string 1187 | } 1188 | 1189 | func (sconn *pubSubConn) writeMessage(pat bool, pchan, channel, msg string) { 1190 | sconn.mu.Lock() 1191 | defer sconn.mu.Unlock() 1192 | if pat { 1193 | sconn.dconn.WriteArray(4) 1194 | sconn.dconn.WriteBulkString("pmessage") 1195 | sconn.dconn.WriteBulkString(pchan) 1196 | sconn.dconn.WriteBulkString(channel) 1197 | sconn.dconn.WriteBulkString(msg) 1198 | } else { 1199 | sconn.dconn.WriteArray(3) 1200 | sconn.dconn.WriteBulkString("message") 1201 | sconn.dconn.WriteBulkString(channel) 1202 | sconn.dconn.WriteBulkString(msg) 1203 | } 1204 | sconn.dconn.Flush() 1205 | } 1206 | 1207 | // bgrunner runs in the background and reads incoming commands from the 1208 | // detached client. 1209 | func (sconn *pubSubConn) bgrunner(ps *PubSub) { 1210 | defer func() { 1211 | // client connection has ended, disconnect from the PubSub instances 1212 | // and close the network connection. 1213 | ps.mu.Lock() 1214 | defer ps.mu.Unlock() 1215 | for entry := range sconn.entries { 1216 | ps.chans.Delete(entry) 1217 | } 1218 | delete(ps.conns, sconn.conn) 1219 | sconn.mu.Lock() 1220 | defer sconn.mu.Unlock() 1221 | sconn.dconn.Close() 1222 | }() 1223 | for { 1224 | cmd, err := sconn.dconn.ReadCommand() 1225 | if err != nil { 1226 | return 1227 | } 1228 | if len(cmd.Args) == 0 { 1229 | continue 1230 | } 1231 | switch strings.ToLower(string(cmd.Args[0])) { 1232 | case "psubscribe", "subscribe": 1233 | if len(cmd.Args) < 2 { 1234 | func() { 1235 | sconn.mu.Lock() 1236 | defer sconn.mu.Unlock() 1237 | sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+ 1238 | "arguments for '%s'", cmd.Args[0])) 1239 | sconn.dconn.Flush() 1240 | }() 1241 | continue 1242 | } 1243 | command := strings.ToLower(string(cmd.Args[0])) 1244 | for i := 1; i < len(cmd.Args); i++ { 1245 | if command == "psubscribe" { 1246 | ps.Psubscribe(sconn.conn, string(cmd.Args[i])) 1247 | } else { 1248 | ps.Subscribe(sconn.conn, string(cmd.Args[i])) 1249 | } 1250 | } 1251 | case "unsubscribe", "punsubscribe": 1252 | pattern := strings.ToLower(string(cmd.Args[0])) == "punsubscribe" 1253 | if len(cmd.Args) == 1 { 1254 | ps.unsubscribe(sconn.conn, pattern, true, "") 1255 | } else { 1256 | for i := 1; i < len(cmd.Args); i++ { 1257 | channel := string(cmd.Args[i]) 1258 | ps.unsubscribe(sconn.conn, pattern, false, channel) 1259 | } 1260 | } 1261 | case "quit": 1262 | func() { 1263 | sconn.mu.Lock() 1264 | defer sconn.mu.Unlock() 1265 | sconn.dconn.WriteString("OK") 1266 | sconn.dconn.Flush() 1267 | sconn.dconn.Close() 1268 | }() 1269 | return 1270 | case "ping": 1271 | var msg string 1272 | switch len(cmd.Args) { 1273 | case 1: 1274 | case 2: 1275 | msg = string(cmd.Args[1]) 1276 | default: 1277 | func() { 1278 | sconn.mu.Lock() 1279 | defer sconn.mu.Unlock() 1280 | sconn.dconn.WriteError(fmt.Sprintf("ERR wrong number of "+ 1281 | "arguments for '%s'", cmd.Args[0])) 1282 | sconn.dconn.Flush() 1283 | }() 1284 | continue 1285 | } 1286 | func() { 1287 | sconn.mu.Lock() 1288 | defer sconn.mu.Unlock() 1289 | sconn.dconn.WriteArray(2) 1290 | sconn.dconn.WriteBulkString("pong") 1291 | sconn.dconn.WriteBulkString(msg) 1292 | sconn.dconn.Flush() 1293 | }() 1294 | default: 1295 | func() { 1296 | sconn.mu.Lock() 1297 | defer sconn.mu.Unlock() 1298 | sconn.dconn.WriteError(fmt.Sprintf("ERR Can't execute '%s': "+ 1299 | "only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT are "+ 1300 | "allowed in this context", cmd.Args[0])) 1301 | sconn.dconn.Flush() 1302 | }() 1303 | } 1304 | } 1305 | } 1306 | 1307 | // byEntry is a "less" function that sorts the entries in a btree. The tree 1308 | // is sorted be (pattern, channel, conn.id). All pattern=true entries are at 1309 | // the end (right) of the tree. 1310 | func byEntry(a, b interface{}) bool { 1311 | aa := a.(*pubSubEntry) 1312 | bb := b.(*pubSubEntry) 1313 | if !aa.pattern && bb.pattern { 1314 | return true 1315 | } 1316 | if aa.pattern && !bb.pattern { 1317 | return false 1318 | } 1319 | if aa.channel < bb.channel { 1320 | return true 1321 | } 1322 | if aa.channel > bb.channel { 1323 | return false 1324 | } 1325 | var aid uint64 1326 | var bid uint64 1327 | if aa.sconn != nil { 1328 | aid = aa.sconn.id 1329 | } 1330 | if bb.sconn != nil { 1331 | bid = bb.sconn.id 1332 | } 1333 | return aid < bid 1334 | } 1335 | 1336 | func (ps *PubSub) subscribe(conn Conn, pattern bool, channel string) { 1337 | ps.mu.Lock() 1338 | defer ps.mu.Unlock() 1339 | 1340 | // initialize the PubSub instance 1341 | if !ps.initd { 1342 | ps.conns = make(map[Conn]*pubSubConn) 1343 | ps.chans = btree.New(byEntry) 1344 | ps.initd = true 1345 | } 1346 | 1347 | // fetch the pubSubConn 1348 | sconn, ok := ps.conns[conn] 1349 | if !ok { 1350 | // initialize a new pubSubConn, which runs on a detached connection, 1351 | // and attach it to the PubSub channels/conn btree 1352 | ps.nextid++ 1353 | dconn := conn.Detach() 1354 | sconn = &pubSubConn{ 1355 | id: ps.nextid, 1356 | conn: conn, 1357 | dconn: dconn, 1358 | entries: make(map[*pubSubEntry]bool), 1359 | } 1360 | ps.conns[conn] = sconn 1361 | } 1362 | sconn.mu.Lock() 1363 | defer sconn.mu.Unlock() 1364 | 1365 | // add an entry to the pubsub btree 1366 | entry := &pubSubEntry{ 1367 | pattern: pattern, 1368 | channel: channel, 1369 | sconn: sconn, 1370 | } 1371 | ps.chans.Set(entry) 1372 | sconn.entries[entry] = true 1373 | 1374 | // send a message to the client 1375 | sconn.dconn.WriteArray(3) 1376 | if pattern { 1377 | sconn.dconn.WriteBulkString("psubscribe") 1378 | } else { 1379 | sconn.dconn.WriteBulkString("subscribe") 1380 | } 1381 | sconn.dconn.WriteBulkString(channel) 1382 | var count int 1383 | for entry := range sconn.entries { 1384 | if entry.pattern == pattern { 1385 | count++ 1386 | } 1387 | } 1388 | sconn.dconn.WriteInt(count) 1389 | sconn.dconn.Flush() 1390 | 1391 | // start the background client operation 1392 | if !ok { 1393 | go sconn.bgrunner(ps) 1394 | } 1395 | } 1396 | 1397 | func (ps *PubSub) unsubscribe(conn Conn, pattern, all bool, channel string) { 1398 | ps.mu.Lock() 1399 | defer ps.mu.Unlock() 1400 | // fetch the pubSubConn. This must exist 1401 | sconn := ps.conns[conn] 1402 | sconn.mu.Lock() 1403 | defer sconn.mu.Unlock() 1404 | 1405 | removeEntry := func(entry *pubSubEntry) { 1406 | if entry != nil { 1407 | ps.chans.Delete(entry) 1408 | delete(sconn.entries, entry) 1409 | } 1410 | sconn.dconn.WriteArray(3) 1411 | if pattern { 1412 | sconn.dconn.WriteBulkString("punsubscribe") 1413 | } else { 1414 | sconn.dconn.WriteBulkString("unsubscribe") 1415 | } 1416 | if entry != nil { 1417 | sconn.dconn.WriteBulkString(entry.channel) 1418 | } else { 1419 | sconn.dconn.WriteNull() 1420 | } 1421 | var count int 1422 | for entry := range sconn.entries { 1423 | if entry.pattern == pattern { 1424 | count++ 1425 | } 1426 | } 1427 | sconn.dconn.WriteInt(count) 1428 | } 1429 | if all { 1430 | // unsubscribe from all (p)subscribe entries 1431 | var entries []*pubSubEntry 1432 | for entry := range sconn.entries { 1433 | if entry.pattern == pattern { 1434 | entries = append(entries, entry) 1435 | } 1436 | } 1437 | if len(entries) == 0 { 1438 | removeEntry(nil) 1439 | } else { 1440 | for _, entry := range entries { 1441 | removeEntry(entry) 1442 | } 1443 | } 1444 | } else { 1445 | // unsubscribe single channel from (p)subscribe. 1446 | for entry := range sconn.entries { 1447 | if entry.pattern == pattern && entry.channel == channel { 1448 | removeEntry(entry) 1449 | break 1450 | } 1451 | } 1452 | } 1453 | sconn.dconn.Flush() 1454 | } 1455 | 1456 | // SetIdleClose will automatically close idle connections after the specified 1457 | // duration. Use zero to disable this feature. 1458 | func (s *Server) SetIdleClose(dur time.Duration) { 1459 | s.mu.Lock() 1460 | s.idleClose = dur 1461 | s.mu.Unlock() 1462 | } 1463 | -------------------------------------------------------------------------------- /redcon_test.go: -------------------------------------------------------------------------------- 1 | package redcon 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "io" 8 | "log" 9 | "math/rand" 10 | "net" 11 | "os" 12 | "strconv" 13 | "strings" 14 | "sync" 15 | "testing" 16 | "time" 17 | ) 18 | 19 | // TestRandomCommands fills a bunch of random commands and test various 20 | // ways that the reader may receive data. 21 | func TestRandomCommands(t *testing.T) { 22 | rand.Seed(time.Now().UnixNano()) 23 | 24 | // build random commands. 25 | gcmds := make([][]string, 10000) 26 | for i := 0; i < len(gcmds); i++ { 27 | args := make([]string, (rand.Int()%50)+1) // 1-50 args 28 | for j := 0; j < len(args); j++ { 29 | n := rand.Int() % 10 30 | if j == 0 { 31 | n++ 32 | } 33 | arg := make([]byte, n) 34 | for k := 0; k < len(arg); k++ { 35 | arg[k] = byte(rand.Int() % 0xFF) 36 | } 37 | args[j] = string(arg) 38 | } 39 | gcmds[i] = args 40 | } 41 | // create a list of a buffers 42 | var bufs []string 43 | 44 | // pipe valid RESP commands 45 | for i := 0; i < len(gcmds); i++ { 46 | args := gcmds[i] 47 | msg := fmt.Sprintf("*%d\r\n", len(args)) 48 | for j := 0; j < len(args); j++ { 49 | msg += fmt.Sprintf("$%d\r\n%s\r\n", len(args[j]), args[j]) 50 | } 51 | bufs = append(bufs, msg) 52 | } 53 | bufs = append(bufs, "RESET THE INDEX\r\n") 54 | 55 | // pipe valid plain commands 56 | for i := 0; i < len(gcmds); i++ { 57 | args := gcmds[i] 58 | var msg string 59 | for j := 0; j < len(args); j++ { 60 | quotes := false 61 | var narg []byte 62 | arg := args[j] 63 | if len(arg) == 0 { 64 | quotes = true 65 | } 66 | for k := 0; k < len(arg); k++ { 67 | switch arg[k] { 68 | default: 69 | narg = append(narg, arg[k]) 70 | case ' ': 71 | quotes = true 72 | narg = append(narg, arg[k]) 73 | case '\\', '"', '*': 74 | quotes = true 75 | narg = append(narg, '\\', arg[k]) 76 | case '\r': 77 | quotes = true 78 | narg = append(narg, '\\', 'r') 79 | case '\n': 80 | quotes = true 81 | narg = append(narg, '\\', 'n') 82 | } 83 | } 84 | msg += " " 85 | if quotes { 86 | msg += "\"" 87 | } 88 | msg += string(narg) 89 | if quotes { 90 | msg += "\"" 91 | } 92 | } 93 | if msg != "" { 94 | msg = msg[1:] 95 | } 96 | msg += "\r\n" 97 | bufs = append(bufs, msg) 98 | } 99 | bufs = append(bufs, "RESET THE INDEX\r\n") 100 | 101 | // pipe valid RESP commands in broken chunks 102 | lmsg := "" 103 | for i := 0; i < len(gcmds); i++ { 104 | args := gcmds[i] 105 | msg := fmt.Sprintf("*%d\r\n", len(args)) 106 | for j := 0; j < len(args); j++ { 107 | msg += fmt.Sprintf("$%d\r\n%s\r\n", len(args[j]), args[j]) 108 | } 109 | msg = lmsg + msg 110 | if len(msg) > 0 { 111 | lmsg = msg[len(msg)/2:] 112 | msg = msg[:len(msg)/2] 113 | } 114 | bufs = append(bufs, msg) 115 | } 116 | bufs = append(bufs, lmsg) 117 | bufs = append(bufs, "RESET THE INDEX\r\n") 118 | 119 | // pipe valid RESP commands in large broken chunks 120 | lmsg = "" 121 | for i := 0; i < len(gcmds); i++ { 122 | args := gcmds[i] 123 | msg := fmt.Sprintf("*%d\r\n", len(args)) 124 | for j := 0; j < len(args); j++ { 125 | msg += fmt.Sprintf("$%d\r\n%s\r\n", len(args[j]), args[j]) 126 | } 127 | if len(lmsg) < 1500 { 128 | lmsg += msg 129 | continue 130 | } 131 | msg = lmsg + msg 132 | if len(msg) > 0 { 133 | lmsg = msg[len(msg)/2:] 134 | msg = msg[:len(msg)/2] 135 | } 136 | bufs = append(bufs, msg) 137 | } 138 | bufs = append(bufs, lmsg) 139 | bufs = append(bufs, "RESET THE INDEX\r\n") 140 | 141 | // Pipe the buffers in a background routine 142 | rd, wr := io.Pipe() 143 | go func() { 144 | defer wr.Close() 145 | for _, msg := range bufs { 146 | io.WriteString(wr, msg) 147 | } 148 | }() 149 | defer rd.Close() 150 | cnt := 0 151 | idx := 0 152 | start := time.Now() 153 | r := NewReader(rd) 154 | for { 155 | cmd, err := r.ReadCommand() 156 | if err != nil { 157 | if err == io.EOF { 158 | break 159 | } 160 | log.Fatal(err) 161 | } 162 | if len(cmd.Args) == 3 && string(cmd.Args[0]) == "RESET" && 163 | string(cmd.Args[1]) == "THE" && string(cmd.Args[2]) == "INDEX" { 164 | if idx != len(gcmds) { 165 | t.Fatalf("did not process all commands") 166 | } 167 | idx = 0 168 | break 169 | } 170 | if len(cmd.Args) != len(gcmds[idx]) { 171 | t.Fatalf("len not equal for index %d -- %d != %d", idx, len(cmd.Args), len(gcmds[idx])) 172 | } 173 | for i := 0; i < len(cmd.Args); i++ { 174 | if i == 0 { 175 | if len(cmd.Args[i]) == len(gcmds[idx][i]) { 176 | ok := true 177 | for j := 0; j < len(cmd.Args[i]); j++ { 178 | c1, c2 := cmd.Args[i][j], gcmds[idx][i][j] 179 | if c1 >= 'A' && c1 <= 'Z' { 180 | c1 += 32 181 | } 182 | if c2 >= 'A' && c2 <= 'Z' { 183 | c2 += 32 184 | } 185 | if c1 != c2 { 186 | ok = false 187 | break 188 | } 189 | } 190 | if ok { 191 | continue 192 | } 193 | } 194 | } else if string(cmd.Args[i]) == string(gcmds[idx][i]) { 195 | continue 196 | } 197 | t.Fatalf("not equal for index %d/%d", idx, i) 198 | } 199 | idx++ 200 | cnt++ 201 | } 202 | if false { 203 | dur := time.Since(start) 204 | fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second))) 205 | } 206 | } 207 | func testDetached(conn DetachedConn) { 208 | conn.WriteString("DETACHED") 209 | if err := conn.Flush(); err != nil { 210 | panic(err) 211 | } 212 | } 213 | func TestServerTCP(t *testing.T) { 214 | testServerNetwork(t, "tcp", ":12345") 215 | } 216 | func TestServerUnix(t *testing.T) { 217 | os.RemoveAll("/tmp/redcon-unix.sock") 218 | defer os.RemoveAll("/tmp/redcon-unix.sock") 219 | testServerNetwork(t, "unix", "/tmp/redcon-unix.sock") 220 | } 221 | 222 | func testServerNetwork(t *testing.T, network, laddr string) { 223 | s := NewServerNetwork(network, laddr, 224 | func(conn Conn, cmd Command) { 225 | switch strings.ToLower(string(cmd.Args[0])) { 226 | default: 227 | conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") 228 | case "ping": 229 | conn.WriteString("PONG") 230 | case "quit": 231 | conn.WriteString("OK") 232 | conn.Close() 233 | case "detach": 234 | go testDetached(conn.Detach()) 235 | case "int": 236 | conn.WriteInt(100) 237 | case "bulk": 238 | conn.WriteBulkString("bulk") 239 | case "bulkbytes": 240 | conn.WriteBulk([]byte("bulkbytes")) 241 | case "null": 242 | conn.WriteNull() 243 | case "err": 244 | conn.WriteError("ERR error") 245 | case "array": 246 | conn.WriteArray(2) 247 | conn.WriteInt(99) 248 | conn.WriteString("Hi!") 249 | } 250 | }, 251 | func(conn Conn) bool { 252 | //log.Printf("accept: %s", conn.RemoteAddr()) 253 | return true 254 | }, 255 | func(conn Conn, err error) { 256 | //log.Printf("closed: %s [%v]", conn.RemoteAddr(), err) 257 | }, 258 | ) 259 | if err := s.Close(); err == nil { 260 | t.Fatalf("expected an error, should not be able to close before serving") 261 | } 262 | go func() { 263 | time.Sleep(time.Second / 4) 264 | if err := ListenAndServeNetwork(network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil { 265 | panic("expected an error, should not be able to listen on the same port") 266 | } 267 | time.Sleep(time.Second / 4) 268 | 269 | err := s.Close() 270 | if err != nil { 271 | panic(err) 272 | } 273 | err = s.Close() 274 | if err == nil { 275 | panic("expected an error") 276 | } 277 | }() 278 | done := make(chan bool) 279 | signal := make(chan error) 280 | go func() { 281 | defer func() { 282 | done <- true 283 | }() 284 | err := <-signal 285 | if err != nil { 286 | panic(err) 287 | } 288 | c, err := net.Dial(network, laddr) 289 | if err != nil { 290 | panic(err) 291 | } 292 | defer c.Close() 293 | do := func(cmd string) (string, error) { 294 | io.WriteString(c, cmd) 295 | buf := make([]byte, 1024) 296 | n, err := c.Read(buf) 297 | if err != nil { 298 | return "", err 299 | } 300 | return string(buf[:n]), nil 301 | } 302 | res, err := do("PING\r\n") 303 | if err != nil { 304 | panic(err) 305 | } 306 | if res != "+PONG\r\n" { 307 | panic(fmt.Sprintf("expecting '+PONG\r\n', got '%v'", res)) 308 | } 309 | res, err = do("BULK\r\n") 310 | if err != nil { 311 | panic(err) 312 | } 313 | if res != "$4\r\nbulk\r\n" { 314 | panic(fmt.Sprintf("expecting bulk, got '%v'", res)) 315 | } 316 | res, err = do("BULKBYTES\r\n") 317 | if err != nil { 318 | panic(err) 319 | } 320 | if res != "$9\r\nbulkbytes\r\n" { 321 | panic(fmt.Sprintf("expecting bulkbytes, got '%v'", res)) 322 | } 323 | res, err = do("INT\r\n") 324 | if err != nil { 325 | panic(err) 326 | } 327 | if res != ":100\r\n" { 328 | panic(fmt.Sprintf("expecting int, got '%v'", res)) 329 | } 330 | res, err = do("NULL\r\n") 331 | if err != nil { 332 | panic(err) 333 | } 334 | if res != "$-1\r\n" { 335 | panic(fmt.Sprintf("expecting nul, got '%v'", res)) 336 | } 337 | res, err = do("ARRAY\r\n") 338 | if err != nil { 339 | panic(err) 340 | } 341 | if res != "*2\r\n:99\r\n+Hi!\r\n" { 342 | panic(fmt.Sprintf("expecting array, got '%v'", res)) 343 | } 344 | res, err = do("ERR\r\n") 345 | if err != nil { 346 | panic(err) 347 | } 348 | if res != "-ERR error\r\n" { 349 | panic(fmt.Sprintf("expecting array, got '%v'", res)) 350 | } 351 | res, err = do("DETACH\r\n") 352 | if err != nil { 353 | panic(err) 354 | } 355 | if res != "+DETACHED\r\n" { 356 | panic(fmt.Sprintf("expecting string, got '%v'", res)) 357 | } 358 | }() 359 | go func() { 360 | err := s.ListenServeAndSignal(signal) 361 | if err != nil { 362 | panic(err) 363 | } 364 | }() 365 | <-done 366 | } 367 | 368 | func TestConnImpl(t *testing.T) { 369 | var i interface{} = &conn{} 370 | if _, ok := i.(Conn); !ok { 371 | t.Fatalf("conn does not implement Conn interface") 372 | } 373 | } 374 | 375 | func TestWriteBulkFrom(t *testing.T) { 376 | wbuf := &bytes.Buffer{} 377 | wr := NewWriter(wbuf) 378 | rbuf := &bytes.Buffer{} 379 | testStr := "hello world" 380 | rbuf.WriteString(testStr) 381 | wr.WriteBulkFrom(int64(len(testStr)), rbuf) 382 | wr.Flush() 383 | if wbuf.String() != fmt.Sprintf("$%d\r\n%s\r\n", len(testStr), testStr) { 384 | t.Fatal("failed") 385 | } 386 | wbuf.Reset() 387 | testStr1 := "hi world" 388 | rbuf.WriteString(testStr1) 389 | wr.WriteBulkFrom(int64(len(testStr1)), rbuf) 390 | wr.Flush() 391 | if wbuf.String() != fmt.Sprintf("$%d\r\n%s\r\n", len(testStr1), testStr1) { 392 | t.Fatal("failed") 393 | } 394 | wbuf.Reset() 395 | } 396 | 397 | func TestWriter(t *testing.T) { 398 | buf := &bytes.Buffer{} 399 | wr := NewWriter(buf) 400 | wr.WriteError("ERR bad stuff") 401 | wr.Flush() 402 | if buf.String() != "-ERR bad stuff\r\n" { 403 | t.Fatal("failed") 404 | } 405 | buf.Reset() 406 | wr.WriteString("HELLO") 407 | wr.Flush() 408 | if buf.String() != "+HELLO\r\n" { 409 | t.Fatal("failed") 410 | } 411 | buf.Reset() 412 | wr.WriteInt(-1234) 413 | wr.Flush() 414 | if buf.String() != ":-1234\r\n" { 415 | t.Fatal("failed") 416 | } 417 | buf.Reset() 418 | wr.WriteNull() 419 | wr.Flush() 420 | if buf.String() != "$-1\r\n" { 421 | t.Fatal("failed") 422 | } 423 | buf.Reset() 424 | wr.WriteBulk([]byte("HELLO\r\nPLANET")) 425 | wr.Flush() 426 | if buf.String() != "$13\r\nHELLO\r\nPLANET\r\n" { 427 | t.Fatal("failed") 428 | } 429 | buf.Reset() 430 | wr.WriteBulkString("HELLO\r\nPLANET") 431 | wr.Flush() 432 | if buf.String() != "$13\r\nHELLO\r\nPLANET\r\n" { 433 | t.Fatal("failed") 434 | } 435 | buf.Reset() 436 | wr.WriteArray(3) 437 | wr.WriteBulkString("THIS") 438 | wr.WriteBulkString("THAT") 439 | wr.WriteString("THE OTHER THING") 440 | wr.Flush() 441 | if buf.String() != "*3\r\n$4\r\nTHIS\r\n$4\r\nTHAT\r\n+THE OTHER THING\r\n" { 442 | t.Fatal("failed") 443 | } 444 | buf.Reset() 445 | } 446 | func testMakeRawCommands(rawargs [][]string) []string { 447 | var rawcmds []string 448 | for i := 0; i < len(rawargs); i++ { 449 | rawcmd := "*" + strconv.FormatUint(uint64(len(rawargs[i])), 10) + "\r\n" 450 | for j := 0; j < len(rawargs[i]); j++ { 451 | rawcmd += "$" + strconv.FormatUint(uint64(len(rawargs[i][j])), 10) + "\r\n" 452 | rawcmd += rawargs[i][j] + "\r\n" 453 | } 454 | rawcmds = append(rawcmds, rawcmd) 455 | } 456 | return rawcmds 457 | } 458 | 459 | func TestReaderRespRandom(t *testing.T) { 460 | rand.Seed(time.Now().UnixNano()) 461 | for h := 0; h < 10000; h++ { 462 | var rawargs [][]string 463 | for i := 0; i < 100; i++ { 464 | // var args []string 465 | n := int(rand.Int() % 16) 466 | for j := 0; j < n; j++ { 467 | arg := make([]byte, rand.Int()%512) 468 | rand.Read(arg) 469 | // args = append(args, string(arg)) 470 | } 471 | } 472 | rawcmds := testMakeRawCommands(rawargs) 473 | data := strings.Join(rawcmds, "") 474 | rd := NewReader(bytes.NewBufferString(data)) 475 | for i := 0; i < len(rawcmds); i++ { 476 | if len(rawargs[i]) == 0 { 477 | continue 478 | } 479 | cmd, err := rd.ReadCommand() 480 | if err != nil { 481 | t.Fatal(err) 482 | } 483 | if string(cmd.Raw) != rawcmds[i] { 484 | t.Fatalf("expected '%v', got '%v'", rawcmds[i], string(cmd.Raw)) 485 | } 486 | if len(cmd.Args) != len(rawargs[i]) { 487 | t.Fatalf("expected '%v', got '%v'", len(rawargs[i]), len(cmd.Args)) 488 | } 489 | for j := 0; j < len(rawargs[i]); j++ { 490 | if string(cmd.Args[j]) != rawargs[i][j] { 491 | t.Fatalf("expected '%v', got '%v'", rawargs[i][j], string(cmd.Args[j])) 492 | } 493 | } 494 | } 495 | } 496 | } 497 | 498 | func TestPlainReader(t *testing.T) { 499 | rawargs := [][]string{ 500 | {"HELLO", "WORLD"}, 501 | {"HELLO", "WORLD"}, 502 | {"HELLO", "PLANET"}, 503 | {"HELLO", "JELLO"}, 504 | {"HELLO ", "JELLO"}, 505 | } 506 | rawcmds := []string{ 507 | "HELLO WORLD\n", 508 | "HELLO WORLD\r\n", 509 | " HELLO PLANET \r\n", 510 | " \"HELLO\" \"JELLO\" \r\n", 511 | " \"HELLO \" JELLO \n", 512 | } 513 | rawres := []string{ 514 | "*2\r\n$5\r\nHELLO\r\n$5\r\nWORLD\r\n", 515 | "*2\r\n$5\r\nHELLO\r\n$5\r\nWORLD\r\n", 516 | "*2\r\n$5\r\nHELLO\r\n$6\r\nPLANET\r\n", 517 | "*2\r\n$5\r\nHELLO\r\n$5\r\nJELLO\r\n", 518 | "*2\r\n$6\r\nHELLO \r\n$5\r\nJELLO\r\n", 519 | } 520 | data := strings.Join(rawcmds, "") 521 | rd := NewReader(bytes.NewBufferString(data)) 522 | for i := 0; i < len(rawcmds); i++ { 523 | if len(rawargs[i]) == 0 { 524 | continue 525 | } 526 | cmd, err := rd.ReadCommand() 527 | if err != nil { 528 | t.Fatal(err) 529 | } 530 | if string(cmd.Raw) != rawres[i] { 531 | t.Fatalf("expected '%v', got '%v'", rawres[i], string(cmd.Raw)) 532 | } 533 | if len(cmd.Args) != len(rawargs[i]) { 534 | t.Fatalf("expected '%v', got '%v'", len(rawargs[i]), len(cmd.Args)) 535 | } 536 | for j := 0; j < len(rawargs[i]); j++ { 537 | if string(cmd.Args[j]) != rawargs[i][j] { 538 | t.Fatalf("expected '%v', got '%v'", rawargs[i][j], string(cmd.Args[j])) 539 | } 540 | } 541 | } 542 | } 543 | 544 | func TestParse(t *testing.T) { 545 | _, err := Parse(nil) 546 | if err != errIncompleteCommand { 547 | t.Fatalf("expected '%v', got '%v'", errIncompleteCommand, err) 548 | } 549 | _, err = Parse([]byte("*1\r\n")) 550 | if err != errIncompleteCommand { 551 | t.Fatalf("expected '%v', got '%v'", errIncompleteCommand, err) 552 | } 553 | _, err = Parse([]byte("*-1\r\n")) 554 | if err != errInvalidMultiBulkLength { 555 | t.Fatalf("expected '%v', got '%v'", errInvalidMultiBulkLength, err) 556 | } 557 | _, err = Parse([]byte("*0\r\n")) 558 | if err != errInvalidMultiBulkLength { 559 | t.Fatalf("expected '%v', got '%v'", errInvalidMultiBulkLength, err) 560 | } 561 | cmd, err := Parse([]byte("*1\r\n$1\r\nA\r\n")) 562 | if err != nil { 563 | t.Fatal(err) 564 | } 565 | if string(cmd.Raw) != "*1\r\n$1\r\nA\r\n" { 566 | t.Fatalf("expected '%v', got '%v'", "*1\r\n$1\r\nA\r\n", string(cmd.Raw)) 567 | } 568 | if len(cmd.Args) != 1 { 569 | t.Fatalf("expected '%v', got '%v'", 1, len(cmd.Args)) 570 | } 571 | if string(cmd.Args[0]) != "A" { 572 | t.Fatalf("expected '%v', got '%v'", "A", string(cmd.Args[0])) 573 | } 574 | cmd, err = Parse([]byte("A\r\n")) 575 | if err != nil { 576 | t.Fatal(err) 577 | } 578 | if string(cmd.Raw) != "*1\r\n$1\r\nA\r\n" { 579 | t.Fatalf("expected '%v', got '%v'", "*1\r\n$1\r\nA\r\n", string(cmd.Raw)) 580 | } 581 | if len(cmd.Args) != 1 { 582 | t.Fatalf("expected '%v', got '%v'", 1, len(cmd.Args)) 583 | } 584 | if string(cmd.Args[0]) != "A" { 585 | t.Fatalf("expected '%v', got '%v'", "A", string(cmd.Args[0])) 586 | } 587 | } 588 | 589 | func TestPubSub(t *testing.T) { 590 | addr := ":12346" 591 | done := make(chan bool) 592 | go func() { 593 | var ps PubSub 594 | go func() { 595 | tch := time.NewTicker(time.Millisecond * 5) 596 | defer tch.Stop() 597 | channels := []string{"achan1", "bchan2", "cchan3", "dchan4"} 598 | for i := 0; ; i++ { 599 | select { 600 | case <-tch.C: 601 | case <-done: 602 | for { 603 | var empty bool 604 | ps.mu.Lock() 605 | if len(ps.conns) == 0 { 606 | if ps.chans.Len() != 0 { 607 | panic("chans not empty") 608 | } 609 | empty = true 610 | } 611 | ps.mu.Unlock() 612 | if empty { 613 | break 614 | } 615 | time.Sleep(time.Millisecond * 10) 616 | } 617 | done <- true 618 | return 619 | } 620 | channel := channels[i%len(channels)] 621 | message := fmt.Sprintf("message %d", i) 622 | ps.Publish(channel, message) 623 | } 624 | }() 625 | panic(ListenAndServe(addr, func(conn Conn, cmd Command) { 626 | switch strings.ToLower(string(cmd.Args[0])) { 627 | default: 628 | conn.WriteError("ERR unknown command '" + 629 | string(cmd.Args[0]) + "'") 630 | case "publish": 631 | if len(cmd.Args) != 3 { 632 | conn.WriteError("ERR wrong number of arguments for '" + 633 | string(cmd.Args[0]) + "' command") 634 | return 635 | } 636 | count := ps.Publish(string(cmd.Args[1]), string(cmd.Args[2])) 637 | conn.WriteInt(count) 638 | case "subscribe", "psubscribe": 639 | if len(cmd.Args) < 2 { 640 | conn.WriteError("ERR wrong number of arguments for '" + 641 | string(cmd.Args[0]) + "' command") 642 | return 643 | } 644 | command := strings.ToLower(string(cmd.Args[0])) 645 | for i := 1; i < len(cmd.Args); i++ { 646 | if command == "psubscribe" { 647 | ps.Psubscribe(conn, string(cmd.Args[i])) 648 | } else { 649 | ps.Subscribe(conn, string(cmd.Args[i])) 650 | } 651 | } 652 | } 653 | }, nil, nil)) 654 | }() 655 | 656 | final := make(chan bool) 657 | go func() { 658 | select { 659 | case <-time.Tick(time.Second * 30): 660 | panic("timeout") 661 | case <-final: 662 | return 663 | } 664 | }() 665 | 666 | // create 10 connections 667 | var wg sync.WaitGroup 668 | wg.Add(10) 669 | for i := 0; i < 10; i++ { 670 | go func(i int) { 671 | defer wg.Done() 672 | var conn net.Conn 673 | for i := 0; i < 5; i++ { 674 | var err error 675 | conn, err = net.Dial("tcp", addr) 676 | if err != nil { 677 | time.Sleep(time.Second / 10) 678 | continue 679 | } 680 | } 681 | if conn == nil { 682 | panic("could not connect to server") 683 | } 684 | defer conn.Close() 685 | 686 | regs := make(map[string]int) 687 | var maxp int 688 | var maxs int 689 | fmt.Fprintf(conn, "subscribe achan1\r\n") 690 | fmt.Fprintf(conn, "subscribe bchan2 cchan3\r\n") 691 | fmt.Fprintf(conn, "psubscribe a*1\r\n") 692 | fmt.Fprintf(conn, "psubscribe b*2 c*3\r\n") 693 | 694 | // collect 50 messages from each channel 695 | rd := bufio.NewReader(conn) 696 | var buf []byte 697 | for { 698 | line, err := rd.ReadBytes('\n') 699 | if err != nil { 700 | panic(err) 701 | } 702 | buf = append(buf, line...) 703 | n, resp := ReadNextRESP(buf) 704 | if n == 0 { 705 | continue 706 | } 707 | buf = nil 708 | if resp.Type != Array { 709 | panic("expected array") 710 | } 711 | var vals []RESP 712 | resp.ForEach(func(item RESP) bool { 713 | vals = append(vals, item) 714 | return true 715 | }) 716 | 717 | name := string(vals[0].Data) 718 | switch name { 719 | case "subscribe": 720 | if len(vals) != 3 { 721 | panic("invalid count") 722 | } 723 | ch := string(vals[1].Data) 724 | regs[ch] = 0 725 | maxs, _ = strconv.Atoi(string(vals[2].Data)) 726 | case "psubscribe": 727 | if len(vals) != 3 { 728 | panic("invalid count") 729 | } 730 | ch := string(vals[1].Data) 731 | regs[ch] = 0 732 | maxp, _ = strconv.Atoi(string(vals[2].Data)) 733 | case "message": 734 | if len(vals) != 3 { 735 | panic("invalid count") 736 | } 737 | ch := string(vals[1].Data) 738 | regs[ch] = regs[ch] + 1 739 | case "pmessage": 740 | if len(vals) != 4 { 741 | panic("invalid count") 742 | } 743 | ch := string(vals[1].Data) 744 | regs[ch] = regs[ch] + 1 745 | } 746 | if len(regs) == 6 && maxp == 3 && maxs == 3 { 747 | ready := true 748 | for _, count := range regs { 749 | if count < 50 { 750 | ready = false 751 | break 752 | } 753 | } 754 | if ready { 755 | // all messages have been received 756 | return 757 | } 758 | } 759 | } 760 | }(i) 761 | } 762 | wg.Wait() 763 | // notify sender 764 | done <- true 765 | // wait for sender 766 | <-done 767 | // stop the timeout 768 | final <- true 769 | } 770 | -------------------------------------------------------------------------------- /resp.go: -------------------------------------------------------------------------------- 1 | package redcon 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "sort" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | // Type of RESP 12 | type Type byte 13 | 14 | // Various RESP kinds 15 | const ( 16 | Integer = ':' 17 | String = '+' 18 | Bulk = '$' 19 | Array = '*' 20 | Error = '-' 21 | ) 22 | 23 | type RESP struct { 24 | Type Type 25 | Raw []byte 26 | Data []byte 27 | Count int 28 | } 29 | 30 | // ForEach iterates over each Array element 31 | func (r RESP) ForEach(iter func(resp RESP) bool) { 32 | data := r.Data 33 | for i := 0; i < r.Count; i++ { 34 | n, resp := ReadNextRESP(data) 35 | if !iter(resp) { 36 | return 37 | } 38 | data = data[n:] 39 | } 40 | } 41 | 42 | func (r RESP) Bytes() []byte { 43 | return r.Data 44 | } 45 | 46 | func (r RESP) String() string { 47 | return string(r.Data) 48 | } 49 | 50 | func (r RESP) Int() int64 { 51 | x, _ := strconv.ParseInt(r.String(), 10, 64) 52 | return x 53 | } 54 | 55 | func (r RESP) Float() float64 { 56 | x, _ := strconv.ParseFloat(r.String(), 10) 57 | return x 58 | } 59 | 60 | // Map returns a key/value map of an Array. 61 | // The receiver RESP must be an Array with an equal number of values, where 62 | // the value of the key is followed by the key. 63 | // Example: key1,value1,key2,value2,key3,value3 64 | func (r RESP) Map() map[string]RESP { 65 | if r.Type != Array { 66 | return nil 67 | } 68 | var n int 69 | var key string 70 | m := make(map[string]RESP) 71 | r.ForEach(func(resp RESP) bool { 72 | if n&1 == 0 { 73 | key = resp.String() 74 | } else { 75 | m[key] = resp 76 | } 77 | n++ 78 | return true 79 | }) 80 | return m 81 | } 82 | 83 | func (r RESP) MapGet(key string) RESP { 84 | if r.Type != Array { 85 | return RESP{} 86 | } 87 | var val RESP 88 | var n int 89 | var ok bool 90 | r.ForEach(func(resp RESP) bool { 91 | if n&1 == 0 { 92 | ok = resp.String() == key 93 | } else if ok { 94 | val = resp 95 | return false 96 | } 97 | n++ 98 | return true 99 | }) 100 | return val 101 | } 102 | 103 | func (r RESP) Exists() bool { 104 | return r.Type != 0 105 | } 106 | 107 | // ReadNextRESP returns the next resp in b and returns the number of bytes the 108 | // took up the result. 109 | func ReadNextRESP(b []byte) (n int, resp RESP) { 110 | if len(b) == 0 { 111 | return 0, RESP{} // no data to read 112 | } 113 | resp.Type = Type(b[0]) 114 | switch resp.Type { 115 | case Integer, String, Bulk, Array, Error: 116 | default: 117 | return 0, RESP{} // invalid kind 118 | } 119 | // read to end of line 120 | i := 1 121 | for ; ; i++ { 122 | if i == len(b) { 123 | return 0, RESP{} // not enough data 124 | } 125 | if b[i] == '\n' { 126 | if b[i-1] != '\r' { 127 | return 0, RESP{} //, missing CR character 128 | } 129 | i++ 130 | break 131 | } 132 | } 133 | resp.Raw = b[0:i] 134 | resp.Data = b[1 : i-2] 135 | if resp.Type == Integer { 136 | // Integer 137 | if len(resp.Data) == 0 { 138 | return 0, RESP{} //, invalid integer 139 | } 140 | var j int 141 | if resp.Data[0] == '-' { 142 | if len(resp.Data) == 1 { 143 | return 0, RESP{} //, invalid integer 144 | } 145 | j++ 146 | } 147 | for ; j < len(resp.Data); j++ { 148 | if resp.Data[j] < '0' || resp.Data[j] > '9' { 149 | return 0, RESP{} // invalid integer 150 | } 151 | } 152 | return len(resp.Raw), resp 153 | } 154 | if resp.Type == String || resp.Type == Error { 155 | // String, Error 156 | return len(resp.Raw), resp 157 | } 158 | var err error 159 | resp.Count, err = strconv.Atoi(string(resp.Data)) 160 | if resp.Type == Bulk { 161 | // Bulk 162 | if err != nil { 163 | return 0, RESP{} // invalid number of bytes 164 | } 165 | if resp.Count < 0 { 166 | resp.Data = nil 167 | resp.Count = 0 168 | return len(resp.Raw), resp 169 | } 170 | if len(b) < i+resp.Count+2 { 171 | return 0, RESP{} // not enough data 172 | } 173 | if b[i+resp.Count] != '\r' || b[i+resp.Count+1] != '\n' { 174 | return 0, RESP{} // invalid end of line 175 | } 176 | resp.Data = b[i : i+resp.Count] 177 | resp.Raw = b[0 : i+resp.Count+2] 178 | resp.Count = 0 179 | return len(resp.Raw), resp 180 | } 181 | // Array 182 | if err != nil { 183 | return 0, RESP{} // invalid number of elements 184 | } 185 | var tn int 186 | sdata := b[i:] 187 | for j := 0; j < resp.Count; j++ { 188 | rn, rresp := ReadNextRESP(sdata) 189 | if rresp.Type == 0 { 190 | return 0, RESP{} 191 | } 192 | tn += rn 193 | sdata = sdata[rn:] 194 | } 195 | resp.Data = b[i : i+tn] 196 | resp.Raw = b[0 : i+tn] 197 | return len(resp.Raw), resp 198 | } 199 | 200 | // Kind is the kind of command 201 | type Kind int 202 | 203 | const ( 204 | // Redis is returned for Redis protocol commands 205 | Redis Kind = iota 206 | // Tile38 is returnd for Tile38 native protocol commands 207 | Tile38 208 | // Telnet is returnd for plain telnet commands 209 | Telnet 210 | ) 211 | 212 | var errInvalidMessage = &errProtocol{"invalid message"} 213 | 214 | // ReadNextCommand reads the next command from the provided packet. It's 215 | // possible that the packet contains multiple commands, or zero commands 216 | // when the packet is incomplete. 217 | // 'argsbuf' is an optional reusable buffer and it can be nil. 218 | // 'complete' indicates that a command was read. false means no more commands. 219 | // 'args' are the output arguments for the command. 220 | // 'kind' is the type of command that was read. 221 | // 'leftover' is any remaining unused bytes which belong to the next command. 222 | // 'err' is returned when a protocol error was encountered. 223 | func ReadNextCommand(packet []byte, argsbuf [][]byte) ( 224 | complete bool, args [][]byte, kind Kind, leftover []byte, err error, 225 | ) { 226 | args = argsbuf[:0] 227 | if len(packet) > 0 { 228 | if packet[0] != '*' { 229 | if packet[0] == '$' { 230 | return readTile38Command(packet, args) 231 | } 232 | return readTelnetCommand(packet, args) 233 | } 234 | // standard redis command 235 | for s, i := 1, 1; i < len(packet); i++ { 236 | if packet[i] == '\n' { 237 | if packet[i-1] != '\r' { 238 | return false, args[:0], Redis, packet, errInvalidMultiBulkLength 239 | } 240 | count, ok := parseInt(packet[s : i-1]) 241 | if !ok || count < 0 { 242 | return false, args[:0], Redis, packet, errInvalidMultiBulkLength 243 | } 244 | i++ 245 | if count == 0 { 246 | return true, args[:0], Redis, packet[i:], nil 247 | } 248 | nextArg: 249 | for j := 0; j < count; j++ { 250 | if i == len(packet) { 251 | break 252 | } 253 | if packet[i] != '$' { 254 | return false, args[:0], Redis, packet, 255 | &errProtocol{"expected '$', got '" + 256 | string(packet[i]) + "'"} 257 | } 258 | for s := i + 1; i < len(packet); i++ { 259 | if packet[i] == '\n' { 260 | if packet[i-1] != '\r' { 261 | return false, args[:0], Redis, packet, errInvalidBulkLength 262 | } 263 | n, ok := parseInt(packet[s : i-1]) 264 | if !ok || count <= 0 { 265 | return false, args[:0], Redis, packet, errInvalidBulkLength 266 | } 267 | i++ 268 | if len(packet)-i >= n+2 { 269 | if packet[i+n] != '\r' || packet[i+n+1] != '\n' { 270 | return false, args[:0], Redis, packet, errInvalidBulkLength 271 | } 272 | args = append(args, packet[i:i+n]) 273 | i += n + 2 274 | if j == count-1 { 275 | // done reading 276 | return true, args, Redis, packet[i:], nil 277 | } 278 | continue nextArg 279 | } 280 | break 281 | } 282 | } 283 | break 284 | } 285 | break 286 | } 287 | } 288 | } 289 | return false, args[:0], Redis, packet, nil 290 | } 291 | 292 | func readTile38Command(packet []byte, argsbuf [][]byte) ( 293 | complete bool, args [][]byte, kind Kind, leftover []byte, err error, 294 | ) { 295 | for i := 1; i < len(packet); i++ { 296 | if packet[i] == ' ' { 297 | n, ok := parseInt(packet[1:i]) 298 | if !ok || n < 0 { 299 | return false, args[:0], Tile38, packet, errInvalidMessage 300 | } 301 | i++ 302 | if len(packet) >= i+n+2 { 303 | if packet[i+n] != '\r' || packet[i+n+1] != '\n' { 304 | return false, args[:0], Tile38, packet, errInvalidMessage 305 | } 306 | line := packet[i : i+n] 307 | reading: 308 | for len(line) != 0 { 309 | if line[0] == '{' { 310 | // The native protocol cannot understand json boundaries so it assumes that 311 | // a json element must be at the end of the line. 312 | args = append(args, line) 313 | break 314 | } 315 | if line[0] == '"' && line[len(line)-1] == '"' { 316 | if len(args) > 0 && 317 | strings.ToLower(string(args[0])) == "set" && 318 | strings.ToLower(string(args[len(args)-1])) == "string" { 319 | // Setting a string value that is contained inside double quotes. 320 | // This is only because of the boundary issues of the native protocol. 321 | args = append(args, line[1:len(line)-1]) 322 | break 323 | } 324 | } 325 | i := 0 326 | for ; i < len(line); i++ { 327 | if line[i] == ' ' { 328 | value := line[:i] 329 | if len(value) > 0 { 330 | args = append(args, value) 331 | } 332 | line = line[i+1:] 333 | continue reading 334 | } 335 | } 336 | args = append(args, line) 337 | break 338 | } 339 | return true, args, Tile38, packet[i+n+2:], nil 340 | } 341 | break 342 | } 343 | } 344 | return false, args[:0], Tile38, packet, nil 345 | } 346 | func readTelnetCommand(packet []byte, argsbuf [][]byte) ( 347 | complete bool, args [][]byte, kind Kind, leftover []byte, err error, 348 | ) { 349 | // just a plain text command 350 | for i := 0; i < len(packet); i++ { 351 | if packet[i] == '\n' { 352 | var line []byte 353 | if i > 0 && packet[i-1] == '\r' { 354 | line = packet[:i-1] 355 | } else { 356 | line = packet[:i] 357 | } 358 | var quote bool 359 | var quotech byte 360 | var escape bool 361 | outer: 362 | for { 363 | nline := make([]byte, 0, len(line)) 364 | for i := 0; i < len(line); i++ { 365 | c := line[i] 366 | if !quote { 367 | if c == ' ' { 368 | if len(nline) > 0 { 369 | args = append(args, nline) 370 | } 371 | line = line[i+1:] 372 | continue outer 373 | } 374 | if c == '"' || c == '\'' { 375 | if i != 0 { 376 | return false, args[:0], Telnet, packet, errUnbalancedQuotes 377 | } 378 | quotech = c 379 | quote = true 380 | line = line[i+1:] 381 | continue outer 382 | } 383 | } else { 384 | if escape { 385 | escape = false 386 | switch c { 387 | case 'n': 388 | c = '\n' 389 | case 'r': 390 | c = '\r' 391 | case 't': 392 | c = '\t' 393 | } 394 | } else if c == quotech { 395 | quote = false 396 | quotech = 0 397 | args = append(args, nline) 398 | line = line[i+1:] 399 | if len(line) > 0 && line[0] != ' ' { 400 | return false, args[:0], Telnet, packet, errUnbalancedQuotes 401 | } 402 | continue outer 403 | } else if c == '\\' { 404 | escape = true 405 | continue 406 | } 407 | } 408 | nline = append(nline, c) 409 | } 410 | if quote { 411 | return false, args[:0], Telnet, packet, errUnbalancedQuotes 412 | } 413 | if len(line) > 0 { 414 | args = append(args, line) 415 | } 416 | break 417 | } 418 | return true, args, Telnet, packet[i+1:], nil 419 | } 420 | } 421 | return false, args[:0], Telnet, packet, nil 422 | } 423 | 424 | // appendPrefix will append a "$3\r\n" style redis prefix for a message. 425 | func appendPrefix(b []byte, c byte, n int64) []byte { 426 | if n >= 0 && n <= 9 { 427 | return append(b, c, byte('0'+n), '\r', '\n') 428 | } 429 | b = append(b, c) 430 | b = strconv.AppendInt(b, n, 10) 431 | return append(b, '\r', '\n') 432 | } 433 | 434 | // AppendUint appends a Redis protocol uint64 to the input bytes. 435 | func AppendUint(b []byte, n uint64) []byte { 436 | b = append(b, ':') 437 | b = strconv.AppendUint(b, n, 10) 438 | return append(b, '\r', '\n') 439 | } 440 | 441 | // AppendInt appends a Redis protocol int64 to the input bytes. 442 | func AppendInt(b []byte, n int64) []byte { 443 | return appendPrefix(b, ':', n) 444 | } 445 | 446 | // AppendArray appends a Redis protocol array to the input bytes. 447 | func AppendArray(b []byte, n int) []byte { 448 | return appendPrefix(b, '*', int64(n)) 449 | } 450 | 451 | // AppendBulk appends a Redis protocol bulk byte slice to the input bytes. 452 | func AppendBulk(b []byte, bulk []byte) []byte { 453 | b = appendPrefix(b, '$', int64(len(bulk))) 454 | b = append(b, bulk...) 455 | return append(b, '\r', '\n') 456 | } 457 | 458 | // AppendBulkString appends a Redis protocol bulk string to the input bytes. 459 | func AppendBulkString(b []byte, bulk string) []byte { 460 | b = appendPrefix(b, '$', int64(len(bulk))) 461 | b = append(b, bulk...) 462 | return append(b, '\r', '\n') 463 | } 464 | 465 | // AppendString appends a Redis protocol string to the input bytes. 466 | func AppendString(b []byte, s string) []byte { 467 | b = append(b, '+') 468 | b = append(b, stripNewlines(s)...) 469 | return append(b, '\r', '\n') 470 | } 471 | 472 | // AppendError appends a Redis protocol error to the input bytes. 473 | func AppendError(b []byte, s string) []byte { 474 | b = append(b, '-') 475 | b = append(b, stripNewlines(s)...) 476 | return append(b, '\r', '\n') 477 | } 478 | 479 | // AppendOK appends a Redis protocol OK to the input bytes. 480 | func AppendOK(b []byte) []byte { 481 | return append(b, '+', 'O', 'K', '\r', '\n') 482 | } 483 | func stripNewlines(s string) string { 484 | for i := 0; i < len(s); i++ { 485 | if s[i] == '\r' || s[i] == '\n' { 486 | s = strings.Replace(s, "\r", " ", -1) 487 | s = strings.Replace(s, "\n", " ", -1) 488 | break 489 | } 490 | } 491 | return s 492 | } 493 | 494 | // AppendTile38 appends a Tile38 message to the input bytes. 495 | func AppendTile38(b []byte, data []byte) []byte { 496 | b = append(b, '$') 497 | b = strconv.AppendInt(b, int64(len(data)), 10) 498 | b = append(b, ' ') 499 | b = append(b, data...) 500 | return append(b, '\r', '\n') 501 | } 502 | 503 | // AppendNull appends a Redis protocol null to the input bytes. 504 | func AppendNull(b []byte) []byte { 505 | return append(b, '$', '-', '1', '\r', '\n') 506 | } 507 | 508 | // AppendBulkFloat appends a float64, as bulk bytes. 509 | func AppendBulkFloat(dst []byte, f float64) []byte { 510 | return AppendBulk(dst, strconv.AppendFloat(nil, f, 'f', -1, 64)) 511 | } 512 | 513 | // AppendBulkInt appends an int64, as bulk bytes. 514 | func AppendBulkInt(dst []byte, x int64) []byte { 515 | return AppendBulk(dst, strconv.AppendInt(nil, x, 10)) 516 | } 517 | 518 | // AppendBulkUint appends an uint64, as bulk bytes. 519 | func AppendBulkUint(dst []byte, x uint64) []byte { 520 | return AppendBulk(dst, strconv.AppendUint(nil, x, 10)) 521 | } 522 | 523 | func prefixERRIfNeeded(msg string) string { 524 | msg = strings.TrimSpace(msg) 525 | firstWord := strings.Split(msg, " ")[0] 526 | addERR := len(firstWord) == 0 527 | for i := 0; i < len(firstWord); i++ { 528 | if firstWord[i] < 'A' || firstWord[i] > 'Z' { 529 | addERR = true 530 | break 531 | } 532 | } 533 | if addERR { 534 | msg = strings.TrimSpace("ERR " + msg) 535 | } 536 | return msg 537 | } 538 | 539 | // SimpleString is for representing a non-bulk representation of a string 540 | // from an *Any call. 541 | type SimpleString string 542 | 543 | // SimpleInt is for representing a non-bulk representation of a int 544 | // from an *Any call. 545 | type SimpleInt int 546 | 547 | // SimpleError is for representing an error without adding the "ERR" prefix 548 | // from an *Any call. 549 | type SimpleError error 550 | 551 | // Marshaler is the interface implemented by types that 552 | // can marshal themselves into a Redis response type from an *Any call. 553 | // The return value is not check for validity. 554 | type Marshaler interface { 555 | MarshalRESP() []byte 556 | } 557 | 558 | // AppendAny appends any type to valid Redis type. 559 | // 560 | // nil -> null 561 | // error -> error (adds "ERR " when first word is not uppercase) 562 | // string -> bulk-string 563 | // numbers -> bulk-string 564 | // []byte -> bulk-string 565 | // bool -> bulk-string ("0" or "1") 566 | // slice -> array 567 | // map -> array with key/value pairs 568 | // SimpleString -> string 569 | // SimpleInt -> integer 570 | // Marshaler -> raw bytes 571 | // everything-else -> bulk-string representation using fmt.Sprint() 572 | func AppendAny(b []byte, v interface{}) []byte { 573 | switch v := v.(type) { 574 | case SimpleString: 575 | b = AppendString(b, string(v)) 576 | case SimpleInt: 577 | b = AppendInt(b, int64(v)) 578 | case SimpleError: 579 | b = AppendError(b, v.Error()) 580 | case nil: 581 | b = AppendNull(b) 582 | case error: 583 | b = AppendError(b, prefixERRIfNeeded(v.Error())) 584 | case string: 585 | b = AppendBulkString(b, v) 586 | case []byte: 587 | if v == nil { 588 | b = AppendNull(b) 589 | } else { 590 | b = AppendBulk(b, v) 591 | } 592 | case bool: 593 | if v { 594 | b = AppendBulkString(b, "1") 595 | } else { 596 | b = AppendBulkString(b, "0") 597 | } 598 | case int: 599 | b = AppendBulkInt(b, int64(v)) 600 | case int8: 601 | b = AppendBulkInt(b, int64(v)) 602 | case int16: 603 | b = AppendBulkInt(b, int64(v)) 604 | case int32: 605 | b = AppendBulkInt(b, int64(v)) 606 | case int64: 607 | b = AppendBulkInt(b, int64(v)) 608 | case uint: 609 | b = AppendBulkUint(b, uint64(v)) 610 | case uint8: 611 | b = AppendBulkUint(b, uint64(v)) 612 | case uint16: 613 | b = AppendBulkUint(b, uint64(v)) 614 | case uint32: 615 | b = AppendBulkUint(b, uint64(v)) 616 | case uint64: 617 | b = AppendBulkUint(b, uint64(v)) 618 | case float32: 619 | b = AppendBulkFloat(b, float64(v)) 620 | case float64: 621 | b = AppendBulkFloat(b, float64(v)) 622 | case Marshaler: 623 | b = append(b, v.MarshalRESP()...) 624 | default: 625 | vv := reflect.ValueOf(v) 626 | switch vv.Kind() { 627 | case reflect.Slice: 628 | n := vv.Len() 629 | b = AppendArray(b, n) 630 | for i := 0; i < n; i++ { 631 | b = AppendAny(b, vv.Index(i).Interface()) 632 | } 633 | case reflect.Map: 634 | n := vv.Len() 635 | b = AppendArray(b, n*2) 636 | var i int 637 | var strKey bool 638 | var strsKeyItems []strKeyItem 639 | 640 | iter := vv.MapRange() 641 | for iter.Next() { 642 | key := iter.Key().Interface() 643 | if i == 0 { 644 | if _, ok := key.(string); ok { 645 | strKey = true 646 | strsKeyItems = make([]strKeyItem, n) 647 | } 648 | } 649 | if strKey { 650 | strsKeyItems[i] = strKeyItem{ 651 | key.(string), iter.Value().Interface(), 652 | } 653 | } else { 654 | b = AppendAny(b, key) 655 | b = AppendAny(b, iter.Value().Interface()) 656 | } 657 | i++ 658 | } 659 | if strKey { 660 | sort.Slice(strsKeyItems, func(i, j int) bool { 661 | return strsKeyItems[i].key < strsKeyItems[j].key 662 | }) 663 | for _, item := range strsKeyItems { 664 | b = AppendBulkString(b, item.key) 665 | b = AppendAny(b, item.value) 666 | } 667 | } 668 | default: 669 | b = AppendBulkString(b, fmt.Sprint(v)) 670 | } 671 | } 672 | return b 673 | } 674 | 675 | type strKeyItem struct { 676 | key string 677 | value interface{} 678 | } 679 | -------------------------------------------------------------------------------- /resp_test.go: -------------------------------------------------------------------------------- 1 | package redcon 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "math/rand" 7 | "strconv" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func isEmptyRESP(resp RESP) bool { 13 | return resp.Type == 0 && resp.Count == 0 && 14 | resp.Data == nil && resp.Raw == nil 15 | } 16 | 17 | func expectBad(t *testing.T, payload string) { 18 | t.Helper() 19 | n, resp := ReadNextRESP([]byte(payload)) 20 | if n > 0 || !isEmptyRESP(resp) { 21 | t.Fatalf("expected empty resp") 22 | } 23 | } 24 | 25 | func respVOut(a RESP) string { 26 | var data string 27 | var raw string 28 | if a.Data == nil { 29 | data = "nil" 30 | } else { 31 | data = strconv.Quote(string(a.Data)) 32 | } 33 | if a.Raw == nil { 34 | raw = "nil" 35 | } else { 36 | raw = strconv.Quote(string(a.Raw)) 37 | } 38 | return fmt.Sprintf("{Type: %d, Count: %d, Data: %s, Raw: %s}", 39 | a.Type, a.Count, data, raw, 40 | ) 41 | } 42 | 43 | func respEquals(a, b RESP) bool { 44 | if a.Count != b.Count { 45 | return false 46 | } 47 | if a.Type != b.Type { 48 | return false 49 | } 50 | if (a.Data == nil && b.Data != nil) || (a.Data != nil && b.Data == nil) { 51 | return false 52 | } 53 | if string(a.Data) != string(b.Data) { 54 | return false 55 | } 56 | if (a.Raw == nil && b.Raw != nil) || (a.Raw != nil && b.Raw == nil) { 57 | return false 58 | } 59 | if string(a.Raw) != string(b.Raw) { 60 | return false 61 | } 62 | return true 63 | } 64 | 65 | func expectGood(t *testing.T, payload string, exp RESP) { 66 | t.Helper() 67 | n, resp := ReadNextRESP([]byte(payload)) 68 | if n != len(payload) || isEmptyRESP(resp) { 69 | t.Fatalf("expected good resp") 70 | } 71 | if string(resp.Raw) != payload { 72 | t.Fatalf("expected '%s', got '%s'", payload, resp.Raw) 73 | } 74 | exp.Raw = []byte(payload) 75 | switch exp.Type { 76 | case Integer, String, Error: 77 | exp.Data = []byte(payload[1 : len(payload)-2]) 78 | } 79 | if !respEquals(resp, exp) { 80 | t.Fatalf("expected %v, got %v", respVOut(exp), respVOut(resp)) 81 | } 82 | } 83 | 84 | func TestRESP(t *testing.T) { 85 | expectBad(t, "") 86 | expectBad(t, "^hello\r\n") 87 | expectBad(t, "+hello\r") 88 | expectBad(t, "+hello\n") 89 | expectBad(t, ":\r\n") 90 | expectBad(t, ":-\r\n") 91 | expectBad(t, ":-abc\r\n") 92 | expectBad(t, ":abc\r\n") 93 | expectGood(t, ":-123\r\n", RESP{Type: Integer}) 94 | expectGood(t, ":123\r\n", RESP{Type: Integer}) 95 | expectBad(t, "+\r") 96 | expectBad(t, "+\n") 97 | expectGood(t, "+\r\n", RESP{Type: String}) 98 | expectGood(t, "+hello world\r\n", RESP{Type: String}) 99 | expectBad(t, "-\r") 100 | expectBad(t, "-\n") 101 | expectGood(t, "-\r\n", RESP{Type: Error}) 102 | expectGood(t, "-hello world\r\n", RESP{Type: Error}) 103 | expectBad(t, "$") 104 | expectBad(t, "$\r") 105 | expectBad(t, "$\r\n") 106 | expectGood(t, "$-1\r\n", RESP{Type: Bulk}) 107 | expectGood(t, "$0\r\n\r\n", RESP{Type: Bulk, Data: []byte("")}) 108 | expectBad(t, "$5\r\nhello\r") 109 | expectBad(t, "$5\r\nhello\n\n") 110 | expectGood(t, "$5\r\nhello\r\n", RESP{Type: Bulk, Data: []byte("hello")}) 111 | expectBad(t, "*a\r\n") 112 | expectBad(t, "*3\r\n") 113 | expectBad(t, "*3\r\n:hello\r") 114 | expectGood(t, "*3\r\n:1\r\n:2\r\n:3\r\n", 115 | RESP{Type: Array, Count: 3, Data: []byte(":1\r\n:2\r\n:3\r\n")}) 116 | 117 | var xx int 118 | _, r := ReadNextRESP([]byte("*4\r\n:1\r\n:2\r\n:3\r\n:4\r\n")) 119 | r.ForEach(func(resp RESP) bool { 120 | xx++ 121 | x, _ := strconv.Atoi(string(resp.Data)) 122 | if x != xx { 123 | t.Fatalf("expected %v, got %v", x, xx) 124 | } 125 | if xx == 3 { 126 | return false 127 | } 128 | return true 129 | }) 130 | if xx != 3 { 131 | t.Fatalf("expected %v, got %v", 3, xx) 132 | } 133 | } 134 | 135 | func TestNextCommand(t *testing.T) { 136 | rand.Seed(time.Now().UnixNano()) 137 | start := time.Now() 138 | for time.Since(start) < time.Second { 139 | // keep copy of pipeline args for final compare 140 | var plargs [][][]byte 141 | 142 | // create a pipeline of random number of commands with random data. 143 | N := rand.Int() % 10000 144 | var data []byte 145 | for i := 0; i < N; i++ { 146 | nargs := rand.Int() % 10 147 | data = AppendArray(data, nargs) 148 | var args [][]byte 149 | for j := 0; j < nargs; j++ { 150 | arg := make([]byte, rand.Int()%100) 151 | if _, err := rand.Read(arg); err != nil { 152 | t.Fatal(err) 153 | } 154 | data = AppendBulk(data, arg) 155 | args = append(args, arg) 156 | } 157 | plargs = append(plargs, args) 158 | } 159 | 160 | // break data into random number of chunks 161 | chunkn := rand.Int() % 100 162 | if chunkn == 0 { 163 | chunkn = 1 164 | } 165 | if len(data) < chunkn { 166 | continue 167 | } 168 | var chunks [][]byte 169 | var chunksz int 170 | for i := 0; i < len(data); i += chunksz { 171 | chunksz = rand.Int() % (len(data) / chunkn) 172 | var chunk []byte 173 | if i+chunksz < len(data) { 174 | chunk = data[i : i+chunksz] 175 | } else { 176 | chunk = data[i:] 177 | } 178 | chunks = append(chunks, chunk) 179 | } 180 | 181 | // process chunks 182 | var rbuf []byte 183 | var fargs [][][]byte 184 | for _, chunk := range chunks { 185 | var data []byte 186 | if len(rbuf) > 0 { 187 | data = append(rbuf, chunk...) 188 | } else { 189 | data = chunk 190 | } 191 | for { 192 | complete, args, _, leftover, err := ReadNextCommand(data, nil) 193 | data = leftover 194 | if err != nil { 195 | t.Fatal(err) 196 | } 197 | if !complete { 198 | break 199 | } 200 | fargs = append(fargs, args) 201 | } 202 | rbuf = append(rbuf[:0], data...) 203 | } 204 | // compare final args to original 205 | if len(plargs) != len(fargs) { 206 | t.Fatalf("not equal size: %v != %v", len(plargs), len(fargs)) 207 | } 208 | for i := 0; i < len(plargs); i++ { 209 | if len(plargs[i]) != len(fargs[i]) { 210 | t.Fatalf("not equal size for item %v: %v != %v", i, len(plargs[i]), len(fargs[i])) 211 | } 212 | for j := 0; j < len(plargs[i]); j++ { 213 | if !bytes.Equal(plargs[i][j], plargs[i][j]) { 214 | t.Fatalf("not equal for item %v:%v: %v != %v", i, j, len(plargs[i][j]), len(fargs[i][j])) 215 | } 216 | } 217 | } 218 | } 219 | } 220 | 221 | func TestAppendBulkFloat(t *testing.T) { 222 | var b []byte 223 | b = AppendString(b, "HELLO") 224 | b = AppendBulkFloat(b, 9.123192839) 225 | b = AppendString(b, "HELLO") 226 | exp := "+HELLO\r\n$11\r\n9.123192839\r\n+HELLO\r\n" 227 | if string(b) != exp { 228 | t.Fatalf("expected '%s', got '%s'", exp, b) 229 | } 230 | } 231 | 232 | func TestAppendBulkInt(t *testing.T) { 233 | var b []byte 234 | b = AppendString(b, "HELLO") 235 | b = AppendBulkInt(b, -9182739137) 236 | b = AppendString(b, "HELLO") 237 | exp := "+HELLO\r\n$11\r\n-9182739137\r\n+HELLO\r\n" 238 | if string(b) != exp { 239 | t.Fatalf("expected '%s', got '%s'", exp, b) 240 | } 241 | } 242 | 243 | func TestAppendBulkUint(t *testing.T) { 244 | var b []byte 245 | b = AppendString(b, "HELLO") 246 | b = AppendBulkInt(b, 91827391370) 247 | b = AppendString(b, "HELLO") 248 | exp := "+HELLO\r\n$11\r\n91827391370\r\n+HELLO\r\n" 249 | if string(b) != exp { 250 | t.Fatalf("expected '%s', got '%s'", exp, b) 251 | } 252 | } 253 | 254 | func TestArrayMap(t *testing.T) { 255 | var dst []byte 256 | dst = AppendArray(dst, 4) 257 | dst = AppendBulkString(dst, "key1") 258 | dst = AppendBulkString(dst, "val1") 259 | dst = AppendBulkString(dst, "key2") 260 | dst = AppendBulkString(dst, "val2") 261 | n, resp := ReadNextRESP(dst) 262 | if n != len(dst) { 263 | t.Fatalf("expected '%d', got '%d'", len(dst), n) 264 | } 265 | m := resp.Map() 266 | if len(m) != 2 { 267 | t.Fatalf("expected '%d', got '%d'", 2, len(m)) 268 | } 269 | if m["key1"].String() != "val1" { 270 | t.Fatalf("expected '%s', got '%s'", "val1", m["key1"].String()) 271 | } 272 | if m["key2"].String() != "val2" { 273 | t.Fatalf("expected '%s', got '%s'", "val2", m["key2"].String()) 274 | } 275 | } 276 | --------------------------------------------------------------------------------