├── README.md ├── bitcask ├── bitcask.go ├── filelock_unix.go ├── filelock_windows.go └── status.go ├── client ├── client.go └── main.go ├── clusters ├── node1 │ └── config.yaml ├── node2 │ └── config.yaml ├── node3 │ └── config.yaml ├── node4 │ └── config.yaml ├── node5 │ └── config.yaml ├── run.bat └── run.sh ├── config └── db.yaml ├── doc ├── Group.png ├── Having.png ├── Transaction.md ├── and.png ├── asofSelect.png ├── asofSelect2.png ├── cnf.png ├── column.png ├── columnConstraint.png ├── create.png ├── crossjoin.png ├── dataType.png ├── delete.png ├── drop.png ├── eq.png ├── explain.png ├── expr.png ├── func.png ├── innerjoin.png ├── insert.png ├── leftjoin.png ├── like.png ├── limit.png ├── noteq.png ├── or.png ├── rightjoin.png ├── select.png ├── sql.md ├── status.png ├── status2.png └── update.png ├── go.mod ├── go.sum ├── gobReg └── init.go ├── log ├── Leader.go ├── candidate.go ├── follower.go ├── log.go ├── message.go ├── node.go ├── server.go └── state.go ├── logger └── logger.go ├── main.go ├── server └── server.go ├── sourcegoyacc └── main.go ├── sql ├── catalog │ ├── aggregation.go │ ├── execinterface.go │ ├── interface.go │ ├── join.go │ ├── mutation.go │ ├── node.go │ ├── optimizer.go │ ├── planner.go │ ├── query.go │ ├── schema.go │ ├── schemaexec.go │ └── source.go ├── engine │ ├── kv.go │ └── raft.go ├── expr │ ├── expression.go │ └── fn.go └── value.go ├── sqlparser ├── ast │ ├── ast.go │ ├── expression.go │ ├── fn.go │ └── stmt.go ├── model │ └── model.go └── parser │ ├── charset.go │ ├── handle.go │ ├── lexer.go │ ├── misc.go │ ├── parser.go │ ├── parser.y │ ├── sourcegoyacc.exe │ ├── y.go │ └── y.output ├── storage ├── mvcc.go ├── mvccTransaction.go └── mvcckey.go └── util ├── conn.go └── util.go /README.md: -------------------------------------------------------------------------------- 1 | # CabbageDB 2 | 3 | 本项目是一个轻量级分布式 SQL 数据库(Go),专为学习数据库核心技术设计,核心模块自主实现,支持分布式容错、ACID 事务及标准 SQL 查询。 4 | 5 | ### **核心模块与功能** 6 | 7 | #### **1. 存储引擎** 8 | 9 | - **Bitcask 日志合并存储**: 10 | - 数据以仅追加(append-only)方式持久化到磁盘,支持快速写入。 11 | - 启动时自动合并压缩无效数据(如删除标记),减少磁盘占用。 12 | - 内存中维护 BTreeMap 索引,加速主键查询与范围扫描。 13 | 14 | #### **2. SQL 语法解析** 15 | 16 | - **解析器**: 17 | - **手写词法分析器(Lexer)**:逐字符解析 SQL 字符串,生成 Token 流(如识别关键字、标识符、运算符等)。 18 | - **基于 Goyacc 的语法分析器**:通过自定义文法规则(Yacc 规范),将 Token 流转换为抽象语法树(AST)。 19 | 20 | #### **3. SQL 查询与执行** 21 | 22 | [具体SQL语法](./doc/sql.md) 23 | 24 | - **SQL 兼容性**: 25 | - 支持标准 DDL/DML(`CREATE TABLE`, `INSERT`, `UPDATE`, `DELETE`, `SELECT`)。 26 | - 提供聚合函数(`COUNT`, `SUM`, `MIN`, `MAX`, `AVG`)、多表连接(`INNER JOIN` `OUTER JOIN`)及排序分页(`ORDER BY`, `LIMIT`)。 27 | - **查询优化**: 28 | - 自动优化规则:谓词下推、常量折叠、索引查找优化、冗余节点清理、join连接优化。 29 | - **物理执行**: 30 | - 基础算子:索引扫描(`IndexScan`)、过滤(`Filter`)、哈希连接(`HashJoin`)、聚合(`Aggregate`)。 31 | 32 | #### **4. 分布式架构** 33 | 34 | - **Raft 多副本共识**: 35 | - 手动实现 Raft 协议,支持 Leader 选举与日志复制,保证多副本数据一致性。 36 | - 通过 Leader 节点提供线性一致性读写(强一致性)。 37 | - 持久化 Raft 状态机日志至 Bitcask 存储,支持节点崩溃恢复。 38 | 39 | #### **5. 事务与并发控制** 40 | 41 | [具体命令及案例](./doc/Transaction.md) 42 | 43 | - **隔离级别**: 44 | - 支持可重复读(Repeatable Read),确保事务内多次读取结果一致。 45 | - 读写事务支持写冲突检测。 46 | - **MVCC 多版本管理**: 47 | - 数据版本化存储,允许时间点快照查询(如历史数据回溯)。 48 | - 仅读事务无需锁竞争,直接访问快照版本。 49 | 50 | #### **6. 网络通信** 51 | 52 | - **高并发通信框架**: 53 | - 基于 Go 的 `goroutine` 与 `channel` 实现异步消息处理。 54 | - 自定义二进制协议(Header + Body 格式),支持高效序列化与反序列化事件。 55 | - 事件类型:客户端 SQL 请求、Raft 日志同步、集群心跳检测。 56 | 57 | #### **7. 扩展性设计** 58 | 59 | - **可插拔存储引擎**: 60 | - 默认集成 Bitcask,支持替换为内存存储(`BTreeMap`)及其他引擎适配不同场景。 61 | - **轻量级部署**: 62 | - 单节点模式(无需分布式依赖)与集群模式一键切换。 63 | 64 | ### 启动方式 65 | 66 | - 客户端 67 | 68 | ```go 69 | cd client 70 | go run ./main.go ./client.go 71 | ``` 72 | 73 | - 服务端单机 74 | 75 | ```go 76 | go run ./main.go --config ./config/db.yaml 77 | ``` 78 | 79 | - 服务端集群(WINDOWS) 80 | 81 | ```go 82 | ./clusters/run.bat 83 | ``` 84 | 85 | - 服务端集群(LINUX) 86 | 87 | ```go 88 | ./clusters/run.sh 89 | ``` 90 | 91 | ### 客户端命令 92 | 93 | - !tables 查看所有表 94 | - !table tablename 查看tablename表结构 95 | - !status 查看raft节点状态 96 | 97 | 98 | -------------------------------------------------------------------------------- /bitcask/bitcask.go: -------------------------------------------------------------------------------- 1 | package bitcask 2 | 3 | import ( 4 | "bytes" 5 | "cabbageDB/logger" 6 | "cabbageDB/util" 7 | "encoding/gob" 8 | "fmt" 9 | "github.com/google/btree" 10 | "io" 11 | "os" 12 | "path/filepath" 13 | ) 14 | 15 | // Pos代表在文件中的偏移量 16 | // Len表示Value的长度 17 | type ValueOffset struct { 18 | Pos uint64 19 | Len uint32 20 | } 21 | 22 | // KeyDir里面的一个元素,对应序列化到日志的数据结构Key->(ValuePos,ValueLen) 23 | type ByteItem struct { 24 | Key []byte 25 | Value *ValueOffset 26 | } 27 | 28 | // 将ByteItem转为原始Key->Value 29 | type ByteMap struct { 30 | Key []byte 31 | Value []byte 32 | } 33 | 34 | func (bi *ByteItem) Less(than btree.Item) bool { 35 | other := than.(*ByteItem) 36 | return bytes.Compare(bi.Key, other.Key) < 0 37 | } 38 | 39 | // BitCask的追加日志文件 40 | type Log struct { 41 | Path string 42 | File *os.File 43 | } 44 | 45 | // 打开一个日志文件,如果文件不存在则创建一个。在文件关闭之前对其持有排他锁,如果锁已被持有则返回错误 46 | func NewLog(path string) (*Log, error) { 47 | dir := filepath.Dir(path) 48 | if err := os.MkdirAll(dir, os.ModePerm); err != nil { 49 | return nil, fmt.Errorf("failed to create directory: %w", err) 50 | } 51 | 52 | file, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0666) 53 | if err != nil { 54 | return nil, fmt.Errorf("failed to open/create file: %w", err) 55 | } 56 | 57 | if err = LockFileNonBlocking(file); err != nil { 58 | return nil, fmt.Errorf("get lockfile err: %w", err) 59 | } 60 | 61 | return &Log{ 62 | Path: path, 63 | File: file, 64 | }, nil 65 | } 66 | 67 | // 扫描日志文件创建keydir 68 | // 如果遇到从日志读取具体值时出错,可能是由不完整的写入引起,则截断 69 | // 如果遇见其他错误,则报panic,阻止继续运行 70 | func (log *Log) buildKeyDir() (*btree.BTree, error) { 71 | var lenBuf = make([]byte, 4) 72 | keyDir := btree.New(2) 73 | 74 | file := log.File 75 | 76 | fileInfo, err := file.Stat() 77 | if err != nil { 78 | return nil, err 79 | } 80 | fileLen := fileInfo.Size() 81 | 82 | pos, err := file.Seek(0, io.SeekStart) 83 | if err != nil { 84 | return nil, err 85 | } 86 | 87 | for pos < fileLen { 88 | 89 | _, err = file.Read(lenBuf) 90 | 91 | if err != nil { 92 | if err == io.EOF { 93 | os.Truncate(log.File.Name(), pos) 94 | return keyDir, nil 95 | } 96 | return nil, err 97 | } 98 | var keyLen uint32 99 | err = util.ByteToInt(lenBuf, &keyLen) 100 | if err != nil { 101 | return nil, err 102 | } 103 | 104 | // 读取valuelen 可能是-1 105 | _, err = file.Read(lenBuf) 106 | if err != nil { 107 | if err == io.EOF { 108 | os.Truncate(log.File.Name(), pos) 109 | return keyDir, nil 110 | } 111 | return nil, err 112 | } 113 | var valueLenOrTombstone int32 114 | err = util.ByteToInt(lenBuf, &valueLenOrTombstone) 115 | if err != nil { 116 | return nil, err 117 | } 118 | 119 | // 读取key值 120 | key := make([]byte, keyLen) 121 | _, err = file.Read(key) 122 | if err != nil { 123 | if err == io.EOF { 124 | os.Truncate(log.File.Name(), pos) 125 | return keyDir, nil 126 | } 127 | return nil, err 128 | } 129 | valuePos := pos + 4 + 4 + int64(keyLen) 130 | 131 | if valueLenOrTombstone > 0 { 132 | // 跳过value seek时如果大于文件长度并不会报错 133 | file.Seek(int64(valueLenOrTombstone), io.SeekCurrent) 134 | 135 | if valuePos+int64(valueLenOrTombstone) > fileLen { 136 | os.Truncate(log.File.Name(), pos) 137 | return keyDir, nil 138 | } 139 | } 140 | byteItem := &ByteItem{ 141 | Key: key, 142 | Value: &ValueOffset{ 143 | Pos: uint64(valuePos), 144 | }, 145 | } 146 | if valueLenOrTombstone > 0 { 147 | byteItem.Value.Len = uint32(valueLenOrTombstone) 148 | keyDir.ReplaceOrInsert(byteItem) 149 | pos = valuePos + int64(valueLenOrTombstone) 150 | } else { 151 | byteItem.Value.Len = 0 152 | keyDir.Delete(byteItem) 153 | pos = valuePos 154 | } 155 | 156 | } 157 | 158 | return keyDir, nil 159 | } 160 | 161 | // 从日志文件中读取value 162 | func (log *Log) ReadValue(valuePos uint64, valueLen uint32) (buffer []byte, err error) { 163 | buffer = make([]byte, valueLen) 164 | _, err = log.File.Seek(int64(valuePos), io.SeekStart) 165 | if err != nil { 166 | return buffer, err 167 | } 168 | _, err = log.File.Read(buffer) 169 | return buffer, err 170 | } 171 | 172 | // 将Key->Value值追加到日志文件中,Value为nil值表示墓碑条目。它返回条目的位置和长度。 173 | func (log *Log) writeEntry(key, value []byte) (uint64, uint32) { 174 | keyLen := uint32(len(key)) 175 | valueLen := uint32(len(value)) 176 | valueLenOrTombstone := int32(-1) 177 | if value != nil { 178 | valueLenOrTombstone = int32(len(value)) 179 | } 180 | itemLen := 4 + 4 + keyLen + valueLen 181 | file := log.File 182 | pos, _ := file.Seek(0, io.SeekEnd) 183 | 184 | keyLenByte := util.BinaryToByte(keyLen) 185 | file.Write(keyLenByte) 186 | valueLenByte := util.BinaryToByte(valueLenOrTombstone) 187 | file.Write(valueLenByte) 188 | 189 | file.Write(key) 190 | file.Write(value) 191 | file.Sync() 192 | 193 | return uint64(pos), itemLen 194 | } 195 | 196 | // 这是一个非常简化的BitCask 197 | // BitCask将键值对写入一个追加型的日志文件,并在内存中保持Key->(ValuePos,ValueLen)的映射关系 198 | // 删除一个键时,会在日志追加一条表示删除的特殊标记 199 | type BitCask struct { 200 | Log *Log 201 | KeyDir *btree.BTree 202 | } 203 | 204 | // 自定义序列化方法(仅保存 Log.Path) 205 | func (bc *BitCask) GobEncode() ([]byte, error) { 206 | // 定义临时结构体,仅包含需要序列化的字段 207 | tmp := struct { 208 | LogPath string // 只保留 Log 的 Path 字段 209 | }{ 210 | LogPath: bc.Log.Path, // 提取 Path 211 | } 212 | 213 | var buf bytes.Buffer 214 | encoder := gob.NewEncoder(&buf) 215 | if err := encoder.Encode(tmp); err != nil { 216 | return nil, err 217 | } 218 | return buf.Bytes(), nil 219 | } 220 | 221 | // 自定义反序列化方法(仅恢复 Log.Path) 222 | func (bc *BitCask) GobDecode(data []byte) error { 223 | // 定义临时结构体,仅读取 Log.Path 224 | tmp := struct { 225 | LogPath string 226 | }{} 227 | 228 | decoder := gob.NewDecoder(bytes.NewReader(data)) 229 | if err := decoder.Decode(&tmp); err != nil { 230 | return err 231 | } 232 | 233 | // 重建 Log 结构(其他字段如 File 初始化为零值) 234 | bc.Log = &Log{ 235 | Path: tmp.LogPath, 236 | File: nil, // File 不序列化,需后续手动初始化 237 | } 238 | 239 | // KeyDir 不反序列化,需后续手动重建 240 | bc.KeyDir = nil 241 | 242 | return nil 243 | } 244 | 245 | // 创建一个Bitcask,并自动压缩 246 | func NewCompact(path string, GarbageRatioThreshold float64) *BitCask { 247 | bitCask := NewBitCask(path) 248 | status := bitCask.Status() 249 | GarbageRatio := float64(status.GarbageDiskSize) / float64(status.TotalDiskSize) 250 | if status.GarbageDiskSize > 0 && GarbageRatio >= GarbageRatioThreshold { 251 | logger.Info("start compact") 252 | if err := bitCask.Compact(); err != nil { 253 | panic(err) 254 | } 255 | } 256 | return bitCask 257 | } 258 | 259 | // 打开或者创建一个BitCask 260 | // 读取到不完整条目时需截断 261 | func NewBitCask(path string) *BitCask { 262 | log, err := NewLog(path) 263 | if err != nil { 264 | panic(err) 265 | } 266 | 267 | keyDir, err := log.buildKeyDir() 268 | if err != nil { 269 | panic(err) 270 | } 271 | return &BitCask{ 272 | Log: log, 273 | KeyDir: keyDir, 274 | } 275 | } 276 | 277 | func (bitCask *BitCask) Compact() (err error) { 278 | 279 | type value struct { 280 | valueByte []byte 281 | valueLen uint32 282 | } 283 | 284 | type itemMap struct { 285 | key []byte 286 | value *value 287 | } 288 | 289 | tmpItemList := make([]*itemMap, 0, bitCask.KeyDir.Len()) 290 | 291 | bitCask.KeyDir.Ascend(func(i btree.Item) bool { 292 | item := i.(*ByteItem) 293 | valueByte, err1 := bitCask.Log.ReadValue(item.Value.Pos, item.Value.Len) 294 | if err1 != nil { 295 | return false 296 | } 297 | tmpItem := &itemMap{ 298 | key: item.Key, 299 | value: &value{ 300 | valueByte: valueByte, 301 | valueLen: item.Value.Len, 302 | }, 303 | } 304 | tmpItemList = append(tmpItemList, tmpItem) 305 | return true 306 | }) 307 | 308 | err = os.Truncate(bitCask.Log.File.Name(), 0) 309 | if err != nil { 310 | return err 311 | } 312 | _, err = bitCask.Log.File.Seek(0, io.SeekStart) 313 | if err != nil { 314 | return err 315 | } 316 | 317 | for _, tmpItem := range tmpItemList { 318 | pos, itemLen := bitCask.Log.writeEntry(tmpItem.key, tmpItem.value.valueByte) 319 | bitCask.KeyDir.ReplaceOrInsert(&ByteItem{ 320 | Key: tmpItem.key, 321 | Value: &ValueOffset{ 322 | Pos: pos + uint64(itemLen) - uint64(tmpItem.value.valueLen), 323 | Len: tmpItem.value.valueLen, 324 | }, 325 | }) 326 | } 327 | return err 328 | 329 | } 330 | 331 | func (bitCask *BitCask) Set(key, value []byte) { 332 | pos, itemLen := bitCask.Log.writeEntry(key, value) 333 | valueLen := uint32(len(value)) 334 | 335 | valuePos := pos + uint64(itemLen) - uint64(valueLen) 336 | bitCask.KeyDir.ReplaceOrInsert(&ByteItem{ 337 | Key: key, 338 | Value: &ValueOffset{ 339 | Pos: valuePos, 340 | Len: valueLen, 341 | }, 342 | }) 343 | } 344 | 345 | func (bitCask *BitCask) Get(key []byte) []byte { 346 | valueOffset := bitCask.KeyDir.Get(&ByteItem{Key: key}) 347 | var valueByte []byte 348 | if valueOffset != nil { 349 | byteItem := valueOffset.(*ByteItem) 350 | valueByte, _ = bitCask.Log.ReadValue(byteItem.Value.Pos, byteItem.Value.Len) 351 | } 352 | return valueByte 353 | } 354 | 355 | func (bitCask *BitCask) Delete(key []byte) { 356 | bitCask.KeyDir.Delete(&ByteItem{ 357 | Key: key, 358 | }) 359 | bitCask.Log.writeEntry(key, nil) 360 | } 361 | 362 | func (bitCask *BitCask) Scan(from, to []byte) []*ByteMap { 363 | 364 | byteMapList := []*ByteMap{} 365 | 366 | // 标志 避免一直循环 367 | needStop := false 368 | 369 | bitCask.KeyDir.Ascend(func(i btree.Item) bool { 370 | item := i.(*ByteItem) 371 | 372 | // 如果to为nil 只需要判断key大于from即可 373 | // bytes.Compare(item.Key, from) item>from 是1 item= 0 && bytes.Compare(item.Key, to) <= 0 { 382 | byteMapList = append(byteMapList, &ByteMap{Key: item.Key, Value: value}) 383 | needStop = true 384 | } else { 385 | if needStop { 386 | return false 387 | } 388 | } 389 | 390 | return true 391 | }) 392 | return byteMapList 393 | } 394 | 395 | func (bitCask *BitCask) ScanPrefix(prefix []byte) []*ByteMap { 396 | to := make([]byte, len(prefix)) 397 | endSum := 3 398 | 399 | // 只有前缀 400 | if len(prefix) == 2 { 401 | endSum = 1 402 | } 403 | // 只有版本 Version 404 | if len(prefix) == 10 { 405 | endSum = 8 406 | } 407 | 408 | copy(to, prefix) 409 | flag := false 410 | 411 | for i := len(to) - endSum; i >= 0; i-- { 412 | if to[i] != 0xff { 413 | to[i]++ 414 | flag = true 415 | break 416 | } 417 | } 418 | if !flag { 419 | return bitCask.Scan(prefix, nil) 420 | } 421 | return bitCask.Scan(prefix, to) 422 | 423 | } 424 | 425 | func (bitCask *BitCask) Status() *Status { 426 | keys := uint64(bitCask.KeyDir.Len()) 427 | size := uint64(0) 428 | bitCask.KeyDir.Ascend(func(i btree.Item) bool { 429 | item := i.(*ByteItem) 430 | size = size + uint64(len(item.Key)) + uint64(item.Value.Len) 431 | return true 432 | }) 433 | stat, _ := bitCask.Log.File.Stat() 434 | totalDiskSize := uint64(stat.Size()) 435 | liveDiskSize := size + 8*keys 436 | garbageDiskSize := totalDiskSize - liveDiskSize 437 | return &Status{ 438 | Name: "bitcask", 439 | Keys: keys, 440 | Size: size, 441 | TotalDiskSize: totalDiskSize, 442 | GarbageDiskSize: garbageDiskSize, 443 | LiveDiskSize: liveDiskSize, 444 | FileName: bitCask.FileName(), 445 | } 446 | } 447 | 448 | func (bitCask *BitCask) FlushFile() { 449 | bitCask.Log.File.Sync() 450 | } 451 | 452 | func (bitCask *BitCask) FileName() string { 453 | path, _ := filepath.Abs(bitCask.Log.File.Name()) 454 | return path 455 | } 456 | 457 | func init() { 458 | gob.Register(&BitCask{}) 459 | } 460 | -------------------------------------------------------------------------------- /bitcask/filelock_unix.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || dragonfly || freebsd || illumos || linux || netbsd || openbsd 2 | 3 | package bitcask 4 | 5 | import ( 6 | "errors" 7 | "os" 8 | "syscall" 9 | ) 10 | 11 | func LockFileNonBlocking(file *os.File) error { 12 | if err := syscall.Flock(int(file.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil { 13 | return errors.New("file is already locked") 14 | } 15 | return nil 16 | } 17 | -------------------------------------------------------------------------------- /bitcask/filelock_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package bitcask 4 | 5 | import ( 6 | "errors" 7 | "golang.org/x/sys/windows" 8 | "os" 9 | ) 10 | 11 | func LockFileNonBlocking(file *os.File) error { 12 | 13 | // 使用 LOCKFILE_FAIL_IMMEDIATELY 和 LOCKFILE_EXCLUSIVE_LOCK 标志 14 | flags := windows.LOCKFILE_FAIL_IMMEDIATELY | windows.LOCKFILE_EXCLUSIVE_LOCK 15 | 16 | err := windows.LockFileEx(windows.Handle(file.Fd()), uint32(flags), 0, 1, 0, &windows.Overlapped{}) 17 | if err != nil { 18 | return errors.New("file is already locked") 19 | } 20 | 21 | return err 22 | } 23 | -------------------------------------------------------------------------------- /bitcask/status.go: -------------------------------------------------------------------------------- 1 | package bitcask 2 | 3 | type Status struct { 4 | Name string 5 | Keys uint64 6 | Size uint64 7 | TotalDiskSize uint64 8 | LiveDiskSize uint64 9 | GarbageDiskSize uint64 10 | FileName string 11 | } 12 | -------------------------------------------------------------------------------- /client/client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "cabbageDB/server" 5 | "cabbageDB/sql/catalog" 6 | "cabbageDB/util" 7 | "errors" 8 | "fmt" 9 | "net" 10 | ) 11 | 12 | type Client struct { 13 | Conn net.Conn 14 | Txn *Txn 15 | } 16 | 17 | type Txn struct { 18 | TxnID uint64 19 | Status bool 20 | } 21 | 22 | func (c *Client) Call(request server.RaftRequest) server.RaftResponse { 23 | prefix := [2]byte{server.ClientPrefix} 24 | var err error 25 | switch v := request.(type) { 26 | case *server.Execute: 27 | prefix[1] = server.ExecutePrefix 28 | reqByte := util.BinaryStructToByte(v) 29 | err = util.SendPrefixMsg(c.Conn, prefix, reqByte) 30 | 31 | case *server.GetTable: 32 | prefix[1] = server.GetTablePrefix 33 | reqByte := []byte(v.Data) 34 | err = util.SendPrefixMsg(c.Conn, prefix, reqByte) 35 | 36 | case *server.ListTables: 37 | prefix[1] = server.ListTablesPrefix 38 | _, err = c.Conn.Write(prefix[:]) 39 | 40 | case *server.Status: 41 | prefix[1] = server.StatusPrefix 42 | _, err = c.Conn.Write(prefix[:]) 43 | } 44 | if err != nil { 45 | fmt.Println("client disconnected, restart client:", err.Error()) 46 | return nil 47 | } 48 | 49 | var respPrefix [2]byte 50 | _, _ = c.Conn.Read(respPrefix[:]) 51 | 52 | if respPrefix[0] != server.ClientPrefix { 53 | fmt.Println("protocol validation failed: invalid packet header") 54 | return nil 55 | } 56 | switch respPrefix[1] { 57 | case server.ExecutePrefix: 58 | resultSet := server.ExecuteResp{} 59 | respByte := util.ReceiveMsg(c.Conn) 60 | util.ByteToStruct(respByte, &resultSet) 61 | return &resultSet 62 | case server.ListTablesPrefix: 63 | resultSet := server.ListTables{} 64 | respByte := util.ReceiveMsg(c.Conn) 65 | util.ByteToStruct(respByte, &resultSet) 66 | return &resultSet 67 | case server.GetTablePrefix: 68 | resultSet := server.GetTableResp{} 69 | respByte := util.ReceiveMsg(c.Conn) 70 | util.ByteToStruct(respByte, &resultSet) 71 | return &resultSet 72 | case server.StatusPrefix: 73 | resultSet := server.Status{} 74 | respByte := util.ReceiveMsg(c.Conn) 75 | util.ByteToStruct(respByte, &resultSet) 76 | return &resultSet 77 | case server.RespErrPrefix: 78 | errResultSet := server.RespError{} 79 | respByte := util.ReceiveMsg(c.Conn) 80 | util.ByteToStruct(respByte, &errResultSet) 81 | return &errResultSet 82 | } 83 | return nil 84 | } 85 | 86 | func (c *Client) Execute(query string) (catalog.ResultSet, error) { 87 | resp := c.Call(&server.Execute{ 88 | Data: query, 89 | }) 90 | if resp == nil { 91 | return nil, errors.New("server is not responding") 92 | } 93 | errMsg, ok := resp.(*server.RespError) 94 | if ok { 95 | return nil, errors.New(errMsg.Errmsg) 96 | } 97 | 98 | resultSet := resp.(*server.ExecuteResp).Data 99 | 100 | switch v := resultSet.(type) { 101 | case *catalog.QueryResultSet: 102 | 103 | return &catalog.QueryResultSet{ 104 | Columns: v.Columns, 105 | Rows: v.Rows, 106 | }, nil 107 | case *catalog.BeginResultSet: 108 | c.Txn = &Txn{ 109 | TxnID: v.Version, 110 | Status: v.ReadOnly, 111 | } 112 | return resultSet, nil 113 | default: 114 | return resultSet, nil 115 | } 116 | } 117 | 118 | func (c *Client) GetTable(table string) (*catalog.Table, error) { 119 | result := c.Call(&server.GetTable{ 120 | Data: table, 121 | }) 122 | if v, ok := result.(*server.GetTableResp); ok { 123 | return v.Data, nil 124 | } 125 | 126 | errMsg, ok := result.(*server.RespError) 127 | if ok { 128 | return nil, errors.New(errMsg.Errmsg) 129 | } 130 | return nil, nil 131 | } 132 | 133 | func (c *Client) ListTables() []string { 134 | result := c.Call(&server.ListTables{}) 135 | if v, ok := result.(*server.ListTables); ok { 136 | return v.Data 137 | } 138 | return nil 139 | } 140 | 141 | func (c *Client) Status() *catalog.Status { 142 | result := c.Call(&server.Status{}) 143 | if v, ok := result.(*server.Status); ok { 144 | return v.Data 145 | } 146 | return nil 147 | } 148 | -------------------------------------------------------------------------------- /client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "cabbageDB/gobReg" 5 | "cabbageDB/sql/catalog" 6 | "encoding/json" 7 | "errors" 8 | "flag" 9 | "fmt" 10 | "github.com/chzyer/readline" 11 | "log" 12 | "net" 13 | "os" 14 | "strings" 15 | "text/tabwriter" 16 | ) 17 | 18 | // 命令行参数解析 19 | type Options struct { 20 | Command string 21 | Host string 22 | Port uint 23 | } 24 | 25 | func ParseArgs() *Options { 26 | opts := &Options{} 27 | flag.StringVar(&opts.Host, "H", "127.0.0.1", "Host to connect to") 28 | flag.StringVar(&opts.Host, "host", "127.0.0.1", "Host to connect to") 29 | flag.UintVar(&opts.Port, "p", 9605, "Port number to connect to") 30 | flag.UintVar(&opts.Port, "port", 9605, "Port number to connect to") 31 | flag.Parse() 32 | 33 | if args := flag.Args(); len(args) > 0 { 34 | opts.Command = strings.Join(args, " ") 35 | } 36 | return opts 37 | } 38 | 39 | type SQLClient struct { 40 | Client *Client 41 | Editor *readline.Instance 42 | HistoryPath string 43 | ShowHeaders bool 44 | } 45 | 46 | func NewSQLClient(host string, port uint) (*SQLClient, error) { 47 | conn, err := net.Dial("tcp", (fmt.Sprintf("%s:%d", host, port))) 48 | client := Client{ 49 | Conn: conn, 50 | } 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | home, _ := os.UserHomeDir() 56 | return &SQLClient{ 57 | Client: &client, 58 | HistoryPath: fmt.Sprintf("%s/.toysql_history", home), 59 | ShowHeaders: true, 60 | }, nil 61 | } 62 | 63 | // 执行命令入口 64 | func (c *SQLClient) Execute(input string) error { 65 | input = strings.TrimSpace(input) 66 | if input == "" { 67 | return nil 68 | } 69 | 70 | if strings.HasPrefix(input, "!") { 71 | return c.ExecuteCommand(input) 72 | } 73 | return c.ExecuteQuery(input) 74 | } 75 | 76 | // 处理元命令 77 | func (c *SQLClient) ExecuteCommand(input string) error { 78 | parts := strings.Fields(input) 79 | if len(parts) == 0 { 80 | return nil 81 | } 82 | 83 | cmd := parts[0] 84 | args := parts[1:] 85 | 86 | err1 := errors.New("Invalid return value") 87 | 88 | switch cmd { 89 | case "!headers": 90 | if len(args) != 1 { 91 | return fmt.Errorf("usage: !headers ") 92 | } 93 | if args[0] == "false" { 94 | c.ShowHeaders = false 95 | } else { 96 | c.ShowHeaders = true 97 | } 98 | //c.ShowHeaders = (args[0] == "on") 99 | fmt.Printf("Headers %s\n", args[0]) 100 | 101 | case "!help": 102 | fmt.Print(` 103 | Enter a SQL statement terminated by a semicolon (;) to execute it. 104 | Available commands: 105 | 106 | !help Show this help 107 | !status Show server status 108 | !table [name] Show table schema 109 | !tables List tables 110 | `) 111 | 112 | case "!status": 113 | status := c.Client.Status() 114 | if status == nil { 115 | return err1 116 | } 117 | statusJson, _ := json.Marshal(status) 118 | fmt.Println(string(statusJson)) 119 | 120 | case "!table": 121 | if len(args) != 1 { 122 | return fmt.Errorf("usage: !table ") 123 | } 124 | schema, err := c.Client.GetTable(args[0]) 125 | if err != nil { 126 | return err 127 | } 128 | if schema == nil { 129 | return nil 130 | } 131 | if schema.Name != "" { 132 | fmt.Println(schema.String()) 133 | } 134 | case "!tables": 135 | tables := c.Client.ListTables() 136 | if tables == nil { 137 | return nil 138 | } 139 | for _, table := range tables { 140 | fmt.Println(table) 141 | } 142 | 143 | default: 144 | return fmt.Errorf("unknown command: %s", cmd) 145 | } 146 | return nil 147 | } 148 | 149 | // 执行SQL查询 150 | func (c *SQLClient) ExecuteQuery(query string) error { 151 | result, err := c.Client.Execute(query) 152 | if err != nil { 153 | return err 154 | } 155 | 156 | switch res := result.(type) { 157 | case *catalog.BeginResultSet: 158 | if res.ReadOnly == false { 159 | fmt.Printf("Began read-write transaction at version %d\n", res.Version) 160 | } else { 161 | fmt.Printf("Began read-only transaction at version %d\n", res.Version) 162 | } 163 | case *catalog.CommitResultSet: 164 | fmt.Printf("Committed transaction %d\n", res.Version) 165 | case *catalog.RollbackResultSet: 166 | fmt.Printf("Rolled back transaction %d\n", res.Version) 167 | case *catalog.CreateResultSet: 168 | fmt.Printf("Created %d rows\n", res.Count) 169 | case *catalog.DeleteResultSet: 170 | fmt.Printf("Delete %d rows\n", res.Count) 171 | case *catalog.UpdateResultSet: 172 | fmt.Printf("Update %d rows\n", res.Count) 173 | case *catalog.CreateTableResultSet: 174 | fmt.Printf("Created table %s\n", res.Name) 175 | case *catalog.DropTableResultSet: 176 | fmt.Printf("Dropped table %s\n", res.Name) 177 | case *catalog.ExplainResultSet: 178 | fmt.Printf("------Explain info------\n%s", res.NodeInfo) 179 | case *catalog.QueryResultSet: 180 | w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) 181 | defer w.Flush() 182 | 183 | if c.ShowHeaders { 184 | if len(res.Columns) != 0 { 185 | fmt.Fprintln(w, strings.Join(res.Columns, "\t")) 186 | } 187 | } 188 | for _, row := range res.Rows { 189 | rowStr := []string{} 190 | for _, value := range row { 191 | rowStr = append(rowStr, value.String()) 192 | } 193 | fmt.Fprintln(w, strings.Join(rowStr, "\t")) 194 | } 195 | 196 | default: 197 | jsonByte, _ := json.Marshal(res) 198 | fmt.Printf("%+v\n", string(jsonByte)) 199 | } 200 | return nil 201 | } 202 | 203 | // REPL交互循环 204 | func (c *SQLClient) Run() error { 205 | rl, err := readline.NewEx(&readline.Config{ 206 | HistoryFile: c.HistoryPath, 207 | AutoComplete: c.CreateCompleter(), 208 | InterruptPrompt: "^C", 209 | EOFPrompt: "exit", 210 | }) 211 | if err != nil { 212 | return err 213 | } 214 | defer rl.Close() 215 | c.Editor = rl 216 | 217 | // 加载历史记录 218 | if _, err := os.Stat(c.HistoryPath); err == nil { 219 | rl.SetHistoryPath(c.HistoryPath) 220 | } 221 | 222 | var ( 223 | multiLineBuffer strings.Builder 224 | prompt = "sql> " 225 | ) 226 | 227 | for { 228 | // 动态设置提示符 229 | rl.SetPrompt(prompt) 230 | 231 | input, err := rl.Readline() 232 | if err != nil { // 处理Ctrl+D/Ctrl+C 233 | break 234 | } 235 | 236 | line := strings.TrimSpace(input) 237 | if line == "" { 238 | continue 239 | } 240 | 241 | // 累积多行输入 242 | if multiLineBuffer.Len() > 0 { 243 | multiLineBuffer.WriteByte(' ') 244 | } 245 | multiLineBuffer.WriteString(line) 246 | 247 | // 检查语句结束 248 | if strings.HasSuffix(line, ";") { 249 | // 执行完整SQL语句 250 | query := strings.TrimSuffix(multiLineBuffer.String(), ";") 251 | if err := c.Execute(query); err != nil { 252 | fmt.Printf("Error: %v\n", err) 253 | } 254 | 255 | // 重置状态 256 | multiLineBuffer.Reset() 257 | prompt = "sql> " 258 | rl.SaveHistory(c.HistoryPath) 259 | } else { 260 | // 进入多行模式 261 | prompt = " -> " // 或者用 "...> " 等其他提示符 262 | } 263 | } 264 | return nil 265 | } 266 | 267 | // 创建自动补全器 268 | func (c *SQLClient) CreateCompleter() *readline.PrefixCompleter { 269 | return readline.NewPrefixCompleter( 270 | readline.PcItem("SELECT", 271 | readline.PcItem("FROM"), 272 | readline.PcItem("WHERE"), 273 | ), 274 | readline.PcItem("INSERT"), 275 | readline.PcItem("UPDATE"), 276 | readline.PcItem("BEGIN"), 277 | readline.PcItem("COMMIT"), 278 | readline.PcItem("ROLLBACK"), 279 | readline.PcItem("!help"), 280 | readline.PcItem("!exit"), 281 | ) 282 | } 283 | 284 | func main() { 285 | gobReg.GobRegMain() 286 | opts := ParseArgs() 287 | 288 | sqlClient, err := NewSQLClient(opts.Host, (opts.Port)) 289 | if err != nil { 290 | log.Fatal(err) 291 | } 292 | 293 | if opts.Command != "" { 294 | if err := sqlClient.Execute(opts.Command); err != nil { 295 | log.Fatal(err) 296 | } 297 | return 298 | } 299 | 300 | if err := sqlClient.Run(); err != nil { 301 | log.Fatal(err) 302 | } 303 | } 304 | -------------------------------------------------------------------------------- /clusters/node1/config.yaml: -------------------------------------------------------------------------------- 1 | id: 1 2 | peers: { 3 | "2":"127.0.0.1:9702", 4 | "3":"127.0.0.1:9703", 5 | "4":"127.0.0.1:9704", 6 | "5":"127.0.0.1:9705", 7 | } 8 | log_level: INFO 9 | listen_sql: 0.0.0.0:9601 10 | listen_raft: 0.0.0.0:9701 11 | data_dir: data 12 | compact_threshold: 0.2 13 | storage_raft: bitcask 14 | storage_sql: bitcask -------------------------------------------------------------------------------- /clusters/node2/config.yaml: -------------------------------------------------------------------------------- 1 | id: 2 2 | peers: { 3 | "1":"127.0.0.1:9701", 4 | "3":"127.0.0.1:9703", 5 | "4":"127.0.0.1:9704", 6 | "5":"127.0.0.1:9705", 7 | } 8 | log_level: INFO 9 | listen_sql: 0.0.0.0:9602 10 | listen_raft: 0.0.0.0:9702 11 | data_dir: data 12 | compact_threshold: 0.2 13 | storage_raft: bitcask 14 | storage_sql: bitcask -------------------------------------------------------------------------------- /clusters/node3/config.yaml: -------------------------------------------------------------------------------- 1 | id: 3 2 | peers: { 3 | "2":"127.0.0.1:9702", 4 | "1":"127.0.0.1:9701", 5 | "4":"127.0.0.1:9704", 6 | "5":"127.0.0.1:9705", 7 | } 8 | log_level: INFO 9 | listen_sql: 0.0.0.0:9603 10 | listen_raft: 0.0.0.0:9703 11 | data_dir: data 12 | compact_threshold: 0.2 13 | storage_raft: bitcask 14 | storage_sql: bitcask -------------------------------------------------------------------------------- /clusters/node4/config.yaml: -------------------------------------------------------------------------------- 1 | id: 4 2 | peers: { 3 | "2":"127.0.0.1:9702", 4 | "3":"127.0.0.1:9703", 5 | "1":"127.0.0.1:9701", 6 | "5":"127.0.0.1:9705", 7 | } 8 | log_level: INFO 9 | listen_sql: 0.0.0.0:9604 10 | listen_raft: 0.0.0.0:9704 11 | data_dir: data 12 | compact_threshold: 0.2 13 | storage_raft: bitcask 14 | storage_sql: bitcask -------------------------------------------------------------------------------- /clusters/node5/config.yaml: -------------------------------------------------------------------------------- 1 | id: 5 2 | peers: { 3 | "2":"127.0.0.1:9702", 4 | "3":"127.0.0.1:9703", 5 | "4":"127.0.0.1:9704", 6 | "1":"127.0.0.1:9701", 7 | } 8 | log_level: INFO 9 | listen_sql: 0.0.0.0:9605 10 | listen_raft: 0.0.0.0:9705 11 | data_dir: data 12 | compact_threshold: 0.2 13 | storage_raft: bitcask 14 | storage_sql: bitcask -------------------------------------------------------------------------------- /clusters/run.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | setlocal enabledelayedexpansion 3 | 4 | 5 | pushd "%~dp0.." 6 | set "ROOT_DIR=!CD!" 7 | popd 8 | 9 | 10 | for /l %%i in (1,1,5) do ( 11 | start "Node%%i" cmd /k ^"cd /d "!ROOT_DIR!" ^&^& go run main.go --config "clusters\node%%i\config.yaml"^" 12 | ) -------------------------------------------------------------------------------- /clusters/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ROOT_DIR="$(cd "$(dirname "$0")/.." && pwd)" 4 | 5 | for node in {1..5}; do 6 | (cd "$ROOT_DIR" && go run ./main.go --config "clusters/node${node}/config.yaml") & 7 | done -------------------------------------------------------------------------------- /config/db.yaml: -------------------------------------------------------------------------------- 1 | id: 1 2 | peers: { 3 | } 4 | log_level: INFO 5 | listen_sql: 0.0.0.0:9605 6 | listen_raft: 0.0.0.0:9705 7 | data_dir: data 8 | compact_threshold: 0.2 9 | sync: true 10 | storage_raft: bitcask 11 | storage_sql: bitcask 12 | -------------------------------------------------------------------------------- /doc/Group.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/Group.png -------------------------------------------------------------------------------- /doc/Having.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/Having.png -------------------------------------------------------------------------------- /doc/Transaction.md: -------------------------------------------------------------------------------- 1 | # 事务 2 | 3 | #### 1. 事务类型分类 4 | 5 | 本系统支持两种基础事务类型,并根据操作特性进行扩展: 6 | 7 | **1.1 只读事务 (Read-Only Transaction)** 8 | 9 | - 特性:仅执行数据查询操作 10 | - 版本控制:提交(commit)/回滚(rollback)操作不会触发数据库版本号递增 11 | - 扩展类型: 12 | - 基于版本号的只读事务: 13 | - 通过指定版本号访问历史数据快照 14 | - 支持跨多版本数据一致性读取 15 | 16 | **1.2 读写事务 (Read-Write Transaction)** 17 | 18 | - 特性:包含数据修改操作(DML) 19 | - 版本控制:成功提交时将递增数据库版本号 20 | - ACID保障: 21 | - 原子性(Atomicity) 22 | - 一致性(Consistency) 23 | - 隔离性(Isolation) 24 | - 持久性(Durability) 25 | 26 | #### 2. 事务控制模式 27 | 28 | **2.1 显式事务 (Explicit Transaction)** 29 | 30 | - 控制方式:需手动声明事务边界 31 | 32 | - 管理命令: 33 | 34 | ```sql 35 | BEGIN; -- 开启只读事务 36 | COMMIT; -- 提交变更 37 | ROLLBACK; -- 回滚操作 38 | START TRANSACTION READ ONLY; -- 开启只读事务 39 | START TRANSACTION READ WRITE; -- 开启读写事务 40 | START TRANSACTION READ ONLY 3; -- 开启基于版本的只读事务 41 | ``` 42 | 43 | **2.2 隐式事务 (Implicit Transaction)** 44 | 45 | - 控制方式:由系统自动管理 46 | 47 | - 运行特征: 48 | 49 | - 单条SQL语句自动构成独立事务 50 | - 执行成功自动提交 51 | - 异常时自动回滚 52 | - 无版本号变更(仅限只读操作) 53 | 54 | #### 3. 版本控制机制 55 | 56 | - 版本号生成规则: 57 | - 读写事务成功提交时+1 58 | - 只读事务不改变当前版本号 59 | - 快照隔离实现: 60 | - 基于MVCC多版本并发控制 61 | - 读操作获取事务开始时的版本快照 62 | - 写操作创建新数据版本 63 | 64 | #### 4.案例 65 | 66 | - 创建表: 67 | 68 | ```sql 69 | CREATE TABLE employees ( 70 | id INT PRIMARY KEY, 71 | firstname VARCHAR(50), 72 | lastname VARCHAR(500), 73 | email VARCHAR(50) UNIQUE 74 | ); 75 | ``` 76 | 77 | ![image-20250428213704560](./create.png) 78 | 79 | - 查看当前版本状态 80 | 81 | ```go 82 | !status; 83 | ``` 84 | 85 | ![image-20250428215031845](./status.png) 86 | 87 | 可知当前版本为2 88 | 89 | - 插入数据 90 | 91 | ```sql 92 | INSERT INTO Employees (ID, FirstName, LastName, Email) VALUES(3, 'Alice', 'Brown', 'alice@example.com'), (4, 'Bob', 'Johnson', NULL), (5, 'Charlie', 'Lee', 'charlie.lee@example.com'); 93 | ``` 94 | 95 | ​ ![image-20250428215510886](./insert.png) 96 | 97 | - 查看当前版本状态 98 | 99 | ```go 100 | !status; 101 | ``` 102 | 103 | ![image-20250428215649972](./status2.png)版本已更新为Version:3 104 | 105 | - 查询数据 106 | 107 | ![image-20250428215810958](./select.png) 108 | 109 | - 指定版本查询数据 110 | 111 | ![image-20250428215932300](./asofSelect.png) 112 | 113 | - 查看最新数据 114 | 115 | ![image-20250428220029183](./asofSelect2.png) 116 | -------------------------------------------------------------------------------- /doc/and.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/and.png -------------------------------------------------------------------------------- /doc/asofSelect.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/asofSelect.png -------------------------------------------------------------------------------- /doc/asofSelect2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/asofSelect2.png -------------------------------------------------------------------------------- /doc/cnf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/cnf.png -------------------------------------------------------------------------------- /doc/column.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/column.png -------------------------------------------------------------------------------- /doc/columnConstraint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/columnConstraint.png -------------------------------------------------------------------------------- /doc/create.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/create.png -------------------------------------------------------------------------------- /doc/crossjoin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/crossjoin.png -------------------------------------------------------------------------------- /doc/dataType.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/dataType.png -------------------------------------------------------------------------------- /doc/delete.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/delete.png -------------------------------------------------------------------------------- /doc/drop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/drop.png -------------------------------------------------------------------------------- /doc/eq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/eq.png -------------------------------------------------------------------------------- /doc/explain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/explain.png -------------------------------------------------------------------------------- /doc/expr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/expr.png -------------------------------------------------------------------------------- /doc/func.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/func.png -------------------------------------------------------------------------------- /doc/innerjoin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/innerjoin.png -------------------------------------------------------------------------------- /doc/insert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/insert.png -------------------------------------------------------------------------------- /doc/leftjoin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/leftjoin.png -------------------------------------------------------------------------------- /doc/like.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/like.png -------------------------------------------------------------------------------- /doc/limit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/limit.png -------------------------------------------------------------------------------- /doc/noteq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/noteq.png -------------------------------------------------------------------------------- /doc/or.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/or.png -------------------------------------------------------------------------------- /doc/rightjoin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/rightjoin.png -------------------------------------------------------------------------------- /doc/select.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/select.png -------------------------------------------------------------------------------- /doc/sql.md: -------------------------------------------------------------------------------- 1 | ## 1.可解析的数据类型 2 | 3 | - CHAR 4 | 5 | - VARCHAR 6 | 7 | - TEXT 8 | 9 | - BOOLEAN 10 | 11 | - BOOL 12 | 13 | - FLOAT 14 | 15 | - Double 16 | 17 | - Int 18 | 19 | - Integer 20 | 21 | ## 2.可解析的常量 22 | 23 | - TRUE 24 | 25 | - FLASE 26 | - NULL 27 | 28 | ```sql 29 | CREATE TABLE example_table ( 30 | id INT PRIMARY KEY, 31 | char_col CHAR(10), 32 | varchar_col VARCHAR(255), 33 | text_col TEXT, 34 | boolean_col BOOLEAN, 35 | bool_col BOOL, 36 | float_col FLOAT, 37 | double_col DOUBLE, 38 | int_col INT, 39 | integer_col INTEGER, 40 | nullable_col VARCHAR(100) 41 | ); 42 | 43 | INSERT INTO example_table ( 44 | id, char_col, varchar_col, text_col, 45 | boolean_col, bool_col, float_col, 46 | double_col, int_col, integer_col, nullable_col 47 | ) VALUES ( 48 | 1, 49 | 'ABC', 50 | 'ABC', 51 | 'ABC', 52 | TRUE, 53 | FALSE, 54 | 3.14, 55 | 2.71828, 56 | 42, 57 | 100, 58 | NULL 59 | ); 60 | ``` 61 | 62 | 63 | 64 | ## 3.可解析的列定义 65 | 66 | - NOT NULL 67 | - NULL 68 | - PRIMARY KEY 69 | - UNIQUE 70 | - DEFAULE 71 | 72 | ```sql 73 | CREATE TABLE example_table2 ( 74 | id INT PRIMARY KEY NOT NULL, 75 | username VARCHAR(100) UNIQUE NOT NULL, 76 | email VARCHAR(100) UNIQUE NULL DEFAULT NULL, 77 | is_active BOOLEAN NOT NULL DEFAULT TRUE 78 | ); 79 | INSERT INTO example_table2 (id,username) 80 | VALUES (1,'A'); 81 | ``` 82 | 83 | 84 | 85 | ## 4.可解析的列约束 86 | 87 | - PRIMARY KEY 88 | - UNIQUE KEY 89 | - KEY 90 | - FOREIGN KEY 91 | 92 | ```sql 93 | CREATE TABLE departments ( 94 | department_id INT NOT NULL, 95 | department_name VARCHAR(50) NOT NULL, 96 | PRIMARY KEY (department_id), 97 | UNIQUE KEY (department_name) 98 | ); 99 | 100 | INSERT INTO departments (department_id,department_name) 101 | VALUES 102 | (1,'Human Resources'), 103 | (2,'Engineering'); 104 | 105 | CREATE TABLE employees ( 106 | employee_id INT NOT NULL, 107 | salary INT NOT NULL, 108 | name VARCHAR(50) NOT NULL, 109 | email VARCHAR(100) NOT NULL, 110 | ext_id INT, 111 | 112 | PRIMARY KEY (employee_id), 113 | UNIQUE KEY (email), 114 | KEY salary (salary), 115 | FOREIGN KEY (name) 116 | REFERENCES departments(department_name) 117 | ); 118 | 119 | INSERT INTO employees (employee_id,salary, name,email,ext_id) 120 | VALUES 121 | (1,100, 'Human Resources', '123@.com', 11), 122 | (2,200, 'Engineering', '456@.com', 22); 123 | ``` 124 | 125 | 126 | 127 | ## 5.可解析的函数 128 | 129 | - SUM 130 | - MIN 131 | - MAX 132 | - AVG 133 | - COUNT 134 | 135 | ```sql 136 | CREATE TABLE sales ( 137 | id INT PRIMARY KEY, 138 | product_name VARCHAR(50), 139 | price DOUBLE, 140 | quantity INT 141 | ); 142 | 143 | INSERT INTO sales (id,product_name, price, quantity) VALUES 144 | (1,'Laptop', 999.99, 3), 145 | (2,'Phone', 699.50, 5), 146 | (3,'Tablet', 299.00, 2), 147 | (4,'Headphones', 149.99, 8), 148 | (5,'Speaker', 199.95, 4); 149 | 150 | ``` 151 | 152 | 153 | 154 | ## 6.可解析的表达式 155 | 156 | - \+ 157 | 158 | - \- 159 | 160 | - \* 161 | 162 | - / 163 | 164 | - ^ 165 | 166 | - % 167 | 168 | 169 | 170 | - = 171 | 172 | - \> 173 | 174 | - \>= / <> 175 | 176 | ![image-20250522191657295](./eq.png) 177 | 178 | - < 179 | 180 | - \<= 181 | 182 | - != / <> 183 | 184 | ![image-20250522191856850](./noteq.png) 185 | 186 | - and / && 187 | 188 | ![image-20250522192005141](./and.png) 189 | 190 | - or / || 191 | 192 | ![image-20250522192434068](./or.png) 193 | 194 | - 嵌套and、or 195 | 196 | ![image-20250522202100905](./cnf.png) 197 | 198 | ## 7.可解析关键词 199 | 200 | - ORDER BY 201 | 202 | - GROUP BY 203 | 204 | - DESC ASC 205 | 206 | ```sql 207 | CREATE TABLE employees ( 208 | employee_id INT PRIMARY KEY, 209 | name VARCHAR(50), 210 | department VARCHAR(50), 211 | position VARCHAR(50), 212 | salary FLOAT 213 | ); 214 | 215 | INSERT INTO employees (employee_id, name, department, position, salary) 216 | VALUES 217 | (1, 'Alice', 'HR', 'Manager', 5000.00), 218 | (2, 'Bob', 'IT', 'Developer', 6000.00), 219 | (3, 'Charlie', 'HR', 'Analyst', 5500.00), 220 | (4, 'David', 'IT', 'Developer', 7000.00), 221 | (5, 'Eva', 'Finance', 'Accountant', 4800.00); 222 | ``` 223 | 224 | 225 | 226 | - HAVING 227 | 228 | 229 | 230 | - LIMIT 231 | 232 | - OFFSET 233 | 234 | 235 | 236 | - LIKE 237 | 238 | 239 | 240 | ## 8.可解析的join 类型 241 | 242 | - LEFT JOIN 243 | 244 | ```sql 245 | CREATE TABLE users ( 246 | id INT PRIMARY KEY, 247 | username VARCHAR(50), 248 | email VARCHAR(50) 249 | ); 250 | 251 | CREATE TABLE orders ( 252 | order_id INT PRIMARY KEY, 253 | user_id INT, 254 | product VARCHAR(50), 255 | amount FLOAT 256 | ); 257 | 258 | INSERT INTO users (id, username, email) VALUES 259 | (1, 'John Doe', 'john@example.com'), 260 | (2, 'Jane Smith', 'jane@example.com'), 261 | (3, 'Bob Johnson', 'bob@example.com'), 262 | (4, 'Alice Brown', 'alice@example.com'); 263 | 264 | INSERT INTO orders (order_id, user_id, product, amount) VALUES 265 | (101, 1, 'Laptop', 999.99), 266 | (102, 1, 'Mouse', 19.99), 267 | (103, 2, 'Keyboard', 49.99), 268 | (104, 1, 'Monitor', 199.99), 269 | (105, 999, 'Headphones', 89.99), 270 | (106, NULL, 'USB Cable', 9.99); 271 | ``` 272 | 273 | 274 | 275 | - Right join 276 | 277 | 278 | 279 | - CROSS JOIN 280 | 281 | ```sql 282 | CREATE TABLE colors ( 283 | color_id INT PRIMARY KEY, 284 | color_name VARCHAR(20) 285 | ); 286 | 287 | CREATE TABLE sizes ( 288 | size_id INT PRIMARY KEY, 289 | size_code VARCHAR(10) 290 | ); 291 | 292 | INSERT INTO colors (color_id, color_name) VALUES 293 | (1, 'Red'), 294 | (2, 'Blue'), 295 | (3, 'Green'); 296 | 297 | INSERT INTO sizes (size_id, size_code) VALUES 298 | (101, 'S'), 299 | (102, 'M'), 300 | (103, 'L'), 301 | (104, 'XL'); 302 | ``` 303 | 304 | 305 | 306 | - inner join(join) 307 | 308 | 309 | 310 | ## 9.删改 311 | 312 | - UPDATE 313 | 314 | 315 | 316 | - DELETE 317 | 318 | 319 | 320 | - DROP TABLE 321 | 322 | 323 | 324 | ## 10.执行计划 325 | 326 | 建表语句见4 327 | 328 | 329 | 330 | -------------------------------------------------------------------------------- /doc/status.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/status.png -------------------------------------------------------------------------------- /doc/status2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/status2.png -------------------------------------------------------------------------------- /doc/update.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/doc/update.png -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module cabbageDB 2 | 3 | go 1.22.0 4 | 5 | require ( 6 | github.com/google/btree v1.1.3 7 | github.com/google/uuid v1.4.0 8 | github.com/pkg/errors v0.9.1 9 | github.com/spf13/viper v1.19.0 10 | go.uber.org/zap v1.27.0 11 | golang.org/x/sys v0.25.0 12 | modernc.org/mathutil v1.7.1 13 | modernc.org/parser v1.0.8 14 | modernc.org/sortutil v1.2.1 15 | modernc.org/strutil v1.2.1 16 | modernc.org/y v1.1.0 17 | ) 18 | 19 | require ( 20 | github.com/chzyer/readline v1.5.1 // indirect 21 | github.com/fsnotify/fsnotify v1.7.0 // indirect 22 | github.com/hashicorp/hcl v1.0.0 // indirect 23 | github.com/magiconair/properties v1.8.7 // indirect 24 | github.com/mitchellh/mapstructure v1.5.0 // indirect 25 | github.com/pelletier/go-toml/v2 v2.2.3 // indirect 26 | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect 27 | github.com/sagikazarmark/locafero v0.6.0 // indirect 28 | github.com/sagikazarmark/slog-shim v0.1.0 // indirect 29 | github.com/sourcegraph/conc v0.3.0 // indirect 30 | github.com/spf13/afero v1.11.0 // indirect 31 | github.com/spf13/cast v1.7.0 // indirect 32 | github.com/spf13/pflag v1.0.5 // indirect 33 | github.com/subosito/gotenv v1.6.0 // indirect 34 | go.uber.org/multierr v1.11.0 // indirect 35 | golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect 36 | golang.org/x/text v0.18.0 // indirect 37 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 38 | gopkg.in/ini.v1 v1.67.0 // indirect 39 | gopkg.in/yaml.v3 v3.0.1 // indirect 40 | modernc.org/golex v1.0.5 // indirect 41 | ) 42 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= 2 | github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= 3 | github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= 4 | github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= 7 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= 9 | github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= 10 | github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= 11 | github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= 12 | github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= 13 | github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= 14 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 15 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 16 | github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= 17 | github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 18 | github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= 19 | github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= 20 | github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 21 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 22 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 23 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 24 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 25 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 26 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 27 | github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= 28 | github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= 29 | github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= 30 | github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= 31 | github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= 32 | github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= 33 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 34 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 35 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= 36 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 37 | github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= 38 | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= 39 | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= 40 | github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= 41 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= 42 | github.com/sagikazarmark/locafero v0.6.0 h1:ON7AQg37yzcRPU69mt7gwhFEBwxI6P9T4Qu3N51bwOk= 43 | github.com/sagikazarmark/locafero v0.6.0/go.mod h1:77OmuIc6VTraTXKXIs/uvUxKGUXjE1GbemJYHqdNjX0= 44 | github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= 45 | github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= 46 | github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= 47 | github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= 48 | github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= 49 | github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= 50 | github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= 51 | github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= 52 | github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= 53 | github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= 54 | github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= 55 | github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg= 56 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 57 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 58 | github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= 59 | github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= 60 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 61 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 62 | go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= 63 | go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= 64 | go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= 65 | go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= 66 | golang.org/x/exp v0.0.0-20181106170214-d68db9428509/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= 67 | golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= 68 | golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= 69 | golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 70 | golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= 71 | golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 72 | golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= 73 | golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= 74 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 75 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 76 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 77 | gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= 78 | gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= 79 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 80 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 81 | modernc.org/fileutil v1.1.2/go.mod h1:HdjlliqRHrMAI4nVOvvpYVzVgvRSK7WnoCiG0GUWJNo= 82 | modernc.org/golex v1.0.5 h1:M+4kIjbDMvKN4pAuh5gJBOfG7Emi9WXGpg2Eay1dlGI= 83 | modernc.org/golex v1.0.5/go.mod h1:pTY7KKjdvZbv2ROjfp6FFX5BXMM9QWZEnmCsl60aCfI= 84 | modernc.org/lex v1.1.1/go.mod h1:6r8o8DLJkAnOsQaGi8fMoi+Vt6LTbDaCrkUK729D8xM= 85 | modernc.org/lexer v1.0.4/go.mod h1:tOajb8S4sdfOYitzCgXDFmbVJ/LE0v1fNJ7annTw36U= 86 | modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= 87 | modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= 88 | modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= 89 | modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= 90 | modernc.org/parser v1.0.8 h1:teneMxK6cqIJeowCj+dMrlSu6mWxQykgVSChJkb7CDo= 91 | modernc.org/parser v1.0.8/go.mod h1:gSb1YDm/lCtL9U4M6+HIk2JfFiGgguEJUjZlIdK3I0g= 92 | modernc.org/scanner v1.1.0/go.mod h1:pDSh3vhQZeHFCjpcSzhDsvDIDOku2b/DdagPGXkK35o= 93 | modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss= 94 | modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= 95 | modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= 96 | modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= 97 | modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= 98 | modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= 99 | modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= 100 | modernc.org/y v1.1.0 h1:JdIvLry+rKeSsVNRCdr6YWYimwwNm0GXtzxid77VfWc= 101 | modernc.org/y v1.1.0/go.mod h1:Iz3BmyIS4OwAbwGaUS7cqRrLsSsfp2sFWtpzX+P4CsE= 102 | -------------------------------------------------------------------------------- /gobReg/init.go: -------------------------------------------------------------------------------- 1 | package gobReg 2 | 3 | import ( 4 | "cabbageDB/log" 5 | "cabbageDB/server" 6 | "cabbageDB/sql/catalog" 7 | "cabbageDB/sql/engine" 8 | "cabbageDB/sql/expr" 9 | "cabbageDB/sqlparser/ast" 10 | "cabbageDB/storage" 11 | ) 12 | 13 | // gob注册结构体并初始化 避免序列化时字节变化 14 | func GobRegMain() { 15 | storage.GobReg() 16 | engine.GobReg() 17 | catalog.GobReg() 18 | server.GobReg() 19 | expr.GobReg() 20 | log.GobReg() 21 | ast.GobReg() 22 | 23 | } 24 | -------------------------------------------------------------------------------- /log/Leader.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "math" 5 | "sort" 6 | ) 7 | 8 | type Progress struct { 9 | Next Index 10 | Last Index 11 | } 12 | 13 | type Leader struct { 14 | Progress map[NodeID]*Progress 15 | SinceHeartBeat Ticks 16 | NodeInfo *NodeInfo 17 | } 18 | 19 | func (l *Leader) Step(msg Message) Node { 20 | if msg.Term < l.NodeInfo.Term && msg.Term > 0 { 21 | return l 22 | } 23 | if msg.Term > l.NodeInfo.Term { 24 | return l.NodeInfo.BecomeFollower(msg.Term, 0, nil).Step(msg) 25 | } 26 | 27 | switch v := msg.Event.(type) { 28 | case *ConfirmLeader: 29 | from := GetNodeID(msg.From) 30 | l.NodeInfo.StateTx <- &Vote{ 31 | Term: msg.Term, 32 | Index: v.CommitIndex, 33 | Address: msg.From, 34 | } 35 | if !v.HasCommitted { 36 | l.SendLog(from) 37 | } 38 | case *AcceptEntries: 39 | from := GetNodeID(msg.From) 40 | l.Progress[from].Last = v.LastIndex 41 | l.Progress[from].Next = v.LastIndex + 1 42 | l.MaybeCommit() 43 | case *RejectEntries: 44 | from := GetNodeID(msg.From) 45 | if l.Progress[from].Next > 1 { 46 | l.Progress[from].Next = l.Progress[from].Next - 1 47 | } 48 | l.SendLog(from) 49 | case *ClientRequest: 50 | switch req := v.Request.(type) { 51 | case *RaftQuery: 52 | commitIndex, _ := l.NodeInfo.Log.GetCommitIndex() 53 | l.NodeInfo.StateTx <- &Query{ 54 | ID: v.ID, 55 | Address: msg.From, 56 | Command: req.Command, 57 | Term: l.NodeInfo.Term, 58 | Index: commitIndex, 59 | Quorum: uint64(l.NodeInfo.Quorum()), 60 | } 61 | l.NodeInfo.StateTx <- &Vote{ 62 | Term: l.NodeInfo.Term, 63 | Index: commitIndex, 64 | Address: SetNodeID(l.NodeInfo.ID), 65 | } 66 | l.HeartBeat() 67 | case *RaftMutate: 68 | index := l.Propose(req.Command) 69 | l.NodeInfo.StateTx <- &Notify{ 70 | ID: v.ID, 71 | Address: msg.From, 72 | Index: index, 73 | } 74 | if len(l.NodeInfo.Peers) == 0 { 75 | l.MaybeCommit() 76 | } 77 | case *RaftStatus: 78 | engineStatus := l.NodeInfo.Log.Status() 79 | 80 | lastIndexMap := make(map[NodeID]Index) 81 | for id, progress := range l.Progress { 82 | lastIndexMap[id] = progress.Last 83 | } 84 | lastIndex, _ := l.NodeInfo.Log.GetLastIndex() 85 | lastIndexMap[l.NodeInfo.ID] = lastIndex 86 | 87 | commitIndex, _ := l.NodeInfo.Log.GetCommitIndex() 88 | 89 | status := NodeStatus{ 90 | Server: l.NodeInfo.ID, 91 | Leader: l.NodeInfo.ID, 92 | Term: l.NodeInfo.Term, 93 | NodeLastIndex: lastIndexMap, 94 | ApplyIndex: 0, 95 | Storage: engineStatus.Name, 96 | StorageSize: engineStatus.Size, 97 | CommitIndex: commitIndex, 98 | FileName: l.NodeInfo.Log.Engine.FileName(), 99 | } 100 | 101 | l.NodeInfo.StateTx <- &Status{ 102 | ID: v.ID, 103 | Address: msg.From, 104 | Statue: &status, 105 | } 106 | } 107 | case *ClientResponse: 108 | if status, ok := v.Response.(*RaftStatus); ok { 109 | status.Status.Server = l.NodeInfo.ID 110 | } 111 | l.NodeInfo.Send([]byte{AddressPrefix, ClientPrefix}, &ClientResponse{ 112 | ID: v.ID, 113 | Response: v.Response, 114 | }) 115 | 116 | } 117 | 118 | return l 119 | } 120 | func (l *Leader) Tick() Node { 121 | l.SinceHeartBeat += 1 122 | if l.SinceHeartBeat >= HEARTBEAT_INTERVAL { 123 | l.HeartBeat() 124 | l.SinceHeartBeat = 0 125 | } 126 | return l 127 | } 128 | func (l *Leader) Info() *NodeInfo { 129 | return l.NodeInfo 130 | } 131 | 132 | func (l *Leader) HeartBeat() { 133 | commitIndex, commitTerm := l.NodeInfo.Log.GetCommitIndex() 134 | l.NodeInfo.Send([]byte{AddressPrefix, BroadcastPrefix}, &HeartBeat{ 135 | CommitIndex: commitIndex, 136 | CommitTerm: commitTerm, 137 | }) 138 | } 139 | 140 | func (l *Leader) Propose(command []byte) Index { 141 | index := l.Info().Log.Append(l.NodeInfo.Term, command) 142 | for id := range l.NodeInfo.Peers { 143 | l.SendLog(id) 144 | } 145 | return index 146 | } 147 | 148 | func (l *Leader) MaybeCommit() Index { 149 | indexes := []int{} 150 | for _, v := range l.Progress { 151 | indexes = append(indexes, int(v.Last)) 152 | } 153 | 154 | lastIndex, _ := l.NodeInfo.Log.GetLastIndex() 155 | 156 | indexes = append(indexes, int(lastIndex)) 157 | 158 | sort.Sort(sort.Reverse(sort.IntSlice(indexes))) 159 | commitIndex := Index(indexes[l.NodeInfo.Quorum()-1]) 160 | if commitIndex == 0 { 161 | return commitIndex 162 | } 163 | 164 | prevCommitIndex, _ := l.NodeInfo.Log.GetCommitIndex() 165 | if commitIndex < prevCommitIndex { 166 | return prevCommitIndex 167 | } 168 | 169 | entry := l.NodeInfo.Log.Get(commitIndex) 170 | if entry != nil && entry.Term != l.NodeInfo.Term { 171 | return prevCommitIndex 172 | } 173 | if commitIndex > prevCommitIndex { 174 | l.NodeInfo.Log.Commit(commitIndex) 175 | scan := l.NodeInfo.Log.Scan(prevCommitIndex, commitIndex, false, true) 176 | for _, entryItem := range scan { 177 | l.NodeInfo.StateTx <- &Apply{ 178 | Entry: entryItem, 179 | } 180 | } 181 | 182 | } 183 | return commitIndex 184 | } 185 | 186 | func (l *Leader) SendLog(peer NodeID) { 187 | 188 | var baseIndex Index 189 | var baseTerm Term 190 | if progress, ok := l.Progress[peer]; ok { 191 | if progress.Next > 1 { 192 | entry := l.NodeInfo.Log.Get(progress.Next - 1) 193 | baseIndex, baseTerm = entry.Index, entry.Term 194 | } 195 | } else { 196 | baseIndex, baseTerm = 0, 0 197 | } 198 | 199 | entires := l.NodeInfo.Log.Scan(baseIndex+1, math.MaxUint64, true, true) 200 | to := SetNodeID(peer) 201 | l.NodeInfo.Send(to, &AppendEntries{ 202 | BaseIndex: baseIndex, 203 | BaseTerm: baseTerm, 204 | Entries: entires, 205 | }) 206 | } 207 | 208 | func NewLeader(peers map[NodeID]struct{}, lastIndex Index, nodeInfo *NodeInfo) *Leader { 209 | next := lastIndex + 1 210 | progress := make(map[NodeID]*Progress) 211 | 212 | for nodeID := range peers { 213 | progress[nodeID] = &Progress{ 214 | Next: next, 215 | Last: 0, 216 | } 217 | } 218 | return &Leader{ 219 | Progress: progress, 220 | SinceHeartBeat: 0, 221 | NodeInfo: nodeInfo, 222 | } 223 | } 224 | -------------------------------------------------------------------------------- /log/candidate.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "cabbageDB/logger" 5 | ) 6 | 7 | type Candidate struct { 8 | Votes map[NodeID]struct{} 9 | ElectionDuration Ticks 10 | ElectionTimeout Ticks 11 | NodeInfo *NodeInfo 12 | } 13 | 14 | func NewCandidate(nodeInfo *NodeInfo) *Candidate { 15 | return &Candidate{ 16 | Votes: make(map[NodeID]struct{}), 17 | ElectionDuration: 0, 18 | ElectionTimeout: randElectionTimeout(), 19 | NodeInfo: nodeInfo, 20 | } 21 | } 22 | 23 | func (c *Candidate) Step(msg Message) Node { 24 | info := c.Info() 25 | if msg.Term < info.Term && msg.Term > 0 { 26 | logger.Info("Dropping message from past term ", msg) 27 | return c 28 | } 29 | if msg.Term > info.Term { 30 | return c.NodeInfo.BecomeFollower(msg.Term, 0, nil).Step(msg) 31 | } 32 | 33 | switch v := msg.Event.(type) { 34 | case *GrantVote: 35 | id := GetNodeID(msg.From) 36 | if id != 0 { 37 | c.Votes[id] = struct{}{} 38 | } 39 | if len(c.Votes) >= c.NodeInfo.Quorum() { 40 | return c.NodeInfo.BecomeLeader() 41 | } 42 | case *HeartBeat, *AppendEntries: 43 | id := GetNodeID(msg.From) 44 | return c.NodeInfo.BecomeFollower(msg.Term, id, nil).Step(msg) 45 | case *ClientRequest: 46 | c.NodeInfo.Send(msg.From, &ClientResponse{ 47 | ID: v.ID, 48 | Response: &RaftError{}, 49 | }) 50 | } 51 | return c 52 | } 53 | 54 | func (c *Candidate) Tick() Node { 55 | c.ElectionDuration += 1 56 | 57 | if c.ElectionDuration >= c.ElectionTimeout { 58 | c = NewCandidate(c.NodeInfo) 59 | c.Campaign() 60 | } 61 | return c 62 | } 63 | 64 | func (c *Candidate) Info() *NodeInfo { 65 | return c.NodeInfo 66 | } 67 | 68 | func (c *Candidate) Campaign() { 69 | c.NodeInfo.Term += 1 70 | c.Votes[c.NodeInfo.ID] = struct{}{} 71 | c.NodeInfo.Log.SetTerm(c.NodeInfo.Term, c.NodeInfo.ID) 72 | lastIndex, lastTerm := c.NodeInfo.Log.GetLastIndex() 73 | c.NodeInfo.Send([]byte{AddressPrefix, BroadcastPrefix}, &SolicitVote{ 74 | LastIndex: lastIndex, 75 | LastTerm: lastTerm, 76 | }) 77 | } 78 | -------------------------------------------------------------------------------- /log/follower.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "cabbageDB/logger" 5 | ) 6 | 7 | type Follower struct { 8 | Leader NodeID 9 | LeaderSeen Ticks 10 | ElectionTimeout Ticks 11 | VotedFor NodeID 12 | Forwarded map[RequestID]struct{} 13 | NodeInfo *NodeInfo 14 | } 15 | 16 | func (f *Follower) Step(msg Message) Node { 17 | 18 | info := f.NodeInfo 19 | if msg.Term < info.Term && msg.Term > 0 { 20 | logger.Info("Dropping message from past term ", msg) 21 | return f 22 | } 23 | if msg.Term > f.NodeInfo.Term { 24 | return f.NodeInfo.BecomeFollower(msg.Term, 0, &f.VotedFor).Step(msg) 25 | } 26 | if f.IsLeader(msg.From) { 27 | f.LeaderSeen = 0 28 | } 29 | switch v := msg.Event.(type) { 30 | case *HeartBeat: 31 | if f.Leader == 0 { 32 | return f.Info().BecomeFollower(msg.Term, GetNodeID(msg.From), &f.VotedFor).Step(msg) 33 | } 34 | hasCommited := f.NodeInfo.Log.Has(v.CommitIndex, v.CommitTerm) 35 | oldCommitIndex, _ := f.NodeInfo.Log.GetCommitIndex() 36 | if hasCommited && v.CommitIndex > oldCommitIndex { 37 | f.NodeInfo.Log.Commit(v.CommitIndex) 38 | scan := f.NodeInfo.Log.Scan(oldCommitIndex+1, v.CommitIndex, true, true) 39 | for _, entry := range scan { 40 | f.NodeInfo.StateTx <- &Apply{Entry: entry} 41 | } 42 | } 43 | 44 | f.NodeInfo.Send(msg.From, &ConfirmLeader{ 45 | CommitIndex: v.CommitIndex, 46 | HasCommitted: hasCommited, 47 | }) 48 | case *AppendEntries: 49 | from := GetNodeID(msg.From) 50 | if f.Leader == 0 { 51 | f.Leader = from 52 | } else if f.Leader != from { 53 | logger.Info("Multiple leaders in term") 54 | return f 55 | } 56 | if v.BaseIndex > 0 && !f.NodeInfo.Log.Has(v.BaseIndex, v.BaseTerm) { 57 | f.NodeInfo.Send(msg.From, &RejectEntries{}) 58 | } else { 59 | lastIndex := f.NodeInfo.Log.Splice(v.Entries) 60 | f.NodeInfo.Send(msg.From, &AcceptEntries{ 61 | LastIndex: lastIndex, 62 | }) 63 | } 64 | case *SolicitVote: 65 | from := GetNodeID(msg.From) 66 | if f.VotedFor != 0 && from != f.VotedFor { 67 | return f 68 | } 69 | logIndex, logTerm := f.NodeInfo.Log.GetLastIndex() 70 | 71 | if v.LastTerm > logTerm || (v.LastTerm == logTerm && v.LastIndex >= logIndex) { 72 | f.NodeInfo.Send(msg.From, &GrantVote{}) 73 | f.NodeInfo.Log.SetTerm(f.NodeInfo.Term, from) 74 | f.VotedFor = from 75 | } 76 | case *ClientRequest: 77 | if msg.From[0] != AddressPrefix || msg.From[1] != ClientPrefix { 78 | return f 79 | } 80 | if f.Leader != 0 { 81 | f.Forwarded[v.ID] = struct{}{} 82 | to := SetNodeID(f.Leader) 83 | f.NodeInfo.Send(to, msg.Event) 84 | } else { 85 | f.NodeInfo.Send(msg.From, &ClientResponse{ 86 | ID: v.ID, 87 | Response: &RaftError{}, 88 | }) 89 | } 90 | case *ClientResponse: 91 | if !f.IsLeader(msg.From) { 92 | return f 93 | } 94 | 95 | if status, ok := v.Response.(*RaftStatus); ok { 96 | status.Status.Server = f.NodeInfo.ID 97 | } 98 | if _, ok := f.Forwarded[v.ID]; ok { 99 | f.NodeInfo.Send([]byte{AddressPrefix, ClientPrefix}, &ClientResponse{ 100 | ID: v.ID, 101 | Response: v.Response, 102 | }) 103 | delete(f.Forwarded, v.ID) 104 | } 105 | 106 | } 107 | 108 | return f 109 | } 110 | 111 | func (f *Follower) IsLeader(from Address) bool { 112 | id := GetNodeID(from) 113 | if f.Leader == id { 114 | return true 115 | } 116 | return false 117 | } 118 | 119 | func (f *Follower) Info() *NodeInfo { 120 | return f.NodeInfo 121 | } 122 | func (f *Follower) Tick() Node { 123 | f.LeaderSeen += 1 124 | if f.LeaderSeen >= f.ElectionTimeout { 125 | return f.BecomeCandidate() 126 | } 127 | return f 128 | } 129 | func (f *Follower) BecomeCandidate() Node { 130 | f.AbortForwarded() 131 | candidate := NewCandidate(f.NodeInfo) 132 | candidate.Campaign() 133 | return candidate 134 | } 135 | 136 | func NewFollower(leader NodeID, votedFor NodeID, nodeInfo *NodeInfo) Node { 137 | return &Follower{ 138 | Leader: leader, 139 | VotedFor: votedFor, 140 | LeaderSeen: 0, 141 | ElectionTimeout: randElectionTimeout(), 142 | Forwarded: make(map[RequestID]struct{}), 143 | NodeInfo: nodeInfo, 144 | } 145 | } 146 | 147 | func (f *Follower) AbortForwarded() { 148 | for id := range f.Forwarded { 149 | f.NodeInfo.Send([]byte{AddressPrefix, ClientPrefix}, &ClientResponse{ 150 | ID: id, 151 | Response: &RaftError{}, 152 | }) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /log/log.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "cabbageDB/bitcask" 5 | "cabbageDB/logger" 6 | "cabbageDB/util" 7 | ) 8 | 9 | type Index uint64 10 | type Term uint64 11 | 12 | const ( 13 | LogKeyPrefix byte = 0x02 14 | EntryPrefix byte = 0x02 15 | TermVotePrefix byte = 0x03 16 | CommitIndexPrefix byte = 0x04 17 | ) 18 | 19 | // 一个日志条目 20 | type Entry struct { 21 | Index Index `json:"index"` //条目索引 22 | Term Term `json:"term"` //条目被添加时所在的任期 23 | Command []byte `json:"command"` //状态机命令 24 | } 25 | 26 | // 一个Raft Log 27 | type RaftLog struct { 28 | Engine Engine //底层存储引擎 29 | LastIndex Index //最后存储条目的索引 30 | LastTerm Term //最后存储条目的任期编号 31 | CommitIndex Index //最后提交条目的索引 32 | CommitTerm Term //最后提交条目的任期编号 33 | } 34 | 35 | func NewRaftLog(engine Engine) *RaftLog { 36 | byteMapList := engine.ScanPrefix([]byte{LogKeyPrefix, EntryPrefix}) 37 | 38 | var lastIndex Index 39 | var lastTerm Term 40 | if len(byteMapList) > 0 { 41 | 42 | var byteMap *bitcask.ByteMap 43 | if len(byteMapList) > 1 { 44 | // 当只有前缀的时候由于在ScanPrefix里面他会将前缀+1 即将2,2变成2,3 并且scan是包括左右边界 因此需要倒数第二个 45 | byteMap = byteMapList[len(byteMapList)-2] 46 | } else { 47 | byteMap = byteMapList[len(byteMapList)-1] 48 | } 49 | 50 | lastEntry := DecodeEntry(byteMap.Key, byteMap.Value) 51 | lastIndex, lastTerm = lastEntry.Index, lastEntry.Term 52 | } else { 53 | lastIndex, lastTerm = 0, 0 54 | } 55 | 56 | commitEntryByte := engine.Get([]byte{LogKeyPrefix, CommitIndexPrefix}) 57 | commitEntry := DecodeCommitEntry(commitEntryByte) 58 | commitIndex, commitTerm := commitEntry.Index, commitEntry.Term 59 | return &RaftLog{ 60 | Engine: engine, 61 | LastIndex: lastIndex, 62 | LastTerm: lastTerm, 63 | CommitIndex: commitIndex, 64 | CommitTerm: commitTerm, 65 | } 66 | 67 | } 68 | 69 | func (log *RaftLog) SetTerm(term Term, votedFor NodeID) { 70 | termByte := util.BinaryToByte(uint64(term)) 71 | votedForByte := util.BinaryToByte(votedFor) 72 | termByte = append(termByte, votedForByte...) 73 | 74 | log.Engine.Set([]byte{LogKeyPrefix, TermVotePrefix}, termByte) 75 | } 76 | 77 | func (log *RaftLog) Get(index Index) *Entry { 78 | keyByte := append([]byte{LogKeyPrefix, EntryPrefix}, util.BinaryToByte(uint64(index))...) 79 | valueByte := log.Engine.Get(keyByte) 80 | if len(valueByte) == 0 { 81 | return nil 82 | } 83 | return log.DecodeEntryValue(index, valueByte) 84 | } 85 | func (log *RaftLog) Commit(index Index) Index { 86 | if index < log.CommitIndex { 87 | logger.Info("Commit index regression ", log.CommitIndex, " -> ", index) 88 | return 0 89 | } 90 | entry := log.Get(index) 91 | if entry == nil { 92 | logger.Info("Can't commit non-existant index ", index) 93 | return 0 94 | } 95 | log.Engine.Set([]byte{LogKeyPrefix, CommitIndexPrefix}, util.BinaryStructToByte(entry)) 96 | log.CommitIndex = entry.Index 97 | log.CommitTerm = entry.Term 98 | return index 99 | 100 | } 101 | 102 | func (log *RaftLog) GetLastIndex() (Index, Term) { 103 | return log.LastIndex, log.LastTerm 104 | } 105 | func (log *RaftLog) Status() *bitcask.Status { 106 | return log.Engine.Status() 107 | } 108 | 109 | func (log *RaftLog) GetCommitIndex() (Index, Term) { 110 | return log.CommitIndex, log.CommitTerm 111 | } 112 | 113 | func (log *RaftLog) GetTerm() (Term, NodeID) { 114 | valueByte := log.Engine.Get([]byte{LogKeyPrefix, TermVotePrefix}) 115 | if len(valueByte) == 0 { 116 | return 0, 0 117 | } 118 | var term uint64 119 | util.ByteToInt(valueByte[:8], &term) 120 | var nodeID uint8 121 | util.ByteToInt(valueByte[8:], &nodeID) 122 | 123 | return Term(term), nodeID 124 | } 125 | 126 | func (log *RaftLog) Scan(from, to Index, includeFrom, includeTo bool) []*Entry { 127 | 128 | if !includeFrom { 129 | from += 1 130 | } 131 | if !includeTo { 132 | to -= 1 133 | } 134 | 135 | fromKey := append([]byte{LogKeyPrefix, EntryPrefix}, util.BinaryToByte(uint64(from))...) 136 | toKey := append([]byte{LogKeyPrefix, EntryPrefix}, util.BinaryToByte(uint64(to))...) 137 | 138 | items := log.Engine.Scan(fromKey, toKey) 139 | entryList := make([]*Entry, len(items)) 140 | for i, item := range items { 141 | entryList[i] = log.DecodeEntry(item.Key, item.Value) 142 | } 143 | return entryList 144 | 145 | } 146 | 147 | func (log *RaftLog) DecodeEntry(key []byte, value []byte) *Entry { 148 | 149 | if key[0] == LogKeyPrefix && key[1] == EntryPrefix { 150 | var index uint64 151 | util.ByteToInt(key[2:], &index) 152 | return log.DecodeEntryValue(Index(index), value) 153 | 154 | } else { 155 | logger.Info("Invalid key error") 156 | return nil 157 | } 158 | 159 | } 160 | 161 | func (log *RaftLog) DecodeEntryValue(index Index, value []byte) *Entry { 162 | if len(value) == 0 { 163 | return &Entry{ 164 | Index: index, 165 | Term: 0, 166 | Command: []byte{}, 167 | } 168 | } 169 | var term uint64 170 | util.ByteToInt(value[:8], &term) 171 | return &Entry{ 172 | Index: index, 173 | Term: Term(term), 174 | Command: value[8:], 175 | } 176 | } 177 | 178 | func (log *RaftLog) Append(term Term, command []byte) Index { 179 | index := log.LastIndex + 1 180 | indexByte := append([]byte{LogKeyPrefix, EntryPrefix}, util.BinaryToByte(uint64(index))...) 181 | 182 | valueByte := util.BinaryToByte(uint64(term)) 183 | log.Engine.Set(indexByte, append(valueByte, command...)) 184 | //log.Engine.FlushFile() 185 | log.LastIndex = index 186 | log.LastTerm = term 187 | return index 188 | } 189 | func (log *RaftLog) Has(index Index, term Term) bool { 190 | entry := log.Get(index) 191 | if entry == nil { 192 | return false 193 | } 194 | if entry.Term == term { 195 | return true 196 | } 197 | if index == 0 && term == 0 { 198 | return true 199 | } 200 | return false 201 | } 202 | func (log *RaftLog) Splice(entries []*Entry) Index { 203 | enlen := len(entries) 204 | if enlen == 0 { 205 | return log.LastIndex 206 | } 207 | lastentry := entries[enlen-1] 208 | lastIndex := lastentry.Index 209 | lastTerm := lastentry.Term 210 | 211 | scan := log.Scan(entries[0].Index, lastIndex, true, true) 212 | 213 | // 如果在引擎里需要删掉已处理的 214 | for _, entry := range scan { 215 | if entry.Term != entries[0].Term { 216 | break 217 | } 218 | entries = entries[1:] 219 | } 220 | 221 | for _, e := range entries { 222 | valueByte := append(util.BinaryToByte(uint64(e.Term)), e.Command...) 223 | log.Engine.Set(append([]byte{LogKeyPrefix, EntryPrefix}, util.BinaryToByte(uint64(e.Index))...), valueByte) 224 | } 225 | for index := log.LastIndex + 1; index <= log.LastIndex; index++ { 226 | keyByte := append([]byte{LogKeyPrefix, EntryPrefix}, util.BinaryToByte(uint64(index))...) 227 | log.Engine.Delete(keyByte) 228 | } 229 | 230 | //log.Engine.FlushFile() 231 | log.LastIndex = lastIndex 232 | log.LastTerm = lastTerm 233 | return log.LastIndex 234 | } 235 | 236 | type Engine interface { 237 | Delete(key []byte) 238 | Get(key []byte) []byte 239 | Scan(from, to []byte) []*bitcask.ByteMap 240 | ScanPrefix(prefix []byte) []*bitcask.ByteMap 241 | Set(key, value []byte) 242 | Status() *bitcask.Status 243 | FlushFile() 244 | FileName() string 245 | } 246 | 247 | func DecodeEntry(key, value []byte) *Entry { 248 | 249 | var index, term uint64 250 | util.ByteToInt(key[2:], &index) 251 | util.ByteToInt(value[:8], &term) 252 | 253 | entry := Entry{ 254 | Command: value[8:], 255 | Term: Term(term), 256 | Index: Index(index), 257 | } 258 | return &entry 259 | } 260 | 261 | // 从日志获取的Value中解码一个条目 262 | func DecodeCommitEntry(value []byte) *Entry { 263 | entry := Entry{} 264 | 265 | util.ByteToStruct(value, &entry) 266 | return &entry 267 | } 268 | -------------------------------------------------------------------------------- /log/message.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "bytes" 5 | "cabbageDB/util" 6 | "encoding/gob" 7 | "github.com/google/uuid" 8 | ) 9 | 10 | type Message struct { 11 | Term Term 12 | From Address 13 | To Address 14 | Event Event 15 | } 16 | 17 | type Address = []byte 18 | 19 | type RequestID = uuid.UUID 20 | 21 | const ( 22 | AddressPrefix byte = 0x07 23 | BroadcastPrefix byte = 0x02 24 | NodePrefix byte = 0x03 25 | ClientPrefix byte = 0x04 26 | ) 27 | 28 | func GetNodeID(addr Address) NodeID { 29 | var id NodeID 30 | if len(addr) > 2 && addr[0] == AddressPrefix && addr[1] == NodePrefix { 31 | util.ByteToInt(addr[2:], &id) 32 | } 33 | return id 34 | } 35 | func SetNodeID(id NodeID) []byte { 36 | return append([]byte{AddressPrefix, NodePrefix}, util.BinaryToByte(id)...) 37 | } 38 | 39 | type Event interface { 40 | event() 41 | } 42 | 43 | type HeartBeat struct { 44 | CommitIndex Index 45 | CommitTerm Term 46 | } 47 | 48 | func (*HeartBeat) event() {} 49 | 50 | type ConfirmLeader struct { 51 | CommitIndex Index 52 | HasCommitted bool 53 | } 54 | 55 | func (*ConfirmLeader) event() {} 56 | 57 | type SolicitVote struct { 58 | LastIndex Index 59 | LastTerm Term 60 | } 61 | 62 | func (*SolicitVote) event() {} 63 | 64 | type GrantVote struct { 65 | } 66 | 67 | func (*GrantVote) event() {} 68 | 69 | type AppendEntries struct { 70 | BaseIndex Index 71 | BaseTerm Term 72 | Entries []*Entry 73 | } 74 | 75 | func (*AppendEntries) event() {} 76 | 77 | type AcceptEntries struct { 78 | LastIndex Index 79 | } 80 | 81 | func (*AcceptEntries) event() {} 82 | 83 | type RejectEntries struct { 84 | } 85 | 86 | func (*RejectEntries) event() {} 87 | 88 | type ClientRequest struct { 89 | ID RequestID 90 | Request Request 91 | } 92 | 93 | func (*ClientRequest) event() {} 94 | 95 | type ClientResponse struct { 96 | ID RequestID 97 | Response Response 98 | } 99 | 100 | func (*ClientResponse) event() {} 101 | 102 | type Request interface { 103 | requestType() []byte 104 | } 105 | 106 | type Response interface { 107 | responseType() []byte 108 | } 109 | 110 | type RaftQuery struct { 111 | Command []byte 112 | } 113 | 114 | func (q *RaftQuery) requestType() []byte { 115 | return q.Command 116 | } 117 | 118 | func (q *RaftQuery) responseType() []byte { 119 | return q.Command 120 | } 121 | 122 | type RaftMutate struct { 123 | Command []byte 124 | } 125 | 126 | func (m *RaftMutate) requestType() []byte { 127 | return m.Command 128 | } 129 | 130 | func (m *RaftMutate) responseType() []byte { 131 | return m.Command 132 | } 133 | 134 | type RaftStatus struct { 135 | Status *NodeStatus 136 | } 137 | 138 | func (s *RaftStatus) requestType() []byte { 139 | return nil 140 | } 141 | 142 | func (s *RaftStatus) responseType() []byte { 143 | return nil 144 | } 145 | 146 | type RaftError struct { 147 | Errmsg error 148 | } 149 | 150 | func (e *RaftError) responseType() []byte { 151 | return nil 152 | } 153 | 154 | type RaftMessage struct { 155 | Request Request 156 | ResponseTx chan<- Response 157 | } 158 | 159 | func GobReg() { 160 | 161 | gob.Register(&HeartBeat{}) 162 | gob.Register(&ConfirmLeader{}) 163 | gob.Register(&SolicitVote{}) 164 | gob.Register(&GrantVote{}) 165 | gob.Register(&AppendEntries{}) 166 | gob.Register(&AcceptEntries{}) 167 | gob.Register(&RejectEntries{}) 168 | gob.Register(&ClientRequest{}) 169 | gob.Register(&ClientResponse{}) 170 | gob.Register(&RaftQuery{}) 171 | gob.Register(&RaftMutate{}) 172 | gob.Register(&RaftStatus{}) 173 | gob.Register(&RaftError{}) 174 | 175 | // 预先编码保证顺序不变 176 | var buf bytes.Buffer 177 | enc := gob.NewEncoder(&buf) 178 | _ = enc.Encode(&HeartBeat{}) 179 | _ = enc.Encode(&ConfirmLeader{}) 180 | _ = enc.Encode(&SolicitVote{}) 181 | _ = enc.Encode(&GrantVote{}) 182 | _ = enc.Encode(&AppendEntries{}) 183 | _ = enc.Encode(&AppendEntries{}) 184 | _ = enc.Encode(&RejectEntries{}) 185 | _ = enc.Encode(&ClientRequest{}) 186 | _ = enc.Encode(&RaftQuery{}) 187 | _ = enc.Encode(&RaftMutate{}) 188 | _ = enc.Encode(&RaftStatus{}) 189 | _ = enc.Encode(&RaftError{}) 190 | } 191 | -------------------------------------------------------------------------------- /log/node.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "cabbageDB/logger" 5 | "math/rand" 6 | "strconv" 7 | "time" 8 | ) 9 | 10 | type NodeID = uint8 11 | 12 | // 定义选举超时范围 13 | const ( 14 | ELECTION_TIMEOUT_MIN Ticks = 10 15 | ELECTION_TIMEOUT_MAX Ticks = 20 16 | ) 17 | 18 | type Ticks = uint8 19 | 20 | const HEARTBEAT_INTERVAL Ticks = 3 21 | 22 | // randElectionTimeout 生成一个随机的选举超时时间 23 | func randElectionTimeout() Ticks { 24 | // 初始化随机数生成器,确保每次运行程序时产生的序列不同 25 | rand.Seed(time.Now().UnixNano()) 26 | return Ticks(rand.Intn(int(ELECTION_TIMEOUT_MAX-ELECTION_TIMEOUT_MIN)) + int(ELECTION_TIMEOUT_MIN)) 27 | } 28 | 29 | type NodeInfo struct { 30 | ID NodeID 31 | Peers map[NodeID]struct{} 32 | Term Term 33 | Log *RaftLog 34 | NodeTx chan<- Message 35 | StateTx chan<- Instruction 36 | } 37 | 38 | type Node interface { 39 | Step(message Message) Node 40 | Tick() Node 41 | Info() *NodeInfo 42 | } 43 | 44 | func (info *NodeInfo) Quorum() int { 45 | return (len(info.Peers)+1)/2 + 1 46 | } 47 | 48 | func (info *NodeInfo) Send(to Address, event Event) { 49 | msg := Message{ 50 | Term: info.Term, 51 | From: SetNodeID(info.ID), 52 | To: to, 53 | Event: event, 54 | } 55 | info.NodeTx <- msg 56 | 57 | } 58 | 59 | func (info *NodeInfo) BecomeFollower(term Term, leader NodeID, votedFor *NodeID) Node { 60 | if leader != 0 { 61 | logger.Info("Lost election, following leader " + strconv.Itoa(int(leader)) + " in term " + strconv.Itoa(int(term))) 62 | id := info.ID 63 | if votedFor != nil { 64 | id = *votedFor 65 | } 66 | return NewFollower(leader, id, info) 67 | } else { 68 | info.Term = term 69 | info.Log.SetTerm(term, 0) 70 | return NewFollower(0, 0, info) 71 | } 72 | } 73 | 74 | func (info *NodeInfo) BecomeLeader() Node { 75 | last_index, _ := info.Log.GetLastIndex() 76 | node := NewLeader(info.Peers, last_index, info) 77 | node.HeartBeat() 78 | node.Propose([]byte{}) 79 | return node 80 | } 81 | 82 | func NewNode(id NodeID, peers map[NodeID]struct{}, log *RaftLog, state RaftTxnState, nodeTx chan<- Message) Node { 83 | stateChan := make(chan Instruction) 84 | driver := NewDriver(id, stateChan, nodeTx) 85 | driver.ApplyLog(state, log) 86 | go driver.Drive(state) 87 | term, votedFor := log.GetTerm() 88 | 89 | nodeInfo := NodeInfo{ 90 | ID: id, 91 | Peers: peers, 92 | Term: term, 93 | Log: log, 94 | NodeTx: nodeTx, 95 | StateTx: stateChan, 96 | } 97 | 98 | if len(nodeInfo.Peers) == 0 { 99 | if votedFor != id { 100 | nodeInfo.Term += 1 101 | nodeInfo.Log.SetTerm(nodeInfo.Term, id) 102 | } 103 | lastIndex, _ := nodeInfo.Log.GetLastIndex() 104 | 105 | return NewLeader(make(map[NodeID]struct{}), lastIndex, &nodeInfo) 106 | } else { 107 | return NewFollower(0, votedFor, &nodeInfo) 108 | } 109 | 110 | } 111 | 112 | type NodeStatus struct { 113 | Server NodeID 114 | Leader NodeID 115 | Term Term 116 | NodeLastIndex map[NodeID]Index 117 | CommitIndex Index 118 | ApplyIndex Index 119 | Storage string 120 | StorageSize uint64 121 | FileName string 122 | } 123 | -------------------------------------------------------------------------------- /log/server.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "cabbageDB/logger" 5 | "cabbageDB/util" 6 | "github.com/google/uuid" 7 | "net" 8 | "time" 9 | ) 10 | 11 | type Server struct { 12 | Node Node 13 | Peers map[NodeID]string 14 | NodeRx <-chan Message 15 | } 16 | 17 | func NewServer(id NodeID, peers map[NodeID]string, log *RaftLog, state RaftTxnState) *Server { 18 | nodechan := make(chan Message, 10) 19 | 20 | peersHashSet := map[NodeID]struct{}{} 21 | for peer, _ := range peers { 22 | peersHashSet[peer] = struct{}{} 23 | } 24 | node := NewNode(id, peersHashSet, log, state, nodechan) 25 | return &Server{ 26 | Node: node, 27 | Peers: peers, 28 | NodeRx: nodechan, 29 | } 30 | } 31 | 32 | func (s *Server) Serve(listener net.Listener, clientRx <-chan RaftMessage) { 33 | tcpin := make(chan Message, 10) 34 | tcpout := make(chan Message, 10) 35 | go s.TcpReceive(listener, tcpin) 36 | go s.TcpSender(s.Peers, tcpout) 37 | go s.EventLoop(s.Node, s.NodeRx, clientRx, tcpin, tcpout) 38 | } 39 | 40 | func (s *Server) EventLoop(node Node, nodeRx <-chan Message, clientRx <-chan RaftMessage, tcpRx <-chan Message, tcpTx chan<- Message) { 41 | timer := time.NewTicker(100 * time.Millisecond) 42 | requests := make(map[string]chan<- Response) 43 | 44 | for { 45 | select { 46 | case <-timer.C: 47 | node = node.Tick() 48 | case msg := <-tcpRx: 49 | // 从另一个节点接收信息 50 | node = node.Step(msg) 51 | case msg := <-nodeRx: 52 | if msg.To[0] != AddressPrefix { 53 | continue 54 | } 55 | switch msg.To[1] { 56 | case NodePrefix: 57 | tcpTx <- msg 58 | // 发送给另一个节点 59 | case BroadcastPrefix: 60 | // 广播给config里的节点 61 | tcpTx <- msg 62 | case ClientPrefix: 63 | if event, ok := msg.Event.(*ClientResponse); ok { 64 | uuidStr := event.ID.String() 65 | responseTx, ok1 := requests[uuidStr] 66 | if ok1 { 67 | responseTx <- event.Response 68 | delete(requests, uuidStr) 69 | } 70 | } 71 | } 72 | case clientMsg := <-clientRx: 73 | newUUID := uuid.New() 74 | msg := Message{ 75 | From: []byte{AddressPrefix, ClientPrefix}, 76 | To: SetNodeID(node.Info().ID), 77 | Term: 0, 78 | Event: &ClientRequest{ 79 | ID: newUUID, 80 | Request: clientMsg.Request, 81 | }, 82 | } 83 | node = node.Step(msg) 84 | requests[newUUID.String()] = clientMsg.ResponseTx 85 | } 86 | } 87 | } 88 | 89 | func (s *Server) TcpSender(peers map[NodeID]string, outrx <-chan Message) { 90 | peersTxs := make(map[NodeID]chan<- Message) 91 | for nodeID, addr := range peers { 92 | msgchan := make(chan Message, 10) 93 | peersTxs[nodeID] = msgchan 94 | go s.TcpSendPeer(addr, msgchan) 95 | } 96 | 97 | for msg := range outrx { 98 | 99 | to := []NodeID{} 100 | 101 | if msg.To[0] != AddressPrefix { 102 | continue 103 | } 104 | 105 | switch msg.To[1] { 106 | case BroadcastPrefix: 107 | for nodeID, _ := range peersTxs { 108 | to = append(to, nodeID) 109 | } 110 | 111 | case NodePrefix: 112 | var nodeID NodeID 113 | util.ByteToInt(msg.To[2:], &nodeID) 114 | to = append(to, nodeID) 115 | } 116 | 117 | for _, id := range to { 118 | peersTxs[id] <- msg 119 | } 120 | } 121 | } 122 | 123 | func (s *Server) TcpSendPeer(addr string, outrx <-chan Message) { 124 | 125 | var conn net.Conn 126 | var err error 127 | for msg := range outrx { 128 | for i := 0; i < 3; i++ { 129 | conn, err = net.Dial("tcp", addr) 130 | 131 | if err != nil { 132 | time.Sleep(1 * time.Second) 133 | conn = nil 134 | continue 135 | } 136 | 137 | defer conn.Close() 138 | 139 | msgByte := util.BinaryStructToByte(&msg) 140 | err = util.SendPrefixMsg(conn, [2]byte{AddressPrefix, NodePrefix}, msgByte) 141 | if err == nil { 142 | break 143 | } 144 | } 145 | } 146 | 147 | } 148 | 149 | func (s *Server) TcpReceive(listener net.Listener, tcpintx chan<- Message) { 150 | for { 151 | conn, err := listener.Accept() 152 | if err != nil { 153 | logger.Info("TcpReceive err:" + err.Error()) 154 | } 155 | 156 | var preFix [2]byte 157 | conn.Read(preFix[:]) 158 | if preFix[0] != AddressPrefix && preFix[1] != NodePrefix { 159 | conn.Close() 160 | continue 161 | } 162 | msgByte := util.ReceiveMsg(conn) 163 | var message Message 164 | util.ByteToStruct(msgByte, &message) 165 | tcpintx <- message 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /log/state.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "bytes" 5 | "cabbageDB/util" 6 | "github.com/google/btree" 7 | ) 8 | 9 | type RaftTxnState interface { 10 | GetAppliedIndex() uint64 11 | Apply(entry *Entry) ([]byte, error) 12 | Query(command []byte) ([]byte, error) 13 | } 14 | 15 | type Instruction interface { 16 | instruction() 17 | } 18 | 19 | type Abort struct { 20 | } 21 | 22 | func (*Abort) instruction() {} 23 | func (*Abort) responseType() {} 24 | 25 | type Apply struct { 26 | Entry *Entry 27 | } 28 | 29 | func (*Apply) instruction() {} 30 | 31 | type Notify struct { 32 | ID RequestID 33 | Address Address 34 | Index Index 35 | } 36 | 37 | func (*Notify) instruction() { 38 | } 39 | 40 | type Query struct { 41 | ID RequestID 42 | Address Address 43 | Command []byte 44 | Term Term 45 | Index Index 46 | Quorum uint64 47 | } 48 | 49 | func (*Query) instruction() { 50 | } 51 | 52 | type Status struct { 53 | ID RequestID 54 | Address Address 55 | Statue *NodeStatus 56 | } 57 | 58 | func (*Status) instruction() { 59 | } 60 | 61 | type Vote struct { 62 | Term Term 63 | Index Index 64 | Address Address 65 | } 66 | 67 | func (*Vote) instruction() { 68 | } 69 | 70 | type AddressValue struct { 71 | ToAddress Address 72 | ID RequestID 73 | } 74 | 75 | type Driver struct { 76 | NodeID NodeID 77 | StateRx <-chan Instruction 78 | NodeTX chan<- Message 79 | Notify map[Index]*AddressValue 80 | Queries Queries 81 | } 82 | type DriverQuery struct { 83 | ID RequestID 84 | Term Term 85 | Address Address 86 | Command []byte 87 | Quorum uint64 88 | Votes map[NodeID]struct{} 89 | } 90 | 91 | type Queries struct { 92 | Tree *btree.BTree 93 | } 94 | 95 | type QueriesItem struct { 96 | Key Index 97 | Value *btree.BTree 98 | } 99 | 100 | func (q *QueriesItem) Less(than btree.Item) bool { 101 | other := than.(*QueriesItem) 102 | return q.Key < other.Key 103 | } 104 | 105 | type QueriesItemValueItem struct { 106 | ID RequestID 107 | Query DriverQuery 108 | } 109 | 110 | func (q *QueriesItemValueItem) Less(than btree.Item) bool { 111 | other := than.(*QueriesItemValueItem) 112 | qid := ([16]byte)(q.ID) 113 | otherid := ([16]byte)(other.ID) 114 | return bytes.Compare(qid[:], otherid[:]) < 0 115 | } 116 | 117 | func (driver *Driver) NotifyApplied(appliedIndex Index, result []byte) { 118 | if v, ok := driver.Notify[appliedIndex]; ok { 119 | clientResp := ClientResponse{ 120 | ID: v.ID, 121 | Response: &RaftMutate{Command: result}, 122 | } 123 | driver.Send(v.ToAddress, &clientResp) 124 | } 125 | 126 | } 127 | 128 | func (driver *Driver) NotifyAppliedError(appliedIndex Index, err error) { 129 | if v, ok := driver.Notify[appliedIndex]; ok { 130 | clientResp := ClientResponse{ 131 | ID: v.ID, 132 | Response: &RaftError{Errmsg: err}, 133 | } 134 | driver.Send(v.ToAddress, &clientResp) 135 | } 136 | 137 | } 138 | 139 | func (driver *Driver) QueryRead(appliedIndex Index) []*DriverQuery { 140 | ready := []*DriverQuery{} 141 | empty := []*QueriesItem{} 142 | 143 | driver.Queries.Tree.Ascend(func(i btree.Item) bool { 144 | queryItem := i.(*QueriesItem) 145 | if queryItem.Key > appliedIndex { 146 | return false 147 | } 148 | queryItem.Value.Ascend(func(i btree.Item) bool { 149 | valueItem := i.(*QueriesItemValueItem) 150 | 151 | if uint64(len(valueItem.Query.Votes)) >= valueItem.Query.Quorum { 152 | ready = append(ready, &valueItem.Query) 153 | } 154 | 155 | if len(ready) > 0 { 156 | empty = append(empty, queryItem) 157 | } 158 | 159 | return true 160 | }) 161 | 162 | return false 163 | }) 164 | for _, queryItem := range empty { 165 | driver.Queries.Tree.Delete(queryItem) 166 | } 167 | 168 | return ready 169 | } 170 | 171 | // 执行查询 172 | func (driver *Driver) QueryExecute(state RaftTxnState) { 173 | for _, query := range driver.QueryRead(Index(state.GetAppliedIndex())) { 174 | result, err := state.Query(query.Command) 175 | event := ClientResponse{ 176 | ID: query.ID, 177 | Response: &RaftQuery{Command: result}, 178 | } 179 | if err != nil { 180 | event.Response = &RaftError{Errmsg: err} 181 | } 182 | 183 | driver.Send(query.Address, &event) 184 | } 185 | } 186 | 187 | func (driver *Driver) Apply(state RaftTxnState, entry *Entry) Index { 188 | result, err := state.Apply(entry) 189 | if err != nil { 190 | driver.NotifyAppliedError(Index(state.GetAppliedIndex()), err) 191 | driver.QueryExecute(state) 192 | return Index(state.GetAppliedIndex()) 193 | } 194 | driver.NotifyApplied(Index(state.GetAppliedIndex()), result) 195 | driver.QueryExecute(state) 196 | return Index(state.GetAppliedIndex()) 197 | } 198 | 199 | func (driver *Driver) Send(to Address, event Event) { 200 | msg := Message{ 201 | From: SetNodeID(driver.NodeID), 202 | To: to, 203 | Term: 0, 204 | Event: event, 205 | } 206 | driver.NodeTX <- msg 207 | } 208 | 209 | func (driver *Driver) ApplyLog(state RaftTxnState, log *RaftLog) Index { 210 | appliedIndex := Index(state.GetAppliedIndex()) 211 | commitIndex, _ := log.GetCommitIndex() 212 | 213 | if appliedIndex > commitIndex { 214 | panic("applied index above commit index") 215 | } 216 | if appliedIndex < commitIndex { 217 | entryList := log.Scan(appliedIndex, commitIndex, false, true) 218 | 219 | for _, entry := range entryList { 220 | driver.Apply(state, entry) 221 | } 222 | 223 | } 224 | 225 | return Index(state.GetAppliedIndex()) 226 | } 227 | 228 | func (driver *Driver) Drive(state RaftTxnState) { 229 | for instruction := range driver.StateRx { 230 | driver.Execute(instruction, state) 231 | } 232 | } 233 | 234 | func (driver *Driver) Execute(i Instruction, state RaftTxnState) { 235 | switch v := i.(type) { 236 | case *Abort: 237 | driver.NotifyAbort() 238 | driver.QueryAbort() 239 | case *Apply: 240 | driver.Apply(state, v.Entry) 241 | case *Notify: 242 | if v.Index > Index(state.GetAppliedIndex()) { 243 | driver.Notify[v.Index] = &AddressValue{ 244 | ToAddress: v.Address, 245 | ID: v.ID, 246 | } 247 | } else { 248 | event := ClientResponse{ 249 | ID: v.ID, 250 | Response: &RaftError{}, 251 | } 252 | driver.Send(v.Address, &event) 253 | } 254 | case *Query: 255 | 256 | queriesItemValue := btree.New(2) 257 | queriesItemValue.ReplaceOrInsert(&QueriesItemValueItem{ 258 | ID: v.ID, 259 | Query: DriverQuery{ 260 | ID: v.ID, 261 | Term: v.Term, 262 | Address: v.Address, 263 | Quorum: v.Quorum, 264 | Votes: make(map[NodeID]struct{}), 265 | Command: v.Command, 266 | }, 267 | }) 268 | driver.Queries.Tree.ReplaceOrInsert(&QueriesItem{ 269 | Key: v.Index, 270 | Value: queriesItemValue, 271 | }) 272 | case *Status: 273 | v.Statue.ApplyIndex = Index(state.GetAppliedIndex()) 274 | event := ClientResponse{ 275 | ID: v.ID, 276 | Response: &RaftStatus{ 277 | Status: v.Statue, 278 | }, 279 | } 280 | driver.Send(v.Address, &event) 281 | case *Vote: 282 | driver.QueryVote(v.Term, v.Index, v.Address) 283 | driver.QueryExecute(state) 284 | } 285 | } 286 | 287 | // 某地址对某一任期内截止并包含指定提交索引的查询所投的票数 288 | func (driver *Driver) QueryVote(term Term, commitIndex Index, address Address) { 289 | driver.Queries.Tree.Ascend(func(i btree.Item) bool { 290 | queryies := i.(*QueriesItem) 291 | if queryies.Key != commitIndex { 292 | return false 293 | } 294 | queryies.Value.Ascend(func(i btree.Item) bool { 295 | valueItem := i.(*QueriesItemValueItem) 296 | if term >= valueItem.Query.Term { 297 | if address[0] != AddressPrefix || address[1] != NodePrefix { 298 | return false 299 | } 300 | var nodeID uint8 301 | util.ByteToInt(address[2:], &nodeID) 302 | valueItem.Query.Votes[nodeID] = struct{}{} 303 | } 304 | return true 305 | }) 306 | return true 307 | }) 308 | 309 | } 310 | 311 | func (driver *Driver) NotifyAbort() { 312 | for _, v := range driver.Notify { 313 | event := ClientResponse{ 314 | ID: v.ID, 315 | Response: &RaftError{}, 316 | } 317 | driver.Send(v.ToAddress, &event) 318 | } 319 | } 320 | 321 | func (driver *Driver) QueryAbort() { 322 | driver.Queries.Tree.Ascend(func(i btree.Item) bool { 323 | queryItem := i.(*QueriesItem) 324 | queryItem.Value.Ascend(func(i btree.Item) bool { 325 | valueItem := i.(*QueriesItemValueItem) 326 | event := ClientResponse{ 327 | ID: valueItem.ID, 328 | Response: &RaftError{}, 329 | } 330 | driver.Send(valueItem.Query.Address, &event) 331 | return true 332 | }) 333 | return true 334 | }) 335 | } 336 | 337 | func NewDriver(nodeID NodeID, stateRx <-chan Instruction, nodeTX chan<- Message) *Driver { 338 | 339 | queries := btree.New(2) 340 | return &Driver{ 341 | NodeID: nodeID, 342 | StateRx: stateRx, 343 | NodeTX: nodeTX, 344 | Notify: make(map[Index]*AddressValue), 345 | Queries: Queries{ 346 | Tree: queries, 347 | }, 348 | } 349 | } 350 | -------------------------------------------------------------------------------- /logger/logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "go.uber.org/zap" 5 | "go.uber.org/zap/zapcore" 6 | "os" 7 | "path/filepath" 8 | ) 9 | 10 | var log *zap.SugaredLogger 11 | 12 | func InitLogger(cfgID string, cfgLogLevel string) { 13 | 14 | logDir := filepath.Join(".", "logs") 15 | if err := os.MkdirAll(logDir, 0755); err != nil { 16 | log.Fatalf("failed to create logs directory: %v", err) 17 | } 18 | 19 | // 创建日志文件 20 | file, err := os.OpenFile("./logs/server_"+cfgID+".log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) 21 | if err != nil { 22 | panic("failed to open log file: " + err.Error()) 23 | } 24 | 25 | // 正确初始化 EncoderConfig 26 | config := zap.NewProductionEncoderConfig() 27 | config.EncodeTime = zapcore.ISO8601TimeEncoder 28 | 29 | // 创建编码器 30 | consoleEncoder := zapcore.NewConsoleEncoder(config) 31 | fileEncoder := zapcore.NewJSONEncoder(config) 32 | 33 | // 获取日志级别 34 | logLevel := getLoggerLevel(cfgLogLevel) 35 | atomicLevel := zap.NewAtomicLevelAt(logLevel) 36 | 37 | // 正确构造 Tee 核心 38 | core := zapcore.NewTee( 39 | zapcore.NewCore(consoleEncoder, zapcore.AddSync(os.Stdout), atomicLevel), 40 | zapcore.NewCore(fileEncoder, zapcore.AddSync(file), atomicLevel), 41 | ) 42 | 43 | logger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(2)) 44 | log = logger.Sugar() 45 | } 46 | 47 | func getLoggerLevel(lvl string) zapcore.Level { 48 | 49 | switch lvl { 50 | case "debug": 51 | return zap.DebugLevel 52 | case "info": 53 | return zap.InfoLevel 54 | case "warn": 55 | return zap.WarnLevel 56 | case "error": 57 | return zap.ErrorLevel 58 | case "panic": 59 | return zap.PanicLevel 60 | case "dpanic": 61 | return zap.DPanicLevel 62 | case "fatal": 63 | return zap.FatalLevel 64 | } 65 | 66 | return zapcore.InfoLevel 67 | } 68 | 69 | func Debug(args ...interface{}) { 70 | log.Debug(args...) 71 | } 72 | 73 | func Debugf(format string, args ...interface{}) { 74 | log.Debugf(format, args...) 75 | } 76 | 77 | func Info(args ...interface{}) { 78 | log.Info(args...) 79 | } 80 | 81 | func Infof(format string, args ...interface{}) { 82 | log.Infof(format, args...) 83 | } 84 | 85 | func Warn(args ...interface{}) { 86 | log.Warn(args...) 87 | } 88 | 89 | func Warnf(format string, args ...interface{}) { 90 | log.Warnf(format, args...) 91 | } 92 | 93 | func Error(args ...interface{}) { 94 | log.Error(args...) 95 | } 96 | 97 | func Errorf(format string, args ...interface{}) { 98 | log.Errorf(format, args...) 99 | } 100 | 101 | func DPanic(args ...interface{}) { 102 | log.DPanic(args...) 103 | } 104 | 105 | func DPanicf(format string, args ...interface{}) { 106 | log.DPanicf(format, args...) 107 | } 108 | 109 | func Panic(args ...interface{}) { 110 | log.Panic(args...) 111 | } 112 | 113 | func Panicf(format string, args ...interface{}) { 114 | log.Panicf(format, args...) 115 | } 116 | 117 | func Fatal(args ...interface{}) { 118 | log.Fatal(args...) 119 | } 120 | 121 | func Fatalf(format string, args ...interface{}) { 122 | log.Fatalf(format, args...) 123 | } 124 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "cabbageDB/bitcask" 5 | "cabbageDB/gobReg" 6 | "cabbageDB/log" 7 | "cabbageDB/logger" 8 | server2 "cabbageDB/server" 9 | "cabbageDB/sql/engine" 10 | "flag" 11 | "fmt" 12 | "github.com/spf13/viper" 13 | "path/filepath" 14 | "strconv" 15 | ) 16 | 17 | func main() { 18 | 19 | configFile := flag.String("config", "config/db.yaml", "Configuration file path") 20 | flag.Parse() 21 | if *configFile != "" { 22 | fmt.Printf("config is: %s\n", *configFile) 23 | } else { 24 | fmt.Println("No configuration file provided.") 25 | } 26 | 27 | root := filepath.Dir(*configFile) 28 | 29 | cfg := LoadConfig(*configFile) 30 | if cfg.ID == 0 { 31 | panic("id not allow equal 0") 32 | } 33 | gobReg.GobRegMain() 34 | 35 | logger.InitLogger(strconv.Itoa(int(cfg.ID)), cfg.LogLevel) 36 | 37 | var raftLog *log.RaftLog 38 | if cfg.StorageRaft == "bitcask" { 39 | raftLog = log.NewRaftLog(bitcask.NewCompact(filepath.Join(root, cfg.DataDir, "log"), cfg.CompactThresh)) 40 | } else { 41 | panic(fmt.Sprintf("Unknown Raft storage engine %s", cfg.StorageRaft)) 42 | } 43 | 44 | var raftState log.RaftTxnState 45 | if cfg.StorageRaft == "bitcask" { 46 | stateEngine := bitcask.NewCompact(filepath.Join(root, cfg.DataDir, "state"), cfg.CompactThresh) 47 | raftState = engine.NewState(stateEngine) 48 | } else { 49 | panic(fmt.Sprintf("Unknown Raft storage engine %s", cfg.StorageRaft)) 50 | } 51 | 52 | ser := server2.NewServer(cfg.ID, cfg.Peers, raftLog, raftState) 53 | ser.Listen(cfg.ListenSQL, cfg.ListenRaft) 54 | ser.Serve() 55 | } 56 | 57 | type Config struct { 58 | ID log.NodeID `json:"id" mapstructure:"id"` 59 | Peers map[log.NodeID]string `json:"peers" mapstructure:"peers"` 60 | ListenSQL string `json:"listen_sql" mapstructure:"listen_sql"` 61 | ListenRaft string `json:"listen_raft" mapstructure:"listen_raft"` 62 | LogLevel string `json:"log_level" mapstructure:"log_level"` 63 | DataDir string `json:"data_dir" mapstructure:"data_dir"` 64 | CompactThresh float64 `json:"compact_threshold" mapstructure:"compact_threshold"` 65 | StorageRaft string `json:"storage_raft" mapstructure:"storage_raft"` 66 | StorageSQL string `json:"storage_sql" mapstructure:"storage_sql"` 67 | } 68 | 69 | func DefaultConfig() *Config { 70 | return &Config{ 71 | ID: 1, 72 | Peers: map[log.NodeID]string{}, 73 | ListenSQL: "0.0.0.0:9605", 74 | ListenRaft: "0.0.0.0:9705", 75 | LogLevel: "INFO", 76 | DataDir: "data", 77 | CompactThresh: 0.2, 78 | StorageRaft: "bitcask", 79 | StorageSQL: "bitcask", 80 | } 81 | } 82 | 83 | func LoadConfig(configFile string) *Config { 84 | viperCfg := viper.New() 85 | viperCfg.AddConfigPath(".") 86 | viperCfg.SetConfigFile(configFile) 87 | if err := viperCfg.ReadInConfig(); err != nil { 88 | fmt.Println("Read Config error:", err.Error()) 89 | } 90 | 91 | config := DefaultConfig() 92 | if err := viperCfg.Unmarshal(config); err != nil { 93 | fmt.Println(err) 94 | return DefaultConfig() 95 | } 96 | return config 97 | } 98 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "cabbageDB/log" 6 | "cabbageDB/logger" 7 | "cabbageDB/sql" 8 | "cabbageDB/sql/catalog" 9 | "cabbageDB/sql/engine" 10 | "cabbageDB/util" 11 | "encoding/gob" 12 | "errors" 13 | "net" 14 | "strings" 15 | ) 16 | 17 | const ( 18 | ClientPrefix = 0x08 19 | ExecutePrefix = 0x02 20 | GetTablePrefix = 0x03 21 | ListTablesPrefix = 0x04 22 | StatusPrefix = 0x05 23 | RespErrPrefix = 0x06 24 | ) 25 | 26 | type Server struct { 27 | Raft *log.Server 28 | RaftListener net.Listener 29 | SQLListener net.Listener 30 | } 31 | 32 | func NewServer(id log.NodeID, peers map[log.NodeID]string, raftLog *log.RaftLog, state log.RaftTxnState) *Server { 33 | return &Server{ 34 | Raft: log.NewServer(id, peers, raftLog, state), 35 | } 36 | } 37 | 38 | func (s *Server) Listen(sqlAddr string, raftAddr string) { 39 | sqlListener, err := net.Listen("tcp", sqlAddr) 40 | if err != nil { 41 | panic("listen sqladdr " + sqlAddr + " err:" + err.Error()) 42 | } 43 | raftLitener, err := net.Listen("tcp", raftAddr) 44 | if err != nil { 45 | panic("listen raftaddr" + raftAddr + " err:" + err.Error()) 46 | } 47 | s.SQLListener = sqlListener 48 | s.RaftListener = raftLitener 49 | 50 | } 51 | 52 | func (s *Server) Serve() { 53 | raftChan := make(chan log.RaftMessage, 10) 54 | sqlEngine := engine.NewRaft(raftChan) 55 | 56 | s.Raft.Serve(s.RaftListener, raftChan) 57 | s.ServeSQL(s.SQLListener, sqlEngine) 58 | } 59 | 60 | func (s *Server) ServeSQL(listener net.Listener, engine *engine.Raft) { 61 | for { 62 | conn, err := listener.Accept() 63 | if err != nil { 64 | logger.Info("ServeSQL err:" + err.Error()) 65 | } 66 | 67 | session := NewSession(engine) 68 | go session.Handle(conn) 69 | 70 | } 71 | } 72 | 73 | type RaftRequest interface { 74 | requestType() 75 | } 76 | type Execute struct { 77 | Data string 78 | } 79 | 80 | func (*Execute) requestType() {} 81 | 82 | type GetTable struct { 83 | Data string 84 | } 85 | 86 | func (*GetTable) requestType() {} 87 | 88 | type RespError struct { 89 | Errmsg string 90 | } 91 | 92 | func (*RespError) requestType() {} 93 | func (*RespError) responseType() {} 94 | 95 | type ListTables struct { 96 | Data []string 97 | } 98 | 99 | func (*ListTables) requestType() {} 100 | func (*ListTables) responseType() {} 101 | 102 | type Status struct { 103 | Data *catalog.Status 104 | } 105 | 106 | func (*Status) requestType() {} 107 | func (*Status) responseType() {} 108 | 109 | type RaftResponse interface { 110 | responseType() 111 | } 112 | 113 | type ExecuteResp struct { 114 | Data catalog.ResultSet 115 | } 116 | 117 | func (*ExecuteResp) responseType() {} 118 | 119 | type Row struct { 120 | Data []*sql.ValueData 121 | } 122 | 123 | func (*Row) responseType() {} 124 | 125 | type GetTableResp struct { 126 | Data *catalog.Table 127 | } 128 | 129 | func (*GetTableResp) responseType() {} 130 | 131 | type ClientSession struct { 132 | Engine *engine.Raft 133 | SQL *catalog.Session 134 | } 135 | 136 | func NewSession(engine *engine.Raft) *ClientSession { 137 | return &ClientSession{ 138 | Engine: engine, 139 | SQL: engine.Session(), 140 | } 141 | } 142 | 143 | func (s *ClientSession) Handle(conn net.Conn) { 144 | defer conn.Close() 145 | for { 146 | resp, err := s.Request(conn) 147 | if _, ok := err.(*NetConn); ok { 148 | break 149 | } 150 | 151 | prefix := [2]byte{ClientPrefix} 152 | respByte := []byte{} 153 | 154 | if err != nil { 155 | prefix[1] = RespErrPrefix 156 | respByte = util.BinaryStructToByte(&RespError{Errmsg: err.Error()}) 157 | util.SendPrefixMsg(conn, prefix, respByte) 158 | continue 159 | } 160 | 161 | switch v := resp.(type) { 162 | case *ExecuteResp: 163 | prefix[1] = ExecutePrefix 164 | respByte = util.BinaryStructToByte(v) 165 | case *GetTableResp: 166 | prefix[1] = GetTablePrefix 167 | respByte = util.BinaryStructToByte(v) 168 | case *ListTables: 169 | prefix[1] = ListTablesPrefix 170 | respByte = util.BinaryStructToByte(v) 171 | case *Status: 172 | prefix[1] = StatusPrefix 173 | respByte = util.BinaryStructToByte(v) 174 | } 175 | 176 | util.SendPrefixMsg(conn, prefix, respByte) 177 | } 178 | 179 | } 180 | 181 | type NetConn struct { 182 | Err error 183 | } 184 | 185 | func (n *NetConn) Error() string { 186 | return n.Err.Error() 187 | } 188 | 189 | func (s *ClientSession) Request(conn net.Conn) (RaftResponse, error) { 190 | 191 | var prefix [2]byte 192 | _, err := conn.Read(prefix[:]) 193 | if err != nil { 194 | return nil, &NetConn{ 195 | Err: err, 196 | } 197 | } 198 | 199 | if prefix[0] != ClientPrefix { 200 | return nil, errors.New("conn protocol validation failed: invalid packet header") 201 | } 202 | 203 | switch prefix[1] { 204 | case ExecutePrefix: 205 | execute := Execute{} 206 | reqByte := util.ReceiveMsg(conn) 207 | util.ByteToStruct(reqByte, &execute) 208 | sqlStr := strings.Split(execute.Data, ";")[0] 209 | result, err1 := s.SQL.Execute(sqlStr + ";") 210 | if err1 != nil { 211 | return nil, err1 212 | } 213 | return &ExecuteResp{ 214 | Data: result, 215 | }, nil 216 | 217 | case GetTablePrefix: 218 | tableNameByte := util.ReceiveMsg(conn) 219 | tableStr := string(tableNameByte) 220 | table, err1 := s.SQL.ReadWithTxn(func(txn catalog.Transaction) (any, error) { 221 | return txn.MustReadTable(strings.ToLower(tableStr)) 222 | }) 223 | if err1 != nil { 224 | return nil, err1 225 | } 226 | return &GetTableResp{ 227 | Data: table.(*catalog.Table), 228 | }, nil 229 | case ListTablesPrefix: 230 | alltables, _ := s.SQL.ReadWithTxn(func(txn catalog.Transaction) (any, error) { 231 | return txn.ScanTables(), nil 232 | }) 233 | var tables []string 234 | for _, table := range alltables.([]*catalog.Table) { 235 | tables = append(tables, table.Name) 236 | } 237 | return &ListTables{ 238 | Data: tables, 239 | }, nil 240 | case StatusPrefix: 241 | return &Status{ 242 | Data: s.Engine.Status(), 243 | }, nil 244 | } 245 | 246 | return nil, errors.New("conn protocol validation failed: invalid packet header") 247 | 248 | } 249 | 250 | func GobReg() { 251 | gob.Register(&ListTables{}) 252 | gob.Register(&Execute{}) 253 | gob.Register(&ExecuteResp{}) 254 | gob.Register(&Status{}) 255 | gob.Register(&GetTable{}) 256 | gob.Register(&Row{}) 257 | 258 | var buf bytes.Buffer 259 | enc := gob.NewEncoder(&buf) 260 | _ = enc.Encode(&ListTables{}) 261 | _ = enc.Encode(&Execute{}) 262 | _ = enc.Encode(&ExecuteResp{}) 263 | _ = enc.Encode(&Status{}) 264 | _ = enc.Encode(&GetTable{}) 265 | _ = enc.Encode(&Row{}) 266 | 267 | } 268 | -------------------------------------------------------------------------------- /sql/catalog/aggregation.go: -------------------------------------------------------------------------------- 1 | package catalog 2 | 3 | import ( 4 | "bytes" 5 | "cabbageDB/logger" 6 | "cabbageDB/sql" 7 | "errors" 8 | "fmt" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | type AggregationExec struct { 14 | Source Executor 15 | Aggregates []Aggregate 16 | Accumulators map[string][]Accumulator 17 | } 18 | 19 | type Accumulator interface { 20 | Accumulate(value *sql.ValueData) 21 | Aggregate() *sql.ValueData 22 | } 23 | 24 | func (a *AggregationExec) Execute(txn Transaction) (ResultSet, error) { 25 | aggCount := len(a.Aggregates) 26 | 27 | resultSet, err := a.Source.Execute(txn) 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | switch v := resultSet.(type) { 33 | case *QueryResultSet: 34 | if len(v.Rows) == 0 && aggCount > 0 { 35 | // 创建默认空分组键的累加器 36 | keyStr := "" // 空分组键的特殊标识 37 | accs := make([]Accumulator, aggCount) 38 | for i, agg := range a.Aggregates { 39 | accs[i] = NewAccumulator(agg) // 初始化所有聚合器 40 | } 41 | a.Accumulators[keyStr] = accs 42 | } else { 43 | // 正常处理每一行数据 44 | for i := range v.Rows { 45 | // 防御性检查:确保行数据足够分割 46 | if len(v.Rows[i]) < aggCount { 47 | return nil, errors.New("column len err") 48 | } 49 | // 分割当前行的聚合参数和分组键 50 | aggValues := v.Rows[i][:aggCount] // 前N列是聚合函数的输入值 51 | groupKey := v.Rows[i][aggCount:] // 后续列是分组键 52 | // 修复点2: 使用确定性编码代替JSON序列化 53 | keyStr := encodeGroupKey(groupKey) // 自定义编码保证唯一性 54 | // 获取或创建当前分组的累加器 55 | accs, exists := a.Accumulators[keyStr] 56 | if !exists { 57 | accs = make([]Accumulator, aggCount) 58 | for i, agg := range a.Aggregates { 59 | accs[i] = NewAccumulator(agg) // 根据聚合类型初始化 60 | } 61 | a.Accumulators[keyStr] = accs 62 | } 63 | for i2, acc := range accs { 64 | acc.Accumulate(aggValues[i2]) 65 | } 66 | } 67 | } 68 | // 构建结果列 69 | columns := make([]string, len(v.Columns)) 70 | for i := range columns { 71 | if i < aggCount { 72 | agg := a.Aggregates[i] 73 | columns[i] = fmt.Sprintf("%s(%s)", agg.String(), v.Columns[i]) 74 | } else { 75 | columns[i] = v.Columns[i] // 保留分组列原名 76 | } 77 | } 78 | 79 | // 生成最终结果行 80 | var rows [][]*sql.ValueData 81 | for keyStr, accs := range a.Accumulators { 82 | var groupKey []*sql.ValueData 83 | if keyStr != "" { 84 | groupKey = decodeGroupKey(keyStr) 85 | } 86 | 87 | // 构建结果行:聚合结果 + 分组键 88 | row := make([]*sql.ValueData, 0, len(accs)+len(groupKey)) 89 | for _, acc := range accs { 90 | val := acc.Aggregate() // 获取最终聚合值 91 | row = append(row, val) 92 | } 93 | row = append(row, groupKey...) 94 | rows = append(rows, row) 95 | } 96 | 97 | return &QueryResultSet{ 98 | Columns: columns, 99 | Rows: rows, 100 | }, nil 101 | 102 | } 103 | return nil, errors.New("AggregationExec Invalid return ResultSet") 104 | } 105 | 106 | func encodeKey(k *sql.ValueData) string { 107 | var buf bytes.Buffer 108 | switch k.Type { 109 | case sql.NullType: 110 | buf.WriteString("nNULL|") // n前缀表示null 111 | 112 | case sql.FloatType: 113 | // 保持与String()相同的精度(4位小数) 114 | f := k.Value.(float64) 115 | buf.WriteString(fmt.Sprintf("f%.4f|", f)) // f前缀表示float 116 | 117 | case sql.BoolType: 118 | // 严格对应String()的TRUE/FALSE全大写 119 | if k.Value.(bool) { 120 | buf.WriteString("bTRUE|") // b前缀表示bool 121 | } else { 122 | buf.WriteString("bFALSE|") 123 | } 124 | 125 | case sql.IntType: 126 | // 统一转换为十进制字符串 127 | var num int64 128 | switch n := k.Value.(type) { 129 | case int: 130 | num = int64(n) 131 | case int64: 132 | num = n 133 | case uint64: 134 | num = int64(n) // 可能溢出,但保持与String()一致 135 | } 136 | buf.WriteString(fmt.Sprintf("i%d|", num)) // i前缀表示integer 137 | 138 | case sql.StringType: 139 | // 包含长度防御(防止s3:ab和s2:abc碰撞) 140 | s := k.Value.(string) 141 | buf.WriteString(fmt.Sprintf("s%d:%s|", len(s), s)) // s前缀表示string 142 | 143 | default: 144 | logger.Info("unsupported type: %v", k.Type) 145 | return "" 146 | } 147 | return buf.String() 148 | } 149 | 150 | func encodeGroupKey(keys []*sql.ValueData) string { 151 | var buf bytes.Buffer 152 | for _, k := range keys { 153 | // 根据String()方法逻辑处理各类型 154 | keyStr := encodeKey(k) 155 | buf.WriteString(keyStr) 156 | } 157 | return buf.String() 158 | } 159 | func decodeGroupKey(s string) []*sql.ValueData { 160 | // 步骤1:按分隔符"|"切分编码字符串 161 | parts := strings.Split(s, "|") 162 | var values []*sql.ValueData 163 | 164 | for _, part := range parts { 165 | // 跳过空段(最后一个竖线后的空字符串) 166 | if part == "" { 167 | continue 168 | } 169 | 170 | // 步骤2:验证基础格式 171 | if len(part) < 1 { 172 | logger.Info("invalid encoded part: empty string") 173 | return nil 174 | } 175 | 176 | // 步骤3:提取类型前缀和数据部分 177 | prefix := part[0] 178 | data := part[1:] 179 | 180 | switch prefix { 181 | // 案例1:Null类型处理 182 | case 'n': 183 | if data != "NULL" { 184 | logger.Info("invalid null encoding: %s", part) 185 | return nil 186 | } 187 | values = append(values, &sql.ValueData{ 188 | Type: sql.NullType, 189 | Value: nil, 190 | }) 191 | 192 | // 案例2:Float类型处理 193 | case 'f': 194 | val, err := strconv.ParseFloat(data, 64) 195 | if err != nil { 196 | logger.Info("float parse error: %s, %v", data, err) 197 | return nil 198 | } 199 | values = append(values, &sql.ValueData{ 200 | Type: sql.FloatType, 201 | Value: val, 202 | }) 203 | 204 | // 案例3:Bool类型处理 205 | case 'b': 206 | switch data { 207 | case "TRUE": 208 | values = append(values, &sql.ValueData{ 209 | Type: sql.BoolType, 210 | Value: true, 211 | }) 212 | case "FALSE": 213 | values = append(values, &sql.ValueData{ 214 | Type: sql.BoolType, 215 | Value: false, 216 | }) 217 | default: 218 | logger.Info("invalid bool value: %s", data) 219 | return nil 220 | } 221 | 222 | // 案例4:Int类型处理 223 | case 'i': 224 | val, err := strconv.ParseInt(data, 10, 64) 225 | if err != nil { 226 | logger.Info("int parse error: %s, %v", data, err) 227 | return nil 228 | } 229 | values = append(values, &sql.ValueData{ 230 | Type: sql.IntType, 231 | Value: val, 232 | }) 233 | 234 | // 案例5:String类型处理(核心难点) 235 | case 's': 236 | // 步骤5.1:切分长度和实际内容 237 | split := strings.SplitN(data, ":", 2) 238 | if len(split) != 2 { 239 | logger.Info("invalid string format: %s", data) 240 | return nil 241 | } 242 | 243 | // 步骤5.2:解析声明长度 244 | length, err := strconv.Atoi(split[0]) 245 | if err != nil { 246 | 247 | logger.Info("invalid length: %s", split[0]) 248 | return nil 249 | } 250 | 251 | // 步骤5.3:验证实际长度 252 | str := split[1] 253 | if len(str) != length { 254 | 255 | logger.Info("length mismatch: declared %d, actual %d", length, len(str)) 256 | return nil 257 | } 258 | 259 | values = append(values, &sql.ValueData{ 260 | Type: sql.StringType, 261 | Value: str, 262 | }) 263 | 264 | // 案例6:未知类型处理 265 | default: 266 | logger.Info("unknown prefix: %c", prefix) 267 | return nil 268 | } 269 | } 270 | 271 | return values 272 | } 273 | 274 | func NewAccumulator(agg Aggregate) Accumulator { 275 | switch agg { 276 | case Average: 277 | return &AverageAcc{ 278 | Count: &CountAcc{}, 279 | Sum: &SumAcc{}, 280 | } 281 | case Max: 282 | return &MaxAcc{} 283 | case Min: 284 | return &MinAcc{} 285 | case Count: 286 | return &CountAcc{} 287 | case Sum: 288 | return &SumAcc{} 289 | } 290 | return nil 291 | } 292 | 293 | type AverageAcc struct { 294 | Count *CountAcc 295 | Sum *SumAcc 296 | } 297 | 298 | func (a *AverageAcc) Accumulate(cmp *sql.ValueData) { 299 | a.Count.Accumulate(cmp) 300 | a.Sum.Accumulate(cmp) 301 | } 302 | 303 | func (a *AverageAcc) Aggregate() *sql.ValueData { 304 | sum := a.Sum.Aggregate() 305 | count := a.Count.Aggregate() 306 | if sum.Type == sql.IntType && count.Type == sql.IntType { 307 | value := &sql.ValueData{ 308 | Type: sql.IntType, 309 | } 310 | 311 | switch sum.Value.(type) { 312 | case int64: 313 | value.Value = sum.Value.(int64) / count.Value.(int64) 314 | case int: 315 | value.Value = sum.Value.(int) / count.Value.(int) 316 | 317 | } 318 | return value 319 | } 320 | if sum.Type == sql.FloatType && count.Type == sql.IntType { 321 | 322 | var intV float64 323 | switch count.Value.(type) { 324 | case int64: 325 | intV = float64(count.Value.(int64)) 326 | case int: 327 | intV = float64(count.Value.(int)) 328 | } 329 | 330 | return &sql.ValueData{ 331 | Type: sql.FloatType, 332 | Value: sum.Value.(float64) / intV, 333 | } 334 | } 335 | 336 | return &sql.ValueData{ 337 | Type: sql.NullType, 338 | } 339 | 340 | } 341 | 342 | type MaxAcc struct { 343 | MaxV *sql.ValueData 344 | } 345 | 346 | func (m *MaxAcc) Accumulate(cmp *sql.ValueData) { 347 | if cmp == nil { 348 | return 349 | } 350 | if m.MaxV == nil { 351 | m.MaxV = &sql.ValueData{ 352 | Type: cmp.Type, 353 | Value: cmp.Value, 354 | } 355 | return 356 | } 357 | if m.MaxV.Type != cmp.Type { 358 | m.MaxV = nil 359 | return 360 | } 361 | if m.MaxV.Type == sql.IntType { 362 | switch m.MaxV.Value.(type) { 363 | case int64: 364 | maxv := m.MaxV.Value.(int64) 365 | cmpv := cmp.Value.(int64) 366 | if maxv < cmpv { 367 | m.MaxV = cmp 368 | } 369 | return 370 | case int: 371 | maxv := m.MaxV.Value.(int) 372 | cmpv := cmp.Value.(int) 373 | if maxv < cmpv { 374 | m.MaxV = cmp 375 | } 376 | return 377 | 378 | } 379 | 380 | } 381 | 382 | if m.MaxV.Type == sql.FloatType { 383 | maxv := m.MaxV.Value.(float64) 384 | cmpv := cmp.Value.(float64) 385 | if maxv < cmpv { 386 | m.MaxV = cmp 387 | } 388 | return 389 | } 390 | 391 | } 392 | 393 | func (m *MaxAcc) Aggregate() *sql.ValueData { 394 | return m.MaxV 395 | } 396 | 397 | type MinAcc struct { 398 | MinV *sql.ValueData 399 | } 400 | 401 | func (m *MinAcc) Accumulate(cmp *sql.ValueData) { 402 | if cmp == nil { 403 | return 404 | } 405 | if m.MinV == nil { 406 | m.MinV = &sql.ValueData{ 407 | Type: cmp.Type, 408 | Value: cmp.Value, 409 | } 410 | return 411 | } 412 | if m.MinV.Type != cmp.Type { 413 | m.MinV = nil 414 | return 415 | } 416 | if m.MinV.Type == sql.IntType { 417 | switch m.MinV.Value.(type) { 418 | case int64: 419 | minv := m.MinV.Value.(int64) 420 | cmpv := cmp.Value.(int64) 421 | if minv > cmpv { 422 | m.MinV = cmp 423 | } 424 | return 425 | case int: 426 | minv := m.MinV.Value.(int64) 427 | cmpv := cmp.Value.(int64) 428 | if minv > cmpv { 429 | m.MinV = cmp 430 | } 431 | return 432 | } 433 | 434 | } 435 | 436 | if m.MinV.Type == sql.FloatType { 437 | minv := m.MinV.Value.(float64) 438 | cmpv := cmp.Value.(float64) 439 | if minv > cmpv { 440 | m.MinV = cmp 441 | } 442 | return 443 | } 444 | 445 | } 446 | 447 | func (m *MinAcc) Aggregate() *sql.ValueData { 448 | return m.MinV 449 | } 450 | 451 | type SumAcc struct { 452 | SumV *sql.ValueData 453 | } 454 | 455 | func (s *SumAcc) Accumulate(cmp *sql.ValueData) { 456 | if cmp == nil { 457 | return 458 | } 459 | 460 | if s.SumV == nil && cmp != nil { 461 | s.SumV = &sql.ValueData{ 462 | Type: cmp.Type, 463 | Value: cmp.Value, 464 | } 465 | return 466 | } 467 | 468 | if s.SumV.Type == sql.IntType { 469 | switch s.SumV.Value.(type) { 470 | case int64: 471 | sum := s.SumV.Value.(int64) 472 | cmpv := cmp.Value.(int64) 473 | s.SumV.Value = sum + cmpv 474 | case int: 475 | sum := s.SumV.Value.(int) 476 | cmpv := cmp.Value.(int) 477 | s.SumV.Value = sum + cmpv 478 | } 479 | 480 | return 481 | } 482 | 483 | if s.SumV.Type == sql.FloatType { 484 | sum := s.SumV.Value.(float64) 485 | cmpv := cmp.Value.(float64) 486 | s.SumV.Value = sum + cmpv 487 | return 488 | } 489 | 490 | } 491 | 492 | func (s *SumAcc) Aggregate() *sql.ValueData { 493 | return s.SumV 494 | } 495 | 496 | type CountAcc struct { 497 | CountV int 498 | } 499 | 500 | func (c *CountAcc) Accumulate(cmp *sql.ValueData) { 501 | if cmp == nil { 502 | return 503 | } 504 | if cmp.Type == sql.NullType { 505 | return 506 | } 507 | c.CountV += 1 508 | } 509 | func (c *CountAcc) Aggregate() *sql.ValueData { 510 | return &sql.ValueData{ 511 | Type: sql.IntType, 512 | Value: c.CountV, 513 | } 514 | } 515 | -------------------------------------------------------------------------------- /sql/catalog/execinterface.go: -------------------------------------------------------------------------------- 1 | package catalog 2 | 3 | type Executor interface { 4 | Execute(txc Transaction) (ResultSet, error) 5 | } 6 | 7 | func ExecBuild(node Node) Executor { 8 | switch v := node.(type) { 9 | case *AggregationNode: 10 | return &AggregationExec{ 11 | Source: ExecBuild(v.Source), 12 | Aggregates: v.Aggregates, 13 | Accumulators: make(map[string][]Accumulator), 14 | } 15 | case *CreateTableNode: 16 | return &CreateTableExec{Table: v.Schema} 17 | case *DeleteNode: 18 | return &DeteleExec{Table: v.TableName, Source: ExecBuild(v.Source)} 19 | case *DropTableNode: 20 | return &DropTableExec{Table: v.TableName} 21 | case *FilterNode: 22 | return &FilterExec{Source: ExecBuild(v.Source), Predicate: v.Predicate} 23 | case *HashJoinNode: 24 | return &HashJoinExec{ 25 | Left: ExecBuild(v.Left), 26 | LeftField: v.LeftField.Index, 27 | Right: ExecBuild(v.Right), 28 | RightField: v.RightField.Index, 29 | Outer: v.Outer, 30 | } 31 | case *IndexLookupNode: 32 | return &IndexLookupExec{ 33 | Table: v.Table, 34 | Column: v.ColumnName, 35 | Values: v.Values, 36 | } 37 | case *InsertNode: 38 | return &InsertExec{ 39 | Table: v.TableName, 40 | Columns: v.ColumnNames, 41 | Rows: v.Expressions, 42 | } 43 | 44 | case *KeyLookupNode: 45 | return &KeyLookupExec{ 46 | Table: v.TableName, 47 | Keys: v.Keys, 48 | } 49 | case *LimitNode: 50 | return &LimitExec{ 51 | Source: ExecBuild(v.Source), 52 | Limit: v.Limit, 53 | } 54 | case *NestedLoopJoinNode: 55 | return &NestedLoopJoinExec{ 56 | Left: ExecBuild(v.Left), 57 | Right: ExecBuild(v.Right), 58 | Predicate: v.Predicate, 59 | Outer: v.Outer, 60 | } 61 | case *NothingNode: 62 | return &NothingExec{} 63 | case *OffsetNode: 64 | return &OffsetExec{ 65 | Source: ExecBuild(v.Source), 66 | Offset: v.Num, 67 | } 68 | case *OrderNode: 69 | return &OrderExec{ 70 | Source: ExecBuild(v.Source), 71 | Orders: v.Orders, 72 | } 73 | case *ProjectionNode: 74 | return &ProjectionExec{ 75 | Source: ExecBuild(v.Source), 76 | Expressions: v.Expressions, 77 | } 78 | case *ScanNode: 79 | return &ScanExec{ 80 | Table: v.TableName, 81 | Filter: v.Filter, 82 | } 83 | case *UpdateNode: 84 | return &UpdateExec{ 85 | Table: v.TableName, 86 | Source: ExecBuild(v.Source), 87 | Expressions: v.Expressions, 88 | } 89 | } 90 | return nil 91 | } 92 | -------------------------------------------------------------------------------- /sql/catalog/interface.go: -------------------------------------------------------------------------------- 1 | package catalog 2 | 3 | import ( 4 | "bytes" 5 | "cabbageDB/log" 6 | "cabbageDB/sql" 7 | "cabbageDB/sql/expr" 8 | "cabbageDB/sqlparser/ast" 9 | "cabbageDB/sqlparser/parser" 10 | "cabbageDB/storage" 11 | "cabbageDB/util" 12 | "encoding/gob" 13 | "errors" 14 | ) 15 | 16 | type ValueHashSet = map[*sql.ValueData]struct{} 17 | 18 | type Engine interface { 19 | Begin() Transaction 20 | BeginReadOnly() Transaction 21 | BeginAsOf(version uint64) Transaction 22 | Session() *Session 23 | } 24 | 25 | type Transaction interface { 26 | Catalog 27 | Version() uint64 28 | ReadOnly() bool 29 | Commit() bool 30 | Rollback() bool 31 | Create(table string, row []*sql.ValueData) error 32 | Delete(table string, id *sql.ValueData) error 33 | Read(table string, id *sql.ValueData) ([]*sql.ValueData, error) 34 | ReadIndex(table string, column string, value *sql.ValueData) (ValueHashSet, error) 35 | Scan(table string, filter expr.Expression) ([][]*sql.ValueData, error) 36 | ScanIndex(table string, column string) ([]*IndexValue, error) 37 | Update(table string, id *sql.ValueData, row []*sql.ValueData) error 38 | } 39 | 40 | type Session struct { 41 | Engine Engine 42 | Txn Transaction 43 | } 44 | 45 | func (s *Session) Execute(query string) (ResultSet, error) { 46 | p := parser.Parser{} 47 | p.Reset(query) 48 | ret := p.ParseSQL() 49 | if ret != 0 { 50 | if p.ErrorMsg != nil { 51 | return nil, p.ErrorMsg 52 | } 53 | return nil, errors.New("parse sql error") 54 | } 55 | 56 | for _, result := range p.Result { 57 | switch v := result.(type) { 58 | case *ast.BeginStmt: 59 | if s.Txn != nil { 60 | return nil, errors.New("Already in a transaction") 61 | } 62 | 63 | if v.ReadOnly == true { 64 | if v.AsOf == 0 { 65 | txn := s.Engine.BeginReadOnly() 66 | resultSet := &BeginResultSet{Version: txn.Version(), ReadOnly: true} 67 | s.Txn = txn 68 | return resultSet, nil 69 | } else { 70 | txn := s.Engine.BeginAsOf(v.AsOf) 71 | resultSet := &BeginResultSet{Version: v.AsOf, ReadOnly: true} 72 | s.Txn = txn 73 | return resultSet, nil 74 | } 75 | } else { 76 | if v.AsOf != 0 { 77 | return nil, errors.New("Can't start read-write transaction in a given version") 78 | } 79 | txn := s.Engine.Begin() 80 | resultSet := &BeginResultSet{Version: txn.Version(), ReadOnly: false} 81 | s.Txn = txn 82 | return resultSet, nil 83 | } 84 | case *ast.CommitStmt: 85 | if s.Txn == nil { 86 | return nil, errors.New("Not in a transaction") 87 | } 88 | txn := s.Txn 89 | s.Txn = nil 90 | version := txn.Version() 91 | txn.Commit() 92 | return &CommitResultSet{Version: version}, nil 93 | case *ast.RollbackStmt: 94 | if s.Txn == nil { 95 | return nil, errors.New("Not in a transaction") 96 | } 97 | txn := s.Txn 98 | s.Txn = nil 99 | version := txn.Version() 100 | txn.Rollback() 101 | return &RollbackResultSet{Version: version}, nil 102 | case *ast.ExplainStmt: 103 | 104 | node, err := s.ReadWithTxn(func(txn Transaction) (any, error) { 105 | plan, err1 := Build(v.Stmt, txn) 106 | if err1 != nil { 107 | return nil, err1 108 | } 109 | return plan.Optimize(txn).Node, nil 110 | }) 111 | if err != nil { 112 | return nil, err 113 | } 114 | 115 | return &ExplainResultSet{ 116 | NodeInfo: FormatNode(node.(Node), 1), 117 | }, nil 118 | case *ast.SelectStmt: 119 | if s.Txn != nil { 120 | plan, err := Build(result, s.Txn) 121 | if err != nil { 122 | return nil, err 123 | } 124 | resultset, err1 := plan.Optimize(s.Txn).Execute(s.Txn) 125 | if err1 != nil { 126 | return nil, err1 127 | } 128 | return resultset, nil 129 | } 130 | 131 | txn := s.Engine.BeginReadOnly() 132 | plan, err := Build(result, txn) 133 | if err != nil { 134 | return nil, err 135 | } 136 | resultset, err1 := plan.Optimize(txn).Execute(txn) 137 | if err1 != nil { 138 | return nil, err1 139 | } 140 | txn.Rollback() 141 | return resultset, nil 142 | default: 143 | if s.Txn != nil { 144 | plan, err := Build(result, s.Txn) 145 | if err != nil { 146 | return nil, err 147 | } 148 | 149 | return plan.Optimize(s.Txn).Execute(s.Txn) 150 | } 151 | txn := s.Engine.Begin() 152 | plan, err := Build(result, txn) 153 | if err != nil { 154 | return nil, err 155 | } 156 | resultset, err1 := plan.Optimize(txn).Execute(txn) 157 | if err1 != nil { 158 | txn.Rollback() 159 | return nil, err1 160 | } 161 | txn.Commit() 162 | return resultset, nil 163 | } 164 | 165 | } 166 | 167 | return nil, nil 168 | } 169 | 170 | func (s *Session) ReadWithTxn(f func(transaction Transaction) (any, error)) (any, error) { 171 | if s.Txn != nil { 172 | return f(s.Txn) 173 | } 174 | txn := s.Engine.BeginReadOnly() 175 | res, err := f(txn) 176 | if err != nil { 177 | return nil, err 178 | } 179 | txn.Rollback() 180 | return res, nil 181 | } 182 | 183 | type IndexValue struct { 184 | Value interface{} 185 | ValueHashSet storage.VersionHashSet 186 | } 187 | 188 | type Status struct { 189 | RaftStatus *log.NodeStatus 190 | MVCCStatus *storage.Status 191 | } 192 | 193 | const ( 194 | ResultSetPrefix = 0x09 195 | BeginResultSetPrefix = 0x02 196 | CommitResultSetPrefix = 0x03 197 | RollbackResultSetPrefix = 0x04 198 | CreateResultSetPrefix = 0x05 199 | DeleteResultSetPrefix = 0x06 200 | CreateTableResultSetPrefix = 0x07 201 | DropTableResultSetPrefix = 0x08 202 | QueryResultSetPrefix = 0x09 203 | ExplainResultSetPrefix = 0x10 204 | ) 205 | 206 | type ResultSet interface { 207 | ResultSetEncode() []byte 208 | } 209 | 210 | func GetReusltSetPrefix(result ResultSet) byte { 211 | switch result.(type) { 212 | case *BeginResultSet: 213 | return BeginResultSetPrefix 214 | case *CommitResultSet: 215 | return CommitResultSetPrefix 216 | case *RollbackResultSet: 217 | return RollbackResultSetPrefix 218 | case *CreateResultSet: 219 | return CreateResultSetPrefix 220 | case *DeleteResultSet: 221 | return DeleteResultSetPrefix 222 | case *CreateTableResultSet: 223 | return CreateTableResultSetPrefix 224 | case *DropTableResultSet: 225 | return DropTableResultSetPrefix 226 | case *QueryResultSet: 227 | return QueryResultSetPrefix 228 | case *ExplainResultSet: 229 | return ExplainResultSetPrefix 230 | } 231 | return ResultSetPrefix 232 | } 233 | 234 | type BeginResultSet struct { 235 | Version uint64 236 | ReadOnly bool 237 | } 238 | 239 | func (b *BeginResultSet) ResultSetEncode() []byte { 240 | return append([]byte{ResultSetPrefix, BeginResultSetPrefix}, util.BinaryStructToByte(b)...) 241 | } 242 | 243 | type CommitResultSet struct { 244 | Version uint64 245 | } 246 | 247 | func (c *CommitResultSet) ResultSetEncode() []byte { 248 | return append([]byte{ResultSetPrefix, CommitResultSetPrefix}, util.BinaryStructToByte(c)...) 249 | 250 | } 251 | 252 | type RollbackResultSet struct { 253 | Version uint64 254 | } 255 | 256 | func (r *RollbackResultSet) ResultSetEncode() []byte { 257 | return append([]byte{ResultSetPrefix, RollbackResultSetPrefix}, util.BinaryStructToByte(r)...) 258 | } 259 | 260 | type CreateResultSet struct { 261 | Count int 262 | } 263 | 264 | func (c *CreateResultSet) ResultSetEncode() []byte { 265 | return append([]byte{ResultSetPrefix, CreateResultSetPrefix}, util.BinaryStructToByte(c)...) 266 | } 267 | 268 | type DeleteResultSet struct { 269 | Count int 270 | } 271 | 272 | func (d *DeleteResultSet) ResultSetEncode() []byte { 273 | return append([]byte{ResultSetPrefix, DeleteResultSetPrefix}, util.BinaryStructToByte(d)...) 274 | } 275 | 276 | type UpdateResultSet struct { 277 | Count int 278 | } 279 | 280 | func (u *UpdateResultSet) ResultSetEncode() []byte { 281 | return append([]byte{ResultSetPrefix, DeleteResultSetPrefix}, util.BinaryStructToByte(u)...) 282 | } 283 | 284 | type CreateTableResultSet struct { 285 | Name string 286 | } 287 | 288 | func (c *CreateTableResultSet) ResultSetEncode() []byte { 289 | return append([]byte{ResultSetPrefix, CreateTableResultSetPrefix}, util.BinaryStructToByte(c)...) 290 | } 291 | 292 | type DropTableResultSet struct { 293 | Name string 294 | } 295 | 296 | func (d *DropTableResultSet) ResultSetEncode() []byte { 297 | return append([]byte{ResultSetPrefix, DropTableResultSetPrefix}, util.BinaryStructToByte(d)...) 298 | } 299 | 300 | type QueryResultSet struct { 301 | Columns []string 302 | Rows [][]*sql.ValueData 303 | } 304 | 305 | func (q *QueryResultSet) ResultSetEncode() []byte { 306 | return append([]byte{ResultSetPrefix, QueryResultSetPrefix}, util.BinaryStructToByte(q)...) 307 | 308 | } 309 | 310 | type ExplainResultSet struct { 311 | NodeInfo string 312 | } 313 | 314 | func (e *ExplainResultSet) ResultSetEncode() []byte { 315 | return append([]byte{ResultSetPrefix, ExplainResultSetPrefix}, util.BinaryStructToByte(e)...) 316 | 317 | } 318 | 319 | func GobReg() { 320 | //gob.Register(&ReferenceField{}) 321 | 322 | gob.Register(&BeginResultSet{}) 323 | gob.Register(&CommitResultSet{}) 324 | gob.Register(&RollbackResultSet{}) 325 | gob.Register(&CreateResultSet{}) 326 | gob.Register(&DeleteResultSet{}) 327 | gob.Register(&UpdateResultSet{}) 328 | gob.Register(&CreateTableResultSet{}) 329 | gob.Register(&DropTableResultSet{}) 330 | gob.Register(&QueryResultSet{}) 331 | gob.Register(&ExplainResultSet{}) 332 | gob.Register(&sql.ValueData{}) 333 | gob.Register([][]*sql.ValueData{}) 334 | gob.Register(ValueHashSet{}) 335 | 336 | gob.Register(&AggregationNode{}) 337 | gob.Register(&CreateTableNode{}) 338 | gob.Register(&NothingNode{}) 339 | gob.Register(&DeleteNode{}) 340 | gob.Register(&DropTableNode{}) 341 | gob.Register(&FilterNode{}) 342 | gob.Register(&HashJoinNode{}) 343 | gob.Register(&IndexLookupNode{}) 344 | gob.Register(&InsertNode{}) 345 | gob.Register(&KeyLookupNode{}) 346 | gob.Register(&LimitNode{}) 347 | gob.Register(&NestedLoopJoinNode{}) 348 | gob.Register(&OffsetNode{}) 349 | gob.Register(&OrderNode{}) 350 | gob.Register(&ProjectionNode{}) 351 | gob.Register(&ScanNode{}) 352 | gob.Register(&UpdateNode{}) 353 | 354 | var buf bytes.Buffer 355 | enc := gob.NewEncoder(&buf) 356 | //_ = enc.Encode(&ReferenceField{}) 357 | 358 | _ = enc.Encode(&BeginResultSet{}) 359 | _ = enc.Encode(&CommitResultSet{}) 360 | _ = enc.Encode(&RollbackResultSet{}) 361 | _ = enc.Encode(&CreateResultSet{}) 362 | _ = enc.Encode(&DeleteResultSet{}) 363 | _ = enc.Encode(&UpdateResultSet{}) 364 | _ = enc.Encode(&CreateTableResultSet{}) 365 | _ = enc.Encode(&DropTableResultSet{}) 366 | _ = enc.Encode(&QueryResultSet{}) 367 | _ = enc.Encode(&ExplainResultSet{}) 368 | _ = enc.Encode(&sql.ValueData{}) 369 | _ = enc.Encode([][]*sql.ValueData{}) 370 | _ = enc.Encode(ValueHashSet{}) 371 | 372 | _ = enc.Encode(&AggregationNode{}) 373 | _ = enc.Encode(&CreateTableNode{}) 374 | _ = enc.Encode(&NothingNode{}) 375 | _ = enc.Encode(&DeleteNode{}) 376 | _ = enc.Encode(&DropTableNode{}) 377 | _ = enc.Encode(&FilterNode{}) 378 | _ = enc.Encode(&HashJoinNode{}) 379 | _ = enc.Encode(&IndexLookupNode{}) 380 | _ = enc.Encode(&InsertNode{}) 381 | _ = enc.Encode(&KeyLookupNode{}) 382 | _ = enc.Encode(&LimitNode{}) 383 | _ = enc.Encode(&NestedLoopJoinNode{}) 384 | _ = enc.Encode(&OffsetNode{}) 385 | _ = enc.Encode(&OrderNode{}) 386 | _ = enc.Encode(&ProjectionNode{}) 387 | _ = enc.Encode(&ScanNode{}) 388 | _ = enc.Encode(&UpdateNode{}) 389 | 390 | } 391 | -------------------------------------------------------------------------------- /sql/catalog/join.go: -------------------------------------------------------------------------------- 1 | package catalog 2 | 3 | import ( 4 | "cabbageDB/sql" 5 | "cabbageDB/sql/expr" 6 | "errors" 7 | ) 8 | 9 | type NestedLoopJoinExec struct { 10 | Left Executor 11 | Right Executor 12 | Predicate expr.Expression 13 | Outer bool 14 | } 15 | 16 | func (n *NestedLoopJoinExec) Execute(txn Transaction) (ResultSet, error) { 17 | leftResult, err1 := n.Left.Execute(txn) 18 | if err1 != nil { 19 | return nil, err1 20 | } 21 | rightResult, err2 := n.Right.Execute(txn) 22 | if err2 != nil { 23 | return nil, err2 24 | } 25 | 26 | leftQuery, ok1 := leftResult.(*QueryResultSet) 27 | rightQuery, ok2 := rightResult.(*QueryResultSet) 28 | if !ok1 || !ok2 { 29 | 30 | return nil, errors.New("HashJoinExec Invalid return ResultSet") 31 | } 32 | columns := append(leftQuery.Columns, rightQuery.Columns...) 33 | nestRow := &NestedLoopRowsExec{ 34 | Left: leftQuery.Rows, 35 | Right: rightQuery.Rows, 36 | Predicate: n.Predicate, 37 | Outer: n.Outer, 38 | RightWidth: len(rightQuery.Columns), 39 | } 40 | return &QueryResultSet{ 41 | Rows: GetNestedLoopRowsExecRows(nestRow), 42 | Columns: columns, 43 | }, nil 44 | } 45 | 46 | type NestedLoopRowsExec struct { 47 | Left [][]*sql.ValueData 48 | Right [][]*sql.ValueData 49 | Predicate expr.Expression 50 | RightWidth int 51 | Outer bool 52 | } 53 | 54 | func GetNestedLoopRowsExecRows(n *NestedLoopRowsExec) [][]*sql.ValueData { 55 | var results [][]*sql.ValueData 56 | rightEmpty := make([]*sql.ValueData, n.RightWidth) 57 | for i := range rightEmpty { 58 | rightEmpty[i] = &sql.ValueData{ 59 | Type: sql.NullType, 60 | } 61 | } 62 | for leftRow := range n.Left { 63 | hasRightMatches := false 64 | 65 | for _, rightRow := range n.Right { 66 | combinedRow := append(n.Left[leftRow], rightRow...) 67 | if n.Predicate != nil { 68 | value := n.Predicate.Evaluate(combinedRow) 69 | if value == nil { 70 | continue 71 | } 72 | if !(value.Type == sql.BoolType && value.Value == true) { 73 | continue 74 | } 75 | } 76 | results = append(results, combinedRow) 77 | hasRightMatches = true 78 | } 79 | 80 | // 当左表的某一行在右表中没有匹配(hasRightMatches == false) 81 | // 且是外连接(outer == true)时,将左表行(leftRow)与 rightEmpty 合并,生成一个“左行 + 右空值”的完整行,表示右表无匹配。 82 | if n.Outer && !hasRightMatches { 83 | nullRow := append(n.Left[leftRow], rightEmpty...) 84 | results = append(results, nullRow) 85 | } 86 | } 87 | return results 88 | 89 | } 90 | 91 | type HashJoinExec struct { 92 | Left Executor 93 | LeftField int 94 | Right Executor 95 | RightField int 96 | Outer bool 97 | } 98 | 99 | func (h *HashJoinExec) Execute(txn Transaction) (ResultSet, error) { 100 | leftResult, err1 := h.Left.Execute(txn) 101 | if err1 != nil { 102 | return nil, err1 103 | } 104 | rightResult, err2 := h.Right.Execute(txn) 105 | if err2 != nil { 106 | return nil, err2 107 | } 108 | 109 | leftQuery, ok1 := leftResult.(*QueryResultSet) 110 | rightQuery, ok2 := rightResult.(*QueryResultSet) 111 | if !ok1 || !ok2 { 112 | return nil, errors.New("HashJoinExec Invalid return ResultSet") 113 | } 114 | 115 | right := make(map[string][][]*sql.ValueData) 116 | for i := range rightQuery.Rows { 117 | keyStr := encodeKey(rightQuery.Rows[i][h.RightField]) 118 | right[keyStr] = append(right[keyStr], rightQuery.Rows[i]) 119 | } 120 | 121 | columns := append(leftQuery.Columns, rightQuery.Columns...) 122 | 123 | newRows := [][]*sql.ValueData{} 124 | 125 | for i := range leftQuery.Rows { 126 | row := leftQuery.Rows[i] 127 | keyStr := encodeKey(leftQuery.Rows[i][h.LeftField]) 128 | if hit, ok := right[keyStr]; ok { 129 | 130 | for _, rightRow := range hit { 131 | newRow := append(row, rightRow...) 132 | newRows = append(newRows, newRow) 133 | } 134 | 135 | } else if h.Outer { 136 | 137 | empty := make([]*sql.ValueData, len(rightQuery.Columns)) 138 | for i1 := range empty { 139 | empty[i1] = &sql.ValueData{ 140 | Type: sql.NullType, 141 | } 142 | } 143 | newRow := append(leftQuery.Rows[i], empty...) 144 | newRows = append(newRows, newRow) 145 | } 146 | } 147 | return &QueryResultSet{ 148 | Columns: columns, 149 | Rows: newRows, 150 | }, nil 151 | } 152 | -------------------------------------------------------------------------------- /sql/catalog/mutation.go: -------------------------------------------------------------------------------- 1 | package catalog 2 | 3 | import ( 4 | "cabbageDB/sql" 5 | "cabbageDB/sql/expr" 6 | "errors" 7 | ) 8 | 9 | type InsertExec struct { 10 | Table string 11 | Columns []string 12 | Rows [][]expr.Expression 13 | } 14 | 15 | func PadRow(table *Table, row []*sql.ValueData) []*sql.ValueData { 16 | 17 | if len(row) > len(table.Columns) { 18 | return nil 19 | } 20 | for _, defaule := range table.Columns[len(row)-1:] { 21 | if defaule.Default != nil { 22 | row = append(row, defaule.Default) 23 | } 24 | 25 | } 26 | return row 27 | } 28 | 29 | func (i *InsertExec) Execute(txn Transaction) (ResultSet, error) { 30 | table, err := txn.MustReadTable(i.Table) 31 | if err != nil { 32 | return nil, err 33 | } 34 | count := 0 35 | for _, exprs := range i.Rows { 36 | var row []*sql.ValueData 37 | for _, expr := range exprs { 38 | row = append(row, expr.Evaluate(nil)) 39 | } 40 | 41 | if len(i.Columns) == 0 { 42 | row = PadRow(table, row) 43 | } else { 44 | row, err = MakeRow(table, i.Columns, row) 45 | if err != nil { 46 | return nil, err 47 | } 48 | } 49 | err = txn.Create(i.Table, row) 50 | if err != nil { 51 | return nil, err 52 | } 53 | count += 1 54 | } 55 | return &CreateResultSet{ 56 | Count: count, 57 | }, nil 58 | } 59 | 60 | func MakeRow(table *Table, columns []string, values []*sql.ValueData) ([]*sql.ValueData, error) { 61 | if len(columns) != len(values) { 62 | return nil, errors.New("Column and value counts do not match") 63 | } 64 | inputs := make(map[string]*sql.ValueData) 65 | for i := range values { 66 | _, err := table.GetColumn(columns[i]) 67 | if err != nil { 68 | return nil, err 69 | } 70 | 71 | inputs[columns[i]] = values[i] 72 | 73 | } 74 | row := []*sql.ValueData{} 75 | for _, column := range table.Columns { 76 | if v, ok := inputs[column.Name]; ok { 77 | row = append(row, v) 78 | continue 79 | } else if column.Default != nil { 80 | row = append(row, column.Default) 81 | } 82 | 83 | } 84 | 85 | return row, nil 86 | } 87 | 88 | type UpdateExec struct { 89 | Table string 90 | Source Executor 91 | Expressions []*UpdateExpr 92 | } 93 | 94 | func (u *UpdateExec) Execute(txn Transaction) (ResultSet, error) { 95 | resultSet, err := u.Source.Execute(txn) 96 | if err != nil { 97 | return nil, err 98 | } 99 | 100 | switch v := resultSet.(type) { 101 | case *QueryResultSet: 102 | table, err1 := txn.MustReadTable(u.Table) 103 | if err1 != nil { 104 | return nil, err 105 | } 106 | updated := make(map[*sql.ValueData]struct{}) 107 | for i := range v.Rows { 108 | id := table.GetRowKey(v.Rows[i]) 109 | if _, ok := updated[id]; ok { 110 | continue 111 | } 112 | for _, expr := range u.Expressions { 113 | v.Rows[i][expr.Index] = expr.Expr.Evaluate(v.Rows[i]) 114 | } 115 | err2 := txn.Update(table.Name, id, v.Rows[i]) 116 | if err2 != nil { 117 | return nil, err2 118 | } 119 | updated[id] = struct{}{} 120 | } 121 | return &UpdateResultSet{ 122 | Count: len(updated), 123 | }, nil 124 | } 125 | return nil, errors.New("HashJoinExec Invalid return ResultSet") 126 | 127 | } 128 | 129 | type DeteleExec struct { 130 | Table string 131 | Source Executor 132 | } 133 | 134 | func (d *DeteleExec) Execute(txn Transaction) (ResultSet, error) { 135 | table, err := txn.MustReadTable(d.Table) 136 | if err != nil { 137 | return nil, err 138 | } 139 | count := 0 140 | resultSet, err := d.Source.Execute(txn) 141 | if err != nil { 142 | return nil, err 143 | } 144 | 145 | switch v := resultSet.(type) { 146 | case *QueryResultSet: 147 | for i := range v.Rows { 148 | err = txn.Delete(table.Name, table.GetRowKey(v.Rows[i])) 149 | if err != nil { 150 | return nil, err 151 | } 152 | count += 1 153 | } 154 | return &DeleteResultSet{ 155 | Count: count, 156 | }, nil 157 | } 158 | return nil, errors.New("DeteleExec Invalid return ResultSet") 159 | } 160 | -------------------------------------------------------------------------------- /sql/catalog/query.go: -------------------------------------------------------------------------------- 1 | package catalog 2 | 3 | import ( 4 | "cabbageDB/sql" 5 | "cabbageDB/sql/expr" 6 | "errors" 7 | "sort" 8 | ) 9 | 10 | type FilterExec struct { 11 | Source Executor 12 | Predicate expr.Expression 13 | } 14 | 15 | func (f *FilterExec) Execute(txn Transaction) (ResultSet, error) { 16 | resultSet, err := f.Source.Execute(txn) 17 | if err != nil { 18 | return nil, err 19 | } 20 | switch v := resultSet.(type) { 21 | 22 | case *QueryResultSet: 23 | 24 | var filteredRows [][]*sql.ValueData 25 | 26 | for _, row := range v.Rows { 27 | result := f.Predicate.Evaluate(row) 28 | if result == nil { 29 | continue 30 | } 31 | switch result.Type { 32 | case sql.BoolType: 33 | if result.Value == true { 34 | filteredRows = append(filteredRows, row) 35 | } 36 | 37 | } 38 | } 39 | 40 | return &QueryResultSet{ 41 | Columns: v.Columns, 42 | Rows: filteredRows, 43 | }, nil 44 | } 45 | return nil, errors.New("FilterExec Invalid return ResultSet") 46 | } 47 | 48 | type ProjectionExec struct { 49 | Source Executor 50 | Expressions []*ExprAs 51 | } 52 | 53 | func (p *ProjectionExec) Execute(txn Transaction) (ResultSet, error) { 54 | 55 | resultSet, err := p.Source.Execute(txn) 56 | if err != nil { 57 | return nil, err 58 | } 59 | switch v := resultSet.(type) { 60 | 61 | case *QueryResultSet: 62 | columns := []string{} 63 | 64 | exprs := []expr.Expression{} 65 | for _, expr1 := range p.Expressions { 66 | ex, ok := expr1.Expr.(*expr.Field) 67 | if ok && expr1.As == "" && len(v.Columns) > ex.Index { 68 | columns = append(columns, v.Columns[ex.Index]) 69 | } else if expr1.As != "" { 70 | columns = append(columns, expr1.As) 71 | } else { 72 | columns = append(columns, "") 73 | } 74 | exprs = append(exprs, expr1.Expr) 75 | 76 | } 77 | rows := [][]*sql.ValueData{} 78 | for i := range v.Rows { 79 | row := []*sql.ValueData{} 80 | for i1 := range exprs { 81 | row = append(row, exprs[i1].Evaluate(v.Rows[i])) 82 | } 83 | rows = append(rows, row) 84 | } 85 | 86 | return &QueryResultSet{ 87 | Columns: columns, 88 | Rows: rows, 89 | }, nil 90 | 91 | } 92 | return nil, errors.New("ProjectionExec Invalid return ResultSet") 93 | } 94 | 95 | type LimitExec struct { 96 | Source Executor 97 | Limit int 98 | } 99 | 100 | func (l *LimitExec) Execute(txn Transaction) (ResultSet, error) { 101 | 102 | resultSet, err := l.Source.Execute(txn) 103 | if err != nil { 104 | return nil, err 105 | } 106 | switch v := resultSet.(type) { 107 | 108 | case *QueryResultSet: 109 | if l.Limit > len(v.Rows) { 110 | l.Limit = len(v.Rows) 111 | } 112 | return &QueryResultSet{ 113 | Columns: v.Columns, 114 | Rows: v.Rows[:l.Limit], 115 | }, nil 116 | } 117 | return nil, errors.New("LimitExec Invalid return ResultSet") 118 | } 119 | 120 | type OffsetExec struct { 121 | Source Executor 122 | Offset int 123 | } 124 | 125 | func (o *OffsetExec) Execute(txn Transaction) (ResultSet, error) { 126 | resultSet, err := o.Source.Execute(txn) 127 | if err != nil { 128 | return nil, err 129 | } 130 | 131 | switch v := resultSet.(type) { 132 | case *QueryResultSet: 133 | if o.Offset > len(v.Rows) { 134 | o.Offset = len(v.Rows) 135 | } 136 | 137 | return &QueryResultSet{ 138 | Columns: v.Columns, 139 | Rows: v.Rows[o.Offset:], 140 | }, nil 141 | } 142 | return nil, errors.New("OffsetExec Invalid return ResultSet") 143 | } 144 | 145 | type OrderExec struct { 146 | Source Executor 147 | Orders []*DirectionExpr 148 | } 149 | 150 | type Item struct { 151 | Row []*sql.ValueData 152 | Values []*sql.ValueData 153 | } 154 | 155 | func (o *OrderExec) Execute(txn Transaction) (ResultSet, error) { 156 | resultSet, err := o.Source.Execute(txn) 157 | if err != nil { 158 | return nil, err 159 | } 160 | switch v := resultSet.(type) { 161 | case *QueryResultSet: 162 | items := []*Item{} 163 | for i := range v.Rows { 164 | values := []*sql.ValueData{} 165 | for _, exprAs := range o.Orders { 166 | values = append(values, exprAs.Expr.Evaluate(v.Rows[i])) 167 | } 168 | items = append(items, &Item{ 169 | Row: v.Rows[i], 170 | Values: values, 171 | }) 172 | } 173 | 174 | sorter := Sorter{ 175 | items: items, 176 | orders: o.Orders, 177 | } 178 | sort.Sort(sorter) 179 | 180 | result := make([][]*sql.ValueData, len(items)) 181 | for i, item := range items { 182 | result[i] = item.Row 183 | } 184 | 185 | return &QueryResultSet{ 186 | Columns: v.Columns, 187 | Rows: result, 188 | }, nil 189 | 190 | } 191 | return nil, errors.New("OrderExec Invalid return ResultSet") 192 | 193 | } 194 | 195 | type Sorter struct { 196 | items []*Item 197 | orders []*DirectionExpr 198 | } 199 | 200 | func (s Sorter) Len() int { 201 | return len(s.items) 202 | } 203 | 204 | func (s Sorter) Swap(i, j int) { 205 | s.items[i], s.items[j] = s.items[j], s.items[i] 206 | } 207 | 208 | func (s Sorter) Less(i, j int) bool { 209 | a, b := s.items[i], s.items[j] 210 | for idx, order := range s.orders { 211 | valA := a.Values[idx] 212 | valB := b.Values[idx] 213 | 214 | res, comparable1 := valA.Compare(valB) 215 | if !comparable1 { 216 | continue // 跳过不可比的条件 217 | } 218 | 219 | if res == 0 { 220 | continue // 继续下一个排序条件 221 | } 222 | 223 | // 根据方向调整结果 224 | if order.Desc { 225 | res = -res 226 | } 227 | return res < 0 228 | } 229 | return false // 所有条件均相等 230 | } 231 | -------------------------------------------------------------------------------- /sql/catalog/schema.go: -------------------------------------------------------------------------------- 1 | package catalog 2 | 3 | import ( 4 | "cabbageDB/sql" 5 | "errors" 6 | "fmt" 7 | "strings" 8 | ) 9 | 10 | type Table struct { 11 | Name string 12 | Columns []*Column 13 | } 14 | 15 | func (t *Table) ValiDateRow(row []*sql.ValueData, txn Transaction) error { 16 | if len(row) != len(t.Columns) { 17 | return errors.New("Invalid row size for table") 18 | } 19 | pk := t.GetRowKey(row) 20 | for i, column := range t.Columns { 21 | err := column.ValidateValue(t, pk, row[i], txn) 22 | if err != nil { 23 | return err 24 | } 25 | } 26 | return nil 27 | } 28 | 29 | func (t *Table) GetColumn(columnName string) (*Column, error) { 30 | for _, column := range t.Columns { 31 | if column.Name == columnName { 32 | return column, nil 33 | } 34 | } 35 | return nil, errors.New("Column " + columnName + " not found in table " + t.Name) 36 | } 37 | 38 | func (t *Table) GetColumnIndex(name string) (int, error) { 39 | for i, column := range t.Columns { 40 | if column.Name == name { 41 | return i, nil 42 | } 43 | } 44 | 45 | return -1, errors.New("Column " + name + " not found in table " + t.Name) 46 | } 47 | 48 | func (t *Table) ValiDate(txn Transaction) error { 49 | if t.Columns == nil || len(t.Columns) == 0 { 50 | //todo 在这里需要输出错误信息 要不使用log呢,或者专门一个错误通道呢? 51 | return errors.New("table columns is nil") 52 | } 53 | 54 | count := 0 55 | for _, column := range t.Columns { 56 | if column.PrimaryKey { 57 | count++ 58 | } 59 | } 60 | if count != 1 { 61 | 62 | return errors.New("No primary key in table") 63 | } 64 | for _, column := range t.Columns { 65 | err := column.Validate(t, txn) 66 | if err != nil { 67 | return err 68 | } 69 | } 70 | return nil 71 | } 72 | 73 | func (t *Table) GetPrimaryKey() *Column { 74 | for _, column := range t.Columns { 75 | if column.PrimaryKey { 76 | return column 77 | } 78 | } 79 | return nil 80 | } 81 | 82 | func (t *Table) GetRowKey(row []*sql.ValueData) *sql.ValueData { 83 | 84 | for i, column := range t.Columns { 85 | if column.PrimaryKey { 86 | return row[i] 87 | } 88 | } 89 | return nil 90 | } 91 | func (t *Table) String() string { 92 | // 1. 收集列定义 93 | var columns []string 94 | for _, col := range t.Columns { 95 | columns = append(columns, col.String()) 96 | } 97 | 98 | // 2. 收集表级约束(外键、索引) 99 | var constraints []string 100 | for _, col := range t.Columns { 101 | if constraint := col.ConstraintString(); constraint != "" { 102 | for _, line := range strings.Split(constraint, "\n") { 103 | if trimmed := strings.TrimSpace(line); trimmed != "" { 104 | constraints = append(constraints, trimmed) 105 | } 106 | } 107 | } 108 | } 109 | 110 | // 3. 合并列和约束,确保约束在列之后 111 | var parts []string 112 | parts = append(parts, columns...) 113 | parts = append(parts, constraints...) 114 | 115 | // 4. 生成最终 SQL 116 | return fmt.Sprintf( 117 | "CREATE TABLE %s (\n %s\n);", 118 | t.Name, 119 | strings.Join(parts, ",\n "), 120 | ) 121 | } 122 | 123 | type ReferenceField struct { 124 | TableName string 125 | ColumnName string 126 | } 127 | 128 | // 这里的DataType 先这么写,可以在sql解析的时候强制他的类型 129 | type Column struct { 130 | Name string 131 | DataType sql.DataType 132 | PrimaryKey bool 133 | NullAble bool 134 | Default *sql.ValueData 135 | Unique bool 136 | //Reference string 137 | Index bool 138 | Reference *ReferenceField 139 | } 140 | 141 | func (c *Column) Validate(table *Table, txnEngine Transaction) error { 142 | if c.PrimaryKey && c.NullAble { 143 | return errors.New("Primary key " + c.Name + " cannot be nullable") 144 | } 145 | if c.PrimaryKey && !c.Unique { 146 | return errors.New("Primary key " + c.Name + " cannot be unique") 147 | } 148 | 149 | if c.Default != nil { 150 | 151 | dataType := sql.DecodeDataType(c.Default.Value) 152 | 153 | // 只有当Default为NULL 并且c.NullAble的时候才通过 154 | if dataType == sql.NullType { 155 | if !c.NullAble { 156 | return errors.New("Can't use NULL as default value for non-nullable column " + c.Name) 157 | } 158 | } else if c.DataType != dataType { 159 | return errors.New("Default value for column " + c.Name + " has datatype " + c.DataType.String() + ", must be " + dataType.String()) 160 | } 161 | 162 | } 163 | 164 | if c.Reference == nil { 165 | return nil 166 | } 167 | 168 | var target *Table 169 | if c.Reference.TableName != table.Name { 170 | target = txnEngine.ReadTable(c.Reference.TableName) 171 | } else { 172 | target = table 173 | } 174 | if target.Name == "" { 175 | return errors.New("Referenced Table " + c.Reference.TableName + " is not exist") 176 | } 177 | 178 | index, err := target.GetColumnIndex(c.Reference.ColumnName) 179 | if err != nil { 180 | return nil 181 | } 182 | if index == -1 { 183 | return errors.New("Referenced Column " + c.Reference.ColumnName + " is not exist Table" + c.Reference.TableName) 184 | } 185 | 186 | if !target.Columns[index].Index && !target.Columns[index].PrimaryKey { 187 | return errors.New("Referenced Column " + c.Reference.ColumnName + " Must Be Primary Key or Index") 188 | } 189 | 190 | //if c.DataType != target.GetPrimaryKey().DataType { 191 | // return errors.New("Referenced primary key dataType is not equal") 192 | //} 193 | 194 | if c.DataType != target.Columns[index].DataType { 195 | return errors.New("Referenced primary key dataType is not equal") 196 | } 197 | return nil 198 | 199 | } 200 | 201 | func (c *Column) ValidateValue(table *Table, pk *sql.ValueData, value *sql.ValueData, txn Transaction) error { 202 | 203 | datatype := sql.DecodeDataType(value.Value) 204 | switch datatype { 205 | case sql.NullType: 206 | if !c.NullAble { 207 | return errors.New("NULL value not allowed for column:" + c.Name) 208 | } 209 | case sql.IntType, sql.BoolType, sql.StringType, sql.FloatType: 210 | if datatype != c.DataType { 211 | return errors.New("Invalid datatype " + datatype.String() + " for " + c.DataType.String() + " column " + c.Name) 212 | } 213 | } 214 | 215 | if c.Reference != nil { 216 | 217 | // 获取主键值 218 | rowData, err := txn.Scan(c.Reference.TableName, nil) 219 | if err != nil { 220 | return err 221 | } 222 | if len(rowData) == 0 { 223 | return errors.New("Referenced column value " + value.String() + "in table " + c.Reference.TableName + " does not exist") 224 | } 225 | 226 | // 获取t 227 | t := txn.ReadTable(c.Reference.TableName) 228 | index, err1 := t.GetColumnIndex(c.Reference.ColumnName) 229 | if err1 != nil { 230 | return err1 231 | } 232 | if index == -1 { 233 | return errors.New("Referenced Column " + c.Reference.ColumnName + " is not exist Table" + c.Reference.TableName) 234 | } 235 | if t.Columns[index].Index { 236 | valueData, err2 := txn.ReadIndex(t.Name, c.Reference.ColumnName, value) 237 | if err2 != nil { 238 | return err2 239 | } 240 | if len(valueData) == 0 { 241 | return errors.New("Referenced Column Value " + value.String() + " is not exist Table " + c.Reference.TableName) 242 | } 243 | } 244 | if t.Columns[index].PrimaryKey { 245 | flag := false 246 | for _, row := range rowData { 247 | if sql.EqualValue(value, t.GetRowKey(row)) { 248 | flag = true 249 | } 250 | } 251 | if !flag { 252 | return errors.New("Referenced Column Value " + value.String() + " is not exist Table " + c.Reference.TableName) 253 | } 254 | } 255 | } 256 | 257 | if c.Unique && !c.PrimaryKey && value != nil { 258 | //index, err := table.GetColumnIndex(c.Name) 259 | //if err != nil { 260 | // return err 261 | //} 262 | scan, err := txn.ReadIndex(table.Name, c.Name, value) 263 | if err != nil { 264 | return err 265 | } 266 | 267 | if len(scan) != 0 { 268 | return errors.New("Unique value " + value.String() + " already exists for column " + c.Name) 269 | } 270 | //for _, row := range scan { 271 | // if sql.EqualValue(row[index], value) && !sql.EqualValue(table.GetRowKey(row), pk) { 272 | // return errors.New("Unique value " + value.String() + " already exists for column " + c.Name) 273 | // } 274 | //} 275 | } 276 | return nil 277 | 278 | } 279 | 280 | func (c *Column) ConstraintString() string { 281 | var constraints []string 282 | // 外键约束 283 | if c.Reference != nil { 284 | constraints = append(constraints, fmt.Sprintf("FOREIGN KEY (%s) REFERENCES %s(%s)", 285 | c.Name, c.Reference.TableName, c.Reference.ColumnName)) 286 | } 287 | // 索引约束(表级) 288 | if !c.Unique && c.Index { 289 | constraints = append(constraints, fmt.Sprintf("KEY %s (%s)", 290 | c.Name, c.Name)) 291 | } 292 | return strings.Join(constraints, "\n") 293 | } 294 | 295 | func (c *Column) String() string { 296 | var builder strings.Builder 297 | 298 | // 基础列名和类型 299 | builder.WriteString(fmt.Sprintf("%s %s", c.Name, c.DataType.String())) 300 | 301 | // 主键约束 302 | if c.PrimaryKey { 303 | builder.WriteString(" PRIMARY KEY") 304 | } 305 | 306 | // 非空约束 (主键默认隐含 NOT NULL,故不重复添加) 307 | if !c.NullAble && !c.PrimaryKey { 308 | builder.WriteString(" NOT NULL") 309 | } 310 | 311 | // 默认值 312 | if c.Default != nil { 313 | builder.WriteString(fmt.Sprintf(" DEFAULT %v", c.Default)) 314 | } 315 | 316 | // 唯一约束 (主键隐含唯一性,故不重复添加) 317 | if c.Unique && !c.PrimaryKey { 318 | builder.WriteString(" UNIQUE") 319 | } 320 | 321 | return builder.String() 322 | } 323 | 324 | type Catalog interface { 325 | CreateTable(table *Table) error 326 | DeleteTable(tableName string) error 327 | ReadTable(tableName string) *Table 328 | ScanTables() []*Table 329 | MustReadTable(tableName string) (*Table, error) 330 | TableReferences(tableName string, withSelf bool) []*TableReferences 331 | } 332 | 333 | type TableReferences struct { 334 | TableName string 335 | ColumnReferences []string 336 | } 337 | -------------------------------------------------------------------------------- /sql/catalog/schemaexec.go: -------------------------------------------------------------------------------- 1 | package catalog 2 | 3 | type CreateTableExec struct { 4 | Table *Table 5 | } 6 | 7 | func (c *CreateTableExec) Execute(txn Transaction) (ResultSet, error) { 8 | err := txn.CreateTable(c.Table) 9 | if err != nil { 10 | return nil, err 11 | } 12 | return &CreateTableResultSet{ 13 | Name: c.Table.Name, 14 | }, nil 15 | } 16 | 17 | type DropTableExec struct { 18 | Table string 19 | } 20 | 21 | func (d *DropTableExec) Execute(txn Transaction) (ResultSet, error) { 22 | err := txn.DeleteTable(d.Table) 23 | if err != nil { 24 | return nil, err 25 | } 26 | return &DropTableResultSet{ 27 | Name: d.Table, 28 | }, nil 29 | } 30 | -------------------------------------------------------------------------------- /sql/catalog/source.go: -------------------------------------------------------------------------------- 1 | package catalog 2 | 3 | import ( 4 | "cabbageDB/sql" 5 | "cabbageDB/sql/expr" 6 | ) 7 | 8 | type ScanExec struct { 9 | Table string 10 | Filter expr.Expression 11 | } 12 | 13 | func (s *ScanExec) Execute(txn Transaction) (ResultSet, error) { 14 | table, err := txn.MustReadTable(s.Table) 15 | if err != nil { 16 | return nil, err 17 | } 18 | columns := make([]string, len(table.Columns)) 19 | for i := range table.Columns { 20 | columns[i] = table.Columns[i].Name 21 | } 22 | rows, err1 := txn.Scan(s.Table, s.Filter) 23 | if err1 != nil { 24 | return nil, err1 25 | } 26 | return &QueryResultSet{ 27 | Columns: columns, 28 | Rows: rows, 29 | }, nil 30 | 31 | } 32 | 33 | type KeyLookupExec struct { 34 | Table string 35 | Keys []*sql.ValueData 36 | } 37 | 38 | func (k *KeyLookupExec) Execute(txn Transaction) (ResultSet, error) { 39 | table, err := txn.MustReadTable(k.Table) 40 | if err != nil { 41 | return nil, err 42 | } 43 | 44 | rows := [][]*sql.ValueData{} 45 | for _, key := range k.Keys { 46 | value, err1 := txn.Read(k.Table, key) 47 | if err1 != nil { 48 | return nil, err1 49 | } 50 | rows = append(rows, value) 51 | } 52 | columns := []string{} 53 | for _, column := range table.Columns { 54 | columns = append(columns, column.Name) 55 | } 56 | return &QueryResultSet{ 57 | Columns: columns, 58 | Rows: rows, 59 | }, nil 60 | } 61 | 62 | type IndexLookupExec struct { 63 | Table string 64 | Column string 65 | Values []*sql.ValueData 66 | } 67 | 68 | func (i *IndexLookupExec) Execute(txn Transaction) (ResultSet, error) { 69 | table, err := txn.MustReadTable(i.Table) 70 | if err != nil { 71 | return nil, err 72 | } 73 | pks := make(map[*sql.ValueData]struct{}) 74 | for _, value := range i.Values { 75 | idx, err1 := txn.ReadIndex(i.Table, i.Column, value) 76 | if err1 != nil { 77 | return nil, err1 78 | } 79 | for i1, _ := range idx { 80 | pks[i1] = struct{}{} 81 | } 82 | } 83 | 84 | rows := [][]*sql.ValueData{} 85 | for i1, _ := range pks { 86 | value, err1 := txn.Read(i.Table, i1) 87 | if err1 != nil { 88 | return nil, err1 89 | } 90 | rows = append(rows, value) 91 | } 92 | colums := []string{} 93 | for i1 := range table.Columns { 94 | colums = append(colums, table.Columns[i1].Name) 95 | } 96 | 97 | return &QueryResultSet{ 98 | Columns: colums, 99 | Rows: rows, 100 | }, nil 101 | } 102 | 103 | type NothingExec struct { 104 | } 105 | 106 | func (n *NothingExec) Execute(txn Transaction) (ResultSet, error) { 107 | return &QueryResultSet{ 108 | Columns: []string{}, 109 | Rows: make([][]*sql.ValueData, 1), 110 | }, nil 111 | } 112 | -------------------------------------------------------------------------------- /sql/expr/fn.go: -------------------------------------------------------------------------------- 1 | package expr 2 | 3 | type ExprFnBool func(expression Expression) bool 4 | type ExprFn func(expression Expression) Expression 5 | 6 | func Contains(expr Expression, fn ExprFnBool) bool { 7 | return !Walk(expr, func(expr1 Expression) bool { 8 | return !fn(expr1) 9 | }) 10 | } 11 | 12 | func Walk(expr Expression, visitor ExprFnBool) bool { 13 | visitorBool := visitor(expr) 14 | 15 | var exprBool bool 16 | switch v := expr.(type) { 17 | case *Add: 18 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 19 | case *And: 20 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 21 | case *Divide: 22 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 23 | case *Equal: 24 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 25 | case *Exponentiate: 26 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 27 | case *GreaterThan: 28 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 29 | case *LessThan: 30 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 31 | case *Like: 32 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 33 | case *Modulo: 34 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 35 | case *Multiply: 36 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 37 | case *Or: 38 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 39 | case *Subtract: 40 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 41 | case *Assert: 42 | exprBool = Walk(v.L, visitor) 43 | case *Factorial: 44 | exprBool = Walk(v.L, visitor) 45 | case *IsNull: 46 | exprBool = Walk(v.L, visitor) 47 | case *Negate: 48 | exprBool = Walk(v.L, visitor) 49 | case *Not: 50 | exprBool = Walk(v.L, visitor) 51 | case *Constant, *Field: 52 | 53 | exprBool = true 54 | 55 | } 56 | 57 | return visitorBool && exprBool 58 | } 59 | 60 | func ReplaceWith(expr Expression, fn ExprFn) Expression { 61 | return fn(expr) 62 | } 63 | 64 | func Transform(expr Expression, before ExprFn, after ExprFn) Expression { 65 | expr = before(expr) 66 | switch v := expr.(type) { 67 | case *Add: 68 | expr = &Add{ 69 | L: ReplaceWith(v.L, func(expression Expression) Expression { 70 | return Transform(expression, before, after) 71 | }), 72 | R: ReplaceWith(v.R, func(expression Expression) Expression { 73 | return Transform(expression, before, after) 74 | }), 75 | } 76 | 77 | case *And: 78 | expr = &And{ 79 | L: ReplaceWith(v.L, func(expression Expression) Expression { 80 | return Transform(expression, before, after) 81 | }), 82 | R: ReplaceWith(v.R, func(expression Expression) Expression { 83 | return Transform(expression, before, after) 84 | }), 85 | } 86 | case *Divide: 87 | expr = &Divide{ 88 | L: ReplaceWith(v.L, func(expression Expression) Expression { 89 | return Transform(expression, before, after) 90 | }), 91 | R: ReplaceWith(v.R, func(expression Expression) Expression { 92 | return Transform(expression, before, after) 93 | }), 94 | } 95 | case *Equal: 96 | expr = &Equal{ 97 | L: ReplaceWith(v.L, func(expression Expression) Expression { 98 | return Transform(expression, before, after) 99 | }), 100 | R: ReplaceWith(v.R, func(expression Expression) Expression { 101 | return Transform(expression, before, after) 102 | }), 103 | } 104 | case *Exponentiate: 105 | expr = &Exponentiate{ 106 | L: ReplaceWith(v.L, func(expression Expression) Expression { 107 | return Transform(expression, before, after) 108 | }), 109 | R: ReplaceWith(v.R, func(expression Expression) Expression { 110 | return Transform(expression, before, after) 111 | }), 112 | } 113 | case *GreaterThan: 114 | expr = &GreaterThan{ 115 | L: ReplaceWith(v.L, func(expression Expression) Expression { 116 | return Transform(expression, before, after) 117 | }), 118 | R: ReplaceWith(v.R, func(expression Expression) Expression { 119 | return Transform(expression, before, after) 120 | }), 121 | } 122 | case *LessThan: 123 | expr = &LessThan{ 124 | L: ReplaceWith(v.L, func(expression Expression) Expression { 125 | return Transform(expression, before, after) 126 | }), 127 | R: ReplaceWith(v.R, func(expression Expression) Expression { 128 | return Transform(expression, before, after) 129 | }), 130 | } 131 | case *Like: 132 | expr = &Like{ 133 | L: ReplaceWith(v.L, func(expression Expression) Expression { 134 | return Transform(expression, before, after) 135 | }), 136 | R: ReplaceWith(v.R, func(expression Expression) Expression { 137 | return Transform(expression, before, after) 138 | }), 139 | } 140 | case *Modulo: 141 | expr = &Modulo{ 142 | L: ReplaceWith(v.L, func(expression Expression) Expression { 143 | return Transform(expression, before, after) 144 | }), 145 | R: ReplaceWith(v.R, func(expression Expression) Expression { 146 | return Transform(expression, before, after) 147 | }), 148 | } 149 | case *Multiply: 150 | expr = &Multiply{ 151 | L: ReplaceWith(v.L, func(expression Expression) Expression { 152 | return Transform(expression, before, after) 153 | }), 154 | R: ReplaceWith(v.R, func(expression Expression) Expression { 155 | return Transform(expression, before, after) 156 | }), 157 | } 158 | case *Or: 159 | expr = &Or{ 160 | L: ReplaceWith(v.L, func(expression Expression) Expression { 161 | return Transform(expression, before, after) 162 | }), 163 | R: ReplaceWith(v.R, func(expression Expression) Expression { 164 | return Transform(expression, before, after) 165 | }), 166 | } 167 | case *Subtract: 168 | expr = &Subtract{ 169 | L: ReplaceWith(v.L, func(expression Expression) Expression { 170 | return Transform(expression, before, after) 171 | }), 172 | R: ReplaceWith(v.R, func(expression Expression) Expression { 173 | return Transform(expression, before, after) 174 | }), 175 | } 176 | case *Assert: 177 | expr = &Assert{ 178 | L: ReplaceWith(v.L, func(expression Expression) Expression { 179 | return Transform(expression, before, after) 180 | }), 181 | } 182 | case *Factorial: 183 | expr = &Factorial{ 184 | L: ReplaceWith(v.L, func(expression Expression) Expression { 185 | return Transform(expression, before, after) 186 | }), 187 | } 188 | case *IsNull: 189 | expr = &IsNull{ 190 | L: ReplaceWith(v.L, func(expression Expression) Expression { 191 | return Transform(expression, before, after) 192 | }), 193 | } 194 | case *Negate: 195 | expr = &Negate{ 196 | L: ReplaceWith(v.L, func(expression Expression) Expression { 197 | return Transform(expression, before, after) 198 | }), 199 | } 200 | case *Not: 201 | expr = &Not{ 202 | L: ReplaceWith(v.L, func(expression Expression) Expression { 203 | return Transform(expression, before, after) 204 | }), 205 | } 206 | 207 | } 208 | 209 | return after(expr) 210 | } 211 | -------------------------------------------------------------------------------- /sql/value.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import "strconv" 4 | 5 | type DataType byte 6 | 7 | func (d DataType) String() string { 8 | switch d { 9 | case NullType: 10 | return "NULL" 11 | case IntType: 12 | return "INT" 13 | case StringType: 14 | return "VARCHAR(100)" 15 | case FloatType: 16 | return "FLOAF" 17 | case BoolType: 18 | return "BOOL" 19 | default: 20 | return "" 21 | } 22 | } 23 | 24 | type ValueData struct { 25 | Type DataType 26 | Value any 27 | } 28 | 29 | func (v *ValueData) String() string { 30 | switch v.Type { 31 | case NullType: 32 | return "NULL" 33 | case FloatType: 34 | return strconv.FormatFloat(v.Value.(float64), 'f', 4, 64) 35 | case BoolType: 36 | if v.Value == true { 37 | return "TRUE" 38 | } else { 39 | return "FALSE" 40 | } 41 | case IntType: 42 | switch num := v.Value.(type) { 43 | case int: 44 | return strconv.Itoa(num) 45 | case int64: 46 | return strconv.Itoa(int(num)) 47 | case uint64: 48 | return strconv.Itoa(int(num)) 49 | } 50 | return "" 51 | case StringType: 52 | return v.Value.(string) 53 | } 54 | return "" 55 | } 56 | 57 | func (v *ValueData) Compare(other *ValueData) (int, bool) { 58 | if v.Type == other.Type && v.Type == NullType { 59 | return 0, true 60 | } 61 | if v.Type == NullType { 62 | return -1, true 63 | } 64 | if other.Type == NullType { 65 | return 1, true 66 | } 67 | if v.Type == other.Type && v.Type == BoolType { 68 | a := v.Value.(bool) 69 | b := other.Value.(bool) 70 | if a == b { 71 | return 0, true 72 | } 73 | if !a && b { // a 是 false,b 是 true → a < b 74 | return -1, true 75 | } 76 | return 1, true 77 | } 78 | if v.Type == other.Type && v.Type == StringType { 79 | a := v.Value.(string) 80 | b := other.Value.(string) 81 | if a == b { 82 | return 0, true 83 | } 84 | if a < b { 85 | return -1, true 86 | } 87 | return 1, true 88 | } 89 | if v.Type == other.Type && v.Type == IntType { 90 | a := v.Value.(int64) 91 | b := other.Value.(int64) 92 | if a == b { 93 | return 0, true 94 | } 95 | if a < b { 96 | return -1, true 97 | } 98 | return 1, true 99 | 100 | } 101 | if v.Type == other.Type && v.Type == FloatType { 102 | a := v.Value.(float64) 103 | b := other.Value.(float64) 104 | if a == b { 105 | return 0, true 106 | } 107 | if a < b { 108 | return -1, true 109 | } 110 | return 1, true 111 | } 112 | if (v.Type == FloatType && other.Type == IntType) || (v.Type == IntType && other.Type == FloatType) { 113 | a := v.Value.(float64) 114 | b := other.Value.(float64) 115 | if a == b { 116 | return 0, true 117 | } 118 | if a < b { 119 | return -1, true 120 | } 121 | return 1, true 122 | } 123 | return 0, false 124 | } 125 | 126 | const ( 127 | NullType DataType = 0x01 128 | BoolType DataType = 0x02 129 | IntType DataType = 0x03 130 | FloatType DataType = 0x04 131 | StringType DataType = 0x05 132 | ) 133 | 134 | func DecodeDataType(v interface{}) DataType { 135 | switch v.(type) { 136 | case int, int64, int8, int16, int32: 137 | return IntType 138 | case bool: 139 | return BoolType 140 | case string: 141 | return StringType 142 | case float64, float32: 143 | return FloatType 144 | case nil: 145 | return NullType 146 | } 147 | return StringType 148 | } 149 | 150 | func EqualValue(data1 *ValueData, data2 *ValueData) bool { 151 | if data1 == nil || data2 == nil { 152 | return false 153 | } 154 | if data1.Type != data2.Type { 155 | return false 156 | } 157 | if data1.String() != data2.String() { 158 | return false 159 | } 160 | return true 161 | } 162 | -------------------------------------------------------------------------------- /sqlparser/ast/ast.go: -------------------------------------------------------------------------------- 1 | package ast 2 | 3 | import ( 4 | "cabbageDB/sql" 5 | ) 6 | 7 | type Column struct { 8 | Name string 9 | ColumnType sql.DataType 10 | PrimaryKey bool 11 | Nullable bool 12 | Default Expression 13 | Unique bool 14 | Index bool 15 | References *Field 16 | } 17 | 18 | type ColumnOptionType int 19 | 20 | const ( 21 | NOTNULL ColumnOptionType = iota 22 | NULL 23 | PRIMARYKEY 24 | DEFAULT 25 | UNIQUE 26 | ) 27 | 28 | type ColumnOption struct { 29 | Type ColumnOptionType 30 | Value interface{} 31 | } 32 | 33 | type ConstraintType int 34 | 35 | const ( 36 | PRIMARYKEYConstraint ConstraintType = iota 37 | UNIQUEKEYConstraint 38 | KEYConstraint 39 | FORREGINKEYConstraint 40 | ) 41 | 42 | type Constraint struct { 43 | Type ConstraintType 44 | IndexName string 45 | ColumnName string 46 | TableName string 47 | SubColumnName string 48 | } 49 | -------------------------------------------------------------------------------- /sqlparser/ast/expression.go: -------------------------------------------------------------------------------- 1 | package ast 2 | 3 | import ( 4 | "bytes" 5 | "cabbageDB/sql" 6 | "encoding/gob" 7 | ) 8 | 9 | type Order struct { 10 | Expr Expression 11 | Desc bool 12 | } 13 | 14 | type FromItem interface { 15 | fromItem() 16 | } 17 | 18 | type FromItemTable struct { 19 | Name string 20 | Alias string 21 | } 22 | 23 | func (f *FromItemTable) fromItem() {} 24 | 25 | type JoinType int 26 | 27 | const ( 28 | CrossJoin JoinType = iota + 1 29 | InnerJoin 30 | LeftJoin 31 | RightJoin 32 | ) 33 | 34 | type FromItemJoinTable struct { 35 | Left FromItem 36 | Right FromItem 37 | Type JoinType 38 | Predicate Expression 39 | } 40 | 41 | func (f *FromItemJoinTable) fromItem() {} 42 | 43 | // 表达式 44 | type ExprAS struct { 45 | Expr Expression 46 | As string 47 | } 48 | 49 | type ExprColumn struct { 50 | ColumnName string 51 | Expr Expression 52 | } 53 | 54 | type Expression interface { 55 | expression() 56 | } 57 | 58 | type Field struct { 59 | TableName string 60 | ColumnName string 61 | } 62 | 63 | func (f *Field) expression() { 64 | } 65 | 66 | type ColumnIdx struct { 67 | Index int 68 | } 69 | 70 | func (c *ColumnIdx) expression() { 71 | } 72 | 73 | //type DataType int 74 | // 75 | //const ( 76 | // Null DataType = iota 77 | // Bool 78 | // Int 79 | // Float 80 | // String 81 | //) 82 | 83 | // Type 84 | // 0 Null 85 | // 1 Bool 86 | // 2 Int 87 | // 3 Float 88 | // 4 String 89 | type Literal struct { 90 | Type sql.DataType 91 | Value interface{} 92 | } 93 | 94 | func (l *Literal) expression() { 95 | } 96 | 97 | type Function struct { 98 | FuncName string 99 | Args []Expression 100 | } 101 | 102 | func (f *Function) expression() { 103 | } 104 | 105 | type Operation struct { 106 | Operation OperationType 107 | } 108 | 109 | func (o *Operation) expression() { 110 | } 111 | 112 | type OperationType interface { 113 | operationType() 114 | } 115 | 116 | type AndOper struct { 117 | L Expression 118 | R Expression 119 | } 120 | 121 | func (a *AndOper) operationType() { 122 | } 123 | func (a *AndOper) expression() { 124 | } 125 | 126 | type NotOper struct { 127 | L Expression 128 | } 129 | 130 | func (n *NotOper) operationType() { 131 | 132 | } 133 | func (n *NotOper) expression() { 134 | } 135 | 136 | type OrOper struct { 137 | L Expression 138 | R Expression 139 | } 140 | 141 | func (o *OrOper) operationType() { 142 | 143 | } 144 | func (o *OrOper) expression() { 145 | } 146 | 147 | type EqualOper struct { 148 | L Expression 149 | R Expression 150 | } 151 | 152 | func (e *EqualOper) operationType() { 153 | 154 | } 155 | func (e *EqualOper) expression() { 156 | } 157 | 158 | type GreaterThanOper struct { 159 | L Expression 160 | R Expression 161 | } 162 | 163 | func (g *GreaterThanOper) operationType() { 164 | 165 | } 166 | func (g *GreaterThanOper) expression() { 167 | } 168 | 169 | type GreaterThanOrEqualOper struct { 170 | L Expression 171 | R Expression 172 | } 173 | 174 | func (g *GreaterThanOrEqualOper) operationType() { 175 | 176 | } 177 | func (g *GreaterThanOrEqualOper) expression() { 178 | } 179 | 180 | type IsNullOper struct { 181 | L Expression 182 | } 183 | 184 | func (i *IsNullOper) operationType() { 185 | 186 | } 187 | func (i *IsNullOper) expression() { 188 | } 189 | 190 | type LessThanOper struct { 191 | L Expression 192 | R Expression 193 | } 194 | 195 | func (l *LessThanOper) operationType() { 196 | 197 | } 198 | func (l *LessThanOper) expression() { 199 | } 200 | 201 | type LessThanOrEqualOper struct { 202 | L Expression 203 | R Expression 204 | } 205 | 206 | func (g *LessThanOrEqualOper) operationType() { 207 | 208 | } 209 | 210 | func (g *LessThanOrEqualOper) expression() { 211 | } 212 | 213 | type NotEqualOper struct { 214 | L Expression 215 | R Expression 216 | } 217 | 218 | func (n *NotEqualOper) operationType() { 219 | } 220 | 221 | func (n *NotEqualOper) expression() { 222 | } 223 | 224 | type AddOper struct { 225 | L Expression 226 | R Expression 227 | } 228 | 229 | func (a *AddOper) operationType() { 230 | } 231 | func (a *AddOper) expression() { 232 | } 233 | 234 | type AssertOper struct { 235 | L Expression 236 | } 237 | 238 | func (a *AssertOper) operationType() { 239 | } 240 | 241 | func (a *AssertOper) expression() { 242 | } 243 | 244 | type DivideOper struct { 245 | L Expression 246 | R Expression 247 | } 248 | 249 | func (d *DivideOper) operationType() { 250 | } 251 | func (d *DivideOper) expression() { 252 | } 253 | 254 | type ExponentiateOper struct { 255 | L Expression 256 | R Expression 257 | } 258 | 259 | func (e *ExponentiateOper) operationType() { 260 | } 261 | func (e *ExponentiateOper) expression() { 262 | } 263 | 264 | type FactorialOper struct { 265 | L Expression 266 | } 267 | 268 | func (f *FactorialOper) operationType() { 269 | } 270 | func (f *FactorialOper) expression() { 271 | } 272 | 273 | type ModuloOper struct { 274 | L Expression 275 | R Expression 276 | } 277 | 278 | func (m *ModuloOper) operationType() { 279 | } 280 | func (m *ModuloOper) expression() { 281 | } 282 | 283 | type MultiplyOper struct { 284 | L Expression 285 | R Expression 286 | } 287 | 288 | func (m *MultiplyOper) operationType() { 289 | } 290 | func (m *MultiplyOper) expression() { 291 | } 292 | 293 | type NegateOper struct { 294 | L Expression 295 | } 296 | 297 | func (n *NegateOper) operationType() { 298 | } 299 | func (n *NegateOper) expression() { 300 | } 301 | 302 | type SubtractOper struct { 303 | L Expression 304 | R Expression 305 | } 306 | 307 | func (s *SubtractOper) operationType() { 308 | } 309 | 310 | func (s *SubtractOper) expression() { 311 | } 312 | 313 | type LikeOper struct { 314 | L Expression 315 | R Expression 316 | } 317 | 318 | func (l *LikeOper) operationType() { 319 | } 320 | func (l *LikeOper) expression() { 321 | } 322 | 323 | func GobReg() { 324 | 325 | gob.Register(&Literal{}) 326 | var buf bytes.Buffer 327 | enc := gob.NewEncoder(&buf) 328 | _ = enc.Encode(&Literal{}) 329 | } 330 | -------------------------------------------------------------------------------- /sqlparser/ast/fn.go: -------------------------------------------------------------------------------- 1 | package ast 2 | 3 | type ExprFnBool func(expression Expression) bool 4 | type ExprFn func(expression Expression) Expression 5 | 6 | func Contains(expr Expression, fn ExprFnBool) bool { 7 | return !Walk(expr, func(expr Expression) bool { 8 | return !fn(expr) 9 | }) 10 | } 11 | 12 | func Walk(expr Expression, visitor ExprFnBool) bool { 13 | visitorBool := visitor(expr) 14 | 15 | var exprBool bool 16 | switch v := expr.(type) { 17 | case *AddOper: 18 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 19 | case *AndOper: 20 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 21 | case *DivideOper: 22 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 23 | case *EqualOper: 24 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 25 | case *ExponentiateOper: 26 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 27 | case *GreaterThanOrEqualOper: 28 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 29 | case *GreaterThanOper: 30 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 31 | case *LessThanOper: 32 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 33 | case *LessThanOrEqualOper: 34 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 35 | case *LikeOper: 36 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 37 | case *ModuloOper: 38 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 39 | case *MultiplyOper: 40 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 41 | case *NotEqualOper: 42 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 43 | case *OrOper: 44 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 45 | case *SubtractOper: 46 | exprBool = Walk(v.L, visitor) && Walk(v.R, visitor) 47 | case *AssertOper: 48 | exprBool = Walk(v.L, visitor) 49 | case *FactorialOper: 50 | exprBool = Walk(v.L, visitor) 51 | case *IsNullOper: 52 | exprBool = Walk(v.L, visitor) 53 | case *NegateOper: 54 | exprBool = Walk(v.L, visitor) 55 | case *NotOper: 56 | exprBool = Walk(v.L, visitor) 57 | case *Function: 58 | for _, fn := range v.Args { 59 | if !Walk(fn, visitor) { 60 | return false 61 | } 62 | } 63 | exprBool = true 64 | case *Literal, *Field, *ColumnIdx: 65 | exprBool = true 66 | 67 | } 68 | 69 | return visitorBool && exprBool 70 | } 71 | 72 | func ReplaceWith(expr Expression, fn ExprFn) Expression { 73 | return fn(expr) 74 | } 75 | 76 | func Transform(expr Expression, before ExprFn, after ExprFn) Expression { 77 | expr = before(expr) 78 | switch v := expr.(type) { 79 | case *Operation: 80 | switch v1 := v.Operation.(type) { 81 | case *AddOper: 82 | expr = &Operation{Operation: &AddOper{ 83 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 84 | return Transform(expression, before, after) 85 | }), 86 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 87 | return Transform(expression, before, after) 88 | })}, 89 | } 90 | 91 | case *AndOper: 92 | expr = &Operation{Operation: &AndOper{ 93 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 94 | return Transform(expression, before, after) 95 | }), 96 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 97 | return Transform(expression, before, after) 98 | })}, 99 | } 100 | case *DivideOper: 101 | expr = &Operation{Operation: &DivideOper{ 102 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 103 | return Transform(expression, before, after) 104 | }), 105 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 106 | return Transform(expression, before, after) 107 | })}, 108 | } 109 | case *EqualOper: 110 | expr = &Operation{Operation: &EqualOper{ 111 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 112 | return Transform(expression, before, after) 113 | }), 114 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 115 | return Transform(expression, before, after) 116 | })}, 117 | } 118 | case *ExponentiateOper: 119 | expr = &Operation{Operation: &ExponentiateOper{ 120 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 121 | return Transform(expression, before, after) 122 | }), 123 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 124 | return Transform(expression, before, after) 125 | })}, 126 | } 127 | case *GreaterThanOrEqualOper: 128 | expr = &Operation{Operation: &GreaterThanOrEqualOper{ 129 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 130 | return Transform(expression, before, after) 131 | }), 132 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 133 | return Transform(expression, before, after) 134 | })}, 135 | } 136 | case *GreaterThanOper: 137 | expr = &Operation{Operation: &GreaterThanOper{ 138 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 139 | return Transform(expression, before, after) 140 | }), 141 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 142 | return Transform(expression, before, after) 143 | }), 144 | }} 145 | case *LessThanOper: 146 | expr = &Operation{Operation: &LessThanOper{ 147 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 148 | return Transform(expression, before, after) 149 | }), 150 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 151 | return Transform(expression, before, after) 152 | })}, 153 | } 154 | case *LessThanOrEqualOper: 155 | expr = &Operation{Operation: &LessThanOrEqualOper{ 156 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 157 | return Transform(expression, before, after) 158 | }), 159 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 160 | return Transform(expression, before, after) 161 | })}, 162 | } 163 | case *LikeOper: 164 | expr = &Operation{Operation: &LikeOper{ 165 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 166 | return Transform(expression, before, after) 167 | }), 168 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 169 | return Transform(expression, before, after) 170 | })}, 171 | } 172 | case *ModuloOper: 173 | expr = &Operation{Operation: &ModuloOper{ 174 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 175 | return Transform(expression, before, after) 176 | }), 177 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 178 | return Transform(expression, before, after) 179 | })}, 180 | } 181 | case *MultiplyOper: 182 | expr = &Operation{Operation: &MultiplyOper{ 183 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 184 | return Transform(expression, before, after) 185 | }), 186 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 187 | return Transform(expression, before, after) 188 | })}, 189 | } 190 | case *NotEqualOper: 191 | expr = &Operation{Operation: &NotEqualOper{ 192 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 193 | return Transform(expression, before, after) 194 | }), 195 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 196 | return Transform(expression, before, after) 197 | })}, 198 | } 199 | case *OrOper: 200 | expr = &Operation{Operation: &OrOper{ 201 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 202 | return Transform(expression, before, after) 203 | }), 204 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 205 | return Transform(expression, before, after) 206 | })}, 207 | } 208 | case *SubtractOper: 209 | expr = &Operation{Operation: &SubtractOper{ 210 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 211 | return Transform(expression, before, after) 212 | }), 213 | R: ReplaceWith(v1.R, func(expression Expression) Expression { 214 | return Transform(expression, before, after) 215 | })}, 216 | } 217 | case *AssertOper: 218 | expr = &Operation{Operation: &SubtractOper{ 219 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 220 | return Transform(expression, before, after) 221 | })}, 222 | } 223 | case *FactorialOper: 224 | expr = &Operation{Operation: &FactorialOper{ 225 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 226 | return Transform(expression, before, after) 227 | })}, 228 | } 229 | case *NegateOper: 230 | expr = &Operation{Operation: &NegateOper{ 231 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 232 | return Transform(expression, before, after) 233 | })}, 234 | } 235 | case *NotOper: 236 | expr = &Operation{Operation: &NotOper{ 237 | L: ReplaceWith(v1.L, func(expression Expression) Expression { 238 | return Transform(expression, before, after) 239 | })}, 240 | } 241 | } 242 | 243 | case *Function: 244 | var exprs []Expression 245 | for _, fn := range v.Args { 246 | exprs = append(exprs, ReplaceWith(fn, func(expression Expression) Expression { 247 | return Transform(expression, before, after) 248 | })) 249 | } 250 | expr = &Function{ 251 | FuncName: v.FuncName, 252 | Args: exprs, 253 | } 254 | 255 | } 256 | 257 | return after(expr) 258 | } 259 | -------------------------------------------------------------------------------- /sqlparser/ast/stmt.go: -------------------------------------------------------------------------------- 1 | package ast 2 | 3 | type Stmt interface { 4 | StmtIter() 5 | } 6 | 7 | type BeginStmt struct { 8 | ReadOnly bool 9 | AsOf uint64 10 | } 11 | 12 | func (b *BeginStmt) StmtIter() { 13 | } 14 | 15 | type CreateStmt struct { 16 | Name string 17 | Columns []*Column 18 | } 19 | 20 | func (c *CreateStmt) StmtIter() { 21 | } 22 | 23 | type SelectStmt struct { 24 | Select []*ExprAS 25 | From []FromItem 26 | Where Expression 27 | GroupBy []Expression 28 | Having Expression 29 | Order []*Order 30 | Offset Expression 31 | Limit Expression 32 | } 33 | 34 | func (s *SelectStmt) StmtIter() { 35 | } 36 | 37 | type CommitStmt struct { 38 | } 39 | 40 | func (c *CommitStmt) StmtIter() { 41 | } 42 | 43 | type RollbackStmt struct { 44 | } 45 | 46 | func (r *RollbackStmt) StmtIter() { 47 | } 48 | 49 | type ExplainStmt struct { 50 | Stmt Stmt 51 | } 52 | 53 | func (e *ExplainStmt) StmtIter() { 54 | } 55 | 56 | type DropTableStmt struct { 57 | TableName string 58 | } 59 | 60 | func (d *DropTableStmt) StmtIter() { 61 | } 62 | 63 | type DeleteStmt struct { 64 | TableName string 65 | Where Expression 66 | } 67 | 68 | func (d *DeleteStmt) StmtIter() { 69 | } 70 | 71 | type InsertStmt struct { 72 | TableName string 73 | Columns []string 74 | Values [][]Expression 75 | } 76 | 77 | func (i *InsertStmt) StmtIter() { 78 | } 79 | 80 | type UpdateStmt struct { 81 | TableName string 82 | Set []*ExprColumn 83 | Where Expression 84 | } 85 | 86 | func (u *UpdateStmt) StmtIter() { 87 | } 88 | -------------------------------------------------------------------------------- /sqlparser/model/model.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "strings" 4 | 5 | func LowStr(s string) string { 6 | return strings.ToLower(s) 7 | } 8 | -------------------------------------------------------------------------------- /sqlparser/parser/charset.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "github.com/pkg/errors" 5 | "strings" 6 | ) 7 | 8 | type Charset struct { 9 | Name string 10 | DefaultCollation string 11 | Collations map[string]*Collation 12 | Desc string 13 | Maxlen int 14 | } 15 | 16 | type Collation struct { 17 | ID int 18 | CharsetName string 19 | Name string 20 | IsDefault bool 21 | } 22 | 23 | var charsets = make(map[string]*Charset) 24 | 25 | func GetCharsetInfo(cs string) (string, string, error) { 26 | c, ok := charsets[strings.ToLower(cs)] 27 | if !ok { 28 | return "", "", errors.Errorf("Unknown charset %s", cs) 29 | } 30 | return c.Name, c.DefaultCollation, nil 31 | } 32 | -------------------------------------------------------------------------------- /sqlparser/parser/handle.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "cabbageDB/sqlparser/ast" 5 | "errors" 6 | ) 7 | 8 | var Sqlerr error 9 | 10 | func handleColumn(columnOptionList []*ast.ColumnOption, column *ast.Column) error { 11 | for _, columnOption := range columnOptionList { 12 | if columnOption == nil { 13 | continue 14 | } 15 | switch columnOption.Type { 16 | case ast.NULL: 17 | column.Nullable = true 18 | case ast.NOTNULL: 19 | column.Nullable = false 20 | case ast.PRIMARYKEY: 21 | column.PrimaryKey = true 22 | case ast.DEFAULT: 23 | column.Default = columnOption.Value.(ast.Expression) 24 | case ast.UNIQUE: 25 | column.Unique = true 26 | } 27 | } 28 | return nil 29 | } 30 | 31 | func handleConstraint(columns []*ast.Column, c *ast.Constraint) error { 32 | if c == nil { 33 | return nil 34 | } 35 | 36 | flag := false 37 | for _, col := range columns { 38 | if c.Type == ast.FORREGINKEYConstraint && c.SubColumnName == col.Name { 39 | flag = true 40 | if c.TableName == "" { 41 | return errors.New("References Table Name is empty") 42 | } 43 | if c.ColumnName == "" { 44 | return errors.New("References Column Name is empty") 45 | } 46 | col.References = &ast.Field{ 47 | TableName: c.TableName, 48 | ColumnName: c.ColumnName, 49 | } 50 | } 51 | 52 | if col.Name == c.ColumnName { 53 | flag = true 54 | switch c.Type { 55 | case ast.PRIMARYKEYConstraint: 56 | col.PrimaryKey = true 57 | case ast.UNIQUEKEYConstraint: 58 | col.Unique = true 59 | case ast.KEYConstraint: 60 | col.Index = true 61 | } 62 | } 63 | } 64 | if flag == false { 65 | return errors.New("not found column " + c.ColumnName) 66 | } 67 | return nil 68 | } 69 | 70 | func getUint64FromNUM(num interface{}) uint64 { 71 | switch v := num.(type) { 72 | case int64: 73 | return uint64(v) 74 | case uint64: 75 | return v 76 | } 77 | return 0 78 | } 79 | 80 | func valiDate(columns []*ast.Column) error { 81 | hasPrimaryKey := false 82 | columnNames := make(map[string]struct{}) 83 | for _, col := range columns { 84 | if _, ok := columnNames[col.Name]; ok { 85 | return errors.New("column " + col.Name + " alread exists") 86 | } 87 | columnNames[col.Name] = struct{}{} 88 | 89 | if col.PrimaryKey { 90 | if hasPrimaryKey { 91 | // 发现第二个主键时,立即报错并指明列名 92 | return errors.New("multiple primary keys defined (column " + col.Name + ")") 93 | } 94 | hasPrimaryKey = true // 标记已存在主键 95 | } 96 | } 97 | if !hasPrimaryKey { 98 | // 如果没有主键 99 | return errors.New("must be have a primary key") 100 | } 101 | return nil 102 | } 103 | -------------------------------------------------------------------------------- /sqlparser/parser/lexer.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "bytes" 5 | "unicode" 6 | "unicode/utf8" 7 | ) 8 | 9 | type Scanner struct { 10 | r reader 11 | buf bytes.Buffer 12 | errs []error 13 | identifierDot bool 14 | } 15 | 16 | type reader struct { 17 | s string 18 | p Pos 19 | w int 20 | } 21 | 22 | type Pos struct { 23 | Line int 24 | Col int 25 | Offset int 26 | } 27 | 28 | func (s *Scanner) reset(sql string) { 29 | s.r = reader{s: sql, p: Pos{Line: 1}} 30 | s.buf.Reset() 31 | s.errs = s.errs[:0] 32 | } 33 | 34 | func (s *Scanner) skipWhitespace() rune { 35 | return s.r.incAsLongAs(unicode.IsSpace) 36 | } 37 | 38 | func (s *Scanner) scan() (tok int, pos Pos, lit string) { 39 | ch0 := s.r.peek() 40 | if unicode.IsSpace(ch0) { 41 | ch0 = s.skipWhitespace() 42 | } 43 | 44 | pos = s.r.pos() 45 | if s.r.eof() { 46 | return 0, pos, "" 47 | } 48 | if !s.r.eof() && isIdentExtend(ch0) { 49 | return scanIdentifier(s) 50 | } 51 | 52 | node := &ruleTable 53 | for !(node.childs[ch0] == nil || s.r.eof()) { 54 | node = node.childs[ch0] 55 | if node.fn != nil { 56 | return node.fn(s) 57 | } 58 | s.r.inc() 59 | ch0 = s.r.peek() 60 | } 61 | 62 | tok, lit = node.token, s.r.data(&pos) 63 | return 64 | } 65 | 66 | func (r *reader) data(from *Pos) string { 67 | return r.s[from.Offset:r.p.Offset] 68 | } 69 | 70 | func scanIdentifier(s *Scanner) (int, Pos, string) { 71 | pos := s.r.pos() 72 | s.r.inc() 73 | s.r.incAsLongAs(isIdentChar) 74 | 75 | return identifier, pos, s.r.data(&pos) 76 | } 77 | 78 | var eof = Pos{-1, -1, -1} 79 | 80 | func (r *reader) eof() bool { 81 | // 如果Offset>reader的s长度则到了eof 82 | return r.p.Offset >= len(r.s) 83 | } 84 | 85 | func (r *reader) peek() rune { 86 | if r.eof() { 87 | return unicode.ReplacementChar 88 | } 89 | 90 | // s.r.incAsLongAs(isIdentChar) 在这里迭代 91 | v, w := rune(r.s[r.p.Offset]), 1 92 | switch { 93 | case v == 0: 94 | r.w = w 95 | return v 96 | case v >= 0x80: 97 | v, w = utf8.DecodeRuneInString(r.s[r.p.Offset:]) 98 | if v == utf8.RuneError && w == 1 { 99 | v = rune(r.s[r.p.Offset]) 100 | } 101 | } 102 | r.w = w 103 | 104 | return v 105 | } 106 | 107 | func (r *reader) incAsLongAs(fn func(rune) bool) rune { 108 | for { 109 | ch := r.peek() 110 | if !fn(ch) { 111 | return ch 112 | } 113 | if ch == unicode.ReplacementChar && r.eof() { 114 | return 0 115 | } 116 | r.inc() 117 | } 118 | } 119 | 120 | func (r *reader) inc() { 121 | if r.s[r.p.Offset] == '\n' { 122 | r.p.Line++ 123 | r.p.Col = 0 124 | } 125 | r.p.Offset += r.w 126 | r.p.Col++ 127 | } 128 | 129 | func (r *reader) pos() Pos { 130 | return r.p 131 | } 132 | 133 | func (s *Scanner) isTokenIdentifier(lit string, offset int) int { 134 | if s.r.peek() == '.' { 135 | return 0 136 | } 137 | for idx := offset - 1; idx >= 0; idx-- { 138 | if s.r.s[idx] == ' ' { 139 | continue 140 | } else if s.r.s[idx] == '.' { 141 | return 0 142 | } 143 | break 144 | } 145 | buf := &s.buf 146 | buf.Reset() 147 | buf.Grow(len(lit)) 148 | data := buf.Bytes()[:len(lit)] 149 | 150 | for i := 0; i < len(lit); i++ { 151 | c := lit[i] 152 | if c >= 'a' && c <= 'z' { 153 | data[i] = c + 'A' - 'a' 154 | } else { 155 | data[i] = c 156 | } 157 | } 158 | 159 | tokenStr := string(data) 160 | tok, _ := tokenMap[tokenStr] 161 | 162 | return tok 163 | } 164 | 165 | func (p *Parser) GetResult() interface{} { 166 | ret := yyParse(p) 167 | 168 | return ret 169 | } 170 | 171 | func (r *reader) readByte() (ch byte) { 172 | ch = byte(r.peek()) 173 | if r.eof() { 174 | return 175 | } 176 | r.inc() 177 | return 178 | } 179 | 180 | func (s *Scanner) scanDigits() string { 181 | pos := s.r.pos() 182 | s.r.incAsLongAs(isDigit) 183 | return s.r.data(&pos) 184 | } 185 | 186 | func startWithDot(s *Scanner) (tok int, pos Pos, lit string) { 187 | pos = s.r.pos() 188 | s.r.inc() 189 | if s.identifierDot { 190 | return int('.'), pos, "." 191 | } 192 | if isDigit(s.r.peek()) { 193 | tok, p, l := s.scanFloat(&pos) 194 | if tok == identifier { 195 | return 0, p, l 196 | } 197 | return tok, p, l 198 | } 199 | tok, lit = int('.'), "." 200 | return 201 | } 202 | 203 | func (r *reader) updatePos(pos Pos) { 204 | r.p = pos 205 | } 206 | 207 | func (s *Scanner) scanFloat(beg *Pos) (tok int, pos Pos, lit string) { 208 | s.r.updatePos(*beg) 209 | // float = D1 . D2 e D3 210 | s.scanDigits() 211 | ch0 := s.r.peek() 212 | if ch0 == '.' { 213 | s.r.inc() 214 | s.scanDigits() 215 | ch0 = s.r.peek() 216 | } 217 | if isDigit(s.r.peek()) { 218 | s.scanDigits() 219 | } 220 | pos, lit = *beg, s.r.data(beg) 221 | tok = floatLit 222 | return 223 | } 224 | -------------------------------------------------------------------------------- /sqlparser/parser/misc.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "bytes" 5 | "strings" 6 | ) 7 | 8 | func isIdentChar(ch rune) bool { 9 | return isLetter(ch) || isDigit(ch) || ch == '_' || ch == '$' || isIdentExtend(ch) 10 | } 11 | 12 | func isIdentExtend(ch rune) bool { 13 | return ch >= 0x80 && ch <= '\uffff' 14 | } 15 | func isLetter(ch rune) bool { 16 | return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') 17 | } 18 | func isDigit(ch rune) bool { 19 | return (ch >= '0' && ch <= '9') 20 | } 21 | 22 | type trieNode struct { 23 | childs [256]*trieNode 24 | token int 25 | fn func(s *Scanner) (int, Pos, string) 26 | } 27 | 28 | var ruleTable trieNode 29 | 30 | func initTokenFunc(str string, fn func(s *Scanner) (int, Pos, string)) { 31 | for i := 0; i < len(str); i++ { 32 | c := str[i] 33 | if ruleTable.childs[c] == nil { 34 | ruleTable.childs[c] = &trieNode{} 35 | } 36 | ruleTable.childs[c].fn = fn 37 | } 38 | } 39 | 40 | func initTokenByte(c byte, tok int) { 41 | if ruleTable.childs[c] == nil { 42 | ruleTable.childs[c] = &trieNode{} 43 | } 44 | ruleTable.childs[c].token = tok 45 | } 46 | 47 | func handleIdent(lval *yySymType) int { 48 | s := lval.ident 49 | 50 | if !strings.HasPrefix(s, "_") { 51 | return identifier 52 | } 53 | cs, _, err := GetCharsetInfo(s[1:]) 54 | if err != nil { 55 | return identifier 56 | } 57 | 58 | lval.ident = cs 59 | 60 | return 0 61 | } 62 | 63 | var tokenMap = map[string]int{ 64 | 65 | "AND": and, 66 | "AS": as, 67 | "ASC": asc, 68 | "AVG": avg, 69 | "BY": by, 70 | "COUNT": count, 71 | "CROSS": cross, 72 | "DESC": desc, 73 | "FROM": from, 74 | "GROUP": group, 75 | "HAVING": having, 76 | "INNER": inner, 77 | "IS": is, 78 | "JOIN": join, 79 | "LEFT": left, 80 | "LIKE": like, 81 | "LIMIT": limit, 82 | "NOT": not, 83 | "NULL": null, 84 | "OFFSET": offset, 85 | "ON": on, 86 | "OR": or, 87 | "ORDER": order, 88 | "OUTER": outer, 89 | "RIGHT": right, 90 | "SELECT": selectKwd, 91 | "WHERE": where, 92 | "SUM": sum, 93 | "MAX": maxKwd, 94 | "MIN": minKwd, 95 | "CREATE": create, 96 | "TABLE": table, 97 | "IF": ifkwd, 98 | "EXISTS": exists, 99 | "CHAR": char, 100 | "VARCHAR": varchar, 101 | "TEXT": text, 102 | "BOOLEAN": boolean, 103 | "BOOL": boolkwd, 104 | "FLOAT": float, 105 | "DOUBLE": double, 106 | "INT": intkwd, 107 | "INTEGER": integer, 108 | "INSERT": insert, 109 | "INTO": into, 110 | "VALUES": values, 111 | "Integer": integer, 112 | "PRIMARY": primary, 113 | "KEY": key, 114 | "DEFAULT": defaultKwd, 115 | "UNIQUE": unique, 116 | "UPDATE": update, 117 | "SET": set, 118 | "DELETE": delete, 119 | "DROP": drop, 120 | "FOREIGN": foreign, 121 | "REFERENCES": references, 122 | "EXPLAIN": explain, 123 | "START": start, 124 | "COMMIT": commit, 125 | "READ": read, 126 | "ONLY": only, 127 | "WRITE": write, 128 | "BEGIN": begin, 129 | "TRANSACTION": transaction, 130 | "ROLLBACK": rollback, 131 | "TRUE": trueKwd, 132 | "FALSE": falseKwd, 133 | } 134 | 135 | func init() { 136 | initTokenByte('(', int('(')) 137 | initTokenByte(')', int(')')) 138 | initTokenByte(',', int(',')) 139 | initTokenByte('>', int('>')) 140 | initTokenByte('<', int('<')) 141 | initTokenByte('+', int('+')) 142 | initTokenByte('-', int('-')) 143 | initTokenByte('*', int('*')) 144 | initTokenByte('/', int('/')) 145 | initTokenByte('%', int('%')) 146 | initTokenByte('^', int('^')) 147 | 148 | initTokenByte('=', eq) 149 | initTokenString(">=", ge) 150 | initTokenString("<=", le) 151 | initTokenString("!=", neq) 152 | initTokenString("<>", neqSynonym) 153 | initTokenString("&&", and) 154 | initTokenString("||", or) 155 | 156 | initTokenFunc("0123456789", startWithNumber) 157 | initTokenFunc("'\"", scanString) 158 | initTokenFunc(".", startWithDot) 159 | 160 | initTokenFunc("_$ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", scanIdentifier) 161 | } 162 | 163 | func initTokenString(str string, tok int) { 164 | node := &ruleTable 165 | for _, c := range str { 166 | if node.childs[c] == nil { 167 | node.childs[c] = &trieNode{} 168 | } 169 | node = node.childs[c] 170 | } 171 | node.token = tok 172 | } 173 | 174 | func scanString(s *Scanner) (tok int, pos Pos, lit string) { 175 | tok, pos = stringLit, s.r.pos() 176 | ending := s.r.readByte() 177 | s.buf.Reset() 178 | for !s.r.eof() { 179 | ch0 := s.r.readByte() 180 | if ch0 == ending { 181 | if byte(s.r.peek()) != ending { 182 | lit = s.buf.String() 183 | return 184 | } 185 | s.r.inc() 186 | s.buf.WriteByte(ch0) 187 | } else if ch0 == '\\' { 188 | if s.r.eof() { 189 | break 190 | } 191 | s.handleEscape(byte(s.r.peek()), &s.buf) 192 | s.r.inc() 193 | } else { 194 | s.buf.WriteByte(ch0) 195 | } 196 | 197 | } 198 | tok = 0 199 | return 200 | } 201 | 202 | // handleEscape handles the case in scanString when previous char is '\'. 203 | func (*Scanner) handleEscape(b byte, buf *bytes.Buffer) { 204 | var ch0 byte 205 | /* 206 | \" \' \\ \n \0 \b \Z \r \t ==> escape to one char 207 | \% \_ ==> preserve both char 208 | other ==> remove \ 209 | */ 210 | switch b { 211 | case 'n': 212 | ch0 = '\n' 213 | case '0': 214 | ch0 = 0 215 | case 'b': 216 | ch0 = 8 217 | case 'Z': 218 | ch0 = 26 219 | case 'r': 220 | ch0 = '\r' 221 | case 't': 222 | ch0 = '\t' 223 | case '%', '_': 224 | buf.WriteByte('\\') 225 | ch0 = b 226 | default: 227 | ch0 = b 228 | } 229 | buf.WriteByte(ch0) 230 | } 231 | 232 | func startWithNumber(s *Scanner) (tok int, pos Pos, lit string) { 233 | 234 | pos = s.r.pos() 235 | tok = intLit 236 | ch0 := s.r.readByte() 237 | 238 | s.scanDigits() 239 | ch0 = byte(s.r.peek()) 240 | if ch0 == '.' || ch0 == 'e' || ch0 == 'E' { 241 | return s.scanFloat(&pos) 242 | } 243 | 244 | if !s.r.eof() && isIdentChar(rune(ch0)) { 245 | s.r.incAsLongAs(isIdentChar) 246 | return identifier, pos, s.r.data(&pos) 247 | } 248 | lit = s.r.data(&pos) 249 | return 250 | } 251 | -------------------------------------------------------------------------------- /sqlparser/parser/parser.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "cabbageDB/sqlparser/ast" 5 | "errors" 6 | "strconv" 7 | ) 8 | 9 | type Parser struct { 10 | Scan *Scanner 11 | Result []ast.Stmt 12 | ErrorMsg error 13 | } 14 | 15 | func (p *Parser) Reset(sql string) { 16 | scan := Scanner{} 17 | scan.reset(sql) 18 | p.Scan = &scan 19 | } 20 | 21 | func (p *Parser) Lex(lval *yySymType) int { 22 | tok, pos, lit := p.Scan.scan() 23 | 24 | if tok == identifier { 25 | tok = handleIdent(lval) 26 | } 27 | 28 | if tok == identifier { 29 | if tok1 := p.Scan.isTokenIdentifier(lit, pos.Offset); tok1 != 0 { 30 | tok = tok1 31 | } 32 | } 33 | if tok == intLit { 34 | n, _ := strconv.ParseUint(lit, 10, 64) 35 | lval.item = int64(n) 36 | } 37 | if tok == floatLit { 38 | n, _ := strconv.ParseFloat(lit, 64) 39 | lval.item = float64(n) 40 | } 41 | switch tok { 42 | 43 | case identifier: 44 | lval.ident = lit 45 | case stringLit: 46 | lval.ident = lit 47 | } 48 | 49 | return tok 50 | } 51 | 52 | func (p *Parser) Error(s string) { 53 | p.ErrorMsg = errors.New(s) 54 | } 55 | 56 | func (p *Parser) ParseSQL() int { 57 | ret := yyParse(p) 58 | return ret 59 | } 60 | -------------------------------------------------------------------------------- /sqlparser/parser/sourcegoyacc.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shihuricha/cabbageDB/215fb495e9cc322a967e78fb4947476ff084736a/sqlparser/parser/sourcegoyacc.exe -------------------------------------------------------------------------------- /storage/mvcc.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "cabbageDB/bitcask" 5 | "cabbageDB/log" 6 | "cabbageDB/logger" 7 | "cabbageDB/util" 8 | "errors" 9 | "math" 10 | "sync" 11 | ) 12 | 13 | type MVCC struct { 14 | Engine log.Engine 15 | Mu sync.Mutex 16 | } 17 | 18 | func (mvcc *MVCC) Resume(state *TransactionState) *MVCCTransactionEngine { 19 | return MVCCTxnEngineResume(mvcc.Engine, state) 20 | } 21 | 22 | type Status struct { 23 | Version uint64 24 | ActiveTxns uint64 25 | Storage *bitcask.Status 26 | } 27 | 28 | func (mvcc *MVCC) Status() *Status { 29 | engine := mvcc.Engine 30 | nextVersion := NextVersion{} 31 | versionByte := engine.Get(nextVersion.MVCCEncode()) 32 | var version uint64 33 | if len(versionByte) == 0 { 34 | version = 0 35 | } else { 36 | err := util.ByteToInt(versionByte, &version) 37 | if err != nil { 38 | logger.Info("err:", err.Error()) 39 | } 40 | } 41 | txnActive := TxnActive{} 42 | activeTxns := uint64(len(engine.ScanPrefix(txnActive.MVCCEncode()))) 43 | return &Status{ 44 | Version: version, 45 | ActiveTxns: activeTxns, 46 | Storage: engine.Status(), 47 | } 48 | } 49 | 50 | func (mvcc *MVCC) GetUnversioned(key []byte) []byte { 51 | mvcc.Mu.Lock() 52 | defer mvcc.Mu.Unlock() 53 | versioned := &UnVersioned{Key: key} 54 | return mvcc.Engine.Get(versioned.MVCCEncode()) 55 | } 56 | 57 | func (mvcc *MVCC) SetUnversioned(key, value []byte) { 58 | mvcc.Mu.Lock() 59 | defer mvcc.Mu.Unlock() 60 | versioned := &UnVersioned{Key: key} 61 | mvcc.Engine.Set(versioned.MVCCEncode(), value) 62 | } 63 | 64 | func (mvcc *MVCC) Begin() *MVCCTransactionEngine { 65 | return MVCCTxnEngineBegin(mvcc.Engine) 66 | } 67 | 68 | func (mvcc *MVCC) BeginReadOnly() *MVCCTransactionEngine { 69 | return MVCCTxnEngineBeginReadOnly(mvcc.Engine, 0) 70 | } 71 | 72 | func (mvcc *MVCC) BeginAsOf(version Version) *MVCCTransactionEngine { 73 | return MVCCTxnEngineBeginReadOnly(mvcc.Engine, version) 74 | } 75 | 76 | func NewMVCC(engine log.Engine) *MVCC { 77 | return &MVCC{ 78 | Engine: engine, 79 | Mu: sync.Mutex{}, 80 | } 81 | } 82 | 83 | type MVCCTransactionEngine struct { 84 | Engine log.Engine 85 | St *TransactionState 86 | } 87 | 88 | func (mvccTxnEngine *MVCCTransactionEngine) Version() Version { 89 | return mvccTxnEngine.St.Version 90 | } 91 | 92 | func (mvccTxnEngine *MVCCTransactionEngine) ReadOnly() bool { 93 | return mvccTxnEngine.St.ReadOnly 94 | } 95 | 96 | func (mvccTxnEngine *MVCCTransactionEngine) Set(keyPrefix [2]byte, key []byte, value []byte) error { 97 | return mvccTxnEngine.WriteVersion(keyPrefix, key, value) 98 | } 99 | 100 | func (mvccTxnEngine *MVCCTransactionEngine) DDLSet(key []byte, value []byte) bool { 101 | mvccTxnEngine.Engine.Set(key, value) 102 | return true 103 | } 104 | func (mvccTxnEngine *MVCCTransactionEngine) DDLGet(key []byte) []byte { 105 | return mvccTxnEngine.Engine.Get(key) 106 | } 107 | func (mvccTxnEngine *MVCCTransactionEngine) DDLDelete(key []byte) { 108 | mvccTxnEngine.Engine.Delete(key) 109 | } 110 | 111 | func (mvccTxnEngine *MVCCTransactionEngine) DDLScanPrefix(keyPrefix [2]byte) []*bitcask.ByteMap { 112 | return mvccTxnEngine.Engine.ScanPrefix(keyPrefix[:]) 113 | } 114 | 115 | func (mvccTxnEngine *MVCCTransactionEngine) ScanPrefix(keyPrefix [2]byte) []*bitcask.ByteMap { 116 | from := &Versioned{ 117 | KeyPrefix: keyPrefix, 118 | Version: 0, 119 | Key: []byte{}, 120 | } 121 | to := &Versioned{ 122 | KeyPrefix: keyPrefix, 123 | Version: mvccTxnEngine.St.Version + 1, 124 | Key: []byte{}, 125 | } 126 | 127 | itemList := mvccTxnEngine.Engine.Scan(from.MVCCEncode(), to.MVCCEncode()) 128 | itemListResp := []*bitcask.ByteMap{} 129 | for i := len(itemList) - 1; i >= 0; i-- { 130 | versionKey := DecodeKey(itemList[i].Key) 131 | if v, ok := versionKey.(*Versioned); ok { 132 | if mvccTxnEngine.St.IsVisible(v.Version) { 133 | itemListResp = append(itemListResp, itemList[i]) 134 | } 135 | } 136 | } 137 | return itemListResp 138 | } 139 | func (mvccTxnEngine *MVCCTransactionEngine) Delete(keyPrefix [2]byte, key []byte) bool { 140 | 141 | from := &Versioned{KeyPrefix: keyPrefix, Version: 0, Key: key} 142 | to := &Versioned{KeyPrefix: keyPrefix, Version: math.MaxInt, Key: key} 143 | versionList := mvccTxnEngine.Engine.Scan(from.MVCCEncode(), to.MVCCEncode()) 144 | for _, item := range versionList { 145 | versioned, ok := DecodeKey(item.Key).(*Versioned) 146 | if !ok { 147 | continue 148 | } 149 | if slicesEqual(versioned.Key, key) { 150 | mvccTxnEngine.Engine.Delete(item.Key) 151 | } 152 | } 153 | 154 | return true 155 | } 156 | func (mvccTxnEngine *MVCCTransactionEngine) Get(keyPrefix [2]byte, key []byte) []byte { 157 | from := &Versioned{ 158 | KeyPrefix: keyPrefix, 159 | Key: key, 160 | Version: 0, 161 | } 162 | to := &Versioned{ 163 | KeyPrefix: keyPrefix, 164 | Key: key, 165 | Version: mvccTxnEngine.St.Version, 166 | } 167 | 168 | itemList := mvccTxnEngine.Engine.Scan(from.MVCCEncode(), to.MVCCEncode()) 169 | 170 | for i := len(itemList) - 1; i >= 0; i-- { 171 | versionKey := DecodeKey(itemList[i].Key) 172 | v, ok := versionKey.(*Versioned) 173 | if !ok { 174 | continue 175 | } 176 | 177 | if mvccTxnEngine.St.IsVisible(v.Version) { 178 | if slicesEqual(v.Key, key) { 179 | return itemList[i].Value 180 | } 181 | } 182 | 183 | } 184 | return []byte{} 185 | } 186 | 187 | func slicesEqual(a, b []byte) bool { 188 | if len(a) != len(b) { 189 | return false 190 | } 191 | for i := range a { 192 | if a[i] != b[i] { 193 | return false 194 | } 195 | } 196 | return true 197 | } 198 | 199 | func (mvccTxnEngine *MVCCTransactionEngine) WriteVersion(keyPrefix [2]byte, key []byte, value []byte) error { 200 | if mvccTxnEngine.St.ReadOnly { 201 | return nil 202 | } 203 | 204 | version := mvccTxnEngine.St.Version + 1 205 | 206 | for tmpVersion, _ := range mvccTxnEngine.St.Active { 207 | if tmpVersion < version { 208 | version = tmpVersion 209 | } 210 | } 211 | from := &Versioned{KeyPrefix: keyPrefix, Version: version, Key: key} 212 | to := &Versioned{KeyPrefix: keyPrefix, Version: math.MaxInt, Key: key} 213 | versionList := mvccTxnEngine.Engine.Scan(from.MVCCEncode(), to.MVCCEncode()) 214 | 215 | for _, v := range versionList { 216 | versionByte := DecodeKey(v.Key) 217 | if versioned, ok := versionByte.(*Versioned); ok { 218 | if !mvccTxnEngine.St.IsVisible(versioned.Version) { 219 | return errors.New("Error Serialization") 220 | } 221 | } 222 | } 223 | 224 | txnWrite := &TxnWrite{ 225 | Version: mvccTxnEngine.St.Version, 226 | Key: append(keyPrefix[:], key...), 227 | } 228 | mvccTxnEngine.Engine.Set(txnWrite.MVCCEncode(), []byte{'1'}) 229 | 230 | versioned := &Versioned{ 231 | KeyPrefix: keyPrefix, 232 | Key: key, 233 | Version: mvccTxnEngine.St.Version, 234 | } 235 | 236 | mvccTxnEngine.Engine.Set(versioned.MVCCEncode(), value) 237 | 238 | return nil 239 | } 240 | 241 | func (mvccTxnEngine *MVCCTransactionEngine) Commit() bool { 242 | if mvccTxnEngine.St.ReadOnly { 243 | return true 244 | } 245 | fromTxnWrite := TxnWrite{ 246 | Version: mvccTxnEngine.St.Version, 247 | } 248 | toTxnWrite := TxnWrite{ 249 | Version: mvccTxnEngine.St.Version + 1, 250 | } 251 | removeKV := mvccTxnEngine.Engine.Scan(fromTxnWrite.MVCCEncode(), toTxnWrite.MVCCEncode()) 252 | for _, v := range removeKV { 253 | mvccTxnEngine.Engine.Delete(v.Key) 254 | } 255 | txnActive := TxnActive{ 256 | Version: mvccTxnEngine.St.Version, 257 | } 258 | mvccTxnEngine.Engine.Delete(txnActive.MVCCEncode()) 259 | 260 | return true 261 | } 262 | 263 | func (mvccTxnEngine *MVCCTransactionEngine) RollBack() bool { 264 | if mvccTxnEngine.St.ReadOnly { 265 | return true 266 | } 267 | 268 | fromTxnWrite := TxnWrite{ 269 | Version: mvccTxnEngine.St.Version, 270 | } 271 | toTxnWrite := TxnWrite{ 272 | Version: mvccTxnEngine.St.Version + 1, 273 | } 274 | 275 | scan := mvccTxnEngine.Engine.Scan(fromTxnWrite.MVCCEncode(), toTxnWrite.MVCCEncode()) 276 | for _, v := range scan { 277 | key := DecodeKey(v.Key) 278 | if txnWrite, ok := key.(*TxnWrite); ok { 279 | versionByte := &Versioned{ 280 | Version: mvccTxnEngine.St.Version, 281 | Key: txnWrite.Key[2:], 282 | KeyPrefix: [2]byte{txnWrite.Key[0], txnWrite.Key[1]}, 283 | } 284 | mvccTxnEngine.Engine.Delete(versionByte.MVCCEncode()) 285 | } else { 286 | return false 287 | } 288 | mvccTxnEngine.Engine.Delete(v.Key) 289 | } 290 | 291 | txnActive := TxnActive{ 292 | Version: mvccTxnEngine.St.Version, 293 | } 294 | mvccTxnEngine.Engine.Delete(txnActive.MVCCEncode()) 295 | return true 296 | } 297 | 298 | func MVCCTxnEngineBegin(engine log.Engine) *MVCCTransactionEngine { 299 | nextVersion := NextVersion{} 300 | versionByte := engine.Get(nextVersion.MVCCEncode()) 301 | var versionValue uint64 302 | versionValue = 1 303 | if len(versionByte) == 0 { 304 | versionValue = uint64(1) 305 | } else { 306 | _ = util.ByteToInt(versionByte, &versionValue) 307 | } 308 | newVersionByte := util.BinaryToByte(versionValue + 1) 309 | engine.Set(nextVersion.MVCCEncode(), newVersionByte) 310 | 311 | active := ScanActive(engine) 312 | if len(active) > 0 { 313 | txnActiveSnap := TxnActiveSnapshot{Version: Version(versionValue)} 314 | engine.Set(txnActiveSnap.MVCCEncode(), util.BinaryStructToByte(&active)) 315 | } 316 | txnActive := TxnActive{ 317 | Version: Version(versionValue), 318 | } 319 | engine.Set(txnActive.MVCCEncode(), []byte{'1'}) 320 | return &MVCCTransactionEngine{ 321 | Engine: engine, 322 | St: &TransactionState{ 323 | Version: Version(versionValue), 324 | ReadOnly: false, 325 | Active: active, 326 | }, 327 | } 328 | } 329 | 330 | func MVCCTxnEngineBeginReadOnly(engine log.Engine, asOf Version) *MVCCTransactionEngine { 331 | nextVersion := NextVersion{} 332 | versionByte := engine.Get(nextVersion.MVCCEncode()) 333 | var versionValue uint64 334 | if len(versionByte) == 0 { 335 | versionValue = 1 336 | } else { 337 | _ = util.ByteToInt(versionByte, &versionValue) 338 | } 339 | var active VersionHashSet 340 | if asOf != 0 { 341 | if uint64(asOf) >= versionValue { 342 | logger.Info(" specified version>last version,return last data") 343 | } 344 | versionValue = uint64(asOf) 345 | txnActiveSnap := TxnActiveSnapshot{ 346 | Version: Version(versionValue), 347 | } 348 | txnActiveSnapVersionByte := engine.Get(txnActiveSnap.MVCCEncode()) 349 | snapVersion := DecodeKey(txnActiveSnapVersionByte) 350 | if _, ok := snapVersion.(*TxnActiveSnapshot); ok { 351 | util.ByteToStruct(txnActiveSnapVersionByte, &active) 352 | } 353 | } else { 354 | active = ScanActive(engine) 355 | 356 | } 357 | 358 | return &MVCCTransactionEngine{Engine: engine, St: &TransactionState{ 359 | Version: Version(versionValue), 360 | ReadOnly: true, 361 | Active: active, 362 | }} 363 | } 364 | 365 | func ScanActive(session log.Engine) VersionHashSet { 366 | active := VersionHashSet{} 367 | txnActive := TxnActive{} 368 | scan := session.ScanPrefix(txnActive.MVCCEncode()) 369 | for _, v := range scan { 370 | txnActiveInterface := DecodeKey(v.Key) 371 | 372 | txnActive1, ok := txnActiveInterface.(*TxnActive) 373 | if ok { 374 | active[txnActive1.Version] = struct{}{} 375 | } 376 | } 377 | return active 378 | } 379 | 380 | func MVCCTxnEngineResume(engine log.Engine, state *TransactionState) *MVCCTransactionEngine { 381 | active := TxnActive{state.Version} 382 | value := engine.Get(active.MVCCEncode()) 383 | if !state.ReadOnly && len(value) == 0 { 384 | logger.Info("No active transaction at version ", state.Version) 385 | return nil 386 | } 387 | return &MVCCTransactionEngine{ 388 | Engine: engine, 389 | St: state, 390 | } 391 | } 392 | -------------------------------------------------------------------------------- /storage/mvccTransaction.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | type VersionHashSet = map[Version]struct{} 4 | type TransactionState struct { 5 | Version Version 6 | ReadOnly bool 7 | Active VersionHashSet 8 | } 9 | 10 | func (txnState *TransactionState) IsVisible(version Version) bool { 11 | if _, ok := txnState.Active[version]; ok { 12 | if version == txnState.Version { 13 | // 虽然活动但是应该对自己可见 14 | return true 15 | } 16 | return false 17 | } else if txnState.ReadOnly { 18 | return version < txnState.Version 19 | } else { 20 | return version <= txnState.Version 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /storage/mvcckey.go: -------------------------------------------------------------------------------- 1 | package storage 2 | 3 | import ( 4 | "bytes" 5 | "cabbageDB/util" 6 | "encoding/gob" 7 | ) 8 | 9 | type Key interface { 10 | MVCCEncode() []byte 11 | } 12 | 13 | type Version uint64 14 | 15 | func DecodeKey(key []byte) Key { 16 | if len(key) == 0 { 17 | return nil 18 | } 19 | 20 | if key[0] != MVCCKeyPrefix { 21 | return nil 22 | } 23 | 24 | switch key[1] { 25 | case NextVersionPrefix: 26 | return &NextVersion{} 27 | case TxnActivePrefix: 28 | txnActive := TxnActive{} 29 | var versionNum uint64 30 | util.ByteToInt(key[2:], &versionNum) 31 | txnActive.Version = Version(versionNum) 32 | return &txnActive 33 | case TxnActiveSnapshotPrefix: 34 | txnActiveSnap := TxnActiveSnapshot{} 35 | var versionNum uint64 36 | util.ByteToInt(key[2:], &versionNum) 37 | txnActiveSnap.Version = Version(versionNum) 38 | return &txnActiveSnap 39 | case TxnWritePrefix: 40 | txnWrite := TxnWrite{} 41 | var versionNum uint64 42 | util.ByteToInt(key[2:10], &versionNum) 43 | txnWrite.Version = Version(versionNum) 44 | txnWrite.Key = key[10:] 45 | return &txnWrite 46 | case VersionedPrefix: 47 | version := Versioned{} 48 | version.KeyPrefix = [2]byte{key[2], key[3]} 49 | var versionNum uint64 50 | util.ByteToInt(key[4:12], &versionNum) 51 | version.Version = Version(versionNum) 52 | version.Key = key[12:] 53 | return &version 54 | case UnVersionedPrefix: 55 | unVersion := UnVersioned{} 56 | unVersion.Key = key[2:] 57 | return &unVersion 58 | } 59 | 60 | return nil 61 | } 62 | 63 | const ( 64 | MVCCKeyPrefix byte = 0x03 65 | NextVersionPrefix byte = 0x02 66 | TxnActivePrefix byte = 0x03 67 | TxnActiveSnapshotPrefix byte = 0x04 68 | TxnWritePrefix byte = 0x05 69 | VersionedPrefix byte = 0x06 70 | UnVersionedPrefix byte = 0x07 71 | ) 72 | 73 | type NextVersion struct{} 74 | 75 | func (n *NextVersion) MVCCEncode() []byte { 76 | return []byte{MVCCKeyPrefix, NextVersionPrefix} 77 | } 78 | 79 | type TxnActive struct { 80 | Version Version 81 | } 82 | 83 | func (t *TxnActive) MVCCEncode() []byte { 84 | return append([]byte{MVCCKeyPrefix, TxnActivePrefix}, util.BinaryToByte(uint64(t.Version))...) 85 | } 86 | 87 | type TxnActiveSnapshot struct { 88 | Version Version 89 | } 90 | 91 | func (t *TxnActiveSnapshot) MVCCEncode() []byte { 92 | return append([]byte{MVCCKeyPrefix, TxnActiveSnapshotPrefix}, util.BinaryToByte(uint64(t.Version))...) 93 | } 94 | 95 | type TxnWrite struct { 96 | Version Version 97 | Key []byte 98 | } 99 | 100 | func (t *TxnWrite) MVCCEncode() []byte { 101 | versionByte := util.BinaryToByte(uint64(t.Version)) 102 | valueByte := append(versionByte, t.Key...) 103 | return append([]byte{MVCCKeyPrefix, TxnWritePrefix}, valueByte...) 104 | } 105 | 106 | type Versioned struct { 107 | KeyPrefix [2]byte 108 | Version Version 109 | Key []byte 110 | } 111 | 112 | // 这里自定义编码 避免因为长度问题,导致遍历失败 113 | func (v *Versioned) MVCCEncode() []byte { 114 | byte1 := append([]byte{MVCCKeyPrefix, VersionedPrefix}, v.KeyPrefix[:]...) 115 | byte1 = append(byte1, util.BinaryToByte(uint64(v.Version))...) 116 | byte1 = append(byte1, v.Key...) 117 | return byte1 118 | } 119 | 120 | type UnVersioned struct { 121 | Key []byte 122 | } 123 | 124 | func (u *UnVersioned) MVCCEncode() []byte { 125 | return append([]byte{MVCCKeyPrefix, UnVersionedPrefix}, u.Key...) 126 | } 127 | 128 | func GobReg() { 129 | 130 | gob.Register(&Versioned{}) 131 | gob.Register(&UnVersioned{}) 132 | gob.Register(Version(0)) 133 | gob.Register(&NextVersion{}) 134 | gob.Register(&TxnActive{}) 135 | gob.Register(&TxnActiveSnapshot{}) 136 | gob.Register(&TxnWrite{}) 137 | gob.Register(&TransactionState{}) 138 | gob.Register(&VersionHashSet{}) 139 | 140 | // 预先编码保证顺序不变 141 | var buf bytes.Buffer 142 | enc := gob.NewEncoder(&buf) 143 | _ = enc.Encode(&Versioned{}) 144 | _ = enc.Encode(&UnVersioned{}) 145 | _ = enc.Encode(Version(0)) 146 | _ = enc.Encode(&NextVersion{}) 147 | _ = enc.Encode(&TxnActive{}) 148 | _ = enc.Encode(&TxnActiveSnapshot{}) 149 | _ = enc.Encode(&TxnWrite{}) 150 | _ = enc.Encode(&TransactionState{}) 151 | _ = enc.Encode(&VersionHashSet{}) 152 | 153 | } 154 | -------------------------------------------------------------------------------- /util/conn.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "net" 5 | ) 6 | 7 | func SendPrefixMsg(conn net.Conn, prefix [2]byte, req []byte) error { 8 | reqLen := len(req) 9 | reqLenByte := BinaryToByte(uint64(reqLen)) 10 | reqByte := append(prefix[:], reqLenByte...) 11 | reqByte = append(reqByte, req...) 12 | _, err := conn.Write(reqByte) 13 | return err 14 | } 15 | 16 | func ReceiveMsg(conn net.Conn) []byte { 17 | var msgLen uint64 18 | tmp := [8]byte{} 19 | conn.Read(tmp[:]) 20 | ByteToInt(tmp[:], &msgLen) 21 | 22 | cnt := 0 23 | msgByte := make([]byte, int(msgLen)) 24 | for cnt < int(msgLen) { 25 | n, _ := conn.Read(msgByte) 26 | cnt += n 27 | } 28 | return msgByte 29 | } 30 | -------------------------------------------------------------------------------- /util/util.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "bytes" 5 | "cabbageDB/logger" 6 | "encoding/binary" 7 | "encoding/gob" 8 | "reflect" 9 | ) 10 | 11 | func BinaryToByte[T ~uint32 | int32 | uint64 | uint8 | int64](value T) []byte { 12 | var buffer bytes.Buffer 13 | err := binary.Write(&buffer, binary.BigEndian, value) 14 | if err != nil { 15 | logger.Info("err===>", err.Error()) 16 | } 17 | return buffer.Bytes() 18 | } 19 | 20 | func ByteToInt[T ~uint32 | int32 | uint64 | uint8 | int64](buf []byte, value *T) error { 21 | err := binary.Read(bytes.NewReader(buf), binary.BigEndian, value) 22 | return err 23 | } 24 | 25 | func BufferAppend(args ...[]byte) []byte { 26 | var buffer bytes.Buffer 27 | for _, v := range args { 28 | buffer.Write(v) 29 | } 30 | return buffer.Bytes() 31 | } 32 | 33 | func BoolToByte(b bool) byte { 34 | if b { 35 | return 0x01 36 | } else { 37 | return 0x00 38 | } 39 | } 40 | func ByteIterToStruct[T any](value []byte, toStruct *T) { 41 | if len(value) == 0 { 42 | return 43 | } 44 | decoder := gob.NewDecoder(bytes.NewBuffer(value)) 45 | err := decoder.Decode(toStruct) 46 | if err != nil { 47 | logger.Info("Byte to Struct "+reflect.TypeOf(toStruct).String()+"err:", err.Error()) 48 | } 49 | 50 | } 51 | 52 | func ByteToStruct[T any](value []byte, toStruct *T) { 53 | if len(value) == 0 { 54 | return 55 | } 56 | decoder := gob.NewDecoder(bytes.NewBuffer(value)) 57 | err := decoder.Decode(&toStruct) 58 | if err != nil { 59 | logger.Info("Byte to Struct "+reflect.TypeOf(toStruct).String()+"err:", err.Error()) 60 | } 61 | 62 | } 63 | 64 | func BinaryStructToByte[T any](v *T) []byte { 65 | var buffer bytes.Buffer 66 | encoder := gob.NewEncoder(&buffer) 67 | err := encoder.Encode(v) 68 | if err != nil { 69 | 70 | logger.Info("Struct to Byte err:", err.Error()) 71 | return nil 72 | } 73 | 74 | return buffer.Bytes() 75 | } 76 | --------------------------------------------------------------------------------