├── .gitignore ├── stmt.go ├── result.go ├── tx.go ├── rows.go ├── tx_test.go ├── LICENSE ├── conn.go ├── testdb.go ├── examples_test.go ├── README.md └── testdb_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | -------------------------------------------------------------------------------- /stmt.go: -------------------------------------------------------------------------------- 1 | package testdb 2 | 3 | import ( 4 | "database/sql/driver" 5 | ) 6 | 7 | type stmt struct { 8 | queryFunc func(args []driver.Value) (driver.Rows, error) 9 | execFunc func(args []driver.Value) (driver.Result, error) 10 | } 11 | 12 | func (s *stmt) Close() error { 13 | return nil 14 | } 15 | 16 | func (s *stmt) NumInput() int { 17 | // This prevents the sql package from validating the number of inputs 18 | return -1 19 | } 20 | 21 | func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { 22 | return s.execFunc(args) 23 | } 24 | 25 | func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { 26 | return s.queryFunc(args) 27 | } 28 | -------------------------------------------------------------------------------- /result.go: -------------------------------------------------------------------------------- 1 | package testdb 2 | 3 | type Result struct { 4 | lastInsertId int64 5 | lastInsertIdError error 6 | rowsAffected int64 7 | rowsAffectedError error 8 | } 9 | 10 | func NewResult(lastId int64, lastIdError error, rowsAffected int64, rowsAffectedError error) (res *Result) { 11 | return &Result{ 12 | lastInsertId: lastId, 13 | lastInsertIdError: lastIdError, 14 | rowsAffected: rowsAffected, 15 | rowsAffectedError: rowsAffectedError, 16 | } 17 | } 18 | 19 | func (res *Result) LastInsertId() (int64, error) { 20 | return res.lastInsertId, res.lastInsertIdError 21 | } 22 | 23 | func (res *Result) RowsAffected() (int64, error) { 24 | return res.rowsAffected, res.rowsAffectedError 25 | } 26 | -------------------------------------------------------------------------------- /tx.go: -------------------------------------------------------------------------------- 1 | package testdb 2 | 3 | type Tx struct { 4 | commitFunc func() error 5 | rollbackFunc func() error 6 | } 7 | 8 | func (t *Tx) Commit() error { 9 | if t.commitFunc != nil { 10 | return t.commitFunc() 11 | } 12 | return nil 13 | } 14 | 15 | func (t *Tx) Rollback() error { 16 | if t.rollbackFunc != nil { 17 | return t.rollbackFunc() 18 | } 19 | return nil 20 | } 21 | 22 | func (t *Tx) SetCommitFunc(f func() error) { 23 | t.commitFunc = f 24 | } 25 | 26 | func (t *Tx) StubCommitError(err error) { 27 | t.SetCommitFunc(func() error { 28 | return err 29 | }) 30 | } 31 | 32 | func (t *Tx) SetRollbackFunc(f func() error) { 33 | t.rollbackFunc = f 34 | } 35 | 36 | func (t *Tx) StubRollbackError(err error) { 37 | t.SetRollbackFunc(func() error { 38 | return err 39 | }) 40 | } 41 | -------------------------------------------------------------------------------- /rows.go: -------------------------------------------------------------------------------- 1 | package testdb 2 | 3 | import ( 4 | "database/sql/driver" 5 | "io" 6 | ) 7 | 8 | type rows struct { 9 | closed bool 10 | columns []string 11 | rows [][]driver.Value 12 | pos int 13 | } 14 | 15 | func (rs *rows) clone() *rows { 16 | if rs == nil { 17 | return nil 18 | } 19 | 20 | return &rows{closed: false, columns: rs.columns, rows: rs.rows, pos: 0} 21 | } 22 | 23 | func (rs *rows) Next(dest []driver.Value) error { 24 | rs.pos++ 25 | if rs.pos > len(rs.rows) { 26 | rs.closed = true 27 | 28 | return io.EOF // per interface spec 29 | } 30 | 31 | for i, col := range rs.rows[rs.pos-1] { 32 | dest[i] = col 33 | } 34 | 35 | return nil 36 | } 37 | 38 | func (rs *rows) Err() error { 39 | return nil 40 | } 41 | 42 | func (rs *rows) Columns() []string { 43 | return rs.columns 44 | } 45 | 46 | func (rs *rows) Close() error { 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /tx_test.go: -------------------------------------------------------------------------------- 1 | package testdb 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | func TestTxSetCommitFunc(t *testing.T) { 9 | tx := &Tx{} 10 | 11 | tx.SetCommitFunc(func() error { 12 | return errors.New("commit failed") 13 | }) 14 | 15 | err := tx.Commit() 16 | 17 | if err == nil || err.Error() != "commit failed" { 18 | t.Fatal("stubbed commit did not return expected error") 19 | } 20 | } 21 | 22 | func TestTxStubCommitError(t *testing.T) { 23 | tx := &Tx{} 24 | 25 | tx.StubCommitError(errors.New("commit failed")) 26 | 27 | err := tx.Commit() 28 | 29 | if err == nil || err.Error() != "commit failed" { 30 | t.Fatal("stubbed commit did not return expected error") 31 | } 32 | } 33 | 34 | func TestTxSetRollbackFunc(t *testing.T) { 35 | tx := &Tx{} 36 | 37 | tx.SetRollbackFunc(func() error { 38 | return errors.New("rollback failed") 39 | }) 40 | 41 | err := tx.Rollback() 42 | 43 | if err == nil || err.Error() != "rollback failed" { 44 | t.Fatal("stubbed rollback did not return expected error") 45 | } 46 | } 47 | 48 | func TestTxStubRollbackError(t *testing.T) { 49 | tx := &Tx{} 50 | 51 | tx.StubRollbackError(errors.New("rollback failed")) 52 | 53 | err := tx.Rollback() 54 | 55 | if err == nil || err.Error() != "rollback failed" { 56 | t.Fatal("stubbed rollback did not return expected error") 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013, Erik St. Martin 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package testdb 2 | 3 | import ( 4 | "database/sql/driver" 5 | "errors" 6 | ) 7 | 8 | type conn struct { 9 | queries map[string]query 10 | queryFunc func(query string, args []driver.Value) (driver.Rows, error) 11 | execFunc func(query string, args []driver.Value) (driver.Result, error) 12 | beginFunc func() (driver.Tx, error) 13 | commitFunc func() error 14 | rollbackFunc func() error 15 | } 16 | 17 | func newConn() *conn { 18 | return &conn{ 19 | queries: make(map[string]query), 20 | } 21 | } 22 | 23 | func (c *conn) Prepare(query string) (driver.Stmt, error) { 24 | s := new(stmt) 25 | 26 | if c.queryFunc != nil { 27 | s.queryFunc = func(args []driver.Value) (driver.Rows, error) { 28 | return c.queryFunc(query, args) 29 | } 30 | } 31 | 32 | if c.execFunc != nil { 33 | s.execFunc = func(args []driver.Value) (driver.Result, error) { 34 | return c.execFunc(query, args) 35 | } 36 | } 37 | 38 | if q, ok := d.conn.queries[getQueryHash(query)]; ok { 39 | if s.queryFunc == nil && q.rows != nil { 40 | s.queryFunc = func(args []driver.Value) (driver.Rows, error) { 41 | if q.rows != nil { 42 | if rows, ok := q.rows.(*rows); ok { 43 | return rows.clone(), nil 44 | } 45 | return q.rows, nil 46 | } 47 | return nil, q.err 48 | } 49 | } 50 | 51 | if s.execFunc == nil && q.result != nil { 52 | s.execFunc = func(args []driver.Value) (driver.Result, error) { 53 | if q.result != nil { 54 | return q.result, nil 55 | } 56 | return nil, q.err 57 | } 58 | } 59 | } 60 | 61 | if s.queryFunc == nil && s.execFunc == nil { 62 | return new(stmt), errors.New("Query not stubbed: " + query) 63 | } 64 | 65 | return s, nil 66 | } 67 | 68 | func (*conn) Close() error { 69 | return nil 70 | } 71 | 72 | func (c *conn) Begin() (driver.Tx, error) { 73 | if c.beginFunc != nil { 74 | return c.beginFunc() 75 | } 76 | 77 | t := &Tx{} 78 | if c.commitFunc != nil { 79 | t.SetCommitFunc(c.commitFunc) 80 | } 81 | if c.rollbackFunc != nil { 82 | t.SetRollbackFunc(c.rollbackFunc) 83 | } 84 | 85 | return t, nil 86 | } 87 | 88 | func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { 89 | if c.queryFunc != nil { 90 | return c.queryFunc(query, args) 91 | } 92 | if q, ok := d.conn.queries[getQueryHash(query)]; ok { 93 | if rows, ok := q.rows.(*rows); ok { 94 | return rows.clone(), q.err 95 | } 96 | return q.rows, q.err 97 | } 98 | return nil, errors.New("Query not stubbed: " + query) 99 | } 100 | 101 | func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { 102 | if c.execFunc != nil { 103 | return c.execFunc(query, args) 104 | } 105 | 106 | if q, ok := d.conn.queries[getQueryHash(query)]; ok { 107 | if q.result != nil { 108 | return q.result, nil 109 | } else if q.err != nil { 110 | return nil, q.err 111 | } 112 | } 113 | 114 | return nil, errors.New("Exec call not stubbed: " + query) 115 | } 116 | -------------------------------------------------------------------------------- /testdb.go: -------------------------------------------------------------------------------- 1 | package testdb 2 | 3 | import ( 4 | "crypto/sha1" 5 | "database/sql" 6 | "database/sql/driver" 7 | "encoding/csv" 8 | "io" 9 | "regexp" 10 | "strings" 11 | "time" 12 | ) 13 | 14 | var d *testDriver 15 | 16 | func init() { 17 | d = newDriver() 18 | sql.Register("testdb", d) 19 | } 20 | 21 | type testDriver struct { 22 | openFunc func(dsn string) (driver.Conn, error) 23 | conn *conn 24 | enableTimeParsing bool 25 | } 26 | 27 | type query struct { 28 | rows driver.Rows 29 | result *Result 30 | err error 31 | } 32 | 33 | func newDriver() *testDriver { 34 | return &testDriver{ 35 | conn: newConn(), 36 | } 37 | } 38 | 39 | func EnableTimeParsing(flag bool) { 40 | d.enableTimeParsing = flag 41 | } 42 | 43 | func (d *testDriver) Open(dsn string) (driver.Conn, error) { 44 | if d.openFunc != nil { 45 | conn, err := d.openFunc(dsn) 46 | return conn, err 47 | } 48 | 49 | if d.conn == nil { 50 | d.conn = newConn() 51 | } 52 | 53 | return d.conn, nil 54 | } 55 | 56 | var whitespaceRegexp = regexp.MustCompile("\\s") 57 | 58 | func getQueryHash(query string) string { 59 | // Remove whitespace and lowercase to make stubbing less brittle 60 | query = strings.ToLower(whitespaceRegexp.ReplaceAllString(query, "")) 61 | 62 | h := sha1.New() 63 | io.WriteString(h, query) 64 | 65 | return string(h.Sum(nil)) 66 | } 67 | 68 | // Set your own function to be executed when db.Query() is called. As with StubQuery() you can use the RowsFromCSVString() method to easily generate the driver.Rows, or you can return your own. 69 | func SetQueryFunc(f func(query string) (result driver.Rows, err error)) { 70 | SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { 71 | return f(query) 72 | }) 73 | } 74 | 75 | // Set your own function to be executed when db.Query() is called. As with StubQuery() you can use the RowsFromCSVString() method to easily generate the driver.Rows, or you can return your own. 76 | func SetQueryWithArgsFunc(f func(query string, args []driver.Value) (result driver.Rows, err error)) { 77 | d.conn.queryFunc = f 78 | } 79 | 80 | // Stubs the global driver.Conn to return the supplied driver.Rows when db.Query() is called, query stubbing is case insensitive, and whitespace is also ignored. 81 | func StubQuery(q string, rows driver.Rows) { 82 | d.conn.queries[getQueryHash(q)] = query{ 83 | rows: rows, 84 | } 85 | } 86 | 87 | // Stubs the global driver.Conn to return the supplied error when db.Query() is called, query stubbing is case insensitive, and whitespace is also ignored. 88 | func StubQueryError(q string, err error) { 89 | d.conn.queries[getQueryHash(q)] = query{ 90 | err: err, 91 | } 92 | } 93 | 94 | // Set your own function to be executed when db.Open() is called. You can either hand back a valid connection, or an error. Conn() can be used to grab the global Conn object containing stubbed queries. 95 | func SetOpenFunc(f func(dsn string) (driver.Conn, error)) { 96 | d.openFunc = f 97 | } 98 | 99 | // Set your own function to be executed when db.Exec is called. You can return an error or a Result object with the LastInsertId and RowsAffected 100 | func SetExecFunc(f func(query string) (driver.Result, error)) { 101 | SetExecWithArgsFunc(func(query string, args []driver.Value) (driver.Result, error) { 102 | return f(query) 103 | }) 104 | } 105 | 106 | // Set your own function to be executed when db.Exec is called. You can return an error or a Result object with the LastInsertId and RowsAffected 107 | func SetExecWithArgsFunc(f func(query string, args []driver.Value) (driver.Result, error)) { 108 | d.conn.execFunc = f 109 | } 110 | 111 | // Stubs the global driver.Conn to return the supplied Result when db.Exec is called, query stubbing is case insensitive, and whitespace is also ignored. 112 | func StubExec(q string, r *Result) { 113 | d.conn.queries[getQueryHash(q)] = query{ 114 | result: r, 115 | } 116 | } 117 | 118 | // Stubs the global driver.Conn to return the supplied error when db.Exec() is called, query stubbing is case insensitive, and whitespace is also ignored. 119 | func StubExecError(q string, err error) { 120 | StubQueryError(q, err) 121 | } 122 | 123 | // Set your own function to be executed when db.Begin() is called. You can either hand back a valid transaction, or an error. Conn() can be used to grab the global Conn object containing stubbed queries. 124 | func SetBeginFunc(f func() (driver.Tx, error)) { 125 | d.conn.beginFunc = f 126 | } 127 | 128 | // Stubs the global driver.Conn to return the supplied tx and error when db.Begin() is called. 129 | func StubBegin(tx driver.Tx, err error) { 130 | SetBeginFunc(func() (driver.Tx, error) { 131 | return tx, err 132 | }) 133 | } 134 | 135 | // Set your own function to be executed when tx.Commit() is called on the default transcation. Conn() can be used to grab the global Conn object containing stubbed queries. 136 | func SetCommitFunc(f func() error) { 137 | d.conn.commitFunc = f 138 | } 139 | 140 | // Stubs the default transaction to return the supplied error when tx.Commit() is called. 141 | func StubCommitError(err error) { 142 | SetCommitFunc(func() error { 143 | return err 144 | }) 145 | } 146 | 147 | // Set your own function to be executed when tx.Rollback() is called on the default transcation. Conn() can be used to grab the global Conn object containing stubbed queries. 148 | func SetRollbackFunc(f func() error) { 149 | d.conn.rollbackFunc = f 150 | } 151 | 152 | // Stubs the default transaction to return the supplied error when tx.Rollback() is called. 153 | func StubRollbackError(err error) { 154 | SetRollbackFunc(func() error { 155 | return err 156 | }) 157 | } 158 | 159 | // Clears all stubbed queries, and replaced functions. 160 | func Reset() { 161 | d.conn = newConn() 162 | d.openFunc = nil 163 | } 164 | 165 | // Returns a pointer to the global conn object associated with this driver. 166 | func Conn() driver.Conn { 167 | return d.conn 168 | } 169 | 170 | func RowsFromCSVString(columns []string, s string, c ...rune) driver.Rows { 171 | r := strings.NewReader(strings.TrimSpace(s)) 172 | csvReader := csv.NewReader(r) 173 | if len(c) > 0 { 174 | csvReader.Comma = c[0] 175 | } 176 | 177 | rows := [][]driver.Value{} 178 | for { 179 | r, err := csvReader.Read() 180 | 181 | if err != nil || r == nil { 182 | break 183 | } 184 | 185 | row := make([]driver.Value, len(columns)) 186 | 187 | for i, v := range r { 188 | v := strings.TrimSpace(v) 189 | 190 | // If enableTimeParsing is on, check to see if this is a 191 | // time in RFC33339 format 192 | if d.enableTimeParsing { 193 | if time, err := time.Parse(time.RFC3339, v); err == nil { 194 | row[i] = time 195 | } else { 196 | row[i] = v 197 | } 198 | } else { 199 | row[i] = v 200 | } 201 | } 202 | 203 | rows = append(rows, row) 204 | } 205 | 206 | return RowsFromSlice(columns, rows) 207 | } 208 | 209 | func RowsFromSlice(columns []string, data [][]driver.Value) driver.Rows { 210 | return &rows{ 211 | closed: false, 212 | columns: columns, 213 | rows: data, 214 | pos: 0, 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /examples_test.go: -------------------------------------------------------------------------------- 1 | package testdb 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "errors" 7 | "fmt" 8 | "strconv" 9 | ) 10 | 11 | type user struct { 12 | id int64 13 | name string 14 | age int64 15 | created string 16 | data string 17 | } 18 | 19 | func ExampleSetOpenFunc() { 20 | defer Reset() 21 | 22 | SetOpenFunc(func(dsn string) (driver.Conn, error) { 23 | // Conn() will return the same internal driver.Conn being used by the driver 24 | return Conn(), errors.New("test error") 25 | }) 26 | 27 | // err only returns from this if it's an unknown driver, we are stubbing opening a connection 28 | db, _ := sql.Open("testdb", "foo") 29 | _, err := db.Driver().Open("foo") 30 | 31 | if err != nil { 32 | fmt.Println("Stubbed error returned as expected: " + err.Error()) 33 | } 34 | 35 | // Output: 36 | // Stubbed error returned as expected: test error 37 | } 38 | 39 | func ExampleRowsFromCSVString() { 40 | columns := []string{"id", "name", "age", "created"} 41 | result := ` 42 | 1,tim,20,2012-10-01 01:00:01 43 | 2,joe,25,2012-10-02 02:00:02 44 | 3,bob,30,2012-10-03 03:00:03 45 | ` 46 | rows := RowsFromCSVString(columns, result) 47 | 48 | fmt.Println(rows.Columns()) 49 | 50 | // Output: 51 | // [id name age created] 52 | } 53 | 54 | func ExampleStubQuery() { 55 | defer Reset() 56 | 57 | db, _ := sql.Open("testdb", "") 58 | 59 | sql := "select id, name, age from users" 60 | columns := []string{"id", "name", "age", "created"} 61 | result := ` 62 | 1,tim,20,2012-10-01 01:00:01 63 | 2,joe,25,2012-10-02 02:00:02 64 | 3,bob,30,2012-10-03 03:00:03 65 | ` 66 | StubQuery(sql, RowsFromCSVString(columns, result)) 67 | 68 | res, _ := db.Query(sql) 69 | 70 | for res.Next() { 71 | var u = new(user) 72 | res.Scan(&u.id, &u.name, &u.age, &u.created) 73 | 74 | fmt.Println(u.name + " - " + strconv.FormatInt(u.age, 10)) 75 | } 76 | 77 | // Output: 78 | // tim - 20 79 | // joe - 25 80 | // bob - 30 81 | } 82 | 83 | func ExampleStubQuery_queryRow() { 84 | defer Reset() 85 | 86 | db, _ := sql.Open("testdb", "") 87 | 88 | sql := "select id, name, age from users" 89 | columns := []string{"id", "name", "age", "created"} 90 | result := ` 91 | 1,tim,20,2012-10-01 01:00:01 92 | ` 93 | StubQuery(sql, RowsFromCSVString(columns, result)) 94 | 95 | row := db.QueryRow(sql) 96 | 97 | u := new(user) 98 | row.Scan(&u.id, &u.name, &u.age, &u.created) 99 | 100 | fmt.Println(u.name + " - " + strconv.FormatInt(u.age, 10)) 101 | 102 | // Output: 103 | // tim - 20 104 | } 105 | 106 | func ExampleStubQueryError() { 107 | defer Reset() 108 | 109 | db, _ := sql.Open("testdb", "") 110 | 111 | sql := "select count(*) from error" 112 | 113 | StubQueryError(sql, errors.New("test error")) 114 | 115 | _, err := db.Query(sql) 116 | 117 | if err != nil { 118 | fmt.Println("Error returned: " + err.Error()) 119 | } 120 | 121 | // Output: 122 | // Error returned: test error 123 | } 124 | 125 | func ExampleSetQueryFunc() { 126 | defer Reset() 127 | 128 | columns := []string{"id", "name", "age", "created"} 129 | rows := "1,tim,20,2012-10-01 01:00:01\n2,joe,25,2012-10-02 02:00:02\n3,bob,30,2012-10-03 03:00:03" 130 | 131 | SetQueryFunc(func(query string) (result driver.Rows, err error) { 132 | return RowsFromCSVString(columns, rows), nil 133 | }) 134 | 135 | db, _ := sql.Open("testdb", "") 136 | 137 | res, _ := db.Query("SELECT foo FROM bar") 138 | 139 | for res.Next() { 140 | var u = new(user) 141 | res.Scan(&u.id, &u.name, &u.age, &u.created) 142 | 143 | fmt.Println(u.name + " - " + strconv.FormatInt(u.age, 10)) 144 | } 145 | 146 | // Output: 147 | // tim - 20 148 | // joe - 25 149 | // bob - 30 150 | } 151 | 152 | func ExampleSetQueryFunc_queryRow() { 153 | defer Reset() 154 | 155 | columns := []string{"id", "name", "age", "created"} 156 | rows := "1,tim,20,2012-10-01 01:00:01" 157 | 158 | SetQueryFunc(func(query string) (result driver.Rows, err error) { 159 | return RowsFromCSVString(columns, rows), nil 160 | }) 161 | 162 | db, _ := sql.Open("testdb", "") 163 | 164 | row := db.QueryRow("SELECT foo FROM bar") 165 | 166 | var u = new(user) 167 | row.Scan(&u.id, &u.name, &u.age, &u.created) 168 | 169 | fmt.Println(u.name + " - " + strconv.FormatInt(u.age, 10)) 170 | 171 | // Output: 172 | // tim - 20 173 | } 174 | 175 | func ExampleSetQueryWithArgsFunc() { 176 | defer Reset() 177 | 178 | SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { 179 | columns := []string{"id", "name", "age", "created"} 180 | 181 | rows := "" 182 | if args[0] == "joe" { 183 | rows = "2,joe,25,2012-10-02 02:00:02" 184 | } 185 | return RowsFromCSVString(columns, rows), nil 186 | }) 187 | 188 | db, _ := sql.Open("testdb", "") 189 | 190 | res, _ := db.Query("SELECT foo FROM bar WHERE name = $1", "joe") 191 | 192 | for res.Next() { 193 | var u = new(user) 194 | res.Scan(&u.id, &u.name, &u.age, &u.created) 195 | 196 | fmt.Println(u.name + " - " + strconv.FormatInt(u.age, 10)) 197 | } 198 | 199 | // Output: 200 | // joe - 25 201 | } 202 | 203 | type testResult struct { 204 | lastId int64 205 | affectedRows int64 206 | } 207 | 208 | func (r testResult) LastInsertId() (int64, error) { 209 | return r.lastId, nil 210 | } 211 | 212 | func (r testResult) RowsAffected() (int64, error) { 213 | return r.affectedRows, nil 214 | } 215 | 216 | func ExampleSetExecWithArgsFunc() { 217 | defer Reset() 218 | 219 | SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { 220 | if args[0] == "joe" { 221 | return testResult{1, 1}, nil 222 | } 223 | return testResult{1, 0}, nil 224 | }) 225 | 226 | db, _ := sql.Open("testdb", "") 227 | 228 | res, _ := db.Exec("UPDATE bar SET name = 'foo' WHERE name = ?", "joe") 229 | 230 | rowsAffected, _ := res.RowsAffected() 231 | fmt.Println("RowsAffected =", rowsAffected) 232 | 233 | // Output: 234 | // RowsAffected = 1 235 | } 236 | 237 | func ExampleSetBeginFunc() { 238 | defer Reset() 239 | 240 | commitCalled := false 241 | rollbackCalled := false 242 | SetBeginFunc(func() (txn driver.Tx, err error) { 243 | t := &Tx{} 244 | t.SetCommitFunc(func() error { 245 | commitCalled = true 246 | return nil 247 | }) 248 | t.SetRollbackFunc(func() error { 249 | rollbackCalled = true 250 | return nil 251 | }) 252 | return t, nil 253 | }) 254 | 255 | db, _ := sql.Open("testdb", "") 256 | tx, _ := db.Begin() 257 | tx.Commit() 258 | 259 | fmt.Println("CommitCalled =", commitCalled) 260 | fmt.Println("RollbackCalled =", rollbackCalled) 261 | 262 | // Output: 263 | // CommitCalled = true 264 | // RollbackCalled = false 265 | } 266 | 267 | func ExampleSetCommitFunc() { 268 | defer Reset() 269 | 270 | SetCommitFunc(func() error { 271 | return errors.New("commit failed") 272 | }) 273 | 274 | db, _ := sql.Open("testdb", "") 275 | tx, _ := db.Begin() 276 | 277 | fmt.Println("CommitResult =", tx.Commit()) 278 | 279 | // Output: 280 | // CommitResult = commit failed 281 | } 282 | 283 | func ExampleSetRollbackFunc() { 284 | defer Reset() 285 | 286 | SetRollbackFunc(func() error { 287 | return errors.New("rollback failed") 288 | }) 289 | 290 | db, _ := sql.Open("testdb", "") 291 | tx, _ := db.Begin() 292 | 293 | fmt.Println("RollbackResult =", tx.Rollback()) 294 | 295 | // Output: 296 | // RollbackResult = rollback failed 297 | } 298 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | go-testdb 2 | ========= 3 | 4 | Framework for stubbing responses from go's driver.Driver interface. 5 | 6 | This can be used to sit in place of your sql.Db so that you can stub responses for sql calls, and remove database dependencies for your test suite. 7 | 8 | This project is in its infancy, but has worked well for all the use cases i've had so far, and continues to evolve as new scenarios are uncovered. Please feel free to send pull-requests, or submit feature requests if you have scenarios that are not accounted for yet. 9 | 10 | ## Setup 11 | The only thing needed for setup is to include the go-testdb package, then you can create your db connection specifying "testdb" as your driver. 12 |
13 | import (
14 | "database/sql"
15 | _"github.com/erikstmartin/go-testdb"
16 | )
17 |
18 | db, _ := sql.Open("testdb", "")
19 |
20 |
21 | ## Stubbing connection failure
22 | You're able to set your own function to execute when the sql library calls sql.Open
23 |
24 | testdb.SetOpenFunc(func(dsn string) (driver.Conn, error) {
25 | return c, errors.New("failed to connect")
26 | })
27 |
28 |
29 | ## Stubbing queries
30 | You're able to stub responses to known queries, unknown queries will trigger log errors so that you can see that queries were executed that were not stubbed.
31 |
32 | Differences in whitespace, and case are ignored.
33 |
34 | For convenience a method has been created for you to take a CSV string and turn it into a database result object (RowsFromCSVString).
35 |
36 |
37 | db, _ := sql.Open("testdb", "")
38 |
39 | sql := "select id, name, age from users"
40 | columns := []string{"id", "name", "age", "created"}
41 | result := `
42 | 1,tim,20,2012-10-01 01:00:01
43 | 2,joe,25,2012-10-02 02:00:02
44 | 3,bob,30,2012-10-03 03:00:03
45 | `
46 | testdb.StubQuery(sql, testdb.RowsFromCSVString(columns, result))
47 |
48 | res, err := db.Query(sql)
49 |
50 |
51 | If for some reason you need to specify another rune to split the columns, you can do it passing the rune that you want to use as `Comma` character as third argument to RowsFromCSVString
52 |
53 |
54 | db, _ := sql.Open("testdb", "")
55 |
56 | sql := "select id, name, age, data from users"
57 | columns := []string{"id", "name", "age", "data", "created"}
58 | result := `
59 | 1|tim|20|part_1,part_2,part_3|2014-10-16 15:01:00
60 | 2|joe|25|part_4,part_5,part_6|2014-10-17 15:01:01
61 | 3|bob|30|part_7,part_8,part_9|2014-10-18 15:01:02
62 | `
63 | testdb.StunQuery(sql, RowsFromCSVString(columns, result, '|'))
64 |
65 | res, err := db.Query(sql)
66 |
67 |
68 | ## Stubbing Query function
69 | Some times you need more control over Query being run, maybe you need to assert whether or not a particular query is run.
70 |
71 | You can return either a driver.Rows for response (your own or utilize RowsFromCSV) or an error to be returned
72 |
73 | testdb.SetQueryFunc(func(query string) (result driver.Rows, err error) {
74 | columns := []string{"id", "name", "age", "created"}
75 | rows := `
76 | 1,tim,20,2012-10-01 01:00:01
77 | 2,joe,25,2012-10-02 02:00:02
78 | 3,bob,30,2012-10-03 03:00:03`
79 |
80 | // inspect query to ensure it matches a pattern, or anything else you want to do first
81 | return RowsFromCSVString(columns, rows), nil
82 | })
83 |
84 | db, _ := sql.Open("testdb", "")
85 |
86 | res, err := db.Query("SELECT foo FROM bar")
87 |
88 |
89 | ## Stubbing Parameterized Query
90 | Sometimes you need control over the results of a parameterized query.
91 |
92 |
93 | testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
94 | columns := []string{"id", "name", "age", "created"}
95 |
96 | rows := ""
97 | if args[0] == "joe" {
98 | rows = "2,joe,25,2012-10-02 02:00:02"
99 | }
100 | return testdb.RowsFromCSVString(columns, rows), nil
101 | })
102 |
103 | db, _ := sql.Open("testdb", "")
104 |
105 | res, _ := db.Query("SELECT foo FROM bar WHERE name = $1", "joe")
106 |
107 |
108 | ## Stubbing errors returned from queries
109 | In case you need to stub errors returned from queries to ensure your code handles them properly
110 |
111 |
112 | db, _ := sql.Open("testdb", "")
113 |
114 | sql := "select count(*) from error"
115 | testdb.StubQueryError(sql, errors.New("test error"))
116 |
117 | res, err := db.Query(sql)
118 |
119 |
120 | ## Stubbing Parameterized Exec query
121 | Sometimes you need control over the handling of a parameterized query that does not return any rows.
122 |
123 |
124 | type testResult struct{
125 | lastId int64
126 | affectedRows int64
127 | }
128 |
129 | func (r testResult) LastInsertId() (int64, error){
130 | return r.lastId, nil
131 | }
132 |
133 | func (r testResult) RowsAffected() (int64, error) {
134 | return r.affectedRows, nil
135 | }
136 | testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
137 | if args[0] == "joe" {
138 | return testResult{1, 1}, nil
139 | }
140 | return testResult{1, 0}, nil
141 | })
142 |
143 | db, _ := sql.Open("testdb", "")
144 |
145 | res, _ := db.Exec("UPDATE bar SET name = 'foo' WHERE name = ?", "joe")
146 |
147 |
148 | ## Stubbing Prepared Statements
149 | You can use the same methods as `SetQueryFunc`, `SetQueryWithArgsFunc` for Prepared Statements
150 |
151 |
152 | testdb.SetQueryFunc(func(query string) (result driver.Rows, err error) {
153 | columns := []string{"id", "name", "age", "created"}
154 | rows := `
155 | 1,tim,20,2012-10-01 01:00:01
156 | 2,joe,25,2012-10-02 02:00:02
157 | 3,bob,30,2012-10-03 03:00:03`
158 |
159 | // inspect query to ensure it matches a pattern, or anything else you want to do first
160 | return RowsFromCSVString(columns, rows), nil
161 | })
162 |
163 | db, _ := sql.Open("testdb", "")
164 |
165 | stmt, _ := db.Prepare("SELECT foo FROM bar")
166 | res, err := stmt.Query("SELECT foo FROM bar")
167 |
168 |
169 | ## Reset
170 | At any point in your test, or as a defer you can remove all stubbed queries, errors, custom set Query or Open functions by calling the reset method.
171 |
172 |
173 | func TestMyDatabase(t *testing.T){
174 | defer testdb.Reset()
175 | }
176 |
177 |
178 | #### TODO
179 | Feel free to contribute and send pull requests
180 | - Transactions
181 |
182 | #### License
183 | Copyright (c) 2013, Erik St. Martin
184 | All rights reserved.
185 |
186 | Redistribution and use in source and binary forms, with or without
187 | modification, are permitted provided that the following conditions are met:
188 |
189 | * Redistributions of source code must retain the above copyright notice, this
190 | list of conditions and the following disclaimer.
191 |
192 | * Redistributions in binary form must reproduce the above copyright notice,
193 | this list of conditions and the following disclaimer in the documentation
194 | and/or other materials provided with the distribution.
195 |
196 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
197 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
198 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
199 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
200 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
201 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
202 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
203 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
204 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
205 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
206 |
--------------------------------------------------------------------------------
/testdb_test.go:
--------------------------------------------------------------------------------
1 | package testdb
2 |
3 | import (
4 | "database/sql"
5 | "database/sql/driver"
6 | "errors"
7 | "reflect"
8 | "testing"
9 | )
10 |
11 | func TestSetOpenFunc(t *testing.T) {
12 | defer Reset()
13 |
14 | SetOpenFunc(func(dsn string) (driver.Conn, error) {
15 | return Conn(), errors.New("test error")
16 | })
17 |
18 | // err only returns from this if it's an unknown driver, we are stubbing opening a connection
19 | db, _ := sql.Open("testdb", "foo")
20 | conn, err := db.Driver().Open("foo")
21 |
22 | if db == nil {
23 | t.Fatal("driver.Open not properly set: db was nil")
24 | }
25 |
26 | if conn == nil {
27 | t.Fatal("driver.Open not properly set: didn't connection")
28 | }
29 |
30 | if err.Error() != "test error" {
31 | t.Fatal("driver.Open not properly set: err was not returned properly")
32 | }
33 | }
34 |
35 | func TestStubQuery(t *testing.T) {
36 | defer Reset()
37 |
38 | db, _ := sql.Open("testdb", "")
39 |
40 | sql := "select count(*) from foo"
41 | columns := []string{"count"}
42 | result := `
43 | 5
44 | `
45 | StubQuery(sql, RowsFromCSVString(columns, result))
46 |
47 | res, err := db.Query(sql)
48 |
49 | if err != nil {
50 | t.Fatal("stubbed query should not return error")
51 | }
52 |
53 | if res.Next() {
54 | var count int64
55 | err = res.Scan(&count)
56 |
57 | if err != nil {
58 | t.Fatal(err)
59 | }
60 |
61 | if count != 5 {
62 | t.Fatal("failed to return count")
63 | }
64 | }
65 | }
66 |
67 | func TestStubQueryAdditionalWhitespace(t *testing.T) {
68 | defer Reset()
69 |
70 | db, _ := sql.Open("testdb", "")
71 |
72 | sqlWhitespace := "select count(*) from foo"
73 | sql := "select count(*) from foo"
74 | columns := []string{"count"}
75 | result := `
76 | 5
77 | `
78 | StubQuery(sqlWhitespace, RowsFromCSVString(columns, result))
79 |
80 | res, err := db.Query(sql)
81 |
82 | if err != nil {
83 | t.Fatal("stubbed query should not return error")
84 | }
85 |
86 | if res.Next() {
87 | var count int64
88 | err = res.Scan(&count)
89 |
90 | if err != nil {
91 | t.Fatal(err)
92 | }
93 |
94 | if count != 5 {
95 | t.Fatal("failed to return count")
96 | }
97 | }
98 | }
99 |
100 | func TestStubQueryChangeCase(t *testing.T) {
101 | defer Reset()
102 |
103 | db, _ := sql.Open("testdb", "")
104 |
105 | sqlCase := "SELECT COUNT(*) FROM foo"
106 | sql := "select count(*) from foo"
107 | columns := []string{"count"}
108 | result := `
109 | 5
110 | `
111 | StubQuery(sqlCase, RowsFromCSVString(columns, result))
112 |
113 | res, err := db.Query(sql)
114 |
115 | if err != nil {
116 | t.Fatal("stubbed query should not return error")
117 | }
118 |
119 | if res.Next() {
120 | var count int64
121 | err = res.Scan(&count)
122 |
123 | if err != nil {
124 | t.Fatal(err)
125 | }
126 |
127 | if count != 5 {
128 | t.Fatal("failed to return count")
129 | }
130 | }
131 | }
132 |
133 | func TestUnknownQuery(t *testing.T) {
134 | defer Reset()
135 |
136 | db, _ := sql.Open("testdb", "")
137 |
138 | sql := "select count(*) from foobar"
139 | _, err := db.Query(sql)
140 |
141 | if err == nil {
142 | t.Fatal("Unknown queries should fail")
143 | }
144 | }
145 |
146 | func TestStubQueryError(t *testing.T) {
147 | defer Reset()
148 |
149 | db, _ := sql.Open("testdb", "")
150 |
151 | sql := "select count(*) from error"
152 |
153 | StubQueryError(sql, errors.New("test error"))
154 |
155 | res, err := db.Query(sql)
156 |
157 | if err == nil {
158 | t.Fatal("failed to return error from stubbed query")
159 | }
160 |
161 | if res != nil {
162 | t.Fatal("result should be nil on error")
163 | }
164 | }
165 |
166 | func TestStubQueryRowError(t *testing.T) {
167 | defer Reset()
168 |
169 | db, _ := sql.Open("testdb", "")
170 |
171 | sql := "select count(*) from error"
172 |
173 | StubQueryError(sql, errors.New("test error"))
174 |
175 | row := db.QueryRow(sql)
176 | var count int64
177 | err := row.Scan(&count)
178 |
179 | if err == nil {
180 | t.Fatal("error not returned")
181 | }
182 | }
183 |
184 | func TestStubQueryMultipleResult(t *testing.T) {
185 | defer Reset()
186 |
187 | db, _ := sql.Open("testdb", "")
188 |
189 | sql := "select id, name, age from users"
190 | columns := []string{"id", "name", "age", "created"}
191 | result := `
192 | 1,tim,20,2012-10-01 01:00:01
193 | 2,joe,25,2012-10-02 02:00:02
194 | 3,bob,30,2012-10-03 03:00:03
195 | `
196 | StubQuery(sql, RowsFromCSVString(columns, result))
197 |
198 | res, err := db.Query(sql)
199 |
200 | if err != nil {
201 | t.Fatal("stubbed query should not return error")
202 | }
203 |
204 | i := 0
205 |
206 | for res.Next() {
207 | var u = user{}
208 | err = res.Scan(&u.id, &u.name, &u.age, &u.created)
209 |
210 | if err != nil {
211 | t.Fatal(err)
212 | }
213 |
214 | if u.id == 0 || u.name == "" || u.age == 0 || u.created == "" {
215 | t.Fatal("failed to populate object with result")
216 | }
217 | i++
218 | }
219 |
220 | if i != 3 {
221 | t.Fatal("failed to return proper number of results")
222 | }
223 | }
224 |
225 | func TestStubQueryMultipleResultWithCustomComma(t *testing.T) {
226 | defer Reset()
227 |
228 | db, _ := sql.Open("testdb", "")
229 |
230 | sql := "select id, name, age from users"
231 | columns := []string{"id", "name", "age", "data", "created"}
232 | result := `
233 | 1|tim|20|part_1,part_2,part_3|2014-10-16 15:01:00
234 | 2|joe|25|part_4,part_5,part_6|2014-10-17 15:01:01
235 | 3|bob|30|part_7,part_8,part_9|2014-10-18 15:01:02
236 | `
237 | StubQuery(sql, RowsFromCSVString(columns, result, '|'))
238 |
239 | res, err := db.Query(sql)
240 |
241 | if err != nil {
242 | t.Fatal("stubbed query should not return error")
243 | }
244 |
245 | i := 0
246 |
247 | for res.Next() {
248 | var u = user{}
249 | err = res.Scan(&u.id, &u.name, &u.age, &u.data, &u.created)
250 |
251 | if err != nil {
252 | t.Fatal(err)
253 | }
254 |
255 | if u.id == 0 || u.name == "" || u.age == 0 || u.data == "" || u.created == "" {
256 | t.Fatal("failed to populate object with result")
257 | }
258 | i++
259 | }
260 |
261 | if i != 3 {
262 | t.Fatal("failed to return proper number of results")
263 | }
264 | }
265 |
266 | func TestStubQueryMultipleResultNewline(t *testing.T) {
267 | defer Reset()
268 |
269 | db, _ := sql.Open("testdb", "")
270 |
271 | sql := "select id, name, age from users"
272 | columns := []string{"id", "name", "age", "created"}
273 | result := "1,tim,20,2012-10-01 01:00:01\n2,joe,25,2012-10-02 02:00:02\n3,bob,30,2012-10-03 03:00:03"
274 |
275 | StubQuery(sql, RowsFromCSVString(columns, result))
276 |
277 | res, err := db.Query(sql)
278 |
279 | if err != nil {
280 | t.Fatal("stubbed query should not return error")
281 | }
282 |
283 | i := 0
284 |
285 | for res.Next() {
286 | var u = user{}
287 | err = res.Scan(&u.id, &u.name, &u.age, &u.created)
288 |
289 | if err != nil {
290 | t.Fatal(err)
291 | }
292 |
293 | if u.id == 0 || u.name == "" || u.age == 0 || u.created == "" {
294 | t.Fatal("failed to populate object with result")
295 | }
296 | i++
297 | }
298 |
299 | if i != 3 {
300 | t.Fatal("failed to return proper number of results")
301 | }
302 | }
303 |
304 | func TestSetQueryFunc(t *testing.T) {
305 | defer Reset()
306 |
307 | columns := []string{"id", "name", "age", "created"}
308 | rows := "1,tim,20,2012-10-01 01:00:01\n2,joe,25,2012-10-02 02:00:02\n3,bob,30,2012-10-03 03:00:03"
309 |
310 | SetQueryFunc(func(query string) (result driver.Rows, err error) {
311 | return RowsFromCSVString(columns, rows), nil
312 | })
313 |
314 | if Conn().(*conn).queryFunc == nil {
315 | t.Fatal("query function not stubbed")
316 | }
317 |
318 | db, _ := sql.Open("testdb", "")
319 |
320 | res, err := db.Query("SELECT foo FROM bar")
321 |
322 | if err != nil {
323 | t.Fatal(err)
324 | }
325 |
326 | i := 0
327 |
328 | for res.Next() {
329 | var u = user{}
330 | err = res.Scan(&u.id, &u.name, &u.age, &u.created)
331 |
332 | if err != nil {
333 | t.Fatal(err)
334 | }
335 |
336 | if u.id == 0 || u.name == "" || u.age == 0 || u.created == "" {
337 | t.Fatal("failed to populate object with result")
338 | }
339 | i++
340 | }
341 |
342 | if i != 3 {
343 | t.Fatal("failed to return proper number of results")
344 | }
345 | }
346 |
347 | func TestSetQueryFuncError(t *testing.T) {
348 | defer Reset()
349 |
350 | SetQueryFunc(func(query string) (result driver.Rows, err error) {
351 | return nil, errors.New("stubbed error")
352 | })
353 |
354 | db, _ := sql.Open("testdb", "")
355 |
356 | _, err := db.Query("SELECT foo FROM bar")
357 |
358 | if err == nil {
359 | t.Fatal("failed to return error from QueryFunc")
360 | }
361 | }
362 |
363 | func TestReset(t *testing.T) {
364 | sql.Open("testdb", "")
365 |
366 | sql := "select count(*) from error"
367 | StubQueryError(sql, errors.New("test error"))
368 |
369 | Reset()
370 |
371 | if len(d.conn.queries) > 0 {
372 | t.Fatal("failed to reset connection")
373 | }
374 | }
375 |
376 | func TestStubQueryRow(t *testing.T) {
377 | defer Reset()
378 |
379 | db, _ := sql.Open("testdb", "")
380 |
381 | sql := "select count(*) from foo"
382 | columns := []string{"count"}
383 | result := `
384 | 5
385 | `
386 | StubQuery(sql, RowsFromCSVString(columns, result))
387 |
388 | row := db.QueryRow(sql)
389 |
390 | if row == nil {
391 | t.Fatal("stub query should have returned row")
392 | }
393 |
394 | var count int64
395 | err := row.Scan(&count)
396 |
397 | if err != nil {
398 | t.Fatal(err)
399 | }
400 |
401 | if count != 5 {
402 | t.Fatal("failed to return count")
403 | }
404 | }
405 |
406 | func TestStubQueryRowReuse(t *testing.T) {
407 | defer Reset()
408 |
409 | db, _ := sql.Open("testdb", "")
410 |
411 | sql := "select count(*) from foo"
412 | columns := []string{"count"}
413 | result := `
414 | 5
415 | `
416 | StubQuery(sql, RowsFromCSVString(columns, result))
417 |
418 | i := 0
419 | rows, _ := db.Query(sql)
420 | for rows.Next() {
421 | i++
422 | }
423 | if i != 1 {
424 | t.Fatal("stub query should have returned one row")
425 | }
426 |
427 | j := i
428 | moreRows, _ := db.Query(sql)
429 | for moreRows.Next() {
430 | j++
431 | }
432 |
433 | if i == j {
434 | t.Fatal("stub query did not return another set of rows")
435 | }
436 | }
437 |
438 | func TestSetQueryFuncRow(t *testing.T) {
439 | defer Reset()
440 |
441 | columns := []string{"id", "name", "age", "created"}
442 | rows := "1,tim,20,2012-10-01 01:00:01"
443 |
444 | SetQueryFunc(func(query string) (result driver.Rows, err error) {
445 | return RowsFromCSVString(columns, rows), nil
446 | })
447 |
448 | if Conn().(*conn).queryFunc == nil {
449 | t.Fatal("query function not stubbed")
450 | }
451 |
452 | db, _ := sql.Open("testdb", "")
453 |
454 | row := db.QueryRow("SELECT foo FROM bar")
455 |
456 | var u = user{}
457 | err := row.Scan(&u.id, &u.name, &u.age, &u.created)
458 |
459 | if err != nil {
460 | t.Fatal(err)
461 | }
462 |
463 | if u.id == 0 || u.name == "" || u.age == 0 || u.created == "" {
464 | t.Fatal("failed to populate object with result")
465 | }
466 | }
467 |
468 | func TestSetQueryFuncRowError(t *testing.T) {
469 | defer Reset()
470 |
471 | SetQueryFunc(func(query string) (result driver.Rows, err error) {
472 | return nil, errors.New("Stubbed error")
473 | })
474 |
475 | if Conn().(*conn).queryFunc == nil {
476 | t.Fatal("query function not stubbed")
477 | }
478 |
479 | db, _ := sql.Open("testdb", "")
480 |
481 | row := db.QueryRow("SELECT foo FROM bar")
482 |
483 | var u = user{}
484 | err := row.Scan(&u.id, &u.name, &u.age, &u.created)
485 |
486 | if err == nil {
487 | t.Fatal("Did not return error")
488 | }
489 | }
490 |
491 | func TestStubExec(t *testing.T) {
492 | defer Reset()
493 |
494 | db, _ := sql.Open("testdb", "")
495 |
496 | sql := "INSERT INTO foo SET (foo) VALUES (bar)"
497 | StubExec(sql, NewResult(5, errors.New("last insert error"), 3, errors.New("rows affected error")))
498 |
499 | res, err := db.Exec(sql)
500 |
501 | if err != nil {
502 | t.Fatal("stubbed exec call returned unexpected error")
503 | }
504 |
505 | var insertId int64
506 | insertId, err = res.LastInsertId()
507 | if insertId != 5 || err.Error() != "last insert error" {
508 | t.Fatal("stubbed exec did not return expected result")
509 | }
510 |
511 | var affected int64
512 | affected, err = res.RowsAffected()
513 |
514 | if affected != 3 || err.Error() != "rows affected error" {
515 | t.Fatal("stubbed exec did not return expected result")
516 | }
517 | }
518 |
519 | func TestStubExecError(t *testing.T) {
520 | defer Reset()
521 |
522 | db, _ := sql.Open("testdb", "")
523 |
524 | query := "INSERT INTO foo SET (foo) VALUES (bar)"
525 | StubExecError(query, errors.New("request failed"))
526 |
527 | res, err := db.Exec(query)
528 |
529 | if reflect.Indirect(reflect.ValueOf(res)).CanAddr() {
530 | t.Fatal("stubbed exec returned unexpected result")
531 | }
532 |
533 | if err == nil || err.Error() != "request failed" {
534 | t.Fatal("stubbed exec call did not return expected error")
535 | }
536 | }
537 |
538 | func TestStubExecFunc(t *testing.T) {
539 | defer Reset()
540 |
541 | db, _ := sql.Open("testdb", "")
542 |
543 | query := "INSERT INTO foo SET (foo) VALUES (bar)"
544 | result := NewResult(5, errors.New("last insert error"), 3, errors.New("rows affected error"))
545 |
546 | SetExecFunc(func(query string) (driver.Result, error) {
547 | return result, nil
548 | })
549 |
550 | res, err := db.Exec(query)
551 |
552 | if err != nil {
553 | t.Fatal("stubbed exec returned unexpected error")
554 | }
555 |
556 | var insertId int64
557 | insertId, err = res.LastInsertId()
558 | if insertId != 5 || err.Error() != "last insert error" {
559 | t.Fatal("stubbed exec did not return expected result")
560 | }
561 |
562 | var affected int64
563 | affected, err = res.RowsAffected()
564 |
565 | if affected != 3 || err.Error() != "rows affected error" {
566 | t.Fatal("stubbed exec did not return expected result")
567 | }
568 | }
569 |
570 | func TestStubExecFuncError(t *testing.T) {
571 | defer Reset()
572 |
573 | db, _ := sql.Open("testdb", "")
574 |
575 | query := "INSERT INTO foo SET (foo) VALUES (bar)"
576 |
577 | SetExecFunc(func(query string) (driver.Result, error) {
578 | return nil, errors.New("request failed")
579 | })
580 |
581 | res, err := db.Exec(query)
582 |
583 | if res != nil {
584 | t.Fatal("stubbed exec unexpected result")
585 | }
586 |
587 | if err == nil || err.Error() != "request failed" {
588 | t.Fatal("stubbed exec did not return expected error")
589 | }
590 | }
591 |
592 | func TestSetBeginFunc(t *testing.T) {
593 | defer Reset()
594 |
595 | db, _ := sql.Open("testdb", "")
596 |
597 | SetBeginFunc(func() (driver.Tx, error) {
598 | return nil, errors.New("begin failed")
599 | })
600 |
601 | res, err := db.Begin()
602 |
603 | if res != nil {
604 | t.Fatal("stubbed begin unexpected result")
605 | }
606 |
607 | if err == nil || err.Error() != "begin failed" {
608 | t.Fatal("stubbed begin did not return expected error")
609 | }
610 | }
611 |
612 | func TestStubBegin(t *testing.T) {
613 | defer Reset()
614 |
615 | db, _ := sql.Open("testdb", "")
616 |
617 | StubBegin(nil, errors.New("begin failed"))
618 | res, err := db.Begin()
619 |
620 | if res != nil {
621 | t.Fatal("stubbed begin unexpected result")
622 | }
623 |
624 | if err == nil || err.Error() != "begin failed" {
625 | t.Fatal("stubbed begin did not return expected error")
626 | }
627 | }
628 |
629 | func TestSetCommitFunc(t *testing.T) {
630 | defer Reset()
631 |
632 | db, _ := sql.Open("testdb", "")
633 |
634 | SetCommitFunc(func() error {
635 | return errors.New("commit failed")
636 | })
637 |
638 | tx, err := db.Begin()
639 |
640 | if tx == nil {
641 | t.Fatal("begin expected result")
642 | }
643 |
644 | if err != nil {
645 | t.Fatal("begin returned unexpected error")
646 | }
647 |
648 | err = tx.Commit()
649 |
650 | if err == nil || err.Error() != "commit failed" {
651 | t.Fatal("stubbed commit did not return expected error")
652 | }
653 | }
654 |
655 | func TestStubCommitError(t *testing.T) {
656 | defer Reset()
657 |
658 | db, _ := sql.Open("testdb", "")
659 |
660 | StubCommitError(errors.New("commit failed"))
661 |
662 | tx, err := db.Begin()
663 |
664 | if tx == nil {
665 | t.Fatal("begin expected result")
666 | }
667 |
668 | if err != nil {
669 | t.Fatal("begin returned unexpected error")
670 | }
671 |
672 | err = tx.Commit()
673 |
674 | if err == nil || err.Error() != "commit failed" {
675 | t.Fatal("stubbed commit did not return expected error")
676 | }
677 | }
678 |
679 | func TestSetRollbackFunc(t *testing.T) {
680 | defer Reset()
681 |
682 | db, _ := sql.Open("testdb", "")
683 |
684 | SetRollbackFunc(func() error {
685 | return errors.New("rollback failed")
686 | })
687 |
688 | tx, err := db.Begin()
689 |
690 | if tx == nil {
691 | t.Fatal("begin expected result")
692 | }
693 |
694 | if err != nil {
695 | t.Fatal("begin returned unexpected error")
696 | }
697 |
698 | err = tx.Rollback()
699 |
700 | if err == nil || err.Error() != "rollback failed" {
701 | t.Fatal("stubbed rollback did not return expected error")
702 | }
703 | }
704 |
705 | func TestStubRollbackError(t *testing.T) {
706 | defer Reset()
707 |
708 | db, _ := sql.Open("testdb", "")
709 |
710 | StubRollbackError(errors.New("rollback failed"))
711 |
712 | tx, err := db.Begin()
713 |
714 | if tx == nil {
715 | t.Fatal("begin expected result")
716 | }
717 |
718 | if err != nil {
719 | t.Fatal("begin returned unexpected error")
720 | }
721 |
722 | err = tx.Rollback()
723 |
724 | if err == nil || err.Error() != "rollback failed" {
725 | t.Fatal("stubbed rollback did not return expected error")
726 | }
727 | }
728 |
--------------------------------------------------------------------------------