├── .circleci └── config.yml ├── .gitignore ├── LICENSE.md ├── Makefile ├── README.md ├── ast.go ├── ast_test.go ├── backend.go ├── cmd ├── indextest │ └── main.go ├── libraryexample │ └── main.go ├── main.go └── sqlexample │ └── main.go ├── driver.go ├── error.go ├── go.mod ├── go.sum ├── lexer.go ├── lexer_test.go ├── memory.go ├── memory_test.go ├── parser.go ├── parser_test.go └── repl.go /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | docker: 5 | - image: circleci/golang:1.14 6 | 7 | working_directory: /go/src/github.com/eatonphil/gosql 8 | steps: 9 | - checkout 10 | 11 | # Install golangci-lint 12 | - run: curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.25.0 13 | - run: go get -v -t -d ./... 14 | - run: make test 15 | - run: make lint 16 | # Fail if there's a gofmt diff 17 | - run: bash -c '[[ $(gofmt -l .) ]] && exit 1 || exit 0' 18 | - run: go build cmd/main.go -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | main 2 | coverage.out -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2020 Phil Eaton 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | fmt: 2 | gofmt -w -s . 3 | 4 | test: 5 | go test -race -cover -coverprofile=coverage.out . 6 | 7 | cover: 8 | go tool cover -func=coverage.out 9 | 10 | lint: 11 | go vet . 12 | golangci-lint run 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gosql 2 | 3 | An early PostgreSQL implementation in Go. 4 | 5 | [![gosql](https://circleci.com/gh/eatonphil/gosql.svg?style=svg)](https://circleci.com/gh/eatonphil/gosql) 6 | 7 | ## Example 8 | 9 | ```bash 10 | $ git clone git@github.com:eatonphil/gosql 11 | $ cd gosql 12 | $ go run cmd/main.go 13 | Welcome to gosql. 14 | # CREATE TABLE users (id INT PRIMARY KEY, name TEXT, age INT); 15 | ok 16 | # \d users 17 | Table "users" 18 | Column | Type | Nullable 19 | ---------+---------+----------- 20 | id | integer | not null 21 | name | text | 22 | age | integer | 23 | Indexes: 24 | "users_pkey" PRIMARY KEY, rbtree ("id") 25 | 26 | # INSERT INTO users VALUES (1, 'Corey', 34); 27 | ok 28 | # INSERT INTO users VALUES (1, 'Max', 29); 29 | Error inserting values: Duplicate key value violates unique constraint 30 | # INSERT INTO users VALUES (2, 'Max', 29); 31 | ok 32 | # SELECT * FROM users WHERE id = 2; 33 | id | name | age 34 | -----+------+------ 35 | 2 | Max | 29 36 | (1 result) 37 | ok 38 | # SELECT id, name, age + 3 FROM users WHERE id = 2 OR id = 1; 39 | id | name | ?column? 40 | -----+-------+----------- 41 | 1 | Corey | 37 42 | 2 | Max | 32 43 | (2 results) 44 | ok 45 | ``` 46 | 47 | ## Using the database/sql driver 48 | 49 | See cmd/sqlexample/main.go: 50 | 51 | ```go 52 | package main 53 | 54 | import ( 55 | "database/sql" 56 | "fmt" 57 | 58 | _ "github.com/eatonphil/gosql" 59 | ) 60 | 61 | func main() { 62 | db, err := sql.Open("postgres", "") 63 | if err != nil { 64 | panic(err) 65 | } 66 | defer db.Close() 67 | 68 | _, err = db.Query("CREATE TABLE users (name TEXT, age INT);") 69 | if err != nil { 70 | panic(err) 71 | } 72 | 73 | _, err = db.Query("INSERT INTO users VALUES ('Terry', 45);") 74 | if err != nil { 75 | panic(err) 76 | } 77 | 78 | _, err = db.Query("INSERT INTO users VALUES ('Anette', 57);") 79 | if err != nil { 80 | panic(err) 81 | } 82 | 83 | rows, err := db.Query("SELECT name, age FROM users;") 84 | if err != nil { 85 | panic(err) 86 | } 87 | 88 | var name string 89 | var age uint64 90 | defer rows.Close() 91 | for rows.Next() { 92 | err := rows.Scan(&name, &age) 93 | if err != nil { 94 | panic(err) 95 | } 96 | 97 | fmt.Printf("Name: %s, Age: %d\n", name, age) 98 | } 99 | 100 | if err = rows.Err(); err != nil { 101 | panic(err) 102 | } 103 | } 104 | ``` 105 | 106 | Parameterization is not currently supported. 107 | 108 | ## Architecture 109 | 110 | * [cmd/main.go](./cmd/main.go) 111 | * Contains the REPL and high-level interface to the project 112 | * Dataflow is: user input -> lexer -> parser -> in-memory backend 113 | * [lexer.go](./lexer.go) 114 | * Handles breaking user input into tokens for the parser 115 | * [parser.go](./parser.go) 116 | * Matches a list of tokens into an AST or fails if the user input is not a valid program 117 | * [memory.go](./memory.go) 118 | * An example, in-memory backend supporting the Backend interface (defined in backend.go) 119 | 120 | ## Contributing 121 | 122 | * Add a new operator (such as `-`, `*`, etc.) 123 | * Add a new data type (such as `VARCHAR(n)``) 124 | 125 | In each case, you'll probably have to add support in the lexer, 126 | parser, and in-memory backend. I recommend going in that order. 127 | 128 | In all cases, make sure the code is formatted (`make fmt`), linted 129 | (`make lint`) and passes tests (`make test`). New code should have 130 | tests. 131 | 132 | ## Blog series 133 | 134 | * [Writing a SQL database from scratch in Go](https://notes.eatonphil.com/database-basics.html) 135 | * [Binary expressions and WHERE filters](https://notes.eatonphil.com/database-basics-expressions-and-where.html) 136 | * [Indexes](https://notes.eatonphil.com/database-basics-indexes.html) 137 | * [A database/sql driver](https://notes.eatonphil.com/database-basics-a-database-sql-driver.html) 138 | 139 | ## Further reading 140 | 141 | Here are some similar projects written in Go. 142 | 143 | * [go-mysql-server](https://github.com/src-d/go-mysql-server) 144 | * This is a MySQL frontend (with an in-memory backend for testing only). 145 | * [ramsql](https://github.com/proullon/ramsql) 146 | * This is a WIP PostgreSQL-compatible in-memory database. 147 | * [CockroachDB](https://github.com/cockroachdb/cockroach) 148 | * This is a production-ready PostgreSQL-compatible database. 149 | -------------------------------------------------------------------------------- /ast.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type ExpressionKind uint 9 | 10 | const ( 11 | LiteralKind ExpressionKind = iota 12 | BinaryKind 13 | ) 14 | 15 | type BinaryExpression struct { 16 | A Expression 17 | B Expression 18 | Op Token 19 | } 20 | 21 | func (be BinaryExpression) GenerateCode() string { 22 | return fmt.Sprintf("(%s %s %s)", be.A.GenerateCode(), be.Op.Value, be.B.GenerateCode()) 23 | } 24 | 25 | type Expression struct { 26 | Literal *Token 27 | Binary *BinaryExpression 28 | Kind ExpressionKind 29 | } 30 | 31 | func (e Expression) GenerateCode() string { 32 | switch e.Kind { 33 | case LiteralKind: 34 | switch e.Literal.Kind { 35 | case IdentifierKind: 36 | return fmt.Sprintf("\"%s\"", e.Literal.Value) 37 | case StringKind: 38 | return fmt.Sprintf("'%s'", e.Literal.Value) 39 | default: 40 | return fmt.Sprintf(e.Literal.Value) 41 | } 42 | 43 | case BinaryKind: 44 | return e.Binary.GenerateCode() 45 | } 46 | 47 | return "" 48 | } 49 | 50 | type SelectItem struct { 51 | Exp *Expression 52 | Asterisk bool // for * 53 | As *Token 54 | } 55 | 56 | type SelectStatement struct { 57 | Item *[]*SelectItem 58 | From *Token 59 | Where *Expression 60 | Limit *Expression 61 | Offset *Expression 62 | } 63 | 64 | func (ss SelectStatement) GenerateCode() string { 65 | item := []string{} 66 | for _, i := range *ss.Item { 67 | s := "\t*" 68 | if !i.Asterisk { 69 | s = "\t" + i.Exp.GenerateCode() 70 | 71 | if i.As != nil { 72 | s = fmt.Sprintf("\t%s AS \"%s\"", s, i.As.Value) 73 | } 74 | } 75 | item = append(item, s) 76 | } 77 | 78 | code := "SELECT\n" + strings.Join(item, ",\n") 79 | if ss.From != nil { 80 | code += fmt.Sprintf("\nFROM\n\t\"%s\"", ss.From.Value) 81 | } 82 | 83 | if ss.Where != nil { 84 | code += "\nWHERE\n\t" + ss.Where.GenerateCode() 85 | } 86 | 87 | if ss.Limit != nil { 88 | code += "\nLIMIT\n\t" + ss.Limit.GenerateCode() 89 | } 90 | 91 | if ss.Offset != nil { 92 | code += "\nOFFSET\n\t" + ss.Limit.GenerateCode() 93 | } 94 | 95 | return code + ";" 96 | } 97 | 98 | type ColumnDefinition struct { 99 | Name Token 100 | Datatype Token 101 | PrimaryKey bool 102 | } 103 | 104 | type CreateTableStatement struct { 105 | Name Token 106 | Cols *[]*ColumnDefinition 107 | } 108 | 109 | func (cts CreateTableStatement) GenerateCode() string { 110 | cols := []string{} 111 | for _, col := range *cts.Cols { 112 | modifiers := "" 113 | if col.PrimaryKey { 114 | modifiers += " " + "PRIMARY KEY" 115 | } 116 | spec := fmt.Sprintf("\t\"%s\" %s%s", col.Name.Value, strings.ToUpper(col.Datatype.Value), modifiers) 117 | cols = append(cols, spec) 118 | } 119 | return fmt.Sprintf("CREATE TABLE \"%s\" (\n%s\n);", cts.Name.Value, strings.Join(cols, ",\n")) 120 | } 121 | 122 | type CreateIndexStatement struct { 123 | Name Token 124 | Unique bool 125 | PrimaryKey bool 126 | Table Token 127 | Exp Expression 128 | } 129 | 130 | func (cis CreateIndexStatement) GenerateCode() string { 131 | unique := "" 132 | if cis.Unique { 133 | unique = " UNIQUE" 134 | } 135 | return fmt.Sprintf("CREATE%s INDEX \"%s\" ON \"%s\" (%s);", unique, cis.Name.Value, cis.Table.Value, cis.Exp.GenerateCode()) 136 | } 137 | 138 | type DropTableStatement struct { 139 | Name Token 140 | } 141 | 142 | func (dts DropTableStatement) GenerateCode() string { 143 | return fmt.Sprintf("DROP TABLE \"%s\";", dts.Name.Value) 144 | } 145 | 146 | type InsertStatement struct { 147 | Table Token 148 | Values *[]*Expression 149 | } 150 | 151 | func (is InsertStatement) GenerateCode() string { 152 | values := []string{} 153 | for _, exp := range *is.Values { 154 | values = append(values, exp.GenerateCode()) 155 | } 156 | return fmt.Sprintf("INSERT INTO \"%s\" VALUES (%s);", is.Table.Value, strings.Join(values, ", ")) 157 | } 158 | 159 | type AstKind uint 160 | 161 | const ( 162 | SelectKind AstKind = iota 163 | CreateTableKind 164 | CreateIndexKind 165 | DropTableKind 166 | InsertKind 167 | ) 168 | 169 | type Statement struct { 170 | SelectStatement *SelectStatement 171 | CreateTableStatement *CreateTableStatement 172 | CreateIndexStatement *CreateIndexStatement 173 | DropTableStatement *DropTableStatement 174 | InsertStatement *InsertStatement 175 | Kind AstKind 176 | } 177 | 178 | func (s Statement) GenerateCode() string { 179 | switch s.Kind { 180 | case SelectKind: 181 | return s.SelectStatement.GenerateCode() 182 | case CreateTableKind: 183 | return s.CreateTableStatement.GenerateCode() 184 | case CreateIndexKind: 185 | return s.CreateIndexStatement.GenerateCode() 186 | case DropTableKind: 187 | return s.DropTableStatement.GenerateCode() 188 | case InsertKind: 189 | return s.InsertStatement.GenerateCode() 190 | } 191 | 192 | return "?unknown?" 193 | } 194 | 195 | type Ast struct { 196 | Statements []*Statement 197 | } 198 | -------------------------------------------------------------------------------- /ast_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestStatement_GenerateCode(t *testing.T) { 10 | tests := []struct { 11 | result string 12 | stmt Statement 13 | }{ 14 | { 15 | `DROP TABLE "foo";`, 16 | Statement{ 17 | DropTableStatement: &DropTableStatement{ 18 | Name: Token{Value: "foo"}, 19 | }, 20 | Kind: DropTableKind, 21 | }, 22 | }, 23 | { 24 | `CREATE TABLE "users" ( 25 | "id" INT PRIMARY KEY, 26 | "name" TEXT 27 | );`, 28 | Statement{ 29 | CreateTableStatement: &CreateTableStatement{ 30 | Name: Token{Value: "users"}, 31 | Cols: &[]*ColumnDefinition{ 32 | { 33 | Name: Token{Value: "id"}, 34 | Datatype: Token{Value: "int"}, 35 | PrimaryKey: true, 36 | }, 37 | { 38 | Name: Token{Value: "name"}, 39 | Datatype: Token{Value: "text"}, 40 | }, 41 | }, 42 | }, 43 | Kind: CreateTableKind, 44 | }, 45 | }, 46 | { 47 | `CREATE UNIQUE INDEX "age_idx" ON "users" ("age");`, 48 | Statement{ 49 | CreateIndexStatement: &CreateIndexStatement{ 50 | Name: Token{Value: "age_idx"}, 51 | Unique: true, 52 | Table: Token{Value: "users"}, 53 | Exp: Expression{Literal: &Token{Value: "age", Kind: IdentifierKind}, Kind: LiteralKind}, 54 | }, 55 | Kind: CreateIndexKind, 56 | }, 57 | }, 58 | { 59 | `INSERT INTO "foo" VALUES (1, 'flubberty', true);`, 60 | Statement{ 61 | InsertStatement: &InsertStatement{ 62 | Table: Token{Value: "foo"}, 63 | Values: &[]*Expression{ 64 | {Literal: &Token{Value: "1", Kind: NumericKind}, Kind: LiteralKind}, 65 | {Literal: &Token{Value: "flubberty", Kind: StringKind}, Kind: LiteralKind}, 66 | {Literal: &Token{Value: "true", Kind: BoolKind}, Kind: LiteralKind}, 67 | }, 68 | }, 69 | Kind: InsertKind, 70 | }, 71 | }, 72 | { 73 | `SELECT 74 | "id", 75 | "name" 76 | FROM 77 | "users" 78 | WHERE 79 | ("id" = 2);`, 80 | Statement{ 81 | SelectStatement: &SelectStatement{ 82 | Item: &[]*SelectItem{ 83 | {Exp: &Expression{Literal: &Token{Value: "id", Kind: IdentifierKind}, Kind: LiteralKind}}, 84 | {Exp: &Expression{Literal: &Token{Value: "name", Kind: IdentifierKind}, Kind: LiteralKind}}, 85 | }, 86 | From: &Token{Value: "users"}, 87 | Where: &Expression{ 88 | Binary: &BinaryExpression{ 89 | A: Expression{Literal: &Token{Value: "id", Kind: IdentifierKind}, Kind: LiteralKind}, 90 | B: Expression{Literal: &Token{Value: "2", Kind: NumericKind}, Kind: LiteralKind}, 91 | Op: Token{Value: "=", Kind: SymbolKind}, 92 | }, 93 | Kind: BinaryKind, 94 | }, 95 | }, 96 | Kind: SelectKind, 97 | }, 98 | }, 99 | } 100 | 101 | for _, test := range tests { 102 | assert.Equal(t, test.result, test.stmt.GenerateCode()) 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /backend.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import "errors" 4 | 5 | type ColumnType uint 6 | 7 | const ( 8 | TextType ColumnType = iota 9 | IntType 10 | BoolType 11 | ) 12 | 13 | func (c ColumnType) String() string { 14 | switch c { 15 | case TextType: 16 | return "TextType" 17 | case IntType: 18 | return "IntType" 19 | case BoolType: 20 | return "BoolType" 21 | default: 22 | return "Error" 23 | } 24 | } 25 | 26 | type Cell interface { 27 | AsText() *string 28 | AsInt() *int32 29 | AsBool() *bool 30 | } 31 | 32 | type Results struct { 33 | Columns []ResultColumn 34 | Rows [][]Cell 35 | } 36 | 37 | type ResultColumn struct { 38 | Type ColumnType 39 | Name string 40 | NotNull bool 41 | } 42 | 43 | type Index struct { 44 | Name string 45 | Exp string 46 | Type string 47 | Unique bool 48 | PrimaryKey bool 49 | } 50 | 51 | type TableMetadata struct { 52 | Name string 53 | Columns []ResultColumn 54 | Indexes []Index 55 | } 56 | 57 | type Backend interface { 58 | CreateTable(*CreateTableStatement) error 59 | DropTable(*DropTableStatement) error 60 | CreateIndex(*CreateIndexStatement) error 61 | Insert(*InsertStatement) error 62 | Select(*SelectStatement) (*Results, error) 63 | GetTables() []TableMetadata 64 | } 65 | 66 | // Useful to embed when prototyping new backends 67 | type EmptyBackend struct{} 68 | 69 | func (eb EmptyBackend) CreateTable(_ *CreateTableStatement) error { 70 | return errors.New("Create not supported") 71 | } 72 | 73 | func (eb EmptyBackend) DropTable(_ *DropTableStatement) error { 74 | return errors.New("Drop not supported") 75 | } 76 | 77 | func (eb EmptyBackend) CreateIndex(_ *CreateIndexStatement) error { 78 | return errors.New("Create index not supported") 79 | } 80 | 81 | func (eb EmptyBackend) Insert(_ *InsertStatement) error { 82 | return errors.New("Insert not supported") 83 | } 84 | 85 | func (eb EmptyBackend) Select(_ *SelectStatement) (*Results, error) { 86 | return nil, errors.New("Select not supported") 87 | } 88 | 89 | func (eb EmptyBackend) GetTables() []TableMetadata { 90 | return nil 91 | } 92 | -------------------------------------------------------------------------------- /cmd/indextest/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "runtime" 7 | "strconv" 8 | "time" 9 | 10 | "github.com/eatonphil/gosql" 11 | ) 12 | 13 | var inserts = 0 14 | var lastId = 0 15 | var firstId = 0 16 | 17 | func doInsert(mb gosql.Backend) { 18 | parser := gosql.Parser{} 19 | for i := 0; i < inserts; i++ { 20 | lastId = i 21 | if i == 0 { 22 | firstId = lastId 23 | } 24 | ast, err := parser.Parse(fmt.Sprintf("INSERT INTO users VALUES (%d)", lastId)) 25 | if err != nil { 26 | panic(err) 27 | } 28 | 29 | err = mb.Insert(ast.Statements[0].InsertStatement) 30 | if err != nil { 31 | panic(err) 32 | } 33 | } 34 | } 35 | 36 | func doSelect(mb gosql.Backend) { 37 | parser := gosql.Parser{} 38 | ast, err := parser.Parse(fmt.Sprintf("SELECT id FROM users WHERE id = %d", lastId)) 39 | if err != nil { 40 | panic(err) 41 | } 42 | 43 | r, err := mb.Select(ast.Statements[0].SelectStatement) 44 | if err != nil { 45 | panic(err) 46 | } 47 | 48 | if len(r.Rows) != 1 { 49 | panic("Expected 1 row") 50 | } 51 | 52 | if int(*r.Rows[0][0].AsInt()) != inserts-1 { 53 | panic(fmt.Sprintf("Bad row, got: %d", r.Rows[0][1].AsInt())) 54 | } 55 | 56 | ast, err = parser.Parse(fmt.Sprintf("SELECT id FROM users WHERE id = %d", firstId)) 57 | if err != nil { 58 | panic(err) 59 | } 60 | 61 | r, err = mb.Select(ast.Statements[0].SelectStatement) 62 | if err != nil { 63 | panic(err) 64 | } 65 | 66 | if len(r.Rows) != 1 { 67 | panic("Expected 1 row") 68 | } 69 | 70 | if int(*r.Rows[0][0].AsInt()) != 0 { 71 | panic(fmt.Sprintf("Bad row, got: %d", r.Rows[0][1].AsInt())) 72 | } 73 | } 74 | 75 | func perf(name string, b gosql.Backend, cb func(b gosql.Backend)) { 76 | start := time.Now() 77 | fmt.Println("Starting", name) 78 | cb(b) 79 | fmt.Printf("Finished %s: %f seconds\n", name, time.Since(start).Seconds()) 80 | 81 | var m runtime.MemStats 82 | runtime.ReadMemStats(&m) 83 | fmt.Printf("Alloc = %d MiB\n\n", m.Alloc/1024/1024) 84 | } 85 | 86 | func main() { 87 | mb := gosql.NewMemoryBackend() 88 | 89 | index := false 90 | for i, arg := range os.Args { 91 | if arg == "--with-index" { 92 | index = true 93 | } 94 | 95 | if arg == "--inserts" { 96 | inserts, _ = strconv.Atoi(os.Args[i+1]) 97 | } 98 | } 99 | 100 | primaryKey := "" 101 | if index { 102 | primaryKey = " PRIMARY KEY" 103 | } 104 | 105 | parser := gosql.Parser{} 106 | ast, err := parser.Parse(fmt.Sprintf("CREATE TABLE users (id INT%s)", primaryKey)) 107 | if err != nil { 108 | panic(err) 109 | } 110 | 111 | err = mb.CreateTable(ast.Statements[0].CreateTableStatement) 112 | if err != nil { 113 | panic(err) 114 | } 115 | 116 | indexingString := " with indexing enabled" 117 | if !index { 118 | indexingString = "" 119 | } 120 | fmt.Printf("Inserting %d rows%s\n", inserts, indexingString) 121 | 122 | perf("INSERT", mb, doInsert) 123 | 124 | perf("SELECT", mb, doSelect) 125 | } 126 | -------------------------------------------------------------------------------- /cmd/libraryexample/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/eatonphil/gosql" 7 | ) 8 | 9 | func main() { 10 | mb := gosql.NewMemoryBackend() 11 | 12 | parser := gosql.Parser{} 13 | ast, err := parser.Parse("CREATE TABLE users (id INT, name TEXT); INSERT INTO users VALUES (1, 'Admin'); SELECT id, name FROM users") 14 | if err != nil { 15 | panic(err) 16 | } 17 | 18 | err = mb.CreateTable(ast.Statements[0].CreateTableStatement) 19 | if err != nil { 20 | panic(err) 21 | } 22 | 23 | err = mb.Insert(ast.Statements[1].InsertStatement) 24 | if err != nil { 25 | panic(err) 26 | } 27 | 28 | results, err := mb.Select(ast.Statements[2].SelectStatement) 29 | if err != nil { 30 | panic(err) 31 | } 32 | 33 | for _, col := range results.Columns { 34 | fmt.Printf("| %s ", col.Name) 35 | } 36 | fmt.Println("|") 37 | 38 | for i := 0; i < 20; i++ { 39 | fmt.Printf("=") 40 | } 41 | fmt.Println() 42 | 43 | for _, result := range results.Rows { 44 | fmt.Printf("|") 45 | 46 | for i, cell := range result { 47 | typ := results.Columns[i].Type 48 | s := "" 49 | switch typ { 50 | case gosql.IntType: 51 | s = fmt.Sprintf("%d", *cell.AsInt()) 52 | case gosql.TextType: 53 | s = *cell.AsText() 54 | } 55 | 56 | fmt.Printf(" %s | ", s) 57 | } 58 | 59 | fmt.Println() 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /cmd/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/eatonphil/gosql" 5 | ) 6 | 7 | func main() { 8 | mb := gosql.NewMemoryBackend() 9 | 10 | gosql.RunRepl(mb) 11 | } 12 | -------------------------------------------------------------------------------- /cmd/sqlexample/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | 7 | _ "github.com/eatonphil/gosql" 8 | ) 9 | 10 | func main() { 11 | db, err := sql.Open("postgres", "") 12 | if err != nil { 13 | panic(err) 14 | } 15 | defer db.Close() 16 | 17 | _, err = db.Query("CREATE TABLE users (name TEXT, age INT);") 18 | if err != nil { 19 | panic(err) 20 | } 21 | 22 | _, err = db.Query("INSERT INTO users VALUES ('Terry', 45);") 23 | if err != nil { 24 | panic(err) 25 | } 26 | 27 | _, err = db.Query("INSERT INTO users VALUES ('Anette', 57);") 28 | if err != nil { 29 | panic(err) 30 | } 31 | 32 | rows, err := db.Query("SELECT name, age FROM users;") 33 | if err != nil { 34 | panic(err) 35 | } 36 | 37 | var name string 38 | var age uint64 39 | defer rows.Close() 40 | for rows.Next() { 41 | err := rows.Scan(&name, &age) 42 | if err != nil { 43 | panic(err) 44 | } 45 | 46 | fmt.Printf("Name: %s, Age: %d\n", name, age) 47 | } 48 | 49 | if err = rows.Err(); err != nil { 50 | panic(err) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /driver.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "fmt" 7 | "io" 8 | ) 9 | 10 | type Rows struct { 11 | columns []ResultColumn 12 | index uint64 13 | rows [][]Cell 14 | } 15 | 16 | func (r *Rows) Columns() []string { 17 | columns := []string{} 18 | for _, c := range r.columns { 19 | columns = append(columns, c.Name) 20 | } 21 | 22 | return columns 23 | } 24 | 25 | func (r *Rows) Close() error { 26 | r.index = uint64(len(r.rows)) 27 | return nil 28 | } 29 | 30 | func (r *Rows) Next(dest []driver.Value) error { 31 | if r.index >= uint64(len(r.rows)) { 32 | return io.EOF 33 | } 34 | 35 | row := r.rows[r.index] 36 | 37 | for idx, cell := range row { 38 | typ := r.columns[idx].Type 39 | switch typ { 40 | case IntType: 41 | i := cell.AsInt() 42 | if i == nil { 43 | dest[idx] = i 44 | } else { 45 | dest[idx] = *i 46 | } 47 | case TextType: 48 | s := cell.AsText() 49 | if s == nil { 50 | dest[idx] = s 51 | } else { 52 | dest[idx] = *s 53 | } 54 | case BoolType: 55 | b := cell.AsBool() 56 | if b == nil { 57 | dest[idx] = b 58 | } else { 59 | dest[idx] = b 60 | } 61 | } 62 | } 63 | 64 | r.index++ 65 | return nil 66 | } 67 | 68 | type Conn struct { 69 | bkd Backend 70 | } 71 | 72 | func (dc *Conn) Query(query string, args []driver.Value) (driver.Rows, error) { 73 | if len(args) > 0 { 74 | // TODO: support parameterization 75 | panic("Parameterization not supported") 76 | } 77 | 78 | parser := Parser{} 79 | ast, err := parser.Parse(query) 80 | if err != nil { 81 | return nil, fmt.Errorf("Error while parsing: %s", err) 82 | } 83 | 84 | // NOTE: ignorning all but the first statement 85 | stmt := ast.Statements[0] 86 | switch stmt.Kind { 87 | case CreateIndexKind: 88 | err = dc.bkd.CreateIndex(stmt.CreateIndexStatement) 89 | if err != nil { 90 | return nil, fmt.Errorf("Error adding index on table: %s", err) 91 | } 92 | case CreateTableKind: 93 | err = dc.bkd.CreateTable(stmt.CreateTableStatement) 94 | if err != nil { 95 | return nil, fmt.Errorf("Error creating table: %s", err) 96 | } 97 | case DropTableKind: 98 | err = dc.bkd.DropTable(stmt.DropTableStatement) 99 | if err != nil { 100 | return nil, fmt.Errorf("Error dropping table: %s", err) 101 | } 102 | case InsertKind: 103 | err = dc.bkd.Insert(stmt.InsertStatement) 104 | if err != nil { 105 | return nil, fmt.Errorf("Error inserting values: %s", err) 106 | } 107 | case SelectKind: 108 | results, err := dc.bkd.Select(stmt.SelectStatement) 109 | if err != nil { 110 | return nil, err 111 | } 112 | 113 | return &Rows{ 114 | rows: results.Rows, 115 | columns: results.Columns, 116 | index: 0, 117 | }, nil 118 | } 119 | 120 | return nil, nil 121 | } 122 | 123 | func (dc *Conn) Prepare(query string) (driver.Stmt, error) { 124 | panic("Prepare not implemented") 125 | } 126 | 127 | func (dc *Conn) Begin() (driver.Tx, error) { 128 | panic("Begin not implemented") 129 | } 130 | 131 | func (dc *Conn) Close() error { 132 | return nil 133 | } 134 | 135 | type Driver struct { 136 | bkd Backend 137 | } 138 | 139 | func (d *Driver) Open(name string) (driver.Conn, error) { 140 | return &Conn{d.bkd}, nil 141 | } 142 | 143 | func init() { 144 | sql.Register("postgres", &Driver{NewMemoryBackend()}) 145 | } 146 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import "errors" 4 | 5 | var ( 6 | ErrTableDoesNotExist = errors.New("Table does not exist") 7 | ErrTableAlreadyExists = errors.New("Table already exists") 8 | ErrIndexAlreadyExists = errors.New("Index already exists") 9 | ErrViolatesUniqueConstraint = errors.New("Duplicate key value violates unique constraint") 10 | ErrViolatesNotNullConstraint = errors.New("Value violates not null constraint") 11 | ErrColumnDoesNotExist = errors.New("Column does not exist") 12 | ErrInvalidSelectItem = errors.New("Select item is not valid") 13 | ErrInvalidDatatype = errors.New("Invalid datatype") 14 | ErrMissingValues = errors.New("Missing values") 15 | ErrInvalidCell = errors.New("Cell is invalid") 16 | ErrInvalidOperands = errors.New("Operands are invalid") 17 | ErrPrimaryKeyAlreadyExists = errors.New("Primary key already exists") 18 | ) 19 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/eatonphil/gosql 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/chzyer/logex v1.1.10 // indirect 7 | github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e 8 | github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1 // indirect 9 | github.com/olekukonko/tablewriter v0.0.4 10 | github.com/petar/GoLLRB v0.0.0-20190514000832-33fb24c13b99 11 | github.com/stretchr/testify v1.5.1 12 | ) 13 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE= 2 | github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= 3 | github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8= 4 | github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= 5 | github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1 h1:q763qf9huN11kDQavWsoZXJNW3xEE4JJyHa5Q25/sd8= 6 | github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= 7 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 8 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 9 | github.com/mattn/go-runewidth v0.0.7 h1:Ei8KR0497xHyKJPAv59M1dkC+rOZCMBJ+t3fZ+twI54= 10 | github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= 11 | github.com/olekukonko/tablewriter v0.0.4 h1:vHD/YYe1Wolo78koG299f7V/VAS08c6IpCLn+Ejf/w8= 12 | github.com/olekukonko/tablewriter v0.0.4/go.mod h1:zq6QwlOf5SlnkVbMSr5EoBv3636FWnp+qbPhuoO21uA= 13 | github.com/petar/GoLLRB v0.0.0-20190514000832-33fb24c13b99 h1:KcEvVBAvyHkUdFAygKAzwB6LAcZ6LS32WHmRD2VyXMI= 14 | github.com/petar/GoLLRB v0.0.0-20190514000832-33fb24c13b99/go.mod h1:HUpKUBZnpzkdx0kD/+Yfuft+uD3zHGtXF/XJB14TUr4= 15 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 16 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 17 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 18 | github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= 19 | github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= 20 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 21 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 22 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 23 | -------------------------------------------------------------------------------- /lexer.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // location of the token in source code 9 | type Location struct { 10 | Line uint 11 | Col uint 12 | } 13 | 14 | // for storing SQL reserved Keywords 15 | type Keyword string 16 | 17 | const ( 18 | SelectKeyword Keyword = "select" 19 | FromKeyword Keyword = "from" 20 | AsKeyword Keyword = "as" 21 | TableKeyword Keyword = "table" 22 | CreateKeyword Keyword = "create" 23 | DropKeyword Keyword = "drop" 24 | InsertKeyword Keyword = "insert" 25 | IntoKeyword Keyword = "into" 26 | ValuesKeyword Keyword = "values" 27 | IntKeyword Keyword = "int" 28 | TextKeyword Keyword = "text" 29 | BoolKeyword Keyword = "boolean" 30 | WhereKeyword Keyword = "where" 31 | AndKeyword Keyword = "and" 32 | OrKeyword Keyword = "or" 33 | TrueKeyword Keyword = "true" 34 | FalseKeyword Keyword = "false" 35 | UniqueKeyword Keyword = "unique" 36 | IndexKeyword Keyword = "index" 37 | OnKeyword Keyword = "on" 38 | PrimarykeyKeyword Keyword = "primary key" 39 | NullKeyword Keyword = "null" 40 | LimitKeyword Keyword = "limit" 41 | OffsetKeyword Keyword = "offset" 42 | ) 43 | 44 | // for storing SQL syntax 45 | type Symbol string 46 | 47 | const ( 48 | SemicolonSymbol Symbol = ";" 49 | AsteriskSymbol Symbol = "*" 50 | CommaSymbol Symbol = "," 51 | LeftParenSymbol Symbol = "(" 52 | RightParenSymbol Symbol = ")" 53 | EqSymbol Symbol = "=" 54 | NeqSymbol Symbol = "<>" 55 | NeqSymbol2 Symbol = "!=" 56 | ConcatSymbol Symbol = "||" 57 | PlusSymbol Symbol = "+" 58 | LtSymbol Symbol = "<" 59 | LteSymbol Symbol = "<=" 60 | GtSymbol Symbol = ">" 61 | GteSymbol Symbol = ">=" 62 | ) 63 | 64 | type TokenKind uint 65 | 66 | const ( 67 | KeywordKind TokenKind = iota 68 | SymbolKind 69 | IdentifierKind 70 | StringKind 71 | NumericKind 72 | BoolKind 73 | NullKind 74 | ) 75 | 76 | type Token struct { 77 | Value string 78 | Kind TokenKind 79 | Loc Location 80 | } 81 | 82 | func (t Token) bindingPower() uint { 83 | switch t.Kind { 84 | case KeywordKind: 85 | switch Keyword(t.Value) { 86 | case AndKeyword: 87 | fallthrough 88 | case OrKeyword: 89 | return 1 90 | } 91 | case SymbolKind: 92 | switch Symbol(t.Value) { 93 | case EqSymbol: 94 | fallthrough 95 | case NeqSymbol: 96 | return 2 97 | 98 | case LtSymbol: 99 | fallthrough 100 | case GtSymbol: 101 | return 3 102 | 103 | // For some reason these are grouped separately 104 | case LteSymbol: 105 | fallthrough 106 | case GteSymbol: 107 | return 4 108 | 109 | case ConcatSymbol: 110 | fallthrough 111 | case PlusSymbol: 112 | return 5 113 | } 114 | } 115 | 116 | return 0 117 | } 118 | 119 | func (t *Token) equals(other *Token) bool { 120 | return t.Value == other.Value && t.Kind == other.Kind 121 | } 122 | 123 | // cursor indicates the current position of the lexer 124 | type cursor struct { 125 | pointer uint 126 | loc Location 127 | } 128 | 129 | // longestMatch iterates through a source string starting at the given 130 | // cursor to find the longest matching substring among the provided 131 | // options 132 | func longestMatch(source string, ic cursor, options []string) string { 133 | var value []byte 134 | var skipList []int 135 | var match string 136 | 137 | cur := ic 138 | 139 | for cur.pointer < uint(len(source)) { 140 | 141 | value = append(value, strings.ToLower(string(source[cur.pointer]))...) 142 | cur.pointer++ 143 | 144 | match: 145 | for i, option := range options { 146 | for _, skip := range skipList { 147 | if i == skip { 148 | continue match 149 | } 150 | } 151 | 152 | // Deal with cases like INT vs INTO 153 | if option == string(value) { 154 | skipList = append(skipList, i) 155 | if len(option) > len(match) { 156 | match = option 157 | } 158 | 159 | continue 160 | } 161 | 162 | sharesPrefix := string(value) == option[:cur.pointer-ic.pointer] 163 | tooLong := len(value) > len(option) 164 | if tooLong || !sharesPrefix { 165 | skipList = append(skipList, i) 166 | } 167 | } 168 | 169 | if len(skipList) == len(options) { 170 | break 171 | } 172 | } 173 | 174 | return match 175 | } 176 | 177 | func lexSymbol(source string, ic cursor) (*Token, cursor, bool) { 178 | c := source[ic.pointer] 179 | cur := ic 180 | // Will get overwritten later if not an ignored syntax 181 | cur.pointer++ 182 | cur.loc.Col++ 183 | 184 | switch c { 185 | // Syntax that should be thrown away 186 | case '\n': 187 | cur.loc.Line++ 188 | cur.loc.Col = 0 189 | fallthrough 190 | case '\t': 191 | fallthrough 192 | case ' ': 193 | return nil, cur, true 194 | } 195 | 196 | // Syntax that should be kept 197 | Symbols := []Symbol{ 198 | EqSymbol, 199 | NeqSymbol, 200 | NeqSymbol2, 201 | LtSymbol, 202 | LteSymbol, 203 | GtSymbol, 204 | GteSymbol, 205 | ConcatSymbol, 206 | PlusSymbol, 207 | CommaSymbol, 208 | LeftParenSymbol, 209 | RightParenSymbol, 210 | SemicolonSymbol, 211 | AsteriskSymbol, 212 | } 213 | 214 | var options []string 215 | for _, s := range Symbols { 216 | options = append(options, string(s)) 217 | } 218 | 219 | // Use `ic`, not `cur` 220 | match := longestMatch(source, ic, options) 221 | // Unknown character 222 | if match == "" { 223 | return nil, ic, false 224 | } 225 | 226 | cur.pointer = ic.pointer + uint(len(match)) 227 | cur.loc.Col = ic.loc.Col + uint(len(match)) 228 | 229 | // != is rewritten as <>: https://www.postgresql.org/docs/9.5/functions-comparison.html 230 | if match == string(NeqSymbol2) { 231 | match = string(NeqSymbol) 232 | } 233 | 234 | return &Token{ 235 | Value: match, 236 | Loc: ic.loc, 237 | Kind: SymbolKind, 238 | }, cur, true 239 | } 240 | 241 | func lexKeyword(source string, ic cursor) (*Token, cursor, bool) { 242 | cur := ic 243 | Keywords := []Keyword{ 244 | SelectKeyword, 245 | InsertKeyword, 246 | ValuesKeyword, 247 | TableKeyword, 248 | CreateKeyword, 249 | DropKeyword, 250 | WhereKeyword, 251 | FromKeyword, 252 | IntoKeyword, 253 | TextKeyword, 254 | BoolKeyword, 255 | IntKeyword, 256 | AndKeyword, 257 | OrKeyword, 258 | AsKeyword, 259 | TrueKeyword, 260 | FalseKeyword, 261 | UniqueKeyword, 262 | IndexKeyword, 263 | OnKeyword, 264 | PrimarykeyKeyword, 265 | NullKeyword, 266 | LimitKeyword, 267 | OffsetKeyword, 268 | } 269 | 270 | var options []string 271 | for _, k := range Keywords { 272 | options = append(options, string(k)) 273 | } 274 | 275 | match := longestMatch(source, ic, options) 276 | if match == "" { 277 | return nil, ic, false 278 | } 279 | 280 | cur.pointer = ic.pointer + uint(len(match)) 281 | cur.loc.Col = ic.loc.Col + uint(len(match)) 282 | 283 | Kind := KeywordKind 284 | if match == string(TrueKeyword) || match == string(FalseKeyword) { 285 | Kind = BoolKind 286 | } 287 | 288 | if match == string(NullKeyword) { 289 | Kind = NullKind 290 | } 291 | 292 | return &Token{ 293 | Value: match, 294 | Kind: Kind, 295 | Loc: ic.loc, 296 | }, cur, true 297 | } 298 | 299 | func lexNumeric(source string, ic cursor) (*Token, cursor, bool) { 300 | cur := ic 301 | 302 | periodFound := false 303 | expMarkerFound := false 304 | 305 | for ; cur.pointer < uint(len(source)); cur.pointer++ { 306 | c := source[cur.pointer] 307 | cur.loc.Col++ 308 | 309 | isDigit := c >= '0' && c <= '9' 310 | isPeriod := c == '.' 311 | isExpMarker := c == 'e' 312 | 313 | // Must start with a digit or period 314 | if cur.pointer == ic.pointer { 315 | if !isDigit && !isPeriod { 316 | return nil, ic, false 317 | } 318 | 319 | periodFound = isPeriod 320 | continue 321 | } 322 | 323 | if isPeriod { 324 | if periodFound { 325 | return nil, ic, false 326 | } 327 | 328 | periodFound = true 329 | continue 330 | } 331 | 332 | if isExpMarker { 333 | if expMarkerFound { 334 | return nil, ic, false 335 | } 336 | 337 | // No periods allowed after expMarker 338 | periodFound = true 339 | expMarkerFound = true 340 | 341 | // expMarker must be followed by digits 342 | if cur.pointer == uint(len(source)-1) { 343 | return nil, ic, false 344 | } 345 | 346 | cNext := source[cur.pointer+1] 347 | if cNext == '-' || cNext == '+' { 348 | cur.pointer++ 349 | cur.loc.Col++ 350 | } 351 | continue 352 | } 353 | 354 | if !isDigit { 355 | break 356 | } 357 | } 358 | 359 | // No characters accumulated 360 | if cur.pointer == ic.pointer { 361 | return nil, ic, false 362 | } 363 | 364 | return &Token{ 365 | Value: source[ic.pointer:cur.pointer], 366 | Loc: ic.loc, 367 | Kind: NumericKind, 368 | }, cur, true 369 | } 370 | 371 | // lexCharacterDelimited looks through a source string starting at the 372 | // given cursor to find a start- and end- delimiter. The delimiter can 373 | // be escaped be preceeding the delimiter with itself. 374 | func lexCharacterDelimited(source string, ic cursor, delimiter byte) (*Token, cursor, bool) { 375 | cur := ic 376 | 377 | if len(source[cur.pointer:]) == 0 { 378 | return nil, ic, false 379 | } 380 | 381 | if source[cur.pointer] != delimiter { 382 | return nil, ic, false 383 | } 384 | 385 | cur.loc.Col++ 386 | cur.pointer++ 387 | 388 | var value []byte 389 | for ; cur.pointer < uint(len(source)); cur.pointer++ { 390 | c := source[cur.pointer] 391 | 392 | if c == delimiter { 393 | // SQL escapes are via double characters, not backslash. 394 | if cur.pointer+1 >= uint(len(source)) || source[cur.pointer+1] != delimiter { 395 | cur.pointer++ 396 | cur.loc.Col++ 397 | return &Token{ 398 | Value: string(value), 399 | Loc: ic.loc, 400 | Kind: StringKind, 401 | }, cur, true 402 | } 403 | value = append(value, delimiter) 404 | cur.pointer++ 405 | cur.loc.Col++ 406 | } 407 | 408 | value = append(value, c) 409 | cur.loc.Col++ 410 | } 411 | 412 | return nil, ic, false 413 | } 414 | 415 | func lexIdentifier(source string, ic cursor) (*Token, cursor, bool) { 416 | // Handle separately if is a double-quoted identifier 417 | if token, newCursor, ok := lexCharacterDelimited(source, ic, '"'); ok { 418 | // Overwrite from stringkind to identifierkind 419 | token.Kind = IdentifierKind 420 | return token, newCursor, true 421 | } 422 | 423 | cur := ic 424 | 425 | c := source[cur.pointer] 426 | // Other characters count too, big ignoring non-ascii for now 427 | isAlphabetical := (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') 428 | if !isAlphabetical { 429 | return nil, ic, false 430 | } 431 | cur.pointer++ 432 | cur.loc.Col++ 433 | 434 | value := []byte{c} 435 | for ; cur.pointer < uint(len(source)); cur.pointer++ { 436 | c = source[cur.pointer] 437 | 438 | // Other characters count too, big ignoring non-ascii for now 439 | isAlphabetical := (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') 440 | isNumeric := c >= '0' && c <= '9' 441 | if isAlphabetical || isNumeric || c == '$' || c == '_' { 442 | value = append(value, c) 443 | cur.loc.Col++ 444 | continue 445 | } 446 | 447 | break 448 | } 449 | 450 | return &Token{ 451 | // Unquoted identifiers are case-insensitive 452 | Value: strings.ToLower(string(value)), 453 | Loc: ic.loc, 454 | Kind: IdentifierKind, 455 | }, cur, true 456 | } 457 | 458 | func lexString(source string, ic cursor) (*Token, cursor, bool) { 459 | return lexCharacterDelimited(source, ic, '\'') 460 | } 461 | 462 | type lexer func(string, cursor) (*Token, cursor, bool) 463 | 464 | // lex splits an input string into a list of Tokens. This process 465 | // can be divided into following tasks: 466 | // 467 | // 1. Instantiating a cursor with pointing to the start of the string 468 | // 469 | // 2. Execute all the lexers in series. 470 | // 471 | // 3. If any of the lexer generate a Token then add the Token to the 472 | // Token slice, update the cursor and restart the process from the new 473 | // cursor location. 474 | func lex(source string) ([]*Token, error) { 475 | var tokens []*Token 476 | cur := cursor{} 477 | 478 | lex: 479 | for cur.pointer < uint(len(source)) { 480 | lexers := []lexer{lexKeyword, lexSymbol, lexString, lexNumeric, lexIdentifier} 481 | for _, l := range lexers { 482 | if token, newCursor, ok := l(source, cur); ok { 483 | cur = newCursor 484 | 485 | // Omit nil tokens for valid, but empty syntax like newlines 486 | if token != nil { 487 | tokens = append(tokens, token) 488 | } 489 | 490 | continue lex 491 | } 492 | } 493 | 494 | hint := "" 495 | if len(tokens) > 0 { 496 | hint = " after " + tokens[len(tokens)-1].Value 497 | } 498 | for _, t := range tokens { 499 | fmt.Println(t.Value) 500 | } 501 | return nil, fmt.Errorf("Unable to lex token%s, at %d:%d", hint, cur.loc.Line, cur.loc.Col) 502 | } 503 | 504 | return tokens, nil 505 | } 506 | -------------------------------------------------------------------------------- /lexer_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestToken_lexNumeric(t *testing.T) { 11 | tests := []struct { 12 | number bool 13 | value string 14 | }{ 15 | { 16 | number: true, 17 | value: "105", 18 | }, 19 | { 20 | number: true, 21 | value: "105 ", 22 | }, 23 | { 24 | number: true, 25 | value: "123.", 26 | }, 27 | { 28 | number: true, 29 | value: "123.145", 30 | }, 31 | { 32 | number: true, 33 | value: "1e5", 34 | }, 35 | { 36 | number: true, 37 | value: "1.e21", 38 | }, 39 | { 40 | number: true, 41 | value: "1.1e2", 42 | }, 43 | { 44 | number: true, 45 | value: "1.1e-2", 46 | }, 47 | { 48 | number: true, 49 | value: "1.1e+2", 50 | }, 51 | { 52 | number: true, 53 | value: "1e-1", 54 | }, 55 | { 56 | number: true, 57 | value: ".1", 58 | }, 59 | { 60 | number: true, 61 | value: "4.", 62 | }, 63 | // false tests 64 | { 65 | number: false, 66 | value: "e4", 67 | }, 68 | { 69 | number: false, 70 | value: "1..", 71 | }, 72 | { 73 | number: false, 74 | value: "1ee4", 75 | }, 76 | { 77 | number: false, 78 | value: " 1", 79 | }, 80 | } 81 | 82 | for _, test := range tests { 83 | tok, _, ok := lexNumeric(test.value, cursor{}) 84 | assert.Equal(t, test.number, ok, test.value) 85 | if ok { 86 | assert.Equal(t, strings.TrimSpace(test.value), tok.Value, test.value) 87 | } 88 | } 89 | } 90 | 91 | func TestToken_lexString(t *testing.T) { 92 | tests := []struct { 93 | string bool 94 | value string 95 | }{ 96 | { 97 | string: false, 98 | value: "a", 99 | }, 100 | { 101 | string: true, 102 | value: "'abc'", 103 | }, 104 | { 105 | string: true, 106 | value: "'a b'", 107 | }, 108 | { 109 | string: true, 110 | value: "'a' ", 111 | }, 112 | { 113 | string: true, 114 | value: "'a '' b'", 115 | }, 116 | // false tests 117 | { 118 | string: false, 119 | value: "'", 120 | }, 121 | { 122 | string: false, 123 | value: "", 124 | }, 125 | { 126 | string: false, 127 | value: " 'foo'", 128 | }, 129 | } 130 | 131 | for _, test := range tests { 132 | tok, _, ok := lexString(test.value, cursor{}) 133 | assert.Equal(t, test.string, ok, test.value) 134 | if ok { 135 | test.value = strings.TrimSpace(test.value) 136 | assert.Equal(t, test.value[1:len(test.value)-1], tok.Value, test.value) 137 | } 138 | } 139 | } 140 | 141 | func TestToken_lexSymbol(t *testing.T) { 142 | tests := []struct { 143 | symbol bool 144 | value string 145 | }{ 146 | { 147 | symbol: true, 148 | value: "= ", 149 | }, 150 | { 151 | symbol: true, 152 | value: "||", 153 | }, 154 | } 155 | 156 | for _, test := range tests { 157 | tok, _, ok := lexSymbol(test.value, cursor{}) 158 | assert.Equal(t, test.symbol, ok, test.value) 159 | if ok { 160 | test.value = strings.TrimSpace(test.value) 161 | assert.Equal(t, test.value, tok.Value, test.value) 162 | } 163 | } 164 | } 165 | 166 | func TestToken_lexIdentifier(t *testing.T) { 167 | tests := []struct { 168 | Identifier bool 169 | input string 170 | value string 171 | }{ 172 | { 173 | Identifier: true, 174 | input: "a", 175 | value: "a", 176 | }, 177 | { 178 | Identifier: true, 179 | input: "abc", 180 | value: "abc", 181 | }, 182 | { 183 | Identifier: true, 184 | input: "abc ", 185 | value: "abc", 186 | }, 187 | { 188 | Identifier: true, 189 | input: `" abc "`, 190 | value: ` abc `, 191 | }, 192 | { 193 | Identifier: true, 194 | input: "a9$", 195 | value: "a9$", 196 | }, 197 | { 198 | Identifier: true, 199 | input: "userName", 200 | value: "username", 201 | }, 202 | { 203 | Identifier: true, 204 | input: `"userName"`, 205 | value: "userName", 206 | }, 207 | // false tests 208 | { 209 | Identifier: false, 210 | input: `"`, 211 | }, 212 | { 213 | Identifier: false, 214 | input: "_sadsfa", 215 | }, 216 | { 217 | Identifier: false, 218 | input: "9sadsfa", 219 | }, 220 | { 221 | Identifier: false, 222 | input: " abc", 223 | }, 224 | } 225 | 226 | for _, test := range tests { 227 | tok, _, ok := lexIdentifier(test.input, cursor{}) 228 | assert.Equal(t, test.Identifier, ok, test.input) 229 | if ok { 230 | assert.Equal(t, test.value, tok.Value, test.input) 231 | } 232 | } 233 | } 234 | 235 | func TestToken_lexKeyword(t *testing.T) { 236 | tests := []struct { 237 | keyword bool 238 | value string 239 | }{ 240 | { 241 | keyword: true, 242 | value: "select ", 243 | }, 244 | { 245 | keyword: true, 246 | value: "from", 247 | }, 248 | { 249 | keyword: true, 250 | value: "as", 251 | }, 252 | { 253 | keyword: true, 254 | value: "SELECT", 255 | }, 256 | { 257 | keyword: true, 258 | value: "into", 259 | }, 260 | // false tests 261 | { 262 | keyword: false, 263 | value: " into", 264 | }, 265 | { 266 | keyword: false, 267 | value: "flubbrety", 268 | }, 269 | } 270 | 271 | for _, test := range tests { 272 | tok, _, ok := lexKeyword(test.value, cursor{}) 273 | assert.Equal(t, test.keyword, ok, test.value) 274 | if ok { 275 | test.value = strings.TrimSpace(test.value) 276 | assert.Equal(t, strings.ToLower(test.value), tok.Value, test.value) 277 | } 278 | } 279 | } 280 | 281 | func TestLex(t *testing.T) { 282 | tests := []struct { 283 | input string 284 | Tokens []Token 285 | err error 286 | }{ 287 | { 288 | input: "select a", 289 | Tokens: []Token{ 290 | { 291 | Loc: Location{Col: 0, Line: 0}, 292 | Value: string(SelectKeyword), 293 | Kind: KeywordKind, 294 | }, 295 | { 296 | Loc: Location{Col: 7, Line: 0}, 297 | Value: "a", 298 | Kind: IdentifierKind, 299 | }, 300 | }, 301 | }, 302 | { 303 | input: "select true", 304 | Tokens: []Token{ 305 | { 306 | Loc: Location{Col: 0, Line: 0}, 307 | Value: string(SelectKeyword), 308 | Kind: KeywordKind, 309 | }, 310 | { 311 | Loc: Location{Col: 7, Line: 0}, 312 | Value: "true", 313 | Kind: BoolKind, 314 | }, 315 | }, 316 | }, 317 | { 318 | input: "select 1", 319 | Tokens: []Token{ 320 | { 321 | Loc: Location{Col: 0, Line: 0}, 322 | Value: string(SelectKeyword), 323 | Kind: KeywordKind, 324 | }, 325 | { 326 | Loc: Location{Col: 7, Line: 0}, 327 | Value: "1", 328 | Kind: NumericKind, 329 | }, 330 | }, 331 | err: nil, 332 | }, 333 | { 334 | input: "select 'foo' || 'bar';", 335 | Tokens: []Token{ 336 | { 337 | Loc: Location{Col: 0, Line: 0}, 338 | Value: string(SelectKeyword), 339 | Kind: KeywordKind, 340 | }, 341 | { 342 | Loc: Location{Col: 7, Line: 0}, 343 | Value: "foo", 344 | Kind: StringKind, 345 | }, 346 | { 347 | Loc: Location{Col: 13, Line: 0}, 348 | Value: string(ConcatSymbol), 349 | Kind: SymbolKind, 350 | }, 351 | { 352 | Loc: Location{Col: 16, Line: 0}, 353 | Value: "bar", 354 | Kind: StringKind, 355 | }, 356 | { 357 | Loc: Location{Col: 21, Line: 0}, 358 | Value: string(SemicolonSymbol), 359 | Kind: SymbolKind, 360 | }, 361 | }, 362 | err: nil, 363 | }, 364 | { 365 | input: "CREATE TABLE u (id INT, name TEXT)", 366 | Tokens: []Token{ 367 | { 368 | Loc: Location{Col: 0, Line: 0}, 369 | Value: string(CreateKeyword), 370 | Kind: KeywordKind, 371 | }, 372 | { 373 | Loc: Location{Col: 7, Line: 0}, 374 | Value: string(TableKeyword), 375 | Kind: KeywordKind, 376 | }, 377 | { 378 | Loc: Location{Col: 13, Line: 0}, 379 | Value: "u", 380 | Kind: IdentifierKind, 381 | }, 382 | { 383 | Loc: Location{Col: 15, Line: 0}, 384 | Value: "(", 385 | Kind: SymbolKind, 386 | }, 387 | { 388 | Loc: Location{Col: 16, Line: 0}, 389 | Value: "id", 390 | Kind: IdentifierKind, 391 | }, 392 | { 393 | Loc: Location{Col: 19, Line: 0}, 394 | Value: "int", 395 | Kind: KeywordKind, 396 | }, 397 | { 398 | Loc: Location{Col: 22, Line: 0}, 399 | Value: ",", 400 | Kind: SymbolKind, 401 | }, 402 | { 403 | Loc: Location{Col: 24, Line: 0}, 404 | Value: "name", 405 | Kind: IdentifierKind, 406 | }, 407 | { 408 | Loc: Location{Col: 29, Line: 0}, 409 | Value: "text", 410 | Kind: KeywordKind, 411 | }, 412 | { 413 | Loc: Location{Col: 33, Line: 0}, 414 | Value: ")", 415 | Kind: SymbolKind, 416 | }, 417 | }, 418 | }, 419 | { 420 | input: "insert into users Values (105, 233)", 421 | Tokens: []Token{ 422 | { 423 | Loc: Location{Col: 0, Line: 0}, 424 | Value: string(InsertKeyword), 425 | Kind: KeywordKind, 426 | }, 427 | { 428 | Loc: Location{Col: 7, Line: 0}, 429 | Value: string(IntoKeyword), 430 | Kind: KeywordKind, 431 | }, 432 | { 433 | Loc: Location{Col: 12, Line: 0}, 434 | Value: "users", 435 | Kind: IdentifierKind, 436 | }, 437 | { 438 | Loc: Location{Col: 18, Line: 0}, 439 | Value: string(ValuesKeyword), 440 | Kind: KeywordKind, 441 | }, 442 | { 443 | Loc: Location{Col: 25, Line: 0}, 444 | Value: "(", 445 | Kind: SymbolKind, 446 | }, 447 | { 448 | Loc: Location{Col: 26, Line: 0}, 449 | Value: "105", 450 | Kind: NumericKind, 451 | }, 452 | { 453 | Loc: Location{Col: 30, Line: 0}, 454 | Value: ",", 455 | Kind: SymbolKind, 456 | }, 457 | { 458 | Loc: Location{Col: 32, Line: 0}, 459 | Value: "233", 460 | Kind: NumericKind, 461 | }, 462 | { 463 | Loc: Location{Col: 36, Line: 0}, 464 | Value: ")", 465 | Kind: SymbolKind, 466 | }, 467 | }, 468 | err: nil, 469 | }, 470 | { 471 | input: "SELECT id FROM users;", 472 | Tokens: []Token{ 473 | { 474 | Loc: Location{Col: 0, Line: 0}, 475 | Value: string(SelectKeyword), 476 | Kind: KeywordKind, 477 | }, 478 | { 479 | Loc: Location{Col: 7, Line: 0}, 480 | Value: "id", 481 | Kind: IdentifierKind, 482 | }, 483 | { 484 | Loc: Location{Col: 10, Line: 0}, 485 | Value: string(FromKeyword), 486 | Kind: KeywordKind, 487 | }, 488 | { 489 | Loc: Location{Col: 15, Line: 0}, 490 | Value: "users", 491 | Kind: IdentifierKind, 492 | }, 493 | { 494 | Loc: Location{Col: 20, Line: 0}, 495 | Value: ";", 496 | Kind: SymbolKind, 497 | }, 498 | }, 499 | err: nil, 500 | }, 501 | } 502 | 503 | for _, test := range tests { 504 | tokens, err := lex(test.input) 505 | assert.Equal(t, test.err, err, test.input) 506 | assert.Equal(t, len(test.Tokens), len(tokens), test.input) 507 | 508 | for i, tok := range tokens { 509 | assert.Equal(t, &test.Tokens[i], tok, test.input) 510 | } 511 | } 512 | } 513 | -------------------------------------------------------------------------------- /memory.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "fmt" 7 | "strconv" 8 | 9 | "github.com/petar/GoLLRB/llrb" 10 | ) 11 | 12 | // memoryCell is the underlying storage for the in-memory backend 13 | // implementation. Each supported datatype can be mapped to and from 14 | // this byte array. 15 | type memoryCell []byte 16 | 17 | func (mc memoryCell) AsInt() *int32 { 18 | if len(mc) == 0 { 19 | return nil 20 | } 21 | 22 | var i int32 23 | err := binary.Read(bytes.NewBuffer(mc), binary.BigEndian, &i) 24 | if err != nil { 25 | fmt.Printf("Corrupted data [%s]: %s\n", mc, err) 26 | return nil 27 | } 28 | 29 | return &i 30 | } 31 | 32 | func (mc memoryCell) AsText() *string { 33 | if len(mc) == 0 { 34 | return nil 35 | } 36 | 37 | s := string(mc) 38 | return &s 39 | } 40 | 41 | func (mc memoryCell) AsBool() *bool { 42 | if len(mc) == 0 { 43 | return nil 44 | } 45 | 46 | b := mc[0] == 1 47 | return &b 48 | } 49 | 50 | func (mc memoryCell) equals(b memoryCell) bool { 51 | // Seems verbose but need to make sure if one is nil, the 52 | // comparison still fails quickly 53 | if mc == nil || b == nil { 54 | return mc == nil && b == nil 55 | } 56 | 57 | return bytes.Equal(mc, b) 58 | } 59 | 60 | func literalToMemoryCell(t *Token) memoryCell { 61 | if t.Kind == NumericKind { 62 | buf := new(bytes.Buffer) 63 | i, err := strconv.Atoi(t.Value) 64 | if err != nil { 65 | fmt.Printf("Corrupted data [%s]: %s\n", t.Value, err) 66 | return nil 67 | } 68 | 69 | // TODO: handle bigint 70 | err = binary.Write(buf, binary.BigEndian, int32(i)) 71 | if err != nil { 72 | fmt.Printf("Corrupted data [%s]: %s\n", buf.String(), err) 73 | return nil 74 | } 75 | return buf.Bytes() 76 | } 77 | 78 | if t.Kind == StringKind { 79 | return memoryCell(t.Value) 80 | } 81 | 82 | if t.Kind == BoolKind { 83 | if t.Value == "true" { 84 | return []byte{1} 85 | } 86 | 87 | return []byte{0} 88 | } 89 | 90 | return nil 91 | } 92 | 93 | var ( 94 | trueToken = Token{Kind: BoolKind, Value: "true"} 95 | falseToken = Token{Kind: BoolKind, Value: "false"} 96 | 97 | trueMemoryCell = literalToMemoryCell(&trueToken) 98 | falseMemoryCell = literalToMemoryCell(&falseToken) 99 | nullMemoryCell = literalToMemoryCell(&Token{Kind: NullKind}) 100 | ) 101 | 102 | type treeItem struct { 103 | value memoryCell 104 | index uint 105 | } 106 | 107 | func (te treeItem) Less(than llrb.Item) bool { 108 | return bytes.Compare(te.value, than.(treeItem).value) < 0 109 | } 110 | 111 | type index struct { 112 | name string 113 | exp Expression 114 | unique bool 115 | primaryKey bool 116 | tree *llrb.LLRB 117 | typ string 118 | } 119 | 120 | func (i *index) addRow(t *table, rowIndex uint) error { 121 | indexValue, _, _, err := t.evaluateCell(rowIndex, i.exp) 122 | if err != nil { 123 | return err 124 | } 125 | 126 | if indexValue == nil { 127 | return ErrViolatesNotNullConstraint 128 | } 129 | 130 | if i.unique && i.tree.Has(treeItem{value: indexValue}) { 131 | return ErrViolatesUniqueConstraint 132 | } 133 | 134 | i.tree.InsertNoReplace(treeItem{ 135 | value: indexValue, 136 | index: rowIndex, 137 | }) 138 | return nil 139 | } 140 | 141 | func (i *index) applicableValue(exp Expression) *Expression { 142 | if exp.Kind != BinaryKind { 143 | return nil 144 | } 145 | 146 | be := exp.Binary 147 | // Find the column and the value in the binary Expression 148 | columnExp := be.A 149 | valueExp := be.B 150 | if columnExp.GenerateCode() != i.exp.GenerateCode() { 151 | columnExp = be.B 152 | valueExp = be.A 153 | } 154 | 155 | // Neither side is applicable, return nil 156 | if columnExp.GenerateCode() != i.exp.GenerateCode() { 157 | return nil 158 | } 159 | 160 | supportedChecks := []Symbol{EqSymbol, NeqSymbol, GtSymbol, GteSymbol, LtSymbol, LteSymbol} 161 | supported := false 162 | for _, sym := range supportedChecks { 163 | if string(sym) == be.Op.Value { 164 | supported = true 165 | break 166 | } 167 | } 168 | if !supported { 169 | return nil 170 | } 171 | 172 | if valueExp.Kind != LiteralKind { 173 | fmt.Println("Only index checks on literals supported") 174 | return nil 175 | } 176 | 177 | return &valueExp 178 | } 179 | 180 | func (i *index) newTableFromSubset(t *table, exp Expression) *table { 181 | valueExp := i.applicableValue(exp) 182 | if valueExp == nil { 183 | return t 184 | } 185 | 186 | value, _, _, err := createTable().evaluateCell(0, *valueExp) 187 | if err != nil { 188 | fmt.Println(err) 189 | return t 190 | } 191 | 192 | tiValue := treeItem{value: value} 193 | 194 | indexes := []uint{} 195 | switch Symbol(exp.Binary.Op.Value) { 196 | case EqSymbol: 197 | i.tree.AscendGreaterOrEqual(tiValue, func(i llrb.Item) bool { 198 | ti := i.(treeItem) 199 | 200 | if !bytes.Equal(ti.value, value) { 201 | return false 202 | } 203 | 204 | indexes = append(indexes, ti.index) 205 | return true 206 | }) 207 | case NeqSymbol: 208 | i.tree.AscendGreaterOrEqual(llrb.Inf(-1), func(i llrb.Item) bool { 209 | ti := i.(treeItem) 210 | if bytes.Equal(ti.value, value) { 211 | indexes = append(indexes, ti.index) 212 | } 213 | 214 | return true 215 | }) 216 | case LtSymbol: 217 | i.tree.DescendLessOrEqual(tiValue, func(i llrb.Item) bool { 218 | ti := i.(treeItem) 219 | if bytes.Compare(ti.value, value) < 0 { 220 | indexes = append(indexes, ti.index) 221 | } 222 | 223 | return true 224 | }) 225 | case LteSymbol: 226 | i.tree.DescendLessOrEqual(tiValue, func(i llrb.Item) bool { 227 | ti := i.(treeItem) 228 | if bytes.Compare(ti.value, value) <= 0 { 229 | indexes = append(indexes, ti.index) 230 | } 231 | 232 | return true 233 | }) 234 | case GtSymbol: 235 | i.tree.AscendGreaterOrEqual(tiValue, func(i llrb.Item) bool { 236 | ti := i.(treeItem) 237 | if bytes.Compare(ti.value, value) > 0 { 238 | indexes = append(indexes, ti.index) 239 | } 240 | 241 | return true 242 | }) 243 | case GteSymbol: 244 | i.tree.AscendGreaterOrEqual(tiValue, func(i llrb.Item) bool { 245 | ti := i.(treeItem) 246 | if bytes.Compare(ti.value, value) >= 0 { 247 | indexes = append(indexes, ti.index) 248 | } 249 | 250 | return true 251 | }) 252 | } 253 | 254 | newT := createTable() 255 | newT.columns = t.columns 256 | newT.columnTypes = t.columnTypes 257 | newT.indexes = t.indexes 258 | newT.rows = [][]memoryCell{} 259 | 260 | for _, index := range indexes { 261 | newT.rows = append(newT.rows, t.rows[index]) 262 | } 263 | 264 | return newT 265 | } 266 | 267 | type table struct { 268 | name string 269 | columns []string 270 | columnTypes []ColumnType 271 | rows [][]memoryCell 272 | indexes []*index 273 | } 274 | 275 | func createTable() *table { 276 | return &table{ 277 | name: "?tmp?", 278 | columns: nil, 279 | columnTypes: nil, 280 | rows: nil, 281 | indexes: []*index{}, 282 | } 283 | } 284 | 285 | func (t *table) evaluateLiteralCell(rowIndex uint, exp Expression) (memoryCell, string, ColumnType, error) { 286 | if exp.Kind != LiteralKind { 287 | return nil, "", 0, ErrInvalidCell 288 | } 289 | 290 | lit := exp.Literal 291 | if lit.Kind == IdentifierKind { 292 | for i, tableCol := range t.columns { 293 | if tableCol == lit.Value { 294 | return t.rows[rowIndex][i], tableCol, t.columnTypes[i], nil 295 | } 296 | } 297 | 298 | return nil, "", 0, ErrColumnDoesNotExist 299 | } 300 | 301 | columnType := IntType 302 | if lit.Kind == StringKind { 303 | columnType = TextType 304 | } else if lit.Kind == BoolKind { 305 | columnType = BoolType 306 | } 307 | 308 | return literalToMemoryCell(lit), "?column?", columnType, nil 309 | } 310 | 311 | func (t *table) evaluateBinaryCell(rowIndex uint, exp Expression) (memoryCell, string, ColumnType, error) { 312 | if exp.Kind != BinaryKind { 313 | return nil, "", 0, ErrInvalidCell 314 | } 315 | 316 | bexp := exp.Binary 317 | 318 | l, _, lt, err := t.evaluateCell(rowIndex, bexp.A) 319 | if err != nil { 320 | return nil, "", 0, err 321 | } 322 | 323 | r, _, rt, err := t.evaluateCell(rowIndex, bexp.B) 324 | if err != nil { 325 | return nil, "", 0, err 326 | } 327 | 328 | switch bexp.Op.Kind { 329 | case SymbolKind: 330 | switch Symbol(bexp.Op.Value) { 331 | case EqSymbol: 332 | if len(l) == 0 || len(r) == 0 { 333 | return nullMemoryCell, "?column?", BoolType, nil 334 | } 335 | 336 | eq := l.equals(r) 337 | if lt == TextType && rt == TextType && eq { 338 | return trueMemoryCell, "?column?", BoolType, nil 339 | } 340 | 341 | if lt == IntType && rt == IntType && eq { 342 | return trueMemoryCell, "?column?", BoolType, nil 343 | } 344 | 345 | if lt == BoolType && rt == BoolType && eq { 346 | return trueMemoryCell, "?column?", BoolType, nil 347 | } 348 | 349 | return falseMemoryCell, "?column?", BoolType, nil 350 | case NeqSymbol: 351 | if len(l) == 0 || len(r) == 0 { 352 | return nullMemoryCell, "?column?", BoolType, nil 353 | } 354 | 355 | if lt != rt || !l.equals(r) { 356 | return trueMemoryCell, "?column?", BoolType, nil 357 | } 358 | 359 | return falseMemoryCell, "?column?", BoolType, nil 360 | case ConcatSymbol: 361 | if len(l) == 0 || len(r) == 0 { 362 | return nullMemoryCell, "?column?", TextType, nil 363 | } 364 | 365 | if lt != TextType || rt != TextType { 366 | return nil, "", 0, ErrInvalidOperands 367 | } 368 | 369 | return literalToMemoryCell(&Token{Kind: StringKind, Value: *l.AsText() + *r.AsText()}), "?column?", TextType, nil 370 | case PlusSymbol: 371 | if len(l) == 0 || len(r) == 0 { 372 | return nullMemoryCell, "?column?", IntType, nil 373 | } 374 | 375 | if lt != IntType || rt != IntType { 376 | return nil, "", 0, ErrInvalidOperands 377 | } 378 | 379 | iValue := int(*l.AsInt() + *r.AsInt()) 380 | return literalToMemoryCell(&Token{Kind: NumericKind, Value: strconv.Itoa(iValue)}), "?column?", IntType, nil 381 | case LtSymbol: 382 | if len(l) == 0 || len(r) == 0 { 383 | return nullMemoryCell, "?column?", BoolType, nil 384 | } 385 | 386 | if lt != IntType || rt != IntType { 387 | return nil, "", 0, ErrInvalidOperands 388 | } 389 | 390 | if *l.AsInt() < *r.AsInt() { 391 | return trueMemoryCell, "?column?", BoolType, nil 392 | } 393 | 394 | return falseMemoryCell, "?column?", BoolType, nil 395 | case LteSymbol: 396 | if len(l) == 0 || len(r) == 0 { 397 | return nullMemoryCell, "?column?", BoolType, nil 398 | } 399 | 400 | if lt != IntType || rt != IntType { 401 | return nil, "", 0, ErrInvalidOperands 402 | } 403 | 404 | if *l.AsInt() <= *r.AsInt() { 405 | return trueMemoryCell, "?column?", BoolType, nil 406 | } 407 | 408 | return falseMemoryCell, "?column?", BoolType, nil 409 | case GtSymbol: 410 | if len(l) == 0 || len(r) == 0 { 411 | return nullMemoryCell, "?column?", BoolType, nil 412 | } 413 | 414 | if lt != IntType || rt != IntType { 415 | return nil, "", 0, ErrInvalidOperands 416 | } 417 | 418 | if *l.AsInt() > *r.AsInt() { 419 | return trueMemoryCell, "?column?", BoolType, nil 420 | } 421 | 422 | return falseMemoryCell, "?column?", BoolType, nil 423 | case GteSymbol: 424 | if len(l) == 0 || len(r) == 0 { 425 | return nullMemoryCell, "?column?", BoolType, nil 426 | } 427 | 428 | if lt != IntType || rt != IntType { 429 | return nil, "", 0, ErrInvalidOperands 430 | } 431 | 432 | if *l.AsInt() >= *r.AsInt() { 433 | return trueMemoryCell, "?column?", BoolType, nil 434 | } 435 | 436 | return falseMemoryCell, "?column?", BoolType, nil 437 | default: 438 | // TODO 439 | break 440 | } 441 | case KeywordKind: 442 | switch Keyword(bexp.Op.Value) { 443 | case AndKeyword: 444 | res := falseMemoryCell 445 | if lt != BoolType || rt != BoolType { 446 | return nil, "", 0, ErrInvalidOperands 447 | } 448 | 449 | if len(l) == 0 || len(r) == 0 { 450 | res = nullMemoryCell 451 | } else if *l.AsBool() && *r.AsBool() { 452 | res = trueMemoryCell 453 | } 454 | 455 | return res, "?column?", BoolType, nil 456 | case OrKeyword: 457 | res := falseMemoryCell 458 | if lt != BoolType || rt != BoolType { 459 | return nil, "", 0, ErrInvalidOperands 460 | } 461 | 462 | if len(l) == 0 || len(r) == 0 { 463 | res = nullMemoryCell 464 | } else if *l.AsBool() || *r.AsBool() { 465 | res = trueMemoryCell 466 | } 467 | 468 | return res, "?column?", BoolType, nil 469 | default: 470 | // TODO 471 | break 472 | } 473 | } 474 | 475 | return nil, "", 0, ErrInvalidCell 476 | } 477 | 478 | func (t *table) evaluateCell(rowIndex uint, exp Expression) (memoryCell, string, ColumnType, error) { 479 | switch exp.Kind { 480 | case LiteralKind: 481 | return t.evaluateLiteralCell(rowIndex, exp) 482 | case BinaryKind: 483 | return t.evaluateBinaryCell(rowIndex, exp) 484 | default: 485 | return nil, "", 0, ErrInvalidCell 486 | } 487 | } 488 | 489 | type indexAndExpression struct { 490 | i *index 491 | e Expression 492 | } 493 | 494 | func (t *table) getApplicableIndexes(where *Expression) []indexAndExpression { 495 | var linearizeExpressions func(where *Expression, exps []Expression) []Expression 496 | linearizeExpressions = func(where *Expression, exps []Expression) []Expression { 497 | if where == nil || where.Kind != BinaryKind { 498 | return exps 499 | } 500 | 501 | if where.Binary.Op.Value == string(OrKeyword) { 502 | return exps 503 | } 504 | 505 | if where.Binary.Op.Value == string(AndKeyword) { 506 | exps := linearizeExpressions(&where.Binary.A, exps) 507 | return linearizeExpressions(&where.Binary.B, exps) 508 | } 509 | 510 | return append(exps, *where) 511 | } 512 | 513 | exps := linearizeExpressions(where, []Expression{}) 514 | 515 | iAndE := []indexAndExpression{} 516 | for _, exp := range exps { 517 | for _, index := range t.indexes { 518 | if index.applicableValue(exp) != nil { 519 | iAndE = append(iAndE, indexAndExpression{ 520 | i: index, 521 | e: exp, 522 | }) 523 | } 524 | } 525 | } 526 | 527 | return iAndE 528 | } 529 | 530 | type MemoryBackend struct { 531 | tables map[string]*table 532 | } 533 | 534 | func (mb *MemoryBackend) Select(slct *SelectStatement) (*Results, error) { 535 | t := createTable() 536 | 537 | if slct.From != nil { 538 | var ok bool 539 | t, ok = mb.tables[slct.From.Value] 540 | if !ok { 541 | return nil, ErrTableDoesNotExist 542 | } 543 | } 544 | 545 | if slct.Item == nil || len(*slct.Item) == 0 { 546 | return &Results{}, nil 547 | } 548 | 549 | results := [][]Cell{} 550 | columns := []ResultColumn{} 551 | 552 | if slct.From == nil { 553 | t = createTable() 554 | t.rows = [][]memoryCell{{}} 555 | } 556 | 557 | for _, iAndE := range t.getApplicableIndexes(slct.Where) { 558 | index := iAndE.i 559 | exp := iAndE.e 560 | t = index.newTableFromSubset(t, exp) 561 | } 562 | 563 | // Expand SELECT * at the AST level into a SELECT on all columns 564 | finalItems := []*SelectItem{} 565 | for _, item := range *slct.Item { 566 | if item.Asterisk { 567 | newItems := []*SelectItem{} 568 | for j := 0; j < len(t.columns); j++ { 569 | newSelectItem := &SelectItem{ 570 | Exp: &Expression{ 571 | Literal: &Token{ 572 | Value: t.columns[j], 573 | Kind: IdentifierKind, 574 | Loc: Location{0, uint(len("SELECT") + 1)}, 575 | }, 576 | Binary: nil, 577 | Kind: LiteralKind, 578 | }, 579 | Asterisk: false, 580 | As: nil, 581 | } 582 | newItems = append(newItems, newSelectItem) 583 | } 584 | finalItems = append(finalItems, newItems...) 585 | } else { 586 | finalItems = append(finalItems, item) 587 | } 588 | } 589 | 590 | limit := len(t.rows) 591 | if slct.Limit != nil { 592 | v, _, _, err := t.evaluateCell(0, *slct.Limit) 593 | if err != nil { 594 | return nil, err 595 | } 596 | 597 | limit = int(*v.AsInt()) 598 | } 599 | if limit < 0 { 600 | return nil, fmt.Errorf("Invalid, negative limit") 601 | } 602 | 603 | offset := 0 604 | if slct.Offset != nil { 605 | v, _, _, err := t.evaluateCell(0, *slct.Offset) 606 | if err != nil { 607 | return nil, err 608 | } 609 | 610 | offset = int(*v.AsInt()) 611 | } 612 | if offset < 0 { 613 | return nil, fmt.Errorf("Invalid, negative limit") 614 | } 615 | 616 | rowIndex := -1 617 | for i := range t.rows { 618 | result := []Cell{} 619 | isFirstRow := len(results) == 0 620 | 621 | if slct.Where != nil { 622 | val, _, _, err := t.evaluateCell(uint(i), *slct.Where) 623 | if err != nil { 624 | return nil, err 625 | } 626 | 627 | if !*val.AsBool() { 628 | continue 629 | } 630 | } 631 | 632 | rowIndex++ 633 | if rowIndex < offset { 634 | continue 635 | } else if rowIndex > offset+limit-1 { 636 | break 637 | } 638 | 639 | for _, col := range finalItems { 640 | value, columnName, columnType, err := t.evaluateCell(uint(i), *col.Exp) 641 | if err != nil { 642 | return nil, err 643 | } 644 | 645 | if isFirstRow { 646 | columns = append(columns, ResultColumn{ 647 | Type: columnType, 648 | Name: columnName, 649 | }) 650 | } 651 | 652 | result = append(result, value) 653 | } 654 | 655 | results = append(results, result) 656 | } 657 | 658 | return &Results{ 659 | Columns: columns, 660 | Rows: results, 661 | }, nil 662 | } 663 | 664 | func (mb *MemoryBackend) Insert(inst *InsertStatement) error { 665 | t, ok := mb.tables[inst.Table.Value] 666 | if !ok { 667 | return ErrTableDoesNotExist 668 | } 669 | 670 | if inst.Values == nil { 671 | return nil 672 | } 673 | 674 | if len(*inst.Values) != len(t.columns) { 675 | return ErrMissingValues 676 | } 677 | 678 | row := []memoryCell{} 679 | for _, valueNode := range *inst.Values { 680 | if valueNode.Kind != LiteralKind { 681 | fmt.Println("Skipping non-literal.") 682 | continue 683 | } 684 | 685 | emptyTable := createTable() 686 | value, _, _, err := emptyTable.evaluateCell(0, *valueNode) 687 | if err != nil { 688 | return err 689 | } 690 | 691 | row = append(row, value) 692 | } 693 | 694 | t.rows = append(t.rows, row) 695 | 696 | for _, index := range t.indexes { 697 | err := index.addRow(t, uint(len(t.rows)-1)) 698 | if err != nil { 699 | // Drop the row on failure 700 | t.rows = t.rows[:len(t.rows)-1] 701 | return err 702 | } 703 | } 704 | 705 | return nil 706 | } 707 | 708 | func (mb *MemoryBackend) CreateTable(crt *CreateTableStatement) error { 709 | if _, ok := mb.tables[crt.Name.Value]; ok { 710 | return ErrTableAlreadyExists 711 | } 712 | 713 | t := createTable() 714 | t.name = crt.Name.Value 715 | mb.tables[t.name] = t 716 | if crt.Cols == nil { 717 | return nil 718 | } 719 | 720 | var primaryKey *Expression = nil 721 | for _, col := range *crt.Cols { 722 | t.columns = append(t.columns, col.Name.Value) 723 | 724 | var dt ColumnType 725 | switch col.Datatype.Value { 726 | case "int": 727 | dt = IntType 728 | case "text": 729 | dt = TextType 730 | case "boolean": 731 | dt = BoolType 732 | default: 733 | delete(mb.tables, t.name) 734 | return ErrInvalidDatatype 735 | } 736 | 737 | if col.PrimaryKey { 738 | if primaryKey != nil { 739 | delete(mb.tables, t.name) 740 | return ErrPrimaryKeyAlreadyExists 741 | } 742 | 743 | primaryKey = &Expression{ 744 | Literal: &col.Name, 745 | Kind: LiteralKind, 746 | } 747 | } 748 | 749 | t.columnTypes = append(t.columnTypes, dt) 750 | } 751 | 752 | if primaryKey != nil { 753 | err := mb.CreateIndex(&CreateIndexStatement{ 754 | Table: crt.Name, 755 | Name: Token{Value: t.name + "_pkey"}, 756 | Unique: true, 757 | PrimaryKey: true, 758 | Exp: *primaryKey, 759 | }) 760 | if err != nil { 761 | delete(mb.tables, t.name) 762 | return err 763 | } 764 | } 765 | 766 | return nil 767 | } 768 | 769 | func (mb *MemoryBackend) CreateIndex(ci *CreateIndexStatement) error { 770 | table, ok := mb.tables[ci.Table.Value] 771 | if !ok { 772 | return ErrTableDoesNotExist 773 | } 774 | 775 | for _, index := range table.indexes { 776 | if index.name == ci.Name.Value { 777 | return ErrIndexAlreadyExists 778 | } 779 | } 780 | 781 | index := &index{ 782 | exp: ci.Exp, 783 | unique: ci.Unique, 784 | primaryKey: ci.PrimaryKey, 785 | name: ci.Name.Value, 786 | tree: llrb.New(), 787 | typ: "rbtree", 788 | } 789 | table.indexes = append(table.indexes, index) 790 | 791 | for i := range table.rows { 792 | err := index.addRow(table, uint(i)) 793 | if err != nil { 794 | return err 795 | } 796 | } 797 | 798 | return nil 799 | } 800 | 801 | func (mb *MemoryBackend) DropTable(dt *DropTableStatement) error { 802 | if _, ok := mb.tables[dt.Name.Value]; ok { 803 | delete(mb.tables, dt.Name.Value) 804 | return nil 805 | } 806 | return ErrTableDoesNotExist 807 | } 808 | 809 | func (mb *MemoryBackend) GetTables() []TableMetadata { 810 | tms := []TableMetadata{} 811 | for name, t := range mb.tables { 812 | tm := TableMetadata{} 813 | tm.Name = name 814 | 815 | pkeyColumn := "" 816 | for _, i := range t.indexes { 817 | if i.primaryKey { 818 | pkeyColumn = i.exp.GenerateCode() 819 | } 820 | 821 | tm.Indexes = append(tm.Indexes, Index{ 822 | Name: i.name, 823 | Type: i.typ, 824 | Unique: i.unique, 825 | PrimaryKey: i.primaryKey, 826 | Exp: i.exp.GenerateCode(), 827 | }) 828 | } 829 | 830 | for i, column := range t.columns { 831 | tm.Columns = append(tm.Columns, ResultColumn{ 832 | Type: t.columnTypes[i], 833 | Name: column, 834 | NotNull: pkeyColumn == `"`+column+`"`, 835 | }) 836 | } 837 | 838 | tms = append(tms, tm) 839 | } 840 | 841 | return tms 842 | } 843 | 844 | func NewMemoryBackend() *MemoryBackend { 845 | return &MemoryBackend{ 846 | tables: map[string]*table{}, 847 | } 848 | } 849 | -------------------------------------------------------------------------------- /memory_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "fmt" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | ) 8 | 9 | var mb *MemoryBackend 10 | 11 | func TestSelect(t *testing.T) { 12 | mb = NewMemoryBackend() 13 | 14 | parser := Parser{HelpMessagesDisabled: true} 15 | ast, err := parser.Parse("SELECT * FROM test") 16 | assert.Nil(t, err) 17 | assert.NotEqual(t, ast, nil) 18 | _, err = mb.Select(ast.Statements[0].SelectStatement) 19 | assert.Equal(t, err, ErrTableDoesNotExist) 20 | 21 | ast, err = parser.Parse("CREATE TABLE test(x INT, y INT, z BOOLEAN);") 22 | assert.Nil(t, err) 23 | assert.NotEqual(t, ast, nil) 24 | err = mb.CreateTable(ast.Statements[0].CreateTableStatement) 25 | assert.Nil(t, err) 26 | 27 | ast, err = parser.Parse("INSERT INTO test VALUES(100, 200, true)") 28 | assert.Nil(t, err) 29 | assert.NotEqual(t, ast, nil) 30 | err = mb.Insert(ast.Statements[0].InsertStatement) 31 | assert.Nil(t, err) 32 | 33 | Value100 := literalToMemoryCell(&Token{"100", NumericKind, Location{}}) 34 | Value200 := literalToMemoryCell(&Token{"200", NumericKind, Location{}}) 35 | xCol := ResultColumn{IntType, "x", false} 36 | yCol := ResultColumn{IntType, "y", false} 37 | zCol := ResultColumn{BoolType, "z", false} 38 | 39 | tests := []struct { 40 | query string 41 | results Results 42 | }{ 43 | { 44 | "SELECT * FROM test", 45 | Results{ 46 | []ResultColumn{xCol, yCol, zCol}, 47 | [][]Cell{{Value100, Value200, trueMemoryCell}}, 48 | }, 49 | }, 50 | { 51 | "SELECT x FROM test", 52 | Results{ 53 | []ResultColumn{xCol}, 54 | [][]Cell{{Value100}}, 55 | }, 56 | }, 57 | { 58 | "SELECT x, y FROM test", 59 | Results{ 60 | []ResultColumn{xCol, yCol}, 61 | [][]Cell{{Value100, Value200}}, 62 | }, 63 | }, 64 | { 65 | "SELECT x, y, z FROM test", 66 | Results{ 67 | []ResultColumn{xCol, yCol, zCol}, 68 | [][]Cell{{Value100, Value200, trueMemoryCell}}, 69 | }, 70 | }, 71 | { 72 | "SELECT *, x FROM test", 73 | Results{ 74 | []ResultColumn{xCol, yCol, zCol, xCol}, 75 | [][]Cell{{Value100, Value200, trueMemoryCell, Value100}}, 76 | }, 77 | }, 78 | { 79 | "SELECT *, x, y FROM test", 80 | Results{ 81 | []ResultColumn{xCol, yCol, zCol, xCol, yCol}, 82 | [][]Cell{{Value100, Value200, trueMemoryCell, Value100, Value200}}, 83 | }, 84 | }, 85 | { 86 | "SELECT *, x, y, z FROM test", 87 | Results{ 88 | []ResultColumn{xCol, yCol, zCol, xCol, yCol, zCol}, 89 | [][]Cell{{Value100, Value200, trueMemoryCell, Value100, Value200, trueMemoryCell}}, 90 | }, 91 | }, 92 | { 93 | "SELECT x, *, z FROM test", 94 | Results{ 95 | []ResultColumn{xCol, xCol, yCol, zCol, zCol}, 96 | [][]Cell{{Value100, Value100, Value200, trueMemoryCell, trueMemoryCell}}, 97 | }, 98 | }, 99 | } 100 | 101 | for _, test := range tests { 102 | ast, err = parser.Parse(test.query) 103 | assert.Nil(t, err) 104 | assert.NotEqual(t, ast, nil) 105 | 106 | var res *Results 107 | res, err = mb.Select(ast.Statements[0].SelectStatement) 108 | assert.Nil(t, err) 109 | assert.Equal(t, *res, test.results) 110 | } 111 | } 112 | 113 | func TestInsert(t *testing.T) { 114 | mb = NewMemoryBackend() 115 | 116 | parser := Parser{HelpMessagesDisabled: true} 117 | ast, err := parser.Parse("INSERT INTO test VALUES(100, 200, 300)") 118 | assert.Nil(t, err) 119 | assert.NotEqual(t, ast, nil) 120 | err = mb.Insert(ast.Statements[0].InsertStatement) 121 | assert.Equal(t, err, ErrTableDoesNotExist) 122 | 123 | ast, err = parser.Parse("CREATE TABLE test(x INT, y INT, z INT);") 124 | assert.Nil(t, err) 125 | assert.NotEqual(t, ast, nil) 126 | err = mb.CreateTable(ast.Statements[0].CreateTableStatement) 127 | assert.Nil(t, err) 128 | 129 | ast, err = parser.Parse("INSERT INTO test VALUES(100, 200, 300)") 130 | assert.Nil(t, err) 131 | assert.NotEqual(t, ast, nil) 132 | err = mb.Insert(ast.Statements[0].InsertStatement) 133 | assert.Nil(t, err) 134 | } 135 | 136 | func TestCreateTable(t *testing.T) { 137 | mb = NewMemoryBackend() 138 | 139 | parser := Parser{HelpMessagesDisabled: true} 140 | ast, err := parser.Parse("CREATE TABLE test(x INT, y INT, z INT)") 141 | assert.Nil(t, err) 142 | err = mb.CreateTable(ast.Statements[0].CreateTableStatement) 143 | assert.Nil(t, err) 144 | assert.Equal(t, mb.tables["test"].name, "test") 145 | assert.Equal(t, mb.tables["test"].columns, []string{"x", "y", "z"}) 146 | 147 | // Second time, already exists 148 | err = mb.CreateTable(ast.Statements[0].CreateTableStatement) 149 | assert.Equal(t, ErrTableAlreadyExists, err) 150 | } 151 | 152 | func TestCreateIndex(t *testing.T) { 153 | mb = NewMemoryBackend() 154 | 155 | parser := Parser{HelpMessagesDisabled: true} 156 | ast, err := parser.Parse("CREATE TABLE test(x INT, y INT, z INT)") 157 | assert.Nil(t, err) 158 | err = mb.CreateTable(ast.Statements[0].CreateTableStatement) 159 | assert.Nil(t, err) 160 | 161 | ast, err = parser.Parse("CREATE INDEX foo ON test (x);") 162 | assert.Nil(t, err) 163 | err = mb.CreateIndex(ast.Statements[0].CreateIndexStatement) 164 | assert.Nil(t, err) 165 | assert.Equal(t, mb.tables["test"].indexes[0].name, "foo") 166 | assert.Equal(t, mb.tables["test"].indexes[0].exp.GenerateCode(), `"x"`) 167 | 168 | // Second time, already exists 169 | err = mb.CreateIndex(ast.Statements[0].CreateIndexStatement) 170 | assert.Equal(t, ErrIndexAlreadyExists, err) 171 | } 172 | 173 | func TestDropTable(t *testing.T) { 174 | mb = NewMemoryBackend() 175 | 176 | parser := Parser{HelpMessagesDisabled: true} 177 | ast, err := parser.Parse("DROP TABLE test;") 178 | assert.Nil(t, err) 179 | assert.NotEqual(t, ast, nil) 180 | err = mb.DropTable(ast.Statements[0].DropTableStatement) 181 | assert.Equal(t, err, ErrTableDoesNotExist) 182 | 183 | ast, err = parser.Parse("CREATE TABLE test(x INT, y INT, z INT);") 184 | assert.Nil(t, err) 185 | err = mb.CreateTable(ast.Statements[0].CreateTableStatement) 186 | assert.Nil(t, err) 187 | assert.NotEqual(t, ast, nil) 188 | 189 | ast, err = parser.Parse("DROP TABLE test;") 190 | assert.Nil(t, err) 191 | assert.NotEqual(t, ast, nil) 192 | err = mb.DropTable(ast.Statements[0].DropTableStatement) 193 | assert.Nil(t, err) 194 | } 195 | 196 | func TestTable_GetApplicableIndexes(t *testing.T) { 197 | mb := NewMemoryBackend() 198 | 199 | parser := Parser{HelpMessagesDisabled: true} 200 | ast, err := parser.Parse("CREATE TABLE test (x INT, y INT);") 201 | assert.Nil(t, err) 202 | err = mb.CreateTable(ast.Statements[0].CreateTableStatement) 203 | assert.Nil(t, err) 204 | 205 | ast, err = parser.Parse("CREATE INDEX x_idx ON test (x);") 206 | assert.Nil(t, err) 207 | err = mb.CreateIndex(ast.Statements[0].CreateIndexStatement) 208 | assert.Nil(t, err) 209 | 210 | tests := []struct { 211 | where string 212 | indexes []string 213 | }{ 214 | { 215 | "x = 2 OR y = 3", 216 | []string{}, 217 | }, 218 | { 219 | "x = 2", 220 | []string{`"x"`}, 221 | }, 222 | { 223 | "x = 2 AND y = 3", 224 | []string{`"x"`}, 225 | }, 226 | { 227 | "x = 2 AND (y = 3 OR y = 5)", 228 | []string{`"x"`}, 229 | }, 230 | } 231 | 232 | for _, test := range tests { 233 | ast, err = parser.Parse(fmt.Sprintf("SELECT * FROM test WHERE %s", test.where)) 234 | assert.Nil(t, err, test.where) 235 | where := ast.Statements[0].SelectStatement.Where 236 | indexes := []string{} 237 | for _, i := range mb.tables["test"].getApplicableIndexes(where) { 238 | indexes = append(indexes, i.i.exp.GenerateCode()) 239 | } 240 | assert.Equal(t, test.indexes, indexes, test.where) 241 | } 242 | } 243 | 244 | func TestLiteralToMemoryCell(t *testing.T) { 245 | var i *int32 246 | assert.Equal(t, i, literalToMemoryCell(&Token{Value: "null", Kind: NullKind}).AsInt()) 247 | assert.Equal(t, i, literalToMemoryCell(&Token{Value: "not an int", Kind: NumericKind}).AsInt()) 248 | assert.Equal(t, int32(2), *literalToMemoryCell(&Token{Value: "2", Kind: NumericKind}).AsInt()) 249 | 250 | var s *string 251 | assert.Equal(t, s, literalToMemoryCell(&Token{Value: "null", Kind: NullKind}).AsText()) 252 | assert.Equal(t, "foo", *literalToMemoryCell(&Token{Value: "foo", Kind: StringKind}).AsText()) 253 | 254 | var b *bool 255 | assert.Equal(t, b, literalToMemoryCell(&Token{Value: "null", Kind: NullKind}).AsBool()) 256 | assert.Equal(t, true, *literalToMemoryCell(&Token{Value: "true", Kind: BoolKind}).AsBool()) 257 | assert.Equal(t, false, *literalToMemoryCell(&Token{Value: "false", Kind: BoolKind}).AsBool()) 258 | } 259 | -------------------------------------------------------------------------------- /parser.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | func tokenFromKeyword(k Keyword) Token { 9 | return Token{ 10 | Kind: KeywordKind, 11 | Value: string(k), 12 | } 13 | } 14 | 15 | func tokenFromSymbol(s Symbol) Token { 16 | return Token{ 17 | Kind: SymbolKind, 18 | Value: string(s), 19 | } 20 | } 21 | 22 | type Parser struct { 23 | HelpMessagesDisabled bool 24 | } 25 | 26 | // helpMessage prints errors found while parsing 27 | func (p Parser) helpMessage(tokens []*Token, cursor uint, msg string) { 28 | if p.HelpMessagesDisabled { 29 | return 30 | } 31 | 32 | var c *Token 33 | if cursor+1 < uint(len(tokens)) { 34 | c = tokens[cursor+1] 35 | } else { 36 | c = tokens[cursor] 37 | } 38 | 39 | fmt.Printf("[%d,%d]: %s, near: %s\n", c.Loc.Line, c.Loc.Col, msg, c.Value) 40 | } 41 | 42 | // parseTokenKind looks for a token of the given kind 43 | func (p Parser) parseTokenKind(tokens []*Token, initialCursor uint, kind TokenKind) (*Token, uint, bool) { 44 | cursor := initialCursor 45 | 46 | if cursor >= uint(len(tokens)) { 47 | return nil, initialCursor, false 48 | } 49 | 50 | current := tokens[cursor] 51 | if current.Kind == kind { 52 | return current, cursor + 1, true 53 | } 54 | 55 | return nil, initialCursor, false 56 | } 57 | 58 | // parseToken looks for a Token the same as passed in (ignoring Token 59 | // location) 60 | func (p Parser) parseToken(tokens []*Token, initialCursor uint, t Token) (*Token, uint, bool) { 61 | cursor := initialCursor 62 | 63 | if cursor >= uint(len(tokens)) { 64 | return nil, initialCursor, false 65 | } 66 | 67 | if p := tokens[cursor]; t.equals(p) { 68 | return p, cursor + 1, true 69 | } 70 | 71 | return nil, initialCursor, false 72 | } 73 | 74 | func (p Parser) parseLiteralExpression(tokens []*Token, initialCursor uint) (*Expression, uint, bool) { 75 | cursor := initialCursor 76 | 77 | kinds := []TokenKind{IdentifierKind, NumericKind, StringKind, BoolKind, NullKind} 78 | for _, kind := range kinds { 79 | t, newCursor, ok := p.parseTokenKind(tokens, cursor, kind) 80 | if ok { 81 | return &Expression{ 82 | Literal: t, 83 | Kind: LiteralKind, 84 | }, newCursor, true 85 | } 86 | } 87 | 88 | return nil, initialCursor, false 89 | } 90 | 91 | func (p Parser) parseExpression(tokens []*Token, initialCursor uint, delimiters []Token, minBp uint) (*Expression, uint, bool) { 92 | cursor := initialCursor 93 | 94 | var exp *Expression 95 | _, newCursor, ok := p.parseToken(tokens, cursor, tokenFromSymbol(LeftParenSymbol)) 96 | if ok { 97 | cursor = newCursor 98 | RightParenToken := tokenFromSymbol(RightParenSymbol) 99 | 100 | exp, cursor, ok = p.parseExpression(tokens, cursor, append(delimiters, RightParenToken), minBp) 101 | if !ok { 102 | p.helpMessage(tokens, cursor, "Expected expression after opening paren") 103 | return nil, initialCursor, false 104 | } 105 | 106 | _, cursor, ok = p.parseToken(tokens, cursor, RightParenToken) 107 | if !ok { 108 | p.helpMessage(tokens, cursor, "Expected closing paren") 109 | return nil, initialCursor, false 110 | } 111 | } else { 112 | exp, cursor, ok = p.parseLiteralExpression(tokens, cursor) 113 | if !ok { 114 | return nil, initialCursor, false 115 | } 116 | } 117 | 118 | lastCursor := cursor 119 | outer: 120 | for cursor < uint(len(tokens)) { 121 | for _, d := range delimiters { 122 | _, _, ok = p.parseToken(tokens, cursor, d) 123 | if ok { 124 | break outer 125 | } 126 | } 127 | 128 | binOps := []Token{ 129 | tokenFromKeyword(AndKeyword), 130 | tokenFromKeyword(OrKeyword), 131 | tokenFromSymbol(EqSymbol), 132 | tokenFromSymbol(NeqSymbol), 133 | tokenFromSymbol(LtSymbol), 134 | tokenFromSymbol(LteSymbol), 135 | tokenFromSymbol(GtSymbol), 136 | tokenFromSymbol(GteSymbol), 137 | tokenFromSymbol(ConcatSymbol), 138 | tokenFromSymbol(PlusSymbol), 139 | } 140 | 141 | var op *Token 142 | for _, bo := range binOps { 143 | var t *Token 144 | t, cursor, ok = p.parseToken(tokens, cursor, bo) 145 | if ok { 146 | op = t 147 | break 148 | } 149 | } 150 | 151 | if op == nil { 152 | p.helpMessage(tokens, cursor, "Expected binary operator") 153 | return nil, initialCursor, false 154 | } 155 | 156 | bp := op.bindingPower() 157 | if bp < minBp { 158 | cursor = lastCursor 159 | break 160 | } 161 | 162 | b, newCursor, ok := p.parseExpression(tokens, cursor, delimiters, bp) 163 | if !ok { 164 | p.helpMessage(tokens, cursor, "Expected right operand") 165 | return nil, initialCursor, false 166 | } 167 | exp = &Expression{ 168 | Binary: &BinaryExpression{ 169 | *exp, 170 | *b, 171 | *op, 172 | }, 173 | Kind: BinaryKind, 174 | } 175 | cursor = newCursor 176 | lastCursor = cursor 177 | } 178 | 179 | return exp, cursor, true 180 | } 181 | 182 | func (p Parser) parseSelectItem(tokens []*Token, initialCursor uint, delimiters []Token) (*[]*SelectItem, uint, bool) { 183 | cursor := initialCursor 184 | 185 | var s []*SelectItem 186 | outer: 187 | for { 188 | if cursor >= uint(len(tokens)) { 189 | return nil, initialCursor, false 190 | } 191 | 192 | current := tokens[cursor] 193 | for _, delimiter := range delimiters { 194 | if delimiter.equals(current) { 195 | break outer 196 | } 197 | } 198 | 199 | var ok bool 200 | if len(s) > 0 { 201 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromSymbol(CommaSymbol)) 202 | if !ok { 203 | p.helpMessage(tokens, cursor, "Expected comma") 204 | return nil, initialCursor, false 205 | } 206 | } 207 | 208 | var si SelectItem 209 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromSymbol(AsteriskSymbol)) 210 | if ok { 211 | si = SelectItem{Asterisk: true} 212 | } else { 213 | asToken := tokenFromKeyword(AsKeyword) 214 | delimiters := append(delimiters, tokenFromSymbol(CommaSymbol), asToken) 215 | exp, newCursor, ok := p.parseExpression(tokens, cursor, delimiters, 0) 216 | if !ok { 217 | p.helpMessage(tokens, cursor, "Expected expression") 218 | return nil, initialCursor, false 219 | } 220 | 221 | cursor = newCursor 222 | si.Exp = exp 223 | 224 | _, cursor, ok = p.parseToken(tokens, cursor, asToken) 225 | if ok { 226 | id, newCursor, ok := p.parseTokenKind(tokens, cursor, IdentifierKind) 227 | if !ok { 228 | p.helpMessage(tokens, cursor, "Expected identifier after AS") 229 | return nil, initialCursor, false 230 | } 231 | 232 | cursor = newCursor 233 | si.As = id 234 | } 235 | } 236 | 237 | s = append(s, &si) 238 | } 239 | 240 | return &s, cursor, true 241 | } 242 | 243 | func (p Parser) parseSelectStatement(tokens []*Token, initialCursor uint, delimiter Token) (*SelectStatement, uint, bool) { 244 | var ok bool 245 | cursor := initialCursor 246 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(SelectKeyword)) 247 | if !ok { 248 | return nil, initialCursor, false 249 | } 250 | 251 | slct := SelectStatement{} 252 | 253 | fromToken := tokenFromKeyword(FromKeyword) 254 | item, newCursor, ok := p.parseSelectItem(tokens, cursor, []Token{fromToken, delimiter}) 255 | if !ok { 256 | return nil, initialCursor, false 257 | } 258 | 259 | slct.Item = item 260 | cursor = newCursor 261 | 262 | whereToken := tokenFromKeyword(WhereKeyword) 263 | 264 | _, cursor, ok = p.parseToken(tokens, cursor, fromToken) 265 | if ok { 266 | from, newCursor, ok := p.parseTokenKind(tokens, cursor, IdentifierKind) 267 | if !ok { 268 | p.helpMessage(tokens, cursor, "Expected FROM item") 269 | return nil, initialCursor, false 270 | } 271 | 272 | slct.From = from 273 | cursor = newCursor 274 | } 275 | 276 | limitToken := tokenFromKeyword(LimitKeyword) 277 | offsetToken := tokenFromKeyword(OffsetKeyword) 278 | 279 | _, cursor, ok = p.parseToken(tokens, cursor, whereToken) 280 | if ok { 281 | where, newCursor, ok := p.parseExpression(tokens, cursor, []Token{limitToken, offsetToken, delimiter}, 0) 282 | if !ok { 283 | p.helpMessage(tokens, cursor, "Expected WHERE conditionals") 284 | return nil, initialCursor, false 285 | } 286 | 287 | slct.Where = where 288 | cursor = newCursor 289 | } 290 | 291 | _, cursor, ok = p.parseToken(tokens, cursor, limitToken) 292 | if ok { 293 | limit, newCursor, ok := p.parseExpression(tokens, cursor, []Token{offsetToken, delimiter}, 0) 294 | if !ok { 295 | p.helpMessage(tokens, cursor, "Expected LIMIT value") 296 | return nil, initialCursor, false 297 | } 298 | 299 | slct.Limit = limit 300 | cursor = newCursor 301 | } 302 | 303 | _, cursor, ok = p.parseToken(tokens, cursor, offsetToken) 304 | if ok { 305 | offset, newCursor, ok := p.parseExpression(tokens, cursor, []Token{delimiter}, 0) 306 | if !ok { 307 | p.helpMessage(tokens, cursor, "Expected OFFSET value") 308 | return nil, initialCursor, false 309 | } 310 | 311 | slct.Offset = offset 312 | cursor = newCursor 313 | } 314 | 315 | return &slct, cursor, true 316 | } 317 | 318 | func (p Parser) parseExpressions(tokens []*Token, initialCursor uint, delimiter Token) (*[]*Expression, uint, bool) { 319 | cursor := initialCursor 320 | 321 | var exps []*Expression 322 | for { 323 | if cursor >= uint(len(tokens)) { 324 | return nil, initialCursor, false 325 | } 326 | 327 | current := tokens[cursor] 328 | if delimiter.equals(current) { 329 | break 330 | } 331 | 332 | if len(exps) > 0 { 333 | var ok bool 334 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromSymbol(CommaSymbol)) 335 | if !ok { 336 | p.helpMessage(tokens, cursor, "Expected comma") 337 | return nil, initialCursor, false 338 | } 339 | } 340 | 341 | exp, newCursor, ok := p.parseExpression(tokens, cursor, []Token{tokenFromSymbol(CommaSymbol), tokenFromSymbol(RightParenSymbol)}, 0) 342 | if !ok { 343 | p.helpMessage(tokens, cursor, "Expected expression") 344 | return nil, initialCursor, false 345 | } 346 | cursor = newCursor 347 | 348 | exps = append(exps, exp) 349 | } 350 | 351 | return &exps, cursor, true 352 | } 353 | 354 | func (p Parser) parseInsertStatement(tokens []*Token, initialCursor uint, _ Token) (*InsertStatement, uint, bool) { 355 | cursor := initialCursor 356 | ok := false 357 | 358 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(InsertKeyword)) 359 | if !ok { 360 | return nil, initialCursor, false 361 | } 362 | 363 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(IntoKeyword)) 364 | if !ok { 365 | p.helpMessage(tokens, cursor, "Expected into") 366 | return nil, initialCursor, false 367 | } 368 | 369 | table, newCursor, ok := p.parseTokenKind(tokens, cursor, IdentifierKind) 370 | if !ok { 371 | p.helpMessage(tokens, cursor, "Expected table name") 372 | return nil, initialCursor, false 373 | } 374 | cursor = newCursor 375 | 376 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(ValuesKeyword)) 377 | if !ok { 378 | p.helpMessage(tokens, cursor, "Expected VALUES") 379 | return nil, initialCursor, false 380 | } 381 | 382 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromSymbol(LeftParenSymbol)) 383 | if !ok { 384 | p.helpMessage(tokens, cursor, "Expected left paren") 385 | return nil, initialCursor, false 386 | } 387 | 388 | values, newCursor, ok := p.parseExpressions(tokens, cursor, tokenFromSymbol(RightParenSymbol)) 389 | if !ok { 390 | p.helpMessage(tokens, cursor, "Expected expressions") 391 | return nil, initialCursor, false 392 | } 393 | cursor = newCursor 394 | 395 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromSymbol(RightParenSymbol)) 396 | if !ok { 397 | p.helpMessage(tokens, cursor, "Expected right paren") 398 | return nil, initialCursor, false 399 | } 400 | 401 | return &InsertStatement{ 402 | Table: *table, 403 | Values: values, 404 | }, cursor, true 405 | } 406 | 407 | func (p Parser) parseColumnDefinitions(tokens []*Token, initialCursor uint, delimiter Token) (*[]*ColumnDefinition, uint, bool) { 408 | cursor := initialCursor 409 | 410 | var cds []*ColumnDefinition 411 | for { 412 | if cursor >= uint(len(tokens)) { 413 | return nil, initialCursor, false 414 | } 415 | 416 | current := tokens[cursor] 417 | if delimiter.equals(current) { 418 | break 419 | } 420 | 421 | if len(cds) > 0 { 422 | var ok bool 423 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromSymbol(CommaSymbol)) 424 | if !ok { 425 | p.helpMessage(tokens, cursor, "Expected comma") 426 | return nil, initialCursor, false 427 | } 428 | } 429 | 430 | id, newCursor, ok := p.parseTokenKind(tokens, cursor, IdentifierKind) 431 | if !ok { 432 | p.helpMessage(tokens, cursor, "Expected column name") 433 | return nil, initialCursor, false 434 | } 435 | cursor = newCursor 436 | 437 | ty, newCursor, ok := p.parseTokenKind(tokens, cursor, KeywordKind) 438 | if !ok { 439 | p.helpMessage(tokens, cursor, "Expected column type") 440 | return nil, initialCursor, false 441 | } 442 | cursor = newCursor 443 | 444 | primaryKey := false 445 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(PrimarykeyKeyword)) 446 | if ok { 447 | primaryKey = true 448 | } 449 | 450 | cds = append(cds, &ColumnDefinition{ 451 | Name: *id, 452 | Datatype: *ty, 453 | PrimaryKey: primaryKey, 454 | }) 455 | } 456 | 457 | return &cds, cursor, true 458 | } 459 | 460 | func (p Parser) parseCreateTableStatement(tokens []*Token, initialCursor uint, _ Token) (*CreateTableStatement, uint, bool) { 461 | cursor := initialCursor 462 | ok := false 463 | 464 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(CreateKeyword)) 465 | if !ok { 466 | return nil, initialCursor, false 467 | } 468 | 469 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(TableKeyword)) 470 | if !ok { 471 | return nil, initialCursor, false 472 | } 473 | 474 | name, newCursor, ok := p.parseTokenKind(tokens, cursor, IdentifierKind) 475 | if !ok { 476 | p.helpMessage(tokens, cursor, "Expected table name") 477 | return nil, initialCursor, false 478 | } 479 | cursor = newCursor 480 | 481 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromSymbol(LeftParenSymbol)) 482 | if !ok { 483 | p.helpMessage(tokens, cursor, "Expected left parenthesis") 484 | return nil, initialCursor, false 485 | } 486 | 487 | cols, newCursor, ok := p.parseColumnDefinitions(tokens, cursor, tokenFromSymbol(RightParenSymbol)) 488 | if !ok { 489 | return nil, initialCursor, false 490 | } 491 | cursor = newCursor 492 | 493 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromSymbol(RightParenSymbol)) 494 | if !ok { 495 | p.helpMessage(tokens, cursor, "Expected right parenthesis") 496 | return nil, initialCursor, false 497 | } 498 | 499 | return &CreateTableStatement{ 500 | Name: *name, 501 | Cols: cols, 502 | }, cursor, true 503 | } 504 | 505 | func (p Parser) parseDropTableStatement(tokens []*Token, initialCursor uint, _ Token) (*DropTableStatement, uint, bool) { 506 | cursor := initialCursor 507 | ok := false 508 | 509 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(DropKeyword)) 510 | if !ok { 511 | return nil, initialCursor, false 512 | } 513 | 514 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(TableKeyword)) 515 | if !ok { 516 | return nil, initialCursor, false 517 | } 518 | 519 | name, newCursor, ok := p.parseTokenKind(tokens, cursor, IdentifierKind) 520 | if !ok { 521 | p.helpMessage(tokens, cursor, "Expected table name") 522 | return nil, initialCursor, false 523 | } 524 | cursor = newCursor 525 | 526 | return &DropTableStatement{ 527 | Name: *name, 528 | }, cursor, true 529 | } 530 | 531 | func (p Parser) parseStatement(tokens []*Token, initialCursor uint, _ Token) (*Statement, uint, bool) { 532 | cursor := initialCursor 533 | 534 | semicolonToken := tokenFromSymbol(SemicolonSymbol) 535 | slct, newCursor, ok := p.parseSelectStatement(tokens, cursor, semicolonToken) 536 | if ok { 537 | return &Statement{ 538 | Kind: SelectKind, 539 | SelectStatement: slct, 540 | }, newCursor, true 541 | } 542 | 543 | inst, newCursor, ok := p.parseInsertStatement(tokens, cursor, semicolonToken) 544 | if ok { 545 | return &Statement{ 546 | Kind: InsertKind, 547 | InsertStatement: inst, 548 | }, newCursor, true 549 | } 550 | 551 | crtTbl, newCursor, ok := p.parseCreateTableStatement(tokens, cursor, semicolonToken) 552 | if ok { 553 | return &Statement{ 554 | Kind: CreateTableKind, 555 | CreateTableStatement: crtTbl, 556 | }, newCursor, true 557 | } 558 | 559 | crtIdx, newCursor, ok := p.parseCreateIndexStatement(tokens, cursor, semicolonToken) 560 | if ok { 561 | return &Statement{ 562 | Kind: CreateIndexKind, 563 | CreateIndexStatement: crtIdx, 564 | }, newCursor, true 565 | } 566 | 567 | dpTbl, newCursor, ok := p.parseDropTableStatement(tokens, cursor, semicolonToken) 568 | if ok { 569 | return &Statement{ 570 | Kind: DropTableKind, 571 | DropTableStatement: dpTbl, 572 | }, newCursor, true 573 | } 574 | 575 | return nil, initialCursor, false 576 | } 577 | 578 | func (p Parser) parseCreateIndexStatement(tokens []*Token, initialCursor uint, delimiter Token) (*CreateIndexStatement, uint, bool) { 579 | cursor := initialCursor 580 | ok := false 581 | 582 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(CreateKeyword)) 583 | if !ok { 584 | return nil, initialCursor, false 585 | } 586 | 587 | unique := false 588 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(UniqueKeyword)) 589 | if ok { 590 | unique = true 591 | } 592 | 593 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(IndexKeyword)) 594 | if !ok { 595 | return nil, initialCursor, false 596 | } 597 | 598 | name, newCursor, ok := p.parseTokenKind(tokens, cursor, IdentifierKind) 599 | if !ok { 600 | p.helpMessage(tokens, cursor, "Expected index name") 601 | return nil, initialCursor, false 602 | } 603 | cursor = newCursor 604 | 605 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromKeyword(OnKeyword)) 606 | if !ok { 607 | p.helpMessage(tokens, cursor, "Expected ON Keyword") 608 | return nil, initialCursor, false 609 | } 610 | 611 | table, newCursor, ok := p.parseTokenKind(tokens, cursor, IdentifierKind) 612 | if !ok { 613 | p.helpMessage(tokens, cursor, "Expected table name") 614 | return nil, initialCursor, false 615 | } 616 | cursor = newCursor 617 | 618 | e, newCursor, ok := p.parseExpression(tokens, cursor, []Token{delimiter}, 0) 619 | if !ok { 620 | p.helpMessage(tokens, cursor, "Expected table name") 621 | return nil, initialCursor, false 622 | } 623 | cursor = newCursor 624 | 625 | return &CreateIndexStatement{ 626 | Name: *name, 627 | Unique: unique, 628 | Table: *table, 629 | Exp: *e, 630 | }, cursor, true 631 | } 632 | 633 | func (p Parser) Parse(source string) (*Ast, error) { 634 | tokens, err := lex(source) 635 | if err != nil { 636 | return nil, err 637 | } 638 | 639 | semicolonToken := tokenFromSymbol(SemicolonSymbol) 640 | if len(tokens) > 0 && !tokens[len(tokens)-1].equals(&semicolonToken) { 641 | tokens = append(tokens, &semicolonToken) 642 | } 643 | 644 | a := Ast{} 645 | cursor := uint(0) 646 | for cursor < uint(len(tokens)) { 647 | stmt, newCursor, ok := p.parseStatement(tokens, cursor, tokenFromSymbol(SemicolonSymbol)) 648 | if !ok { 649 | p.helpMessage(tokens, cursor, "Expected statement") 650 | return nil, errors.New("Failed to parse, expected statement") 651 | } 652 | cursor = newCursor 653 | 654 | a.Statements = append(a.Statements, stmt) 655 | 656 | atLeastOneSemicolon := false 657 | for { 658 | _, cursor, ok = p.parseToken(tokens, cursor, tokenFromSymbol(SemicolonSymbol)) 659 | if ok { 660 | atLeastOneSemicolon = true 661 | } else { 662 | break 663 | } 664 | } 665 | 666 | if !atLeastOneSemicolon { 667 | p.helpMessage(tokens, cursor, "Expected semi-colon delimiter between statements") 668 | return nil, errors.New("Missing semi-colon between statements") 669 | } 670 | } 671 | 672 | return &a, nil 673 | } 674 | -------------------------------------------------------------------------------- /parser_test.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestParseExpression(t *testing.T) { 11 | tests := []struct { 12 | source string 13 | ast *Expression 14 | }{ 15 | { 16 | source: "2 = 3 AND 4 = 5", 17 | ast: &Expression{ 18 | Binary: &BinaryExpression{ 19 | A: Expression{ 20 | Binary: &BinaryExpression{ 21 | A: Expression{ 22 | Literal: &Token{"2", NumericKind, Location{0, 0}}, 23 | Kind: LiteralKind, 24 | }, 25 | B: Expression{ 26 | Literal: &Token{"3", NumericKind, Location{0, 5}}, 27 | Kind: LiteralKind, 28 | }, 29 | Op: Token{"=", SymbolKind, Location{0, 3}}, 30 | }, 31 | Kind: BinaryKind, 32 | }, 33 | B: Expression{ 34 | Binary: &BinaryExpression{ 35 | A: Expression{ 36 | Literal: &Token{"4", NumericKind, Location{0, 12}}, 37 | Kind: LiteralKind, 38 | }, 39 | B: Expression{ 40 | Literal: &Token{"5", NumericKind, Location{0, 17}}, 41 | Kind: LiteralKind, 42 | }, 43 | Op: Token{"=", SymbolKind, Location{0, 15}}, 44 | }, 45 | Kind: BinaryKind, 46 | }, 47 | Op: Token{"and", KeywordKind, Location{0, 8}}, 48 | }, 49 | Kind: BinaryKind, 50 | }, 51 | }, 52 | } 53 | 54 | for _, test := range tests { 55 | fmt.Println("(Parser) Testing: ", test.source) 56 | tokens, err := lex(test.source) 57 | assert.Nil(t, err) 58 | 59 | parser := Parser{} 60 | ast, cursor, ok := parser.parseExpression(tokens, 0, []Token{}, 0) 61 | assert.True(t, ok, err, test.source) 62 | assert.Equal(t, cursor, uint(len(tokens))) 63 | assert.Equal(t, ast, test.ast, test.source) 64 | } 65 | } 66 | 67 | func TestParse(t *testing.T) { 68 | tests := []struct { 69 | source string 70 | ast *Ast 71 | }{ 72 | { 73 | source: "INSERT INTO users VALUES (105, 233 + 42)", 74 | ast: &Ast{ 75 | Statements: []*Statement{ 76 | { 77 | Kind: InsertKind, 78 | InsertStatement: &InsertStatement{ 79 | Table: Token{ 80 | Loc: Location{Col: 12, Line: 0}, 81 | Kind: IdentifierKind, 82 | Value: "users", 83 | }, 84 | Values: &[]*Expression{ 85 | { 86 | Literal: &Token{ 87 | Loc: Location{Col: 26, Line: 0}, 88 | Kind: NumericKind, 89 | Value: "105", 90 | }, 91 | Kind: LiteralKind, 92 | }, 93 | { 94 | Binary: &BinaryExpression{ 95 | A: Expression{ 96 | Literal: &Token{ 97 | Loc: Location{Col: 32, Line: 0}, 98 | Kind: NumericKind, 99 | Value: "233", 100 | }, 101 | Kind: LiteralKind, 102 | }, 103 | B: Expression{ 104 | Literal: &Token{ 105 | Loc: Location{Col: 39, Line: 0}, 106 | Kind: NumericKind, 107 | Value: "42", 108 | }, 109 | Kind: LiteralKind, 110 | }, 111 | Op: Token{ 112 | Loc: Location{Col: 37, Line: 0}, 113 | Kind: SymbolKind, 114 | Value: string(PlusSymbol), 115 | }, 116 | }, 117 | Kind: BinaryKind, 118 | }, 119 | }, 120 | }, 121 | }, 122 | }, 123 | }, 124 | }, 125 | { 126 | source: "CREATE TABLE users (id INT, name TEXT)", 127 | ast: &Ast{ 128 | Statements: []*Statement{ 129 | { 130 | Kind: CreateTableKind, 131 | CreateTableStatement: &CreateTableStatement{ 132 | Name: Token{ 133 | Loc: Location{Col: 13, Line: 0}, 134 | Kind: IdentifierKind, 135 | Value: "users", 136 | }, 137 | Cols: &[]*ColumnDefinition{ 138 | { 139 | Name: Token{ 140 | Loc: Location{Col: 20, Line: 0}, 141 | Kind: IdentifierKind, 142 | Value: "id", 143 | }, 144 | Datatype: Token{ 145 | Loc: Location{Col: 23, Line: 0}, 146 | Kind: KeywordKind, 147 | Value: "int", 148 | }, 149 | }, 150 | { 151 | Name: Token{ 152 | Loc: Location{Col: 28, Line: 0}, 153 | Kind: IdentifierKind, 154 | Value: "name", 155 | }, 156 | Datatype: Token{ 157 | Loc: Location{Col: 33, Line: 0}, 158 | Kind: KeywordKind, 159 | Value: "text", 160 | }, 161 | }, 162 | }, 163 | }, 164 | }, 165 | }, 166 | }, 167 | }, 168 | { 169 | source: "SELECT *, exclusive", 170 | ast: &Ast{ 171 | Statements: []*Statement{ 172 | { 173 | Kind: SelectKind, 174 | SelectStatement: &SelectStatement{ 175 | Item: &[]*SelectItem{ 176 | { 177 | Asterisk: true, 178 | }, 179 | { 180 | Exp: &Expression{ 181 | Kind: LiteralKind, 182 | Literal: &Token{ 183 | Loc: Location{Col: 10, Line: 0}, 184 | Kind: IdentifierKind, 185 | Value: "exclusive", 186 | }, 187 | }, 188 | }, 189 | }, 190 | }, 191 | }, 192 | }, 193 | }, 194 | }, 195 | { 196 | source: `SELECT id, name AS fullname FROM "sketchy name" LIMIT 10 OFFSET 12`, 197 | ast: &Ast{ 198 | Statements: []*Statement{ 199 | { 200 | Kind: SelectKind, 201 | SelectStatement: &SelectStatement{ 202 | Item: &[]*SelectItem{ 203 | { 204 | Exp: &Expression{ 205 | Kind: LiteralKind, 206 | Literal: &Token{ 207 | Loc: Location{Col: 7, Line: 0}, 208 | Kind: IdentifierKind, 209 | Value: "id", 210 | }, 211 | }, 212 | }, 213 | { 214 | Exp: &Expression{ 215 | Kind: LiteralKind, 216 | Literal: &Token{ 217 | Loc: Location{Col: 11, Line: 0}, 218 | Kind: IdentifierKind, 219 | Value: "name", 220 | }, 221 | }, 222 | As: &Token{ 223 | Loc: Location{Col: 19, Line: 0}, 224 | Kind: IdentifierKind, 225 | Value: "fullname", 226 | }, 227 | }, 228 | }, 229 | From: &Token{ 230 | Loc: Location{Col: 33, Line: 0}, 231 | Kind: IdentifierKind, 232 | Value: "sketchy name", 233 | }, 234 | Limit: &Expression{ 235 | Kind: LiteralKind, 236 | Literal: &Token{ 237 | Loc: Location{Col: 54, Line: 0}, 238 | Kind: NumericKind, 239 | Value: "10", 240 | }, 241 | }, 242 | Offset: &Expression{ 243 | Kind: LiteralKind, 244 | Literal: &Token{ 245 | Loc: Location{Col: 65, Line: 0}, 246 | Kind: NumericKind, 247 | Value: "12", 248 | }, 249 | }, 250 | }, 251 | }, 252 | }, 253 | }, 254 | }, 255 | } 256 | 257 | for _, test := range tests { 258 | fmt.Println("(Parser) Testing: ", test.source) 259 | parser := Parser{} 260 | ast, err := parser.Parse(test.source) 261 | assert.Nil(t, err, test.source) 262 | assert.Equal(t, test.ast, ast, test.source) 263 | } 264 | } 265 | -------------------------------------------------------------------------------- /repl.go: -------------------------------------------------------------------------------- 1 | package gosql 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "strings" 8 | 9 | "github.com/chzyer/readline" 10 | "github.com/olekukonko/tablewriter" 11 | ) 12 | 13 | func doSelect(mb Backend, slct *SelectStatement) error { 14 | results, err := mb.Select(slct) 15 | if err != nil { 16 | return err 17 | } 18 | 19 | if len(results.Rows) == 0 { 20 | fmt.Println("(no results)") 21 | return nil 22 | } 23 | 24 | table := tablewriter.NewWriter(os.Stdout) 25 | header := []string{} 26 | for _, col := range results.Columns { 27 | header = append(header, col.Name) 28 | } 29 | table.SetHeader(header) 30 | table.SetAutoFormatHeaders(false) 31 | 32 | rows := [][]string{} 33 | for _, result := range results.Rows { 34 | row := []string{} 35 | for i, cell := range result { 36 | typ := results.Columns[i].Type 37 | r := "" 38 | switch typ { 39 | case IntType: 40 | i := cell.AsInt() 41 | if i != nil { 42 | r = fmt.Sprintf("%d", *i) 43 | } 44 | case TextType: 45 | s := cell.AsText() 46 | if s != nil { 47 | r = *s 48 | } 49 | case BoolType: 50 | b := cell.AsBool() 51 | if b != nil { 52 | r = "t" 53 | if !*b { 54 | r = "f" 55 | } 56 | } 57 | } 58 | 59 | row = append(row, r) 60 | } 61 | 62 | rows = append(rows, row) 63 | } 64 | 65 | table.SetBorder(false) 66 | table.AppendBulk(rows) 67 | table.Render() 68 | 69 | if len(rows) == 1 { 70 | fmt.Println("(1 result)") 71 | } else { 72 | fmt.Printf("(%d results)\n", len(rows)) 73 | } 74 | 75 | return nil 76 | } 77 | 78 | func debugTable(b Backend, name string) { 79 | // psql behavior is to display all if no name is specified. 80 | if name == "" { 81 | debugTables(b) 82 | return 83 | } 84 | 85 | var tm *TableMetadata = nil 86 | for _, t := range b.GetTables() { 87 | if t.Name == name { 88 | tm = &t 89 | } 90 | } 91 | 92 | if tm == nil { 93 | fmt.Printf(`Did not find any relation named "%s".\n`, name) 94 | return 95 | } 96 | 97 | fmt.Printf("Table \"%s\"\n", name) 98 | 99 | table := tablewriter.NewWriter(os.Stdout) 100 | table.SetHeader([]string{"Column", "Type", "Nullable"}) 101 | table.SetAutoFormatHeaders(false) 102 | table.SetBorder(false) 103 | 104 | rows := [][]string{} 105 | for _, c := range tm.Columns { 106 | typeString := "integer" 107 | switch c.Type { 108 | case TextType: 109 | typeString = "text" 110 | case BoolType: 111 | typeString = "boolean" 112 | } 113 | nullable := "" 114 | if c.NotNull { 115 | nullable = "not null" 116 | } 117 | rows = append(rows, []string{c.Name, typeString, nullable}) 118 | } 119 | 120 | table.AppendBulk(rows) 121 | table.Render() 122 | 123 | if len(tm.Indexes) > 0 { 124 | fmt.Println("Indexes:") 125 | } 126 | 127 | for _, index := range tm.Indexes { 128 | attributes := []string{} 129 | if index.PrimaryKey { 130 | attributes = append(attributes, "PRIMARY KEY") 131 | } else if index.Unique { 132 | attributes = append(attributes, "UNIQUE") 133 | } 134 | attributes = append(attributes, index.Type) 135 | 136 | fmt.Printf("\t\"%s\" %s (%s)\n", index.Name, strings.Join(attributes, ", "), index.Exp) 137 | } 138 | 139 | fmt.Println("") 140 | } 141 | 142 | func debugTables(b Backend) { 143 | tables := b.GetTables() 144 | if len(tables) == 0 { 145 | fmt.Println("Did not find any relations.") 146 | return 147 | } 148 | 149 | fmt.Println("List of relations") 150 | 151 | table := tablewriter.NewWriter(os.Stdout) 152 | table.SetHeader([]string{"Name", "Type"}) 153 | table.SetAutoFormatHeaders(false) 154 | table.SetBorder(false) 155 | 156 | rows := [][]string{} 157 | for _, t := range tables { 158 | rows = append(rows, []string{t.Name, "table"}) 159 | } 160 | 161 | table.AppendBulk(rows) 162 | table.Render() 163 | 164 | fmt.Println("") 165 | } 166 | 167 | func RunRepl(b Backend) { 168 | l, err := readline.NewEx(&readline.Config{ 169 | Prompt: "# ", 170 | HistoryFile: "/tmp/tmp", 171 | InterruptPrompt: "^C", 172 | EOFPrompt: "exit", 173 | }) 174 | if err != nil { 175 | panic(err) 176 | } 177 | defer l.Close() 178 | 179 | fmt.Println("Welcome to gosql.") 180 | repl: 181 | for { 182 | fmt.Print("# ") 183 | line, err := l.Readline() 184 | if err == readline.ErrInterrupt { 185 | if len(line) == 0 { 186 | break 187 | } else { 188 | continue repl 189 | } 190 | } else if err == io.EOF { 191 | break 192 | } 193 | if err != nil { 194 | fmt.Println("Error while reading line:", err) 195 | continue repl 196 | } 197 | 198 | parser := Parser{} 199 | 200 | trimmed := strings.TrimSpace(line) 201 | if trimmed == "quit" || trimmed == "exit" || trimmed == "\\q" { 202 | break 203 | } 204 | 205 | if trimmed == "\\dt" { 206 | debugTables(b) 207 | continue 208 | } 209 | 210 | if strings.HasPrefix(trimmed, "\\d") { 211 | name := strings.TrimSpace(trimmed[len("\\d"):]) 212 | debugTable(b, name) 213 | continue 214 | } 215 | 216 | parseOnly := false 217 | if strings.HasPrefix(trimmed, "\\p") { 218 | line = strings.TrimSpace(trimmed[len("\\p"):]) 219 | parseOnly = true 220 | } 221 | 222 | ast, err := parser.Parse(line) 223 | if err != nil { 224 | fmt.Println("Error while parsing:", err) 225 | continue repl 226 | } 227 | 228 | for _, stmt := range ast.Statements { 229 | if parseOnly { 230 | fmt.Println(stmt.GenerateCode()) 231 | continue 232 | } 233 | 234 | switch stmt.Kind { 235 | case CreateIndexKind: 236 | err = b.CreateIndex(ast.Statements[0].CreateIndexStatement) 237 | if err != nil { 238 | fmt.Println("Error adding index on table:", err) 239 | continue repl 240 | } 241 | case CreateTableKind: 242 | err = b.CreateTable(ast.Statements[0].CreateTableStatement) 243 | if err != nil { 244 | fmt.Println("Error creating table:", err) 245 | continue repl 246 | } 247 | case DropTableKind: 248 | err = b.DropTable(ast.Statements[0].DropTableStatement) 249 | if err != nil { 250 | fmt.Println("Error dropping table:", err) 251 | continue repl 252 | } 253 | case InsertKind: 254 | err = b.Insert(stmt.InsertStatement) 255 | if err != nil { 256 | fmt.Println("Error inserting values:", err) 257 | continue repl 258 | } 259 | case SelectKind: 260 | err := doSelect(b, stmt.SelectStatement) 261 | if err != nil { 262 | fmt.Println("Error selecting values:", err) 263 | continue repl 264 | } 265 | } 266 | } 267 | 268 | fmt.Println("ok") 269 | } 270 | } 271 | --------------------------------------------------------------------------------