├── LICENSE ├── README.md ├── cond.go ├── conn.go ├── delete.go ├── delete_test.go ├── driver.go ├── execstub.go ├── input.go ├── insert.go ├── insert_test.go ├── internal ├── proto │ ├── LICENSE │ └── query │ │ └── query.pb.go ├── sqlparser │ ├── LICENSE │ ├── Makefile │ ├── analyzer.go │ ├── ast.go │ ├── ast_test.go │ ├── parse_test.go │ ├── parsed_query.go │ ├── parsed_query_test.go │ ├── precedence_test.go │ ├── sql.go │ ├── sql.y │ ├── token.go │ └── tracked_buffer.go └── sqltypes │ ├── LICENSE │ ├── proto3.go │ ├── proto3_test.go │ ├── result.go │ ├── result_test.go │ ├── type.go │ ├── type_test.go │ ├── value.go │ └── value_test.go ├── magic.go ├── mogi.go ├── mogi_test.go ├── rows.go ├── select.go ├── select_test.go ├── stmt.go ├── stub.go ├── tx.go ├── unify.go ├── update.go ├── update_test.go └── where.go /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014, Greg Roseberry 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 11 | 12 | 13 | Portions of this software use code from https://github.com/DATA-DOG/go-sqlmock 14 | Copyright (c) 2013, DataDog.lt team 15 | All rights reserved. 16 | 17 | Redistribution and use in source and binary forms, with or without 18 | modification, are permitted provided that the following conditions are met: 19 | 20 | * Redistributions of source code must retain the above copyright notice, this 21 | list of conditions and the following disclaimer. 22 | 23 | * Redistributions in binary form must reproduce the above copyright notice, 24 | this list of conditions and the following disclaimer in the documentation 25 | and/or other materials provided with the distribution. 26 | 27 | * The name DataDog.lt may not be used to endorse or promote products 28 | derived from this software without specific prior written permission. 29 | 30 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, 34 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 35 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 36 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 37 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 38 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, 39 | EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## mogi [![GoDoc](https://godoc.org/github.com/guregu/mogi?status.svg)](https://godoc.org/github.com/guregu/mogi) [![Coverage](http://gocover.io/_badge/github.com/guregu/mogi)](http://gocover.io/github.com/guregu/mogi) 2 | `import "github.com/guregu/mogi"` 3 | 4 | mogi is a fancy SQL mocking/stubbing library for Go. It uses the [vitess](https://github.com/vitessio/vitess) SQL parser for maximum happiness. 5 | 6 | **Note**: because vitess is a MySQL-flavored parser, other kinds of (non-)standard SQL may break mogi. Mogi isn't finished yet. You can't yet filter stubs based on subqueries and complex bits like ON DUPLICATED AT. 7 | 8 | 9 | ### Usage 10 | 11 | #### Getting started 12 | ```go 13 | import "github.com/guregu/mogi" 14 | db, _ := sql.Open("mogi", "") 15 | ``` 16 | 17 | #### Stubbing SELECT queries 18 | ```go 19 | // Stub any SELECT query 20 | mogi.Select().StubCSV(`1,Yona Yona Ale,Yo-Ho Brewing,5.5`) 21 | rows, err := db.Query("SELECT id, name, brewery, pct FROM beer") 22 | 23 | // Reset to clear all stubs 24 | mogi.Reset() 25 | 26 | // Stub SELECT queries by columns selected 27 | mogi.Select("id", "name", "brewery", "pct").StubCSV(`1,Yona Yona Ale,Yo-Ho Brewing,5.5`) 28 | // Aliased columns should be given as they are aliased. 29 | // Qualified columns should be given as they are qualified. 30 | // e.g. SELECT beer.name AS n, breweries.founded FROM beer JOIN breweries ON beer.brewery = breweries.name 31 | mogi.Select("n", "breweries.founded").StubCSV(`Stone IPA,1996`) 32 | 33 | // You can stub with driver.Values instead of CSV 34 | mogi.Select("id", "deleted_at").Stub([][]driver.Value{{1, nil}}) 35 | 36 | // Filter by table name 37 | mogi.Select().From("beer").StubCSV(`1,Yona Yona Ale,Yo-Ho Brewing,5.5`) 38 | 39 | // You can supply multiple table names for JOIN queries 40 | // e.g. SELECT beer.name, wine.name FROM beer JOIN wine ON beer.pct = wine.pct 41 | // or SELECT beer.name, wine.name FROM beer, wine WHERE beer.pct = wine.pct 42 | mogi.Select().From("beer", "wine").StubCSV(`Westvleteren XII,Across the Pond Riesling`) 43 | 44 | // Filter by WHERE clause params 45 | mogi.Select().Where("id", 10).StubCSV(`10,Apex,Bear Republic Brewing Co.,8.95`) 46 | mogi.Select().Where("id", 42).StubCSV(`42,Westvleteren XII,Brouwerij Westvleteren,10.2`) 47 | rows, err := db.Query("SELECT id, name, brewery, pct FROM beer WHERE id = ?", 10) 48 | ... 49 | rows, err = db.Query("SELECT id, name, brewery, pct FROM beer WHERE id = ?", 42) 50 | ... 51 | 52 | // Pass multiple arguments to Where() for IN clauses. 53 | mogi.Select().Where("id", 10, 42).StubCSV("Apex\nWestvleteren XII") 54 | rows, err = db.Query("SELECT name FROM beer WHERE id IN (?, ?)", 10, 42) 55 | 56 | // Stub an error while you're at it 57 | mogi.Select().Where("id", 3).StubError(sql.ErrNoRows) 58 | // FYI, unstubbed queries will return mogi.ErrUnstubbed 59 | 60 | // Filter by args given 61 | mogi.Select().Args(1).StubCSV(`1,Yona Yona Ale,Yo-Ho Brewing,5.5`) 62 | rows, err := db.Query("SELECT id, name, brewery, pct FROM beer WHERE id = ?", 1) 63 | 64 | // Chain filters as much as you'd like 65 | mogi.Select("id", "name", "brewery", "pct").From("beer").Where("id", 1).StubCSV(`1,Yona Yona Ale,Yo-Ho Brewing,5.5`) 66 | ``` 67 | 68 | #### Stubbing INSERT queries 69 | ```go 70 | // Stub any INSERT query 71 | // You can use StubResult to easily stub a driver.Result. 72 | // You can pass -1 to StubResult to have it return an error for that particular bit. 73 | // In this example, we have 1 row affected, but no LastInsertID. 74 | mogi.Insert().StubResult(-1, 1) 75 | // If you have your own driver.Result you want to pass, just use Stub. 76 | // You can also stub an error with StubError. 77 | 78 | // Filter by the columns used in the INSERT query 79 | mogi.Insert("name", "brewery", "pct").StubResult(1, 1) 80 | result, err := db.Exec("INSERT INTO beer (name, brewery, pct) VALUES (?, ?, ?)", "Yona Yona Ale", "Yo-Ho Brewing", 5.5) 81 | 82 | // Filter by the table used in the query 83 | mogi.Insert().Into("beer").StubResult(1, 1) 84 | 85 | // Filter by the args passed to the query (the things replacing the ?s) 86 | mogi.Insert().Args("Yona Yona Ale", "Yo-Ho Brewing", 5.5).StubResult(1, 1) 87 | 88 | // Filter by the values used in the query 89 | mogi.Insert().Value("name", "Yona Yona Ale").Value("brewery", "Yo-Ho Brewing").StubResult(1, 1) 90 | // Use ValueAt when you are inserting multiple rows. The first argument is the row #, starting with 0. 91 | // Parameters are interpolated for you. 92 | mogi.Insert(). 93 | ValueAt(0, "brewery", "Mikkeller").ValueAt(0, "pct", 4.6). 94 | ValueAt(1, "brewery", "BrewDog").ValueAt(1, "pct", 18.2). 95 | StubResult(4, 2) 96 | result, err = db.Exec(`INSERT INTO beer (name, brewery, pct) VALUES (?, "Mikkeller", 4.6), (?, ?, ?)`, 97 | "Mikkel’s Dream", 98 | "Tokyo*", "BrewDog", 18.2, 99 | ) 100 | ``` 101 | 102 | #### Stubbing UPDATE queries 103 | ```go 104 | // Stub any UPDATE query 105 | // UPDATE stubs work the same as INSERT stubs 106 | // This stubs all UPDATE queries to return 10 rows affected 107 | mogi.Update().StubResult(-1, 10) 108 | // This does the same thing 109 | mogi.Update().StubRowsAffected(10) 110 | 111 | // Filter by the columns used in the SET clause 112 | mogi.Update("name", "brewery", "pct").StubRowsAffected(1) 113 | _, err := db.Exec(`UPDATE beer 114 | SET name = "Mikkel’s Dream", brewery = "Mikkeller", pct = 4.6 115 | WHERE id = ? AND moon = ?`, 3, "full") 116 | 117 | // Filter by values set by the SET clause 118 | mogi.Update().Value("name", "Mikkel’s Dream").Value("brewery", "Mikkeller").StubRowsAffected(1) 119 | 120 | // Filter by args (? placeholder values) 121 | mogi.Update().Args(3, "full").StubRowsAffected(1) 122 | 123 | // Filter by the table being updated 124 | mogi.Update().Table("beer").StubRowsAffected(1) 125 | 126 | // Filter by WHERE clause params 127 | mogi.Update().Where("id", 3).Where("moon", "full").StubRowsAffected(1) 128 | ``` 129 | 130 | #### Stubbing DELETE queries 131 | Works the same as UPDATE, docs later! 132 | 133 | #### Other stuff 134 | 135 | ##### Reset 136 | You can remove all the stubs you've set with `mogi.Reset()`. 137 | 138 | ##### Verbose 139 | `mogi.Verbose(true)` will enable verbose mode, logging unstubbed queries. 140 | 141 | ##### Parse time 142 | Set the time layout with `mogi.ParseTime()`. CSV values matching that layout will be converted to time.Time. 143 | You can also stub time.Time directly using the `Stub()` method. 144 | ```go 145 | mogi.ParseTime(time.RFC3339) 146 | mogi.Select("release"). 147 | From("beer"). 148 | Where("id", 42). 149 | StubCSV(`2014-06-30T12:00:00Z`) 150 | ``` 151 | 152 | ##### Dump stubs 153 | Dump all the stubs with `mogi.Dump()`. It will print something like this: 154 | ``` 155 | >> Query stubs: (1 total) 156 | ========================= 157 | #1 [3] SELECT (any) [+1] 158 | FROM device_tokens [+1] 159 | WHERE user_id ≈ [42] [+1] 160 | → error: sql: no rows in result set 161 | 162 | >> Exec stubs: (2 total) 163 | ========================= 164 | #1 [3] INSERT (any) [+1] 165 | TABLE device_tokens [+1] 166 | VALUE device_type ≈ gunosy_lite (row 0) [+1] 167 | → result ID: 1337, rows: 1 168 | #2 [2] INSERT (any) [+1] 169 | TABLE device_tokens [+1] 170 | → error: device_type should be overwriten 171 | ``` 172 | This is helpful when you're debugging and need to double-check the priorities and conditions you've stubbed. 173 | The numbers in [brackets] are the priorities. 174 | You can also add `Dump()` to a stub condition chain. It will dump lots of information about the query when matched. 175 | 176 | ### License 177 | BSD 178 | -------------------------------------------------------------------------------- /cond.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | // "database/sql" 7 | "database/sql/driver" 8 | "fmt" 9 | 10 | "github.com/davecgh/go-spew/spew" 11 | "github.com/guregu/mogi/internal/sqlparser" 12 | ) 13 | 14 | type cond interface { 15 | matches(in input) bool 16 | priority() int 17 | fmt.Stringer 18 | } 19 | 20 | type condchain []cond 21 | 22 | func (chain condchain) matches(in input) bool { 23 | for _, c := range chain { 24 | if !c.matches(in) { 25 | return false 26 | } 27 | } 28 | return true 29 | } 30 | 31 | func (chain condchain) priority() int { 32 | p := 0 33 | for _, c := range chain { 34 | p += c.priority() 35 | } 36 | return p 37 | } 38 | 39 | func (chain condchain) String() string { 40 | return "Chain..." 41 | } 42 | 43 | type tableCond struct { 44 | table string 45 | } 46 | 47 | func (tc tableCond) matches(in input) bool { 48 | switch x := in.statement.(type) { 49 | case *sqlparser.Insert: 50 | return strings.ToLower(tc.table) == strings.ToLower(string(x.Table.Name)) 51 | case *sqlparser.Update: 52 | return strings.ToLower(tc.table) == strings.ToLower(string(x.Table.Name)) 53 | case *sqlparser.Delete: 54 | return strings.ToLower(tc.table) == strings.ToLower(string(x.Table.Name)) 55 | } 56 | return false 57 | } 58 | 59 | func (tc tableCond) priority() int { 60 | return 1 61 | } 62 | 63 | func (tc tableCond) String() string { 64 | return fmt.Sprintf("TABLE %s", tc.table) 65 | } 66 | 67 | type argsCond struct { 68 | args []driver.Value 69 | } 70 | 71 | func (ac argsCond) matches(in input) bool { 72 | given := unifyValues(ac.args) 73 | return reflect.DeepEqual(given, in.args) 74 | } 75 | 76 | func (ac argsCond) priority() int { 77 | return 1 78 | } 79 | 80 | func (ac argsCond) String() string { 81 | return fmt.Sprintf("WITH ARGS %+v", ac.args) 82 | } 83 | 84 | type valueCond struct { 85 | row int 86 | col string 87 | v interface{} 88 | } 89 | 90 | func newValueCond(row int, col string, v interface{}) valueCond { 91 | return valueCond{ 92 | row: row, 93 | col: col, 94 | v: unify(v), 95 | } 96 | } 97 | 98 | func (vc valueCond) matches(in input) bool { 99 | switch in.statement.(type) { 100 | case *sqlparser.Insert: 101 | values := in.rows() 102 | if vc.row > len(values)-1 { 103 | return false 104 | } 105 | v, ok := values[vc.row][vc.col] 106 | if !ok { 107 | return false 108 | } 109 | return equals(v, vc.v) 110 | case *sqlparser.Update: 111 | values := in.values() 112 | v, ok := values[vc.col] 113 | if !ok { 114 | return false 115 | } 116 | return equals(v, vc.v) 117 | } 118 | return false 119 | } 120 | 121 | func (vc valueCond) priority() int { 122 | return 1 123 | } 124 | 125 | func (vc valueCond) String() string { 126 | return fmt.Sprintf("VALUE %s ≈ %v (row %d)", vc.col, vc.v, vc.row) 127 | } 128 | 129 | type priorityCond struct { 130 | p int 131 | } 132 | 133 | func (pc priorityCond) matches(in input) bool { 134 | return true 135 | } 136 | 137 | func (pc priorityCond) priority() int { 138 | return pc.p 139 | } 140 | 141 | func (pc priorityCond) String() string { 142 | return "PRIORITY" 143 | } 144 | 145 | type notifyCond struct { 146 | ch chan<- struct{} 147 | } 148 | 149 | func (nc notifyCond) matches(in input) bool { 150 | go func() { 151 | nc.ch <- struct{}{} 152 | }() 153 | return true 154 | } 155 | 156 | func (nc notifyCond) priority() int { 157 | return 0 158 | } 159 | 160 | func (nc notifyCond) String() string { 161 | return "NOTIFY" 162 | } 163 | 164 | type dumpCond struct{} 165 | 166 | func (dc dumpCond) matches(in input) bool { 167 | fmt.Println(in.query) 168 | spew.Dump(in.args) 169 | switch in.statement.(type) { 170 | case *sqlparser.Select: 171 | spew.Dump(in.cols(), in.where()) 172 | case *sqlparser.Insert: 173 | spew.Dump(in.cols(), in.rows()) 174 | case *sqlparser.Update: 175 | spew.Dump(in.cols(), in.values(), in.where()) 176 | } 177 | spew.Dump(in.statement) 178 | return true 179 | } 180 | 181 | func (dc dumpCond) priority() int { 182 | return 0 183 | } 184 | 185 | func (dc dumpCond) String() string { 186 | return "DUMP" 187 | } 188 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "log" 5 | "sort" 6 | 7 | "database/sql/driver" 8 | ) 9 | 10 | type conn struct { 11 | stubs stubs 12 | execStubs execStubs 13 | } 14 | 15 | func newConn() *conn { 16 | return &conn{} 17 | } 18 | 19 | func addStub(s *Stub) { 20 | drv.conn.stubs = append(drv.conn.stubs, s) 21 | sort.Sort(drv.conn.stubs) 22 | } 23 | 24 | func addExecStub(s *ExecStub) { 25 | drv.conn.execStubs = append(drv.conn.execStubs, s) 26 | sort.Sort(drv.conn.execStubs) 27 | } 28 | 29 | func (c *conn) Prepare(query string) (driver.Stmt, error) { 30 | return &stmt{ 31 | query: query, 32 | }, nil 33 | } 34 | 35 | func (c *conn) Close() error { 36 | return nil 37 | } 38 | 39 | func (c *conn) Begin() (driver.Tx, error) { 40 | return &tx{}, nil 41 | } 42 | 43 | func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { 44 | in, err := newInput(query, args) 45 | if err != nil { 46 | return nil, err 47 | } 48 | for _, c := range c.stubs { 49 | if c.matches(in) { 50 | return c.rows(in) 51 | } 52 | } 53 | if verbose { 54 | log.Println("Unstubbed query:", query, args) 55 | } 56 | return nil, ErrUnstubbed 57 | } 58 | 59 | func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { 60 | in, err := newInput(query, args) 61 | if err != nil { 62 | return nil, err 63 | } 64 | for _, c := range c.execStubs { 65 | if c.matches(in) { 66 | return c.results() 67 | } 68 | } 69 | if verbose { 70 | log.Println("Unstubbed query:", query, args) 71 | } 72 | return nil, ErrUnstubbed 73 | } 74 | -------------------------------------------------------------------------------- /delete.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "github.com/guregu/mogi/internal/sqlparser" 5 | ) 6 | 7 | type deleteCond struct{} 8 | 9 | func (uc deleteCond) matches(in input) bool { 10 | _, ok := in.statement.(*sqlparser.Delete) 11 | return ok 12 | } 13 | 14 | func (uc deleteCond) priority() int { 15 | return 1 16 | } 17 | 18 | func (uc deleteCond) String() string { 19 | return "DELETE" 20 | } 21 | -------------------------------------------------------------------------------- /delete_test.go: -------------------------------------------------------------------------------- 1 | package mogi_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/guregu/mogi" 7 | ) 8 | 9 | func TestDelete(t *testing.T) { 10 | defer mogi.Reset() 11 | db := openDB() 12 | 13 | mogi.Delete().StubRowsAffected(1) 14 | _, err := db.Exec("DELETE FROM beer WHERE id = ?", 42) 15 | checkNil(t, err) 16 | 17 | mogi.Reset() 18 | mogi.Delete().Table("beer").StubRowsAffected(1) 19 | _, err = db.Exec("DELETE FROM beer WHERE id = ?", 42) 20 | checkNil(t, err) 21 | 22 | mogi.Reset() 23 | mogi.Delete().Table("beer").Where("id", 42).StubRowsAffected(1) 24 | _, err = db.Exec("DELETE FROM beer WHERE id = ?", 42) 25 | checkNil(t, err) 26 | 27 | mogi.Reset() 28 | mogi.Delete().Table("beer").WhereOp("id", "=", 42).StubRowsAffected(1) 29 | _, err = db.Exec("DELETE FROM beer WHERE id = ?", 42) 30 | checkNil(t, err) 31 | 32 | mogi.Reset() 33 | mogi.Delete().Table("beer").WhereOp("id", "=", 42).WhereOp("id", ">", 100).StubRowsAffected(1) 34 | _, err = db.Exec("DELETE FROM beer WHERE id = ? OR id > 100", 42) 35 | checkNil(t, err) 36 | 37 | mogi.Reset() 38 | mogi.Delete().Table("beer").Where("id", 50).StubRowsAffected(1) 39 | _, err = db.Exec("DELETE FROM beer WHERE id = ?", 42) 40 | if err != mogi.ErrUnstubbed { 41 | t.Error("err should be ErrUnstubbed but is", err) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /driver.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | var drv *mdriver 8 | 9 | type mdriver struct { 10 | *conn 11 | } 12 | 13 | func newDriver() *mdriver { 14 | return &mdriver{ 15 | conn: newConn(), 16 | } 17 | } 18 | 19 | func (d *mdriver) Open(name string) (driver.Conn, error) { 20 | return drv.conn, nil 21 | } 22 | 23 | type execResult struct { 24 | lastInsertID int64 25 | rowsAffected int64 26 | } 27 | 28 | func (r execResult) LastInsertId() (int64, error) { 29 | if r.lastInsertID == -1 { 30 | return 0, errNotSet 31 | } 32 | return r.lastInsertID, nil 33 | } 34 | 35 | // RowsAffected returns the number of rows affected by the 36 | // query. 37 | func (r execResult) RowsAffected() (int64, error) { 38 | if r.rowsAffected == -1 { 39 | return 0, errNotSet 40 | } 41 | return r.rowsAffected, nil 42 | } 43 | -------------------------------------------------------------------------------- /execstub.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | // ExecStub is a SQL exec stub (for INSERT, UPDATE, DELETE) 8 | type ExecStub struct { 9 | chain condchain 10 | result driver.Result 11 | err error 12 | } 13 | 14 | // Insert starts a new stub for INSERT statements. 15 | // You can filter out which columns to use this stub for. 16 | // If you don't pass any columns, it will stub all INSERT queries. 17 | func Insert(cols ...string) *ExecStub { 18 | return &ExecStub{ 19 | chain: condchain{insertCond{ 20 | cols: cols, 21 | }}, 22 | } 23 | } 24 | 25 | // Update starts a new stub for UPDATE statements. 26 | // You can filter out which columns (from the SET statement) this stub is for. 27 | // If you don't pass any columns, it will stub all UPDATE queries. 28 | func Update(cols ...string) *ExecStub { 29 | return &ExecStub{ 30 | chain: condchain{updateCond{ 31 | cols: cols, 32 | }}, 33 | } 34 | } 35 | 36 | // Delete starts a new stub for DELETE statements. 37 | func Delete() *ExecStub { 38 | return &ExecStub{ 39 | chain: condchain{deleteCond{}}, 40 | } 41 | } 42 | 43 | // Table further filters this stub, matching the target table in INSERT, UPDATE, or DELETE. 44 | func (s *ExecStub) Table(table string) *ExecStub { 45 | s.chain = append(s.chain, tableCond{ 46 | table: table, 47 | }) 48 | return s 49 | } 50 | 51 | // Into further filters this stub, matching based on the INTO table specified. 52 | func (s *ExecStub) Into(table string) *ExecStub { 53 | return s.Table(table) 54 | } 55 | 56 | // From further filters this stub, matching based on the FROM table specified. 57 | func (s *ExecStub) From(table string) *ExecStub { 58 | return s.Table(table) 59 | } 60 | 61 | // Value further filters this stub, matching based on values supplied to the query 62 | // For INSERTs, it matches the first row of values, so it is a shortcut for ValueAt(0, ...) 63 | // For UPDATEs, it matches on the SET clause. 64 | func (s *ExecStub) Value(col string, v interface{}) *ExecStub { 65 | s.ValueAt(0, col, v) 66 | return s 67 | } 68 | 69 | // ValueAt further filters this stub, matching based on values supplied to the query 70 | func (s *ExecStub) ValueAt(row int, col string, v interface{}) *ExecStub { 71 | s.chain = append(s.chain, newValueCond(row, col, v)) 72 | return s 73 | } 74 | 75 | // Where further filters this stub by values of input in the WHERE clause. 76 | // You can pass multiple values for IN clause matching. 77 | func (s *ExecStub) Where(col string, v ...interface{}) *ExecStub { 78 | s.chain = append(s.chain, newWhereCond(col, v)) 79 | return s 80 | } 81 | 82 | // WhereOp further filters this stub by values of input and the operator used in the WHERE clause. 83 | func (s *ExecStub) WhereOp(col string, operator string, v ...interface{}) *ExecStub { 84 | s.chain = append(s.chain, newWhereOpCond(col, v, operator)) 85 | return s 86 | } 87 | 88 | // Args further filters this stub, matching based on the args passed to the query 89 | func (s *ExecStub) Args(args ...driver.Value) *ExecStub { 90 | s.chain = append(s.chain, argsCond{args}) 91 | return s 92 | } 93 | 94 | // Priority adds the given priority to this stub, without performing any matching. 95 | func (s *ExecStub) Priority(p int) *ExecStub { 96 | s.chain = append(s.chain, priorityCond{p}) 97 | return s 98 | } 99 | 100 | // Notify will have this stub send to the given channel when matched. 101 | // You should put this as the last part of your stub chain. 102 | func (s *ExecStub) Notify(ch chan<- struct{}) *ExecStub { 103 | s.chain = append(s.chain, notifyCond{ch}) 104 | return s 105 | } 106 | 107 | // Dump outputs debug information, without performing any matching. 108 | func (s *ExecStub) Dump() *ExecStub { 109 | s.chain = append(s.chain, dumpCond{}) 110 | return s 111 | } 112 | 113 | // Stub takes a driver.Result and registers this stub with the driver 114 | func (s *ExecStub) Stub(res driver.Result) { 115 | s.result = res 116 | addExecStub(s) 117 | } 118 | 119 | // StubResult is an easy way to stub a driver.Result. 120 | // Given a value of -1, the result will return an error for that particular part. 121 | func (s *ExecStub) StubResult(lastInsertID, rowsAffected int64) { 122 | s.result = execResult{ 123 | lastInsertID: lastInsertID, 124 | rowsAffected: rowsAffected, 125 | } 126 | addExecStub(s) 127 | } 128 | 129 | // StubRowsAffected is an easy way to stub a driver.Result when you only need to specify the rows affected. 130 | func (s *ExecStub) StubRowsAffected(rowsAffected int64) { 131 | s.StubResult(-1, rowsAffected) 132 | } 133 | 134 | // StubError takes an error and registers this stub with the driver 135 | func (s *ExecStub) StubError(err error) { 136 | s.err = err 137 | addExecStub(s) 138 | } 139 | 140 | func (s *ExecStub) matches(in input) bool { 141 | return s.chain.matches(in) 142 | } 143 | 144 | func (s *ExecStub) results() (driver.Result, error) { 145 | return s.result, s.err 146 | } 147 | 148 | func (s *ExecStub) priority() int { 149 | return s.chain.priority() 150 | } 151 | 152 | type execStubs []*ExecStub 153 | 154 | func (s execStubs) Len() int { return len(s) } 155 | func (s execStubs) Less(i, j int) bool { return s[i].priority() > s[j].priority() } 156 | func (s execStubs) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 157 | -------------------------------------------------------------------------------- /input.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "database/sql/driver" 5 | "log" 6 | 7 | // "github.com/davecgh/go-spew/spew" 8 | "github.com/guregu/mogi/internal/sqlparser" 9 | ) 10 | 11 | type input struct { 12 | query string 13 | statement sqlparser.Statement 14 | args []driver.Value 15 | 16 | whereVars map[string]interface{} 17 | whereOpVars map[colop]interface{} 18 | } 19 | 20 | func newInput(query string, args []driver.Value) (in input, err error) { 21 | in = input{ 22 | query: query, 23 | args: args, 24 | } 25 | in.statement, err = sqlparser.Parse(query) 26 | return 27 | } 28 | 29 | type arg int 30 | 31 | type opval struct { 32 | op string 33 | v interface{} 34 | } 35 | 36 | type colop struct { 37 | col string 38 | op string 39 | } 40 | 41 | /* 42 | Column name rules: 43 | SELECT a → a 44 | SELECT a.b → a.b 45 | SELECT a.b AS c → c 46 | */ 47 | func (in input) cols() []string { 48 | var cols []string 49 | 50 | switch x := in.statement.(type) { 51 | case *sqlparser.Select: 52 | for _, sexpr := range x.SelectExprs { 53 | name := stringify(transmogrify(sexpr)) 54 | cols = append(cols, name) 55 | } 56 | case *sqlparser.Insert: 57 | for _, c := range x.Columns { 58 | nse, ok := c.(*sqlparser.NonStarExpr) 59 | if !ok { 60 | log.Println("something other than NonStarExpr", c) 61 | continue 62 | } 63 | name := extractColumnName(nse) 64 | cols = append(cols, name) 65 | } 66 | case *sqlparser.Update: 67 | for _, expr := range x.Exprs { 68 | // TODO qualifiers 69 | name := string(expr.Name.Name) 70 | cols = append(cols, name) 71 | } 72 | } 73 | return cols 74 | } 75 | 76 | // for UPDATEs 77 | func (in input) values() map[string]interface{} { 78 | vals := make(map[string]interface{}) 79 | 80 | switch x := in.statement.(type) { 81 | case *sqlparser.Update: 82 | for _, expr := range x.Exprs { 83 | // TODO qualifiers 84 | colName := string(expr.Name.Name) 85 | v := transmogrify(expr.Expr) 86 | if a, ok := v.(arg); ok { 87 | // replace placeholders 88 | v = unify(in.args[int(a)]) 89 | } 90 | vals[colName] = v 91 | } 92 | } 93 | 94 | return vals 95 | } 96 | 97 | // for INSERTs 98 | func (in input) rows() []map[string]interface{} { 99 | var vals []map[string]interface{} 100 | cols := in.cols() 101 | 102 | switch x := in.statement.(type) { 103 | case *sqlparser.Insert: 104 | insertRows := x.Rows.(sqlparser.Values) 105 | vals = make([]map[string]interface{}, len(insertRows)) 106 | for i, rowTuple := range insertRows { 107 | vals[i] = make(map[string]interface{}) 108 | row := rowTuple.(sqlparser.ValTuple) 109 | for j, val := range row { 110 | colName := cols[j] 111 | v := transmogrify(val) 112 | if a, ok := v.(arg); ok { 113 | // replace placeholders 114 | v = unify(in.args[int(a)]) 115 | } 116 | vals[i][colName] = v 117 | } 118 | } 119 | } 120 | return vals 121 | } 122 | 123 | // for SELECT and UPDATE and DELETE 124 | func (in input) where() map[string]interface{} { 125 | if in.whereVars != nil { 126 | return in.whereVars 127 | } 128 | var w *sqlparser.Where 129 | switch x := in.statement.(type) { 130 | case *sqlparser.Select: 131 | w = x.Where 132 | case *sqlparser.Update: 133 | w = x.Where 134 | case *sqlparser.Delete: 135 | w = x.Where 136 | default: 137 | return nil 138 | } 139 | if w == nil { 140 | return map[string]interface{}{} 141 | } 142 | in.whereVars = extractBoolExpr(nil, w.Expr) 143 | // replace placeholders 144 | for k, v := range in.whereVars { 145 | if a, ok := v.(arg); ok { 146 | in.whereVars[k] = unify(in.args[int(a)]) 147 | continue 148 | } 149 | 150 | // arrays 151 | if arr, ok := v.([]interface{}); ok { 152 | for i, v := range arr { 153 | if a, ok := v.(arg); ok { 154 | arr[i] = unify(in.args[int(a)]) 155 | } 156 | } 157 | } 158 | } 159 | return in.whereVars 160 | } 161 | 162 | // for SELECT and UPDATE and DELETE 163 | func (in input) whereOp() map[colop]interface{} { 164 | if in.whereOpVars != nil { 165 | return in.whereOpVars 166 | } 167 | var w *sqlparser.Where 168 | switch x := in.statement.(type) { 169 | case *sqlparser.Select: 170 | w = x.Where 171 | case *sqlparser.Update: 172 | w = x.Where 173 | case *sqlparser.Delete: 174 | w = x.Where 175 | default: 176 | return nil 177 | } 178 | if w == nil { 179 | return map[colop]interface{}{} 180 | } 181 | in.whereOpVars = extractBoolExprWithOps(nil, w.Expr) 182 | // replace placeholders 183 | for k, v := range in.whereOpVars { 184 | if a, ok := v.(arg); ok { 185 | in.whereOpVars[k] = unify(in.args[int(a)]) 186 | continue 187 | } 188 | 189 | // arrays 190 | if arr, ok := v.([]interface{}); ok { 191 | for i, v := range arr { 192 | if a, ok := v.(arg); ok { 193 | arr[i] = unify(in.args[int(a)]) 194 | } 195 | } 196 | } 197 | } 198 | return in.whereOpVars 199 | } 200 | -------------------------------------------------------------------------------- /insert.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | 8 | "github.com/guregu/mogi/internal/sqlparser" 9 | ) 10 | 11 | type insertCond struct { 12 | cols []string 13 | } 14 | 15 | func (ic insertCond) matches(in input) bool { 16 | _, ok := in.statement.(*sqlparser.Insert) 17 | if !ok { 18 | return false 19 | } 20 | 21 | // zero parameters means anything 22 | if len(ic.cols) == 0 { 23 | return true 24 | } 25 | 26 | return reflect.DeepEqual(lowercase(ic.cols), lowercase(in.cols())) 27 | } 28 | 29 | func (ic insertCond) priority() int { 30 | if len(ic.cols) > 0 { 31 | return 2 32 | } 33 | return 1 34 | } 35 | 36 | func (ic insertCond) String() string { 37 | cols := "(any)" // TODO support star select 38 | if len(ic.cols) > 0 { 39 | cols = strings.Join(ic.cols, ", ") 40 | } 41 | return fmt.Sprintf("INSERT %s", cols) 42 | } 43 | -------------------------------------------------------------------------------- /insert_test.go: -------------------------------------------------------------------------------- 1 | package mogi_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/guregu/mogi" 7 | ) 8 | 9 | func TestInsert(t *testing.T) { 10 | //mogi.Insert().Into("device_tokens").Expect("device_type", "gunosy_lite").StubResult(1337, 1) 11 | defer mogi.Reset() 12 | db := openDB() 13 | 14 | // naked INSERT stub 15 | mogi.Insert().StubResult(3, 1) 16 | _, err := db.Exec("INSERT INTO beer (name, brewery, pct) VALUES (?, ?, ?)", "Mikkel’s Dream", "Mikkeller", 4.6) 17 | checkNil(t, err) 18 | 19 | // INSERT with columns 20 | mogi.Reset() 21 | mogi.Insert("name", "brewery", "pct").Into("beer").StubResult(3, 1) 22 | _, err = db.Exec("INSERT INTO beer (name, brewery, pct) VALUES (?, ?, ?)", "Mikkel’s Dream", "Mikkeller", 4.6) 23 | checkNil(t, err) 24 | 25 | // INSERT with wrong columns 26 | mogi.Reset() 27 | mogi.Insert("犬", "🐱", "かっぱ").Into("beer").StubResult(3, 1) 28 | _, err = db.Exec("INSERT INTO beer (name, brewery, pct) VALUES (?, ?, ?)", "Mikkel’s Dream", "Mikkeller", 4.6) 29 | if err != mogi.ErrUnstubbed { 30 | t.Error("err should be ErrUnstubbed but is", err) 31 | } 32 | } 33 | 34 | func TestInsertArgs(t *testing.T) { 35 | defer mogi.Reset() 36 | db := openDB() 37 | 38 | mogi.Insert().Args("Mikkel’s Dream", "Mikkeller", 4.6).StubResult(3, 1) 39 | _, err := db.Exec("INSERT INTO beer (name, brewery, pct) VALUES (?, ?, ?)", "Mikkel’s Dream", "Mikkeller", 4.6) 40 | checkNil(t, err) 41 | 42 | // wrong args 43 | mogi.Reset() 44 | mogi.Insert().Args("Nodogoshi", "Kirin", 5).StubResult(4, 1) 45 | _, err = db.Exec("INSERT INTO beer (name, brewery, pct) VALUES (?, ?, ?)", "Mikkel’s Dream", "Mikkeller", 4.6) 46 | if err != mogi.ErrUnstubbed { 47 | t.Error("err should be ErrUnstubbed but is", err) 48 | } 49 | } 50 | 51 | func TestInsertInto(t *testing.T) { 52 | defer mogi.Reset() 53 | db := openDB() 54 | 55 | mogi.Insert().Into("beer").StubResult(3, 1) 56 | _, err := db.Exec("INSERT INTO beer (name, brewery, pct) VALUES (?, ?, ?)", "Mikkel’s Dream", "Mikkeller", 4.6) 57 | checkNil(t, err) 58 | // make sure .Into() and .Table() are the same 59 | mogi.Reset() 60 | mogi.Insert().Table("beer").StubResult(3, 1) 61 | _, err = db.Exec("INSERT INTO beer (name, brewery, pct) VALUES (?, ?, ?)", "Mikkel’s Dream", "Mikkeller", 4.6) 62 | checkNil(t, err) 63 | } 64 | 65 | func TestStubResult(t *testing.T) { 66 | defer mogi.Reset() 67 | db := openDB() 68 | 69 | mogi.Insert().StubResult(3, 1) 70 | res, err := db.Exec("INSERT INTO beer (name, brewery, pct) VALUES (?, ?, ?)", "Mikkel’s Dream", "Mikkeller", 4.6) 71 | checkNil(t, err) 72 | lastID, err := res.LastInsertId() 73 | checkNil(t, err) 74 | if lastID != 3 { 75 | t.Error("LastInsertId() should be 3 but is", lastID) 76 | } 77 | affected, err := res.RowsAffected() 78 | checkNil(t, err) 79 | if affected != 1 { 80 | t.Error("RowsAffected() should be 1 but is", affected) 81 | } 82 | } 83 | 84 | func TestStubResultWithErrors(t *testing.T) { 85 | defer mogi.Reset() 86 | db := openDB() 87 | 88 | mogi.Insert().StubResult(-1, -1) 89 | res, err := db.Exec("INSERT INTO beer (name, brewery, pct) VALUES (?, ?, ?)", "Mikkel’s Dream", "Mikkeller", 4.6) 90 | checkNil(t, err) 91 | _, err = res.LastInsertId() 92 | if err == nil { 93 | t.Error("error is nil but shouldn't be:", err) 94 | } 95 | _, err = res.RowsAffected() 96 | if err == nil { 97 | t.Error("error is nil but shouldn't be:", err) 98 | } 99 | } 100 | 101 | func TestInsertValues(t *testing.T) { 102 | defer mogi.Reset() 103 | db := openDB() 104 | 105 | // single row 106 | mogi.Insert().Value("brewery", "Mikkeller").Value("pct", 4.6).StubResult(3, 1) 107 | _, err := db.Exec("INSERT INTO beer (name, brewery, pct) VALUES (?, ?, ?)", "Mikkel’s Dream", "Mikkeller", 4.6) 108 | checkNil(t, err) 109 | 110 | // multiple rows 111 | mogi.Reset() 112 | mogi.Insert(). 113 | ValueAt(0, "brewery", "Mikkeller").ValueAt(0, "pct", 4.6). 114 | ValueAt(1, "brewery", "BrewDog").ValueAt(1, "pct", 18.2). 115 | StubResult(4, 2) 116 | _, err = db.Exec(`INSERT INTO beer (name, brewery, pct) VALUES (?, "Mikkeller", 4.6), (?, ?, ?)`, 117 | "Mikkel’s Dream", 118 | "Tokyo*", "BrewDog", 18.2, 119 | ) 120 | checkNil(t, err) 121 | } 122 | -------------------------------------------------------------------------------- /internal/proto/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2012, Google Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above 11 | copyright notice, this list of conditions and the following disclaimer 12 | in the documentation and/or other materials provided with the 13 | distribution. 14 | * Neither the name of Google Inc. nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /internal/sqlparser/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2012, Google Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above 11 | copyright notice, this list of conditions and the following disclaimer 12 | in the documentation and/or other materials provided with the 13 | distribution. 14 | * Neither the name of Google Inc. nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /internal/sqlparser/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright 2012, Google Inc. All rights reserved. 2 | # Use of this source code is governed by a BSD-style license that can 3 | # be found in the LICENSE file. 4 | 5 | MAKEFLAGS = -s 6 | 7 | sql.go: sql.y 8 | go tool yacc -o sql.go sql.y 9 | gofmt -w sql.go 10 | 11 | clean: 12 | rm -f y.output sql.go 13 | -------------------------------------------------------------------------------- /internal/sqlparser/analyzer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | // analyzer.go contains utility analysis functions. 8 | 9 | import ( 10 | "fmt" 11 | 12 | "github.com/guregu/mogi/internal/sqltypes" 13 | ) 14 | 15 | // GetTableName returns the table name from the SimpleTableExpr 16 | // only if it's a simple expression. Otherwise, it returns "". 17 | func GetTableName(node SimpleTableExpr) string { 18 | if n, ok := node.(*TableName); ok && n.Qualifier == "" { 19 | return string(n.Name) 20 | } 21 | // sub-select or '.' expression 22 | return "" 23 | } 24 | 25 | // GetColName returns the column name, only if 26 | // it's a simple expression. Otherwise, it returns "". 27 | func GetColName(node Expr) string { 28 | if n, ok := node.(*ColName); ok { 29 | return string(n.Name) 30 | } 31 | return "" 32 | } 33 | 34 | // IsColName returns true if the ValExpr is a *ColName. 35 | func IsColName(node ValExpr) bool { 36 | _, ok := node.(*ColName) 37 | return ok 38 | } 39 | 40 | // IsValue returns true if the ValExpr is a string, number or value arg. 41 | // NULL is not considered to be a value. 42 | func IsValue(node ValExpr) bool { 43 | switch node.(type) { 44 | case StrVal, NumVal, ValArg: 45 | return true 46 | } 47 | return false 48 | } 49 | 50 | // IsNull returns true if the ValExpr is SQL NULL 51 | func IsNull(node ValExpr) bool { 52 | switch node.(type) { 53 | case *NullVal: 54 | return true 55 | } 56 | return false 57 | } 58 | 59 | // HasINClause returns true if any of the conditions has an IN clause. 60 | func HasINClause(conditions []BoolExpr) bool { 61 | for _, node := range conditions { 62 | if c, ok := node.(*ComparisonExpr); ok && c.Operator == InStr { 63 | return true 64 | } 65 | } 66 | return false 67 | } 68 | 69 | // IsSimpleTuple returns true if the ValExpr is a ValTuple that 70 | // contains simple values or if it's a list arg. 71 | func IsSimpleTuple(node ValExpr) bool { 72 | switch vals := node.(type) { 73 | case ValTuple: 74 | for _, n := range vals { 75 | if !IsValue(n) { 76 | return false 77 | } 78 | } 79 | return true 80 | case ListArg: 81 | return true 82 | } 83 | // It's a subquery 84 | return false 85 | } 86 | 87 | // AsInterface converts the ValExpr to an interface. It converts 88 | // ValTuple to []interface{}, ValArg to string, StrVal to sqltypes.String, 89 | // NumVal to sqltypes.Numeric, NullVal to nil. 90 | // Otherwise, it returns an error. 91 | func AsInterface(node ValExpr) (interface{}, error) { 92 | switch node := node.(type) { 93 | case ValTuple: 94 | vals := make([]interface{}, 0, len(node)) 95 | for _, val := range node { 96 | v, err := AsInterface(val) 97 | if err != nil { 98 | return nil, err 99 | } 100 | vals = append(vals, v) 101 | } 102 | return vals, nil 103 | case ValArg: 104 | return string(node), nil 105 | case ListArg: 106 | return string(node), nil 107 | case StrVal: 108 | return sqltypes.MakeString(node), nil 109 | case NumVal: 110 | n, err := sqltypes.BuildIntegral(string(node)) 111 | if err != nil { 112 | return nil, fmt.Errorf("type mismatch: %s", err) 113 | } 114 | return n, nil 115 | case *NullVal: 116 | return nil, nil 117 | } 118 | return nil, fmt.Errorf("unexpected node %v", node) 119 | } 120 | 121 | // StringIn is a convenience function that returns 122 | // true if str matches any of the values. 123 | func StringIn(str string, values ...string) bool { 124 | for _, val := range values { 125 | if str == val { 126 | return true 127 | } 128 | } 129 | return false 130 | } 131 | -------------------------------------------------------------------------------- /internal/sqlparser/ast_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2014, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import "testing" 8 | 9 | func TestSelect(t *testing.T) { 10 | tree, err := Parse("select * from t where a = 1") 11 | if err != nil { 12 | t.Error(err) 13 | } 14 | expr := tree.(*Select).Where.Expr 15 | 16 | sel := &Select{} 17 | sel.AddWhere(expr) 18 | buf := NewTrackedBuffer(nil) 19 | sel.Where.Format(buf) 20 | want := " where a = 1" 21 | if buf.String() != want { 22 | t.Errorf("where: %q, want %s", buf.String(), want) 23 | } 24 | sel.AddWhere(expr) 25 | buf = NewTrackedBuffer(nil) 26 | sel.Where.Format(buf) 27 | want = " where a = 1 and a = 1" 28 | if buf.String() != want { 29 | t.Errorf("where: %q, want %s", buf.String(), want) 30 | } 31 | sel = &Select{} 32 | sel.AddHaving(expr) 33 | buf = NewTrackedBuffer(nil) 34 | sel.Having.Format(buf) 35 | want = " having a = 1" 36 | if buf.String() != want { 37 | t.Errorf("having: %q, want %s", buf.String(), want) 38 | } 39 | sel.AddHaving(expr) 40 | buf = NewTrackedBuffer(nil) 41 | sel.Having.Format(buf) 42 | want = " having a = 1 and a = 1" 43 | if buf.String() != want { 44 | t.Errorf("having: %q, want %s", buf.String(), want) 45 | } 46 | 47 | // OR clauses must be parenthesized. 48 | tree, err = Parse("select * from t where a = 1 or b = 1") 49 | if err != nil { 50 | t.Error(err) 51 | } 52 | expr = tree.(*Select).Where.Expr 53 | sel = &Select{} 54 | sel.AddWhere(expr) 55 | buf = NewTrackedBuffer(nil) 56 | sel.Where.Format(buf) 57 | want = " where (a = 1 or b = 1)" 58 | if buf.String() != want { 59 | t.Errorf("where: %q, want %s", buf.String(), want) 60 | } 61 | sel = &Select{} 62 | sel.AddHaving(expr) 63 | buf = NewTrackedBuffer(nil) 64 | sel.Having.Format(buf) 65 | want = " having (a = 1 or b = 1)" 66 | if buf.String() != want { 67 | t.Errorf("having: %q, want %s", buf.String(), want) 68 | } 69 | } 70 | 71 | func TestWhere(t *testing.T) { 72 | var w *Where 73 | buf := NewTrackedBuffer(nil) 74 | w.Format(buf) 75 | if buf.String() != "" { 76 | t.Errorf("w.Format(nil): %q, want \"\"", buf.String()) 77 | } 78 | w = NewWhere(WhereStr, nil) 79 | buf = NewTrackedBuffer(nil) 80 | w.Format(buf) 81 | if buf.String() != "" { 82 | t.Errorf("w.Format(&Where{nil}: %q, want \"\"", buf.String()) 83 | } 84 | } 85 | 86 | func TestLimits(t *testing.T) { 87 | var l *Limit 88 | o, r, err := l.Limits() 89 | if o != nil || r != nil || err != nil { 90 | t.Errorf("got %v, %v, %v, want nils", o, r, err) 91 | } 92 | 93 | l = &Limit{Offset: NumVal([]byte("aa"))} 94 | _, _, err = l.Limits() 95 | wantErr := "strconv.ParseInt: parsing \"aa\": invalid syntax" 96 | if err == nil || err.Error() != wantErr { 97 | t.Errorf("got %v, want %s", err, wantErr) 98 | } 99 | 100 | l = &Limit{Offset: NumVal([]byte("2"))} 101 | _, _, err = l.Limits() 102 | wantErr = "unexpected node for rowcount: " 103 | if err == nil || err.Error() != wantErr { 104 | t.Errorf("got %v, want %s", err, wantErr) 105 | } 106 | 107 | l = &Limit{Offset: StrVal([]byte("2"))} 108 | _, _, err = l.Limits() 109 | wantErr = "unexpected node for offset: [50]" 110 | if err == nil || err.Error() != wantErr { 111 | t.Errorf("got %v, want %s", err, wantErr) 112 | } 113 | 114 | l = &Limit{Offset: NumVal([]byte("2")), Rowcount: NumVal([]byte("aa"))} 115 | _, _, err = l.Limits() 116 | wantErr = "strconv.ParseInt: parsing \"aa\": invalid syntax" 117 | if err == nil || err.Error() != wantErr { 118 | t.Errorf("got %v, want %s", err, wantErr) 119 | } 120 | 121 | l = &Limit{Offset: NumVal([]byte("2")), Rowcount: NumVal([]byte("3"))} 122 | o, r, err = l.Limits() 123 | if o.(int64) != 2 || r.(int64) != 3 || err != nil { 124 | t.Errorf("got %v %v %v, want 2, 3, nil", o, r, err) 125 | } 126 | 127 | l = &Limit{Offset: ValArg([]byte(":a")), Rowcount: NumVal([]byte("3"))} 128 | o, r, err = l.Limits() 129 | if o.(string) != ":a" || r.(int64) != 3 || err != nil { 130 | t.Errorf("got %v %v %v, want :a, 3, nil", o, r, err) 131 | } 132 | 133 | l = &Limit{Offset: nil, Rowcount: NumVal([]byte("3"))} 134 | o, r, err = l.Limits() 135 | if o != nil || r.(int64) != 3 || err != nil { 136 | t.Errorf("got %v %v %v, want nil, 3, nil", o, r, err) 137 | } 138 | 139 | l = &Limit{Offset: nil, Rowcount: ValArg([]byte(":a"))} 140 | o, r, err = l.Limits() 141 | if o != nil || r.(string) != ":a" || err != nil { 142 | t.Errorf("got %v %v %v, want nil, :a, nil", o, r, err) 143 | } 144 | 145 | l = &Limit{Offset: NumVal([]byte("-2")), Rowcount: NumVal([]byte("0"))} 146 | _, _, err = l.Limits() 147 | wantErr = "negative offset: -2" 148 | if err == nil || err.Error() != wantErr { 149 | t.Errorf("got %v, want %s", err, wantErr) 150 | } 151 | 152 | l = &Limit{Offset: NumVal([]byte("2")), Rowcount: NumVal([]byte("-2"))} 153 | _, _, err = l.Limits() 154 | wantErr = "negative limit: -2" 155 | if err == nil || err.Error() != wantErr { 156 | t.Errorf("got %v, want %s", err, wantErr) 157 | } 158 | } 159 | 160 | func TestIsAggregate(t *testing.T) { 161 | f := FuncExpr{Name: "avg"} 162 | if !f.IsAggregate() { 163 | t.Error("IsAggregate: false, want true") 164 | } 165 | 166 | f = FuncExpr{Name: "foo"} 167 | if f.IsAggregate() { 168 | t.Error("IsAggregate: true, want false") 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /internal/sqlparser/parse_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import "testing" 8 | 9 | func TestValid(t *testing.T) { 10 | validSQL := []struct { 11 | input string 12 | output string 13 | }{{ 14 | input: "select 1 from t", 15 | }, { 16 | input: "select .1 from t", 17 | }, { 18 | input: "select 1.2e1 from t", 19 | }, { 20 | input: "select 1.2e+1 from t", 21 | }, { 22 | input: "select 1.2e-1 from t", 23 | }, { 24 | input: "select 08.3 from t", 25 | }, { 26 | input: "select -1 from t where b = -2", 27 | }, { 28 | input: "select - -1 from t", 29 | output: "select 1 from t", 30 | }, { 31 | input: "select 1 from t // aa", 32 | output: "select 1 from t", 33 | }, { 34 | input: "select 1 from t -- aa", 35 | output: "select 1 from t", 36 | }, { 37 | input: "select /* simplest */ 1 from t", 38 | }, { 39 | input: "select /* keyword col */ `By` from t", 40 | output: "select /* keyword col */ `by` from t", 41 | }, { 42 | input: "select /* double star **/ 1 from t", 43 | }, { 44 | input: "select /* double */ /* comment */ 1 from t", 45 | }, { 46 | input: "select /* back-quote */ 1 from `t`", 47 | output: "select /* back-quote */ 1 from t", 48 | }, { 49 | input: "select /* back-quote keyword */ 1 from `By`", 50 | output: "select /* back-quote keyword */ 1 from `By`", 51 | }, { 52 | input: "select /* @ */ @@a from b", 53 | }, { 54 | input: "select /* \\0 */ '\\0' from a", 55 | }, { 56 | input: "select 1 /* drop this comment */ from t", 57 | output: "select 1 from t", 58 | }, { 59 | input: "select /* union */ 1 from t union select 1 from t", 60 | }, { 61 | input: "select /* double union */ 1 from t union select 1 from t union select 1 from t", 62 | }, { 63 | input: "select /* union all */ 1 from t union all select 1 from t", 64 | }, { 65 | input: "select /* minus */ 1 from t minus select 1 from t", 66 | }, { 67 | input: "select /* except */ 1 from t except select 1 from t", 68 | }, { 69 | input: "select /* intersect */ 1 from t intersect select 1 from t", 70 | }, { 71 | input: "select /* distinct */ distinct 1 from t", 72 | }, { 73 | input: "select /* for update */ 1 from t for update", 74 | }, { 75 | input: "select /* lock in share mode */ 1 from t lock in share mode", 76 | }, { 77 | input: "select /* select list */ 1, 2 from t", 78 | }, { 79 | input: "select /* * */ * from t", 80 | }, { 81 | input: "select /* column alias */ a b from t", 82 | output: "select /* column alias */ a as b from t", 83 | }, { 84 | input: "select /* column alias with as */ a as b from t", 85 | }, { 86 | input: "select /* keyword column alias */ a as `By` from t", 87 | output: "select /* keyword column alias */ a as `by` from t", 88 | }, { 89 | input: "select /* a.* */ a.* from t", 90 | }, { 91 | input: "select next value for t", 92 | output: "select next value from t", 93 | }, { 94 | input: "select next value from t", 95 | }, { 96 | input: "select /* `By`.* */ `By`.* from t", 97 | }, { 98 | input: "select /* select with bool expr */ a = b from t", 99 | }, { 100 | input: "select /* case_when */ case when a = b then c end from t", 101 | }, { 102 | input: "select /* case_when_else */ case when a = b then c else d end from t", 103 | }, { 104 | input: "select /* case_when_when_else */ case when a = b then c when b = d then d else d end from t", 105 | }, { 106 | input: "select /* case */ case aa when a = b then c end from t", 107 | }, { 108 | input: "select /* parenthesis */ 1 from (t)", 109 | }, { 110 | input: "select /* parenthesis multi-table */ 1 from (t1, t2)", 111 | }, { 112 | input: "select /* table list */ 1 from t1, t2", 113 | }, { 114 | input: "select /* parenthessis in table list 1 */ 1 from (t1), t2", 115 | }, { 116 | input: "select /* parenthessis in table list 2 */ 1 from t1, (t2)", 117 | }, { 118 | input: "select /* use */ 1 from t1 use index (a) where b = 1", 119 | }, { 120 | input: "select /* keyword index */ 1 from t1 use index (`By`) where b = 1", 121 | output: "select /* keyword index */ 1 from t1 use index (`by`) where b = 1", 122 | }, { 123 | input: "select /* ignore */ 1 from t1 as t2 ignore index (a), t3 use index (b) where b = 1", 124 | }, { 125 | input: "select /* use */ 1 from t1 as t2 use index (a), t3 use index (b) where b = 1", 126 | }, { 127 | input: "select /* force */ 1 from t1 as t2 force index (a), t3 force index (b) where b = 1", 128 | }, { 129 | input: "select /* table alias */ 1 from t t1", 130 | output: "select /* table alias */ 1 from t as t1", 131 | }, { 132 | input: "select /* table alias with as */ 1 from t as t1", 133 | }, { 134 | input: "select /* keyword table alias */ 1 from t as `By`", 135 | }, { 136 | input: "select /* join */ 1 from t1 join t2", 137 | }, { 138 | input: "select /* join on */ 1 from t1 join t2 on a = b", 139 | }, { 140 | input: "select /* inner join */ 1 from t1 inner join t2", 141 | output: "select /* inner join */ 1 from t1 join t2", 142 | }, { 143 | input: "select /* cross join */ 1 from t1 cross join t2", 144 | output: "select /* cross join */ 1 from t1 join t2", 145 | }, { 146 | input: "select /* straight_join */ 1 from t1 straight_join t2", 147 | }, { 148 | input: "select /* straight_join on */ 1 from t1 straight_join t2 on a = b", 149 | }, { 150 | input: "select /* left join */ 1 from t1 left join t2 on a = b", 151 | }, { 152 | input: "select /* left outer join */ 1 from t1 left outer join t2 on a = b", 153 | output: "select /* left outer join */ 1 from t1 left join t2 on a = b", 154 | }, { 155 | input: "select /* right join */ 1 from t1 right join t2 on a = b", 156 | }, { 157 | input: "select /* right outer join */ 1 from t1 right outer join t2 on a = b", 158 | output: "select /* right outer join */ 1 from t1 right join t2 on a = b", 159 | }, { 160 | input: "select /* natural join */ 1 from t1 natural join t2", 161 | }, { 162 | input: "select /* natural left join */ 1 from t1 natural left join t2", 163 | }, { 164 | input: "select /* natural left outer join */ 1 from t1 natural left join t2", 165 | output: "select /* natural left outer join */ 1 from t1 natural left join t2", 166 | }, { 167 | input: "select /* natural right join */ 1 from t1 natural right join t2", 168 | }, { 169 | input: "select /* natural right outer join */ 1 from t1 natural right join t2", 170 | output: "select /* natural right outer join */ 1 from t1 natural right join t2", 171 | }, { 172 | input: "select /* join on */ 1 from t1 join t2 on a = b", 173 | }, { 174 | input: "select /* s.t */ 1 from s.t", 175 | }, { 176 | input: "select /* keyword schema & table name */ 1 from `By`.`bY`", 177 | }, { 178 | input: "select /* select in from */ 1 from (select 1 from t) as a", 179 | }, { 180 | input: "select /* select in from with no as */ 1 from (select 1 from t) a", 181 | output: "select /* select in from with no as */ 1 from (select 1 from t) as a", 182 | }, { 183 | input: "select /* where */ 1 from t where a = b", 184 | }, { 185 | input: "select /* and */ 1 from t where a = b and a = c", 186 | }, { 187 | input: "select /* or */ 1 from t where a = b or a = c", 188 | }, { 189 | input: "select /* not */ 1 from t where not a = b", 190 | }, { 191 | input: "select /* bool is */ 1 from t where a = b is null", 192 | }, { 193 | input: "select /* bool is not */ 1 from t where a = b is not false", 194 | }, { 195 | input: "select /* true */ 1 from t where true", 196 | }, { 197 | input: "select /* false */ 1 from t where false", 198 | }, { 199 | input: "select /* exists */ 1 from t where exists (select 1 from t)", 200 | }, { 201 | input: "select /* keyrange */ 1 from t where keyrange(1, 2)", 202 | }, { 203 | input: "select /* (boolean) */ 1 from t where not (a = b)", 204 | }, { 205 | input: "select /* in value list */ 1 from t where a in (b, c)", 206 | }, { 207 | input: "select /* in select */ 1 from t where a in (select 1 from t)", 208 | }, { 209 | input: "select /* not in */ 1 from t where a not in (b, c)", 210 | }, { 211 | input: "select /* like */ 1 from t where a like b", 212 | }, { 213 | input: "select /* not like */ 1 from t where a not like b", 214 | }, { 215 | input: "select /* regexp */ 1 from t where a regexp b", 216 | }, { 217 | input: "select /* not regexp */ 1 from t where a not regexp b", 218 | }, { 219 | input: "select /* rlike */ 1 from t where a rlike b", 220 | output: "select /* rlike */ 1 from t where a regexp b", 221 | }, { 222 | input: "select /* not rlike */ 1 from t where a not rlike b", 223 | output: "select /* not rlike */ 1 from t where a not regexp b", 224 | }, { 225 | input: "select /* between */ 1 from t where a between b and c", 226 | }, { 227 | input: "select /* not between */ 1 from t where a not between b and c", 228 | }, { 229 | input: "select /* is null */ 1 from t where a is null", 230 | }, { 231 | input: "select /* is not null */ 1 from t where a is not null", 232 | }, { 233 | input: "select /* is true */ 1 from t where a is true", 234 | }, { 235 | input: "select /* is not true */ 1 from t where a is not true", 236 | }, { 237 | input: "select /* is false */ 1 from t where a is false", 238 | }, { 239 | input: "select /* is not false */ 1 from t where a is not false", 240 | }, { 241 | input: "select /* < */ 1 from t where a < b", 242 | }, { 243 | input: "select /* <= */ 1 from t where a <= b", 244 | }, { 245 | input: "select /* >= */ 1 from t where a >= b", 246 | }, { 247 | input: "select /* > */ 1 from t where a > b", 248 | }, { 249 | input: "select /* != */ 1 from t where a != b", 250 | }, { 251 | input: "select /* <> */ 1 from t where a <> b", 252 | output: "select /* <> */ 1 from t where a != b", 253 | }, { 254 | input: "select /* <=> */ 1 from t where a <=> b", 255 | }, { 256 | input: "select /* != */ 1 from t where a != b", 257 | }, { 258 | input: "select /* single value expre list */ 1 from t where a in (b)", 259 | }, { 260 | input: "select /* select as a value expression */ 1 from t where a = (select a from t)", 261 | }, { 262 | input: "select /* parenthesised value */ 1 from t where a = (b)", 263 | }, { 264 | input: "select /* over-parenthesize */ ((1)) from t where ((a)) in (((1))) and ((a, b)) in ((((1, 1))), ((2, 2)))", 265 | }, { 266 | input: "select /* dot-parenthesize */ (a.b) from t where (b.c) = 2", 267 | }, { 268 | input: "select /* & */ 1 from t where a = b & c", 269 | }, { 270 | input: "select /* & */ 1 from t where a = b & c", 271 | }, { 272 | input: "select /* | */ 1 from t where a = b | c", 273 | }, { 274 | input: "select /* ^ */ 1 from t where a = b ^ c", 275 | }, { 276 | input: "select /* + */ 1 from t where a = b + c", 277 | }, { 278 | input: "select /* - */ 1 from t where a = b - c", 279 | }, { 280 | input: "select /* * */ 1 from t where a = b * c", 281 | }, { 282 | input: "select /* / */ 1 from t where a = b / c", 283 | }, { 284 | input: "select /* % */ 1 from t where a = b % c", 285 | }, { 286 | input: "select /* << */ 1 from t where a = b << c", 287 | }, { 288 | input: "select /* >> */ 1 from t where a = b >> c", 289 | }, { 290 | input: "select /* % no space */ 1 from t where a = b%c", 291 | output: "select /* % no space */ 1 from t where a = b % c", 292 | }, { 293 | input: "select /* u+ */ 1 from t where a = +b", 294 | }, { 295 | input: "select /* u- */ 1 from t where a = -b", 296 | }, { 297 | input: "select /* u~ */ 1 from t where a = ~b", 298 | }, { 299 | input: "select /* empty function */ 1 from t where a = b()", 300 | }, { 301 | input: "select /* function with 1 param */ 1 from t where a = b(c)", 302 | }, { 303 | input: "select /* function with many params */ 1 from t where a = b(c, d)", 304 | }, { 305 | input: "select /* if as func */ 1 from t where a = if(b)", 306 | }, { 307 | input: "select /* function with distinct */ count(distinct a) from t", 308 | }, { 309 | input: "select /* a */ a from t", 310 | }, { 311 | input: "select /* a.b */ a.b from t", 312 | }, { 313 | input: "select /* keyword a.b */ `By`.`bY` from t", 314 | output: "select /* keyword a.b */ `By`.`by` from t", 315 | }, { 316 | input: "select /* string */ 'a' from t", 317 | }, { 318 | input: "select /* double quoted string */ \"a\" from t", 319 | output: "select /* double quoted string */ 'a' from t", 320 | }, { 321 | input: "select /* quote quote in string */ 'a''a' from t", 322 | output: "select /* quote quote in string */ 'a\\'a' from t", 323 | }, { 324 | input: "select /* double quote quote in string */ \"a\"\"a\" from t", 325 | output: "select /* double quote quote in string */ 'a\\\"a' from t", 326 | }, { 327 | input: "select /* quote in double quoted string */ \"a'a\" from t", 328 | output: "select /* quote in double quoted string */ 'a\\'a' from t", 329 | }, { 330 | input: "select /* backslash quote in string */ 'a\\'a' from t", 331 | }, { 332 | input: "select /* literal backslash in string */ 'a\\\\na' from t", 333 | }, { 334 | input: "select /* all escapes */ '\\0\\'\\\"\\b\\n\\r\\t\\Z\\\\' from t", 335 | }, { 336 | input: "select /* non-escape */ '\\x' from t", 337 | output: "select /* non-escape */ 'x' from t", 338 | }, { 339 | input: "select /* unescaped backslash */ '\\n' from t", 340 | }, { 341 | input: "select /* value argument */ :a from t", 342 | }, { 343 | input: "select /* value argument with digit */ :a1 from t", 344 | }, { 345 | input: "select /* value argument with dot */ :a.b from t", 346 | }, { 347 | input: "select /* positional argument */ ? from t", 348 | output: "select /* positional argument */ :v1 from t", 349 | }, { 350 | input: "select /* multiple positional arguments */ ?, ? from t", 351 | output: "select /* multiple positional arguments */ :v1, :v2 from t", 352 | }, { 353 | input: "select /* list arg */ * from t where a in ::list", 354 | }, { 355 | input: "select /* list arg not in */ * from t where a not in ::list", 356 | }, { 357 | input: "select /* null */ null from t", 358 | }, { 359 | input: "select /* octal */ 010 from t", 360 | }, { 361 | input: "select /* hex */ 0xf0 from t", 362 | }, { 363 | input: "select /* hex caps */ 0xF0 from t", 364 | }, { 365 | input: "select /* float */ 0.1 from t", 366 | }, { 367 | input: "select /* group by */ 1 from t group by a", 368 | }, { 369 | input: "select /* having */ 1 from t having a = b", 370 | }, { 371 | input: "select /* simple order by */ 1 from t order by a", 372 | output: "select /* simple order by */ 1 from t order by a asc", 373 | }, { 374 | input: "select /* order by asc */ 1 from t order by a asc", 375 | }, { 376 | input: "select /* order by desc */ 1 from t order by a desc", 377 | }, { 378 | input: "select /* limit a */ 1 from t limit a", 379 | }, { 380 | input: "select /* limit a,b */ 1 from t limit a, b", 381 | }, { 382 | input: "select /* binary unary */ a- -b from t", 383 | output: "select /* binary unary */ a - -b from t", 384 | }, { 385 | input: "select /* - - */ - -b from t", 386 | }, { 387 | input: "select /* interval */ adddate('2008-01-02', interval 31 day) from t", 388 | }, { 389 | input: "select /* dual */ 1 from dual", 390 | }, { 391 | input: "select /* Dual */ 1 from Dual", 392 | output: "select /* Dual */ 1 from dual", 393 | }, { 394 | input: "select /* DUAL */ 1 from Dual", 395 | output: "select /* DUAL */ 1 from dual", 396 | }, { 397 | input: "insert /* simple */ into a values (1)", 398 | }, { 399 | input: "insert /* a.b */ into a.b values (1)", 400 | }, { 401 | input: "insert /* multi-value */ into a values (1, 2)", 402 | }, { 403 | input: "insert /* multi-value list */ into a values (1, 2), (3, 4)", 404 | }, { 405 | input: "insert /* set */ into a set a = 1, a.b = 2", 406 | output: "insert /* set */ into a(a, a.b) values (1, 2)", 407 | }, { 408 | input: "insert /* value expression list */ into a values (a + 1, 2 * 3)", 409 | }, { 410 | input: "insert /* column list */ into a(a, b) values (1, 2)", 411 | }, { 412 | input: "insert /* qualified column list */ into a(a, a.b) values (1, 2)", 413 | }, { 414 | input: "insert /* select */ into a select b, c from d", 415 | }, { 416 | input: "insert /* on duplicate */ into a values (1, 2) on duplicate key update b = func(a), c = d", 417 | }, { 418 | input: "update /* simple */ a set b = 3", 419 | }, { 420 | input: "update /* a.b */ a.b set b = 3", 421 | }, { 422 | input: "update /* b.c */ a set b.c = 3", 423 | }, { 424 | input: "update /* list */ a set b = 3, c = 4", 425 | }, { 426 | input: "update /* expression */ a set b = 3 + 4", 427 | }, { 428 | input: "update /* where */ a set b = 3 where a = b", 429 | }, { 430 | input: "update /* order */ a set b = 3 order by c desc", 431 | }, { 432 | input: "update /* limit */ a set b = 3 limit c", 433 | }, { 434 | input: "delete /* simple */ from a", 435 | }, { 436 | input: "delete /* a.b */ from a.b", 437 | }, { 438 | input: "delete /* where */ from a where a = b", 439 | }, { 440 | input: "delete /* order */ from a order by b desc", 441 | }, { 442 | input: "delete /* limit */ from a limit b", 443 | }, { 444 | input: "set /* simple */ a = 3", 445 | }, { 446 | input: "set /* list */ a = 3, b = 4", 447 | }, { 448 | input: "alter ignore table a add foo", 449 | output: "alter table a", 450 | }, { 451 | input: "alter table a add foo", 452 | output: "alter table a", 453 | }, { 454 | input: "alter table `By` add foo", 455 | output: "alter table `By`", 456 | }, { 457 | input: "alter table a alter foo", 458 | output: "alter table a", 459 | }, { 460 | input: "alter table a change foo", 461 | output: "alter table a", 462 | }, { 463 | input: "alter table a modify foo", 464 | output: "alter table a", 465 | }, { 466 | input: "alter table a drop foo", 467 | output: "alter table a", 468 | }, { 469 | input: "alter table a disable foo", 470 | output: "alter table a", 471 | }, { 472 | input: "alter table a enable foo", 473 | output: "alter table a", 474 | }, { 475 | input: "alter table a order foo", 476 | output: "alter table a", 477 | }, { 478 | input: "alter table a default foo", 479 | output: "alter table a", 480 | }, { 481 | input: "alter table a discard foo", 482 | output: "alter table a", 483 | }, { 484 | input: "alter table a import foo", 485 | output: "alter table a", 486 | }, { 487 | input: "alter table a rename b", 488 | output: "rename table a b", 489 | }, { 490 | input: "alter table `By` rename `bY`", 491 | output: "rename table `By` `bY`", 492 | }, { 493 | input: "alter table a rename to b", 494 | output: "rename table a b", 495 | }, { 496 | input: "create table a", 497 | }, { 498 | input: "create table `by`", 499 | }, { 500 | input: "create table if not exists a", 501 | output: "create table a", 502 | }, { 503 | input: "create index a on b", 504 | output: "alter table b", 505 | }, { 506 | input: "create unique index a on b", 507 | output: "alter table b", 508 | }, { 509 | input: "create unique index a using foo on b", 510 | output: "alter table b", 511 | }, { 512 | input: "create view a", 513 | output: "create table a", 514 | }, { 515 | input: "alter view a", 516 | output: "alter table a", 517 | }, { 518 | input: "drop view a", 519 | output: "drop table a", 520 | }, { 521 | input: "drop table a", 522 | }, { 523 | input: "drop table if exists a", 524 | output: "drop table a", 525 | }, { 526 | input: "drop view if exists a", 527 | output: "drop table a", 528 | }, { 529 | input: "drop index b on a", 530 | output: "alter table a", 531 | }, { 532 | input: "analyze table a", 533 | output: "alter table a", 534 | }, { 535 | input: "show foobar", 536 | output: "other", 537 | }, { 538 | input: "describe foobar", 539 | output: "other", 540 | }, { 541 | input: "explain foobar", 542 | output: "other", 543 | }} 544 | for _, tcase := range validSQL { 545 | if tcase.output == "" { 546 | tcase.output = tcase.input 547 | } 548 | tree, err := Parse(tcase.input) 549 | if err != nil { 550 | t.Errorf("input: %s, err: %v", tcase.input, err) 551 | continue 552 | } 553 | out := String(tree) 554 | if out != tcase.output { 555 | t.Errorf("out: %s, want %s", out, tcase.output) 556 | } 557 | // This test just exercises the tree walking functionality. 558 | // There's no way automated way to verify that a node calls 559 | // all its children. But we can examine code coverage and 560 | // ensure that all WalkSubtree functions were called. 561 | Walk(func(node SQLNode) (bool, error) { 562 | return true, nil 563 | }, tree) 564 | } 565 | } 566 | 567 | func TestCaseSensitivity(t *testing.T) { 568 | validSQL := []struct { 569 | input string 570 | output string 571 | }{{ 572 | input: "create table A", 573 | }, { 574 | input: "create index b on A", 575 | output: "alter table A", 576 | }, { 577 | input: "alter table A foo", 578 | output: "alter table A", 579 | }, { 580 | input: "alter table A rename to B", 581 | output: "rename table A B", 582 | }, { 583 | input: "rename table A to B", 584 | output: "rename table A B", 585 | }, { 586 | input: "drop table B", 587 | }, { 588 | input: "drop index b on A", 589 | output: "alter table A", 590 | }, { 591 | input: "select a from B", 592 | }, { 593 | input: "select A as B from C", 594 | output: "select a as b from C", 595 | }, { 596 | input: "select B.* from c", 597 | }, { 598 | input: "select B.A from c", 599 | output: "select B.a from c", 600 | }, { 601 | input: "select * from B as C", 602 | }, { 603 | input: "select * from A.B", 604 | }, { 605 | input: "update A set b = 1", 606 | }, { 607 | input: "update A.B set b = 1", 608 | }, { 609 | input: "select A() from b", 610 | output: "select a() from b", 611 | }, { 612 | input: "select A(B, C) from b", 613 | output: "select a(b, c) from b", 614 | }, { 615 | input: "select A(distinct B, C) from b", 616 | output: "select a(distinct b, c) from b", 617 | }, { 618 | input: "select IF(B, C) from b", 619 | output: "select if(b, c) from b", 620 | }, { 621 | input: "select * from b use index (A)", 622 | output: "select * from b use index (a)", 623 | }, { 624 | input: "insert into A(A, B) values (1, 2)", 625 | output: "insert into A(a, b) values (1, 2)", 626 | }, { 627 | input: "CREATE TABLE A", 628 | output: "create table A", 629 | }, { 630 | input: "create view A", 631 | output: "create table a", 632 | }, { 633 | input: "alter view A", 634 | output: "alter table a", 635 | }, { 636 | input: "drop view A", 637 | output: "drop table a", 638 | }} 639 | for _, tcase := range validSQL { 640 | if tcase.output == "" { 641 | tcase.output = tcase.input 642 | } 643 | tree, err := Parse(tcase.input) 644 | if err != nil { 645 | t.Errorf("input: %s, err: %v", tcase.input, err) 646 | continue 647 | } 648 | out := String(tree) 649 | if out != tcase.output { 650 | t.Errorf("out: %s, want %s", out, tcase.output) 651 | } 652 | } 653 | } 654 | 655 | func TestErrors(t *testing.T) { 656 | invalidSQL := []struct { 657 | input string 658 | output string 659 | }{{ 660 | input: "select !8 from t", 661 | output: "syntax error at position 9 near '!'", 662 | }, { 663 | input: "select $ from t", 664 | output: "syntax error at position 9 near '$'", 665 | }, { 666 | input: "select : from t", 667 | output: "syntax error at position 9 near ':'", 668 | }, { 669 | input: "select 078 from t", 670 | output: "syntax error at position 11 near '078'", 671 | }, { 672 | input: "select `1a` from t", 673 | output: "syntax error at position 9 near '1'", 674 | }, { 675 | input: "select `:table` from t", 676 | output: "syntax error at position 9 near ':'", 677 | }, { 678 | input: "select `table:` from t", 679 | output: "syntax error at position 14 near 'table'", 680 | }, { 681 | input: "select 'aa\\", 682 | output: "syntax error at position 12 near 'aa'", 683 | }, { 684 | input: "select 'aa", 685 | output: "syntax error at position 12 near 'aa'", 686 | }, { 687 | input: "select * from t where :1 = 2", 688 | output: "syntax error at position 24 near ':'", 689 | }, { 690 | input: "select * from t where :. = 2", 691 | output: "syntax error at position 24 near ':'", 692 | }, { 693 | input: "select * from t where ::1 = 2", 694 | output: "syntax error at position 25 near '::'", 695 | }, { 696 | input: "select * from t where ::. = 2", 697 | output: "syntax error at position 25 near '::'", 698 | }, { 699 | input: "update a set c = values(1)", 700 | output: "syntax error at position 24 near 'values'", 701 | }, { 702 | input: "update a set c = last_insert_id(1)", 703 | output: "syntax error at position 32 near 'last_insert_id'", 704 | }, { 705 | input: "select(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F" + 706 | "(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(" + 707 | "F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F" + 708 | "(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(" + 709 | "F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F" + 710 | "(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(" + 711 | "F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F" + 712 | "(F(F(F(F(F(F(F(F(F(F(F(F(", 713 | output: "max nesting level reached at position 406", 714 | }, { 715 | input: "select(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F" + 716 | "(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(" + 717 | "F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F" + 718 | "(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(" + 719 | "F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F" + 720 | "(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(" + 721 | "F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F" + 722 | "(F(F(F(F(F(F(F(F(F(F(F(", 723 | output: "syntax error at position 405", 724 | }, { 725 | input: "select /* aa", 726 | output: "syntax error at position 13 near '/* aa'", 727 | }, { 728 | // This construct is considered invalid due to a grammar conflict. 729 | input: "insert into a select * from b join c on duplicate key update d=e", 730 | output: "syntax error at position 50 near 'duplicate'", 731 | }, { 732 | input: "select * from a left join b", 733 | output: "syntax error at position 29", 734 | }, { 735 | input: "select * from a natural join b on c = d", 736 | output: "syntax error at position 34 near 'on'", 737 | }, { 738 | input: "select next id from a", 739 | output: "expecting value after next at position 23", 740 | }} 741 | for _, tcase := range invalidSQL { 742 | if tcase.output == "" { 743 | tcase.output = tcase.input 744 | } 745 | _, err := Parse(tcase.input) 746 | if err == nil || err.Error() != tcase.output { 747 | t.Errorf("%s: %v, want %s", tcase.input, err, tcase.output) 748 | } 749 | } 750 | } 751 | 752 | func BenchmarkParse1(b *testing.B) { 753 | sql := "select 'abcd', 20, 30.0, eid from a where 1=eid and name='3'" 754 | for i := 0; i < b.N; i++ { 755 | ast, err := Parse(sql) 756 | if err != nil { 757 | b.Fatal(err) 758 | } 759 | _ = String(ast) 760 | } 761 | } 762 | 763 | func BenchmarkParse2(b *testing.B) { 764 | sql := "select aaaa, bbb, ccc, ddd, eeee, ffff, gggg, hhhh, iiii from tttt, ttt1, ttt3 where aaaa = bbbb and bbbb = cccc and dddd+1 = eeee group by fff, gggg having hhhh = iiii and iiii = jjjj order by kkkk, llll limit 3, 4" 765 | for i := 0; i < b.N; i++ { 766 | ast, err := Parse(sql) 767 | if err != nil { 768 | b.Fatal(err) 769 | } 770 | _ = String(ast) 771 | } 772 | } 773 | -------------------------------------------------------------------------------- /internal/sqlparser/parsed_query.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import ( 8 | "bytes" 9 | "encoding/json" 10 | "errors" 11 | "fmt" 12 | 13 | "github.com/guregu/mogi/internal/sqltypes" 14 | ) 15 | 16 | type bindLocation struct { 17 | offset, length int 18 | } 19 | 20 | // ParsedQuery represents a parsed query where 21 | // bind locations are precompued for fast substitutions. 22 | type ParsedQuery struct { 23 | Query string 24 | bindLocations []bindLocation 25 | } 26 | 27 | // GenerateQuery generates a query by substituting the specified 28 | // bindVariables. 29 | func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]interface{}) ([]byte, error) { 30 | if len(pq.bindLocations) == 0 { 31 | return []byte(pq.Query), nil 32 | } 33 | buf := bytes.NewBuffer(make([]byte, 0, len(pq.Query))) 34 | current := 0 35 | for _, loc := range pq.bindLocations { 36 | buf.WriteString(pq.Query[current:loc.offset]) 37 | name := pq.Query[loc.offset : loc.offset+loc.length] 38 | supplied, _, err := FetchBindVar(name, bindVariables) 39 | if err != nil { 40 | return nil, err 41 | } 42 | if err := EncodeValue(buf, supplied); err != nil { 43 | return nil, err 44 | } 45 | current = loc.offset + loc.length 46 | } 47 | buf.WriteString(pq.Query[current:]) 48 | return buf.Bytes(), nil 49 | } 50 | 51 | // MarshalJSON is a custom JSON marshaler for ParsedQuery. 52 | func (pq *ParsedQuery) MarshalJSON() ([]byte, error) { 53 | return json.Marshal(pq.Query) 54 | } 55 | 56 | // EncodeValue encodes one bind variable value into the query. 57 | func EncodeValue(buf *bytes.Buffer, value interface{}) error { 58 | switch bindVal := value.(type) { 59 | case nil: 60 | buf.WriteString("null") 61 | case []sqltypes.Value: 62 | for i := 0; i < len(bindVal); i++ { 63 | if i != 0 { 64 | buf.WriteString(", ") 65 | } 66 | if err := EncodeValue(buf, bindVal[i]); err != nil { 67 | return err 68 | } 69 | } 70 | case [][]sqltypes.Value: 71 | for i := 0; i < len(bindVal); i++ { 72 | if i != 0 { 73 | buf.WriteString(", ") 74 | } 75 | buf.WriteByte('(') 76 | if err := EncodeValue(buf, bindVal[i]); err != nil { 77 | return err 78 | } 79 | buf.WriteByte(')') 80 | } 81 | case []interface{}: 82 | buf.WriteByte('(') 83 | for i, v := range bindVal { 84 | if i != 0 { 85 | buf.WriteString(", ") 86 | } 87 | if err := EncodeValue(buf, v); err != nil { 88 | return err 89 | } 90 | } 91 | buf.WriteByte(')') 92 | case TupleEqualityList: 93 | if err := bindVal.Encode(buf); err != nil { 94 | return err 95 | } 96 | default: 97 | v, err := sqltypes.BuildValue(bindVal) 98 | if err != nil { 99 | return err 100 | } 101 | v.EncodeSQL(buf) 102 | } 103 | return nil 104 | } 105 | 106 | // TupleEqualityList is for generating equality constraints 107 | // for tables that have composite primary keys. 108 | type TupleEqualityList struct { 109 | Columns []string 110 | Rows [][]sqltypes.Value 111 | } 112 | 113 | // Encode generates the where clause constraints for the tuple 114 | // equality. 115 | func (tpl *TupleEqualityList) Encode(buf *bytes.Buffer) error { 116 | if len(tpl.Rows) == 0 { 117 | return errors.New("cannot encode with 0 rows") 118 | } 119 | if len(tpl.Columns) == 1 { 120 | return tpl.encodeAsIN(buf) 121 | } 122 | return tpl.encodeAsEquality(buf) 123 | } 124 | 125 | func (tpl *TupleEqualityList) encodeAsIN(buf *bytes.Buffer) error { 126 | buf.WriteString(tpl.Columns[0]) 127 | buf.WriteString(" in (") 128 | for i, r := range tpl.Rows { 129 | if len(r) != 1 { 130 | return errors.New("values don't match column count") 131 | } 132 | if i != 0 { 133 | buf.WriteString(", ") 134 | } 135 | if err := EncodeValue(buf, r); err != nil { 136 | return err 137 | } 138 | } 139 | buf.WriteByte(')') 140 | return nil 141 | } 142 | 143 | func (tpl *TupleEqualityList) encodeAsEquality(buf *bytes.Buffer) error { 144 | for i, r := range tpl.Rows { 145 | if i != 0 { 146 | buf.WriteString(" or ") 147 | } 148 | buf.WriteString("(") 149 | for j, c := range tpl.Columns { 150 | if j != 0 { 151 | buf.WriteString(" and ") 152 | } 153 | buf.WriteString(c) 154 | buf.WriteString(" = ") 155 | if err := EncodeValue(buf, r[j]); err != nil { 156 | return err 157 | } 158 | } 159 | buf.WriteByte(')') 160 | } 161 | return nil 162 | } 163 | 164 | // FetchBindVar resolves the bind variable by fetching it from bindVariables. 165 | func FetchBindVar(name string, bindVariables map[string]interface{}) (val interface{}, isList bool, err error) { 166 | name = name[1:] 167 | if name[0] == ':' { 168 | name = name[1:] 169 | isList = true 170 | } 171 | supplied, ok := bindVariables[name] 172 | if !ok { 173 | return nil, false, fmt.Errorf("missing bind var %s", name) 174 | } 175 | list, gotList := supplied.([]interface{}) 176 | if isList { 177 | if !gotList { 178 | return nil, false, fmt.Errorf("unexpected list arg type %T for key %s", supplied, name) 179 | } 180 | if len(list) == 0 { 181 | return nil, false, fmt.Errorf("empty list supplied for %s", name) 182 | } 183 | return list, true, nil 184 | } 185 | if gotList { 186 | return nil, false, fmt.Errorf("unexpected arg type %T for key %s", supplied, name) 187 | } 188 | return supplied, false, nil 189 | } 190 | -------------------------------------------------------------------------------- /internal/sqlparser/parsed_query_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | 11 | "github.com/guregu/mogi/internal/sqltypes" 12 | ) 13 | 14 | func TestParsedQuery(t *testing.T) { 15 | tcases := []struct { 16 | desc string 17 | query string 18 | bindVars map[string]interface{} 19 | output string 20 | }{ 21 | { 22 | "no subs", 23 | "select * from a where id = 2", 24 | map[string]interface{}{ 25 | "id": 1, 26 | }, 27 | "select * from a where id = 2", 28 | }, { 29 | "simple bindvar sub", 30 | "select * from a where id1 = :id1 and id2 = :id2", 31 | map[string]interface{}{ 32 | "id1": 1, 33 | "id2": nil, 34 | }, 35 | "select * from a where id1 = 1 and id2 = null", 36 | }, { 37 | "missing bind var", 38 | "select * from a where id1 = :id1 and id2 = :id2", 39 | map[string]interface{}{ 40 | "id1": 1, 41 | }, 42 | "missing bind var id2", 43 | }, { 44 | "unencodable bind var", 45 | "select * from a where id1 = :id", 46 | map[string]interface{}{ 47 | "id": make([]int, 1), 48 | }, 49 | "unexpected type []int: [0]", 50 | }, { 51 | "list inside bind vars", 52 | "select * from a where id in (:vals)", 53 | map[string]interface{}{ 54 | "vals": []sqltypes.Value{ 55 | sqltypes.MakeTrusted(sqltypes.Int64, []byte("1")), 56 | sqltypes.MakeString([]byte("aa")), 57 | }, 58 | }, 59 | "select * from a where id in (1, 'aa')", 60 | }, { 61 | "two lists inside bind vars", 62 | "select * from a where id in (:vals)", 63 | map[string]interface{}{ 64 | "vals": [][]sqltypes.Value{ 65 | { 66 | sqltypes.MakeTrusted(sqltypes.Int64, []byte("1")), 67 | sqltypes.MakeString([]byte("aa")), 68 | }, 69 | { 70 | {}, 71 | sqltypes.MakeString([]byte("bb")), 72 | }, 73 | }, 74 | }, 75 | "select * from a where id in ((1, 'aa'), (null, 'bb'))", 76 | }, { 77 | "list bind vars", 78 | "select * from a where id in ::vals", 79 | map[string]interface{}{ 80 | "vals": []interface{}{ 81 | 1, 82 | "aa", 83 | }, 84 | }, 85 | "select * from a where id in (1, 'aa')", 86 | }, { 87 | "list bind vars single argument", 88 | "select * from a where id in ::vals", 89 | map[string]interface{}{ 90 | "vals": []interface{}{ 91 | 1, 92 | }, 93 | }, 94 | "select * from a where id in (1)", 95 | }, { 96 | "list bind vars 0 arguments", 97 | "select * from a where id in ::vals", 98 | map[string]interface{}{ 99 | "vals": []interface{}{}, 100 | }, 101 | "empty list supplied for vals", 102 | }, { 103 | "non-list bind var supplied", 104 | "select * from a where id in ::vals", 105 | map[string]interface{}{ 106 | "vals": 1, 107 | }, 108 | "unexpected list arg type int for key vals", 109 | }, { 110 | "list bind var for non-list", 111 | "select * from a where id = :vals", 112 | map[string]interface{}{ 113 | "vals": []interface{}{1}, 114 | }, 115 | "unexpected arg type []interface {} for key vals", 116 | }, { 117 | "single column tuple equality", 118 | // We have to use an incorrect construct to get around the parser. 119 | "select * from a where b = :equality", 120 | map[string]interface{}{ 121 | "equality": TupleEqualityList{ 122 | Columns: []string{"pk"}, 123 | Rows: [][]sqltypes.Value{ 124 | {sqltypes.MakeTrusted(sqltypes.Int64, []byte("1"))}, 125 | {sqltypes.MakeString([]byte("aa"))}, 126 | }, 127 | }, 128 | }, 129 | "select * from a where b = pk in (1, 'aa')", 130 | }, { 131 | "multi column tuple equality", 132 | "select * from a where b = :equality", 133 | map[string]interface{}{ 134 | "equality": TupleEqualityList{ 135 | Columns: []string{"pk1", "pk2"}, 136 | Rows: [][]sqltypes.Value{ 137 | { 138 | sqltypes.MakeTrusted(sqltypes.Int64, []byte("1")), 139 | sqltypes.MakeString([]byte("aa")), 140 | }, 141 | { 142 | sqltypes.MakeTrusted(sqltypes.Int64, []byte("2")), 143 | sqltypes.MakeString([]byte("bb")), 144 | }, 145 | }, 146 | }, 147 | }, 148 | "select * from a where b = (pk1 = 1 and pk2 = 'aa') or (pk1 = 2 and pk2 = 'bb')", 149 | }, { 150 | "0 rows", 151 | "select * from a where b = :equality", 152 | map[string]interface{}{ 153 | "equality": TupleEqualityList{ 154 | Columns: []string{"pk"}, 155 | Rows: [][]sqltypes.Value{}, 156 | }, 157 | }, 158 | "cannot encode with 0 rows", 159 | }, { 160 | "values don't match column count", 161 | "select * from a where b = :equality", 162 | map[string]interface{}{ 163 | "equality": TupleEqualityList{ 164 | Columns: []string{"pk"}, 165 | Rows: [][]sqltypes.Value{ 166 | { 167 | sqltypes.MakeTrusted(sqltypes.Int64, []byte("1")), 168 | sqltypes.MakeString([]byte("aa")), 169 | }, 170 | }, 171 | }, 172 | }, 173 | "values don't match column count", 174 | }, 175 | } 176 | 177 | for _, tcase := range tcases { 178 | tree, err := Parse(tcase.query) 179 | if err != nil { 180 | t.Errorf("parse failed for %s: %v", tcase.desc, err) 181 | continue 182 | } 183 | buf := NewTrackedBuffer(nil) 184 | buf.Myprintf("%v", tree) 185 | pq := buf.ParsedQuery() 186 | bytes, err := pq.GenerateQuery(tcase.bindVars) 187 | var got string 188 | if err != nil { 189 | got = err.Error() 190 | } else { 191 | got = string(bytes) 192 | } 193 | if got != tcase.output { 194 | t.Errorf("for test case: %s, got: '%s', want '%s'", tcase.desc, got, tcase.output) 195 | } 196 | } 197 | } 198 | 199 | func TestGenerateParsedQuery(t *testing.T) { 200 | stmt, err := Parse("select * from a where id =:id") 201 | if err != nil { 202 | t.Error(err) 203 | return 204 | } 205 | pq := GenerateParsedQuery(stmt) 206 | want := &ParsedQuery{ 207 | Query: "select * from a where id = :id", 208 | bindLocations: []bindLocation{{offset: 27, length: 3}}, 209 | } 210 | if !reflect.DeepEqual(pq, want) { 211 | t.Errorf("GenerateParsedQuery: %+v, want %+v", pq, want) 212 | } 213 | } 214 | -------------------------------------------------------------------------------- /internal/sqlparser/precedence_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import ( 8 | "fmt" 9 | "testing" 10 | ) 11 | 12 | func readable(node Expr) string { 13 | switch node := node.(type) { 14 | case *OrExpr: 15 | return fmt.Sprintf("(%s or %s)", readable(node.Left), readable(node.Right)) 16 | case *AndExpr: 17 | return fmt.Sprintf("(%s and %s)", readable(node.Left), readable(node.Right)) 18 | case *BinaryExpr: 19 | return fmt.Sprintf("(%s %s %s)", readable(node.Left), node.Operator, readable(node.Right)) 20 | case *IsExpr: 21 | return fmt.Sprintf("(%s %s)", readable(node.Expr), node.Operator) 22 | default: 23 | return String(node) 24 | } 25 | } 26 | 27 | func TestAndOrPrecedence(t *testing.T) { 28 | validSQL := []struct { 29 | input string 30 | output string 31 | }{{ 32 | input: "select * from a where a=b and c=d or e=f", 33 | output: "((a = b and c = d) or e = f)", 34 | }, { 35 | input: "select * from a where a=b or c=d and e=f", 36 | output: "(a = b or (c = d and e = f))", 37 | }} 38 | for _, tcase := range validSQL { 39 | tree, err := Parse(tcase.input) 40 | if err != nil { 41 | t.Error(err) 42 | continue 43 | } 44 | expr := readable(tree.(*Select).Where.Expr) 45 | if expr != tcase.output { 46 | t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output) 47 | } 48 | } 49 | } 50 | 51 | func TestPlusStarPrecedence(t *testing.T) { 52 | validSQL := []struct { 53 | input string 54 | output string 55 | }{{ 56 | input: "select 1+2*3 from a", 57 | output: "(1 + (2 * 3))", 58 | }, { 59 | input: "select 1*2+3 from a", 60 | output: "((1 * 2) + 3)", 61 | }} 62 | for _, tcase := range validSQL { 63 | tree, err := Parse(tcase.input) 64 | if err != nil { 65 | t.Error(err) 66 | continue 67 | } 68 | expr := readable(tree.(*Select).SelectExprs[0].(*NonStarExpr).Expr) 69 | if expr != tcase.output { 70 | t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output) 71 | } 72 | } 73 | } 74 | 75 | func TestIsPrecedence(t *testing.T) { 76 | validSQL := []struct { 77 | input string 78 | output string 79 | }{{ 80 | input: "select * from a where a+b is true", 81 | output: "((a + b) is true)", 82 | }, { 83 | input: "select * from a where a=1 and b=2 is true", 84 | output: "(a = 1 and (b = 2 is true))", 85 | }, { 86 | input: "select * from a where (a=1 and b=2) is true", 87 | output: "((a = 1 and b = 2) is true)", 88 | }} 89 | for _, tcase := range validSQL { 90 | tree, err := Parse(tcase.input) 91 | if err != nil { 92 | t.Error(err) 93 | continue 94 | } 95 | expr := readable(tree.(*Select).Where.Expr) 96 | if expr != tcase.output { 97 | t.Errorf("Parse: \n%s, want: \n%s", expr, tcase.output) 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /internal/sqlparser/sql.y: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | %{ 6 | package sqlparser 7 | 8 | import "strings" 9 | 10 | func setParseTree(yylex interface{}, stmt Statement) { 11 | yylex.(*Tokenizer).ParseTree = stmt 12 | } 13 | 14 | func setAllowComments(yylex interface{}, allow bool) { 15 | yylex.(*Tokenizer).AllowComments = allow 16 | } 17 | 18 | func incNesting(yylex interface{}) bool { 19 | yylex.(*Tokenizer).nesting++ 20 | if yylex.(*Tokenizer).nesting == 200 { 21 | return true 22 | } 23 | return false 24 | } 25 | 26 | func decNesting(yylex interface{}) { 27 | yylex.(*Tokenizer).nesting-- 28 | } 29 | 30 | func forceEOF(yylex interface{}) { 31 | yylex.(*Tokenizer).ForceEOF = true 32 | } 33 | 34 | %} 35 | 36 | %union { 37 | empty struct{} 38 | statement Statement 39 | selStmt SelectStatement 40 | byt byte 41 | bytes []byte 42 | bytes2 [][]byte 43 | str string 44 | selectExprs SelectExprs 45 | selectExpr SelectExpr 46 | columns Columns 47 | colName *ColName 48 | tableExprs TableExprs 49 | tableExpr TableExpr 50 | smTableExpr SimpleTableExpr 51 | tableName *TableName 52 | indexHints *IndexHints 53 | expr Expr 54 | boolExpr BoolExpr 55 | valExpr ValExpr 56 | colTuple ColTuple 57 | valExprs ValExprs 58 | values Values 59 | rowTuple RowTuple 60 | subquery *Subquery 61 | caseExpr *CaseExpr 62 | whens []*When 63 | when *When 64 | orderBy OrderBy 65 | order *Order 66 | limit *Limit 67 | insRows InsertRows 68 | updateExprs UpdateExprs 69 | updateExpr *UpdateExpr 70 | sqlID SQLName 71 | sqlIDs []SQLName 72 | } 73 | 74 | %token LEX_ERROR 75 | %left UNION MINUS EXCEPT INTERSECT 76 | %token SELECT INSERT UPDATE DELETE FROM WHERE GROUP HAVING ORDER BY LIMIT FOR 77 | %token ALL DISTINCT AS EXISTS ASC DESC INTO DUPLICATE KEY DEFAULT SET LOCK KEYRANGE 78 | %token VALUES LAST_INSERT_ID 79 | %token NEXT VALUE 80 | %left JOIN STRAIGHT_JOIN LEFT RIGHT INNER OUTER CROSS NATURAL USE FORCE 81 | %left ON 82 | %token '(' ',' ')' 83 | %token ID STRING NUMBER VALUE_ARG LIST_ARG COMMENT 84 | %token NULL TRUE FALSE 85 | 86 | // Precedence dictated by mysql. But the vitess grammar is simplified. 87 | // Some of these operators don't conflict in our situation. Nevertheless, 88 | // it's better to have these listed in the correct order. Also, we don't 89 | // support all operators yet. 90 | %left OR 91 | %left AND 92 | %right NOT 93 | %left BETWEEN CASE WHEN THEN ELSE 94 | %left '=' '<' '>' LE GE NE NULL_SAFE_EQUAL IS LIKE REGEXP IN 95 | %left '|' 96 | %left '&' 97 | %left SHIFT_LEFT SHIFT_RIGHT 98 | %left '+' '-' 99 | %left '*' '/' '%' 100 | %left '^' 101 | %right '~' UNARY 102 | %right INTERVAL 103 | %nonassoc '.' 104 | %left END 105 | 106 | // DDL Tokens 107 | %token CREATE ALTER DROP RENAME ANALYZE 108 | %token TABLE INDEX VIEW TO IGNORE IF UNIQUE USING 109 | %token SHOW DESCRIBE EXPLAIN 110 | 111 | %type command 112 | %type select_statement 113 | %type insert_statement update_statement delete_statement set_statement 114 | %type create_statement alter_statement rename_statement drop_statement 115 | %type analyze_statement other_statement 116 | %type comment_opt comment_list 117 | %type union_op 118 | %type distinct_opt 119 | %type select_expression_list 120 | %type select_expression 121 | %type expression 122 | %type table_references 123 | %type table_reference table_factor join_table 124 | %type inner_join outer_join natural_join 125 | %type simple_table_expression 126 | %type dml_table_expression 127 | %type index_hint_list 128 | %type index_list 129 | %type where_expression_opt 130 | %type boolean_expression condition 131 | %type compare 132 | %type row_list 133 | %type value value_expression 134 | %type is_suffix 135 | %type col_tuple 136 | %type value_expression_list 137 | %type tuple_list 138 | %type row_tuple 139 | %type subquery 140 | %type column_name 141 | %type case_expression 142 | %type when_expression_list 143 | %type when_expression 144 | %type value_expression_opt else_expression_opt 145 | %type group_by_opt 146 | %type having_opt 147 | %type order_by_opt order_list 148 | %type order 149 | %type asc_desc_opt 150 | %type limit_opt 151 | %type lock_opt 152 | %type column_list_opt column_list 153 | %type on_dup_opt 154 | %type update_list 155 | %type update_expression 156 | %type for_from 157 | %type ignore_opt 158 | %type exists_opt not_exists_opt non_rename_operation to_opt constraint_opt using_opt 159 | %type sql_id as_lower_opt 160 | %type table_id as_opt_id 161 | %type as_opt 162 | %type force_eof 163 | 164 | %start any_command 165 | 166 | %% 167 | 168 | any_command: 169 | command 170 | { 171 | setParseTree(yylex, $1) 172 | } 173 | 174 | command: 175 | select_statement 176 | { 177 | $$ = $1 178 | } 179 | | insert_statement 180 | | update_statement 181 | | delete_statement 182 | | set_statement 183 | | create_statement 184 | | alter_statement 185 | | rename_statement 186 | | drop_statement 187 | | analyze_statement 188 | | other_statement 189 | 190 | select_statement: 191 | SELECT comment_opt distinct_opt select_expression_list FROM table_references where_expression_opt group_by_opt having_opt order_by_opt limit_opt lock_opt 192 | { 193 | $$ = &Select{Comments: Comments($2), Distinct: $3, SelectExprs: $4, From: $6, Where: NewWhere(WhereStr, $7), GroupBy: GroupBy($8), Having: NewWhere(HavingStr, $9), OrderBy: $10, Limit: $11, Lock: $12} 194 | } 195 | | SELECT comment_opt NEXT sql_id for_from simple_table_expression 196 | { 197 | if $4 != "value" { 198 | yylex.Error("expecting value after next") 199 | return 1 200 | } 201 | $$ = &Select{Comments: Comments($2), SelectExprs: SelectExprs{Nextval{}}, From: TableExprs{&AliasedTableExpr{Expr: $6}}} 202 | } 203 | | select_statement union_op select_statement %prec UNION 204 | { 205 | $$ = &Union{Type: $2, Left: $1, Right: $3} 206 | } 207 | 208 | insert_statement: 209 | INSERT comment_opt ignore_opt INTO dml_table_expression column_list_opt row_list on_dup_opt 210 | { 211 | $$ = &Insert{Comments: Comments($2), Ignore: $3, Table: $5, Columns: $6, Rows: $7, OnDup: OnDup($8)} 212 | } 213 | | INSERT comment_opt ignore_opt INTO dml_table_expression SET update_list on_dup_opt 214 | { 215 | cols := make(Columns, 0, len($7)) 216 | vals := make(ValTuple, 0, len($7)) 217 | for _, col := range $7 { 218 | cols = append(cols, &NonStarExpr{Expr: col.Name}) 219 | vals = append(vals, col.Expr) 220 | } 221 | $$ = &Insert{Comments: Comments($2), Ignore: $3, Table: $5, Columns: cols, Rows: Values{vals}, OnDup: OnDup($8)} 222 | } 223 | 224 | update_statement: 225 | UPDATE comment_opt dml_table_expression SET update_list where_expression_opt order_by_opt limit_opt 226 | { 227 | $$ = &Update{Comments: Comments($2), Table: $3, Exprs: $5, Where: NewWhere(WhereStr, $6), OrderBy: $7, Limit: $8} 228 | } 229 | 230 | delete_statement: 231 | DELETE comment_opt FROM dml_table_expression where_expression_opt order_by_opt limit_opt 232 | { 233 | $$ = &Delete{Comments: Comments($2), Table: $4, Where: NewWhere(WhereStr, $5), OrderBy: $6, Limit: $7} 234 | } 235 | 236 | set_statement: 237 | SET comment_opt update_list 238 | { 239 | $$ = &Set{Comments: Comments($2), Exprs: $3} 240 | } 241 | 242 | create_statement: 243 | CREATE TABLE not_exists_opt table_id force_eof 244 | { 245 | $$ = &DDL{Action: CreateStr, NewName: $4} 246 | } 247 | | CREATE constraint_opt INDEX ID using_opt ON table_id force_eof 248 | { 249 | // Change this to an alter statement 250 | $$ = &DDL{Action: AlterStr, Table: $7, NewName: $7} 251 | } 252 | | CREATE VIEW sql_id force_eof 253 | { 254 | $$ = &DDL{Action: CreateStr, NewName: SQLName($3)} 255 | } 256 | 257 | alter_statement: 258 | ALTER ignore_opt TABLE table_id non_rename_operation force_eof 259 | { 260 | $$ = &DDL{Action: AlterStr, Table: $4, NewName: $4} 261 | } 262 | | ALTER ignore_opt TABLE table_id RENAME to_opt table_id 263 | { 264 | // Change this to a rename statement 265 | $$ = &DDL{Action: RenameStr, Table: $4, NewName: $7} 266 | } 267 | | ALTER VIEW sql_id force_eof 268 | { 269 | $$ = &DDL{Action: AlterStr, Table: SQLName($3), NewName: SQLName($3)} 270 | } 271 | 272 | rename_statement: 273 | RENAME TABLE table_id TO table_id 274 | { 275 | $$ = &DDL{Action: RenameStr, Table: $3, NewName: $5} 276 | } 277 | 278 | drop_statement: 279 | DROP TABLE exists_opt table_id 280 | { 281 | $$ = &DDL{Action: DropStr, Table: $4} 282 | } 283 | | DROP INDEX ID ON table_id 284 | { 285 | // Change this to an alter statement 286 | $$ = &DDL{Action: AlterStr, Table: $5, NewName: $5} 287 | } 288 | | DROP VIEW exists_opt sql_id force_eof 289 | { 290 | $$ = &DDL{Action: DropStr, Table: SQLName($4)} 291 | } 292 | 293 | analyze_statement: 294 | ANALYZE TABLE table_id 295 | { 296 | $$ = &DDL{Action: AlterStr, Table: $3, NewName: $3} 297 | } 298 | 299 | other_statement: 300 | SHOW force_eof 301 | { 302 | $$ = &Other{} 303 | } 304 | | DESCRIBE force_eof 305 | { 306 | $$ = &Other{} 307 | } 308 | | EXPLAIN force_eof 309 | { 310 | $$ = &Other{} 311 | } 312 | 313 | comment_opt: 314 | { 315 | setAllowComments(yylex, true) 316 | } 317 | comment_list 318 | { 319 | $$ = $2 320 | setAllowComments(yylex, false) 321 | } 322 | 323 | comment_list: 324 | { 325 | $$ = nil 326 | } 327 | | comment_list COMMENT 328 | { 329 | $$ = append($1, $2) 330 | } 331 | 332 | union_op: 333 | UNION 334 | { 335 | $$ = UnionStr 336 | } 337 | | UNION ALL 338 | { 339 | $$ = UnionAllStr 340 | } 341 | | MINUS 342 | { 343 | $$ = SetMinusStr 344 | } 345 | | EXCEPT 346 | { 347 | $$ = ExceptStr 348 | } 349 | | INTERSECT 350 | { 351 | $$ = IntersectStr 352 | } 353 | 354 | distinct_opt: 355 | { 356 | $$ = "" 357 | } 358 | | DISTINCT 359 | { 360 | $$ = DistinctStr 361 | } 362 | 363 | select_expression_list: 364 | select_expression 365 | { 366 | $$ = SelectExprs{$1} 367 | } 368 | | select_expression_list ',' select_expression 369 | { 370 | $$ = append($$, $3) 371 | } 372 | 373 | select_expression: 374 | '*' 375 | { 376 | $$ = &StarExpr{} 377 | } 378 | | expression as_lower_opt 379 | { 380 | $$ = &NonStarExpr{Expr: $1, As: $2} 381 | } 382 | | table_id '.' '*' 383 | { 384 | $$ = &StarExpr{TableName: $1} 385 | } 386 | 387 | expression: 388 | boolean_expression 389 | { 390 | $$ = $1 391 | } 392 | | value_expression 393 | { 394 | $$ = $1 395 | } 396 | 397 | as_lower_opt: 398 | { 399 | $$ = "" 400 | } 401 | | sql_id 402 | { 403 | $$ = $1 404 | } 405 | | AS sql_id 406 | { 407 | $$ = $2 408 | } 409 | 410 | table_references: 411 | table_reference 412 | { 413 | $$ = TableExprs{$1} 414 | } 415 | | table_references ',' table_reference 416 | { 417 | $$ = append($$, $3) 418 | } 419 | 420 | table_reference: 421 | table_factor 422 | | join_table 423 | 424 | table_factor: 425 | simple_table_expression as_opt_id index_hint_list 426 | { 427 | $$ = &AliasedTableExpr{Expr:$1, As: $2, Hints: $3} 428 | } 429 | | subquery as_opt table_id 430 | { 431 | $$ = &AliasedTableExpr{Expr:$1, As: $3} 432 | } 433 | | openb table_references closeb 434 | { 435 | $$ = &ParenTableExpr{Exprs: $2} 436 | } 437 | 438 | // There is a grammar conflict here: 439 | // 1: INSERT INTO a SELECT * FROM b JOIN c ON b.i = c.i 440 | // 2: INSERT INTO a SELECT * FROM b JOIN c ON DUPLICATE KEY UPDATE a.i = 1 441 | // When yacc encounters the ON clause, it cannot determine which way to 442 | // resolve. The %prec override below makes the parser choose the 443 | // first construct, which automatically makes the second construct a 444 | // syntax error. This is the same behavior as MySQL. 445 | join_table: 446 | table_reference inner_join table_factor %prec JOIN 447 | { 448 | $$ = &JoinTableExpr{LeftExpr: $1, Join: $2, RightExpr: $3} 449 | } 450 | | table_reference inner_join table_factor ON boolean_expression 451 | { 452 | $$ = &JoinTableExpr{LeftExpr: $1, Join: $2, RightExpr: $3, On: $5} 453 | } 454 | | table_reference outer_join table_reference ON boolean_expression 455 | { 456 | $$ = &JoinTableExpr{LeftExpr: $1, Join: $2, RightExpr: $3, On: $5} 457 | } 458 | | table_reference natural_join table_factor 459 | { 460 | $$ = &JoinTableExpr{LeftExpr: $1, Join: $2, RightExpr: $3} 461 | } 462 | 463 | as_opt: 464 | { $$ = struct{}{} } 465 | | AS 466 | { $$ = struct{}{} } 467 | 468 | as_opt_id: 469 | { 470 | $$ = "" 471 | } 472 | | table_id 473 | { 474 | $$ = $1 475 | } 476 | | AS table_id 477 | { 478 | $$ = $2 479 | } 480 | 481 | inner_join: 482 | JOIN 483 | { 484 | $$ = JoinStr 485 | } 486 | | INNER JOIN 487 | { 488 | $$ = JoinStr 489 | } 490 | | CROSS JOIN 491 | { 492 | $$ = JoinStr 493 | } 494 | | STRAIGHT_JOIN 495 | { 496 | $$ = StraightJoinStr 497 | } 498 | 499 | outer_join: 500 | LEFT JOIN 501 | { 502 | $$ = LeftJoinStr 503 | } 504 | | LEFT OUTER JOIN 505 | { 506 | $$ = LeftJoinStr 507 | } 508 | | RIGHT JOIN 509 | { 510 | $$ = RightJoinStr 511 | } 512 | | RIGHT OUTER JOIN 513 | { 514 | $$ = RightJoinStr 515 | } 516 | 517 | natural_join: 518 | NATURAL JOIN 519 | { 520 | $$ = NaturalJoinStr 521 | } 522 | | NATURAL outer_join 523 | { 524 | if $2 == LeftJoinStr { 525 | $$ = NaturalLeftJoinStr 526 | } else { 527 | $$ = NaturalRightJoinStr 528 | } 529 | } 530 | 531 | simple_table_expression: 532 | table_id 533 | { 534 | $$ = &TableName{Name: $1} 535 | } 536 | | table_id '.' table_id 537 | { 538 | $$ = &TableName{Qualifier: $1, Name: $3} 539 | } 540 | 541 | dml_table_expression: 542 | table_id 543 | { 544 | $$ = &TableName{Name: $1} 545 | } 546 | | table_id '.' table_id 547 | { 548 | $$ = &TableName{Qualifier: $1, Name: $3} 549 | } 550 | 551 | index_hint_list: 552 | { 553 | $$ = nil 554 | } 555 | | USE INDEX openb index_list closeb 556 | { 557 | $$ = &IndexHints{Type: UseStr, Indexes: $4} 558 | } 559 | | IGNORE INDEX openb index_list closeb 560 | { 561 | $$ = &IndexHints{Type: IgnoreStr, Indexes: $4} 562 | } 563 | | FORCE INDEX openb index_list closeb 564 | { 565 | $$ = &IndexHints{Type: ForceStr, Indexes: $4} 566 | } 567 | 568 | index_list: 569 | sql_id 570 | { 571 | $$ = []SQLName{$1} 572 | } 573 | | index_list ',' sql_id 574 | { 575 | $$ = append($1, $3) 576 | } 577 | 578 | where_expression_opt: 579 | { 580 | $$ = nil 581 | } 582 | | WHERE boolean_expression 583 | { 584 | $$ = $2 585 | } 586 | 587 | boolean_expression: 588 | condition 589 | | boolean_expression AND boolean_expression 590 | { 591 | $$ = &AndExpr{Left: $1, Right: $3} 592 | } 593 | | boolean_expression OR boolean_expression 594 | { 595 | $$ = &OrExpr{Left: $1, Right: $3} 596 | } 597 | | NOT boolean_expression 598 | { 599 | $$ = &NotExpr{Expr: $2} 600 | } 601 | | openb boolean_expression closeb 602 | { 603 | $$ = &ParenBoolExpr{Expr: $2} 604 | } 605 | | boolean_expression IS is_suffix 606 | { 607 | $$ = &IsExpr{Operator: $3, Expr: $1} 608 | } 609 | 610 | condition: 611 | TRUE 612 | { 613 | $$ = BoolVal(true) 614 | } 615 | | FALSE 616 | { 617 | $$ = BoolVal(false) 618 | } 619 | | value_expression compare value_expression 620 | { 621 | $$ = &ComparisonExpr{Left: $1, Operator: $2, Right: $3} 622 | } 623 | | value_expression IN col_tuple 624 | { 625 | $$ = &ComparisonExpr{Left: $1, Operator: InStr, Right: $3} 626 | } 627 | | value_expression NOT IN col_tuple 628 | { 629 | $$ = &ComparisonExpr{Left: $1, Operator: NotInStr, Right: $4} 630 | } 631 | | value_expression LIKE value_expression 632 | { 633 | $$ = &ComparisonExpr{Left: $1, Operator: LikeStr, Right: $3} 634 | } 635 | | value_expression NOT LIKE value_expression 636 | { 637 | $$ = &ComparisonExpr{Left: $1, Operator: NotLikeStr, Right: $4} 638 | } 639 | | value_expression REGEXP value_expression 640 | { 641 | $$ = &ComparisonExpr{Left: $1, Operator: RegexpStr, Right: $3} 642 | } 643 | | value_expression NOT REGEXP value_expression 644 | { 645 | $$ = &ComparisonExpr{Left: $1, Operator: NotRegexpStr, Right: $4} 646 | } 647 | | value_expression BETWEEN value_expression AND value_expression 648 | { 649 | $$ = &RangeCond{Left: $1, Operator: BetweenStr, From: $3, To: $5} 650 | } 651 | | value_expression NOT BETWEEN value_expression AND value_expression 652 | { 653 | $$ = &RangeCond{Left: $1, Operator: NotBetweenStr, From: $4, To: $6} 654 | } 655 | | value_expression IS is_suffix 656 | { 657 | $$ = &IsExpr{Operator: $3, Expr: $1} 658 | } 659 | | EXISTS subquery 660 | { 661 | $$ = &ExistsExpr{Subquery: $2} 662 | } 663 | | KEYRANGE openb value ',' value closeb 664 | { 665 | $$ = &KeyrangeExpr{Start: $3, End: $5} 666 | } 667 | 668 | is_suffix: 669 | NULL 670 | { 671 | $$ = IsNullStr 672 | } 673 | | NOT NULL 674 | { 675 | $$ = IsNotNullStr 676 | } 677 | | TRUE 678 | { 679 | $$ = IsTrueStr 680 | } 681 | | NOT TRUE 682 | { 683 | $$ = IsNotTrueStr 684 | } 685 | | FALSE 686 | { 687 | $$ = IsFalseStr 688 | } 689 | | NOT FALSE 690 | { 691 | $$ = IsNotFalseStr 692 | } 693 | 694 | compare: 695 | '=' 696 | { 697 | $$ = EqualStr 698 | } 699 | | '<' 700 | { 701 | $$ = LessThanStr 702 | } 703 | | '>' 704 | { 705 | $$ = GreaterThanStr 706 | } 707 | | LE 708 | { 709 | $$ = LessEqualStr 710 | } 711 | | GE 712 | { 713 | $$ = GreaterEqualStr 714 | } 715 | | NE 716 | { 717 | $$ = NotEqualStr 718 | } 719 | | NULL_SAFE_EQUAL 720 | { 721 | $$ = NullSafeEqualStr 722 | } 723 | 724 | col_tuple: 725 | openb value_expression_list closeb 726 | { 727 | $$ = ValTuple($2) 728 | } 729 | | subquery 730 | { 731 | $$ = $1 732 | } 733 | | LIST_ARG 734 | { 735 | $$ = ListArg($1) 736 | } 737 | 738 | subquery: 739 | openb select_statement closeb 740 | { 741 | $$ = &Subquery{$2} 742 | } 743 | 744 | value_expression_list: 745 | value_expression 746 | { 747 | $$ = ValExprs{$1} 748 | } 749 | | value_expression_list ',' value_expression 750 | { 751 | $$ = append($1, $3) 752 | } 753 | 754 | value_expression: 755 | value 756 | { 757 | $$ = $1 758 | } 759 | | column_name 760 | { 761 | $$ = $1 762 | } 763 | | row_tuple 764 | { 765 | $$ = $1 766 | } 767 | | value_expression '&' value_expression 768 | { 769 | $$ = &BinaryExpr{Left: $1, Operator: BitAndStr, Right: $3} 770 | } 771 | | value_expression '|' value_expression 772 | { 773 | $$ = &BinaryExpr{Left: $1, Operator: BitOrStr, Right: $3} 774 | } 775 | | value_expression '^' value_expression 776 | { 777 | $$ = &BinaryExpr{Left: $1, Operator: BitXorStr, Right: $3} 778 | } 779 | | value_expression '+' value_expression 780 | { 781 | $$ = &BinaryExpr{Left: $1, Operator: PlusStr, Right: $3} 782 | } 783 | | value_expression '-' value_expression 784 | { 785 | $$ = &BinaryExpr{Left: $1, Operator: MinusStr, Right: $3} 786 | } 787 | | value_expression '*' value_expression 788 | { 789 | $$ = &BinaryExpr{Left: $1, Operator: MultStr, Right: $3} 790 | } 791 | | value_expression '/' value_expression 792 | { 793 | $$ = &BinaryExpr{Left: $1, Operator: DivStr, Right: $3} 794 | } 795 | | value_expression '%' value_expression 796 | { 797 | $$ = &BinaryExpr{Left: $1, Operator: ModStr, Right: $3} 798 | } 799 | | value_expression SHIFT_LEFT value_expression 800 | { 801 | $$ = &BinaryExpr{Left: $1, Operator: ShiftLeftStr, Right: $3} 802 | } 803 | | value_expression SHIFT_RIGHT value_expression 804 | { 805 | $$ = &BinaryExpr{Left: $1, Operator: ShiftRightStr, Right: $3} 806 | } 807 | | '+' value_expression %prec UNARY 808 | { 809 | if num, ok := $2.(NumVal); ok { 810 | $$ = num 811 | } else { 812 | $$ = &UnaryExpr{Operator: UPlusStr, Expr: $2} 813 | } 814 | } 815 | | '-' value_expression %prec UNARY 816 | { 817 | if num, ok := $2.(NumVal); ok { 818 | // Handle double negative 819 | if num[0] == '-' { 820 | $$ = num[1:] 821 | } else { 822 | $$ = append(NumVal("-"), num...) 823 | } 824 | } else { 825 | $$ = &UnaryExpr{Operator: UMinusStr, Expr: $2} 826 | } 827 | } 828 | | '~' value_expression 829 | { 830 | $$ = &UnaryExpr{Operator: TildaStr, Expr: $2} 831 | } 832 | | INTERVAL value_expression sql_id 833 | { 834 | // This rule prevents the usage of INTERVAL 835 | // as a function. If support is needed for that, 836 | // we'll need to revisit this. The solution 837 | // will be non-trivial because of grammar conflicts. 838 | $$ = &IntervalExpr{Expr: $2, Unit: $3} 839 | } 840 | | sql_id openb closeb 841 | { 842 | $$ = &FuncExpr{Name: string($1)} 843 | } 844 | | sql_id openb select_expression_list closeb 845 | { 846 | $$ = &FuncExpr{Name: string($1), Exprs: $3} 847 | } 848 | | sql_id openb DISTINCT select_expression_list closeb 849 | { 850 | $$ = &FuncExpr{Name: string($1), Distinct: true, Exprs: $4} 851 | } 852 | | IF openb select_expression_list closeb 853 | { 854 | $$ = &FuncExpr{Name: "if", Exprs: $3} 855 | } 856 | | case_expression 857 | { 858 | $$ = $1 859 | } 860 | 861 | case_expression: 862 | CASE value_expression_opt when_expression_list else_expression_opt END 863 | { 864 | $$ = &CaseExpr{Expr: $2, Whens: $3, Else: $4} 865 | } 866 | 867 | value_expression_opt: 868 | { 869 | $$ = nil 870 | } 871 | | value_expression 872 | { 873 | $$ = $1 874 | } 875 | 876 | when_expression_list: 877 | when_expression 878 | { 879 | $$ = []*When{$1} 880 | } 881 | | when_expression_list when_expression 882 | { 883 | $$ = append($1, $2) 884 | } 885 | 886 | when_expression: 887 | WHEN boolean_expression THEN value_expression 888 | { 889 | $$ = &When{Cond: $2, Val: $4} 890 | } 891 | 892 | else_expression_opt: 893 | { 894 | $$ = nil 895 | } 896 | | ELSE value_expression 897 | { 898 | $$ = $2 899 | } 900 | 901 | column_name: 902 | sql_id 903 | { 904 | $$ = &ColName{Name: $1} 905 | } 906 | | table_id '.' sql_id 907 | { 908 | $$ = &ColName{Qualifier: $1, Name: $3} 909 | } 910 | 911 | value: 912 | STRING 913 | { 914 | $$ = StrVal($1) 915 | } 916 | | NUMBER 917 | { 918 | $$ = NumVal($1) 919 | } 920 | | VALUE_ARG 921 | { 922 | $$ = ValArg($1) 923 | } 924 | | NULL 925 | { 926 | $$ = &NullVal{} 927 | } 928 | 929 | group_by_opt: 930 | { 931 | $$ = nil 932 | } 933 | | GROUP BY value_expression_list 934 | { 935 | $$ = $3 936 | } 937 | 938 | having_opt: 939 | { 940 | $$ = nil 941 | } 942 | | HAVING boolean_expression 943 | { 944 | $$ = $2 945 | } 946 | 947 | order_by_opt: 948 | { 949 | $$ = nil 950 | } 951 | | ORDER BY order_list 952 | { 953 | $$ = $3 954 | } 955 | 956 | order_list: 957 | order 958 | { 959 | $$ = OrderBy{$1} 960 | } 961 | | order_list ',' order 962 | { 963 | $$ = append($1, $3) 964 | } 965 | 966 | order: 967 | value_expression asc_desc_opt 968 | { 969 | $$ = &Order{Expr: $1, Direction: $2} 970 | } 971 | 972 | asc_desc_opt: 973 | { 974 | $$ = AscScr 975 | } 976 | | ASC 977 | { 978 | $$ = AscScr 979 | } 980 | | DESC 981 | { 982 | $$ = DescScr 983 | } 984 | 985 | limit_opt: 986 | { 987 | $$ = nil 988 | } 989 | | LIMIT value_expression 990 | { 991 | $$ = &Limit{Rowcount: $2} 992 | } 993 | | LIMIT value_expression ',' value_expression 994 | { 995 | $$ = &Limit{Offset: $2, Rowcount: $4} 996 | } 997 | 998 | lock_opt: 999 | { 1000 | $$ = "" 1001 | } 1002 | | FOR UPDATE 1003 | { 1004 | $$ = ForUpdateStr 1005 | } 1006 | | LOCK IN sql_id sql_id 1007 | { 1008 | if $3 != "share" { 1009 | yylex.Error("expecting share") 1010 | return 1 1011 | } 1012 | if $4 != "mode" { 1013 | yylex.Error("expecting mode") 1014 | return 1 1015 | } 1016 | $$ = ShareModeStr 1017 | } 1018 | 1019 | column_list_opt: 1020 | { 1021 | $$ = nil 1022 | } 1023 | | openb column_list closeb 1024 | { 1025 | $$ = $2 1026 | } 1027 | 1028 | column_list: 1029 | column_name 1030 | { 1031 | $$ = Columns{&NonStarExpr{Expr: $1}} 1032 | } 1033 | | column_list ',' column_name 1034 | { 1035 | $$ = append($$, &NonStarExpr{Expr: $3}) 1036 | } 1037 | 1038 | on_dup_opt: 1039 | { 1040 | $$ = nil 1041 | } 1042 | | ON DUPLICATE KEY UPDATE update_list 1043 | { 1044 | $$ = $5 1045 | } 1046 | 1047 | row_list: 1048 | VALUES tuple_list 1049 | { 1050 | $$ = $2 1051 | } 1052 | | select_statement 1053 | { 1054 | $$ = $1 1055 | } 1056 | 1057 | tuple_list: 1058 | row_tuple 1059 | { 1060 | $$ = Values{$1} 1061 | } 1062 | | tuple_list ',' row_tuple 1063 | { 1064 | $$ = append($1, $3) 1065 | } 1066 | 1067 | row_tuple: 1068 | openb value_expression_list closeb 1069 | { 1070 | $$ = ValTuple($2) 1071 | } 1072 | | subquery 1073 | { 1074 | $$ = $1 1075 | } 1076 | 1077 | update_list: 1078 | update_expression 1079 | { 1080 | $$ = UpdateExprs{$1} 1081 | } 1082 | | update_list ',' update_expression 1083 | { 1084 | $$ = append($1, $3) 1085 | } 1086 | 1087 | update_expression: 1088 | column_name '=' value_expression 1089 | { 1090 | $$ = &UpdateExpr{Name: $1, Expr: $3} 1091 | } 1092 | 1093 | for_from: 1094 | FOR 1095 | | FROM 1096 | 1097 | exists_opt: 1098 | { $$ = struct{}{} } 1099 | | IF EXISTS 1100 | { $$ = struct{}{} } 1101 | 1102 | not_exists_opt: 1103 | { $$ = struct{}{} } 1104 | | IF NOT EXISTS 1105 | { $$ = struct{}{} } 1106 | 1107 | ignore_opt: 1108 | { $$ = "" } 1109 | | IGNORE 1110 | { $$ = IgnoreStr } 1111 | 1112 | non_rename_operation: 1113 | ALTER 1114 | { $$ = struct{}{} } 1115 | | DEFAULT 1116 | { $$ = struct{}{} } 1117 | | DROP 1118 | { $$ = struct{}{} } 1119 | | ORDER 1120 | { $$ = struct{}{} } 1121 | | ID 1122 | { $$ = struct{}{} } 1123 | 1124 | to_opt: 1125 | { $$ = struct{}{} } 1126 | | TO 1127 | { $$ = struct{}{} } 1128 | 1129 | constraint_opt: 1130 | { $$ = struct{}{} } 1131 | | UNIQUE 1132 | { $$ = struct{}{} } 1133 | 1134 | using_opt: 1135 | { $$ = struct{}{} } 1136 | | USING sql_id 1137 | { $$ = struct{}{} } 1138 | 1139 | sql_id: 1140 | ID 1141 | { 1142 | $$ = SQLName(strings.ToLower(string($1))) 1143 | } 1144 | 1145 | table_id: 1146 | ID 1147 | { 1148 | $$ = SQLName($1) 1149 | } 1150 | 1151 | openb: 1152 | '(' 1153 | { 1154 | if incNesting(yylex) { 1155 | yylex.Error("max nesting level reached") 1156 | return 1 1157 | } 1158 | } 1159 | 1160 | closeb: 1161 | ')' 1162 | { 1163 | decNesting(yylex) 1164 | } 1165 | 1166 | force_eof: 1167 | { 1168 | forceEOF(yylex) 1169 | } 1170 | -------------------------------------------------------------------------------- /internal/sqlparser/token.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "strings" 11 | 12 | "github.com/guregu/mogi/internal/sqltypes" 13 | ) 14 | 15 | const eofChar = 0x100 16 | 17 | // Tokenizer is the struct used to generate SQL 18 | // tokens for the parser. 19 | type Tokenizer struct { 20 | InStream *strings.Reader 21 | AllowComments bool 22 | ForceEOF bool 23 | lastChar uint16 24 | Position int 25 | lastToken []byte 26 | LastError string 27 | posVarIndex int 28 | ParseTree Statement 29 | nesting int 30 | 31 | // Special hack to recongnize 32 | // SELECT NEXT VALUE for construct. 33 | // If the token immediately after a SELECT is 34 | // a NEXT, then it's treated as a keyword. Otherwise, 35 | // it's a normal identifier. This flag gets set 36 | // if we see a SELECT, and gets reset for any other token. 37 | atSelect bool 38 | } 39 | 40 | // NewStringTokenizer creates a new Tokenizer for the 41 | // sql string. 42 | func NewStringTokenizer(sql string) *Tokenizer { 43 | return &Tokenizer{InStream: strings.NewReader(sql)} 44 | } 45 | 46 | var keywords = map[string]int{ 47 | "all": ALL, 48 | "alter": ALTER, 49 | "analyze": ANALYZE, 50 | "and": AND, 51 | "as": AS, 52 | "asc": ASC, 53 | "between": BETWEEN, 54 | "by": BY, 55 | "case": CASE, 56 | "create": CREATE, 57 | "cross": CROSS, 58 | "default": DEFAULT, 59 | "delete": DELETE, 60 | "desc": DESC, 61 | "describe": DESCRIBE, 62 | "distinct": DISTINCT, 63 | "drop": DROP, 64 | "duplicate": DUPLICATE, 65 | "else": ELSE, 66 | "end": END, 67 | "except": EXCEPT, 68 | "exists": EXISTS, 69 | "explain": EXPLAIN, 70 | "false": FALSE, 71 | "for": FOR, 72 | "force": FORCE, 73 | "from": FROM, 74 | "group": GROUP, 75 | "having": HAVING, 76 | "if": IF, 77 | "ignore": IGNORE, 78 | "in": IN, 79 | "index": INDEX, 80 | "inner": INNER, 81 | "insert": INSERT, 82 | "intersect": INTERSECT, 83 | "interval": INTERVAL, 84 | "into": INTO, 85 | "is": IS, 86 | "join": JOIN, 87 | "key": KEY, 88 | "keyrange": KEYRANGE, 89 | "last_insert_id": LAST_INSERT_ID, 90 | "left": LEFT, 91 | "like": LIKE, 92 | "limit": LIMIT, 93 | "lock": LOCK, 94 | "minus": MINUS, 95 | "natural": NATURAL, 96 | "not": NOT, 97 | "null": NULL, 98 | "on": ON, 99 | "or": OR, 100 | "order": ORDER, 101 | "outer": OUTER, 102 | "rename": RENAME, 103 | "regexp": REGEXP, 104 | "right": RIGHT, 105 | "rlike": REGEXP, 106 | "select": SELECT, 107 | "set": SET, 108 | "show": SHOW, 109 | "straight_join": STRAIGHT_JOIN, 110 | "table": TABLE, 111 | "then": THEN, 112 | "to": TO, 113 | "true": TRUE, 114 | "union": UNION, 115 | "unique": UNIQUE, 116 | "update": UPDATE, 117 | "use": USE, 118 | "using": USING, 119 | "values": VALUES, 120 | "view": VIEW, 121 | "when": WHEN, 122 | "where": WHERE, 123 | } 124 | 125 | // Lex returns the next token form the Tokenizer. 126 | // This function is used by go yacc. 127 | func (tkn *Tokenizer) Lex(lval *yySymType) int { 128 | typ, val := tkn.Scan() 129 | for typ == COMMENT { 130 | if tkn.AllowComments { 131 | break 132 | } 133 | typ, val = tkn.Scan() 134 | } 135 | switch typ { 136 | case ID, STRING, NUMBER, VALUE_ARG, LIST_ARG, COMMENT: 137 | lval.bytes = val 138 | } 139 | tkn.lastToken = val 140 | if typ == SELECT { 141 | tkn.atSelect = true 142 | } else { 143 | tkn.atSelect = false 144 | } 145 | return typ 146 | } 147 | 148 | // Error is called by go yacc if there's a parsing error. 149 | func (tkn *Tokenizer) Error(err string) { 150 | buf := &bytes.Buffer{} 151 | if tkn.lastToken != nil { 152 | fmt.Fprintf(buf, "%s at position %v near '%s'", err, tkn.Position, tkn.lastToken) 153 | } else { 154 | fmt.Fprintf(buf, "%s at position %v", err, tkn.Position) 155 | } 156 | tkn.LastError = buf.String() 157 | } 158 | 159 | // Scan scans the tokenizer for the next token and returns 160 | // the token type and an optional value. 161 | func (tkn *Tokenizer) Scan() (int, []byte) { 162 | if tkn.ForceEOF { 163 | return 0, nil 164 | } 165 | 166 | if tkn.lastChar == 0 { 167 | tkn.next() 168 | } 169 | tkn.skipBlank() 170 | switch ch := tkn.lastChar; { 171 | case isLetter(ch): 172 | return tkn.scanIdentifier() 173 | case isDigit(ch): 174 | return tkn.scanNumber(false) 175 | case ch == ':': 176 | return tkn.scanBindVar() 177 | default: 178 | tkn.next() 179 | switch ch { 180 | case eofChar: 181 | return 0, nil 182 | case '=', ',', ';', '(', ')', '+', '*', '%', '&', '|', '^', '~': 183 | return int(ch), nil 184 | case '?': 185 | tkn.posVarIndex++ 186 | buf := new(bytes.Buffer) 187 | fmt.Fprintf(buf, ":v%d", tkn.posVarIndex) 188 | return VALUE_ARG, buf.Bytes() 189 | case '.': 190 | if isDigit(tkn.lastChar) { 191 | return tkn.scanNumber(true) 192 | } 193 | return int(ch), nil 194 | case '/': 195 | switch tkn.lastChar { 196 | case '/': 197 | tkn.next() 198 | return tkn.scanCommentType1("//") 199 | case '*': 200 | tkn.next() 201 | return tkn.scanCommentType2() 202 | default: 203 | return int(ch), nil 204 | } 205 | case '-': 206 | if tkn.lastChar == '-' { 207 | tkn.next() 208 | return tkn.scanCommentType1("--") 209 | } 210 | return int(ch), nil 211 | case '<': 212 | switch tkn.lastChar { 213 | case '>': 214 | tkn.next() 215 | return NE, nil 216 | case '<': 217 | tkn.next() 218 | return SHIFT_LEFT, nil 219 | case '=': 220 | tkn.next() 221 | switch tkn.lastChar { 222 | case '>': 223 | tkn.next() 224 | return NULL_SAFE_EQUAL, nil 225 | default: 226 | return LE, nil 227 | } 228 | default: 229 | return int(ch), nil 230 | } 231 | case '>': 232 | switch tkn.lastChar { 233 | case '=': 234 | tkn.next() 235 | return GE, nil 236 | case '>': 237 | tkn.next() 238 | return SHIFT_RIGHT, nil 239 | default: 240 | return int(ch), nil 241 | } 242 | case '!': 243 | if tkn.lastChar == '=' { 244 | tkn.next() 245 | return NE, nil 246 | } 247 | return LEX_ERROR, []byte("!") 248 | case '\'', '"': 249 | return tkn.scanString(ch, STRING) 250 | case '`': 251 | return tkn.scanLiteralIdentifier() 252 | default: 253 | return LEX_ERROR, []byte{byte(ch)} 254 | } 255 | } 256 | } 257 | 258 | func (tkn *Tokenizer) skipBlank() { 259 | ch := tkn.lastChar 260 | for ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t' { 261 | tkn.next() 262 | ch = tkn.lastChar 263 | } 264 | } 265 | 266 | func (tkn *Tokenizer) scanIdentifier() (int, []byte) { 267 | buffer := &bytes.Buffer{} 268 | buffer.WriteByte(byte(tkn.lastChar)) 269 | for tkn.next(); isLetter(tkn.lastChar) || isDigit(tkn.lastChar); tkn.next() { 270 | buffer.WriteByte(byte(tkn.lastChar)) 271 | } 272 | lowered := bytes.ToLower(buffer.Bytes()) 273 | loweredStr := string(lowered) 274 | if keywordID, found := keywords[loweredStr]; found { 275 | return keywordID, lowered 276 | } 277 | // If we're at a SELECT, treat NEXT as a keyword. 278 | if tkn.atSelect && loweredStr == "next" { 279 | return NEXT, lowered 280 | } 281 | // dual must always be case-insensitive 282 | if loweredStr == "dual" { 283 | return ID, lowered 284 | } 285 | return ID, buffer.Bytes() 286 | } 287 | 288 | func (tkn *Tokenizer) scanLiteralIdentifier() (int, []byte) { 289 | buffer := &bytes.Buffer{} 290 | buffer.WriteByte(byte(tkn.lastChar)) 291 | if !isLetter(tkn.lastChar) { 292 | return LEX_ERROR, buffer.Bytes() 293 | } 294 | for tkn.next(); isLetter(tkn.lastChar) || isDigit(tkn.lastChar); tkn.next() { 295 | buffer.WriteByte(byte(tkn.lastChar)) 296 | } 297 | if tkn.lastChar != '`' { 298 | return LEX_ERROR, buffer.Bytes() 299 | } 300 | tkn.next() 301 | return ID, buffer.Bytes() 302 | } 303 | 304 | func (tkn *Tokenizer) scanBindVar() (int, []byte) { 305 | buffer := &bytes.Buffer{} 306 | buffer.WriteByte(byte(tkn.lastChar)) 307 | token := VALUE_ARG 308 | tkn.next() 309 | if tkn.lastChar == ':' { 310 | token = LIST_ARG 311 | buffer.WriteByte(byte(tkn.lastChar)) 312 | tkn.next() 313 | } 314 | if !isLetter(tkn.lastChar) { 315 | return LEX_ERROR, buffer.Bytes() 316 | } 317 | for isLetter(tkn.lastChar) || isDigit(tkn.lastChar) || tkn.lastChar == '.' { 318 | buffer.WriteByte(byte(tkn.lastChar)) 319 | tkn.next() 320 | } 321 | return token, buffer.Bytes() 322 | } 323 | 324 | func (tkn *Tokenizer) scanMantissa(base int, buffer *bytes.Buffer) { 325 | for digitVal(tkn.lastChar) < base { 326 | tkn.consumeNext(buffer) 327 | } 328 | } 329 | 330 | func (tkn *Tokenizer) scanNumber(seenDecimalPoint bool) (int, []byte) { 331 | buffer := &bytes.Buffer{} 332 | if seenDecimalPoint { 333 | buffer.WriteByte('.') 334 | tkn.scanMantissa(10, buffer) 335 | goto exponent 336 | } 337 | 338 | if tkn.lastChar == '0' { 339 | // int or float 340 | tkn.consumeNext(buffer) 341 | if tkn.lastChar == 'x' || tkn.lastChar == 'X' { 342 | // hexadecimal int 343 | tkn.consumeNext(buffer) 344 | tkn.scanMantissa(16, buffer) 345 | } else { 346 | // octal int or float 347 | seenDecimalDigit := false 348 | tkn.scanMantissa(8, buffer) 349 | if tkn.lastChar == '8' || tkn.lastChar == '9' { 350 | // illegal octal int or float 351 | seenDecimalDigit = true 352 | tkn.scanMantissa(10, buffer) 353 | } 354 | if tkn.lastChar == '.' || tkn.lastChar == 'e' || tkn.lastChar == 'E' { 355 | goto fraction 356 | } 357 | // octal int 358 | if seenDecimalDigit { 359 | return LEX_ERROR, buffer.Bytes() 360 | } 361 | } 362 | goto exit 363 | } 364 | 365 | // decimal int or float 366 | tkn.scanMantissa(10, buffer) 367 | 368 | fraction: 369 | if tkn.lastChar == '.' { 370 | tkn.consumeNext(buffer) 371 | tkn.scanMantissa(10, buffer) 372 | } 373 | 374 | exponent: 375 | if tkn.lastChar == 'e' || tkn.lastChar == 'E' { 376 | tkn.consumeNext(buffer) 377 | if tkn.lastChar == '+' || tkn.lastChar == '-' { 378 | tkn.consumeNext(buffer) 379 | } 380 | tkn.scanMantissa(10, buffer) 381 | } 382 | 383 | exit: 384 | return NUMBER, buffer.Bytes() 385 | } 386 | 387 | func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, []byte) { 388 | buffer := &bytes.Buffer{} 389 | for { 390 | ch := tkn.lastChar 391 | tkn.next() 392 | if ch == delim { 393 | if tkn.lastChar == delim { 394 | tkn.next() 395 | } else { 396 | break 397 | } 398 | } else if ch == '\\' { 399 | if tkn.lastChar == eofChar { 400 | return LEX_ERROR, buffer.Bytes() 401 | } 402 | if decodedChar := sqltypes.SQLDecodeMap[byte(tkn.lastChar)]; decodedChar == sqltypes.DontEscape { 403 | ch = tkn.lastChar 404 | } else { 405 | ch = uint16(decodedChar) 406 | } 407 | tkn.next() 408 | } 409 | if ch == eofChar { 410 | return LEX_ERROR, buffer.Bytes() 411 | } 412 | buffer.WriteByte(byte(ch)) 413 | } 414 | return typ, buffer.Bytes() 415 | } 416 | 417 | func (tkn *Tokenizer) scanCommentType1(prefix string) (int, []byte) { 418 | buffer := &bytes.Buffer{} 419 | buffer.WriteString(prefix) 420 | for tkn.lastChar != eofChar { 421 | if tkn.lastChar == '\n' { 422 | tkn.consumeNext(buffer) 423 | break 424 | } 425 | tkn.consumeNext(buffer) 426 | } 427 | return COMMENT, buffer.Bytes() 428 | } 429 | 430 | func (tkn *Tokenizer) scanCommentType2() (int, []byte) { 431 | buffer := &bytes.Buffer{} 432 | buffer.WriteString("/*") 433 | for { 434 | if tkn.lastChar == '*' { 435 | tkn.consumeNext(buffer) 436 | if tkn.lastChar == '/' { 437 | tkn.consumeNext(buffer) 438 | break 439 | } 440 | continue 441 | } 442 | if tkn.lastChar == eofChar { 443 | return LEX_ERROR, buffer.Bytes() 444 | } 445 | tkn.consumeNext(buffer) 446 | } 447 | return COMMENT, buffer.Bytes() 448 | } 449 | 450 | func (tkn *Tokenizer) consumeNext(buffer *bytes.Buffer) { 451 | if tkn.lastChar == eofChar { 452 | // This should never happen. 453 | panic("unexpected EOF") 454 | } 455 | buffer.WriteByte(byte(tkn.lastChar)) 456 | tkn.next() 457 | } 458 | 459 | func (tkn *Tokenizer) next() { 460 | if ch, err := tkn.InStream.ReadByte(); err != nil { 461 | // Only EOF is possible. 462 | tkn.lastChar = eofChar 463 | } else { 464 | tkn.lastChar = uint16(ch) 465 | } 466 | tkn.Position++ 467 | } 468 | 469 | func isLetter(ch uint16) bool { 470 | return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '@' 471 | } 472 | 473 | func digitVal(ch uint16) int { 474 | switch { 475 | case '0' <= ch && ch <= '9': 476 | return int(ch) - '0' 477 | case 'a' <= ch && ch <= 'f': 478 | return int(ch) - 'a' + 10 479 | case 'A' <= ch && ch <= 'F': 480 | return int(ch) - 'A' + 10 481 | } 482 | return 16 // larger than any legal digit val 483 | } 484 | 485 | func isDigit(ch uint16) bool { 486 | return '0' <= ch && ch <= '9' 487 | } 488 | -------------------------------------------------------------------------------- /internal/sqlparser/tracked_buffer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqlparser 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | ) 11 | 12 | // TrackedBuffer is used to rebuild a query from the ast. 13 | // bindLocations keeps track of locations in the buffer that 14 | // use bind variables for efficient future substitutions. 15 | // nodeFormatter is the formatting function the buffer will 16 | // use to format a node. By default(nil), it's FormatNode. 17 | // But you can supply a different formatting function if you 18 | // want to generate a query that's different from the default. 19 | type TrackedBuffer struct { 20 | *bytes.Buffer 21 | bindLocations []bindLocation 22 | nodeFormatter func(buf *TrackedBuffer, node SQLNode) 23 | } 24 | 25 | // NewTrackedBuffer creates a new TrackedBuffer. 26 | func NewTrackedBuffer(nodeFormatter func(buf *TrackedBuffer, node SQLNode)) *TrackedBuffer { 27 | buf := &TrackedBuffer{ 28 | Buffer: bytes.NewBuffer(make([]byte, 0, 128)), 29 | bindLocations: make([]bindLocation, 0, 4), 30 | nodeFormatter: nodeFormatter, 31 | } 32 | return buf 33 | } 34 | 35 | // Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v), 36 | // Node.Value(%s) and string(%s). It also allows a %a for a value argument, in 37 | // which case it adds tracking info for future substitutions. 38 | // 39 | // The name must be something other than the usual Printf() to avoid "go vet" 40 | // warnings due to our custom format specifiers. 41 | func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) { 42 | end := len(format) 43 | fieldnum := 0 44 | for i := 0; i < end; { 45 | lasti := i 46 | for i < end && format[i] != '%' { 47 | i++ 48 | } 49 | if i > lasti { 50 | buf.WriteString(format[lasti:i]) 51 | } 52 | if i >= end { 53 | break 54 | } 55 | i++ // '%' 56 | switch format[i] { 57 | case 'c': 58 | switch v := values[fieldnum].(type) { 59 | case byte: 60 | buf.WriteByte(v) 61 | case rune: 62 | buf.WriteRune(v) 63 | default: 64 | panic(fmt.Sprintf("unexpected type %T", v)) 65 | } 66 | case 's': 67 | switch v := values[fieldnum].(type) { 68 | case []byte: 69 | buf.Write(v) 70 | case string: 71 | buf.WriteString(v) 72 | default: 73 | panic(fmt.Sprintf("unexpected type %T", v)) 74 | } 75 | case 'v': 76 | node := values[fieldnum].(SQLNode) 77 | if buf.nodeFormatter == nil { 78 | node.Format(buf) 79 | } else { 80 | buf.nodeFormatter(buf, node) 81 | } 82 | case 'a': 83 | buf.WriteArg(values[fieldnum].(string)) 84 | default: 85 | panic("unexpected") 86 | } 87 | fieldnum++ 88 | i++ 89 | } 90 | } 91 | 92 | // WriteArg writes a value argument into the buffer along with 93 | // tracking information for future substitutions. arg must contain 94 | // the ":" or "::" prefix. 95 | func (buf *TrackedBuffer) WriteArg(arg string) { 96 | buf.bindLocations = append(buf.bindLocations, bindLocation{ 97 | offset: buf.Len(), 98 | length: len(arg), 99 | }) 100 | buf.WriteString(arg) 101 | } 102 | 103 | // ParsedQuery returns a ParsedQuery that contains bind 104 | // locations for easy substitution. 105 | func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery { 106 | return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations} 107 | } 108 | 109 | // HasBindVars returns true if the parsed query uses bind vars. 110 | func (buf *TrackedBuffer) HasBindVars() bool { 111 | return len(buf.bindLocations) != 0 112 | } 113 | -------------------------------------------------------------------------------- /internal/sqltypes/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2012, Google Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above 11 | copyright notice, this list of conditions and the following disclaimer 12 | in the documentation and/or other materials provided with the 13 | distribution. 14 | * Neither the name of Google Inc. nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /internal/sqltypes/proto3.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqltypes 6 | 7 | import querypb "github.com/guregu/mogi/internal/proto/query" 8 | 9 | // This file contains the proto3 conversion functions for the structures 10 | // defined here. 11 | 12 | // RowsToProto3 converts [][]Value to proto3. 13 | func RowsToProto3(rows [][]Value) []*querypb.Row { 14 | if len(rows) == 0 { 15 | return nil 16 | } 17 | 18 | result := make([]*querypb.Row, len(rows)) 19 | for i, r := range rows { 20 | row := &querypb.Row{} 21 | result[i] = row 22 | row.Lengths = make([]int64, 0, len(r)) 23 | total := 0 24 | for _, c := range r { 25 | if c.IsNull() { 26 | row.Lengths = append(row.Lengths, -1) 27 | continue 28 | } 29 | length := c.Len() 30 | row.Lengths = append(row.Lengths, int64(length)) 31 | total += length 32 | } 33 | row.Values = make([]byte, 0, total) 34 | for _, c := range r { 35 | if c.IsNull() { 36 | continue 37 | } 38 | row.Values = append(row.Values, c.Raw()...) 39 | } 40 | } 41 | return result 42 | } 43 | 44 | // proto3ToRows converts a proto3 rows to [][]Value. The function is private 45 | // because it uses the trusted API. 46 | func proto3ToRows(fields []*querypb.Field, rows []*querypb.Row) [][]Value { 47 | if len(rows) == 0 { 48 | // TODO(sougou): This is needed for backward compatibility. 49 | // Remove when it's not needed any more. 50 | return [][]Value{} 51 | } 52 | 53 | result := make([][]Value, len(rows)) 54 | for i, r := range rows { 55 | result[i] = MakeRowTrusted(fields, r) 56 | } 57 | return result 58 | } 59 | 60 | // ResultToProto3 converts Result to proto3. 61 | func ResultToProto3(qr *Result) *querypb.QueryResult { 62 | if qr == nil { 63 | return nil 64 | } 65 | return &querypb.QueryResult{ 66 | Fields: qr.Fields, 67 | RowsAffected: qr.RowsAffected, 68 | InsertId: qr.InsertID, 69 | Rows: RowsToProto3(qr.Rows), 70 | } 71 | } 72 | 73 | // Proto3ToResult converts a proto3 Result to an internal data structure. This function 74 | // should be used only if the field info is populated in qr. 75 | func Proto3ToResult(qr *querypb.QueryResult) *Result { 76 | if qr == nil { 77 | return nil 78 | } 79 | return &Result{ 80 | Fields: qr.Fields, 81 | RowsAffected: qr.RowsAffected, 82 | InsertID: qr.InsertId, 83 | Rows: proto3ToRows(qr.Fields, qr.Rows), 84 | } 85 | } 86 | 87 | // CustomProto3ToResult converts a proto3 Result to an internal data structure. This function 88 | // takes a separate fields input because not all QueryResults contain the field info. 89 | // In particular, only the first packet of streaming queries contain the field info. 90 | func CustomProto3ToResult(fields []*querypb.Field, qr *querypb.QueryResult) *Result { 91 | if qr == nil { 92 | return nil 93 | } 94 | return &Result{ 95 | Fields: qr.Fields, 96 | RowsAffected: qr.RowsAffected, 97 | InsertID: qr.InsertId, 98 | Rows: proto3ToRows(fields, qr.Rows), 99 | } 100 | } 101 | 102 | // ResultsToProto3 converts []Result to proto3. 103 | func ResultsToProto3(qr []Result) []*querypb.QueryResult { 104 | if len(qr) == 0 { 105 | return nil 106 | } 107 | result := make([]*querypb.QueryResult, len(qr)) 108 | for i, q := range qr { 109 | result[i] = ResultToProto3(&q) 110 | } 111 | return result 112 | } 113 | 114 | // Proto3ToResults converts proto3 results to []Result. 115 | func Proto3ToResults(qr []*querypb.QueryResult) []Result { 116 | if len(qr) == 0 { 117 | return nil 118 | } 119 | result := make([]Result, len(qr)) 120 | for i, q := range qr { 121 | result[i] = *Proto3ToResult(q) 122 | } 123 | return result 124 | } 125 | -------------------------------------------------------------------------------- /internal/sqltypes/proto3_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqltypes 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | 11 | querypb "github.com/guregu/mogi/internal/proto/query" 12 | ) 13 | 14 | func TestResult(t *testing.T) { 15 | fields := []*querypb.Field{{ 16 | Name: "col1", 17 | Type: VarChar, 18 | }, { 19 | Name: "col2", 20 | Type: Int64, 21 | }, { 22 | Name: "col3", 23 | Type: Float64, 24 | }} 25 | sqlResult := &Result{ 26 | Fields: fields, 27 | InsertID: 1, 28 | RowsAffected: 2, 29 | Rows: [][]Value{{ 30 | testVal(VarChar, "aa"), 31 | testVal(Int64, "1"), 32 | testVal(Float64, "2"), 33 | }, { 34 | MakeTrusted(VarChar, []byte("bb")), 35 | NULL, 36 | NULL, 37 | }}, 38 | } 39 | p3Result := &querypb.QueryResult{ 40 | Fields: fields, 41 | InsertId: 1, 42 | RowsAffected: 2, 43 | Rows: []*querypb.Row{{ 44 | Lengths: []int64{2, 1, 1}, 45 | Values: []byte("aa12"), 46 | }, { 47 | Lengths: []int64{2, -1, -1}, 48 | Values: []byte("bb"), 49 | }}, 50 | } 51 | p3converted := ResultToProto3(sqlResult) 52 | if !reflect.DeepEqual(p3converted, p3Result) { 53 | t.Errorf("P3:\n%v, want\n%v", p3converted, p3Result) 54 | } 55 | 56 | reverse := Proto3ToResult(p3Result) 57 | if !reflect.DeepEqual(reverse, sqlResult) { 58 | t.Errorf("reverse:\n%#v, want\n%#v", reverse, sqlResult) 59 | } 60 | 61 | // Test custom fields. 62 | fields[1].Type = VarBinary 63 | sqlResult.Rows[0][1] = testVal(VarBinary, "1") 64 | reverse = CustomProto3ToResult(fields, p3Result) 65 | if !reflect.DeepEqual(reverse, sqlResult) { 66 | t.Errorf("reverse:\n%#v, want\n%#v", reverse, sqlResult) 67 | } 68 | } 69 | 70 | func TestResults(t *testing.T) { 71 | fields1 := []*querypb.Field{{ 72 | Name: "col1", 73 | Type: VarChar, 74 | }, { 75 | Name: "col2", 76 | Type: Int64, 77 | }, { 78 | Name: "col3", 79 | Type: Float64, 80 | }} 81 | fields2 := []*querypb.Field{{ 82 | Name: "col11", 83 | Type: VarChar, 84 | }, { 85 | Name: "col12", 86 | Type: Int64, 87 | }, { 88 | Name: "col13", 89 | Type: Float64, 90 | }} 91 | sqlResults := []Result{{ 92 | Fields: fields1, 93 | InsertID: 1, 94 | RowsAffected: 2, 95 | Rows: [][]Value{{ 96 | testVal(VarChar, "aa"), 97 | testVal(Int64, "1"), 98 | testVal(Float64, "2"), 99 | }}, 100 | }, { 101 | Fields: fields2, 102 | InsertID: 3, 103 | RowsAffected: 4, 104 | Rows: [][]Value{{ 105 | testVal(VarChar, "bb"), 106 | testVal(Int64, "3"), 107 | testVal(Float64, "4"), 108 | }}, 109 | }} 110 | p3Results := []*querypb.QueryResult{{ 111 | Fields: fields1, 112 | InsertId: 1, 113 | RowsAffected: 2, 114 | Rows: []*querypb.Row{{ 115 | Lengths: []int64{2, 1, 1}, 116 | Values: []byte("aa12"), 117 | }}, 118 | }, { 119 | Fields: fields2, 120 | InsertId: 3, 121 | RowsAffected: 4, 122 | Rows: []*querypb.Row{{ 123 | Lengths: []int64{2, 1, 1}, 124 | Values: []byte("bb34"), 125 | }}, 126 | }} 127 | p3converted := ResultsToProto3(sqlResults) 128 | if !reflect.DeepEqual(p3converted, p3Results) { 129 | t.Errorf("P3:\n%v, want\n%v", p3converted, p3Results) 130 | } 131 | 132 | reverse := Proto3ToResults(p3Results) 133 | if !reflect.DeepEqual(reverse, sqlResults) { 134 | t.Errorf("reverse:\n%#v, want\n%#v", reverse, sqlResults) 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /internal/sqltypes/result.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqltypes 6 | 7 | import querypb "github.com/guregu/mogi/internal/proto/query" 8 | 9 | // Result represents a query result. 10 | type Result struct { 11 | Fields []*querypb.Field `json:"fields"` 12 | RowsAffected uint64 `json:"rows_affected"` 13 | InsertID uint64 `json:"insert_id"` 14 | Rows [][]Value `json:"rows"` 15 | } 16 | 17 | // Repair fixes the type info in the rows 18 | // to conform to the supplied field types. 19 | func (result *Result) Repair(fields []*querypb.Field) { 20 | // Usage of j is intentional. 21 | for j, f := range fields { 22 | for _, r := range result.Rows { 23 | if r[j].typ != Null { 24 | r[j].typ = f.Type 25 | } 26 | } 27 | } 28 | } 29 | 30 | // Copy creates a deep copy of Result. 31 | func (result *Result) Copy() *Result { 32 | out := &Result{ 33 | InsertID: result.InsertID, 34 | RowsAffected: result.RowsAffected, 35 | } 36 | if result.Fields != nil { 37 | fieldsp := make([]*querypb.Field, len(result.Fields)) 38 | fields := make([]querypb.Field, len(result.Fields)) 39 | for i, f := range result.Fields { 40 | fields[i] = *f 41 | fieldsp[i] = &fields[i] 42 | } 43 | out.Fields = fieldsp 44 | } 45 | if result.Rows != nil { 46 | rows := make([][]Value, len(result.Rows)) 47 | for i, r := range result.Rows { 48 | rows[i] = make([]Value, len(r)) 49 | totalLen := 0 50 | for _, c := range r { 51 | totalLen += len(c.val) 52 | } 53 | arena := make([]byte, 0, totalLen) 54 | for j, c := range r { 55 | start := len(arena) 56 | arena = append(arena, c.val...) 57 | rows[i][j] = MakeTrusted(c.typ, arena[start:start+len(c.val)]) 58 | } 59 | } 60 | out.Rows = rows 61 | } 62 | return out 63 | } 64 | 65 | // MakeRowTrusted converts a *querypb.Row to []Value based on the types 66 | // in fields. It does not sanity check the values against the type. 67 | // Every place this function is called, a comment is needed that explains 68 | // why it's justified. 69 | func MakeRowTrusted(fields []*querypb.Field, row *querypb.Row) []Value { 70 | sqlRow := make([]Value, len(row.Lengths)) 71 | var offset int64 72 | for i, length := range row.Lengths { 73 | if length < 0 { 74 | continue 75 | } 76 | sqlRow[i] = MakeTrusted(fields[i].Type, row.Values[offset:offset+length]) 77 | offset += length 78 | } 79 | return sqlRow 80 | } 81 | -------------------------------------------------------------------------------- /internal/sqltypes/result_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqltypes 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | 11 | querypb "github.com/guregu/mogi/internal/proto/query" 12 | ) 13 | 14 | func TestRepair(t *testing.T) { 15 | fields := []*querypb.Field{{ 16 | Type: Int64, 17 | }, { 18 | Type: VarChar, 19 | }} 20 | in := Result{ 21 | Rows: [][]Value{ 22 | {testVal(VarBinary, "1"), testVal(VarBinary, "aa")}, 23 | {testVal(VarBinary, "2"), testVal(VarBinary, "bb")}, 24 | }, 25 | } 26 | want := Result{ 27 | Rows: [][]Value{ 28 | {testVal(Int64, "1"), testVal(VarChar, "aa")}, 29 | {testVal(Int64, "2"), testVal(VarChar, "bb")}, 30 | }, 31 | } 32 | in.Repair(fields) 33 | if !reflect.DeepEqual(in, want) { 34 | t.Errorf("Repair:\n%#v, want\n%#v", in, want) 35 | } 36 | } 37 | 38 | func TestCopy(t *testing.T) { 39 | in := &Result{ 40 | Fields: []*querypb.Field{{ 41 | Type: Int64, 42 | }, { 43 | Type: VarChar, 44 | }}, 45 | InsertID: 1, 46 | RowsAffected: 2, 47 | Rows: [][]Value{ 48 | {testVal(Int64, "1"), MakeTrusted(Null, nil)}, 49 | {testVal(Int64, "2"), MakeTrusted(VarChar, nil)}, 50 | {testVal(Int64, "3"), testVal(VarChar, "")}, 51 | }, 52 | } 53 | want := &Result{ 54 | Fields: []*querypb.Field{{ 55 | Type: Int64, 56 | }, { 57 | Type: VarChar, 58 | }}, 59 | InsertID: 1, 60 | RowsAffected: 2, 61 | Rows: [][]Value{ 62 | {testVal(Int64, "1"), MakeTrusted(Null, nil)}, 63 | {testVal(Int64, "2"), testVal(VarChar, "")}, 64 | {testVal(Int64, "3"), testVal(VarChar, "")}, 65 | }, 66 | } 67 | out := in.Copy() 68 | // Change in so we're sure out got actually copied 69 | in.Fields[0].Type = VarChar 70 | in.Rows[0][0] = testVal(VarChar, "aa") 71 | if !reflect.DeepEqual(out, want) { 72 | t.Errorf("Copy:\n%#v, want\n%#v", out, want) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /internal/sqltypes/type.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqltypes 6 | 7 | import ( 8 | "fmt" 9 | 10 | querypb "github.com/guregu/mogi/internal/proto/query" 11 | ) 12 | 13 | // This file provides wrappers and support 14 | // functions for querypb.Type. 15 | 16 | // These bit flags can be used to query on the 17 | // common properties of types. 18 | const ( 19 | flagIsIntegral = int(querypb.Flag_ISINTEGRAL) 20 | flagIsUnsigned = int(querypb.Flag_ISUNSIGNED) 21 | flagIsFloat = int(querypb.Flag_ISFLOAT) 22 | flagIsQuoted = int(querypb.Flag_ISQUOTED) 23 | flagIsText = int(querypb.Flag_ISTEXT) 24 | flagIsBinary = int(querypb.Flag_ISBINARY) 25 | ) 26 | 27 | // IsIntegral returns true if querypb.Type is an integral 28 | // (signed/unsigned) that can be represented using 29 | // up to 64 binary bits. 30 | func IsIntegral(t querypb.Type) bool { 31 | return int(t)&flagIsIntegral == flagIsIntegral 32 | } 33 | 34 | // IsSigned returns true if querypb.Type is a signed integral. 35 | func IsSigned(t querypb.Type) bool { 36 | return int(t)&(flagIsIntegral|flagIsUnsigned) == flagIsIntegral 37 | } 38 | 39 | // IsUnsigned returns true if querypb.Type is an unsigned integral. 40 | // Caution: this is not the same as !IsSigned. 41 | func IsUnsigned(t querypb.Type) bool { 42 | return int(t)&(flagIsIntegral|flagIsUnsigned) == flagIsIntegral|flagIsUnsigned 43 | } 44 | 45 | // IsFloat returns true is querypb.Type is a floating point. 46 | func IsFloat(t querypb.Type) bool { 47 | return int(t)&flagIsFloat == flagIsFloat 48 | } 49 | 50 | // IsQuoted returns true if querypb.Type is a quoted text or binary. 51 | func IsQuoted(t querypb.Type) bool { 52 | return int(t)&flagIsQuoted == flagIsQuoted 53 | } 54 | 55 | // IsText returns true if querypb.Type is a text. 56 | func IsText(t querypb.Type) bool { 57 | return int(t)&flagIsText == flagIsText 58 | } 59 | 60 | // IsBinary returns true if querypb.Type is a binary. 61 | func IsBinary(t querypb.Type) bool { 62 | return int(t)&flagIsBinary == flagIsBinary 63 | } 64 | 65 | // Vitess data types. These are idiomatically 66 | // named synonyms for the querypb.Type values. 67 | const ( 68 | Null = querypb.Type_NULL_TYPE 69 | Int8 = querypb.Type_INT8 70 | Uint8 = querypb.Type_UINT8 71 | Int16 = querypb.Type_INT16 72 | Uint16 = querypb.Type_UINT16 73 | Int24 = querypb.Type_INT24 74 | Uint24 = querypb.Type_UINT24 75 | Int32 = querypb.Type_INT32 76 | Uint32 = querypb.Type_UINT32 77 | Int64 = querypb.Type_INT64 78 | Uint64 = querypb.Type_UINT64 79 | Float32 = querypb.Type_FLOAT32 80 | Float64 = querypb.Type_FLOAT64 81 | Timestamp = querypb.Type_TIMESTAMP 82 | Date = querypb.Type_DATE 83 | Time = querypb.Type_TIME 84 | Datetime = querypb.Type_DATETIME 85 | Year = querypb.Type_YEAR 86 | Decimal = querypb.Type_DECIMAL 87 | Text = querypb.Type_TEXT 88 | Blob = querypb.Type_BLOB 89 | VarChar = querypb.Type_VARCHAR 90 | VarBinary = querypb.Type_VARBINARY 91 | Char = querypb.Type_CHAR 92 | Binary = querypb.Type_BINARY 93 | Bit = querypb.Type_BIT 94 | Enum = querypb.Type_ENUM 95 | Set = querypb.Type_SET 96 | Tuple = querypb.Type_TUPLE 97 | ) 98 | 99 | // bit-shift the mysql flags by two byte so we 100 | // can merge them with the mysql or vitess types. 101 | const ( 102 | mysqlUnsigned = 32 << 16 103 | mysqlBinary = 128 << 16 104 | mysqlEnum = 256 << 16 105 | mysqlSet = 2048 << 16 106 | 107 | relevantFlags = mysqlUnsigned | 108 | mysqlBinary | 109 | mysqlEnum | 110 | mysqlSet 111 | ) 112 | 113 | // If you add to this map, make sure you add a test case 114 | // in tabletserver/endtoend. 115 | var mysqlToType = map[int64]querypb.Type{ 116 | 1: Int8, 117 | 2: Int16, 118 | 3: Int32, 119 | 4: Float32, 120 | 5: Float64, 121 | 6: Null, 122 | 7: Timestamp, 123 | 8: Int64, 124 | 9: Int24, 125 | 10: Date, 126 | 11: Time, 127 | 12: Datetime, 128 | 13: Year, 129 | 16: Bit, 130 | 246: Decimal, 131 | 249: Text, 132 | 250: Text, 133 | 251: Text, 134 | 252: Text, 135 | 253: VarChar, 136 | 254: Char, 137 | } 138 | 139 | var modifier = map[int64]querypb.Type{ 140 | int64(Int8) | mysqlUnsigned: Uint8, 141 | int64(Int16) | mysqlUnsigned: Uint16, 142 | int64(Int32) | mysqlUnsigned: Uint32, 143 | int64(Int64) | mysqlUnsigned: Uint64, 144 | int64(Int24) | mysqlUnsigned: Uint24, 145 | int64(Text) | mysqlBinary: Blob, 146 | int64(VarChar) | mysqlBinary: VarBinary, 147 | int64(Char) | mysqlBinary: Binary, 148 | int64(Char) | mysqlEnum: Enum, 149 | int64(Char) | mysqlSet: Set, 150 | } 151 | 152 | // typeToMySQL is the reverse of mysqlToType. 153 | var typeToMySQL = map[querypb.Type]struct { 154 | typ int64 155 | flags int64 156 | }{ 157 | Int8: {typ: 1}, 158 | Uint8: {typ: 1, flags: mysqlUnsigned}, 159 | Int16: {typ: 2}, 160 | Uint16: {typ: 2, flags: mysqlUnsigned}, 161 | Int32: {typ: 3}, 162 | Uint32: {typ: 3, flags: mysqlUnsigned}, 163 | Float32: {typ: 4}, 164 | Float64: {typ: 5}, 165 | Null: {typ: 6, flags: mysqlBinary}, 166 | Timestamp: {typ: 7}, 167 | Int64: {typ: 8}, 168 | Uint64: {typ: 8, flags: mysqlUnsigned}, 169 | Int24: {typ: 9}, 170 | Uint24: {typ: 9, flags: mysqlUnsigned}, 171 | Date: {typ: 10, flags: mysqlBinary}, 172 | Time: {typ: 11, flags: mysqlBinary}, 173 | Datetime: {typ: 12, flags: mysqlBinary}, 174 | Year: {typ: 13, flags: mysqlUnsigned}, 175 | Bit: {typ: 16, flags: mysqlUnsigned}, 176 | Decimal: {typ: 246}, 177 | Text: {typ: 252}, 178 | Blob: {typ: 252, flags: mysqlBinary}, 179 | VarChar: {typ: 253}, 180 | VarBinary: {typ: 253, flags: mysqlBinary}, 181 | Char: {typ: 254}, 182 | Binary: {typ: 254, flags: mysqlBinary}, 183 | Enum: {typ: 254, flags: mysqlEnum}, 184 | Set: {typ: 254, flags: mysqlSet}, 185 | } 186 | 187 | // MySQLToType computes the vitess type from mysql type and flags. 188 | // The function panics if the type is unrecognized. 189 | func MySQLToType(mysqlType, flags int64) querypb.Type { 190 | result, ok := mysqlToType[mysqlType] 191 | if !ok { 192 | panic(fmt.Errorf("Could not map: %d to a vitess type", mysqlType)) 193 | } 194 | 195 | converted := (flags << 16) & relevantFlags 196 | modified, ok := modifier[int64(result)|converted] 197 | if ok { 198 | return modified 199 | } 200 | return result 201 | } 202 | 203 | // TypeToMySQL returns the equivalent mysql type and flag for a vitess type. 204 | func TypeToMySQL(typ querypb.Type) (mysqlType, flags int64) { 205 | val := typeToMySQL[typ] 206 | return val.typ, val.flags >> 16 207 | } 208 | -------------------------------------------------------------------------------- /internal/sqltypes/type_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015| Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqltypes 6 | 7 | import ( 8 | "testing" 9 | 10 | querypb "github.com/guregu/mogi/internal/proto/query" 11 | ) 12 | 13 | func TestTypeValues(t *testing.T) { 14 | testcases := []struct { 15 | defined querypb.Type 16 | expected int 17 | }{{ 18 | defined: Null, 19 | expected: 0, 20 | }, { 21 | defined: Int8, 22 | expected: 1 | flagIsIntegral, 23 | }, { 24 | defined: Uint8, 25 | expected: 2 | flagIsIntegral | flagIsUnsigned, 26 | }, { 27 | defined: Int16, 28 | expected: 3 | flagIsIntegral, 29 | }, { 30 | defined: Uint16, 31 | expected: 4 | flagIsIntegral | flagIsUnsigned, 32 | }, { 33 | defined: Int24, 34 | expected: 5 | flagIsIntegral, 35 | }, { 36 | defined: Uint24, 37 | expected: 6 | flagIsIntegral | flagIsUnsigned, 38 | }, { 39 | defined: Int32, 40 | expected: 7 | flagIsIntegral, 41 | }, { 42 | defined: Uint32, 43 | expected: 8 | flagIsIntegral | flagIsUnsigned, 44 | }, { 45 | defined: Int64, 46 | expected: 9 | flagIsIntegral, 47 | }, { 48 | defined: Uint64, 49 | expected: 10 | flagIsIntegral | flagIsUnsigned, 50 | }, { 51 | defined: Float32, 52 | expected: 11 | flagIsFloat, 53 | }, { 54 | defined: Float64, 55 | expected: 12 | flagIsFloat, 56 | }, { 57 | defined: Timestamp, 58 | expected: 13 | flagIsQuoted, 59 | }, { 60 | defined: Date, 61 | expected: 14 | flagIsQuoted, 62 | }, { 63 | defined: Time, 64 | expected: 15 | flagIsQuoted, 65 | }, { 66 | defined: Datetime, 67 | expected: 16 | flagIsQuoted, 68 | }, { 69 | defined: Year, 70 | expected: 17 | flagIsIntegral | flagIsUnsigned, 71 | }, { 72 | defined: Decimal, 73 | expected: 18, 74 | }, { 75 | defined: Text, 76 | expected: 19 | flagIsQuoted | flagIsText, 77 | }, { 78 | defined: Blob, 79 | expected: 20 | flagIsQuoted | flagIsBinary, 80 | }, { 81 | defined: VarChar, 82 | expected: 21 | flagIsQuoted | flagIsText, 83 | }, { 84 | defined: VarBinary, 85 | expected: 22 | flagIsQuoted | flagIsBinary, 86 | }, { 87 | defined: Char, 88 | expected: 23 | flagIsQuoted | flagIsText, 89 | }, { 90 | defined: Binary, 91 | expected: 24 | flagIsQuoted | flagIsBinary, 92 | }, { 93 | defined: Bit, 94 | expected: 25 | flagIsQuoted, 95 | }, { 96 | defined: Enum, 97 | expected: 26 | flagIsQuoted, 98 | }, { 99 | defined: Set, 100 | expected: 27 | flagIsQuoted, 101 | }, { 102 | defined: Tuple, 103 | expected: 28, 104 | }} 105 | for _, tcase := range testcases { 106 | if int(tcase.defined) != tcase.expected { 107 | t.Errorf("Type %s: %d, want: %d", tcase.defined, int(tcase.defined), tcase.expected) 108 | } 109 | } 110 | } 111 | 112 | func TestIsFunctions(t *testing.T) { 113 | if IsIntegral(Null) { 114 | t.Error("Null: IsIntegral, must be false") 115 | } 116 | if !IsIntegral(Int64) { 117 | t.Error("Int64: !IsIntegral, must be true") 118 | } 119 | if IsSigned(Uint64) { 120 | t.Error("Uint64: IsSigned, must be false") 121 | } 122 | if !IsSigned(Int64) { 123 | t.Error("Int64: !IsSigned, must be true") 124 | } 125 | if IsUnsigned(Int64) { 126 | t.Error("Int64: IsUnsigned, must be false") 127 | } 128 | if !IsUnsigned(Uint64) { 129 | t.Error("Uint64: !IsUnsigned, must be true") 130 | } 131 | if IsFloat(Int64) { 132 | t.Error("Int64: IsFloat, must be false") 133 | } 134 | if !IsFloat(Float64) { 135 | t.Error("Uint64: !IsFloat, must be true") 136 | } 137 | if IsQuoted(Int64) { 138 | t.Error("Int64: IsQuoted, must be false") 139 | } 140 | if !IsQuoted(Binary) { 141 | t.Error("Binary: !IsQuoted, must be true") 142 | } 143 | if IsText(Int64) { 144 | t.Error("Int64: IsText, must be false") 145 | } 146 | if !IsText(Char) { 147 | t.Error("Char: !IsText, must be true") 148 | } 149 | if IsBinary(Int64) { 150 | t.Error("Int64: IsBinary, must be false") 151 | } 152 | if !IsBinary(Binary) { 153 | t.Error("Char: !IsBinary, must be true") 154 | } 155 | } 156 | 157 | func TestTypeToMySQL(t *testing.T) { 158 | v, f := TypeToMySQL(Bit) 159 | if v != 16 { 160 | t.Errorf("Bit: %d, want 16", v) 161 | } 162 | if f != mysqlUnsigned>>16 { 163 | t.Errorf("Bit flag: %x, want %x", f, mysqlUnsigned>>16) 164 | } 165 | v, f = TypeToMySQL(Date) 166 | if v != 10 { 167 | t.Errorf("Bit: %d, want 10", v) 168 | } 169 | if f != mysqlBinary>>16 { 170 | t.Errorf("Bit flag: %x, want %x", f, mysqlBinary>>16) 171 | } 172 | } 173 | 174 | func TestTypeFlexibility(t *testing.T) { 175 | v := MySQLToType(1, mysqlBinary>>16) 176 | if v != Int8 { 177 | t.Errorf("conversion: %v, want %v", v, Int8) 178 | } 179 | var typ int64 180 | for typ = 249; typ <= 252; typ++ { 181 | v = MySQLToType(typ, mysqlBinary>>16) 182 | if v != Blob { 183 | t.Errorf("conversion: %v, want %v", v, Blob) 184 | } 185 | } 186 | } 187 | 188 | func TestTypePanic(t *testing.T) { 189 | defer func() { 190 | err := recover().(error) 191 | want := "Could not map: 15 to a vitess type" 192 | if err == nil || err.Error() != want { 193 | t.Errorf("Error: %v, want %v", err, want) 194 | } 195 | }() 196 | _ = MySQLToType(15, 0) 197 | } 198 | -------------------------------------------------------------------------------- /internal/sqltypes/value.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package sqltypes implements interfaces and types that represent SQL values. 6 | package sqltypes 7 | 8 | import ( 9 | "encoding/base64" 10 | "encoding/json" 11 | "errors" 12 | "fmt" 13 | "strconv" 14 | "time" 15 | 16 | querypb "github.com/guregu/mogi/internal/proto/query" 17 | "github.com/youtube/vitess/go/hack" 18 | ) 19 | 20 | var ( 21 | // NULL represents the NULL value. 22 | NULL = Value{} 23 | // DontEscape tells you if a character should not be escaped. 24 | DontEscape = byte(255) 25 | nullstr = []byte("null") 26 | ) 27 | 28 | // BinWriter interface is used for encoding values. 29 | // Types like bytes.Buffer conform to this interface. 30 | // We expect the writer objects to be in-memory buffers. 31 | // So, we don't expect the write operations to fail. 32 | type BinWriter interface { 33 | Write([]byte) (int, error) 34 | WriteByte(byte) error 35 | } 36 | 37 | // Value can store any SQL value. If the value represents 38 | // an integral type, the bytes are always stored as a cannonical 39 | // representation that matches how MySQL returns such values. 40 | type Value struct { 41 | typ querypb.Type 42 | val []byte 43 | } 44 | 45 | // MakeTrusted makes a new Value based on the type. 46 | // If the value is an integral, then val must be in its cannonical 47 | // form. This function should only be used if you know the value 48 | // and type conform to the rules. Every place this function is 49 | // called, a comment is needed that explains why it's justified. 50 | // Functions within this package are exempt. 51 | func MakeTrusted(typ querypb.Type, val []byte) Value { 52 | if typ == Null { 53 | return NULL 54 | } 55 | return Value{typ: typ, val: val} 56 | } 57 | 58 | // MakeString makes a VarBinary Value. 59 | func MakeString(val []byte) Value { 60 | return MakeTrusted(VarBinary, val) 61 | } 62 | 63 | // BuildValue builds a value from any go type. sqltype.Value is 64 | // also allowed. 65 | func BuildValue(goval interface{}) (v Value, err error) { 66 | // Look for the most common types first. 67 | switch goval := goval.(type) { 68 | case nil: 69 | // no op 70 | case []byte: 71 | v = MakeTrusted(VarBinary, goval) 72 | case int64: 73 | v = MakeTrusted(Int64, strconv.AppendInt(nil, int64(goval), 10)) 74 | case uint64: 75 | v = MakeTrusted(Uint64, strconv.AppendUint(nil, uint64(goval), 10)) 76 | case float64: 77 | v = MakeTrusted(Float64, strconv.AppendFloat(nil, goval, 'f', -1, 64)) 78 | case int: 79 | v = MakeTrusted(Int64, strconv.AppendInt(nil, int64(goval), 10)) 80 | case int8: 81 | v = MakeTrusted(Int8, strconv.AppendInt(nil, int64(goval), 10)) 82 | case int16: 83 | v = MakeTrusted(Int16, strconv.AppendInt(nil, int64(goval), 10)) 84 | case int32: 85 | v = MakeTrusted(Int32, strconv.AppendInt(nil, int64(goval), 10)) 86 | case uint: 87 | v = MakeTrusted(Uint64, strconv.AppendUint(nil, uint64(goval), 10)) 88 | case uint8: 89 | v = MakeTrusted(Uint8, strconv.AppendUint(nil, uint64(goval), 10)) 90 | case uint16: 91 | v = MakeTrusted(Uint16, strconv.AppendUint(nil, uint64(goval), 10)) 92 | case uint32: 93 | v = MakeTrusted(Uint32, strconv.AppendUint(nil, uint64(goval), 10)) 94 | case float32: 95 | v = MakeTrusted(Float32, strconv.AppendFloat(nil, float64(goval), 'f', -1, 64)) 96 | case string: 97 | v = MakeTrusted(VarBinary, []byte(goval)) 98 | case time.Time: 99 | v = MakeTrusted(Datetime, []byte(goval.Format("2006-01-02 15:04:05"))) 100 | case Value: 101 | v = goval 102 | default: 103 | return v, fmt.Errorf("unexpected type %T: %v", goval, goval) 104 | } 105 | return v, nil 106 | } 107 | 108 | // BuildConverted is like BuildValue except that it tries to 109 | // convert a string or []byte to an integral if the target type 110 | // is an integral. We don't perform other implicit conversions 111 | // because they're unsafe. 112 | func BuildConverted(typ querypb.Type, goval interface{}) (v Value, err error) { 113 | if IsIntegral(typ) { 114 | switch goval := goval.(type) { 115 | case []byte: 116 | return ValueFromBytes(typ, goval) 117 | case string: 118 | return ValueFromBytes(typ, []byte(goval)) 119 | case Value: 120 | if goval.IsQuoted() { 121 | return ValueFromBytes(typ, goval.Raw()) 122 | } 123 | } 124 | } 125 | return BuildValue(goval) 126 | } 127 | 128 | // ValueFromBytes builds a Value using typ and val. It ensures that val 129 | // matches the requested type. If type is an integral it's converted to 130 | // a cannonical form. Otherwise, the original representation is preserved. 131 | func ValueFromBytes(typ querypb.Type, val []byte) (v Value, err error) { 132 | switch { 133 | case IsSigned(typ): 134 | signed, err := strconv.ParseInt(string(val), 0, 64) 135 | if err != nil { 136 | return NULL, err 137 | } 138 | v = MakeTrusted(typ, strconv.AppendInt(nil, signed, 10)) 139 | case IsUnsigned(typ): 140 | unsigned, err := strconv.ParseUint(string(val), 0, 64) 141 | if err != nil { 142 | return NULL, err 143 | } 144 | v = MakeTrusted(typ, strconv.AppendUint(nil, unsigned, 10)) 145 | case typ == Tuple: 146 | return NULL, errors.New("tuple not allowed for ValueFromBytes") 147 | case IsFloat(typ) || typ == Decimal: 148 | _, err := strconv.ParseFloat(string(val), 64) 149 | if err != nil { 150 | return NULL, err 151 | } 152 | // After verification, we preserve the original representation. 153 | fallthrough 154 | default: 155 | v = MakeTrusted(typ, val) 156 | } 157 | return v, nil 158 | } 159 | 160 | // BuildIntegral builds an integral type from a string representaion. 161 | // The type will be Int64 or Uint64. Int64 will be preferred where possible. 162 | func BuildIntegral(val string) (n Value, err error) { 163 | signed, err := strconv.ParseInt(val, 0, 64) 164 | if err == nil { 165 | return MakeTrusted(Int64, strconv.AppendInt(nil, signed, 10)), nil 166 | } 167 | unsigned, err := strconv.ParseUint(val, 0, 64) 168 | if err != nil { 169 | return Value{}, err 170 | } 171 | return MakeTrusted(Uint64, strconv.AppendUint(nil, unsigned, 10)), nil 172 | } 173 | 174 | // Type returns the type of Value. 175 | func (v Value) Type() querypb.Type { 176 | return v.typ 177 | } 178 | 179 | // Raw returns the raw bytes. All types are currently implemented as []byte. 180 | // You should avoid using this function. If you do, you should treat the 181 | // bytes as read-only. 182 | func (v Value) Raw() []byte { 183 | return v.val 184 | } 185 | 186 | // Len returns the length. 187 | func (v Value) Len() int { 188 | return len(v.val) 189 | } 190 | 191 | // String returns the raw value as a string. 192 | func (v Value) String() string { 193 | return hack.String(v.val) 194 | } 195 | 196 | // ToNative converts Value to a native go type. 197 | // This does not work for sqltypes.Tuple. The function 198 | // panics if there are inconsistencies. 199 | func (v Value) ToNative() interface{} { 200 | var out interface{} 201 | var err error 202 | switch { 203 | case v.typ == Null: 204 | // no-op 205 | case IsSigned(v.typ): 206 | out, err = v.ParseInt64() 207 | case IsUnsigned(v.typ): 208 | out, err = v.ParseUint64() 209 | case IsFloat(v.typ): 210 | out, err = v.ParseFloat64() 211 | case v.typ == Tuple: 212 | err = errors.New("unexpected tuple") 213 | default: 214 | out = v.val 215 | } 216 | if err != nil { 217 | panic(err) 218 | } 219 | return out 220 | } 221 | 222 | // ParseInt64 will parse a Value into an int64. It does 223 | // not check the type. 224 | func (v Value) ParseInt64() (val int64, err error) { 225 | return strconv.ParseInt(v.String(), 10, 64) 226 | } 227 | 228 | // ParseUint64 will parse a Value into a uint64. It does 229 | // not check the type. 230 | func (v Value) ParseUint64() (val uint64, err error) { 231 | return strconv.ParseUint(v.String(), 10, 64) 232 | } 233 | 234 | // ParseFloat64 will parse a Value into an float64. It does 235 | // not check the type. 236 | func (v Value) ParseFloat64() (val float64, err error) { 237 | return strconv.ParseFloat(v.String(), 64) 238 | } 239 | 240 | // EncodeSQL encodes the value into an SQL statement. Can be binary. 241 | func (v Value) EncodeSQL(b BinWriter) { 242 | // ToNative panics if v is invalid. 243 | _ = v.ToNative() 244 | switch { 245 | case v.typ == Null: 246 | writebytes(nullstr, b) 247 | case IsQuoted(v.typ): 248 | encodeBytesSQL(v.val, b) 249 | default: 250 | writebytes(v.val, b) 251 | } 252 | } 253 | 254 | // EncodeASCII encodes the value using 7-bit clean ascii bytes. 255 | func (v Value) EncodeASCII(b BinWriter) { 256 | // ToNative panics if v is invalid. 257 | _ = v.ToNative() 258 | switch { 259 | case v.typ == Null: 260 | writebytes(nullstr, b) 261 | case IsQuoted(v.typ): 262 | encodeBytesASCII(v.val, b) 263 | default: 264 | writebytes(v.val, b) 265 | } 266 | } 267 | 268 | // IsNull returns true if Value is null. 269 | func (v Value) IsNull() bool { 270 | return v.typ == Null 271 | } 272 | 273 | // IsIntegral returns true if Value is an integral. 274 | func (v Value) IsIntegral() bool { 275 | return IsIntegral(v.typ) 276 | } 277 | 278 | // IsSigned returns true if Value is a signed integral. 279 | func (v Value) IsSigned() bool { 280 | return IsSigned(v.typ) 281 | } 282 | 283 | // IsUnsigned returns true if Value is an unsigned integral. 284 | func (v Value) IsUnsigned() bool { 285 | return IsUnsigned(v.typ) 286 | } 287 | 288 | // IsFloat returns true if Value is a float. 289 | func (v Value) IsFloat() bool { 290 | return IsFloat(v.typ) 291 | } 292 | 293 | // IsQuoted returns true if Value must be SQL-quoted. 294 | func (v Value) IsQuoted() bool { 295 | return IsQuoted(v.typ) 296 | } 297 | 298 | // IsText returns true if Value is a collatable text. 299 | func (v Value) IsText() bool { 300 | return IsText(v.typ) 301 | } 302 | 303 | // IsBinary returns true if Value is binary. 304 | func (v Value) IsBinary() bool { 305 | return IsBinary(v.typ) 306 | } 307 | 308 | // MarshalJSON should only be used for testing. 309 | // It's not a complete implementation. 310 | func (v Value) MarshalJSON() ([]byte, error) { 311 | switch { 312 | case v.IsQuoted(): 313 | return json.Marshal(v.String()) 314 | case v.typ == Null: 315 | return nullstr, nil 316 | } 317 | return v.val, nil 318 | } 319 | 320 | // UnmarshalJSON should only be used for testing. 321 | // It's not a complete implementation. 322 | func (v *Value) UnmarshalJSON(b []byte) error { 323 | if len(b) == 0 { 324 | return fmt.Errorf("error unmarshaling empty bytes") 325 | } 326 | var val interface{} 327 | var err error 328 | switch b[0] { 329 | case '-': 330 | var ival int64 331 | err = json.Unmarshal(b, &ival) 332 | val = ival 333 | case '"': 334 | var bval []byte 335 | err = json.Unmarshal(b, &bval) 336 | val = bval 337 | case 'n': // null 338 | err = json.Unmarshal(b, &val) 339 | default: 340 | var uval uint64 341 | err = json.Unmarshal(b, &uval) 342 | val = uval 343 | } 344 | if err != nil { 345 | return err 346 | } 347 | *v, err = BuildValue(val) 348 | return err 349 | } 350 | 351 | func encodeBytesSQL(val []byte, b BinWriter) { 352 | writebyte('\'', b) 353 | for _, ch := range val { 354 | if encodedChar := SQLEncodeMap[ch]; encodedChar == DontEscape { 355 | writebyte(ch, b) 356 | } else { 357 | writebyte('\\', b) 358 | writebyte(encodedChar, b) 359 | } 360 | } 361 | writebyte('\'', b) 362 | } 363 | 364 | func encodeBytesASCII(val []byte, b BinWriter) { 365 | writebyte('\'', b) 366 | encoder := base64.NewEncoder(base64.StdEncoding, b) 367 | encoder.Write(val) 368 | encoder.Close() 369 | writebyte('\'', b) 370 | } 371 | 372 | func writebyte(c byte, b BinWriter) { 373 | if err := b.WriteByte(c); err != nil { 374 | panic(err) 375 | } 376 | } 377 | 378 | func writebytes(val []byte, b BinWriter) { 379 | n, err := b.Write(val) 380 | if err != nil { 381 | panic(err) 382 | } 383 | if n != len(val) { 384 | panic(errors.New("short write")) 385 | } 386 | } 387 | 388 | // SQLEncodeMap specifies how to escape binary data with '\'. 389 | // Complies to http://dev.mysql.com/doc/refman/5.1/en/string-syntax.html 390 | var SQLEncodeMap [256]byte 391 | 392 | // SQLDecodeMap is the reverse of SQLEncodeMap 393 | var SQLDecodeMap [256]byte 394 | 395 | var encodeRef = map[byte]byte{ 396 | '\x00': '0', 397 | '\'': '\'', 398 | '"': '"', 399 | '\b': 'b', 400 | '\n': 'n', 401 | '\r': 'r', 402 | '\t': 't', 403 | 26: 'Z', // ctl-Z 404 | '\\': '\\', 405 | } 406 | 407 | func init() { 408 | for i := range SQLEncodeMap { 409 | SQLEncodeMap[i] = DontEscape 410 | SQLDecodeMap[i] = DontEscape 411 | } 412 | for i := range SQLEncodeMap { 413 | if to, ok := encodeRef[byte(i)]; ok { 414 | SQLEncodeMap[byte(i)] = to 415 | SQLDecodeMap[to] = byte(i) 416 | } 417 | } 418 | } 419 | -------------------------------------------------------------------------------- /internal/sqltypes/value_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012, Google Inc. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package sqltypes 6 | 7 | import ( 8 | "bytes" 9 | "reflect" 10 | "strings" 11 | "testing" 12 | "time" 13 | 14 | querypb "github.com/guregu/mogi/internal/proto/query" 15 | ) 16 | 17 | func TestMake(t *testing.T) { 18 | v := MakeTrusted(Null, []byte("abcd")) 19 | if !reflect.DeepEqual(v, NULL) { 20 | t.Errorf("MakeTrusted(Null...) = %v, want null", makePretty(v)) 21 | } 22 | v = MakeTrusted(Int64, []byte("1")) 23 | want := testVal(Int64, "1") 24 | if !reflect.DeepEqual(v, want) { 25 | t.Errorf("MakeTrusted(Int64, \"1\") = %v, want %v", makePretty(v), makePretty(want)) 26 | } 27 | v = MakeString([]byte("a")) 28 | want = testVal(VarBinary, "a") 29 | if !reflect.DeepEqual(v, want) { 30 | t.Errorf("MakeString(\"a\") = %v, want %v", makePretty(v), makePretty(want)) 31 | } 32 | } 33 | 34 | func TestBuildValue(t *testing.T) { 35 | testcases := []struct { 36 | in interface{} 37 | out Value 38 | }{{ 39 | in: nil, 40 | out: NULL, 41 | }, { 42 | in: []byte("a"), 43 | out: testVal(VarBinary, "a"), 44 | }, { 45 | in: int64(1), 46 | out: testVal(Int64, "1"), 47 | }, { 48 | in: uint64(1), 49 | out: testVal(Uint64, "1"), 50 | }, { 51 | in: float64(1.2), 52 | out: testVal(Float64, "1.2"), 53 | }, { 54 | in: int(1), 55 | out: testVal(Int64, "1"), 56 | }, { 57 | in: int8(1), 58 | out: testVal(Int8, "1"), 59 | }, { 60 | in: int16(1), 61 | out: testVal(Int16, "1"), 62 | }, { 63 | in: int32(1), 64 | out: testVal(Int32, "1"), 65 | }, { 66 | in: uint(1), 67 | out: testVal(Uint64, "1"), 68 | }, { 69 | in: uint8(1), 70 | out: testVal(Uint8, "1"), 71 | }, { 72 | in: uint16(1), 73 | out: testVal(Uint16, "1"), 74 | }, { 75 | in: uint32(1), 76 | out: testVal(Uint32, "1"), 77 | }, { 78 | in: float32(1), 79 | out: testVal(Float32, "1"), 80 | }, { 81 | in: "a", 82 | out: testVal(VarBinary, "a"), 83 | }, { 84 | in: time.Date(2012, time.February, 24, 23, 19, 43, 10, time.UTC), 85 | out: testVal(Datetime, "2012-02-24 23:19:43"), 86 | }, { 87 | in: testVal(VarBinary, "a"), 88 | out: testVal(VarBinary, "a"), 89 | }} 90 | for _, tcase := range testcases { 91 | v, err := BuildValue(tcase.in) 92 | if err != nil { 93 | t.Errorf("BuildValue(%#v) error: %v", tcase.in, err) 94 | continue 95 | } 96 | if !reflect.DeepEqual(v, tcase.out) { 97 | t.Errorf("BuildValue(%#v) = %v, want %v", tcase.in, makePretty(v), makePretty(tcase.out)) 98 | } 99 | } 100 | 101 | _, err := BuildValue(make(chan bool)) 102 | want := "unexpected" 103 | if err == nil || !strings.Contains(err.Error(), want) { 104 | t.Errorf("BuildValue(chan): %v, want %v", err, want) 105 | } 106 | } 107 | 108 | func TestBuildConverted(t *testing.T) { 109 | testcases := []struct { 110 | typ querypb.Type 111 | val interface{} 112 | out Value 113 | }{{ 114 | typ: Int64, 115 | val: 123, 116 | out: testVal(Int64, "123"), 117 | }, { 118 | typ: Int64, 119 | val: "123", 120 | out: testVal(Int64, "123"), 121 | }, { 122 | typ: Uint64, 123 | val: "123", 124 | out: testVal(Uint64, "123"), 125 | }, { 126 | typ: Int64, 127 | val: []byte("123"), 128 | out: testVal(Int64, "123"), 129 | }, { 130 | typ: Int64, 131 | val: testVal(VarBinary, "123"), 132 | out: testVal(Int64, "123"), 133 | }, { 134 | typ: Int64, 135 | val: testVal(Float32, "123"), 136 | out: testVal(Float32, "123"), 137 | }} 138 | for _, tcase := range testcases { 139 | v, err := BuildConverted(tcase.typ, tcase.val) 140 | if err != nil { 141 | t.Errorf("BuildValue(%v, %#v) error: %v", tcase.typ, tcase.val, err) 142 | continue 143 | } 144 | if !reflect.DeepEqual(v, tcase.out) { 145 | t.Errorf("BuildValue(%v, %#v) = %v, want %v", tcase.typ, tcase.val, makePretty(v), makePretty(tcase.out)) 146 | } 147 | } 148 | } 149 | 150 | const ( 151 | InvalidNeg = "-9223372036854775809" 152 | MinNeg = "-9223372036854775808" 153 | MinPos = "18446744073709551615" 154 | InvalidPos = "18446744073709551616" 155 | ) 156 | 157 | func TestValueFromBytes(t *testing.T) { 158 | testcases := []struct { 159 | inType querypb.Type 160 | inVal string 161 | outVal Value 162 | outErr string 163 | }{{ 164 | inType: Null, 165 | inVal: "", 166 | outVal: NULL, 167 | }, { 168 | inType: Int8, 169 | inVal: "1", 170 | outVal: testVal(Int8, "1"), 171 | }, { 172 | inType: Int16, 173 | inVal: "1", 174 | outVal: testVal(Int16, "1"), 175 | }, { 176 | inType: Int24, 177 | inVal: "1", 178 | outVal: testVal(Int24, "1"), 179 | }, { 180 | inType: Int32, 181 | inVal: "1", 182 | outVal: testVal(Int32, "1"), 183 | }, { 184 | inType: Int64, 185 | inVal: "1", 186 | outVal: testVal(Int64, "1"), 187 | }, { 188 | inType: Uint8, 189 | inVal: "1", 190 | outVal: testVal(Uint8, "1"), 191 | }, { 192 | inType: Uint16, 193 | inVal: "1", 194 | outVal: testVal(Uint16, "1"), 195 | }, { 196 | inType: Uint24, 197 | inVal: "1", 198 | outVal: testVal(Uint24, "1"), 199 | }, { 200 | inType: Uint32, 201 | inVal: "1", 202 | outVal: testVal(Uint32, "1"), 203 | }, { 204 | inType: Uint64, 205 | inVal: "1", 206 | outVal: testVal(Uint64, "1"), 207 | }, { 208 | inType: Float32, 209 | inVal: "1.00", 210 | outVal: testVal(Float32, "1.00"), 211 | }, { 212 | inType: Float64, 213 | inVal: "1.00", 214 | outVal: testVal(Float64, "1.00"), 215 | }, { 216 | inType: Decimal, 217 | inVal: "1.00", 218 | outVal: testVal(Decimal, "1.00"), 219 | }, { 220 | inType: Timestamp, 221 | inVal: "2012-02-24 23:19:43", 222 | outVal: testVal(Timestamp, "2012-02-24 23:19:43"), 223 | }, { 224 | inType: Date, 225 | inVal: "2012-02-24", 226 | outVal: testVal(Date, "2012-02-24"), 227 | }, { 228 | inType: Time, 229 | inVal: "23:19:43", 230 | outVal: testVal(Time, "23:19:43"), 231 | }, { 232 | inType: Datetime, 233 | inVal: "2012-02-24 23:19:43", 234 | outVal: testVal(Datetime, "2012-02-24 23:19:43"), 235 | }, { 236 | inType: Year, 237 | inVal: "1", 238 | outVal: testVal(Year, "1"), 239 | }, { 240 | inType: Text, 241 | inVal: "a", 242 | outVal: testVal(Text, "a"), 243 | }, { 244 | inType: Blob, 245 | inVal: "a", 246 | outVal: testVal(Blob, "a"), 247 | }, { 248 | inType: VarChar, 249 | inVal: "a", 250 | outVal: testVal(VarChar, "a"), 251 | }, { 252 | inType: Binary, 253 | inVal: "a", 254 | outVal: testVal(Binary, "a"), 255 | }, { 256 | inType: Char, 257 | inVal: "a", 258 | outVal: testVal(Char, "a"), 259 | }, { 260 | inType: Bit, 261 | inVal: "1", 262 | outVal: testVal(Bit, "1"), 263 | }, { 264 | inType: Enum, 265 | inVal: "a", 266 | outVal: testVal(Enum, "a"), 267 | }, { 268 | inType: Set, 269 | inVal: "a", 270 | outVal: testVal(Set, "a"), 271 | }, { 272 | inType: VarBinary, 273 | inVal: "a", 274 | outVal: testVal(VarBinary, "a"), 275 | }, { 276 | inType: Int64, 277 | inVal: InvalidNeg, 278 | outErr: "out of range", 279 | }, { 280 | inType: Int64, 281 | inVal: InvalidPos, 282 | outErr: "out of range", 283 | }, { 284 | inType: Uint64, 285 | inVal: "-1", 286 | outErr: "invalid syntax", 287 | }, { 288 | inType: Uint64, 289 | inVal: InvalidPos, 290 | outErr: "out of range", 291 | }, { 292 | inType: Float64, 293 | inVal: "a", 294 | outErr: "invalid syntax", 295 | }, { 296 | inType: Tuple, 297 | inVal: "a", 298 | outErr: "not allowed", 299 | }} 300 | for _, tcase := range testcases { 301 | v, err := ValueFromBytes(tcase.inType, []byte(tcase.inVal)) 302 | if tcase.outErr != "" { 303 | if err == nil || !strings.Contains(err.Error(), tcase.outErr) { 304 | t.Errorf("ValueFromBytes(%v, %v) error: %v, must contain %v", tcase.inType, tcase.inVal, err, tcase.outErr) 305 | } 306 | continue 307 | } 308 | if err != nil { 309 | t.Errorf("ValueFromBytes(%v, %v) error: %v", tcase.inType, tcase.inVal, err) 310 | continue 311 | } 312 | if !reflect.DeepEqual(v, tcase.outVal) { 313 | t.Errorf("ValueFromBytes(%v, %v) = %v, want %v", tcase.inType, tcase.inVal, makePretty(v), makePretty(tcase.outVal)) 314 | } 315 | } 316 | } 317 | 318 | func TestBuildIntegral(t *testing.T) { 319 | testcases := []struct { 320 | in string 321 | outVal Value 322 | outErr string 323 | }{{ 324 | in: MinNeg, 325 | outVal: testVal(Int64, MinNeg), 326 | }, { 327 | in: "1", 328 | outVal: testVal(Int64, "1"), 329 | }, { 330 | in: MinPos, 331 | outVal: testVal(Uint64, MinPos), 332 | }, { 333 | in: InvalidPos, 334 | outErr: "out of range", 335 | }} 336 | for _, tcase := range testcases { 337 | v, err := BuildIntegral(tcase.in) 338 | if tcase.outErr != "" { 339 | if err == nil || !strings.Contains(err.Error(), tcase.outErr) { 340 | t.Errorf("BuildIntegral(%v) error: %v, must contain %v", tcase.in, err, tcase.outErr) 341 | } 342 | continue 343 | } 344 | if err != nil { 345 | t.Errorf("BuildIntegral(%v) error: %v", tcase.in, err) 346 | continue 347 | } 348 | if !reflect.DeepEqual(v, tcase.outVal) { 349 | t.Errorf("BuildIntegral(%v) = %v, want %v", tcase.in, makePretty(v), makePretty(tcase.outVal)) 350 | } 351 | } 352 | } 353 | 354 | func TestAccessors(t *testing.T) { 355 | v := testVal(Int64, "1") 356 | if v.Type() != Int64 { 357 | t.Errorf("v.Type=%v, want Int64", v.Type()) 358 | } 359 | if !bytes.Equal(v.Raw(), []byte("1")) { 360 | t.Errorf("v.Raw=%s, want 1", v.Raw()) 361 | } 362 | if v.Len() != 1 { 363 | t.Errorf("v.Len=%d, want 1", v.Len()) 364 | } 365 | if v.String() != "1" { 366 | t.Errorf("v.String=%s, want 1", v.String()) 367 | } 368 | if v.IsNull() { 369 | t.Error("v.IsNull: true, want false") 370 | } 371 | if !v.IsIntegral() { 372 | t.Error("v.IsIntegral: false, want true") 373 | } 374 | if !v.IsSigned() { 375 | t.Error("v.IsSigned: false, want true") 376 | } 377 | if v.IsUnsigned() { 378 | t.Error("v.IsUnsigned: true, want false") 379 | } 380 | if v.IsFloat() { 381 | t.Error("v.IsFloat: true, want false") 382 | } 383 | if v.IsQuoted() { 384 | t.Error("v.IsQuoted: true, want false") 385 | } 386 | if v.IsText() { 387 | t.Error("v.IsText: true, want false") 388 | } 389 | if v.IsBinary() { 390 | t.Error("v.IsBinary: true, want false") 391 | } 392 | } 393 | 394 | func TestToNative(t *testing.T) { 395 | testcases := []struct { 396 | in Value 397 | out interface{} 398 | }{{ 399 | in: NULL, 400 | out: nil, 401 | }, { 402 | in: testVal(Int8, "1"), 403 | out: int64(1), 404 | }, { 405 | in: testVal(Int16, "1"), 406 | out: int64(1), 407 | }, { 408 | in: testVal(Int24, "1"), 409 | out: int64(1), 410 | }, { 411 | in: testVal(Int32, "1"), 412 | out: int64(1), 413 | }, { 414 | in: testVal(Int64, "1"), 415 | out: int64(1), 416 | }, { 417 | in: testVal(Uint8, "1"), 418 | out: uint64(1), 419 | }, { 420 | in: testVal(Uint16, "1"), 421 | out: uint64(1), 422 | }, { 423 | in: testVal(Uint24, "1"), 424 | out: uint64(1), 425 | }, { 426 | in: testVal(Uint32, "1"), 427 | out: uint64(1), 428 | }, { 429 | in: testVal(Uint64, "1"), 430 | out: uint64(1), 431 | }, { 432 | in: testVal(Float32, "1"), 433 | out: float64(1), 434 | }, { 435 | in: testVal(Float64, "1"), 436 | out: float64(1), 437 | }, { 438 | in: testVal(Timestamp, "2012-02-24 23:19:43"), 439 | out: []byte("2012-02-24 23:19:43"), 440 | }, { 441 | in: testVal(Date, "2012-02-24"), 442 | out: []byte("2012-02-24"), 443 | }, { 444 | in: testVal(Time, "23:19:43"), 445 | out: []byte("23:19:43"), 446 | }, { 447 | in: testVal(Datetime, "2012-02-24 23:19:43"), 448 | out: []byte("2012-02-24 23:19:43"), 449 | }, { 450 | in: testVal(Year, "1"), 451 | out: uint64(1), 452 | }, { 453 | in: testVal(Decimal, "1"), 454 | out: []byte("1"), 455 | }, { 456 | in: testVal(Text, "a"), 457 | out: []byte("a"), 458 | }, { 459 | in: testVal(Blob, "a"), 460 | out: []byte("a"), 461 | }, { 462 | in: testVal(VarChar, "a"), 463 | out: []byte("a"), 464 | }, { 465 | in: testVal(VarBinary, "a"), 466 | out: []byte("a"), 467 | }, { 468 | in: testVal(Char, "a"), 469 | out: []byte("a"), 470 | }, { 471 | in: testVal(Binary, "a"), 472 | out: []byte("a"), 473 | }, { 474 | in: testVal(Bit, "1"), 475 | out: []byte("1"), 476 | }, { 477 | in: testVal(Enum, "a"), 478 | out: []byte("a"), 479 | }, { 480 | in: testVal(Set, "a"), 481 | out: []byte("a"), 482 | }} 483 | for _, tcase := range testcases { 484 | v := tcase.in.ToNative() 485 | if !reflect.DeepEqual(v, tcase.out) { 486 | t.Errorf("%v.ToNative = %#v, want %#v", makePretty(tcase.in), v, tcase.out) 487 | } 488 | } 489 | } 490 | 491 | func TestPanics(t *testing.T) { 492 | testcases := []struct { 493 | in Value 494 | out string 495 | }{{ 496 | in: testVal(Int64, InvalidNeg), 497 | out: "out of range", 498 | }, { 499 | in: testVal(Uint64, InvalidPos), 500 | out: "out of range", 501 | }, { 502 | in: testVal(Uint64, "-1"), 503 | out: "invalid syntax", 504 | }, { 505 | in: testVal(Float64, "a"), 506 | out: "invalid syntax", 507 | }, { 508 | in: testVal(Tuple, "a"), 509 | out: "unexpected", 510 | }} 511 | for _, tcase := range testcases { 512 | func() { 513 | defer func() { 514 | x := recover() 515 | if x == nil { 516 | t.Errorf("%v.ToNative did not panic", makePretty(tcase.in)) 517 | } 518 | err, ok := x.(error) 519 | if !ok { 520 | t.Errorf("%v.ToNative did not panic with an error", makePretty(tcase.in)) 521 | } 522 | if !strings.Contains(err.Error(), tcase.out) { 523 | t.Errorf("%v.ToNative error: %v, must contain; %v ", makePretty(tcase.in), err, tcase.out) 524 | } 525 | }() 526 | _ = tcase.in.ToNative() 527 | }() 528 | } 529 | for _, tcase := range testcases { 530 | func() { 531 | defer func() { 532 | x := recover() 533 | if x == nil { 534 | t.Errorf("%v.EncodeSQL did not panic", makePretty(tcase.in)) 535 | } 536 | err, ok := x.(error) 537 | if !ok { 538 | t.Errorf("%v.EncodeSQL did not panic with an error", makePretty(tcase.in)) 539 | } 540 | if !strings.Contains(err.Error(), tcase.out) { 541 | t.Errorf("%v.EncodeSQL error: %v, must contain; %v ", makePretty(tcase.in), err, tcase.out) 542 | } 543 | }() 544 | tcase.in.EncodeSQL(&bytes.Buffer{}) 545 | }() 546 | } 547 | for _, tcase := range testcases { 548 | func() { 549 | defer func() { 550 | x := recover() 551 | if x == nil { 552 | t.Errorf("%v.EncodeASCII did not panic", makePretty(tcase.in)) 553 | } 554 | err, ok := x.(error) 555 | if !ok { 556 | t.Errorf("%v.EncodeASCII did not panic with an error", makePretty(tcase.in)) 557 | } 558 | if !strings.Contains(err.Error(), tcase.out) { 559 | t.Errorf("%v.EncodeASCII error: %v, must contain; %v ", makePretty(tcase.in), err, tcase.out) 560 | } 561 | }() 562 | tcase.in.EncodeASCII(&bytes.Buffer{}) 563 | }() 564 | } 565 | } 566 | 567 | func TestParseNumbers(t *testing.T) { 568 | v := testVal(VarChar, "1") 569 | sval, err := v.ParseInt64() 570 | if err != nil { 571 | t.Error(err) 572 | } 573 | if sval != 1 { 574 | t.Errorf("v.ParseInt64 = %d, want 1", sval) 575 | } 576 | uval, err := v.ParseUint64() 577 | if err != nil { 578 | t.Error(err) 579 | } 580 | if uval != 1 { 581 | t.Errorf("v.ParseUint64 = %d, want 1", uval) 582 | } 583 | fval, err := v.ParseFloat64() 584 | if err != nil { 585 | t.Error(err) 586 | } 587 | if fval != 1 { 588 | t.Errorf("v.ParseFloat64 = %f, want 1", fval) 589 | } 590 | } 591 | 592 | func TestEncode(t *testing.T) { 593 | testcases := []struct { 594 | in Value 595 | outSQL string 596 | outASCII string 597 | }{{ 598 | in: NULL, 599 | outSQL: "null", 600 | outASCII: "null", 601 | }, { 602 | in: testVal(Int64, "1"), 603 | outSQL: "1", 604 | outASCII: "1", 605 | }, { 606 | in: testVal(VarChar, "foo"), 607 | outSQL: "'foo'", 608 | outASCII: "'Zm9v'", 609 | }, { 610 | in: testVal(VarChar, "\x00'\"\b\n\r\t\x1A\\"), 611 | outSQL: "'\\0\\'\\\"\\b\\n\\r\\t\\Z\\\\'", 612 | outASCII: "'ACciCAoNCRpc'", 613 | }} 614 | for _, tcase := range testcases { 615 | buf := &bytes.Buffer{} 616 | tcase.in.EncodeSQL(buf) 617 | if tcase.outSQL != buf.String() { 618 | t.Errorf("%v.EncodeSQL = %q, want %q", makePretty(tcase.in), buf.String(), tcase.outSQL) 619 | } 620 | buf = &bytes.Buffer{} 621 | tcase.in.EncodeASCII(buf) 622 | if tcase.outASCII != buf.String() { 623 | t.Errorf("%v.EncodeASCII = %q, want %q", makePretty(tcase.in), buf.String(), tcase.outASCII) 624 | } 625 | } 626 | } 627 | 628 | // TestEncodeMap ensures DontEscape is not escaped 629 | func TestEncodeMap(t *testing.T) { 630 | if SQLEncodeMap[DontEscape] != DontEscape { 631 | t.Errorf("SQLEncodeMap[DontEscape] = %v, want %v", SQLEncodeMap[DontEscape], DontEscape) 632 | } 633 | if SQLDecodeMap[DontEscape] != DontEscape { 634 | t.Errorf("SQLDecodeMap[DontEscape] = %v, want %v", SQLEncodeMap[DontEscape], DontEscape) 635 | } 636 | } 637 | 638 | // testVal makes it easy to build a Value for testing. 639 | func testVal(typ querypb.Type, val string) Value { 640 | return Value{typ: typ, val: []byte(val)} 641 | } 642 | 643 | type prettyVal struct { 644 | Type querypb.Type 645 | Value string 646 | } 647 | 648 | // makePretty converts Value to a struct that's readable when printed. 649 | func makePretty(v Value) prettyVal { 650 | return prettyVal{v.typ, string(v.val)} 651 | } 652 | -------------------------------------------------------------------------------- /magic.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "strconv" 7 | "strings" 8 | 9 | // "github.com/davecgh/go-spew/spew" 10 | "github.com/guregu/mogi/internal/sqlparser" 11 | ) 12 | 13 | // transmogrify takes sqlparser expressions and turns them into useful go values 14 | func transmogrify(v interface{}) interface{} { 15 | switch x := v.(type) { 16 | case *sqlparser.ColName: 17 | name := string(x.Name) 18 | if x.Qualifier != "" { 19 | name = fmt.Sprintf("%s.%s", x.Qualifier, name) 20 | } 21 | return name 22 | case *sqlparser.NonStarExpr: 23 | if x.As != "" { 24 | return string(x.As) 25 | } 26 | return transmogrify(x.Expr) 27 | case *sqlparser.StarExpr: 28 | return "*" 29 | case *sqlparser.FuncExpr: 30 | name := strings.ToUpper(string(x.Name)) 31 | var args []string 32 | for _, expr := range x.Exprs { 33 | args = append(args, stringify(transmogrify(expr))) 34 | } 35 | return fmt.Sprintf("%s(%s)", name, strings.Join(args, ", ")) 36 | case *sqlparser.BinaryExpr: 37 | // TODO: figure out some way to make this work 38 | return transmogrify(x.Left) 39 | case sqlparser.ValArg: 40 | // vitess makes args like :v1 41 | str := string(x) 42 | hdr, num := str[:2], str[2:] 43 | if hdr != ":v" { 44 | log.Panicln("unexpected arg format", str) 45 | } 46 | idx, err := strconv.Atoi(num) 47 | if err != nil { 48 | panic(err) 49 | } 50 | return arg(idx - 1) 51 | case sqlparser.StrVal: 52 | return string(x) 53 | case sqlparser.NumVal: 54 | s := string(x) 55 | n, err := strconv.ParseInt(s, 10, 64) 56 | if err == nil { 57 | return n 58 | } 59 | f, err := strconv.ParseFloat(s, 64) 60 | if err == nil { 61 | return f 62 | } 63 | case sqlparser.ValTuple: 64 | vals := make([]interface{}, 0, len(x)) 65 | for _, item := range x { 66 | vals = append(vals, transmogrify(item)) 67 | } 68 | return vals 69 | default: 70 | log.Printf("unknown transmogrify: (%T) %v\n", v, v) 71 | //panic(x) 72 | } 73 | return nil 74 | } 75 | 76 | func extractColumnName(nse *sqlparser.NonStarExpr) string { 77 | if nse.As != "" { 78 | return string(nse.As) 79 | } 80 | return stringify(transmogrify(nse.Expr)) 81 | } 82 | 83 | func extractTableNames(tables *[]string, from sqlparser.TableExpr) { 84 | switch x := from.(type) { 85 | case *sqlparser.AliasedTableExpr: 86 | if name, ok := x.Expr.(*sqlparser.TableName); ok { 87 | *tables = append(*tables, string(name.Name)) 88 | } 89 | case *sqlparser.JoinTableExpr: 90 | extractTableNames(tables, x.LeftExpr) 91 | extractTableNames(tables, x.RightExpr) 92 | } 93 | } 94 | 95 | func extractBoolExpr(vals map[string]interface{}, expr sqlparser.BoolExpr) map[string]interface{} { 96 | if vals == nil { 97 | vals = make(map[string]interface{}) 98 | } 99 | switch x := expr.(type) { 100 | case *sqlparser.OrExpr: 101 | extractBoolExpr(vals, x.Left) 102 | extractBoolExpr(vals, x.Right) 103 | case *sqlparser.AndExpr: 104 | extractBoolExpr(vals, x.Left) 105 | extractBoolExpr(vals, x.Right) 106 | case *sqlparser.ComparisonExpr: 107 | column := transmogrify(x.Left).(string) 108 | vals[column] = transmogrify(x.Right) 109 | } 110 | return vals 111 | } 112 | 113 | func extractBoolExprWithOps(vals map[colop]interface{}, expr sqlparser.BoolExpr) map[colop]interface{} { 114 | if vals == nil { 115 | vals = make(map[colop]interface{}) 116 | } 117 | switch x := expr.(type) { 118 | case *sqlparser.OrExpr: 119 | extractBoolExprWithOps(vals, x.Left) 120 | extractBoolExprWithOps(vals, x.Right) 121 | case *sqlparser.AndExpr: 122 | extractBoolExprWithOps(vals, x.Left) 123 | extractBoolExprWithOps(vals, x.Right) 124 | case *sqlparser.ComparisonExpr: 125 | column := transmogrify(x.Left).(string) 126 | vals[colop{column, x.Operator}] = transmogrify(x.Right) 127 | } 128 | return vals 129 | } 130 | -------------------------------------------------------------------------------- /mogi.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "errors" 7 | "fmt" 8 | "os" 9 | "text/tabwriter" 10 | ) 11 | 12 | var ( 13 | // ErrUnstubbed is returned as the result for unstubbed queries. 14 | ErrUnstubbed = errors.New("mogi: query not stubbed") 15 | // ErrUnresolved is returned as the result of a stub that was matched, 16 | // but whose data could not be resolved. For example, exceeded LIMITs. 17 | ErrUnresolved = errors.New("mogi: query matched but no stub data") 18 | 19 | // errNotSet is used for Exec results stubbed as -1. 20 | errNotSet = errors.New("value set to -1") 21 | ) 22 | 23 | var ( 24 | verbose = false 25 | timeLayout = "" 26 | ) 27 | 28 | func init() { 29 | drv = newDriver() 30 | sql.Register("mogi", drv) 31 | } 32 | 33 | // Reset removes all the stubs that have been set 34 | func Reset() { 35 | drv.conn.stubs = nil 36 | drv.conn.execStubs = nil 37 | } 38 | 39 | // Verbose turns on unstubbed logging when v is true 40 | func Verbose(v bool) { 41 | verbose = v 42 | } 43 | 44 | // ParseTime will configure mogi to convert dates of the given layout 45 | // (e.g. time.RFC3339) to time.Time when using StubCSV. 46 | // Give it an empty string to turn off time parsing. 47 | func ParseTime(layout string) { 48 | timeLayout = layout 49 | } 50 | 51 | // Dump prints all the current stubs, in order of priority. 52 | // Helpful for debugging. 53 | func Dump() { 54 | w := tabwriter.NewWriter(os.Stdout, 0, 8, 0, '\t', 0) 55 | fmt.Fprintf(w, ">>\t\tQuery stubs: (%d total)\t\n", len(drv.conn.stubs)) 56 | fmt.Fprintf(w, "\t\t=========================\t\n") 57 | for rank, s := range drv.conn.stubs { 58 | for i, c := range s.chain { 59 | if i == 0 { 60 | fmt.Fprintf(w, "#%d\t[%d]\t%s\t[%+d]\n", rank+1, s.priority(), c, c.priority()) 61 | continue 62 | } 63 | fmt.Fprintf(w, "\t\t%s\t[%+d]\n", c, c.priority()) 64 | } 65 | switch { 66 | case s.err != nil: 67 | fmt.Fprintf(w, "\t\t→ error: %v\t\n", s.err) 68 | case s.data != nil, s.resolve != nil: 69 | fmt.Fprintf(w, "\t\t→ data\t\n") 70 | } 71 | } 72 | fmt.Fprintf(w, "\t\t\t\n") 73 | fmt.Fprintf(w, ">>\t\tExec stubs: (%d total)\t\n", len(drv.conn.execStubs)) 74 | fmt.Fprintf(w, "\t\t=========================\t\n") 75 | for rank, s := range drv.conn.execStubs { 76 | for i, c := range s.chain { 77 | if i == 0 { 78 | fmt.Fprintf(w, "#%d\t[%d]\t%s\t[%+d]\n", rank+1, s.priority(), c, c.priority()) 79 | continue 80 | } 81 | fmt.Fprintf(w, "\t\t%s\t[%+d]\n", c, c.priority()) 82 | } 83 | switch { 84 | case s.err != nil: 85 | fmt.Fprintf(w, "\t\t→ error: %v\t\n", s.err) 86 | case s.result != nil: 87 | if r, ok := s.result.(execResult); ok { 88 | fmt.Fprintf(w, "\t\t→ result ID: %d, rows: %d\t\n", r.lastInsertID, r.rowsAffected) 89 | } else { 90 | fmt.Fprintf(w, "\t\t→ result %T\t\n", s.result) 91 | } 92 | } 93 | } 94 | w.Flush() 95 | } 96 | 97 | // func Replace() { 98 | // drv.conn = newConn() 99 | // } 100 | 101 | var _ driver.Stmt = &stmt{} 102 | var _ driver.Conn = &conn{} 103 | var _ driver.Driver = &mdriver{} 104 | -------------------------------------------------------------------------------- /mogi_test.go: -------------------------------------------------------------------------------- 1 | package mogi_test 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "testing" 7 | 8 | "github.com/guregu/mogi" 9 | ) 10 | 11 | func TestMogi(t *testing.T) { 12 | defer mogi.Reset() 13 | mogi.Verbose(false) 14 | db := openDB() 15 | 16 | // select (any columns) 17 | mogi.Select().StubCSV(beerCSV) 18 | runBeerSelectQuery(t, db) 19 | 20 | // test .Stub() 21 | mogi.Select().Stub([][]driver.Value{ 22 | {1, "Yona Yona Ale", "Yo-Ho Brewing", 5.5}, 23 | {2, "Punk IPA", "BrewDog", 5.6}}) 24 | runBeerSelectQuery(t, db) 25 | 26 | // test reset 27 | mogi.Reset() 28 | _, err := db.Query("SELECT id, name, brewery, pct FROM beer WHERE pct > ?", 5) 29 | if err != mogi.ErrUnstubbed { 30 | t.Error("after reset, err should be ErrUnstubbed but is", err) 31 | } 32 | 33 | // select specific columns 34 | mogi.Select("id", "name", "brewery", "pct").StubCSV(beerCSV) 35 | runBeerSelectQuery(t, db) 36 | 37 | // select the "wrong" columns 38 | mogi.Reset() 39 | mogi.Select("hello", "👞").StubCSV(beerCSV) 40 | runUnstubbedSelect(t, db) 41 | } 42 | 43 | func TestNotify(t *testing.T) { 44 | defer mogi.Reset() 45 | db := openDB() 46 | 47 | ch := make(chan struct{}) 48 | 49 | mogi.Insert().Into("beer").Notify(ch).StubResult(3, 1) 50 | _, err := db.Exec("INSERT INTO beer (name, brewery, pct) VALUES (?, ?, ?)", "Mikkel’s Dream", "Mikkeller", 4.6) 51 | checkNil(t, err) 52 | 53 | <-ch 54 | } 55 | 56 | func checkNil(t *testing.T, err error) { 57 | if err != nil { 58 | t.Error("error should be nil but is", err) 59 | } 60 | } 61 | 62 | func openDB() *sql.DB { 63 | db, _ := sql.Open("mogi", "") 64 | return db 65 | } 66 | -------------------------------------------------------------------------------- /rows.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/csv" 6 | "io" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | type rows struct { 12 | cols []string 13 | data [][]driver.Value 14 | 15 | cursor int 16 | closed bool 17 | } 18 | 19 | func newRows(cols []string, data [][]driver.Value) *rows { 20 | return &rows{ 21 | cols: cols, 22 | data: data, 23 | } 24 | } 25 | 26 | func (r *rows) Columns() []string { 27 | return r.cols 28 | } 29 | 30 | // Close closes the rows iterator. 31 | func (r *rows) Close() error { 32 | r.closed = true 33 | return nil 34 | } 35 | 36 | func (r *rows) Err() error { 37 | return nil 38 | } 39 | 40 | func (r *rows) Next(dest []driver.Value) error { 41 | r.cursor++ 42 | if r.cursor > len(r.data) { 43 | r.closed = true 44 | return io.EOF 45 | } 46 | 47 | for i, col := range r.data[r.cursor-1] { 48 | dest[i] = col 49 | } 50 | 51 | return nil 52 | } 53 | 54 | // cribbed from DATA-DOG/go-sqlmock 55 | // TODO rewrite 56 | func csvToValues(cols []string, s string) [][]driver.Value { 57 | var data [][]driver.Value 58 | if s == "" { 59 | return nil 60 | } 61 | 62 | res := strings.NewReader(strings.TrimSpace(s)) 63 | csvReader := csv.NewReader(res) 64 | 65 | for { 66 | res, err := csvReader.Read() 67 | if err != nil || res == nil { 68 | break 69 | } 70 | 71 | row := []driver.Value{} 72 | for _, v := range res { 73 | if timeLayout != "" { 74 | if t, err := time.Parse(timeLayout, v); err == nil { 75 | row = append(row, t) 76 | continue 77 | } 78 | } 79 | row = append(row, []byte(strings.TrimSpace(v))) 80 | } 81 | data = append(data, row) 82 | } 83 | return data 84 | } 85 | -------------------------------------------------------------------------------- /select.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | 8 | "github.com/guregu/mogi/internal/sqlparser" 9 | ) 10 | 11 | type selectCond struct { 12 | cols []string 13 | } 14 | 15 | func (sc selectCond) matches(in input) bool { 16 | _, ok := in.statement.(*sqlparser.Select) 17 | if !ok { 18 | return false 19 | } 20 | 21 | // zero parameters means anything 22 | if len(sc.cols) == 0 { 23 | return true 24 | } 25 | 26 | return reflect.DeepEqual(lowercase(sc.cols), lowercase(in.cols())) 27 | } 28 | 29 | func (sc selectCond) priority() int { 30 | if len(sc.cols) > 0 { 31 | return 2 32 | } 33 | return 1 34 | } 35 | 36 | func (sc selectCond) String() string { 37 | cols := "(any)" // TODO support star select 38 | if len(sc.cols) > 0 { 39 | cols = strings.Join(sc.cols, ", ") 40 | } 41 | return fmt.Sprintf("SELECT %s", cols) 42 | } 43 | 44 | type fromCond struct { 45 | tables []string 46 | } 47 | 48 | func (fc fromCond) matches(in input) bool { 49 | var inTables []string 50 | switch x := in.statement.(type) { 51 | case *sqlparser.Select: 52 | for _, tex := range x.From { 53 | extractTableNames(&inTables, tex) 54 | } 55 | } 56 | return reflect.DeepEqual(lowercase(fc.tables), lowercase(inTables)) 57 | } 58 | 59 | func (fc fromCond) priority() int { 60 | if len(fc.tables) > 0 { 61 | return 1 62 | } 63 | return 0 64 | } 65 | 66 | func (fc fromCond) String() string { 67 | return fmt.Sprintf("FROM %s", strings.Join(fc.tables, ", ")) 68 | } 69 | -------------------------------------------------------------------------------- /select_test.go: -------------------------------------------------------------------------------- 1 | package mogi_test 2 | 3 | import ( 4 | "database/sql" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/guregu/mogi" 9 | ) 10 | 11 | const ( 12 | beerCSV = `1,Yona Yona Ale,Yo-Ho Brewing,5.5 13 | 2,Punk IPA,BrewDog,5.6` 14 | ) 15 | 16 | var ( 17 | beers = map[int]beer{ 18 | 1: beer{ 19 | id: 1, 20 | name: "Yona Yona Ale", 21 | brewery: "Yo-Ho Brewing", 22 | pct: 5.5, 23 | }, 24 | 2: beer{ 25 | id: 2, 26 | name: "Punk IPA", 27 | brewery: "BrewDog", 28 | pct: 5.6, 29 | }, 30 | } 31 | ) 32 | 33 | type beer struct { 34 | id int64 35 | name string 36 | brewery string 37 | pct float64 38 | } 39 | 40 | func TestSelectTable(t *testing.T) { 41 | defer mogi.Reset() 42 | db := openDB() 43 | 44 | // filter by table 45 | mogi.Select("id", "name", "brewery", "pct").From("beer").StubCSV(beerCSV) 46 | runBeerSelectQuery(t, db) 47 | 48 | // select the wrong table 49 | mogi.Reset() 50 | mogi.Select("id", "name", "brewery", "pct").From("酒").StubCSV(beerCSV) 51 | runUnstubbedSelect(t, db) 52 | } 53 | 54 | func TestSelectWhere(t *testing.T) { 55 | defer mogi.Reset() 56 | db := openDB() 57 | 58 | // where 59 | mogi.Select().From("beer").Where("pct", 5).StubCSV(beerCSV) 60 | runBeerSelectQuery(t, db) 61 | 62 | // where with weird type 63 | type 数字 int 64 | 五 := 数字(5) 65 | mogi.Reset() 66 | mogi.Select().From("beer").Where("pct", &五).StubCSV(beerCSV) 67 | runBeerSelectQuery(t, db) 68 | 69 | // wrong where 70 | mogi.Reset() 71 | mogi.Select().From("beer").Where("pct", 98).StubCSV(beerCSV) 72 | runUnstubbedSelect(t, db) 73 | } 74 | 75 | func TestSelectArgs(t *testing.T) { 76 | defer mogi.Reset() 77 | db := openDB() 78 | 79 | // where 80 | mogi.Select().Args(5).StubCSV(beerCSV) 81 | runBeerSelectQuery(t, db) 82 | 83 | // wrong args 84 | mogi.Reset() 85 | mogi.Select().Args("サービス残業").StubCSV(beerCSV) 86 | runUnstubbedSelect(t, db) 87 | } 88 | 89 | func TestStubError(t *testing.T) { 90 | defer mogi.Reset() 91 | db := openDB() 92 | 93 | mogi.Select().StubError(sql.ErrNoRows) 94 | _, err := db.Query("SELECT id, name, brewery, pct FROM beer WHERE pct > ?", 5) 95 | if err != sql.ErrNoRows { 96 | t.Error("after StubError, err should be ErrNoRows but is", err) 97 | } 98 | } 99 | 100 | func TestSelectMultipleTables(t *testing.T) { 101 | defer mogi.Reset() 102 | db := openDB() 103 | 104 | mogi.Select().From("a", "b").StubCSV(`foo,bar`) 105 | _, err := db.Query("SELECT a.thing, b.thing FROM a, b WHERE a.id = b.id") 106 | checkNil(t, err) 107 | _, err = db.Query("SELECT a.thing, b.thing FROM a JOIN b ON a.id = b.id") 108 | checkNil(t, err) 109 | 110 | mogi.Reset() 111 | mogi.Select().From("a", "b", "c").StubCSV(`foo,bar,baz`) 112 | _, err = db.Query("SELECT a.thing, b.thing, c.thing FROM a, b, c WHERE a.id = b.id") 113 | checkNil(t, err) 114 | _, err = db.Query("SELECT a.thing, b.thing, c.thing FROM a JOIN b ON a.id = b.id JOIN c ON a.id = c.id") 115 | checkNil(t, err) 116 | } 117 | 118 | func TestSelectColumnNames(t *testing.T) { 119 | defer mogi.Reset() 120 | db := openDB() 121 | 122 | // qualified names 123 | mogi.Select("a.thing", "b.thing", "c.thing").From("qqqq", "b", "c").StubCSV(`foo,bar,baz`) 124 | _, err := db.Query("SELECT a.thing, b.thing, c.thing FROM qqqq as a, b, c WHERE a.id = b.id") 125 | checkNil(t, err) 126 | 127 | // aliased names 128 | mogi.Reset() 129 | mogi.Select("dog", "cat", "hamster").From("a", "b", "c").StubCSV(`foo,bar,baz`) 130 | _, err = db.Query("SELECT a.thing AS dog, b.thing AS cat, c.thing AS hamster FROM a JOIN b ON a.id = b.id JOIN c ON a.id = c.id") 131 | checkNil(t, err) 132 | } 133 | 134 | func TestSelectCount(t *testing.T) { 135 | defer mogi.Reset() 136 | db := openDB() 137 | 138 | mogi.Select("COUNT(abc)", "count(*)").StubCSV("1,5") 139 | _, err := db.Query("SELECT COUNT(abc), COUNT(*) FROM beer") 140 | checkNil(t, err) 141 | } 142 | 143 | func TestSelectWhereIn(t *testing.T) { 144 | defer mogi.Reset() 145 | db := openDB() 146 | 147 | mogi.Select().Where("pct", 5.4, 10.2).StubCSV("2") 148 | _, err := db.Query("SELECT COUNT(*) FROM beer WHERE pct IN (5.4, ?)", 10.2) 149 | checkNil(t, err) 150 | 151 | mogi.Reset() 152 | mogi.Select().WhereOp("pct", "IN", 5.4, 10.2).StubCSV("2") 153 | _, err = db.Query("SELECT COUNT(*) FROM beer WHERE pct IN (5.4, ?)", 10.2) 154 | checkNil(t, err) 155 | } 156 | 157 | func TestSelectStar(t *testing.T) { 158 | defer mogi.Reset() 159 | db := openDB() 160 | 161 | mogi.Select("*").StubCSV("a,b,c") 162 | _, err := db.Query("SELECT * FROM beer") 163 | checkNil(t, err) 164 | } 165 | 166 | func runUnstubbedSelect(t *testing.T, db *sql.DB) { 167 | _, err := db.Query("SELECT id, name, brewery, pct FROM beer WHERE pct > ?", 5) 168 | if err != mogi.ErrUnstubbed { 169 | t.Error("with unmatched query, err should be ErrUnstubbed but is", err) 170 | } 171 | } 172 | 173 | func runBeerSelectQuery(t *testing.T, db *sql.DB) { 174 | expectCols := []string{"id", "name", "brewery", "pct"} 175 | rows, err := db.Query("SELECT id, name, brewery, pct FROM beer WHERE pct > ?", 5) 176 | checkNil(t, err) 177 | cols, err := rows.Columns() 178 | checkNil(t, err) 179 | if !reflect.DeepEqual(cols, expectCols) { 180 | t.Error("bad columns", cols, "≠", expectCols) 181 | } 182 | i := 0 183 | for rows.Next() { 184 | var b beer 185 | rows.Scan(&b.id, &b.name, &b.brewery, &b.pct) 186 | checkBeer(t, b, i+1) 187 | i++ 188 | } 189 | } 190 | 191 | func checkBeer(t *testing.T, b beer, id int) { 192 | cmp, ok := beers[id] 193 | if !ok { 194 | t.Error("unknown beer", id) 195 | return 196 | } 197 | if b != cmp { 198 | t.Error("beers don't match", b, "≠", cmp, id) 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /stmt.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | type stmt struct { 8 | query string 9 | } 10 | 11 | func (s *stmt) Close() error { 12 | return nil 13 | } 14 | 15 | // NumInput returns the number of placeholder parameters. 16 | // 17 | // If NumInput returns >= 0, the sql package will sanity check 18 | // argument counts from callers and return errors to the caller 19 | // before the statement's Exec or Query methods are called. 20 | // 21 | // NumInput may also return -1, if the driver doesn't know 22 | // its number of placeholders. In that case, the sql package 23 | // will not sanity check Exec or Query argument counts. 24 | func (s *stmt) NumInput() int { 25 | return -1 26 | } 27 | 28 | // Exec executes a query that doesn't return rows, such 29 | // as an INSERT or UPDATE. 30 | func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { 31 | return drv.conn.Exec(s.query, args) 32 | } 33 | 34 | // Query executes a query that may return rows, such as a 35 | // SELECT. 36 | func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { 37 | return drv.conn.Query(s.query, args) 38 | } 39 | -------------------------------------------------------------------------------- /stub.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | // Stub is a SQL query stub (for SELECT) 8 | type Stub struct { 9 | chain condchain 10 | data [][]driver.Value 11 | err error 12 | 13 | resolve func(input) 14 | } 15 | 16 | type subquery struct { 17 | chain condchain 18 | } 19 | 20 | // Select starts a new stub for SELECT statements. 21 | // You can filter out which columns to use this stub for. 22 | // If you don't pass any columns, it will stub all SELECT queries. 23 | func Select(cols ...string) *Stub { 24 | return &Stub{ 25 | chain: condchain{selectCond{ 26 | cols: cols, 27 | }}, 28 | } 29 | } 30 | 31 | // From further filters this stub by table names in the FROM and JOIN clauses (in order). 32 | // You need to give it the un-aliased table names. 33 | func (s *Stub) From(tables ...string) *Stub { 34 | s.chain = append(s.chain, fromCond{ 35 | tables: tables, 36 | }) 37 | return s 38 | } 39 | 40 | // Where further filters this stub by values of input in the WHERE clause. 41 | // You can pass multiple values for IN clause matching. 42 | func (s *Stub) Where(col string, v ...interface{}) *Stub { 43 | s.chain = append(s.chain, newWhereCond(col, v)) 44 | return s 45 | } 46 | 47 | // WhereOp further filters this stub by values of input and the operator used in the WHERE clause. 48 | func (s *Stub) WhereOp(col string, operator string, v ...interface{}) *Stub { 49 | s.chain = append(s.chain, newWhereOpCond(col, v, operator)) 50 | return s 51 | } 52 | 53 | // Args further filters this stub, matching based on the args passed to the query 54 | func (s *Stub) Args(args ...driver.Value) *Stub { 55 | s.chain = append(s.chain, argsCond{args}) 56 | return s 57 | } 58 | 59 | // Priority adds the given priority to this stub, without performing any matching. 60 | func (s *Stub) Priority(p int) *Stub { 61 | s.chain = append(s.chain, priorityCond{p}) 62 | return s 63 | } 64 | 65 | // Notify will have this stub send to the given channel when matched. 66 | // You should put this as the last part of your stub chain. 67 | func (s *Stub) Notify(ch chan<- struct{}) *Stub { 68 | s.chain = append(s.chain, notifyCond{ch}) 69 | return s 70 | } 71 | 72 | // Dump outputs debug information, without performing any matching. 73 | func (s *Stub) Dump() *Stub { 74 | s.chain = append(s.chain, dumpCond{}) 75 | return s 76 | } 77 | 78 | // StubCSV takes CSV data and registers this stub with the driver 79 | func (s *Stub) StubCSV(data string) { 80 | s.resolve = func(in input) { 81 | s.data = csvToValues(in.cols(), data) 82 | } 83 | addStub(s) 84 | } 85 | 86 | // Stub takes row data and registers this stub with the driver 87 | func (s *Stub) Stub(rows [][]driver.Value) { 88 | s.data = rows 89 | addStub(s) 90 | } 91 | 92 | // StubError registers this stub to return the given error 93 | func (s *Stub) StubError(err error) { 94 | s.err = err 95 | addStub(s) 96 | } 97 | 98 | func (s *Stub) Subquery() subquery { 99 | return subquery{chain: s.chain} 100 | } 101 | 102 | func (s *Stub) matches(in input) bool { 103 | return s.chain.matches(in) 104 | } 105 | 106 | func (s *Stub) rows(in input) (*rows, error) { 107 | switch { 108 | case s.err != nil: 109 | return nil, s.err 110 | case s.data == nil && s.resolve != nil: 111 | s.resolve(in) 112 | } 113 | return newRows(in.cols(), s.data), nil 114 | } 115 | 116 | func (s *Stub) priority() int { 117 | return s.chain.priority() 118 | } 119 | 120 | // stubs are arranged by how complex they are for now 121 | type stubs []*Stub 122 | 123 | func (s stubs) Len() int { return len(s) } 124 | func (s stubs) Less(i, j int) bool { return s[i].priority() > s[j].priority() } 125 | func (s stubs) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 126 | -------------------------------------------------------------------------------- /tx.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | type tx struct { 4 | } 5 | 6 | func (t *tx) Commit() error { 7 | return nil 8 | } 9 | 10 | func (t *tx) Rollback() error { 11 | return nil 12 | } 13 | -------------------------------------------------------------------------------- /unify.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "log" 7 | "reflect" 8 | "strings" 9 | "time" 10 | ) 11 | 12 | // unify converts values to fit driver.Value, 13 | // except []byte which is converted to string. 14 | func unify(v interface{}) interface{} { 15 | // happy path 16 | switch x := v.(type) { 17 | case nil: 18 | return x 19 | case bool: 20 | return x 21 | case driver.Valuer: 22 | v, err := x.Value() 23 | if err != nil { 24 | panic(err) 25 | } 26 | return v 27 | 28 | // int64 29 | case int64: 30 | return x 31 | case int: 32 | return int64(x) 33 | case int32: 34 | return int64(x) 35 | case int16: 36 | return int64(x) 37 | case int8: 38 | return int64(x) 39 | case byte: 40 | return int64(x) 41 | 42 | // float64 43 | case float64: 44 | return x 45 | case float32: 46 | return float64(x) 47 | 48 | // string 49 | case string: 50 | return x 51 | case []byte: 52 | return string(x) 53 | 54 | // time.Time 55 | case time.Time: 56 | return x 57 | case *time.Time: 58 | if x == nil { 59 | return nil 60 | } 61 | return *x 62 | } 63 | 64 | // sad path 65 | rv := reflect.ValueOf(v) 66 | return reflectUnify(rv) 67 | } 68 | 69 | func reflectUnify(rv reflect.Value) interface{} { 70 | switch rv.Kind() { 71 | case reflect.Ptr: 72 | if rv.IsNil() { 73 | return nil 74 | } 75 | return reflectUnify(rv.Elem()) 76 | case reflect.Bool: 77 | return rv.Bool() 78 | case reflect.Int64, reflect.Int, reflect.Int32, reflect.Int16, reflect.Int8: 79 | return rv.Int() 80 | case reflect.Float64, reflect.Float32: 81 | return rv.Float() 82 | case reflect.String: 83 | return rv.String() 84 | case reflect.Slice: 85 | if rv.Elem().Kind() == reflect.Int8 { 86 | return string(rv.Bytes()) 87 | } 88 | } 89 | 90 | panic("couldn't unify value of type " + rv.Type().Name()) 91 | } 92 | 93 | func unifyValues(values []driver.Value) []driver.Value { 94 | for i, v := range values { 95 | values[i] = unify(v) 96 | } 97 | return values 98 | } 99 | 100 | func unifyInterfaces(slice []interface{}) []interface{} { 101 | for i, v := range slice { 102 | slice[i] = unify(v) 103 | } 104 | return slice 105 | } 106 | 107 | func stringify(v interface{}) string { 108 | return fmt.Sprintf("%s", v) 109 | } 110 | 111 | func lowercase(strs []string) []string { 112 | lower := make([]string, 0, len(strs)) 113 | for _, str := range strs { 114 | lower = append(lower, strings.ToLower(str)) 115 | } 116 | return lower 117 | } 118 | 119 | func equals(src interface{}, to interface{}) bool { 120 | switch tox := to.(type) { 121 | case time.Time: 122 | // we need to convert source timestamps to time.Time 123 | if timeLayout == "" { 124 | break 125 | } 126 | var other time.Time 127 | switch srcx := src.(type) { 128 | case string: 129 | var err error 130 | if other, err = time.Parse(timeLayout, srcx); err != nil { 131 | goto deep 132 | } 133 | case []byte: 134 | var err error 135 | if other, err = time.Parse(timeLayout, string(srcx)); err != nil { 136 | goto deep 137 | } 138 | case time.Time: 139 | other = srcx 140 | } 141 | return tox.Format(timeLayout) == other.Format(timeLayout) 142 | case bool: 143 | // some drivers send booleans as 0 and 1 144 | switch srcx := src.(type) { 145 | case int64: 146 | return tox == (srcx != 0) 147 | case bool: 148 | return tox == srcx 149 | case string: 150 | other, ok := str2bool(srcx) 151 | if !ok { 152 | goto deep 153 | } 154 | return tox == other 155 | case []byte: 156 | other, ok := str2bool(string(srcx)) 157 | if !ok { 158 | goto deep 159 | } 160 | return tox == other 161 | } 162 | } 163 | deep: 164 | return reflect.DeepEqual(src, to) 165 | } 166 | 167 | // converts boolean-like strings to a bool 168 | func str2bool(str string) (v bool, ok bool) { 169 | switch str { 170 | case "true", "1": 171 | return true, true 172 | case "false", "0": 173 | return false, true 174 | default: 175 | log.Println("mogi: unknown boolean string:", str) 176 | return false, false 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /update.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | 8 | "github.com/guregu/mogi/internal/sqlparser" 9 | ) 10 | 11 | type updateCond struct { 12 | cols []string 13 | } 14 | 15 | func (uc updateCond) matches(in input) bool { 16 | _, ok := in.statement.(*sqlparser.Update) 17 | if !ok { 18 | return false 19 | } 20 | 21 | // zero parameters means anything 22 | if len(uc.cols) == 0 { 23 | return true 24 | } 25 | 26 | return reflect.DeepEqual(lowercase(uc.cols), lowercase(in.cols())) 27 | } 28 | 29 | func (uc updateCond) priority() int { 30 | if len(uc.cols) > 0 { 31 | return 2 32 | } 33 | return 1 34 | } 35 | 36 | func (uc updateCond) String() string { 37 | cols := "(any)" // TODO support star select 38 | if len(uc.cols) > 0 { 39 | cols = strings.Join(uc.cols, ", ") 40 | } 41 | return fmt.Sprintf("UPDATE %s", cols) 42 | } 43 | -------------------------------------------------------------------------------- /update_test.go: -------------------------------------------------------------------------------- 1 | package mogi_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/guregu/mogi" 8 | ) 9 | 10 | func TestUpdate(t *testing.T) { 11 | defer mogi.Reset() 12 | db := openDB() 13 | 14 | // naked update 15 | mogi.Update().StubResult(-1, 1) 16 | _, err := db.Exec(`UPDATE beer 17 | SET name = "Mikkel’s Dream", brewery = "Mikkeller", pct = 4.6 18 | WHERE id = 3`) 19 | checkNil(t, err) 20 | 21 | // update with cols 22 | mogi.Reset() 23 | mogi.Update("name", "brewery", "pct").StubResult(-1, 1) 24 | _, err = db.Exec(`UPDATE beer 25 | SET name = "Mikkel’s Dream", brewery = "Mikkeller", pct = 4.6 26 | WHERE id = 3`) 27 | checkNil(t, err) 28 | 29 | // with wrong cols 30 | mogi.Reset() 31 | mogi.Update("犬", "🐱", "かっぱ").StubResult(-1, 1) 32 | _, err = db.Exec(`UPDATE beer 33 | SET name = "Mikkel’s Dream", brewery = "Mikkeller", pct = 4.6 34 | WHERE id = 3`) 35 | if err != mogi.ErrUnstubbed { 36 | t.Error("err should be ErrUnstubbed but is", err) 37 | } 38 | } 39 | 40 | func TestUpdateTable(t *testing.T) { 41 | defer mogi.Reset() 42 | db := openDB() 43 | 44 | // table 45 | mogi.Update().Table("beer").StubResult(-1, 1) 46 | _, err := db.Exec(`UPDATE beer 47 | SET name = "Mikkel’s Dream", brewery = "Mikkeller", pct = 4.6 48 | WHERE id = 3`) 49 | checkNil(t, err) 50 | 51 | // with wrong table 52 | mogi.Reset() 53 | mogi.Update().Table("酒").StubResult(-1, 1) 54 | _, err = db.Exec(`UPDATE beer 55 | SET name = "Mikkel’s Dream", brewery = "Mikkeller", pct = 4.6 56 | WHERE id = 3`) 57 | if err != mogi.ErrUnstubbed { 58 | t.Error("err should be ErrUnstubbed but is", err) 59 | } 60 | } 61 | 62 | func TestUpdateValues(t *testing.T) { 63 | defer mogi.Reset() 64 | db := openDB() 65 | 66 | mogi.Update().Value("name", "Mikkel’s Dream").Value("brewery", "Mikkeller").Value("pct", 4.6).StubRowsAffected(1) 67 | _, err := db.Exec(`UPDATE beer 68 | SET name = "Mikkel’s Dream", brewery = "Mikkeller", pct = ? 69 | WHERE id = 3`, 4.6) 70 | checkNil(t, err) 71 | 72 | // time.Time 73 | mogi.Reset() 74 | mogi.ParseTime(time.RFC3339) 75 | now := time.Now() 76 | mogi.Update().Value("updated_at", now).Value("brewery", "Mikkeller").Value("pct", 4.6).StubRowsAffected(1) 77 | _, err = db.Exec(`UPDATE beer 78 | SET updated_at = ?, brewery = "Mikkeller", pct = ? 79 | WHERE id = 3`, now, 4.6) 80 | checkNil(t, err) 81 | 82 | // boolean as 1 vs true 83 | mogi.Reset() 84 | mogi.Update().Value("awesome", true).StubRowsAffected(1) 85 | _, err = db.Exec(`UPDATE beer 86 | SET awesome = 1 87 | WHERE id = 3`) 88 | checkNil(t, err) 89 | 90 | // with wrong values 91 | mogi.Reset() 92 | mogi.Update().Value("name", "7-Premium THE BREW").Value("brewery", "Suntory").Value("pct", 5.0).StubResult(-1, 1) 93 | _, err = db.Exec(`UPDATE beer 94 | SET name = "Mikkel’s Dream", brewery = "Mikkeller", pct = ? 95 | WHERE id = 3`, 4.6) 96 | if err != mogi.ErrUnstubbed { 97 | t.Error("err should be ErrUnstubbed but is", err) 98 | } 99 | } 100 | 101 | func TestUpdateWhere(t *testing.T) { 102 | defer mogi.Reset() 103 | db := openDB() 104 | 105 | mogi.Update().Where("id", 3).Where("moon", "full").StubResult(-1, 1) 106 | _, err := db.Exec(`UPDATE beer 107 | SET name = "Mikkel’s Dream", brewery = "Mikkeller", pct = ? 108 | WHERE id = ? AND moon = "full"`, 4.6, 3) 109 | checkNil(t, err) 110 | 111 | mogi.Reset() 112 | mogi.Update().Where("foo", 555).Where("bar", "qux").StubResult(-1, 1) 113 | _, err = db.Exec(`UPDATE beer 114 | SET name = "Mikkel’s Dream", brewery = "Mikkeller", pct = ? 115 | WHERE id = 3 AND moon = "full"`, 4.6) 116 | if err != mogi.ErrUnstubbed { 117 | t.Error("err should be ErrUnstubbed but is", err) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /where.go: -------------------------------------------------------------------------------- 1 | package mogi 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type whereCond struct { 9 | col string 10 | v []interface{} 11 | } 12 | 13 | func newWhereCond(col string, v []interface{}) whereCond { 14 | return whereCond{ 15 | col: strings.ToLower(col), 16 | v: unifyInterfaces(v), 17 | } 18 | } 19 | 20 | func (wc whereCond) matches(in input) bool { 21 | vals := in.where() 22 | v, ok := vals[wc.col] 23 | if !ok { 24 | return false 25 | } 26 | 27 | // compare slices 28 | if slice, ok := v.([]interface{}); ok { 29 | for i, src := range slice { 30 | if !equals(src, wc.v[i]) { 31 | return false 32 | } 33 | } 34 | return true 35 | } 36 | 37 | // compare single value 38 | return equals(v, wc.v[0]) 39 | } 40 | 41 | func (wc whereCond) priority() int { 42 | return 1 43 | } 44 | 45 | func (wc whereCond) String() string { 46 | return fmt.Sprintf("WHERE %s ≈ %v", wc.col, wc.v) 47 | } 48 | 49 | type whereOpCond struct { 50 | col string 51 | op string 52 | v []interface{} 53 | } 54 | 55 | func newWhereOpCond(col string, v []interface{}, op string) whereOpCond { 56 | return whereOpCond{ 57 | col: strings.ToLower(col), 58 | v: unifyInterfaces(v), 59 | op: strings.ToLower(op), 60 | } 61 | } 62 | 63 | func (wc whereOpCond) matches(in input) bool { 64 | vals := in.whereOp() 65 | v, ok := vals[colop{wc.col, wc.op}] 66 | if !ok { 67 | return false 68 | } 69 | 70 | // compare slices 71 | if slice, ok := v.([]interface{}); ok { 72 | for i, src := range slice { 73 | if !equals(src, wc.v[i]) { 74 | return false 75 | } 76 | } 77 | return true 78 | } 79 | 80 | // compare single value 81 | return equals(v, wc.v[0]) 82 | } 83 | 84 | func (wc whereOpCond) priority() int { 85 | return 2 86 | } 87 | 88 | func (wc whereOpCond) String() string { 89 | return fmt.Sprintf("WHERE %s %s %v", wc.col, strings.ToUpper(wc.op), wc.v) 90 | } 91 | --------------------------------------------------------------------------------