├── .gitignore ├── .idea ├── .gitignore ├── dropdb.iml ├── modules.xml └── vcs.xml ├── README.md ├── buffer ├── buffer.go ├── manager.go ├── manager_test.go ├── naive_replacement_strategy.go └── replacement_strategy.go ├── driver ├── conn.go ├── driver.go ├── driver_test.go ├── result.go ├── rows.go ├── statement.go └── transaction.go ├── dropdb.png ├── example.go ├── file ├── blockid.go ├── manager.go ├── manager_test.go ├── page.go └── page_test.go ├── go.mod ├── go.sum ├── index ├── btree │ ├── directory.go │ ├── directory_entry.go │ ├── index.go │ ├── index_test.go │ ├── leaf.go │ └── page.go ├── common │ └── constants.go ├── hash │ ├── index.go │ └── index_test.go └── index.go ├── log ├── iterator.go ├── manager.go └── manager_test.go ├── materialize └── temp_table.go ├── metadata ├── index_info.go ├── index_info_test.go ├── index_manager.go ├── index_manager_test.go ├── metadata_manager.go ├── stat_info.go ├── stat_manager.go ├── stat_manager_test.go ├── table_manager.go ├── table_manager_test.go ├── view_manager.go └── view_manager_test.go ├── parse ├── create_index_data.go ├── create_table_data.go ├── create_view_data.go ├── delete_data.go ├── insert_data.go ├── lexer.go ├── lexer_test.go ├── modify_data.go ├── parser.go ├── parser_test.go ├── predicate_parser.go └── query_data.go ├── plan └── plan.go ├── plan_impl ├── basic_query_planner.go ├── basic_query_planner_test.go ├── basic_update_planner.go ├── basic_update_planner_test.go ├── group_by_plan.go ├── group_by_plan_test.go ├── index_join_plan.go ├── index_join_plan_test.go ├── index_select_plan.go ├── index_select_plan_test.go ├── index_update_planner.go ├── index_update_planner_test.go ├── materialize_plan.go ├── materialize_plan_test.go ├── planner.go ├── planner_test.go ├── product_plan.go ├── product_plan_test.go ├── project_plan.go ├── project_plan_test.go ├── query_planner.go ├── select_plan.go ├── select_plan_test.go ├── sort_plan.go ├── sort_plan_test.go ├── table_plan.go ├── table_plan_test.go └── update_planner.go ├── query ├── expression.go ├── functions │ ├── aggregation_function.go │ ├── avg_function.go │ ├── count_function.go │ ├── max_function.go │ ├── min_function.go │ └── sum_function.go ├── group_by_scan.go ├── group_by_scan_test.go ├── group_value.go ├── index_join_scan.go ├── index_join_scan_test.go ├── index_select_scan.go ├── index_select_scan_test.go ├── predicate.go ├── product_scan.go ├── product_scan_test.go ├── project_scan.go ├── project_scan_test.go ├── record_comparator.go ├── select_scan.go ├── select_scan_test.go ├── sort_scan.go ├── sort_scan_test.go └── term.go ├── record ├── alignment.go ├── id.go ├── layout.go ├── layout_test.go ├── page.go ├── page_test.go ├── schema.go └── schema_test.go ├── scan ├── scan.go └── update_scan.go ├── server └── dropdb.go ├── table ├── table_scan.go └── table_scan_test.go ├── tx ├── buffer_list.go ├── checkpoint.go ├── commit.go ├── concurrency │ ├── lock_table.go │ └── manager.go ├── concurrency_test.go ├── log_record.go ├── log_record_test.go ├── recovery_manager.go ├── rollback.go ├── set_bool.go ├── set_date.go ├── set_int.go ├── set_long.go ├── set_short.go ├── set_string.go ├── start.go └── transaction.go ├── types ├── comparisons.go ├── field_info.go ├── hash.go ├── int_type.go └── operators.go └── utils ├── hash_value.go └── hash_value_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | go.work.sum 23 | 24 | # env file 25 | .env -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/dropdb.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | DropDB logo 2 | 3 | ## Overview 4 | 5 | DropDB is a fully-featured database written in Go, designed as an educational project inspired by Edward 6 | Sciore's [Database Design and Implementation](https://link.springer.com/book/10.1007/978-3-030-33836-7). The project 7 | extends beyond the book's implementation with enhanced features, optimizations, and additional capabilities. 8 | 9 | The goal is to implement a fairly feature-complete database while exploring database internals, including storage 10 | management, query processing, transaction handling, and optimization techniques. 11 | 12 | ## Core Features 13 | 14 | ### Implemented 15 | 16 | - **Disk and File Management** 17 | - Efficient storage and retrieval with optimized disk layout 18 | - Byte alignment optimization for improved performance 19 | 20 | - **Memory Management** 21 | - Intelligent memory allocation for data and metadata 22 | - Statistics tracking for query planning optimization 23 | 24 | - **Transaction Management** 25 | - ACID compliance with atomicity and durability 26 | - Robust transaction processing and recovery 27 | 28 | - **Record and Metadata Management** 29 | - Efficient record access and updates 30 | - Comprehensive schema and table definition management 31 | 32 | - **Query Processing** 33 | - SQL parsing and execution 34 | - Materialization and sorting capabilities 35 | - Support for complex queries with multiple clauses 36 | 37 | - **Indexing** 38 | - B-tree index implementation 39 | - Performance optimization for data retrieval 40 | 41 | ### In Development 42 | 43 | - **Buffer Management** 44 | - Smart buffer pool management 45 | - I/O optimization strategies 46 | 47 | - **Query Optimization** 48 | - Cost-based query planning 49 | - Execution plan optimization 50 | 51 | ## Data Types and Operations 52 | 53 | ### Supported Types 54 | 55 | - `int`, `short`, `long` 56 | - `string` 57 | - `bool` 58 | - `date` 59 | 60 | ### Query Capabilities 61 | 62 | - **Comparison Operators**: `=`, `!=`, `>`, `<`, `>=`, `<=` 63 | - **Aggregation Functions**: 64 | - `COUNT`, `SUM`, `AVG`, `MIN`, `MAX` 65 | - Note: AVG and SUM results use integer casting due to current floating-point limitations 66 | - Precision issues may occur with 64-bit integers on 32-bit machines 67 | 68 | ### Query Features 69 | 70 | - Aggregations 71 | - `GROUP BY` clauses 72 | - `HAVING` clauses 73 | - `ORDER BY` (currently ascending only) 74 | 75 | ## SQL Support 76 | 77 | ### Supported Commands 78 | 79 | #### SELECT 80 | 81 | ```sql 82 | -- Basic query with conditions 83 | SELECT name 84 | FROM users 85 | WHERE id = 1 86 | 87 | -- Join query 88 | SELECT name, dept_name 89 | FROM users, 90 | departments 91 | WHERE users_dept_id = dept_id 92 | 93 | -- Aggregation with grouping 94 | SELECT dept, avg(salary) 95 | FROM employees 96 | GROUP BY dept 97 | 98 | -- Complex query 99 | SELECT category, date, sum (amount) 100 | FROM orders 101 | WHERE amount > 500 102 | GROUP BY category, date 103 | HAVING sum (amount) > 2000 104 | ORDER BY total asc 105 | ``` 106 | 107 | #### Data Definition 108 | 109 | - `CREATE TABLE` - Define new tables with specified fields and types 110 | - `CREATE VIEW` - Create stored queries 111 | - `CREATE INDEX` - Build indexes for performance optimization 112 | 113 | #### Data Manipulation 114 | 115 | - `INSERT` - Add new records 116 | - `UPDATE` - Modify existing records 117 | - `DELETE` - Remove records based on conditions 118 | 119 | ## Project Goals 120 | 121 | DropDB serves as both a learning platform and a practical implementation of database concepts. While primarily developed 122 | for educational purposes, it aims to provide real-world database functionality and performance. 123 | 124 | ## Contributing 125 | 126 | Contributions are welcome! Feel free to: 127 | 128 | - Open issues for bugs or feature requests 129 | - Submit pull requests for improvements 130 | - Share feedback on implementation approaches 131 | 132 | The project particularly welcomes contributions in areas like buffer management and query optimization. -------------------------------------------------------------------------------- /buffer/buffer.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | "github.com/JyotinderSingh/dropdb/log" 7 | ) 8 | 9 | // Buffer represents an individual buffer. A data buffer wraps a page 10 | // and stores information about its status, 11 | // such as the associated disk block, 12 | // the number of times the buffer has been pinned, 13 | // whether its contents have been modified, 14 | // and if so, the id and lsn of the modifying transaction. 15 | type Buffer struct { 16 | fileManager *file.Manager 17 | logManager *log.Manager 18 | contents *file.Page 19 | block *file.BlockId 20 | pins int 21 | txnNum int 22 | lsn int 23 | } 24 | 25 | func NewBuffer(fileManager *file.Manager, logManager *log.Manager) *Buffer { 26 | return &Buffer{ 27 | fileManager: fileManager, 28 | logManager: logManager, 29 | contents: file.NewPage(fileManager.BlockSize()), 30 | block: nil, 31 | pins: 0, 32 | txnNum: -1, 33 | lsn: -1, 34 | } 35 | } 36 | 37 | func (b *Buffer) Contents() *file.Page { 38 | return b.contents 39 | } 40 | 41 | // Block returns a reference to the disk block allocated to the buffer. 42 | func (b *Buffer) Block() *file.BlockId { 43 | return b.block 44 | } 45 | 46 | func (b *Buffer) SetModified(txnNum, lsn int) { 47 | b.txnNum = txnNum 48 | 49 | // If LSN is smaller than 0, it indicates that a log record was not generated for this update. 50 | if lsn >= 0 { 51 | b.lsn = lsn 52 | } 53 | } 54 | 55 | // isPinned returns true if the buffer is currently pinned (that is, if it has a nonzero pin count). 56 | func (b *Buffer) isPinned() bool { 57 | return b.pins > 0 58 | } 59 | 60 | func (b *Buffer) modifyingTxn() int { 61 | return b.txnNum 62 | } 63 | 64 | // assignToBlock reads the contents of the specified block into the contents of the buffer. 65 | // If the buffer was dirty, then its previous contents are first written to disk. 66 | func (b *Buffer) assignToBlock(block *file.BlockId) error { 67 | if err := b.flush(); err != nil { 68 | return fmt.Errorf("failed to flush buffer for block %s: %v", b.block.String(), err) 69 | } 70 | b.block = block 71 | if err := b.fileManager.Read(block, b.contents); err != nil { 72 | return fmt.Errorf("failed to read block %s to buffer: %v", block.String(), err) 73 | } 74 | b.pins = 0 75 | return nil 76 | } 77 | 78 | // flush writes the buffer to its disk block if it is dirty. The method first writes the log record to the log file, 79 | // and then writes the contents of the buffer to disk. 80 | func (b *Buffer) flush() error { 81 | if b.txnNum >= 0 { 82 | if err := b.logManager.Flush(b.lsn); err != nil { 83 | return fmt.Errorf("failed to flush log record for txn %d: %v", b.txnNum, err) 84 | } 85 | if err := b.fileManager.Write(b.block, b.contents); err != nil { 86 | return fmt.Errorf("failed to write block: %v", err) 87 | } 88 | b.txnNum = -1 89 | } 90 | return nil 91 | } 92 | 93 | // pin increases the buffer's pin count. 94 | func (b *Buffer) pin() { 95 | b.pins++ 96 | } 97 | 98 | // unpin decreases the buffer's pin count. 99 | func (b *Buffer) unpin() { 100 | b.pins-- 101 | } 102 | -------------------------------------------------------------------------------- /buffer/naive_replacement_strategy.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import "sync" 4 | 5 | // NaiveStrategy is a simple buffer replacement strategy that selects the first unpinned buffer. 6 | type NaiveStrategy struct { 7 | ReplacementStrategy 8 | buffers []*Buffer 9 | mu sync.Mutex 10 | } 11 | 12 | // NewNaiveStrategy creates a new NaiveStrategy. 13 | func NewNaiveStrategy() *NaiveStrategy { 14 | return &NaiveStrategy{} 15 | } 16 | 17 | // Initialize initializes the strategy with the buffer pool. 18 | func (ns *NaiveStrategy) initialize(buffers []*Buffer) { 19 | ns.mu.Lock() 20 | defer ns.mu.Unlock() 21 | ns.buffers = buffers 22 | } 23 | 24 | // PinBuffer notifies the strategy that a buffer has been pinned. 25 | // No action needed for naive strategy. 26 | func (ns *NaiveStrategy) pinBuffer(buff *Buffer) { 27 | // No action needed. 28 | } 29 | 30 | // UnpinBuffer notifies the strategy that a buffer has been unpinned. 31 | // No action needed for naive strategy. 32 | func (ns *NaiveStrategy) unpinBuffer(buff *Buffer) { 33 | // No action needed. 34 | } 35 | 36 | // ChooseUnpinnedBuffer selects an unpinned buffer to replace. 37 | func (ns *NaiveStrategy) chooseUnpinnedBuffer() *Buffer { 38 | ns.mu.Lock() 39 | defer ns.mu.Unlock() 40 | for _, buff := range ns.buffers { 41 | if !buff.isPinned() { 42 | return buff 43 | } 44 | } 45 | return nil 46 | } 47 | -------------------------------------------------------------------------------- /buffer/replacement_strategy.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | // ReplacementStrategy defines the interface for buffer replacement strategies. 4 | type ReplacementStrategy interface { 5 | // initialize initializes the strategy with the buffer pool. 6 | initialize(buffers []*Buffer) 7 | // pinBuffer notifies the strategy that a buffer has been pinned. 8 | pinBuffer(buff *Buffer) 9 | // unpinBuffer notifies the strategy that a buffer has been unpinned. 10 | unpinBuffer(buff *Buffer) 11 | // chooseUnpinnedBuffer selects an unpinned buffer to replace. 12 | chooseUnpinnedBuffer() *Buffer 13 | } 14 | -------------------------------------------------------------------------------- /driver/conn.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "database/sql/driver" 5 | "errors" 6 | "github.com/JyotinderSingh/dropdb/server" 7 | "github.com/JyotinderSingh/dropdb/tx" 8 | ) 9 | 10 | // DropDBConn implements driver.Conn. 11 | type DropDBConn struct { 12 | db *server.DropDB 13 | 14 | // activeTx is non-nil if we are in an explicit transaction 15 | activeTx *tx.Transaction 16 | } 17 | 18 | // Prepare returns a prepared statement, but we'll simply store the SQL string. 19 | // Actual planning happens in Stmt.Exec / Stmt.Query (auto-commit style). 20 | func (c *DropDBConn) Prepare(query string) (driver.Stmt, error) { 21 | return &DropDBStmt{ 22 | conn: c, 23 | query: query, 24 | }, nil 25 | } 26 | 27 | // Close is called when database/sql is done with this connection. 28 | func (c *DropDBConn) Close() error { 29 | // There's no real "closing" an embedded DB, but if you had 30 | // a long-running Tx or resources pinned, you could clean them up here. 31 | return nil 32 | } 33 | 34 | // Begin starts a transaction 35 | func (c *DropDBConn) Begin() (driver.Tx, error) { 36 | if c.activeTx != nil { 37 | // either error or nested transactions if supported 38 | return nil, errors.New("already in a transaction") 39 | } 40 | newTx := c.db.NewTx() 41 | c.activeTx = newTx 42 | return &DropDBTx{ 43 | conn: c, 44 | tx: newTx, 45 | }, nil 46 | } 47 | -------------------------------------------------------------------------------- /driver/driver.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "github.com/JyotinderSingh/dropdb/server" 7 | ) 8 | 9 | const dbName = "dropdb" 10 | 11 | // Register the driver when this package is imported. 12 | func init() { 13 | sql.Register(dbName, &DropDBDriver{}) 14 | } 15 | 16 | // DropDBDriver implements database/sql/driver.Driver. 17 | var _ driver.Driver = (*DropDBDriver)(nil) 18 | 19 | type DropDBDriver struct{} 20 | 21 | // Open is the entry point. The directory is the path to the DB directory. 22 | func (d *DropDBDriver) Open(directory string) (driver.Conn, error) { 23 | db, err := server.NewDropDB(directory) 24 | if err != nil { 25 | return nil, err 26 | } 27 | return &DropDBConn{ 28 | db: db, 29 | // We do not open a transaction here. We'll open a new one for each statement (auto-commit). 30 | }, nil 31 | } 32 | -------------------------------------------------------------------------------- /driver/result.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import "errors" 4 | 5 | // DropDBResult implements driver.Result for the Exec path. 6 | type DropDBResult struct { 7 | rowsAffected int64 8 | } 9 | 10 | // LastInsertId is not implemented yet. 11 | func (r *DropDBResult) LastInsertId() (int64, error) { 12 | return 0, errors.New("LastInsertId not supported") 13 | } 14 | 15 | // RowsAffected returns how many rows were changed by the statement. 16 | func (r *DropDBResult) RowsAffected() (int64, error) { 17 | return r.rowsAffected, nil 18 | } 19 | -------------------------------------------------------------------------------- /driver/rows.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "github.com/JyotinderSingh/dropdb/plan" 7 | "github.com/JyotinderSingh/dropdb/scan" 8 | "github.com/JyotinderSingh/dropdb/tx" 9 | "github.com/JyotinderSingh/dropdb/types" 10 | "io" 11 | ) 12 | 13 | type DropDBRows struct { 14 | stmt *DropDBStmt 15 | tx *tx.Transaction 16 | 17 | scan scan.Scan 18 | plan plan.Plan 19 | done bool 20 | 21 | // We'll extract column names once. 22 | columns []string 23 | } 24 | 25 | // Columns returns the column names from the schema. 26 | func (r *DropDBRows) Columns() []string { 27 | if r.columns == nil { 28 | sch := r.plan.Schema() 29 | fields := sch.Fields() 30 | r.columns = make([]string, len(fields)) 31 | copy(r.columns, fields) 32 | } 33 | return r.columns 34 | } 35 | 36 | // Close is called by database/sql when the result set is done. 37 | // We need to release the underlying scan and commit the transaction (auto-commit). 38 | func (r *DropDBRows) Close() error { 39 | if r.done { 40 | return nil 41 | } 42 | r.done = true 43 | r.scan.Close() 44 | // We can commit the transaction to auto-commit. 45 | return r.tx.Commit() 46 | } 47 | 48 | // Next is called to advance the cursor and populate one row of data into 'dest'. 49 | // 'Dest' must match the number and types of the columns. 50 | func (r *DropDBRows) Next(dest []driver.Value) error { 51 | if r.done { 52 | return io.EOF 53 | } 54 | // Attempt to move to the next record 55 | hasNext, err := r.scan.Next() 56 | if err != nil { 57 | // On error, rollback so no partial commit 58 | _ = r.tx.Rollback() 59 | r.done = true 60 | return err 61 | } 62 | if !hasNext { 63 | // no more rows 64 | r.done = true 65 | // auto-commit 66 | if commitErr := r.tx.Commit(); commitErr != nil { 67 | return commitErr 68 | } 69 | return io.EOF 70 | } 71 | 72 | // We have another row. Extract each column from the scan. 73 | cols := r.Columns() 74 | for i, col := range cols { 75 | columnType := r.plan.Schema().Type(col) 76 | 77 | // Convert from scan's type to driver.Value 78 | var v interface{} 79 | switch columnType { 80 | case types.Integer: 81 | v, err = r.scan.GetInt(col) 82 | if err != nil { 83 | return err 84 | } 85 | case types.Varchar: 86 | v, err = r.scan.GetString(col) 87 | if err != nil { 88 | return err 89 | } 90 | case types.Boolean: 91 | v, err = r.scan.GetBool(col) 92 | if err != nil { 93 | return err 94 | } 95 | case types.Long: 96 | v, err = r.scan.GetLong(col) 97 | if err != nil { 98 | return err 99 | } 100 | case types.Short: 101 | v, err = r.scan.GetShort(col) 102 | if err != nil { 103 | return err 104 | } 105 | case types.Date: 106 | v, err = r.scan.GetDate(col) 107 | if err != nil { 108 | return err 109 | } 110 | default: 111 | return fmt.Errorf("unsupported field type: %v", columnType) 112 | } 113 | dest[i] = v 114 | } 115 | return nil 116 | } 117 | -------------------------------------------------------------------------------- /driver/statement.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "github.com/JyotinderSingh/dropdb/tx" 7 | "strings" 8 | ) 9 | 10 | // DropDBStmt implements driver.Stmt. 11 | type DropDBStmt struct { 12 | conn *DropDBConn 13 | query string 14 | } 15 | 16 | // Close is a no-op for this simple driver. 17 | func (s *DropDBStmt) Close() error { 18 | return nil 19 | } 20 | 21 | // NumInput returns -1 indicating we don't do bound parameters in this example. 22 | func (s *DropDBStmt) NumInput() int { 23 | return -1 24 | } 25 | 26 | // Exec executes a non-SELECT statement (INSERT, UPDATE, DELETE, CREATE, etc). 27 | // If the statement is actually a SELECT, we throw an error or ignore. 28 | func (s *DropDBStmt) Exec(args []driver.Value) (driver.Result, error) { 29 | var t *tx.Transaction 30 | if s.conn.activeTx == nil { 31 | // create transaction for auto-commit 32 | t = s.conn.db.NewTx() 33 | } else { 34 | // use the existing transaction 35 | t = s.conn.activeTx 36 | } 37 | 38 | planner := s.conn.db.Planner() 39 | 40 | // Simple detection if it's a "SELECT" (for a real driver, you'd parse properly). 41 | lower := strings.ToLower(strings.TrimSpace(s.query)) 42 | if strings.HasPrefix(lower, "select") { 43 | // By the tests’ logic, Exec() is for CREATE/INSERT/UPDATE/DELETE. 44 | // You could either: 45 | // 1. Return an error, or 46 | // 2. Forward to Query() if you prefer 47 | return nil, fmt.Errorf("Exec called with SELECT statement: %s", s.query) 48 | } 49 | 50 | // For all other statements (CREATE, INSERT, UPDATE, DELETE, etc.), 51 | // use planner.ExecuteUpdate 52 | rowsAffected, err := planner.ExecuteUpdate(s.query, t) 53 | 54 | if err != nil { 55 | // if it was an auto-commit transaction, rollback 56 | if s.conn.activeTx == nil { 57 | _ = t.Rollback() 58 | } 59 | return nil, err 60 | } 61 | 62 | if s.conn.activeTx == nil { 63 | // auto-commit 64 | if err := t.Commit(); err != nil { 65 | return nil, err 66 | } 67 | } 68 | 69 | // Return a driver.Result containing rows-affected count 70 | return &DropDBResult{rowsAffected: int64(rowsAffected)}, nil 71 | } 72 | 73 | // Query executes a SELECT statement and returns the resulting rows. 74 | func (s *DropDBStmt) Query(args []driver.Value) (driver.Rows, error) { 75 | // Decide whether we're in an explicit transaction or need to auto-commit 76 | var t *tx.Transaction 77 | if s.conn.activeTx == nil { 78 | // No active transaction => create a new one for auto-commit 79 | t = s.conn.db.NewTx() 80 | } else { 81 | // We already have an open transaction 82 | t = s.conn.activeTx 83 | } 84 | 85 | // We'll detect SELECT queries by prefix: 86 | lower := strings.ToLower(strings.TrimSpace(s.query)) 87 | if !strings.HasPrefix(lower, "select") { 88 | // By the test logic, Query is only for SELECT statements. 89 | // For everything else (CREATE, INSERT, etc.) we do Exec. 90 | return nil, fmt.Errorf("Query called with non-SELECT statement: %s", s.query) 91 | } 92 | 93 | planner := s.conn.db.Planner() 94 | 95 | // Use the Planner to build a query plan 96 | plan, err := planner.CreateQueryPlan(s.query, t) 97 | if err != nil { 98 | // Roll back on error 99 | _ = t.Rollback() 100 | return nil, err 101 | } 102 | 103 | sc, err := plan.Open() 104 | if err != nil { 105 | if s.conn.activeTx == nil { 106 | _ = t.Rollback() 107 | } 108 | return nil, err 109 | } 110 | 111 | // Return the Rows object. We'll commit/rollback inside Rows.Close() 112 | // (or when the result set is exhausted). 113 | return &DropDBRows{ 114 | stmt: s, 115 | tx: t, 116 | scan: sc, 117 | plan: plan, 118 | }, nil 119 | } 120 | -------------------------------------------------------------------------------- /driver/transaction.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import "github.com/JyotinderSingh/dropdb/tx" 4 | 5 | // DropDBTx implements driver.Tx so that database/sql can manage 6 | // a transaction with Commit() and Rollback(). 7 | // It just holds a reference to the connection so we can clear activeTx on commit/rollback 8 | type DropDBTx struct { 9 | conn *DropDBConn 10 | tx *tx.Transaction 11 | } 12 | 13 | func (t *DropDBTx) Commit() error { 14 | err := t.tx.Commit() 15 | t.conn.activeTx = nil 16 | return err 17 | } 18 | 19 | func (t *DropDBTx) Rollback() error { 20 | err := t.tx.Rollback() 21 | t.conn.activeTx = nil 22 | return err 23 | } 24 | -------------------------------------------------------------------------------- /dropdb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JyotinderSingh/dropdb/edaf375d74f5df7213e90904eb88c227ec768e67/dropdb.png -------------------------------------------------------------------------------- /example.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "log" 7 | "os" 8 | 9 | _ "github.com/JyotinderSingh/dropdb/driver" // Import the driver for side effects 10 | ) 11 | 12 | func main() { 13 | // Specify the directory for DropDB database files 14 | dbDir := "./mydb" 15 | defer func() { 16 | if err := os.RemoveAll(dbDir); err != nil { 17 | log.Fatalf("Failed to clean up database directory: %v\n", err) 18 | } 19 | }() 20 | 21 | // Open a connection to DropDB 22 | db, err := sql.Open("dropdb", dbDir) 23 | if err != nil { 24 | log.Fatalf("Failed to open DropDB: %v\n", err) 25 | } 26 | defer db.Close() 27 | 28 | // ---------------------------------------------------------------- 29 | // 1. Create a table (auto-commit mode) 30 | // ---------------------------------------------------------------- 31 | fmt.Println("Creating table in auto-commit mode...") 32 | createTableSQL := ` 33 | CREATE TABLE student ( 34 | sname VARCHAR(10), 35 | gradyear INT 36 | ) 37 | ` 38 | if _, err = db.Exec(createTableSQL); err != nil { 39 | log.Fatalf("Failed to create table: %v\n", err) 40 | } 41 | fmt.Println("Table 'student' created successfully.\n") 42 | 43 | // ---------------------------------------------------------------- 44 | // 2. Demonstrate a ROLLBACK 45 | // ---------------------------------------------------------------- 46 | fmt.Println("Starting an explicit transaction and rolling back...") 47 | tx1, err := db.Begin() 48 | if err != nil { 49 | log.Fatalf("Failed to begin transaction tx1: %v\n", err) 50 | } 51 | 52 | // Insert a row that we'll never commit 53 | _, err = tx1.Exec(`INSERT INTO student (sname, gradyear) VALUES ('Zoe', 9999)`) 54 | if err != nil { 55 | // If any error occurs, roll back and exit 56 | _ = tx1.Rollback() 57 | log.Fatalf("Failed to insert in tx1: %v\n", err) 58 | } 59 | 60 | // Now intentionally rollback 61 | if err := tx1.Rollback(); err != nil { 62 | log.Fatalf("Failed to roll back tx1: %v\n", err) 63 | } 64 | 65 | fmt.Println("Rolled back transaction. Row for 'Zoe' should NOT be in the table.\n") 66 | 67 | // ---------------------------------------------------------------- 68 | // 3. Demonstrate a COMMIT with multiple inserts 69 | // ---------------------------------------------------------------- 70 | fmt.Println("Starting a second explicit transaction and committing...") 71 | tx2, err := db.Begin() 72 | if err != nil { 73 | log.Fatalf("Failed to begin transaction tx2: %v\n", err) 74 | } 75 | 76 | // Insert rows into the table inside tx2 77 | insertStatements := []string{ 78 | `INSERT INTO student (sname, gradyear) VALUES ('Alice', 2023)`, 79 | `INSERT INTO student (sname, gradyear) VALUES ('Bob', 2024)`, 80 | `INSERT INTO student (sname, gradyear) VALUES ('Charlie', 2025)`, 81 | } 82 | 83 | for _, stmt := range insertStatements { 84 | if _, err := tx2.Exec(stmt); err != nil { 85 | // If insert fails, roll back 86 | _ = tx2.Rollback() 87 | log.Fatalf("Failed to insert row in tx2: %v\n", err) 88 | } 89 | } 90 | 91 | // Commit tx2 to persist the inserts 92 | if err := tx2.Commit(); err != nil { 93 | log.Fatalf("Failed to commit tx2: %v\n", err) 94 | } 95 | fmt.Println("Transaction tx2 committed successfully.\n") 96 | 97 | // ---------------------------------------------------------------- 98 | // 4. Query the table to confirm the results 99 | // ---------------------------------------------------------------- 100 | fmt.Println("Querying rows...") 101 | querySQL := "SELECT sname, gradyear FROM student ORDER BY gradyear" 102 | rows, err := db.Query(querySQL) 103 | if err != nil { 104 | log.Fatalf("Failed to query rows: %v\n", err) 105 | } 106 | defer rows.Close() 107 | 108 | fmt.Println("Query results:") 109 | for rows.Next() { 110 | var name string 111 | var year int 112 | // NOTE: The order of columns is reversed due to how SortPlan is implemented. This is a known issue and should be fixed using some changes in the SortPlan implementation. 113 | if err := rows.Scan(&year, &name); err != nil { 114 | log.Fatalf("Failed to scan row: %v\n", err) 115 | } 116 | fmt.Printf(" - Name: %s, Graduation Year: %d\n", name, year) 117 | } 118 | 119 | // Check for any errors encountered during iteration 120 | if err := rows.Err(); err != nil { 121 | log.Fatalf("Rows iteration error: %v\n", err) 122 | } 123 | 124 | fmt.Println("\nQuery completed successfully. Notice that 'Zoe' is missing because her insert was rolled back, but 'Alice', 'Bob', and 'Charlie' are present.") 125 | } 126 | -------------------------------------------------------------------------------- /file/blockid.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import "fmt" 4 | 5 | // BlockId identifies a disk block by its filename and block number. 6 | type BlockId struct { 7 | File string 8 | BlockNumber int 9 | } 10 | 11 | func NewBlockId(filename string, blockNumber int) *BlockId { 12 | return &BlockId{ 13 | File: filename, 14 | BlockNumber: blockNumber, 15 | } 16 | } 17 | 18 | func (b *BlockId) Filename() string { 19 | return b.File 20 | } 21 | 22 | func (b *BlockId) Number() int { 23 | return b.BlockNumber 24 | } 25 | 26 | func (b *BlockId) String() string { 27 | return fmt.Sprintf("[file %s, block %d]", b.File, b.BlockNumber) 28 | } 29 | 30 | func (b *BlockId) Equals(other *BlockId) bool { 31 | return b.File == other.File && b.BlockNumber == other.BlockNumber 32 | } 33 | -------------------------------------------------------------------------------- /file/page.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "github.com/JyotinderSingh/dropdb/types" 7 | "runtime" 8 | "time" 9 | "unicode/utf8" 10 | ) 11 | 12 | // Page represents a page in the database file. 13 | // A page is a fixed-size block of data that is read from or written to disk as a unit. 14 | // The size of a page is determined by the file manager and is typically a multiple of the disk block size. 15 | // Pages are the unit of transfer between disk and main memory. 16 | type Page struct { 17 | buffer []byte 18 | } 19 | 20 | // NewPage creates a Page with a buffer of the given block size. 21 | func NewPage(blockSize int) *Page { 22 | return &Page{buffer: make([]byte, blockSize)} 23 | } 24 | 25 | // NewPageFromBytes creates a Page by wrapping the provided byte slice. 26 | func NewPageFromBytes(bytes []byte) *Page { 27 | return &Page{buffer: bytes} 28 | } 29 | 30 | // GetInt retrieves an integer from the buffer at the specified offset. 31 | func (p *Page) GetInt(offset int) int { 32 | if runtime.GOARCH == "386" || runtime.GOARCH == "arm" { 33 | return int(binary.BigEndian.Uint32(p.buffer[offset:])) 34 | } 35 | // arm64 (M1/M2 Macs) and amd64 use 64-bit 36 | return int(binary.BigEndian.Uint64(p.buffer[offset:])) 37 | } 38 | 39 | // SetInt writes an integer to the buffer at the specified offset. 40 | func (p *Page) SetInt(offset int, n int) { 41 | if runtime.GOARCH == "386" || runtime.GOARCH == "arm" { 42 | binary.BigEndian.PutUint32(p.buffer[offset:], uint32(n)) 43 | } else { 44 | binary.BigEndian.PutUint64(p.buffer[offset:], uint64(n)) 45 | } 46 | } 47 | 48 | // GetLong retrieves a 64-bit integer from the buffer at the specified offset. 49 | func (p *Page) GetLong(offset int) int64 { 50 | return int64(binary.BigEndian.Uint64(p.buffer[offset:])) 51 | } 52 | 53 | // SetLong writes a 64-bit integer to the buffer at the specified offset. 54 | func (p *Page) SetLong(offset int, n int64) { 55 | binary.BigEndian.PutUint64(p.buffer[offset:], uint64(n)) 56 | } 57 | 58 | // GetBytes retrieves a byte slice from the buffer starting at the specified offset. 59 | func (p *Page) GetBytes(offset int) []byte { 60 | length := p.GetInt(offset) 61 | start := offset + types.IntSize 62 | end := start + int(length) 63 | b := make([]byte, length) 64 | copy(b, p.buffer[start:end]) 65 | return b 66 | } 67 | 68 | // SetBytes writes a byte slice to the buffer starting at the specified offset. 69 | func (p *Page) SetBytes(offset int, b []byte) { 70 | length := len(b) 71 | p.SetInt(offset, length) 72 | start := offset + types.IntSize 73 | copy(p.buffer[start:], b) 74 | } 75 | 76 | // GetString retrieves a string from the buffer at the specified offset. 77 | func (p *Page) GetString(offset int) (string, error) { 78 | b := p.GetBytes(offset) 79 | if !utf8.Valid(b) { 80 | return "", errors.New("invalid UTF-8 encoding") 81 | } 82 | return string(b), nil 83 | } 84 | 85 | // SetString writes a string to the buffer at the specified offset. 86 | func (p *Page) SetString(offset int, s string) error { 87 | if !utf8.ValidString(s) { 88 | return errors.New("string contains invalid UTF-8 characters") 89 | } 90 | p.SetBytes(offset, []byte(s)) 91 | return nil 92 | } 93 | 94 | // GetShort retrieves a 16-bit integer from the buffer at the specified offset. 95 | func (p *Page) GetShort(offset int) int16 { 96 | return int16(binary.BigEndian.Uint16(p.buffer[offset:])) 97 | } 98 | 99 | // SetShort writes a 16-bit integer to the buffer at the specified offset. 100 | func (p *Page) SetShort(offset int, n int16) { 101 | binary.BigEndian.PutUint16(p.buffer[offset:], uint16(n)) 102 | } 103 | 104 | // GetBool retrieves a boolean from the buffer at the specified offset. 105 | func (p *Page) GetBool(offset int) bool { 106 | return p.buffer[offset] != 0 107 | } 108 | 109 | // SetBool writes a boolean to the buffer at the specified offset. 110 | func (p *Page) SetBool(offset int, b bool) { 111 | if b { 112 | p.buffer[offset] = 1 113 | } else { 114 | p.buffer[offset] = 0 115 | } 116 | } 117 | 118 | // GetDate retrieves a date (stored as a Unix timestamp) from the buffer at the specified offset. 119 | func (p *Page) GetDate(offset int) time.Time { 120 | unixTimestamp := int64(binary.BigEndian.Uint64(p.buffer[offset:])) 121 | return time.Unix(unixTimestamp, 0) 122 | } 123 | 124 | // SetDate writes a date (as a Unix timestamp) to the buffer at the specified offset. 125 | func (p *Page) SetDate(offset int, date time.Time) { 126 | binary.BigEndian.PutUint64(p.buffer[offset:], uint64(date.Unix())) 127 | } 128 | 129 | // MaxLength calculates the maximum number of bytes required to store a string of a given length. 130 | func MaxLength(strlen int) int { 131 | // Golang uses UTF-8 encoding 132 | // Add utils.IntSize bytes for the length prefix. 133 | return types.IntSize + strlen*utf8.UTFMax 134 | } 135 | 136 | // Contents returns the byte buffer maintained by the Page. 137 | func (p *Page) Contents() []byte { 138 | return p.buffer 139 | } 140 | -------------------------------------------------------------------------------- /file/page_test.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/types" 5 | "github.com/stretchr/testify/assert" 6 | "math" 7 | "testing" 8 | "unicode/utf8" 9 | ) 10 | 11 | func TestPage(t *testing.T) { 12 | t.Run("NewPage", func(t *testing.T) { 13 | assert := assert.New(t) 14 | blockSize := 400 15 | page := NewPage(blockSize) 16 | assert.Equal(blockSize, len(page.Contents()), "Buffer size should match block size") 17 | }) 18 | 19 | t.Run("NewPageFromBytes", func(t *testing.T) { 20 | assert := assert.New(t) 21 | data := []byte{1, 2, 3, 4} 22 | page := NewPageFromBytes(data) 23 | 24 | assert.Equal(len(data), len(page.Contents()), "Buffer size should match input data size") 25 | assert.Equal(data, page.Contents(), "Buffer contents should match input data") 26 | }) 27 | 28 | t.Run("IntOperations", func(t *testing.T) { 29 | assert := assert.New(t) 30 | page := NewPage(100) 31 | testCases := []struct { 32 | offset int 33 | value int 34 | }{ 35 | {0, 42}, 36 | {4, -123}, 37 | {8, 0}, 38 | {12, math.MaxInt}, 39 | {16, math.MinInt}, 40 | } 41 | 42 | for _, tc := range testCases { 43 | page.SetInt(tc.offset, tc.value) 44 | got := page.GetInt(tc.offset) 45 | assert.Equal(tc.value, got, "Integer value at offset %d should match", tc.offset) 46 | } 47 | }) 48 | 49 | t.Run("BytesOperations", func(t *testing.T) { 50 | assert := assert.New(t) 51 | page := NewPage(100) 52 | testCases := []struct { 53 | offset int 54 | data []byte 55 | }{ 56 | {0, []byte{1, 2, 3, 4}}, 57 | {20, []byte{}}, // empty array 58 | {40, []byte{255, 0, 255}}, 59 | {60, make([]byte, 20)}, // zero bytes 60 | } 61 | 62 | for _, tc := range testCases { 63 | page.SetBytes(tc.offset, tc.data) 64 | got := page.GetBytes(tc.offset) 65 | assert.Equal(tc.data, got, "Byte data at offset %d should match", tc.offset) 66 | } 67 | }) 68 | 69 | t.Run("StringOperations", func(t *testing.T) { 70 | assert := assert.New(t) 71 | page := NewPage(1000) 72 | testCases := []struct { 73 | offset string 74 | value string 75 | valid bool 76 | }{ 77 | {offset: "basic", value: "Hello, World!", valid: true}, 78 | {offset: "empty", value: "", valid: true}, 79 | {offset: "unicode", value: "Hello, 世界!", valid: true}, 80 | {offset: "emoji", value: "🌍🌎🌏", valid: true}, 81 | {offset: "multiline", value: "Line 1\nLine 2", valid: true}, 82 | } 83 | 84 | offset := 0 85 | for _, tc := range testCases { 86 | t.Run(tc.offset, func(t *testing.T) { 87 | err := page.SetString(offset, tc.value) 88 | if tc.valid { 89 | assert.NoError(err, "SetString should not fail for valid string") 90 | got, err := page.GetString(offset) 91 | assert.NoError(err, "GetString should not fail for valid string") 92 | assert.Equal(tc.value, got, "String value should match") 93 | } 94 | offset += MaxLength(len(tc.value)) + 8 // add some padding 95 | }) 96 | } 97 | }) 98 | 99 | t.Run("InvalidUTF8", func(t *testing.T) { 100 | assert := assert.New(t) 101 | page := NewPage(100) 102 | offset := 0 103 | 104 | // Create invalid UTF-8 sequence 105 | invalidUTF8 := []byte{0xFF, 0xFE, 0xFD} 106 | page.SetBytes(offset, invalidUTF8) 107 | 108 | _, err := page.GetString(offset) 109 | assert.Error(err, "GetString should fail for invalid UTF-8 sequence") 110 | }) 111 | 112 | t.Run("MaxLength", func(t *testing.T) { 113 | assert := assert.New(t) 114 | testCases := []struct { 115 | strlen int 116 | want int 117 | }{ 118 | {0, types.IntSize}, // empty string 119 | {1, types.IntSize + utf8.UTFMax}, // single character 120 | {10, types.IntSize + 10*utf8.UTFMax}, // 10 characters 121 | {1000, types.IntSize + 1000*utf8.UTFMax}, // 1000 characters 122 | } 123 | 124 | for _, tc := range testCases { 125 | got := MaxLength(tc.strlen) 126 | assert.Equal(tc.want, got, "MaxLength for string length %d should match", tc.strlen) 127 | } 128 | }) 129 | 130 | t.Run("BufferBoundary", func(t *testing.T) { 131 | assert := assert.New(t) 132 | blockSize := 20 133 | page := NewPage(blockSize) 134 | 135 | // Test writing at the end of buffer 136 | lastValidOffset := blockSize - 8 // space for one int64, this test assumes that it runs on a 64-bit machine. 137 | page.SetInt(lastValidOffset, 42) 138 | got := page.GetInt(lastValidOffset) 139 | assert.Equal(42, got, "Value at buffer boundary should match") 140 | }) 141 | 142 | t.Run("LargeData", func(t *testing.T) { 143 | assert := assert.New(t) 144 | blockSize := 1000 145 | page := NewPage(blockSize) 146 | 147 | // Create large string 148 | largeString := make([]byte, 500) 149 | for i := range largeString { 150 | largeString[i] = byte('A' + (i % 26)) 151 | } 152 | 153 | err := page.SetString(0, string(largeString)) 154 | assert.NoError(err, "Setting large string should not fail") 155 | 156 | got, err := page.GetString(0) 157 | assert.NoError(err, "Getting large string should not fail") 158 | assert.Equal(string(largeString), got, "Large string content should match") 159 | }) 160 | } 161 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/JyotinderSingh/dropdb 2 | 3 | go 1.23.2 4 | 5 | require ( 6 | github.com/davecgh/go-spew v1.1.1 // indirect 7 | github.com/pmezard/go-difflib v1.0.0 // indirect 8 | github.com/stretchr/testify v1.9.0 // indirect 9 | gopkg.in/yaml.v3 v3.0.1 // indirect 10 | ) 11 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 6 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 7 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 8 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 9 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 10 | -------------------------------------------------------------------------------- /index/btree/directory_entry.go: -------------------------------------------------------------------------------- 1 | package btree 2 | 3 | type DirectoryEntry struct { 4 | dataValue any 5 | blockNumber int 6 | } 7 | 8 | // NewDirectoryEntry creates a new DirectoryEntry with the specified data value and block number. 9 | func NewDirectoryEntry(dataValue any, blockNumber int) *DirectoryEntry { 10 | return &DirectoryEntry{dataValue, blockNumber} 11 | } 12 | 13 | // DataValue returns the data value of this directory entry. 14 | func (de *DirectoryEntry) DataValue() any { 15 | return de.dataValue 16 | } 17 | 18 | // BlockNumber returns the block number of this directory entry. 19 | func (de *DirectoryEntry) BlockNumber() int { 20 | return de.blockNumber 21 | } 22 | -------------------------------------------------------------------------------- /index/common/constants.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | const ( 4 | BlockField = "block" 5 | IDField = "id" 6 | DataValueField = "data_value" 7 | ) 8 | -------------------------------------------------------------------------------- /index/hash/index.go: -------------------------------------------------------------------------------- 1 | package hash 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/index" 6 | "github.com/JyotinderSingh/dropdb/index/common" 7 | "github.com/JyotinderSingh/dropdb/record" 8 | "github.com/JyotinderSingh/dropdb/table" 9 | "github.com/JyotinderSingh/dropdb/tx" 10 | "github.com/JyotinderSingh/dropdb/utils" 11 | ) 12 | 13 | const ( 14 | numBuckets = 100 15 | ) 16 | 17 | // ensure index interface is implemented 18 | var _ index.Index = (*Index)(nil) 19 | 20 | type Index struct { 21 | transaction *tx.Transaction 22 | indexName string 23 | layout *record.Layout 24 | searchKey any 25 | tableScan *table.Scan 26 | } 27 | 28 | // NewIndex opens a hash index for the specified index. 29 | func NewIndex(transaction *tx.Transaction, indexName string, layout *record.Layout) index.Index { 30 | return &Index{ 31 | transaction: transaction, 32 | indexName: indexName, 33 | layout: layout, 34 | searchKey: nil, 35 | tableScan: nil, 36 | } 37 | } 38 | 39 | // BeforeFirst positions the index before the first index record having the specified search key. 40 | // The method hashes the search key to determine the bucket, 41 | // and then opens a table scan on the file corresponding to that bucket. 42 | // The table scan for the previous bucket (if any) is closed. 43 | func (idx *Index) BeforeFirst(searchKey any) error { 44 | idx.Close() 45 | idx.searchKey = searchKey 46 | hashValue, err := utils.HashValue(searchKey) 47 | if err != nil { 48 | return err 49 | } 50 | bucket := hashValue % numBuckets 51 | tableName := fmt.Sprintf("%s-%d", idx.indexName, bucket) 52 | idx.tableScan, err = table.NewTableScan(idx.transaction, tableName, idx.layout) 53 | return err 54 | } 55 | 56 | // Next moves to the next index record having the search key. 57 | // The method loops through the table scan for the bucket, looking for a matching record, 58 | // and returns false if there are no more such records. 59 | func (idx *Index) Next() (bool, error) { 60 | for { 61 | hasNext, err := idx.tableScan.Next() 62 | if err != nil || !hasNext { 63 | return false, err 64 | } 65 | 66 | currentValue, err := idx.tableScan.GetVal(common.DataValueField) 67 | if err != nil { 68 | return false, err 69 | } 70 | if currentValue == idx.searchKey { 71 | return true, nil 72 | } 73 | } 74 | } 75 | 76 | // GetDataRecordID retrieves the data record ID from the current record in the table scan for the bucket. 77 | func (idx *Index) GetDataRecordID() (*record.ID, error) { 78 | blockNumber, err := idx.tableScan.GetInt(common.BlockField) 79 | if err != nil { 80 | return nil, err 81 | } 82 | id, err := idx.tableScan.GetInt(common.IDField) 83 | if err != nil { 84 | return nil, err 85 | } 86 | 87 | return record.NewID(blockNumber, id), nil 88 | } 89 | 90 | // Insert inserts a new record into the table scan for the bucket. 91 | func (idx *Index) Insert(dataValue any, dataRecordID *record.ID) error { 92 | if err := idx.BeforeFirst(dataValue); err != nil { 93 | return err 94 | } 95 | 96 | if err := idx.tableScan.Insert(); err != nil { 97 | return err 98 | } 99 | if err := idx.tableScan.SetInt(common.BlockField, dataRecordID.BlockNumber()); err != nil { 100 | return err 101 | } 102 | if err := idx.tableScan.SetInt(common.IDField, dataRecordID.Slot()); err != nil { 103 | return err 104 | } 105 | return idx.tableScan.SetVal(common.DataValueField, dataValue) 106 | } 107 | 108 | // Delete deletes the specified record from the table scan for the bucket. 109 | // The method starts at the beginning of the scan, and loops through the 110 | // records until the specified record is found. If the record is found, it is deleted. 111 | // If the record is not found, the method does nothing and does not return an error. 112 | func (idx *Index) Delete(dataValue any, dataRecordID *record.ID) error { 113 | if err := idx.BeforeFirst(dataValue); err != nil { 114 | return err 115 | } 116 | 117 | for { 118 | hasNext, err := idx.tableScan.Next() 119 | if err != nil { 120 | return err 121 | } 122 | if !hasNext { 123 | break 124 | } 125 | 126 | currentRecordID, err := idx.GetDataRecordID() 127 | if err != nil { 128 | return err 129 | } 130 | 131 | if currentRecordID.Equals(dataRecordID) { 132 | return idx.tableScan.Delete() 133 | } 134 | } 135 | 136 | return nil 137 | } 138 | 139 | // Close closes the index by closing the current table scan. 140 | func (idx *Index) Close() { 141 | if idx.tableScan != nil { 142 | idx.tableScan.Close() 143 | idx.tableScan = nil 144 | } 145 | } 146 | 147 | // SearchCost returns the cost of searching an index file having 148 | // the specified number of blocks. 149 | // the method assumes that all buckets are about the same size, 150 | // so the cost is simply the size of the bucket. 151 | func SearchCost(numBlocks, recordsPerBucket int) int { 152 | return numBlocks / numBuckets 153 | } 154 | -------------------------------------------------------------------------------- /index/index.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | import "github.com/JyotinderSingh/dropdb/record" 4 | 5 | type Index interface { 6 | // BeforeFirst positions the index before the 7 | // first record having the specified search key. 8 | BeforeFirst(searchKey any) error 9 | 10 | // Next moves the index to the next record having the search key specified in the BeforeFirst method. 11 | // Returns false if there are no more such index records. 12 | Next() (bool, error) 13 | 14 | // GetDataRecordID returns the data record ID stored in the current index record. 15 | GetDataRecordID() (*record.ID, error) 16 | 17 | // Insert inserts a new index record having the specified dataValue and dataRecordID values. 18 | Insert(dataValue any, dataRecordID *record.ID) error 19 | 20 | // Delete deletes the index record having the specified dataValue and dataRecordID values. 21 | Delete(dataValue any, dataRecordID *record.ID) error 22 | 23 | // Close closes the index. 24 | Close() 25 | } 26 | -------------------------------------------------------------------------------- /log/iterator.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/JyotinderSingh/dropdb/file" 7 | "github.com/JyotinderSingh/dropdb/types" 8 | ) 9 | 10 | // Iterator provides the ability to move through the records of the log file in reverse order. 11 | type Iterator struct { 12 | fileManager *file.Manager 13 | block *file.BlockId 14 | page *file.Page 15 | currentPosition int 16 | boundary int 17 | } 18 | 19 | // NewIterator creates an iterator for the records in the log file, positioned after the last log record. 20 | func NewIterator(fileManager *file.Manager, block *file.BlockId) (*Iterator, error) { 21 | page := file.NewPage(fileManager.BlockSize()) 22 | iterator := &Iterator{ 23 | fileManager: fileManager, 24 | block: block, 25 | page: page, 26 | } 27 | if err := iterator.moveToBlock(block); err != nil { 28 | return nil, fmt.Errorf("failed to move to block: %v", err) 29 | } 30 | return iterator, nil 31 | } 32 | 33 | // HasNext determines if the current log record is the earliest record in the log file. Returns true if there is an earlier record. 34 | func (it *Iterator) HasNext() bool { 35 | return it.currentPosition < it.fileManager.BlockSize() || it.block.Number() > 0 36 | } 37 | 38 | // Next moves to the next log record in the block. 39 | // If there are no more log records in the block, then move to the previous block and return the log record from there. 40 | // Returns the next earliest log record. 41 | func (it *Iterator) Next() ([]byte, error) { 42 | // Check if there are no more records left in the current block. 43 | if it.currentPosition == it.fileManager.BlockSize() { 44 | // Check if this is the first block. 45 | if it.block.Number() == 0 { 46 | return nil, errors.New("no more log records") 47 | } 48 | 49 | // Move to the previous block in the log file. 50 | it.block = &file.BlockId{File: it.block.Filename(), BlockNumber: it.block.Number() - 1} 51 | if err := it.moveToBlock(it.block); err != nil { 52 | return nil, fmt.Errorf("failed to move to block: %v", err) 53 | } 54 | } 55 | 56 | record := it.page.GetBytes(it.currentPosition) 57 | it.currentPosition += types.IntSize + len(record) // (size of record) + (length of record) 58 | return record, nil 59 | } 60 | 61 | // moveToBlock moves to the specified log block and positions it at the first record in that block (i.e., the most recent one). 62 | func (it *Iterator) moveToBlock(block *file.BlockId) error { 63 | if err := it.fileManager.Read(block, it.page); err != nil { 64 | return fmt.Errorf("failed to read block: %v", err) 65 | } 66 | 67 | it.boundary = it.page.GetInt(0) 68 | it.currentPosition = it.boundary 69 | return nil 70 | } 71 | -------------------------------------------------------------------------------- /log/manager_test.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | "github.com/stretchr/testify/assert" 7 | "os" 8 | "testing" 9 | ) 10 | 11 | // Helper function to create a new temporary FileMgr 12 | func createTempFileMgr(blocksize int) (*file.Manager, func(), error) { 13 | tmpDir, err := os.MkdirTemp("", "filemgr_test") 14 | if err != nil { 15 | return nil, nil, fmt.Errorf("failed to create temp directory: %v", err) 16 | } 17 | 18 | fm, err := file.NewManager(tmpDir, blocksize) 19 | if err != nil { 20 | os.RemoveAll(tmpDir) 21 | return nil, nil, fmt.Errorf("failed to create FileMgr: %v", err) 22 | } 23 | 24 | cleanup := func() { os.RemoveAll(tmpDir) } 25 | return fm, cleanup, nil 26 | } 27 | 28 | func TestLogMgr_AppendAndIteratorConsistency(t *testing.T) { 29 | assert := assert.New(t) 30 | blockSize := 4096 31 | fm, cleanup, err := createTempFileMgr(blockSize) 32 | defer cleanup() 33 | assert.NoErrorf(err, "Error creating FileMgr: %v", err) 34 | 35 | logfile := "testlog" 36 | lm, err := NewManager(fm, logfile) 37 | assert.NoErrorf(err, "Error creating LogMgr: %v", err) 38 | 39 | // Append and flush multiple records, then verify consistency 40 | recordCount := 100 41 | records := make([][]byte, recordCount) 42 | for i := 0; i < recordCount; i++ { 43 | records[i] = []byte(fmt.Sprintf("log record %d", i+1)) 44 | _, err := lm.Append(records[i]) 45 | assert.NoErrorf(err, "Error appending record %d: %v", i+1, err) 46 | } 47 | 48 | // Verify with iterator in reverse order 49 | iterator, err := lm.Iterator() 50 | assert.NoErrorf(err, "Error creating log iterator: %v", err) 51 | 52 | for i := recordCount - 1; i >= 0; i-- { 53 | assert.Truef(iterator.HasNext(), "Expected more records, but iterator has none") 54 | 55 | rec, err := iterator.Next() 56 | assert.NoErrorf(err, "Error getting next record from iterator: %v", err) 57 | 58 | assert.Equal(rec, records[i]) 59 | } 60 | 61 | assert.Falsef(iterator.HasNext(), "Expected no more records, but iterator has more") 62 | } 63 | -------------------------------------------------------------------------------- /materialize/temp_table.go: -------------------------------------------------------------------------------- 1 | package materialize 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/record" 6 | "github.com/JyotinderSingh/dropdb/scan" 7 | "github.com/JyotinderSingh/dropdb/table" 8 | "github.com/JyotinderSingh/dropdb/tx" 9 | "sync" 10 | ) 11 | 12 | const tempTablePrefix = "temp" 13 | 14 | // TempTable represents a temporary table not registered in the catalog. 15 | type TempTable struct { 16 | tx *tx.Transaction 17 | tblName string 18 | layout *record.Layout 19 | } 20 | 21 | var ( 22 | nextTableNum = 0 23 | nextTableNumMu sync.Mutex 24 | ) 25 | 26 | // NewTempTable creates a new temporary table with the specified schema and transaction. 27 | func NewTempTable(tx *tx.Transaction, schema *record.Schema) *TempTable { 28 | return &TempTable{ 29 | tx: tx, 30 | tblName: nextTableName(), 31 | layout: record.NewLayout(schema), 32 | } 33 | } 34 | 35 | // Open opens a table scan for the temporary table. 36 | func (tt *TempTable) Open() (scan.UpdateScan, error) { 37 | return table.NewTableScan(tt.tx, tt.tblName, tt.layout) 38 | } 39 | 40 | // TableName returns the name of the temporary table. 41 | func (tt *TempTable) TableName() string { 42 | return tt.tblName 43 | } 44 | 45 | // GetLayout returns the table's metadata (layout). 46 | func (tt *TempTable) GetLayout() *record.Layout { 47 | return tt.layout 48 | } 49 | 50 | // nextTableName generates a unique name for the next temporary table. 51 | func nextTableName() string { 52 | nextTableNumMu.Lock() 53 | defer nextTableNumMu.Unlock() 54 | nextTableNum++ 55 | return fmt.Sprintf("%s%d", tempTablePrefix, nextTableNum) 56 | } 57 | -------------------------------------------------------------------------------- /metadata/index_info.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/index" 5 | "github.com/JyotinderSingh/dropdb/index/common" 6 | "github.com/JyotinderSingh/dropdb/index/hash" 7 | "github.com/JyotinderSingh/dropdb/record" 8 | "github.com/JyotinderSingh/dropdb/tx" 9 | "github.com/JyotinderSingh/dropdb/types" 10 | ) 11 | 12 | type IndexInfo struct { 13 | indexName string 14 | fieldName string 15 | transaction *tx.Transaction 16 | tableSchema *record.Schema 17 | indexLayout *record.Layout 18 | statInfo *StatInfo 19 | } 20 | 21 | // NewIndexInfo creates an IndexInfo object for the specified index. 22 | func NewIndexInfo(indexName, fieldName string, tableSchema *record.Schema, 23 | transaction *tx.Transaction, statInfo *StatInfo) *IndexInfo { 24 | ii := &IndexInfo{ 25 | indexName: indexName, 26 | fieldName: fieldName, 27 | transaction: transaction, 28 | tableSchema: tableSchema, 29 | statInfo: statInfo, 30 | } 31 | ii.indexLayout = ii.CreateIndexLayout() 32 | return ii 33 | } 34 | 35 | // Open opens the index described by this object. 36 | func (ii *IndexInfo) Open() index.Index { 37 | return hash.NewIndex(ii.transaction, ii.indexName, ii.indexLayout) 38 | //return NewBtreeIndex(ii.transaction, ii.indexName, ii.indexLayout) 39 | } 40 | 41 | // BlocksAccessed estimates the number of block accesses required to 42 | // find all the index records having a particular search key. 43 | // The method uses the table's metadata to estimate the size of the 44 | // index file and the number of index records per block. 45 | // It then passes this information to the traversalCost method of the 46 | // appropriate index type, which then provides the estimate. 47 | func (ii *IndexInfo) BlocksAccessed() int { 48 | recordsPerBlock := ii.transaction.BlockSize() / ii.indexLayout.SlotSize() 49 | numBlocks := ii.statInfo.RecordsOutput() / recordsPerBlock 50 | return hash.SearchCost(numBlocks, recordsPerBlock) 51 | //return BtreeIndex.SearchCost(numBlocks, recordsPerBlock) 52 | } 53 | 54 | // RecordsOutput returns the estimated number of records having a search key. 55 | // This value is the same as doing a select query; that is, it is the number of records in the table 56 | // divided by the number of distinct values of the indexed field. 57 | func (ii *IndexInfo) RecordsOutput() int { 58 | return ii.statInfo.RecordsOutput() / ii.statInfo.DistinctValues(ii.fieldName) 59 | } 60 | 61 | // DistinctValues returns the number of distinct values for the indexed field 62 | // in the underlying table, or 1 for the indexed field. 63 | func (ii *IndexInfo) DistinctValues(fieldName string) int { 64 | if ii.fieldName == fieldName { 65 | return 1 66 | } 67 | return ii.statInfo.DistinctValues(fieldName) 68 | } 69 | 70 | // CreateIndexLayout returns the layout of the index records. 71 | // The schema consists of the dataRecordID (which is represented as two integers, 72 | // the block number and the record ID) and the dataValue (which is the indexed field). 73 | // Schema information about the indexed field is obtained from the table's schema. 74 | func (ii *IndexInfo) CreateIndexLayout() *record.Layout { 75 | schema := record.NewSchema() 76 | schema.AddIntField(common.BlockField) 77 | schema.AddIntField(common.IDField) 78 | switch ii.tableSchema.Type(ii.fieldName) { 79 | case types.Integer: 80 | schema.AddIntField(common.DataValueField) 81 | case types.Varchar: 82 | schema.AddStringField(common.DataValueField, ii.tableSchema.Length(ii.fieldName)) 83 | case types.Boolean: 84 | schema.AddBoolField(common.DataValueField) 85 | case types.Long: 86 | schema.AddLongField(common.DataValueField) 87 | case types.Short: 88 | schema.AddShortField(common.DataValueField) 89 | case types.Date: 90 | schema.AddDateField(common.DataValueField) 91 | } 92 | 93 | return record.NewLayout(schema) 94 | } 95 | -------------------------------------------------------------------------------- /metadata/index_info_test.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/buffer" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | "github.com/JyotinderSingh/dropdb/index/common" 7 | "github.com/JyotinderSingh/dropdb/log" 8 | "github.com/JyotinderSingh/dropdb/record" 9 | "github.com/JyotinderSingh/dropdb/tx" 10 | "github.com/JyotinderSingh/dropdb/tx/concurrency" 11 | "github.com/JyotinderSingh/dropdb/types" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | "os" 15 | "testing" 16 | ) 17 | 18 | func setupIndexInfoTest(t *testing.T) (*IndexInfo, *tx.Transaction, func()) { 19 | t.Helper() 20 | 21 | dbDir := t.TempDir() 22 | 23 | fm, err := file.NewManager(dbDir, 400) 24 | require.NoError(t, err) 25 | 26 | lm, err := log.NewManager(fm, "logfile") 27 | require.NoError(t, err) 28 | 29 | bm := buffer.NewManager(fm, lm, 8) 30 | 31 | transaction := tx.NewTransaction(fm, lm, bm, concurrency.NewLockTable()) 32 | 33 | tableSchema := record.NewSchema() 34 | tableSchema.AddIntField("block") 35 | tableSchema.AddIntField("id") 36 | tableSchema.AddStringField("data_value", 20) 37 | 38 | statInfo := NewStatInfo(10, 100, map[string]int{ 39 | "block": 10, 40 | "id": 100, 41 | "data_value": 20, 42 | }) 43 | 44 | indexInfo := NewIndexInfo( 45 | "test_index", 46 | "data_value", 47 | tableSchema, 48 | transaction, 49 | statInfo, 50 | ) 51 | 52 | cleanup := func() { 53 | if err := transaction.Commit(); err != nil { 54 | t.Error(err) 55 | } 56 | if err := os.RemoveAll(dbDir); err != nil { 57 | t.Error(err) 58 | } 59 | } 60 | 61 | return indexInfo, transaction, cleanup 62 | } 63 | 64 | func TestIndexInfo_InsertAndValidate(t *testing.T) { 65 | indexInfo, _, cleanup := setupIndexInfoTest(t) 66 | defer cleanup() 67 | 68 | idx := indexInfo.Open() 69 | 70 | // Insert records into the index 71 | err := idx.Insert("key1", record.NewID(1, 1)) 72 | require.NoError(t, err) 73 | err = idx.Insert("key2", record.NewID(2, 2)) 74 | require.NoError(t, err) 75 | err = idx.Insert("key1", record.NewID(3, 3)) // Duplicate key with different ID 76 | require.NoError(t, err) 77 | 78 | // Validate RecordsOutput and DistinctValues 79 | assert.Equal(t, 100/20, indexInfo.RecordsOutput(), "RecordsOutput mismatch") // numRecords / distinctValues 80 | assert.Equal(t, 1, indexInfo.DistinctValues("data_value"), "DistinctValues mismatch for indexed field") 81 | assert.Equal(t, 10, indexInfo.DistinctValues("block"), "DistinctValues mismatch for non-indexed field") 82 | } 83 | 84 | func TestIndexInfo_DeleteAndValidate(t *testing.T) { 85 | indexInfo, _, cleanup := setupIndexInfoTest(t) 86 | defer cleanup() 87 | 88 | idx := indexInfo.Open() 89 | 90 | // Insert and delete a record 91 | err := idx.Insert("key1", record.NewID(1, 1)) 92 | require.NoError(t, err) 93 | err = idx.Delete("key1", record.NewID(1, 1)) 94 | require.NoError(t, err) 95 | 96 | // Verify RecordsOutput and DistinctValues remain consistent 97 | assert.Equal(t, 100/20, indexInfo.RecordsOutput(), "RecordsOutput mismatch after deletion") 98 | assert.Equal(t, 1, indexInfo.DistinctValues("data_value"), "DistinctValues mismatch for indexed field after deletion") 99 | } 100 | 101 | func TestIndexInfo_CreateIndexLayout(t *testing.T) { 102 | indexInfo, _, cleanup := setupIndexInfoTest(t) 103 | defer cleanup() 104 | 105 | layout := indexInfo.CreateIndexLayout() 106 | require.NotNil(t, layout) 107 | 108 | schema := layout.Schema() 109 | assert.True(t, schema.HasField(common.BlockField)) 110 | assert.True(t, schema.HasField(common.IDField)) 111 | assert.True(t, schema.HasField(common.DataValueField)) 112 | assert.Equal(t, types.Varchar, schema.Type(common.DataValueField)) 113 | } 114 | -------------------------------------------------------------------------------- /metadata/index_manager.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/record" 6 | "github.com/JyotinderSingh/dropdb/table" 7 | "github.com/JyotinderSingh/dropdb/tx" 8 | ) 9 | 10 | const ( 11 | indexCatalogTable = "index_catalog" 12 | indexNameField = "index_name" 13 | ) 14 | 15 | // IndexManager is responsible for managing indexes in the database. 16 | type IndexManager struct { 17 | layout *record.Layout 18 | tableManager *TableManager 19 | StatManager *StatManager 20 | } 21 | 22 | // NewIndexManager creates a new IndexManager instance. 23 | // This method is called during system startup. 24 | // If the database is new, then the idxCatalog is created. 25 | func NewIndexManager(isNew bool, tableManager *TableManager, statManager *StatManager, transaction *tx.Transaction) (*IndexManager, error) { 26 | if isNew { 27 | schema := record.NewSchema() 28 | schema.AddStringField(indexNameField, maxNameLength) 29 | schema.AddStringField(tableNameField, maxNameLength) 30 | schema.AddStringField(fieldNameField, maxNameLength) 31 | 32 | if err := tableManager.CreateTable(indexCatalogTable, schema, transaction); err != nil { 33 | return nil, err 34 | } 35 | } 36 | 37 | layout, err := tableManager.GetLayout(indexCatalogTable, transaction) 38 | if err != nil { 39 | return nil, err 40 | } 41 | 42 | return &IndexManager{ 43 | layout: layout, 44 | tableManager: tableManager, 45 | StatManager: statManager, 46 | }, nil 47 | } 48 | 49 | // CreateIndex creates a new index of the specified type for the specified field. 50 | // A unique ID is assigned to this index, and its information is stored in the indexCatalogTable. 51 | func (im *IndexManager) CreateIndex(indexName, tableName, fieldName string, transaction *tx.Transaction) error { 52 | tableScan, err := table.NewTableScan(transaction, indexCatalogTable, im.layout) 53 | if err != nil { 54 | return fmt.Errorf("failed to create table scan: %w", err) 55 | } 56 | defer tableScan.Close() 57 | 58 | if err := tableScan.Insert(); err != nil { 59 | return fmt.Errorf("failed to insert into table scan: %w", err) 60 | } 61 | 62 | if err := tableScan.SetString(indexNameField, indexName); err != nil { 63 | return fmt.Errorf("failed to set string: %w", err) 64 | } 65 | 66 | if err := tableScan.SetString(tableNameField, tableName); err != nil { 67 | return fmt.Errorf("failed to set string: %w", err) 68 | } 69 | 70 | if err := tableScan.SetString(fieldNameField, fieldName); err != nil { 71 | return fmt.Errorf("failed to set string: %w", err) 72 | } 73 | 74 | return nil 75 | } 76 | 77 | // GetIndexInfo returns a map containing the index info for all indexes on the specified table. 78 | func (im *IndexManager) GetIndexInfo(tableName string, transaction *tx.Transaction) (map[string]*IndexInfo, error) { 79 | tableScan, err := table.NewTableScan(transaction, indexCatalogTable, im.layout) 80 | if err != nil { 81 | return nil, err 82 | } 83 | defer tableScan.Close() 84 | 85 | result := make(map[string]*IndexInfo) 86 | 87 | for { 88 | hasNext, err := tableScan.Next() 89 | if err != nil { 90 | return nil, err 91 | } 92 | if !hasNext { 93 | break 94 | } 95 | 96 | currentTableName, err := tableScan.GetString(tableNameField) 97 | if err != nil { 98 | return nil, err 99 | } 100 | if currentTableName != tableName { 101 | continue 102 | } 103 | 104 | var indexName, fieldName string 105 | 106 | if indexName, err = tableScan.GetString(indexNameField); err != nil { 107 | return nil, err 108 | } 109 | if fieldName, err = tableScan.GetString(fieldNameField); err != nil { 110 | return nil, err 111 | } 112 | 113 | tableLayout, err := im.tableManager.GetLayout(tableName, transaction) 114 | if err != nil { 115 | return nil, err 116 | } 117 | 118 | tableStatInfo, err := im.StatManager.GetStatInfo(tableName, tableLayout, transaction) 119 | if err != nil { 120 | return nil, err 121 | } 122 | 123 | indexInfo := NewIndexInfo(indexName, fieldName, tableLayout.Schema(), transaction, tableStatInfo) 124 | result[fieldName] = indexInfo 125 | } 126 | 127 | return result, nil 128 | } 129 | -------------------------------------------------------------------------------- /metadata/index_manager_test.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/table" 5 | "testing" 6 | 7 | "github.com/JyotinderSingh/dropdb/record" 8 | "github.com/JyotinderSingh/dropdb/tx" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func setupIndexManagerTest(t *testing.T) (*TableManager, *IndexManager, *tx.Transaction, func()) { 14 | t.Helper() 15 | 16 | tm, txn, cleanup := setupTestMetadata(400, t) 17 | sm, err := NewStatManager(tm, txn, 100) 18 | require.NoError(t, err) 19 | indexManager, err := NewIndexManager(true, tm, sm, txn) 20 | require.NoError(t, err) 21 | return tm, indexManager, txn, cleanup 22 | } 23 | 24 | func TestIndexManager_CreateIndex(t *testing.T) { 25 | tm, indexManager, txn, cleanup := setupIndexManagerTest(t) 26 | defer cleanup() 27 | 28 | // Define schema and create a table 29 | schema := record.NewSchema() 30 | schema.AddIntField("id") 31 | schema.AddStringField("name", 20) 32 | err := tm.CreateTable("test_table", schema, txn) 33 | require.NoError(t, err) 34 | 35 | // Create an index on the "id" field 36 | err = indexManager.CreateIndex("test_index", "test_table", "id", txn) 37 | require.NoError(t, err) 38 | 39 | // Verify index metadata in index_catalog 40 | indexCatalogLayout := indexManager.layout 41 | ts, err := table.NewTableScan(txn, indexCatalogTable, indexCatalogLayout) 42 | require.NoError(t, err) 43 | defer ts.Close() 44 | 45 | err = ts.BeforeFirst() 46 | require.NoError(t, err) 47 | 48 | found := false 49 | for { 50 | hasNext, err := ts.Next() 51 | require.NoError(t, err) 52 | if !hasNext { 53 | break 54 | } 55 | 56 | indexName, err := ts.GetString(indexNameField) 57 | require.NoError(t, err) 58 | 59 | if indexName != "test_index" { 60 | continue 61 | } 62 | 63 | tableName, err := ts.GetString("table_name") 64 | require.NoError(t, err) 65 | assert.Equal(t, "test_table", tableName, "Table name mismatch in index_catalog") 66 | 67 | fieldName, err := ts.GetString("field_name") 68 | require.NoError(t, err) 69 | assert.Equal(t, "id", fieldName, "Field name mismatch in index_catalog") 70 | 71 | found = true 72 | break 73 | } 74 | 75 | assert.True(t, found, "Index not found in index_catalog") 76 | } 77 | 78 | func TestIndexManager_GetIndexInfo(t *testing.T) { 79 | tm, indexManager, txn, cleanup := setupIndexManagerTest(t) 80 | defer cleanup() 81 | 82 | // Define schema and create a table 83 | schema := record.NewSchema() 84 | schema.AddIntField("id") 85 | err := tm.CreateTable("test_table", schema, txn) 86 | require.NoError(t, err) 87 | 88 | // Create an index on the "id" field 89 | err = indexManager.CreateIndex("test_index", "test_table", "id", txn) 90 | require.NoError(t, err) 91 | 92 | // Retrieve index info 93 | indexInfos, err := indexManager.GetIndexInfo("test_table", txn) 94 | require.NoError(t, err) 95 | assert.Contains(t, indexInfos, "id", "IndexInfo for 'id' field not found") 96 | 97 | indexInfo := indexInfos["id"] 98 | assert.Equal(t, "test_index", indexInfo.indexName) 99 | assert.Equal(t, "id", indexInfo.fieldName) 100 | 101 | // Open the index and perform operations 102 | idx := indexInfo.Open() 103 | err = idx.Insert(1234, record.NewID(1, 1)) 104 | require.NoError(t, err) 105 | err = idx.BeforeFirst(1234) 106 | require.NoError(t, err) 107 | hasNext, err := idx.Next() 108 | require.NoError(t, err) 109 | assert.True(t, hasNext) 110 | } 111 | -------------------------------------------------------------------------------- /metadata/metadata_manager.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/record" 5 | "github.com/JyotinderSingh/dropdb/tx" 6 | ) 7 | 8 | type Manager struct { 9 | tableManager *TableManager 10 | viewManager *ViewManager 11 | statManager *StatManager 12 | indexManager *IndexManager 13 | } 14 | 15 | func NewManager(isNew bool, transaction *tx.Transaction) (*Manager, error) { 16 | m := &Manager{} 17 | 18 | var err error 19 | if m.tableManager, err = NewTableManager(isNew, transaction); err != nil { 20 | return nil, err 21 | } 22 | if m.viewManager, err = NewViewManager(isNew, m.tableManager, transaction); err != nil { 23 | return nil, err 24 | } 25 | if m.statManager, err = NewStatManager(m.tableManager, transaction, 0); err != nil { 26 | return nil, err 27 | } 28 | if m.indexManager, err = NewIndexManager(isNew, m.tableManager, m.statManager, transaction); err != nil { 29 | return nil, err 30 | } 31 | 32 | return m, nil 33 | } 34 | 35 | // CreateTable creates a new table having the specified name and schema. 36 | func (m *Manager) CreateTable(tableName string, schema *record.Schema, transaction *tx.Transaction) error { 37 | return m.tableManager.CreateTable(tableName, schema, transaction) 38 | } 39 | 40 | // GetLayout returns the layout of the specified table from the catalog. 41 | func (m *Manager) GetLayout(tableName string, transaction *tx.Transaction) (*record.Layout, error) { 42 | return m.tableManager.GetLayout(tableName, transaction) 43 | } 44 | 45 | // CreateView creates a view. 46 | func (m *Manager) CreateView(viewName, viewDefinition string, transaction *tx.Transaction) error { 47 | return m.viewManager.CreateView(viewName, viewDefinition, transaction) 48 | } 49 | 50 | // GetViewDefinition returns the definition of the specified view. 51 | func (m *Manager) GetViewDefinition(viewName string, transaction *tx.Transaction) (string, error) { 52 | return m.viewManager.GetViewDefinition(viewName, transaction) 53 | } 54 | 55 | // CreateIndex creates a new index of the specified type for the specified field. 56 | // A unique ID is assigned to this index, and its information is stored in the indexCatalogTable. 57 | func (m *Manager) CreateIndex(indexName, tableName, fieldName string, transaction *tx.Transaction) error { 58 | return m.indexManager.CreateIndex(indexName, tableName, fieldName, transaction) 59 | } 60 | 61 | // GetIndexInfo returns a map containing the index info for all indexes on the specified table. 62 | func (m *Manager) GetIndexInfo(tableName string, transaction *tx.Transaction) (map[string]*IndexInfo, error) { 63 | return m.indexManager.GetIndexInfo(tableName, transaction) 64 | } 65 | 66 | // GetStatInfo returns statistical information about the specified table. 67 | // It refreshes statistics periodically based on the refreshLimit. 68 | func (m *Manager) GetStatInfo(tableName string, layout *record.Layout, transaction *tx.Transaction) (*StatInfo, error) { 69 | return m.statManager.GetStatInfo(tableName, layout, transaction) 70 | } 71 | -------------------------------------------------------------------------------- /metadata/stat_info.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | type StatInfo struct { 4 | numBlocks int 5 | numRecords int 6 | distinctValues map[string]int 7 | } 8 | 9 | // NewStatInfo creates a new StatInfo object with calculated distinct values. 10 | func NewStatInfo(numBlocks, numRecords int, distinctValues map[string]int) *StatInfo { 11 | return &StatInfo{ 12 | numBlocks: numBlocks, 13 | numRecords: numRecords, 14 | distinctValues: distinctValues, 15 | } 16 | } 17 | 18 | // BlocksAccessed returns the estimated number of blocks in the table. 19 | func (si *StatInfo) BlocksAccessed() int { 20 | return si.numBlocks 21 | } 22 | 23 | // RecordsOutput returns the estimated number of records in the table. 24 | func (si *StatInfo) RecordsOutput() int { 25 | return si.numRecords 26 | } 27 | 28 | // DistinctValues returns the estimated number of distinct values for a given field in the table. 29 | // Returns -1 if the field is not found. 30 | func (si *StatInfo) DistinctValues(fieldName string) int { 31 | if val, ok := si.distinctValues[fieldName]; ok { 32 | return val 33 | } 34 | return -1 // Default to -1 if the field is not found 35 | } 36 | -------------------------------------------------------------------------------- /metadata/stat_manager.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/record" 5 | "github.com/JyotinderSingh/dropdb/table" 6 | "github.com/JyotinderSingh/dropdb/tx" 7 | "sync" 8 | ) 9 | 10 | type StatManager struct { 11 | tableManager *TableManager 12 | tableStats map[string]*StatInfo 13 | numCalls int 14 | mu sync.Mutex 15 | refreshLimit int 16 | } 17 | 18 | // NewStatManager creates a new StatManager instance, initializing statistics by scanning the entire database. 19 | func NewStatManager(tableManager *TableManager, transaction *tx.Transaction, refreshLimit int) (*StatManager, error) { 20 | statMgr := &StatManager{ 21 | tableManager: tableManager, 22 | tableStats: make(map[string]*StatInfo), 23 | refreshLimit: refreshLimit, 24 | } 25 | if err := statMgr.RefreshStatistics(transaction); err != nil { 26 | return nil, err 27 | } 28 | return statMgr, nil 29 | } 30 | 31 | // GetStatInfo returns statistical information about the specified table. 32 | // It refreshes statistics periodically based on the refreshLimit. 33 | func (sm *StatManager) GetStatInfo(tableName string, layout *record.Layout, transaction *tx.Transaction) (*StatInfo, error) { 34 | sm.mu.Lock() 35 | defer sm.mu.Unlock() 36 | 37 | sm.numCalls++ 38 | if sm.numCalls > sm.refreshLimit { 39 | // Call the internal refresh that expects the lock to already be held 40 | if err := sm._refreshStatistics(transaction); err != nil { 41 | return nil, err 42 | } 43 | } 44 | 45 | if statInfo, exists := sm.tableStats[tableName]; exists { 46 | return statInfo, nil 47 | } 48 | 49 | // Calculate statistics if not already available 50 | statInfo, err := sm.calcTableStats(tableName, layout, transaction) 51 | if err != nil { 52 | return nil, err 53 | } 54 | sm.tableStats[tableName] = statInfo 55 | return statInfo, nil 56 | } 57 | 58 | // RefreshStatistics publicly forces a refresh of all table statistics. 59 | // This is useful if something external triggers a refresh. 60 | func (sm *StatManager) RefreshStatistics(transaction *tx.Transaction) error { 61 | sm.mu.Lock() 62 | defer sm.mu.Unlock() 63 | return sm._refreshStatistics(transaction) 64 | } 65 | 66 | // _refreshStatistics recalculates statistics for all tables in the database. 67 | // It assumes the caller already holds sm.mu. 68 | func (sm *StatManager) _refreshStatistics(transaction *tx.Transaction) error { 69 | // Since the caller already holds the lock, do NOT lock here. 70 | 71 | sm.tableStats = make(map[string]*StatInfo) 72 | sm.numCalls = 0 73 | 74 | tableCatalogLayout, err := sm.tableManager.GetLayout(tableCatalogTable, transaction) 75 | if err != nil { 76 | return err 77 | } 78 | tableCatalogTableScan, err := table.NewTableScan(transaction, tableCatalogTable, tableCatalogLayout) 79 | if err != nil { 80 | return err 81 | } 82 | defer tableCatalogTableScan.Close() 83 | 84 | for { 85 | hasNext, err := tableCatalogTableScan.Next() 86 | if err != nil { 87 | return err 88 | } 89 | if !hasNext { 90 | break 91 | } 92 | 93 | tblName, err := tableCatalogTableScan.GetString(tableNameField) 94 | if err != nil { 95 | return err 96 | } 97 | 98 | layout, err := sm.tableManager.GetLayout(tblName, transaction) 99 | if err != nil { 100 | return err 101 | } 102 | 103 | statInfo, err := sm.calcTableStats(tblName, layout, transaction) 104 | if err != nil { 105 | return err 106 | } 107 | sm.tableStats[tblName] = statInfo 108 | } 109 | 110 | return nil 111 | } 112 | 113 | // calcTableStats calculates the number of records, blocks, and distinct values for a specific table. 114 | func (sm *StatManager) calcTableStats(tableName string, layout *record.Layout, transaction *tx.Transaction) (*StatInfo, error) { 115 | numRecords := 0 116 | numBlocks := 0 117 | distinctValues := make(map[string]map[any]interface{}) // field name -> distinct values 118 | 119 | for _, field := range layout.Schema().Fields() { 120 | distinctValues[field] = make(map[any]interface{}) 121 | } 122 | 123 | ts, err := table.NewTableScan(transaction, tableName, layout) 124 | if err != nil { 125 | return nil, err 126 | } 127 | defer ts.Close() 128 | 129 | for { 130 | hasNext, err := ts.Next() 131 | if err != nil { 132 | return nil, err 133 | } 134 | if !hasNext { 135 | break 136 | } 137 | 138 | numRecords++ 139 | rid := ts.GetRecordID() 140 | if rid.BlockNumber() >= numBlocks { 141 | numBlocks = rid.BlockNumber() + 1 142 | } 143 | 144 | // Track distinct values for each field 145 | for _, field := range layout.Schema().Fields() { 146 | val, err := ts.GetVal(field) 147 | if err != nil { 148 | return nil, err 149 | } 150 | distinctValues[field][val] = struct{}{} 151 | } 152 | } 153 | 154 | distinctCounts := make(map[string]int) 155 | for field, values := range distinctValues { 156 | distinctCounts[field] = len(values) 157 | } 158 | 159 | return NewStatInfo(numBlocks, numRecords, distinctCounts), nil 160 | } 161 | -------------------------------------------------------------------------------- /metadata/stat_manager_test.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/JyotinderSingh/dropdb/record" 7 | "github.com/JyotinderSingh/dropdb/table" 8 | "github.com/JyotinderSingh/dropdb/tx" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // setupStatMgr initializes a StatManager for testing. 14 | func setupStatMgr(t *testing.T, refreshLimit int) (*StatManager, *TableManager, *tx.Transaction, func()) { 15 | tm, txn, cleanup := setupTestMetadata(400, t) 16 | statMgr, err := NewStatManager(tm, txn, refreshLimit) 17 | require.NoError(t, err) 18 | return statMgr, tm, txn, cleanup 19 | } 20 | 21 | func TestStatMgr_GetStatInfo(t *testing.T) { 22 | statMgr, tableManager, txn, cleanup := setupStatMgr(t, 100) 23 | defer cleanup() 24 | 25 | // Create a schema and a table 26 | schema := record.NewSchema() 27 | schema.AddIntField("id") 28 | schema.AddStringField("name", 20) 29 | err := tableManager.CreateTable("test_table", schema, txn) 30 | require.NoError(t, err) 31 | 32 | // Insert some data 33 | layout, err := tableManager.GetLayout("test_table", txn) 34 | require.NoError(t, err) 35 | ts, err := table.NewTableScan(txn, "test_table", layout) 36 | require.NoError(t, err) 37 | defer ts.Close() 38 | 39 | for i := 1; i <= 10; i++ { 40 | require.NoError(t, ts.Insert()) 41 | require.NoError(t, ts.SetInt("id", i)) 42 | require.NoError(t, ts.SetString("name", "name"+string(rune(i)))) 43 | } 44 | 45 | // Retrieve statistics 46 | stats, err := statMgr.GetStatInfo("test_table", layout, txn) 47 | require.NoError(t, err) 48 | 49 | // Validate statistics 50 | assert.Equal(t, 10, stats.RecordsOutput(), "Number of records mismatch") 51 | assert.Equal(t, 4, stats.BlocksAccessed(), "Number of blocks mismatch") 52 | assert.Equal(t, 10, stats.DistinctValues("id"), "Distinct values for 'id' mismatch") 53 | assert.Equal(t, 10, stats.DistinctValues("name"), "Distinct values for 'name' mismatch") 54 | } 55 | 56 | func TestStatMgr_RefreshStatistics(t *testing.T) { 57 | statMgr, tableManager, txn, cleanup := setupStatMgr(t, 2) 58 | defer cleanup() 59 | 60 | // Create a schema and a table 61 | schema := record.NewSchema() 62 | schema.AddIntField("id") 63 | schema.AddStringField("name", 20) 64 | err := tableManager.CreateTable("test_table", schema, txn) 65 | require.NoError(t, err) 66 | 67 | // Insert some data 68 | layout, err := tableManager.GetLayout("test_table", txn) 69 | require.NoError(t, err) 70 | ts, err := table.NewTableScan(txn, "test_table", layout) 71 | require.NoError(t, err) 72 | defer ts.Close() 73 | 74 | for i := 1; i <= 5; i++ { 75 | require.NoError(t, ts.Insert()) 76 | require.NoError(t, ts.SetInt("id", i)) 77 | require.NoError(t, ts.SetString("name", "name"+string(rune(i)))) 78 | } 79 | 80 | // Call GetStatInfo twice to trigger a refresh 81 | for i := 0; i < 3; i++ { 82 | _, err := statMgr.GetStatInfo("test_table", layout, txn) 83 | require.NoError(t, err) 84 | } 85 | 86 | // Confirm that statistics are refreshed 87 | stats, err := statMgr.GetStatInfo("test_table", layout, txn) 88 | require.NoError(t, err) 89 | assert.Equal(t, 5, stats.RecordsOutput(), "Number of records mismatch after refresh") 90 | } 91 | -------------------------------------------------------------------------------- /metadata/view_manager.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/record" 5 | "github.com/JyotinderSingh/dropdb/table" 6 | "github.com/JyotinderSingh/dropdb/tx" 7 | ) 8 | 9 | const ( 10 | maxViewDefinitionLength = 100 11 | viewNameField = "view_name" 12 | viewDefinitionField = "view_definition" 13 | viewCatalogTable = "view_catalog" 14 | ) 15 | 16 | type ViewManager struct { 17 | tableManager *TableManager 18 | } 19 | 20 | // NewViewManager creates a new ViewManager. 21 | func NewViewManager(isNew bool, tableManager *TableManager, tx *tx.Transaction) (*ViewManager, error) { 22 | vm := &ViewManager{tableManager: tableManager} 23 | 24 | if isNew { 25 | schema := record.NewSchema() 26 | schema.AddStringField(viewNameField, maxNameLength) 27 | schema.AddStringField(viewDefinitionField, maxViewDefinitionLength) 28 | if err := vm.tableManager.CreateTable(viewCatalogTable, schema, tx); err != nil { 29 | return nil, err 30 | } 31 | } 32 | 33 | return vm, nil 34 | } 35 | 36 | // CreateView creates a view. 37 | func (vm *ViewManager) CreateView(viewName, viewDefinition string, tx *tx.Transaction) error { 38 | layout, err := vm.tableManager.GetLayout(viewCatalogTable, tx) 39 | if err != nil { 40 | return err 41 | } 42 | 43 | viewCatalogTableScan, err := table.NewTableScan(tx, viewCatalogTable, layout) 44 | if err != nil { 45 | return err 46 | } 47 | defer viewCatalogTableScan.Close() 48 | 49 | if err := viewCatalogTableScan.Insert(); err != nil { 50 | return err 51 | } 52 | if err := viewCatalogTableScan.SetString(viewNameField, viewName); err != nil { 53 | return err 54 | } 55 | return viewCatalogTableScan.SetString(viewDefinitionField, viewDefinition) 56 | } 57 | 58 | // GetViewDefinition returns the definition of the specified view. Returns an empty string if the view does not exist. 59 | func (vm *ViewManager) GetViewDefinition(viewName string, tx *tx.Transaction) (string, error) { 60 | layout, err := vm.tableManager.GetLayout(viewCatalogTable, tx) 61 | if err != nil { 62 | return "", err 63 | } 64 | 65 | viewCatalogTableScan, err := table.NewTableScan(tx, viewCatalogTable, layout) 66 | if err != nil { 67 | return "", err 68 | } 69 | defer viewCatalogTableScan.Close() 70 | 71 | for { 72 | hasNext, err := viewCatalogTableScan.Next() 73 | if err != nil { 74 | return "", err 75 | } 76 | if !hasNext { 77 | break 78 | } 79 | 80 | name, err := viewCatalogTableScan.GetString(viewNameField) 81 | if err != nil { 82 | return "", err 83 | } 84 | 85 | if name == viewName { 86 | definition, err := viewCatalogTableScan.GetString(viewDefinitionField) 87 | if err != nil { 88 | return "", err 89 | } 90 | 91 | return definition, nil 92 | } 93 | } 94 | 95 | return "", nil 96 | } 97 | -------------------------------------------------------------------------------- /metadata/view_manager_test.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/JyotinderSingh/dropdb/table" 7 | "github.com/JyotinderSingh/dropdb/tx" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func setupTestViewManager(t *testing.T) (*ViewManager, *tx.Transaction, func()) { 13 | tm, txn, cleanup := setupTestMetadata(800, t) // Assume setupTestMetadata initializes a TableManager and Transaction 14 | viewManager, err := NewViewManager(true, tm, txn) 15 | require.NoError(t, err) 16 | return viewManager, txn, cleanup 17 | } 18 | 19 | func TestViewManager_CreateView(t *testing.T) { 20 | vm, txn, cleanup := setupTestViewManager(t) 21 | defer cleanup() 22 | 23 | // Create a view 24 | viewName := "test_view" 25 | viewDefinition := "SELECT * FROM test_table" 26 | err := vm.CreateView(viewName, viewDefinition, txn) 27 | require.NoError(t, err) 28 | 29 | // Validate the view exists in the view catalog 30 | layout, err := vm.tableManager.GetLayout(viewCatalogTable, txn) 31 | require.NoError(t, err) 32 | 33 | viewCatalogScan, err := table.NewTableScan(txn, viewCatalogTable, layout) 34 | require.NoError(t, err) 35 | defer viewCatalogScan.Close() 36 | 37 | err = viewCatalogScan.BeforeFirst() 38 | require.NoError(t, err) 39 | 40 | found := false 41 | for { 42 | hasNext, err := viewCatalogScan.Next() 43 | require.NoError(t, err) 44 | if !hasNext { 45 | break 46 | } 47 | 48 | name, err := viewCatalogScan.GetString(viewNameField) 49 | require.NoError(t, err) 50 | 51 | if name == viewName { 52 | definition, err := viewCatalogScan.GetString(viewDefinitionField) 53 | require.NoError(t, err) 54 | assert.Equal(t, viewDefinition, definition, "View definition mismatch") 55 | found = true 56 | break 57 | } 58 | } 59 | 60 | assert.True(t, found, "View not found in view catalog") 61 | } 62 | 63 | func TestViewManager_GetViewDefinition(t *testing.T) { 64 | vm, txn, cleanup := setupTestViewManager(t) 65 | defer cleanup() 66 | 67 | // Create a view 68 | viewName := "test_view" 69 | viewDefinition := "SELECT * FROM test_table" 70 | err := vm.CreateView(viewName, viewDefinition, txn) 71 | require.NoError(t, err) 72 | 73 | // Retrieve the view definition 74 | retrievedDefinition, err := vm.GetViewDefinition(viewName, txn) 75 | require.NoError(t, err) 76 | assert.Equal(t, viewDefinition, retrievedDefinition, "Retrieved view definition mismatch") 77 | } 78 | -------------------------------------------------------------------------------- /parse/create_index_data.go: -------------------------------------------------------------------------------- 1 | package parse 2 | 3 | type CreateIndexData struct { 4 | indexName string 5 | tableName string 6 | fieldName string 7 | } 8 | 9 | func NewCreateIndexData(indexName, tableName, fieldName string) *CreateIndexData { 10 | return &CreateIndexData{ 11 | indexName: indexName, 12 | tableName: tableName, 13 | fieldName: fieldName, 14 | } 15 | } 16 | 17 | func (cid *CreateIndexData) IndexName() string { 18 | return cid.indexName 19 | } 20 | 21 | func (cid *CreateIndexData) TableName() string { 22 | return cid.tableName 23 | } 24 | 25 | func (cid *CreateIndexData) FieldName() string { 26 | return cid.fieldName 27 | } 28 | -------------------------------------------------------------------------------- /parse/create_table_data.go: -------------------------------------------------------------------------------- 1 | package parse 2 | 3 | import "github.com/JyotinderSingh/dropdb/record" 4 | 5 | type CreateTableData struct { 6 | tableName string 7 | schema *record.Schema 8 | } 9 | 10 | func NewCreateTableData(tableName string, sch *record.Schema) *CreateTableData { 11 | return &CreateTableData{ 12 | tableName: tableName, 13 | schema: sch, 14 | } 15 | } 16 | 17 | func (ctd *CreateTableData) TableName() string { 18 | return ctd.tableName 19 | } 20 | 21 | func (ctd *CreateTableData) NewSchema() *record.Schema { 22 | return ctd.schema 23 | } 24 | -------------------------------------------------------------------------------- /parse/create_view_data.go: -------------------------------------------------------------------------------- 1 | package parse 2 | 3 | type CreateViewData struct { 4 | viewName string 5 | queryData *QueryData 6 | } 7 | 8 | func NewCreateViewData(viewName string, queryData *QueryData) *CreateViewData { 9 | return &CreateViewData{ 10 | viewName: viewName, 11 | queryData: queryData, 12 | } 13 | } 14 | 15 | func (cvd *CreateViewData) ViewName() string { 16 | return cvd.viewName 17 | } 18 | 19 | func (cvd *CreateViewData) ViewDefinition() string { 20 | return cvd.queryData.String() 21 | } 22 | -------------------------------------------------------------------------------- /parse/delete_data.go: -------------------------------------------------------------------------------- 1 | package parse 2 | 3 | import "github.com/JyotinderSingh/dropdb/query" 4 | 5 | type DeleteData struct { 6 | tableName string 7 | predicate *query.Predicate 8 | } 9 | 10 | func NewDeleteData(tableName string, predicate *query.Predicate) *DeleteData { 11 | return &DeleteData{ 12 | tableName: tableName, 13 | predicate: predicate, 14 | } 15 | } 16 | 17 | func (dd *DeleteData) TableName() string { 18 | return dd.tableName 19 | } 20 | 21 | func (dd *DeleteData) Predicate() *query.Predicate { 22 | return dd.predicate 23 | } 24 | -------------------------------------------------------------------------------- /parse/insert_data.go: -------------------------------------------------------------------------------- 1 | package parse 2 | 3 | type InsertData struct { 4 | tableName string 5 | fields []string 6 | values []any 7 | } 8 | 9 | func NewInsertData(tableName string, fields []string, values []any) *InsertData { 10 | return &InsertData{ 11 | tableName: tableName, 12 | fields: fields, 13 | values: values, 14 | } 15 | } 16 | 17 | func (id *InsertData) TableName() string { 18 | return id.tableName 19 | } 20 | 21 | func (id *InsertData) Fields() []string { 22 | return id.fields 23 | } 24 | 25 | func (id *InsertData) Values() []any { 26 | return id.values 27 | } 28 | -------------------------------------------------------------------------------- /parse/modify_data.go: -------------------------------------------------------------------------------- 1 | package parse 2 | 3 | import "github.com/JyotinderSingh/dropdb/query" 4 | 5 | type ModifyData struct { 6 | tableName string 7 | fieldName string 8 | newValue *query.Expression 9 | predicate *query.Predicate 10 | } 11 | 12 | func NewModifyData(tableName, fieldName string, newVal *query.Expression, pred *query.Predicate) *ModifyData { 13 | return &ModifyData{ 14 | tableName: tableName, 15 | fieldName: fieldName, 16 | newValue: newVal, 17 | predicate: pred, 18 | } 19 | } 20 | 21 | func (md *ModifyData) TableName() string { 22 | return md.tableName 23 | } 24 | 25 | func (md *ModifyData) TargetField() string { 26 | return md.fieldName 27 | } 28 | 29 | func (md *ModifyData) NewValue() *query.Expression { 30 | return md.newValue 31 | } 32 | 33 | func (md *ModifyData) Predicate() *query.Predicate { 34 | return md.predicate 35 | } 36 | -------------------------------------------------------------------------------- /parse/predicate_parser.go: -------------------------------------------------------------------------------- 1 | package parse 2 | 3 | type PredParser struct { 4 | lex *Lexer 5 | } 6 | 7 | func NewPredParser(s string) *PredParser { 8 | return &PredParser{lex: NewLexer(s)} 9 | } 10 | 11 | func (pp *PredParser) field() (string, error) { 12 | return pp.lex.EatId() 13 | } 14 | 15 | func (pp *PredParser) constant() error { 16 | if pp.lex.MatchStringConstant() { 17 | _, err := pp.lex.EatStringConstant() 18 | return err 19 | } else { 20 | _, err := pp.lex.EatIntConstant() 21 | return err 22 | } 23 | } 24 | 25 | func (pp *PredParser) expression() error { 26 | if pp.lex.MatchId() { 27 | _, err := pp.field() 28 | return err 29 | } else { 30 | return pp.constant() 31 | } 32 | } 33 | 34 | func (pp *PredParser) term() error { 35 | // expression 36 | if err := pp.expression(); err != nil { 37 | return err 38 | } 39 | // eat '=' 40 | if err := pp.lex.EatDelim('='); err != nil { 41 | return err 42 | } 43 | // next expression 44 | if err := pp.expression(); err != nil { 45 | return err 46 | } 47 | return nil 48 | } 49 | 50 | func (pp *PredParser) predicate() error { 51 | if err := pp.term(); err != nil { 52 | return err 53 | } 54 | if pp.lex.MatchKeyword("and") { 55 | // eat "and" 56 | if err := pp.lex.EatKeyword("and"); err != nil { 57 | return err 58 | } 59 | return pp.predicate() 60 | } 61 | return nil 62 | } 63 | -------------------------------------------------------------------------------- /parse/query_data.go: -------------------------------------------------------------------------------- 1 | package parse 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/query" 5 | "github.com/JyotinderSingh/dropdb/query/functions" 6 | ) 7 | 8 | type OrderByItem struct { 9 | field string 10 | descending bool 11 | } 12 | 13 | func (obi *OrderByItem) Field() string { 14 | return obi.field 15 | } 16 | 17 | type QueryData struct { 18 | fields []string 19 | tables []string 20 | predicate *query.Predicate 21 | groupBy []string // Fields to group by 22 | having *query.Predicate // Having clause predicate 23 | orderBy []OrderByItem // Order by clause items 24 | aggregates []functions.AggregationFunction // Aggregate functions in use 25 | } 26 | 27 | func NewQueryData(fields, tables []string, predicate *query.Predicate) *QueryData { 28 | return &QueryData{ 29 | fields: fields, 30 | tables: tables, 31 | predicate: predicate, 32 | } 33 | } 34 | 35 | func (qd *QueryData) Fields() []string { 36 | return qd.fields 37 | } 38 | 39 | func (qd *QueryData) Tables() []string { 40 | return qd.tables 41 | } 42 | 43 | func (qd *QueryData) Pred() *query.Predicate { 44 | return qd.predicate 45 | } 46 | 47 | func (qd *QueryData) GroupBy() []string { 48 | return qd.groupBy 49 | } 50 | 51 | func (qd *QueryData) Having() *query.Predicate { 52 | return qd.having 53 | } 54 | 55 | func (qd *QueryData) OrderBy() []OrderByItem { 56 | return qd.orderBy 57 | } 58 | 59 | func (qd *QueryData) Aggregates() []functions.AggregationFunction { 60 | return qd.aggregates 61 | } 62 | 63 | func (qd *QueryData) String() string { 64 | if len(qd.fields) == 0 || len(qd.tables) == 0 { 65 | return "" 66 | } 67 | result := "select " 68 | for _, fieldName := range qd.fields { 69 | result += fieldName + ", " 70 | } 71 | // remove final comma/space 72 | if len(qd.fields) > 0 { 73 | result = result[:len(result)-2] 74 | } 75 | result += " from " 76 | for _, tableName := range qd.tables { 77 | result += tableName + ", " 78 | } 79 | if len(qd.tables) > 0 { 80 | result = result[:len(result)-2] 81 | } 82 | predicateString := qd.predicate.String() 83 | if predicateString != "" { 84 | result += " where " + predicateString 85 | } 86 | return result 87 | } 88 | -------------------------------------------------------------------------------- /plan/plan.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/record" 5 | "github.com/JyotinderSingh/dropdb/scan" 6 | ) 7 | 8 | type Plan interface { 9 | // Open opens a scan corresponding to this plan. 10 | // The scan will be positioned before its first record. 11 | Open() (scan.Scan, error) 12 | 13 | // BlocksAccessed returns the estimated number of 14 | // block accesses that will occur when the scan is read to completion. 15 | BlocksAccessed() int 16 | 17 | // RecordsOutput returns the estimated number of records 18 | // in the query's output table. 19 | RecordsOutput() int 20 | 21 | // DistinctValues returns the estimated number of distinct values 22 | // for the specified field in the query's output table. 23 | DistinctValues(fieldName string) int 24 | 25 | // Schema returns the schema of the query's output table. 26 | Schema() *record.Schema 27 | } 28 | -------------------------------------------------------------------------------- /plan_impl/basic_query_planner.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/metadata" 5 | "github.com/JyotinderSingh/dropdb/parse" 6 | "github.com/JyotinderSingh/dropdb/plan" 7 | "github.com/JyotinderSingh/dropdb/tx" 8 | ) 9 | 10 | var _ QueryPlanner = &BasicQueryPlanner{} 11 | 12 | type BasicQueryPlanner struct { 13 | metadataManager *metadata.Manager 14 | } 15 | 16 | // NewBasicQueryPlanner creates a new BasicQueryPlanner 17 | func NewBasicQueryPlanner(metadataManager *metadata.Manager) *BasicQueryPlanner { 18 | return &BasicQueryPlanner{metadataManager: metadataManager} 19 | } 20 | 21 | // CreatePlan creates a query plan as follows: 22 | // 1. Takes the product of all tables and views 23 | // 2. Applies predicate selection 24 | // 3. Applies grouping and having if specified 25 | // 4. Projects on the field list 26 | // 5. Applies ordering if specified 27 | func (qp *BasicQueryPlanner) CreatePlan(queryData *parse.QueryData, transaction *tx.Transaction) (plan.Plan, error) { 28 | // 1. Create a plan for each mentioned table or view 29 | plans := make([]plan.Plan, len(queryData.Tables())) 30 | for idx, tableName := range queryData.Tables() { 31 | viewDefinition, err := qp.metadataManager.GetViewDefinition(tableName, transaction) 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | if viewDefinition == "" { 37 | tablePlan, err := NewTablePlan(transaction, tableName, qp.metadataManager) 38 | if err != nil { 39 | return nil, err 40 | } 41 | plans[idx] = tablePlan 42 | } else { 43 | parser := parse.NewParser(viewDefinition) 44 | viewData, err := parser.Query() 45 | if err != nil { 46 | return nil, err 47 | } 48 | 49 | viewPlan, err := qp.CreatePlan(viewData, transaction) 50 | if err != nil { 51 | return nil, err 52 | } 53 | plans[idx] = viewPlan 54 | } 55 | } 56 | 57 | // 2. Create the product of all table plans 58 | var err error 59 | currentPlan := plans[0] 60 | plans = plans[1:] 61 | 62 | for _, nextPlan := range plans { 63 | planChoice1, err := NewProductPlan(currentPlan, nextPlan) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | planChoice2, err := NewProductPlan(nextPlan, currentPlan) 69 | if err != nil { 70 | return nil, err 71 | } 72 | 73 | if planChoice1.BlocksAccessed() < planChoice2.BlocksAccessed() { 74 | currentPlan = planChoice1 75 | } else { 76 | currentPlan = planChoice2 77 | } 78 | } 79 | 80 | // 3. Add a selection plan for the predicate 81 | currentPlan = NewSelectPlan(currentPlan, queryData.Pred()) 82 | 83 | projectionFields := queryData.Fields() 84 | // 4. Add grouping if specified 85 | if len(queryData.GroupBy()) > 0 { 86 | currentPlan = NewGroupByPlan(transaction, currentPlan, queryData.GroupBy(), queryData.Aggregates()) 87 | 88 | // Apply having clause if present 89 | if queryData.Having() != nil { 90 | currentPlan = NewSelectPlan(currentPlan, queryData.Having()) 91 | } 92 | 93 | for _, AggFunc := range queryData.Aggregates() { 94 | projectionFields = append(projectionFields, AggFunc.FieldName()) 95 | } 96 | } 97 | 98 | // 5. Add a projection plan for the field list 99 | currentPlan, err = NewProjectPlan(currentPlan, projectionFields) 100 | if err != nil { 101 | return nil, err 102 | } 103 | 104 | // 6. Add ordering if specified 105 | if len(queryData.OrderBy()) > 0 { 106 | sortFields := make([]string, len(queryData.OrderBy())) 107 | for i, item := range queryData.OrderBy() { 108 | // Note: Currently the SortPlan doesn't support descending order 109 | sortFields[i] = item.Field() 110 | } 111 | currentPlan = NewSortPlan(transaction, currentPlan, sortFields) 112 | } 113 | 114 | return currentPlan, nil 115 | } 116 | -------------------------------------------------------------------------------- /plan_impl/basic_update_planner.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/metadata" 5 | "github.com/JyotinderSingh/dropdb/parse" 6 | "github.com/JyotinderSingh/dropdb/plan" 7 | "github.com/JyotinderSingh/dropdb/scan" 8 | "github.com/JyotinderSingh/dropdb/tx" 9 | ) 10 | 11 | var _ UpdatePlanner = &BasicUpdatePlanner{} 12 | 13 | type BasicUpdatePlanner struct { 14 | metadataManager *metadata.Manager 15 | } 16 | 17 | // NewBasicUpdatePlanner creates a new BasicUpdatePlanner. 18 | func NewBasicUpdatePlanner(metadataManager *metadata.Manager) UpdatePlanner { 19 | return &BasicUpdatePlanner{metadataManager: metadataManager} 20 | } 21 | 22 | func (up *BasicUpdatePlanner) ExecuteDelete(data *parse.DeleteData, transaction *tx.Transaction) (int, error) { 23 | var p plan.Plan 24 | p, err := NewTablePlan(transaction, data.TableName(), up.metadataManager) 25 | if err != nil { 26 | return 0, err 27 | } 28 | 29 | p = NewSelectPlan(p, data.Predicate()) 30 | s, err := p.Open() 31 | if err != nil { 32 | return 0, err 33 | } 34 | updateScan := s.(scan.UpdateScan) 35 | defer updateScan.Close() 36 | 37 | count := 0 38 | for { 39 | hasNext, err := updateScan.Next() 40 | if err != nil || !hasNext { 41 | return count, err 42 | } 43 | 44 | if err := updateScan.Delete(); err != nil { 45 | return count, err 46 | } 47 | count++ 48 | } 49 | } 50 | 51 | func (up *BasicUpdatePlanner) ExecuteModify(data *parse.ModifyData, transaction *tx.Transaction) (int, error) { 52 | var p plan.Plan 53 | p, err := NewTablePlan(transaction, data.TableName(), up.metadataManager) 54 | if err != nil { 55 | return 0, err 56 | } 57 | 58 | p = NewSelectPlan(p, data.Predicate()) 59 | s, err := p.Open() 60 | if err != nil { 61 | return 0, err 62 | } 63 | updateScan := s.(scan.UpdateScan) 64 | defer updateScan.Close() 65 | 66 | count := 0 67 | for { 68 | hasNext, err := updateScan.Next() 69 | if err != nil || !hasNext { 70 | return count, err 71 | } 72 | 73 | val, err := data.NewValue().Evaluate(updateScan) 74 | if err != nil { 75 | return count, err 76 | } 77 | if err := updateScan.SetVal(data.TargetField(), val); err != nil { 78 | return count, err 79 | } 80 | count++ 81 | } 82 | } 83 | 84 | func (up *BasicUpdatePlanner) ExecuteInsert(data *parse.InsertData, transaction *tx.Transaction) (int, error) { 85 | p, err := NewTablePlan(transaction, data.TableName(), up.metadataManager) 86 | if err != nil { 87 | return 0, err 88 | } 89 | 90 | s, err := p.Open() 91 | if err != nil { 92 | return 0, err 93 | } 94 | updateScan := s.(scan.UpdateScan) 95 | defer updateScan.Close() 96 | 97 | if err := updateScan.Insert(); err != nil { 98 | return 0, err 99 | } 100 | 101 | vals := data.Values() 102 | for idx, field := range data.Fields() { 103 | val := vals[idx] 104 | if err := updateScan.SetVal(field, val); err != nil { 105 | return 0, err 106 | } 107 | } 108 | 109 | return 1, nil 110 | } 111 | 112 | func (up *BasicUpdatePlanner) ExecuteCreateTable(data *parse.CreateTableData, transaction *tx.Transaction) (int, error) { 113 | err := up.metadataManager.CreateTable(data.TableName(), data.NewSchema(), transaction) 114 | return 0, err 115 | } 116 | 117 | func (up *BasicUpdatePlanner) ExecuteCreateView(data *parse.CreateViewData, transaction *tx.Transaction) (int, error) { 118 | err := up.metadataManager.CreateView(data.ViewName(), data.ViewDefinition(), transaction) 119 | return 0, err 120 | } 121 | 122 | func (up *BasicUpdatePlanner) ExecuteCreateIndex(data *parse.CreateIndexData, transaction *tx.Transaction) (int, error) { 123 | err := up.metadataManager.CreateIndex(data.IndexName(), data.TableName(), data.FieldName(), transaction) 124 | return 0, err 125 | } 126 | -------------------------------------------------------------------------------- /plan_impl/group_by_plan.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/plan" 5 | "github.com/JyotinderSingh/dropdb/query" 6 | "github.com/JyotinderSingh/dropdb/query/functions" 7 | "github.com/JyotinderSingh/dropdb/record" 8 | "github.com/JyotinderSingh/dropdb/scan" 9 | "github.com/JyotinderSingh/dropdb/tx" 10 | ) 11 | 12 | var _ plan.Plan = &GroupByPlan{} 13 | 14 | type GroupByPlan struct { 15 | inputPlan plan.Plan 16 | groupFields []string 17 | aggregationFunctions []functions.AggregationFunction 18 | schema *record.Schema 19 | } 20 | 21 | // NewGroupByPlan creates a grorupbyy plan for the underlying 22 | // query. The grouping is determined by the specified collection 23 | // of group fields, and the aggregation is computed by the specified 24 | // aggregation functions. 25 | func NewGroupByPlan(transaction *tx.Transaction, inputPlan plan.Plan, groupFields []string, aggregationFunctions []functions.AggregationFunction) *GroupByPlan { 26 | gbp := &GroupByPlan{ 27 | inputPlan: NewSortPlan(transaction, inputPlan, groupFields), 28 | groupFields: groupFields, 29 | aggregationFunctions: aggregationFunctions, 30 | schema: record.NewSchema(), 31 | } 32 | 33 | for _, field := range groupFields { 34 | gbp.schema.Add(field, gbp.inputPlan.Schema()) 35 | } 36 | 37 | for _, f := range aggregationFunctions { 38 | gbp.schema.AddIntField(f.FieldName()) 39 | } 40 | 41 | return gbp 42 | } 43 | 44 | // Open opens a sort plan for the specified plan. 45 | // The sort plan ensures that the underlying records 46 | // will be appropriately grouped. 47 | func (p *GroupByPlan) Open() (scan.Scan, error) { 48 | sortScan, err := p.inputPlan.Open() 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | groupByScan, err := query.NewGroupByScan(sortScan, p.groupFields, p.aggregationFunctions) 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | return groupByScan, nil 59 | } 60 | 61 | // BlocksAccessed returns the estimated number of block accesses 62 | // required to compute the aggregation, 63 | // which is one pass through the sorted table. 64 | // It does not include the one-time cost of materializing and sorting the records. 65 | func (p *GroupByPlan) BlocksAccessed() int { 66 | return p.inputPlan.BlocksAccessed() 67 | } 68 | 69 | // RecordsOutput returns the number of groups. Assuming equal distribution, 70 | // this is the product of the distinct values of each grouping field. 71 | func (p *GroupByPlan) RecordsOutput() int { 72 | numGroups := 1 73 | for _, field := range p.groupFields { 74 | numGroups *= p.inputPlan.DistinctValues(field) 75 | } 76 | return numGroups 77 | } 78 | 79 | // DistinctValues are the number of distinct values for the specified field. 80 | // If the field is a grouping field, then the number of distinct values is the 81 | // same as in the underlying query. 82 | // If the field is an aggregation field, then we assume that all the values are distinct. 83 | func (p *GroupByPlan) DistinctValues(fieldName string) int { 84 | if p.schema.HasField(fieldName) { 85 | return p.inputPlan.DistinctValues(fieldName) 86 | } 87 | return p.RecordsOutput() 88 | } 89 | 90 | // Schema returns the schema of the output table. 91 | // The schema consists of the grouping fields and the aggregation fields. 92 | func (p *GroupByPlan) Schema() *record.Schema { 93 | return p.schema 94 | } 95 | -------------------------------------------------------------------------------- /plan_impl/index_join_plan.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/metadata" 6 | "github.com/JyotinderSingh/dropdb/plan" 7 | "github.com/JyotinderSingh/dropdb/query" 8 | "github.com/JyotinderSingh/dropdb/record" 9 | "github.com/JyotinderSingh/dropdb/scan" 10 | "github.com/JyotinderSingh/dropdb/table" 11 | ) 12 | 13 | var _ plan.Plan = &IndexJoinPlan{} 14 | 15 | // IndexJoinPlan is a plan that corresponds to an index join operation. 16 | type IndexJoinPlan struct { 17 | plan1 plan.Plan 18 | plan2 plan.Plan 19 | indexInfo metadata.IndexInfo 20 | joinField string 21 | schema *record.Schema 22 | } 23 | 24 | // NewIndexJoinPlan creates a new IndexJoinPlan with the given plans and index info 25 | func NewIndexJoinPlan(plan1, plan2 plan.Plan, indexInfo metadata.IndexInfo, joinField string) *IndexJoinPlan { 26 | ijp := &IndexJoinPlan{ 27 | plan1: plan1, 28 | plan2: plan2, 29 | indexInfo: indexInfo, 30 | joinField: joinField, 31 | schema: record.NewSchema(), 32 | } 33 | 34 | ijp.schema.AddAll(plan1.Schema()) 35 | ijp.schema.AddAll(plan2.Schema()) 36 | 37 | return ijp 38 | } 39 | 40 | // Open opens an index join scan for this query. 41 | func (ijp *IndexJoinPlan) Open() (scan.Scan, error) { 42 | s1, err := ijp.plan1.Open() 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | s2, err := ijp.plan2.Open() 48 | if err != nil { 49 | return nil, err 50 | } 51 | tableScan, ok := s2.(*table.Scan) 52 | if !ok { 53 | return nil, fmt.Errorf("first plan is not a table scan") 54 | } 55 | 56 | idx := ijp.indexInfo.Open() 57 | 58 | return query.NewIndexJoinScan(s1, tableScan, ijp.joinField, idx) 59 | } 60 | 61 | // BlocksAccessed estimates the number of block access to compute the join. 62 | // The formula is 63 | // blocks(indexjoin(p1, p2, idx)) = blocks(p1) + Rows(p1)*blocks(idx) + rows(indexjoin(p1, p2, idx)) 64 | func (ijp *IndexJoinPlan) BlocksAccessed() int { 65 | return ijp.plan1.BlocksAccessed() + (ijp.plan1.RecordsOutput() * ijp.indexInfo.BlocksAccessed()) + ijp.RecordsOutput() 66 | } 67 | 68 | // RecordsOutput estimates the number of output records after performing the join. 69 | // The formula is 70 | // rows(indexjoin(p1, p2, idx)) = rows(p1) * rows(idx) 71 | func (ijp *IndexJoinPlan) RecordsOutput() int { 72 | return ijp.plan1.RecordsOutput() * ijp.indexInfo.RecordsOutput() 73 | } 74 | 75 | // DistinctValues estimates the number of distinct values for the specified field. 76 | func (ijp *IndexJoinPlan) DistinctValues(fieldName string) int { 77 | if ijp.plan1.Schema().HasField(fieldName) { 78 | return ijp.plan1.DistinctValues(fieldName) 79 | } 80 | return ijp.plan2.DistinctValues(fieldName) 81 | } 82 | 83 | // Schema returns the schema for the index join plan. 84 | func (ijp *IndexJoinPlan) Schema() *record.Schema { 85 | return ijp.schema 86 | } 87 | -------------------------------------------------------------------------------- /plan_impl/index_select_plan.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/metadata" 6 | "github.com/JyotinderSingh/dropdb/plan" 7 | "github.com/JyotinderSingh/dropdb/query" 8 | "github.com/JyotinderSingh/dropdb/record" 9 | "github.com/JyotinderSingh/dropdb/scan" 10 | "github.com/JyotinderSingh/dropdb/table" 11 | ) 12 | 13 | var _ plan.Plan = &IndexSelectPlan{} 14 | 15 | type IndexSelectPlan struct { 16 | inputPlan plan.Plan 17 | indexInfo *metadata.IndexInfo 18 | value any 19 | } 20 | 21 | // NewIndexSelectPlan creates a new indexselect node in the query tree 22 | // for the specified index and selection constant. 23 | func NewIndexSelectPlan(inputPlan plan.Plan, indexInfo *metadata.IndexInfo, value any) *IndexSelectPlan { 24 | return &IndexSelectPlan{ 25 | inputPlan: inputPlan, 26 | indexInfo: indexInfo, 27 | value: value, 28 | } 29 | } 30 | 31 | // Open creates a new indexselect scan for this query. 32 | func (isp *IndexSelectPlan) Open() (scan.Scan, error) { 33 | inputScan, err := isp.inputPlan.Open() 34 | if err != nil { 35 | return nil, err 36 | } 37 | tableScan, ok := inputScan.(*table.Scan) 38 | if !ok { 39 | return nil, fmt.Errorf("IndexSelectPlan requires a tablescan") 40 | } 41 | idx := isp.indexInfo.Open() 42 | return query.NewIndexSelectScan(tableScan, idx, isp.value) 43 | } 44 | 45 | // BlocksAccessed returns the estimated number of block accesses 46 | // to compute the index selection, which is the same as the index 47 | // traversal cost plus the number of matching data records. 48 | func (isp *IndexSelectPlan) BlocksAccessed() int { 49 | return isp.indexInfo.BlocksAccessed() + isp.RecordsOutput() 50 | } 51 | 52 | // RecordsOutput returns the estimated number of records in the 53 | // index selection, which is the same as the number of search 54 | // key values for the index. 55 | func (isp *IndexSelectPlan) RecordsOutput() int { 56 | return isp.indexInfo.RecordsOutput() 57 | } 58 | 59 | // DistinctValues returns the estimated number of distinct values 60 | // as defined by the index. 61 | func (isp *IndexSelectPlan) DistinctValues(fieldName string) int { 62 | return isp.indexInfo.DistinctValues(fieldName) 63 | } 64 | 65 | // Schema returns the schema of the data table. 66 | func (isp *IndexSelectPlan) Schema() *record.Schema { 67 | return isp.inputPlan.Schema() 68 | } 69 | -------------------------------------------------------------------------------- /plan_impl/materialize_plan.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/materialize" 5 | "github.com/JyotinderSingh/dropdb/plan" 6 | "github.com/JyotinderSingh/dropdb/record" 7 | "github.com/JyotinderSingh/dropdb/scan" 8 | "github.com/JyotinderSingh/dropdb/tx" 9 | "math" 10 | ) 11 | 12 | // MaterializePlan represents the Plan for the materialize operator. 13 | type MaterializePlan struct { 14 | srcPlan plan.Plan 15 | tx *tx.Transaction 16 | } 17 | 18 | // NewMaterializePlan creates a materialize plan for the specified query. 19 | func NewMaterializePlan(tx *tx.Transaction, srcPlan plan.Plan) *MaterializePlan { 20 | return &MaterializePlan{ 21 | srcPlan: srcPlan, 22 | tx: tx, 23 | } 24 | } 25 | 26 | // Open loops through the underlying query, copying its output records into a temporary table. 27 | // It then returns a table scan for that table. 28 | func (mp *MaterializePlan) Open() (scan.Scan, error) { 29 | schema := mp.srcPlan.Schema() 30 | tempTable := materialize.NewTempTable(mp.tx, schema) 31 | srcScan, err := mp.srcPlan.Open() 32 | if err != nil { 33 | return nil, err 34 | } 35 | defer srcScan.Close() 36 | 37 | destinationScan, err := tempTable.Open() 38 | if err != nil { 39 | return nil, err 40 | } 41 | 42 | for { 43 | hasNext, err := srcScan.Next() 44 | if err != nil { 45 | return nil, err 46 | } 47 | if !hasNext { 48 | break 49 | } 50 | 51 | if err := destinationScan.Insert(); err != nil { 52 | return nil, err 53 | } 54 | for _, fieldName := range schema.Fields() { 55 | val, err := srcScan.GetVal(fieldName) 56 | if err != nil { 57 | return nil, err 58 | } 59 | if err := destinationScan.SetVal(fieldName, val); err != nil { 60 | return nil, err 61 | } 62 | } 63 | } 64 | 65 | if err := destinationScan.BeforeFirst(); err != nil { 66 | return nil, err 67 | } 68 | return destinationScan, nil 69 | } 70 | 71 | // BlocksAccessed returns the estimated number of blocks in the materialized table. 72 | func (mp *MaterializePlan) BlocksAccessed() int { 73 | // create a fake layout to calculate the record size 74 | layout := record.NewLayout(mp.srcPlan.Schema()) 75 | recordLength := layout.SlotSize() 76 | recordsPerBlock := float64(mp.tx.BlockSize()) / float64(recordLength) 77 | return int(math.Ceil(float64(mp.srcPlan.RecordsOutput()) / recordsPerBlock)) 78 | } 79 | 80 | // RecordsOutput returns the number of records in the materialized table. 81 | func (mp *MaterializePlan) RecordsOutput() int { 82 | return mp.srcPlan.RecordsOutput() 83 | } 84 | 85 | // DistinctValues returns the number of distinct field values, which is the same as the underlying plan. 86 | func (mp *MaterializePlan) DistinctValues(fieldName string) int { 87 | return mp.srcPlan.DistinctValues(fieldName) 88 | } 89 | 90 | // Schema returns the schema of the materialized table, which is the same as the underlying plan. 91 | func (mp *MaterializePlan) Schema() *record.Schema { 92 | return mp.srcPlan.Schema() 93 | } 94 | -------------------------------------------------------------------------------- /plan_impl/planner.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/parse" 6 | "github.com/JyotinderSingh/dropdb/plan" 7 | "github.com/JyotinderSingh/dropdb/tx" 8 | ) 9 | 10 | type Planner struct { 11 | queryPlanner QueryPlanner 12 | updatePlanner UpdatePlanner 13 | } 14 | 15 | func NewPlanner(queryPlanner QueryPlanner, updatePlanner UpdatePlanner) *Planner { 16 | return &Planner{ 17 | queryPlanner: queryPlanner, 18 | updatePlanner: updatePlanner, 19 | } 20 | } 21 | 22 | // CreateQueryPlan creates a plan for a SQL select statement, using the supplied planner. 23 | func (planner *Planner) CreateQueryPlan(sql string, transaction *tx.Transaction) (plan.Plan, error) { 24 | parser := parse.NewParser(sql) 25 | data, err := parser.Query() 26 | if err != nil { 27 | return nil, err 28 | } 29 | if err := verifyQuery(data); err != nil { 30 | return nil, err 31 | } 32 | return planner.queryPlanner.CreatePlan(data, transaction) 33 | } 34 | 35 | // ExecuteUpdate executes a SQL insert, delete, modify, or create statement. 36 | // The method dispatches to the appropriate method of the supplied update planner, 37 | // depending on what the parser returns. 38 | func (planner *Planner) ExecuteUpdate(sql string, transaction *tx.Transaction) (int, error) { 39 | parser := parse.NewParser(sql) 40 | data, err := parser.UpdateCmd() 41 | if err != nil { 42 | return 0, err 43 | } 44 | 45 | if err := verifyUpdate(data); err != nil { 46 | return 0, err 47 | } 48 | 49 | switch data.(type) { 50 | case *parse.InsertData: 51 | return planner.updatePlanner.ExecuteInsert(data.(*parse.InsertData), transaction) 52 | case *parse.DeleteData: 53 | return planner.updatePlanner.ExecuteDelete(data.(*parse.DeleteData), transaction) 54 | case *parse.ModifyData: 55 | return planner.updatePlanner.ExecuteModify(data.(*parse.ModifyData), transaction) 56 | case *parse.CreateTableData: 57 | return planner.updatePlanner.ExecuteCreateTable(data.(*parse.CreateTableData), transaction) 58 | case *parse.CreateViewData: 59 | return planner.updatePlanner.ExecuteCreateView(data.(*parse.CreateViewData), transaction) 60 | case *parse.CreateIndexData: 61 | return planner.updatePlanner.ExecuteCreateIndex(data.(*parse.CreateIndexData), transaction) 62 | default: 63 | return 0, fmt.Errorf("unexpected type %T", data) 64 | } 65 | } 66 | 67 | func verifyQuery(data *parse.QueryData) error { 68 | // TODO: Implement this 69 | return nil 70 | } 71 | 72 | func verifyUpdate(data any) error { 73 | // TODO: Implement this 74 | return nil 75 | } 76 | -------------------------------------------------------------------------------- /plan_impl/product_plan.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/plan" 5 | "github.com/JyotinderSingh/dropdb/query" 6 | "github.com/JyotinderSingh/dropdb/record" 7 | "github.com/JyotinderSingh/dropdb/scan" 8 | ) 9 | 10 | var _ plan.Plan = &ProductPlan{} 11 | 12 | type ProductPlan struct { 13 | plan1 plan.Plan 14 | plan2 plan.Plan 15 | schema *record.Schema 16 | } 17 | 18 | // NewProductPlan creates a new product node in the query tree, 19 | // having the specified subqueries. 20 | func NewProductPlan(plan1 plan.Plan, plan2 plan.Plan) (*ProductPlan, error) { 21 | pp := &ProductPlan{plan1: plan1, plan2: plan2, schema: record.NewSchema()} 22 | pp.schema.AddAll(plan1.Schema()) 23 | pp.schema.AddAll(plan2.Schema()) 24 | return pp, nil 25 | } 26 | 27 | // Open creates a product scan for this query. 28 | func (pp *ProductPlan) Open() (scan.Scan, error) { 29 | s1, err := pp.plan1.Open() 30 | if err != nil { 31 | return nil, err 32 | } 33 | s2, err := pp.plan2.Open() 34 | if err != nil { 35 | return nil, err 36 | } 37 | return query.NewProductScan(s1, s2), nil 38 | } 39 | 40 | // BlocksAccessed estimates the number of block accesses in the product, 41 | // The formula is: blocks(plan1) + records(plan1) * blocks(plan2). 42 | func (pp *ProductPlan) BlocksAccessed() int { 43 | return pp.plan1.BlocksAccessed() + pp.plan1.RecordsOutput()*pp.plan2.BlocksAccessed() 44 | } 45 | 46 | // RecordsOutput estimates the number of records in the product. 47 | // The formula is: records(plan1) * records(plan2). 48 | func (pp *ProductPlan) RecordsOutput() int { 49 | return pp.plan1.RecordsOutput() * pp.plan2.RecordsOutput() 50 | } 51 | 52 | // DistinctValues estimates the number of distinct field values in the product. 53 | // Since the product does not increase or decrease field valuese, 54 | // the estimate is the same as in the appropriate subplan. 55 | func (pp *ProductPlan) DistinctValues(fieldName string) int { 56 | if pp.plan1.Schema().HasField(fieldName) { 57 | return pp.plan1.DistinctValues(fieldName) 58 | } 59 | return pp.plan2.DistinctValues(fieldName) 60 | } 61 | 62 | // Schema returns the schema of the product, 63 | // which is the concatenation subplans' schemas. 64 | func (pp *ProductPlan) Schema() *record.Schema { 65 | return pp.schema 66 | } 67 | -------------------------------------------------------------------------------- /plan_impl/project_plan.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/plan" 5 | "github.com/JyotinderSingh/dropdb/query" 6 | "github.com/JyotinderSingh/dropdb/record" 7 | "github.com/JyotinderSingh/dropdb/scan" 8 | ) 9 | 10 | var _ plan.Plan = &ProjectPlan{} 11 | 12 | type ProjectPlan struct { 13 | inputPlan plan.Plan 14 | schema *record.Schema 15 | } 16 | 17 | // NewProjectPlan creates a new project node in the query tree, 18 | // having the specified subquery and field list. 19 | func NewProjectPlan(inputPlan plan.Plan, fieldList []string) (*ProjectPlan, error) { 20 | pp := &ProjectPlan{inputPlan: inputPlan, schema: record.NewSchema()} 21 | 22 | for _, fieldName := range fieldList { 23 | pp.schema.Add(fieldName, inputPlan.Schema()) 24 | } 25 | 26 | return pp, nil 27 | } 28 | 29 | // Open creates a project scan for this query. 30 | func (pp *ProjectPlan) Open() (scan.Scan, error) { 31 | inputScan, err := pp.inputPlan.Open() 32 | if err != nil { 33 | return nil, err 34 | } 35 | return query.NewProjectScan(inputScan, pp.schema.Fields()) 36 | } 37 | 38 | // BlocksAccessed estimates the number of block accesses in the projection, 39 | // which is the same as in the underlying query. 40 | func (pp *ProjectPlan) BlocksAccessed() int { 41 | return pp.inputPlan.BlocksAccessed() 42 | } 43 | 44 | // RecordsOutput estimates the number of records in the projection, 45 | // which is the same as in the underlying query. 46 | func (pp *ProjectPlan) RecordsOutput() int { 47 | return pp.inputPlan.RecordsOutput() 48 | } 49 | 50 | // DistinctValues estimates the number of distinct values in the projection, 51 | // which is the same as in the underlying query. 52 | func (pp *ProjectPlan) DistinctValues(fieldName string) int { 53 | return pp.inputPlan.DistinctValues(fieldName) 54 | } 55 | 56 | // Schema returns the schema of the projection, 57 | // which is taken from the field list. 58 | func (pp *ProjectPlan) Schema() *record.Schema { 59 | return pp.schema 60 | } 61 | -------------------------------------------------------------------------------- /plan_impl/project_plan_test.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/require" 8 | 9 | "github.com/JyotinderSingh/dropdb/scan" 10 | ) 11 | 12 | func TestProjectPlan_Basic(t *testing.T) { 13 | // 1) Setup environment 14 | txn, cleanup := setupTestEnvironment(t, 800, 8) 15 | defer cleanup() 16 | 17 | // 2) Create table "users" with three fields 18 | mdm := createTableMetadataWithSchema(t, txn, "users", map[string]interface{}{ 19 | "id": 0, 20 | "name": "string", 21 | "active": true, 22 | }) 23 | 24 | // 3) Create a TablePlan to insert and read from "users" 25 | tp, err := NewTablePlan(txn, "users", mdm) 26 | require.NoError(t, err) 27 | 28 | s, err := tp.Open() 29 | require.NoError(t, err) 30 | defer s.Close() 31 | 32 | us, ok := s.(scan.UpdateScan) 33 | require.True(t, ok) 34 | 35 | // Insert some test data 36 | records := []map[string]interface{}{ 37 | {"id": 1, "name": "Alice", "active": true}, 38 | {"id": 2, "name": "Bob", "active": false}, 39 | {"id": 3, "name": "Carol", "active": true}, 40 | } 41 | insertRecords(t, us, records) 42 | 43 | // Re-instantiate TablePlan after insertion to refresh stats 44 | tp, err = NewTablePlan(txn, "users", mdm) 45 | require.NoError(t, err) 46 | 47 | // 4) Create a ProjectPlan selecting only ["id", "name"] 48 | projectedFields := []string{"id", "name"} 49 | pp, err := NewProjectPlan(tp, projectedFields) 50 | require.NoError(t, err) 51 | 52 | // 5) Open the ProjectPlan and verify records 53 | projectScan, err := pp.Open() 54 | require.NoError(t, err) 55 | defer projectScan.Close() 56 | 57 | require.NoError(t, projectScan.BeforeFirst()) 58 | 59 | readCount := 0 60 | for { 61 | hasNext, err := projectScan.Next() 62 | require.NoError(t, err) 63 | if !hasNext { 64 | break 65 | } 66 | readCount++ 67 | 68 | // We should be able to read 'id' and 'name' 69 | userID, err := projectScan.GetInt("id") 70 | require.NoError(t, err) 71 | userName, err := projectScan.GetString("name") 72 | require.NoError(t, err) 73 | 74 | // But 'active' is not in the projection, so it should fail: 75 | // either it returns an error or you must not call GetBool("active"). 76 | // We'll just check the schema instead: 77 | hasActive := pp.Schema().HasField("active") 78 | assert.False(t, hasActive, "Schema should NOT include 'active' in projection") 79 | 80 | // Ensure returned values match one of the inserted rows 81 | // (id, name) among our test records. 82 | var found bool 83 | for _, rec := range records { 84 | if rec["id"] == userID && rec["name"] == userName { 85 | found = true 86 | break 87 | } 88 | } 89 | assert.True(t, found, "Projected (id,name) should match an inserted record") 90 | } 91 | assert.Equal(t, len(records), readCount, "Projected scan should return all rows, but only 2 fields") 92 | 93 | // 6) Validate plan-level stats 94 | // ProjectPlan does not change #blocks accessed or #records; it only hides fields. 95 | assert.Equal(t, tp.BlocksAccessed(), pp.BlocksAccessed()) 96 | assert.Equal(t, tp.RecordsOutput(), pp.RecordsOutput()) 97 | 98 | // Distinct values for "id" is the same as the underlying plan's estimate 99 | distinctID := pp.DistinctValues("id") 100 | assert.Equal(t, tp.DistinctValues("id"), distinctID) 101 | 102 | // 7) Validate the projected schema 103 | schema := pp.Schema() 104 | require.NotNil(t, schema) 105 | assert.True(t, schema.HasField("id")) 106 | assert.True(t, schema.HasField("name")) 107 | assert.False(t, schema.HasField("active")) 108 | assert.Len(t, schema.Fields(), 2, "Schema should only have 2 fields in the projection") 109 | } 110 | -------------------------------------------------------------------------------- /plan_impl/query_planner.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/parse" 5 | "github.com/JyotinderSingh/dropdb/plan" 6 | "github.com/JyotinderSingh/dropdb/tx" 7 | ) 8 | 9 | // QueryPlanner is an interface implemented by planners for the SQL select statement. 10 | type QueryPlanner interface { 11 | // CreatePlan creates a query plan for the specified query data. 12 | CreatePlan(queryData *parse.QueryData, transaction *tx.Transaction) (plan.Plan, error) 13 | } 14 | -------------------------------------------------------------------------------- /plan_impl/select_plan.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/plan" 5 | "github.com/JyotinderSingh/dropdb/query" 6 | "github.com/JyotinderSingh/dropdb/record" 7 | "github.com/JyotinderSingh/dropdb/scan" 8 | "github.com/JyotinderSingh/dropdb/types" 9 | ) 10 | 11 | var _ plan.Plan = &SelectPlan{} 12 | 13 | type SelectPlan struct { 14 | inputPlan plan.Plan 15 | predicate *query.Predicate 16 | } 17 | 18 | // NewSelectPlan creates a new select node in the query tree, 19 | // having the specified subquery and predicate. 20 | func NewSelectPlan(inputPlan plan.Plan, predicate *query.Predicate) *SelectPlan { 21 | return &SelectPlan{ 22 | inputPlan: inputPlan, 23 | predicate: predicate, 24 | } 25 | } 26 | 27 | // Open creates a select scan for this query. 28 | func (sp *SelectPlan) Open() (scan.Scan, error) { 29 | inputScan, err := sp.inputPlan.Open() 30 | if err != nil { 31 | return nil, err 32 | } 33 | return query.NewSelectScan(inputScan, sp.predicate) 34 | } 35 | 36 | // BlocksAccessed estimates the number of block accesses in the selection, 37 | // which is the same as in the underlying query. 38 | func (sp *SelectPlan) BlocksAccessed() int { 39 | return sp.inputPlan.BlocksAccessed() 40 | } 41 | 42 | // RecordsOutput estimates the number of records in the selection, 43 | // which is determined by the reduction factor of the predicate. 44 | func (sp *SelectPlan) RecordsOutput() int { 45 | return sp.inputPlan.RecordsOutput() / sp.predicate.ReductionFactor(sp.inputPlan) 46 | } 47 | 48 | // DistinctValues estimates the number of distinct values in the projection. 49 | // This is a heuristic estimate based on the predicate. It's not always accurate. 50 | // We can probably improve this estimate by considering the actual data. 51 | func (sp *SelectPlan) DistinctValues(fieldName string) int { 52 | // 1) If there's an equality check for fieldName = constant, it's 1 distinct value. 53 | if sp.predicate.EquatesWithConstant(fieldName) != nil { 54 | return 1 55 | } 56 | 57 | // 2) If there's an equality check for fieldName = someOtherField 58 | fieldName2 := sp.predicate.EquatesWithField(fieldName) 59 | if fieldName2 != "" { 60 | return min( 61 | sp.inputPlan.DistinctValues(fieldName), 62 | sp.inputPlan.DistinctValues(fieldName2), 63 | ) 64 | } 65 | 66 | // 3) Check for range comparisons (fieldName < c, > c, <= c, >= c, <> c, etc.) 67 | op, _ := sp.predicate.ComparesWithConstant(fieldName) 68 | switch op { 69 | case types.LT, types.LE, types.GT, types.GE: 70 | // A naive heuristic: cut the number of distinct values in half 71 | // because we're restricting to a range. 72 | return max(1, sp.inputPlan.DistinctValues(fieldName)/2) 73 | 74 | case types.NE: 75 | // “not equal” typically leaves most of the domain, but at least 76 | // it excludes 1 distinct value if we know which constant is being excluded. 77 | distinct := sp.inputPlan.DistinctValues(fieldName) 78 | if distinct > 1 { 79 | return distinct - 1 80 | } 81 | return 1 // if there's only 1 or 0 possible distinct values, clamp to 1 82 | 83 | default: 84 | // If there's no relevant range comparison or none is recognized, 85 | // fall back to the underlying plan’s estimate. 86 | return sp.inputPlan.DistinctValues(fieldName) 87 | } 88 | } 89 | 90 | // Schema returns the schema of the selection, 91 | // which is the same as the schema of the underlying query. 92 | func (sp *SelectPlan) Schema() *record.Schema { 93 | return sp.inputPlan.Schema() 94 | } 95 | -------------------------------------------------------------------------------- /plan_impl/table_plan.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/metadata" 5 | "github.com/JyotinderSingh/dropdb/plan" 6 | "github.com/JyotinderSingh/dropdb/record" 7 | "github.com/JyotinderSingh/dropdb/scan" 8 | "github.com/JyotinderSingh/dropdb/table" 9 | "github.com/JyotinderSingh/dropdb/tx" 10 | ) 11 | 12 | var _ plan.Plan = &TablePlan{} 13 | 14 | type TablePlan struct { 15 | tableName string 16 | transaction *tx.Transaction 17 | layout *record.Layout 18 | statInfo *metadata.StatInfo 19 | } 20 | 21 | // NewTablePlan creates a leaf node in the query tree 22 | // corresponding to the specified table. 23 | func NewTablePlan(transaction *tx.Transaction, tableName string, metadataManager *metadata.Manager) (*TablePlan, error) { 24 | tp := &TablePlan{ 25 | tableName: tableName, 26 | transaction: transaction, 27 | } 28 | 29 | var err error 30 | if tp.layout, err = metadataManager.GetLayout(tableName, transaction); err != nil { 31 | return nil, err 32 | } 33 | if tp.statInfo, err = metadataManager.GetStatInfo(tableName, tp.layout, transaction); err != nil { 34 | return nil, err 35 | } 36 | return tp, nil 37 | } 38 | 39 | // Open creates a table scan for this query 40 | func (tp *TablePlan) Open() (scan.Scan, error) { 41 | return table.NewTableScan(tp.transaction, tp.tableName, tp.layout) 42 | } 43 | 44 | // BlocksAccessed estimates the number of block accesses for the table, 45 | // which is obtainable from the statistics manager. 46 | func (tp *TablePlan) BlocksAccessed() int { 47 | return tp.statInfo.BlocksAccessed() 48 | } 49 | 50 | // RecordsOutput estimates the number of records in the table, 51 | // which is obtainable from the statistics manager. 52 | func (tp *TablePlan) RecordsOutput() int { 53 | return tp.statInfo.RecordsOutput() 54 | } 55 | 56 | // DistinctValues estimates the number of distinct values for the specified field 57 | // in the table, which is obtainable from the stats manager. 58 | func (tp *TablePlan) DistinctValues(fieldName string) int { 59 | return tp.statInfo.DistinctValues(fieldName) 60 | } 61 | 62 | // Schema determines the schema of the table, 63 | // which is obtainable from the catalog manager 64 | func (tp *TablePlan) Schema() *record.Schema { 65 | return tp.layout.Schema() 66 | } 67 | -------------------------------------------------------------------------------- /plan_impl/update_planner.go: -------------------------------------------------------------------------------- 1 | package plan_impl 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/parse" 5 | "github.com/JyotinderSingh/dropdb/tx" 6 | ) 7 | 8 | type UpdatePlanner interface { 9 | // ExecuteInsert executes the specified insert statement, and 10 | // returns the numbeb of affected records. 11 | ExecuteInsert(data *parse.InsertData, transaction *tx.Transaction) (int, error) 12 | 13 | // ExecuteDelete executes the specified delete statement, and 14 | // returns the number of affected records. 15 | ExecuteDelete(data *parse.DeleteData, transaction *tx.Transaction) (int, error) 16 | 17 | // ExecuteModify executes the specified modify statement, and 18 | // returns the number of affected records. 19 | ExecuteModify(data *parse.ModifyData, transaction *tx.Transaction) (int, error) 20 | 21 | // ExecuteCreateTable executes the specified create table statement, and 22 | // returns the number of affected records. 23 | ExecuteCreateTable(data *parse.CreateTableData, transaction *tx.Transaction) (int, error) 24 | 25 | // ExecuteCreateView executes the specified create view statement, and 26 | // returns the number of affected records. 27 | ExecuteCreateView(data *parse.CreateViewData, transaction *tx.Transaction) (int, error) 28 | 29 | // ExecuteCreateIndex executes the specified create index statement, and 30 | // returns the number of affected records. 31 | ExecuteCreateIndex(data *parse.CreateIndexData, transaction *tx.Transaction) (int, error) 32 | } 33 | -------------------------------------------------------------------------------- /query/expression.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/record" 6 | "github.com/JyotinderSingh/dropdb/scan" 7 | ) 8 | 9 | type Expression struct { 10 | value any 11 | fieldName string 12 | } 13 | 14 | // NewFieldExpression creates a new expression for a field name. 15 | func NewFieldExpression(fieldName string) *Expression { 16 | return &Expression{value: nil, fieldName: fieldName} 17 | } 18 | 19 | // NewConstantExpression creates a new expression for a constant value. 20 | func NewConstantExpression(value any) *Expression { 21 | return &Expression{value: value, fieldName: ""} 22 | } 23 | 24 | // Evaluate the expression with respect to the current record of the specified inputScan. 25 | func (e *Expression) Evaluate(inputScan scan.Scan) (any, error) { 26 | if e.value != nil { 27 | return e.value, nil 28 | } 29 | return inputScan.GetVal(e.fieldName) 30 | } 31 | 32 | // IsFieldName returns true if the expression is a field reference. 33 | func (e *Expression) IsFieldName() bool { 34 | return e.fieldName != "" 35 | } 36 | 37 | // IsConstant returns true if the expression is a constant expression, 38 | // or nil if the expression does not denote a constant. 39 | func (e *Expression) asConstant() any { 40 | return e.value 41 | } 42 | 43 | // IsFieldName returns the field name if the expression is a field reference, 44 | // or an empty string if the expression does not denote a field. 45 | func (e *Expression) asFieldName() string { 46 | return e.fieldName 47 | } 48 | 49 | // AppliesTo determines if all the fields mentioned in this expression are contained in the specified schema. 50 | func (e *Expression) AppliesTo(schema *record.Schema) bool { 51 | return e.value != nil || schema.HasField(e.fieldName) 52 | } 53 | 54 | func (e *Expression) String() string { 55 | if e.value != nil { 56 | return fmt.Sprintf("%v", e.value) 57 | } 58 | return e.fieldName 59 | } 60 | -------------------------------------------------------------------------------- /query/functions/aggregation_function.go: -------------------------------------------------------------------------------- 1 | package functions 2 | 3 | import "github.com/JyotinderSingh/dropdb/scan" 4 | 5 | type AggregationFunction interface { 6 | // ProcessFirst uses the current record of the 7 | // specified scan to be the first record in the group. 8 | ProcessFirst(s scan.Scan) error 9 | 10 | // ProcessNext uses the current record of the 11 | // specified scan to be the next record in the group. 12 | ProcessNext(s scan.Scan) error 13 | 14 | // FieldName returns the name of the new aggregation 15 | // field. 16 | FieldName() string 17 | 18 | // Value returns the computed aggregation value. 19 | Value() any 20 | } 21 | -------------------------------------------------------------------------------- /query/functions/avg_function.go: -------------------------------------------------------------------------------- 1 | package functions 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/scan" 5 | ) 6 | 7 | var _ AggregationFunction = &AvgFunction{} 8 | 9 | const avgFunctionPrefix = "avgOf" 10 | 11 | type AvgFunction struct { 12 | fieldName string 13 | sum int 14 | count int 15 | } 16 | 17 | // NewAvgFunction creates a new avg aggregation function for the specified field. 18 | func NewAvgFunction(fieldName string) *AvgFunction { 19 | return &AvgFunction{ 20 | fieldName: fieldName, 21 | } 22 | } 23 | 24 | // ProcessFirst sets the initial sum and count. 25 | func (f *AvgFunction) ProcessFirst(s scan.Scan) error { 26 | val, err := s.GetVal(f.fieldName) 27 | if err != nil { 28 | return err 29 | } 30 | numVal, err := toInt(val) 31 | if err != nil { 32 | return err 33 | } 34 | f.sum = numVal 35 | f.count = 1 36 | return nil 37 | } 38 | 39 | // ProcessNext adds the field value to the sum and increments the count. 40 | func (f *AvgFunction) ProcessNext(s scan.Scan) error { 41 | val, err := s.GetVal(f.fieldName) 42 | if err != nil { 43 | return err 44 | } 45 | numVal, err := toInt(val) 46 | if err != nil { 47 | return err 48 | } 49 | f.sum += numVal 50 | f.count++ 51 | return nil 52 | } 53 | 54 | // FieldName returns the field's name, prepended by avgFunctionPrefix. 55 | func (f *AvgFunction) FieldName() string { 56 | return avgFunctionPrefix + f.fieldName 57 | } 58 | 59 | // Value returns the current average as a float64 (or int, depending on your needs). 60 | // TODO: Casts value to int for now since our database doesnt support floats yet.. 61 | func (f *AvgFunction) Value() any { 62 | if f.count == 0 { 63 | return 0 // or error if no rows 64 | } 65 | return int(f.sum / f.count) 66 | } 67 | -------------------------------------------------------------------------------- /query/functions/count_function.go: -------------------------------------------------------------------------------- 1 | package functions 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/scan" 5 | ) 6 | 7 | var _ AggregationFunction = &CountFunction{} 8 | 9 | const countFunctionPrefix = "countOf" 10 | 11 | type CountFunction struct { 12 | fieldName string 13 | count int64 14 | } 15 | 16 | // NewCountFunction creates a new count aggregation function for the specified field. 17 | // Some implementations ignore the fieldName if they want to count *all* rows. 18 | func NewCountFunction(fieldName string) *CountFunction { 19 | return &CountFunction{ 20 | fieldName: fieldName, 21 | } 22 | } 23 | 24 | // ProcessFirst initializes the count to 1. 25 | func (f *CountFunction) ProcessFirst(s scan.Scan) error { 26 | f.count = 1 27 | return nil 28 | } 29 | 30 | // ProcessNext increments the count by 1. 31 | func (f *CountFunction) ProcessNext(s scan.Scan) error { 32 | f.count++ 33 | return nil 34 | } 35 | 36 | // FieldName returns a name like "countOf". 37 | func (f *CountFunction) FieldName() string { 38 | return countFunctionPrefix + f.fieldName 39 | } 40 | 41 | // Value returns the current count. 42 | func (f *CountFunction) Value() any { 43 | return f.count 44 | } 45 | -------------------------------------------------------------------------------- /query/functions/max_function.go: -------------------------------------------------------------------------------- 1 | package functions 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/scan" 5 | "github.com/JyotinderSingh/dropdb/types" 6 | ) 7 | 8 | var _ AggregationFunction = &MaxFunction{} 9 | 10 | const maxFunctionPrefix = "maxOf" 11 | 12 | type MaxFunction struct { 13 | fieldName string 14 | value any 15 | } 16 | 17 | // NewMaxFunction creates a new max aggregation function for the specified field. 18 | func NewMaxFunction(fieldName string) *MaxFunction { 19 | return &MaxFunction{ 20 | fieldName: fieldName, 21 | } 22 | } 23 | 24 | // ProcessFirst starts a new maximum to be the field 25 | // value in the current record. 26 | func (f *MaxFunction) ProcessFirst(s scan.Scan) error { 27 | var err error 28 | f.value, err = s.GetVal(f.fieldName) 29 | return err 30 | } 31 | 32 | // ProcessNext replaces the current maximum with the field 33 | // value in the current record if it is greater. 34 | func (f *MaxFunction) ProcessNext(s scan.Scan) error { 35 | newValue, err := s.GetVal(f.fieldName) 36 | if err != nil { 37 | return err 38 | } 39 | 40 | if types.CompareSupportedTypes(newValue, f.value, types.GT) { 41 | f.value = newValue 42 | } 43 | 44 | return nil 45 | } 46 | 47 | // FieldName returns the field's name, prepended by maxFunctionPrefix. 48 | func (f *MaxFunction) FieldName() string { 49 | return maxFunctionPrefix + f.fieldName 50 | } 51 | 52 | // Value returns the current maximum value. 53 | func (f *MaxFunction) Value() any { 54 | return f.value 55 | } 56 | -------------------------------------------------------------------------------- /query/functions/min_function.go: -------------------------------------------------------------------------------- 1 | package functions 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/scan" 5 | "github.com/JyotinderSingh/dropdb/types" 6 | ) 7 | 8 | var _ AggregationFunction = &MinFunction{} 9 | 10 | const minFunctionPrefix = "minOf" 11 | 12 | type MinFunction struct { 13 | fieldName string 14 | value any 15 | } 16 | 17 | // NewMinFunction creates a new min aggregation function for the specified field. 18 | func NewMinFunction(fieldName string) *MinFunction { 19 | return &MinFunction{ 20 | fieldName: fieldName, 21 | } 22 | } 23 | 24 | // ProcessFirst starts a new minimum to be the field value in the current record. 25 | func (f *MinFunction) ProcessFirst(s scan.Scan) error { 26 | val, err := s.GetVal(f.fieldName) 27 | if err != nil { 28 | return err 29 | } 30 | f.value = val 31 | return nil 32 | } 33 | 34 | // ProcessNext replaces the current minimum with the field value in the current 35 | // record if it is smaller. 36 | func (f *MinFunction) ProcessNext(s scan.Scan) error { 37 | newVal, err := s.GetVal(f.fieldName) 38 | if err != nil { 39 | return err 40 | } 41 | 42 | if types.CompareSupportedTypes(newVal, f.value, types.LT) { 43 | f.value = newVal 44 | } 45 | return nil 46 | } 47 | 48 | // FieldName returns the field's name, prepended by minFunctionPrefix. 49 | func (f *MinFunction) FieldName() string { 50 | return minFunctionPrefix + f.fieldName 51 | } 52 | 53 | // Value returns the current minimum value. 54 | func (f *MinFunction) Value() any { 55 | return f.value 56 | } 57 | -------------------------------------------------------------------------------- /query/functions/sum_function.go: -------------------------------------------------------------------------------- 1 | package functions 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/JyotinderSingh/dropdb/scan" 7 | ) 8 | 9 | var _ AggregationFunction = &SumFunction{} 10 | 11 | const sumFunctionPrefix = "sumOf" 12 | 13 | type SumFunction struct { 14 | fieldName string 15 | sum int // Using int to make it simpler to handle types across the db. This might cause issues with 64-bit integers on 32-bit architectures. 16 | } 17 | 18 | // NewSumFunction creates a new sum aggregation function for the specified field. 19 | func NewSumFunction(fieldName string) *SumFunction { 20 | return &SumFunction{ 21 | fieldName: fieldName, 22 | } 23 | } 24 | 25 | // ProcessFirst sets the initial sum to the field value in the current record. 26 | func (f *SumFunction) ProcessFirst(s scan.Scan) error { 27 | val, err := s.GetVal(f.fieldName) 28 | if err != nil { 29 | return err 30 | } 31 | intVal, err := toInt(val) 32 | if err != nil { 33 | return err 34 | } 35 | f.sum = intVal 36 | return nil 37 | } 38 | 39 | // ProcessNext adds the field value in the current record to the running sum. 40 | func (f *SumFunction) ProcessNext(s scan.Scan) error { 41 | val, err := s.GetVal(f.fieldName) 42 | if err != nil { 43 | return err 44 | } 45 | intVal, err := toInt(val) 46 | if err != nil { 47 | return err 48 | } 49 | f.sum += intVal 50 | return nil 51 | } 52 | 53 | // FieldName returns the field's name, prepended by sumFunctionPrefix. 54 | func (f *SumFunction) FieldName() string { 55 | return sumFunctionPrefix + f.fieldName 56 | } 57 | 58 | // Value returns the current sum. 59 | func (f *SumFunction) Value() any { 60 | return f.sum 61 | } 62 | 63 | // Helper to handle int, int16, int64, or possibly other numeric types. 64 | // Using int to make it simpler to handle types across the db. 65 | // This might cause issues with 64-bit integers on 32-bit architectures. 66 | func toInt(v any) (int, error) { 67 | switch num := v.(type) { 68 | case int: 69 | return num, nil 70 | case int16: 71 | return int(num), nil 72 | case int64: 73 | return int(num), nil 74 | default: 75 | return 0, fmt.Errorf("cannot convert %T to int64 for sum", v) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /query/group_value.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/scan" 5 | "github.com/JyotinderSingh/dropdb/types" 6 | ) 7 | 8 | type GroupValue struct { 9 | values map[string]any 10 | } 11 | 12 | // NewGroupValue creates a new group value, given the specified scan 13 | // and list of fields. 14 | // The values in the current record of each field are stored. 15 | func NewGroupValue(s scan.Scan, fields []string) (*GroupValue, error) { 16 | values := make(map[string]any) 17 | for _, field := range fields { 18 | value, err := s.GetVal(field) 19 | if err != nil { 20 | return nil, err 21 | } 22 | values[field] = value 23 | } 24 | return &GroupValue{values: values}, nil 25 | } 26 | 27 | // GetVal returns the value of the specified field in the group. 28 | func (g *GroupValue) GetVal(field string) any { 29 | if val, ok := g.values[field]; !ok { 30 | return nil 31 | } else { 32 | return val 33 | } 34 | } 35 | 36 | // Equals compares the specified group value with this one. Two group 37 | // values are equal if they have the same values for their grouping fields. 38 | func (g *GroupValue) Equals(other any) bool { 39 | otherGroup, ok := other.(*GroupValue) 40 | if !ok { 41 | return false 42 | } 43 | 44 | for field, value := range g.values { 45 | value2 := otherGroup.GetVal(field) 46 | if types.CompareSupportedTypes(value, value2, types.NE) { 47 | return false 48 | } 49 | } 50 | 51 | return true 52 | } 53 | 54 | // Hash returns a hash value for the group value. 55 | // The hash of a GroupValue is the sum of hashes of its field values. 56 | func (g *GroupValue) Hash() int { 57 | hash := 0 58 | for _, value := range g.values { 59 | hash += types.Hash(value) 60 | } 61 | return hash 62 | } 63 | -------------------------------------------------------------------------------- /query/index_join_scan.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/index" 5 | "github.com/JyotinderSingh/dropdb/scan" 6 | "github.com/JyotinderSingh/dropdb/table" 7 | "time" 8 | ) 9 | 10 | var _ scan.Scan = (*IndexJoinScan)(nil) 11 | 12 | // IndexJoinScan is a scan that joins two scans using an index. 13 | // It uses the index to look up the right-hand side of the join for each row of the left-hand side. 14 | type IndexJoinScan struct { 15 | lhs scan.Scan 16 | rhs *table.Scan 17 | joinField string 18 | idx index.Index 19 | } 20 | 21 | // NewIndexJoinScan creates a new IndexJoinScan for the specified LHS scan and RHS index. 22 | func NewIndexJoinScan(lhs scan.Scan, rhs *table.Scan, joinField string, idx index.Index) (*IndexJoinScan, error) { 23 | ijs := &IndexJoinScan{ 24 | lhs: lhs, 25 | rhs: rhs, 26 | joinField: joinField, 27 | idx: idx, 28 | } 29 | 30 | if err := ijs.BeforeFirst(); err != nil { 31 | return nil, err 32 | } 33 | 34 | return ijs, nil 35 | } 36 | 37 | // BeforeFirst resets the scan and positions it before the first record. 38 | // That is, the LHS scan will be positioned at its first record, and 39 | // the RHS scan will be positioned at the first record for the join value. 40 | func (ijs *IndexJoinScan) BeforeFirst() error { 41 | if err := ijs.lhs.BeforeFirst(); err != nil { 42 | return err 43 | } 44 | 45 | if _, err := ijs.lhs.Next(); err != nil { 46 | return err 47 | } 48 | 49 | return ijs.resetIndex() 50 | } 51 | 52 | // Next advances the scan to the next record. 53 | // The method moves to the next index record, if possible. 54 | // Otherwise, it moves to the next LHS record and the 55 | // first index record. 56 | func (ijs *IndexJoinScan) Next() (bool, error) { 57 | for { 58 | hasNext, err := ijs.idx.Next() 59 | if err != nil { 60 | return false, err 61 | } 62 | 63 | if hasNext { 64 | recordID, err := ijs.idx.GetDataRecordID() 65 | if err != nil { 66 | return false, err 67 | } 68 | if err := ijs.rhs.MoveToRecordID(recordID); err != nil { 69 | return false, err 70 | } 71 | return true, nil 72 | } 73 | 74 | hasNext, err = ijs.lhs.Next() 75 | if err != nil { 76 | return false, err 77 | } 78 | if !hasNext { 79 | return false, nil 80 | } 81 | 82 | if err := ijs.resetIndex(); err != nil { 83 | return false, err 84 | } 85 | } 86 | } 87 | 88 | // GetInt returns the integer value of the specified field in the current record. 89 | func (ijs *IndexJoinScan) GetInt(fieldName string) (int, error) { 90 | if ijs.rhs.HasField(fieldName) { 91 | return ijs.rhs.GetInt(fieldName) 92 | } 93 | return ijs.lhs.GetInt(fieldName) 94 | } 95 | 96 | // GetLong returns the long value of the specified field in the current record. 97 | func (ijs *IndexJoinScan) GetLong(fieldName string) (int64, error) { 98 | if ijs.rhs.HasField(fieldName) { 99 | return ijs.rhs.GetLong(fieldName) 100 | } 101 | return ijs.lhs.GetLong(fieldName) 102 | } 103 | 104 | // GetShort returns the short value of the specified field in the current record. 105 | func (ijs *IndexJoinScan) GetShort(fieldName string) (int16, error) { 106 | if ijs.rhs.HasField(fieldName) { 107 | return ijs.rhs.GetShort(fieldName) 108 | } 109 | return ijs.lhs.GetShort(fieldName) 110 | } 111 | 112 | // GetString returns the string value of the specified field in the current record. 113 | func (ijs *IndexJoinScan) GetString(fieldName string) (string, error) { 114 | if ijs.rhs.HasField(fieldName) { 115 | return ijs.rhs.GetString(fieldName) 116 | } 117 | return ijs.lhs.GetString(fieldName) 118 | } 119 | 120 | // GetBool returns the boolean value of the specified field in the current record. 121 | func (ijs *IndexJoinScan) GetBool(fieldName string) (bool, error) { 122 | if ijs.rhs.HasField(fieldName) { 123 | return ijs.rhs.GetBool(fieldName) 124 | } 125 | return ijs.lhs.GetBool(fieldName) 126 | } 127 | 128 | // GetDate returns the date value of the specified field in the current record. 129 | func (ijs *IndexJoinScan) GetDate(fieldName string) (time.Time, error) { 130 | if ijs.rhs.HasField(fieldName) { 131 | return ijs.rhs.GetDate(fieldName) 132 | } 133 | return ijs.lhs.GetDate(fieldName) 134 | } 135 | 136 | // GetVal returns the value of the specified field in the current record. 137 | func (ijs *IndexJoinScan) GetVal(fieldName string) (any, error) { 138 | if ijs.rhs.HasField(fieldName) { 139 | return ijs.rhs.GetVal(fieldName) 140 | } 141 | return ijs.lhs.GetVal(fieldName) 142 | 143 | } 144 | 145 | // HasField returns true if the field is in the schema. 146 | func (ijs *IndexJoinScan) HasField(fieldName string) bool { 147 | return ijs.lhs.HasField(fieldName) || ijs.rhs.HasField(fieldName) 148 | } 149 | 150 | // Close closes the scan and its subscans. 151 | func (ijs *IndexJoinScan) Close() { 152 | ijs.lhs.Close() 153 | ijs.rhs.Close() 154 | ijs.idx.Close() 155 | } 156 | 157 | func (ijs *IndexJoinScan) resetIndex() error { 158 | searchKey, err := ijs.lhs.GetVal(ijs.joinField) 159 | if err != nil { 160 | return err 161 | } 162 | 163 | return ijs.idx.BeforeFirst(searchKey) 164 | } 165 | -------------------------------------------------------------------------------- /query/index_select_scan.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/index" 5 | "github.com/JyotinderSingh/dropdb/scan" 6 | "github.com/JyotinderSingh/dropdb/table" 7 | "time" 8 | ) 9 | 10 | var _ scan.Scan = (*IndexSelectScan)(nil) 11 | 12 | // IndexSelectScan is a scan that combines an index scan with a table scan. 13 | // It is used to scan the data records of a table that satisfy a selection 14 | // constant on an index. 15 | type IndexSelectScan struct { 16 | tableScan *table.Scan 17 | idx index.Index 18 | value any 19 | } 20 | 21 | // NewIndexSelectScan creates an index select scan for the specified index 22 | // and selection constant. 23 | func NewIndexSelectScan(tableScan *table.Scan, idx index.Index, value any) (*IndexSelectScan, error) { 24 | iss := &IndexSelectScan{ 25 | tableScan: tableScan, 26 | idx: idx, 27 | value: value, 28 | } 29 | if err := iss.BeforeFirst(); err != nil { 30 | return nil, err 31 | } 32 | return iss, nil 33 | } 34 | 35 | // BeforeFirst positions the scan before the first record, 36 | // which in this case means positioning the index before 37 | // the first instance of the selection constant. 38 | func (iss *IndexSelectScan) BeforeFirst() error { 39 | return iss.idx.BeforeFirst(iss.value) 40 | } 41 | 42 | // Next moves to the next record, which in this case means 43 | // moving the index to the next record satisfying the 44 | // selection constant, and returning false if there are no 45 | // more such index records. 46 | // If there is a next record, the method moves the tablescan 47 | // to the corresponding data record. 48 | func (iss *IndexSelectScan) Next() (bool, error) { 49 | next, err := iss.idx.Next() 50 | if !next || err != nil { 51 | return next, err 52 | } 53 | dataRID, err := iss.idx.GetDataRecordID() 54 | if err != nil { 55 | return false, err 56 | } 57 | return next, iss.tableScan.MoveToRecordID(dataRID) 58 | } 59 | 60 | // GetInt returns the integer value of the specified field in the current record. 61 | func (iss *IndexSelectScan) GetInt(fieldName string) (int, error) { 62 | return iss.tableScan.GetInt(fieldName) 63 | } 64 | 65 | // GetLong returns the long value of the specified field in the current record. 66 | func (iss *IndexSelectScan) GetLong(fieldName string) (int64, error) { 67 | return iss.tableScan.GetLong(fieldName) 68 | } 69 | 70 | // GetShort returns the short value of the specified field in the current record. 71 | func (iss *IndexSelectScan) GetShort(fieldName string) (int16, error) { 72 | return iss.tableScan.GetShort(fieldName) 73 | } 74 | 75 | // GetString returns the string value of the specified field in the current record. 76 | func (iss *IndexSelectScan) GetString(fieldName string) (string, error) { 77 | return iss.tableScan.GetString(fieldName) 78 | } 79 | 80 | // GetBool returns the boolean value of the specified field in the current record. 81 | func (iss *IndexSelectScan) GetBool(fieldName string) (bool, error) { 82 | return iss.tableScan.GetBool(fieldName) 83 | } 84 | 85 | // GetDate returns the date value of the specified field in the current record. 86 | func (iss *IndexSelectScan) GetDate(fieldName string) (time.Time, error) { 87 | return iss.tableScan.GetDate(fieldName) 88 | } 89 | 90 | // GetVal returns the value of the specified field in the current record. 91 | func (iss *IndexSelectScan) GetVal(fieldName string) (any, error) { 92 | return iss.tableScan.GetVal(fieldName) 93 | } 94 | 95 | // HasField returns true if the underlying scan has the specified field. 96 | func (iss *IndexSelectScan) HasField(fieldName string) bool { 97 | return iss.tableScan.HasField(fieldName) 98 | } 99 | 100 | // Close closes the scan by closing the index and the tablescan. 101 | func (iss *IndexSelectScan) Close() { 102 | iss.idx.Close() 103 | iss.tableScan.Close() 104 | } 105 | -------------------------------------------------------------------------------- /query/predicate.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/plan" 5 | "github.com/JyotinderSingh/dropdb/record" 6 | "github.com/JyotinderSingh/dropdb/scan" 7 | "github.com/JyotinderSingh/dropdb/types" 8 | ) 9 | 10 | type Predicate struct { 11 | terms []*Term 12 | } 13 | 14 | // NewPredicate creates an empty predicate, corresponding to TRUE. 15 | func NewPredicate() *Predicate { 16 | return &Predicate{terms: []*Term{}} 17 | } 18 | 19 | // NewPredicateFromTerm creates a new predicate from the specified term. 20 | func NewPredicateFromTerm(term *Term) *Predicate { 21 | return &Predicate{terms: []*Term{term}} 22 | } 23 | 24 | // ConjoinWith modifies the predicate to be the conjunction of itself and the specified predicate. 25 | func (p *Predicate) ConjoinWith(other *Predicate) { 26 | p.terms = append(p.terms, other.terms...) 27 | } 28 | 29 | // IsSatisfied returns true if the predicate evaluates to true with respect to the specified inputScan. 30 | func (p *Predicate) IsSatisfied(inputScan scan.Scan) bool { 31 | for _, term := range p.terms { 32 | if !term.IsSatisfied(inputScan) { 33 | return false 34 | } 35 | } 36 | return true 37 | } 38 | 39 | // ReductionFactor calculates the extent to which selecting on the 40 | // predicate reduces the number of records output by a query. 41 | // For example, if the reduction factor is 2, then the predicate 42 | // cuts the size of the output in half. 43 | // If the reduction factor is 1, then the predicate has no effect. 44 | func (p *Predicate) ReductionFactor(queryPlan plan.Plan) int { 45 | factor := 1 46 | for _, term := range p.terms { 47 | factor *= term.ReductionFactor(queryPlan) 48 | } 49 | return factor 50 | } 51 | 52 | // SelectSubPredicate returns the sub-predicate that applies to the specified schema. 53 | func (p *Predicate) SelectSubPredicate(schema *record.Schema) *Predicate { 54 | result := NewPredicate() 55 | for _, term := range p.terms { 56 | if term.AppliesTo(schema) { 57 | result.terms = append(result.terms, term) 58 | } 59 | } 60 | 61 | if len(result.terms) == 0 { 62 | return nil 63 | } 64 | 65 | return result 66 | } 67 | 68 | // JoinSubPredicate returns the sub-predicate consisting of terms 69 | // that apply to the union of the two specified schemas, 70 | // but not to either schema separately. 71 | func (p *Predicate) JoinSubPredicate(schema1, schema2 *record.Schema) *Predicate { 72 | result := NewPredicate() 73 | unionSchema := record.NewSchema() 74 | 75 | unionSchema.AddAll(schema1) 76 | unionSchema.AddAll(schema2) 77 | 78 | for _, term := range p.terms { 79 | if !term.AppliesTo(schema1) && !term.AppliesTo(schema2) && term.AppliesTo(unionSchema) { 80 | result.terms = append(result.terms, term) 81 | } 82 | } 83 | if len(result.terms) == 0 { 84 | return nil 85 | } 86 | return result 87 | } 88 | 89 | // EquatesWithConstant determines if there is a term of the form "F=c" 90 | // where F is the specified field and c is some constant. 91 | // If so, the constant is returned; otherwise, nil is returned. 92 | func (p *Predicate) EquatesWithConstant(fieldName string) any { 93 | for _, term := range p.terms { 94 | if c := term.EquatesWithConstant(fieldName); c != nil { 95 | return c 96 | } 97 | } 98 | return nil 99 | } 100 | 101 | // ComparesWithConstant determines if there is a term of the form "F1>c" 102 | func (p *Predicate) ComparesWithConstant(fieldName string) (types.Operator, any) { 103 | for _, term := range p.terms { 104 | if op, c := term.ComparesWithConstant(fieldName); op != types.NONE { 105 | return op, c 106 | } 107 | } 108 | return types.NONE, nil 109 | } 110 | 111 | // EquatesWithField determines if there is a term of the form "F1=F2" 112 | // where F1 is the specified field and F2 is another field. 113 | // If so, the name of the other field is returned; otherwise, an empty string is returned. 114 | func (p *Predicate) EquatesWithField(fieldName string) string { 115 | for _, term := range p.terms { 116 | if f := term.EquatesWithField(fieldName); f != "" { 117 | return f 118 | } 119 | } 120 | return "" 121 | } 122 | 123 | // String returns a string representation of the predicate. 124 | func (p *Predicate) String() string { 125 | if len(p.terms) == 0 { 126 | return "" 127 | } 128 | 129 | result := p.terms[0].String() 130 | for _, term := range p.terms[1:] { 131 | result += " and " + term.String() 132 | } 133 | 134 | return result 135 | } 136 | -------------------------------------------------------------------------------- /query/record_comparator.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/scan" 5 | "github.com/JyotinderSingh/dropdb/types" 6 | ) 7 | 8 | // RecordComparator is a comparator for scans based on a list of field names. 9 | type RecordComparator struct { 10 | fields []string 11 | } 12 | 13 | // NewRecordComparator creates a new comparator using the specified fields. 14 | func NewRecordComparator(fields []string) *RecordComparator { 15 | return &RecordComparator{fields: fields} 16 | } 17 | 18 | // Compare compares the current records of two scans based on the specified fields. Expects supported types. 19 | func (rc *RecordComparator) Compare(s1, s2 scan.Scan) int { 20 | for _, fieldName := range rc.fields { 21 | // Get values for the current field 22 | val1, err1 := s1.GetVal(fieldName) 23 | val2, err2 := s2.GetVal(fieldName) 24 | 25 | if err1 != nil || err2 != nil { 26 | panic("Error retrieving field values for comparison") 27 | } 28 | 29 | // Compare using CompareSupportedTypes with equality and ordering operators 30 | if types.CompareSupportedTypes(val1, val2, types.LT) { 31 | return -1 // val1 < val2 32 | } else if types.CompareSupportedTypes(val1, val2, types.GT) { 33 | return 1 // val1 > val2 34 | } 35 | // If neither LT nor GT, the values must be equal for this field; continue to next field. 36 | } 37 | return 0 // All fields are equal 38 | } 39 | -------------------------------------------------------------------------------- /record/alignment.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/types" 5 | ) 6 | 7 | // Data type alignments in bytes (platform-independent where possible) 8 | const ( 9 | LongAlignment = 8 10 | ShortAlignment = 2 11 | BooleanAlignment = 1 12 | DateAlignment = 8 13 | VarcharAlignment = 1 // No alignment for strings, packed tightly 14 | ) 15 | 16 | // alignmentRequirement returns the alignment size for a given field type. 17 | func alignmentRequirement(fieldType types.SchemaType) int { 18 | switch fieldType { 19 | case types.Integer: 20 | return types.IntSize 21 | case types.Long: 22 | return LongAlignment 23 | case types.Short: 24 | return ShortAlignment 25 | case types.Boolean: 26 | return BooleanAlignment 27 | case types.Date: 28 | return DateAlignment 29 | case types.Varchar: 30 | return VarcharAlignment 31 | default: 32 | return 1 // Default to no alignment for unknown types 33 | } 34 | } 35 | 36 | // Helper function to find the maximum alignment from the map 37 | func maxAlignment(fieldAlignments map[string]int) int { 38 | maxAlign := 1 39 | for _, align := range fieldAlignments { 40 | if align > maxAlign { 41 | maxAlign = align 42 | } 43 | } 44 | return maxAlign 45 | } 46 | -------------------------------------------------------------------------------- /record/id.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import "fmt" 4 | 5 | type ID struct { 6 | blockNumber int 7 | slot int 8 | } 9 | 10 | // NewID creates a new ID having the specified location in the specified block. 11 | func NewID(blockNumber, slot int) *ID { 12 | return &ID{blockNumber, slot} 13 | } 14 | 15 | // BlockNumber returns the block number of this ID. 16 | func (id *ID) BlockNumber() int { 17 | return id.blockNumber 18 | } 19 | 20 | // Slot returns the slot number of this ID. 21 | func (id *ID) Slot() int { 22 | return id.slot 23 | } 24 | 25 | func (id *ID) Equals(other *ID) bool { 26 | return id.blockNumber == other.blockNumber && id.slot == other.slot 27 | } 28 | 29 | func (id *ID) String() string { 30 | return fmt.Sprintf("[%d, %d]", id.blockNumber, id.slot) 31 | } 32 | -------------------------------------------------------------------------------- /record/layout.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | "github.com/JyotinderSingh/dropdb/types" 7 | "sort" 8 | ) 9 | 10 | // Layout describes the structure of a record. 11 | // It contains the name, type, length, and offset of 12 | // each field of a given table. 13 | type Layout struct { 14 | schema *Schema 15 | offsets map[string]int 16 | slotSize int 17 | } 18 | 19 | // NewLayout creates a new layout for a given schema. 20 | // The layout introduces padding between fields to ensure that each field is aligned 21 | // correctly to their respective alignment requirements. Certain types require specific 22 | // alignment sizes (e.g., longs are 8 bytes aligned). 23 | // The layout is optimized for space efficiency by placing fields with larger alignment 24 | // requirements first, which minimizes padding between fields. 25 | func NewLayout(schema *Schema) *Layout { 26 | layout := &Layout{ 27 | schema: schema, 28 | offsets: make(map[string]int), 29 | } 30 | 31 | // Determine the alignment and sizes of fields 32 | fieldAlignments := make(map[string]int) 33 | for _, field := range schema.Fields() { 34 | fieldAlignments[field] = alignmentRequirement(schema.Type(field)) 35 | } 36 | 37 | // Sort fields by their alignment requirements in descending order. 38 | // This ensures that fields with larger alignment requirements are placed first, which 39 | // minimizes padding between fields and reduces the overall size of the record. 40 | fields := schema.Fields() 41 | sort.Slice(fields, func(i, j int) bool { 42 | return fieldAlignments[fields[i]] > fieldAlignments[fields[j]] 43 | }) 44 | 45 | pos := types.IntSize // Reserve space for the empty/in-use field. 46 | for _, field := range fields { 47 | align := fieldAlignments[field] 48 | 49 | // Ensure alignment for the current field 50 | if pos%align != 0 { 51 | pos += align - (pos % align) 52 | } 53 | 54 | // Set the offset for the field 55 | layout.offsets[field] = pos 56 | 57 | // Move the position by the field's size 58 | pos += layout.lengthInBytes(field) 59 | } 60 | 61 | // Align the total slot size to the largest alignment requirement 62 | largestAlignment := maxAlignment(fieldAlignments) 63 | if pos%largestAlignment != 0 { 64 | pos += largestAlignment - (pos % largestAlignment) 65 | } 66 | 67 | layout.slotSize = pos 68 | return layout 69 | } 70 | 71 | // NewLayoutFromMetadata creates a new layout from the specified metadata. 72 | // This method is used when the metadata is retrieved from the catalog. 73 | func NewLayoutFromMetadata(schema *Schema, offsets map[string]int, slotSize int) *Layout { 74 | return &Layout{ 75 | schema: schema, 76 | offsets: offsets, 77 | slotSize: slotSize, 78 | } 79 | } 80 | 81 | // Schema returns the schema of the table's records. 82 | func (l *Layout) Schema() *Schema { 83 | return l.schema 84 | } 85 | 86 | // Offset returns the offset of the specified field within a record based on the layout. 87 | func (l *Layout) Offset(fieldName string) int { 88 | return l.offsets[fieldName] 89 | } 90 | 91 | // SlotSize returns the size of a record slot in bytes. 92 | func (l *Layout) SlotSize() int { 93 | return l.slotSize 94 | } 95 | 96 | // lengthInBytes returns the length of a field in bytes. 97 | func (l *Layout) lengthInBytes(fieldName string) int { 98 | fieldType := l.schema.Type(fieldName) 99 | 100 | switch fieldType { 101 | case types.Integer: 102 | return types.IntSize 103 | case types.Long: 104 | return 8 // 8 bytes for long 105 | case types.Short: 106 | return 2 // 2 bytes for short 107 | case types.Boolean: 108 | return 1 // 1 byte for boolean 109 | case types.Date: 110 | return 8 // 8 bytes for date (64 bit Unix timestamp) 111 | case types.Varchar: 112 | return file.MaxLength(l.schema.Length(fieldName)) 113 | default: 114 | panic(fmt.Sprintf("Unknown field type: %d", fieldType)) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /record/schema.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import "github.com/JyotinderSingh/dropdb/types" 4 | 5 | // Schema represents record schema of a table. 6 | // A schema contains the name and type of each 7 | // field of the table, as well as the length of 8 | // each varchar field. 9 | type Schema struct { 10 | fields []string 11 | info map[string]types.FieldInfo 12 | } 13 | 14 | // NewSchema creates a new schema. 15 | func NewSchema() *Schema { 16 | return &Schema{ 17 | fields: make([]string, 0), 18 | info: make(map[string]types.FieldInfo), 19 | } 20 | } 21 | 22 | // AddField adds a field to the schema having a specified 23 | // name, type, and length. 24 | // If the field type is not a character type, the length 25 | // value is irrelevant. 26 | func (s *Schema) AddField(fieldName string, fieldType types.SchemaType, length int) { 27 | s.fields = append(s.fields, fieldName) 28 | s.info[fieldName] = types.FieldInfo{Type: fieldType, Length: length} 29 | } 30 | 31 | // AddIntField adds an integer field to the schema. 32 | func (s *Schema) AddIntField(fieldName string) { 33 | s.AddField(fieldName, types.Integer, 0) 34 | } 35 | 36 | // AddStringField adds a string field to the schema. 37 | func (s *Schema) AddStringField(fieldName string, length int) { 38 | s.AddField(fieldName, types.Varchar, length) 39 | } 40 | 41 | // AddBoolField adds a boolean field to the schema. 42 | func (s *Schema) AddBoolField(fieldName string) { 43 | s.AddField(fieldName, types.Boolean, 0) 44 | } 45 | 46 | // AddLongField adds a long field to the schema. 47 | func (s *Schema) AddLongField(fieldName string) { 48 | s.AddField(fieldName, types.Long, 0) 49 | } 50 | 51 | // AddShortField adds a short field to the schema. 52 | func (s *Schema) AddShortField(fieldName string) { 53 | s.AddField(fieldName, types.Short, 0) 54 | } 55 | 56 | // AddDateField adds a date field to the schema. 57 | func (s *Schema) AddDateField(fieldName string) { 58 | s.AddField(fieldName, types.Date, 0) 59 | } 60 | 61 | // Add adds a field to the schema having the same 62 | // type and length as the corresponding field in 63 | // the specified schema. 64 | func (s *Schema) Add(fieldName string, other *Schema) { 65 | info := other.info[fieldName] 66 | s.AddField(fieldName, info.Type, info.Length) 67 | } 68 | 69 | // AddAll adds all the fields in the specified schema to the current schema. 70 | func (s *Schema) AddAll(other *Schema) { 71 | for _, field := range other.fields { 72 | s.Add(field, other) 73 | } 74 | } 75 | 76 | // Fields returns the names of all the fields in the schema. 77 | func (s *Schema) Fields() []string { 78 | return s.fields 79 | } 80 | 81 | // HasField returns true if the schema contains a field with the specified name. 82 | func (s *Schema) HasField(fieldName string) bool { 83 | _, ok := s.info[fieldName] 84 | return ok 85 | } 86 | 87 | // Type returns the type of the field with the specified name. 88 | func (s *Schema) Type(fieldName string) types.SchemaType { 89 | return s.info[fieldName].Type 90 | } 91 | 92 | // Length returns the length of the field with the specified name. 93 | func (s *Schema) Length(fieldName string) int { 94 | return s.info[fieldName].Length 95 | } 96 | -------------------------------------------------------------------------------- /record/schema_test.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/types" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | ) 8 | 9 | func TestAddField(t *testing.T) { 10 | s := NewSchema() 11 | 12 | tests := []struct { 13 | name string 14 | field string 15 | typ types.SchemaType 16 | length int 17 | }{ 18 | {"integer field", "age", types.Integer, 0}, 19 | {"varchar field", "name", types.Varchar, 20}, 20 | {"boolean field", "active", types.Boolean, 0}, 21 | } 22 | 23 | for _, tt := range tests { 24 | t.Run(tt.name, func(t *testing.T) { 25 | s.AddField(tt.field, tt.typ, tt.length) 26 | 27 | info, ok := s.info[tt.field] 28 | assert.True(t, ok, "Field %s not found in info map", tt.field) 29 | assert.Equal(t, tt.typ, info.Type, "Field type mismatch") 30 | assert.Equal(t, tt.length, info.Length, "Field length mismatch") 31 | }) 32 | } 33 | } 34 | 35 | func TestTypeSpecificAdders(t *testing.T) { 36 | s := NewSchema() 37 | 38 | tests := []struct { 39 | name string 40 | adder func() 41 | field string 42 | expected types.SchemaType 43 | length int 44 | }{ 45 | { 46 | "AddIntField", 47 | func() { s.AddIntField("age") }, 48 | "age", 49 | types.Integer, 50 | 0, 51 | }, 52 | { 53 | "AddStringField", 54 | func() { s.AddStringField("name", 30) }, 55 | "name", 56 | types.Varchar, 57 | 30, 58 | }, 59 | { 60 | "AddBoolField", 61 | func() { s.AddBoolField("active") }, 62 | "active", 63 | types.Boolean, 64 | 0, 65 | }, 66 | { 67 | "AddLongField", 68 | func() { s.AddLongField("id") }, 69 | "id", 70 | types.Long, 71 | 0, 72 | }, 73 | { 74 | "AddShortField", 75 | func() { s.AddShortField("count") }, 76 | "count", 77 | types.Short, 78 | 0, 79 | }, 80 | { 81 | "AddDateField", 82 | func() { s.AddDateField("created") }, 83 | "created", 84 | types.Date, 85 | 0, 86 | }, 87 | } 88 | 89 | for _, tt := range tests { 90 | t.Run(tt.name, func(t *testing.T) { 91 | tt.adder() 92 | 93 | assert.True(t, s.HasField(tt.field), "Field %s not found", tt.field) 94 | assert.Equal(t, tt.expected, s.Type(tt.field), "Field type mismatch") 95 | assert.Equal(t, tt.length, s.Length(tt.field), "Field length mismatch") 96 | }) 97 | } 98 | } 99 | 100 | func TestAdd(t *testing.T) { 101 | source := &Schema{ 102 | fields: []string{"id", "name"}, 103 | info: map[string]types.FieldInfo{ 104 | "id": {types.Integer, 0}, 105 | "name": {types.Varchar, 25}, 106 | }, 107 | } 108 | 109 | dest := NewSchema() 110 | 111 | // Add individual fields 112 | dest.Add("id", source) 113 | dest.Add("name", source) 114 | 115 | assert.Equal(t, 2, len(dest.fields), "Expected 2 fields") 116 | 117 | // Check id field 118 | idInfo, ok := dest.info["id"] 119 | assert.True(t, ok, "id field not found") 120 | assert.Equal(t, types.Integer, idInfo.Type, "id field type mismatch") 121 | assert.Equal(t, 0, idInfo.Length, "id field length mismatch") 122 | 123 | // Check name field 124 | nameInfo, ok := dest.info["name"] 125 | assert.True(t, ok, "name field not found") 126 | assert.Equal(t, types.Varchar, nameInfo.Type, "name field type mismatch") 127 | assert.Equal(t, 25, nameInfo.Length, "name field length mismatch") 128 | } 129 | 130 | func TestAddAll(t *testing.T) { 131 | source := &Schema{ 132 | fields: []string{"id", "name", "active"}, 133 | info: map[string]types.FieldInfo{ 134 | "id": {types.Integer, 0}, 135 | "name": {types.Varchar, 25}, 136 | "active": {types.Boolean, 0}, 137 | }, 138 | } 139 | 140 | dest := NewSchema() 141 | 142 | dest.AddAll(source) 143 | 144 | assert.Equal(t, len(source.fields), len(dest.fields), "Field count mismatch") 145 | 146 | // Verify field order is preserved 147 | for i, field := range source.fields { 148 | assert.Equal(t, field, dest.fields[i], "Field order mismatch at index %d", i) 149 | } 150 | 151 | // Verify all field info was copied correctly 152 | for field, sourceInfo := range source.info { 153 | destInfo, ok := dest.info[field] 154 | assert.True(t, ok, "Field %s not found in destination schema", field) 155 | assert.Equal(t, sourceInfo, destInfo, "Field info mismatch for %s", field) 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /scan/scan.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import "time" 4 | 5 | // Scan interface will be implemented by each query scan. 6 | // There is a Scan class for each relational algebra Operator. 7 | type Scan interface { 8 | // BeforeFirst positions the scan before the first record. A subsequent call to Next will move to the first record. 9 | BeforeFirst() error 10 | 11 | // Next moves to the next record in the scan. It returns false if there are no more records to scan. 12 | Next() (bool, error) 13 | 14 | // GetInt returns the integer value of the specified field in the current record. 15 | GetInt(fieldName string) (int, error) 16 | 17 | // GetLong returns the long value of the specified field in the current record. 18 | GetLong(fieldName string) (int64, error) 19 | 20 | // GetShort returns the short value of the specified field in the current record. 21 | GetShort(fieldName string) (int16, error) 22 | 23 | // GetString returns the string value of the specified field in the current record. 24 | GetString(fieldName string) (string, error) 25 | 26 | // GetBool returns the boolean value of the specified field in the current record. 27 | GetBool(fieldName string) (bool, error) 28 | 29 | // GetDate returns the date value of the specified field in the current record. 30 | GetDate(fieldName string) (time.Time, error) 31 | 32 | // HasField returns true if the current record has the specified field. 33 | HasField(fieldName string) bool 34 | 35 | // GetVal returns the value of the specified field in the current record. 36 | GetVal(fieldName string) (any, error) 37 | 38 | // Close closes the scan and its subscans, if any. 39 | Close() 40 | } 41 | -------------------------------------------------------------------------------- /scan/update_scan.go: -------------------------------------------------------------------------------- 1 | package scan 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/record" 5 | "time" 6 | ) 7 | 8 | type UpdateScan interface { 9 | Scan 10 | 11 | // SetVal sets the value of the specified field in the current record. 12 | SetVal(fieldName string, val any) error 13 | 14 | // SetInt sets the integer value of the specified field in the current record. 15 | SetInt(fieldName string, val int) error 16 | 17 | // SetLong sets the long value of the specified field in the current record. 18 | SetLong(fieldName string, val int64) error 19 | 20 | // SetShort sets the short value of the specified field in the current record. 21 | SetShort(fieldName string, val int16) error 22 | 23 | // SetString sets the string value of the specified field in the current record. 24 | SetString(fieldName string, val string) error 25 | 26 | // SetBool sets the boolean value of the specified field in the current record. 27 | SetBool(fieldName string, val bool) error 28 | 29 | // SetDate sets the date value of the specified field in the current record. 30 | SetDate(fieldName string, val time.Time) error 31 | 32 | // Insert inserts a new record somewhere in the scan. 33 | Insert() error 34 | 35 | // Delete deletes the current record from the scan. 36 | Delete() error 37 | 38 | // GetRecordID returns the record ID of the current record. 39 | GetRecordID() *record.ID 40 | 41 | // MoveToRecordID moves the scan to the record with the specified record ID. 42 | MoveToRecordID(rid *record.ID) error 43 | } 44 | -------------------------------------------------------------------------------- /server/dropdb.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/buffer" 6 | "github.com/JyotinderSingh/dropdb/file" 7 | "github.com/JyotinderSingh/dropdb/log" 8 | "github.com/JyotinderSingh/dropdb/metadata" 9 | "github.com/JyotinderSingh/dropdb/plan_impl" 10 | "github.com/JyotinderSingh/dropdb/tx" 11 | "github.com/JyotinderSingh/dropdb/tx/concurrency" 12 | ) 13 | 14 | const ( 15 | blockSize = 800 16 | bufferSize = 64 17 | logFile = "dropdb.log" 18 | ) 19 | 20 | type DropDB struct { 21 | fileManager *file.Manager 22 | bufferManager *buffer.Manager 23 | logManager *log.Manager 24 | metadataManager *metadata.Manager 25 | lockTable *concurrency.LockTable 26 | queryPlanner plan_impl.QueryPlanner 27 | updatePlanner plan_impl.UpdatePlanner 28 | planner *plan_impl.Planner 29 | } 30 | 31 | // NewDropDBWithOptions is a constructor that is mostly useful for debugging purposes. 32 | func NewDropDBWithOptions(dirName string, blockSize, bufferSize int) (*DropDB, error) { 33 | db := &DropDB{} 34 | var err error 35 | 36 | if db.fileManager, err = file.NewManager(dirName, blockSize); err != nil { 37 | return nil, err 38 | } 39 | if db.logManager, err = log.NewManager(db.fileManager, logFile); err != nil { 40 | return nil, err 41 | } 42 | db.bufferManager = buffer.NewManager(db.fileManager, db.logManager, bufferSize) 43 | db.lockTable = concurrency.NewLockTable() 44 | 45 | return db, nil 46 | } 47 | 48 | // NewDropDB creates a new DropDB instance. Use this constructor for production code. 49 | func NewDropDB(dirName string) (*DropDB, error) { 50 | db, err := NewDropDBWithOptions(dirName, blockSize, bufferSize) 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | transaction := db.NewTx() 56 | isNew := db.fileManager.IsNew() 57 | 58 | if isNew { 59 | fmt.Printf("creating new database\n") 60 | } else { 61 | fmt.Printf("recovering existing database\n") 62 | if err := transaction.Recover(); err != nil { 63 | return nil, err 64 | } 65 | } 66 | 67 | if db.metadataManager, err = metadata.NewManager(isNew, transaction); err != nil { 68 | return nil, err 69 | } 70 | 71 | db.queryPlanner = plan_impl.NewBasicQueryPlanner(db.metadataManager) 72 | db.updatePlanner = plan_impl.NewBasicUpdatePlanner(db.metadataManager) 73 | db.planner = plan_impl.NewPlanner(db.queryPlanner, db.updatePlanner) 74 | 75 | err = transaction.Commit() 76 | return db, err 77 | } 78 | 79 | func (db *DropDB) NewTx() *tx.Transaction { 80 | return tx.NewTransaction(db.fileManager, db.logManager, db.bufferManager, db.lockTable) 81 | } 82 | 83 | func (db *DropDB) MetadataManager() *metadata.Manager { 84 | return db.metadataManager 85 | } 86 | 87 | func (db *DropDB) Planner() *plan_impl.Planner { 88 | return db.planner 89 | } 90 | 91 | func (db *DropDB) FileManager() *file.Manager { 92 | return db.fileManager 93 | } 94 | 95 | func (db *DropDB) LogManager() *log.Manager { 96 | return db.logManager 97 | } 98 | 99 | func (db *DropDB) BufferManager() *buffer.Manager { 100 | return db.bufferManager 101 | } 102 | -------------------------------------------------------------------------------- /tx/buffer_list.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/buffer" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | ) 7 | 8 | // pinnedBuffer tracks the underlying buffer + how many times this transaction pinned it. 9 | type pinnedBuffer struct { 10 | buffer *buffer.Buffer 11 | refCount int 12 | } 13 | 14 | // BufferList manages a transaction's currently pinned buffers with reference counts. 15 | type BufferList struct { 16 | buffers map[file.BlockId]*pinnedBuffer 17 | bufferManager *buffer.Manager 18 | } 19 | 20 | // NewBufferList creates a new BufferList. 21 | func NewBufferList(bufferManager *buffer.Manager) *BufferList { 22 | return &BufferList{ 23 | buffers: make(map[file.BlockId]*pinnedBuffer), 24 | bufferManager: bufferManager, 25 | } 26 | } 27 | 28 | // GetBuffer returns the buffer pinned to the specified block. 29 | // The method returns nil if the transaction has not pinned the block. 30 | func (bl *BufferList) GetBuffer(block *file.BlockId) *buffer.Buffer { 31 | pinnedBuf, ok := bl.buffers[*block] 32 | if !ok { 33 | return nil 34 | } 35 | return pinnedBuf.buffer 36 | } 37 | 38 | // Pin pins the block. If the block is already pinned by this transaction, 39 | // simply increment the reference count. Otherwise, pin it via bufferManager. 40 | func (bl *BufferList) Pin(block *file.BlockId) error { 41 | if pinnedBuf, ok := bl.buffers[*block]; ok { 42 | // Already pinned by this transaction; just increase refCount 43 | pinnedBuf.refCount++ 44 | return nil 45 | } 46 | 47 | // Not pinned yet; ask bufferManager for a fresh pin 48 | buff, err := bl.bufferManager.Pin(block) 49 | if err != nil { 50 | return err 51 | } 52 | bl.buffers[*block] = &pinnedBuffer{ 53 | buffer: buff, 54 | refCount: 1, 55 | } 56 | return nil 57 | } 58 | 59 | // Unpin decrements the refCount. Only call bufferManager.Unpin when the last pin is released. 60 | func (bl *BufferList) Unpin(block *file.BlockId) { 61 | pinnedBuf, ok := bl.buffers[*block] 62 | if !ok { 63 | // This block isn't pinned or was already unpinned. 64 | // In production, you might log a warning or return silently. 65 | return 66 | } 67 | pinnedBuf.refCount-- 68 | if pinnedBuf.refCount <= 0 { 69 | // Now fully unpin from buffer manager and remove from our map 70 | bl.bufferManager.Unpin(pinnedBuf.buffer) 71 | delete(bl.buffers, *block) 72 | } 73 | } 74 | 75 | // UnpinAll unpins all blocks pinned by this transaction. 76 | // We decrement each block's refCount down to zero, unpinning once for each pin. 77 | func (bl *BufferList) UnpinAll() { 78 | for _, pinnedBuf := range bl.buffers { 79 | // We pinned this 'pinnedBuf.refCount' times; unpin that many times 80 | for pinnedBuf.refCount > 0 { 81 | pinnedBuf.refCount-- 82 | bl.bufferManager.Unpin(pinnedBuf.buffer) 83 | } 84 | // Alternatively: 85 | // for i := 0; i < pinnedBuf.refCount; i++ { 86 | // bl.bufferManager.Unpin(pinnedBuf.buffer) 87 | // } 88 | // pinnedBuf.refCount = 0 89 | } 90 | // Clear our map 91 | bl.buffers = make(map[file.BlockId]*pinnedBuffer) 92 | } 93 | -------------------------------------------------------------------------------- /tx/checkpoint.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/file" 5 | "github.com/JyotinderSingh/dropdb/log" 6 | ) 7 | 8 | type CheckpointRecord struct { 9 | LogRecord 10 | } 11 | 12 | // NewCheckpointRecord creates a new CheckpointRecord from a Page. 13 | func NewCheckpointRecord() (*CheckpointRecord, error) { 14 | return &CheckpointRecord{}, nil 15 | } 16 | 17 | // Op returns the type of the log record. 18 | func (r *CheckpointRecord) Op() LogRecordType { 19 | return Checkpoint 20 | } 21 | 22 | // TxNumber returns the transaction number stored in the log record. CheckpointRecord does not have a transaction 23 | // number, so it returns a "dummy", negative txId. 24 | func (r *CheckpointRecord) TxNumber() int { 25 | return -1 26 | } 27 | 28 | // Undo does nothing. CheckpointRecord does not change any data. 29 | func (r *CheckpointRecord) Undo(_ *Transaction) error { 30 | return nil 31 | } 32 | 33 | // String returns a string representation of the log record. 34 | func (r *CheckpointRecord) String() string { 35 | return "" 36 | } 37 | 38 | // WriteCheckpointToLog writes a checkpoint record to the log. This log record contains the Checkpoint operator and 39 | // nothing else. 40 | // The method returns the LSN of the new log record. 41 | func WriteCheckpointToLog(logManager *log.Manager) (int, error) { 42 | record := make([]byte, 4) 43 | 44 | page := file.NewPageFromBytes(record) 45 | page.SetInt(0, int(Checkpoint)) 46 | 47 | return logManager.Append(record) 48 | } 49 | -------------------------------------------------------------------------------- /tx/commit.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | "github.com/JyotinderSingh/dropdb/log" 7 | "github.com/JyotinderSingh/dropdb/types" 8 | ) 9 | 10 | type CommitRecord struct { 11 | LogRecord 12 | txNum int 13 | } 14 | 15 | // NewCommitRecord creates a new CommitRecord from a Page. 16 | func NewCommitRecord(page *file.Page) (*CommitRecord, error) { 17 | operationPos := 0 18 | txNumPos := operationPos + types.IntSize 19 | txNum := page.GetInt(txNumPos) 20 | 21 | return &CommitRecord{txNum: txNum}, nil 22 | } 23 | 24 | // Op returns the type of the log record. 25 | func (r *CommitRecord) Op() LogRecordType { 26 | return Commit 27 | } 28 | 29 | // TxNumber returns the transaction number stored in the log record. 30 | func (r *CommitRecord) TxNumber() int { 31 | return r.txNum 32 | } 33 | 34 | // Undo does nothing. CommitRecord does not change any data. 35 | func (r *CommitRecord) Undo(_ *Transaction) error { 36 | return nil 37 | } 38 | 39 | // String returns a string representation of the log record. 40 | func (r *CommitRecord) String() string { 41 | return fmt.Sprintf("", r.txNum) 42 | } 43 | 44 | // WriteCommitToLog writes a commit record to the log. This log record contains the Commit operator, 45 | // followed by the transaction id. 46 | // The method returns the LSN of the new log record. 47 | func WriteCommitToLog(logManager *log.Manager, txNum int) (int, error) { 48 | record := make([]byte, 2*types.IntSize) 49 | 50 | page := file.NewPageFromBytes(record) 51 | page.SetInt(0, int(Commit)) 52 | page.SetInt(types.IntSize, txNum) 53 | 54 | return logManager.Append(record) 55 | } 56 | -------------------------------------------------------------------------------- /tx/concurrency/lock_table.go: -------------------------------------------------------------------------------- 1 | package concurrency 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "github.com/JyotinderSingh/dropdb/file" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | const maxWaitTime = 10 * time.Second 13 | 14 | // LockTable provides methods to lock and Unlock blocks. 15 | // If a transaction requests a lock that causes a conflict with an existing lock, 16 | // then that transaction is placed on a wait list. 17 | // There is only one wait list for all blocks. 18 | // When the last lock on a block is unlocked, 19 | // then all transactions are removed from the wait list and rescheduled. 20 | // If one of those transactions discovers that the lock it is waiting for is still locked, 21 | // it will place itself back on the wait list. 22 | type LockTable struct { 23 | locks map[file.BlockId]int 24 | mu sync.Mutex 25 | cond *sync.Cond 26 | } 27 | 28 | // NewLockTable creates a new LockTable. 29 | func NewLockTable() *LockTable { 30 | lt := &LockTable{ 31 | locks: make(map[file.BlockId]int), 32 | } 33 | lt.cond = sync.NewCond(<.mu) 34 | return lt 35 | } 36 | 37 | // SLock grants a shared lock on the specified block. 38 | // If an XLock exists when the method is called, 39 | // then the calling thread will be placed on a wait list 40 | // until the lock is released. 41 | // If the thread remains on the wait list for too long (10 seconds for now), 42 | // then the method will return an error. 43 | func (lt *LockTable) SLock(block *file.BlockId) error { 44 | lt.mu.Lock() 45 | defer lt.mu.Unlock() 46 | 47 | ctx, cancel := context.WithTimeout(context.Background(), maxWaitTime) 48 | defer cancel() 49 | 50 | // This function will run after the context expires. 51 | stop := context.AfterFunc(ctx, func() { 52 | lt.cond.L.Lock() 53 | lt.cond.Broadcast() 54 | lt.cond.L.Unlock() 55 | }) 56 | 57 | defer stop() 58 | 59 | for { 60 | // If there's no exclusive lock, we can proceed 61 | if !lt.hasXLock(block) { 62 | // Get the number of shared locks. 63 | val := lt.getLockVal(block) 64 | // Grant the shared lock. 65 | lt.locks[*block] = val + 1 66 | return nil 67 | } 68 | 69 | // Wait until notified or context is done. 70 | lt.cond.Wait() 71 | 72 | if ctx.Err() != nil { 73 | if errors.Is(ctx.Err(), context.DeadlineExceeded) { 74 | return fmt.Errorf("lock abort exception: could not acquire shared lock on block %v: %v", block, ctx.Err()) 75 | } 76 | return ctx.Err() 77 | } 78 | } 79 | } 80 | 81 | // XLock grants an exclusive lock on the specified block. 82 | // Assumes that the calling thread already has a shared lock on the block. 83 | // If a lock of any type (by some other transaction) exists when the method is called, 84 | // then the calling thread will be placed on a wait list until the locks are released. 85 | // If the thread remains on the wait list for too long (10 seconds for now), 86 | // then the method will return an error. 87 | func (lt *LockTable) XLock(block *file.BlockId) error { 88 | lt.mu.Lock() 89 | defer lt.mu.Unlock() 90 | 91 | ctx, cancel := context.WithTimeout(context.Background(), maxWaitTime) 92 | defer cancel() 93 | 94 | stop := context.AfterFunc(ctx, func() { 95 | lt.cond.L.Lock() 96 | lt.cond.Broadcast() 97 | lt.cond.L.Unlock() 98 | }) 99 | 100 | defer stop() 101 | 102 | for { 103 | // Assume that the calling thread already has a shared lock. If any other shared locks exist, we can't proceed. 104 | if !lt.hasOtherSLocks(block) { 105 | lt.locks[*block] = -1 106 | return nil 107 | } 108 | 109 | lt.cond.Wait() 110 | 111 | if ctx.Err() != nil { 112 | if errors.Is(ctx.Err(), context.DeadlineExceeded) { 113 | return fmt.Errorf("lock abort exception: could not acquire exclusive lock on block %v: %v", block, ctx.Err()) 114 | } 115 | return ctx.Err() 116 | } 117 | } 118 | } 119 | 120 | // Unlock releases the lock on the specified block. 121 | // If this lock is the last lock on that block, 122 | // then the waiting transactions are notified. 123 | func (lt *LockTable) Unlock(block *file.BlockId) { 124 | lt.mu.Lock() 125 | defer lt.mu.Unlock() 126 | 127 | val := lt.getLockVal(block) 128 | if val > 1 { 129 | lt.locks[*block] = val - 1 130 | } else { 131 | delete(lt.locks, *block) 132 | lt.cond.Broadcast() 133 | } 134 | } 135 | 136 | // hasXLock returns true if there is an exclusive lock on the block. 137 | func (lt *LockTable) hasXLock(block *file.BlockId) bool { 138 | return lt.getLockVal(block) < 0 139 | } 140 | 141 | // hasOtherSLocks returns true if there is more than one shared locks on the block. 142 | func (lt *LockTable) hasOtherSLocks(block *file.BlockId) bool { 143 | return lt.getLockVal(block) > 1 144 | } 145 | 146 | func (lt *LockTable) getLockVal(block *file.BlockId) int { 147 | return lt.locks[*block] 148 | } 149 | -------------------------------------------------------------------------------- /tx/concurrency/manager.go: -------------------------------------------------------------------------------- 1 | package concurrency 2 | 3 | import ( 4 | "github.com/JyotinderSingh/dropdb/file" 5 | ) 6 | 7 | type Manager struct { 8 | lockTable *LockTable // pointer to the global lock table. 9 | locks map[file.BlockId]string 10 | } 11 | 12 | // NewManager creates a new Manager. 13 | func NewManager(lockTable *LockTable) *Manager { 14 | return &Manager{lockTable: lockTable, locks: make(map[file.BlockId]string)} 15 | } 16 | 17 | // SLock obtains a shared lock on the block, if necessary. 18 | // The method will ask the lock table for an SLock if the transaction currently has no locks on the block. 19 | func (m *Manager) SLock(block *file.BlockId) error { 20 | // if the lock doesn't exist in the locks map, acquire it from the lock table. 21 | if _, ok := m.locks[*block]; !ok { 22 | if err := m.lockTable.SLock(block); err != nil { 23 | return err 24 | } 25 | m.locks[*block] = "s" 26 | } 27 | return nil 28 | } 29 | 30 | // XLock obtains an exclusive lock on the block, if necessary. 31 | // If the transaction does not have an exclusive lock on the block, 32 | // the method first gets a shared lock on that block (if necessary), and then upgrades it to an exclusive lock. 33 | func (m *Manager) XLock(block *file.BlockId) error { 34 | if !m.hasXLock(block) { 35 | if err := m.SLock(block); err != nil { 36 | return err 37 | } 38 | if err := m.lockTable.XLock(block); err != nil { 39 | return err 40 | } 41 | m.locks[*block] = "x" 42 | } 43 | return nil 44 | } 45 | 46 | // Release releases all the locks by asking the lock table to Unlock each one. 47 | func (m *Manager) Release() { 48 | for block := range m.locks { 49 | m.lockTable.Unlock(&block) 50 | } 51 | m.locks = make(map[file.BlockId]string) 52 | } 53 | 54 | // hasXLock returns true if the transaction has an exclusive lock on the block. 55 | func (m *Manager) hasXLock(block *file.BlockId) bool { 56 | lock, ok := m.locks[*block] 57 | return ok && lock == "x" 58 | } 59 | -------------------------------------------------------------------------------- /tx/log_record.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "errors" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | ) 7 | 8 | // LogRecordType is the type of log record. 9 | type LogRecordType int 10 | 11 | const ( 12 | Checkpoint LogRecordType = iota 13 | Start 14 | Commit 15 | Rollback 16 | SetInt 17 | SetString 18 | SetBool 19 | SetLong 20 | SetShort 21 | SetDate 22 | ) 23 | 24 | func (t LogRecordType) String() string { 25 | switch t { 26 | case Checkpoint: 27 | return "Checkpoint" 28 | case Start: 29 | return "Start" 30 | case Commit: 31 | return "Commit" 32 | case Rollback: 33 | return "Rollback" 34 | case SetInt: 35 | return "SetInt" 36 | case SetString: 37 | return "SetString" 38 | case SetBool: 39 | return "SetBool" 40 | case SetLong: 41 | return "SetLong" 42 | case SetShort: 43 | return "SetShort" 44 | case SetDate: 45 | return "SetDate" 46 | default: 47 | return "Unknown" 48 | } 49 | } 50 | 51 | func FromCode(code int) (LogRecordType, error) { 52 | switch code { 53 | case 0: 54 | return Checkpoint, nil 55 | case 1: 56 | return Start, nil 57 | case 2: 58 | return Commit, nil 59 | case 3: 60 | return Rollback, nil 61 | case 4: 62 | return SetInt, nil 63 | case 5: 64 | return SetString, nil 65 | case 6: 66 | return SetBool, nil 67 | case 7: 68 | return SetLong, nil 69 | case 8: 70 | return SetShort, nil 71 | case 9: 72 | return SetDate, nil 73 | default: 74 | return -1, errors.New("unknown LogRecordType code") 75 | } 76 | } 77 | 78 | // LogRecord interface for log records. 79 | type LogRecord interface { 80 | // Op returns the log record type. 81 | Op() LogRecordType 82 | 83 | // TxNumber returns the transaction ID stored with the log record. 84 | TxNumber() int 85 | 86 | // Undo undoes the operation encoded by this log record. 87 | // Undoes the operation encoded by this log record. 88 | // The only log record types for which this method does anything interesting are SETINT and SETSTRING. 89 | Undo(tx *Transaction) error 90 | 91 | // String returns a string representation of the log record. 92 | String() string 93 | } 94 | 95 | // CreateLogRecord interprets the bytes to create the appropriate log record. This method assumes that the first 4 bytes 96 | // of the byte array represent the log record type. 97 | func CreateLogRecord(bytes []byte) (LogRecord, error) { 98 | p := file.NewPageFromBytes(bytes) 99 | code := p.GetInt(0) 100 | recordType, err := FromCode(int(code)) 101 | if err != nil { 102 | return nil, err 103 | } 104 | 105 | switch recordType { 106 | case Checkpoint: 107 | return NewCheckpointRecord() 108 | case Start: 109 | return NewStartRecord(p) 110 | case Commit: 111 | return NewCommitRecord(p) 112 | case Rollback: 113 | return NewRollbackRecord(p) 114 | case SetInt: 115 | return NewSetIntRecord(p) 116 | case SetString: 117 | return NewSetStringRecord(p) 118 | case SetBool: 119 | return NewSetBoolRecord(p) 120 | case SetLong: 121 | return NewSetLongRecord(p) 122 | case SetShort: 123 | return NewSetShortRecord(p) 124 | case SetDate: 125 | return NewSetDateRecord(p) 126 | default: 127 | return nil, errors.New("unexpected LogRecordType") 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /tx/rollback.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | "github.com/JyotinderSingh/dropdb/log" 7 | "github.com/JyotinderSingh/dropdb/types" 8 | ) 9 | 10 | type RollbackRecord struct { 11 | LogRecord 12 | txNum int 13 | } 14 | 15 | // NewRollbackRecord creates a new RollbackRecord from a Page. 16 | func NewRollbackRecord(page *file.Page) (*RollbackRecord, error) { 17 | operationPos := 0 18 | txNumPos := operationPos + types.IntSize 19 | txNum := int(page.GetInt(txNumPos)) 20 | 21 | return &RollbackRecord{txNum: txNum}, nil 22 | } 23 | 24 | // Op returns the type of the log record. 25 | func (r *RollbackRecord) Op() LogRecordType { 26 | return Rollback 27 | } 28 | 29 | // TxNumber returns the transaction number stored in the log record. 30 | func (r *RollbackRecord) TxNumber() int { 31 | return r.txNum 32 | } 33 | 34 | // Undo does nothing. RollbackRecord does not change any data. 35 | func (r *RollbackRecord) Undo(_ *Transaction) error { 36 | return nil 37 | } 38 | 39 | // String returns a string representation of the log record. 40 | func (r *RollbackRecord) String() string { 41 | return fmt.Sprintf("", r.txNum) 42 | } 43 | 44 | // WriteRollbackToLog writes a rollback record to the log. This log record contains the Rollback operator, 45 | // followed by the transaction id. 46 | // The method returns the LSN of the new log record. 47 | func WriteRollbackToLog(logManager *log.Manager, txNum int) (int, error) { 48 | record := make([]byte, 2*types.IntSize) 49 | 50 | page := file.NewPageFromBytes(record) 51 | page.SetInt(0, int(Rollback)) 52 | page.SetInt(4, txNum) 53 | 54 | return logManager.Append(record) 55 | } 56 | -------------------------------------------------------------------------------- /tx/set_bool.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | "github.com/JyotinderSingh/dropdb/log" 7 | "github.com/JyotinderSingh/dropdb/types" 8 | ) 9 | 10 | type SetBoolRecord struct { 11 | LogRecord 12 | txNum int 13 | offset int 14 | value bool 15 | block *file.BlockId 16 | } 17 | 18 | func NewSetBoolRecord(page *file.Page) (*SetBoolRecord, error) { 19 | operationPos := 0 20 | txNumPos := operationPos + types.IntSize 21 | txNum := page.GetInt(txNumPos) 22 | 23 | fileNamePos := txNumPos + types.IntSize 24 | fileName, err := page.GetString(fileNamePos) 25 | if err != nil { 26 | return nil, err 27 | } 28 | 29 | blockNumPos := fileNamePos + file.MaxLength(len(fileName)) 30 | blockNum := page.GetInt(blockNumPos) 31 | block := &file.BlockId{File: fileName, BlockNumber: int(blockNum)} 32 | 33 | offsetPos := blockNumPos + types.IntSize 34 | offset := page.GetInt(offsetPos) 35 | 36 | valuePos := offsetPos + types.IntSize 37 | val := page.GetBool(valuePos) 38 | 39 | return &SetBoolRecord{txNum: txNum, offset: offset, value: val, block: block}, nil 40 | } 41 | 42 | func (r *SetBoolRecord) Op() LogRecordType { 43 | return SetBool 44 | } 45 | 46 | func (r *SetBoolRecord) TxNumber() int { 47 | return r.txNum 48 | } 49 | 50 | func (r *SetBoolRecord) String() string { 51 | return fmt.Sprintf("", r.txNum, r.block, r.offset, r.value) 52 | } 53 | 54 | func (r *SetBoolRecord) Undo(tx *Transaction) error { 55 | if err := tx.Pin(r.block); err != nil { 56 | return err 57 | } 58 | defer tx.Unpin(r.block) 59 | return tx.SetBool(r.block, r.offset, r.value, false) 60 | } 61 | 62 | func WriteSetBoolToLog(logManager *log.Manager, txNum int, block *file.BlockId, offset int, val bool) (int, error) { 63 | operationPos := 0 64 | txNumPos := operationPos + types.IntSize 65 | fileNamePos := txNumPos + types.IntSize 66 | fileName := block.Filename() 67 | 68 | blockNumPos := fileNamePos + file.MaxLength(len(fileName)) 69 | blockNum := block.Number() 70 | 71 | offsetPos := blockNumPos + types.IntSize 72 | valuePos := offsetPos + types.IntSize 73 | 74 | // 1 byte for bool 75 | recordLen := valuePos + 1 76 | 77 | recordBytes := make([]byte, recordLen) 78 | page := file.NewPageFromBytes(recordBytes) 79 | 80 | page.SetInt(operationPos, int(SetBool)) 81 | page.SetInt(txNumPos, txNum) 82 | if err := page.SetString(fileNamePos, fileName); err != nil { 83 | return -1, err 84 | } 85 | page.SetInt(blockNumPos, blockNum) 86 | page.SetInt(offsetPos, offset) 87 | page.SetBool(valuePos, val) 88 | 89 | return logManager.Append(recordBytes) 90 | } 91 | -------------------------------------------------------------------------------- /tx/set_date.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/types" 6 | "time" 7 | 8 | "github.com/JyotinderSingh/dropdb/file" 9 | "github.com/JyotinderSingh/dropdb/log" 10 | ) 11 | 12 | type SetDateRecord struct { 13 | LogRecord 14 | txNum int 15 | offset int 16 | value time.Time 17 | block *file.BlockId 18 | } 19 | 20 | func NewSetDateRecord(page *file.Page) (*SetDateRecord, error) { 21 | operationPos := 0 22 | txNumPos := operationPos + types.IntSize 23 | txNum := page.GetInt(txNumPos) 24 | 25 | fileNamePos := txNumPos + types.IntSize 26 | fileName, err := page.GetString(fileNamePos) 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | blockNumPos := fileNamePos + file.MaxLength(len(fileName)) 32 | blockNum := page.GetInt(blockNumPos) 33 | block := &file.BlockId{File: fileName, BlockNumber: int(blockNum)} 34 | 35 | offsetPos := blockNumPos + types.IntSize 36 | offset := page.GetInt(offsetPos) 37 | 38 | valuePos := offsetPos + types.IntSize 39 | val := page.GetDate(valuePos) 40 | 41 | return &SetDateRecord{txNum: txNum, offset: offset, value: val, block: block}, nil 42 | } 43 | 44 | func (r *SetDateRecord) Op() LogRecordType { 45 | return SetDate 46 | } 47 | 48 | func (r *SetDateRecord) TxNumber() int { 49 | return r.txNum 50 | } 51 | 52 | func (r *SetDateRecord) String() string { 53 | return fmt.Sprintf("", r.txNum, r.block, r.offset, r.value.String()) 54 | } 55 | 56 | func (r *SetDateRecord) Undo(tx *Transaction) error { 57 | if err := tx.Pin(r.block); err != nil { 58 | return err 59 | } 60 | defer tx.Unpin(r.block) 61 | return tx.SetDate(r.block, r.offset, r.value, false) 62 | } 63 | 64 | func WriteSetDateToLog(logManager *log.Manager, txNum int, block *file.BlockId, offset int, val time.Time) (int, error) { 65 | operationPos := 0 66 | txNumPos := operationPos + types.IntSize 67 | fileNamePos := txNumPos + types.IntSize 68 | fileName := block.Filename() 69 | 70 | blockNumPos := fileNamePos + file.MaxLength(len(fileName)) 71 | blockNum := block.Number() 72 | 73 | offsetPos := blockNumPos + types.IntSize 74 | valuePos := offsetPos + types.IntSize 75 | // time.Time stored as int64 (8 bytes) 76 | recordLen := valuePos + 8 77 | 78 | recordBytes := make([]byte, recordLen) 79 | page := file.NewPageFromBytes(recordBytes) 80 | 81 | page.SetInt(operationPos, int(SetDate)) 82 | page.SetInt(txNumPos, txNum) 83 | if err := page.SetString(fileNamePos, fileName); err != nil { 84 | return -1, err 85 | } 86 | page.SetInt(blockNumPos, blockNum) 87 | page.SetInt(offsetPos, offset) 88 | page.SetDate(valuePos, val) 89 | 90 | return logManager.Append(recordBytes) 91 | } 92 | -------------------------------------------------------------------------------- /tx/set_int.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | "github.com/JyotinderSingh/dropdb/log" 7 | "github.com/JyotinderSingh/dropdb/types" 8 | ) 9 | 10 | type SetIntRecord struct { 11 | LogRecord 12 | txNum int 13 | offset int 14 | value int 15 | block *file.BlockId 16 | } 17 | 18 | // NewSetIntRecord creates a new SetIntRecord from a Page. 19 | func NewSetIntRecord(page *file.Page) (*SetIntRecord, error) { 20 | operationPos := 0 21 | txNumPos := operationPos + types.IntSize 22 | txNum := page.GetInt(txNumPos) 23 | 24 | fileNamePos := txNumPos + types.IntSize 25 | fileName, err := page.GetString(fileNamePos) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | blockNumPos := fileNamePos + file.MaxLength(len(fileName)) 31 | blockNum := page.GetInt(blockNumPos) 32 | block := &file.BlockId{File: fileName, BlockNumber: int(blockNum)} 33 | 34 | offsetPos := blockNumPos + types.IntSize 35 | offset := page.GetInt(offsetPos) 36 | 37 | valuePos := offsetPos + types.IntSize 38 | value := page.GetInt(valuePos) 39 | 40 | return &SetIntRecord{txNum: txNum, offset: offset, value: value, block: block}, nil 41 | } 42 | 43 | // Op returns the type of the log record. 44 | func (r *SetIntRecord) Op() LogRecordType { 45 | return SetInt 46 | } 47 | 48 | // TxNumber returns the transaction number stored in the log record. 49 | func (r *SetIntRecord) TxNumber() int { 50 | return r.txNum 51 | } 52 | 53 | // String returns a string representation of the log record. 54 | func (r *SetIntRecord) String() string { 55 | return fmt.Sprintf("", r.txNum, r.block, r.offset, r.value) 56 | } 57 | 58 | // Undo replaces the specified data value with the value saved in the log record. 59 | // The method pins a buffer to the specified block, 60 | // calls setInt to restore the saved value, 61 | // and unpins the buffer. 62 | func (r *SetIntRecord) Undo(tx *Transaction) error { 63 | if err := tx.Pin(r.block); err != nil { 64 | return err 65 | } 66 | defer tx.Unpin(r.block) 67 | return tx.SetInt(r.block, r.offset, r.value, false) 68 | } 69 | 70 | // WriteSetIntToLog writes a SetInt record to the log. The record contains the specified transaction number, the 71 | // filename and block number of the block containing the int, the offset of the int in the block, and the new value 72 | // of the int. 73 | // The method returns the LSN of the new log record. 74 | func WriteSetIntToLog(logManager *log.Manager, txNum int, block *file.BlockId, offset, val int) (int, error) { 75 | operationPos := 0 76 | txNumPos := operationPos + types.IntSize 77 | fileNamePos := txNumPos + types.IntSize 78 | fileName := block.Filename() 79 | 80 | blockNumPos := fileNamePos + file.MaxLength(len(block.File)) 81 | blockNum := block.Number() 82 | 83 | offsetPos := blockNumPos + types.IntSize 84 | valuePos := offsetPos + types.IntSize 85 | recordLen := valuePos + types.IntSize 86 | 87 | recordBytes := make([]byte, recordLen) 88 | page := file.NewPageFromBytes(recordBytes) 89 | 90 | page.SetInt(operationPos, int(SetInt)) 91 | page.SetInt(txNumPos, txNum) 92 | if err := page.SetString(fileNamePos, fileName); err != nil { 93 | return -1, err 94 | } 95 | page.SetInt(blockNumPos, blockNum) 96 | page.SetInt(offsetPos, offset) 97 | page.SetInt(valuePos, val) 98 | 99 | return logManager.Append(recordBytes) 100 | } 101 | -------------------------------------------------------------------------------- /tx/set_long.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | "github.com/JyotinderSingh/dropdb/log" 7 | "github.com/JyotinderSingh/dropdb/types" 8 | ) 9 | 10 | type SetLongRecord struct { 11 | LogRecord 12 | txNum int 13 | offset int 14 | value int64 15 | block *file.BlockId 16 | } 17 | 18 | func NewSetLongRecord(page *file.Page) (*SetLongRecord, error) { 19 | operationPos := 0 20 | txNumPos := operationPos + types.IntSize 21 | txNum := page.GetInt(txNumPos) 22 | 23 | fileNamePos := txNumPos + types.IntSize 24 | fileName, err := page.GetString(fileNamePos) 25 | if err != nil { 26 | return nil, err 27 | } 28 | 29 | blockNumPos := fileNamePos + file.MaxLength(len(fileName)) 30 | blockNum := page.GetInt(blockNumPos) 31 | block := &file.BlockId{File: fileName, BlockNumber: int(blockNum)} 32 | 33 | offsetPos := blockNumPos + types.IntSize 34 | offset := page.GetInt(offsetPos) 35 | 36 | valuePos := offsetPos + types.IntSize 37 | val := page.GetLong(valuePos) // 8 bytes long 38 | 39 | return &SetLongRecord{txNum: txNum, offset: offset, value: val, block: block}, nil 40 | } 41 | 42 | func (r *SetLongRecord) Op() LogRecordType { 43 | return SetLong 44 | } 45 | 46 | func (r *SetLongRecord) TxNumber() int { 47 | return r.txNum 48 | } 49 | 50 | func (r *SetLongRecord) String() string { 51 | return fmt.Sprintf("", r.txNum, r.block, r.offset, r.value) 52 | } 53 | 54 | func (r *SetLongRecord) Undo(tx *Transaction) error { 55 | if err := tx.Pin(r.block); err != nil { 56 | return err 57 | } 58 | defer tx.Unpin(r.block) 59 | return tx.SetLong(r.block, r.offset, r.value, false) 60 | } 61 | 62 | func WriteSetLongToLog(logManager *log.Manager, txNum int, block *file.BlockId, offset int, val int64) (int, error) { 63 | operationPos := 0 64 | txNumPos := operationPos + types.IntSize 65 | fileNamePos := txNumPos + types.IntSize 66 | fileName := block.Filename() 67 | 68 | blockNumPos := fileNamePos + file.MaxLength(len(fileName)) 69 | blockNum := block.Number() 70 | 71 | offsetPos := blockNumPos + types.IntSize 72 | valuePos := offsetPos + types.IntSize 73 | // int64 is 8 bytes 74 | recordLen := valuePos + 8 75 | 76 | recordBytes := make([]byte, recordLen) 77 | page := file.NewPageFromBytes(recordBytes) 78 | 79 | page.SetInt(operationPos, int(SetLong)) 80 | page.SetInt(txNumPos, txNum) 81 | if err := page.SetString(fileNamePos, fileName); err != nil { 82 | return -1, err 83 | } 84 | page.SetInt(blockNumPos, blockNum) 85 | page.SetInt(offsetPos, offset) 86 | page.SetLong(valuePos, val) 87 | 88 | return logManager.Append(recordBytes) 89 | } 90 | -------------------------------------------------------------------------------- /tx/set_short.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | "github.com/JyotinderSingh/dropdb/log" 7 | "github.com/JyotinderSingh/dropdb/types" 8 | ) 9 | 10 | type SetShortRecord struct { 11 | LogRecord 12 | txNum int 13 | offset int 14 | value int16 15 | block *file.BlockId 16 | } 17 | 18 | func NewSetShortRecord(page *file.Page) (*SetShortRecord, error) { 19 | operationPos := 0 20 | txNumPos := operationPos + types.IntSize 21 | txNum := page.GetInt(txNumPos) 22 | 23 | fileNamePos := txNumPos + types.IntSize 24 | fileName, err := page.GetString(fileNamePos) 25 | if err != nil { 26 | return nil, err 27 | } 28 | 29 | blockNumPos := fileNamePos + file.MaxLength(len(fileName)) 30 | blockNum := page.GetInt(blockNumPos) 31 | block := &file.BlockId{File: fileName, BlockNumber: int(blockNum)} 32 | 33 | offsetPos := blockNumPos + types.IntSize 34 | offset := page.GetInt(offsetPos) 35 | 36 | valuePos := offsetPos + types.IntSize 37 | val := page.GetShort(valuePos) 38 | 39 | return &SetShortRecord{txNum: txNum, offset: offset, value: val, block: block}, nil 40 | } 41 | 42 | func (r *SetShortRecord) Op() LogRecordType { 43 | return SetShort 44 | } 45 | 46 | func (r *SetShortRecord) TxNumber() int { 47 | return r.txNum 48 | } 49 | 50 | func (r *SetShortRecord) String() string { 51 | return fmt.Sprintf("", r.txNum, r.block, r.offset, r.value) 52 | } 53 | 54 | func (r *SetShortRecord) Undo(tx *Transaction) error { 55 | if err := tx.Pin(r.block); err != nil { 56 | return err 57 | } 58 | defer tx.Unpin(r.block) 59 | return tx.SetShort(r.block, r.offset, r.value, false) 60 | } 61 | 62 | func WriteSetShortToLog(logManager *log.Manager, txNum int, block *file.BlockId, offset int, val int16) (int, error) { 63 | operationPos := 0 64 | txNumPos := operationPos + types.IntSize 65 | fileNamePos := txNumPos + types.IntSize 66 | fileName := block.Filename() 67 | 68 | blockNumPos := fileNamePos + file.MaxLength(len(fileName)) 69 | blockNum := block.Number() 70 | 71 | offsetPos := blockNumPos + types.IntSize 72 | valuePos := offsetPos + types.IntSize 73 | // int16 is 2 bytes 74 | recordLen := valuePos + 2 75 | 76 | recordBytes := make([]byte, recordLen) 77 | page := file.NewPageFromBytes(recordBytes) 78 | 79 | page.SetInt(operationPos, int(SetShort)) 80 | page.SetInt(txNumPos, txNum) 81 | if err := page.SetString(fileNamePos, fileName); err != nil { 82 | return -1, err 83 | } 84 | page.SetInt(blockNumPos, blockNum) 85 | page.SetInt(offsetPos, offset) 86 | page.SetShort(valuePos, val) 87 | 88 | return logManager.Append(recordBytes) 89 | } 90 | -------------------------------------------------------------------------------- /tx/set_string.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | "github.com/JyotinderSingh/dropdb/log" 7 | "github.com/JyotinderSingh/dropdb/types" 8 | ) 9 | 10 | type SetStringRecord struct { 11 | LogRecord 12 | txNum int 13 | offset int 14 | value string 15 | block *file.BlockId 16 | } 17 | 18 | // NewSetStringRecord creates a new SetStringRecord from a Page. 19 | func NewSetStringRecord(page *file.Page) (*SetStringRecord, error) { 20 | operationPos := 0 21 | txNumPos := operationPos + types.IntSize 22 | txNum := page.GetInt(txNumPos) 23 | 24 | fileNamePos := txNumPos + types.IntSize 25 | fileName, err := page.GetString(fileNamePos) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | blockNumPos := fileNamePos + file.MaxLength(len(fileName)) 31 | blockNum := page.GetInt(blockNumPos) 32 | block := &file.BlockId{File: fileName, BlockNumber: int(blockNum)} 33 | 34 | offsetPos := blockNumPos + types.IntSize 35 | offset := page.GetInt(offsetPos) 36 | 37 | valuePos := offsetPos + types.IntSize 38 | value, err := page.GetString(valuePos) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | return &SetStringRecord{txNum: txNum, offset: offset, value: value, block: block}, nil 44 | } 45 | 46 | // Op returns the type of the log record. 47 | func (r *SetStringRecord) Op() LogRecordType { 48 | return SetString 49 | } 50 | 51 | // TxNumber returns the transaction number stored in the log record. 52 | func (r *SetStringRecord) TxNumber() int { 53 | return r.txNum 54 | } 55 | 56 | // String returns a string representation of the log record. 57 | func (r *SetStringRecord) String() string { 58 | return fmt.Sprintf("", r.txNum, r.block, r.offset, r.value) 59 | } 60 | 61 | // Undo replaces the specified data value with the value saved in the log record. 62 | // The method pins a buffer to the specified block, 63 | // calls the buffer's setString method to restore the saved value, and unpins the buffer. 64 | func (r *SetStringRecord) Undo(tx *Transaction) error { 65 | if err := tx.Pin(r.block); err != nil { 66 | return err 67 | } 68 | defer tx.Unpin(r.block) 69 | return tx.SetString(r.block, r.offset, r.value, false) // Don't log the undo 70 | } 71 | 72 | // WriteSetStringToLog writes a set string record to the log. The record contains the specified transaction number, the 73 | // filename and block number of the block containing the string, the offset of the string in the block, and the new value 74 | // of the string. 75 | // The method returns the LSN of the new log record. 76 | func WriteSetStringToLog(logManager *log.Manager, txNum int, block *file.BlockId, offset int, value string) (int, error) { 77 | operationPos := 0 78 | txNumPos := operationPos + types.IntSize 79 | fileNamePos := txNumPos + types.IntSize 80 | fileName := block.Filename() 81 | 82 | blockNumPos := fileNamePos + file.MaxLength(len(fileName)) 83 | blockNum := block.Number() 84 | 85 | offsetPos := blockNumPos + types.IntSize 86 | valuePos := offsetPos + types.IntSize 87 | recordLen := valuePos + file.MaxLength(len(value)) 88 | 89 | recordBytes := make([]byte, recordLen) 90 | page := file.NewPageFromBytes(recordBytes) 91 | 92 | page.SetInt(operationPos, int(SetString)) 93 | page.SetInt(txNumPos, txNum) 94 | if err := page.SetString(fileNamePos, fileName); err != nil { 95 | return -1, err 96 | } 97 | page.SetInt(blockNumPos, blockNum) 98 | page.SetInt(offsetPos, offset) 99 | if err := page.SetString(valuePos, value); err != nil { 100 | return -1, err 101 | } 102 | 103 | return logManager.Append(recordBytes) 104 | } 105 | -------------------------------------------------------------------------------- /tx/start.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "fmt" 5 | "github.com/JyotinderSingh/dropdb/file" 6 | "github.com/JyotinderSingh/dropdb/log" 7 | "github.com/JyotinderSingh/dropdb/types" 8 | ) 9 | 10 | type StartRecord struct { 11 | LogRecord 12 | txNum int 13 | } 14 | 15 | // NewStartRecord creates a new StartRecord from a Page. 16 | func NewStartRecord(page *file.Page) (*StartRecord, error) { 17 | operationPos := 0 18 | txNumPos := operationPos + types.IntSize 19 | txNum := page.GetInt(txNumPos) 20 | 21 | return &StartRecord{txNum: txNum}, nil 22 | } 23 | 24 | // Op returns the type of the log record. 25 | func (r *StartRecord) Op() LogRecordType { 26 | return Start 27 | } 28 | 29 | // TxNumber returns the transaction number stored in the log record. 30 | func (r *StartRecord) TxNumber() int { 31 | return r.txNum 32 | } 33 | 34 | // Undo does nothing. StartRecord does not change any data. 35 | func (r *StartRecord) Undo(_ *Transaction) error { 36 | return nil 37 | } 38 | 39 | // String returns a string representation of the log record. 40 | func (r *StartRecord) String() string { 41 | return fmt.Sprintf("", r.txNum) 42 | } 43 | 44 | // WriteStartToLog writes a start record to the log. This log record contains the Start operator, 45 | // followed by the transaction id. 46 | // The method returns the LSN of the new log record. 47 | func WriteStartToLog(logManager *log.Manager, txNum int) (int, error) { 48 | record := make([]byte, 2*types.IntSize) 49 | 50 | page := file.NewPageFromBytes(record) 51 | page.SetInt(0, int(Start)) 52 | page.SetInt(4, txNum) 53 | 54 | return logManager.Append(record) 55 | } 56 | -------------------------------------------------------------------------------- /types/comparisons.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | // CompareSupportedTypes handles comparison for supported types. 9 | func CompareSupportedTypes(lhs, rhs any, op Operator) bool { 10 | // Handle nil values explicitly 11 | if lhs == nil || rhs == nil { 12 | return false // Null comparisons always return false in SQL semantics 13 | } 14 | 15 | // First try to unify integer types: 16 | // Using int to make it simpler to handle types across the db. 17 | // This might cause issues with 64-bit integers on 32-bit architectures. 18 | if lhsInt, lhsIsInt := toInt(lhs); lhsIsInt { 19 | if rhsInt, rhsIsInt := toInt(rhs); rhsIsInt { 20 | // Both lhs and rhs are integers, so compare them as int64 21 | return compareInts(lhsInt, rhsInt, op) 22 | } 23 | } 24 | 25 | // If not both integers, switch on types for the other supported comparisons: 26 | switch lhs := lhs.(type) { 27 | case string: 28 | if rhs, ok := rhs.(string); ok { 29 | return compareStrings(lhs, rhs, op) 30 | } 31 | case bool: 32 | if rhs, ok := rhs.(bool); ok { 33 | return compareBools(lhs, rhs, op) 34 | } 35 | case time.Time: 36 | if rhs, ok := rhs.(time.Time); ok { 37 | return compareTimes(lhs, rhs, op) 38 | } 39 | // You can still directly handle type == type comparisons if needed 40 | // (e.g., if you had float64 or others). 41 | default: 42 | // Log unsupported type for debugging 43 | fmt.Printf("Unsupported or mismatched types for comparison: lhs=%T, rhs=%T\n", lhs, rhs) 44 | } 45 | 46 | // Return false for unsupported or mismatched types 47 | return false 48 | } 49 | 50 | // toInt attempts to convert an interface to int. 51 | // It returns (convertedValue, true) if successful; (0, false) otherwise. 52 | func toInt(i any) (int, bool) { 53 | switch v := i.(type) { 54 | case int: 55 | return v, true 56 | case int64: 57 | return int(v), true 58 | case int16: 59 | return int(v), true 60 | default: 61 | return 0, false 62 | } 63 | } 64 | 65 | // compareInts compares two integers. 66 | func compareInts(lhs, rhs int, op Operator) bool { 67 | switch op { 68 | case NE: 69 | return lhs != rhs 70 | case EQ: 71 | return lhs == rhs 72 | case LT: 73 | return lhs < rhs 74 | case LE: 75 | return lhs <= rhs 76 | case GT: 77 | return lhs > rhs 78 | case GE: 79 | return lhs >= rhs 80 | default: 81 | fmt.Printf("unsupported operator: %v\n", op) 82 | return false 83 | } 84 | } 85 | 86 | // compareInt64s compares two int64 values. 87 | func compareInt64s(lhs, rhs int64, op Operator) bool { 88 | switch op { 89 | case NE: 90 | return lhs != rhs 91 | case EQ: 92 | return lhs == rhs 93 | case LT: 94 | return lhs < rhs 95 | case LE: 96 | return lhs <= rhs 97 | case GT: 98 | return lhs > rhs 99 | case GE: 100 | return lhs >= rhs 101 | default: 102 | fmt.Printf("unsupported operator: %v\n", op) 103 | return false 104 | } 105 | } 106 | 107 | // compareInt16s compares two int16 values. 108 | func compareInt16s(lhs, rhs int16, op Operator) bool { 109 | switch op { 110 | case NE: 111 | return lhs != rhs 112 | case EQ: 113 | return lhs == rhs 114 | case LT: 115 | return lhs < rhs 116 | case LE: 117 | return lhs <= rhs 118 | case GT: 119 | return lhs > rhs 120 | case GE: 121 | return lhs >= rhs 122 | default: 123 | fmt.Printf("unsupported operator: %v\n", op) 124 | return false 125 | } 126 | } 127 | 128 | // compareStrings compares two strings. 129 | func compareStrings(lhs, rhs string, op Operator) bool { 130 | switch op { 131 | case NE: 132 | return lhs != rhs 133 | case EQ: 134 | return lhs == rhs 135 | case LT: 136 | return lhs < rhs 137 | case LE: 138 | return lhs <= rhs 139 | case GT: 140 | return lhs > rhs 141 | case GE: 142 | return lhs >= rhs 143 | default: 144 | fmt.Printf("unsupported operator: %v\n", op) 145 | return false 146 | } 147 | } 148 | 149 | // compareBools compares two booleans (only equality comparisons make sense). 150 | func compareBools(lhs, rhs bool, op Operator) bool { 151 | switch op { 152 | case EQ: 153 | return lhs == rhs 154 | case NE: 155 | return lhs != rhs 156 | default: 157 | fmt.Printf("unsupported operator: %v\n", op) 158 | return false // Invalid for comparison operators like <, > 159 | } 160 | } 161 | 162 | // compareTimes compares two time.Time values. 163 | func compareTimes(lhs, rhs time.Time, op Operator) bool { 164 | switch op { 165 | case NE: 166 | return !lhs.Equal(rhs) 167 | case EQ: 168 | return lhs.Equal(rhs) 169 | case LT: 170 | return lhs.Before(rhs) 171 | case LE: 172 | return lhs.Before(rhs) || lhs.Equal(rhs) 173 | case GT: 174 | return lhs.After(rhs) 175 | case GE: 176 | return lhs.After(rhs) || lhs.Equal(rhs) 177 | default: 178 | fmt.Printf("unsupported operator: %v\n", op) 179 | return false 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /types/field_info.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | type SchemaType int 4 | 5 | // JDBC type codes 6 | const ( 7 | Integer SchemaType = 4 8 | Varchar SchemaType = 12 9 | Boolean SchemaType = 16 10 | Long SchemaType = -5 11 | Short SchemaType = 5 12 | Date SchemaType = 91 13 | ) 14 | 15 | type FieldInfo struct { 16 | Type SchemaType 17 | Length int 18 | } 19 | -------------------------------------------------------------------------------- /types/hash.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import "time" 4 | 5 | func Hash(value any) int { 6 | if value == nil { 7 | return 0 8 | } 9 | 10 | switch v := value.(type) { 11 | case int: 12 | return v 13 | case int64: 14 | return int(v) 15 | case int16: 16 | return int(v) 17 | case string: 18 | hash := 0 19 | for _, c := range v { 20 | hash += int(c) 21 | } 22 | return hash 23 | case bool: 24 | if v { 25 | return 1 26 | } else { 27 | return 0 28 | } 29 | case time.Time: 30 | return int(v.Unix()) 31 | default: 32 | return 0 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /types/int_type.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import "runtime" 4 | 5 | // IntSize provides the size of int on this architecture. 6 | var IntSize = 8 7 | 8 | func init() { 9 | if runtime.GOARCH == "386" || runtime.GOARCH == "arm" { 10 | IntSize = 4 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /types/operators.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import "fmt" 4 | 5 | // Operator is the type of Operator used in a term. 6 | type Operator int 7 | 8 | const ( 9 | NONE Operator = -1 10 | // EQ is the equal Operator. 11 | EQ Operator = iota 12 | // NE is the not equal Operator. 13 | NE 14 | // LT is the less than Operator. 15 | LT 16 | // LE is the less than or equal Operator. 17 | LE 18 | // GT is the greater than Operator. 19 | GT 20 | // GE is the greater than or equal Operator. 21 | GE 22 | ) 23 | 24 | // String returns the string representation of the Operator. 25 | func (op Operator) String() string { 26 | switch op { 27 | case EQ: 28 | return "=" 29 | case NE: 30 | return "<>" 31 | case LT: 32 | return "<" 33 | case LE: 34 | return "<=" 35 | case GT: 36 | return ">" 37 | case GE: 38 | return ">=" 39 | default: 40 | return "" 41 | } 42 | } 43 | 44 | // OperatorFromString returns the Operator from the given string. 45 | func OperatorFromString(op string) (Operator, error) { 46 | switch op { 47 | case "=": 48 | return EQ, nil 49 | case "<>", "!=": 50 | return NE, nil 51 | case "<": 52 | return LT, nil 53 | case "<=": 54 | return LE, nil 55 | case ">": 56 | return GT, nil 57 | case ">=": 58 | return GE, nil 59 | default: 60 | return -1, fmt.Errorf("invalid operator: %s", op) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /utils/hash_value.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "hash/fnv" 7 | "time" 8 | ) 9 | 10 | // HashValue hashes a variety of types using fnv 11 | func HashValue(val interface{}) (uint32, error) { 12 | h := fnv.New32a() // Create a 32-bit FNV-1a hash 13 | 14 | switch v := val.(type) { 15 | case int16: 16 | _, err := fmt.Fprintf(h, "%d", v) 17 | if err != nil { 18 | return 0, fmt.Errorf("failed to hash int16: %w", err) 19 | } 20 | case int: 21 | _, err := fmt.Fprintf(h, "%d", v) 22 | if err != nil { 23 | return 0, fmt.Errorf("failed to hash int: %w", err) 24 | } 25 | case int64: 26 | _, err := fmt.Fprintf(h, "%d", v) 27 | if err != nil { 28 | return 0, fmt.Errorf("failed to hash int64: %w", err) 29 | } 30 | case string: 31 | _, err := h.Write([]byte(v)) 32 | if err != nil { 33 | return 0, fmt.Errorf("failed to hash string: %w", err) 34 | } 35 | case bool: 36 | _, err := fmt.Fprintf(h, "%t", v) 37 | if err != nil { 38 | return 0, fmt.Errorf("failed to hash bool: %w", err) 39 | } 40 | case time.Time: 41 | _, err := h.Write([]byte(v.String())) 42 | if err != nil { 43 | return 0, fmt.Errorf("failed to hash time.Time: %w", err) 44 | } 45 | case nil: 46 | return 0, errors.New("cannot hash nil value") 47 | default: 48 | return 0, fmt.Errorf("unsupported type: %T", v) 49 | } 50 | 51 | return h.Sum32(), nil 52 | } 53 | -------------------------------------------------------------------------------- /utils/hash_value_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "github.com/stretchr/testify/require" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestHashValue(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | input interface{} 14 | wantErr bool 15 | }{ 16 | {"int16", int16(123), false}, 17 | {"int", 456, false}, 18 | {"int64", int64(789), false}, 19 | {"string", "test string", false}, 20 | {"bool true", true, false}, 21 | {"bool false", false, false}, 22 | {"time.Time", time.Now(), false}, 23 | {"nil", nil, true}, 24 | {"unsupported type", struct{}{}, true}, 25 | } 26 | 27 | for _, tt := range tests { 28 | t.Run(tt.name, func(t *testing.T) { 29 | got, err := HashValue(tt.input) 30 | 31 | if tt.wantErr { 32 | assert.Error(t, err, "Expected an error for input: %v", tt.input) 33 | } else { 34 | require.NoError(t, err, "Did not expect an error for input: %v", tt.input) 35 | assert.NotEqual(t, uint32(0), got, "Hash value should not be 0 for input: %v", tt.input) 36 | } 37 | }) 38 | } 39 | } 40 | --------------------------------------------------------------------------------