├── .gitignore ├── start.sh ├── img ├── frame.png ├── logic.png └── database.png ├── redis.conf ├── main.go ├── lib ├── time.go └── pool │ └── pool.go ├── app ├── app.go ├── factory.go └── config.go ├── go.mod ├── datastore ├── persist.go ├── expire.go ├── string.go ├── hash.go ├── set.go ├── list.go ├── set_test.go ├── list_test.go ├── hash_test.go ├── sorted_set.go ├── sorted_set_test.go └── kv_store.go ├── handler ├── struct.go ├── persister.go ├── handler.go └── reply.go ├── database ├── trigger.go ├── executor.go └── struct.go ├── README.md ├── persist ├── persist.go ├── aof_rewrite.go └── aof.go ├── go.sum ├── log └── log.go ├── server └── server.go ├── protocol └── parser.go └── goredis_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | *.output 3 | *.aof -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | nohup go run ./main.go >./nohup.output 2>&1 & -------------------------------------------------------------------------------- /img/frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoxuxiansheng/goredis/HEAD/img/frame.png -------------------------------------------------------------------------------- /img/logic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoxuxiansheng/goredis/HEAD/img/logic.png -------------------------------------------------------------------------------- /img/database.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoxuxiansheng/goredis/HEAD/img/database.png -------------------------------------------------------------------------------- /redis.conf: -------------------------------------------------------------------------------- 1 | # ip 地址 2 | bind 0.0.0.0 3 | # 端口 4 | port 6379 5 | 6 | # 是否启用 aof 7 | appendonly yes 8 | # aof 文件名称 9 | appendfilename appendonly.aof 10 | # aof 级别. always | everysec | no 11 | appendfsync everysec 12 | # 每执行多少次 aof 操作后,进行一次重写 13 | auto-aof-rewrite-after-cmds 1000 14 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "github.com/xiaoxuxiansheng/goredis/app" 4 | 5 | func main() { 6 | server, err := app.ConstructServer() 7 | if err != nil { 8 | panic(err) 9 | } 10 | 11 | app := app.NewApplication(server, app.SetUpConfig()) 12 | defer app.Stop() 13 | 14 | if err := app.Run(); err != nil { 15 | panic(err) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /lib/time.go: -------------------------------------------------------------------------------- 1 | package lib 2 | 3 | import "time" 4 | 5 | const ( 6 | YYYY_MM_DD_HH_MM_SS = "2006-01-02 15:04:05" 7 | ) 8 | 9 | func TimeNow() time.Time { 10 | return time.Now() 11 | } 12 | 13 | func ParseTimeSecondFormat(timeStr string) (time.Time, error) { 14 | return time.ParseInLocation(YYYY_MM_DD_HH_MM_SS, timeStr, time.Local) 15 | } 16 | 17 | func TimeSecondFormat(t time.Time) string { 18 | return t.Format(YYYY_MM_DD_HH_MM_SS) 19 | } 20 | -------------------------------------------------------------------------------- /app/app.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import "github.com/xiaoxuxiansheng/goredis/server" 4 | 5 | type Application struct { 6 | server *server.Server 7 | conf *Config 8 | } 9 | 10 | func NewApplication(server *server.Server, conf *Config) *Application { 11 | return &Application{ 12 | server: server, 13 | conf: conf, 14 | } 15 | } 16 | 17 | func (a *Application) Run() error { 18 | return a.server.Serve(a.conf.Address()) 19 | } 20 | 21 | func (a *Application) Stop() { 22 | a.server.Stop() 23 | } 24 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/xiaoxuxiansheng/goredis 2 | 3 | go 1.19 4 | 5 | require ( 6 | go.uber.org/dig v1.17.1 7 | go.uber.org/zap v1.27.0 8 | gopkg.in/natefinch/lumberjack.v2 v2.2.1 9 | ) 10 | 11 | require ( 12 | github.com/davecgh/go-spew v1.1.1 // indirect 13 | github.com/panjf2000/ants v1.3.0 // indirect 14 | github.com/pmezard/go-difflib v1.0.0 // indirect 15 | github.com/spf13/cast v1.6.0 // indirect 16 | github.com/stretchr/testify v1.9.0 // indirect 17 | go.uber.org/multierr v1.10.0 // indirect 18 | gopkg.in/yaml.v3 v3.0.1 // indirect 19 | ) 20 | -------------------------------------------------------------------------------- /datastore/persist.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/xiaoxuxiansheng/goredis/database" 7 | "github.com/xiaoxuxiansheng/goredis/lib" 8 | ) 9 | 10 | func (k *KVStore) ForEach(f func(key string, adapter database.CmdAdapter, expireAt *time.Time)) { 11 | for key, data := range k.data { 12 | expiredAt, ok := k.expiredAt[key] 13 | if ok && expiredAt.Before(lib.TimeNow()) { 14 | continue 15 | } 16 | _adapter, _ := data.(database.CmdAdapter) 17 | if ok { 18 | f(key, _adapter, &expiredAt) 19 | } else { 20 | f(key, _adapter, nil) 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /lib/pool/pool.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "runtime/debug" 5 | "strings" 6 | 7 | "github.com/xiaoxuxiansheng/goredis/log" 8 | 9 | "github.com/panjf2000/ants" 10 | ) 11 | 12 | var pool *ants.Pool 13 | 14 | func init() { 15 | _pool, err := ants.NewPool(50000, ants.WithPanicHandler(func(i interface{}) { 16 | stackInfo := strings.Replace(string(debug.Stack()), "\n", "", -1) 17 | log.GetDefaultLogger().Errorf("recover info: %v, stack info: %s", i, stackInfo) 18 | })) 19 | if err != nil { 20 | panic(err) 21 | } 22 | pool = _pool 23 | } 24 | 25 | func Submit(task func()) { 26 | pool.Submit(task) 27 | } 28 | -------------------------------------------------------------------------------- /handler/struct.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "strings" 7 | ) 8 | 9 | var UnknownErrReplyBytes = []byte("-ERR unknown\r\n") 10 | 11 | type Reply interface { 12 | ToBytes() []byte 13 | } 14 | 15 | type MultiReply interface { 16 | Reply 17 | Args() [][]byte 18 | } 19 | 20 | type Droplet struct { 21 | Reply Reply 22 | Err error 23 | } 24 | 25 | func (d *Droplet) Terminated() bool { 26 | if d.Err == io.EOF || d.Err == io.ErrUnexpectedEOF { 27 | return true 28 | } 29 | return d.Err != nil && strings.Contains(d.Err.Error(), "use of closed network connection") 30 | } 31 | 32 | type DB interface { 33 | Do(ctx context.Context, cmdLine [][]byte) Reply 34 | Close() 35 | } 36 | 37 | // 协议解析器 38 | type Parser interface { 39 | ParseStream(reader io.Reader) <-chan *Droplet 40 | } 41 | -------------------------------------------------------------------------------- /handler/persister.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "io" 6 | ) 7 | 8 | var loadingPersisterPattern int 9 | var ctxKeyLoadingPersisterPattern = &loadingPersisterPattern 10 | 11 | func SetLoadingPattern(ctx context.Context) context.Context { 12 | return context.WithValue(ctx, ctxKeyLoadingPersisterPattern, true) 13 | } 14 | 15 | func IsLoadingPattern(ctx context.Context) bool { 16 | is, _ := ctx.Value(ctxKeyLoadingPersisterPattern).(bool) 17 | return is 18 | } 19 | 20 | type Persister interface { 21 | Reloader() (io.ReadCloser, error) 22 | PersistCmd(ctx context.Context, cmd [][]byte) 23 | Close() 24 | } 25 | 26 | type fakeReadWriter struct { 27 | io.Reader 28 | } 29 | 30 | func newFakeReaderWriter(reader io.Reader) io.ReadWriter { 31 | return &fakeReadWriter{ 32 | Reader: reader, 33 | } 34 | } 35 | 36 | func (f *fakeReadWriter) Write(p []byte) (n int, err error) { 37 | // log ... 38 | return 0, nil 39 | } 40 | -------------------------------------------------------------------------------- /datastore/expire.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/xiaoxuxiansheng/goredis/lib" 7 | ) 8 | 9 | func (k *KVStore) GC() { 10 | // 找出当前所有已过期的 key,批量回收 11 | nowUnix := lib.TimeNow().Unix() 12 | for _, expiredKey := range k.expireTimeWheel.Range(0, nowUnix) { 13 | k.expireProcess(expiredKey) 14 | } 15 | } 16 | 17 | func (k *KVStore) ExpirePreprocess(key string) { 18 | expiredAt, ok := k.expiredAt[key] 19 | if !ok { 20 | return 21 | } 22 | 23 | if expiredAt.After(lib.TimeNow()) { 24 | return 25 | } 26 | 27 | k.expireProcess(key) 28 | } 29 | 30 | func (k *KVStore) expireProcess(key string) { 31 | delete(k.expiredAt, key) 32 | delete(k.data, key) 33 | k.expireTimeWheel.Rem(key) 34 | } 35 | 36 | func (k *KVStore) expire(key string, expiredAt time.Time) { 37 | if _, ok := k.data[key]; !ok { 38 | return 39 | } 40 | k.expiredAt[key] = expiredAt 41 | k.expireTimeWheel.Add(expiredAt.Unix(), key) 42 | } 43 | -------------------------------------------------------------------------------- /database/trigger.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | 8 | "github.com/xiaoxuxiansheng/goredis/handler" 9 | ) 10 | 11 | type DBTrigger struct { 12 | once sync.Once 13 | executor Executor 14 | } 15 | 16 | func NewDBTrigger(executor Executor) handler.DB { 17 | return &DBTrigger{executor: executor} 18 | } 19 | 20 | func (d *DBTrigger) Do(ctx context.Context, cmdLine [][]byte) handler.Reply { 21 | if len(cmdLine) < 2 { 22 | return handler.NewErrReply(fmt.Sprintf("invalid cmd line: %v", cmdLine)) 23 | } 24 | 25 | cmdType := CmdType(cmdLine[0]) 26 | if !d.executor.ValidCommand(cmdType) { 27 | return handler.NewErrReply(fmt.Sprintf("unknown cmd '%s'", cmdLine[0])) 28 | } 29 | 30 | cmd := Command{ 31 | ctx: ctx, 32 | cmd: cmdType, 33 | args: cmdLine[1:], 34 | receiver: make(CmdReceiver), 35 | } 36 | 37 | // 投递给到 executor 38 | d.executor.Entrance() <- &cmd 39 | 40 | // 监听 chan,直到接收到返回的 reply 41 | return <-cmd.Receiver() 42 | } 43 | 44 | func (d *DBTrigger) Close() { 45 | d.once.Do(d.executor.Close) 46 | } 47 | -------------------------------------------------------------------------------- /datastore/string.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "github.com/xiaoxuxiansheng/goredis/database" 5 | "github.com/xiaoxuxiansheng/goredis/handler" 6 | ) 7 | 8 | func (k *KVStore) getAsString(key string) (String, error) { 9 | v, ok := k.data[key] 10 | if !ok { 11 | return nil, nil 12 | } 13 | 14 | str, ok := v.(String) 15 | if !ok { 16 | return nil, handler.NewWrongTypeErrReply() 17 | } 18 | 19 | return str, nil 20 | } 21 | 22 | func (k *KVStore) put(key, value string, insertStrategy bool) int64 { 23 | if _, ok := k.data[key]; ok && insertStrategy { 24 | return 0 25 | } 26 | 27 | k.data[key] = NewString(key, value) 28 | return 1 29 | } 30 | 31 | type String interface { 32 | Bytes() []byte 33 | database.CmdAdapter 34 | } 35 | 36 | type stringEntity struct { 37 | key, str string 38 | } 39 | 40 | func NewString(key, str string) String { 41 | return &stringEntity{key: key, str: str} 42 | } 43 | 44 | func (s *stringEntity) Bytes() []byte { 45 | return []byte(s.str) 46 | } 47 | 48 | func (s *stringEntity) ToCmd() [][]byte { 49 | return [][]byte{[]byte(database.CmdTypeSet), []byte(s.key), []byte(s.str)} 50 | } 51 | -------------------------------------------------------------------------------- /datastore/hash.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "github.com/xiaoxuxiansheng/goredis/database" 5 | "github.com/xiaoxuxiansheng/goredis/handler" 6 | ) 7 | 8 | func (k *KVStore) getAsHashMap(key string) (HashMap, error) { 9 | v, ok := k.data[key] 10 | if !ok { 11 | return nil, nil 12 | } 13 | 14 | hmap, ok := v.(HashMap) 15 | if !ok { 16 | return nil, handler.NewWrongTypeErrReply() 17 | } 18 | 19 | return hmap, nil 20 | } 21 | 22 | func (k *KVStore) putAsHashMap(key string, hmap HashMap) { 23 | k.data[key] = hmap 24 | } 25 | 26 | type HashMap interface { 27 | Put(key string, value []byte) 28 | Get(key string) []byte 29 | Del(key string) int64 30 | database.CmdAdapter 31 | } 32 | 33 | type hashMapEntity struct { 34 | key string 35 | data map[string][]byte 36 | } 37 | 38 | func newHashMapEntity(key string) HashMap { 39 | return &hashMapEntity{ 40 | key: key, 41 | data: make(map[string][]byte), 42 | } 43 | } 44 | 45 | func (h *hashMapEntity) Put(key string, value []byte) { 46 | h.data[key] = value 47 | } 48 | 49 | func (h *hashMapEntity) Get(key string) []byte { 50 | return h.data[key] 51 | } 52 | 53 | func (h *hashMapEntity) Del(key string) int64 { 54 | if _, ok := h.data[key]; !ok { 55 | return 0 56 | } 57 | delete(h.data, key) 58 | return 1 59 | } 60 | 61 | func (h *hashMapEntity) ToCmd() [][]byte { 62 | args := make([][]byte, 0, 2+2*len(h.data)) 63 | args = append(args, []byte(database.CmdTypeHSet), []byte(h.key)) 64 | for k, v := range h.data { 65 | args = append(args, []byte(k), v) 66 | } 67 | return args 68 | } 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | goredis: 基于 go 实现 redis 4 |

5 |

6 | 7 | ## 📚 前言 8 | 笔者在学习 goredis 实现方案的过程中,在很大程度上借鉴了 godis 项目,在此特别致敬一下作者. 9 | 附上传送门:https://github.com/HDT3213/godis/ 10 | 11 | ## 📖 简介 12 | 本着学习和实践的目标, 基于 100% 纯度 go 语言实现的 “低配仿制” redis ,实现到的功能点包括: 13 | - tcp服务端搭建 14 | - 基于go自带netpoller实现io多路复用 15 | - 还原redis数据解析协议 16 | - 常规数据类型与操作指令支持 17 | - string——get/mget/set/mset 18 | - list——lpush/lpop/rpush/rpop/lrange 19 | - set——sadd/sismember/srem 20 | - hashmap——hset/hget/hdel 21 | - sortedset——zadd/zremzrangebyscore 22 | - 数据持久化机制 23 | - appendonlyfile落盘与重写 24 | 25 | ## 💡 `goredis` 技术原理及源码实现 26 | 基于go实现redis之主干框架

27 | 基于go实现redis之指令分发

28 | 基于go实现redis之存储引擎

29 | 基于go实现redis之数据持久化

30 | 31 | ## 💻 核心架构 32 | 服务端与指令分发层 33 | 34 | 数据存储引擎层 35 | 36 | -------------------------------------------------------------------------------- /app/factory.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "github.com/xiaoxuxiansheng/goredis/database" 5 | "github.com/xiaoxuxiansheng/goredis/datastore" 6 | "github.com/xiaoxuxiansheng/goredis/handler" 7 | "github.com/xiaoxuxiansheng/goredis/log" 8 | "github.com/xiaoxuxiansheng/goredis/persist" 9 | "github.com/xiaoxuxiansheng/goredis/protocol" 10 | "github.com/xiaoxuxiansheng/goredis/server" 11 | 12 | "go.uber.org/dig" 13 | ) 14 | 15 | var container = dig.New() 16 | 17 | func init() { 18 | /** 19 | 其它 20 | **/ 21 | // 配置加载 conf 22 | _ = container.Provide(SetUpConfig) 23 | _ = container.Provide(PersistThinker) 24 | // 日志打印 logger 25 | _ = container.Provide(log.GetDefaultLogger) 26 | 27 | /** 28 | 存储引擎 29 | **/ 30 | // 数据持久化 31 | _ = container.Provide(persist.NewPersister) 32 | // 存储介质 33 | _ = container.Provide(datastore.NewKVStore) 34 | // 执行器 35 | _ = container.Provide(database.NewDBExecutor) 36 | // 触发器 37 | _ = container.Provide(database.NewDBTrigger) 38 | 39 | /** 40 | 逻辑处理层 41 | **/ 42 | // 协议解析 43 | _ = container.Provide(protocol.NewParser) 44 | // 指令处理 45 | _ = container.Provide(handler.NewHandler) 46 | 47 | /** 48 | 服务端 49 | **/ 50 | _ = container.Provide(server.NewServer) 51 | } 52 | 53 | func ConstructServer() (*server.Server, error) { 54 | var h server.Handler 55 | if err := container.Invoke(func(_h server.Handler) { 56 | h = _h 57 | }); err != nil { 58 | return nil, err 59 | } 60 | 61 | var l log.Logger 62 | if err := container.Invoke(func(_l log.Logger) { 63 | l = _l 64 | }); err != nil { 65 | return nil, err 66 | } 67 | return server.NewServer(h, l), nil 68 | } 69 | -------------------------------------------------------------------------------- /persist/persist.go: -------------------------------------------------------------------------------- 1 | package persist 2 | 3 | import ( 4 | "context" 5 | "io" 6 | 7 | "github.com/xiaoxuxiansheng/goredis/handler" 8 | ) 9 | 10 | type Thinker interface { 11 | AppendOnly() bool 12 | AppendFileName() string 13 | AppendFsync() string 14 | AutoAofRewriteAfterCmd() int 15 | } 16 | 17 | func NewPersister(thinker Thinker) (handler.Persister, error) { 18 | if !thinker.AppendOnly() { 19 | return newFakePersister(nil), nil 20 | } 21 | 22 | return newAofPersister(thinker) 23 | } 24 | 25 | type fakeReadCloser struct { 26 | io.Reader 27 | closef func() error 28 | } 29 | 30 | func readCloserAdapter(reader io.Reader, closef func() error) io.ReadCloser { 31 | return &fakeReadCloser{Reader: reader, closef: closef} 32 | } 33 | 34 | func (f *fakeReadCloser) Close() error { 35 | return f.closef() 36 | } 37 | 38 | func newFakePersister(readCloser io.ReadCloser) handler.Persister { 39 | f := fakePersister{} 40 | if readCloser == nil { 41 | f.readCloser = singleFakeReloader 42 | return &f 43 | } 44 | f.readCloser = readCloser 45 | return &f 46 | } 47 | 48 | type fakePersister struct { 49 | readCloser io.ReadCloser 50 | } 51 | 52 | func (f *fakePersister) Reloader() (io.ReadCloser, error) { 53 | return f.readCloser, nil 54 | } 55 | 56 | func (f *fakePersister) PersistCmd(ctx context.Context, cmd [][]byte) {} 57 | 58 | func (f *fakePersister) Close() {} 59 | 60 | var singleFakeReloader = &fakeReloader{} 61 | 62 | type fakeReloader struct { 63 | } 64 | 65 | func (f *fakeReloader) Read(p []byte) (n int, err error) { 66 | return 0, io.EOF 67 | } 68 | 69 | func (f *fakeReloader) Close() error { 70 | return nil 71 | } 72 | -------------------------------------------------------------------------------- /datastore/set.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "github.com/xiaoxuxiansheng/goredis/database" 5 | "github.com/xiaoxuxiansheng/goredis/handler" 6 | ) 7 | 8 | func (k *KVStore) getAsSet(key string) (Set, error) { 9 | v, ok := k.data[key] 10 | if !ok { 11 | return nil, nil 12 | } 13 | 14 | set, ok := v.(Set) 15 | if !ok { 16 | return nil, handler.NewWrongTypeErrReply() 17 | } 18 | 19 | return set, nil 20 | } 21 | 22 | func (k *KVStore) putAsSet(key string, set Set) { 23 | k.data[key] = set 24 | } 25 | 26 | type Set interface { 27 | Add(value string) int64 28 | Exist(value string) int64 29 | Rem(value string) int64 30 | database.CmdAdapter 31 | } 32 | 33 | type setEntity struct { 34 | key string 35 | container map[string]struct{} 36 | } 37 | 38 | func newSetEntity(key string) Set { 39 | return &setEntity{ 40 | key: key, 41 | container: make(map[string]struct{}), 42 | } 43 | } 44 | 45 | func (s *setEntity) Add(value string) int64 { 46 | if _, ok := s.container[value]; ok { 47 | return 0 48 | } 49 | s.container[value] = struct{}{} 50 | return 1 51 | } 52 | 53 | func (s *setEntity) Exist(value string) int64 { 54 | if _, ok := s.container[value]; ok { 55 | return 1 56 | } 57 | return 0 58 | } 59 | 60 | func (s *setEntity) Rem(value string) int64 { 61 | if _, ok := s.container[value]; ok { 62 | delete(s.container, value) 63 | return 1 64 | } 65 | return 0 66 | } 67 | 68 | func (s *setEntity) ToCmd() [][]byte { 69 | args := make([][]byte, 0, 2+len(s.container)) 70 | args = append(args, []byte(database.CmdTypeSAdd), []byte(s.key)) 71 | for k := range s.container { 72 | args = append(args, []byte(k)) 73 | } 74 | 75 | return args 76 | } 77 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/panjf2000/ants v1.3.0 h1:8pQ+8leaLc9lys2viEEr8md0U4RN6uOSUCE9bOYjQ9M= 4 | github.com/panjf2000/ants v1.3.0/go.mod h1:AaACblRPzq35m1g3enqYcxspbbiOJJYaxU2wMpm1cXY= 5 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 6 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 7 | github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= 8 | github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= 9 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 10 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 11 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 12 | go.uber.org/dig v1.17.1 h1:Tga8Lz8PcYNsWsyHMZ1Vm0OQOUaJNDyvPImgbAu9YSc= 13 | go.uber.org/dig v1.17.1/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= 14 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 15 | go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= 16 | go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= 17 | go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= 18 | go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= 19 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 20 | gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= 21 | gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= 22 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 23 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 24 | -------------------------------------------------------------------------------- /datastore/list.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "github.com/xiaoxuxiansheng/goredis/database" 5 | "github.com/xiaoxuxiansheng/goredis/handler" 6 | ) 7 | 8 | func (k *KVStore) getAsList(key string) (List, error) { 9 | v, ok := k.data[key] 10 | if !ok { 11 | return nil, nil 12 | } 13 | 14 | list, ok := v.(List) 15 | if !ok { 16 | return nil, handler.NewWrongTypeErrReply() 17 | } 18 | 19 | return list, nil 20 | } 21 | 22 | func (k *KVStore) putAsList(key string, list List) { 23 | k.data[key] = list 24 | } 25 | 26 | type List interface { 27 | LPush(value []byte) 28 | LPop(cnt int64) [][]byte 29 | RPush(value []byte) 30 | RPop(cnt int64) [][]byte 31 | Len() int64 32 | Range(start, stop int64) [][]byte 33 | database.CmdAdapter 34 | } 35 | 36 | type listEntity struct { 37 | key string 38 | data [][]byte 39 | } 40 | 41 | func newListEntity(key string, elements ...[]byte) List { 42 | return &listEntity{ 43 | key: key, 44 | data: elements, 45 | } 46 | } 47 | 48 | func (l *listEntity) LPush(value []byte) { 49 | l.data = append([][]byte{value}, l.data...) 50 | } 51 | 52 | func (l *listEntity) LPop(cnt int64) [][]byte { 53 | if int64(len(l.data)) < cnt { 54 | return nil 55 | } 56 | 57 | poped := l.data[:cnt] 58 | l.data = l.data[cnt:] 59 | return poped 60 | } 61 | 62 | func (l *listEntity) RPush(value []byte) { 63 | l.data = append(l.data, value) 64 | } 65 | 66 | func (l *listEntity) RPop(cnt int64) [][]byte { 67 | if int64(len(l.data)) < cnt { 68 | return nil 69 | } 70 | 71 | poped := l.data[int64(len(l.data))-cnt:] 72 | l.data = l.data[:int64(len(l.data))-cnt] 73 | return poped 74 | } 75 | 76 | func (l *listEntity) Len() int64 { 77 | return int64(len(l.data)) 78 | } 79 | 80 | func (l *listEntity) Range(start, stop int64) [][]byte { 81 | if stop == -1 { 82 | stop = int64(len(l.data) - 1) 83 | } 84 | 85 | if start < 0 || start >= int64(len(l.data)) { 86 | return nil 87 | } 88 | 89 | if stop < 0 || stop >= int64(len(l.data)) || stop < start { 90 | return nil 91 | } 92 | 93 | return l.data[start : stop+1] 94 | } 95 | 96 | func (l *listEntity) ToCmd() [][]byte { 97 | args := make([][]byte, 0, 2+l.Len()) 98 | args = append(args, []byte(database.CmdTypeRPush), []byte(l.key)) 99 | args = append(args, l.data...) 100 | return args 101 | } 102 | -------------------------------------------------------------------------------- /datastore/set_test.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "math/rand" 5 | "sort" 6 | "testing" 7 | 8 | "github.com/spf13/cast" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/xiaoxuxiansheng/goredis/database" 11 | "github.com/xiaoxuxiansheng/goredis/lib" 12 | ) 13 | 14 | func Test_set_crud(t *testing.T) { 15 | set := newSetEntity("") 16 | s := make(map[int]struct{}, 1000) 17 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano())) 18 | 19 | t.Run("add", func(t *testing.T) { 20 | for i := 0; i < 1000; i++ { 21 | member := rander.Intn(1000) 22 | _, ok := s[member] 23 | success := set.Add(cast.ToString(member)) 24 | s[member] = struct{}{} 25 | assert.Equal(t, ok, success == 0) 26 | } 27 | }) 28 | 29 | t.Run("rem", func(t *testing.T) { 30 | for i := 0; i < 1000; i++ { 31 | member := rander.Intn(1000) 32 | _, ok := s[member] 33 | exist := set.Rem(cast.ToString(member)) 34 | delete(s, member) 35 | assert.Equal(t, ok, exist == 1) 36 | } 37 | }) 38 | 39 | t.Run("exist", func(t *testing.T) { 40 | for i := 0; i < 1000; i++ { 41 | member := rander.Intn(1000) 42 | _, ok := s[member] 43 | exist := set.Exist(cast.ToString(member)) 44 | assert.Equal(t, ok, exist == 1) 45 | } 46 | }) 47 | } 48 | 49 | func Test_set_to_cmd(t *testing.T) { 50 | set := newSetEntity("") 51 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano())) 52 | s := make(map[int]struct{}, 1000) 53 | // 插入1000条数据 54 | for i := 0; i < 1000; i++ { 55 | member := rander.Intn(1000) 56 | set.Add(cast.ToString(member)) 57 | s[member] = struct{}{} 58 | } 59 | 60 | cmd := set.ToCmd() 61 | t.Run("length", func(t *testing.T) { 62 | assert.Equal(t, len(s)+2, len(cmd)) 63 | }) 64 | t.Run("command", func(t *testing.T) { 65 | assert.Equal(t, database.CmdTypeSAdd, database.CmdType(cmd[0])) 66 | }) 67 | t.Run("key", func(t *testing.T) { 68 | assert.Equal(t, "", string(cmd[1])) 69 | }) 70 | 71 | actual := make([]int, 0, len(cmd)-2) 72 | for i := 2; i < len(cmd); i++ { 73 | actual = append(actual, cast.ToInt(string(cmd[i]))) 74 | } 75 | sort.Slice(actual, func(i, j int) bool { 76 | return actual[i] < actual[j] 77 | }) 78 | 79 | expect := make([]int, 0, len(s)) 80 | for member := range s { 81 | expect = append(expect, member) 82 | } 83 | sort.Slice(expect, func(i, j int) bool { 84 | return expect[i] < expect[j] 85 | }) 86 | 87 | t.Run("member", func(t *testing.T) { 88 | assert.Equal(t, expect, actual) 89 | }) 90 | } 91 | -------------------------------------------------------------------------------- /datastore/list_test.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "math/rand" 5 | "sort" 6 | "testing" 7 | 8 | "github.com/spf13/cast" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/xiaoxuxiansheng/goredis/database" 11 | "github.com/xiaoxuxiansheng/goredis/lib" 12 | ) 13 | 14 | func Test_list_crud(t *testing.T) { 15 | list := newListEntity("") 16 | l := make([][]byte, 0, 1000) 17 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano())) 18 | for i := 0; i < 1000; i++ { 19 | member1 := rander.Intn(1000) 20 | member2 := rander.Intn(1000) 21 | list.LPush([]byte(cast.ToString(member1))) 22 | list.RPush([]byte(cast.ToString(member2))) 23 | l = append([][]byte{[]byte(cast.ToString(member1))}, l...) 24 | l = append(l, []byte(cast.ToString(member2))) 25 | } 26 | 27 | t.Run("range", func(t *testing.T) { 28 | for i := 0; i < 1000; i++ { 29 | start := rander.Intn(1001) 30 | end := start + rander.Intn(1000) 31 | actual := list.Range(int64(start), int64(end)) 32 | expect := l[start : end+1] 33 | assert.Equal(t, expect, actual) 34 | } 35 | }) 36 | 37 | t.Run("pop", func(t *testing.T) { 38 | for i := 0; i < 500; i++ { 39 | actual := list.LPop(2) 40 | expect := l[:2] 41 | l = l[2:] 42 | assert.Equal(t, expect, actual) 43 | 44 | actual = list.RPop(2) 45 | expect = l[len(l)-2:] 46 | l = l[:len(l)-2] 47 | assert.Equal(t, expect, actual) 48 | } 49 | }) 50 | } 51 | 52 | func Test_list_to_cmds(t *testing.T) { 53 | list := newListEntity("") 54 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano())) 55 | l := make([]int, 0, 1000) 56 | // 插入1000条数据 57 | for i := 0; i < 1000; i++ { 58 | member := rander.Intn(1000) 59 | list.LPush([]byte(cast.ToString(member))) 60 | l = append(l, member) 61 | } 62 | 63 | cmd := list.ToCmd() 64 | t.Run("length", func(t *testing.T) { 65 | assert.Equal(t, len(l)+2, len(cmd)) 66 | }) 67 | t.Run("command", func(t *testing.T) { 68 | assert.Equal(t, database.CmdTypeRPush, database.CmdType(cmd[0])) 69 | }) 70 | t.Run("key", func(t *testing.T) { 71 | assert.Equal(t, "", string(cmd[1])) 72 | }) 73 | 74 | actual := make([]int, 0, len(cmd)-2) 75 | for i := 2; i < len(cmd); i++ { 76 | actual = append(actual, cast.ToInt(string(cmd[i]))) 77 | } 78 | sort.Slice(actual, func(i, j int) bool { 79 | return actual[i] < actual[j] 80 | }) 81 | 82 | expect := l 83 | sort.Slice(expect, func(i, j int) bool { 84 | return expect[i] < expect[j] 85 | }) 86 | 87 | t.Run("member", func(t *testing.T) { 88 | assert.Equal(t, expect, actual) 89 | }) 90 | } 91 | -------------------------------------------------------------------------------- /database/executor.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/xiaoxuxiansheng/goredis/handler" 9 | "github.com/xiaoxuxiansheng/goredis/lib/pool" 10 | ) 11 | 12 | type DBExecutor struct { 13 | ctx context.Context 14 | cancel context.CancelFunc 15 | ch chan *Command 16 | 17 | cmdHandlers map[CmdType]CmdHandler 18 | dataStore DataStore 19 | 20 | gcTicker *time.Ticker 21 | } 22 | 23 | func NewDBExecutor(dataStore DataStore) Executor { 24 | ctx, cancel := context.WithCancel(context.Background()) 25 | e := DBExecutor{ 26 | dataStore: dataStore, 27 | ch: make(chan *Command), 28 | ctx: ctx, 29 | cancel: cancel, 30 | gcTicker: time.NewTicker(time.Minute), 31 | } 32 | e.cmdHandlers = map[CmdType]CmdHandler{ 33 | CmdTypeExpire: e.dataStore.Expire, 34 | CmdTypeExpireAt: e.dataStore.ExpireAt, 35 | 36 | // string 37 | CmdTypeGet: e.dataStore.Get, 38 | CmdTypeSet: e.dataStore.Set, 39 | CmdTypeMGet: e.dataStore.MGet, 40 | CmdTypeMSet: e.dataStore.MSet, 41 | 42 | // list 43 | CmdTypeLPush: e.dataStore.LPush, 44 | CmdTypeLPop: e.dataStore.LPop, 45 | CmdTypeRPush: e.dataStore.RPush, 46 | CmdTypeRPop: e.dataStore.RPop, 47 | CmdTypeLRange: e.dataStore.LRange, 48 | 49 | // set 50 | CmdTypeSAdd: e.dataStore.SAdd, 51 | CmdTypeSIsMember: e.dataStore.SIsMember, 52 | CmdTypeSRem: e.dataStore.SRem, 53 | 54 | // hash 55 | CmdTypeHSet: e.dataStore.HSet, 56 | CmdTypeHGet: e.dataStore.HGet, 57 | CmdTypeHDel: e.dataStore.HDel, 58 | 59 | // sorted set 60 | CmdTypeZAdd: e.dataStore.ZAdd, 61 | CmdTypeZRangeByScore: e.dataStore.ZRangeByScore, 62 | CmdTypeZRem: e.dataStore.ZRem, 63 | } 64 | 65 | pool.Submit(e.run) 66 | return &e 67 | } 68 | 69 | func (e *DBExecutor) Entrance() chan<- *Command { 70 | return e.ch 71 | } 72 | 73 | func (e *DBExecutor) ValidCommand(cmd CmdType) bool { 74 | _, valid := e.cmdHandlers[cmd] // map 只读,不考虑并发问题 75 | return valid 76 | } 77 | 78 | func (e *DBExecutor) Close() { 79 | e.cancel() 80 | } 81 | 82 | func (e *DBExecutor) run() { 83 | for { 84 | select { 85 | case <-e.ctx.Done(): 86 | return 87 | 88 | // 每隔 1 分钟批量一次过期的 key 89 | case <-e.gcTicker.C: 90 | e.dataStore.GC() 91 | 92 | case cmd := <-e.ch: 93 | cmdFunc, ok := e.cmdHandlers[cmd.cmd] 94 | if !ok { 95 | cmd.receiver <- handler.NewErrReply(fmt.Sprintf("unknown command '%s'", cmd.cmd)) 96 | continue 97 | } 98 | 99 | e.dataStore.ExpirePreprocess(string(cmd.args[0])) // 懒加载机制实现过期 key 删除 100 | cmd.receiver <- cmdFunc(cmd) 101 | } 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /datastore/hash_test.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "math/rand" 5 | "sort" 6 | "testing" 7 | 8 | "github.com/spf13/cast" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/xiaoxuxiansheng/goredis/database" 11 | "github.com/xiaoxuxiansheng/goredis/lib" 12 | ) 13 | 14 | func Test_hashmap_crud(t *testing.T) { 15 | hashmap := newHashMapEntity("") 16 | mp := make(map[int]int, 1000) 17 | 18 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano())) 19 | for i := 0; i < 1000; i++ { 20 | k := rander.Intn(1000) 21 | v := rander.Intn(1000) 22 | hashmap.Put(cast.ToString(k), []byte(cast.ToString(v))) 23 | mp[k] = v 24 | } 25 | 26 | t.Run("delete", func(t *testing.T) { 27 | for i := 0; i < 1000; i++ { 28 | k := rander.Intn(1000) 29 | _, ok := mp[k] 30 | exist := hashmap.Del(cast.ToString(k)) 31 | assert.Equal(t, ok, exist == 1) 32 | delete(mp, k) 33 | } 34 | }) 35 | 36 | t.Run("get", func(t *testing.T) { 37 | for i := 0; i < 1000; i++ { 38 | k := rander.Intn(1000) 39 | value := hashmap.Get(cast.ToString(k)) 40 | v, ok := mp[k] 41 | assert.Equal(t, ok, value != nil) 42 | if ok { 43 | assert.Equal(t, v, cast.ToInt(string(value))) 44 | } 45 | } 46 | }) 47 | } 48 | 49 | func Test_hashmap_to_cmd(t *testing.T) { 50 | hashmap := newHashMapEntity("") 51 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano())) 52 | mp := make(map[int]int, 1000) 53 | // 插入1000条数据 54 | for i := 0; i < 1000; i++ { 55 | k := rander.Intn(1000) 56 | v := rander.Intn(1000) 57 | hashmap.Put(cast.ToString(k), []byte(cast.ToString(v))) 58 | mp[k] = v 59 | } 60 | 61 | cmd := hashmap.ToCmd() 62 | t.Run("length", func(t *testing.T) { 63 | assert.Equal(t, 2*len(mp)+2, len(cmd)) 64 | }) 65 | t.Run("command", func(t *testing.T) { 66 | assert.Equal(t, database.CmdTypeHSet, database.CmdType(cmd[0])) 67 | }) 68 | t.Run("key", func(t *testing.T) { 69 | assert.Equal(t, "", string(cmd[1])) 70 | }) 71 | 72 | type kv struct { 73 | k, v int 74 | } 75 | actual := make([]kv, 0, 1000) 76 | for i := 2; i < len(cmd); i += 2 { 77 | actual = append(actual, kv{ 78 | k: cast.ToInt(string(cmd[i])), 79 | v: cast.ToInt(string(cmd[i+1])), 80 | }) 81 | } 82 | 83 | sort.Slice(actual, func(i, j int) bool { 84 | if actual[i].k == actual[j].k { 85 | return actual[i].v < actual[j].v 86 | } 87 | return actual[i].k < actual[j].k 88 | }) 89 | 90 | expect := make([]kv, 0, 2*len(mp)) 91 | for k, v := range mp { 92 | expect = append(expect, kv{ 93 | k: k, 94 | v: v, 95 | }) 96 | } 97 | sort.Slice(expect, func(i, j int) bool { 98 | if expect[i].k == expect[j].k { 99 | return expect[i].v < expect[j].v 100 | } 101 | return expect[i].k < expect[j].k 102 | }) 103 | 104 | t.Run("member", func(t *testing.T) { 105 | assert.Equal(t, expect, actual) 106 | }) 107 | } 108 | -------------------------------------------------------------------------------- /log/log.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "go.uber.org/zap" 5 | "go.uber.org/zap/zapcore" 6 | "gopkg.in/natefinch/lumberjack.v2" 7 | ) 8 | 9 | type Logger interface { 10 | Errorf(format string, v ...interface{}) 11 | Warnf(format string, v ...interface{}) 12 | Infof(format string, v ...interface{}) 13 | Debugf(format string, v ...interface{}) 14 | } 15 | 16 | var ( 17 | defaultLogger Logger 18 | ) 19 | 20 | func init() { 21 | defaultLogger = NewLogger(NewOptions()) 22 | } 23 | 24 | // Options 选项配置 25 | type Options struct { 26 | LogName string // 日志名称 27 | LogLevel string // 日志级别 28 | FileName string // 文件名称 29 | MaxAge int // 日志保留时间,以天为单位 30 | MaxSize int // 日志保留大小,以 M 为单位 31 | MaxBackups int // 保留文件个数 32 | Compress bool // 是否压缩 33 | } 34 | 35 | // Option 选项方法 36 | type Option func(*Options) 37 | 38 | // NewOptions 初始化 39 | func NewOptions(opts ...Option) Options { 40 | options := Options{ 41 | LogName: "app", 42 | LogLevel: "info", 43 | FileName: "app.log", 44 | MaxAge: 10, 45 | MaxSize: 100, 46 | MaxBackups: 3, 47 | Compress: true, 48 | } 49 | for _, opt := range opts { 50 | opt(&options) 51 | } 52 | return options 53 | } 54 | 55 | // WithLogLevel 日志级别 56 | func WithLogLevel(level string) Option { 57 | return func(o *Options) { 58 | o.LogLevel = level 59 | } 60 | } 61 | 62 | // WithFileName 日志文件 63 | func WithFileName(filename string) Option { 64 | return func(o *Options) { 65 | o.FileName = filename 66 | } 67 | } 68 | 69 | // Levels zapcore level 70 | var Levels = map[string]zapcore.Level{ 71 | "debug": zapcore.DebugLevel, 72 | "info": zapcore.InfoLevel, 73 | "warn": zapcore.WarnLevel, 74 | "error": zapcore.ErrorLevel, 75 | "fatal": zapcore.FatalLevel, 76 | } 77 | 78 | type zapLoggerWrapper struct { 79 | *zap.SugaredLogger 80 | options Options 81 | } 82 | 83 | func NewLogger(options Options) Logger { 84 | w := &zapLoggerWrapper{options: options} 85 | encoder := w.getEncoder() 86 | writeSyncer := w.getLogWriter() 87 | core := zapcore.NewCore(encoder, writeSyncer, Levels[options.LogLevel]) 88 | w.SugaredLogger = zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1)).Sugar() 89 | return w 90 | } 91 | 92 | func (w *zapLoggerWrapper) getEncoder() zapcore.Encoder { 93 | encoderConfig := zap.NewProductionEncoderConfig() 94 | encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder 95 | 96 | // 在日志文件中使用大写字母记录日志级别 97 | encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder 98 | // NewConsoleEncoder 打印更符合人们观察的方式 99 | return zapcore.NewConsoleEncoder(encoderConfig) 100 | } 101 | 102 | func (w *zapLoggerWrapper) getLogWriter() zapcore.WriteSyncer { 103 | return zapcore.AddSync(&lumberjack.Logger{ 104 | Filename: w.options.FileName, 105 | MaxAge: w.options.MaxAge, 106 | MaxSize: w.options.MaxSize, 107 | MaxBackups: w.options.MaxBackups, 108 | Compress: w.options.Compress, 109 | }) 110 | } 111 | 112 | // GetDefaultLogger 获取默认日志实现 113 | func GetDefaultLogger() Logger { 114 | return defaultLogger 115 | } 116 | -------------------------------------------------------------------------------- /database/struct.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | "time" 7 | 8 | "github.com/xiaoxuxiansheng/goredis/handler" 9 | ) 10 | 11 | type Executor interface { 12 | Entrance() chan<- *Command 13 | ValidCommand(cmd CmdType) bool 14 | Close() 15 | } 16 | 17 | type CmdType string 18 | 19 | func (c CmdType) String() string { 20 | return strings.ToLower(string(c)) 21 | } 22 | 23 | const ( 24 | CmdTypeExpire CmdType = "expire" 25 | CmdTypeExpireAt CmdType = "expireat" 26 | 27 | // string 28 | CmdTypeGet CmdType = "get" 29 | CmdTypeSet CmdType = "set" 30 | CmdTypeMGet CmdType = "mget" 31 | CmdTypeMSet CmdType = "mset" 32 | 33 | // list 34 | CmdTypeLPush CmdType = "lpush" 35 | CmdTypeLPop CmdType = "lpop" 36 | CmdTypeRPush CmdType = "rpush" 37 | CmdTypeRPop CmdType = "rpop" 38 | CmdTypeLRange CmdType = "lrange" 39 | 40 | // hash 41 | CmdTypeHSet CmdType = "hset" 42 | CmdTypeHGet CmdType = "hget" 43 | CmdTypeHDel CmdType = "hdel" 44 | 45 | // set 46 | CmdTypeSAdd CmdType = "sadd" 47 | CmdTypeSIsMember CmdType = "sismember" 48 | CmdTypeSRem CmdType = "srem" 49 | 50 | // sorted set 51 | CmdTypeZAdd CmdType = "zadd" 52 | CmdTypeZRangeByScore CmdType = "zrangebyscore" 53 | CmdTypeZRem CmdType = "zrem" 54 | ) 55 | 56 | type CmdAdapter interface { 57 | ToCmd() [][]byte 58 | } 59 | 60 | type DataStore interface { 61 | ForEach(task func(key string, adapter CmdAdapter, expireAt *time.Time)) 62 | 63 | ExpirePreprocess(key string) 64 | GC() 65 | 66 | Expire(*Command) handler.Reply 67 | ExpireAt(*Command) handler.Reply 68 | 69 | // string 70 | Get(*Command) handler.Reply 71 | MGet(*Command) handler.Reply 72 | Set(*Command) handler.Reply 73 | MSet(*Command) handler.Reply 74 | 75 | // list 76 | LPush(*Command) handler.Reply 77 | LPop(*Command) handler.Reply 78 | RPush(*Command) handler.Reply 79 | RPop(*Command) handler.Reply 80 | LRange(*Command) handler.Reply 81 | 82 | // set 83 | SAdd(*Command) handler.Reply 84 | SIsMember(*Command) handler.Reply 85 | SRem(*Command) handler.Reply 86 | 87 | // hash 88 | HSet(*Command) handler.Reply 89 | HGet(*Command) handler.Reply 90 | HDel(*Command) handler.Reply 91 | 92 | // sorted set 93 | ZAdd(*Command) handler.Reply 94 | ZRangeByScore(*Command) handler.Reply 95 | ZRem(*Command) handler.Reply 96 | } 97 | 98 | type CmdHandler func(*Command) handler.Reply 99 | 100 | type Command struct { 101 | ctx context.Context 102 | cmd CmdType 103 | args [][]byte 104 | receiver CmdReceiver 105 | } 106 | 107 | func NewCommand(cmd CmdType, args [][]byte) *Command { 108 | return &Command{ 109 | cmd: cmd, 110 | args: args, 111 | } 112 | } 113 | 114 | func (c *Command) Ctx() context.Context { 115 | return c.ctx 116 | } 117 | 118 | func (c *Command) Receiver() CmdReceiver { 119 | return c.receiver 120 | } 121 | 122 | func (c *Command) Args() [][]byte { 123 | return c.args 124 | } 125 | 126 | func (c *Command) Cmd() [][]byte { 127 | return append([][]byte{[]byte(c.cmd.String())}, c.args...) 128 | } 129 | 130 | type CmdReceiver chan handler.Reply 131 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "os" 7 | "os/signal" 8 | "sync" 9 | "syscall" 10 | "time" 11 | 12 | "github.com/xiaoxuxiansheng/goredis/lib/pool" 13 | "github.com/xiaoxuxiansheng/goredis/log" 14 | ) 15 | 16 | // 处理器 17 | type Handler interface { 18 | Start() error // 启动 handler 19 | // 处理到来的每一笔 tcp 连接 20 | Handle(ctx context.Context, conn net.Conn) 21 | // 关闭处理器 22 | Close() 23 | } 24 | 25 | type Server struct { 26 | runOnce sync.Once 27 | stopOnce sync.Once 28 | handler Handler 29 | logger log.Logger 30 | stopc chan struct{} 31 | } 32 | 33 | func NewServer(handler Handler, logger log.Logger) *Server { 34 | return &Server{ 35 | handler: handler, 36 | logger: logger, 37 | stopc: make(chan struct{}), 38 | } 39 | } 40 | 41 | func (s *Server) Serve(address string) error { 42 | if err := s.handler.Start(); err != nil { 43 | return err 44 | } 45 | var _err error 46 | s.runOnce.Do(func() { 47 | // 监听进程信号 48 | exitWords := []os.Signal{syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT} 49 | 50 | sigc := make(chan os.Signal, 1) 51 | signal.Notify(sigc, exitWords...) 52 | closec := make(chan struct{}, 4) 53 | pool.Submit(func() { 54 | for { 55 | select { 56 | case signal := <-sigc: 57 | switch signal { 58 | case exitWords[0], exitWords[1], exitWords[2], exitWords[3]: 59 | closec <- struct{}{} 60 | return 61 | default: 62 | } 63 | case <-s.stopc: 64 | closec <- struct{}{} 65 | return 66 | } 67 | } 68 | }) 69 | 70 | listener, err := net.Listen("tcp", address) 71 | if err != nil { 72 | _err = err 73 | return 74 | } 75 | 76 | s.listenAndServe(listener, closec) 77 | }) 78 | 79 | return _err 80 | } 81 | 82 | func (s *Server) Stop() { 83 | s.stopOnce.Do(func() { 84 | close(s.stopc) 85 | }) 86 | } 87 | 88 | func (s *Server) listenAndServe(listener net.Listener, closec chan struct{}) { 89 | errc := make(chan error, 1) 90 | defer close(errc) 91 | 92 | // 遇到意外错误,则终止流程 93 | ctx, cancel := context.WithCancel(context.Background()) 94 | pool.Submit( 95 | func() { 96 | select { 97 | case <-closec: 98 | s.logger.Errorf("[server]server closing...") 99 | case err := <-errc: 100 | s.logger.Errorf("[server]server err: %s", err.Error()) 101 | } 102 | cancel() 103 | s.logger.Warnf("[server]server closeing...") 104 | s.handler.Close() 105 | if err := listener.Close(); err != nil { 106 | s.logger.Errorf("[server]server close listener err: %s", err.Error()) 107 | } 108 | }) 109 | 110 | s.logger.Warnf("[server]server starting...") 111 | var wg sync.WaitGroup 112 | // io 多路复用模型,goroutine for per conn 113 | for { 114 | conn, err := listener.Accept() 115 | if err != nil { 116 | // 超时类错误,忽略 117 | if ne, ok := err.(net.Error); ok && ne.Timeout() { 118 | time.Sleep(5 * time.Millisecond) 119 | continue 120 | } 121 | 122 | // 意外错误,则停止运行 123 | errc <- err 124 | break 125 | } 126 | 127 | // 为每个到来的 conn 分配一个 goroutine 处理 128 | wg.Add(1) 129 | pool.Submit(func() { 130 | defer wg.Done() 131 | s.handler.Handle(ctx, conn) 132 | }) 133 | } 134 | 135 | // 通过 waitGroup 保证优雅退出 136 | wg.Wait() 137 | } 138 | -------------------------------------------------------------------------------- /handler/handler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net" 7 | "sync" 8 | "sync/atomic" 9 | 10 | "github.com/xiaoxuxiansheng/goredis/log" 11 | "github.com/xiaoxuxiansheng/goredis/server" 12 | ) 13 | 14 | type Handler struct { 15 | sync.Once 16 | mu sync.RWMutex 17 | conns map[net.Conn]struct{} 18 | closed atomic.Bool 19 | 20 | db DB 21 | parser Parser 22 | persister Persister 23 | logger log.Logger 24 | } 25 | 26 | func NewHandler(db DB, persister Persister, parser Parser, logger log.Logger) (server.Handler, error) { 27 | h := Handler{ 28 | conns: make(map[net.Conn]struct{}), 29 | persister: persister, 30 | logger: logger, 31 | db: db, 32 | parser: parser, 33 | } 34 | 35 | return &h, nil 36 | } 37 | 38 | func (h *Handler) Start() error { 39 | // 加载持久化文件,还原内容 40 | reloader, err := h.persister.Reloader() 41 | if err != nil { 42 | return err 43 | } 44 | defer reloader.Close() 45 | h.handle(SetLoadingPattern(context.Background()), newFakeReaderWriter(reloader)) 46 | return nil 47 | } 48 | 49 | func (h *Handler) Handle(ctx context.Context, conn net.Conn) { 50 | h.mu.Lock() 51 | // 判断 db 是否已经关闭 52 | if h.closed.Load() { 53 | h.mu.Unlock() 54 | return 55 | } 56 | 57 | // 当前 conn 缓存起来 58 | h.conns[conn] = struct{}{} 59 | h.mu.Unlock() 60 | 61 | h.handle(ctx, conn) 62 | } 63 | 64 | func (h *Handler) handle(ctx context.Context, conn io.ReadWriter) { 65 | // 持续处理 66 | stream := h.parser.ParseStream(conn) 67 | for { 68 | select { 69 | case <-ctx.Done(): 70 | h.logger.Warnf("[handler]handle ctx err: %s", ctx.Err().Error()) 71 | return 72 | 73 | case droplet := <-stream: 74 | if err := h.handleDroplet(ctx, conn, droplet); err != nil { 75 | h.logger.Errorf("[handler]conn terminated, err: %s", droplet.Err.Error()) 76 | return 77 | } 78 | } 79 | } 80 | } 81 | 82 | func (h *Handler) handleDroplet(ctx context.Context, conn io.ReadWriter, droplet *Droplet) error { 83 | if droplet.Terminated() { 84 | return droplet.Err 85 | } 86 | 87 | if droplet.Err != nil { 88 | _, _ = conn.Write(droplet.Reply.ToBytes()) 89 | h.logger.Errorf("[handler]conn request, err: %s", droplet.Err.Error()) 90 | return nil 91 | } 92 | 93 | if droplet.Reply == nil { 94 | h.logger.Errorf("[handler]conn empty request") 95 | return nil 96 | } 97 | 98 | // 请求参数必须为 multiBulkReply 类型 99 | multiReply, ok := droplet.Reply.(MultiReply) 100 | if !ok { 101 | h.logger.Errorf("[handler]conn invalid request: %s", droplet.Reply.ToBytes()) 102 | return nil 103 | } 104 | 105 | if reply := h.db.Do(ctx, multiReply.Args()); reply != nil { 106 | _, _ = conn.Write(reply.ToBytes()) 107 | return nil 108 | } 109 | 110 | _, _ = conn.Write(UnknownErrReplyBytes) 111 | return nil 112 | } 113 | 114 | func (h *Handler) Close() { 115 | h.Once.Do(func() { 116 | h.logger.Warnf("[handler]handler closing...") 117 | h.closed.Store(true) 118 | h.mu.RLock() 119 | defer h.mu.RUnlock() 120 | for conn := range h.conns { 121 | if err := conn.Close(); err != nil { 122 | h.logger.Errorf("[handler]close conn err, local addr: %s, err: %s", conn.LocalAddr().String(), err.Error()) 123 | } 124 | } 125 | h.conns = nil 126 | h.db.Close() 127 | h.persister.Close() 128 | }) 129 | } 130 | -------------------------------------------------------------------------------- /app/config.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "os" 8 | "reflect" 9 | "strconv" 10 | "strings" 11 | "sync" 12 | 13 | "github.com/xiaoxuxiansheng/goredis/persist" 14 | ) 15 | 16 | type Config struct { 17 | Bind string `cfg:"bind"` // ip 地址 18 | Port int `cfg:"port"` // 启动端口号 19 | AppendOnly_ bool `cfg:"appendonly"` // 是否启用 aof 20 | AppendFileName_ string `cfg:"appendfilename"` // aof 文件名称 21 | AppendFsync_ string `cfg:"appendfsync"` // aof 级别 22 | AutoAofRewriteAfterCmd_ int `cfg:"auto-aof-rewrite-after-cmds"` // 每执行多少次 aof 操作后,进行一次重写 23 | } 24 | 25 | func (c *Config) Address() string { 26 | return fmt.Sprintf("%s:%d", c.Bind, c.Port) 27 | } 28 | 29 | func (c *Config) AppendOnly() bool { 30 | return c.AppendOnly_ 31 | } 32 | 33 | func (c *Config) AppendFileName() string { 34 | return c.AppendFileName_ 35 | } 36 | 37 | func (c *Config) AppendFsync() string { 38 | return c.AppendFsync_ 39 | } 40 | 41 | func (c *Config) AutoAofRewriteAfterCmd() int { 42 | return c.AutoAofRewriteAfterCmd_ 43 | } 44 | 45 | var ( 46 | confOnce sync.Once 47 | globalConf *Config 48 | ) 49 | 50 | func PersistThinker() persist.Thinker { 51 | return SetUpConfig() 52 | } 53 | 54 | func SetUpConfig() *Config { 55 | confOnce.Do(func() { 56 | defer func() { 57 | if globalConf == nil { 58 | globalConf = defaultConf() 59 | } 60 | }() 61 | 62 | file, err := os.Open("./redis.conf") 63 | if err != nil { 64 | return 65 | } 66 | defer file.Close() 67 | globalConf = setUpConfig(file) 68 | }) 69 | 70 | return globalConf 71 | } 72 | 73 | func setUpConfig(src io.Reader) *Config { 74 | tmpkv := make(map[string]string) 75 | scanner := bufio.NewScanner(src) 76 | for scanner.Scan() { 77 | line := scanner.Text() 78 | // 注释行,跳过 79 | trimmed := strings.TrimSpace(line) 80 | if len(trimmed) > 0 && trimmed[0] == '#' { 81 | continue 82 | } 83 | 84 | // 寻找合法的空格分隔符位置 85 | pivot := strings.Index(trimmed, " ") 86 | if pivot <= 0 || pivot >= len(trimmed)-1 { 87 | continue 88 | } 89 | 90 | key := trimmed[:pivot] 91 | value := trimmed[pivot+1:] 92 | tmpkv[key] = value 93 | } 94 | 95 | if err := scanner.Err(); err != nil { 96 | return nil 97 | } 98 | 99 | conf := &Config{} 100 | // 通过反射设置 conf 属性值 101 | t := reflect.TypeOf(conf) 102 | v := reflect.ValueOf(conf) 103 | for i := 0; i < t.Elem().NumField(); i++ { 104 | field := t.Elem().Field(i) 105 | fieldVal := v.Elem().Field(i) 106 | key, ok := field.Tag.Lookup("cfg") 107 | if !ok || strings.TrimSpace(key) == "" { 108 | key = field.Name 109 | } 110 | value, ok := tmpkv[key] 111 | if !ok { 112 | continue 113 | } 114 | switch field.Type.Kind() { 115 | case reflect.String: 116 | fieldVal.SetString(value) 117 | case reflect.Int: 118 | intv, _ := strconv.ParseInt(value, 10, 64) 119 | fieldVal.SetInt(intv) 120 | case reflect.Bool: 121 | fieldVal.SetBool(value == "yes") 122 | } 123 | } 124 | 125 | return conf 126 | } 127 | 128 | func defaultConf() *Config { 129 | return &Config{ 130 | Bind: "0.0.0.0", 131 | Port: 6379, 132 | AppendOnly_: false, // 默认不启用 aof 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /persist/aof_rewrite.go: -------------------------------------------------------------------------------- 1 | package persist 2 | 3 | import ( 4 | "io" 5 | "os" 6 | "time" 7 | 8 | "github.com/xiaoxuxiansheng/goredis/database" 9 | "github.com/xiaoxuxiansheng/goredis/datastore" 10 | "github.com/xiaoxuxiansheng/goredis/handler" 11 | "github.com/xiaoxuxiansheng/goredis/lib" 12 | "github.com/xiaoxuxiansheng/goredis/log" 13 | "github.com/xiaoxuxiansheng/goredis/protocol" 14 | ) 15 | 16 | // 重写 aof 文件 17 | func (a *aofPersister) rewriteAOF() error { 18 | // 1 重写前处理. 需要短暂加锁 19 | tmpFile, fileSize, err := a.startRewrite() 20 | if err != nil { 21 | return err 22 | } 23 | 24 | // 2 aof 指令重写. 与主流程并发执行 25 | if err = a.doRewrite(tmpFile, fileSize); err != nil { 26 | return err 27 | } 28 | 29 | // 3 完成重写. 需要短暂加锁 30 | return a.endRewrite(tmpFile, fileSize) 31 | } 32 | 33 | func (a *aofPersister) startRewrite() (*os.File, int64, error) { 34 | a.mu.Lock() 35 | defer a.mu.Unlock() 36 | 37 | if err := a.aofFile.Sync(); err != nil { 38 | return nil, 0, err 39 | } 40 | 41 | fileInfo, _ := os.Stat(a.aofFileName) 42 | fileSize := fileInfo.Size() 43 | 44 | // 创建一个临时的 aof 文件 45 | tmpFile, err := os.CreateTemp("./", "*.aof") 46 | if err != nil { 47 | return nil, 0, err 48 | } 49 | 50 | return tmpFile, fileSize, nil 51 | } 52 | 53 | func (a *aofPersister) doRewrite(tmpFile *os.File, fileSize int64) error { 54 | forkedDB, err := a.forkDB(fileSize) 55 | if err != nil { 56 | return err 57 | } 58 | 59 | // 将 db 数据转为 aof cmd 60 | forkedDB.ForEach(func(key string, adapter database.CmdAdapter, expireAt *time.Time) { 61 | _, _ = tmpFile.Write(handler.NewMultiBulkReply(adapter.ToCmd()).ToBytes()) 62 | 63 | if expireAt == nil { 64 | return 65 | } 66 | 67 | expireCmd := [][]byte{[]byte(database.CmdTypeExpireAt), []byte(key), []byte(lib.TimeSecondFormat(*expireAt))} 68 | _, _ = tmpFile.Write(handler.NewMultiBulkReply(expireCmd).ToBytes()) 69 | }) 70 | 71 | return nil 72 | } 73 | 74 | func (a *aofPersister) forkDB(fileSize int64) (database.DataStore, error) { 75 | file, err := os.Open(a.aofFileName) 76 | if err != nil { 77 | return nil, err 78 | } 79 | file.Seek(0, io.SeekStart) 80 | logger := log.GetDefaultLogger() 81 | reloader := readCloserAdapter(io.LimitReader(file, fileSize), file.Close) 82 | fakePerisister := newFakePersister(reloader) 83 | tmpKVStore := datastore.NewKVStore(fakePerisister) 84 | executor := database.NewDBExecutor(tmpKVStore) 85 | trigger := database.NewDBTrigger(executor) 86 | h, err := handler.NewHandler(trigger, fakePerisister, protocol.NewParser(logger), logger) 87 | if err != nil { 88 | return nil, err 89 | } 90 | if err = h.Start(); err != nil { 91 | return nil, err 92 | } 93 | return tmpKVStore, nil 94 | } 95 | 96 | func (a *aofPersister) endRewrite(tmpFile *os.File, fileSize int64) error { 97 | a.mu.Lock() 98 | defer a.mu.Unlock() 99 | 100 | // copy commands executed during rewriting to tmpFile 101 | /* read write commands executed during rewriting */ 102 | src, err := os.Open(a.aofFileName) 103 | if err != nil { 104 | return err 105 | } 106 | defer func() { 107 | _ = src.Close() 108 | _ = tmpFile.Close() 109 | }() 110 | 111 | if _, err = src.Seek(fileSize, 0); err != nil { 112 | return err 113 | } 114 | 115 | // 把老的 aof 文件中后续内容 copy 到 tmp 中 116 | if _, err = io.Copy(tmpFile, src); err != nil { 117 | return err 118 | } 119 | 120 | // 关闭老的 aof 文件,准备废弃 121 | _ = a.aofFile.Close() 122 | // 重命名 tmp 文件,作为新的 aof 文件 123 | if err := os.Rename(tmpFile.Name(), a.aofFileName); err != nil { 124 | // log 125 | } 126 | 127 | // 重新开启 128 | aofFile, err := os.OpenFile(a.aofFileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600) 129 | if err != nil { 130 | panic(err) 131 | } 132 | a.aofFile = aofFile 133 | return nil 134 | } 135 | -------------------------------------------------------------------------------- /handler/reply.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | ) 7 | 8 | // CRLF 是 redis 统一的行分隔符协议 9 | const CRLF = "\r\n" 10 | 11 | type OKReply struct{} 12 | 13 | func NewOKReply() *OKReply { 14 | return theOkReply 15 | } 16 | 17 | var okBytes = []byte("+OK\r\n") 18 | 19 | func (o *OKReply) ToBytes() []byte { 20 | return okBytes 21 | } 22 | 23 | var theOkReply = new(OKReply) 24 | 25 | // 简单字符串类型. 协议为 【+】【string】【CRLF】 26 | type SimpleStringReply struct { 27 | Str string 28 | } 29 | 30 | func NewSimpleStringReply(str string) *SimpleStringReply { 31 | return &SimpleStringReply{ 32 | Str: str, 33 | } 34 | } 35 | 36 | func (s *SimpleStringReply) ToBytes() []byte { 37 | return []byte("+" + s.Str + CRLF) 38 | } 39 | 40 | // 简单数字类型. 协议为 【:】【int】【CRLF】 41 | type IntReply struct { 42 | Code int64 43 | } 44 | 45 | func NewIntReply(code int64) *IntReply { 46 | return &IntReply{ 47 | Code: code, 48 | } 49 | } 50 | 51 | func (i *IntReply) ToBytes() []byte { 52 | return []byte(":" + strconv.FormatInt(i.Code, 10) + CRLF) 53 | } 54 | 55 | // 参数语法错误 56 | type SyntaxErrReply struct{} 57 | 58 | var syntaxErrBytes = []byte("-Err syntax error\r\n") 59 | var theSyntaxErrReply = &SyntaxErrReply{} 60 | 61 | func NewSyntaxErrReply() *SyntaxErrReply { 62 | return theSyntaxErrReply 63 | } 64 | 65 | func (r *SyntaxErrReply) ToBytes() []byte { 66 | return syntaxErrBytes 67 | } 68 | 69 | func (r *SyntaxErrReply) Error() string { 70 | return "Err syntax error" 71 | } 72 | 73 | // 数据类型错误 74 | type WrongTypeErrReply struct{} 75 | 76 | var theWrongTypeErrReply = &WrongTypeErrReply{} 77 | 78 | var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n") 79 | 80 | func NewWrongTypeErrReply() *WrongTypeErrReply { 81 | return theWrongTypeErrReply 82 | } 83 | 84 | func (r *WrongTypeErrReply) ToBytes() []byte { 85 | return wrongTypeErrBytes 86 | } 87 | 88 | func (r *WrongTypeErrReply) Error() string { 89 | return "WRONGTYPE Operation against a key holding the wrong kind of value" 90 | } 91 | 92 | // 错误类型. 协议为 【-】【err】【CRLF】 93 | type ErrReply struct { 94 | ErrStr string 95 | } 96 | 97 | func NewErrReply(errStr string) *ErrReply { 98 | return &ErrReply{ 99 | ErrStr: errStr, 100 | } 101 | } 102 | 103 | func (e *ErrReply) ToBytes() []byte { 104 | return []byte("-" + e.ErrStr + CRLF) 105 | } 106 | 107 | var ( 108 | nillReply = &NillReply{} 109 | nillBulkBytes = []byte("$-1\r\n") 110 | ) 111 | 112 | // nill 类型,采用全局单例,格式固定为 【$】【-1】【CRLF】 113 | type NillReply struct { 114 | } 115 | 116 | func NewNillReply() *NillReply { 117 | return nillReply 118 | } 119 | 120 | func (n *NillReply) ToBytes() []byte { 121 | return nillBulkBytes 122 | } 123 | 124 | // 定长字符串类型,协议固定为 【$】【length】【CRLF】【content】【CRLF】 125 | type BulkReply struct { 126 | Arg []byte 127 | } 128 | 129 | func NewBulkReply(arg []byte) *BulkReply { 130 | return &BulkReply{ 131 | Arg: arg, 132 | } 133 | } 134 | 135 | func (b *BulkReply) ToBytes() []byte { 136 | if b.Arg == nil { 137 | return nillBulkBytes 138 | } 139 | return []byte("$" + strconv.Itoa(len(b.Arg)) + CRLF + string(b.Arg) + CRLF) 140 | } 141 | 142 | // 数组类型. 协议固定为 【*】【arr.length】【CRLF】+ arr.length * (【$】【length】【CRLF】【content】【CRLF】) 143 | type MultiBulkReply struct { 144 | args [][]byte 145 | } 146 | 147 | func NewMultiBulkReply(args [][]byte) *MultiBulkReply { 148 | return &MultiBulkReply{ 149 | args: args, 150 | } 151 | } 152 | 153 | func (m *MultiBulkReply) Args() [][]byte { 154 | return m.args 155 | } 156 | 157 | func (m *MultiBulkReply) ToBytes() []byte { 158 | var strBuf strings.Builder 159 | strBuf.WriteString("*" + strconv.Itoa(len(m.args)) + CRLF) 160 | for _, arg := range m.args { 161 | if arg == nil { 162 | strBuf.WriteString(string(nillBulkBytes)) 163 | continue 164 | } 165 | strBuf.WriteString("$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF) 166 | } 167 | return []byte(strBuf.String()) 168 | } 169 | 170 | var emptyMultiBulkBytes = []byte("*0\r\n") 171 | 172 | // 空数组类型. 采用单例,协议固定为【*】【0】【CRLF】 173 | type EmptyMultiBulkReply struct{} 174 | 175 | func NewEmptyMultiBulkReply() *EmptyMultiBulkReply { 176 | return &EmptyMultiBulkReply{} 177 | } 178 | 179 | func (r *EmptyMultiBulkReply) ToBytes() []byte { 180 | return emptyMultiBulkBytes 181 | } 182 | -------------------------------------------------------------------------------- /persist/aof.go: -------------------------------------------------------------------------------- 1 | package persist 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "os" 7 | "sync" 8 | "sync/atomic" 9 | "time" 10 | 11 | "github.com/xiaoxuxiansheng/goredis/handler" 12 | "github.com/xiaoxuxiansheng/goredis/lib/pool" 13 | ) 14 | 15 | // always | everysec | no 16 | type appendSyncStrategy string 17 | 18 | func (a appendSyncStrategy) string() string { 19 | return string(a) 20 | } 21 | 22 | const ( 23 | alwaysAppendSyncStrategy appendSyncStrategy = "always" 24 | everysecAppendSyncStrategy appendSyncStrategy = "everysec" 25 | noAppendSyncStrategy appendSyncStrategy = "no" 26 | ) 27 | 28 | type aofPersister struct { 29 | ctx context.Context 30 | cancel context.CancelFunc 31 | 32 | buffer chan [][]byte 33 | aofFile *os.File 34 | aofFileName string 35 | appendFsync appendSyncStrategy 36 | autoAofRewriteAfterCmd int64 37 | aofCounter atomic.Int64 38 | 39 | mu sync.Mutex 40 | once sync.Once 41 | } 42 | 43 | func newAofPersister(thinker Thinker) (handler.Persister, error) { 44 | aofFileName := thinker.AppendFileName() 45 | aofFile, err := os.OpenFile(aofFileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600) 46 | if err != nil { 47 | return nil, err 48 | } 49 | ctx, cancel := context.WithCancel(context.Background()) 50 | a := aofPersister{ 51 | ctx: ctx, 52 | cancel: cancel, 53 | buffer: make(chan [][]byte, 1<<10), 54 | aofFile: aofFile, 55 | aofFileName: aofFileName, 56 | } 57 | 58 | if autoAofRewriteAfterCmd := thinker.AutoAofRewriteAfterCmd(); autoAofRewriteAfterCmd > 1 { 59 | a.autoAofRewriteAfterCmd = int64(autoAofRewriteAfterCmd) 60 | } 61 | 62 | switch thinker.AppendFsync() { 63 | case alwaysAppendSyncStrategy.string(): 64 | a.appendFsync = alwaysAppendSyncStrategy 65 | case everysecAppendSyncStrategy.string(): 66 | a.appendFsync = everysecAppendSyncStrategy 67 | default: 68 | a.appendFsync = noAppendSyncStrategy // 默认策略 69 | } 70 | 71 | pool.Submit(a.run) 72 | return &a, nil 73 | } 74 | 75 | func (a *aofPersister) Reloader() (io.ReadCloser, error) { 76 | file, err := os.Open(a.aofFileName) 77 | if err != nil { 78 | return nil, err 79 | } 80 | _, _ = file.Seek(0, io.SeekStart) 81 | return file, nil 82 | } 83 | 84 | func (a *aofPersister) PersistCmd(ctx context.Context, cmd [][]byte) { 85 | if handler.IsLoadingPattern(ctx) { 86 | return 87 | } 88 | a.buffer <- cmd 89 | } 90 | 91 | func (a *aofPersister) Close() { 92 | a.once.Do(func() { 93 | a.cancel() 94 | _ = a.aofFile.Close() 95 | }) 96 | } 97 | 98 | func (a *aofPersister) run() { 99 | if a.appendFsync == everysecAppendSyncStrategy { 100 | pool.Submit(a.fsyncEverySecond) 101 | } 102 | 103 | for { 104 | select { 105 | case <-a.ctx.Done(): 106 | // log 107 | return 108 | case cmd := <-a.buffer: 109 | a.writeAof(cmd) 110 | a.aofTick() 111 | } 112 | } 113 | } 114 | 115 | // 记录执行的 aof 指令次数 116 | func (a *aofPersister) aofTick() { 117 | if a.autoAofRewriteAfterCmd <= 1 { 118 | return 119 | } 120 | 121 | if ticked := a.aofCounter.Add(1); ticked < int64(a.autoAofRewriteAfterCmd) { 122 | return 123 | } 124 | 125 | // 达到重写次数,扣减计数器,进行重写 126 | _ = a.aofCounter.Add(-a.autoAofRewriteAfterCmd) 127 | pool.Submit(func() { 128 | if err := a.rewriteAOF(); err != nil { 129 | // log 130 | } 131 | }) 132 | } 133 | 134 | func (a *aofPersister) fsyncEverySecond() { 135 | ticker := time.NewTicker(time.Second) 136 | for { 137 | select { 138 | case <-a.ctx.Done(): 139 | // log 140 | return 141 | case <-ticker.C: 142 | if err := a.fsync(); err != nil { 143 | // log 144 | } 145 | } 146 | } 147 | } 148 | 149 | func (a *aofPersister) writeAof(cmd [][]byte) { 150 | a.mu.Lock() 151 | defer a.mu.Unlock() 152 | 153 | persistCmd := handler.NewMultiBulkReply(cmd) 154 | if _, err := a.aofFile.Write(persistCmd.ToBytes()); err != nil { 155 | // log 156 | return 157 | } 158 | 159 | if a.appendFsync != alwaysAppendSyncStrategy { 160 | return 161 | } 162 | 163 | if err := a.fsyncLocked(); err != nil { 164 | // log 165 | } 166 | } 167 | 168 | func (a *aofPersister) fsync() error { 169 | a.mu.Lock() 170 | defer a.mu.Unlock() 171 | return a.fsyncLocked() 172 | } 173 | 174 | func (a *aofPersister) fsyncLocked() error { 175 | return a.aofFile.Sync() 176 | } 177 | -------------------------------------------------------------------------------- /protocol/parser.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "io" 7 | "strconv" 8 | 9 | "github.com/xiaoxuxiansheng/goredis/handler" 10 | "github.com/xiaoxuxiansheng/goredis/lib/pool" 11 | "github.com/xiaoxuxiansheng/goredis/log" 12 | ) 13 | 14 | type lineParser func(header []byte, reader *bufio.Reader) *handler.Droplet 15 | 16 | type Parser struct { 17 | lineParsers map[byte]lineParser 18 | logger log.Logger 19 | } 20 | 21 | func NewParser(logger log.Logger) handler.Parser { 22 | p := Parser{ 23 | logger: logger, 24 | } 25 | p.lineParsers = map[byte]lineParser{ 26 | '+': p.parseSimpleString, 27 | '-': p.parseError, 28 | ':': p.parseInt, 29 | '$': p.parseBulk, 30 | '*': p.parseMultiBulk, 31 | } 32 | return &p 33 | } 34 | 35 | func (p *Parser) ParseStream(reader io.Reader) <-chan *handler.Droplet { 36 | ch := make(chan *handler.Droplet) 37 | pool.Submit( 38 | func() { 39 | p.parse(reader, ch) 40 | }) 41 | return ch 42 | } 43 | 44 | func (p *Parser) parse(rawReader io.Reader, ch chan<- *handler.Droplet) { 45 | reader := bufio.NewReader(rawReader) 46 | for { 47 | firstLine, err := reader.ReadBytes('\n') 48 | if err != nil { 49 | ch <- &handler.Droplet{ 50 | Reply: handler.NewErrReply(err.Error()), 51 | Err: err, 52 | } 53 | return 54 | } 55 | 56 | length := len(firstLine) 57 | if length <= 2 || firstLine[length-1] != '\n' || firstLine[length-2] != '\r' { 58 | continue 59 | } 60 | 61 | firstLine = bytes.TrimSuffix(firstLine, []byte{'\r', '\n'}) 62 | lineParseFunc, ok := p.lineParsers[firstLine[0]] 63 | if !ok { 64 | p.logger.Errorf("[parser] invalid line handler: %s", firstLine[0]) 65 | continue 66 | } 67 | 68 | ch <- lineParseFunc(firstLine, reader) 69 | } 70 | } 71 | 72 | // 解析简单 string 类型 73 | func (p *Parser) parseSimpleString(header []byte, reader *bufio.Reader) *handler.Droplet { 74 | content := header[1:] 75 | return &handler.Droplet{ 76 | Reply: handler.NewSimpleStringReply(string(content)), 77 | } 78 | } 79 | 80 | // 解析简单 int 类型 81 | func (p *Parser) parseInt(header []byte, reader *bufio.Reader) *handler.Droplet { 82 | 83 | i, err := strconv.ParseInt(string(header[1:]), 10, 64) 84 | if err != nil { 85 | return &handler.Droplet{ 86 | Err: err, 87 | Reply: handler.NewErrReply(err.Error()), 88 | } 89 | } 90 | 91 | return &handler.Droplet{ 92 | Reply: handler.NewIntReply(i), 93 | } 94 | } 95 | 96 | // 解析错误类型 97 | func (p *Parser) parseError(header []byte, reader *bufio.Reader) *handler.Droplet { 98 | return &handler.Droplet{ 99 | Reply: handler.NewErrReply(string(header[1:])), 100 | } 101 | } 102 | 103 | // 解析定长 string 类型 104 | func (p *Parser) parseBulk(header []byte, reader *bufio.Reader) *handler.Droplet { 105 | // 解析定长 string 106 | body, err := p.parseBulkBody(header, reader) 107 | if err != nil { 108 | return &handler.Droplet{ 109 | Reply: handler.NewErrReply(err.Error()), 110 | Err: err, 111 | } 112 | } 113 | return &handler.Droplet{ 114 | Reply: handler.NewBulkReply(body), 115 | } 116 | } 117 | 118 | // 解析定长 string 119 | func (p *Parser) parseBulkBody(header []byte, reader *bufio.Reader) ([]byte, error) { 120 | // 获取 string 长度 121 | strLen, err := strconv.ParseInt(string(header[1:]), 10, 64) 122 | if err != nil { 123 | return nil, err 124 | } 125 | 126 | // 长度 + 2,把 CRLF 也考虑在内 127 | body := make([]byte, strLen+2) 128 | // 从 reader 中读取对应长度 129 | if _, err = io.ReadFull(reader, body); err != nil { 130 | return nil, err 131 | } 132 | return body[:len(body)-2], nil 133 | } 134 | 135 | // 解析 136 | func (p *Parser) parseMultiBulk(header []byte, reader *bufio.Reader) (droplet *handler.Droplet) { 137 | var _err error 138 | defer func() { 139 | if _err != nil { 140 | droplet = &handler.Droplet{ 141 | Reply: handler.NewErrReply(_err.Error()), 142 | Err: _err, 143 | } 144 | } 145 | }() 146 | 147 | // 获取数组长度 148 | length, err := strconv.ParseInt(string(header[1:]), 10, 64) 149 | if err != nil { 150 | _err = err 151 | return 152 | } 153 | 154 | if length <= 0 { 155 | return &handler.Droplet{ 156 | Reply: handler.NewEmptyMultiBulkReply(), 157 | } 158 | } 159 | 160 | lines := make([][]byte, 0, length) 161 | for i := int64(0); i < length; i++ { 162 | // 获取每个 bulk 首行 163 | firstLine, err := reader.ReadBytes('\n') 164 | if err != nil { 165 | _err = err 166 | return 167 | } 168 | 169 | // bulk 首行格式校验 170 | length := len(firstLine) 171 | if length < 4 || firstLine[length-2] != '\r' || firstLine[length-1] != '\n' || firstLine[0] != '$' { 172 | continue 173 | } 174 | 175 | // bulk 解析 176 | bulkBody, err := p.parseBulkBody(firstLine[:length-2], reader) 177 | if err != nil { 178 | _err = err 179 | return 180 | } 181 | 182 | lines = append(lines, bulkBody) 183 | } 184 | 185 | return &handler.Droplet{ 186 | Reply: handler.NewMultiBulkReply(lines), 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /datastore/sorted_set.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | "strconv" 7 | 8 | "github.com/xiaoxuxiansheng/goredis/database" 9 | "github.com/xiaoxuxiansheng/goredis/handler" 10 | "github.com/xiaoxuxiansheng/goredis/lib" 11 | ) 12 | 13 | func (k *KVStore) getAsSortedSet(key string) (SortedSet, error) { 14 | v, ok := k.data[key] 15 | if !ok { 16 | return nil, nil 17 | } 18 | 19 | zset, ok := v.(SortedSet) 20 | if !ok { 21 | return nil, handler.NewWrongTypeErrReply() 22 | } 23 | 24 | return zset, nil 25 | } 26 | 27 | func (k *KVStore) putAsSortedSet(key string, zset SortedSet) { 28 | k.data[key] = zset 29 | } 30 | 31 | type SortedSet interface { 32 | Add(score int64, member string) 33 | Rem(member string) int64 34 | Range(score1, score2 int64) []string 35 | database.CmdAdapter 36 | } 37 | 38 | type skiplist struct { 39 | key string 40 | scoreToNode map[int64]*skipnode 41 | memberToScore map[string]int64 42 | head *skipnode 43 | rander *rand.Rand 44 | } 45 | 46 | func newSkiplist(key string) SortedSet { 47 | return &skiplist{ 48 | key: key, 49 | memberToScore: make(map[string]int64), 50 | scoreToNode: make(map[int64]*skipnode), 51 | head: newSkipnode(0, 0), 52 | rander: rand.New((rand.NewSource(lib.TimeNow().UnixNano()))), 53 | } 54 | } 55 | 56 | func (s *skiplist) Add(score int64, member string) { 57 | // 之前存在,需要删除 58 | oldScore, ok := s.memberToScore[member] 59 | if ok { 60 | if oldScore == score { 61 | return 62 | } 63 | s.rem(oldScore, member) 64 | } 65 | 66 | s.memberToScore[member] = score 67 | node, ok := s.scoreToNode[score] 68 | if ok { 69 | node.members[member] = struct{}{} 70 | return 71 | } 72 | 73 | // 新插入,roll 出高度 74 | height := s.roll() 75 | for int64(len(s.head.nexts)) < height+1 { 76 | s.head.nexts = append(s.head.nexts, nil) 77 | } 78 | 79 | inserted := newSkipnode(score, height+1) 80 | inserted.members[member] = struct{}{} 81 | s.scoreToNode[score] = inserted 82 | 83 | move := s.head 84 | for i := height; i >= 0; i-- { 85 | for move.nexts[i] != nil && move.nexts[i].score < score { 86 | move = move.nexts[i] 87 | continue 88 | } 89 | 90 | inserted.nexts[i] = move.nexts[i] 91 | move.nexts[i] = inserted 92 | } 93 | } 94 | 95 | func (s *skiplist) Rem(member string) int64 { 96 | // 之前存在,需要删除 97 | score, ok := s.memberToScore[member] 98 | if !ok { 99 | return 0 100 | } 101 | s.rem(score, member) 102 | return 1 103 | } 104 | 105 | // [score1,score2] 106 | func (s *skiplist) Range(score1, score2 int64) []string { 107 | if score2 == -1 { 108 | score2 = math.MaxInt64 109 | } 110 | 111 | if score1 > score2 { 112 | return []string{} 113 | } 114 | 115 | move := s.head 116 | for i := len(s.head.nexts) - 1; i >= 0; i-- { 117 | for move.nexts[i] != nil && move.nexts[i].score < score1 { 118 | move = move.nexts[i] 119 | } 120 | } 121 | 122 | // 来到了 level0 层,move.nexts[i] 如果存在,就是首个 >= score1 的元素 123 | if len(move.nexts) == 0 || move.nexts[0] == nil { 124 | return []string{} 125 | } 126 | 127 | res := []string{} 128 | for move.nexts[0] != nil && move.nexts[0].score >= score1 && move.nexts[0].score <= score2 { 129 | for member := range move.nexts[0].members { 130 | res = append(res, member) 131 | } 132 | move = move.nexts[0] 133 | } 134 | return res 135 | } 136 | 137 | func (s *skiplist) roll() int64 { 138 | var level int64 139 | for s.rander.Intn(2) > 0 { 140 | level++ 141 | } 142 | return level 143 | } 144 | 145 | func (s *skiplist) rem(score int64, member string) { 146 | delete(s.memberToScore, member) 147 | skipnode := s.scoreToNode[score] 148 | 149 | delete(skipnode.members, member) 150 | if len(skipnode.members) > 0 { 151 | return 152 | } 153 | 154 | delete(s.scoreToNode, score) 155 | move := s.head 156 | for i := len(s.head.nexts) - 1; i >= 0; i-- { 157 | for move.nexts[i] != nil && move.nexts[i].score < score { 158 | move = move.nexts[i] 159 | } 160 | 161 | if move.nexts[i] == nil || move.nexts[i].score > score { 162 | continue 163 | } 164 | 165 | remed := move.nexts[i] 166 | move.nexts[i] = move.nexts[i].nexts[i] 167 | remed.nexts[i] = nil 168 | } 169 | } 170 | 171 | func (s *skiplist) ToCmd() [][]byte { 172 | args := make([][]byte, 0, 2+2*len(s.memberToScore)) 173 | args = append(args, []byte(database.CmdTypeZAdd), []byte(s.key)) 174 | for member, score := range s.memberToScore { 175 | scoreStr := strconv.FormatInt(score, 10) 176 | args = append(args, []byte(scoreStr), []byte(member)) 177 | } 178 | return args 179 | } 180 | 181 | type skipnode struct { 182 | score int64 183 | members map[string]struct{} 184 | nexts []*skipnode 185 | } 186 | 187 | func newSkipnode(score, height int64) *skipnode { 188 | return &skipnode{ 189 | score: score, 190 | members: make(map[string]struct{}), 191 | nexts: make([]*skipnode, height), 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /goredis_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "math/rand" 8 | "net" 9 | "os" 10 | "sync" 11 | "testing" 12 | "time" 13 | 14 | "github.com/spf13/cast" 15 | "github.com/stretchr/testify/assert" 16 | "github.com/xiaoxuxiansheng/goredis/app" 17 | "github.com/xiaoxuxiansheng/goredis/lib" 18 | "github.com/xiaoxuxiansheng/goredis/lib/pool" 19 | ) 20 | 21 | // goredis 质检员 22 | type QualityInspector struct { 23 | times int 24 | app *app.Application 25 | t *testing.T 26 | rander *rand.Rand 27 | } 28 | 29 | func NewQualityInspector(t *testing.T, times int) *QualityInspector { 30 | return &QualityInspector{ 31 | t: t, 32 | times: times, 33 | rander: rand.New(rand.NewSource(lib.TimeNow().UnixNano())), 34 | } 35 | } 36 | 37 | func (q *QualityInspector) prepareApp(clean bool) error { 38 | // 1 移除 aof 文件,避免读取脏数据 39 | if clean { 40 | _ = os.Remove("./appendonly.aof") 41 | } 42 | // 1 创建应用 43 | server, err := app.ConstructServer() 44 | if err != nil { 45 | return err 46 | } 47 | q.app = app.NewApplication(server, app.SetUpConfig()) 48 | // 2 异步启动应用 49 | pool.Submit(func() { 50 | if err := q.app.Run(); err != nil { 51 | q.t.Error(err) 52 | } 53 | }) 54 | return nil 55 | } 56 | 57 | func (q *QualityInspector) connApp() (*net.TCPConn, error) { 58 | <-time.After(100 * time.Millisecond) 59 | // 建立 tcp 连接 60 | return net.DialTCP("tcp", nil, &net.TCPAddr{ 61 | IP: net.IPv4(127, 0, 0, 1), 62 | Port: 6379, 63 | }) 64 | } 65 | 66 | func (q *QualityInspector) execSet(w io.Writer) { 67 | writer := bufio.NewWriter(w) 68 | for i := 0; i < 2*q.times; i++ { 69 | k := cast.ToString(i % q.times) 70 | v := cast.ToString(i % q.times) 71 | _, _ = writer.WriteString("*3\r\n") 72 | _, _ = writer.WriteString("$3\r\n") 73 | _, _ = writer.WriteString("set\r\n") 74 | _, _ = writer.WriteString(fmt.Sprintf("$%d\r\n", len(k))) 75 | _, _ = writer.WriteString(fmt.Sprintf("%s\r\n", k)) 76 | _, _ = writer.WriteString(fmt.Sprintf("$%d\r\n", len(v))) 77 | _, _ = writer.WriteString(fmt.Sprintf("%s\r\n", v)) 78 | if err := writer.Flush(); err != nil { 79 | q.t.Error(err) 80 | } 81 | q.t.Logf("set k: %s, v: %s", k, v) 82 | } 83 | } 84 | 85 | func (q *QualityInspector) readSetResp(r io.Reader) { 86 | reader := bufio.NewReader(r) 87 | for i := 0; i < q.times; i++ { 88 | line, _, _ := reader.ReadLine() 89 | q.t.Logf("set resp: %s\n", line) 90 | } 91 | } 92 | 93 | func (q *QualityInspector) execGet(w io.Writer) { 94 | writer := bufio.NewWriter(w) 95 | for i := 0; i < q.times; i++ { 96 | k := cast.ToString(i) 97 | _, _ = writer.WriteString("*2\r\n") 98 | _, _ = writer.WriteString("$3\r\n") 99 | _, _ = writer.WriteString("get\r\n") 100 | _, _ = writer.WriteString(fmt.Sprintf("$%d\r\n", len(k))) 101 | _, _ = writer.WriteString(fmt.Sprintf("%s\r\n", k)) 102 | if err := writer.Flush(); err != nil { 103 | q.t.Error(err) 104 | } 105 | q.t.Logf("get k: %s", k) 106 | } 107 | } 108 | 109 | func (q *QualityInspector) readGetResp(r io.Reader) { 110 | reader := bufio.NewReader(r) 111 | for i := 0; i < q.times; i++ { 112 | _, _, _ = reader.ReadLine() // $n 113 | line, _, _ := reader.ReadLine() 114 | q.t.Logf("get resp: %s\n", line) 115 | assert.Equal(q.t, cast.ToString(i), string(line)) 116 | } 117 | } 118 | 119 | func Test_Goredis_Set_Get(t *testing.T) { 120 | q := NewQualityInspector(t, 100) 121 | 122 | // 1 启动 go redis. 保证全局唯一 123 | if err := q.prepareApp(true); err != nil { 124 | t.Error(err) 125 | return 126 | } 127 | defer q.app.Stop() 128 | 129 | // 2 连接到 go redis 130 | conn, err := q.connApp() 131 | if err != nil { 132 | t.Error(err) 133 | return 134 | } 135 | defer conn.Close() 136 | 137 | var wg sync.WaitGroup 138 | // 3 读取set结果 139 | wg.Add(1) 140 | pool.Submit(func() { 141 | defer wg.Done() 142 | q.readSetResp(conn) 143 | }) 144 | 145 | // 4 发送set指令 146 | wg.Add(1) 147 | pool.Submit(func() { 148 | defer wg.Done() 149 | q.execSet(conn) 150 | }) 151 | 152 | wg.Wait() 153 | 154 | // 5 读取get结果 155 | wg.Add(1) 156 | pool.Submit(func() { 157 | defer wg.Done() 158 | q.readGetResp(conn) 159 | }) 160 | 161 | // 6 发送get指令 162 | wg.Add(1) 163 | pool.Submit(func() { 164 | defer wg.Done() 165 | q.execGet(conn) 166 | }) 167 | wg.Wait() 168 | 169 | <-time.After(time.Second) 170 | } 171 | 172 | func Test_Goredis_Set(t *testing.T) { 173 | test_goredis_set(t) // 1 启动 goredis 2 set 数据 3 停止 goredis 174 | } 175 | 176 | func test_goredis_set(t *testing.T) { 177 | q := NewQualityInspector(t, 100) 178 | 179 | // 1 启动 go redis. 保证全局唯一 180 | if err := q.prepareApp(true); err != nil { 181 | t.Error(err) 182 | return 183 | } 184 | defer q.app.Stop() 185 | 186 | // 2 连接到 go redis 187 | conn, err := q.connApp() 188 | if err != nil { 189 | t.Error(err) 190 | return 191 | } 192 | defer conn.Close() 193 | 194 | var wg sync.WaitGroup 195 | // 3 读取set结果 196 | wg.Add(1) 197 | pool.Submit(func() { 198 | defer wg.Done() 199 | q.readSetResp(conn) 200 | }) 201 | 202 | // 4 发送set指令 203 | wg.Add(1) 204 | pool.Submit(func() { 205 | defer wg.Done() 206 | q.execSet(conn) 207 | }) 208 | 209 | wg.Wait() 210 | <-time.After(2 * time.Second) 211 | } 212 | 213 | func Test_Aof_Get(t *testing.T) { 214 | test_goredis_aof_get(t) // 1 启动 goredis(通过 aof 恢复数据)2 get 数据 3 停止 goredis 215 | } 216 | 217 | func test_goredis_aof_get(t *testing.T) { 218 | q := NewQualityInspector(t, 100) 219 | 220 | <-time.After(time.Second) 221 | // 1 启动 go redis. 不删除 aof 文件 222 | if err := q.prepareApp(false); err != nil { 223 | t.Error(err) 224 | return 225 | } 226 | defer q.app.Stop() 227 | 228 | // 2 连接到 go redis 229 | <-time.After(time.Second) 230 | conn, err := q.connApp() 231 | if err != nil { 232 | t.Error(err) 233 | return 234 | } 235 | defer conn.Close() 236 | 237 | var wg sync.WaitGroup 238 | // 5 读取get结果 239 | wg.Add(1) 240 | pool.Submit(func() { 241 | defer wg.Done() 242 | q.readGetResp(conn) 243 | }) 244 | 245 | // 6 发送get指令 246 | wg.Add(1) 247 | pool.Submit(func() { 248 | defer wg.Done() 249 | q.execGet(conn) 250 | }) 251 | wg.Wait() 252 | } 253 | -------------------------------------------------------------------------------- /datastore/sorted_set_test.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "sort" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/spf13/cast" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/xiaoxuxiansheng/goredis/database" 13 | "github.com/xiaoxuxiansheng/goredis/lib" 14 | ) 15 | 16 | func Test_skiplist_add_rem_range(t *testing.T) { 17 | skiplist := newSkiplist("") 18 | // 添加 1000 条指令 19 | for i := 0; i < 1000; i++ { 20 | skiplist.Add(int64(i), fmt.Sprintf("%d_0", i)) 21 | skiplist.Add(int64(i), fmt.Sprintf("%d_1", i)) 22 | } 23 | 24 | // 随机移除 1000 个 member 25 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano())) 26 | remSet := make(map[string]struct{}, 1000) 27 | for i := 0; i < 1000; i++ { 28 | score := rander.Intn(1000) 29 | index := rander.Intn(2) 30 | member := fmt.Sprintf("%d_%d", score, index) 31 | remSet[member] = struct{}{} 32 | skiplist.Rem(member) 33 | } 34 | 35 | t.Run("single_score", func(t *testing.T) { 36 | for i := 0; i < 100; i++ { 37 | score := int64(rander.Intn(1000)) 38 | member := skiplist.Range(score, score) 39 | sort.Slice(member, func(i, j int) bool { 40 | return member[i] < member[j] 41 | }) 42 | expected := make([]string, 0, 2) 43 | member1 := fmt.Sprintf("%d_0", score) 44 | member2 := fmt.Sprintf("%d_1", score) 45 | if _, ok := remSet[member1]; !ok { 46 | expected = append(expected, member1) 47 | } 48 | if _, ok := remSet[member2]; !ok { 49 | expected = append(expected, member2) 50 | } 51 | 52 | assert.Equal(t, expected, member) 53 | } 54 | }) 55 | 56 | t.Run("normal_score_range", func(t *testing.T) { 57 | for i := 0; i < 100; i++ { 58 | leftScore := int64(rander.Intn(501)) 59 | rightScore := leftScore + int64(rander.Intn(500)) 60 | member := skiplist.Range(leftScore, rightScore) 61 | sort.Slice(member, func(i, j int) bool { 62 | splitted1 := strings.Split(member[i], "_") 63 | splitted2 := strings.Split(member[j], "_") 64 | if splitted1[0] == splitted2[0] { 65 | return cast.ToInt(splitted1[1]) < cast.ToInt(splitted2[1]) 66 | } 67 | return cast.ToInt(splitted1[0]) < cast.ToInt(splitted2[0]) 68 | }) 69 | 70 | expected := make([]string, 0, 2*(rightScore-leftScore+1)) 71 | for j := leftScore; j <= rightScore; j++ { 72 | member1 := fmt.Sprintf("%d_0", j) 73 | member2 := fmt.Sprintf("%d_1", j) 74 | if _, ok := remSet[member1]; !ok { 75 | expected = append(expected, member1) 76 | } 77 | if _, ok := remSet[member2]; !ok { 78 | expected = append(expected, member2) 79 | } 80 | } 81 | assert.Equal(t, expected, member) 82 | } 83 | }) 84 | 85 | t.Run("with_maximum_right_range", func(t *testing.T) { 86 | for i := 0; i < 100; i++ { 87 | leftScore := int64(rander.Intn(1000)) 88 | rightScore := int64(-1) 89 | member := skiplist.Range(leftScore, rightScore) 90 | sort.Slice(member, func(i, j int) bool { 91 | splitted1 := strings.Split(member[i], "_") 92 | splitted2 := strings.Split(member[j], "_") 93 | if splitted1[0] == splitted2[0] { 94 | return cast.ToInt(splitted1[1]) < cast.ToInt(splitted2[1]) 95 | } 96 | return cast.ToInt(splitted1[0]) < cast.ToInt(splitted2[0]) 97 | }) 98 | 99 | expected := make([]string, 0, 2*(1000-leftScore)) 100 | for j := leftScore; j < 1000; j++ { 101 | member1 := fmt.Sprintf("%d_0", j) 102 | member2 := fmt.Sprintf("%d_1", j) 103 | if _, ok := remSet[member1]; !ok { 104 | expected = append(expected, member1) 105 | } 106 | if _, ok := remSet[member2]; !ok { 107 | expected = append(expected, member2) 108 | } 109 | } 110 | assert.Equal(t, expected, member) 111 | } 112 | }) 113 | } 114 | 115 | func Test_skiplist_upsert_member_with_dif_score(t *testing.T) { 116 | skiplist := newSkiplist("") 117 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano())) 118 | scoreToMembers := make(map[int64][]string) 119 | memberSet := make(map[string]struct{}) 120 | for i := 0; i < 1000; i++ { 121 | score1 := int64(rander.Intn(1000)) 122 | member := cast.ToString(score1) 123 | if _, ok := memberSet[member]; ok { 124 | continue 125 | } 126 | memberSet[member] = struct{}{} 127 | skiplist.Add(score1, member) 128 | score2 := int64(rander.Intn(1000)) 129 | skiplist.Add(score2, member) 130 | scoreToMembers[score2] = append(scoreToMembers[score2], member) 131 | } 132 | 133 | t.Run("score_to_members", func(t *testing.T) { 134 | for score, members := range scoreToMembers { 135 | sort.Slice(members, func(i, j int) bool { 136 | return cast.ToInt(members[i]) < cast.ToInt(members[j]) 137 | }) 138 | 139 | actualMembers := skiplist.Range(score, score) 140 | sort.Slice(actualMembers, func(i, j int) bool { 141 | return cast.ToInt(actualMembers[i]) < cast.ToInt(actualMembers[j]) 142 | }) 143 | 144 | assert.Equal(t, members, actualMembers) 145 | 146 | // member 对应的前一个 score 不能查询得到 member 147 | for _, member := range members { 148 | oldScore := cast.ToInt64(member) 149 | if oldScore == score { 150 | continue 151 | } 152 | for _, gotMember := range skiplist.Range(oldScore, oldScore) { 153 | if gotMember == member { 154 | t.Errorf("old score: %d, members: %s", oldScore, gotMember) 155 | } 156 | } 157 | } 158 | } 159 | }) 160 | } 161 | 162 | func Test_skiplist_to_cmd(t *testing.T) { 163 | skiplist := newSkiplist("") 164 | 165 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano())) 166 | memberToScore := make(map[int]int, 1000) 167 | // 插入1000条数据 168 | for i := 0; i < 1000; i++ { 169 | score := rander.Intn(1000) 170 | member := rander.Intn(1000) 171 | skiplist.Add(int64(score), cast.ToString(member)) 172 | memberToScore[member] = score 173 | } 174 | 175 | cmd := skiplist.ToCmd() 176 | t.Run("length", func(t *testing.T) { 177 | assert.Equal(t, 2*len(memberToScore)+2, len(cmd)) 178 | }) 179 | t.Run("command", func(t *testing.T) { 180 | assert.Equal(t, database.CmdTypeZAdd, database.CmdType(cmd[0])) 181 | }) 182 | t.Run("key", func(t *testing.T) { 183 | assert.Equal(t, "", string(cmd[1])) 184 | }) 185 | 186 | type scoreToMember struct { 187 | score, member int 188 | } 189 | actual := make([]scoreToMember, 0, 1000) 190 | for i := 2; i < len(cmd); i += 2 { 191 | actual = append(actual, scoreToMember{ 192 | score: cast.ToInt(string(cmd[i])), 193 | member: cast.ToInt(string(cmd[i+1])), 194 | }) 195 | } 196 | 197 | sort.Slice(actual, func(i, j int) bool { 198 | if actual[i].score == actual[j].score { 199 | return actual[i].member < actual[j].member 200 | } 201 | return actual[i].score < actual[j].score 202 | }) 203 | 204 | expect := make([]scoreToMember, 0, 2*len(memberToScore)) 205 | for member, score := range memberToScore { 206 | expect = append(expect, scoreToMember{ 207 | score: score, 208 | member: member, 209 | }) 210 | } 211 | sort.Slice(expect, func(i, j int) bool { 212 | if expect[i].score == expect[j].score { 213 | return expect[i].member < expect[j].member 214 | } 215 | return expect[i].score < expect[j].score 216 | }) 217 | 218 | t.Run("member", func(t *testing.T) { 219 | assert.Equal(t, expect, actual) 220 | }) 221 | } 222 | -------------------------------------------------------------------------------- /datastore/kv_store.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "context" 5 | "strconv" 6 | "strings" 7 | "time" 8 | 9 | "github.com/xiaoxuxiansheng/goredis/database" 10 | "github.com/xiaoxuxiansheng/goredis/handler" 11 | "github.com/xiaoxuxiansheng/goredis/lib" 12 | ) 13 | 14 | type KVStore struct { 15 | data map[string]interface{} 16 | expiredAt map[string]time.Time 17 | 18 | expireTimeWheel SortedSet 19 | 20 | persister handler.Persister 21 | } 22 | 23 | func NewKVStore(persister handler.Persister) database.DataStore { 24 | return &KVStore{ 25 | data: make(map[string]interface{}), 26 | expiredAt: make(map[string]time.Time), 27 | expireTimeWheel: newSkiplist("expireTimeWheel"), 28 | persister: persister, 29 | } 30 | } 31 | 32 | // expire 33 | func (k *KVStore) Expire(cmd *database.Command) handler.Reply { 34 | args := cmd.Args() 35 | key := string(args[0]) 36 | ttl, err := strconv.ParseInt(string(args[1]), 10, 64) 37 | if err != nil { 38 | return handler.NewSyntaxErrReply() 39 | } 40 | if ttl <= 0 { 41 | return handler.NewErrReply("ERR invalid expire time") 42 | } 43 | 44 | expireAt := lib.TimeNow().Add(time.Duration(ttl) * time.Second) 45 | _cmd := [][]byte{[]byte(database.CmdTypeExpireAt), []byte(key), []byte(lib.TimeSecondFormat(expireAt))} 46 | return k.expireAt(cmd.Ctx(), _cmd, key, expireAt) 47 | } 48 | 49 | func (k *KVStore) ExpireAt(cmd *database.Command) handler.Reply { 50 | args := cmd.Args() 51 | key := string(args[0]) 52 | expiredAt, err := lib.ParseTimeSecondFormat(string((args[1]))) 53 | if err != nil { 54 | return handler.NewSyntaxErrReply() 55 | } 56 | if expiredAt.Before(lib.TimeNow()) { 57 | return handler.NewErrReply("ERR invalid expire time") 58 | } 59 | 60 | return k.expireAt(cmd.Ctx(), cmd.Cmd(), key, expiredAt) 61 | } 62 | 63 | func (k *KVStore) expireAt(ctx context.Context, cmd [][]byte, key string, expireAt time.Time) handler.Reply { 64 | k.expire(key, expireAt) 65 | k.persister.PersistCmd(ctx, cmd) // 持久化 66 | return handler.NewOKReply() 67 | } 68 | 69 | // string 70 | func (k *KVStore) Get(cmd *database.Command) handler.Reply { 71 | args := cmd.Args() 72 | key := string(args[0]) 73 | v, err := k.getAsString(key) 74 | if err != nil { 75 | return handler.NewErrReply(err.Error()) 76 | } 77 | if v == nil { 78 | return handler.NewNillReply() 79 | } 80 | return handler.NewBulkReply(v.Bytes()) 81 | } 82 | 83 | func (k *KVStore) MGet(cmd *database.Command) handler.Reply { 84 | args := cmd.Args() 85 | res := make([][]byte, 0, len(args)) 86 | for _, arg := range args { 87 | v, err := k.getAsString(string(arg)) 88 | if err != nil { 89 | return handler.NewErrReply(err.Error()) 90 | } 91 | if v == nil { 92 | res = append(res, []byte("(nil)")) 93 | continue 94 | } 95 | res = append(res, v.Bytes()) 96 | } 97 | 98 | return handler.NewMultiBulkReply(res) 99 | } 100 | 101 | func (k *KVStore) Set(cmd *database.Command) handler.Reply { 102 | args := cmd.Args() 103 | key := string(args[0]) 104 | value := string(args[1]) 105 | 106 | // 支持 NX EX 107 | var ( 108 | insertStrategy bool 109 | ttlStrategy bool 110 | ttlSeconds int64 111 | ttlIndex = -1 112 | ) 113 | 114 | for i := 2; i < len(args); i++ { 115 | flag := strings.ToLower(string(args[i])) 116 | switch flag { 117 | case "nx": 118 | insertStrategy = true 119 | case "ex": 120 | // 重复的 ex 指令 121 | if ttlStrategy { 122 | return handler.NewSyntaxErrReply() 123 | } 124 | if i == len(args)-1 { 125 | return handler.NewSyntaxErrReply() 126 | } 127 | ttl, err := strconv.ParseInt(string(args[i+1]), 10, 64) 128 | if err != nil { 129 | return handler.NewSyntaxErrReply() 130 | } 131 | if ttl <= 0 { 132 | return handler.NewErrReply("ERR invalid expire time") 133 | } 134 | 135 | ttlStrategy = true 136 | ttlSeconds = ttl 137 | ttlIndex = i 138 | i++ 139 | default: 140 | return handler.NewSyntaxErrReply() 141 | } 142 | } 143 | 144 | // 将 args 剔除 ex 部分,进行持久化 145 | if ttlIndex != -1 { 146 | args = append(args[:ttlIndex], args[ttlIndex+2:]...) 147 | } 148 | 149 | // 设置 150 | affected := k.put(key, value, insertStrategy) 151 | if affected > 0 && ttlStrategy { 152 | expireAt := lib.TimeNow().Add(time.Duration(ttlSeconds) * time.Second) 153 | _cmd := [][]byte{[]byte(database.CmdTypeExpireAt), []byte(key), []byte(lib.TimeSecondFormat(expireAt))} 154 | _ = k.expireAt(cmd.Ctx(), _cmd, key, expireAt) // 其中会完成 ex 信息的持久化 155 | } 156 | 157 | // 过期时间处理 158 | if affected > 0 { 159 | k.persister.PersistCmd(cmd.Ctx(), append([][]byte{[]byte(database.CmdTypeSet)}, args...)) 160 | return handler.NewIntReply(affected) 161 | } 162 | 163 | return handler.NewNillReply() 164 | } 165 | 166 | func (k *KVStore) MSet(cmd *database.Command) handler.Reply { 167 | args := cmd.Args() 168 | if len(args)&1 == 1 { 169 | return handler.NewSyntaxErrReply() 170 | } 171 | 172 | for i := 0; i < len(args); i += 2 { 173 | _ = k.put(string(args[i]), string(args[i+1]), false) 174 | } 175 | 176 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) 177 | return handler.NewIntReply(int64(len(args) >> 1)) 178 | } 179 | 180 | // list 181 | func (k *KVStore) LPush(cmd *database.Command) handler.Reply { 182 | args := cmd.Args() 183 | key := string(args[0]) 184 | list, err := k.getAsList(key) 185 | if err != nil { 186 | return handler.NewErrReply(err.Error()) 187 | } 188 | 189 | if list == nil { 190 | list = newListEntity(key) 191 | k.putAsList(key, list) 192 | } 193 | 194 | for i := 1; i < len(args); i++ { 195 | list.LPush(args[i]) 196 | } 197 | 198 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) 199 | return handler.NewIntReply(list.Len()) 200 | } 201 | 202 | func (k *KVStore) LPop(cmd *database.Command) handler.Reply { 203 | args := cmd.Args() 204 | key := string(args[0]) 205 | var cnt int64 206 | if len(args) > 1 { 207 | rawCnt, err := strconv.ParseInt(string(args[1]), 10, 64) 208 | if err != nil { 209 | return handler.NewSyntaxErrReply() 210 | } 211 | if rawCnt < 1 { 212 | return handler.NewSyntaxErrReply() 213 | } 214 | cnt = rawCnt 215 | } 216 | 217 | list, err := k.getAsList(key) 218 | if err != nil { 219 | return handler.NewErrReply(err.Error()) 220 | } 221 | 222 | if list == nil { 223 | return handler.NewNillReply() 224 | } 225 | 226 | if cnt == 0 { 227 | cnt = 1 228 | } 229 | 230 | poped := list.LPop(cnt) 231 | if poped == nil { 232 | return handler.NewNillReply() 233 | } 234 | 235 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化 236 | 237 | if len(poped) == 1 { 238 | return handler.NewBulkReply(poped[0]) 239 | } 240 | 241 | return handler.NewMultiBulkReply(poped) 242 | } 243 | 244 | func (k *KVStore) RPush(cmd *database.Command) handler.Reply { 245 | args := cmd.Args() 246 | key := string(args[0]) 247 | list, err := k.getAsList(key) 248 | if err != nil { 249 | return handler.NewErrReply(err.Error()) 250 | } 251 | 252 | if list == nil { 253 | list = newListEntity(key, args[1:]...) 254 | k.putAsList(key, list) 255 | return handler.NewIntReply(list.Len()) 256 | } 257 | 258 | for i := 1; i < len(args); i++ { 259 | list.RPush(args[i]) 260 | } 261 | 262 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化 263 | return handler.NewIntReply(list.Len()) 264 | } 265 | 266 | func (k *KVStore) RPop(cmd *database.Command) handler.Reply { 267 | args := cmd.Args() 268 | key := string(args[0]) 269 | var cnt int64 270 | if len(args) > 1 { 271 | rawCnt, err := strconv.ParseInt(string(args[1]), 10, 64) 272 | if err != nil { 273 | return handler.NewSyntaxErrReply() 274 | } 275 | if rawCnt < 1 { 276 | return handler.NewSyntaxErrReply() 277 | } 278 | cnt = rawCnt 279 | } 280 | 281 | list, err := k.getAsList(key) 282 | if err != nil { 283 | return handler.NewErrReply(err.Error()) 284 | } 285 | 286 | if list == nil { 287 | return handler.NewNillReply() 288 | } 289 | 290 | if cnt == 0 { 291 | cnt = 1 292 | } 293 | 294 | poped := list.RPop(cnt) 295 | if poped == nil { 296 | return handler.NewNillReply() 297 | } 298 | 299 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化 300 | if len(poped) == 1 { 301 | return handler.NewBulkReply(poped[0]) 302 | } 303 | 304 | return handler.NewMultiBulkReply(poped) 305 | } 306 | 307 | func (k *KVStore) LRange(cmd *database.Command) handler.Reply { 308 | args := cmd.Args() 309 | if len(args) != 3 { 310 | return handler.NewSyntaxErrReply() 311 | } 312 | 313 | key := string(args[0]) 314 | start, err := strconv.ParseInt(string(args[1]), 10, 64) 315 | if err != nil { 316 | return handler.NewSyntaxErrReply() 317 | } 318 | 319 | stop, err := strconv.ParseInt(string(args[2]), 10, 64) 320 | if err != nil { 321 | return handler.NewSyntaxErrReply() 322 | } 323 | 324 | list, err := k.getAsList(key) 325 | if err != nil { 326 | return handler.NewErrReply(err.Error()) 327 | } 328 | 329 | if list == nil { 330 | return handler.NewNillReply() 331 | } 332 | 333 | if got := list.Range(start, stop); got != nil { 334 | return handler.NewMultiBulkReply(got) 335 | } 336 | 337 | return handler.NewNillReply() 338 | } 339 | 340 | // set 341 | func (k *KVStore) SAdd(cmd *database.Command) handler.Reply { 342 | args := cmd.Args() 343 | key := string(args[0]) 344 | set, err := k.getAsSet(key) 345 | if err != nil { 346 | return handler.NewErrReply(err.Error()) 347 | } 348 | 349 | if set == nil { 350 | set = newSetEntity(key) 351 | k.putAsSet(key, set) 352 | } 353 | 354 | var added int64 355 | for _, arg := range args[1:] { 356 | added += set.Add(string(arg)) 357 | } 358 | 359 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化 360 | return handler.NewIntReply(added) 361 | } 362 | 363 | func (k *KVStore) SIsMember(cmd *database.Command) handler.Reply { 364 | args := cmd.Args() 365 | if len(args) != 2 { 366 | return handler.NewSyntaxErrReply() 367 | } 368 | 369 | key := string(args[0]) 370 | set, err := k.getAsSet(key) 371 | if err != nil { 372 | return handler.NewErrReply(err.Error()) 373 | } 374 | 375 | if set == nil { 376 | return handler.NewIntReply(0) 377 | } 378 | 379 | return handler.NewIntReply(set.Exist(string(args[1]))) 380 | } 381 | 382 | func (k *KVStore) SRem(cmd *database.Command) handler.Reply { 383 | args := cmd.Args() 384 | key := string(args[0]) 385 | set, err := k.getAsSet(key) 386 | if err != nil { 387 | return handler.NewErrReply(err.Error()) 388 | } 389 | 390 | if set == nil { 391 | return handler.NewIntReply(0) 392 | } 393 | 394 | var remed int64 395 | for _, arg := range args[1:] { 396 | remed += set.Rem(string(arg)) 397 | } 398 | 399 | if remed > 0 { 400 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化 401 | } 402 | return handler.NewIntReply(remed) 403 | } 404 | 405 | // hash 406 | func (k *KVStore) HSet(cmd *database.Command) handler.Reply { 407 | args := cmd.Args() 408 | if len(args)&1 != 1 { 409 | return handler.NewSyntaxErrReply() 410 | } 411 | 412 | key := string(args[0]) 413 | hmap, err := k.getAsHashMap(key) 414 | if err != nil { 415 | return handler.NewErrReply(err.Error()) 416 | } 417 | 418 | if hmap == nil { 419 | hmap = newHashMapEntity(key) 420 | k.putAsHashMap(key, hmap) 421 | } 422 | 423 | for i := 0; i < len(args)-1; i += 2 { 424 | hkey := string(args[i+1]) 425 | hvalue := args[i+2] 426 | hmap.Put(hkey, hvalue) 427 | } 428 | 429 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化 430 | return handler.NewIntReply(int64((len(args) - 1) >> 1)) 431 | } 432 | 433 | func (k *KVStore) HGet(cmd *database.Command) handler.Reply { 434 | args := cmd.Args() 435 | key := string(args[0]) 436 | hmap, err := k.getAsHashMap(key) 437 | if err != nil { 438 | return handler.NewErrReply(err.Error()) 439 | } 440 | 441 | if hmap == nil { 442 | return handler.NewNillReply() 443 | } 444 | 445 | if v := hmap.Get(string(args[1])); v != nil { 446 | return handler.NewBulkReply(v) 447 | } 448 | 449 | return handler.NewNillReply() 450 | } 451 | 452 | func (k *KVStore) HDel(cmd *database.Command) handler.Reply { 453 | args := cmd.Args() 454 | key := string(args[0]) 455 | hmap, err := k.getAsHashMap(key) 456 | if err != nil { 457 | return handler.NewErrReply(err.Error()) 458 | } 459 | 460 | if hmap == nil { 461 | return handler.NewIntReply(0) 462 | } 463 | 464 | var remed int64 465 | for _, arg := range args[1:] { 466 | remed += hmap.Del(string(arg)) 467 | } 468 | 469 | if remed > 0 { 470 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化 471 | } 472 | return handler.NewIntReply(remed) 473 | } 474 | 475 | // sorted set 476 | func (k *KVStore) ZAdd(cmd *database.Command) handler.Reply { 477 | args := cmd.Args() 478 | if len(args)&1 != 1 { 479 | return handler.NewSyntaxErrReply() 480 | } 481 | 482 | key := string(args[0]) 483 | var ( 484 | scores = make([]int64, 0, (len(args)-1)>>1) 485 | members = make([]string, 0, (len(args)-1)>>1) 486 | ) 487 | 488 | for i := 0; i < len(args)-1; i += 2 { 489 | score, err := strconv.ParseInt(string(args[i+1]), 10, 64) 490 | if err != nil { 491 | return handler.NewSyntaxErrReply() 492 | } 493 | 494 | scores = append(scores, score) 495 | members = append(members, string(args[i+2])) 496 | } 497 | 498 | zset, err := k.getAsSortedSet(key) 499 | if err != nil { 500 | return handler.NewErrReply(err.Error()) 501 | } 502 | 503 | if zset == nil { 504 | zset = newSkiplist(key) 505 | k.putAsSortedSet(key, zset) 506 | } 507 | 508 | for i := 0; i < len(scores); i++ { 509 | zset.Add(scores[i], members[i]) 510 | } 511 | 512 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化 513 | return handler.NewIntReply(int64(len(scores))) 514 | } 515 | 516 | func (k *KVStore) ZRangeByScore(cmd *database.Command) handler.Reply { 517 | args := cmd.Args() 518 | if len(args) < 3 { 519 | return handler.NewSyntaxErrReply() 520 | } 521 | 522 | key := string(args[0]) 523 | score1, err := strconv.ParseInt(string(args[1]), 10, 64) 524 | if err != nil { 525 | return handler.NewSyntaxErrReply() 526 | } 527 | score2, err := strconv.ParseInt(string(args[2]), 10, 64) 528 | if err != nil { 529 | return handler.NewSyntaxErrReply() 530 | } 531 | 532 | zset, err := k.getAsSortedSet(key) 533 | if err != nil { 534 | return handler.NewErrReply(err.Error()) 535 | } 536 | 537 | if zset == nil { 538 | return handler.NewNillReply() 539 | } 540 | 541 | rawRes := zset.Range(score1, score2) 542 | if len(rawRes) == 0 { 543 | return handler.NewNillReply() 544 | } 545 | 546 | res := make([][]byte, 0, len(rawRes)) 547 | for _, item := range rawRes { 548 | res = append(res, []byte(item)) 549 | } 550 | 551 | return handler.NewMultiBulkReply(res) 552 | } 553 | 554 | func (k *KVStore) ZRem(cmd *database.Command) handler.Reply { 555 | args := cmd.Args() 556 | key := string(args[0]) 557 | zset, err := k.getAsSortedSet(key) 558 | if err != nil { 559 | return handler.NewErrReply(err.Error()) 560 | } 561 | 562 | if zset == nil { 563 | return handler.NewIntReply(0) 564 | } 565 | 566 | var remed int64 567 | for _, arg := range args { 568 | remed += zset.Rem(string(arg)) 569 | } 570 | 571 | if remed > 0 { 572 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化 573 | } 574 | return handler.NewIntReply(remed) 575 | } 576 | --------------------------------------------------------------------------------