├── .gitignore ├── .vscode └── launch.json ├── Makefile ├── README.md ├── antlr-4.13.1-complete.jar ├── cmd ├── cli │ └── main.go └── server │ └── main.go ├── go.mod ├── go.sum ├── internal ├── buffer │ ├── buffer.go │ ├── buffer_manager.go │ └── buffer_manager_test.go ├── file │ ├── block_id.go │ ├── byte_buffer.go │ ├── file_manager.go │ ├── file_manager_test.go │ ├── page.go │ └── random_access_file.go ├── index │ ├── hash │ │ └── hash_index.go │ └── index.go ├── log │ ├── log_iterator.go │ ├── log_manager.go │ └── log_manager_test.go ├── metadata │ ├── catalog_test.go │ ├── index_info.go │ ├── index_manager.go │ ├── metadata_manager.go │ ├── metadata_manager_test.go │ ├── stat_info.go │ ├── stats_manager.go │ ├── table_manager.go │ ├── table_manager_test.go │ └── view_manager.go ├── parser │ ├── SimpleSql.g4 │ ├── SimpleSql.interp │ ├── SimpleSql.tokens │ ├── SimpleSqlLexer.interp │ ├── SimpleSqlLexer.tokens │ ├── ast.go │ ├── parser_test.go │ ├── simplesql_base_visitor.go │ ├── simplesql_lexer.go │ ├── simplesql_parser.go │ ├── simplesql_visitor.go │ └── visitor.go ├── plan │ ├── basic_query_planner.go │ ├── basic_update_planner.go │ ├── plan.go │ ├── planner.go │ ├── product_plan.go │ ├── project_plan.go │ ├── query_planner.go │ ├── select_plan.go │ ├── table_plan.go │ └── update_planner.go ├── query │ ├── constant.go │ ├── expression.go │ ├── predicate.go │ ├── product_scan.go │ ├── project_scan.go │ ├── rid.go │ ├── scan.go │ ├── select_scan.go │ ├── term.go │ └── update_scan.go ├── record │ ├── layout.go │ ├── layout_test.go │ ├── record_page.go │ ├── record_page_test.go │ ├── schema.go │ ├── table_scan.go │ └── table_scan_test.go ├── server │ └── simpledb.go ├── tx │ ├── concurency_test.go │ ├── concurrency │ │ ├── concurrency_manager.go │ │ └── lock_table.go │ ├── recovery │ │ ├── buffer_list.go │ │ ├── checkpoint_record.go │ │ ├── commit_record.go │ │ ├── log_record.go │ │ ├── recovery_manager.go │ │ ├── rollback_record.go │ │ ├── set_int_record.go │ │ ├── set_string_record.go │ │ ├── start_record.go │ │ └── transaction.go │ └── transaction_test.go └── utils │ ├── const.go │ └── utils.go └── screenshot.png /.gitignore: -------------------------------------------------------------------------------- 1 | **/data 2 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | 8 | { 9 | "name": "Launch Package", 10 | "type": "go", 11 | "request": "launch", 12 | "mode": "auto", 13 | "program": "cmd/cli/main.go" 14 | } 15 | ] 16 | } 17 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | run: 3 | go run ./cmd/cli/main.go 4 | 5 | parser: 6 | cd ./internal/parser && antlr4 -Dlanguage=Go -visitor -no-listener -package parser SimpleSql.g4 7 | 8 | parser-pigeon: 9 | pigeon -o ./internal/parser_pigeon/parser.go ./internal/parser_pigeon/simpledb_sql.peg 10 | 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # simpledb-go 2 | 3 | This project implements SimpleDB in golang. SimpleDB is an educational 4 | database management system presented in Edward Sciore's book titled 5 | [Database Design and Implementation 2nEd](https://www.amazon.com/Database-Design-Implementation-Edward-Sciore/dp/0471757160) 6 | 7 | ![image](./screenshot.png) 8 | 9 | # Features not implemented yet 10 | - Support for indexing (B+Tree, Hash) 11 | - Support for Aggregate queries 12 | 13 | ## References: 14 | 15 | - https://howqueryengineswork.com/00-introduction.html 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /antlr-4.13.1-complete.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evanxg852000/simpledb-go/9ab9079c4ec66200f8291fe8af4fbddd891a1b9a/antlr-4.13.1-complete.jar -------------------------------------------------------------------------------- /cmd/cli/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | 8 | "github.com/c-bata/go-prompt" 9 | "github.com/evanxg852000/simpledb/internal/plan" 10 | "github.com/evanxg852000/simpledb/internal/server" 11 | "github.com/jedib0t/go-pretty/v6/table" 12 | ) 13 | 14 | func main() { 15 | fmt.Println("Welcome SimpleDB CLI!") 16 | workspaceDir := "./data" 17 | err := os.MkdirAll(workspaceDir, 0755) 18 | if err != nil { 19 | fmt.Println("error initializing database: ", err) 20 | } 21 | 22 | dbDir := path.Join(workspaceDir, "students_db") 23 | db := server.NewSimpleDB(dbDir, 4096, 10) 24 | 25 | if db.FileManager().IsNew() { 26 | prepareUsersTable(db) 27 | } 28 | 29 | //for debugin 30 | // tx := db.NewTx() 31 | // c, err := db.Planner().ExecuteQuery("select * from users where id = 3", tx) 32 | // _, rows := fetchResult(c) 33 | // _ = rows 34 | 35 | for { 36 | sqlInput := prompt.Input(">> ", completer, 37 | prompt.OptionTitle("sql-prompt"), 38 | prompt.OptionHistory([]string{"select id, name from users"}), 39 | prompt.OptionPrefixTextColor(prompt.Yellow), 40 | prompt.OptionPreviewSuggestionTextColor(prompt.Blue), 41 | prompt.OptionSelectedSuggestionBGColor(prompt.LightGray), 42 | prompt.OptionSuggestionBGColor(prompt.DarkGray), 43 | ) 44 | 45 | if sqlInput == ".exit" { 46 | os.Exit(0) 47 | } 48 | 49 | tx := db.NewTx() 50 | stmtResult, err := db.Planner().ExecuteQuery(sqlInput, tx) 51 | if err != nil { 52 | tx.Rollback() 53 | fmt.Println(err) 54 | continue 55 | } 56 | 57 | tx.Commit() 58 | 59 | if affectedRows, ok := stmtResult.(int64); ok { 60 | fmt.Println("Affected rows:", affectedRows) 61 | continue 62 | } 63 | 64 | headers, rows := fetchResult(stmtResult) 65 | renderDataTable(headers, rows) 66 | fmt.Println("") 67 | } 68 | 69 | } 70 | 71 | func completer(in prompt.Document) []prompt.Suggest { 72 | s := []prompt.Suggest{} 73 | return prompt.FilterHasPrefix(s, in.GetWordBeforeCursor(), true) 74 | } 75 | 76 | func renderDataTable(headers table.Row, rows []table.Row) { 77 | t := table.NewWriter() 78 | t.SetOutputMirror(os.Stdout) 79 | t.AppendHeader(headers) 80 | t.AppendRows(rows) 81 | t.Render() 82 | } 83 | 84 | func prepareUsersTable(db *server.SimpleDB) { 85 | t1 := db.NewTx() 86 | _, err := db.Planner().ExecuteQuery("create table users(id int, name varchar(10))", t1) 87 | if err != nil { 88 | t1.Rollback() 89 | panic(err) 90 | } 91 | t1.Commit() 92 | 93 | names := []string{"Evan", "John", "Jane", "Rodriguez", "Samuel", "Mauris"} 94 | t2 := db.NewTx() 95 | for id, name := range names { 96 | sqlStmt := fmt.Sprintf("insert into users(id, name) values(%d, '%s')", id+1, name) 97 | _, err = db.Planner().ExecuteQuery(sqlStmt, t2) 98 | if err != nil { 99 | t2.Rollback() 100 | panic(err) 101 | } 102 | } 103 | 104 | t2.Commit() 105 | } 106 | 107 | func fetchResult(result any) (table.Row, []table.Row) { 108 | plan := result.(plan.Plan) 109 | schema := plan.Schema() 110 | scan := plan.Open() 111 | rows := make([]table.Row, 0) 112 | for scan.Next() { 113 | row := make([]string, 0) 114 | for _, fieldName := range schema.Fields() { 115 | data := scan.GetValue(fieldName) 116 | row = append(row, data.String()) 117 | } 118 | rows = append(rows, makeRow(row)) 119 | } 120 | return makeRow(schema.Fields()), rows 121 | } 122 | 123 | func makeRow(data []string) []any { 124 | rows := make([]any, 0, len(data)) 125 | for _, item := range data { 126 | rows = append(rows, item) 127 | } 128 | return rows 129 | } 130 | -------------------------------------------------------------------------------- /cmd/server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/evanxg852000/simpledb/internal/file" 7 | ) 8 | 9 | func main() { 10 | fmt.Println("Welcome SimpleDB Server!") 11 | _ = file.NewBlockId("users.tbl", 32) 12 | } 13 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/evanxg852000/simpledb 2 | 3 | go 1.22.0 4 | 5 | require ( 6 | github.com/antlr/antlr4 v0.0.0-20181218183524-be58ebffde8e 7 | github.com/c-bata/go-prompt v0.2.6 8 | github.com/jedib0t/go-pretty/v6 v6.5.9 9 | github.com/stretchr/testify v1.9.0 10 | ) 11 | 12 | require ( 13 | github.com/davecgh/go-spew v1.1.1 // indirect 14 | github.com/mattn/go-colorable v0.1.7 // indirect 15 | github.com/mattn/go-isatty v0.0.12 // indirect 16 | github.com/mattn/go-runewidth v0.0.15 // indirect 17 | github.com/mattn/go-tty v0.0.3 // indirect 18 | github.com/pkg/term v1.2.0-beta.2 // indirect 19 | github.com/pmezard/go-difflib v1.0.0 // indirect 20 | github.com/rivo/uniseg v0.2.0 // indirect 21 | golang.org/x/sys v0.17.0 // indirect 22 | gopkg.in/yaml.v3 v3.0.1 // indirect 23 | ) 24 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/antlr/antlr4 v0.0.0-20181218183524-be58ebffde8e h1:yxMh4HIdsSh2EqxUESWvzszYMNzOugRyYCeohfwNULM= 2 | github.com/antlr/antlr4 v0.0.0-20181218183524-be58ebffde8e/go.mod h1:T7PbCXFs94rrTttyxjbyT5+/1V8T2TYDejxUfHJjw1Y= 3 | github.com/c-bata/go-prompt v0.2.6 h1:POP+nrHE+DfLYx370bedwNhsqmpCUynWPxuHi0C5vZI= 4 | github.com/c-bata/go-prompt v0.2.6/go.mod h1:/LMAke8wD2FsNu9EXNdHxNLbd9MedkPnCdfpU9wwHfY= 5 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 6 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/jedib0t/go-pretty/v6 v6.5.9 h1:ACteMBRrrmm1gMsXe9PSTOClQ63IXDUt03H5U+UV8OU= 8 | github.com/jedib0t/go-pretty/v6 v6.5.9/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E= 9 | github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= 10 | github.com/mattn/go-colorable v0.1.7 h1:bQGKb3vps/j0E9GfJQ03JyhRuxsvdAanXlT9BTw3mdw= 11 | github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= 12 | github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= 13 | github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= 14 | github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= 15 | github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= 16 | github.com/mattn/go-runewidth v0.0.6/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= 17 | github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= 18 | github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= 19 | github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= 20 | github.com/mattn/go-tty v0.0.3 h1:5OfyWorkyO7xP52Mq7tB36ajHDG5OHrmBGIS/DtakQI= 21 | github.com/mattn/go-tty v0.0.3/go.mod h1:ihxohKRERHTVzN+aSVRwACLCeqIoZAWpoICkkvrWyR0= 22 | github.com/pkg/term v1.2.0-beta.2 h1:L3y/h2jkuBVFdWiJvNfYfKmzcCnILw7mJWm2JQuMppw= 23 | github.com/pkg/term v1.2.0-beta.2/go.mod h1:E25nymQcrSllhX42Ok8MRm1+hyBdHY0dCeiKZ9jpNGw= 24 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 25 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 26 | github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= 27 | github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= 28 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 29 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 30 | golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 31 | golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 32 | golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 33 | golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 34 | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 35 | golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 36 | golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 37 | golang.org/x/sys v0.0.0-20200918174421-af09f7315aff/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 38 | golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= 39 | golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 40 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 41 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 42 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 43 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 44 | -------------------------------------------------------------------------------- /internal/buffer/buffer.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/file" 5 | walog "github.com/evanxg852000/simpledb/internal/log" 6 | ) 7 | 8 | // A data buffer that wraps a page ans stores information about its status. 9 | // such as the associated disk block, the number of times the buffer has been 10 | // pinned, whether its contents has been modified, and if so, the id and lsn 11 | // of the modifying transaction. 12 | type Buffer struct { 13 | fileManager *file.FileManager 14 | logManager *walog.LogManager 15 | content file.Page 16 | blockId file.BlockId 17 | pins int64 18 | txNum int64 19 | lsn int64 20 | } 21 | 22 | func NewBuffer(fileManager *file.FileManager, logManager *walog.LogManager) Buffer { 23 | return Buffer{ 24 | fileManager: fileManager, 25 | logManager: logManager, 26 | content: file.NewPage(fileManager.BlockSize()), 27 | blockId: file.BlockId{}, 28 | pins: 0, 29 | txNum: -1, 30 | lsn: -1, 31 | } 32 | } 33 | 34 | func (buf *Buffer) Content() *file.Page { 35 | return &buf.content 36 | } 37 | 38 | func (buf *Buffer) Block() file.BlockId { 39 | return buf.blockId 40 | } 41 | 42 | // Negative lsn denotes a transaction withouts corresponding 43 | // log record 44 | func (buf *Buffer) Modify(txNum int64, lsn int64) { 45 | buf.txNum = txNum 46 | if lsn >= 0 { 47 | buf.lsn = lsn 48 | } 49 | } 50 | 51 | func (buf *Buffer) IsPinned() bool { 52 | return buf.pins > 0 53 | } 54 | 55 | func (buf *Buffer) ModifyingTx() int64 { 56 | return buf.txNum 57 | } 58 | 59 | // Reads the content of the specified block into the content of the buffer. 60 | // if the buffer was dirty, then its previous content 61 | // is first written to disk. 62 | func (buf *Buffer) AssignToBlock(blockId file.BlockId) error { 63 | buf.flush() 64 | buf.blockId = blockId 65 | err := buf.fileManager.Read(blockId, &buf.content) 66 | if err != nil { 67 | return err 68 | } 69 | buf.pins = 0 70 | return nil 71 | } 72 | 73 | // Write the buffer content to its disk block if it is dirty 74 | func (buf *Buffer) flush() error { 75 | if buf.txNum > 0 { 76 | err := buf.logManager.Flush(buf.lsn) 77 | if err != nil { 78 | return err 79 | } 80 | 81 | err = buf.fileManager.Write(buf.blockId, &buf.content) 82 | if err != nil { 83 | return err 84 | } 85 | buf.txNum = -1 86 | } 87 | return nil 88 | } 89 | 90 | func (buf *Buffer) Pin() { 91 | buf.pins += 1 92 | } 93 | 94 | func (buf *Buffer) Unpin() { 95 | buf.pins -= 1 96 | } 97 | -------------------------------------------------------------------------------- /internal/buffer/buffer_manager.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "time" 7 | 8 | "github.com/evanxg852000/simpledb/internal/file" 9 | walog "github.com/evanxg852000/simpledb/internal/log" 10 | "github.com/evanxg852000/simpledb/internal/utils" 11 | ) 12 | 13 | const ( 14 | MAX_WAIT_TIME = 5 * time.Second 15 | ) 16 | 17 | // Manages the pinning and unpinning of buffers to blocks. 18 | type BufferManager struct { 19 | bufferPool []Buffer 20 | numAvailable int64 21 | mu *sync.Mutex 22 | cond *sync.Cond 23 | } 24 | 25 | // Creates a buffer manager having the specified number 26 | // of buffer slots. 27 | func NewBufferManager(fileManager *file.FileManager, logManager *walog.LogManager, numSlot int) *BufferManager { 28 | bufferPool := make([]Buffer, numSlot) 29 | for i := 0; i < numSlot; i++ { 30 | bufferPool[i] = NewBuffer(fileManager, logManager) 31 | } 32 | 33 | mu := new(sync.Mutex) 34 | cond := sync.NewCond(mu) 35 | return &BufferManager{ 36 | bufferPool: bufferPool, 37 | numAvailable: int64(numSlot), 38 | mu: mu, 39 | cond: cond, 40 | } 41 | } 42 | 43 | // Returns the number of available (i.e. unpinned) buffers. 44 | func (bm *BufferManager) Available() int64 { 45 | bm.mu.Lock() 46 | defer bm.mu.Unlock() 47 | return bm.numAvailable 48 | } 49 | 50 | // Flushes the dirty buffers modified by the specified transaction 51 | func (bm *BufferManager) FlushAll(txNum int64) error { 52 | for _, buffer := range bm.bufferPool { 53 | if buffer.ModifyingTx() == txNum { 54 | err := buffer.flush() 55 | if err != nil { 56 | return err 57 | } 58 | } 59 | } 60 | return nil 61 | } 62 | 63 | // Unpins the specified data buffer. If its pin count 64 | // goes to zero, then notify any waiting threads. 65 | func (bm *BufferManager) Unpin(buff *Buffer) { 66 | bm.mu.Lock() 67 | defer bm.mu.Unlock() 68 | buff.Unpin() 69 | // is this slot (buffer) free to be used by another blockId? 70 | if !buff.IsPinned() { 71 | bm.numAvailable += 1 72 | bm.cond.Broadcast() 73 | } 74 | } 75 | 76 | // Pins a buffer to the specified block, potentially 77 | // waiting until a slot (buffer) becomes available. 78 | // If no slot becomes available within a fixed 79 | // time period (10 seconds), then an error is returned. 80 | func (bm *BufferManager) Pin(blockId file.BlockId) (*Buffer, error) { 81 | bm.mu.Lock() 82 | defer bm.mu.Unlock() 83 | 84 | startTimestamp := time.Now().UnixMilli() 85 | buff := bm.tryToPin(blockId) 86 | for buff == nil && !bm.waitToLong(startTimestamp) { 87 | utils.WaitCondWithTimeout(bm.cond, MAX_WAIT_TIME) 88 | buff = bm.tryToPin(blockId) 89 | } 90 | 91 | if buff == nil { 92 | return nil, fmt.Errorf("buffer request waited for too long") 93 | } 94 | return buff, nil 95 | } 96 | 97 | func (bm *BufferManager) waitToLong(startTimestamp int64) bool { 98 | return time.Now().UnixMilli()-startTimestamp > MAX_WAIT_TIME.Milliseconds() 99 | } 100 | 101 | // Tries to pin a buffer to the specified block. 102 | // If there is already a buffer assigned to that block 103 | // then that buffer is used; 104 | // otherwise, an unpinned buffer from the pool is chosen. 105 | // Returns a null value if there are no available buffers. 106 | func (bm *BufferManager) tryToPin(blockId file.BlockId) *Buffer { 107 | buff := bm.findExistingBuffer(blockId) 108 | if buff == nil { 109 | buff = bm.chooseUnpinnedBuffer() 110 | if buff == nil { 111 | return nil 112 | } 113 | buff.AssignToBlock(blockId) 114 | } 115 | 116 | if !buff.IsPinned() { 117 | bm.numAvailable -= 1 118 | } 119 | buff.Pin() 120 | return buff 121 | } 122 | 123 | func (bm *BufferManager) findExistingBuffer(blockId file.BlockId) *Buffer { 124 | for i := 0; i < len(bm.bufferPool); i++ { 125 | buff := &bm.bufferPool[i] 126 | if blockId.Equals(buff.blockId) { 127 | return buff 128 | } 129 | } 130 | return nil 131 | } 132 | 133 | func (bm *BufferManager) chooseUnpinnedBuffer() *Buffer { 134 | for i := 0; i < len(bm.bufferPool); i++ { 135 | buff := &bm.bufferPool[i] 136 | if !buff.IsPinned() { 137 | return buff 138 | } 139 | } 140 | return nil 141 | } 142 | -------------------------------------------------------------------------------- /internal/buffer/buffer_manager_test.go: -------------------------------------------------------------------------------- 1 | package buffer_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/evanxg852000/simpledb/internal/buffer" 8 | "github.com/evanxg852000/simpledb/internal/file" 9 | walog "github.com/evanxg852000/simpledb/internal/log" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestBufferManager(t *testing.T) { 14 | assert := assert.New(t) 15 | 16 | dbDirectory, err := os.MkdirTemp("", "test_buffer_manager_") 17 | assert.Nil(err) 18 | defer os.RemoveAll(dbDirectory) 19 | 20 | fm, err := file.NewFileManager(dbDirectory, 512) 21 | assert.Nil(err) 22 | 23 | lm, err := walog.NewLogManager(fm, "log_file") 24 | assert.Nil(err) 25 | 26 | bm := buffer.NewBufferManager(fm, lm, 3) 27 | buffers := make([]*buffer.Buffer, 6) 28 | 29 | buffers[0], err = bm.Pin(file.NewBlockId("testfile", 0)) 30 | assert.Nil(err) 31 | 32 | buffers[1], err = bm.Pin(file.NewBlockId("testfile", 1)) 33 | assert.Nil(err) 34 | 35 | buffers[2], err = bm.Pin(file.NewBlockId("testfile", 2)) 36 | assert.Nil(err) 37 | 38 | bm.Unpin(buffers[1]) 39 | buffers[1] = nil 40 | 41 | buffers[3], err = bm.Pin(file.NewBlockId("testfile", 0)) // block 0 pinned twice 42 | assert.Nil(err) 43 | assert.Equal(int64(1), bm.Available()) 44 | 45 | buffers[4], err = bm.Pin(file.NewBlockId("testfile", 1)) // block 1 repinned 46 | assert.Nil(err) 47 | 48 | // Attempting to pin block 3 will not work, no buffer left 49 | buffers[5], err = bm.Pin(file.NewBlockId("testfile", 3)) 50 | assert.NotNil(err) 51 | 52 | bm.Unpin(buffers[2]) 53 | buffers[2] = nil 54 | buffers[5], err = bm.Pin(file.NewBlockId("testfile", 3)) // now this works 55 | assert.Nil(err) 56 | 57 | //check final buffers state 58 | assert.Equal(int64(0), buffers[0].Block().BlockNum) 59 | assert.Empty(buffers[1]) 60 | assert.Empty(buffers[2]) 61 | assert.Equal(int64(0), buffers[3].Block().BlockNum) 62 | assert.Equal(int64(1), buffers[4].Block().BlockNum) 63 | assert.Equal(int64(3), buffers[5].Block().BlockNum) 64 | } 65 | -------------------------------------------------------------------------------- /internal/file/block_id.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import "fmt" 4 | 5 | type BlockId struct { 6 | FileName string 7 | BlockNum int64 8 | } 9 | 10 | func NewBlockId(file_name string, block_num int64) BlockId { 11 | return BlockId{FileName: file_name, BlockNum: block_num} 12 | } 13 | 14 | func (blockId *BlockId) Equals(other BlockId) bool { 15 | return blockId.BlockNum == other.BlockNum && blockId.FileName == other.FileName 16 | } 17 | 18 | func (blockId *BlockId) String() string { 19 | return fmt.Sprintf("{file: %s, block: %d}", blockId.FileName, blockId.BlockNum) 20 | } 21 | -------------------------------------------------------------------------------- /internal/file/byte_buffer.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | ) 7 | 8 | type ByteBuffer struct { 9 | buffer *bytes.Buffer 10 | } 11 | 12 | func NewByteBuffer() *ByteBuffer { 13 | return &ByteBuffer{ 14 | buffer: bytes.NewBuffer(nil), 15 | } 16 | } 17 | 18 | func (bb *ByteBuffer) WriteInt(value int64) error { 19 | return binary.Write(bb.buffer, binary.LittleEndian, value) 20 | } 21 | 22 | func (bb *ByteBuffer) WriteFloat(value float64) error { 23 | return binary.Write(bb.buffer, binary.LittleEndian, value) 24 | } 25 | 26 | func (bb *ByteBuffer) WriteString(value string) error { 27 | return bb.WriteBytes([]byte(value)) 28 | } 29 | 30 | func (bb *ByteBuffer) WriteBytes(value []byte) error { 31 | length := uint64(len(value)) 32 | err := binary.Write(bb.buffer, binary.LittleEndian, length) 33 | if err != nil { 34 | return err 35 | } 36 | 37 | _, err = bb.buffer.Write(value) 38 | if err != nil { 39 | return err 40 | } 41 | 42 | return nil 43 | } 44 | 45 | func (bb *ByteBuffer) Data() []byte { 46 | return bb.buffer.Bytes() 47 | } 48 | -------------------------------------------------------------------------------- /internal/file/file_manager.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "sync" 8 | ) 9 | 10 | type FileManager struct { 11 | directory string 12 | blockSize int64 13 | openedFiles map[string]RandomAccessFile 14 | mu sync.Mutex 15 | isNew bool 16 | } 17 | 18 | func NewFileManager(directory string, block_size int64) (*FileManager, error) { 19 | _, err := os.Stat(directory) 20 | isNew := false 21 | if err != nil { 22 | if os.IsNotExist(err) { 23 | err = os.MkdirAll(directory, 0755) 24 | isNew = true 25 | } 26 | 27 | if err != nil { 28 | return nil, err 29 | } 30 | } 31 | 32 | dirFile, err := os.Open(directory) 33 | if err != nil { 34 | return nil, err 35 | } 36 | files, err := dirFile.Readdir(0) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | openedFiles := map[string]RandomAccessFile{} 42 | for _, entry := range files { 43 | if entry.IsDir() { 44 | continue 45 | } 46 | 47 | rafName := entry.Name() 48 | rafPath := filepath.Join(directory, rafName) 49 | raf, err := NewRandomAccessFile(rafPath) 50 | if err != nil { 51 | return nil, err 52 | } 53 | openedFiles[rafName] = raf 54 | } 55 | 56 | return &FileManager{ 57 | directory: directory, 58 | blockSize: block_size, 59 | openedFiles: openedFiles, 60 | isNew: isNew, 61 | }, nil 62 | } 63 | 64 | func (fm *FileManager) Read(block BlockId, page *Page) error { 65 | fm.mu.Lock() 66 | defer fm.mu.Unlock() 67 | 68 | raf, err := fm.getFile(block.FileName) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | raf.ReadAt(page.Data(), block.BlockNum*fm.blockSize) 74 | return nil 75 | } 76 | 77 | func (fm *FileManager) Write(block BlockId, page *Page) error { 78 | fm.mu.Lock() 79 | defer fm.mu.Unlock() 80 | 81 | raf, err := fm.getFile(block.FileName) 82 | if err != nil { 83 | return err 84 | } 85 | 86 | raf.WriteAt(page.Data(), block.BlockNum*fm.blockSize) 87 | return nil 88 | } 89 | 90 | func (fm *FileManager) Append(fileName string) (BlockId, error) { 91 | block_num, err := fm.BlockCount(fileName) 92 | if err != nil { 93 | return BlockId{}, err 94 | } 95 | 96 | block := NewBlockId(fileName, block_num) 97 | data := make([]byte, fm.blockSize) 98 | 99 | fm.mu.Lock() 100 | defer fm.mu.Unlock() 101 | raFile, exists := fm.openedFiles[block.FileName] 102 | if !exists { 103 | return BlockId{}, fmt.Errorf("file not found: `%s` ", block.FileName) 104 | } 105 | 106 | raFile.WriteAt(data, block.BlockNum*fm.blockSize) 107 | return block, nil 108 | } 109 | 110 | func (fm *FileManager) BlockCount(fileName string) (int64, error) { 111 | fm.mu.Lock() 112 | defer fm.mu.Unlock() 113 | 114 | raf, err := fm.getFile(fileName) 115 | if err != nil { 116 | return 0, err 117 | } 118 | 119 | size, err := raf.Size() 120 | if err != nil { 121 | return 0, err 122 | } 123 | 124 | return (size / fm.blockSize), nil 125 | } 126 | 127 | func (fm *FileManager) BlockSize() int64 { 128 | return fm.blockSize 129 | } 130 | 131 | func (fm *FileManager) getFile(fileName string) (RandomAccessFile, error) { 132 | raf, exists := fm.openedFiles[fileName] 133 | if exists { 134 | return raf, nil 135 | } 136 | 137 | rafPath := filepath.Join(fm.directory, fileName) 138 | raf, err := NewRandomAccessFile(rafPath) 139 | if err != nil { 140 | return RandomAccessFile{}, err 141 | } 142 | 143 | fm.openedFiles[fileName] = raf 144 | return raf, nil 145 | } 146 | 147 | func (fm *FileManager) IsNew() bool { 148 | return fm.isNew 149 | } 150 | -------------------------------------------------------------------------------- /internal/file/file_manager_test.go: -------------------------------------------------------------------------------- 1 | package file_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/evanxg852000/simpledb/internal/file" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestFileManager(t *testing.T) { 12 | assert := assert.New(t) 13 | 14 | dbDirectory, err := os.MkdirTemp("", "test_file_manager_") 15 | assert.Nil(err) 16 | defer os.RemoveAll(dbDirectory) 17 | 18 | fm, err := file.NewFileManager(dbDirectory, 800) 19 | assert.Nil(err) 20 | 21 | blockId := file.NewBlockId("testfile", 2) 22 | pos1 := int64(88) 23 | 24 | p1 := file.NewPage(fm.BlockSize()) 25 | n, err := p1.WriteString(pos1, "abcdefghijklm") 26 | assert.Nil(err) 27 | pos2 := pos1 + n 28 | p1.WriteInt(pos2, -345) 29 | pos3 := pos2 + 8 30 | p1.WriteFloat(pos3, -3.14) 31 | fm.Write(blockId, &p1) 32 | 33 | p2 := file.NewPage(fm.BlockSize()) 34 | fm.Read(blockId, &p2) 35 | 36 | s, err := p2.ReadString(pos1) 37 | assert.Nil(err) 38 | assert.Equal("abcdefghijklm", s) 39 | 40 | v, err := p2.ReadInt(pos2) 41 | assert.Nil(err) 42 | assert.Equal(int64(-345), v) 43 | 44 | f, err := p2.ReadFloat(pos3) 45 | assert.Nil(err) 46 | assert.Equal(float64(-3.14), f) 47 | } 48 | -------------------------------------------------------------------------------- /internal/file/page.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "fmt" 7 | ) 8 | 9 | type Page struct { 10 | data []byte 11 | } 12 | 13 | // For creating data buffers 14 | func NewPage(blockSize int64) Page { 15 | return Page{data: make([]byte, blockSize)} 16 | } 17 | 18 | // For creating log pages 19 | func NewPageWithData(data []byte) Page { 20 | return Page{data: data} 21 | } 22 | 23 | func (page *Page) ReadInt(offset int64) (int64, error) { 24 | var value int64 25 | buffer := bytes.NewBuffer(page.data[offset:]) 26 | err := binary.Read(buffer, binary.LittleEndian, &value) 27 | if err != nil { 28 | return 0, err 29 | } 30 | return value, nil 31 | } 32 | 33 | func (page *Page) WriteInt(offset int64, value int64) error { 34 | buffer := bytes.NewBuffer(nil) 35 | err := binary.Write(buffer, binary.LittleEndian, value) 36 | if err != nil { 37 | return err 38 | } 39 | 40 | if int(offset+8) > len(page.data) { 41 | return fmt.Errorf("not enough space to encode data") 42 | } 43 | 44 | copy(page.data[offset:], buffer.Bytes()) 45 | return nil 46 | } 47 | 48 | func (page *Page) ReadFloat(offset int64) (float64, error) { 49 | var value float64 50 | buffer := bytes.NewBuffer(page.data[offset:]) 51 | err := binary.Read(buffer, binary.LittleEndian, &value) 52 | if err != nil { 53 | return 0, err 54 | } 55 | return value, nil 56 | } 57 | 58 | func (page *Page) WriteFloat(offset int64, value float64) error { 59 | buffer := bytes.NewBuffer(nil) 60 | err := binary.Write(buffer, binary.LittleEndian, value) 61 | if err != nil { 62 | return err 63 | } 64 | 65 | if int(offset+8) > len(page.data) { 66 | return fmt.Errorf("not enough space to encode data") 67 | } 68 | 69 | copy(page.data[offset:], buffer.Bytes()) 70 | return nil 71 | } 72 | 73 | func (page *Page) ReadString(offset int64) (string, error) { 74 | data, err := page.ReadBytes(offset) 75 | if err != nil { 76 | return "", err 77 | } 78 | 79 | return string(data), nil 80 | } 81 | 82 | func (page *Page) WriteString(offset int64, value string) (int64, error) { 83 | n, err := page.WriteBytes(offset, []byte(value)) 84 | if err != nil { 85 | return 0, err 86 | } 87 | return n, nil 88 | } 89 | 90 | func (page *Page) ReadBytes(offset int64) ([]byte, error) { 91 | var length uint64 92 | buffer := bytes.NewBuffer(page.data[offset:]) 93 | err := binary.Read(buffer, binary.LittleEndian, &length) 94 | if err != nil { 95 | return []byte{}, err 96 | } 97 | 98 | data := make([]byte, length) 99 | n, err := buffer.Read(data) 100 | if err != nil { 101 | return []byte{}, err 102 | } 103 | 104 | if n != int(length) { 105 | return []byte{}, fmt.Errorf("early EOF") 106 | } 107 | 108 | return data, nil 109 | } 110 | 111 | func (page *Page) WriteBytes(offset int64, value []byte) (int64, error) { 112 | length := uint64(len(value)) 113 | 114 | buffer := bytes.NewBuffer(nil) 115 | err := binary.Write(buffer, binary.LittleEndian, length) 116 | if err != nil { 117 | return 0, err 118 | } 119 | 120 | _, err = buffer.Write(value) 121 | if err != nil { 122 | return 8, err 123 | } 124 | 125 | encodedData := buffer.Bytes() 126 | if int(offset)+len(encodedData) > len(page.data) { 127 | return 8, fmt.Errorf("not enough space to encode data") 128 | } 129 | 130 | copy(page.data[offset:], encodedData) 131 | return 8 + int64(len(encodedData)), nil 132 | } 133 | 134 | func (page *Page) Data() []byte { 135 | return page.data 136 | } 137 | 138 | func GetEncodingLength(len int64) int64 { 139 | return int64(8 + len) 140 | } 141 | -------------------------------------------------------------------------------- /internal/file/random_access_file.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import ( 4 | "os" 5 | "sync" 6 | ) 7 | 8 | type RandomAccessFile struct { 9 | file *os.File 10 | mu *sync.Mutex 11 | } 12 | 13 | func NewRandomAccessFile(name string) (RandomAccessFile, error) { 14 | file, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE, 0644) 15 | if err != nil { 16 | return RandomAccessFile{}, err 17 | } 18 | 19 | return RandomAccessFile{file: file, mu: new(sync.Mutex)}, nil 20 | } 21 | 22 | func (raf *RandomAccessFile) WriteAt(data []byte, offset int64) (int, error) { 23 | raf.mu.Lock() 24 | defer raf.mu.Unlock() 25 | 26 | n, err := raf.file.WriteAt(data, offset) 27 | if err != nil { 28 | return 0, err 29 | } 30 | 31 | err = raf.file.Sync() 32 | return n, err 33 | } 34 | 35 | func (raf *RandomAccessFile) ReadAt(data []byte, offset int64) (int, error) { 36 | raf.mu.Lock() 37 | defer raf.mu.Unlock() 38 | return raf.file.ReadAt(data, offset) 39 | } 40 | 41 | func (raf *RandomAccessFile) Size() (int64, error) { 42 | raf.mu.Lock() 43 | defer raf.mu.Unlock() 44 | info, err := raf.file.Stat() 45 | if err != nil { 46 | return 0, err 47 | } 48 | 49 | return info.Size(), nil 50 | } 51 | 52 | func (raf *RandomAccessFile) Close() { 53 | raf.mu.Lock() 54 | defer raf.mu.Unlock() 55 | raf.file.Close() 56 | } 57 | 58 | // func (raf *RandomAccessFile) Seek(offset int64) (int64, error) { 59 | // return raf.file.Seek(offset, io.SeekStart) 60 | // } 61 | -------------------------------------------------------------------------------- /internal/index/hash/hash_index.go: -------------------------------------------------------------------------------- 1 | package hash 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/query" 5 | "github.com/evanxg852000/simpledb/internal/record" 6 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 7 | ) 8 | 9 | const ( 10 | NUM_BUCKETS = 100 11 | ) 12 | 13 | // A static hash implementation of the Index interface. 14 | // A fixed number of buckets is allocated (currently, 100), 15 | // and each bucket is implemented as a file of index records. 16 | type HashIndex struct { 17 | } 18 | 19 | func NewHashIndex(tx *recovery.Transaction, indexName string, layout *record.Layout) *HashIndex { 20 | //TODO: 21 | return nil 22 | } 23 | 24 | func (hi *HashIndex) BeforeFirst(searchKey query.Constant) { 25 | 26 | } 27 | 28 | func (hi *HashIndex) Next() bool { 29 | //TODO: 30 | return false 31 | } 32 | 33 | func (hi *HashIndex) GetDataRID() query.RID { 34 | //TODO: 35 | return query.NewRID(0, 1) 36 | } 37 | 38 | func (hi *HashIndex) Insert(dataVal query.Constant, id query.RID) { 39 | //TODO: 40 | } 41 | 42 | func (hi *HashIndex) Delete(dataVal query.Constant, id query.RID) { 43 | //TODO: 44 | } 45 | 46 | func (hi *HashIndex) Close() { 47 | //TODO: 48 | } 49 | 50 | func SearchCost(numBlocks, recordPerBlock int64) int64 { 51 | return numBlocks / NUM_BUCKETS 52 | } 53 | -------------------------------------------------------------------------------- /internal/index/index.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | import "github.com/evanxg852000/simpledb/internal/query" 4 | 5 | type Index interface { 6 | // Positions the index before the first record 7 | // having the specified search key. 8 | BeforeFirst(searchKey query.Constant) 9 | 10 | // Moves the index to the next record having the 11 | // search key specified in the beforeFirst method. 12 | // Returns false if there are no more such index records. 13 | Next() bool 14 | 15 | // Returns the dataRID value stored in the current index record. 16 | GetDataRID() query.RID 17 | 18 | // Inserts an index record having the specified 19 | // dataval and dataRID values. 20 | Insert(dataVal query.Constant, id query.RID) 21 | 22 | // Deletes the index record having the specified 23 | // dataval and dataRID values. 24 | Delete(dataVal query.Constant, id query.RID) 25 | 26 | // Closes the index. 27 | Close() 28 | } 29 | -------------------------------------------------------------------------------- /internal/log/log_iterator.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import "github.com/evanxg852000/simpledb/internal/file" 4 | 5 | type LogIterator struct { 6 | fileManager *file.FileManager 7 | blockId file.BlockId 8 | currentPage file.Page 9 | currentOffset int64 10 | dataSpaceStart int64 11 | } 12 | 13 | func NewLogIterator(fm *file.FileManager, blockId file.BlockId) (LogIterator, error) { 14 | logIterator := LogIterator{ 15 | fileManager: fm, 16 | blockId: blockId, 17 | currentPage: file.NewPageWithData(make([]byte, fm.BlockSize())), 18 | currentOffset: 0, 19 | dataSpaceStart: 0, 20 | } 21 | 22 | err := logIterator.moveToBlock(blockId) 23 | if err != nil { 24 | return LogIterator{}, err 25 | } 26 | return logIterator, nil 27 | } 28 | 29 | func (logIter *LogIterator) HasNext() bool { 30 | return logIter.currentOffset < logIter.fileManager.BlockSize() || logIter.blockId.BlockNum > 0 31 | } 32 | 33 | func (logIter *LogIterator) Next() ([]byte, error) { 34 | if logIter.currentOffset == logIter.fileManager.BlockSize() { 35 | logIter.blockId = file.NewBlockId(logIter.blockId.FileName, logIter.blockId.BlockNum-1) 36 | logIter.moveToBlock(logIter.blockId) 37 | } 38 | 39 | data, err := logIter.currentPage.ReadBytes(logIter.currentOffset) 40 | if err != nil { 41 | return []byte{}, err 42 | } 43 | 44 | logIter.currentOffset = logIter.currentOffset + 8 + int64(len(data)) 45 | return data, nil 46 | } 47 | 48 | // Moves a page of the file specified by blockId 49 | // and positions the cursor at the first record in that block 50 | func (logIter *LogIterator) moveToBlock(blockId file.BlockId) error { 51 | logIter.fileManager.Read(blockId, &logIter.currentPage) 52 | dataSpaceStart, err := logIter.currentPage.ReadInt(0) 53 | if err != nil { 54 | return err 55 | } 56 | logIter.dataSpaceStart = dataSpaceStart 57 | logIter.currentOffset = dataSpaceStart 58 | return nil 59 | } 60 | -------------------------------------------------------------------------------- /internal/log/log_manager.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/evanxg852000/simpledb/internal/file" 7 | ) 8 | 9 | type LogManager struct { 10 | logFile string 11 | fileManager *file.FileManager 12 | logPage file.Page 13 | currentBlock file.BlockId 14 | latestLSN int64 15 | lastSavedLSN int64 16 | mu *sync.Mutex 17 | } 18 | 19 | func NewLogManager(fm *file.FileManager, logFile string) (*LogManager, error) { 20 | logManager := &LogManager{ 21 | logFile: logFile, 22 | fileManager: fm, 23 | logPage: file.NewPageWithData(make([]byte, fm.BlockSize())), 24 | currentBlock: file.BlockId{}, 25 | latestLSN: 0, 26 | lastSavedLSN: 0, 27 | mu: new(sync.Mutex), 28 | } 29 | 30 | numBlock, err := fm.BlockCount(logFile) 31 | if err != nil { 32 | return logManager, err 33 | } 34 | 35 | if numBlock == 0 { 36 | logManager.currentBlock, err = logManager.appendNewBlock() 37 | if err != nil { 38 | return logManager, err 39 | } 40 | return logManager, nil 41 | } 42 | 43 | logManager.currentBlock = file.NewBlockId(logFile, numBlock-1) 44 | fm.Read(logManager.currentBlock, &logManager.logPage) 45 | return logManager, nil 46 | } 47 | 48 | // Ensures that the log record corresponding to the 49 | // specified LSN has been written to disk. 50 | // All earlier log records will also be written to disk. 51 | func (lm *LogManager) Flush(lsn int64) error { 52 | if lsn >= lm.lastSavedLSN { 53 | return lm.flushFile() 54 | } 55 | return nil 56 | } 57 | 58 | func (lm *LogManager) Iterator() (LogIterator, error) { 59 | err := lm.flushFile() 60 | if err != nil { 61 | return LogIterator{}, err 62 | } 63 | return NewLogIterator(lm.fileManager, lm.currentBlock) 64 | } 65 | 66 | // Appends a log record to the log buffer. 67 | // The record consists of an arbitrary array of bytes. 68 | // Log records are written right to left in the buffer. 69 | // The size of the record is written before the bytes. 70 | // The beginning of the buffer contains the location 71 | // of the last-written record (the "boundary"). 72 | // Storing the records backwards makes it easy to read them in reverse order. 73 | func (lm *LogManager) Append(data []byte) (int64, error) { 74 | lm.mu.Lock() 75 | defer lm.mu.Unlock() 76 | 77 | dataSpaceStart, err := lm.logPage.ReadInt(0) 78 | if err != nil { 79 | return 0, err 80 | } 81 | bytesNeeded := int64(8 + len(data)) 82 | 83 | // does the log record fit? 84 | if dataSpaceStart-bytesNeeded < 8 { 85 | lm.flushFile() 86 | lm.currentBlock, err = lm.appendNewBlock() 87 | if err != nil { 88 | return 0, err 89 | } 90 | 91 | dataSpaceStart, err = lm.logPage.ReadInt(0) 92 | if err != nil { 93 | return 0, err 94 | } 95 | } 96 | 97 | offset := dataSpaceStart - bytesNeeded 98 | _, err = lm.logPage.WriteBytes(offset, data) 99 | if err != nil { 100 | return 0, err 101 | } 102 | err = lm.logPage.WriteInt(0, offset) // update the new data boundary 103 | if err != nil { 104 | return 0, err 105 | } 106 | 107 | lm.latestLSN += 1 108 | return lm.latestLSN, nil 109 | } 110 | 111 | // Appends a new page to the log file 112 | func (lm *LogManager) appendNewBlock() (file.BlockId, error) { 113 | blockId, err := lm.fileManager.Append(lm.logFile) 114 | if err != nil { 115 | return file.BlockId{}, err 116 | } 117 | 118 | // Store the valid data start offset at the start of the page. 119 | // Note that record are insert in log page from the end of the page. 120 | err = lm.logPage.WriteInt(0, lm.fileManager.BlockSize()) 121 | if err != nil { 122 | return file.BlockId{}, err 123 | } 124 | 125 | err = lm.fileManager.Write(blockId, &lm.logPage) 126 | return blockId, err 127 | } 128 | 129 | // flushFile flushes syncs the current page to the file 130 | func (lm *LogManager) flushFile() error { 131 | return lm.fileManager.Write(lm.currentBlock, &lm.logPage) 132 | } 133 | -------------------------------------------------------------------------------- /internal/log/log_manager_test.go: -------------------------------------------------------------------------------- 1 | package log_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/evanxg852000/simpledb/internal/file" 10 | walog "github.com/evanxg852000/simpledb/internal/log" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestLogManager(t *testing.T) { 15 | assert := assert.New(t) 16 | 17 | dbDirectory, err := os.MkdirTemp("", "test_log_manager_") 18 | assert.Nil(err) 19 | defer os.RemoveAll(dbDirectory) 20 | 21 | fm, err := file.NewFileManager(dbDirectory, 512) 22 | assert.Nil(err) 23 | 24 | logManager, err := walog.NewLogManager(fm, "log_file") 25 | assert.Nil(err) 26 | 27 | // create records 28 | expectedItems := []string{} 29 | suffix := strings.Repeat("X", 10) 30 | for i := 1; i <= 1500; i++ { 31 | item := fmt.Sprintf("LOG_%d_%s", i, suffix) 32 | expectedItems = append(expectedItems, item) 33 | lsn, err := logManager.Append([]byte(item)) 34 | assert.Nil(err) 35 | if i%150 == 0 { 36 | err = logManager.Flush(lsn) 37 | assert.Nil(err) 38 | } 39 | } 40 | 41 | // iterate 42 | logIterator, err := logManager.Iterator() 43 | assert.Nil(err) 44 | index := 1500 45 | for logIterator.HasNext() { 46 | index = index - 1 47 | foundItem, err := logIterator.Next() 48 | assert.Nil(err) 49 | assert.Equal(expectedItems[index], string(foundItem)) 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /internal/metadata/catalog_test.go: -------------------------------------------------------------------------------- 1 | package metadata_test 2 | 3 | import ( 4 | "os" 5 | "path" 6 | "testing" 7 | 8 | "github.com/evanxg852000/simpledb/internal/record" 9 | "github.com/evanxg852000/simpledb/internal/server" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestCatalog(t *testing.T) { 14 | assert := assert.New(t) 15 | workspaceDir, err := os.MkdirTemp("", "test_catalog") 16 | assert.Nil(err) 17 | dbDir := path.Join(workspaceDir, "db") 18 | defer os.RemoveAll(workspaceDir) 19 | 20 | db := server.NewSimpleDB(dbDir, 400, 8) 21 | tx := db.NewTx() 22 | 23 | tblManager := db.MetadataManager().GetTableManager() 24 | // Get all tables and their slot size 25 | layout, err := tblManager.GetLayout("table_catalog", tx) 26 | assert.Nil(err) 27 | tblScan, err := record.NewTableScan(tx, "table_catalog", layout) 28 | assert.Nil(err) 29 | rows := []struct { 30 | tblName string 31 | slotSize int64 32 | }{} 33 | for tblScan.Next() { 34 | tblName := tblScan.GetString("table_name") 35 | slotSize := tblScan.GetInt("slot_size") 36 | rows = append(rows, struct { 37 | tblName string 38 | slotSize int64 39 | }{tblName, slotSize}) 40 | } 41 | assert.Equal([]struct { 42 | tblName string 43 | slotSize int64 44 | }{ 45 | {"table_catalog", 56}, 46 | {"field_catalog", 112}, 47 | {"view_catalog", 156}, 48 | {"index_catalog", 128}, 49 | }, rows) 50 | 51 | tblScan.Close() 52 | 53 | // Get all fields for each table and their offsets 54 | layout, err = tblManager.GetLayout("field_catalog", tx) 55 | assert.Nil(err) 56 | tblScan, err = record.NewTableScan(tx, "field_catalog", layout) 57 | assert.Nil(err) 58 | rows2 := []struct { 59 | tblName string 60 | fldName string 61 | offset int64 62 | }{} 63 | for tblScan.Next() { 64 | tblName := tblScan.GetString("table_name") 65 | fldName := tblScan.GetString("field_name") 66 | offset := tblScan.GetInt("offset") 67 | rows2 = append(rows2, struct { 68 | tblName string 69 | fldName string 70 | offset int64 71 | }{tblName, fldName, offset}) 72 | } 73 | assert.Equal(12, len(rows2)) 74 | tblScan.Close() 75 | } 76 | -------------------------------------------------------------------------------- /internal/metadata/index_info.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/index" 5 | "github.com/evanxg852000/simpledb/internal/index/hash" 6 | "github.com/evanxg852000/simpledb/internal/record" 7 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 8 | ) 9 | 10 | type IndexInfo struct { 11 | IndexName string 12 | FieldName string 13 | tx *recovery.Transaction 14 | schema *record.Schema 15 | layout *record.Layout 16 | statsInfo StatInfo 17 | } 18 | 19 | // Create an IndexInfo object for the specified index. 20 | func NewIndexInfo(indexName string, fieldName string, 21 | schema *record.Schema, tx *recovery.Transaction, si StatInfo) *IndexInfo { 22 | fldType := schema.FieldType(fieldName) 23 | fldLength := schema.FieldLength(fieldName) 24 | 25 | layout := createIndexLayout(fldType, fldLength) 26 | return &IndexInfo{ 27 | indexName, 28 | fieldName, 29 | tx, 30 | schema, 31 | layout, 32 | si, 33 | } 34 | } 35 | 36 | // Open the index described by this object. 37 | func (ii *IndexInfo) Open() index.Index { 38 | // return new BTreeIndex(ii.tx, ii.indexName, ii.layout) 39 | return hash.NewHashIndex(ii.tx, ii.IndexName, ii.layout) 40 | } 41 | 42 | // Estimate the number of block accesses required to 43 | // find all index records having a particular search key. 44 | // The method uses the table's metadata to estimate the 45 | // size of the index file and the number of index records 46 | // per block. 47 | // It then passes this information to the traversalCost 48 | // method of the appropriate index type, 49 | // which provides the estimate. 50 | func (ii *IndexInfo) BlockAccessed() int64 { 51 | recordPerBlock := ii.tx.BlockSize() / ii.layout.SlotSize() 52 | numBlocks := ii.statsInfo.RecordsOutput() / recordPerBlock 53 | // return btree.SearchCost(numBlocks, recordPerBlock) 54 | return hash.SearchCost(numBlocks, recordPerBlock) 55 | } 56 | 57 | // Return the estimated number of records having a 58 | // search key. This value is the same as doing a select 59 | // query; that is, it is the number of records in the table 60 | // divided by the number of distinct values of the indexed field. 61 | func (ii *IndexInfo) RecordsOutput() int64 { 62 | return ii.statsInfo.RecordsOutput() / ii.statsInfo.DistinctValues(ii.FieldName) 63 | } 64 | 65 | // Return the distinct values for a specified field 66 | // in the underlying table, or 1 for the indexed field. 67 | func (ii IndexInfo) DistinctValues(fieldName string) int64 { 68 | if ii.FieldName == fieldName { 69 | return 1 70 | } 71 | return ii.statsInfo.DistinctValues(ii.FieldName) 72 | } 73 | 74 | // Return the layout of the index records. 75 | // The schema consists of the dataRID (which is 76 | // represented as two integers, the block number and the 77 | // record ID) and the dataval (which is the indexed field). 78 | // Schema information about the indexed field is obtained 79 | // via the table's schema. 80 | func createIndexLayout(fldType, fldLength int64) *record.Layout { 81 | schema := record.NewSchema() 82 | schema.AddIntField("block") 83 | schema.AddIntField("id") 84 | if int(fldType) == record.INTEGER_TYPE { 85 | schema.AddIntField("data_val") 86 | } else { 87 | schema.AddStringField("data_val", fldLength) 88 | } 89 | return record.NewLayout(schema) 90 | } 91 | -------------------------------------------------------------------------------- /internal/metadata/index_manager.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/evanxg852000/simpledb/internal/record" 7 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 8 | ) 9 | 10 | // The index manager. 11 | // The index manager has similar functionality to the table manager. 12 | type IndexManager struct { 13 | layout *record.Layout 14 | tableManager *TableManager 15 | statsManager *StatsManager 16 | } 17 | 18 | // Create the index manager. 19 | // This constructor is called during system startup. 20 | // If the database is new, then the index catalog table is created. 21 | func NewIndexManager(isNew bool, tableManager *TableManager, statsManager *StatsManager, tx *recovery.Transaction) *IndexManager { 22 | if isNew { 23 | schema := record.NewSchema() 24 | schema.AddStringField("index_name", MAX_NAME_LENGTH) 25 | schema.AddStringField("table_name", MAX_NAME_LENGTH) 26 | schema.AddStringField("field_name", MAX_NAME_LENGTH) 27 | tableManager.CreateTable("index_catalog", schema, tx) 28 | } 29 | 30 | layout, err := tableManager.GetLayout("index_catalog", tx) 31 | if err != nil { 32 | fmt.Println("err: ", err) 33 | } 34 | 35 | return &IndexManager{layout, tableManager, statsManager} 36 | } 37 | 38 | // Create an index of the specified type for the specified field. 39 | // A unique ID is assigned to this index, and its information 40 | // is stored in the idxcat table. 41 | func (indexManager *IndexManager) CreateIndex(idxName, tblName, fldName string, tx *recovery.Transaction) error { 42 | tableScan, err := record.NewTableScan(tx, "index_catalog", indexManager.layout) 43 | if err != nil { 44 | return err 45 | } 46 | tableScan.Insert() 47 | tableScan.SetString("index_name", idxName) 48 | tableScan.SetString("table_name", tblName) 49 | tableScan.SetString("field_name", fldName) 50 | tableScan.Close() 51 | return nil 52 | } 53 | 54 | // Return a map containing the index info for all indexes 55 | // on the specified table. 56 | func (indexManager *IndexManager) GetIndexInfo(tblName string, tx *recovery.Transaction) (map[string]IndexInfo, error) { 57 | result := map[string]IndexInfo{} 58 | tableScan, err := record.NewTableScan(tx, "index_catalog", indexManager.layout) 59 | if err != nil { 60 | return result, err 61 | } 62 | 63 | for tableScan.Next() { 64 | storedTableName := tableScan.GetString("table_name") 65 | if storedTableName == tblName { 66 | idxName := tableScan.GetString("index_name") 67 | fldName := tableScan.GetString("field_name") 68 | tableLayout, _ := indexManager.tableManager.GetLayout(tblName, tx) 69 | statsInfo := indexManager.statsManager.GetStatInfo(tblName, tableLayout, tx) 70 | idxInfo := NewIndexInfo(idxName, fldName, tableLayout.Schema, tx, statsInfo) 71 | result[fldName] = *idxInfo 72 | } 73 | } 74 | tableScan.Close() 75 | return result, nil 76 | } 77 | -------------------------------------------------------------------------------- /internal/metadata/metadata_manager.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/record" 5 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 6 | ) 7 | 8 | type MetadataManager struct { 9 | tableManager *TableManager 10 | viewManager *ViewManager 11 | statsManager *StatsManager 12 | indexManager *IndexManager 13 | } 14 | 15 | func NewMetadataManager(isNew bool, tx *recovery.Transaction) *MetadataManager { 16 | tableManager := NewTableManager(isNew, tx) 17 | viewManager := NewViewManager(isNew, tableManager, tx) 18 | statsManager := NewStatsManager(tableManager, tx) 19 | indexManager := NewIndexManager(isNew, tableManager, statsManager, tx) 20 | return &MetadataManager{ 21 | tableManager, 22 | viewManager, 23 | statsManager, 24 | indexManager, 25 | } 26 | } 27 | 28 | func (mdtManager *MetadataManager) CreateTable(tblName string, schema *record.Schema, tx *recovery.Transaction) error { 29 | return mdtManager.tableManager.CreateTable(tblName, schema, tx) 30 | } 31 | 32 | func (mdtManager *MetadataManager) GetLayout(tblName string, tx *recovery.Transaction) (*record.Layout, error) { 33 | return mdtManager.tableManager.GetLayout(tblName, tx) 34 | } 35 | 36 | func (mdtManager *MetadataManager) CreateView(viewName, viewDef string, tx *recovery.Transaction) error { 37 | return mdtManager.viewManager.CreateView(viewName, viewDef, tx) 38 | } 39 | 40 | func (mdtManager *MetadataManager) GetViewDef(viewName string, tx *recovery.Transaction) (string, error) { 41 | return mdtManager.viewManager.GetViewDef(viewName, tx) 42 | } 43 | 44 | func (mdtManager *MetadataManager) CreateIndex(idxName, tblName, fldName string, tx *recovery.Transaction) error { 45 | return mdtManager.indexManager.CreateIndex(idxName, tblName, fldName, tx) 46 | } 47 | 48 | func (mdtManager *MetadataManager) GetIndexInfo(tblName string, tx *recovery.Transaction) (map[string]IndexInfo, error) { 49 | return mdtManager.indexManager.GetIndexInfo(tblName, tx) 50 | } 51 | 52 | func (mdtManager *MetadataManager) GetStatInfo(tblName string, layout *record.Layout, tx *recovery.Transaction) StatInfo { 53 | return mdtManager.statsManager.GetStatInfo(tblName, layout, tx) 54 | } 55 | 56 | func (mdtManager *MetadataManager) GetTableManager() *TableManager { 57 | return mdtManager.tableManager 58 | } 59 | 60 | func (mdtManager *MetadataManager) GetViewManager() *ViewManager { 61 | return mdtManager.viewManager 62 | } 63 | 64 | func (mdtManager *MetadataManager) GetStatsManager() *StatsManager { 65 | return mdtManager.statsManager 66 | } 67 | 68 | func (mdtManager *MetadataManager) GetIndexManager() *IndexManager { 69 | return mdtManager.indexManager 70 | } 71 | -------------------------------------------------------------------------------- /internal/metadata/metadata_manager_test.go: -------------------------------------------------------------------------------- 1 | package metadata_test 2 | 3 | import ( 4 | "os" 5 | "path" 6 | "testing" 7 | 8 | "github.com/evanxg852000/simpledb/internal/record" 9 | "github.com/evanxg852000/simpledb/internal/server" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestMetadataManager(t *testing.T) { 14 | assert := assert.New(t) 15 | workspaceDir, err := os.MkdirTemp("", "test_table_manager") 16 | assert.Nil(err) 17 | dbDir := path.Join(workspaceDir, "db") 18 | defer os.RemoveAll(workspaceDir) 19 | 20 | db := server.NewSimpleDB(dbDir, 400, 8) 21 | tblManager := db.MetadataManager().GetTableManager() 22 | 23 | tx := db.NewTx() 24 | 25 | schema := record.NewSchema() 26 | schema.AddIntField("A") 27 | schema.AddStringField("B", 9) 28 | tblManager.CreateTable("my_table", schema, tx) 29 | 30 | layout, err := tblManager.GetLayout("my_table", tx) 31 | assert.Nil(err) 32 | 33 | assert.Equal(int64(33), layout.SlotSize()) 34 | assert.Equal(2, len(layout.Schema.Fields())) 35 | rows := []struct { 36 | n string 37 | t int64 38 | }{} 39 | for _, fldName := range layout.Schema.Fields() { 40 | rows = append(rows, struct { 41 | n string 42 | t int64 43 | }{fldName, layout.Schema.FieldType(fldName)}) 44 | } 45 | 46 | tx.Commit() 47 | 48 | assert.Equal([]struct { 49 | n string 50 | t int64 51 | }{{"A", record.INTEGER_TYPE}, {"B", record.STRING_TYPE}}, rows) 52 | } 53 | -------------------------------------------------------------------------------- /internal/metadata/stat_info.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | // A StatInfo object holds three pieces of 4 | // statistical information about a table: 5 | // the number of blocks, the number of records, 6 | // and the number of distinct values for each field. 7 | type StatInfo struct { 8 | numBlocks int64 9 | numRecords int64 10 | } 11 | 12 | // Create a StatInfo object. 13 | // Note that the number of distinct values is not 14 | // passed into the constructor. 15 | // The object fakes this value. 16 | func NewStatInfo(numBlocks, numRecords int64) StatInfo { 17 | return StatInfo{numBlocks, numRecords} 18 | } 19 | 20 | // Return the estimated number of blocks in the table. 21 | func (si *StatInfo) BlockAccessed() int64 { 22 | return si.numBlocks 23 | } 24 | 25 | // Return the estimated number of records in the table. 26 | func (si *StatInfo) RecordsOutput() int64 { 27 | return si.numRecords 28 | } 29 | 30 | // Return the estimated number of distinct values 31 | // for the specified field. 32 | // This estimate is a complete guess, because doing something 33 | // reasonable is beyond the scope of this system. 34 | func (si *StatInfo) DistinctValues(fieldName string) int64 { 35 | _ = fieldName 36 | return (si.numRecords / 3) + 1 37 | } 38 | -------------------------------------------------------------------------------- /internal/metadata/stats_manager.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/evanxg852000/simpledb/internal/record" 7 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 8 | ) 9 | 10 | // The statistics manager is responsible for 11 | // keeping statistical information about each table. 12 | // The manager does not store this information in the database. 13 | // Instead, it calculates this information on system startup, 14 | // and periodically refreshes it. 15 | type StatsManager struct { 16 | tableManager TableManager 17 | tableStats map[string]StatInfo 18 | numCalls int 19 | mu *sync.Mutex 20 | } 21 | 22 | // Create the statistics manager. 23 | // The initial statistics are calculated by 24 | // traversing the entire database. 25 | func NewStatsManager(tableManager *TableManager, tx *recovery.Transaction) *StatsManager { 26 | statsManager := &StatsManager{ 27 | tableManager: *tableManager, 28 | tableStats: make(map[string]StatInfo), 29 | numCalls: 0, 30 | mu: new(sync.Mutex), 31 | } 32 | statsManager.mu.Lock() 33 | statsManager.refreshStatistics(tx) 34 | statsManager.mu.Unlock() 35 | return statsManager 36 | } 37 | 38 | // Return the statistical information about the specified table. 39 | func (statsManager *StatsManager) GetStatInfo(tblName string, layout *record.Layout, tx *recovery.Transaction) StatInfo { 40 | statsManager.mu.Lock() 41 | defer statsManager.mu.Unlock() 42 | 43 | statsManager.numCalls += 1 44 | if statsManager.numCalls > 100 { 45 | statsManager.refreshStatistics(tx) 46 | } 47 | 48 | si, exist := statsManager.tableStats[tblName] 49 | if !exist { 50 | si, _ = statsManager.calcTableStats(tblName, layout, tx) 51 | statsManager.tableStats[tblName] = si 52 | } 53 | return si 54 | } 55 | 56 | func (statsManager *StatsManager) refreshStatistics(tx *recovery.Transaction) error { 57 | statsManager.tableStats = map[string]StatInfo{} 58 | statsManager.numCalls = 0 59 | layout, err := statsManager.tableManager.GetLayout(TABLE_CATALOG, tx) 60 | if err != nil { 61 | return err 62 | } 63 | 64 | tableScan, err := record.NewTableScan(tx, TABLE_CATALOG, layout) 65 | if err != nil { 66 | return err 67 | } 68 | 69 | for tableScan.Next() { 70 | tblName := tableScan.GetString("table_name") 71 | layout, _ := statsManager.tableManager.GetLayout(tblName, tx) 72 | si, _ := statsManager.calcTableStats(tblName, layout, tx) 73 | statsManager.tableStats[tblName] = si 74 | } 75 | tableScan.Close() 76 | return nil 77 | } 78 | 79 | func (statsManager *StatsManager) calcTableStats(tblName string, layout *record.Layout, tx *recovery.Transaction) (StatInfo, error) { 80 | 81 | numBlocks := int64(0) 82 | numRecords := int64(0) 83 | tableScan, err := record.NewTableScan(tx, tblName, layout) 84 | if err != nil { 85 | return StatInfo{}, err 86 | } 87 | 88 | for tableScan.Next() { 89 | numRecords += 1 90 | numBlocks = tableScan.GetRID().BlockNum + 1 91 | } 92 | tableScan.Close() 93 | return NewStatInfo(numBlocks, numRecords), nil 94 | } 95 | -------------------------------------------------------------------------------- /internal/metadata/table_manager.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/record" 5 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 6 | ) 7 | 8 | const ( 9 | //The max characters a tablename or fieldname can have. 10 | MAX_NAME_LENGTH = 32 11 | 12 | TABLE_CATALOG = "table_catalog" 13 | FIELD_CATALOG = "field_catalog" 14 | ) 15 | 16 | // The table manager. 17 | // There are methods to create a table, save the metadata 18 | // in the catalog, and obtain the metadata of a 19 | // previously-created table. 20 | type TableManager struct { 21 | tableCatLayout *record.Layout 22 | fieldCatLayout *record.Layout 23 | } 24 | 25 | // Create a new table manager for the database system. 26 | // If the database is new, the two catalog tables (table, field) 27 | // are created. 28 | func NewTableManager(isNew bool, tx *recovery.Transaction) *TableManager { 29 | tableCatSchema := record.NewSchema() 30 | tableCatSchema.AddStringField("table_name", MAX_NAME_LENGTH) 31 | tableCatSchema.AddIntField("slot_size") 32 | tableCatLayout := record.NewLayout(tableCatSchema) 33 | 34 | fieldCatSchema := record.NewSchema() 35 | fieldCatSchema.AddStringField("table_name", MAX_NAME_LENGTH) 36 | fieldCatSchema.AddStringField("field_name", MAX_NAME_LENGTH) 37 | fieldCatSchema.AddIntField("type") 38 | fieldCatSchema.AddIntField("length") 39 | fieldCatSchema.AddIntField("offset") 40 | fieldCatLayout := record.NewLayout(fieldCatSchema) 41 | 42 | tableManager := &TableManager{tableCatLayout, fieldCatLayout} 43 | if isNew { 44 | tableManager.CreateTable(TABLE_CATALOG, tableCatSchema, tx) 45 | tableManager.CreateTable(FIELD_CATALOG, fieldCatSchema, tx) 46 | } 47 | return tableManager 48 | } 49 | 50 | // Create a new table having the specified name and schema 51 | func (tableManager *TableManager) CreateTable(tblName string, schema *record.Schema, tx *recovery.Transaction) error { 52 | layout := record.NewLayout(schema) 53 | 54 | // insert one record into table_catalog 55 | tableScan, err := record.NewTableScan(tx, TABLE_CATALOG, tableManager.tableCatLayout) 56 | if err != nil { 57 | return err 58 | } 59 | tableScan.Insert() 60 | tableScan.SetString("table_name", tblName) 61 | tableScan.SetInt("slot_size", layout.SlotSize()) 62 | tableScan.Close() 63 | 64 | // insert one record into field_catalog 65 | tableScan, err = record.NewTableScan(tx, FIELD_CATALOG, tableManager.fieldCatLayout) 66 | if err != nil { 67 | return err 68 | } 69 | for _, fldName := range schema.Fields() { 70 | tableScan.Insert() 71 | tableScan.SetString("table_name", tblName) 72 | tableScan.SetString("field_name", fldName) 73 | tableScan.SetInt("type", int64(schema.FieldType(fldName))) 74 | tableScan.SetInt("length", int64(schema.FieldLength(fldName))) 75 | tableScan.SetInt("offset", layout.Offset(fldName)) 76 | } 77 | tableScan.Close() 78 | return nil 79 | } 80 | 81 | func (tableManager *TableManager) GetLayout(tblName string, tx *recovery.Transaction) (*record.Layout, error) { 82 | size := int64(-1) 83 | tableScan, err := record.NewTableScan(tx, TABLE_CATALOG, tableManager.tableCatLayout) 84 | if err != nil { 85 | return nil, err 86 | } 87 | for tableScan.Next() { 88 | storedTblName := tableScan.GetString("table_name") 89 | if storedTblName == tblName { 90 | size = tableScan.GetInt("slot_size") 91 | break 92 | } 93 | } 94 | tableScan.Close() 95 | 96 | schema := record.NewSchema() 97 | offsets := make(map[string]int64) 98 | tableScan, err = record.NewTableScan(tx, FIELD_CATALOG, tableManager.fieldCatLayout) 99 | if err != nil { 100 | return nil, err 101 | } 102 | 103 | for tableScan.Next() { 104 | storedTblName := tableScan.GetString("table_name") 105 | if storedTblName == tblName { 106 | fldName := tableScan.GetString("field_name") 107 | fldType := tableScan.GetInt("type") 108 | fldLength := tableScan.GetInt("length") 109 | offset := tableScan.GetInt("offset") 110 | offsets[fldName] = offset 111 | schema.AddField(fldName, fldType, fldLength) 112 | } 113 | } 114 | tableScan.Close() 115 | return record.NewLayoutFromMetadata(schema, offsets, size), nil 116 | } 117 | -------------------------------------------------------------------------------- /internal/metadata/table_manager_test.go: -------------------------------------------------------------------------------- 1 | package metadata_test 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "os" 7 | "path" 8 | "testing" 9 | 10 | "github.com/evanxg852000/simpledb/internal/metadata" 11 | "github.com/evanxg852000/simpledb/internal/record" 12 | "github.com/evanxg852000/simpledb/internal/server" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestTableManager(t *testing.T) { 17 | assert := assert.New(t) 18 | workspaceDir, err := os.MkdirTemp("", "test_table_manager") 19 | assert.Nil(err) 20 | dbDir := path.Join(workspaceDir, "db") 21 | defer os.RemoveAll(workspaceDir) 22 | 23 | db := server.NewSimpleDB(dbDir, 400, 8) 24 | mdtManager := db.MetadataManager() 25 | tx := db.NewTx() 26 | 27 | schema := record.NewSchema() 28 | schema.AddIntField("A") 29 | schema.AddStringField("B", 9) 30 | mdtManager.CreateTable("my_table", schema, tx) 31 | 32 | // Statistics metadata 33 | layout, err := mdtManager.GetLayout("my_table", tx) 34 | assert.Nil(err) 35 | tblScan, err := record.NewTableScan(tx, "my_table", layout) 36 | assert.Nil(err) 37 | for i := 0; i < 50; i++ { 38 | tblScan.Insert() 39 | n := int64(rand.Intn(50)) 40 | tblScan.SetInt("A", n) 41 | tblScan.SetString("B", fmt.Sprintf("rec_%d", n)) 42 | } 43 | si := mdtManager.GetStatInfo("my_table", layout, tx) 44 | assert.Equal(metadata.NewStatInfo(5, 50), si) 45 | 46 | // View metadata 47 | viewDef := "select B from MyTable where A = 1" 48 | mdtManager.CreateView("my_view", viewDef, tx) 49 | v, _ := mdtManager.GetViewDef("my_view", tx) 50 | assert.Equal(viewDef, v) 51 | 52 | // Index metadata 53 | mdtManager.CreateIndex("idx_a", "my_table", "A", tx) 54 | mdtManager.CreateIndex("idx_b", "my_table", "B", tx) 55 | idxMap, _ := mdtManager.GetIndexInfo("my_table", tx) 56 | 57 | idxA := idxMap["A"] 58 | assert.Equal("idx_a", idxA.IndexName) 59 | assert.Equal("A", idxA.FieldName) 60 | 61 | idxB := idxMap["B"] 62 | assert.Equal("idx_b", idxB.IndexName) 63 | assert.Equal("B", idxB.FieldName) 64 | tx.Commit() 65 | } 66 | -------------------------------------------------------------------------------- /internal/metadata/view_manager.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/record" 5 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 6 | ) 7 | 8 | const ( 9 | MAX_VIEW_DEF = 100 10 | 11 | VIEW_CATALOG = "view_catalog" 12 | ) 13 | 14 | type ViewManager struct { 15 | tableManager *TableManager 16 | } 17 | 18 | func NewViewManager(isNew bool, tableManager *TableManager, tx *recovery.Transaction) *ViewManager { 19 | viewManager := &ViewManager{tableManager} 20 | if isNew { 21 | schema := record.NewSchema() 22 | schema.AddStringField("view_name", MAX_NAME_LENGTH) 23 | schema.AddStringField("view_def", MAX_VIEW_DEF) 24 | tableManager.CreateTable("view_catalog", schema, tx) 25 | } 26 | return viewManager 27 | } 28 | 29 | func (vm *ViewManager) CreateView(vName string, viewDef string, tx *recovery.Transaction) error { 30 | layout, err := vm.tableManager.GetLayout("view_catalog", tx) 31 | if err != nil { 32 | return err 33 | } 34 | tableScan, err := record.NewTableScan(tx, VIEW_CATALOG, layout) 35 | if err != nil { 36 | return err 37 | } 38 | tableScan.Insert() 39 | tableScan.SetString("view_name", vName) 40 | tableScan.SetString("view_def", viewDef) 41 | tableScan.Close() 42 | return nil 43 | } 44 | 45 | func (vm *ViewManager) GetViewDef(vName string, tx *recovery.Transaction) (string, error) { 46 | viewDef := "" 47 | layout, err := vm.tableManager.GetLayout("view_catalog", tx) 48 | if err != nil { 49 | return viewDef, err 50 | } 51 | tableScan, err := record.NewTableScan(tx, VIEW_CATALOG, layout) 52 | if err != nil { 53 | return viewDef, err 54 | } 55 | 56 | for tableScan.Next() { 57 | storedViewName := tableScan.GetString("view_name") 58 | if storedViewName == vName { 59 | viewDef = tableScan.GetString("view_def") 60 | break 61 | } 62 | } 63 | tableScan.Close() 64 | return viewDef, nil 65 | } 66 | -------------------------------------------------------------------------------- /internal/parser/SimpleSql.g4: -------------------------------------------------------------------------------- 1 | grammar SimpleSql; 2 | 3 | /* antlr4 4.7.2 */ 4 | 5 | parse 6 | : statementList* EOF 7 | ; 8 | 9 | statementList 10 | : statement (SEMI_COLON statement)* 11 | ; 12 | 13 | statement 14 | : create_table_stmt 15 | | insert_stmt 16 | | select_stmt 17 | | update_stmt 18 | | delete_stmt 19 | | create_view_stmt 20 | | create_index_stmt 21 | ; 22 | 23 | create_table_stmt: CREATE_ TABLE_ IDENT '(' field_specs ')' ; 24 | field_specs: field_spec (COMMA field_spec)* ; 25 | field_spec: IDENT type_spec ; 26 | type_spec: INT_ | varchar_spec ; 27 | varchar_spec: VAR_CHAR_ '(' INT_LITERAL ')' ; 28 | 29 | insert_stmt: INSERT_ INTO_ IDENT ( '(' ident_list ')' )? VALUES_ '(' constant_list ')' ; 30 | constant_list: literal (COMMA literal)* ; 31 | 32 | select_stmt: SELECT_ (STAR | ident_list) FROM_ ident_list (WHERE_ condition)? ; 33 | ident_list: IDENT (COMMA IDENT)* ; 34 | 35 | update_stmt: UPDATE_ IDENT SET_ update_expr_list (WHERE_ condition)? ; 36 | update_expr_list: update_expr (COMMA update_expr)* ; 37 | update_expr: IDENT '=' expression ; 38 | 39 | delete_stmt: DELETE_ FROM_ IDENT (WHERE_ condition)? ; 40 | 41 | create_view_stmt: CREATE_ VIEW_ IDENT AS_ select_stmt ; 42 | 43 | create_index_stmt: CREATE_ INDEX_ IDENT ON_ IDENT '(' IDENT ')' ; 44 | 45 | 46 | condition: term ( op=(AND_ | OR_) term)?; 47 | term: left=expression operator=(EQUAL|NOT_EQUAL) right=expression ; 48 | expression: IDENT | literal ; 49 | literal: INT_LITERAL | STR_LITERAL ; 50 | 51 | /* keywords */ 52 | 53 | CREATE_: 'create' ; 54 | INSERT_: 'insert' ; 55 | SELECT_: 'select' ; 56 | UPDATE_: 'update' ; 57 | DELETE_: 'delete' ; 58 | FROM_: 'from' ; 59 | SET_: 'set' ; 60 | WHERE_: 'where' ; 61 | INTO_: 'into' ; 62 | VALUES_: 'values' ; 63 | TABLE_: 'table' ; 64 | INDEX_: 'index' ; 65 | VIEW_: 'view' ; 66 | AS_: 'as' ; 67 | ON_: 'on' ; 68 | INT_: 'int'; 69 | VAR_CHAR_: 'varchar' ; 70 | AND_: 'and' ; 71 | OR_: 'or' ; 72 | 73 | STAR: '*' ; 74 | EQUAL: '=' ; 75 | NOT_EQUAL: '!=' ; 76 | COMMA: ','; 77 | SEMI_COLON: ';'; 78 | 79 | IDENT: [a-zA-Z_][a-zA-Z0-9_]* ; 80 | INT_LITERAL: '0'|[-+]?[1-9][0-9]* ; 81 | STR_LITERAL: '\'' ( ~'\'' | '\'\'')* '\'' ; 82 | 83 | SPACES: [ \t\r\n] -> skip ; 84 | 85 | -------------------------------------------------------------------------------- /internal/parser/SimpleSql.interp: -------------------------------------------------------------------------------- 1 | token literal names: 2 | null 3 | '(' 4 | ')' 5 | 'create' 6 | 'insert' 7 | 'select' 8 | 'update' 9 | 'delete' 10 | 'from' 11 | 'set' 12 | 'where' 13 | 'into' 14 | 'values' 15 | 'table' 16 | 'index' 17 | 'view' 18 | 'as' 19 | 'on' 20 | 'int' 21 | 'varchar' 22 | 'and' 23 | 'or' 24 | '*' 25 | '=' 26 | '!=' 27 | ',' 28 | ';' 29 | null 30 | null 31 | null 32 | null 33 | 34 | token symbolic names: 35 | null 36 | null 37 | null 38 | CREATE_ 39 | INSERT_ 40 | SELECT_ 41 | UPDATE_ 42 | DELETE_ 43 | FROM_ 44 | SET_ 45 | WHERE_ 46 | INTO_ 47 | VALUES_ 48 | TABLE_ 49 | INDEX_ 50 | VIEW_ 51 | AS_ 52 | ON_ 53 | INT_ 54 | VAR_CHAR_ 55 | AND_ 56 | OR_ 57 | STAR 58 | EQUAL 59 | NOT_EQUAL 60 | COMMA 61 | SEMI_COLON 62 | IDENT 63 | INT_LITERAL 64 | STR_LITERAL 65 | SPACES 66 | 67 | rule names: 68 | parse 69 | statementList 70 | statement 71 | create_table_stmt 72 | field_specs 73 | field_spec 74 | type_spec 75 | varchar_spec 76 | insert_stmt 77 | constant_list 78 | select_stmt 79 | ident_list 80 | update_stmt 81 | update_expr_list 82 | update_expr 83 | delete_stmt 84 | create_view_stmt 85 | create_index_stmt 86 | condition 87 | term 88 | expression 89 | literal 90 | 91 | 92 | atn: 93 | [3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 3, 32, 197, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 4, 20, 9, 20, 4, 21, 9, 21, 4, 22, 9, 22, 4, 23, 9, 23, 3, 2, 7, 2, 48, 10, 2, 12, 2, 14, 2, 51, 11, 2, 3, 2, 3, 2, 3, 3, 3, 3, 3, 3, 7, 3, 58, 10, 3, 12, 3, 14, 3, 61, 11, 3, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 70, 10, 4, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 6, 3, 6, 3, 6, 7, 6, 82, 10, 6, 12, 6, 14, 6, 85, 11, 6, 3, 7, 3, 7, 3, 7, 3, 8, 3, 8, 5, 8, 92, 10, 8, 3, 9, 3, 9, 3, 9, 3, 9, 3, 9, 3, 10, 3, 10, 3, 10, 3, 10, 3, 10, 3, 10, 3, 10, 5, 10, 106, 10, 10, 3, 10, 3, 10, 3, 10, 3, 10, 3, 10, 3, 11, 3, 11, 3, 11, 7, 11, 116, 10, 11, 12, 11, 14, 11, 119, 11, 11, 3, 12, 3, 12, 3, 12, 5, 12, 124, 10, 12, 3, 12, 3, 12, 3, 12, 3, 12, 5, 12, 130, 10, 12, 3, 13, 3, 13, 3, 13, 7, 13, 135, 10, 13, 12, 13, 14, 13, 138, 11, 13, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 5, 14, 146, 10, 14, 3, 15, 3, 15, 3, 15, 7, 15, 151, 10, 15, 12, 15, 14, 15, 154, 11, 15, 3, 16, 3, 16, 3, 16, 3, 16, 3, 17, 3, 17, 3, 17, 3, 17, 3, 17, 5, 17, 165, 10, 17, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 19, 3, 19, 3, 19, 3, 19, 3, 19, 3, 19, 3, 19, 3, 19, 3, 19, 3, 20, 3, 20, 3, 20, 5, 20, 185, 10, 20, 3, 21, 3, 21, 3, 21, 3, 21, 3, 22, 3, 22, 5, 22, 193, 10, 22, 3, 23, 3, 23, 3, 23, 2, 2, 24, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 2, 5, 3, 2, 22, 23, 3, 2, 25, 26, 3, 2, 30, 31, 2, 194, 2, 49, 3, 2, 2, 2, 4, 54, 3, 2, 2, 2, 6, 69, 3, 2, 2, 2, 8, 71, 3, 2, 2, 2, 10, 78, 3, 2, 2, 2, 12, 86, 3, 2, 2, 2, 14, 91, 3, 2, 2, 2, 16, 93, 3, 2, 2, 2, 18, 98, 3, 2, 2, 2, 20, 112, 3, 2, 2, 2, 22, 120, 3, 2, 2, 2, 24, 131, 3, 2, 2, 2, 26, 139, 3, 2, 2, 2, 28, 147, 3, 2, 2, 2, 30, 155, 3, 2, 2, 2, 32, 159, 3, 2, 2, 2, 34, 166, 3, 2, 2, 2, 36, 172, 3, 2, 2, 2, 38, 181, 3, 2, 2, 2, 40, 186, 3, 2, 2, 2, 42, 192, 3, 2, 2, 2, 44, 194, 3, 2, 2, 2, 46, 48, 5, 4, 3, 2, 47, 46, 3, 2, 2, 2, 48, 51, 3, 2, 2, 2, 49, 47, 3, 2, 2, 2, 49, 50, 3, 2, 2, 2, 50, 52, 3, 2, 2, 2, 51, 49, 3, 2, 2, 2, 52, 53, 7, 2, 2, 3, 53, 3, 3, 2, 2, 2, 54, 59, 5, 6, 4, 2, 55, 56, 7, 28, 2, 2, 56, 58, 5, 6, 4, 2, 57, 55, 3, 2, 2, 2, 58, 61, 3, 2, 2, 2, 59, 57, 3, 2, 2, 2, 59, 60, 3, 2, 2, 2, 60, 5, 3, 2, 2, 2, 61, 59, 3, 2, 2, 2, 62, 70, 5, 8, 5, 2, 63, 70, 5, 18, 10, 2, 64, 70, 5, 22, 12, 2, 65, 70, 5, 26, 14, 2, 66, 70, 5, 32, 17, 2, 67, 70, 5, 34, 18, 2, 68, 70, 5, 36, 19, 2, 69, 62, 3, 2, 2, 2, 69, 63, 3, 2, 2, 2, 69, 64, 3, 2, 2, 2, 69, 65, 3, 2, 2, 2, 69, 66, 3, 2, 2, 2, 69, 67, 3, 2, 2, 2, 69, 68, 3, 2, 2, 2, 70, 7, 3, 2, 2, 2, 71, 72, 7, 5, 2, 2, 72, 73, 7, 15, 2, 2, 73, 74, 7, 29, 2, 2, 74, 75, 7, 3, 2, 2, 75, 76, 5, 10, 6, 2, 76, 77, 7, 4, 2, 2, 77, 9, 3, 2, 2, 2, 78, 83, 5, 12, 7, 2, 79, 80, 7, 27, 2, 2, 80, 82, 5, 12, 7, 2, 81, 79, 3, 2, 2, 2, 82, 85, 3, 2, 2, 2, 83, 81, 3, 2, 2, 2, 83, 84, 3, 2, 2, 2, 84, 11, 3, 2, 2, 2, 85, 83, 3, 2, 2, 2, 86, 87, 7, 29, 2, 2, 87, 88, 5, 14, 8, 2, 88, 13, 3, 2, 2, 2, 89, 92, 7, 20, 2, 2, 90, 92, 5, 16, 9, 2, 91, 89, 3, 2, 2, 2, 91, 90, 3, 2, 2, 2, 92, 15, 3, 2, 2, 2, 93, 94, 7, 21, 2, 2, 94, 95, 7, 3, 2, 2, 95, 96, 7, 30, 2, 2, 96, 97, 7, 4, 2, 2, 97, 17, 3, 2, 2, 2, 98, 99, 7, 6, 2, 2, 99, 100, 7, 13, 2, 2, 100, 105, 7, 29, 2, 2, 101, 102, 7, 3, 2, 2, 102, 103, 5, 24, 13, 2, 103, 104, 7, 4, 2, 2, 104, 106, 3, 2, 2, 2, 105, 101, 3, 2, 2, 2, 105, 106, 3, 2, 2, 2, 106, 107, 3, 2, 2, 2, 107, 108, 7, 14, 2, 2, 108, 109, 7, 3, 2, 2, 109, 110, 5, 20, 11, 2, 110, 111, 7, 4, 2, 2, 111, 19, 3, 2, 2, 2, 112, 117, 5, 44, 23, 2, 113, 114, 7, 27, 2, 2, 114, 116, 5, 44, 23, 2, 115, 113, 3, 2, 2, 2, 116, 119, 3, 2, 2, 2, 117, 115, 3, 2, 2, 2, 117, 118, 3, 2, 2, 2, 118, 21, 3, 2, 2, 2, 119, 117, 3, 2, 2, 2, 120, 123, 7, 7, 2, 2, 121, 124, 7, 24, 2, 2, 122, 124, 5, 24, 13, 2, 123, 121, 3, 2, 2, 2, 123, 122, 3, 2, 2, 2, 124, 125, 3, 2, 2, 2, 125, 126, 7, 10, 2, 2, 126, 129, 5, 24, 13, 2, 127, 128, 7, 12, 2, 2, 128, 130, 5, 38, 20, 2, 129, 127, 3, 2, 2, 2, 129, 130, 3, 2, 2, 2, 130, 23, 3, 2, 2, 2, 131, 136, 7, 29, 2, 2, 132, 133, 7, 27, 2, 2, 133, 135, 7, 29, 2, 2, 134, 132, 3, 2, 2, 2, 135, 138, 3, 2, 2, 2, 136, 134, 3, 2, 2, 2, 136, 137, 3, 2, 2, 2, 137, 25, 3, 2, 2, 2, 138, 136, 3, 2, 2, 2, 139, 140, 7, 8, 2, 2, 140, 141, 7, 29, 2, 2, 141, 142, 7, 11, 2, 2, 142, 145, 5, 28, 15, 2, 143, 144, 7, 12, 2, 2, 144, 146, 5, 38, 20, 2, 145, 143, 3, 2, 2, 2, 145, 146, 3, 2, 2, 2, 146, 27, 3, 2, 2, 2, 147, 152, 5, 30, 16, 2, 148, 149, 7, 27, 2, 2, 149, 151, 5, 30, 16, 2, 150, 148, 3, 2, 2, 2, 151, 154, 3, 2, 2, 2, 152, 150, 3, 2, 2, 2, 152, 153, 3, 2, 2, 2, 153, 29, 3, 2, 2, 2, 154, 152, 3, 2, 2, 2, 155, 156, 7, 29, 2, 2, 156, 157, 7, 25, 2, 2, 157, 158, 5, 42, 22, 2, 158, 31, 3, 2, 2, 2, 159, 160, 7, 9, 2, 2, 160, 161, 7, 10, 2, 2, 161, 164, 7, 29, 2, 2, 162, 163, 7, 12, 2, 2, 163, 165, 5, 38, 20, 2, 164, 162, 3, 2, 2, 2, 164, 165, 3, 2, 2, 2, 165, 33, 3, 2, 2, 2, 166, 167, 7, 5, 2, 2, 167, 168, 7, 17, 2, 2, 168, 169, 7, 29, 2, 2, 169, 170, 7, 18, 2, 2, 170, 171, 5, 22, 12, 2, 171, 35, 3, 2, 2, 2, 172, 173, 7, 5, 2, 2, 173, 174, 7, 16, 2, 2, 174, 175, 7, 29, 2, 2, 175, 176, 7, 19, 2, 2, 176, 177, 7, 29, 2, 2, 177, 178, 7, 3, 2, 2, 178, 179, 7, 29, 2, 2, 179, 180, 7, 4, 2, 2, 180, 37, 3, 2, 2, 2, 181, 184, 5, 40, 21, 2, 182, 183, 9, 2, 2, 2, 183, 185, 5, 40, 21, 2, 184, 182, 3, 2, 2, 2, 184, 185, 3, 2, 2, 2, 185, 39, 3, 2, 2, 2, 186, 187, 5, 42, 22, 2, 187, 188, 9, 3, 2, 2, 188, 189, 5, 42, 22, 2, 189, 41, 3, 2, 2, 2, 190, 193, 7, 29, 2, 2, 191, 193, 5, 44, 23, 2, 192, 190, 3, 2, 2, 2, 192, 191, 3, 2, 2, 2, 193, 43, 3, 2, 2, 2, 194, 195, 9, 4, 2, 2, 195, 45, 3, 2, 2, 2, 17, 49, 59, 69, 83, 91, 105, 117, 123, 129, 136, 145, 152, 164, 184, 192] -------------------------------------------------------------------------------- /internal/parser/SimpleSql.tokens: -------------------------------------------------------------------------------- 1 | T__0=1 2 | T__1=2 3 | CREATE_=3 4 | INSERT_=4 5 | SELECT_=5 6 | UPDATE_=6 7 | DELETE_=7 8 | FROM_=8 9 | SET_=9 10 | WHERE_=10 11 | INTO_=11 12 | VALUES_=12 13 | TABLE_=13 14 | INDEX_=14 15 | VIEW_=15 16 | AS_=16 17 | ON_=17 18 | INT_=18 19 | VAR_CHAR_=19 20 | AND_=20 21 | OR_=21 22 | STAR=22 23 | EQUAL=23 24 | NOT_EQUAL=24 25 | COMMA=25 26 | SEMI_COLON=26 27 | IDENT=27 28 | INT_LITERAL=28 29 | STR_LITERAL=29 30 | SPACES=30 31 | '('=1 32 | ')'=2 33 | 'create'=3 34 | 'insert'=4 35 | 'select'=5 36 | 'update'=6 37 | 'delete'=7 38 | 'from'=8 39 | 'set'=9 40 | 'where'=10 41 | 'into'=11 42 | 'values'=12 43 | 'table'=13 44 | 'index'=14 45 | 'view'=15 46 | 'as'=16 47 | 'on'=17 48 | 'int'=18 49 | 'varchar'=19 50 | 'and'=20 51 | 'or'=21 52 | '*'=22 53 | '='=23 54 | '!='=24 55 | ','=25 56 | ';'=26 57 | -------------------------------------------------------------------------------- /internal/parser/SimpleSqlLexer.interp: -------------------------------------------------------------------------------- 1 | token literal names: 2 | null 3 | '(' 4 | ')' 5 | 'create' 6 | 'insert' 7 | 'select' 8 | 'update' 9 | 'delete' 10 | 'from' 11 | 'set' 12 | 'where' 13 | 'into' 14 | 'values' 15 | 'table' 16 | 'index' 17 | 'view' 18 | 'as' 19 | 'on' 20 | 'int' 21 | 'varchar' 22 | 'and' 23 | 'or' 24 | '*' 25 | '=' 26 | '!=' 27 | ',' 28 | ';' 29 | null 30 | null 31 | null 32 | null 33 | 34 | token symbolic names: 35 | null 36 | null 37 | null 38 | CREATE_ 39 | INSERT_ 40 | SELECT_ 41 | UPDATE_ 42 | DELETE_ 43 | FROM_ 44 | SET_ 45 | WHERE_ 46 | INTO_ 47 | VALUES_ 48 | TABLE_ 49 | INDEX_ 50 | VIEW_ 51 | AS_ 52 | ON_ 53 | INT_ 54 | VAR_CHAR_ 55 | AND_ 56 | OR_ 57 | STAR 58 | EQUAL 59 | NOT_EQUAL 60 | COMMA 61 | SEMI_COLON 62 | IDENT 63 | INT_LITERAL 64 | STR_LITERAL 65 | SPACES 66 | 67 | rule names: 68 | T__0 69 | T__1 70 | CREATE_ 71 | INSERT_ 72 | SELECT_ 73 | UPDATE_ 74 | DELETE_ 75 | FROM_ 76 | SET_ 77 | WHERE_ 78 | INTO_ 79 | VALUES_ 80 | TABLE_ 81 | INDEX_ 82 | VIEW_ 83 | AS_ 84 | ON_ 85 | INT_ 86 | VAR_CHAR_ 87 | AND_ 88 | OR_ 89 | STAR 90 | EQUAL 91 | NOT_EQUAL 92 | COMMA 93 | SEMI_COLON 94 | IDENT 95 | INT_LITERAL 96 | STR_LITERAL 97 | SPACES 98 | 99 | channel names: 100 | DEFAULT_TOKEN_CHANNEL 101 | HIDDEN 102 | 103 | mode names: 104 | DEFAULT_MODE 105 | 106 | atn: 107 | [3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 32, 217, 8, 1, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 4, 20, 9, 20, 4, 21, 9, 21, 4, 22, 9, 22, 4, 23, 9, 23, 4, 24, 9, 24, 4, 25, 9, 25, 4, 26, 9, 26, 4, 27, 9, 27, 4, 28, 9, 28, 4, 29, 9, 29, 4, 30, 9, 30, 4, 31, 9, 31, 3, 2, 3, 2, 3, 3, 3, 3, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 3, 7, 3, 7, 3, 7, 3, 7, 3, 7, 3, 7, 3, 7, 3, 8, 3, 8, 3, 8, 3, 8, 3, 8, 3, 8, 3, 8, 3, 9, 3, 9, 3, 9, 3, 9, 3, 9, 3, 10, 3, 10, 3, 10, 3, 10, 3, 11, 3, 11, 3, 11, 3, 11, 3, 11, 3, 11, 3, 12, 3, 12, 3, 12, 3, 12, 3, 12, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 15, 3, 15, 3, 15, 3, 15, 3, 15, 3, 15, 3, 16, 3, 16, 3, 16, 3, 16, 3, 16, 3, 17, 3, 17, 3, 17, 3, 18, 3, 18, 3, 18, 3, 19, 3, 19, 3, 19, 3, 19, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 21, 3, 21, 3, 21, 3, 21, 3, 22, 3, 22, 3, 22, 3, 23, 3, 23, 3, 24, 3, 24, 3, 25, 3, 25, 3, 25, 3, 26, 3, 26, 3, 27, 3, 27, 3, 28, 3, 28, 7, 28, 185, 10, 28, 12, 28, 14, 28, 188, 11, 28, 3, 29, 3, 29, 5, 29, 192, 10, 29, 3, 29, 3, 29, 7, 29, 196, 10, 29, 12, 29, 14, 29, 199, 11, 29, 5, 29, 201, 10, 29, 3, 30, 3, 30, 3, 30, 3, 30, 7, 30, 207, 10, 30, 12, 30, 14, 30, 210, 11, 30, 3, 30, 3, 30, 3, 31, 3, 31, 3, 31, 3, 31, 2, 2, 32, 3, 3, 5, 4, 7, 5, 9, 6, 11, 7, 13, 8, 15, 9, 17, 10, 19, 11, 21, 12, 23, 13, 25, 14, 27, 15, 29, 16, 31, 17, 33, 18, 35, 19, 37, 20, 39, 21, 41, 22, 43, 23, 45, 24, 47, 25, 49, 26, 51, 27, 53, 28, 55, 29, 57, 30, 59, 31, 61, 32, 3, 2, 9, 5, 2, 67, 92, 97, 97, 99, 124, 6, 2, 50, 59, 67, 92, 97, 97, 99, 124, 4, 2, 45, 45, 47, 47, 3, 2, 51, 59, 3, 2, 50, 59, 3, 2, 41, 41, 5, 2, 11, 12, 15, 15, 34, 34, 2, 222, 2, 3, 3, 2, 2, 2, 2, 5, 3, 2, 2, 2, 2, 7, 3, 2, 2, 2, 2, 9, 3, 2, 2, 2, 2, 11, 3, 2, 2, 2, 2, 13, 3, 2, 2, 2, 2, 15, 3, 2, 2, 2, 2, 17, 3, 2, 2, 2, 2, 19, 3, 2, 2, 2, 2, 21, 3, 2, 2, 2, 2, 23, 3, 2, 2, 2, 2, 25, 3, 2, 2, 2, 2, 27, 3, 2, 2, 2, 2, 29, 3, 2, 2, 2, 2, 31, 3, 2, 2, 2, 2, 33, 3, 2, 2, 2, 2, 35, 3, 2, 2, 2, 2, 37, 3, 2, 2, 2, 2, 39, 3, 2, 2, 2, 2, 41, 3, 2, 2, 2, 2, 43, 3, 2, 2, 2, 2, 45, 3, 2, 2, 2, 2, 47, 3, 2, 2, 2, 2, 49, 3, 2, 2, 2, 2, 51, 3, 2, 2, 2, 2, 53, 3, 2, 2, 2, 2, 55, 3, 2, 2, 2, 2, 57, 3, 2, 2, 2, 2, 59, 3, 2, 2, 2, 2, 61, 3, 2, 2, 2, 3, 63, 3, 2, 2, 2, 5, 65, 3, 2, 2, 2, 7, 67, 3, 2, 2, 2, 9, 74, 3, 2, 2, 2, 11, 81, 3, 2, 2, 2, 13, 88, 3, 2, 2, 2, 15, 95, 3, 2, 2, 2, 17, 102, 3, 2, 2, 2, 19, 107, 3, 2, 2, 2, 21, 111, 3, 2, 2, 2, 23, 117, 3, 2, 2, 2, 25, 122, 3, 2, 2, 2, 27, 129, 3, 2, 2, 2, 29, 135, 3, 2, 2, 2, 31, 141, 3, 2, 2, 2, 33, 146, 3, 2, 2, 2, 35, 149, 3, 2, 2, 2, 37, 152, 3, 2, 2, 2, 39, 156, 3, 2, 2, 2, 41, 164, 3, 2, 2, 2, 43, 168, 3, 2, 2, 2, 45, 171, 3, 2, 2, 2, 47, 173, 3, 2, 2, 2, 49, 175, 3, 2, 2, 2, 51, 178, 3, 2, 2, 2, 53, 180, 3, 2, 2, 2, 55, 182, 3, 2, 2, 2, 57, 200, 3, 2, 2, 2, 59, 202, 3, 2, 2, 2, 61, 213, 3, 2, 2, 2, 63, 64, 7, 42, 2, 2, 64, 4, 3, 2, 2, 2, 65, 66, 7, 43, 2, 2, 66, 6, 3, 2, 2, 2, 67, 68, 7, 101, 2, 2, 68, 69, 7, 116, 2, 2, 69, 70, 7, 103, 2, 2, 70, 71, 7, 99, 2, 2, 71, 72, 7, 118, 2, 2, 72, 73, 7, 103, 2, 2, 73, 8, 3, 2, 2, 2, 74, 75, 7, 107, 2, 2, 75, 76, 7, 112, 2, 2, 76, 77, 7, 117, 2, 2, 77, 78, 7, 103, 2, 2, 78, 79, 7, 116, 2, 2, 79, 80, 7, 118, 2, 2, 80, 10, 3, 2, 2, 2, 81, 82, 7, 117, 2, 2, 82, 83, 7, 103, 2, 2, 83, 84, 7, 110, 2, 2, 84, 85, 7, 103, 2, 2, 85, 86, 7, 101, 2, 2, 86, 87, 7, 118, 2, 2, 87, 12, 3, 2, 2, 2, 88, 89, 7, 119, 2, 2, 89, 90, 7, 114, 2, 2, 90, 91, 7, 102, 2, 2, 91, 92, 7, 99, 2, 2, 92, 93, 7, 118, 2, 2, 93, 94, 7, 103, 2, 2, 94, 14, 3, 2, 2, 2, 95, 96, 7, 102, 2, 2, 96, 97, 7, 103, 2, 2, 97, 98, 7, 110, 2, 2, 98, 99, 7, 103, 2, 2, 99, 100, 7, 118, 2, 2, 100, 101, 7, 103, 2, 2, 101, 16, 3, 2, 2, 2, 102, 103, 7, 104, 2, 2, 103, 104, 7, 116, 2, 2, 104, 105, 7, 113, 2, 2, 105, 106, 7, 111, 2, 2, 106, 18, 3, 2, 2, 2, 107, 108, 7, 117, 2, 2, 108, 109, 7, 103, 2, 2, 109, 110, 7, 118, 2, 2, 110, 20, 3, 2, 2, 2, 111, 112, 7, 121, 2, 2, 112, 113, 7, 106, 2, 2, 113, 114, 7, 103, 2, 2, 114, 115, 7, 116, 2, 2, 115, 116, 7, 103, 2, 2, 116, 22, 3, 2, 2, 2, 117, 118, 7, 107, 2, 2, 118, 119, 7, 112, 2, 2, 119, 120, 7, 118, 2, 2, 120, 121, 7, 113, 2, 2, 121, 24, 3, 2, 2, 2, 122, 123, 7, 120, 2, 2, 123, 124, 7, 99, 2, 2, 124, 125, 7, 110, 2, 2, 125, 126, 7, 119, 2, 2, 126, 127, 7, 103, 2, 2, 127, 128, 7, 117, 2, 2, 128, 26, 3, 2, 2, 2, 129, 130, 7, 118, 2, 2, 130, 131, 7, 99, 2, 2, 131, 132, 7, 100, 2, 2, 132, 133, 7, 110, 2, 2, 133, 134, 7, 103, 2, 2, 134, 28, 3, 2, 2, 2, 135, 136, 7, 107, 2, 2, 136, 137, 7, 112, 2, 2, 137, 138, 7, 102, 2, 2, 138, 139, 7, 103, 2, 2, 139, 140, 7, 122, 2, 2, 140, 30, 3, 2, 2, 2, 141, 142, 7, 120, 2, 2, 142, 143, 7, 107, 2, 2, 143, 144, 7, 103, 2, 2, 144, 145, 7, 121, 2, 2, 145, 32, 3, 2, 2, 2, 146, 147, 7, 99, 2, 2, 147, 148, 7, 117, 2, 2, 148, 34, 3, 2, 2, 2, 149, 150, 7, 113, 2, 2, 150, 151, 7, 112, 2, 2, 151, 36, 3, 2, 2, 2, 152, 153, 7, 107, 2, 2, 153, 154, 7, 112, 2, 2, 154, 155, 7, 118, 2, 2, 155, 38, 3, 2, 2, 2, 156, 157, 7, 120, 2, 2, 157, 158, 7, 99, 2, 2, 158, 159, 7, 116, 2, 2, 159, 160, 7, 101, 2, 2, 160, 161, 7, 106, 2, 2, 161, 162, 7, 99, 2, 2, 162, 163, 7, 116, 2, 2, 163, 40, 3, 2, 2, 2, 164, 165, 7, 99, 2, 2, 165, 166, 7, 112, 2, 2, 166, 167, 7, 102, 2, 2, 167, 42, 3, 2, 2, 2, 168, 169, 7, 113, 2, 2, 169, 170, 7, 116, 2, 2, 170, 44, 3, 2, 2, 2, 171, 172, 7, 44, 2, 2, 172, 46, 3, 2, 2, 2, 173, 174, 7, 63, 2, 2, 174, 48, 3, 2, 2, 2, 175, 176, 7, 35, 2, 2, 176, 177, 7, 63, 2, 2, 177, 50, 3, 2, 2, 2, 178, 179, 7, 46, 2, 2, 179, 52, 3, 2, 2, 2, 180, 181, 7, 61, 2, 2, 181, 54, 3, 2, 2, 2, 182, 186, 9, 2, 2, 2, 183, 185, 9, 3, 2, 2, 184, 183, 3, 2, 2, 2, 185, 188, 3, 2, 2, 2, 186, 184, 3, 2, 2, 2, 186, 187, 3, 2, 2, 2, 187, 56, 3, 2, 2, 2, 188, 186, 3, 2, 2, 2, 189, 201, 7, 50, 2, 2, 190, 192, 9, 4, 2, 2, 191, 190, 3, 2, 2, 2, 191, 192, 3, 2, 2, 2, 192, 193, 3, 2, 2, 2, 193, 197, 9, 5, 2, 2, 194, 196, 9, 6, 2, 2, 195, 194, 3, 2, 2, 2, 196, 199, 3, 2, 2, 2, 197, 195, 3, 2, 2, 2, 197, 198, 3, 2, 2, 2, 198, 201, 3, 2, 2, 2, 199, 197, 3, 2, 2, 2, 200, 189, 3, 2, 2, 2, 200, 191, 3, 2, 2, 2, 201, 58, 3, 2, 2, 2, 202, 208, 7, 41, 2, 2, 203, 207, 10, 7, 2, 2, 204, 205, 7, 41, 2, 2, 205, 207, 7, 41, 2, 2, 206, 203, 3, 2, 2, 2, 206, 204, 3, 2, 2, 2, 207, 210, 3, 2, 2, 2, 208, 206, 3, 2, 2, 2, 208, 209, 3, 2, 2, 2, 209, 211, 3, 2, 2, 2, 210, 208, 3, 2, 2, 2, 211, 212, 7, 41, 2, 2, 212, 60, 3, 2, 2, 2, 213, 214, 9, 8, 2, 2, 214, 215, 3, 2, 2, 2, 215, 216, 8, 31, 2, 2, 216, 62, 3, 2, 2, 2, 9, 2, 186, 191, 197, 200, 206, 208, 3, 8, 2, 2] -------------------------------------------------------------------------------- /internal/parser/SimpleSqlLexer.tokens: -------------------------------------------------------------------------------- 1 | T__0=1 2 | T__1=2 3 | CREATE_=3 4 | INSERT_=4 5 | SELECT_=5 6 | UPDATE_=6 7 | DELETE_=7 8 | FROM_=8 9 | SET_=9 10 | WHERE_=10 11 | INTO_=11 12 | VALUES_=12 13 | TABLE_=13 14 | INDEX_=14 15 | VIEW_=15 16 | AS_=16 17 | ON_=17 18 | INT_=18 19 | VAR_CHAR_=19 20 | AND_=20 21 | OR_=21 22 | STAR=22 23 | EQUAL=23 24 | NOT_EQUAL=24 25 | COMMA=25 26 | SEMI_COLON=26 27 | IDENT=27 28 | INT_LITERAL=28 29 | STR_LITERAL=29 30 | SPACES=30 31 | '('=1 32 | ')'=2 33 | 'create'=3 34 | 'insert'=4 35 | 'select'=5 36 | 'update'=6 37 | 'delete'=7 38 | 'from'=8 39 | 'set'=9 40 | 'where'=10 41 | 'into'=11 42 | 'values'=12 43 | 'table'=13 44 | 'index'=14 45 | 'view'=15 46 | 'as'=16 47 | 'on'=17 48 | 'int'=18 49 | 'varchar'=19 50 | 'and'=20 51 | 'or'=21 52 | '*'=22 53 | '='=23 54 | '!='=24 55 | ','=25 56 | ';'=26 57 | -------------------------------------------------------------------------------- /internal/parser/ast.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | // "github.com/evanxg852000/simpledb/internal/record" 4 | 5 | const ( 6 | _ = iota 7 | INTEGER_TYPE 8 | STRING_TYPE 9 | ) 10 | 11 | type Literal struct { 12 | Value any 13 | } 14 | 15 | func (c *Literal) Type() int64 { 16 | if _, ok := c.Value.(string); ok { 17 | return STRING_TYPE 18 | } 19 | return INTEGER_TYPE 20 | } 21 | 22 | func (c *Literal) AsInt() int { 23 | return c.Value.(int) 24 | } 25 | 26 | func (c *Literal) AsString() string { 27 | return c.Value.(string) 28 | } 29 | 30 | // type FieldIdent struct { 31 | // name string 32 | // } 33 | 34 | type Condition struct { 35 | Left Term 36 | Op string 37 | Right Term 38 | } 39 | 40 | type Term struct { 41 | Left Expr 42 | Op string 43 | Right Expr 44 | } 45 | 46 | type Expr struct { 47 | Value any // FieldName or Literal 48 | } 49 | 50 | func (e *Expr) IsFieldName() bool { 51 | if _, ok := e.Value.(string); ok { 52 | return true 53 | } 54 | return false 55 | } 56 | 57 | func (e *Expr) AsFieldExpr() string { 58 | return e.Value.(string) 59 | } 60 | 61 | func (e *Expr) AsLiteralExpr() Literal { 62 | return e.Value.(Literal) 63 | } 64 | 65 | type InsertStmt struct { 66 | Table string 67 | Fields []string 68 | Values []Literal 69 | } 70 | 71 | type SelectStmt struct { 72 | Fields []string 73 | Tables []string 74 | Condition Condition 75 | } 76 | 77 | type UpdateExpr struct { 78 | Field string 79 | Value Expr 80 | } 81 | 82 | type UpdateStmt struct { 83 | Table string 84 | Exprs []UpdateExpr 85 | Condition Condition 86 | } 87 | 88 | type DeleteStmt struct { 89 | Table string 90 | Condition Condition 91 | } 92 | 93 | type TypeSpec struct { 94 | DataType int64 95 | Length int64 96 | } 97 | 98 | type FieldSpec struct { 99 | Name string 100 | Spec TypeSpec 101 | } 102 | 103 | type CreateTableStmt struct { 104 | Table string 105 | Fields []FieldSpec 106 | } 107 | 108 | type CreateViewStmt struct { 109 | Name string 110 | Query SelectStmt 111 | QueryStr string 112 | } 113 | 114 | type CreateIndexStmt struct { 115 | Name string 116 | Table string 117 | Field string 118 | } 119 | -------------------------------------------------------------------------------- /internal/parser/parser_test.go: -------------------------------------------------------------------------------- 1 | package parser_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/evanxg852000/simpledb/internal/parser" 7 | "github.com/evanxg852000/simpledb/internal/record" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestParseCreateTableStmt(t *testing.T) { 12 | assert := assert.New(t) 13 | input := "create table foo(a int, b varchar(4), c int)" 14 | ast := parser.ParseQuery(input) 15 | 16 | stmts := ast.([]any) 17 | assert.Equal(len(stmts), 1) 18 | 19 | createStmt := stmts[0].(parser.CreateTableStmt) 20 | assert.Equal(parser.CreateTableStmt{"foo", []parser.FieldSpec{ 21 | {"a", parser.TypeSpec{record.INTEGER_TYPE, 0}}, 22 | {"b", parser.TypeSpec{record.STRING_TYPE, 4}}, 23 | {"c", parser.TypeSpec{record.INTEGER_TYPE, 0}}, 24 | }}, createStmt) 25 | } 26 | 27 | func TestParseInsertStmt(t *testing.T) { 28 | assert := assert.New(t) 29 | input := "insert into foo(a, b) values (2, 'evan')" 30 | ast := parser.ParseQuery(input) 31 | 32 | stmts := ast.([]any) 33 | assert.Equal(len(stmts), 1) 34 | 35 | insertStmt := stmts[0].(parser.InsertStmt) 36 | assert.Equal(parser.InsertStmt{ 37 | "foo", 38 | []string{"a", "b"}, 39 | []parser.Literal{{int64(2)}, {"evan"}}, 40 | }, insertStmt) 41 | } 42 | 43 | func TestParseSelectStmt(t *testing.T) { 44 | assert := assert.New(t) 45 | input := "select a, b from foo where a=1" 46 | ast := parser.ParseQuery(input) 47 | 48 | stmts := ast.([]any) 49 | assert.Equal(len(stmts), 1) 50 | 51 | selectStmt := stmts[0].(parser.SelectStmt) 52 | assert.Equal(parser.SelectStmt{ 53 | []string{"a", "b"}, 54 | []string{"foo"}, 55 | parser.Condition{ 56 | Left: parser.Term{ 57 | parser.Expr{"a"}, 58 | "=", 59 | parser.Expr{parser.Literal{int64(1)}}, 60 | }, 61 | Op: "", 62 | Right: parser.Term{}, 63 | }, 64 | }, selectStmt) 65 | } 66 | 67 | func TestParseUpdateStmt(t *testing.T) { 68 | assert := assert.New(t) 69 | input := "update foo set a=2, b=1 where a=1 or b != 2" 70 | ast := parser.ParseQuery(input) 71 | 72 | stmts := ast.([]any) 73 | assert.Equal(len(stmts), 1) 74 | 75 | updateStmt := stmts[0].(parser.UpdateStmt) 76 | assert.Equal(parser.UpdateStmt{ 77 | "foo", 78 | []parser.UpdateExpr{ 79 | {"a", parser.Expr{parser.Literal{int64(2)}}}, 80 | {"b", parser.Expr{parser.Literal{int64(1)}}}, 81 | }, 82 | parser.Condition{ 83 | Left: parser.Term{ 84 | parser.Expr{"a"}, 85 | "=", 86 | parser.Expr{parser.Literal{int64(1)}}, 87 | }, 88 | Op: "or", 89 | Right: parser.Term{ 90 | parser.Expr{"b"}, 91 | "!=", 92 | parser.Expr{parser.Literal{int64(2)}}, 93 | }, 94 | }, 95 | }, updateStmt) 96 | } 97 | 98 | func TestParseDeleteStmt(t *testing.T) { 99 | assert := assert.New(t) 100 | input := "delete from foo where a=23 and f!=100" 101 | ast := parser.ParseQuery(input) 102 | 103 | stmts := ast.([]any) 104 | assert.Equal(len(stmts), 1) 105 | 106 | deleteStmt := stmts[0].(parser.DeleteStmt) 107 | assert.Equal(parser.DeleteStmt{ 108 | "foo", 109 | parser.Condition{ 110 | Left: parser.Term{ 111 | parser.Expr{"a"}, 112 | "=", 113 | parser.Expr{parser.Literal{int64(23)}}, 114 | }, 115 | Op: "and", 116 | Right: parser.Term{ 117 | parser.Expr{"f"}, 118 | "!=", 119 | parser.Expr{parser.Literal{int64(100)}}, 120 | }, 121 | }, 122 | }, deleteStmt) 123 | } 124 | 125 | func TestParseCreateViewStmt(t *testing.T) { 126 | assert := assert.New(t) 127 | input := "create view view1 as select * from foo where a=23" 128 | ast := parser.ParseQuery(input) 129 | 130 | stmts := ast.([]any) 131 | assert.Equal(len(stmts), 1) 132 | 133 | createViewStmt := stmts[0].(parser.CreateViewStmt) 134 | assert.Equal(parser.CreateViewStmt{ 135 | "view1", 136 | parser.SelectStmt{ 137 | []string{"*"}, 138 | []string{"foo"}, 139 | parser.Condition{ 140 | Left: parser.Term{ 141 | parser.Expr{"a"}, 142 | "=", 143 | parser.Expr{parser.Literal{int64(23)}}, 144 | }, 145 | Op: "", 146 | Right: parser.Term{}, 147 | }, 148 | }, 149 | "select*fromfoowherea=23", 150 | }, createViewStmt) 151 | } 152 | 153 | func TestParseCreateIndexStmt(t *testing.T) { 154 | assert := assert.New(t) 155 | input := "create index index_b on foo (b)" 156 | ast := parser.ParseQuery(input) 157 | 158 | stmts := ast.([]any) 159 | assert.Equal(len(stmts), 1) 160 | 161 | createIndexStmt := stmts[0].(parser.CreateIndexStmt) 162 | assert.Equal(parser.CreateIndexStmt{"index_b", "foo", "b"}, createIndexStmt) 163 | } 164 | -------------------------------------------------------------------------------- /internal/parser/simplesql_base_visitor.go: -------------------------------------------------------------------------------- 1 | // Code generated from SimpleSql.g4 by ANTLR 4.7.2. DO NOT EDIT. 2 | 3 | package parser // SimpleSql 4 | import "github.com/antlr/antlr4/runtime/Go/antlr" 5 | 6 | type BaseSimpleSqlVisitor struct { 7 | *antlr.BaseParseTreeVisitor 8 | } 9 | 10 | func (v *BaseSimpleSqlVisitor) VisitParse(ctx *ParseContext) interface{} { 11 | return v.VisitChildren(ctx) 12 | } 13 | 14 | func (v *BaseSimpleSqlVisitor) VisitStatementList(ctx *StatementListContext) interface{} { 15 | return v.VisitChildren(ctx) 16 | } 17 | 18 | func (v *BaseSimpleSqlVisitor) VisitStatement(ctx *StatementContext) interface{} { 19 | return v.VisitChildren(ctx) 20 | } 21 | 22 | func (v *BaseSimpleSqlVisitor) VisitCreate_table_stmt(ctx *Create_table_stmtContext) interface{} { 23 | return v.VisitChildren(ctx) 24 | } 25 | 26 | func (v *BaseSimpleSqlVisitor) VisitField_specs(ctx *Field_specsContext) interface{} { 27 | return v.VisitChildren(ctx) 28 | } 29 | 30 | func (v *BaseSimpleSqlVisitor) VisitField_spec(ctx *Field_specContext) interface{} { 31 | return v.VisitChildren(ctx) 32 | } 33 | 34 | func (v *BaseSimpleSqlVisitor) VisitType_spec(ctx *Type_specContext) interface{} { 35 | return v.VisitChildren(ctx) 36 | } 37 | 38 | func (v *BaseSimpleSqlVisitor) VisitVarchar_spec(ctx *Varchar_specContext) interface{} { 39 | return v.VisitChildren(ctx) 40 | } 41 | 42 | func (v *BaseSimpleSqlVisitor) VisitInsert_stmt(ctx *Insert_stmtContext) interface{} { 43 | return v.VisitChildren(ctx) 44 | } 45 | 46 | func (v *BaseSimpleSqlVisitor) VisitConstant_list(ctx *Constant_listContext) interface{} { 47 | return v.VisitChildren(ctx) 48 | } 49 | 50 | func (v *BaseSimpleSqlVisitor) VisitSelect_stmt(ctx *Select_stmtContext) interface{} { 51 | return v.VisitChildren(ctx) 52 | } 53 | 54 | func (v *BaseSimpleSqlVisitor) VisitIdent_list(ctx *Ident_listContext) interface{} { 55 | return v.VisitChildren(ctx) 56 | } 57 | 58 | func (v *BaseSimpleSqlVisitor) VisitUpdate_stmt(ctx *Update_stmtContext) interface{} { 59 | return v.VisitChildren(ctx) 60 | } 61 | 62 | func (v *BaseSimpleSqlVisitor) VisitUpdate_expr_list(ctx *Update_expr_listContext) interface{} { 63 | return v.VisitChildren(ctx) 64 | } 65 | 66 | func (v *BaseSimpleSqlVisitor) VisitUpdate_expr(ctx *Update_exprContext) interface{} { 67 | return v.VisitChildren(ctx) 68 | } 69 | 70 | func (v *BaseSimpleSqlVisitor) VisitDelete_stmt(ctx *Delete_stmtContext) interface{} { 71 | return v.VisitChildren(ctx) 72 | } 73 | 74 | func (v *BaseSimpleSqlVisitor) VisitCreate_view_stmt(ctx *Create_view_stmtContext) interface{} { 75 | return v.VisitChildren(ctx) 76 | } 77 | 78 | func (v *BaseSimpleSqlVisitor) VisitCreate_index_stmt(ctx *Create_index_stmtContext) interface{} { 79 | return v.VisitChildren(ctx) 80 | } 81 | 82 | func (v *BaseSimpleSqlVisitor) VisitCondition(ctx *ConditionContext) interface{} { 83 | return v.VisitChildren(ctx) 84 | } 85 | 86 | func (v *BaseSimpleSqlVisitor) VisitTerm(ctx *TermContext) interface{} { 87 | return v.VisitChildren(ctx) 88 | } 89 | 90 | func (v *BaseSimpleSqlVisitor) VisitExpression(ctx *ExpressionContext) interface{} { 91 | return v.VisitChildren(ctx) 92 | } 93 | 94 | func (v *BaseSimpleSqlVisitor) VisitLiteral(ctx *LiteralContext) interface{} { 95 | return v.VisitChildren(ctx) 96 | } 97 | -------------------------------------------------------------------------------- /internal/parser/simplesql_lexer.go: -------------------------------------------------------------------------------- 1 | // Code generated from SimpleSql.g4 by ANTLR 4.7.2. DO NOT EDIT. 2 | 3 | package parser 4 | 5 | import ( 6 | "fmt" 7 | "unicode" 8 | 9 | "github.com/antlr/antlr4/runtime/Go/antlr" 10 | ) 11 | 12 | // Suppress unused import error 13 | var _ = fmt.Printf 14 | var _ = unicode.IsLetter 15 | 16 | var serializedLexerAtn = []uint16{ 17 | 3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 32, 217, 18 | 8, 1, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 19 | 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 20 | 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 21 | 18, 9, 18, 4, 19, 9, 19, 4, 20, 9, 20, 4, 21, 9, 21, 4, 22, 9, 22, 4, 23, 22 | 9, 23, 4, 24, 9, 24, 4, 25, 9, 25, 4, 26, 9, 26, 4, 27, 9, 27, 4, 28, 9, 23 | 28, 4, 29, 9, 29, 4, 30, 9, 30, 4, 31, 9, 31, 3, 2, 3, 2, 3, 3, 3, 3, 3, 24 | 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 25 | 5, 3, 5, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 3, 7, 3, 7, 3, 7, 3, 26 | 7, 3, 7, 3, 7, 3, 7, 3, 8, 3, 8, 3, 8, 3, 8, 3, 8, 3, 8, 3, 8, 3, 9, 3, 27 | 9, 3, 9, 3, 9, 3, 9, 3, 10, 3, 10, 3, 10, 3, 10, 3, 11, 3, 11, 3, 11, 3, 28 | 11, 3, 11, 3, 11, 3, 12, 3, 12, 3, 12, 3, 12, 3, 12, 3, 13, 3, 13, 3, 13, 29 | 3, 13, 3, 13, 3, 13, 3, 13, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 30 | 15, 3, 15, 3, 15, 3, 15, 3, 15, 3, 15, 3, 16, 3, 16, 3, 16, 3, 16, 3, 16, 31 | 3, 17, 3, 17, 3, 17, 3, 18, 3, 18, 3, 18, 3, 19, 3, 19, 3, 19, 3, 19, 3, 32 | 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 21, 3, 21, 3, 21, 33 | 3, 21, 3, 22, 3, 22, 3, 22, 3, 23, 3, 23, 3, 24, 3, 24, 3, 25, 3, 25, 3, 34 | 25, 3, 26, 3, 26, 3, 27, 3, 27, 3, 28, 3, 28, 7, 28, 185, 10, 28, 12, 28, 35 | 14, 28, 188, 11, 28, 3, 29, 3, 29, 5, 29, 192, 10, 29, 3, 29, 3, 29, 7, 36 | 29, 196, 10, 29, 12, 29, 14, 29, 199, 11, 29, 5, 29, 201, 10, 29, 3, 30, 37 | 3, 30, 3, 30, 3, 30, 7, 30, 207, 10, 30, 12, 30, 14, 30, 210, 11, 30, 3, 38 | 30, 3, 30, 3, 31, 3, 31, 3, 31, 3, 31, 2, 2, 32, 3, 3, 5, 4, 7, 5, 9, 6, 39 | 11, 7, 13, 8, 15, 9, 17, 10, 19, 11, 21, 12, 23, 13, 25, 14, 27, 15, 29, 40 | 16, 31, 17, 33, 18, 35, 19, 37, 20, 39, 21, 41, 22, 43, 23, 45, 24, 47, 41 | 25, 49, 26, 51, 27, 53, 28, 55, 29, 57, 30, 59, 31, 61, 32, 3, 2, 9, 5, 42 | 2, 67, 92, 97, 97, 99, 124, 6, 2, 50, 59, 67, 92, 97, 97, 99, 124, 4, 2, 43 | 45, 45, 47, 47, 3, 2, 51, 59, 3, 2, 50, 59, 3, 2, 41, 41, 5, 2, 11, 12, 44 | 15, 15, 34, 34, 2, 222, 2, 3, 3, 2, 2, 2, 2, 5, 3, 2, 2, 2, 2, 7, 3, 2, 45 | 2, 2, 2, 9, 3, 2, 2, 2, 2, 11, 3, 2, 2, 2, 2, 13, 3, 2, 2, 2, 2, 15, 3, 46 | 2, 2, 2, 2, 17, 3, 2, 2, 2, 2, 19, 3, 2, 2, 2, 2, 21, 3, 2, 2, 2, 2, 23, 47 | 3, 2, 2, 2, 2, 25, 3, 2, 2, 2, 2, 27, 3, 2, 2, 2, 2, 29, 3, 2, 2, 2, 2, 48 | 31, 3, 2, 2, 2, 2, 33, 3, 2, 2, 2, 2, 35, 3, 2, 2, 2, 2, 37, 3, 2, 2, 2, 49 | 2, 39, 3, 2, 2, 2, 2, 41, 3, 2, 2, 2, 2, 43, 3, 2, 2, 2, 2, 45, 3, 2, 2, 50 | 2, 2, 47, 3, 2, 2, 2, 2, 49, 3, 2, 2, 2, 2, 51, 3, 2, 2, 2, 2, 53, 3, 2, 51 | 2, 2, 2, 55, 3, 2, 2, 2, 2, 57, 3, 2, 2, 2, 2, 59, 3, 2, 2, 2, 2, 61, 3, 52 | 2, 2, 2, 3, 63, 3, 2, 2, 2, 5, 65, 3, 2, 2, 2, 7, 67, 3, 2, 2, 2, 9, 74, 53 | 3, 2, 2, 2, 11, 81, 3, 2, 2, 2, 13, 88, 3, 2, 2, 2, 15, 95, 3, 2, 2, 2, 54 | 17, 102, 3, 2, 2, 2, 19, 107, 3, 2, 2, 2, 21, 111, 3, 2, 2, 2, 23, 117, 55 | 3, 2, 2, 2, 25, 122, 3, 2, 2, 2, 27, 129, 3, 2, 2, 2, 29, 135, 3, 2, 2, 56 | 2, 31, 141, 3, 2, 2, 2, 33, 146, 3, 2, 2, 2, 35, 149, 3, 2, 2, 2, 37, 152, 57 | 3, 2, 2, 2, 39, 156, 3, 2, 2, 2, 41, 164, 3, 2, 2, 2, 43, 168, 3, 2, 2, 58 | 2, 45, 171, 3, 2, 2, 2, 47, 173, 3, 2, 2, 2, 49, 175, 3, 2, 2, 2, 51, 178, 59 | 3, 2, 2, 2, 53, 180, 3, 2, 2, 2, 55, 182, 3, 2, 2, 2, 57, 200, 3, 2, 2, 60 | 2, 59, 202, 3, 2, 2, 2, 61, 213, 3, 2, 2, 2, 63, 64, 7, 42, 2, 2, 64, 4, 61 | 3, 2, 2, 2, 65, 66, 7, 43, 2, 2, 66, 6, 3, 2, 2, 2, 67, 68, 7, 101, 2, 62 | 2, 68, 69, 7, 116, 2, 2, 69, 70, 7, 103, 2, 2, 70, 71, 7, 99, 2, 2, 71, 63 | 72, 7, 118, 2, 2, 72, 73, 7, 103, 2, 2, 73, 8, 3, 2, 2, 2, 74, 75, 7, 107, 64 | 2, 2, 75, 76, 7, 112, 2, 2, 76, 77, 7, 117, 2, 2, 77, 78, 7, 103, 2, 2, 65 | 78, 79, 7, 116, 2, 2, 79, 80, 7, 118, 2, 2, 80, 10, 3, 2, 2, 2, 81, 82, 66 | 7, 117, 2, 2, 82, 83, 7, 103, 2, 2, 83, 84, 7, 110, 2, 2, 84, 85, 7, 103, 67 | 2, 2, 85, 86, 7, 101, 2, 2, 86, 87, 7, 118, 2, 2, 87, 12, 3, 2, 2, 2, 88, 68 | 89, 7, 119, 2, 2, 89, 90, 7, 114, 2, 2, 90, 91, 7, 102, 2, 2, 91, 92, 7, 69 | 99, 2, 2, 92, 93, 7, 118, 2, 2, 93, 94, 7, 103, 2, 2, 94, 14, 3, 2, 2, 70 | 2, 95, 96, 7, 102, 2, 2, 96, 97, 7, 103, 2, 2, 97, 98, 7, 110, 2, 2, 98, 71 | 99, 7, 103, 2, 2, 99, 100, 7, 118, 2, 2, 100, 101, 7, 103, 2, 2, 101, 16, 72 | 3, 2, 2, 2, 102, 103, 7, 104, 2, 2, 103, 104, 7, 116, 2, 2, 104, 105, 7, 73 | 113, 2, 2, 105, 106, 7, 111, 2, 2, 106, 18, 3, 2, 2, 2, 107, 108, 7, 117, 74 | 2, 2, 108, 109, 7, 103, 2, 2, 109, 110, 7, 118, 2, 2, 110, 20, 3, 2, 2, 75 | 2, 111, 112, 7, 121, 2, 2, 112, 113, 7, 106, 2, 2, 113, 114, 7, 103, 2, 76 | 2, 114, 115, 7, 116, 2, 2, 115, 116, 7, 103, 2, 2, 116, 22, 3, 2, 2, 2, 77 | 117, 118, 7, 107, 2, 2, 118, 119, 7, 112, 2, 2, 119, 120, 7, 118, 2, 2, 78 | 120, 121, 7, 113, 2, 2, 121, 24, 3, 2, 2, 2, 122, 123, 7, 120, 2, 2, 123, 79 | 124, 7, 99, 2, 2, 124, 125, 7, 110, 2, 2, 125, 126, 7, 119, 2, 2, 126, 80 | 127, 7, 103, 2, 2, 127, 128, 7, 117, 2, 2, 128, 26, 3, 2, 2, 2, 129, 130, 81 | 7, 118, 2, 2, 130, 131, 7, 99, 2, 2, 131, 132, 7, 100, 2, 2, 132, 133, 82 | 7, 110, 2, 2, 133, 134, 7, 103, 2, 2, 134, 28, 3, 2, 2, 2, 135, 136, 7, 83 | 107, 2, 2, 136, 137, 7, 112, 2, 2, 137, 138, 7, 102, 2, 2, 138, 139, 7, 84 | 103, 2, 2, 139, 140, 7, 122, 2, 2, 140, 30, 3, 2, 2, 2, 141, 142, 7, 120, 85 | 2, 2, 142, 143, 7, 107, 2, 2, 143, 144, 7, 103, 2, 2, 144, 145, 7, 121, 86 | 2, 2, 145, 32, 3, 2, 2, 2, 146, 147, 7, 99, 2, 2, 147, 148, 7, 117, 2, 87 | 2, 148, 34, 3, 2, 2, 2, 149, 150, 7, 113, 2, 2, 150, 151, 7, 112, 2, 2, 88 | 151, 36, 3, 2, 2, 2, 152, 153, 7, 107, 2, 2, 153, 154, 7, 112, 2, 2, 154, 89 | 155, 7, 118, 2, 2, 155, 38, 3, 2, 2, 2, 156, 157, 7, 120, 2, 2, 157, 158, 90 | 7, 99, 2, 2, 158, 159, 7, 116, 2, 2, 159, 160, 7, 101, 2, 2, 160, 161, 91 | 7, 106, 2, 2, 161, 162, 7, 99, 2, 2, 162, 163, 7, 116, 2, 2, 163, 40, 3, 92 | 2, 2, 2, 164, 165, 7, 99, 2, 2, 165, 166, 7, 112, 2, 2, 166, 167, 7, 102, 93 | 2, 2, 167, 42, 3, 2, 2, 2, 168, 169, 7, 113, 2, 2, 169, 170, 7, 116, 2, 94 | 2, 170, 44, 3, 2, 2, 2, 171, 172, 7, 44, 2, 2, 172, 46, 3, 2, 2, 2, 173, 95 | 174, 7, 63, 2, 2, 174, 48, 3, 2, 2, 2, 175, 176, 7, 35, 2, 2, 176, 177, 96 | 7, 63, 2, 2, 177, 50, 3, 2, 2, 2, 178, 179, 7, 46, 2, 2, 179, 52, 3, 2, 97 | 2, 2, 180, 181, 7, 61, 2, 2, 181, 54, 3, 2, 2, 2, 182, 186, 9, 2, 2, 2, 98 | 183, 185, 9, 3, 2, 2, 184, 183, 3, 2, 2, 2, 185, 188, 3, 2, 2, 2, 186, 99 | 184, 3, 2, 2, 2, 186, 187, 3, 2, 2, 2, 187, 56, 3, 2, 2, 2, 188, 186, 3, 100 | 2, 2, 2, 189, 201, 7, 50, 2, 2, 190, 192, 9, 4, 2, 2, 191, 190, 3, 2, 2, 101 | 2, 191, 192, 3, 2, 2, 2, 192, 193, 3, 2, 2, 2, 193, 197, 9, 5, 2, 2, 194, 102 | 196, 9, 6, 2, 2, 195, 194, 3, 2, 2, 2, 196, 199, 3, 2, 2, 2, 197, 195, 103 | 3, 2, 2, 2, 197, 198, 3, 2, 2, 2, 198, 201, 3, 2, 2, 2, 199, 197, 3, 2, 104 | 2, 2, 200, 189, 3, 2, 2, 2, 200, 191, 3, 2, 2, 2, 201, 58, 3, 2, 2, 2, 105 | 202, 208, 7, 41, 2, 2, 203, 207, 10, 7, 2, 2, 204, 205, 7, 41, 2, 2, 205, 106 | 207, 7, 41, 2, 2, 206, 203, 3, 2, 2, 2, 206, 204, 3, 2, 2, 2, 207, 210, 107 | 3, 2, 2, 2, 208, 206, 3, 2, 2, 2, 208, 209, 3, 2, 2, 2, 209, 211, 3, 2, 108 | 2, 2, 210, 208, 3, 2, 2, 2, 211, 212, 7, 41, 2, 2, 212, 60, 3, 2, 2, 2, 109 | 213, 214, 9, 8, 2, 2, 214, 215, 3, 2, 2, 2, 215, 216, 8, 31, 2, 2, 216, 110 | 62, 3, 2, 2, 2, 9, 2, 186, 191, 197, 200, 206, 208, 3, 8, 2, 2, 111 | } 112 | 113 | var lexerDeserializer = antlr.NewATNDeserializer(nil) 114 | var lexerAtn = lexerDeserializer.DeserializeFromUInt16(serializedLexerAtn) 115 | 116 | var lexerChannelNames = []string{ 117 | "DEFAULT_TOKEN_CHANNEL", "HIDDEN", 118 | } 119 | 120 | var lexerModeNames = []string{ 121 | "DEFAULT_MODE", 122 | } 123 | 124 | var lexerLiteralNames = []string{ 125 | "", "'('", "')'", "'create'", "'insert'", "'select'", "'update'", "'delete'", 126 | "'from'", "'set'", "'where'", "'into'", "'values'", "'table'", "'index'", 127 | "'view'", "'as'", "'on'", "'int'", "'varchar'", "'and'", "'or'", "'*'", 128 | "'='", "'!='", "','", "';'", 129 | } 130 | 131 | var lexerSymbolicNames = []string{ 132 | "", "", "", "CREATE_", "INSERT_", "SELECT_", "UPDATE_", "DELETE_", "FROM_", 133 | "SET_", "WHERE_", "INTO_", "VALUES_", "TABLE_", "INDEX_", "VIEW_", "AS_", 134 | "ON_", "INT_", "VAR_CHAR_", "AND_", "OR_", "STAR", "EQUAL", "NOT_EQUAL", 135 | "COMMA", "SEMI_COLON", "IDENT", "INT_LITERAL", "STR_LITERAL", "SPACES", 136 | } 137 | 138 | var lexerRuleNames = []string{ 139 | "T__0", "T__1", "CREATE_", "INSERT_", "SELECT_", "UPDATE_", "DELETE_", 140 | "FROM_", "SET_", "WHERE_", "INTO_", "VALUES_", "TABLE_", "INDEX_", "VIEW_", 141 | "AS_", "ON_", "INT_", "VAR_CHAR_", "AND_", "OR_", "STAR", "EQUAL", "NOT_EQUAL", 142 | "COMMA", "SEMI_COLON", "IDENT", "INT_LITERAL", "STR_LITERAL", "SPACES", 143 | } 144 | 145 | type SimpleSqlLexer struct { 146 | *antlr.BaseLexer 147 | channelNames []string 148 | modeNames []string 149 | // TODO: EOF string 150 | } 151 | 152 | var lexerDecisionToDFA = make([]*antlr.DFA, len(lexerAtn.DecisionToState)) 153 | 154 | func init() { 155 | for index, ds := range lexerAtn.DecisionToState { 156 | lexerDecisionToDFA[index] = antlr.NewDFA(ds, index) 157 | } 158 | } 159 | 160 | func NewSimpleSqlLexer(input antlr.CharStream) *SimpleSqlLexer { 161 | 162 | l := new(SimpleSqlLexer) 163 | 164 | l.BaseLexer = antlr.NewBaseLexer(input) 165 | l.Interpreter = antlr.NewLexerATNSimulator(l, lexerAtn, lexerDecisionToDFA, antlr.NewPredictionContextCache()) 166 | 167 | l.channelNames = lexerChannelNames 168 | l.modeNames = lexerModeNames 169 | l.RuleNames = lexerRuleNames 170 | l.LiteralNames = lexerLiteralNames 171 | l.SymbolicNames = lexerSymbolicNames 172 | l.GrammarFileName = "SimpleSql.g4" 173 | // TODO: l.EOF = antlr.TokenEOF 174 | 175 | return l 176 | } 177 | 178 | // SimpleSqlLexer tokens. 179 | const ( 180 | SimpleSqlLexerT__0 = 1 181 | SimpleSqlLexerT__1 = 2 182 | SimpleSqlLexerCREATE_ = 3 183 | SimpleSqlLexerINSERT_ = 4 184 | SimpleSqlLexerSELECT_ = 5 185 | SimpleSqlLexerUPDATE_ = 6 186 | SimpleSqlLexerDELETE_ = 7 187 | SimpleSqlLexerFROM_ = 8 188 | SimpleSqlLexerSET_ = 9 189 | SimpleSqlLexerWHERE_ = 10 190 | SimpleSqlLexerINTO_ = 11 191 | SimpleSqlLexerVALUES_ = 12 192 | SimpleSqlLexerTABLE_ = 13 193 | SimpleSqlLexerINDEX_ = 14 194 | SimpleSqlLexerVIEW_ = 15 195 | SimpleSqlLexerAS_ = 16 196 | SimpleSqlLexerON_ = 17 197 | SimpleSqlLexerINT_ = 18 198 | SimpleSqlLexerVAR_CHAR_ = 19 199 | SimpleSqlLexerAND_ = 20 200 | SimpleSqlLexerOR_ = 21 201 | SimpleSqlLexerSTAR = 22 202 | SimpleSqlLexerEQUAL = 23 203 | SimpleSqlLexerNOT_EQUAL = 24 204 | SimpleSqlLexerCOMMA = 25 205 | SimpleSqlLexerSEMI_COLON = 26 206 | SimpleSqlLexerIDENT = 27 207 | SimpleSqlLexerINT_LITERAL = 28 208 | SimpleSqlLexerSTR_LITERAL = 29 209 | SimpleSqlLexerSPACES = 30 210 | ) 211 | -------------------------------------------------------------------------------- /internal/parser/simplesql_visitor.go: -------------------------------------------------------------------------------- 1 | // Code generated from SimpleSql.g4 by ANTLR 4.7.2. DO NOT EDIT. 2 | 3 | package parser // SimpleSql 4 | import "github.com/antlr/antlr4/runtime/Go/antlr" 5 | 6 | // A complete Visitor for a parse tree produced by SimpleSqlParser. 7 | type SimpleSqlVisitor interface { 8 | antlr.ParseTreeVisitor 9 | 10 | // Visit a parse tree produced by SimpleSqlParser#parse. 11 | VisitParse(ctx *ParseContext) interface{} 12 | 13 | // Visit a parse tree produced by SimpleSqlParser#statementList. 14 | VisitStatementList(ctx *StatementListContext) interface{} 15 | 16 | // Visit a parse tree produced by SimpleSqlParser#statement. 17 | VisitStatement(ctx *StatementContext) interface{} 18 | 19 | // Visit a parse tree produced by SimpleSqlParser#create_table_stmt. 20 | VisitCreate_table_stmt(ctx *Create_table_stmtContext) interface{} 21 | 22 | // Visit a parse tree produced by SimpleSqlParser#field_specs. 23 | VisitField_specs(ctx *Field_specsContext) interface{} 24 | 25 | // Visit a parse tree produced by SimpleSqlParser#field_spec. 26 | VisitField_spec(ctx *Field_specContext) interface{} 27 | 28 | // Visit a parse tree produced by SimpleSqlParser#type_spec. 29 | VisitType_spec(ctx *Type_specContext) interface{} 30 | 31 | // Visit a parse tree produced by SimpleSqlParser#varchar_spec. 32 | VisitVarchar_spec(ctx *Varchar_specContext) interface{} 33 | 34 | // Visit a parse tree produced by SimpleSqlParser#insert_stmt. 35 | VisitInsert_stmt(ctx *Insert_stmtContext) interface{} 36 | 37 | // Visit a parse tree produced by SimpleSqlParser#constant_list. 38 | VisitConstant_list(ctx *Constant_listContext) interface{} 39 | 40 | // Visit a parse tree produced by SimpleSqlParser#select_stmt. 41 | VisitSelect_stmt(ctx *Select_stmtContext) interface{} 42 | 43 | // Visit a parse tree produced by SimpleSqlParser#ident_list. 44 | VisitIdent_list(ctx *Ident_listContext) interface{} 45 | 46 | // Visit a parse tree produced by SimpleSqlParser#update_stmt. 47 | VisitUpdate_stmt(ctx *Update_stmtContext) interface{} 48 | 49 | // Visit a parse tree produced by SimpleSqlParser#update_expr_list. 50 | VisitUpdate_expr_list(ctx *Update_expr_listContext) interface{} 51 | 52 | // Visit a parse tree produced by SimpleSqlParser#update_expr. 53 | VisitUpdate_expr(ctx *Update_exprContext) interface{} 54 | 55 | // Visit a parse tree produced by SimpleSqlParser#delete_stmt. 56 | VisitDelete_stmt(ctx *Delete_stmtContext) interface{} 57 | 58 | // Visit a parse tree produced by SimpleSqlParser#create_view_stmt. 59 | VisitCreate_view_stmt(ctx *Create_view_stmtContext) interface{} 60 | 61 | // Visit a parse tree produced by SimpleSqlParser#create_index_stmt. 62 | VisitCreate_index_stmt(ctx *Create_index_stmtContext) interface{} 63 | 64 | // Visit a parse tree produced by SimpleSqlParser#condition. 65 | VisitCondition(ctx *ConditionContext) interface{} 66 | 67 | // Visit a parse tree produced by SimpleSqlParser#term. 68 | VisitTerm(ctx *TermContext) interface{} 69 | 70 | // Visit a parse tree produced by SimpleSqlParser#expression. 71 | VisitExpression(ctx *ExpressionContext) interface{} 72 | 73 | // Visit a parse tree produced by SimpleSqlParser#literal. 74 | VisitLiteral(ctx *LiteralContext) interface{} 75 | } 76 | -------------------------------------------------------------------------------- /internal/parser/visitor.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | 7 | "github.com/antlr/antlr4/runtime/Go/antlr" 8 | ) 9 | 10 | func ParseQuery(input string) interface{} { 11 | is := antlr.NewInputStream(input) 12 | lexer := NewSimpleSqlLexer(is) 13 | tokens := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel) 14 | parser := NewSimpleSqlParser(tokens) 15 | visitor := NewSimpleSqlAstBuilder() 16 | return parser.Parse().Accept(visitor) 17 | } 18 | 19 | type SimpleSqlAstBuilder struct { 20 | *antlr.BaseParseTreeVisitor 21 | } 22 | 23 | func NewSimpleSqlAstBuilder() *SimpleSqlAstBuilder { 24 | return &SimpleSqlAstBuilder{} 25 | } 26 | 27 | func (v *SimpleSqlAstBuilder) VisitParse(ctx *ParseContext) interface{} { 28 | stmtListCtx := ctx.StatementList(0) 29 | return v.VisitStatementList(stmtListCtx.(*StatementListContext)) 30 | } 31 | 32 | func (v *SimpleSqlAstBuilder) VisitStatementList(ctx *StatementListContext) interface{} { 33 | statements := make([]any, 0) 34 | nodes := ctx.AllStatement() 35 | for _, node := range nodes { 36 | statements = append(statements, v.VisitStatement(node.(*StatementContext))) 37 | } 38 | return statements 39 | } 40 | 41 | func (v *SimpleSqlAstBuilder) VisitStatement(ctx *StatementContext) interface{} { 42 | if stmt := ctx.Create_table_stmt(); stmt != nil { 43 | return v.VisitCreate_table_stmt(stmt.(*Create_table_stmtContext)) 44 | } 45 | 46 | if stmt := ctx.Insert_stmt(); stmt != nil { 47 | return v.VisitInsert_stmt(stmt.(*Insert_stmtContext)) 48 | } 49 | 50 | if stmt := ctx.Select_stmt(); stmt != nil { 51 | return v.VisitSelect_stmt(stmt.(*Select_stmtContext)) 52 | } 53 | 54 | if stmt := ctx.Update_stmt(); stmt != nil { 55 | return v.VisitUpdate_stmt(stmt.(*Update_stmtContext)) 56 | } 57 | 58 | if stmt := ctx.Delete_stmt(); stmt != nil { 59 | return v.VisitDelete_stmt(stmt.(*Delete_stmtContext)) 60 | } 61 | 62 | if stmt := ctx.Create_view_stmt(); stmt != nil { 63 | return v.VisitCreate_view_stmt(stmt.(*Create_view_stmtContext)) 64 | } 65 | 66 | if stmt := ctx.Create_index_stmt(); stmt != nil { 67 | return v.VisitCreate_index_stmt(stmt.(*Create_index_stmtContext)) 68 | } 69 | 70 | return v.VisitChildren(ctx) 71 | } 72 | 73 | func (v *SimpleSqlAstBuilder) VisitCreate_table_stmt(ctx *Create_table_stmtContext) interface{} { 74 | if ctx.CREATE_() == nil { 75 | return nil 76 | } 77 | 78 | if ctx.TABLE_() == nil { 79 | return nil 80 | } 81 | 82 | tableName := ctx.IDENT().GetText() 83 | fieldSpecs := v.VisitField_specs(ctx.Field_specs().(*Field_specsContext)) 84 | return CreateTableStmt{tableName, fieldSpecs.([]FieldSpec)} 85 | } 86 | 87 | func (v *SimpleSqlAstBuilder) VisitField_specs(ctx *Field_specsContext) interface{} { 88 | listCtxs := ctx.AllField_spec() 89 | fieldSpec := v.VisitField_spec(listCtxs[0].(*Field_specContext)).(FieldSpec) 90 | fieldSpecs := []FieldSpec{fieldSpec} 91 | for i, fieldSpecCtx := range listCtxs[1:] { 92 | if ctx.COMMA(i) == nil { 93 | return nil 94 | } 95 | fieldSpec := v.VisitField_spec(fieldSpecCtx.(*Field_specContext)).(FieldSpec) 96 | fieldSpecs = append(fieldSpecs, fieldSpec) 97 | } 98 | return fieldSpecs 99 | } 100 | 101 | func (v *SimpleSqlAstBuilder) VisitField_spec(ctx *Field_specContext) interface{} { 102 | field := ctx.IDENT().GetText() 103 | typeSpec := v.VisitType_spec(ctx.Type_spec().(*Type_specContext)) 104 | return FieldSpec{field, typeSpec.(TypeSpec)} 105 | } 106 | 107 | func (v *SimpleSqlAstBuilder) VisitType_spec(ctx *Type_specContext) interface{} { 108 | if ctx.INT_() != nil { 109 | return TypeSpec{INTEGER_TYPE, 0} 110 | } 111 | 112 | typeSpec := v.VisitVarchar_spec(ctx.Varchar_spec().(*Varchar_specContext)) 113 | return typeSpec 114 | } 115 | 116 | func (v *SimpleSqlAstBuilder) VisitVarchar_spec(ctx *Varchar_specContext) interface{} { 117 | if ctx.VAR_CHAR_() == nil { 118 | return nil 119 | } 120 | length, err := strconv.ParseInt(ctx.INT_LITERAL().GetText(), 10, 64) 121 | if err != nil { 122 | return nil 123 | } 124 | return TypeSpec{STRING_TYPE, length} 125 | } 126 | 127 | func (v *SimpleSqlAstBuilder) VisitInsert_stmt(ctx *Insert_stmtContext) interface{} { 128 | if ctx.INSERT_() == nil { 129 | return nil 130 | } 131 | 132 | if ctx.INTO_() == nil { 133 | return nil 134 | } 135 | 136 | tableName := ctx.IDENT().GetText() 137 | 138 | fields := v.VisitIdent_list(ctx.Ident_list().(*Ident_listContext)).([]string) 139 | fieldList := make([]string, 0, len(fields)) 140 | fieldList = append(fieldList, fields...) 141 | 142 | if ctx.VALUES_() == nil { 143 | return nil 144 | } 145 | 146 | valueList := v.VisitConstant_list(ctx.Constant_list().(*Constant_listContext)).([]Literal) 147 | return InsertStmt{tableName, fieldList, valueList} 148 | } 149 | 150 | func (v *SimpleSqlAstBuilder) VisitConstant_list(ctx *Constant_listContext) interface{} { 151 | listCtxs := ctx.AllLiteral() 152 | literal := v.VisitLiteral(listCtxs[0].(*LiteralContext)).(Literal) 153 | literalList := []Literal{literal} 154 | for i, litCtx := range listCtxs[1:] { 155 | if ctx.COMMA(i) == nil { 156 | return nil 157 | } 158 | literal := v.VisitLiteral(litCtx.(*LiteralContext)).(Literal) 159 | literalList = append(literalList, literal) 160 | } 161 | return literalList 162 | } 163 | 164 | func (v *SimpleSqlAstBuilder) VisitSelect_stmt(ctx *Select_stmtContext) interface{} { 165 | if ctx.SELECT_() == nil { 166 | return nil 167 | } 168 | 169 | fieldList := make([]string, 0) 170 | identIdx := 0 171 | if startSelect := ctx.STAR(); startSelect != nil { 172 | fieldList = append(fieldList, "*") 173 | } else { 174 | identListCtx := ctx.Ident_list(identIdx) 175 | fields := v.VisitIdent_list(identListCtx.(*Ident_listContext)).([]string) 176 | fieldList = append(fieldList, fields...) 177 | identIdx += 1 178 | } 179 | 180 | if ctx.FROM_() == nil { 181 | return nil 182 | } 183 | 184 | identListCtx := ctx.Ident_list(identIdx) 185 | tableList := v.VisitIdent_list(identListCtx.(*Ident_listContext)).([]string) 186 | 187 | condition := Condition{} 188 | if ctx.WHERE_() != nil { 189 | condition = v.VisitCondition(ctx.Condition().(*ConditionContext)).(Condition) 190 | } 191 | 192 | return SelectStmt{fieldList, tableList, condition} 193 | } 194 | 195 | func (v *SimpleSqlAstBuilder) VisitIdent_list(ctx *Ident_listContext) interface{} { 196 | list := ctx.AllIDENT() 197 | identList := []string{list[0].GetText()} 198 | for i, item := range list[1:] { 199 | if ctx.COMMA(i) == nil { 200 | return nil 201 | } 202 | identList = append(identList, item.GetText()) 203 | } 204 | return identList 205 | } 206 | 207 | func (v *SimpleSqlAstBuilder) VisitUpdate_stmt(ctx *Update_stmtContext) interface{} { 208 | if ctx.UPDATE_() == nil { 209 | return nil 210 | } 211 | 212 | tableName := ctx.IDENT().GetText() 213 | 214 | if ctx.SET_() == nil { 215 | return nil 216 | } 217 | 218 | updateExprCtx := ctx.Update_expr_list() 219 | updateExprs := v.VisitUpdate_expr_list(updateExprCtx.(*Update_expr_listContext)).([]UpdateExpr) 220 | 221 | condition := Condition{} 222 | if ctx.WHERE_() != nil { 223 | condition = v.VisitCondition(ctx.Condition().(*ConditionContext)).(Condition) 224 | } 225 | return UpdateStmt{tableName, updateExprs, condition} 226 | } 227 | 228 | func (v *SimpleSqlAstBuilder) VisitUpdate_expr_list(ctx *Update_expr_listContext) interface{} { 229 | updateExprCtxs := ctx.AllUpdate_expr() 230 | 231 | expr := v.VisitUpdate_expr(updateExprCtxs[0].(*Update_exprContext)).(UpdateExpr) 232 | updateExprList := []UpdateExpr{expr} 233 | for i, exprCtx := range updateExprCtxs[1:] { 234 | if ctx.COMMA(i) == nil { 235 | return nil 236 | } 237 | expr := v.VisitUpdate_expr(exprCtx.(*Update_exprContext)).(UpdateExpr) 238 | updateExprList = append(updateExprList, expr) 239 | } 240 | 241 | return updateExprList 242 | } 243 | 244 | func (v *SimpleSqlAstBuilder) VisitUpdate_expr(ctx *Update_exprContext) interface{} { 245 | field := ctx.IDENT().GetText() 246 | if ctx.EQUAL() == nil { 247 | return nil 248 | } 249 | 250 | expr := v.VisitExpression(ctx.Expression().(*ExpressionContext)) 251 | return UpdateExpr{field, expr.(Expr)} 252 | } 253 | 254 | func (v *SimpleSqlAstBuilder) VisitDelete_stmt(ctx *Delete_stmtContext) interface{} { 255 | tableName := ctx.IDENT().GetText() 256 | 257 | condition := Condition{} 258 | if ctx.WHERE_() != nil { 259 | condition = v.VisitCondition(ctx.Condition().(*ConditionContext)).(Condition) 260 | } 261 | return DeleteStmt{tableName, condition} 262 | } 263 | 264 | func (v *SimpleSqlAstBuilder) VisitCreate_view_stmt(ctx *Create_view_stmtContext) interface{} { 265 | if ctx.CREATE_() == nil { 266 | return nil 267 | } 268 | 269 | if ctx.VIEW_() == nil { 270 | return nil 271 | } 272 | 273 | tableName := ctx.IDENT().GetText() 274 | 275 | if ctx.AS_() == nil { 276 | return nil 277 | } 278 | 279 | selectStmt := v.VisitSelect_stmt(ctx.Select_stmt().(*Select_stmtContext)) 280 | selectStmtText := ctx.Select_stmt().GetText() 281 | return CreateViewStmt{tableName, selectStmt.(SelectStmt), selectStmtText} 282 | } 283 | 284 | func (v *SimpleSqlAstBuilder) VisitCreate_index_stmt(ctx *Create_index_stmtContext) interface{} { 285 | if ctx.CREATE_() == nil { 286 | return nil 287 | } 288 | 289 | if ctx.INDEX_() == nil { 290 | return nil 291 | } 292 | 293 | indexName := ctx.IDENT(0).GetText() 294 | 295 | if ctx.ON_() == nil { 296 | return nil 297 | } 298 | 299 | tableName := ctx.IDENT(1).GetText() 300 | field := ctx.IDENT(2).GetText() 301 | return CreateIndexStmt{indexName, tableName, field} 302 | } 303 | 304 | func (v *SimpleSqlAstBuilder) VisitCondition(ctx *ConditionContext) interface{} { 305 | termCtx := ctx.Term(0) 306 | left := v.VisitTerm(termCtx.(*TermContext)) 307 | 308 | if orTerm := ctx.OR_(); orTerm != nil { 309 | termCtx = ctx.Term(1) 310 | right := v.VisitTerm(termCtx.(*TermContext)) 311 | return Condition{left.(Term), "or", right.(Term)} 312 | } 313 | 314 | if andTerm := ctx.AND_(); andTerm != nil { 315 | termCtx = ctx.Term(1) 316 | right := v.VisitTerm(termCtx.(*TermContext)) 317 | return Condition{left.(Term), "and", right.(Term)} 318 | } 319 | return Condition{left.(Term), "", Term{}} 320 | } 321 | 322 | func (v *SimpleSqlAstBuilder) VisitTerm(ctx *TermContext) interface{} { 323 | lhs := ctx.left.Accept(v) 324 | op := ctx.operator.GetText() 325 | rhs := ctx.right.Accept(v) 326 | return Term{lhs.(Expr), op, rhs.(Expr)} 327 | } 328 | 329 | func (v *SimpleSqlAstBuilder) VisitExpression(ctx *ExpressionContext) interface{} { 330 | if expr := ctx.IDENT(); expr != nil { 331 | return Expr{expr.GetText()} 332 | } 333 | literal := ctx.Literal().Accept(v) 334 | return Expr{literal} 335 | } 336 | 337 | func (v *SimpleSqlAstBuilder) VisitLiteral(ctx *LiteralContext) interface{} { 338 | if intLit := ctx.INT_LITERAL(); intLit != nil { 339 | intValue, _ := strconv.ParseInt(intLit.GetText(), 10, 64) 340 | return Literal{intValue} 341 | } 342 | strLit := ctx.STR_LITERAL().GetText() 343 | return Literal{strings.Trim(strLit, "'")} 344 | } 345 | -------------------------------------------------------------------------------- /internal/plan/basic_query_planner.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/evanxg852000/simpledb/internal/metadata" 7 | "github.com/evanxg852000/simpledb/internal/parser" 8 | "github.com/evanxg852000/simpledb/internal/query" 9 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 10 | ) 11 | 12 | // The simplest, most naive query planner possible. 13 | type BasicQueryPlanner struct { 14 | mdtManager *metadata.MetadataManager 15 | } 16 | 17 | func NewBasicQueryPlanner(mdtManager *metadata.MetadataManager) *BasicQueryPlanner { 18 | return &BasicQueryPlanner{mdtManager} 19 | } 20 | 21 | // Creates a query plan as follows. It first takes 22 | // the product of all tables and views; it then selects on the predicate; 23 | // and finally it projects on the field list. 24 | func (bqp *BasicQueryPlanner) CreatePlan(selectStmt parser.SelectStmt, tx *recovery.Transaction) Plan { 25 | // Step 1: Create a plan for each mentioned table or view. 26 | plans := make([]Plan, 0) 27 | for _, tableName := range selectStmt.Tables { 28 | viewDef, err := bqp.mdtManager.GetViewDef(tableName, tx) 29 | if err != nil { 30 | panic(fmt.Sprint("error fetching view definition", err)) 31 | } 32 | if viewDef != "" { 33 | // Recursively plan the view. 34 | ast := parser.ParseQuery(viewDef) 35 | stmts := ast.([]any) 36 | viewStmt := stmts[0].(parser.SelectStmt) 37 | plans = append(plans, bqp.CreatePlan(viewStmt, tx)) 38 | } else { 39 | plan := NewTablePlan(tx, tableName, bqp.mdtManager) 40 | plans = append(plans, plan) 41 | } 42 | } 43 | 44 | // Step 2: Create the product of all table plans 45 | plan := plans[0] 46 | plans = plans[1:] 47 | for _, nextPlan := range plans { 48 | plan = NewProductPlan(plan, nextPlan) 49 | } 50 | 51 | // Step 3: Add a selection plan for the predicate 52 | predicate := query.NewPredicate(selectStmt.Condition) 53 | plan = NewSelectPlan(plan, predicate) 54 | 55 | // Step 4: Project on the field names 56 | plan = NewProjectPlan(plan, selectStmt.Fields) 57 | 58 | return plan 59 | } 60 | -------------------------------------------------------------------------------- /internal/plan/basic_update_planner.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/metadata" 5 | "github.com/evanxg852000/simpledb/internal/parser" 6 | "github.com/evanxg852000/simpledb/internal/query" 7 | "github.com/evanxg852000/simpledb/internal/record" 8 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 9 | ) 10 | 11 | // The basic planner for SQL update statements. 12 | type BasicUpdatePlanner struct { 13 | mdtManager *metadata.MetadataManager 14 | } 15 | 16 | func NewBasicUpdatePlanner(mdtManager *metadata.MetadataManager) *BasicUpdatePlanner { 17 | return &BasicUpdatePlanner{mdtManager} 18 | } 19 | 20 | func (bup *BasicUpdatePlanner) ExecuteInsert(stmt parser.InsertStmt, tx *recovery.Transaction) int64 { 21 | plan := NewTablePlan(tx, stmt.Table, bup.mdtManager) 22 | updateScan := plan.Open().(query.UpdateScan) 23 | 24 | updateScan.Insert() 25 | for i, fieldName := range stmt.Fields { 26 | value := query.NewConstant(stmt.Values[i].Value) 27 | updateScan.SetValue(fieldName, value) 28 | } 29 | updateScan.Close() 30 | return 1 31 | } 32 | 33 | func (bup *BasicUpdatePlanner) ExecuteDelete(stmt parser.DeleteStmt, tx *recovery.Transaction) int64 { 34 | var plan Plan 35 | plan = NewTablePlan(tx, stmt.Table, bup.mdtManager) 36 | plan = NewSelectPlan(plan, query.NewPredicate(stmt.Condition)) 37 | updateScan := plan.Open().(query.UpdateScan) 38 | count := 0 39 | for updateScan.Next() { 40 | updateScan.Delete() 41 | count += 1 42 | } 43 | updateScan.Close() 44 | return int64(count) 45 | } 46 | 47 | func (bup *BasicUpdatePlanner) ExecuteModify(stmt parser.UpdateStmt, tx *recovery.Transaction) int64 { 48 | var plan Plan 49 | plan = NewTablePlan(tx, stmt.Table, bup.mdtManager) 50 | plan = NewSelectPlan(plan, query.NewPredicate(stmt.Condition)) 51 | updateScan := plan.Open().(query.UpdateScan) 52 | count := 0 53 | for updateScan.Next() { 54 | for _, updateExpr := range stmt.Exprs { 55 | expr := query.NewExpression(updateExpr.Value) 56 | value := expr.Evaluate(updateScan) 57 | updateScan.SetValue(updateExpr.Field, value) 58 | } 59 | count += 1 60 | } 61 | updateScan.Close() 62 | return int64(count) 63 | } 64 | 65 | func (bup *BasicUpdatePlanner) ExecuteCreateTable(stmt parser.CreateTableStmt, tx *recovery.Transaction) int64 { 66 | schema := record.NewSchema() 67 | for _, fieldDesc := range stmt.Fields { 68 | fieldSpec := fieldDesc.Spec 69 | schema.AddField(fieldDesc.Name, fieldSpec.DataType, fieldSpec.Length) 70 | } 71 | err := bup.mdtManager.CreateTable(stmt.Table, schema, tx) 72 | if err != nil { 73 | panic(err) 74 | } 75 | return 0 76 | } 77 | 78 | func (bup *BasicUpdatePlanner) ExecuteCreateView(stmt parser.CreateViewStmt, tx *recovery.Transaction) int64 { 79 | err := bup.mdtManager.CreateView(stmt.Name, stmt.QueryStr, tx) 80 | if err != nil { 81 | panic(err) 82 | } 83 | return 0 84 | } 85 | 86 | func (bup *BasicUpdatePlanner) ExecuteCreateIndex(stmt parser.CreateIndexStmt, tx *recovery.Transaction) int64 { 87 | err := bup.mdtManager.CreateIndex(stmt.Name, stmt.Table, stmt.Field, tx) 88 | if err != nil { 89 | panic(err) 90 | } 91 | return 0 92 | } 93 | -------------------------------------------------------------------------------- /internal/plan/plan.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/query" 5 | "github.com/evanxg852000/simpledb/internal/record" 6 | ) 7 | 8 | // The interface implemented by each query plan. 9 | // There is a Plan class for each relational algebra operator. 10 | type Plan interface { 11 | // Opens a scan corresponding to this plan. 12 | // The scan will be positioned before its first record. 13 | Open() query.Scan 14 | 15 | // Returns an estimate of the number of block accesses 16 | // that will occur when the scan is read to completion. 17 | BlockAccessed() int64 18 | 19 | // Returns an estimate of the number of records 20 | // in the query's output table. 21 | RecordsOutput() int64 22 | 23 | // Returns an estimate of the number of distinct values 24 | // for the specified field in the query's output table. 25 | DistinctValues(string) int64 26 | 27 | // Returns the schema of the query. 28 | Schema() record.Schema 29 | } 30 | -------------------------------------------------------------------------------- /internal/plan/planner.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/evanxg852000/simpledb/internal/parser" 7 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 8 | ) 9 | 10 | type Planner struct { 11 | queryPlanner QueryPlanner 12 | updatePlanner UpdatePlanner 13 | } 14 | 15 | // The object that executes SQL statements. 16 | func NewPlanner(queryPlanner QueryPlanner, updatePlanner UpdatePlanner) *Planner { 17 | return &Planner{queryPlanner, updatePlanner} 18 | } 19 | 20 | // Parses an SQL statement, 21 | // - returns a plan for data selection queries 22 | // - execute and returns affected rows for modification queries 23 | func (planner *Planner) ExecuteQuery(queryStr string, tx *recovery.Transaction) (any, error) { 24 | defer func() { 25 | err := recover() 26 | if err != nil { 27 | fmt.Println("ERR:", err) 28 | } 29 | }() 30 | 31 | ast := parser.ParseQuery(queryStr) 32 | sqlStmt := ast.([]any)[0] 33 | err := planner.VerifyStatement(sqlStmt) 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | switch stmt := sqlStmt.(type) { 39 | case parser.SelectStmt: 40 | return planner.queryPlanner.CreatePlan(stmt, tx), nil 41 | case parser.InsertStmt: 42 | return planner.updatePlanner.ExecuteInsert(stmt, tx), nil 43 | case parser.UpdateStmt: 44 | return planner.updatePlanner.ExecuteModify(stmt, tx), nil 45 | case parser.DeleteStmt: 46 | return planner.updatePlanner.ExecuteDelete(stmt, tx), nil 47 | case parser.CreateTableStmt: 48 | return planner.updatePlanner.ExecuteCreateTable(stmt, tx), nil 49 | case parser.CreateViewStmt: 50 | return planner.updatePlanner.ExecuteCreateView(stmt, tx), nil 51 | case parser.CreateIndexStmt: 52 | return planner.updatePlanner.ExecuteCreateIndex(stmt, tx), nil 53 | } 54 | return nil, fmt.Errorf("unknown SQL statement") 55 | } 56 | 57 | func (planner *Planner) VerifyStatement(sqlStmt any) error { 58 | //TODO: verify statement and return error if any 59 | return nil 60 | } 61 | -------------------------------------------------------------------------------- /internal/plan/product_plan.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/query" 5 | "github.com/evanxg852000/simpledb/internal/record" 6 | ) 7 | 8 | type ProductPlan struct { 9 | left Plan 10 | right Plan 11 | schema *record.Schema 12 | } 13 | 14 | // Creates a new product node in the query tree, 15 | // having the two specified sub-queries. 16 | func NewProductPlan(left, right Plan) *ProductPlan { 17 | schema := record.NewSchema() 18 | schema.AddAll(left.Schema()) 19 | schema.AddAll(right.Schema()) 20 | return &ProductPlan{left, right, schema} 21 | } 22 | 23 | // Creates a product scan for this query. 24 | func (pp *ProductPlan) Open() query.Scan { 25 | leftScan := pp.left.Open() 26 | rightScan := pp.right.Open() 27 | return query.NewProductScan(leftScan, rightScan) 28 | } 29 | 30 | // Estimates the number of block accesses in the product. 31 | // The formula is: 32 | // B(product(p1,p2)) = B(p1) + R(p1)*B(p2) 33 | func (pp *ProductPlan) BlockAccessed() int64 { 34 | return pp.left.BlockAccessed() + (pp.left.BlockAccessed() * pp.right.BlockAccessed()) 35 | } 36 | 37 | // Estimates the number of output records in the product. 38 | // The formula is: 39 | // R(product(p1,p2)) = R(p1)*R(p2) 40 | func (pp *ProductPlan) RecordsOutput() int64 { 41 | return pp.left.RecordsOutput() * pp.right.RecordsOutput() 42 | } 43 | 44 | // Estimates the distinct number of field values in the product. 45 | // Since the product does not increase or decrease field values, 46 | // the estimate is the same as in the appropriate underlying query. 47 | func (pp *ProductPlan) DistinctValues(fieldName string) int64 { 48 | leftSchema := pp.left.Schema() 49 | if leftSchema.HasField(fieldName) { 50 | return pp.left.DistinctValues(fieldName) 51 | } 52 | return pp.right.DistinctValues(fieldName) 53 | } 54 | 55 | // Returns the schema of the product, 56 | // which is the union of the schemas of the underlying queries. 57 | func (pp *ProductPlan) Schema() record.Schema { 58 | return *pp.schema 59 | } 60 | -------------------------------------------------------------------------------- /internal/plan/project_plan.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/query" 5 | "github.com/evanxg852000/simpledb/internal/record" 6 | ) 7 | 8 | type ProjectPlan struct { 9 | plan Plan 10 | schema *record.Schema 11 | } 12 | 13 | // Creates a new project node in the query tree, 14 | // having the specified subquery and field list. 15 | func NewProjectPlan(plan Plan, fields []string) *ProjectPlan { 16 | schema := record.NewSchema() 17 | if fields[0] == "*" { 18 | schema.AddAll(plan.Schema()) 19 | fields = fields[1:] 20 | } 21 | for _, field := range fields { 22 | schema.Add(field, plan.Schema()) 23 | } 24 | return &ProjectPlan{plan, schema} 25 | } 26 | 27 | // Creates a project scan for this query. 28 | func (pp *ProjectPlan) Open() query.Scan { 29 | scan := pp.plan.Open() 30 | return query.NewProjectScan(scan, pp.schema.Fields()) 31 | } 32 | 33 | // Estimates the number of block accesses in the projection, 34 | // which is the same as in the underlying query. 35 | func (pp *ProjectPlan) BlockAccessed() int64 { 36 | return pp.plan.BlockAccessed() 37 | } 38 | 39 | // Estimates the number of output records in the projection, 40 | // which is the same as in the underlying query. 41 | func (pp *ProjectPlan) RecordsOutput() int64 { 42 | return pp.plan.RecordsOutput() 43 | } 44 | 45 | // Estimates the number of distinct field values 46 | // in the projection, 47 | // which is the same as in the underlying query. 48 | func (pp *ProjectPlan) DistinctValues(fieldName string) int64 { 49 | return pp.plan.DistinctValues(fieldName) 50 | } 51 | 52 | // Returns the schema of the projection, 53 | // which is taken from the field list. 54 | func (pp *ProjectPlan) Schema() record.Schema { 55 | return *pp.schema 56 | } 57 | -------------------------------------------------------------------------------- /internal/plan/query_planner.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/parser" 5 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 6 | ) 7 | 8 | // The interface implemented by planners for 9 | // the SQL select statement. 10 | type QueryPlanner interface { 11 | // Creates a plan for the parsed query. 12 | CreatePlan(selectStmt parser.SelectStmt, tx *recovery.Transaction) Plan 13 | } 14 | -------------------------------------------------------------------------------- /internal/plan/select_plan.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/query" 5 | "github.com/evanxg852000/simpledb/internal/record" 6 | ) 7 | 8 | // The Plan class corresponding to the select 9 | // relational algebra operator. 10 | type SelectPlan struct { 11 | plan Plan 12 | predicate *query.Predicate 13 | } 14 | 15 | // Creates a new select node in the query tree, 16 | // having the specified subquery and predicate. 17 | func NewSelectPlan(plan Plan, predicate *query.Predicate) *SelectPlan { 18 | return &SelectPlan{plan, predicate} 19 | } 20 | 21 | // Creates a select scan for this query. 22 | func (sp *SelectPlan) Open() query.Scan { 23 | scan := sp.plan.Open() 24 | return query.NewSelectScan(scan, sp.predicate) 25 | } 26 | 27 | // Estimates the number of block accesses in the selection, 28 | // which is the same as in the underlying query. 29 | func (sp *SelectPlan) BlockAccessed() int64 { 30 | return sp.plan.BlockAccessed() 31 | } 32 | 33 | // Estimates the number of records in the table, 34 | // which is obtainable from the statistics manager. 35 | func (sp *SelectPlan) RecordsOutput() int64 { 36 | return sp.plan.RecordsOutput() / ReductionFactor(sp.predicate, sp.plan) 37 | } 38 | 39 | // Estimates the number of distinct field values in the table, 40 | // which is obtainable from the statistics manager. 41 | func (sp *SelectPlan) DistinctValues(fieldName string) int64 { 42 | //TODO: missing optimal query planner 43 | return 1 44 | } 45 | 46 | // Returns the schema of the selection, 47 | // which is the same as in the underlying query. 48 | func (sp *SelectPlan) Schema() record.Schema { 49 | return sp.plan.Schema() 50 | } 51 | 52 | func ReductionFactor(pred *query.Predicate, p Plan) int64 { 53 | //TODO: missing optimal query planner 54 | return 1 55 | } 56 | -------------------------------------------------------------------------------- /internal/plan/table_plan.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/evanxg852000/simpledb/internal/metadata" 7 | "github.com/evanxg852000/simpledb/internal/query" 8 | "github.com/evanxg852000/simpledb/internal/record" 9 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 10 | ) 11 | 12 | // The Plan class corresponding to a table. 13 | type TablePlan struct { 14 | tableName string 15 | tx *recovery.Transaction 16 | layout *record.Layout 17 | statInfo metadata.StatInfo 18 | } 19 | 20 | // Creates a leaf node in the query tree corresponding 21 | // to the specified table. 22 | func NewTablePlan(tx *recovery.Transaction, tableName string, mdtManager *metadata.MetadataManager) *TablePlan { 23 | layout, err := mdtManager.GetLayout(tableName, tx) 24 | if err != nil { 25 | panic(fmt.Sprintf("could not retrieve table `%s` layout", tableName)) 26 | } 27 | statInfo := mdtManager.GetStatInfo(tableName, layout, tx) 28 | return &TablePlan{ 29 | tableName, 30 | tx, 31 | layout, 32 | statInfo, 33 | } 34 | } 35 | 36 | // Creates a table scan for this query. 37 | func (tblPlan *TablePlan) Open() query.Scan { 38 | scan, err := record.NewTableScan(tblPlan.tx, tblPlan.tableName, tblPlan.layout) 39 | if err != nil { 40 | panic(fmt.Sprint("could not create table plan", err)) 41 | } 42 | return any(scan).(query.Scan) 43 | } 44 | 45 | // Estimates the number of block accesses for the table, 46 | // which is obtainable from the statistics manager. 47 | func (tblPlan *TablePlan) BlockAccessed() int64 { 48 | return tblPlan.statInfo.BlockAccessed() 49 | } 50 | 51 | // Estimates the number of records in the table, 52 | // which is obtainable from the statistics manager. 53 | func (tblPlan *TablePlan) RecordsOutput() int64 { 54 | return tblPlan.statInfo.RecordsOutput() 55 | } 56 | 57 | // Estimates the number of distinct field values in the table, 58 | // which is obtainable from the statistics manager. 59 | func (tblPlan *TablePlan) DistinctValues(fieldName string) int64 { 60 | return tblPlan.statInfo.DistinctValues(fieldName) 61 | } 62 | 63 | // Determines the schema of the table, 64 | // which is obtainable from the catalog manager. 65 | func (tblPlan *TablePlan) Schema() record.Schema { 66 | return *tblPlan.layout.Schema 67 | } 68 | -------------------------------------------------------------------------------- /internal/plan/update_planner.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/parser" 5 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 6 | ) 7 | 8 | // The interface implemented by the planners 9 | // for SQL insert, delete, and modify statements. 10 | type UpdatePlanner interface { 11 | 12 | // Executes the specified insert statement, and 13 | // returns the number of affected records. 14 | ExecuteInsert(stmt parser.InsertStmt, tx *recovery.Transaction) int64 15 | 16 | // Executes the specified delete statement, and 17 | // returns the number of affected records. 18 | ExecuteDelete(stmt parser.DeleteStmt, tx *recovery.Transaction) int64 19 | 20 | // Executes the specified modify statement, and 21 | // returns the number of affected records. 22 | ExecuteModify(stmt parser.UpdateStmt, tx *recovery.Transaction) int64 23 | 24 | // Executes the specified create table statement, and 25 | // returns the number of affected records. 26 | ExecuteCreateTable(stmt parser.CreateTableStmt, tx *recovery.Transaction) int64 27 | 28 | // Executes the specified create view statement, and 29 | // returns the number of affected records. 30 | ExecuteCreateView(stmt parser.CreateViewStmt, tx *recovery.Transaction) int64 31 | 32 | // Executes the specified create index statement, and 33 | // returns the number of affected records. 34 | ExecuteCreateIndex(stmt parser.CreateIndexStmt, tx *recovery.Transaction) int64 35 | } 36 | -------------------------------------------------------------------------------- /internal/query/constant.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | type Constant struct { 8 | value any 9 | } 10 | 11 | func NewConstant(value any) Constant { 12 | return Constant{value} 13 | } 14 | 15 | func (c *Constant) AsInt() int64 { 16 | return c.value.(int64) 17 | } 18 | 19 | func (c *Constant) AsString() string { 20 | return c.value.(string) 21 | } 22 | 23 | func (c *Constant) String() string { 24 | return fmt.Sprint(c.value) 25 | } 26 | 27 | func (c *Constant) Equals(other Constant) bool { 28 | switch c.value.(type) { 29 | case int64: 30 | if other_value, ok := other.value.(int64); ok { 31 | return c.AsInt() == other_value 32 | } 33 | case string: 34 | if other_value, ok := other.value.(string); ok { 35 | return c.AsString() == other_value 36 | } 37 | } 38 | return false 39 | } 40 | -------------------------------------------------------------------------------- /internal/query/expression.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/evanxg852000/simpledb/internal/parser" 7 | ) 8 | 9 | type Expression struct { 10 | inner parser.Expr 11 | } 12 | 13 | func NewExpression(inner parser.Expr) *Expression { 14 | return &Expression{inner} 15 | } 16 | 17 | // Evaluate the expression with respect to the 18 | // current record of the specified scan. 19 | func (expr *Expression) Evaluate(scan Scan) Constant { 20 | if !expr.inner.IsFieldName() { 21 | return NewConstant(expr.inner.AsLiteralExpr().Value) 22 | } 23 | 24 | return scan.GetValue(expr.inner.AsFieldExpr()) 25 | } 26 | 27 | // Return true if the expression is a field reference. 28 | func (expr *Expression) IsFieldName() bool { 29 | return expr.inner.IsFieldName() 30 | } 31 | 32 | // Return the constant corresponding to a constant expression, 33 | // or null if the expression does not 34 | // denote a constant. 35 | func (expr *Expression) AsConstant() Constant { 36 | return NewConstant(expr.inner.AsLiteralExpr().Value) 37 | } 38 | 39 | // Return the field name corresponding to a constant expression, 40 | // or null if the expression does not 41 | // denote a field. 42 | func (expr *Expression) AsFieldName() string { 43 | return expr.inner.AsFieldExpr() 44 | } 45 | 46 | func (expr *Expression) toString() string { 47 | if !expr.inner.IsFieldName() { 48 | return fmt.Sprintf("%v", expr.inner.AsLiteralExpr().Value) 49 | } 50 | return expr.inner.AsFieldExpr() 51 | } 52 | -------------------------------------------------------------------------------- /internal/query/predicate.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import "github.com/evanxg852000/simpledb/internal/parser" 4 | 5 | // A predicate is a Boolean combination of terms. 6 | type Predicate struct { 7 | Condition parser.Condition 8 | } 9 | 10 | func NewPredicate(condition parser.Condition) *Predicate { 11 | return &Predicate{condition} 12 | } 13 | 14 | func (pred *Predicate) IsSatisfied(scan Scan) bool { 15 | if pred.Condition == (parser.Condition{}) { 16 | return true 17 | } 18 | 19 | left := EvaluateTerm(scan, pred.Condition.Left) 20 | if pred.Condition.Right == (parser.Term{}) { 21 | return left 22 | } 23 | 24 | right := EvaluateTerm(scan, pred.Condition.Right) 25 | 26 | switch pred.Condition.Op { 27 | case "and": 28 | return left && right 29 | case "or": 30 | return left || right 31 | } 32 | return false 33 | } 34 | 35 | func EvaluateTerm(scan Scan, term parser.Term) bool { 36 | left := EvaluateExpr(scan, term.Left) 37 | right := EvaluateExpr(scan, term.Right) 38 | switch term.Op { 39 | case "=": 40 | return left.Equals(right) 41 | case "!=": 42 | return !left.Equals(right) 43 | } 44 | return false 45 | } 46 | 47 | func EvaluateExpr(scan Scan, expr parser.Expr) Constant { 48 | if expr.IsFieldName() { 49 | return scan.GetValue(expr.AsFieldExpr()) 50 | } 51 | return NewConstant(expr.AsLiteralExpr().Value) 52 | } 53 | -------------------------------------------------------------------------------- /internal/query/product_scan.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | type ProductScan struct { 4 | left Scan 5 | right Scan 6 | } 7 | 8 | // Create a product scan having the two underlying scans. 9 | func NewProductScan(left Scan, right Scan) *ProductScan { 10 | scan := &ProductScan{left, right} 11 | scan.BeforeFirst() 12 | return scan 13 | } 14 | 15 | // Position the scan before its first record. 16 | // In particular, the LHS scan is positioned at 17 | // its first record, and the RHS scan 18 | // is positioned before its first record. 19 | func (ps *ProductScan) BeforeFirst() { 20 | ps.left.BeforeFirst() 21 | ps.left.Next() 22 | ps.right.BeforeFirst() 23 | } 24 | 25 | // Move the scan to the next record. 26 | // The method moves to the next RHS record, if possible. 27 | // Otherwise, it moves to the next LHS record and the 28 | // first RHS record. 29 | // If there are no more LHS records, the method returns false. 30 | func (ps *ProductScan) Next() bool { 31 | if ps.left.Next() { 32 | return true 33 | } else { 34 | ps.right.BeforeFirst() 35 | return ps.right.Next() && ps.left.Next() 36 | } 37 | } 38 | 39 | // Return the integer value of the specified field. 40 | // The value is obtained from whichever scan 41 | // contains the field. 42 | func (ps *ProductScan) GetInt(fieldName string) int64 { 43 | if ps.left.HasField(fieldName) { 44 | return ps.left.GetInt(fieldName) 45 | } 46 | return ps.right.GetInt(fieldName) 47 | } 48 | 49 | // Returns the string value of the specified field. 50 | // The value is obtained from whichever scan 51 | // contains the field. 52 | func (ps *ProductScan) GetString(fieldName string) string { 53 | if ps.left.HasField(fieldName) { 54 | return ps.left.GetString(fieldName) 55 | } 56 | return ps.right.GetString(fieldName) 57 | } 58 | 59 | // Return the value of the specified field. 60 | // The value is obtained from whichever scan 61 | // contains the field. 62 | func (ps *ProductScan) GetValue(fieldName string) Constant { 63 | if ps.left.HasField(fieldName) { 64 | return ps.left.GetValue(fieldName) 65 | } 66 | return ps.right.GetValue(fieldName) 67 | } 68 | 69 | // Returns true if the specified field is in 70 | // either of the underlying scans. 71 | func (ps *ProductScan) HasField(fieldName string) bool { 72 | return ps.left.HasField(fieldName) || ps.right.HasField(fieldName) 73 | } 74 | 75 | func (ps *ProductScan) Close() { 76 | ps.left.Close() 77 | ps.right.Close() 78 | } 79 | -------------------------------------------------------------------------------- /internal/query/project_scan.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "fmt" 5 | "slices" 6 | ) 7 | 8 | // The scan class corresponding to the project relational 9 | // algebra operator. 10 | // All methods except hasField delegate their work to the 11 | // underlying scan. 12 | type ProjectScan struct { 13 | scan Scan 14 | fields []string 15 | } 16 | 17 | // Create a project scan having the specified 18 | // underlying scan and field list. 19 | func NewProjectScan(scan Scan, fields []string) *ProjectScan { 20 | return &ProjectScan{scan, fields} 21 | } 22 | 23 | func (ps *ProjectScan) BeforeFirst() { 24 | ps.scan.BeforeFirst() 25 | } 26 | 27 | func (ps *ProjectScan) Next() bool { 28 | return ps.scan.Next() 29 | } 30 | 31 | func (ps *ProjectScan) GetInt(fieldName string) int64 { 32 | if ps.HasField(fieldName) { 33 | return ps.scan.GetInt(fieldName) 34 | } 35 | panic(fmt.Sprintf("field `%v` not found.", fieldName)) 36 | } 37 | 38 | func (ps *ProjectScan) GetString(fieldName string) string { 39 | if ps.HasField(fieldName) { 40 | return ps.scan.GetString(fieldName) 41 | } 42 | panic(fmt.Sprintf("field `%v` not found.", fieldName)) 43 | } 44 | 45 | func (ps *ProjectScan) GetValue(fieldName string) Constant { 46 | if ps.HasField(fieldName) { 47 | return ps.scan.GetValue(fieldName) 48 | } 49 | panic(fmt.Sprintf("field `%v` not found.", fieldName)) 50 | } 51 | 52 | func (ps *ProjectScan) HasField(fieldName string) bool { 53 | return slices.Contains(ps.fields, fieldName) 54 | } 55 | 56 | func (ps *ProjectScan) Close() { 57 | ps.scan.Close() 58 | } 59 | -------------------------------------------------------------------------------- /internal/query/rid.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import "fmt" 4 | 5 | type RID struct { 6 | BlockNum int64 7 | Slot int64 8 | } 9 | 10 | func NewRID(block_num int64, slot int64) RID { 11 | return RID{BlockNum: block_num, Slot: slot} 12 | } 13 | 14 | func (rId *RID) Equals(other RID) bool { 15 | return rId.BlockNum == other.BlockNum && rId.Slot == other.Slot 16 | } 17 | 18 | func (rId *RID) String() string { 19 | return fmt.Sprintf("{block: %d, slot: %d}", rId.BlockNum, rId.Slot) 20 | } 21 | -------------------------------------------------------------------------------- /internal/query/scan.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | // The interface will be implemented by each query scan. 4 | // There is a Scan class for each relational 5 | // algebra operator. 6 | type Scan interface { 7 | // Position the scan before its first record. A 8 | // subsequent call to next() will return the first record. 9 | BeforeFirst() 10 | 11 | // Move the scan to the next record. 12 | Next() bool 13 | 14 | // Return the value of the specified integer field 15 | // in the current record. 16 | GetInt(fieldName string) int64 17 | 18 | // Return the value of the specified string field 19 | // in the current record. 20 | GetString(fieldName string) string 21 | 22 | // Return the value of the specified field in the current record. 23 | // The value is expressed as a Constant. 24 | GetValue(fieldName string) Constant 25 | 26 | // Return true if the scan has the specified field. 27 | HasField(fieldName string) bool 28 | 29 | // Close the scan and its sub_scans, if any 30 | Close() 31 | } 32 | -------------------------------------------------------------------------------- /internal/query/select_scan.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | // The scan class corresponding to the select relational 4 | // algebra operator. 5 | // All methods except next delegate their work to the underlying scan 6 | type SelectScan struct { 7 | scan Scan 8 | predicate *Predicate 9 | } 10 | 11 | // Create a select scan having the specified underlying 12 | // scan and predicate. 13 | func NewSelectScan(scan Scan, predicate *Predicate) *SelectScan { 14 | return &SelectScan{scan, predicate} 15 | } 16 | 17 | // Scan methods 18 | 19 | func (ss *SelectScan) BeforeFirst() { 20 | ss.scan.BeforeFirst() 21 | } 22 | 23 | func (ss *SelectScan) Next() bool { 24 | for ss.scan.Next() { 25 | if ss.predicate.IsSatisfied(ss.scan) { 26 | return true 27 | } 28 | } 29 | return false 30 | } 31 | 32 | func (ss *SelectScan) GetInt(fieldName string) int64 { 33 | return ss.scan.GetInt(fieldName) 34 | } 35 | 36 | func (ss *SelectScan) GetString(fieldName string) string { 37 | return ss.scan.GetString(fieldName) 38 | } 39 | 40 | func (ss *SelectScan) GetValue(fieldName string) Constant { 41 | return ss.scan.GetValue(fieldName) 42 | } 43 | 44 | func (ss *SelectScan) HasField(fieldName string) bool { 45 | return ss.scan.HasField(fieldName) 46 | } 47 | 48 | func (ss *SelectScan) Close() { 49 | ss.scan.Close() 50 | } 51 | 52 | // UpdateScan methods 53 | 54 | func (ss *SelectScan) SetInt(fieldName string, value int64) { 55 | us := any(ss.scan).(UpdateScan) 56 | us.SetInt(fieldName, value) 57 | } 58 | 59 | func (ss *SelectScan) SetString(fieldName string, value string) { 60 | us := any(ss.scan).(UpdateScan) 61 | us.SetString(fieldName, value) 62 | } 63 | 64 | func (ss *SelectScan) SetValue(fieldName string, value Constant) { 65 | us := any(ss.scan).(UpdateScan) 66 | us.SetValue(fieldName, value) 67 | } 68 | 69 | func (ss *SelectScan) Delete() { 70 | us := any(ss.scan).(UpdateScan) 71 | us.Delete() 72 | } 73 | 74 | func (ss *SelectScan) Insert() { 75 | us := any(ss.scan).(UpdateScan) 76 | us.Insert() 77 | } 78 | 79 | func (ss *SelectScan) GetRID() RID { 80 | us := any(ss.scan).(UpdateScan) 81 | return us.GetRID() 82 | } 83 | 84 | func (ss *SelectScan) MoveToRID(rId RID) { 85 | us := any(ss.scan).(UpdateScan) 86 | us.MoveToRID(rId) 87 | } 88 | -------------------------------------------------------------------------------- /internal/query/term.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | // A term is a comparison between two expressions. 4 | // type Term struct { 5 | // left *Expression 6 | // right *Expression 7 | // } 8 | 9 | // // Create a new term that compares two expressions 10 | // // for equality. 11 | // func NewTerm(left *Expression, right *Expression) *Term { 12 | // return &Term{left, right} 13 | // } 14 | 15 | // // Return true if both of the term's expressions 16 | // // evaluate to the same constant, 17 | // // with respect to the specified scan. 18 | // func (t *Term) IsSatisfied(scan Scan) bool { 19 | // left := t.left.Evaluate(scan) 20 | // right := t.right.Evaluate(scan) 21 | // _ = right 22 | // _ = left 23 | // //TODO: 24 | // return false 25 | // } 26 | 27 | // // func (t *Term) ReductionFactor(p plan.Plan) int { 28 | // // //TODO: 29 | // // return 2 30 | // // } 31 | 32 | // // Determine if this term is of the form "F=c" 33 | // // where F is the specified field and c is some constant. 34 | // // If so, the method returns that constant. 35 | // // If not, the method returns null. 36 | // func (t *Term) EquatesWithConstant(fieldName string) *Constant { 37 | // //TODO: 38 | // return nil 39 | // } 40 | 41 | // func (t *Term) ToString(fieldName string) string { 42 | // return fmt.Sprintf("%s = %s ", t.left.toString(), t.right.toString()) 43 | // } 44 | -------------------------------------------------------------------------------- /internal/query/update_scan.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | // The interface implemented by all updateable scans. 4 | type UpdateScan interface { 5 | Scan // update scan inherit a scan 6 | 7 | // Modify the field value of the current record. 8 | SetValue(fldName string, value Constant) 9 | 10 | // Modify the field value of the current record. 11 | SetInt(fldName string, value int64) 12 | 13 | // Modify the field value of the current record. 14 | SetString(fldName string, value string) 15 | 16 | // Insert a new record somewhere in the scan. 17 | Insert() 18 | 19 | // Delete the current record from the scan. 20 | Delete() 21 | 22 | // Return the id of the current record. 23 | GetRID() RID 24 | 25 | // Position the scan so that the current record has 26 | // the specified id. 27 | MoveToRID(rId RID) 28 | } 29 | -------------------------------------------------------------------------------- /internal/record/layout.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import "github.com/evanxg852000/simpledb/internal/file" 4 | 5 | // Description of the structure of a record. 6 | // It contains the name, type, length and offset of 7 | // each field of the table. 8 | type Layout struct { 9 | Schema *Schema 10 | offsets map[string]int64 11 | slotSize int64 12 | } 13 | 14 | // This constructor creates a Layout object from a schema. 15 | // This constructor is used when a table 16 | // is created. It determines the physical offset of 17 | // each field within the record. 18 | func NewLayout(schema *Schema) *Layout { 19 | pos := int64(8) // // leave space for the empty/in-use flag 20 | offsets := map[string]int64{} 21 | for _, fldName := range schema.Fields() { 22 | offsets[fldName] = pos 23 | pos += lengthInBytes(schema, fldName) 24 | } 25 | return &Layout{ 26 | Schema: schema, 27 | offsets: offsets, 28 | slotSize: pos, 29 | } 30 | } 31 | 32 | // Create a Layout object from the specified metadata. 33 | // This constructor is used when the metadata 34 | // is retrieved from the catalog. 35 | func NewLayoutFromMetadata(schema *Schema, offsets map[string]int64, slotSize int64) *Layout { 36 | return &Layout{ 37 | Schema: schema, 38 | offsets: offsets, 39 | slotSize: slotSize, 40 | } 41 | } 42 | 43 | // Return the offset of a specified field within a record 44 | func (layout *Layout) Offset(fldName string) int64 { 45 | return layout.offsets[fldName] 46 | } 47 | 48 | // Return the size of a slot, in bytes. 49 | func (layout *Layout) SlotSize() int64 { 50 | return layout.slotSize 51 | } 52 | 53 | func lengthInBytes(schema *Schema, fldName string) int64 { 54 | fldType := schema.FieldType(fldName) 55 | if fldType == INTEGER_TYPE { 56 | return 8 57 | } 58 | // STRING_TYPE 59 | return file.GetEncodingLength(schema.FieldLength(fldName)) 60 | } 61 | -------------------------------------------------------------------------------- /internal/record/layout_test.go: -------------------------------------------------------------------------------- 1 | package record_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/evanxg852000/simpledb/internal/record" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestLayout(t *testing.T) { 11 | assert := assert.New(t) 12 | 13 | schema := record.NewSchema() 14 | schema.AddIntField("A") 15 | schema.AddStringField("B", 9) 16 | schema.AddIntField("C") 17 | 18 | layout := record.NewLayout(schema) 19 | expectedOffsets := []int64{8, 16, 33} 20 | for idx, fldName := range layout.Schema.Fields() { 21 | offset := layout.Offset(fldName) 22 | assert.Equal(expectedOffsets[idx], offset) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /internal/record/record_page.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/evanxg852000/simpledb/internal/file" 7 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 8 | ) 9 | 10 | const ( 11 | EMPTY = iota 12 | USED 13 | ) 14 | 15 | // Store a record at a given location in a block. 16 | type RecordPage struct { 17 | tx *recovery.Transaction 18 | blockId file.BlockId 19 | layout *Layout 20 | } 21 | 22 | func NewRecordPage(tx *recovery.Transaction, blockId file.BlockId, layout *Layout) *RecordPage { 23 | tx.Pin(blockId) 24 | return &RecordPage{tx, blockId, layout} 25 | } 26 | 27 | // Return the integer value stored for the 28 | // specified field of a specified slot. 29 | func (rp *RecordPage) GetInt(slot int64, fldName string) (int64, error) { 30 | fldPosition := rp.offset(slot) + rp.layout.Offset(fldName) 31 | return rp.tx.GetInt(rp.blockId, fldPosition) 32 | } 33 | 34 | // Return the string value stored for the 35 | // specified field of the specified slot. 36 | func (rp *RecordPage) GetString(slot int64, fldName string) (string, error) { 37 | fldPosition := rp.offset(slot) + rp.layout.Offset(fldName) 38 | return rp.tx.GetString(rp.blockId, fldPosition) 39 | } 40 | 41 | // Store an integer at the specified field 42 | // of the specified slot. 43 | func (rp *RecordPage) SetInt(slot int64, fldName string, value int64) error { 44 | fldPosition := rp.offset(slot) + rp.layout.Offset(fldName) 45 | return rp.tx.SetInt(rp.blockId, fldPosition, value, true) 46 | } 47 | 48 | // Store a string at the specified field 49 | // of the specified slot. 50 | func (rp *RecordPage) SetString(slot int64, fldName string, value string) error { 51 | fldPosition := rp.offset(slot) + rp.layout.Offset(fldName) 52 | return rp.tx.SetString(rp.blockId, fldPosition, value, true) 53 | } 54 | 55 | func (rp *RecordPage) Delete(slot int64) { 56 | rp.setFlag(slot, EMPTY) 57 | } 58 | 59 | // Use the layout to format a new block of records. 60 | // These values should not be logged 61 | // (because the old values are meaningless). 62 | func (rp *RecordPage) Format() error { 63 | slot := int64(0) 64 | for rp.isValidSlot(slot) { 65 | rp.tx.SetInt(rp.blockId, rp.offset(slot), EMPTY, false) 66 | schema := rp.layout.Schema 67 | for _, fldName := range schema.Fields() { 68 | fldPosition := rp.offset(slot) + rp.layout.Offset(fldName) 69 | if schema.FieldType(fldName) == INTEGER_TYPE { 70 | err := rp.tx.SetInt(rp.blockId, fldPosition, 0, false) 71 | if err != nil { 72 | return err 73 | } 74 | } else { 75 | err := rp.tx.SetString(rp.blockId, fldPosition, "", false) 76 | if err != nil { 77 | return err 78 | } 79 | } 80 | } 81 | slot += 1 82 | } 83 | return nil 84 | } 85 | 86 | func (rp *RecordPage) NextAfter(slot int64) int64 { 87 | return rp.searchAfter(slot, USED) 88 | } 89 | 90 | func (rp *RecordPage) InsertAfter(slot int64) int64 { 91 | newSlot := rp.searchAfter(slot, EMPTY) 92 | if newSlot >= 0 { 93 | rp.setFlag(newSlot, USED) 94 | } 95 | return newSlot 96 | } 97 | 98 | func (rp *RecordPage) BlockId() file.BlockId { 99 | return rp.blockId 100 | } 101 | 102 | func (rp *RecordPage) setFlag(slot int64, flag int64) { 103 | rp.tx.SetInt(rp.blockId, rp.offset(slot), flag, true) 104 | } 105 | 106 | func (rp *RecordPage) searchAfter(slot int64, flag int64) int64 { 107 | slot += 1 108 | for rp.isValidSlot(slot) { 109 | storedFlag, err := rp.tx.GetInt(rp.blockId, rp.offset(slot)) 110 | if err != nil { 111 | log.Fatalf("error check slot status: %v", err) 112 | } 113 | if storedFlag == flag { 114 | return slot 115 | } 116 | slot += 1 117 | } 118 | return -1 119 | } 120 | 121 | func (rp *RecordPage) isValidSlot(slot int64) bool { 122 | return rp.offset(slot+1) <= rp.tx.BlockSize() 123 | } 124 | 125 | func (rp *RecordPage) offset(slot int64) int64 { 126 | return slot * rp.layout.slotSize 127 | } 128 | -------------------------------------------------------------------------------- /internal/record/record_page_test.go: -------------------------------------------------------------------------------- 1 | package record_test 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "os" 7 | "path" 8 | "testing" 9 | 10 | "github.com/evanxg852000/simpledb/internal/record" 11 | "github.com/evanxg852000/simpledb/internal/server" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestRecordPage(t *testing.T) { 16 | assert := assert.New(t) 17 | workspaceDir, err := os.MkdirTemp("", "test_record_page") 18 | assert.Nil(err) 19 | dbDir := path.Join(workspaceDir, "db") 20 | defer os.RemoveAll(workspaceDir) 21 | 22 | db := server.NewSimpleDB(dbDir, 400, 8) 23 | tx := db.NewTx() 24 | 25 | schema := record.NewSchema() 26 | schema.AddIntField("A") 27 | schema.AddStringField("B", 9) 28 | 29 | layout := record.NewLayout(schema) 30 | 31 | blockId, err := tx.Append("test_record_page") 32 | assert.Nil(err) 33 | tx.Pin(blockId) 34 | 35 | recordPage := record.NewRecordPage(tx, blockId, layout) 36 | err = recordPage.Format() 37 | assert.Nil(err) 38 | 39 | // Filling the page with random records. 40 | type item struct { 41 | a int64 42 | b string 43 | } 44 | records := make([]int64, 0) 45 | slot := recordPage.InsertAfter(-1) 46 | for slot >= 0 { 47 | n := int64(rand.Intn(50)) 48 | recordPage.SetInt(slot, "A", n) 49 | recordPage.SetString(slot, "B", fmt.Sprintf("rec_%d", n)) 50 | slot = recordPage.InsertAfter(slot) 51 | records = append(records, n) 52 | } 53 | 54 | // Deleting records, whose A-values are less than 25. 55 | deletedRecords := make([]item, 0) 56 | remainingRecords := make([]item, 0) 57 | slot = recordPage.NextAfter(-1) 58 | for slot >= 0 { 59 | aVal, err := recordPage.GetInt(slot, "A") 60 | assert.Nil(err) 61 | bVal, err := recordPage.GetString(slot, "B") 62 | assert.Nil(err) 63 | rec := item{a: aVal, b: bVal} 64 | if rec.a < 25 { 65 | recordPage.Delete(slot) 66 | deletedRecords = append(deletedRecords, rec) 67 | } else { 68 | remainingRecords = append(remainingRecords, rec) 69 | } 70 | slot = recordPage.NextAfter(slot) 71 | } 72 | tx.Unpin(blockId) 73 | tx.Commit() 74 | 75 | // Checking remaining records. 76 | assert.Equal(len(records), len(deletedRecords)+len(remainingRecords)) 77 | for _, rec := range deletedRecords { 78 | assert.Equal(true, rec.a < 25) 79 | } 80 | for _, rec := range remainingRecords { 81 | assert.Equal(true, rec.a >= 25) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /internal/record/schema.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "slices" 5 | 6 | "github.com/evanxg852000/simpledb/internal/query" 7 | ) 8 | 9 | const ( 10 | _ = iota 11 | INTEGER_TYPE 12 | STRING_TYPE 13 | ) 14 | 15 | type FieldInfo struct { 16 | dataType int64 17 | length int64 18 | } 19 | 20 | func NewFieldInfo(fType int64, length int64) FieldInfo { 21 | return FieldInfo{fType, length} 22 | } 23 | 24 | // The record schema of a table. 25 | // A schema contains the name and type of 26 | // each field of the table, as well as the length 27 | // of each varchar field. 28 | type Schema struct { 29 | fields []string 30 | info map[string]FieldInfo 31 | } 32 | 33 | func NewSchema() *Schema { 34 | return &Schema{ 35 | fields: make([]string, 0), 36 | info: make(map[string]FieldInfo), 37 | } 38 | } 39 | 40 | // Add a field to the schema having a specified 41 | // name, type, and length. 42 | // If the field type is "integer", then the length 43 | // value is irrelevant. 44 | func (schema *Schema) AddField(fldName string, fldType int64, fldLength int64) { 45 | schema.fields = append(schema.fields, fldName) 46 | schema.info[fldName] = NewFieldInfo(fldType, fldLength) 47 | } 48 | 49 | // Add an integer field to the schema. 50 | func (schema *Schema) AddIntField(fldName string) { 51 | schema.AddField(fldName, INTEGER_TYPE, 0) 52 | } 53 | 54 | // Add a string field to the schema. 55 | // The length is the conceptual length of the field. 56 | // For example, if the field is defined as varchar(8), 57 | // then its length is 8. 58 | func (schema *Schema) AddStringField(fldName string, fldLength int64) { 59 | schema.AddField(fldName, STRING_TYPE, fldLength) 60 | } 61 | 62 | // Add a field to the schema having the same 63 | // type and length as the corresponding field 64 | // in another schema. 65 | func (schema *Schema) Add(fldName string, sch Schema) { 66 | fldType := sch.FieldType(fldName) 67 | fldLength := sch.FieldLength(fldName) 68 | schema.AddField(fldName, fldType, fldLength) 69 | } 70 | 71 | // Add all of the fields in the specified schema 72 | // to the current schema. 73 | func (schema *Schema) AddAll(sch Schema) { 74 | for _, fldName := range sch.fields { 75 | schema.Add(fldName, sch) 76 | } 77 | } 78 | 79 | // Return a collection containing the name of 80 | // each field in the schema. 81 | func (schema *Schema) Fields() []string { 82 | return schema.fields 83 | } 84 | 85 | // Return true if the specified field 86 | // is in the schema 87 | func (schema *Schema) HasField(fldName string) bool { 88 | return slices.Contains(schema.fields, fldName) 89 | } 90 | 91 | // Return the type of the specified field. 92 | func (schema *Schema) FieldType(fldName string) int64 { 93 | return schema.info[fldName].dataType 94 | } 95 | 96 | // Return the conceptual length of the specified field. 97 | // If the field is not a string field, then 98 | // the return value is undefined. 99 | func (schema *Schema) FieldLength(fldName string) int64 { 100 | return schema.info[fldName].length 101 | } 102 | 103 | // Determine if all of the fields mentioned in this expression 104 | // are contained in the specified schema. 105 | func (schema *Schema) AppliesTo(expr *query.Expression) bool { 106 | if !expr.IsFieldName() { 107 | return true 108 | } 109 | return schema.HasField(expr.AsFieldName()) 110 | } 111 | -------------------------------------------------------------------------------- /internal/record/table_scan.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/evanxg852000/simpledb/internal/file" 7 | "github.com/evanxg852000/simpledb/internal/query" 8 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 9 | ) 10 | 11 | type TableScan struct { 12 | tx *recovery.Transaction 13 | layout *Layout 14 | recordPage *RecordPage 15 | fileName string 16 | currentSlot int64 17 | } 18 | 19 | func NewTableScan(tx *recovery.Transaction, tblName string, layout *Layout) (*TableScan, error) { 20 | fileName := fmt.Sprintf("%s.tbl", tblName) 21 | tableScan := &TableScan{ 22 | tx: tx, 23 | layout: layout, 24 | recordPage: nil, 25 | fileName: fileName, 26 | currentSlot: -1, 27 | } 28 | fileSize, err := tx.Size(fileName) 29 | if err != nil { 30 | return nil, err 31 | } 32 | 33 | if fileSize == 0 { 34 | err := tableScan.moveToNewBlock() 35 | if err != nil { 36 | return nil, err 37 | } 38 | } else { 39 | tableScan.moveToBlock(0) 40 | } 41 | return tableScan, nil 42 | } 43 | 44 | // Methods that implement Scan 45 | 46 | func (tblScan *TableScan) BeforeFirst() { 47 | tblScan.moveToBlock(0) 48 | } 49 | 50 | func (tblScan *TableScan) Next() bool { 51 | tblScan.currentSlot = tblScan.recordPage.NextAfter(tblScan.currentSlot) 52 | for tblScan.currentSlot < 0 { 53 | isAtLast, err := tblScan.atLastBlock() 54 | if err != nil { 55 | panic(err) 56 | } 57 | 58 | if isAtLast { 59 | return false 60 | } 61 | tblScan.moveToBlock(tblScan.recordPage.BlockId().BlockNum + 1) 62 | tblScan.currentSlot = tblScan.recordPage.NextAfter(tblScan.currentSlot) 63 | } 64 | return true 65 | } 66 | 67 | func (tblScan *TableScan) GetInt(fldName string) int64 { 68 | value, err := tblScan.recordPage.GetInt(tblScan.currentSlot, fldName) 69 | if err != nil { 70 | panic(err) 71 | } 72 | return value 73 | } 74 | 75 | func (tblScan *TableScan) GetString(fldName string) string { 76 | value, err := tblScan.recordPage.GetString(tblScan.currentSlot, fldName) 77 | if err != nil { 78 | panic(err) 79 | } 80 | return value 81 | } 82 | 83 | func (tblScan *TableScan) GetValue(fldName string) query.Constant { 84 | if tblScan.layout.Schema.FieldType(fldName) == INTEGER_TYPE { 85 | iVal := tblScan.GetInt(fldName) 86 | return query.NewConstant(iVal) 87 | } 88 | 89 | sVal := tblScan.GetString(fldName) 90 | return query.NewConstant(sVal) 91 | } 92 | 93 | func (tblScan *TableScan) HasField(fldName string) bool { 94 | return tblScan.layout.Schema.HasField(fldName) 95 | } 96 | 97 | func (tblScan *TableScan) Close() { 98 | if tblScan.recordPage != nil { 99 | tblScan.tx.Unpin(tblScan.recordPage.BlockId()) 100 | } 101 | } 102 | 103 | // Methods that implement UpdateScan 104 | 105 | func (tblScan *TableScan) SetInt(fldName string, value int64) { 106 | err := tblScan.recordPage.SetInt(tblScan.currentSlot, fldName, value) 107 | if err != nil { 108 | panic(err) 109 | } 110 | } 111 | 112 | func (tblScan *TableScan) SetString(fldName string, value string) { 113 | err := tblScan.recordPage.SetString(tblScan.currentSlot, fldName, value) 114 | if err != nil { 115 | panic(err) 116 | } 117 | } 118 | 119 | func (tblScan *TableScan) SetValue(fldName string, value query.Constant) { 120 | if tblScan.layout.Schema.FieldType(fldName) == INTEGER_TYPE { 121 | tblScan.SetInt(fldName, value.AsInt()) 122 | return 123 | } 124 | tblScan.SetString(fldName, value.AsString()) 125 | } 126 | 127 | func (tblScan *TableScan) Insert() { 128 | tblScan.currentSlot = tblScan.recordPage.InsertAfter(tblScan.currentSlot) 129 | for tblScan.currentSlot < 0 { 130 | isAtLast, err := tblScan.atLastBlock() 131 | if err != nil { 132 | panic(err) 133 | } 134 | if isAtLast { 135 | err := tblScan.moveToNewBlock() 136 | if err != nil { 137 | panic(err) 138 | } 139 | } else { 140 | tblScan.moveToBlock(tblScan.recordPage.BlockId().BlockNum + 1) 141 | } 142 | 143 | tblScan.currentSlot = tblScan.recordPage.InsertAfter(tblScan.currentSlot) 144 | } 145 | } 146 | 147 | func (tblScan *TableScan) Delete() { 148 | tblScan.recordPage.Delete(tblScan.currentSlot) 149 | } 150 | 151 | func (tblScan *TableScan) MoveToRID(rId query.RID) { 152 | tblScan.Close() 153 | blockId := file.NewBlockId(tblScan.fileName, rId.BlockNum) 154 | tblScan.recordPage = NewRecordPage(tblScan.tx, blockId, tblScan.layout) 155 | tblScan.currentSlot = rId.Slot 156 | } 157 | 158 | func (tblScan *TableScan) GetRID() query.RID { 159 | return query.NewRID(tblScan.recordPage.BlockId().BlockNum, tblScan.currentSlot) 160 | } 161 | 162 | // Private auxiliary methods 163 | 164 | func (tblScan *TableScan) moveToBlock(blockNum int64) { 165 | tblScan.Close() 166 | blockId := file.NewBlockId(tblScan.fileName, blockNum) 167 | tblScan.recordPage = NewRecordPage(tblScan.tx, blockId, tblScan.layout) 168 | tblScan.currentSlot = -1 169 | } 170 | 171 | func (tblScan *TableScan) moveToNewBlock() error { 172 | tblScan.Close() 173 | blockId, err := tblScan.tx.Append(tblScan.fileName) 174 | if err != nil { 175 | return err 176 | } 177 | 178 | tblScan.recordPage = NewRecordPage(tblScan.tx, blockId, tblScan.layout) 179 | tblScan.recordPage.Format() 180 | tblScan.currentSlot = -1 181 | return nil 182 | } 183 | 184 | func (tblScan *TableScan) atLastBlock() (bool, error) { 185 | fileSize, err := tblScan.tx.Size(tblScan.fileName) 186 | if err != nil { 187 | return false, err 188 | } 189 | return tblScan.recordPage.BlockId().BlockNum == fileSize-1, nil 190 | } 191 | -------------------------------------------------------------------------------- /internal/record/table_scan_test.go: -------------------------------------------------------------------------------- 1 | package record_test 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "os" 7 | "path" 8 | "testing" 9 | 10 | "github.com/evanxg852000/simpledb/internal/record" 11 | "github.com/evanxg852000/simpledb/internal/server" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestTableScan(t *testing.T) { 16 | assert := assert.New(t) 17 | workspaceDir, err := os.MkdirTemp("", "test_table_scan") 18 | assert.Nil(err) 19 | dbDir := path.Join(workspaceDir, "db") 20 | defer os.RemoveAll(workspaceDir) 21 | 22 | db := server.NewSimpleDB(dbDir, 400, 8) 23 | tx := db.NewTx() 24 | 25 | schema := record.NewSchema() 26 | schema.AddIntField("A") 27 | schema.AddStringField("B", 9) 28 | 29 | layout := record.NewLayout(schema) 30 | 31 | // Filling the table with 50 random records. 32 | tblScan, err := record.NewTableScan(tx, "T", layout) 33 | assert.Nil(err) 34 | for i := 0; i < 50; i++ { 35 | tblScan.Insert() 36 | 37 | n := int64(rand.Intn(50)) 38 | tblScan.SetInt("A", n) 39 | tblScan.SetString("B", fmt.Sprintf("rec_%d", n)) 40 | 41 | rId := tblScan.GetRID() 42 | assert.True(rId.BlockNum >= 0) 43 | assert.True(rId.Slot >= 0) 44 | } 45 | 46 | // Deleting records, whose A-values are less than 25. 47 | deletedRecords := make([]int64, 0) 48 | remainingRecords := make([]int64, 0) 49 | tblScan.BeforeFirst() 50 | for tblScan.Next() { 51 | aVal := tblScan.GetInt("A") 52 | _ = tblScan.GetString("B") 53 | assert.Nil(err) 54 | if aVal < 25 { 55 | tblScan.Delete() 56 | deletedRecords = append(deletedRecords, aVal) 57 | } else { 58 | remainingRecords = append(remainingRecords, aVal) 59 | } 60 | } 61 | 62 | assert.Equal(50, len(deletedRecords)+len(remainingRecords)) 63 | tblScan.Close() 64 | tx.Commit() 65 | } 66 | -------------------------------------------------------------------------------- /internal/server/simpledb.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/evanxg852000/simpledb/internal/buffer" 7 | "github.com/evanxg852000/simpledb/internal/file" 8 | walog "github.com/evanxg852000/simpledb/internal/log" 9 | "github.com/evanxg852000/simpledb/internal/metadata" 10 | "github.com/evanxg852000/simpledb/internal/plan" 11 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 12 | ) 13 | 14 | const ( 15 | BLOCK_SIZE = 400 16 | BUFFER_SIZE = 8 17 | LOG_FILE = "simpledb.log" 18 | ) 19 | 20 | type SimpleDB struct { 21 | fileManager *file.FileManager 22 | bufferManager *buffer.BufferManager 23 | logManager *walog.LogManager 24 | metadataManager *metadata.MetadataManager 25 | planner *plan.Planner 26 | } 27 | 28 | // A constructor useful for debugging. 29 | func NewSimpleDB(directory string, blockSize int, bufferSize int) *SimpleDB { 30 | fileManager, err := file.NewFileManager(directory, int64(blockSize)) 31 | if err != nil { 32 | log.Fatalf("could not open the database") 33 | } 34 | 35 | logManager, err := walog.NewLogManager(fileManager, LOG_FILE) 36 | if err != nil { 37 | log.Fatalf("could not open the database") 38 | } 39 | 40 | bufferManager := buffer.NewBufferManager(fileManager, logManager, bufferSize) 41 | 42 | tx := recovery.NewTransaction(fileManager, logManager, bufferManager) 43 | isNew := fileManager.IsNew() 44 | if isNew { 45 | log.Println("creating new database") 46 | } else { 47 | log.Println("recovering existing database") 48 | tx.Recover() 49 | } 50 | 51 | metadataManager := metadata.NewMetadataManager(isNew, tx) 52 | planner := plan.NewPlanner( 53 | plan.NewBasicQueryPlanner(metadataManager), 54 | plan.NewBasicUpdatePlanner(metadataManager), 55 | ) 56 | tx.Commit() 57 | 58 | return &SimpleDB{ 59 | fileManager: fileManager, 60 | logManager: logManager, 61 | bufferManager: bufferManager, 62 | metadataManager: metadataManager, 63 | planner: planner, 64 | } 65 | } 66 | 67 | func (sdb *SimpleDB) FileManager() *file.FileManager { 68 | return sdb.fileManager 69 | } 70 | 71 | func (sdb *SimpleDB) LogManager() *walog.LogManager { 72 | return sdb.logManager 73 | } 74 | 75 | func (sdb *SimpleDB) BufferManager() *buffer.BufferManager { 76 | return sdb.bufferManager 77 | } 78 | 79 | func (sdb *SimpleDB) NewTx() *recovery.Transaction { 80 | return recovery.NewTransaction(sdb.fileManager, sdb.logManager, sdb.bufferManager) 81 | } 82 | 83 | func (sdb *SimpleDB) MetadataManager() *metadata.MetadataManager { 84 | return sdb.metadataManager 85 | } 86 | 87 | func (sdb *SimpleDB) Planner() *plan.Planner { 88 | return sdb.planner 89 | } 90 | -------------------------------------------------------------------------------- /internal/tx/concurency_test.go: -------------------------------------------------------------------------------- 1 | package tx_test 2 | 3 | import ( 4 | "os" 5 | "path" 6 | "sync" 7 | "testing" 8 | 9 | "github.com/evanxg852000/simpledb/internal/server" 10 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestConcurrency(t *testing.T) { 15 | assert := assert.New(t) 16 | workspaceDir, err := os.MkdirTemp("", "test_concurrency") 17 | assert.Nil(err) 18 | dbDir := path.Join(workspaceDir, "db") 19 | defer os.RemoveAll(workspaceDir) 20 | 21 | db := server.NewSimpleDB(dbDir, 400, 8) 22 | fm := db.FileManager() 23 | bm := db.BufferManager() 24 | lm := db.LogManager() 25 | 26 | txa := recovery.NewTransaction(fm, lm, bm) 27 | 28 | wg := sync.WaitGroup{} 29 | wg.Add(3) 30 | 31 | go A(assert, txa, &wg) 32 | go B(assert, txa, &wg) 33 | go C(assert, txa, &wg) 34 | 35 | wg.Wait() 36 | } 37 | 38 | func A(assert *assert.Assertions, tx *recovery.Transaction, wg *sync.WaitGroup) { 39 | defer wg.Done() 40 | //TODO: 41 | } 42 | 43 | func B(assert *assert.Assertions, tx *recovery.Transaction, wg *sync.WaitGroup) { 44 | defer wg.Done() 45 | //TODO: 46 | } 47 | 48 | func C(assert *assert.Assertions, tx *recovery.Transaction, wg *sync.WaitGroup) { 49 | defer wg.Done() 50 | //TODO: 51 | } 52 | -------------------------------------------------------------------------------- /internal/tx/concurrency/concurrency_manager.go: -------------------------------------------------------------------------------- 1 | package concurrency 2 | 3 | import "github.com/evanxg852000/simpledb/internal/file" 4 | 5 | // The global lock table. all transactions 6 | // share the same lock table. 7 | var lockTableInstance *LockTable = NewLockTable() 8 | 9 | // The concurrency manager for the transaction. 10 | // Each transaction has its own concurrency manager. 11 | // The concurrency manager keeps track of which locks the 12 | // transaction currently has, and interacts with the 13 | // global lock table as needed. 14 | type ConcurrencyManager struct { 15 | lockTable *LockTable 16 | locks map[file.BlockId]string 17 | } 18 | 19 | func NewConcurrencyManager() *ConcurrencyManager { 20 | return &ConcurrencyManager{ 21 | lockTable: lockTableInstance, 22 | locks: map[file.BlockId]string{}, 23 | } 24 | } 25 | 26 | // Obtain an SLock on the block, if necessary. 27 | // The method will ask the lock table for an SLock 28 | // if the transaction currently has no locks on that block. 29 | func (cm *ConcurrencyManager) SLock(blockId file.BlockId) { 30 | if _, exists := cm.locks[blockId]; !exists { 31 | cm.lockTable.SLock(blockId) 32 | cm.locks[blockId] = "S" 33 | } 34 | } 35 | 36 | // Obtain an XLock on the block, if necessary. 37 | // If the transaction does not have an XLock on that block, 38 | // then the method first gets an SLock on that block 39 | // (if necessary), and then upgrades it to an XLock. 40 | func (cm *ConcurrencyManager) XLock(blockId file.BlockId) { 41 | if !cm.hasXLock(blockId) { 42 | cm.SLock(blockId) 43 | cm.lockTable.XLock(blockId) 44 | cm.locks[blockId] = "X" 45 | } 46 | } 47 | 48 | // Release all locks by asking the lock table to 49 | // unlock each one. 50 | func (cm *ConcurrencyManager) Release() { 51 | for blockId := range cm.locks { 52 | cm.lockTable.Unlock(blockId) 53 | } 54 | clear(cm.locks) 55 | } 56 | 57 | func (cm *ConcurrencyManager) hasXLock(blockId file.BlockId) bool { 58 | lockType, exists := cm.locks[blockId] 59 | if !exists { 60 | return false 61 | } 62 | return lockType == "X" 63 | } 64 | -------------------------------------------------------------------------------- /internal/tx/concurrency/lock_table.go: -------------------------------------------------------------------------------- 1 | package concurrency 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | "time" 7 | 8 | "github.com/evanxg852000/simpledb/internal/file" 9 | "github.com/evanxg852000/simpledb/internal/utils" 10 | ) 11 | 12 | const ( 13 | MAX_WAIT_TIME = 10 * time.Second 14 | ) 15 | 16 | // The lock table, which provides methods to lock and unlock blocks. 17 | // If a transaction requests a lock that causes a conflict with an 18 | // existing lock, then that transaction is placed on a wait list. 19 | // There is only one wait list for all blocks. 20 | // When the last lock on a block is unlocked, then all transactions 21 | // are removed from the wait list and rescheduled. 22 | // If one of those transactions discovers that the lock it is waiting for 23 | // is still locked, it will place itself back on the wait list. 24 | type LockTable struct { 25 | locks map[file.BlockId]int 26 | mu *sync.Mutex 27 | cond *sync.Cond 28 | } 29 | 30 | func NewLockTable() *LockTable { 31 | mu := new(sync.Mutex) 32 | cond := sync.NewCond(mu) 33 | 34 | return &LockTable{ 35 | locks: map[file.BlockId]int{}, 36 | mu: mu, 37 | cond: cond, 38 | } 39 | } 40 | 41 | // Grant an SLock on the specified block. 42 | // If an XLock exists when the method is called, 43 | // then the calling thread will be placed on a wait list 44 | // until the lock is released. 45 | // If the thread remains on the wait list for a certain 46 | // amount of time (currently 10 seconds), 47 | // then an exception is thrown. 48 | func (lt *LockTable) SLock(blockId file.BlockId) error { 49 | lt.mu.Lock() 50 | defer lt.mu.Unlock() 51 | 52 | startTimestamp := time.Now().UnixMilli() 53 | for lt.hasXLock(blockId) && !lt.waitToLong(startTimestamp) { 54 | utils.WaitCondWithTimeout(lt.cond, MAX_WAIT_TIME) 55 | } 56 | 57 | if lt.hasXLock(blockId) { 58 | return errors.New("LockAbortException") 59 | } 60 | 61 | value := lt.getLockValue(blockId) // will not be negative 62 | lt.locks[blockId] = value + 1 63 | return nil 64 | } 65 | 66 | // Grant an XLock on the specified block. 67 | // If a lock of any type exists when the method is called, 68 | // then the calling thread will be placed on a wait list 69 | // until the locks are released. 70 | // If the thread remains on the wait list for a certain 71 | // amount of time (currently 10 seconds), 72 | // then an exception is thrown. 73 | func (lt *LockTable) XLock(blockId file.BlockId) error { 74 | lt.mu.Lock() 75 | defer lt.mu.Unlock() 76 | 77 | startTimestamp := time.Now().UnixMilli() 78 | for lt.hasOtherSLocks(blockId) && !lt.waitToLong(startTimestamp) { 79 | utils.WaitCondWithTimeout(lt.cond, MAX_WAIT_TIME) 80 | } 81 | if lt.hasOtherSLocks(blockId) { 82 | return errors.New("LockAbortException") 83 | } 84 | 85 | lt.locks[blockId] = -1 86 | return nil 87 | } 88 | 89 | func (lt *LockTable) Unlock(blockId file.BlockId) { 90 | value := lt.getLockValue(blockId) 91 | if value > 1 { 92 | lt.locks[blockId] = value - 1 93 | } else { 94 | delete(lt.locks, blockId) 95 | lt.cond.Broadcast() 96 | } 97 | } 98 | 99 | func (lt *LockTable) hasXLock(blockId file.BlockId) bool { 100 | return lt.getLockValue(blockId) < 0 101 | } 102 | 103 | func (lt *LockTable) hasOtherSLocks(blockId file.BlockId) bool { 104 | return lt.getLockValue(blockId) > 1 105 | } 106 | 107 | func (lt *LockTable) waitToLong(startTimestamp int64) bool { 108 | return time.Now().UnixMilli()-startTimestamp > MAX_WAIT_TIME.Milliseconds() 109 | } 110 | 111 | func (lt *LockTable) getLockValue(blockId file.BlockId) int { 112 | value, exists := lt.locks[blockId] 113 | if !exists { 114 | return 0 115 | } 116 | return value 117 | } 118 | -------------------------------------------------------------------------------- /internal/tx/recovery/buffer_list.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "slices" 5 | 6 | "github.com/evanxg852000/simpledb/internal/buffer" 7 | "github.com/evanxg852000/simpledb/internal/file" 8 | ) 9 | 10 | type BufferList struct { 11 | buffers map[file.BlockId]*buffer.Buffer 12 | pins []file.BlockId 13 | bufferManager *buffer.BufferManager 14 | } 15 | 16 | func NewBufferList(bufferManager *buffer.BufferManager) *BufferList { 17 | return &BufferList{ 18 | buffers: make(map[file.BlockId]*buffer.Buffer), 19 | pins: make([]file.BlockId, 0), 20 | bufferManager: bufferManager, 21 | } 22 | } 23 | 24 | func (bl *BufferList) GetBuffer(blockId file.BlockId) *buffer.Buffer { 25 | return bl.buffers[blockId] 26 | } 27 | 28 | func (bl *BufferList) Pin(blockId file.BlockId) error { 29 | buffer, err := bl.bufferManager.Pin(blockId) 30 | if err != nil { 31 | return err 32 | } 33 | bl.buffers[blockId] = buffer 34 | bl.pins = append(bl.pins, blockId) 35 | return nil 36 | } 37 | 38 | func (bl *BufferList) Unpin(blockId file.BlockId) { 39 | buffer, exist := bl.buffers[blockId] 40 | if !exist { 41 | return 42 | } 43 | bl.bufferManager.Unpin(buffer) 44 | deleted := false 45 | bl.pins = slices.DeleteFunc(bl.pins, func(probe file.BlockId) bool { 46 | if !deleted && probe == blockId { 47 | deleted = true 48 | return true 49 | } 50 | return false 51 | }) 52 | 53 | if !slices.Contains(bl.pins, blockId) { 54 | delete(bl.buffers, blockId) 55 | } 56 | } 57 | 58 | func (bl *BufferList) UnpinAll() { 59 | for _, blockId := range bl.pins { 60 | buffer := bl.buffers[blockId] 61 | bl.bufferManager.Unpin(buffer) 62 | } 63 | clear(bl.buffers) 64 | clear(bl.pins) 65 | } 66 | -------------------------------------------------------------------------------- /internal/tx/recovery/checkpoint_record.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/file" 5 | walog "github.com/evanxg852000/simpledb/internal/log" 6 | ) 7 | 8 | type CheckpointRecord struct { 9 | } 10 | 11 | func NewCheckpointRecord() (CheckpointRecord, error) { 12 | return CheckpointRecord{}, nil 13 | } 14 | 15 | func (cr CheckpointRecord) Operation() int { 16 | return CHECKPOINT 17 | } 18 | 19 | func (cr CheckpointRecord) TxNumber() int64 { 20 | return -1 21 | } 22 | 23 | func (cr CheckpointRecord) ToString() string { 24 | return "" 25 | } 26 | 27 | // Does nothing, because a checkpoint record 28 | // contains no undo information. 29 | func (cr CheckpointRecord) Undo(tx *Transaction) {} 30 | 31 | // A static method to write a checkpoint record to the log. 32 | // This log record contains the CHECKPOINT operator, 33 | // and nothing else. 34 | func (cr CheckpointRecord) WriteToLog(lm *walog.LogManager) (int64, error) { 35 | buffer := file.NewByteBuffer() 36 | err := buffer.WriteInt(CHECKPOINT) 37 | if err != nil { 38 | return -1, err 39 | } 40 | 41 | return lm.Append(buffer.Data()) 42 | } 43 | -------------------------------------------------------------------------------- /internal/tx/recovery/commit_record.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/evanxg852000/simpledb/internal/file" 7 | walog "github.com/evanxg852000/simpledb/internal/log" 8 | ) 9 | 10 | type CommitRecord struct { 11 | txNum int64 12 | } 13 | 14 | func NewCommitRecord(page file.Page) (CommitRecord, error) { 15 | txNum, err := page.ReadInt(8) 16 | if err != nil { 17 | return CommitRecord{}, err 18 | } 19 | return CommitRecord{txNum}, nil 20 | } 21 | 22 | func (cr CommitRecord) Operation() int { 23 | return COMMIT 24 | } 25 | 26 | func (cr CommitRecord) TxNumber() int64 { 27 | return cr.txNum 28 | } 29 | 30 | func (cr CommitRecord) ToString() string { 31 | return fmt.Sprintf("", cr.txNum) 32 | } 33 | 34 | // Does nothing, because a commit record 35 | // contains no undo information. 36 | func (cr CommitRecord) Undo(tx *Transaction) {} 37 | 38 | // A static method to write a commit record to the log. 39 | // This log record contains the COMMIT operator, 40 | // followed by the transaction id. 41 | func (cr CommitRecord) WriteToLog(lm *walog.LogManager, txNum int64) (int64, error) { 42 | buffer := file.NewByteBuffer() 43 | err := buffer.WriteInt(COMMIT) 44 | if err != nil { 45 | return -1, err 46 | } 47 | 48 | err = buffer.WriteInt(txNum) 49 | if err != nil { 50 | return -1, err 51 | } 52 | 53 | return lm.Append(buffer.Data()) 54 | } 55 | -------------------------------------------------------------------------------- /internal/tx/recovery/log_record.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "github.com/evanxg852000/simpledb/internal/file" 5 | ) 6 | 7 | const ( 8 | CHECKPOINT = iota 9 | START 10 | COMMIT 11 | ROLLBACK 12 | SETINT 13 | SETSTRING 14 | ) 15 | 16 | // The interface implemented by each type of log record. 17 | type LogRecord interface { 18 | // Returns the log record's type. 19 | Operation() int 20 | 21 | // Returns the transaction id stored with the log record. 22 | TxNumber() int64 23 | 24 | // Undoes the operation encoded by this log record. 25 | // The only log record types for which this method 26 | // does anything interesting are SETINT and SETSTRING. 27 | Undo(tx *Transaction) 28 | 29 | // Return string representation 30 | ToString() string 31 | } 32 | 33 | func NewLogRecord(data []byte) (LogRecord, error) { 34 | page := file.NewPageWithData(data) 35 | op, err := page.ReadInt(0) 36 | if err != nil { 37 | return nil, err 38 | } 39 | switch op { 40 | case CHECKPOINT: 41 | return NewCheckpointRecord() 42 | case START: 43 | return NewStartRecord(page) 44 | case COMMIT: 45 | return NewCommitRecord(page) 46 | case ROLLBACK: 47 | return NewRollbackRecord(page) 48 | case SETINT: 49 | return NewSetIntRecord(page) 50 | case SETSTRING: 51 | return NewSetStringRecord(page) 52 | } 53 | return nil, nil 54 | } 55 | -------------------------------------------------------------------------------- /internal/tx/recovery/recovery_manager.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "slices" 5 | 6 | "github.com/evanxg852000/simpledb/internal/buffer" 7 | walog "github.com/evanxg852000/simpledb/internal/log" 8 | ) 9 | 10 | // Each transaction has its own recovery manager. 11 | type RecoveryManager struct { 12 | logManager *walog.LogManager 13 | bufferManager *buffer.BufferManager 14 | tx *Transaction 15 | txNum int64 16 | } 17 | 18 | // Create a recovery manager for the specified transaction. 19 | func NewRecoveryManager(tx *Transaction, txNum int64, lm *walog.LogManager, bm *buffer.BufferManager) *RecoveryManager { 20 | StartRecord.WriteToLog(StartRecord{}, lm, txNum) 21 | return &RecoveryManager{ 22 | tx: tx, 23 | txNum: txNum, 24 | logManager: lm, 25 | bufferManager: bm, 26 | } 27 | } 28 | 29 | // Write a commit record to the log, and flushes it to disk. 30 | func (rm *RecoveryManager) Commit() error { 31 | err := rm.bufferManager.FlushAll(rm.txNum) 32 | if err != nil { 33 | return err 34 | } 35 | lsn, err := CommitRecord.WriteToLog(CommitRecord{}, rm.logManager, rm.txNum) 36 | if err != nil { 37 | return err 38 | } 39 | return rm.logManager.Flush(lsn) 40 | } 41 | 42 | // Write a rollback record to the log and flush it to disk. 43 | func (rm *RecoveryManager) Rollback() error { 44 | err := rm.doRollback() 45 | if err != nil { 46 | return err 47 | } 48 | rm.bufferManager.FlushAll(rm.txNum) 49 | lsn, err := RollbackRecord.WriteToLog(RollbackRecord{}, rm.logManager, rm.txNum) 50 | if err != nil { 51 | return err 52 | } 53 | rm.logManager.Flush(lsn) 54 | return nil 55 | } 56 | 57 | // Recover uncompleted transactions from the log 58 | // and then write a quiescent checkpoint record to the log and flush it. 59 | func (rm *RecoveryManager) Recover() error { 60 | err := rm.doRecover() 61 | if err != nil { 62 | return err 63 | } 64 | rm.bufferManager.FlushAll(rm.txNum) 65 | lsn, err := CheckpointRecord.WriteToLog(CheckpointRecord{}, rm.logManager) 66 | if err != nil { 67 | return err 68 | } 69 | rm.logManager.Flush(lsn) 70 | return nil 71 | } 72 | 73 | // Write a setint record to the log and return its lsn. 74 | func (rm *RecoveryManager) SetInt(buffer *buffer.Buffer, offset int64, value int64) (int64, error) { 75 | oldValue, err := buffer.Content().ReadInt(offset) 76 | if err != nil { 77 | return 0, err 78 | } 79 | blockId := buffer.Block() 80 | return SetIntRecord.WriteToLog(SetIntRecord{}, rm.logManager, rm.txNum, blockId, offset, oldValue) 81 | } 82 | 83 | // Write a setstring record to the log and return its lsn. 84 | func (rm *RecoveryManager) SetString(buffer *buffer.Buffer, offset int64, value string) (int64, error) { 85 | oldValue, err := buffer.Content().ReadString(offset) 86 | if err != nil { 87 | return 0, err 88 | } 89 | blockId := buffer.Block() 90 | return SetStringRecord.WriteToLog(SetStringRecord{}, rm.logManager, rm.txNum, blockId, offset, oldValue) 91 | } 92 | 93 | // Rollback the transaction, by iterating 94 | // through the log records until it finds 95 | // the transaction's START record, 96 | // calling undo() for each of the transaction's 97 | // log records. 98 | func (rm *RecoveryManager) doRollback() error { 99 | iter, err := rm.logManager.Iterator() 100 | if err != nil { 101 | return err 102 | } 103 | 104 | for iter.HasNext() { 105 | data, err := iter.Next() 106 | if err != nil { 107 | return err 108 | } 109 | record, err := NewLogRecord(data) 110 | if err != nil { 111 | return err 112 | } 113 | 114 | if record.TxNumber() == rm.txNum { 115 | if record.Operation() == START { 116 | return nil 117 | } 118 | record.Undo(rm.tx) 119 | } 120 | 121 | } 122 | return nil 123 | } 124 | 125 | // Do a complete database recovery. 126 | // The method iterates through the log records. 127 | // Whenever it finds a log record for an unfinished 128 | // transaction, it calls undo() on that record. 129 | // The method stops when it encounters a CHECKPOINT record 130 | // or the end of the log. 131 | func (rm *RecoveryManager) doRecover() error { 132 | finishedTxs := make([]int64, 0) 133 | iter, err := rm.logManager.Iterator() 134 | if err != nil { 135 | return err 136 | } 137 | 138 | for iter.HasNext() { 139 | data, err := iter.Next() 140 | if err != nil { 141 | return err 142 | } 143 | record, err := NewLogRecord(data) 144 | if err != nil { 145 | return err 146 | } 147 | 148 | if record.Operation() == CHECKPOINT { 149 | return nil 150 | } 151 | 152 | if record.Operation() == COMMIT || record.Operation() == ROLLBACK { 153 | finishedTxs = append(finishedTxs, record.TxNumber()) 154 | } 155 | 156 | if !slices.Contains(finishedTxs, record.TxNumber()) { 157 | record.Undo(rm.tx) 158 | } 159 | } 160 | return nil 161 | } 162 | -------------------------------------------------------------------------------- /internal/tx/recovery/rollback_record.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/evanxg852000/simpledb/internal/file" 7 | walog "github.com/evanxg852000/simpledb/internal/log" 8 | ) 9 | 10 | type RollbackRecord struct { 11 | txNum int64 12 | } 13 | 14 | func NewRollbackRecord(page file.Page) (RollbackRecord, error) { 15 | txNum, err := page.ReadInt(8) 16 | if err != nil { 17 | return RollbackRecord{}, err 18 | } 19 | return RollbackRecord{txNum}, nil 20 | } 21 | 22 | func (rr RollbackRecord) Operation() int { 23 | return START 24 | } 25 | 26 | func (rr RollbackRecord) TxNumber() int64 { 27 | return rr.txNum 28 | } 29 | 30 | func (rr RollbackRecord) ToString() string { 31 | return fmt.Sprintf("", rr.txNum) 32 | } 33 | 34 | // Does nothing, because a rollback record 35 | // contains no undo information. 36 | func (rr RollbackRecord) Undo(tx *Transaction) {} 37 | 38 | // A static method to write a rollback record to the log. 39 | // This log record contains the ROLLBACK operator, 40 | // followed by the transaction id. 41 | func (sr RollbackRecord) WriteToLog(lm *walog.LogManager, txNum int64) (int64, error) { 42 | buffer := file.NewByteBuffer() 43 | err := buffer.WriteInt(ROLLBACK) 44 | if err != nil { 45 | return -1, err 46 | } 47 | 48 | err = buffer.WriteInt(txNum) 49 | if err != nil { 50 | return -1, err 51 | } 52 | 53 | return lm.Append(buffer.Data()) 54 | } 55 | -------------------------------------------------------------------------------- /internal/tx/recovery/set_int_record.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/evanxg852000/simpledb/internal/file" 7 | walog "github.com/evanxg852000/simpledb/internal/log" 8 | ) 9 | 10 | type SetIntRecord struct { 11 | txNum int64 12 | offset int64 13 | value int64 14 | blockId file.BlockId 15 | } 16 | 17 | func NewSetIntRecord(page file.Page) (SetIntRecord, error) { 18 | pos := int64(8) 19 | txNum, err := page.ReadInt(pos) 20 | if err != nil { 21 | return SetIntRecord{}, err 22 | } 23 | 24 | pos += 8 25 | fileName, err := page.ReadString(pos) 26 | if err != nil { 27 | return SetIntRecord{}, err 28 | } 29 | 30 | pos = pos + file.GetEncodingLength(int64(len(fileName))) 31 | blockNum, err := page.ReadInt(pos) 32 | if err != nil { 33 | return SetIntRecord{}, err 34 | } 35 | blockId := file.NewBlockId(fileName, blockNum) 36 | 37 | pos = pos + 8 38 | offset, err := page.ReadInt(pos) 39 | if err != nil { 40 | return SetIntRecord{}, err 41 | } 42 | 43 | pos = pos + 8 44 | value, err := page.ReadInt(pos) 45 | if err != nil { 46 | return SetIntRecord{}, err 47 | } 48 | 49 | return SetIntRecord{txNum, offset, value, blockId}, nil 50 | } 51 | 52 | func (sir SetIntRecord) Operation() int { 53 | return SETINT 54 | } 55 | 56 | func (sir SetIntRecord) TxNumber() int64 { 57 | return sir.txNum 58 | } 59 | 60 | func (sir SetIntRecord) ToString() string { 61 | return fmt.Sprintf("", sir.txNum, sir.blockId, sir.offset, sir.value) 62 | } 63 | 64 | // Replace the specified data value with the value saved in the log record. 65 | // The method pins a buffer to the specified block, 66 | // calls setInt to restore the saved value, 67 | // and unpins the buffer. 68 | func (sir SetIntRecord) Undo(tx *Transaction) { 69 | tx.Pin(sir.blockId) 70 | tx.SetInt(sir.blockId, sir.offset, sir.value, false) // don't log the undo 71 | tx.Unpin(sir.blockId) 72 | } 73 | 74 | // A static method to write a setInt record to the log. 75 | // This log record contains the SETINT operator, 76 | // followed by the transaction id, the filename, number, 77 | // and offset of the modified block, and the previous 78 | // integer value at that offset. 79 | func (SetIntRecord) WriteToLog(lm *walog.LogManager, txNum int64, blockId file.BlockId, offset int64, value int64) (int64, error) { 80 | buffer := file.NewByteBuffer() 81 | err := buffer.WriteInt(SETINT) 82 | if err != nil { 83 | return -1, err 84 | } 85 | 86 | err = buffer.WriteInt(txNum) 87 | if err != nil { 88 | return -1, err 89 | } 90 | 91 | err = buffer.WriteString(blockId.FileName) 92 | if err != nil { 93 | return -1, err 94 | } 95 | 96 | err = buffer.WriteInt(blockId.BlockNum) 97 | if err != nil { 98 | return -1, err 99 | } 100 | 101 | err = buffer.WriteInt(offset) 102 | if err != nil { 103 | return -1, err 104 | } 105 | 106 | err = buffer.WriteInt(value) 107 | if err != nil { 108 | return -1, err 109 | } 110 | 111 | return lm.Append(buffer.Data()) 112 | } 113 | -------------------------------------------------------------------------------- /internal/tx/recovery/set_string_record.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/evanxg852000/simpledb/internal/file" 7 | walog "github.com/evanxg852000/simpledb/internal/log" 8 | ) 9 | 10 | type SetStringRecord struct { 11 | txNum int64 12 | offset int64 13 | value string 14 | blockId file.BlockId 15 | } 16 | 17 | func NewSetStringRecord(page file.Page) (SetStringRecord, error) { 18 | pos := int64(8) 19 | txNum, err := page.ReadInt(pos) 20 | if err != nil { 21 | return SetStringRecord{}, err 22 | } 23 | 24 | pos += 8 25 | fileName, err := page.ReadString(pos) 26 | if err != nil { 27 | return SetStringRecord{}, err 28 | } 29 | 30 | pos = pos + file.GetEncodingLength(int64(len(fileName))) 31 | blockNum, err := page.ReadInt(pos) 32 | if err != nil { 33 | return SetStringRecord{}, err 34 | } 35 | blockId := file.NewBlockId(fileName, blockNum) 36 | 37 | pos = pos + 8 38 | offset, err := page.ReadInt(pos) 39 | if err != nil { 40 | return SetStringRecord{}, err 41 | } 42 | 43 | pos = pos + 8 44 | value, err := page.ReadString(pos) 45 | if err != nil { 46 | return SetStringRecord{}, err 47 | } 48 | 49 | return SetStringRecord{txNum, offset, value, blockId}, nil 50 | } 51 | 52 | func (ssr SetStringRecord) Operation() int { 53 | return SETSTRING 54 | } 55 | 56 | func (ssr SetStringRecord) TxNumber() int64 { 57 | return ssr.txNum 58 | } 59 | 60 | func (ssr SetStringRecord) ToString() string { 61 | return fmt.Sprintf("", ssr.txNum, ssr.blockId, ssr.offset, ssr.value) 62 | } 63 | 64 | // Replace the specified data value with the value saved in the log record. 65 | // The method pins a buffer to the specified block, 66 | // calls setInt to restore the saved value, 67 | // and unpins the buffer. 68 | func (ssr SetStringRecord) Undo(tx *Transaction) { 69 | tx.Pin(ssr.blockId) 70 | tx.SetString(ssr.blockId, ssr.offset, ssr.value, false) // don't log the undo 71 | tx.Unpin(ssr.blockId) 72 | } 73 | 74 | // A static method to write a SetString record to the log. 75 | // This log record contains the SETSTRING operator, 76 | // followed by the transaction id, the filename, number, 77 | // and offset of the modified block, and the previous 78 | // integer value at that offset. 79 | func (SetStringRecord) WriteToLog(lm *walog.LogManager, txNum int64, blockId file.BlockId, offset int64, value string) (int64, error) { 80 | buffer := file.NewByteBuffer() 81 | err := buffer.WriteInt(SETSTRING) 82 | if err != nil { 83 | return -1, err 84 | } 85 | 86 | err = buffer.WriteInt(txNum) 87 | if err != nil { 88 | return -1, err 89 | } 90 | 91 | err = buffer.WriteString(blockId.FileName) 92 | if err != nil { 93 | return -1, err 94 | } 95 | 96 | err = buffer.WriteInt(blockId.BlockNum) 97 | if err != nil { 98 | return -1, err 99 | } 100 | 101 | err = buffer.WriteInt(offset) 102 | if err != nil { 103 | return -1, err 104 | } 105 | 106 | err = buffer.WriteString(value) 107 | if err != nil { 108 | return -1, err 109 | } 110 | 111 | return lm.Append(buffer.Data()) 112 | } 113 | -------------------------------------------------------------------------------- /internal/tx/recovery/start_record.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/evanxg852000/simpledb/internal/file" 7 | walog "github.com/evanxg852000/simpledb/internal/log" 8 | ) 9 | 10 | type StartRecord struct { 11 | txNum int64 12 | } 13 | 14 | func NewStartRecord(page file.Page) (StartRecord, error) { 15 | txNum, err := page.ReadInt(8) 16 | if err != nil { 17 | return StartRecord{}, err 18 | } 19 | return StartRecord{txNum}, nil 20 | } 21 | 22 | func (sr StartRecord) Operation() int { 23 | return START 24 | } 25 | 26 | func (sr StartRecord) TxNumber() int64 { 27 | return sr.txNum 28 | } 29 | 30 | func (sr StartRecord) ToString() string { 31 | return fmt.Sprintf("", sr.txNum) 32 | } 33 | 34 | func (sr StartRecord) Undo(tx *Transaction) {} 35 | 36 | // A static method to write a start record to the log. 37 | // This log record contains the START operator, 38 | // followed by the transaction id. 39 | func (sr StartRecord) WriteToLog(lm *walog.LogManager, txNum int64) (int64, error) { 40 | buffer := file.NewByteBuffer() 41 | err := buffer.WriteInt(START) 42 | if err != nil { 43 | return -1, err 44 | } 45 | 46 | err = buffer.WriteInt(txNum) 47 | if err != nil { 48 | return -1, err 49 | } 50 | 51 | return lm.Append(buffer.Data()) 52 | } 53 | -------------------------------------------------------------------------------- /internal/tx/recovery/transaction.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "sync/atomic" 7 | 8 | "github.com/evanxg852000/simpledb/internal/buffer" 9 | "github.com/evanxg852000/simpledb/internal/file" 10 | walog "github.com/evanxg852000/simpledb/internal/log" 11 | "github.com/evanxg852000/simpledb/internal/tx/concurrency" 12 | // "github.com/evanxg852000/simpledb/internal/tx/recovery" 13 | ) 14 | 15 | const END_OF_FILE int = -1 16 | 17 | var nextTxNum = atomic.Int64{} 18 | 19 | type Transaction struct { 20 | recoveryManager *RecoveryManager 21 | concurrencyManager *concurrency.ConcurrencyManager 22 | bufferManager *buffer.BufferManager 23 | fileManager *file.FileManager 24 | txNum int64 25 | buffers *BufferList 26 | } 27 | 28 | func NewTransaction(fileManager *file.FileManager, logManager *walog.LogManager, bufferManager *buffer.BufferManager) *Transaction { 29 | newTxNum := nextTxNum.Add(1) 30 | tx := &Transaction{ 31 | concurrencyManager: concurrency.NewConcurrencyManager(), 32 | bufferManager: bufferManager, 33 | fileManager: fileManager, 34 | txNum: newTxNum, 35 | buffers: NewBufferList(bufferManager), 36 | } 37 | 38 | tx.recoveryManager = NewRecoveryManager(tx, newTxNum, logManager, bufferManager) 39 | return tx 40 | } 41 | 42 | // Commit the current transaction. 43 | // Flush all modified buffers (and their log records), 44 | // write and flush a commit record to the log, 45 | // release all locks, and unpin any pinned buffers. 46 | func (tx *Transaction) Commit() { 47 | tx.recoveryManager.Commit() 48 | log.Printf("transaction %d committed\n", tx.txNum) 49 | tx.concurrencyManager.Release() 50 | tx.buffers.UnpinAll() 51 | } 52 | 53 | // Rollback the current transaction. 54 | // Undo any modified values, 55 | // flush those buffers, 56 | // write and flush a rollback record to the log, 57 | // release all locks, and unpin any pinned buffers. 58 | func (tx *Transaction) Rollback() { 59 | tx.recoveryManager.Rollback() 60 | fmt.Printf("transaction %d rolled back\n", tx.txNum) 61 | tx.concurrencyManager.Release() 62 | tx.buffers.UnpinAll() 63 | } 64 | 65 | // Flush all modified buffers. 66 | // Then go through the log, rolling back all 67 | // uncommitted transactions. Finally, 68 | // write a quiescent checkpoint record to the log. 69 | // This method is called during system startup, 70 | // before user transactions begin. 71 | func (tx *Transaction) Recover() { 72 | tx.bufferManager.FlushAll(tx.txNum) 73 | tx.recoveryManager.Recover() 74 | } 75 | 76 | // Pin the specified block. 77 | // The transaction manages the buffer for the client. 78 | func (tx *Transaction) Pin(blockId file.BlockId) { 79 | tx.buffers.Pin(blockId) 80 | } 81 | 82 | // Unpin the specified block. 83 | // The transaction looks up the buffer pinned to this block, 84 | // and unpins it. 85 | func (tx *Transaction) Unpin(blockId file.BlockId) { 86 | tx.buffers.Unpin(blockId) 87 | } 88 | 89 | // Return the integer value stored at the 90 | // specified offset of the specified block. 91 | // The method first obtains an SLock on the block, 92 | // then it calls the buffer to retrieve the value. 93 | // returns the integer stored at that offset 94 | func (tx *Transaction) GetInt(blockId file.BlockId, offset int64) (int64, error) { 95 | tx.concurrencyManager.SLock(blockId) 96 | buffer := tx.buffers.GetBuffer(blockId) 97 | return buffer.Content().ReadInt(offset) 98 | } 99 | 100 | // Return the string value stored at the 101 | // specified offset of the specified block. 102 | // The method first obtains an SLock on the block, 103 | // then it calls the buffer to retrieve the value. 104 | // returns the string stored at that offset 105 | func (tx *Transaction) GetString(blockId file.BlockId, offset int64) (string, error) { 106 | tx.concurrencyManager.SLock(blockId) 107 | buffer := tx.buffers.GetBuffer(blockId) 108 | return buffer.Content().ReadString(offset) 109 | } 110 | 111 | // Store an integer at the specified offset 112 | // of the specified block. 113 | // The method first obtains an XLock on the block. 114 | // It then reads the current value at that offset, 115 | // puts it into an update log record, and 116 | // writes that record to the log. 117 | // Finally, it calls the buffer to store the value, 118 | // passing in the LSN of the log record and the transaction's id. 119 | func (tx *Transaction) SetInt(blockId file.BlockId, offset int64, value int64, okToLog bool) error { 120 | tx.concurrencyManager.XLock(blockId) 121 | buffer := tx.buffers.GetBuffer(blockId) 122 | lsn := int64(-1) 123 | if okToLog { 124 | lsn, err := tx.recoveryManager.SetInt(buffer, offset, value) 125 | _ = lsn 126 | if err != nil { 127 | return err 128 | } 129 | } 130 | err := buffer.Content().WriteInt(offset, value) 131 | if err != nil { 132 | return err 133 | } 134 | buffer.Modify(tx.txNum, lsn) 135 | return nil 136 | } 137 | 138 | // Store a string at the specified offset 139 | // of the specified block. 140 | // The method first obtains an XLock on the block. 141 | // It then reads the current value at that offset, 142 | // puts it into an update log record, and 143 | // writes that record to the log. 144 | // Finally, it calls the buffer to store the value, 145 | // passing in the LSN of the log record and the transaction's id. 146 | func (tx *Transaction) SetString(blockId file.BlockId, offset int64, value string, okToLog bool) error { 147 | tx.concurrencyManager.XLock(blockId) 148 | buffer := tx.buffers.GetBuffer(blockId) 149 | lsn := int64(-1) 150 | if okToLog { 151 | lsn, err := tx.recoveryManager.SetString(buffer, offset, value) 152 | _ = lsn 153 | if err != nil { 154 | return err 155 | } 156 | } 157 | _, err := buffer.Content().WriteString(offset, value) 158 | if err != nil { 159 | return err 160 | } 161 | buffer.Modify(tx.txNum, lsn) 162 | return nil 163 | } 164 | 165 | // Return the number of blocks in the specified file. 166 | // This method first obtains an SLock on the 167 | // "end of the file", before asking the file manager 168 | // to return the file size. 169 | func (tx *Transaction) Size(fileName string) (int64, error) { 170 | blockId := file.NewBlockId(fileName, int64(END_OF_FILE)) 171 | tx.concurrencyManager.SLock(blockId) 172 | return tx.fileManager.BlockCount(fileName) 173 | } 174 | 175 | // Append a new block to the end of the specified file 176 | // and returns a reference to it. 177 | // This method first obtains an XLock on the 178 | // "end of the file", before performing the append. 179 | func (tx *Transaction) Append(fileName string) (file.BlockId, error) { 180 | blockId := file.NewBlockId(fileName, int64(END_OF_FILE)) 181 | tx.concurrencyManager.XLock(blockId) 182 | return tx.fileManager.Append(fileName) 183 | } 184 | 185 | func (tx *Transaction) BlockSize() int64 { 186 | return tx.fileManager.BlockSize() 187 | } 188 | 189 | func (tx *Transaction) AvailableBuffers() int64 { 190 | return tx.bufferManager.Available() 191 | } 192 | -------------------------------------------------------------------------------- /internal/tx/transaction_test.go: -------------------------------------------------------------------------------- 1 | package tx_test 2 | 3 | import ( 4 | "os" 5 | "path" 6 | "testing" 7 | 8 | "github.com/evanxg852000/simpledb/internal/file" 9 | "github.com/evanxg852000/simpledb/internal/server" 10 | "github.com/evanxg852000/simpledb/internal/tx/recovery" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestTransaction(t *testing.T) { 15 | assert := assert.New(t) 16 | 17 | workspaceDir, err := os.MkdirTemp("", "test_transaction") 18 | assert.Nil(err) 19 | dbDir := path.Join(workspaceDir, "db") 20 | defer os.RemoveAll(workspaceDir) 21 | 22 | db := server.NewSimpleDB(dbDir, 400, 8) 23 | fm := db.FileManager() 24 | bm := db.BufferManager() 25 | lm := db.LogManager() 26 | 27 | tx1 := recovery.NewTransaction(fm, lm, bm) 28 | blockId := file.NewBlockId("testfile", 1) 29 | tx1.Pin(blockId) 30 | // The block initially contains unknown bytes, 31 | // so don't log those values here. 32 | tx1.SetInt(blockId, 80, 1, false) 33 | tx1.SetString(blockId, 40, "one", false) 34 | tx1.Commit() 35 | 36 | tx2 := recovery.NewTransaction(fm, lm, bm) 37 | tx2.Pin(blockId) 38 | iVal, err := tx2.GetInt(blockId, 80) 39 | assert.Nil(err) 40 | assert.Equal(int64(1), iVal) 41 | 42 | sVal, err := tx2.GetString(blockId, 40) 43 | assert.Nil(err) 44 | assert.Equal("one", sVal) 45 | 46 | iVal = iVal + 1 47 | sVal = sVal + "!" 48 | tx2.SetInt(blockId, 80, iVal, true) 49 | tx2.SetString(blockId, 40, sVal, true) 50 | tx2.Commit() 51 | 52 | tx3 := recovery.NewTransaction(fm, lm, bm) 53 | tx3.Pin(blockId) 54 | iVal, err = tx3.GetInt(blockId, 80) 55 | assert.Nil(err) 56 | assert.Equal(int64(2), iVal) 57 | 58 | sVal, err = tx3.GetString(blockId, 40) 59 | assert.Nil(err) 60 | assert.Equal("one!", sVal) 61 | 62 | tx3.SetInt(blockId, 80, 9999, true) 63 | iVal, _ = tx3.GetInt(blockId, 80) 64 | assert.Equal(int64(9999), iVal) 65 | tx3.Rollback() 66 | 67 | tx4 := recovery.NewTransaction(fm, lm, bm) 68 | tx4.Pin(blockId) 69 | iVal, _ = tx4.GetInt(blockId, 80) 70 | assert.Equal(int64(2), iVal) 71 | tx4.Commit() 72 | } 73 | -------------------------------------------------------------------------------- /internal/utils/const.go: -------------------------------------------------------------------------------- 1 | package utils 2 | -------------------------------------------------------------------------------- /internal/utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | func WaitCondWithTimeout(cond *sync.Cond, timeout time.Duration) { 9 | go func(cond *sync.Cond, timeout time.Duration) { 10 | time.Sleep(timeout) 11 | cond.Broadcast() 12 | }(cond, timeout) 13 | cond.Wait() 14 | } 15 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evanxg852000/simpledb-go/9ab9079c4ec66200f8291fe8af4fbddd891a1b9a/screenshot.png --------------------------------------------------------------------------------