├── .gitignore ├── README.md ├── abstract ├── connection.go └── engine.go ├── aof ├── aof.go ├── cmd.go └── rewrite.go ├── cluster ├── cluster.go ├── conn_pool.go ├── mset.go ├── relay.go ├── router.go ├── tcc.go ├── tcc_test.go ├── transaction.go └── utils.go ├── datastruct ├── dict │ ├── concurrent.go │ └── dict.go ├── list │ └── linkedlist.go └── sortedset │ ├── border.go │ ├── skiplist.go │ ├── skiplist_test.go │ └── sortedset.go ├── doc ├── 1.tcp服务 │ ├── image-1.png │ ├── image-2.png │ ├── image-3.png │ ├── image-4.png │ ├── image.png │ └── tcp服务.md ├── 10.对象池 │ ├── image-1.png │ ├── image.png │ └── pool.md ├── 11.分布式集群 │ ├── cluster.md │ ├── image-1.png │ ├── image-2.png │ └── image.png ├── 12.分布式事务TCC │ ├── image.png │ ├── tcc.md │ └── transaction.png ├── 2.Redis序列化协议 │ ├── RESP.md │ └── image.png ├── 3.内存数据库 │ ├── image-1.png │ ├── image-2.png │ ├── image.png │ └── 内存数据库.md ├── 4.延迟算法(时间轮) │ ├── image-1.png │ ├── image-2.png │ ├── image-3.png │ ├── image.png │ └── 时间轮.md ├── 5.持久化之AOF │ ├── aof.md │ ├── image-1.png │ ├── image-2.png │ ├── image-3.png │ └── image.png ├── 6.发布订阅 │ ├── image-1.png │ ├── image-2.png │ ├── image-3.png │ ├── image.png │ └── 发布订阅.md ├── 7.跳表的实现 │ ├── image-1.png │ ├── image-2.png │ ├── image-3.png │ ├── image-4.png │ ├── image-5.png │ ├── image.png │ └── skiplist.md ├── 8.pipeline客户端 │ ├── client.md │ └── image.png └── 9.事务 │ ├── image-1.png │ ├── image.png │ └── 事务.md ├── engine ├── commoncmd.go ├── database.go ├── engine.go ├── keys.go ├── payload │ └── payload.go ├── register.go ├── sortedset.go ├── string.go ├── systemcmd.go ├── transaction.go └── utils.go ├── go.mod ├── go.sum ├── image-1.png ├── image-2.png ├── image-3.png ├── image.png ├── main.go ├── pubhub └── pubhub.go ├── redis-cli.sh ├── redis-cluster0.sh ├── redis-cluster1.sh ├── redis-cluster2.sh ├── redis ├── client │ ├── client.go │ └── client_test.go ├── connection │ ├── conn.go │ └── virtualconn.go ├── handler.go ├── parser │ └── parser.go └── protocol │ ├── basic.go │ ├── bulk.go │ ├── errors.go │ └── interface.go ├── redis0.conf ├── redis1.conf ├── redis2.conf ├── tcpserver └── tcpserver.go ├── test.conf ├── test.sh ├── tool ├── conf │ ├── config.go │ └── config_test.go ├── consistenthash │ ├── consistenthash.go │ └── consistenthash_test.go ├── idgenerator │ ├── snowflake.go │ └── snowflake_test.go ├── locker │ ├── locker.go │ └── locker_test.go ├── logger │ ├── logger.go │ └── logger_test.go ├── pool │ ├── pool.go │ └── pool_test.go ├── timewheel │ ├── delay.go │ ├── delay_test.go │ └── timewheel.go ├── wait │ └── wait.go └── wildcard │ ├── wildcard.go │ └── wildcard_test.go └── utils ├── cmdline.go ├── const.go ├── hash.go ├── logo.go ├── path.go └── rand.go /.gitignore: -------------------------------------------------------------------------------- 1 | logs 2 | *.aof 3 | tmp 4 | data -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Golang实现自己的Redis 2 | 3 | 4 | 用 11篇文章实现一个遵循标准的Redis服务,姑且叫**EasyRedis**吧,希望通过文章将Redis掰开撕碎了让大家有更直观的理解,而不是仅仅停留在八股文的层面,而是非常爽的感觉,欢迎持续关注学习。 5 | 6 | ## 单机版 7 | - [x] [easyredis之TCP服务](https://github.com/gofish2020/easyredis/blob/main/doc/1.tcp%E6%9C%8D%E5%8A%A1/tcp%E6%9C%8D%E5%8A%A1.md) 8 | - [x] [easyredis之网络请求序列化协议(RESP)](https://github.com/gofish2020/easyredis/blob/main/doc/2.Redis%E5%BA%8F%E5%88%97%E5%8C%96%E5%8D%8F%E8%AE%AE/RESP.md) 9 | - [x] [easyredis之内存数据库](https://github.com/gofish2020/easyredis/blob/main/doc/3.%E5%86%85%E5%AD%98%E6%95%B0%E6%8D%AE%E5%BA%93/%E5%86%85%E5%AD%98%E6%95%B0%E6%8D%AE%E5%BA%93.md) 10 | - [x] [easyredis之过期时间 (时间轮实现)](https://github.com/gofish2020/easyredis/blob/main/doc/4.%E5%BB%B6%E8%BF%9F%E7%AE%97%E6%B3%95(%E6%97%B6%E9%97%B4%E8%BD%AE)/%E6%97%B6%E9%97%B4%E8%BD%AE.md) 11 | - [x] [easyredis之持久化 (AOF实现)](https://github.com/gofish2020/easyredis/blob/main/doc/5.%E6%8C%81%E4%B9%85%E5%8C%96%E4%B9%8BAOF/aof.md) 12 | - [x] [easyredis之发布订阅功能](https://github.com/gofish2020/easyredis/blob/main/doc/6.%E5%8F%91%E5%B8%83%E8%AE%A2%E9%98%85/%E5%8F%91%E5%B8%83%E8%AE%A2%E9%98%85.md) 13 | - [x] [easyredis之有序集合(跳表实现)](https://github.com/gofish2020/easyredis/blob/main/doc/7.%E8%B7%B3%E8%A1%A8%E7%9A%84%E5%AE%9E%E7%8E%B0/skiplist.md) 14 | - [x] [easyredis之 pipeline 客户端实现](https://github.com/gofish2020/easyredis/blob/main/doc/8.pipeline%E5%AE%A2%E6%88%B7%E7%AB%AF/client.md) 15 | - [x] [easyredis之事务(原子性/回滚)](https://github.com/gofish2020/easyredis/blob/main/doc/9.%E4%BA%8B%E5%8A%A1/%E4%BA%8B%E5%8A%A1.md) 16 | 17 | ## 分布式 18 | - [x] [easyredis之连接池](https://github.com/gofish2020/easyredis/blob/main/doc/10.%E5%AF%B9%E8%B1%A1%E6%B1%A0/pool.md) 19 | - [x] [easyredis之分布式集群存储](https://github.com/gofish2020/easyredis/blob/main/doc/11.%E5%88%86%E5%B8%83%E5%BC%8F%E9%9B%86%E7%BE%A4/cluster.md) 20 | 21 | ## 补充篇 22 | - [x] [分布式事务 TCC](https://github.com/gofish2020/easyredis/blob/main/doc/12.%E5%88%86%E5%B8%83%E5%BC%8F%E4%BA%8B%E5%8A%A1TCC/tcc.md) 23 | 24 | 25 | 26 | 27 | ## 使用说明 28 | 29 | ### 单机版 30 | - 使用`./test.sh`命令启动单机版服务端 31 | - 使用`./redis-cli.sh`命令启动官方端redis客户端,连接服务(需要你本机自己安装redis-cli并加入到环境变量中) 32 | 33 | 效果图如下: 34 | 启动服务端 35 | ![](image.png) 36 | 37 | 客户端连接: 38 | 39 | ![](image-1.png) 40 | 41 | 42 | ### 分布式 43 | 44 | - 使用`./redis-cluster0.sh` `./redis-cluster1.sh` `./redis-cluster2.sh`命令启动3个服务端 45 | - 使用`./redis-cli.sh`命令启动官方端redis客户端,连接服务(需要你本机自己安装redis-cli并加入到环境变量中) 46 | 47 | 效果图如下 48 | 启动服务端 49 | ![](image-2.png) 50 | 51 | 客户端连接: 52 | 53 | ![](image-3.png) -------------------------------------------------------------------------------- /abstract/connection.go: -------------------------------------------------------------------------------- 1 | package abstract 2 | 3 | type Connection interface { 4 | GetDBIndex() int 5 | SetDBIndex(int) 6 | SetPassword(string) 7 | GetPassword() string 8 | Write([]byte) (int, error) 9 | 10 | IsClosed() bool 11 | // pub/sub 12 | Subscribe(channel string) 13 | Unsubscribe(channel string) 14 | SubCount() int 15 | GetChannels() []string 16 | 17 | // transaction 18 | 19 | IsTransaction() bool 20 | SetTransaction(bool) 21 | 22 | EnqueueCmd(redisCommand [][]byte) 23 | GetQueuedCmdLine() [][][]byte 24 | 25 | GetWatchKey() map[string]int64 26 | CleanWatchKey() 27 | 28 | AddTxError(err error) 29 | GetTxErrors() []error 30 | } 31 | -------------------------------------------------------------------------------- /abstract/engine.go: -------------------------------------------------------------------------------- 1 | package abstract 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/gofish2020/easyredis/engine/payload" 7 | "github.com/gofish2020/easyredis/redis/protocol" 8 | ) 9 | 10 | type Engine interface { 11 | Exec(c Connection, redisCommand [][]byte) (result protocol.Reply) 12 | ForEach(dbIndex int, cb func(key string, data *payload.DataEntity, expiration *time.Time) bool) 13 | Close() 14 | } 15 | -------------------------------------------------------------------------------- /aof/aof.go: -------------------------------------------------------------------------------- 1 | package aof 2 | 3 | import ( 4 | "io" 5 | "os" 6 | "strconv" 7 | "strings" 8 | "sync" 9 | "sync/atomic" 10 | "time" 11 | 12 | "github.com/gofish2020/easyredis/abstract" 13 | "github.com/gofish2020/easyredis/redis/connection" 14 | "github.com/gofish2020/easyredis/redis/parser" 15 | "github.com/gofish2020/easyredis/redis/protocol" 16 | "github.com/gofish2020/easyredis/tool/logger" 17 | ) 18 | 19 | /* 20 | 原理类似于写日志,engine 中只会持有一个aof对象,通过生产者消费者模型,将数据写入到磁盘文件中 21 | */ 22 | 23 | const ( 24 | aofChanSize = 1 << 20 25 | ) 26 | 27 | const ( 28 | // 每次写入命令 & 刷盘 29 | FsyncAlways = "always" 30 | // 每秒刷盘 31 | FsyncEverySec = "everysec" 32 | // 不主动刷盘,取决于操作系统刷盘 33 | FsyncNo = "no" 34 | ) 35 | 36 | type Command = [][]byte 37 | 38 | type aofRecord struct { 39 | dbIndex int 40 | command Command 41 | } 42 | 43 | type AOF struct { 44 | // aof 文件句柄 45 | aofFile *os.File 46 | // aof 文件路径 47 | aofFileName string 48 | // 刷盘间隔 49 | aofFsync string 50 | // 最后写入aof日志的数据库索引 51 | lastDBIndex int 52 | // 保存aof记录通道 53 | aofChan chan aofRecord 54 | // 互斥锁 55 | mu sync.Mutex 56 | 57 | // aofChan读取完毕 58 | aofFinished chan struct{} 59 | // 关闭定时刷盘 60 | closed chan struct{} 61 | // 禁止aofChan的写入 62 | atomicClose atomic.Bool 63 | // 引擎 *Engine 64 | engine abstract.Engine 65 | } 66 | 67 | // 构建AOF对象 68 | func NewAOF(aofFileName string, engine abstract.Engine, load bool, fsync string) (*AOF, error) { 69 | aof := &AOF{} 70 | aof.aofFileName = aofFileName 71 | aof.aofFsync = strings.ToLower(fsync) 72 | aof.lastDBIndex = 0 73 | aof.aofChan = make(chan aofRecord, aofChanSize) 74 | aof.closed = make(chan struct{}) 75 | aof.aofFinished = make(chan struct{}) 76 | aof.engine = engine 77 | aof.atomicClose.Store(false) 78 | 79 | // 启动加载aof文件 80 | if load { 81 | aof.LoadAof(0) 82 | } 83 | 84 | // 打开文件(追加写/创建/读写) 85 | aofFile, err := os.OpenFile(aof.aofFileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600) 86 | if err != nil { 87 | return nil, err 88 | } 89 | aof.aofFile = aofFile 90 | 91 | // 启动协程:每秒刷盘 92 | if aof.aofFsync == FsyncEverySec { 93 | aof.fsyncEverySec() 94 | } 95 | // 启动协程:检测aofChan 96 | go aof.watchChan() 97 | return aof, nil 98 | } 99 | 100 | func (aof *AOF) watchChan() { 101 | 102 | for record := range aof.aofChan { 103 | aof.writeAofRecord(record) 104 | } 105 | aof.aofFinished <- struct{}{} 106 | } 107 | 108 | func (aof *AOF) SaveRedisCommand(dbIndex int, command Command) { 109 | 110 | // 关闭 111 | if aof.atomicClose.Load() { 112 | return 113 | } 114 | // 写入文件 & 刷盘 115 | if aof.aofFsync == FsyncAlways { 116 | record := aofRecord{ 117 | dbIndex: dbIndex, 118 | command: command, 119 | } 120 | aof.writeAofRecord(record) 121 | return 122 | } 123 | // 写入缓冲 124 | aof.aofChan <- aofRecord{ 125 | dbIndex: dbIndex, 126 | command: command, 127 | } 128 | } 129 | 130 | func (aof *AOF) writeAofRecord(record aofRecord) { 131 | 132 | aof.mu.Lock() 133 | defer aof.mu.Unlock() 134 | 135 | // 因为aof对象是所有数据库对象【复用】写入文件方法,每个数据库的索引不同 136 | // 所以,每个命令的执行,有个前提就是操作的不同的数据库 137 | if record.dbIndex != aof.lastDBIndex { 138 | // 构建select index 命令 & 写入文件 139 | selectCommand := [][]byte{[]byte("select"), []byte(strconv.Itoa(record.dbIndex))} 140 | data := protocol.NewMultiBulkReply(selectCommand).ToBytes() 141 | _, err := aof.aofFile.Write(data) 142 | if err != nil { 143 | logger.Warn(err) 144 | return 145 | } 146 | aof.lastDBIndex = record.dbIndex 147 | } 148 | 149 | // redis命令 150 | data := protocol.NewMultiBulkReply(record.command).ToBytes() 151 | _, err := aof.aofFile.Write(data) 152 | if err != nil { 153 | logger.Warn(err) 154 | } 155 | logger.Debugf("write aof command:%q", data) 156 | // 每次写入刷盘 157 | if aof.aofFsync == FsyncAlways { 158 | aof.aofFile.Sync() 159 | } 160 | } 161 | 162 | func (aof *AOF) Fsync() { 163 | aof.mu.Lock() 164 | defer aof.mu.Unlock() 165 | if err := aof.aofFile.Sync(); err != nil { 166 | logger.Errorf("aof sync err:%+v", err) 167 | } 168 | } 169 | 170 | func (aof *AOF) fsyncEverySec() { 171 | // 每秒刷盘 172 | ticker := time.NewTicker(1 * time.Second) 173 | go func() { 174 | for { 175 | select { 176 | case <-ticker.C: 177 | aof.Fsync() 178 | case <-aof.closed: 179 | return 180 | } 181 | } 182 | }() 183 | } 184 | 185 | // 加载aof文件,maxBytes限定读取的字节数 186 | func (aof *AOF) LoadAof(maxBytes int) { 187 | 188 | // 目的:当加载aof文件的时候,因为需要复用engine对象,内部重放命令的时候会自动写aof日志,加载aof 禁用 SaveRedisCommand的写入 189 | aof.atomicClose.Store(true) 190 | defer func() { 191 | aof.atomicClose.Store(false) 192 | }() 193 | 194 | // 只读打开文件 195 | file, err := os.Open(aof.aofFileName) 196 | if err != nil { 197 | logger.Error(err.Error()) 198 | return 199 | } 200 | defer file.Close() 201 | file.Seek(0, io.SeekStart) 202 | 203 | var reader io.Reader 204 | if maxBytes > 0 { // 限定读取的字节大小 205 | reader = io.LimitReader(file, int64(maxBytes)) 206 | } else { // 不限定,直接读取到文件结尾(为止) 207 | reader = file 208 | } 209 | 210 | // 文件中保存的格式和网络传输的格式一致 211 | ch := parser.ParseStream(reader) 212 | virtualConn := connection.NewVirtualConn() 213 | 214 | for payload := range ch { 215 | if payload.Err != nil { 216 | // 文件已经读取到“完成“ 217 | if payload.Err == io.EOF { 218 | break 219 | } 220 | // 读取到非法的格式 221 | logger.Errorf("LoadAof parser error %+v:", payload.Err) 222 | continue 223 | } 224 | 225 | if payload.Reply == nil { 226 | logger.Error("empty payload data") 227 | continue 228 | } 229 | // 从文件中读取到命令 230 | reply, ok := payload.Reply.(*protocol.MultiBulkReply) 231 | if !ok { 232 | logger.Error("require multi bulk protocol") 233 | continue 234 | } 235 | 236 | // 利用数据库引擎,将命令数据保存到内存中(命令重放) 237 | ret := aof.engine.Exec(virtualConn, reply.RedisCommand) 238 | // 判断是否执行失败 239 | if protocol.IsErrReply(ret) { 240 | logger.Error("exec err ", string(ret.ToBytes())) 241 | } 242 | // 判断命令是否是"select" 243 | if strings.ToLower(string(reply.RedisCommand[0])) == "select" { 244 | dbIndex, err := strconv.Atoi(string(reply.RedisCommand[1])) 245 | if err == nil { 246 | aof.lastDBIndex = dbIndex // 记录下数据恢复过程中,选中的数据库索引 247 | } 248 | } 249 | } 250 | } 251 | 252 | func (aof *AOF) Close() { 253 | 254 | if aof.aofFile != nil { 255 | // 禁止写入 256 | aof.atomicClose.CompareAndSwap(false, true) 257 | // 停止每秒刷盘 258 | close(aof.closed) 259 | // aofChan关闭,chan缓冲中可能还有数据 260 | close(aof.aofChan) 261 | // 等待缓冲中处理完成 aofFile句柄才会不被使用 262 | <-aof.aofFinished 263 | // 关闭文件句柄 264 | aof.aofFile.Close() 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /aof/cmd.go: -------------------------------------------------------------------------------- 1 | package aof 2 | 3 | import ( 4 | "strconv" 5 | "time" 6 | 7 | "github.com/gofish2020/easyredis/datastruct/sortedset" 8 | "github.com/gofish2020/easyredis/engine/payload" 9 | "github.com/gofish2020/easyredis/redis/protocol" 10 | "github.com/gofish2020/easyredis/utils" 11 | ) 12 | 13 | // 指定key绝对时间戳过期 milliseconds 14 | func PExpireAtCmd(key string, expireAt time.Time) [][]byte { 15 | return utils.BuildCmdLine("PEXPIREAT", [][]byte{[]byte(key), []byte(strconv.FormatInt((expireAt.UnixNano() / 1e6), 10))}...) 16 | } 17 | 18 | func SetCmd(args ...[]byte) [][]byte { 19 | return utils.BuildCmdLine("SET", args...) 20 | } 21 | 22 | func MSetCmd(args ...[]byte) [][]byte { 23 | return utils.BuildCmdLine("MSet", args...) 24 | } 25 | 26 | func SelectCmd(args ...[]byte) [][]byte { 27 | return utils.BuildCmdLine("SELECT", args...) 28 | } 29 | 30 | func ZAddCmd(args ...[]byte) [][]byte { 31 | return utils.BuildCmdLine("ZADD", args...) 32 | } 33 | 34 | // ZIncrBy 35 | func ZIncrByCmd(args ...[]byte) [][]byte { 36 | return utils.BuildCmdLine("ZINCRBY", args...) 37 | } 38 | 39 | // zpopmin 40 | func ZPopMin(args ...[]byte) [][]byte { 41 | return utils.BuildCmdLine("ZPOPMIN", args...) 42 | } 43 | 44 | // ZRem 45 | func ZRem(args ...[]byte) [][]byte { 46 | return utils.BuildCmdLine("ZREM", args...) 47 | } 48 | 49 | // zremrangebyscore 50 | func ZRemRangeByScore(args ...[]byte) [][]byte { 51 | return utils.BuildCmdLine("ZREMRANGEBYSCORE", args...) 52 | } 53 | 54 | // zremrangebyrank 55 | func ZRemRangeByRank(args ...[]byte) [][]byte { 56 | return utils.BuildCmdLine("ZREMRANGEBYRANK", args...) 57 | } 58 | 59 | // del 60 | func Del(args ...[]byte) [][]byte { 61 | return utils.BuildCmdLine("DEL", args...) 62 | } 63 | 64 | // persist 65 | func Persist(args ...[]byte) [][]byte { 66 | return utils.BuildCmdLine("PERSIST", args...) 67 | } 68 | 69 | // Auth 70 | func Auth(args ...[]byte) [][]byte { 71 | return utils.BuildCmdLine("AUTH", args...) 72 | } 73 | 74 | // Get 75 | 76 | func Get(args ...[]byte) [][]byte { 77 | return utils.BuildCmdLine("GET", args...) 78 | } 79 | 80 | // 内存对象转换成 redis命令 81 | func EntityToCmd(key string, entity *payload.DataEntity) *protocol.MultiBulkReply { 82 | if entity == nil { 83 | return nil 84 | } 85 | var cmd *protocol.MultiBulkReply 86 | switch val := entity.RedisObject.(type) { 87 | case []byte: 88 | cmd = protocol.NewMultiBulkReply(SetCmd([]byte(key), val)) 89 | // case List.List: 90 | // cmd = listToCmd(key, val) 91 | // case *set.Set: 92 | // cmd = setToCmd(key, val) 93 | // case dict.Dict: 94 | // cmd = hashToCmd(key, val) 95 | case *sortedset.SortedSet: 96 | cmd = zSetToCmd(key, val) 97 | } 98 | return cmd 99 | } 100 | 101 | func zSetToCmd(key string, set *sortedset.SortedSet) *protocol.MultiBulkReply { 102 | size := set.Len() 103 | args := make([][]byte, 1+2*size) 104 | args[0] = []byte(key) // key 105 | i := 0 106 | 107 | set.ForEachByRank(0, size, true, func(pair *sortedset.Pair) bool { 108 | score := strconv.FormatFloat(pair.Score, 'f', -1, 64) 109 | args[2*i+1] = []byte(score) // score 110 | args[2*i+2] = []byte(pair.Member) // member 111 | i++ 112 | return true 113 | }) 114 | // zadd key score member [score member..] 115 | return protocol.NewMultiBulkReply(ZAddCmd(args...)) 116 | } 117 | -------------------------------------------------------------------------------- /aof/rewrite.go: -------------------------------------------------------------------------------- 1 | package aof 2 | 3 | import ( 4 | "io" 5 | "os" 6 | "strconv" 7 | "time" 8 | 9 | "github.com/gofish2020/easyredis/abstract" 10 | "github.com/gofish2020/easyredis/engine/payload" 11 | "github.com/gofish2020/easyredis/redis/protocol" 12 | "github.com/gofish2020/easyredis/tool/conf" 13 | "github.com/gofish2020/easyredis/tool/logger" 14 | ) 15 | 16 | type snapshotAOF struct { 17 | fileSize int64 // 文件大小 18 | dbIndex int // 数据库索引 19 | tempFile *os.File 20 | } 21 | 22 | func (aof *AOF) Rewrite(engine abstract.Engine) { 23 | //1.对现有的aof文件做一次快照 24 | snapShot, err := aof.startRewrite() 25 | if err != nil { 26 | logger.Errorf("StartRewrite err: %+v", err) 27 | return 28 | } 29 | 30 | //2. 将现在的aof文件数据,加在到新(内存)对象中,并重写入新aof文件中 31 | err = aof.doRewrite(snapShot, engine) 32 | if err != nil { 33 | logger.Errorf("doRewrite err: %+v", err) 34 | return 35 | } 36 | 37 | //3. 将重写过程中的增量命令写入到新文件中 38 | err = aof.finishRewrite(snapShot) 39 | if err != nil { 40 | logger.Errorf("finishRewrite err: %+v", err) 41 | } 42 | } 43 | 44 | func (aof *AOF) startRewrite() (*snapshotAOF, error) { 45 | 46 | // 加锁 47 | aof.mu.Lock() 48 | defer aof.mu.Unlock() 49 | 50 | // 文件刷盘 51 | err := aof.aofFile.Sync() 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | // 获取当前文件信息 57 | fileInfo, err := aof.aofFile.Stat() 58 | if err != nil { 59 | return nil, err 60 | } 61 | 62 | // 创建新的aof文件 63 | file, err := os.CreateTemp(conf.TmpDir(), "*.aof") 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | // 当前aof重写前的信息 69 | snapShot := &snapshotAOF{} 70 | // 大小 71 | snapShot.fileSize = fileInfo.Size() 72 | // 选中的数据库 73 | snapShot.dbIndex = aof.lastDBIndex 74 | snapShot.tempFile = file 75 | return snapShot, nil 76 | } 77 | 78 | func (aof *AOF) doRewrite(snapShot *snapshotAOF, engine abstract.Engine) error { 79 | // 临时aof对象 80 | tmpAof := &AOF{} 81 | tmpAof.aofFileName = aof.aofFileName 82 | tmpAof.engine = engine 83 | // 临时aof,加载aof文件,并将数据保存到临时内存中(engine) 84 | tmpAof.LoadAof(int(snapShot.fileSize)) 85 | // 扫描临时内存,将结果保存到新的aof文件中 86 | tmpFile := snapShot.tempFile 87 | for i := 0; i < conf.GlobalConfig.Databases; i++ { 88 | // 写入 select index 89 | data := protocol.NewMultiBulkReply(SelectCmd([]byte(strconv.Itoa(i)))) 90 | _, err := tmpFile.Write(data.ToBytes()) 91 | if err != nil { 92 | return err 93 | } 94 | 95 | // 遍历数据库 96 | tmpAof.engine.ForEach(i, func(key string, data *payload.DataEntity, expiration *time.Time) bool { 97 | // 写入 redis命令 98 | cmd := EntityToCmd(key, data) 99 | tmpFile.Write(cmd.ToBytes()) 100 | // 写入过期时间(如果存在的话) 101 | if expiration != nil { 102 | cmd := protocol.NewMultiBulkReply(PExpireAtCmd(key, *expiration)) 103 | tmpFile.Write(cmd.ToBytes()) 104 | } 105 | return true 106 | }) 107 | } 108 | return nil 109 | 110 | } 111 | 112 | func (aof *AOF) finishRewrite(snapshot *snapshotAOF) error { 113 | aof.mu.Lock() 114 | defer aof.mu.Unlock() 115 | 116 | err := aof.aofFile.Sync() 117 | if err != nil { 118 | return err 119 | } 120 | 121 | tmpFile := snapshot.tempFile 122 | 123 | lastCopy := func() error { 124 | // 打开现有aof 125 | src, err := os.Open(aof.aofFileName) 126 | if err != nil { 127 | return err 128 | } 129 | 130 | defer func() { 131 | src.Close() 132 | tmpFile.Close() 133 | }() 134 | // 将游标移动到上次的快照位置 135 | _, err = src.Seek(snapshot.fileSize, io.SeekStart) 136 | if err != nil { 137 | return err 138 | } 139 | 140 | // 将快照的dbindex保存到tmpFile中(因为增量的命令是在当时快照的时候数据库下生成的) 141 | data := protocol.NewMultiBulkReply(SelectCmd([]byte(strconv.Itoa(snapshot.dbIndex)))) 142 | _, err = tmpFile.Write(data.ToBytes()) 143 | if err != nil { 144 | return err 145 | } 146 | 147 | // 将增量的数据保存到tmpFile中 148 | _, err = io.Copy(tmpFile, src) 149 | if err != nil { 150 | return err 151 | } 152 | return nil 153 | } 154 | 155 | if err := lastCopy(); err != nil { 156 | return err 157 | } 158 | // 执行到这里说明数据复制完毕 159 | 160 | // 关闭当前的aof 161 | aof.aofFile.Close() 162 | 163 | // 将临时文件移动到aof文件 164 | if err := os.Rename(tmpFile.Name(), aof.aofFileName); err != nil { 165 | logger.Warn(err) 166 | } 167 | 168 | // 重新打开 169 | aofFile, err := os.OpenFile(aof.aofFileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600) 170 | if err != nil { 171 | panic(err) 172 | } 173 | aof.aofFile = aofFile 174 | 175 | // 将当前的数据库索引写入 176 | data := protocol.NewMultiBulkReply(SelectCmd([]byte(strconv.Itoa(aof.lastDBIndex)))) 177 | _, err = aof.aofFile.Write(data.ToBytes()) 178 | if err != nil { 179 | panic(err) 180 | } 181 | 182 | return nil 183 | } 184 | -------------------------------------------------------------------------------- /cluster/cluster.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import ( 4 | "fmt" 5 | "runtime/debug" 6 | "strings" 7 | "sync" 8 | "time" 9 | 10 | "github.com/gofish2020/easyredis/abstract" 11 | "github.com/gofish2020/easyredis/engine" 12 | "github.com/gofish2020/easyredis/engine/payload" 13 | "github.com/gofish2020/easyredis/redis/protocol" 14 | "github.com/gofish2020/easyredis/tool/conf" 15 | "github.com/gofish2020/easyredis/tool/consistenthash" 16 | "github.com/gofish2020/easyredis/tool/idgenerator" 17 | "github.com/gofish2020/easyredis/tool/logger" 18 | "github.com/gofish2020/easyredis/tool/timewheel" 19 | ) 20 | 21 | /* 22 | Redis集群 23 | */ 24 | 25 | type CmdLine = [][]byte 26 | 27 | const ( 28 | replicas = 100 // 副本数量 29 | ) 30 | 31 | type Cluster struct { 32 | // 当前的ip地址 33 | self string 34 | // socket连接池 35 | clientFactory Factory 36 | // Redis存储引擎 37 | engine *engine.Engine 38 | 39 | // 一致性hash 40 | consistHash *consistenthash.Map 41 | 42 | // 雪花算法,生成唯一guid 43 | snowflake *idgenerator.IDGenerator 44 | 45 | // 分布式事务 46 | transactionLock sync.RWMutex 47 | transactions map[string]*Transaction 48 | 49 | delay *timewheel.Delay 50 | } 51 | 52 | func NewCluster() *Cluster { 53 | cluster := Cluster{ 54 | clientFactory: NewRedisConnPool(), 55 | engine: engine.NewEngine(), 56 | consistHash: consistenthash.New(replicas, nil), 57 | self: conf.GlobalConfig.Self, 58 | snowflake: idgenerator.MakeGenerator(conf.GlobalConfig.Self), 59 | delay: timewheel.NewDelay(), 60 | transactions: make(map[string]*Transaction), 61 | } 62 | 63 | // 一致性hash初始化 64 | contains := make(map[string]struct{}) 65 | peers := make([]string, 0, len(conf.GlobalConfig.Peers)+1) 66 | // 去重 67 | for _, peer := range conf.GlobalConfig.Peers { 68 | if _, ok := contains[peer]; ok { 69 | continue 70 | } 71 | peers = append(peers, peer) 72 | } 73 | 74 | if _, ok := contains[cluster.self]; !ok { 75 | peers = append(peers, cluster.self) 76 | } 77 | // 添加到集群 78 | cluster.consistHash.Add(peers...) 79 | return &cluster 80 | } 81 | 82 | func (cluster *Cluster) Exec(c abstract.Connection, redisCommand [][]byte) (result protocol.Reply) { 83 | defer func() { 84 | if err := recover(); err != nil { 85 | logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack()))) 86 | result = protocol.NewUnknownErrReply() 87 | } 88 | }() 89 | 90 | name := strings.ToLower(string(redisCommand[0])) 91 | routerFunc, ok := clusterRouter[name] 92 | if !ok { 93 | return protocol.NewGenericErrReply("unknown command '" + name + "' or not support command in cluster mode") 94 | } 95 | return routerFunc(cluster, c, redisCommand) 96 | } 97 | 98 | func (cluster *Cluster) Close() { 99 | cluster.engine.Close() 100 | } 101 | 102 | func (cluster *Cluster) ForEach(dbIndex int, cb func(key string, data *payload.DataEntity, expiration *time.Time) bool) { 103 | cluster.engine.ForEach(dbIndex, cb) 104 | } 105 | -------------------------------------------------------------------------------- /cluster/conn_pool.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/gofish2020/easyredis/aof" 7 | "github.com/gofish2020/easyredis/datastruct/dict" 8 | "github.com/gofish2020/easyredis/redis/client" 9 | "github.com/gofish2020/easyredis/redis/protocol" 10 | "github.com/gofish2020/easyredis/tool/conf" 11 | "github.com/gofish2020/easyredis/tool/pool" 12 | ) 13 | 14 | type Factory interface { 15 | GetConn(addr string) (Client, error) 16 | ReturnConn(peer string, cli Client) error 17 | } 18 | 19 | type Client interface { 20 | Send(command [][]byte) (protocol.Reply, error) 21 | } 22 | 23 | type RedisConnPool struct { 24 | connDict *dict.ConcurrentDict // addr -> *pool.Pool 25 | } 26 | 27 | func NewRedisConnPool() *RedisConnPool { 28 | 29 | return &RedisConnPool{ 30 | connDict: dict.NewConcurrentDict(16), 31 | } 32 | } 33 | 34 | func (r *RedisConnPool) GetConn(addr string) (Client, error) { 35 | 36 | var connectionPool *pool.Pool // 对象池 37 | 38 | // 通过不同的地址addr,获取不同的对象池 39 | raw, ok := r.connDict.Get(addr) 40 | if ok { 41 | connectionPool = raw.(*pool.Pool) 42 | } else { 43 | 44 | // 创建对象函数 45 | newClient := func() (any, error) { 46 | // redis的客户端连接 47 | cli, err := client.NewRedisClient(addr) 48 | if err != nil { 49 | return nil, err 50 | } 51 | // 启动 52 | cli.Start() 53 | if conf.GlobalConfig.RequirePass != "" { // 说明服务需要密码 54 | reply, err := cli.Send(aof.Auth([]byte(conf.GlobalConfig.RequirePass))) 55 | if err != nil { 56 | return nil, err 57 | } 58 | if !protocol.IsOKReply(reply) { 59 | return nil, errors.New("auth failed:" + string(reply.ToBytes())) 60 | } 61 | return cli, nil 62 | } 63 | return cli, nil 64 | } 65 | 66 | // 释放对象函数 67 | freeClient := func(x any) { 68 | cli, ok := x.(*client.RedisClient) 69 | if ok { 70 | cli.Stop() // 释放 71 | } 72 | } 73 | 74 | // 针对addr地址,创建一个新的对象池 75 | connectionPool = pool.NewPool(newClient, freeClient, pool.Config{ 76 | MaxIdles: 1, 77 | MaxActive: 20, 78 | }) 79 | // addr -> *pool.Pool 80 | r.connDict.Put(addr, connectionPool) 81 | } 82 | 83 | // 从对象池中获取一个对象 84 | raw, err := connectionPool.Get() 85 | if err != nil { 86 | return nil, err 87 | } 88 | conn, ok := raw.(*client.RedisClient) 89 | if !ok { 90 | return nil, errors.New("connection pool make wrong type") 91 | } 92 | return conn, nil 93 | } 94 | 95 | func (r *RedisConnPool) ReturnConn(peer string, cli Client) error { 96 | raw, ok := r.connDict.Get(peer) 97 | if !ok { 98 | return errors.New("connection pool not found") 99 | } 100 | raw.(*pool.Pool).Put(cli) 101 | return nil 102 | } 103 | -------------------------------------------------------------------------------- /cluster/mset.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import ( 4 | "github.com/gofish2020/easyredis/abstract" 5 | "github.com/gofish2020/easyredis/redis/protocol" 6 | ) 7 | 8 | // mset key value [key value...] 9 | func mset(cluster *Cluster, c abstract.Connection, redisCommand [][]byte) protocol.Reply { 10 | 11 | // 基础校验 12 | if len(redisCommand) < 3 { 13 | return protocol.NewArgNumErrReply("mset") 14 | } 15 | 16 | argsNum := len(redisCommand) - 1 17 | if argsNum%2 != 0 { 18 | return protocol.NewArgNumErrReply("mset") 19 | } 20 | 21 | //1.从命令中,提取出 key value 22 | size := argsNum / 2 23 | keys := make([]string, 0, size) 24 | values := make(map[string]string) 25 | for i := 0; i < size; i++ { 26 | keys = append(keys, string(redisCommand[2*i+1])) 27 | values[keys[i]] = string(redisCommand[2*i+2]) 28 | } 29 | 30 | //2.计算key映射的ip地址; ip -> []string 31 | ipMap := cluster.groupByKeys(keys) 32 | 33 | // 3.说明keys映射为同一个ip地址,直接转发执行(不需要走分布式事务) 34 | if len(ipMap) == 1 { 35 | for ip := range ipMap { 36 | return cluster.Relay(ip, c, pushCmd(redisCommand, "Direct")) 37 | } 38 | } 39 | 40 | // 4.prepare阶段 41 | var respReply protocol.Reply = protocol.NewOkReply() 42 | // 事务id 43 | txId := cluster.newTxId() 44 | rollback := false 45 | for ip, keys := range ipMap { 46 | // txid mset key value [key value...] 47 | argsGroup := [][]byte{[]byte(txId), []byte("mset")} 48 | for _, key := range keys { 49 | argsGroup = append(argsGroup, []byte(key), []byte(values[key])) 50 | } 51 | //发送命令: prepare txid mset key value [key value...] 52 | reply := cluster.Relay(ip, c, pushCmd(argsGroup, "Prepare")) 53 | if protocol.IsErrReply(reply) { // 说明失败 54 | respReply = reply 55 | rollback = true 56 | break 57 | } 58 | } 59 | 60 | if rollback { // 如果prepare阶段失败, 向所有节点请求回滚 61 | rollbackTransaction(cluster, c, txId, ipMap) 62 | } else { // 所有节点都可以提交 63 | _, reply := commitTransaction(cluster, c, txId, ipMap) 64 | if reply != nil { 65 | respReply = reply 66 | } 67 | } 68 | return respReply 69 | } 70 | -------------------------------------------------------------------------------- /cluster/relay.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import ( 4 | "github.com/gofish2020/easyredis/abstract" 5 | "github.com/gofish2020/easyredis/redis/protocol" 6 | "github.com/gofish2020/easyredis/tool/logger" 7 | ) 8 | 9 | func (cluster *Cluster) Relay(peer string, conn abstract.Connection, redisCommand [][]byte) protocol.Reply { 10 | 11 | // ******本地执行****** 12 | if cluster.self == peer { 13 | //return cluster.engine.Exec(conn, redisCommand) 14 | return cluster.Exec(conn, redisCommand) 15 | } 16 | 17 | // ******发送到远端执行****** 18 | 19 | client, err := cluster.clientFactory.GetConn(peer) // 从连接池中获取一个连接 20 | if err != nil { 21 | logger.Error(err) 22 | return protocol.NewGenericErrReply(err.Error()) 23 | } 24 | 25 | defer func() { 26 | cluster.clientFactory.ReturnConn(peer, client) // 归还连接 27 | }() 28 | 29 | logger.Debugf("命令:%q,转发至ip:%s", protocol.NewMultiBulkReply(redisCommand).ToBytes(), peer) 30 | reply, err := client.Send(redisCommand) // 发送命令 31 | if err != nil { 32 | logger.Error(err) 33 | return protocol.NewGenericErrReply(err.Error()) 34 | } 35 | 36 | return reply 37 | } 38 | -------------------------------------------------------------------------------- /cluster/router.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/gofish2020/easyredis/abstract" 7 | "github.com/gofish2020/easyredis/redis/protocol" 8 | ) 9 | 10 | type clusterFunc func(cluster *Cluster, conn abstract.Connection, args [][]byte) protocol.Reply 11 | 12 | var clusterRouter = make(map[string]clusterFunc) 13 | 14 | func registerClusterRouter(cmd string, f clusterFunc) { 15 | cmd = strings.ToLower(cmd) 16 | clusterRouter[cmd] = f 17 | } 18 | func init() { 19 | 20 | // 在集群节点上注册的命令 21 | registerClusterRouter("Set", defultFunc) 22 | registerClusterRouter("Get", defultFunc) 23 | registerClusterRouter("MSet", mset) 24 | 25 | registerClusterRouter("Prepare", prepareFunc) 26 | registerClusterRouter("Rollback", rollbackFunc) 27 | registerClusterRouter("Commit", commitFunc) 28 | 29 | // 表示命令直接在存储引擎上执行命令 30 | registerClusterRouter("Direct", directFunc) 31 | } 32 | 33 | func defultFunc(cluster *Cluster, conn abstract.Connection, redisCommand [][]byte) protocol.Reply { 34 | key := string(redisCommand[1]) 35 | // 计算key所属的节点 36 | peer := cluster.consistHash.Get(key) 37 | return cluster.Relay(peer, conn, pushCmd(redisCommand, "Direct")) // 将命令转发至节点,直接执行(不用再重复计算key所属节点) 38 | } 39 | 40 | // 直接在存储引擎上执行命令 41 | func directFunc(cluster *Cluster, conn abstract.Connection, redisCommand [][]byte) protocol.Reply { 42 | return cluster.engine.Exec(conn, popCmd(redisCommand)) 43 | } 44 | -------------------------------------------------------------------------------- /cluster/tcc.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import ( 4 | "github.com/gofish2020/easyredis/abstract" 5 | "github.com/gofish2020/easyredis/redis/protocol" 6 | ) 7 | 8 | // 回滚事务 9 | func rollbackTransaction(cluster *Cluster, c abstract.Connection, txId string, ipMap map[string][]string) { 10 | argsGroup := [][]byte{[]byte(txId)} 11 | // 向所有的ip发送回滚请求 12 | for ip := range ipMap { 13 | cluster.Relay(ip, c, pushCmd(argsGroup, "Rollback")) // Rollback txid 14 | } 15 | } 16 | 17 | // 提交事务 18 | func commitTransaction(cluster *Cluster, c abstract.Connection, txId string, ipMap map[string][]string) ([]protocol.Reply, protocol.Reply) { 19 | 20 | result := make([]protocol.Reply, len(ipMap)) 21 | var errReply protocol.Reply = nil 22 | argsGroup := [][]byte{[]byte(txId)} 23 | // 向所有的ip发送提交请求 24 | for ip := range ipMap { 25 | reply := cluster.Relay(ip, c, pushCmd(argsGroup, "Commit")) 26 | if protocol.IsErrReply(reply) { // 说明提交的时候失败了 27 | errReply = reply 28 | break 29 | } 30 | // 保存提交结果 31 | result = append(result, reply) 32 | } 33 | 34 | if errReply != nil { 35 | rollbackTransaction(cluster, c, txId, ipMap) 36 | result = nil 37 | } 38 | return result, errReply 39 | } 40 | 41 | func genTxKey(txId string) string { 42 | return "tx:" + txId 43 | } 44 | 45 | // ***********************Prepare/Commit/Rollback命令处理函数*********************** 46 | // prepare txid mset key value [key value...] 47 | func prepareFunc(cluster *Cluster, conn abstract.Connection, redisCommand [][]byte) protocol.Reply { 48 | 49 | if len(redisCommand) < 3 { 50 | return protocol.NewArgNumErrReply("prepare") 51 | } 52 | 53 | txId := string(redisCommand[1]) 54 | 55 | // 创建事务对象 56 | tx := NewTransaction(txId, redisCommand[2:], cluster, conn) 57 | 58 | // 存储对象 59 | cluster.transactionLock.Lock() 60 | cluster.transactions[txId] = tx 61 | cluster.transactionLock.Unlock() 62 | 63 | // prepare事务 64 | err := tx.prepare() 65 | if err != nil { 66 | return protocol.NewGenericErrReply(err.Error()) 67 | } 68 | 69 | // 3s后如果事务还没有提交,自动回滚(避免长时间锁定) 70 | cluster.delay.Add(maxPrepareTime, genTxKey(txId), func() { 71 | tx.mu.Lock() 72 | defer tx.mu.Unlock() 73 | if tx.status == preparedStatus { 74 | tx.rollback() 75 | cluster.transactionLock.Lock() 76 | defer cluster.transactionLock.Unlock() 77 | delete(cluster.transactions, tx.txId) 78 | } 79 | }) 80 | return protocol.NewOkReply() 81 | } 82 | 83 | 84 | // rollback txid 85 | func rollbackFunc(cluster *Cluster, conn abstract.Connection, redisCommand [][]byte) protocol.Reply { 86 | 87 | if len(redisCommand) != 2 { 88 | return protocol.NewArgNumErrReply("rollback") 89 | } 90 | cluster.transactionLock.RLock() 91 | tx, ok := cluster.transactions[string(redisCommand[1])] 92 | cluster.transactionLock.RUnlock() 93 | if !ok { 94 | return protocol.NewIntegerReply(0) // 事务不存在 95 | } 96 | 97 | tx.mu.Lock() 98 | defer tx.mu.Unlock() 99 | 100 | // 回滚事务 101 | err := tx.rollback() 102 | if err != nil { 103 | return protocol.NewGenericErrReply(err.Error()) 104 | } 105 | 106 | // 延迟6s删除事务对象 107 | cluster.delay.Add(waitBeforeCleanTx, "", func() { 108 | cluster.transactionLock.Lock() 109 | defer cluster.transactionLock.Unlock() 110 | delete(cluster.transactions, tx.txId) 111 | }) 112 | 113 | return protocol.NewIntegerReply(1) 114 | } 115 | 116 | // commit txid 117 | func commitFunc(cluster *Cluster, conn abstract.Connection, redisCommand [][]byte) protocol.Reply { 118 | 119 | if len(redisCommand) != 2 { 120 | return protocol.NewArgNumErrReply("commit") 121 | } 122 | 123 | cluster.transactionLock.RLock() 124 | tx, ok := cluster.transactions[string(redisCommand[1])] 125 | cluster.transactionLock.RUnlock() 126 | if !ok { 127 | return protocol.NewIntegerReply(0) // 事务不存在 128 | } 129 | 130 | return tx.commit() 131 | } 132 | -------------------------------------------------------------------------------- /cluster/tcc_test.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/gofish2020/easyredis/aof" 8 | "github.com/gofish2020/easyredis/redis/connection" 9 | "github.com/gofish2020/easyredis/redis/protocol" 10 | "github.com/gofish2020/easyredis/tool/conf" 11 | ) 12 | 13 | var cluster map[string]*Cluster = make(map[string]*Cluster) 14 | 15 | type mockFactory struct { 16 | } 17 | 18 | func (f *mockFactory) GetConn(addr string) (Client, error) { 19 | return &fakeClient{cluster: cluster[addr]}, nil 20 | } 21 | 22 | func (f *mockFactory) ReturnConn(peer string, cli Client) error { 23 | return nil 24 | } 25 | 26 | type fakeClient struct { 27 | cluster *Cluster 28 | } 29 | 30 | func (f *fakeClient) Send(command [][]byte) (protocol.Reply, error) { 31 | 32 | return f.cluster.Exec(connection.NewVirtualConn(), command), nil 33 | } 34 | 35 | func TestTCC(t *testing.T) { 36 | conf.GlobalConfig.Peers = []string{"127.0.0.1:6379", "127.0.0.1:7379", "127.0.0.1:8379"} 37 | 38 | for _, v := range conf.GlobalConfig.Peers { 39 | conf.GlobalConfig.Self = v 40 | clusterX := NewCluster() 41 | clusterX.clientFactory = &mockFactory{} 42 | cluster[v] = clusterX 43 | } 44 | 45 | // 选中一个节点,作为协调者 46 | oneCluster := cluster[conf.GlobalConfig.Peers[0]] 47 | conn := connection.NewVirtualConn() 48 | 49 | txId := oneCluster.newTxId() 50 | keys := []string{"1", "6", "10"} 51 | values := make(map[string]string) 52 | values["1"] = "300" 53 | values["6"] = "300" 54 | values["10"] = "300" 55 | 56 | ipMap := oneCluster.groupByKeys(keys) 57 | for ip, keys := range ipMap { 58 | // txid mset key value [key value...] 59 | argsGroup := [][]byte{[]byte(txId), []byte("mset")} 60 | for _, key := range keys { 61 | argsGroup = append(argsGroup, []byte(key), []byte(values[key])) 62 | } 63 | //发送命令: prepare txid mset key value [key value...] 64 | oneCluster.Relay(ip, conn, pushCmd(argsGroup, "Prepare")) 65 | } 66 | 67 | // test commit 68 | commitTransaction(oneCluster, conn, txId, ipMap) 69 | t.Logf("%q", oneCluster.Exec(conn, aof.Get([]byte("1"))).ToBytes()) 70 | t.Logf("%q", oneCluster.Exec(conn, aof.Get([]byte("6"))).ToBytes()) 71 | t.Logf("%q", oneCluster.Exec(conn, aof.Get([]byte("10"))).ToBytes()) 72 | time.Sleep(1 * time.Second) 73 | // test rollback 74 | rollbackTransaction(oneCluster, conn, txId, ipMap) 75 | t.Logf("%q", oneCluster.Exec(conn, aof.Get([]byte("1"))).ToBytes()) 76 | t.Logf("%q", oneCluster.Exec(conn, aof.Get([]byte("6"))).ToBytes()) 77 | t.Logf("%q", oneCluster.Exec(conn, aof.Get([]byte("10"))).ToBytes()) 78 | } 79 | -------------------------------------------------------------------------------- /cluster/transaction.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/gofish2020/easyredis/abstract" 8 | "github.com/gofish2020/easyredis/engine" 9 | "github.com/gofish2020/easyredis/redis/protocol" 10 | ) 11 | 12 | /* 13 | 事务对象 14 | */ 15 | 16 | const ( 17 | maxPrepareTime = 3 * time.Second 18 | waitBeforeCleanTx = 2 * maxPrepareTime 19 | ) 20 | 21 | type transactionStatus int8 22 | 23 | const ( 24 | createdStatus transactionStatus = 0 25 | preparedStatus transactionStatus = 1 26 | committedStatus transactionStatus = 2 27 | rolledBackStatus transactionStatus = 3 28 | ) 29 | 30 | type Transaction struct { 31 | txId string // transaction id 32 | redisCommand [][]byte // redis命令 33 | cluster *Cluster // 集群对象 34 | conn abstract.Connection // socket连接 35 | dbIndex int // 数据库索引 36 | 37 | writeKeys []string // 写key 38 | readKeys []string // 读key 39 | keysLocked bool // 是否对写key/读key已经上锁 40 | undoLog []CmdLine // 回滚日志 41 | 42 | status transactionStatus // 事务状态 43 | mu *sync.Mutex // 事务锁(操作事务对象的时候上锁) 44 | } 45 | 46 | func NewTransaction(txId string, cmdLine [][]byte, cluster *Cluster, c abstract.Connection) *Transaction { 47 | return &Transaction{ 48 | txId: txId, 49 | redisCommand: cmdLine, 50 | cluster: cluster, 51 | conn: c, 52 | dbIndex: c.GetDBIndex(), 53 | status: createdStatus, 54 | mu: &sync.Mutex{}, 55 | } 56 | } 57 | 58 | func (tx *Transaction) prepare() error { 59 | 60 | // 1.上锁 61 | tx.mu.Lock() 62 | defer tx.mu.Unlock() 63 | // 2.获取读写key 64 | readKeys, writeKeys := engine.GetRelatedKeys(tx.redisCommand) 65 | tx.readKeys = readKeys 66 | tx.writeKeys = writeKeys 67 | // 3. 锁定节点资源 68 | tx.locks() 69 | // 4.生成回滚日志 70 | tx.undoLog = tx.cluster.engine.GetUndoLogs(tx.dbIndex, tx.redisCommand) 71 | tx.status = preparedStatus 72 | return nil 73 | } 74 | 75 | func (tx *Transaction) rollback() error { 76 | if tx.status == rolledBackStatus { // no need to rollback a rolled-back transaction 77 | return nil 78 | } 79 | tx.locks() 80 | for _, cmdLine := range tx.undoLog { // 执行回滚日志 81 | tx.cluster.engine.ExecWithLock(tx.dbIndex, cmdLine) 82 | } 83 | tx.unlocks() 84 | tx.status = rolledBackStatus 85 | return nil 86 | } 87 | 88 | func (tx *Transaction) commit() protocol.Reply { 89 | tx.mu.Lock() 90 | defer tx.mu.Unlock() 91 | if tx.status == committedStatus { 92 | return protocol.NewIntegerReply(0) 93 | } 94 | 95 | tx.locks() 96 | reply := tx.cluster.engine.ExecWithLock(tx.dbIndex, tx.redisCommand) 97 | if protocol.IsErrReply(reply) { 98 | tx.rollback() // commit 失败,自动回滚 99 | return reply 100 | } 101 | tx.status = committedStatus 102 | tx.unlocks() 103 | 104 | // 保留事务对象6s 105 | tx.cluster.delay.Add(waitBeforeCleanTx, "", func() { 106 | tx.cluster.transactionLock.Lock() 107 | delete(tx.cluster.transactions, tx.txId) 108 | tx.cluster.transactionLock.Unlock() 109 | }) 110 | return reply 111 | } 112 | 113 | func (tx *Transaction) locks() { 114 | if !tx.keysLocked { 115 | tx.cluster.engine.RWLocks(tx.dbIndex, tx.readKeys, tx.writeKeys) 116 | tx.keysLocked = true 117 | } 118 | } 119 | 120 | func (tx *Transaction) unlocks() { 121 | if tx.keysLocked { 122 | tx.cluster.engine.RWUnLocks(tx.dbIndex, tx.readKeys, tx.writeKeys) 123 | tx.keysLocked = false 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /cluster/utils.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import "strconv" 4 | 5 | // 计算key应该存储的节点ip 6 | func (cluster *Cluster) groupByKeys(keys []string) map[string][]string { 7 | var result = make(map[string][]string) 8 | for _, key := range keys { 9 | ip := cluster.consistHash.Get(key) 10 | result[ip] = append(result[ip], key) 11 | } 12 | return result 13 | } 14 | 15 | // 替换命令名 16 | 17 | func replaceCmd(redisCommand [][]byte, newCmd string) [][]byte { 18 | newRedisCommand := make([][]byte, len(redisCommand)) 19 | copy(newRedisCommand, redisCommand) 20 | newRedisCommand[0] = []byte(newCmd) 21 | return newRedisCommand 22 | } 23 | 24 | // 头部添加命令名 25 | func pushCmd(redisCommand [][]byte, newCmd string) [][]byte { 26 | result := make([][]byte, len(redisCommand)+1) 27 | result[0] = []byte(newCmd) 28 | for i := 0; i < len(redisCommand); i++ { 29 | result[i+1] = []byte(redisCommand[i]) 30 | } 31 | return result 32 | } 33 | 34 | // 删除头部的命令 35 | func popCmd(redisCommand [][]byte) [][]byte { 36 | result := make([][]byte, len(redisCommand)-1) 37 | copy(result, redisCommand[1:]) 38 | return result 39 | } 40 | 41 | // 生成事务id 42 | func (cluster *Cluster) newTxId() string { 43 | id := cluster.snowflake.NextID() 44 | return strconv.FormatInt(id, 10) 45 | } 46 | -------------------------------------------------------------------------------- /datastruct/dict/concurrent.go: -------------------------------------------------------------------------------- 1 | package dict 2 | 3 | import ( 4 | "sort" 5 | "sync" 6 | "sync/atomic" 7 | 8 | "github.com/gofish2020/easyredis/utils" 9 | ) 10 | 11 | // 并发安全的字典 12 | type ConcurrentDict struct { 13 | shds []*shard // 底层shard切片 14 | mask uint32 // 掩码 15 | count *atomic.Int32 // 元素个数 16 | } 17 | 18 | type shard struct { 19 | m map[string]interface{} 20 | mu sync.RWMutex 21 | } 22 | 23 | func (sh *shard) forEach(consumer Consumer) bool { 24 | sh.mu.RLock() 25 | defer sh.mu.RUnlock() 26 | for k, v := range sh.m { 27 | res := consumer(k, v) 28 | if !res { 29 | return false 30 | } 31 | } 32 | return true 33 | } 34 | 35 | // 构造字典对象 36 | func NewConcurrentDict(shardCount int) *ConcurrentDict { 37 | shardCount = utils.ComputeCapacity(shardCount) 38 | 39 | dict := &ConcurrentDict{} 40 | shds := make([]*shard, shardCount) 41 | 42 | for i := range shds { 43 | shds[i] = &shard{ 44 | m: make(map[string]interface{}), 45 | } 46 | } 47 | dict.shds = shds 48 | dict.mask = uint32(shardCount - 1) 49 | dict.count = &atomic.Int32{} 50 | return dict 51 | } 52 | 53 | // code 对应的索引 54 | func (c *ConcurrentDict) index(code uint32) uint32 { 55 | return c.mask & code 56 | } 57 | 58 | // 获取key对应的shard 59 | func (c *ConcurrentDict) getShard(key string) *shard { 60 | return c.shds[c.index(utils.Fnv32(key))] 61 | } 62 | 63 | // AddVersion 对key增加版本号 64 | func (c *ConcurrentDict) AddVersion(key string, delta int64) (val interface{}, exist bool) { 65 | shd := c.getShard(key) 66 | shd.mu.RLock() 67 | defer shd.mu.RUnlock() 68 | val, exist = shd.m[key] 69 | if !exist { 70 | shd.m[key] = delta 71 | return delta, exist 72 | } 73 | 74 | v, ok := val.(int64) 75 | if ok { 76 | v += delta 77 | } else { 78 | v = delta 79 | } 80 | shd.m[key] = v 81 | return v, exist 82 | } 83 | 84 | // 获取key保存的值 85 | func (c *ConcurrentDict) Get(key string) (val interface{}, exist bool) { 86 | shd := c.getShard(key) 87 | shd.mu.RLock() 88 | defer shd.mu.RUnlock() 89 | val, exist = shd.m[key] 90 | return 91 | } 92 | 93 | func (c *ConcurrentDict) GetWithLock(key string) (val interface{}, exists bool) { 94 | if c == nil { 95 | panic("dict is nil") 96 | } 97 | shd := c.getShard(key) 98 | val, exists = shd.m[key] 99 | return 100 | } 101 | 102 | // 元素个数 103 | func (c *ConcurrentDict) Count() int { 104 | return int(c.count.Load()) 105 | } 106 | 107 | // 数量+1 108 | func (c *ConcurrentDict) addCount() { 109 | c.count.Add(1) 110 | } 111 | 112 | // 数量-1 113 | func (c *ConcurrentDict) subCount() { 114 | c.count.Add(-1) 115 | } 116 | 117 | // 删除key 118 | func (c *ConcurrentDict) Delete(key string) (interface{}, int) { 119 | shd := c.getShard(key) 120 | shd.mu.Lock() 121 | defer shd.mu.Unlock() 122 | 123 | if val, ok := shd.m[key]; ok { 124 | delete(shd.m, key) 125 | c.subCount() 126 | return val, 1 127 | } 128 | return nil, 0 129 | } 130 | 131 | func (c *ConcurrentDict) DeleteWithLock(key string) (val interface{}, result int) { 132 | shd := c.getShard(key) 133 | // 删除 & 个数 -1 134 | if val, ok := shd.m[key]; ok { 135 | delete(shd.m, key) 136 | c.subCount() 137 | return val, 1 138 | } 139 | // 返回被删除的 value 140 | return val, 0 141 | } 142 | 143 | // 保存key(insert or update) 144 | func (c *ConcurrentDict) Put(key string, val interface{}) int { 145 | shd := c.getShard(key) 146 | shd.mu.Lock() 147 | defer shd.mu.Unlock() 148 | if _, ok := shd.m[key]; ok { 149 | shd.m[key] = val 150 | return 0 // 更新 151 | } 152 | c.addCount() 153 | shd.m[key] = val 154 | return 1 // 插入 155 | } 156 | 157 | // 保存key(insert or update) 158 | func (c *ConcurrentDict) PutWithLock(key string, val interface{}) int { 159 | shd := c.getShard(key) 160 | if _, ok := shd.m[key]; ok { 161 | shd.m[key] = val 162 | return 0 // 更新 163 | } 164 | c.addCount() 165 | shd.m[key] = val 166 | return 1 // 插入 167 | } 168 | 169 | // 保存key( only insert) 170 | func (c *ConcurrentDict) PutIfAbsent(key string, val interface{}) int { 171 | 172 | shd := c.getShard(key) 173 | 174 | shd.mu.Lock() 175 | defer shd.mu.Unlock() 176 | 177 | if _, ok := shd.m[key]; ok { 178 | return 0 179 | } 180 | c.addCount() 181 | shd.m[key] = val 182 | return 1 // 插入 183 | } 184 | 185 | // 保存key( only insert) 186 | func (c *ConcurrentDict) PutIfAbsentWithLock(key string, val interface{}) int { 187 | shd := c.getShard(key) 188 | if _, ok := shd.m[key]; ok { 189 | return 0 190 | } 191 | c.addCount() 192 | shd.m[key] = val 193 | return 1 // 插入 194 | } 195 | 196 | // 保存key (only update) 197 | func (c *ConcurrentDict) PutIfExist(key string, val interface{}) int { 198 | shd := c.getShard(key) 199 | 200 | shd.mu.Lock() 201 | defer shd.mu.Unlock() 202 | if _, ok := shd.m[key]; ok { 203 | shd.m[key] = val 204 | return 1 // 更新 205 | } 206 | return 0 207 | } 208 | 209 | // 保存key (only update) 210 | func (c *ConcurrentDict) PutIfExistWithLock(key string, val interface{}) int { 211 | shd := c.getShard(key) 212 | 213 | if _, ok := shd.m[key]; ok { 214 | shd.m[key] = val 215 | return 1 // 更新 216 | } 217 | return 0 218 | } 219 | 220 | // 遍历 221 | func (c *ConcurrentDict) ForEach(consumer Consumer) { 222 | if c == nil { 223 | panic("dict is nil") 224 | } 225 | for _, sh := range c.shds { 226 | keepContinue := sh.forEach(consumer) 227 | if !keepContinue { 228 | break 229 | } 230 | } 231 | } 232 | 233 | // 加【读写锁】 234 | func (c *ConcurrentDict) RWLock(readKeys, writeKeys []string) { 235 | 236 | // 所有key映射的索引 237 | keys := append(readKeys, writeKeys...) 238 | allIndexs := c.toLockIndex(keys...) 239 | 240 | // 写key映射的索引 241 | writeIndexs := c.toLockIndexMap(writeKeys...) 242 | for _, index := range allIndexs { 243 | _, ok := writeIndexs[index] // 判断是否写 244 | rwMutex := &c.shds[index].mu 245 | if ok { // 写锁 246 | rwMutex.Lock() 247 | } else { // 读锁 248 | rwMutex.RLock() 249 | } 250 | } 251 | } 252 | 253 | // 解【读写锁】 254 | func (c *ConcurrentDict) RWUnLock(readKeys, writeKeys []string) { 255 | // 所有key映射的索引 256 | keys := append(readKeys, writeKeys...) 257 | allIndexs := c.toLockIndex(keys...) 258 | 259 | // 写key映射的索引 260 | writeIndexs := c.toLockIndexMap(writeKeys...) 261 | for _, index := range allIndexs { 262 | _, ok := writeIndexs[index] // 判断是否写 263 | rwMutex := &c.shds[index].mu 264 | if ok { // 写锁 265 | rwMutex.Unlock() 266 | } else { // 读锁 267 | rwMutex.RUnlock() 268 | } 269 | } 270 | } 271 | 272 | func (c *ConcurrentDict) toLockIndex(keys ...string) []uint32 { 273 | mapIndex := make(map[uint32]struct{}) // 去重 274 | for _, key := range keys { 275 | mapIndex[c.index(utils.Fnv32(key))] = struct{}{} // 将key转成索引 276 | } 277 | indices := make([]uint32, 0, len(mapIndex)) 278 | for k := range mapIndex { 279 | indices = append(indices, k) 280 | } 281 | // 对索引排序 282 | sort.Slice(indices, func(i, j int) bool { 283 | return indices[i] < indices[j] 284 | }) 285 | return indices 286 | } 287 | 288 | func (c *ConcurrentDict) toLockIndexMap(keys ...string) map[uint32]struct{} { 289 | 290 | result := make(map[uint32]struct{}) 291 | for _, key := range keys { 292 | result[c.index(utils.Fnv32(key))] = struct{}{} 293 | } 294 | return result 295 | } 296 | -------------------------------------------------------------------------------- /datastruct/dict/dict.go: -------------------------------------------------------------------------------- 1 | package dict 2 | 3 | type Consumer func(key string, val interface{}) bool 4 | -------------------------------------------------------------------------------- /datastruct/list/linkedlist.go: -------------------------------------------------------------------------------- 1 | package list 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | type Consumer func(i int, val interface{}) bool 8 | 9 | type Expected func(actual interface{}) bool 10 | 11 | // 双向链表,实现增Add/删Del/改Modify/查 Get 12 | 13 | type LinkedList struct { 14 | first *node 15 | last *node 16 | 17 | size int 18 | } 19 | 20 | type node struct { 21 | pre *node 22 | next *node 23 | val any 24 | } 25 | 26 | func newNode(val any) *node { 27 | 28 | return &node{val: val} 29 | } 30 | 31 | // Add push new node to the tail 32 | func (l *LinkedList) Add(val interface{}) { 33 | n := newNode(val) 34 | 35 | if l.last == nil { // 空链表 36 | l.first = n 37 | l.last = n 38 | } else { 39 | n.pre = l.last 40 | l.last.next = n 41 | l.last = n 42 | } 43 | l.size++ 44 | } 45 | 46 | func (l *LinkedList) find(index int) *node { 47 | // 要找的节点在链表的前半部分 48 | if index < l.Len()/2 { 49 | n := l.first 50 | for i := 0; i < index; i++ { 51 | n = n.next 52 | } 53 | return n 54 | } 55 | // 要找的节点在链表的后半部分 56 | n := l.last 57 | for i := l.Len() - 1; i > index; i-- { 58 | n = n.pre 59 | } 60 | return n 61 | } 62 | 63 | // 获取指定索引节点的值 64 | func (l *LinkedList) Get(index int) (any, error) { 65 | if index < 0 || index >= l.size { 66 | return nil, errors.New("out of range") 67 | } 68 | n := l.find(index) 69 | return n.val, nil 70 | 71 | } 72 | 73 | // 修改指定节点的值 74 | func (l *LinkedList) Modify(index int, val any) error { 75 | if index < 0 || index >= l.size { 76 | return errors.New("out of range") 77 | } 78 | 79 | n := l.find(index) 80 | n.val = val 81 | return nil 82 | } 83 | 84 | func (l *LinkedList) delNode(n *node) { 85 | // n 的前驱节点 86 | pre := n.pre 87 | // n 的后驱节点 88 | next := n.next 89 | 90 | if pre != nil { 91 | pre.next = next 92 | } else { // 说明n就是第一个节点 93 | l.first = next 94 | } 95 | 96 | if next != nil { 97 | next.pre = pre 98 | } else { // 说明n就是最后一个节点 99 | l.last = pre 100 | } 101 | 102 | // for gc 103 | n.pre = nil 104 | n.next = nil 105 | 106 | l.size-- 107 | } 108 | 109 | // 删除指定节点 110 | func (l *LinkedList) Del(index int) (any, error) { 111 | if index < 0 || index >= l.size { 112 | return nil, errors.New("out of range") 113 | } 114 | n := l.find(index) 115 | l.delNode(n) 116 | return n.val, nil 117 | } 118 | 119 | // 删除最后一个节点 120 | func (l *LinkedList) DelLast() (any, error) { 121 | if l.Len() == 0 { // do nothing 122 | return nil, nil 123 | } 124 | return l.Del(l.Len() - 1) 125 | } 126 | 127 | // 遍历链表中的元素 128 | func (l *LinkedList) ForEach(consumer Consumer) { 129 | i := 0 130 | for n := l.first; n != nil; n = n.next { 131 | if !consumer(i, n.val) { 132 | break 133 | } 134 | } 135 | } 136 | 137 | // 判断是否包含指定值 138 | func (l *LinkedList) Contain(expect Expected) bool { 139 | result := false 140 | l.ForEach(func(index int, val interface{}) bool { 141 | if expect(val) { 142 | result = true 143 | return false 144 | } 145 | return true 146 | }) 147 | return result 148 | } 149 | 150 | // 删除链表中的指定值(所有) 151 | func (l *LinkedList) DelAllByVal(expected Expected) int { 152 | 153 | removed := 0 154 | for n := l.first; n != nil; { 155 | next := n.next 156 | if expected(n.val) { 157 | l.delNode(n) 158 | removed++ 159 | } 160 | n = next 161 | } 162 | return removed 163 | } 164 | 165 | // 链表的长度 166 | func (l *LinkedList) Len() int { 167 | return l.size 168 | } 169 | 170 | // 构建新链表 171 | func NewLinkedList() *LinkedList { 172 | l := &LinkedList{} 173 | return l 174 | } 175 | -------------------------------------------------------------------------------- /datastruct/sortedset/border.go: -------------------------------------------------------------------------------- 1 | package sortedset 2 | 3 | import ( 4 | "errors" 5 | "strconv" 6 | ) 7 | 8 | /* 9 | * ScoreBorder is a struct represents `min` `max` parameter of redis command `ZRANGEBYSCORE` 10 | * can accept: 11 | * int or float value, such as 2.718, 2, -2.718, -2 ... 12 | * exclusive int or float value, such as (2.718, (2, (-2.718, (-2 ... 13 | * infinity: +inf, -inf, inf(same as +inf) 14 | */ 15 | 16 | const ( 17 | scoreNegativeInf int8 = -1 18 | scorePositiveInf int8 = 1 19 | lexNegativeInf int8 = '-' 20 | lexPositiveInf int8 = '+' 21 | ) 22 | 23 | type Border interface { 24 | greater(element *Pair) bool 25 | less(element *Pair) bool 26 | getValue() interface{} 27 | getExclude() bool 28 | isIntersected(max Border) bool 29 | } 30 | 31 | // ScoreBorder represents range of a float value, including: <, <=, >, >=, +inf, -inf 32 | type ScoreBorder struct { 33 | Inf int8 34 | Value float64 35 | Exclude bool // 不包含(排除的意思) 36 | } 37 | 38 | // if max.greater(score) then the score is within the upper border 39 | // do not use min.greater() 40 | func (border *ScoreBorder) greater(element *Pair) bool { 41 | 42 | value := element.Score 43 | if border.Inf == scoreNegativeInf { // -inf 44 | return false 45 | } else if border.Inf == scorePositiveInf { // +inf 46 | return true 47 | } 48 | if border.Exclude { 49 | return border.Value > value 50 | } 51 | return border.Value >= value 52 | } 53 | 54 | func (border *ScoreBorder) less(element *Pair) bool { 55 | value := element.Score 56 | if border.Inf == scoreNegativeInf { // -inf 57 | return true 58 | } else if border.Inf == scorePositiveInf { // +inf 59 | return false 60 | } 61 | if border.Exclude { 62 | return border.Value < value 63 | } 64 | return border.Value <= value 65 | } 66 | 67 | func (border *ScoreBorder) getValue() interface{} { 68 | return border.Value 69 | } 70 | 71 | func (border *ScoreBorder) getExclude() bool { 72 | return border.Exclude 73 | } 74 | 75 | var scorePositiveInfBorder = &ScoreBorder{ 76 | Inf: scorePositiveInf, 77 | } 78 | 79 | var scoreNegativeInfBorder = &ScoreBorder{ 80 | Inf: scoreNegativeInf, 81 | } 82 | 83 | // 模拟score的范围 84 | // ParseScoreBorder creates ScoreBorder from redis arguments 85 | func ParseScoreBorder(s string) (Border, error) { 86 | if s == "inf" || s == "+inf" { 87 | return scorePositiveInfBorder, nil 88 | } 89 | if s == "-inf" { 90 | return scoreNegativeInfBorder, nil 91 | } 92 | if s[0] == '(' { 93 | value, err := strconv.ParseFloat(s[1:], 64) 94 | if err != nil { 95 | return nil, errors.New("min or max is not a float") 96 | } 97 | return &ScoreBorder{ 98 | Inf: 0, 99 | Value: value, 100 | Exclude: true, 101 | }, nil 102 | } 103 | value, err := strconv.ParseFloat(s, 64) 104 | if err != nil { 105 | return nil, errors.New("min or max is not a float") 106 | } 107 | return &ScoreBorder{ 108 | Inf: 0, 109 | Value: value, 110 | Exclude: false, 111 | }, nil 112 | } 113 | 114 | // 校验两个边界是否有重叠 115 | func (border *ScoreBorder) isIntersected(max Border) bool { //是否重叠,重叠无效 116 | minValue := border.Value 117 | maxValue := max.(*ScoreBorder).Value 118 | return minValue > maxValue || (minValue == maxValue && (border.getExclude() || max.getExclude())) // [ min ,max ) 119 | } 120 | 121 | // 模拟字符串的范围 122 | // LexBorder represents range of a string value, including: <, <=, >, >=, +, - 123 | type LexBorder struct { 124 | Inf int8 125 | Value string 126 | Exclude bool 127 | } 128 | 129 | // if max.greater(lex) then the lex is within the upper border 130 | // do not use min.greater() 131 | func (border *LexBorder) greater(element *Pair) bool { 132 | value := element.Member 133 | if border.Inf == lexNegativeInf { // -inf 134 | return false 135 | } else if border.Inf == lexPositiveInf { // +inf 136 | return true 137 | } 138 | if border.Exclude { 139 | return border.Value > value 140 | } 141 | return border.Value >= value 142 | } 143 | 144 | func (border *LexBorder) less(element *Pair) bool { 145 | value := element.Member 146 | if border.Inf == lexNegativeInf { 147 | return true 148 | } else if border.Inf == lexPositiveInf { 149 | return false 150 | } 151 | if border.Exclude { 152 | return border.Value < value 153 | } 154 | return border.Value <= value 155 | } 156 | 157 | func (border *LexBorder) getValue() interface{} { 158 | return border.Value 159 | } 160 | 161 | func (border *LexBorder) getExclude() bool { 162 | return border.Exclude 163 | } 164 | 165 | var lexPositiveInfBorder = &LexBorder{ 166 | Inf: lexPositiveInf, 167 | } 168 | 169 | var lexNegativeInfBorder = &LexBorder{ 170 | Inf: lexNegativeInf, 171 | } 172 | 173 | // ParseLexBorder creates LexBorder from redis arguments 174 | func ParseLexBorder(s string) (Border, error) { 175 | if s == "+" { 176 | return lexPositiveInfBorder, nil 177 | } 178 | if s == "-" { 179 | return lexNegativeInfBorder, nil 180 | } 181 | if s[0] == '(' { 182 | return &LexBorder{ 183 | Inf: 0, 184 | Value: s[1:], 185 | Exclude: true, 186 | }, nil 187 | } 188 | 189 | if s[0] == '[' { 190 | return &LexBorder{ 191 | Inf: 0, 192 | Value: s[1:], 193 | Exclude: false, 194 | }, nil 195 | } 196 | 197 | return nil, errors.New("ERR min or max not valid string range item") 198 | } 199 | 200 | func (border *LexBorder) isIntersected(max Border) bool { 201 | minValue := border.Value 202 | maxValue := max.(*LexBorder).Value 203 | return border.Inf == '+' || minValue > maxValue || (minValue == maxValue && (border.getExclude() || max.getExclude())) 204 | } 205 | -------------------------------------------------------------------------------- /datastruct/sortedset/skiplist_test.go: -------------------------------------------------------------------------------- 1 | package sortedset 2 | 3 | import "testing" 4 | 5 | func TestRandomLevel(t *testing.T) { 6 | m := make(map[int16]int) 7 | for i := 0; i < 10000; i++ { 8 | level := randomLevel() // [1,16] 9 | m[level]++ 10 | } 11 | 12 | // 每个层级之间大约1/2的比例 13 | for i := 0; i <= defaultMaxLevel; i++ { 14 | t.Logf("level %d, count %d", i, m[int16(i)]) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /datastruct/sortedset/sortedset.go: -------------------------------------------------------------------------------- 1 | package sortedset 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/gofish2020/easyredis/tool/logger" 7 | ) 8 | 9 | // score可以相同,member不能重复 10 | type SortedSet struct { 11 | dict map[string]*Pair // 利用map基于member去重 12 | skl *skiplist // 利用skl进行排序 13 | } 14 | 15 | func NewSortedSet() *SortedSet { 16 | 17 | ss := SortedSet{} 18 | ss.dict = make(map[string]*Pair) 19 | ss.skl = newSkipList() 20 | return &ss 21 | } 22 | 23 | // bool 为true表示新增, false表示修改 24 | func (s *SortedSet) Add(member string, score float64) bool { 25 | pair, ok := s.dict[member] 26 | 27 | s.dict[member] = &Pair{ 28 | Member: member, 29 | Score: score, 30 | } 31 | 32 | // 说明是重复添加 33 | if ok { 34 | // 分值不同 35 | if score != pair.Score { 36 | // 将原来的从跳表中删除 37 | s.skl.remove(pair.Member, pair.Score) 38 | // 插入新值 39 | s.skl.insert(member, score) 40 | } 41 | // 分值相同,do nothing... 42 | return false 43 | } 44 | // 新增 45 | s.skl.insert(member, score) 46 | return true 47 | } 48 | 49 | func (s *SortedSet) Len() int64 { 50 | return int64(len(s.dict)) 51 | } 52 | 53 | func (s *SortedSet) Get(member string) (*Pair, bool) { 54 | pair, ok := s.dict[member] 55 | if !ok { 56 | return nil, false 57 | } 58 | return pair, true 59 | } 60 | 61 | func (s *SortedSet) Remove(member string) bool { 62 | 63 | pair, ok := s.dict[member] 64 | if ok { 65 | s.skl.remove(pair.Member, pair.Score) 66 | delete(s.dict, member) 67 | return true 68 | } 69 | 70 | return false 71 | } 72 | 73 | // 获取在链表中的排序索引号 74 | func (s *SortedSet) GetRank(member string, desc bool) (rank int64) { 75 | pair, ok := s.dict[member] 76 | if !ok { 77 | return -1 78 | } 79 | r := s.skl.getRank(pair.Member, pair.Score) 80 | if desc { 81 | r = s.skl.length - r 82 | } else { 83 | r-- 84 | } 85 | return r 86 | } 87 | 88 | // start / stop都是正数 89 | // 进行范围查询,扫描有序链表[start,stop)索引范围的节点,desc表示按照正序还是倒序 90 | func (s *SortedSet) ForEachByRank(start, stop int64, desc bool, consumer func(pair *Pair) bool) error { 91 | // 节点个数 92 | size := s.Len() 93 | 94 | // start不能越界 95 | if start < 0 || start >= size { 96 | return errors.New("start out of range") 97 | } 98 | // stop不能越界 99 | if start > stop || stop > size { 100 | return errors.New("stop is illegal or out of range") 101 | } 102 | 103 | // 肯定要先找到该范围内的第一个节点 104 | var node *node 105 | if desc { // 表示倒着遍历链表 106 | node = s.skl.tailer //start==0,表示链表的倒数第一个节点 107 | if start > 0 { 108 | // size-start 就是正向的排序编号 109 | node = s.skl.getByRank(size - start) // start表示从链表尾部向前的索引(倒数),start=0表示链表倒数第一个节点,start=1表示链表倒数第二个节点 110 | } 111 | 112 | } else { // 正序遍历链表 113 | // start==0 ,表示正向的第一个节点 114 | node = s.skl.header.levels[0].forward 115 | // 如果索引>0 116 | if start > 0 { 117 | // 从skl链表中找到该节点(skl内部是按照从1开始计数),start索引是从0 118 | node = s.skl.getByRank(start + 1) // 所以这里要+1 119 | } 120 | 121 | } 122 | 123 | // 找到第一个节点后,就按照链表的方式扫描链表 124 | 125 | count := stop - start // 需要扫面的节点个数 126 | 127 | for i := 0; i < int(count); i++ { 128 | 129 | if !consumer(&node.Pair) { 130 | break 131 | } 132 | if desc { 133 | node = node.backward 134 | } else { 135 | node = node.levels[0].forward 136 | } 137 | } 138 | 139 | return nil 140 | } 141 | 142 | // 扫描[start,stop)范围的节点,起始索引从0开始 143 | func (s *SortedSet) RangeByRank(start, stop int64, desc bool) []*Pair { 144 | sliceSize := stop - start 145 | 146 | slice := make([]*Pair, sliceSize) 147 | i := 0 148 | err := s.ForEachByRank(start, stop, desc, func(pair *Pair) bool { 149 | slice[i] = &Pair{ 150 | Member: pair.Member, 151 | Score: pair.Score, 152 | } 153 | i++ 154 | return true 155 | }) 156 | 157 | if err != nil { 158 | logger.Error("RangeByRank err", err) 159 | } 160 | return slice 161 | } 162 | 163 | // // 统计满足条件的节点个数 164 | // func (s *SortedSet) RangeCount(min, max Border) int64 { 165 | // var i int64 = 0 166 | // // 遍历整个链表[0,s.Len()) 167 | // s.ForEachByRank(0, s.Len(), false, func(pair *Pair) bool { 168 | 169 | // gtMin := min.less(pair) 170 | // if !gtMin { // pair < min ,不符合 171 | // return true // 小于左边界,继续遍历 172 | // } 173 | 174 | // ltMax := max.greater(pair) 175 | // if !ltMax { // pair > max 176 | // return false // 超过右边界,停止遍历 177 | // } 178 | // // min <= pair <= max 179 | // i++ 180 | // return true 181 | // }) 182 | 183 | // return i 184 | // } 185 | 186 | func (s *SortedSet) RangeCount(min, max Border) int64 { 187 | 188 | // 找到范围内的第一个节点 189 | var node *node 190 | node = s.skl.getFirstInRange(min, max) 191 | 192 | var i int64 = 0 193 | // 扫描链表 194 | for node != nil { 195 | gtMin := min.less(&node.Pair) 196 | ltMax := max.greater(&node.Pair) 197 | // 不在范围内,跳出 198 | if !gtMin || !ltMax { 199 | break 200 | } 201 | i++ 202 | node = node.levels[0].forward 203 | } 204 | return i 205 | } 206 | 207 | // 扫描[min,max] 范围内的节点,从偏移min + offset位置开始,扫面 count 个元素 208 | func (s *SortedSet) ForEach(min, max Border, offset int64, count int64, desc bool, consumer func(pair *Pair) bool) { 209 | 210 | var node *node 211 | 212 | // 查找边界节点 213 | if desc { 214 | node = s.skl.getLastInRange(min, max) 215 | } else { 216 | node = s.skl.getFirstInRange(min, max) 217 | } 218 | 219 | // 让node偏移offset 220 | for node != nil && offset > 0 { 221 | if desc { 222 | node = node.backward 223 | } else { 224 | node = node.levels[0].forward 225 | } 226 | offset-- 227 | } 228 | 229 | if node == nil { 230 | return 231 | } 232 | 233 | // 扫描limit个元素(count 可能是负数,表示不限制个数,一直扫描到边界位置) 234 | for i := int64(0); i < count || count < 0; i++ { 235 | if !consumer(&node.Pair) { 236 | break 237 | } 238 | 239 | if desc { 240 | node = node.backward 241 | } else { 242 | node = node.levels[0].forward 243 | } 244 | 245 | // 如果下一个为nil,跳出 246 | if node == nil { 247 | break 248 | } 249 | // 判断node值是否在范围内 250 | gtMin := min.less(&node.Pair) 251 | ltMax := max.greater(&node.Pair) 252 | // 不在范围内,跳出 253 | if !gtMin || !ltMax { 254 | break 255 | } 256 | } 257 | } 258 | 259 | func (s *SortedSet) Range(min Border, max Border, offset int64, count int64, desc bool) []*Pair { 260 | if count == 0 || offset < 0 { 261 | return make([]*Pair, 0) 262 | } 263 | slice := make([]*Pair, 0) 264 | s.ForEach(min, max, offset, count, desc, func(element *Pair) bool { 265 | slice = append(slice, element) 266 | return true 267 | }) 268 | return slice 269 | } 270 | 271 | // 删除范围[min,max]的元素,这里的Border表示是Score或者Member 272 | func (s *SortedSet) RemoveRange(min Border, max Border) int64 { 273 | // 从链表中删除 274 | removed := s.skl.RemoveRange(min, max, 0) 275 | 276 | // 从map中删除 277 | for _, pair := range removed { 278 | delete(s.dict, pair.Member) 279 | } 280 | return int64(len(removed)) 281 | } 282 | 283 | // 删除最小的值 284 | func (s *SortedSet) PopMin(count int) []*Pair { 285 | // 获取范围内的最小节点 286 | first := s.skl.getFirstInRange(scoreNegativeInfBorder, scorePositiveInfBorder) 287 | if first == nil { 288 | return nil 289 | } 290 | // 将最小值作为左边界 291 | border := &ScoreBorder{ 292 | Value: first.Score, 293 | Exclude: false, // 包含 294 | } 295 | // 删除范围内的count个元素 296 | removed := s.skl.RemoveRange(border, scorePositiveInfBorder, count) 297 | for _, pair := range removed { 298 | delete(s.dict, pair.Member) 299 | } 300 | return removed 301 | } 302 | 303 | // 表示删除索引[start,stop)的节点 304 | func (s *SortedSet) RemoveByRank(start int64, stop int64) int64 { 305 | 306 | // 跳表的位置编号从1开始 [start+1,stop+1) 307 | removed := s.skl.RemoveRangeByRank(start+1, stop+1) 308 | for _, element := range removed { 309 | delete(s.dict, element.Member) 310 | } 311 | return int64(len(removed)) 312 | } 313 | -------------------------------------------------------------------------------- /doc/1.tcp服务/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/1.tcp服务/image-1.png -------------------------------------------------------------------------------- /doc/1.tcp服务/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/1.tcp服务/image-2.png -------------------------------------------------------------------------------- /doc/1.tcp服务/image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/1.tcp服务/image-3.png -------------------------------------------------------------------------------- /doc/1.tcp服务/image-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/1.tcp服务/image-4.png -------------------------------------------------------------------------------- /doc/1.tcp服务/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/1.tcp服务/image.png -------------------------------------------------------------------------------- /doc/10.对象池/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/10.对象池/image-1.png -------------------------------------------------------------------------------- /doc/10.对象池/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/10.对象池/image.png -------------------------------------------------------------------------------- /doc/10.对象池/pool.md: -------------------------------------------------------------------------------- 1 | 用11篇文章实现一个可用的Redis服务,姑且叫**EasyRedis**吧,希望通过文章将Redis掰开撕碎了呈现给大家,而不是仅仅停留在八股文的层面,并且有非常爽的感觉,欢迎持续关注学习。 2 | 3 | 项目代码地址: https://github.com/gofish2020/easyredis 欢迎Fork & Star 4 | 5 | - [x] easyredis之TCP服务 6 | - [x] easyredis之网络请求序列化协议(RESP) 7 | - [x] easyredis之内存数据库 8 | - [x] easyredis之过期时间 (时间轮实现) 9 | - [x] easyredis之持久化 (AOF实现) 10 | - [x] easyredis之发布订阅功能 11 | - [x] easyredis之有序集合(跳表实现) 12 | - [x] easyredis之 pipeline 客户端实现 13 | - [x] easyredis之事务(原子性/回滚) 14 | - [x] easyredis之连接池 15 | - [ ] easyredis之分布式集群存储 16 | 17 | 18 | 19 | ### 【第十篇】Redis之连接池 20 | 21 | 通过本篇可以学到什么? 22 | - 通道的应用 23 | - 连接池的封装 24 | 25 | 从本篇开始,实现分布式相关的代码。既然是分布式,那么`redis key`就会分布(分散)在不同的集群节点上。 26 | 27 | ![](image.png) 28 | 29 | 当客户端发送`set key value`命令给`Redis0`服务,通过hash计算如果该key应该保存在`Redis2`服务,那么`Redis0`就要连接`Redis2`服务,并将命令转发给`Redis2`进行处理。 30 | 31 | 在命令的转发的过程中,需要频繁的连接**分布式节点**,所以我们需要先实现连接池的基本功能,复用连接。 32 | 33 | 在[第八篇pipeline客户端](https://mp.weixin.qq.com/s?__biz=MzkwMTE3NTY5MQ==&mid=2247484296&idx=1&sn=44aca704812e0e71386348769c190c03&chksm=c0b9836cf7ce0a7aafb150d7c59415ffbc9a20a865d5baf96bec24dc4f8d66ccbc33decfd4f0&token=1787556969&lang=zh_CN#rd)我们已经实现了**客户端连接**,本篇需要实现一个**池子的功能**将已经使用完的连接缓存起来,等到需要使用的时候,再取出来继续使用。 34 | 35 | 36 | 代码路径`tool/pool/pool.go`,代码量**160行** 37 | 38 | ### 池子结构体定义 39 | - 既然是池子,那定义的数据结构里面肯定要有个**缓冲的变量**,这里就是`idles chan any` 40 | - 一开始池子中肯定是没有对象的,所以需要有个能够创建对象的函数 `newObject` 41 | - 配套有个释放对象的函数`freeObject` 42 | - 池子中的对象不可能让他无限的增多,当达到`activeCount`个对象的时候,就不再继续用`newObject`生成新对象,需要等之前的对象回收以后,才能获取到对象(这里不理解往下继续看) 43 | 44 | ```go 45 | 46 | type Pool struct { 47 | Config 48 | 49 | // 创建对象 50 | newObject func() (any, error) 51 | // 释放对象 52 | freeObject func(x any) 53 | 54 | // 空闲对象池 55 | idles chan any 56 | 57 | mu sync.Mutex 58 | activeCount int // 已经创建的对象个数 59 | waiting []chan any // 阻塞等待 60 | 61 | closed bool // 是否已关闭 62 | } 63 | 64 | 65 | func NewPool(new func() (any, error), free func(x any), conf Config) *Pool { 66 | 67 | if new == nil { 68 | logger.Error("NewPool argument new func is nil") 69 | return nil 70 | } 71 | 72 | if free == nil { 73 | free = func(x any) {} 74 | } 75 | 76 | p := Pool{ 77 | Config: conf, 78 | newObject: new, 79 | freeObject: free, 80 | activeCount: 0, 81 | closed: false, 82 | } 83 | p.idles = make(chan any, p.MaxIdles) 84 | return &p 85 | } 86 | 87 | ``` 88 | 89 | ### 从池子中获取对象 90 | 91 | - `p.mu.Lock()`加锁(race condition) 92 | - 从空闲缓冲`idles`中获取一个之前缓冲的对象 93 | - 如果没有获取到就调用`p.getOne()`新创建一个 94 | - 在`func (p *Pool) getOne() (any, error)`函数中,会判断当前池子中是否(历史上)已经创建了足够多的对象`p.activeCount >= p.Config.MaxActive `,那就不创建新对象,**阻塞等待回收**;否则调用`newObject`函数创建新对象 95 | 96 | ```go 97 | func (p *Pool) Get() (any, error) { 98 | p.mu.Lock() 99 | if p.closed { 100 | p.mu.Unlock() 101 | return nil, ErrClosed 102 | } 103 | select { 104 | case x := <-p.idles: // 从空闲中获取 105 | p.mu.Unlock() // 解锁 106 | return x, nil 107 | default: 108 | return p.getOne() // 获取一个新的 109 | } 110 | } 111 | 112 | func (p *Pool) getOne() (any, error) { 113 | 114 | // 说明已经创建了太多对象 115 | if p.activeCount >= p.Config.MaxActive { 116 | 117 | wait := make(chan any, 1) 118 | p.waiting = append(p.waiting, wait) 119 | p.mu.Unlock() 120 | // 阻塞等待 121 | x, ok := <-wait 122 | if !ok { 123 | return nil, ErrClosed 124 | } 125 | return x, nil 126 | } 127 | 128 | p.activeCount++ 129 | p.mu.Unlock() 130 | // 创建新对象 131 | x, err := p.newObject() 132 | if err != nil { 133 | p.mu.Lock() 134 | p.activeCount-- 135 | p.mu.Unlock() 136 | return nil, err 137 | } 138 | return x, nil 139 | } 140 | 141 | 142 | ``` 143 | 144 | ### 池子对象回收 145 | 146 | 回收的过程就是对象缓存的过程,当然也要有个“度” 147 | - 先加锁 148 | - 回收前先判断是否有阻塞等待回收`len(p.waiting) > 0`,这里的逻辑和上面的等待阻塞逻辑对应起来了 149 | - 如果没有阻塞等待的,那就直接将对象保存到缓冲中`idles`中 150 | - 这里还有一个逻辑,缓冲有个大小限制(不可能无限的缓冲,多余不使用的对象,我们将它释放了,占用内存也没啥意义) 151 | 152 | ```go 153 | func (p *Pool) Put(x any) { 154 | p.mu.Lock() 155 | if p.closed { 156 | p.mu.Unlock() 157 | p.freeObject(x) // 直接释放 158 | return 159 | } 160 | 161 | //1.先判断等待中 162 | if len(p.waiting) > 0 { 163 | // 弹出一个(从头部) 164 | wait := p.waiting[0] 165 | temp := make([]chan any, len(p.waiting)-1) 166 | copy(temp, p.waiting[1:]) 167 | p.waiting = temp 168 | wait <- x // 取消阻塞 169 | p.mu.Unlock() 170 | return 171 | 172 | } 173 | // 2.直接放回空闲缓冲 174 | select { 175 | case p.idles <- x: 176 | p.mu.Unlock() 177 | default: // 说明空闲已满 178 | p.activeCount-- // 对象个数-1 179 | p.mu.Unlock() 180 | p.freeObject(x) // 释放 181 | } 182 | 183 | } 184 | 185 | ``` 186 | 187 | 188 | 189 | ### 再次封装(socket连接池) 190 | 191 | ![](image-1.png) 192 | 193 | 上面的代码已经完全实现了一个池子的功能;但是我们在实际使用的时候,每个ip地址对应一个连接池,所以这里又增加了一个结构体`RedisConnPool`,结合上面的池子功能,再配合之前的pipleline客户端的功能,实现socket连接池。 194 | 195 | 196 | 代码路径:`cluster/conn_pool.go` 197 | 代码逻辑: 198 | - 用一个字典key表示ip地址,value表示上面实现的池对象 199 | - `GetConn`获取一个ip地址对应的连接 200 | - `ReturnConn`归还连接到连接池中 201 | ```go 202 | type RedisConnPool struct { 203 | connDict *dict.ConcurrentDict // addr -> *pool.Pool 204 | } 205 | 206 | func NewRedisConnPool() *RedisConnPool { 207 | 208 | return &RedisConnPool{ 209 | connDict: dict.NewConcurrentDict(16), 210 | } 211 | } 212 | 213 | func (r *RedisConnPool) GetConn(addr string) (*client.RedisClent, error) { 214 | 215 | var connectionPool *pool.Pool // 对象池 216 | 217 | // 通过不同的地址addr,获取不同的对象池 218 | raw, ok := r.connDict.Get(addr) 219 | if ok { 220 | connectionPool = raw.(*pool.Pool) 221 | } else { 222 | 223 | // 创建对象函数 224 | newClient := func() (any, error) { 225 | // redis的客户端连接 226 | cli, err := client.NewRedisClient(addr) 227 | if err != nil { 228 | return nil, err 229 | } 230 | // 启动 231 | cli.Start() 232 | if conf.GlobalConfig.RequirePass != "" { // 说明服务需要密码 233 | reply, err := cli.Send(aof.Auth([]byte(conf.GlobalConfig.RequirePass))) 234 | if err != nil { 235 | return nil, err 236 | } 237 | if !protocol.IsOKReply(reply) { 238 | return nil, errors.New("auth failed:" + string(reply.ToBytes())) 239 | } 240 | return cli, nil 241 | } 242 | return cli, nil 243 | } 244 | 245 | // 释放对象函数 246 | freeClient := func(x any) { 247 | cli, ok := x.(*client.RedisClent) 248 | if ok { 249 | cli.Stop() // 释放 250 | } 251 | } 252 | 253 | // 针对addr地址,创建一个新的对象池 254 | connectionPool = pool.NewPool(newClient, freeClient, pool.Config{ 255 | MaxIdles: 1, 256 | MaxActive: 20, 257 | }) 258 | // addr -> *pool.Pool 259 | r.connDict.Put(addr, connectionPool) 260 | } 261 | 262 | // 从对象池中获取一个对象 263 | raw, err := connectionPool.Get() 264 | if err != nil { 265 | return nil, err 266 | } 267 | conn, ok := raw.(*client.RedisClent) 268 | if !ok { 269 | return nil, errors.New("connection pool make wrong type") 270 | } 271 | return conn, nil 272 | } 273 | 274 | func (r *RedisConnPool) ReturnConn(peer string, cli *client.RedisClent) error { 275 | raw, ok := r.connDict.Get(peer) 276 | if !ok { 277 | return errors.New("connection pool not found") 278 | } 279 | raw.(*pool.Pool).Put(cli) 280 | return nil 281 | } 282 | 283 | ``` -------------------------------------------------------------------------------- /doc/11.分布式集群/cluster.md: -------------------------------------------------------------------------------- 1 | 2 | 用11篇文章实现一个可用的Redis服务,姑且叫**EasyRedis**吧,希望通过文章将Redis掰开撕碎了呈现给大家,而不是仅仅停留在八股文的层面,并且有非常爽的感觉,欢迎持续关注学习。 3 | 4 | 项目代码地址: https://github.com/gofish2020/easyredis 欢迎Fork & Star 5 | 6 | - [x] easyredis之TCP服务 7 | - [x] easyredis之网络请求序列化协议(RESP) 8 | - [x] easyredis之内存数据库 9 | - [x] easyredis之过期时间 (时间轮实现) 10 | - [x] easyredis之持久化 (AOF实现) 11 | - [x] easyredis之发布订阅功能 12 | - [x] easyredis之有序集合(跳表实现) 13 | - [x] easyredis之 pipeline 客户端实现 14 | - [x] easyredis之事务(原子性/回滚) 15 | - [x] easyredis之连接池 16 | - [x] easyredis之分布式集群存储 17 | - [ ] easyredis之分布式事务 18 | 19 | ## 分布式集群 20 | 21 | 22 | ### 一致性hash算法 23 | 24 | 为什么需要一致性 hash? 25 | 在采用分片方式建立分布式缓存时,我们面临的第一个问题是如何决定存储数据的节点。最自然的方式是参考 hash 表的做法,假设集群中存在 n 个节点,我们用 `node = hashCode(key) % n` 来决定所属的节点。 26 | 27 | 普通 hash 算法解决了如何选择节点的问题,但在分布式系统中经常出现增加节点或某个节点宕机的情况。若节点数 n 发生变化, 大多数 `key` 根据 `node = hashCode(key) % n` 计算出的节点都会改变。这意味着若要在 n 变化后维持系统正常运转,需要将大多数数据在节点间进行重新分布。这个操作会消耗大量的时间和带宽等资源,这在生产环境下是不可接受的。 28 | 29 | 算法原理 30 | 一致性 hash 算法的目的是在节点数量 n 变化时, 使尽可能少的 key 需要进行节点间重新分布。一致性 hash 算法将数据 key 和服务器地址 addr 散列到 2^32 的空间中。 31 | 32 | 我们将 2^32 个整数首尾相连形成一个环,首先计算服务器地址 addr 的 hash 值放置在环上。然后计算 key 的 hash 值放置在环上,**顺时针查找**,将数据放在找到的的第一个节点上。 33 | ![](image.png) 34 | 35 | `key1 key4`归属于 `192.168.1.20`节点 36 | `key2`归属于 `192.168.1.21`节点 37 | `key3`归属于 `192.168.1.23`节点 38 | 39 | 在增加或删除节点时只有该节点附近的数据需要重新分布,从而解决了上述问题。 40 | ![](image-1.png) 41 | 新增 节点`192.168.1.24` 后,`key4` 从 `192.168.1.20` 转移到 `192.168.1.24`其它 key 不变 42 | 43 | 44 | 一般来说环上的节点越多数据分布越均匀,不过我们不需要真的增加一台服务器,只需要将实际的服务器节点映射为几个**虚拟节点**放在环上即可。 45 | 46 | 47 | 48 | ### 代码实现 49 | 50 | 代码路径`tool/consistenthash/consistenthash.go` 51 | 52 | 数据结构体定义: 53 | 54 | ```go 55 | type HashFunc func(data []byte) uint32 56 | 57 | type Map struct { 58 | hashFunc HashFunc // 计算hash函数 59 | replicas int // 每个节点的虚拟节点数量 60 | hashValue []int // hash值 61 | hashMap map[int]string // hash值映射的真实节点 62 | } 63 | 64 | /* 65 | replicas:副本数量 66 | fn:hash函数 67 | */ 68 | func New(replicas int, fn HashFunc) *Map { 69 | m := &Map{ 70 | replicas: replicas, 71 | hashFunc: fn, 72 | hashMap: make(map[int]string), 73 | } 74 | if m.hashFunc == nil { 75 | m.hashFunc = crc32.ChecksumIEEE 76 | } 77 | return m 78 | } 79 | ``` 80 | 81 | 服务启动时,添加主机节点 82 | 83 | ```go 84 | 85 | // 添加 节点 86 | func (m *Map) Add(ipAddrs ...string) { 87 | for _, ipAddr := range ipAddrs { 88 | if ipAddr == "" { 89 | continue 90 | } 91 | // 每个ipAddr 生成 m.replicas个哈希值副本 92 | for i := 0; i < m.replicas; i++ { 93 | hash := int(m.hashFunc([]byte(strconv.Itoa(i) + ipAddr))) 94 | // 记录hash值 95 | m.hashValue = append(m.hashValue, hash) 96 | // 映射hash为同一个ipAddr 97 | m.hashMap[hash] = ipAddr 98 | } 99 | } 100 | sort.Ints(m.hashValue) 101 | } 102 | ``` 103 | 104 | 获取key归属的节点 105 | 106 | ```go 107 | 108 | // Get gets the closest item in the hash to the provided key. 109 | func (m *Map) Get(key string) string { 110 | if m.IsEmpty() { 111 | return "" 112 | } 113 | 114 | partitionKey := getPartitionKey(key) 115 | hash := int(m.hashFunc([]byte(partitionKey))) 116 | 117 | // 查找 m.keys中第一个大于or等于hash值的元素索引 118 | idx := sort.Search(len(m.hashValue), func(i int) bool { return m.hashValue[i] >= hash }) // 119 | 120 | // 表示找了一圈没有找到大于or等于hash值的元素,那么默认是第0号元素 121 | if idx == len(m.hashValue) { 122 | idx = 0 123 | } 124 | 125 | // 返回 key应该存储的ipAddr 126 | return m.hashMap[m.hashValue[idx]] 127 | } 128 | 129 | // support hash tag example :{key} 130 | func getPartitionKey(key string) string { 131 | beg := strings.Index(key, "{") 132 | if beg == -1 { 133 | return key 134 | } 135 | end := strings.Index(key, "}") 136 | if end == -1 || end == beg+1 { 137 | return key 138 | } 139 | return key[beg+1 : end] 140 | } 141 | ``` 142 | 143 | 144 | ### 集群实现 145 | 146 | 代码路径 `cluster/cluster.go` 147 | 148 | 149 | 集群启动的时候,基于配置文件中的`peers`,初始化一致性hash对象`consistHash *consistenthash.Map` 150 | ```go 151 | const ( 152 | replicas = 100 // 副本数量 153 | ) 154 | 155 | type Cluster struct { 156 | // 当前的ip地址 157 | self string 158 | // socket连接池 159 | clientFactory *RedisConnPool 160 | // Redis存储引擎 161 | engine *engine.Engine 162 | 163 | // 一致性hash 164 | consistHash *consistenthash.Map 165 | } 166 | 167 | func NewCluster() *Cluster { 168 | cluster := Cluster{ 169 | clientFactory: NewRedisConnPool(), 170 | engine: engine.NewEngine(), 171 | consistHash: consistenthash.New(replicas, nil), 172 | self: conf.GlobalConfig.Self, 173 | } 174 | 175 | // 一致性hash初始化 176 | contains := make(map[string]struct{}) 177 | peers := make([]string, 0, len(conf.GlobalConfig.Peers)+1) 178 | // 去重 179 | for _, peer := range conf.GlobalConfig.Peers { 180 | if _, ok := contains[peer]; ok { 181 | continue 182 | } 183 | peers = append(peers, peer) 184 | } 185 | peers = append(peers, cluster.self) 186 | cluster.consistHash.Add(peers...) 187 | return &cluster 188 | } 189 | 190 | ``` 191 | 192 | 当节点接收到客户端发送来的Redis命令的时候,从注册中心`clusterRouter`,获取命令处理函数 193 | ```go 194 | func (cluster *Cluster) Exec(c abstract.Connection, redisCommand [][]byte) (result protocol.Reply) { 195 | defer func() { 196 | if err := recover(); err != nil { 197 | logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack()))) 198 | result = protocol.NewUnknownErrReply() 199 | } 200 | }() 201 | 202 | name := strings.ToLower(string(redisCommand[0])) 203 | routerFunc, ok := clusterRouter[name] 204 | if !ok { 205 | return protocol.NewGenericErrReply("unknown command '" + name + "' or not support command in cluster mode") 206 | } 207 | return routerFunc(cluster, c, redisCommand) 208 | } 209 | ``` 210 | 211 | 这里暂时只注册了`set get `命令,在处理函数`defultFunc`中,会调用`cluster.consistHash.Get(key)`函数基于一致性hash算法,计算key应该由哪个节点处理(其实就是节点的ip地址) 212 | 213 | ```go 214 | 215 | type clusterFunc func(cluster *Cluster, conn abstract.Connection, args [][]byte) protocol.Reply 216 | 217 | var clusterRouter = make(map[string]clusterFunc) 218 | 219 | func init() { 220 | 221 | clusterRouter["set"] = defultFunc 222 | clusterRouter["get"] = defultFunc 223 | } 224 | 225 | func defultFunc(cluster *Cluster, conn abstract.Connection, redisCommand [][]byte) protocol.Reply { 226 | key := string(redisCommand[1]) 227 | peer := cluster.consistHash.Get(key) 228 | return cluster.Relay(peer, conn, redisCommand) // 将命令转发 229 | 230 | } 231 | ``` 232 | 233 | 234 | 最后在 `Relay`函数中,基于`peer string`参数,判断该ip地址是当前节点的ip还是其他的节点ip;如果是远程节点,将使用上篇文件介绍的连接池,连接节点并将命令转发 235 | ```go 236 | func (cluster *Cluster) Relay(peer string, conn abstract.Connection, redisCommand [][]byte) protocol.Reply { 237 | 238 | // ******本地执行****** 239 | if cluster.self == peer { 240 | return cluster.engine.Exec(conn, redisCommand) 241 | } 242 | 243 | // ******发送到远端执行****** 244 | 245 | client, err := cluster.clientFactory.GetConn(peer) // 从连接池中获取一个连接 246 | if err != nil { 247 | logger.Error(err) 248 | return protocol.NewGenericErrReply(err.Error()) 249 | } 250 | 251 | defer func() { 252 | cluster.clientFactory.ReturnConn(peer, client) // 归还连接 253 | }() 254 | 255 | logger.Debugf("命令:%q,转发至ip:%s", protocol.NewMultiBulkReply(redisCommand).ToBytes(), peer) 256 | reply, err := client.Send(redisCommand) // 发送命令 257 | if err != nil { 258 | logger.Error(err) 259 | return protocol.NewGenericErrReply(err.Error()) 260 | } 261 | 262 | return reply 263 | } 264 | ``` 265 | 266 | ### 效果图如下: 267 | 268 | ![](image-2.png) 269 | 270 | 271 | 272 | 这里实现的集群其实比较简陋,集群的元数据信息都是在配置文件中写死,实际线上的产品会使用 `gossip or raft`协议维护集群(也就是可以动态的增加/较少节点),这个等我学会了,会再把这块重新写一下。 -------------------------------------------------------------------------------- /doc/11.分布式集群/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/11.分布式集群/image-1.png -------------------------------------------------------------------------------- /doc/11.分布式集群/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/11.分布式集群/image-2.png -------------------------------------------------------------------------------- /doc/11.分布式集群/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/11.分布式集群/image.png -------------------------------------------------------------------------------- /doc/12.分布式事务TCC/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/12.分布式事务TCC/image.png -------------------------------------------------------------------------------- /doc/12.分布式事务TCC/transaction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/12.分布式事务TCC/transaction.png -------------------------------------------------------------------------------- /doc/2.Redis序列化协议/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/2.Redis序列化协议/image.png -------------------------------------------------------------------------------- /doc/3.内存数据库/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/3.内存数据库/image-1.png -------------------------------------------------------------------------------- /doc/3.内存数据库/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/3.内存数据库/image-2.png -------------------------------------------------------------------------------- /doc/3.内存数据库/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/3.内存数据库/image.png -------------------------------------------------------------------------------- /doc/4.延迟算法(时间轮)/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/4.延迟算法(时间轮)/image-1.png -------------------------------------------------------------------------------- /doc/4.延迟算法(时间轮)/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/4.延迟算法(时间轮)/image-2.png -------------------------------------------------------------------------------- /doc/4.延迟算法(时间轮)/image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/4.延迟算法(时间轮)/image-3.png -------------------------------------------------------------------------------- /doc/4.延迟算法(时间轮)/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/4.延迟算法(时间轮)/image.png -------------------------------------------------------------------------------- /doc/4.延迟算法(时间轮)/时间轮.md: -------------------------------------------------------------------------------- 1 | # Golang实现自己的Redis(过期时间) 2 | 3 | 用11篇文章实现一个可用的Redis服务,姑且叫**EasyRedis**吧,希望通过文章将Redis掰开撕碎了呈现给大家,而不是仅仅停留在八股文的层面,并且有非常爽的感觉,欢迎持续关注学习。 4 | 5 | 项目代码地址: https://github.com/gofish2020/easyredis 欢迎Fork & Star 6 | 7 | - [x] easyredis之TCP服务 8 | - [x] easyredis之网络请求序列化协议(RESP) 9 | - [x] easyredis之内存数据库 10 | - [x] easyredis之过期时间 (时间轮实现) 11 | - [ ] easyredis之持久化 (AOF实现) 12 | - [ ] easyredis之发布订阅功能 13 | - [ ] easyredis之有序集合(跳表实现) 14 | - [ ] easyredis之 pipeline 客户端实现 15 | - [ ] easyredis之事务(原子性/回滚) 16 | - [ ] easyredis之连接池 17 | - [ ] easyredis之分布式集群存储 18 | 19 | 20 | ## 【第四篇】EasyRedis之过期时间 21 | 22 | 在使用`Redis`的时候经常会对缓存设定过期时间,例如`set key value ex 3`,设定过期时间`3s`,等到过期以后,我们再执行`get key`正常情况下是得不到数据的。不同的key会设定不同的过期时间`1s 5s 2s`等等。按照八股文我们知道`key`过期的时候,有两种删除策略: 23 | - 惰性删除:不主动删除过期key,当访问该key的时候,如果发现过期了再删除 24 | 好处:对CPU友好,不用频繁执行删除,但是对内存不友好,都过期了还占用内存 25 | - 定时删除:主动删除key,到了key的过期时间,立即执行删除 26 | 好处:对内存友好,可以缓解内存压力,对CPU不友好,需要频繁的执行删除 27 | 28 | 所以redis就把两种策略都实现了,我们看下代码如何使下? 29 | 30 | 31 | ### 惰性删除 32 | 本质就是访问的时候判断下key是否过期,过期就删除并返回空。 33 | 代码路径`engine/database.go` 34 | 在获取key的值时候,我们会执行一次 `db.IsExpire(key)`判断key是否过期 35 | ```go 36 | func (db *DB) GetEntity(key string) (*payload.DataEntity, bool) { 37 | 38 | // key 不存在 39 | val, exist := db.dataDict.Get(key) 40 | if !exist { 41 | return nil, false 42 | } 43 | // key是否过期(主动检测一次) 44 | if db.IsExpire(key) { 45 | return nil, false 46 | } 47 | // 返回内存数据 48 | dataEntity, ok := val.(*payload.DataEntity) 49 | if !ok { 50 | return nil, false 51 | } 52 | return dataEntity, true 53 | } 54 | ``` 55 | 就是从过期字典`ttlDict`中获取key的过期时间 56 | - 如果没有获取到,说明没有设定过期时间(do nothing) 57 | - 如果有过期时间,并且时间已经过期,主动删除之 58 | 59 | ```go 60 | // 判断key是否已过期 61 | func (db *DB) IsExpire(key string) bool { 62 | val, result := db.ttlDict.Get(key) 63 | if !result { 64 | return false 65 | } 66 | expireTime, _ := val.(time.Time) 67 | isExpire := time.Now().After(expireTime) 68 | if isExpire { // 如果过期,主动删除 69 | db.Remove(key) 70 | } 71 | return isExpire 72 | } 73 | 74 | ``` 75 | 76 | ### 定时删除 77 | 78 | 本质是对key设定一个过期时间,时间一到立即执行删除的任务。 79 | 正常的思路肯定是设定一个固定的定时器,例如`3s`检测一次,这种思路可以,但是存在一个问题, 80 | - 如果key的过期时间为`1s`,那你`3s`才检测是否太不够及时了? 81 | - 那就把检测间隔设定为`1s`吧,那如果`key`的过期时间都为`3s`,到执行时间检测一遍发现任务都没过期,那不就白白浪费CPU时间了吗? 82 | 83 | 这就要推出我们的时间轮算法了,时间轮算法就是在模拟现实世界**钟表**的原理 84 | 85 | ![](image.png) 86 | 87 | - 我想里面增加2个3s的任务,那就将任务添加到距离当前位置`pos + 3`的位置 88 | - 同时再加1个5s的任务,那就将任务添加到距离当前位置`pos + 5`的位置 89 | 90 | ![](image-1.png) 91 | 当钟表的指针指向`pos + 3`的位置,就执行**任务链表**的任务即可。 92 | 因为钟表是循环往复的运行,那如果我再添加11s的任务,可以发现该任务也是放置到 `pos+3`的位置,那任务就要区分下,到底是3s的任务还是11s的任务 93 | 94 | ![](image-2.png) 95 | 所以里面又有了一个`circle`的标记,表示当前任务是第几圈的任务 96 | 97 | 98 | 代码路径`tool/timewheel` 99 | 100 | 代码中通过切片模型环,通过链表模拟任务链表 101 | ```go 102 | // 循环队列 + 链表 103 | type TimeWheel struct { 104 | 105 | // 间隔 106 | interval time.Duration 107 | // 定时器 108 | ticker *time.Ticker 109 | 110 | // 游标 111 | curSlotPos int 112 | // 循环队列大小 113 | slotNum int 114 | // 底层存储 115 | slots []*list.List 116 | m map[string]*taskPos 117 | 118 | // 任务通道 119 | addChannel chan *task 120 | cacelChannel chan string 121 | // 停止 122 | stopChannel chan struct{} 123 | } 124 | 125 | ``` 126 | 127 | 当添加任务的时候,需要通过延迟时间计算当前任务的圈数`circle` 128 | 129 | ```go 130 | 131 | func (tw *TimeWheel) posAndCircle(d time.Duration) (pos, circle int) { 132 | 133 | // 延迟(秒) 134 | delaySecond := int(d.Seconds()) 135 | // 间隔(秒) 136 | intervalSecond := int(tw.interval.Seconds()) 137 | // delaySecond/intervalSecond 表示从curSlotPos位置偏移 138 | pos = (tw.curSlotPos + delaySecond/intervalSecond) % tw.slotNum 139 | circle = (delaySecond / intervalSecond) / tw.slotNum 140 | return 141 | } 142 | 143 | func (tw *TimeWheel) addTask(t *task) { 144 | 145 | // 定位任务应该保存在循环队列的位置 & 圈数 146 | pos, circle := tw.posAndCircle(t.delay) 147 | t.circle = circle 148 | 149 | // 将任务保存到循环队列pos位置 150 | ele := tw.slots[pos].PushBack(t) 151 | // 在map中记录 key -> { pos, ele } 的映射 152 | if t.key != "" { 153 | // 已经存在重复的key 154 | if _, ok := tw.m[t.key]; ok { 155 | tw.cancelTask(t.key) 156 | } 157 | tw.m[t.key] = &taskPos{pos: pos, ele: ele} 158 | } 159 | } 160 | 161 | ``` 162 | 代码中注释的很清晰,也就**100多行**建议看代码结合上图体会下(很简单) 163 | 164 | ### 额外补充 165 | 166 | 我们在执行`set key value ex 3`的时候,先设定过期时间为3s,但是在1s的时候,我们又执行了`set key value`,请问key还会过期吗?? 167 | 168 | 答案:不会过期了。相当于对key去掉了过期时间。所以在代码处理中,我们需要考虑这种情况,重复设定的问题 169 | 170 | 171 | 代码细节位于`engine/string.go` set命令处理函数`func cmdSet(db *DB, args [][]byte) protocol.Reply `的尾部位置 172 | ![](image-3.png) -------------------------------------------------------------------------------- /doc/5.持久化之AOF/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/5.持久化之AOF/image-1.png -------------------------------------------------------------------------------- /doc/5.持久化之AOF/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/5.持久化之AOF/image-2.png -------------------------------------------------------------------------------- /doc/5.持久化之AOF/image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/5.持久化之AOF/image-3.png -------------------------------------------------------------------------------- /doc/5.持久化之AOF/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/5.持久化之AOF/image.png -------------------------------------------------------------------------------- /doc/6.发布订阅/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/6.发布订阅/image-1.png -------------------------------------------------------------------------------- /doc/6.发布订阅/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/6.发布订阅/image-2.png -------------------------------------------------------------------------------- /doc/6.发布订阅/image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/6.发布订阅/image-3.png -------------------------------------------------------------------------------- /doc/6.发布订阅/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/6.发布订阅/image.png -------------------------------------------------------------------------------- /doc/6.发布订阅/发布订阅.md: -------------------------------------------------------------------------------- 1 | # Golang实现自己的Redis(发布订阅功能) 2 | 3 | 用11篇文章实现一个可用的Redis服务,姑且叫**EasyRedis**吧,希望通过文章将Redis掰开撕碎了呈现给大家,而不是仅仅停留在八股文的层面,并且有非常爽的感觉,欢迎持续关注学习。 4 | 5 | 项目代码地址: https://github.com/gofish2020/easyredis 欢迎Fork & Star 6 | 7 | - [x] easyredis之TCP服务 8 | - [x] easyredis之网络请求序列化协议(RESP) 9 | - [x] easyredis之内存数据库 10 | - [x] easyredis之过期时间 (时间轮实现) 11 | - [x] easyredis之持久化 (AOF实现) 12 | - [x] easyredis之发布订阅功能 13 | - [ ] easyredis之有序集合(跳表实现) 14 | - [ ] easyredis之 pipeline 客户端实现 15 | - [ ] easyredis之事务(原子性/回滚) 16 | - [ ] easyredis之连接池 17 | - [ ] easyredis之分布式集群存储 18 | 19 | ## 【第六篇】EasyRedis之发布订阅 20 | 21 | 代码路径: `pubhub/pubhub.go`这个代码很简单,总共就**200行** 22 | 23 | 发布订阅的基本原理:客户端A/B/C订阅通道,客户端D往通道中发送消息后,客户端A/B/C可以接收到通道中的消息 24 | 25 | ![](image.png) 26 | 27 | **效果演示**: 28 | ![](image-3.png) 29 | 30 | 31 | 底层实现的数据结构采用`map + list`,map中的`key`表示channel,value则用list来存储同一个channel下的多个客户端`clientN` 32 | 33 | ```go 34 | 35 | type Pubhub struct { 36 | 37 | // 自定义实现的map 38 | dataDict dict.ConcurrentDict 39 | 40 | // 该锁的颗粒度太大 41 | //locker sync.RWMutex 42 | 43 | locker *locker.Locker // 自定义一个分布锁 44 | } 45 | ``` 46 | - `dataDict`就是我们自己实现的`map` 47 | - `locker` 用来对操作**同一个链表**的不同客户端加锁,避免并发问题 48 | 49 | 50 | ### 订阅Subscribe 51 | 52 | - 获取客户端发送来的通道名 53 | - 加锁(锁的原理看文章最后) 54 | - 遍历通道,获取该通道下的客户端链表 55 | - 将当前的客户端加入到链表中即可(前提:没有订阅过) 56 | 57 | ```go 58 | // SUBSCRIBE channel [channel ...] 59 | func (p *Pubhub) Subscribe(c abstract.Connection, args [][]byte) protocol.Reply { 60 | 61 | if len(args) < 1 { 62 | return protocol.NewArgNumErrReply("subscribe") 63 | } 64 | 65 | // 通道名 66 | keys := make([]string, 0, len(args)) 67 | for _, arg := range args { 68 | keys = append(keys, string(arg)) 69 | } 70 | // 加锁 71 | p.locker.Locks(keys...) 72 | defer p.locker.Unlocks(keys...) 73 | 74 | for _, arg := range args { 75 | chanName := string(arg) 76 | // 记录当前客户端连接订阅的通道 77 | c.Subscribe(chanName) 78 | 79 | // 双向链表,记录通道下的客户端连接 80 | var l *list.LinkedList 81 | raw, exist := p.dataDict.Get(chanName) 82 | if !exist { // 说明该channel第一次使用 83 | l = list.NewLinkedList() 84 | p.dataDict.Put(chanName, l) 85 | } else { 86 | l, _ = raw.(*list.LinkedList) 87 | } 88 | 89 | // 未订阅 90 | if !l.Contain(func(actual interface{}) bool { 91 | return c == actual 92 | }) { 93 | // 如果不重复,那就记录订阅 94 | logger.Debug("subscribe channel [" + chanName + "] success") 95 | l.Add(c) 96 | } 97 | 98 | // 回复客户端消息 99 | _, err := c.Write(channelMsg(_subscribe, chanName, c.SubCount())) 100 | if err != nil { 101 | logger.Warn(err) 102 | } 103 | } 104 | 105 | return protocol.NewNoReply() 106 | } 107 | ``` 108 | 109 | ### 取消订阅 Unsubscribe 110 | 111 | - 获取通道名(如果没有指定,就是取消当前客户端的所有通道) 112 | - 加锁(锁的原理看文章最后) 113 | - 获取该通道下的客户端链表 114 | - 从链表中删除当前的客户端 115 | 116 | ```go 117 | // 取消订阅 118 | // unsubscribes itself from all the channels using the UNSUBSCRIBE command without additional arguments 119 | func (p *Pubhub) Unsubscribe(c abstract.Connection, args [][]byte) protocol.Reply { 120 | 121 | var channels []string 122 | if len(args) < 1 { // 取消全部 123 | channels = c.GetChannels() 124 | } else { // 取消指定channel 125 | channels = make([]string, len(args)) 126 | for i, v := range args { 127 | channels[i] = string(v) 128 | } 129 | } 130 | 131 | p.locker.Locks(channels...) 132 | defer p.locker.Unlocks(channels...) 133 | 134 | // 说明已经没有订阅的通道 135 | if len(channels) == 0 { 136 | c.Write(noChannelMsg()) 137 | } 138 | for _, channel := range channels { 139 | 140 | // 从客户端中删除当前通道 141 | c.Unsubscribe(channel) 142 | // 获取链表 143 | raw, ok := p.dataDict.Get(channel) 144 | if ok { 145 | // 从链表中删除当前客户端 146 | l, _ := raw.(*list.LinkedList) 147 | l.DelAllByVal(func(actual interface{}) bool { 148 | return c == actual 149 | }) 150 | 151 | // 如果链表为空,清理map 152 | if l.Len() == 0 { 153 | p.dataDict.Delete(channel) 154 | } 155 | } 156 | c.Write(channelMsg(_unsubscribe, channel, c.SubCount())) 157 | } 158 | 159 | return protocol.NewNoReply() 160 | } 161 | ``` 162 | 163 | ### 发布 publish 164 | 165 | - 获取客户端的channel 166 | - 从map将channel作为key得到客户端链表 167 | - 对链表的所有客户端发送数据即可 168 | 169 | ```go 170 | 171 | func (p *Pubhub) Publish(self abstract.Connection, args [][]byte) protocol.Reply { 172 | 173 | if len(args) != 2 { 174 | return protocol.NewArgNumErrReply("publish") 175 | } 176 | 177 | channelName := string(args[0]) 178 | // 加锁 179 | p.locker.Locks(channelName) 180 | defer p.locker.Unlocks(channelName) 181 | 182 | raw, ok := p.dataDict.Get(channelName) 183 | if ok { 184 | 185 | var sendSuccess int64 186 | var failedClient = make(map[interface{}]struct{}) 187 | // 取出链表 188 | l, _ := raw.(*list.LinkedList) 189 | // 遍历链表 190 | l.ForEach(func(i int, val interface{}) bool { 191 | 192 | conn, _ := val.(abstract.Connection) 193 | 194 | if conn.IsClosed() { 195 | failedClient[val] = struct{}{} 196 | return true 197 | } 198 | 199 | if val == self { //不给自己发送 200 | return true 201 | } 202 | // 发送数据 203 | conn.Write(publisMsg(channelName, string(args[1]))) 204 | sendSuccess++ 205 | return true 206 | }) 207 | 208 | // 剔除客户端 209 | if len(failedClient) > 0 { 210 | removed := l.DelAllByVal(func(actual interface{}) bool { 211 | _, ok := failedClient[actual] 212 | return ok 213 | }) 214 | logger.Debugf("del %d closed client", removed) 215 | } 216 | 217 | // 返回发送的客户端数量 218 | return protocol.NewIntegerReply(sendSuccess) 219 | } 220 | // 如果channel不存在 221 | return protocol.NewIntegerReply(0) 222 | } 223 | 224 | ``` 225 | 226 | 227 | ### 锁的原理 228 | 代码路径 `tool/locker/locker.go` 229 | 230 | ```go 231 | type Pubhub struct { 232 | 233 | // 自定义实现的map 234 | dataDict dict.ConcurrentDict 235 | 236 | // 该锁的颗粒度太大 237 | //locker sync.RWMutex 238 | 239 | locker *locker.Locker // 自定义一个分布锁 240 | } 241 | 242 | 243 | ``` 244 | 245 | 在结构体中,当有【多个客户端同时订阅不同的通道】,通过通道名,可以获取到不同的客户端**链表**,也就是不同的客户端操作不同的链表可以并行操作(只有操作同一个链表才是互斥),如果我们使用 `locker sync.RWMutex` 锁,那就是所有的客户端持有同一把锁,一个客户端只有操作完成一个链表,才能允许另一个客户端操作另外一个链表,整个操作只能是串行的。所以我们需要实现一个颗粒度更小的锁 246 | 247 | ![](image-1.png) 248 | 249 | 通过不同的通道名,加不同的锁即可(尽可能的减小锁的粒度),同时为了避免死锁,并行的协程加锁的顺序要一致。所以代码中有个排序。 250 | 251 | 这里做了一个技巧,通过hash将通道名映射成不同的hash值,再通过取余,将锁固定在一个范围内(将无限多的channel名 转成 有限范围的值),所以可能存在不同的通道名取余的结果,用的同一个锁 252 | 253 | ![](image-2.png) 254 | ```go 255 | 256 | type Locker struct { 257 | mu []*sync.RWMutex 258 | mask uint32 259 | } 260 | 261 | // 顺序加锁(互斥) 262 | func (l *Locker) Locks(keys ...string) { 263 | indexs := l.toLockIndex(keys...) 264 | for _, index := range indexs { 265 | mu := l.mu[index] 266 | mu.Lock() 267 | } 268 | } 269 | 270 | // 顺序解锁(互斥) 271 | func (l *Locker) Unlocks(keys ...string) { 272 | indexs := l.toLockIndex(keys...) 273 | for _, index := range indexs { 274 | mu := l.mu[index] 275 | mu.Unlock() 276 | } 277 | } 278 | 279 | func (l *Locker) toLockIndex(keys ...string) []uint32 { 280 | 281 | // 将key转成 切片索引[0,mask] 282 | mapIndex := make(map[uint32]struct{}) // 去重 283 | for _, key := range keys { 284 | mapIndex[l.spread(utils.Fnv32(key))] = struct{}{} 285 | } 286 | 287 | indices := make([]uint32, 0, len(mapIndex)) 288 | for k := range mapIndex { 289 | indices = append(indices, k) 290 | } 291 | // 对索引排序 292 | sort.Slice(indices, func(i, j int) bool { 293 | return indices[i] < indices[j] 294 | }) 295 | return indices 296 | } 297 | 298 | ``` 299 | 300 | ### 总结 301 | 锁相关的代码很有实践意义,建议大家自己的手动敲一下,平时工作中作为自己的代码小组件使用,绝对可以让人眼前一亮。 -------------------------------------------------------------------------------- /doc/7.跳表的实现/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/7.跳表的实现/image-1.png -------------------------------------------------------------------------------- /doc/7.跳表的实现/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/7.跳表的实现/image-2.png -------------------------------------------------------------------------------- /doc/7.跳表的实现/image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/7.跳表的实现/image-3.png -------------------------------------------------------------------------------- /doc/7.跳表的实现/image-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/7.跳表的实现/image-4.png -------------------------------------------------------------------------------- /doc/7.跳表的实现/image-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/7.跳表的实现/image-5.png -------------------------------------------------------------------------------- /doc/7.跳表的实现/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/7.跳表的实现/image.png -------------------------------------------------------------------------------- /doc/8.pipeline客户端/client.md: -------------------------------------------------------------------------------- 1 | # Golang实现自己的Redis (pipeline客户端) 2 | 3 | 用11篇文章实现一个可用的Redis服务,姑且叫**EasyRedis**吧,希望通过文章将Redis掰开撕碎了呈现给大家,而不是仅仅停留在八股文的层面,并且有非常爽的感觉,欢迎持续关注学习。 4 | 5 | 项目代码地址: https://github.com/gofish2020/easyredis 欢迎Fork & Star 6 | 7 | - [x] easyredis之TCP服务 8 | - [x] easyredis之网络请求序列化协议(RESP) 9 | - [x] easyredis之内存数据库 10 | - [x] easyredis之过期时间 (时间轮实现) 11 | - [x] easyredis之持久化 (AOF实现) 12 | - [x] easyredis之发布订阅功能 13 | - [x] easyredis之有序集合(跳表实现) 14 | - [x] easyredis之 pipeline 客户端实现 15 | - [ ] easyredis之事务(原子性/回滚) 16 | - [ ] easyredis之连接池 17 | - [ ] easyredis之分布式集群存储 18 | 19 | ## 【第八篇】EasyRedis之pipeline客户端 20 | 21 | 22 | 网络编程的一个基础知识:用同一个sokcet连接发送多个数据包的时候,我们一般的做法是,发送并立刻接收结果,在没有接收到,是不会继续发送数据包。这种方法简单,但是效率太低。时间都浪费在等待上了... 23 | 24 | socket的【发送缓冲区和接收缓冲区】是分离的,也就是发送不用等待接收,接收也不用等待发送。 25 | 26 | 所以我们可以把我们要发送的多个数据包【数据包1/数据包2...数据包N】复用同一个连接,通过**发送缓冲区**按顺序都发送给服务端。服务端处理请求的顺序,也是按照【数据包1/数据包2...数据包N】这个顺序处理的。当处理完以后,处理结果将按照【数据包结果1/数据包结果2...数据包结果N】顺序发送给客户端的**接收缓冲区**。客户端只需要从接收缓冲区中读取数据,并保存到请求数据包上,即可。这样我们就可以将发送和接收分离开来。一个协程只负责发送,一个协程只负责接收,互相不用等待。关键在于保证**发送和接收的顺序是相同的** 27 | 设计逻辑图如下: 28 | ![](image.png) 29 | 30 | 代码路径`redis/client/client.go` 31 | 整个代码也就是200多行,结合上图非常容易理解 32 | 33 | ### 创建客户端 34 | 35 | ```go 36 | type RedisClent struct { 37 | // socket连接 38 | conn net.Conn 39 | 40 | addr string 41 | // 客户端当前状态 42 | connStatus atomic.Int32 43 | 44 | // heartbeat 45 | ticker time.Ticker 46 | 47 | // buffer cache 48 | waitSend chan *request 49 | waitResult chan *request 50 | 51 | // 有请求正在处理中... 52 | working sync.WaitGroup 53 | } 54 | 55 | // 创建redis客户端socket 56 | func NewRedisClient(addr string) (*RedisClent, error) { 57 | conn, err := net.Dial("tcp", addr) 58 | if err != nil { 59 | return nil, err 60 | } 61 | 62 | rc := RedisClent{} 63 | rc.conn = conn 64 | rc.waitSend = make(chan *request, maxChanSize) 65 | rc.waitResult = make(chan *request, maxChanSize) 66 | rc.addr = addr 67 | return &rc, nil 68 | } 69 | 70 | // 启动 71 | func (rc *RedisClent) Start() error { 72 | rc.ticker = *time.NewTicker(heartBeatInterval) 73 | // 将waitSend缓冲区进行发送 74 | go rc.execSend() 75 | // 获取服务端结果 76 | go rc.execReceive() 77 | // 定时发送心跳 78 | go rc.execHeardBeat() 79 | rc.connStatus.Store(connRunning) // 启动状态 80 | return nil 81 | } 82 | ``` 83 | 84 | ### 发送Redis命令 85 | 86 | 将`command [][]byte`保存到缓冲区 `rc.waitSend`中 87 | 88 | ```go 89 | // 将redis命令保存到 waitSend 中 90 | func (rc *RedisClent) Send(command [][]byte) (protocol.Reply, error) { 91 | 92 | // 已关闭 93 | if rc.connStatus.Load() == connClosed { 94 | return nil, errors.New("client closed") 95 | } 96 | 97 | req := &request{ 98 | command: command, 99 | wait: wait.Wait{}, 100 | } 101 | // 单个请求 102 | req.wait.Add(1) 103 | 104 | // 所有请求 105 | rc.working.Add(1) 106 | defer rc.working.Done() 107 | 108 | // 将数据保存到缓冲中 109 | rc.waitSend <- req 110 | 111 | // 等待处理结束 112 | if req.wait.WaitWithTimeOut(maxWait) { 113 | return nil, errors.New("time out") 114 | } 115 | // 出错 116 | if req.err != nil { 117 | err := req.err 118 | return nil, err 119 | } 120 | // 正常 121 | return req.reply, nil 122 | } 123 | ``` 124 | 125 | 126 | ### 发送Redis命令到服务端 127 | 128 | ```go 129 | // 将waitSend缓冲区进行发送 130 | func (rc *RedisClent) execSend() { 131 | for req := range rc.waitSend { 132 | rc.sendReq(req) 133 | } 134 | } 135 | 136 | func (rc *RedisClent) sendReq(req *request) { 137 | // 无效请求 138 | if req == nil || len(req.command) == 0 { 139 | return 140 | } 141 | 142 | var err error 143 | // 网络请求(重试3次) 144 | for i := 0; i < 3; i++ { 145 | _, err = rc.conn.Write(req.Bytes()) 146 | // 发送成功 or 发送错误(除了超时错误和deadline错误)跳出 147 | if err == nil || 148 | (!strings.Contains(err.Error(), "timeout") && // only retry timeout 149 | !strings.Contains(err.Error(), "deadline exceeded")) { 150 | break 151 | } 152 | } 153 | 154 | if err == nil { // 发送成功,异步等待结果 155 | rc.waitResult <- req 156 | } else { // 发送失败,请求直接失败 157 | req.err = err 158 | req.wait.Done() 159 | } 160 | } 161 | 162 | 163 | ``` 164 | 165 | ### 从服务端读取数据 166 | 167 | ```go 168 | func (rc *RedisClent) execReceive() { 169 | 170 | ch := parser.ParseStream(rc.conn) 171 | 172 | for payload := range ch { 173 | 174 | if payload.Err != nil { 175 | if rc.connStatus.Load() == connClosed { // 连接已关闭 176 | return 177 | } 178 | 179 | // 否则,重新连接(可能因为网络抖动临时断开了) 180 | 181 | rc.reconnect() 182 | return 183 | } 184 | 185 | // 说明一切正常 186 | 187 | rc.handleResult(payload.Reply) 188 | } 189 | } 190 | 191 | func (rc *RedisClent) handleResult(reply protocol.Reply) { 192 | // 从rc.waitResult 获取一个等待中的请求,将结果保存进去 193 | req := <-rc.waitResult 194 | if req == nil { 195 | return 196 | } 197 | req.reply = reply 198 | req.wait.Done() // 通知已经获取到结果 199 | } 200 | 201 | ``` 202 | 203 | 204 | ### 断线重连 205 | 206 | 因为网络抖动可能存在连接断开的情况,所以需要有重连的功能 207 | 208 | ```go 209 | func (rc *RedisClent) reconnect() { 210 | logger.Info("redis client reconnect...") 211 | rc.conn.Close() 212 | 213 | var conn net.Conn 214 | // 重连(重试3次) 215 | for i := 0; i < 3; i++ { 216 | var err error 217 | conn, err = net.Dial("tcp", rc.addr) 218 | if err != nil { 219 | logger.Error("reconnect error: " + err.Error()) 220 | time.Sleep(time.Second) 221 | continue 222 | } else { 223 | break 224 | } 225 | } 226 | // 服务端连不上,说明服务可能挂了(or 网络问题 and so on...) 227 | if conn == nil { 228 | rc.Stop() 229 | return 230 | } 231 | 232 | // 这里关闭没问题,因为rc.conn.Close已经关闭,函数Send中保存的请求因为发送不成功,不会写入到waitResult 233 | close(rc.waitResult) 234 | // 清理 waitResult(因为连接重置,新连接上只能处理新请求,老的请求的数据结果在老连接上,老连接已经关了,新连接上肯定是没有结果的) 235 | for req := range rc.waitResult { 236 | req.err = errors.New("connect reset") 237 | req.wait.Done() 238 | } 239 | 240 | // 新连接(新气象) 241 | rc.waitResult = make(chan *request, maxWait) 242 | rc.conn = conn 243 | 244 | // 重新启动接收协程 245 | go rc.execReceive() 246 | } 247 | 248 | ``` 249 | -------------------------------------------------------------------------------- /doc/8.pipeline客户端/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/8.pipeline客户端/image.png -------------------------------------------------------------------------------- /doc/9.事务/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/9.事务/image-1.png -------------------------------------------------------------------------------- /doc/9.事务/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/doc/9.事务/image.png -------------------------------------------------------------------------------- /engine/commoncmd.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "strconv" 5 | 6 | "github.com/gofish2020/easyredis/abstract" 7 | "github.com/gofish2020/easyredis/redis/protocol" 8 | "github.com/gofish2020/easyredis/tool/conf" 9 | ) 10 | 11 | /* 12 | 基础命令 13 | */ 14 | func execSelect(c abstract.Connection, redisArgs [][]byte) protocol.Reply { 15 | if len(redisArgs) != 1 { 16 | return protocol.NewArgNumErrReply("select") 17 | } 18 | dbIndex, err := strconv.ParseInt(string(redisArgs[0]), 10, 64) 19 | if err != nil { 20 | return protocol.NewGenericErrReply("invaild db index") 21 | } 22 | if dbIndex < 0 || dbIndex >= int64(conf.GlobalConfig.Databases) { 23 | return protocol.NewGenericErrReply("db index out of range") 24 | } 25 | c.SetDBIndex(int(dbIndex)) 26 | return protocol.NewOkReply() 27 | 28 | } 29 | 30 | // 异步方式重写aof 31 | func BGRewriteAOF(engine *Engine) protocol.Reply { 32 | go engine.aof.Rewrite(newAuxiliaryEngine()) 33 | return protocol.NewSimpleReply("Background append only file rewriting started") 34 | } 35 | -------------------------------------------------------------------------------- /engine/engine.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "fmt" 5 | "runtime/debug" 6 | "strings" 7 | "sync/atomic" 8 | "time" 9 | 10 | "github.com/gofish2020/easyredis/abstract" 11 | "github.com/gofish2020/easyredis/aof" 12 | "github.com/gofish2020/easyredis/engine/payload" 13 | "github.com/gofish2020/easyredis/pubhub" 14 | "github.com/gofish2020/easyredis/redis/protocol" 15 | "github.com/gofish2020/easyredis/tool/conf" 16 | "github.com/gofish2020/easyredis/tool/logger" 17 | "github.com/gofish2020/easyredis/tool/timewheel" 18 | ) 19 | 20 | // 存储引擎,负责数据的CRUD 21 | type Engine struct { 22 | // *DB 23 | dbSet []*atomic.Value 24 | // 时间轮(延迟任务) 25 | delay *timewheel.Delay 26 | // Append Only File 27 | aof *aof.AOF 28 | 29 | // 订阅 30 | 31 | hub *pubhub.Pubhub 32 | } 33 | 34 | func NewEngine() *Engine { 35 | 36 | engine := &Engine{} 37 | 38 | engine.delay = timewheel.NewDelay() 39 | // 多个dbSet 40 | engine.dbSet = make([]*atomic.Value, conf.GlobalConfig.Databases) 41 | for i := 0; i < conf.GlobalConfig.Databases; i++ { 42 | // 创建 *db 43 | db := newDB(engine.delay) 44 | db.SetIndex(i) 45 | // 保存到 atomic.Value中 46 | dbset := &atomic.Value{} 47 | dbset.Store(db) 48 | // 赋值到 dbSet中 49 | engine.dbSet[i] = dbset 50 | } 51 | 52 | engine.hub = pubhub.NewPubsub() 53 | // 启用AOF日志 54 | if conf.GlobalConfig.AppendOnly { 55 | // 创建*AOF对象 56 | aof, err := aof.NewAOF(conf.GlobalConfig.Dir+"/"+conf.GlobalConfig.AppendFilename, engine, true, conf.GlobalConfig.AppendFsync) 57 | if err != nil { 58 | panic(err) 59 | } 60 | engine.aof = aof 61 | // 设定每个db,使用aof写入日志 62 | engine.aofBindEveryDB() 63 | } 64 | return engine 65 | } 66 | 67 | func (e *Engine) aofBindEveryDB() { 68 | for _, dbSet := range e.dbSet { 69 | db := dbSet.Load().(*DB) 70 | db.writeAof = func(redisCommand [][]byte) { 71 | if conf.GlobalConfig.AppendOnly { 72 | // 调用e.aof对象方法,保存命令 73 | e.aof.SaveRedisCommand(db.index, aof.Command(redisCommand)) 74 | } 75 | } 76 | } 77 | } 78 | 79 | // 选中指定的 *DB 80 | func (e *Engine) selectDB(index int) (*DB, *protocol.GenericErrReply) { 81 | if index < 0 || index >= len(e.dbSet) { 82 | return nil, protocol.NewGenericErrReply("db index is out of range") 83 | } 84 | return e.dbSet[index].Load().(*DB), nil 85 | } 86 | 87 | // redisCommand 待执行的命令 protocol.Reply 执行结果 88 | func (e *Engine) Exec(c abstract.Connection, redisCommand [][]byte) (result protocol.Reply) { 89 | 90 | defer func() { 91 | if err := recover(); err != nil { 92 | logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack()))) 93 | result = protocol.NewUnknownErrReply() 94 | } 95 | }() 96 | // 命令小写 97 | commandName := strings.ToLower(string(redisCommand[0])) 98 | if commandName == "ping" { // https://redis.io/commands/ping/ 99 | return Ping(redisCommand[1:]) 100 | } 101 | if commandName == "auth" { // https://redis.io/commands/auth/ 102 | return Auth(c, redisCommand[1:]) 103 | } 104 | // 校验密码 105 | if !checkPasswd(c) { 106 | return protocol.NewGenericErrReply("Authentication required") 107 | } 108 | 109 | // 基础命令 110 | switch commandName { 111 | case "select": // 表示当前连接,要选中哪个db https://redis.io/commands/select/ 112 | if c != nil && c.IsTransaction() { // 事务模式,不能切换数据库 113 | return protocol.NewGenericErrReply("cannot select database within multi") 114 | } 115 | return execSelect(c, redisCommand[1:]) 116 | case "bgrewriteaof": // https://redis.io/commands/bgrewriteaof/ 117 | if !conf.GlobalConfig.AppendOnly { 118 | return protocol.NewGenericErrReply("AppendOnly is false, you can't rewrite aof file") 119 | } 120 | return BGRewriteAOF(e) 121 | case "subscribe": 122 | return e.hub.Subscribe(c, redisCommand[1:]) 123 | case "unsubscribe": 124 | return e.hub.Unsubscribe(c, redisCommand[1:]) 125 | case "publish": 126 | return e.hub.Publish(c, redisCommand[1:]) 127 | } 128 | 129 | // redis 命令处理 130 | dbIndex := c.GetDBIndex() 131 | logger.Debugf("db index:%d", dbIndex) 132 | db, errReply := e.selectDB(dbIndex) 133 | if errReply != nil { 134 | return errReply 135 | } 136 | return db.Exec(c, redisCommand) 137 | } 138 | 139 | func (e *Engine) Close() { 140 | e.aof.Close() 141 | } 142 | 143 | func (e *Engine) RWLocks(dbIndex int, readKeys, writeKeys []string) { 144 | db, err := e.selectDB(dbIndex) 145 | if err != nil { 146 | logger.Error("RWLocks err:", err.Status) 147 | return 148 | } 149 | db.RWLock(readKeys, writeKeys) 150 | } 151 | 152 | func (e *Engine) RWUnLocks(dbIndex int, readKeys, writeKeys []string) { 153 | db, err := e.selectDB(dbIndex) 154 | if err != nil { 155 | logger.Error("RWLocks err:", err.Status) 156 | return 157 | } 158 | db.RWUnLock(readKeys, writeKeys) 159 | } 160 | 161 | func (e *Engine) GetUndoLogs(dbIndex int, redisCommand [][]byte) []CmdLine { 162 | db, err := e.selectDB(dbIndex) 163 | if err != nil { 164 | logger.Error("RWLocks err:", err.Status) 165 | return nil 166 | } 167 | return db.GetUndoLog(redisCommand) 168 | } 169 | 170 | func (e *Engine) ExecWithLock(dbIndex int, redisCommand [][]byte) protocol.Reply { 171 | db, err := e.selectDB(dbIndex) 172 | if err != nil { 173 | logger.Error("RWLocks err:", err.Status) 174 | return err 175 | } 176 | 177 | return db.execWithLock(redisCommand) 178 | } 179 | 180 | // 遍历引擎的所有数据 181 | func (e *Engine) ForEach(dbIndex int, cb func(key string, data *payload.DataEntity, expiration *time.Time) bool) { 182 | 183 | db, errReply := e.selectDB(dbIndex) 184 | if errReply != nil { 185 | logger.Error("ForEach err ", errReply.ToBytes()) 186 | return 187 | } 188 | 189 | db.dataDict.ForEach(func(key string, val interface{}) bool { 190 | entity, _ := val.(*payload.DataEntity) 191 | var expiration *time.Time 192 | rawExpireTime, ok := db.ttlDict.Get(key) 193 | if ok { 194 | expireTime, _ := rawExpireTime.(time.Time) 195 | expiration = &expireTime 196 | } 197 | return cb(key, entity, expiration) 198 | }) 199 | } 200 | 201 | func newAuxiliaryEngine() *Engine { 202 | engine := &Engine{} 203 | engine.delay = timewheel.NewDelay() 204 | engine.dbSet = make([]*atomic.Value, conf.GlobalConfig.Databases) 205 | for i := range engine.dbSet { 206 | 207 | db := newBasicDB(engine.delay) 208 | db.SetIndex(i) 209 | 210 | holder := &atomic.Value{} 211 | holder.Store(db) 212 | engine.dbSet[i] = holder 213 | } 214 | return engine 215 | } 216 | -------------------------------------------------------------------------------- /engine/payload/payload.go: -------------------------------------------------------------------------------- 1 | package payload 2 | 3 | // 定义底层的数据存储对象 4 | type DataEntity struct { 5 | RedisObject interface{} // 字符串 跳表 链表 quicklist 集合 etc... 6 | } 7 | -------------------------------------------------------------------------------- /engine/register.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/gofish2020/easyredis/redis/protocol" 7 | ) 8 | 9 | /* 10 | 命令注册中心:记录命令和命令执行函数之间的映射关系 11 | */ 12 | 13 | type ExecFunc func(db *DB, args [][]byte) protocol.Reply 14 | 15 | type KeysFunc func(args [][]byte) ([]string, []string) // read/write 16 | 17 | type UndoFunc func(db *DB, args [][]byte) [][][]byte 18 | 19 | var commandCenter map[string]*command = make(map[string]*command) 20 | 21 | type command struct { 22 | commandName string 23 | execFunc ExecFunc // 命令执行函数 24 | keyFunc KeysFunc // 获取命令中的key 25 | argsNum int // redis命令组成个数;例如 get key就是由2部分组成; 如果是负数-2表示要>=2;如果是正数2表示 = 2 26 | undoFunc UndoFunc // 生成回滚命令 27 | } 28 | 29 | func registerCommand(name string, execFunc ExecFunc, keyFunc KeysFunc, argsNum int, undoFunc UndoFunc) { 30 | name = strings.ToLower(name) 31 | cmd := &command{} 32 | cmd.commandName = name 33 | cmd.execFunc = execFunc 34 | cmd.keyFunc = keyFunc 35 | cmd.argsNum = argsNum 36 | cmd.undoFunc = undoFunc 37 | commandCenter[name] = cmd 38 | } 39 | 40 | func GetRelatedKeys(redisCommand [][]byte) ([]string, []string) { 41 | 42 | cmdName := strings.ToLower(string(redisCommand[0])) 43 | 44 | cmd, ok := commandCenter[cmdName] 45 | if !ok { 46 | return nil, nil 47 | } 48 | 49 | keyFunc := cmd.keyFunc 50 | if keyFunc == nil { 51 | return nil, nil 52 | } 53 | return keyFunc(redisCommand[1:]) 54 | } 55 | -------------------------------------------------------------------------------- /engine/string.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | "time" 7 | 8 | "github.com/gofish2020/easyredis/aof" 9 | "github.com/gofish2020/easyredis/engine/payload" 10 | "github.com/gofish2020/easyredis/redis/protocol" 11 | ) 12 | 13 | const ( 14 | defaultPolicy = iota + 1 // 插入 or 更新 15 | insertPolicy // 只插入 16 | updatePolicy // 只更新 17 | 18 | ) 19 | 20 | const nolimitedTTL int64 = 0 // 过期时间 21 | 22 | // 获取底层存储对象【字节流】 23 | func (db *DB) getStringObject(key string) ([]byte, protocol.Reply) { 24 | payload, exist := db.GetEntity(key) 25 | if !exist { 26 | return nil, protocol.NewNullBulkReply() 27 | } 28 | // 判断底层对象是否为【字节流】 29 | bytes, ok := payload.RedisObject.([]byte) 30 | if !ok { 31 | return nil, protocol.NewWrongTypeErrReply() 32 | } 33 | return bytes, nil 34 | } 35 | 36 | // https://redis.io/commands/get/ key 37 | func cmdGet(db *DB, args [][]byte) protocol.Reply { 38 | if len(args) != 1 { 39 | return protocol.NewSyntaxErrReply() 40 | } 41 | 42 | key := string(args[0]) 43 | bytes, reply := db.getStringObject(key) 44 | if reply != nil { 45 | return reply 46 | } 47 | return protocol.NewBulkReply(bytes) 48 | } 49 | 50 | // https://redis.io/commands/set/ key value nx/xx ex/px 60 51 | func cmdSet(db *DB, args [][]byte) protocol.Reply { 52 | key := string(args[0]) 53 | value := args[1] 54 | 55 | policy := defaultPolicy 56 | ttl := nolimitedTTL 57 | if len(args) > 2 { 58 | 59 | for i := 2; i < len(args); i++ { 60 | arg := strings.ToUpper(string(args[i])) 61 | if arg == "NX" { // 插入 62 | if policy == updatePolicy { // 说明policy 已经被设置过,重复设置(语法错误) 63 | return protocol.NewSyntaxErrReply() 64 | } 65 | policy = insertPolicy 66 | } else if arg == "XX" { // 更新 67 | if policy == insertPolicy { // 说明policy 已经被设置过,重复设置(语法错误) 68 | return protocol.NewSyntaxErrReply() 69 | } 70 | policy = updatePolicy 71 | } else if arg == "EX" { // ex in seconds 72 | 73 | if ttl != nolimitedTTL { // 说明 ttl 已经被设置过,重复设置(语法错误) 74 | return protocol.NewSyntaxErrReply() 75 | } 76 | if i+1 >= len(args) { // 过期时间后面要跟上正整数 77 | return protocol.NewSyntaxErrReply() 78 | } 79 | ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64) 80 | if err != nil { 81 | return protocol.NewSyntaxErrReply() 82 | } 83 | 84 | if ttlArg <= 0 { 85 | return protocol.NewGenericErrReply("expire time is not a positive integer") 86 | } 87 | // 转成 ms 88 | ttl = ttlArg * 1000 89 | i++ // 跳过下一个参数 90 | } else if arg == "PX" { // px in milliseconds 91 | 92 | if ttl != nolimitedTTL { // 说明 ttl 已经被设置过,重复设置(语法错误) 93 | return protocol.NewSyntaxErrReply() 94 | } 95 | if i+1 >= len(args) { // 过期时间后面要跟上正整数 96 | return protocol.NewSyntaxErrReply() 97 | } 98 | ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64) 99 | if err != nil { 100 | return protocol.NewSyntaxErrReply() 101 | } 102 | 103 | if ttlArg <= 0 { 104 | return protocol.NewGenericErrReply("expire time is not a positive integer") 105 | } 106 | 107 | ttl = ttlArg 108 | i++ //跳过下一个参数 109 | } else { 110 | // 发现不符合要求的参数 111 | return protocol.NewSyntaxErrReply() 112 | } 113 | } 114 | } 115 | 116 | // 构建存储实体 117 | entity := payload.DataEntity{ 118 | RedisObject: value, 119 | } 120 | 121 | // 保存到内存字典中 122 | var result int 123 | if policy == defaultPolicy { 124 | db.PutEntity(key, &entity) 125 | result = 1 126 | } else if policy == insertPolicy { 127 | result = db.PutIfAbsent(key, &entity) 128 | } else if policy == updatePolicy { 129 | result = db.PutIfExist(key, &entity) 130 | } 131 | 132 | if result > 0 { // 1 表示存储成功 133 | //TODO: 过期时间处理 134 | if ttl != nolimitedTTL { // 设定key过期 135 | expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) 136 | db.ExpireAt(key, expireTime) 137 | //写入日志 138 | db.writeAof(aof.SetCmd([][]byte{args[0], args[1]}...)) 139 | db.writeAof(aof.PExpireAtCmd(string(args[0]), expireTime)) 140 | } else { // 设定key不过期 141 | db.Persist(key) 142 | //写入日志 143 | db.writeAof(aof.SetCmd(args...)) 144 | } 145 | return protocol.NewOkReply() 146 | } 147 | 148 | return protocol.NewNullBulkReply() 149 | } 150 | 151 | func cmdMSet(db *DB, args [][]byte) protocol.Reply { 152 | size := len(args) / 2 153 | // 提取出key value 154 | keys := make([]string, 0, size) 155 | values := make([][]byte, 0, size) 156 | for i := 0; i < size; i++ { 157 | keys = append(keys, string(args[2*i])) 158 | values = append(values, args[2*i+1]) 159 | } 160 | // 保存到内存中 161 | for i, key := range keys { 162 | value := values[i] 163 | entity := payload.DataEntity{ 164 | RedisObject: value, 165 | } 166 | db.PutEntity(key, &entity) 167 | } 168 | // 写日志 169 | db.writeAof(aof.MSetCmd(args...)) 170 | return protocol.NewOkReply() 171 | } 172 | func init() { 173 | // 获取值 174 | registerCommand("Get", cmdGet, readFirstKey, 2, nil) 175 | // 设置值 176 | registerCommand("Set", cmdSet, writeFirstKey, -3, rollbackFirstKey) 177 | // 设置多个值 178 | registerCommand("MSet", cmdMSet, writeMultiKey, -3, undoMSet) 179 | } 180 | -------------------------------------------------------------------------------- /engine/systemcmd.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "github.com/gofish2020/easyredis/abstract" 5 | "github.com/gofish2020/easyredis/redis/protocol" 6 | "github.com/gofish2020/easyredis/tool/conf" 7 | ) 8 | 9 | /* 10 | 常用:系统命令 11 | */ 12 | func Ping(redisArgs [][]byte) protocol.Reply { 13 | 14 | if len(redisArgs) == 0 { // 不带参数 15 | return protocol.NewPONGReply() 16 | } else if len(redisArgs) == 1 { // 带参数1个 17 | return protocol.NewBulkReply(redisArgs[0]) 18 | } 19 | // 否则,回复命令格式错误 20 | return protocol.NewArgNumErrReply("ping") 21 | } 22 | 23 | func checkPasswd(c abstract.Connection) bool { 24 | // 如果没有配置密码 25 | if conf.GlobalConfig.RequirePass == "" { 26 | return true 27 | } 28 | // 密码是否一致 29 | return c.GetPassword() == conf.GlobalConfig.RequirePass 30 | } 31 | 32 | func Auth(c abstract.Connection, redisArgs [][]byte) protocol.Reply { 33 | if len(redisArgs) != 1 { 34 | return protocol.NewArgNumErrReply("auth") 35 | } 36 | 37 | if conf.GlobalConfig.RequirePass == "" { 38 | return protocol.NewGenericErrReply("No authorization is required") 39 | } 40 | 41 | password := string(redisArgs[0]) 42 | if conf.GlobalConfig.RequirePass != password { 43 | return protocol.NewGenericErrReply("Auth failed, password is wrong") 44 | } 45 | 46 | c.SetPassword(password) 47 | return protocol.NewOkReply() 48 | } 49 | -------------------------------------------------------------------------------- /engine/transaction.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | 7 | "github.com/gofish2020/easyredis/abstract" 8 | "github.com/gofish2020/easyredis/redis/protocol" 9 | ) 10 | 11 | // 事务文档: https://redis.io/docs/interact/transactions/ 12 | 13 | // 开启事务 14 | func StartMulti(c abstract.Connection) protocol.Reply { 15 | if c.IsTransaction() { 16 | return protocol.NewGenericErrReply("multi is already start,do not repeat it") 17 | } 18 | // 设定开启 19 | c.SetTransaction(true) 20 | return protocol.NewOkReply() 21 | } 22 | 23 | // 取消事务 24 | func DiscardMulti(c abstract.Connection) protocol.Reply { 25 | if !c.IsTransaction() { 26 | return protocol.NewGenericErrReply("DISCARD without MULTI") 27 | } 28 | // 取消开启 29 | c.SetTransaction(false) 30 | return protocol.NewOkReply() 31 | } 32 | 33 | // 入队:保证命令在格式正确&& 存在的情况下入队 34 | func EnqueueCmd(c abstract.Connection, redisCommand [][]byte) protocol.Reply { 35 | 36 | cmdName := strings.ToLower(string(redisCommand[0])) 37 | 38 | // 从命令注册中心,获取命令的执行函数 39 | cmd, ok := commandCenter[cmdName] 40 | if !ok { // 命令不存在 41 | c.AddTxError(errors.New("unknown command '" + cmdName + "'")) 42 | return protocol.NewGenericErrReply("unknown command '" + cmdName + "'") 43 | } 44 | 45 | // 获取key的函数未设置 46 | if cmd.keyFunc == nil { 47 | c.AddTxError(errors.New("ERR command '" + cmdName + "' cannot be used in MULTI")) 48 | return protocol.NewGenericErrReply("ERR command '" + cmdName + "' cannot be used in MULTI") 49 | } 50 | 51 | // 参数个数不对 52 | if !validateArity(cmd.argsNum, redisCommand) { 53 | c.AddTxError(errors.New("ERR wrong number of arguments for '" + cmdName + "' command")) 54 | return protocol.NewArgNumErrReply(cmdName) 55 | } 56 | // 入队命令 57 | c.EnqueueCmd(redisCommand) 58 | return protocol.NewQueuedReply() 59 | } 60 | 61 | // 监视 key [key...] 62 | func Watch(db *DB, conn abstract.Connection, args [][]byte) protocol.Reply { 63 | if len(args) < 1 { 64 | return protocol.NewArgNumErrReply("WATCH") 65 | } 66 | if conn.IsTransaction() { 67 | return protocol.NewGenericErrReply("WATCH inside MULTI is not allowed") 68 | } 69 | watching := conn.GetWatchKey() 70 | for _, bkey := range args { 71 | key := string(bkey) 72 | watching[key] = db.GetVersion(key) // 保存当前key的版本号(利用版本号机制判断key是否有变化) 73 | } 74 | return protocol.NewOkReply() 75 | } 76 | 77 | // 清空watch key 78 | func UnWatch(db *DB, conn abstract.Connection) protocol.Reply { 79 | conn.CleanWatchKey() 80 | return protocol.NewOkReply() 81 | } 82 | 83 | // 执行事务 exec rb 84 | func ExecMulti(db *DB, conn abstract.Connection, args [][]byte) protocol.Reply { 85 | 86 | // 说明当前不是【事务模式】 87 | if !conn.IsTransaction() { 88 | return protocol.NewGenericErrReply("EXEC without MULTI") 89 | } 90 | // 执行完,自动退出事务模式 91 | defer conn.SetTransaction(false) 92 | 93 | // 如果在入队的时候,就有格式错误,直接返回 94 | if len(conn.GetTxErrors()) > 0 { 95 | return protocol.NewGenericErrReply("EXECABORT Transaction discarded because of previous errors.") 96 | } 97 | 98 | // 是否自动回滚(这里是自定义的一个参数,标准redis中没有) 99 | isRollBack := false 100 | if len(args) > 0 && strings.ToUpper(string(args[0])) == "RB" { // 有rb参数,说明要自动回滚 101 | isRollBack = true 102 | } 103 | // 获取所有的待执行命令 104 | cmdLines := conn.GetQueuedCmdLine() 105 | return db.execMulti(conn, cmdLines, isRollBack) 106 | } 107 | 108 | // 获取命令的回滚命令 109 | func (db *DB) GetUndoLog(cmdLine [][]byte) []CmdLine { 110 | cmdName := strings.ToLower(string(cmdLine[0])) 111 | cmd, ok := commandCenter[cmdName] 112 | if !ok { 113 | return nil 114 | } 115 | undo := cmd.undoFunc 116 | if undo == nil { 117 | return nil 118 | } 119 | return undo(db, cmdLine[1:]) 120 | } 121 | 122 | // 执行事务:本质就是一堆命令一起执行, isRollback 表示出错是否回滚 123 | func (db *DB) execMulti(conn abstract.Connection, cmdLines []CmdLine, isRollback bool) protocol.Reply { 124 | 125 | // 命令的执行结果 126 | results := make([]protocol.Reply, len(cmdLines)) 127 | 128 | versionKeys := make([][]string, len(cmdLines)) 129 | 130 | var writeKeys []string 131 | var readKeys []string 132 | for idx, cmdLine := range cmdLines { 133 | cmdName := strings.ToLower(string(cmdLine[0])) 134 | cmd, ok := commandCenter[cmdName] 135 | if !ok { 136 | // 这里正常不会执行 137 | continue 138 | } 139 | keyFunc := cmd.keyFunc 140 | readKs, writeKs := keyFunc(cmdLine[1:]) 141 | // 读写key 142 | readKeys = append(readKeys, readKs...) 143 | writeKeys = append(writeKeys, writeKs...) 144 | // 写key需要 变更版本号 145 | versionKeys[idx] = append(versionKeys[idx], writeKs...) 146 | } 147 | 148 | watchingKey := conn.GetWatchKey() 149 | if isWatchingChanged(db, watchingKey) { // 判断watch key是否发生了变更 150 | return protocol.NewEmptyMultiBulkReply() 151 | } 152 | 153 | // 所有key上锁(原子性) 154 | db.RWLock(readKeys, writeKeys) 155 | defer db.RWUnLock(readKeys, writeKeys) 156 | 157 | undoCmdLines := [][]CmdLine{} 158 | aborted := false 159 | for idx, cmdLine := range cmdLines { 160 | 161 | // 生成回滚命令 162 | if isRollback { 163 | undoCmdLines = append(undoCmdLines, db.GetUndoLog(cmdLine)) 164 | } 165 | 166 | // 执行命令 167 | reply := db.execWithLock(cmdLine) 168 | if protocol.IsErrReply(reply) { // 执行出错 169 | if isRollback { // 需要回滚 170 | undoCmdLines = undoCmdLines[:len(undoCmdLines)-1] // 命令执行失败(不用回滚),剔除最后一个回滚命令 171 | aborted = true 172 | break 173 | } 174 | } 175 | // 执行结果 176 | results[idx] = reply 177 | } 178 | // 中断,执行回滚 179 | if aborted { 180 | size := len(undoCmdLines) 181 | // 倒序执行回滚指令(完成回滚) 182 | for i := size - 1; i >= 0; i-- { 183 | curCmdLines := undoCmdLines[i] 184 | if len(curCmdLines) == 0 { 185 | continue 186 | } 187 | for _, cmdLine := range curCmdLines { 188 | db.execWithLock(cmdLine) 189 | } 190 | } 191 | return protocol.NewGenericErrReply("EXECABORT Transaction discarded because of previous errors.") 192 | } 193 | 194 | // 执行到这里,说明命令执行完成(可能全部成功,也可能部分成功) 195 | for idx, keys := range versionKeys { 196 | if !protocol.IsErrReply(results[idx]) { // 针对执行成功的命令(写命令),变更版本号 197 | db.addVersion(keys...) 198 | } 199 | } 200 | // 将多个命令执行的结果,进行合并返回 201 | mixReply := protocol.NewMixReply() 202 | mixReply.Append(results...) 203 | return mixReply 204 | } 205 | 206 | func isWatchingChanged(db *DB, watching map[string]int64) bool { 207 | for key, ver := range watching { 208 | currentVersion := db.GetVersion(key) 209 | if ver != currentVersion { 210 | return true 211 | } 212 | } 213 | return false 214 | } 215 | -------------------------------------------------------------------------------- /engine/utils.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "strconv" 5 | 6 | "github.com/gofish2020/easyredis/aof" 7 | ) 8 | 9 | func readFirstKey(args [][]byte) ([]string, []string) { 10 | return []string{string(args[0])}, nil 11 | } 12 | 13 | func writeFirstKey(args [][]byte) ([]string, []string) { 14 | return nil, []string{string(args[0])} 15 | } 16 | 17 | func readAllKey(args [][]byte) ([]string, []string) { 18 | readKeys := make([]string, len(args)) 19 | for i, arg := range args { 20 | readKeys[i] = string(arg) 21 | } 22 | return readKeys, nil 23 | } 24 | 25 | func writeAllKey(args [][]byte) ([]string, []string) { 26 | writeKeys := make([]string, len(args)) 27 | for i, arg := range args { 28 | writeKeys[i] = string(arg) 29 | } 30 | return nil, writeKeys 31 | } 32 | 33 | func noKey(args [][]byte) ([]string, []string) { 34 | return nil, nil 35 | } 36 | 37 | func writeMultiKey(args [][]byte) ([]string, []string) { 38 | 39 | size := len(args) / 2 40 | 41 | writeKeys := make([]string, size) 42 | 43 | for i := 0; i < size; i++ { 44 | writeKeys = append(writeKeys, string(args[2*i])) 45 | } 46 | 47 | return nil, writeKeys 48 | } 49 | 50 | // ********* 回滚 *********** 51 | 52 | // 通用的回滚(其实就是将整个内存数据都记录下来) 53 | func rollbackFirstKey(db *DB, args [][]byte) []CmdLine { 54 | key := string(args[0]) 55 | return rollbackGivenKeys(db, key) 56 | } 57 | 58 | func rollbackGivenKeys(db *DB, keys ...string) []CmdLine { 59 | var undoCmdLines []CmdLine 60 | for _, key := range keys { 61 | // 获取内存对象 62 | entity, ok := db.GetEntity(key) 63 | if !ok { 64 | undoCmdLines = append(undoCmdLines, 65 | aof.Del([]byte(key)), // key不存在,del 66 | ) 67 | } else { 68 | undoCmdLines = append(undoCmdLines, 69 | aof.Del([]byte(key)), // 先清理 70 | aof.EntityToCmd(key, entity).RedisCommand, // redis命令 71 | toTTLCmd(db, key).RedisCommand, // 过期 72 | ) 73 | } 74 | } 75 | return undoCmdLines 76 | } 77 | 78 | func rollbackZSetMembers(db *DB, key string, members ...string) []CmdLine { 79 | var undoCmdLines [][][]byte 80 | // 获取有序集合对象 81 | zset, errReply := db.getSortedSetObject(key) 82 | if errReply != nil { 83 | return nil 84 | } 85 | // 说明集合对象不存在(所以要生成删除回滚) 86 | if zset == nil { 87 | undoCmdLines = append(undoCmdLines, 88 | aof.Del([]byte(key)), 89 | ) 90 | return undoCmdLines 91 | } 92 | for _, member := range members { 93 | elem, ok := zset.Get(member) 94 | if !ok { // member不存在(回滚:就是删除) 95 | undoCmdLines = append(undoCmdLines, 96 | aof.ZRem([]byte(key), []byte(member)), 97 | ) 98 | } else { 99 | // 记录原始值 100 | score := strconv.FormatFloat(elem.Score, 'f', -1, 64) 101 | undoCmdLines = append(undoCmdLines, 102 | aof.ZAddCmd([]byte(key), []byte(score), []byte(member)), 103 | ) 104 | } 105 | } 106 | return undoCmdLines 107 | } 108 | 109 | func undoMSet(db *DB, args [][]byte) []CmdLine { 110 | _, writeKeys := writeMultiKey(args) 111 | return rollbackGivenKeys(db, writeKeys...) 112 | } 113 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/gofish2020/easyredis 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 7 | github.com/stretchr/testify v1.8.4 8 | ) 9 | 10 | require ( 11 | github.com/davecgh/go-spew v1.1.1 // indirect 12 | github.com/pmezard/go-difflib v1.0.0 // indirect 13 | gopkg.in/yaml.v3 v3.0.1 // indirect 14 | ) 15 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 h1:iQTw/8FWTuc7uiaSepXwyf3o52HaUYcV+Tu66S3F5GA= 4 | github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= 5 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 6 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 7 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 8 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 9 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 10 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 11 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 12 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 13 | -------------------------------------------------------------------------------- /image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/image-1.png -------------------------------------------------------------------------------- /image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/image-2.png -------------------------------------------------------------------------------- /image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/image-3.png -------------------------------------------------------------------------------- /image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gofish2020/easyredis/a2c6654de878632c390c7c0aa9d11266d397ff3c/image.png -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "os" 8 | 9 | "github.com/gofish2020/easyredis/redis" 10 | "github.com/gofish2020/easyredis/tcpserver" 11 | "github.com/gofish2020/easyredis/tool/conf" 12 | "github.com/gofish2020/easyredis/tool/logger" 13 | "github.com/gofish2020/easyredis/utils" 14 | ) 15 | 16 | func main() { 17 | //1. 打印logo 18 | println(utils.Logo()) 19 | 20 | //2. 初始化配置 21 | initConfig() 22 | 23 | //3. 日志库初始化 24 | initLogger() 25 | 26 | logger.Info("start easyredis server") 27 | 28 | //4. 服务对象 29 | tcp := tcpserver.NewTCPServer(tcpserver.TCPConfig{ 30 | Addr: fmt.Sprintf("%s:%d", conf.GlobalConfig.Bind, conf.GlobalConfig.Port), 31 | }, redis.NewRedisHandler()) 32 | 33 | //5. 启动服务 34 | err := tcp.Start() 35 | if err != nil { 36 | log.Printf("%+v", err) 37 | os.Exit(1) 38 | } 39 | 40 | //6. 关闭服务 41 | tcp.Close() 42 | } 43 | 44 | func initLogger() { 45 | logger.Setup(&logger.Settings{ 46 | Path: conf.GlobalConfig.Dir + "/logs", 47 | Name: "easyredis", 48 | Ext: "log", 49 | DateFormat: utils.DateFormat, 50 | }) 51 | 52 | logger.SetLoggerLevel(logger.DEBUG) 53 | } 54 | 55 | func initConfig() { 56 | configFileName := "" 57 | flag.StringVar(&configFileName, "conf", "", "Usage: -conf=./redis.conf") 58 | flag.Parse() 59 | 60 | // 解析配置文件 61 | if configFileName == "" { 62 | configFileName = utils.ExecDir() + "/redis.conf" 63 | } 64 | if utils.FileExists(configFileName) { 65 | conf.LoadConfig(configFileName) 66 | } else { 67 | // 默认的配置 68 | conf.GlobalConfig = &conf.RedisConfig{ 69 | Bind: "0.0.0.0", 70 | Port: 6379, 71 | AppendOnly: false, 72 | AppendFilename: "", 73 | RunID: utils.RandString(40), 74 | } 75 | } 76 | //logger.Debugf("%#v", conf.GlobalConfig) 77 | } 78 | -------------------------------------------------------------------------------- /pubhub/pubhub.go: -------------------------------------------------------------------------------- 1 | package pubhub 2 | 3 | import ( 4 | "strconv" 5 | 6 | "github.com/gofish2020/easyredis/abstract" 7 | "github.com/gofish2020/easyredis/datastruct/dict" 8 | "github.com/gofish2020/easyredis/datastruct/list" 9 | "github.com/gofish2020/easyredis/redis/protocol" 10 | "github.com/gofish2020/easyredis/tool/locker" 11 | "github.com/gofish2020/easyredis/tool/logger" 12 | "github.com/gofish2020/easyredis/utils" 13 | ) 14 | 15 | /* 16 | 发布订阅的底层数据结构: map + list 17 | map中的key表示channel, list记录订阅该channel的客户端 18 | */ 19 | 20 | const ( 21 | _subscribe = "subscribe" 22 | _unsubscribe = "unsubscribe" 23 | _message = "message" 24 | ) 25 | 26 | // https://redis.io/docs/interact/pubsub/ 27 | // chanName 通道名 count总共订阅成功的数量 28 | func channelMsg(action, chanName string, count int) []byte { 29 | // Wire protocol 30 | return []byte("*3" + utils.CRLF + 31 | "$" + strconv.Itoa(len(action)) + utils.CRLF + action + utils.CRLF + 32 | "$" + strconv.Itoa(len(chanName)) + utils.CRLF + chanName + utils.CRLF + 33 | ":" + strconv.Itoa(count) + utils.CRLF) 34 | } 35 | 36 | func noChannelMsg() []byte { 37 | return []byte("*3" + utils.CRLF + 38 | "$" + strconv.Itoa(len(_unsubscribe)) + utils.CRLF + _unsubscribe + utils.CRLF + 39 | "$-1" + utils.CRLF + 40 | ":0" + utils.CRLF) 41 | } 42 | 43 | func publisMsg(channel string, msg string) []byte { 44 | 45 | return []byte("*3" + utils.CRLF + 46 | "$" + strconv.Itoa(len(_message)) + utils.CRLF + _message + utils.CRLF + 47 | "$" + strconv.Itoa(len(channel)) + utils.CRLF + channel + utils.CRLF + 48 | "$" + strconv.Itoa(len(msg)) + utils.CRLF + msg + utils.CRLF) 49 | } 50 | 51 | type Pubhub struct { 52 | 53 | // 自定义实现的map 54 | dataDict dict.ConcurrentDict 55 | 56 | // 该锁的颗粒度太大 57 | //locker sync.RWMutex 58 | 59 | locker *locker.Locker // 自定义一个分布锁 60 | } 61 | 62 | func NewPubsub() *Pubhub { 63 | pubsub := &Pubhub{ 64 | dataDict: *dict.NewConcurrentDict(16), 65 | locker: locker.NewLocker(16), 66 | } 67 | return pubsub 68 | } 69 | 70 | // SUBSCRIBE channel [channel ...] 71 | func (p *Pubhub) Subscribe(c abstract.Connection, args [][]byte) protocol.Reply { 72 | 73 | if len(args) < 1 { 74 | return protocol.NewArgNumErrReply("subscribe") 75 | } 76 | 77 | // 通道名 78 | keys := make([]string, 0, len(args)) 79 | for _, arg := range args { 80 | keys = append(keys, string(arg)) 81 | } 82 | // 加锁 83 | p.locker.Locks(keys...) 84 | defer p.locker.Unlocks(keys...) 85 | 86 | for _, arg := range args { 87 | chanName := string(arg) 88 | // 记录当前客户端连接订阅的通道 89 | c.Subscribe(chanName) 90 | 91 | // 双向链表,记录通道下的客户端连接 92 | var l *list.LinkedList 93 | raw, exist := p.dataDict.Get(chanName) 94 | if !exist { // 说明该channel第一次使用 95 | l = list.NewLinkedList() 96 | p.dataDict.Put(chanName, l) 97 | } else { 98 | l, _ = raw.(*list.LinkedList) 99 | } 100 | 101 | // 未订阅 102 | if !l.Contain(func(actual interface{}) bool { 103 | return c == actual 104 | }) { 105 | // 如果不重复,那就记录订阅 106 | logger.Debug("subscribe channel [" + chanName + "] success") 107 | l.Add(c) 108 | } 109 | 110 | // 回复客户端消息 111 | _, err := c.Write(channelMsg(_subscribe, chanName, c.SubCount())) 112 | if err != nil { 113 | logger.Warn(err) 114 | } 115 | } 116 | 117 | return protocol.NewNoReply() 118 | } 119 | 120 | // 取消订阅 121 | // unsubscribes itself from all the channels using the UNSUBSCRIBE command without additional arguments 122 | func (p *Pubhub) Unsubscribe(c abstract.Connection, args [][]byte) protocol.Reply { 123 | 124 | var channels []string 125 | if len(args) < 1 { // 取消全部 126 | channels = c.GetChannels() 127 | } else { // 取消指定channel 128 | channels = make([]string, len(args)) 129 | for i, v := range args { 130 | channels[i] = string(v) 131 | } 132 | } 133 | 134 | p.locker.Locks(channels...) 135 | defer p.locker.Unlocks(channels...) 136 | 137 | // 说明已经没有订阅的通道 138 | if len(channels) == 0 { 139 | c.Write(noChannelMsg()) 140 | } 141 | for _, channel := range channels { 142 | 143 | // 从客户端中删除当前通道 144 | c.Unsubscribe(channel) 145 | // 获取链表 146 | raw, ok := p.dataDict.Get(channel) 147 | if ok { 148 | // 从链表中删除当前客户端 149 | l, _ := raw.(*list.LinkedList) 150 | l.DelAllByVal(func(actual interface{}) bool { 151 | return c == actual 152 | }) 153 | 154 | // 如果链表为空,清理map 155 | if l.Len() == 0 { 156 | p.dataDict.Delete(channel) 157 | } 158 | } 159 | c.Write(channelMsg(_unsubscribe, channel, c.SubCount())) 160 | } 161 | 162 | return protocol.NewNoReply() 163 | } 164 | 165 | func (p *Pubhub) Publish(self abstract.Connection, args [][]byte) protocol.Reply { 166 | 167 | if len(args) != 2 { 168 | return protocol.NewArgNumErrReply("publish") 169 | } 170 | 171 | channelName := string(args[0]) 172 | // 加锁 173 | p.locker.Locks(channelName) 174 | defer p.locker.Unlocks(channelName) 175 | 176 | raw, ok := p.dataDict.Get(channelName) 177 | if ok { 178 | 179 | var sendSuccess int64 180 | var failedClient = make(map[interface{}]struct{}) 181 | // 取出链表 182 | l, _ := raw.(*list.LinkedList) 183 | // 遍历链表 184 | l.ForEach(func(i int, val interface{}) bool { 185 | 186 | conn, _ := val.(abstract.Connection) 187 | 188 | if conn.IsClosed() { 189 | failedClient[val] = struct{}{} 190 | return true 191 | } 192 | 193 | if val == self { //不给自己发送 194 | return true 195 | } 196 | // 发送数据 197 | conn.Write(publisMsg(channelName, string(args[1]))) 198 | sendSuccess++ 199 | return true 200 | }) 201 | 202 | // 剔除客户端 203 | if len(failedClient) > 0 { 204 | removed := l.DelAllByVal(func(actual interface{}) bool { 205 | _, ok := failedClient[actual] 206 | return ok 207 | }) 208 | logger.Debugf("del %d closed client", removed) 209 | } 210 | 211 | // 返回发送的客户端数量 212 | return protocol.NewIntegerReply(sendSuccess) 213 | } 214 | // 如果channel不存在 215 | return protocol.NewIntegerReply(0) 216 | } 217 | -------------------------------------------------------------------------------- /redis-cli.sh: -------------------------------------------------------------------------------- 1 | # redis-cli :redis客户端,记得加到自己的环境变量PATH中 2 | redis-cli -h 127.0.0.1 -p 6379 3 | 4 | # redis-cli -h 127.0.0.1 -p 6379 -a 1 -------------------------------------------------------------------------------- /redis-cluster0.sh: -------------------------------------------------------------------------------- 1 | go run main.go -conf=./redis0.conf -------------------------------------------------------------------------------- /redis-cluster1.sh: -------------------------------------------------------------------------------- 1 | go run main.go -conf=./redis1.conf -------------------------------------------------------------------------------- /redis-cluster2.sh: -------------------------------------------------------------------------------- 1 | go run main.go -conf=./redis2.conf -------------------------------------------------------------------------------- /redis/client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "strings" 7 | "sync" 8 | "sync/atomic" 9 | "time" 10 | 11 | "github.com/gofish2020/easyredis/redis/parser" 12 | "github.com/gofish2020/easyredis/redis/protocol" 13 | "github.com/gofish2020/easyredis/tool/logger" 14 | "github.com/gofish2020/easyredis/tool/wait" 15 | ) 16 | 17 | const ( 18 | maxChanSize = 1 << 10 19 | maxWait = 3 * time.Second 20 | 21 | heartBeatInterval = 1 * time.Second 22 | ) 23 | 24 | const ( 25 | connCreated = iota 26 | connRunning 27 | connClosed 28 | ) 29 | 30 | type request struct { 31 | command [][]byte // redis命令 32 | err error // 处理出错 33 | reply protocol.Reply // 处理结果 34 | 35 | wait wait.Wait // 等待处理完成 36 | } 37 | 38 | func (r *request) Bytes() []byte { 39 | return protocol.NewMultiBulkReply(r.command).ToBytes() 40 | } 41 | 42 | type RedisClient struct { 43 | // socket连接 44 | conn net.Conn 45 | 46 | addr string 47 | // 客户端当前状态 48 | connStatus atomic.Int32 49 | 50 | // heartbeat 51 | ticker time.Ticker 52 | 53 | // buffer cache 54 | waitSend chan *request 55 | waitResult chan *request 56 | 57 | // 有请求正在处理中... 58 | working sync.WaitGroup 59 | } 60 | 61 | // 创建redis客户端socket 62 | func NewRedisClient(addr string) (*RedisClient, error) { 63 | conn, err := net.Dial("tcp", addr) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | rc := RedisClient{} 69 | rc.conn = conn 70 | rc.waitSend = make(chan *request, maxChanSize) 71 | rc.waitResult = make(chan *request, maxChanSize) 72 | rc.addr = addr 73 | rc.connStatus.Store(connCreated) 74 | return &rc, nil 75 | } 76 | 77 | // 启动 78 | func (rc *RedisClient) Start() error { 79 | rc.ticker = *time.NewTicker(heartBeatInterval) 80 | // 将waitSend缓冲区进行发送 81 | go rc.execSend() 82 | // 获取服务端结果 83 | go rc.execReceive() 84 | // 定时发送心跳 85 | //go rc.execHeardBeat() 86 | rc.connStatus.Store(connRunning) // 启动状态 87 | return nil 88 | } 89 | 90 | func (rc *RedisClient) execReceive() { 91 | 92 | ch := parser.ParseStream(rc.conn) 93 | 94 | for payload := range ch { 95 | 96 | if payload.Err != nil { 97 | if rc.connStatus.Load() == connClosed { // 连接已关闭 98 | return 99 | } 100 | 101 | // 否则,重新连接(可能因为网络抖动临时断开了) 102 | 103 | rc.reconnect() 104 | return 105 | } 106 | 107 | // 说明一切正常 108 | 109 | rc.handleResult(payload.Reply) 110 | } 111 | } 112 | 113 | func (rc *RedisClient) reconnect() { 114 | logger.Info("redis client reconnect...") 115 | rc.conn.Close() 116 | 117 | var conn net.Conn 118 | // 重连(重试3次) 119 | for i := 0; i < 3; i++ { 120 | var err error 121 | conn, err = net.Dial("tcp", rc.addr) 122 | if err != nil { 123 | logger.Error("reconnect error: " + err.Error()) 124 | time.Sleep(time.Second) 125 | continue 126 | } else { 127 | break 128 | } 129 | } 130 | // 服务端连不上,说明服务可能挂了(or 网络问题 and so on...) 131 | if conn == nil { 132 | rc.Stop() 133 | return 134 | } 135 | 136 | // 这里关闭没问题,因为rc.conn.Close已经关闭,函数Send中保存的请求因为发送不成功,不会写入到waitResult 137 | close(rc.waitResult) 138 | // 清理 waitResult(因为连接重置,新连接上只能处理新请求,老的请求的数据结果在老连接上,老连接已经关了,新连接上肯定是没有结果的) 139 | for req := range rc.waitResult { 140 | req.err = errors.New("connect reset") 141 | req.wait.Done() 142 | } 143 | 144 | // 新连接(新气象) 145 | rc.waitResult = make(chan *request, maxWait) 146 | rc.conn = conn 147 | 148 | // 重新启动接收协程 149 | go rc.execReceive() 150 | } 151 | 152 | func (rc *RedisClient) handleResult(reply protocol.Reply) { 153 | // 从rc.waitResult 获取一个等待中的请求,将结果保存进去 154 | req := <-rc.waitResult 155 | if req == nil { 156 | return 157 | } 158 | req.reply = reply 159 | req.wait.Done() // 通知已经获取到结果 160 | } 161 | 162 | // 将waitSend缓冲区进行发送 163 | func (rc *RedisClient) execSend() { 164 | for req := range rc.waitSend { 165 | rc.sendReq(req) 166 | } 167 | } 168 | 169 | func (rc *RedisClient) sendReq(req *request) { 170 | // 无效请求 171 | if req == nil || len(req.command) == 0 { 172 | return 173 | } 174 | 175 | var err error 176 | // 网络请求(重试3次) 177 | for i := 0; i < 3; i++ { 178 | _, err = rc.conn.Write(req.Bytes()) 179 | // 发送成功 or 发送错误(除了超时错误和deadline错误)跳出 180 | if err == nil || 181 | (!strings.Contains(err.Error(), "timeout") && // only retry timeout 182 | !strings.Contains(err.Error(), "deadline exceeded")) { 183 | break 184 | } 185 | } 186 | 187 | if err == nil { // 发送成功,异步等待结果 188 | rc.waitResult <- req 189 | } else { // 发送失败,请求直接失败 190 | req.err = err 191 | req.wait.Done() 192 | } 193 | } 194 | 195 | // 定时发送心跳 196 | func (rc *RedisClient) execHeardBeat() { 197 | for range rc.ticker.C { 198 | rc.Send([][]byte{[]byte("PING")}) 199 | } 200 | 201 | } 202 | 203 | // 将redis命令保存到 waitSend 中 204 | func (rc *RedisClient) Send(command [][]byte) (protocol.Reply, error) { 205 | 206 | // 已关闭 207 | if rc.connStatus.Load() == connClosed { 208 | return nil, errors.New("client closed") 209 | } 210 | 211 | req := &request{ 212 | command: command, 213 | wait: wait.Wait{}, 214 | } 215 | // 单个请求 216 | req.wait.Add(1) 217 | 218 | // 所有请求 219 | rc.working.Add(1) 220 | defer rc.working.Done() 221 | 222 | // 将数据保存到缓冲中 223 | rc.waitSend <- req 224 | 225 | // 等待处理结束 226 | if req.wait.WaitWithTimeOut(maxWait) { 227 | return nil, errors.New("time out") 228 | } 229 | // 出错 230 | if req.err != nil { 231 | err := req.err 232 | return nil, err 233 | } 234 | // 正常 235 | return req.reply, nil 236 | } 237 | 238 | func (rc *RedisClient) Stop() { 239 | // 设置已关闭 240 | rc.connStatus.Store(connClosed) 241 | rc.ticker.Stop() 242 | 243 | // 保证发送协程停止 244 | close(rc.waitSend) 245 | // 说明等待网络请求结果的request客户端不阻塞了(也就是剩下的req不需要等待了,可以关闭网络连接) 246 | rc.working.Wait() 247 | rc.conn.Close() 248 | close(rc.waitResult) 249 | } 250 | -------------------------------------------------------------------------------- /redis/client/client_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | "time" 7 | 8 | "github.com/gofish2020/easyredis/tool/logger" 9 | ) 10 | 11 | func TestReconnect(t *testing.T) { 12 | logger.Setup(&logger.Settings{ 13 | Path: "logs", 14 | Name: "easyredis", 15 | Ext: ".log", 16 | DateFormat: "2006-01-02", 17 | }) 18 | client, err := NewRedisClient("localhost:6379") 19 | if err != nil { 20 | t.Error(err) 21 | } 22 | client.Start() 23 | 24 | // 模拟连接断开 25 | _ = client.conn.Close() 26 | time.Sleep(time.Second) // wait for reconnecting 27 | success := false 28 | for i := 0; i < 3; i++ { 29 | result, err := client.Send([][]byte{ 30 | []byte("PING"), 31 | }) 32 | if err == nil && bytes.Equal(result.ToBytes(), []byte("+PONG\r\n")) { 33 | success = true 34 | break 35 | } 36 | } 37 | if !success { 38 | t.Error("reconnect error") 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /redis/connection/conn.go: -------------------------------------------------------------------------------- 1 | package connection 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | "sync/atomic" 7 | "time" 8 | 9 | "github.com/gofish2020/easyredis/tool/logger" 10 | "github.com/gofish2020/easyredis/tool/wait" 11 | ) 12 | 13 | const ( 14 | timeout = 10 * time.Second 15 | ) 16 | 17 | // 连接池对象 18 | var connPool = sync.Pool{ 19 | New: func() interface{} { 20 | return &KeepConnection{ 21 | dbIndex: 0, 22 | c: nil, 23 | password: "", 24 | closed: atomic.Bool{}, 25 | trx: atomic.Bool{}, 26 | } 27 | }, 28 | } 29 | 30 | // 记录连接的相关信息 31 | type KeepConnection struct { 32 | // 网络conn 33 | c net.Conn 34 | // 服务密码 35 | password string 36 | // 当前连接指定的数据库 37 | dbIndex int 38 | 39 | // 当要关闭连接,如果连接还在使用中【等待...】 wait.Wait 是对 sync.WaitGroup的封装 40 | writeDataWaitGroup wait.Wait 41 | 42 | // 记录当前连接,订阅的channel 43 | 44 | mu sync.Mutex 45 | subs map[string]struct{} 46 | 47 | closed atomic.Bool 48 | 49 | // 事务模式 50 | trx atomic.Bool 51 | queue [][][]byte 52 | watchKey map[string]int64 53 | txErrors []error 54 | } 55 | 56 | // 本质就是构建 *KeepConnection对象,存储c net.Conn 以及相关信息 57 | func NewKeepConnection(c net.Conn) *KeepConnection { 58 | 59 | conn, ok := connPool.Get().(*KeepConnection) 60 | if !ok { 61 | logger.Error("connection pool make wrong type") 62 | return &KeepConnection{ 63 | dbIndex: 0, 64 | c: nil, 65 | password: "", 66 | closed: atomic.Bool{}, 67 | trx: atomic.Bool{}, 68 | } 69 | } 70 | conn.c = c 71 | conn.closed.Store(false) 72 | conn.trx.Store(false) 73 | conn.queue = nil 74 | conn.txErrors = nil 75 | conn.watchKey = nil 76 | return conn 77 | } 78 | 79 | // 当前连接选定的数据库 80 | func (k *KeepConnection) SetDBIndex(index int) { 81 | k.dbIndex = index 82 | } 83 | 84 | func (k *KeepConnection) GetDBIndex() int { 85 | return k.dbIndex 86 | } 87 | 88 | // 连接远程地址信息 89 | func (k *KeepConnection) RemoteAddr() string { 90 | 91 | return k.c.RemoteAddr().String() 92 | } 93 | 94 | // 关闭 *KeepConnection 对象 95 | func (k *KeepConnection) Close() error { 96 | 97 | k.closed.Store(true) 98 | k.writeDataWaitGroup.WaitWithTimeOut(timeout) // 等待write结束 99 | k.c.Close() 100 | k.dbIndex = 0 101 | connPool.Put(k) 102 | return nil 103 | } 104 | 105 | func (k *KeepConnection) IsClosed() bool { 106 | return k.closed.Load() 107 | } 108 | 109 | func (k *KeepConnection) Write(b []byte) (int, error) { 110 | 111 | if len(b) == 0 { 112 | return 0, nil 113 | } 114 | 115 | k.writeDataWaitGroup.Add(1) // 说明在write 116 | defer k.writeDataWaitGroup.Done() // 说明write结束 117 | return k.c.Write(b) 118 | } 119 | 120 | // 密码信息 121 | func (k *KeepConnection) SetPassword(password string) { 122 | k.password = password 123 | } 124 | 125 | func (k *KeepConnection) GetPassword() string { 126 | return k.password 127 | } 128 | 129 | func (k *KeepConnection) Subscribe(channel string) { 130 | k.mu.Lock() 131 | defer k.mu.Unlock() 132 | if k.subs == nil { 133 | k.subs = map[string]struct{}{} 134 | } 135 | 136 | k.subs[channel] = struct{}{} 137 | 138 | } 139 | 140 | func (k *KeepConnection) Unsubscribe(channel string) { 141 | k.mu.Lock() 142 | defer k.mu.Unlock() 143 | 144 | if len(k.subs) == 0 { 145 | return 146 | } 147 | 148 | delete(k.subs, channel) 149 | } 150 | 151 | func (k *KeepConnection) SubCount() int { 152 | return len(k.subs) 153 | } 154 | 155 | func (k *KeepConnection) GetChannels() []string { 156 | 157 | k.mu.Lock() 158 | defer k.mu.Unlock() 159 | 160 | var result []string 161 | for channel := range k.subs { 162 | result = append(result, channel) 163 | } 164 | return result 165 | } 166 | 167 | func (k *KeepConnection) IsTransaction() bool { 168 | return k.trx.Load() 169 | } 170 | 171 | func (k *KeepConnection) SetTransaction(val bool) { 172 | if !val { // 取消事务模式,清空队列和watch key 173 | k.queue = nil 174 | k.watchKey = nil 175 | k.txErrors = nil 176 | } 177 | // 开启事务状态 178 | k.trx.Store(val) 179 | } 180 | 181 | func (k *KeepConnection) EnqueueCmd(redisCommand [][]byte) { 182 | k.queue = append(k.queue, redisCommand) 183 | } 184 | 185 | func (k *KeepConnection) GetQueuedCmdLine() [][][]byte { 186 | return k.queue 187 | } 188 | func (k *KeepConnection) GetWatchKey() map[string]int64 { 189 | if k.watchKey == nil { 190 | k.watchKey = make(map[string]int64) 191 | } 192 | return k.watchKey 193 | } 194 | 195 | func (k *KeepConnection) CleanWatchKey() { 196 | k.watchKey = nil 197 | } 198 | 199 | func (k *KeepConnection) GetTxErrors() []error { 200 | return k.txErrors 201 | } 202 | 203 | func (k *KeepConnection) AddTxError(err error) { 204 | k.txErrors = append(k.txErrors, err) 205 | } 206 | -------------------------------------------------------------------------------- /redis/connection/virtualconn.go: -------------------------------------------------------------------------------- 1 | package connection 2 | 3 | import "github.com/gofish2020/easyredis/tool/conf" 4 | 5 | type VirtualConnection struct { 6 | KeepConnection 7 | dbIndex int 8 | } 9 | 10 | func NewVirtualConn() *VirtualConnection { 11 | c := &VirtualConnection{} 12 | return c 13 | } 14 | 15 | func (v *VirtualConnection) SetDBIndex(index int) { 16 | v.dbIndex = index 17 | } 18 | 19 | func (v *VirtualConnection) GetDBIndex() int { 20 | return v.dbIndex 21 | } 22 | 23 | func (v *VirtualConnection) GetPassword() string { 24 | return conf.GlobalConfig.RequirePass 25 | } 26 | -------------------------------------------------------------------------------- /redis/handler.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net" 7 | "strings" 8 | "sync" 9 | 10 | "github.com/gofish2020/easyredis/abstract" 11 | "github.com/gofish2020/easyredis/cluster" 12 | "github.com/gofish2020/easyredis/engine" 13 | "github.com/gofish2020/easyredis/redis/connection" 14 | "github.com/gofish2020/easyredis/redis/parser" 15 | "github.com/gofish2020/easyredis/redis/protocol" 16 | "github.com/gofish2020/easyredis/tool/conf" 17 | "github.com/gofish2020/easyredis/tool/logger" 18 | ) 19 | 20 | type Handler interface { 21 | Handle(ctx context.Context, conn net.Conn) 22 | Close() error 23 | } 24 | 25 | type RedisHandler struct { 26 | activeConn sync.Map 27 | 28 | //engine *engine.Engine 29 | 30 | engine abstract.Engine 31 | } 32 | 33 | func NewRedisHandler() *RedisHandler { 34 | 35 | var abEngine abstract.Engine 36 | if len(conf.GlobalConfig.Peers) > 0 { 37 | // 分布式 38 | logger.Debug("启动集群版") 39 | abEngine = cluster.NewCluster() 40 | } else { 41 | // 单机版 42 | logger.Debug("启动单机版") 43 | abEngine = engine.NewEngine() 44 | } 45 | return &RedisHandler{ 46 | engine: abEngine, 47 | } 48 | } 49 | 50 | // 该方法是不同的conn复用的方法,要做的事情就是从conn中读取出符合RESP格式的数据; 51 | // 然后针对消息格式,进行不同的业务处理 52 | func (h *RedisHandler) Handle(ctx context.Context, conn net.Conn) { 53 | 54 | // 因为需要记录和conn相关的各种信息呢,所以定义 KeepConnection对象,将conn保存 55 | keepConn := connection.NewKeepConnection(conn) 56 | h.activeConn.Store(keepConn, struct{}{}) 57 | 58 | outChan := parser.ParseStream(conn) 59 | for payload := range outChan { 60 | if payload.Err != nil { 61 | // 网络conn关闭 62 | if payload.Err == io.EOF || payload.Err == io.ErrUnexpectedEOF || strings.Contains(payload.Err.Error(), "use of closed network connection") { 63 | h.activeConn.Delete(keepConn) 64 | logger.Warn("client closed:" + keepConn.RemoteAddr()) 65 | keepConn.Close() 66 | return 67 | } 68 | 69 | // 解析出错 protocol error 70 | errReply := protocol.NewGenericErrReply(payload.Err.Error()) 71 | _, err := keepConn.Write(errReply.ToBytes()) 72 | if err != nil { 73 | h.activeConn.Delete(keepConn) 74 | logger.Warn("client closed:" + keepConn.RemoteAddr() + " err info: " + err.Error()) 75 | keepConn.Close() 76 | return 77 | } 78 | continue 79 | } 80 | 81 | if payload.Reply == nil { 82 | logger.Error("empty payload") 83 | continue 84 | } 85 | 86 | reply, ok := payload.Reply.(*protocol.MultiBulkReply) 87 | if !ok { 88 | logger.Error("require multi bulk protocol") 89 | continue 90 | } 91 | 92 | logger.Debugf("%q", string(reply.ToBytes())) 93 | // 解析出redis命令,丢给存储引擎处理 94 | result := h.engine.Exec(keepConn, reply.RedisCommand) 95 | if result != nil { 96 | keepConn.Write(result.ToBytes()) 97 | } else { 98 | keepConn.Write(protocol.NewUnknownErrReply().ToBytes()) 99 | } 100 | } 101 | } 102 | 103 | func (h *RedisHandler) Close() error { 104 | 105 | logger.Info("handler shutting down...") 106 | 107 | h.activeConn.Range(func(key, value any) bool { 108 | keepConn := key.(*connection.KeepConnection) 109 | keepConn.Close() 110 | h.activeConn.Delete(key) 111 | return true 112 | }) 113 | h.engine.Close() 114 | return nil 115 | } 116 | -------------------------------------------------------------------------------- /redis/parser/parser.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "errors" 7 | "io" 8 | "runtime/debug" 9 | "strconv" 10 | 11 | "github.com/gofish2020/easyredis/redis/protocol" 12 | "github.com/gofish2020/easyredis/tool/logger" 13 | ) 14 | 15 | type Payload struct { 16 | Err error 17 | Reply protocol.Reply 18 | } 19 | 20 | // 从reader读取数据&解析,并保存到chan中,供外部读取 21 | func ParseStream(reader io.Reader) <-chan *Payload { 22 | dataStream := make(chan *Payload) 23 | // 启动协程 24 | go parse(reader, dataStream) 25 | return dataStream 26 | } 27 | 28 | // 从r中读取数据,将读取的结果通过 out chan 发送给外部使用(包括:正常的数据包 or 网络错误) 29 | func parse(r io.Reader, out chan<- *Payload) { 30 | 31 | // 异常恢复,避免未知异常 32 | defer func() { 33 | if err := recover(); err != nil { 34 | logger.Error(err, string(debug.Stack())) 35 | } 36 | }() 37 | 38 | reader := bufio.NewReader(r) 39 | for { 40 | 41 | // 按照 \n 分隔符读取一行数据 42 | line, err := reader.ReadBytes('\n') 43 | if err != nil { // 一般是 io.EOF错误(说明conn关闭or文件尾部) 44 | out <- &Payload{Err: err} 45 | close(out) 46 | return 47 | } 48 | // 读取到的line中包括 \n 分割符 49 | length := len(line) 50 | 51 | // RESP协议是按照 \r\n 分割数据 52 | if length <= 2 || line[length-2] != '\r' { // 说明是空白行,忽略 53 | continue 54 | } 55 | 56 | // 去掉尾部 \r\n 57 | line = bytes.TrimSuffix(line, []byte{'\r', '\n'}) 58 | 59 | // 协议文档 :https://redis.io/docs/reference/protocol-spec/ 60 | // The first byte in an RESP-serialized payload always identifies its type. Subsequent bytes constitute the type's contents. 61 | switch line[0] { 62 | case '*': // * 表示数组 63 | err := parseArrays(line, reader, out) 64 | if err != nil { 65 | out <- &Payload{Err: err} 66 | close(out) 67 | return 68 | } 69 | // + 成功 70 | case '+': 71 | out <- &Payload{ 72 | Reply: protocol.NewSimpleReply(string(line[1:])), 73 | } 74 | // - 错误 75 | case '-': 76 | out <- &Payload{ 77 | Reply: protocol.NewSimpleErrReply(string(line[1:])), 78 | } 79 | 80 | // $ 二进制安全,字符串 81 | case '$': 82 | err = parseBulkString(line, reader, out) 83 | if err != nil { 84 | out <- &Payload{Err: err} 85 | close(out) 86 | return 87 | } 88 | default: 89 | args := bytes.Split(line, []byte{' '}) 90 | out <- &Payload{ 91 | Reply: protocol.NewMultiBulkReply(args), 92 | } 93 | } 94 | } 95 | } 96 | 97 | // 格式: $5\r\nvalue\r\n 98 | func parseBulkString(header []byte, reader *bufio.Reader, out chan<- *Payload) error { 99 | 100 | byteNum, err := strconv.ParseInt(string(header[1:]), 10, 64) 101 | if err != nil || byteNum < -1 { 102 | protocolError(out, "illegal bulk string header: "+string(header)) 103 | return nil 104 | } else if byteNum == -1 { // 空字符串 105 | out <- &Payload{ 106 | Reply: protocol.NewNullBulkReply(), 107 | } 108 | return nil 109 | } 110 | 111 | body := make([]byte, byteNum+2) 112 | _, err = io.ReadFull(reader, body) 113 | if err != nil { 114 | return err 115 | } 116 | out <- &Payload{ 117 | Reply: protocol.NewBulkReply(body[:len(body)-2]), 118 | } 119 | return nil 120 | } 121 | 122 | /* 123 | 数组格式: 124 | 125 | *2\r\n 126 | $5\r\n 127 | hello\r\n 128 | $5\r\n 129 | world\r\n 130 | 131 | */ 132 | 133 | func parseArrays(header []byte, reader *bufio.Reader, out chan<- *Payload) error { 134 | // 解析 *2 , bodyNum 表示后序有多少个数据等待解析 135 | 136 | bodyNum, err := strconv.ParseInt(string(header[1:]), 10, 64) 137 | if err != nil || bodyNum < 0 { 138 | protocolError(out, "illegal array header"+string(header[1:])) 139 | return nil 140 | } 141 | 142 | // lines最终保存的解析出来的结果 143 | lines := make([][]byte, 0, bodyNum) 144 | // 解析后序数据 145 | for i := int64(0); i < bodyNum; i++ { 146 | // 继续读取一行 147 | var line []byte 148 | line, err = reader.ReadBytes('\n') 149 | if err != nil { 150 | return err 151 | } 152 | 153 | // 解析 $5\r\n 154 | length := len(line) 155 | if length < 4 || line[length-2] != '\r' || line[0] != '$' { 156 | protocolError(out, "illegal bulk string header "+string(line)) 157 | return nil 158 | } 159 | // 得到数字 $5中的数字5 160 | dataLen, err := strconv.ParseInt(string(line[1:length-2]), 10, 64) 161 | if err != nil || dataLen < -1 { 162 | protocolError(out, "illegal bulk string length "+string(line)) 163 | return nil 164 | } else if dataLen == -1 { // 这里的-1 表示 Null elements in arrays 165 | lines = append(lines, nil) 166 | } else { 167 | // 基于数字5 读取 5+2 长度的数据,这里的2表示\r\n 168 | body := make([]byte, dataLen+2) 169 | // 注意:这里直接读取指定长度的字节 170 | _, err := io.ReadFull(reader, body) 171 | if err != nil { 172 | return err 173 | } 174 | // 所以最终读取到的是 hello\r\n,去掉\r\n 保存到 lines中 175 | lines = append(lines, body[:len(body)-2]) 176 | } 177 | } 178 | 179 | out <- &Payload{ 180 | Err: nil, 181 | Reply: protocol.NewMultiBulkReply(lines), 182 | } 183 | return nil 184 | } 185 | 186 | func protocolError(out chan<- *Payload, msg string) { 187 | err := errors.New("protocol error: " + msg) 188 | out <- &Payload{Err: err} 189 | } 190 | -------------------------------------------------------------------------------- /redis/protocol/basic.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import "github.com/gofish2020/easyredis/utils" 4 | 5 | // +OK\r\n 6 | var okReply = &OKReply{} 7 | 8 | type OKReply struct{} 9 | 10 | func (r *OKReply) ToBytes() []byte { 11 | return []byte("+OK" + utils.CRLF) 12 | } 13 | 14 | func NewOkReply() *OKReply { 15 | return okReply 16 | } 17 | 18 | // +PONG\r\n 19 | var pongReply = &PONGReply{} 20 | 21 | type PONGReply struct{} 22 | 23 | func (r *PONGReply) ToBytes() []byte { 24 | return []byte("+PONG" + utils.CRLF) 25 | } 26 | 27 | func NewPONGReply() *PONGReply { 28 | return pongReply 29 | } 30 | 31 | // 简单字符串 32 | type SimpleReply struct { 33 | Str string 34 | } 35 | 36 | func (s *SimpleReply) ToBytes() []byte { 37 | return []byte("+" + s.Str + utils.CRLF) 38 | } 39 | 40 | func NewSimpleReply(str string) *SimpleReply { 41 | return &SimpleReply{ 42 | Str: str, 43 | } 44 | } 45 | 46 | // 空回复 47 | type NoReply struct{} 48 | 49 | var noBytes = []byte("") 50 | 51 | // ToBytes marshal redis.Reply 52 | func (r *NoReply) ToBytes() []byte { 53 | return noBytes 54 | } 55 | 56 | func NewNoReply() *NoReply { 57 | return &NoReply{} 58 | } 59 | 60 | // +QUEUED 61 | var queuedReply = &QueuedReply{} 62 | 63 | type QueuedReply struct{} 64 | 65 | func (r *QueuedReply) ToBytes() []byte { 66 | return []byte("+QUEUED" + utils.CRLF) 67 | } 68 | 69 | func NewQueuedReply() *QueuedReply { 70 | return queuedReply 71 | } 72 | 73 | func IsOKReply(reply Reply) bool { 74 | return string(reply.ToBytes()) == "+OK\r\n" 75 | } 76 | -------------------------------------------------------------------------------- /redis/protocol/bulk.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "bytes" 5 | "strconv" 6 | "strings" 7 | 8 | "github.com/gofish2020/easyredis/utils" 9 | ) 10 | 11 | // 空数组 empty array 12 | var emptyMultiBulkReply = &EmptyMultiBulkReply{} 13 | 14 | type EmptyMultiBulkReply struct { 15 | } 16 | 17 | func (e *EmptyMultiBulkReply) ToBytes() []byte { 18 | return []byte("*0\r\n") 19 | } 20 | 21 | func NewEmptyMultiBulkReply() *EmptyMultiBulkReply { 22 | return emptyMultiBulkReply 23 | } 24 | 25 | // 二进制安全 多个bulk *2\r\n$5\r\nhello\r\n$5\r\nworld\r\n 26 | type MultiBulkReply struct { 27 | RedisCommand [][]byte 28 | } 29 | 30 | func NewMultiBulkReply(command [][]byte) *MultiBulkReply { 31 | return &MultiBulkReply{ 32 | RedisCommand: command, 33 | } 34 | } 35 | 36 | func (r *MultiBulkReply) ToBytes() []byte { 37 | num := len(r.RedisCommand) 38 | var buf bytes.Buffer 39 | buf.WriteString("*" + strconv.Itoa(num) + utils.CRLF) 40 | for _, command := range r.RedisCommand { 41 | if command == nil { 42 | buf.WriteString("$-1" + utils.CRLF) 43 | } else { 44 | length := len(command) 45 | buf.WriteString("$" + strconv.Itoa(length) + utils.CRLF + string(command) + utils.CRLF) 46 | } 47 | } 48 | return buf.Bytes() 49 | } 50 | 51 | // 二进制安全 单个bulk $3\r\nkey\r\n 52 | type BulkReply struct { 53 | Arg []byte 54 | } 55 | 56 | func NewBulkReply(arg []byte) *BulkReply { 57 | 58 | return &BulkReply{ 59 | Arg: arg, 60 | } 61 | } 62 | func (b *BulkReply) ToBytes() []byte { 63 | if b.Arg == nil { 64 | return NewNullBulkReply().ToBytes() 65 | } 66 | return []byte("$" + strconv.Itoa(len(b.Arg)) + utils.CRLF + string(b.Arg) + utils.CRLF) 67 | } 68 | 69 | // null bulk $-1\r\n 70 | var nullBulkReply = &NullBulkReply{} 71 | 72 | type NullBulkReply struct{} 73 | 74 | func (n *NullBulkReply) ToBytes() []byte { 75 | return []byte("$-1" + utils.CRLF) 76 | } 77 | 78 | func NewNullBulkReply() *NullBulkReply { 79 | return nullBulkReply 80 | } 81 | 82 | // Integer :3\r\n 83 | type IntegerReply struct { 84 | Integer int64 85 | } 86 | 87 | func (i *IntegerReply) ToBytes() []byte { 88 | return []byte(":" + strconv.FormatInt(i.Integer, 10) + utils.CRLF) 89 | } 90 | 91 | func NewIntegerReply(integer int64) *IntegerReply { 92 | return &IntegerReply{Integer: integer} 93 | } 94 | 95 | // Mix 96 | func NewMixReply() *MixReply { 97 | return &MixReply{} 98 | } 99 | 100 | type MixReply struct { 101 | replies []Reply 102 | } 103 | 104 | func (m *MixReply) ToBytes() []byte { 105 | num := len(m.replies) 106 | var str strings.Builder 107 | str.WriteString("*" + strconv.Itoa(num) + utils.CRLF) // example: *3\r\n 108 | for _, reply := range m.replies { 109 | str.Write(reply.ToBytes()) 110 | } 111 | return []byte(str.String()) 112 | } 113 | 114 | func (m *MixReply) Append(replies ...Reply) { 115 | m.replies = append(m.replies, replies...) 116 | } 117 | -------------------------------------------------------------------------------- /redis/protocol/errors.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import "github.com/gofish2020/easyredis/utils" 4 | 5 | // 自定义错误 - xxxxx 6 | type SimpleErrReply struct { 7 | Status string 8 | } 9 | 10 | func NewSimpleErrReply(status string) *SimpleErrReply { 11 | return &SimpleErrReply{ 12 | Status: status, 13 | } 14 | } 15 | 16 | func (s *SimpleErrReply) ToBytes() []byte { 17 | return []byte("-" + s.Status) 18 | } 19 | 20 | // 一般错误 -ERR xxxxx 21 | type GenericErrReply struct { 22 | Status string 23 | } 24 | 25 | func NewGenericErrReply(status string) *GenericErrReply { 26 | return &GenericErrReply{ 27 | Status: status, 28 | } 29 | } 30 | 31 | func (s *GenericErrReply) ToBytes() []byte { 32 | return []byte("-ERR " + s.Status + utils.CRLF) 33 | } 34 | 35 | // 未知错误 -ERR unknown 36 | type UnknownErrReply struct{} 37 | 38 | func (r *UnknownErrReply) ToBytes() []byte { 39 | return []byte("-ERR unknown\r\n") 40 | } 41 | 42 | func NewUnknownErrReply() *UnknownErrReply { 43 | return &UnknownErrReply{} 44 | } 45 | 46 | // 命令参数数量错误 47 | type ArgNumErrReply struct { 48 | Cmd string 49 | } 50 | 51 | func (r *ArgNumErrReply) ToBytes() []byte { 52 | return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n") 53 | } 54 | 55 | func NewArgNumErrReply(cmd string) *ArgNumErrReply { 56 | return &ArgNumErrReply{ 57 | Cmd: cmd, 58 | } 59 | } 60 | 61 | // 底层数据类型错误 62 | type WrongTypeErrReply struct{} 63 | 64 | var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n") 65 | 66 | func (r *WrongTypeErrReply) ToBytes() []byte { 67 | return wrongTypeErrBytes 68 | } 69 | 70 | func NewWrongTypeErrReply() *WrongTypeErrReply { 71 | return &WrongTypeErrReply{} 72 | } 73 | 74 | // 语法错误 75 | type SyntaxErrReply struct{} 76 | 77 | var syntaxErrBytes = []byte("-Err syntax error\r\n") 78 | var syntaxErrReply = &SyntaxErrReply{} 79 | 80 | func (s *SyntaxErrReply) ToBytes() []byte { 81 | return syntaxErrBytes 82 | } 83 | 84 | func NewSyntaxErrReply() *SyntaxErrReply { 85 | return syntaxErrReply 86 | } 87 | 88 | // 是否为Err 89 | 90 | func IsErrReply(reply Reply) bool { 91 | return reply.ToBytes()[0] == '-' 92 | } 93 | -------------------------------------------------------------------------------- /redis/protocol/interface.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | type Reply interface { 4 | ToBytes() []byte 5 | } 6 | -------------------------------------------------------------------------------- /redis0.conf: -------------------------------------------------------------------------------- 1 | Bind 127.0.0.1 2 | Port 6379 3 | 4 | Dir ./data/redis0/ 5 | 6 | AppendOnly yes 7 | AppendFilename append.aof 8 | AppendFsync everysec 9 | # 密码 10 | # RequirePass 1 11 | 12 | 13 | Peers 127.0.0.1:7379,127.0.0.1:8379 14 | Self 127.0.0.1:6379 -------------------------------------------------------------------------------- /redis1.conf: -------------------------------------------------------------------------------- 1 | Bind 127.0.0.1 2 | Port 7379 3 | 4 | Dir ./data/redis1/ 5 | 6 | AppendOnly yes 7 | AppendFilename append.aof 8 | AppendFsync everysec 9 | # 密码 10 | # RequirePass 1 11 | 12 | 13 | Peers 127.0.0.1:6379,127.0.0.1:8379 14 | Self 127.0.0.1:7379 -------------------------------------------------------------------------------- /redis2.conf: -------------------------------------------------------------------------------- 1 | Bind 127.0.0.1 2 | Port 8379 3 | 4 | Dir ./data/redis2/ 5 | 6 | AppendOnly yes 7 | AppendFilename append.aof 8 | AppendFsync everysec 9 | # 密码 10 | # RequirePass 1 11 | 12 | 13 | Peers 127.0.0.1:6379,127.0.0.1:7379 14 | Self 127.0.0.1:8379 -------------------------------------------------------------------------------- /tcpserver/tcpserver.go: -------------------------------------------------------------------------------- 1 | package tcpserver 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "os" 7 | "os/signal" 8 | "sync" 9 | "sync/atomic" 10 | "syscall" 11 | "time" 12 | 13 | "github.com/gofish2020/easyredis/redis" 14 | "github.com/gofish2020/easyredis/tool/logger" 15 | ) 16 | 17 | type TCPConfig struct { 18 | Addr string 19 | } 20 | 21 | type TCPServer struct { 22 | listener net.Listener // 监听句柄 23 | waitDone sync.WaitGroup // 优雅关闭(等待) 24 | clientCounter int64 // 有多少个客户端在执行中 25 | conf TCPConfig // 配置 26 | closeTcp int32 // 关闭标识 27 | quit chan os.Signal // 监听进程信号 28 | redisHander redis.Handler // 实际处理连接对象 29 | } 30 | 31 | func NewTCPServer(conf TCPConfig, handler redis.Handler) *TCPServer { 32 | server := &TCPServer{ 33 | conf: conf, 34 | closeTcp: 0, 35 | clientCounter: 0, 36 | quit: make(chan os.Signal, 1), 37 | redisHander: handler, 38 | } 39 | return server 40 | } 41 | 42 | func (t *TCPServer) Start() error { 43 | // 开启监听 44 | listen, err := net.Listen("tcp", t.conf.Addr) 45 | if err != nil { 46 | return err 47 | } 48 | t.listener = listen 49 | logger.Infof("bind %s listening...", t.conf.Addr) 50 | // 接收连接 51 | go t.accept() 52 | // 阻塞于信号 53 | signal.Notify(t.quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) 54 | <-t.quit 55 | return nil 56 | } 57 | 58 | // accept 死循环接收新连接的到来 59 | func (t *TCPServer) accept() error { 60 | 61 | for { 62 | conn, err := t.listener.Accept() 63 | if err != nil { 64 | if ne, ok := err.(net.Error); ok && ne.Timeout() { 65 | logger.Infof("accept occurs temporary error: %v, retry in 5ms", err) 66 | time.Sleep(5 * time.Millisecond) 67 | continue 68 | } 69 | // 说明监听listener关闭,无法接收新连接 70 | logger.Warn(err.Error()) 71 | atomic.CompareAndSwapInt32(&t.closeTcp, 0, 1) 72 | // 整个进程退出 73 | t.quit <- syscall.SIGTERM 74 | // 结束 for循环 75 | break 76 | } 77 | // 启动一个协程处理conn 78 | go t.handleConn(conn) 79 | } 80 | 81 | return nil 82 | } 83 | 84 | func (t *TCPServer) handleConn(conn net.Conn) { 85 | // 如果已关闭,新连接不再处理 86 | if atomic.LoadInt32(&t.closeTcp) == 1 { 87 | // 直接关闭 88 | conn.Close() 89 | return 90 | } 91 | 92 | logger.Debugf("accept new conn %s", conn.RemoteAddr().String()) 93 | t.waitDone.Add(1) 94 | atomic.AddInt64(&t.clientCounter, 1) 95 | defer func() { 96 | t.waitDone.Done() 97 | atomic.AddInt64(&t.clientCounter, -1) 98 | }() 99 | 100 | // TODO :处理连接 101 | t.redisHander.Handle(context.Background(), conn) 102 | } 103 | 104 | // 退出前,清理 105 | func (t *TCPServer) Close() { 106 | logger.Info("graceful shutdown easyredis server") 107 | 108 | atomic.CompareAndSwapInt32(&t.closeTcp, 0, 1) 109 | // 关闭监听 110 | t.listener.Close() 111 | // 关闭处理对象 112 | t.redisHander.Close() 113 | // 阻塞中...(优雅关闭) 114 | t.waitDone.Wait() 115 | } 116 | -------------------------------------------------------------------------------- /test.conf: -------------------------------------------------------------------------------- 1 | Bind 127.0.0.1 2 | Port 6379 3 | 4 | Dir ./data/redis/ 5 | 6 | AppendOnly yes 7 | AppendFilename append.aof 8 | AppendFsync everysec 9 | # 密码 10 | # RequirePass 1 11 | 12 | 13 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | go run main.go -conf=./test.conf -------------------------------------------------------------------------------- /tool/conf/config.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "os" 7 | "reflect" 8 | "strconv" 9 | "strings" 10 | 11 | "github.com/gofish2020/easyredis/utils" 12 | ) 13 | 14 | const runidMaxLen = 40 15 | const defaultDatabasesNum = 16 16 | 17 | /* 18 | purpose:读取conf配置文件 19 | */ 20 | 21 | type RedisConfig struct { 22 | // 基础配置 23 | Bind string `conf:"bind"` 24 | Port int `conf:"port"` 25 | Dir string `conf:"dir"` 26 | RunID string `conf:"runid"` 27 | 28 | // 数据库个数 29 | Databases int `conf:"databases"` 30 | 31 | // aof 相关 32 | AppendOnly bool `conf:"appendonly"` // 是否启用aof 33 | AppendFilename string `conf:"appendfilename"` // aof文件名 34 | AppendFsync string `conf:"appendfsync"` // aof刷盘间隔 35 | 36 | // 服务器密码 37 | RequirePass string `conf:"requirepass,omitempty"` 38 | 39 | // 集群 40 | Peers []string `conf:"peers"` 41 | Self string `conf:"self"` 42 | } 43 | 44 | // 全局配置 45 | var GlobalConfig *RedisConfig 46 | 47 | func init() { 48 | GlobalConfig = &RedisConfig{ 49 | Bind: "127.0.0.1", 50 | Port: 6379, 51 | AppendOnly: false, 52 | Dir: ".", 53 | RunID: utils.RandString(runidMaxLen), 54 | Databases: defaultDatabasesNum, 55 | } 56 | } 57 | 58 | // 加载配置文件,更新 GlobalConfig 对象 59 | func LoadConfig(configFile string) error { 60 | 61 | //1.打开文件 62 | file, err := os.Open(configFile) 63 | if err != nil { 64 | panic(err) 65 | } 66 | defer file.Close() 67 | //2.解析文件 68 | GlobalConfig = parse(file) 69 | 70 | //3.补充信息 71 | GlobalConfig.RunID = utils.RandString(runidMaxLen) 72 | if GlobalConfig.Dir == "" { 73 | GlobalConfig.Dir = utils.ExecDir() 74 | } 75 | 76 | if GlobalConfig.Databases == 0 { 77 | GlobalConfig.Databases = defaultDatabasesNum 78 | } 79 | 80 | utils.MakeDir(GlobalConfig.Dir) 81 | return nil 82 | } 83 | 84 | func parse(r io.Reader) *RedisConfig { 85 | 86 | newRedisConfig := &RedisConfig{} 87 | 88 | //1.按行扫描文件 89 | lineMap := make(map[string]string) 90 | scanner := bufio.NewScanner(r) 91 | 92 | for scanner.Scan() { 93 | line := scanner.Text() 94 | line = strings.TrimLeft(line, " ") 95 | 96 | // 空行 or 注释行 97 | if len(line) == 0 || (len(line) > 0 && line[0] == '#') { 98 | continue 99 | } 100 | 101 | // 解析行 例如: Bind 127.0.0.1 102 | idx := strings.IndexAny(line, " ") 103 | if idx > 0 && idx < len(line)-1 { 104 | key := line[:idx] 105 | value := strings.Trim(line[idx+1:], " ") 106 | // 将每行的结果,保存到lineMap中 107 | lineMap[strings.ToLower(key)] = value 108 | } 109 | } 110 | 111 | if err := scanner.Err(); err != nil { 112 | panic(err.Error()) 113 | } 114 | 115 | //2.将扫描结果保存到newRedisConfig 对象中 116 | 117 | configValue := reflect.ValueOf(newRedisConfig).Elem() 118 | configType := reflect.TypeOf(newRedisConfig).Elem() 119 | 120 | // 遍历结构体字段(类型) 121 | for i := 0; i < configType.NumField(); i++ { 122 | 123 | fieldType := configType.Field(i) 124 | // 读取字段名 125 | fieldName := strings.Trim(fieldType.Tag.Get("conf"), " ") 126 | if fieldName == "" { 127 | fieldName = fieldType.Name 128 | } else { 129 | fieldName = strings.Split(fieldName, ",")[0] 130 | } 131 | fieldName = strings.ToLower(fieldName) 132 | // 判断该字段是否在config中有配置 133 | fieldValue, ok := lineMap[fieldName] 134 | 135 | if ok { 136 | // 将结果保存到字段中 137 | switch fieldType.Type.Kind() { 138 | case reflect.String: 139 | configValue.Field(i).SetString(fieldValue) 140 | case reflect.Bool: 141 | configValue.Field(i).SetBool("yes" == fieldValue) 142 | case reflect.Int: 143 | intValue, err := strconv.ParseInt(fieldValue, 10, 64) 144 | if err == nil { 145 | configValue.Field(i).SetInt(intValue) 146 | } 147 | case reflect.Slice: 148 | // 切片的元素是字符串 149 | if fieldType.Type.Elem().Kind() == reflect.String { 150 | tmpSlice := strings.Split(fieldValue, ",") 151 | configValue.Field(i).Set(reflect.ValueOf(tmpSlice)) 152 | } 153 | } 154 | } 155 | } 156 | return newRedisConfig 157 | } 158 | 159 | func TmpDir() string { 160 | dir := GlobalConfig.Dir + "/tmp" 161 | utils.MakeDir(dir) 162 | return dir 163 | } 164 | -------------------------------------------------------------------------------- /tool/conf/config_test.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestLoad(t *testing.T) { 10 | 11 | err := LoadConfig("/Users/mac/source/easyredis/test.conf") 12 | 13 | assert.Equal(t, nil, err) 14 | 15 | assert.Equal(t, "127.0.0.1", GlobalConfig.Bind) 16 | assert.Equal(t, 3000, GlobalConfig.Port) 17 | assert.Equal(t, "/opt/data", GlobalConfig.Dir) 18 | assert.Equal(t, true, GlobalConfig.AppendOnly) 19 | assert.Equal(t, "", GlobalConfig.AppendFilename) 20 | assert.Equal(t, "", GlobalConfig.AppendFsync) 21 | assert.Equal(t, "20231224", GlobalConfig.RequirePass) 22 | assert.Equal(t, []string{"192.168.1.10", "192.168.1.10"}, GlobalConfig.Peers) 23 | 24 | t.Log(GlobalConfig.RunID) 25 | 26 | } 27 | 28 | 29 | -------------------------------------------------------------------------------- /tool/consistenthash/consistenthash.go: -------------------------------------------------------------------------------- 1 | package consistenthash 2 | 3 | import ( 4 | "hash/crc32" 5 | "sort" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | type HashFunc func(data []byte) uint32 11 | 12 | type Map struct { 13 | hashFunc HashFunc // 计算hash函数 14 | replicas int // 每个节点的虚拟节点数量 15 | hashValue []int // hash值 16 | hashMap map[int]string // hash值映射的真实节点 17 | } 18 | 19 | /* 20 | replicas:副本数量 21 | fn:hash函数 22 | */ 23 | func New(replicas int, fn HashFunc) *Map { 24 | m := &Map{ 25 | replicas: replicas, 26 | hashFunc: fn, 27 | hashMap: make(map[int]string), 28 | } 29 | if m.hashFunc == nil { 30 | m.hashFunc = crc32.ChecksumIEEE 31 | } 32 | return m 33 | } 34 | 35 | func (m *Map) IsEmpty() bool { 36 | return len(m.hashValue) == 0 37 | } 38 | 39 | // 添加 节点 40 | func (m *Map) Add(ipAddrs ...string) { 41 | for _, ipAddr := range ipAddrs { 42 | if ipAddr == "" { 43 | continue 44 | } 45 | // 每个ipAddr 生成 m.replicas个哈希值副本 46 | for i := 0; i < m.replicas; i++ { 47 | hash := int(m.hashFunc([]byte(strconv.Itoa(i) + ipAddr))) 48 | // 记录hash值 49 | m.hashValue = append(m.hashValue, hash) 50 | // 映射hash为同一个ipAddr 51 | m.hashMap[hash] = ipAddr 52 | } 53 | } 54 | sort.Ints(m.hashValue) 55 | } 56 | 57 | // support hash tag example :{key} 58 | func getPartitionKey(key string) string { 59 | beg := strings.Index(key, "{") 60 | if beg == -1 { 61 | return key 62 | } 63 | end := strings.Index(key, "}") 64 | if end == -1 || end == beg+1 { 65 | return key 66 | } 67 | return key[beg+1 : end] 68 | } 69 | 70 | // Get gets the closest item in the hash to the provided key. 71 | func (m *Map) Get(key string) string { 72 | if m.IsEmpty() { 73 | return "" 74 | } 75 | 76 | partitionKey := getPartitionKey(key) 77 | hash := int(m.hashFunc([]byte(partitionKey))) 78 | 79 | // 查找 m.keys中第一个大于or等于hash值的元素索引 80 | idx := sort.Search(len(m.hashValue), func(i int) bool { return m.hashValue[i] >= hash }) // 81 | 82 | // 表示找了一圈没有找到大于or等于hash值的元素,那么默认是第0号元素 83 | if idx == len(m.hashValue) { 84 | idx = 0 85 | } 86 | 87 | // 返回 key应该存储的ipAddr 88 | return m.hashMap[m.hashValue[idx]] 89 | } 90 | -------------------------------------------------------------------------------- /tool/consistenthash/consistenthash_test.go: -------------------------------------------------------------------------------- 1 | package consistenthash 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/gofish2020/easyredis/utils" 7 | ) 8 | 9 | func TestConsistentHash(t *testing.T) { 10 | 11 | hashMap := New(100, nil) // replicas越大,分布越均匀 12 | 13 | ipAddrs := []string{"127.0.0.1:7379", "127.0.0.1:8379", "127.0.0.1:6379"} 14 | hashMap.Add(ipAddrs...) 15 | 16 | constains := make(map[string]int) 17 | for i := 0; i < 100; i++ { 18 | value := hashMap.Get(utils.RandString(10)) 19 | constains[value]++ 20 | } 21 | 22 | for k, v := range constains { 23 | t.Log(k, v) 24 | } 25 | 26 | } 27 | -------------------------------------------------------------------------------- /tool/idgenerator/snowflake.go: -------------------------------------------------------------------------------- 1 | package idgenerator 2 | 3 | import ( 4 | "hash/fnv" 5 | "log" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | const ( 11 | // epoch0 is set to the twitter snowflake epoch of Nov 04 2010 01:42:54 UTC in milliseconds 12 | // You may customize this to set a different epoch for your application. 13 | epoch0 int64 = 1288834974657 14 | maxSequence int64 = -1 ^ (-1 << uint64(nodeLeft)) 15 | timeLeft uint8 = 22 16 | nodeLeft uint8 = 10 17 | nodeMask int64 = -1 ^ (-1 << uint64(timeLeft-nodeLeft)) 18 | ) 19 | 20 | // IDGenerator generates unique uint64 ID using snowflake algorithm 21 | type IDGenerator struct { 22 | mu *sync.Mutex 23 | lastStamp int64 24 | nodeID int64 25 | sequence int64 26 | epoch time.Time 27 | } 28 | 29 | // MakeGenerator creates a new IDGenerator 30 | func MakeGenerator(node string) *IDGenerator { 31 | fnv64 := fnv.New64() 32 | _, _ = fnv64.Write([]byte(node)) 33 | nodeID := int64(fnv64.Sum64()) & nodeMask 34 | 35 | var curTime = time.Now() 36 | epoch := curTime.Add(time.Unix(epoch0/1000, (epoch0%1000)*1000000).Sub(curTime)) 37 | 38 | return &IDGenerator{ 39 | mu: &sync.Mutex{}, 40 | lastStamp: -1, 41 | nodeID: nodeID, 42 | sequence: 1, 43 | epoch: epoch, 44 | } 45 | } 46 | 47 | // NextID returns next unique ID 48 | func (w *IDGenerator) NextID() int64 { 49 | w.mu.Lock() 50 | defer w.mu.Unlock() 51 | 52 | timestamp := time.Since(w.epoch).Nanoseconds() / 1000000 53 | if timestamp < w.lastStamp { 54 | log.Fatal("can not generate id") 55 | } 56 | if w.lastStamp == timestamp { 57 | w.sequence = (w.sequence + 1) & maxSequence 58 | if w.sequence == 0 { 59 | for timestamp <= w.lastStamp { 60 | timestamp = time.Since(w.epoch).Nanoseconds() / 1000000 61 | } 62 | } 63 | } else { 64 | w.sequence = 0 65 | } 66 | w.lastStamp = timestamp 67 | id := (timestamp << timeLeft) | (w.nodeID << nodeLeft) | w.sequence 68 | //fmt.Printf("%d %d %d\n", timestamp, w.sequence, id) 69 | return id 70 | } 71 | -------------------------------------------------------------------------------- /tool/idgenerator/snowflake_test.go: -------------------------------------------------------------------------------- 1 | package idgenerator 2 | 3 | import "testing" 4 | 5 | func TestMGenerator(t *testing.T) { 6 | gen := MakeGenerator("a") 7 | ids := make(map[int64]struct{}) 8 | size := int(1e6) 9 | for i := 0; i < size; i++ { 10 | id := gen.NextID() 11 | _, ok := ids[id] 12 | if ok { 13 | t.Errorf("duplicated id: %d, time: %d, seq: %d", id, gen.lastStamp, gen.sequence) 14 | } 15 | ids[id] = struct{}{} 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /tool/locker/locker.go: -------------------------------------------------------------------------------- 1 | package locker 2 | 3 | import ( 4 | "sort" 5 | "sync" 6 | 7 | "github.com/gofish2020/easyredis/utils" 8 | ) 9 | 10 | type Locker struct { 11 | mu []*sync.RWMutex 12 | 13 | mask uint32 14 | } 15 | 16 | func NewLocker(count int) *Locker { 17 | l := &Locker{} 18 | 19 | count = utils.ComputeCapacity(count) 20 | l.mask = uint32(count) - 1 21 | l.mu = make([]*sync.RWMutex, count) 22 | for i := 0; i < count; i++ { 23 | l.mu[i] = &sync.RWMutex{} 24 | } 25 | return l 26 | } 27 | 28 | // 顺序加锁(互斥) 29 | func (l *Locker) Locks(keys ...string) { 30 | indexs := l.toLockIndex(keys...) 31 | for _, index := range indexs { 32 | mu := l.mu[index] 33 | mu.Lock() 34 | } 35 | } 36 | 37 | // 顺序解锁(互斥) 38 | func (l *Locker) Unlocks(keys ...string) { 39 | indexs := l.toLockIndex(keys...) 40 | for _, index := range indexs { 41 | mu := l.mu[index] 42 | mu.Unlock() 43 | } 44 | } 45 | 46 | // 顺序加锁(只读) 47 | func (l *Locker) RLocks(keys ...string) { 48 | indexs := l.toLockIndex(keys...) 49 | for _, index := range indexs { 50 | mu := l.mu[index] 51 | mu.RLock() 52 | } 53 | } 54 | 55 | // 顺序解锁(只读) 56 | func (l *Locker) RUnlocks(keys ...string) { 57 | indexs := l.toLockIndex(keys...) 58 | for _, index := range indexs { 59 | mu := l.mu[index] 60 | mu.RUnlock() 61 | } 62 | } 63 | 64 | // 顺序加锁 wkeys和rkeys可以有重叠 65 | func (l *Locker) RWLocks(wkeys []string, rkeys []string) { 66 | 67 | // 所有的key的索引 (内部会去重) 68 | allKeys := append(wkeys, rkeys...) 69 | allIndexs := l.toLockIndex(allKeys...) 70 | 71 | // 只写key的索引 72 | wMapIndex := make(map[uint32]struct{}) 73 | for _, key := range wkeys { 74 | wMapIndex[l.spread(utils.Fnv32(key))] = struct{}{} 75 | } 76 | 77 | for _, index := range allIndexs { 78 | mu := l.mu[index] 79 | 80 | if _, ok := wMapIndex[index]; ok { // 索引是写 81 | mu.Lock() // 加互斥锁 82 | } else { 83 | mu.RLock() // 加只读锁 84 | } 85 | } 86 | } 87 | 88 | // 顺序解锁 wkeys和rkeys可以有重叠 89 | func (l *Locker) RWUnlocks(wkeys []string, rkeys []string) { 90 | // 所有的key的索引 (内部会去重) 91 | allKeys := append(wkeys, rkeys...) 92 | allIndexs := l.toLockIndex(allKeys...) 93 | 94 | // 只写key的索引 95 | wMapIndex := make(map[uint32]struct{}) 96 | for _, key := range wkeys { 97 | wMapIndex[l.spread(utils.Fnv32(key))] = struct{}{} 98 | } 99 | 100 | for _, index := range allIndexs { 101 | mu := l.mu[index] 102 | 103 | if _, ok := wMapIndex[index]; ok { // 索引是写 104 | mu.Unlock() // 解锁 105 | } else { 106 | mu.RUnlock() // 解锁 107 | } 108 | } 109 | } 110 | 111 | func (l *Locker) spread(hashcode uint32) uint32 { 112 | return hashcode & l.mask 113 | } 114 | 115 | func (l *Locker) toLockIndex(keys ...string) []uint32 { 116 | 117 | // 将key转成 切片索引[0,mask] 118 | mapIndex := make(map[uint32]struct{}) // 去重 119 | for _, key := range keys { 120 | mapIndex[l.spread(utils.Fnv32(key))] = struct{}{} 121 | } 122 | 123 | indices := make([]uint32, 0, len(mapIndex)) 124 | for k := range mapIndex { 125 | indices = append(indices, k) 126 | } 127 | // 对索引排序 128 | sort.Slice(indices, func(i, j int) bool { 129 | return indices[i] < indices[j] 130 | }) 131 | return indices 132 | } 133 | -------------------------------------------------------------------------------- /tool/locker/locker_test.go: -------------------------------------------------------------------------------- 1 | package locker 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/gofish2020/easyredis/tool/logger" 8 | ) 9 | 10 | var locker *Locker 11 | 12 | func TestMain(m *testing.M) { 13 | 14 | locker = NewLocker(8) 15 | m.Run() 16 | } 17 | 18 | func Test2LockIndex(t *testing.T) { 19 | 20 | keys := []string{"1", "2", "3"} 21 | logger.Debug(locker.toLockIndex(keys...)) 22 | 23 | locker.Locks(keys...) 24 | locker.Unlocks(keys...) 25 | 26 | locker.RLocks(keys...) 27 | locker.RUnlocks(keys...) 28 | time.Sleep(3 * time.Second) 29 | } 30 | 31 | func TestRWLocker(t *testing.T) { 32 | 33 | wKeys := []string{"1", "2", "3"} 34 | logger.Debug(locker.toLockIndex(wKeys...)) 35 | rKeys := []string{"3", "4", "5"} 36 | logger.Debug(locker.toLockIndex(rKeys...)) 37 | 38 | locker.RWLocks(wKeys, rKeys) 39 | locker.RWUnlocks(wKeys, rKeys) 40 | 41 | time.Sleep(1 * time.Second) 42 | } 43 | -------------------------------------------------------------------------------- /tool/logger/logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | "path" 8 | "runtime" 9 | "sync" 10 | "time" 11 | 12 | "github.com/gofish2020/easyredis/utils" 13 | ) 14 | 15 | /* 16 | purpose: 日志库 17 | */ 18 | 19 | const ( 20 | maxLogMessageNum = 1e5 21 | callerDepth = 2 22 | 23 | Reset = "\033[0m" 24 | Red = "\033[31m" 25 | Green = "\033[32m" 26 | Blue = "\033[34m" 27 | Yellow = "\033[33m" 28 | ) 29 | 30 | // config for logger example: redis-20231225.log 31 | type Settings struct { 32 | Path string `yaml:"path"` // 路径 33 | Name string `yaml:"name"` // 文件名 34 | Ext string `yaml:"ext"` // 文件后缀 35 | DateFormat string `yaml:"date-format"` // 日期格式 36 | } 37 | 38 | // 日志级别 39 | type LogLevel int 40 | 41 | const ( 42 | NULL LogLevel = iota 43 | 44 | FATAL 45 | ERROR 46 | WARN 47 | INFO 48 | DEBUG 49 | ) 50 | 51 | var levelFlags = []string{"", "Fatal", "Error", "Warn", "Info", "Debug"} 52 | 53 | // 日志消息 54 | type logMessage struct { 55 | level LogLevel 56 | msg string 57 | } 58 | 59 | func (m *logMessage) reset() { 60 | m.level = NULL 61 | m.msg = "" 62 | } 63 | 64 | // 日志底层操作对象 65 | type logger struct { 66 | logFile *os.File 67 | logStd *log.Logger 68 | logMsgChan chan *logMessage 69 | logMsgPool *sync.Pool 70 | logLevel LogLevel 71 | close chan struct{} 72 | } 73 | 74 | func (l *logger) Close() { 75 | close(l.close) 76 | } 77 | func (l *logger) writeLog(level LogLevel, callerDepth int, msg string) { 78 | var formattedMsg string 79 | _, file, line, ok := runtime.Caller(callerDepth) 80 | if ok { 81 | formattedMsg = fmt.Sprintf("[%s][%s:%d] %s", levelFlags[level], file, line, msg) 82 | } else { 83 | formattedMsg = fmt.Sprintf("[%s] %s", levelFlags[level], msg) 84 | } 85 | 86 | // 对象池,复用*logMessage对象 87 | logMsg := l.logMsgPool.Get().(*logMessage) 88 | logMsg.level = level 89 | logMsg.msg = formattedMsg 90 | // 保存到chan缓冲中 91 | l.logMsgChan <- logMsg 92 | } 93 | 94 | var defaultLogger *logger = newStdLogger() 95 | 96 | // 构造标准输出日志对象 97 | func newStdLogger() *logger { 98 | 99 | stdLogger := &logger{ 100 | logFile: nil, 101 | logStd: log.New(os.Stdout, "", log.LstdFlags), 102 | logMsgChan: make(chan *logMessage, maxLogMessageNum), 103 | logLevel: DEBUG, 104 | close: make(chan struct{}), 105 | logMsgPool: &sync.Pool{ 106 | New: func() any { 107 | return &logMessage{} 108 | }, 109 | }, 110 | } 111 | 112 | go func() { 113 | // 从缓冲中读取数据 114 | for { 115 | select { 116 | case <-stdLogger.close: 117 | return 118 | case logMsg := <-stdLogger.logMsgChan: 119 | msg := logMsg.msg 120 | // 根据日志级别,增加不同的颜色 121 | switch logMsg.level { 122 | 123 | case DEBUG: 124 | msg = Blue + msg + Reset 125 | case INFO: 126 | msg = Green + msg + Reset 127 | case WARN: 128 | msg = Yellow + msg + Reset 129 | case ERROR, FATAL: 130 | msg = Red + msg + Reset 131 | } 132 | stdLogger.logStd.Output(0, msg) 133 | // 对象池,复用*logMessage对象 134 | logMsg.reset() 135 | stdLogger.logMsgPool.Put(logMsg) 136 | } 137 | } 138 | }() 139 | 140 | return stdLogger 141 | } 142 | 143 | // 生成输出到文件的日志对象 144 | func newFileLogger(settings *Settings) (*logger, error) { 145 | 146 | fileName := fmt.Sprintf("%s-%s.%s", settings.Name, time.Now().Format(settings.DateFormat), settings.Ext) 147 | 148 | fd, err := utils.OpenFile(fileName, settings.Path) 149 | if err != nil { 150 | return nil, fmt.Errorf("newFileLogger.OpenFile err: %s", err) 151 | } 152 | 153 | fileLogger := &logger{ 154 | logFile: fd, 155 | logStd: log.New(os.Stdout, "", log.LstdFlags), 156 | logMsgChan: make(chan *logMessage, maxLogMessageNum), 157 | logLevel: DEBUG, 158 | logMsgPool: &sync.Pool{ 159 | New: func() any { 160 | return &logMessage{} 161 | }, 162 | }, 163 | close: make(chan struct{}), 164 | } 165 | 166 | go func() { 167 | 168 | for { 169 | select { 170 | case <-fileLogger.close: 171 | return 172 | case logMsg := <-fileLogger.logMsgChan: 173 | //检查是否跨天,重新生成日志文件 174 | logFilename := fmt.Sprintf("%s-%s.%s", settings.Name, time.Now().Format(settings.DateFormat), settings.Ext) 175 | 176 | if path.Join(settings.Path, logFilename) != fileLogger.logFile.Name() { 177 | 178 | fd, err := utils.OpenFile(logFilename, settings.Path) 179 | if err != nil { 180 | panic("open log " + logFilename + " failed: " + err.Error()) 181 | } 182 | 183 | fileLogger.logFile.Close() 184 | fileLogger.logFile = fd 185 | } 186 | 187 | msg := logMsg.msg 188 | // 根据日志级别,增加不同的颜色 189 | switch logMsg.level { 190 | case DEBUG: 191 | msg = Blue + msg + Reset 192 | case INFO: 193 | msg = Green + msg + Reset 194 | case WARN: 195 | msg = Yellow + msg + Reset 196 | case ERROR, FATAL: 197 | msg = Red + msg + Reset 198 | } 199 | // 标准输出 200 | fileLogger.logStd.Output(0, msg) 201 | // 输出到文件 202 | fileLogger.logFile.WriteString(time.Now().Format(utils.DateTimeFormat) + " " + logMsg.msg + utils.CRLF) 203 | } 204 | } 205 | 206 | }() 207 | return fileLogger, nil 208 | } 209 | 210 | // 程序初始运行的时候调用 211 | func Setup(settings *Settings) { 212 | defaultLogger.Close() 213 | logger, err := newFileLogger(settings) 214 | if err != nil { 215 | panic(err) 216 | } 217 | defaultLogger = logger 218 | } 219 | 220 | // 设置日志级别 221 | func SetLoggerLevel(logLevel LogLevel) { 222 | defaultLogger.logLevel = logLevel 223 | } 224 | 225 | // ***********外部调用的日志函数*************** 226 | func Debug(v ...any) { 227 | if defaultLogger.logLevel >= DEBUG { 228 | msg := fmt.Sprint(v...) 229 | defaultLogger.writeLog(DEBUG, callerDepth, msg) 230 | } 231 | } 232 | func Debugf(format string, v ...any) { 233 | if defaultLogger.logLevel >= DEBUG { 234 | msg := fmt.Sprintf(format, v...) 235 | defaultLogger.writeLog(DEBUG, callerDepth, msg) 236 | } 237 | } 238 | 239 | func Info(v ...any) { 240 | if defaultLogger.logLevel >= INFO { 241 | msg := fmt.Sprint(v...) 242 | defaultLogger.writeLog(INFO, callerDepth, msg) 243 | } 244 | } 245 | 246 | func Infof(format string, v ...any) { 247 | if defaultLogger.logLevel >= INFO { 248 | msg := fmt.Sprintf(format, v...) 249 | defaultLogger.writeLog(INFO, callerDepth, msg) 250 | } 251 | } 252 | 253 | func Warn(v ...any) { 254 | if defaultLogger.logLevel >= WARN { 255 | msg := fmt.Sprint(v...) 256 | defaultLogger.writeLog(WARN, callerDepth, msg) 257 | } 258 | } 259 | 260 | func Warnf(format string, v ...any) { 261 | if defaultLogger.logLevel >= WARN { 262 | msg := fmt.Sprintf(format, v...) 263 | defaultLogger.writeLog(WARN, callerDepth, msg) 264 | } 265 | } 266 | 267 | func Error(v ...any) { 268 | if defaultLogger.logLevel >= ERROR { 269 | msg := fmt.Sprint(v...) 270 | defaultLogger.writeLog(ERROR, callerDepth, msg) 271 | } 272 | } 273 | 274 | func Errorf(format string, v ...any) { 275 | if defaultLogger.logLevel >= ERROR { 276 | msg := fmt.Sprintf(format, v...) 277 | defaultLogger.writeLog(ERROR, callerDepth, msg) 278 | } 279 | } 280 | 281 | func Fatal(v ...any) { 282 | if defaultLogger.logLevel >= FATAL { 283 | msg := fmt.Sprint(v...) 284 | defaultLogger.writeLog(FATAL, callerDepth, msg) 285 | } 286 | } 287 | 288 | func Fatalf(format string, v ...any) { 289 | if defaultLogger.logLevel >= FATAL { 290 | msg := fmt.Sprintf(format, v...) 291 | defaultLogger.writeLog(FATAL, callerDepth, msg) 292 | } 293 | } 294 | -------------------------------------------------------------------------------- /tool/logger/logger_test.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/gofish2020/easyredis/utils" 8 | ) 9 | 10 | func TestStd(t *testing.T) { 11 | 12 | Debug("hello redis") 13 | Debugf("%s", "hello redis format") 14 | 15 | Info("hello redis") 16 | Infof("%s", "hello redis format") 17 | Error("hello redis") 18 | Errorf("%s", "hello redis format") 19 | Fatal("hello redis") 20 | Fatalf("%s", "hello redis format") 21 | Warn("hello redis") 22 | Warnf("%s", "hello redis format") 23 | 24 | SetLoggerLevel(INFO) // modify log level 25 | 26 | Debug("hello redis") // don't print 27 | Debugf("%s", "hello redis format") // don't print 28 | Info("hello redis") 29 | Infof("%s", "hello redis format") 30 | Error("hello redis") 31 | Errorf("%s", "hello redis format") 32 | Fatal("hello redis") 33 | Fatalf("%s", "hello redis format") 34 | Warn("hello redis") 35 | Warnf("%s", "hello redis format") 36 | time.Sleep(3 * time.Second) 37 | } 38 | 39 | func TestFile(t *testing.T) { 40 | Setup(&Settings{ 41 | Path: "logs", 42 | Name: "easyredis", 43 | Ext: "log", 44 | DateFormat: utils.DateFormat, 45 | }) 46 | 47 | SetLoggerLevel(ERROR) // modify log level 48 | Debug("hello redis") 49 | Debugf("%s", "hello redis format") 50 | Info("hello redis") 51 | Infof("%s", "hello redis format") 52 | Error("hello redis") 53 | Errorf("%s", "hello redis format") 54 | Fatal("hello redis") 55 | Fatalf("%s", "hello redis format") 56 | Warn("hello redis") 57 | Warnf("%s", "hello redis format") 58 | <-time.After(3 * time.Second) 59 | } 60 | -------------------------------------------------------------------------------- /tool/pool/pool.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | 7 | "github.com/gofish2020/easyredis/tool/logger" 8 | ) 9 | 10 | /* 11 | 对象池: 12 | 1.最多只能缓存 MaxIdles 个空闲对象 13 | 2.获取对象的时候,如果没有空闲对象,就创建新对象,对象池中最多只能创建 MaxActive 个对象 14 | */ 15 | 16 | var ( 17 | ErrClosed = errors.New("pool closed") 18 | ) 19 | 20 | type Config struct { 21 | MaxIdles int 22 | MaxActive int 23 | } 24 | 25 | type Pool struct { 26 | Config 27 | 28 | // 创建对象 29 | newObject func() (any, error) 30 | // 释放对象 31 | freeObject func(x any) 32 | 33 | // 空闲对象池 34 | idles chan any 35 | 36 | mu sync.Mutex 37 | activeCount int // 已经创建的对象个数 38 | waiting []chan any // 阻塞等待 39 | 40 | closed bool // 是否已关闭 41 | } 42 | 43 | func NewPool(new func() (any, error), free func(x any), conf Config) *Pool { 44 | 45 | if new == nil { 46 | logger.Error("NewPool argument new func is nil") 47 | return nil 48 | } 49 | 50 | if free == nil { 51 | free = func(x any) {} 52 | } 53 | 54 | p := Pool{ 55 | Config: conf, 56 | newObject: new, 57 | freeObject: free, 58 | activeCount: 0, 59 | closed: false, 60 | } 61 | p.idles = make(chan any, p.MaxIdles) 62 | return &p 63 | } 64 | 65 | func (p *Pool) Put(x any) { 66 | p.mu.Lock() 67 | if p.closed { 68 | p.mu.Unlock() 69 | p.freeObject(x) // 直接释放 70 | return 71 | } 72 | 73 | //1.先判断等待中 74 | if len(p.waiting) > 0 { 75 | // 弹出一个(从头部) 76 | wait := p.waiting[0] 77 | temp := make([]chan any, len(p.waiting)-1) 78 | copy(temp, p.waiting[1:]) 79 | p.waiting = temp 80 | wait <- x // 取消阻塞 81 | p.mu.Unlock() 82 | return 83 | 84 | } 85 | // 2.直接放回空闲缓冲 86 | select { 87 | case p.idles <- x: 88 | p.mu.Unlock() 89 | default: // 说明空闲已满 90 | p.activeCount-- // 对象个数-1 91 | p.mu.Unlock() 92 | p.freeObject(x) // 释放 93 | } 94 | 95 | } 96 | 97 | func (p *Pool) Get() (any, error) { 98 | p.mu.Lock() 99 | if p.closed { 100 | p.mu.Unlock() 101 | return nil, ErrClosed 102 | } 103 | select { 104 | case x := <-p.idles: // 从空闲中获取 105 | p.mu.Unlock() // 解锁 106 | return x, nil 107 | default: 108 | return p.getOne() // 获取一个新的 109 | } 110 | } 111 | 112 | func (p *Pool) getOne() (any, error) { 113 | 114 | // 说明已经创建了太多对象 115 | if p.activeCount >= p.Config.MaxActive { 116 | 117 | wait := make(chan any, 1) 118 | p.waiting = append(p.waiting, wait) 119 | p.mu.Unlock() 120 | // 阻塞等待 121 | x, ok := <-wait 122 | if !ok { 123 | return nil, ErrClosed 124 | } 125 | return x, nil 126 | } 127 | 128 | p.activeCount++ 129 | p.mu.Unlock() 130 | // 创建新对象 131 | x, err := p.newObject() 132 | if err != nil { 133 | p.mu.Lock() 134 | p.activeCount-- 135 | p.mu.Unlock() 136 | return nil, err 137 | } 138 | return x, nil 139 | } 140 | 141 | // 关闭对象池 142 | func (p *Pool) Close() { 143 | 144 | p.mu.Lock() 145 | if p.closed { 146 | p.mu.Unlock() 147 | return 148 | } 149 | p.closed = true 150 | close(p.idles) 151 | for _, wait := range p.waiting { 152 | close(wait) // 关闭等待的通道 153 | } 154 | p.waiting = nil 155 | p.mu.Unlock() 156 | 157 | // 释放空闲对象 158 | for x := range p.idles { 159 | p.freeObject(x) 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /tool/pool/pool_test.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | type mockConn struct { 10 | open bool 11 | } 12 | 13 | func TestPool(t *testing.T) { 14 | connNum := 0 15 | factory := func() (interface{}, error) { 16 | connNum++ 17 | return &mockConn{ 18 | open: true, 19 | }, nil 20 | } 21 | finalizer := func(x interface{}) { 22 | connNum-- 23 | c := x.(*mockConn) 24 | c.open = false 25 | } 26 | cfg := Config{ 27 | MaxIdles: 20, 28 | MaxActive: 40, 29 | } 30 | pool := NewPool(factory, finalizer, cfg) 31 | var borrowed []*mockConn 32 | for i := 0; i < int(cfg.MaxActive); i++ { // Get 33 | x, err := pool.Get() 34 | if err != nil { 35 | t.Error(err) 36 | return 37 | } 38 | c := x.(*mockConn) 39 | if !c.open { 40 | t.Error("conn is not open") 41 | return 42 | } 43 | borrowed = append(borrowed, c) 44 | } 45 | for _, c := range borrowed { // Put 46 | pool.Put(c) 47 | } 48 | 49 | borrowed = nil 50 | // borrow returned 51 | for i := 0; i < int(cfg.MaxActive); i++ { // Get 52 | 53 | x, err := pool.Get() 54 | if err != nil { 55 | t.Error(err) 56 | return 57 | } 58 | c := x.(*mockConn) 59 | if !c.open { 60 | t.Error("conn is not open") 61 | return 62 | } 63 | borrowed = append(borrowed, c) 64 | } 65 | for i, c := range borrowed { // Put 66 | if i < len(borrowed)-1 { 67 | pool.Put(c) 68 | } 69 | } 70 | pool.Close() 71 | pool.Close() // test close twice 72 | pool.Put(borrowed[len(borrowed)-1]) 73 | if connNum != 0 { 74 | t.Errorf("%d connections has not closed", connNum) 75 | } 76 | _, err := pool.Get() 77 | if err != ErrClosed { 78 | t.Error("expect err closed") 79 | } 80 | } 81 | 82 | func TestPool_Waiting(t *testing.T) { 83 | factory := func() (interface{}, error) { 84 | return &mockConn{ 85 | open: true, 86 | }, nil 87 | } 88 | finalizer := func(x interface{}) { 89 | c := x.(*mockConn) 90 | c.open = false 91 | } 92 | cfg := Config{ 93 | MaxIdles: 2, 94 | MaxActive: 4, 95 | } 96 | pool := NewPool(factory, finalizer, cfg) 97 | var borrowed []*mockConn 98 | for i := 0; i < int(cfg.MaxActive); i++ { // Get 99 | x, err := pool.Get() 100 | if err != nil { 101 | t.Error(err) 102 | return 103 | } 104 | c := x.(*mockConn) 105 | if !c.open { 106 | t.Error("conn is not open") 107 | return 108 | } 109 | borrowed = append(borrowed, c) 110 | } 111 | getResult := make(chan bool) 112 | go func() { 113 | x, err := pool.Get() // 阻塞 114 | if err != nil { 115 | t.Error(err) 116 | getResult <- false 117 | return 118 | } 119 | c := x.(*mockConn) 120 | if !c.open { 121 | t.Error("conn is not open") 122 | getResult <- false 123 | return 124 | } 125 | getResult <- true 126 | }() 127 | time.Sleep(time.Second) 128 | pool.Put(borrowed[0]) // 放回一个 129 | if ret := <-getResult; !ret { 130 | t.Error("get and waiting returned failed") 131 | } 132 | } 133 | 134 | func TestPool_CreateErr(t *testing.T) { 135 | makeErr := true 136 | factory := func() (interface{}, error) { 137 | if makeErr { 138 | makeErr = false 139 | return nil, errors.New("mock err") 140 | } 141 | return &mockConn{ 142 | open: true, 143 | }, nil 144 | } 145 | finalizer := func(x interface{}) { 146 | c := x.(*mockConn) 147 | c.open = false 148 | } 149 | cfg := Config{ 150 | MaxIdles: 2, 151 | MaxActive: 4, 152 | } 153 | pool := NewPool(factory, finalizer, cfg) 154 | _, err := pool.Get() //第一次获取-错误 155 | if err == nil { 156 | t.Error("expecting err") 157 | return 158 | } 159 | x, err := pool.Get() // 第二次获取成功 160 | if err != nil { 161 | t.Error("get err") 162 | return 163 | } 164 | pool.Put(x) // 放回 165 | _, err = pool.Get() // 再获取回来 166 | if err != nil { 167 | t.Error("get err") 168 | return 169 | } 170 | 171 | } 172 | -------------------------------------------------------------------------------- /tool/timewheel/delay.go: -------------------------------------------------------------------------------- 1 | package timewheel 2 | 3 | import "time" 4 | 5 | type Delay struct { 6 | tw *TimeWheel 7 | } 8 | 9 | func NewDelay() *Delay { 10 | delay := &Delay{} 11 | delay.tw = New(1*time.Second, 3600) 12 | delay.tw.Start() 13 | return delay 14 | } 15 | 16 | // 添加延迟任务 绝对时间 17 | func (d *Delay) AddAt(expire time.Time, key string, callback func()) { 18 | interval := time.Until(expire) 19 | d.Add(interval, key, callback) 20 | } 21 | 22 | // 添加延迟任务 相对时间 23 | func (d *Delay) Add(interval time.Duration, key string, callback func()) { 24 | d.tw.Add(interval, key, callback) 25 | } 26 | 27 | // 取消延迟任务 28 | func (d *Delay) Cancel(key string) { 29 | d.tw.Cancel(key) 30 | } 31 | -------------------------------------------------------------------------------- /tool/timewheel/delay_test.go: -------------------------------------------------------------------------------- 1 | package timewheel 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/gofish2020/easyredis/tool/logger" 8 | ) 9 | 10 | func TestAdd(t *testing.T) { 11 | ch := make(chan time.Time) 12 | beginTime := time.Now() 13 | delay := NewDelay() 14 | delay.Add(time.Second, "", func() { 15 | logger.Debug("exec task...") 16 | ch <- time.Now() 17 | }) 18 | execAt := <-ch 19 | delayDuration := execAt.Sub(beginTime) 20 | // usually 1.0~2.0 s 21 | if delayDuration < time.Second || delayDuration > 3*time.Second { 22 | t.Error("wrong execute time") 23 | } 24 | } 25 | 26 | func TestAddTask(t *testing.T) { 27 | delay := NewDelay() 28 | delay.Add(0*time.Second, "test0", func() { 29 | logger.Info("0 time.Second running") 30 | time.Sleep(10 * time.Second) 31 | }) 32 | 33 | time.Sleep(1500 * time.Millisecond) 34 | 35 | delay.Add(9*time.Second, "testKey", func() { 36 | logger.Info("9 time.Second running") 37 | time.Sleep(5 * time.Second) 38 | }) 39 | 40 | time.Sleep(14 * time.Second) 41 | } 42 | 43 | func TestCancelTask(t *testing.T) { 44 | delay := NewDelay() 45 | delay.Add(0*time.Second, "test0", func() { 46 | logger.Info("0 time.Second running") 47 | time.Sleep(10 * time.Second) 48 | }) 49 | 50 | time.Sleep(1500 * time.Millisecond) 51 | 52 | delay.Add(9*time.Second, "testKey", func() { 53 | logger.Info("9 time.Second running") 54 | time.Sleep(5 * time.Second) 55 | }) 56 | 57 | delay.Cancel("testKey") 58 | time.Sleep(14 * time.Second) 59 | } 60 | -------------------------------------------------------------------------------- /tool/timewheel/timewheel.go: -------------------------------------------------------------------------------- 1 | package timewheel 2 | 3 | import ( 4 | "container/list" 5 | "time" 6 | 7 | "github.com/gofish2020/easyredis/tool/logger" 8 | ) 9 | 10 | // 记录:任务task位于循环队列的哪个(pos游标)链表上 11 | type taskPos struct { 12 | pos int 13 | ele *list.Element 14 | } 15 | 16 | type task struct { 17 | delay time.Duration 18 | key string 19 | circle int 20 | callback func() 21 | } 22 | 23 | // 循环队列 + 链表 24 | type TimeWheel struct { 25 | 26 | // 间隔 27 | interval time.Duration 28 | // 定时器 29 | ticker *time.Ticker 30 | 31 | // 游标 32 | curSlotPos int 33 | // 循环队列大小 34 | slotNum int 35 | // 底层存储 36 | slots []*list.List 37 | m map[string]*taskPos 38 | 39 | // 任务通道 40 | addChannel chan *task 41 | cacelChannel chan string 42 | // 停止 43 | stopChannel chan struct{} 44 | } 45 | 46 | func New(interval time.Duration, slotNum int) *TimeWheel { 47 | 48 | timeWheel := &TimeWheel{ 49 | ticker: nil, 50 | interval: interval, 51 | slots: make([]*list.List, slotNum), 52 | slotNum: slotNum, 53 | m: make(map[string]*taskPos), 54 | addChannel: make(chan *task), 55 | cacelChannel: make(chan string), 56 | stopChannel: make(chan struct{}), 57 | } 58 | 59 | for i := 0; i < slotNum; i++ { 60 | timeWheel.slots[i] = list.New() 61 | } 62 | return timeWheel 63 | } 64 | 65 | func (tw *TimeWheel) doTask() { 66 | 67 | for { 68 | select { 69 | case <-tw.ticker.C: 70 | tw.execTask() 71 | case t := <-tw.addChannel: 72 | tw.addTask(t) 73 | case key := <-tw.cacelChannel: 74 | tw.cancelTask(key) 75 | case <-tw.stopChannel: 76 | tw.ticker.Stop() 77 | return 78 | } 79 | } 80 | } 81 | 82 | func (tw *TimeWheel) execTask() { 83 | l := tw.slots[tw.curSlotPos] 84 | if tw.curSlotPos == tw.slotNum-1 { 85 | tw.curSlotPos = 0 86 | } else { 87 | tw.curSlotPos++ 88 | } 89 | go tw.scanList(l) 90 | } 91 | 92 | func (tw *TimeWheel) scanList(l *list.List) { 93 | 94 | for e := l.Front(); e != nil; { 95 | 96 | t := e.Value.(*task) 97 | // 任务不在当前圈执行 98 | if t.circle > 0 { 99 | t.circle-- 100 | continue 101 | } 102 | 103 | // 执行任务 104 | go func() { 105 | // 异常恢复 106 | defer func() { 107 | if err := recover(); err != nil { 108 | logger.Error(err) 109 | } 110 | }() 111 | // 协程中执行任务 112 | call := t.callback 113 | call() 114 | }() 115 | 116 | // 下一个记录 117 | next := e.Next() 118 | // 链表中删除 119 | l.Remove(e) 120 | // map中删除 121 | if t.key != "" { 122 | delete(tw.m, t.key) 123 | } 124 | e = next 125 | } 126 | } 127 | func (tw *TimeWheel) posAndCircle(d time.Duration) (pos, circle int) { 128 | 129 | // 延迟(秒) 130 | delaySecond := int(d.Seconds()) 131 | // 间隔(秒) 132 | intervalSecond := int(tw.interval.Seconds()) 133 | // delaySecond/intervalSecond 表示从curSlotPos位置偏移 134 | pos = (tw.curSlotPos + delaySecond/intervalSecond) % tw.slotNum 135 | circle = (delaySecond / intervalSecond) / tw.slotNum 136 | return 137 | } 138 | 139 | func (tw *TimeWheel) addTask(t *task) { 140 | 141 | // 定位任务应该保存在循环队列的位置 & 圈数 142 | pos, circle := tw.posAndCircle(t.delay) 143 | t.circle = circle 144 | 145 | // 将任务保存到循环队列pos位置 146 | ele := tw.slots[pos].PushBack(t) 147 | // 在map中记录 key -> { pos, ele } 的映射 148 | if t.key != "" { 149 | // 已经存在重复的key 150 | if _, ok := tw.m[t.key]; ok { 151 | tw.cancelTask(t.key) 152 | } 153 | tw.m[t.key] = &taskPos{pos: pos, ele: ele} 154 | } 155 | } 156 | 157 | func (tw *TimeWheel) cancelTask(key string) { 158 | taskPos, ok := tw.m[key] 159 | if !ok { 160 | return 161 | } 162 | // 从循环队列链表中删除任务 163 | tw.slots[taskPos.pos].Remove(taskPos.ele) 164 | // 从map中删除 165 | delete(tw.m, key) 166 | } 167 | 168 | /************ 外部调用 ************/ 169 | 170 | func (tw *TimeWheel) Start() { 171 | tw.ticker = time.NewTicker(tw.interval) 172 | go tw.doTask() 173 | } 174 | 175 | func (tw *TimeWheel) Stop() { 176 | tw.stopChannel <- struct{}{} 177 | } 178 | 179 | func (tw *TimeWheel) Add(delay time.Duration, key string, callback func()) { 180 | 181 | // 延迟时间 182 | if delay < 0 { 183 | return 184 | } 185 | // 新建任务 186 | t := task{ 187 | delay: delay, 188 | key: key, 189 | callback: callback, 190 | } 191 | // 发送到channel中 192 | tw.addChannel <- &t 193 | } 194 | 195 | func (tw *TimeWheel) Cancel(key string) { 196 | tw.cacelChannel <- key 197 | } 198 | -------------------------------------------------------------------------------- /tool/wait/wait.go: -------------------------------------------------------------------------------- 1 | package wait 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | /* 9 | 对系统 WaitGroup的封装 10 | */ 11 | type Wait struct { 12 | wait sync.WaitGroup 13 | } 14 | 15 | func (w *Wait) Add(delta int) { 16 | w.wait.Add(delta) 17 | } 18 | 19 | func (w *Wait) Done() { 20 | w.wait.Done() 21 | } 22 | 23 | func (w *Wait) Wait() { 24 | w.wait.Wait() 25 | } 26 | 27 | // 超时等待 28 | func (w *Wait) WaitWithTimeOut(timeout time.Duration) bool { 29 | 30 | ch := make(chan struct{}) 31 | go func() { 32 | defer close(ch) 33 | w.Wait() 34 | }() 35 | 36 | select { 37 | case <-ch: 38 | return false // 正常 39 | case <-time.After(timeout): 40 | return true // 超时 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /tool/wildcard/wildcard.go: -------------------------------------------------------------------------------- 1 | package wildcard 2 | 3 | import ( 4 | "errors" 5 | "regexp" 6 | "strings" 7 | ) 8 | 9 | /* 10 | 11 | 将redis支持的匹配模式,转换成go支持的正则表达式 12 | 13 | 14 | Supported glob-style patterns: 15 | 16 | h?llo matches hello, hallo and hxllo 17 | h*llo matches hllo and heeeello 18 | h[ae]llo matches hello and hallo, but not hillo 19 | h[^e]llo matches hallo, hbllo, ... but not hello 20 | h[a-b]llo matches hallo and hbllo 21 | 22 | 除了 ? * [] [^] [-] 这些字符,都作为正常的字符进行匹配(也就是要转义) 23 | 24 | */ 25 | 26 | type Pattern struct { 27 | exp *regexp.Regexp 28 | } 29 | 30 | var replaceMap = map[byte]string{ 31 | // characters in the wildcard that must be escaped in the regexp 32 | '+': `\+`, // 这些字符在正则表达式中是有意的,为了作为普通的字符匹配,需要 \ 转义为普通字符 33 | ')': `\)`, 34 | '$': `\$`, 35 | '.': `\.`, 36 | '{': `\{`, 37 | '}': `\}`, 38 | '|': `\|`, 39 | '*': ".*", // * 对应go的 .* 40 | '?': ".", // ? 对应go的 . 41 | } 42 | 43 | var errEndWithEscape = "end with escape \\" 44 | 45 | func CompilePattern(src string) (*Pattern, error) { 46 | regSrc := strings.Builder{} 47 | 48 | regSrc.WriteByte('^') // 正则表达式-开头 49 | 50 | for i := 0; i < len(src); i++ { 51 | ch := src[i] 52 | if ch == '\\' { 53 | if i == len(src)-1 { // 例如: a\ 54 | return nil, errors.New(errEndWithEscape) 55 | } 56 | // 例如 \b 57 | regSrc.WriteByte('\\') 58 | regSrc.WriteByte(src[i+1]) 59 | i++ // 将\b一次性全保存 60 | } else if ch == '^' { // redis中 [^ 是配套的一对 61 | if i == 0 { // 说明前面没有 [ 62 | regSrc.WriteString(`\^`) // 那^就转义为普通字符 \^ 63 | } else if i == 1 { 64 | if src[i-1] == '[' { 65 | regSrc.WriteString(`^`) // src is: [^ 66 | } else { 67 | regSrc.WriteString(`\^`) // example: a^ 68 | } 69 | } else { 70 | if src[i-1] == '[' && src[i-2] != '\\' { 71 | regSrc.WriteString(`^`) // example: a[^ 72 | } else { 73 | regSrc.WriteString(`\^`) // example: \[^ 74 | } 75 | } 76 | } else if escaped, toEscape := replaceMap[ch]; toEscape { 77 | regSrc.WriteString(escaped) 78 | } else { 79 | regSrc.WriteByte(ch) 80 | } 81 | } 82 | regSrc.WriteByte('$') // 正则表达式-结尾 83 | 84 | re, err := regexp.Compile(regSrc.String()) 85 | if err != nil { 86 | return nil, err 87 | } 88 | 89 | return &Pattern{ 90 | exp: re, 91 | }, nil 92 | } 93 | 94 | func (p *Pattern) IsMatch(s string) bool { 95 | return p.exp.MatchString(s) 96 | } 97 | -------------------------------------------------------------------------------- /tool/wildcard/wildcard_test.go: -------------------------------------------------------------------------------- 1 | package wildcard 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | func TestWildCard(t *testing.T) { 9 | p, err := CompilePattern("") 10 | if err != nil { 11 | t.Error(err) 12 | return 13 | } 14 | if !p.IsMatch("") { 15 | t.Error("expect true actually false") 16 | } 17 | p, err = CompilePattern("a") 18 | if err != nil { 19 | t.Error(err) 20 | return 21 | } 22 | if !p.IsMatch("a") { 23 | t.Error("expect true actually false") 24 | } 25 | if p.IsMatch("b") { 26 | t.Error("expect false actually true") 27 | } 28 | 29 | // test '?' 30 | p, err = CompilePattern("a?") 31 | if err != nil { 32 | t.Error(err) 33 | return 34 | } 35 | if !p.IsMatch("ab") { 36 | t.Error("expect true actually false") 37 | } 38 | if p.IsMatch("a") { 39 | t.Error("expect false actually true") 40 | } 41 | if p.IsMatch("abb") { 42 | t.Error("expect false actually true") 43 | } 44 | if p.IsMatch("bb") { 45 | t.Error("expect false actually true") 46 | } 47 | 48 | // test * 49 | p, err = CompilePattern("a*") 50 | if err != nil { 51 | t.Error(err) 52 | return 53 | } 54 | if !p.IsMatch("ab") { 55 | t.Error("expect true actually false") 56 | } 57 | if !p.IsMatch("a") { 58 | t.Error("expect true actually false") 59 | } 60 | if !p.IsMatch("abb") { 61 | t.Error("expect true actually false") 62 | } 63 | if p.IsMatch("bb") { 64 | t.Error("expect false actually true") 65 | } 66 | 67 | // test [] 68 | p, err = CompilePattern("a[ab[]") 69 | if err != nil { 70 | t.Error(err) 71 | return 72 | } 73 | if !p.IsMatch("ab") { 74 | t.Error("expect true actually false") 75 | } 76 | if !p.IsMatch("aa") { 77 | t.Error("expect true actually false") 78 | } 79 | if !p.IsMatch("a[") { 80 | t.Error("expect true actually false") 81 | } 82 | if p.IsMatch("abb") { 83 | t.Error("expect false actually true") 84 | } 85 | if p.IsMatch("bb") { 86 | t.Error("expect false actually true") 87 | } 88 | 89 | // test [a-c] 90 | p, err = CompilePattern("h[a-c]llo") 91 | if err != nil { 92 | t.Error(err) 93 | return 94 | } 95 | if !p.IsMatch("hallo") { 96 | t.Error("expect true actually false") 97 | } 98 | if !p.IsMatch("hbllo") { 99 | t.Error("expect true actually false") 100 | } 101 | if !p.IsMatch("hcllo") { 102 | t.Error("expect true actually false") 103 | } 104 | if p.IsMatch("hdllo") { 105 | t.Error("expect false actually true") 106 | } 107 | if p.IsMatch("hello") { 108 | t.Error("expect false actually true") 109 | } 110 | 111 | //test [^] 112 | p, err = CompilePattern("h[^ab]llo") 113 | if err != nil { 114 | t.Error(err) 115 | return 116 | } 117 | if p.IsMatch("hallo") { 118 | t.Error("expect false actually true") 119 | } 120 | if p.IsMatch("hbllo") { 121 | t.Error("expect false actually true") 122 | } 123 | if !p.IsMatch("hcllo") { 124 | t.Error("expect true actually false") 125 | } 126 | 127 | p, err = CompilePattern("[^ab]c") 128 | if err != nil { 129 | t.Error(err) 130 | return 131 | } 132 | if p.IsMatch("abc") { 133 | t.Error("expect false actually true") 134 | } 135 | if !p.IsMatch("1c") { 136 | t.Error("expect true actually false") 137 | } 138 | 139 | p, err = CompilePattern("1^2") 140 | if err != nil { 141 | t.Error(err) 142 | return 143 | } 144 | if !p.IsMatch("1^2") { 145 | t.Error("expect true actually false") 146 | } 147 | 148 | p, err = CompilePattern(`\[^1]2`) 149 | if err != nil { 150 | t.Error(err) 151 | return 152 | } 153 | if !p.IsMatch("[^1]2") { 154 | t.Error("expect true actually false") 155 | } 156 | 157 | p, err = CompilePattern(`^1`) 158 | if err != nil { 159 | t.Error(err) 160 | return 161 | } 162 | if !p.IsMatch("^1") { 163 | t.Error("expect true actually false") 164 | } 165 | 166 | // test escape 167 | p, err = CompilePattern(`\\\\`) 168 | if err != nil { 169 | t.Error(err) 170 | return 171 | } 172 | if !p.IsMatch(`\\`) { 173 | t.Error("expect true actually false") 174 | } 175 | 176 | p, err = CompilePattern("\\*") 177 | if err != nil { 178 | t.Error(err) 179 | return 180 | } 181 | if !p.IsMatch("*") { 182 | t.Error("expect true actually false") 183 | } 184 | if p.IsMatch("a") { 185 | t.Error("expect false actually true") 186 | } 187 | 188 | p, err = CompilePattern(`\`) 189 | if err == nil || err.Error() != errEndWithEscape { 190 | t.Error(err) 191 | return 192 | } 193 | } 194 | 195 | func TestCompile(t *testing.T) { 196 | 197 | // CompilePattern("h[^ab]llo") 198 | // CompilePattern("a*") 199 | CompilePattern(`\[^`) 200 | time.Sleep(3 * time.Second) 201 | 202 | } 203 | -------------------------------------------------------------------------------- /utils/cmdline.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | func BuildCmdLine(commandName string, args ...[]byte) [][]byte { 4 | result := make([][]byte, len(args)+1) 5 | result[0] = []byte(commandName) 6 | for i, s := range args { 7 | result[i+1] = s 8 | } 9 | return result 10 | 11 | } 12 | -------------------------------------------------------------------------------- /utils/const.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | const ( 4 | Yellow = "\033[33m" 5 | Reset = "\033[0m" 6 | ) 7 | 8 | const ( 9 | DateTimeFormat = "2006-01-02 15:04:05" 10 | DateFormat = "2006-01-02" 11 | TimeFormat = "15:04:05" 12 | 13 | CRLF = "\r\n" 14 | ) 15 | -------------------------------------------------------------------------------- /utils/hash.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "math" 4 | 5 | // 计算比param参数大,并满足是2的N次幂, 最近接近param的数值size 6 | func ComputeCapacity(param int) (size int) { 7 | if param <= 16 { 8 | return 16 9 | } 10 | n := param - 1 11 | n |= n >> 1 12 | n |= n >> 2 13 | n |= n >> 4 14 | n |= n >> 8 15 | n |= n >> 16 16 | if n < 0 { 17 | return math.MaxInt32 18 | } 19 | return n + 1 20 | } 21 | 22 | // 计算key的hashcode 23 | const prime32 = uint32(16777619) 24 | 25 | func Fnv32(key string) uint32 { 26 | hash := uint32(2166136261) 27 | for i := 0; i < len(key); i++ { 28 | hash *= prime32 29 | hash ^= uint32(key[i]) 30 | } 31 | return hash 32 | } 33 | -------------------------------------------------------------------------------- /utils/logo.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "math/rand" 5 | "time" 6 | ) 7 | 8 | // ascii艺术文字: https://tool.cccyun.cc/ascii_art 9 | 10 | var logos = []string{ 11 | 12 | ` 13 | ____ __ ____ _ _ ____ ____ ____ __ ____ 14 | ( __) / _\ / ___)( \/ )( _ \( __)( \( )/ ___) 15 | ) _) / \\___ \ ) / ) / ) _) ) D ( )( \___ \ 16 | (____)\_/\_/(____/(__/ (__\_)(____)(____/(__)(____/ 17 | `, 18 | 19 | ` 20 | 888 Y8P 21 | 888 22 | .d88b. 8888b. .d8888b 888 888 888d888 .d88b. .d88888 888 .d8888b 23 | d8P Y8b "88b 88K 888 888 888P" d8P Y8b d88" 888 888 88K 24 | 88888888 .d888888 "Y8888b. 888 888 888 88888888 888 888 888 "Y8888b. 25 | Y8b. 888 888 X88 Y88b 888 888 Y8b. Y88b 888 888 X88 26 | "Y8888 "Y888888 88888P' "Y88888 888 "Y8888 "Y88888 888 88888P' 27 | 888 28 | Y8b d88P 29 | "Y88P" 30 | `, 31 | ` 32 | @@@@@@@@ @@@@@@ @@@@@@ @@@ @@@ @@@@@@@ @@@@@@@@ @@@@@@@ @@@ @@@@@@ 33 | @@@@@@@@ @@@@@@@@ @@@@@@@ @@@ @@@ @@@@@@@@ @@@@@@@@ @@@@@@@@ @@@ @@@@@@@ 34 | @@! @@! @@@ !@@ @@! !@@ @@! @@@ @@! @@! @@@ @@! !@@ 35 | !@! !@! @!@ !@! !@! @!! !@! @!@ !@! !@! @!@ !@! !@! 36 | @!!!:! @!@!@!@! !!@@!! !@!@! @!@!!@! @!!!:! @!@ !@! !!@ !!@@!! 37 | !!!!!: !!!@!!!! !!@!!! @!!! !!@!@! !!!!!: !@! !!! !!! !!@!!! 38 | !!: !!: !!! !:! !!: !!: :!! !!: !!: !!! !!: !:! 39 | :!: :!: !:! !:! :!: :!: !:! :!: :!: !:! :!: !:! 40 | :: :::: :: ::: :::: :: :: :: ::: :: :::: :::: :: :: :::: :: 41 | : :: :: : : : :: : : : : : : : :: :: :: : : : :: : : `, 42 | } 43 | 44 | // 随机logo, just for fun,yoho!!! 45 | func Logo() string { 46 | rand.Seed(time.Now().UnixNano()) 47 | return Yellow + logos[rand.Intn(len(logos))] + Reset 48 | } 49 | -------------------------------------------------------------------------------- /utils/path.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | 8 | "github.com/kardianos/osext" 9 | ) 10 | 11 | // ExecDir 当前可执行程序目录 12 | func ExecDir() string { 13 | 14 | path, err := osext.ExecutableFolder() 15 | if err != nil { 16 | return "" 17 | } 18 | return path 19 | } 20 | 21 | func FileExists(filename string) bool { 22 | info, err := os.Stat(filename) 23 | return err == nil && !info.IsDir() 24 | } 25 | 26 | // 打开文件(如果文件不存在自动创建) 27 | func OpenFile(fileName, dir string) (*os.File, error) { 28 | // 校验是否有该目录权限 29 | if checkPermission(dir) { 30 | return nil, fmt.Errorf("permission denied dir: %s", dir) 31 | } 32 | // 创建目录(目录存在啥也不做) 33 | if err := MakeDir(dir); err != nil { 34 | return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err) 35 | } 36 | 37 | // 打开文件(不存在会自动创建),O_APPEND 追加写(权限 读/写) 38 | f, err := os.OpenFile(path.Join(dir, fileName), os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644) 39 | if err != nil { 40 | return nil, fmt.Errorf("fail to open file, err: %s", err) 41 | } 42 | return f, nil 43 | } 44 | 45 | func checkPermission(src string) bool { 46 | _, err := os.Stat(src) 47 | return os.IsPermission(err) 48 | } 49 | 50 | func MakeDir(src string) error { 51 | return os.MkdirAll(src, os.ModePerm) 52 | } 53 | -------------------------------------------------------------------------------- /utils/rand.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "math/rand" 5 | "time" 6 | ) 7 | 8 | // 随机对象*rand 9 | var r = rand.New(rand.NewSource(time.Now().UnixNano())) 10 | 11 | var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") 12 | 13 | // 生成随机字符串 14 | func RandString(n int) string { 15 | result := make([]rune, n) 16 | for i := 0; i < len(result); i++ { 17 | // 从letters中随机获取一个字符,保存到result中 18 | result[i] = rune(letters[r.Intn(len(letters))]) 19 | } 20 | return string(result) 21 | } 22 | --------------------------------------------------------------------------------