├── .gitignore ├── .travis.yml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── buffer.go ├── buffer_test.go ├── builder.go ├── condition.go ├── condition_test.go ├── dbr.go ├── dbr_go18.go ├── dbr_go18_test.go ├── dbr_test.go ├── delete.go ├── delete_builder.go ├── delete_test.go ├── dialect.go ├── dialect ├── clickhouse.go ├── dialect.go ├── dialect_test.go ├── mysql.go ├── postgresql.go └── sqlite3.go ├── errors.go ├── event.go ├── expr.go ├── go.mod ├── go.sum ├── ident.go ├── insert.go ├── insert_builder.go ├── insert_test.go ├── interpolate.go ├── interpolate_test.go ├── join.go ├── load.go ├── load_bench_test.go ├── load_test.go ├── now.go ├── order.go ├── select.go ├── select_builder.go ├── select_builder_test.go ├── select_return.go ├── select_test.go ├── transaction.go ├── transaction_test.go ├── types.go ├── types_test.go ├── union.go ├── update.go ├── update_builder.go ├── update_test.go ├── util.go └── util_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | /.devcontainer 3 | .DS_Store -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | notifications: 4 | email: false 5 | 6 | go: 7 | - "1.11" 8 | - "1.12" 9 | 10 | env: 11 | - GO111MODULE=on 12 | 13 | dist: trusty 14 | sudo: required 15 | 16 | addons: 17 | postgresql: "9.5" 18 | 19 | services: 20 | - mysql 21 | - postgresql 22 | - docker 23 | 24 | before_install: 25 | - travis_retry docker pull yandex/clickhouse-server 26 | - docker run -d -p 127.0.0.1:8123:8123 --name dbr-clickhouse-server yandex/clickhouse-server 27 | 28 | install: 29 | - travis_retry go get github.com/mattn/goveralls 30 | - travis_retry go get golang.org/x/lint/golint 31 | 32 | - psql -c 'create database dbr_ci_test;' -U postgres 33 | - mysql -e 'CREATE DATABASE dbr_ci_test;' -uroot 34 | 35 | before_script: 36 | - export DBR_TEST_MYSQL_DSN="root:@tcp(127.0.0.1:3306)/dbr_ci_test?charset=utf8" 37 | - export DBR_TEST_POSTGRES_DSN="postgres://postgres:@127.0.0.1:5432/dbr_ci_test?sslmode=disable" 38 | - export DBR_TEST_CLICKHOUSE_DSN="http://localhost:8123/default" 39 | 40 | script: 41 | - $HOME/gopath/bin/golint ./... 42 | - go vet ./... 43 | - test -z "$(gofmt -d -s . | tee /dev/stderr)" 44 | - go test -v -covermode=count -coverprofile=coverage.out 45 | - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci 46 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change log 2 | 3 | Not all minor changes may be noted here, but all large and/or breaking changes 4 | should be. 5 | 6 | ## v2.0 - 2015-10-09 7 | 8 | ### Added 9 | - PostgreSQL support! 10 | - `Open(driver, dsn string, log EventReceiver)` creates an underlying connection for you based on a supplied driver and dsn string 11 | - All builders are now available without a `Session` facilitating much more complex queries 12 | - More common SQL support: Subqueries, Unions, Joins, Aliases 13 | - More complex condition building support: And/Or/Eq/Neq/Gt/Gte/Lt/Lte 14 | 15 | ### Deprecated 16 | - `NewConnection` is deprecated. It assumes MySQL driver. Please use `Open` instead 17 | 18 | ### Changed 19 | - `NullTime` no longer relies on the mysql package. E.g. instead of `NullTime{mysql.NullTime{Time: t, Valid: true}}` it's now simply `NullTime{Time: t, Valid: true}` 20 | - All `*Builder` structs now embed a corresponding `*Stmt` struct (E.g. `SelectBuilder` embeds `SelectStmt`). All non-`Session` specific properies have been moved the `*Stmt` structs 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Jonathan Novak, Tyler Smith, Tai-Lin Chu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | dbr (fork of gocraft/dbr) provides additions to Go's database/sql for super fast performance and convenience. 2 | 3 | [![Build Status](https://travis-ci.org/mailru/dbr.svg?branch=master)](https://travis-ci.org/mailru/dbr) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/mailru/dbr)](https://goreportcard.com/report/github.com/mailru/dbr) 5 | [![Coverage Status](https://coveralls.io/repos/github/mailru/dbr/badge.svg?branch=develop)](https://coveralls.io/github/mailru/dbr?branch=develop) 6 | 7 | ## Getting Started 8 | 9 | ```go 10 | // create a connection (e.g. "postgres", "mysql", or "sqlite3") 11 | conn, _ := dbr.Open("postgres", "...") 12 | 13 | // create a session for each business unit of execution (e.g. a web request or goworkers job) 14 | sess := conn.NewSession(nil) 15 | 16 | // get a record 17 | var suggestion Suggestion 18 | sess.Select("id", "title").From("suggestions").Where("id = ?", 1).Load(&suggestion) 19 | 20 | // JSON-ready, with dbr.Null* types serialized like you want 21 | json.Marshal(&suggestion) 22 | ``` 23 | 24 | ## Feature highlights 25 | 26 | ### Use a Sweet Query Builder or use Plain SQL 27 | 28 | mailru/dbr supports both. 29 | 30 | Sweet Query Builder: 31 | ```go 32 | stmt := dbr.Select("title", "body"). 33 | From("suggestions"). 34 | OrderBy("id"). 35 | Limit(10) 36 | ``` 37 | 38 | Plain SQL: 39 | 40 | ```go 41 | builder := dbr.SelectBySql("SELECT `title`, `body` FROM `suggestions` ORDER BY `id` ASC LIMIT 10") 42 | ``` 43 | 44 | ### Amazing instrumentation with session 45 | 46 | All queries in mailru/dbr are made in the context of a session. This is because when instrumenting your app, it's important to understand which business action the query took place in. 47 | 48 | Writing instrumented code is a first-class concern for mailru/dbr. We instrument each query to emit to a EventReceiver interface. 49 | 50 | ### Faster performance than using database/sql directly 51 | Every time you call database/sql's db.Query("SELECT ...") method, under the hood, the mysql driver will create a prepared statement, execute it, and then throw it away. This has a big performance cost. 52 | 53 | mailru/dbr doesn't use prepared statements. We ported mysql's query escape functionality directly into our package, which means we interpolate all of those question marks with their arguments before they get to MySQL. The result of this is that it's way faster, and just as secure. 54 | 55 | Check out these [benchmarks](https://github.com/tyler-smith/golang-sql-benchmark). 56 | 57 | ### IN queries that aren't horrible 58 | Traditionally, database/sql uses prepared statements, which means each argument in an IN clause needs its own question mark. mailru/dbr, on the other hand, handles interpolation itself so that you can easily use a single question mark paired with a dynamically sized slice. 59 | ```go 60 | ids := []int64{1, 2, 3, 4, 5} 61 | builder.Where("id IN ?", ids) // `id` IN ? 62 | ``` 63 | map object can be used for IN queries as well. 64 | Note: interpolation map is slower than slice and it is preferable to use slice when it is possible. 65 | ```go 66 | ids := map[int64]string{1: "one", 2: "two"} 67 | builder.Where("id IN ?", ids) // `id` IN ? 68 | ``` 69 | 70 | ### JSON Friendly 71 | Every try to JSON-encode a sql.NullString? You get: 72 | ```json 73 | { 74 | "str1": { 75 | "Valid": true, 76 | "String": "Hi!" 77 | }, 78 | "str2": { 79 | "Valid": false, 80 | "String": "" 81 | } 82 | } 83 | ``` 84 | 85 | Not quite what you want. mailru/dbr has dbr.NullString (and the rest of the Null* types) that encode correctly, giving you: 86 | 87 | ```json 88 | { 89 | "str1": "Hi!", 90 | "str2": null 91 | } 92 | ``` 93 | 94 | ### Inserting multiple records 95 | 96 | ```go 97 | sess.InsertInto("suggestions").Columns("title", "body"). 98 | Record(suggestion1). 99 | Record(suggestion2) 100 | ``` 101 | 102 | ### Updating records on conflict 103 | 104 | ```go 105 | stmt := sess.InsertInto("suggestions").Columns("title", "body").Record(suggestion1) 106 | stmt.OnConflict("suggestions_pkey").Action("body", dbr.Proposed("body")) 107 | ``` 108 | 109 | 110 | ### Updating records 111 | 112 | ```go 113 | sess.Update("suggestions"). 114 | Set("title", "Gopher"). 115 | Set("body", "I love go."). 116 | Where("id = ?", 1) 117 | ``` 118 | 119 | ### Transactions 120 | 121 | ```go 122 | tx, err := sess.Begin() 123 | if err != nil { 124 | return err 125 | } 126 | defer tx.RollbackUnlessCommitted() 127 | 128 | // do stuff... 129 | 130 | return tx.Commit() 131 | ``` 132 | 133 | ### Load database values to variables 134 | 135 | Querying is the heart of mailru/dbr. 136 | 137 | * Load(&any): load everything! 138 | * LoadStruct(&oneStruct): load struct 139 | * LoadStructs(&manyStructs): load a slice of structs 140 | * LoadValue(&oneValue): load basic type 141 | * LoadValues(&manyValues): load a slice of basic types 142 | 143 | ```go 144 | // columns are mapped by tag then by field 145 | type Suggestion struct { 146 | ID int64 // id, will be autoloaded by last insert id 147 | Title string // title 148 | Url string `db:"-"` // ignored 149 | secret string // ignored 150 | Body dbr.NullString `db:"content"` // content 151 | User User 152 | } 153 | 154 | // By default dbr converts CamelCase property names to snake_case column_names 155 | // You can override this with struct tags, just like with JSON tags 156 | // This is especially helpful while migrating from legacy systems 157 | type Suggestion struct { 158 | Id int64 159 | Title dbr.NullString `db:"subject"` // subjects are called titles now 160 | CreatedAt dbr.NullTime 161 | } 162 | 163 | var suggestions []Suggestion 164 | sess.Select("*").From("suggestions").Load(&suggestions) 165 | ``` 166 | 167 | ### Join multiple tables 168 | 169 | dbr supports many join types: 170 | 171 | ```go 172 | sess.Select("*").From("suggestions"). 173 | Join("subdomains", "suggestions.subdomain_id = subdomains.id") 174 | 175 | sess.Select("*").From("suggestions"). 176 | LeftJoin("subdomains", "suggestions.subdomain_id = subdomains.id") 177 | 178 | sess.Select("*").From("suggestions"). 179 | RightJoin("subdomains", "suggestions.subdomain_id = subdomains.id") 180 | 181 | sess.Select("*").From("suggestions"). 182 | FullJoin("subdomains", "suggestions.subdomain_id = subdomains.id") 183 | ``` 184 | 185 | You can join on multiple tables: 186 | 187 | ```go 188 | sess.Select("*").From("suggestions"). 189 | Join("subdomains", "suggestions.subdomain_id = subdomains.id"). 190 | Join("accounts", "subdomains.accounts_id = accounts.id") 191 | ``` 192 | 193 | ### Quoting/escaping identifiers (e.g. table and column names) 194 | 195 | ```go 196 | dbr.I("suggestions.id") // `suggestions`.`id` 197 | ``` 198 | 199 | ### Subquery 200 | 201 | ```go 202 | sess.Select("count(id)").From( 203 | dbr.Select("*").From("suggestions").As("count"), 204 | ) 205 | ``` 206 | 207 | ### Union 208 | 209 | ```go 210 | dbr.Union( 211 | dbr.Select("*"), 212 | dbr.Select("*"), 213 | ) 214 | 215 | dbr.UnionAll( 216 | dbr.Select("*"), 217 | dbr.Select("*"), 218 | ) 219 | ``` 220 | 221 | Union can be used in subquery. 222 | 223 | ### Alias/AS 224 | 225 | * SelectStmt 226 | 227 | ```go 228 | dbr.Select("*").From("suggestions").As("count") 229 | ``` 230 | 231 | * Identity 232 | 233 | ```go 234 | dbr.I("suggestions").As("s") 235 | ``` 236 | 237 | * Union 238 | 239 | ```go 240 | dbr.Union( 241 | dbr.Select("*"), 242 | dbr.Select("*"), 243 | ).As("u1") 244 | 245 | dbr.UnionAll( 246 | dbr.Select("*"), 247 | dbr.Select("*"), 248 | ).As("u2") 249 | ``` 250 | 251 | ### Building arbitrary condition 252 | 253 | One common reason to use this is to prevent string concatenation in a loop. 254 | 255 | * And 256 | * Or 257 | * Eq 258 | * Neq 259 | * Gt 260 | * Gte 261 | * Lt 262 | * Lte 263 | 264 | ```go 265 | dbr.And( 266 | dbr.Or( 267 | dbr.Gt("created_at", "2015-09-10"), 268 | dbr.Lte("created_at", "2015-09-11"), 269 | ), 270 | dbr.Eq("title", "hello world"), 271 | ) 272 | ``` 273 | 274 | ### Built with extensibility 275 | 276 | The core of dbr is interpolation, which can expand `?` with arbitrary SQL. If you need a feature that is not currently supported, 277 | you can build it on your own (or use `dbr.Expr`). 278 | 279 | To do that, the value that you wish to be expaned with `?` needs to implement `dbr.Builder`. 280 | 281 | ```go 282 | type Builder interface { 283 | Build(Dialect, Buffer) error 284 | } 285 | ``` 286 | 287 | ## Driver support 288 | 289 | * MySQL 290 | * PostgreSQL 291 | * SQLite3 292 | * ClickHouse 293 | 294 | These packages were developed by the [engineering team](https://eng.uservoice.com) at [UserVoice](https://www.uservoice.com) and currently power much of its infrastructure and tech stack. 295 | 296 | ## Thanks & Authors 297 | Inspiration from these excellent libraries: 298 | * [sqlx](https://github.com/jmoiron/sqlx) - various useful tools and utils for interacting with database/sql. 299 | * [Squirrel](https://github.com/lann/squirrel) - simple fluent query builder. 300 | 301 | Authors: 302 | * Jonathan Novak -- [https://github.com/cypriss](https://github.com/cypriss) 303 | * Tyler Smith -- [https://github.com/tyler-smith](https://github.com/tyler-smith) 304 | * Tai-Lin Chu -- [https://github.com/taylorchu](https://github.com/taylorchu) 305 | * Sponsored by [UserVoice](https://eng.uservoice.com) 306 | 307 | Contributors: 308 | * Paul Bergeron -- [https://github.com/dinedal](https://github.com/dinedal) - SQLite dialect 309 | * Bulat Gaifullin -- [https://github.com/bgaifullin](https://github.com/bgaifullin) - ClickHouse dialect 310 | -------------------------------------------------------------------------------- /buffer.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import "bytes" 4 | 5 | // Buffer is an interface used by Builder to store intermediate results 6 | type Buffer interface { 7 | WriteString(s string) (n int, err error) 8 | String() string 9 | 10 | WriteValue(v ...interface{}) (err error) 11 | Value() []interface{} 12 | } 13 | 14 | // NewBuffer creates buffer 15 | func NewBuffer() Buffer { 16 | return &buffer{} 17 | } 18 | 19 | type buffer struct { 20 | bytes.Buffer 21 | v []interface{} 22 | } 23 | 24 | func (b *buffer) WriteValue(v ...interface{}) error { 25 | b.v = append(b.v, v...) 26 | return nil 27 | } 28 | 29 | func (b *buffer) Value() []interface{} { 30 | return b.v 31 | } 32 | -------------------------------------------------------------------------------- /buffer_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | -------------------------------------------------------------------------------- /builder.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | // Builder builds sql in one dialect like MySQL/PostgreSQL 4 | // e.g. XxxBuilder 5 | type Builder interface { 6 | Build(Dialect, Buffer) error 7 | } 8 | 9 | // BuildFunc is an adapter to allow the use of ordinary functions as Builder 10 | type BuildFunc func(Dialect, Buffer) error 11 | 12 | // Build implements Builder interface 13 | func (b BuildFunc) Build(d Dialect, buf Buffer) error { 14 | return b(d, buf) 15 | } 16 | -------------------------------------------------------------------------------- /condition.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import "reflect" 4 | 5 | func buildCond(d Dialect, buf Buffer, pred string, cond ...Builder) error { 6 | for i, c := range cond { 7 | if i > 0 { 8 | buf.WriteString(" ") 9 | buf.WriteString(pred) 10 | buf.WriteString(" ") 11 | } 12 | buf.WriteString("(") 13 | err := c.Build(d, buf) 14 | if err != nil { 15 | return err 16 | } 17 | buf.WriteString(")") 18 | } 19 | return nil 20 | } 21 | 22 | // And creates AND from a list of conditions 23 | func And(cond ...Builder) Builder { 24 | return BuildFunc(func(d Dialect, buf Buffer) error { 25 | return buildCond(d, buf, "AND", cond...) 26 | }) 27 | } 28 | 29 | // Or creates OR from a list of conditions 30 | func Or(cond ...Builder) Builder { 31 | return BuildFunc(func(d Dialect, buf Buffer) error { 32 | return buildCond(d, buf, "OR", cond...) 33 | }) 34 | } 35 | 36 | func buildCmp(d Dialect, buf Buffer, pred, column string, value interface{}) error { 37 | buf.WriteString(d.QuoteIdent(column)) 38 | buf.WriteString(" ") 39 | buf.WriteString(pred) 40 | buf.WriteString(" ") 41 | buf.WriteString(placeholder) 42 | 43 | buf.WriteValue(value) 44 | return nil 45 | } 46 | 47 | // Eq is `=`. 48 | // When value is nil, it will be translated to `IS NULL`. 49 | // When value is a slice, it will be translated to `IN`. 50 | // Otherwise it will be translated to `=`. 51 | func Eq(column string, value interface{}) Builder { 52 | return BuildFunc(func(d Dialect, buf Buffer) error { 53 | if value == nil { 54 | buf.WriteString(d.QuoteIdent(column)) 55 | buf.WriteString(" IS NULL") 56 | return nil 57 | } 58 | v := reflect.ValueOf(value) 59 | if v.Kind() == reflect.Slice || v.Kind() == reflect.Map { 60 | if v.Len() == 0 { 61 | buf.WriteString(d.EncodeBool(false)) 62 | return nil 63 | } 64 | return buildCmp(d, buf, "IN", column, value) 65 | } 66 | return buildCmp(d, buf, "=", column, value) 67 | }) 68 | } 69 | 70 | // Neq is `!=`. 71 | // When value is nil, it will be translated to `IS NOT NULL`. 72 | // When value is a slice, it will be translated to `NOT IN`. 73 | // Otherwise it will be translated to `!=`. 74 | func Neq(column string, value interface{}) Builder { 75 | return BuildFunc(func(d Dialect, buf Buffer) error { 76 | if value == nil { 77 | buf.WriteString(d.QuoteIdent(column)) 78 | buf.WriteString(" IS NOT NULL") 79 | return nil 80 | } 81 | v := reflect.ValueOf(value) 82 | if v.Kind() == reflect.Slice || v.Kind() == reflect.Map { 83 | if v.Len() == 0 { 84 | buf.WriteString(d.EncodeBool(true)) 85 | return nil 86 | } 87 | return buildCmp(d, buf, "NOT IN", column, value) 88 | } 89 | return buildCmp(d, buf, "!=", column, value) 90 | }) 91 | } 92 | 93 | // Gt is `>`. 94 | func Gt(column string, value interface{}) Builder { 95 | return BuildFunc(func(d Dialect, buf Buffer) error { 96 | return buildCmp(d, buf, ">", column, value) 97 | }) 98 | } 99 | 100 | // Gte is '>='. 101 | func Gte(column string, value interface{}) Builder { 102 | return BuildFunc(func(d Dialect, buf Buffer) error { 103 | return buildCmp(d, buf, ">=", column, value) 104 | }) 105 | } 106 | 107 | // Lt is '<'. 108 | func Lt(column string, value interface{}) Builder { 109 | return BuildFunc(func(d Dialect, buf Buffer) error { 110 | return buildCmp(d, buf, "<", column, value) 111 | }) 112 | } 113 | 114 | // Lte is `<=`. 115 | func Lte(column string, value interface{}) Builder { 116 | return BuildFunc(func(d Dialect, buf Buffer) error { 117 | return buildCmp(d, buf, "<=", column, value) 118 | }) 119 | } 120 | -------------------------------------------------------------------------------- /condition_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/mailru/dbr/dialect" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestCondition(t *testing.T) { 11 | for _, test := range []struct { 12 | cond Builder 13 | query string 14 | value []interface{} 15 | }{ 16 | { 17 | cond: Eq("col", 1), 18 | query: "`col` = ?", 19 | value: []interface{}{1}, 20 | }, 21 | { 22 | cond: Eq("col", nil), 23 | query: "`col` IS NULL", 24 | value: nil, 25 | }, 26 | { 27 | cond: Eq("col", []int{}), 28 | query: "0", 29 | value: nil, 30 | }, 31 | { 32 | cond: Eq("col", map[int]int{}), 33 | query: "0", 34 | value: nil, 35 | }, 36 | { 37 | cond: Eq("col", []int{1}), 38 | query: "`col` IN ?", 39 | value: []interface{}{[]int{1}}, 40 | }, 41 | { 42 | cond: Eq("col", map[int]int{1: 2}), 43 | query: "`col` IN ?", 44 | value: []interface{}{map[int]int{1: 2}}, 45 | }, 46 | { 47 | cond: Neq("col", 1), 48 | query: "`col` != ?", 49 | value: []interface{}{1}, 50 | }, 51 | { 52 | cond: Neq("col", nil), 53 | query: "`col` IS NOT NULL", 54 | value: nil, 55 | }, 56 | { 57 | cond: Neq("col", []int{}), 58 | query: "1", 59 | value: nil, 60 | }, 61 | { 62 | cond: Neq("col", []int{1}), 63 | query: "`col` NOT IN ?", 64 | value: []interface{}{[]int{1}}, 65 | }, 66 | { 67 | cond: Neq("col", map[int]int{1: 2}), 68 | query: "`col` NOT IN ?", 69 | value: []interface{}{map[int]int{1: 2}}, 70 | }, 71 | { 72 | cond: Gt("col", 1), 73 | query: "`col` > ?", 74 | value: []interface{}{1}, 75 | }, 76 | { 77 | cond: Gte("col", 1), 78 | query: "`col` >= ?", 79 | value: []interface{}{1}, 80 | }, 81 | { 82 | cond: Lt("col", 1), 83 | query: "`col` < ?", 84 | value: []interface{}{1}, 85 | }, 86 | { 87 | cond: Lte("col", 1), 88 | query: "`col` <= ?", 89 | value: []interface{}{1}, 90 | }, 91 | { 92 | cond: And(Lt("a", 1), Or(Gt("b", 2), Neq("c", 3))), 93 | query: "(`a` < ?) AND ((`b` > ?) OR (`c` != ?))", 94 | value: []interface{}{1, 2, 3}, 95 | }, 96 | } { 97 | buf := NewBuffer() 98 | err := test.cond.Build(dialect.MySQL, buf) 99 | assert.NoError(t, err) 100 | assert.Equal(t, test.query, buf.String()) 101 | assert.Equal(t, test.value, buf.Value()) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /dbr.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/mailru/dbr/dialect" 10 | ) 11 | 12 | // Open instantiates a Connection for a given database/sql connection 13 | // and event receiver 14 | func Open(driver, dsn string, log EventReceiver) (*Connection, error) { 15 | if log == nil { 16 | log = nullReceiver 17 | } 18 | conn, err := sql.Open(driver, dsn) 19 | if err != nil { 20 | return nil, err 21 | } 22 | var d Dialect 23 | switch driver { 24 | case "mysql": 25 | d = dialect.MySQL 26 | case "postgres": 27 | d = dialect.PostgreSQL 28 | case "sqlite3": 29 | d = dialect.SQLite3 30 | case "clickhouse", "chhttp": 31 | d = dialect.ClickHouse 32 | default: 33 | return nil, ErrNotSupported 34 | } 35 | return &Connection{DBConn: conn, EventReceiver: log, Dialect: d}, nil 36 | } 37 | 38 | const ( 39 | placeholder = "?" 40 | ) 41 | 42 | // Connection is a connection to the database with an EventReceiver 43 | // to send events, errors, and timings to 44 | type Connection struct { 45 | DBConn 46 | Dialect Dialect 47 | EventReceiver 48 | } 49 | 50 | // Session represents a business unit of execution for some connection 51 | type Session struct { 52 | *Connection 53 | EventReceiver 54 | ctx context.Context 55 | } 56 | 57 | // NewSession instantiates a Session for the Connection 58 | func (conn *Connection) NewSession(log EventReceiver) *Session { 59 | return conn.NewSessionContext(context.Background(), log) 60 | } 61 | 62 | // NewSessionContext instantiates a Session with context for the Connection 63 | func (conn *Connection) NewSessionContext(ctx context.Context, log EventReceiver) *Session { 64 | if log == nil { 65 | log = conn.EventReceiver // Use parent instrumentation 66 | } 67 | return &Session{Connection: conn, EventReceiver: log, ctx: ctx} 68 | } 69 | 70 | // NewSession forks current session 71 | func (sess *Session) NewSession(log EventReceiver) *Session { 72 | if log == nil { 73 | log = sess.EventReceiver 74 | } 75 | return &Session{Connection: sess.Connection, EventReceiver: log, ctx: sess.ctx} 76 | } 77 | 78 | // beginTx starts a transaction with context. 79 | func (conn *Connection) beginTx() (*sql.Tx, error) { 80 | return conn.Begin() 81 | } 82 | 83 | // SessionRunner can do anything that a Session can except start a transaction. 84 | type SessionRunner interface { 85 | Select(column ...string) SelectBuilder 86 | SelectBySql(query string, value ...interface{}) SelectBuilder 87 | 88 | InsertInto(table string) InsertBuilder 89 | InsertBySql(query string, value ...interface{}) InsertBuilder 90 | 91 | Update(table string) UpdateBuilder 92 | UpdateBySql(query string, value ...interface{}) UpdateBuilder 93 | 94 | DeleteFrom(table string) DeleteBuilder 95 | DeleteBySql(query string, value ...interface{}) DeleteBuilder 96 | } 97 | 98 | // DBConn interface for sql.DB 99 | type DBConn interface { 100 | runner 101 | 102 | BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error) 103 | Begin() (*sql.Tx, error) 104 | 105 | PingContext(ctx context.Context) error 106 | Ping() error 107 | 108 | Stats() sql.DBStats 109 | Close() error 110 | } 111 | 112 | type runner interface { 113 | Exec(query string, args ...interface{}) (sql.Result, error) 114 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 115 | 116 | Query(query string, args ...interface{}) (*sql.Rows, error) 117 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 118 | 119 | QueryRow(query string, args ...interface{}) *sql.Row 120 | QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 121 | } 122 | 123 | // Executer can execute requests to database 124 | type Executer interface { 125 | Exec() (sql.Result, error) 126 | ExecContext(ctx context.Context) (sql.Result, error) 127 | } 128 | 129 | type loader interface { 130 | Load(value interface{}) (int, error) 131 | LoadStruct(value interface{}) error 132 | LoadStructs(value interface{}) (int, error) 133 | LoadValue(value interface{}) error 134 | LoadValues(value interface{}) (int, error) 135 | LoadContext(ctx context.Context, value interface{}) (int, error) 136 | LoadStructContext(ctx context.Context, value interface{}) error 137 | LoadStructsContext(ctx context.Context, value interface{}) (int, error) 138 | LoadValueContext(ctx context.Context, value interface{}) error 139 | LoadValuesContext(ctx context.Context, value interface{}) (int, error) 140 | } 141 | 142 | func exec(ctx context.Context, runner runner, log EventReceiver, builder Builder, d Dialect) (sql.Result, error) { 143 | i := interpolator{ 144 | Buffer: NewBuffer(), 145 | Dialect: d, 146 | IgnoreBinary: true, 147 | } 148 | err := i.interpolate(placeholder, []interface{}{builder}) 149 | query, value := i.String(), i.Value() 150 | if err != nil { 151 | return nil, log.EventErrKv("dbr.exec.interpolate", err, kvs{ 152 | "sql": query, 153 | "args": fmt.Sprint(value), 154 | }) 155 | } 156 | 157 | startTime := time.Now() 158 | defer func() { 159 | log.TimingKv("dbr.exec", time.Since(startTime).Nanoseconds(), kvs{ 160 | "sql": query, 161 | }) 162 | }() 163 | 164 | traceImpl, hasTracingImpl := log.(TracingEventReceiver) 165 | if hasTracingImpl { 166 | ctx = traceImpl.SpanStart(ctx, "dbr.exec", query) 167 | defer traceImpl.SpanFinish(ctx) 168 | } 169 | 170 | result, err := runner.Exec(query, value...) 171 | if err != nil { 172 | if hasTracingImpl { 173 | traceImpl.SpanError(ctx, err) 174 | } 175 | 176 | return result, log.EventErrKv("dbr.exec.exec", err, kvs{ 177 | "sql": query, 178 | }) 179 | } 180 | return result, nil 181 | } 182 | 183 | func queryRows(ctx context.Context, runner runner, log EventReceiver, builder Builder, d Dialect) (*sql.Rows, string, error) { 184 | i := interpolator{ 185 | Buffer: NewBuffer(), 186 | Dialect: d, 187 | IgnoreBinary: true, 188 | } 189 | err := i.interpolate(placeholder, []interface{}{builder}) 190 | query, value := i.String(), i.Value() 191 | if err != nil { 192 | return nil, "", log.EventErrKv("dbr.select.interpolate", err, kvs{ 193 | "sql": query, 194 | "args": fmt.Sprint(value), 195 | }) 196 | } 197 | 198 | startTime := time.Now() 199 | defer func() { 200 | log.TimingKv("dbr.select", time.Since(startTime).Nanoseconds(), kvs{ 201 | "sql": query, 202 | }) 203 | }() 204 | 205 | traceImpl, hasTracingImpl := log.(TracingEventReceiver) 206 | if hasTracingImpl { 207 | ctx = traceImpl.SpanStart(ctx, "dbr.select", query) 208 | defer traceImpl.SpanFinish(ctx) 209 | } 210 | 211 | rows, err := runner.QueryContext(ctx, query, value...) 212 | if err != nil { 213 | if hasTracingImpl { 214 | traceImpl.SpanError(ctx, err) 215 | } 216 | 217 | return nil, query, log.EventErrKv("dbr.select.load.query", err, kvs{ 218 | "sql": query, 219 | }) 220 | } 221 | 222 | return rows, query, nil 223 | } 224 | 225 | func query(ctx context.Context, runner runner, log EventReceiver, builder Builder, d Dialect, dest interface{}) (int, error) { 226 | rows, query, err := queryRows(ctx, runner, log, builder, d) 227 | if err != nil { 228 | return 0, err 229 | } 230 | 231 | count, err := Load(rows, dest) 232 | if err != nil { 233 | return 0, log.EventErrKv("dbr.select.load.scan", err, kvs{ 234 | "sql": query, 235 | }) 236 | } 237 | 238 | return count, nil 239 | } 240 | -------------------------------------------------------------------------------- /dbr_go18.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package dbr 4 | 5 | import ( 6 | "database/sql" 7 | ) 8 | 9 | // Exec executes a query without returning any rows. 10 | // The args are for any placeholder parameters in the query. 11 | func (sess *Session) Exec(query string, args ...interface{}) (sql.Result, error) { 12 | return sess.ExecContext(sess.ctx, query, args...) 13 | } 14 | 15 | // Query executes a query that returns rows, typically a SELECT. 16 | // The args are for any placeholder parameters in the query. 17 | func (sess *Session) Query(query string, args ...interface{}) (*sql.Rows, error) { 18 | return sess.QueryContext(sess.ctx, query, args...) 19 | } 20 | 21 | // Exec executes a query without returning any rows. 22 | // The args are for any placeholder parameters in the query. 23 | func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { 24 | return tx.ExecContext(tx.ctx, query, args...) 25 | } 26 | 27 | // Query executes a query that returns rows, typically a SELECT. 28 | // The args are for any placeholder parameters in the query. 29 | func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { 30 | return tx.QueryContext(tx.ctx, query, args...) 31 | } 32 | 33 | // beginTx starts a transaction with context. 34 | func (sess *Session) beginTx(opts *sql.TxOptions) (*sql.Tx, error) { 35 | return sess.BeginTx(sess.ctx, opts) 36 | } 37 | -------------------------------------------------------------------------------- /dbr_go18_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package dbr 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "testing" 9 | 10 | "github.com/mailru/dbr/dialect" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestContextCancel(t *testing.T) { 15 | // context support is implemented for PostgreSQL 16 | for _, sess := range testSession { 17 | if sess.Dialect == dialect.SQLite3 { 18 | continue 19 | } 20 | checkSessionContext(t, postgresSession.Connection) 21 | if sess.Dialect != dialect.ClickHouse { 22 | checkTxQueryContext(t, postgresSession.Connection) 23 | checkTxExecContext(t, postgresSession.Connection) 24 | } 25 | } 26 | } 27 | 28 | func checkSessionContext(t *testing.T, conn *Connection) { 29 | ctx, cancel := context.WithCancel(context.Background()) 30 | cancel() 31 | sess := conn.NewSessionContext(ctx, nil) 32 | _, err := sess.SelectBySql("SELECT 1").ReturnInt64() 33 | assert.EqualError(t, err, "context canceled") 34 | _, err = sess.Update("dbr_people").Where(Eq("id", 1)).Set("name", "jonathan1").Exec() 35 | assert.EqualError(t, err, "context canceled") 36 | _, err = sess.Begin() 37 | assert.EqualError(t, err, "context canceled") 38 | } 39 | 40 | func checkTxQueryContext(t *testing.T, conn *Connection) { 41 | ctx, cancel := context.WithCancel(context.Background()) 42 | sess := conn.NewSessionContext(ctx, nil) 43 | tx, err := sess.Begin() 44 | if !assert.NoError(t, err) { 45 | cancel() 46 | return 47 | } 48 | cancel() 49 | _, err = tx.SelectBySql("SELECT 1").ReturnInt64() 50 | assert.EqualError(t, err, "context canceled") 51 | err = tx.Rollback() 52 | // context cancel may cause transaction rollback automatically 53 | assert.True(t, err == nil || err == sql.ErrTxDone) 54 | } 55 | 56 | func checkTxExecContext(t *testing.T, conn *Connection) { 57 | ctx, cancel := context.WithCancel(context.Background()) 58 | sess := conn.NewSessionContext(ctx, nil) 59 | tx, err := sess.Begin() 60 | if !assert.NoError(t, err) { 61 | cancel() 62 | return 63 | } 64 | _, err = tx.Update("dbr_people").Where(Eq("id", 1)).Set("name", "jonathan1").Exec() 65 | assert.NoError(t, err) 66 | cancel() 67 | assert.EqualError(t, tx.Commit(), "context canceled") 68 | } 69 | -------------------------------------------------------------------------------- /dbr_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "bytes" 5 | "log" 6 | "os" 7 | "testing" 8 | 9 | "github.com/mailru/dbr/dialect" 10 | "github.com/stretchr/testify/assert" 11 | 12 | _ "github.com/go-sql-driver/mysql" 13 | _ "github.com/lib/pq" 14 | _ "github.com/mailru/go-clickhouse" 15 | _ "github.com/mattn/go-sqlite3" 16 | ) 17 | 18 | // Ensure that tx and session are session runner 19 | var ( 20 | _ SessionRunner = (*Tx)(nil) 21 | _ SessionRunner = (*Session)(nil) 22 | ) 23 | 24 | var ( 25 | currID int64 = 256 26 | ) 27 | 28 | // nextID returns next pseudo unique id 29 | func nextID() int64 { 30 | currID++ 31 | return currID 32 | } 33 | 34 | const ( 35 | mysqlDSN = "root@unix(/tmp/mysql.sock)/dbr_test?charset=utf8" 36 | postgresDSN = "postgres://postgres@localhost:5432/dbr_test?sslmode=disable" 37 | sqlite3DSN = ":memory:" 38 | clickhouseDSN = "http://localhost:8123/dbr_test" 39 | ) 40 | 41 | func createSession(driver, dsn string) *Session { 42 | var testDSN string 43 | switch driver { 44 | case "mysql": 45 | testDSN = os.Getenv("DBR_TEST_MYSQL_DSN") 46 | case "postgres": 47 | testDSN = os.Getenv("DBR_TEST_POSTGRES_DSN") 48 | case "sqlite3": 49 | testDSN = os.Getenv("DBR_TEST_SQLITE3_DSN") 50 | case "clickhouse": 51 | testDSN = os.Getenv("DBR_TEST_CLICKHOUSE_DSN") 52 | } 53 | if testDSN != "" { 54 | dsn = testDSN 55 | } 56 | conn, err := Open(driver, dsn, nil) 57 | if err != nil { 58 | log.Fatal(err) 59 | } 60 | sess := conn.NewSession(nil) 61 | reset(sess) 62 | return sess 63 | } 64 | 65 | var ( 66 | mysqlSession = createSession("mysql", mysqlDSN) 67 | postgresSession = createSession("postgres", postgresDSN) 68 | postgresBinarySession = createSession("postgres", postgresDSN+"&binary_parameters=yes") 69 | sqlite3Session = createSession("sqlite3", sqlite3DSN) 70 | clickhouseSession = createSession("clickhouse", clickhouseDSN) 71 | 72 | // all test sessions should be here 73 | testSession = []*Session{mysqlSession, postgresSession, sqlite3Session, clickhouseSession} 74 | ) 75 | 76 | type person struct { 77 | ID int64 78 | Name string 79 | Email string 80 | } 81 | 82 | type nullTypedRecord struct { 83 | ID int64 84 | StringVal NullString 85 | Int64Val NullInt64 86 | Float64Val NullFloat64 87 | TimeVal NullTime 88 | BoolVal NullBool 89 | } 90 | 91 | func reset(sess *Session) { 92 | var stmts []string 93 | switch sess.Dialect { 94 | case dialect.MySQL: 95 | stmts = []string{ 96 | `DROP TABLE IF EXISTS dbr_people`, 97 | `CREATE TABLE dbr_people(id SERIAL PRIMARY KEY, name varchar(255) NOT NULL, email varchar(255))`, 98 | `DROP TABLE IF EXISTS null_types`, 99 | `CREATE TABLE null_types( 100 | id SERIAL PRIMARY KEY, 101 | string_val varchar(255) NULL, 102 | int64_val integer NULL, 103 | float64_val float NULL, 104 | time_val timestamp NULL, 105 | bool_val bool NULL 106 | )`, 107 | `DROP TABLE IF EXISTS dbr_keys`, 108 | `CREATE TABLE dbr_keys (key_value varchar(255) PRIMARY KEY, val_value varchar(255))`, 109 | } 110 | case dialect.PostgreSQL: 111 | stmts = []string{ 112 | "DROP TABLE IF EXISTS dbr_people", 113 | "CREATE TABLE dbr_people(id SERIAL PRIMARY KEY, name varchar(255) NOT NULL, email varchar(255))", 114 | `DROP TABLE IF EXISTS null_types`, 115 | `CREATE TABLE null_types( 116 | id SERIAL PRIMARY KEY, 117 | string_val varchar(255) NULL, 118 | int64_val integer NULL, 119 | float64_val float NULL, 120 | time_val timestamp NULL, 121 | bool_val bool NULL 122 | )`, 123 | `DROP TABLE IF EXISTS dbr_keys`, 124 | `CREATE TABLE dbr_keys (key_value varchar(255) PRIMARY KEY, val_value varchar(255))`, 125 | } 126 | case dialect.SQLite3: 127 | stmts = []string{ 128 | "DROP TABLE IF EXISTS dbr_people", 129 | "CREATE TABLE dbr_people(id INTEGER PRIMARY KEY, name varchar(255) NOT NULL, email varchar(255))", 130 | `DROP TABLE IF EXISTS null_types`, 131 | `CREATE TABLE null_types( 132 | id INTEGER PRIMARY KEY, 133 | string_val varchar(255) NULL, 134 | int64_val integer NULL, 135 | float64_val float NULL, 136 | time_val timestamp NULL, 137 | bool_val bool NULL 138 | )`, 139 | `DROP TABLE IF EXISTS dbr_keys`, 140 | `CREATE TABLE dbr_keys (key_value varchar(255) PRIMARY KEY, val_value varchar(255))`, 141 | } 142 | case dialect.ClickHouse: 143 | stmts = []string{ 144 | "DROP TABLE IF EXISTS dbr_people", 145 | "CREATE TABLE dbr_people(id Int32, name String, email String) Engine=Memory", 146 | `DROP TABLE IF EXISTS dbr_keys`, 147 | `CREATE TABLE dbr_keys (key_value String, val_value String) Engine=Memory`, 148 | } 149 | } 150 | for _, v := range stmts { 151 | _, err := sess.Exec(v) 152 | if err != nil { 153 | log.Fatalf("Failed to execute statement: %s, Got error: %s", v, err) 154 | } 155 | } 156 | } 157 | 158 | func BenchmarkByteaNoBinaryEncode(b *testing.B) { 159 | benchmarkBytea(b, postgresSession) 160 | } 161 | 162 | func BenchmarkByteaBinaryEncode(b *testing.B) { 163 | benchmarkBytea(b, postgresBinarySession) 164 | } 165 | 166 | func benchmarkBytea(b *testing.B, sess *Session) { 167 | data := bytes.Repeat([]byte("0123456789"), 1000) 168 | for _, v := range []string{ 169 | `DROP TABLE IF EXISTS bytea_table`, 170 | `CREATE TABLE bytea_table ( 171 | val bytea 172 | )`, 173 | } { 174 | _, err := sess.Exec(v) 175 | assert.NoError(b, err) 176 | } 177 | b.ResetTimer() 178 | 179 | for i := 0; i < b.N; i++ { 180 | _, err := sess.InsertInto("bytea_table").Pair("val", data).Exec() 181 | assert.NoError(b, err) 182 | } 183 | } 184 | 185 | func TestBasicCRUD(t *testing.T) { 186 | for _, sess := range testSession { 187 | jonathan := person{ 188 | Name: "jonathan", 189 | Email: "jonathan@uservoice.com", 190 | } 191 | insertColumns := []string{"name", "email"} 192 | if sess.Dialect == dialect.PostgreSQL || sess.Dialect == dialect.ClickHouse { 193 | jonathan.ID = nextID() 194 | insertColumns = []string{"id", "name", "email"} 195 | } 196 | // insert 197 | result, err := sess.InsertInto("dbr_people").Columns(insertColumns...).Record(&jonathan).Exec() 198 | assert.NoError(t, err) 199 | 200 | rowsAffected, err := result.RowsAffected() 201 | if err == nil { 202 | assert.EqualValues(t, 1, rowsAffected) 203 | } 204 | 205 | assert.True(t, jonathan.ID > 0) 206 | // select 207 | var people []person 208 | count, err := sess.Select("*").From("dbr_people").Where(Eq("id", jonathan.ID)).LoadStructs(&people) 209 | assert.NoError(t, err) 210 | if assert.Equal(t, 1, count) { 211 | assert.Equal(t, jonathan.ID, people[0].ID) 212 | assert.Equal(t, jonathan.Name, people[0].Name) 213 | assert.Equal(t, jonathan.Email, people[0].Email) 214 | } 215 | 216 | // select id 217 | ids, err := sess.Select("id").From("dbr_people").ReturnInt64s() 218 | assert.NoError(t, err) 219 | assert.Equal(t, 1, len(ids)) 220 | 221 | // select id limit 222 | ids, err = sess.Select("id").From("dbr_people").Limit(1).ReturnInt64s() 223 | assert.NoError(t, err) 224 | assert.Equal(t, 1, len(ids)) 225 | 226 | if sess.Dialect == dialect.ClickHouse { 227 | // clickhouse does not support update/delete 228 | continue 229 | } 230 | // update 231 | result, err = sess.Update("dbr_people").Where(Eq("id", jonathan.ID)).Set("name", "jonathan1").Exec() 232 | assert.NoError(t, err) 233 | 234 | rowsAffected, err = result.RowsAffected() 235 | assert.NoError(t, err) 236 | assert.EqualValues(t, 1, rowsAffected) 237 | 238 | var n NullInt64 239 | sess.Select("count(*)").From("dbr_people").Where("name = ?", "jonathan1").LoadValue(&n) 240 | assert.EqualValues(t, 1, n.Int64) 241 | 242 | // delete 243 | result, err = sess.DeleteFrom("dbr_people").Where(Eq("id", jonathan.ID)).Exec() 244 | assert.NoError(t, err) 245 | 246 | rowsAffected, err = result.RowsAffected() 247 | assert.NoError(t, err) 248 | assert.EqualValues(t, 1, rowsAffected) 249 | 250 | // select id 251 | ids, err = sess.Select("id").From("dbr_people").ReturnInt64s() 252 | assert.NoError(t, err) 253 | assert.Equal(t, 0, len(ids)) 254 | } 255 | } 256 | 257 | func TestOnConflict(t *testing.T) { 258 | for _, sess := range testSession { 259 | if sess.Dialect == dialect.SQLite3 || sess.Dialect == dialect.ClickHouse { 260 | continue 261 | } 262 | for i := 0; i < 2; i++ { 263 | b := sess.InsertInto("dbr_keys").Columns("key_value", "val_value").Values("key", "value") 264 | b.OnConflict("dbr_keys_pkey").Action("val_value", Expr("CONCAT(?, 2)", Proposed("val_value"))) 265 | _, err := b.Exec() 266 | assert.NoError(t, err) 267 | } 268 | var value string 269 | _, err := sess.SelectBySql("SELECT val_value FROM dbr_keys WHERE key_value=?", "key").Load(&value) 270 | assert.NoError(t, err) 271 | assert.Equal(t, "value2", value) 272 | } 273 | } 274 | 275 | func TestForkSession(t *testing.T) { 276 | sess := testSession[0] 277 | sess2 := sess.NewSession(nil) 278 | assert.True(t, sess.ctx == sess2.ctx) 279 | assert.True(t, sess.EventReceiver == sess2.EventReceiver) 280 | recv := new(NullEventReceiver) 281 | sess3 := sess.NewSession(recv) 282 | assert.True(t, sess3.EventReceiver != sess.EventReceiver) 283 | } 284 | -------------------------------------------------------------------------------- /delete.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | // DeleteStmt builds `DELETE ...` 4 | type DeleteStmt interface { 5 | Builder 6 | Where(query interface{}, value ...interface{}) DeleteStmt 7 | } 8 | 9 | type deleteStmt struct { 10 | raw 11 | 12 | Table string 13 | WhereCond []Builder 14 | } 15 | 16 | // Build builds `DELETE ...` in dialect 17 | func (b *deleteStmt) Build(d Dialect, buf Buffer) error { 18 | if b.raw.Query != "" { 19 | return b.raw.Build(d, buf) 20 | } 21 | 22 | if b.Table == "" { 23 | return ErrTableNotSpecified 24 | } 25 | 26 | buf.WriteString("DELETE FROM ") 27 | buf.WriteString(d.QuoteIdent(b.Table)) 28 | 29 | if len(b.WhereCond) > 0 { 30 | buf.WriteString(" WHERE ") 31 | err := And(b.WhereCond...).Build(d, buf) 32 | if err != nil { 33 | return err 34 | } 35 | } 36 | return nil 37 | } 38 | 39 | // DeleteFrom creates a DeleteStmt 40 | func DeleteFrom(table string) DeleteStmt { 41 | return createDeleteStmt(table) 42 | } 43 | 44 | func createDeleteStmt(table string) *deleteStmt { 45 | return &deleteStmt{ 46 | Table: table, 47 | } 48 | } 49 | 50 | // DeleteBySql creates a DeleteStmt from raw query 51 | func DeleteBySql(query string, value ...interface{}) DeleteStmt { 52 | return createDeleteStmtBySQL(query, value) 53 | } 54 | 55 | func createDeleteStmtBySQL(query string, value []interface{}) *deleteStmt { 56 | return &deleteStmt{ 57 | raw: raw{ 58 | Query: query, 59 | Value: value, 60 | }, 61 | } 62 | } 63 | 64 | // Where adds a where condition 65 | func (b *deleteStmt) Where(query interface{}, value ...interface{}) DeleteStmt { 66 | switch query := query.(type) { 67 | case string: 68 | b.WhereCond = append(b.WhereCond, Expr(query, value...)) 69 | case Builder: 70 | b.WhereCond = append(b.WhereCond, query) 71 | } 72 | return b 73 | } 74 | -------------------------------------------------------------------------------- /delete_builder.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | ) 8 | 9 | // DeleteBuilder builds "DELETE ..." stmt 10 | type DeleteBuilder interface { 11 | Builder 12 | EventReceiver 13 | Executer 14 | 15 | Where(query interface{}, value ...interface{}) DeleteBuilder 16 | Limit(n uint64) DeleteBuilder 17 | } 18 | 19 | type deleteBuilder struct { 20 | runner 21 | EventReceiver 22 | 23 | Dialect Dialect 24 | deleteStmt *deleteStmt 25 | LimitCount int64 26 | ctx context.Context 27 | } 28 | 29 | // DeleteFrom creates a DeleteBuilder 30 | func (sess *Session) DeleteFrom(table string) DeleteBuilder { 31 | return &deleteBuilder{ 32 | runner: sess, 33 | EventReceiver: sess.EventReceiver, 34 | Dialect: sess.Dialect, 35 | deleteStmt: createDeleteStmt(table), 36 | LimitCount: -1, 37 | ctx: sess.ctx, 38 | } 39 | } 40 | 41 | // DeleteFrom creates a DeleteBuilder 42 | func (tx *Tx) DeleteFrom(table string) DeleteBuilder { 43 | return &deleteBuilder{ 44 | runner: tx, 45 | EventReceiver: tx.EventReceiver, 46 | Dialect: tx.Dialect, 47 | deleteStmt: createDeleteStmt(table), 48 | LimitCount: -1, 49 | ctx: tx.ctx, 50 | } 51 | } 52 | 53 | // DeleteBySql creates a DeleteBuilder from raw query 54 | func (sess *Session) DeleteBySql(query string, value ...interface{}) DeleteBuilder { 55 | return &deleteBuilder{ 56 | runner: sess, 57 | EventReceiver: sess.EventReceiver, 58 | Dialect: sess.Dialect, 59 | deleteStmt: createDeleteStmtBySQL(query, value), 60 | LimitCount: -1, 61 | ctx: sess.ctx, 62 | } 63 | } 64 | 65 | // DeleteBySql creates a DeleteBuilder from raw query 66 | func (tx *Tx) DeleteBySql(query string, value ...interface{}) DeleteBuilder { 67 | return &deleteBuilder{ 68 | runner: tx, 69 | EventReceiver: tx.EventReceiver, 70 | Dialect: tx.Dialect, 71 | deleteStmt: createDeleteStmtBySQL(query, value), 72 | LimitCount: -1, 73 | ctx: tx.ctx, 74 | } 75 | } 76 | 77 | // Exec executes the stmt with background context 78 | func (b *deleteBuilder) Exec() (sql.Result, error) { 79 | return b.ExecContext(b.ctx) 80 | } 81 | 82 | // ExecContext executes the stmt 83 | func (b *deleteBuilder) ExecContext(ctx context.Context) (sql.Result, error) { 84 | return exec(ctx, b.runner, b.EventReceiver, b, b.Dialect) 85 | } 86 | 87 | // Where adds condition to the stmt 88 | func (b *deleteBuilder) Where(query interface{}, value ...interface{}) DeleteBuilder { 89 | b.deleteStmt.Where(query, value...) 90 | return b 91 | } 92 | 93 | // Limit adds LIMIT 94 | func (b *deleteBuilder) Limit(n uint64) DeleteBuilder { 95 | b.LimitCount = int64(n) 96 | return b 97 | } 98 | 99 | // Build builds `DELETE ...` in dialect 100 | func (b *deleteBuilder) Build(d Dialect, buf Buffer) error { 101 | err := b.deleteStmt.Build(b.Dialect, buf) 102 | if err != nil { 103 | return err 104 | } 105 | if b.LimitCount >= 0 { 106 | buf.WriteString(" LIMIT ") 107 | buf.WriteString(fmt.Sprint(b.LimitCount)) 108 | } 109 | return nil 110 | } 111 | -------------------------------------------------------------------------------- /delete_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/mailru/dbr/dialect" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestDeleteStmt(t *testing.T) { 11 | buf := NewBuffer() 12 | builder := DeleteFrom("table").Where(Eq("a", 1)) 13 | err := builder.Build(dialect.MySQL, buf) 14 | assert.NoError(t, err) 15 | assert.Equal(t, "DELETE FROM `table` WHERE (`a` = ?)", buf.String()) 16 | assert.Equal(t, []interface{}{1}, buf.Value()) 17 | } 18 | 19 | func BenchmarkDeleteSQL(b *testing.B) { 20 | buf := NewBuffer() 21 | for i := 0; i < b.N; i++ { 22 | DeleteFrom("table").Where(Eq("a", 1)).Build(dialect.MySQL, buf) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /dialect.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import "time" 4 | 5 | // Dialect abstracts database differences 6 | type Dialect interface { 7 | QuoteIdent(id string) string 8 | 9 | EncodeString(s string) string 10 | EncodeBool(b bool) string 11 | EncodeTime(t time.Time) string 12 | EncodeBytes(b []byte) string 13 | Placeholder(n int) string 14 | OnConflict(constraint string) string 15 | Proposed(column string) string 16 | Limit(offset, limit int64) string 17 | Prewhere() string 18 | } 19 | -------------------------------------------------------------------------------- /dialect/clickhouse.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "time" 7 | ) 8 | 9 | const ( 10 | clickhouseTimeFormat = "2006-01-02 15:04:05" 11 | ) 12 | 13 | type clickhouse struct{} 14 | 15 | func (d clickhouse) QuoteIdent(s string) string { 16 | return quoteIdent(s, "`") 17 | } 18 | 19 | func (d clickhouse) EncodeString(s string) string { 20 | buf := new(bytes.Buffer) 21 | 22 | buf.WriteRune('\'') 23 | for i := 0; i < len(s); i++ { 24 | switch s[i] { 25 | case 0: 26 | buf.WriteString(`\0`) 27 | case '\'': 28 | buf.WriteString(`\'`) 29 | case '"': 30 | buf.WriteString(`\"`) 31 | case '\b': 32 | buf.WriteString(`\b`) 33 | case '\n': 34 | buf.WriteString(`\n`) 35 | case '\r': 36 | buf.WriteString(`\r`) 37 | case '\t': 38 | buf.WriteString(`\t`) 39 | case 26: 40 | buf.WriteString(`\Z`) 41 | case '\\': 42 | buf.WriteString(`\\`) 43 | default: 44 | buf.WriteByte(s[i]) 45 | } 46 | } 47 | 48 | buf.WriteRune('\'') 49 | return buf.String() 50 | } 51 | 52 | func (d clickhouse) EncodeBool(b bool) string { 53 | if b { 54 | return "1" 55 | } 56 | return "0" 57 | } 58 | 59 | func (d clickhouse) EncodeTime(t time.Time) string { 60 | return `'` + t.UTC().Format(clickhouseTimeFormat) + `'` 61 | } 62 | 63 | func (d clickhouse) EncodeBytes(b []byte) string { 64 | return fmt.Sprintf(`0x%x`, b) 65 | } 66 | 67 | func (d clickhouse) Placeholder(_ int) string { 68 | return "?" 69 | } 70 | 71 | func (d clickhouse) OnConflict(_ string) string { 72 | return "" 73 | } 74 | 75 | func (d clickhouse) Proposed(_ string) string { 76 | return "" 77 | } 78 | 79 | func (d clickhouse) Limit(offset, limit int64) string { 80 | if offset < 0 { 81 | return fmt.Sprintf("LIMIT %d", limit) 82 | } 83 | return fmt.Sprintf("LIMIT %d,%d", offset, limit) 84 | } 85 | 86 | func (d clickhouse) String() string { 87 | return "clickhouse" 88 | } 89 | 90 | func (d clickhouse) Prewhere() string { 91 | return "PREWHERE" 92 | } 93 | -------------------------------------------------------------------------------- /dialect/dialect.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import "strings" 4 | 5 | var ( 6 | //ClickHouse dialect 7 | ClickHouse = clickhouse{} 8 | // MySQL dialect 9 | MySQL = mysql{} 10 | // PostgreSQL dialect 11 | PostgreSQL = postgreSQL{} 12 | // SQLite3 dialect 13 | SQLite3 = sqlite3{} 14 | ) 15 | 16 | const ( 17 | timeFormat = "2006-01-02 15:04:05.000000" 18 | ) 19 | 20 | func quoteIdent(s, quote string) string { 21 | part := strings.SplitN(s, ".", 2) 22 | if len(part) == 2 { 23 | return quoteIdent(part[0], quote) + "." + quoteIdent(part[1], quote) 24 | } 25 | return quote + s + quote 26 | } 27 | -------------------------------------------------------------------------------- /dialect/dialect_test.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestMySQL(t *testing.T) { 10 | for _, test := range []struct { 11 | in string 12 | want string 13 | }{ 14 | { 15 | in: "table.col", 16 | want: "`table`.`col`", 17 | }, 18 | { 19 | in: "col", 20 | want: "`col`", 21 | }, 22 | } { 23 | assert.Equal(t, test.want, MySQL.QuoteIdent(test.in)) 24 | } 25 | } 26 | 27 | func TestPostgreSQL(t *testing.T) { 28 | for _, test := range []struct { 29 | in string 30 | want string 31 | }{ 32 | { 33 | in: "table.col", 34 | want: `"table"."col"`, 35 | }, 36 | { 37 | in: "col", 38 | want: `"col"`, 39 | }, 40 | } { 41 | assert.Equal(t, test.want, PostgreSQL.QuoteIdent(test.in)) 42 | } 43 | } 44 | 45 | func TestSQLite3(t *testing.T) { 46 | for _, test := range []struct { 47 | in string 48 | want string 49 | }{ 50 | { 51 | in: "table.col", 52 | want: `"table"."col"`, 53 | }, 54 | { 55 | in: "col", 56 | want: `"col"`, 57 | }, 58 | } { 59 | assert.Equal(t, test.want, SQLite3.QuoteIdent(test.in)) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /dialect/mysql.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "time" 7 | ) 8 | 9 | type mysql struct{} 10 | 11 | func (d mysql) QuoteIdent(s string) string { 12 | return quoteIdent(s, "`") 13 | } 14 | 15 | func (d mysql) EncodeString(s string) string { 16 | buf := new(bytes.Buffer) 17 | 18 | buf.WriteRune('\'') 19 | // https://dev.mysql.com/doc/refman/5.7/en/string-literals.html 20 | for i := 0; i < len(s); i++ { 21 | switch s[i] { 22 | case 0: 23 | buf.WriteString(`\0`) 24 | case '\'': 25 | buf.WriteString(`\'`) 26 | case '"': 27 | buf.WriteString(`\"`) 28 | case '\b': 29 | buf.WriteString(`\b`) 30 | case '\n': 31 | buf.WriteString(`\n`) 32 | case '\r': 33 | buf.WriteString(`\r`) 34 | case '\t': 35 | buf.WriteString(`\t`) 36 | case 26: 37 | buf.WriteString(`\Z`) 38 | case '\\': 39 | buf.WriteString(`\\`) 40 | default: 41 | buf.WriteByte(s[i]) 42 | } 43 | } 44 | 45 | buf.WriteRune('\'') 46 | return buf.String() 47 | } 48 | 49 | func (d mysql) EncodeBool(b bool) string { 50 | if b { 51 | return "1" 52 | } 53 | return "0" 54 | } 55 | 56 | func (d mysql) EncodeTime(t time.Time) string { 57 | return `'` + t.UTC().Format(timeFormat) + `'` 58 | } 59 | 60 | func (d mysql) EncodeBytes(b []byte) string { 61 | return fmt.Sprintf(`0x%x`, b) 62 | } 63 | 64 | func (d mysql) Placeholder(_ int) string { 65 | return "?" 66 | } 67 | 68 | func (d mysql) OnConflict(_ string) string { 69 | return "ON DUPLICATE KEY UPDATE" 70 | } 71 | 72 | func (d mysql) Proposed(column string) string { 73 | return fmt.Sprintf("VALUES(%s)", d.QuoteIdent(column)) 74 | } 75 | 76 | func (d mysql) Limit(offset, limit int64) string { 77 | if offset < 0 { 78 | return fmt.Sprintf("LIMIT %d", limit) 79 | } 80 | return fmt.Sprintf("LIMIT %d,%d", offset, limit) 81 | } 82 | 83 | func (d mysql) Prewhere() string { 84 | return "" 85 | } 86 | -------------------------------------------------------------------------------- /dialect/postgresql.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "time" 7 | ) 8 | 9 | type postgreSQL struct{} 10 | 11 | func (d postgreSQL) QuoteIdent(s string) string { 12 | return quoteIdent(s, `"`) 13 | } 14 | 15 | func (d postgreSQL) EncodeString(s string) string { 16 | // http://www.postgresql.org/docs/9.2/static/sql-syntax-lexical.html 17 | return `'` + strings.Replace(s, `'`, `''`, -1) + `'` 18 | } 19 | 20 | func (d postgreSQL) EncodeBool(b bool) string { 21 | if b { 22 | return "TRUE" 23 | } 24 | return "FALSE" 25 | } 26 | 27 | func (d postgreSQL) EncodeTime(t time.Time) string { 28 | return MySQL.EncodeTime(t) 29 | } 30 | 31 | func (d postgreSQL) EncodeBytes(b []byte) string { 32 | return fmt.Sprintf(`E'\\x%x'`, b) 33 | } 34 | 35 | func (d postgreSQL) Placeholder(n int) string { 36 | return fmt.Sprintf("$%d", n+1) 37 | } 38 | 39 | func (d postgreSQL) OnConflict(constraint string) string { 40 | return fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", d.QuoteIdent(constraint)) 41 | } 42 | 43 | func (d postgreSQL) Proposed(column string) string { 44 | return fmt.Sprintf("EXCLUDED.%s", d.QuoteIdent(column)) 45 | } 46 | 47 | func (d postgreSQL) Limit(offset, limit int64) string { 48 | if offset < 0 { 49 | return fmt.Sprintf("LIMIT %d", limit) 50 | } 51 | return fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) 52 | } 53 | 54 | func (d postgreSQL) Prewhere() string { 55 | return "" 56 | } 57 | -------------------------------------------------------------------------------- /dialect/sqlite3.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "time" 7 | ) 8 | 9 | type sqlite3 struct{} 10 | 11 | func (d sqlite3) QuoteIdent(s string) string { 12 | return quoteIdent(s, `"`) 13 | } 14 | 15 | func (d sqlite3) EncodeString(s string) string { 16 | // https://www.sqlite.org/faq.html 17 | return `'` + strings.Replace(s, `'`, `''`, -1) + `'` 18 | } 19 | 20 | func (d sqlite3) EncodeBool(b bool) string { 21 | // https://www.sqlite.org/lang_expr.html 22 | if b { 23 | return "1" 24 | } 25 | return "0" 26 | } 27 | 28 | func (d sqlite3) EncodeTime(t time.Time) string { 29 | // https://www.sqlite.org/lang_datefunc.html 30 | return MySQL.EncodeTime(t) 31 | } 32 | 33 | func (d sqlite3) EncodeBytes(b []byte) string { 34 | // https://www.sqlite.org/lang_expr.html 35 | return fmt.Sprintf(`X'%x'`, b) 36 | } 37 | 38 | func (d sqlite3) Placeholder(_ int) string { 39 | return "?" 40 | } 41 | 42 | func (d sqlite3) OnConflict(_ string) string { 43 | return "" 44 | } 45 | 46 | func (d sqlite3) Proposed(_ string) string { 47 | return "" 48 | } 49 | 50 | func (d sqlite3) Limit(offset, limit int64) string { 51 | if offset < 0 { 52 | return fmt.Sprintf("LIMIT %d", limit) 53 | } 54 | return fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) 55 | } 56 | 57 | func (d sqlite3) Prewhere() string { 58 | return "" 59 | } 60 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import "errors" 4 | 5 | // package errors 6 | var ( 7 | ErrNotFound = errors.New("dbr: not found") 8 | ErrNotSupported = errors.New("dbr: not supported") 9 | ErrTableNotSpecified = errors.New("dbr: table not specified") 10 | ErrColumnNotSpecified = errors.New("dbr: column not specified") 11 | ErrInvalidPointer = errors.New("dbr: attempt to load into an invalid pointer") 12 | ErrPlaceholderCount = errors.New("dbr: wrong placeholder count") 13 | ErrInvalidSliceLength = errors.New("dbr: length of slice is 0. length must be >= 1") 14 | ErrCantConvertToTime = errors.New("dbr: can't convert to time.Time") 15 | ErrInvalidTimestring = errors.New("dbr: invalid time string") 16 | ErrPrewhereNotSupported = errors.New("dbr: PREWHERE statement is not supported") 17 | ) 18 | -------------------------------------------------------------------------------- /event.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import "context" 4 | 5 | // EventReceiver gets events from dbr methods for logging purposes 6 | type EventReceiver interface { 7 | Event(eventName string) 8 | EventKv(eventName string, kvs map[string]string) 9 | EventErr(eventName string, err error) error 10 | EventErrKv(eventName string, err error, kvs map[string]string) error 11 | Timing(eventName string, nanoseconds int64) 12 | TimingKv(eventName string, nanoseconds int64, kvs map[string]string) 13 | } 14 | 15 | // TracingEventReceiver is an optional interface an EventReceiver type can implement 16 | // to allow tracing instrumentation 17 | type TracingEventReceiver interface { 18 | SpanStart(ctx context.Context, eventName, query string) context.Context 19 | SpanError(ctx context.Context, err error) 20 | SpanFinish(ctx context.Context) 21 | } 22 | 23 | type kvs map[string]string 24 | 25 | var nullReceiver = &NullEventReceiver{} 26 | 27 | // NullEventReceiver is a sentinel EventReceiver; use it if the caller doesn't supply one 28 | type NullEventReceiver struct{} 29 | 30 | // Event receives a simple notification when various events occur 31 | func (n *NullEventReceiver) Event(eventName string) {} 32 | 33 | // EventKv receives a notification when various events occur along with 34 | // optional key/value data 35 | func (n *NullEventReceiver) EventKv(eventName string, kvs map[string]string) {} 36 | 37 | // EventErr receives a notification of an error if one occurs 38 | func (n *NullEventReceiver) EventErr(eventName string, err error) error { return err } 39 | 40 | // EventErrKv receives a notification of an error if one occurs along with 41 | // optional key/value data 42 | func (n *NullEventReceiver) EventErrKv(eventName string, err error, kvs map[string]string) error { 43 | return err 44 | } 45 | 46 | // Timing receives the time an event took to happen 47 | func (n *NullEventReceiver) Timing(eventName string, nanoseconds int64) {} 48 | 49 | // TimingKv receives the time an event took to happen along with optional key/value data 50 | func (n *NullEventReceiver) TimingKv(eventName string, nanoseconds int64, kvs map[string]string) {} 51 | -------------------------------------------------------------------------------- /expr.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | type raw struct { 4 | Query string 5 | Value []interface{} 6 | } 7 | 8 | // Expr should be used when sql syntax is not supported 9 | func Expr(query string, value ...interface{}) Builder { 10 | return &raw{Query: query, Value: value} 11 | } 12 | 13 | func (raw *raw) Build(_ Dialect, buf Buffer) error { 14 | buf.WriteString(raw.Query) 15 | buf.WriteValue(raw.Value...) 16 | return nil 17 | } 18 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mailru/dbr 2 | 3 | go 1.11 4 | 5 | require ( 6 | github.com/DATA-DOG/go-sqlmock v1.3.0 7 | github.com/go-sql-driver/mysql v1.4.1 8 | github.com/lib/pq v1.2.0 9 | github.com/mailru/go-clickhouse v1.1.0 10 | github.com/mattn/go-sqlite3 v1.11.0 11 | github.com/stretchr/testify v1.4.0 12 | google.golang.org/appengine v1.6.2 // indirect 13 | ) 14 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/DATA-DOG/go-sqlmock v1.3.0 h1:ljjRxlddjfChBJdFKJs5LuCwCWPLaC1UZLwAo3PBBMk= 2 | github.com/DATA-DOG/go-sqlmock v1.3.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= 3 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= 6 | github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= 7 | github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 8 | github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= 9 | github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= 10 | github.com/mailru/go-clickhouse v1.1.0 h1:o23GiQ1CHyb/FnDizEOuKIq5l7HJFepCgLR8BV8v/I8= 11 | github.com/mailru/go-clickhouse v1.1.0/go.mod h1:nJ671Q14775Y+SpWW28Km2gPSfIgLluZb5F1bUqX6PQ= 12 | github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q= 13 | github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= 14 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 15 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 16 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 17 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= 18 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 19 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 20 | golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 21 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 22 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 23 | golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= 24 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 25 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 26 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 27 | golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 28 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 29 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 30 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 31 | golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= 32 | google.golang.org/appengine v1.6.2/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= 33 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 34 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 35 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 36 | -------------------------------------------------------------------------------- /ident.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | // I is a identifier, which always will be quoted 4 | type I string 5 | 6 | // Build escapes identifier in Dialect 7 | func (i I) Build(d Dialect, buf Buffer) error { 8 | buf.WriteString(d.QuoteIdent(string(i))) 9 | return nil 10 | } 11 | 12 | // As creates an alias for expr. e.g. SELECT `a1` AS `a2` 13 | func (i I) As(alias string) Builder { 14 | return as(i, alias) 15 | } 16 | 17 | func as(expr interface{}, alias string) Builder { 18 | return BuildFunc(func(d Dialect, buf Buffer) error { 19 | buf.WriteString(placeholder) 20 | buf.WriteValue(expr) 21 | buf.WriteString(" AS ") 22 | buf.WriteString(d.QuoteIdent(alias)) 23 | return nil 24 | }) 25 | } 26 | -------------------------------------------------------------------------------- /insert.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "reflect" 7 | "sort" 8 | ) 9 | 10 | // ConflictStmt is ` ON CONFLICT ...` part of InsertStmt 11 | type ConflictStmt interface { 12 | Action(column string, action interface{}) ConflictStmt 13 | } 14 | 15 | type conflictStmt struct { 16 | constraint string 17 | actions map[string]interface{} 18 | } 19 | 20 | // Action adds action for column which will do if conflict happens 21 | func (b *conflictStmt) Action(column string, action interface{}) ConflictStmt { 22 | b.actions[column] = action 23 | return b 24 | } 25 | 26 | // InsertStmt builds `INSERT INTO ...` 27 | type InsertStmt interface { 28 | Builder 29 | Columns(column ...string) InsertStmt 30 | Values(value ...interface{}) InsertStmt 31 | Record(structValue interface{}) InsertStmt 32 | OnConflictMap(constraint string, actions map[string]interface{}) InsertStmt 33 | OnConflict(constraint string) ConflictStmt 34 | } 35 | 36 | type insertStmt struct { 37 | raw 38 | 39 | Table string 40 | Column []string 41 | Value [][]interface{} 42 | Conflict *conflictStmt 43 | } 44 | 45 | // Proposed is reference to proposed value in on conflict clause 46 | func Proposed(column string) Builder { 47 | return BuildFunc(func(d Dialect, b Buffer) error { 48 | _, err := b.WriteString(d.Proposed(column)) 49 | return err 50 | }) 51 | } 52 | 53 | // Build builds `INSERT INTO ...` in dialect 54 | func (b *insertStmt) Build(d Dialect, buf Buffer) error { 55 | if b.raw.Query != "" { 56 | return b.raw.Build(d, buf) 57 | } 58 | 59 | if b.Table == "" { 60 | return ErrTableNotSpecified 61 | } 62 | 63 | if len(b.Column) == 0 { 64 | return ErrColumnNotSpecified 65 | } 66 | 67 | buf.WriteString("INSERT INTO ") 68 | buf.WriteString(d.QuoteIdent(b.Table)) 69 | 70 | placeholderBuf := new(bytes.Buffer) 71 | placeholderBuf.WriteString("(") 72 | buf.WriteString(" (") 73 | for i, col := range b.Column { 74 | if i > 0 { 75 | buf.WriteString(",") 76 | placeholderBuf.WriteString(",") 77 | } 78 | buf.WriteString(d.QuoteIdent(col)) 79 | placeholderBuf.WriteString(placeholder) 80 | } 81 | buf.WriteString(") VALUES ") 82 | placeholderBuf.WriteString(")") 83 | placeholderStr := placeholderBuf.String() 84 | 85 | for i, tuple := range b.Value { 86 | if i > 0 { 87 | buf.WriteString(", ") 88 | } 89 | buf.WriteString(placeholderStr) 90 | 91 | buf.WriteValue(tuple...) 92 | } 93 | if b.Conflict != nil && len(b.Conflict.actions) > 0 { 94 | keyword := d.OnConflict(b.Conflict.constraint) 95 | if len(keyword) == 0 { 96 | return fmt.Errorf("Dialect %s does not support OnConflict", d) 97 | } 98 | buf.WriteString(" ") 99 | buf.WriteString(keyword) 100 | buf.WriteString(" ") 101 | needComma := false 102 | for _, column := range b.Column { 103 | if v, ok := b.Conflict.actions[column]; ok { 104 | if needComma { 105 | buf.WriteString(",") 106 | } 107 | buf.WriteString(d.QuoteIdent(column)) 108 | buf.WriteString("=") 109 | buf.WriteString(placeholder) 110 | buf.WriteValue(v) 111 | needComma = true 112 | } 113 | } 114 | } 115 | 116 | return nil 117 | } 118 | 119 | // InsertInto creates an InsertStmt 120 | func InsertInto(table string) InsertStmt { 121 | return createInsertStmt(table) 122 | } 123 | 124 | func createInsertStmt(table string) *insertStmt { 125 | return &insertStmt{ 126 | Table: table, 127 | } 128 | } 129 | 130 | // InsertBySql creates an InsertStmt from raw query 131 | func InsertBySql(query string, value ...interface{}) InsertStmt { 132 | return createInsertStmtBySQL(query, value) 133 | } 134 | 135 | func createInsertStmtBySQL(query string, value []interface{}) *insertStmt { 136 | return &insertStmt{ 137 | raw: raw{ 138 | Query: query, 139 | Value: value, 140 | }, 141 | } 142 | } 143 | 144 | // Columns adds columns 145 | func (b *insertStmt) Columns(column ...string) InsertStmt { 146 | b.Column = append(b.Column, column...) 147 | return b 148 | } 149 | 150 | // Values adds a tuple for columns 151 | func (b *insertStmt) Values(value ...interface{}) InsertStmt { 152 | b.Value = append(b.Value, value) 153 | return b 154 | } 155 | 156 | // Record adds a tuple for columns from a struct if no columns where 157 | // specified yet for this insert, the record fields will be used to populate the columns. 158 | func (b *insertStmt) Record(structValue interface{}) InsertStmt { 159 | v := reflect.Indirect(reflect.ValueOf(structValue)) 160 | 161 | if v.Kind() == reflect.Struct { 162 | var value []interface{} 163 | m := structMap(v.Type()) 164 | 165 | // populate columns from available record fields 166 | // if no columns were specified up to this point 167 | if len(b.Column) == 0 { 168 | b.Column = make([]string, 0, len(m)) 169 | for key := range m { 170 | b.Column = append(b.Column, key) 171 | } 172 | 173 | // ensure that the column ordering is deterministic 174 | sort.Strings(b.Column) 175 | } 176 | 177 | for _, key := range b.Column { 178 | if index, ok := m[key]; ok { 179 | value = append(value, v.FieldByIndex(index).Interface()) 180 | } else { 181 | value = append(value, nil) 182 | } 183 | } 184 | b.Values(value...) 185 | } 186 | 187 | return b 188 | } 189 | 190 | // OnConflictMap allows to add actions for constraint violation, e.g UPSERT 191 | func (b *insertStmt) OnConflictMap(constraint string, actions map[string]interface{}) InsertStmt { 192 | b.Conflict = &conflictStmt{constraint: constraint, actions: actions} 193 | return b 194 | } 195 | 196 | // OnConflict creates an empty OnConflict section fo insert statement , e.g UPSERT 197 | func (b *insertStmt) OnConflict(constraint string) ConflictStmt { 198 | b.Conflict = &conflictStmt{constraint: constraint, actions: make(map[string]interface{})} 199 | return b.Conflict 200 | } 201 | -------------------------------------------------------------------------------- /insert_builder.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "reflect" 7 | ) 8 | 9 | // InsertBuilder builds "INSERT ..." stmt 10 | type InsertBuilder interface { 11 | Builder 12 | EventReceiver 13 | Executer 14 | Columns(column ...string) InsertBuilder 15 | Values(value ...interface{}) InsertBuilder 16 | Record(structValue interface{}) InsertBuilder 17 | OnConflictMap(constraint string, actions map[string]interface{}) InsertBuilder 18 | OnConflict(constraint string) ConflictStmt 19 | Pair(column string, value interface{}) InsertBuilder 20 | } 21 | 22 | // InsertBuilder builds "INSERT ..." stmt 23 | type insertBuilder struct { 24 | EventReceiver 25 | runner 26 | 27 | Dialect Dialect 28 | RecordID reflect.Value 29 | insertStmt *insertStmt 30 | ctx context.Context 31 | } 32 | 33 | // InsertInto creates a InsertBuilder 34 | func (sess *Session) InsertInto(table string) InsertBuilder { 35 | return &insertBuilder{ 36 | runner: sess, 37 | EventReceiver: sess.EventReceiver, 38 | Dialect: sess.Dialect, 39 | insertStmt: createInsertStmt(table), 40 | ctx: sess.ctx, 41 | } 42 | } 43 | 44 | // InsertInto creates a InsertBuilder 45 | func (tx *Tx) InsertInto(table string) InsertBuilder { 46 | return &insertBuilder{ 47 | runner: tx, 48 | EventReceiver: tx.EventReceiver, 49 | Dialect: tx.Dialect, 50 | insertStmt: createInsertStmt(table), 51 | ctx: tx.ctx, 52 | } 53 | } 54 | 55 | // InsertBySql creates a InsertBuilder from raw query 56 | func (sess *Session) InsertBySql(query string, value ...interface{}) InsertBuilder { 57 | return &insertBuilder{ 58 | runner: sess, 59 | EventReceiver: sess.EventReceiver, 60 | Dialect: sess.Dialect, 61 | insertStmt: createInsertStmtBySQL(query, value), 62 | ctx: sess.ctx, 63 | } 64 | } 65 | 66 | // InsertBySql creates a InsertBuilder from raw query 67 | func (tx *Tx) InsertBySql(query string, value ...interface{}) InsertBuilder { 68 | return &insertBuilder{ 69 | runner: tx, 70 | EventReceiver: tx.EventReceiver, 71 | Dialect: tx.Dialect, 72 | insertStmt: createInsertStmtBySQL(query, value), 73 | ctx: tx.ctx, 74 | } 75 | } 76 | 77 | func (b *insertBuilder) Build(d Dialect, buf Buffer) error { 78 | return b.insertStmt.Build(d, buf) 79 | } 80 | 81 | // Pair adds a new column value pair 82 | func (b *insertBuilder) Pair(column string, value interface{}) InsertBuilder { 83 | b.Columns(column) 84 | switch len(b.insertStmt.Value) { 85 | case 0: 86 | b.insertStmt.Values(value) 87 | case 1: 88 | b.insertStmt.Value[0] = append(b.insertStmt.Value[0], value) 89 | default: 90 | panic("pair only allows one record to insert") 91 | } 92 | return b 93 | } 94 | 95 | // Exec executes the stmt with background context 96 | func (b *insertBuilder) Exec() (sql.Result, error) { 97 | return b.ExecContext(b.ctx) 98 | } 99 | 100 | // ExecContext executes the stmt 101 | func (b *insertBuilder) ExecContext(ctx context.Context) (sql.Result, error) { 102 | result, err := exec(ctx, b.runner, b.EventReceiver, b, b.Dialect) 103 | if err != nil { 104 | return nil, err 105 | } 106 | 107 | if b.RecordID.IsValid() { 108 | if id, err := result.LastInsertId(); err == nil { 109 | b.RecordID.SetInt(id) 110 | } 111 | } 112 | 113 | return result, nil 114 | } 115 | 116 | // Columns adds columns 117 | func (b *insertBuilder) Columns(column ...string) InsertBuilder { 118 | b.insertStmt.Columns(column...) 119 | return b 120 | } 121 | 122 | // Values adds a tuple for columns 123 | func (b *insertBuilder) Values(value ...interface{}) InsertBuilder { 124 | b.insertStmt.Values(value...) 125 | return b 126 | } 127 | 128 | // Record adds a tuple for columns from a struct 129 | func (b *insertBuilder) Record(structValue interface{}) InsertBuilder { 130 | v := reflect.Indirect(reflect.ValueOf(structValue)) 131 | if v.Kind() == reflect.Struct && v.CanSet() { 132 | // ID is recommended by golint here 133 | for _, name := range []string{"Id", "ID"} { 134 | field := v.FieldByName(name) 135 | if field.IsValid() && field.Kind() == reflect.Int64 { 136 | b.RecordID = field 137 | break 138 | } 139 | } 140 | } 141 | 142 | b.insertStmt.Record(structValue) 143 | return b 144 | } 145 | 146 | // OnConflictMap allows to add actions for constraint violation, e.g UPSERT 147 | func (b *insertBuilder) OnConflictMap(constraint string, actions map[string]interface{}) InsertBuilder { 148 | b.insertStmt.OnConflictMap(constraint, actions) 149 | return b 150 | } 151 | 152 | // OnConflict creates an empty OnConflict section fo insert statement , e.g UPSERT 153 | func (b *insertBuilder) OnConflict(constraint string) ConflictStmt { 154 | return b.insertStmt.OnConflict(constraint) 155 | } 156 | -------------------------------------------------------------------------------- /insert_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/mailru/dbr/dialect" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type insertTest struct { 11 | v string 12 | A int 13 | C string `db:"b"` 14 | } 15 | 16 | func TestInsertStmt(t *testing.T) { 17 | buf := NewBuffer() 18 | builder := InsertInto("table").Columns("a", "b").Values(1, "one").Record(&insertTest{ 19 | A: 2, 20 | C: "two", 21 | }) 22 | err := builder.Build(dialect.MySQL, buf) 23 | assert.NoError(t, err) 24 | assert.Equal(t, "INSERT INTO `table` (`a`,`b`) VALUES (?,?), (?,?)", buf.String()) 25 | assert.Equal(t, []interface{}{1, "one", 2, "two"}, buf.Value()) 26 | } 27 | 28 | func TestInsertRecordNoColumns(t *testing.T) { 29 | buf := NewBuffer() 30 | builder := InsertInto("table").Record(&insertTest{ 31 | A: 2, 32 | C: "two", 33 | }).Values(1, "one") 34 | err := builder.Build(dialect.MySQL, buf) 35 | assert.NoError(t, err) 36 | assert.Equal(t, "INSERT INTO `table` (`a`,`b`) VALUES (?,?), (?,?)", buf.String()) 37 | assert.Equal(t, []interface{}{2, "two", 1, "one"}, buf.Value()) 38 | } 39 | 40 | func TestInsertOnConflictStmt(t *testing.T) { 41 | buf := NewBuffer() 42 | exp := Expr("a + ?", 1) 43 | builder := InsertInto("table").Columns("a", "b").Values(1, "one") 44 | builder.OnConflict("").Action("a", exp).Action("b", "one") 45 | err := builder.Build(dialect.MySQL, buf) 46 | assert.NoError(t, err) 47 | assert.Equal(t, "INSERT INTO `table` (`a`,`b`) VALUES (?,?) ON DUPLICATE KEY UPDATE `a`=?,`b`=?", buf.String()) 48 | assert.Equal(t, []interface{}{1, "one", exp, "one"}, buf.Value()) 49 | } 50 | 51 | func TestInsertOnConflictMapStmt(t *testing.T) { 52 | buf := NewBuffer() 53 | exp := Expr("a + ?", 1) 54 | builder := InsertInto("table").Columns("a", "b").Values(1, "one") 55 | err := builder.OnConflictMap("", map[string]interface{}{"a": exp, "b": "one"}).Build(dialect.MySQL, buf) 56 | assert.NoError(t, err) 57 | assert.Equal(t, "INSERT INTO `table` (`a`,`b`) VALUES (?,?) ON DUPLICATE KEY UPDATE `a`=?,`b`=?", buf.String()) 58 | assert.Equal(t, []interface{}{1, "one", exp, "one"}, buf.Value()) 59 | } 60 | 61 | func BenchmarkInsertValuesSQL(b *testing.B) { 62 | buf := NewBuffer() 63 | for i := 0; i < b.N; i++ { 64 | InsertInto("table").Columns("a", "b").Values(1, "one").Build(dialect.MySQL, buf) 65 | } 66 | } 67 | 68 | func BenchmarkInsertRecordSQL(b *testing.B) { 69 | buf := NewBuffer() 70 | for i := 0; i < b.N; i++ { 71 | InsertInto("table").Columns("a", "b").Record(&insertTest{ 72 | A: 2, 73 | C: "two", 74 | }).Build(dialect.MySQL, buf) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /interpolate.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "database/sql/driver" 5 | "reflect" 6 | "sort" 7 | "strconv" 8 | "strings" 9 | "time" 10 | _ "unsafe" // needs for reflect.UnsafeAddr 11 | ) 12 | 13 | type interpolator struct { 14 | Buffer 15 | Dialect 16 | IgnoreBinary bool 17 | N int 18 | } 19 | 20 | // InterpolateForDialect replaces placeholder in query with corresponding value in dialect 21 | func InterpolateForDialect(query string, value []interface{}, d Dialect) (string, error) { 22 | i := interpolator{ 23 | Buffer: NewBuffer(), 24 | Dialect: d, 25 | } 26 | err := i.interpolate(query, value) 27 | if err != nil { 28 | return "", err 29 | } 30 | return i.String(), nil 31 | } 32 | 33 | func (i *interpolator) interpolate(query string, value []interface{}) error { 34 | if strings.Count(query, placeholder) != len(value) { 35 | return ErrPlaceholderCount 36 | } 37 | 38 | valueIndex := 0 39 | 40 | for { 41 | index := strings.Index(query, placeholder) 42 | if index == -1 { 43 | break 44 | } 45 | 46 | i.WriteString(query[:index]) 47 | if _, ok := value[valueIndex].([]byte); ok && i.IgnoreBinary { 48 | i.WriteString(i.Placeholder(i.N)) 49 | i.N++ 50 | i.WriteValue(value[valueIndex]) 51 | } else { 52 | err := i.encodePlaceholder(value[valueIndex]) 53 | if err != nil { 54 | return err 55 | } 56 | } 57 | query = query[index+len(placeholder):] 58 | valueIndex++ 59 | } 60 | 61 | // placeholder not found; write remaining query 62 | i.WriteString(query) 63 | 64 | return nil 65 | } 66 | 67 | func (i *interpolator) encodePlaceholder(value interface{}) error { 68 | if builder, ok := value.(Builder); ok { 69 | pbuf := NewBuffer() 70 | err := builder.Build(i.Dialect, pbuf) 71 | if err != nil { 72 | return err 73 | } 74 | paren := true 75 | switch value.(type) { 76 | case SelectStmt: 77 | case *union: 78 | default: 79 | paren = false 80 | } 81 | if paren { 82 | i.WriteString("(") 83 | } 84 | err = i.interpolate(pbuf.String(), pbuf.Value()) 85 | if err != nil { 86 | return err 87 | } 88 | if paren { 89 | i.WriteString(")") 90 | } 91 | return nil 92 | } 93 | 94 | if valuer, ok := value.(driver.Valuer); ok { 95 | // get driver.Valuer's data 96 | var err error 97 | value, err = valuer.Value() 98 | if err != nil { 99 | return err 100 | } 101 | } 102 | 103 | if value == nil { 104 | i.WriteString("NULL") 105 | return nil 106 | } 107 | v := reflect.ValueOf(value) 108 | switch v.Kind() { 109 | case reflect.String: 110 | i.WriteString(i.EncodeString(v.String())) 111 | return nil 112 | case reflect.Bool: 113 | i.WriteString(i.EncodeBool(v.Bool())) 114 | return nil 115 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 116 | i.WriteString(strconv.FormatInt(v.Int(), 10)) 117 | return nil 118 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 119 | i.WriteString(strconv.FormatUint(v.Uint(), 10)) 120 | return nil 121 | case reflect.Float32, reflect.Float64: 122 | i.WriteString(strconv.FormatFloat(v.Float(), 'f', -1, 64)) 123 | return nil 124 | case reflect.Struct: 125 | if v.Type() == reflect.TypeOf(time.Time{}) { 126 | i.WriteString(i.EncodeTime(v.Interface().(time.Time))) 127 | return nil 128 | } 129 | case reflect.Slice: 130 | if v.Type().Elem().Kind() == reflect.Uint8 { 131 | // []byte 132 | i.WriteString(i.EncodeBytes(v.Bytes())) 133 | return nil 134 | } 135 | if v.Len() == 0 { 136 | // FIXME: support zero-length slice 137 | return ErrInvalidSliceLength 138 | } 139 | i.WriteString("(") 140 | for n := 0; n < v.Len(); n++ { 141 | if n > 0 { 142 | i.WriteString(",") 143 | } 144 | err := i.encodePlaceholder(v.Index(n).Interface()) 145 | if err != nil { 146 | return err 147 | } 148 | } 149 | i.WriteString(")") 150 | return nil 151 | case reflect.Map: 152 | if v.Len() == 0 { 153 | // FIXME: support zero-length slice 154 | return ErrInvalidSliceLength 155 | } 156 | i.WriteString("(") 157 | // we need to sort keys, because in this case it is more chance 158 | // for database cache hit because the query will be same for same values 159 | // and this covers extra cost of sorting 160 | keys := mapKeys(v.MapKeys()) 161 | sort.Sort(keys) 162 | for n := 0; n < len(keys); n++ { 163 | if n > 0 { 164 | i.WriteString(",") 165 | } 166 | err := i.encodePlaceholder(keys[n].Interface()) 167 | if err != nil { 168 | return err 169 | } 170 | } 171 | i.WriteString(")") 172 | return nil 173 | case reflect.Ptr: 174 | if v.IsNil() { 175 | i.WriteString("NULL") 176 | return nil 177 | } 178 | return i.encodePlaceholder(v.Elem().Interface()) 179 | } 180 | return ErrNotSupported 181 | } 182 | 183 | type mapKeys []reflect.Value 184 | 185 | func (k mapKeys) Len() int { 186 | return len(k) 187 | } 188 | 189 | func (k mapKeys) Less(i, j int) bool { 190 | vi, vj := k[i], k[j] 191 | switch vi.Kind() { 192 | case reflect.Bool: 193 | return !vi.Bool() && vj.Bool() 194 | case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: 195 | return vi.Int() < vj.Int() 196 | case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, reflect.Uintptr: 197 | return vi.Uint() < vj.Uint() 198 | case reflect.Float32, reflect.Float64: 199 | return vi.Float() < vj.Float() 200 | case reflect.String: 201 | return strings.Compare(vi.String(), vj.String()) < 0 202 | default: 203 | return vi.UnsafeAddr() < vj.UnsafeAddr() 204 | } 205 | } 206 | 207 | func (k mapKeys) Swap(i, j int) { 208 | k[i], k[j] = k[j], k[i] 209 | } 210 | -------------------------------------------------------------------------------- /interpolate_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | "time" 7 | 8 | "github.com/mailru/dbr/dialect" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestInterpolateIgnoreBinary(t *testing.T) { 13 | for _, test := range []struct { 14 | query string 15 | value []interface{} 16 | wantQuery string 17 | wantValue []interface{} 18 | }{ 19 | { 20 | query: "?", 21 | value: []interface{}{1}, 22 | wantQuery: "1", 23 | wantValue: nil, 24 | }, 25 | { 26 | query: "?", 27 | value: []interface{}{[]byte{1, 2, 3}}, 28 | wantQuery: "?", 29 | wantValue: []interface{}{[]byte{1, 2, 3}}, 30 | }, 31 | { 32 | query: "? ?", 33 | value: []interface{}{[]byte{1}, []byte{2}}, 34 | wantQuery: "? ?", 35 | wantValue: []interface{}{[]byte{1}, []byte{2}}, 36 | }, 37 | { 38 | query: "? ?", 39 | value: []interface{}{Expr("|?| ?", []byte{1}, Expr("|?|", []byte{2})), []byte{3}}, 40 | wantQuery: "|?| |?| ?", 41 | wantValue: []interface{}{[]byte{1}, []byte{2}, []byte{3}}, 42 | }, 43 | } { 44 | i := interpolator{ 45 | Buffer: NewBuffer(), 46 | Dialect: dialect.MySQL, 47 | IgnoreBinary: true, 48 | } 49 | 50 | err := i.interpolate(test.query, test.value) 51 | assert.NoError(t, err) 52 | 53 | assert.Equal(t, test.wantQuery, i.String()) 54 | assert.Equal(t, test.wantValue, i.Value()) 55 | } 56 | } 57 | 58 | func TestInterpolateForDialect(t *testing.T) { 59 | for _, test := range []struct { 60 | query string 61 | value []interface{} 62 | want string 63 | }{ 64 | { 65 | query: "?", 66 | value: []interface{}{nil}, 67 | want: "NULL", 68 | }, 69 | { 70 | query: "?", 71 | value: []interface{}{`'"'"`}, 72 | want: "'\\'\\\"\\'\\\"'", 73 | }, 74 | { 75 | query: "? ?", 76 | value: []interface{}{true, false}, 77 | want: "1 0", 78 | }, 79 | { 80 | query: "? ?", 81 | value: []interface{}{1, 1.23}, 82 | want: "1 1.23", 83 | }, 84 | { 85 | query: "?", 86 | value: []interface{}{time.Date(2008, 9, 17, 20, 4, 26, 123456000, time.UTC)}, 87 | want: "'2008-09-17 20:04:26.123456'", 88 | }, 89 | { 90 | query: "?", 91 | value: []interface{}{[]string{"one", "two"}}, 92 | want: "('one','two')", 93 | }, 94 | { 95 | query: "?", 96 | value: []interface{}{map[string]bool{"one": true, "two": false}}, 97 | want: "('one','two')", 98 | }, 99 | { 100 | query: "?", 101 | value: []interface{}{[]byte{0x1, 0x2, 0x3}}, 102 | want: "0x010203", 103 | }, 104 | { 105 | query: "start?end", 106 | value: []interface{}{new(int)}, 107 | want: "start0end", 108 | }, 109 | { 110 | query: "?", 111 | value: []interface{}{Select("a").From("table")}, 112 | want: "(SELECT a FROM table)", 113 | }, 114 | { 115 | query: "?", 116 | value: []interface{}{I("a1").As("a2")}, 117 | want: "`a1` AS `a2`", 118 | }, 119 | { 120 | query: "?", 121 | value: []interface{}{Select("a").From("table").As("a1")}, 122 | want: "(SELECT a FROM table) AS `a1`", 123 | }, 124 | { 125 | query: "?", 126 | value: []interface{}{ 127 | UnionAll( 128 | Select("a").From("table1"), 129 | Select("b").From("table2"), 130 | ).As("t"), 131 | }, 132 | want: "((SELECT a FROM table1) UNION ALL (SELECT b FROM table2)) AS `t`", 133 | }, 134 | { 135 | query: "?", 136 | value: []interface{}{time.Month(7)}, 137 | want: "7", 138 | }, 139 | { 140 | query: "?", 141 | value: []interface{}{(*int64)(nil)}, 142 | want: "NULL", 143 | }, 144 | } { 145 | s, err := InterpolateForDialect(test.query, test.value, dialect.MySQL) 146 | assert.NoError(t, err) 147 | assert.Equal(t, test.want, s) 148 | } 149 | } 150 | 151 | // Attempts to test common SQL injection strings. See `InjectionAttempts` for 152 | // more information on the source and the strings themselves. 153 | func TestCommonSQLInjections(t *testing.T) { 154 | for _, sess := range testSession { 155 | for _, injectionAttempt := range strings.Split(injectionAttempts, "\n") { 156 | // Create a user with the attempted injection as the email address 157 | id := nextID() 158 | _, err := sess.InsertInto("dbr_people"). 159 | Pair("name", injectionAttempt). 160 | Pair("id", id). 161 | Exec() 162 | assert.NoError(t, err) 163 | 164 | // SELECT the name back and ensure it's equal to the injection attempt 165 | var name string 166 | err = sess.Select("name").From("dbr_people").Where(Eq("id", id)).LoadValue(&name) 167 | assert.Equal(t, injectionAttempt, name) 168 | } 169 | } 170 | } 171 | 172 | // InjectionAttempts is a newline separated list of common SQL injection exploits 173 | // taken from https://wfuzz.googlecode.com/svn/trunk/wordlist/Injections/SQL.txt 174 | 175 | const injectionAttempts = ` 176 | ' 177 | " 178 | # 179 | - 180 | -- 181 | '%20-- 182 | --'; 183 | '%20; 184 | =%20' 185 | =%20; 186 | =%20-- 187 | \x23 188 | \x27 189 | \x3D%20\x3B' 190 | \x3D%20\x27 191 | \x27\x4F\x52 SELECT * 192 | \x27\x6F\x72 SELECT * 193 | 'or%20select * 194 | admin'-- 195 | <>"'%;)(&+ 196 | '%20or%20''=' 197 | '%20or%20'x'='x 198 | "%20or%20"x"="x 199 | ')%20or%20('x'='x 200 | 0 or 1=1 201 | ' or 0=0 -- 202 | " or 0=0 -- 203 | or 0=0 -- 204 | ' or 0=0 # 205 | " or 0=0 # 206 | or 0=0 # 207 | ' or 1=1-- 208 | " or 1=1-- 209 | ' or '1'='1'-- 210 | "' or 1 --'" 211 | or 1=1-- 212 | or%201=1 213 | or%201=1 -- 214 | ' or 1=1 or ''=' 215 | " or 1=1 or ""=" 216 | ' or a=a-- 217 | " or "a"="a 218 | ') or ('a'='a 219 | ") or ("a"="a 220 | hi" or "a"="a 221 | hi" or 1=1 -- 222 | hi' or 1=1 -- 223 | hi' or 'a'='a 224 | hi') or ('a'='a 225 | hi") or ("a"="a 226 | 'hi' or 'x'='x'; 227 | @variable 228 | ,@variable 229 | PRINT 230 | PRINT @@variable 231 | select 232 | insert 233 | as 234 | or 235 | procedure 236 | limit 237 | order by 238 | asc 239 | desc 240 | delete 241 | update 242 | distinct 243 | having 244 | truncate 245 | replace 246 | like 247 | handler 248 | bfilename 249 | ' or username like '% 250 | ' or uname like '% 251 | ' or userid like '% 252 | ' or uid like '% 253 | ' or user like '% 254 | exec xp 255 | exec sp 256 | '; exec master..xp_cmdshell 257 | '; exec xp_regread 258 | t'exec master..xp_cmdshell 'nslookup www.google.com'-- 259 | --sp_password 260 | \x27UNION SELECT 261 | ' UNION SELECT 262 | ' UNION ALL SELECT 263 | ' or (EXISTS) 264 | ' (select top 1 265 | '||UTL_HTTP.REQUEST 266 | 1;SELECT%20* 267 | to_timestamp_tz 268 | tz_offset 269 | <>"'%;)(&+ 270 | '%20or%201=1 271 | %27%20or%201=1 272 | %20$(sleep%2050) 273 | %20'sleep%2050' 274 | char%4039%41%2b%40SELECT 275 | '%20OR 276 | 'sqlattempt1 277 | (sqlattempt2) 278 | | 279 | %7C 280 | *| 281 | %2A%7C 282 | *(|(mail=*)) 283 | %2A%28%7C%28mail%3D%2A%29%29 284 | *(|(objectclass=*)) 285 | %2A%28%7C%28objectclass%3D%2A%29%29 286 | ( 287 | %28 288 | ) 289 | %29 290 | & 291 | %26 292 | ! 293 | %21 294 | ' or 1=1 or ''=' 295 | ' or ''=' 296 | x' or 1=1 or 'x'='y 297 | / 298 | // 299 | //* 300 | */* 301 | ` 302 | -------------------------------------------------------------------------------- /join.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | type joinType uint8 4 | 5 | const ( 6 | inner joinType = iota 7 | left 8 | right 9 | full 10 | ) 11 | 12 | func join(t joinType, table, on interface{}) Builder { 13 | return BuildFunc(func(d Dialect, buf Buffer) error { 14 | buf.WriteString(" ") 15 | switch t { 16 | case left: 17 | buf.WriteString("LEFT ") 18 | case right: 19 | buf.WriteString("RIGHT ") 20 | case full: 21 | buf.WriteString("FULL ") 22 | } 23 | buf.WriteString("JOIN ") 24 | switch table := table.(type) { 25 | case string: 26 | buf.WriteString(d.QuoteIdent(table)) 27 | default: 28 | buf.WriteString(placeholder) 29 | buf.WriteValue(table) 30 | } 31 | buf.WriteString(" ON ") 32 | switch on := on.(type) { 33 | case string: 34 | buf.WriteString(on) 35 | case Builder: 36 | buf.WriteString(placeholder) 37 | buf.WriteValue(on) 38 | } 39 | return nil 40 | }) 41 | } 42 | -------------------------------------------------------------------------------- /load.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "reflect" 7 | ) 8 | 9 | // Load loads any value from sql.Rows 10 | func Load(rows *sql.Rows, value interface{}) (int, error) { 11 | defer rows.Close() 12 | 13 | column, err := rows.Columns() 14 | if err != nil { 15 | return 0, err 16 | } 17 | 18 | v := reflect.ValueOf(value) 19 | if v.Kind() != reflect.Ptr || v.IsNil() { 20 | return 0, ErrInvalidPointer 21 | } 22 | 23 | v = v.Elem() 24 | isSlice := v.Kind() == reflect.Slice && v.Type().Elem().Kind() != reflect.Uint8 25 | elem := v 26 | elemType := elem.Type() 27 | 28 | if isSlice { 29 | elemType = elemType.Elem() 30 | elem = reflect.New(elemType).Elem() 31 | } 32 | 33 | extractor, err := findExtractor(elemType) 34 | if err != nil { 35 | return 0, err 36 | } 37 | 38 | ptrs := extractor(column, elem) 39 | count := 0 40 | 41 | for rows.Next() { 42 | if err = rows.Scan(ptrs...); err != nil { 43 | return count, err 44 | } 45 | 46 | count++ 47 | 48 | if !isSlice { 49 | break 50 | } 51 | 52 | if elemType.Kind() == reflect.Ptr { 53 | elemCopy := reflect.New(elemType.Elem()).Elem() 54 | elemCopy.Set(elem.Elem()) 55 | v.Set(reflect.Append(v, elemCopy.Addr())) 56 | } else { 57 | v.Set(reflect.Append(v, elem)) 58 | } 59 | } 60 | 61 | return count, rows.Err() 62 | } 63 | 64 | type dummyScanner struct{} 65 | 66 | func (dummyScanner) Scan(interface{}) error { 67 | return nil 68 | } 69 | 70 | type keyValueMap map[string]interface{} 71 | 72 | type kvScanner struct { 73 | column string 74 | m keyValueMap 75 | } 76 | 77 | func (kv *kvScanner) Scan(v interface{}) error { 78 | if b, ok := v.([]byte); ok { 79 | tmp := make([]byte, len(b)) 80 | copy(tmp, b) 81 | kv.m[kv.column] = tmp 82 | } else { 83 | // int64, float64, bool, string, time.Time, nil 84 | kv.m[kv.column] = v 85 | } 86 | return nil 87 | } 88 | 89 | type pointersExtractor func(columns []string, value reflect.Value) []interface{} 90 | 91 | var ( 92 | dummyDest sql.Scanner = dummyScanner{} 93 | typeScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() 94 | typeKeyValueMap = reflect.TypeOf(keyValueMap(nil)) 95 | ) 96 | 97 | func getStructFieldsExtractor(t reflect.Type) pointersExtractor { 98 | mapping := structMap(t) 99 | return func(columns []string, value reflect.Value) []interface{} { 100 | var ptr []interface{} 101 | for _, key := range columns { 102 | if index, ok := mapping[key]; ok { 103 | ptr = append(ptr, value.FieldByIndex(index).Addr().Interface()) 104 | } else { 105 | ptr = append(ptr, dummyDest) 106 | } 107 | } 108 | return ptr 109 | } 110 | } 111 | 112 | func getIndirectExtractor(extractor pointersExtractor) pointersExtractor { 113 | return func(columns []string, value reflect.Value) []interface{} { 114 | if value.IsNil() { 115 | value.Set(reflect.New(value.Type().Elem())) 116 | } 117 | return extractor(columns, value.Elem()) 118 | } 119 | } 120 | 121 | func mapExtractor(columns []string, value reflect.Value) []interface{} { 122 | if value.IsNil() { 123 | value.Set(reflect.MakeMap(value.Type())) 124 | } 125 | m := value.Convert(typeKeyValueMap).Interface().(keyValueMap) 126 | var ptr = make([]interface{}, 0, len(columns)) 127 | for _, c := range columns { 128 | ptr = append(ptr, &kvScanner{column: c, m: m}) 129 | } 130 | return ptr 131 | } 132 | 133 | func dummyExtractor(columns []string, value reflect.Value) []interface{} { 134 | return []interface{}{value.Addr().Interface()} 135 | } 136 | 137 | func findExtractor(t reflect.Type) (pointersExtractor, error) { 138 | if reflect.PtrTo(t).Implements(typeScanner) { 139 | return dummyExtractor, nil 140 | } 141 | 142 | switch t.Kind() { 143 | case reflect.Map: 144 | if !t.ConvertibleTo(typeKeyValueMap) { 145 | return nil, fmt.Errorf("expected %v, got %v", typeKeyValueMap, t) 146 | } 147 | return mapExtractor, nil 148 | case reflect.Ptr: 149 | inner, err := findExtractor(t.Elem()) 150 | if err != nil { 151 | return nil, err 152 | } 153 | return getIndirectExtractor(inner), nil 154 | case reflect.Struct: 155 | return getStructFieldsExtractor(t), nil 156 | } 157 | return dummyExtractor, nil 158 | } 159 | -------------------------------------------------------------------------------- /load_bench_test.go: -------------------------------------------------------------------------------- 1 | package dbr_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/DATA-DOG/go-sqlmock" 8 | "github.com/mailru/dbr" 9 | "github.com/mailru/dbr/dialect" 10 | ) 11 | 12 | var rawData = getDataSlice(10000) 13 | 14 | func Benchmark_SQLScan(b *testing.B) { 15 | for i := 0; i < b.N; i++ { 16 | benchRawSQL(b, rawData, []benchItem{}) 17 | } 18 | } 19 | 20 | func Benchmark_SQLScanWithCap(b *testing.B) { 21 | for i := 0; i < b.N; i++ { 22 | benchRawSQL(b, rawData, make([]benchItem, 0, len(rawData))) 23 | } 24 | } 25 | 26 | func Benchmark_DBRLoad(b *testing.B) { 27 | for i := 0; i < b.N; i++ { 28 | benchDBR(b, rawData, []benchItem{}) 29 | } 30 | } 31 | 32 | func Benchmark_DBRLoadPtrs(b *testing.B) { 33 | for i := 0; i < b.N; i++ { 34 | benchDBRPtrs(b, rawData, []*benchItem{}) 35 | } 36 | } 37 | 38 | func Benchmark_DBRLoadWithCap(b *testing.B) { 39 | for i := 0; i < b.N; i++ { 40 | benchDBR(b, rawData, make([]benchItem, 0, len(rawData))) 41 | } 42 | } 43 | 44 | func Benchmark_DBRLoadPtrsWithCap(b *testing.B) { 45 | for i := 0; i < b.N; i++ { 46 | benchDBR(b, rawData, make([]benchItem, 0, len(rawData))) 47 | } 48 | } 49 | 50 | type benchItem struct { 51 | Field1 string 52 | Field2 int 53 | } 54 | 55 | func getDataSlice(itemsCnt int) []benchItem { 56 | res := make([]benchItem, 0, itemsCnt) 57 | for num := 0; len(res) < cap(res); num++ { 58 | res = append(res, benchItem{Field1: "str" + fmt.Sprint(num), Field2: num}) 59 | } 60 | return res 61 | } 62 | 63 | func getRowsMocked(b *testing.B, data []benchItem) *sqlmock.Rows { 64 | rows := sqlmock.NewRows([]string{"field1", "field2"}) 65 | for _, item := range data { 66 | rows.AddRow(item.Field1, item.Field2) 67 | } 68 | return rows 69 | } 70 | 71 | func benchRawSQL(b *testing.B, data []benchItem, res []benchItem) { 72 | b.StopTimer() 73 | db, mock, err := sqlmock.New() 74 | if err != nil { 75 | b.Error(err) 76 | } 77 | mock.ExpectQuery("select").WillReturnRows(getRowsMocked(b, data)) 78 | b.StartTimer() 79 | 80 | rows, err := db.Query("select") 81 | if err != nil { 82 | b.Error(err) 83 | } 84 | 85 | var item benchItem 86 | for rows.Next() { 87 | if err := rows.Scan(&item.Field1, &item.Field2); err != nil { 88 | b.Error(err) 89 | } 90 | res = append(res, item) 91 | } 92 | } 93 | 94 | func benchDBR(b *testing.B, data []benchItem, res []benchItem) { 95 | b.StopTimer() 96 | sess, dbmock := getDBRMock(b, dialect.MySQL) 97 | dbmock.ExpectQuery("SELECT field1, field2 FROM sometable").WillReturnRows(getRowsMocked(b, data)) 98 | rows := sess.Select("field1", "field2").From("sometable") 99 | b.StartTimer() 100 | 101 | if _, err := rows.LoadStructs(&res); err != nil { 102 | b.Error(err) 103 | } 104 | } 105 | 106 | func benchDBRPtrs(b *testing.B, data []benchItem, res []*benchItem) { 107 | b.StopTimer() 108 | sess, dbmock := getDBRMock(b, dialect.MySQL) 109 | dbmock.ExpectQuery("SELECT field1, field2 FROM sometable").WillReturnRows(getRowsMocked(b, data)) 110 | rows := sess.Select("field1", "field2").From("sometable") 111 | b.StartTimer() 112 | 113 | if _, err := rows.LoadStructs(&res); err != nil { 114 | b.Error(err) 115 | } 116 | } 117 | 118 | func getDBRMock(b *testing.B, dialect dbr.Dialect) (*dbr.Session, sqlmock.Sqlmock) { 119 | db, dbmock, err := sqlmock.New() 120 | if err != nil { 121 | b.Error(err) 122 | } 123 | 124 | conn := dbr.Connection{DBConn: db, Dialect: dialect, EventReceiver: &dbr.NullEventReceiver{}} 125 | 126 | return conn.NewSession(conn.EventReceiver), dbmock 127 | } 128 | -------------------------------------------------------------------------------- /load_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "fmt" 7 | "reflect" 8 | "testing" 9 | 10 | "github.com/DATA-DOG/go-sqlmock" 11 | "github.com/mailru/dbr/dialect" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestLoad(t *testing.T) { 16 | type testStruct struct { 17 | A string 18 | } 19 | 20 | testcases := []struct { 21 | columns []string 22 | expected interface{} 23 | }{ 24 | {[]string{"a"}, "a"}, 25 | {[]string{"a"}, []string{"a"}}, 26 | {[]string{"a"}, testStruct{"a"}}, 27 | {[]string{"a"}, &testStruct{"a"}}, 28 | {[]string{"a"}, []testStruct{{"a"}}}, 29 | {[]string{"a"}, []*testStruct{{"a"}}}, 30 | {[]string{"a", "b"}, map[string]interface{}{"a": "a", "b": "b"}}, 31 | {[]string{"a", "b"}, &map[string]interface{}{"a": "a", "b": "b"}}, 32 | {[]string{"a", "b"}, []map[string]interface{}{{"a": "a", "b": "b"}}}, 33 | } 34 | 35 | for _, tc := range testcases { 36 | var values []driver.Value 37 | session, dbmock := newSessionMock() 38 | for _, c := range tc.columns { 39 | values = append(values, c) 40 | } 41 | rows := sqlmock.NewRows(tc.columns).AddRow(values...) 42 | dbmock.ExpectQuery("SELECT .+").WillReturnRows(rows) 43 | v := reflect.New(reflect.TypeOf(tc.expected)).Elem().Addr().Interface() 44 | session.Select(tc.columns...).From("table").Load(v) 45 | assert.Equal(t, tc.expected, reflect.Indirect(reflect.ValueOf(v)).Interface()) 46 | } 47 | } 48 | 49 | func TestLoadWithBytesValue(t *testing.T) { 50 | var values []driver.Value 51 | columns := []string{"fieldname"} 52 | value := []byte("fieldvalue") 53 | session, dbmock := newSessionMock() 54 | values = append(values, value) 55 | rows := sqlmock.NewRows(columns).AddRow(values...) 56 | dbmock.ExpectQuery("SELECT .+").WillReturnRows(rows) 57 | v := reflect.New(reflect.TypeOf(map[string]interface{}(nil))).Elem().Addr().Interface() 58 | session.Select(columns...).From("table").Load(v) 59 | value[0] = byte('a') 60 | assert.Equal(t, map[string]interface{}{"fieldname": []byte("fieldvalue")}, 61 | reflect.Indirect(reflect.ValueOf(v)).Interface()) 62 | } 63 | 64 | func BenchmarkLoad(b *testing.B) { 65 | session, dbmock := newSessionMock() 66 | rows := sqlmock.NewRows([]string{"a", "b", "c"}) 67 | for i := 0; i < 100; i++ { 68 | rows = rows.AddRow(1, 2, 3) 69 | } 70 | dbmock.ExpectQuery("SELECT a, b, c FROM table").WillReturnRows(rows) 71 | b.ResetTimer() 72 | for i := 0; i < b.N; i++ { 73 | r := make([]struct { 74 | A int `db:"a"` 75 | B int `db:"b"` 76 | C int 77 | D int `db:"-"` 78 | e int 79 | F int `db:"f"` 80 | G int `db:"g"` 81 | H int `db:"h"` 82 | i int 83 | j int 84 | }, 0, 100) 85 | session.Select("a", "b", "c").From("table").LoadStructs(&r) 86 | } 87 | } 88 | 89 | func newSessionMock() (SessionRunner, sqlmock.Sqlmock) { 90 | db, m, err := sqlmock.New() 91 | if err != nil { 92 | panic(err) 93 | } 94 | conn := Connection{DBConn: db, Dialect: dialect.MySQL, EventReceiver: nullReceiver} 95 | return conn.NewSession(nil), m 96 | } 97 | 98 | func Test_Load_Scalar(t *testing.T) { 99 | t.Parallel() 100 | var res int 101 | _, err := Load(sqlRows(t, sqlmock.NewRows([]string{"cnt"}).AddRow(123)), &res) 102 | assert.NoError(t, err) 103 | assert.EqualValues(t, 123, res) 104 | } 105 | 106 | func Test_Load_ScalarPtr(t *testing.T) { 107 | t.Parallel() 108 | var res *int 109 | _, err := Load(sqlRows(t, sqlmock.NewRows([]string{"cnt"}).AddRow(123)), &res) 110 | assert.NoError(t, err) 111 | expected := new(int) 112 | *expected = 123 113 | assert.EqualValues(t, expected, res) 114 | } 115 | 116 | func Test_Load_ScalarSlice(t *testing.T) { 117 | t.Parallel() 118 | var res []int 119 | _, err := Load(sqlRows(t, sqlmock.NewRows([]string{"cnt"}).AddRow(111).AddRow(222).AddRow(333)), &res) 120 | assert.NoError(t, err) 121 | assert.EqualValues(t, []int{111, 222, 333}, res) 122 | } 123 | 124 | func Test_Load_ScalarSlicePtr(t *testing.T) { 125 | t.Parallel() 126 | var expected, actual []*int 127 | _, err := Load(sqlRows(t, sqlmock.NewRows([]string{"cnt"}).AddRow(0).AddRow(1).AddRow(2)), &actual) 128 | assert.NoError(t, err) 129 | for k := range make([]int, 3) { 130 | k := k 131 | expected = append(expected, &k) 132 | } 133 | assert.EqualValues(t, expected, actual) 134 | } 135 | 136 | type testObj struct { 137 | Field1 string 138 | Field2 int 139 | } 140 | 141 | func Test_Load_Struct(t *testing.T) { 142 | t.Parallel() 143 | var res testObj 144 | _, err := Load(sqlRows(t, sqlmock.NewRows([]string{"field1", "field2"}).AddRow("111", 222)), &res) 145 | assert.NoError(t, err) 146 | assert.EqualValues(t, testObj{"111", 222}, res) 147 | } 148 | 149 | func Test_Load_StructPtr(t *testing.T) { 150 | t.Parallel() 151 | res := &testObj{} 152 | _, err := Load(sqlRows(t, sqlmock.NewRows([]string{"field1", "field2"}).AddRow("111", 222)), &res) 153 | assert.NoError(t, err) 154 | assert.EqualValues(t, &testObj{"111", 222}, res) 155 | } 156 | 157 | func Test_Load_StructSlice(t *testing.T) { 158 | t.Parallel() 159 | var res []testObj 160 | _, err := Load(sqlRows(t, sqlmock.NewRows([]string{"field1", "field2"}).AddRow("111", 222).AddRow("222", 333)), &res) 161 | assert.NoError(t, err) 162 | assert.EqualValues(t, []testObj{{"111", 222}, {"222", 333}}, res) 163 | } 164 | 165 | func Test_Load_StructSlicePtr(t *testing.T) { 166 | t.Parallel() 167 | var expected, actual []*testObj 168 | _, err := Load(sqlRows(t, sqlmock.NewRows([]string{"field1", "field2"}).AddRow("0", 0).AddRow("1", 1)), &actual) 169 | assert.NoError(t, err) 170 | for k := range make([]int, 2) { 171 | k := k 172 | expected = append(expected, &testObj{fmt.Sprint(k), k}) 173 | } 174 | assert.EqualValues(t, expected, actual) 175 | } 176 | 177 | func sqlRows(t *testing.T, mockedRows *sqlmock.Rows) *sql.Rows { 178 | t.Helper() 179 | 180 | db, dbmock, err := sqlmock.New() 181 | if err != nil { 182 | t.Error(err) 183 | } 184 | 185 | dbmock.ExpectQuery("select").WillReturnRows(mockedRows) 186 | 187 | rows, err := db.Query("select") 188 | if err != nil { 189 | t.Error(err) 190 | } 191 | 192 | return rows 193 | } 194 | -------------------------------------------------------------------------------- /now.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "database/sql/driver" 5 | "time" 6 | ) 7 | 8 | // Now is a value that serializes to the current time 9 | var Now = nowSentinel{} 10 | 11 | const timeFormat = "2006-01-02 15:04:05.000000" 12 | 13 | type nowSentinel struct{} 14 | 15 | // Value implements a valuer for compatibility 16 | func (n nowSentinel) Value() (driver.Value, error) { 17 | now := time.Now().UTC().Format(timeFormat) 18 | return now, nil 19 | } 20 | -------------------------------------------------------------------------------- /order.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | type direction bool 4 | 5 | // orderby directions 6 | // most databases by default use asc 7 | const ( 8 | asc direction = false 9 | desc = true 10 | ) 11 | 12 | func order(column string, dir direction) Builder { 13 | return BuildFunc(func(d Dialect, buf Buffer) error { 14 | // FIXME: no quote ident 15 | buf.WriteString(column) 16 | switch dir { 17 | case asc: 18 | buf.WriteString(" ASC") 19 | case desc: 20 | buf.WriteString(" DESC") 21 | } 22 | return nil 23 | }) 24 | } 25 | -------------------------------------------------------------------------------- /select.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | // SelectStmt builds `SELECT ...` 4 | type SelectStmt interface { 5 | Builder 6 | 7 | From(table interface{}) SelectStmt 8 | Distinct() SelectStmt 9 | Prewhere(query interface{}, value ...interface{}) SelectStmt 10 | Where(query interface{}, value ...interface{}) SelectStmt 11 | Having(query interface{}, value ...interface{}) SelectStmt 12 | GroupBy(col ...string) SelectStmt 13 | OrderAsc(col string) SelectStmt 14 | OrderDesc(col string) SelectStmt 15 | Limit(n uint64) SelectStmt 16 | Offset(n uint64) SelectStmt 17 | ForUpdate() SelectStmt 18 | SkipLocked() SelectStmt 19 | Join(table, on interface{}) SelectStmt 20 | LeftJoin(table, on interface{}) SelectStmt 21 | RightJoin(table, on interface{}) SelectStmt 22 | FullJoin(table, on interface{}) SelectStmt 23 | AddComment(text string) SelectStmt 24 | As(alias string) Builder 25 | } 26 | 27 | type selectStmt struct { 28 | raw 29 | 30 | IsDistinct bool 31 | 32 | Column []interface{} 33 | Table interface{} 34 | JoinTable []Builder 35 | 36 | Comment []Builder 37 | PrewhereCond []Builder 38 | WhereCond []Builder 39 | Group []Builder 40 | HavingCond []Builder 41 | Order []Builder 42 | 43 | LimitCount int64 44 | OffsetCount int64 45 | IsForUpdate bool 46 | IsSkipLocked bool 47 | } 48 | 49 | // Build builds `SELECT ...` in dialect 50 | func (b *selectStmt) Build(d Dialect, buf Buffer) error { 51 | if b.raw.Query != "" { 52 | return b.raw.Build(d, buf) 53 | } 54 | 55 | if len(b.Column) == 0 { 56 | return ErrColumnNotSpecified 57 | } 58 | 59 | if len(b.Comment) > 0 { 60 | for _, comm := range b.Comment { 61 | buf.WriteString("/* ") 62 | err := comm.Build(d, buf) 63 | if err != nil { 64 | return err 65 | } 66 | buf.WriteString(" */") 67 | } 68 | } 69 | 70 | buf.WriteString("SELECT ") 71 | 72 | if b.IsDistinct { 73 | buf.WriteString("DISTINCT ") 74 | } 75 | 76 | for i, col := range b.Column { 77 | if i > 0 { 78 | buf.WriteString(", ") 79 | } 80 | switch col := col.(type) { 81 | case string: 82 | buf.WriteString(col) 83 | default: 84 | buf.WriteString(placeholder) 85 | buf.WriteValue(col) 86 | } 87 | } 88 | 89 | if b.Table != nil { 90 | buf.WriteString(" FROM ") 91 | switch table := b.Table.(type) { 92 | case string: 93 | buf.WriteString(table) 94 | default: 95 | buf.WriteString(placeholder) 96 | buf.WriteValue(table) 97 | } 98 | if len(b.JoinTable) > 0 { 99 | for _, join := range b.JoinTable { 100 | err := join.Build(d, buf) 101 | if err != nil { 102 | return err 103 | } 104 | } 105 | } 106 | } 107 | 108 | if len(b.PrewhereCond) > 0 { 109 | keyword := d.Prewhere() 110 | if len(keyword) == 0 { 111 | return ErrPrewhereNotSupported 112 | } 113 | 114 | buf.WriteString(" ") 115 | buf.WriteString(keyword) 116 | buf.WriteString(" ") 117 | err := And(b.PrewhereCond...).Build(d, buf) 118 | if err != nil { 119 | return err 120 | } 121 | } 122 | 123 | if len(b.WhereCond) > 0 { 124 | buf.WriteString(" WHERE ") 125 | err := And(b.WhereCond...).Build(d, buf) 126 | if err != nil { 127 | return err 128 | } 129 | } 130 | 131 | if len(b.Group) > 0 { 132 | buf.WriteString(" GROUP BY ") 133 | for i, group := range b.Group { 134 | if i > 0 { 135 | buf.WriteString(", ") 136 | } 137 | err := group.Build(d, buf) 138 | if err != nil { 139 | return err 140 | } 141 | } 142 | } 143 | 144 | if len(b.HavingCond) > 0 { 145 | buf.WriteString(" HAVING ") 146 | err := And(b.HavingCond...).Build(d, buf) 147 | if err != nil { 148 | return err 149 | } 150 | } 151 | 152 | if len(b.Order) > 0 { 153 | buf.WriteString(" ORDER BY ") 154 | for i, order := range b.Order { 155 | if i > 0 { 156 | buf.WriteString(", ") 157 | } 158 | err := order.Build(d, buf) 159 | if err != nil { 160 | return err 161 | } 162 | } 163 | } 164 | 165 | if b.LimitCount >= 0 { 166 | buf.WriteString(" ") 167 | buf.WriteString(d.Limit(b.OffsetCount, b.LimitCount)) 168 | } 169 | 170 | if b.IsForUpdate { 171 | buf.WriteString(" FOR UPDATE") 172 | } 173 | 174 | if b.IsSkipLocked { 175 | buf.WriteString(" SKIP LOCKED") 176 | } 177 | 178 | return nil 179 | } 180 | 181 | // Select creates a SelectStmt 182 | func Select(column ...interface{}) SelectStmt { 183 | return createSelectStmt(column) 184 | } 185 | 186 | func createSelectStmt(column []interface{}) *selectStmt { 187 | return &selectStmt{ 188 | Column: column, 189 | LimitCount: -1, 190 | OffsetCount: -1, 191 | } 192 | } 193 | 194 | // From specifies table 195 | func (b *selectStmt) From(table interface{}) SelectStmt { 196 | b.Table = table 197 | return b 198 | } 199 | 200 | // SelectBySql creates a SelectStmt from raw query 201 | func SelectBySql(query string, value ...interface{}) SelectStmt { 202 | return createSelectStmtBySQL(query, value) 203 | } 204 | 205 | func createSelectStmtBySQL(query string, value []interface{}) *selectStmt { 206 | return &selectStmt{ 207 | raw: raw{ 208 | Query: query, 209 | Value: value, 210 | }, 211 | LimitCount: -1, 212 | OffsetCount: -1, 213 | } 214 | } 215 | 216 | // Distinct adds `DISTINCT` 217 | func (b *selectStmt) Distinct() SelectStmt { 218 | b.IsDistinct = true 219 | return b 220 | } 221 | 222 | // Prewhere adds a prewhere condition 223 | // For example clickhouse PREWHERE: 224 | // https://clickhouse.yandex/docs/en/query_language/select/#prewhere-clause 225 | func (b *selectStmt) Prewhere(query interface{}, value ...interface{}) SelectStmt { 226 | switch query := query.(type) { 227 | case string: 228 | b.PrewhereCond = append(b.PrewhereCond, Expr(query, value...)) 229 | case Builder: 230 | b.PrewhereCond = append(b.PrewhereCond, query) 231 | } 232 | return b 233 | } 234 | 235 | // Where adds a where condition 236 | func (b *selectStmt) Where(query interface{}, value ...interface{}) SelectStmt { 237 | switch query := query.(type) { 238 | case string: 239 | b.WhereCond = append(b.WhereCond, Expr(query, value...)) 240 | case Builder: 241 | b.WhereCond = append(b.WhereCond, query) 242 | } 243 | return b 244 | } 245 | 246 | // Having adds a having condition 247 | func (b *selectStmt) Having(query interface{}, value ...interface{}) SelectStmt { 248 | switch query := query.(type) { 249 | case string: 250 | b.HavingCond = append(b.HavingCond, Expr(query, value...)) 251 | case Builder: 252 | b.HavingCond = append(b.HavingCond, query) 253 | } 254 | return b 255 | } 256 | 257 | // GroupBy specifies columns for grouping 258 | func (b *selectStmt) GroupBy(col ...string) SelectStmt { 259 | for _, group := range col { 260 | b.Group = append(b.Group, Expr(group)) 261 | } 262 | return b 263 | } 264 | 265 | // OrderAsc specifies columns for ordering in asc direction 266 | func (b *selectStmt) OrderAsc(col string) SelectStmt { 267 | b.Order = append(b.Order, order(col, asc)) 268 | return b 269 | } 270 | 271 | // OrderDesc specifies columns for ordering in desc direction 272 | func (b *selectStmt) OrderDesc(col string) SelectStmt { 273 | b.Order = append(b.Order, order(col, desc)) 274 | return b 275 | } 276 | 277 | // Limit adds LIMIT 278 | func (b *selectStmt) Limit(n uint64) SelectStmt { 279 | b.LimitCount = int64(n) 280 | return b 281 | } 282 | 283 | // Offset adds OFFSET, works only if LIMIT is set 284 | func (b *selectStmt) Offset(n uint64) SelectStmt { 285 | b.OffsetCount = int64(n) 286 | return b 287 | } 288 | 289 | // ForUpdate adds `FOR UPDATE` 290 | func (b *selectStmt) ForUpdate() SelectStmt { 291 | b.IsForUpdate = true 292 | return b 293 | } 294 | 295 | // SkipLocked adds `SKIP LOCKED` 296 | func (b *selectStmt) SkipLocked() SelectStmt { 297 | b.IsSkipLocked = true 298 | return b 299 | } 300 | 301 | // Join joins table on condition 302 | func (b *selectStmt) Join(table, on interface{}) SelectStmt { 303 | b.JoinTable = append(b.JoinTable, join(inner, table, on)) 304 | return b 305 | } 306 | 307 | // LeftJoin joins table on condition via LEFT JOIN 308 | func (b *selectStmt) LeftJoin(table, on interface{}) SelectStmt { 309 | b.JoinTable = append(b.JoinTable, join(left, table, on)) 310 | return b 311 | } 312 | 313 | // RightJoin joins table on condition via RIGHT JOIN 314 | func (b *selectStmt) RightJoin(table, on interface{}) SelectStmt { 315 | b.JoinTable = append(b.JoinTable, join(right, table, on)) 316 | return b 317 | } 318 | 319 | // FullJoin joins table on condition via FULL JOIN 320 | func (b *selectStmt) FullJoin(table, on interface{}) SelectStmt { 321 | b.JoinTable = append(b.JoinTable, join(full, table, on)) 322 | return b 323 | } 324 | 325 | // AddComment adds a comment at the beginning of the query 326 | func (b *selectStmt) AddComment(comment string) SelectStmt { 327 | b.Comment = append(b.Comment, Expr(comment)) 328 | return b 329 | } 330 | 331 | // As creates alias for select statement 332 | func (b *selectStmt) As(alias string) Builder { 333 | return as(b, alias) 334 | } 335 | -------------------------------------------------------------------------------- /select_builder.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "reflect" 7 | "time" 8 | ) 9 | 10 | // SelectBuilder build "SELECT" stmt 11 | type SelectBuilder interface { 12 | Builder 13 | EventReceiver 14 | loader 15 | typesLoader 16 | 17 | As(alias string) Builder 18 | Comment(text string) SelectBuilder 19 | Distinct() SelectBuilder 20 | ForUpdate() SelectBuilder 21 | From(table interface{}) SelectBuilder 22 | FullJoin(table, on interface{}) SelectBuilder 23 | GroupBy(col ...string) SelectBuilder 24 | Having(query interface{}, value ...interface{}) SelectBuilder 25 | InTimezone(loc *time.Location) SelectBuilder 26 | Join(table, on interface{}) SelectBuilder 27 | LeftJoin(table, on interface{}) SelectBuilder 28 | Limit(n uint64) SelectBuilder 29 | Offset(n uint64) SelectBuilder 30 | OrderAsc(col string) SelectBuilder 31 | OrderBy(col string) SelectBuilder 32 | OrderDesc(col string) SelectBuilder 33 | OrderDir(col string, isAsc bool) SelectBuilder 34 | Paginate(page, perPage uint64) SelectBuilder 35 | Prewhere(query interface{}, value ...interface{}) SelectBuilder 36 | RightJoin(table, on interface{}) SelectBuilder 37 | SkipLocked() SelectBuilder 38 | Where(query interface{}, value ...interface{}) SelectBuilder 39 | GetRows() (*sql.Rows, error) 40 | GetRowsContext(context.Context) (*sql.Rows, error) 41 | } 42 | 43 | type selectBuilder struct { 44 | runner 45 | EventReceiver 46 | 47 | Dialect Dialect 48 | selectStmt *selectStmt 49 | timezone *time.Location 50 | ctx context.Context 51 | } 52 | 53 | func prepareSelect(a []string) []interface{} { 54 | b := make([]interface{}, len(a)) 55 | for i := range a { 56 | b[i] = a[i] 57 | } 58 | return b 59 | } 60 | 61 | // Select creates a SelectBuilder 62 | func (sess *Session) Select(column ...string) SelectBuilder { 63 | return &selectBuilder{ 64 | runner: sess, 65 | EventReceiver: sess.EventReceiver, 66 | Dialect: sess.Dialect, 67 | selectStmt: createSelectStmt(prepareSelect(column)), 68 | ctx: sess.ctx, 69 | } 70 | } 71 | 72 | // Select creates a SelectBuilder 73 | func (tx *Tx) Select(column ...string) SelectBuilder { 74 | return &selectBuilder{ 75 | runner: tx, 76 | EventReceiver: tx.EventReceiver, 77 | Dialect: tx.Dialect, 78 | selectStmt: createSelectStmt(prepareSelect(column)), 79 | ctx: tx.ctx, 80 | } 81 | } 82 | 83 | // SelectBySql creates a SelectBuilder from raw query 84 | func (sess *Session) SelectBySql(query string, value ...interface{}) SelectBuilder { 85 | return &selectBuilder{ 86 | runner: sess, 87 | EventReceiver: sess.EventReceiver, 88 | Dialect: sess.Dialect, 89 | selectStmt: createSelectStmtBySQL(query, value), 90 | ctx: sess.ctx, 91 | } 92 | } 93 | 94 | // SelectBySql creates a SelectBuilder from raw query 95 | func (tx *Tx) SelectBySql(query string, value ...interface{}) SelectBuilder { 96 | return &selectBuilder{ 97 | runner: tx, 98 | EventReceiver: tx.EventReceiver, 99 | Dialect: tx.Dialect, 100 | selectStmt: createSelectStmtBySQL(query, value), 101 | ctx: tx.ctx, 102 | } 103 | } 104 | 105 | func (b *selectBuilder) changeTimezone(value reflect.Value) { 106 | v, t := extractOriginal(value) 107 | switch t { 108 | case reflect.Slice, reflect.Array: 109 | for i := 0; i < v.Len(); i++ { 110 | b.changeTimezone(v.Index(i)) 111 | } 112 | case reflect.Map: 113 | // TODO: add timezone changing for map keys 114 | for _, k := range v.MapKeys() { 115 | b.changeTimezone(v.MapIndex(k)) 116 | } 117 | case reflect.Struct: 118 | if v.Type() == reflect.TypeOf(time.Time{}) { 119 | v.Set(reflect.ValueOf(v.Interface().(time.Time).In(b.timezone))) 120 | return 121 | } 122 | 123 | for i := 0; i < v.NumField(); i++ { 124 | b.changeTimezone(v.Field(i)) 125 | } 126 | } 127 | } 128 | 129 | func (b *selectBuilder) Build(d Dialect, buf Buffer) error { 130 | return b.selectStmt.Build(d, buf) 131 | } 132 | 133 | // Load loads any value from query result with background context 134 | func (b *selectBuilder) Load(value interface{}) (int, error) { 135 | return b.LoadContext(b.ctx, value) 136 | } 137 | 138 | // LoadContext loads any value from query result 139 | func (b *selectBuilder) LoadContext(ctx context.Context, value interface{}) (int, error) { 140 | c, err := query(ctx, b.runner, b.EventReceiver, b, b.Dialect, value) 141 | if err == nil && b.timezone != nil { 142 | b.changeTimezone(reflect.ValueOf(value)) 143 | } 144 | return c, err 145 | } 146 | 147 | // LoadStruct loads struct from query result with background context, returns ErrNotFound if there is no result 148 | func (b *selectBuilder) LoadStruct(value interface{}) error { 149 | return b.LoadStructContext(b.ctx, value) 150 | } 151 | 152 | // LoadStructContext loads struct from query result, returns ErrNotFound if there is no result 153 | func (b *selectBuilder) LoadStructContext(ctx context.Context, value interface{}) error { 154 | count, err := query(ctx, b.runner, b.EventReceiver, b, b.Dialect, value) 155 | if err != nil { 156 | return err 157 | } 158 | if count == 0 { 159 | return ErrNotFound 160 | } 161 | if b.timezone != nil { 162 | b.changeTimezone(reflect.ValueOf(value)) 163 | } 164 | return nil 165 | } 166 | 167 | // LoadStructs loads structures from query result with background context 168 | func (b *selectBuilder) LoadStructs(value interface{}) (int, error) { 169 | return b.LoadStructsContext(b.ctx, value) 170 | } 171 | 172 | // LoadStructsContext loads structures from query result 173 | func (b *selectBuilder) LoadStructsContext(ctx context.Context, value interface{}) (int, error) { 174 | c, err := query(ctx, b.runner, b.EventReceiver, b, b.Dialect, value) 175 | if err == nil && b.timezone != nil { 176 | b.changeTimezone(reflect.ValueOf(value)) 177 | } 178 | return c, err 179 | } 180 | 181 | // GetRows returns sql.Rows from query result. 182 | func (b *selectBuilder) GetRows() (*sql.Rows, error) { 183 | return b.GetRowsContext(b.ctx) 184 | } 185 | 186 | // GetRowsContext returns sql.Rows from query result. 187 | func (b *selectBuilder) GetRowsContext(ctx context.Context) (*sql.Rows, error) { 188 | rows, _, err := queryRows(ctx, b.runner, b.EventReceiver, b, b.Dialect) 189 | 190 | return rows, err 191 | } 192 | 193 | // LoadValue loads any value from query result with background context, returns ErrNotFound if there is no result 194 | func (b *selectBuilder) LoadValue(value interface{}) error { 195 | return b.LoadValueContext(b.ctx, value) 196 | } 197 | 198 | // LoadValueContext loads any value from query result, returns ErrNotFound if there is no result 199 | func (b *selectBuilder) LoadValueContext(ctx context.Context, value interface{}) error { 200 | count, err := query(ctx, b.runner, b.EventReceiver, b, b.Dialect, value) 201 | if err != nil { 202 | return err 203 | } 204 | if count == 0 { 205 | return ErrNotFound 206 | } 207 | if b.timezone != nil { 208 | b.changeTimezone(reflect.ValueOf(value)) 209 | } 210 | return nil 211 | } 212 | 213 | // LoadValues loads any values from query result with background context 214 | func (b *selectBuilder) LoadValues(value interface{}) (int, error) { 215 | return b.LoadValuesContext(b.ctx, value) 216 | } 217 | 218 | // LoadValuesContext loads any values from query result 219 | func (b *selectBuilder) LoadValuesContext(ctx context.Context, value interface{}) (int, error) { 220 | c, err := query(ctx, b.runner, b.EventReceiver, b, b.Dialect, value) 221 | if err == nil && b.timezone != nil { 222 | b.changeTimezone(reflect.ValueOf(value)) 223 | } 224 | return c, err 225 | } 226 | 227 | // Join joins table on condition 228 | func (b *selectBuilder) Join(table, on interface{}) SelectBuilder { 229 | b.selectStmt.Join(table, on) 230 | return b 231 | } 232 | 233 | // LeftJoin joins table on condition via LEFT JOIN 234 | func (b *selectBuilder) LeftJoin(table, on interface{}) SelectBuilder { 235 | b.selectStmt.LeftJoin(table, on) 236 | return b 237 | } 238 | 239 | // RightJoin joins table on condition via RIGHT JOIN 240 | func (b *selectBuilder) RightJoin(table, on interface{}) SelectBuilder { 241 | b.selectStmt.RightJoin(table, on) 242 | return b 243 | } 244 | 245 | // FullJoin joins table on condition via FULL JOIN 246 | func (b *selectBuilder) FullJoin(table, on interface{}) SelectBuilder { 247 | b.selectStmt.FullJoin(table, on) 248 | return b 249 | } 250 | 251 | // Distinct adds `DISTINCT` 252 | func (b *selectBuilder) Distinct() SelectBuilder { 253 | b.selectStmt.Distinct() 254 | return b 255 | } 256 | 257 | // From specifies table 258 | func (b *selectBuilder) From(table interface{}) SelectBuilder { 259 | b.selectStmt.From(table) 260 | return b 261 | } 262 | 263 | // GroupBy specifies columns for grouping 264 | func (b *selectBuilder) GroupBy(col ...string) SelectBuilder { 265 | b.selectStmt.GroupBy(col...) 266 | return b 267 | } 268 | 269 | // Having adds a having condition 270 | func (b *selectBuilder) Having(query interface{}, value ...interface{}) SelectBuilder { 271 | b.selectStmt.Having(query, value...) 272 | return b 273 | } 274 | 275 | // Limit adds LIMIT 276 | func (b *selectBuilder) Limit(n uint64) SelectBuilder { 277 | b.selectStmt.Limit(n) 278 | return b 279 | } 280 | 281 | // Offset adds OFFSET, works only if LIMIT is set 282 | func (b *selectBuilder) Offset(n uint64) SelectBuilder { 283 | b.selectStmt.Offset(n) 284 | return b 285 | } 286 | 287 | // OrderDir specifies columns for ordering in direction 288 | func (b *selectBuilder) OrderDir(col string, isAsc bool) SelectBuilder { 289 | if isAsc { 290 | b.selectStmt.OrderAsc(col) 291 | } else { 292 | b.selectStmt.OrderDesc(col) 293 | } 294 | return b 295 | } 296 | 297 | // Paginate adds LIMIT and OFFSET 298 | func (b *selectBuilder) Paginate(page, perPage uint64) SelectBuilder { 299 | b.Limit(perPage) 300 | b.Offset((page - 1) * perPage) 301 | return b 302 | } 303 | 304 | // OrderBy specifies column for ordering 305 | func (b *selectBuilder) OrderBy(col string) SelectBuilder { 306 | b.selectStmt.Order = append(b.selectStmt.Order, Expr(col)) 307 | return b 308 | } 309 | 310 | // Where adds a where condition 311 | func (b *selectBuilder) Prewhere(query interface{}, value ...interface{}) SelectBuilder { 312 | b.selectStmt.Prewhere(query, value...) 313 | return b 314 | } 315 | 316 | // Where adds a where condition 317 | func (b *selectBuilder) Where(query interface{}, value ...interface{}) SelectBuilder { 318 | b.selectStmt.Where(query, value...) 319 | return b 320 | } 321 | 322 | // ForUpdate adds lock via FOR UPDATE 323 | func (b *selectBuilder) ForUpdate() SelectBuilder { 324 | b.selectStmt.ForUpdate() 325 | return b 326 | } 327 | 328 | // SkipLocked skips locked rows via SKIP LOCKED 329 | func (b *selectBuilder) SkipLocked() SelectBuilder { 330 | b.selectStmt.SkipLocked() 331 | return b 332 | } 333 | 334 | // InTimezone all time.Time fields in the result will be returned with the specified location. 335 | func (b *selectBuilder) InTimezone(loc *time.Location) SelectBuilder { 336 | b.timezone = loc 337 | return b 338 | } 339 | 340 | func (b *selectBuilder) OrderAsc(col string) SelectBuilder { 341 | b.selectStmt.OrderAsc(col) 342 | return b 343 | } 344 | 345 | func (b *selectBuilder) OrderDesc(col string) SelectBuilder { 346 | b.selectStmt.OrderDesc(col) 347 | return b 348 | } 349 | 350 | // As creates alias for select statement 351 | func (b *selectBuilder) As(alias string) Builder { 352 | return b.selectStmt.As(alias) 353 | } 354 | 355 | // Comment adds a comment at the beginning of the query 356 | func (b *selectBuilder) Comment(text string) SelectBuilder { 357 | b.selectStmt.AddComment(text) 358 | return b 359 | } 360 | -------------------------------------------------------------------------------- /select_builder_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type innerTestStruct struct { 12 | InnerTime time.Time 13 | InnerNonTime int64 14 | } 15 | 16 | type testStruct struct { 17 | innerTestStruct 18 | Time time.Time 19 | Inner innerTestStruct 20 | InnerPtr *innerTestStruct 21 | InnerSlice []innerTestStruct 22 | InnerSlicePtr []*innerTestStruct 23 | InnerMap map[int]*innerTestStruct 24 | } 25 | 26 | func TestChangeTimezone(t *testing.T) { 27 | location := "America/New_York" 28 | 29 | v := testStruct{ 30 | innerTestStruct: innerTestStruct{ 31 | InnerTime: time.Date(2020, 1, 20, 8, 0, 0, 0, time.UTC), 32 | }, 33 | Time: time.Date(2020, 1, 21, 9, 0, 0, 0, time.UTC), 34 | Inner: innerTestStruct{ 35 | InnerTime: time.Date(2020, 1, 22, 10, 0, 0, 0, time.UTC), 36 | }, 37 | InnerPtr: &innerTestStruct{ 38 | InnerTime: time.Date(2020, 1, 23, 11, 0, 0, 0, time.UTC), 39 | }, 40 | InnerSlice: []innerTestStruct{ 41 | {InnerTime: time.Date(2020, 1, 24, 12, 0, 0, 0, time.UTC)}, 42 | {InnerTime: time.Date(2020, 1, 25, 13, 0, 0, 0, time.UTC)}, 43 | }, 44 | InnerSlicePtr: []*innerTestStruct{ 45 | {InnerTime: time.Date(2020, 1, 26, 14, 0, 0, 0, time.UTC)}, 46 | {InnerTime: time.Date(2020, 1, 27, 15, 0, 0, 0, time.UTC)}, 47 | }, 48 | InnerMap: map[int]*innerTestStruct{ 49 | 1: {InnerTime: time.Date(2020, 1, 28, 16, 0, 0, 0, time.UTC)}, 50 | 2: {InnerTime: time.Date(2020, 1, 28, 16, 0, 0, 0, time.UTC)}, 51 | }, 52 | } 53 | 54 | b := &selectBuilder{} 55 | l, _ := time.LoadLocation(location) 56 | b.InTimezone(l) 57 | b.changeTimezone(reflect.ValueOf(&v)) 58 | 59 | assert.Equal(t, "America/New_York", v.InnerTime.Location().String()) 60 | assert.Equal(t, "America/New_York", v.Time.Location().String()) 61 | assert.Equal(t, "America/New_York", v.Inner.InnerTime.Location().String()) 62 | assert.Equal(t, "America/New_York", v.InnerPtr.InnerTime.Location().String()) 63 | for _, tt := range v.InnerSlice { 64 | assert.Equal(t, "America/New_York", tt.InnerTime.Location().String()) 65 | } 66 | 67 | for _, tt := range v.InnerSlicePtr { 68 | assert.Equal(t, "America/New_York", tt.InnerTime.Location().String()) 69 | } 70 | 71 | for _, tt := range v.InnerMap { 72 | assert.Equal(t, "America/New_York", tt.InnerTime.Location().String()) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /select_return.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | // 4 | // These are a set of helpers that just call LoadValue and return the value. 5 | // They return (_, ErrNotFound) if nothing was found. 6 | // 7 | 8 | // The inclusion of these helpers in the package is not an obvious choice: 9 | // Benefits: 10 | // - slight increase in code clarity/conciseness b/c you can use ":=" to define the variable 11 | // 12 | // count, err := d.Select("COUNT(*)").From("users").Where("x = ?", x).ReturnInt64() 13 | // 14 | // vs 15 | // 16 | // var count int64 17 | // err := d.Select("COUNT(*)").From("users").Where("x = ?", x).LoadValue(&count) 18 | // 19 | // Downsides: 20 | // - very small increase in code cost, although it's not complex code 21 | // - increase in conceptual model / API footprint when presenting the package to new users 22 | // - no functionality that you can't achieve calling .LoadValue directly. 23 | // - There's a lot of possible types. Do we want to include ALL of them? u?int{8,16,32,64}?, strings, null varieties, etc. 24 | // - Let's just do the common, non-null varieties. 25 | 26 | type typesLoader interface { 27 | ReturnInt64() (int64, error) 28 | ReturnInt64s() ([]int64, error) 29 | ReturnUint64() (uint64, error) 30 | ReturnUint64s() ([]uint64, error) 31 | ReturnString() (string, error) 32 | ReturnStrings() ([]string, error) 33 | } 34 | 35 | // ReturnInt64 executes the SelectStmt and returns the value as an int64 36 | func (b *selectBuilder) ReturnInt64() (int64, error) { 37 | var v int64 38 | err := b.LoadValue(&v) 39 | return v, err 40 | } 41 | 42 | // ReturnInt64s executes the SelectStmt and returns the value as a slice of int64s 43 | func (b *selectBuilder) ReturnInt64s() ([]int64, error) { 44 | var v []int64 45 | _, err := b.LoadValues(&v) 46 | return v, err 47 | } 48 | 49 | // ReturnUint64 executes the SelectStmt and returns the value as an uint64 50 | func (b *selectBuilder) ReturnUint64() (uint64, error) { 51 | var v uint64 52 | err := b.LoadValue(&v) 53 | return v, err 54 | } 55 | 56 | // ReturnUint64s executes the SelectStmt and returns the value as a slice of uint64s 57 | func (b *selectBuilder) ReturnUint64s() ([]uint64, error) { 58 | var v []uint64 59 | _, err := b.LoadValues(&v) 60 | return v, err 61 | } 62 | 63 | // ReturnString executes the SelectStmt and returns the value as a string 64 | func (b *selectBuilder) ReturnString() (string, error) { 65 | var v string 66 | err := b.LoadValue(&v) 67 | return v, err 68 | } 69 | 70 | // ReturnStrings executes the SelectStmt and returns the value as a slice of strings 71 | func (b *selectBuilder) ReturnStrings() ([]string, error) { 72 | var v []string 73 | _, err := b.LoadValues(&v) 74 | return v, err 75 | } 76 | -------------------------------------------------------------------------------- /select_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/mailru/dbr/dialect" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestSelectStmt(t *testing.T) { 11 | bufClickHouse := NewBuffer() 12 | bufMySQL := NewBuffer() 13 | builder := Select("a", "b"). 14 | AddComment("zzz"). 15 | From(Select("a").From("table")). 16 | LeftJoin("table2", "table.a1 = table.a2"). 17 | Distinct(). 18 | Prewhere(Eq("c1", 15)). 19 | Where(Eq("c2", 1)). 20 | GroupBy("d"). 21 | Having(Eq("e", 2)). 22 | OrderAsc("f"). 23 | Limit(3). 24 | Offset(4). 25 | ForUpdate(). 26 | SkipLocked() 27 | 28 | err := builder.Build(dialect.ClickHouse, bufClickHouse) // because this lib is clickhouse first. 29 | assert.NoError(t, err) 30 | assert.Equal(t, "/* zzz */SELECT DISTINCT a, b FROM ? LEFT JOIN `table2` ON table.a1 = table.a2 PREWHERE (`c1` = ?) WHERE (`c2` = ?) GROUP BY d HAVING (`e` = ?) ORDER BY f ASC LIMIT 4,3 FOR UPDATE SKIP LOCKED", bufClickHouse.String()) 31 | assert.Equal(t, 4, len(bufClickHouse.Value())) 32 | 33 | err = builder.Build(dialect.MySQL, bufMySQL) 34 | assert.EqualError(t, err, ErrPrewhereNotSupported.Error()) // handle PREWHERE statement error 35 | } 36 | 37 | func BenchmarkSelectSQL(b *testing.B) { 38 | buf := NewBuffer() 39 | for i := 0; i < b.N; i++ { 40 | Select("a", "b").From("table").Where(Eq("c", 1)).OrderAsc("d").Build(dialect.MySQL, buf) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /transaction.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | ) 7 | 8 | // Tx is a transaction for the given Session 9 | type Tx struct { 10 | EventReceiver 11 | Dialect Dialect 12 | *sql.Tx 13 | ctx context.Context 14 | } 15 | 16 | // Begin creates a transaction for the given session 17 | func (sess *Session) Begin() (*Tx, error) { 18 | return sess.BeginWithOpts(&sql.TxOptions{}) 19 | } 20 | 21 | // BeginWithOpts creates a transaction for the given section with ability to set TxOpts 22 | func (sess *Session) BeginWithOpts(opts *sql.TxOptions) (*Tx, error) { 23 | tx, err := sess.beginTx(opts) 24 | if err != nil { 25 | return nil, sess.EventErr("dbr.begin.error", err) 26 | } 27 | sess.Event("dbr.begin") 28 | 29 | return &Tx{ 30 | EventReceiver: sess.EventReceiver, 31 | Dialect: sess.Dialect, 32 | Tx: tx, 33 | ctx: sess.ctx, 34 | }, nil 35 | } 36 | 37 | // Commit finishes the transaction 38 | func (tx *Tx) Commit() error { 39 | err := tx.Tx.Commit() 40 | if err != nil { 41 | return tx.EventErr("dbr.commit.error", err) 42 | } 43 | tx.Event("dbr.commit") 44 | return nil 45 | } 46 | 47 | // Rollback cancels the transaction 48 | func (tx *Tx) Rollback() error { 49 | err := tx.Tx.Rollback() 50 | if err != nil { 51 | return tx.EventErr("dbr.rollback", err) 52 | } 53 | tx.Event("dbr.rollback") 54 | return nil 55 | } 56 | 57 | // RollbackUnlessCommitted rollsback the transaction unless it has already been committed or rolled back. 58 | // Useful to defer tx.RollbackUnlessCommitted() -- so you don't have to handle N failure cases 59 | // Keep in mind the only way to detect an error on the rollback is via the event log. 60 | func (tx *Tx) RollbackUnlessCommitted() { 61 | err := tx.Tx.Rollback() 62 | if err == sql.ErrTxDone { 63 | // ok 64 | } else if err != nil { 65 | tx.EventErr("dbr.rollback_unless_committed", err) 66 | } else { 67 | tx.Event("dbr.rollback") 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /transaction_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "github.com/mailru/dbr/dialect" 5 | 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestTransactionCommit(t *testing.T) { 12 | for _, sess := range testSession { 13 | tx, err := sess.Begin() 14 | assert.NoError(t, err) 15 | defer tx.RollbackUnlessCommitted() 16 | 17 | id := nextID() 18 | 19 | result, err := tx.InsertInto("dbr_people").Columns("id", "name", "email").Values(id, "Barack", "obama@whitehouse.gov").Exec() 20 | assert.NoError(t, err) 21 | 22 | rowsAffected, err := result.RowsAffected() 23 | // not all drivers supports RowsAffected 24 | if err == nil { 25 | assert.EqualValues(t, 1, rowsAffected) 26 | } 27 | 28 | err = tx.Commit() 29 | assert.NoError(t, err) 30 | 31 | var person person 32 | err = tx.Select("*").From("dbr_people").Where(Eq("id", id)).LoadStruct(&person) 33 | assert.Error(t, err) 34 | } 35 | } 36 | 37 | func TestTransactionRollback(t *testing.T) { 38 | for _, sess := range testSession { 39 | if sess.Dialect == dialect.ClickHouse { 40 | // clickhouse does not support transactions 41 | continue 42 | } 43 | tx, err := sess.Begin() 44 | assert.NoError(t, err) 45 | defer tx.RollbackUnlessCommitted() 46 | 47 | id := nextID() 48 | 49 | result, err := tx.InsertInto("dbr_people").Columns("id", "name", "email").Values(id, "Barack", "obama@whitehouse.gov").Exec() 50 | assert.NoError(t, err) 51 | 52 | rowsAffected, err := result.RowsAffected() 53 | assert.NoError(t, err) 54 | assert.EqualValues(t, 1, rowsAffected) 55 | 56 | err = tx.Rollback() 57 | assert.NoError(t, err) 58 | 59 | var person person 60 | err = tx.Select("*").From("dbr_people").Where(Eq("id", id)).LoadStruct(&person) 61 | assert.Error(t, err) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /types.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "database/sql/driver" 7 | "encoding/json" 8 | "time" 9 | ) 10 | 11 | // 12 | // Your app can use these Null types instead of the defaults. The sole benefit you get is a MarshalJSON method that is not retarded. 13 | // 14 | 15 | // NullString is a type that can be null or a string 16 | type NullString struct { 17 | sql.NullString 18 | } 19 | 20 | // NullFloat64 is a type that can be null or a float64 21 | type NullFloat64 struct { 22 | sql.NullFloat64 23 | } 24 | 25 | // NullInt64 is a type that can be null or an int 26 | type NullInt64 struct { 27 | sql.NullInt64 28 | } 29 | 30 | // NullTime is a type that can be null or a time 31 | type NullTime struct { 32 | Time time.Time 33 | Valid bool // Valid is true if Time is not NULL 34 | } 35 | 36 | // Value implements the driver Valuer interface. 37 | func (n NullTime) Value() (driver.Value, error) { 38 | if !n.Valid { 39 | return nil, nil 40 | } 41 | return n.Time, nil 42 | } 43 | 44 | // NullBool is a type that can be null or a bool 45 | type NullBool struct { 46 | sql.NullBool 47 | } 48 | 49 | var nullString = []byte("null") 50 | 51 | // MarshalJSON correctly serializes a NullString to JSON 52 | func (n NullString) MarshalJSON() ([]byte, error) { 53 | if n.Valid { 54 | return json.Marshal(n.String) 55 | } 56 | return nullString, nil 57 | } 58 | 59 | // MarshalJSON correctly serializes a NullInt64 to JSON 60 | func (n NullInt64) MarshalJSON() ([]byte, error) { 61 | if n.Valid { 62 | return json.Marshal(n.Int64) 63 | } 64 | return nullString, nil 65 | } 66 | 67 | // MarshalJSON correctly serializes a NullFloat64 to JSON 68 | func (n NullFloat64) MarshalJSON() ([]byte, error) { 69 | if n.Valid { 70 | return json.Marshal(n.Float64) 71 | } 72 | return nullString, nil 73 | } 74 | 75 | // MarshalJSON correctly serializes a NullTime to JSON 76 | func (n NullTime) MarshalJSON() ([]byte, error) { 77 | if n.Valid { 78 | return json.Marshal(n.Time) 79 | } 80 | return nullString, nil 81 | } 82 | 83 | // MarshalJSON correctly serializes a NullBool to JSON 84 | func (n NullBool) MarshalJSON() ([]byte, error) { 85 | if n.Valid { 86 | return json.Marshal(n.Bool) 87 | } 88 | return nullString, nil 89 | } 90 | 91 | // UnmarshalJSON correctly deserializes a NullString from JSON 92 | func (n *NullString) UnmarshalJSON(b []byte) error { 93 | var s interface{} 94 | if err := json.Unmarshal(b, &s); err != nil { 95 | return err 96 | } 97 | return n.Scan(s) 98 | } 99 | 100 | // UnmarshalJSON correctly deserializes a NullInt64 from JSON 101 | func (n *NullInt64) UnmarshalJSON(b []byte) error { 102 | var s interface{} 103 | if err := json.Unmarshal(b, &s); err != nil { 104 | return err 105 | } 106 | return n.Scan(s) 107 | } 108 | 109 | // UnmarshalJSON correctly deserializes a NullFloat64 from JSON 110 | func (n *NullFloat64) UnmarshalJSON(b []byte) error { 111 | var s interface{} 112 | if err := json.Unmarshal(b, &s); err != nil { 113 | return err 114 | } 115 | return n.Scan(s) 116 | } 117 | 118 | // UnmarshalJSON correctly deserializes a NullTime from JSON 119 | func (n *NullTime) UnmarshalJSON(b []byte) error { 120 | // scan for null 121 | if bytes.Equal(b, nullString) { 122 | return n.Scan(nil) 123 | } 124 | // scan for JSON timestamp 125 | var t time.Time 126 | if err := json.Unmarshal(b, &t); err != nil { 127 | return err 128 | } 129 | return n.Scan(t) 130 | } 131 | 132 | // UnmarshalJSON correctly deserializes a NullBool from JSON 133 | func (n *NullBool) UnmarshalJSON(b []byte) error { 134 | var s interface{} 135 | if err := json.Unmarshal(b, &s); err != nil { 136 | return err 137 | } 138 | return n.Scan(s) 139 | } 140 | 141 | // NewNullInt64 create a NullInt64 from v 142 | func NewNullInt64(v interface{}) (n NullInt64) { 143 | n.Scan(v) 144 | return 145 | } 146 | 147 | // NewNullFloat64 create a NullFloat64 from v 148 | func NewNullFloat64(v interface{}) (n NullFloat64) { 149 | n.Scan(v) 150 | return 151 | } 152 | 153 | // NewNullString create a NullString from v 154 | func NewNullString(v interface{}) (n NullString) { 155 | n.Scan(v) 156 | return 157 | } 158 | 159 | // NewNullTime create a NullTime from v 160 | func NewNullTime(v interface{}) (n NullTime) { 161 | n.Scan(v) 162 | return 163 | } 164 | 165 | // NewNullBool create a NullBool from v 166 | func NewNullBool(v interface{}) (n NullBool) { 167 | n.Scan(v) 168 | return 169 | } 170 | 171 | // The `(*NullTime) Scan(interface{})` and `parseDateTime(string, *time.Location)` 172 | // functions are slightly modified versions of code from the github.com/go-sql-driver/mysql 173 | // package. They work with Postgres and MySQL databases. Potential future 174 | // drivers should ensure these will work for them, or come up with an alternative. 175 | // 176 | // Conforming with its licensing terms the copyright notice and link to the licence 177 | // are available below. 178 | // 179 | // Source: https://github.com/go-sql-driver/mysql/blob/527bcd55aab2e53314f1a150922560174b493034/utils.go#L452-L508 180 | 181 | // Copyright notice from original developers: 182 | // 183 | // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 184 | // 185 | // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 186 | // 187 | // This Source Code Form is subject to the terms of the Mozilla Public 188 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 189 | // You can obtain one at http://mozilla.org/MPL/2.0/ 190 | 191 | // Scan implements the Scanner interface. 192 | // The value type must be time.Time or string / []byte (formatted time-string), 193 | // otherwise Scan fails. 194 | func (n *NullTime) Scan(value interface{}) error { 195 | var err error 196 | 197 | if value == nil { 198 | n.Time, n.Valid = time.Time{}, false 199 | return nil 200 | } 201 | 202 | switch v := value.(type) { 203 | case time.Time: 204 | n.Time, n.Valid = v, true 205 | return nil 206 | case []byte: 207 | n.Time, err = parseDateTime(string(v), time.UTC) 208 | n.Valid = err == nil 209 | return err 210 | case string: 211 | n.Time, err = parseDateTime(v, time.UTC) 212 | n.Valid = err == nil 213 | return err 214 | } 215 | 216 | n.Valid = false 217 | return nil 218 | } 219 | 220 | func parseDateTime(str string, loc *time.Location) (time.Time, error) { 221 | var t time.Time 222 | var err error 223 | 224 | base := "0000-00-00 00:00:00.0000000" 225 | switch len(str) { 226 | case 10, 19, 21, 22, 23, 24, 25, 26: 227 | if str == base[:len(str)] { 228 | return t, err 229 | } 230 | t, err = time.Parse(timeFormat[:len(str)], str) 231 | default: 232 | err = ErrInvalidTimestring 233 | return t, err 234 | } 235 | 236 | // Adjust location 237 | if err == nil && loc != time.UTC { 238 | y, mo, d := t.Date() 239 | h, mi, s := t.Clock() 240 | t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil 241 | } 242 | 243 | return t, err 244 | } 245 | -------------------------------------------------------------------------------- /types_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | "time" 7 | 8 | "github.com/mailru/dbr/dialect" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | var ( 13 | filledRecord = nullTypedRecord{ 14 | StringVal: NewNullString("wow"), 15 | Int64Val: NewNullInt64(42), 16 | Float64Val: NewNullFloat64(1.618), 17 | TimeVal: NewNullTime(time.Date(2009, 1, 3, 18, 15, 5, 0, time.UTC)), 18 | BoolVal: NewNullBool(true), 19 | } 20 | ) 21 | 22 | func TestNullTypesScanning(t *testing.T) { 23 | for _, test := range []struct { 24 | in nullTypedRecord 25 | }{ 26 | {}, 27 | { 28 | in: filledRecord, 29 | }, 30 | } { 31 | for _, sess := range testSession { 32 | if sess.Dialect == dialect.ClickHouse { 33 | // clickhouse does not support null type 34 | continue 35 | } 36 | test.in.ID = nextID() 37 | _, err := sess.InsertInto("null_types").Columns("id", "string_val", "int64_val", "float64_val", "time_val", "bool_val").Record(test.in).Exec() 38 | assert.NoError(t, err) 39 | 40 | var record nullTypedRecord 41 | err = sess.Select("*").From("null_types").Where(Eq("id", test.in.ID)).LoadStruct(&record) 42 | assert.NoError(t, err) 43 | if sess.Dialect == dialect.PostgreSQL { 44 | // TODO: https://github.com/lib/pq/issues/329 45 | if !record.TimeVal.Time.IsZero() { 46 | record.TimeVal.Time = record.TimeVal.Time.UTC() 47 | } 48 | } 49 | assert.Equal(t, test.in, record) 50 | } 51 | } 52 | } 53 | 54 | func TestNullTypesJSON(t *testing.T) { 55 | for _, test := range []struct { 56 | in interface{} 57 | in2 interface{} 58 | out interface{} 59 | want string 60 | }{ 61 | { 62 | in: &filledRecord.BoolVal, 63 | in2: filledRecord.BoolVal, 64 | out: new(NullBool), 65 | want: "true", 66 | }, 67 | { 68 | in: &filledRecord.Float64Val, 69 | in2: filledRecord.Float64Val, 70 | out: new(NullFloat64), 71 | want: "1.618", 72 | }, 73 | { 74 | in: &filledRecord.Int64Val, 75 | in2: filledRecord.Int64Val, 76 | out: new(NullInt64), 77 | want: "42", 78 | }, 79 | { 80 | in: &filledRecord.StringVal, 81 | in2: filledRecord.StringVal, 82 | out: new(NullString), 83 | want: `"wow"`, 84 | }, 85 | { 86 | in: &filledRecord.TimeVal, 87 | in2: filledRecord.TimeVal, 88 | out: new(NullTime), 89 | want: `"2009-01-03T18:15:05Z"`, 90 | }, 91 | } { 92 | // marshal ptr 93 | b, err := json.Marshal(test.in) 94 | assert.NoError(t, err) 95 | assert.Equal(t, test.want, string(b)) 96 | 97 | // marshal value 98 | b, err = json.Marshal(test.in2) 99 | assert.NoError(t, err) 100 | assert.Equal(t, test.want, string(b)) 101 | 102 | // unmarshal 103 | err = json.Unmarshal(b, test.out) 104 | assert.NoError(t, err) 105 | assert.Equal(t, test.in, test.out) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /union.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | type union struct { 4 | builder []Builder 5 | all bool 6 | } 7 | 8 | // Union builds "UNION ..." 9 | func Union(builder ...Builder) interface { 10 | Builder 11 | As(string) Builder 12 | } { 13 | return &union{ 14 | builder: builder, 15 | } 16 | } 17 | 18 | // UnionAll builds "UNION ALL ..." 19 | func UnionAll(builder ...Builder) interface { 20 | Builder 21 | As(string) Builder 22 | } { 23 | return &union{ 24 | builder: builder, 25 | all: true, 26 | } 27 | } 28 | 29 | func (u *union) Build(d Dialect, buf Buffer) error { 30 | for i, b := range u.builder { 31 | if i > 0 { 32 | buf.WriteString(" UNION ") 33 | if u.all { 34 | buf.WriteString("ALL ") 35 | } 36 | } 37 | buf.WriteString(placeholder) 38 | buf.WriteValue(b) 39 | } 40 | return nil 41 | } 42 | 43 | func (u *union) As(alias string) Builder { 44 | return as(u, alias) 45 | } 46 | -------------------------------------------------------------------------------- /update.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import "reflect" 4 | 5 | // UpdateStmt builds `UPDATE ...` 6 | type UpdateStmt interface { 7 | Builder 8 | 9 | Where(query interface{}, value ...interface{}) UpdateStmt 10 | Set(column string, value interface{}) UpdateStmt 11 | SetMap(m map[string]interface{}) UpdateStmt 12 | SetRecord(structValue interface{}) UpdateStmt 13 | } 14 | 15 | type updateStmt struct { 16 | raw 17 | 18 | Table string 19 | Value map[string]interface{} 20 | WhereCond []Builder 21 | } 22 | 23 | // Build builds `UPDATE ...` in dialect 24 | func (b *updateStmt) Build(d Dialect, buf Buffer) error { 25 | if b.raw.Query != "" { 26 | return b.raw.Build(d, buf) 27 | } 28 | 29 | if b.Table == "" { 30 | return ErrTableNotSpecified 31 | } 32 | 33 | if len(b.Value) == 0 { 34 | return ErrColumnNotSpecified 35 | } 36 | 37 | buf.WriteString("UPDATE ") 38 | buf.WriteString(d.QuoteIdent(b.Table)) 39 | buf.WriteString(" SET ") 40 | 41 | i := 0 42 | for col, v := range b.Value { 43 | if i > 0 { 44 | buf.WriteString(", ") 45 | } 46 | buf.WriteString(d.QuoteIdent(col)) 47 | buf.WriteString(" = ") 48 | buf.WriteString(placeholder) 49 | 50 | buf.WriteValue(v) 51 | i++ 52 | } 53 | 54 | if len(b.WhereCond) > 0 { 55 | buf.WriteString(" WHERE ") 56 | err := And(b.WhereCond...).Build(d, buf) 57 | if err != nil { 58 | return err 59 | } 60 | } 61 | return nil 62 | } 63 | 64 | // Update creates an UpdateStmt 65 | func Update(table string) UpdateStmt { 66 | return createUpdateStmt(table) 67 | } 68 | 69 | func createUpdateStmt(table string) *updateStmt { 70 | return &updateStmt{ 71 | Table: table, 72 | Value: make(map[string]interface{}), 73 | } 74 | } 75 | 76 | // UpdateBySql creates an UpdateStmt with raw query 77 | func UpdateBySql(query string, value ...interface{}) UpdateStmt { 78 | return createUpdateStmtBySQL(query, value) 79 | } 80 | 81 | func createUpdateStmtBySQL(query string, value []interface{}) *updateStmt { 82 | return &updateStmt{ 83 | raw: raw{ 84 | Query: query, 85 | Value: value, 86 | }, 87 | Value: make(map[string]interface{}), 88 | } 89 | } 90 | 91 | // Where adds a where condition 92 | func (b *updateStmt) Where(query interface{}, value ...interface{}) UpdateStmt { 93 | switch query := query.(type) { 94 | case string: 95 | b.WhereCond = append(b.WhereCond, Expr(query, value...)) 96 | case Builder: 97 | b.WhereCond = append(b.WhereCond, query) 98 | } 99 | return b 100 | } 101 | 102 | // Set specifies a key-value pair 103 | func (b *updateStmt) Set(column string, value interface{}) UpdateStmt { 104 | b.Value[column] = value 105 | return b 106 | } 107 | 108 | // SetMap specifies a list of key-value pair 109 | func (b *updateStmt) SetMap(m map[string]interface{}) UpdateStmt { 110 | for col, val := range m { 111 | b.Set(col, val) 112 | } 113 | return b 114 | } 115 | 116 | // SetRecord specifies a record with field and values to set 117 | func (b *updateStmt) SetRecord(structValue interface{}) UpdateStmt { 118 | v := reflect.Indirect(reflect.ValueOf(structValue)) 119 | 120 | if v.Kind() == reflect.Struct { 121 | sm := structMap(v.Type()) 122 | 123 | for col, index := range sm { 124 | b.Set(col, v.FieldByIndex(index).Interface()) 125 | } 126 | } 127 | 128 | return b 129 | } 130 | -------------------------------------------------------------------------------- /update_builder.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | ) 8 | 9 | // UpdateBuilder builds `UPDATE ...` 10 | type UpdateBuilder interface { 11 | Builder 12 | EventReceiver 13 | Executer 14 | 15 | Where(query interface{}, value ...interface{}) UpdateBuilder 16 | Set(column string, value interface{}) UpdateBuilder 17 | SetMap(m map[string]interface{}) UpdateBuilder 18 | Limit(n uint64) UpdateBuilder 19 | } 20 | 21 | type updateBuilder struct { 22 | EventReceiver 23 | runner 24 | 25 | Dialect Dialect 26 | updateStmt *updateStmt 27 | LimitCount int64 28 | ctx context.Context 29 | } 30 | 31 | // Update creates a UpdateBuilder 32 | func (sess *Session) Update(table string) UpdateBuilder { 33 | return &updateBuilder{ 34 | runner: sess, 35 | EventReceiver: sess.EventReceiver, 36 | Dialect: sess.Dialect, 37 | updateStmt: createUpdateStmt(table), 38 | LimitCount: -1, 39 | ctx: sess.ctx, 40 | } 41 | } 42 | 43 | // Update creates a UpdateBuilder 44 | func (tx *Tx) Update(table string) UpdateBuilder { 45 | return &updateBuilder{ 46 | runner: tx, 47 | EventReceiver: tx.EventReceiver, 48 | Dialect: tx.Dialect, 49 | updateStmt: createUpdateStmt(table), 50 | LimitCount: -1, 51 | ctx: tx.ctx, 52 | } 53 | } 54 | 55 | // UpdateBySql creates a UpdateBuilder from raw query 56 | func (sess *Session) UpdateBySql(query string, value ...interface{}) UpdateBuilder { 57 | return &updateBuilder{ 58 | runner: sess, 59 | EventReceiver: sess.EventReceiver, 60 | Dialect: sess.Dialect, 61 | updateStmt: createUpdateStmtBySQL(query, value), 62 | LimitCount: -1, 63 | ctx: sess.ctx, 64 | } 65 | } 66 | 67 | // UpdateBySql creates a UpdateBuilder from raw query 68 | func (tx *Tx) UpdateBySql(query string, value ...interface{}) UpdateBuilder { 69 | return &updateBuilder{ 70 | runner: tx, 71 | EventReceiver: tx.EventReceiver, 72 | Dialect: tx.Dialect, 73 | updateStmt: createUpdateStmtBySQL(query, value), 74 | LimitCount: -1, 75 | ctx: tx.ctx, 76 | } 77 | } 78 | 79 | // Exec executes the stmt with background context 80 | func (b *updateBuilder) Exec() (sql.Result, error) { 81 | return b.ExecContext(b.ctx) 82 | } 83 | 84 | // ExecContext executes the stmt 85 | func (b *updateBuilder) ExecContext(ctx context.Context) (sql.Result, error) { 86 | return exec(ctx, b.runner, b.EventReceiver, b, b.Dialect) 87 | } 88 | 89 | // Set adds "SET column=value" 90 | func (b *updateBuilder) Set(column string, value interface{}) UpdateBuilder { 91 | b.updateStmt.Set(column, value) 92 | return b 93 | } 94 | 95 | // SetMap adds "SET column=value" for each key value pair in m 96 | func (b *updateBuilder) SetMap(m map[string]interface{}) UpdateBuilder { 97 | b.updateStmt.SetMap(m) 98 | return b 99 | } 100 | 101 | // Where adds condition to the stmt 102 | func (b *updateBuilder) Where(query interface{}, value ...interface{}) UpdateBuilder { 103 | b.updateStmt.Where(query, value...) 104 | return b 105 | } 106 | 107 | // Limit adds LIMIT 108 | func (b *updateBuilder) Limit(n uint64) UpdateBuilder { 109 | b.LimitCount = int64(n) 110 | return b 111 | } 112 | 113 | // Build builds `UPDATE ...` in dialect 114 | func (b *updateBuilder) Build(d Dialect, buf Buffer) error { 115 | err := b.updateStmt.Build(b.Dialect, buf) 116 | if err != nil { 117 | return err 118 | } 119 | if b.LimitCount >= 0 { 120 | buf.WriteString(" LIMIT ") 121 | buf.WriteString(fmt.Sprint(b.LimitCount)) 122 | } 123 | return nil 124 | } 125 | -------------------------------------------------------------------------------- /update_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/mailru/dbr/dialect" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestUpdateStmt(t *testing.T) { 11 | buf := NewBuffer() 12 | builder := Update("table").Set("a", 1).Where(Eq("b", 2)) 13 | err := builder.Build(dialect.MySQL, buf) 14 | assert.NoError(t, err) 15 | 16 | assert.Equal(t, "UPDATE `table` SET `a` = ? WHERE (`b` = ?)", buf.String()) 17 | assert.Equal(t, []interface{}{1, 2}, buf.Value()) 18 | } 19 | 20 | func TestUpdateStmtSetRecord(t *testing.T) { 21 | record := struct{ A int }{A: 1} 22 | buf := NewBuffer() 23 | builder := Update("table").SetRecord(&record).Where(Eq("b", 2)) 24 | err := builder.Build(dialect.MySQL, buf) 25 | assert.NoError(t, err) 26 | 27 | assert.Equal(t, "UPDATE `table` SET `a` = ? WHERE (`b` = ?)", buf.String()) 28 | assert.Equal(t, []interface{}{1, 2}, buf.Value()) 29 | } 30 | 31 | func BenchmarkUpdateValuesSQL(b *testing.B) { 32 | buf := NewBuffer() 33 | for i := 0; i < b.N; i++ { 34 | Update("table").Set("a", 1).Build(dialect.MySQL, buf) 35 | } 36 | } 37 | 38 | func BenchmarkUpdateMapSQL(b *testing.B) { 39 | buf := NewBuffer() 40 | for i := 0; i < b.N; i++ { 41 | Update("table").SetMap(map[string]interface{}{"a": 1, "b": 2}).Build(dialect.MySQL, buf) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "bytes" 5 | "database/sql/driver" 6 | "reflect" 7 | "unicode" 8 | ) 9 | 10 | func camelCaseToSnakeCase(name string) string { 11 | buf := new(bytes.Buffer) 12 | 13 | runes := []rune(name) 14 | 15 | for i := 0; i < len(runes); i++ { 16 | buf.WriteRune(unicode.ToLower(runes[i])) 17 | if i != len(runes)-1 && unicode.IsUpper(runes[i+1]) && 18 | (unicode.IsLower(runes[i]) || unicode.IsDigit(runes[i]) || 19 | (i != len(runes)-2 && unicode.IsLower(runes[i+2]))) { 20 | buf.WriteRune('_') 21 | } 22 | } 23 | 24 | return buf.String() 25 | } 26 | 27 | // structMap builds index to fast lookup fields in struct 28 | func structMap(t reflect.Type) map[string][]int { 29 | m := make(map[string][]int) 30 | structTraverse(m, t, nil) 31 | return m 32 | } 33 | 34 | var ( 35 | typeValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() 36 | ) 37 | 38 | func structTraverse(m map[string][]int, t reflect.Type, head []int) { 39 | if t.Implements(typeValuer) { 40 | return 41 | } 42 | switch t.Kind() { 43 | case reflect.Ptr: 44 | structTraverse(m, t.Elem(), head) 45 | case reflect.Struct: 46 | for i := 0; i < t.NumField(); i++ { 47 | field := t.Field(i) 48 | if field.PkgPath != "" && !field.Anonymous { 49 | // unexported 50 | continue 51 | } 52 | tag := field.Tag.Get("db") 53 | if tag == "-" { 54 | // ignore 55 | continue 56 | } 57 | if tag == "" { 58 | // no tag, but we can record the field name 59 | tag = camelCaseToSnakeCase(field.Name) 60 | } 61 | if _, ok := m[tag]; !ok { 62 | m[tag] = append(head, i) 63 | } 64 | structTraverse(m, field.Type, append(head, i)) 65 | } 66 | } 67 | } 68 | 69 | // extractOriginal removes all ptr and interface wrappers 70 | func extractOriginal(v reflect.Value) (reflect.Value, reflect.Kind) { 71 | switch v.Kind() { 72 | case reflect.Ptr: 73 | if v.IsNil() { 74 | return v, reflect.Ptr 75 | } 76 | return extractOriginal(v.Elem()) 77 | case reflect.Interface: 78 | if v.IsNil() { 79 | return v, reflect.Interface 80 | } 81 | return extractOriginal(v.Elem()) 82 | default: 83 | return v, v.Kind() 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | package dbr 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestSnakeCase(t *testing.T) { 12 | for _, test := range []struct { 13 | in string 14 | want string 15 | }{ 16 | { 17 | in: "", 18 | want: "", 19 | }, 20 | { 21 | in: "IsDigit", 22 | want: "is_digit", 23 | }, 24 | { 25 | in: "Is", 26 | want: "is", 27 | }, 28 | { 29 | in: "IsID", 30 | want: "is_id", 31 | }, 32 | { 33 | in: "IsSQL", 34 | want: "is_sql", 35 | }, 36 | { 37 | in: "LongSQL", 38 | want: "long_sql", 39 | }, 40 | { 41 | in: "Float64Val", 42 | want: "float64_val", 43 | }, 44 | { 45 | in: "XMLName", 46 | want: "xml_name", 47 | }, 48 | } { 49 | assert.Equal(t, test.want, camelCaseToSnakeCase(test.in)) 50 | } 51 | } 52 | 53 | func TestStructMap(t *testing.T) { 54 | for _, test := range []struct { 55 | in interface{} 56 | expected map[string][]int 57 | }{ 58 | { 59 | in: struct { 60 | CreatedAt time.Time 61 | }{}, 62 | expected: map[string][]int{"created_at": {0}}, 63 | }, 64 | { 65 | in: struct { 66 | intVal int 67 | }{}, 68 | expected: map[string][]int{}, 69 | }, 70 | { 71 | in: struct { 72 | IntVal int `db:"test"` 73 | }{}, 74 | expected: map[string][]int{"test": {0}}, 75 | }, 76 | { 77 | in: struct { 78 | IntVal int `db:"-"` 79 | }{}, 80 | expected: map[string][]int{}, 81 | }, 82 | { 83 | in: struct { 84 | Test1 struct { 85 | Test2 int 86 | } 87 | }{}, 88 | expected: map[string][]int{"test1": {0}, "test2": {0, 0}}, 89 | }, 90 | } { 91 | m := structMap(reflect.ValueOf(test.in).Type()) 92 | assert.Equal(t, test.expected, m) 93 | } 94 | } 95 | --------------------------------------------------------------------------------