├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── builder.go ├── builder_mssql.go ├── builder_mssql_test.go ├── builder_mysql.go ├── builder_mysql_test.go ├── builder_oci.go ├── builder_oci_test.go ├── builder_pgsql.go ├── builder_pgsql_test.go ├── builder_sqlite.go ├── builder_sqlite_test.go ├── builder_standard.go ├── builder_standard_test.go ├── db.go ├── db_test.go ├── example_test.go ├── expression.go ├── expression_test.go ├── go.mod ├── go.sum ├── model_query.go ├── model_query_test.go ├── query.go ├── query_builder.go ├── query_builder_test.go ├── query_test.go ├── rows.go ├── select.go ├── select_test.go ├── struct.go ├── struct_test.go ├── testdata └── mysql.sql └── tx.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | 3 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 4 | *.o 5 | *.a 6 | *.so 7 | 8 | # Folders 9 | _obj 10 | _test 11 | 12 | # Architecture specific extensions/prefixes 13 | *.[568vq] 14 | [568vq].out 15 | 16 | *.cgo1.go 17 | *.cgo2.c 18 | _cgo_defun.c 19 | _cgo_gotypes.go 20 | _cgo_export.* 21 | 22 | _testmain.go 23 | 24 | *.exe 25 | *.test 26 | *.prof 27 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: bionic 2 | 3 | language: go 4 | 5 | go: 6 | - 1.13.x 7 | 8 | services: 9 | - mysql 10 | 11 | install: 12 | - go get golang.org/x/tools/cmd/cover 13 | - go get github.com/mattn/goveralls 14 | - go get golang.org/x/lint/golint 15 | 16 | before_script: 17 | - mysql -e 'CREATE DATABASE pocketbase_dbx_test;'; 18 | 19 | script: 20 | - test -z "`gofmt -l -d .`" 21 | - go test -v -covermode=count -coverprofile=coverage.out 22 | - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright (c) 2016, Qiang Xue 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software 5 | and associated documentation files (the "Software"), to deal in the Software without restriction, 6 | including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | sublicense, and/or sell copies of the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all copies or 11 | substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING 14 | BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | -------------------------------------------------------------------------------- /builder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "sort" 11 | "strings" 12 | ) 13 | 14 | // Builder supports building SQL statements in a DB-agnostic way. 15 | // Builder mainly provides two sets of query building methods: those building SELECT statements 16 | // and those manipulating DB data or schema (e.g. INSERT statements, CREATE TABLE statements). 17 | type Builder interface { 18 | // NewQuery creates a new Query object with the given SQL statement. 19 | // The SQL statement may contain parameter placeholders which can be bound with actual parameter 20 | // values before the statement is executed. 21 | NewQuery(string) *Query 22 | // Select returns a new SelectQuery object that can be used to build a SELECT statement. 23 | // The parameters to this method should be the list column names to be selected. 24 | // A column name may have an optional alias name. For example, Select("id", "my_name AS name"). 25 | Select(...string) *SelectQuery 26 | // ModelQuery returns a new ModelQuery object that can be used to perform model insertion, update, and deletion. 27 | // The parameter to this method should be a pointer to the model struct that needs to be inserted, updated, or deleted. 28 | Model(interface{}) *ModelQuery 29 | 30 | // GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID. 31 | GeneratePlaceholder(int) string 32 | 33 | // Quote quotes a string so that it can be embedded in a SQL statement as a string value. 34 | Quote(string) string 35 | // QuoteSimpleTableName quotes a simple table name. 36 | // A simple table name does not contain any schema prefix. 37 | QuoteSimpleTableName(string) string 38 | // QuoteSimpleColumnName quotes a simple column name. 39 | // A simple column name does not contain any table prefix. 40 | QuoteSimpleColumnName(string) string 41 | 42 | // QueryBuilder returns the query builder supporting the current DB. 43 | QueryBuilder() QueryBuilder 44 | 45 | // Insert creates a Query that represents an INSERT SQL statement. 46 | // The keys of cols are the column names, while the values of cols are the corresponding column 47 | // values to be inserted. 48 | Insert(table string, cols Params) *Query 49 | // Upsert creates a Query that represents an UPSERT SQL statement. 50 | // Upsert inserts a row into the table if the primary key or unique index is not found. 51 | // Otherwise it will update the row with the new values. 52 | // The keys of cols are the column names, while the values of cols are the corresponding column 53 | // values to be inserted. 54 | Upsert(table string, cols Params, constraints ...string) *Query 55 | // Update creates a Query that represents an UPDATE SQL statement. 56 | // The keys of cols are the column names, while the values of cols are the corresponding new column 57 | // values. If the "where" expression is nil, the UPDATE SQL statement will have no WHERE clause 58 | // (be careful in this case as the SQL statement will update ALL rows in the table). 59 | Update(table string, cols Params, where Expression) *Query 60 | // Delete creates a Query that represents a DELETE SQL statement. 61 | // If the "where" expression is nil, the DELETE SQL statement will have no WHERE clause 62 | // (be careful in this case as the SQL statement will delete ALL rows in the table). 63 | Delete(table string, where Expression) *Query 64 | 65 | // CreateTable creates a Query that represents a CREATE TABLE SQL statement. 66 | // The keys of cols are the column names, while the values of cols are the corresponding column types. 67 | // The optional "options" parameters will be appended to the generated SQL statement. 68 | CreateTable(table string, cols map[string]string, options ...string) *Query 69 | // RenameTable creates a Query that can be used to rename a table. 70 | RenameTable(oldName, newName string) *Query 71 | // DropTable creates a Query that can be used to drop a table. 72 | DropTable(table string) *Query 73 | // TruncateTable creates a Query that can be used to truncate a table. 74 | TruncateTable(table string) *Query 75 | 76 | // AddColumn creates a Query that can be used to add a column to a table. 77 | AddColumn(table, col, typ string) *Query 78 | // DropColumn creates a Query that can be used to drop a column from a table. 79 | DropColumn(table, col string) *Query 80 | // RenameColumn creates a Query that can be used to rename a column in a table. 81 | RenameColumn(table, oldName, newName string) *Query 82 | // AlterColumn creates a Query that can be used to change the definition of a table column. 83 | AlterColumn(table, col, typ string) *Query 84 | 85 | // AddPrimaryKey creates a Query that can be used to specify primary key(s) for a table. 86 | // The "name" parameter specifies the name of the primary key constraint. 87 | AddPrimaryKey(table, name string, cols ...string) *Query 88 | // DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table. 89 | DropPrimaryKey(table, name string) *Query 90 | 91 | // AddForeignKey creates a Query that can be used to add a foreign key constraint to a table. 92 | // The length of cols and refCols must be the same as they refer to the primary and referential columns. 93 | // The optional "options" parameters will be appended to the SQL statement. They can be used to 94 | // specify options such as "ON DELETE CASCADE". 95 | AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *Query 96 | // DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table. 97 | DropForeignKey(table, name string) *Query 98 | 99 | // CreateIndex creates a Query that can be used to create an index for a table. 100 | CreateIndex(table, name string, cols ...string) *Query 101 | // CreateUniqueIndex creates a Query that can be used to create a unique index for a table. 102 | CreateUniqueIndex(table, name string, cols ...string) *Query 103 | // DropIndex creates a Query that can be used to remove the named index from a table. 104 | DropIndex(table, name string) *Query 105 | } 106 | 107 | // BaseBuilder provides a basic implementation of the Builder interface. 108 | type BaseBuilder struct { 109 | db *DB 110 | executor Executor 111 | } 112 | 113 | // NewBaseBuilder creates a new BaseBuilder instance. 114 | func NewBaseBuilder(db *DB, executor Executor) *BaseBuilder { 115 | return &BaseBuilder{db, executor} 116 | } 117 | 118 | // DB returns the DB instance that this builder is associated with. 119 | func (b *BaseBuilder) DB() *DB { 120 | return b.db 121 | } 122 | 123 | // Executor returns the executor object (a DB instance or a transaction) for executing SQL statements. 124 | func (b *BaseBuilder) Executor() Executor { 125 | return b.executor 126 | } 127 | 128 | // NewQuery creates a new Query object with the given SQL statement. 129 | // The SQL statement may contain parameter placeholders which can be bound with actual parameter 130 | // values before the statement is executed. 131 | func (b *BaseBuilder) NewQuery(sql string) *Query { 132 | return NewQuery(b.db, b.executor, sql) 133 | } 134 | 135 | // GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID. 136 | func (b *BaseBuilder) GeneratePlaceholder(int) string { 137 | return "?" 138 | } 139 | 140 | // Quote quotes a string so that it can be embedded in a SQL statement as a string value. 141 | func (b *BaseBuilder) Quote(s string) string { 142 | return "'" + strings.Replace(s, "'", "''", -1) + "'" 143 | } 144 | 145 | // QuoteSimpleTableName quotes a simple table name. 146 | // A simple table name does not contain any schema prefix. 147 | func (b *BaseBuilder) QuoteSimpleTableName(s string) string { 148 | if strings.Contains(s, `"`) { 149 | return s 150 | } 151 | return `"` + s + `"` 152 | } 153 | 154 | // QuoteSimpleColumnName quotes a simple column name. 155 | // A simple column name does not contain any table prefix. 156 | func (b *BaseBuilder) QuoteSimpleColumnName(s string) string { 157 | if strings.Contains(s, `"`) || s == "*" { 158 | return s 159 | } 160 | return `"` + s + `"` 161 | } 162 | 163 | // Insert creates a Query that represents an INSERT SQL statement. 164 | // The keys of cols are the column names, while the values of cols are the corresponding column 165 | // values to be inserted. 166 | func (b *BaseBuilder) Insert(table string, cols Params) *Query { 167 | names := make([]string, 0, len(cols)) 168 | for name := range cols { 169 | names = append(names, name) 170 | } 171 | sort.Strings(names) 172 | 173 | params := Params{} 174 | columns := make([]string, 0, len(names)) 175 | values := make([]string, 0, len(names)) 176 | for _, name := range names { 177 | columns = append(columns, b.db.QuoteColumnName(name)) 178 | value := cols[name] 179 | if e, ok := value.(Expression); ok { 180 | values = append(values, e.Build(b.db, params)) 181 | } else { 182 | values = append(values, fmt.Sprintf("{:p%v}", len(params))) 183 | params[fmt.Sprintf("p%v", len(params))] = value 184 | } 185 | } 186 | 187 | var sql string 188 | if len(names) == 0 { 189 | sql = fmt.Sprintf("INSERT INTO %v DEFAULT VALUES", b.db.QuoteTableName(table)) 190 | } else { 191 | sql = fmt.Sprintf("INSERT INTO %v (%v) VALUES (%v)", 192 | b.db.QuoteTableName(table), 193 | strings.Join(columns, ", "), 194 | strings.Join(values, ", "), 195 | ) 196 | } 197 | 198 | return b.NewQuery(sql).Bind(params) 199 | } 200 | 201 | // Upsert creates a Query that represents an UPSERT SQL statement. 202 | // Upsert inserts a row into the table if the primary key or unique index is not found. 203 | // Otherwise it will update the row with the new values. 204 | // The keys of cols are the column names, while the values of cols are the corresponding column 205 | // values to be inserted. 206 | func (b *BaseBuilder) Upsert(table string, cols Params, constraints ...string) *Query { 207 | q := b.NewQuery("") 208 | q.LastError = errors.New("Upsert is not supported") 209 | return q 210 | } 211 | 212 | // Update creates a Query that represents an UPDATE SQL statement. 213 | // The keys of cols are the column names, while the values of cols are the corresponding new column 214 | // values. If the "where" expression is nil, the UPDATE SQL statement will have no WHERE clause 215 | // (be careful in this case as the SQL statement will update ALL rows in the table). 216 | func (b *BaseBuilder) Update(table string, cols Params, where Expression) *Query { 217 | names := make([]string, 0, len(cols)) 218 | for name := range cols { 219 | names = append(names, name) 220 | } 221 | sort.Strings(names) 222 | 223 | params := Params{} 224 | lines := make([]string, 0, len(names)) 225 | for _, name := range names { 226 | value := cols[name] 227 | name = b.db.QuoteColumnName(name) 228 | if e, ok := value.(Expression); ok { 229 | lines = append(lines, name+"="+e.Build(b.db, params)) 230 | } else { 231 | lines = append(lines, fmt.Sprintf("%v={:p%v}", name, len(params))) 232 | params[fmt.Sprintf("p%v", len(params))] = value 233 | } 234 | } 235 | 236 | sql := fmt.Sprintf("UPDATE %v SET %v", b.db.QuoteTableName(table), strings.Join(lines, ", ")) 237 | if where != nil { 238 | w := where.Build(b.db, params) 239 | if w != "" { 240 | sql += " WHERE " + w 241 | } 242 | } 243 | 244 | return b.NewQuery(sql).Bind(params) 245 | } 246 | 247 | // Delete creates a Query that represents a DELETE SQL statement. 248 | // If the "where" expression is nil, the DELETE SQL statement will have no WHERE clause 249 | // (be careful in this case as the SQL statement will delete ALL rows in the table). 250 | func (b *BaseBuilder) Delete(table string, where Expression) *Query { 251 | sql := "DELETE FROM " + b.db.QuoteTableName(table) 252 | params := Params{} 253 | if where != nil { 254 | w := where.Build(b.db, params) 255 | if w != "" { 256 | sql += " WHERE " + w 257 | } 258 | } 259 | return b.NewQuery(sql).Bind(params) 260 | } 261 | 262 | // CreateTable creates a Query that represents a CREATE TABLE SQL statement. 263 | // The keys of cols are the column names, while the values of cols are the corresponding column types. 264 | // The optional "options" parameters will be appended to the generated SQL statement. 265 | func (b *BaseBuilder) CreateTable(table string, cols map[string]string, options ...string) *Query { 266 | names := []string{} 267 | for name := range cols { 268 | names = append(names, name) 269 | } 270 | sort.Strings(names) 271 | 272 | columns := []string{} 273 | for _, name := range names { 274 | columns = append(columns, b.db.QuoteColumnName(name)+" "+cols[name]) 275 | } 276 | 277 | sql := fmt.Sprintf("CREATE TABLE %v (%v)", b.db.QuoteTableName(table), strings.Join(columns, ", ")) 278 | for _, opt := range options { 279 | sql += " " + opt 280 | } 281 | 282 | return b.NewQuery(sql) 283 | } 284 | 285 | // RenameTable creates a Query that can be used to rename a table. 286 | func (b *BaseBuilder) RenameTable(oldName, newName string) *Query { 287 | sql := fmt.Sprintf("RENAME TABLE %v TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName)) 288 | return b.NewQuery(sql) 289 | } 290 | 291 | // DropTable creates a Query that can be used to drop a table. 292 | func (b *BaseBuilder) DropTable(table string) *Query { 293 | sql := "DROP TABLE " + b.db.QuoteTableName(table) 294 | return b.NewQuery(sql) 295 | } 296 | 297 | // TruncateTable creates a Query that can be used to truncate a table. 298 | func (b *BaseBuilder) TruncateTable(table string) *Query { 299 | sql := "TRUNCATE TABLE " + b.db.QuoteTableName(table) 300 | return b.NewQuery(sql) 301 | } 302 | 303 | // AddColumn creates a Query that can be used to add a column to a table. 304 | func (b *BaseBuilder) AddColumn(table, col, typ string) *Query { 305 | sql := fmt.Sprintf("ALTER TABLE %v ADD %v %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(col), typ) 306 | return b.NewQuery(sql) 307 | } 308 | 309 | // DropColumn creates a Query that can be used to drop a column from a table. 310 | func (b *BaseBuilder) DropColumn(table, col string) *Query { 311 | sql := fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(col)) 312 | return b.NewQuery(sql) 313 | } 314 | 315 | // RenameColumn creates a Query that can be used to rename a column in a table. 316 | func (b *BaseBuilder) RenameColumn(table, oldName, newName string) *Query { 317 | sql := fmt.Sprintf("ALTER TABLE %v RENAME COLUMN %v TO %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(oldName), b.db.QuoteColumnName(newName)) 318 | return b.NewQuery(sql) 319 | } 320 | 321 | // AlterColumn creates a Query that can be used to change the definition of a table column. 322 | func (b *BaseBuilder) AlterColumn(table, col, typ string) *Query { 323 | col = b.db.QuoteColumnName(col) 324 | sql := fmt.Sprintf("ALTER TABLE %v CHANGE %v %v %v", b.db.QuoteTableName(table), col, col, typ) 325 | return b.NewQuery(sql) 326 | } 327 | 328 | // AddPrimaryKey creates a Query that can be used to specify primary key(s) for a table. 329 | // The "name" parameter specifies the name of the primary key constraint. 330 | func (b *BaseBuilder) AddPrimaryKey(table, name string, cols ...string) *Query { 331 | sql := fmt.Sprintf("ALTER TABLE %v ADD CONSTRAINT %v PRIMARY KEY (%v)", 332 | b.db.QuoteTableName(table), 333 | b.db.QuoteColumnName(name), 334 | b.quoteColumns(cols)) 335 | return b.NewQuery(sql) 336 | } 337 | 338 | // DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table. 339 | func (b *BaseBuilder) DropPrimaryKey(table, name string) *Query { 340 | sql := fmt.Sprintf("ALTER TABLE %v DROP CONSTRAINT %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(name)) 341 | return b.NewQuery(sql) 342 | } 343 | 344 | // AddForeignKey creates a Query that can be used to add a foreign key constraint to a table. 345 | // The length of cols and refCols must be the same as they refer to the primary and referential columns. 346 | // The optional "options" parameters will be appended to the SQL statement. They can be used to 347 | // specify options such as "ON DELETE CASCADE". 348 | func (b *BaseBuilder) AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *Query { 349 | sql := fmt.Sprintf("ALTER TABLE %v ADD CONSTRAINT %v FOREIGN KEY (%v) REFERENCES %v (%v)", 350 | b.db.QuoteTableName(table), 351 | b.db.QuoteColumnName(name), 352 | b.quoteColumns(cols), 353 | b.db.QuoteTableName(refTable), 354 | b.quoteColumns(refCols)) 355 | for _, opt := range options { 356 | sql += " " + opt 357 | } 358 | return b.NewQuery(sql) 359 | } 360 | 361 | // DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table. 362 | func (b *BaseBuilder) DropForeignKey(table, name string) *Query { 363 | sql := fmt.Sprintf("ALTER TABLE %v DROP CONSTRAINT %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(name)) 364 | return b.NewQuery(sql) 365 | } 366 | 367 | // CreateIndex creates a Query that can be used to create an index for a table. 368 | func (b *BaseBuilder) CreateIndex(table, name string, cols ...string) *Query { 369 | sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v)", 370 | b.db.QuoteColumnName(name), 371 | b.db.QuoteTableName(table), 372 | b.quoteColumns(cols)) 373 | return b.NewQuery(sql) 374 | } 375 | 376 | // CreateUniqueIndex creates a Query that can be used to create a unique index for a table. 377 | func (b *BaseBuilder) CreateUniqueIndex(table, name string, cols ...string) *Query { 378 | sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v)", 379 | b.db.QuoteColumnName(name), 380 | b.db.QuoteTableName(table), 381 | b.quoteColumns(cols)) 382 | return b.NewQuery(sql) 383 | } 384 | 385 | // DropIndex creates a Query that can be used to remove the named index from a table. 386 | func (b *BaseBuilder) DropIndex(table, name string) *Query { 387 | sql := fmt.Sprintf("DROP INDEX %v ON %v", b.db.QuoteColumnName(name), b.db.QuoteTableName(table)) 388 | return b.NewQuery(sql) 389 | } 390 | 391 | // quoteColumns quotes a list of columns and concatenates them with commas. 392 | func (b *BaseBuilder) quoteColumns(cols []string) string { 393 | s := "" 394 | for i, col := range cols { 395 | if i == 0 { 396 | s = b.db.QuoteColumnName(col) 397 | } else { 398 | s += ", " + b.db.QuoteColumnName(col) 399 | } 400 | } 401 | return s 402 | } 403 | -------------------------------------------------------------------------------- /builder_mssql.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "fmt" 9 | "strings" 10 | ) 11 | 12 | // MssqlBuilder is the builder for SQL Server databases. 13 | type MssqlBuilder struct { 14 | *BaseBuilder 15 | qb *MssqlQueryBuilder 16 | } 17 | 18 | var _ Builder = &MssqlBuilder{} 19 | 20 | // MssqlQueryBuilder is the query builder for SQL Server databases. 21 | type MssqlQueryBuilder struct { 22 | *BaseQueryBuilder 23 | } 24 | 25 | // NewMssqlBuilder creates a new MssqlBuilder instance. 26 | func NewMssqlBuilder(db *DB, executor Executor) Builder { 27 | return &MssqlBuilder{ 28 | NewBaseBuilder(db, executor), 29 | &MssqlQueryBuilder{NewBaseQueryBuilder(db)}, 30 | } 31 | } 32 | 33 | // QueryBuilder returns the query builder supporting the current DB. 34 | func (b *MssqlBuilder) QueryBuilder() QueryBuilder { 35 | return b.qb 36 | } 37 | 38 | // Select returns a new SelectQuery object that can be used to build a SELECT statement. 39 | // The parameters to this method should be the list column names to be selected. 40 | // A column name may have an optional alias name. For example, Select("id", "my_name AS name"). 41 | func (b *MssqlBuilder) Select(cols ...string) *SelectQuery { 42 | return NewSelectQuery(b, b.db).Select(cols...) 43 | } 44 | 45 | // Model returns a new ModelQuery object that can be used to perform model-based DB operations. 46 | // The model passed to this method should be a pointer to a model struct. 47 | func (b *MssqlBuilder) Model(model interface{}) *ModelQuery { 48 | return NewModelQuery(model, b.db.FieldMapper, b.db, b) 49 | } 50 | 51 | // QuoteSimpleTableName quotes a simple table name. 52 | // A simple table name does not contain any schema prefix. 53 | func (b *MssqlBuilder) QuoteSimpleTableName(s string) string { 54 | if strings.Contains(s, `[`) { 55 | return s 56 | } 57 | return `[` + s + `]` 58 | } 59 | 60 | // QuoteSimpleColumnName quotes a simple column name. 61 | // A simple column name does not contain any table prefix. 62 | func (b *MssqlBuilder) QuoteSimpleColumnName(s string) string { 63 | if strings.Contains(s, `[`) || s == "*" { 64 | return s 65 | } 66 | return `[` + s + `]` 67 | } 68 | 69 | // RenameTable creates a Query that can be used to rename a table. 70 | func (b *MssqlBuilder) RenameTable(oldName, newName string) *Query { 71 | sql := fmt.Sprintf("sp_name '%v', '%v'", oldName, newName) 72 | return b.NewQuery(sql) 73 | } 74 | 75 | // RenameColumn creates a Query that can be used to rename a column in a table. 76 | func (b *MssqlBuilder) RenameColumn(table, oldName, newName string) *Query { 77 | sql := fmt.Sprintf("sp_name '%v.%v', '%v', 'COLUMN'", table, oldName, newName) 78 | return b.NewQuery(sql) 79 | } 80 | 81 | // AlterColumn creates a Query that can be used to change the definition of a table column. 82 | func (b *MssqlBuilder) AlterColumn(table, col, typ string) *Query { 83 | col = b.db.QuoteColumnName(col) 84 | sql := fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", b.db.QuoteTableName(table), col, typ) 85 | return b.NewQuery(sql) 86 | } 87 | 88 | // BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses. 89 | func (q *MssqlQueryBuilder) BuildOrderByAndLimit(sql string, cols []string, limit int64, offset int64) string { 90 | orderBy := q.BuildOrderBy(cols) 91 | if limit < 0 && offset < 0 { 92 | if orderBy == "" { 93 | return sql 94 | } 95 | return sql + "\n" + orderBy 96 | } 97 | 98 | // only SQL SERVER 2012 or newer are supported by this method 99 | 100 | if orderBy == "" { 101 | // ORDER BY clause is required when FETCH and OFFSET are in the SQL 102 | orderBy = "ORDER BY (SELECT NULL)" 103 | } 104 | sql += "\n" + orderBy 105 | 106 | // http://technet.microsoft.com/en-us/library/gg699618.aspx 107 | if offset < 0 { 108 | offset = 0 109 | } 110 | sql += "\n" + fmt.Sprintf("OFFSET %v ROWS", offset) 111 | if limit >= 0 { 112 | sql += "\n" + fmt.Sprintf("FETCH NEXT %v ROWS ONLY", limit) 113 | } 114 | return sql 115 | } 116 | -------------------------------------------------------------------------------- /builder_mssql_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestMssqlBuilder_QuoteSimpleTableName(t *testing.T) { 14 | b := getMssqlBuilder() 15 | assert.Equal(t, b.QuoteSimpleTableName(`abc`), "[abc]", "t1") 16 | assert.Equal(t, b.QuoteSimpleTableName("[abc]"), "[abc]", "t2") 17 | assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), "[{{abc}}]", "t3") 18 | assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), "[a.bc]", "t4") 19 | } 20 | 21 | func TestMssqlBuilder_QuoteSimpleColumnName(t *testing.T) { 22 | b := getMssqlBuilder() 23 | assert.Equal(t, b.QuoteSimpleColumnName(`abc`), "[abc]", "t1") 24 | assert.Equal(t, b.QuoteSimpleColumnName("[abc]"), "[abc]", "t2") 25 | assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), "[{{abc}}]", "t3") 26 | assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), "[a.bc]", "t4") 27 | assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5") 28 | } 29 | 30 | func TestMssqlBuilder_RenameTable(t *testing.T) { 31 | b := getMssqlBuilder() 32 | q := b.RenameTable("users", "user") 33 | assert.Equal(t, q.SQL(), `sp_name 'users', 'user'`, "t1") 34 | } 35 | 36 | func TestMssqlBuilder_RenameColumn(t *testing.T) { 37 | b := getMssqlBuilder() 38 | q := b.RenameColumn("users", "name", "username") 39 | assert.Equal(t, q.SQL(), `sp_name 'users.name', 'username', 'COLUMN'`, "t1") 40 | } 41 | 42 | func TestMssqlBuilder_AlterColumn(t *testing.T) { 43 | b := getMssqlBuilder() 44 | q := b.AlterColumn("users", "name", "int") 45 | assert.Equal(t, q.SQL(), `ALTER TABLE [users] ALTER COLUMN [name] int`, "t1") 46 | } 47 | 48 | func TestMssqlQueryBuilder_BuildOrderByAndLimit(t *testing.T) { 49 | qb := getMssqlBuilder().QueryBuilder() 50 | 51 | sql := qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, 10, 2) 52 | expected := "SELECT *\nORDER BY [name]\nOFFSET 2 ROWS\nFETCH NEXT 10 ROWS ONLY" 53 | assert.Equal(t, sql, expected, "t1") 54 | 55 | sql = qb.BuildOrderByAndLimit("SELECT *", nil, -1, -1) 56 | expected = "SELECT *" 57 | assert.Equal(t, sql, expected, "t2") 58 | 59 | sql = qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, -1, -1) 60 | expected = "SELECT *\nORDER BY [name]" 61 | assert.Equal(t, sql, expected, "t3") 62 | 63 | sql = qb.BuildOrderByAndLimit("SELECT *", nil, 10, -1) 64 | expected = "SELECT *\nORDER BY (SELECT NULL)\nOFFSET 0 ROWS\nFETCH NEXT 10 ROWS ONLY" 65 | assert.Equal(t, sql, expected, "t4") 66 | } 67 | 68 | func getMssqlBuilder() Builder { 69 | db := getDB() 70 | b := NewMssqlBuilder(db, db.sqlDB) 71 | db.Builder = b 72 | return b 73 | } 74 | -------------------------------------------------------------------------------- /builder_mysql.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "fmt" 9 | "regexp" 10 | "sort" 11 | "strings" 12 | ) 13 | 14 | // MysqlBuilder is the builder for MySQL databases. 15 | type MysqlBuilder struct { 16 | *BaseBuilder 17 | qb *BaseQueryBuilder 18 | } 19 | 20 | var _ Builder = &MysqlBuilder{} 21 | 22 | // NewMysqlBuilder creates a new MysqlBuilder instance. 23 | func NewMysqlBuilder(db *DB, executor Executor) Builder { 24 | return &MysqlBuilder{ 25 | NewBaseBuilder(db, executor), 26 | NewBaseQueryBuilder(db), 27 | } 28 | } 29 | 30 | // QueryBuilder returns the query builder supporting the current DB. 31 | func (b *MysqlBuilder) QueryBuilder() QueryBuilder { 32 | return b.qb 33 | } 34 | 35 | // Select returns a new SelectQuery object that can be used to build a SELECT statement. 36 | // The parameters to this method should be the list column names to be selected. 37 | // A column name may have an optional alias name. For example, Select("id", "my_name AS name"). 38 | func (b *MysqlBuilder) Select(cols ...string) *SelectQuery { 39 | return NewSelectQuery(b, b.db).Select(cols...) 40 | } 41 | 42 | // Model returns a new ModelQuery object that can be used to perform model-based DB operations. 43 | // The model passed to this method should be a pointer to a model struct. 44 | func (b *MysqlBuilder) Model(model interface{}) *ModelQuery { 45 | return NewModelQuery(model, b.db.FieldMapper, b.db, b) 46 | } 47 | 48 | // QuoteSimpleTableName quotes a simple table name. 49 | // A simple table name does not contain any schema prefix. 50 | func (b *MysqlBuilder) QuoteSimpleTableName(s string) string { 51 | if strings.ContainsAny(s, "`") { 52 | return s 53 | } 54 | return "`" + s + "`" 55 | } 56 | 57 | // QuoteSimpleColumnName quotes a simple column name. 58 | // A simple column name does not contain any table prefix. 59 | func (b *MysqlBuilder) QuoteSimpleColumnName(s string) string { 60 | if strings.Contains(s, "`") || s == "*" { 61 | return s 62 | } 63 | return "`" + s + "`" 64 | } 65 | 66 | // Upsert creates a Query that represents an UPSERT SQL statement. 67 | // Upsert inserts a row into the table if the primary key or unique index is not found. 68 | // Otherwise it will update the row with the new values. 69 | // The keys of cols are the column names, while the values of cols are the corresponding column 70 | // values to be inserted. 71 | func (b *MysqlBuilder) Upsert(table string, cols Params, constraints ...string) *Query { 72 | q := b.Insert(table, cols) 73 | 74 | names := []string{} 75 | for name := range cols { 76 | names = append(names, name) 77 | } 78 | sort.Strings(names) 79 | 80 | lines := []string{} 81 | for _, name := range names { 82 | value := cols[name] 83 | name = b.db.QuoteColumnName(name) 84 | if e, ok := value.(Expression); ok { 85 | lines = append(lines, name+"="+e.Build(b.db, q.params)) 86 | } else { 87 | lines = append(lines, fmt.Sprintf("%v={:p%v}", name, len(q.params))) 88 | q.params[fmt.Sprintf("p%v", len(q.params))] = value 89 | } 90 | } 91 | 92 | q.sql += " ON DUPLICATE KEY UPDATE " + strings.Join(lines, ", ") 93 | 94 | return q 95 | } 96 | 97 | var mysqlColumnRegexp = regexp.MustCompile("(?m)^\\s*[`\"](.*?)[`\"]\\s+(.*?),?$") 98 | 99 | // RenameColumn creates a Query that can be used to rename a column in a table. 100 | func (b *MysqlBuilder) RenameColumn(table, oldName, newName string) *Query { 101 | qt := b.db.QuoteTableName(table) 102 | sql := fmt.Sprintf("ALTER TABLE %v CHANGE %v %v", qt, b.db.QuoteColumnName(oldName), b.db.QuoteColumnName(newName)) 103 | 104 | var info struct { 105 | SQL string `db:"Create Table"` 106 | } 107 | if err := b.db.NewQuery("SHOW CREATE TABLE " + qt).One(&info); err != nil { 108 | return b.db.NewQuery(sql) 109 | } 110 | 111 | if matches := mysqlColumnRegexp.FindAllStringSubmatch(info.SQL, -1); matches != nil { 112 | for _, match := range matches { 113 | if match[1] == oldName { 114 | sql += " " + match[2] 115 | break 116 | } 117 | } 118 | } 119 | 120 | return b.db.NewQuery(sql) 121 | } 122 | 123 | // DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table. 124 | func (b *MysqlBuilder) DropPrimaryKey(table, name string) *Query { 125 | sql := fmt.Sprintf("ALTER TABLE %v DROP PRIMARY KEY", b.db.QuoteTableName(table)) 126 | return b.db.NewQuery(sql) 127 | } 128 | 129 | // DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table. 130 | func (b *MysqlBuilder) DropForeignKey(table, name string) *Query { 131 | sql := fmt.Sprintf("ALTER TABLE %v DROP FOREIGN KEY %v", b.db.QuoteTableName(table), b.db.QuoteColumnName(name)) 132 | return b.db.NewQuery(sql) 133 | } 134 | -------------------------------------------------------------------------------- /builder_mysql_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestMysqlBuilder_QuoteSimpleTableName(t *testing.T) { 14 | b := getMysqlBuilder() 15 | assert.Equal(t, b.QuoteSimpleTableName(`abc`), "`abc`", "t1") 16 | assert.Equal(t, b.QuoteSimpleTableName("`abc`"), "`abc`", "t2") 17 | assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), "`{{abc}}`", "t3") 18 | assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), "`a.bc`", "t4") 19 | } 20 | 21 | func TestMysqlBuilder_QuoteSimpleColumnName(t *testing.T) { 22 | b := getMysqlBuilder() 23 | assert.Equal(t, b.QuoteSimpleColumnName(`abc`), "`abc`", "t1") 24 | assert.Equal(t, b.QuoteSimpleColumnName("`abc`"), "`abc`", "t2") 25 | assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), "`{{abc}}`", "t3") 26 | assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), "`a.bc`", "t4") 27 | assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5") 28 | } 29 | 30 | func TestMysqlBuilder_Upsert(t *testing.T) { 31 | getPreparedDB() 32 | b := getMysqlBuilder() 33 | q := b.Upsert("users", Params{ 34 | "name": "James", 35 | "age": 30, 36 | }) 37 | assert.Equal(t, q.SQL(), "INSERT INTO `users` (`age`, `name`) VALUES ({:p0}, {:p1}) ON DUPLICATE KEY UPDATE `age`={:p2}, `name`={:p3}", "t1") 38 | assert.Equal(t, q.Params()["p0"], 30, "t2") 39 | assert.Equal(t, q.Params()["p1"], "James", "t3") 40 | assert.Equal(t, q.Params()["p2"], 30, "t2") 41 | assert.Equal(t, q.Params()["p3"], "James", "t3") 42 | } 43 | 44 | func TestMysqlBuilder_RenameColumn(t *testing.T) { 45 | b := getMysqlBuilder() 46 | q := b.RenameColumn("users", "name", "username") 47 | assert.Equal(t, q.SQL(), "ALTER TABLE `users` CHANGE `name` `username`") 48 | q = b.RenameColumn("customer", "email", "e") 49 | assert.Equal(t, q.SQL(), "ALTER TABLE `customer` CHANGE `email` `e` varchar(128) NOT NULL") 50 | } 51 | 52 | func TestMysqlBuilder_DropPrimaryKey(t *testing.T) { 53 | b := getMysqlBuilder() 54 | q := b.DropPrimaryKey("users", "pk") 55 | assert.Equal(t, q.SQL(), "ALTER TABLE `users` DROP PRIMARY KEY", "t1") 56 | } 57 | 58 | func TestMysqlBuilder_DropForeignKey(t *testing.T) { 59 | b := getMysqlBuilder() 60 | q := b.DropForeignKey("users", "fk") 61 | assert.Equal(t, q.SQL(), "ALTER TABLE `users` DROP FOREIGN KEY `fk`", "t1") 62 | } 63 | 64 | func getMysqlBuilder() Builder { 65 | db := getDB() 66 | b := NewMysqlBuilder(db, db.sqlDB) 67 | db.Builder = b 68 | return b 69 | } 70 | -------------------------------------------------------------------------------- /builder_oci.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "fmt" 9 | ) 10 | 11 | // OciBuilder is the builder for Oracle databases. 12 | type OciBuilder struct { 13 | *BaseBuilder 14 | qb *OciQueryBuilder 15 | } 16 | 17 | var _ Builder = &OciBuilder{} 18 | 19 | // OciQueryBuilder is the query builder for Oracle databases. 20 | type OciQueryBuilder struct { 21 | *BaseQueryBuilder 22 | } 23 | 24 | // NewOciBuilder creates a new OciBuilder instance. 25 | func NewOciBuilder(db *DB, executor Executor) Builder { 26 | return &OciBuilder{ 27 | NewBaseBuilder(db, executor), 28 | &OciQueryBuilder{NewBaseQueryBuilder(db)}, 29 | } 30 | } 31 | 32 | // Select returns a new SelectQuery object that can be used to build a SELECT statement. 33 | // The parameters to this method should be the list column names to be selected. 34 | // A column name may have an optional alias name. For example, Select("id", "my_name AS name"). 35 | func (b *OciBuilder) Select(cols ...string) *SelectQuery { 36 | return NewSelectQuery(b, b.db).Select(cols...) 37 | } 38 | 39 | // Model returns a new ModelQuery object that can be used to perform model-based DB operations. 40 | // The model passed to this method should be a pointer to a model struct. 41 | func (b *OciBuilder) Model(model interface{}) *ModelQuery { 42 | return NewModelQuery(model, b.db.FieldMapper, b.db, b) 43 | } 44 | 45 | // GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID. 46 | func (b *OciBuilder) GeneratePlaceholder(i int) string { 47 | return fmt.Sprintf(":p%v", i) 48 | } 49 | 50 | // QueryBuilder returns the query builder supporting the current DB. 51 | func (b *OciBuilder) QueryBuilder() QueryBuilder { 52 | return b.qb 53 | } 54 | 55 | // DropIndex creates a Query that can be used to remove the named index from a table. 56 | func (b *OciBuilder) DropIndex(table, name string) *Query { 57 | sql := fmt.Sprintf("DROP INDEX %v", b.db.QuoteColumnName(name)) 58 | return b.NewQuery(sql) 59 | } 60 | 61 | // RenameTable creates a Query that can be used to rename a table. 62 | func (b *OciBuilder) RenameTable(oldName, newName string) *Query { 63 | sql := fmt.Sprintf("ALTER TABLE %v RENAME TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName)) 64 | return b.NewQuery(sql) 65 | } 66 | 67 | // AlterColumn creates a Query that can be used to change the definition of a table column. 68 | func (b *OciBuilder) AlterColumn(table, col, typ string) *Query { 69 | col = b.db.QuoteColumnName(col) 70 | sql := fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", b.db.QuoteTableName(table), col, typ) 71 | return b.NewQuery(sql) 72 | } 73 | 74 | // BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses. 75 | func (q *OciQueryBuilder) BuildOrderByAndLimit(sql string, cols []string, limit int64, offset int64) string { 76 | if orderBy := q.BuildOrderBy(cols); orderBy != "" { 77 | sql += "\n" + orderBy 78 | } 79 | 80 | c := "" 81 | if offset > 0 { 82 | c = fmt.Sprintf("rowNumId > %v", offset) 83 | } 84 | if limit >= 0 { 85 | if c != "" { 86 | c += " AND " 87 | } 88 | c += fmt.Sprintf("rowNum <= %v", limit) 89 | } 90 | 91 | if c == "" { 92 | return sql 93 | } 94 | 95 | return `WITH USER_SQL AS (` + sql + `), 96 | PAGINATION AS (SELECT USER_SQL.*, rownum as rowNumId FROM USER_SQL) 97 | SELECT * FROM PAGINATION WHERE ` + c 98 | } 99 | -------------------------------------------------------------------------------- /builder_oci_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestOciBuilder_DropIndex(t *testing.T) { 14 | b := getOciBuilder() 15 | q := b.DropIndex("users", "idx") 16 | assert.Equal(t, q.SQL(), `DROP INDEX "idx"`, "t1") 17 | } 18 | 19 | func TestOciBuilder_RenameTable(t *testing.T) { 20 | b := getOciBuilder() 21 | q := b.RenameTable("users", "user") 22 | assert.Equal(t, q.SQL(), `ALTER TABLE "users" RENAME TO "user"`, "t1") 23 | } 24 | 25 | func TestOciBuilder_AlterColumn(t *testing.T) { 26 | b := getOciBuilder() 27 | q := b.AlterColumn("users", "name", "int") 28 | assert.Equal(t, q.SQL(), `ALTER TABLE "users" MODIFY "name" int`, "t1") 29 | } 30 | 31 | func TestOciQueryBuilder_BuildOrderByAndLimit(t *testing.T) { 32 | qb := getOciBuilder().QueryBuilder() 33 | 34 | sql := qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, 10, 2) 35 | expected := "WITH USER_SQL AS (SELECT *\nORDER BY \"name\"),\n\tPAGINATION AS (SELECT USER_SQL.*, rownum as rowNumId FROM USER_SQL)\nSELECT * FROM PAGINATION WHERE rowNumId > 2 AND rowNum <= 10" 36 | assert.Equal(t, sql, expected, "t1") 37 | 38 | sql = qb.BuildOrderByAndLimit("SELECT *", nil, -1, -1) 39 | expected = "SELECT *" 40 | assert.Equal(t, sql, expected, "t2") 41 | 42 | sql = qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, -1, -1) 43 | expected = "SELECT *\nORDER BY \"name\"" 44 | assert.Equal(t, sql, expected, "t3") 45 | 46 | sql = qb.BuildOrderByAndLimit("SELECT *", nil, 10, -1) 47 | expected = "WITH USER_SQL AS (SELECT *),\n\tPAGINATION AS (SELECT USER_SQL.*, rownum as rowNumId FROM USER_SQL)\nSELECT * FROM PAGINATION WHERE rowNum <= 10" 48 | assert.Equal(t, sql, expected, "t4") 49 | } 50 | 51 | func getOciBuilder() Builder { 52 | db := getDB() 53 | b := NewOciBuilder(db, db.sqlDB) 54 | db.Builder = b 55 | return b 56 | } 57 | -------------------------------------------------------------------------------- /builder_pgsql.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "fmt" 9 | "sort" 10 | "strings" 11 | ) 12 | 13 | // PgsqlBuilder is the builder for PostgreSQL databases. 14 | type PgsqlBuilder struct { 15 | *BaseBuilder 16 | qb *BaseQueryBuilder 17 | } 18 | 19 | var _ Builder = &PgsqlBuilder{} 20 | 21 | // NewPgsqlBuilder creates a new PgsqlBuilder instance. 22 | func NewPgsqlBuilder(db *DB, executor Executor) Builder { 23 | return &PgsqlBuilder{ 24 | NewBaseBuilder(db, executor), 25 | NewBaseQueryBuilder(db), 26 | } 27 | } 28 | 29 | // Select returns a new SelectQuery object that can be used to build a SELECT statement. 30 | // The parameters to this method should be the list column names to be selected. 31 | // A column name may have an optional alias name. For example, Select("id", "my_name AS name"). 32 | func (b *PgsqlBuilder) Select(cols ...string) *SelectQuery { 33 | return NewSelectQuery(b, b.db).Select(cols...) 34 | } 35 | 36 | // Model returns a new ModelQuery object that can be used to perform model-based DB operations. 37 | // The model passed to this method should be a pointer to a model struct. 38 | func (b *PgsqlBuilder) Model(model interface{}) *ModelQuery { 39 | return NewModelQuery(model, b.db.FieldMapper, b.db, b) 40 | } 41 | 42 | // GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID. 43 | func (b *PgsqlBuilder) GeneratePlaceholder(i int) string { 44 | return fmt.Sprintf("$%v", i) 45 | } 46 | 47 | // QueryBuilder returns the query builder supporting the current DB. 48 | func (b *PgsqlBuilder) QueryBuilder() QueryBuilder { 49 | return b.qb 50 | } 51 | 52 | // Upsert creates a Query that represents an UPSERT SQL statement. 53 | // Upsert inserts a row into the table if the primary key or unique index is not found. 54 | // Otherwise it will update the row with the new values. 55 | // The keys of cols are the column names, while the values of cols are the corresponding column 56 | // values to be inserted. 57 | func (b *PgsqlBuilder) Upsert(table string, cols Params, constraints ...string) *Query { 58 | q := b.Insert(table, cols) 59 | 60 | names := []string{} 61 | for name := range cols { 62 | names = append(names, name) 63 | } 64 | sort.Strings(names) 65 | 66 | lines := []string{} 67 | for _, name := range names { 68 | value := cols[name] 69 | name = b.db.QuoteColumnName(name) 70 | if e, ok := value.(Expression); ok { 71 | lines = append(lines, name+"="+e.Build(b.db, q.params)) 72 | } else { 73 | lines = append(lines, fmt.Sprintf("%v={:p%v}", name, len(q.params))) 74 | q.params[fmt.Sprintf("p%v", len(q.params))] = value 75 | } 76 | } 77 | 78 | if len(constraints) > 0 { 79 | c := b.quoteColumns(constraints) 80 | q.sql += " ON CONFLICT (" + c + ") DO UPDATE SET " + strings.Join(lines, ", ") 81 | } else { 82 | q.sql += " ON CONFLICT DO UPDATE SET " + strings.Join(lines, ", ") 83 | } 84 | 85 | return b.NewQuery(q.sql).Bind(q.params) 86 | } 87 | 88 | // DropIndex creates a Query that can be used to remove the named index from a table. 89 | func (b *PgsqlBuilder) DropIndex(table, name string) *Query { 90 | sql := fmt.Sprintf("DROP INDEX %v", b.db.QuoteColumnName(name)) 91 | return b.NewQuery(sql) 92 | } 93 | 94 | // RenameTable creates a Query that can be used to rename a table. 95 | func (b *PgsqlBuilder) RenameTable(oldName, newName string) *Query { 96 | sql := fmt.Sprintf("ALTER TABLE %v RENAME TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName)) 97 | return b.NewQuery(sql) 98 | } 99 | 100 | // AlterColumn creates a Query that can be used to change the definition of a table column. 101 | func (b *PgsqlBuilder) AlterColumn(table, col, typ string) *Query { 102 | col = b.db.QuoteColumnName(col) 103 | sql := fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", b.db.QuoteTableName(table), col, typ) 104 | return b.NewQuery(sql) 105 | } 106 | -------------------------------------------------------------------------------- /builder_pgsql_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestPgsqlBuilder_Upsert(t *testing.T) { 14 | b := getPgsqlBuilder() 15 | q := b.Upsert("users", Params{ 16 | "name": "James", 17 | "age": 30, 18 | }, "id") 19 | assert.Equal(t, q.sql, `INSERT INTO "users" ("age", "name") VALUES ({:p0}, {:p1}) ON CONFLICT ("id") DO UPDATE SET "age"={:p2}, "name"={:p3}`, "t1") 20 | assert.Equal(t, q.rawSQL, `INSERT INTO "users" ("age", "name") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "age"=$3, "name"=$4`, "t2") 21 | assert.Equal(t, q.Params()["p0"], 30, "t3") 22 | assert.Equal(t, q.Params()["p1"], "James", "t4") 23 | assert.Equal(t, q.Params()["p2"], 30, "t5") 24 | assert.Equal(t, q.Params()["p3"], "James", "t6") 25 | } 26 | func TestPgsqlBuilder_DropIndex(t *testing.T) { 27 | b := getPgsqlBuilder() 28 | q := b.DropIndex("users", "idx") 29 | assert.Equal(t, q.SQL(), `DROP INDEX "idx"`, "t1") 30 | } 31 | 32 | func TestPgsqlBuilder_RenameTable(t *testing.T) { 33 | b := getPgsqlBuilder() 34 | q := b.RenameTable("users", "user") 35 | assert.Equal(t, q.SQL(), `ALTER TABLE "users" RENAME TO "user"`, "t1") 36 | } 37 | 38 | func TestPgsqlBuilder_AlterColumn(t *testing.T) { 39 | b := getPgsqlBuilder() 40 | q := b.AlterColumn("users", "name", "int") 41 | assert.Equal(t, q.SQL(), `ALTER TABLE "users" ALTER COLUMN "name" TYPE int`, "t1") 42 | } 43 | 44 | func getPgsqlBuilder() Builder { 45 | db := getDB() 46 | b := NewPgsqlBuilder(db, db.sqlDB) 47 | db.Builder = b 48 | return b 49 | } 50 | -------------------------------------------------------------------------------- /builder_sqlite.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "strings" 11 | ) 12 | 13 | // SqliteBuilder is the builder for SQLite databases. 14 | type SqliteBuilder struct { 15 | *BaseBuilder 16 | qb *BaseQueryBuilder 17 | } 18 | 19 | var _ Builder = &SqliteBuilder{} 20 | 21 | // NewSqliteBuilder creates a new SqliteBuilder instance. 22 | func NewSqliteBuilder(db *DB, executor Executor) Builder { 23 | return &SqliteBuilder{ 24 | NewBaseBuilder(db, executor), 25 | NewBaseQueryBuilder(db), 26 | } 27 | } 28 | 29 | // QueryBuilder returns the query builder supporting the current DB. 30 | func (b *SqliteBuilder) QueryBuilder() QueryBuilder { 31 | return b.qb 32 | } 33 | 34 | // Select returns a new SelectQuery object that can be used to build a SELECT statement. 35 | // The parameters to this method should be the list column names to be selected. 36 | // A column name may have an optional alias name. For example, Select("id", "my_name AS name"). 37 | func (b *SqliteBuilder) Select(cols ...string) *SelectQuery { 38 | return NewSelectQuery(b, b.db).Select(cols...) 39 | } 40 | 41 | // Model returns a new ModelQuery object that can be used to perform model-based DB operations. 42 | // The model passed to this method should be a pointer to a model struct. 43 | func (b *SqliteBuilder) Model(model interface{}) *ModelQuery { 44 | return NewModelQuery(model, b.db.FieldMapper, b.db, b) 45 | } 46 | 47 | // QuoteSimpleTableName quotes a simple table name. 48 | // A simple table name does not contain any schema prefix. 49 | func (b *SqliteBuilder) QuoteSimpleTableName(s string) string { 50 | if strings.ContainsAny(s, "`") { 51 | return s 52 | } 53 | return "`" + s + "`" 54 | } 55 | 56 | // QuoteSimpleColumnName quotes a simple column name. 57 | // A simple column name does not contain any table prefix. 58 | func (b *SqliteBuilder) QuoteSimpleColumnName(s string) string { 59 | if strings.Contains(s, "`") || s == "*" { 60 | return s 61 | } 62 | return "`" + s + "`" 63 | } 64 | 65 | // DropIndex creates a Query that can be used to remove the named index from a table. 66 | func (b *SqliteBuilder) DropIndex(table, name string) *Query { 67 | sql := fmt.Sprintf("DROP INDEX %v", b.db.QuoteColumnName(name)) 68 | return b.NewQuery(sql) 69 | } 70 | 71 | // TruncateTable creates a Query that can be used to truncate a table. 72 | func (b *SqliteBuilder) TruncateTable(table string) *Query { 73 | sql := "DELETE FROM " + b.db.QuoteTableName(table) 74 | return b.NewQuery(sql) 75 | } 76 | 77 | // RenameTable creates a Query that can be used to rename a table. 78 | func (b *SqliteBuilder) RenameTable(oldName, newName string) *Query { 79 | sql := fmt.Sprintf("ALTER TABLE %v RENAME TO %v", b.db.QuoteTableName(oldName), b.db.QuoteTableName(newName)) 80 | return b.NewQuery(sql) 81 | } 82 | 83 | // AlterColumn creates a Query that can be used to change the definition of a table column. 84 | func (b *SqliteBuilder) AlterColumn(table, col, typ string) *Query { 85 | q := b.NewQuery("") 86 | q.LastError = errors.New("SQLite does not support altering column") 87 | return q 88 | } 89 | 90 | // AddPrimaryKey creates a Query that can be used to specify primary key(s) for a table. 91 | // The "name" parameter specifies the name of the primary key constraint. 92 | func (b *SqliteBuilder) AddPrimaryKey(table, name string, cols ...string) *Query { 93 | q := b.NewQuery("") 94 | q.LastError = errors.New("SQLite does not support adding primary key") 95 | return q 96 | } 97 | 98 | // DropPrimaryKey creates a Query that can be used to remove the named primary key constraint from a table. 99 | func (b *SqliteBuilder) DropPrimaryKey(table, name string) *Query { 100 | q := b.NewQuery("") 101 | q.LastError = errors.New("SQLite does not support dropping primary key") 102 | return q 103 | } 104 | 105 | // AddForeignKey creates a Query that can be used to add a foreign key constraint to a table. 106 | // The length of cols and refCols must be the same as they refer to the primary and referential columns. 107 | // The optional "options" parameters will be appended to the SQL statement. They can be used to 108 | // specify options such as "ON DELETE CASCADE". 109 | func (b *SqliteBuilder) AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *Query { 110 | q := b.NewQuery("") 111 | q.LastError = errors.New("SQLite does not support adding foreign keys") 112 | return q 113 | } 114 | 115 | // DropForeignKey creates a Query that can be used to remove the named foreign key constraint from a table. 116 | func (b *SqliteBuilder) DropForeignKey(table, name string) *Query { 117 | q := b.NewQuery("") 118 | q.LastError = errors.New("SQLite does not support dropping foreign keys") 119 | return q 120 | } 121 | -------------------------------------------------------------------------------- /builder_sqlite_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestSqliteBuilder_QuoteSimpleTableName(t *testing.T) { 14 | b := getSqliteBuilder() 15 | assert.Equal(t, b.QuoteSimpleTableName(`abc`), "`abc`", "t1") 16 | assert.Equal(t, b.QuoteSimpleTableName("`abc`"), "`abc`", "t2") 17 | assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), "`{{abc}}`", "t3") 18 | assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), "`a.bc`", "t4") 19 | } 20 | 21 | func TestSqliteBuilder_QuoteSimpleColumnName(t *testing.T) { 22 | b := getSqliteBuilder() 23 | assert.Equal(t, b.QuoteSimpleColumnName(`abc`), "`abc`", "t1") 24 | assert.Equal(t, b.QuoteSimpleColumnName("`abc`"), "`abc`", "t2") 25 | assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), "`{{abc}}`", "t3") 26 | assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), "`a.bc`", "t4") 27 | assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5") 28 | } 29 | 30 | func TestSqliteBuilder_DropIndex(t *testing.T) { 31 | b := getSqliteBuilder() 32 | q := b.DropIndex("users", "idx") 33 | assert.Equal(t, q.SQL(), "DROP INDEX `idx`", "t1") 34 | } 35 | 36 | func TestSqliteBuilder_TruncateTable(t *testing.T) { 37 | b := getSqliteBuilder() 38 | q := b.TruncateTable("users") 39 | assert.Equal(t, q.SQL(), "DELETE FROM `users`", "t1") 40 | } 41 | 42 | func TestSqliteBuilder_RenameTable(t *testing.T) { 43 | b := getSqliteBuilder() 44 | q := b.RenameTable("usersOld", "usersNew") 45 | assert.Equal(t, q.SQL(), "ALTER TABLE `usersOld` RENAME TO `usersNew`", "t1") 46 | } 47 | 48 | func TestSqliteBuilder_AlterColumn(t *testing.T) { 49 | b := getSqliteBuilder() 50 | q := b.AlterColumn("users", "name", "int") 51 | assert.NotEqual(t, q.LastError, nil, "t1") 52 | } 53 | 54 | func TestSqliteBuilder_AddPrimaryKey(t *testing.T) { 55 | b := getSqliteBuilder() 56 | q := b.AddPrimaryKey("users", "pk", "id1", "id2") 57 | assert.NotEqual(t, q.LastError, nil, "t1") 58 | } 59 | 60 | func TestSqliteBuilder_DropPrimaryKey(t *testing.T) { 61 | b := getSqliteBuilder() 62 | q := b.DropPrimaryKey("users", "pk") 63 | assert.NotEqual(t, q.LastError, nil, "t1") 64 | } 65 | 66 | func TestSqliteBuilder_AddForeignKey(t *testing.T) { 67 | b := getSqliteBuilder() 68 | q := b.AddForeignKey("users", "fk", []string{"p1", "p2"}, []string{"f1", "f2"}, "profile", "opt") 69 | assert.NotEqual(t, q.LastError, nil, "t1") 70 | } 71 | 72 | func TestSqliteBuilder_DropForeignKey(t *testing.T) { 73 | b := getSqliteBuilder() 74 | q := b.DropForeignKey("users", "fk") 75 | assert.NotEqual(t, q.LastError, nil, "t1") 76 | } 77 | 78 | func getSqliteBuilder() Builder { 79 | db := getDB() 80 | b := NewSqliteBuilder(db, db.sqlDB) 81 | db.Builder = b 82 | return b 83 | } 84 | -------------------------------------------------------------------------------- /builder_standard.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | // StandardBuilder is the builder that is used by DB for an unknown driver. 8 | type StandardBuilder struct { 9 | *BaseBuilder 10 | qb *BaseQueryBuilder 11 | } 12 | 13 | var _ Builder = &StandardBuilder{} 14 | 15 | // NewStandardBuilder creates a new StandardBuilder instance. 16 | func NewStandardBuilder(db *DB, executor Executor) Builder { 17 | return &StandardBuilder{ 18 | NewBaseBuilder(db, executor), 19 | NewBaseQueryBuilder(db), 20 | } 21 | } 22 | 23 | // QueryBuilder returns the query builder supporting the current DB. 24 | func (b *StandardBuilder) QueryBuilder() QueryBuilder { 25 | return b.qb 26 | } 27 | 28 | // Select returns a new SelectQuery object that can be used to build a SELECT statement. 29 | // The parameters to this method should be the list column names to be selected. 30 | // A column name may have an optional alias name. For example, Select("id", "my_name AS name"). 31 | func (b *StandardBuilder) Select(cols ...string) *SelectQuery { 32 | return NewSelectQuery(b, b.db).Select(cols...) 33 | } 34 | 35 | // Model returns a new ModelQuery object that can be used to perform model-based DB operations. 36 | // The model passed to this method should be a pointer to a model struct. 37 | func (b *StandardBuilder) Model(model interface{}) *ModelQuery { 38 | return NewModelQuery(model, b.db.FieldMapper, b.db, b) 39 | } 40 | -------------------------------------------------------------------------------- /builder_standard_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestStandardBuilder_Quote(t *testing.T) { 14 | b := getStandardBuilder() 15 | assert.Equal(t, b.Quote(`abc`), `'abc'`, "t1") 16 | assert.Equal(t, b.Quote(`I'm`), `'I''m'`, "t2") 17 | assert.Equal(t, b.Quote(``), `''`, "t3") 18 | } 19 | 20 | func TestStandardBuilder_QuoteSimpleTableName(t *testing.T) { 21 | b := getStandardBuilder() 22 | assert.Equal(t, b.QuoteSimpleTableName(`abc`), `"abc"`, "t1") 23 | assert.Equal(t, b.QuoteSimpleTableName(`"abc"`), `"abc"`, "t2") 24 | assert.Equal(t, b.QuoteSimpleTableName(`{{abc}}`), `"{{abc}}"`, "t3") 25 | assert.Equal(t, b.QuoteSimpleTableName(`a.bc`), `"a.bc"`, "t4") 26 | } 27 | 28 | func TestStandardBuilder_QuoteSimpleColumnName(t *testing.T) { 29 | b := getStandardBuilder() 30 | assert.Equal(t, b.QuoteSimpleColumnName(`abc`), `"abc"`, "t1") 31 | assert.Equal(t, b.QuoteSimpleColumnName(`"abc"`), `"abc"`, "t2") 32 | assert.Equal(t, b.QuoteSimpleColumnName(`{{abc}}`), `"{{abc}}"`, "t3") 33 | assert.Equal(t, b.QuoteSimpleColumnName(`a.bc`), `"a.bc"`, "t4") 34 | assert.Equal(t, b.QuoteSimpleColumnName(`*`), `*`, "t5") 35 | } 36 | 37 | func TestStandardBuilder_Insert(t *testing.T) { 38 | b := getStandardBuilder() 39 | q := b.Insert("users", Params{ 40 | "name": "James", 41 | "age": 30, 42 | }) 43 | assert.Equal(t, q.SQL(), `INSERT INTO "users" ("age", "name") VALUES ({:p0}, {:p1})`, "t1") 44 | assert.Equal(t, q.Params()["p0"], 30, "t2") 45 | assert.Equal(t, q.Params()["p1"], "James", "t3") 46 | 47 | q = b.Insert("users", Params{}) 48 | assert.Equal(t, q.SQL(), `INSERT INTO "users" DEFAULT VALUES`, "t2") 49 | } 50 | 51 | func TestStandardBuilder_Upsert(t *testing.T) { 52 | b := getStandardBuilder() 53 | q := b.Upsert("users", Params{ 54 | "name": "James", 55 | "age": 30, 56 | }) 57 | assert.NotEqual(t, q.LastError, nil, "t1") 58 | } 59 | 60 | func TestStandardBuilder_Update(t *testing.T) { 61 | b := getStandardBuilder() 62 | q := b.Update("users", Params{ 63 | "name": "James", 64 | "age": 30, 65 | }, NewExp("id=10")) 66 | assert.Equal(t, q.SQL(), `UPDATE "users" SET "age"={:p0}, "name"={:p1} WHERE id=10`, "t1") 67 | assert.Equal(t, q.Params()["p0"], 30, "t2") 68 | assert.Equal(t, q.Params()["p1"], "James", "t3") 69 | 70 | q = b.Update("users", Params{ 71 | "name": "James", 72 | "age": 30, 73 | }, nil) 74 | assert.Equal(t, q.SQL(), `UPDATE "users" SET "age"={:p0}, "name"={:p1}`, "t2") 75 | } 76 | 77 | func TestStandardBuilder_Delete(t *testing.T) { 78 | b := getStandardBuilder() 79 | q := b.Delete("users", NewExp("id=10")) 80 | assert.Equal(t, q.SQL(), `DELETE FROM "users" WHERE id=10`, "t1") 81 | q = b.Delete("users", nil) 82 | assert.Equal(t, q.SQL(), `DELETE FROM "users"`, "t2") 83 | } 84 | 85 | func TestStandardBuilder_CreateTable(t *testing.T) { 86 | b := getStandardBuilder() 87 | q := b.CreateTable("users", map[string]string{ 88 | "id": "int primary key", 89 | "name": "varchar(255)", 90 | }, "ON DELETE CASCADE") 91 | assert.Equal(t, q.SQL(), "CREATE TABLE \"users\" (\"id\" int primary key, \"name\" varchar(255)) ON DELETE CASCADE", "t1") 92 | } 93 | 94 | func TestStandardBuilder_RenameTable(t *testing.T) { 95 | b := getStandardBuilder() 96 | q := b.RenameTable("users", "user") 97 | assert.Equal(t, q.SQL(), `RENAME TABLE "users" TO "user"`, "t1") 98 | } 99 | 100 | func TestStandardBuilder_DropTable(t *testing.T) { 101 | b := getStandardBuilder() 102 | q := b.DropTable("users") 103 | assert.Equal(t, q.SQL(), `DROP TABLE "users"`, "t1") 104 | } 105 | 106 | func TestStandardBuilder_TruncateTable(t *testing.T) { 107 | b := getStandardBuilder() 108 | q := b.TruncateTable("users") 109 | assert.Equal(t, q.SQL(), `TRUNCATE TABLE "users"`, "t1") 110 | } 111 | 112 | func TestStandardBuilder_AddColumn(t *testing.T) { 113 | b := getStandardBuilder() 114 | q := b.AddColumn("users", "age", "int") 115 | assert.Equal(t, q.SQL(), `ALTER TABLE "users" ADD "age" int`, "t1") 116 | } 117 | 118 | func TestStandardBuilder_DropColumn(t *testing.T) { 119 | b := getStandardBuilder() 120 | q := b.DropColumn("users", "age") 121 | assert.Equal(t, q.SQL(), `ALTER TABLE "users" DROP COLUMN "age"`, "t1") 122 | } 123 | 124 | func TestStandardBuilder_RenameColumn(t *testing.T) { 125 | b := getStandardBuilder() 126 | q := b.RenameColumn("users", "name", "username") 127 | assert.Equal(t, q.SQL(), `ALTER TABLE "users" RENAME COLUMN "name" TO "username"`, "t1") 128 | } 129 | 130 | func TestStandardBuilder_AlterColumn(t *testing.T) { 131 | b := getStandardBuilder() 132 | q := b.AlterColumn("users", "name", "int") 133 | assert.Equal(t, q.SQL(), `ALTER TABLE "users" CHANGE "name" "name" int`, "t1") 134 | } 135 | 136 | func TestStandardBuilder_AddPrimaryKey(t *testing.T) { 137 | b := getStandardBuilder() 138 | q := b.AddPrimaryKey("users", "pk", "id1", "id2") 139 | assert.Equal(t, q.SQL(), `ALTER TABLE "users" ADD CONSTRAINT "pk" PRIMARY KEY ("id1", "id2")`, "t1") 140 | } 141 | 142 | func TestStandardBuilder_DropPrimaryKey(t *testing.T) { 143 | b := getStandardBuilder() 144 | q := b.DropPrimaryKey("users", "pk") 145 | assert.Equal(t, q.SQL(), `ALTER TABLE "users" DROP CONSTRAINT "pk"`, "t1") 146 | } 147 | 148 | func TestStandardBuilder_AddForeignKey(t *testing.T) { 149 | b := getStandardBuilder() 150 | q := b.AddForeignKey("users", "fk", []string{"p1", "p2"}, []string{"f1", "f2"}, "profile", "opt") 151 | assert.Equal(t, q.SQL(), `ALTER TABLE "users" ADD CONSTRAINT "fk" FOREIGN KEY ("p1", "p2") REFERENCES "profile" ("f1", "f2") opt`, "t1") 152 | } 153 | 154 | func TestStandardBuilder_DropForeignKey(t *testing.T) { 155 | b := getStandardBuilder() 156 | q := b.DropForeignKey("users", "fk") 157 | assert.Equal(t, q.SQL(), `ALTER TABLE "users" DROP CONSTRAINT "fk"`, "t1") 158 | } 159 | 160 | func TestStandardBuilder_CreateIndex(t *testing.T) { 161 | b := getStandardBuilder() 162 | q := b.CreateIndex("users", "idx", "id1", "id2") 163 | assert.Equal(t, q.SQL(), `CREATE INDEX "idx" ON "users" ("id1", "id2")`, "t1") 164 | } 165 | 166 | func TestStandardBuilder_CreateUniqueIndex(t *testing.T) { 167 | b := getStandardBuilder() 168 | q := b.CreateUniqueIndex("users", "idx", "id1", "id2") 169 | assert.Equal(t, q.SQL(), `CREATE UNIQUE INDEX "idx" ON "users" ("id1", "id2")`, "t1") 170 | } 171 | 172 | func TestStandardBuilder_DropIndex(t *testing.T) { 173 | b := getStandardBuilder() 174 | q := b.DropIndex("users", "idx") 175 | assert.Equal(t, q.SQL(), `DROP INDEX "idx" ON "users"`, "t1") 176 | } 177 | 178 | func getStandardBuilder() Builder { 179 | db := getDB() 180 | b := NewStandardBuilder(db, db.sqlDB) 181 | db.Builder = b 182 | return b 183 | } 184 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package dbx provides a set of DB-agnostic and easy-to-use query building methods for relational databases. 6 | package dbx 7 | 8 | import ( 9 | "bytes" 10 | "context" 11 | "database/sql" 12 | "regexp" 13 | "strings" 14 | "time" 15 | ) 16 | 17 | type ( 18 | // LogFunc logs a message for each SQL statement being executed. 19 | // This method takes one or multiple parameters. If a single parameter 20 | // is provided, it will be treated as the log message. If multiple parameters 21 | // are provided, they will be passed to fmt.Sprintf() to generate the log message. 22 | LogFunc func(format string, a ...interface{}) 23 | 24 | // PerfFunc is called when a query finishes execution. 25 | // The query execution time is passed to this function so that the DB performance 26 | // can be profiled. The "ns" parameter gives the number of nanoseconds that the 27 | // SQL statement takes to execute, while the "execute" parameter indicates whether 28 | // the SQL statement is executed or queried (usually SELECT statements). 29 | PerfFunc func(ns int64, sql string, execute bool) 30 | 31 | // QueryLogFunc is called each time when performing a SQL query. 32 | // The "t" parameter gives the time that the SQL statement takes to execute, 33 | // while rows and err are the result of the query. 34 | QueryLogFunc func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) 35 | 36 | // ExecLogFunc is called each time when a SQL statement is executed. 37 | // The "t" parameter gives the time that the SQL statement takes to execute, 38 | // while result and err refer to the result of the execution. 39 | ExecLogFunc func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) 40 | 41 | // BuilderFunc creates a Builder instance using the given DB instance and Executor. 42 | BuilderFunc func(*DB, Executor) Builder 43 | 44 | // DB enhances sql.DB by providing a set of DB-agnostic query building methods. 45 | // DB allows easier query building and population of data into Go variables. 46 | DB struct { 47 | Builder 48 | 49 | // FieldMapper maps struct fields to DB columns. Defaults to DefaultFieldMapFunc. 50 | FieldMapper FieldMapFunc 51 | // TableMapper maps structs to table names. Defaults to GetTableName. 52 | TableMapper TableMapFunc 53 | // LogFunc logs the SQL statements being executed. Defaults to nil, meaning no logging. 54 | LogFunc LogFunc 55 | // PerfFunc logs the SQL execution time. Defaults to nil, meaning no performance profiling. 56 | // Deprecated: Please use QueryLogFunc and ExecLogFunc instead. 57 | PerfFunc PerfFunc 58 | // QueryLogFunc is called each time when performing a SQL query that returns data. 59 | QueryLogFunc QueryLogFunc 60 | // ExecLogFunc is called each time when a SQL statement is executed. 61 | ExecLogFunc ExecLogFunc 62 | 63 | sqlDB *sql.DB 64 | driverName string 65 | ctx context.Context 66 | } 67 | 68 | // Errors represents a list of errors. 69 | Errors []error 70 | ) 71 | 72 | // BuilderFuncMap lists supported BuilderFunc according to DB driver names. 73 | // You may modify this variable to add the builder support for a new DB driver. 74 | // If a DB driver is not listed here, the StandardBuilder will be used. 75 | var BuilderFuncMap = map[string]BuilderFunc{ 76 | "sqlite": NewSqliteBuilder, 77 | "sqlite3": NewSqliteBuilder, 78 | "mysql": NewMysqlBuilder, 79 | "postgres": NewPgsqlBuilder, 80 | "pgx": NewPgsqlBuilder, 81 | "mssql": NewMssqlBuilder, 82 | "oci8": NewOciBuilder, 83 | } 84 | 85 | // NewFromDB encapsulates an existing database connection. 86 | func NewFromDB(sqlDB *sql.DB, driverName string) *DB { 87 | db := &DB{ 88 | driverName: driverName, 89 | sqlDB: sqlDB, 90 | FieldMapper: DefaultFieldMapFunc, 91 | TableMapper: GetTableName, 92 | } 93 | db.Builder = db.newBuilder(db.sqlDB) 94 | return db 95 | } 96 | 97 | // Open opens a database specified by a driver name and data source name (DSN). 98 | // Note that Open does not check if DSN is specified correctly. It doesn't try to establish a DB connection either. 99 | // Please refer to sql.Open() for more information. 100 | func Open(driverName, dsn string) (*DB, error) { 101 | sqlDB, err := sql.Open(driverName, dsn) 102 | if err != nil { 103 | return nil, err 104 | } 105 | 106 | return NewFromDB(sqlDB, driverName), nil 107 | } 108 | 109 | // MustOpen opens a database and establishes a connection to it. 110 | // Please refer to sql.Open() and sql.Ping() for more information. 111 | func MustOpen(driverName, dsn string) (*DB, error) { 112 | db, err := Open(driverName, dsn) 113 | if err != nil { 114 | return nil, err 115 | } 116 | if err := db.sqlDB.Ping(); err != nil { 117 | db.Close() 118 | return nil, err 119 | } 120 | return db, nil 121 | } 122 | 123 | // Clone makes a shallow copy of DB. 124 | func (db *DB) Clone() *DB { 125 | db2 := &DB{ 126 | driverName: db.driverName, 127 | sqlDB: db.sqlDB, 128 | FieldMapper: db.FieldMapper, 129 | TableMapper: db.TableMapper, 130 | PerfFunc: db.PerfFunc, 131 | LogFunc: db.LogFunc, 132 | QueryLogFunc: db.QueryLogFunc, 133 | ExecLogFunc: db.ExecLogFunc, 134 | } 135 | db2.Builder = db2.newBuilder(db.sqlDB) 136 | return db2 137 | } 138 | 139 | // WithContext returns a new instance of DB associated with the given context. 140 | func (db *DB) WithContext(ctx context.Context) *DB { 141 | db2 := db.Clone() 142 | db2.ctx = ctx 143 | return db2 144 | } 145 | 146 | // Context returns the context associated with the DB instance. 147 | // It returns nil if no context is associated. 148 | func (db *DB) Context() context.Context { 149 | return db.ctx 150 | } 151 | 152 | // DB returns the sql.DB instance encapsulated by dbx.DB. 153 | func (db *DB) DB() *sql.DB { 154 | return db.sqlDB 155 | } 156 | 157 | // Close closes the database, releasing any open resources. 158 | // It is rare to Close a DB, as the DB handle is meant to be 159 | // long-lived and shared between many goroutines. 160 | func (db *DB) Close() error { 161 | return db.sqlDB.Close() 162 | } 163 | 164 | // Begin starts a transaction. 165 | func (db *DB) Begin() (*Tx, error) { 166 | var tx *sql.Tx 167 | var err error 168 | if db.ctx != nil { 169 | tx, err = db.sqlDB.BeginTx(db.ctx, nil) 170 | } else { 171 | tx, err = db.sqlDB.Begin() 172 | } 173 | if err != nil { 174 | return nil, err 175 | } 176 | return &Tx{db.newBuilder(tx), tx}, nil 177 | } 178 | 179 | // BeginTx starts a transaction with the given context and transaction options. 180 | func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { 181 | tx, err := db.sqlDB.BeginTx(ctx, opts) 182 | if err != nil { 183 | return nil, err 184 | } 185 | return &Tx{db.newBuilder(tx), tx}, nil 186 | } 187 | 188 | // Wrap encapsulates an existing transaction. 189 | func (db *DB) Wrap(sqlTx *sql.Tx) *Tx { 190 | return &Tx{db.newBuilder(sqlTx), sqlTx} 191 | } 192 | 193 | // Transactional starts a transaction and executes the given function. 194 | // If the function returns an error, the transaction will be rolled back. 195 | // Otherwise, the transaction will be committed. 196 | func (db *DB) Transactional(f func(*Tx) error) (err error) { 197 | tx, err := db.Begin() 198 | if err != nil { 199 | return err 200 | } 201 | 202 | defer func() { 203 | if p := recover(); p != nil { 204 | tx.Rollback() 205 | panic(p) 206 | } else if err != nil { 207 | if err2 := tx.Rollback(); err2 != nil { 208 | if err2 == sql.ErrTxDone { 209 | return 210 | } 211 | err = Errors{err, err2} 212 | } 213 | } else { 214 | if err = tx.Commit(); err == sql.ErrTxDone { 215 | err = nil 216 | } 217 | } 218 | }() 219 | 220 | err = f(tx) 221 | 222 | return err 223 | } 224 | 225 | // TransactionalContext starts a transaction and executes the given function with the given context and transaction options. 226 | // If the function returns an error, the transaction will be rolled back. 227 | // Otherwise, the transaction will be committed. 228 | func (db *DB) TransactionalContext(ctx context.Context, opts *sql.TxOptions, f func(*Tx) error) (err error) { 229 | tx, err := db.BeginTx(ctx, opts) 230 | if err != nil { 231 | return err 232 | } 233 | 234 | defer func() { 235 | if p := recover(); p != nil { 236 | tx.Rollback() 237 | panic(p) 238 | } else if err != nil { 239 | if err2 := tx.Rollback(); err2 != nil { 240 | if err2 == sql.ErrTxDone { 241 | return 242 | } 243 | err = Errors{err, err2} 244 | } 245 | } else { 246 | if err = tx.Commit(); err == sql.ErrTxDone { 247 | err = nil 248 | } 249 | } 250 | }() 251 | 252 | err = f(tx) 253 | 254 | return err 255 | } 256 | 257 | // DriverName returns the name of the DB driver. 258 | func (db *DB) DriverName() string { 259 | return db.driverName 260 | } 261 | 262 | // QuoteTableName quotes the given table name appropriately. 263 | // If the table name contains DB schema prefix, it will be handled accordingly. 264 | // This method will do nothing if the table name is already quoted or if it contains parenthesis. 265 | func (db *DB) QuoteTableName(s string) string { 266 | if strings.Contains(s, "(") || strings.Contains(s, "{{") { 267 | return s 268 | } 269 | if !strings.Contains(s, ".") { 270 | return db.QuoteSimpleTableName(s) 271 | } 272 | parts := strings.Split(s, ".") 273 | for i, part := range parts { 274 | parts[i] = db.QuoteSimpleTableName(part) 275 | } 276 | return strings.Join(parts, ".") 277 | } 278 | 279 | // QuoteColumnName quotes the given column name appropriately. 280 | // If the table name contains table name prefix, it will be handled accordingly. 281 | // This method will do nothing if the column name is already quoted or if it contains parenthesis. 282 | func (db *DB) QuoteColumnName(s string) string { 283 | if strings.Contains(s, "(") || strings.Contains(s, "{{") || strings.Contains(s, "[[") { 284 | return s 285 | } 286 | prefix := "" 287 | if pos := strings.LastIndex(s, "."); pos != -1 { 288 | prefix = db.QuoteTableName(s[:pos]) + "." 289 | s = s[pos+1:] 290 | } 291 | return prefix + db.QuoteSimpleColumnName(s) 292 | } 293 | 294 | var ( 295 | plRegex = regexp.MustCompile(`\{:\w+\}`) 296 | quoteRegex = regexp.MustCompile(`(\{\{[\w\-\. ]+\}\}|\[\[[\w\-\. ]+\]\])`) 297 | ) 298 | 299 | // processSQL replaces the named param placeholders in the given SQL with anonymous ones. 300 | // It also quotes table names and column names found in the SQL if these names are enclosed 301 | // within double square/curly brackets. The method will return the updated SQL and the list of parameter names. 302 | func (db *DB) processSQL(s string) (string, []string) { 303 | var placeholders []string 304 | count := 0 305 | s = plRegex.ReplaceAllStringFunc(s, func(m string) string { 306 | count++ 307 | placeholders = append(placeholders, m[2:len(m)-1]) 308 | return db.GeneratePlaceholder(count) 309 | }) 310 | s = quoteRegex.ReplaceAllStringFunc(s, func(m string) string { 311 | if m[0] == '{' { 312 | return db.QuoteTableName(m[2 : len(m)-2]) 313 | } 314 | return db.QuoteColumnName(m[2 : len(m)-2]) 315 | }) 316 | return s, placeholders 317 | } 318 | 319 | // newBuilder creates a query builder based on the current driver name. 320 | func (db *DB) newBuilder(executor Executor) Builder { 321 | builderFunc, ok := BuilderFuncMap[db.driverName] 322 | if !ok { 323 | builderFunc = NewStandardBuilder 324 | } 325 | return builderFunc(db, executor) 326 | } 327 | 328 | // Error returns the error string of Errors. 329 | func (errs Errors) Error() string { 330 | var b bytes.Buffer 331 | for i, e := range errs { 332 | if i > 0 { 333 | b.WriteRune('\n') 334 | } 335 | b.WriteString(e.Error()) 336 | } 337 | return b.String() 338 | } 339 | -------------------------------------------------------------------------------- /db_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | "errors" 11 | "io/ioutil" 12 | "strings" 13 | "testing" 14 | 15 | // @todo change to sqlite 16 | _ "github.com/go-sql-driver/mysql" 17 | "github.com/stretchr/testify/assert" 18 | ) 19 | 20 | const ( 21 | TestDSN = "travis:@/pocketbase_dbx_test?parseTime=true" 22 | FixtureFile = "testdata/mysql.sql" 23 | ) 24 | 25 | func TestDB_NewFromDB(t *testing.T) { 26 | sqlDB, err := sql.Open("mysql", TestDSN) 27 | if assert.Nil(t, err) { 28 | db := NewFromDB(sqlDB, "mysql") 29 | assert.NotNil(t, db.sqlDB) 30 | assert.NotNil(t, db.FieldMapper) 31 | } 32 | } 33 | 34 | func TestDB_Open(t *testing.T) { 35 | db, err := Open("mysql", TestDSN) 36 | assert.Nil(t, err) 37 | if assert.NotNil(t, db) { 38 | assert.NotNil(t, db.sqlDB) 39 | assert.NotNil(t, db.FieldMapper) 40 | db2 := db.Clone() 41 | assert.NotEqual(t, db, db2) 42 | assert.Equal(t, db.driverName, db2.driverName) 43 | ctx := context.Background() 44 | db3 := db.WithContext(ctx) 45 | assert.Equal(t, ctx, db3.ctx) 46 | assert.Equal(t, ctx, db3.Context()) 47 | assert.NotEqual(t, db, db3) 48 | } 49 | 50 | _, err = Open("xyz", TestDSN) 51 | assert.NotNil(t, err) 52 | } 53 | 54 | func TestDB_MustOpen(t *testing.T) { 55 | _, err := MustOpen("mysql", TestDSN) 56 | assert.Nil(t, err) 57 | 58 | _, err = MustOpen("mysql", "unknown:x@/test") 59 | assert.NotNil(t, err) 60 | } 61 | 62 | func TestDB_Close(t *testing.T) { 63 | db := getDB() 64 | assert.Nil(t, db.Close()) 65 | } 66 | 67 | func TestDB_DriverName(t *testing.T) { 68 | db := getDB() 69 | assert.Equal(t, "mysql", db.DriverName()) 70 | } 71 | 72 | func TestDB_QuoteTableName(t *testing.T) { 73 | tests := []struct { 74 | input, output string 75 | }{ 76 | {"users", "`users`"}, 77 | {"`users`", "`users`"}, 78 | {"(select)", "(select)"}, 79 | {"{{users}}", "{{users}}"}, 80 | {"public.db1.users", "`public`.`db1`.`users`"}, 81 | } 82 | db := getDB() 83 | for _, test := range tests { 84 | result := db.QuoteTableName(test.input) 85 | assert.Equal(t, test.output, result, test.input) 86 | } 87 | } 88 | 89 | func TestDB_QuoteColumnName(t *testing.T) { 90 | tests := []struct { 91 | input, output string 92 | }{ 93 | {"*", "*"}, 94 | {"users.*", "`users`.*"}, 95 | {"name", "`name`"}, 96 | {"`name`", "`name`"}, 97 | {"(select)", "(select)"}, 98 | {"{{name}}", "{{name}}"}, 99 | {"[[name]]", "[[name]]"}, 100 | {"public.db1.users", "`public`.`db1`.`users`"}, 101 | } 102 | db := getDB() 103 | for _, test := range tests { 104 | result := db.QuoteColumnName(test.input) 105 | assert.Equal(t, test.output, result, test.input) 106 | } 107 | } 108 | 109 | func TestDB_ProcessSQL(t *testing.T) { 110 | tests := []struct { 111 | tag string 112 | sql string // original SQL 113 | mysql string // expected MySQL version 114 | postgres string // expected PostgreSQL version 115 | oci8 string // expected OCI version 116 | params []string // expected params 117 | }{ 118 | { 119 | "normal case", 120 | `INSERT INTO employee (id, name, age) VALUES ({:id}, {:name}, {:age})`, 121 | `INSERT INTO employee (id, name, age) VALUES (?, ?, ?)`, 122 | `INSERT INTO employee (id, name, age) VALUES ($1, $2, $3)`, 123 | `INSERT INTO employee (id, name, age) VALUES (:p1, :p2, :p3)`, 124 | []string{"id", "name", "age"}, 125 | }, 126 | { 127 | "the same placeholder is used twice", 128 | `SELECT * FROM employee WHERE first_name LIKE {:keyword} OR last_name LIKE {:keyword}`, 129 | `SELECT * FROM employee WHERE first_name LIKE ? OR last_name LIKE ?`, 130 | `SELECT * FROM employee WHERE first_name LIKE $1 OR last_name LIKE $2`, 131 | `SELECT * FROM employee WHERE first_name LIKE :p1 OR last_name LIKE :p2`, 132 | []string{"keyword", "keyword"}, 133 | }, 134 | { 135 | "non-matching placeholder", 136 | `SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE {:keyword}`, 137 | `SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE ?`, 138 | `SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE $1`, 139 | `SELECT * FROM employee WHERE first_name LIKE "{:key?word}" OR last_name LIKE :p1`, 140 | []string{"keyword"}, 141 | }, 142 | { 143 | "quote table/column names", 144 | `SELECT * FROM {{public.user}} WHERE [[user.id]]=1`, 145 | "SELECT * FROM `public`.`user` WHERE `user`.`id`=1", 146 | "SELECT * FROM \"public\".\"user\" WHERE \"user\".\"id\"=1", 147 | "SELECT * FROM \"public\".\"user\" WHERE \"user\".\"id\"=1", 148 | nil, 149 | }, 150 | } 151 | 152 | mysqlDB := getDB() 153 | mysqlDB.Builder = NewMysqlBuilder(nil, nil) 154 | pgsqlDB := getDB() 155 | pgsqlDB.Builder = NewPgsqlBuilder(nil, nil) 156 | ociDB := getDB() 157 | ociDB.Builder = NewOciBuilder(nil, nil) 158 | 159 | for _, test := range tests { 160 | s1, names := mysqlDB.processSQL(test.sql) 161 | assert.Equal(t, test.mysql, s1, test.tag) 162 | s2, _ := pgsqlDB.processSQL(test.sql) 163 | assert.Equal(t, test.postgres, s2, test.tag) 164 | s3, _ := ociDB.processSQL(test.sql) 165 | assert.Equal(t, test.oci8, s3, test.tag) 166 | 167 | assert.Equal(t, test.params, names, test.tag) 168 | } 169 | } 170 | 171 | func TestDB_Begin(t *testing.T) { 172 | tests := []struct { 173 | makeTx func(db *DB) *Tx 174 | desc string 175 | }{ 176 | { 177 | makeTx: func(db *DB) *Tx { 178 | tx, _ := db.Begin() 179 | return tx 180 | }, 181 | desc: "Begin", 182 | }, 183 | { 184 | makeTx: func(db *DB) *Tx { 185 | sqlTx, _ := db.DB().Begin() 186 | return db.Wrap(sqlTx) 187 | }, 188 | desc: "Wrap", 189 | }, 190 | { 191 | makeTx: func(db *DB) *Tx { 192 | tx, _ := db.BeginTx(context.Background(), nil) 193 | return tx 194 | }, 195 | desc: "BeginTx", 196 | }, 197 | } 198 | 199 | db := getPreparedDB() 200 | 201 | var ( 202 | lastID int 203 | name string 204 | tx *Tx 205 | ) 206 | db.NewQuery("SELECT MAX(id) FROM item").Row(&lastID) 207 | 208 | for _, test := range tests { 209 | t.Log(test.desc) 210 | 211 | tx = test.makeTx(db) 212 | _, err1 := tx.Insert("item", Params{ 213 | "name": "name1", 214 | }).Execute() 215 | _, err2 := tx.Insert("item", Params{ 216 | "name": "name2", 217 | }).Execute() 218 | if err1 == nil && err2 == nil { 219 | tx.Commit() 220 | } else { 221 | t.Errorf("Unexpected TX rollback: %v, %v", err1, err2) 222 | tx.Rollback() 223 | } 224 | 225 | q := db.NewQuery("SELECT name FROM item WHERE id={:id}") 226 | q.Bind(Params{"id": lastID + 1}).Row(&name) 227 | assert.Equal(t, "name1", name) 228 | q.Bind(Params{"id": lastID + 2}).Row(&name) 229 | assert.Equal(t, "name2", name) 230 | 231 | tx = test.makeTx(db) 232 | _, err3 := tx.NewQuery("DELETE FROM item WHERE id=7").Execute() 233 | _, err4 := tx.NewQuery("DELETE FROM items WHERE id=7").Execute() 234 | if err3 == nil && err4 == nil { 235 | t.Error("Unexpected TX commit") 236 | tx.Commit() 237 | } else { 238 | tx.Rollback() 239 | } 240 | } 241 | } 242 | 243 | func TestDB_Transactional(t *testing.T) { 244 | db := getPreparedDB() 245 | 246 | var ( 247 | lastID int 248 | name string 249 | ) 250 | db.NewQuery("SELECT MAX(id) FROM item").Row(&lastID) 251 | 252 | err := db.Transactional(func(tx *Tx) error { 253 | _, err := tx.Insert("item", Params{ 254 | "name": "name1", 255 | }).Execute() 256 | if err != nil { 257 | return err 258 | } 259 | _, err = tx.Insert("item", Params{ 260 | "name": "name2", 261 | }).Execute() 262 | if err != nil { 263 | return err 264 | } 265 | return nil 266 | }) 267 | 268 | if assert.Nil(t, err) { 269 | q := db.NewQuery("SELECT name FROM item WHERE id={:id}") 270 | q.Bind(Params{"id": lastID + 1}).Row(&name) 271 | assert.Equal(t, "name1", name) 272 | q.Bind(Params{"id": lastID + 2}).Row(&name) 273 | assert.Equal(t, "name2", name) 274 | } 275 | 276 | err = db.Transactional(func(tx *Tx) error { 277 | _, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute() 278 | if err != nil { 279 | return err 280 | } 281 | _, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute() 282 | if err != nil { 283 | return err 284 | } 285 | return nil 286 | }) 287 | if assert.NotNil(t, err) { 288 | db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name) 289 | assert.Equal(t, "Go in Action", name) 290 | } 291 | 292 | // Rollback called within Transactional and return error 293 | err = db.Transactional(func(tx *Tx) error { 294 | _, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute() 295 | if err != nil { 296 | return err 297 | } 298 | _, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute() 299 | if err != nil { 300 | tx.Rollback() 301 | return err 302 | } 303 | return nil 304 | }) 305 | if assert.NotNil(t, err) { 306 | db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name) 307 | assert.Equal(t, "Go in Action", name) 308 | } 309 | 310 | // Rollback called within Transactional without returning error 311 | err = db.Transactional(func(tx *Tx) error { 312 | _, err := tx.NewQuery("DELETE FROM item WHERE id=2").Execute() 313 | if err != nil { 314 | return err 315 | } 316 | _, err = tx.NewQuery("DELETE FROM items WHERE id=2").Execute() 317 | if err != nil { 318 | tx.Rollback() 319 | return nil 320 | } 321 | return nil 322 | }) 323 | if assert.Nil(t, err) { 324 | db.NewQuery("SELECT name FROM item WHERE id=2").Row(&name) 325 | assert.Equal(t, "Go in Action", name) 326 | } 327 | } 328 | 329 | func TestErrors_Error(t *testing.T) { 330 | errs := Errors{} 331 | assert.Equal(t, "", errs.Error()) 332 | errs = Errors{errors.New("a")} 333 | assert.Equal(t, "a", errs.Error()) 334 | errs = Errors{errors.New("a"), errors.New("b")} 335 | assert.Equal(t, "a\nb", errs.Error()) 336 | } 337 | 338 | func getDB() *DB { 339 | db, err := Open("mysql", TestDSN) 340 | if err != nil { 341 | panic(err) 342 | } 343 | return db 344 | } 345 | 346 | func getPreparedDB() *DB { 347 | db := getDB() 348 | s, err := ioutil.ReadFile(FixtureFile) 349 | if err != nil { 350 | panic(err) 351 | } 352 | lines := strings.Split(string(s), ";") 353 | for _, line := range lines { 354 | if strings.TrimSpace(line) == "" { 355 | continue 356 | } 357 | if _, err := db.NewQuery(line).Execute(); err != nil { 358 | panic(err) 359 | } 360 | } 361 | return db 362 | } 363 | 364 | // Naming according to issue 49 ( https://github.com/pocketbase/dbx/issues/49 ) 365 | 366 | type ArtistDAO struct { 367 | nickname string 368 | } 369 | 370 | func (ArtistDAO) TableName() string { 371 | return "artists" 372 | } 373 | 374 | func Test_TableNameWithPrefix(t *testing.T) { 375 | db := NewFromDB(nil, "mysql") 376 | db.TableMapper = func(a interface{}) string { 377 | return "tbl_" + GetTableName(a) 378 | } 379 | assert.Equal(t, "tbl_artists", db.TableMapper(ArtistDAO{})) 380 | } 381 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package dbx_test 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pocketbase/dbx" 7 | ) 8 | 9 | // This example shows how to populate DB data in different ways. 10 | func Example_dbQueries() { 11 | db, _ := dbx.Open("mysql", "user:pass@/example") 12 | 13 | // create a new query 14 | q := db.NewQuery("SELECT id, name FROM users LIMIT 10") 15 | 16 | // fetch all rows into a struct array 17 | var users []struct { 18 | ID, Name string 19 | } 20 | q.All(&users) 21 | 22 | // fetch a single row into a struct 23 | var user struct { 24 | ID, Name string 25 | } 26 | q.One(&user) 27 | 28 | // fetch a single row into a string map 29 | data := dbx.NullStringMap{} 30 | q.One(data) 31 | 32 | // fetch row by row 33 | rows2, _ := q.Rows() 34 | for rows2.Next() { 35 | rows2.ScanStruct(&user) 36 | // rows.ScanMap(data) 37 | // rows.Scan(&id, &name) 38 | } 39 | } 40 | 41 | // This example shows how to use query builder to build DB queries. 42 | func Example_queryBuilder() { 43 | db, _ := dbx.Open("mysql", "user:pass@/example") 44 | 45 | // build a SELECT query 46 | // SELECT `id`, `name` FROM `users` WHERE `name` LIKE '%Charles%' ORDER BY `id` 47 | q := db.Select("id", "name"). 48 | From("users"). 49 | Where(dbx.Like("name", "Charles")). 50 | OrderBy("id") 51 | 52 | // fetch all rows into a struct array 53 | var users []struct { 54 | ID, Name string 55 | } 56 | q.All(&users) 57 | 58 | // build an INSERT query 59 | // INSERT INTO `users` (name) VALUES ('James') 60 | db.Insert("users", dbx.Params{ 61 | "name": "James", 62 | }).Execute() 63 | } 64 | 65 | // This example shows how to use query builder in transactions. 66 | func Example_transactions() { 67 | db, _ := dbx.Open("mysql", "user:pass@/example") 68 | 69 | db.Transactional(func(tx *dbx.Tx) error { 70 | _, err := tx.Insert("user", dbx.Params{ 71 | "name": "user1", 72 | }).Execute() 73 | if err != nil { 74 | return err 75 | } 76 | _, err = tx.Insert("user", dbx.Params{ 77 | "name": "user2", 78 | }).Execute() 79 | return err 80 | }) 81 | } 82 | 83 | type Customer struct { 84 | ID string 85 | Name string 86 | } 87 | 88 | // This example shows how to do CRUD operations. 89 | func Example_crudOperations() { 90 | db, _ := dbx.Open("mysql", "user:pass@/example") 91 | 92 | var customer Customer 93 | 94 | // read a customer: SELECT * FROM customer WHERE id=100 95 | db.Select().Model(100, &customer) 96 | 97 | // create a customer: INSERT INTO customer (name) VALUES ('test') 98 | db.Model(&customer).Insert() 99 | 100 | // update a customer: UPDATE customer SET name='test' WHERE id=100 101 | db.Model(&customer).Update() 102 | 103 | // delete a customer: DELETE FROM customer WHERE id=100 104 | db.Model(&customer).Delete() 105 | } 106 | 107 | func ExampleSchemaBuilder() { 108 | db, _ := dbx.Open("mysql", "user:pass@/example") 109 | 110 | db.Insert("users", dbx.Params{ 111 | "name": "James", 112 | "age": 30, 113 | }).Execute() 114 | } 115 | 116 | func ExampleRows_ScanMap() { 117 | db, _ := dbx.Open("mysql", "user:pass@/example") 118 | 119 | user := dbx.NullStringMap{} 120 | 121 | sql := "SELECT id, name FROM users LIMIT 10" 122 | rows, _ := db.NewQuery(sql).Rows() 123 | for rows.Next() { 124 | rows.ScanMap(user) 125 | // ... 126 | } 127 | } 128 | 129 | func ExampleRows_ScanStruct() { 130 | db, _ := dbx.Open("mysql", "user:pass@/example") 131 | 132 | var user struct { 133 | ID, Name string 134 | } 135 | 136 | sql := "SELECT id, name FROM users LIMIT 10" 137 | rows, _ := db.NewQuery(sql).Rows() 138 | for rows.Next() { 139 | rows.ScanStruct(&user) 140 | // ... 141 | } 142 | } 143 | 144 | func ExampleQuery_All() { 145 | db, _ := dbx.Open("mysql", "user:pass@/example") 146 | sql := "SELECT id, name FROM users LIMIT 10" 147 | 148 | // fetches data into a slice of struct 149 | var users []struct { 150 | ID, Name string 151 | } 152 | db.NewQuery(sql).All(&users) 153 | 154 | // fetches data into a slice of NullStringMap 155 | var users2 []dbx.NullStringMap 156 | db.NewQuery(sql).All(&users2) 157 | for _, user := range users2 { 158 | fmt.Println(user["id"].String, user["name"].String) 159 | } 160 | } 161 | 162 | func ExampleQuery_One() { 163 | db, _ := dbx.Open("mysql", "user:pass@/example") 164 | sql := "SELECT id, name FROM users LIMIT 10" 165 | 166 | // fetches data into a struct 167 | var user struct { 168 | ID, Name string 169 | } 170 | db.NewQuery(sql).One(&user) 171 | 172 | // fetches data into a NullStringMap 173 | var user2 dbx.NullStringMap 174 | db.NewQuery(sql).All(user2) 175 | fmt.Println(user2["id"].String, user2["name"].String) 176 | } 177 | 178 | func ExampleQuery_Row() { 179 | db, _ := dbx.Open("mysql", "user:pass@/example") 180 | sql := "SELECT id, name FROM users LIMIT 10" 181 | 182 | // fetches data into a struct 183 | var ( 184 | id int 185 | name string 186 | ) 187 | db.NewQuery(sql).Row(&id, &name) 188 | } 189 | 190 | func ExampleQuery_Rows() { 191 | var user struct { 192 | ID, Name string 193 | } 194 | 195 | db, _ := dbx.Open("mysql", "user:pass@/example") 196 | sql := "SELECT id, name FROM users LIMIT 10" 197 | 198 | rows, _ := db.NewQuery(sql).Rows() 199 | for rows.Next() { 200 | rows.ScanStruct(&user) 201 | // ... 202 | } 203 | } 204 | 205 | func ExampleQuery_Bind() { 206 | var user struct { 207 | ID, Name string 208 | } 209 | 210 | db, _ := dbx.Open("mysql", "user:pass@/example") 211 | sql := "SELECT id, name FROM users WHERE age>{:age} AND status={:status}" 212 | 213 | q := db.NewQuery(sql) 214 | q.Bind(dbx.Params{"age": 30, "status": 1}).One(&user) 215 | } 216 | 217 | func ExampleQuery_Prepare() { 218 | var users1, users2, users3 []struct { 219 | ID, Name string 220 | } 221 | 222 | db, _ := dbx.Open("mysql", "user:pass@/example") 223 | sql := "SELECT id, name FROM users WHERE age>{:age} AND status={:status}" 224 | 225 | q := db.NewQuery(sql).Prepare() 226 | defer q.Close() 227 | 228 | q.Bind(dbx.Params{"age": 30, "status": 1}).All(&users1) 229 | q.Bind(dbx.Params{"age": 20, "status": 1}).All(&users2) 230 | q.Bind(dbx.Params{"age": 10, "status": 1}).All(&users3) 231 | } 232 | 233 | func ExampleDB() { 234 | db, _ := dbx.Open("mysql", "user:pass@/example") 235 | 236 | // queries data through a plain SQL 237 | var users []struct { 238 | ID, Name string 239 | } 240 | db.NewQuery("SELECT id, name FROM users WHERE age=30").All(&users) 241 | 242 | // queries data using query builder 243 | db.Select("id", "name").From("users").Where(dbx.HashExp{"age": 30}).All(&users) 244 | 245 | // executes a plain SQL 246 | db.NewQuery("INSERT INTO users (name) SET ({:name})").Bind(dbx.Params{"name": "James"}).Execute() 247 | 248 | // executes a SQL using query builder 249 | db.Insert("users", dbx.Params{"name": "James"}).Execute() 250 | } 251 | 252 | func ExampleDB_Open() { 253 | db, err := dbx.Open("mysql", "user:pass@/example") 254 | if err != nil { 255 | panic(err) 256 | } 257 | 258 | var users []dbx.NullStringMap 259 | if err := db.NewQuery("SELECT * FROM users LIMIT 10").All(&users); err != nil { 260 | panic(err) 261 | } 262 | } 263 | 264 | func ExampleDB_Begin() { 265 | db, _ := dbx.Open("mysql", "user:pass@/example") 266 | 267 | tx, _ := db.Begin() 268 | 269 | _, err1 := tx.Insert("user", dbx.Params{ 270 | "name": "user1", 271 | }).Execute() 272 | _, err2 := tx.Insert("user", dbx.Params{ 273 | "name": "user2", 274 | }).Execute() 275 | 276 | if err1 == nil && err2 == nil { 277 | tx.Commit() 278 | } else { 279 | tx.Rollback() 280 | } 281 | } 282 | -------------------------------------------------------------------------------- /expression.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "fmt" 9 | "sort" 10 | "strings" 11 | ) 12 | 13 | // Expression represents a DB expression that can be embedded in a SQL statement. 14 | type Expression interface { 15 | // Build converts an expression into a SQL fragment. 16 | // If the expression contains binding parameters, they will be added to the given Params. 17 | Build(*DB, Params) string 18 | } 19 | 20 | // HashExp represents a hash expression. 21 | // 22 | // A hash expression is a map whose keys are DB column names which need to be filtered according 23 | // to the corresponding values. For example, HashExp{"level": 2, "dept": 10} will generate 24 | // the SQL: "level"=2 AND "dept"=10. 25 | // 26 | // HashExp also handles nil values and slice values. For example, HashExp{"level": []interface{}{1, 2}, "dept": nil} 27 | // will generate: "level" IN (1, 2) AND "dept" IS NULL. 28 | type HashExp map[string]interface{} 29 | 30 | // NewExp generates an expression with the specified SQL fragment and the optional binding parameters. 31 | func NewExp(e string, params ...Params) Expression { 32 | if len(params) > 0 { 33 | return &Exp{e, params[0]} 34 | } 35 | return &Exp{e, nil} 36 | } 37 | 38 | // Not generates a NOT expression which prefixes "NOT" to the specified expression. 39 | func Not(e Expression) Expression { 40 | return &NotExp{e} 41 | } 42 | 43 | // And generates an AND expression which concatenates the given expressions with "AND". 44 | func And(exps ...Expression) Expression { 45 | return &AndOrExp{exps, "AND"} 46 | } 47 | 48 | // Or generates an OR expression which concatenates the given expressions with "OR". 49 | func Or(exps ...Expression) Expression { 50 | return &AndOrExp{exps, "OR"} 51 | } 52 | 53 | // In generates an IN expression for the specified column and the list of allowed values. 54 | // If values is empty, a SQL "0=1" will be generated which represents a false expression. 55 | func In(col string, values ...interface{}) Expression { 56 | return &InExp{col, values, false} 57 | } 58 | 59 | // NotIn generates an NOT IN expression for the specified column and the list of disallowed values. 60 | // If values is empty, an empty string will be returned indicating a true expression. 61 | func NotIn(col string, values ...interface{}) Expression { 62 | return &InExp{col, values, true} 63 | } 64 | 65 | // DefaultLikeEscape specifies the default special character escaping for LIKE expressions 66 | // The strings at 2i positions are the special characters to be escaped while those at 2i+1 positions 67 | // are the corresponding escaped versions. 68 | var DefaultLikeEscape = []string{"\\", "\\\\", "%", "\\%", "_", "\\_"} 69 | 70 | // Like generates a LIKE expression for the specified column and the possible strings that the column should be like. 71 | // If multiple values are present, the column should be like *all* of them. For example, Like("name", "key", "word") 72 | // will generate a SQL expression: "name" LIKE "%key%" AND "name" LIKE "%word%". 73 | // 74 | // By default, each value will be surrounded by "%" to enable partial matching. If a value contains special characters 75 | // such as "%", "\", "_", they will also be properly escaped. 76 | // 77 | // You may call Escape() and/or Match() to change the default behavior. For example, Like("name", "key").Match(false, true) 78 | // generates "name" LIKE "key%". 79 | func Like(col string, values ...string) *LikeExp { 80 | return &LikeExp{ 81 | left: true, 82 | right: true, 83 | col: col, 84 | values: values, 85 | escape: DefaultLikeEscape, 86 | Like: "LIKE", 87 | } 88 | } 89 | 90 | // NotLike generates a NOT LIKE expression. 91 | // For example, NotLike("name", "key", "word") will generate a SQL expression: 92 | // "name" NOT LIKE "%key%" AND "name" NOT LIKE "%word%". Please see Like() for more details. 93 | func NotLike(col string, values ...string) *LikeExp { 94 | return &LikeExp{ 95 | left: true, 96 | right: true, 97 | col: col, 98 | values: values, 99 | escape: DefaultLikeEscape, 100 | Like: "NOT LIKE", 101 | } 102 | } 103 | 104 | // OrLike generates an OR LIKE expression. 105 | // This is similar to Like() except that the column should be like one of the possible values. 106 | // For example, OrLike("name", "key", "word") will generate a SQL expression: 107 | // "name" LIKE "%key%" OR "name" LIKE "%word%". Please see Like() for more details. 108 | func OrLike(col string, values ...string) *LikeExp { 109 | return &LikeExp{ 110 | or: true, 111 | left: true, 112 | right: true, 113 | col: col, 114 | values: values, 115 | escape: DefaultLikeEscape, 116 | Like: "LIKE", 117 | } 118 | } 119 | 120 | // OrNotLike generates an OR NOT LIKE expression. 121 | // For example, OrNotLike("name", "key", "word") will generate a SQL expression: 122 | // "name" NOT LIKE "%key%" OR "name" NOT LIKE "%word%". Please see Like() for more details. 123 | func OrNotLike(col string, values ...string) *LikeExp { 124 | return &LikeExp{ 125 | or: true, 126 | left: true, 127 | right: true, 128 | col: col, 129 | values: values, 130 | escape: DefaultLikeEscape, 131 | Like: "NOT LIKE", 132 | } 133 | } 134 | 135 | // Exists generates an EXISTS expression by prefixing "EXISTS" to the given expression. 136 | func Exists(exp Expression) Expression { 137 | return &ExistsExp{exp, false} 138 | } 139 | 140 | // NotExists generates an EXISTS expression by prefixing "NOT EXISTS" to the given expression. 141 | func NotExists(exp Expression) Expression { 142 | return &ExistsExp{exp, true} 143 | } 144 | 145 | // Between generates a BETWEEN expression. 146 | // For example, Between("age", 10, 30) generates: "age" BETWEEN 10 AND 30 147 | func Between(col string, from, to interface{}) Expression { 148 | return &BetweenExp{col, from, to, false} 149 | } 150 | 151 | // NotBetween generates a NOT BETWEEN expression. 152 | // For example, NotBetween("age", 10, 30) generates: "age" NOT BETWEEN 10 AND 30 153 | func NotBetween(col string, from, to interface{}) Expression { 154 | return &BetweenExp{col, from, to, true} 155 | } 156 | 157 | // Exp represents an expression with a SQL fragment and a list of optional binding parameters. 158 | type Exp struct { 159 | e string 160 | params Params 161 | } 162 | 163 | // Build converts an expression into a SQL fragment. 164 | func (e *Exp) Build(db *DB, params Params) string { 165 | if len(e.params) == 0 { 166 | return e.e 167 | } 168 | for k, v := range e.params { 169 | params[k] = v 170 | } 171 | return e.e 172 | } 173 | 174 | // Build converts an expression into a SQL fragment. 175 | func (e HashExp) Build(db *DB, params Params) string { 176 | if len(e) == 0 { 177 | return "" 178 | } 179 | 180 | // ensure the hash exp generates the same SQL for different runs 181 | names := []string{} 182 | for name := range e { 183 | names = append(names, name) 184 | } 185 | sort.Strings(names) 186 | 187 | var parts []string 188 | for _, name := range names { 189 | value := e[name] 190 | switch value.(type) { 191 | case nil: 192 | name = db.QuoteColumnName(name) 193 | parts = append(parts, name+" IS NULL") 194 | case Expression: 195 | if sql := value.(Expression).Build(db, params); sql != "" { 196 | parts = append(parts, "("+sql+")") 197 | } 198 | case []interface{}: 199 | in := In(name, value.([]interface{})...) 200 | if sql := in.Build(db, params); sql != "" { 201 | parts = append(parts, sql) 202 | } 203 | default: 204 | pn := fmt.Sprintf("p%v", len(params)) 205 | name = db.QuoteColumnName(name) 206 | parts = append(parts, name+"={:"+pn+"}") 207 | params[pn] = value 208 | } 209 | } 210 | if len(parts) == 1 { 211 | return parts[0] 212 | } 213 | return strings.Join(parts, " AND ") 214 | } 215 | 216 | // NotExp represents an expression that should prefix "NOT" to a specified expression. 217 | type NotExp struct { 218 | e Expression 219 | } 220 | 221 | // Build converts an expression into a SQL fragment. 222 | func (e *NotExp) Build(db *DB, params Params) string { 223 | if sql := e.e.Build(db, params); sql != "" { 224 | return "NOT (" + sql + ")" 225 | } 226 | return "" 227 | } 228 | 229 | // AndOrExp represents an expression that concatenates multiple expressions using either "AND" or "OR". 230 | type AndOrExp struct { 231 | exps []Expression 232 | op string 233 | } 234 | 235 | // Build converts an expression into a SQL fragment. 236 | func (e *AndOrExp) Build(db *DB, params Params) string { 237 | if len(e.exps) == 0 { 238 | return "" 239 | } 240 | 241 | var parts []string 242 | for _, a := range e.exps { 243 | if a == nil { 244 | continue 245 | } 246 | if sql := a.Build(db, params); sql != "" { 247 | parts = append(parts, sql) 248 | } 249 | } 250 | if len(parts) == 1 { 251 | return parts[0] 252 | } 253 | return "(" + strings.Join(parts, ") "+e.op+" (") + ")" 254 | } 255 | 256 | // InExp represents an "IN" or "NOT IN" expression. 257 | type InExp struct { 258 | col string 259 | values []interface{} 260 | not bool 261 | } 262 | 263 | // Build converts an expression into a SQL fragment. 264 | func (e *InExp) Build(db *DB, params Params) string { 265 | if len(e.values) == 0 { 266 | if e.not { 267 | return "" 268 | } 269 | return "0=1" 270 | } 271 | 272 | var values []string 273 | for _, value := range e.values { 274 | switch value.(type) { 275 | case nil: 276 | values = append(values, "NULL") 277 | case Expression: 278 | sql := value.(Expression).Build(db, params) 279 | values = append(values, sql) 280 | default: 281 | name := fmt.Sprintf("p%v", len(params)) 282 | params[name] = value 283 | values = append(values, "{:"+name+"}") 284 | } 285 | } 286 | col := db.QuoteColumnName(e.col) 287 | if len(values) == 1 { 288 | if e.not { 289 | return col + "<>" + values[0] 290 | } 291 | return col + "=" + values[0] 292 | } 293 | in := "IN" 294 | if e.not { 295 | in = "NOT IN" 296 | } 297 | return fmt.Sprintf("%v %v (%v)", col, in, strings.Join(values, ", ")) 298 | } 299 | 300 | // LikeExp represents a variant of LIKE expressions. 301 | type LikeExp struct { 302 | or bool 303 | left, right bool 304 | col string 305 | values []string 306 | escape []string 307 | 308 | // Like stores the LIKE operator. It can be "LIKE", "NOT LIKE". 309 | // It may also be customized as something like "ILIKE". 310 | Like string 311 | } 312 | 313 | // Escape specifies how a LIKE expression should be escaped. 314 | // Each string at position 2i represents a special character and the string at position 2i+1 is 315 | // the corresponding escaped version. 316 | func (e *LikeExp) Escape(chars ...string) *LikeExp { 317 | e.escape = chars 318 | return e 319 | } 320 | 321 | // Match specifies whether to do wildcard matching on the left and/or right of given strings. 322 | func (e *LikeExp) Match(left, right bool) *LikeExp { 323 | e.left, e.right = left, right 324 | return e 325 | } 326 | 327 | // Build converts an expression into a SQL fragment. 328 | func (e *LikeExp) Build(db *DB, params Params) string { 329 | if len(e.values) == 0 { 330 | return "" 331 | } 332 | 333 | if len(e.escape)%2 != 0 { 334 | panic("LikeExp.Escape must be a slice of even number of strings") 335 | } 336 | 337 | var parts []string 338 | col := db.QuoteColumnName(e.col) 339 | for _, value := range e.values { 340 | name := fmt.Sprintf("p%v", len(params)) 341 | for i := 0; i < len(e.escape); i += 2 { 342 | value = strings.Replace(value, e.escape[i], e.escape[i+1], -1) 343 | } 344 | if e.left { 345 | value = "%" + value 346 | } 347 | if e.right { 348 | value += "%" 349 | } 350 | params[name] = value 351 | parts = append(parts, fmt.Sprintf("%v %v {:%v}", col, e.Like, name)) 352 | } 353 | 354 | if e.or { 355 | return strings.Join(parts, " OR ") 356 | } 357 | return strings.Join(parts, " AND ") 358 | } 359 | 360 | // ExistsExp represents an EXISTS or NOT EXISTS expression. 361 | type ExistsExp struct { 362 | exp Expression 363 | not bool 364 | } 365 | 366 | // Build converts an expression into a SQL fragment. 367 | func (e *ExistsExp) Build(db *DB, params Params) string { 368 | sql := e.exp.Build(db, params) 369 | if sql == "" { 370 | if e.not { 371 | return "" 372 | } 373 | return "0=1" 374 | } 375 | if e.not { 376 | return "NOT EXISTS (" + sql + ")" 377 | } 378 | return "EXISTS (" + sql + ")" 379 | } 380 | 381 | // BetweenExp represents a BETWEEN or a NOT BETWEEN expression. 382 | type BetweenExp struct { 383 | col string 384 | from, to interface{} 385 | not bool 386 | } 387 | 388 | // Build converts an expression into a SQL fragment. 389 | func (e *BetweenExp) Build(db *DB, params Params) string { 390 | between := "BETWEEN" 391 | if e.not { 392 | between = "NOT BETWEEN" 393 | } 394 | name1 := fmt.Sprintf("p%v", len(params)) 395 | name2 := fmt.Sprintf("p%v", len(params)+1) 396 | params[name1] = e.from 397 | params[name2] = e.to 398 | col := db.QuoteColumnName(e.col) 399 | return fmt.Sprintf("%v %v {:%v} AND {:%v}", col, between, name1, name2) 400 | } 401 | 402 | // Enclose surrounds the provided nonempty expression with parenthesis "()". 403 | func Enclose(exp Expression) Expression { 404 | return &EncloseExp{exp} 405 | } 406 | 407 | // EncloseExp represents a parenthesis enclosed expression. 408 | type EncloseExp struct { 409 | exp Expression 410 | } 411 | 412 | // Build converts an expression into a SQL fragment. 413 | func (e *EncloseExp) Build(db *DB, params Params) string { 414 | str := e.exp.Build(db, params) 415 | 416 | if str == "" { 417 | return "" 418 | } 419 | 420 | return "(" + str + ")" 421 | } 422 | -------------------------------------------------------------------------------- /expression_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestExp(t *testing.T) { 14 | params := Params{"k2": "v2"} 15 | 16 | e1 := NewExp("s1").(*Exp) 17 | assert.Equal(t, e1.Build(nil, params), "s1", "e1.Build()") 18 | assert.Equal(t, len(params), 1, `len(params)@1`) 19 | 20 | e2 := NewExp("s2", Params{"k1": "v1"}).(*Exp) 21 | assert.Equal(t, e2.Build(nil, params), "s2", "e2.Build()") 22 | assert.Equal(t, len(params), 2, `len(params)@2`) 23 | } 24 | 25 | func TestHashExp(t *testing.T) { 26 | e1 := HashExp{} 27 | assert.Equal(t, e1.Build(nil, nil), "", `e1.Build()`) 28 | 29 | e2 := HashExp{ 30 | "k1": nil, 31 | "k2": NewExp("s1", Params{"ka": "va"}), 32 | "k3": 1.1, 33 | "k4": "abc", 34 | "k5": []interface{}{1, 2}, 35 | } 36 | db := getDB() 37 | params := Params{"k0": "v0"} 38 | expected := "`k1` IS NULL AND (s1) AND `k3`={:p2} AND `k4`={:p3} AND `k5` IN ({:p4}, {:p5})" 39 | 40 | assert.Equal(t, e2.Build(db, params), expected, `e2.Build()`) 41 | assert.Equal(t, len(params), 6, `len(params)`) 42 | assert.Equal(t, params["p5"].(int), 2, `params["p5"]`) 43 | } 44 | 45 | func TestNotExp(t *testing.T) { 46 | e1 := Not(NewExp("s1")) 47 | assert.Equal(t, e1.Build(nil, nil), "NOT (s1)", `e1.Build()`) 48 | 49 | e2 := Not(NewExp("")) 50 | assert.Equal(t, e2.Build(nil, nil), "", `e2.Build()`) 51 | } 52 | 53 | func TestAndOrExp(t *testing.T) { 54 | e1 := And(NewExp("s1", Params{"k1": "v1"}), NewExp(""), NewExp("s2", Params{"k2": "v2"})) 55 | params := Params{} 56 | assert.Equal(t, e1.Build(nil, params), "(s1) AND (s2)", `e1.Build()`) 57 | assert.Equal(t, len(params), 2, `len(params)`) 58 | 59 | e2 := Or(NewExp("s1"), NewExp("s2")) 60 | assert.Equal(t, e2.Build(nil, nil), "(s1) OR (s2)", `e2.Build()`) 61 | 62 | e3 := And() 63 | assert.Equal(t, e3.Build(nil, nil), "", `e3.Build()`) 64 | 65 | e4 := And(NewExp("s1")) 66 | assert.Equal(t, e4.Build(nil, nil), "s1", `e4.Build()`) 67 | 68 | e5 := And(NewExp("s1"), nil) 69 | assert.Equal(t, e5.Build(nil, nil), "s1", `e5.Build()`) 70 | } 71 | 72 | func TestInExp(t *testing.T) { 73 | db := getDB() 74 | 75 | e1 := In("age", 1, 2, 3) 76 | params := Params{} 77 | assert.Equal(t, e1.Build(db, params), "`age` IN ({:p0}, {:p1}, {:p2})", `e1.Build()`) 78 | assert.Equal(t, len(params), 3, `len(params)@1`) 79 | 80 | e2 := In("age", 1) 81 | params = Params{} 82 | assert.Equal(t, e2.Build(db, params), "`age`={:p0}", `e2.Build()`) 83 | assert.Equal(t, len(params), 1, `len(params)@2`) 84 | 85 | e3 := NotIn("age", 1, 2, 3) 86 | params = Params{} 87 | assert.Equal(t, e3.Build(db, params), "`age` NOT IN ({:p0}, {:p1}, {:p2})", `e3.Build()`) 88 | assert.Equal(t, len(params), 3, `len(params)@3`) 89 | 90 | e4 := NotIn("age", 1) 91 | params = Params{} 92 | assert.Equal(t, e4.Build(db, params), "`age`<>{:p0}", `e4.Build()`) 93 | assert.Equal(t, len(params), 1, `len(params)@4`) 94 | 95 | e5 := In("age") 96 | assert.Equal(t, e5.Build(db, nil), "0=1", `e5.Build()`) 97 | 98 | e6 := NotIn("age") 99 | assert.Equal(t, e6.Build(db, nil), "", `e6.Build()`) 100 | } 101 | 102 | func TestLikeExp(t *testing.T) { 103 | db := getDB() 104 | 105 | e1 := Like("name", "a", "b", "c") 106 | params := Params{} 107 | assert.Equal(t, e1.Build(db, params), "`name` LIKE {:p0} AND `name` LIKE {:p1} AND `name` LIKE {:p2}", `e1.Build()`) 108 | assert.Equal(t, len(params), 3, `len(params)@1`) 109 | 110 | e2 := Like("name", "a") 111 | params = Params{} 112 | assert.Equal(t, e2.Build(db, params), "`name` LIKE {:p0}", `e2.Build()`) 113 | assert.Equal(t, len(params), 1, `len(params)@2`) 114 | 115 | e3 := Like("name") 116 | assert.Equal(t, e3.Build(db, nil), "", `e3.Build()`) 117 | 118 | e4 := NotLike("name", "a", "b", "c") 119 | params = Params{} 120 | assert.Equal(t, e4.Build(db, params), "`name` NOT LIKE {:p0} AND `name` NOT LIKE {:p1} AND `name` NOT LIKE {:p2}", `e4.Build()`) 121 | assert.Equal(t, len(params), 3, `len(params)@4`) 122 | 123 | e5 := OrLike("name", "a", "b", "c") 124 | params = Params{} 125 | assert.Equal(t, e5.Build(db, params), "`name` LIKE {:p0} OR `name` LIKE {:p1} OR `name` LIKE {:p2}", `e5.Build()`) 126 | assert.Equal(t, len(params), 3, `len(params)@5`) 127 | 128 | e6 := OrNotLike("name", "a", "b", "c") 129 | params = Params{} 130 | assert.Equal(t, e6.Build(db, params), "`name` NOT LIKE {:p0} OR `name` NOT LIKE {:p1} OR `name` NOT LIKE {:p2}", `e6.Build()`) 131 | assert.Equal(t, len(params), 3, `len(params)@6`) 132 | 133 | e7 := Like("name", "a_\\%") 134 | params = Params{} 135 | e7.Build(db, params) 136 | assert.Equal(t, params["p0"], "%a\\_\\\\\\%%", `params["p0"]@1`) 137 | 138 | e8 := Like("name", "a").Match(false, true) 139 | params = Params{} 140 | e8.Build(db, params) 141 | assert.Equal(t, params["p0"], "a%", `params["p0"]@2`) 142 | 143 | e9 := Like("name", "a").Match(true, false) 144 | params = Params{} 145 | e9.Build(db, params) 146 | assert.Equal(t, params["p0"], "%a", `params["p0"]@3`) 147 | 148 | e10 := Like("name", "a").Match(false, false) 149 | params = Params{} 150 | e10.Build(db, params) 151 | assert.Equal(t, params["p0"], "a", `params["p0"]@4`) 152 | 153 | e11 := Like("name", "%a").Match(false, false).Escape() 154 | params = Params{} 155 | e11.Build(db, params) 156 | assert.Equal(t, params["p0"], "%a", `params["p0"]@5`) 157 | } 158 | 159 | func TestBetweenExp(t *testing.T) { 160 | db := getDB() 161 | 162 | e1 := Between("age", 30, 40) 163 | params := Params{} 164 | assert.Equal(t, e1.Build(db, params), "`age` BETWEEN {:p0} AND {:p1}", `e1.Build()`) 165 | assert.Equal(t, len(params), 2, `len(params)@1`) 166 | 167 | e2 := NotBetween("age", 30, 40) 168 | params = Params{} 169 | assert.Equal(t, e2.Build(db, params), "`age` NOT BETWEEN {:p0} AND {:p1}", `e2.Build()`) 170 | assert.Equal(t, len(params), 2, `len(params)@2`) 171 | } 172 | 173 | func TestExistsExp(t *testing.T) { 174 | e1 := Exists(NewExp("s1")) 175 | assert.Equal(t, e1.Build(nil, nil), "EXISTS (s1)", `e1.Build()`) 176 | 177 | e2 := NotExists(NewExp("s1")) 178 | assert.Equal(t, e2.Build(nil, nil), "NOT EXISTS (s1)", `e2.Build()`) 179 | 180 | e3 := Exists(NewExp("")) 181 | assert.Equal(t, e3.Build(nil, nil), "0=1", `e3.Build()`) 182 | 183 | e4 := NotExists(NewExp("")) 184 | assert.Equal(t, e4.Build(nil, nil), "", `e4.Build()`) 185 | } 186 | 187 | func TestEncloseExp(t *testing.T) { 188 | e1 := Enclose(NewExp("")) 189 | assert.Equal(t, e1.Build(nil, nil), "", `e1.Build()`) 190 | 191 | e2 := Enclose(NewExp("s1")) 192 | assert.Equal(t, e2.Build(nil, nil), "(s1)", `e2.Build()`) 193 | 194 | e3 := Enclose(NewExp("(s1)")) 195 | assert.Equal(t, e3.Build(nil, nil), "((s1))", `e3.Build()`) 196 | } 197 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/pocketbase/dbx 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/go-sql-driver/mysql v1.4.1 7 | github.com/stretchr/testify v1.4.0 8 | google.golang.org/appengine v1.6.5 // indirect 9 | ) 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= 4 | github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= 5 | github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 6 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 7 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 8 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 9 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= 10 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 11 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 12 | golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= 13 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 14 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 15 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 16 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 17 | google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= 18 | google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= 19 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 20 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 21 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 22 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 23 | -------------------------------------------------------------------------------- /model_query.go: -------------------------------------------------------------------------------- 1 | package dbx 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "reflect" 8 | ) 9 | 10 | type ( 11 | // TableModel is the interface that should be implemented by models which have unconventional table names. 12 | TableModel interface { 13 | TableName() string 14 | } 15 | 16 | // ModelQuery represents a query associated with a struct model. 17 | ModelQuery struct { 18 | db *DB 19 | ctx context.Context 20 | builder Builder 21 | model *structValue 22 | exclude []string 23 | lastError error 24 | } 25 | ) 26 | 27 | var ( 28 | MissingPKError = errors.New("missing primary key declaration") 29 | CompositePKError = errors.New("composite primary key is not supported") 30 | ) 31 | 32 | func NewModelQuery(model interface{}, fieldMapFunc FieldMapFunc, db *DB, builder Builder) *ModelQuery { 33 | q := &ModelQuery{ 34 | db: db, 35 | ctx: db.ctx, 36 | builder: builder, 37 | model: newStructValue(model, fieldMapFunc, db.TableMapper), 38 | } 39 | if q.model == nil { 40 | q.lastError = VarTypeError("must be a pointer to a struct representing the model") 41 | } 42 | return q 43 | } 44 | 45 | // Context returns the context associated with the query. 46 | func (q *ModelQuery) Context() context.Context { 47 | return q.ctx 48 | } 49 | 50 | // WithContext associates a context with the query. 51 | func (q *ModelQuery) WithContext(ctx context.Context) *ModelQuery { 52 | q.ctx = ctx 53 | return q 54 | } 55 | 56 | // Exclude excludes the specified struct fields from being inserted/updated into the DB table. 57 | func (q *ModelQuery) Exclude(attrs ...string) *ModelQuery { 58 | q.exclude = attrs 59 | return q 60 | } 61 | 62 | // Insert inserts a row in the table using the struct model associated with this query. 63 | // 64 | // By default, it inserts *all* public fields into the table, including those nil or empty ones. 65 | // You may pass a list of the fields to this method to indicate that only those fields should be inserted. 66 | // You may also call Exclude to exclude some fields from being inserted. 67 | // 68 | // If a model has an empty primary key, it is considered auto-incremental and the corresponding struct 69 | // field will be filled with the generated primary key value after a successful insertion. 70 | func (q *ModelQuery) Insert(attrs ...string) error { 71 | if q.lastError != nil { 72 | return q.lastError 73 | } 74 | cols := q.model.columns(attrs, q.exclude) 75 | pkName := "" 76 | for name, value := range q.model.pk() { 77 | if isAutoInc(value) { 78 | delete(cols, name) 79 | pkName = name 80 | break 81 | } 82 | } 83 | 84 | if pkName == "" { 85 | _, err := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx).Execute() 86 | return err 87 | } 88 | 89 | // handle auto-incremental PK 90 | query := q.builder.Insert(q.model.tableName, Params(cols)).WithContext(q.ctx) 91 | pkValue, err := insertAndReturnPK(q.db, query, pkName) 92 | if err != nil { 93 | return err 94 | } 95 | 96 | pkField := indirect(q.model.dbNameMap[pkName].getField(q.model.value)) 97 | switch pkField.Kind() { 98 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 99 | pkField.SetUint(uint64(pkValue)) 100 | default: 101 | pkField.SetInt(pkValue) 102 | } 103 | 104 | return nil 105 | } 106 | 107 | func insertAndReturnPK(db *DB, query *Query, pkName string) (int64, error) { 108 | if db.DriverName() != "postgres" && db.DriverName() != "pgx" { 109 | result, err := query.Execute() 110 | if err != nil { 111 | return 0, err 112 | } 113 | return result.LastInsertId() 114 | } 115 | 116 | // specially handle postgres (lib/pq) as it doesn't support LastInsertId 117 | returning := fmt.Sprintf(" RETURNING %s", db.QuoteColumnName(pkName)) 118 | query.sql += returning 119 | query.rawSQL += returning 120 | var pkValue int64 121 | err := query.Row(&pkValue) 122 | return pkValue, err 123 | } 124 | 125 | func isAutoInc(value interface{}) bool { 126 | v := reflect.ValueOf(value) 127 | switch v.Kind() { 128 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 129 | return v.Int() == 0 130 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 131 | return v.Uint() == 0 132 | case reflect.Ptr: 133 | return v.IsNil() || isAutoInc(v.Elem()) 134 | case reflect.Invalid: 135 | return true 136 | } 137 | return false 138 | } 139 | 140 | // Update updates a row in the table using the struct model associated with this query. 141 | // The row being updated has the same primary key as specified by the model. 142 | // 143 | // By default, it updates *all* public fields in the table, including those nil or empty ones. 144 | // You may pass a list of the fields to this method to indicate that only those fields should be updated. 145 | // You may also call Exclude to exclude some fields from being updated. 146 | func (q *ModelQuery) Update(attrs ...string) error { 147 | if q.lastError != nil { 148 | return q.lastError 149 | } 150 | pk := q.model.pk() 151 | if len(pk) == 0 { 152 | return MissingPKError 153 | } 154 | 155 | cols := q.model.columns(attrs, q.exclude) 156 | for name := range pk { 157 | delete(cols, name) 158 | } 159 | _, err := q.builder.Update(q.model.tableName, Params(cols), HashExp(pk)).WithContext(q.ctx).Execute() 160 | return err 161 | } 162 | 163 | // Delete deletes a row in the table using the primary key specified by the struct model associated with this query. 164 | func (q *ModelQuery) Delete() error { 165 | if q.lastError != nil { 166 | return q.lastError 167 | } 168 | pk := q.model.pk() 169 | if len(pk) == 0 { 170 | return MissingPKError 171 | } 172 | _, err := q.builder.Delete(q.model.tableName, HashExp(pk)).WithContext(q.ctx).Execute() 173 | return err 174 | } 175 | -------------------------------------------------------------------------------- /model_query_test.go: -------------------------------------------------------------------------------- 1 | package dbx 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type Item struct { 11 | ID2 int 12 | Name string 13 | } 14 | 15 | func TestModelQuery_Insert(t *testing.T) { 16 | db := getPreparedDB() 17 | defer db.Close() 18 | 19 | name := "test" 20 | email := "test@example.com" 21 | 22 | { 23 | // inserting normally 24 | customer := Customer{ 25 | Name: name, 26 | Email: email, 27 | } 28 | err := db.Model(&customer).Insert() 29 | if assert.Nil(t, err) { 30 | assert.Equal(t, 4, customer.ID) 31 | var c Customer 32 | db.Select().From("customer").Where(HashExp{"ID": 4}).One(&c) 33 | assert.Equal(t, name, c.Name) 34 | assert.Equal(t, email, c.Email) 35 | assert.Equal(t, 0, c.Status) 36 | assert.False(t, c.Address.Valid) 37 | } 38 | } 39 | 40 | { 41 | // inserting with pointer-typed fields 42 | customer := CustomerPtr{ 43 | Name: name, 44 | Email: &email, 45 | } 46 | err := db.Model(&customer).Insert() 47 | if assert.Nil(t, err) && assert.NotNil(t, customer.ID) { 48 | assert.Equal(t, 5, *customer.ID) 49 | var c CustomerPtr 50 | db.Select().From("customer").Where(HashExp{"ID": 4}).One(&c) 51 | assert.Equal(t, name, c.Name) 52 | if assert.NotNil(t, c.Email) { 53 | assert.Equal(t, email, *c.Email) 54 | } 55 | if assert.NotNil(t, c.Status) { 56 | assert.Equal(t, 0, *c.Status) 57 | } 58 | assert.Nil(t, c.Address) 59 | } 60 | } 61 | 62 | { 63 | // inserting with null-typed fields 64 | customer := CustomerNull{ 65 | Name: name, 66 | Email: sql.NullString{email, true}, 67 | } 68 | err := db.Model(&customer).Insert() 69 | if assert.Nil(t, err) { 70 | // potential todo: need to check if the field implements sql.Scanner 71 | // assert.Equal(t, int64(6), customer.ID.Int64) 72 | var c CustomerNull 73 | db.Select().From("customer").Where(HashExp{"ID": 4}).One(&c) 74 | assert.Equal(t, name, c.Name) 75 | assert.Equal(t, email, c.Email.String) 76 | if assert.NotNil(t, c.Status) { 77 | assert.Equal(t, int64(0), c.Status.Int64) 78 | } 79 | assert.False(t, c.Address.Valid) 80 | } 81 | } 82 | 83 | { 84 | // inserting with embedded structures 85 | customer := CustomerEmbedded{ 86 | Id: 100, 87 | Email: &email, 88 | InnerCustomer: InnerCustomer{ 89 | Name: &name, 90 | Status: sql.NullInt64{1, true}, 91 | }, 92 | } 93 | err := db.Model(&customer).Insert() 94 | if assert.Nil(t, err) { 95 | assert.Equal(t, 100, customer.Id) 96 | var c CustomerEmbedded 97 | db.Select().From("customer").Where(HashExp{"ID": 100}).One(&c) 98 | assert.Equal(t, name, *c.Name) 99 | assert.Equal(t, email, *c.Email) 100 | if assert.NotNil(t, c.Status) { 101 | assert.Equal(t, int64(1), c.Status.Int64) 102 | } 103 | assert.False(t, c.Address.Valid) 104 | } 105 | } 106 | 107 | { 108 | // inserting with include/exclude fields 109 | customer := Customer{ 110 | Name: name, 111 | Email: email, 112 | Status: 1, 113 | } 114 | err := db.Model(&customer).Exclude("Name").Insert("Name", "Email") 115 | if assert.Nil(t, err) { 116 | assert.Equal(t, 101, customer.ID) 117 | var c Customer 118 | db.Select().From("customer").Where(HashExp{"ID": 101}).One(&c) 119 | assert.Equal(t, "", c.Name) 120 | assert.Equal(t, email, c.Email) 121 | assert.Equal(t, 0, c.Status) 122 | assert.False(t, c.Address.Valid) 123 | } 124 | } 125 | 126 | var a int 127 | assert.NotNil(t, db.Model(&a).Insert()) 128 | } 129 | 130 | func TestModelQuery_Update(t *testing.T) { 131 | db := getPreparedDB() 132 | defer db.Close() 133 | 134 | id := 2 135 | name := "test" 136 | email := "test@example.com" 137 | { 138 | // updating normally 139 | customer := Customer{ 140 | ID: id, 141 | Name: name, 142 | Email: email, 143 | } 144 | err := db.Model(&customer).Update() 145 | if assert.Nil(t, err) { 146 | var c Customer 147 | db.Select().From("customer").Where(HashExp{"ID": id}).One(&c) 148 | assert.Equal(t, name, c.Name) 149 | assert.Equal(t, email, c.Email) 150 | assert.Equal(t, 0, c.Status) 151 | } 152 | } 153 | 154 | { 155 | // updating without primary keys 156 | item2 := Item{ 157 | Name: name, 158 | } 159 | err := db.Model(&item2).Update() 160 | assert.Equal(t, MissingPKError, err) 161 | } 162 | 163 | { 164 | // updating all fields 165 | customer := CustomerPtr{ 166 | ID: &id, 167 | Name: name, 168 | Email: &email, 169 | } 170 | err := db.Model(&customer).Update() 171 | if assert.Nil(t, err) { 172 | assert.Equal(t, id, *customer.ID) 173 | var c CustomerPtr 174 | db.Select().From("customer").Where(HashExp{"ID": id}).One(&c) 175 | assert.Equal(t, name, c.Name) 176 | if assert.NotNil(t, c.Email) { 177 | assert.Equal(t, email, *c.Email) 178 | } 179 | assert.Nil(t, c.Status) 180 | } 181 | } 182 | 183 | { 184 | // updating selected fields only 185 | id = 3 186 | customer := CustomerPtr{ 187 | ID: &id, 188 | Name: name, 189 | Email: &email, 190 | } 191 | err := db.Model(&customer).Update("Name", "Email") 192 | if assert.Nil(t, err) { 193 | assert.Equal(t, id, *customer.ID) 194 | var c CustomerPtr 195 | db.Select().From("customer").Where(HashExp{"ID": id}).One(&c) 196 | assert.Equal(t, name, c.Name) 197 | if assert.NotNil(t, c.Email) { 198 | assert.Equal(t, email, *c.Email) 199 | } 200 | if assert.NotNil(t, c.Status) { 201 | assert.Equal(t, 2, *c.Status) 202 | } 203 | } 204 | } 205 | 206 | { 207 | // updating non-struct 208 | var a int 209 | assert.NotNil(t, db.Model(&a).Update()) 210 | } 211 | } 212 | 213 | func TestModelQuery_Delete(t *testing.T) { 214 | db := getPreparedDB() 215 | defer db.Close() 216 | 217 | customer := Customer{ 218 | ID: 2, 219 | } 220 | err := db.Model(&customer).Delete() 221 | if assert.Nil(t, err) { 222 | var m Customer 223 | err := db.Select().From("customer").Where(HashExp{"ID": 2}).One(&m) 224 | assert.NotNil(t, err) 225 | } 226 | 227 | { 228 | // deleting without primary keys 229 | item2 := Item{ 230 | Name: "", 231 | } 232 | err := db.Model(&item2).Delete() 233 | assert.Equal(t, MissingPKError, err) 234 | } 235 | 236 | var a int 237 | assert.NotNil(t, db.Model(&a).Delete()) 238 | } 239 | -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | "database/sql/driver" 11 | "encoding/hex" 12 | "errors" 13 | "fmt" 14 | "strings" 15 | "time" 16 | ) 17 | 18 | // ExecHookFunc executes before op allowing custom handling like auto fail/retry. 19 | type ExecHookFunc func(q *Query, op func() error) error 20 | 21 | // OneHookFunc executes right before the query populate the row result from One() call (aka. op). 22 | type OneHookFunc func(q *Query, a interface{}, op func(b interface{}) error) error 23 | 24 | // AllHookFunc executes right before the query populate the row result from All() call (aka. op). 25 | type AllHookFunc func(q *Query, sliceA interface{}, op func(sliceB interface{}) error) error 26 | 27 | // Params represents a list of parameter values to be bound to a SQL statement. 28 | // The map keys are the parameter names while the map values are the corresponding parameter values. 29 | type Params map[string]interface{} 30 | 31 | // Executor prepares, executes, or queries a SQL statement. 32 | type Executor interface { 33 | // Exec executes a SQL statement 34 | Exec(query string, args ...interface{}) (sql.Result, error) 35 | // ExecContext executes a SQL statement with the given context 36 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 37 | // Query queries a SQL statement 38 | Query(query string, args ...interface{}) (*sql.Rows, error) 39 | // QueryContext queries a SQL statement with the given context 40 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 41 | // Prepare creates a prepared statement 42 | Prepare(query string) (*sql.Stmt, error) 43 | } 44 | 45 | // Query represents a SQL statement to be executed. 46 | type Query struct { 47 | executor Executor 48 | 49 | sql, rawSQL string 50 | placeholders []string 51 | params Params 52 | 53 | stmt *sql.Stmt 54 | ctx context.Context 55 | 56 | // hooks 57 | execHook ExecHookFunc 58 | oneHook OneHookFunc 59 | allHook AllHookFunc 60 | 61 | // FieldMapper maps struct field names to DB column names. 62 | FieldMapper FieldMapFunc 63 | // LastError contains the last error (if any) of the query. 64 | // LastError is cleared by Execute(), Row(), Rows(), One(), and All(). 65 | LastError error 66 | // LogFunc is used to log the SQL statement being executed. 67 | LogFunc LogFunc 68 | // PerfFunc is used to log the SQL execution time. It is ignored if nil. 69 | // Deprecated: Please use QueryLogFunc and ExecLogFunc instead. 70 | PerfFunc PerfFunc 71 | // QueryLogFunc is called each time when performing a SQL query that returns data. 72 | QueryLogFunc QueryLogFunc 73 | // ExecLogFunc is called each time when a SQL statement is executed. 74 | ExecLogFunc ExecLogFunc 75 | } 76 | 77 | // NewQuery creates a new Query with the given SQL statement. 78 | func NewQuery(db *DB, executor Executor, sql string) *Query { 79 | rawSQL, placeholders := db.processSQL(sql) 80 | return &Query{ 81 | executor: executor, 82 | sql: sql, 83 | rawSQL: rawSQL, 84 | placeholders: placeholders, 85 | params: Params{}, 86 | ctx: db.ctx, 87 | FieldMapper: db.FieldMapper, 88 | LogFunc: db.LogFunc, 89 | PerfFunc: db.PerfFunc, 90 | QueryLogFunc: db.QueryLogFunc, 91 | ExecLogFunc: db.ExecLogFunc, 92 | } 93 | } 94 | 95 | // SQL returns the original SQL used to create the query. 96 | // The actual SQL (RawSQL) being executed is obtained by replacing the named 97 | // parameter placeholders with anonymous ones. 98 | func (q *Query) SQL() string { 99 | return q.sql 100 | } 101 | 102 | // Context returns the context associated with the query. 103 | func (q *Query) Context() context.Context { 104 | return q.ctx 105 | } 106 | 107 | // WithContext associates a context with the query. 108 | func (q *Query) WithContext(ctx context.Context) *Query { 109 | q.ctx = ctx 110 | return q 111 | } 112 | 113 | // WithExecHook associates the provided exec hook function with the query. 114 | // 115 | // It is called for every Query resolver (Execute(), One(), All(), Row(), Column()), 116 | // allowing you to implement auto fail/retry or any other additional handling. 117 | func (q *Query) WithExecHook(fn ExecHookFunc) *Query { 118 | q.execHook = fn 119 | return q 120 | } 121 | 122 | // WithOneHook associates the provided hook function with the query, 123 | // called on q.One(), allowing you to implement custom struct scan based 124 | // on the One() argument and/or result. 125 | func (q *Query) WithOneHook(fn OneHookFunc) *Query { 126 | q.oneHook = fn 127 | return q 128 | } 129 | 130 | // WithOneHook associates the provided hook function with the query, 131 | // called on q.All(), allowing you to implement custom slice scan based 132 | // on the All() argument and/or result. 133 | func (q *Query) WithAllHook(fn AllHookFunc) *Query { 134 | q.allHook = fn 135 | return q 136 | } 137 | 138 | // logSQL returns the SQL statement with parameters being replaced with the actual values. 139 | // The result is only for logging purpose and should not be used to execute. 140 | func (q *Query) logSQL() string { 141 | s := q.sql 142 | for k, v := range q.params { 143 | if valuer, ok := v.(driver.Valuer); ok && valuer != nil { 144 | v, _ = valuer.Value() 145 | } 146 | var sv string 147 | if str, ok := v.(string); ok { 148 | sv = "'" + strings.Replace(str, "'", "''", -1) + "'" 149 | } else if bs, ok := v.([]byte); ok { 150 | sv = "0x" + hex.EncodeToString(bs) 151 | } else { 152 | sv = fmt.Sprintf("%v", v) 153 | } 154 | s = strings.Replace(s, "{:"+k+"}", sv, -1) 155 | } 156 | return s 157 | } 158 | 159 | // Params returns the parameters to be bound to the SQL statement represented by this query. 160 | func (q *Query) Params() Params { 161 | return q.params 162 | } 163 | 164 | // Prepare creates a prepared statement for later queries or executions. 165 | // Close() should be called after finishing all queries. 166 | func (q *Query) Prepare() *Query { 167 | stmt, err := q.executor.Prepare(q.rawSQL) 168 | if err != nil { 169 | q.LastError = err 170 | return q 171 | } 172 | q.stmt = stmt 173 | return q 174 | } 175 | 176 | // Close closes the underlying prepared statement. 177 | // Close does nothing if the query has not been prepared before. 178 | func (q *Query) Close() error { 179 | if q.stmt == nil { 180 | return nil 181 | } 182 | 183 | err := q.stmt.Close() 184 | q.stmt = nil 185 | return err 186 | } 187 | 188 | // Bind sets the parameters that should be bound to the SQL statement. 189 | // The parameter placeholders in the SQL statement are in the format of "{:ParamName}". 190 | func (q *Query) Bind(params Params) *Query { 191 | if len(q.params) == 0 { 192 | q.params = params 193 | } else { 194 | for k, v := range params { 195 | q.params[k] = v 196 | } 197 | } 198 | return q 199 | } 200 | 201 | // Execute executes the SQL statement without retrieving data. 202 | func (q *Query) Execute() (sql.Result, error) { 203 | var result sql.Result 204 | 205 | execErr := q.execWrap(func() error { 206 | var err error 207 | result, err = q.execute() 208 | return err 209 | }) 210 | 211 | return result, execErr 212 | } 213 | 214 | func (q *Query) execute() (result sql.Result, err error) { 215 | err = q.LastError 216 | q.LastError = nil 217 | if err != nil { 218 | return 219 | } 220 | 221 | var params []interface{} 222 | params, err = replacePlaceholders(q.placeholders, q.params) 223 | if err != nil { 224 | return 225 | } 226 | 227 | start := time.Now() 228 | 229 | if q.ctx == nil { 230 | if q.stmt == nil { 231 | result, err = q.executor.Exec(q.rawSQL, params...) 232 | } else { 233 | result, err = q.stmt.Exec(params...) 234 | } 235 | } else { 236 | if q.stmt == nil { 237 | result, err = q.executor.ExecContext(q.ctx, q.rawSQL, params...) 238 | } else { 239 | result, err = q.stmt.ExecContext(q.ctx, params...) 240 | } 241 | } 242 | 243 | if q.ExecLogFunc != nil { 244 | q.ExecLogFunc(q.ctx, time.Now().Sub(start), q.logSQL(), result, err) 245 | } 246 | if q.LogFunc != nil { 247 | q.LogFunc("[%.2fms] Execute SQL: %v", float64(time.Now().Sub(start).Milliseconds()), q.logSQL()) 248 | } 249 | if q.PerfFunc != nil { 250 | q.PerfFunc(time.Now().Sub(start).Nanoseconds(), q.logSQL(), true) 251 | } 252 | return 253 | } 254 | 255 | // One executes the SQL statement and populates the first row of the result into a struct or NullStringMap. 256 | // Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how to specify 257 | // the variable to be populated. 258 | // Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned. 259 | func (q *Query) One(a interface{}) error { 260 | return q.execWrap(func() error { 261 | rows, err := q.Rows() 262 | if err != nil { 263 | return err 264 | } 265 | 266 | if q.oneHook != nil { 267 | return q.oneHook(q, a, rows.one) 268 | } 269 | 270 | return rows.one(a) 271 | }) 272 | } 273 | 274 | // All executes the SQL statement and populates all the resulting rows into a slice of struct or NullStringMap. 275 | // The slice must be given as a pointer. Each slice element must be either a struct or a NullStringMap. 276 | // Refer to Rows.ScanStruct() and Rows.ScanMap() for more details on how each slice element can be. 277 | // If the query returns no row, the slice will be an empty slice (not nil). 278 | func (q *Query) All(slice interface{}) error { 279 | return q.execWrap(func() error { 280 | rows, err := q.Rows() 281 | if err != nil { 282 | return err 283 | } 284 | 285 | if q.allHook != nil { 286 | return q.allHook(q, slice, rows.all) 287 | } 288 | 289 | return rows.all(slice) 290 | }) 291 | } 292 | 293 | // Row executes the SQL statement and populates the first row of the result into a list of variables. 294 | // Note that the number of the variables should match to that of the columns in the query result. 295 | // Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned. 296 | func (q *Query) Row(a ...interface{}) error { 297 | return q.execWrap(func() error { 298 | rows, err := q.Rows() 299 | if err != nil { 300 | return err 301 | } 302 | return rows.row(a...) 303 | }) 304 | } 305 | 306 | // Column executes the SQL statement and populates the first column of the result into a slice. 307 | // Note that the parameter must be a pointer to a slice. 308 | func (q *Query) Column(a interface{}) error { 309 | return q.execWrap(func() error { 310 | rows, err := q.Rows() 311 | if err != nil { 312 | return err 313 | } 314 | return rows.column(a) 315 | }) 316 | } 317 | 318 | // Rows executes the SQL statement and returns a Rows object to allow retrieving data row by row. 319 | func (q *Query) Rows() (rows *Rows, err error) { 320 | err = q.LastError 321 | q.LastError = nil 322 | if err != nil { 323 | return 324 | } 325 | 326 | var params []interface{} 327 | params, err = replacePlaceholders(q.placeholders, q.params) 328 | if err != nil { 329 | return 330 | } 331 | 332 | start := time.Now() 333 | 334 | var rr *sql.Rows 335 | if q.ctx == nil { 336 | if q.stmt == nil { 337 | rr, err = q.executor.Query(q.rawSQL, params...) 338 | } else { 339 | rr, err = q.stmt.Query(params...) 340 | } 341 | } else { 342 | if q.stmt == nil { 343 | rr, err = q.executor.QueryContext(q.ctx, q.rawSQL, params...) 344 | } else { 345 | rr, err = q.stmt.QueryContext(q.ctx, params...) 346 | } 347 | } 348 | rows = &Rows{rr, q.FieldMapper} 349 | 350 | if q.QueryLogFunc != nil { 351 | q.QueryLogFunc(q.ctx, time.Now().Sub(start), q.logSQL(), rr, err) 352 | } 353 | if q.LogFunc != nil { 354 | q.LogFunc("[%.2fms] Query SQL: %v", float64(time.Now().Sub(start).Milliseconds()), q.logSQL()) 355 | } 356 | if q.PerfFunc != nil { 357 | q.PerfFunc(time.Now().Sub(start).Nanoseconds(), q.logSQL(), false) 358 | } 359 | return 360 | } 361 | 362 | func (q *Query) execWrap(op func() error) error { 363 | if q.execHook != nil { 364 | return q.execHook(q, op) 365 | } 366 | return op() 367 | } 368 | 369 | // replacePlaceholders converts a list of named parameters into a list of anonymous parameters. 370 | func replacePlaceholders(placeholders []string, params Params) ([]interface{}, error) { 371 | if len(placeholders) == 0 { 372 | return nil, nil 373 | } 374 | 375 | var result []interface{} 376 | for _, name := range placeholders { 377 | if value, ok := params[name]; ok { 378 | result = append(result, value) 379 | } else { 380 | return nil, errors.New("Named parameter not found: " + name) 381 | } 382 | } 383 | return result, nil 384 | } 385 | -------------------------------------------------------------------------------- /query_builder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "regexp" 11 | "strings" 12 | ) 13 | 14 | // QueryBuilder builds different clauses for a SELECT SQL statement. 15 | type QueryBuilder interface { 16 | // BuildSelect generates a SELECT clause from the given selected column names. 17 | BuildSelect(cols []string, distinct bool, option string) string 18 | // BuildFrom generates a FROM clause from the given tables. 19 | BuildFrom(tables []string) string 20 | // BuildGroupBy generates a GROUP BY clause from the given group-by columns. 21 | BuildGroupBy(cols []string) string 22 | // BuildJoin generates a JOIN clause from the given join information. 23 | BuildJoin([]JoinInfo, Params) string 24 | // BuildWhere generates a WHERE clause from the given expression. 25 | BuildWhere(Expression, Params) string 26 | // BuildHaving generates a HAVING clause from the given expression. 27 | BuildHaving(Expression, Params) string 28 | // BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses. 29 | BuildOrderByAndLimit(string, []string, int64, int64) string 30 | // BuildUnion generates a UNION clause from the given union information. 31 | BuildUnion([]UnionInfo, Params) string 32 | } 33 | 34 | // BaseQueryBuilder provides a basic implementation of QueryBuilder. 35 | type BaseQueryBuilder struct { 36 | db *DB 37 | } 38 | 39 | var _ QueryBuilder = &BaseQueryBuilder{} 40 | 41 | // NewBaseQueryBuilder creates a new BaseQueryBuilder instance. 42 | func NewBaseQueryBuilder(db *DB) *BaseQueryBuilder { 43 | return &BaseQueryBuilder{db} 44 | } 45 | 46 | // DB returns the DB instance associated with the query builder. 47 | func (q *BaseQueryBuilder) DB() *DB { 48 | return q.db 49 | } 50 | 51 | // the regexp for columns and tables. 52 | var selectRegex = regexp.MustCompile(`(?i:\s+as\s+|\s+)([\w\-_\.]+)$`) 53 | 54 | // BuildSelect generates a SELECT clause from the given selected column names. 55 | func (q *BaseQueryBuilder) BuildSelect(cols []string, distinct bool, option string) string { 56 | var s bytes.Buffer 57 | s.WriteString("SELECT ") 58 | if distinct { 59 | s.WriteString("DISTINCT ") 60 | } 61 | if option != "" { 62 | s.WriteString(option) 63 | s.WriteString(" ") 64 | } 65 | if len(cols) == 0 { 66 | s.WriteString("*") 67 | return s.String() 68 | } 69 | 70 | for i, col := range cols { 71 | if i > 0 { 72 | s.WriteString(", ") 73 | } 74 | matches := selectRegex.FindStringSubmatch(col) 75 | if len(matches) == 0 { 76 | s.WriteString(q.db.QuoteColumnName(col)) 77 | } else { 78 | col := col[:len(col)-len(matches[0])] 79 | alias := matches[1] 80 | s.WriteString(q.db.QuoteColumnName(col) + " AS " + q.db.QuoteSimpleColumnName(alias)) 81 | } 82 | } 83 | 84 | return s.String() 85 | } 86 | 87 | // BuildFrom generates a FROM clause from the given tables. 88 | func (q *BaseQueryBuilder) BuildFrom(tables []string) string { 89 | if len(tables) == 0 { 90 | return "" 91 | } 92 | s := "" 93 | for _, table := range tables { 94 | table = q.quoteTableNameAndAlias(table) 95 | if s == "" { 96 | s = table 97 | } else { 98 | s += ", " + table 99 | } 100 | } 101 | return "FROM " + s 102 | } 103 | 104 | // BuildJoin generates a JOIN clause from the given join information. 105 | func (q *BaseQueryBuilder) BuildJoin(joins []JoinInfo, params Params) string { 106 | if len(joins) == 0 { 107 | return "" 108 | } 109 | parts := []string{} 110 | for _, join := range joins { 111 | sql := join.Join + " " + q.quoteTableNameAndAlias(join.Table) 112 | on := "" 113 | if join.On != nil { 114 | on = join.On.Build(q.db, params) 115 | } 116 | if on != "" { 117 | sql += " ON " + on 118 | } 119 | parts = append(parts, sql) 120 | } 121 | return strings.Join(parts, " ") 122 | } 123 | 124 | // BuildWhere generates a WHERE clause from the given expression. 125 | func (q *BaseQueryBuilder) BuildWhere(e Expression, params Params) string { 126 | if e != nil { 127 | if c := e.Build(q.db, params); c != "" { 128 | return "WHERE " + c 129 | } 130 | } 131 | return "" 132 | } 133 | 134 | // BuildHaving generates a HAVING clause from the given expression. 135 | func (q *BaseQueryBuilder) BuildHaving(e Expression, params Params) string { 136 | if e != nil { 137 | if c := e.Build(q.db, params); c != "" { 138 | return "HAVING " + c 139 | } 140 | } 141 | return "" 142 | } 143 | 144 | // BuildGroupBy generates a GROUP BY clause from the given group-by columns. 145 | func (q *BaseQueryBuilder) BuildGroupBy(cols []string) string { 146 | if len(cols) == 0 { 147 | return "" 148 | } 149 | s := "" 150 | for i, col := range cols { 151 | if i == 0 { 152 | s = q.db.QuoteColumnName(col) 153 | } else { 154 | s += ", " + q.db.QuoteColumnName(col) 155 | } 156 | } 157 | return "GROUP BY " + s 158 | } 159 | 160 | // BuildOrderByAndLimit generates the ORDER BY and LIMIT clauses. 161 | func (q *BaseQueryBuilder) BuildOrderByAndLimit(sql string, cols []string, limit int64, offset int64) string { 162 | if orderBy := q.BuildOrderBy(cols); orderBy != "" { 163 | sql += " " + orderBy 164 | } 165 | if limit := q.BuildLimit(limit, offset); limit != "" { 166 | return sql + " " + limit 167 | } 168 | return sql 169 | } 170 | 171 | // BuildUnion generates a UNION clause from the given union information. 172 | func (q *BaseQueryBuilder) BuildUnion(unions []UnionInfo, params Params) string { 173 | if len(unions) == 0 { 174 | return "" 175 | } 176 | sql := "" 177 | for i, union := range unions { 178 | if i > 0 { 179 | sql += " " 180 | } 181 | for k, v := range union.Query.params { 182 | params[k] = v 183 | } 184 | u := "UNION" 185 | if union.All { 186 | u = "UNION ALL" 187 | } 188 | sql += fmt.Sprintf("%v (%v)", u, union.Query.sql) 189 | } 190 | return sql 191 | } 192 | 193 | var orderRegex = regexp.MustCompile(`\s+((?i)ASC|DESC)$`) 194 | 195 | // BuildOrderBy generates the ORDER BY clause. 196 | func (q *BaseQueryBuilder) BuildOrderBy(cols []string) string { 197 | if len(cols) == 0 { 198 | return "" 199 | } 200 | s := "" 201 | for i, col := range cols { 202 | if i > 0 { 203 | s += ", " 204 | } 205 | matches := orderRegex.FindStringSubmatch(col) 206 | if len(matches) == 0 { 207 | s += q.db.QuoteColumnName(col) 208 | } else { 209 | col := col[:len(col)-len(matches[0])] 210 | dir := matches[1] 211 | s += q.db.QuoteColumnName(col) + " " + dir 212 | } 213 | } 214 | return "ORDER BY " + s 215 | } 216 | 217 | // BuildLimit generates the LIMIT clause. 218 | func (q *BaseQueryBuilder) BuildLimit(limit int64, offset int64) string { 219 | if limit < 0 && offset > 0 { 220 | // most DBMS requires LIMIT when OFFSET is present 221 | limit = 9223372036854775807 // 2^63 - 1 222 | } 223 | 224 | sql := "" 225 | if limit >= 0 { 226 | sql = fmt.Sprintf("LIMIT %v", limit) 227 | } 228 | if offset <= 0 { 229 | return sql 230 | } 231 | if sql != "" { 232 | sql += " " 233 | } 234 | return sql + fmt.Sprintf("OFFSET %v", offset) 235 | } 236 | 237 | func (q *BaseQueryBuilder) quoteTableNameAndAlias(table string) string { 238 | matches := selectRegex.FindStringSubmatch(table) 239 | if len(matches) == 0 { 240 | return q.db.QuoteTableName(table) 241 | } 242 | table = table[:len(table)-len(matches[0])] 243 | return q.db.QuoteTableName(table) + " " + q.db.QuoteSimpleTableName(matches[1]) 244 | } 245 | -------------------------------------------------------------------------------- /query_builder_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestQB_BuildSelect(t *testing.T) { 14 | tests := []struct { 15 | tag string 16 | cols []string 17 | distinct bool 18 | option string 19 | expected string 20 | }{ 21 | {"empty", []string{}, false, "", "SELECT *"}, 22 | {"empty distinct", []string{}, true, "CALC_ROWS", "SELECT DISTINCT CALC_ROWS *"}, 23 | {"multi-columns", []string{"name", "DOB1"}, false, "", "SELECT `name`, `DOB1`"}, 24 | {"aliased columns", []string{"name As Name", "users.last_name", "u.first1 first"}, false, "", "SELECT `name` AS `Name`, `users`.`last_name`, `u`.`first1` AS `first`"}, 25 | } 26 | 27 | db := getDB() 28 | qb := db.QueryBuilder() 29 | for _, test := range tests { 30 | s := qb.BuildSelect(test.cols, test.distinct, test.option) 31 | assert.Equal(t, test.expected, s, test.tag) 32 | } 33 | assert.Equal(t, qb.(*BaseQueryBuilder).DB(), db) 34 | } 35 | 36 | func TestQB_BuildFrom(t *testing.T) { 37 | tests := []struct { 38 | tag string 39 | tables []string 40 | expected string 41 | }{ 42 | {"empty", []string{}, ""}, 43 | {"single table", []string{"users"}, "FROM `users`"}, 44 | {"multiple tables", []string{"users", "posts"}, "FROM `users`, `posts`"}, 45 | {"table alias", []string{"users u", "posts as p"}, "FROM `users` `u`, `posts` `p`"}, 46 | {"table prefix and alias", []string{"pub.users p.u", "posts AS p1"}, "FROM `pub`.`users` `p.u`, `posts` `p1`"}, 47 | } 48 | 49 | qb := getDB().QueryBuilder() 50 | for _, test := range tests { 51 | s := qb.BuildFrom(test.tables) 52 | assert.Equal(t, test.expected, s, test.tag) 53 | } 54 | } 55 | 56 | func TestQB_BuildGroupBy(t *testing.T) { 57 | tests := []struct { 58 | tag string 59 | cols []string 60 | expected string 61 | }{ 62 | {"empty", []string{}, ""}, 63 | {"single column", []string{"name"}, "GROUP BY `name`"}, 64 | {"multiple columns", []string{"name", "age"}, "GROUP BY `name`, `age`"}, 65 | } 66 | 67 | qb := getDB().QueryBuilder() 68 | for _, test := range tests { 69 | s := qb.BuildGroupBy(test.cols) 70 | assert.Equal(t, test.expected, s, test.tag) 71 | } 72 | } 73 | 74 | func TestQB_BuildWhere(t *testing.T) { 75 | tests := []struct { 76 | exp Expression 77 | expected string 78 | count int 79 | tag string 80 | }{ 81 | {HashExp{"age": 30, "dept": "marketing"}, "WHERE `age`={:p0} AND `dept`={:p1}", 2, "t1"}, 82 | {nil, "", 0, "t2"}, 83 | {NewExp(""), "", 0, "t3"}, 84 | } 85 | 86 | qb := getDB().QueryBuilder() 87 | for _, test := range tests { 88 | params := Params{} 89 | s := qb.BuildWhere(test.exp, params) 90 | assert.Equal(t, test.expected, s, test.tag) 91 | assert.Equal(t, test.count, len(params), test.tag) 92 | } 93 | } 94 | 95 | func TestQB_BuildHaving(t *testing.T) { 96 | tests := []struct { 97 | exp Expression 98 | expected string 99 | count int 100 | tag string 101 | }{ 102 | {HashExp{"age": 30, "dept": "marketing"}, "HAVING `age`={:p0} AND `dept`={:p1}", 2, "t1"}, 103 | {nil, "", 0, "t2"}, 104 | {NewExp(""), "", 0, "t3"}, 105 | } 106 | 107 | qb := getDB().QueryBuilder() 108 | for _, test := range tests { 109 | params := Params{} 110 | s := qb.BuildHaving(test.exp, params) 111 | assert.Equal(t, test.expected, s, test.tag) 112 | assert.Equal(t, test.count, len(params), test.tag) 113 | } 114 | } 115 | 116 | func TestQB_BuildOrderBy(t *testing.T) { 117 | tests := []struct { 118 | tag string 119 | cols []string 120 | expected string 121 | }{ 122 | {"empty", []string{}, ""}, 123 | {"single column", []string{"name"}, "ORDER BY `name`"}, 124 | {"multiple columns", []string{"name ASC", "age DESC", "id desc"}, "ORDER BY `name` ASC, `age` DESC, `id` desc"}, 125 | } 126 | qb := getDB().QueryBuilder().(*BaseQueryBuilder) 127 | for _, test := range tests { 128 | s := qb.BuildOrderBy(test.cols) 129 | assert.Equal(t, test.expected, s, test.tag) 130 | } 131 | } 132 | 133 | func TestQB_BuildLimit(t *testing.T) { 134 | tests := []struct { 135 | tag string 136 | limit, offset int64 137 | expected string 138 | }{ 139 | {"t1", 10, -1, "LIMIT 10"}, 140 | {"t2", 10, 0, "LIMIT 10"}, 141 | {"t3", 10, 2, "LIMIT 10 OFFSET 2"}, 142 | {"t4", 0, 2, "LIMIT 0 OFFSET 2"}, 143 | {"t5", -1, 2, "LIMIT 9223372036854775807 OFFSET 2"}, 144 | {"t6", -1, 0, ""}, 145 | } 146 | qb := getDB().QueryBuilder().(*BaseQueryBuilder) 147 | for _, test := range tests { 148 | s := qb.BuildLimit(test.limit, test.offset) 149 | assert.Equal(t, test.expected, s, test.tag) 150 | } 151 | } 152 | 153 | func TestQB_BuildOrderByAndLimit(t *testing.T) { 154 | qb := getDB().QueryBuilder() 155 | 156 | sql := qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, 10, 2) 157 | expected := "SELECT * ORDER BY `name` LIMIT 10 OFFSET 2" 158 | assert.Equal(t, sql, expected, "t1") 159 | 160 | sql = qb.BuildOrderByAndLimit("SELECT *", nil, -1, -1) 161 | expected = "SELECT *" 162 | assert.Equal(t, sql, expected, "t2") 163 | 164 | sql = qb.BuildOrderByAndLimit("SELECT *", []string{"name"}, -1, -1) 165 | expected = "SELECT * ORDER BY `name`" 166 | assert.Equal(t, sql, expected, "t3") 167 | 168 | sql = qb.BuildOrderByAndLimit("SELECT *", nil, 10, -1) 169 | expected = "SELECT * LIMIT 10" 170 | assert.Equal(t, sql, expected, "t4") 171 | } 172 | 173 | func TestQB_BuildJoin(t *testing.T) { 174 | qb := getDB().QueryBuilder() 175 | 176 | params := Params{} 177 | ji := JoinInfo{"LEFT JOIN", "users u", NewExp("id=u.id", Params{"id": 1})} 178 | sql := qb.BuildJoin([]JoinInfo{ji}, params) 179 | expected := "LEFT JOIN `users` `u` ON id=u.id" 180 | assert.Equal(t, sql, expected, "BuildJoin@1") 181 | assert.Equal(t, len(params), 1, "len(params)@1") 182 | 183 | params = Params{} 184 | ji = JoinInfo{"INNER JOIN", "users", nil} 185 | sql = qb.BuildJoin([]JoinInfo{ji}, params) 186 | expected = "INNER JOIN `users`" 187 | assert.Equal(t, sql, expected, "BuildJoin@2") 188 | assert.Equal(t, len(params), 0, "len(params)@2") 189 | 190 | sql = qb.BuildJoin([]JoinInfo{}, nil) 191 | expected = "" 192 | assert.Equal(t, sql, expected, "BuildJoin@3") 193 | 194 | ji = JoinInfo{"INNER JOIN", "users", nil} 195 | ji2 := JoinInfo{"LEFT JOIN", "posts", nil} 196 | sql = qb.BuildJoin([]JoinInfo{ji, ji2}, nil) 197 | expected = "INNER JOIN `users` LEFT JOIN `posts`" 198 | assert.Equal(t, sql, expected, "BuildJoin@3") 199 | } 200 | 201 | func TestQB_BuildUnion(t *testing.T) { 202 | db := getDB() 203 | qb := db.QueryBuilder() 204 | 205 | params := Params{} 206 | ui := UnionInfo{false, db.NewQuery("SELECT names").Bind(Params{"id": 1})} 207 | sql := qb.BuildUnion([]UnionInfo{ui}, params) 208 | expected := "UNION (SELECT names)" 209 | assert.Equal(t, sql, expected, "BuildUnion@1") 210 | assert.Equal(t, len(params), 1, "len(params)@1") 211 | 212 | params = Params{} 213 | ui = UnionInfo{true, db.NewQuery("SELECT names")} 214 | sql = qb.BuildUnion([]UnionInfo{ui}, params) 215 | expected = "UNION ALL (SELECT names)" 216 | assert.Equal(t, sql, expected, "BuildUnion@2") 217 | assert.Equal(t, len(params), 0, "len(params)@2") 218 | 219 | sql = qb.BuildUnion([]UnionInfo{}, nil) 220 | expected = "" 221 | assert.Equal(t, sql, expected, "BuildUnion@3") 222 | 223 | ui = UnionInfo{true, db.NewQuery("SELECT names")} 224 | ui2 := UnionInfo{false, db.NewQuery("SELECT ages")} 225 | sql = qb.BuildUnion([]UnionInfo{ui, ui2}, nil) 226 | expected = "UNION ALL (SELECT names) UNION (SELECT ages)" 227 | assert.Equal(t, sql, expected, "BuildUnion@4") 228 | } 229 | -------------------------------------------------------------------------------- /query_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | ss "database/sql" 9 | "encoding/json" 10 | "errors" 11 | "testing" 12 | "time" 13 | 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | type City struct { 18 | ID int 19 | Name string 20 | } 21 | 22 | func TestNewQuery(t *testing.T) { 23 | db := getDB() 24 | sql := "SELECT * FROM users WHERE id={:id}" 25 | q := NewQuery(db, db.sqlDB, sql) 26 | assert.Equal(t, q.SQL(), sql, "q.SQL()") 27 | assert.Equal(t, q.rawSQL, "SELECT * FROM users WHERE id=?", "q.RawSQL()") 28 | 29 | assert.Equal(t, len(q.Params()), 0, "len(q.Params())@1") 30 | q.Bind(Params{"id": 1}) 31 | assert.Equal(t, len(q.Params()), 1, "len(q.Params())@2") 32 | } 33 | 34 | func TestQuery_Execute(t *testing.T) { 35 | db := getPreparedDB() 36 | defer db.Close() 37 | 38 | result, err := db.NewQuery("INSERT INTO item (name) VALUES ('test')").Execute() 39 | if assert.Nil(t, err) { 40 | rows, _ := result.RowsAffected() 41 | assert.Equal(t, rows, int64(1), "Result.RowsAffected()") 42 | lastID, _ := result.LastInsertId() 43 | assert.Equal(t, lastID, int64(6), "Result.LastInsertId()") 44 | } 45 | } 46 | 47 | type Customer struct { 48 | scanned bool 49 | 50 | ID int 51 | Email string 52 | Status int 53 | Name string 54 | Address ss.NullString 55 | } 56 | 57 | func (m Customer) TableName() string { 58 | return "customer" 59 | } 60 | 61 | func (m *Customer) PostScan() error { 62 | m.scanned = true 63 | return nil 64 | } 65 | 66 | type CustomerPtr struct { 67 | ID *int `db:"pk"` 68 | Email *string 69 | Status *int 70 | Name string 71 | Address *string 72 | } 73 | 74 | func (m CustomerPtr) TableName() string { 75 | return "customer" 76 | } 77 | 78 | type CustomerNull struct { 79 | ID ss.NullInt64 `db:"pk,id"` 80 | Email ss.NullString 81 | Status *ss.NullInt64 82 | Name string 83 | Address ss.NullString 84 | } 85 | 86 | func (m CustomerNull) TableName() string { 87 | return "customer" 88 | } 89 | 90 | type CustomerEmbedded struct { 91 | Id int 92 | Email *string 93 | InnerCustomer 94 | } 95 | 96 | func (m CustomerEmbedded) TableName() string { 97 | return "customer" 98 | } 99 | 100 | type CustomerEmbedded2 struct { 101 | ID int 102 | Email *string 103 | Inner InnerCustomer 104 | } 105 | 106 | type InnerCustomer struct { 107 | Status ss.NullInt64 108 | Name *string 109 | Address ss.NullString 110 | } 111 | 112 | func TestQuery_Rows(t *testing.T) { 113 | db := getPreparedDB() 114 | defer db.Close() 115 | 116 | var ( 117 | sql string 118 | err error 119 | ) 120 | 121 | // Query.All() 122 | var customers []Customer 123 | sql = `SELECT * FROM customer ORDER BY id` 124 | err = db.NewQuery(sql).All(&customers) 125 | if assert.Nil(t, err) { 126 | assert.Equal(t, len(customers), 3, "len(customers)") 127 | assert.Equal(t, customers[2].ID, 3, "customers[2].ID") 128 | assert.Equal(t, customers[2].Email, `user3@example.com`, "customers[2].Email") 129 | assert.Equal(t, customers[2].Status, 2, "customers[2].Status") 130 | assert.Equal(t, customers[0].scanned, true, "customers[0].scanned") 131 | assert.Equal(t, customers[1].scanned, true, "customers[1].scanned") 132 | assert.Equal(t, customers[2].scanned, true, "customers[2].scanned") 133 | } 134 | 135 | // Query.All() with slice of pointers 136 | var customersPtrSlice []*Customer 137 | sql = `SELECT * FROM customer ORDER BY id` 138 | err = db.NewQuery(sql).All(&customersPtrSlice) 139 | if assert.Nil(t, err) { 140 | assert.Equal(t, len(customersPtrSlice), 3, "len(customersPtrSlice)") 141 | assert.Equal(t, customersPtrSlice[2].ID, 3, "customersPtrSlice[2].ID") 142 | assert.Equal(t, customersPtrSlice[2].Email, `user3@example.com`, "customersPtrSlice[2].Email") 143 | assert.Equal(t, customersPtrSlice[2].Status, 2, "customersPtrSlice[2].Status") 144 | assert.Equal(t, customersPtrSlice[0].scanned, true, "customersPtrSlice[0].scanned") 145 | assert.Equal(t, customersPtrSlice[1].scanned, true, "customersPtrSlice[1].scanned") 146 | assert.Equal(t, customersPtrSlice[2].scanned, true, "customersPtrSlice[2].scanned") 147 | } 148 | 149 | var customers2 []NullStringMap 150 | err = db.NewQuery(sql).All(&customers2) 151 | if assert.Nil(t, err) { 152 | assert.Equal(t, len(customers2), 3, "len(customers2)") 153 | assert.Equal(t, customers2[1]["id"].String, "2", "customers2[1][id]") 154 | assert.Equal(t, customers2[1]["email"].String, `user2@example.com`, "customers2[1][email]") 155 | assert.Equal(t, customers2[1]["status"].String, "1", "customers2[1][status]") 156 | } 157 | err = db.NewQuery(sql).All(customers) 158 | assert.NotNil(t, err) 159 | 160 | var customers3 []string 161 | err = db.NewQuery(sql).All(&customers3) 162 | assert.NotNil(t, err) 163 | 164 | var customers4 string 165 | err = db.NewQuery(sql).All(&customers4) 166 | assert.NotNil(t, err) 167 | 168 | var customers5 []Customer 169 | err = db.NewQuery(`SELECT * FROM customer WHERE id=999`).All(&customers5) 170 | if assert.Nil(t, err) { 171 | assert.NotNil(t, customers5) 172 | assert.Zero(t, len(customers5)) 173 | } 174 | 175 | // One 176 | var customer Customer 177 | sql = `SELECT * FROM customer WHERE id={:id}` 178 | err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customer) 179 | if assert.Nil(t, err) { 180 | assert.Equal(t, customer.ID, 2, "customer.ID") 181 | assert.Equal(t, customer.Email, `user2@example.com`, "customer.Email") 182 | assert.Equal(t, customer.Status, 1, "customer.Status") 183 | } 184 | 185 | var customerPtr2 CustomerPtr 186 | sql = `SELECT id, email, address FROM customer WHERE id=2` 187 | rows2, err := db.sqlDB.Query(sql) 188 | defer rows2.Close() 189 | assert.Nil(t, err) 190 | rows2.Next() 191 | err = rows2.Scan(&customerPtr2.ID, &customerPtr2.Email, &customerPtr2.Address) 192 | if assert.Nil(t, err) { 193 | assert.Equal(t, *customerPtr2.ID, 2, "customer.ID") 194 | assert.Equal(t, *customerPtr2.Email, `user2@example.com`) 195 | assert.Nil(t, customerPtr2.Address) 196 | } 197 | 198 | // struct fields are pointers 199 | var customerPtr CustomerPtr 200 | sql = `SELECT * FROM customer WHERE id={:id}` 201 | err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerPtr) 202 | if assert.Nil(t, err) { 203 | assert.Equal(t, *customerPtr.ID, 2, "customer.ID") 204 | assert.Equal(t, *customerPtr.Email, `user2@example.com`, "customer.Email") 205 | assert.Equal(t, *customerPtr.Status, 1, "customer.Status") 206 | } 207 | 208 | // struct fields are null types 209 | var customerNull CustomerNull 210 | sql = `SELECT * FROM customer WHERE id={:id}` 211 | err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerNull) 212 | if assert.Nil(t, err) { 213 | assert.Equal(t, customerNull.ID.Int64, int64(2), "customer.ID") 214 | assert.Equal(t, customerNull.Email.String, `user2@example.com`, "customer.Email") 215 | assert.Equal(t, customerNull.Status.Int64, int64(1), "customer.Status") 216 | } 217 | 218 | // embedded with anonymous struct 219 | var customerEmbedded CustomerEmbedded 220 | sql = `SELECT * FROM customer WHERE id={:id}` 221 | err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerEmbedded) 222 | if assert.Nil(t, err) { 223 | assert.Equal(t, customerEmbedded.Id, 2, "customer.ID") 224 | assert.Equal(t, *customerEmbedded.Email, `user2@example.com`, "customer.Email") 225 | assert.Equal(t, customerEmbedded.Status.Int64, int64(1), "customer.Status") 226 | } 227 | 228 | // embedded with named struct 229 | var customerEmbedded2 CustomerEmbedded2 230 | sql = `SELECT id, email, status as "inner.status" FROM customer WHERE id={:id}` 231 | err = db.NewQuery(sql).Bind(Params{"id": 2}).One(&customerEmbedded2) 232 | if assert.Nil(t, err) { 233 | assert.Equal(t, customerEmbedded2.ID, 2, "customer.ID") 234 | assert.Equal(t, *customerEmbedded2.Email, `user2@example.com`, "customer.Email") 235 | assert.Equal(t, customerEmbedded2.Inner.Status.Int64, int64(1), "customer.Status") 236 | } 237 | 238 | customer2 := NullStringMap{} 239 | sql = `SELECT * FROM customer WHERE id={:id}` 240 | err = db.NewQuery(sql).Bind(Params{"id": 1}).One(customer2) 241 | if assert.Nil(t, err) { 242 | assert.Equal(t, customer2["id"].String, "1", "customer2[id]") 243 | assert.Equal(t, customer2["email"].String, `user1@example.com`, "customer2[email]") 244 | assert.Equal(t, customer2["status"].String, "1", "customer2[status]") 245 | } 246 | 247 | err = db.NewQuery(sql).Bind(Params{"id": 2}).One(customer) 248 | assert.NotNil(t, err) 249 | 250 | var customer3 NullStringMap 251 | err = db.NewQuery(sql).Bind(Params{"id": 2}).One(customer3) 252 | assert.NotNil(t, err) 253 | 254 | err = db.NewQuery(sql).Bind(Params{"id": 1}).One(&customer3) 255 | if assert.Nil(t, err) { 256 | assert.Equal(t, customer3["id"].String, "1", "customer3[id]") 257 | } 258 | 259 | // Rows 260 | sql = `SELECT * FROM customer ORDER BY id DESC` 261 | rows, err := db.NewQuery(sql).Rows() 262 | if assert.Nil(t, err) { 263 | s := "" 264 | for rows.Next() { 265 | rows.ScanStruct(&customer) 266 | s += customer.Email + "," 267 | } 268 | assert.Equal(t, s, "user3@example.com,user2@example.com,user1@example.com,", "Rows().Next()") 269 | } 270 | 271 | // FieldMapper 272 | var a struct { 273 | MyID string `db:"id"` 274 | name string 275 | } 276 | sql = `SELECT * FROM customer WHERE id=2` 277 | err = db.NewQuery(sql).One(&a) 278 | if assert.Nil(t, err) { 279 | assert.Equal(t, a.MyID, "2", "a.MyID") 280 | // unexported field is not populated 281 | assert.Equal(t, a.name, "", "a.name") 282 | } 283 | 284 | // prepared statement 285 | sql = `SELECT * FROM customer WHERE id={:id}` 286 | q := db.NewQuery(sql).Prepare() 287 | q.Bind(Params{"id": 1}).One(&customer) 288 | assert.Equal(t, customer.ID, 1, "prepared@1") 289 | err = q.Bind(Params{"id": 20}).One(&customer) 290 | assert.Equal(t, err, ss.ErrNoRows, "prepared@2") 291 | q.Bind(Params{"id": 3}).One(&customer) 292 | assert.Equal(t, customer.ID, 3, "prepared@3") 293 | 294 | sql = `SELECT name FROM customer WHERE id={:id}` 295 | var name string 296 | q = db.NewQuery(sql).Prepare() 297 | q.Bind(Params{"id": 1}).Row(&name) 298 | assert.Equal(t, name, "user1", "prepared2@1") 299 | err = q.Bind(Params{"id": 20}).Row(&name) 300 | assert.Equal(t, err, ss.ErrNoRows, "prepared2@2") 301 | q.Bind(Params{"id": 3}).Row(&name) 302 | assert.Equal(t, name, "user3", "prepared2@3") 303 | 304 | // Query.LastError 305 | sql = `SELECT * FROM a` 306 | q = db.NewQuery(sql).Prepare() 307 | customer.ID = 100 308 | err = q.Bind(Params{"id": 1}).One(&customer) 309 | assert.NotEqual(t, err, nil, "LastError@0") 310 | assert.Equal(t, customer.ID, 100, "LastError@1") 311 | assert.Equal(t, q.LastError, nil, "LastError@2") 312 | 313 | // Query.Column 314 | sql = `SELECT name, id FROM customer ORDER BY id` 315 | var names []string 316 | err = db.NewQuery(sql).Column(&names) 317 | if assert.Nil(t, err) && assert.Equal(t, 3, len(names)) { 318 | assert.Equal(t, "user1", names[0]) 319 | assert.Equal(t, "user2", names[1]) 320 | assert.Equal(t, "user3", names[2]) 321 | } 322 | err = db.NewQuery(sql).Column(names) 323 | assert.NotNil(t, err) 324 | } 325 | 326 | func TestQuery_logSQL(t *testing.T) { 327 | db := getDB() 328 | q := db.NewQuery("SELECT * FROM users WHERE type={:type} AND id={:id} AND bytes={:bytes}").Bind(Params{ 329 | "id": 1, 330 | "type": "a", 331 | "bytes": []byte("test"), 332 | }) 333 | expected := "SELECT * FROM users WHERE type='a' AND id=1 AND bytes=0x74657374" 334 | assert.Equal(t, q.logSQL(), expected, "logSQL()") 335 | } 336 | 337 | func TestReplacePlaceholders(t *testing.T) { 338 | tests := []struct { 339 | ID string 340 | Placeholders []string 341 | Params Params 342 | ExpectedParams string 343 | HasError bool 344 | }{ 345 | {"t1", nil, nil, "null", false}, 346 | {"t2", []string{"id", "name"}, Params{"id": 1, "name": "xyz"}, `[1,"xyz"]`, false}, 347 | {"t3", []string{"id", "name"}, Params{"id": 1}, `null`, true}, 348 | {"t4", []string{"id", "name"}, Params{"id": 1, "name": "xyz", "age": 30}, `[1,"xyz"]`, false}, 349 | } 350 | for _, test := range tests { 351 | params, err := replacePlaceholders(test.Placeholders, test.Params) 352 | result, _ := json.Marshal(params) 353 | assert.Equal(t, string(result), test.ExpectedParams, "params@"+test.ID) 354 | assert.Equal(t, err != nil, test.HasError, "error@"+test.ID) 355 | } 356 | } 357 | 358 | func TestIssue6(t *testing.T) { 359 | db := getPreparedDB() 360 | q := db.Select("*").From("customer").Where(HashExp{"id": 1}) 361 | var customer Customer 362 | assert.Equal(t, q.One(&customer), nil) 363 | assert.Equal(t, 1, customer.ID) 364 | } 365 | 366 | type User struct { 367 | ID int64 368 | Email string 369 | Created time.Time 370 | Updated *time.Time 371 | } 372 | 373 | func TestIssue13(t *testing.T) { 374 | db := getPreparedDB() 375 | var user User 376 | err := db.Select().From("user").Where(HashExp{"id": 1}).One(&user) 377 | if assert.Nil(t, err) { 378 | assert.NotZero(t, user.Created) 379 | assert.Nil(t, user.Updated) 380 | } 381 | 382 | now := time.Now() 383 | 384 | user2 := User{ 385 | Email: "now@example.com", 386 | Created: now, 387 | } 388 | err = db.Model(&user2).Insert() 389 | if assert.Nil(t, err) { 390 | assert.NotZero(t, user2.ID) 391 | } 392 | 393 | user3 := User{ 394 | Email: "now@example.com", 395 | Created: now, 396 | Updated: &now, 397 | } 398 | err = db.Model(&user3).Insert() 399 | if assert.Nil(t, err) { 400 | assert.NotZero(t, user2.ID) 401 | } 402 | } 403 | 404 | func TestQueryWithExecHook(t *testing.T) { 405 | db := getPreparedDB() 406 | defer db.Close() 407 | 408 | // error return 409 | { 410 | err := db.NewQuery("select * from user"). 411 | WithExecHook(func(q *Query, op func() error) error { 412 | return errors.New("test") 413 | }). 414 | Row() 415 | 416 | assert.Error(t, err) 417 | } 418 | 419 | // Row() 420 | { 421 | calls := 0 422 | err := db.NewQuery("select * from user"). 423 | WithExecHook(func(q *Query, op func() error) error { 424 | calls++ 425 | return nil 426 | }). 427 | Row() 428 | assert.Nil(t, err) 429 | assert.Equal(t, 1, calls, "Row()") 430 | } 431 | 432 | // One() 433 | { 434 | calls := 0 435 | err := db.NewQuery("select * from user"). 436 | WithExecHook(func(q *Query, op func() error) error { 437 | calls++ 438 | return nil 439 | }). 440 | One(nil) 441 | assert.Nil(t, err) 442 | assert.Equal(t, 1, calls, "One()") 443 | } 444 | 445 | // All() 446 | { 447 | calls := 0 448 | err := db.NewQuery("select * from user"). 449 | WithExecHook(func(q *Query, op func() error) error { 450 | calls++ 451 | return nil 452 | }). 453 | All(nil) 454 | assert.Nil(t, err) 455 | assert.Equal(t, 1, calls, "All()") 456 | } 457 | 458 | // Column() 459 | { 460 | calls := 0 461 | err := db.NewQuery("select * from user"). 462 | WithExecHook(func(q *Query, op func() error) error { 463 | calls++ 464 | return nil 465 | }). 466 | Column(nil) 467 | assert.Nil(t, err) 468 | assert.Equal(t, 1, calls, "Column()") 469 | } 470 | 471 | // Execute() 472 | { 473 | calls := 0 474 | _, err := db.NewQuery("select * from user"). 475 | WithExecHook(func(q *Query, op func() error) error { 476 | calls++ 477 | return nil 478 | }). 479 | Execute() 480 | assert.Nil(t, err) 481 | assert.Equal(t, 1, calls, "Execute()") 482 | } 483 | 484 | // op call 485 | { 486 | calls := 0 487 | var id int 488 | err := db.NewQuery("select id from user where id = 2"). 489 | WithExecHook(func(q *Query, op func() error) error { 490 | calls++ 491 | return op() 492 | }). 493 | Row(&id) 494 | assert.Nil(t, err) 495 | assert.Equal(t, 1, calls, "op hook calls") 496 | assert.Equal(t, 2, id, "id mismatch") 497 | } 498 | } 499 | 500 | func TestQueryWithOneHook(t *testing.T) { 501 | db := getPreparedDB() 502 | defer db.Close() 503 | 504 | // error return 505 | { 506 | err := db.NewQuery("select * from user"). 507 | WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error { 508 | return errors.New("test") 509 | }). 510 | One(nil) 511 | 512 | assert.Error(t, err) 513 | } 514 | 515 | // hooks call order 516 | { 517 | hookCalls := []string{} 518 | err := db.NewQuery("select * from user"). 519 | WithExecHook(func(q *Query, op func() error) error { 520 | hookCalls = append(hookCalls, "exec") 521 | return op() 522 | }). 523 | WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error { 524 | hookCalls = append(hookCalls, "one") 525 | return nil 526 | }). 527 | One(nil) 528 | 529 | assert.Nil(t, err) 530 | assert.Equal(t, hookCalls, []string{"exec", "one"}) 531 | } 532 | 533 | // op call 534 | { 535 | calls := 0 536 | other := User{} 537 | err := db.NewQuery("select id from user where id = 2"). 538 | WithOneHook(func(q *Query, a interface{}, op func(b interface{}) error) error { 539 | calls++ 540 | return op(&other) 541 | }). 542 | One(nil) 543 | 544 | assert.Nil(t, err) 545 | assert.Equal(t, 1, calls, "hook calls") 546 | assert.Equal(t, int64(2), other.ID, "replaced scan struct") 547 | } 548 | } 549 | 550 | func TestQueryWithAllHook(t *testing.T) { 551 | db := getPreparedDB() 552 | defer db.Close() 553 | 554 | // error return 555 | { 556 | err := db.NewQuery("select * from user"). 557 | WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error { 558 | return errors.New("test") 559 | }). 560 | All(nil) 561 | 562 | assert.Error(t, err) 563 | } 564 | 565 | // hooks call order 566 | { 567 | hookCalls := []string{} 568 | err := db.NewQuery("select * from user"). 569 | WithExecHook(func(q *Query, op func() error) error { 570 | hookCalls = append(hookCalls, "exec") 571 | return op() 572 | }). 573 | WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error { 574 | hookCalls = append(hookCalls, "all") 575 | return nil 576 | }). 577 | All(nil) 578 | 579 | assert.Nil(t, err) 580 | assert.Equal(t, hookCalls, []string{"exec", "all"}) 581 | } 582 | 583 | // op call 584 | { 585 | calls := 0 586 | other := []User{} 587 | err := db.NewQuery("select id from user order by id asc"). 588 | WithAllHook(func(q *Query, a interface{}, op func(b interface{}) error) error { 589 | calls++ 590 | return op(&other) 591 | }). 592 | All(nil) 593 | 594 | assert.Nil(t, err) 595 | assert.Equal(t, 1, calls, "hook calls") 596 | assert.Equal(t, 2, len(other), "users length") 597 | assert.Equal(t, int64(1), other[0].ID, "user 1 id check") 598 | assert.Equal(t, int64(2), other[1].ID, "user 2 id check") 599 | } 600 | } 601 | -------------------------------------------------------------------------------- /rows.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "database/sql" 9 | "reflect" 10 | ) 11 | 12 | // VarTypeError indicates a variable type error when trying to populating a variable with DB result. 13 | type VarTypeError string 14 | 15 | // Error returns the error message. 16 | func (s VarTypeError) Error() string { 17 | return "Invalid variable type: " + string(s) 18 | } 19 | 20 | // NullStringMap is a map of sql.NullString that can be used to hold DB query result. 21 | // The map keys correspond to the DB column names, while the map values are their corresponding column values. 22 | type NullStringMap map[string]sql.NullString 23 | 24 | // Rows enhances sql.Rows by providing additional data query methods. 25 | // Rows can be obtained by calling Query.Rows(). It is mainly used to populate data row by row. 26 | type Rows struct { 27 | *sql.Rows 28 | fieldMapFunc FieldMapFunc 29 | } 30 | 31 | // ScanMap populates the current row of data into a NullStringMap. 32 | // Note that the NullStringMap must not be nil, or it will panic. 33 | // The NullStringMap will be populated using column names as keys and their values as 34 | // the corresponding element values. 35 | func (r *Rows) ScanMap(a NullStringMap) error { 36 | cols, _ := r.Columns() 37 | var refs []interface{} 38 | for i := 0; i < len(cols); i++ { 39 | var t sql.NullString 40 | refs = append(refs, &t) 41 | } 42 | if err := r.Scan(refs...); err != nil { 43 | return err 44 | } 45 | 46 | for i, col := range cols { 47 | a[col] = *refs[i].(*sql.NullString) 48 | } 49 | 50 | return nil 51 | } 52 | 53 | // ScanStruct populates the current row of data into a struct. 54 | // The struct must be given as a pointer. 55 | // 56 | // ScanStruct associates struct fields with DB table columns through a field mapping function. 57 | // It populates a struct field with the data of its associated column. 58 | // Note that only exported struct fields will be populated. 59 | // 60 | // By default, DefaultFieldMapFunc() is used to map struct fields to table columns. 61 | // This function separates each word in a field name with a underscore and turns every letter into lower case. 62 | // For example, "LastName" is mapped to "last_name", "MyID" is mapped to "my_id", and so on. 63 | // To change the default behavior, set DB.FieldMapper with your custom mapping function. 64 | // You may also set Query.FieldMapper to change the behavior for particular queries. 65 | func (r *Rows) ScanStruct(a interface{}) error { 66 | rv := reflect.ValueOf(a) 67 | if rv.Kind() != reflect.Ptr || rv.IsNil() { 68 | return VarTypeError("must be a pointer") 69 | } 70 | rv = indirect(rv) 71 | if rv.Kind() != reflect.Struct { 72 | return VarTypeError("must be a pointer to a struct") 73 | } 74 | 75 | si := getStructInfo(rv.Type(), r.fieldMapFunc) 76 | 77 | cols, _ := r.Columns() 78 | refs := make([]interface{}, len(cols)) 79 | 80 | for i, col := range cols { 81 | if fi, ok := si.dbNameMap[col]; ok { 82 | refs[i] = fi.getField(rv).Addr().Interface() 83 | } else { 84 | refs[i] = &sql.NullString{} 85 | } 86 | } 87 | 88 | if err := r.Scan(refs...); err != nil { 89 | return err 90 | } 91 | 92 | // check for PostScanner 93 | if rv.CanAddr() { 94 | addr := rv.Addr() 95 | if addr.CanInterface() { 96 | if ps, ok := addr.Interface().(PostScanner); ok { 97 | if err := ps.PostScan(); err != nil { 98 | return err 99 | } 100 | } 101 | } 102 | } 103 | 104 | return nil 105 | } 106 | 107 | // all populates all rows of query result into a slice of struct or NullStringMap. 108 | // Note that the slice must be given as a pointer. 109 | func (r *Rows) all(slice interface{}) error { 110 | defer r.Close() 111 | 112 | v := reflect.ValueOf(slice) 113 | if v.Kind() != reflect.Ptr || v.IsNil() { 114 | return VarTypeError("must be a pointer") 115 | } 116 | v = indirect(v) 117 | 118 | if v.Kind() != reflect.Slice { 119 | return VarTypeError("must be a slice of struct or NullStringMap") 120 | } 121 | 122 | if v.IsNil() { 123 | // create an empty slice 124 | v.Set(reflect.MakeSlice(v.Type(), 0, 0)) 125 | } 126 | 127 | et := v.Type().Elem() 128 | 129 | if et.Kind() == reflect.Map { 130 | for r.Next() { 131 | ev, ok := reflect.MakeMap(et).Interface().(NullStringMap) 132 | if !ok { 133 | return VarTypeError("must be a slice of struct or NullStringMap") 134 | } 135 | if err := r.ScanMap(ev); err != nil { 136 | return err 137 | } 138 | v.Set(reflect.Append(v, reflect.ValueOf(ev))) 139 | } 140 | return r.Close() 141 | } 142 | 143 | var isSliceOfPointers bool 144 | if et.Kind() == reflect.Ptr { 145 | isSliceOfPointers = true 146 | et = et.Elem() 147 | } 148 | 149 | if et.Kind() != reflect.Struct { 150 | return VarTypeError("must be a slice of struct or NullStringMap") 151 | } 152 | 153 | etPtr := reflect.PtrTo(et) 154 | implementsPostScanner := etPtr.Implements(postScannerType) 155 | 156 | si := getStructInfo(et, r.fieldMapFunc) 157 | 158 | cols, _ := r.Columns() 159 | for r.Next() { 160 | ev := reflect.New(et).Elem() 161 | refs := make([]interface{}, len(cols)) 162 | for i, col := range cols { 163 | if fi, ok := si.dbNameMap[col]; ok { 164 | refs[i] = fi.getField(ev).Addr().Interface() 165 | } else { 166 | refs[i] = &sql.NullString{} 167 | } 168 | } 169 | if err := r.Scan(refs...); err != nil { 170 | return err 171 | } 172 | 173 | if isSliceOfPointers { 174 | ev = ev.Addr() 175 | } 176 | 177 | // check for PostScanner 178 | if implementsPostScanner { 179 | evAddr := ev 180 | if ev.CanAddr() { 181 | evAddr = ev.Addr() 182 | } 183 | if evAddr.CanInterface() { 184 | if ps, ok := evAddr.Interface().(PostScanner); ok { 185 | if err := ps.PostScan(); err != nil { 186 | return err 187 | } 188 | } 189 | } 190 | } 191 | 192 | v.Set(reflect.Append(v, ev)) 193 | } 194 | 195 | return r.Close() 196 | } 197 | 198 | // column populates the given slice with the first column of the query result. 199 | // Note that the slice must be given as a pointer. 200 | func (r *Rows) column(slice interface{}) error { 201 | defer r.Close() 202 | 203 | v := reflect.ValueOf(slice) 204 | if v.Kind() != reflect.Ptr || v.IsNil() { 205 | return VarTypeError("must be a pointer to a slice") 206 | } 207 | v = indirect(v) 208 | 209 | if v.Kind() != reflect.Slice { 210 | return VarTypeError("must be a pointer to a slice") 211 | } 212 | 213 | et := v.Type().Elem() 214 | 215 | cols, _ := r.Columns() 216 | for r.Next() { 217 | ev := reflect.New(et) 218 | refs := make([]interface{}, len(cols)) 219 | for i := range cols { 220 | if i == 0 { 221 | refs[i] = ev.Interface() 222 | } else { 223 | refs[i] = &sql.NullString{} 224 | } 225 | } 226 | if err := r.Scan(refs...); err != nil { 227 | return err 228 | } 229 | v.Set(reflect.Append(v, ev.Elem())) 230 | } 231 | 232 | return r.Close() 233 | } 234 | 235 | // one populates a single row of query result into a struct or a NullStringMap. 236 | // Note that if a struct is given, it should be a pointer. 237 | func (r *Rows) one(a interface{}) error { 238 | defer r.Close() 239 | 240 | if !r.Next() { 241 | if err := r.Err(); err != nil { 242 | return err 243 | } 244 | return sql.ErrNoRows 245 | } 246 | 247 | var err error 248 | 249 | rt := reflect.TypeOf(a) 250 | if rt.Kind() == reflect.Ptr && rt.Elem().Kind() == reflect.Map { 251 | // pointer to map 252 | v := indirect(reflect.ValueOf(a)) 253 | if v.IsNil() { 254 | v.Set(reflect.MakeMap(v.Type())) 255 | } 256 | a = v.Interface() 257 | rt = reflect.TypeOf(a) 258 | } 259 | 260 | if rt.Kind() == reflect.Map { 261 | v, ok := a.(NullStringMap) 262 | if !ok { 263 | return VarTypeError("must be a NullStringMap") 264 | } 265 | if v == nil { 266 | return VarTypeError("NullStringMap is nil") 267 | } 268 | err = r.ScanMap(v) 269 | } else { 270 | err = r.ScanStruct(a) 271 | } 272 | 273 | if err != nil { 274 | return err 275 | } 276 | 277 | return r.Close() 278 | } 279 | 280 | // row populates a single row of query result into a list of variables. 281 | func (r *Rows) row(a ...interface{}) error { 282 | defer r.Close() 283 | 284 | for _, dp := range a { 285 | if _, ok := dp.(*sql.RawBytes); ok { 286 | return VarTypeError("RawBytes isn't allowed on Row()") 287 | } 288 | } 289 | 290 | if !r.Next() { 291 | if err := r.Err(); err != nil { 292 | return err 293 | } 294 | return sql.ErrNoRows 295 | } 296 | if err := r.Scan(a...); err != nil { 297 | return err 298 | } 299 | 300 | return r.Close() 301 | } 302 | -------------------------------------------------------------------------------- /select.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "reflect" 11 | ) 12 | 13 | // BuildHookFunc defines a callback function that is executed on Query creation. 14 | type BuildHookFunc func(q *Query) 15 | 16 | // SelectQuery represents a DB-agnostic SELECT query. 17 | // It can be built into a DB-specific query by calling the Build() method. 18 | type SelectQuery struct { 19 | // FieldMapper maps struct field names to DB column names. 20 | FieldMapper FieldMapFunc 21 | // TableMapper maps structs to DB table names. 22 | TableMapper TableMapFunc 23 | 24 | builder Builder 25 | ctx context.Context 26 | buildHook BuildHookFunc 27 | 28 | preFragment string 29 | postFragment string 30 | selects []string 31 | distinct bool 32 | selectOption string 33 | from []string 34 | where Expression 35 | join []JoinInfo 36 | orderBy []string 37 | groupBy []string 38 | having Expression 39 | union []UnionInfo 40 | limit int64 41 | offset int64 42 | params Params 43 | } 44 | 45 | // JoinInfo contains the specification for a JOIN clause. 46 | type JoinInfo struct { 47 | Join string 48 | Table string 49 | On Expression 50 | } 51 | 52 | // UnionInfo contains the specification for a UNION clause. 53 | type UnionInfo struct { 54 | All bool 55 | Query *Query 56 | } 57 | 58 | // NewSelectQuery creates a new SelectQuery instance. 59 | func NewSelectQuery(builder Builder, db *DB) *SelectQuery { 60 | return &SelectQuery{ 61 | builder: builder, 62 | selects: []string{}, 63 | from: []string{}, 64 | join: []JoinInfo{}, 65 | orderBy: []string{}, 66 | groupBy: []string{}, 67 | union: []UnionInfo{}, 68 | limit: -1, 69 | params: Params{}, 70 | ctx: db.ctx, 71 | FieldMapper: db.FieldMapper, 72 | TableMapper: db.TableMapper, 73 | } 74 | } 75 | 76 | // WithBuildHook runs the provided hook function with the query created on Build(). 77 | func (q *SelectQuery) WithBuildHook(fn BuildHookFunc) *SelectQuery { 78 | q.buildHook = fn 79 | return q 80 | } 81 | 82 | // Context returns the context associated with the query. 83 | func (q *SelectQuery) Context() context.Context { 84 | return q.ctx 85 | } 86 | 87 | // WithContext associates a context with the query. 88 | func (q *SelectQuery) WithContext(ctx context.Context) *SelectQuery { 89 | q.ctx = ctx 90 | return q 91 | } 92 | 93 | // PreFragment sets SQL fragment that should be prepended before the select query (e.g. WITH clause). 94 | func (s *SelectQuery) PreFragment(fragment string) *SelectQuery { 95 | s.preFragment = fragment 96 | return s 97 | } 98 | 99 | // PostFragment sets SQL fragment that should be appended at the end of the select query. 100 | func (s *SelectQuery) PostFragment(fragment string) *SelectQuery { 101 | s.postFragment = fragment 102 | return s 103 | } 104 | 105 | // Select specifies the columns to be selected. 106 | // Column names will be automatically quoted. 107 | func (s *SelectQuery) Select(cols ...string) *SelectQuery { 108 | s.selects = cols 109 | return s 110 | } 111 | 112 | // AndSelect adds additional columns to be selected. 113 | // Column names will be automatically quoted. 114 | func (s *SelectQuery) AndSelect(cols ...string) *SelectQuery { 115 | s.selects = append(s.selects, cols...) 116 | return s 117 | } 118 | 119 | // Distinct specifies whether to select columns distinctively. 120 | // By default, distinct is false. 121 | func (s *SelectQuery) Distinct(v bool) *SelectQuery { 122 | s.distinct = v 123 | return s 124 | } 125 | 126 | // SelectOption specifies additional option that should be append to "SELECT". 127 | func (s *SelectQuery) SelectOption(option string) *SelectQuery { 128 | s.selectOption = option 129 | return s 130 | } 131 | 132 | // From specifies which tables to select from. 133 | // Table names will be automatically quoted. 134 | func (s *SelectQuery) From(tables ...string) *SelectQuery { 135 | s.from = tables 136 | return s 137 | } 138 | 139 | // Where specifies the WHERE condition. 140 | func (s *SelectQuery) Where(e Expression) *SelectQuery { 141 | s.where = e 142 | return s 143 | } 144 | 145 | // AndWhere concatenates a new WHERE condition with the existing one (if any) using "AND". 146 | func (s *SelectQuery) AndWhere(e Expression) *SelectQuery { 147 | s.where = And(s.where, e) 148 | return s 149 | } 150 | 151 | // OrWhere concatenates a new WHERE condition with the existing one (if any) using "OR". 152 | func (s *SelectQuery) OrWhere(e Expression) *SelectQuery { 153 | s.where = Or(s.where, e) 154 | return s 155 | } 156 | 157 | // Join specifies a JOIN clause. 158 | // The "typ" parameter specifies the JOIN type (e.g. "INNER JOIN", "LEFT JOIN"). 159 | func (s *SelectQuery) Join(typ string, table string, on Expression) *SelectQuery { 160 | s.join = append(s.join, JoinInfo{typ, table, on}) 161 | return s 162 | } 163 | 164 | // InnerJoin specifies an INNER JOIN clause. 165 | // This is a shortcut method for Join. 166 | func (s *SelectQuery) InnerJoin(table string, on Expression) *SelectQuery { 167 | return s.Join("INNER JOIN", table, on) 168 | } 169 | 170 | // LeftJoin specifies a LEFT JOIN clause. 171 | // This is a shortcut method for Join. 172 | func (s *SelectQuery) LeftJoin(table string, on Expression) *SelectQuery { 173 | return s.Join("LEFT JOIN", table, on) 174 | } 175 | 176 | // RightJoin specifies a RIGHT JOIN clause. 177 | // This is a shortcut method for Join. 178 | func (s *SelectQuery) RightJoin(table string, on Expression) *SelectQuery { 179 | return s.Join("RIGHT JOIN", table, on) 180 | } 181 | 182 | // OrderBy specifies the ORDER BY clause. 183 | // Column names will be properly quoted. A column name can contain "ASC" or "DESC" to indicate its ordering direction. 184 | func (s *SelectQuery) OrderBy(cols ...string) *SelectQuery { 185 | s.orderBy = cols 186 | return s 187 | } 188 | 189 | // AndOrderBy appends additional columns to the existing ORDER BY clause. 190 | // Column names will be properly quoted. A column name can contain "ASC" or "DESC" to indicate its ordering direction. 191 | func (s *SelectQuery) AndOrderBy(cols ...string) *SelectQuery { 192 | s.orderBy = append(s.orderBy, cols...) 193 | return s 194 | } 195 | 196 | // GroupBy specifies the GROUP BY clause. 197 | // Column names will be properly quoted. 198 | func (s *SelectQuery) GroupBy(cols ...string) *SelectQuery { 199 | s.groupBy = cols 200 | return s 201 | } 202 | 203 | // AndGroupBy appends additional columns to the existing GROUP BY clause. 204 | // Column names will be properly quoted. 205 | func (s *SelectQuery) AndGroupBy(cols ...string) *SelectQuery { 206 | s.groupBy = append(s.groupBy, cols...) 207 | return s 208 | } 209 | 210 | // Having specifies the HAVING clause. 211 | func (s *SelectQuery) Having(e Expression) *SelectQuery { 212 | s.having = e 213 | return s 214 | } 215 | 216 | // AndHaving concatenates a new HAVING condition with the existing one (if any) using "AND". 217 | func (s *SelectQuery) AndHaving(e Expression) *SelectQuery { 218 | s.having = And(s.having, e) 219 | return s 220 | } 221 | 222 | // OrHaving concatenates a new HAVING condition with the existing one (if any) using "OR". 223 | func (s *SelectQuery) OrHaving(e Expression) *SelectQuery { 224 | s.having = Or(s.having, e) 225 | return s 226 | } 227 | 228 | // Union specifies a UNION clause. 229 | func (s *SelectQuery) Union(q *Query) *SelectQuery { 230 | s.union = append(s.union, UnionInfo{false, q}) 231 | return s 232 | } 233 | 234 | // UnionAll specifies a UNION ALL clause. 235 | func (s *SelectQuery) UnionAll(q *Query) *SelectQuery { 236 | s.union = append(s.union, UnionInfo{true, q}) 237 | return s 238 | } 239 | 240 | // Limit specifies the LIMIT clause. 241 | // A negative limit means no limit. 242 | func (s *SelectQuery) Limit(limit int64) *SelectQuery { 243 | s.limit = limit 244 | return s 245 | } 246 | 247 | // Offset specifies the OFFSET clause. 248 | // A negative offset means no offset. 249 | func (s *SelectQuery) Offset(offset int64) *SelectQuery { 250 | s.offset = offset 251 | return s 252 | } 253 | 254 | // Bind specifies the parameter values to be bound to the query. 255 | func (s *SelectQuery) Bind(params Params) *SelectQuery { 256 | s.params = params 257 | return s 258 | } 259 | 260 | // AndBind appends additional parameters to be bound to the query. 261 | func (s *SelectQuery) AndBind(params Params) *SelectQuery { 262 | if len(s.params) == 0 { 263 | s.params = params 264 | } else { 265 | for k, v := range params { 266 | s.params[k] = v 267 | } 268 | } 269 | return s 270 | } 271 | 272 | // Build builds the SELECT query and returns an executable Query object. 273 | func (s *SelectQuery) Build() *Query { 274 | params := Params{} 275 | for k, v := range s.params { 276 | params[k] = v 277 | } 278 | 279 | qb := s.builder.QueryBuilder() 280 | 281 | clauses := []string{ 282 | s.preFragment, 283 | qb.BuildSelect(s.selects, s.distinct, s.selectOption), 284 | qb.BuildFrom(s.from), 285 | qb.BuildJoin(s.join, params), 286 | qb.BuildWhere(s.where, params), 287 | qb.BuildGroupBy(s.groupBy), 288 | qb.BuildHaving(s.having, params), 289 | } 290 | 291 | sql := "" 292 | for _, clause := range clauses { 293 | if clause != "" { 294 | if sql == "" { 295 | sql = clause 296 | } else { 297 | sql += " " + clause 298 | } 299 | } 300 | } 301 | 302 | sql = qb.BuildOrderByAndLimit(sql, s.orderBy, s.limit, s.offset) 303 | 304 | if s.postFragment != "" { 305 | sql += " " + s.postFragment 306 | } 307 | 308 | if union := qb.BuildUnion(s.union, params); union != "" { 309 | sql = fmt.Sprintf("(%v) %v", sql, union) 310 | } 311 | 312 | query := s.builder.NewQuery(sql).Bind(params).WithContext(s.ctx) 313 | 314 | if s.buildHook != nil { 315 | s.buildHook(query) 316 | } 317 | 318 | return query 319 | } 320 | 321 | // One executes the SELECT query and populates the first row of the result into the specified variable. 322 | // 323 | // If the query does not specify a "from" clause, the method will try to infer the name of the table 324 | // to be selected from by calling getTableName() which will return either the variable type name 325 | // or the TableName() method if the variable implements the TableModel interface. 326 | // 327 | // Note that when the query has no rows in the result set, an sql.ErrNoRows will be returned. 328 | func (s *SelectQuery) One(a interface{}) error { 329 | if len(s.from) == 0 { 330 | if tableName := s.TableMapper(a); tableName != "" { 331 | s.from = []string{tableName} 332 | } 333 | } 334 | 335 | return s.Build().One(a) 336 | } 337 | 338 | // Model selects the row with the specified primary key and populates the model with the row data. 339 | // 340 | // The model variable should be a pointer to a struct. If the query does not specify a "from" clause, 341 | // it will use the model struct to determine which table to select data from. It will also use the model 342 | // to infer the name of the primary key column. Only simple primary key is supported. For composite primary keys, 343 | // please use Where() to specify the filtering condition. 344 | func (s *SelectQuery) Model(pk, model interface{}) error { 345 | t := reflect.TypeOf(model) 346 | if t.Kind() == reflect.Ptr { 347 | t = t.Elem() 348 | } 349 | if t.Kind() != reflect.Struct { 350 | return VarTypeError("must be a pointer to a struct") 351 | } 352 | 353 | si := getStructInfo(t, s.FieldMapper) 354 | if len(si.pkNames) == 1 { 355 | return s.AndWhere(HashExp{si.nameMap[si.pkNames[0]].dbName: pk}).One(model) 356 | } 357 | if len(si.pkNames) == 0 { 358 | return MissingPKError 359 | } 360 | 361 | return CompositePKError 362 | } 363 | 364 | // All executes the SELECT query and populates all rows of the result into a slice. 365 | // 366 | // Note that the slice must be passed in as a pointer. 367 | // 368 | // If the query does not specify a "from" clause, the method will try to infer the name of the table 369 | // to be selected from by calling getTableName() which will return either the type name of the slice elements 370 | // or the TableName() method if the slice element implements the TableModel interface. 371 | func (s *SelectQuery) All(slice interface{}) error { 372 | if len(s.from) == 0 { 373 | if tableName := s.TableMapper(slice); tableName != "" { 374 | s.from = []string{tableName} 375 | } 376 | } 377 | 378 | return s.Build().All(slice) 379 | } 380 | 381 | // Rows builds and executes the SELECT query and returns a Rows object for data retrieval purpose. 382 | // This is a shortcut to SelectQuery.Build().Rows() 383 | func (s *SelectQuery) Rows() (*Rows, error) { 384 | return s.Build().Rows() 385 | } 386 | 387 | // Row builds and executes the SELECT query and populates the first row of the result into the specified variables. 388 | // This is a shortcut to SelectQuery.Build().Row() 389 | func (s *SelectQuery) Row(a ...interface{}) error { 390 | return s.Build().Row(a...) 391 | } 392 | 393 | // Column builds and executes the SELECT statement and populates the first column of the result into a slice. 394 | // Note that the parameter must be a pointer to a slice. 395 | // This is a shortcut to SelectQuery.Build().Column() 396 | func (s *SelectQuery) Column(a interface{}) error { 397 | return s.Build().Column(a) 398 | } 399 | 400 | // QueryInfo represents a debug/info struct with exported SelectQuery fields. 401 | type QueryInfo struct { 402 | PreFragment string 403 | PostFragment string 404 | Builder Builder 405 | Selects []string 406 | Distinct bool 407 | SelectOption string 408 | From []string 409 | Where Expression 410 | Join []JoinInfo 411 | OrderBy []string 412 | GroupBy []string 413 | Having Expression 414 | Union []UnionInfo 415 | Limit int64 416 | Offset int64 417 | Params Params 418 | Context context.Context 419 | BuildHook BuildHookFunc 420 | } 421 | 422 | // Info exports common SelectQuery fields allowing to inspect the 423 | // current select query options. 424 | func (s *SelectQuery) Info() *QueryInfo { 425 | return &QueryInfo{ 426 | Builder: s.builder, 427 | PreFragment: s.preFragment, 428 | PostFragment: s.postFragment, 429 | Selects: s.selects, 430 | Distinct: s.distinct, 431 | SelectOption: s.selectOption, 432 | From: s.from, 433 | Where: s.where, 434 | Join: s.join, 435 | OrderBy: s.orderBy, 436 | GroupBy: s.groupBy, 437 | Having: s.having, 438 | Union: s.union, 439 | Limit: s.limit, 440 | Offset: s.offset, 441 | Params: s.params, 442 | Context: s.ctx, 443 | BuildHook: s.buildHook, 444 | } 445 | } 446 | -------------------------------------------------------------------------------- /select_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "testing" 9 | 10 | "database/sql" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestSelectQuery(t *testing.T) { 16 | db := getDB() 17 | 18 | // minimal select query 19 | q := db.Select().From("users").Build() 20 | expected := "SELECT * FROM `users`" 21 | assert.Equal(t, q.SQL(), expected, "t1") 22 | assert.Equal(t, len(q.Params()), 0, "t2") 23 | 24 | // a full select query 25 | q = db.Select("id", "name"). 26 | PreFragment("pre"). 27 | PostFragment("post"). 28 | AndSelect("age"). 29 | Distinct(true). 30 | SelectOption("CALC"). 31 | From("users"). 32 | Where(NewExp("age>30")). 33 | AndWhere(NewExp("status=1")). 34 | OrWhere(NewExp("type=2")). 35 | InnerJoin("profile", NewExp("user.id=profile.id")). 36 | LeftJoin("team", nil). 37 | RightJoin("dept", nil). 38 | OrderBy("age DESC", "type"). 39 | AndOrderBy("id"). 40 | GroupBy("id"). 41 | AndGroupBy("age"). 42 | Having(NewExp("id>10")). 43 | AndHaving(NewExp("id<20")). 44 | OrHaving(NewExp("type=3")). 45 | Limit(10). 46 | Offset(20). 47 | Bind(Params{"id": 1}). 48 | AndBind(Params{"age": 30}). 49 | Build() 50 | 51 | expected = "pre SELECT DISTINCT CALC `id`, `name`, `age` FROM `users` INNER JOIN `profile` ON user.id=profile.id LEFT JOIN `team` RIGHT JOIN `dept` WHERE ((age>30) AND (status=1)) OR (type=2) GROUP BY `id`, `age` HAVING ((id>10) AND (id<20)) OR (type=3) ORDER BY `age` DESC, `type`, `id` LIMIT 10 OFFSET 20 post" 52 | assert.Equal(t, q.SQL(), expected, "t3") 53 | assert.Equal(t, len(q.Params()), 2, "t4") 54 | 55 | q3 := db.Select().AndBind(Params{"id": 1}).Build() 56 | assert.Equal(t, len(q3.Params()), 1) 57 | 58 | // union 59 | q1 := db.Select().From("users").PreFragment("pre_q1").Build() 60 | q2 := db.Select().From("posts").PostFragment("post_q2").Build() 61 | q = db.Select().From("profiles").Union(q1).UnionAll(q2).PreFragment("pre").PostFragment("post").Build() 62 | expected = "(pre SELECT * FROM `profiles` post) UNION (pre_q1 SELECT * FROM `users`) UNION ALL (SELECT * FROM `posts` post_q2)" 63 | assert.Equal(t, q.SQL(), expected, "t5") 64 | } 65 | 66 | func TestSelectQuery_Data(t *testing.T) { 67 | db := getPreparedDB() 68 | defer db.Close() 69 | 70 | q := db.Select("id", "email").From("customer").OrderBy("id") 71 | 72 | var customer Customer 73 | q.One(&customer) 74 | assert.Equal(t, customer.Email, "user1@example.com", "customer.Email") 75 | 76 | var customers []Customer 77 | q.All(&customers) 78 | assert.Equal(t, len(customers), 3, "len(customers)") 79 | 80 | rows, _ := q.Rows() 81 | customer.Email = "" 82 | rows.one(&customer) 83 | assert.Equal(t, customer.Email, "user1@example.com", "customer.Email") 84 | 85 | var id, email string 86 | q.Row(&id, &email) 87 | assert.Equal(t, id, "1", "id") 88 | assert.Equal(t, email, "user1@example.com", "email") 89 | 90 | var emails []string 91 | err := db.Select("email").From("customer").Column(&emails) 92 | if assert.Nil(t, err) { 93 | assert.Equal(t, 3, len(emails)) 94 | } 95 | 96 | var e int 97 | err = db.Select().From("customer").One(&e) 98 | assert.NotNil(t, err) 99 | err = db.Select().From("customer").All(&e) 100 | assert.NotNil(t, err) 101 | } 102 | 103 | func TestSelectQuery_Model(t *testing.T) { 104 | db := getPreparedDB() 105 | defer db.Close() 106 | 107 | { 108 | // One without specifying FROM 109 | var customer CustomerPtr 110 | err := db.Select().OrderBy("id").One(&customer) 111 | if assert.Nil(t, err) { 112 | assert.Equal(t, "user1@example.com", *customer.Email) 113 | } 114 | } 115 | 116 | { 117 | // All without specifying FROM 118 | var customers []CustomerPtr 119 | err := db.Select().OrderBy("id").All(&customers) 120 | if assert.Nil(t, err) { 121 | assert.Equal(t, 3, len(customers)) 122 | } 123 | } 124 | 125 | { 126 | // Model without specifying FROM 127 | var customer CustomerPtr 128 | err := db.Select().Model(2, &customer) 129 | if assert.Nil(t, err) { 130 | assert.Equal(t, "user2@example.com", *customer.Email) 131 | } 132 | } 133 | 134 | { 135 | // Model with WHERE 136 | var customer CustomerPtr 137 | err := db.Select().Where(HashExp{"id": 1}).Model(2, &customer) 138 | assert.Equal(t, sql.ErrNoRows, err) 139 | 140 | err = db.Select().Where(HashExp{"id": 2}).Model(2, &customer) 141 | assert.Nil(t, err) 142 | } 143 | 144 | { 145 | // errors 146 | var i int 147 | err := db.Select().Model(1, &i) 148 | assert.Equal(t, VarTypeError("must be a pointer to a struct"), err) 149 | 150 | var a struct { 151 | Name string 152 | } 153 | 154 | err = db.Select().Model(1, &a) 155 | assert.Equal(t, MissingPKError, err) 156 | var b struct { 157 | ID1 string `db:"pk"` 158 | ID2 string `db:"pk"` 159 | } 160 | err = db.Select().Model(1, &b) 161 | assert.Equal(t, CompositePKError, err) 162 | } 163 | } 164 | 165 | func TestSelectWithBuildHook(t *testing.T) { 166 | db := getPreparedDB() 167 | defer db.Close() 168 | 169 | var buildSQL string 170 | 171 | db.Select("id"). 172 | From("user"). 173 | WithBuildHook(func(q *Query) { 174 | buildSQL = q.SQL() 175 | }). 176 | Build() 177 | 178 | assert.Equal(t, "SELECT `id` FROM `user`", buildSQL) 179 | } 180 | -------------------------------------------------------------------------------- /struct.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "database/sql" 9 | "reflect" 10 | "regexp" 11 | "strings" 12 | "sync" 13 | ) 14 | 15 | type ( 16 | // FieldMapFunc converts a struct field name into a DB column name. 17 | FieldMapFunc func(string) string 18 | 19 | // TableMapFunc converts a sample struct into a DB table name. 20 | TableMapFunc func(a interface{}) string 21 | 22 | structInfo struct { 23 | nameMap map[string]*fieldInfo // mapping from struct field names to field infos 24 | dbNameMap map[string]*fieldInfo // mapping from db column names to field infos 25 | pkNames []string // struct field names representing PKs 26 | } 27 | 28 | structValue struct { 29 | *structInfo 30 | value reflect.Value // the struct value 31 | tableName string // the db table name for the struct 32 | } 33 | 34 | fieldInfo struct { 35 | name string // field name 36 | dbName string // db column name 37 | path []int // index path to the struct field reflection 38 | } 39 | 40 | structInfoMapKey struct { 41 | t reflect.Type 42 | m reflect.Value 43 | } 44 | ) 45 | 46 | var ( 47 | // DbTag is the name of the struct tag used to specify the column name for the associated struct field 48 | DbTag = "db" 49 | 50 | fieldRegex = regexp.MustCompile(`([^A-Z_])([A-Z])`) 51 | scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() 52 | postScannerType = reflect.TypeOf((*PostScanner)(nil)).Elem() 53 | structInfoMap = make(map[structInfoMapKey]*structInfo) 54 | muStructInfoMap sync.Mutex 55 | ) 56 | 57 | // PostScanner is an optional interface used by ScanStruct. 58 | type PostScanner interface { 59 | // PostScan executes right after the struct has been populated 60 | // with the DB values, allowing you to further normalize or validate 61 | // the loaded data. 62 | PostScan() error 63 | } 64 | 65 | // DefaultFieldMapFunc maps a field name to a DB column name. 66 | // The mapping rule set by this method is that words in a field name will be separated by underscores 67 | // and the name will be turned into lower case. For example, "FirstName" maps to "first_name", and "MyID" becomes "my_id". 68 | // See DB.FieldMapper for more details. 69 | func DefaultFieldMapFunc(f string) string { 70 | return strings.ToLower(fieldRegex.ReplaceAllString(f, "${1}_$2")) 71 | } 72 | 73 | func getStructInfo(a reflect.Type, mapper FieldMapFunc) *structInfo { 74 | muStructInfoMap.Lock() 75 | defer muStructInfoMap.Unlock() 76 | 77 | key := structInfoMapKey{a, reflect.ValueOf(mapper)} 78 | if si, ok := structInfoMap[key]; ok { 79 | return si 80 | } 81 | 82 | si := &structInfo{ 83 | nameMap: map[string]*fieldInfo{}, 84 | dbNameMap: map[string]*fieldInfo{}, 85 | } 86 | si.build(a, make([]int, 0), "", "", mapper) 87 | structInfoMap[key] = si 88 | 89 | return si 90 | } 91 | 92 | func newStructValue(model interface{}, fieldMapFunc FieldMapFunc, tableMapFunc TableMapFunc) *structValue { 93 | value := reflect.ValueOf(model) 94 | if value.Kind() != reflect.Ptr || value.Elem().Kind() != reflect.Struct || value.IsNil() { 95 | return nil 96 | } 97 | 98 | return &structValue{ 99 | structInfo: getStructInfo(reflect.TypeOf(model).Elem(), fieldMapFunc), 100 | value: value.Elem(), 101 | tableName: tableMapFunc(model), 102 | } 103 | } 104 | 105 | // pk returns the primary key values indexed by the corresponding primary key column names. 106 | func (s *structValue) pk() map[string]interface{} { 107 | if len(s.pkNames) == 0 { 108 | return nil 109 | } 110 | return s.columns(s.pkNames, nil) 111 | } 112 | 113 | // columns returns the struct field values indexed by their corresponding DB column names. 114 | func (s *structValue) columns(include, exclude []string) map[string]interface{} { 115 | v := make(map[string]interface{}, len(s.nameMap)) 116 | if len(include) == 0 { 117 | for _, fi := range s.nameMap { 118 | v[fi.dbName] = fi.getValue(s.value) 119 | } 120 | } else { 121 | for _, attr := range include { 122 | if fi, ok := s.nameMap[attr]; ok { 123 | v[fi.dbName] = fi.getValue(s.value) 124 | } 125 | } 126 | } 127 | if len(exclude) > 0 { 128 | for _, name := range exclude { 129 | if fi, ok := s.nameMap[name]; ok { 130 | delete(v, fi.dbName) 131 | } 132 | } 133 | } 134 | return v 135 | } 136 | 137 | // getValue returns the field value for the given struct value. 138 | func (fi *fieldInfo) getValue(a reflect.Value) interface{} { 139 | for _, i := range fi.path { 140 | a = a.Field(i) 141 | if a.Kind() == reflect.Ptr { 142 | if a.IsNil() { 143 | return nil 144 | } 145 | a = a.Elem() 146 | } 147 | } 148 | return a.Interface() 149 | } 150 | 151 | // getField returns the reflection value of the field for the given struct value. 152 | func (fi *fieldInfo) getField(a reflect.Value) reflect.Value { 153 | i := 0 154 | for ; i < len(fi.path)-1; i++ { 155 | a = indirect(a.Field(fi.path[i])) 156 | } 157 | return a.Field(fi.path[i]) 158 | } 159 | 160 | func (si *structInfo) build(a reflect.Type, path []int, namePrefix, dbNamePrefix string, mapper FieldMapFunc) { 161 | n := a.NumField() 162 | for i := 0; i < n; i++ { 163 | field := a.Field(i) 164 | tag := field.Tag.Get(DbTag) 165 | 166 | // only handle anonymous or exported fields 167 | if !field.Anonymous && field.PkgPath != "" || tag == "-" { 168 | continue 169 | } 170 | 171 | path2 := make([]int, len(path), len(path)+1) 172 | copy(path2, path) 173 | path2 = append(path2, i) 174 | 175 | ft := field.Type 176 | if ft.Kind() == reflect.Ptr { 177 | ft = ft.Elem() 178 | } 179 | 180 | name := field.Name 181 | dbName, isPK := parseTag(tag) 182 | if dbName == "" && !field.Anonymous { 183 | if mapper != nil { 184 | dbName = mapper(field.Name) 185 | } else { 186 | dbName = field.Name 187 | } 188 | } 189 | if field.Anonymous { 190 | name = "" 191 | } 192 | 193 | if isNestedStruct(ft) { 194 | // dive into non-scanner struct 195 | si.build(ft, path2, concat(namePrefix, name), concat(dbNamePrefix, dbName), mapper) 196 | } else if dbName != "" { 197 | // non-anonymous scanner or struct field 198 | fi := &fieldInfo{ 199 | name: concat(namePrefix, name), 200 | dbName: concat(dbNamePrefix, dbName), 201 | path: path2, 202 | } 203 | // a field in an anonymous struct may be shadowed 204 | if _, ok := si.nameMap[fi.name]; !ok || len(path2) < len(si.nameMap[fi.name].path) { 205 | si.nameMap[fi.name] = fi 206 | si.dbNameMap[fi.dbName] = fi 207 | if isPK { 208 | si.pkNames = append(si.pkNames, fi.name) 209 | } 210 | } 211 | } 212 | } 213 | if len(si.pkNames) == 0 { 214 | if _, ok := si.nameMap["ID"]; ok { 215 | si.pkNames = append(si.pkNames, "ID") 216 | } else if _, ok := si.nameMap["Id"]; ok { 217 | si.pkNames = append(si.pkNames, "Id") 218 | } 219 | } 220 | } 221 | 222 | func isNestedStruct(t reflect.Type) bool { 223 | if t.PkgPath() == "time" && t.Name() == "Time" { 224 | return false 225 | } 226 | return t.Kind() == reflect.Struct && !reflect.PtrTo(t).Implements(scannerType) 227 | } 228 | 229 | func parseTag(tag string) (string, bool) { 230 | if tag == "pk" { 231 | return "", true 232 | } 233 | if strings.HasPrefix(tag, "pk,") { 234 | return tag[3:], true 235 | } 236 | return tag, false 237 | } 238 | 239 | func concat(s1, s2 string) string { 240 | if s1 == "" { 241 | return s2 242 | } else if s2 == "" { 243 | return s1 244 | } else { 245 | return s1 + "." + s2 246 | } 247 | } 248 | 249 | // indirect dereferences pointers and returns the actual value it points to. 250 | // If a pointer is nil, it will be initialized with a new value. 251 | func indirect(v reflect.Value) reflect.Value { 252 | for v.Kind() == reflect.Ptr { 253 | if v.IsNil() { 254 | v.Set(reflect.New(v.Type().Elem())) 255 | } 256 | v = v.Elem() 257 | } 258 | return v 259 | } 260 | 261 | // GetTableName implements the default way of determining the table name corresponding to the given model struct 262 | // or slice of structs. To get the actual table name for a model, you should use DB.TableMapFunc() instead. 263 | // Do not call this method in a model's TableName() method because it will cause infinite loop. 264 | func GetTableName(a interface{}) string { 265 | if tm, ok := a.(TableModel); ok { 266 | v := reflect.ValueOf(a) 267 | if v.Kind() == reflect.Ptr && v.IsNil() { 268 | a = reflect.New(v.Type().Elem()).Interface() 269 | return a.(TableModel).TableName() 270 | } 271 | return tm.TableName() 272 | } 273 | t := reflect.TypeOf(a) 274 | if t.Kind() == reflect.Ptr { 275 | t = t.Elem() 276 | } 277 | if t.Kind() == reflect.Slice { 278 | return GetTableName(reflect.Zero(t.Elem()).Interface()) 279 | } 280 | return DefaultFieldMapFunc(t.Name()) 281 | } 282 | -------------------------------------------------------------------------------- /struct_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import ( 8 | "database/sql" 9 | "reflect" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestDefaultFieldMapFunc(t *testing.T) { 16 | tests := []struct { 17 | input, output string 18 | }{ 19 | {"Name", "name"}, 20 | {"FirstName", "first_name"}, 21 | {"Name0", "name0"}, 22 | {"ID", "id"}, 23 | {"UserID", "user_id"}, 24 | {"User0ID", "user0_id"}, 25 | {"MyURL", "my_url"}, 26 | {"URLPath", "urlpath"}, 27 | {"MyURLPath", "my_urlpath"}, 28 | {"First_Name", "first_name"}, 29 | {"first_name", "first_name"}, 30 | {"_FirstName", "_first_name"}, 31 | {"_First_Name", "_first_name"}, 32 | } 33 | for _, test := range tests { 34 | r := DefaultFieldMapFunc(test.input) 35 | assert.Equal(t, test.output, r, test.input) 36 | } 37 | } 38 | 39 | func Test_concat(t *testing.T) { 40 | assert.Equal(t, "a.b", concat("a", "b")) 41 | assert.Equal(t, "a", concat("a", "")) 42 | assert.Equal(t, "b", concat("", "b")) 43 | } 44 | 45 | func Test_parseTag(t *testing.T) { 46 | name, pk := parseTag("abc") 47 | assert.Equal(t, "abc", name) 48 | assert.False(t, pk) 49 | 50 | name, pk = parseTag("pk,abc") 51 | assert.Equal(t, "abc", name) 52 | assert.True(t, pk) 53 | 54 | name, pk = parseTag("pk") 55 | assert.Equal(t, "", name) 56 | assert.True(t, pk) 57 | } 58 | 59 | func Test_indirect(t *testing.T) { 60 | var a int 61 | assert.Equal(t, reflect.ValueOf(a).Kind(), indirect(reflect.ValueOf(a)).Kind()) 62 | var b *int 63 | bi := indirect(reflect.ValueOf(&b)) 64 | assert.Equal(t, reflect.ValueOf(a).Kind(), bi.Kind()) 65 | if assert.NotNil(t, b) { 66 | assert.Equal(t, 0, *b) 67 | } 68 | } 69 | 70 | func Test_structValue_columns(t *testing.T) { 71 | customer := Customer{ 72 | ID: 1, 73 | Name: "abc", 74 | Status: 2, 75 | Email: "abc@example.com", 76 | } 77 | sv := newStructValue(&customer, DefaultFieldMapFunc, GetTableName) 78 | cols := sv.columns(nil, nil) 79 | assert.Equal(t, map[string]interface{}{"id": 1, "name": "abc", "status": 2, "email": "abc@example.com", "address": sql.NullString{}}, cols) 80 | 81 | cols = sv.columns([]string{"ID", "name"}, nil) 82 | assert.Equal(t, map[string]interface{}{"id": 1}, cols) 83 | 84 | cols = sv.columns([]string{"ID", "Name"}, []string{"ID"}) 85 | assert.Equal(t, map[string]interface{}{"name": "abc"}, cols) 86 | 87 | cols = sv.columns(nil, []string{"ID", "Address"}) 88 | assert.Equal(t, map[string]interface{}{"name": "abc", "status": 2, "email": "abc@example.com"}, cols) 89 | 90 | sv = newStructValue(&customer, nil, GetTableName) 91 | cols = sv.columns([]string{"ID", "Name"}, []string{"ID"}) 92 | assert.Equal(t, map[string]interface{}{"Name": "abc"}, cols) 93 | } 94 | 95 | func TestIssue37(t *testing.T) { 96 | customer := Customer{ 97 | ID: 1, 98 | Name: "abc", 99 | Status: 2, 100 | Email: "abc@example.com", 101 | } 102 | ev := struct { 103 | Customer 104 | Status string 105 | }{customer, "20"} 106 | sv := newStructValue(&ev, nil, GetTableName) 107 | cols := sv.columns([]string{"ID", "Status"}, nil) 108 | assert.Equal(t, map[string]interface{}{"ID": 1, "Status": "20"}, cols) 109 | 110 | ev2 := struct { 111 | Status string 112 | Customer 113 | }{"20", customer} 114 | sv = newStructValue(&ev2, nil, GetTableName) 115 | cols = sv.columns([]string{"ID", "Status"}, nil) 116 | assert.Equal(t, map[string]interface{}{"ID": 1, "Status": "20"}, cols) 117 | } 118 | 119 | type MyCustomer struct{} 120 | 121 | func TestGetTableName(t *testing.T) { 122 | var c1 Customer 123 | assert.Equal(t, "customer", GetTableName(c1)) 124 | 125 | var c2 *Customer 126 | assert.Equal(t, "customer", GetTableName(c2)) 127 | 128 | var c3 MyCustomer 129 | assert.Equal(t, "my_customer", GetTableName(c3)) 130 | 131 | var c4 []Customer 132 | assert.Equal(t, "customer", GetTableName(c4)) 133 | 134 | var c5 *[]Customer 135 | assert.Equal(t, "customer", GetTableName(c5)) 136 | 137 | var c6 []MyCustomer 138 | assert.Equal(t, "my_customer", GetTableName(c6)) 139 | 140 | var c7 []CustomerPtr 141 | assert.Equal(t, "customer", GetTableName(c7)) 142 | 143 | var c8 **int 144 | assert.Equal(t, "", GetTableName(c8)) 145 | } 146 | 147 | type FA struct { 148 | A1 string 149 | A2 int 150 | } 151 | 152 | type FB struct { 153 | B1 string 154 | } 155 | -------------------------------------------------------------------------------- /testdata/mysql.sql: -------------------------------------------------------------------------------- 1 | /** 2 | * This is the database schema for testing MySQL support of ozzo-dbx. 3 | * The following database setup is required in order to run the test: 4 | * - host: 127.0.0.1 5 | * - user: travis 6 | * - pass: 7 | * - database: pocketbase_dbx_test 8 | */ 9 | 10 | DROP TABLE IF EXISTS `order_item` CASCADE; 11 | DROP TABLE IF EXISTS `item` CASCADE; 12 | DROP TABLE IF EXISTS `order` CASCADE; 13 | DROP TABLE IF EXISTS `customer` CASCADE; 14 | DROP TABLE IF EXISTS `user` CASCADE; 15 | 16 | CREATE TABLE `customer` ( 17 | `id` int(11) NOT NULL AUTO_INCREMENT, 18 | `email` varchar(128) NOT NULL, 19 | `name` varchar(128), 20 | `address` text, 21 | `status` int (11) DEFAULT 0, 22 | PRIMARY KEY (`id`) 23 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8; 24 | 25 | CREATE TABLE `user` ( 26 | `id` int(11) NOT NULL AUTO_INCREMENT, 27 | `email` varchar(128) NOT NULL, 28 | `created` date, 29 | `updated` date, 30 | PRIMARY KEY (`id`) 31 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8; 32 | 33 | CREATE TABLE `item` ( 34 | `id` int(11) NOT NULL AUTO_INCREMENT, 35 | `name` varchar(128) NOT NULL, 36 | PRIMARY KEY (`id`) 37 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8; 38 | 39 | CREATE TABLE `order` ( 40 | `id` int(11) NOT NULL AUTO_INCREMENT, 41 | `customer_id` int(11) NOT NULL, 42 | `created_at` int(11) NOT NULL, 43 | `total` decimal(10,0) NOT NULL, 44 | PRIMARY KEY (`id`), 45 | CONSTRAINT `FK_order_customer_id` FOREIGN KEY (`customer_id`) REFERENCES `customer` (`id`) ON DELETE CASCADE 46 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8; 47 | 48 | CREATE TABLE `order_item` ( 49 | `order_id` int(11) NOT NULL, 50 | `item_id` int(11) NOT NULL, 51 | `quantity` int(11) NOT NULL, 52 | `subtotal` decimal(10,0) NOT NULL, 53 | PRIMARY KEY (`order_id`,`item_id`), 54 | KEY `FK_order_item_item_id` (`item_id`), 55 | CONSTRAINT `FK_order_item_order_id` FOREIGN KEY (`order_id`) REFERENCES `order` (`id`) ON DELETE CASCADE, 56 | CONSTRAINT `FK_order_item_item_id` FOREIGN KEY (`item_id`) REFERENCES `item` (`id`) ON DELETE CASCADE 57 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8; 58 | 59 | INSERT INTO `customer` (email, name, address, status) VALUES ('user1@example.com', 'user1', 'address1', 1); 60 | INSERT INTO `customer` (email, name, address, status) VALUES ('user2@example.com', 'user2', NULL, 1); 61 | INSERT INTO `customer` (email, name, address, status) VALUES ('user3@example.com', 'user3', 'address3', 2); 62 | 63 | INSERT INTO `user` (email, created) VALUES ('user1@example.com', '2015-01-02'); 64 | INSERT INTO `user` (email, created) VALUES ('user2@example.com', now()); 65 | 66 | INSERT INTO `item` (name) VALUES ('The Go Programming Language'); 67 | INSERT INTO `item` (name) VALUES ('Go in Action'); 68 | INSERT INTO `item` (name) VALUES ('Go Programming Blueprints'); 69 | INSERT INTO `item` (name) VALUES ('Building Microservices'); 70 | INSERT INTO `item` (name) VALUES ('Go Web Programming'); 71 | 72 | INSERT INTO `order` (customer_id, created_at, total) VALUES (1, 1325282384, 110.0); 73 | INSERT INTO `order` (customer_id, created_at, total) VALUES (2, 1325334482, 33.0); 74 | INSERT INTO `order` (customer_id, created_at, total) VALUES (2, 1325502201, 40.0); 75 | 76 | INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (1, 1, 1, 30.0); 77 | INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (1, 2, 2, 40.0); 78 | INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (2, 4, 1, 10.0); 79 | INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (2, 5, 1, 15.0); 80 | INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (2, 3, 1, 8.0); 81 | INSERT INTO `order_item` (order_id, item_id, quantity, subtotal) VALUES (3, 2, 1, 40.0); 82 | -------------------------------------------------------------------------------- /tx.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package dbx 6 | 7 | import "database/sql" 8 | 9 | // Tx enhances sql.Tx with additional querying methods. 10 | type Tx struct { 11 | Builder 12 | tx *sql.Tx 13 | } 14 | 15 | // Commit commits the transaction. 16 | func (t *Tx) Commit() error { 17 | return t.tx.Commit() 18 | } 19 | 20 | // Rollback aborts the transaction. 21 | func (t *Tx) Rollback() error { 22 | return t.tx.Rollback() 23 | } 24 | --------------------------------------------------------------------------------