├── core ├── aof.go ├── object.go ├── pubsub.go ├── adlist.go ├── godis.go ├── zset.go ├── proto │ └── proto.go ├── geo.go └── geohash.go ├── godis-cli.go ├── util └── bufio2 │ └── bufio.go └── godis-server.go /core/aof.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "os" 9 | "syscall" 10 | ) 11 | 12 | //AppendToFile 写文件 13 | func AppendToFile(fileName string, content string) error { 14 | // 以只写的模式,打开文件 15 | f, err := os.OpenFile(fileName, os.O_WRONLY|syscall.O_CREAT, 0644) 16 | if err != nil { 17 | log.Println("aof file open failed" + err.Error()) 18 | } else { 19 | n, _ := f.Seek(0, os.SEEK_END) 20 | _, err = f.WriteAt([]byte(content), n) 21 | } 22 | defer f.Close() 23 | return err 24 | } 25 | 26 | func ReadAof(fileName string) []string { 27 | f, err := os.Open(fileName) 28 | if err != nil { 29 | fmt.Println("aof file open failed" + err.Error()) 30 | } 31 | defer f.Close() 32 | content, err := ioutil.ReadFile(fileName) 33 | if err != nil { 34 | fmt.Println("aof file read failed" + err.Error()) 35 | } 36 | ret := bytes.Split(content, []byte{'*'}) 37 | var pros = make([]string, len(ret)-1) 38 | for k, v := range ret[1:] { 39 | v := append(v[:0], append([]byte{'*'}, v[0:]...)...) 40 | pros[k] = string(v) 41 | } 42 | return pros 43 | } 44 | -------------------------------------------------------------------------------- /core/object.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | // GodisObject 是对特定类型的数据的包装 4 | type GodisObject struct { 5 | ObjectType int 6 | //encoding uint 7 | Ptr interface{} 8 | } 9 | 10 | const C_ERR = -1 11 | const C_OK = 0 12 | 13 | const ObjectTypeString = 0 14 | const OBJ_LIST = 1 15 | const OBJ_SET = 2 16 | const OBJ_ZSET = 3 17 | const OBJ_HASH = 4 18 | 19 | const OBJ_ENCODING_RAW = 0 /* Raw representation */ 20 | const OBJ_ENCODING_INT = 1 /* Encoded as integer */ 21 | const OBJ_ENCODING_HT = 2 /* Encoded as hash table */ 22 | const OBJ_ENCODING_ZIPMAP = 3 /* Encoded as zipmap */ 23 | const OBJ_ENCODING_LINKEDLIST = 4 /* No longer used: old list encoding. */ 24 | const OBJ_ENCODING_ZIPLIST = 5 /* Encoded as ziplist */ 25 | const OBJ_ENCODING_INTSET = 6 /* Encoded as intset */ 26 | const OBJ_ENCODING_SKIPLIST = 7 /* Encoded as skiplist */ 27 | const OBJ_ENCODING_EMBSTR = 8 /* Embedded sds string encoding */ 28 | const OBJ_ENCODING_QUICKLIST = 9 /* Encoded as linked list of ziplists */ 29 | 30 | // CreateObject 创建特定类型的object结构 31 | func CreateObject(t int, ptr interface{}) (o *GodisObject) { 32 | o = new(GodisObject) 33 | o.ObjectType = t 34 | o.Ptr = ptr 35 | return 36 | } 37 | -------------------------------------------------------------------------------- /core/pubsub.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import "strconv" 4 | 5 | func SubscribeCommand(c *Client, s *Server) { 6 | for j := 1; j < c.Argc; j++ { 7 | pubsubSubscribeChannel(c, c.Argv[j], s) 8 | } 9 | c.Flags |= CLIENT_PUBSUB 10 | 11 | } 12 | 13 | func pubsubSubscribeChannel(c *Client, obj *GodisObject, s *Server) { 14 | (*c.PubSubChannels)[obj.Ptr.(string)] = nil 15 | de := (*(s.PubSubChannels))[obj.Ptr.(string)] 16 | var clients *List 17 | if de == nil { 18 | clients = listCreate() 19 | (*(s.PubSubChannels))[obj.Ptr.(string)] = clients 20 | } else { 21 | clients = de 22 | } 23 | clients.listAddNodeTail(c) 24 | } 25 | 26 | func PublishCommand(c *Client, s *Server) { 27 | receivers := pubsubPublishMessage(c.Argv[1], c.Argv[2], s) 28 | //广播到其他集群上暂不支持 29 | //aof存储暂不支持 30 | addReplyStatus(c, strconv.Itoa(receivers)) 31 | } 32 | 33 | func pubsubPublishMessage(channel *GodisObject, message *GodisObject, s *Server) int { 34 | receivers := 0 35 | de := (*s.PubSubChannels)[channel.Ptr.(string)] 36 | if de != nil { 37 | for list := de.head; list != nil; list = list.next { 38 | c := list.value.(*Client) 39 | addReplyStatus(c, message.Ptr.(string)) 40 | receivers++ 41 | } 42 | } 43 | return receivers 44 | 45 | } 46 | -------------------------------------------------------------------------------- /godis-cli.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "godis/core/proto" 7 | "log" 8 | "net" 9 | "os" 10 | "strings" 11 | ) 12 | 13 | func main() { 14 | IPPort := "127.0.0.1:9736" 15 | 16 | reader := bufio.NewReader(os.Stdin) 17 | fmt.Println("Hi Godis") 18 | tcpAddr, err := net.ResolveTCPAddr("tcp4", IPPort) 19 | checkError(err) 20 | 21 | //建立连接 如果第二个参数(本地地址)为nil,会自动生成一个本地地址 22 | conn, err := net.DialTCP("tcp", nil, tcpAddr) 23 | checkError(err) 24 | defer conn.Close() 25 | //log.Println(tcpAddr, conn.LocalAddr(), conn.RemoteAddr()) 26 | 27 | for { 28 | fmt.Print(IPPort + "> ") 29 | text, _ := reader.ReadString('\n') 30 | //清除掉回车换行符 31 | text = strings.Replace(text, "\n", "", -1) 32 | send2Server(text, conn) 33 | 34 | buff := make([]byte, 1024) 35 | n, err := conn.Read(buff) 36 | resp, er := proto.DecodeFromBytes(buff) 37 | checkError(err) 38 | if n == 0 { 39 | fmt.Println(IPPort+"> ", "nil") 40 | } else if er == nil { 41 | fmt.Println(IPPort+">", string(resp.Value)) 42 | } else { 43 | fmt.Println(IPPort+"> ", "err server response") 44 | } 45 | } 46 | 47 | } 48 | func send2Server(msg string, conn net.Conn) (n int, err error) { 49 | p, e := proto.EncodeCmd(msg) 50 | if e != nil { 51 | return 0, e 52 | } 53 | //fmt.Println("proto encode", p, string(p)) 54 | n, err = conn.Write(p) 55 | return n, err 56 | } 57 | func checkError(err error) { 58 | if err != nil { 59 | log.Println("err ", err.Error()) 60 | os.Exit(1) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /core/adlist.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | type listNode struct { 4 | prev *listNode 5 | next *listNode 6 | value interface{} 7 | } 8 | 9 | type List struct { 10 | head *listNode 11 | tail *listNode 12 | len int 13 | } 14 | 15 | func (l List) listLength() int { 16 | return l.len 17 | } 18 | 19 | func (l List) listFirst() *listNode { 20 | return l.head 21 | } 22 | 23 | func (l List) listLast() *listNode { 24 | return l.tail 25 | } 26 | 27 | func (n listNode) listPrevNode() *listNode { 28 | return n.prev 29 | } 30 | 31 | func (n listNode) listNextNode() *listNode { 32 | return n.next 33 | } 34 | 35 | func (n listNode) listNodeValue() interface{} { 36 | return n.value 37 | } 38 | 39 | func listCreate() *List { 40 | list := new(List) 41 | list.head = nil 42 | list.tail = nil 43 | list.len = 0 44 | return list 45 | } 46 | 47 | func (l *List) listAddNodeHead(value interface{}) *List { 48 | node := new(listNode) 49 | 50 | node.value = value 51 | if l.len == 0 { 52 | l.head = node 53 | l.tail = node 54 | node.prev = nil 55 | node.next = nil 56 | } else { 57 | node.prev = nil 58 | node.next = l.head 59 | l.head.prev = node 60 | l.head = node 61 | } 62 | l.len++ 63 | return l 64 | } 65 | 66 | func (l *List) listAddNodeTail(value interface{}) *List { 67 | node := new(listNode) 68 | 69 | node.value = value 70 | if l.len == 0 { 71 | l.head = node 72 | l.tail = node 73 | node.prev = nil 74 | node.next = nil 75 | } else { 76 | node.prev = l.tail 77 | node.next = nil 78 | l.tail.next = node 79 | l.tail = node 80 | } 81 | l.len++ 82 | return l 83 | } 84 | 85 | func (l *List) listInsertNode(oldNode *listNode, value interface{}, after int) *List { 86 | node := new(listNode) 87 | node.value = value 88 | if after > 0 { 89 | 90 | } 91 | l.len++ 92 | return l 93 | } 94 | -------------------------------------------------------------------------------- /core/godis.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "godis/core/proto" 8 | "log" 9 | "net" 10 | "os" 11 | ) 12 | 13 | //Client 与服务端连接之后即创建一个Client结构 14 | type Client struct { 15 | Cmd *GodisCommand 16 | Argv []*GodisObject 17 | Argc int 18 | Db *GodisDb 19 | QueryBuf string 20 | Buf string 21 | FakeFlag bool 22 | PubSubChannels *map[string]*List 23 | PubSubPatterns *List 24 | Flags int //client flags 25 | } 26 | 27 | //flags 模式 28 | const CLIENT_PUBSUB = (1 << 18) 29 | 30 | //GodisCommand redis命令结构 31 | type GodisCommand struct { 32 | Name string 33 | Proc cmdFunc 34 | } 35 | 36 | //命令函数指针 37 | type cmdFunc func(c *Client, s *Server) 38 | 39 | // Server 服务端实例结构体 40 | type Server struct { 41 | Db []*GodisDb 42 | DbNum int 43 | Start int64 44 | Port int32 45 | RdbFilename string 46 | AofFilename string 47 | NextClientID int32 48 | SystemMemorySize int32 49 | Clients int32 50 | Pid int 51 | Commands map[string]*GodisCommand 52 | Dirty int64 53 | AofBuf []string 54 | PubSubChannels *map[string]*List 55 | PubSubPatterns *List 56 | } 57 | 58 | //use map[string]* as type dict 59 | //使用Go原生数据结构map作为redis中dict结构体 暂不对dict造轮子 60 | type dict map[string]*GodisObject 61 | 62 | //GodisDb db结构体 63 | type GodisDb struct { 64 | Dict dict 65 | Expires dict 66 | ID int32 67 | } 68 | 69 | // SetCommand cmd of set 70 | func SetCommand(c *Client, s *Server) { 71 | objKey := c.Argv[1] 72 | objValue := c.Argv[2] 73 | if c.Argc != 3 { 74 | addReplyError(c, "(error) ERR wrong number of arguments for 'set' command") 75 | } 76 | if stringKey, ok1 := objKey.Ptr.(string); ok1 { 77 | if stringValue, ok2 := objValue.Ptr.(string); ok2 { 78 | c.Db.Dict[stringKey] = CreateObject(ObjectTypeString, stringValue) 79 | } 80 | } 81 | s.Dirty++ 82 | addReplyStatus(c, "OK") 83 | } 84 | 85 | // GetCommand get命令实现 86 | func GetCommand(c *Client, s *Server) { 87 | o := lookupKey(c.Db, c.Argv[1]) 88 | if o != nil { 89 | addReplyStatus(c, o.Ptr.(string)) 90 | } else { 91 | addReplyStatus(c, "nil") 92 | } 93 | } 94 | 95 | // addReply 添加回复 96 | func addReply(c *Client, o *GodisObject) { 97 | c.Buf = o.Ptr.(string) 98 | } 99 | 100 | func addReplyStatus(c *Client, s string) { 101 | r := proto.NewString([]byte(s)) 102 | addReplyString(c, r) 103 | } 104 | func addReplyError(c *Client, s string) { 105 | r := proto.NewError([]byte(s)) 106 | addReplyString(c, r) 107 | } 108 | func addReplyString(c *Client, r *proto.Resp) { 109 | if ret, err := proto.EncodeToBytes(r); err == nil { 110 | c.Buf = string(ret) 111 | } 112 | } 113 | 114 | // ProcessCommand 执行命令 115 | func (s *Server) ProcessCommand(c *Client) { 116 | v := c.Argv[0].Ptr 117 | name, ok := v.(string) 118 | if !ok { 119 | log.Println("error cmd") 120 | os.Exit(1) 121 | } 122 | cmd := lookupCommand(name, s) 123 | fmt.Println(cmd, name, s) 124 | if cmd != nil { 125 | c.Cmd = cmd 126 | call(c, s) 127 | } else { 128 | addReplyError(c, fmt.Sprintf("(error) ERR unknown command '%s'", name)) 129 | } 130 | } 131 | 132 | // lookupCommand查找命令 133 | func lookupCommand(name string, s *Server) *GodisCommand { 134 | if cmd, ok := s.Commands[name]; ok { 135 | return cmd 136 | } 137 | return nil 138 | } 139 | 140 | // call 真正调用命令 141 | func call(c *Client, s *Server) { 142 | dirty := s.Dirty 143 | c.Cmd.Proc(c, s) 144 | dirty = s.Dirty - dirty 145 | if dirty > 0 && !c.FakeFlag { 146 | AppendToFile(s.AofFilename, c.QueryBuf) 147 | } 148 | 149 | } 150 | func lookupKey(db *GodisDb, key *GodisObject) (ret *GodisObject) { 151 | if o, ok := db.Dict[key.Ptr.(string)]; ok { 152 | return o 153 | } 154 | return nil 155 | } 156 | 157 | // CreateClient 连接建立 创建client记录当前连接 158 | func (s *Server) CreateClient() (c *Client) { 159 | c = new(Client) 160 | c.Db = s.Db[0] 161 | c.QueryBuf = "" 162 | tmp := make(map[string]*List, 0) 163 | c.PubSubChannels = &tmp 164 | c.Flags = 0 165 | return c 166 | } 167 | 168 | // ReadQueryFromClient 读取客户端请求信息 169 | func (c *Client) ReadQueryFromClient(conn net.Conn) (err error) { 170 | buff := make([]byte, 512) 171 | n, err := conn.Read(buff) 172 | 173 | if err != nil { 174 | log.Println("conn.Read err!=nil", err, "---len---", n, conn) 175 | conn.Close() 176 | return err 177 | } 178 | c.QueryBuf = string(buff) 179 | return nil 180 | } 181 | 182 | // ProcessInputBuffer 处理客户端请求信息 183 | func (c *Client) ProcessInputBuffer() error { 184 | //r := regexp.MustCompile("[^\\s]+") 185 | decoder := proto.NewDecoder(bytes.NewReader([]byte(c.QueryBuf))) 186 | //decoder := proto.NewDecoder(bytes.NewReader([]byte("*2\r\n$3\r\nget\r\n"))) 187 | if resp, err := decoder.DecodeMultiBulk(); err == nil { 188 | c.Argc = len(resp) 189 | c.Argv = make([]*GodisObject, c.Argc) 190 | for k, s := range resp { 191 | c.Argv[k] = CreateObject(ObjectTypeString, string(s.Value)) 192 | } 193 | return nil 194 | } 195 | return errors.New("ProcessInputBuffer failed") 196 | } 197 | -------------------------------------------------------------------------------- /util/bufio2/bufio.go: -------------------------------------------------------------------------------- 1 | package bufio2 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "io" 7 | ) 8 | 9 | type Reader struct { 10 | err error 11 | buf []byte 12 | 13 | rd io.Reader 14 | rpos int 15 | wpos int 16 | 17 | slice sliceAlloc 18 | } 19 | type sliceAlloc struct { 20 | buf []byte 21 | } 22 | 23 | func NewReader(rd io.Reader) *Reader { 24 | return NewReaderSize(rd, 1024) 25 | } 26 | 27 | func NewReaderSize(rd io.Reader, size int) *Reader { 28 | if size <= 0 { 29 | size = 1024 30 | } 31 | return &Reader{rd: rd, buf: make([]byte, size)} 32 | } 33 | 34 | func (b *Reader) fill() error { 35 | if b.err != nil { 36 | return b.err 37 | } 38 | if b.rpos > 0 { 39 | n := copy(b.buf, b.buf[b.rpos:b.wpos]) 40 | b.rpos = 0 41 | b.wpos = n 42 | } 43 | n, err := b.rd.Read(b.buf[b.wpos:]) 44 | if err != nil { 45 | b.err = err 46 | } else if n == 0 { 47 | b.err = io.ErrNoProgress 48 | } else { 49 | b.wpos += n 50 | } 51 | return b.err 52 | } 53 | 54 | func (b *Reader) buffered() int { 55 | return b.wpos - b.rpos 56 | } 57 | func (b *Reader) ReadByte() (byte, error) { 58 | if b.err != nil { 59 | return 0, b.err 60 | } 61 | if b.buffered() == 0 { 62 | if b.fill() != nil { 63 | return 0, b.err 64 | } 65 | } 66 | c := b.buf[b.rpos] 67 | b.rpos++ 68 | return c, nil 69 | } 70 | 71 | func (b *Reader) ReadBytes(delim byte) ([]byte, error) { 72 | var full [][]byte 73 | var last []byte 74 | var size int 75 | for last == nil { 76 | f, err := b.ReadSlice(delim) 77 | if err != nil { 78 | if err != bufio.ErrBufferFull { 79 | return nil, b.err 80 | } 81 | dup := b.slice.Make(len(f)) 82 | copy(dup, f) 83 | full = append(full, dup) 84 | } else { 85 | last = f 86 | } 87 | size += len(f) 88 | } 89 | var n int 90 | var buf = b.slice.Make(size) 91 | for _, frag := range full { 92 | n += copy(buf[n:], frag) 93 | } 94 | copy(buf[n:], last) 95 | return buf, nil 96 | } 97 | func (b *Reader) ReadSlice(delim byte) ([]byte, error) { 98 | if b.err != nil { 99 | return nil, b.err 100 | } 101 | for { 102 | var index = bytes.IndexByte(b.buf[b.rpos:b.wpos], delim) 103 | if index >= 0 { 104 | limit := b.rpos + index + 1 105 | slice := b.buf[b.rpos:limit] 106 | b.rpos = limit 107 | return slice, nil 108 | } 109 | if b.buffered() == len(b.buf) { 110 | b.rpos = b.wpos 111 | return b.buf, bufio.ErrBufferFull 112 | } 113 | if b.fill() != nil { 114 | return nil, b.err 115 | } 116 | } 117 | } 118 | 119 | func (b *Reader) ReadFull(n int) ([]byte, error) { 120 | //return b.buf[] 121 | if b.err != nil || n == 0 { 122 | return nil, b.err 123 | } 124 | var buf = b.slice.Make(n) 125 | if _, err := io.ReadFull(bytes.NewReader(b.buf[b.rpos:]), buf); err != nil { 126 | return nil, err 127 | } 128 | b.rpos += n 129 | return buf, nil 130 | } 131 | 132 | type Writer struct { 133 | err error 134 | buf []byte 135 | 136 | wr io.Writer 137 | wpos int 138 | } 139 | 140 | func NewWriter(wr io.Writer) *Writer { 141 | return NewWriterSize(wr, 1024) 142 | } 143 | 144 | func NewWriterSize(wr io.Writer, size int) *Writer { 145 | if size <= 0 { 146 | size = 1024 147 | } 148 | return &Writer{wr: wr, buf: make([]byte, size)} 149 | } 150 | 151 | func (d *sliceAlloc) Make(n int) (ss []byte) { 152 | switch { 153 | case n == 0: 154 | return []byte{} 155 | case n >= 512: 156 | return make([]byte, n) 157 | default: 158 | if len(d.buf) < n { 159 | d.buf = make([]byte, 8192) 160 | } 161 | ss, d.buf = d.buf[:n:n], d.buf[n:] 162 | return ss 163 | } 164 | } 165 | 166 | //Flush api 167 | func (b *Writer) Flush() error { 168 | return b.flush() 169 | } 170 | 171 | func (b *Writer) flush() error { 172 | if b.err != nil { 173 | return b.err 174 | } 175 | if b.wpos == 0 { 176 | return nil 177 | } 178 | n, err := b.wr.Write(b.buf[:b.wpos]) 179 | if err != nil { 180 | b.err = err 181 | } else if n < b.wpos { 182 | b.err = io.ErrShortWrite 183 | } else { 184 | b.wpos = 0 185 | } 186 | return b.err 187 | } 188 | 189 | func (b *Writer) available() int { 190 | return len(b.buf) - b.wpos 191 | } 192 | func (b *Writer) Write(p []byte) (nn int, err error) { 193 | for b.err == nil && len(p) > b.available() { 194 | var n int 195 | if b.wpos == 0 { 196 | n, b.err = b.wr.Write(p) 197 | } else { 198 | n = copy(b.buf[b.wpos:], p) 199 | b.wpos += n 200 | b.flush() 201 | } 202 | nn, p = nn+n, p[n:] 203 | } 204 | if b.err != nil || len(p) == 0 { 205 | return nn, b.err 206 | } 207 | n := copy(b.buf[b.wpos:], p) 208 | b.wpos += n 209 | return nn + n, nil 210 | } 211 | 212 | // WriteByte write byte 213 | func (b *Writer) WriteByte(c byte) error { 214 | if b.err != nil { 215 | return b.err 216 | } 217 | if b.available() == 0 && b.flush() != nil { 218 | return b.err 219 | } 220 | b.buf[b.wpos] = c 221 | b.wpos++ 222 | return nil 223 | } 224 | 225 | // WriteString write buf 226 | func (b *Writer) WriteString(s string) (nn int, err error) { 227 | for b.err == nil && len(s) > b.available() { 228 | n := copy(b.buf[b.wpos:], s) 229 | b.wpos += n 230 | b.flush() 231 | nn, s = nn+n, s[n:] 232 | } 233 | if b.err != nil || len(s) == 0 { 234 | return nn, b.err 235 | } 236 | n := copy(b.buf[b.wpos:], s) 237 | b.wpos += n 238 | return nn + n, nil 239 | } 240 | func (b *Reader) PeekByte() (byte, error) { 241 | if b.err != nil { 242 | return 0, b.err 243 | } 244 | if b.buffered() == 0 { 245 | if b.fill() != nil { 246 | return 0, b.err 247 | } 248 | } 249 | c := b.buf[b.rpos] 250 | return c, nil 251 | } 252 | -------------------------------------------------------------------------------- /godis-server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "godis/core" 6 | "log" 7 | "net" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | "time" 12 | ) 13 | 14 | const ( 15 | DefaultAofFile = "./godis.aof" 16 | ) 17 | 18 | // 服务端实例 19 | var godis = new(core.Server) 20 | 21 | func main() { 22 | /*---- 命令行参数处理 ----*/ 23 | argv := os.Args 24 | argc := len(os.Args) 25 | if argc >= 2 { 26 | /* Handle special options --help and --version */ 27 | if argv[1] == "-v" || argv[1] == "--version" { 28 | version() 29 | } 30 | if argv[1] == "--help" || argv[1] == "-h" { 31 | usage() 32 | } 33 | } 34 | 35 | /*---- 监听信号 平滑退出 ----*/ 36 | c := make(chan os.Signal) 37 | signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR1, syscall.SIGUSR2) 38 | go sigHandler(c) 39 | 40 | /*---- 初始化服务端实例 ----*/ 41 | initServer() 42 | 43 | /*---- 网络处理 ----*/ 44 | netListen, err := net.Listen("tcp", "127.0.0.1:9736") 45 | if err != nil { 46 | log.Print("listen err ") 47 | } 48 | //checkError(err) 49 | defer netListen.Close() 50 | 51 | for { 52 | conn, err := netListen.Accept() 53 | 54 | if err != nil { 55 | continue 56 | } 57 | //log.Println(conn.LocalAddr(), conn.RemoteAddr()) 58 | go handle(conn) 59 | } 60 | } 61 | 62 | // 处理请求 63 | func handle(conn net.Conn) { 64 | c := godis.CreateClient() 65 | for { 66 | if c.Flags&core.CLIENT_PUBSUB > 0 { 67 | if c.Buf != "" { 68 | responseConn(conn, c) 69 | c.Buf = "" 70 | } 71 | time.Sleep(1) 72 | 73 | } else { 74 | err := c.ReadQueryFromClient(conn) 75 | 76 | if err != nil { 77 | log.Println("readQueryFromClient err", err) 78 | return 79 | } 80 | err = c.ProcessInputBuffer() 81 | if err != nil { 82 | log.Println("ProcessInputBuffer err", err) 83 | return 84 | } 85 | godis.ProcessCommand(c) 86 | responseConn(conn, c) 87 | } 88 | } 89 | } 90 | 91 | // 响应返回给客户端 92 | func responseConn(conn net.Conn, c *core.Client) { 93 | conn.Write([]byte(c.Buf)) 94 | } 95 | 96 | // 初始化服务端实例 97 | func initServer() { 98 | godis.Pid = os.Getpid() 99 | godis.DbNum = 16 100 | initDb() 101 | godis.Start = time.Now().UnixNano() / 1000000 102 | //var getf server.CmdFun 103 | godis.AofFilename = DefaultAofFile 104 | 105 | getCommand := &core.GodisCommand{Name: "get", Proc: core.GetCommand} 106 | setCommand := &core.GodisCommand{Name: "set", Proc: core.SetCommand} 107 | subscribeCommand := &core.GodisCommand{Name: "subscribe", Proc: core.SubscribeCommand} 108 | publishCommand := &core.GodisCommand{Name: "publish", Proc: core.PublishCommand} 109 | geoaddCommand := &core.GodisCommand{Name: "geoadd", Proc: core.GeoAddCommand} 110 | geohashCommand := &core.GodisCommand{Name: "geohash", Proc: core.GeoHashCommand} 111 | geoposCommand := &core.GodisCommand{Name: "geopos", Proc: core.GeoPosCommand} 112 | geodistCommand := &core.GodisCommand{Name: "geodist", Proc: core.GeoDistCommand} 113 | georadiusCommand := &core.GodisCommand{Name: "georadius", Proc: core.GeoRadiusCommand} 114 | georadiusbymemberCommand := &core.GodisCommand{Name: "georadiusbymember", Proc: core.GeoRadiusByMemberCommand} 115 | 116 | godis.Commands = map[string]*core.GodisCommand{ 117 | "get": getCommand, 118 | "set": setCommand, 119 | "geoadd": geoaddCommand, 120 | "geohash": geohashCommand, 121 | "geopos": geoposCommand, 122 | "geodist": geodistCommand, 123 | "georadius": georadiusCommand, 124 | "georadiusbymember": georadiusbymemberCommand, 125 | "subscribe": subscribeCommand, 126 | "publish": publishCommand, 127 | } 128 | tmp := make(map[string]*core.List) 129 | godis.PubSubChannels = &tmp 130 | LoadData() 131 | } 132 | 133 | // 初始化db 134 | func initDb() { 135 | godis.Db = make([]*core.GodisDb, godis.DbNum) 136 | for i := 0; i < godis.DbNum; i++ { 137 | godis.Db[i] = new(core.GodisDb) 138 | godis.Db[i].Dict = make(map[string]*core.GodisObject, 100) 139 | } 140 | } 141 | func LoadData() { 142 | c := godis.CreateClient() 143 | c.FakeFlag = true 144 | pros := core.ReadAof(godis.AofFilename) 145 | for _, v := range pros { 146 | c.QueryBuf = string(v) 147 | err := c.ProcessInputBuffer() 148 | if err != nil { 149 | log.Println("ProcessInputBuffer err", err) 150 | } 151 | godis.ProcessCommand(c) 152 | } 153 | } 154 | 155 | func sigHandler(c chan os.Signal) { 156 | for s := range c { 157 | switch s { 158 | case syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT: 159 | exitHandler() 160 | default: 161 | fmt.Println("signal ", s) 162 | } 163 | } 164 | } 165 | 166 | func exitHandler() { 167 | fmt.Println("exiting smoothly ...") 168 | fmt.Println("bye ") 169 | os.Exit(0) 170 | } 171 | 172 | func version() { 173 | println("Godis server v=0.0.1 sha=xxxxxxx:001 malloc=libc-go bits=64 ") 174 | os.Exit(0) 175 | } 176 | 177 | func usage() { 178 | println("Usage: ./godis-server [/path/to/redis.conf] [options]") 179 | println(" ./godis-server - (read config from stdin)") 180 | println(" ./godis-server -v or --version") 181 | println(" ./godis-server -h or --help") 182 | println("Examples:") 183 | println(" ./godis-server (run the server with default conf)") 184 | println(" ./godis-server /etc/redis/6379.conf") 185 | println(" ./godis-server --port 7777") 186 | println(" ./godis-server --port 7777 --slaveof 127.0.0.1 8888") 187 | println(" ./godis-server /etc/myredis.conf --loglevel verbose") 188 | println("Sentinel mode:") 189 | println(" ./godis-server /etc/sentinel.conf --sentinel") 190 | os.Exit(0) 191 | } 192 | -------------------------------------------------------------------------------- /core/zset.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "math/rand" 5 | ) 6 | 7 | /* Input flags. */ 8 | const ZADD_NONE = 0 9 | const ZADD_INCR = (1 << 0) /* Increment the score instead of setting it. */ 10 | const ZADD_NX = (1 << 1) /* Don't touch elements not already existing. */ 11 | const ZADD_XX = (1 << 2) 12 | 13 | /* Output flags. */ 14 | const ZADD_NOP = (1 << 3) /* Operation not performed because of conditionals.*/ 15 | const ZADD_NAN = (1 << 4) /* Only touch elements already exisitng. */ 16 | const ZADD_ADDED = (1 << 5) /* The element was new and was added. */ 17 | const ZADD_UPDATED = (1 << 6) /* The element already existed, score updated. */ 18 | 19 | const ZSKIPLIST_MAXLEVEL = 32 20 | const ZSKIPLIST_P = 0.25 /* Skiplist P = 1/4 */ 21 | 22 | type zSet struct { 23 | dict *dict 24 | zsl *zSkipList 25 | } 26 | 27 | type zSkipList struct { 28 | header *zSkipListNode 29 | tail *zSkipListNode 30 | length uint //节点层的数量 31 | level int //表示目前跳跃表内,层数最大的节点的层数 32 | } 33 | 34 | type zSkipListNode struct { //一层的节点 35 | ele string 36 | score float64 // 37 | backward *zSkipListNode // 后退指针 38 | level []zSkipListLevel 39 | } 40 | 41 | type zSkipListLevel struct { 42 | forward *zSkipListNode 43 | span uint 44 | } 45 | 46 | // ex结尾指示开区间还是闭区间 值为 1 表示开,值为 0 表示闭 47 | type zRangeSpec struct { 48 | min float64 49 | max float64 50 | minEx int 51 | maxEx int 52 | } 53 | 54 | func zaddCommand(c *Client) { 55 | zaddGenericCommand(c, ZADD_NONE) 56 | } 57 | 58 | func zincrbyCommand(c *Client) { 59 | zaddGenericCommand(c, ZADD_INCR) 60 | } 61 | 62 | /*----------------------------------------------------------------------------- 63 | * Sorted set commands 64 | *----------------------------------------------------------------------------*/ 65 | 66 | /* This generic command implements both ZADD and ZINCRBY. */ 67 | func zaddGenericCommand(c *Client, flags int) { 68 | key := c.Argv[1] 69 | scoreIdx := 2 70 | elements := c.Argc - scoreIdx 71 | elements /= 2 72 | scores := make([]float64, elements) 73 | 74 | for j := 0; j < elements; j++ { 75 | if value, ok := c.Argv[2+j*2].Ptr.(uint64); ok { 76 | scores[j] = float64(value) 77 | } 78 | } 79 | 80 | //这里首先在client对应的db中查找该key,即有序集 81 | zobj := lookupKey(c.Db, key) 82 | if zobj == nil { 83 | //hash+skiplist组合方式,后续再进行判断实现ziplist 84 | zobj = createZsetObject() 85 | //添加到c.db中 86 | c.Db.Dict[key.Ptr.(string)] = zobj 87 | } 88 | 89 | for j := 0; j < elements; j++ { 90 | var newScore float64 91 | score := scores[j] 92 | retFlags := flags 93 | if ele, ok := c.Argv[scoreIdx+1+j*2].Ptr.(string); ok { 94 | zSetAdd(zobj, score, ele, &retFlags, &newScore) 95 | } 96 | 97 | } 98 | 99 | } 100 | 101 | // create zset 102 | func createZsetObject() *GodisObject { 103 | val := new(zSet) 104 | val.dict = new(dict) 105 | dict := make(map[string]*GodisObject) 106 | *val.dict = dict 107 | 108 | val.zsl = zslCreate() //这里创建节点 109 | o := CreateObject(OBJ_ZSET, val) 110 | return o 111 | } 112 | 113 | /* Create a new skiplist. */ 114 | func zslCreate() *zSkipList { 115 | zsl := new(zSkipList) 116 | zsl.level = 1 117 | zsl.length = 0 118 | zsl.header = zslCreateNode(ZSKIPLIST_MAXLEVEL, 0, "") 119 | for j := 0; j < ZSKIPLIST_MAXLEVEL; j++ { 120 | zsl.header.level[j].forward = nil 121 | zsl.header.level[j].span = 0 122 | } 123 | zsl.header.backward = nil 124 | zsl.tail = nil 125 | return zsl 126 | } 127 | 128 | // 参数依次是有序集,要添加的元素的score,要添加的元素,操作模式,新的score 129 | func zSetAdd(zObj *GodisObject, score float64, ele string, flags *int, newScore *float64) bool { 130 | incr := (*flags & ZADD_INCR) != 0 131 | nx := (*flags & ZADD_NX) != 0 132 | xx := (*flags & ZADD_XX) != 0 133 | *flags = 0 134 | var curscore float64 135 | //暂只支持skiplist 136 | if zObj.ObjectType == OBJ_ZSET { 137 | //进行hash查找 138 | zs := zObj.Ptr.(*zSet) //使用*zSet好,还是zSet好 139 | 140 | dict := zs.dict 141 | de := dictFind(dict, ele) 142 | if de != nil { 143 | if nx { 144 | 145 | } 146 | if incr { 147 | 148 | } 149 | //获取存储的score 150 | if coreTemp, ok := de.Ptr.(float64); ok { 151 | curscore = coreTemp 152 | } else { 153 | //exit 154 | } 155 | 156 | //remove and in-insert when score changes 157 | if curscore != score { 158 | 159 | } 160 | 161 | } else if !xx { 162 | //insert 163 | zslInsert(zs.zsl, score, ele) 164 | //插入dict 165 | (*(zs.dict))[ele] = CreateObject(ObjectTypeString, score) 166 | *flags |= ZADD_ADDED 167 | return true 168 | } 169 | } else { 170 | //exit; 171 | } 172 | return false /* Never reached. */ 173 | } 174 | 175 | func dictFind(d *dict, key string) *GodisObject { 176 | if (*d)[key] != nil { 177 | return (*d)[key] 178 | } 179 | return nil 180 | } 181 | 182 | /* 183 | * 创建一个成员为 obj ,分值为 score 的新节点, 184 | * 并将这个新节点插入到跳跃表 zsl 中。 185 | * 186 | * 函数的返回值为新节点。 187 | * 188 | * T_wrost = O(N^2), T_avg = O(N log N) 189 | */ 190 | /*src/t_zset.c/zslInsert*/ 191 | func zslInsert(zsl *zSkipList, score float64, ele string) *zSkipListNode { 192 | update := make([]*zSkipListNode, ZSKIPLIST_MAXLEVEL) 193 | rank := make([]uint, ZSKIPLIST_MAXLEVEL) 194 | x := zsl.header 195 | 196 | for i := zsl.level - 1; i >= 0; i-- { 197 | if i == zsl.level-1 { 198 | rank[i] = 0 199 | } else { 200 | rank[i] = rank[i+1] 201 | } 202 | 203 | for x.level[i].forward != nil && (x.level[i].forward.score < score || 204 | (x.level[i].forward.score == score && (x.level[i].forward.ele < ele))) { 205 | rank[i] += x.level[i].span 206 | x = x.level[i].forward 207 | } 208 | update[i] = x 209 | } 210 | 211 | level := zslRandomLevel() 212 | if level > zsl.level { 213 | for i := zsl.level; i < level; i++ { 214 | rank[i] = 0 215 | update[i] = zsl.header 216 | update[i].level[i].span = zsl.length 217 | } 218 | zsl.level = level 219 | } 220 | 221 | x = zslCreateNode(level, score, ele) 222 | for i := 0; i < level; i++ { 223 | x.level[i].forward = update[i].level[i].forward 224 | update[i].level[i].forward = x 225 | x.level[i].span = update[i].level[i].span - (rank[0] - rank[i]) 226 | update[i].level[i].span = (rank[0] - rank[i]) + 1 227 | } 228 | 229 | for i := level; i < zsl.level; i++ { 230 | update[i].level[i].span++ 231 | } 232 | 233 | if update[0] == zsl.header { 234 | x.backward = nil 235 | } else { 236 | x.backward = update[0] 237 | } 238 | 239 | if x.level[0].forward != nil { 240 | x.level[0].forward.backward = x 241 | } else { 242 | zsl.tail = x 243 | } 244 | zsl.length++ 245 | return x 246 | } 247 | 248 | // 获取一个随机值作为新节点的层数 249 | func zslRandomLevel() int { 250 | level := 1 251 | for rand.Float64()*65535 < ZSKIPLIST_P*65535 { 252 | level++ 253 | } 254 | 255 | if level < ZSKIPLIST_MAXLEVEL { 256 | return level 257 | } 258 | return ZSKIPLIST_MAXLEVEL 259 | } 260 | 261 | // 创建新节点zSkipListNode 262 | func zslCreateNode(level int, score float64, ele string) *zSkipListNode { 263 | zn := new(zSkipListNode) 264 | zl := make([]zSkipListLevel, level) 265 | zn.level = zl 266 | zn.score = score 267 | zn.ele = ele 268 | return zn 269 | } 270 | 271 | func zslFirstInRange(zsl *zSkipList, zRange *zRangeSpec) *zSkipListNode { 272 | if !zslIsInRange(zsl, zRange) { 273 | return nil 274 | } 275 | x := zsl.header 276 | for i := zsl.level - 1; i >= 0; i-- { 277 | for x.level[i].forward != nil && !zslValueGteMin(x.level[i].forward.score, zRange) { 278 | x = x.level[i].forward 279 | } 280 | } 281 | 282 | x = x.level[0].forward 283 | if x == nil { 284 | return nil 285 | } 286 | 287 | if !zslValueLteMax(x.score, zRange) { 288 | return nil 289 | } 290 | return x 291 | } 292 | 293 | func zslValueGteMin(value float64, spec *zRangeSpec) bool { 294 | if spec.minEx != 0 { 295 | return value > spec.min 296 | } 297 | return value >= spec.min 298 | } 299 | 300 | func zslValueLteMax(value float64, spec *zRangeSpec) bool { 301 | if spec.maxEx != 0 { 302 | return value < spec.max 303 | } 304 | return value <= spec.max 305 | } 306 | 307 | func zslIsInRange(zsl *zSkipList, zRange *zRangeSpec) bool { 308 | //test invalid param 309 | if zRange.min > zRange.max || 310 | (zRange.min == zRange.max && (zRange.minEx != 0 || zRange.maxEx != 0)) { 311 | return false 312 | } 313 | x := zsl.tail 314 | if x == nil || !zslValueGteMin(x.score, zRange) { 315 | return false 316 | } 317 | x = zsl.header.level[0].forward 318 | if x == nil || !zslValueLteMax(x.score, zRange) { 319 | return false 320 | } 321 | return true 322 | } 323 | 324 | //从skiplist删除一个节点 325 | func zslDelete(zsl *zSkipList, score float64, ele string, node **zSkipListNode) bool { 326 | update := make([]*zSkipListNode, ZSKIPLIST_MAXLEVEL) 327 | rank := make([]uint, ZSKIPLIST_MAXLEVEL) 328 | x := zsl.header 329 | for i := zsl.level - 1; i >= 0; i-- { 330 | if i == zsl.level-1 { 331 | rank[i] = 0 332 | } else { 333 | rank[i] = rank[i+1] 334 | } 335 | 336 | for x.level[i].forward != nil && (x.level[i].forward.score < score || 337 | (x.level[i].forward.score == score && (x.level[i].forward.ele < ele))) { 338 | rank[i] += x.level[i].span 339 | x = x.level[i].forward 340 | } 341 | update[i] = x 342 | } 343 | x = x.level[0].forward 344 | if x != nil && score == x.score && ele == x.ele { 345 | zslDeleteNode(zsl, x, update) 346 | return true 347 | } 348 | return false 349 | } 350 | 351 | func zslDeleteNode(zsl *zSkipList, x *zSkipListNode, update []*zSkipListNode) { 352 | for i := 0; i < zsl.level; i++ { 353 | if update[i].level[i].forward == x { 354 | update[i].level[i].span += x.level[i].span - 1 355 | update[i].level[i].forward = x.level[i].forward 356 | } else { 357 | update[i].level[i].span -= 1 358 | } 359 | } 360 | 361 | if x.level[0].forward != nil { 362 | x.level[0].forward.backward = x.backward 363 | } else { 364 | zsl.tail = x.backward 365 | } 366 | 367 | for zsl.level > 1 && zsl.header.level[zsl.level-1].forward == nil { 368 | zsl.level-- 369 | } 370 | zsl.length-- 371 | } 372 | 373 | func zsetScore(zobj *GodisObject, member string, score *float64) int { 374 | if zobj == nil || member == "" { 375 | return C_ERR 376 | } 377 | // only search skiplist 378 | if zobj.ObjectType == OBJ_ZSET { 379 | zs := zobj.Ptr.(*zSet) 380 | dict := zs.dict 381 | de := dictFind(dict, member) 382 | 383 | if de == nil { 384 | return C_ERR 385 | } 386 | value := de.Ptr.(float64) 387 | *score = value 388 | } else { 389 | panic("Unknown sorted set encoding") 390 | } 391 | return C_OK 392 | } 393 | -------------------------------------------------------------------------------- /core/proto/proto.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "bytes" 5 | "godis/util/bufio2" 6 | "io" 7 | "log" 8 | "strconv" 9 | 10 | "errors" 11 | ) 12 | 13 | var ( 14 | ErrBadArrayLen = errors.New("bad array len") 15 | ErrBadArrayLenTooLong = errors.New("bad array len, too long") 16 | 17 | ErrBadBulkBytesLen = errors.New("bad bulk bytes len") 18 | ErrBadBulkBytesLenTooLong = errors.New("bad bulk bytes len, too long") 19 | 20 | ErrBadMultiBulkLen = errors.New("bad multi-bulk len") 21 | ErrBadMultiBulkContent = errors.New("bad multi-bulk content, should be bulkbytes") 22 | ) 23 | 24 | const ( 25 | // MaxBulkBytesLen 最大长度 26 | MaxBulkBytesLen = 1024 * 1024 * 512 27 | // MaxArrayLen 最大长度 28 | MaxArrayLen = 1024 * 1024 29 | ) 30 | 31 | type RespType byte 32 | 33 | const ( 34 | TypeString = '+' 35 | TypeError = '-' 36 | TypeInt = ':' 37 | TypeBulkBytes = '$' 38 | TypeArray = '*' 39 | ) 40 | 41 | // Btoi64 byte to int64 42 | func Btoi64(b []byte) (int64, error) { 43 | if len(b) != 0 && len(b) < 10 { 44 | var neg, i = false, 0 45 | switch b[0] { 46 | case '-': 47 | neg = true 48 | fallthrough 49 | case '+': 50 | i++ 51 | } 52 | if len(b) != i { 53 | var n int64 54 | for ; i < len(b) && b[i] >= '0' && b[i] <= '9'; i++ { 55 | n = int64(b[i]-'0') + n*10 56 | } 57 | if len(b) == i { 58 | if neg { 59 | n = -n 60 | } 61 | return n, nil 62 | } 63 | } 64 | } 65 | 66 | if n, err := strconv.ParseInt(string(b), 10, 64); err != nil { 67 | return 0, errorsTrace(err) 68 | } else { 69 | return n, nil 70 | } 71 | } 72 | 73 | /*---- Encoder ----*/ 74 | 75 | type Encoder struct { 76 | bw *bufio2.Writer 77 | 78 | Err error 79 | } 80 | 81 | // NewEncoder 82 | func NewEncoder(w io.Writer) *Encoder { 83 | return NewEncoderBuffer(bufio2.NewWriterSize(w, 8192)) 84 | } 85 | 86 | // NewEncoderSize new encoder by size 87 | func NewEncoderSize(w io.Writer, size int) *Encoder { 88 | return NewEncoderBuffer(bufio2.NewWriterSize(w, size)) 89 | } 90 | 91 | // NewEncoderBuffer new encoder by bufWriter 92 | func NewEncoderBuffer(bw *bufio2.Writer) *Encoder { 93 | return &Encoder{bw: bw} 94 | } 95 | 96 | // Encode 转换为协议 97 | func (e *Encoder) Encode(r *Resp, flush bool) error { 98 | if e.Err != nil { 99 | return errorsTrace(e.Err) 100 | } 101 | if err := e.encodeResp(r); err != nil { 102 | e.Err = err 103 | } else if flush { 104 | e.Err = errorsTrace(e.bw.Flush()) 105 | } 106 | return e.Err 107 | } 108 | 109 | // EncodeCmd 命令行编码协议 110 | func EncodeCmd(cmd string) ([]byte, error) { 111 | return EncodeBytes([]byte(cmd)) 112 | } 113 | 114 | // EncodeBytes Bytes编码协议 115 | func EncodeBytes(b []byte) ([]byte, error) { 116 | r := bytes.Split(b, []byte(" ")) 117 | if r == nil { 118 | return nil, errorsTrace(errorNew("empty split")) 119 | } 120 | resp := NewArray(nil) 121 | for _, v := range r { 122 | if len(v) > 0 { 123 | resp.Array = append(resp.Array, NewBulkBytes(v)) 124 | } 125 | } 126 | return EncodeToBytes(resp) 127 | } 128 | 129 | // EncodeMultiBulk encode 多条批量回复 130 | func (e *Encoder) EncodeMultiBulk(multi []*Resp, flush bool) error { 131 | if e.Err != nil { 132 | return errorsTrace(e.Err) 133 | } 134 | if err := e.encodeMultiBulk(multi); err != nil { 135 | e.Err = err 136 | } else if flush { 137 | e.Err = errorsTrace(e.Err) 138 | } 139 | return e.Err 140 | } 141 | 142 | // Flush buf to writer 143 | func (e *Encoder) Flush() error { 144 | if e.Err != nil { 145 | return errorsTrace(errorNew("Flush error")) 146 | } 147 | if err := e.bw.Flush(); err != nil { 148 | e.Err = errorsTrace(errorNew("bw.Flush error")) 149 | } 150 | return e.Err 151 | } 152 | 153 | // Encode 转换为协议接口 154 | func Encode(w io.Writer, r *Resp) error { 155 | return NewEncoder(w).Encode(r, true) 156 | } 157 | 158 | // EncodeToBytes Resp编码协议 159 | func EncodeToBytes(r *Resp) ([]byte, error) { 160 | var b = &bytes.Buffer{} 161 | if err := Encode(b, r); err != nil { 162 | return nil, err 163 | } 164 | return b.Bytes(), nil 165 | } 166 | 167 | // encodeResp 编码 168 | func (e *Encoder) encodeResp(r *Resp) error { 169 | if err := e.bw.WriteByte(byte(r.Type)); err != nil { 170 | return errorsTrace(err) 171 | } 172 | switch r.Type { 173 | case TypeString, TypeError, TypeInt: 174 | return e.encodeTextBytes(r.Value) 175 | case TypeBulkBytes: 176 | return e.encodeBulkBytes(r.Value) 177 | case TypeArray: 178 | return e.encodeArray(r.Array) 179 | default: 180 | return errorsTrace(e.Err) 181 | } 182 | } 183 | 184 | // encodeMultiBulk encode 多条批量回复 185 | func (e *Encoder) encodeMultiBulk(multi []*Resp) error { 186 | if err := e.bw.WriteByte(byte(TypeArray)); err != nil { 187 | return errorsTrace(err) 188 | } 189 | return e.encodeArray(multi) 190 | } 191 | 192 | // encodeTextBytes encode text type 193 | func (e *Encoder) encodeTextBytes(b []byte) error { 194 | if _, err := e.bw.Write(b); err != nil { 195 | return errorsTrace(err) 196 | } 197 | if _, err := e.bw.WriteString("\r\n"); err != nil { 198 | return errorsTrace(err) 199 | } 200 | return nil 201 | } 202 | 203 | // encode text type 204 | func (e *Encoder) encodeTextString(s string) error { 205 | if _, err := e.bw.WriteString(s); err != nil { 206 | return errorsTrace(err) 207 | } 208 | if _, err := e.bw.WriteString("\r\n"); err != nil { 209 | return errorsTrace(err) 210 | } 211 | return nil 212 | } 213 | 214 | // encodeInt encode整数 215 | func (e *Encoder) encodeInt(v int64) error { 216 | return e.encodeTextString(strconv.FormatInt(v, 10)) 217 | } 218 | 219 | // encodeBulkBytes 批量回复 220 | func (e *Encoder) encodeBulkBytes(b []byte) error { 221 | if b == nil { 222 | return e.encodeInt(-1) 223 | } else { 224 | if err := e.encodeInt(int64(len(b))); err != nil { 225 | return err 226 | } 227 | return e.encodeTextBytes(b) 228 | } 229 | } 230 | 231 | // encodeArray encode 多条批量回复 232 | func (e *Encoder) encodeArray(array []*Resp) error { 233 | if array == nil { 234 | return e.encodeInt(-1) 235 | } else { 236 | if err := e.encodeInt(int64(len(array))); err != nil { 237 | return err 238 | } 239 | for _, r := range array { 240 | if err := e.encodeResp(r); err != nil { 241 | return err 242 | } 243 | } 244 | return nil 245 | } 246 | } 247 | 248 | /*---- Decoder ----*/ 249 | type Decoder struct { 250 | br *bufio2.Reader 251 | 252 | Err error 253 | } 254 | 255 | // NewDecoder 256 | func NewDecoder(r io.Reader) *Decoder { 257 | return NewDecoderBuffer(bufio2.NewReaderSize(r, 8192)) 258 | } 259 | 260 | // NewDecoderSize by size 261 | func NewDecoderSize(r io.Reader, size int) *Decoder { 262 | return NewDecoderBuffer(bufio2.NewReaderSize(r, size)) 263 | } 264 | 265 | // NewDecoderBuffer by bufReader 266 | func NewDecoderBuffer(br *bufio2.Reader) *Decoder { 267 | return &Decoder{br: br} 268 | } 269 | 270 | // Decode 271 | func (d *Decoder) Decode() (*Resp, error) { 272 | if d.Err != nil { 273 | return nil, errorsTrace(errorNew("Decode err")) 274 | } 275 | r, err := d.decodeResp() 276 | if err != nil { 277 | d.Err = err 278 | } 279 | return r, d.Err 280 | } 281 | 282 | // DecodeMultiBulk decode批量回复 283 | func (d *Decoder) DecodeMultiBulk() ([]*Resp, error) { 284 | if d.Err != nil { 285 | return nil, errorsTrace(errorNew("DecodeMultibulk error")) 286 | } 287 | m, err := d.decodeMultiBulk() 288 | if err != nil { 289 | d.Err = err 290 | } 291 | return m, err 292 | } 293 | 294 | // Decode api 295 | func Decode(r io.Reader) (*Resp, error) { 296 | return NewDecoder(r).Decode() 297 | } 298 | 299 | // DecodeFromBytes bytes to resp 300 | func DecodeFromBytes(p []byte) (*Resp, error) { 301 | return NewDecoder(bytes.NewReader(p)).Decode() 302 | } 303 | 304 | // DecodeMultiBulkFromBytes format multibulk 305 | func DecodeMultiBulkFromBytes(p []byte) ([]*Resp, error) { 306 | return NewDecoder(bytes.NewReader(p)).DecodeMultiBulk() 307 | } 308 | 309 | // decodeResp 根据返回类型调用不同解析实现 310 | func (d *Decoder) decodeResp() (*Resp, error) { 311 | b, err := d.br.ReadByte() 312 | if err != nil { 313 | return nil, errorsTrace(err) 314 | } 315 | r := &Resp{} 316 | r.Type = byte(b) 317 | switch r.Type { 318 | default: 319 | return nil, errorsTrace(err) 320 | case TypeString, TypeError, TypeInt: 321 | r.Value, err = d.decodeTextBytes() 322 | case TypeBulkBytes: 323 | r.Value, err = d.decodeBulkBytes() 324 | case TypeArray: 325 | r.Array, err = d.decodeArray() 326 | } 327 | return r, err 328 | } 329 | 330 | // decodeTextBytes decode文本 331 | func (d *Decoder) decodeTextBytes() ([]byte, error) { 332 | b, err := d.br.ReadBytes('\n') 333 | if err != nil { 334 | return nil, errorsTrace(err) 335 | } 336 | if n := len(b) - 2; n < 0 || b[n] != '\r' { 337 | return nil, errorsTrace(err) 338 | } else { 339 | return b[:n], nil 340 | } 341 | } 342 | 343 | // decodeInt decode int 344 | func (d *Decoder) decodeInt() (int64, error) { 345 | b, err := d.br.ReadSlice('\n') 346 | if err != nil { 347 | return 0, errorsTrace(err) 348 | } 349 | if n := len(b) - 2; n < 0 || b[n] != '\r' { 350 | return 0, errorsTrace(err) 351 | } else { 352 | return Btoi64(b[:n]) 353 | } 354 | } 355 | 356 | // decodeBulkBytes decode 批量回复 357 | func (d *Decoder) decodeBulkBytes() ([]byte, error) { 358 | n, err := d.decodeInt() 359 | if err != nil { 360 | return nil, err 361 | } 362 | switch { 363 | case n < -1: 364 | return nil, errorsTrace(err) 365 | case n > MaxBulkBytesLen: 366 | return nil, errorsTrace(err) 367 | case n == -1: 368 | return nil, nil 369 | } 370 | b, err := d.br.ReadFull(int(n) + 2) 371 | if err != nil { 372 | return nil, errorsTrace(err) 373 | } 374 | if b[n] != '\r' || b[n+1] != '\n' { 375 | return nil, errorsTrace(err) 376 | } 377 | return b[:n], nil 378 | } 379 | 380 | // decodeArray decode 多条批量回复 381 | func (d *Decoder) decodeArray() ([]*Resp, error) { 382 | n, err := d.decodeInt() 383 | if err != nil { 384 | return nil, err 385 | } 386 | switch { 387 | case n < -1: 388 | return nil, errorsTrace(err) 389 | case n > MaxArrayLen: 390 | return nil, errorsTrace(err) 391 | case n == -1: 392 | return nil, nil 393 | } 394 | array := make([]*Resp, n) 395 | for i := range array { 396 | r, err := d.decodeResp() 397 | if err != nil { 398 | return nil, err 399 | } 400 | array[i] = r 401 | } 402 | return array, nil 403 | } 404 | 405 | func (d *Decoder) decodeSingleLineMultiBulk() ([]*Resp, error) { 406 | b, err := d.decodeTextBytes() 407 | if err != nil { 408 | return nil, err 409 | } 410 | multi := make([]*Resp, 0, 8) 411 | for l, r := 0, 0; r <= len(b); r++ { 412 | if r == len(b) || b[r] == ' ' { 413 | if l < r { 414 | multi = append(multi, NewBulkBytes(b[l:r])) 415 | } 416 | l = r + 1 417 | } 418 | } 419 | if len(multi) == 0 { 420 | return nil, errorsTrace(err) 421 | } 422 | return multi, nil 423 | } 424 | 425 | func (d *Decoder) decodeMultiBulk() ([]*Resp, error) { 426 | b, err := d.br.PeekByte() 427 | if err != nil { 428 | return nil, errorsTrace(err) 429 | } 430 | if RespType(b) != TypeArray { 431 | return d.decodeSingleLineMultiBulk() 432 | } 433 | if _, err := d.br.ReadByte(); err != nil { 434 | return nil, errorsTrace(err) 435 | } 436 | n, err := d.decodeInt() 437 | 438 | if err != nil { 439 | return nil, errorsTrace(err) 440 | } 441 | switch { 442 | case n <= 0: 443 | return nil, errorsTrace(ErrBadArrayLen) 444 | case n > MaxArrayLen: 445 | return nil, errorsTrace(ErrBadArrayLenTooLong) 446 | } 447 | multi := make([]*Resp, n) 448 | for i := range multi { 449 | r, err := d.decodeResp() 450 | if err != nil { 451 | return nil, err 452 | } 453 | if r.Type != TypeBulkBytes { 454 | return nil, errorsTrace(ErrBadMultiBulkContent) 455 | } 456 | multi[i] = r 457 | } 458 | return multi, nil 459 | } 460 | 461 | /*---- Response ----*/ 462 | type Resp struct { 463 | Type byte 464 | 465 | Value []byte 466 | Array []*Resp 467 | } 468 | 469 | func NewString(value []byte) *Resp { 470 | r := &Resp{} 471 | r.Type = TypeString 472 | r.Value = value 473 | return r 474 | } 475 | 476 | func NewError(value []byte) *Resp { 477 | r := &Resp{} 478 | r.Type = TypeError 479 | r.Value = value 480 | return r 481 | } 482 | 483 | func NewInt(value []byte) *Resp { 484 | r := &Resp{} 485 | r.Type = TypeInt 486 | r.Value = value 487 | return r 488 | } 489 | 490 | // NewBulkBytes 批量回复类型 491 | func NewBulkBytes(value []byte) *Resp { 492 | r := &Resp{} 493 | r.Type = TypeBulkBytes 494 | r.Value = value 495 | return r 496 | } 497 | 498 | // NewArray 多条批量回复类型 499 | func NewArray(array []*Resp) *Resp { 500 | r := &Resp{} 501 | r.Type = TypeArray 502 | r.Array = array 503 | return r 504 | } 505 | func errorsTrace(err error) error { 506 | if err != nil { 507 | log.Println("errors Tracing", err.Error()) 508 | } 509 | return err 510 | } 511 | 512 | func errorNew(msg string) error { 513 | return errors.New("error occur, msg ") 514 | } 515 | -------------------------------------------------------------------------------- /core/geo.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | const RADIUS_COORDS = (1 << 0) /* Search around coordinates. */ 11 | const RADIUS_MEMBER = (1 << 1) /* Search around member. */ 12 | const RADIUS_NOSTORE = (1 << 2) 13 | 14 | const SORT_NONE = 0 15 | const SORT_ASC = 1 16 | const SORT_DESC = 2 17 | 18 | // geoaddCommand 命令实现 19 | func GeoAddCommand(c *Client, s *Server) { 20 | // check params numbers 21 | if (c.Argc-2)%3 != 0 { 22 | /* Need an odd number of arguments if we got this far... */ 23 | addReplyError(c, "syntax error. Try GEOADD key [x1] [y1] [name1] "+ 24 | "[x2] [y2] [name2] ... ") 25 | } 26 | 27 | elements := (c.Argc - 2) / 3 //坐标数 28 | argc := 2 + elements*2 /* ZADD key score ele ... */ 29 | argv := make([]*GodisObject, argc) 30 | argv[0] = CreateObject(ObjectTypeString, "zadd") 31 | argv[1] = c.Argv[1] 32 | 33 | for i := 0; i < elements; i++ { 34 | var xy [2]float64 35 | var hash GeoHashBits 36 | //提取经纬度 37 | if lngObj, ok1 := c.Argv[i*3+2].Ptr.(string); ok1 { 38 | if latObj, ok2 := c.Argv[i*3+3].Ptr.(string); ok2 { 39 | var ok error 40 | xy[0], ok = strconv.ParseFloat(lngObj, 64) 41 | xy[1], ok = strconv.ParseFloat(latObj, 64) 42 | if ok != nil { 43 | addReplyError(c, "lng lat type error") 44 | os.Exit(0) 45 | } 46 | } 47 | } 48 | geohashEncodeWGS84(xy[0], xy[1], GEO_STEP_MAX, &hash) 49 | bits := geohashAlign52Bits(hash) 50 | score := CreateObject(ObjectTypeString, bits) 51 | 52 | val := c.Argv[2+i*3+2] 53 | argv[2+i*2] = score // 设置有序集合元素的分值和名字 54 | argv[3+i*2] = val 55 | } 56 | c.Argc = argc 57 | c.Argv = argv 58 | zaddCommand(c) 59 | 60 | addReplyStatus(c, "OK") 61 | } 62 | 63 | //获取特定位置的hash值 64 | func GeoHashCommand(c *Client, s *Server) { 65 | geoAlphabet := "0123456789bcdefghjkmnpqrstuvwxyz" 66 | zobj := lookupKey(c.Db, c.Argv[1]) 67 | if zobj != nil && zobj.ObjectType != OBJ_ZSET { 68 | return 69 | } 70 | buf := "" 71 | for j := 2; j < c.Argc; j++ { 72 | var score float64 73 | if zobj == nil || zsetScore(zobj, c.Argv[j].Ptr.(string), &score) == C_ERR { 74 | addReplyError(c, "score get error ") 75 | return 76 | } 77 | var xy [2]float64 78 | if !decodeGeohash(score, &xy) { 79 | addReplyError(c, "hash get error") 80 | continue 81 | } 82 | r := [2]GeoHashRange{} 83 | var hash GeoHashBits 84 | r[0].min = -180 85 | r[0].max = 180 86 | r[1].min = -90 87 | r[1].max = 90 88 | geohashEncode(&r[0], &r[1], xy[0], xy[1], 26, &hash) 89 | 90 | temp := "" 91 | for i := 0; i < 11; i++ { 92 | count := 52 - (i+1)*5 93 | idx := (hash.bits >> (uint(count))) & 0x1f 94 | temp += string(geoAlphabet[idx]) 95 | } 96 | buf += temp 97 | buf += ";" 98 | } 99 | addReplyStatus(c, buf) 100 | } 101 | 102 | //获取经纬度 103 | func GeoPosCommand(c *Client, s *Server) { 104 | zobj := lookupKey(c.Db, c.Argv[1]) 105 | if zobj != nil && zobj.ObjectType != OBJ_ZSET { 106 | return 107 | } 108 | buf := "lng:" 109 | 110 | for j := 2; j < c.Argc; j++ { 111 | var score float64 112 | if zobj == nil || zsetScore(zobj, c.Argv[j].Ptr.(string), &score) == C_ERR { 113 | addReplyError(c, "score get error ") 114 | return 115 | } 116 | var xy [2]float64 117 | if !decodeGeohash(score, &xy) { 118 | addReplyError(c, "hash get error") 119 | continue 120 | } 121 | 122 | buf += fmt.Sprint(xy[0]) 123 | buf += ",lat:" 124 | buf += fmt.Sprint(xy[1]) 125 | buf += ";" 126 | } 127 | addReplyStatus(c, buf) 128 | } 129 | 130 | //获取两个位置的距离 131 | func GeoDistCommand(c *Client, s *Server) { 132 | if c.Argc >= 5 { 133 | addReplyError(c, "params error") 134 | return 135 | } 136 | zobj := lookupKey(c.Db, c.Argv[1]) 137 | if zobj != nil && zobj.ObjectType != OBJ_ZSET { 138 | return 139 | } 140 | 141 | var score1, score2 float64 142 | var xyxy1, xyxy2 [2]float64 143 | if zsetScore(zobj, c.Argv[2].Ptr.(string), &score1) == C_ERR || 144 | zsetScore(zobj, c.Argv[3].Ptr.(string), &score2) == C_ERR { 145 | addReplyError(c, "score get error ") 146 | return 147 | } 148 | 149 | if !decodeGeohash(score1, &xyxy1) || !decodeGeohash(score2, &xyxy2) { 150 | addReplyError(c, "hash get error") 151 | return 152 | } 153 | 154 | buf := geohashGetDistance(xyxy1[0], xyxy1[1], xyxy2[0], xyxy2[1]) 155 | addReplyStatus(c, fmt.Sprint(buf)) 156 | } 157 | 158 | func GeoRadiusCommand(c *Client, s *Server) { 159 | georadiusGeneric(c, RADIUS_COORDS) 160 | } 161 | 162 | func GeoRadiusByMemberCommand(c *Client, s *Server) { 163 | georadiusGeneric(c, RADIUS_MEMBER) 164 | } 165 | 166 | //georadius Sicily 15 37 100 km 167 | func georadiusGeneric(c *Client, flags uint) { 168 | var storekey *GodisObject 169 | storedist := 0 /* 0 for STORE, 1 for STOREDIST. */ 170 | 171 | //获取有序集合 172 | zobj := lookupKey(c.Db, c.Argv[1]) 173 | if zobj != nil && zobj.ObjectType != OBJ_ZSET { 174 | return 175 | } 176 | 177 | var xy [2]float64 178 | var base_args int 179 | if flags&RADIUS_COORDS > 0 { 180 | base_args = 6 181 | arg2, ok1 := c.Argv[2].Ptr.(string) 182 | arg3, ok2 := c.Argv[3].Ptr.(string) 183 | if !ok1 || !ok2 { 184 | addReplyError(c, "get lng lat error") 185 | return 186 | } 187 | 188 | var err error 189 | xy[0], err = strconv.ParseFloat(arg2, 64) 190 | xy[1], err = strconv.ParseFloat(arg3, 64) 191 | if err != nil { 192 | addReplyError(c, "get lng lat float error") 193 | return 194 | } 195 | } else if flags&RADIUS_MEMBER > 0 { 196 | //member command 197 | base_args = 7 198 | } else { 199 | addReplyError(c, "Unknown georadius search type") 200 | return 201 | } 202 | 203 | //获取参数单位 204 | conversion := extractUnitOrReply(c, *c.Argv[base_args-1]) 205 | radius_meters, err := strconv.ParseFloat(c.Argv[base_args-2].Ptr.(string), 64) 206 | if err != nil { 207 | addReplyError(c, "radius_meters error") 208 | return 209 | } 210 | radius_meters = radius_meters * conversion 211 | 212 | // 提取所有可选参数 213 | withdist := 0 214 | withhash := 0 215 | withcoords := 0 216 | sort := SORT_NONE 217 | var count int64 = 0 218 | if c.Argc > base_args { 219 | remaining := c.Argc - base_args 220 | for i := 0; i < remaining; i++ { 221 | arg := c.Argv[base_args+i].Ptr.(string) 222 | if strings.EqualFold(arg, "withdist") { 223 | withdist = 1 224 | } else if strings.EqualFold(arg, "withhash") { 225 | withhash = 1 226 | } else if strings.EqualFold(arg, "withcoord") { 227 | withcoords = 1 228 | } else if strings.EqualFold(arg, "asc") { 229 | sort = SORT_ASC 230 | } else if strings.EqualFold(arg, "desc") { 231 | sort = SORT_DESC 232 | } else if strings.EqualFold(arg, "count") && (i+1) < remaining { 233 | 234 | if count < 0 { 235 | addReplyError(c, "COUNT must be > 0") 236 | return 237 | } 238 | i++ 239 | } else if strings.EqualFold(arg, "store") && (i+1) < remaining && (flags&RADIUS_NOSTORE == 0) { 240 | storekey = c.Argv[base_args+i+1] 241 | storedist = 0 242 | i++ 243 | } else if strings.EqualFold(arg, "storedist") && (i+1) < remaining && (flags&RADIUS_NOSTORE == 0) { 244 | storekey = c.Argv[base_args+i+1] 245 | storedist = 1 246 | i++ 247 | } else { 248 | addReplyError(c, "params error") 249 | return 250 | } 251 | } 252 | } 253 | 254 | if storekey != nil && (withdist > 0 || withhash > 0 || withcoords > 0) { 255 | addReplyError(c, 256 | "STORE option in GEORADIUS is not compatible with "+ 257 | "WITHDIST, WITHHASH and WITHCOORDS options") 258 | return 259 | } 260 | 261 | // 指定排序方式 262 | if count != 0 && sort == SORT_NONE { 263 | sort = SORT_ASC 264 | } 265 | 266 | // 定位中心点所处的范围 267 | georadius := geohashGetAreasByRadiusWGS84(xy[0], xy[1], radius_meters) 268 | 269 | /* Search the zset for all matching points */ 270 | ga := geoArrayCreate() // 对中心点以及它的八个方向进行查找,找出所有范围内的元素 271 | membersOfAllNeighbors(zobj, georadius, xy[0], xy[1], radius_meters, ga) 272 | 273 | if ga.used == 0 && storekey == nil { 274 | addReplyError(c, "emptymultibulk") 275 | return 276 | } 277 | 278 | result_length := ga.used 279 | var returned_items int 280 | if count == 0 || int64(result_length) < count { 281 | returned_items = int(result_length) 282 | } else { 283 | returned_items = int(count) 284 | } 285 | option_length := 0 286 | 287 | if sort == SORT_ASC { 288 | 289 | } else if sort == SORT_DESC { 290 | 291 | } 292 | 293 | if storekey == nil { 294 | if withdist > 0 { 295 | option_length++ 296 | } 297 | if withcoords > 0 { 298 | option_length++ 299 | } 300 | if withhash > 0 { 301 | option_length++ 302 | } 303 | 304 | /* Finally send results back to the caller */ 305 | for i := 0; i < returned_items; i++ { 306 | gp := ga.array[i] 307 | gp.dist /= conversion 308 | fmt.Println(gp) 309 | addReplyStatus(c, gp.member) 310 | } 311 | 312 | } else { 313 | fmt.Println(storedist) 314 | } 315 | 316 | } 317 | 318 | func geoArrayCreate() *geoArray { 319 | ga := new(geoArray) 320 | ga.array = make([]*geoPoint, 0) 321 | ga.buckets = 0 322 | ga.used = 0 323 | return ga 324 | } 325 | 326 | //单位 327 | func extractUnitOrReply(c *Client, uint GodisObject) float64 { 328 | u := uint.Ptr.(string) 329 | 330 | if strings.Compare(u, "m") == 0 { 331 | return 1 332 | } else if strings.Compare(u, "km") == 0 { 333 | return 1000 334 | } else if strings.Compare(u, "ft") == 0 { 335 | return 0.3048 336 | } else if strings.Compare(u, "mi") == 0 { 337 | return 1609.34 338 | } else { 339 | addReplyError(c, "unsupported unit provided. please use m, km, ft, mi") 340 | return -1 341 | } 342 | } 343 | 344 | func membersOfAllNeighbors(zobj *GodisObject, n GeoHashRadius, lon float64, lat float64, radius float64, ga *geoArray) int { 345 | neighbors := [9]GeoHashBits{} 346 | var count, last_processed int 347 | debugmsg := 0 348 | 349 | neighbors[0] = n.hash 350 | neighbors[1] = n.neighbors.north 351 | neighbors[2] = n.neighbors.south 352 | neighbors[3] = n.neighbors.east 353 | neighbors[4] = n.neighbors.west 354 | neighbors[5] = n.neighbors.north_east 355 | neighbors[6] = n.neighbors.north_west 356 | neighbors[7] = n.neighbors.south_east 357 | neighbors[8] = n.neighbors.south_west 358 | 359 | for i := 0; i < len(neighbors); i++ { 360 | if hashIsZero(neighbors[i]) { 361 | continue 362 | } 363 | 364 | /* Debugging info. */ 365 | if debugmsg > 0 { 366 | var long_range, lat_range GeoHashRange 367 | geohashGetCoordRange(&long_range, &lat_range) 368 | myarea := new(GeoHashArea) 369 | geohashDecode(long_range, lat_range, neighbors[i], myarea) 370 | 371 | /* Dump center square. */ 372 | fmt.Println("neighbors[%d]:\n", i) 373 | fmt.Println("area.longitude.min: %f\n", myarea.longitude.min) 374 | fmt.Println("area.longitude.max: %f\n", myarea.longitude.max) 375 | fmt.Println("area.latitude.min: %f\n", myarea.latitude.min) 376 | fmt.Println("area.latitude.max: %f\n", myarea.latitude.max) 377 | } 378 | 379 | /* When a huge Radius (in the 5000 km range or more) is used, 380 | * adjacent neighbors can be the same, leading to duplicated 381 | * elements. Skip every range which is the same as the one 382 | * processed previously. */ 383 | if last_processed > 0 && 384 | neighbors[i].bits == neighbors[last_processed].bits && 385 | neighbors[i].step == neighbors[last_processed].step { 386 | if debugmsg > 0 { 387 | fmt.Println("Skipping processing of %d, same as previous\n", i) 388 | } 389 | continue 390 | } 391 | count += membersOfGeoHashBox(zobj, neighbors[i], ga, lon, lat, radius) 392 | last_processed = i 393 | } 394 | return count 395 | } 396 | 397 | func membersOfGeoHashBox(zobj *GodisObject, hash GeoHashBits, ga *geoArray, lon float64, lat float64, radius float64) int { 398 | var min, max GeoHashFix52Bits 399 | 400 | scoresOfGeoHashBox(hash, &min, &max) 401 | return geoGetPointsInRange(zobj, float64(min), float64(max), lon, lat, radius, ga) 402 | } 403 | 404 | func scoresOfGeoHashBox(hash GeoHashBits, min *GeoHashFix52Bits, max *GeoHashFix52Bits) { 405 | *min = geohashAlign52Bits(hash) 406 | hash.bits++ 407 | *max = geohashAlign52Bits(hash) 408 | } 409 | 410 | func geoGetPointsInRange(zobj *GodisObject, min float64, max float64, lon float64, lat float64, radius float64, ga *geoArray) int { 411 | zrange := zRangeSpec{min: min, max: max, minEx: 0, maxEx: 1} 412 | var origincount uint = ga.used 413 | //var member string 414 | if zobj.ObjectType == OBJ_ZSET { 415 | zs := zobj.Ptr.(*zSet) //使用*zSet好,还是zSet 416 | zsl := zs.zsl 417 | var ln *zSkipListNode 418 | 419 | ln = zslFirstInRange(zsl, &zrange) 420 | if ln == nil { 421 | return 0 422 | } 423 | 424 | for ln != nil { 425 | ele := ln.ele 426 | if !zslValueLteMax(ln.score, &zrange) { 427 | break 428 | } 429 | geoAppendIfWithinRadius(ga, lon, lat, radius, ln.score, ele) 430 | ln = ln.level[0].forward 431 | } 432 | } else { 433 | //ziplist 434 | } 435 | return int(ga.used - origincount) 436 | } 437 | 438 | func geoAppendIfWithinRadius(ga *geoArray, lon float64, lat float64, radius float64, score float64, member string) int { 439 | var distance float64 440 | xy := [2]float64{} 441 | 442 | if !decodeGeohash(score, &xy) { 443 | return C_ERR 444 | } 445 | if !geohashGetDistanceIfInRadiusWGS84(lon, lat, xy[0], xy[1], radius, &distance) { 446 | return C_ERR 447 | } 448 | 449 | gp := geoArrayAppend(ga) 450 | gp.longitude = xy[0] 451 | gp.latitude = xy[1] 452 | gp.dist = distance 453 | gp.member = member 454 | gp.score = score 455 | return C_OK 456 | } 457 | 458 | func geoArrayAppend(ga *geoArray) *geoPoint { 459 | if ga.used == ga.buckets { 460 | if ga.buckets == 0 { 461 | ga.buckets = 8 462 | } else { 463 | ga.buckets = ga.buckets * 2 464 | } 465 | } 466 | gp := new(geoPoint) 467 | ga.array = append(ga.array, gp) 468 | ga.used++ 469 | return gp 470 | } 471 | -------------------------------------------------------------------------------- /core/geohash.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import "math" 4 | 5 | const GEO_STEP_MAX = 26 /* 26*2 = 52 bits. */ 6 | 7 | /* Limits from EPSG:900913 / EPSG:3785 / OSGEO:41001 */ 8 | const GEO_LAT_MIN = -85.05112878 9 | const GEO_LAT_MAX = 85.05112878 10 | const GEO_LONG_MIN = -180 11 | const GEO_LONG_MAX = 180 12 | const D_R = (math.Pi / 180.0) 13 | 14 | /// @brief Earth's quatratic mean radius for WGS-84 15 | const EARTH_RADIUS_IN_METERS = 6372797.560856 16 | 17 | const MERCATOR_MAX = 20037726.37 18 | const MERCATOR_MIN = -20037726.37 19 | 20 | type GeoHashFix52Bits = uint64 21 | type GeoHashVarBits = uint64 22 | 23 | type GeoHashBits struct { 24 | bits uint64 25 | step uint8 26 | } 27 | 28 | type GeoHashRange struct { 29 | min float64 30 | max float64 31 | } 32 | 33 | type GeoHashArea struct { 34 | hash GeoHashBits 35 | longitude GeoHashRange 36 | latitude GeoHashRange 37 | } 38 | 39 | type GeoHashRadius struct { 40 | hash GeoHashBits 41 | area GeoHashArea 42 | neighbors GeoHashNeighbors 43 | } 44 | type GeoHashNeighbors struct { 45 | north GeoHashBits 46 | east GeoHashBits 47 | west GeoHashBits 48 | south GeoHashBits 49 | north_east GeoHashBits 50 | south_east GeoHashBits 51 | north_west GeoHashBits 52 | south_west GeoHashBits 53 | } 54 | 55 | type geoArray struct { 56 | array []*geoPoint 57 | buckets uint 58 | used uint 59 | } 60 | 61 | type geoPoint struct { 62 | longitude float64 63 | latitude float64 64 | dist float64 65 | score float64 66 | member string 67 | } 68 | 69 | func deg_rad(ang float64) float64 { 70 | return ang * D_R 71 | } 72 | func rad_deg(ang float64) float64 { 73 | return ang / D_R 74 | } 75 | 76 | func geohashEncodeWGS84(longitude float64, latitude float64, step uint8, hash *GeoHashBits) int { 77 | return geohashEncodeType(longitude, latitude, step, hash) 78 | } 79 | 80 | func geohashEncodeType(longitude float64, latitude float64, step uint8, hash *GeoHashBits) int { 81 | r := [2]GeoHashRange{} 82 | geohashGetCoordRange(&r[0], &r[1]) 83 | return geohashEncode(&r[0], &r[1], longitude, latitude, step, hash) 84 | } 85 | 86 | /* These are constraints from EPSG:900913 / EPSG:3785 / OSGEO:41001 */ 87 | /* We can't geocode at the north/south pole. */ 88 | func geohashGetCoordRange(long_range *GeoHashRange, lat_range *GeoHashRange) { 89 | long_range.max = GEO_LONG_MAX 90 | long_range.min = GEO_LONG_MIN 91 | lat_range.max = GEO_LAT_MAX 92 | lat_range.min = GEO_LAT_MIN 93 | } 94 | 95 | func geohashEncode(long_range *GeoHashRange, lat_range *GeoHashRange, longitude float64, latitude float64, step uint8, 96 | hash *GeoHashBits) int { 97 | /* Check basic arguments sanity. */ 98 | 99 | /* Return an error when trying to index outside the supported 100 | * constraints. */ 101 | if longitude > 180 || longitude < -180 || 102 | latitude > 85.05112878 || latitude < -85.05112878 { 103 | return 0 104 | } 105 | 106 | hash.bits = 0 107 | hash.step = step 108 | 109 | if latitude < lat_range.min || latitude > lat_range.max || 110 | longitude < long_range.min || longitude > long_range.max { 111 | return 0 112 | } 113 | 114 | var lat_offset float64 115 | var long_offset float64 116 | lat_offset = 117 | (latitude - lat_range.min) / (lat_range.max - lat_range.min) 118 | long_offset = 119 | (longitude - long_range.min) / (long_range.max - long_range.min) 120 | 121 | /* convert to fixed point based on the step size */ 122 | mask := 1 << step 123 | lat_offset = lat_offset * float64(mask) 124 | long_offset = long_offset * float64(mask) 125 | hash.bits = interleave64(int32(lat_offset), int32(long_offset)) 126 | return 1 127 | } 128 | 129 | /* 130 | lat 放在偶数位,lng放在奇数位 131 | */ 132 | func interleave64(latOffset int32, lngOffset int32) uint64 { 133 | B := []uint64{0x5555555555555555, 0x3333333333333333, 134 | 0x0F0F0F0F0F0F0F0F, 0x00FF00FF00FF00FF, 135 | 0x0000FFFF0000FFFF} 136 | S := []uint8{1, 2, 4, 8, 16} 137 | x := uint64(latOffset) 138 | y := uint64(lngOffset) 139 | x = (x | (x << S[4])) & B[4] 140 | y = (y | (y << S[4])) & B[4] 141 | x = (x | (x << S[3])) & B[3] 142 | y = (y | (y << S[3])) & B[3] 143 | x = (x | (x << S[2])) & B[2] 144 | y = (y | (y << S[2])) & B[2] 145 | x = (x | (x << S[1])) & B[1] 146 | y = (y | (y << S[1])) & B[1] 147 | x = (x | (x << S[0])) & B[0] 148 | y = (y | (y << S[0])) & B[0] 149 | return x | (y << 1) 150 | } 151 | 152 | func deinterleave64(interleaved uint64) uint64 { 153 | B := []uint64{0x5555555555555555, 0x3333333333333333, 154 | 0x0F0F0F0F0F0F0F0F, 0x00FF00FF00FF00FF, 155 | 0x0000FFFF0000FFFF, 0x00000000FFFFFFFF} 156 | 157 | S := []uint8{0, 1, 2, 4, 8, 16} 158 | x := interleaved 159 | y := interleaved >> 1 160 | x = (x | (x >> S[0])) & B[0] 161 | y = (y | (y >> S[0])) & B[0] 162 | 163 | x = (x | (x >> S[1])) & B[1] 164 | y = (y | (y >> S[1])) & B[1] 165 | 166 | x = (x | (x >> S[2])) & B[2] 167 | y = (y | (y >> S[2])) & B[2] 168 | 169 | x = (x | (x >> S[3])) & B[3] 170 | y = (y | (y >> S[3])) & B[3] 171 | 172 | x = (x | (x >> S[4])) & B[4] 173 | y = (y | (y >> S[4])) & B[4] 174 | 175 | x = (x | (x >> S[5])) & B[5] 176 | y = (y | (y >> S[5])) & B[5] 177 | 178 | return x | (y << 32) 179 | } 180 | 181 | func geohashAlign52Bits(hash GeoHashBits) uint64 { 182 | bits := hash.bits 183 | bits <<= (52 - hash.step*2) 184 | return bits 185 | } 186 | 187 | func decodeGeohash(bits float64, xy *[2]float64) bool { 188 | hash := GeoHashBits{bits: uint64(bits), step: GEO_STEP_MAX} 189 | return geohashDecodeToLongLatWGS84(hash, xy) 190 | } 191 | func geohashDecodeToLongLatWGS84(hash GeoHashBits, xy *[2]float64) bool { 192 | return geohashDecodeToLongLatType(hash, xy) 193 | } 194 | 195 | func geohashDecodeToLongLatType(hash GeoHashBits, xy *[2]float64) bool { 196 | area := new(GeoHashArea) 197 | if xy == nil || !geohashDecodeType(hash, area) { 198 | return false 199 | } 200 | return geohashDecodeAreaToLongLat(area, xy) 201 | } 202 | 203 | func geohashDecodeType(hash GeoHashBits, area *GeoHashArea) bool { 204 | r := [2]GeoHashRange{} 205 | geohashGetCoordRange(&r[0], &r[1]) 206 | return geohashDecode(r[0], r[1], hash, area) 207 | } 208 | 209 | func geohashDecodeWGS84(hash GeoHashBits, area *GeoHashArea) bool { 210 | return geohashDecodeType(hash, area) 211 | } 212 | 213 | func geohashDecodeAreaToLongLat(area *GeoHashArea, xy *[2]float64) bool { 214 | if xy == nil { 215 | return false 216 | } 217 | xy[0] = (area.longitude.min + area.longitude.max) / 2 218 | xy[1] = (area.latitude.min + area.latitude.max) / 2 219 | return true 220 | 221 | } 222 | 223 | func hashIsZero(hash GeoHashBits) bool { 224 | return hash.bits == 0 && hash.step == 0 225 | } 226 | 227 | func rangeIsZero(r GeoHashRange) bool { 228 | return r.max == 0 && r.min == 0 229 | } 230 | 231 | func geohashDecode(long_range GeoHashRange, lat_range GeoHashRange, hash GeoHashBits, area *GeoHashArea) bool { 232 | if hashIsZero(hash) || area == nil || rangeIsZero(lat_range) || rangeIsZero(long_range) { 233 | return false 234 | } 235 | 236 | area.hash = hash 237 | step := hash.step 238 | hash_sep := deinterleave64(hash.bits) 239 | 240 | lat_scale := lat_range.max - lat_range.min 241 | long_scale := long_range.max - long_range.min 242 | 243 | ilato := uint32(hash_sep) 244 | ilono := uint32(hash_sep >> 32) 245 | 246 | area.latitude.min = lat_range.min + (float64(ilato)*1.0/float64(uint64(1)< 1 && decrease_step > 0 { 316 | steps-- 317 | geohashEncode(&long_range, &lat_range, longitude, latitude, uint8(steps), &hash) 318 | geohashNeighbors(&hash, &neighbors) 319 | geohashDecode(long_range, lat_range, hash, &area) 320 | } 321 | /* Exclude the search areas that are useless. */ 322 | if steps >= 2 { 323 | if area.latitude.min < min_lat { 324 | GZERO(&neighbors.south) 325 | GZERO(&neighbors.south_west) 326 | GZERO(&neighbors.south_east) 327 | } 328 | if area.latitude.max > max_lat { 329 | GZERO(&neighbors.north) 330 | GZERO(&neighbors.north_east) 331 | GZERO(&neighbors.north_west) 332 | } 333 | if area.longitude.min < min_lon { 334 | GZERO(&neighbors.west) 335 | GZERO(&neighbors.south_west) 336 | GZERO(&neighbors.north_west) 337 | } 338 | if area.longitude.max > max_lon { 339 | GZERO(&neighbors.east) 340 | GZERO(&neighbors.south_east) 341 | GZERO(&neighbors.north_east) 342 | } 343 | } 344 | radius.hash = hash 345 | radius.neighbors = neighbors 346 | radius.area = area 347 | return radius 348 | } 349 | func GZERO(s *GeoHashBits) { 350 | s.bits = 0 351 | s.step = 0 352 | } 353 | 354 | //计算经度、纬度为中心的搜索区域的边界框 355 | func geohashBoundingBox(longitude float64, latitude float64, radius_meters float64, bounds *[4]float64) bool { 356 | if bounds == nil { 357 | return false 358 | } 359 | bounds[0] = longitude - rad_deg(radius_meters/EARTH_RADIUS_IN_METERS/math.Cos(deg_rad(latitude))) 360 | bounds[2] = longitude + rad_deg(radius_meters/EARTH_RADIUS_IN_METERS/math.Cos(deg_rad(latitude))) 361 | bounds[1] = latitude - rad_deg(radius_meters/EARTH_RADIUS_IN_METERS) 362 | bounds[3] = latitude + rad_deg(radius_meters/EARTH_RADIUS_IN_METERS) 363 | return true 364 | } 365 | 366 | //计算bits 位的精度 367 | func geohashEstimateStepsByRadius(range_meters float64, lat float64) uint8 { 368 | if range_meters == 0 { 369 | return 26 370 | } 371 | step := uint8(1) 372 | for range_meters < MERCATOR_MAX { 373 | range_meters *= 2 374 | step++ 375 | } 376 | 377 | step -= 2 378 | if lat > 66 || lat < -66 { 379 | step-- 380 | if lat > 80 || lat < -80 { 381 | step-- 382 | } 383 | } 384 | 385 | /* Frame to valid range. */ 386 | if step < 1 { 387 | step = 1 388 | } 389 | if step > 26 { 390 | step = 26 391 | } 392 | return step 393 | } 394 | 395 | //计算其余8个框的geohash 396 | func geohashNeighbors(hash *GeoHashBits, neighbors *GeoHashNeighbors) { 397 | neighbors.east = *hash 398 | neighbors.west = *hash 399 | neighbors.north = *hash 400 | neighbors.south = *hash 401 | neighbors.south_east = *hash 402 | neighbors.south_west = *hash 403 | neighbors.north_east = *hash 404 | neighbors.north_west = *hash //8个方位的hash赋值 405 | 406 | Geohash_move_x(&neighbors.east, 1) 407 | Geohash_move_y(&neighbors.east, 0) 408 | 409 | Geohash_move_x(&neighbors.west, -1) 410 | Geohash_move_y(&neighbors.west, 0) 411 | 412 | Geohash_move_x(&neighbors.south, 0) 413 | Geohash_move_y(&neighbors.south, -1) 414 | 415 | Geohash_move_x(&neighbors.north, 0) 416 | Geohash_move_y(&neighbors.north, 1) 417 | 418 | Geohash_move_x(&neighbors.north_west, -1) 419 | Geohash_move_y(&neighbors.north_west, 1) 420 | 421 | Geohash_move_x(&neighbors.north_east, 1) 422 | Geohash_move_y(&neighbors.north_east, 1) 423 | 424 | Geohash_move_x(&neighbors.south_east, 1) 425 | Geohash_move_y(&neighbors.south_east, -1) 426 | 427 | Geohash_move_x(&neighbors.south_west, -1) 428 | Geohash_move_y(&neighbors.south_west, -1) 429 | } 430 | 431 | func Geohash_move_x(hash *GeoHashBits, d int8) { 432 | if d == 0 { 433 | return 434 | } 435 | 436 | x := hash.bits & 0xaaaaaaaaaaaaaaaa 437 | y := hash.bits & 0x5555555555555555 438 | 439 | zz := uint64(0x5555555555555555 >> (64 - hash.step*2)) 440 | if d > 0 { 441 | x = x + (zz + 1) 442 | } else { 443 | x = x | zz 444 | x = x - (zz + 1) 445 | } 446 | x &= (0xaaaaaaaaaaaaaaaa >> (64 - hash.step*2)) 447 | hash.bits = (x | y) 448 | } 449 | 450 | func Geohash_move_y(hash *GeoHashBits, d int8) { 451 | if d == 0 { 452 | return 453 | } 454 | 455 | x := hash.bits & 0xaaaaaaaaaaaaaaaa 456 | y := hash.bits & 0x5555555555555555 457 | 458 | zz := uint64(0xaaaaaaaaaaaaaaaa >> (64 - hash.step*2)) 459 | if d > 0 { 460 | y = y + (zz + 1) 461 | } else { 462 | y = y | zz 463 | y = y - (zz + 1) 464 | } 465 | y &= (0x5555555555555555 >> (64 - hash.step*2)) 466 | hash.bits = (x | y) 467 | } 468 | 469 | func geohashGetDistanceIfInRadius(x1 float64, y1 float64, x2 float64, y2 float64, radius float64, distance *float64) bool { 470 | *distance = geohashGetDistance(x1, y1, x2, y2) 471 | if *distance > radius { 472 | return false 473 | } 474 | return true 475 | } 476 | 477 | func geohashGetDistanceIfInRadiusWGS84(x1 float64, y1 float64, x2 float64, y2 float64, radius float64, distance *float64) bool { 478 | return geohashGetDistanceIfInRadius(x1, y1, x2, y2, radius, distance) 479 | } 480 | --------------------------------------------------------------------------------