├── bootstrap.sh ├── Makefile ├── cfg.json ├── README.md ├── .travis.yml ├── cmd └── main.go ├── util ├── conn.go ├── conn_pool_test.go ├── configure.go └── conn_pool.go ├── server.go ├── util.go ├── server_test.go └── proto.go /bootstrap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | make clean 4 | 5 | go get github.com/garyburd/redigo/redis 6 | 7 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: build 2 | 3 | build: 4 | go build -o bin/proxy ./cmd 5 | 6 | clean: 7 | @rm -rf bin 8 | go clean -i ./... 9 | 10 | test: 11 | go test -v ./... 12 | -------------------------------------------------------------------------------- /cfg.json: -------------------------------------------------------------------------------- 1 | { 2 | "id":"1", 3 | "ip":"127.0.0.1", 4 | "port":"9000", 5 | "prof_port":"54321", 6 | "bucket_base":"2", 7 | "buckets":[0,0,1,1], 8 | "bucket_addr":{ 9 | "0":"127.0.0.1:6379", 10 | "1":"127.0.0.1:6380" 11 | } 12 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | minproxy 2 | ========== 3 | 4 | A simple proxy based high performance redis cluster solution written in Go. 5 | 6 | ToDo: 7 | Auto rebalance 8 | 9 | Features 10 | ========== 11 | 12 | * Supports most of Redis commands. 13 | * Supports proxying to multiple servers. 14 | 15 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.3 5 | 6 | services: 7 | - redis-server --port 6379 8 | - redis-server --port 6380 9 | 10 | install: go get ./... && go build -v ./... && make 11 | 12 | script: 13 | - ./bootstrap.sh 14 | - make test 15 | 16 | branches: 17 | only: 18 | - master 19 | -------------------------------------------------------------------------------- /cmd/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "net/http" 7 | _ "net/http/pprof" 8 | "runtime" 9 | 10 | "github.com/zimulala/minproxy" 11 | "github.com/zimulala/minproxy/util" 12 | ) 13 | 14 | var ( 15 | cfgPath = flag.String("cfg", "/tmp/cfg.json", "configure path") 16 | ) 17 | 18 | func main() { 19 | flag.Parse() 20 | runtime.GOMAXPROCS(runtime.NumCPU()) 21 | 22 | cfg := util.LoadConfigFile(*cfgPath) 23 | pprof := cfg.GetString("prof_port") 24 | if pprof == "" { 25 | log.Println("bad config") 26 | return 27 | } 28 | 29 | go func() { 30 | log.Fatalln("failed to listen and serve, err:", http.ListenAndServe(":"+pprof, nil)) 31 | }() 32 | 33 | s := minproxy.NewServer() 34 | if err := s.Start(cfg); err != nil { 35 | log.Println("failed to start server, err:", err) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /util/conn.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "bufio" 5 | "net" 6 | "time" 7 | ) 8 | 9 | const DefaultMinReadBufferSize = 1024 10 | 11 | type Conn struct { 12 | addr string 13 | c *net.TCPConn 14 | R *bufio.Reader 15 | } 16 | 17 | func NewCon(network, addr string, timeout time.Duration) (*Conn, error) { 18 | c, err := net.DialTimeout(network, addr, timeout) 19 | if err != nil { 20 | return nil, err 21 | } 22 | 23 | return &Conn{addr: addr, c: c.(*net.TCPConn), R: bufio.NewReaderSize(c, DefaultMinReadBufferSize)}, nil 24 | } 25 | 26 | func (c *Conn) Write(buf []byte) (err error) { 27 | _, err = c.c.Write(buf) 28 | 29 | return 30 | } 31 | 32 | func (c *Conn) ReadBytes(p byte) ([]byte, error) { 33 | return c.R.ReadBytes(p) 34 | } 35 | 36 | func (c *Conn) SetReadDeadline(t time.Time) error { 37 | return c.c.SetReadDeadline(t) 38 | } 39 | 40 | func (c *Conn) SetKeepAlive(b bool) error { 41 | return c.c.SetKeepAlive(b) 42 | } 43 | 44 | func (c *Conn) SetNoDelay(b bool) error { 45 | return c.c.SetNoDelay(b) 46 | } 47 | 48 | func (c Conn) Addr() string { 49 | return c.addr 50 | } 51 | 52 | func (c *Conn) Close() { 53 | c.c.Close() 54 | } 55 | -------------------------------------------------------------------------------- /util/conn_pool_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "runtime" 5 | "sync" 6 | "testing" 7 | ) 8 | 9 | const ( 10 | Addr = "127.0.0.1:80" 11 | PutOp = "put" 12 | GetOp = "get" 13 | ) 14 | 15 | func TestBasicConnPool(t *testing.T) { 16 | p := NewConnPool() 17 | p.NewUnitPool(-1, Addr, 3, -1) 18 | 19 | c, err := p.GetConn(Addr) 20 | if err != nil { 21 | t.Log("get conn err:", err) 22 | t.FailNow() 23 | } 24 | if err = p.PutConn(Addr, c); err != nil { 25 | t.Log("put conn err:", err) 26 | t.FailNow() 27 | } 28 | 29 | //Addr don't exist 30 | if _, err := p.GetConn(Addr + "not exist"); err != ErrNotExistUnitPool { 31 | t.Log("get conn err:", err) 32 | t.Fail() 33 | } 34 | if err = p.PutConn(Addr+"not exist", c); err != ErrNotExistUnitPool { 35 | t.Log("put conn err:", err) 36 | t.Fail() 37 | } 38 | } 39 | 40 | func MutilThreadsOperation(loops int, b *testing.B, logTab string, f func(Addr string) error) { 41 | wg := sync.WaitGroup{} 42 | 43 | for i := 0; i < loops; i++ { 44 | wg.Add(1) 45 | go func() { 46 | if err := f(Addr); err != nil { 47 | b.Log(logTab+" err:", err) 48 | b.FailNow() 49 | } 50 | 51 | wg.Done() 52 | }() 53 | } 54 | 55 | wg.Wait() 56 | } 57 | 58 | func BenchmarkConnPool(b *testing.B) { 59 | runtime.GOMAXPROCS(runtime.NumCPU()) 60 | p := NewConnPool() 61 | size := 55 62 | timeout := 3 63 | retrys := 2 64 | p.NewUnitPool(size, Addr, timeout, retrys) 65 | 66 | wg := sync.WaitGroup{} 67 | go func() { 68 | wg.Add(1) 69 | MutilThreadsOperation(b.N, b, GetOp, func(Addr string) error { 70 | c, err := p.GetConn(Addr) 71 | if err == nil { 72 | c.Close() 73 | } 74 | return err 75 | }) 76 | wg.Done() 77 | }() 78 | go func() { 79 | wg.Add(1) 80 | MutilThreadsOperation(b.N, b, PutOp, func(Addr string) error { err := p.PutConn(Addr, nil); return err }) 81 | wg.Done() 82 | }() 83 | 84 | wg.Wait() 85 | } 86 | -------------------------------------------------------------------------------- /util/configure.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "encoding/json" 5 | "io/ioutil" 6 | "log" 7 | "strconv" 8 | ) 9 | 10 | type Config struct { 11 | data map[string]interface{} 12 | } 13 | 14 | func newConfig() *Config { 15 | result := new(Config) 16 | result.data = make(map[string]interface{}) 17 | return result 18 | } 19 | 20 | // Loads config information from a JSON file 21 | func LoadConfigFile(filename string) *Config { 22 | result := newConfig() 23 | err := result.parse(filename) 24 | if err != nil { 25 | log.Fatalf("error loading config file %s: %s", filename, err) 26 | } 27 | 28 | return result 29 | } 30 | 31 | // Loads config information from a JSON string 32 | func LoadConfigString(s string) *Config { 33 | result := newConfig() 34 | err := json.Unmarshal([]byte(s), &result.data) 35 | if err != nil { 36 | log.Fatalf("error parsing config string %s: %s", s, err) 37 | } 38 | 39 | return result 40 | } 41 | 42 | func (c *Config) parse(fileName string) error { 43 | jsonFileBytes, err := ioutil.ReadFile(fileName) 44 | if err == nil { 45 | err = json.Unmarshal(jsonFileBytes, &c.data) 46 | } 47 | 48 | return err 49 | } 50 | 51 | // Returns a string for the config variable key 52 | func (c *Config) GetString(key string) string { 53 | result, present := c.data[key] 54 | if !present { 55 | return "" 56 | } 57 | 58 | return result.(string) 59 | } 60 | 61 | // Returns a int for the config variable key 62 | func (c *Config) GetInt(key string) int { 63 | if x, ok := c.data[key]; ok { 64 | str := x.(string) 65 | if v, err := strconv.Atoi(str); err == nil { 66 | return v 67 | } 68 | } 69 | 70 | return -1 71 | } 72 | 73 | // Returns a float for the config variable key 74 | func (c *Config) GetFloat(key string) float64 { 75 | x, ok := c.data[key] 76 | if !ok { 77 | return -1 78 | } 79 | 80 | return x.(float64) 81 | } 82 | 83 | // Returns a bool for the config variable key 84 | func (c *Config) GetBool(key string) bool { 85 | x, ok := c.data[key] 86 | if !ok { 87 | return false 88 | } 89 | 90 | return x.(bool) 91 | } 92 | 93 | // Returns an interface for the config variable key 94 | func (c *Config) GetInterface(key string) interface{} { 95 | result, present := c.data[key] 96 | if !present { 97 | return interface{}(nil) 98 | } 99 | 100 | return result.(interface{}) 101 | } 102 | 103 | // Returns an array for the config variable key 104 | func (c *Config) GetArray(key string) []interface{} { 105 | result, present := c.data[key] 106 | if !present { 107 | return []interface{}(nil) 108 | } 109 | 110 | return result.([]interface{}) 111 | } 112 | -------------------------------------------------------------------------------- /util/conn_pool.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | const ( 10 | DefaultSize = 50 11 | DefaultTrys = 1 12 | DefaultRetryInterval = 1 13 | ConnType = "tcp" 14 | ) 15 | 16 | var ( 17 | ErrSameAddr = errors.New("SameAddrError") 18 | ErrPoolFull = errors.New("PoolFullError") 19 | ErrAddrEmpty = errors.New("AddrEmptyError") 20 | ErrNotExistUnitPool = errors.New("NotExistUnitPoolErr") 21 | ) 22 | 23 | type ConnPool struct { 24 | rwMu sync.RWMutex 25 | unitPools map[string]*UnitConnPool 26 | } 27 | 28 | type UnitConnPool struct { 29 | size int 30 | timeout int 31 | trys int 32 | addr string 33 | pool chan *Conn 34 | } 35 | 36 | func NewConnPool() *ConnPool { 37 | return &ConnPool{unitPools: make(map[string]*UnitConnPool)} 38 | } 39 | 40 | func (connp *ConnPool) NewUnitPool(size int, addr string, timeout, trys int) (p *UnitConnPool, err error) { 41 | if size < 0 { 42 | size = DefaultSize 43 | } 44 | if trys <= 0 { 45 | trys = DefaultTrys 46 | } 47 | if addr == "" { 48 | return nil, ErrAddrEmpty 49 | } 50 | 51 | p = &UnitConnPool{addr: addr, size: size, timeout: timeout, trys: trys, pool: make(chan *Conn, size)} 52 | for i := 0; i < size; i++ { 53 | p.pool <- nil 54 | } 55 | if err = p.Ping(); err != nil { 56 | return 57 | } 58 | connp.SetUintPool(addr, p) 59 | 60 | return 61 | } 62 | 63 | func (p *UnitConnPool) Ping() (err error) { 64 | c, err := p.Get() 65 | if err != nil { 66 | return 67 | } 68 | defer p.Put(c) 69 | 70 | return 71 | } 72 | 73 | func (p *UnitConnPool) Get() (c *Conn, err error) { 74 | select { 75 | case c = <-p.pool: 76 | if c != nil { 77 | return 78 | } 79 | default: 80 | } 81 | 82 | for i := 0; i < p.trys; i++ { 83 | if c, err = NewCon(ConnType, p.addr, time.Duration(p.timeout)*time.Second); err == nil { 84 | break 85 | } 86 | } 87 | if err != nil { 88 | return 89 | } 90 | 91 | c.SetKeepAlive(true) 92 | c.SetNoDelay(true) 93 | 94 | return 95 | } 96 | 97 | func (p *UnitConnPool) Put(conn *Conn) (err error) { 98 | select { 99 | case p.pool <- conn: 100 | default: 101 | if conn.c != nil { 102 | conn.Close() 103 | } 104 | } 105 | 106 | return 107 | } 108 | 109 | func (connp *ConnPool) GetConn(addr string) (c *Conn, err error) { 110 | p, ok := connp.GetUintPool(addr) 111 | if !ok { 112 | return nil, ErrNotExistUnitPool 113 | } 114 | 115 | return p.Get() 116 | } 117 | 118 | func (connp *ConnPool) PutConn(addr string, conn *Conn) (err error) { 119 | p, ok := connp.GetUintPool(addr) 120 | if !ok { 121 | if conn.c != nil { 122 | conn.Close() 123 | } 124 | return 125 | } 126 | 127 | return p.Put(conn) 128 | } 129 | 130 | func (connp *ConnPool) SetUintPool(addr string, p *UnitConnPool) { 131 | connp.rwMu.Lock() 132 | connp.unitPools[addr] = p 133 | connp.rwMu.Unlock() 134 | } 135 | 136 | func (connp *ConnPool) GetUintPool(addr string) (p *UnitConnPool, ok bool) { 137 | connp.rwMu.RLock() 138 | p, ok = connp.unitPools[addr] 139 | connp.rwMu.RUnlock() 140 | 141 | return 142 | } 143 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package minproxy 2 | 3 | import ( 4 | "bufio" 5 | "net" 6 | "sync" 7 | 8 | "github.com/zimulala/minproxy/util" 9 | ) 10 | 11 | type Server struct { 12 | id int 13 | ip string 14 | port string 15 | connPool *util.ConnPool 16 | 17 | bucketBase int 18 | buckets []int 19 | bucketAddrMap map[int]string //key: bucket, val: serverAddr 20 | bucketMux sync.RWMutex 21 | } 22 | 23 | func NewServer() *Server { 24 | return &Server{ 25 | connPool: util.NewConnPool(), 26 | bucketAddrMap: make(map[int]string)} 27 | } 28 | 29 | func (s *Server) Start(cfg *util.Config) error { 30 | if err := s.CheckConfig(cfg); err != nil { 31 | return err 32 | } 33 | 34 | if err := InitConnPool(s.bucketAddrMap, s.connPool); err != nil { 35 | return err 36 | } 37 | 38 | return s.ListenAndServe() 39 | } 40 | 41 | func (s *Server) ListenAndServe() (err error) { 42 | l, err := net.Listen("tcp", ":"+s.port) 43 | if err != nil { 44 | return 45 | } 46 | 47 | for { 48 | c, err := l.Accept() 49 | if err != nil { 50 | return err 51 | } 52 | go s.Serve(c) 53 | } 54 | 55 | l.Close() 56 | 57 | return 58 | } 59 | 60 | func (s *Server) Serve(c net.Conn) { 61 | conn, _ := c.(*net.TCPConn) 62 | conn.SetKeepAlive(true) 63 | conn.SetNoDelay(true) 64 | reader := bufio.NewReader(c) 65 | taskCh := make(chan *Task, 1024) 66 | exitCh := make(chan Sigal, 1) 67 | 68 | go s.handleReplys(conn, taskCh, exitCh) 69 | 70 | for { 71 | req, err := ReadReqs(conn, reader) 72 | if err != nil { 73 | close(exitCh) 74 | break 75 | } 76 | if err = s.handleReqs(req); err != nil { 77 | close(exitCh) 78 | break 79 | } 80 | taskCh <- req 81 | } 82 | conn.Close() 83 | } 84 | 85 | func ReadReqs(c *net.TCPConn, reader *bufio.Reader) (t *Task, err error) { 86 | t = &Task{Id: GenerateId()} 87 | if t.Raw, err = ReadReqData(reader); err != nil { 88 | return 89 | } 90 | if len(t.Raw) <= 0 { 91 | err = ErrBadReqFormat 92 | } 93 | 94 | return 95 | } 96 | 97 | func (s *Server) handleReqs(req *Task) (err error) { 98 | if err = req.UnmarshalPkg(); err != nil { 99 | return 100 | } 101 | 102 | addrs, err := s.GetAddrs(req) 103 | if err != nil { 104 | req.PackErrorReply(err.Error()) 105 | return nil 106 | } 107 | 108 | if err = s.GetConns(addrs, req); err != nil { 109 | req.PackErrorReply(err.Error()) 110 | } 111 | 112 | return nil 113 | } 114 | 115 | func (s *Server) handleReplys(c *net.TCPConn, taskCh chan *Task, exitCh chan Sigal) { 116 | for { 117 | select { 118 | case task := <-taskCh: 119 | if task.IsErrTask() { 120 | Write(c, *task.Resp) 121 | s.ReleaseConns(task) 122 | break 123 | } 124 | 125 | ReadReplys(task) 126 | if err := task.MergeReplys(); err != nil { 127 | task.PackErrorReply(err.Error()) 128 | } 129 | Write(c, *task.Resp) 130 | s.ReleaseConns(task) 131 | case <-exitCh: 132 | return 133 | } 134 | } 135 | 136 | return 137 | } 138 | 139 | func ReadReplys(task *Task) { 140 | if len(task.OutInfos) == 1 { 141 | if err := task.OutInfos[0].ReadReply(); err != nil { 142 | task.OutInfos[0].connAddr = task.OutInfos[0].conn.Addr() 143 | } 144 | return 145 | } 146 | 147 | wg := sync.WaitGroup{} 148 | for _, info := range task.OutInfos { 149 | wg.Add(1) 150 | go func() { 151 | if err := info.ReadReply(); err != nil { 152 | info.connAddr = info.conn.Addr() 153 | } 154 | wg.Done() 155 | }() 156 | } 157 | wg.Wait() 158 | } 159 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package minproxy 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "log" 7 | "strconv" 8 | "sync" 9 | "sync/atomic" 10 | "time" 11 | 12 | "github.com/zimulala/minproxy/util" 13 | ) 14 | 15 | const ( 16 | ConnTimeout = 5 17 | ConnRetrys = 2 18 | ConnSize = 600 19 | ConnReadDeadline = 5 20 | GetConnErr = 0 21 | WriteToConnErr = 1 22 | ConnOk = 2 23 | ConnOkStr = "" 24 | ) 25 | 26 | var ( 27 | ErrBadConfig = errors.New("bad config err") 28 | ErrBadBucketKey = errors.New("bad bucket key err") 29 | ErrGetConn = errors.New("get conn err") 30 | ErrWriteToConn = errors.New("write to conn err") 31 | ) 32 | 33 | type Sigal struct{} 34 | 35 | func (s *Server) CheckConfig(cfg *util.Config) error { 36 | s.id = cfg.GetInt("id") 37 | s.ip = cfg.GetString("ip") 38 | s.port = cfg.GetString("port") 39 | s.bucketBase = cfg.GetInt("bucket_base") 40 | buckets := cfg.GetArray("buckets") 41 | bucketAddrMap := cfg.GetInterface("bucket_addr").(map[string]interface{}) 42 | 43 | for _, b := range buckets { 44 | s.buckets = append(s.buckets, int(b.(float64))) 45 | } 46 | for b, addr := range bucketAddrMap { 47 | bInt, err := strconv.Atoi(b) 48 | if err != nil { 49 | return err 50 | } 51 | s.bucketAddrMap[bInt] = addr.(string) 52 | } 53 | 54 | if s.id == -1 || s.ip == "" || s.port == "" { 55 | return ErrBadConfig 56 | } 57 | 58 | return nil 59 | } 60 | 61 | func InitConnPool(addrMap map[int]string, connP *util.ConnPool) (err error) { 62 | for _, addr := range addrMap { 63 | if _, err = connP.NewUnitPool(ConnSize, addr, ConnTimeout, ConnRetrys); err != nil { 64 | break 65 | } 66 | } 67 | 68 | return 69 | } 70 | 71 | func GenerateId() int64 { 72 | return time.Now().UnixNano() 73 | } 74 | 75 | func (s *Server) GetAddrs(pkg *Task) (addrs []string, err error) { 76 | weights := make([]int64, len(pkg.OutInfos)) 77 | addrs = make([]string, len(pkg.OutInfos)) 78 | 79 | for i, info := range pkg.OutInfos { 80 | for _, k := range info.key { 81 | weights[i] += int64(k) 82 | } 83 | } 84 | 85 | s.bucketMux.RLock() 86 | defer s.bucketMux.RUnlock() 87 | for i, w := range weights { 88 | bucket := int(w % int64(len(s.buckets)/s.bucketBase)) 89 | addr, ok := s.bucketAddrMap[bucket] 90 | if !ok { 91 | return nil, ErrBadBucketKey 92 | } 93 | addrs[i] = addr 94 | } 95 | 96 | return 97 | } 98 | 99 | func (s *Server) GetConns(addrs []string, task *Task) (err error) { 100 | if len(task.OutInfos) == 1 { 101 | if task.OutInfos[0].conn, err = s.connPool.GetConn(addrs[0]); err == nil { 102 | err = task.OutInfos[0].conn.Write(task.OutInfos[0].data) 103 | } 104 | if err != nil { 105 | task.OutInfos[0].connAddr = addrs[0] 106 | } 107 | return 108 | } 109 | 110 | isErr := uint32(ConnOk) 111 | wg := sync.WaitGroup{} 112 | 113 | for i, info := range task.OutInfos { 114 | wg.Add(1) 115 | go func() { 116 | if info.conn, err = s.connPool.GetConn(addrs[i]); err != nil { 117 | info.connAddr = addrs[i] 118 | atomic.StoreUint32(&isErr, GetConnErr) 119 | wg.Done() 120 | return 121 | } 122 | if err = info.conn.Write(info.data); err != nil { 123 | info.connAddr = addrs[i] 124 | atomic.StoreUint32(&isErr, WriteToConnErr) 125 | } 126 | wg.Done() 127 | }() 128 | } 129 | wg.Wait() 130 | 131 | if isErr == GetConnErr { 132 | err = ErrGetConn 133 | } else if isErr == WriteToConnErr { 134 | err = ErrWriteToConn 135 | } 136 | 137 | return 138 | } 139 | 140 | func (s *Server) ReleaseConns(pkg *Task) { 141 | for _, info := range pkg.OutInfos { 142 | if info.connAddr == ConnOkStr { 143 | s.connPool.PutConn(info.conn.Addr(), info.conn) 144 | continue 145 | } 146 | s.connPool.PutConn(info.connAddr, nil) 147 | } 148 | } 149 | 150 | func GetVal(s []byte) (val []byte, err error) { 151 | size := len(s) 152 | idx := bytes.IndexByte(s, '\n') 153 | if idx < 0 || size < idx+1 || idx+1 > size-2 { 154 | log.Println("GetVal, s:", string(s), " idx:", idx, " size:", size) 155 | return nil, ErrBadReqFormat 156 | } 157 | 158 | val = s[idx+1 : size-2] 159 | 160 | return 161 | } 162 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package minproxy 2 | 3 | import ( 4 | "github.com/garyburd/redigo/redis" 5 | "math" 6 | "runtime" 7 | "testing" 8 | "time" 9 | 10 | "github.com/zimulala/minproxy/util" 11 | ) 12 | 13 | const ( 14 | cfgPath = "cfg.json" 15 | addr = "127.0.0.1:9000" 16 | //addr = "127.0.0.1:6379" 17 | ) 18 | 19 | var s = NewServer() 20 | 21 | var writeTests = []struct { 22 | args []interface{} 23 | data string 24 | }{ 25 | { 26 | []interface{}{"DEL", "key"}, 27 | "*2\r\n$3\r\nDEL\r\n$3\r\nkey\r\n", 28 | }, 29 | { 30 | []interface{}{"GETSET", "key", "val"}, 31 | "*3\r\n$6\r\nGETSET\r\n$3\r\nkey\r\n$3\r\nval\r\n", 32 | }, 33 | { 34 | []interface{}{"GETSET", "key", ""}, 35 | "*3\r\n$6\r\nGETSET\r\n$3\r\nkey\r\n$0\r\n\r\n", 36 | }, 37 | { 38 | []interface{}{"GET", "key"}, 39 | "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n", 40 | }, 41 | { 42 | []interface{}{"ZADD", "salary", 9000, "tom"}, 43 | "*4\r\n$4\r\nZADD\r\n$6\r\nsalary\r\n$4\r\n9000\r\n$3\r\ntom\r\n", 44 | }, 45 | { 46 | []interface{}{"ZADD", "salary", 15000, "lily"}, 47 | "*4\r\n$4\r\nZADD\r\n$6\r\nsalary\r\n$5\r\n15000\r\n$4\r\nlily\r\n", 48 | }, 49 | { 50 | []interface{}{"ZADD", "salary", 33000, "lala"}, 51 | "*4\r\n$4\r\nZADD\r\n$6\r\nsalary\r\n$5\r\n33000\r\n$4\r\nlala\r\n", 52 | }, 53 | { 54 | []interface{}{"ZCARD", "salary"}, 55 | "*2\r\n$5\r\nZCARD\r\n$6\r\nsalary\r\n", 56 | }, 57 | { 58 | []interface{}{"HGET", "myhash", "foo"}, 59 | "*3\r\n$4\r\nHGET\r\n$6\r\nmyhash\r\n$3\r\nfoo\r\n", 60 | }, 61 | { 62 | []interface{}{"HDEL", "myhash", "foo"}, 63 | "*3\r\n$4\r\nHDEL\r\n$6\r\nmyhash\r\n$3\r\nfoo\r\n", 64 | }, 65 | { 66 | []interface{}{"HSET", "myhash", "foo", 2}, 67 | "*4\r\n$4\r\nHSET\r\n$6\r\nmyhash\r\n$3\r\nfoo\r\n$1\r\n2\r\n", 68 | }, 69 | { 70 | []interface{}{"HIncrBy", "myhash", "foo", 2}, 71 | "*4\r\n$7\r\nHIncrBy\r\n$6\r\nmyhash\r\n$3\r\nfoo\r\n$1\r\n2\r\n", 72 | }, 73 | { 74 | []interface{}{"HGET", "myhash", "foo"}, 75 | "*3\r\n$4\r\nHGET\r\n$6\r\nmyhash\r\n$3\r\nfoo\r\n", 76 | }, 77 | { 78 | []interface{}{"SET", "foo", "bar"}, 79 | "*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\nbar\r\n", 80 | }, 81 | { 82 | []interface{}{"SET", "foo", "bar"}, 83 | "*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\nbar\r\n", 84 | }, 85 | { 86 | []interface{}{"GET", "foo"}, 87 | "*2\r\n$3\r\nGET\r\n$3\r\nfoo\r\n", 88 | }, 89 | { 90 | []interface{}{"TYPE", "foo"}, 91 | "*2\r\n$3\r\nDEL\r\n$3\r\nfoo\r\n", 92 | }, 93 | { 94 | []interface{}{"EXISTS", "foo"}, 95 | "*2\r\n$6\r\nEXISTS\r\n$3\r\nfoo\r\n", 96 | }, 97 | { 98 | []interface{}{"PERSIST", "foo"}, 99 | "*2\r\n$7\r\nPERSIST\r\n$3\r\nfoo\r\n", 100 | }, 101 | { 102 | []interface{}{"EXPIRE", "foo", 10}, 103 | "*3\r\n$6\r\nEXPIRE\r\n$3\r\nfoo\r\n$2\r\n10\r\n", 104 | }, 105 | { 106 | []interface{}{"PERSIST", "foo"}, 107 | "*2\r\n$7\r\nPERSIST\r\n$3\r\nfoo\r\n", 108 | }, 109 | { 110 | []interface{}{"TTL", "foo"}, 111 | "*2\r\n$3\r\nTTL\r\n$3\r\nfoo\r\n", 112 | }, 113 | { 114 | []interface{}{"EXPIREAT", "foo", 1293840000}, 115 | "*3\r\n$8\r\nEXPIREAT\r\n$3\r\nfoo\r\n$10\r\n1293840000\r\n", 116 | }, 117 | { 118 | []interface{}{"EXPIREAT", "foo", 1293840000}, 119 | "*3\r\n$8\r\nEXPIREAT\r\n$3\r\nfoo\r\n$10\r\n1293840000\r\n", 120 | }, 121 | { 122 | []interface{}{"DEL", "foo"}, 123 | "*2\r\n$3\r\nDEL\r\n$3\r\nfoo\r\n", 124 | }, 125 | { 126 | []interface{}{"DEL", "foo"}, 127 | "*2\r\n$3\r\nDEL\r\n$3\r\nfoo\r\n", 128 | }, 129 | { 130 | []interface{}{"SET", "foo", byte(100)}, 131 | "*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\n100\r\n", 132 | }, 133 | { 134 | []interface{}{"SET", "foo", 100}, 135 | "*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\n100\r\n", 136 | }, 137 | { 138 | []interface{}{"SET", "foo", int64(math.MinInt64)}, 139 | "*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$20\r\n-9223372036854775808\r\n", 140 | }, 141 | { 142 | []interface{}{"SET", "foo", float64(1349673917.939762)}, 143 | "*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$21\r\n1.349673917939762e+09\r\n", 144 | }, 145 | { 146 | []interface{}{"SET", "", []byte("foo")}, 147 | "*3\r\n$3\r\nSET\r\n$0\r\n\r\n$3\r\nfoo\r\n", 148 | }, 149 | { 150 | []interface{}{"SET", nil, []byte("foo")}, 151 | "*3\r\n$3\r\nSET\r\n$0\r\n\r\n$3\r\nfoo\r\n", 152 | }, 153 | } 154 | 155 | func init() { 156 | runtime.GOMAXPROCS(runtime.NumCPU()) 157 | cfg := util.LoadConfigFile(cfgPath) 158 | go s.Start(cfg) 159 | time.Sleep(5 * time.Second) 160 | } 161 | 162 | func TestBasic(t *testing.T) { 163 | conn, err := redis.DialTimeout("tcp", addr, 10*time.Second, 0, 0) 164 | if err != nil { 165 | t.Fatalf("dial err:%+v", err) 166 | } 167 | 168 | for _, tt := range writeTests { 169 | reply, err := conn.Do(tt.args[0].(string), tt.args[1:]...) 170 | if err != nil { 171 | t.Errorf("Do(%v) returned error %v", tt.args, err) 172 | continue 173 | } 174 | t.Log(tt.args[0].(string), " reply:", reply) 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /proto.go: -------------------------------------------------------------------------------- 1 | package minproxy 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "errors" 7 | "io" 8 | "net" 9 | "strconv" 10 | "time" 11 | 12 | "github.com/zimulala/minproxy/util" 13 | ) 14 | 15 | const ( 16 | LineNumStr = "*" 17 | DataSizeStr = "$" 18 | ArgSplitStr = "\r\n" 19 | ) 20 | 21 | var ( 22 | TagBeginByte = []byte{'{'} 23 | TagEndBytes = []byte{'}'} 24 | TagSplitByte = []byte{','} 25 | LineNumBytes = []byte{'*'} 26 | DataSizeBytes = []byte{'$'} 27 | ArgSplitBytes = []byte("\r\n") 28 | ) 29 | 30 | var ( 31 | ErrBadReqFormat = errors.New("bad req format err") 32 | ErrBadArgsNum = errors.New("bad args num err") 33 | ErrReadConn = errors.New("read conn err") 34 | ) 35 | 36 | var ( 37 | OpError uint8 = 0xFF 38 | ) 39 | 40 | type UnitPkg struct { 41 | conn *util.Conn 42 | uId int 43 | key []byte 44 | data []byte 45 | connAddr string 46 | } 47 | 48 | type Task struct { 49 | Opcode uint8 50 | Id int64 51 | OutInfos []*UnitPkg 52 | Raw [][]byte 53 | Resp *[]byte 54 | } 55 | 56 | func (t *Task) IsErrTask() (err bool) { 57 | if t.Opcode == OpError { 58 | err = true 59 | } 60 | 61 | return 62 | } 63 | 64 | func (t *Task) PackErrorReply(msg string) { 65 | t.Opcode = OpError 66 | errMsg := []byte("-" + msg + "\r\n") 67 | t.Resp = &errMsg 68 | 69 | return 70 | } 71 | 72 | func (t *Task) getMKeys(e [][]byte) { 73 | interval := 2 74 | val := ArgSplitBytes 75 | if string(e[2]) == "mset" { 76 | interval = 4 77 | } 78 | 79 | begin := 3 80 | t.OutInfos = make([]*UnitPkg, (len(e)-begin)/interval) 81 | for i := begin; i < len(e); i += interval { 82 | if interval == 4 { 83 | val = append(bytes.Join([][]byte{e[i+2], e[i+3]}, ArgSplitBytes), ArgSplitBytes...) 84 | } 85 | 86 | info := &UnitPkg{uId: i - begin, key: e[i+1], 87 | data: bytes.Join([][]byte{[]byte("*3"), e[1], e[2], e[i], e[i+1], val}, ArgSplitBytes)} 88 | t.OutInfos = append(t.OutInfos, info) 89 | } 90 | 91 | return 92 | } 93 | 94 | func Append(data [][]byte) (buf []byte) { 95 | for _, b := range data { 96 | buf = append(buf, b...) 97 | } 98 | 99 | return 100 | } 101 | 102 | /* 103 | *4\r\n 104 | $4\r\n 105 | HSET\r\n 106 | $6\r\n 107 | myhash\r\n 108 | $5\r\n 109 | field1\r\n 110 | $0\r\n 111 | \r\n 112 | */ 113 | func (t *Task) UnmarshalPkg() (err error) { 114 | if !bytes.HasPrefix(t.Raw[0], LineNumBytes) { //ping 115 | t.OutInfos = append(t.OutInfos, &UnitPkg{uId: 0, key: t.Raw[0], data: Append(t.Raw)}) 116 | return 117 | } 118 | 119 | lineN, err := strconv.Atoi(string(t.Raw[0][1 : len(t.Raw[0])-2])) 120 | if err != nil { 121 | return 122 | } 123 | switch lineN <= 1 { 124 | case true: 125 | return ErrBadArgsNum 126 | case false: 127 | if len(t.Raw) < 3 { 128 | return ErrBadArgsNum 129 | } 130 | if cmd, err := GetVal(t.Raw[1]); err != nil { 131 | return err 132 | } else if string(cmd) == "mset" || string(cmd) == "mget" { 133 | t.getMKeys(t.Raw) 134 | break 135 | } 136 | 137 | key, err := GetVal(t.Raw[2]) 138 | if err != nil { 139 | return err 140 | } 141 | t.OutInfos = append(t.OutInfos, &UnitPkg{uId: 0, key: key, data: Append(t.Raw)}) 142 | if bytes.Contains(t.OutInfos[0].key, TagBeginByte) || bytes.Contains(t.OutInfos[0].key, TagEndBytes) { 143 | start := bytes.Index(t.OutInfos[0].key, TagBeginByte) 144 | end := bytes.Index(t.OutInfos[0].key, TagSplitByte) 145 | if end < 0 { 146 | end = bytes.Index(t.OutInfos[0].key, TagEndBytes) 147 | } 148 | if start < 0 || end < start { 149 | return ErrBadReqFormat 150 | } 151 | t.OutInfos[0].key = t.OutInfos[0].key[start+1 : end] 152 | } 153 | } 154 | 155 | return 156 | } 157 | 158 | func (t *Task) MergeReplys() (err error) { 159 | lines := len(t.OutInfos) 160 | if lines == 1 { 161 | if t.OutInfos[0].connAddr != ConnOkStr { 162 | return ErrReadConn 163 | } 164 | t.Resp = &t.OutInfos[0].data 165 | return 166 | } 167 | 168 | *t.Resp = append(LineNumBytes, byte(lines)) 169 | for _, info := range t.OutInfos { 170 | if info.connAddr != ConnOkStr { 171 | return ErrReadConn 172 | } 173 | *t.Resp = append(*t.Resp, info.data...) 174 | } 175 | 176 | return 177 | } 178 | 179 | func readBulk(r *bufio.Reader, d []byte, data *[]byte) (err error) { 180 | bufL, err := strconv.Atoi(string(d[1 : len(d)-2])) 181 | if err != nil { 182 | return err 183 | } 184 | if bufL < 0 { 185 | return 186 | } 187 | 188 | buf, err := r.ReadBytes('\n') 189 | if err == nil { 190 | *data = append(*data, buf...) 191 | } 192 | 193 | return 194 | } 195 | 196 | func (p *UnitPkg) ReadReply() (err error) { 197 | p.conn.SetReadDeadline(time.Now().Add(ConnReadDeadline * time.Second)) 198 | if p.data, err = p.conn.ReadBytes('\n'); err != nil { 199 | return 200 | } 201 | if !bytes.HasPrefix(p.data, LineNumBytes) && !bytes.HasPrefix(p.data, DataSizeBytes) { 202 | return 203 | } 204 | if bytes.HasPrefix(p.data, DataSizeBytes) { 205 | readBulk(p.conn.R, p.data, &p.data) 206 | return 207 | } 208 | 209 | return p.readReplyData() 210 | } 211 | 212 | func (p *UnitPkg) readReplyData() (err error) { 213 | lines, err := strconv.Atoi(string((p.data)[1 : len(p.data)-2])) 214 | if err != nil { 215 | return 216 | } 217 | 218 | for i := 0; i < lines; i++ { 219 | buf, err := p.conn.ReadBytes('\n') 220 | if err != nil { 221 | return err 222 | } 223 | p.data = append(p.data, buf...) 224 | 225 | if bytes.HasPrefix(buf, DataSizeBytes) { 226 | readBulk(p.conn.R, buf, &p.data) 227 | continue 228 | } 229 | 230 | if bytes.HasPrefix(buf, LineNumBytes) { 231 | i = -1 232 | lines, err = strconv.Atoi(string(buf[1 : len(buf)-2])) 233 | } 234 | } 235 | 236 | return 237 | } 238 | 239 | func readLine(r *bufio.Reader) (b []byte, err error) { 240 | if b, err = r.ReadBytes('\n'); err != nil { 241 | return 242 | } 243 | l := len(b) - 2 244 | if l < 0 || b[l] != '\r' { 245 | err = ErrBadReqFormat 246 | } 247 | 248 | return 249 | } 250 | 251 | //*3\r\n$6\r\nGETSET\r\n$3\r\nkey\r\n$0\r\n\r\n 252 | func ReadReqData(r *bufio.Reader) (raws [][]byte, err error) { 253 | buf, err := readLine(r) 254 | if err != nil || len(buf) <= 2 { 255 | return 256 | } 257 | lines, err := strconv.Atoi(string(buf[1 : len(buf)-2])) 258 | if err != nil || lines < 0 { 259 | return 260 | } 261 | 262 | switch buf[0] { 263 | case '+', '-', ':': 264 | err = ErrBadReqFormat 265 | case '$': 266 | raws = make([][]byte, 1) 267 | s := len(buf) 268 | raws[0] = make([]byte, s+lines+2) 269 | copy(raws[0][:s], buf) 270 | if _, err = io.ReadFull(r, raws[0][s:s+lines]); err != nil { 271 | return nil, err 272 | } 273 | if b, err := readLine(r); err != nil || len(b) != 2 { 274 | return nil, err 275 | } else { 276 | copy(raws[0][(s+lines):(s+lines+2)], b) 277 | } 278 | case '*': 279 | raws = make([][]byte, lines+1) 280 | raws[0] = buf 281 | for i := 1; i <= lines; i++ { 282 | raw, err := ReadReqData(r) 283 | if err != nil || len(raw) <= 0 { 284 | return nil, err 285 | } 286 | raws[i] = raw[0] 287 | } 288 | case 'P': 289 | raws = make([][]byte, 1) 290 | raws[0] = buf 291 | default: 292 | err = ErrBadReqFormat 293 | } 294 | 295 | return 296 | } 297 | 298 | func Write(c *net.TCPConn, buf []byte) (err error) { 299 | _, err = c.Write(buf) 300 | 301 | return 302 | } 303 | --------------------------------------------------------------------------------