├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── changelog.txt ├── examples ├── mock.go └── pkg_doc.go ├── go.mod ├── go.sum ├── grammar ├── errors.go ├── grammar.go ├── grammar_postgres_test.go ├── grammar_sqlite_test.go ├── pkg_doc.go ├── postgres.go └── sqlite.go ├── hobbled └── pkg.go ├── interfaces.go ├── internal_test.go ├── model ├── errors.go ├── examples │ ├── mock.go │ ├── model.go │ └── pkg_doc.go ├── examples_test.go ├── model.go ├── models.go ├── models_test.go ├── pkg_doc.go ├── query_binding.go ├── query_binding_test.go ├── save_mode.go ├── statements │ ├── pkg_doc.go │ ├── query.go │ └── table.go └── tablename.go ├── pkg_doc.go ├── scanner.go ├── scanner_examples_test.go ├── scanner_test.go ├── schema ├── column.go ├── index.go ├── pkg_doc.go └── table.go ├── transact.go └── transact_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | 3 | # Binaries for programs and plugins 4 | *.exe 5 | *.exe~ 6 | *.dll 7 | *.so 8 | *.dylib 9 | 10 | # Test binary, build with `go test -c` 11 | *.test 12 | 13 | # Output of the go coverage tool, specifically when used with LiteIDE 14 | *.out 15 | 16 | coverage.txt 17 | cpu.prof 18 | mem.prof 19 | profile*.pdf 20 | profile*.png 21 | profile*.txt 22 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | branches: 4 | only: 5 | - master 6 | 7 | go: 8 | - '1.16.x' 9 | - '1.17.x' 10 | - '1.18.x' 11 | 12 | before_install: 13 | - go get -t -v ./... 14 | 15 | script: 16 | - go test ./... 17 | - go test -run abcxyz -benchmem -bench . 18 | - go test ./... -race -coverprofile=coverage.txt -covermode=atomic 19 | 20 | after_success: 21 | - bash <(curl -s https://codecov.io/bash) 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 nofeaturesonlybugs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Go Reference](https://pkg.go.dev/badge/github.com/nofeaturesonlybugs/sqlh.svg)](https://pkg.go.dev/github.com/nofeaturesonlybugs/sqlh) 2 | [![Go Report Card](https://goreportcard.com/badge/github.com/nofeaturesonlybugs/sqlh)](https://goreportcard.com/report/github.com/nofeaturesonlybugs/sqlh) 3 | [![Build Status](https://app.travis-ci.com/nofeaturesonlybugs/sqlh.svg?branch=master)](https://app.travis-ci.com/nofeaturesonlybugs/sqlh) 4 | [![codecov](https://codecov.io/gh/nofeaturesonlybugs/sqlh/branch/master/graph/badge.svg)](https://codecov.io/gh/nofeaturesonlybugs/sqlh) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | 7 | `sqlh` aka `SQL Helper`. 8 | 9 | ## `sqlh.Scanner` 10 | 11 | `sqlh.Scanner` is a powerful database result set scanner. 12 | 13 | - Similar to `jmoiron/sqlx` but supports nested Go `structs`. 14 | - _Should work with any `database/sql` compatible driver._ 15 | 16 | ## `model.Models` 17 | 18 | `model.Models` supports `INSERT|UPDATE` on Go `structs` registered as database _models_, where a _model_ is a language type mapped to a database table. 19 | 20 | - Supports Postgres. 21 | - Supports grammars that use `?` for parameters **and** have a `RETURNING` clause. 22 | - Benchmarked with Sqlite 3.35 -- your mileage may vary. 23 | 24 | ## `sqlh` Design Philosphy 25 | 26 | ``` 27 | Hand Crafted | | Can I Have 28 | Artisinal | ======================================= | My Database 29 | SQL | ^ | Back, Please? 30 | | 31 | +-- sqlh is here. 32 | ``` 33 | 34 | `sqlh` is easy to use because it lives very close to `database/sql`. The primary goal of `sqlh` is to work with and facilitate using `database/sql` without replacing or hijacking it. When using `sqlh` you manage your `*sql.DB` or create `*sql.Tx` as you normally would and pass those as arguments to functions in `sqlh` when scanning or persisting models; `sqlh` then works within the confines of what you gave it. 35 | 36 | When accepting arguments that work directly with the database (`*sql.DB` or `*sql.Tx`) `sqlh` accepts them as interfaces. This means `sqlh` may work with other database packages that define their own types as long as they kept a method set similar to `database/sql`. 37 | 38 | The implementation for `sqlh` is fairly straight forward. Primarily this is because all the heavy `reflect` work is offloaded to `set`, which is another of my packages @ https://www.github.com/nofeaturesonlybugs/set 39 | 40 | `set` exports a flexible `set.Mapper` for mapping Go `structs` to string names such as database columns. A lot of the power and flexibility exposed by `sqlh` is really derived from `set`. I think this gives `sqlh` an advantage over similar database packages because it's very configurable, performs well, and alleviates `sqlh` from getting bogged down in the complexities of `reflect`. 41 | 42 | Here are some `sqlh.Scanner` examples: 43 | 44 | ```go 45 | type MyStruct struct { 46 | Message string 47 | Number int 48 | } 49 | // 50 | db, err := examples.Connect(examples.ExSimpleMapper) 51 | if err != nil { 52 | fmt.Println(err.Error()) 53 | } 54 | // 55 | scanner := &sqlh.Scanner{ 56 | // Mapper is pure defaults. Uses exported struct names as column names. 57 | Mapper: &set.Mapper{}, 58 | } 59 | var rv []MyStruct // []*MyStruct also acceptable 60 | err = scanner.Select(db, &rv, "select * from mytable") 61 | if err != nil { 62 | fmt.Println(err.Error()) 63 | } 64 | ``` 65 | 66 | ```go 67 | type Common struct { 68 | Id int `json:"id"` 69 | Created time.Time `json:"created"` 70 | Modified time.Time `json:"modified"` 71 | } 72 | type Person struct { 73 | Common 74 | First string `json:"first"` 75 | Last string `json:"last"` 76 | } 77 | // Note here the natural mapping of SQL columns to nested structs. 78 | type Sale struct { 79 | Common 80 | // customer_first and customer_last map to Customer. 81 | Customer Person `json:"customer"` 82 | // contact_first and contact_last map to Contact. 83 | Contact Person `json:"contact"` 84 | } 85 | db, err := examples.Connect(examples.ExNestedTwice) 86 | if err != nil { 87 | fmt.Println(err.Error()) 88 | } 89 | // 90 | scanner := &sqlh.Scanner{ 91 | Mapper: &set.Mapper{ 92 | // Mapper elevates Common to same level as other fields. 93 | Elevated: set.NewTypeList(Common{}), 94 | // Nested struct fields joined with _ 95 | Join: "_", 96 | // Mapper uses struct tag db or json, db higher priority. 97 | Tags: []string{"db", "json"}, 98 | }, 99 | } 100 | var rv []Sale // []*Sale also acceptable 101 | query := ` 102 | select 103 | s.id, s.created, s.modified, 104 | s.customer_id, c.first as customer_first, c.last as customer_last, 105 | s.vendor_id as contact_id, v.first as contact_first, v.last as contact_last 106 | from sales s 107 | inner join customers c on s.customer_id = c.id 108 | inner join vendors v on s.vendor_id = v.id 109 | ` 110 | err = scanner.Select(db, &rv, query) 111 | if err != nil { 112 | fmt.Println(err.Error()) 113 | } 114 | ``` 115 | 116 | ## Roadmap 117 | 118 | The development of `sqlh` is essentially following my specific pain points when using `database/sql`: 119 | 120 | - ✓ Row scanning provided by sqlh.Scanner 121 | - ✓ High level Save() method provided by model.Models 122 | - ✓ Specific Insert(), Update(), and Upsert() logic provided by model.Models 123 | - Upsert() currently supports conflict from primary key; conflicts on arbitrary unique indexes not supported. 124 | - ⭴ `DELETE` CRUD statements : to be covered by `model.Models`. 125 | - ⭴ `UPSERT` type operations using index information : to be covered by `model.Models`. 126 | - ⭴ `Find()` or `Filter()` for advanced `WHERE` clauses and model selection. 127 | - ⭴ Performance enhancements if possible. 128 | - ⭴ Relationship management -- maybe. 129 | 130 | Personally I find `SELECT|INSERT|UPDATE` to be the most painful and tedious with large queries or tables so those are the features I've addressed first. 131 | 132 | ## `set.Mapper` Tips 133 | 134 | When you want `set.Mapper` to treat a nested struct as a single field rather than a struct itself add it to the `TreatAsScalar` member: 135 | 136 | - `TreatAsScalar : set.NewTypeList( sql.NullBool{}, sql.NullString{} )` 137 | 138 | When you use a common nested struct to represent fields present in many of your types consider using the `Elevated` member: 139 | 140 | ```go 141 | type CommonDB struct { 142 | Id int 143 | CreatedAt time.Time 144 | ModifiedAt time.Time 145 | } 146 | type Something struct { 147 | CommonDB 148 | Name string 149 | } 150 | ``` 151 | 152 | Without `Elevated` the `set.Mapper` will generate names like: 153 | 154 | ``` 155 | CommonDBId 156 | CommonDBCreatedAt 157 | CommonDBModifiedAt 158 | Name 159 | ``` 160 | 161 | To prevent `CommonDB` from being part of the name add `CommonDB{}` to the `Elevated` member of the mapper, which elevates the nested fields as if they were defined directly in the parent struct: 162 | 163 | ```go 164 | Elevated : set.NewTypeList( CommonDB{} ) 165 | ``` 166 | 167 | Then the generated names will be: 168 | 169 | ``` 170 | Id 171 | CreatedAt 172 | ModifiedAt 173 | Name 174 | ``` 175 | 176 | You can further customize generated names with struct tags: 177 | 178 | ```go 179 | type CommonDB struct { 180 | Id int `json:"id"` 181 | CreatedAt time.Time `json:"created"` 182 | ModifiedAt time.Time `json:"modified"` 183 | } 184 | type Something struct { 185 | CommonDB // No tag necessary since this field is Elevated. 186 | Name string `json:"name"` 187 | } 188 | ``` 189 | 190 | Specify the tag name to use in the `Tags` member, which is a `[]string`: 191 | 192 | ```go 193 | Tags : []string{"json"} 194 | ``` 195 | 196 | Now generated names will be: 197 | 198 | ``` 199 | id 200 | created 201 | modified 202 | name 203 | ``` 204 | 205 | If you want to use different names for some fields in your database versus your JSON encoding you can specify multiple `Tags`, with tags listed first taking higher priority: 206 | 207 | ```go 208 | Tags : []string{"db", "json"} // Uses either db or json, db has higher priority. 209 | ``` 210 | 211 | With the above `Tags`, if `CommonDB` is defined as the following: 212 | 213 | ```go 214 | type CommonDB struct { 215 | Id int `json:"id" db:"pk"` 216 | CreatedAt time.Time `json:"created" db:"created_tmz"` 217 | ModifiedAt time.Time `json:"modified" db:"modified_tmz"` 218 | } 219 | ``` 220 | 221 | Then the mapped names are: 222 | 223 | ``` 224 | pk 225 | created_tmz 226 | modified_tmz 227 | name 228 | ``` 229 | 230 | ## Benchmarks 231 | 232 | See my sibling package `sqlhbenchmarks` for my methodology, goals, and interpretation of results. 233 | 234 | ## API Consistency and Breaking Changes 235 | 236 | I am making a very concerted effort to break the API as little as possible while adding features or fixing bugs. However this software is currently in a pre-1.0.0 version and breaking changes _are_ allowed under standard semver. As the API approaches a stable 1.0.0 release I will list any such breaking changes here and they will always be signaled by a bump in _minor_ version. 237 | 238 | - 0.4.0 ⭢ 0.5.0 239 | - `model.Models` methods allow `[]T` or `[]*T` when performing `INSERT|UPDATE|UPSERT` on slices of models. 240 | - `model.QueryBinding` is no longer an interface. 241 | - `model.Model` pruned: 242 | - Removed fields `V`, `VSlice` and `BoundMapping` 243 | - Removed methods `NewInstance` and `NewSlice` 244 | - `BindQuery()` signature changed to require a `*set.Mapper` 245 | - Upgrade `set` dependency to v0.5.1 for performance enhancements. 246 | - 0.3.0 ⭢ 0.4.0 247 | - `Transact(fn)` was correctly rolling the transaction back if `fn` returned `err != nil`; however 248 | the error from `fn` and any potential error from the rollback were not returned from `Transact()`. 249 | This is fixed in `0.4.0` and while technically a bug fix it _also_ changes the behavior of `Transact()` 250 | to (correctly) return errors as it should have been doing. As this is a potentially breaking change 251 | in behavior I have bumped the minor version for this patch. 252 | - 0.2.0 ⭢ 0.3.0 253 | - `grammar.Default` renamed to `grammar.Sqlite` -- generated SQL is same as previous version. 254 | - `grammar.Grammar` is now an interface where methods now return `(*statements.Query, error)` 255 | where previously only `(*statements.Query)` was returned. 256 | - Package grammar no longer has any panics; errors are returned instead (see previous note). 257 | - Prior to this release `model.Models` only ran queries that had followup targets 258 | for Scan() and panicked when such targets did not exist. This release allows for queries 259 | that do not have any Scan() targets and will switch to calling Exec() instead of Query() or 260 | QueryRow() when necessary. An implication of this change is that `Models.Insert()` and 261 | `Models.Update()` no longer panic in the absence of Scan() targets. 262 | -------------------------------------------------------------------------------- /changelog.txt: -------------------------------------------------------------------------------- 1 | /develop 2 | 3 | 0.5.1 4 | + Package maintenance. 5 | + Update dependencies. 6 | + Update badges. 7 | 8 | 0.5.0 9 | + Update dependency github.com/nofeaturesonlybugs/set to v0.5.1 10 | + This release offers performance enhancements that benefit sqlh. 11 | 12 | model 13 | + Models do not need to be registered via pointer. 14 | 15 | + Models can be registered via reflect.Type. 16 | 17 | + When calling methods on Models (Insert,Update,etc) slices of models can 18 | be passed as []T or []*T 19 | 20 | + Add Models.Save method. Save accepts *T, []T, or []*T and will delegate 21 | to the appropriate method (Insert,Update,Upsert) depending on the model 22 | and the current value of its key field(s). If []T or []*T is passed then 23 | the first element is inspected to determine the delegated method and this 24 | method is then applied to all elements. 25 | 26 | + Altered Model struct definition. 27 | + Removed fields V, VSlice, and BoundMapping 28 | + Removed methods NewInstance and NewSlice 29 | + BindQuery method signature altered 30 | 31 | 0.4.0 32 | + Transact(fn) was not correctly returning the error from the call to fn(); transactions 33 | themselves were correctly rolled back due to non-nil error returned from fn() but the 34 | error itself was not returning to the caller. 35 | 36 | 0.3.0 37 | + Breaking change migration (impact=low). 38 | + grammar.Default renamed to grammar.Sqlite -- generated SQL is same as previous version. 39 | + grammar.Grammar is now an interface where methods now return (*statements.Query, error) 40 | where previously only (*statements.Query) was returned. 41 | + Package grammar no longer has any panics; errors are returned instead (see previous note). 42 | + model.Models. Prior to this release Models only ran queries that had followup targets 43 | for Scan() and panicked when such targets did not exist. This release allows for queries 44 | that do not have any Scan() targets and will switch to calling Exec() instead of Query() or 45 | QueryRow() when necessary. An implication of this change is that Models.Insert() and 46 | Models.Update() no longer panic in the absence of Scan() targets. 47 | 48 | grammar 49 | + Grammar is now an interface. 50 | + Global Default renamed to Sqlite. 51 | + Add PostgresGrammar. 52 | + Add SqliteGrammar. 53 | + Add global error vars: ErrTableRequired, ErrColumnsRequired, & ErrKeysRequired. 54 | Grammar functions now return a variation of these errors indicating errors when building 55 | SQL statements. 56 | + Add Grammar.Upsert() to support INSERT...ON CONFLICT(...) DO UPDATE queries. 57 | 58 | hobbled 59 | + Package hobbled exposes facilities for hobbling database types such as *sql.DB or *sql.Tx by removing 60 | methods such as Begin() or Prepare(). This is useful when testing the dynamic nature of sqlh and its 61 | subpackages that can switch logic depending on the facilities of the database type given to it when 62 | performing work. 63 | 64 | model 65 | + Add global error ErrUnsupported. Models functions may return a variation of ErrUnsupported if 66 | an operation is called on a type for which the SQL execution can not be performed. 67 | 68 | + QueryBinding.QueryOne() does not return an error if database/sql returns sql.ErrNoRows 69 | and Query.Expect is equal to statements.ExpectRowOrNone. 70 | 71 | + QueryBinding.QuerySlice() does not return an error if database/sql returns sql.ErrNoRows 72 | and Query.Expect is equal to statements.ExpectRowOrNone. 73 | 74 | + Add Models.Upsert() for models that do not have "key,auto" primary keys. Upsert() currently 75 | only supports primary keys and does not support UNIQUE indexes. 76 | 77 | + Removed possible panics from Models.Insert() and Models.Update(). 78 | 79 | + Behavior change. Prior to this release Models only ran queries that had followup targets 80 | for Scan() and panicked when such targets did not exist. This release allows for queries 81 | that do not have any Scan() targets and will switch to calling Exec() instead of Query() or 82 | QueryRow() when necessary. 83 | 84 | model/statements 85 | + Add ExpectRowOrNone for queries that could return 0 or 1 rows and 0 does not indicate 86 | an error. 87 | 88 | sqlh 89 | + Add functions sqlh.Transact() and sqlh.TransactRollback(). Both functions ease the use 90 | of database transactions by wrapping a provided function argument inside a transaction. 91 | However sqlh.TransactRollback always calls tx.Rollback() to unwind the transaction, which is 92 | useful for writing test cases. 93 | 94 | 0.2.0 95 | 96 | + sqlh.Scanner.Select handles "no rows" for dest *T where T is a struct by setting 97 | the *T to nil. 98 | 99 | 0.1.0 100 | 101 | + Breaking change migration (impact=low). 102 | Interface sqlh.IRows renamed to sqlh.IIterates. 103 | 104 | + Add packages examples, grammar, model, & schema. 105 | + examples exports some utility to facilitate our example code. 106 | + grammar, model, and schema work together to provide a simple `model` 107 | layer to cut down on boiler plate for `INSERT|UPDATE` operations 108 | on Go types representing database tables. 109 | 110 | + sqlh.Scanner.Select can now scan the following types of results: 111 | + T where T is a scalar or primitive value. 112 | + []T where T is a scalar or primitive value. 113 | + T where T is a single struct. 114 | + []T where T is a struct slice. 115 | 116 | + Add interfaces sqlh.IBegins and sqlh.IPrepares. 117 | 118 | 0.0.2 119 | + No API change; clean up some documentation. 120 | 121 | 0.0.1 122 | + Add interfaces: 123 | + IQuery - A type that can run a database query. 124 | + IRows - A type that can iterate a database result set. 125 | + Add type: 126 | + Scanner - The powerhouse of the package. 127 | -------------------------------------------------------------------------------- /examples/mock.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "database/sql" 5 | "time" 6 | 7 | "github.com/DATA-DOG/go-sqlmock" 8 | ) 9 | 10 | // SentinalTime is a set time value used to generate times. 11 | var SentinalTime time.Time = time.Date(2006, 1, 2, 3, 4, 5, 7, time.Local) 12 | 13 | // TimeGenerator uses SentinalTime to return deterministic time values. 14 | type TimeGenerator struct { 15 | n int 16 | } 17 | 18 | // Next returns the next time.Time. 19 | func (tg *TimeGenerator) Next() time.Time { 20 | rv := SentinalTime.Add(time.Duration(tg.n) * time.Hour) 21 | tg.n++ 22 | return rv 23 | } 24 | 25 | // Example is a specific example. 26 | type Example int 27 | 28 | const ( 29 | ExSimpleMapper Example = iota 30 | ExTags 31 | ExNestedStruct 32 | ExNestedTwice 33 | ExScalar 34 | ExScalarSlice 35 | ExStruct 36 | ExStructNotFound 37 | ) 38 | 39 | // Connect creates a sqlmock DB and configures it for the example. 40 | func Connect(e Example) (DB *sql.DB, err error) { 41 | var mock sqlmock.Sqlmock 42 | DB, mock, err = sqlmock.New() 43 | // 44 | switch e { 45 | case ExSimpleMapper: 46 | mock.ExpectQuery("select +"). 47 | WillReturnRows( 48 | sqlmock.NewRows([]string{"Message", "Number"}). 49 | AddRow("Hello, World!", 42). 50 | AddRow("So long!", 100)). 51 | RowsWillBeClosed() 52 | 53 | case ExTags: 54 | mock.ExpectQuery("select +"). 55 | WillReturnRows( 56 | sqlmock.NewRows([]string{"message", "num"}). 57 | AddRow("Hello, World!", 42). 58 | AddRow("So long!", 100)). 59 | RowsWillBeClosed() 60 | 61 | case ExNestedStruct: 62 | tg := TimeGenerator{} 63 | mock.ExpectQuery("select +"). 64 | WillReturnRows( 65 | sqlmock.NewRows([]string{"id", "created", "modified", "message", "num"}). 66 | AddRow(1, tg.Next(), tg.Next(), "Hello, World!", 42). 67 | AddRow(2, tg.Next(), tg.Next(), "So long!", 100)). 68 | RowsWillBeClosed() 69 | 70 | case ExNestedTwice: 71 | tg := TimeGenerator{} 72 | mock.ExpectQuery("select +"). 73 | WillReturnRows( 74 | sqlmock.NewRows([]string{ 75 | "id", "created", "modified", 76 | "customer_id", "customer_first", "customer_last", 77 | "contact_id", "contact_first", "contact_last"}). 78 | AddRow(1, tg.Next(), tg.Next(), 10, "Bob", "Smith", 100, "Sally", "Johnson"). 79 | AddRow(2, tg.Next(), tg.Next(), 20, "Fred", "Jones", 200, "Betty", "Walker")). 80 | RowsWillBeClosed() 81 | 82 | case ExScalar: 83 | mock.ExpectQuery("select +"). 84 | WillReturnRows(sqlmock.NewRows([]string{"n"}).AddRow(64)).RowsWillBeClosed() 85 | 86 | case ExScalarSlice: 87 | mock.ExpectQuery("select +"). 88 | WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1).AddRow(2).AddRow(3)).RowsWillBeClosed() 89 | 90 | case ExStruct: 91 | mock.ExpectQuery("select +"). 92 | WillReturnRows(sqlmock.NewRows([]string{"min", "max"}). 93 | AddRow( 94 | // Don't use TimeGenerator here; these values are hard coded into the // Output: block of an example. 95 | time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), 96 | time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC)), 97 | ). 98 | RowsWillBeClosed() 99 | 100 | case ExStructNotFound: 101 | mock.ExpectQuery("select +"). 102 | WillReturnRows( 103 | sqlmock.NewRows([]string{ 104 | "id", "created", "modified", 105 | "customer_id", "customer_first", "customer_last", 106 | "contact_id", "contact_first", "contact_last"})). 107 | RowsWillBeClosed() 108 | 109 | } 110 | 111 | return DB, err 112 | } 113 | -------------------------------------------------------------------------------- /examples/pkg_doc.go: -------------------------------------------------------------------------------- 1 | // Package examples provides some common code for sqlh examples and tests. 2 | package examples 3 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/nofeaturesonlybugs/sqlh 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/DATA-DOG/go-sqlmock v1.5.0 7 | github.com/nofeaturesonlybugs/errors v1.1.1 8 | github.com/nofeaturesonlybugs/set v0.5.2 9 | github.com/stretchr/testify v1.7.2 10 | ) 11 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= 2 | github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/nofeaturesonlybugs/errors v1.1.1 h1:hB+r5k391E9p/RwXYfLbkzmyzM+U4QLLe+Kt4SXa3uQ= 7 | github.com/nofeaturesonlybugs/errors v1.1.1/go.mod h1:nduBPri47dRYPSfaHxuboF0/Z2835DHp4qw9mtz8WeI= 8 | github.com/nofeaturesonlybugs/set v0.5.2 h1:jYpcJ80zMN9SeUiCv7DeJBHmnjagfZPd6Cxv7Wamas4= 9 | github.com/nofeaturesonlybugs/set v0.5.2/go.mod h1:PJIUfp+9nsOTBYANBs4H0wfLz9iYpasx0pGIZRcJCoU= 10 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 11 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 12 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 13 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 14 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 15 | github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= 16 | github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= 17 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 18 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 19 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 20 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 21 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 22 | -------------------------------------------------------------------------------- /grammar/errors.go: -------------------------------------------------------------------------------- 1 | package grammar 2 | 3 | import "errors" 4 | 5 | var ( 6 | ErrTableRequired error = errors.New("table name is required") 7 | ErrColumnsRequired error = errors.New("columns are required") 8 | ErrKeysRequired error = errors.New("keys are required") 9 | ) 10 | -------------------------------------------------------------------------------- /grammar/grammar.go: -------------------------------------------------------------------------------- 1 | package grammar 2 | 3 | import ( 4 | "github.com/nofeaturesonlybugs/sqlh/model/statements" 5 | ) 6 | 7 | // Grammar creates SQL queries for a specific database engine. 8 | type Grammar interface { 9 | // Delete returns the query type for deleting from the table. 10 | Delete(table string, keys []string) (*statements.Query, error) 11 | // Insert returns the query type for inserting into table. 12 | Insert(table string, columns []string, auto []string) (*statements.Query, error) 13 | // Update returns the query type for updating a record in a table. 14 | Update(table string, columns []string, keys []string, auto []string) (*statements.Query, error) 15 | // Upsert returns the query type for upserting (INSERT|UPDATE) a record in a table. 16 | Upsert(table string, columns []string, keys []string, auto []string) (*statements.Query, error) 17 | } 18 | 19 | // TODO Implement Driver. 20 | // // Driver describes features of the underlying database/sql driver. 21 | // type Driver struct { 22 | // // Set to true if the driver supports LastInsertID. Currently drivers that do not support LastInsertID 23 | // // are expected to use a RETURNING clause in queries. 24 | // LastInsertID bool 25 | // } 26 | -------------------------------------------------------------------------------- /grammar/grammar_postgres_test.go: -------------------------------------------------------------------------------- 1 | package grammar_test 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | 9 | "github.com/nofeaturesonlybugs/sqlh/grammar" 10 | "github.com/nofeaturesonlybugs/sqlh/model/statements" 11 | ) 12 | 13 | func TestPostgresGrammar(t *testing.T) { 14 | chk := assert.New(t) 15 | // 16 | g := grammar.Postgres 17 | // 18 | { 19 | // inserts 20 | columns := []string{"a", "b", "c"} 21 | auto := []string{} 22 | query, err := g.Insert("foo", columns, auto) 23 | chk.NoError(err) 24 | chk.NotEmpty(query.SQL) 25 | chk.NotEmpty(query.Arguments) 26 | expect := "INSERT INTO foo\n\t\t( a, b, c )\n\tVALUES\n\t\t( $1, $2, $3 )" 27 | chk.Equal(expect, query.SQL) 28 | chk.Equal([]string{"a", "b", "c"}, query.Arguments) 29 | chk.Empty(query.Scan) 30 | // insert with returning 31 | columns = []string{"a", "b", "c"} 32 | auto = []string{"x", "y", "z"} 33 | query, err = g.Insert("foo", columns, auto) 34 | chk.NoError(err) 35 | chk.NotEmpty(query.SQL) 36 | chk.NotEmpty(query.Arguments) 37 | expect = "INSERT INTO foo\n\t\t( a, b, c )\n\tVALUES\n\t\t( $1, $2, $3 )\n\tRETURNING x, y, z" 38 | chk.Equal(expect, query.SQL) 39 | chk.Equal([]string{"a", "b", "c"}, query.Arguments) 40 | chk.Equal([]string{"x", "y", "z"}, query.Scan) 41 | } 42 | // 43 | { 44 | // updates 45 | columns := []string{"a", "b", "c"} 46 | keys := []string{"x"} 47 | auto := []string(nil) 48 | query, err := g.Update("foo", columns, keys, auto) 49 | chk.NoError(err) 50 | chk.NotEmpty(query.SQL) 51 | chk.NotEmpty(query.Arguments) 52 | expect := "UPDATE foo SET\n\t\ta = $1,\n\t\tb = $2,\n\t\tc = $3\n\tWHERE\n\t\tx = $4" 53 | chk.Equal(expect, query.SQL) 54 | chk.Equal(append(append([]string{}, columns...), keys...), query.Arguments) 55 | chk.Empty(query.Scan) 56 | // update with returning 57 | columns = []string{"a", "b", "c"} 58 | keys = []string{"x"} 59 | auto = []string{"y", "z"} 60 | query, err = g.Update("foo", columns, keys, auto) 61 | chk.NoError(err) 62 | chk.NotEmpty(query.SQL) 63 | chk.NotEmpty(query.Arguments) 64 | expect = "UPDATE foo SET\n\t\ta = $1,\n\t\tb = $2,\n\t\tc = $3\n\tWHERE\n\t\tx = $4\n\tRETURNING y, z" 65 | chk.Equal(expect, query.SQL) 66 | chk.Equal(append(append([]string{}, columns...), keys...), query.Arguments) 67 | chk.Equal([]string{"y", "z"}, query.Scan) 68 | } 69 | // 70 | { 71 | // deletes 72 | keys := []string{"x"} 73 | query, err := g.Delete("foo", keys) 74 | chk.NoError(err) 75 | chk.NotEmpty(query.SQL) 76 | chk.NotEmpty(query.Arguments) 77 | expect := "DELETE FROM foo\n\tWHERE\n\t\tx = $1" 78 | chk.Equal(expect, query.SQL) 79 | chk.Equal(append([]string{}, keys...), query.Arguments) 80 | chk.Empty(query.Scan) 81 | // 82 | // composite key 83 | keys = []string{"x", "y", "z"} 84 | query, err = g.Delete("foo", keys) 85 | chk.NoError(err) 86 | chk.NotEmpty(query.SQL) 87 | chk.NotEmpty(query.Arguments) 88 | expect = "DELETE FROM foo\n\tWHERE\n\t\tx = $1 AND y = $2 AND z = $3" 89 | chk.Equal(expect, query.SQL) 90 | chk.Equal(append([]string{}, keys...), query.Arguments) 91 | chk.Empty(query.Scan) 92 | } 93 | } 94 | 95 | func TestPostgresReturnsErrors(t *testing.T) { 96 | chk := assert.New(t) 97 | // 98 | var err error 99 | table, columns, keys := "mytable", []string{"a", "b", "c"}, []string{"x", "y"} 100 | g := grammar.Postgres 101 | // Missing table name. 102 | _, err = g.Delete("", keys) 103 | chk.Error(err) 104 | _, err = g.Insert("", columns, nil) 105 | chk.Error(err) 106 | _, err = g.Update("", columns, keys, nil) 107 | chk.Error(err) 108 | _, err = g.Upsert("", columns, keys, nil) 109 | chk.Error(err) 110 | // Missing keys. 111 | _, err = g.Delete(table, nil) 112 | chk.Error(err) 113 | _, err = g.Update(table, columns, nil, nil) 114 | chk.Error(err) 115 | _, err = g.Upsert(table, columns, nil, nil) 116 | chk.Error(err) 117 | // Missing columns. 118 | _, err = g.Insert(table, nil, nil) 119 | chk.Error(err) 120 | _, err = g.Update(table, nil, keys, nil) 121 | chk.Error(err) 122 | _, err = g.Upsert(table, nil, keys, nil) 123 | chk.Error(err) 124 | } 125 | 126 | func TestPostgresGrammarUpsert(t *testing.T) { 127 | chk := assert.New(t) 128 | // 129 | g := grammar.Postgres 130 | { // single key, no auto 131 | columns := []string{"a", "b", "c"} 132 | keys := []string{"key"} 133 | query, err := g.Upsert("foo", columns, keys, nil) 134 | chk.NoError(err) 135 | chk.NotNil(query) 136 | chk.NotEmpty(query.SQL) 137 | chk.NotEmpty(query.Arguments) 138 | parts := []string{ 139 | "INSERT INTO foo AS dest\n\t\t( key, a, b, c )\n\tVALUES\n\t\t( $1, $2, $3, $4 )", 140 | "\tON CONFLICT( key ) DO UPDATE SET", 141 | "\t\ta = EXCLUDED.a, b = EXCLUDED.b, c = EXCLUDED.c", 142 | "\t\tWHERE (\n\t\t\tdest.a <> EXCLUDED.a OR dest.b <> EXCLUDED.b OR dest.c <> EXCLUDED.c\n\t\t)", 143 | } 144 | expect := strings.Join(parts, "\n") 145 | chk.Equal(expect, query.SQL) 146 | args := append([]string{}, keys...) 147 | args = append(args, columns...) 148 | chk.Equal(args, query.Arguments) 149 | } 150 | { // composite key, no auto 151 | columns := []string{"a", "b", "c"} 152 | keys := []string{"key1", "key2", "key3"} 153 | query, err := g.Upsert("foo", columns, keys, nil) 154 | chk.NoError(err) 155 | chk.NotNil(query) 156 | chk.NotEmpty(query.SQL) 157 | chk.NotEmpty(query.Arguments) 158 | parts := []string{ 159 | "INSERT INTO foo AS dest\n\t\t( key1, key2, key3, a, b, c )\n\tVALUES\n\t\t( $1, $2, $3, $4, $5, $6 )", 160 | "\tON CONFLICT( key1, key2, key3 ) DO UPDATE SET", 161 | "\t\ta = EXCLUDED.a, b = EXCLUDED.b, c = EXCLUDED.c", 162 | "\t\tWHERE (\n\t\t\tdest.a <> EXCLUDED.a OR dest.b <> EXCLUDED.b OR dest.c <> EXCLUDED.c\n\t\t)", 163 | } 164 | expect := strings.Join(parts, "\n") 165 | chk.Equal(expect, query.SQL) 166 | args := append([]string{}, keys...) 167 | args = append(args, columns...) 168 | chk.Equal(args, query.Arguments) 169 | } 170 | { // composite key, has auto 171 | columns := []string{"a", "b", "c"} 172 | keys := []string{"key1", "key2", "key3"} 173 | auto := []string{"created", "modified"} 174 | query, err := g.Upsert("foo", columns, keys, auto) 175 | chk.NoError(err) 176 | chk.NotNil(query) 177 | chk.NotEmpty(query.SQL) 178 | chk.NotEmpty(query.Arguments) 179 | parts := []string{ 180 | "INSERT INTO foo AS dest\n\t\t( key1, key2, key3, a, b, c )\n\tVALUES\n\t\t( $1, $2, $3, $4, $5, $6 )", 181 | "\tON CONFLICT( key1, key2, key3 ) DO UPDATE SET", 182 | "\t\ta = EXCLUDED.a, b = EXCLUDED.b, c = EXCLUDED.c", 183 | "\t\tWHERE (\n\t\t\tdest.a <> EXCLUDED.a OR dest.b <> EXCLUDED.b OR dest.c <> EXCLUDED.c\n\t\t)", 184 | "\tRETURNING created, modified", 185 | } 186 | expect := strings.Join(parts, "\n") 187 | chk.Equal(expect, query.SQL) 188 | args := append([]string{}, keys...) 189 | args = append(args, columns...) 190 | chk.Equal(args, query.Arguments) 191 | } 192 | { 193 | // various errors... 194 | var query *statements.Query 195 | var err error 196 | // 197 | // empty table 198 | query, err = g.Upsert("", []string{"a", "b", "c"}, []string{"key"}, nil) 199 | chk.Error(err) 200 | chk.Nil(query) 201 | // empty columns 202 | query, err = g.Upsert("foo", nil, []string{"key"}, nil) 203 | chk.Error(err) 204 | chk.Nil(query) 205 | // empty keys 206 | query, err = g.Upsert("foo", []string{"a", "b", "c"}, nil, nil) 207 | chk.Error(err) 208 | chk.Nil(query) 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /grammar/grammar_sqlite_test.go: -------------------------------------------------------------------------------- 1 | package grammar_test 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | 9 | "github.com/nofeaturesonlybugs/sqlh/grammar" 10 | "github.com/nofeaturesonlybugs/sqlh/model/statements" 11 | ) 12 | 13 | func TestDefaultGrammar(t *testing.T) { 14 | chk := assert.New(t) 15 | // 16 | g := grammar.Sqlite 17 | // 18 | { 19 | // inserts 20 | columns := []string{"a", "b", "c"} 21 | auto := []string(nil) 22 | query, err := g.Insert("foo", columns, auto) 23 | chk.NoError(err) 24 | chk.NotEmpty(query.SQL) 25 | chk.NotEmpty(query.Arguments) 26 | expect := "INSERT INTO foo\n\t\t( a, b, c )\n\tVALUES\n\t\t( ?, ?, ? )" 27 | chk.Equal(expect, query.SQL) 28 | chk.Equal([]string{"a", "b", "c"}, query.Arguments) 29 | chk.Empty(query.Scan) 30 | // insert with returning 31 | columns = []string{"a", "b", "c"} 32 | auto = []string{"x", "y", "z"} 33 | query, err = g.Insert("foo", columns, auto) 34 | chk.NoError(err) 35 | chk.NotEmpty(query.SQL) 36 | chk.NotEmpty(query.Arguments) 37 | expect = "INSERT INTO foo\n\t\t( a, b, c )\n\tVALUES\n\t\t( ?, ?, ? )\n\tRETURNING x, y, z" 38 | chk.Equal(expect, query.SQL) 39 | chk.Equal([]string{"a", "b", "c"}, query.Arguments) 40 | chk.Equal([]string{"x", "y", "z"}, query.Scan) 41 | } 42 | // 43 | { 44 | // updates 45 | columns := []string{"a", "b", "c"} 46 | keys := []string{"x"} 47 | auto := []string(nil) 48 | query, err := g.Update("foo", columns, keys, auto) 49 | chk.NoError(err) 50 | chk.NotEmpty(query.SQL) 51 | chk.NotEmpty(query.Arguments) 52 | expect := "UPDATE foo SET\n\t\ta = ?,\n\t\tb = ?,\n\t\tc = ?\n\tWHERE\n\t\tx = ?" 53 | chk.Equal(expect, query.SQL) 54 | chk.Equal(append(append([]string{}, columns...), keys...), query.Arguments) 55 | chk.Empty(query.Scan) 56 | // update with returning 57 | columns = []string{"a", "b", "c"} 58 | keys = []string{"x"} 59 | auto = []string{"y", "z"} 60 | query, err = g.Update("foo", columns, keys, auto) 61 | chk.NoError(err) 62 | chk.NotEmpty(query.SQL) 63 | chk.NotEmpty(query.Arguments) 64 | expect = "UPDATE foo SET\n\t\ta = ?,\n\t\tb = ?,\n\t\tc = ?\n\tWHERE\n\t\tx = ?\n\tRETURNING y, z" 65 | chk.Equal(expect, query.SQL) 66 | chk.Equal(append(append([]string{}, columns...), keys...), query.Arguments) 67 | chk.Equal([]string{"y", "z"}, query.Scan) 68 | } 69 | // 70 | { 71 | // deletes 72 | keys := []string{"x"} 73 | query, err := g.Delete("foo", keys) 74 | chk.NoError(err) 75 | chk.NotEmpty(query.SQL) 76 | chk.NotEmpty(query.Arguments) 77 | expect := "DELETE FROM foo\n\tWHERE\n\t\tx = ?" 78 | chk.Equal(expect, query.SQL) 79 | chk.Equal(append([]string{}, keys...), query.Arguments) 80 | chk.Empty(query.Scan) 81 | // 82 | // composite key 83 | keys = []string{"x", "y", "z"} 84 | query, err = g.Delete("foo", keys) 85 | chk.NoError(err) 86 | chk.NotEmpty(query.SQL) 87 | chk.NotEmpty(query.Arguments) 88 | expect = "DELETE FROM foo\n\tWHERE\n\t\tx = ? AND y = ? AND z = ?" 89 | chk.Equal(expect, query.SQL) 90 | chk.Equal(append([]string{}, keys...), query.Arguments) 91 | chk.Empty(query.Scan) 92 | } 93 | } 94 | 95 | func TestDefaultReturnsErrors(t *testing.T) { 96 | chk := assert.New(t) 97 | // 98 | var err error 99 | table, columns, keys := "mytable", []string{"a", "b", "c"}, []string{"x", "y"} 100 | g := grammar.Sqlite 101 | // Missing table name. 102 | _, err = g.Delete("", keys) 103 | chk.Error(err) 104 | _, err = g.Insert("", columns, nil) 105 | chk.Error(err) 106 | _, err = g.Update("", columns, keys, nil) 107 | chk.Error(err) 108 | _, err = g.Upsert("", columns, keys, nil) 109 | chk.Error(err) 110 | // Missing keys. 111 | _, err = g.Delete(table, nil) 112 | chk.Error(err) 113 | _, err = g.Update(table, columns, nil, nil) 114 | chk.Error(err) 115 | _, err = g.Upsert(table, columns, nil, nil) 116 | chk.Error(err) 117 | // Missing columns. 118 | _, err = g.Insert(table, nil, nil) 119 | chk.Error(err) 120 | _, err = g.Update(table, nil, keys, nil) 121 | chk.Error(err) 122 | _, err = g.Upsert(table, nil, keys, nil) 123 | chk.Error(err) 124 | } 125 | 126 | func TestDefaultGrammarUpsert(t *testing.T) { 127 | chk := assert.New(t) 128 | // 129 | g := grammar.Sqlite 130 | { // single key, no auto 131 | columns := []string{"a", "b", "c"} 132 | keys := []string{"key"} 133 | query, err := g.Upsert("foo", columns, keys, nil) 134 | chk.NoError(err) 135 | chk.NotNil(query) 136 | chk.NotEmpty(query.SQL) 137 | chk.NotEmpty(query.Arguments) 138 | parts := []string{ 139 | "INSERT INTO foo\n\t\t( key, a, b, c )\n\tVALUES\n\t\t( ?, ?, ?, ? )", 140 | "\tON CONFLICT( key ) DO UPDATE SET", 141 | "\t\tfoo.a = EXCLUDED.a, foo.b = EXCLUDED.b, foo.c = EXCLUDED.c", 142 | "\t\tWHERE (\n\t\t\tfoo.a <> EXCLUDED.a OR foo.b <> EXCLUDED.b OR foo.c <> EXCLUDED.c\n\t\t)", 143 | } 144 | expect := strings.Join(parts, "\n") 145 | chk.Equal(expect, query.SQL) 146 | args := append([]string{}, keys...) 147 | args = append(args, columns...) 148 | chk.Equal(args, query.Arguments) 149 | } 150 | { // composite key, no auto 151 | columns := []string{"a", "b", "c"} 152 | keys := []string{"key1", "key2", "key3"} 153 | query, err := g.Upsert("foo", columns, keys, nil) 154 | chk.NoError(err) 155 | chk.NotNil(query) 156 | chk.NotEmpty(query.SQL) 157 | chk.NotEmpty(query.Arguments) 158 | parts := []string{ 159 | "INSERT INTO foo\n\t\t( key1, key2, key3, a, b, c )\n\tVALUES\n\t\t( ?, ?, ?, ?, ?, ? )", 160 | "\tON CONFLICT( key1, key2, key3 ) DO UPDATE SET", 161 | "\t\tfoo.a = EXCLUDED.a, foo.b = EXCLUDED.b, foo.c = EXCLUDED.c", 162 | "\t\tWHERE (\n\t\t\tfoo.a <> EXCLUDED.a OR foo.b <> EXCLUDED.b OR foo.c <> EXCLUDED.c\n\t\t)", 163 | } 164 | expect := strings.Join(parts, "\n") 165 | chk.Equal(expect, query.SQL) 166 | args := append([]string{}, keys...) 167 | args = append(args, columns...) 168 | chk.Equal(args, query.Arguments) 169 | } 170 | { // composite key, has auto 171 | columns := []string{"a", "b", "c"} 172 | keys := []string{"key1", "key2", "key3"} 173 | auto := []string{"created", "modified"} 174 | query, err := g.Upsert("foo", columns, keys, auto) 175 | chk.NoError(err) 176 | chk.NotNil(query) 177 | chk.NotEmpty(query.SQL) 178 | chk.NotEmpty(query.Arguments) 179 | parts := []string{ 180 | "INSERT INTO foo\n\t\t( key1, key2, key3, a, b, c )\n\tVALUES\n\t\t( ?, ?, ?, ?, ?, ? )", 181 | "\tON CONFLICT( key1, key2, key3 ) DO UPDATE SET", 182 | "\t\tfoo.a = EXCLUDED.a, foo.b = EXCLUDED.b, foo.c = EXCLUDED.c", 183 | "\t\tWHERE (\n\t\t\tfoo.a <> EXCLUDED.a OR foo.b <> EXCLUDED.b OR foo.c <> EXCLUDED.c\n\t\t)", 184 | "\tRETURNING created, modified", 185 | } 186 | expect := strings.Join(parts, "\n") 187 | chk.Equal(expect, query.SQL) 188 | args := append([]string{}, keys...) 189 | args = append(args, columns...) 190 | chk.Equal(args, query.Arguments) 191 | } 192 | { 193 | // various errors... 194 | var query *statements.Query 195 | var err error 196 | // 197 | // empty table 198 | query, err = g.Upsert("", []string{"a", "b", "c"}, []string{"key"}, nil) 199 | chk.Error(err) 200 | chk.Nil(query) 201 | // empty columns 202 | query, err = g.Upsert("foo", nil, []string{"key"}, nil) 203 | chk.Error(err) 204 | chk.Nil(query) 205 | // empty keys 206 | query, err = g.Upsert("foo", []string{"a", "b", "c"}, nil, nil) 207 | chk.Error(err) 208 | chk.Nil(query) 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /grammar/pkg_doc.go: -------------------------------------------------------------------------------- 1 | // Package grammar generates SQL statements described by a configured grammar. 2 | package grammar 3 | -------------------------------------------------------------------------------- /grammar/postgres.go: -------------------------------------------------------------------------------- 1 | package grammar 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/nofeaturesonlybugs/errors" 8 | "github.com/nofeaturesonlybugs/sqlh/model/statements" 9 | ) 10 | 11 | // Postgres is an instantiated grammar for PostgreSQL. 12 | var Postgres Grammar = &PostgresGrammar{} 13 | 14 | // PostgresGrammar defines a grammar for Postgres. 15 | type PostgresGrammar struct { 16 | } 17 | 18 | // ParamN returns the string for parameter N where N is zero-based and return value is one-based. 19 | func (me *PostgresGrammar) ParamN(n int) string { 20 | return fmt.Sprintf("$%v", n+1) 21 | } 22 | 23 | // Delete returns the query type for deleting from the table. 24 | func (me *PostgresGrammar) Delete(table string, keys []string) (*statements.Query, error) { 25 | var keySize int 26 | if table == "" { 27 | return nil, errors.Go(ErrTableRequired) 28 | } else if keySize = len(keys); keySize == 0 { 29 | return nil, errors.Go(ErrKeysRequired).Tag("table", table).Tag("SQL", "DELETE") 30 | } 31 | rv := &statements.Query{ 32 | Arguments: make([]string, keySize), 33 | } 34 | wheres := make([]string, keySize) 35 | for k, key := range keys { 36 | wheres[k] = key + " = " + me.ParamN(k) 37 | rv.Arguments[k] = key 38 | } 39 | // 40 | parts := []string{ 41 | "DELETE FROM " + table, 42 | "\tWHERE", 43 | "\t\t" + strings.Join(wheres, " AND "), 44 | } 45 | rv.SQL = strings.Join(parts, "\n") 46 | return rv, nil 47 | } 48 | 49 | // Insert returns the query type for inserting into table. 50 | func (me *PostgresGrammar) Insert(table string, columns []string, auto []string) (*statements.Query, error) { 51 | var colSize int 52 | if table == "" { 53 | return nil, errors.Go(ErrTableRequired) 54 | } else if colSize = len(columns); colSize == 0 { 55 | return nil, errors.Go(ErrColumnsRequired).Tag("table", table).Tag("SQL", "INSERT") 56 | } 57 | rv := &statements.Query{ 58 | Arguments: make([]string, colSize), 59 | } 60 | values := make([]string, colSize) 61 | for k, column := range columns { 62 | values[k] = me.ParamN(k) 63 | rv.Arguments[k] = column 64 | } 65 | // 66 | parts := []string{ 67 | "INSERT INTO " + table, 68 | "\t\t( " + strings.Join(columns, ", ") + " )", 69 | "\tVALUES", 70 | "\t\t( " + strings.Join(values, ", ") + " )", 71 | } 72 | if len(auto) > 0 { 73 | parts = append(parts, "\tRETURNING "+strings.Join(auto, ", ")) 74 | rv.Scan = append([]string{}, auto...) 75 | rv.Expect = statements.ExpectRow 76 | } 77 | rv.SQL = strings.Join(parts, "\n") 78 | return rv, nil 79 | } 80 | 81 | // Update returns the query type for updating a record in a table. 82 | func (me *PostgresGrammar) Update(table string, columns []string, keys []string, auto []string) (*statements.Query, error) { 83 | var colSize, keySize int 84 | if table == "" { 85 | return nil, errors.Go(ErrTableRequired) 86 | } else if colSize = len(columns); colSize == 0 { 87 | return nil, errors.Go(ErrColumnsRequired).Tag("table", table).Tag("SQL", "UPDATE") 88 | } else if keySize = len(keys); keySize == 0 { 89 | return nil, errors.Go(ErrKeysRequired).Tag("table", table).Tag("SQL", "UPDATE") 90 | } 91 | rv := &statements.Query{ 92 | Arguments: make([]string, colSize+keySize), 93 | } 94 | // 95 | sets, wheres := make([]string, colSize), make([]string, keySize) 96 | for k, column := range columns { 97 | sets[k] = column + " = " + me.ParamN(k) 98 | rv.Arguments[k] = column 99 | } 100 | for k, key := range keys { 101 | total := colSize + k 102 | wheres[k] = key + " = " + me.ParamN(total) 103 | rv.Arguments[total] = key 104 | } 105 | // 106 | parts := []string{ 107 | "UPDATE " + table + " SET", 108 | "\t\t" + strings.Join(sets, ",\n\t\t"), 109 | "\tWHERE", 110 | "\t\t" + strings.Join(wheres, " AND "), 111 | } 112 | if len(auto) > 0 { 113 | parts = append(parts, "\tRETURNING "+strings.Join(auto, ", ")) 114 | rv.Scan = append([]string{}, auto...) 115 | rv.Expect = statements.ExpectRowOrNone 116 | } 117 | rv.SQL = strings.Join(parts, "\n") 118 | return rv, nil 119 | } 120 | 121 | // Upsert returns the query type for upserting (INSERT|UPDATE) a record in a table. 122 | func (me *PostgresGrammar) Upsert(table string, columns []string, keys []string, auto []string) (*statements.Query, error) { 123 | var colSize, keySize int 124 | if table == "" { 125 | return nil, errors.Go(ErrTableRequired) 126 | } else if colSize = len(columns); colSize == 0 { 127 | return nil, errors.Go(ErrColumnsRequired).Tag("table", table).Tag("SQL", "UPDATE") 128 | } else if keySize = len(keys); keySize == 0 { 129 | return nil, errors.Go(ErrKeysRequired).Tag("table", table).Tag("SQL", "UPDATE") 130 | } 131 | // Both keys + columns are combined for the INSERT portion of the query. 132 | sizeInsert := colSize + keySize 133 | rv := &statements.Query{ 134 | Arguments: make([]string, sizeInsert), 135 | } 136 | copy(rv.Arguments[0:], keys) 137 | copy(rv.Arguments[keySize:], columns) 138 | // INSERT...VALUES portion. 139 | values := make([]string, sizeInsert) 140 | for k := range rv.Arguments { 141 | values[k] = me.ParamN(k) 142 | } 143 | // Create an AS alias for the target table. 144 | alias := "dest" 145 | // Only columns are used for the DO UPDATE portion of the query. 146 | updateColumns := make([]string, colSize) 147 | whereColumns := make([]string, colSize) 148 | for k, column := range columns { 149 | updateColumns[k] = column + " = EXCLUDED." + column 150 | whereColumns[k] = alias + "." + column + " <> EXCLUDED." + column 151 | } 152 | // 153 | parts := []string{ 154 | "INSERT INTO " + table + " AS " + alias, 155 | "\t\t( " + strings.Join(rv.Arguments, ", ") + " )", 156 | "\tVALUES", 157 | "\t\t( " + strings.Join(values, ", ") + " )", 158 | "\tON CONFLICT( " + strings.Join(keys, ", ") + " ) DO UPDATE SET", 159 | "\t\t" + strings.Join(updateColumns, ", "), 160 | "\t\tWHERE (", 161 | "\t\t\t" + strings.Join(whereColumns, " OR "), 162 | "\t\t)", 163 | } 164 | if len(auto) > 0 { 165 | parts = append(parts, "\tRETURNING "+strings.Join(auto, ", ")) 166 | rv.Scan = append([]string{}, auto...) 167 | rv.Expect = statements.ExpectRowOrNone 168 | } 169 | rv.SQL = strings.Join(parts, "\n") 170 | return rv, nil 171 | } 172 | -------------------------------------------------------------------------------- /grammar/sqlite.go: -------------------------------------------------------------------------------- 1 | package grammar 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/nofeaturesonlybugs/errors" 7 | "github.com/nofeaturesonlybugs/sqlh/model/statements" 8 | ) 9 | 10 | // Sqlite is an instantiated grammar for SQLite. 11 | var Sqlite Grammar = &SqliteGrammar{} 12 | 13 | // SqliteGrammar defines a grammar for SQLite v2.35+. 14 | type SqliteGrammar struct { 15 | } 16 | 17 | // Delete returns the query type for deleting from the table. 18 | func (me *SqliteGrammar) Delete(table string, keys []string) (*statements.Query, error) { 19 | var keySize int 20 | if table == "" { 21 | return nil, errors.Go(ErrTableRequired) 22 | } else if keySize = len(keys); keySize == 0 { 23 | return nil, errors.Go(ErrKeysRequired).Tag("table", table).Tag("SQL", "DELETE") 24 | } 25 | rv := &statements.Query{ 26 | Arguments: make([]string, keySize), 27 | } 28 | // 29 | wheres := make([]string, keySize) 30 | for k, key := range keys { 31 | wheres[k] = key + " = ?" 32 | rv.Arguments[k] = key 33 | } 34 | // 35 | parts := []string{ 36 | "DELETE FROM " + table, 37 | "\tWHERE", 38 | "\t\t" + strings.Join(wheres, " AND "), 39 | } 40 | rv.SQL = strings.Join(parts, "\n") 41 | return rv, nil 42 | } 43 | 44 | // Insert returns the query type for inserting into table. 45 | func (me *SqliteGrammar) Insert(table string, columns []string, auto []string) (*statements.Query, error) { 46 | var colSize int 47 | if table == "" { 48 | return nil, errors.Go(ErrTableRequired) 49 | } else if colSize = len(columns); colSize == 0 { 50 | return nil, errors.Go(ErrColumnsRequired).Tag("table", table).Tag("SQL", "INSERT") 51 | } 52 | rv := &statements.Query{ 53 | Arguments: make([]string, colSize), 54 | } 55 | copy(rv.Arguments[0:], columns) 56 | values := "?" + strings.Repeat(", ?", colSize-1) 57 | // 58 | parts := []string{ 59 | "INSERT INTO " + table, 60 | "\t\t( " + strings.Join(columns, ", ") + " )", 61 | "\tVALUES", 62 | "\t\t( " + values + " )", 63 | } 64 | if len(auto) > 0 { 65 | parts = append(parts, "\tRETURNING "+strings.Join(auto, ", ")) 66 | rv.Scan = append([]string{}, auto...) 67 | rv.Expect = statements.ExpectRow 68 | } 69 | rv.SQL = strings.Join(parts, "\n") 70 | return rv, nil 71 | } 72 | 73 | // Update returns the query type for updating a record in a table. 74 | func (me *SqliteGrammar) Update(table string, columns []string, keys []string, auto []string) (*statements.Query, error) { 75 | var colSize, keySize int 76 | if table == "" { 77 | return nil, errors.Go(ErrTableRequired) 78 | } else if colSize = len(columns); colSize == 0 { 79 | return nil, errors.Go(ErrColumnsRequired).Tag("table", table).Tag("SQL", "UPDATE") 80 | } else if keySize = len(keys); keySize == 0 { 81 | return nil, errors.Go(ErrKeysRequired).Tag("table", table).Tag("SQL", "UPDATE") 82 | } 83 | rv := &statements.Query{ 84 | Arguments: make([]string, colSize+keySize), 85 | } 86 | sets, wheres := make([]string, colSize), make([]string, keySize) 87 | for k, column := range columns { 88 | sets[k] = column + " = ?" 89 | rv.Arguments[k] = column 90 | } 91 | for k, key := range keys { 92 | wheres[k] = key + " = ?" 93 | rv.Arguments[colSize+k] = key 94 | } 95 | // 96 | parts := []string{ 97 | "UPDATE " + table + " SET", 98 | "\t\t" + strings.Join(sets, ",\n\t\t"), 99 | "\tWHERE", 100 | "\t\t" + strings.Join(wheres, " AND "), 101 | } 102 | if len(auto) > 0 { 103 | parts = append(parts, "\tRETURNING "+strings.Join(auto, ", ")) 104 | rv.Scan = append([]string{}, auto...) 105 | rv.Expect = statements.ExpectRow 106 | } 107 | rv.SQL = strings.Join(parts, "\n") 108 | return rv, nil 109 | } 110 | 111 | // Upsert returns the query type for upserting (INSERT|UPDATE) a record in a table. 112 | func (me *SqliteGrammar) Upsert(table string, columns []string, keys []string, auto []string) (*statements.Query, error) { 113 | var colSize, keySize int 114 | if table == "" { 115 | return nil, errors.Go(ErrTableRequired) 116 | } else if colSize = len(columns); colSize == 0 { 117 | return nil, errors.Go(ErrColumnsRequired).Tag("table", table).Tag("SQL", "UPDATE") 118 | } else if keySize = len(keys); keySize == 0 { 119 | return nil, errors.Go(ErrKeysRequired).Tag("table", table).Tag("SQL", "UPDATE") 120 | } 121 | // Both keys + columns are combined for the INSERT portion of the query. 122 | sizeInsert := colSize + keySize 123 | rv := &statements.Query{ 124 | Arguments: make([]string, sizeInsert), 125 | } 126 | copy(rv.Arguments[0:], keys) 127 | copy(rv.Arguments[keySize:], columns) 128 | // Only columns are used for the DO UPDATE portion of the query. 129 | updateColumns := make([]string, colSize) 130 | whereColumns := make([]string, colSize) 131 | for k, column := range columns { 132 | updateColumns[k] = table + "." + column + " = EXCLUDED." + column 133 | whereColumns[k] = table + "." + column + " <> EXCLUDED." + column 134 | } 135 | // 136 | // The INSERT...VALUES portion of the query 137 | values := "?" + strings.Repeat(", ?", sizeInsert-1) 138 | // 139 | parts := []string{ 140 | "INSERT INTO " + table, 141 | "\t\t( " + strings.Join(rv.Arguments, ", ") + " )", 142 | "\tVALUES", 143 | "\t\t( " + values + " )", 144 | "\tON CONFLICT( " + strings.Join(keys, ", ") + " ) DO UPDATE SET", 145 | "\t\t" + strings.Join(updateColumns, ", "), 146 | "\t\tWHERE (", 147 | "\t\t\t" + strings.Join(whereColumns, " OR "), 148 | "\t\t)", 149 | } 150 | if len(auto) > 0 { 151 | parts = append(parts, "\tRETURNING "+strings.Join(auto, ", ")) 152 | rv.Scan = append([]string{}, auto...) 153 | rv.Expect = statements.ExpectRowOrNone 154 | } 155 | rv.SQL = strings.Join(parts, "\n") 156 | return rv, nil 157 | } 158 | -------------------------------------------------------------------------------- /hobbled/pkg.go: -------------------------------------------------------------------------------- 1 | // Package hobbled allows creation of hobbled or deficient database types to facilitate testing within sqlh. 2 | package hobbled 3 | 4 | import ( 5 | "context" 6 | "database/sql" 7 | "fmt" 8 | 9 | "github.com/nofeaturesonlybugs/sqlh" 10 | ) 11 | 12 | // Wrapper is a type that wraps around a DB in order to hobble it. 13 | type Wrapper int 14 | 15 | const ( 16 | // Passthru does not modify database types given to it. 17 | Passthru Wrapper = iota 18 | // NoBegin removes Begin() from database types given to it. 19 | NoBegin 20 | // NoBeginNoPrepare removes Begin() and Prepare() from types given to it. 21 | NoBeginNoPrepare 22 | ) 23 | 24 | // String describes the wrapper type. 25 | func (me Wrapper) String() string { 26 | return [...]string{"DB", "DB w/o begin", "DB w/o begin+prepare"}[me] 27 | } 28 | 29 | // WrapDB wraps the given DB and returns a sqlh.IQueries instance. 30 | func (me Wrapper) WrapDB(db *sql.DB) sqlh.IQueries { 31 | switch me { 32 | case Passthru: 33 | return db 34 | case NoBegin: 35 | return NewWithoutBegin(db) 36 | case NoBeginNoPrepare: 37 | return NewWithoutBeginWithoutPrepare(db) 38 | } 39 | panic(fmt.Sprintf("unknown %T %v", me, int(me))) 40 | } 41 | 42 | // WithoutBegin has no Begin call. 43 | type WithoutBegin interface { 44 | Exec(query string, args ...interface{}) (sql.Result, error) 45 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 46 | Prepare(query string) (*sql.Stmt, error) 47 | PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) 48 | Query(query string, args ...interface{}) (*sql.Rows, error) 49 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 50 | QueryRow(query string, args ...interface{}) *sql.Row 51 | QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 52 | } 53 | 54 | // NewWithoutBegin returns a WithoutBegin type. 55 | func NewWithoutBegin(db interface{}) WithoutBegin { 56 | type T struct { 57 | *canQuery 58 | *canPrepare 59 | } 60 | switch tt := db.(type) { 61 | case *sql.DB: 62 | return &T{ 63 | canQuery: &canQuery{db: tt}, 64 | canPrepare: &canPrepare{db: tt}, 65 | } 66 | case *sql.Tx: 67 | return &T{ 68 | canQuery: &canQuery{tx: tt}, 69 | canPrepare: &canPrepare{tx: tt}, 70 | } 71 | } 72 | panic("db is not a *sql.DB or *sql.Tx") 73 | } 74 | 75 | // WithoutPrepare has no Prepare calls. 76 | type WithoutPrepare interface { 77 | Begin() (*sql.Tx, error) 78 | Exec(query string, args ...interface{}) (sql.Result, error) 79 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 80 | Query(query string, args ...interface{}) (*sql.Rows, error) 81 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 82 | QueryRow(query string, args ...interface{}) *sql.Row 83 | QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 84 | } 85 | 86 | // NewWithoutPrepare returns a WithoutPrepare type. 87 | func NewWithoutPrepare(db interface{}) WithoutPrepare { 88 | type T struct { 89 | *canBegin 90 | *canQuery 91 | } 92 | switch tt := db.(type) { 93 | case *sql.DB: 94 | return &T{ 95 | canBegin: &canBegin{db: tt}, 96 | canQuery: &canQuery{db: tt}, 97 | } 98 | case *sql.Tx: 99 | panic("*sql.Tx can not be a NoPrepare because NoPrepare has a Begin() method.") 100 | } 101 | panic("db is not a *sql.DB or *sql.Tx") 102 | } 103 | 104 | // WithoutBeginWithoutPrepare has no Begin or Prepare calls. 105 | type WithoutBeginWithoutPrepare interface { 106 | Exec(query string, args ...interface{}) (sql.Result, error) 107 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 108 | Query(query string, args ...interface{}) (*sql.Rows, error) 109 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 110 | QueryRow(query string, args ...interface{}) *sql.Row 111 | QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 112 | } 113 | 114 | // NewWithoutBeginWithoutPrepare returns a WithoutBeginWithoutPrepare type. 115 | func NewWithoutBeginWithoutPrepare(db interface{}) WithoutBeginWithoutPrepare { 116 | switch tt := db.(type) { 117 | case *sql.DB: 118 | return &canQuery{db: tt} 119 | case *sql.Tx: 120 | return &canQuery{tx: tt} 121 | } 122 | panic("db is not a *sql.DB or *sql.Tx") 123 | } 124 | 125 | // canBegin allows the begin function. 126 | type canBegin struct { 127 | db *sql.DB 128 | } 129 | 130 | func (me *canBegin) Begin() (*sql.Tx, error) { 131 | return me.db.Begin() 132 | } 133 | 134 | // canQuery allows the query functions. 135 | type canQuery struct { 136 | db *sql.DB 137 | tx *sql.Tx 138 | } 139 | 140 | func (me *canQuery) Exec(query string, args ...interface{}) (sql.Result, error) { 141 | if me.tx != nil { 142 | return me.tx.Exec(query, args...) 143 | } 144 | return me.db.Exec(query, args...) 145 | } 146 | func (me *canQuery) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 147 | if me.tx != nil { 148 | return me.tx.ExecContext(ctx, query, args...) 149 | } 150 | return me.db.ExecContext(ctx, query, args...) 151 | } 152 | func (me *canQuery) Query(query string, args ...interface{}) (*sql.Rows, error) { 153 | if me.tx != nil { 154 | return me.tx.Query(query, args...) 155 | } 156 | return me.db.Query(query, args...) 157 | } 158 | func (me *canQuery) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { 159 | if me.tx != nil { 160 | return me.tx.QueryContext(ctx, query, args...) 161 | } 162 | return me.db.QueryContext(ctx, query, args...) 163 | } 164 | func (me *canQuery) QueryRow(query string, args ...interface{}) *sql.Row { 165 | if me.tx != nil { 166 | return me.tx.QueryRow(query, args...) 167 | } 168 | return me.db.QueryRow(query, args...) 169 | } 170 | func (me *canQuery) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 171 | if me.tx != nil { 172 | return me.tx.QueryRowContext(ctx, query, args...) 173 | } 174 | return me.db.QueryRowContext(ctx, query, args...) 175 | } 176 | 177 | // canPrepare allows the prepare functions. 178 | type canPrepare struct { 179 | db *sql.DB 180 | tx *sql.Tx 181 | } 182 | 183 | func (me *canPrepare) Prepare(query string) (*sql.Stmt, error) { 184 | if me.tx != nil { 185 | return me.tx.Prepare(query) 186 | } 187 | return me.db.Prepare(query) 188 | } 189 | 190 | func (me *canPrepare) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { 191 | if me.tx != nil { 192 | return me.tx.PrepareContext(ctx, query) 193 | } 194 | return me.db.PrepareContext(ctx, query) 195 | } 196 | -------------------------------------------------------------------------------- /interfaces.go: -------------------------------------------------------------------------------- 1 | package sqlh 2 | 3 | import ( 4 | "database/sql" 5 | ) 6 | 7 | // IQueries defines the methods common to types that can run queries. 8 | type IQueries interface { 9 | Exec(query string, args ...interface{}) (sql.Result, error) 10 | Query(query string, args ...interface{}) (*sql.Rows, error) 11 | QueryRow(query string, args ...interface{}) *sql.Row 12 | } 13 | 14 | // IPrepares defines the methods required to run prepared statements. 15 | type IPrepares interface { 16 | Prepare(query string) (*sql.Stmt, error) 17 | } 18 | 19 | // IIterates defines the methods required for iterating a query result set. 20 | type IIterates interface { 21 | Close() error 22 | Columns() ([]string, error) 23 | Err() error 24 | Next() bool 25 | Scan(dest ...interface{}) error 26 | } 27 | 28 | // IBegins defines the method(s) required to open a transaction. 29 | type IBegins interface { 30 | Begin() (*sql.Tx, error) 31 | } 32 | -------------------------------------------------------------------------------- /internal_test.go: -------------------------------------------------------------------------------- 1 | package sqlh 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestScannerDestType(t *testing.T) { 10 | chk := assert.New(t) 11 | // 12 | chk.Equal("Invalid", scannerDestType(0).String()) 13 | } 14 | -------------------------------------------------------------------------------- /model/errors.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "errors" 4 | 5 | var ErrUnsupported error = errors.New("unsupported") 6 | -------------------------------------------------------------------------------- /model/examples/mock.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "math/rand" 7 | "strings" 8 | "time" 9 | 10 | "github.com/DATA-DOG/go-sqlmock" 11 | ) 12 | 13 | // SentinalTime is a set time value used to generate times. 14 | var SentinalTime time.Time = time.Date(2006, 1, 2, 3, 4, 5, 7, time.Local) 15 | 16 | // TimeGenerator uses SentinalTime to return deterministic time values. 17 | type TimeGenerator struct { 18 | n int 19 | } 20 | 21 | // Next returns the next time.Time. 22 | func (tg *TimeGenerator) Next() time.Time { 23 | rv := SentinalTime.Add(time.Duration(tg.n) * time.Hour) 24 | tg.n++ 25 | return rv 26 | } 27 | 28 | // Example is a specific example. 29 | type Example int 30 | 31 | const ( 32 | ExNone Example = iota 33 | ExAddressInsert 34 | ExAddressInsertSlice 35 | ExAddressUpdate 36 | ExAddressUpdateSlice 37 | ExAddressSave 38 | ExAddressSaveSlice 39 | ExLogEntrySave 40 | ExRelationshipInsert 41 | ExRelationshipInsertSlice 42 | ExRelationshipUpdate 43 | ExRelationshipUpdateSlice 44 | ExRelationshipUpsert 45 | ExRelationshipUpsertSlice 46 | ExRelationshipSave 47 | ExUpsert 48 | ExUpsertSlice 49 | ) 50 | 51 | // ReturnArgs creates enough return args for the specified number of models in n. 52 | // Columns can be: pk (int), created (time.Time), modified (time.Time) 53 | func ReturnArgs(n int, columns ...string) []driver.Value { 54 | var rv []driver.Value 55 | var created time.Time 56 | for k := 0; k < n; k++ { 57 | for _, column := range columns { 58 | switch column { 59 | case "pk": 60 | rv = append(rv, rand.Int()) 61 | case "created": 62 | created = time.Now().Add(-1 * time.Duration((rand.Int() % 3600)) * time.Hour) 63 | rv = append(rv, created) 64 | case "modified": 65 | rv = append(rv, created.Add(time.Duration((rand.Int()%3600))*time.Hour)) 66 | } 67 | } 68 | } 69 | return rv 70 | } 71 | 72 | // Connect creates a sqlmock DB and configures it for the example. 73 | func Connect(e Example) (DB *sql.DB, err error) { 74 | var mock sqlmock.Sqlmock 75 | DB, mock, err = sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) 76 | // 77 | switch e { 78 | case ExAddressInsert: 79 | parts := []string{ 80 | "INSERT INTO addresses", 81 | "\t\t( street, city, state, zip )", 82 | "\tVALUES", 83 | "\t\t( $1, $2, $3, $4 )", 84 | "\tRETURNING pk, created_tmz, modified_tmz", 85 | } 86 | allRows := []*sqlmock.Rows{ 87 | sqlmock.NewRows([]string{"pk", "created_tmz", "modified_tmz"}). 88 | AddRow(ReturnArgs(1, "pk", "created", "modified")...), 89 | sqlmock.NewRows([]string{"pk", "created_tmz", "modified_tmz"}). 90 | AddRow(ReturnArgs(1, "pk", "created", "modified")...), 91 | } 92 | mock.ExpectQuery(strings.Join(parts, "\n")). 93 | WithArgs("1234 The Street", "Small City", "ST", "98765"). 94 | WillReturnRows(allRows[0]). 95 | RowsWillBeClosed() 96 | mock.ExpectQuery(strings.Join(parts, "\n")). 97 | WithArgs("4321 The Street", "Big City", "TS", "56789"). 98 | WillReturnRows(allRows[1]). 99 | RowsWillBeClosed() 100 | 101 | case ExAddressInsertSlice: 102 | parts := []string{ 103 | "INSERT INTO addresses", 104 | "\t\t( street, city, state, zip )", 105 | "\tVALUES", 106 | "\t\t( $1, $2, $3, $4 )", 107 | "\tRETURNING pk, created_tmz, modified_tmz", 108 | } 109 | for k := 0; k < 2; k++ { 110 | mock.ExpectBegin() 111 | prepared := mock.ExpectPrepare(strings.Join(parts, "\n")) 112 | rows := sqlmock.NewRows([]string{"pk", "created_tmz", "modified_tmz"}). 113 | AddRow(ReturnArgs(1, "pk", "created", "modified")...) 114 | prepared.ExpectQuery(). 115 | WithArgs("1234 The Street", "Small City", "ST", "98765"). 116 | WillReturnRows(rows). 117 | RowsWillBeClosed() 118 | rows = sqlmock.NewRows([]string{"pk", "created_tmz", "modified_tmz"}). 119 | AddRow(ReturnArgs(1, "pk", "created", "modified")...) 120 | prepared.ExpectQuery(). 121 | WithArgs("55 Here We Are", "Big City", "TS", "56789"). 122 | WillReturnRows(rows). 123 | RowsWillBeClosed() 124 | prepared.WillBeClosed() 125 | mock.ExpectCommit() 126 | } 127 | 128 | case ExAddressUpdate: 129 | parts := []string{ 130 | "UPDATE addresses SET", 131 | "\t\tstreet = $1,", 132 | "\t\tcity = $2,", 133 | "\t\tstate = $3,", 134 | "\t\tzip = $4", 135 | "\tWHERE", 136 | "\t\tpk = $5", 137 | "\tRETURNING modified_tmz", 138 | } 139 | allRows := []*sqlmock.Rows{ 140 | sqlmock.NewRows([]string{"modified_tmz"}). 141 | AddRow(ReturnArgs(1, "modified")...), 142 | sqlmock.NewRows([]string{"modified_tmz"}). 143 | AddRow(ReturnArgs(1, "modified")...), 144 | } 145 | mock.ExpectQuery(strings.Join(parts, "\n")). 146 | WithArgs("1234 The Street", "Small City", "ST", "98765", 42). 147 | WillReturnRows(allRows[0]). 148 | RowsWillBeClosed() 149 | mock.ExpectQuery(strings.Join(parts, "\n")). 150 | WithArgs("4321 The Street", "Big City", "TS", "56789", 42). 151 | WillReturnRows(allRows[1]). 152 | RowsWillBeClosed() 153 | 154 | case ExAddressUpdateSlice: 155 | parts := []string{ 156 | "UPDATE addresses SET", 157 | "\t\tstreet = $1,", 158 | "\t\tcity = $2,", 159 | "\t\tstate = $3,", 160 | "\t\tzip = $4", 161 | "\tWHERE", 162 | "\t\tpk = $5", 163 | "\tRETURNING modified_tmz", 164 | } 165 | for k := 0; k < 2; k++ { 166 | mock.ExpectBegin() 167 | prepared := mock.ExpectPrepare(strings.Join(parts, "\n")) 168 | rows := sqlmock.NewRows([]string{"modified_tmz"}). 169 | AddRow(ReturnArgs(1, "modified")...) 170 | prepared.ExpectQuery(). 171 | WithArgs("1234 The Street", "Small City", "ST", "98765", 42). 172 | WillReturnRows(rows). 173 | RowsWillBeClosed() 174 | rows = sqlmock.NewRows([]string{"modified_tmz"}). 175 | AddRow(ReturnArgs(1, "modified")...) 176 | prepared.ExpectQuery(). 177 | WithArgs("55 Here We Are", "Big City", "TS", "56789", 62). 178 | WillReturnRows(rows). 179 | RowsWillBeClosed() 180 | prepared.WillBeClosed() 181 | mock.ExpectCommit() 182 | } 183 | 184 | case ExAddressSave: 185 | var tg TimeGenerator 186 | var times []time.Time = []time.Time{ 187 | tg.Next(), tg.Next(), 188 | } 189 | 190 | // The INSERT portion 191 | parts := []string{ 192 | "INSERT INTO addresses", 193 | "\t\t( street, city, state, zip )", 194 | "\tVALUES", 195 | "\t\t( $1, $2, $3, $4 )", 196 | "\tRETURNING pk, created_tmz, modified_tmz", 197 | } 198 | allRows := []*sqlmock.Rows{ 199 | sqlmock.NewRows([]string{"pk", "created_tmz", "modified_tmz"}). 200 | AddRow(1, times[0], times[0]), 201 | sqlmock.NewRows([]string{"pk", "created_tmz", "modified_tmz"}). 202 | AddRow(2, times[1], times[1]), 203 | } 204 | mock.ExpectQuery(strings.Join(parts, "\n")). 205 | WithArgs("1234 The Street", "Small City", "ST", "98765"). 206 | WillReturnRows(allRows[0]). 207 | RowsWillBeClosed() 208 | mock.ExpectQuery(strings.Join(parts, "\n")). 209 | WithArgs("55 Here We Are", "Big City", "TS", "56789"). 210 | WillReturnRows(allRows[1]). 211 | RowsWillBeClosed() 212 | 213 | // The UPDATE portion 214 | parts = []string{ 215 | "UPDATE addresses SET", 216 | "\t\tstreet = $1,", 217 | "\t\tcity = $2,", 218 | "\t\tstate = $3,", 219 | "\t\tzip = $4", 220 | "\tWHERE", 221 | "\t\tpk = $5", 222 | "\tRETURNING modified_tmz", 223 | } 224 | allRows = []*sqlmock.Rows{ 225 | sqlmock.NewRows([]string{"modified_tmz"}). 226 | AddRow(times[0].Add(time.Hour)), 227 | sqlmock.NewRows([]string{"modified_tmz"}). 228 | AddRow(times[1].Add(time.Hour)), 229 | } 230 | mock.ExpectQuery(strings.Join(parts, "\n")). 231 | WithArgs("1 New Street", "Small City", "ST", "99111", 1). 232 | WillReturnRows(allRows[0]). 233 | RowsWillBeClosed() 234 | mock.ExpectQuery(strings.Join(parts, "\n")). 235 | WithArgs("2 New Street", "Big City", "TS", "99222", 2). 236 | WillReturnRows(allRows[1]). 237 | RowsWillBeClosed() 238 | 239 | case ExAddressSaveSlice: 240 | var tg TimeGenerator 241 | var times []time.Time = []time.Time{ 242 | tg.Next(), tg.Next(), 243 | } 244 | 245 | // The INSERT portion 246 | parts := []string{ 247 | "INSERT INTO addresses", 248 | "\t\t( street, city, state, zip )", 249 | "\tVALUES", 250 | "\t\t( $1, $2, $3, $4 )", 251 | "\tRETURNING pk, created_tmz, modified_tmz", 252 | } 253 | allRows := []*sqlmock.Rows{ 254 | sqlmock.NewRows([]string{"pk", "created_tmz", "modified_tmz"}). 255 | AddRow(1, times[0], times[0]), 256 | sqlmock.NewRows([]string{"pk", "created_tmz", "modified_tmz"}). 257 | AddRow(2, times[1], times[1]), 258 | } 259 | mock.ExpectBegin() 260 | prepared := mock.ExpectPrepare(strings.Join(parts, "\n")) 261 | prepared.ExpectQuery(). 262 | WithArgs("1234 The Street", "Small City", "ST", "98765"). 263 | WillReturnRows(allRows[0]). 264 | RowsWillBeClosed() 265 | prepared.ExpectQuery(). 266 | WithArgs("55 Here We Are", "Big City", "TS", "56789"). 267 | WillReturnRows(allRows[1]). 268 | RowsWillBeClosed() 269 | prepared.WillBeClosed() 270 | mock.ExpectCommit() 271 | 272 | // The UPDATE portion 273 | parts = []string{ 274 | "UPDATE addresses SET", 275 | "\t\tstreet = $1,", 276 | "\t\tcity = $2,", 277 | "\t\tstate = $3,", 278 | "\t\tzip = $4", 279 | "\tWHERE", 280 | "\t\tpk = $5", 281 | "\tRETURNING modified_tmz", 282 | } 283 | allRows = []*sqlmock.Rows{ 284 | sqlmock.NewRows([]string{"modified_tmz"}). 285 | AddRow(times[0].Add(time.Hour)), 286 | sqlmock.NewRows([]string{"modified_tmz"}). 287 | AddRow(times[1].Add(time.Hour)), 288 | } 289 | mock.ExpectBegin() 290 | prepared = mock.ExpectPrepare(strings.Join(parts, "\n")) 291 | prepared.ExpectQuery(). 292 | WithArgs("1 New Street", "Small City", "ST", "99111", 1). 293 | WillReturnRows(allRows[0]). 294 | RowsWillBeClosed() 295 | prepared.ExpectQuery(). 296 | WithArgs("2 New Street", "Big City", "TS", "99222", 2). 297 | WillReturnRows(allRows[1]). 298 | RowsWillBeClosed() 299 | prepared.WillBeClosed() 300 | mock.ExpectCommit() 301 | 302 | case ExLogEntrySave: 303 | parts := []string{ 304 | "INSERT INTO log", 305 | "\t\t( message )", 306 | "\tVALUES", 307 | "\t\t( $1 )", 308 | } 309 | mock.ExpectBegin() 310 | prepared := mock.ExpectPrepare(strings.Join(parts, "\n")) 311 | prepared.ExpectExec(). 312 | WithArgs("Hello, World!"). 313 | WillReturnResult(sqlmock.NewResult(0, 1)) 314 | prepared.ExpectExec(). 315 | WithArgs("Foo, Bar!"). 316 | WillReturnResult(sqlmock.NewResult(0, 1)) 317 | prepared.ExpectExec(). 318 | WithArgs("The llamas are escaping!"). 319 | WillReturnResult(sqlmock.NewResult(0, 1)) 320 | prepared.WillBeClosed() 321 | mock.ExpectCommit() 322 | 323 | case ExRelationshipInsert: 324 | parts := []string{ 325 | "INSERT INTO relationship", 326 | "\t\t( left_fk, right_fk, toggle )", 327 | "\tVALUES", 328 | "\t\t( $1, $2, $3 )", 329 | } 330 | mock.ExpectExec(strings.Join(parts, "\n")).WithArgs(1, 10, false).WillReturnResult(sqlmock.NewResult(0, 1)) 331 | 332 | case ExRelationshipInsertSlice: 333 | parts := []string{ 334 | "INSERT INTO relationship", 335 | "\t\t( left_fk, right_fk, toggle )", 336 | "\tVALUES", 337 | "\t\t( $1, $2, $3 )", 338 | } 339 | mock.ExpectBegin() 340 | prepared := mock.ExpectPrepare(strings.Join(parts, "\n")) 341 | prepared.ExpectExec().WithArgs(1, 10, false).WillReturnResult(sqlmock.NewResult(0, 1)) 342 | prepared.ExpectExec().WithArgs(2, 20, true).WillReturnResult(sqlmock.NewResult(0, 1)) 343 | prepared.ExpectExec().WithArgs(3, 30, false).WillReturnResult(sqlmock.NewResult(0, 1)) 344 | prepared.WillBeClosed() 345 | mock.ExpectCommit() 346 | 347 | case ExRelationshipUpdate: 348 | parts := []string{ 349 | "UPDATE relationship SET", 350 | "\t\ttoggle = $1", 351 | "\tWHERE", 352 | "\t\tleft_fk = $2 AND right_fk = $3", 353 | } 354 | mock.ExpectExec(strings.Join(parts, "\n")).WithArgs(true, 1, 10).WillReturnResult(sqlmock.NewResult(0, 1)) 355 | 356 | case ExRelationshipUpdateSlice: 357 | parts := []string{ 358 | "UPDATE relationship SET", 359 | "\t\ttoggle = $1", 360 | "\tWHERE", 361 | "\t\tleft_fk = $2 AND right_fk = $3", 362 | } 363 | mock.ExpectBegin() 364 | prepared := mock.ExpectPrepare(strings.Join(parts, "\n")) 365 | prepared.ExpectExec().WithArgs(true, 1, 10).WillReturnResult(sqlmock.NewResult(0, 1)) 366 | prepared.ExpectExec().WithArgs(false, 2, 20).WillReturnResult(sqlmock.NewResult(0, 1)) 367 | prepared.ExpectExec().WithArgs(true, 3, 30).WillReturnResult(sqlmock.NewResult(0, 1)) 368 | prepared.WillBeClosed() 369 | mock.ExpectCommit() 370 | 371 | case ExRelationshipUpsert: 372 | parts := []string{ 373 | "INSERT INTO relationship AS dest", 374 | "\t\t( left_fk, right_fk, toggle )", 375 | "\tVALUES", 376 | "\t\t( $1, $2, $3 )", 377 | "\tON CONFLICT( left_fk, right_fk ) DO UPDATE SET", 378 | "\t\ttoggle = EXCLUDED.toggle", 379 | "\t\tWHERE (", 380 | "\t\t\tdest.toggle <> EXCLUDED.toggle", 381 | "\t\t)", 382 | } 383 | mock.ExpectExec(strings.Join(parts, "\n")).WithArgs(1, 10, false).WillReturnResult(sqlmock.NewResult(0, 1)) 384 | 385 | case ExRelationshipUpsertSlice: 386 | parts := []string{ 387 | "INSERT INTO relationship AS dest", 388 | "\t\t( left_fk, right_fk, toggle )", 389 | "\tVALUES", 390 | "\t\t( $1, $2, $3 )", 391 | "\tON CONFLICT( left_fk, right_fk ) DO UPDATE SET", 392 | "\t\ttoggle = EXCLUDED.toggle", 393 | "\t\tWHERE (", 394 | "\t\t\tdest.toggle <> EXCLUDED.toggle", 395 | "\t\t)", 396 | } 397 | mock.ExpectBegin() 398 | prepared := mock.ExpectPrepare(strings.Join(parts, "\n")) 399 | prepared.ExpectExec().WithArgs(1, 10, false).WillReturnResult(sqlmock.NewResult(0, 1)) 400 | prepared.ExpectExec().WithArgs(2, 20, true).WillReturnResult(sqlmock.NewResult(0, 1)) 401 | prepared.ExpectExec().WithArgs(3, 30, false).WillReturnResult(sqlmock.NewResult(0, 1)) 402 | prepared.WillBeClosed() 403 | mock.ExpectCommit() 404 | 405 | case ExRelationshipSave: 406 | parts := []string{ 407 | "INSERT INTO relationship AS dest", 408 | "\t\t( left_fk, right_fk, toggle )", 409 | "\tVALUES", 410 | "\t\t( $1, $2, $3 )", 411 | "\tON CONFLICT( left_fk, right_fk ) DO UPDATE SET", 412 | "\t\ttoggle = EXCLUDED.toggle", 413 | "\t\tWHERE (", 414 | "\t\t\tdest.toggle <> EXCLUDED.toggle", 415 | "\t\t)", 416 | } 417 | mock.ExpectBegin() 418 | prepared := mock.ExpectPrepare(strings.Join(parts, "\n")) 419 | prepared.ExpectExec().WithArgs(1, 2, false).WillReturnResult(sqlmock.NewResult(0, 1)) 420 | prepared.ExpectExec().WithArgs(10, 20, false).WillReturnResult(sqlmock.NewResult(0, 1)) 421 | prepared.WillBeClosed() 422 | mock.ExpectCommit() 423 | mock.ExpectBegin() 424 | prepared = mock.ExpectPrepare(strings.Join(parts, "\n")) 425 | prepared.ExpectExec().WithArgs(1, 2, true).WillReturnResult(sqlmock.NewResult(0, 1)) 426 | prepared.ExpectExec().WithArgs(10, 20, true).WillReturnResult(sqlmock.NewResult(0, 1)) 427 | prepared.WillBeClosed() 428 | mock.ExpectCommit() 429 | 430 | case ExUpsert: 431 | parts := []string{ 432 | "INSERT INTO upsertable AS dest", 433 | "\t\t( pk, string, number )", 434 | "\tVALUES", 435 | "\t\t( $1, $2, $3 )", 436 | "\tON CONFLICT( pk ) DO UPDATE SET", 437 | "\t\tstring = EXCLUDED.string, number = EXCLUDED.number", 438 | "\t\tWHERE (", 439 | "\t\t\tdest.string <> EXCLUDED.string OR dest.number <> EXCLUDED.number", 440 | "\t\t)", 441 | "\tRETURNING created_tmz, modified_tmz", 442 | } 443 | allRows := []*sqlmock.Rows{ 444 | sqlmock.NewRows([]string{"created_tmz", "modified_tmz"}). 445 | AddRow(ReturnArgs(1, "created", "modified")...), 446 | sqlmock.NewRows([]string{"created_tmz", "modified_tmz"}). 447 | AddRow(ReturnArgs(1, "created", "modified")...), 448 | } 449 | qu := mock.ExpectQuery(strings.Join(parts, "\n")) 450 | qu.WithArgs("some-unique-string", "Hello, World!", 42) 451 | qu.WillReturnRows(allRows[0]) 452 | qu.RowsWillBeClosed() 453 | qu = mock.ExpectQuery(strings.Join(parts, "\n")) 454 | qu.WithArgs("other-unique-string", "Foo, Bar!", 100) 455 | qu.WillReturnRows(allRows[1]) 456 | qu.RowsWillBeClosed() 457 | 458 | case ExUpsertSlice: 459 | parts := []string{ 460 | "INSERT INTO upsertable AS dest", 461 | "\t\t( pk, string, number )", 462 | "\tVALUES", 463 | "\t\t( $1, $2, $3 )", 464 | "\tON CONFLICT( pk ) DO UPDATE SET", 465 | "\t\tstring = EXCLUDED.string, number = EXCLUDED.number", 466 | "\t\tWHERE (", 467 | "\t\t\tdest.string <> EXCLUDED.string OR dest.number <> EXCLUDED.number", 468 | "\t\t)", 469 | "\tRETURNING created_tmz, modified_tmz", 470 | } 471 | for k := 0; k < 2; k++ { 472 | mock.ExpectBegin() 473 | prepared := mock.ExpectPrepare(strings.Join(parts, "\n")) 474 | rows := sqlmock.NewRows([]string{"created_tmz", "modified_tmz"}). 475 | AddRow(ReturnArgs(1, "created", "modified")...) 476 | prepared.ExpectQuery(). 477 | WithArgs("some-unique-string", "Hello, World!", 42). 478 | WillReturnRows(rows) 479 | rows = sqlmock.NewRows([]string{"created_tmz", "modified_tmz"}). 480 | AddRow(ReturnArgs(1, "created", "modified")...) 481 | prepared.ExpectQuery(). 482 | WithArgs("other-unique-string", "Goodbye, World!", 10). 483 | WillReturnRows(rows) 484 | prepared.WillBeClosed() 485 | mock.ExpectCommit() 486 | } 487 | 488 | } 489 | // 490 | return DB, err 491 | } 492 | -------------------------------------------------------------------------------- /model/examples/model.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/nofeaturesonlybugs/set" 7 | "github.com/nofeaturesonlybugs/sqlh/grammar" 8 | "github.com/nofeaturesonlybugs/sqlh/model" 9 | ) 10 | 11 | var ( 12 | // Models is a sample model.Database. 13 | // 14 | // Most of the magic occurs in the Mapper (set.Mapper); if your models follow a consistent logic 15 | // for struct-tags to column-names then you may only need a single model.Database for your application. 16 | // 17 | // However creating a model.Database is relatively easy and there's nothing to stop you from creating 18 | // multiple of them with different Mapper (set.Mapper) to handle inconsistency in your models. 19 | Models = &model.Models{ 20 | // Mapper defines how struct fields map to friendly column names (i.e. database names). 21 | Mapper: &set.Mapper{ 22 | Join: "_", 23 | Tags: []string{"db", "json"}, 24 | }, 25 | // This instance uses a Postgres grammar. 26 | Grammar: grammar.Postgres, 27 | } 28 | ) 29 | 30 | // NewModels returns a model.Models type. 31 | func NewModels() *model.Models { 32 | rv := &model.Models{ 33 | Mapper: &set.Mapper{ 34 | Join: "_", 35 | Tags: []string{"db", "json"}, 36 | }, 37 | Grammar: grammar.Postgres, 38 | } 39 | rv.Register(Address{}) 40 | rv.Register(LogEntry{}) 41 | rv.Register(Person{}) 42 | rv.Register(PersonAddress{}) 43 | rv.Register(Relationship{}) 44 | rv.Register(Upsertable{}) 45 | return rv 46 | } 47 | 48 | func init() { 49 | // Somewhere in your application you need to register all types to be used as models. 50 | Models.Register(Address{}) 51 | Models.Register(LogEntry{}) 52 | Models.Register(Person{}) 53 | Models.Register(PersonAddress{}) 54 | Models.Register(Relationship{}) 55 | Models.Register(Upsertable{}) 56 | } 57 | 58 | // Address is a simple model representing an address. 59 | type Address struct { 60 | // This member tells the model.Database the name of the table for this model; think of it 61 | // like xml.Name when using encoding/xml. 62 | model.TableName `json:"-" model:"addresses"` 63 | // 64 | // The struct fields and column names. The example Mapper allows the db and json 65 | // to share names where they are the same; however db is higher in precedence than 66 | // json so where present that is the database column name. 67 | Id int `json:"id" db:"pk" model:"key,auto"` 68 | CreatedTime time.Time `json:"created_time" db:"created_tmz" model:"inserted"` 69 | ModifiedTime time.Time `json:"modified_time" db:"modified_tmz" model:"inserted,updated"` 70 | Street string `json:"street"` 71 | City string `json:"city"` 72 | State string `json:"state"` 73 | Zip string `json:"zip"` 74 | } 75 | 76 | // Person is a simple model representing a person. 77 | type Person struct { 78 | model.TableName `json:"-" model:"people"` 79 | // 80 | Id int `json:"id" db:"pk" model:"key,auto"` 81 | SpouseId int `json:"spouse_id" db:"spouse_fk" model:"foreign"` 82 | CreatedTime time.Time `json:"created_time" db:"created_tmz" model:"inserted"` 83 | ModifiedTime time.Time `json:"modified_time" db:"modified_tmz" model:"inserted,updated"` 84 | First string `json:"first"` 85 | Last string `json:"last"` 86 | Age int `json:"age"` 87 | SSN string `json:"ssn" model:"unique"` 88 | } 89 | 90 | // PersonAddress links a person to an address. 91 | type PersonAddress struct { 92 | model.TableName `json:"-" model:"relate_people_addresses"` 93 | // 94 | PersonId int `json:"person_id" db:"person_fk" model:"key"` 95 | AddressId int `json:"address_id" db:"address_fk" model:"key"` 96 | } 97 | 98 | // Upsertable is a model that can use UPSERT style queries because it only 99 | // has "key" and not "key,auto" columns. 100 | type Upsertable struct { 101 | model.TableName `json:"-" model:"upsertable"` 102 | // 103 | Id string `json:"id" db:"pk" model:"key"` 104 | CreatedTime time.Time `json:"created_time" db:"created_tmz" model:"inserted"` 105 | ModifiedTime time.Time `json:"modified_time" db:"modified_tmz" model:"inserted,updated"` 106 | String string `json:"string"` 107 | Number int `json:"number"` 108 | } 109 | 110 | // Relationship is a model with a composite primary key and no fields that auto update. 111 | // Such a model might exist for relationship tables. 112 | type Relationship struct { 113 | model.TableName `json:"-" model:"relationship"` 114 | // 115 | LeftId int `json:"left_id" db:"left_fk" model:"key"` 116 | RightId int `json:"right_id" db:"right_fk" model:"key"` 117 | // Such a table might have other columns. 118 | Toggle bool `json:"toggle"` 119 | } 120 | 121 | // LogEntry is a model with no key fields defined in the Go struct. When such 122 | // a model is saved with Models.Save the Insert method will be used. 123 | // 124 | // Note that there's no reason a model like LogEntry can't be a partial model 125 | // with just enough information to insert a record without caring about 126 | // any key,auto field within the database table. 127 | type LogEntry struct { 128 | model.TableName `json:"-" model:"log"` 129 | 130 | Message string `json:"message"` 131 | } 132 | -------------------------------------------------------------------------------- /model/examples/pkg_doc.go: -------------------------------------------------------------------------------- 1 | // Package examples provides types and functions to facilitate the examples and test code in the model package. 2 | // 3 | // Important take aways from this example package are the instantiation of the global Models variable and registering 4 | // type(s) with it. 5 | // 6 | // Model registration is performed via an init() function: 7 | // func init() { 8 | // // Somewhere in your application you need to register all types to be used as models. 9 | // Models.Register(&Address{}) 10 | // Models.Register(&Person{}) 11 | // Models.Register(&PersonAddress{}) 12 | // } 13 | // 14 | // Also important is the struct definition for the type Address; the struct definition combined with the Models global 15 | // defines how SQL statements are generated and the expected column names in the generated SQL. 16 | package examples 17 | -------------------------------------------------------------------------------- /model/examples_test.go: -------------------------------------------------------------------------------- 1 | package model_test 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/nofeaturesonlybugs/set" 8 | 9 | "github.com/nofeaturesonlybugs/sqlh/grammar" 10 | "github.com/nofeaturesonlybugs/sqlh/model" 11 | "github.com/nofeaturesonlybugs/sqlh/model/examples" 12 | ) 13 | 14 | func ExampleModels_Register() { 15 | // This example demonstrates model registration. 16 | 17 | var Models *model.Models = &model.Models{ 18 | // Mapper and its fields control how Go structs are traversed and mapped to 19 | // database column names. 20 | Mapper: &set.Mapper{ 21 | Join: "_", 22 | Tags: []string{"db", "json"}, 23 | }, 24 | Grammar: grammar.Postgres, 25 | 26 | // StructTag defines the tag name to use when inspecting models. 27 | // StructTag: "", // Blank defaults to "model" 28 | } 29 | 30 | // 31 | // Simple model examples 32 | // 33 | 34 | // tablename=strings 35 | // strings.pk ↣ auto incrementing key 36 | // strings.value 37 | type StringModel struct { 38 | // This field specifies the table name in the database. 39 | // json:"-" tells encoding/json to ignore this field when marshalling 40 | // model:"strings" means the table name is "strings" in the database. 41 | model.TableName `json:"-" model:"strings"` 42 | 43 | // An auto incrementing primary key field. 44 | // 45 | // The mapper is configured to use `db` tag before `json` tag; 46 | // therefore this maps to strings.pk in the database but json 47 | // marshals as id 48 | // 49 | // `json:"id" db:"pk" model:"key,auto"` 50 | // ^-- auto incrementing 51 | // ^-- field is the key or part of composite key 52 | // ^-- maps to strings.pk column 53 | // ^-- json marshals to id 54 | Id int `json:"id" db:"pk" model:"key,auto"` 55 | 56 | // json marshals as value 57 | // maps to database column strings.value 58 | Value string `json:"value"` 59 | } 60 | 61 | // tablename=numbers 62 | // numbers.pk ↣ auto incrementing key 63 | // numbers.value 64 | type NumberModel struct { 65 | // This model does not include the model.TableName embed; the table name 66 | // must be specified during registration (see below). 67 | // model.TableName `json:"-" model:"numbers"` 68 | 69 | Id int `json:"id" db:"pk" model:"key,auto"` 70 | Value int `json:"value"` 71 | } 72 | 73 | // tablename=companies 74 | // companies.pk ↣ auto incrementing key 75 | // companies.created ↣ updates on INSERT 76 | // companies.modified ↣ updates on INSERT and UPDATE 77 | // companies.name 78 | type CompanyModel struct { 79 | Id int `json:"id" db:"pk" model:"key,auto"` 80 | 81 | // Models can have fields that update during INSERT or UPDATE statements. 82 | // `json:"created" model:"inserted"` 83 | // ^-- this column updates on insert 84 | // `json:"modified" model:"inserted,updated"` 85 | // ^-- this column updates on insert and updates 86 | CreatedTime time.Time `json:"created" model:"inserted"` 87 | ModifiedTime time.Time `json:"modified" model:"inserted,updated"` 88 | 89 | Name int `json:"name"` 90 | } 91 | 92 | // 93 | // Model registration 94 | // + Models that embed model.TableName do not need to specify the tablename during registration. 95 | Models.Register(StringModel{}) 96 | Models.Register(NumberModel{}, model.TableName("numbers")) 97 | Models.Register(CompanyModel{}, model.TableName("companies")) 98 | 99 | fmt.Println("all done") 100 | 101 | // Output: all done 102 | } 103 | 104 | func ExampleModels_Insert() { 105 | var zero time.Time 106 | // 107 | // Create a mock database. 108 | db, err := examples.Connect(examples.ExAddressInsert) 109 | if err != nil { 110 | fmt.Println("err", err.Error()) 111 | return 112 | } 113 | WasInserted := func(id int, created time.Time, modified time.Time) error { 114 | if id == 0 || zero.Equal(created) || zero.Equal(modified) { 115 | return fmt.Errorf("Record not inserted.") 116 | } 117 | return nil 118 | } 119 | // A "value" record. 120 | byVal := examples.Address{ 121 | // Id, CreatedTime, ModifiedTime are updated by the database. 122 | Street: "1234 The Street", 123 | City: "Small City", 124 | State: "ST", 125 | Zip: "98765", 126 | } 127 | // A pointer record. 128 | byPtr := &examples.Address{ 129 | // Id, CreatedTime, ModifiedTime are updated by the database. 130 | Street: "4321 The Street", 131 | City: "Big City", 132 | State: "TS", 133 | Zip: "56789", 134 | } 135 | 136 | // Pass the address of the "value" record. 137 | if err := examples.Models.Insert(db, &byVal); err != nil { 138 | fmt.Println("err", err.Error()) 139 | return 140 | } 141 | if err := WasInserted(byVal.Id, byVal.CreatedTime, byVal.ModifiedTime); err != nil { 142 | fmt.Println("err", err.Error()) 143 | return 144 | } 145 | // The pointer record can be passed directly. 146 | if err := examples.Models.Insert(db, byPtr); err != nil { 147 | fmt.Println("err", err.Error()) 148 | return 149 | } 150 | if err := WasInserted(byPtr.Id, byPtr.CreatedTime, byPtr.ModifiedTime); err != nil { 151 | fmt.Println("err", err.Error()) 152 | return 153 | } 154 | fmt.Println("Models inserted.") 155 | 156 | // Output: Models inserted. 157 | } 158 | 159 | func ExampleModels_Insert_slice() { 160 | var zero time.Time 161 | // 162 | // Create a mock database. 163 | db, err := examples.Connect(examples.ExAddressInsertSlice) 164 | if err != nil { 165 | fmt.Println("err", err.Error()) 166 | return 167 | } 168 | WasInserted := func(id int, created time.Time, modified time.Time) error { 169 | if id == 0 || zero.Equal(created) || zero.Equal(modified) { 170 | return fmt.Errorf("Record not inserted.") 171 | } 172 | return nil 173 | } 174 | // A slice of values. 175 | values := []examples.Address{ 176 | // Id, CreatedTime, ModifiedTime are updated by the database. 177 | { 178 | Street: "1234 The Street", 179 | City: "Small City", 180 | State: "ST", 181 | Zip: "98765", 182 | }, 183 | { 184 | Street: "55 Here We Are", 185 | City: "Big City", 186 | State: "TS", 187 | Zip: "56789", 188 | }, 189 | } 190 | // A slice of pointers. 191 | pointers := []*examples.Address{ 192 | // Id, CreatedTime, ModifiedTime are updated by the database. 193 | { 194 | Street: "1234 The Street", 195 | City: "Small City", 196 | State: "ST", 197 | Zip: "98765", 198 | }, 199 | { 200 | Street: "55 Here We Are", 201 | City: "Big City", 202 | State: "TS", 203 | Zip: "56789", 204 | }, 205 | } 206 | 207 | // Slices of values can be passed directly. 208 | if err := examples.Models.Insert(db, values); err != nil { 209 | fmt.Println("err", err.Error()) 210 | return 211 | } 212 | for _, model := range values { 213 | if err := WasInserted(model.Id, model.CreatedTime, model.ModifiedTime); err != nil { 214 | fmt.Println("err", err.Error()) 215 | return 216 | } 217 | } 218 | // Slices of pointers can be passed directly. 219 | if err := examples.Models.Insert(db, pointers); err != nil { 220 | fmt.Println("err", err.Error()) 221 | return 222 | } 223 | for _, model := range pointers { 224 | if err := WasInserted(model.Id, model.CreatedTime, model.ModifiedTime); err != nil { 225 | fmt.Println("err", err.Error()) 226 | return 227 | } 228 | } 229 | 230 | fmt.Println("Models inserted.") 231 | 232 | // Output: Models inserted. 233 | } 234 | 235 | func ExampleModels_Save() { 236 | // This example demonstrates using Models.Save when the models have only "key,auto" fields. 237 | // This means when models are first created and passed to Save an INSERT is performed. 238 | // Subsequent calls to Save with the model instances results in UPDATE queries. 239 | 240 | // Similar to other examples this example uses a "value" model and a pointer model. Note 241 | // the "value" model needs to be passed by address. 242 | 243 | var zero time.Time 244 | // 245 | // Create a mock database. 246 | db, err := examples.Connect(examples.ExAddressSave) 247 | if err != nil { 248 | fmt.Println("err", err.Error()) 249 | return 250 | } 251 | WasInserted := func(id int, created time.Time, modified time.Time) error { 252 | if id == 0 || zero.Equal(created) || !created.Equal(modified) { 253 | return fmt.Errorf("Record not inserted.") 254 | } 255 | return nil 256 | } 257 | WasUpdated := func(created time.Time, modified time.Time) error { 258 | if created.Equal(modified) { 259 | return fmt.Errorf("Record not updated.") 260 | } 261 | return nil 262 | } 263 | // A "value" instance. 264 | byVal := examples.Address{ 265 | Street: "1234 The Street", 266 | City: "Small City", 267 | State: "ST", 268 | Zip: "98765", 269 | } 270 | // A pointer instance. 271 | byPtr := &examples.Address{ 272 | Street: "55 Here We Are", 273 | City: "Big City", 274 | State: "TS", 275 | Zip: "56789", 276 | } 277 | // 278 | 279 | // Save the models; since these models only have "key,auto" fields they will first INSERT. 280 | if err := examples.Models.Save(db, &byVal); err != nil { 281 | fmt.Println("err", err.Error()) 282 | return 283 | } 284 | if err := WasInserted(byVal.Id, byVal.CreatedTime, byVal.ModifiedTime); err != nil { 285 | fmt.Println("err", err.Error()) 286 | return 287 | } 288 | if err := examples.Models.Save(db, byPtr); err != nil { 289 | fmt.Println("err", err.Error()) 290 | return 291 | } 292 | if err := WasInserted(byPtr.Id, byPtr.CreatedTime, byPtr.ModifiedTime); err != nil { 293 | fmt.Println("err", err.Error()) 294 | return 295 | } 296 | 297 | // Edit the model fields. 298 | byVal.Street = "1 New Street" 299 | byVal.Zip = "99111" 300 | 301 | byPtr.Street = "2 New Street" 302 | byPtr.Zip = "99222" 303 | 304 | // Save the models; since the key fields are no longer zero values they will UPDATE. 305 | if err := examples.Models.Save(db, &byVal); err != nil { 306 | fmt.Println("err", err.Error()) 307 | return 308 | } 309 | if err := WasUpdated(byVal.CreatedTime, byVal.ModifiedTime); err != nil { 310 | fmt.Println("err", err.Error()) 311 | return 312 | } 313 | if err := examples.Models.Save(db, byPtr); err != nil { 314 | fmt.Println("err", err.Error()) 315 | return 316 | } 317 | if err := WasUpdated(byPtr.CreatedTime, byPtr.ModifiedTime); err != nil { 318 | fmt.Println("err", err.Error()) 319 | return 320 | } 321 | 322 | fmt.Println("Models saved.") 323 | 324 | // Output: Models saved. 325 | } 326 | 327 | func ExampleModels_Save_slice() { 328 | // This example demonstrates using Models.Save when the models have only "key,auto" fields. 329 | // This means when models are first created and passed to Save an INSERT is performed. 330 | // Subsequent calls to Save with the model instances results in UPDATE queries. 331 | 332 | var zero time.Time 333 | // 334 | // Create a mock database. 335 | db, err := examples.Connect(examples.ExAddressSaveSlice) 336 | if err != nil { 337 | fmt.Println("err", err.Error()) 338 | return 339 | } 340 | WasInserted := func(id int, created time.Time, modified time.Time) error { 341 | if id == 0 || zero.Equal(created) || !created.Equal(modified) { 342 | return fmt.Errorf("Record not inserted.") 343 | } 344 | return nil 345 | } 346 | WasUpdated := func(created time.Time, modified time.Time) error { 347 | if created.Equal(modified) { 348 | return fmt.Errorf("Record not updated.") 349 | } 350 | return nil 351 | } 352 | values := []examples.Address{ 353 | { 354 | Street: "1234 The Street", 355 | City: "Small City", 356 | State: "ST", 357 | Zip: "98765", 358 | }, 359 | { 360 | Street: "55 Here We Are", 361 | City: "Big City", 362 | State: "TS", 363 | Zip: "56789", 364 | }, 365 | } 366 | 367 | // Save the models; since these models only have "key,auto" fields they will first INSERT. 368 | if err := examples.Models.Save(db, values); err != nil { 369 | fmt.Println("err", err.Error()) 370 | return 371 | } 372 | for _, value := range values { 373 | if err := WasInserted(value.Id, value.CreatedTime, value.ModifiedTime); err != nil { 374 | fmt.Println("err", err.Error()) 375 | return 376 | } 377 | } 378 | 379 | // Edit the model fields. 380 | values[0].Street = "1 New Street" 381 | values[0].Zip = "99111" 382 | 383 | values[1].Street = "2 New Street" 384 | values[1].Zip = "99222" 385 | 386 | // Save the models; since the key fields are no longer zero values they will UPDATE. 387 | if err := examples.Models.Save(db, values); err != nil { 388 | fmt.Println("err", err.Error()) 389 | return 390 | } 391 | for _, value := range values { 392 | if err := WasUpdated(value.CreatedTime, value.ModifiedTime); err != nil { 393 | fmt.Println("err", err.Error()) 394 | return 395 | } 396 | } 397 | 398 | fmt.Println("Models saved.") 399 | 400 | // Output: Models saved. 401 | } 402 | 403 | func ExampleModels_Save_compositeKeyUpserts() { 404 | // This example demonstrates using Models.Save when the models have only "key" key fields 405 | // and zero "key,auto" fields. Such models are saved with UPSERT. 406 | 407 | // 408 | // Create a mock database. 409 | db, err := examples.Connect(examples.ExRelationshipSave) 410 | if err != nil { 411 | fmt.Println("err", err.Error()) 412 | return 413 | } 414 | values := []examples.Relationship{ 415 | { 416 | LeftId: 1, // LeftId and RightId are the composite key 417 | RightId: 2, 418 | Toggle: false, 419 | }, 420 | { 421 | LeftId: 10, 422 | RightId: 20, 423 | Toggle: false, 424 | }, 425 | } 426 | 427 | // Save the models; since these models only have "key,auto" fields they will first INSERT. 428 | if err := examples.Models.Save(db, values); err != nil { 429 | fmt.Println("err", err.Error()) 430 | return 431 | } 432 | 433 | // Edit the model fields. 434 | values[0].Toggle = true 435 | values[1].Toggle = true 436 | 437 | // Save the models; since the key fields are no longer zero values they will UPDATE. 438 | if err := examples.Models.Save(db, values); err != nil { 439 | fmt.Println("err", err.Error()) 440 | return 441 | } 442 | 443 | fmt.Println("Models saved.") 444 | 445 | // Output: Models saved. 446 | } 447 | 448 | func ExampleModels_Save_noKeyFieldsInserts() { 449 | // This example demonstrates using Models.Save when the models do not have any 450 | // "key" or "key,auto" fields. Such models must INSERT. 451 | // 452 | // Note also that such a model could be a partial model of the actual database table 453 | // that does have key fields defined in the table schema. 454 | 455 | // 456 | // Create a mock database. 457 | db, err := examples.Connect(examples.ExLogEntrySave) 458 | if err != nil { 459 | fmt.Println("err", err.Error()) 460 | return 461 | } 462 | values := []examples.LogEntry{ 463 | {Message: "Hello, World!"}, 464 | {Message: "Foo, Bar!"}, 465 | {Message: "The llamas are escaping!"}, 466 | } 467 | 468 | // Save the models; since these models have no "key" or "key,auto" fields use INSERT. 469 | if err := examples.Models.Save(db, values); err != nil { 470 | fmt.Println("err", err.Error()) 471 | return 472 | } 473 | 474 | fmt.Println("Models saved.") 475 | 476 | // Output: Models saved. 477 | } 478 | 479 | func ExampleModels_Update() { 480 | var zero time.Time 481 | // 482 | // Create a mock database. 483 | db, err := examples.Connect(examples.ExAddressUpdate) 484 | if err != nil { 485 | fmt.Println("err", err.Error()) 486 | return 487 | } 488 | WasUpdated := func(modified time.Time) error { 489 | if zero.Equal(modified) { 490 | return fmt.Errorf("Record not updated.") 491 | } 492 | return nil 493 | } 494 | // A "value" record. 495 | byVal := examples.Address{ 496 | Id: 42, 497 | CreatedTime: time.Now().Add(-1 * time.Hour), 498 | // ModifiedTime is zero value; will be updated by database. 499 | Street: "1234 The Street", 500 | City: "Small City", 501 | State: "ST", 502 | Zip: "98765", 503 | } 504 | // A pointer record. 505 | byPtr := &examples.Address{ 506 | Id: 42, 507 | CreatedTime: time.Now().Add(-1 * time.Hour), 508 | // ModifiedTime is zero value; will be updated by database. 509 | Street: "4321 The Street", 510 | City: "Big City", 511 | State: "TS", 512 | Zip: "56789", 513 | } 514 | 515 | // Pass "value" record by address. 516 | if err := examples.Models.Update(db, &byVal); err != nil { 517 | fmt.Println("err", err.Error()) 518 | return 519 | } 520 | if err := WasUpdated(byVal.ModifiedTime); err != nil { 521 | fmt.Println("err", err.Error()) 522 | return 523 | } 524 | // Pass pointer record directly. 525 | if err := examples.Models.Update(db, byPtr); err != nil { 526 | fmt.Println("err", err.Error()) 527 | return 528 | } 529 | if err := WasUpdated(byPtr.ModifiedTime); err != nil { 530 | fmt.Println("err", err.Error()) 531 | return 532 | } 533 | 534 | fmt.Printf("Models updated.") 535 | 536 | // Output: Models updated. 537 | } 538 | 539 | func ExampleModels_Update_slice() { 540 | var zero time.Time 541 | // 542 | // Create a mock database. 543 | db, err := examples.Connect(examples.ExAddressUpdateSlice) 544 | if err != nil { 545 | fmt.Println("err", err.Error()) 546 | return 547 | } 548 | WasUpdated := func(modified time.Time) error { 549 | if zero.Equal(modified) { 550 | return fmt.Errorf("Record not updated.") 551 | } 552 | return nil 553 | } 554 | // Slice of values. 555 | values := []examples.Address{ 556 | // ModifiedTime is not set and updated by the database. 557 | { 558 | Id: 42, 559 | CreatedTime: time.Now().Add(-2 * time.Hour), 560 | Street: "1234 The Street", 561 | City: "Small City", 562 | State: "ST", 563 | Zip: "98765", 564 | }, 565 | { 566 | Id: 62, 567 | CreatedTime: time.Now().Add(-1 * time.Hour), 568 | Street: "55 Here We Are", 569 | City: "Big City", 570 | State: "TS", 571 | Zip: "56789", 572 | }, 573 | } 574 | // Slice of pointers. 575 | pointers := []*examples.Address{ 576 | // ModifiedTime is not set and updated by the database. 577 | { 578 | Id: 42, 579 | CreatedTime: time.Now().Add(-2 * time.Hour), 580 | Street: "1234 The Street", 581 | City: "Small City", 582 | State: "ST", 583 | Zip: "98765", 584 | }, 585 | { 586 | Id: 62, 587 | CreatedTime: time.Now().Add(-1 * time.Hour), 588 | Street: "55 Here We Are", 589 | City: "Big City", 590 | State: "TS", 591 | Zip: "56789", 592 | }, 593 | } 594 | 595 | // Slice of values can be passed directly. 596 | if err := examples.Models.Update(db, values); err != nil { 597 | fmt.Println("err", err.Error()) 598 | return 599 | } 600 | for _, model := range values { 601 | if err := WasUpdated(model.ModifiedTime); err != nil { 602 | fmt.Println("err", err.Error()) 603 | return 604 | } 605 | } 606 | // Slice of pointers can be passed directly. 607 | if err := examples.Models.Update(db, pointers); err != nil { 608 | fmt.Println("err", err.Error()) 609 | return 610 | } 611 | for _, model := range pointers { 612 | if err := WasUpdated(model.ModifiedTime); err != nil { 613 | fmt.Println("err", err.Error()) 614 | return 615 | } 616 | } 617 | 618 | fmt.Println("Models updated.") 619 | 620 | // Output: Models updated. 621 | } 622 | 623 | func ExampleModels_Upsert() { 624 | var zero time.Time 625 | // 626 | // Create a mock database. 627 | db, err := examples.Connect(examples.ExUpsert) 628 | if err != nil { 629 | fmt.Println("err", err.Error()) 630 | return 631 | } 632 | WasUpserted := func(created time.Time, modified time.Time) error { 633 | if zero.Equal(created) || zero.Equal(modified) { 634 | return fmt.Errorf("Record not upserted.") 635 | } 636 | return nil 637 | } 638 | // A "value" record. 639 | byVal := examples.Upsertable{ 640 | Id: "some-unique-string", 641 | String: "Hello, World!", 642 | Number: 42, 643 | } 644 | // A pointer record. 645 | byPtr := &examples.Upsertable{ 646 | Id: "other-unique-string", 647 | String: "Foo, Bar!", 648 | Number: 100, 649 | } 650 | 651 | // Pass "value" record by address. 652 | if err := examples.Models.Upsert(db, &byVal); err != nil { 653 | fmt.Println("err", err.Error()) 654 | return 655 | } 656 | if err := WasUpserted(byVal.CreatedTime, byVal.ModifiedTime); err != nil { 657 | fmt.Println("err", err.Error()) 658 | return 659 | } 660 | // Pass pointer record directly. 661 | if err := examples.Models.Upsert(db, byPtr); err != nil { 662 | fmt.Println("err", err.Error()) 663 | return 664 | } 665 | if err := WasUpserted(byPtr.CreatedTime, byPtr.ModifiedTime); err != nil { 666 | fmt.Println("err", err.Error()) 667 | return 668 | } 669 | 670 | fmt.Printf("Models upserted.") 671 | 672 | // Output: Models upserted. 673 | } 674 | 675 | func ExampleModels_Upsert_slice() { 676 | var zero time.Time 677 | // 678 | // Create a mock database. 679 | db, err := examples.Connect(examples.ExUpsertSlice) 680 | if err != nil { 681 | fmt.Println("err", err.Error()) 682 | return 683 | } 684 | WasUpserted := func(created time.Time, modified time.Time) error { 685 | if zero.Equal(created) || zero.Equal(modified) { 686 | return fmt.Errorf("Record not upserted.") 687 | } 688 | return nil 689 | } 690 | // Slice of values. 691 | values := []examples.Upsertable{ 692 | { 693 | Id: "some-unique-string", 694 | String: "Hello, World!", 695 | Number: 42, 696 | }, 697 | { 698 | Id: "other-unique-string", 699 | String: "Goodbye, World!", 700 | Number: 10, 701 | }, 702 | } 703 | // Slice of pointers. 704 | pointers := []*examples.Upsertable{ 705 | { 706 | Id: "some-unique-string", 707 | String: "Hello, World!", 708 | Number: 42, 709 | }, 710 | { 711 | Id: "other-unique-string", 712 | String: "Goodbye, World!", 713 | Number: 10, 714 | }, 715 | } 716 | 717 | // Pass "values" directly. 718 | if err := examples.Models.Upsert(db, values); err != nil { 719 | fmt.Println("err", err.Error()) 720 | return 721 | } 722 | for _, model := range values { 723 | if err := WasUpserted(model.CreatedTime, model.ModifiedTime); err != nil { 724 | fmt.Println("err", err.Error()) 725 | return 726 | } 727 | } 728 | // Pass pointers directly. 729 | if err := examples.Models.Upsert(db, pointers); err != nil { 730 | fmt.Println("err", err.Error()) 731 | return 732 | } 733 | for _, model := range pointers { 734 | if err := WasUpserted(model.CreatedTime, model.ModifiedTime); err != nil { 735 | fmt.Println("err", err.Error()) 736 | return 737 | } 738 | } 739 | 740 | fmt.Println("Models upserted.") 741 | 742 | // Output: Models upserted. 743 | } 744 | 745 | func ExampleModels_relationship() { 746 | // This single example shows the common cases of INSERT|UPDATE|UPSERT as distinct code blocks. 747 | // examples.Relationship has a composite key and no auto-updating fields. 748 | { 749 | // Demonstrates INSERT of a single model. 750 | db, err := examples.Connect(examples.ExRelationshipInsert) 751 | if err != nil { 752 | fmt.Println(err.Error()) 753 | return 754 | } 755 | // A "value" model. 756 | value := examples.Relationship{ 757 | LeftId: 1, 758 | RightId: 10, 759 | Toggle: false, 760 | } 761 | // Pass "value" model by address. 762 | if err = examples.Models.Insert(db, &value); err != nil { 763 | fmt.Println(err) 764 | return 765 | } 766 | fmt.Println("Insert success.") 767 | } 768 | { 769 | // Demonstrates UPDATE of a single model. 770 | db, err := examples.Connect(examples.ExRelationshipUpdate) 771 | if err != nil { 772 | fmt.Println(err.Error()) 773 | return 774 | } 775 | // A pointer model. 776 | relate := &examples.Relationship{ 777 | LeftId: 1, 778 | RightId: 10, 779 | Toggle: true, 780 | } 781 | // Pass pointer model directly. 782 | if err = examples.Models.Update(db, relate); err != nil { 783 | fmt.Println(err) 784 | return 785 | } 786 | fmt.Println("Update success.") 787 | } 788 | { 789 | // Demonstrates UPSERT of a single model. 790 | db, err := examples.Connect(examples.ExRelationshipUpsert) 791 | if err != nil { 792 | fmt.Println(err.Error()) 793 | return 794 | } 795 | // 796 | relate := &examples.Relationship{ 797 | LeftId: 1, 798 | RightId: 10, 799 | Toggle: false, 800 | } 801 | if err = examples.Models.Upsert(db, relate); err != nil { 802 | fmt.Println(err) 803 | return 804 | } 805 | fmt.Println("Upsert success.") 806 | } 807 | 808 | // Output: Insert success. 809 | // Update success. 810 | // Upsert success. 811 | } 812 | 813 | func ExampleModels_relationshipSlice() { 814 | // This single example shows the common cases of INSERT|UPDATE|UPSERT as distinct code blocks. 815 | // examples.Relationship has a composite key and no auto-updating fields. 816 | { 817 | // Demonstrates INSERT of a slice of models. 818 | db, err := examples.Connect(examples.ExRelationshipInsertSlice) 819 | if err != nil { 820 | fmt.Println(err.Error()) 821 | return 822 | } 823 | // Slice of "values". 824 | relate := []examples.Relationship{ 825 | { 826 | LeftId: 1, 827 | RightId: 10, 828 | Toggle: false, 829 | }, 830 | { 831 | LeftId: 2, 832 | RightId: 20, 833 | Toggle: true, 834 | }, 835 | { 836 | LeftId: 3, 837 | RightId: 30, 838 | Toggle: false, 839 | }, 840 | } 841 | // Pass slice of "values" directly. 842 | if err = examples.Models.Insert(db, relate); err != nil { 843 | fmt.Println(err) 844 | return 845 | } 846 | fmt.Println("Insert success.") 847 | } 848 | { 849 | // Demonstrates UPDATE of a slice of models. 850 | db, err := examples.Connect(examples.ExRelationshipUpdateSlice) 851 | if err != nil { 852 | fmt.Println(err.Error()) 853 | return 854 | } 855 | // Slice of pointers. 856 | relate := []*examples.Relationship{ 857 | { 858 | LeftId: 1, 859 | RightId: 10, 860 | Toggle: true, 861 | }, 862 | { 863 | LeftId: 2, 864 | RightId: 20, 865 | Toggle: false, 866 | }, 867 | { 868 | LeftId: 3, 869 | RightId: 30, 870 | Toggle: true, 871 | }, 872 | } 873 | // Pass slice of pointers directly. 874 | if err = examples.Models.Update(db, relate); err != nil { 875 | fmt.Println(err) 876 | return 877 | } 878 | fmt.Println("Update success.") 879 | } 880 | { 881 | // Demonstrates UPSERT of a slice of models. 882 | db, err := examples.Connect(examples.ExRelationshipUpsertSlice) 883 | if err != nil { 884 | fmt.Println(err.Error()) 885 | return 886 | } 887 | // 888 | relate := []*examples.Relationship{ 889 | { 890 | LeftId: 1, 891 | RightId: 10, 892 | Toggle: false, 893 | }, 894 | { 895 | LeftId: 2, 896 | RightId: 20, 897 | Toggle: true, 898 | }, 899 | { 900 | LeftId: 3, 901 | RightId: 30, 902 | Toggle: false, 903 | }, 904 | } 905 | if err = examples.Models.Upsert(db, relate); err != nil { 906 | fmt.Println(err) 907 | return 908 | } 909 | fmt.Println("Upsert success.") 910 | } 911 | 912 | // Output: Insert success. 913 | // Update success. 914 | // Upsert success. 915 | } 916 | -------------------------------------------------------------------------------- /model/model.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "github.com/nofeaturesonlybugs/set" 5 | "github.com/nofeaturesonlybugs/set/path" 6 | 7 | "github.com/nofeaturesonlybugs/sqlh/model/statements" 8 | "github.com/nofeaturesonlybugs/sqlh/schema" 9 | ) 10 | 11 | // Model relates a Go type to its Table. 12 | type Model struct { 13 | // Table is the related database table. 14 | Table schema.Table 15 | 16 | // Statements are the SQL database statements. 17 | Statements statements.Table 18 | 19 | // SaveMode is set during model registration and inspected during Models.Save 20 | // to determine which of Insert, Update, or Upsert operations to use. 21 | // 22 | // SaveMode=InsertOrUpdate means InsertUpdatePaths is a non-empty slice of 23 | // key field traversal information. The key fields are examined and if 24 | // any are non-zero values then Update will be used otherwise Insert. 25 | SaveMode SaveMode 26 | InsertUpdatePaths []path.ReflectPath 27 | 28 | // Mapping is the column to struct field mapping. 29 | Mapping set.Mapping 30 | } 31 | 32 | // BindQuery returns a QueryBinding that facilitates running queries against 33 | // instaces of the model. 34 | func (me *Model) BindQuery(mapper *set.Mapper, query *statements.Query) QueryBinding { 35 | return QueryBinding{ 36 | mapper: mapper, 37 | model: me, 38 | query: query, 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /model/models.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | 8 | "github.com/nofeaturesonlybugs/errors" 9 | "github.com/nofeaturesonlybugs/set" 10 | "github.com/nofeaturesonlybugs/set/path" 11 | "github.com/nofeaturesonlybugs/sqlh" 12 | "github.com/nofeaturesonlybugs/sqlh/grammar" 13 | "github.com/nofeaturesonlybugs/sqlh/model/statements" 14 | "github.com/nofeaturesonlybugs/sqlh/schema" 15 | ) 16 | 17 | // Models is a registry of Models and methods to manipulate them. 18 | type Models struct { 19 | // 20 | // Mapper defines how SQL column names map to fields in Go structs. 21 | Mapper *set.Mapper 22 | // 23 | // Grammar defines the SQL grammar to use for SQL generation. 24 | Grammar grammar.Grammar 25 | // 26 | // Models is a map of Go types to Model instances. This member is automatically 27 | // instantiated during calls to Register(). 28 | Models map[reflect.Type]*Model 29 | // 30 | // StructTag specifies the struct tag name to use when inspecting types 31 | // during register. If not set will default to "model". 32 | StructTag string 33 | } 34 | 35 | // Register adds a Go type to the Models instance. 36 | // 37 | // Register is not goroutine safe; implement locking in the store or application level if required. 38 | // 39 | // When Register is called with a type T the following registrations are made: 40 | // T, *T, []T, & []*T 41 | // 42 | // As a convenience register can be called with a reflect.Type as the value. 43 | func (me *Models) Register(value interface{}, opts ...interface{}) { 44 | tagName := me.StructTag 45 | if tagName == "" { 46 | tagName = "model" 47 | } 48 | // 49 | if me.Models == nil { 50 | me.Models = make(map[reflect.Type]*Model) 51 | } 52 | // 53 | var typ reflect.Type 54 | var typInfo set.TypeInfo 55 | switch sw := value.(type) { 56 | case reflect.Type: 57 | typ = sw 58 | typInfo = set.TypeCache.StatType(typ) 59 | default: 60 | typ = reflect.TypeOf(value) 61 | typInfo = set.TypeCache.Stat(value) // Consider creating a local type cache to the Models type. 62 | } 63 | if _, ok := me.Models[typ]; ok { 64 | return // Already registered. 65 | } 66 | // 67 | // Get the table name from embedded TableName field. 68 | var tableName string 69 | for _, opt := range opts { 70 | if tn, ok := opt.(TableName); ok { 71 | tableName = string(tn) 72 | } 73 | } 74 | if tableName == "" { 75 | for _, field := range typInfo.StructFields { 76 | if field.Type == typeTableName { 77 | tableName = field.Tag.Get(tagName) 78 | break 79 | } 80 | } 81 | } 82 | if tableName == "" { 83 | panic("no table name; call Register with a TableName value or embed TableName into your struct") 84 | } 85 | // 86 | // Now map the columns. 87 | mapping := me.Mapper.Map(value) 88 | // 89 | // key is the Columns for the table's primary key. 90 | // unique is the slice of unique indexes on the table. 91 | // columns are the non-primary key columns and includes columns in unique. 92 | key, unique, columns := []schema.Column{}, []schema.Index{}, []schema.Column{} 93 | // 94 | // The following slices keep track of column names in the database. 95 | // autoKeyNames, keyNames 96 | // + Primary key column names. 97 | // + Composite primary keys (keys using multiple fields) are supported. 98 | // + autoKeyNames are columns automatically populated by the database such as auto incrementing integer keys. 99 | // autoInsertNames, autoUpdateNames, autoInsertUpdateNames 100 | // + Columns automatically populated by the database such as created or modified timestamps. 101 | // + autoInsertUpdateNames is UNIQUE( UNION( autoInsertNames, autoUpdateNames ) ). 102 | // columnNames 103 | // + All other column names that need to be explicitly set during insert/update operations. 104 | // 105 | // NB: auto* columns are not currently limited to any specific type. 106 | autoKeyNames, autoInsertNames, autoUpdateNames, autoInsertUpdateNames, keyNames, columnNames := []string{}, []string{}, []string{}, []string{}, []string{}, []string{} 107 | for _, name := range mapping.Keys { 108 | field := mapping.StructFields[name] 109 | if field.Type == typeTableName { 110 | // Leave as empty case to ensure embedded TableName is not used for column information. 111 | } else { 112 | // Create the Column type. 113 | column := schema.Column{ 114 | Name: name, 115 | GoType: reflect.Zero(field.Type).Interface(), 116 | // TODO SqlType 117 | } 118 | // Get the struct field tag and then classify the column accordingly. 119 | tag := field.Tag.Get(tagName) 120 | if tag == "key" || strings.HasPrefix(tag, "key,") { 121 | // tag=key or tag=key,auto is a primary key field. 122 | key = append(key, column) 123 | if strings.Contains(tag, ",auto") { 124 | autoKeyNames = append(autoKeyNames, name) 125 | } else { 126 | keyNames = append(keyNames, name) 127 | } 128 | } else if insert, update := strings.Contains(tag, "inserted"), strings.Contains(tag, "updated"); insert || update { 129 | // inserted or updated signals the column is populated on insert or update statements respectively. 130 | if insert { 131 | autoInsertNames = append(autoInsertNames, name) 132 | } 133 | if update { 134 | autoUpdateNames = append(autoUpdateNames, name) 135 | } 136 | if insert || update { 137 | autoInsertUpdateNames = append(autoInsertUpdateNames, name) 138 | } 139 | } else { 140 | // All other columns are explicitly set during queries. 141 | columns = append(columns, column) 142 | columnNames = append(columnNames, name) 143 | } 144 | if strings.Contains(tag, "unique") { 145 | // unique signals the column is part of a unique index. 146 | // TODO Currently only single column unique indexes are supported; should also support multi-column. 147 | // TODO The above comment is a lie -- indexes aren't supported at all yet. 148 | index := schema.Index{ 149 | Name: "", // TODO Index name. 150 | Columns: []schema.Column{column}, 151 | IsPrimary: false, 152 | IsUnique: true, 153 | } 154 | unique = append(unique, index) 155 | } 156 | } 157 | } 158 | // 159 | // Determine the model's save mode. 160 | var saveMode SaveMode 161 | switch { 162 | case len(keyNames) > 0: 163 | saveMode = Upsert 164 | case len(autoKeyNames) == 0: 165 | saveMode = Insert 166 | default: 167 | saveMode = InsertOrUpdate 168 | } 169 | var insertUpdatePaths []path.ReflectPath 170 | for _, keyName := range autoKeyNames { 171 | insertUpdatePaths = append(insertUpdatePaths, mapping.ReflectPaths[keyName]) 172 | } 173 | // 174 | // Merge autoKeyNames into autoInsertNames as those keys are generated during insert statements. 175 | autoInsertNames = append(autoKeyNames, autoInsertNames...) 176 | // Create table struct. 177 | table := schema.Table{ 178 | Name: tableName, 179 | PrimaryKey: schema.Index{ 180 | Name: "", // TODO Index name. 181 | Columns: key, 182 | IsPrimary: true, 183 | IsUnique: true, 184 | }, 185 | Unique: unique, 186 | Columns: columns, 187 | } 188 | // Create model struct. 189 | model := &Model{ 190 | Table: table, 191 | Statements: statements.Table{}, 192 | SaveMode: saveMode, 193 | InsertUpdatePaths: insertUpdatePaths, 194 | Mapping: mapping, 195 | } 196 | // Fill in query statements. 197 | // NB: Ignore errors here as we'll handle when a query is nil for a model in our other functions. 198 | model.Statements.Insert, _ = me.Grammar.Insert(tableName, append(keyNames, columnNames...), autoInsertNames) 199 | model.Statements.Update, _ = me.Grammar.Update(tableName, columnNames, append(autoKeyNames, keyNames...), autoUpdateNames) 200 | model.Statements.Delete, _ = me.Grammar.Delete(tableName, append(autoKeyNames, keyNames...)) 201 | model.Statements.Upsert, _ = me.Grammar.Upsert(tableName, columnNames, keyNames, autoInsertUpdateNames) 202 | // 203 | // We want to be able to look up the model by the original type T passed to this function 204 | // as well as []T. 205 | me.Models[typ] = model 206 | me.Models[reflect.PtrTo(typ)] = model 207 | me.Models[reflect.SliceOf(typ)] = model 208 | me.Models[reflect.SliceOf(reflect.PtrTo(typ))] = model 209 | } 210 | 211 | // Lookup returns the model associated with the value. 212 | func (me *Models) Lookup(value interface{}) (m *Model, err error) { 213 | if me == nil { 214 | err = errors.NilReceiver() 215 | return 216 | } 217 | var ok bool 218 | t := reflect.TypeOf(value) 219 | if m, ok = me.Models[t]; ok { 220 | return 221 | } 222 | err = errors.Errorf("%T not registered", value) 223 | return 224 | } 225 | 226 | // Insert attempts to persist values via INSERTs. 227 | func (me *Models) Insert(Q sqlh.IQueries, value interface{}) error { 228 | var model *Model 229 | var query *statements.Query 230 | var binding QueryBinding 231 | var err error 232 | if model, err = me.Lookup(value); err != nil { 233 | return errors.Go(err) 234 | } else if query = model.Statements.Insert; query == nil { 235 | return errors.Go(ErrUnsupported).Tag("INSERT", fmt.Sprintf("%T", value)) 236 | } 237 | // 238 | binding = model.BindQuery(me.Mapper, query) 239 | if err = binding.Query(Q, value); err != nil { 240 | return errors.Go(err) 241 | } 242 | // 243 | return nil 244 | } 245 | 246 | // Update attempts to persist values via UPDATESs. 247 | func (me *Models) Update(Q sqlh.IQueries, value interface{}) error { 248 | var model *Model 249 | var query *statements.Query 250 | var binding QueryBinding 251 | var err error 252 | if model, err = me.Lookup(value); err != nil { 253 | return errors.Go(err) 254 | } else if query = model.Statements.Update; query == nil { 255 | return errors.Go(ErrUnsupported).Tag("UPDATE", fmt.Sprintf("%T", value)) 256 | } 257 | // 258 | binding = model.BindQuery(me.Mapper, query) 259 | if err = binding.Query(Q, value); err != nil { 260 | return errors.Go(err) 261 | } 262 | // 263 | return nil 264 | } 265 | 266 | // Save inspects the incoming model and delegates to Insert, Update, or Upsert method 267 | // according to the model's SaveMode value, which is determined during registration. 268 | // 269 | // Models with at least one key field defined as "key" (i.e. not "key,auto") use Upsert. 270 | // 271 | // Models with zero "key" and "key,auto" fields use Insert. 272 | // 273 | // Otherwise the model has only "key,auto" fields and will use Update if any such field 274 | // is a non-zero value and Insert otherwise. 275 | // 276 | // If value is a slice []M then the first element is inspected to determine which of 277 | // Insert, Update, or Upsert is applied to the entire slice. 278 | func (me *Models) Save(Q sqlh.IQueries, value interface{}) error { 279 | model, err := me.Lookup(value) 280 | if err != nil { 281 | return errors.Go(err) 282 | } 283 | switch model.SaveMode { 284 | case Insert: 285 | return me.Insert(Q, value) 286 | case Upsert: 287 | return me.Upsert(Q, value) 288 | case InsertOrUpdate: 289 | v := reflect.ValueOf(value) 290 | switch v.Kind() { 291 | case reflect.Slice: 292 | if v.Len() == 0 { 293 | return nil 294 | } 295 | for v = v.Index(0); v.Kind() == reflect.Ptr; v = v.Elem() { 296 | if v.IsNil() { 297 | return errors.Go(ErrUnsupported).Tag("nil pointer", fmt.Sprintf("%v %v", v.Type(), v.Interface())) 298 | } 299 | } 300 | case reflect.Ptr: 301 | for ; v.Kind() == reflect.Ptr; v = v.Elem() { 302 | if v.IsNil() { 303 | return errors.Go(ErrUnsupported).Tag("nil pointer", fmt.Sprintf("%v %v", v.Type(), v.Interface())) 304 | } 305 | } 306 | } 307 | // TODO Possibly add support for an InsertUpdater interface 308 | // 309 | var keyValue reflect.Value 310 | for _, path := range model.InsertUpdatePaths { 311 | keyValue = path.Value(v) 312 | if !keyValue.IsZero() { 313 | // A non-zero field value means update. 314 | return me.Update(Q, value) 315 | } 316 | } 317 | return me.Insert(Q, value) 318 | } 319 | // Currently it _should_be_ impossible for this to occur. The first thing this method 320 | // does is find the associated model and -- if not found -- returns error. Any model that 321 | // is found is (currently at least) guaranteed to have a SaveMode corresponding to one 322 | // of the above values. 323 | return errors.Go(ErrUnsupported).Tag("SAVE", fmt.Sprintf("%T", value)) 324 | } 325 | 326 | // Upsert attempts to persist values via UPSERTs. 327 | // 328 | // Upsert only works on primary keys that are defined as "key"; in other words columns tagged with "key,auto" 329 | // are not used in the generated query. 330 | // 331 | // Upsert only supports primary keys; currently there is no support for upsert on UNIQUE indexes that are 332 | // not primary keys. 333 | func (me *Models) Upsert(Q sqlh.IQueries, value interface{}) error { 334 | var model *Model 335 | var query *statements.Query 336 | var binding QueryBinding 337 | var err error 338 | if model, err = me.Lookup(value); err != nil { 339 | return errors.Go(err) 340 | } else if query = model.Statements.Upsert; query == nil { 341 | return errors.Go(ErrUnsupported).Tag("UPSERT", fmt.Sprintf("%T", value)) 342 | } 343 | // 344 | binding = model.BindQuery(me.Mapper, query) 345 | if err = binding.Query(Q, value); err != nil { 346 | return errors.Go(err) 347 | } 348 | // 349 | return nil 350 | } 351 | -------------------------------------------------------------------------------- /model/models_test.go: -------------------------------------------------------------------------------- 1 | package model_test 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/DATA-DOG/go-sqlmock" 10 | "github.com/nofeaturesonlybugs/errors" 11 | "github.com/nofeaturesonlybugs/set" 12 | "github.com/nofeaturesonlybugs/sqlh" 13 | "github.com/nofeaturesonlybugs/sqlh/grammar" 14 | "github.com/nofeaturesonlybugs/sqlh/hobbled" 15 | "github.com/nofeaturesonlybugs/sqlh/model" 16 | "github.com/nofeaturesonlybugs/sqlh/model/examples" 17 | "github.com/stretchr/testify/assert" 18 | ) 19 | 20 | // Test is a test function and descriptive name. 21 | type Test struct { 22 | Name string 23 | Test func(t *testing.T) 24 | } 25 | 26 | // ModelQueryTest describes each test and allows us to compose our tests. 27 | type ModelQueryTest struct { 28 | Name string 29 | DBWrapper hobbled.Wrapper 30 | MockFn func(mock sqlmock.Sqlmock) 31 | ModelsFn func(Q sqlh.IQueries, Data interface{}) error 32 | Data interface{} 33 | ExpectError bool 34 | } 35 | 36 | // ModelQueryTestSlice is a slice of Meta objects. 37 | type ModelQueryTestSlice []ModelQueryTest 38 | 39 | // Tests returns a []Test from a ModelQueryMetaSlice. 40 | func (me ModelQueryTestSlice) Tests() []Test { 41 | tests := []Test{} 42 | for _, queryTest := range me { 43 | test := Test{ 44 | Name: fmt.Sprintf("%v: %v", queryTest.DBWrapper.String(), queryTest.Name), 45 | Test: func(qt ModelQueryTest) func(t *testing.T) { 46 | return func(t *testing.T) { 47 | chk := assert.New(t) 48 | dbm, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) 49 | chk.NoError(err) 50 | // 51 | db := qt.DBWrapper.WrapDB(dbm) 52 | qt.MockFn(mock) 53 | // 54 | err = qt.ModelsFn(db, qt.Data) 55 | if qt.ExpectError { 56 | chk.Error(err) 57 | } else { 58 | chk.NoError(err) 59 | } 60 | chk.NoError(mock.ExpectationsWereMet()) 61 | } 62 | }(queryTest), 63 | } 64 | tests = append(tests, test) 65 | } 66 | return tests 67 | } 68 | 69 | func TestModels_Register(t *testing.T) { 70 | t.Run("double register", func(t *testing.T) { 71 | // Double registering a model should hit an early return in Register() 72 | mdb := examples.NewModels() 73 | mdb.Register(&examples.Address{}) 74 | 75 | }) 76 | t.Run("no tablename panics", func(t *testing.T) { 77 | // Models must have a table name when registering. 78 | chk := assert.New(t) 79 | // 80 | type T struct{} 81 | recovered := false 82 | mdb := examples.NewModels() 83 | func() { 84 | defer func() { 85 | if r := recover(); r != nil { 86 | recovered = true 87 | } 88 | }() 89 | mdb.Register(&T{}) 90 | }() 91 | chk.True(recovered) 92 | }) 93 | t.Run("tablename option", func(t *testing.T) { 94 | // We can provide table name either by embedding in struct or passing as argument to Register() 95 | chk := assert.New(t) 96 | // 97 | type A struct { 98 | model.TableName `model:"table_a"` 99 | Name string `db:"name"` 100 | } 101 | type B struct { 102 | Name string `db:"name"` 103 | } 104 | // 105 | mdb := examples.NewModels() 106 | mdb.Register(&A{}) 107 | mdb.Register(&B{}, model.TableName("table_b")) 108 | // 109 | mdl, err := mdb.Lookup(&A{}) 110 | chk.NoError(err) 111 | chk.NotNil(mdl) 112 | chk.Equal("table_a", mdl.Table.Name) 113 | // 114 | mdl, err = mdb.Lookup(&B{}) 115 | chk.NoError(err) 116 | chk.NotNil(mdl) 117 | chk.Equal("table_b", mdl.Table.Name) 118 | }) 119 | t.Run("reflect type", func(t *testing.T) { 120 | // We can register types via reflect.Type 121 | chk := assert.New(t) 122 | // 123 | type A struct { 124 | model.TableName `model:"table_a"` 125 | Name string `db:"name"` 126 | } 127 | // 128 | mdb := examples.NewModels() 129 | mdb.Register(reflect.TypeOf((*A)(nil)).Elem()) 130 | // 131 | mdl, err := mdb.Lookup(&A{}) 132 | chk.NoError(err) 133 | chk.NotNil(mdl) 134 | chk.Equal("table_a", mdl.Table.Name) 135 | }) 136 | } 137 | 138 | func TestModels_Save(t *testing.T) { 139 | mdb := examples.NewModels() 140 | 141 | t.Run("nil ptr", func(t *testing.T) { 142 | // Code overage for (*T)(nil) 143 | chk := assert.New(t) 144 | address := (*examples.Address)(nil) 145 | err := mdb.Save(nil, address) 146 | chk.Error(err) 147 | }) 148 | t.Run("slice empty", func(t *testing.T) { 149 | // Code overage for []T or []*T when slice is empty 150 | chk := assert.New(t) 151 | 152 | addresses := []*examples.Address(nil) 153 | err := mdb.Save(nil, addresses) 154 | chk.NoError(err) 155 | 156 | addresses = []*examples.Address{} 157 | err = mdb.Save(nil, addresses) 158 | chk.NoError(err) 159 | }) 160 | t.Run("slice nil entry", func(t *testing.T) { 161 | // Code overage for []*T when first element is nil 162 | chk := assert.New(t) 163 | addresses := []*examples.Address{nil} 164 | err := mdb.Save(nil, addresses) 165 | chk.Error(err) 166 | }) 167 | t.Run("unsupported type", func(t *testing.T) { 168 | // Code overage for unsupported type 169 | chk := assert.New(t) 170 | var unsupported *[]examples.Address 171 | err := mdb.Save(nil, unsupported) 172 | chk.Error(err) 173 | }) 174 | } 175 | 176 | func TestModels_TableNameTypeChecking(t *testing.T) { 177 | // model.TableName is a type string; want to make sure we can differentiate it from other strings. 178 | // 179 | tn, str := model.TableName(""), "Hello, World!" 180 | // 181 | i := interface{}(tn) 182 | switch i.(type) { 183 | case string: 184 | t.Fatalf("tn hits string case") 185 | case model.TableName: 186 | } 187 | // 188 | i = interface{}(str) 189 | switch i.(type) { 190 | case model.TableName: 191 | t.Fatalf("tn hits mode.TableName case") 192 | case string: 193 | } 194 | } 195 | 196 | func TestModels_NilReceiver(t *testing.T) { 197 | // Some functions have nil receiver checks. 198 | chk := assert.New(t) 199 | // 200 | var mdb *model.Models 201 | m, err := mdb.Lookup(nil) 202 | chk.Error(err) 203 | chk.Nil(m) 204 | } 205 | 206 | func TestModels_NilArguments(t *testing.T) { 207 | // Passing nil to some functions hits early return statements. 208 | chk := assert.New(t) 209 | // 210 | db, _, _ := sqlmock.New() 211 | mdb := examples.NewModels() 212 | m, err := mdb.Lookup(nil) 213 | chk.Error(err) 214 | chk.Nil(m) 215 | // 216 | err = mdb.Insert(db, nil) 217 | chk.Error(err) 218 | err = mdb.Update(db, nil) 219 | chk.Error(err) 220 | err = mdb.Upsert(db, nil) 221 | chk.Error(err) 222 | } 223 | 224 | func TestModelsUnsupported(t *testing.T) { 225 | chk := assert.New(t) 226 | // 227 | type T struct{} 228 | m := &model.Models{ 229 | Grammar: grammar.Sqlite, 230 | Mapper: &set.Mapper{}, 231 | } 232 | m.Register(&T{}, model.TableName("panic_table_T")) 233 | // 234 | var err error 235 | db, _, _ := sqlmock.New() 236 | err = m.Insert(db, &T{}) 237 | chk.Error(err) 238 | chk.Equal(model.ErrUnsupported, errors.Original(err)) 239 | err = m.Update(db, &T{}) 240 | chk.Error(err) 241 | chk.Equal(model.ErrUnsupported, errors.Original(err)) 242 | err = m.Upsert(db, &T{}) 243 | chk.Error(err) 244 | chk.Equal(model.ErrUnsupported, errors.Original(err)) 245 | } 246 | 247 | func TestModelsQueriesError(t *testing.T) { 248 | chk := assert.New(t) 249 | // 250 | db, mock, err := sqlmock.New() 251 | chk.NotNil(db) 252 | chk.NotNil(mock) 253 | chk.NoError(err) 254 | // 255 | mock.ExpectQuery("INSERT+").WillReturnError(errors.Errorf("some error")) 256 | err = examples.Models.Insert(db, &examples.Address{}) 257 | chk.Error(err) 258 | // 259 | mock.ExpectQuery("INSERT+").WillReturnError(errors.Errorf("some error")) 260 | err = examples.Models.Insert(db, []*examples.Address{{}, {}}) 261 | chk.Error(err) 262 | // 263 | mock.ExpectQuery("UPDATE+").WillReturnError(errors.Errorf("some error")) 264 | err = examples.Models.Update(db, &examples.Address{}) 265 | chk.Error(err) 266 | // 267 | mock.ExpectQuery("UPDATE+").WillReturnError(errors.Errorf("some error")) 268 | err = examples.Models.Update(db, []*examples.Address{{}, {}}) 269 | chk.Error(err) 270 | // 271 | mock.ExpectQuery("UPSERT+").WillReturnError(errors.Errorf("some error")) 272 | err = examples.Models.Upsert(db, &examples.Upsertable{}) 273 | chk.Error(err) 274 | // 275 | mock.ExpectQuery("UPSERT+").WillReturnError(errors.Errorf("some error")) 276 | err = examples.Models.Upsert(db, []*examples.Upsertable{{}, {}}) 277 | chk.Error(err) 278 | } 279 | 280 | // MakeModelQueryTestsForCompositeKeyNoAuto builds a slice of Test types to test a model with 281 | // composite primary key and no auto-updating fields. 282 | func MakeModelQueryTestsForCompositeKeyNoAuto() []Test { 283 | // Relationship is a model with a composite primary key and no fields that auto update. 284 | // Such a model might exist for relationship tables. 285 | type Relationship struct { 286 | model.TableName `json:"-" model:"relationship"` 287 | // 288 | LeftId int `json:"left_id" db:"left_fk" model:"key"` 289 | RightId int `json:"right_id" db:"right_fk" model:"key"` 290 | // Such a table might have other columns. 291 | Toggle bool `json:"toggle"` 292 | } 293 | // 294 | models := &model.Models{ 295 | Mapper: &set.Mapper{ 296 | Join: "_", 297 | Tags: []string{"db", "json"}, 298 | }, 299 | Grammar: grammar.Postgres, 300 | } 301 | models.Register(&Relationship{}) 302 | // 303 | relate := &Relationship{ 304 | LeftId: -1, 305 | RightId: -10, 306 | Toggle: false, 307 | } 308 | relateSlice := []*Relationship{ 309 | { 310 | LeftId: 1, 311 | RightId: 10, 312 | Toggle: false, 313 | }, 314 | { 315 | LeftId: 2, 316 | RightId: 20, 317 | Toggle: true, 318 | }, 319 | { 320 | LeftId: -3, 321 | RightId: -30, 322 | Toggle: false, 323 | }, 324 | } // 325 | SQLInsert := strings.Join([]string{ 326 | "INSERT INTO relationship", 327 | "\t\t( left_fk, right_fk, toggle )", 328 | "\tVALUES", 329 | "\t\t( $1, $2, $3 )", 330 | }, "\n") 331 | SQLUpdate := strings.Join([]string{ 332 | "UPDATE relationship SET", 333 | "\t\ttoggle = $1", 334 | "\tWHERE", 335 | "\t\tleft_fk = $2 AND right_fk = $3", 336 | }, "\n") 337 | SQLUpsert := strings.Join([]string{ 338 | "INSERT INTO relationship AS dest", 339 | "\t\t( left_fk, right_fk, toggle )", 340 | "\tVALUES", 341 | "\t\t( $1, $2, $3 )", 342 | "\tON CONFLICT( left_fk, right_fk ) DO UPDATE SET", 343 | "\t\ttoggle = EXCLUDED.toggle", 344 | "\t\tWHERE (", 345 | "\t\t\tdest.toggle <> EXCLUDED.toggle", 346 | "\t\t)", 347 | }, "\n") 348 | 349 | // 350 | meta := []ModelQueryTest{ 351 | // 352 | // INSERTS 353 | { 354 | Name: "insert single with error", 355 | DBWrapper: hobbled.Passthru, 356 | MockFn: func(mock sqlmock.Sqlmock) { 357 | mock.ExpectExec(SQLInsert).WithArgs(-1, -10, false).WillReturnError(fmt.Errorf("relationship error")) 358 | }, 359 | ExpectError: true, 360 | ModelsFn: models.Insert, 361 | Data: relate, 362 | }, 363 | { 364 | Name: "insert slice with error", 365 | DBWrapper: hobbled.Passthru, 366 | MockFn: func(mock sqlmock.Sqlmock) { 367 | mock.ExpectBegin() 368 | prepare := mock.ExpectPrepare(SQLInsert) 369 | prepare.ExpectExec().WithArgs(1, 10, false).WillReturnResult(sqlmock.NewResult(0, 1)) 370 | prepare.ExpectExec().WithArgs(2, 20, true).WillReturnResult(sqlmock.NewResult(0, 1)) 371 | prepare.ExpectExec().WithArgs(-3, -30, false).WillReturnError(fmt.Errorf("relationship slice error")) 372 | mock.ExpectRollback() 373 | }, 374 | ExpectError: true, 375 | ModelsFn: models.Insert, 376 | Data: relateSlice, 377 | }, 378 | { 379 | Name: "insert slice with error", 380 | DBWrapper: hobbled.NoBegin, 381 | MockFn: func(mock sqlmock.Sqlmock) { 382 | prepare := mock.ExpectPrepare(SQLInsert) 383 | prepare.ExpectExec().WithArgs(1, 10, false).WillReturnResult(sqlmock.NewResult(0, 1)) 384 | prepare.ExpectExec().WithArgs(2, 20, true).WillReturnResult(sqlmock.NewResult(0, 1)) 385 | prepare.ExpectExec().WithArgs(-3, -30, false).WillReturnError(fmt.Errorf("relationship slice error")) 386 | }, 387 | ExpectError: true, 388 | ModelsFn: models.Insert, 389 | Data: relateSlice, 390 | }, 391 | { 392 | Name: "insert slice with error", 393 | DBWrapper: hobbled.NoBeginNoPrepare, 394 | MockFn: func(mock sqlmock.Sqlmock) { 395 | mock.ExpectExec(SQLInsert).WithArgs(1, 10, false).WillReturnResult(sqlmock.NewResult(0, 1)) 396 | mock.ExpectExec(SQLInsert).WithArgs(2, 20, true).WillReturnResult(sqlmock.NewResult(0, 1)) 397 | mock.ExpectExec(SQLInsert).WithArgs(-3, -30, false).WillReturnError(fmt.Errorf("relationship slice error")) 398 | }, 399 | ExpectError: true, 400 | ModelsFn: models.Insert, 401 | Data: relateSlice, 402 | }, 403 | // 404 | // UPDATES 405 | { 406 | Name: "update single with error", 407 | DBWrapper: hobbled.Passthru, 408 | MockFn: func(mock sqlmock.Sqlmock) { 409 | mock.ExpectExec(SQLUpdate).WithArgs(false, -1, -10).WillReturnError(fmt.Errorf("relationship error")) 410 | }, 411 | ExpectError: true, 412 | ModelsFn: models.Update, 413 | Data: relate, 414 | }, 415 | { 416 | Name: "update slice with error", 417 | DBWrapper: hobbled.Passthru, 418 | MockFn: func(mock sqlmock.Sqlmock) { 419 | mock.ExpectBegin() 420 | prepare := mock.ExpectPrepare(SQLUpdate) 421 | prepare.ExpectExec().WithArgs(false, 1, 10).WillReturnResult(sqlmock.NewResult(0, 1)) 422 | prepare.ExpectExec().WithArgs(true, 2, 20).WillReturnResult(sqlmock.NewResult(0, 1)) 423 | prepare.ExpectExec().WithArgs(false, -3, -30).WillReturnError(fmt.Errorf("relationship slice error")) 424 | mock.ExpectRollback() 425 | }, 426 | ExpectError: true, 427 | ModelsFn: models.Update, 428 | Data: relateSlice, 429 | }, 430 | { 431 | Name: "update slice with error", 432 | DBWrapper: hobbled.NoBegin, 433 | MockFn: func(mock sqlmock.Sqlmock) { 434 | prepare := mock.ExpectPrepare(SQLUpdate) 435 | prepare.ExpectExec().WithArgs(false, 1, 10).WillReturnResult(sqlmock.NewResult(0, 1)) 436 | prepare.ExpectExec().WithArgs(true, 2, 20).WillReturnResult(sqlmock.NewResult(0, 1)) 437 | prepare.ExpectExec().WithArgs(false, -3, -30).WillReturnError(fmt.Errorf("relationship slice error")) 438 | }, 439 | ExpectError: true, 440 | ModelsFn: models.Update, 441 | Data: relateSlice, 442 | }, 443 | { 444 | Name: "update slice with error", 445 | DBWrapper: hobbled.NoBeginNoPrepare, 446 | MockFn: func(mock sqlmock.Sqlmock) { 447 | mock.ExpectExec(SQLUpdate).WithArgs(false, 1, 10).WillReturnResult(sqlmock.NewResult(0, 1)) 448 | mock.ExpectExec(SQLUpdate).WithArgs(true, 2, 20).WillReturnResult(sqlmock.NewResult(0, 1)) 449 | mock.ExpectExec(SQLUpdate).WithArgs(false, -3, -30).WillReturnError(fmt.Errorf("relationship slice error")) 450 | }, 451 | ExpectError: true, 452 | ModelsFn: models.Update, 453 | Data: relateSlice, 454 | }, 455 | // 456 | // UPSERT 457 | { 458 | Name: "upsert single with error", 459 | DBWrapper: hobbled.Passthru, 460 | MockFn: func(mock sqlmock.Sqlmock) { 461 | mock.ExpectExec(SQLUpsert).WithArgs(-1, -10, false).WillReturnError(fmt.Errorf("relationship error")) 462 | }, 463 | ExpectError: true, 464 | ModelsFn: models.Upsert, 465 | Data: relate, 466 | }, 467 | { 468 | Name: "upsert slice with error", 469 | DBWrapper: hobbled.Passthru, 470 | MockFn: func(mock sqlmock.Sqlmock) { 471 | mock.ExpectBegin() 472 | prepare := mock.ExpectPrepare(SQLUpsert) 473 | prepare.ExpectExec().WithArgs(1, 10, false).WillReturnResult(sqlmock.NewResult(0, 1)) 474 | prepare.ExpectExec().WithArgs(2, 20, true).WillReturnResult(sqlmock.NewResult(0, 1)) 475 | prepare.ExpectExec().WithArgs(-3, -30, false).WillReturnError(fmt.Errorf("relationship slice error")) 476 | mock.ExpectRollback() 477 | }, 478 | ExpectError: true, 479 | ModelsFn: models.Upsert, 480 | Data: relateSlice, 481 | }, 482 | { 483 | Name: "upsert slice with error", 484 | DBWrapper: hobbled.NoBegin, 485 | MockFn: func(mock sqlmock.Sqlmock) { 486 | prepare := mock.ExpectPrepare(SQLUpsert) 487 | prepare.ExpectExec().WithArgs(1, 10, false).WillReturnResult(sqlmock.NewResult(0, 1)) 488 | prepare.ExpectExec().WithArgs(2, 20, true).WillReturnResult(sqlmock.NewResult(0, 1)) 489 | prepare.ExpectExec().WithArgs(-3, -30, false).WillReturnError(fmt.Errorf("relationship slice error")) 490 | }, 491 | ExpectError: true, 492 | ModelsFn: models.Upsert, 493 | Data: relateSlice, 494 | }, 495 | { 496 | Name: "upsert slice with error", 497 | DBWrapper: hobbled.NoBeginNoPrepare, 498 | MockFn: func(mock sqlmock.Sqlmock) { 499 | mock.ExpectExec(SQLUpsert).WithArgs(1, 10, false).WillReturnResult(sqlmock.NewResult(0, 1)) 500 | mock.ExpectExec(SQLUpsert).WithArgs(2, 20, true).WillReturnResult(sqlmock.NewResult(0, 1)) 501 | mock.ExpectExec(SQLUpsert).WithArgs(-3, -30, false).WillReturnError(fmt.Errorf("relationship slice error")) 502 | }, 503 | ExpectError: true, 504 | ModelsFn: models.Upsert, 505 | Data: relateSlice, 506 | }, 507 | } 508 | return ModelQueryTestSlice(meta).Tests() 509 | } 510 | 511 | func TestModels_Suite(t *testing.T) { 512 | for _, test := range MakeModelQueryTestsForCompositeKeyNoAuto() { 513 | t.Run(test.Name, test.Test) 514 | } 515 | } 516 | -------------------------------------------------------------------------------- /model/pkg_doc.go: -------------------------------------------------------------------------------- 1 | // Package model allows Go structs to behave as database models. 2 | // 3 | // While this package exports several types the only one you currently need 4 | // to be concerned with is type Models. All of the examples in this package 5 | // use a global instance of Models defined in the examples subpackage; you may 6 | // refer to that global instance for an instantiation example. 7 | // 8 | // Note that in the examples for this package when you see examples.Models 9 | // or examples.Connect() it is referring the examples subdirectory for 10 | // this package and NOT the subdirectory for sqlh (i.e. both sqlh and sqlh/model 11 | // have an examples subdirectory.) 12 | package model 13 | -------------------------------------------------------------------------------- /model/query_binding.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "reflect" 7 | 8 | "github.com/nofeaturesonlybugs/set" 9 | 10 | "github.com/nofeaturesonlybugs/sqlh" 11 | "github.com/nofeaturesonlybugs/sqlh/model/statements" 12 | ) 13 | 14 | // QueryBinding binds a model together with a specific query. 15 | type QueryBinding struct { 16 | mapper *set.Mapper 17 | model *Model 18 | query *statements.Query 19 | } 20 | 21 | // Query accepts either a single model M or a slice of models []M. It then 22 | // runs and returns the result of QueryOne or QuerySlice. 23 | func (me QueryBinding) Query(q sqlh.IQueries, value interface{}) error { 24 | if reflect.Slice == reflect.TypeOf(value).Kind() { 25 | if err := me.QuerySlice(q, value); err != nil { 26 | return err 27 | } 28 | } else if err := me.QueryOne(q, value); err != nil { 29 | return err 30 | } 31 | return nil 32 | } 33 | 34 | // QueryOne runs the query against a single instance of the model. 35 | // 36 | // As a special case value can be an instance of reflect.Value. 37 | func (me QueryBinding) QueryOne(q sqlh.IQueries, value interface{}) error { 38 | args, scans := make([]interface{}, len(me.query.Arguments)), make([]interface{}, len(me.query.Scan)) 39 | // 40 | // Create our prepared mapping. Note that if the calls to Plan() succeed then we do 41 | // not need to check errors on the following statement for that plan. 42 | prepared, err := me.mapper.Prepare(value) 43 | if err != nil { 44 | return err // TODO sentinal or wrap? 45 | } 46 | if err := prepared.Plan(me.query.Arguments...); err != nil { 47 | return err 48 | } 49 | _, _ = prepared.Fields(args) 50 | if err := prepared.Plan(me.query.Scan...); err != nil { 51 | return err 52 | } 53 | _, _ = prepared.Assignables(scans) 54 | // 55 | // If no scans then use Exec(). 56 | if len(me.query.Scan) == 0 { 57 | if _, err := q.Exec(me.query.SQL, args...); err != nil { 58 | return err 59 | } 60 | return nil 61 | } 62 | // 63 | row := q.QueryRow(me.query.SQL, args...) 64 | // NB: The error conditions are separated for code coverage purposes. 65 | if err := row.Scan(scans...); err != nil { 66 | if err != sql.ErrNoRows { 67 | return err 68 | } else if err == sql.ErrNoRows && me.query.Expect != statements.ExpectRowOrNone { 69 | return err 70 | } 71 | } 72 | return nil 73 | } 74 | 75 | // QuerySlice runs the query against a slice of model instances. 76 | func (me QueryBinding) QuerySlice(q sqlh.IQueries, values interface{}) error { 77 | v := reflect.ValueOf(values) 78 | if v.Kind() != reflect.Slice { 79 | return fmt.Errorf("values expects a slice; got %T", values) // TODO Sentinal error 80 | } 81 | // Size of slice will be helpful here. 82 | size := v.Len() 83 | if size == 0 { 84 | return nil 85 | } else if size == 1 { 86 | return me.QueryOne(q, v.Index(0)) 87 | } 88 | // 89 | var tx *sql.Tx 90 | var stmt *sql.Stmt 91 | var row *sql.Row 92 | var err error 93 | // 94 | // If the calls to Plan succeed then further calls to Fields or Assignables will not error. 95 | preparedArgs, err := me.mapper.Prepare(v.Index(0)) 96 | if err != nil { 97 | return err // TODO sentinal or wrap? 98 | } 99 | preparedScans := preparedArgs.Copy() 100 | if err = preparedArgs.Plan(me.query.Arguments...); err != nil { 101 | return err 102 | } 103 | if err = preparedScans.Plan(me.query.Scan...); err != nil { 104 | return err 105 | } 106 | args, scans := make([]interface{}, len(me.query.Arguments)), make([]interface{}, len(me.query.Scan)) 107 | // 108 | // If original parameter supports transactions... 109 | if txer, ok := q.(sqlh.IBegins); ok { 110 | if tx, err = txer.Begin(); err != nil { 111 | return err 112 | } 113 | defer tx.Rollback() 114 | q = tx 115 | } 116 | // 117 | // QueryRowFunc normalizes the query row call so the same logic can be used with or without prepared statements. 118 | type ExecFunc func(args ...interface{}) (sql.Result, error) 119 | type QueryRowFunc func(args ...interface{}) *sql.Row 120 | var QueryRow QueryRowFunc 121 | var Exec ExecFunc 122 | // 123 | // Use prepared statement if possible. 124 | if pper, ok := q.(sqlh.IPrepares); ok { 125 | if stmt, err = pper.Prepare(me.query.SQL); err != nil { 126 | return err 127 | } 128 | defer stmt.Close() 129 | Exec = stmt.Exec 130 | QueryRow = stmt.QueryRow 131 | } else { 132 | Exec = func(args ...interface{}) (sql.Result, error) { 133 | return q.Exec(me.query.SQL, args...) 134 | } 135 | QueryRow = func(args ...interface{}) *sql.Row { 136 | return q.QueryRow(me.query.SQL, args...) 137 | } 138 | } 139 | // 140 | // There's a little bit of copy+paste between both conditions. Tread carefully when editing the similar portions. 141 | if len(me.query.Scan) == 0 { 142 | for k := 0; k < size; k++ { 143 | elem := v.Index(k) 144 | preparedArgs.Rebind(elem) 145 | _, _ = preparedArgs.Fields(args) 146 | // 147 | if _, err = Exec(args...); err != nil { 148 | return err 149 | } 150 | } 151 | } else { 152 | for k := 0; k < size; k++ { 153 | elem := v.Index(k) 154 | preparedArgs.Rebind(elem) 155 | _, _ = preparedArgs.Fields(args) 156 | preparedScans.Rebind(elem) 157 | _, _ = preparedScans.Assignables(scans) 158 | // 159 | row = QueryRow(args...) 160 | if err = row.Scan(scans...); err != nil { 161 | if err != sql.ErrNoRows { 162 | return err 163 | } else if err == sql.ErrNoRows && me.query.Expect != statements.ExpectRowOrNone { 164 | return err 165 | } 166 | } 167 | } 168 | } 169 | 170 | // 171 | // If we opened a transaction then attempt to commit. 172 | if tx != nil { 173 | if err = tx.Commit(); err != nil { 174 | return err 175 | } 176 | } 177 | return nil 178 | } 179 | -------------------------------------------------------------------------------- /model/query_binding_test.go: -------------------------------------------------------------------------------- 1 | package model_test 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/DATA-DOG/go-sqlmock" 8 | "github.com/nofeaturesonlybugs/errors" 9 | "github.com/nofeaturesonlybugs/set" 10 | "github.com/stretchr/testify/assert" 11 | 12 | "github.com/nofeaturesonlybugs/sqlh/grammar" 13 | "github.com/nofeaturesonlybugs/sqlh/model" 14 | "github.com/nofeaturesonlybugs/sqlh/model/examples" 15 | "github.com/nofeaturesonlybugs/sqlh/model/statements" 16 | ) 17 | 18 | // no_prepare_db is an IQuery interface that does not support IPrepare. 19 | type no_prepare_db struct { 20 | db *sql.DB 21 | } 22 | 23 | func (me *no_prepare_db) Exec(query string, args ...interface{}) (sql.Result, error) { 24 | return me.db.Exec(query, args...) 25 | } 26 | func (me *no_prepare_db) Query(query string, args ...interface{}) (*sql.Rows, error) { 27 | return me.db.Query(query, args...) 28 | } 29 | func (me *no_prepare_db) QueryRow(query string, args ...interface{}) *sql.Row { 30 | return me.db.QueryRow(query, args...) 31 | } 32 | 33 | func TestQueryBinding(t *testing.T) { 34 | chk := assert.New(t) 35 | // 36 | db, mock, err := sqlmock.New() 37 | chk.NotNil(db) 38 | chk.NotNil(mock) 39 | chk.NoError(err) 40 | // 41 | mdb := examples.NewModels() 42 | // 43 | modelptr, err := mdb.Lookup(&examples.Person{}) 44 | chk.NoError(err) 45 | chk.NotNil(modelptr) 46 | // 47 | { 48 | // Check early return conditions for slices. 49 | // Test qu.Arguments causes the error. 50 | qu := &statements.Query{ 51 | SQL: "INSERT", 52 | Arguments: []string{"first", "last"}, 53 | Scan: []string{"pk"}, 54 | } 55 | bound := modelptr.BindQuery(mdb.Mapper, qu) 56 | // Early return when not a slice. 57 | err = bound.QuerySlice(db, int(0)) 58 | chk.Error(err) 59 | // Early return when nil slice. 60 | err = bound.QuerySlice(db, []*examples.Person(nil)) 61 | chk.NoError(err) 62 | // Early return when empty slice. 63 | err = bound.QuerySlice(db, []*examples.Person{}) 64 | chk.NoError(err) 65 | // Early return when single element. 66 | mock.ExpectQuery("INSERT+").WithArgs("", "").WillReturnRows(sqlmock.NewRows([]string{"pk"}).AddRow(10)) 67 | err = bound.QuerySlice(db, []*examples.Person{{}}) 68 | chk.NoError(err) 69 | } 70 | { 71 | // Check the flow path of Query, QueryOne, and QuerySlice 72 | qu := &statements.Query{ 73 | SQL: "INSERT", 74 | Arguments: []string{"first", "last"}, 75 | Scan: []string{"pk"}, 76 | } 77 | bound := modelptr.BindQuery(mdb.Mapper, qu) 78 | // If begin fails 79 | mock.ExpectBegin().WillReturnError(errors.Errorf("begin fail")) 80 | err = bound.QuerySlice(db, []*examples.Person{{}, {}}) 81 | chk.Error(err) 82 | // If prepare errors. 83 | mock.ExpectBegin() 84 | mock.ExpectPrepare("INSERT+").WillReturnError(errors.Errorf("prepare failed")) 85 | mock.ExpectRollback() 86 | err = bound.QuerySlice(db, []*examples.Person{{}, {}}) 87 | chk.Error(err) 88 | // If query errors. 89 | mock.ExpectBegin() 90 | prepare := mock.ExpectPrepare("INSERT+") 91 | prepare.ExpectQuery().WillReturnError(errors.Errorf("query failed")) 92 | mock.ExpectRollback() 93 | err = bound.QuerySlice(db, []*examples.Person{{}, {}}) 94 | chk.Error(err) 95 | // If commit errors. 96 | mock.ExpectBegin() 97 | prepare = mock.ExpectPrepare("INSERT+") 98 | prepare.ExpectQuery().WillReturnRows(sqlmock.NewRows([]string{"pk"}).AddRow(10)) 99 | prepare.ExpectQuery().WillReturnRows(sqlmock.NewRows([]string{"pk"}).AddRow(20)) 100 | mock.ExpectCommit().WillReturnError(errors.Errorf("commit failed")) 101 | err = bound.QuerySlice(db, []*examples.Person{{}, {}}) 102 | chk.Error(err) 103 | } 104 | } 105 | 106 | func TestQueryBinding_NoPrepares(t *testing.T) { 107 | chk := assert.New(t) 108 | // 109 | db, mock, err := sqlmock.New() 110 | chk.NotNil(db) 111 | chk.NotNil(mock) 112 | chk.NoError(err) 113 | // 114 | type Person struct { 115 | model.TableName `model:"people"` 116 | Id int `model:"key,auto"` 117 | First string 118 | Last string 119 | } 120 | models := model.Models{ 121 | Grammar: grammar.Postgres, 122 | Mapper: &set.Mapper{}, 123 | } 124 | models.Register(&Person{}) 125 | // 126 | modelptr, err := models.Lookup(&Person{}) 127 | chk.NoError(err) 128 | // 129 | qu := &statements.Query{ 130 | SQL: "INSERT", 131 | Arguments: []string{"First", "Last"}, 132 | Scan: []string{"Id"}, 133 | } 134 | { 135 | // Check flow when queryer does not support prepared statements. 136 | // db that can't prepare statements. 137 | db := &no_prepare_db{db} 138 | // 139 | bound := modelptr.BindQuery(models.Mapper, qu) 140 | // If query errors. 141 | mock.ExpectBegin() 142 | mock.ExpectQuery("INSERT+").WillReturnRows(sqlmock.NewRows([]string{"Id"}).AddRow(10)) 143 | mock.ExpectQuery("INSERT+").WillReturnError(errors.Errorf("query failed")) 144 | mock.ExpectRollback() 145 | err = bound.QuerySlice(db, []*Person{{}, {}}) 146 | chk.Error(err) 147 | // If commit errors. 148 | mock.ExpectBegin() 149 | mock.ExpectQuery("INSERT+").WillReturnRows(sqlmock.NewRows([]string{"Id"}).AddRow(10)) 150 | mock.ExpectQuery("INSERT+").WillReturnRows(sqlmock.NewRows([]string{"Id"}).AddRow(20)) 151 | mock.ExpectCommit().WillReturnError(errors.Errorf("commit failed")) 152 | err = bound.QuerySlice(db, []*Person{{}, {}}) 153 | chk.Error(err) 154 | } 155 | } 156 | 157 | func TestQueryBinding_QueryOne(t *testing.T) { 158 | chk := assert.New(t) 159 | // 160 | // 161 | db, mock, err := sqlmock.New() 162 | chk.NotNil(db) 163 | chk.NotNil(mock) 164 | chk.NoError(err) 165 | // 166 | type Person struct { 167 | model.TableName `model:"people"` 168 | Id int `model:"key,auto"` 169 | First string 170 | Last string 171 | } 172 | models := model.Models{ 173 | Grammar: grammar.Postgres, 174 | Mapper: &set.Mapper{}, 175 | } 176 | models.Register(&Person{}) 177 | modelptr, err := models.Lookup(&Person{}) 178 | chk.NoError(err) 179 | // 180 | qu := &statements.Query{ 181 | SQL: "INSERT", 182 | Arguments: []string{"First", "Last"}, 183 | Scan: []string{"Id"}, 184 | Expect: statements.ExpectRow, 185 | } 186 | { 187 | // Query expects one row and gets one row. 188 | bound := modelptr.BindQuery(models.Mapper, qu) 189 | // If query errors. 190 | mock.ExpectQuery("INSERT+").WillReturnRows(sqlmock.NewRows([]string{"Id"}).AddRow(10)) 191 | person := &Person{} 192 | err = bound.QueryOne(db, person) 193 | chk.NoError(err) 194 | chk.Equal(10, person.Id) 195 | } 196 | { 197 | // Query expects one row and gets no rows. 198 | bound := modelptr.BindQuery(models.Mapper, qu) 199 | // If query errors. 200 | mock.ExpectQuery("INSERT+").WillReturnRows(sqlmock.NewRows([]string{"Id"})) 201 | person := &Person{} 202 | err = bound.QueryOne(db, person) 203 | chk.Error(err) 204 | chk.Equal(0, person.Id) 205 | } 206 | { 207 | // Query expects one row or none and gets no rows. 208 | qu.Expect = statements.ExpectRowOrNone 209 | bound := modelptr.BindQuery(models.Mapper, qu) 210 | // If query errors. 211 | mock.ExpectQuery("INSERT+").WillReturnRows(sqlmock.NewRows([]string{"Id"})) 212 | person := &Person{Id: 10} 213 | err = bound.QueryOne(db, person) 214 | chk.NoError(err) 215 | chk.Equal(10, person.Id) 216 | } 217 | } 218 | 219 | func TestQueryBinding_QuerySlice_WithPrepare_ExpectRow_GetNone(t *testing.T) { 220 | chk := assert.New(t) 221 | // 222 | // 223 | db, mock, err := sqlmock.New() 224 | chk.NotNil(db) 225 | chk.NotNil(mock) 226 | chk.NoError(err) 227 | // 228 | type Person struct { 229 | model.TableName `model:"people"` 230 | Id int `model:"key,auto"` 231 | First string 232 | Last string 233 | } 234 | models := model.Models{ 235 | Grammar: grammar.Postgres, 236 | Mapper: &set.Mapper{}, 237 | } 238 | models.Register(&Person{}) 239 | modelptr, err := models.Lookup([]*Person{}) 240 | chk.NoError(err) 241 | // 242 | qu := &statements.Query{ 243 | SQL: "INSERT", 244 | Arguments: []string{"First", "Last"}, 245 | Scan: []string{"Id"}, 246 | Expect: statements.ExpectRow, 247 | } 248 | { 249 | // Query expects one row and gets no rows. 250 | bound := modelptr.BindQuery(models.Mapper, qu) 251 | // If query errors. 252 | mock.ExpectBegin() 253 | stmt := mock.ExpectPrepare("INSERT+") 254 | stmt.ExpectQuery().WillReturnRows(sqlmock.NewRows([]string{"Id"}).RowError(0, sql.ErrNoRows)) 255 | mock.ExpectRollback() 256 | people := []*Person{{}, {}} 257 | err = bound.QuerySlice(db, people) 258 | chk.Error(err) 259 | chk.Equal(0, people[0].Id) 260 | chk.Equal(0, people[1].Id) 261 | chk.NoError(mock.ExpectationsWereMet()) 262 | } 263 | } 264 | 265 | func TestQueryBinding_QuerySlice_WithPrepare_ExpectRowOrNone_GetNone(t *testing.T) { 266 | chk := assert.New(t) 267 | // 268 | // 269 | db, mock, err := sqlmock.New() 270 | chk.NotNil(db) 271 | chk.NotNil(mock) 272 | chk.NoError(err) 273 | // 274 | type Person struct { 275 | model.TableName `model:"people"` 276 | Id int `model:"key,auto"` 277 | First string 278 | Last string 279 | } 280 | models := model.Models{ 281 | Grammar: grammar.Postgres, 282 | Mapper: &set.Mapper{}, 283 | } 284 | models.Register(&Person{}) 285 | modelptr, err := models.Lookup([]*Person{}) 286 | chk.NoError(err) 287 | // 288 | qu := &statements.Query{ 289 | SQL: "INSERT", 290 | Arguments: []string{"First", "Last"}, 291 | Scan: []string{"Id"}, 292 | Expect: statements.ExpectRowOrNone, 293 | } 294 | { 295 | // Query expects one row and gets no rows. 296 | bound := modelptr.BindQuery(models.Mapper, qu) 297 | // If query errors. 298 | mock.ExpectBegin() 299 | stmt := mock.ExpectPrepare("INSERT+") 300 | stmt.ExpectQuery().WillReturnRows(sqlmock.NewRows([]string{"Id"}).RowError(0, sql.ErrNoRows)) 301 | stmt.ExpectQuery().WillReturnRows(sqlmock.NewRows([]string{"Id"}).RowError(0, sql.ErrNoRows)) 302 | mock.ExpectCommit() 303 | people := []*Person{{}, {}} 304 | err = bound.QuerySlice(db, people) 305 | chk.NoError(err) 306 | chk.Equal(0, people[0].Id) 307 | chk.Equal(0, people[1].Id) 308 | chk.NoError(mock.ExpectationsWereMet()) 309 | } 310 | } 311 | 312 | func TestQueryBinding_QuerySlice_NoPrepare_ExpectRow_GetNone(t *testing.T) { 313 | chk := assert.New(t) 314 | // 315 | // 316 | db, mock, err := sqlmock.New() 317 | chk.NotNil(db) 318 | chk.NotNil(mock) 319 | chk.NoError(err) 320 | noprepare := &no_prepare_db{db} 321 | // 322 | type Person struct { 323 | model.TableName `model:"people"` 324 | Id int `model:"key,auto"` 325 | First string 326 | Last string 327 | } 328 | models := model.Models{ 329 | Grammar: grammar.Postgres, 330 | Mapper: &set.Mapper{}, 331 | } 332 | models.Register(&Person{}) 333 | modelptr, err := models.Lookup([]*Person{}) 334 | chk.NoError(err) 335 | // 336 | qu := &statements.Query{ 337 | SQL: "INSERT", 338 | Arguments: []string{"First", "Last"}, 339 | Scan: []string{"Id"}, 340 | Expect: statements.ExpectRow, 341 | } 342 | { 343 | // Query expects one row and gets no rows. 344 | bound := modelptr.BindQuery(models.Mapper, qu) 345 | // If query errors. 346 | mock.ExpectQuery("INSERT+").WillReturnRows(sqlmock.NewRows([]string{"Id"}).RowError(0, sql.ErrNoRows)) 347 | people := []*Person{{}, {}} 348 | err = bound.QuerySlice(noprepare, people) 349 | chk.Error(err) 350 | chk.Equal(0, people[0].Id) 351 | chk.Equal(0, people[1].Id) 352 | chk.NoError(mock.ExpectationsWereMet()) 353 | } 354 | } 355 | 356 | func TestQueryBinding_QuerySlice_NoPrepare_ExpectRowOrNone_GetNone(t *testing.T) { 357 | chk := assert.New(t) 358 | // 359 | // 360 | db, mock, err := sqlmock.New() 361 | chk.NotNil(db) 362 | chk.NotNil(mock) 363 | chk.NoError(err) 364 | noprepare := &no_prepare_db{db} 365 | // 366 | type Person struct { 367 | model.TableName `model:"people"` 368 | Id int `model:"key,auto"` 369 | First string 370 | Last string 371 | } 372 | models := model.Models{ 373 | Grammar: grammar.Postgres, 374 | Mapper: &set.Mapper{}, 375 | } 376 | models.Register(&Person{}) 377 | modelptr, err := models.Lookup([]*Person{}) 378 | chk.NoError(err) 379 | // 380 | qu := &statements.Query{ 381 | SQL: "INSERT", 382 | Arguments: []string{"First", "Last"}, 383 | Scan: []string{"Id"}, 384 | Expect: statements.ExpectRowOrNone, 385 | } 386 | { 387 | // Query expects one row and gets no rows. 388 | bound := modelptr.BindQuery(models.Mapper, qu) 389 | // If query errors. 390 | mock.ExpectQuery("INSERT+").WillReturnRows(sqlmock.NewRows([]string{"Id"}).RowError(0, sql.ErrNoRows)) 391 | mock.ExpectQuery("INSERT+").WillReturnRows(sqlmock.NewRows([]string{"Id"}).RowError(0, sql.ErrNoRows)) 392 | people := []*Person{{}, {}} 393 | err = bound.QuerySlice(noprepare, people) 394 | chk.NoError(err) 395 | chk.Equal(0, people[0].Id) 396 | chk.Equal(0, people[1].Id) 397 | chk.NoError(mock.ExpectationsWereMet()) 398 | } 399 | } 400 | 401 | func TestQueryBinding_PreparedMappingErrors(t *testing.T) { 402 | db, mock, _ := sqlmock.New() 403 | // 404 | type Person struct { 405 | model.TableName `model:"people"` 406 | Id int `model:"key,auto"` 407 | First string 408 | Last string 409 | } 410 | var people []Person = []Person{{}, {}} 411 | var person Person 412 | // 413 | models := model.Models{ 414 | Grammar: grammar.Postgres, 415 | Mapper: &set.Mapper{}, 416 | } 417 | models.Register(person) 418 | modelptr, _ := models.Lookup(person) 419 | // 420 | t.Run("not writable", func(t *testing.T) { 421 | // Model is not writable (can not be prepared by set) 422 | chk := assert.New(t) 423 | qu := &statements.Query{ 424 | SQL: "INSERT", 425 | Arguments: []string{"First", "Last"}, 426 | Scan: []string{"Id"}, 427 | Expect: statements.ExpectRowOrNone, 428 | } 429 | bound := modelptr.BindQuery(models.Mapper, qu) 430 | // 431 | err := bound.QueryOne(db, person) 432 | chk.ErrorIs(err, set.ErrReadOnly) 433 | chk.NoError(mock.ExpectationsWereMet()) 434 | }) 435 | t.Run("missing arguments", func(t *testing.T) { 436 | // Query has arguments not in the struct. 437 | chk := assert.New(t) 438 | qu := &statements.Query{ 439 | SQL: "INSERT", 440 | Arguments: []string{"Fields", "Not", "Found"}, 441 | Scan: []string{"Id"}, 442 | Expect: statements.ExpectRowOrNone, 443 | } 444 | bound := modelptr.BindQuery(models.Mapper, qu) 445 | // 446 | err := bound.QueryOne(db, &person) 447 | chk.ErrorIs(err, set.ErrUnknownField) 448 | chk.NoError(mock.ExpectationsWereMet()) 449 | }) 450 | t.Run("missing arguments slice", func(t *testing.T) { 451 | // Query has arguments not in the struct. 452 | chk := assert.New(t) 453 | qu := &statements.Query{ 454 | SQL: "INSERT", 455 | Arguments: []string{"Fields", "Not", "Found"}, 456 | Scan: []string{"Id"}, 457 | Expect: statements.ExpectRowOrNone, 458 | } 459 | bound := modelptr.BindQuery(models.Mapper, qu) 460 | // 461 | err := bound.QuerySlice(db, people) 462 | chk.ErrorIs(err, set.ErrUnknownField) 463 | chk.NoError(mock.ExpectationsWereMet()) 464 | }) 465 | t.Run("missing scan", func(t *testing.T) { 466 | // Query has scan not in the struct 467 | chk := assert.New(t) 468 | qu := &statements.Query{ 469 | SQL: "INSERT", 470 | Arguments: []string{"First", "Last"}, 471 | Scan: []string{"NotFound"}, 472 | Expect: statements.ExpectRowOrNone, 473 | } 474 | bound := modelptr.BindQuery(models.Mapper, qu) 475 | // 476 | err := bound.QueryOne(db, &person) 477 | chk.ErrorIs(err, set.ErrUnknownField) 478 | chk.NoError(mock.ExpectationsWereMet()) 479 | }) 480 | t.Run("missing scan slice", func(t *testing.T) { 481 | // Query has scan not in the struct 482 | chk := assert.New(t) 483 | qu := &statements.Query{ 484 | SQL: "INSERT", 485 | Arguments: []string{"First", "Last"}, 486 | Scan: []string{"NotFound"}, 487 | Expect: statements.ExpectRowOrNone, 488 | } 489 | bound := modelptr.BindQuery(models.Mapper, qu) 490 | // 491 | err := bound.QuerySlice(db, people) 492 | chk.ErrorIs(err, set.ErrUnknownField) 493 | chk.NoError(mock.ExpectationsWereMet()) 494 | }) 495 | } 496 | -------------------------------------------------------------------------------- /model/save_mode.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | // SaveMode describes how a model should be saved when passed to Models.Save method. 4 | type SaveMode int 5 | 6 | const ( 7 | _ SaveMode = iota // Skip zero value and make it unusable. 8 | 9 | // Models with zero key fields can only be inserted. 10 | Insert 11 | 12 | // Models with only key,auto fields use insert or update depending 13 | // on the current values of the key fields. If any key,auto field 14 | // is not the zero value then update otherwise insert. 15 | InsertOrUpdate 16 | 17 | // Models with at least one key field that is not auto must use upsert. 18 | Upsert 19 | ) 20 | -------------------------------------------------------------------------------- /model/statements/pkg_doc.go: -------------------------------------------------------------------------------- 1 | // Package statements builds uses a grammar to build SQL statements scoped to entities within the database. 2 | package statements 3 | -------------------------------------------------------------------------------- /model/statements/query.go: -------------------------------------------------------------------------------- 1 | package statements 2 | 3 | import "fmt" 4 | 5 | // Expect is an enum describing what to expect when running a query. 6 | type Expect int 7 | 8 | const ( 9 | ExpectNone Expect = iota 10 | ExpectRow 11 | ExpectRowOrNone 12 | ExpectRows 13 | ) 14 | 15 | // String returns the Expect value as a string. 16 | func (me Expect) String() string { 17 | return [...]string{"None", "One Row", "One Row or None", "Multiple Rows"}[me] 18 | } 19 | 20 | // Query describes a SQL query. 21 | type Query struct { 22 | // SQL is the query statement. 23 | SQL string 24 | // If the query requires arguments then Arguments are the column name arguments in 25 | // the order expected by the SQL statement. 26 | Arguments []string 27 | // If the query returns columns then Scan are the column names in the order 28 | // to be scanned. 29 | Scan []string 30 | // Expect is a hint that indicates if the query returns no rows, one row, or many rows. 31 | Expect Expect 32 | } 33 | 34 | // String describes the Query as a string. 35 | func (me *Query) String() string { 36 | if me == nil { 37 | return "Nil Query." 38 | } else if me.SQL == "" { 39 | return "Empty Query." 40 | } 41 | // 42 | rv := me.SQL 43 | if len(me.Arguments) > 0 { 44 | rv = rv + fmt.Sprintf("\n\tArguments: %v", me.Arguments) 45 | } 46 | if len(me.Scan) > 0 { 47 | rv = rv + fmt.Sprintf("\n\tScan: %v", me.Scan) 48 | } 49 | rv = rv + "\n\tExpect: " + me.Expect.String() 50 | // 51 | return rv 52 | } 53 | -------------------------------------------------------------------------------- /model/statements/table.go: -------------------------------------------------------------------------------- 1 | package statements 2 | 3 | import "strings" 4 | 5 | // Table is the collection of Query types to perform CRUD against a table. 6 | type Table struct { 7 | Delete *Query 8 | Insert *Query 9 | Update *Query 10 | Upsert *Query 11 | } 12 | 13 | // String returns the table statements as a friendly string. 14 | func (me Table) String() string { 15 | parts := []string{ 16 | "INSERT: " + me.Insert.String(), 17 | "UPDATE: " + me.Update.String(), 18 | "UPSERT: " + me.Upsert.String(), 19 | "DELETE: " + me.Delete.String(), 20 | } 21 | return strings.Join(parts, "\n") 22 | } 23 | -------------------------------------------------------------------------------- /model/tablename.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "reflect" 4 | 5 | // TableName represents a database table name. Embed the TableName type into a struct and set 6 | // the appropriate struct tag to configure the table name. 7 | type TableName string 8 | 9 | var ( 10 | typeTableName = reflect.TypeOf(TableName("")) 11 | ) 12 | -------------------------------------------------------------------------------- /pkg_doc.go: -------------------------------------------------------------------------------- 1 | // Package sqlh provides some simple utility for database/sql. 2 | // 3 | // Refer to examples under Scanner.Select for row scanning. 4 | // 5 | // Refer to subdirectory model for the model abstraction layer. 6 | // 7 | // sqlh and associated packages use reflect but nearly all of the heavy lifting is offloaded 8 | // to set @ https://www.github.com/nofeaturesonlybugs/set 9 | // 10 | // Both sqlh.Scanner and model.Models use set.Mapper; the examples typically demonstrate 11 | // instantiating a set.Mapper. If you design your Go destinations and models well 12 | // then ideally your app will only need a single set.Mapper, a single sqlh.Scanner, 13 | // and a single model.Models. Since all of these can be instantiated without a database 14 | // connection you may wish to define them as globals and register models as part 15 | // of an init() function. 16 | // 17 | // set.Mapper supports some additional flexibility not shown in the examples for this package; 18 | // if your project has extra convoluted Go structs in your database layer you may want to consult 19 | // the package documentation for package set. 20 | package sqlh 21 | -------------------------------------------------------------------------------- /scanner.go: -------------------------------------------------------------------------------- 1 | package sqlh 2 | 3 | import ( 4 | "database/sql" 5 | "reflect" 6 | "time" 7 | 8 | "github.com/nofeaturesonlybugs/errors" 9 | "github.com/nofeaturesonlybugs/set" 10 | ) 11 | 12 | // scannerDestType describes a destination value from the caller. 13 | type scannerDestType int 14 | 15 | const ( 16 | destInvalid scannerDestType = iota 17 | destScalar 18 | destScalarSlice 19 | destStruct 20 | destStructSlice 21 | ) 22 | 23 | // String returns the Expect value as a string. 24 | func (me scannerDestType) String() string { 25 | return [...]string{"Invalid", "Scalar", "[]Scalar", "Struct", "[]Struct"}[me] 26 | } 27 | 28 | // Scanner facilitates scanning query results into destinations. 29 | type Scanner struct { 30 | *set.Mapper 31 | } 32 | 33 | // inspectValue inspects a query destination and determines if it can be used. 34 | func (me *Scanner) inspectValue(dest interface{}) (V set.Value, T scannerDestType, err error) { 35 | T = destInvalid 36 | if dest == nil { 37 | err = errors.Errorf("dest is nil") 38 | return 39 | } else if V = set.V(dest); !V.CanWrite { 40 | err = errors.Errorf("dest is not writable") 41 | return 42 | } 43 | if V.IsSlice { 44 | switch dest.(type) { 45 | case *[]time.Time: 46 | T = destScalarSlice 47 | 48 | default: 49 | if V.ElemTypeInfo.IsStruct { 50 | T = destStructSlice 51 | } else if V.ElemTypeInfo.IsScalar { 52 | T = destScalarSlice 53 | } 54 | } 55 | } else { 56 | switch dest.(type) { 57 | case *time.Time: 58 | T = destScalar 59 | 60 | default: 61 | if V.IsStruct { 62 | T = destStruct 63 | } else if V.IsScalar { 64 | T = destScalar 65 | } 66 | } 67 | } 68 | if T == destInvalid { 69 | err = errors.Errorf("unsupported dest %T", dest) 70 | } 71 | return 72 | } 73 | 74 | // Select uses Q to run the query string with args and scans results into dest. 75 | func (me *Scanner) Select(Q IQueries, dest interface{}, query string, args ...interface{}) error { 76 | V, T, err := me.inspectValue(dest) 77 | if err != nil { 78 | return errors.Go(err) 79 | } 80 | switch T { 81 | case destScalar: 82 | row := Q.QueryRow(query, args...) 83 | if err := row.Scan(dest); err != nil { 84 | return errors.Go(err) 85 | } 86 | 87 | case destStruct: 88 | var rows *sql.Rows 89 | var prepared set.PreparedMapping 90 | var columns []string 91 | var err error 92 | // Why not QueryRow()? Because *sql.Row does not allow us to get the list of columns which we 93 | // need for our dynamic Scan(). 94 | if rows, err = Q.Query(query, args...); err != nil { 95 | return errors.Go(err) 96 | } 97 | defer rows.Close() 98 | if columns, err = rows.Columns(); err != nil { 99 | return errors.Go(err) 100 | } 101 | // 102 | // Get a prepared mapping and prepare the access plan. Note that if prepared.Plan() 103 | // succeeds we know future calls to prepared.Assignables() will succeed and do not 104 | // need to check those errors. 105 | if prepared, err = me.Mapper.Prepare(dest); err != nil { 106 | return errors.Go(err) 107 | } else if err = prepared.Plan(columns...); err != nil { 108 | return errors.Go(err) 109 | } 110 | assignables := make([]interface{}, len(columns)) 111 | if rows.Next() { 112 | _, _ = prepared.Assignables(assignables) 113 | if err = rows.Scan(assignables...); err != nil { 114 | return errors.Go(err) 115 | } 116 | } else { 117 | // When no rows are returned set dest to the zero value of its type. Since dest should be a pointer 118 | // we need to Indirect(ValueOf(dest)) and set TypeOf(dest).Elem(). 119 | reflect.Indirect(reflect.ValueOf(dest)).Set(reflect.Zero(reflect.TypeOf(dest).Elem())) 120 | } 121 | if err = rows.Err(); err != nil { 122 | return errors.Go(err) 123 | } 124 | 125 | case destScalarSlice: 126 | fallthrough 127 | case destStructSlice: 128 | rows, err := Q.Query(query, args...) 129 | if err != nil { 130 | return errors.Go(err) 131 | } 132 | defer rows.Close() 133 | if err = me.scanRows(rows, dest, V, T); err != nil { 134 | return errors.Go(err) 135 | } 136 | 137 | } 138 | 139 | return nil 140 | } 141 | 142 | // scanRows scans rows is the internal scanRows that assumes dest is safe. 143 | func (me *Scanner) scanRows(R IIterates, dest interface{}, V set.Value, T scannerDestType) error { 144 | if R != nil { 145 | defer R.Close() 146 | } 147 | var prepared set.PreparedMapping 148 | var columns []string 149 | var err error 150 | // 151 | switch T { 152 | case destScalarSlice: 153 | e := reflect.New(V.ElemType).Interface() 154 | E := set.V(e) 155 | if R.Next() { 156 | if err = R.Scan(e); err != nil { 157 | return errors.Go(err) 158 | } 159 | // While this *can* panic it *should never* panic. Second famous last words. 160 | set.Panics.Append(V, E) 161 | } 162 | for R.Next() { 163 | // Create new element E; ignore error because we already know the call succeeds. 164 | e = reflect.New(V.ElemType).Interface() 165 | E.Rebind(e) 166 | if err = R.Scan(e); err != nil { 167 | return errors.Go(err) 168 | } 169 | // While this *can* panic it *should never* panic. Second famous last words. 170 | set.Panics.Append(V, E) 171 | } 172 | if err = R.Err(); err != nil { 173 | return errors.Go(err) 174 | } 175 | 176 | case destStructSlice: 177 | if columns, err = R.Columns(); err != nil { 178 | return errors.Go(err) 179 | } 180 | // 181 | assignables := make([]interface{}, len(columns)) 182 | // V is a slice; E is then an element instance that can be appended to V. 183 | e := reflect.New(V.ElemType) 184 | // Create new empty slice to reflect.Append(slice, elem) to. Note on the way 185 | // out of the function we must assign this slice to V.WriteValue. 186 | slice := reflect.New(V.Type).Elem() 187 | // 188 | // Get a prepared mapping and prepare the access plan. Note that if prepared.Plan() 189 | // succeeds we know future calls to prepared.Assignables() will succeed and do not 190 | // need to check those errors. 191 | if prepared, err = me.Mapper.Prepare(e); err != nil { 192 | return errors.Go(err) 193 | } else if err = prepared.Plan(columns...); err != nil { 194 | return errors.Go(err) 195 | } 196 | // 197 | // Want to use our existing bound element; otherwise we're creating and discarding one. 198 | if R.Next() { 199 | _, _ = prepared.Assignables(assignables) 200 | if err = R.Scan(assignables...); err != nil { 201 | return errors.Go(err) 202 | } 203 | slice = reflect.Append(slice, e.Elem()) 204 | } 205 | for R.Next() { 206 | // Create new element E; ignore error because we already know the call succeeds. 207 | e = reflect.New(V.ElemType) 208 | prepared.Rebind(e) 209 | // Get the assignable values; again we already know the call succeeds. 210 | _, _ = prepared.Assignables(assignables) 211 | if err = R.Scan(assignables...); err != nil { 212 | return errors.Go(err) 213 | } 214 | slice = reflect.Append(slice, e.Elem()) 215 | } 216 | if err = R.Err(); err != nil { 217 | return errors.Go(err) 218 | } 219 | V.WriteValue.Set(slice) // Don't forget to assign the slice we made to caller's dest. 220 | } 221 | 222 | return nil 223 | } 224 | 225 | // ScanRows scans rows from R into dest. 226 | func (me *Scanner) ScanRows(R IIterates, dest interface{}) error { 227 | if R != nil { 228 | defer R.Close() 229 | } 230 | V, T, err := me.inspectValue(dest) 231 | if err != nil { 232 | return errors.Go(err) 233 | } else if T != destScalarSlice && T != destStructSlice { 234 | return errors.Errorf("%T.ScanRows expects dest to be address of slice; got %T", me, dest) 235 | } 236 | return me.scanRows(R, dest, V, T) 237 | } 238 | -------------------------------------------------------------------------------- /scanner_examples_test.go: -------------------------------------------------------------------------------- 1 | package sqlh_test 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/nofeaturesonlybugs/set" 8 | "github.com/nofeaturesonlybugs/sqlh" 9 | "github.com/nofeaturesonlybugs/sqlh/examples" 10 | ) 11 | 12 | func ExampleScanner_Select_structSlice() { 13 | type MyStruct struct { 14 | Message string 15 | Number int 16 | } 17 | db, err := examples.Connect(examples.ExSimpleMapper) 18 | if err != nil { 19 | fmt.Println(err.Error()) 20 | } 21 | // 22 | scanner := &sqlh.Scanner{ 23 | // Mapper is pure defaults. Uses exported struct names as column names. 24 | Mapper: &set.Mapper{}, 25 | } 26 | 27 | fmt.Println("Dest is slice struct:") 28 | var rv []MyStruct 29 | err = scanner.Select(db, &rv, "select * from mytable") 30 | if err != nil { 31 | fmt.Println(err.Error()) 32 | } 33 | for _, row := range rv { 34 | fmt.Printf("%v %v\n", row.Message, row.Number) 35 | } 36 | 37 | // Output: Dest is slice struct: 38 | // Hello, World! 42 39 | // So long! 100 40 | } 41 | 42 | func ExampleScanner_Select_structSliceOfPointers() { 43 | type MyStruct struct { 44 | Message string 45 | Number int 46 | } 47 | db, err := examples.Connect(examples.ExSimpleMapper) 48 | if err != nil { 49 | fmt.Println(err.Error()) 50 | } 51 | // 52 | scanner := &sqlh.Scanner{ 53 | // Mapper is pure defaults. Uses exported struct names as column names. 54 | Mapper: &set.Mapper{}, 55 | } 56 | 57 | fmt.Println("Dest is slice of pointers:") 58 | var rv []*MyStruct 59 | err = scanner.Select(db, &rv, "select * from mytable") 60 | if err != nil { 61 | fmt.Println(err.Error()) 62 | } 63 | for _, row := range rv { 64 | fmt.Printf("%v %v\n", row.Message, row.Number) 65 | } 66 | 67 | // Output: Dest is slice of pointers: 68 | // Hello, World! 42 69 | // So long! 100 70 | } 71 | 72 | func ExampleScanner_Select_tags() { 73 | type MyStruct struct { 74 | Message string `json:"message"` 75 | Number int `json:"value" db:"num"` 76 | } 77 | db, err := examples.Connect(examples.ExTags) 78 | if err != nil { 79 | fmt.Println(err.Error()) 80 | } 81 | // 82 | scanner := &sqlh.Scanner{ 83 | // Mapper uses struct tag db or json, db higher priority 84 | Mapper: &set.Mapper{ 85 | Tags: []string{"db", "json"}, 86 | }, 87 | } 88 | var rv []*MyStruct 89 | err = scanner.Select(db, &rv, "select message, num from mytable") 90 | if err != nil { 91 | fmt.Println(err.Error()) 92 | } 93 | for _, row := range rv { 94 | fmt.Printf("%v %v\n", row.Message, row.Number) 95 | } 96 | 97 | // Output: Hello, World! 42 98 | // So long! 100 99 | } 100 | 101 | func ExampleScanner_Select_nestedStruct() { 102 | type Common struct { 103 | Id int `json:"id"` 104 | Created time.Time `json:"created"` 105 | Modified time.Time `json:"modified"` 106 | } 107 | type MyStruct struct { 108 | // Structs can share the structure in Common. 109 | // This would work just as well if the embed was not a pointer. 110 | // Note how set.Mapper.Elevated is set! 111 | *Common 112 | Message string `json:"message"` 113 | Number int `json:"value" db:"num"` 114 | } 115 | db, err := examples.Connect(examples.ExNestedStruct) 116 | if err != nil { 117 | fmt.Println(err.Error()) 118 | } 119 | // 120 | scanner := &sqlh.Scanner{ 121 | // 122 | Mapper: &set.Mapper{ 123 | Elevated: set.NewTypeList(Common{}), 124 | Tags: []string{"db", "json"}, 125 | }, 126 | } 127 | var rv []*MyStruct 128 | err = scanner.Select(db, &rv, "select id, created, modified, message, num from mytable") 129 | if err != nil { 130 | fmt.Println(err.Error()) 131 | } 132 | for _, row := range rv { 133 | fmt.Printf("%v %v %v\n", row.Id, row.Message, row.Number) 134 | } 135 | 136 | // Output: 1 Hello, World! 42 137 | // 2 So long! 100 138 | } 139 | 140 | func ExampleScanner_Select_nestedTwice() { 141 | type Common struct { 142 | Id int `json:"id"` 143 | Created time.Time `json:"created"` 144 | Modified time.Time `json:"modified"` 145 | } 146 | type Person struct { 147 | Common 148 | First string `json:"first"` 149 | Last string `json:"last"` 150 | } 151 | // Note here the natural mapping of SQL columns to nested structs. 152 | type Sale struct { 153 | Common 154 | // customer_first and customer_last map to Customer. 155 | Customer Person `json:"customer"` 156 | // contact_first and contact_last map to Contact. 157 | Contact Person `json:"contact"` 158 | } 159 | db, err := examples.Connect(examples.ExNestedTwice) 160 | if err != nil { 161 | fmt.Println(err.Error()) 162 | } 163 | // 164 | scanner := &sqlh.Scanner{ 165 | // Mapper uses struct tag db or json, db higher priority. 166 | // Mapper elevates Common to same level as other fields. 167 | Mapper: &set.Mapper{ 168 | Elevated: set.NewTypeList(Common{}), 169 | Join: "_", 170 | Tags: []string{"db", "json"}, 171 | }, 172 | } 173 | var rv []*Sale 174 | query := ` 175 | select 176 | s.id, s.created, s.modified, 177 | s.customer_id, c.first as customer_first, c.last as customer_last, 178 | s.vendor_id as contact_id, v.first as contact_first, v.last as contact_last 179 | from sales s 180 | inner join customers c on s.customer_id = c.id 181 | inner join vendors v on s.vendor_id = v.id 182 | ` 183 | err = scanner.Select(db, &rv, query) 184 | if err != nil { 185 | fmt.Println(err.Error()) 186 | } 187 | for _, row := range rv { 188 | fmt.Printf("%v %v.%v %v %v.%v %v\n", row.Id, row.Customer.Id, row.Customer.First, row.Customer.Last, row.Contact.Id, row.Contact.First, row.Contact.Last) 189 | } 190 | 191 | // Output: 1 10.Bob Smith 100.Sally Johnson 192 | // 2 20.Fred Jones 200.Betty Walker 193 | } 194 | 195 | func ExampleScanner_Select_scalar() { 196 | db, err := examples.Connect(examples.ExScalar) 197 | if err != nil { 198 | fmt.Println(err.Error()) 199 | } 200 | scanner := sqlh.Scanner{ 201 | Mapper: &set.Mapper{}, 202 | } 203 | var n int 204 | err = scanner.Select(db, &n, "select count(*) as n from thetable") 205 | if err != nil { 206 | fmt.Println(err.Error()) 207 | } 208 | fmt.Println(n) 209 | // Output: 64 210 | } 211 | 212 | func ExampleScanner_Select_scalarSlice() { 213 | db, err := examples.Connect(examples.ExScalarSlice) 214 | if err != nil { 215 | fmt.Println(err.Error()) 216 | } 217 | scanner := sqlh.Scanner{ 218 | Mapper: &set.Mapper{}, 219 | } 220 | fmt.Println("Dest is slice of scalar:") 221 | var ids []int 222 | err = scanner.Select(db, &ids, "select id from thetable where col = ?", "some value") 223 | if err != nil { 224 | fmt.Println(err.Error()) 225 | } 226 | fmt.Println(ids) 227 | // Output: Dest is slice of scalar: 228 | // [1 2 3] 229 | } 230 | 231 | func ExampleScanner_Select_scalarSliceOfPointers() { 232 | db, err := examples.Connect(examples.ExScalarSlice) 233 | if err != nil { 234 | fmt.Println(err.Error()) 235 | } 236 | scanner := sqlh.Scanner{ 237 | Mapper: &set.Mapper{}, 238 | } 239 | fmt.Println("Dest is slice of pointer-to-scalar:") 240 | var ptrs []*int 241 | err = scanner.Select(db, &ptrs, "select id from thetable where col = ?", "some value") 242 | if err != nil { 243 | fmt.Println(err.Error()) 244 | } 245 | var ids []int 246 | for _, ptr := range ptrs { 247 | ids = append(ids, *ptr) 248 | } 249 | fmt.Println(ids) 250 | // Output: Dest is slice of pointer-to-scalar: 251 | // [1 2 3] 252 | } 253 | 254 | func ExampleScanner_Select_struct() { 255 | db, err := examples.Connect(examples.ExStruct) 256 | if err != nil { 257 | fmt.Println(err.Error()) 258 | } 259 | scanner := sqlh.Scanner{ 260 | Mapper: &set.Mapper{ 261 | Tags: []string{"db", "json"}, 262 | }, 263 | } 264 | type Temp struct { 265 | Min time.Time `json:"min"` 266 | Max time.Time `json:"max"` 267 | } 268 | var dest *Temp 269 | err = scanner.Select(db, &dest, "select min(col) as min, max(col) as max from thetable") 270 | if err != nil { 271 | fmt.Println(err.Error()) 272 | } 273 | fmt.Println(dest.Min.Format(time.RFC3339), dest.Max.Format(time.RFC3339)) 274 | // Output: 1970-01-01T00:00:00Z 2012-01-01T00:00:00Z 275 | } 276 | 277 | func ExampleScanner_Select_structNotFound() { 278 | type Common struct { 279 | Id int `json:"id"` 280 | Created time.Time `json:"created"` 281 | Modified time.Time `json:"modified"` 282 | } 283 | type Person struct { 284 | Common 285 | First string `json:"first"` 286 | Last string `json:"last"` 287 | } 288 | // Note here the natural mapping of SQL columns to nested structs. 289 | type Sale struct { 290 | Common 291 | // customer_first and customer_last map to Customer. 292 | Customer Person `json:"customer"` 293 | // contact_first and contact_last map to Contact. 294 | Contact Person `json:"contact"` 295 | } 296 | // 297 | db, err := examples.Connect(examples.ExStructNotFound) 298 | if err != nil { 299 | fmt.Println(err.Error()) 300 | } 301 | scanner := sqlh.Scanner{ 302 | Mapper: &set.Mapper{ 303 | Elevated: set.NewTypeList(Common{}), 304 | Tags: []string{"db", "json"}, 305 | Join: "_", 306 | }, 307 | } 308 | query := ` 309 | select 310 | s.id, s.created, s.modified, 311 | s.customer_id, c.first as customer_first, c.last as customer_last, 312 | s.vendor_id as contact_id, v.first as contact_first, v.last as contact_last 313 | from sales s 314 | inner join customers c on s.customer_id = c.id 315 | inner join vendors v on s.vendor_id = v.id 316 | ` 317 | // When destination is a pointer to struct and no rows are found then the dest pointer 318 | // remains nil and no error is returned. 319 | var dest *Sale 320 | err = scanner.Select(db, &dest, query) 321 | if err != nil { 322 | fmt.Println(err.Error()) 323 | } 324 | fmt.Printf("Is nil: %v\n", dest == nil) 325 | 326 | // Output: Is nil: true 327 | } 328 | -------------------------------------------------------------------------------- /scanner_test.go: -------------------------------------------------------------------------------- 1 | package sqlh_test 2 | 3 | import ( 4 | "encoding/json" 5 | "reflect" 6 | "testing" 7 | "time" 8 | 9 | "github.com/DATA-DOG/go-sqlmock" 10 | "github.com/stretchr/testify/assert" 11 | 12 | "github.com/nofeaturesonlybugs/errors" 13 | "github.com/nofeaturesonlybugs/set" 14 | "github.com/nofeaturesonlybugs/sqlh" 15 | "github.com/nofeaturesonlybugs/sqlh/examples" 16 | ) 17 | 18 | func TestScanner_StructSliceQueryError(t *testing.T) { 19 | // Tests Query(...) returns non-nil error when dest is []struct. 20 | // 21 | chk := assert.New(t) 22 | // 23 | db, mock, err := sqlmock.New() 24 | chk.NoError(err) 25 | chk.NotNil(mock) 26 | chk.NotNil(db) 27 | { // When dest is []struct and query returns error 28 | mock.ExpectQuery("select +").WillReturnError(errors.Errorf("[]struct query error")) 29 | type Dest struct { 30 | A string 31 | } 32 | scanner := &sqlh.Scanner{ 33 | Mapper: &set.Mapper{}, 34 | } 35 | var d []Dest 36 | err = scanner.Select(db, &d, "select * from table") 37 | chk.Error(err) 38 | } 39 | } 40 | 41 | func TestScanner_StructSliceRowsErrNonNil(t *testing.T) { 42 | // Tests that *sql.Rows.Err() != nil after the for...*sql.Rows.Next() {} loop when dest is a []struct. 43 | // 44 | chk := assert.New(t) 45 | // 46 | db, mock, err := sqlmock.New() 47 | chk.NoError(err) 48 | chk.NotNil(mock) 49 | chk.NotNil(db) 50 | { 51 | // When dest is []struct and *sql.Rows.Err() is non-nil 52 | rows := sqlmock.NewRows([]string{"A"}). 53 | AddRow("a").AddRow("b").AddRow("c"). 54 | RowError(2, errors.Errorf("[]struct *sql.Rows.Err() is non-nil")) 55 | mock.ExpectQuery("select +").WillReturnRows(rows) 56 | type Dest struct { 57 | A string 58 | } 59 | scanner := &sqlh.Scanner{ 60 | Mapper: &set.Mapper{}, 61 | } 62 | var d []Dest 63 | err = scanner.Select(db, &d, "select * from table") 64 | chk.Error(err) 65 | } 66 | } 67 | 68 | func TestScanner_SelectScalarTime(t *testing.T) { 69 | // Tests selecting into a dest of time.Time. 70 | // 71 | chk := assert.New(t) 72 | // 73 | scanner := &sqlh.Scanner{ 74 | Mapper: &set.Mapper{ 75 | Tags: []string{"json"}, 76 | }, 77 | } 78 | // 79 | db, mock, err := sqlmock.New() 80 | chk.NoError(err) 81 | chk.NotNil(db) 82 | chk.NotNil(mock) 83 | { 84 | // scalar time 85 | rv := time.Now() 86 | dataRows := sqlmock.NewRows([]string{"tm"}) 87 | dataRows.AddRow(rv) 88 | mock.ExpectQuery("select (.+)").WillReturnRows(dataRows) 89 | // 90 | var n time.Time 91 | err = scanner.Select(db, &n, "select max(*) as tm from foo") 92 | chk.NoError(err) 93 | chk.True(rv.Equal(n)) 94 | } 95 | } 96 | 97 | func TestScanner_SelectScalarSlice(t *testing.T) { 98 | // Tests selecting into a dest of []time.Time. 99 | chk := assert.New(t) 100 | // 101 | scanner := &sqlh.Scanner{ 102 | Mapper: &set.Mapper{ 103 | Tags: []string{"json"}, 104 | }, 105 | } 106 | // 107 | db, mock, err := sqlmock.New() 108 | chk.NoError(err) 109 | chk.NotNil(db) 110 | chk.NotNil(mock) 111 | { 112 | // scalar time 113 | row1 := time.Now() 114 | row2 := row1.Add(-1 * time.Hour) 115 | dataRows := sqlmock.NewRows([]string{"tm"}) 116 | dataRows.AddRow(row1) 117 | dataRows.AddRow(row2) 118 | mock.ExpectQuery("select (.+)").WillReturnRows(dataRows) 119 | // 120 | var n []time.Time 121 | err = scanner.Select(db, &n, "select max(*) as tm from foo") 122 | chk.NoError(err) 123 | chk.True(row1.Equal(n[0])) 124 | chk.True(row2.Equal(n[1])) 125 | } 126 | } 127 | 128 | func TestScanner_SingleStruct(t *testing.T) { 129 | // Tests various errors when scanning into a single struct. 130 | chk := assert.New(t) 131 | // 132 | type Dest struct { 133 | A string `json:"a"` 134 | B int `json:"b"` 135 | } 136 | // 137 | scanner := &sqlh.Scanner{ 138 | Mapper: &set.Mapper{ 139 | Tags: []string{"json"}, 140 | }, 141 | } 142 | // 143 | db, mock, err := sqlmock.New() 144 | chk.NoError(err) 145 | chk.NotNil(db) 146 | chk.NotNil(mock) 147 | { 148 | // query returns error 149 | mock.ExpectQuery("select (.+)").WillReturnError(errors.Errorf("oops")) 150 | // 151 | var d Dest 152 | err = scanner.Select(db, &d, "select a, b from foo") 153 | chk.Error(err) 154 | } 155 | { // columns error 156 | dataRows := sqlmock.NewRows([]string{"a", "b", "c"}) 157 | dataRows.AddRow("Hello", 42, "not found") 158 | mock.ExpectQuery("select (.+)").WillReturnRows(dataRows) 159 | // 160 | var d Dest 161 | err = scanner.Select(db, &d, "select a, b, c from foo") 162 | chk.Error(err) 163 | } 164 | { // scan error 165 | dataRows := sqlmock.NewRows([]string{"a", "b"}) 166 | dataRows.AddRow("Hello", "asdf") 167 | mock.ExpectQuery("select (.+)").WillReturnRows(dataRows) 168 | // 169 | var d Dest 170 | err = scanner.Select(db, &d, "select a, b, c from foo") 171 | chk.Error(err) 172 | } 173 | } 174 | 175 | func TestScanner_Select_Errors(t *testing.T) { 176 | db, mock, _ := sqlmock.New() 177 | type Dest struct { 178 | A int 179 | B int 180 | } 181 | scanner := &sqlh.Scanner{ 182 | Mapper: &set.Mapper{}, 183 | } 184 | 185 | t.Run("nil dest", func(t *testing.T) { 186 | chk := assert.New(t) 187 | err := scanner.Select(db, nil, "select * from test") 188 | chk.Error(err) 189 | }) 190 | t.Run("invalid dest", func(t *testing.T) { 191 | chk := assert.New(t) 192 | var d map[string]interface{} 193 | err := scanner.Select(db, &d, "select * from test") 194 | chk.Error(err) 195 | }) 196 | t.Run("readonly dest", func(t *testing.T) { 197 | chk := assert.New(t) 198 | var d Dest 199 | err := scanner.Select(db, d, "select * from test") 200 | chk.Error(err) 201 | }) 202 | t.Run("struct column mismatch", func(t *testing.T) { 203 | chk := assert.New(t) 204 | 205 | dataRows := sqlmock.NewRows([]string{"X", "Y"}) 206 | dataRows.AddRow(1, 2) 207 | mock.ExpectQuery("select (.+)").WillReturnRows(dataRows).RowsWillBeClosed() 208 | 209 | var d Dest 210 | err := scanner.Select(db, &d, "select * from test") 211 | chk.Error(err) 212 | chk.NoError(mock.ExpectationsWereMet()) 213 | }) 214 | t.Run("struct column mismatch", func(t *testing.T) { 215 | chk := assert.New(t) 216 | 217 | dataRows := sqlmock.NewRows([]string{"X", "Y"}) 218 | dataRows.AddRow(1, 2) 219 | dataRows.AddRow(3, 4) 220 | mock.ExpectQuery("select (.+)").WillReturnRows(dataRows).RowsWillBeClosed() 221 | 222 | var d []Dest 223 | err := scanner.Select(db, &d, "select * from test") 224 | chk.Error(err) 225 | chk.NoError(mock.ExpectationsWereMet()) 226 | }) 227 | t.Run("struct rows first scan fails", func(t *testing.T) { 228 | chk := assert.New(t) 229 | 230 | dataRows := sqlmock.NewRows([]string{"A", "B"}) 231 | dataRows.AddRow("a", "b") 232 | dataRows.AddRow(3, 4) 233 | mock.ExpectQuery("select (.+)").WillReturnRows(dataRows).RowsWillBeClosed() 234 | 235 | var d []Dest 236 | err := scanner.Select(db, &d, "select * from test") 237 | chk.Error(err) 238 | chk.NoError(mock.ExpectationsWereMet()) 239 | }) 240 | t.Run("struct rows second scan fails", func(t *testing.T) { 241 | chk := assert.New(t) 242 | 243 | dataRows := sqlmock.NewRows([]string{"A", "B"}) 244 | dataRows.AddRow(3, 4) 245 | dataRows.AddRow("a", "b") 246 | mock.ExpectQuery("select (.+)").WillReturnRows(dataRows).RowsWillBeClosed() 247 | 248 | var d []Dest 249 | err := scanner.Select(db, &d, "select * from test") 250 | chk.Error(err) 251 | chk.NoError(mock.ExpectationsWereMet()) 252 | }) 253 | t.Run("scalar rows first scan fails", func(t *testing.T) { 254 | chk := assert.New(t) 255 | 256 | dataRows := sqlmock.NewRows([]string{"n"}) 257 | dataRows.AddRow("abc") 258 | dataRows.AddRow(4) 259 | mock.ExpectQuery("select (.+)").WillReturnRows(dataRows).RowsWillBeClosed() 260 | 261 | var d []int 262 | err := scanner.Select(db, &d, "select * from test") 263 | chk.Error(err) 264 | chk.NoError(mock.ExpectationsWereMet()) 265 | }) 266 | t.Run("scalar rows second scan fails", func(t *testing.T) { 267 | chk := assert.New(t) 268 | 269 | dataRows := sqlmock.NewRows([]string{"n"}) 270 | dataRows.AddRow(3) 271 | dataRows.AddRow("abc") 272 | mock.ExpectQuery("select (.+)").WillReturnRows(dataRows).RowsWillBeClosed() 273 | 274 | var d []int 275 | err := scanner.Select(db, &d, "select * from test") 276 | chk.Error(err) 277 | chk.NoError(mock.ExpectationsWereMet()) 278 | }) 279 | t.Run("scalar rows error", func(t *testing.T) { 280 | chk := assert.New(t) 281 | 282 | dataRows := sqlmock.NewRows([]string{"n"}) 283 | dataRows.AddRow(3) 284 | dataRows.AddRow(4) 285 | dataRows.RowError(0, errors.Errorf("oops")) 286 | 287 | var d []int 288 | err := scanner.Select(db, &d, "select * from test") 289 | chk.Error(err) 290 | chk.NoError(mock.ExpectationsWereMet()) 291 | }) 292 | } 293 | 294 | func TestScanner_ScanRows_Errors(t *testing.T) { 295 | type Dest struct { 296 | A int 297 | B int 298 | } 299 | // 300 | db, mock, _ := sqlmock.New() 301 | // 302 | scanner := &sqlh.Scanner{ 303 | Mapper: &set.Mapper{}, 304 | } 305 | 306 | t.Run("nil dest", func(t *testing.T) { 307 | chk := assert.New(t) 308 | 309 | mock.ExpectQuery("select +"). 310 | WillReturnRows(sqlmock.NewRows([]string{"a"}).AddRow(10)). 311 | RowsWillBeClosed() 312 | 313 | rows, err := db.Query("select * from foo") 314 | chk.NoError(err) 315 | err = scanner.ScanRows(rows, nil) 316 | chk.Error(err) 317 | chk.NoError(mock.ExpectationsWereMet()) 318 | }) 319 | t.Run("invalid dest", func(t *testing.T) { 320 | chk := assert.New(t) 321 | 322 | var n int 323 | mock.ExpectQuery("select +"). 324 | WillReturnRows(sqlmock.NewRows([]string{"a"}).AddRow(10)). 325 | RowsWillBeClosed() 326 | 327 | rows, err := db.Query("select * from foo") 328 | chk.NoError(err) 329 | err = scanner.ScanRows(rows, &n) 330 | chk.Error(err) 331 | chk.NoError(mock.ExpectationsWereMet()) 332 | }) 333 | t.Run("readonly dest", func(t *testing.T) { 334 | chk := assert.New(t) 335 | var d Dest 336 | err := scanner.Select(db, d, "select * from test") 337 | chk.Error(err) 338 | }) 339 | } 340 | 341 | func TestScanner_DestWrongType(t *testing.T) { 342 | chk := assert.New(t) 343 | // 344 | db, mock, err := sqlmock.New() 345 | chk.NoError(err) 346 | chk.NotNil(db) 347 | chk.NotNil(mock) 348 | // 349 | dataRows := sqlmock.NewRows([]string{"A", "B"}) 350 | dataRows.AddRow(1, 2) 351 | dataRows.AddRow(3, 4) 352 | // 353 | scanner := &sqlh.Scanner{ 354 | Mapper: &set.Mapper{ 355 | Tags: []string{"json"}, 356 | }, 357 | } 358 | // dest is not slice 359 | var di int 360 | mock.ExpectQuery("select (.+)").WillReturnRows(dataRows) 361 | err = scanner.Select(db, &di, "select * from test") 362 | chk.Error(err) 363 | err = mock.ExpectationsWereMet() 364 | chk.NoError(err) 365 | // dest is not slice of struct 366 | var dslice []int 367 | mock.ExpectQuery("select (.+)").WillReturnRows(dataRows) 368 | err = scanner.Select(db, &dslice, "select * from test") 369 | chk.Error(err) 370 | err = mock.ExpectationsWereMet() 371 | chk.NoError(err) 372 | } 373 | 374 | func TestScanner_Select(t *testing.T) { 375 | type SimpleStruct struct { 376 | Message string 377 | Number int 378 | } 379 | // 380 | type NestedInnerStruct struct { 381 | Id int `json:"id"` 382 | Created time.Time `json:"created"` 383 | Modified time.Time `json:"modified"` 384 | } 385 | type NestedOuterStruct struct { 386 | NestedInnerStruct 387 | Message string `json:"message"` 388 | Number int `json:"value" db:"num"` 389 | } 390 | type PointerOuterStruct struct { 391 | *NestedInnerStruct 392 | Message string `json:"message"` 393 | Number int `json:"value" db:"num"` 394 | } 395 | 396 | // 397 | // Some examples return times from time generator. 398 | tg := examples.TimeGenerator{} 399 | times := []time.Time{tg.Next(), tg.Next(), tg.Next(), tg.Next()} 400 | // 401 | type SelectTest struct { 402 | Name string 403 | Example examples.Example 404 | Dest reflect.Type 405 | DestPointers reflect.Type 406 | Expect interface{} 407 | Scanner *sqlh.Scanner 408 | } 409 | tests := []SelectTest{ 410 | { 411 | Name: "slice-struct", 412 | Example: examples.ExSimpleMapper, 413 | Dest: reflect.TypeOf([]SimpleStruct(nil)), 414 | DestPointers: reflect.TypeOf([]*SimpleStruct(nil)), 415 | Expect: []SimpleStruct{ 416 | {Message: "Hello, World!", Number: 42}, 417 | {Message: "So long!", Number: 100}, 418 | }, 419 | Scanner: &sqlh.Scanner{ 420 | Mapper: &set.Mapper{}, 421 | }, 422 | }, 423 | { 424 | Name: "slice-struct-with-nesting", 425 | Example: examples.ExNestedStruct, 426 | Dest: reflect.TypeOf([]NestedOuterStruct(nil)), 427 | DestPointers: reflect.TypeOf([]*NestedOuterStruct(nil)), 428 | Expect: []NestedOuterStruct{ 429 | { 430 | NestedInnerStruct: NestedInnerStruct{Id: 1, Created: times[0], Modified: times[1]}, 431 | Message: "Hello, World!", 432 | Number: 42, 433 | }, 434 | { 435 | NestedInnerStruct: NestedInnerStruct{Id: 2, Created: times[2], Modified: times[3]}, 436 | Message: "So long!", 437 | Number: 100, 438 | }, 439 | }, 440 | Scanner: &sqlh.Scanner{ 441 | Mapper: &set.Mapper{ 442 | Elevated: set.NewTypeList(NestedInnerStruct{}), 443 | Tags: []string{"db", "json"}, 444 | }, 445 | }, 446 | }, 447 | { 448 | Name: "slice-struct-with-pointer-nesting", 449 | Example: examples.ExNestedStruct, 450 | Dest: reflect.TypeOf([]PointerOuterStruct(nil)), 451 | DestPointers: reflect.TypeOf([]*PointerOuterStruct(nil)), 452 | Expect: []NestedOuterStruct{ 453 | { 454 | NestedInnerStruct: NestedInnerStruct{Id: 1, Created: times[0], Modified: times[1]}, 455 | Message: "Hello, World!", 456 | Number: 42, 457 | }, 458 | { 459 | NestedInnerStruct: NestedInnerStruct{Id: 2, Created: times[2], Modified: times[3]}, 460 | Message: "So long!", 461 | Number: 100, 462 | }, 463 | }, 464 | Scanner: &sqlh.Scanner{ 465 | Mapper: &set.Mapper{ 466 | Elevated: set.NewTypeList(NestedInnerStruct{}), 467 | Tags: []string{"db", "json"}, 468 | }, 469 | }, 470 | }, 471 | } 472 | for _, test := range tests { 473 | test := test 474 | t.Run(test.Name, func(t *testing.T) { 475 | t.Parallel() 476 | 477 | chk := assert.New(t) 478 | db, err := examples.Connect(test.Example) 479 | chk.NoError(err) 480 | 481 | dest := reflect.New(test.Dest).Interface() 482 | 483 | err = test.Scanner.Select(db, dest, "select * from mytable") 484 | chk.NoError(err) 485 | 486 | // There's different ways we can check for equality here but we'll just see if 487 | // what we have encodes the same as what we expect. 488 | expect, err := json.Marshal(test.Expect) 489 | chk.NoError(err) 490 | actual, err := json.Marshal(dest) 491 | chk.NoError(err) 492 | chk.Equal(expect, actual) 493 | }) 494 | t.Run(test.Name+"-pointers", func(t *testing.T) { 495 | t.Parallel() 496 | 497 | chk := assert.New(t) 498 | db, err := examples.Connect(test.Example) 499 | chk.NoError(err) 500 | 501 | dest := reflect.New(test.DestPointers).Interface() 502 | 503 | err = test.Scanner.Select(db, dest, "select * from mytable") 504 | chk.NoError(err) 505 | 506 | // There's different ways we can check for equality here but we'll just see if 507 | // what we have encodes the same as what we expect. 508 | expect, err := json.Marshal(test.Expect) 509 | chk.NoError(err) 510 | actual, err := json.Marshal(dest) 511 | chk.NoError(err) 512 | chk.Equal(expect, actual) 513 | }) 514 | t.Run(test.Name+"-scan-rows", func(t *testing.T) { 515 | t.Parallel() 516 | 517 | chk := assert.New(t) 518 | db, err := examples.Connect(test.Example) 519 | chk.NoError(err) 520 | 521 | dest := reflect.New(test.Dest).Interface() 522 | 523 | rows, err := db.Query("select * from mytable") 524 | chk.NoError(err) 525 | defer rows.Close() 526 | 527 | err = test.Scanner.ScanRows(rows, dest) 528 | chk.NoError(err) 529 | 530 | // There's different ways we can check for equality here but we'll just see if 531 | // what we have encodes the same as what we expect. 532 | expect, err := json.Marshal(test.Expect) 533 | chk.NoError(err) 534 | actual, err := json.Marshal(dest) 535 | chk.NoError(err) 536 | chk.Equal(expect, actual) 537 | }) 538 | 539 | } 540 | } 541 | -------------------------------------------------------------------------------- /schema/column.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import "fmt" 4 | 5 | // Column describes a database column. 6 | type Column struct { 7 | // Name specifies the column name. 8 | Name string 9 | // GoType specifies the type in Go and should be set to a specific Go type. 10 | GoType interface{} 11 | // SqlType specifies the type in SQL and can be set to a string describing the SQL type. 12 | SqlType string 13 | } 14 | 15 | // String describes the column as a string. 16 | func (me Column) String() string { 17 | sqlType := me.SqlType 18 | if sqlType == "" { 19 | sqlType = "-" 20 | } 21 | return fmt.Sprintf("%v go(%T) sql(%v)", me.Name, me.GoType, sqlType) 22 | } 23 | -------------------------------------------------------------------------------- /schema/index.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // Index describes a database index. 9 | type Index struct { 10 | // Name specifies the index name. 11 | Name string 12 | // Columns contains the columns in the index. 13 | Columns []Column 14 | // IsPrimary is true if the index represents a primary key; if IsPrimary is true then 15 | // IsUnique is also true. 16 | IsPrimary bool 17 | // IsUnique is true if the index represents a unique key. 18 | IsUnique bool 19 | } 20 | 21 | // String describes the index as a string. 22 | func (me Index) String() string { 23 | if me.Name == "" && len(me.Columns) == 0 && me.IsPrimary == false && me.IsUnique == false { 24 | return "" 25 | } 26 | // 27 | describe := "unique" 28 | if me.IsPrimary { 29 | if len(me.Columns) > 1 { 30 | describe = "primary composite key" 31 | } else { 32 | describe = "primary key" 33 | } 34 | } 35 | name := me.Name 36 | if name == "" { 37 | name = "-" 38 | } 39 | fields, gotypes, sqltypes := []string{}, []string{}, []string{} 40 | for _, column := range me.Columns { 41 | fields = append(fields, column.Name) 42 | gotypes = append(gotypes, fmt.Sprintf("%T", column.GoType)) 43 | sqltypes = append(sqltypes, column.SqlType) 44 | } 45 | return fmt.Sprintf("%v name=%v (%v) go(%v) sql(%v)", describe, name, strings.Join(fields, ","), strings.Join(gotypes, ","), strings.Join(sqltypes, ",")) 46 | } 47 | -------------------------------------------------------------------------------- /schema/pkg_doc.go: -------------------------------------------------------------------------------- 1 | // Package schema contains types to describe a database schema. 2 | package schema 3 | -------------------------------------------------------------------------------- /schema/table.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | // Table describes a database table. 4 | type Table struct { 5 | // Name specifies the database table name. 6 | Name string 7 | // Columns represents the table columns. 8 | Columns []Column 9 | // PrimaryKey is the index describing the table's primary key if it has one. 10 | // A primary key - by definition - is a unique index; however it is not also 11 | // stored in the Unique field. 12 | PrimaryKey Index 13 | // Unique is a slice of unique indexes on the table. 14 | Unique []Index 15 | } 16 | 17 | // String describes the table as a string. 18 | func (me Table) String() string { 19 | rv := "" 20 | // 21 | name := me.Name 22 | if name == "" { 23 | name = "- (table, name unknown)" 24 | } 25 | rv = name 26 | // 27 | // primary key 28 | primary := me.PrimaryKey.String() 29 | if primary != "" { 30 | rv = rv + "\n\t" + primary 31 | } 32 | // 33 | // columns 34 | if len(me.Columns) > 0 { 35 | rv = rv + "\n\tcolumns" 36 | for _, column := range me.Columns { 37 | rv = rv + "\n\t\t" + column.String() 38 | } 39 | } 40 | // 41 | // unique indexes 42 | if len(me.Unique) > 0 { 43 | rv = rv + "\n\tunique indexes" 44 | for _, index := range me.Unique { 45 | rv = rv + "\n\t\t" + index.String() 46 | } 47 | } 48 | // 49 | return rv 50 | } 51 | -------------------------------------------------------------------------------- /transact.go: -------------------------------------------------------------------------------- 1 | package sqlh 2 | 3 | import ( 4 | "database/sql" 5 | 6 | "github.com/nofeaturesonlybugs/errors" 7 | ) 8 | 9 | // Transact runs fn inside a transaction if Q supports transactions; otherwise it just calls fn(Q). If a transaction 10 | // is started and fn returns a non-nil error then the transaction is rolled back. 11 | func Transact(Q IQueries, fn func(Q IQueries) error) error { 12 | var B IBegins 13 | var T *sql.Tx 14 | var ok bool 15 | var err, txnErr error 16 | if B, ok = Q.(IBegins); !ok { 17 | return fn(Q) 18 | } else if T, err = B.Begin(); err != nil { 19 | return errors.Go(err) 20 | } else if err = fn(T); err != nil { 21 | err = errors.Go(err) 22 | if txnErr = T.Rollback(); txnErr != nil { 23 | err.(errors.Error).Tag("transaction-rollback", txnErr.Error()) 24 | } 25 | return err 26 | } else if err = T.Commit(); err != nil { 27 | return errors.Go(err) 28 | } 29 | return nil 30 | } 31 | 32 | // TransactRollback is similar to Transact except the created transaction will always be rolled back; consider using this 33 | // during tests when you do not want to persist changes to the database. 34 | // 35 | // Unlike Transact the database object passed to this function must be of type IBegins so the caller is guaranteed 36 | // fn occurs under a transaction that will be rolled back. 37 | func TransactRollback(B IBegins, fn func(Q IQueries) error) error { 38 | var T *sql.Tx 39 | var err error 40 | if T, err = B.Begin(); err != nil { 41 | return errors.Go(err) 42 | } else if err = fn(T); err != nil { 43 | err = errors.Go(err) 44 | } 45 | if e2 := T.Rollback(); e2 != nil { 46 | if err == nil { 47 | err = errors.Go(e2) 48 | } else { 49 | err.(errors.Error).Tag("transaction-rollback", e2.Error()) 50 | } 51 | } 52 | return err 53 | } 54 | -------------------------------------------------------------------------------- /transact_test.go: -------------------------------------------------------------------------------- 1 | package sqlh_test 2 | 3 | import ( 4 | "database/sql/driver" 5 | "testing" 6 | 7 | "github.com/DATA-DOG/go-sqlmock" 8 | "github.com/nofeaturesonlybugs/errors" 9 | "github.com/stretchr/testify/assert" 10 | 11 | "github.com/nofeaturesonlybugs/sqlh" 12 | ) 13 | 14 | func TestTransact(t *testing.T) { 15 | chk := assert.New(t) 16 | // 17 | db, mock, err := sqlmock.New() 18 | chk.NoError(err) 19 | chk.NotNil(db) 20 | chk.NotNil(mock) 21 | // 22 | // Several of the tests use a insert...insert function. 23 | Insert2xFunc := func(Q sqlh.IQueries) error { 24 | var err error 25 | if _, err = Q.Exec("insert into my table", "a", "b", "c"); err != nil { 26 | return err 27 | } else if _, err = Q.Exec("insert into my table", "1", "2", "3"); err != nil { 28 | return err 29 | } 30 | return nil 31 | } 32 | // 33 | type Test struct { 34 | Name string 35 | MockFn func(sqlmock.Sqlmock) 36 | TransactFn func(sqlh.IQueries) error 37 | ExpectError bool 38 | } 39 | tests := []Test{ 40 | { 41 | Name: "no errors", 42 | MockFn: func(mock sqlmock.Sqlmock) { 43 | mock.ExpectBegin() 44 | mock.ExpectExec("insert+").WithArgs("a", "b", "c").WillReturnResult(driver.ResultNoRows) 45 | mock.ExpectExec("insert+").WithArgs("1", "2", "3").WillReturnResult(driver.ResultNoRows) 46 | mock.ExpectCommit() 47 | }, 48 | TransactFn: Insert2xFunc, 49 | ExpectError: false, 50 | }, 51 | { 52 | Name: "commit error", 53 | MockFn: func(sqlmock.Sqlmock) { 54 | mock.ExpectBegin() 55 | mock.ExpectExec("insert+").WithArgs("a", "b", "c").WillReturnResult(driver.ResultNoRows) 56 | mock.ExpectExec("insert+").WithArgs("1", "2", "3").WillReturnResult(driver.ResultNoRows) 57 | mock.ExpectCommit().WillReturnError(errors.Errorf("commit error")) 58 | }, 59 | TransactFn: Insert2xFunc, 60 | ExpectError: true, 61 | }, 62 | { 63 | // Calling sqlh.Transact() inside sqlh.Transaction should result in no additional calls 64 | // to Begin(). 65 | Name: "nested", 66 | MockFn: func(sqlmock.Sqlmock) { 67 | mock.ExpectBegin() 68 | mock.ExpectExec("insert+").WithArgs("a", "b", "c").WillReturnResult(driver.ResultNoRows) 69 | mock.ExpectExec("insert+").WithArgs("1", "2", "3").WillReturnResult(driver.ResultNoRows) 70 | mock.ExpectExec("update+").WithArgs("x", "y", "z").WillReturnResult(driver.ResultNoRows) 71 | mock.ExpectCommit() 72 | }, 73 | TransactFn: func(Q sqlh.IQueries) error { 74 | var err error 75 | if _, err = Q.Exec("insert into my table", "a", "b", "c"); err != nil { 76 | return err 77 | } else if _, err = Q.Exec("insert into my table", "1", "2", "3"); err != nil { 78 | return err 79 | } else if err = sqlh.Transact(Q, func(Q sqlh.IQueries) error { 80 | var err error 81 | if _, err = Q.Exec("update my table set", "x", "y", "z"); err != nil { 82 | return err 83 | } 84 | return nil 85 | }); err != nil { 86 | return err 87 | } 88 | return nil 89 | }, 90 | ExpectError: false, 91 | }, 92 | { 93 | // Fn returns error so expect a rollback and error to filter upwards. 94 | Name: "rollback", 95 | MockFn: func(sqlmock.Sqlmock) { 96 | mock.ExpectBegin() 97 | mock.ExpectExec("insert+").WithArgs("a", "b", "c").WillReturnResult(driver.ResultNoRows) 98 | mock.ExpectExec("insert+").WithArgs("1", "2", "3").WillReturnError(errors.Errorf("insert error")) 99 | mock.ExpectRollback() 100 | }, 101 | TransactFn: Insert2xFunc, 102 | ExpectError: true, 103 | }, 104 | { 105 | // Fn returns error so expect rollback but now rollback also errors. 106 | Name: "rollback error", 107 | MockFn: func(sqlmock.Sqlmock) { 108 | mock.ExpectBegin() 109 | mock.ExpectExec("insert+").WithArgs("a", "b", "c").WillReturnResult(driver.ResultNoRows) 110 | mock.ExpectExec("insert+").WithArgs("1", "2", "3").WillReturnError(errors.Errorf("insert error")) 111 | mock.ExpectRollback().WillReturnError(errors.Errorf("rollback error")) 112 | }, 113 | TransactFn: Insert2xFunc, 114 | ExpectError: true, 115 | }, 116 | { 117 | Name: "begin error", 118 | MockFn: func(sqlmock.Sqlmock) { 119 | mock.ExpectBegin().WillReturnError(errors.Errorf("begin error")) 120 | }, 121 | TransactFn: Insert2xFunc, 122 | ExpectError: true, 123 | }, 124 | } 125 | for _, test := range tests { 126 | t.Run(test.Name, func(t *testing.T) { 127 | chk := assert.New(t) 128 | test.MockFn(mock) 129 | err := sqlh.Transact(db, test.TransactFn) 130 | if test.ExpectError { 131 | chk.Error(err) 132 | } else { 133 | chk.NoError(err) 134 | } 135 | err = mock.ExpectationsWereMet() 136 | chk.NoError(err) 137 | }) 138 | } 139 | } 140 | 141 | func TestTransactWithRollback(t *testing.T) { 142 | chk := assert.New(t) 143 | // 144 | db, mock, err := sqlmock.New() 145 | chk.NoError(err) 146 | chk.NotNil(db) 147 | chk.NotNil(mock) 148 | // 149 | { // begin, insert, insert, rollback 150 | mock.ExpectBegin() 151 | mock.ExpectExec("insert+").WithArgs("a", "b", "c").WillReturnResult(driver.ResultNoRows) 152 | mock.ExpectExec("insert+").WithArgs("1", "2", "3").WillReturnResult(driver.ResultNoRows) 153 | mock.ExpectRollback() 154 | // 155 | err = sqlh.TransactRollback(db, func(Q sqlh.IQueries) error { 156 | var err error 157 | if _, err = Q.Exec("insert into my table", "a", "b", "c"); err != nil { 158 | return err 159 | } else if _, err = Q.Exec("insert into my table", "1", "2", "3"); err != nil { 160 | return err 161 | } 162 | return nil 163 | }) 164 | chk.NoError(err) 165 | err = mock.ExpectationsWereMet() 166 | chk.NoError(err) 167 | } 168 | { // begin (with begin error) 169 | mock.ExpectBegin().WillReturnError(errors.Errorf("begin error")) 170 | // 171 | err = sqlh.TransactRollback(db, func(Q sqlh.IQueries) error { 172 | var err error 173 | if _, err = Q.Exec("insert into my table", "a", "b", "c"); err != nil { 174 | return err 175 | } else if _, err = Q.Exec("insert into my table", "1", "2", "3"); err != nil { 176 | return err 177 | } 178 | return nil 179 | }) 180 | chk.Error(err) 181 | err = mock.ExpectationsWereMet() 182 | chk.NoError(err) 183 | } 184 | { // begin, insert, insert, rollback (with rollback error) 185 | mock.ExpectBegin() 186 | mock.ExpectExec("insert+").WithArgs("a", "b", "c").WillReturnResult(driver.ResultNoRows) 187 | mock.ExpectExec("insert+").WithArgs("1", "2", "3").WillReturnResult(driver.ResultNoRows) 188 | mock.ExpectRollback().WillReturnError(errors.Errorf("rollback error")) 189 | // 190 | err = sqlh.TransactRollback(db, func(Q sqlh.IQueries) error { 191 | var err error 192 | if _, err = Q.Exec("insert into my table", "a", "b", "c"); err != nil { 193 | return err 194 | } else if _, err = Q.Exec("insert into my table", "1", "2", "3"); err != nil { 195 | return err 196 | } 197 | return nil 198 | }) 199 | chk.Error(err) 200 | err = mock.ExpectationsWereMet() 201 | chk.NoError(err) 202 | } 203 | { // begin, insert, insert, fn returns error, rollback (with rollback error) 204 | mock.ExpectBegin() 205 | mock.ExpectExec("insert+").WithArgs("a", "b", "c").WillReturnResult(driver.ResultNoRows) 206 | mock.ExpectExec("insert+").WithArgs("1", "2", "3").WillReturnResult(driver.ResultNoRows) 207 | mock.ExpectRollback().WillReturnError(errors.Errorf("rollback error")) 208 | // 209 | err = sqlh.TransactRollback(db, func(Q sqlh.IQueries) error { 210 | var err error 211 | if _, err = Q.Exec("insert into my table", "a", "b", "c"); err != nil { 212 | return err 213 | } else if _, err = Q.Exec("insert into my table", "1", "2", "3"); err != nil { 214 | return err 215 | } 216 | return errors.Errorf("force error") 217 | }) 218 | chk.Error(err) 219 | err = mock.ExpectationsWereMet() 220 | chk.NoError(err) 221 | } 222 | } 223 | --------------------------------------------------------------------------------