├── LICENSE ├── README.md ├── btree ├── btree.go └── btree_test.go ├── btree_iter ├── btree_iter.go └── btree_iter_test.go ├── freelist ├── freelist.go └── freelist_test.go ├── go.mod ├── go.sum ├── kv ├── kv.go └── kv_test.go ├── ql ├── ql_exec.go └── ql_parse.go ├── table ├── table.go └── table_test.go └── transactions ├── tx.go └── tx_test.go /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Aditya Kumar Singh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ![9ad3c058-d029-47c4-a2df-68cbcbf41e0a](https://github.com/user-attachments/assets/4b77a097-ea68-4e95-85f0-639785b93c7b) 3 | # AdiDB - Building a Database from Scratch 4 | 5 | ## Introduction 6 | 7 | Hey there! 👋 This is my journey of building a database from scratch using Go. I really wanted to understand how databases actually work under the hood, but quickly turned into a fascinating deep-dive into database internals. 8 | 9 | ## Features I've Implemented 10 | 11 | ### ACID Transactions 12 | After many late nights and countless cups of coffee, I've implemented full ACID compliance: 13 | - **Atomicity**: All or nothing - transactions either fully complete or fully roll back 14 | - **Consistency**: The database remains valid after every transaction 15 | - **Isolation**: Concurrent transactions don't step on each other's toes 16 | - **Durability**: Once committed, data stays committed (yes, even if your machine crashes!) 17 | 18 | ### Concurrent Transaction Handling 19 | This was probably the trickiest part! The engine supports: 20 | - Multiple transactions running simultaneously 21 | - Deadlock detection 22 | - Lock management to prevent dirty reads/writes 23 | - Transaction isolation levels 24 | 25 | ### B+ Trees for Indexing 26 | I chose B+ Trees because they're practically the industry standard for database indexes. My implementation includes: 27 | - Efficient range queries 28 | - Auto-balancing on insert/delete 29 | - Disk-friendly node structure 30 | - Configurable node size 31 | 32 | ### SQL-like Query Language 33 | While it's not full SQL (yet!), the query language supports basic operations: 34 | ```sql 35 | CREATE TABLE users ( 36 | id INT PRIMARY KEY, 37 | name VARCHAR(50), 38 | age INT 39 | ) 40 | 41 | INSERT INTO users (name, age) VALUES ("Alice", 30) 42 | 43 | UPDATE users SET age = 31 WHERE id = 1 44 | 45 | SELECT name, age FROM users 46 | ``` 47 | 48 | ## Project Status 49 | 50 | This is very much a learning project and a work in progress. While it works, there's still a lot I want to add: 51 | - [x] Basic CRUD operations 52 | - [x] ACID transactions 53 | - [x] B+ Tree indexes 54 | - [x] Table creation and schema management 55 | - [x] Simple query parser 56 | - [ ] Implement Raft Algorithm 57 | - [ ] Make it a proper distributed database 58 | 59 | ## Why Build This? 60 | 61 | I believe the best way to truly understand how something works is to build it from scratch. While I wouldn't recommend using this in production, building this has taught me more about databases than using them. 62 | 63 | ## Learning Resources 64 | 65 | If you're interested in building your own database, here are some resources I found incredibly helpful: 66 | - [Database Internals](https://www.databass.dev) book 67 | - [Build Your Own](https://build-your-own.org/) book 68 | 69 | ## Contributing 70 | 71 | Feel free to open issues or PRs if you spot bugs or have suggestions! While this is primarily a learning project, I'm always happy to collaborate and learn from others. 72 | 73 | ## License 74 | 75 | MIT 76 | 77 | --- 78 | 79 | *Built with Go and lots of debugging sessions* 🔍 80 | -------------------------------------------------------------------------------- /btree/btree.go: -------------------------------------------------------------------------------- 1 | package btree 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | ) 8 | 9 | const HEADER = 4 10 | const BTREE_PAGE_SIZE = 4096 11 | const BTREE_MAX_KEY_SIZE = 1000 12 | const BTREE_MAX_VAL_SIZE = 3000 13 | 14 | func assert(cond bool) { 15 | if !cond { 16 | panic("assertion failure") 17 | } 18 | } 19 | 20 | func init() { 21 | node1max := HEADER + 8 + 2 + 4 + BTREE_MAX_KEY_SIZE + BTREE_MAX_VAL_SIZE 22 | assert(node1max <= BTREE_PAGE_SIZE) 23 | } 24 | 25 | // in memory data type 26 | type BNode []byte 27 | 28 | type BTree struct { 29 | root uint64 30 | get func(uint64) []byte // dereferecne a pointer -- reads a page from disk 31 | new func([]byte) uint64 //alocates & writes a new page 32 | del func(uint64) //delocate page 33 | } 34 | 35 | const ( 36 | BNODE_NODE = 1 //internal nodes without values 37 | BNODE_LEAF = 2 //leaf nodes with values 38 | ) 39 | 40 | func (node BNode) btype() uint16 { 41 | return binary.LittleEndian.Uint16(node[0:2]) 42 | } 43 | 44 | func (node BNode) nkeys() uint16 { 45 | return binary.LittleEndian.Uint16(node[2:4]) 46 | } 47 | 48 | func (node BNode) setHeader(btype uint16, nkeys uint16) { 49 | binary.LittleEndian.PutUint16(node[0:2], btype) 50 | binary.LittleEndian.PutUint16(node[2:4], nkeys) 51 | } 52 | 53 | // pointers 54 | func (node BNode) getPtr(idx uint16) uint64 { 55 | assert(idx < node.nkeys()) 56 | pos := HEADER + 8*idx 57 | return binary.LittleEndian.Uint64(node[pos:]) 58 | } 59 | 60 | func (node BNode) setPtr(idx uint16, val uint64) { 61 | assert(idx < node.nkeys()) 62 | pos := HEADER + 8*idx 63 | binary.LittleEndian.PutUint64(node[pos:], val) 64 | } 65 | 66 | // offset list 67 | func offsetPos(node BNode, idx uint16) uint16 { 68 | assert(1 <= idx && idx <= node.nkeys()) 69 | 70 | return HEADER + 8*node.nkeys() + 2*(idx-1) 71 | } 72 | 73 | func (node BNode) getOffset(idx uint16) uint16 { 74 | if idx == 0 { 75 | return 0 76 | } 77 | 78 | return binary.LittleEndian.Uint16(node[offsetPos(node, idx):]) 79 | } 80 | 81 | func (node BNode) setOffset(idx uint16, offset uint16) { 82 | binary.LittleEndian.PutUint16(node[offsetPos(node, idx):], offset) 83 | } 84 | 85 | func (node BNode) kvPos(idx uint16) uint16 { 86 | assert(idx <= node.nkeys()) 87 | 88 | return HEADER + 8*node.nkeys() + 2*node.nkeys() + node.getOffset(idx) 89 | } 90 | 91 | func (node BNode) getKey(idx uint16) []byte { 92 | assert(idx < node.nkeys()) 93 | pos := node.kvPos(idx) 94 | klen := binary.LittleEndian.Uint16(node[pos:]) 95 | 96 | return node[pos+4:][:klen] 97 | } 98 | func (node BNode) getVal(idx uint16) []byte { 99 | assert(idx < node.nkeys()) 100 | pos := node.kvPos(idx) 101 | klen := binary.LittleEndian.Uint16(node[pos+0:]) 102 | vlen := binary.LittleEndian.Uint16(node[pos+2:]) 103 | return node[pos+4+klen:][:vlen] 104 | } 105 | 106 | func (node BNode) nbytes() uint16 { 107 | return node.kvPos(node.nkeys()) 108 | } 109 | 110 | func nodeLookupLE(node BNode, key []byte) uint16 { 111 | nkeys := node.nkeys() 112 | found := uint16(0) 113 | 114 | for i := uint16(1); i < nkeys; i++ { 115 | cmp := bytes.Compare(node.getKey(i), key) 116 | if cmp <= 0 { 117 | found = i 118 | } 119 | if cmp >= 0 { 120 | break 121 | } 122 | } 123 | return found 124 | } 125 | 126 | // copies a KV pair 127 | func nodeAppendKV(new BNode, idx uint16, ptr uint64, key []byte, val []byte) { 128 | new.setPtr(idx, ptr) 129 | 130 | pos := new.kvPos(idx) 131 | binary.LittleEndian.PutUint16(new[pos+0:], uint16(len(key))) 132 | binary.LittleEndian.PutUint16(new[pos+2:], uint16(len(val))) 133 | copy(new[pos+4:], key) 134 | copy(new[pos+4+uint16(len(key)):], val) 135 | 136 | new.setOffset(idx+1, new.getOffset(idx)+4+uint16((len(key)+len(val)))) 137 | } 138 | 139 | // copies multiple KV's into posiiton from old node 140 | func nodeAppendRange(new BNode, old BNode, dstNew uint16, srcOld uint16, n uint16) { 141 | assert(srcOld+n <= old.nkeys()) 142 | assert(dstNew+n <= new.nkeys()) 143 | if n == 0 { 144 | return 145 | } 146 | 147 | for i := uint16(0); i < n; i++ { 148 | new.setPtr(dstNew+i, old.getPtr(srcOld+1)) 149 | } 150 | 151 | dstBegin := new.getOffset(dstNew) 152 | srcBegin := old.getOffset(srcOld) 153 | for i := uint16(1); i <= n; i++ { 154 | offset := dstBegin + old.getOffset(srcOld+i) - srcBegin 155 | new.setOffset(dstNew+i, offset) 156 | } 157 | 158 | // kv's 159 | begin := old.kvPos(srcOld) 160 | end := old.kvPos(srcOld + n) 161 | copy(new[new.kvPos(dstNew):], old[begin:end]) 162 | } 163 | 164 | // addin a new key to leaf node 165 | func leafInsert(new BNode, old BNode, idx uint16, key []byte, val []byte) { 166 | new.setHeader(BNODE_LEAF, old.nkeys()+1) 167 | nodeAppendRange(new, old, 0, 0, idx) 168 | nodeAppendKV(new, idx, 0, key, val) 169 | nodeAppendRange(new, old, idx+1, idx, old.nkeys()-idx) 170 | } 171 | 172 | func leafUpdate(new BNode, old BNode, idx uint16, key []byte, val []byte) { 173 | new.setHeader(BNODE_LEAF, old.nkeys()) 174 | nodeAppendRange(new, old, 0, 0, idx) 175 | nodeAppendKV(new, idx, 0, key, val) 176 | nodeAppendRange(new, old, idx+1, idx+1, old.nkeys()-(idx+1)) 177 | } 178 | 179 | func nodeReplaceKid1ptr(new BNode, old BNode, idx uint16, ptr uint64) { 180 | copy(new, old[:old.nbytes()]) 181 | new.setPtr(idx, ptr) // only the pointer is changed 182 | } 183 | 184 | func nodeReplaceKidN(tree *BTree, new BNode, old BNode, idx uint16, kids ...BNode) { 185 | inc := uint16(len(kids)) 186 | if inc == 1 && bytes.Equal(kids[0].getKey(0), old.getKey(idx)) { 187 | nodeReplaceKid1ptr(new, old, idx, tree.new(kids[0])) 188 | return 189 | } 190 | 191 | new.setHeader(BNODE_NODE, old.nkeys()+inc-1) 192 | nodeAppendRange(new, old, 0, 0, idx) 193 | 194 | for i, node := range kids { 195 | nodeAppendKV(new, idx+uint16(i), tree.new(node), node.getKey(0), nil) 196 | } 197 | nodeAppendRange(new, old, idx+inc, idx+1, old.nkeys()-(idx+1)) 198 | } 199 | 200 | // split a big node into 2 201 | func nodeSplit2(left BNode, right BNode, old BNode) { 202 | assert(old.nkeys() >= 2) 203 | 204 | nleft := old.nkeys() / 2 205 | 206 | left_bytes := func() uint16 { 207 | return HEADER + 8*nleft + 2*nleft + old.getOffset(nleft) 208 | } 209 | 210 | for left_bytes() > BTREE_PAGE_SIZE { 211 | nleft-- 212 | } 213 | assert(nleft >= 1) 214 | 215 | right_bytes := func() uint16 { 216 | return old.nbytes() - left_bytes() + HEADER 217 | } 218 | for right_bytes() > BTREE_PAGE_SIZE { 219 | nleft++ 220 | } 221 | assert(nleft < old.nkeys()) 222 | nright := old.nkeys() - nleft 223 | 224 | left.setHeader(old.btype(), nleft) 225 | right.setHeader(old.btype(), nright) 226 | nodeAppendRange(left, old, 0, 0, nleft) 227 | nodeAppendRange(right, old, 0, nleft, nright) 228 | 229 | assert(right.nbytes() <= BTREE_PAGE_SIZE) 230 | } 231 | 232 | // splits an oversized node 233 | func nodeSplit3(old BNode) (uint16, [3]BNode) { 234 | if old.nbytes() <= BTREE_PAGE_SIZE { 235 | old = old[:BTREE_PAGE_SIZE] 236 | return 1, [3]BNode{old} //wont split 237 | } 238 | 239 | left := BNode(make([]byte, 2*BTREE_PAGE_SIZE)) 240 | right := BNode(make([]byte, BTREE_PAGE_SIZE)) 241 | nodeSplit2(left, right, old) 242 | 243 | if left.nbytes() <= BTREE_PAGE_SIZE { 244 | left = left[:BTREE_PAGE_SIZE] 245 | return 2, [3]BNode{left, right} 246 | } 247 | 248 | mostLeft := BNode(make([]byte, BTREE_PAGE_SIZE)) 249 | middle := BNode(make([]byte, BTREE_PAGE_SIZE)) 250 | nodeSplit2(mostLeft, middle, left) 251 | assert(mostLeft.nbytes() <= BTREE_PAGE_SIZE) 252 | 253 | return 3, [3]BNode{mostLeft, middle, right} 254 | } 255 | 256 | const ( 257 | MODE_UPSERT = 0 //insert or replace 258 | MODE_UPDATE_ONLY = 1 // update existing keys 259 | MODE_INSERT_ONLY = 2 //add only new keys 260 | ) 261 | 262 | type UpdateReq struct { 263 | tree *BTree 264 | Added bool // new key 265 | Updated bool 266 | Old []byte //value before update 267 | Key []byte 268 | Val []byte 269 | Mode int 270 | } 271 | 272 | type DeleteReq struct { 273 | tree *BTree 274 | Key []byte 275 | Old []byte 276 | } 277 | 278 | // tree insertion- inserts a KV into a node 279 | func treeInsert(req *UpdateReq, node BNode) BNode { 280 | new := BNode(make([]byte, 2*BTREE_PAGE_SIZE)) 281 | 282 | idx := nodeLookupLE(node, req.Key) 283 | switch node.btype() { 284 | case BNODE_LEAF: 285 | if bytes.Equal(req.Key, node.getKey(idx)) { 286 | // updating the key 287 | leafUpdate(new, node, idx, req.Key, req.Val) 288 | } else { 289 | leafInsert(new, node, idx+1, req.Key, req.Val) 290 | } 291 | 292 | case BNODE_NODE: 293 | nodeInsert(req, new, node, idx) 294 | default: 295 | panic("bad node!") 296 | } 297 | 298 | return new 299 | } 300 | 301 | // KV insertion to an internal node 302 | func nodeInsert(req *UpdateReq, new BNode, node BNode, idx uint16) BNode { 303 | kptr := node.getPtr(idx) 304 | 305 | // insertion to kid node 306 | updated := treeInsert(req, req.tree.get(kptr)) 307 | if len(updated) == 0 { 308 | return BNode{} 309 | } 310 | 311 | nsplit, split := nodeSplit3(updated) 312 | // deallocate kid node 313 | req.tree.del(kptr) 314 | nodeReplaceKidN(req.tree, new, node, idx, split[:nsplit]...) 315 | 316 | return new 317 | } 318 | 319 | func checkLimit(key []byte, val []byte) error { 320 | if len(key) == 0 { 321 | return errors.New("empty key") 322 | } 323 | 324 | if len(key) > BTREE_MAX_KEY_SIZE { 325 | return errors.New("key too long") 326 | } 327 | if len(key) > BTREE_MAX_VAL_SIZE { 328 | return errors.New("value too long") 329 | } 330 | 331 | return nil 332 | } 333 | 334 | func nodeReplace2Kid(new BNode, old BNode, idx uint16, ptr uint64, key []byte) { 335 | new.setHeader(BNODE_NODE, old.nkeys()-1) 336 | nodeAppendRange(new, old, 0, 0, idx) 337 | nodeAppendKV(new, idx, ptr, key, nil) 338 | nodeAppendRange(new, old, idx+1, idx+2, old.nkeys()-(idx+2)) 339 | } 340 | 341 | // tree deletion 342 | // remove key from leaf node 343 | func leafDelete(new BNode, old BNode, idx uint16) { 344 | new.setHeader(BNODE_LEAF, old.nkeys()-1) 345 | nodeAppendRange(new, old, 0, 0, idx) 346 | nodeAppendRange(new, old, idx, idx+1, old.nkeys()-(idx+1)) 347 | } 348 | 349 | // mergin 2 nodes into 1 350 | func nodeMerge(new BNode, left BNode, right BNode) { 351 | new.setHeader(left.btype(), left.nkeys()+right.nkeys()) 352 | nodeAppendRange(new, left, 0, 0, left.nkeys()) 353 | nodeAppendRange(new, right, left.nkeys(), 0, right.nkeys()) 354 | assert(new.nbytes() <= BTREE_PAGE_SIZE) 355 | } 356 | 357 | func shouldMerge(tree *BTree, node BNode, idx uint16, updated BNode) (int, BNode) { 358 | if updated.nbytes() > BTREE_PAGE_SIZE/4 { 359 | return 0, BNode{} 360 | } 361 | if idx > 0 { 362 | sibling := BNode(tree.get(node.getPtr(idx - 1))) 363 | merged := sibling.nbytes() + updated.nbytes() - HEADER 364 | if merged <= BTREE_PAGE_SIZE { 365 | return -1, sibling //left 366 | } 367 | } 368 | 369 | if idx+1 < node.nkeys() { 370 | sibling := BNode(tree.get(node.getPtr(idx + 1))) 371 | merged := sibling.nbytes() + updated.nbytes() - HEADER 372 | if merged <= BTREE_PAGE_SIZE { 373 | return +1, sibling // right 374 | } 375 | } 376 | 377 | return 0, BNode{} 378 | } 379 | 380 | func treeDelete(req *DeleteReq, node BNode) BNode { 381 | idx := nodeLookupLE(node, req.Key) 382 | 383 | switch node.btype() { 384 | case BNODE_LEAF: 385 | if !bytes.Equal(req.Key, node.getKey(idx)) { 386 | return BNode{} // not found 387 | } 388 | // delete the key in the leaf 389 | req.Old = node.getVal(idx) 390 | new := BNode(make([]byte, BTREE_PAGE_SIZE)) 391 | leafDelete(new, node, idx) 392 | return new 393 | case BNODE_NODE: 394 | return nodeDelete(req, node, idx) 395 | default: 396 | panic("bad node") 397 | } 398 | } 399 | 400 | func nodeDelete(req *DeleteReq, node BNode, idx uint16) BNode { 401 | tree := req.tree 402 | 403 | kptr := node.getPtr(idx) 404 | updated := treeDelete(req, tree.get(kptr)) 405 | if len(updated) == 0 { 406 | return BNode{} 407 | } 408 | tree.del(kptr) 409 | 410 | new := BNode(make([]byte, BTREE_PAGE_SIZE)) 411 | 412 | mergeDir, sibling := shouldMerge(tree, node, idx, updated) 413 | switch { 414 | case mergeDir < 0: 415 | merged := BNode(make([]byte, BTREE_PAGE_SIZE)) 416 | nodeMerge(merged, sibling, updated) 417 | tree.del(node.getPtr(idx - 1)) 418 | nodeReplace2Kid(new, node, idx-1, tree.new(merged), merged.getKey(0)) 419 | case mergeDir > 0: 420 | merged := BNode(make([]byte, BTREE_PAGE_SIZE)) 421 | nodeMerge(merged, updated, sibling) 422 | tree.del(node.getPtr(idx + 1)) 423 | nodeReplace2Kid(new, node, idx, tree.new(merged), merged.getKey(0)) 424 | 425 | case mergeDir == 0 && updated.nkeys() == 0: 426 | assert(node.nkeys() == 1 && idx == 0) 427 | new.setHeader(BNODE_NODE, 0) 428 | case mergeDir == 0 && updated.nkeys() > 0: // no merge 429 | nodeReplaceKidN(tree, new, node, idx, updated) 430 | } 431 | 432 | return new 433 | } 434 | 435 | func (tree *BTree) Upsert(key []byte, val []byte) (bool, error) { 436 | return tree.Update(&UpdateReq{Key: key, Val: val}) 437 | } 438 | 439 | func (tree *BTree) Update(req *UpdateReq) (bool, error) { 440 | if err := checkLimit(req.Key, req.Val); err != nil { 441 | return false, err 442 | } 443 | 444 | if tree.root == 0 { 445 | // create first node 446 | root := BNode(make([]byte, BTREE_PAGE_SIZE)) 447 | root.setHeader(BNODE_LEAF, 2) 448 | 449 | nodeAppendKV(root, 0, 0, nil, nil) 450 | nodeAppendKV(root, 1, 0, req.Key, req.Val) 451 | tree.root = tree.new(root) 452 | req.Added = true 453 | req.Updated = true 454 | return true, nil 455 | } 456 | 457 | req.tree = tree 458 | updated := treeInsert(req, tree.get(tree.root)) 459 | if len(updated) == 0 { 460 | return false, nil 461 | } 462 | 463 | nsplit, split := nodeSplit3(updated) 464 | tree.del(tree.root) 465 | if nsplit > 1 { 466 | root := BNode(make([]byte, BTREE_PAGE_SIZE)) 467 | root.setHeader(BNODE_NODE, nsplit) 468 | for i, knode := range split[:nsplit] { 469 | ptr, key := tree.new(knode), knode.getKey(0) 470 | nodeAppendKV(root, uint16(i), ptr, key, nil) 471 | } 472 | tree.root = tree.new(root) 473 | } else { 474 | tree.root = tree.new(split[0]) 475 | } 476 | 477 | return true, nil 478 | } 479 | 480 | func (tree *BTree) Delete(req *DeleteReq) (bool, error) { 481 | if err := checkLimit(req.Key, nil); err != nil { 482 | return false, err 483 | } 484 | 485 | if tree.root == 0 { 486 | return false, nil 487 | } 488 | 489 | req.tree = tree 490 | updated := treeDelete(req, tree.get(tree.root)) 491 | if len(updated) == 0 { 492 | return false, nil 493 | } 494 | 495 | tree.del(tree.root) 496 | if updated.btype() == BNODE_NODE && updated.nkeys() == 1 { 497 | tree.root = updated.getPtr(0) 498 | } else { 499 | tree.root = tree.new(updated) 500 | } 501 | 502 | return true, nil 503 | } 504 | 505 | func nodeGetKey(tree *BTree, node BNode, key []byte) ([]byte, bool) { 506 | idx := nodeLookupLE(node, key) 507 | switch node.btype() { 508 | case BNODE_LEAF: 509 | if bytes.Equal(key, node.getKey(idx)) { 510 | return node.getVal(idx), true 511 | } else { 512 | return nil, false 513 | } 514 | case BNODE_NODE: 515 | return nodeGetKey(tree, tree.get(node.getPtr(idx)), key) 516 | 517 | default: 518 | panic("bad node") 519 | } 520 | } 521 | 522 | func (tree *BTree) Get(key []byte) ([]byte, bool) { 523 | if tree.root == 0 { 524 | return nil, false 525 | } 526 | 527 | return nodeGetKey(tree, tree.get(tree.root), key) 528 | } 529 | -------------------------------------------------------------------------------- /btree/btree_test.go: -------------------------------------------------------------------------------- 1 | package btree 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "sort" 7 | "testing" 8 | "unsafe" 9 | 10 | is "github.com/stretchr/testify/require" 11 | ) 12 | 13 | type C struct { 14 | tree BTree 15 | ref map[string]string 16 | pages map[uint64]BNode 17 | } 18 | 19 | func newC() *C { 20 | pages := map[uint64]BNode{} 21 | return &C{ 22 | tree: BTree{ 23 | get: func(ptr uint64) []byte { 24 | node, ok := pages[ptr] 25 | assert(ok) 26 | return node 27 | }, 28 | new: func(node []byte) uint64 { 29 | assert(BNode(node).nbytes() <= BTREE_PAGE_SIZE) 30 | ptr := uint64(uintptr(unsafe.Pointer(&node[0]))) 31 | assert(pages[ptr] == nil) 32 | pages[ptr] = node 33 | return ptr 34 | }, 35 | del: func(ptr uint64) { 36 | assert(pages[ptr] != nil) 37 | delete(pages, ptr) 38 | }, 39 | }, 40 | ref: map[string]string{}, 41 | pages: pages, 42 | } 43 | } 44 | 45 | func (c *C) add(key string, val string) { 46 | err := c.tree.Insert([]byte(key), []byte(val)) 47 | assert(err == nil) 48 | c.ref[key] = val 49 | } 50 | 51 | func (c *C) del(key string) bool { 52 | delete(c.ref, key) 53 | deleted, err := c.tree.Delete([]byte(key)) 54 | assert(err == nil) 55 | return deleted 56 | } 57 | 58 | func (c *C) dump() ([]string, []string) { 59 | keys := []string{} 60 | vals := []string{} 61 | 62 | var nodeDump func(uint64) 63 | nodeDump = func(ptr uint64) { 64 | node := BNode(c.tree.get(ptr)) 65 | nkeys := node.nkeys() 66 | if node.btype() == BNODE_LEAF { 67 | for i := uint16(0); i < nkeys; i++ { 68 | keys = append(keys, string(node.getKey(i))) 69 | vals = append(vals, string(node.getVal(i))) 70 | } 71 | } else { 72 | for i := uint16(0); i < nkeys; i++ { 73 | ptr := node.getPtr(i) 74 | nodeDump(ptr) 75 | } 76 | } 77 | } 78 | 79 | nodeDump(c.tree.root) 80 | assert(keys[0] == "") 81 | assert(vals[0] == "") 82 | return keys[1:], vals[1:] 83 | } 84 | 85 | type sortIF struct { 86 | len int 87 | less func(i, j int) bool 88 | swap func(i, j int) 89 | } 90 | 91 | func (self sortIF) Len() int { 92 | return self.len 93 | } 94 | func (self sortIF) Less(i, j int) bool { 95 | return self.less(i, j) 96 | } 97 | func (self sortIF) Swap(i, j int) { 98 | self.swap(i, j) 99 | } 100 | 101 | func (c *C) verify(t *testing.T) { 102 | keys, vals := c.dump() 103 | 104 | rkeys, rvals := []string{}, []string{} 105 | for k, v := range c.ref { 106 | rkeys = append(rkeys, k) 107 | rvals = append(rvals, v) 108 | } 109 | is.Equal(t, len(rkeys), len(keys)) 110 | sort.Stable(sortIF{ 111 | len: len(rkeys), 112 | less: func(i, j int) bool { return rkeys[i] < rkeys[j] }, 113 | swap: func(i, j int) { 114 | k, v := rkeys[i], rvals[i] 115 | rkeys[i], rvals[i] = rkeys[j], rvals[j] 116 | rkeys[j], rvals[j] = k, v 117 | }, 118 | }) 119 | 120 | is.Equal(t, rkeys, keys) 121 | is.Equal(t, rvals, vals) 122 | 123 | var nodeVerify func(BNode) 124 | nodeVerify = func(node BNode) { 125 | nkeys := node.nkeys() 126 | assert(nkeys >= 1) 127 | if node.btype() == BNODE_LEAF { 128 | return 129 | } 130 | for i := uint16(0); i < nkeys; i++ { 131 | key := node.getKey(i) 132 | kid := BNode(c.tree.get(node.getPtr(i))) 133 | is.Equal(t, key, kid.getKey(0)) 134 | nodeVerify(kid) 135 | } 136 | } 137 | 138 | nodeVerify(c.tree.get(c.tree.root)) 139 | } 140 | 141 | func fmix32(h uint32) uint32 { 142 | h ^= h >> 16 143 | h *= 0x85ebca6b 144 | h ^= h >> 13 145 | h *= 0xc2b2ae35 146 | h ^= h >> 16 147 | return h 148 | } 149 | 150 | func commonTestBasic(t *testing.T, hasher func(uint32) uint32) { 151 | c := newC() 152 | c.add("k", "v") 153 | c.verify(t) 154 | 155 | // insert 156 | for i := 0; i < 250000; i++ { 157 | key := fmt.Sprintf("key%d", hasher(uint32(i))) 158 | val := fmt.Sprintf("vvv%d", hasher(uint32(-i))) 159 | c.add(key, val) 160 | if i < 2000 { 161 | c.verify(t) 162 | } 163 | } 164 | c.verify(t) 165 | 166 | // del 167 | for i := 2000; i < 250000; i++ { 168 | key := fmt.Sprintf("key%d", hasher(uint32(i))) 169 | is.True(t, c.del(key)) 170 | } 171 | c.verify(t) 172 | 173 | // overwrite 174 | for i := 0; i < 2000; i++ { 175 | key := fmt.Sprintf("key%d", hasher(uint32(i))) 176 | val := fmt.Sprintf("vvv%d", hasher(uint32(+i))) 177 | c.add(key, val) 178 | c.verify(t) 179 | } 180 | 181 | is.False(t, c.del("kk")) 182 | 183 | for i := 0; i < 2000; i++ { 184 | key := fmt.Sprintf("key%d", hasher(uint32(i))) 185 | is.True(t, c.del(key)) 186 | c.verify(t) 187 | } 188 | 189 | c.add("k", "v2") 190 | c.verify(t) 191 | c.del("k") 192 | c.verify(t) 193 | 194 | // the dummy empty key 195 | is.Equal(t, 1, len(c.pages)) 196 | is.Equal(t, uint16(1), BNode(c.tree.get(c.tree.root)).nkeys()) 197 | } 198 | 199 | func TestBTreeBasicAscending(t *testing.T) { 200 | commonTestBasic(t, func(h uint32) uint32 { return +h }) 201 | } 202 | 203 | func TestBTreeBasicDescending(t *testing.T) { 204 | commonTestBasic(t, func(h uint32) uint32 { return -h }) 205 | } 206 | 207 | func TestBTreeBasicRand(t *testing.T) { 208 | commonTestBasic(t, fmix32) 209 | } 210 | 211 | func TestBTreeRandLength(t *testing.T) { 212 | c := newC() 213 | for i := 0; i < 2000; i++ { 214 | klen := fmix32(uint32(2*i+0)) % BTREE_MAX_KEY_SIZE 215 | vlen := fmix32(uint32(2*i+1)) % BTREE_MAX_VAL_SIZE 216 | if klen == 0 { 217 | continue 218 | } 219 | 220 | key := make([]byte, klen) 221 | rand.Read(key) 222 | val := make([]byte, vlen) 223 | // rand.Read(val) 224 | c.add(string(key), string(val)) 225 | c.verify(t) 226 | } 227 | } 228 | 229 | func TestBTreeIncLength(t *testing.T) { 230 | for l := 1; l < BTREE_MAX_KEY_SIZE+BTREE_MAX_VAL_SIZE; l++ { 231 | c := newC() 232 | 233 | klen := l 234 | if klen > BTREE_MAX_KEY_SIZE { 235 | klen = BTREE_MAX_KEY_SIZE 236 | } 237 | vlen := l - klen 238 | key := make([]byte, klen) 239 | val := make([]byte, vlen) 240 | 241 | factor := BTREE_PAGE_SIZE / l 242 | size := factor * factor * 2 243 | if size > 4000 { 244 | size = 4000 245 | } 246 | if size < 10 { 247 | size = 10 248 | } 249 | for i := 0; i < size; i++ { 250 | rand.Read(key) 251 | c.add(string(key), string(val)) 252 | } 253 | c.verify(t) 254 | } 255 | } 256 | -------------------------------------------------------------------------------- /btree_iter/btree_iter.go: -------------------------------------------------------------------------------- 1 | package btree_iter 2 | 3 | import ( 4 | "bytes" 5 | "github.com/Adit0507/AdiDB/btree" 6 | ) 7 | 8 | type BIter struct { 9 | tree *btree.BTree 10 | path []btree.BNode 11 | pos []uint16 12 | } 13 | 14 | // movin backward & forward 15 | func (iter *BIter) Next() { 16 | iterNext(iter, len(iter.path)-1) 17 | } 18 | 19 | func iterIsFirst(iter *BIter) bool { 20 | for _, pos := range iter.pos { 21 | if pos != 0 { 22 | return false 23 | } 24 | } 25 | return true 26 | } 27 | 28 | func iterPrev(iter *BIter, level int) { 29 | if iter.pos[level] > 0 { 30 | iter.pos[level]-- //move within node 31 | } else if level > 0 { 32 | iterPrev(iter, level-1) //move to sibling noe 33 | } else { 34 | panic("unreachable") 35 | } 36 | 37 | if level+1 < len(iter.pos) { 38 | node := iter.path[level] 39 | kid := BNode(iter.tree.get(node.getPtr(iter.pos[level]))) 40 | iter.path[level+1] = kid 41 | iter.pos[level+1] = kid.nkeys() - 1 42 | } 43 | } 44 | 45 | func (iter *BIter) Prev() { 46 | if !iterIsFirst(iter) { 47 | iterPrev(iter, len(iter.path)-1) 48 | } 49 | } 50 | 51 | func iterNext(iter *BIter, level int) { 52 | if iter.pos[level]+1 < iter.path[level].nkeys() { 53 | iter.pos[level]++ //move within node 54 | } else if level > 0 { 55 | iterNext(iter, level-1) //move to sibling node 56 | } else { 57 | iter.pos[len(iter.pos)-1]++ //past last key 58 | return 59 | } 60 | if level+1 < len(iter.pos) { //update child node 61 | node := iter.path[level] 62 | kid := btree.BNode(iter.tree.get(node.getPtr(iter.pos[level]))) 63 | iter.path[level+1] = kid 64 | iter.pos[level+1] = 0 65 | } 66 | } 67 | 68 | // find closest position that is less or equal to input key 69 | func (tree BTreeWrap) SeekLE(key []byte) *BIter { 70 | iter := &BIter{tree: tree} 71 | for ptr := tree.root; ptr != 0; { 72 | node := (tree.get(ptr)) 73 | idx := nodeLookupLE(node, key) 74 | iter.path = append(iter.path, node) 75 | iter.pos = append(iter.pos, idx) 76 | ptr = node.getPtr(idx) 77 | } 78 | return iter 79 | } 80 | 81 | func assert(cond bool) { 82 | if !cond { 83 | panic("assertion failure") 84 | } 85 | } 86 | 87 | // get current KV pair 88 | func (iter *BIter) Deref() ([]byte, []byte) { 89 | assert(iter.Valid()) 90 | last := len(iter.path) - 1 91 | node := iter.path[last] 92 | pos := iter.pos[last] 93 | 94 | return node.getKey(pos), node.getVal(pos) 95 | } 96 | 97 | func iterIsEnd(iter *BIter) bool { 98 | last := len(iter.path) - 1 99 | return last < 0 || iter.pos[last] >= uint16(iter.path[last].nkeys) 100 | } 101 | 102 | // precondition of Dref() 103 | func (iter *BIter) Valid() bool { 104 | return !(iterIsFirst(iter) || iterIsEnd(iter)) 105 | } 106 | 107 | const ( 108 | CMP_GE = +3 // >= 109 | CMP_GT = +2 // > 110 | CMP_LT = -2 // < 111 | CMP_LE = -3 // <= 112 | ) 113 | 114 | func cmpOk(key []byte, cmp int, ref []byte) bool { 115 | r := bytes.Compare(key, ref) 116 | 117 | switch cmp { 118 | case CMP_GE: 119 | return r >= 0 120 | 121 | case CMP_GT: 122 | return r > 0 123 | case CMP_LE: 124 | return r <= 0 125 | case CMP_LT: 126 | return r < 0 127 | default: 128 | panic("what>") 129 | } 130 | } 131 | 132 | type BTreeWrap struct{ 133 | *btree.BTree 134 | } 135 | 136 | func (tree BTreeWrap) Seek(key []byte, cmp int) *BIter { 137 | iter := tree.SeekLE(key) 138 | assert(iterIsFirst(iter) || !iterIsEnd(iter)) 139 | if cmp != CMP_LE { 140 | cur := []byte(nil) 141 | if !iterIsFirst(iter) { 142 | cur, _ = iter.Deref() 143 | } 144 | 145 | if len(key) == 0 || !cmpOk(cur, cmp, key) { 146 | if cmp > 0 { 147 | iter.Next() 148 | } else { 149 | iter.Prev() 150 | } 151 | } 152 | } 153 | 154 | if iter.Valid() { 155 | cur, _ := iter.Deref() 156 | assert(cmpOk(cur, cmp, key)) 157 | } 158 | 159 | return iter 160 | } 161 | -------------------------------------------------------------------------------- /btree_iter/btree_iter_test.go: -------------------------------------------------------------------------------- 1 | package btree_iter 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | is "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestBTreeIter(t *testing.T) { 11 | { 12 | c := newC() 13 | iter := c.tree.SeekLE(nil) 14 | is.False(t, iter.Valid()) 15 | } 16 | 17 | sizes := []int{5, 2500} 18 | for _, sz := range sizes { 19 | c := newC() 20 | 21 | for i := 0; i < sz; i++ { 22 | key := fmt.Sprintf("key%010d", i) 23 | val := fmt.Sprintf("vvv%d", fmix32(uint32(-i))) 24 | c.add(key, val) 25 | } 26 | c.verify(t) 27 | 28 | prevk, prevv := []byte(nil), []byte(nil) 29 | for i := 0; i < sz; i++ { 30 | key := []byte(fmt.Sprintf("key%010d", i)) 31 | val := []byte(fmt.Sprintf("vvv%d", fmix32(uint32(-i)))) 32 | // fmt.Println(i, string(key), val) 33 | 34 | iter := c.tree.SeekLE(key) 35 | is.True(t, iter.Valid()) 36 | gotk, gotv := iter.Deref() 37 | is.Equal(t, key, gotk) 38 | is.Equal(t, val, gotv) 39 | 40 | iter.Prev() 41 | if i > 0 { 42 | is.True(t, iter.Valid()) 43 | gotk, gotv := iter.Deref() 44 | is.Equal(t, prevk, gotk) 45 | is.Equal(t, prevv, gotv) 46 | } else { 47 | is.False(t, iter.Valid()) 48 | } 49 | 50 | iter.Next() 51 | { 52 | is.True(t, iter.Valid()) 53 | gotk, gotv := iter.Deref() 54 | is.Equal(t, key, gotk) 55 | is.Equal(t, val, gotv) 56 | } 57 | 58 | if i+1 == sz { 59 | iter.Next() 60 | is.False(t, iter.Valid()) 61 | } 62 | 63 | prevk, prevv = key, val 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /freelist/freelist.go: -------------------------------------------------------------------------------- 1 | package freelist 2 | 3 | import ( 4 | "encoding/binary" 5 | "github.com/Adit0507/AdiDB/btree" 6 | ) 7 | 8 | // node format: 9 | // |next| pointers | unused 10 | type LNode []byte 11 | 12 | const FREE_LIST_HEADER = 8 13 | const FREE_LIST_CAP = (btree.BTREE_PAGE_SIZE - FREE_LIST_HEADER) / 8 14 | 15 | func assert(cond bool) { 16 | if !cond { 17 | panic("assertion failure") 18 | } 19 | } 20 | 21 | 22 | // settin& gettin 23 | func (node LNode) getNext() uint64 { 24 | return binary.LittleEndian.Uint64(node[0:8]) 25 | } 26 | func (node LNode) setNext(next uint64) { 27 | binary.LittleEndian.PutUint64(node[0:8], next) 28 | } 29 | func (node LNode) getPtr(idx int) (uint64, uint64) { 30 | offset := FREE_LIST_HEADER + 16*idx 31 | return binary.LittleEndian.Uint64(node[offset:]), binary.LittleEndian.Uint64(node[offset+8:]) 32 | } 33 | func (node LNode) setPtr(idx int, ptr uint64, version uint64) { 34 | assert(idx < FREE_LIST_CAP) 35 | offset := FREE_LIST_HEADER + 16*idx 36 | binary.LittleEndian.PutUint64(node[offset+0:], ptr) 37 | binary.LittleEndian.PutUint64(node[offset+8:], version) 38 | } 39 | 40 | type FreeList struct { 41 | // read a page 42 | get func(uint64) []byte 43 | // updating an existing page 44 | set func(uint64) []byte 45 | // append a new page 46 | new func([]byte) uint64 47 | 48 | // pointer to head node 49 | headPage uint64 50 | // seq. no. to index into list head 51 | headSeq uint64 52 | tailPage uint64 53 | tailSeq uint64 54 | 55 | // in-memory states 56 | maxSeq uint64 // saved tailSeq to prevnt consuming newly added items 57 | maxVer uint64 //oldest reader version 58 | curVer uint64 //version no. when commiting 59 | } 60 | 61 | func seq2idx(seq uint64) int { 62 | return int(seq % FREE_LIST_CAP) 63 | } 64 | 65 | func versionBefore(a, b uint64) bool { 66 | return a-b > 1<<63 67 | } 68 | 69 | // makin newly added items available for consumption 70 | func (fl *FreeList) SetMaxVer(maxVer uint64) { 71 | fl.maxSeq = fl.tailSeq 72 | fl.maxVer = maxVer 73 | } 74 | 75 | // get 1 item form list head 76 | func (fl *FreeList) PopHead() uint64 { 77 | ptr, head := flPop(fl) 78 | if head != 0 { 79 | fl.PushTail(head) 80 | } 81 | 82 | return ptr 83 | } 84 | 85 | func (fl *FreeList) check() { 86 | assert(fl.headPage != 0 && fl.tailPage != 0) 87 | assert(fl.headSeq != fl.tailSeq || fl.headPage == fl.tailPage) 88 | } 89 | 90 | func (fl *FreeList) PushTail(ptr uint64) { 91 | fl.check() 92 | // addin to tail node 93 | LNode(fl.set(fl.tailPage)).setPtr(seq2idx(fl.tailSeq), ptr, fl.curVer) 94 | fl.tailSeq++ 95 | 96 | if seq2idx(fl.tailSeq) == 0 { 97 | next, head := flPop(fl) 98 | if next == 0 { 99 | // allocate new node by appending 100 | next = fl.new(make([]byte, btree.BTREE_PAGE_SIZE)) 101 | } 102 | 103 | // link to new tail node 104 | LNode(fl.set(fl.tailPage)).setNext(next) 105 | fl.tailPage = next 106 | 107 | // add head node if its removed 108 | if head != 0 { 109 | LNode(fl.set(fl.tailPage)).setPtr(0, head, fl.curVer) 110 | fl.tailSeq++ 111 | } 112 | } 113 | } 114 | 115 | // rmeove 1 item from head node & remove head node if empty 116 | func flPop(fl *FreeList) (ptr uint64, head uint64) { 117 | fl.check() 118 | if fl.headSeq == fl.maxSeq { 119 | return 0, 0 120 | } 121 | 122 | node := LNode(fl.get(fl.headPage)) 123 | ptr, version := node.getPtr(seq2idx(fl.headSeq)) 124 | if versionBefore(fl.maxVer, version) { 125 | return 0, 0 126 | } 127 | fl.headSeq++ 128 | 129 | if seq2idx(fl.headSeq) == 0 { 130 | head, fl.headPage = fl.headPage, node.getNext() 131 | assert(fl.headPage != 0) 132 | } 133 | 134 | return 135 | } 136 | -------------------------------------------------------------------------------- /freelist/freelist_test.go: -------------------------------------------------------------------------------- 1 | package freelist 2 | 3 | import ( 4 | "slices" 5 | "testing" 6 | "github.com/Adit0507/AdiDB/btree" 7 | ) 8 | 9 | type L struct { 10 | free FreeList 11 | pages map[uint64][]byte // simulate disk pages 12 | added []uint64 13 | removed []uint64 14 | } 15 | 16 | func newL() *L { 17 | pages := map[uint64][]byte{} 18 | pages[1] = make([]byte, btree.BTREE_PAGE_SIZE) 19 | append := uint64(1000) 20 | return &L{ 21 | free: FreeList{ 22 | get: func(ptr uint64) []byte { 23 | assert(pages[ptr] != nil) 24 | return pages[ptr] 25 | }, 26 | set: func(ptr uint64) []byte { 27 | assert(pages[ptr] != nil) 28 | return pages[ptr] 29 | }, 30 | new: func(node []byte) uint64 { 31 | assert(pages[append] == nil) 32 | pages[append] = node 33 | append++ 34 | return append - 1 35 | }, 36 | headPage: 1, // initial node 37 | tailPage: 1, 38 | }, 39 | pages: pages, 40 | } 41 | } 42 | 43 | // returns the content and the list nodes 44 | func flDump(free *FreeList) (list []uint64, nodes []uint64) { 45 | ptr := free.headPage 46 | nodes = append(nodes, ptr) 47 | for seq := free.headSeq; seq != free.tailSeq; { 48 | assert(ptr != 0) 49 | node := LNode(free.get(ptr)) 50 | item, _ := node.getPtr(seq2idx(seq)) 51 | list = append(list, item) 52 | seq++ 53 | if seq2idx(seq) == 0 { 54 | ptr = node.getNext() 55 | nodes = append(nodes, ptr) 56 | } 57 | } 58 | return 59 | } 60 | 61 | func (l *L) push(ptr uint64) { 62 | assert(l.pages[ptr] == nil) 63 | l.pages[ptr] = make([]byte, btree.BTREE_PAGE_SIZE) 64 | l.free.PushTail(ptr) 65 | l.added = append(l.added, ptr) 66 | } 67 | 68 | func (l *L) pop() uint64 { 69 | ptr := l.free.PopHead() 70 | if ptr != 0 { 71 | l.removed = append(l.removed, ptr) 72 | } 73 | return ptr 74 | } 75 | 76 | func (l *L) verify() { 77 | l.free.check() 78 | 79 | // dump all pointers from `l.pages` 80 | appended := []uint64{} 81 | ptrs := []uint64{} 82 | for ptr := range l.pages { 83 | if 1000 <= ptr && ptr < 10000 { 84 | appended = append(appended, ptr) 85 | } else if ptr != 1 { 86 | assert(slices.Contains(l.added, ptr)) 87 | } 88 | ptrs = append(ptrs, ptr) 89 | } 90 | // dump all pointers from the free list 91 | list, nodes := flDump(&l.free) 92 | 93 | // any pointer is either in the free list, a list node, or removed. 94 | assert(len(l.pages) == len(list)+len(nodes)+len(l.removed)) 95 | combined := slices.Concat(list, nodes, l.removed) 96 | slices.Sort(combined) 97 | slices.Sort(ptrs) 98 | assert(slices.Equal(combined, ptrs)) 99 | 100 | // any pointer is either the initial node, an allocated node, or added 101 | assert(len(l.pages) == 1+len(appended)+len(l.added)) 102 | combined = slices.Concat([]uint64{1}, appended, l.added) 103 | slices.Sort(combined) 104 | assert(slices.Equal(combined, ptrs)) 105 | } 106 | 107 | func TestFreeListEmptyFullEmpty(t *testing.T) { 108 | for N := 0; N < 2000; N++ { 109 | l := newL() 110 | for i := 0; i < N; i++ { 111 | l.push(10000 + uint64(i)) 112 | } 113 | l.verify() 114 | 115 | assert(l.pop() == 0) 116 | l.free.SetMaxVer(0) 117 | ptr := l.pop() 118 | for ptr != 0 { 119 | l.free.SetMaxVer(0) 120 | ptr = l.pop() 121 | } 122 | l.verify() 123 | 124 | list, nodes := flDump(&l.free) 125 | assert(len(list) == 0) 126 | assert(len(nodes) == 1) 127 | // println("N", N) 128 | } 129 | } 130 | 131 | func TestFreeListEmptyFullEmpty2(t *testing.T) { 132 | for N := 0; N < 2000; N++ { 133 | l := newL() 134 | for i := 0; i < N; i++ { 135 | l.push(10000 + uint64(i)) 136 | l.free.SetMaxVer(0) // allow self-reuse 137 | } 138 | l.verify() 139 | 140 | ptr := l.pop() 141 | for ptr != 0 { 142 | l.free.SetMaxVer(0) 143 | ptr = l.pop() 144 | } 145 | l.verify() 146 | 147 | list, nodes := flDump(&l.free) 148 | assert(len(list) == 0) 149 | assert(len(nodes) == 1) 150 | // println("N", N) 151 | } 152 | } 153 | 154 | func TestFreeListRandom(t *testing.T) { 155 | for N := 0; N < 1000; N++ { 156 | l := newL() 157 | for i := 0; i < 2000; i++ { 158 | ptr := uint64(10000 + fmix32(uint32(i))) 159 | if ptr%2 == 0 { 160 | l.push(ptr) 161 | } else { 162 | l.pop() 163 | } 164 | } 165 | l.verify() 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/Adit0507/AdiDB 2 | 3 | go 1.22.5 4 | 5 | require ( 6 | github.com/stretchr/testify v1.9.0 7 | golang.org/x/sys v0.26.0 8 | ) 9 | 10 | require ( 11 | github.com/davecgh/go-spew v1.1.1 // indirect 12 | github.com/pmezard/go-difflib v1.0.0 // indirect 13 | gopkg.in/yaml.v3 v3.0.1 // indirect 14 | ) 15 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 6 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 7 | golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= 8 | golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 9 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 10 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 11 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 12 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 13 | -------------------------------------------------------------------------------- /kv/kv.go: -------------------------------------------------------------------------------- 1 | package kv 2 | 3 | // go:build (linux && 386) || (darwin && !cgo) 4 | 5 | import ( 6 | "bytes" 7 | "encoding/binary" 8 | "errors" 9 | "fmt" 10 | "os" 11 | "path" 12 | "sync" 13 | "syscall" 14 | 15 | "github.com/Adit0507/AdiDB/btree" 16 | "github.com/Adit0507/AdiDB/freelist" 17 | "golang.org/x/sys/unix" 18 | ) 19 | 20 | type KV struct { 21 | Path string 22 | Fsync func(int) error // overridable; for testing 23 | // internals 24 | fd int 25 | tree btree.BTree 26 | free freelist.FreeList 27 | mmap struct { 28 | total int // mmap size, can be larger than the file size 29 | chunks [][]byte // multiple mmaps, can be non-continuous 30 | } 31 | page struct { 32 | flushed uint64 // database size in number of pages 33 | nappend uint64 // number of pages to be appended 34 | updates map[uint64][]byte // pending updates, including appended pages 35 | } 36 | failed bool // Did the last update fail? 37 | // concurrency control 38 | mutex sync.Mutex // serialize TX methods 39 | version uint64 // monotonic version number 40 | ongoing []uint64 // version numbers of concurrent TXs 41 | history []CommittedTX // chanages keys; for detecting conflicts 42 | } 43 | 44 | type CommittedTX struct { 45 | version uint64 46 | writes []KeyRange // sorted 47 | } 48 | 49 | // `BTree.get`, read a page. 50 | func (db *KV) pageRead(ptr uint64) []byte { 51 | assert(ptr < db.page.flushed+db.page.nappend) 52 | if node, ok := db.page.updates[ptr]; ok { 53 | return node // pending update 54 | } 55 | return mmapRead(ptr, db.mmap.chunks) 56 | } 57 | 58 | func mmapRead(ptr uint64, chunks [][]byte) []byte { 59 | start := uint64(0) 60 | for _, chunk := range chunks { 61 | end := start + uint64(len(chunk))/BTREE_PAGE_SIZE 62 | if ptr < end { 63 | offset := BTREE_PAGE_SIZE * (ptr - start) 64 | return chunk[offset : offset+BTREE_PAGE_SIZE] 65 | } 66 | start = end 67 | } 68 | panic("bad ptr") 69 | } 70 | 71 | func assert(cond bool) { 72 | if !cond { 73 | panic("assertion failure") 74 | } 75 | } 76 | 77 | // `BTree.new`, allocate a new page. 78 | func (db *KV) pageAlloc(node []byte) uint64 { 79 | assert(len(node) == btree.BTREE_PAGE_SIZE) 80 | if ptr := db.free.PopHead(); ptr != 0 { // try the free list 81 | assert(db.page.updates[ptr] == nil) 82 | db.page.updates[ptr] = node 83 | return ptr 84 | } 85 | return db.pageAppend(node) // append 86 | } 87 | 88 | // `FreeList.new`, append a new page. 89 | func (db *KV) pageAppend(node []byte) uint64 { 90 | assert(len(node) == btree.BTREE_PAGE_SIZE) 91 | ptr := db.page.flushed + db.page.nappend 92 | db.page.nappend++ 93 | assert(db.page.updates[ptr] == nil) 94 | db.page.updates[ptr] = node 95 | return ptr 96 | } 97 | 98 | // `FreeList.set`, update an existing page. 99 | func (db *KV) pageWrite(ptr uint64) []byte { 100 | assert(ptr < db.page.flushed+db.page.nappend) 101 | if node, ok := db.page.updates[ptr]; ok { 102 | return node // pending update 103 | } 104 | // initialize from the file 105 | node := make([]byte, btree.BTREE_PAGE_SIZE) 106 | if !(ptr == 1 && db.page.flushed == 2) { 107 | // special case: page 1 doesn't exist after creating an empty DB 108 | copy(node, mmapRead(ptr, db.mmap.chunks)) 109 | } 110 | db.page.updates[ptr] = node 111 | return node 112 | } 113 | 114 | // open or create a file and fsync the directory 115 | func createFileSync(file string) (int, error) { 116 | // obtain the directory fd 117 | flags := os.O_RDONLY | syscall.O_DIRECTORY 118 | dirfd, err := syscall.Open(path.Dir(file), flags, 0o644) 119 | if err != nil { 120 | return -1, fmt.Errorf("open directory: %w", err) 121 | } 122 | defer syscall.Close(dirfd) 123 | // open or create the file 124 | flags = os.O_RDWR | os.O_CREATE 125 | fd, err := syscall.Openat(dirfd, path.Base(file), flags, 0o644) 126 | if err != nil { 127 | return -1, fmt.Errorf("open file: %w", err) 128 | } 129 | // fsync the directory 130 | err = syscall.Fsync(dirfd) 131 | if err != nil { // may leave an empty file 132 | _ = syscall.Close(fd) 133 | return -1, fmt.Errorf("fsync directory: %w", err) 134 | } 135 | // done 136 | return fd, nil 137 | } 138 | 139 | // open or create a DB file 140 | func (db *KV) Open() error { 141 | if db.Fsync == nil { 142 | db.Fsync = syscall.Fsync 143 | } 144 | var err error 145 | db.page.updates = map[uint64][]byte{} 146 | // B+tree callbacks 147 | db.tree.get = db.pageRead 148 | db.tree.new = db.pageAlloc 149 | db.tree.del = db.free.PushTail 150 | // free list callbacks 151 | db.free.get = db.pageRead 152 | db.free.new = db.pageAppend 153 | db.free.set = db.pageWrite 154 | // open or create the DB file 155 | if db.fd, err = createFileSync(db.Path); err != nil { 156 | return err 157 | } 158 | // get the file size 159 | finfo := syscall.Stat_t{} 160 | if err = syscall.Fstat(db.fd, &finfo); err != nil { 161 | goto fail 162 | } 163 | // create the initial mmap 164 | if err = extendMmap(db, int(finfo.Size)); err != nil { 165 | goto fail 166 | } 167 | // read the meta page 168 | if err = readRoot(db, finfo.Size); err != nil { 169 | goto fail 170 | } 171 | return nil 172 | // error 173 | fail: 174 | db.Close() 175 | return fmt.Errorf("KV.Open: %w", err) 176 | } 177 | 178 | const DB_SIG = "BuildYourOwnDB12" 179 | 180 | /* 181 | the 1st page stores the root pointer and other auxiliary data. 182 | | sig | root | page_used | head_page | head_seq | tail_page | tail_seq | ver | 183 | | 16B | 8B | 8B | 8B | 8B | 8B | 8B | 8B | 184 | */ 185 | func loadMeta(db *KV, data []byte) { 186 | db.tree.root = binary.LittleEndian.Uint64(data[16:24]) 187 | db.page.flushed = binary.LittleEndian.Uint64(data[24:32]) 188 | db.free.headPage = binary.LittleEndian.Uint64(data[32:40]) 189 | db.free.headSeq = binary.LittleEndian.Uint64(data[40:48]) 190 | db.free.tailPage = binary.LittleEndian.Uint64(data[48:56]) 191 | db.free.tailSeq = binary.LittleEndian.Uint64(data[56:64]) 192 | db.version = binary.LittleEndian.Uint64(data[64:72]) 193 | } 194 | 195 | func saveMeta(db *KV) []byte { 196 | var data [72]byte 197 | copy(data[:16], []byte(DB_SIG)) 198 | binary.LittleEndian.PutUint64(data[16:24], db.tree.root) 199 | binary.LittleEndian.PutUint64(data[24:32], db.page.flushed) 200 | binary.LittleEndian.PutUint64(data[32:40], db.free.headPage) 201 | binary.LittleEndian.PutUint64(data[40:48], db.free.headSeq) 202 | binary.LittleEndian.PutUint64(data[48:56], db.free.tailPage) 203 | binary.LittleEndian.PutUint64(data[56:64], db.free.tailSeq) 204 | binary.LittleEndian.PutUint64(data[64:72], db.version) 205 | return data[:] 206 | } 207 | 208 | func readRoot(db *KV, fileSize int64) error { 209 | if fileSize%btree.BTREE_PAGE_SIZE != 0 { 210 | return errors.New("file is not a multiple of pages") 211 | } 212 | if fileSize == 0 { // empty file 213 | // reserve 2 pages: the meta page and a free list node 214 | db.page.flushed = 2 215 | // add an initial node to the free list so it's never empty 216 | db.free.headPage = 1 // the 2nd page 217 | db.free.tailPage = 1 218 | return nil // the meta page will be written in the 1st update 219 | } 220 | // read the page 221 | data := db.mmap.chunks[0] 222 | loadMeta(db, data) 223 | // initialize the free list 224 | db.free.SetMaxVer(db.version) 225 | // verify the page 226 | bad := !bytes.Equal([]byte(DB_SIG), data[:16]) 227 | // pointers are within range? 228 | maxpages := uint64(fileSize / btree.BTREE_PAGE_SIZE) 229 | bad = bad || !(0 < db.page.flushed && db.page.flushed <= maxpages) 230 | bad = bad || !(0 < db.tree.root && db.tree.root < db.page.flushed) 231 | bad = bad || !(0 < db.free.headPage && db.free.headPage < db.page.flushed) 232 | bad = bad || !(0 < db.free.tailPage && db.free.tailPage < db.page.flushed) 233 | if bad { 234 | return errors.New("bad meta page") 235 | } 236 | return nil 237 | } 238 | 239 | // update the meta page. it must be atomic. 240 | func updateRoot(db *KV) error { 241 | // NOTE: atomic? 242 | if _, err := syscall.Pwrite(db.fd, saveMeta(db), 0); err != nil { 243 | return fmt.Errorf("write meta page: %w", err) 244 | } 245 | return nil 246 | } 247 | 248 | // extend the mmap by adding new mappings. 249 | func extendMmap(db *KV, size int) error { 250 | if size <= db.mmap.total { 251 | return nil // enough range 252 | } 253 | alloc := max(db.mmap.total, 64<<20) // double the current address space 254 | for db.mmap.total+alloc < size { 255 | alloc *= 2 // still not enough? 256 | } 257 | chunk, err := syscall.Mmap( 258 | db.fd, int64(db.mmap.total), alloc, 259 | syscall.PROT_READ, syscall.MAP_SHARED, // read-only 260 | ) 261 | if err != nil { 262 | return fmt.Errorf("mmap: %w", err) 263 | } 264 | db.mmap.total += alloc 265 | db.mmap.chunks = append(db.mmap.chunks, chunk) 266 | return nil 267 | } 268 | 269 | func updateFile(db *KV) error { 270 | // 1. Write new nodes. 271 | if err := writePages(db); err != nil { 272 | return err 273 | } 274 | // 2. `fsync` to enforce the order between 1 and 3. 275 | if err := db.Fsync(db.fd); err != nil { 276 | return err 277 | } 278 | // 3. Update the root pointer atomically. 279 | if err := updateRoot(db); err != nil { 280 | return err 281 | } 282 | // 4. `fsync` to make everything persistent. 283 | if err := db.Fsync(db.fd); err != nil { 284 | return err 285 | } 286 | return nil 287 | } 288 | 289 | func updateOrRevert(db *KV, meta []byte) error { 290 | // ensure the on-disk meta page matches the in-memory one after an error 291 | if db.failed { 292 | if _, err := syscall.Pwrite(db.fd, meta, 0); err != nil { 293 | return fmt.Errorf("rewrite meta page: %w", err) 294 | } 295 | if err := db.Fsync(db.fd); err != nil { 296 | return err 297 | } 298 | db.failed = false 299 | } 300 | // 2-phase update 301 | err := updateFile(db) 302 | // revert on error 303 | if err != nil { 304 | // the on-disk meta page is in an unknown state. 305 | // mark it to be rewritten on later recovery. 306 | db.failed = true 307 | // in-memory states are reverted immediately to allow reads 308 | loadMeta(db, meta) 309 | // discard temporaries 310 | db.page.nappend = 0 311 | db.page.updates = map[uint64][]byte{} 312 | } 313 | return err 314 | } 315 | 316 | func writePages(db *KV) error { 317 | // extend the mmap if needed 318 | size := (db.page.flushed + db.page.nappend) * btree.BTREE_PAGE_SIZE 319 | if err := extendMmap(db, int(size)); err != nil { 320 | return err 321 | } 322 | // write data pages to the file 323 | for ptr, node := range db.page.updates { 324 | offset := int64(ptr * btree.BTREE_PAGE_SIZE) 325 | if _, err := unix.Pwrite(db.fd, node, offset); err != nil { 326 | return err 327 | } 328 | } 329 | // discard in-memory data 330 | db.page.flushed += db.page.nappend 331 | db.page.nappend = 0 332 | db.page.updates = map[uint64][]byte{} 333 | return nil 334 | } 335 | 336 | // cleanups 337 | func (db *KV) Close() { 338 | for _, chunk := range db.mmap.chunks { 339 | err := syscall.Munmap(chunk) 340 | assert(err == nil) 341 | } 342 | _ = syscall.Close(db.fd) 343 | } 344 | -------------------------------------------------------------------------------- /kv/kv_test.go: -------------------------------------------------------------------------------- 1 | package kv 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "os" 7 | "sort" 8 | "testing" 9 | 10 | is "github.com/stretchr/testify/require" 11 | ) 12 | 13 | type D struct { 14 | db KV 15 | ref map[string]string 16 | } 17 | 18 | func nofsync(int) error { 19 | return nil 20 | } 21 | 22 | func newD() *D { 23 | os.Remove("test.db") 24 | 25 | d := &D{} 26 | d.ref = map[string]string{} 27 | d.db.Path = "test.db" 28 | d.db.Fsync = nofsync // faster 29 | err := d.db.Open() 30 | assert(err == nil) 31 | return d 32 | } 33 | 34 | func (d *D) reopen() { 35 | d.db.Close() 36 | d.db = KV{Path: d.db.Path, Fsync: d.db.Fsync} 37 | err := d.db.Open() 38 | assert(err == nil) 39 | } 40 | 41 | func (d *D) dispose() { 42 | d.db.Close() 43 | os.Remove("test.db") 44 | } 45 | 46 | func (d *D) add(key string, val string) { 47 | tx := KVTX{} 48 | d.db.Begin(&tx) 49 | _, err := tx.Set([]byte(key), []byte(val)) 50 | assert(err == nil) 51 | err = d.db.Commit(&tx) 52 | assert(err == nil) 53 | d.ref[key] = val 54 | } 55 | 56 | func (d *D) del(key string) bool { 57 | delete(d.ref, key) 58 | tx := KVTX{} 59 | d.db.Begin(&tx) 60 | deleted, err := tx.Del(&DeleteReq{Key: []byte(key)}) 61 | assert(err == nil) 62 | err = d.db.Commit(&tx) 63 | assert(err == nil) 64 | return deleted 65 | } 66 | 67 | func (d *D) dump() ([]string, []string) { 68 | keys := []string{} 69 | vals := []string{} 70 | 71 | var nodeDump func(uint64) 72 | nodeDump = func(ptr uint64) { 73 | node := BNode(d.db.tree.get(ptr)) 74 | nkeys := node.nkeys() 75 | if node.btype() == BNODE_LEAF { 76 | for i := uint16(0); i < nkeys; i++ { 77 | keys = append(keys, string(node.getKey(i))) 78 | vals = append(vals, string(node.getVal(i))) 79 | } 80 | } else { 81 | for i := uint16(0); i < nkeys; i++ { 82 | ptr := node.getPtr(i) 83 | nodeDump(ptr) 84 | } 85 | } 86 | } 87 | 88 | nodeDump(d.db.tree.root) 89 | assert(keys[0] == "") 90 | assert(vals[0] == "") 91 | return keys[1:], vals[1:] 92 | } 93 | 94 | func (d *D) verify(t *testing.T) { 95 | // KV data 96 | keys, vals := d.dump() 97 | // reference data 98 | rkeys, rvals := []string{}, []string{} 99 | for k, v := range d.ref { 100 | rkeys = append(rkeys, k) 101 | rvals = append(rvals, v) 102 | } 103 | is.Equal(t, len(rkeys), len(keys)) 104 | sort.Stable(sortIF{ 105 | len: len(rkeys), 106 | less: func(i, j int) bool { return rkeys[i] < rkeys[j] }, 107 | swap: func(i, j int) { 108 | k, v := rkeys[i], rvals[i] 109 | rkeys[i], rvals[i] = rkeys[j], rvals[j] 110 | rkeys[j], rvals[j] = k, v 111 | }, 112 | }) 113 | // compare with the reference 114 | is.Equal(t, rkeys, keys) 115 | is.Equal(t, rvals, vals) 116 | 117 | // track visited pages 118 | pages := make([]uint8, d.db.page.flushed) 119 | pages[0] = 1 120 | pages[d.db.tree.root] = 1 121 | // verify node structures 122 | var nodeVerify func(BNode) 123 | nodeVerify = func(node BNode) { 124 | nkeys := node.nkeys() 125 | assert(nkeys >= 1) 126 | if node.btype() == BNODE_LEAF { 127 | return 128 | } 129 | for i := uint16(0); i < nkeys; i++ { 130 | ptr := node.getPtr(i) 131 | is.Zero(t, pages[ptr]) 132 | pages[ptr] = 1 // tree node 133 | key := node.getKey(i) 134 | kid := BNode(d.db.tree.get(node.getPtr(i))) 135 | is.Equal(t, key, kid.getKey(0)) 136 | nodeVerify(kid) 137 | } 138 | } 139 | 140 | nodeVerify(d.db.tree.get(d.db.tree.root)) 141 | 142 | // free list 143 | list, nodes := flDump(&d.db.free) 144 | for _, ptr := range nodes { 145 | is.Zero(t, pages[ptr]) 146 | pages[ptr] = 2 // free list node 147 | } 148 | for _, ptr := range list { 149 | is.Zero(t, pages[ptr]) 150 | pages[ptr] = 3 // free list content 151 | } 152 | for _, flag := range pages { 153 | is.NotZero(t, flag) // every page is accounted for 154 | } 155 | } 156 | 157 | func funcTestKVBasic(t *testing.T, reopen bool) { 158 | c := newD() 159 | defer c.dispose() 160 | 161 | c.add("k", "v") 162 | c.verify(t) 163 | 164 | // insert 165 | for i := 0; i < 25000; i++ { 166 | key := fmt.Sprintf("key%d", fmix32(uint32(i))) 167 | val := fmt.Sprintf("vvv%d", fmix32(uint32(-i))) 168 | c.add(key, val) 169 | if i < 2000 { 170 | c.verify(t) 171 | } 172 | } 173 | c.verify(t) 174 | if reopen { 175 | c.reopen() 176 | c.verify(t) 177 | } 178 | t.Log("insertion done") 179 | 180 | // del 181 | for i := 2000; i < 25000; i++ { 182 | key := fmt.Sprintf("key%d", fmix32(uint32(i))) 183 | is.True(t, c.del(key)) 184 | } 185 | c.verify(t) 186 | if reopen { 187 | c.reopen() 188 | c.verify(t) 189 | } 190 | t.Log("deletion done") 191 | 192 | // overwrite 193 | for i := 0; i < 2000; i++ { 194 | key := fmt.Sprintf("key%d", fmix32(uint32(i))) 195 | val := fmt.Sprintf("vvv%d", fmix32(uint32(+i))) 196 | c.add(key, val) 197 | c.verify(t) 198 | } 199 | 200 | is.False(t, c.del("kk")) 201 | 202 | // remove all 203 | for i := 0; i < 2000; i++ { 204 | key := fmt.Sprintf("key%d", fmix32(uint32(i))) 205 | is.True(t, c.del(key)) 206 | c.verify(t) 207 | } 208 | if reopen { 209 | c.reopen() 210 | c.verify(t) 211 | } 212 | 213 | c.add("k", "v2") 214 | c.verify(t) 215 | c.del("k") 216 | c.verify(t) 217 | } 218 | 219 | func TestKVBasic(t *testing.T) { 220 | funcTestKVBasic(t, false) 221 | funcTestKVBasic(t, true) 222 | } 223 | 224 | func fsyncErr(errlist ...int) func(int) error { 225 | return func(int) error { 226 | fail := errlist[0] 227 | errlist = errlist[1:] 228 | if fail != 0 { 229 | return fmt.Errorf("fsync error!") 230 | } else { 231 | return nil 232 | } 233 | } 234 | } 235 | 236 | func TestKVFsyncErr(t *testing.T) { 237 | c := newD() 238 | defer c.dispose() 239 | 240 | set := func(key []byte, val []byte) error { 241 | tx := KVTX{} 242 | c.db.Begin(&tx) 243 | tx.Set(key, val) 244 | return c.db.Commit(&tx) 245 | } 246 | get := func(key []byte) ([]byte, bool) { 247 | tx := KVTX{} 248 | c.db.Begin(&tx) 249 | val, ok := tx.Get(key) 250 | c.db.Abort(&tx) 251 | return val, ok 252 | } 253 | 254 | err := set([]byte("k"), []byte("1")) 255 | assert(err == nil) 256 | val, ok := get([]byte("k")) 257 | assert(ok && string(val) == "1") 258 | 259 | c.db.Fsync = fsyncErr(1) 260 | err = set([]byte("k"), []byte("2")) 261 | assert(err != nil) 262 | val, ok = get([]byte("k")) 263 | assert(ok && string(val) == "1") 264 | 265 | c.db.Fsync = nofsync 266 | err = set([]byte("k"), []byte("3")) 267 | assert(err == nil) 268 | val, ok = get([]byte("k")) 269 | assert(ok && string(val) == "3") 270 | 271 | c.db.Fsync = fsyncErr(0, 1) 272 | err = set([]byte("k"), []byte("4")) 273 | assert(err != nil) 274 | val, ok = get([]byte("k")) 275 | assert(ok && string(val) == "3") 276 | 277 | c.db.Fsync = nofsync 278 | err = set([]byte("k"), []byte("5")) 279 | assert(err == nil) 280 | val, ok = get([]byte("k")) 281 | assert(ok && string(val) == "5") 282 | 283 | c.db.Fsync = fsyncErr(0, 1) 284 | err = set([]byte("k"), []byte("6")) 285 | assert(err != nil) 286 | val, ok = get([]byte("k")) 287 | assert(ok && string(val) == "5") 288 | } 289 | 290 | func TestKVRandLength(t *testing.T) { 291 | c := newD() 292 | defer c.dispose() 293 | 294 | for i := 0; i < 2000; i++ { 295 | klen := fmix32(uint32(2*i+0)) % BTREE_MAX_KEY_SIZE 296 | vlen := fmix32(uint32(2*i+1)) % BTREE_MAX_VAL_SIZE 297 | if klen == 0 { 298 | continue 299 | } 300 | 301 | key := make([]byte, klen) 302 | rand.Read(key) 303 | val := make([]byte, vlen) 304 | // rand.Read(val) 305 | c.add(string(key), string(val)) 306 | c.verify(t) 307 | } 308 | } 309 | 310 | func TestKVIncLength(t *testing.T) { 311 | for l := 1; l < BTREE_MAX_KEY_SIZE+BTREE_MAX_VAL_SIZE; l++ { 312 | c := newD() 313 | 314 | klen := l 315 | if klen > BTREE_MAX_KEY_SIZE { 316 | klen = BTREE_MAX_KEY_SIZE 317 | } 318 | vlen := l - klen 319 | key := make([]byte, klen) 320 | val := make([]byte, vlen) 321 | 322 | factor := BTREE_PAGE_SIZE / l 323 | size := factor * factor * 2 324 | if size > 4000 { 325 | size = 4000 326 | } 327 | if size < 10 { 328 | size = 10 329 | } 330 | for i := 0; i < size; i++ { 331 | rand.Read(key) 332 | c.add(string(key), string(val)) 333 | } 334 | c.verify(t) 335 | 336 | c.dispose() 337 | } 338 | } 339 | 340 | func fileSize(path string) int64 { 341 | finfo, err := os.Stat(path) 342 | assert(err == nil) 343 | return finfo.Size() 344 | } 345 | 346 | // test the free list: file size do not increase under various operations 347 | func TestKVFileSize(t *testing.T) { 348 | c := newD() 349 | fill := func(seed int) { 350 | for i := 0; i < 2000; i++ { 351 | key := fmt.Sprintf("key%d", fmix32(uint32(i))) 352 | val := fmt.Sprintf("vvv%010d", fmix32(uint32(seed*2000+i))) 353 | c.add(key, val) 354 | } 355 | } 356 | fill(0) 357 | fill(1) 358 | size := fileSize(c.db.Path) 359 | 360 | // update the same key 361 | fill(2) 362 | assert(size == fileSize(c.db.Path)) 363 | 364 | // remove everything 365 | for i := 0; i < 2000; i++ { 366 | key := fmt.Sprintf("key%d", fmix32(uint32(i))) 367 | c.del(key) 368 | } 369 | assert(size == fileSize(c.db.Path)) 370 | 371 | // add them back 372 | fill(3) 373 | assert(size == fileSize(c.db.Path)) 374 | } 375 | -------------------------------------------------------------------------------- /ql/ql_exec.go: -------------------------------------------------------------------------------- 1 | package ql 2 | 3 | import ( 4 | "bytes" 5 | "cmp" 6 | "errors" 7 | "fmt" 8 | "slices" 9 | "strconv" 10 | ) 11 | 12 | // evaluating expressions 13 | type QLEvalContext struct { 14 | env Record 15 | out Value 16 | err error 17 | } 18 | 19 | func qlErr(ctx *QLEvalContext, format string, args ...interface{}) { 20 | if ctx.err == nil { 21 | ctx.out.Type = QL_ERR 22 | ctx.err = fmt.Errorf(format, args...) 23 | } 24 | } 25 | 26 | func b2i(b bool) int64 { 27 | if b { 28 | return 1 29 | } else { 30 | return 0 31 | } 32 | } 33 | 34 | func qlEval(ctx *QLEvalContext, node QLNODE) { 35 | switch node.Type { 36 | // refer to col. 37 | case QL_SYM: 38 | if v := ctx.env.Get(string(node.Str)); v != nil { 39 | ctx.out = *v 40 | } else { 41 | qlErr(ctx, "unknown col.: %s", node.Str) 42 | } 43 | //literla value 44 | case QL_I64, QL_STR: 45 | ctx.out = node.Value 46 | case QL_TUP: 47 | qlErr(ctx, "unexpected tuple") 48 | 49 | // operators 50 | case QL_NEG: 51 | qlEval(ctx, node.Kids[0]) 52 | if ctx.out.Type == TYPE_INT64 { 53 | ctx.out.I64 = -ctx.out.I64 54 | } else { 55 | qlErr(ctx, "QL_NEG type error") 56 | } 57 | case QL_NOT: 58 | qlEval(ctx, node.Kids[0]) 59 | if ctx.out.Type == TYPE_INT64 { 60 | ctx.out.I64 = b2i(ctx.out.I64 == 0) 61 | } else { 62 | qlErr(ctx, "QL_NOT type error") 63 | } 64 | 65 | // binary ops. 66 | case QL_CMP_GE, QL_CMP_GT, QL_CMP_LT, QL_CMP_LE, QL_CMP_EQ, QL_CMP_NE: 67 | fallthrough 68 | case QL_ADD, QL_SUB, QL_MUL, QL_DIV, QL_MOD, QL_AND, QL_OR: 69 | qlBinop(ctx, node) 70 | 71 | default: 72 | panic("unreachable") 73 | 74 | } 75 | } 76 | 77 | func qlBinopI64(ctx *QLEvalContext, op uint32, a1 int64, a2 int64) int64 { 78 | switch op { 79 | case QL_ADD: 80 | return a1 + a2 81 | case QL_SUB: 82 | return a1 - a2 83 | case QL_MUL: 84 | return a1 * a2 85 | case QL_DIV: 86 | if a2 == 0 { 87 | qlErr(ctx, "div. by zero") 88 | return 0 89 | } 90 | return a1 / a2 91 | 92 | case QL_MOD: 93 | if a2 == 0 { 94 | qlErr(ctx, "div. by zero") 95 | return 0 96 | } 97 | 98 | return a1 % a2 99 | 100 | case QL_AND: 101 | return b2i(a1&a2 != 0) 102 | case QL_OR: 103 | return b2i(a1|a2 != 0) 104 | 105 | default: 106 | qlErr(ctx, "bad i64 binop") 107 | return 0 108 | } 109 | } 110 | 111 | func qlBinopStr(ctx *QLEvalContext, op uint32, a1 []byte, a2 []byte) { 112 | switch op { 113 | case QL_ADD: 114 | ctx.out.Type = TYPE_BYTES 115 | ctx.out.Str = slices.Concat(a1, a2) 116 | default: 117 | qlErr(ctx, "bad str binop") 118 | } 119 | } 120 | 121 | // binary operators 122 | func qlBinop(ctx *QLEvalContext, node QLNODE) { 123 | isCmp := false 124 | switch node.Type { 125 | case QL_CMP_GE, QL_CMP_GT, QL_CMP_LT, QL_CMP_LE, QL_CMP_EQ, QL_CMP_NE: 126 | isCmp = true 127 | } 128 | 129 | // tuple comparision 130 | if isCmp && node.Kids[0].Type == QL_TUP && node.Kids[1].Type == QL_TUP { 131 | r := qlTupleCmp(ctx, node.Kids[0], node.Kids[1]) 132 | ctx.out.Type = QL_I64 133 | ctx.out.I64 = b2i(cmp2bool(r, node.Type)) 134 | return 135 | } 136 | 137 | // subexpressions 138 | qlEval(ctx, node.Kids[0]) 139 | a1 := ctx.out 140 | qlEval(ctx, node.Kids[1]) 141 | a2 := ctx.out 142 | 143 | // scalar comparision 144 | if isCmp { 145 | r := qlValueCmp(ctx, a1, a2) 146 | ctx.out.Type = QL_I64 147 | ctx.out.I64 = b2i(cmp2bool(r, node.Type)) 148 | return 149 | } 150 | 151 | switch { 152 | case ctx.err != nil: 153 | return 154 | case a1.Type == TYPE_INT64: 155 | ctx.out.Type = QL_I64 156 | ctx.out.I64 = qlBinopI64(ctx, node.Type, a1.I64, a2.I64) 157 | 158 | case a1.Type != a2.Type: 159 | qlErr(ctx, "binop type mismatch") 160 | 161 | case a1.Type == TYPE_BYTES: 162 | ctx.out.Type = QL_STR 163 | qlBinopStr(ctx, node.Type, a1.Str, a2.Str) 164 | 165 | default: 166 | panic("unreachable") 167 | } 168 | 169 | } 170 | 171 | func cmp2bool(res int, cmd uint32) bool { 172 | switch cmd { 173 | case QL_CMP_GE: 174 | return res >= 0 175 | case QL_CMP_GT: 176 | return res > 0 177 | case QL_CMP_LT: 178 | return res < 0 179 | case QL_CMP_LE: 180 | return res <= 0 181 | case QL_CMP_EQ: 182 | return res == 0 183 | case QL_CMP_NE: 184 | return res != 0 185 | 186 | default: 187 | panic("unreachable") 188 | } 189 | } 190 | 191 | func qlValueCmp(ctx *QLEvalContext, a1 Value, a2 Value) int { 192 | switch { 193 | case ctx.err != nil: 194 | return 0 195 | 196 | case a1.Type != a2.Type: 197 | qlErr(ctx, "comparison of different types") 198 | return 0 199 | 200 | case a1.Type == TYPE_INT64: 201 | return cmp.Compare(a1.I64, a2.I64) 202 | 203 | case a1.Type == TYPE_BYTES: 204 | return bytes.Compare(a1.Str, a2.Str) 205 | 206 | default: 207 | panic("unreachable") 208 | } 209 | } 210 | 211 | // comparin 2 tuples of equal length 212 | func qlTupleCmp(ctx *QLEvalContext, n1 QLNODE, n2 QLNODE) int { 213 | if len(n1.Kids) != len(n2.Kids) { 214 | qlErr(ctx, "tuple comp. of different lengths") 215 | } 216 | 217 | for i := 0; i < len(n1.Kids) && ctx.err == nil; i++ { 218 | qlEval(ctx, n1.Kids[i]) 219 | a1 := ctx.out 220 | qlEval(ctx, n2.Kids[i]) 221 | a2 := ctx.out 222 | if cmp := qlValueCmp(ctx, a1, a2); cmp != 0 { 223 | return cmp 224 | } 225 | } 226 | 227 | return 0 228 | } 229 | 230 | func qlEvelMulti(env Record, exprs []QLNODE) ([]Value, error) { 231 | vals := []Value{} 232 | 233 | for _, node := range exprs { 234 | ctx := QLEvalContext{env: env} 235 | qlEval(&ctx, node) 236 | 237 | if ctx.err != nil { 238 | return nil, ctx.err 239 | } 240 | vals = append(vals, ctx.out) 241 | } 242 | 243 | return vals, nil 244 | } 245 | 246 | func qlEvalScanKey(node QLNODE) (Record, int, error) { 247 | cmp := 0 248 | 249 | switch node.Type { 250 | case QL_CMP_GE: 251 | cmp = CMP_GE 252 | case QL_CMP_GT: 253 | cmp = CMP_GT 254 | case QL_CMP_LT: 255 | cmp = CMP_LT 256 | case QL_CMP_LE: 257 | cmp = CMP_LE 258 | case QL_CMP_EQ: 259 | cmp = 0 260 | 261 | default: 262 | panic("unreachable") 263 | } 264 | 265 | names, exprs := node.Kids[0], node.Kids[1] 266 | assert(names.Type == QL_TUP && exprs.Type == QL_TUP) 267 | assert(len(names.Kids) == len(exprs.Kids)) 268 | 269 | vals, err := qlEvelMulti(Record{}, exprs.Kids) 270 | if err != nil { 271 | return Record{}, 0, err 272 | } 273 | 274 | cols := []string{} 275 | for i := range names.Kids { 276 | assert(names.Kids[i].Type == QL_SYM) 277 | cols = append(cols, string(names.Kids[i].Str)) 278 | } 279 | 280 | return Record{cols, vals}, cmp, nil 281 | } 282 | 283 | // scanner implements INDEX BY 284 | func qlScanInit(req *QLScan, sc *Scanner) (err error) { 285 | // convert QLNODE to Record 286 | if sc.Key1, sc.Cmp1, err = qlEvalScanKey(req.Key1); err != nil { 287 | return err 288 | } 289 | if sc.Key2, sc.Cmp2, err = qlEvalScanKey(req.Key2); err != nil { 290 | return err 291 | } 292 | 293 | // convert keys to range 294 | switch { 295 | case req.Key1.Type == 0 && req.Key2.Type == 0: 296 | sc.Cmp1, sc.Cmp2 = CMP_GE, CMP_LE //full table scan by primary key 297 | 298 | case req.Key1.Type == QL_CMP_EQ && req.Key2.Type == 0: 299 | // INDEX BY key= val 300 | sc.Key2 = sc.Key1 301 | sc.Cmp1, sc.Cmp2 = CMP_GE, CMP_LE 302 | 303 | case req.Key1.Type != 0 && req.Key2.Type == 0: //open ended range 304 | if sc.Cmp1 > 0 { 305 | sc.Cmp2 = CMP_LE 306 | } else { 307 | sc.Cmp2 = CMP_GE 308 | } 309 | 310 | case req.Key1.Type != 0 && req.Key2.Type != 0: 311 | // nothing 312 | default: 313 | panic("unreachable") 314 | } 315 | 316 | return nil 317 | } 318 | 319 | type RecordIter interface { 320 | Valid() bool 321 | Next() 322 | Deref(*Record) error 323 | } 324 | 325 | // evaluate expressions in SELECT 326 | type qlSelectIter struct { 327 | iter RecordIter //input 328 | names []string 329 | exprs []QLNODE 330 | } 331 | 332 | func (iter *qlSelectIter) Valid() bool { 333 | return iter.iter.Valid() 334 | } 335 | 336 | func (iter *qlSelectIter) Next() { 337 | iter.iter.Next() 338 | } 339 | 340 | func (iter *qlSelectIter) Deref(rec *Record) error { 341 | if err := iter.iter.Deref(rec); err != nil { 342 | return err 343 | } 344 | 345 | vals, err := qlEvelMulti(*rec, iter.exprs) 346 | if err != nil { 347 | return err 348 | } 349 | 350 | *rec = Record{iter.names, vals} 351 | 352 | return nil 353 | } 354 | 355 | type qlScanIter struct { 356 | // input 357 | req *QLScan 358 | sc Scanner 359 | // state 360 | idx int64 361 | end bool 362 | 363 | // cached output item 364 | rec Record 365 | err error 366 | } 367 | 368 | func qlScanPull(iter *qlScanIter, rec *Record) (bool, error) { 369 | if iter.idx < iter.req.Offset { 370 | return false, nil 371 | } 372 | 373 | iter.sc.Deref(rec) 374 | if iter.req.Filter.Type != 0 { 375 | ctx := QLEvalContext{env: *rec} 376 | qlEval(&ctx, iter.req.Filter) 377 | 378 | if ctx.err != nil { 379 | return false, ctx.err 380 | } 381 | if ctx.out.Type != TYPE_INT64 { 382 | return false, errors.New("filter is not of boolean type") 383 | } 384 | if ctx.out.I64 == 0 { 385 | return false, nil 386 | } 387 | } 388 | 389 | return true, nil 390 | } 391 | 392 | func (iter *qlScanIter) Next() { 393 | for iter.idx < iter.req.Limit && iter.sc.Valid() { 394 | // check current iten 395 | got, err := qlScanPull(iter, &iter.rec) 396 | iter.err = err 397 | 398 | // next item 399 | iter.idx++ 400 | iter.sc.Next() 401 | if got || err != nil { 402 | return 403 | } 404 | } 405 | 406 | iter.end = true 407 | } 408 | 409 | func (iter *qlScanIter) Valid() bool { 410 | return !iter.end 411 | } 412 | 413 | func (iter *qlScanIter) Deref(rec *Record) error { 414 | assert(iter.Valid()) 415 | if iter.err == nil { 416 | *rec = iter.rec 417 | } 418 | return iter.err 419 | } 420 | 421 | // execute query 422 | func qlScan(req *QLScan, tx *DBTX) (RecordIter, error) { 423 | iter := qlScanIter{req: req} 424 | if err := qlScanInit(req, &iter.sc); err != nil { 425 | return nil, err 426 | } 427 | if err := tx.Scan(req.Table, &iter.sc); err != nil { 428 | return nil, err 429 | } 430 | iter.Next() 431 | return &iter, nil 432 | } 433 | 434 | // stmt: select 435 | func qlSelect(req *QLSelect, tx *DBTX) (RecordIter, error) { 436 | // records 437 | records, err := qlScan(&req.QLScan, tx) 438 | if err != nil { 439 | return nil, err 440 | } 441 | 442 | tdef := getTableDef(tx, req.Table) 443 | names, exprs := []string{}, []QLNODE{} 444 | for i := range req.Names { 445 | if req.Names[i] != "*" { 446 | names = append(names, req.Names[i]) 447 | exprs = append(exprs, req.Output[i]) 448 | } else { 449 | names = append(names, tdef.Cols...) 450 | for _, col := range tdef.Cols { 451 | node := QLNODE{Value: Value{Type: QL_SYM, Str: []byte(col)}} 452 | exprs = append(exprs, node) 453 | } 454 | } 455 | } 456 | assert(len(names) == len(exprs)) 457 | 458 | for i := range names { 459 | if names[i] != "" { 460 | continue 461 | } 462 | if exprs[i].Type == QL_SYM { 463 | names[i] = string(exprs[i].Str) 464 | } else { 465 | names[i] = strconv.Itoa(i) 466 | } 467 | } 468 | 469 | return &qlSelectIter{iter: records, names: names, exprs: exprs}, nil 470 | } 471 | 472 | // stmt :create table 473 | func qlCreateTable(req *QLCreateTable, tx *DBTX) error { 474 | return tx.TableNew(&req.Def) 475 | } 476 | 477 | // stmt: Insert 478 | func qlInsert(req *QLInsert, tx *DBTX) (uint64, uint64, error) { 479 | added, updated := uint64(0), uint64(0) 480 | 481 | for _, nodes := range req.Values { 482 | vals, err := qlEvelMulti(Record{}, nodes) 483 | if err != nil { 484 | return 0, 0, err 485 | } 486 | 487 | dbReq := DBUpdateReq{Record: Record{req.Names, vals}, Mode: req.Mode} 488 | _, err = tx.Set(req.Table, &dbReq) 489 | if err != nil { 490 | return 0, 0, err 491 | } 492 | 493 | if dbReq.Added { 494 | added++ 495 | } 496 | if dbReq.Updated { 497 | updated++ 498 | } 499 | } 500 | 501 | return added, updated, nil 502 | } 503 | 504 | // stmt: delete 505 | func qlDelete(req *QLDelete, tx *DBTX) (uint64, error) { 506 | records, err := qlScan(&req.QLScan, tx) 507 | if err != nil { 508 | return 0, err 509 | } 510 | 511 | tdef := getTableDef(tx, req.Table) 512 | deleted := uint64(0) 513 | 514 | for ; records.Valid(); records.Next() { 515 | rec := Record{} 516 | if err := records.Deref(&rec); err != nil { 517 | return 0, err 518 | } 519 | deleted++ 520 | 521 | vals, err := getValues(tdef, rec, tdef.Indexes[0]) 522 | assert(err == nil) 523 | deleted, err := tx.Delete(req.Table, Record{tdef.Indexes[0], vals}) 524 | assert(err == nil && deleted) 525 | } 526 | 527 | return deleted, nil 528 | } 529 | 530 | // stmt Update 531 | func qlUpdate(req *QLUPdate, tx *DBTX) (uint64, error) { 532 | // no update to primary key 533 | assert(len(req.Names) == len(req.Values)) 534 | 535 | tdef := getTableDef(tx, req.Table) 536 | for _, col := range req.Names { 537 | if slices.Index(tdef.Cols, col) < 0 { 538 | return 0, fmt.Errorf("unknown col.: %s", col) 539 | } 540 | if slices.Index(tdef.Indexes[0], col) >= 0 { 541 | return 0, errors.New("cannot update the primary key") 542 | } 543 | } 544 | 545 | records, err := qlScan(&req.QLScan, tx) 546 | if err != nil { 547 | return 0, err 548 | } 549 | 550 | updated := uint64(0) 551 | for ; records.Valid(); records.Next() { 552 | // old record 553 | rec := Record{} 554 | if err := records.Deref(&rec); err != nil { 555 | return 0, err 556 | } 557 | 558 | // new record 559 | vals, err := qlEvelMulti(rec, req.Values) 560 | if err != nil { 561 | return 0, err 562 | } 563 | for i, col := range req.Names { 564 | rec.Vals[slices.Index(tdef.Cols, col)] = vals[i] 565 | } 566 | 567 | // perform updtae 568 | dbReq := DBUpdateReq{Record: rec, Mode: MODE_UPDATE_ONLY} 569 | if _, err := tx.Set(req.Table, &dbReq); err != nil { 570 | return 0, err 571 | } 572 | 573 | if dbReq.Updated { 574 | updated++ 575 | } 576 | } 577 | 578 | return 0, nil 579 | } 580 | 581 | type QLResult struct { 582 | Records RecordIter 583 | Added uint64 584 | Updated uint64 585 | Deleted uint64 586 | } 587 | 588 | // execute a single statement 589 | func qlExec(tx *DBTX, stmt interface{}) (res QLResult, err error) { 590 | save := TXSave{} 591 | tx.Save(&save) 592 | 593 | switch req := stmt.(type) { 594 | case *QLSelect: 595 | res.Records, err = qlSelect(req, tx) 596 | case *QLCreateTable: 597 | err = qlCreateTable(req, tx) 598 | case *QLInsert: 599 | res.Added, res.Updated, err = qlInsert(req, tx) 600 | case *QLDelete: 601 | res.Deleted, err = qlDelete(req, tx) 602 | case *QLUPdate: 603 | res.Updated, err = qlUpdate(req, tx) 604 | 605 | default: 606 | panic("unreachable") 607 | } 608 | 609 | if err != nil { 610 | tx.Revert(&save) 611 | } 612 | 613 | return 614 | } -------------------------------------------------------------------------------- /ql/ql_parse.go: -------------------------------------------------------------------------------- 1 | package ql 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "strings" 7 | "unicode" 8 | 9 | "github.com/Adit0507/AdiDB/btree" 10 | "github.com/Adit0507/AdiDB/table" 11 | ) 12 | 13 | const ( 14 | // syntax tree node types 15 | QL_UNINIT = 0 16 | QL_STR = table.TYPE_BYTES 17 | QL_I64 = table.TYPE_INT64 18 | QL_CMP_GE = 10 // >= 19 | QL_CMP_GT = 11 // > 20 | QL_CMP_LT = 12 // < 21 | QL_CMP_LE = 13 // <= 22 | QL_CMP_EQ = 14 // == 23 | QL_CMP_NE = 15 // != 24 | QL_ADD = 20 25 | QL_SUB = 21 26 | QL_MUL = 22 27 | QL_DIV = 23 28 | QL_MOD = 24 29 | QL_AND = 30 30 | QL_OR = 31 31 | QL_NOT = 50 32 | QL_NEG = 51 33 | QL_SYM = 100 34 | QL_TUP = 101 // tuple 35 | QL_STAR = 102 // select * 36 | QL_ERR = 200 // error; from parsing or evaluation 37 | ) 38 | 39 | // tree node 40 | type QLNODE struct { 41 | Value 42 | Kids []QLNODE //operands 43 | } 44 | 45 | // common structure for stmt.s: 'INDEX BY', ”FILTER', 'LIMIT' 46 | type QLScan struct { 47 | Table string 48 | Key1 QLNODE //index by 49 | Key2 QLNODE 50 | Filter QLNODE //filter expression 51 | Offset int64 52 | Limit int64 53 | } 54 | 55 | // statements: SELECT UPDATE DELETE 56 | type QLSelect struct { 57 | QLScan 58 | Names []string // expr. AS name 59 | Output []QLNODE 60 | } 61 | 62 | type QLUPdate struct { 63 | QLScan 64 | Names []string 65 | Values []QLNODE 66 | } 67 | 68 | // smt: insert 69 | type QLInsert struct { 70 | Table string 71 | Mode int 72 | Names []string 73 | Values [][]QLNODE 74 | } 75 | 76 | type QLDelete struct { 77 | QLScan 78 | } 79 | 80 | type QLCreateTable struct { 81 | Def table.TableDef 82 | } 83 | 84 | type Parser struct { 85 | input []byte 86 | idx int 87 | err error 88 | } 89 | 90 | func isSpace(ch byte) bool { 91 | return unicode.IsSpace(rune(ch)) 92 | } 93 | 94 | func skipSpace(p *Parser) { 95 | for p.idx < len(p.input) && isSpace(p.input[p.idx]) { 96 | p.idx++ 97 | } 98 | } 99 | 100 | func isSym(ch byte) bool { 101 | r := rune(ch) 102 | return unicode.IsLetter(r) || unicode.IsNumber(r) || r == '_' 103 | } 104 | 105 | func isSymStart(ch byte) bool { 106 | return unicode.IsLetter(rune(ch)) || ch == '@' || ch == '_' 107 | } 108 | 109 | // matching multiple keywords sequentially 110 | func pKeyword(p *Parser, kwds ...string) bool { 111 | save := p.idx 112 | 113 | for _, kw := range kwds { 114 | skipSpace(p) 115 | end := p.idx + len(kw) 116 | if end > len(p.input) { 117 | p.idx = save 118 | return false 119 | } 120 | 121 | ok := strings.EqualFold(string(p.input[p.idx:end]), kw) 122 | 123 | if ok && isSym(kw[len(kw)-1]) && end < len(p.input) { 124 | ok = !isSym(p.input[end]) 125 | } 126 | 127 | if !ok { 128 | p.idx = save 129 | return false 130 | } 131 | p.idx += len(kw) 132 | } 133 | 134 | return true 135 | } 136 | 137 | func pSym(p *Parser, node *QLNODE) bool { 138 | skipSpace(p) 139 | 140 | end := p.idx 141 | if !(end < len(p.input) && isSymStart(p.input[end])) { 142 | return false 143 | } 144 | 145 | end++ 146 | for end < len(p.input) && isSym(p.input[end]) { 147 | end++ 148 | } 149 | 150 | if pKeywordSet[strings.ToLower(string(p.input[p.idx:end]))] { 151 | return false 152 | } 153 | 154 | node.Type = QL_SYM 155 | node.Str = p.input[p.idx:end] 156 | p.idx = end 157 | 158 | return true 159 | } 160 | 161 | var pKeywordSet = map[string]bool{ 162 | "from": true, 163 | "index": true, 164 | "filter": true, 165 | "limit": true, 166 | } 167 | 168 | func pErr(p *Parser, format string, args ...interface{}) { 169 | if p.err == nil { 170 | p.err = fmt.Errorf(format, args...) 171 | } 172 | } 173 | 174 | func pMustSym(p *Parser) string { 175 | name := QLNODE{} 176 | if !pSym(p, &name) { 177 | pErr(p, "expect name") 178 | } 179 | 180 | return string(name.Str) 181 | } 182 | 183 | func pCreateTable(p *Parser) *QLCreateTable { 184 | stmt := QLCreateTable{} 185 | stmt.Def.Name = pMustSym(p) 186 | pExpect(p, "(", "expect parenthesis") 187 | 188 | // primary key 189 | stmt.Def.Indexes = append(stmt.Def.Indexes, nil) 190 | comma := true 191 | 192 | for p.err == nil && !pKeyword(p, ")") { 193 | if !comma { 194 | pErr(p, "expect comma") 195 | } 196 | 197 | switch { 198 | case pKeyword(p, "index"): 199 | stmt.Def.Indexes = append(stmt.Def.Indexes, pNameList(p)) 200 | case pKeyword(p, "primary", "key"): 201 | if stmt.Def.Indexes[0] != nil { 202 | pErr(p, "duplicate primary key") 203 | } 204 | stmt.Def.Indexes[0] = pNameList(p) 205 | 206 | default: 207 | stmt.Def.Cols = append(stmt.Def.Cols, pMustSym(p)) 208 | stmt.Def.Types = append(stmt.Def.Types, pColType(p)) 209 | } 210 | comma = pKeyword(p, ",") 211 | } 212 | 213 | return &stmt 214 | } 215 | 216 | func pColType(p *Parser) uint32 { 217 | typedef := pMustSym(p) 218 | 219 | switch strings.ToLower(typedef) { 220 | case "strings", "bytes": 221 | return TYPE_BYTES 222 | case "int", "int64": 223 | return TYPE_INT64 224 | default: 225 | pErr(p, "bad column type: %s", typedef) 226 | return 0 227 | } 228 | } 229 | 230 | 231 | func pSelectExpr(p *Parser, node *QLSelect) { 232 | if pKeyword(p, "*") { 233 | node.Names = append(node.Names, "*") 234 | node.Output = append(node.Output, QLNODE{Value: Value{Type: QL_STAR}}) 235 | return 236 | } 237 | } 238 | 239 | func pSelectExprList(p *Parser, node *QLSelect) { 240 | pSelectExpr(p, node) 241 | 242 | for pKeyword(p, ",") { 243 | pSelectExpr(p, node) 244 | } 245 | } 246 | 247 | func pExpect(p *Parser, tok string, format string, args ...interface{}) { 248 | if !pKeyword(p, tok) { 249 | pErr(p, format, args...) 250 | } 251 | } 252 | 253 | func pScan(p *Parser, node *QLScan) { 254 | // INDEX BY 255 | if pKeyword(p, "index", "by") { 256 | pIndexBy(p, node) 257 | } 258 | 259 | if pKeyword(p, "filter") { 260 | pExprOr(p, &node.Filter) 261 | } 262 | 263 | node.Offset, node.Limit = 0, math.MaxInt64 264 | if pKeyword(p, "index", "by") { 265 | pLimit(p, node) 266 | } 267 | } 268 | 269 | func pExprOr(p *Parser, node *QLNODE) { 270 | pExprBinop(p, node, []string{"or"}, []uint32{QL_OR}, pExprAnd) 271 | } 272 | 273 | func pExprAnd(p *Parser, node *QLNODE) { 274 | pExprBinop(p, node, []string{"and"}, []uint32{QL_ADD}, pExprNot) 275 | } 276 | 277 | func pExprNot(p *Parser, node *QLNODE) { 278 | switch { 279 | case pKeyword(p, "not"): 280 | node.Type = QL_NOT 281 | node.Kids = []QLNODE{{}} 282 | pExprCmp(p, &node.Kids[0]) 283 | } 284 | } 285 | 286 | func pExprCmp(p *Parser, node *QLNODE) { 287 | pExprBinop(p, node, []string{ 288 | "=", "!=", 289 | ">=", "<=", ">", "<", 290 | }, 291 | []uint32{ 292 | QL_CMP_EQ, QL_CMP_NE, QL_CMP_GE, QL_CMP_LE, QL_CMP_GT, QL_CMP_LT, 293 | }, 294 | pExprAdd) 295 | } 296 | 297 | func pExprAdd(p *Parser, node *QLNODE) { 298 | pExprBinop(p, node, []string{"+", "-"}, []uint32{QL_ADD, QL_SUB}, pExprMul) 299 | } 300 | 301 | func pExprMul(p *Parser, node *QLNODE) { 302 | pExprBinop(p, node, []string{"*", "/", "%"}, []uint32{QL_MUL, QL_DIV, QL_MOD}, pExprUnop) 303 | } 304 | 305 | func pExprUnop(p *Parser, node *QLNODE) { 306 | switch { 307 | case pKeyword(p, "-"): 308 | node.Type = QL_NEG 309 | node.Kids = []QLNODE{{}} 310 | pExprAtom(p, &node.Kids[0]) 311 | 312 | default: 313 | pExprAtom(p, node) 314 | } 315 | } 316 | 317 | func pExprAtom(p *Parser, node *QLNODE) { 318 | switch { 319 | case pKeyword(p, "("): 320 | pExprTuple(p, node) 321 | case pSym(p, node): 322 | case pNum(p, node): 323 | case pStr(p, node): 324 | default: 325 | pErr(p, "expect symbol, number or string") 326 | } 327 | } 328 | 329 | func pStr(p *Parser, node *QLNODE) bool { 330 | skipSpace(p) 331 | 332 | cur := p.idx 333 | quote := byte(0) 334 | if cur < len(p.input) { 335 | quote = p.input[cur] 336 | cur++ 337 | } 338 | if !(quote == '*' || quote == '\'') { 339 | return false 340 | } 341 | 342 | var s []byte 343 | for cur < len(p.input) && p.input[cur] != quote { 344 | switch p.input[cur] { 345 | case '\\': 346 | cur++ 347 | if cur >= len(p.input) { 348 | pErr(p, "string not terminated") 349 | return false 350 | } 351 | switch p.input[cur] { 352 | case '"', '\'', '\\': 353 | s = append(s, p.input[cur]) 354 | cur++ 355 | default: 356 | pErr(p, "unknown escape") 357 | return false 358 | } 359 | default: 360 | s = append(s, p.input[cur]) 361 | cur++ 362 | } 363 | } 364 | 365 | if !(cur < len(p.input) && p.input[cur] == quote) { 366 | pErr(p, "string not terminated") 367 | return false 368 | } 369 | 370 | cur++ 371 | node.Type = QL_STR 372 | node.Str = s 373 | p.idx = cur 374 | return true 375 | } 376 | 377 | func pExprTuple(p *Parser, node *QLNODE) { 378 | kids := []QLNODE{} 379 | comma := true 380 | 381 | for p.err == nil && !pKeyword(p, ")") { 382 | if !comma { 383 | pErr(p, "expect comma") 384 | } 385 | 386 | kids = append(kids, QLNODE{}) 387 | pExprOr(p, &kids[len(kids)-1]) 388 | comma = pKeyword(p, ",") 389 | } 390 | 391 | if len(kids) == 1 && !comma { 392 | node = &kids[0] 393 | } else { 394 | node.Type = QL_TUP 395 | node.Kids = kids 396 | } 397 | } 398 | 399 | func pExprBinop(p *Parser, node *QLNODE, ops []string, types []uint32, next func(*Parser, *QLNODE)) { 400 | assert(len(ops) == len(types)) 401 | left := QLNODE{} 402 | next(p, &left) 403 | 404 | for { 405 | i := 0 406 | for i < len(ops) && !pKeyword(p, ops[i]) { 407 | i++ 408 | } 409 | 410 | if i >= len(ops) { 411 | *node = left 412 | return 413 | } 414 | 415 | new := QLNODE{Value: Value{Type: types[i]}} 416 | new.Kids = []QLNODE{left, {}} 417 | next(p, &new.Kids[1]) 418 | left = new 419 | } 420 | } 421 | 422 | func pLimit(p *Parser, node *QLScan) { 423 | offset, count := QLNODE{}, QLNODE{} 424 | ok := pNum(p, &count) 425 | if pKeyword(p, ",") { 426 | offset = count 427 | ok = ok && pNum(p, &count) 428 | } 429 | 430 | if !ok || offset.I64 < 0 || count.I64 < 0 || offset.I64+count.I64 < 0 { 431 | pErr(p, "bad `LIMIT`") 432 | } 433 | 434 | node.Offset = offset.I64 435 | if count.Type != 0 { 436 | node.Limit = node.Offset + count.I64 437 | } 438 | } 439 | 440 | func pNum(p *Parser, node *QLNODE) bool { 441 | skipSpace(p) 442 | 443 | end := p.idx 444 | for end < len(p.input) && unicode.IsNumber(rune(p.input[end])) { 445 | end++ 446 | } 447 | 448 | if end == p.idx { 449 | return false 450 | } 451 | if end < len(p.input) && isSym(p.input[end]) { 452 | return false 453 | } 454 | 455 | return true 456 | } 457 | 458 | // INDEX BY colsvals 459 | func pIndexBy(p *Parser, node *QLScan) { 460 | index := QLNODE{} 461 | pExprAnd(p, &index) 462 | 463 | if index.Type == QL_AND { 464 | node.Key1, node.Key2 = index.Kids[0], index.Kids[1] 465 | } else { 466 | node.Key1 = index 467 | } 468 | 469 | pVerifyScanKey(p, &node.Key1) 470 | if node.Key2.Type != 0 { 471 | pVerifyScanKey(p, &node.Key2) 472 | } 473 | 474 | if node.Key1.Type == QL_CMP_EQ && node.Key2.Type != 0 { 475 | pErr(p, "bad `INDEX BY`: expect only a single `=`") 476 | } 477 | } 478 | 479 | func pVerifyScanKey(p *Parser, node *QLNODE) { 480 | switch node.Type { 481 | case QL_CMP_EQ, QL_CMP_GE, QL_CMP_GT, QL_CMP_LT, QL_CMP_LE: 482 | default: 483 | pErr(p, "bad `INDEX BY`: not a comparision") 484 | return 485 | } 486 | 487 | l, r := node.Kids[0], node.Kids[1] 488 | if l.Type != QL_TUP && r.Type != QL_TUP { 489 | l = QLNODE{Value: Value{Type: QL_TUP}, Kids: []QLNODE{l}} 490 | r = QLNODE{Value: Value{Type: QL_TUP}, Kids: []QLNODE{r}} 491 | } 492 | 493 | if l.Type != QL_TUP || r.Type != QL_TUP { 494 | pErr(p, "bad `INDEX BY`: bad comparison") 495 | } 496 | 497 | if len(l.Kids) != len(r.Kids) { 498 | pErr(p, "bad `INDEX BY`: bad comparison") 499 | } 500 | 501 | for _, name := range l.Kids { 502 | if name.Type != QL_SYM { 503 | pErr(p, "bad `INDEX BY`: expect column name") 504 | } 505 | } 506 | node.Kids[0], node.Kids[1] = l, r 507 | } 508 | 509 | func pSelect(p *Parser) *QLSelect { 510 | stmt := QLSelect{} 511 | pSelectExprList(p, &stmt) 512 | 513 | pExpect(p, "from", "expect `FROM` table") 514 | stmt.Table = pMustSym(p) 515 | 516 | pScan(p, &stmt.QLScan) 517 | return &stmt 518 | } 519 | 520 | func pNameList(p *Parser) []string { 521 | pExpect(p, "(", "expect parenthesis") 522 | names := []string{pMustSym(p)} 523 | comma := pKeyword(p, "expect comma") 524 | 525 | for p.err == nil && !pKeyword(p, ")") { 526 | if !comma { 527 | pErr(p, "expect commma") 528 | } 529 | 530 | names = append(names, pMustSym(p)) 531 | comma = pKeyword(p, ")") 532 | } 533 | 534 | return names 535 | } 536 | 537 | func pValueList(p *Parser) []QLNODE { 538 | pExpect(p, "(", "expect value list") 539 | 540 | var vals []QLNODE 541 | comma := true 542 | for p.err == nil && !pKeyword(p, ")") { 543 | if !comma { 544 | pErr(p, "expect comma") 545 | } 546 | 547 | node := QLNODE{} 548 | pExprOr(p, &node) 549 | vals = append(vals, node) 550 | comma = pKeyword(p, ",") 551 | } 552 | 553 | return vals 554 | } 555 | 556 | func pStmt(p *Parser) (r interface{}) { 557 | switch { 558 | case pKeyword(p, "create", "table"): 559 | r = pCreateTable(p) 560 | case pKeyword(p, "select"): 561 | r = pSelect(p) 562 | case pKeyword(p, "insert", "into"): 563 | r = pInsert(p, btree.MODE_INSERT_ONLY) 564 | case pKeyword(p, "replace", "into"): 565 | r = pInsert(p, btree.MODE_UPDATE_ONLY) 566 | case pKeyword(p, "upsert", "into"): 567 | r = pInsert(p, btree.MODE_UPSERT) 568 | case pKeyword(p, "update"): 569 | r = pUpdate(p) 570 | case pKeyword(p, "delete", "from"): 571 | r = pDelete(p) 572 | 573 | default: 574 | pErr(p, "unknown stmt") 575 | } 576 | 577 | if p.err != nil { 578 | return nil 579 | } 580 | 581 | return r 582 | } 583 | 584 | func pDelete(p *Parser) *QLDelete { 585 | stmt := QLDelete{} 586 | stmt.Table = pMustSym(p) 587 | pScan(p, &stmt.QLScan) 588 | 589 | return &stmt 590 | } 591 | 592 | func pUpdate(p *Parser) *QLUPdate { 593 | stmt := QLUPdate{} 594 | stmt.Table = pMustSym(p) 595 | 596 | pExpect(p, "set", "expect `SET`") 597 | pAssign(p, &stmt) 598 | for pKeyword(p, ",") { 599 | pAssign(p, &stmt) 600 | } 601 | 602 | pScan(p, &stmt.QLScan) 603 | return &stmt 604 | } 605 | 606 | func pAssign(p *Parser, stmt *QLUPdate) { 607 | stmt.Names = append(stmt.Names, pMustSym(p)) 608 | pExpect(p, "=", "expect `=`") 609 | stmt.Values = append(stmt.Values, QLNODE{}) 610 | pExprOr(p, &stmt.Values[len(stmt.Values) - 1]) 611 | } 612 | 613 | func pInsert(p *Parser, mode int) *QLInsert { 614 | stmt := QLInsert{} 615 | stmt.Table = pMustSym(p) 616 | stmt.Mode = mode 617 | stmt.Names = pNameList(p) 618 | 619 | pExpect(p, "values", "expect `VALUES`") 620 | stmt.Values = append(stmt.Values, pValueList(p)) 621 | for pKeyword(p, ",") { 622 | stmt.Values = append(stmt.Values, pValueList(p)) 623 | } 624 | 625 | for _, row := range stmt.Values { 626 | if len(row) != len(stmt.Names) { 627 | pErr(p, "values length dont match") 628 | } 629 | } 630 | 631 | return &stmt 632 | } 633 | -------------------------------------------------------------------------------- /table/table.go: -------------------------------------------------------------------------------- 1 | package table 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "encoding/json" 7 | "fmt" 8 | "slices" 9 | "sync" 10 | 11 | "github.com/Adit0507/AdiDB/btree_iter" 12 | "github.com/Adit0507/AdiDB/btree" 13 | "github.com/Adit0507/AdiDB/kv" 14 | "github.com/Adit0507/AdiDB/transactions" 15 | ) 16 | 17 | const ( 18 | TYPE_ERROR = 0 19 | TYPE_BYTES = 1 20 | TYPE_INT64 = 2 21 | TYPE_INF = 0xff 22 | ) 23 | 24 | type DB struct { 25 | Path string 26 | kv kv.KV 27 | mu sync.Mutex 28 | tables map[string]*TableDef 29 | } 30 | 31 | type DBTX struct { 32 | kv transactions.KVTX 33 | db *DB 34 | } 35 | 36 | func (db *DB) Begin(tx *DBTX) { 37 | tx.db = db 38 | db.kv.Begin(&tx.kv) 39 | } 40 | 41 | func (db *DB) Commit(tx *DBTX) error { 42 | return db.kv.Commit(&tx.kv) 43 | } 44 | 45 | func (db *DB) Abort(tx *DBTX) { 46 | db.kv.Abort(&tx.kv) 47 | } 48 | 49 | func (tx *DBTX) Save(save *transactions.TXSave) { 50 | tx.kv.Save(save) 51 | } 52 | 53 | func (tx*DBTX) Revert(save *transactions.TXSave) { 54 | tx.kv.Revert(save) 55 | } 56 | 57 | type TableDef struct { 58 | Name string 59 | Types []uint32 //col type 60 | Cols []string //col name 61 | Prefixes []uint32 62 | Indexes [][]string 63 | } 64 | 65 | // table cell 66 | type Value struct { 67 | Type uint32 68 | I64 int64 69 | Str []byte 70 | } 71 | 72 | // represents a list of col names and values 73 | type Record struct { 74 | Cols []string 75 | Vals []Value 76 | } 77 | 78 | func (rec *Record) AddStr(col string, val []byte) *Record { 79 | rec.Cols = append(rec.Cols, col) 80 | rec.Vals = append(rec.Vals, Value{Type: TYPE_BYTES, Str: val}) 81 | 82 | return rec 83 | } 84 | 85 | func (rec *Record) AddInt64(col string, val int64) *Record { 86 | rec.Cols = append(rec.Cols, col) 87 | rec.Vals = append(rec.Vals, Value{Type: TYPE_INT64, I64: val}) 88 | 89 | return rec 90 | } 91 | func (rec *Record) Get(key string) *Value { 92 | for i, c := range rec.Cols { 93 | if c == key { 94 | return &rec.Vals[i] 95 | } 96 | } 97 | 98 | return nil 99 | } 100 | 101 | // INTERNAL TABLES 102 | // store metadata 103 | var TDEF_META = &TableDef{ 104 | Name: "@meta", 105 | Types: []uint32{TYPE_BYTES, TYPE_BYTES}, 106 | Cols: []string{"key", "val"}, 107 | Prefixes: []uint32{1}, 108 | Indexes: [][]string{{"key"}}, 109 | } 110 | 111 | // store table schemas 112 | var TDEF_TABLE = &TableDef{ 113 | Name: "@table", 114 | Types: []uint32{TYPE_BYTES, TYPE_BYTES}, 115 | Cols: []string{"name", "def"}, 116 | Prefixes: []uint32{2}, 117 | Indexes: [][]string{{"name"}}, 118 | } 119 | 120 | var INTERNAL_TABLES map[string]*TableDef = map[string]*TableDef{ 121 | "@meta": TDEF_META, 122 | "@table": TDEF_TABLE, 123 | } 124 | 125 | func assert(cond bool ){ 126 | if !cond { 127 | panic("assertion failure") 128 | } 129 | } 130 | 131 | // reorder records to defined col. order 132 | func reorderRecord(tdef *TableDef, rec Record) ([]Value, error) { 133 | assert(len(rec.Cols) == len(rec.Vals)) 134 | out := make([]Value, len(tdef.Cols)) 135 | for i, c := range tdef.Cols { 136 | v := rec.Get(c) 137 | if v == nil { 138 | continue 139 | } 140 | if v.Type != tdef.Types[i] { 141 | return nil, fmt.Errorf("bad column type: %s", c) 142 | } 143 | out[i] = *v 144 | } 145 | 146 | return out, nil 147 | } 148 | 149 | func valuesComplete(tdef *TableDef, vals []Value, n int) error { 150 | for i, v := range vals { 151 | if i < n && v.Type == 0 { 152 | return fmt.Errorf("missing column: %s", tdef.Cols[i]) 153 | } else if i >= n && v.Type != 0 { 154 | return fmt.Errorf("extra column: %s", tdef.Cols[i]) 155 | } 156 | } 157 | 158 | return nil 159 | } 160 | 161 | // escape null byte so string doesnt contain no null byte 162 | func escapeString(in []byte) []byte { 163 | toEscape := bytes.Count(in, []byte{0}) + bytes.Count(in, []byte{1}) 164 | if toEscape == 0 { 165 | return in 166 | } 167 | 168 | out := make([]byte, len(in)+toEscape) 169 | pos := 0 170 | for _, ch := range in { 171 | if ch <= 1 { 172 | out[pos+0] = 0x01 173 | out[pos+1] = ch + 1 174 | pos += 2 175 | } else { 176 | out[pos] = ch 177 | pos += 1 178 | } 179 | } 180 | return out 181 | } 182 | 183 | func unescapeString(in []byte) []byte { 184 | if bytes.Count(in, []byte{1}) == 0 { 185 | return in 186 | } 187 | 188 | out := make([]byte, 0, len(in)) 189 | for i := 0; i < len(in); i++ { 190 | if in[i] == 0x01 { 191 | // 01 01 -> 00 192 | i++ 193 | assert(in[i] == 1 || in[i] == 2) 194 | out = append(out, in[i]-1) 195 | } else { 196 | out = append(out, in[i]) 197 | } 198 | 199 | } 200 | 201 | return out 202 | } 203 | 204 | // order preserving encoding 205 | func encodeValues(out []byte, vals []Value) []byte { 206 | for _, v := range vals { 207 | out = append(out, byte(v.Type)) 208 | switch v.Type { 209 | case TYPE_INT64: 210 | var buf [8]byte 211 | u := uint64(v.I64) + (1 << 63) // flip the sign bit 212 | binary.BigEndian.PutUint64(buf[:], u) // big endian 213 | out = append(out, buf[:]...) 214 | case TYPE_BYTES: 215 | out = append(out, escapeString(v.Str)...) 216 | out = append(out, 0) // null-terminated 217 | default: 218 | panic("what?") 219 | } 220 | } 221 | 222 | return out 223 | } 224 | 225 | // for input range, which can be prefix of index key 226 | func encodeKeyPartial(out []byte, prefix uint32, vals []Value, cmp int) []byte { 227 | out = encodeKey(out, prefix, vals) 228 | if cmp == btree_iter.CMP_GT || cmp == btree_iter.CMP_LE { 229 | out = append(out, 0xff) 230 | } 231 | return out 232 | } 233 | 234 | func encodeKey(out []byte, prefix uint32, vals []Value) []byte { 235 | var buf [4]byte 236 | binary.BigEndian.PutUint32(buf[:], prefix) 237 | out = append(out, buf[:]...) 238 | out = encodeValues(out, vals) 239 | 240 | return out 241 | } 242 | 243 | func decodeKey(in []byte, out []Value) { 244 | decodeValues(in[4:], out) 245 | } 246 | 247 | func decodeValues(in []byte, out []Value) { 248 | for i := range out { 249 | switch out[i].Type { 250 | case TYPE_INT64: 251 | u := binary.BigEndian.Uint64(in[:8]) 252 | out[i].I64 = int64(u - (1 << 63)) 253 | in = in[8:] 254 | case TYPE_BYTES: 255 | idx := bytes.IndexByte(in, 0) 256 | assert(idx >= 0) 257 | out[i].Str = unescapeString(in[:idx]) 258 | in = in[idx+1:] 259 | default: 260 | panic("what?") 261 | } 262 | } 263 | 264 | assert(len(in) == 0) 265 | } 266 | 267 | // check for missing columns 268 | func checkRecord(tdef *TableDef, rec Record, n int) ([]Value, error) { 269 | vals, err := reorderRecord(tdef, rec) 270 | if err != nil { 271 | return nil, err 272 | } 273 | 274 | err = valuesComplete(tdef, vals, n) 275 | if err != nil { 276 | return nil, err 277 | } 278 | return vals, nil 279 | } 280 | 281 | // extract multiple col. values 282 | func getValues(tdef *TableDef, rec Record, cols []string) ([]Value, error) { 283 | vals := make([]Value, len(cols)) 284 | for i, c := range cols { 285 | v := rec.Get(c) 286 | if v == nil { 287 | return nil, fmt.Errorf("missing col.: %s", tdef.Cols[i]) 288 | } 289 | 290 | if v.Type != tdef.Types[slices.Index(tdef.Cols, c)] { 291 | return nil, fmt.Errorf("bad column type: %s", c) 292 | } 293 | vals[i] = *v 294 | } 295 | return vals, nil 296 | } 297 | 298 | // get a single row by primary key 299 | func dbGet(tx *DBTX, tdef *TableDef, rec *Record) (bool, error) { 300 | vals, err := getValues(tdef, *rec, tdef.Indexes[0]) 301 | if err != nil { 302 | return false, err 303 | } 304 | 305 | //scan operation 306 | sc := Scanner{ 307 | Cmp1: btree_iter.CMP_GE, 308 | Cmp2: btree_iter.CMP_LE, 309 | Key1: Record{tdef.Indexes[0], vals}, 310 | Key2: Record{tdef.Indexes[0], vals}, 311 | } 312 | 313 | if err := dbScan(tx, tdef, &sc); err != nil || !sc.Valid() { 314 | return false, err 315 | } 316 | sc.Deref(rec) 317 | return true, nil 318 | } 319 | 320 | func (tx *DBTX) Get(table string, rec *Record) (bool, error) { 321 | tdef := getTableDef(tx, table) 322 | if tdef == nil { 323 | return false, fmt.Errorf("table not found: %s", table) 324 | } 325 | 326 | return dbGet(tx, tdef, rec) 327 | } 328 | 329 | const TABLE_PREFIX_MIN = 100 330 | 331 | func tableDefCheck(tdef *TableDef) error { 332 | // very table schema 333 | bad := tdef.Name == "" || len(tdef.Cols) == 0 334 | bad = bad || len(tdef.Cols) != len(tdef.Types) 335 | if bad { 336 | return fmt.Errorf("bad table schema: %s", tdef.Name) 337 | } 338 | 339 | // verifyin indexes 340 | for i, index := range tdef.Indexes { 341 | index, err := checkIndexCols(tdef, index) 342 | if err != nil { 343 | return err 344 | } 345 | tdef.Indexes[i] = index 346 | } 347 | 348 | return nil 349 | } 350 | 351 | func checkIndexCols(tdef *TableDef, index []string) ([]string, error) { 352 | if len(index) == 0 { 353 | return nil, fmt.Errorf("empty index") 354 | } 355 | 356 | seen := map[string]bool{} 357 | for _, c := range index { 358 | // check index cols 359 | if slices.Index(tdef.Cols, c) < 0 { 360 | return nil, fmt.Errorf("unknown index column: %s", c) 361 | } 362 | if seen[c] { 363 | return nil, fmt.Errorf("duplicated column index: %s", c) 364 | } 365 | 366 | seen[c] = true 367 | } 368 | 369 | // addin primary key to index 370 | for _, c := range tdef.Indexes[0] { 371 | if !seen[c] { 372 | index = append(index, c) 373 | } 374 | } 375 | assert(len(index) <= len(tdef.Cols)) 376 | return index, nil 377 | } 378 | 379 | func (tx *DBTX) TableNew(tdef *TableDef) error { 380 | if err := tableDefCheck(tdef); err != nil { 381 | return err 382 | } 383 | 384 | // check existing table 385 | table := (&Record{}).AddStr("name", []byte(tdef.Name)) 386 | ok, err := dbGet(tx, TDEF_TABLE, table) 387 | assert(err == nil) 388 | if ok { 389 | return fmt.Errorf("table exists: %s", tdef.Name) 390 | } 391 | 392 | // alllocating new prefixes 393 | prefix := uint32(TABLE_PREFIX_MIN) 394 | meta := (&Record{}).AddStr("key", []byte("next_prefix")) 395 | ok, err = dbGet(tx, TDEF_META, meta) 396 | assert(err == nil) 397 | if ok { 398 | prefix = binary.LittleEndian.Uint32(meta.Get("val").Str) 399 | assert(prefix > TABLE_PREFIX_MIN) 400 | } else { 401 | meta.AddStr("val", make([]byte, 4)) 402 | } 403 | assert(len(tdef.Prefixes) == 0) 404 | for i := range tdef.Indexes { 405 | tdef.Prefixes = append(tdef.Prefixes, prefix+uint32(i)) 406 | } 407 | 408 | // updatin next prefix 409 | next := prefix + uint32(len(tdef.Prefixes)) 410 | binary.LittleEndian.PutUint32(meta.Get("val").Str, next) 411 | _, err = dbUpdate(tx, TDEF_META, &DBUpdateReq{Record: *meta}) 412 | if err != nil { 413 | return err 414 | } 415 | 416 | // storin schema 417 | val, err := json.Marshal(tdef) 418 | assert(err == nil) 419 | table.AddStr("def", val) 420 | _, err = dbUpdate(tx, TDEF_TABLE, &DBUpdateReq{Record: *table}) 421 | 422 | return err 423 | } 424 | 425 | // get table schema by naem 426 | func getTableDef(tx *DBTX, name string) *TableDef { 427 | if tdef, ok := INTERNAL_TABLES[name]; ok { 428 | return tdef // expose internal tables 429 | } 430 | tdef := tx.db.tables[name] 431 | if tdef == nil { 432 | if tdef = getTableDefDB(tx, name); tdef != nil { 433 | tx.db.tables[name] = tdef 434 | } 435 | } 436 | return tdef 437 | } 438 | 439 | func getTableDefDB(tx *DBTX, name string) *TableDef { 440 | rec := (&Record{}).AddStr("name", []byte(name)) 441 | ok, err := dbGet(tx, TDEF_TABLE, rec) 442 | assert(err == nil) 443 | if !ok { 444 | return nil 445 | } 446 | 447 | tdef := &TableDef{} 448 | err = json.Unmarshal(rec.Get("def").Str, tdef) 449 | assert(err == nil) 450 | 451 | return tdef 452 | } 453 | 454 | type DBUpdateReq struct { 455 | Record Record 456 | Mode int 457 | Updated bool 458 | Added bool 459 | } 460 | 461 | func nonPrimaryKeyCols(tdef *TableDef) (out []string) { 462 | for _, c := range tdef.Cols { 463 | if slices.Index(tdef.Indexes[0], c) < 0 { 464 | out = append(out, c) 465 | } 466 | } 467 | return 468 | } 469 | 470 | const ( 471 | INDEX_ADD = 1 472 | INDEX_DEL = 2 473 | ) 474 | 475 | // ADD OR REMOVE SECONDARY INDEX KEYS 476 | func indexOP(tx *DBTX, tdef *TableDef, op int, rec Record) error { 477 | for i := 1; i < len(tdef.Indexes); i++ { 478 | vals, err := getValues(tdef, rec, tdef.Indexes[i]) 479 | assert(err == nil) 480 | key := encodeKey(nil, tdef.Prefixes[i], vals) 481 | 482 | switch op { 483 | case INDEX_ADD: 484 | req := UpdateReq{Key: key, Val: nil} 485 | if _, err := tx.kv.Update(&req); err != nil { 486 | return err 487 | } 488 | assert(req.Added) // internal consistency 489 | case INDEX_DEL: 490 | deleted, err := tx.kv.Del(&DeleteReq{Key: key}) 491 | assert(err == nil) 492 | assert(deleted) 493 | default: 494 | panic("unreachable") 495 | } 496 | if err != nil { 497 | return err 498 | } 499 | } 500 | 501 | return nil 502 | } 503 | 504 | // add row to table 505 | func dbUpdate(tx *DBTX, tdef *TableDef, dbreq *DBUpdateReq) (bool, error) { 506 | cols := slices.Concat(tdef.Indexes[0], nonPrimaryKeyCols(tdef)) 507 | values, err := getValues(tdef, dbreq.Record, cols) 508 | if err != nil { 509 | return false, err 510 | } 511 | 512 | // insert row 513 | np := len(tdef.Indexes[0]) 514 | key := encodeKey(nil, tdef.Prefixes[0], values[:np]) 515 | val := encodeValues(nil, values[np:]) 516 | req := UpdateReq{Key: key, Val: val, Mode: dbreq.Mode} 517 | if _, err := tx.kv.Update(&req); err != nil { 518 | return false, err 519 | } 520 | 521 | dbreq.Added, dbreq.Updated = req.Added, req.Updated 522 | 523 | // maintain secondary indexes 524 | if req.Updated && !req.Added { 525 | decodeValues(req.Old, values[np:]) 526 | oldRec := Record{cols, values} 527 | // delete indexed keys 528 | err := indexOP(tx, tdef, INDEX_DEL, oldRec) 529 | assert(err == nil) 530 | } 531 | 532 | if req.Updated { 533 | if err = indexOP(tx, tdef, INDEX_ADD, dbreq.Record); err != nil { 534 | return false, err 535 | } 536 | } 537 | 538 | return req.Updated, nil 539 | } 540 | 541 | // addin a record 542 | func (tx *DBTX) Set(table string, dbreq *DBUpdateReq) (bool, error) { 543 | tdef := getTableDef(tx, table) 544 | if tdef == nil { 545 | return false, fmt.Errorf("table not found: %s", table) 546 | } 547 | 548 | return dbUpdate(tx, tdef, dbreq) 549 | } 550 | 551 | func (tx *DBTX) Insert(table string, rec Record) (bool, error) { 552 | return tx.Set(table, &DBUpdateReq{Record: rec, Mode: btree.MODE_INSERT_ONLY}) 553 | } 554 | func (tx *DBTX) Update(table string, rec Record) (bool, error) { 555 | return tx.Set(table, &DBUpdateReq{Record: rec, Mode: btree.MODE_UPDATE_ONLY}) 556 | } 557 | func (tx *DBTX) Upsert(table string, rec Record) (bool, error) { 558 | return tx.Set(table, &DBUpdateReq{Record: rec, Mode: btree.MODE_UPSERT}) 559 | } 560 | 561 | // delete a record by primary key 562 | func dbDelete(tx *DBTX, tdef *TableDef, rec Record) (bool, error) { 563 | vals, err := getValues(tdef, rec, tdef.Indexes[0]) 564 | if err != nil { 565 | return false, err 566 | } 567 | 568 | // delete row 569 | req := DeleteReq{Key: encodeKey(nil, tdef.Prefixes[0], vals)} 570 | if deleted, _ := tx.kv.Del(&req); !deleted { 571 | return false, nil 572 | } 573 | 574 | for _, c := range nonPrimaryKeyCols(tdef) { 575 | tp := tdef.Types[slices.Index(tdef.Cols, c)] 576 | vals = append(vals, Value{Type: tp}) 577 | } 578 | 579 | decodeValues(req.Old, vals[len(tdef.Indexes[0]):]) 580 | err = indexOP(tx, tdef, INDEX_DEL, Record{tdef.Cols, vals}) 581 | assert(err == nil) 582 | 583 | return true, nil 584 | } 585 | 586 | func (tx *DBTX) Delete(table string, rec Record) (bool, error) { 587 | tdef := getTableDef(tx, table) 588 | if tdef == nil { 589 | return false, fmt.Errorf("table not found: %s", table) 590 | } 591 | 592 | return dbDelete(tx, tdef, rec) 593 | } 594 | 595 | func (db *DB) Open() error { 596 | db.kv.Path = db.Path 597 | db.tables = map[string]*TableDef{} 598 | 599 | // opening kv store 600 | return db.kv.Open() 601 | } 602 | 603 | func (db *DB) Close() { 604 | db.kv.Close() 605 | } 606 | 607 | // scanner decodes KV's into rows 608 | // iterator for range queries 609 | // Scanner is a wrapper for B+ Tree iterator 610 | type Scanner struct { 611 | Cmp1 int 612 | Cmp2 int 613 | 614 | // range from Key 1 to key2 615 | Key1 Record 616 | Key2 Record 617 | 618 | // internal 619 | tx *DBTX 620 | index int 621 | tdef *TableDef 622 | iter transactions.KVIter 623 | keyEnd []byte 624 | } 625 | 626 | // within range or not 627 | func (sc *Scanner) Valid() bool { 628 | return sc.iter.Valid() 629 | } 630 | 631 | // movin underlying B+ tree iterator 632 | func (sc *Scanner) Next() { 633 | sc.iter.Next() 634 | } 635 | 636 | // return current row 637 | func (sc *Scanner) Deref(rec *Record) { 638 | assert(sc.Valid()) 639 | tdef := sc.tdef 640 | 641 | // fetch KV from iterator 642 | key, val := sc.iter.Deref() 643 | 644 | // prepare output record 645 | rec.Cols = slices.Concat(tdef.Indexes[0], nonPrimaryKeyCols(tdef)) 646 | rec.Vals = rec.Vals[:0] 647 | for _, c := range rec.Cols { 648 | tp := tdef.Types[slices.Index(tdef.Cols, c)] 649 | rec.Vals = append(rec.Vals, Value{Type: tp}) 650 | } 651 | 652 | if sc.index == 0 { 653 | // decode full row 654 | np := len(tdef.Indexes[0]) 655 | decodeKey(key, rec.Vals[:np]) 656 | decodeValues(val, rec.Vals[np:]) 657 | } else { 658 | // decode index key 659 | assert(len(val) == 0) 660 | index := tdef.Indexes[sc.index] 661 | irec := Record{index, make([]Value, len(index))} 662 | 663 | for i, c := range index { 664 | irec.Vals[i].Type = tdef.Types[slices.Index(tdef.Cols, c)] 665 | } 666 | decodeKey(key, irec.Vals) 667 | 668 | // extract primary key 669 | for i, c := range tdef.Indexes[0] { 670 | rec.Vals[i] = *irec.Get(c) 671 | } 672 | 673 | // fetch row by primary key 674 | ok, err := dbGet(sc.tx, tdef, rec) 675 | assert(ok && err == nil) 676 | } 677 | } 678 | 679 | // check col. types 680 | func checkTypes(tdef *TableDef, rec Record) error { 681 | if len(rec.Cols) != len(rec.Vals) { 682 | return fmt.Errorf("bad record") 683 | } 684 | 685 | for i, c := range rec.Cols { 686 | j := slices.Index(tdef.Cols, c) 687 | if j < 0 || tdef.Types[j] != rec.Vals[i].Type { 688 | return fmt.Errorf("bad column: %s", c) 689 | } 690 | } 691 | return nil 692 | } 693 | 694 | func dbScan(tx *DBTX, tdef *TableDef, req *Scanner) error { 695 | switch { 696 | case req.Cmp1 > 0 && req.Cmp2 < 0: 697 | case req.Cmp1 < 0 && req.Cmp2 > 0: 698 | default: 699 | return fmt.Errorf("bad range") 700 | } 701 | 702 | if err := checkTypes(tdef, req.Key1); err != nil { 703 | return err 704 | } 705 | if err := checkTypes(tdef, req.Key2); err != nil { 706 | return err 707 | } 708 | 709 | req.tx = tx 710 | req.tdef = tdef 711 | 712 | // select index 713 | isCovered := func(key []string,index []string) bool { 714 | return len(index) >= len(key) && slices.Equal(index[:len(key)], key) 715 | } 716 | 717 | req.index = slices.IndexFunc(tdef.Indexes, func (index []string)bool { 718 | return isCovered(req.Key1.Cols, index) && isCovered(req.Key2.Cols, index) 719 | }) 720 | if req.index < 0 { 721 | return fmt.Errorf("no index") 722 | } 723 | 724 | // encode start key 725 | prefix := tdef.Prefixes[req.index] 726 | keyStart := encodeKeyPartial(nil, prefix, req.Key1.Vals, req.Cmp1) 727 | keyEnd := encodeKeyPartial(nil, prefix, req.Key2.Vals, req.Cmp2) 728 | 729 | // seek to start key 730 | req.iter = tx.kv.Seek(keyStart, req.Cmp1, keyEnd, req.Cmp2) 731 | return nil 732 | } 733 | 734 | func (tx *DBTX) Scan(table string, req *Scanner) error { 735 | tdef := getTableDef(tx, table) 736 | if tdef == nil { 737 | return fmt.Errorf("table not found: %s", table) 738 | } 739 | 740 | return dbScan(tx, tdef, req) 741 | } 742 | -------------------------------------------------------------------------------- /table/table_test.go: -------------------------------------------------------------------------------- 1 | package table 2 | 3 | import ( 4 | "math" 5 | "os" 6 | "reflect" 7 | "sort" 8 | "testing" 9 | 10 | "github.com/Adit0507/AdiDB/btree_iter" 11 | is "github.com/stretchr/testify/require" 12 | ) 13 | 14 | type R struct { 15 | db DB 16 | ref map[string][]Record 17 | } 18 | 19 | func newR() *R { 20 | os.Remove("r.db") 21 | r := &R{ 22 | db: DB{Path: "r.db"}, 23 | ref: map[string][]Record{}, 24 | } 25 | err := r.db.Open() 26 | assert(err == nil) 27 | return r 28 | } 29 | 30 | func (r *R) dispose() { 31 | r.db.Close() 32 | os.Remove("r.db") 33 | } 34 | 35 | func (r *R) begin() *DBTX { 36 | tx := DBTX{} 37 | r.db.Begin(&tx) 38 | return &tx 39 | } 40 | 41 | func (r *R) commit(tx *DBTX) { 42 | err := r.db.Commit(tx) 43 | assert(err == nil) 44 | } 45 | 46 | func (r *R) create(tdef *TableDef) { 47 | tx := r.begin() 48 | err := tx.TableNew(tdef) 49 | r.commit(tx) 50 | assert(err == nil) 51 | } 52 | 53 | func (r *R) findRef(table string, rec Record) int { 54 | pkeys := len(r.db.tables[table].Indexes[0]) 55 | records := r.ref[table] 56 | found := -1 57 | for i, old := range records { 58 | if reflect.DeepEqual(old.Vals[:pkeys], rec.Vals[:pkeys]) { 59 | assert(found == -1) 60 | found = i 61 | } 62 | } 63 | return found 64 | } 65 | 66 | func (r *R) add(table string, rec Record) bool { 67 | tx := r.begin() 68 | dbreq := DBUpdateReq{Record: rec} 69 | _, err := tx.Set(table, &dbreq) 70 | assert(err == nil) 71 | r.commit(tx) 72 | 73 | records := r.ref[table] 74 | idx := r.findRef(table, rec) 75 | assert((idx < 0) == dbreq.Added) 76 | if idx < 0 { 77 | r.ref[table] = append(records, rec) 78 | } else { 79 | records[idx] = rec 80 | } 81 | return dbreq.Added 82 | } 83 | 84 | func (r *R) del(table string, rec Record) bool { 85 | tx := r.begin() 86 | deleted, err := tx.Delete(table, rec) 87 | assert(err == nil) 88 | r.commit(tx) 89 | 90 | idx := r.findRef(table, rec) 91 | if deleted { 92 | assert(idx >= 0) 93 | records := r.ref[table] 94 | copy(records[idx:], records[idx+1:]) 95 | r.ref[table] = records[:len(records)-1] 96 | } else { 97 | assert(idx == -1) 98 | } 99 | 100 | return deleted 101 | } 102 | 103 | func (r *R) get(table string, rec *Record) bool { 104 | tx := r.begin() 105 | ok, err := tx.Get(table, rec) 106 | assert(err == nil) 107 | r.commit(tx) 108 | idx := r.findRef(table, *rec) 109 | if ok { 110 | assert(idx >= 0) 111 | records := r.ref[table] 112 | assert(reflect.DeepEqual(records[idx], *rec)) 113 | } else { 114 | assert(idx < 0) 115 | } 116 | return ok 117 | } 118 | 119 | func TestTableCreate(t *testing.T) { 120 | r := newR() 121 | tdef := &TableDef{ 122 | Name: "tbl_test", 123 | Cols: []string{"ki1", "ks2", "s1", "i2"}, 124 | Types: []uint32{TYPE_INT64, TYPE_BYTES, TYPE_BYTES, TYPE_INT64}, 125 | Indexes: [][]string{{"ki1", "ks2"}}, 126 | } 127 | r.create(tdef) 128 | 129 | tdef = &TableDef{ 130 | Name: "tbl_test2", 131 | Cols: []string{"ki1", "ks2"}, 132 | Types: []uint32{TYPE_INT64, TYPE_BYTES}, 133 | Indexes: [][]string{{"ki1", "ks2"}}, 134 | } 135 | r.create(tdef) 136 | 137 | tx := r.begin() 138 | { 139 | rec := (&Record{}).AddStr("key", []byte("next_prefix")) 140 | ok, err := tx.Get("@meta", rec) 141 | assert(ok && err == nil) 142 | is.Equal(t, []byte{102, 0, 0, 0}, rec.Get("val").Str) 143 | } 144 | { 145 | rec := (&Record{}).AddStr("name", []byte("tbl_test")) 146 | ok, err := tx.Get("@table", rec) 147 | assert(ok && err == nil) 148 | expected := `{"Name":"tbl_test","Types":[2,1,1,2],"Cols":["ki1","ks2","s1","i2"],"Indexes":[["ki1","ks2"]],"Prefixes":[100]}` 149 | is.Equal(t, expected, string(rec.Get("def").Str)) 150 | } 151 | r.commit(tx) 152 | 153 | r.dispose() 154 | } 155 | 156 | func TestTableBasic(t *testing.T) { 157 | r := newR() 158 | tdef := &TableDef{ 159 | Name: "tbl_test", 160 | Cols: []string{"ki1", "ks2", "s1", "i2"}, 161 | Types: []uint32{TYPE_INT64, TYPE_BYTES, TYPE_BYTES, TYPE_INT64}, 162 | Indexes: [][]string{{"ki1", "ks2"}}, 163 | } 164 | r.create(tdef) 165 | 166 | rec := Record{} 167 | rec.AddInt64("ki1", 1).AddStr("ks2", []byte("hello")) 168 | rec.AddStr("s1", []byte("world")).AddInt64("i2", 2) 169 | added := r.add("tbl_test", rec) 170 | is.True(t, added) 171 | 172 | { 173 | got := Record{} 174 | got.AddInt64("ki1", 1).AddStr("ks2", []byte("hello")) 175 | ok := r.get("tbl_test", &got) 176 | is.True(t, ok) 177 | } 178 | { 179 | got := Record{} 180 | got.AddInt64("ki1", 1).AddStr("ks2", []byte("hello2")) 181 | ok := r.get("tbl_test", &got) 182 | is.False(t, ok) 183 | } 184 | 185 | rec.Get("s1").Str = []byte("www") 186 | added = r.add("tbl_test", rec) 187 | is.False(t, added) 188 | 189 | { 190 | got := Record{} 191 | got.AddInt64("ki1", 1).AddStr("ks2", []byte("hello")) 192 | ok := r.get("tbl_test", &got) 193 | is.True(t, ok) 194 | } 195 | 196 | { 197 | key := Record{} 198 | key.AddInt64("ki1", 1).AddStr("ks2", []byte("hello2")) 199 | deleted := r.del("tbl_test", key) 200 | is.False(t, deleted) 201 | 202 | key.Get("ks2").Str = []byte("hello") 203 | deleted = r.del("tbl_test", key) 204 | is.True(t, deleted) 205 | } 206 | 207 | r.dispose() 208 | } 209 | 210 | func TestStringEscape(t *testing.T) { 211 | in := [][]byte{ 212 | {}, 213 | {0}, 214 | {1}, 215 | } 216 | out := [][]byte{ 217 | {}, 218 | {1, 1}, 219 | {1, 2}, 220 | } 221 | for i, s := range in { 222 | b := escapeString(s) 223 | is.Equal(t, out[i], b) 224 | s2 := unescapeString(b) 225 | is.Equal(t, s, s2) 226 | } 227 | } 228 | 229 | func TestTableEncoding(t *testing.T) { 230 | input := []int{-1, 0, +1, math.MinInt64, math.MaxInt64} 231 | sort.Ints(input) 232 | 233 | encoded := []string{} 234 | for _, i := range input { 235 | v := Value{Type: TYPE_INT64, I64: int64(i)} 236 | b := encodeValues(nil, []Value{v}) 237 | out := []Value{v} 238 | decodeValues(b, out) 239 | assert(out[0].I64 == int64(i)) 240 | encoded = append(encoded, string(b)) 241 | } 242 | 243 | is.True(t, sort.StringsAreSorted(encoded)) 244 | } 245 | 246 | func TestTableScan(t *testing.T) { 247 | r := newR() 248 | tdef := &TableDef{ 249 | Name: "tbl_test", 250 | Cols: []string{"ki1", "ks2", "s1", "i2"}, 251 | Types: []uint32{TYPE_INT64, TYPE_BYTES, TYPE_BYTES, TYPE_INT64}, 252 | Indexes: [][]string{ 253 | {"ki1", "ks2"}, 254 | {"i2"}, 255 | }, 256 | } 257 | r.create(tdef) 258 | 259 | size := 100 260 | for i := 0; i < size; i += 2 { 261 | rec := Record{} 262 | rec.AddInt64("ki1", int64(i)).AddStr("ks2", []byte("hello")) 263 | rec.AddStr("s1", []byte("world")).AddInt64("i2", int64(i/2)) 264 | added := r.add("tbl_test", rec) 265 | assert(added) 266 | } 267 | 268 | // full table scan without a key 269 | tx := r.begin() 270 | { 271 | rec := Record{} // empty 272 | req := Scanner{ 273 | Cmp1: btree_iter.CMP_GE, Cmp2: btree_iter.CMP_LE, 274 | Key1: rec, Key2: rec, 275 | } 276 | err := tx.Scan("tbl_test", &req) 277 | assert(err == nil) 278 | 279 | got := []Record{} 280 | for req.Valid() { 281 | rec := Record{} 282 | req.Deref(&rec) 283 | got = append(got, rec) 284 | req.Next() 285 | } 286 | is.Equal(t, r.ref["tbl_test"], got) 287 | } 288 | r.commit(tx) 289 | 290 | tmpkey := func(n int) Record { 291 | rec := Record{} 292 | rec.AddInt64("ki1", int64(n)) // partial primary key 293 | return rec 294 | } 295 | i2key := func(n int) Record { 296 | rec := Record{} 297 | rec.AddInt64("i2", int64(n)/2) // secondary index 298 | return rec 299 | } 300 | 301 | tx = r.begin() 302 | for i := 0; i < size; i += 2 { 303 | ref := []int64{} 304 | for j := i; j < size; j += 2 { 305 | ref = append(ref, int64(j)) 306 | 307 | scanners := []Scanner{ 308 | { 309 | Cmp1: btree_iter.CMP_GE, 310 | Cmp2: btree_iter.CMP_LE, 311 | Key1: tmpkey(i), 312 | Key2: tmpkey(j), 313 | }, 314 | { 315 | Cmp1: btree_iter.CMP_GE, 316 | Cmp2: btree_iter.CMP_LE, 317 | Key1: tmpkey(i - 1), 318 | Key2: tmpkey(j + 1), 319 | }, 320 | { 321 | Cmp1: btree_iter.CMP_GT, 322 | Cmp2: btree_iter.CMP_LT, 323 | Key1: tmpkey(i - 1), 324 | Key2: tmpkey(j + 1), 325 | }, 326 | { 327 | Cmp1: btree_iter.CMP_GT, 328 | Cmp2: btree_iter.CMP_LT, 329 | Key1: tmpkey(i - 2), 330 | Key2: tmpkey(j + 2), 331 | }, 332 | { 333 | Cmp1: btree_iter.CMP_GE, 334 | Cmp2: btree_iter.CMP_LE, 335 | Key1: i2key(i), 336 | Key2: i2key(j), 337 | }, 338 | { 339 | Cmp1: btree_iter.CMP_GT, 340 | Cmp2: btree_iter.CMP_LT, 341 | Key1: i2key(i - 2), 342 | Key2: i2key(j + 2), 343 | }, 344 | } 345 | for _, tmp := range scanners { 346 | tmp.Cmp1, tmp.Cmp2 = tmp.Cmp2, tmp.Cmp1 347 | tmp.Key1, tmp.Key2 = tmp.Key2, tmp.Key1 348 | scanners = append(scanners, tmp) 349 | } 350 | 351 | for _, sc := range scanners { 352 | err := tx.Scan("tbl_test", &sc) 353 | assert(err == nil) 354 | 355 | keys := []int64{} 356 | got := Record{} 357 | for sc.Valid() { 358 | sc.Deref(&got) 359 | keys = append(keys, got.Get("ki1").I64) 360 | sc.Next() 361 | } 362 | if sc.Cmp1 < sc.Cmp2 { 363 | // reverse 364 | for a := 0; a < len(keys)/2; a++ { 365 | b := len(keys) - 1 - a 366 | keys[a], keys[b] = keys[b], keys[a] 367 | } 368 | } 369 | 370 | is.Equal(t, ref, keys) 371 | } // scanners 372 | } // j 373 | } // i 374 | r.commit(tx) 375 | 376 | r.dispose() 377 | } 378 | 379 | func TestTableIndex(t *testing.T) { 380 | r := newR() 381 | tdef := &TableDef{ 382 | Name: "tbl_test", 383 | Cols: []string{"ki1", "ks2", "s1", "i2"}, 384 | Types: []uint32{TYPE_INT64, TYPE_BYTES, TYPE_BYTES, TYPE_INT64}, 385 | Indexes: [][]string{ 386 | {"ki1", "ks2"}, 387 | {"ks2", "ki1"}, 388 | {"i2"}, 389 | {"ki1", "i2"}, 390 | }, 391 | } 392 | r.create(tdef) 393 | 394 | record := func(ki1 int64, ks2 string, s1 string, i2 int64) Record { 395 | rec := Record{} 396 | rec.AddInt64("ki1", ki1).AddStr("ks2", []byte(ks2)) 397 | rec.AddStr("s1", []byte(s1)).AddInt64("i2", i2) 398 | return rec 399 | } 400 | 401 | r1 := record(1, "a1", "v1", 2) 402 | r2 := record(2, "a2", "v2", -2) 403 | r.add("tbl_test", r1) 404 | r.add("tbl_test", r2) 405 | 406 | tx := r.begin() 407 | { 408 | rec := Record{} 409 | rec.AddInt64("i2", 2) 410 | req := Scanner{ 411 | Cmp1: btree_iter.CMP_GE, Cmp2: btree_iter.CMP_LE, 412 | Key1: rec, Key2: rec, 413 | } 414 | err := tx.Scan("tbl_test", &req) 415 | assert(err == nil) 416 | is.True(t, req.Valid()) 417 | 418 | out := Record{} 419 | req.Deref(&out) 420 | is.Equal(t, r1, out) 421 | 422 | req.Next() 423 | is.False(t, req.Valid()) 424 | } 425 | r.commit(tx) 426 | 427 | tx = r.begin() 428 | { 429 | rec1 := Record{} 430 | rec1.AddInt64("i2", 2) 431 | rec2 := Record{} 432 | rec2.AddInt64("i2", 4) 433 | req := Scanner{ 434 | Cmp1: btree_iter.CMP_GT, Cmp2: btree_iter.CMP_LE, 435 | Key1: rec1, Key2: rec2, 436 | } 437 | err := tx.Scan("tbl_test", &req) 438 | assert(err == nil) 439 | is.False(t, req.Valid()) 440 | } 441 | r.commit(tx) 442 | 443 | r.add("tbl_test", record(1, "a1", "v1", 1)) 444 | tx = r.begin() 445 | { 446 | rec := Record{} 447 | rec.AddInt64("i2", 2) 448 | req := Scanner{ 449 | Cmp1: btree_iter.CMP_GE, Cmp2: btree_iter.CMP_LE, 450 | Key1: rec, Key2: rec, 451 | } 452 | err := tx.Scan("tbl_test", &req) 453 | assert(err == nil) 454 | is.False(t, req.Valid()) 455 | } 456 | r.commit(tx) 457 | 458 | tx = r.begin() 459 | { 460 | rec := Record{} 461 | rec.AddInt64("i2", 1) 462 | req := Scanner{ 463 | Cmp1: btree_iter.CMP_GE, Cmp2: btree_iter.CMP_LE, 464 | Key1: rec, Key2: rec, 465 | } 466 | err := tx.Scan("tbl_test", &req) 467 | assert(err == nil) 468 | is.True(t, req.Valid()) 469 | } 470 | r.commit(tx) 471 | 472 | { 473 | rec := Record{} 474 | rec.AddInt64("ki1", 1).AddStr("ks2", []byte("a1")) 475 | ok := r.del("tbl_test", rec) 476 | assert(ok) 477 | } 478 | 479 | tx = r.begin() 480 | { 481 | rec := Record{} 482 | rec.AddInt64("i2", 1) 483 | req := Scanner{ 484 | Cmp1: btree_iter.CMP_GE, Cmp2: btree_iter.CMP_LE, 485 | Key1: rec, Key2: rec, 486 | } 487 | err := tx.Scan("tbl_test", &req) 488 | assert(err == nil) 489 | is.False(t, req.Valid()) 490 | } 491 | r.commit(tx) 492 | 493 | r.dispose() 494 | } 495 | -------------------------------------------------------------------------------- /transactions/tx.go: -------------------------------------------------------------------------------- 1 | package transactions 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "runtime" 7 | "slices" 8 | 9 | "github.com/Adit0507/AdiDB/btree" 10 | "github.com/Adit0507/AdiDB/btree_iter" 11 | "github.com/Adit0507/AdiDB/kv" 12 | ) 13 | 14 | type KVTX struct { 15 | snapshot btree.BTree // read-only snapshot(copy on wrrite) 16 | version uint64 17 | // local updates are held in an memory B+tree 18 | pending btree.BTree //captured KV updates 19 | reads []KeyRange //list of involved interval of keys for dtecting conflicts 20 | // cheks for conflict even if update changes nothing 21 | updateAttempted bool 22 | done bool 23 | } 24 | 25 | // start <=key <=stop 26 | type KeyRange struct { 27 | start []byte 28 | stop []byte 29 | } 30 | 31 | type TXSave struct { 32 | root uint64 33 | reads []KeyRange 34 | } 35 | 36 | func (tx *KVTX) Save(save *TXSave) { 37 | save.root, save.reads = tx.pending.root, tx.reads 38 | } 39 | 40 | func (tx *KVTX) Revert(save *TXSave) { 41 | tx.pending.root, tx.reads = save.root, save.reads 42 | } 43 | 44 | const ( 45 | FLAG_DELETED = byte(1) 46 | FLAG_UPDATED = byte(2) 47 | ) 48 | 49 | func assert(cond bool) { 50 | if !cond{ 51 | panic("assetion failure") 52 | } 53 | } 54 | 55 | // begin a transaction 56 | func (kv *kv.KV) Begin(tx *KVTX) { 57 | kv.mutex.Lock() 58 | defer kv.mutex.Unlock() 59 | 60 | tx.snapshot.root = kv.tree.root 61 | chunks := kv.mmap.chunks 62 | tx.snapshot.get = func(ptr uint64) []byte { return mmapRead(ptr, chunks) } 63 | tx.version = kv.version 64 | 65 | // in memeory tree to caputre updaets 66 | pages := [][]byte(nil) 67 | tx.pending.get = func(ptr uint64) []byte { return pages[ptr-1] } 68 | tx.pending.new = func(b []byte) uint64 { 69 | pages = append(pages, b) 70 | return uint64(len(pages)) 71 | } 72 | tx.pending.del = func(uint64) {} 73 | // keepin track of concurrent TXs 74 | kv.ongoing = append(kv.ongoing, tx.version) 75 | runtime.SetFinalizer(tx, func(tx *KVTX) { assert(tx.done) }) 76 | } 77 | 78 | // rollback on error 79 | func (kv *kv.KV) Commit(tx *KVTX) error { 80 | assert(!tx.done) 81 | tx.done = true 82 | kv.mutex.Lock() 83 | defer kv.mutex.Unlock() 84 | 85 | // check conflicts 86 | if tx.updateAttempted && detectConflicts(kv, tx) { 87 | return ErrorConflict 88 | } 89 | 90 | // save meta page 91 | meta, root := saveMeta(kv), kv.tree.root 92 | kv.free.curVer = kv.version + 1 //transfer current updates to current tree 93 | writes := []KeyRange(nil) 94 | for iter := tx.pending.Seek(nil, btree_iter.CMP_GT); iter.Valid(); iter.Next() { 95 | modified := false 96 | key, val := iter.Deref() 97 | oldVal, isOld := tx.snapshot.Get(key) 98 | switch val[0] { 99 | case FLAG_DELETED: 100 | modified = isOld 101 | deleted, err := kv.tree.Delete(&DeleteReq{Key: key}) 102 | assert(err == nil) // can only fail by length limit 103 | assert(deleted == modified) // assured by conflict detection 104 | 105 | case FLAG_UPDATED: 106 | modified = (!isOld || !bytes.Equal(oldVal, val[1:])) 107 | updated, err := kv.tree.Update(&UpdateReq{Key: key, Val: val[1:]}) 108 | assert(err == nil) 109 | assert(updated == modified) 110 | 111 | default: 112 | panic("unreachable") 113 | } 114 | 115 | if modified && len(kv.ongoing) > 1 { 116 | writes = append(writes, KeyRange{key, key}) 117 | } 118 | } 119 | 120 | // commitin update 121 | if root != kv.tree.root { 122 | kv.version++ 123 | if err := updateOrRevert(kv, meta); err != nil { 124 | return err 125 | } 126 | } 127 | 128 | if len(writes) > 0 { 129 | slices.SortFunc(writes, func(r1, r2 KeyRange) int { 130 | return bytes.Compare(r1.start, r2.start) 131 | }) 132 | kv.history = append(kv.history, CommittedTX{kv.version, writes}) 133 | } 134 | return nil 135 | } 136 | 137 | type KVWrap struct{ 138 | *kv.KV 139 | } 140 | 141 | // end transaction 142 | func (kv KVWrap) Abort(tx *KVTX) { 143 | assert(!tx.done) 144 | tx.done = true 145 | 146 | kv.mutex.Lock() 147 | txFinalize(kv, tx) 148 | kv.mutex.Unlock() 149 | } 150 | 151 | var ErrorConflict = errors.New("cannot commit due to conflict") 152 | 153 | func detectConflicts(kv KVWrap, tx *KVTX) bool { 154 | slices.SortFunc(tx.reads, func(r1, r2 KeyRange) int { 155 | return bytes.Compare(r1.start, r2.start) 156 | }) 157 | 158 | for i := len(kv.history) - 1; i >= 0; i-- { 159 | if !versionBefore(tx.version, kv.history[i].version) { 160 | break 161 | } 162 | if sortedRangesOverlap(tx.reads, kv.history[i].writes) { 163 | return true 164 | } 165 | } 166 | 167 | return false 168 | } 169 | 170 | func sortedRangesOverlap(s1, s2 []KeyRange) bool { 171 | for len(s1) > 0 && len(s2) > 0 { 172 | if bytes.Compare(s1[0].stop, s2[0].start) < 0 { 173 | s1 = s1[1:] 174 | } else if bytes.Compare(s2[0].stop, s1[0].start) < 0 { 175 | s2 = s2[1:] 176 | } else { 177 | return true 178 | } 179 | } 180 | 181 | return false 182 | } 183 | 184 | // routines when exiting a transacion 185 | func txFinalize(kv *KVWrap, tx *KVTX) { 186 | idx := slices.Index(kv.ongoing, tx.version) 187 | last := len(kv.ongoing) - 1 188 | kv.ongoing[idx], kv.ongoing = kv.ongoing[last], kv.ongoing[:last] 189 | 190 | // oldest in use version 191 | minVer := kv.version 192 | for _, other := range kv.ongoing { 193 | if versionBefore(other, minVer) { 194 | minVer = other 195 | } 196 | } 197 | 198 | // release free list 199 | kv.free.SetMaxVer(minVer) 200 | 201 | for idx = 0; idx < len(kv.history); idx++ { 202 | if versionBefore(minVer, kv.history[idx].version) { 203 | break 204 | } 205 | } 206 | 207 | kv.history = kv.history[idx:] 208 | } 209 | 210 | // KV interfaces 211 | type KVIter interface { 212 | Deref() (key []byte, val []byte) 213 | Valid() bool 214 | Next() 215 | } 216 | 217 | // combines pending updates and the snapshot 218 | type CombinedIterator struct { 219 | top *btree_iter.BIter //kvtx pending 220 | bot *btree_iter.BIter //kvtx snapshot 221 | dir int //+1 for greater or greater than, -1 for less or less than 222 | 223 | //end of range 224 | cmp int 225 | end []byte 226 | } 227 | 228 | func (iter *CombinedIterator) Deref() ([]byte, []byte) { 229 | var k1, k2, v1, v2 []byte 230 | top, bot := iter.top.Valid(), iter.bot.Valid() 231 | assert(top || bot) 232 | if top { 233 | k1, v1 = iter.top.Deref() 234 | } 235 | if bot { 236 | k2, v2 = iter.bot.Deref() 237 | } 238 | 239 | // usin min/max key of the two 240 | if top && bot && bytes.Compare(k1, k2) == +iter.dir { 241 | return k2, v2 242 | } 243 | if top { 244 | return k1, v1[1:] 245 | } else { 246 | return k2, v2 247 | } 248 | } 249 | 250 | func (iter *CombinedIterator) Valid() bool { 251 | if iter.top.Valid() || iter.bot.Valid() { 252 | key, _ := iter.Deref() 253 | return cmpOk(key, iter.cmp, iter.end) 254 | } 255 | 256 | return false 257 | } 258 | 259 | func (iter *CombinedIterator) Next() { 260 | top, bot := iter.top.Valid(), iter.bot.Valid() 261 | if top && bot { 262 | k1, _ := iter.top.Deref() 263 | k2, _ := iter.bot.Deref() 264 | 265 | switch bytes.Compare(k1, k2) { 266 | case -iter.dir: 267 | top, bot = true, false 268 | case +iter.dir: 269 | top, bot = false, true 270 | case 0: // equal; move both 271 | } 272 | } 273 | 274 | assert(top || bot) 275 | if top { 276 | if iter.dir > 0 { 277 | iter.top.Next() 278 | } else { 279 | iter.top.Prev() 280 | } 281 | } 282 | 283 | if bot { 284 | if iter.dir > 0 { 285 | iter.bot.Next() 286 | } else { 287 | iter.bot.Prev() 288 | } 289 | } 290 | 291 | } 292 | 293 | func cmp2Dir(cmp int) int { 294 | if cmp > 0 { 295 | return +1 296 | } else { 297 | return -1 298 | } 299 | } 300 | 301 | // range query combines captured updates with snapshots 302 | func (tx *KVTX) Seek(key1 []byte, cmp1 int, key2 []byte, cmp2 int) KVIter { 303 | assert(cmp2Dir(cmp1) != cmp2Dir(cmp2)) 304 | lo, hi := key1, key2 305 | if cmp2Dir(cmp1) < 0 { 306 | lo, hi = hi, lo 307 | } 308 | tx.reads = append(tx.reads, KeyRange{lo, hi}) 309 | 310 | return &CombinedIterator{ 311 | top: tx.pending.Seek(key1, cmp1), 312 | bot: tx.pending.Seek(key1, cmp1), 313 | dir: cmp2Dir(cmp1), 314 | cmp: cmp2, 315 | end: key2, 316 | } 317 | } 318 | 319 | func (tx *KVTX) Update(req *UpdateReq) (bool, error) { 320 | tx.updateAttempted = true 321 | 322 | old, exists := tx.Get(req.Key) 323 | if req.Mode == btree.MODE_UPDATE_ONLY && !exists { 324 | return false, nil 325 | } 326 | if req.Mode == btree.MODE_INSERT_ONLY && exists { 327 | return false, nil 328 | } 329 | if exists && bytes.Equal(old, req.Val) { 330 | return false, nil 331 | } 332 | 333 | flaggedVal := append([]byte{FLAG_UPDATED}, req.Val...) 334 | _, err := tx.pending.Update(&UpdateReq{Key: req.Key, Val: flaggedVal}) 335 | if err != nil { 336 | return false, err 337 | } 338 | 339 | req.Added = !exists 340 | req.Updated = true 341 | req.Old = old 342 | 343 | return true, nil 344 | } 345 | 346 | func (tx *KVTX) Del(req *DeleteReq) (bool, error) { 347 | tx.updateAttempted = true 348 | exists := false 349 | if req.Old, exists = tx.Get(req.Key); !exists { 350 | return false, nil 351 | } 352 | 353 | return tx.pending.Update(&UpdateReq{Key: req.Key, Val: []byte{FLAG_DELETED}}) 354 | } 355 | 356 | func (tx *KVTX) Set(key []byte, val []byte) (bool, error) { 357 | return tx.Update(&UpdateReq{Key: key, Val: val}) 358 | } 359 | 360 | // point query combines captured updates with snapshots 361 | func (tx *KVTX) Get(key []byte) ([]byte, bool) { 362 | tx.reads = append(tx.reads, KeyRange{key, key}) 363 | val, ok := tx.pending.Get(key) 364 | 365 | switch { 366 | case ok && val[0] == FLAG_UPDATED: //updated in this tx 367 | return val[1:], true 368 | case ok && val[0] == FLAG_DELETED: //deleted in this TX 369 | return nil, false 370 | case !ok: 371 | return tx.snapshot.Get(key) 372 | 373 | default: 374 | panic("unreachable") 375 | } 376 | } 377 | -------------------------------------------------------------------------------- /transactions/tx_test.go: -------------------------------------------------------------------------------- 1 | package transactions 2 | 3 | import ( 4 | "fmt" 5 | "slices" 6 | "sort" 7 | "testing" 8 | 9 | is "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestKVTXSequential(t *testing.T) { 13 | d := newD() 14 | 15 | d.add("k1", "v1") 16 | 17 | tx := KVTX{} 18 | d.db.Begin(&tx) 19 | 20 | tx.Update(&UpdateReq{Key: []byte("k1"), Val: []byte("xxx")}) 21 | tx.Update(&UpdateReq{Key: []byte("k2"), Val: []byte("xxx")}) 22 | 23 | val, ok := tx.Get([]byte("k1")) 24 | is.True(t, ok) 25 | is.Equal(t, []byte("xxx"), val) 26 | val, ok = tx.Get([]byte("k2")) 27 | is.True(t, ok) 28 | is.Equal(t, []byte("xxx"), val) 29 | 30 | d.db.Abort(&tx) 31 | 32 | d.verify(t) 33 | 34 | d.reopen() 35 | d.verify(t) 36 | { 37 | tx := KVTX{} 38 | d.db.Begin(&tx) 39 | _, ok = tx.Get([]byte("k2")) 40 | is.False(t, ok) 41 | d.db.Abort(&tx) 42 | } 43 | 44 | d.dispose() 45 | } 46 | 47 | func TestKVTXInterleave(t *testing.T) { 48 | d := newD() 49 | 50 | { 51 | tx1, tx2 := KVTX{}, KVTX{} 52 | d.db.Begin(&tx1) 53 | d.db.Begin(&tx2) 54 | tx1.Set([]byte("k1"), []byte("v1")) 55 | tx2.Set([]byte("k2"), []byte("v2")) 56 | 57 | val, ok := tx1.Get([]byte("k1")) 58 | assert(ok && string(val) == "v1") // read uncomitted write 59 | val, ok = tx2.Get([]byte("k2")) 60 | assert(ok && string(val) == "v2") 61 | 62 | err := d.db.Commit(&tx1) 63 | assert(err == nil) 64 | 65 | err = d.db.Commit(&tx2) 66 | assert(err == nil) 67 | assert(len(d.db.ongoing)+len(d.db.history) == 0) 68 | } 69 | 70 | { 71 | tx1, tx2 := KVTX{}, KVTX{} 72 | d.db.Begin(&tx1) 73 | d.db.Begin(&tx2) 74 | tx1.Set([]byte("k1"), []byte("v2")) 75 | err := d.db.Commit(&tx1) 76 | assert(err == nil) 77 | 78 | val, ok := tx2.Get([]byte("k1")) 79 | assert(ok && string(val) == "v1") // isolation 80 | 81 | err = d.db.Commit(&tx2) 82 | assert(err == nil) // read-only 83 | assert(len(d.db.ongoing)+len(d.db.history) == 0) 84 | } 85 | 86 | { 87 | tx1, tx2 := KVTX{}, KVTX{} 88 | d.db.Begin(&tx1) 89 | d.db.Begin(&tx2) 90 | tx1.Set([]byte("k1"), []byte("v3")) 91 | err := d.db.Commit(&tx1) 92 | assert(err == nil) 93 | 94 | val, ok := tx2.Get([]byte("k1")) 95 | assert(ok && string(val) == "v2") 96 | tx2.Set([]byte("k2"), val) 97 | 98 | err = d.db.Commit(&tx2) 99 | assert(err == ErrorConflict) // read conflict 100 | assert(len(d.db.ongoing)+len(d.db.history) == 0) 101 | } 102 | 103 | { 104 | tx1, tx2 := KVTX{}, KVTX{} 105 | d.db.Begin(&tx1) 106 | d.db.Begin(&tx2) 107 | tx1.Set([]byte("k3"), []byte("v1")) 108 | tx2.Del(&DeleteReq{Key: []byte("k3")}) 109 | err := d.db.Commit(&tx1) 110 | assert(err == nil) 111 | err = d.db.Commit(&tx2) 112 | assert(err == ErrorConflict) // write conflict 113 | assert(len(d.db.ongoing)+len(d.db.history) == 0) 114 | } 115 | 116 | { 117 | d.add("k4", "v1") 118 | tx1, tx2 := KVTX{}, KVTX{} 119 | d.db.Begin(&tx1) 120 | d.db.Begin(&tx2) 121 | tx1.Set([]byte("k4"), []byte("v2")) 122 | tx2.Del(&DeleteReq{Key: []byte("k4")}) 123 | err := d.db.Commit(&tx2) 124 | assert(err == nil) 125 | err = d.db.Commit(&tx1) 126 | assert(err == ErrorConflict) // write conflict 127 | assert(len(d.db.ongoing)+len(d.db.history) == 0) 128 | } 129 | 130 | { 131 | tx1, tx2 := KVTX{}, KVTX{} 132 | d.db.Begin(&tx1) 133 | d.db.Begin(&tx2) 134 | tx1.Set([]byte("k5"), []byte("v2")) 135 | tx2.Del(&DeleteReq{Key: []byte("k5")}) 136 | err := d.db.Commit(&tx2) // no write 137 | assert(err == nil) 138 | err = d.db.Commit(&tx1) 139 | assert(err == nil) // no conflict 140 | assert(len(d.db.ongoing)+len(d.db.history) == 0) 141 | } 142 | 143 | { 144 | tx1, tx2 := KVTX{}, KVTX{} 145 | d.db.Begin(&tx1) 146 | d.add("k6", "v1") // 3rd TX 147 | d.add("k7", "v1") 148 | 149 | d.db.Begin(&tx2) 150 | tx2.Set([]byte("k6"), []byte("v2")) 151 | err := d.db.Commit(&tx2) // no conflict 152 | assert(err == nil) 153 | 154 | _, ok := tx1.Get([]byte("k7")) 155 | assert(!ok) 156 | tx1.Set([]byte("k8"), []byte("v3")) 157 | err = d.db.Commit(&tx1) // read conflict 158 | assert(err == ErrorConflict) 159 | assert(len(d.db.ongoing)+len(d.db.history) == 0) 160 | } 161 | 162 | d.dispose() 163 | } 164 | 165 | func TestKVTXRand(t *testing.T) { 166 | d := newD() 167 | order := []uint32{} 168 | funcs := []func(){} 169 | 170 | N := uint32(50_000) 171 | for i := uint32(0); i < N; i++ { 172 | tx := KVTX{} 173 | key, val := fmt.Sprintf("k%v", i), fmt.Sprintf("v%v", i) 174 | funcs = append(funcs, func() { d.db.Begin(&tx) }) 175 | funcs = append(funcs, func() { tx.Set([]byte(key), []byte(val)) }) 176 | funcs = append(funcs, func() { 177 | err := d.db.Commit(&tx) 178 | assert(err == nil) 179 | }) 180 | 181 | nums := []uint32{fmix32(3*i + 0), fmix32(3*i + 1), fmix32(3*i + 2)} 182 | slices.Sort(nums) 183 | order = append(order, nums...) 184 | } 185 | sort.Sort(sortIF{ 186 | len: int(N), 187 | less: func(i, j int) bool { return order[i] < order[j] }, 188 | swap: func(i, j int) { 189 | order[i], order[j] = order[j], order[i] 190 | funcs[i], funcs[j] = funcs[j], funcs[i] 191 | }, 192 | }) 193 | 194 | for _, f := range funcs { 195 | f() 196 | } 197 | assert(len(d.db.ongoing)+len(d.db.history) == 0) 198 | 199 | d.dispose() 200 | } 201 | --------------------------------------------------------------------------------