├── .gitignore ├── start.sh ├── img ├── frame.png ├── logic.png └── database.png ├── redis.conf ├── main.go ├── lib ├── time.go └── pool │ └── pool.go ├── app ├── app.go ├── factory.go └── config.go ├── go.mod ├── datastore ├── persist.go ├── expire.go ├── string.go ├── hash.go ├── set.go ├── list.go ├── set_test.go ├── list_test.go ├── hash_test.go ├── sorted_set.go ├── sorted_set_test.go └── kv_store.go ├── handler ├── struct.go ├── persister.go ├── handler.go └── reply.go ├── database ├── trigger.go ├── executor.go └── struct.go ├── README.md ├── persist ├── persist.go ├── aof_rewrite.go └── aof.go ├── go.sum ├── log └── log.go ├── server └── server.go ├── protocol └── parser.go └── goredis_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | *.output 3 | *.aof -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | nohup go run ./main.go >./nohup.output 2>&1 & -------------------------------------------------------------------------------- /img/frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoxuxiansheng/goredis/HEAD/img/frame.png -------------------------------------------------------------------------------- /img/logic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoxuxiansheng/goredis/HEAD/img/logic.png -------------------------------------------------------------------------------- /img/database.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoxuxiansheng/goredis/HEAD/img/database.png -------------------------------------------------------------------------------- /redis.conf: -------------------------------------------------------------------------------- 1 | # ip 地址 2 | bind 0.0.0.0 3 | # 端口 4 | port 6379 5 | 6 | # 是否启用 aof 7 | appendonly yes 8 | # aof 文件名称 9 | appendfilename appendonly.aof 10 | # aof 级别. always | everysec | no 11 | appendfsync everysec 12 | # 每执行多少次 aof 操作后,进行一次重写 13 | auto-aof-rewrite-after-cmds 1000 14 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "github.com/xiaoxuxiansheng/goredis/app" 4 | 5 | func main() { 6 | server, err := app.ConstructServer() 7 | if err != nil { 8 | panic(err) 9 | } 10 | 11 | app := app.NewApplication(server, app.SetUpConfig()) 12 | defer app.Stop() 13 | 14 | if err := app.Run(); err != nil { 15 | panic(err) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /lib/time.go: -------------------------------------------------------------------------------- 1 | package lib 2 | 3 | import "time" 4 | 5 | const ( 6 | YYYY_MM_DD_HH_MM_SS = "2006-01-02 15:04:05" 7 | ) 8 | 9 | func TimeNow() time.Time { 10 | return time.Now() 11 | } 12 | 13 | func ParseTimeSecondFormat(timeStr string) (time.Time, error) { 14 | return time.ParseInLocation(YYYY_MM_DD_HH_MM_SS, timeStr, time.Local) 15 | } 16 | 17 | func TimeSecondFormat(t time.Time) string { 18 | return t.Format(YYYY_MM_DD_HH_MM_SS) 19 | } 20 | -------------------------------------------------------------------------------- /app/app.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import "github.com/xiaoxuxiansheng/goredis/server" 4 | 5 | type Application struct { 6 | server *server.Server 7 | conf *Config 8 | } 9 | 10 | func NewApplication(server *server.Server, conf *Config) *Application { 11 | return &Application{ 12 | server: server, 13 | conf: conf, 14 | } 15 | } 16 | 17 | func (a *Application) Run() error { 18 | return a.server.Serve(a.conf.Address()) 19 | } 20 | 21 | func (a *Application) Stop() { 22 | a.server.Stop() 23 | } 24 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/xiaoxuxiansheng/goredis 2 | 3 | go 1.19 4 | 5 | require ( 6 | go.uber.org/dig v1.17.1 7 | go.uber.org/zap v1.27.0 8 | gopkg.in/natefinch/lumberjack.v2 v2.2.1 9 | ) 10 | 11 | require ( 12 | github.com/davecgh/go-spew v1.1.1 // indirect 13 | github.com/panjf2000/ants v1.3.0 // indirect 14 | github.com/pmezard/go-difflib v1.0.0 // indirect 15 | github.com/spf13/cast v1.6.0 // indirect 16 | github.com/stretchr/testify v1.9.0 // indirect 17 | go.uber.org/multierr v1.10.0 // indirect 18 | gopkg.in/yaml.v3 v3.0.1 // indirect 19 | ) 20 | -------------------------------------------------------------------------------- /datastore/persist.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/xiaoxuxiansheng/goredis/database" 7 | "github.com/xiaoxuxiansheng/goredis/lib" 8 | ) 9 | 10 | func (k *KVStore) ForEach(f func(key string, adapter database.CmdAdapter, expireAt *time.Time)) { 11 | for key, data := range k.data { 12 | expiredAt, ok := k.expiredAt[key] 13 | if ok && expiredAt.Before(lib.TimeNow()) { 14 | continue 15 | } 16 | _adapter, _ := data.(database.CmdAdapter) 17 | if ok { 18 | f(key, _adapter, &expiredAt) 19 | } else { 20 | f(key, _adapter, nil) 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /lib/pool/pool.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "runtime/debug" 5 | "strings" 6 | 7 | "github.com/xiaoxuxiansheng/goredis/log" 8 | 9 | "github.com/panjf2000/ants" 10 | ) 11 | 12 | var pool *ants.Pool 13 | 14 | func init() { 15 | _pool, err := ants.NewPool(50000, ants.WithPanicHandler(func(i interface{}) { 16 | stackInfo := strings.Replace(string(debug.Stack()), "\n", "", -1) 17 | log.GetDefaultLogger().Errorf("recover info: %v, stack info: %s", i, stackInfo) 18 | })) 19 | if err != nil { 20 | panic(err) 21 | } 22 | pool = _pool 23 | } 24 | 25 | func Submit(task func()) { 26 | pool.Submit(task) 27 | } 28 | -------------------------------------------------------------------------------- /handler/struct.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "strings" 7 | ) 8 | 9 | var UnknownErrReplyBytes = []byte("-ERR unknown\r\n") 10 | 11 | type Reply interface { 12 | ToBytes() []byte 13 | } 14 | 15 | type MultiReply interface { 16 | Reply 17 | Args() [][]byte 18 | } 19 | 20 | type Droplet struct { 21 | Reply Reply 22 | Err error 23 | } 24 | 25 | func (d *Droplet) Terminated() bool { 26 | if d.Err == io.EOF || d.Err == io.ErrUnexpectedEOF { 27 | return true 28 | } 29 | return d.Err != nil && strings.Contains(d.Err.Error(), "use of closed network connection") 30 | } 31 | 32 | type DB interface { 33 | Do(ctx context.Context, cmdLine [][]byte) Reply 34 | Close() 35 | } 36 | 37 | // 协议解析器 38 | type Parser interface { 39 | ParseStream(reader io.Reader) <-chan *Droplet 40 | } 41 | -------------------------------------------------------------------------------- /handler/persister.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "io" 6 | ) 7 | 8 | var loadingPersisterPattern int 9 | var ctxKeyLoadingPersisterPattern = &loadingPersisterPattern 10 | 11 | func SetLoadingPattern(ctx context.Context) context.Context { 12 | return context.WithValue(ctx, ctxKeyLoadingPersisterPattern, true) 13 | } 14 | 15 | func IsLoadingPattern(ctx context.Context) bool { 16 | is, _ := ctx.Value(ctxKeyLoadingPersisterPattern).(bool) 17 | return is 18 | } 19 | 20 | type Persister interface { 21 | Reloader() (io.ReadCloser, error) 22 | PersistCmd(ctx context.Context, cmd [][]byte) 23 | Close() 24 | } 25 | 26 | type fakeReadWriter struct { 27 | io.Reader 28 | } 29 | 30 | func newFakeReaderWriter(reader io.Reader) io.ReadWriter { 31 | return &fakeReadWriter{ 32 | Reader: reader, 33 | } 34 | } 35 | 36 | func (f *fakeReadWriter) Write(p []byte) (n int, err error) { 37 | // log ... 38 | return 0, nil 39 | } 40 | -------------------------------------------------------------------------------- /datastore/expire.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/xiaoxuxiansheng/goredis/lib" 7 | ) 8 | 9 | func (k *KVStore) GC() { 10 | // 找出当前所有已过期的 key,批量回收 11 | nowUnix := lib.TimeNow().Unix() 12 | for _, expiredKey := range k.expireTimeWheel.Range(0, nowUnix) { 13 | k.expireProcess(expiredKey) 14 | } 15 | } 16 | 17 | func (k *KVStore) ExpirePreprocess(key string) { 18 | expiredAt, ok := k.expiredAt[key] 19 | if !ok { 20 | return 21 | } 22 | 23 | if expiredAt.After(lib.TimeNow()) { 24 | return 25 | } 26 | 27 | k.expireProcess(key) 28 | } 29 | 30 | func (k *KVStore) expireProcess(key string) { 31 | delete(k.expiredAt, key) 32 | delete(k.data, key) 33 | k.expireTimeWheel.Rem(key) 34 | } 35 | 36 | func (k *KVStore) expire(key string, expiredAt time.Time) { 37 | if _, ok := k.data[key]; !ok { 38 | return 39 | } 40 | k.expiredAt[key] = expiredAt 41 | k.expireTimeWheel.Add(expiredAt.Unix(), key) 42 | } 43 | -------------------------------------------------------------------------------- /database/trigger.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | 8 | "github.com/xiaoxuxiansheng/goredis/handler" 9 | ) 10 | 11 | type DBTrigger struct { 12 | once sync.Once 13 | executor Executor 14 | } 15 | 16 | func NewDBTrigger(executor Executor) handler.DB { 17 | return &DBTrigger{executor: executor} 18 | } 19 | 20 | func (d *DBTrigger) Do(ctx context.Context, cmdLine [][]byte) handler.Reply { 21 | if len(cmdLine) < 2 { 22 | return handler.NewErrReply(fmt.Sprintf("invalid cmd line: %v", cmdLine)) 23 | } 24 | 25 | cmdType := CmdType(cmdLine[0]) 26 | if !d.executor.ValidCommand(cmdType) { 27 | return handler.NewErrReply(fmt.Sprintf("unknown cmd '%s'", cmdLine[0])) 28 | } 29 | 30 | cmd := Command{ 31 | ctx: ctx, 32 | cmd: cmdType, 33 | args: cmdLine[1:], 34 | receiver: make(CmdReceiver), 35 | } 36 | 37 | // 投递给到 executor 38 | d.executor.Entrance() <- &cmd 39 | 40 | // 监听 chan,直到接收到返回的 reply 41 | return <-cmd.Receiver() 42 | } 43 | 44 | func (d *DBTrigger) Close() { 45 | d.once.Do(d.executor.Close) 46 | } 47 | -------------------------------------------------------------------------------- /datastore/string.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "github.com/xiaoxuxiansheng/goredis/database" 5 | "github.com/xiaoxuxiansheng/goredis/handler" 6 | ) 7 | 8 | func (k *KVStore) getAsString(key string) (String, error) { 9 | v, ok := k.data[key] 10 | if !ok { 11 | return nil, nil 12 | } 13 | 14 | str, ok := v.(String) 15 | if !ok { 16 | return nil, handler.NewWrongTypeErrReply() 17 | } 18 | 19 | return str, nil 20 | } 21 | 22 | func (k *KVStore) put(key, value string, insertStrategy bool) int64 { 23 | if _, ok := k.data[key]; ok && insertStrategy { 24 | return 0 25 | } 26 | 27 | k.data[key] = NewString(key, value) 28 | return 1 29 | } 30 | 31 | type String interface { 32 | Bytes() []byte 33 | database.CmdAdapter 34 | } 35 | 36 | type stringEntity struct { 37 | key, str string 38 | } 39 | 40 | func NewString(key, str string) String { 41 | return &stringEntity{key: key, str: str} 42 | } 43 | 44 | func (s *stringEntity) Bytes() []byte { 45 | return []byte(s.str) 46 | } 47 | 48 | func (s *stringEntity) ToCmd() [][]byte { 49 | return [][]byte{[]byte(database.CmdTypeSet), []byte(s.key), []byte(s.str)} 50 | } 51 | -------------------------------------------------------------------------------- /datastore/hash.go: -------------------------------------------------------------------------------- 1 | package datastore 2 | 3 | import ( 4 | "github.com/xiaoxuxiansheng/goredis/database" 5 | "github.com/xiaoxuxiansheng/goredis/handler" 6 | ) 7 | 8 | func (k *KVStore) getAsHashMap(key string) (HashMap, error) { 9 | v, ok := k.data[key] 10 | if !ok { 11 | return nil, nil 12 | } 13 | 14 | hmap, ok := v.(HashMap) 15 | if !ok { 16 | return nil, handler.NewWrongTypeErrReply() 17 | } 18 | 19 | return hmap, nil 20 | } 21 | 22 | func (k *KVStore) putAsHashMap(key string, hmap HashMap) { 23 | k.data[key] = hmap 24 | } 25 | 26 | type HashMap interface { 27 | Put(key string, value []byte) 28 | Get(key string) []byte 29 | Del(key string) int64 30 | database.CmdAdapter 31 | } 32 | 33 | type hashMapEntity struct { 34 | key string 35 | data map[string][]byte 36 | } 37 | 38 | func newHashMapEntity(key string) HashMap { 39 | return &hashMapEntity{ 40 | key: key, 41 | data: make(map[string][]byte), 42 | } 43 | } 44 | 45 | func (h *hashMapEntity) Put(key string, value []byte) { 46 | h.data[key] = value 47 | } 48 | 49 | func (h *hashMapEntity) Get(key string) []byte { 50 | return h.data[key] 51 | } 52 | 53 | func (h *hashMapEntity) Del(key string) int64 { 54 | if _, ok := h.data[key]; !ok { 55 | return 0 56 | } 57 | delete(h.data, key) 58 | return 1 59 | } 60 | 61 | func (h *hashMapEntity) ToCmd() [][]byte { 62 | args := make([][]byte, 0, 2+2*len(h.data)) 63 | args = append(args, []byte(database.CmdTypeHSet), []byte(h.key)) 64 | for k, v := range h.data { 65 | args = append(args, []byte(k), v) 66 | } 67 | return args 68 | } 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 | goredis: 基于 go 实现 redis
4 |
5 |
34 | 数据存储引擎层
35 |
36 |
--------------------------------------------------------------------------------
/app/factory.go:
--------------------------------------------------------------------------------
1 | package app
2 |
3 | import (
4 | "github.com/xiaoxuxiansheng/goredis/database"
5 | "github.com/xiaoxuxiansheng/goredis/datastore"
6 | "github.com/xiaoxuxiansheng/goredis/handler"
7 | "github.com/xiaoxuxiansheng/goredis/log"
8 | "github.com/xiaoxuxiansheng/goredis/persist"
9 | "github.com/xiaoxuxiansheng/goredis/protocol"
10 | "github.com/xiaoxuxiansheng/goredis/server"
11 |
12 | "go.uber.org/dig"
13 | )
14 |
15 | var container = dig.New()
16 |
17 | func init() {
18 | /**
19 | 其它
20 | **/
21 | // 配置加载 conf
22 | _ = container.Provide(SetUpConfig)
23 | _ = container.Provide(PersistThinker)
24 | // 日志打印 logger
25 | _ = container.Provide(log.GetDefaultLogger)
26 |
27 | /**
28 | 存储引擎
29 | **/
30 | // 数据持久化
31 | _ = container.Provide(persist.NewPersister)
32 | // 存储介质
33 | _ = container.Provide(datastore.NewKVStore)
34 | // 执行器
35 | _ = container.Provide(database.NewDBExecutor)
36 | // 触发器
37 | _ = container.Provide(database.NewDBTrigger)
38 |
39 | /**
40 | 逻辑处理层
41 | **/
42 | // 协议解析
43 | _ = container.Provide(protocol.NewParser)
44 | // 指令处理
45 | _ = container.Provide(handler.NewHandler)
46 |
47 | /**
48 | 服务端
49 | **/
50 | _ = container.Provide(server.NewServer)
51 | }
52 |
53 | func ConstructServer() (*server.Server, error) {
54 | var h server.Handler
55 | if err := container.Invoke(func(_h server.Handler) {
56 | h = _h
57 | }); err != nil {
58 | return nil, err
59 | }
60 |
61 | var l log.Logger
62 | if err := container.Invoke(func(_l log.Logger) {
63 | l = _l
64 | }); err != nil {
65 | return nil, err
66 | }
67 | return server.NewServer(h, l), nil
68 | }
69 |
--------------------------------------------------------------------------------
/persist/persist.go:
--------------------------------------------------------------------------------
1 | package persist
2 |
3 | import (
4 | "context"
5 | "io"
6 |
7 | "github.com/xiaoxuxiansheng/goredis/handler"
8 | )
9 |
10 | type Thinker interface {
11 | AppendOnly() bool
12 | AppendFileName() string
13 | AppendFsync() string
14 | AutoAofRewriteAfterCmd() int
15 | }
16 |
17 | func NewPersister(thinker Thinker) (handler.Persister, error) {
18 | if !thinker.AppendOnly() {
19 | return newFakePersister(nil), nil
20 | }
21 |
22 | return newAofPersister(thinker)
23 | }
24 |
25 | type fakeReadCloser struct {
26 | io.Reader
27 | closef func() error
28 | }
29 |
30 | func readCloserAdapter(reader io.Reader, closef func() error) io.ReadCloser {
31 | return &fakeReadCloser{Reader: reader, closef: closef}
32 | }
33 |
34 | func (f *fakeReadCloser) Close() error {
35 | return f.closef()
36 | }
37 |
38 | func newFakePersister(readCloser io.ReadCloser) handler.Persister {
39 | f := fakePersister{}
40 | if readCloser == nil {
41 | f.readCloser = singleFakeReloader
42 | return &f
43 | }
44 | f.readCloser = readCloser
45 | return &f
46 | }
47 |
48 | type fakePersister struct {
49 | readCloser io.ReadCloser
50 | }
51 |
52 | func (f *fakePersister) Reloader() (io.ReadCloser, error) {
53 | return f.readCloser, nil
54 | }
55 |
56 | func (f *fakePersister) PersistCmd(ctx context.Context, cmd [][]byte) {}
57 |
58 | func (f *fakePersister) Close() {}
59 |
60 | var singleFakeReloader = &fakeReloader{}
61 |
62 | type fakeReloader struct {
63 | }
64 |
65 | func (f *fakeReloader) Read(p []byte) (n int, err error) {
66 | return 0, io.EOF
67 | }
68 |
69 | func (f *fakeReloader) Close() error {
70 | return nil
71 | }
72 |
--------------------------------------------------------------------------------
/datastore/set.go:
--------------------------------------------------------------------------------
1 | package datastore
2 |
3 | import (
4 | "github.com/xiaoxuxiansheng/goredis/database"
5 | "github.com/xiaoxuxiansheng/goredis/handler"
6 | )
7 |
8 | func (k *KVStore) getAsSet(key string) (Set, error) {
9 | v, ok := k.data[key]
10 | if !ok {
11 | return nil, nil
12 | }
13 |
14 | set, ok := v.(Set)
15 | if !ok {
16 | return nil, handler.NewWrongTypeErrReply()
17 | }
18 |
19 | return set, nil
20 | }
21 |
22 | func (k *KVStore) putAsSet(key string, set Set) {
23 | k.data[key] = set
24 | }
25 |
26 | type Set interface {
27 | Add(value string) int64
28 | Exist(value string) int64
29 | Rem(value string) int64
30 | database.CmdAdapter
31 | }
32 |
33 | type setEntity struct {
34 | key string
35 | container map[string]struct{}
36 | }
37 |
38 | func newSetEntity(key string) Set {
39 | return &setEntity{
40 | key: key,
41 | container: make(map[string]struct{}),
42 | }
43 | }
44 |
45 | func (s *setEntity) Add(value string) int64 {
46 | if _, ok := s.container[value]; ok {
47 | return 0
48 | }
49 | s.container[value] = struct{}{}
50 | return 1
51 | }
52 |
53 | func (s *setEntity) Exist(value string) int64 {
54 | if _, ok := s.container[value]; ok {
55 | return 1
56 | }
57 | return 0
58 | }
59 |
60 | func (s *setEntity) Rem(value string) int64 {
61 | if _, ok := s.container[value]; ok {
62 | delete(s.container, value)
63 | return 1
64 | }
65 | return 0
66 | }
67 |
68 | func (s *setEntity) ToCmd() [][]byte {
69 | args := make([][]byte, 0, 2+len(s.container))
70 | args = append(args, []byte(database.CmdTypeSAdd), []byte(s.key))
71 | for k := range s.container {
72 | args = append(args, []byte(k))
73 | }
74 |
75 | return args
76 | }
77 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3 | github.com/panjf2000/ants v1.3.0 h1:8pQ+8leaLc9lys2viEEr8md0U4RN6uOSUCE9bOYjQ9M=
4 | github.com/panjf2000/ants v1.3.0/go.mod h1:AaACblRPzq35m1g3enqYcxspbbiOJJYaxU2wMpm1cXY=
5 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
6 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
7 | github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
8 | github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
9 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
10 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
11 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
12 | go.uber.org/dig v1.17.1 h1:Tga8Lz8PcYNsWsyHMZ1Vm0OQOUaJNDyvPImgbAu9YSc=
13 | go.uber.org/dig v1.17.1/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
14 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
15 | go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
16 | go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
17 | go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
18 | go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
19 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
20 | gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
21 | gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
22 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
23 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
24 |
--------------------------------------------------------------------------------
/datastore/list.go:
--------------------------------------------------------------------------------
1 | package datastore
2 |
3 | import (
4 | "github.com/xiaoxuxiansheng/goredis/database"
5 | "github.com/xiaoxuxiansheng/goredis/handler"
6 | )
7 |
8 | func (k *KVStore) getAsList(key string) (List, error) {
9 | v, ok := k.data[key]
10 | if !ok {
11 | return nil, nil
12 | }
13 |
14 | list, ok := v.(List)
15 | if !ok {
16 | return nil, handler.NewWrongTypeErrReply()
17 | }
18 |
19 | return list, nil
20 | }
21 |
22 | func (k *KVStore) putAsList(key string, list List) {
23 | k.data[key] = list
24 | }
25 |
26 | type List interface {
27 | LPush(value []byte)
28 | LPop(cnt int64) [][]byte
29 | RPush(value []byte)
30 | RPop(cnt int64) [][]byte
31 | Len() int64
32 | Range(start, stop int64) [][]byte
33 | database.CmdAdapter
34 | }
35 |
36 | type listEntity struct {
37 | key string
38 | data [][]byte
39 | }
40 |
41 | func newListEntity(key string, elements ...[]byte) List {
42 | return &listEntity{
43 | key: key,
44 | data: elements,
45 | }
46 | }
47 |
48 | func (l *listEntity) LPush(value []byte) {
49 | l.data = append([][]byte{value}, l.data...)
50 | }
51 |
52 | func (l *listEntity) LPop(cnt int64) [][]byte {
53 | if int64(len(l.data)) < cnt {
54 | return nil
55 | }
56 |
57 | poped := l.data[:cnt]
58 | l.data = l.data[cnt:]
59 | return poped
60 | }
61 |
62 | func (l *listEntity) RPush(value []byte) {
63 | l.data = append(l.data, value)
64 | }
65 |
66 | func (l *listEntity) RPop(cnt int64) [][]byte {
67 | if int64(len(l.data)) < cnt {
68 | return nil
69 | }
70 |
71 | poped := l.data[int64(len(l.data))-cnt:]
72 | l.data = l.data[:int64(len(l.data))-cnt]
73 | return poped
74 | }
75 |
76 | func (l *listEntity) Len() int64 {
77 | return int64(len(l.data))
78 | }
79 |
80 | func (l *listEntity) Range(start, stop int64) [][]byte {
81 | if stop == -1 {
82 | stop = int64(len(l.data) - 1)
83 | }
84 |
85 | if start < 0 || start >= int64(len(l.data)) {
86 | return nil
87 | }
88 |
89 | if stop < 0 || stop >= int64(len(l.data)) || stop < start {
90 | return nil
91 | }
92 |
93 | return l.data[start : stop+1]
94 | }
95 |
96 | func (l *listEntity) ToCmd() [][]byte {
97 | args := make([][]byte, 0, 2+l.Len())
98 | args = append(args, []byte(database.CmdTypeRPush), []byte(l.key))
99 | args = append(args, l.data...)
100 | return args
101 | }
102 |
--------------------------------------------------------------------------------
/datastore/set_test.go:
--------------------------------------------------------------------------------
1 | package datastore
2 |
3 | import (
4 | "math/rand"
5 | "sort"
6 | "testing"
7 |
8 | "github.com/spf13/cast"
9 | "github.com/stretchr/testify/assert"
10 | "github.com/xiaoxuxiansheng/goredis/database"
11 | "github.com/xiaoxuxiansheng/goredis/lib"
12 | )
13 |
14 | func Test_set_crud(t *testing.T) {
15 | set := newSetEntity("")
16 | s := make(map[int]struct{}, 1000)
17 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano()))
18 |
19 | t.Run("add", func(t *testing.T) {
20 | for i := 0; i < 1000; i++ {
21 | member := rander.Intn(1000)
22 | _, ok := s[member]
23 | success := set.Add(cast.ToString(member))
24 | s[member] = struct{}{}
25 | assert.Equal(t, ok, success == 0)
26 | }
27 | })
28 |
29 | t.Run("rem", func(t *testing.T) {
30 | for i := 0; i < 1000; i++ {
31 | member := rander.Intn(1000)
32 | _, ok := s[member]
33 | exist := set.Rem(cast.ToString(member))
34 | delete(s, member)
35 | assert.Equal(t, ok, exist == 1)
36 | }
37 | })
38 |
39 | t.Run("exist", func(t *testing.T) {
40 | for i := 0; i < 1000; i++ {
41 | member := rander.Intn(1000)
42 | _, ok := s[member]
43 | exist := set.Exist(cast.ToString(member))
44 | assert.Equal(t, ok, exist == 1)
45 | }
46 | })
47 | }
48 |
49 | func Test_set_to_cmd(t *testing.T) {
50 | set := newSetEntity("")
51 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano()))
52 | s := make(map[int]struct{}, 1000)
53 | // 插入1000条数据
54 | for i := 0; i < 1000; i++ {
55 | member := rander.Intn(1000)
56 | set.Add(cast.ToString(member))
57 | s[member] = struct{}{}
58 | }
59 |
60 | cmd := set.ToCmd()
61 | t.Run("length", func(t *testing.T) {
62 | assert.Equal(t, len(s)+2, len(cmd))
63 | })
64 | t.Run("command", func(t *testing.T) {
65 | assert.Equal(t, database.CmdTypeSAdd, database.CmdType(cmd[0]))
66 | })
67 | t.Run("key", func(t *testing.T) {
68 | assert.Equal(t, "", string(cmd[1]))
69 | })
70 |
71 | actual := make([]int, 0, len(cmd)-2)
72 | for i := 2; i < len(cmd); i++ {
73 | actual = append(actual, cast.ToInt(string(cmd[i])))
74 | }
75 | sort.Slice(actual, func(i, j int) bool {
76 | return actual[i] < actual[j]
77 | })
78 |
79 | expect := make([]int, 0, len(s))
80 | for member := range s {
81 | expect = append(expect, member)
82 | }
83 | sort.Slice(expect, func(i, j int) bool {
84 | return expect[i] < expect[j]
85 | })
86 |
87 | t.Run("member", func(t *testing.T) {
88 | assert.Equal(t, expect, actual)
89 | })
90 | }
91 |
--------------------------------------------------------------------------------
/datastore/list_test.go:
--------------------------------------------------------------------------------
1 | package datastore
2 |
3 | import (
4 | "math/rand"
5 | "sort"
6 | "testing"
7 |
8 | "github.com/spf13/cast"
9 | "github.com/stretchr/testify/assert"
10 | "github.com/xiaoxuxiansheng/goredis/database"
11 | "github.com/xiaoxuxiansheng/goredis/lib"
12 | )
13 |
14 | func Test_list_crud(t *testing.T) {
15 | list := newListEntity("")
16 | l := make([][]byte, 0, 1000)
17 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano()))
18 | for i := 0; i < 1000; i++ {
19 | member1 := rander.Intn(1000)
20 | member2 := rander.Intn(1000)
21 | list.LPush([]byte(cast.ToString(member1)))
22 | list.RPush([]byte(cast.ToString(member2)))
23 | l = append([][]byte{[]byte(cast.ToString(member1))}, l...)
24 | l = append(l, []byte(cast.ToString(member2)))
25 | }
26 |
27 | t.Run("range", func(t *testing.T) {
28 | for i := 0; i < 1000; i++ {
29 | start := rander.Intn(1001)
30 | end := start + rander.Intn(1000)
31 | actual := list.Range(int64(start), int64(end))
32 | expect := l[start : end+1]
33 | assert.Equal(t, expect, actual)
34 | }
35 | })
36 |
37 | t.Run("pop", func(t *testing.T) {
38 | for i := 0; i < 500; i++ {
39 | actual := list.LPop(2)
40 | expect := l[:2]
41 | l = l[2:]
42 | assert.Equal(t, expect, actual)
43 |
44 | actual = list.RPop(2)
45 | expect = l[len(l)-2:]
46 | l = l[:len(l)-2]
47 | assert.Equal(t, expect, actual)
48 | }
49 | })
50 | }
51 |
52 | func Test_list_to_cmds(t *testing.T) {
53 | list := newListEntity("")
54 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano()))
55 | l := make([]int, 0, 1000)
56 | // 插入1000条数据
57 | for i := 0; i < 1000; i++ {
58 | member := rander.Intn(1000)
59 | list.LPush([]byte(cast.ToString(member)))
60 | l = append(l, member)
61 | }
62 |
63 | cmd := list.ToCmd()
64 | t.Run("length", func(t *testing.T) {
65 | assert.Equal(t, len(l)+2, len(cmd))
66 | })
67 | t.Run("command", func(t *testing.T) {
68 | assert.Equal(t, database.CmdTypeRPush, database.CmdType(cmd[0]))
69 | })
70 | t.Run("key", func(t *testing.T) {
71 | assert.Equal(t, "", string(cmd[1]))
72 | })
73 |
74 | actual := make([]int, 0, len(cmd)-2)
75 | for i := 2; i < len(cmd); i++ {
76 | actual = append(actual, cast.ToInt(string(cmd[i])))
77 | }
78 | sort.Slice(actual, func(i, j int) bool {
79 | return actual[i] < actual[j]
80 | })
81 |
82 | expect := l
83 | sort.Slice(expect, func(i, j int) bool {
84 | return expect[i] < expect[j]
85 | })
86 |
87 | t.Run("member", func(t *testing.T) {
88 | assert.Equal(t, expect, actual)
89 | })
90 | }
91 |
--------------------------------------------------------------------------------
/database/executor.go:
--------------------------------------------------------------------------------
1 | package database
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "time"
7 |
8 | "github.com/xiaoxuxiansheng/goredis/handler"
9 | "github.com/xiaoxuxiansheng/goredis/lib/pool"
10 | )
11 |
12 | type DBExecutor struct {
13 | ctx context.Context
14 | cancel context.CancelFunc
15 | ch chan *Command
16 |
17 | cmdHandlers map[CmdType]CmdHandler
18 | dataStore DataStore
19 |
20 | gcTicker *time.Ticker
21 | }
22 |
23 | func NewDBExecutor(dataStore DataStore) Executor {
24 | ctx, cancel := context.WithCancel(context.Background())
25 | e := DBExecutor{
26 | dataStore: dataStore,
27 | ch: make(chan *Command),
28 | ctx: ctx,
29 | cancel: cancel,
30 | gcTicker: time.NewTicker(time.Minute),
31 | }
32 | e.cmdHandlers = map[CmdType]CmdHandler{
33 | CmdTypeExpire: e.dataStore.Expire,
34 | CmdTypeExpireAt: e.dataStore.ExpireAt,
35 |
36 | // string
37 | CmdTypeGet: e.dataStore.Get,
38 | CmdTypeSet: e.dataStore.Set,
39 | CmdTypeMGet: e.dataStore.MGet,
40 | CmdTypeMSet: e.dataStore.MSet,
41 |
42 | // list
43 | CmdTypeLPush: e.dataStore.LPush,
44 | CmdTypeLPop: e.dataStore.LPop,
45 | CmdTypeRPush: e.dataStore.RPush,
46 | CmdTypeRPop: e.dataStore.RPop,
47 | CmdTypeLRange: e.dataStore.LRange,
48 |
49 | // set
50 | CmdTypeSAdd: e.dataStore.SAdd,
51 | CmdTypeSIsMember: e.dataStore.SIsMember,
52 | CmdTypeSRem: e.dataStore.SRem,
53 |
54 | // hash
55 | CmdTypeHSet: e.dataStore.HSet,
56 | CmdTypeHGet: e.dataStore.HGet,
57 | CmdTypeHDel: e.dataStore.HDel,
58 |
59 | // sorted set
60 | CmdTypeZAdd: e.dataStore.ZAdd,
61 | CmdTypeZRangeByScore: e.dataStore.ZRangeByScore,
62 | CmdTypeZRem: e.dataStore.ZRem,
63 | }
64 |
65 | pool.Submit(e.run)
66 | return &e
67 | }
68 |
69 | func (e *DBExecutor) Entrance() chan<- *Command {
70 | return e.ch
71 | }
72 |
73 | func (e *DBExecutor) ValidCommand(cmd CmdType) bool {
74 | _, valid := e.cmdHandlers[cmd] // map 只读,不考虑并发问题
75 | return valid
76 | }
77 |
78 | func (e *DBExecutor) Close() {
79 | e.cancel()
80 | }
81 |
82 | func (e *DBExecutor) run() {
83 | for {
84 | select {
85 | case <-e.ctx.Done():
86 | return
87 |
88 | // 每隔 1 分钟批量一次过期的 key
89 | case <-e.gcTicker.C:
90 | e.dataStore.GC()
91 |
92 | case cmd := <-e.ch:
93 | cmdFunc, ok := e.cmdHandlers[cmd.cmd]
94 | if !ok {
95 | cmd.receiver <- handler.NewErrReply(fmt.Sprintf("unknown command '%s'", cmd.cmd))
96 | continue
97 | }
98 |
99 | e.dataStore.ExpirePreprocess(string(cmd.args[0])) // 懒加载机制实现过期 key 删除
100 | cmd.receiver <- cmdFunc(cmd)
101 | }
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/datastore/hash_test.go:
--------------------------------------------------------------------------------
1 | package datastore
2 |
3 | import (
4 | "math/rand"
5 | "sort"
6 | "testing"
7 |
8 | "github.com/spf13/cast"
9 | "github.com/stretchr/testify/assert"
10 | "github.com/xiaoxuxiansheng/goredis/database"
11 | "github.com/xiaoxuxiansheng/goredis/lib"
12 | )
13 |
14 | func Test_hashmap_crud(t *testing.T) {
15 | hashmap := newHashMapEntity("")
16 | mp := make(map[int]int, 1000)
17 |
18 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano()))
19 | for i := 0; i < 1000; i++ {
20 | k := rander.Intn(1000)
21 | v := rander.Intn(1000)
22 | hashmap.Put(cast.ToString(k), []byte(cast.ToString(v)))
23 | mp[k] = v
24 | }
25 |
26 | t.Run("delete", func(t *testing.T) {
27 | for i := 0; i < 1000; i++ {
28 | k := rander.Intn(1000)
29 | _, ok := mp[k]
30 | exist := hashmap.Del(cast.ToString(k))
31 | assert.Equal(t, ok, exist == 1)
32 | delete(mp, k)
33 | }
34 | })
35 |
36 | t.Run("get", func(t *testing.T) {
37 | for i := 0; i < 1000; i++ {
38 | k := rander.Intn(1000)
39 | value := hashmap.Get(cast.ToString(k))
40 | v, ok := mp[k]
41 | assert.Equal(t, ok, value != nil)
42 | if ok {
43 | assert.Equal(t, v, cast.ToInt(string(value)))
44 | }
45 | }
46 | })
47 | }
48 |
49 | func Test_hashmap_to_cmd(t *testing.T) {
50 | hashmap := newHashMapEntity("")
51 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano()))
52 | mp := make(map[int]int, 1000)
53 | // 插入1000条数据
54 | for i := 0; i < 1000; i++ {
55 | k := rander.Intn(1000)
56 | v := rander.Intn(1000)
57 | hashmap.Put(cast.ToString(k), []byte(cast.ToString(v)))
58 | mp[k] = v
59 | }
60 |
61 | cmd := hashmap.ToCmd()
62 | t.Run("length", func(t *testing.T) {
63 | assert.Equal(t, 2*len(mp)+2, len(cmd))
64 | })
65 | t.Run("command", func(t *testing.T) {
66 | assert.Equal(t, database.CmdTypeHSet, database.CmdType(cmd[0]))
67 | })
68 | t.Run("key", func(t *testing.T) {
69 | assert.Equal(t, "", string(cmd[1]))
70 | })
71 |
72 | type kv struct {
73 | k, v int
74 | }
75 | actual := make([]kv, 0, 1000)
76 | for i := 2; i < len(cmd); i += 2 {
77 | actual = append(actual, kv{
78 | k: cast.ToInt(string(cmd[i])),
79 | v: cast.ToInt(string(cmd[i+1])),
80 | })
81 | }
82 |
83 | sort.Slice(actual, func(i, j int) bool {
84 | if actual[i].k == actual[j].k {
85 | return actual[i].v < actual[j].v
86 | }
87 | return actual[i].k < actual[j].k
88 | })
89 |
90 | expect := make([]kv, 0, 2*len(mp))
91 | for k, v := range mp {
92 | expect = append(expect, kv{
93 | k: k,
94 | v: v,
95 | })
96 | }
97 | sort.Slice(expect, func(i, j int) bool {
98 | if expect[i].k == expect[j].k {
99 | return expect[i].v < expect[j].v
100 | }
101 | return expect[i].k < expect[j].k
102 | })
103 |
104 | t.Run("member", func(t *testing.T) {
105 | assert.Equal(t, expect, actual)
106 | })
107 | }
108 |
--------------------------------------------------------------------------------
/log/log.go:
--------------------------------------------------------------------------------
1 | package log
2 |
3 | import (
4 | "go.uber.org/zap"
5 | "go.uber.org/zap/zapcore"
6 | "gopkg.in/natefinch/lumberjack.v2"
7 | )
8 |
9 | type Logger interface {
10 | Errorf(format string, v ...interface{})
11 | Warnf(format string, v ...interface{})
12 | Infof(format string, v ...interface{})
13 | Debugf(format string, v ...interface{})
14 | }
15 |
16 | var (
17 | defaultLogger Logger
18 | )
19 |
20 | func init() {
21 | defaultLogger = NewLogger(NewOptions())
22 | }
23 |
24 | // Options 选项配置
25 | type Options struct {
26 | LogName string // 日志名称
27 | LogLevel string // 日志级别
28 | FileName string // 文件名称
29 | MaxAge int // 日志保留时间,以天为单位
30 | MaxSize int // 日志保留大小,以 M 为单位
31 | MaxBackups int // 保留文件个数
32 | Compress bool // 是否压缩
33 | }
34 |
35 | // Option 选项方法
36 | type Option func(*Options)
37 |
38 | // NewOptions 初始化
39 | func NewOptions(opts ...Option) Options {
40 | options := Options{
41 | LogName: "app",
42 | LogLevel: "info",
43 | FileName: "app.log",
44 | MaxAge: 10,
45 | MaxSize: 100,
46 | MaxBackups: 3,
47 | Compress: true,
48 | }
49 | for _, opt := range opts {
50 | opt(&options)
51 | }
52 | return options
53 | }
54 |
55 | // WithLogLevel 日志级别
56 | func WithLogLevel(level string) Option {
57 | return func(o *Options) {
58 | o.LogLevel = level
59 | }
60 | }
61 |
62 | // WithFileName 日志文件
63 | func WithFileName(filename string) Option {
64 | return func(o *Options) {
65 | o.FileName = filename
66 | }
67 | }
68 |
69 | // Levels zapcore level
70 | var Levels = map[string]zapcore.Level{
71 | "debug": zapcore.DebugLevel,
72 | "info": zapcore.InfoLevel,
73 | "warn": zapcore.WarnLevel,
74 | "error": zapcore.ErrorLevel,
75 | "fatal": zapcore.FatalLevel,
76 | }
77 |
78 | type zapLoggerWrapper struct {
79 | *zap.SugaredLogger
80 | options Options
81 | }
82 |
83 | func NewLogger(options Options) Logger {
84 | w := &zapLoggerWrapper{options: options}
85 | encoder := w.getEncoder()
86 | writeSyncer := w.getLogWriter()
87 | core := zapcore.NewCore(encoder, writeSyncer, Levels[options.LogLevel])
88 | w.SugaredLogger = zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1)).Sugar()
89 | return w
90 | }
91 |
92 | func (w *zapLoggerWrapper) getEncoder() zapcore.Encoder {
93 | encoderConfig := zap.NewProductionEncoderConfig()
94 | encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
95 |
96 | // 在日志文件中使用大写字母记录日志级别
97 | encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
98 | // NewConsoleEncoder 打印更符合人们观察的方式
99 | return zapcore.NewConsoleEncoder(encoderConfig)
100 | }
101 |
102 | func (w *zapLoggerWrapper) getLogWriter() zapcore.WriteSyncer {
103 | return zapcore.AddSync(&lumberjack.Logger{
104 | Filename: w.options.FileName,
105 | MaxAge: w.options.MaxAge,
106 | MaxSize: w.options.MaxSize,
107 | MaxBackups: w.options.MaxBackups,
108 | Compress: w.options.Compress,
109 | })
110 | }
111 |
112 | // GetDefaultLogger 获取默认日志实现
113 | func GetDefaultLogger() Logger {
114 | return defaultLogger
115 | }
116 |
--------------------------------------------------------------------------------
/database/struct.go:
--------------------------------------------------------------------------------
1 | package database
2 |
3 | import (
4 | "context"
5 | "strings"
6 | "time"
7 |
8 | "github.com/xiaoxuxiansheng/goredis/handler"
9 | )
10 |
11 | type Executor interface {
12 | Entrance() chan<- *Command
13 | ValidCommand(cmd CmdType) bool
14 | Close()
15 | }
16 |
17 | type CmdType string
18 |
19 | func (c CmdType) String() string {
20 | return strings.ToLower(string(c))
21 | }
22 |
23 | const (
24 | CmdTypeExpire CmdType = "expire"
25 | CmdTypeExpireAt CmdType = "expireat"
26 |
27 | // string
28 | CmdTypeGet CmdType = "get"
29 | CmdTypeSet CmdType = "set"
30 | CmdTypeMGet CmdType = "mget"
31 | CmdTypeMSet CmdType = "mset"
32 |
33 | // list
34 | CmdTypeLPush CmdType = "lpush"
35 | CmdTypeLPop CmdType = "lpop"
36 | CmdTypeRPush CmdType = "rpush"
37 | CmdTypeRPop CmdType = "rpop"
38 | CmdTypeLRange CmdType = "lrange"
39 |
40 | // hash
41 | CmdTypeHSet CmdType = "hset"
42 | CmdTypeHGet CmdType = "hget"
43 | CmdTypeHDel CmdType = "hdel"
44 |
45 | // set
46 | CmdTypeSAdd CmdType = "sadd"
47 | CmdTypeSIsMember CmdType = "sismember"
48 | CmdTypeSRem CmdType = "srem"
49 |
50 | // sorted set
51 | CmdTypeZAdd CmdType = "zadd"
52 | CmdTypeZRangeByScore CmdType = "zrangebyscore"
53 | CmdTypeZRem CmdType = "zrem"
54 | )
55 |
56 | type CmdAdapter interface {
57 | ToCmd() [][]byte
58 | }
59 |
60 | type DataStore interface {
61 | ForEach(task func(key string, adapter CmdAdapter, expireAt *time.Time))
62 |
63 | ExpirePreprocess(key string)
64 | GC()
65 |
66 | Expire(*Command) handler.Reply
67 | ExpireAt(*Command) handler.Reply
68 |
69 | // string
70 | Get(*Command) handler.Reply
71 | MGet(*Command) handler.Reply
72 | Set(*Command) handler.Reply
73 | MSet(*Command) handler.Reply
74 |
75 | // list
76 | LPush(*Command) handler.Reply
77 | LPop(*Command) handler.Reply
78 | RPush(*Command) handler.Reply
79 | RPop(*Command) handler.Reply
80 | LRange(*Command) handler.Reply
81 |
82 | // set
83 | SAdd(*Command) handler.Reply
84 | SIsMember(*Command) handler.Reply
85 | SRem(*Command) handler.Reply
86 |
87 | // hash
88 | HSet(*Command) handler.Reply
89 | HGet(*Command) handler.Reply
90 | HDel(*Command) handler.Reply
91 |
92 | // sorted set
93 | ZAdd(*Command) handler.Reply
94 | ZRangeByScore(*Command) handler.Reply
95 | ZRem(*Command) handler.Reply
96 | }
97 |
98 | type CmdHandler func(*Command) handler.Reply
99 |
100 | type Command struct {
101 | ctx context.Context
102 | cmd CmdType
103 | args [][]byte
104 | receiver CmdReceiver
105 | }
106 |
107 | func NewCommand(cmd CmdType, args [][]byte) *Command {
108 | return &Command{
109 | cmd: cmd,
110 | args: args,
111 | }
112 | }
113 |
114 | func (c *Command) Ctx() context.Context {
115 | return c.ctx
116 | }
117 |
118 | func (c *Command) Receiver() CmdReceiver {
119 | return c.receiver
120 | }
121 |
122 | func (c *Command) Args() [][]byte {
123 | return c.args
124 | }
125 |
126 | func (c *Command) Cmd() [][]byte {
127 | return append([][]byte{[]byte(c.cmd.String())}, c.args...)
128 | }
129 |
130 | type CmdReceiver chan handler.Reply
131 |
--------------------------------------------------------------------------------
/server/server.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "context"
5 | "net"
6 | "os"
7 | "os/signal"
8 | "sync"
9 | "syscall"
10 | "time"
11 |
12 | "github.com/xiaoxuxiansheng/goredis/lib/pool"
13 | "github.com/xiaoxuxiansheng/goredis/log"
14 | )
15 |
16 | // 处理器
17 | type Handler interface {
18 | Start() error // 启动 handler
19 | // 处理到来的每一笔 tcp 连接
20 | Handle(ctx context.Context, conn net.Conn)
21 | // 关闭处理器
22 | Close()
23 | }
24 |
25 | type Server struct {
26 | runOnce sync.Once
27 | stopOnce sync.Once
28 | handler Handler
29 | logger log.Logger
30 | stopc chan struct{}
31 | }
32 |
33 | func NewServer(handler Handler, logger log.Logger) *Server {
34 | return &Server{
35 | handler: handler,
36 | logger: logger,
37 | stopc: make(chan struct{}),
38 | }
39 | }
40 |
41 | func (s *Server) Serve(address string) error {
42 | if err := s.handler.Start(); err != nil {
43 | return err
44 | }
45 | var _err error
46 | s.runOnce.Do(func() {
47 | // 监听进程信号
48 | exitWords := []os.Signal{syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT}
49 |
50 | sigc := make(chan os.Signal, 1)
51 | signal.Notify(sigc, exitWords...)
52 | closec := make(chan struct{}, 4)
53 | pool.Submit(func() {
54 | for {
55 | select {
56 | case signal := <-sigc:
57 | switch signal {
58 | case exitWords[0], exitWords[1], exitWords[2], exitWords[3]:
59 | closec <- struct{}{}
60 | return
61 | default:
62 | }
63 | case <-s.stopc:
64 | closec <- struct{}{}
65 | return
66 | }
67 | }
68 | })
69 |
70 | listener, err := net.Listen("tcp", address)
71 | if err != nil {
72 | _err = err
73 | return
74 | }
75 |
76 | s.listenAndServe(listener, closec)
77 | })
78 |
79 | return _err
80 | }
81 |
82 | func (s *Server) Stop() {
83 | s.stopOnce.Do(func() {
84 | close(s.stopc)
85 | })
86 | }
87 |
88 | func (s *Server) listenAndServe(listener net.Listener, closec chan struct{}) {
89 | errc := make(chan error, 1)
90 | defer close(errc)
91 |
92 | // 遇到意外错误,则终止流程
93 | ctx, cancel := context.WithCancel(context.Background())
94 | pool.Submit(
95 | func() {
96 | select {
97 | case <-closec:
98 | s.logger.Errorf("[server]server closing...")
99 | case err := <-errc:
100 | s.logger.Errorf("[server]server err: %s", err.Error())
101 | }
102 | cancel()
103 | s.logger.Warnf("[server]server closeing...")
104 | s.handler.Close()
105 | if err := listener.Close(); err != nil {
106 | s.logger.Errorf("[server]server close listener err: %s", err.Error())
107 | }
108 | })
109 |
110 | s.logger.Warnf("[server]server starting...")
111 | var wg sync.WaitGroup
112 | // io 多路复用模型,goroutine for per conn
113 | for {
114 | conn, err := listener.Accept()
115 | if err != nil {
116 | // 超时类错误,忽略
117 | if ne, ok := err.(net.Error); ok && ne.Timeout() {
118 | time.Sleep(5 * time.Millisecond)
119 | continue
120 | }
121 |
122 | // 意外错误,则停止运行
123 | errc <- err
124 | break
125 | }
126 |
127 | // 为每个到来的 conn 分配一个 goroutine 处理
128 | wg.Add(1)
129 | pool.Submit(func() {
130 | defer wg.Done()
131 | s.handler.Handle(ctx, conn)
132 | })
133 | }
134 |
135 | // 通过 waitGroup 保证优雅退出
136 | wg.Wait()
137 | }
138 |
--------------------------------------------------------------------------------
/handler/handler.go:
--------------------------------------------------------------------------------
1 | package handler
2 |
3 | import (
4 | "context"
5 | "io"
6 | "net"
7 | "sync"
8 | "sync/atomic"
9 |
10 | "github.com/xiaoxuxiansheng/goredis/log"
11 | "github.com/xiaoxuxiansheng/goredis/server"
12 | )
13 |
14 | type Handler struct {
15 | sync.Once
16 | mu sync.RWMutex
17 | conns map[net.Conn]struct{}
18 | closed atomic.Bool
19 |
20 | db DB
21 | parser Parser
22 | persister Persister
23 | logger log.Logger
24 | }
25 |
26 | func NewHandler(db DB, persister Persister, parser Parser, logger log.Logger) (server.Handler, error) {
27 | h := Handler{
28 | conns: make(map[net.Conn]struct{}),
29 | persister: persister,
30 | logger: logger,
31 | db: db,
32 | parser: parser,
33 | }
34 |
35 | return &h, nil
36 | }
37 |
38 | func (h *Handler) Start() error {
39 | // 加载持久化文件,还原内容
40 | reloader, err := h.persister.Reloader()
41 | if err != nil {
42 | return err
43 | }
44 | defer reloader.Close()
45 | h.handle(SetLoadingPattern(context.Background()), newFakeReaderWriter(reloader))
46 | return nil
47 | }
48 |
49 | func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
50 | h.mu.Lock()
51 | // 判断 db 是否已经关闭
52 | if h.closed.Load() {
53 | h.mu.Unlock()
54 | return
55 | }
56 |
57 | // 当前 conn 缓存起来
58 | h.conns[conn] = struct{}{}
59 | h.mu.Unlock()
60 |
61 | h.handle(ctx, conn)
62 | }
63 |
64 | func (h *Handler) handle(ctx context.Context, conn io.ReadWriter) {
65 | // 持续处理
66 | stream := h.parser.ParseStream(conn)
67 | for {
68 | select {
69 | case <-ctx.Done():
70 | h.logger.Warnf("[handler]handle ctx err: %s", ctx.Err().Error())
71 | return
72 |
73 | case droplet := <-stream:
74 | if err := h.handleDroplet(ctx, conn, droplet); err != nil {
75 | h.logger.Errorf("[handler]conn terminated, err: %s", droplet.Err.Error())
76 | return
77 | }
78 | }
79 | }
80 | }
81 |
82 | func (h *Handler) handleDroplet(ctx context.Context, conn io.ReadWriter, droplet *Droplet) error {
83 | if droplet.Terminated() {
84 | return droplet.Err
85 | }
86 |
87 | if droplet.Err != nil {
88 | _, _ = conn.Write(droplet.Reply.ToBytes())
89 | h.logger.Errorf("[handler]conn request, err: %s", droplet.Err.Error())
90 | return nil
91 | }
92 |
93 | if droplet.Reply == nil {
94 | h.logger.Errorf("[handler]conn empty request")
95 | return nil
96 | }
97 |
98 | // 请求参数必须为 multiBulkReply 类型
99 | multiReply, ok := droplet.Reply.(MultiReply)
100 | if !ok {
101 | h.logger.Errorf("[handler]conn invalid request: %s", droplet.Reply.ToBytes())
102 | return nil
103 | }
104 |
105 | if reply := h.db.Do(ctx, multiReply.Args()); reply != nil {
106 | _, _ = conn.Write(reply.ToBytes())
107 | return nil
108 | }
109 |
110 | _, _ = conn.Write(UnknownErrReplyBytes)
111 | return nil
112 | }
113 |
114 | func (h *Handler) Close() {
115 | h.Once.Do(func() {
116 | h.logger.Warnf("[handler]handler closing...")
117 | h.closed.Store(true)
118 | h.mu.RLock()
119 | defer h.mu.RUnlock()
120 | for conn := range h.conns {
121 | if err := conn.Close(); err != nil {
122 | h.logger.Errorf("[handler]close conn err, local addr: %s, err: %s", conn.LocalAddr().String(), err.Error())
123 | }
124 | }
125 | h.conns = nil
126 | h.db.Close()
127 | h.persister.Close()
128 | })
129 | }
130 |
--------------------------------------------------------------------------------
/app/config.go:
--------------------------------------------------------------------------------
1 | package app
2 |
3 | import (
4 | "bufio"
5 | "fmt"
6 | "io"
7 | "os"
8 | "reflect"
9 | "strconv"
10 | "strings"
11 | "sync"
12 |
13 | "github.com/xiaoxuxiansheng/goredis/persist"
14 | )
15 |
16 | type Config struct {
17 | Bind string `cfg:"bind"` // ip 地址
18 | Port int `cfg:"port"` // 启动端口号
19 | AppendOnly_ bool `cfg:"appendonly"` // 是否启用 aof
20 | AppendFileName_ string `cfg:"appendfilename"` // aof 文件名称
21 | AppendFsync_ string `cfg:"appendfsync"` // aof 级别
22 | AutoAofRewriteAfterCmd_ int `cfg:"auto-aof-rewrite-after-cmds"` // 每执行多少次 aof 操作后,进行一次重写
23 | }
24 |
25 | func (c *Config) Address() string {
26 | return fmt.Sprintf("%s:%d", c.Bind, c.Port)
27 | }
28 |
29 | func (c *Config) AppendOnly() bool {
30 | return c.AppendOnly_
31 | }
32 |
33 | func (c *Config) AppendFileName() string {
34 | return c.AppendFileName_
35 | }
36 |
37 | func (c *Config) AppendFsync() string {
38 | return c.AppendFsync_
39 | }
40 |
41 | func (c *Config) AutoAofRewriteAfterCmd() int {
42 | return c.AutoAofRewriteAfterCmd_
43 | }
44 |
45 | var (
46 | confOnce sync.Once
47 | globalConf *Config
48 | )
49 |
50 | func PersistThinker() persist.Thinker {
51 | return SetUpConfig()
52 | }
53 |
54 | func SetUpConfig() *Config {
55 | confOnce.Do(func() {
56 | defer func() {
57 | if globalConf == nil {
58 | globalConf = defaultConf()
59 | }
60 | }()
61 |
62 | file, err := os.Open("./redis.conf")
63 | if err != nil {
64 | return
65 | }
66 | defer file.Close()
67 | globalConf = setUpConfig(file)
68 | })
69 |
70 | return globalConf
71 | }
72 |
73 | func setUpConfig(src io.Reader) *Config {
74 | tmpkv := make(map[string]string)
75 | scanner := bufio.NewScanner(src)
76 | for scanner.Scan() {
77 | line := scanner.Text()
78 | // 注释行,跳过
79 | trimmed := strings.TrimSpace(line)
80 | if len(trimmed) > 0 && trimmed[0] == '#' {
81 | continue
82 | }
83 |
84 | // 寻找合法的空格分隔符位置
85 | pivot := strings.Index(trimmed, " ")
86 | if pivot <= 0 || pivot >= len(trimmed)-1 {
87 | continue
88 | }
89 |
90 | key := trimmed[:pivot]
91 | value := trimmed[pivot+1:]
92 | tmpkv[key] = value
93 | }
94 |
95 | if err := scanner.Err(); err != nil {
96 | return nil
97 | }
98 |
99 | conf := &Config{}
100 | // 通过反射设置 conf 属性值
101 | t := reflect.TypeOf(conf)
102 | v := reflect.ValueOf(conf)
103 | for i := 0; i < t.Elem().NumField(); i++ {
104 | field := t.Elem().Field(i)
105 | fieldVal := v.Elem().Field(i)
106 | key, ok := field.Tag.Lookup("cfg")
107 | if !ok || strings.TrimSpace(key) == "" {
108 | key = field.Name
109 | }
110 | value, ok := tmpkv[key]
111 | if !ok {
112 | continue
113 | }
114 | switch field.Type.Kind() {
115 | case reflect.String:
116 | fieldVal.SetString(value)
117 | case reflect.Int:
118 | intv, _ := strconv.ParseInt(value, 10, 64)
119 | fieldVal.SetInt(intv)
120 | case reflect.Bool:
121 | fieldVal.SetBool(value == "yes")
122 | }
123 | }
124 |
125 | return conf
126 | }
127 |
128 | func defaultConf() *Config {
129 | return &Config{
130 | Bind: "0.0.0.0",
131 | Port: 6379,
132 | AppendOnly_: false, // 默认不启用 aof
133 | }
134 | }
135 |
--------------------------------------------------------------------------------
/persist/aof_rewrite.go:
--------------------------------------------------------------------------------
1 | package persist
2 |
3 | import (
4 | "io"
5 | "os"
6 | "time"
7 |
8 | "github.com/xiaoxuxiansheng/goredis/database"
9 | "github.com/xiaoxuxiansheng/goredis/datastore"
10 | "github.com/xiaoxuxiansheng/goredis/handler"
11 | "github.com/xiaoxuxiansheng/goredis/lib"
12 | "github.com/xiaoxuxiansheng/goredis/log"
13 | "github.com/xiaoxuxiansheng/goredis/protocol"
14 | )
15 |
16 | // 重写 aof 文件
17 | func (a *aofPersister) rewriteAOF() error {
18 | // 1 重写前处理. 需要短暂加锁
19 | tmpFile, fileSize, err := a.startRewrite()
20 | if err != nil {
21 | return err
22 | }
23 |
24 | // 2 aof 指令重写. 与主流程并发执行
25 | if err = a.doRewrite(tmpFile, fileSize); err != nil {
26 | return err
27 | }
28 |
29 | // 3 完成重写. 需要短暂加锁
30 | return a.endRewrite(tmpFile, fileSize)
31 | }
32 |
33 | func (a *aofPersister) startRewrite() (*os.File, int64, error) {
34 | a.mu.Lock()
35 | defer a.mu.Unlock()
36 |
37 | if err := a.aofFile.Sync(); err != nil {
38 | return nil, 0, err
39 | }
40 |
41 | fileInfo, _ := os.Stat(a.aofFileName)
42 | fileSize := fileInfo.Size()
43 |
44 | // 创建一个临时的 aof 文件
45 | tmpFile, err := os.CreateTemp("./", "*.aof")
46 | if err != nil {
47 | return nil, 0, err
48 | }
49 |
50 | return tmpFile, fileSize, nil
51 | }
52 |
53 | func (a *aofPersister) doRewrite(tmpFile *os.File, fileSize int64) error {
54 | forkedDB, err := a.forkDB(fileSize)
55 | if err != nil {
56 | return err
57 | }
58 |
59 | // 将 db 数据转为 aof cmd
60 | forkedDB.ForEach(func(key string, adapter database.CmdAdapter, expireAt *time.Time) {
61 | _, _ = tmpFile.Write(handler.NewMultiBulkReply(adapter.ToCmd()).ToBytes())
62 |
63 | if expireAt == nil {
64 | return
65 | }
66 |
67 | expireCmd := [][]byte{[]byte(database.CmdTypeExpireAt), []byte(key), []byte(lib.TimeSecondFormat(*expireAt))}
68 | _, _ = tmpFile.Write(handler.NewMultiBulkReply(expireCmd).ToBytes())
69 | })
70 |
71 | return nil
72 | }
73 |
74 | func (a *aofPersister) forkDB(fileSize int64) (database.DataStore, error) {
75 | file, err := os.Open(a.aofFileName)
76 | if err != nil {
77 | return nil, err
78 | }
79 | file.Seek(0, io.SeekStart)
80 | logger := log.GetDefaultLogger()
81 | reloader := readCloserAdapter(io.LimitReader(file, fileSize), file.Close)
82 | fakePerisister := newFakePersister(reloader)
83 | tmpKVStore := datastore.NewKVStore(fakePerisister)
84 | executor := database.NewDBExecutor(tmpKVStore)
85 | trigger := database.NewDBTrigger(executor)
86 | h, err := handler.NewHandler(trigger, fakePerisister, protocol.NewParser(logger), logger)
87 | if err != nil {
88 | return nil, err
89 | }
90 | if err = h.Start(); err != nil {
91 | return nil, err
92 | }
93 | return tmpKVStore, nil
94 | }
95 |
96 | func (a *aofPersister) endRewrite(tmpFile *os.File, fileSize int64) error {
97 | a.mu.Lock()
98 | defer a.mu.Unlock()
99 |
100 | // copy commands executed during rewriting to tmpFile
101 | /* read write commands executed during rewriting */
102 | src, err := os.Open(a.aofFileName)
103 | if err != nil {
104 | return err
105 | }
106 | defer func() {
107 | _ = src.Close()
108 | _ = tmpFile.Close()
109 | }()
110 |
111 | if _, err = src.Seek(fileSize, 0); err != nil {
112 | return err
113 | }
114 |
115 | // 把老的 aof 文件中后续内容 copy 到 tmp 中
116 | if _, err = io.Copy(tmpFile, src); err != nil {
117 | return err
118 | }
119 |
120 | // 关闭老的 aof 文件,准备废弃
121 | _ = a.aofFile.Close()
122 | // 重命名 tmp 文件,作为新的 aof 文件
123 | if err := os.Rename(tmpFile.Name(), a.aofFileName); err != nil {
124 | // log
125 | }
126 |
127 | // 重新开启
128 | aofFile, err := os.OpenFile(a.aofFileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
129 | if err != nil {
130 | panic(err)
131 | }
132 | a.aofFile = aofFile
133 | return nil
134 | }
135 |
--------------------------------------------------------------------------------
/handler/reply.go:
--------------------------------------------------------------------------------
1 | package handler
2 |
3 | import (
4 | "strconv"
5 | "strings"
6 | )
7 |
8 | // CRLF 是 redis 统一的行分隔符协议
9 | const CRLF = "\r\n"
10 |
11 | type OKReply struct{}
12 |
13 | func NewOKReply() *OKReply {
14 | return theOkReply
15 | }
16 |
17 | var okBytes = []byte("+OK\r\n")
18 |
19 | func (o *OKReply) ToBytes() []byte {
20 | return okBytes
21 | }
22 |
23 | var theOkReply = new(OKReply)
24 |
25 | // 简单字符串类型. 协议为 【+】【string】【CRLF】
26 | type SimpleStringReply struct {
27 | Str string
28 | }
29 |
30 | func NewSimpleStringReply(str string) *SimpleStringReply {
31 | return &SimpleStringReply{
32 | Str: str,
33 | }
34 | }
35 |
36 | func (s *SimpleStringReply) ToBytes() []byte {
37 | return []byte("+" + s.Str + CRLF)
38 | }
39 |
40 | // 简单数字类型. 协议为 【:】【int】【CRLF】
41 | type IntReply struct {
42 | Code int64
43 | }
44 |
45 | func NewIntReply(code int64) *IntReply {
46 | return &IntReply{
47 | Code: code,
48 | }
49 | }
50 |
51 | func (i *IntReply) ToBytes() []byte {
52 | return []byte(":" + strconv.FormatInt(i.Code, 10) + CRLF)
53 | }
54 |
55 | // 参数语法错误
56 | type SyntaxErrReply struct{}
57 |
58 | var syntaxErrBytes = []byte("-Err syntax error\r\n")
59 | var theSyntaxErrReply = &SyntaxErrReply{}
60 |
61 | func NewSyntaxErrReply() *SyntaxErrReply {
62 | return theSyntaxErrReply
63 | }
64 |
65 | func (r *SyntaxErrReply) ToBytes() []byte {
66 | return syntaxErrBytes
67 | }
68 |
69 | func (r *SyntaxErrReply) Error() string {
70 | return "Err syntax error"
71 | }
72 |
73 | // 数据类型错误
74 | type WrongTypeErrReply struct{}
75 |
76 | var theWrongTypeErrReply = &WrongTypeErrReply{}
77 |
78 | var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n")
79 |
80 | func NewWrongTypeErrReply() *WrongTypeErrReply {
81 | return theWrongTypeErrReply
82 | }
83 |
84 | func (r *WrongTypeErrReply) ToBytes() []byte {
85 | return wrongTypeErrBytes
86 | }
87 |
88 | func (r *WrongTypeErrReply) Error() string {
89 | return "WRONGTYPE Operation against a key holding the wrong kind of value"
90 | }
91 |
92 | // 错误类型. 协议为 【-】【err】【CRLF】
93 | type ErrReply struct {
94 | ErrStr string
95 | }
96 |
97 | func NewErrReply(errStr string) *ErrReply {
98 | return &ErrReply{
99 | ErrStr: errStr,
100 | }
101 | }
102 |
103 | func (e *ErrReply) ToBytes() []byte {
104 | return []byte("-" + e.ErrStr + CRLF)
105 | }
106 |
107 | var (
108 | nillReply = &NillReply{}
109 | nillBulkBytes = []byte("$-1\r\n")
110 | )
111 |
112 | // nill 类型,采用全局单例,格式固定为 【$】【-1】【CRLF】
113 | type NillReply struct {
114 | }
115 |
116 | func NewNillReply() *NillReply {
117 | return nillReply
118 | }
119 |
120 | func (n *NillReply) ToBytes() []byte {
121 | return nillBulkBytes
122 | }
123 |
124 | // 定长字符串类型,协议固定为 【$】【length】【CRLF】【content】【CRLF】
125 | type BulkReply struct {
126 | Arg []byte
127 | }
128 |
129 | func NewBulkReply(arg []byte) *BulkReply {
130 | return &BulkReply{
131 | Arg: arg,
132 | }
133 | }
134 |
135 | func (b *BulkReply) ToBytes() []byte {
136 | if b.Arg == nil {
137 | return nillBulkBytes
138 | }
139 | return []byte("$" + strconv.Itoa(len(b.Arg)) + CRLF + string(b.Arg) + CRLF)
140 | }
141 |
142 | // 数组类型. 协议固定为 【*】【arr.length】【CRLF】+ arr.length * (【$】【length】【CRLF】【content】【CRLF】)
143 | type MultiBulkReply struct {
144 | args [][]byte
145 | }
146 |
147 | func NewMultiBulkReply(args [][]byte) *MultiBulkReply {
148 | return &MultiBulkReply{
149 | args: args,
150 | }
151 | }
152 |
153 | func (m *MultiBulkReply) Args() [][]byte {
154 | return m.args
155 | }
156 |
157 | func (m *MultiBulkReply) ToBytes() []byte {
158 | var strBuf strings.Builder
159 | strBuf.WriteString("*" + strconv.Itoa(len(m.args)) + CRLF)
160 | for _, arg := range m.args {
161 | if arg == nil {
162 | strBuf.WriteString(string(nillBulkBytes))
163 | continue
164 | }
165 | strBuf.WriteString("$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF)
166 | }
167 | return []byte(strBuf.String())
168 | }
169 |
170 | var emptyMultiBulkBytes = []byte("*0\r\n")
171 |
172 | // 空数组类型. 采用单例,协议固定为【*】【0】【CRLF】
173 | type EmptyMultiBulkReply struct{}
174 |
175 | func NewEmptyMultiBulkReply() *EmptyMultiBulkReply {
176 | return &EmptyMultiBulkReply{}
177 | }
178 |
179 | func (r *EmptyMultiBulkReply) ToBytes() []byte {
180 | return emptyMultiBulkBytes
181 | }
182 |
--------------------------------------------------------------------------------
/persist/aof.go:
--------------------------------------------------------------------------------
1 | package persist
2 |
3 | import (
4 | "context"
5 | "io"
6 | "os"
7 | "sync"
8 | "sync/atomic"
9 | "time"
10 |
11 | "github.com/xiaoxuxiansheng/goredis/handler"
12 | "github.com/xiaoxuxiansheng/goredis/lib/pool"
13 | )
14 |
15 | // always | everysec | no
16 | type appendSyncStrategy string
17 |
18 | func (a appendSyncStrategy) string() string {
19 | return string(a)
20 | }
21 |
22 | const (
23 | alwaysAppendSyncStrategy appendSyncStrategy = "always"
24 | everysecAppendSyncStrategy appendSyncStrategy = "everysec"
25 | noAppendSyncStrategy appendSyncStrategy = "no"
26 | )
27 |
28 | type aofPersister struct {
29 | ctx context.Context
30 | cancel context.CancelFunc
31 |
32 | buffer chan [][]byte
33 | aofFile *os.File
34 | aofFileName string
35 | appendFsync appendSyncStrategy
36 | autoAofRewriteAfterCmd int64
37 | aofCounter atomic.Int64
38 |
39 | mu sync.Mutex
40 | once sync.Once
41 | }
42 |
43 | func newAofPersister(thinker Thinker) (handler.Persister, error) {
44 | aofFileName := thinker.AppendFileName()
45 | aofFile, err := os.OpenFile(aofFileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
46 | if err != nil {
47 | return nil, err
48 | }
49 | ctx, cancel := context.WithCancel(context.Background())
50 | a := aofPersister{
51 | ctx: ctx,
52 | cancel: cancel,
53 | buffer: make(chan [][]byte, 1<<10),
54 | aofFile: aofFile,
55 | aofFileName: aofFileName,
56 | }
57 |
58 | if autoAofRewriteAfterCmd := thinker.AutoAofRewriteAfterCmd(); autoAofRewriteAfterCmd > 1 {
59 | a.autoAofRewriteAfterCmd = int64(autoAofRewriteAfterCmd)
60 | }
61 |
62 | switch thinker.AppendFsync() {
63 | case alwaysAppendSyncStrategy.string():
64 | a.appendFsync = alwaysAppendSyncStrategy
65 | case everysecAppendSyncStrategy.string():
66 | a.appendFsync = everysecAppendSyncStrategy
67 | default:
68 | a.appendFsync = noAppendSyncStrategy // 默认策略
69 | }
70 |
71 | pool.Submit(a.run)
72 | return &a, nil
73 | }
74 |
75 | func (a *aofPersister) Reloader() (io.ReadCloser, error) {
76 | file, err := os.Open(a.aofFileName)
77 | if err != nil {
78 | return nil, err
79 | }
80 | _, _ = file.Seek(0, io.SeekStart)
81 | return file, nil
82 | }
83 |
84 | func (a *aofPersister) PersistCmd(ctx context.Context, cmd [][]byte) {
85 | if handler.IsLoadingPattern(ctx) {
86 | return
87 | }
88 | a.buffer <- cmd
89 | }
90 |
91 | func (a *aofPersister) Close() {
92 | a.once.Do(func() {
93 | a.cancel()
94 | _ = a.aofFile.Close()
95 | })
96 | }
97 |
98 | func (a *aofPersister) run() {
99 | if a.appendFsync == everysecAppendSyncStrategy {
100 | pool.Submit(a.fsyncEverySecond)
101 | }
102 |
103 | for {
104 | select {
105 | case <-a.ctx.Done():
106 | // log
107 | return
108 | case cmd := <-a.buffer:
109 | a.writeAof(cmd)
110 | a.aofTick()
111 | }
112 | }
113 | }
114 |
115 | // 记录执行的 aof 指令次数
116 | func (a *aofPersister) aofTick() {
117 | if a.autoAofRewriteAfterCmd <= 1 {
118 | return
119 | }
120 |
121 | if ticked := a.aofCounter.Add(1); ticked < int64(a.autoAofRewriteAfterCmd) {
122 | return
123 | }
124 |
125 | // 达到重写次数,扣减计数器,进行重写
126 | _ = a.aofCounter.Add(-a.autoAofRewriteAfterCmd)
127 | pool.Submit(func() {
128 | if err := a.rewriteAOF(); err != nil {
129 | // log
130 | }
131 | })
132 | }
133 |
134 | func (a *aofPersister) fsyncEverySecond() {
135 | ticker := time.NewTicker(time.Second)
136 | for {
137 | select {
138 | case <-a.ctx.Done():
139 | // log
140 | return
141 | case <-ticker.C:
142 | if err := a.fsync(); err != nil {
143 | // log
144 | }
145 | }
146 | }
147 | }
148 |
149 | func (a *aofPersister) writeAof(cmd [][]byte) {
150 | a.mu.Lock()
151 | defer a.mu.Unlock()
152 |
153 | persistCmd := handler.NewMultiBulkReply(cmd)
154 | if _, err := a.aofFile.Write(persistCmd.ToBytes()); err != nil {
155 | // log
156 | return
157 | }
158 |
159 | if a.appendFsync != alwaysAppendSyncStrategy {
160 | return
161 | }
162 |
163 | if err := a.fsyncLocked(); err != nil {
164 | // log
165 | }
166 | }
167 |
168 | func (a *aofPersister) fsync() error {
169 | a.mu.Lock()
170 | defer a.mu.Unlock()
171 | return a.fsyncLocked()
172 | }
173 |
174 | func (a *aofPersister) fsyncLocked() error {
175 | return a.aofFile.Sync()
176 | }
177 |
--------------------------------------------------------------------------------
/protocol/parser.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "io"
7 | "strconv"
8 |
9 | "github.com/xiaoxuxiansheng/goredis/handler"
10 | "github.com/xiaoxuxiansheng/goredis/lib/pool"
11 | "github.com/xiaoxuxiansheng/goredis/log"
12 | )
13 |
14 | type lineParser func(header []byte, reader *bufio.Reader) *handler.Droplet
15 |
16 | type Parser struct {
17 | lineParsers map[byte]lineParser
18 | logger log.Logger
19 | }
20 |
21 | func NewParser(logger log.Logger) handler.Parser {
22 | p := Parser{
23 | logger: logger,
24 | }
25 | p.lineParsers = map[byte]lineParser{
26 | '+': p.parseSimpleString,
27 | '-': p.parseError,
28 | ':': p.parseInt,
29 | '$': p.parseBulk,
30 | '*': p.parseMultiBulk,
31 | }
32 | return &p
33 | }
34 |
35 | func (p *Parser) ParseStream(reader io.Reader) <-chan *handler.Droplet {
36 | ch := make(chan *handler.Droplet)
37 | pool.Submit(
38 | func() {
39 | p.parse(reader, ch)
40 | })
41 | return ch
42 | }
43 |
44 | func (p *Parser) parse(rawReader io.Reader, ch chan<- *handler.Droplet) {
45 | reader := bufio.NewReader(rawReader)
46 | for {
47 | firstLine, err := reader.ReadBytes('\n')
48 | if err != nil {
49 | ch <- &handler.Droplet{
50 | Reply: handler.NewErrReply(err.Error()),
51 | Err: err,
52 | }
53 | return
54 | }
55 |
56 | length := len(firstLine)
57 | if length <= 2 || firstLine[length-1] != '\n' || firstLine[length-2] != '\r' {
58 | continue
59 | }
60 |
61 | firstLine = bytes.TrimSuffix(firstLine, []byte{'\r', '\n'})
62 | lineParseFunc, ok := p.lineParsers[firstLine[0]]
63 | if !ok {
64 | p.logger.Errorf("[parser] invalid line handler: %s", firstLine[0])
65 | continue
66 | }
67 |
68 | ch <- lineParseFunc(firstLine, reader)
69 | }
70 | }
71 |
72 | // 解析简单 string 类型
73 | func (p *Parser) parseSimpleString(header []byte, reader *bufio.Reader) *handler.Droplet {
74 | content := header[1:]
75 | return &handler.Droplet{
76 | Reply: handler.NewSimpleStringReply(string(content)),
77 | }
78 | }
79 |
80 | // 解析简单 int 类型
81 | func (p *Parser) parseInt(header []byte, reader *bufio.Reader) *handler.Droplet {
82 |
83 | i, err := strconv.ParseInt(string(header[1:]), 10, 64)
84 | if err != nil {
85 | return &handler.Droplet{
86 | Err: err,
87 | Reply: handler.NewErrReply(err.Error()),
88 | }
89 | }
90 |
91 | return &handler.Droplet{
92 | Reply: handler.NewIntReply(i),
93 | }
94 | }
95 |
96 | // 解析错误类型
97 | func (p *Parser) parseError(header []byte, reader *bufio.Reader) *handler.Droplet {
98 | return &handler.Droplet{
99 | Reply: handler.NewErrReply(string(header[1:])),
100 | }
101 | }
102 |
103 | // 解析定长 string 类型
104 | func (p *Parser) parseBulk(header []byte, reader *bufio.Reader) *handler.Droplet {
105 | // 解析定长 string
106 | body, err := p.parseBulkBody(header, reader)
107 | if err != nil {
108 | return &handler.Droplet{
109 | Reply: handler.NewErrReply(err.Error()),
110 | Err: err,
111 | }
112 | }
113 | return &handler.Droplet{
114 | Reply: handler.NewBulkReply(body),
115 | }
116 | }
117 |
118 | // 解析定长 string
119 | func (p *Parser) parseBulkBody(header []byte, reader *bufio.Reader) ([]byte, error) {
120 | // 获取 string 长度
121 | strLen, err := strconv.ParseInt(string(header[1:]), 10, 64)
122 | if err != nil {
123 | return nil, err
124 | }
125 |
126 | // 长度 + 2,把 CRLF 也考虑在内
127 | body := make([]byte, strLen+2)
128 | // 从 reader 中读取对应长度
129 | if _, err = io.ReadFull(reader, body); err != nil {
130 | return nil, err
131 | }
132 | return body[:len(body)-2], nil
133 | }
134 |
135 | // 解析
136 | func (p *Parser) parseMultiBulk(header []byte, reader *bufio.Reader) (droplet *handler.Droplet) {
137 | var _err error
138 | defer func() {
139 | if _err != nil {
140 | droplet = &handler.Droplet{
141 | Reply: handler.NewErrReply(_err.Error()),
142 | Err: _err,
143 | }
144 | }
145 | }()
146 |
147 | // 获取数组长度
148 | length, err := strconv.ParseInt(string(header[1:]), 10, 64)
149 | if err != nil {
150 | _err = err
151 | return
152 | }
153 |
154 | if length <= 0 {
155 | return &handler.Droplet{
156 | Reply: handler.NewEmptyMultiBulkReply(),
157 | }
158 | }
159 |
160 | lines := make([][]byte, 0, length)
161 | for i := int64(0); i < length; i++ {
162 | // 获取每个 bulk 首行
163 | firstLine, err := reader.ReadBytes('\n')
164 | if err != nil {
165 | _err = err
166 | return
167 | }
168 |
169 | // bulk 首行格式校验
170 | length := len(firstLine)
171 | if length < 4 || firstLine[length-2] != '\r' || firstLine[length-1] != '\n' || firstLine[0] != '$' {
172 | continue
173 | }
174 |
175 | // bulk 解析
176 | bulkBody, err := p.parseBulkBody(firstLine[:length-2], reader)
177 | if err != nil {
178 | _err = err
179 | return
180 | }
181 |
182 | lines = append(lines, bulkBody)
183 | }
184 |
185 | return &handler.Droplet{
186 | Reply: handler.NewMultiBulkReply(lines),
187 | }
188 | }
189 |
--------------------------------------------------------------------------------
/datastore/sorted_set.go:
--------------------------------------------------------------------------------
1 | package datastore
2 |
3 | import (
4 | "math"
5 | "math/rand"
6 | "strconv"
7 |
8 | "github.com/xiaoxuxiansheng/goredis/database"
9 | "github.com/xiaoxuxiansheng/goredis/handler"
10 | "github.com/xiaoxuxiansheng/goredis/lib"
11 | )
12 |
13 | func (k *KVStore) getAsSortedSet(key string) (SortedSet, error) {
14 | v, ok := k.data[key]
15 | if !ok {
16 | return nil, nil
17 | }
18 |
19 | zset, ok := v.(SortedSet)
20 | if !ok {
21 | return nil, handler.NewWrongTypeErrReply()
22 | }
23 |
24 | return zset, nil
25 | }
26 |
27 | func (k *KVStore) putAsSortedSet(key string, zset SortedSet) {
28 | k.data[key] = zset
29 | }
30 |
31 | type SortedSet interface {
32 | Add(score int64, member string)
33 | Rem(member string) int64
34 | Range(score1, score2 int64) []string
35 | database.CmdAdapter
36 | }
37 |
38 | type skiplist struct {
39 | key string
40 | scoreToNode map[int64]*skipnode
41 | memberToScore map[string]int64
42 | head *skipnode
43 | rander *rand.Rand
44 | }
45 |
46 | func newSkiplist(key string) SortedSet {
47 | return &skiplist{
48 | key: key,
49 | memberToScore: make(map[string]int64),
50 | scoreToNode: make(map[int64]*skipnode),
51 | head: newSkipnode(0, 0),
52 | rander: rand.New((rand.NewSource(lib.TimeNow().UnixNano()))),
53 | }
54 | }
55 |
56 | func (s *skiplist) Add(score int64, member string) {
57 | // 之前存在,需要删除
58 | oldScore, ok := s.memberToScore[member]
59 | if ok {
60 | if oldScore == score {
61 | return
62 | }
63 | s.rem(oldScore, member)
64 | }
65 |
66 | s.memberToScore[member] = score
67 | node, ok := s.scoreToNode[score]
68 | if ok {
69 | node.members[member] = struct{}{}
70 | return
71 | }
72 |
73 | // 新插入,roll 出高度
74 | height := s.roll()
75 | for int64(len(s.head.nexts)) < height+1 {
76 | s.head.nexts = append(s.head.nexts, nil)
77 | }
78 |
79 | inserted := newSkipnode(score, height+1)
80 | inserted.members[member] = struct{}{}
81 | s.scoreToNode[score] = inserted
82 |
83 | move := s.head
84 | for i := height; i >= 0; i-- {
85 | for move.nexts[i] != nil && move.nexts[i].score < score {
86 | move = move.nexts[i]
87 | continue
88 | }
89 |
90 | inserted.nexts[i] = move.nexts[i]
91 | move.nexts[i] = inserted
92 | }
93 | }
94 |
95 | func (s *skiplist) Rem(member string) int64 {
96 | // 之前存在,需要删除
97 | score, ok := s.memberToScore[member]
98 | if !ok {
99 | return 0
100 | }
101 | s.rem(score, member)
102 | return 1
103 | }
104 |
105 | // [score1,score2]
106 | func (s *skiplist) Range(score1, score2 int64) []string {
107 | if score2 == -1 {
108 | score2 = math.MaxInt64
109 | }
110 |
111 | if score1 > score2 {
112 | return []string{}
113 | }
114 |
115 | move := s.head
116 | for i := len(s.head.nexts) - 1; i >= 0; i-- {
117 | for move.nexts[i] != nil && move.nexts[i].score < score1 {
118 | move = move.nexts[i]
119 | }
120 | }
121 |
122 | // 来到了 level0 层,move.nexts[i] 如果存在,就是首个 >= score1 的元素
123 | if len(move.nexts) == 0 || move.nexts[0] == nil {
124 | return []string{}
125 | }
126 |
127 | res := []string{}
128 | for move.nexts[0] != nil && move.nexts[0].score >= score1 && move.nexts[0].score <= score2 {
129 | for member := range move.nexts[0].members {
130 | res = append(res, member)
131 | }
132 | move = move.nexts[0]
133 | }
134 | return res
135 | }
136 |
137 | func (s *skiplist) roll() int64 {
138 | var level int64
139 | for s.rander.Intn(2) > 0 {
140 | level++
141 | }
142 | return level
143 | }
144 |
145 | func (s *skiplist) rem(score int64, member string) {
146 | delete(s.memberToScore, member)
147 | skipnode := s.scoreToNode[score]
148 |
149 | delete(skipnode.members, member)
150 | if len(skipnode.members) > 0 {
151 | return
152 | }
153 |
154 | delete(s.scoreToNode, score)
155 | move := s.head
156 | for i := len(s.head.nexts) - 1; i >= 0; i-- {
157 | for move.nexts[i] != nil && move.nexts[i].score < score {
158 | move = move.nexts[i]
159 | }
160 |
161 | if move.nexts[i] == nil || move.nexts[i].score > score {
162 | continue
163 | }
164 |
165 | remed := move.nexts[i]
166 | move.nexts[i] = move.nexts[i].nexts[i]
167 | remed.nexts[i] = nil
168 | }
169 | }
170 |
171 | func (s *skiplist) ToCmd() [][]byte {
172 | args := make([][]byte, 0, 2+2*len(s.memberToScore))
173 | args = append(args, []byte(database.CmdTypeZAdd), []byte(s.key))
174 | for member, score := range s.memberToScore {
175 | scoreStr := strconv.FormatInt(score, 10)
176 | args = append(args, []byte(scoreStr), []byte(member))
177 | }
178 | return args
179 | }
180 |
181 | type skipnode struct {
182 | score int64
183 | members map[string]struct{}
184 | nexts []*skipnode
185 | }
186 |
187 | func newSkipnode(score, height int64) *skipnode {
188 | return &skipnode{
189 | score: score,
190 | members: make(map[string]struct{}),
191 | nexts: make([]*skipnode, height),
192 | }
193 | }
194 |
--------------------------------------------------------------------------------
/goredis_test.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "bufio"
5 | "fmt"
6 | "io"
7 | "math/rand"
8 | "net"
9 | "os"
10 | "sync"
11 | "testing"
12 | "time"
13 |
14 | "github.com/spf13/cast"
15 | "github.com/stretchr/testify/assert"
16 | "github.com/xiaoxuxiansheng/goredis/app"
17 | "github.com/xiaoxuxiansheng/goredis/lib"
18 | "github.com/xiaoxuxiansheng/goredis/lib/pool"
19 | )
20 |
21 | // goredis 质检员
22 | type QualityInspector struct {
23 | times int
24 | app *app.Application
25 | t *testing.T
26 | rander *rand.Rand
27 | }
28 |
29 | func NewQualityInspector(t *testing.T, times int) *QualityInspector {
30 | return &QualityInspector{
31 | t: t,
32 | times: times,
33 | rander: rand.New(rand.NewSource(lib.TimeNow().UnixNano())),
34 | }
35 | }
36 |
37 | func (q *QualityInspector) prepareApp(clean bool) error {
38 | // 1 移除 aof 文件,避免读取脏数据
39 | if clean {
40 | _ = os.Remove("./appendonly.aof")
41 | }
42 | // 1 创建应用
43 | server, err := app.ConstructServer()
44 | if err != nil {
45 | return err
46 | }
47 | q.app = app.NewApplication(server, app.SetUpConfig())
48 | // 2 异步启动应用
49 | pool.Submit(func() {
50 | if err := q.app.Run(); err != nil {
51 | q.t.Error(err)
52 | }
53 | })
54 | return nil
55 | }
56 |
57 | func (q *QualityInspector) connApp() (*net.TCPConn, error) {
58 | <-time.After(100 * time.Millisecond)
59 | // 建立 tcp 连接
60 | return net.DialTCP("tcp", nil, &net.TCPAddr{
61 | IP: net.IPv4(127, 0, 0, 1),
62 | Port: 6379,
63 | })
64 | }
65 |
66 | func (q *QualityInspector) execSet(w io.Writer) {
67 | writer := bufio.NewWriter(w)
68 | for i := 0; i < 2*q.times; i++ {
69 | k := cast.ToString(i % q.times)
70 | v := cast.ToString(i % q.times)
71 | _, _ = writer.WriteString("*3\r\n")
72 | _, _ = writer.WriteString("$3\r\n")
73 | _, _ = writer.WriteString("set\r\n")
74 | _, _ = writer.WriteString(fmt.Sprintf("$%d\r\n", len(k)))
75 | _, _ = writer.WriteString(fmt.Sprintf("%s\r\n", k))
76 | _, _ = writer.WriteString(fmt.Sprintf("$%d\r\n", len(v)))
77 | _, _ = writer.WriteString(fmt.Sprintf("%s\r\n", v))
78 | if err := writer.Flush(); err != nil {
79 | q.t.Error(err)
80 | }
81 | q.t.Logf("set k: %s, v: %s", k, v)
82 | }
83 | }
84 |
85 | func (q *QualityInspector) readSetResp(r io.Reader) {
86 | reader := bufio.NewReader(r)
87 | for i := 0; i < q.times; i++ {
88 | line, _, _ := reader.ReadLine()
89 | q.t.Logf("set resp: %s\n", line)
90 | }
91 | }
92 |
93 | func (q *QualityInspector) execGet(w io.Writer) {
94 | writer := bufio.NewWriter(w)
95 | for i := 0; i < q.times; i++ {
96 | k := cast.ToString(i)
97 | _, _ = writer.WriteString("*2\r\n")
98 | _, _ = writer.WriteString("$3\r\n")
99 | _, _ = writer.WriteString("get\r\n")
100 | _, _ = writer.WriteString(fmt.Sprintf("$%d\r\n", len(k)))
101 | _, _ = writer.WriteString(fmt.Sprintf("%s\r\n", k))
102 | if err := writer.Flush(); err != nil {
103 | q.t.Error(err)
104 | }
105 | q.t.Logf("get k: %s", k)
106 | }
107 | }
108 |
109 | func (q *QualityInspector) readGetResp(r io.Reader) {
110 | reader := bufio.NewReader(r)
111 | for i := 0; i < q.times; i++ {
112 | _, _, _ = reader.ReadLine() // $n
113 | line, _, _ := reader.ReadLine()
114 | q.t.Logf("get resp: %s\n", line)
115 | assert.Equal(q.t, cast.ToString(i), string(line))
116 | }
117 | }
118 |
119 | func Test_Goredis_Set_Get(t *testing.T) {
120 | q := NewQualityInspector(t, 100)
121 |
122 | // 1 启动 go redis. 保证全局唯一
123 | if err := q.prepareApp(true); err != nil {
124 | t.Error(err)
125 | return
126 | }
127 | defer q.app.Stop()
128 |
129 | // 2 连接到 go redis
130 | conn, err := q.connApp()
131 | if err != nil {
132 | t.Error(err)
133 | return
134 | }
135 | defer conn.Close()
136 |
137 | var wg sync.WaitGroup
138 | // 3 读取set结果
139 | wg.Add(1)
140 | pool.Submit(func() {
141 | defer wg.Done()
142 | q.readSetResp(conn)
143 | })
144 |
145 | // 4 发送set指令
146 | wg.Add(1)
147 | pool.Submit(func() {
148 | defer wg.Done()
149 | q.execSet(conn)
150 | })
151 |
152 | wg.Wait()
153 |
154 | // 5 读取get结果
155 | wg.Add(1)
156 | pool.Submit(func() {
157 | defer wg.Done()
158 | q.readGetResp(conn)
159 | })
160 |
161 | // 6 发送get指令
162 | wg.Add(1)
163 | pool.Submit(func() {
164 | defer wg.Done()
165 | q.execGet(conn)
166 | })
167 | wg.Wait()
168 |
169 | <-time.After(time.Second)
170 | }
171 |
172 | func Test_Goredis_Set(t *testing.T) {
173 | test_goredis_set(t) // 1 启动 goredis 2 set 数据 3 停止 goredis
174 | }
175 |
176 | func test_goredis_set(t *testing.T) {
177 | q := NewQualityInspector(t, 100)
178 |
179 | // 1 启动 go redis. 保证全局唯一
180 | if err := q.prepareApp(true); err != nil {
181 | t.Error(err)
182 | return
183 | }
184 | defer q.app.Stop()
185 |
186 | // 2 连接到 go redis
187 | conn, err := q.connApp()
188 | if err != nil {
189 | t.Error(err)
190 | return
191 | }
192 | defer conn.Close()
193 |
194 | var wg sync.WaitGroup
195 | // 3 读取set结果
196 | wg.Add(1)
197 | pool.Submit(func() {
198 | defer wg.Done()
199 | q.readSetResp(conn)
200 | })
201 |
202 | // 4 发送set指令
203 | wg.Add(1)
204 | pool.Submit(func() {
205 | defer wg.Done()
206 | q.execSet(conn)
207 | })
208 |
209 | wg.Wait()
210 | <-time.After(2 * time.Second)
211 | }
212 |
213 | func Test_Aof_Get(t *testing.T) {
214 | test_goredis_aof_get(t) // 1 启动 goredis(通过 aof 恢复数据)2 get 数据 3 停止 goredis
215 | }
216 |
217 | func test_goredis_aof_get(t *testing.T) {
218 | q := NewQualityInspector(t, 100)
219 |
220 | <-time.After(time.Second)
221 | // 1 启动 go redis. 不删除 aof 文件
222 | if err := q.prepareApp(false); err != nil {
223 | t.Error(err)
224 | return
225 | }
226 | defer q.app.Stop()
227 |
228 | // 2 连接到 go redis
229 | <-time.After(time.Second)
230 | conn, err := q.connApp()
231 | if err != nil {
232 | t.Error(err)
233 | return
234 | }
235 | defer conn.Close()
236 |
237 | var wg sync.WaitGroup
238 | // 5 读取get结果
239 | wg.Add(1)
240 | pool.Submit(func() {
241 | defer wg.Done()
242 | q.readGetResp(conn)
243 | })
244 |
245 | // 6 发送get指令
246 | wg.Add(1)
247 | pool.Submit(func() {
248 | defer wg.Done()
249 | q.execGet(conn)
250 | })
251 | wg.Wait()
252 | }
253 |
--------------------------------------------------------------------------------
/datastore/sorted_set_test.go:
--------------------------------------------------------------------------------
1 | package datastore
2 |
3 | import (
4 | "fmt"
5 | "math/rand"
6 | "sort"
7 | "strings"
8 | "testing"
9 |
10 | "github.com/spf13/cast"
11 | "github.com/stretchr/testify/assert"
12 | "github.com/xiaoxuxiansheng/goredis/database"
13 | "github.com/xiaoxuxiansheng/goredis/lib"
14 | )
15 |
16 | func Test_skiplist_add_rem_range(t *testing.T) {
17 | skiplist := newSkiplist("")
18 | // 添加 1000 条指令
19 | for i := 0; i < 1000; i++ {
20 | skiplist.Add(int64(i), fmt.Sprintf("%d_0", i))
21 | skiplist.Add(int64(i), fmt.Sprintf("%d_1", i))
22 | }
23 |
24 | // 随机移除 1000 个 member
25 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano()))
26 | remSet := make(map[string]struct{}, 1000)
27 | for i := 0; i < 1000; i++ {
28 | score := rander.Intn(1000)
29 | index := rander.Intn(2)
30 | member := fmt.Sprintf("%d_%d", score, index)
31 | remSet[member] = struct{}{}
32 | skiplist.Rem(member)
33 | }
34 |
35 | t.Run("single_score", func(t *testing.T) {
36 | for i := 0; i < 100; i++ {
37 | score := int64(rander.Intn(1000))
38 | member := skiplist.Range(score, score)
39 | sort.Slice(member, func(i, j int) bool {
40 | return member[i] < member[j]
41 | })
42 | expected := make([]string, 0, 2)
43 | member1 := fmt.Sprintf("%d_0", score)
44 | member2 := fmt.Sprintf("%d_1", score)
45 | if _, ok := remSet[member1]; !ok {
46 | expected = append(expected, member1)
47 | }
48 | if _, ok := remSet[member2]; !ok {
49 | expected = append(expected, member2)
50 | }
51 |
52 | assert.Equal(t, expected, member)
53 | }
54 | })
55 |
56 | t.Run("normal_score_range", func(t *testing.T) {
57 | for i := 0; i < 100; i++ {
58 | leftScore := int64(rander.Intn(501))
59 | rightScore := leftScore + int64(rander.Intn(500))
60 | member := skiplist.Range(leftScore, rightScore)
61 | sort.Slice(member, func(i, j int) bool {
62 | splitted1 := strings.Split(member[i], "_")
63 | splitted2 := strings.Split(member[j], "_")
64 | if splitted1[0] == splitted2[0] {
65 | return cast.ToInt(splitted1[1]) < cast.ToInt(splitted2[1])
66 | }
67 | return cast.ToInt(splitted1[0]) < cast.ToInt(splitted2[0])
68 | })
69 |
70 | expected := make([]string, 0, 2*(rightScore-leftScore+1))
71 | for j := leftScore; j <= rightScore; j++ {
72 | member1 := fmt.Sprintf("%d_0", j)
73 | member2 := fmt.Sprintf("%d_1", j)
74 | if _, ok := remSet[member1]; !ok {
75 | expected = append(expected, member1)
76 | }
77 | if _, ok := remSet[member2]; !ok {
78 | expected = append(expected, member2)
79 | }
80 | }
81 | assert.Equal(t, expected, member)
82 | }
83 | })
84 |
85 | t.Run("with_maximum_right_range", func(t *testing.T) {
86 | for i := 0; i < 100; i++ {
87 | leftScore := int64(rander.Intn(1000))
88 | rightScore := int64(-1)
89 | member := skiplist.Range(leftScore, rightScore)
90 | sort.Slice(member, func(i, j int) bool {
91 | splitted1 := strings.Split(member[i], "_")
92 | splitted2 := strings.Split(member[j], "_")
93 | if splitted1[0] == splitted2[0] {
94 | return cast.ToInt(splitted1[1]) < cast.ToInt(splitted2[1])
95 | }
96 | return cast.ToInt(splitted1[0]) < cast.ToInt(splitted2[0])
97 | })
98 |
99 | expected := make([]string, 0, 2*(1000-leftScore))
100 | for j := leftScore; j < 1000; j++ {
101 | member1 := fmt.Sprintf("%d_0", j)
102 | member2 := fmt.Sprintf("%d_1", j)
103 | if _, ok := remSet[member1]; !ok {
104 | expected = append(expected, member1)
105 | }
106 | if _, ok := remSet[member2]; !ok {
107 | expected = append(expected, member2)
108 | }
109 | }
110 | assert.Equal(t, expected, member)
111 | }
112 | })
113 | }
114 |
115 | func Test_skiplist_upsert_member_with_dif_score(t *testing.T) {
116 | skiplist := newSkiplist("")
117 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano()))
118 | scoreToMembers := make(map[int64][]string)
119 | memberSet := make(map[string]struct{})
120 | for i := 0; i < 1000; i++ {
121 | score1 := int64(rander.Intn(1000))
122 | member := cast.ToString(score1)
123 | if _, ok := memberSet[member]; ok {
124 | continue
125 | }
126 | memberSet[member] = struct{}{}
127 | skiplist.Add(score1, member)
128 | score2 := int64(rander.Intn(1000))
129 | skiplist.Add(score2, member)
130 | scoreToMembers[score2] = append(scoreToMembers[score2], member)
131 | }
132 |
133 | t.Run("score_to_members", func(t *testing.T) {
134 | for score, members := range scoreToMembers {
135 | sort.Slice(members, func(i, j int) bool {
136 | return cast.ToInt(members[i]) < cast.ToInt(members[j])
137 | })
138 |
139 | actualMembers := skiplist.Range(score, score)
140 | sort.Slice(actualMembers, func(i, j int) bool {
141 | return cast.ToInt(actualMembers[i]) < cast.ToInt(actualMembers[j])
142 | })
143 |
144 | assert.Equal(t, members, actualMembers)
145 |
146 | // member 对应的前一个 score 不能查询得到 member
147 | for _, member := range members {
148 | oldScore := cast.ToInt64(member)
149 | if oldScore == score {
150 | continue
151 | }
152 | for _, gotMember := range skiplist.Range(oldScore, oldScore) {
153 | if gotMember == member {
154 | t.Errorf("old score: %d, members: %s", oldScore, gotMember)
155 | }
156 | }
157 | }
158 | }
159 | })
160 | }
161 |
162 | func Test_skiplist_to_cmd(t *testing.T) {
163 | skiplist := newSkiplist("")
164 |
165 | rander := rand.New(rand.NewSource(lib.TimeNow().UnixNano()))
166 | memberToScore := make(map[int]int, 1000)
167 | // 插入1000条数据
168 | for i := 0; i < 1000; i++ {
169 | score := rander.Intn(1000)
170 | member := rander.Intn(1000)
171 | skiplist.Add(int64(score), cast.ToString(member))
172 | memberToScore[member] = score
173 | }
174 |
175 | cmd := skiplist.ToCmd()
176 | t.Run("length", func(t *testing.T) {
177 | assert.Equal(t, 2*len(memberToScore)+2, len(cmd))
178 | })
179 | t.Run("command", func(t *testing.T) {
180 | assert.Equal(t, database.CmdTypeZAdd, database.CmdType(cmd[0]))
181 | })
182 | t.Run("key", func(t *testing.T) {
183 | assert.Equal(t, "", string(cmd[1]))
184 | })
185 |
186 | type scoreToMember struct {
187 | score, member int
188 | }
189 | actual := make([]scoreToMember, 0, 1000)
190 | for i := 2; i < len(cmd); i += 2 {
191 | actual = append(actual, scoreToMember{
192 | score: cast.ToInt(string(cmd[i])),
193 | member: cast.ToInt(string(cmd[i+1])),
194 | })
195 | }
196 |
197 | sort.Slice(actual, func(i, j int) bool {
198 | if actual[i].score == actual[j].score {
199 | return actual[i].member < actual[j].member
200 | }
201 | return actual[i].score < actual[j].score
202 | })
203 |
204 | expect := make([]scoreToMember, 0, 2*len(memberToScore))
205 | for member, score := range memberToScore {
206 | expect = append(expect, scoreToMember{
207 | score: score,
208 | member: member,
209 | })
210 | }
211 | sort.Slice(expect, func(i, j int) bool {
212 | if expect[i].score == expect[j].score {
213 | return expect[i].member < expect[j].member
214 | }
215 | return expect[i].score < expect[j].score
216 | })
217 |
218 | t.Run("member", func(t *testing.T) {
219 | assert.Equal(t, expect, actual)
220 | })
221 | }
222 |
--------------------------------------------------------------------------------
/datastore/kv_store.go:
--------------------------------------------------------------------------------
1 | package datastore
2 |
3 | import (
4 | "context"
5 | "strconv"
6 | "strings"
7 | "time"
8 |
9 | "github.com/xiaoxuxiansheng/goredis/database"
10 | "github.com/xiaoxuxiansheng/goredis/handler"
11 | "github.com/xiaoxuxiansheng/goredis/lib"
12 | )
13 |
14 | type KVStore struct {
15 | data map[string]interface{}
16 | expiredAt map[string]time.Time
17 |
18 | expireTimeWheel SortedSet
19 |
20 | persister handler.Persister
21 | }
22 |
23 | func NewKVStore(persister handler.Persister) database.DataStore {
24 | return &KVStore{
25 | data: make(map[string]interface{}),
26 | expiredAt: make(map[string]time.Time),
27 | expireTimeWheel: newSkiplist("expireTimeWheel"),
28 | persister: persister,
29 | }
30 | }
31 |
32 | // expire
33 | func (k *KVStore) Expire(cmd *database.Command) handler.Reply {
34 | args := cmd.Args()
35 | key := string(args[0])
36 | ttl, err := strconv.ParseInt(string(args[1]), 10, 64)
37 | if err != nil {
38 | return handler.NewSyntaxErrReply()
39 | }
40 | if ttl <= 0 {
41 | return handler.NewErrReply("ERR invalid expire time")
42 | }
43 |
44 | expireAt := lib.TimeNow().Add(time.Duration(ttl) * time.Second)
45 | _cmd := [][]byte{[]byte(database.CmdTypeExpireAt), []byte(key), []byte(lib.TimeSecondFormat(expireAt))}
46 | return k.expireAt(cmd.Ctx(), _cmd, key, expireAt)
47 | }
48 |
49 | func (k *KVStore) ExpireAt(cmd *database.Command) handler.Reply {
50 | args := cmd.Args()
51 | key := string(args[0])
52 | expiredAt, err := lib.ParseTimeSecondFormat(string((args[1])))
53 | if err != nil {
54 | return handler.NewSyntaxErrReply()
55 | }
56 | if expiredAt.Before(lib.TimeNow()) {
57 | return handler.NewErrReply("ERR invalid expire time")
58 | }
59 |
60 | return k.expireAt(cmd.Ctx(), cmd.Cmd(), key, expiredAt)
61 | }
62 |
63 | func (k *KVStore) expireAt(ctx context.Context, cmd [][]byte, key string, expireAt time.Time) handler.Reply {
64 | k.expire(key, expireAt)
65 | k.persister.PersistCmd(ctx, cmd) // 持久化
66 | return handler.NewOKReply()
67 | }
68 |
69 | // string
70 | func (k *KVStore) Get(cmd *database.Command) handler.Reply {
71 | args := cmd.Args()
72 | key := string(args[0])
73 | v, err := k.getAsString(key)
74 | if err != nil {
75 | return handler.NewErrReply(err.Error())
76 | }
77 | if v == nil {
78 | return handler.NewNillReply()
79 | }
80 | return handler.NewBulkReply(v.Bytes())
81 | }
82 |
83 | func (k *KVStore) MGet(cmd *database.Command) handler.Reply {
84 | args := cmd.Args()
85 | res := make([][]byte, 0, len(args))
86 | for _, arg := range args {
87 | v, err := k.getAsString(string(arg))
88 | if err != nil {
89 | return handler.NewErrReply(err.Error())
90 | }
91 | if v == nil {
92 | res = append(res, []byte("(nil)"))
93 | continue
94 | }
95 | res = append(res, v.Bytes())
96 | }
97 |
98 | return handler.NewMultiBulkReply(res)
99 | }
100 |
101 | func (k *KVStore) Set(cmd *database.Command) handler.Reply {
102 | args := cmd.Args()
103 | key := string(args[0])
104 | value := string(args[1])
105 |
106 | // 支持 NX EX
107 | var (
108 | insertStrategy bool
109 | ttlStrategy bool
110 | ttlSeconds int64
111 | ttlIndex = -1
112 | )
113 |
114 | for i := 2; i < len(args); i++ {
115 | flag := strings.ToLower(string(args[i]))
116 | switch flag {
117 | case "nx":
118 | insertStrategy = true
119 | case "ex":
120 | // 重复的 ex 指令
121 | if ttlStrategy {
122 | return handler.NewSyntaxErrReply()
123 | }
124 | if i == len(args)-1 {
125 | return handler.NewSyntaxErrReply()
126 | }
127 | ttl, err := strconv.ParseInt(string(args[i+1]), 10, 64)
128 | if err != nil {
129 | return handler.NewSyntaxErrReply()
130 | }
131 | if ttl <= 0 {
132 | return handler.NewErrReply("ERR invalid expire time")
133 | }
134 |
135 | ttlStrategy = true
136 | ttlSeconds = ttl
137 | ttlIndex = i
138 | i++
139 | default:
140 | return handler.NewSyntaxErrReply()
141 | }
142 | }
143 |
144 | // 将 args 剔除 ex 部分,进行持久化
145 | if ttlIndex != -1 {
146 | args = append(args[:ttlIndex], args[ttlIndex+2:]...)
147 | }
148 |
149 | // 设置
150 | affected := k.put(key, value, insertStrategy)
151 | if affected > 0 && ttlStrategy {
152 | expireAt := lib.TimeNow().Add(time.Duration(ttlSeconds) * time.Second)
153 | _cmd := [][]byte{[]byte(database.CmdTypeExpireAt), []byte(key), []byte(lib.TimeSecondFormat(expireAt))}
154 | _ = k.expireAt(cmd.Ctx(), _cmd, key, expireAt) // 其中会完成 ex 信息的持久化
155 | }
156 |
157 | // 过期时间处理
158 | if affected > 0 {
159 | k.persister.PersistCmd(cmd.Ctx(), append([][]byte{[]byte(database.CmdTypeSet)}, args...))
160 | return handler.NewIntReply(affected)
161 | }
162 |
163 | return handler.NewNillReply()
164 | }
165 |
166 | func (k *KVStore) MSet(cmd *database.Command) handler.Reply {
167 | args := cmd.Args()
168 | if len(args)&1 == 1 {
169 | return handler.NewSyntaxErrReply()
170 | }
171 |
172 | for i := 0; i < len(args); i += 2 {
173 | _ = k.put(string(args[i]), string(args[i+1]), false)
174 | }
175 |
176 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd())
177 | return handler.NewIntReply(int64(len(args) >> 1))
178 | }
179 |
180 | // list
181 | func (k *KVStore) LPush(cmd *database.Command) handler.Reply {
182 | args := cmd.Args()
183 | key := string(args[0])
184 | list, err := k.getAsList(key)
185 | if err != nil {
186 | return handler.NewErrReply(err.Error())
187 | }
188 |
189 | if list == nil {
190 | list = newListEntity(key)
191 | k.putAsList(key, list)
192 | }
193 |
194 | for i := 1; i < len(args); i++ {
195 | list.LPush(args[i])
196 | }
197 |
198 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd())
199 | return handler.NewIntReply(list.Len())
200 | }
201 |
202 | func (k *KVStore) LPop(cmd *database.Command) handler.Reply {
203 | args := cmd.Args()
204 | key := string(args[0])
205 | var cnt int64
206 | if len(args) > 1 {
207 | rawCnt, err := strconv.ParseInt(string(args[1]), 10, 64)
208 | if err != nil {
209 | return handler.NewSyntaxErrReply()
210 | }
211 | if rawCnt < 1 {
212 | return handler.NewSyntaxErrReply()
213 | }
214 | cnt = rawCnt
215 | }
216 |
217 | list, err := k.getAsList(key)
218 | if err != nil {
219 | return handler.NewErrReply(err.Error())
220 | }
221 |
222 | if list == nil {
223 | return handler.NewNillReply()
224 | }
225 |
226 | if cnt == 0 {
227 | cnt = 1
228 | }
229 |
230 | poped := list.LPop(cnt)
231 | if poped == nil {
232 | return handler.NewNillReply()
233 | }
234 |
235 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化
236 |
237 | if len(poped) == 1 {
238 | return handler.NewBulkReply(poped[0])
239 | }
240 |
241 | return handler.NewMultiBulkReply(poped)
242 | }
243 |
244 | func (k *KVStore) RPush(cmd *database.Command) handler.Reply {
245 | args := cmd.Args()
246 | key := string(args[0])
247 | list, err := k.getAsList(key)
248 | if err != nil {
249 | return handler.NewErrReply(err.Error())
250 | }
251 |
252 | if list == nil {
253 | list = newListEntity(key, args[1:]...)
254 | k.putAsList(key, list)
255 | return handler.NewIntReply(list.Len())
256 | }
257 |
258 | for i := 1; i < len(args); i++ {
259 | list.RPush(args[i])
260 | }
261 |
262 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化
263 | return handler.NewIntReply(list.Len())
264 | }
265 |
266 | func (k *KVStore) RPop(cmd *database.Command) handler.Reply {
267 | args := cmd.Args()
268 | key := string(args[0])
269 | var cnt int64
270 | if len(args) > 1 {
271 | rawCnt, err := strconv.ParseInt(string(args[1]), 10, 64)
272 | if err != nil {
273 | return handler.NewSyntaxErrReply()
274 | }
275 | if rawCnt < 1 {
276 | return handler.NewSyntaxErrReply()
277 | }
278 | cnt = rawCnt
279 | }
280 |
281 | list, err := k.getAsList(key)
282 | if err != nil {
283 | return handler.NewErrReply(err.Error())
284 | }
285 |
286 | if list == nil {
287 | return handler.NewNillReply()
288 | }
289 |
290 | if cnt == 0 {
291 | cnt = 1
292 | }
293 |
294 | poped := list.RPop(cnt)
295 | if poped == nil {
296 | return handler.NewNillReply()
297 | }
298 |
299 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化
300 | if len(poped) == 1 {
301 | return handler.NewBulkReply(poped[0])
302 | }
303 |
304 | return handler.NewMultiBulkReply(poped)
305 | }
306 |
307 | func (k *KVStore) LRange(cmd *database.Command) handler.Reply {
308 | args := cmd.Args()
309 | if len(args) != 3 {
310 | return handler.NewSyntaxErrReply()
311 | }
312 |
313 | key := string(args[0])
314 | start, err := strconv.ParseInt(string(args[1]), 10, 64)
315 | if err != nil {
316 | return handler.NewSyntaxErrReply()
317 | }
318 |
319 | stop, err := strconv.ParseInt(string(args[2]), 10, 64)
320 | if err != nil {
321 | return handler.NewSyntaxErrReply()
322 | }
323 |
324 | list, err := k.getAsList(key)
325 | if err != nil {
326 | return handler.NewErrReply(err.Error())
327 | }
328 |
329 | if list == nil {
330 | return handler.NewNillReply()
331 | }
332 |
333 | if got := list.Range(start, stop); got != nil {
334 | return handler.NewMultiBulkReply(got)
335 | }
336 |
337 | return handler.NewNillReply()
338 | }
339 |
340 | // set
341 | func (k *KVStore) SAdd(cmd *database.Command) handler.Reply {
342 | args := cmd.Args()
343 | key := string(args[0])
344 | set, err := k.getAsSet(key)
345 | if err != nil {
346 | return handler.NewErrReply(err.Error())
347 | }
348 |
349 | if set == nil {
350 | set = newSetEntity(key)
351 | k.putAsSet(key, set)
352 | }
353 |
354 | var added int64
355 | for _, arg := range args[1:] {
356 | added += set.Add(string(arg))
357 | }
358 |
359 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化
360 | return handler.NewIntReply(added)
361 | }
362 |
363 | func (k *KVStore) SIsMember(cmd *database.Command) handler.Reply {
364 | args := cmd.Args()
365 | if len(args) != 2 {
366 | return handler.NewSyntaxErrReply()
367 | }
368 |
369 | key := string(args[0])
370 | set, err := k.getAsSet(key)
371 | if err != nil {
372 | return handler.NewErrReply(err.Error())
373 | }
374 |
375 | if set == nil {
376 | return handler.NewIntReply(0)
377 | }
378 |
379 | return handler.NewIntReply(set.Exist(string(args[1])))
380 | }
381 |
382 | func (k *KVStore) SRem(cmd *database.Command) handler.Reply {
383 | args := cmd.Args()
384 | key := string(args[0])
385 | set, err := k.getAsSet(key)
386 | if err != nil {
387 | return handler.NewErrReply(err.Error())
388 | }
389 |
390 | if set == nil {
391 | return handler.NewIntReply(0)
392 | }
393 |
394 | var remed int64
395 | for _, arg := range args[1:] {
396 | remed += set.Rem(string(arg))
397 | }
398 |
399 | if remed > 0 {
400 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化
401 | }
402 | return handler.NewIntReply(remed)
403 | }
404 |
405 | // hash
406 | func (k *KVStore) HSet(cmd *database.Command) handler.Reply {
407 | args := cmd.Args()
408 | if len(args)&1 != 1 {
409 | return handler.NewSyntaxErrReply()
410 | }
411 |
412 | key := string(args[0])
413 | hmap, err := k.getAsHashMap(key)
414 | if err != nil {
415 | return handler.NewErrReply(err.Error())
416 | }
417 |
418 | if hmap == nil {
419 | hmap = newHashMapEntity(key)
420 | k.putAsHashMap(key, hmap)
421 | }
422 |
423 | for i := 0; i < len(args)-1; i += 2 {
424 | hkey := string(args[i+1])
425 | hvalue := args[i+2]
426 | hmap.Put(hkey, hvalue)
427 | }
428 |
429 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化
430 | return handler.NewIntReply(int64((len(args) - 1) >> 1))
431 | }
432 |
433 | func (k *KVStore) HGet(cmd *database.Command) handler.Reply {
434 | args := cmd.Args()
435 | key := string(args[0])
436 | hmap, err := k.getAsHashMap(key)
437 | if err != nil {
438 | return handler.NewErrReply(err.Error())
439 | }
440 |
441 | if hmap == nil {
442 | return handler.NewNillReply()
443 | }
444 |
445 | if v := hmap.Get(string(args[1])); v != nil {
446 | return handler.NewBulkReply(v)
447 | }
448 |
449 | return handler.NewNillReply()
450 | }
451 |
452 | func (k *KVStore) HDel(cmd *database.Command) handler.Reply {
453 | args := cmd.Args()
454 | key := string(args[0])
455 | hmap, err := k.getAsHashMap(key)
456 | if err != nil {
457 | return handler.NewErrReply(err.Error())
458 | }
459 |
460 | if hmap == nil {
461 | return handler.NewIntReply(0)
462 | }
463 |
464 | var remed int64
465 | for _, arg := range args[1:] {
466 | remed += hmap.Del(string(arg))
467 | }
468 |
469 | if remed > 0 {
470 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化
471 | }
472 | return handler.NewIntReply(remed)
473 | }
474 |
475 | // sorted set
476 | func (k *KVStore) ZAdd(cmd *database.Command) handler.Reply {
477 | args := cmd.Args()
478 | if len(args)&1 != 1 {
479 | return handler.NewSyntaxErrReply()
480 | }
481 |
482 | key := string(args[0])
483 | var (
484 | scores = make([]int64, 0, (len(args)-1)>>1)
485 | members = make([]string, 0, (len(args)-1)>>1)
486 | )
487 |
488 | for i := 0; i < len(args)-1; i += 2 {
489 | score, err := strconv.ParseInt(string(args[i+1]), 10, 64)
490 | if err != nil {
491 | return handler.NewSyntaxErrReply()
492 | }
493 |
494 | scores = append(scores, score)
495 | members = append(members, string(args[i+2]))
496 | }
497 |
498 | zset, err := k.getAsSortedSet(key)
499 | if err != nil {
500 | return handler.NewErrReply(err.Error())
501 | }
502 |
503 | if zset == nil {
504 | zset = newSkiplist(key)
505 | k.putAsSortedSet(key, zset)
506 | }
507 |
508 | for i := 0; i < len(scores); i++ {
509 | zset.Add(scores[i], members[i])
510 | }
511 |
512 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化
513 | return handler.NewIntReply(int64(len(scores)))
514 | }
515 |
516 | func (k *KVStore) ZRangeByScore(cmd *database.Command) handler.Reply {
517 | args := cmd.Args()
518 | if len(args) < 3 {
519 | return handler.NewSyntaxErrReply()
520 | }
521 |
522 | key := string(args[0])
523 | score1, err := strconv.ParseInt(string(args[1]), 10, 64)
524 | if err != nil {
525 | return handler.NewSyntaxErrReply()
526 | }
527 | score2, err := strconv.ParseInt(string(args[2]), 10, 64)
528 | if err != nil {
529 | return handler.NewSyntaxErrReply()
530 | }
531 |
532 | zset, err := k.getAsSortedSet(key)
533 | if err != nil {
534 | return handler.NewErrReply(err.Error())
535 | }
536 |
537 | if zset == nil {
538 | return handler.NewNillReply()
539 | }
540 |
541 | rawRes := zset.Range(score1, score2)
542 | if len(rawRes) == 0 {
543 | return handler.NewNillReply()
544 | }
545 |
546 | res := make([][]byte, 0, len(rawRes))
547 | for _, item := range rawRes {
548 | res = append(res, []byte(item))
549 | }
550 |
551 | return handler.NewMultiBulkReply(res)
552 | }
553 |
554 | func (k *KVStore) ZRem(cmd *database.Command) handler.Reply {
555 | args := cmd.Args()
556 | key := string(args[0])
557 | zset, err := k.getAsSortedSet(key)
558 | if err != nil {
559 | return handler.NewErrReply(err.Error())
560 | }
561 |
562 | if zset == nil {
563 | return handler.NewIntReply(0)
564 | }
565 |
566 | var remed int64
567 | for _, arg := range args {
568 | remed += zset.Rem(string(arg))
569 | }
570 |
571 | if remed > 0 {
572 | k.persister.PersistCmd(cmd.Ctx(), cmd.Cmd()) // 持久化
573 | }
574 | return handler.NewIntReply(remed)
575 | }
576 |
--------------------------------------------------------------------------------