├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── doc.go ├── go.mod ├── go.sum ├── loadsave.go ├── loadsave_test.go ├── mapper.go ├── mapper_test.go ├── meddler.go ├── meddler_test.go ├── scan.go └── scan_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Travis CI (http://travis-ci.org/) is a continuous integration service for 2 | # open source projects. This file configures it to run unit tests for 3 | # meddler. 4 | 5 | language: go 6 | 7 | go: 8 | - "1.9.x" 9 | - "1.10.x" 10 | - "1.11.x" 11 | - "1.12.x" 12 | - "1.13.x" 13 | 14 | install: 15 | - go get -d -t -v ./... 16 | - go build -v ./... 17 | 18 | script: 19 | - go test -v ./... 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2013 Russ Ross 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Meddler [![Build Status](https://travis-ci.org/russross/meddler.svg?branch=master)](https://travis-ci.org/russross/meddler) [![GoDoc](https://godoc.org/github.com/russross/meddler?status.svg)](https://godoc.org/github.com/russross/meddler) [![Go Report Card](https://goreportcard.com/badge/github.com/russross/meddler)](https://goreportcard.com/report/github.com/russross/meddler) 2 | ======= 3 | 4 | Meddler is a small toolkit to take some of the tedium out of moving data 5 | back and forth between SQL queries and structs. 6 | 7 | It is not a complete ORM. Meddler is intended to be a lightweight way to add some 8 | of the convenience of an ORM while leaving more control in the hands of the 9 | programmer. 10 | 11 | Package docs are available at: 12 | 13 | * http://godoc.org/github.com/russross/meddler 14 | 15 | The package is housed on GitHub, and the README there has more info: 16 | 17 | * http://github.com/russross/meddler 18 | 19 | Meddler is currently configured for SQLite, MySQL, and PostgreSQL, but it 20 | can be configured for use with other databases. If you use it 21 | successfully with a different database, please contact me and I will 22 | add it to the list of pre-configured databases. 23 | 24 | ### DANGER 25 | 26 | Meddler is still a work in progress, and additional 27 | backward-incompatible changes to the API are likely. The most recent 28 | change added support for multiple database types and made it easier 29 | to switch between them. This is most likely to affect the way you 30 | initialize the library to work with your database (see the install 31 | section below). 32 | 33 | Another recent update is the change to int64 for primary keys. This 34 | matches the convention used in database/sql, and is more portable, 35 | but it may require some minor changes to existing code. 36 | 37 | 38 | Install 39 | ------- 40 | 41 | The usual `go get` command will put it in your `$GOPATH`: 42 | 43 | go get github.com/russross/meddler 44 | 45 | If you are only using one type of database, you should set Default 46 | to match your database type, e.g.: 47 | 48 | meddler.Default = meddler.PostgreSQL 49 | 50 | The default database is MySQL, so you should change it for anything 51 | else. To use multiple databases within a single project, or to use a 52 | database other than MySQL, PostgreSQL, or SQLite, see below. 53 | 54 | Note: If you are using MySQL with the `github.com/go-sql-driver/mysql` 55 | driver, you must set "parseTime=true" in the sql.Open call or the 56 | time conversion meddlers will not work. 57 | 58 | 59 | Why? 60 | ---- 61 | 62 | These are the features that set meddler apart from similar 63 | libraries: 64 | 65 | * It uses standard database/sql types, and does not require 66 | special fields in your structs. This lets you use meddler 67 | selectively, without having to alter other database code already 68 | in your project. After creating meddler, I incorporated it into 69 | an existing project, and I was able to convert the code one 70 | struct and one query at a time. 71 | * It leaves query writing to you. It has convenience functions for 72 | simple INSERT/UPDATE/SELECT queries by integer primary key, but 73 | beyond that it stays out of query writing. 74 | * It supports on-the-fly data transformations. If you have a map 75 | or a slice in your struct, you can instruct meddler to 76 | encode/decode using JSON or Gob automatically. If you have time 77 | fields, you can have meddler automatically write them into the 78 | database as UTC, and convert them to the local time zone on 79 | reads. These processors are called “meddlers”, because they 80 | meddle with the data instead of passing it through directly. 81 | * NULL fields in the database can be read as zero values in the 82 | struct, and zero values in the struct can be written as NULL 83 | values. This is not always the right thing to do, but it is 84 | often good enough and is much simpler than most alternatives. 85 | * It exposes low-level hooks for more complex situations. If you 86 | are writing a query that does not map well to the main helper 87 | functions, you can still get some help by using the lower-level 88 | functions to build your own helpers. 89 | 90 | 91 | High-level functions 92 | -------------------- 93 | 94 | Meddler does not create or alter tables. It just provides a little 95 | glue to make it easier to read and write structs as SQL rows. Start 96 | by annotating a struct: 97 | 98 | ``` go 99 | type Person struct { 100 | ID int `meddler:"id,pk"` 101 | Name string `meddler:"name"` 102 | Age int 103 | salary int 104 | Created time.Time `meddler:"created,localtime"` 105 | Closed time.Time `meddler:",localtimez"` 106 | } 107 | ``` 108 | 109 | Notes about this example: 110 | 111 | * If the optional tag is provided, the first field is the database 112 | column name. Note that "Closed" does not provide a column name, 113 | so it will default to "Closed". Likewise, if there is no tag, 114 | the field name will be used. 115 | * ID is marked as the primary key. Currently only integer primary 116 | keys are supported. This is only relevant to Load, Save, Insert, 117 | and Update, a few of the higher-level functions that need to 118 | understand primary keys. Meddler assumes that pk fields have an 119 | autoincrement mechanism set in the database. 120 | * Age has a column name of "Age". A tag is only necessary when the 121 | column name is not the same as the field name, or when you need 122 | to select other options. 123 | * salary is not an exported field, so meddler does not see it. It 124 | will be ignored. 125 | * Created is marked with "localtime". This means that it will be 126 | converted to UTC when being saved, and back to the local time 127 | zone when being loaded. 128 | * Closed has a column name of "Closed", since the tag did not 129 | specify anything different. Closed is marked as "localtimez". 130 | This has the same properties as "localtime", except that the 131 | zero time will be saved in the database as a null column (and 132 | null values will be loaded as the zero time value). 133 | * You can set a default column name mapping by setting 134 | `meddler.Mapper` to a `func(s string) string` function. For 135 | example, `meddler.Mapper = meddler.SnakeCase` will convert field 136 | names to snake_case unless an explict column name is specified. 137 | 138 | Meddler provides a few high-level functions (note: DB is an 139 | interface that works with a *sql.DB or a *sql.Tx): 140 | 141 | * Load(db DB, table string, dst interface{}, pk int64) error 142 | 143 | This loads a single record by its primary key. For example: 144 | 145 | ```go 146 | elt := new(Person) 147 | err := meddler.Load(db, "person", elt, 15) 148 | ``` 149 | 150 | db can be a *sql.DB or a *sql.Tx. The table is the name of the 151 | table, pk is the primary key value, and dst is a pointer to the 152 | struct where it should be stored. 153 | 154 | Note: this call requires that the struct have an integer primary 155 | key field marked. 156 | 157 | * Insert(db DB, table string, src interface{}) error 158 | 159 | This inserts a new row into the database. If the struct value 160 | has a primary key field, it must be zero (and will be omitted 161 | from the insert statement, prompting a default autoincrement 162 | value). 163 | 164 | ```go 165 | elt := &Person{ 166 | Name: "Alice", 167 | Age: 22, 168 | // ... 169 | } 170 | err := meddler.Insert(db, "person", elt) 171 | // elt.ID is updated to the value assigned by the database 172 | ``` 173 | 174 | * Update(db DB, table string, src interface{}) error 175 | 176 | This updates an existing row. It must have a primary key, which 177 | must be non-zero. 178 | 179 | Note: this call requires that the struct have an integer primary 180 | key field marked. 181 | 182 | * Save(db DB, table string, src interface{}) error 183 | 184 | Pick Insert or Update automatically. If there is a non-zero 185 | primary key present, it uses Update, otherwise it uses Insert. 186 | 187 | Note: this call requires that the struct have an integer primary 188 | key field marked. 189 | 190 | * QueryRow(db DB, dst interface{}, query string, args ...interface) error 191 | 192 | Perform the given query, and scan the single-row result into 193 | dst, which must be a pointer to a struct. 194 | 195 | For example: 196 | 197 | ```go 198 | elt := new(Person) 199 | err := meddler.QueryRow(db, elt, "select * from person where name = ?", "bob") 200 | ``` 201 | 202 | * QueryAll(db DB, dst interface{}, query string, args ...interface) error 203 | 204 | Perform the given query, and scan the results into dst, which 205 | must be a pointer to a slice of struct pointers. 206 | 207 | For example: 208 | 209 | ```go 210 | var people []*Person 211 | err := meddler.QueryAll(db, &people, "select * from person") 212 | ``` 213 | 214 | * Scan(rows *sql.Rows, dst interface{}) error 215 | 216 | Scans a single row of data into a struct, complete with 217 | meddling. Can be called repeatedly to walk through all of the 218 | rows in a result set. Returns sql.ErrNoRows when there is no 219 | more data. 220 | 221 | * ScanRow(rows *sql.Rows, dst interface{}) error 222 | 223 | Similar to Scan, but guarantees that the rows object 224 | is closed when it returns. Also returns sql.ErrNoRows if there 225 | was no row. 226 | 227 | * ScanAll(rows *sql.Rows, dst interface{}) error 228 | 229 | Expects a pointer to a slice of structs/pointers to structs, and 230 | appends as many elements as it finds in the row set. Closes the 231 | row set when it is finished. Does not return sql.ErrNoRows on an 232 | empty set; instead it just does not add anything to the slice. 233 | 234 | Note: all of these functions can also be used as methods on Database 235 | objects. When used as package functions, they use the Default 236 | Database object, which is MySQL unless you change it. 237 | 238 | 239 | Meddlers 240 | -------- 241 | 242 | A meddler is a handler that gets to meddle with a field before it is 243 | saved, or when it is loaded. "localtime" and "localtimez" are 244 | examples of built-in meddlers. The full list of built-in meddlers 245 | includes: 246 | 247 | * identity: the default meddler, which does not do anything 248 | 249 | * localtime: for time.Time and *time.Time fields. Converts the 250 | value to UTC on save, and back to the local time zone on loads. 251 | To set your local time zone, make sure the TZ environment 252 | variable is set when your program is launched, or use something 253 | like: 254 | 255 | ```go 256 | os.Setenv("TZ", "America/Denver") 257 | ``` 258 | 259 | in your initial setup, before you start using time functions. 260 | 261 | * localtimez: same, but only for time.Time, and treats the zero 262 | time as a null field (converts both ways) 263 | 264 | * utctime: similar to localtime, but keeps the value in UTC on 265 | loads. This ensures that the time is always coverted to UTC on 266 | save, which is the sane way to save time values in a database. 267 | 268 | * utctimez: same, but with zero time means null. 269 | 270 | * zeroisnull: for other types where a zero value should be 271 | inserted as null, and null values should be read as zero values. 272 | Works for integer, unsigned integer, float, complex number, and 273 | string types. Note: not for pointer types. 274 | 275 | * json: marshals the field value into JSON when saving, and 276 | unmarshals on load. 277 | 278 | * jsongzip: same, but compresses using gzip on save, and 279 | uncompresses on load 280 | 281 | * gob: encodes the field value using Gob when saving, and 282 | decodes on load. 283 | 284 | * gobgzip: same, but compresses using gzip on save, and 285 | uncompresses on load 286 | 287 | You can implement custom meddlers as well by implementing the 288 | Meddler interface. See the existing implementations in medder.go for 289 | examples. 290 | 291 | 292 | Working with different database types 293 | ------------------------------------- 294 | 295 | Meddler can work with multiple database types simultaneously. 296 | Database-specific parameters are stored in a Database struct, and 297 | structs are pre-defined for MySQL, PostgreSQL, and SQLite. 298 | 299 | Instead of relying on the package-level functions, use the method 300 | form on the appropriate database type, e.g.: 301 | 302 | ```go 303 | err = meddler.PostgreSQL.Load(...) 304 | ``` 305 | 306 | instead of 307 | 308 | ```go 309 | err = meddler.Load(...) 310 | ``` 311 | 312 | Or to save typing, define your own abbreviated name for each 313 | database: 314 | 315 | ```go 316 | ms := meddler.MySQL 317 | pg := meddler.PostgreSQL 318 | err = ms.Load(...) 319 | err = pg.QueryAll(...) 320 | ``` 321 | 322 | If you need a different database, create your own Database instance 323 | with the appropriate parameters set. If everything works okay, 324 | please contact me with the parameters you used so I can add the new 325 | database to the pre-defined list. 326 | 327 | 328 | Lower-level functions 329 | --------------------- 330 | 331 | If you are using more complex queries and just want to reduce the 332 | tedium of reading and writing values, there are some lower-level 333 | helper functions as well. See the package docs for details, and 334 | see the implementations of the higher-level functions to see how 335 | they are used. 336 | 337 | 338 | License 339 | ------- 340 | 341 | Meddler is distributed under the BSD 2-Clause License. If this 342 | license prevents you from using Meddler in your project, please 343 | contact me and I will consider adding an additional license that is 344 | better suited to your needs. 345 | 346 | > Copyright © 2013 Russ Ross. 347 | > All rights reserved. 348 | > 349 | > Redistribution and use in source and binary forms, with or without 350 | > modification, are permitted provided that the following conditions 351 | > are met: 352 | > 353 | > 1. Redistributions of source code must retain the above copyright 354 | > notice, this list of conditions and the following disclaimer. 355 | > 356 | > 2. Redistributions in binary form must reproduce the above 357 | > copyright notice, this list of conditions and the following 358 | > disclaimer in the documentation and/or other materials provided with 359 | > the distribution. 360 | > 361 | > THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 362 | > "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 363 | > LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 364 | > FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 365 | > COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 366 | > INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 367 | > BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 368 | > LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 369 | > CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 370 | > LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 371 | > ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 372 | > POSSIBILITY OF SUCH DAMAGE. 373 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package Meddler is a small toolkit to take some of the tedium out of moving data 2 | // back and forth between SQL queries and structs. 3 | // 4 | // It is not a complete ORM. It is intended to be lightweight way to add some 5 | // of the convenience of an ORM while leaving more control in the hands of the 6 | // programmer. 7 | // 8 | // Package docs are available at: 9 | // 10 | // http://godoc.org/github.com/russross/meddler 11 | // 12 | // The package is housed on GitHub, and the README there has more info: 13 | // 14 | // http://github.com/russross/meddler 15 | 16 | package meddler 17 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/russross/meddler 2 | 3 | go 1.13 4 | 5 | require github.com/mattn/go-sqlite3 v1.14.7 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= 2 | github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= 3 | -------------------------------------------------------------------------------- /loadsave.go: -------------------------------------------------------------------------------- 1 | package meddler 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | type dbErr struct { 10 | msg string 11 | err error 12 | } 13 | 14 | func (err *dbErr) Error() string { 15 | return fmt.Sprintf("%s: %v", err.msg, err.err) 16 | } 17 | 18 | // DriverErr returns the original error as returned by the database driver 19 | // if the error comes from the driver, with the second value set to true. 20 | // Otherwise, it returns err itself with false as second value. 21 | func DriverErr(err error) (error, bool) { 22 | if dbe, ok := err.(*dbErr); ok { 23 | return dbe.err, true 24 | } 25 | return err, false 26 | } 27 | 28 | // DB is a generic database interface, matching both *sql.Db and *sql.Tx 29 | type DB interface { 30 | Exec(query string, args ...interface{}) (sql.Result, error) 31 | Query(query string, args ...interface{}) (*sql.Rows, error) 32 | QueryRow(query string, args ...interface{}) *sql.Row 33 | } 34 | 35 | // Load loads a record using a query for the primary key field. 36 | // Returns sql.ErrNoRows if not found. 37 | func (d *Database) Load(db DB, table string, dst interface{}, pk int64) error { 38 | columns, err := d.ColumnsQuoted(dst, true) 39 | if err != nil { 40 | return err 41 | } 42 | 43 | // make sure we have a primary key field 44 | pkName, _, err := d.PrimaryKey(dst) 45 | if err != nil { 46 | return err 47 | } 48 | if pkName == "" { 49 | return fmt.Errorf("meddler.Load: no primary key field found") 50 | } 51 | 52 | // run the query 53 | q := fmt.Sprintf("SELECT %s FROM %s WHERE %s = %s", columns, d.quoted(table), d.quoted(pkName), d.Placeholder) 54 | 55 | rows, err := db.Query(q, pk) 56 | if err != nil { 57 | return &dbErr{msg: "meddler.Load: DB error in Query", err: err} 58 | } 59 | 60 | // scan the row 61 | return d.ScanRow(rows, dst) 62 | } 63 | 64 | // Load using the Default Database type 65 | func Load(db DB, table string, dst interface{}, pk int64) error { 66 | return Default.Load(db, table, dst, pk) 67 | } 68 | 69 | // Insert performs an INSERT query for the given record. 70 | // If the record has a primary key flagged, it must be zero, and it 71 | // will be set to the newly-allocated primary key value from the database 72 | // as returned by LastInsertId. 73 | func (d *Database) Insert(db DB, table string, src interface{}) error { 74 | pkName, pkValue, err := d.PrimaryKey(src) 75 | if err != nil { 76 | return err 77 | } 78 | if pkName != "" && pkValue != 0 { 79 | return fmt.Errorf("meddler.Insert: primary key must be zero") 80 | } 81 | 82 | // gather the query parts 83 | namesPart, err := d.ColumnsQuoted(src, false) 84 | if err != nil { 85 | return err 86 | } 87 | valuesPart, err := d.PlaceholdersString(src, false) 88 | if err != nil { 89 | return err 90 | } 91 | values, err := d.Values(src, false) 92 | if err != nil { 93 | return err 94 | } 95 | 96 | // run the query 97 | q := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", d.quoted(table), namesPart, valuesPart) 98 | if d.UseReturningToGetID && pkName != "" { 99 | q += " RETURNING " + d.quoted(pkName) 100 | var newPk int64 101 | err := db.QueryRow(q, values...).Scan(&newPk) 102 | if err != nil { 103 | return &dbErr{msg: "meddler.Insert: DB error in QueryRow", err: err} 104 | } 105 | if err = d.SetPrimaryKey(src, newPk); err != nil { 106 | return fmt.Errorf("meddler.Insert: Error saving updated pk: %v", err) 107 | } 108 | } else if pkName != "" { 109 | result, err := db.Exec(q, values...) 110 | if err != nil { 111 | return &dbErr{msg: "meddler.Insert: DB error in Exec", err: err} 112 | } 113 | 114 | // save the new primary key 115 | newPk, err := result.LastInsertId() 116 | if err != nil { 117 | return &dbErr{msg: "meddler.Insert: DB error getting new primary key value", err: err} 118 | } 119 | if err = d.SetPrimaryKey(src, newPk); err != nil { 120 | return fmt.Errorf("meddler.Insert: Error saving updated pk: %v", err) 121 | } 122 | } else { 123 | // no primary key, so no need to lookup new value 124 | _, err := db.Exec(q, values...) 125 | if err != nil { 126 | return &dbErr{msg: "meddler.Insert: DB error in Exec", err: err} 127 | } 128 | } 129 | 130 | return nil 131 | } 132 | 133 | // Insert using the Default Database type 134 | func Insert(db DB, table string, src interface{}) error { 135 | return Default.Insert(db, table, src) 136 | } 137 | 138 | // Update performs and UPDATE query for the given record. 139 | // The record must have an integer primary key field that is non-zero, 140 | // and it will be used to select the database row that gets updated. 141 | func (d *Database) Update(db DB, table string, src interface{}) error { 142 | // gather the query parts 143 | names, err := d.Columns(src, false) 144 | if err != nil { 145 | return err 146 | } 147 | placeholders, err := d.Placeholders(src, false) 148 | if err != nil { 149 | return err 150 | } 151 | values, err := d.Values(src, false) 152 | if err != nil { 153 | return err 154 | } 155 | 156 | // form the column=placeholder pairs 157 | var pairs []string 158 | for i := 0; i < len(names) && i < len(placeholders); i++ { 159 | pair := fmt.Sprintf("%s=%s", d.quoted(names[i]), placeholders[i]) 160 | pairs = append(pairs, pair) 161 | } 162 | 163 | pkName, pkValue, err := d.PrimaryKey(src) 164 | if err != nil { 165 | return err 166 | } 167 | if pkName == "" { 168 | return fmt.Errorf("meddler.Update: no primary key field") 169 | } 170 | if pkValue < 1 { 171 | return fmt.Errorf("meddler.Update: primary key must be an integer > 0") 172 | } 173 | ph := d.placeholder(len(placeholders) + 1) 174 | 175 | // run the query 176 | q := fmt.Sprintf("UPDATE %s SET %s WHERE %s=%s", d.quoted(table), 177 | strings.Join(pairs, ","), 178 | d.quoted(pkName), ph) 179 | values = append(values, pkValue) 180 | 181 | if _, err := db.Exec(q, values...); err != nil { 182 | return &dbErr{msg: "meddler.Update: DB error in Exec", err: err} 183 | } 184 | 185 | return nil 186 | } 187 | 188 | // Update using the Default Database type 189 | func Update(db DB, table string, src interface{}) error { 190 | return Default.Update(db, table, src) 191 | } 192 | 193 | // Save performs an INSERT or an UPDATE, depending on whether or not 194 | // a primary keys exists and is non-zero. 195 | func (d *Database) Save(db DB, table string, src interface{}) error { 196 | pkName, pkValue, err := d.PrimaryKey(src) 197 | if err != nil { 198 | return err 199 | } 200 | if pkName != "" && pkValue != 0 { 201 | return d.Update(db, table, src) 202 | } 203 | 204 | return d.Insert(db, table, src) 205 | } 206 | 207 | // Save using the Default Database type 208 | func Save(db DB, table string, src interface{}) error { 209 | return Default.Save(db, table, src) 210 | } 211 | 212 | // QueryRow performs the given query with the given arguments, scanning a 213 | // single row of results into dst. Returns sql.ErrNoRows if there was no 214 | // result row. 215 | func (d *Database) QueryRow(db DB, dst interface{}, query string, args ...interface{}) error { 216 | // perform the query 217 | rows, err := db.Query(query, args...) 218 | if err != nil { 219 | return err 220 | } 221 | 222 | // gather the result 223 | return d.ScanRow(rows, dst) 224 | } 225 | 226 | // QueryRow using the Default Database type 227 | func QueryRow(db DB, dst interface{}, query string, args ...interface{}) error { 228 | return Default.QueryRow(db, dst, query, args...) 229 | } 230 | 231 | // QueryAll performs the given query with the given arguments, scanning 232 | // all results rows into dst. 233 | func (d *Database) QueryAll(db DB, dst interface{}, query string, args ...interface{}) error { 234 | // perform the query 235 | rows, err := db.Query(query, args...) 236 | if err != nil { 237 | return err 238 | } 239 | 240 | // gather the results 241 | return d.ScanAll(rows, dst) 242 | } 243 | 244 | // QueryAll using the Default Database type 245 | func QueryAll(db DB, dst interface{}, query string, args ...interface{}) error { 246 | return Default.QueryAll(db, dst, query, args...) 247 | } 248 | -------------------------------------------------------------------------------- /loadsave_test.go: -------------------------------------------------------------------------------- 1 | package meddler 2 | 3 | import ( 4 | "io" 5 | "testing" 6 | "time" 7 | 8 | "github.com/mattn/go-sqlite3" 9 | ) 10 | 11 | func TestLoad(t *testing.T) { 12 | once.Do(setup) 13 | insertAliceBob(t) 14 | 15 | elt := new(Person) 16 | elt.Age = 50 17 | elt.Closed = time.Now() 18 | if err := Load(db, "person", elt, 2); err != nil { 19 | t.Errorf("Load error on Bob: %v", err) 20 | return 21 | } 22 | bob.ID = 2 23 | personEqual(t, elt, bob) 24 | db.Exec("delete from person") 25 | 26 | // test for err on invalid table 27 | if err := Load(db, "invalid_table_name", elt, 2); err == nil { 28 | t.Errorf("Load on invalid table, expected err, got nil") 29 | } 30 | } 31 | 32 | func TestNoPrimaryKey(t *testing.T) { 33 | once.Do(setup) 34 | 35 | // test without primary key 36 | type personWithoutPK struct { 37 | Name string 38 | } 39 | elt2 := new(personWithoutPK) 40 | if err := Load(db, "person", elt2, 2); err == nil { 41 | t.Error("Load on struct without PK: expected err, got nil") 42 | } 43 | if err := Update(db, "person", elt2); err == nil { 44 | t.Error("Update on struct without PK: expected err, got nil") 45 | } 46 | if err := Save(db, "person", elt2); err == nil { 47 | t.Error("Save on struct without PK: expected err, got nil") 48 | } 49 | } 50 | 51 | func TestLoadUint(t *testing.T) { 52 | once.Do(setup) 53 | insertAliceBob(t) 54 | 55 | elt := new(UintPerson) 56 | elt.Age = 50 57 | elt.Closed = time.Now() 58 | if err := Load(db, "person", elt, 2); err != nil { 59 | t.Errorf("Load error on Bob: %v", err) 60 | return 61 | } 62 | bob.ID = 2 63 | db.Exec("delete from person") 64 | } 65 | 66 | func TestQueryAll(t *testing.T) { 67 | once.Do(setup) 68 | insertAliceBob(t) 69 | var people []*Person 70 | if err := QueryAll(db, &people, "SELECT * FROM person", ""); err != nil { 71 | t.Errorf("QueryAll error: %v", err) 72 | } 73 | 74 | if len(people) != 2 { 75 | t.Errorf("QueryAll(): expected %d results, got %d", 2, len(people)) 76 | } 77 | 78 | db.Exec("delete from person") 79 | 80 | // test on unexisting table 81 | if err := QueryAll(db, &people, "SELECT * FROM invalid_table_name"); err == nil { 82 | t.Errorf("QueryAll on invalid table, expected err, got nil") 83 | } 84 | } 85 | 86 | func TestSave(t *testing.T) { 87 | once.Do(setup) 88 | insertAliceBob(t) 89 | 90 | h := 73 91 | chris := &Person{ 92 | ID: 0, 93 | Name: "Chris", 94 | Email: "chris@chris.com", 95 | Ephemeral: 19, 96 | Age: 23, 97 | Opened: when.Local(), 98 | Closed: when, 99 | Updated: nil, 100 | Height: &h, 101 | } 102 | 103 | tx, err := db.Begin() 104 | if err != nil { 105 | t.Errorf("DB error on begin: %v", err) 106 | } 107 | // test invalid table for err return value 108 | if err := Save(tx, "invalid_table_name", chris); err == nil { 109 | t.Error("Save with invalid table, expected err, got nil") 110 | } 111 | // save correctly 112 | if err = Save(tx, "person", chris); err != nil { 113 | t.Errorf("DB error on Save: %v", err) 114 | } 115 | 116 | id := chris.ID 117 | if id != 3 { 118 | t.Errorf("DB error on Save: expected ID of 3 but got %d", id) 119 | } 120 | 121 | chris.Email = "chris@chrischris.com" 122 | chris.Age = 27 123 | 124 | if err = Save(tx, "person", chris); err != nil { 125 | t.Errorf("DB error on Save: %v", err) 126 | } 127 | if chris.ID != id { 128 | t.Errorf("ID mismatch: found %d when %d expected", chris.ID, id) 129 | } 130 | if err = tx.Commit(); err != nil { 131 | t.Errorf("Commit error: %v", err) 132 | } 133 | 134 | // now test if the data looks right 135 | rows, err := db.Query("select * from person where id = ?", id) 136 | if err != nil { 137 | t.Errorf("DB error on query: %v", err) 138 | return 139 | } 140 | 141 | p := new(Person) 142 | if err = Default.ScanRow(rows, p); err != nil { 143 | t.Errorf("ScanRow error on Chris: %v", err) 144 | return 145 | } 146 | 147 | personEqual(t, p, &Person{3, "Chris", 0, "chris@chrischris.com", 0, 27, when, when, nil, &h}) 148 | 149 | // delete this record so we don't confuse other tests 150 | if _, err = db.Exec("delete from person where id = ?", id); err != nil { 151 | t.Errorf("DB error on delete: %v", err) 152 | } 153 | db.Exec("delete from person") 154 | } 155 | 156 | func TestDriverErr(t *testing.T) { 157 | err, ok := DriverErr(io.EOF) 158 | if ok { 159 | t.Errorf("io.EOF: want driver error = false, got true") 160 | } 161 | if err != io.EOF { 162 | t.Errorf("io.EOF: want itself as returned error, got %v", err) 163 | } 164 | 165 | once.Do(setup) 166 | // insert into an invalid table 167 | alice.ID = 0 168 | err = Insert(db, "invalid", alice) 169 | if err == nil { 170 | t.Fatal("insert into invalid table, want error, got none") 171 | } 172 | err, ok = DriverErr(err) 173 | if !ok { 174 | t.Errorf("DriverErr: want ok to be true, got false") 175 | } 176 | if _, ok := err.(sqlite3.Error); !ok { 177 | t.Errorf("DriverErr: want sqlite3 error, got %T", err) 178 | } 179 | 180 | // insert with primary key set 181 | alice.ID = 1 182 | err = Insert(db, "person", alice) 183 | if err == nil { 184 | t.Errorf("insert with primary key already set. want error, got none") 185 | } 186 | 187 | // update with primary key not set 188 | alice.ID = 0 189 | err = Update(db, "person", alice) 190 | if err == nil { 191 | t.Errorf("update with primary key 0. want error, got none") 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /mapper.go: -------------------------------------------------------------------------------- 1 | package meddler 2 | 3 | import ( 4 | "strings" 5 | "unicode" 6 | ) 7 | 8 | // MapperFunc signature. Argument is field name, return value is database column. 9 | type MapperFunc func(in string) string 10 | 11 | // Mapper defines the function to transform struct field names into database columns. 12 | // Default is strings.TrimSpace, basically a no-op 13 | var Mapper MapperFunc = strings.TrimSpace 14 | 15 | // LowerCase returns a lowercased version of the input string 16 | func LowerCase(in string) string { 17 | return strings.ToLower(in) 18 | } 19 | 20 | // SnakeCase returns a snake_cased version of the input string 21 | func SnakeCase(in string) string { 22 | runes := []rune(in) 23 | 24 | var out []rune 25 | for i := 0; i < len(runes); i++ { 26 | if i > 0 && (unicode.IsUpper(runes[i]) || unicode.IsNumber(runes[i])) && ((i+1 < len(runes) && unicode.IsLower(runes[i+1])) || unicode.IsLower(runes[i-1])) { 27 | out = append(out, '_') 28 | } 29 | out = append(out, unicode.ToLower(runes[i])) 30 | } 31 | 32 | return string(out) 33 | } 34 | -------------------------------------------------------------------------------- /mapper_test.go: -------------------------------------------------------------------------------- 1 | package meddler 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestDefaultMapper(t *testing.T) { 8 | // default mapper should be no-op 9 | var tests = map[string]string{ 10 | "": "", 11 | "foo": "foo", 12 | "foo_bar": "foo_bar", 13 | "FooBar": "FooBar", 14 | "FOOBAR": "FOOBAR", 15 | } 16 | 17 | for i, e := range tests { 18 | if v := Mapper(i); v != e { 19 | t.Errorf("Mapper(\"%s\"): expected %s, got %s", i, e, v) 20 | } 21 | } 22 | 23 | } 24 | 25 | func TestSnakeCase(t *testing.T) { 26 | var tests = map[string]string{ 27 | "": "", 28 | "ID": "id", 29 | "ColumnName": "column_name", 30 | "COLUMN_NAME": "column_name", 31 | "column_name": "column_name", 32 | "UserID": "user_id", 33 | "UserNameRaw": "user_name_raw", 34 | } 35 | 36 | for i, e := range tests { 37 | if v := SnakeCase(i); v != e { 38 | t.Errorf("SnakeCase(\"%s\"): expected %s, got %s", i, e, v) 39 | } 40 | } 41 | } 42 | 43 | func TestLowerCase(t *testing.T) { 44 | var tests = map[string]string{ 45 | "": "", 46 | "ID": "id", 47 | "ColumnName": "columnname", 48 | "COLUMN_NAME": "column_name", 49 | "column_name": "column_name", 50 | "UserID": "userid", 51 | "UserNameRaw": "usernameraw", 52 | } 53 | 54 | for i, e := range tests { 55 | if v := LowerCase(i); v != e { 56 | t.Errorf("LowerCase(\"%s\"): expected %s, got %s", i, e, v) 57 | } 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /meddler.go: -------------------------------------------------------------------------------- 1 | package meddler 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "encoding/gob" 7 | "encoding/json" 8 | "fmt" 9 | "reflect" 10 | "time" 11 | ) 12 | 13 | // Meddler is the interface for a field meddler. Implementations can be 14 | // registered to convert struct fields being loaded and saved in the database. 15 | type Meddler interface { 16 | // PreRead is called before a Scan operation. It is given a pointer to 17 | // the raw struct field, and returns the value that will be given to 18 | // the database driver. 19 | PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) 20 | 21 | // PostRead is called after a Scan operation. It is given the value returned 22 | // by PreRead and a pointer to the raw struct field. It is expected to fill 23 | // in the struct field if the two are different. 24 | PostRead(fieldAddr interface{}, scanTarget interface{}) error 25 | 26 | // PreWrite is called before an Insert or Update operation. It is given 27 | // a pointer to the raw struct field, and returns the value that will be 28 | // given to the database driver. 29 | PreWrite(field interface{}) (saveValue interface{}, err error) 30 | } 31 | 32 | // Register sets up a meddler type. Meddlers get a chance to meddle with the 33 | // data being loaded or saved when a field is annotated with the name of the meddler. 34 | // The registry is global. 35 | func Register(name string, m Meddler) { 36 | if name == "pk" { 37 | panic("meddler.Register: pk cannot be used as a meddler name") 38 | } 39 | registry[name] = m 40 | } 41 | 42 | var registry = make(map[string]Meddler) 43 | 44 | func init() { 45 | Register("identity", IdentityMeddler(false)) 46 | Register("localtime", TimeMeddler{ZeroIsNull: false, Local: true}) 47 | Register("localtimez", TimeMeddler{ZeroIsNull: true, Local: true}) 48 | Register("utctime", TimeMeddler{ZeroIsNull: false, Local: false}) 49 | Register("utctimez", TimeMeddler{ZeroIsNull: true, Local: false}) 50 | Register("zeroisnull", ZeroIsNullMeddler(false)) 51 | Register("json", JSONMeddler(false)) 52 | Register("jsongzip", JSONMeddler(true)) 53 | Register("gob", GobMeddler(false)) 54 | Register("gobgzip", GobMeddler(true)) 55 | } 56 | 57 | // IdentityMeddler is the default meddler, and it passes the original value through with 58 | // no changes. 59 | type IdentityMeddler bool 60 | 61 | // PreRead is called before a Scan operation for fields that have the IdentityMeddler 62 | func (elt IdentityMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) { 63 | return fieldAddr, nil 64 | } 65 | 66 | // PostRead is called after a Scan operation for fields that have the IdentityMeddler 67 | func (elt IdentityMeddler) PostRead(fieldAddr, scanTarget interface{}) error { 68 | return nil 69 | } 70 | 71 | // PreWrite is called before an Insert or Update operation for fields that have the IdentityMeddler 72 | func (elt IdentityMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) { 73 | return field, nil 74 | } 75 | 76 | // TimeMeddler provides useful operations on time.Time fields. It can convert the zero time 77 | // to and from a null column, and it can convert the time zone to UTC on save and to Local on load. 78 | type TimeMeddler struct { 79 | ZeroIsNull bool 80 | Local bool 81 | } 82 | 83 | // PreRead is called before a Scan operation for fields that have a TimeMeddler 84 | func (elt TimeMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) { 85 | switch tgt := fieldAddr.(type) { 86 | case *time.Time: 87 | if elt.ZeroIsNull { 88 | return &tgt, nil 89 | } 90 | return fieldAddr, nil 91 | case **time.Time: 92 | if elt.ZeroIsNull { 93 | return nil, fmt.Errorf("meddler.TimeMeddler cannot be used on a *time.Time field, only time.Time") 94 | } 95 | return fieldAddr, nil 96 | default: 97 | return nil, fmt.Errorf("meddler.TimeMeddler.PreRead: unknown struct field type: %T", fieldAddr) 98 | } 99 | } 100 | 101 | // PostRead is called after a Scan operation for fields that have a TimeMeddler 102 | func (elt TimeMeddler) PostRead(fieldAddr, scanTarget interface{}) error { 103 | switch tgt := fieldAddr.(type) { 104 | case *time.Time: 105 | if elt.ZeroIsNull { 106 | src := scanTarget.(**time.Time) 107 | if *src == nil { 108 | *tgt = time.Time{} 109 | } else if elt.Local { 110 | *tgt = (*src).Local() 111 | } else { 112 | *tgt = (*src).UTC() 113 | } 114 | return nil 115 | } 116 | 117 | src := scanTarget.(*time.Time) 118 | if elt.Local { 119 | *tgt = src.Local() 120 | } else { 121 | *tgt = src.UTC() 122 | } 123 | 124 | return nil 125 | 126 | case **time.Time: 127 | if elt.ZeroIsNull { 128 | return fmt.Errorf("meddler TimeMeddler cannot be used on a *time.Time field, only time.Time") 129 | } 130 | src := scanTarget.(**time.Time) 131 | if *src == nil { 132 | *tgt = nil 133 | } else if elt.Local { 134 | **src = (*src).Local() 135 | *tgt = *src 136 | } else { 137 | **src = (*src).UTC() 138 | *tgt = *src 139 | } 140 | 141 | return nil 142 | 143 | default: 144 | return fmt.Errorf("meddler.TimeMeddler.PostRead: unknown struct field type: %T", fieldAddr) 145 | } 146 | } 147 | 148 | // PreWrite is called before an Insert or Update operation for fields that have a TimeMeddler 149 | func (elt TimeMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) { 150 | switch tgt := field.(type) { 151 | case time.Time: 152 | if elt.ZeroIsNull && tgt.IsZero() { 153 | return nil, nil 154 | } 155 | return tgt.UTC(), nil 156 | 157 | case *time.Time: 158 | if tgt == nil || elt.ZeroIsNull && tgt.IsZero() { 159 | return nil, nil 160 | } 161 | return tgt.UTC(), nil 162 | 163 | default: 164 | return nil, fmt.Errorf("meddler.TimeMeddler.PreWrite: unknown struct field type: %T", field) 165 | } 166 | } 167 | 168 | // ZeroIsNullMeddler converts zero value fields (integers both signed and unsigned, floats, complex numbers, 169 | // and strings) to and from null database columns. 170 | type ZeroIsNullMeddler bool 171 | 172 | // PreRead is called before a Scan operation for fields that have the ZeroIsNullMeddler 173 | func (elt ZeroIsNullMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) { 174 | // create a pointer to this element 175 | // the database driver will set it to nil if the column value is null 176 | return reflect.New(reflect.TypeOf(fieldAddr)).Interface(), nil 177 | } 178 | 179 | // PostRead is called after a Scan operation for fields that have the ZeroIsNullMeddler 180 | func (elt ZeroIsNullMeddler) PostRead(fieldAddr, scanTarget interface{}) error { 181 | sv := reflect.ValueOf(scanTarget) 182 | fv := reflect.ValueOf(fieldAddr) 183 | if sv.Elem().IsNil() { 184 | // null column, so set target to be zero value 185 | fv.Elem().Set(reflect.Zero(fv.Elem().Type())) 186 | } else { 187 | // copy the value that scan found 188 | fv.Elem().Set(sv.Elem().Elem()) 189 | } 190 | return nil 191 | } 192 | 193 | // PreWrite is called before an Insert or Update operation for fields that have the ZeroIsNullMeddler 194 | func (elt ZeroIsNullMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) { 195 | val := reflect.ValueOf(field) 196 | switch val.Kind() { 197 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 198 | if val.Int() == 0 { 199 | return nil, nil 200 | } 201 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 202 | if val.Uint() == 0 { 203 | return nil, nil 204 | } 205 | case reflect.Float32, reflect.Float64: 206 | if val.Float() == 0 { 207 | return nil, nil 208 | } 209 | case reflect.Complex64, reflect.Complex128: 210 | if val.Complex() == 0 { 211 | return nil, nil 212 | } 213 | case reflect.String: 214 | if val.String() == "" { 215 | return nil, nil 216 | } 217 | case reflect.Bool: 218 | if !val.Bool() { 219 | return nil, nil 220 | } 221 | default: 222 | return nil, fmt.Errorf("ZeroIsNullMeddler.PreWrite: unknown struct field type: %T", field) 223 | } 224 | 225 | return field, nil 226 | } 227 | 228 | // JSONMeddler encodes or decodes the field value to or from JSON 229 | type JSONMeddler bool 230 | 231 | // PreRead is called before a Scan operation for fields that have the JSONMeddler 232 | func (zip JSONMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) { 233 | // give a pointer to a byte buffer to grab the raw data 234 | return new([]byte), nil 235 | } 236 | 237 | // PostRead is called after a Scan operation for fields that have the JSONMeddler 238 | func (zip JSONMeddler) PostRead(fieldAddr, scanTarget interface{}) error { 239 | ptr := scanTarget.(*[]byte) 240 | if ptr == nil { 241 | return fmt.Errorf("JSONMeddler.PostRead: nil pointer") 242 | } 243 | raw := *ptr 244 | 245 | if zip { 246 | // un-gzip and decode json 247 | gzipReader, err := gzip.NewReader(bytes.NewReader(raw)) 248 | if err != nil { 249 | return fmt.Errorf("Error creating gzip Reader: %v", err) 250 | } 251 | defer gzipReader.Close() 252 | jsonDecoder := json.NewDecoder(gzipReader) 253 | if err := jsonDecoder.Decode(fieldAddr); err != nil { 254 | return fmt.Errorf("JSON decoder/gzip error: %v", err) 255 | } 256 | if err := gzipReader.Close(); err != nil { 257 | return fmt.Errorf("Closing gzip reader: %v", err) 258 | } 259 | 260 | return nil 261 | } 262 | 263 | // decode json 264 | jsonDecoder := json.NewDecoder(bytes.NewReader(raw)) 265 | if err := jsonDecoder.Decode(fieldAddr); err != nil { 266 | return fmt.Errorf("JSON decode error: %v", err) 267 | } 268 | 269 | return nil 270 | } 271 | 272 | // PreWrite is called before an Insert or Update operation for fields that have the JSONMeddler 273 | func (zip JSONMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) { 274 | buffer := new(bytes.Buffer) 275 | 276 | if zip { 277 | // json encode and gzip 278 | gzipWriter := gzip.NewWriter(buffer) 279 | defer gzipWriter.Close() 280 | jsonEncoder := json.NewEncoder(gzipWriter) 281 | if err := jsonEncoder.Encode(field); err != nil { 282 | return nil, fmt.Errorf("JSON encoding/gzip error: %v", err) 283 | } 284 | if err := gzipWriter.Close(); err != nil { 285 | return nil, fmt.Errorf("Closing gzip writer: %v", err) 286 | } 287 | 288 | return buffer.Bytes(), nil 289 | } 290 | 291 | // json encode 292 | jsonEncoder := json.NewEncoder(buffer) 293 | if err := jsonEncoder.Encode(field); err != nil { 294 | return nil, fmt.Errorf("JSON encoding error: %v", err) 295 | } 296 | return buffer.Bytes(), nil 297 | } 298 | 299 | // GobMeddler encodes or decodes the field value to or from gob 300 | type GobMeddler bool 301 | 302 | // PreRead is called before a Scan operation for fields that have the GobMeddler 303 | func (zip GobMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) { 304 | // give a pointer to a byte buffer to grab the raw data 305 | return new([]byte), nil 306 | } 307 | 308 | // PostRead is called after a Scan operation for fields that have the GobMeddler 309 | func (zip GobMeddler) PostRead(fieldAddr, scanTarget interface{}) error { 310 | ptr := scanTarget.(*[]byte) 311 | if ptr == nil { 312 | return fmt.Errorf("GobMeddler.PostRead: nil pointer") 313 | } 314 | raw := *ptr 315 | 316 | if zip { 317 | // un-gzip and decode gob 318 | gzipReader, err := gzip.NewReader(bytes.NewReader(raw)) 319 | if err != nil { 320 | return fmt.Errorf("Error creating gzip Reader: %v", err) 321 | } 322 | defer gzipReader.Close() 323 | gobDecoder := gob.NewDecoder(gzipReader) 324 | if err := gobDecoder.Decode(fieldAddr); err != nil { 325 | return fmt.Errorf("Gob decoder/gzip error: %v", err) 326 | } 327 | if err := gzipReader.Close(); err != nil { 328 | return fmt.Errorf("Closing gzip reader: %v", err) 329 | } 330 | 331 | return nil 332 | } 333 | 334 | // decode gob 335 | gobDecoder := gob.NewDecoder(bytes.NewReader(raw)) 336 | if err := gobDecoder.Decode(fieldAddr); err != nil { 337 | return fmt.Errorf("Gob decode error: %v", err) 338 | } 339 | 340 | return nil 341 | } 342 | 343 | // PreWrite is called before an Insert or Update operation for fields that have the GobMeddler 344 | func (zip GobMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) { 345 | buffer := new(bytes.Buffer) 346 | 347 | if zip { 348 | // gob encode and gzip 349 | gzipWriter := gzip.NewWriter(buffer) 350 | defer gzipWriter.Close() 351 | gobEncoder := gob.NewEncoder(gzipWriter) 352 | if err := gobEncoder.Encode(field); err != nil { 353 | return nil, fmt.Errorf("Gob encoding/gzip error: %v", err) 354 | } 355 | if err := gzipWriter.Close(); err != nil { 356 | return nil, fmt.Errorf("Closing gzip writer: %v", err) 357 | } 358 | 359 | return buffer.Bytes(), nil 360 | } 361 | 362 | // gob encode 363 | gobEncoder := gob.NewEncoder(buffer) 364 | if err := gobEncoder.Encode(field); err != nil { 365 | return nil, fmt.Errorf("Gob encoding error: %v", err) 366 | } 367 | return buffer.Bytes(), nil 368 | } 369 | -------------------------------------------------------------------------------- /meddler_test.go: -------------------------------------------------------------------------------- 1 | package meddler 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | type ItemJson struct { 8 | ID int64 `meddler:"id,pk"` 9 | Stuff map[string]bool `meddler:"stuff,json"` 10 | StuffZ map[string]bool `meddler:"stuffz,jsongzip"` 11 | } 12 | 13 | type ItemGob struct { 14 | ID int64 `meddler:"id,pk"` 15 | Stuff map[string]bool `meddler:"stuff,gob"` 16 | StuffZ map[string]bool `meddler:"stuffz,gobgzip"` 17 | } 18 | 19 | type ItemZeroes struct { 20 | ID int64 `meddler:"id,pk"` 21 | Int int `meddler:"nullint,zeroisnull"` 22 | Float float64 `meddler:"nullfloat,zeroisnull"` 23 | Complex complex128 `meddler:"nullcomplex,zeroisnull"` 24 | String string `meddler:"nullstring,zeroisnull"` 25 | Bool bool `meddler:"nullbool,zeroisnull"` 26 | } 27 | 28 | func TestZeroIsNullMeddler(t *testing.T) { 29 | once.Do(setup) 30 | 31 | before := &ItemZeroes{} 32 | if err := Save(db, "null_item", before); err != nil { 33 | t.Errorf("Save error: %v", err) 34 | } 35 | id := before.ID 36 | 37 | after := new(ItemZeroes) 38 | if err := Load(db, "null_item", after, id); err != nil { 39 | t.Errorf("Load error: %v", err) 40 | } 41 | 42 | if before.String != after.String { 43 | t.Errorf("before.String: expected %s, got %s", before.String, after.String) 44 | } 45 | if before.Int != after.Int { 46 | t.Errorf("before.Int: expected %d, got %d", before.Int, after.Int) 47 | } 48 | if before.Float != after.Float { 49 | t.Errorf("before.Float: expected %#v, got %#v", before.Float, after.Float) 50 | } 51 | if before.Bool != after.Bool { 52 | t.Errorf("before.Bool: expected %#v, got %#v", before.Bool, after.Bool) 53 | } 54 | if before.Complex != after.Complex { 55 | t.Errorf("before.Complex: expected %#v, got %#v", before.Complex, after.Complex) 56 | } 57 | } 58 | 59 | func TestJsonMeddler(t *testing.T) { 60 | once.Do(setup) 61 | 62 | // save a value 63 | elt := &ItemJson{ 64 | ID: 0, 65 | Stuff: map[string]bool{ 66 | "hello": true, 67 | "world": true, 68 | }, 69 | StuffZ: map[string]bool{ 70 | "goodbye": true, 71 | "cruel": true, 72 | "world": true, 73 | }, 74 | } 75 | 76 | if err := Save(db, "item", elt); err != nil { 77 | t.Errorf("Save error: %v", err) 78 | } 79 | id := elt.ID 80 | 81 | // load it again 82 | elt = new(ItemJson) 83 | if err := Load(db, "item", elt, id); err != nil { 84 | t.Errorf("Load error: %v", err) 85 | } 86 | 87 | if elt.ID != id { 88 | t.Errorf("expected id of %d, found %d", id, elt.ID) 89 | } 90 | if len(elt.Stuff) != 2 { 91 | t.Errorf("expected %d items in Stuff, found %d", 2, len(elt.Stuff)) 92 | } 93 | if !elt.Stuff["hello"] || !elt.Stuff["world"] { 94 | t.Errorf("contents of stuff wrong: %v", elt.Stuff) 95 | } 96 | if len(elt.StuffZ) != 3 { 97 | t.Errorf("expected %d items in StuffZ, found %d", 3, len(elt.StuffZ)) 98 | } 99 | if !elt.StuffZ["goodbye"] || !elt.StuffZ["cruel"] || !elt.StuffZ["world"] { 100 | t.Errorf("contents of stuffz wrong: %v", elt.StuffZ) 101 | } 102 | if _, err := db.Exec("delete from `item`"); err != nil { 103 | t.Errorf("error wiping item table: %v", err) 104 | } 105 | } 106 | 107 | func TestGobMeddler(t *testing.T) { 108 | once.Do(setup) 109 | 110 | // save a value 111 | elt := &ItemGob{ 112 | ID: 0, 113 | Stuff: map[string]bool{ 114 | "hello": true, 115 | "world": true, 116 | }, 117 | StuffZ: map[string]bool{ 118 | "goodbye": true, 119 | "cruel": true, 120 | "world": true, 121 | }, 122 | } 123 | 124 | if err := Save(db, "item", elt); err != nil { 125 | t.Errorf("Save error: %v", err) 126 | } 127 | id := elt.ID 128 | 129 | // load it again 130 | elt = new(ItemGob) 131 | if err := Load(db, "item", elt, id); err != nil { 132 | t.Errorf("Load error: %v", err) 133 | } 134 | 135 | if elt.ID != id { 136 | t.Errorf("expected id of %d, found %d", id, elt.ID) 137 | } 138 | if len(elt.Stuff) != 2 { 139 | t.Errorf("expected %d items in Stuff, found %d", 2, len(elt.Stuff)) 140 | } 141 | if !elt.Stuff["hello"] || !elt.Stuff["world"] { 142 | t.Errorf("contents of stuff wrong: %v", elt.Stuff) 143 | } 144 | if len(elt.StuffZ) != 3 { 145 | t.Errorf("expected %d items in StuffZ, found %d", 3, len(elt.StuffZ)) 146 | } 147 | if !elt.StuffZ["goodbye"] || !elt.StuffZ["cruel"] || !elt.StuffZ["world"] { 148 | t.Errorf("contents of stuffz wrong: %v", elt.StuffZ) 149 | } 150 | if _, err := db.Exec("delete from `item`"); err != nil { 151 | t.Errorf("error wiping item table: %v", err) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /scan.go: -------------------------------------------------------------------------------- 1 | package meddler 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "log" 7 | "reflect" 8 | "strconv" 9 | "strings" 10 | "sync" 11 | ) 12 | 13 | // the name of our struct tag 14 | const tagName = "meddler" 15 | 16 | // Database contains database-specific options. 17 | // MySQL, PostgreSQL, and SQLite are provided for convenience. 18 | // Setting Default to any of these lets you use the package-level convenience functions. 19 | type Database struct { 20 | Quote string // the quote character for table and column names 21 | Placeholder string // the placeholder style to use in generated queries 22 | UseReturningToGetID bool // use PostgreSQL-style RETURNING "ID" instead of calling sql.Result.LastInsertID 23 | } 24 | 25 | // MySQL contains database specific options for executing queries in a MySQL database 26 | var MySQL = &Database{ 27 | Quote: "`", 28 | Placeholder: "?", 29 | UseReturningToGetID: false, 30 | } 31 | 32 | // PostgreSQL contains database specific options for executing queries in a PostgreSQL database 33 | var PostgreSQL = &Database{ 34 | Quote: `"`, 35 | Placeholder: "$1", 36 | UseReturningToGetID: true, 37 | } 38 | 39 | // SQLite contains database specific options for executing queries in a SQLite database 40 | var SQLite = &Database{ 41 | Quote: `"`, 42 | Placeholder: "?", 43 | UseReturningToGetID: false, 44 | } 45 | 46 | // Default contains the default database options (which defaults to MySQL) 47 | var Default = MySQL 48 | 49 | func (d *Database) quoted(s string) string { 50 | return d.Quote + s + d.Quote 51 | } 52 | 53 | func (d *Database) placeholder(n int) string { 54 | return strings.Replace(d.Placeholder, "1", strconv.FormatInt(int64(n), 10), 1) 55 | } 56 | 57 | // Debug enables debug mode, where unused columns and struct fields will be logged 58 | var Debug = true 59 | 60 | type structField struct { 61 | column string 62 | index int 63 | primaryKey bool 64 | meddler Meddler 65 | } 66 | 67 | type structData struct { 68 | columns []string 69 | fields map[string]*structField 70 | pk string 71 | } 72 | 73 | // cache reflection data 74 | var fieldsCache = make(map[reflect.Type]*structData) 75 | var fieldsCacheMutex sync.Mutex 76 | 77 | // getFields gathers the list of columns from a struct using reflection. 78 | func getFields(dstType reflect.Type) (*structData, error) { 79 | fieldsCacheMutex.Lock() 80 | defer fieldsCacheMutex.Unlock() 81 | 82 | if result, present := fieldsCache[dstType]; present { 83 | return result, nil 84 | } 85 | 86 | // make sure dst is a non-nil pointer to a struct 87 | if dstType.Kind() != reflect.Ptr { 88 | return nil, fmt.Errorf("meddler called with non-pointer destination %v", dstType) 89 | } 90 | structType := dstType.Elem() 91 | if structType.Kind() != reflect.Struct { 92 | return nil, fmt.Errorf("meddler called with pointer to non-struct %v", dstType) 93 | } 94 | 95 | // gather the list of fields in the struct 96 | data := new(structData) 97 | data.fields = make(map[string]*structField) 98 | 99 | for i := 0; i < structType.NumField(); i++ { 100 | f := structType.Field(i) 101 | 102 | // skip non-exported fields 103 | if f.PkgPath != "" { 104 | continue 105 | } 106 | 107 | // examine the tag for metadata 108 | tag := strings.Split(f.Tag.Get(tagName), ",") 109 | 110 | // was this field marked for skipping? 111 | if len(tag) > 0 && tag[0] == "-" { 112 | continue 113 | } 114 | 115 | // default to the field name 116 | name := f.Name 117 | 118 | // the tag can override the field name 119 | if len(tag) > 0 && tag[0] != "" { 120 | name = tag[0] 121 | } else { 122 | // use mapper func if field has no explicit tag 123 | name = Mapper(f.Name) 124 | } 125 | 126 | // check for a meddler 127 | var meddler Meddler = registry["identity"] 128 | for j := 1; j < len(tag); j++ { 129 | if tag[j] == "pk" { 130 | if f.Type.Kind() == reflect.Ptr { 131 | return nil, fmt.Errorf("meddler found field %s which is marked as the primary key but is a pointer", f.Name) 132 | } 133 | 134 | // make sure it is an int of some kind 135 | switch f.Type.Kind() { 136 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 137 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 138 | default: 139 | return nil, fmt.Errorf("meddler found field %s which is marked as the primary key, but is not an integer type", f.Name) 140 | } 141 | 142 | if data.pk != "" { 143 | return nil, fmt.Errorf("meddler found field %s which is marked as the primary key, but a primary key field was already found", f.Name) 144 | } 145 | data.pk = name 146 | } else if m, present := registry[tag[j]]; present { 147 | meddler = m 148 | } else { 149 | return nil, fmt.Errorf("meddler found field %s with meddler %s, but that meddler is not registered", f.Name, tag[j]) 150 | } 151 | } 152 | 153 | if _, present := data.fields[name]; present { 154 | return nil, fmt.Errorf("meddler found multiple fields for column %s", name) 155 | } 156 | data.fields[name] = &structField{ 157 | column: name, 158 | primaryKey: name == data.pk, 159 | index: i, 160 | meddler: meddler, 161 | } 162 | data.columns = append(data.columns, name) 163 | } 164 | 165 | fieldsCache[dstType] = data 166 | return data, nil 167 | } 168 | 169 | // Columns returns a list of column names for its input struct. 170 | func (d *Database) Columns(src interface{}, includePk bool) ([]string, error) { 171 | data, err := getFields(reflect.TypeOf(src)) 172 | if err != nil { 173 | return nil, err 174 | } 175 | 176 | var names []string 177 | for _, elt := range data.columns { 178 | if !includePk && elt == data.pk { 179 | continue 180 | } 181 | names = append(names, elt) 182 | } 183 | 184 | return names, nil 185 | } 186 | 187 | // Columns using the Default Database type 188 | func Columns(src interface{}, includePk bool) ([]string, error) { 189 | return Default.Columns(src, includePk) 190 | } 191 | 192 | // ColumnsQuoted is similar to Columns, but it return the list of columns in the form: 193 | // `column1`,`column2`,... 194 | // using Quote as the quote character. 195 | func (d *Database) ColumnsQuoted(src interface{}, includePk bool) (string, error) { 196 | unquoted, err := d.Columns(src, includePk) 197 | if err != nil { 198 | return "", err 199 | } 200 | 201 | var parts []string 202 | for _, elt := range unquoted { 203 | parts = append(parts, d.quoted(elt)) 204 | } 205 | 206 | return strings.Join(parts, ","), nil 207 | } 208 | 209 | // ColumnsQuoted using the Default Database type 210 | func ColumnsQuoted(src interface{}, includePk bool) (string, error) { 211 | return Default.ColumnsQuoted(src, includePk) 212 | } 213 | 214 | // PrimaryKey returns the name and value of the primary key field. The name 215 | // is the empty string if there is not primary key field marked. 216 | func (d *Database) PrimaryKey(src interface{}) (name string, pk int64, err error) { 217 | data, err := getFields(reflect.TypeOf(src)) 218 | if err != nil { 219 | return "", 0, err 220 | } 221 | 222 | if data.pk == "" { 223 | return "", 0, nil 224 | } 225 | 226 | name = data.pk 227 | field := reflect.ValueOf(src).Elem().Field(data.fields[name].index) 228 | switch field.Type().Kind() { 229 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 230 | pk = field.Int() 231 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 232 | pk = int64(field.Uint()) 233 | default: 234 | return "", 0, fmt.Errorf("meddler found field %s which is marked as the primary key, but is not an integer type", name) 235 | } 236 | 237 | return name, pk, nil 238 | } 239 | 240 | // PrimaryKey using the Default Database type 241 | func PrimaryKey(src interface{}) (name string, pk int64, err error) { 242 | return Default.PrimaryKey(src) 243 | } 244 | 245 | // SetPrimaryKey sets the primary key field to the given int value. 246 | func (d *Database) SetPrimaryKey(src interface{}, pk int64) error { 247 | data, err := getFields(reflect.TypeOf(src)) 248 | if err != nil { 249 | return err 250 | } 251 | 252 | if data.pk == "" { 253 | return fmt.Errorf("meddler.SetPrimaryKey: no primary key field found") 254 | } 255 | 256 | field := reflect.ValueOf(src).Elem().Field(data.fields[data.pk].index) 257 | switch field.Type().Kind() { 258 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 259 | field.SetInt(pk) 260 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 261 | field.SetUint(uint64(pk)) 262 | default: 263 | return fmt.Errorf("meddler found field %s which is marked as the primary key, but is not an integer type", data.pk) 264 | } 265 | 266 | return nil 267 | } 268 | 269 | // SetPrimaryKey using the Default Database type 270 | func SetPrimaryKey(src interface{}, pk int64) error { 271 | return Default.SetPrimaryKey(src, pk) 272 | } 273 | 274 | // Values returns a list of PreWrite processed values suitable for 275 | // use in an INSERT or UPDATE query. If includePk is false, the primary 276 | // key field is omitted. The columns used are the same ones (in the same 277 | // order) as returned by Columns. 278 | func (d *Database) Values(src interface{}, includePk bool) ([]interface{}, error) { 279 | columns, err := d.Columns(src, includePk) 280 | if err != nil { 281 | return nil, err 282 | } 283 | return d.SomeValues(src, columns) 284 | } 285 | 286 | // Values using the Default Database type 287 | func Values(src interface{}, includePk bool) ([]interface{}, error) { 288 | return Default.Values(src, includePk) 289 | } 290 | 291 | // SomeValues returns a list of PreWrite processed values suitable for 292 | // use in an INSERT or UPDATE query. The columns used are the same ones (in 293 | // the same order) as specified in the columns argument. 294 | func (d *Database) SomeValues(src interface{}, columns []string) ([]interface{}, error) { 295 | data, err := getFields(reflect.TypeOf(src)) 296 | if err != nil { 297 | return nil, err 298 | } 299 | structVal := reflect.ValueOf(src).Elem() 300 | 301 | var values []interface{} 302 | for _, name := range columns { 303 | field, present := data.fields[name] 304 | if !present { 305 | // write null to the database 306 | values = append(values, nil) 307 | 308 | if Debug { 309 | log.Printf("meddler.SomeValues: column [%s] not found in struct", name) 310 | } 311 | continue 312 | } 313 | 314 | saveVal, err := field.meddler.PreWrite(structVal.Field(field.index).Interface()) 315 | if err != nil { 316 | return nil, fmt.Errorf("meddler.SomeValues: PreWrite error on column [%s]: %v", name, err) 317 | } 318 | values = append(values, saveVal) 319 | } 320 | 321 | return values, nil 322 | } 323 | 324 | // SomeValues using the Default Database type 325 | func SomeValues(src interface{}, columns []string) ([]interface{}, error) { 326 | return Default.SomeValues(src, columns) 327 | } 328 | 329 | // Placeholders returns a list of placeholders suitable for an INSERT or UPDATE query. 330 | // If includePk is false, the primary key field is omitted. 331 | func (d *Database) Placeholders(src interface{}, includePk bool) ([]string, error) { 332 | data, err := getFields(reflect.TypeOf(src)) 333 | if err != nil { 334 | return nil, err 335 | } 336 | 337 | var placeholders []string 338 | for _, name := range data.columns { 339 | if !includePk && name == data.pk { 340 | continue 341 | } 342 | ph := d.placeholder(len(placeholders) + 1) 343 | placeholders = append(placeholders, ph) 344 | } 345 | 346 | return placeholders, nil 347 | } 348 | 349 | // Placeholders using the Default Database type 350 | func Placeholders(src interface{}, includePk bool) ([]string, error) { 351 | return Default.Placeholders(src, includePk) 352 | } 353 | 354 | // PlaceholdersString returns a list of placeholders suitable for an INSERT 355 | // or UPDATE query in string form, e.g.: 356 | // ?,?,?,? 357 | // if includePk is false, the primary key field is omitted. 358 | func (d *Database) PlaceholdersString(src interface{}, includePk bool) (string, error) { 359 | lst, err := d.Placeholders(src, includePk) 360 | if err != nil { 361 | return "", err 362 | } 363 | return strings.Join(lst, ","), nil 364 | } 365 | 366 | // PlaceholdersString using the Default Database type 367 | func PlaceholdersString(src interface{}, includePk bool) (string, error) { 368 | return Default.PlaceholdersString(src, includePk) 369 | } 370 | 371 | // scan a single row of data into a struct. 372 | func (d *Database) scanRow(data *structData, rows *sql.Rows, dst interface{}, columns []string) error { 373 | // check if there is data waiting 374 | if !rows.Next() { 375 | if err := rows.Err(); err != nil { 376 | return err 377 | } 378 | return sql.ErrNoRows 379 | } 380 | 381 | // get a list of targets 382 | targets, err := d.Targets(dst, columns) 383 | if err != nil { 384 | return err 385 | } 386 | 387 | // perform the scan 388 | if err := rows.Scan(targets...); err != nil { 389 | return err 390 | } 391 | 392 | // post-process and copy the target values into the struct 393 | if err := d.WriteTargets(dst, columns, targets); err != nil { 394 | return err 395 | } 396 | 397 | return rows.Err() 398 | } 399 | 400 | // Targets returns a list of values suitable for handing to a 401 | // Scan function in the sql package, complete with meddling. After 402 | // the Scan is performed, the same values should be handed to 403 | // WriteTargets to finalize the values and record them in the struct. 404 | func (d *Database) Targets(dst interface{}, columns []string) ([]interface{}, error) { 405 | data, err := getFields(reflect.TypeOf(dst)) 406 | if err != nil { 407 | return nil, err 408 | } 409 | 410 | structVal := reflect.ValueOf(dst).Elem() 411 | 412 | var targets []interface{} 413 | for _, name := range columns { 414 | if field, present := data.fields[name]; present { 415 | fieldAddr := structVal.Field(field.index).Addr().Interface() 416 | scanTarget, err := field.meddler.PreRead(fieldAddr) 417 | if err != nil { 418 | return nil, fmt.Errorf("meddler.Targets: PreRead error on column %s: %v", name, err) 419 | } 420 | targets = append(targets, scanTarget) 421 | } else { 422 | // no destination, so throw this away 423 | targets = append(targets, new(interface{})) 424 | 425 | if Debug { 426 | log.Printf("meddler.Targets: column [%s] not found in struct", name) 427 | } 428 | } 429 | } 430 | 431 | return targets, nil 432 | } 433 | 434 | // Targets using the Default Database type 435 | func Targets(dst interface{}, columns []string) ([]interface{}, error) { 436 | return Default.Targets(dst, columns) 437 | } 438 | 439 | // WriteTargets post-processes values with meddlers after a Scan from the 440 | // sql package has been performed. The list of targets is normally produced 441 | // by Targets. 442 | func (d *Database) WriteTargets(dst interface{}, columns []string, targets []interface{}) error { 443 | if len(columns) != len(targets) { 444 | return fmt.Errorf("meddler.WriteTargets: mismatch in number of columns (%d) and targets (%d)", 445 | len(columns), len(targets)) 446 | } 447 | 448 | data, err := getFields(reflect.TypeOf(dst)) 449 | if err != nil { 450 | return err 451 | } 452 | structVal := reflect.ValueOf(dst).Elem() 453 | 454 | for i, name := range columns { 455 | if field, present := data.fields[name]; present { 456 | fieldAddr := structVal.Field(field.index).Addr().Interface() 457 | err := field.meddler.PostRead(fieldAddr, targets[i]) 458 | if err != nil { 459 | return fmt.Errorf("meddler.WriteTargets: PostRead error on column [%s]: %v", name, err) 460 | } 461 | } else { 462 | // not destination, so throw this away 463 | if Debug { 464 | log.Printf("meddler.WriteTargets: column [%s] not found in struct", name) 465 | } 466 | } 467 | } 468 | 469 | return nil 470 | } 471 | 472 | // WriteTargets using the Default Database type 473 | func WriteTargets(dst interface{}, columns []string, targets []interface{}) error { 474 | return Default.WriteTargets(dst, columns, targets) 475 | } 476 | 477 | // Scan scans a single sql result row into a struct. 478 | // It leaves rows ready to be scanned again for the next row. 479 | // Returns sql.ErrNoRows if there is no data to read. 480 | func (d *Database) Scan(rows *sql.Rows, dst interface{}) error { 481 | // get the list of struct fields 482 | data, err := getFields(reflect.TypeOf(dst)) 483 | if err != nil { 484 | return err 485 | } 486 | 487 | // get the sql columns 488 | columns, err := rows.Columns() 489 | if err != nil { 490 | return err 491 | } 492 | 493 | return d.scanRow(data, rows, dst, columns) 494 | } 495 | 496 | // Scan using the Default Database type 497 | func Scan(rows *sql.Rows, dst interface{}) error { 498 | return Default.Scan(rows, dst) 499 | } 500 | 501 | // ScanRow scans a single sql result row into a struct. 502 | // It reads exactly one result row and closes rows when finished. 503 | // Returns sql.ErrNoRows if there is no result row. 504 | func (d *Database) ScanRow(rows *sql.Rows, dst interface{}) error { 505 | // make sure we always close rows, even if there is a scan error 506 | defer rows.Close() 507 | 508 | if err := d.Scan(rows, dst); err != nil { 509 | return err 510 | } 511 | 512 | return rows.Close() 513 | } 514 | 515 | // ScanRow using the Default Database type 516 | func ScanRow(rows *sql.Rows, dst interface{}) error { 517 | return Default.ScanRow(rows, dst) 518 | } 519 | 520 | // ScanAll scans all sql result rows into a slice of structs. 521 | // It reads all rows and closes rows when finished. 522 | // dst should be a pointer to a slice of the appropriate type. 523 | // The new results will be appended to any existing data in dst. 524 | func (d *Database) ScanAll(rows *sql.Rows, dst interface{}) error { 525 | // make sure we always close rows 526 | defer rows.Close() 527 | 528 | // make sure dst is an appropriate type 529 | dstVal := reflect.ValueOf(dst) 530 | if dstVal.Kind() != reflect.Ptr || dstVal.IsNil() { 531 | return fmt.Errorf("ScanAll called with non-pointer destination: %T", dst) 532 | } 533 | sliceVal := dstVal.Elem() 534 | if sliceVal.Kind() != reflect.Slice { 535 | return fmt.Errorf("ScanAll called with pointer to non-slice: %T", dst) 536 | } 537 | ptrType := sliceVal.Type().Elem() 538 | if ptrType.Kind() != reflect.Ptr { 539 | return fmt.Errorf("ScanAll expects element to be pointers, found %T", dst) 540 | } 541 | eltType := ptrType.Elem() 542 | if eltType.Kind() != reflect.Struct { 543 | return fmt.Errorf("ScanAll expects element to be pointers to structs, found %T", dst) 544 | } 545 | 546 | // get the list of struct fields 547 | data, err := getFields(ptrType) 548 | if err != nil { 549 | return err 550 | } 551 | 552 | // get the sql columns 553 | columns, err := rows.Columns() 554 | if err != nil { 555 | return err 556 | } 557 | 558 | // gather the results 559 | for { 560 | // create a new element 561 | eltVal := reflect.New(eltType) 562 | elt := eltVal.Interface() 563 | 564 | // scan it 565 | if err := d.scanRow(data, rows, elt, columns); err != nil { 566 | if err == sql.ErrNoRows { 567 | return nil 568 | } 569 | return err 570 | } 571 | 572 | // add to the result slice 573 | sliceVal.Set(reflect.Append(sliceVal, eltVal)) 574 | } 575 | } 576 | 577 | // ScanAll using the Default Database type 578 | func ScanAll(rows *sql.Rows, dst interface{}) error { 579 | return Default.ScanAll(rows, dst) 580 | } 581 | -------------------------------------------------------------------------------- /scan_test.go: -------------------------------------------------------------------------------- 1 | package meddler 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "reflect" 7 | "sort" 8 | "strings" 9 | "sync" 10 | "testing" 11 | "time" 12 | 13 | _ "github.com/mattn/go-sqlite3" 14 | ) 15 | 16 | var once sync.Once 17 | var db *sql.DB 18 | var when = time.Date(2013, 6, 23, 15, 30, 12, 0, time.UTC) 19 | 20 | type Person struct { 21 | ID int64 `meddler:"id,pk"` 22 | Name string `meddler:"name"` 23 | private int 24 | Email string 25 | Ephemeral int `meddler:"-"` 26 | Age int `meddler:",zeroisnull"` 27 | Opened time.Time `meddler:"opened,utctime"` 28 | Closed time.Time `meddler:"closed,utctimez"` 29 | Updated *time.Time `meddler:"updated,localtime"` 30 | Height *int `meddler:"height"` 31 | } 32 | 33 | type HalfPerson struct { 34 | ID int64 `meddler:"id,pk"` 35 | private int 36 | Ephemeral int `meddler:"-"` 37 | Age int `meddler:",zeroisnull"` 38 | Closed time.Time `meddler:"closed,utctimez"` 39 | Updated *time.Time `meddler:"updated,localtime"` 40 | } 41 | 42 | type UintPerson struct { 43 | ID uint64 `meddler:"id,pk"` 44 | Name string `meddler:"name"` 45 | private int 46 | Email string 47 | Ephemeral int `meddler:"-"` 48 | Age int `meddler:",zeroisnull"` 49 | Opened time.Time `meddler:"opened,utctime"` 50 | Closed time.Time `meddler:"closed,utctimez"` 51 | Updated *time.Time `meddler:"updated,localtime"` 52 | Height *int `meddler:"height"` 53 | } 54 | 55 | const schema1 = `create table person ( 56 | id integer primary key, 57 | name text not null, 58 | Email text not null, 59 | Age integer, 60 | opened datetime not null, 61 | closed datetime, 62 | updated datetime, 63 | height integer 64 | )` 65 | 66 | const schema2 = `create table item ( 67 | id integer primary key, 68 | stuff text not null, 69 | stuffz blob not null 70 | )` 71 | 72 | const schema3 = `create table null_item ( 73 | id integer primary key, 74 | nullint integer null, 75 | nullstring text null, 76 | nullcomplex real null, 77 | nullfloat real null, 78 | nullbool integer null 79 | )` 80 | 81 | var aliceHeight int = 65 82 | var alice = &Person{ 83 | Name: "Alice", 84 | Email: "alice@alice.com", 85 | Ephemeral: 12, 86 | Age: 32, 87 | Opened: when.Local(), 88 | Closed: when, 89 | Updated: &when, 90 | Height: &aliceHeight, 91 | } 92 | 93 | var bob = &Person{ 94 | Name: "Bob", 95 | Email: "bob@bob.com", 96 | Opened: when, 97 | } 98 | 99 | func setup() { 100 | var err error 101 | 102 | // create the database 103 | db, err = sql.Open("sqlite3", ":memory:") 104 | if err != nil { 105 | panic("error creating test database: " + err.Error()) 106 | } 107 | 108 | // create the tables 109 | if _, err = db.Exec(schema1); err != nil { 110 | panic("error creating person table: " + err.Error()) 111 | } 112 | if _, err = db.Exec(schema2); err != nil { 113 | panic("error creating item table: " + err.Error()) 114 | } 115 | if _, err = db.Exec(schema3); err != nil { 116 | panic("error creating null_item table: " + err.Error()) 117 | } 118 | 119 | } 120 | 121 | func structFieldEqual(t *testing.T, elt *structField, ref *structField) { 122 | if elt == nil { 123 | t.Errorf("Missing field for %s", ref.column) 124 | return 125 | } 126 | if elt.column != ref.column { 127 | t.Errorf("Column %s column found as %v", ref.column, elt.column) 128 | } 129 | if elt.primaryKey != ref.primaryKey { 130 | t.Errorf("Column %s primaryKey found as %v", ref.column, elt.primaryKey) 131 | } 132 | if elt.index != ref.index { 133 | t.Errorf("Column %s index found as %v", ref.column, elt.index) 134 | } 135 | if elt.meddler != ref.meddler { 136 | t.Errorf("Column %s meddler mismatch", ref.column) 137 | } 138 | } 139 | 140 | func TestGetFields(t *testing.T) { 141 | data, err := getFields(reflect.TypeOf((*Person)(nil))) 142 | if err != nil { 143 | t.Errorf("Error in getFields: %v", err) 144 | return 145 | } 146 | 147 | // see if everything checks out 148 | if len(data.fields) != 8 || len(data.columns) != 8 { 149 | t.Errorf("Found %d/%d fields, expected 8", len(data.fields), len(data.columns)) 150 | } 151 | structFieldEqual(t, data.fields[data.columns[0]], &structField{"id", 0, true, registry["identity"]}) 152 | structFieldEqual(t, data.fields[data.columns[1]], &structField{"name", 1, false, registry["identity"]}) 153 | structFieldEqual(t, data.fields[data.columns[2]], &structField{"Email", 3, false, registry["identity"]}) 154 | structFieldEqual(t, data.fields[data.columns[3]], &structField{"Age", 5, false, registry["zeroisnull"]}) 155 | structFieldEqual(t, data.fields[data.columns[4]], &structField{"opened", 6, false, registry["utctime"]}) 156 | structFieldEqual(t, data.fields[data.columns[5]], &structField{"closed", 7, false, registry["utctimez"]}) 157 | structFieldEqual(t, data.fields[data.columns[6]], &structField{"updated", 8, false, registry["localtime"]}) 158 | structFieldEqual(t, data.fields[data.columns[7]], &structField{"height", 9, false, registry["identity"]}) 159 | 160 | // test with non-pointer 161 | if _, err := getFields(reflect.TypeOf(*alice)); err == nil { 162 | t.Errorf("calling getFields with non-pointer type should return err, got nil") 163 | } 164 | 165 | // test with pointer to non-struct 166 | s := "foo" 167 | if _, err := getFields(reflect.TypeOf(&s)); err == nil { 168 | t.Errorf("calling getFields with pointer to non-struct should return err, got nil") 169 | } 170 | 171 | // test with pointer as PK 172 | type personPointerPK struct { 173 | ID *int `meddler:",pk"` 174 | } 175 | if _, err := getFields(reflect.TypeOf((*personPointerPK)(nil))); err == nil { 176 | t.Errorf("calling getFields with pointer as primary key should return err, got nil") 177 | } 178 | 179 | // test with struct as PK 180 | type personStructPK struct { 181 | ID Person `meddler:",pk"` 182 | } 183 | if _, err := getFields(reflect.TypeOf((*personStructPK)(nil))); err == nil { 184 | t.Errorf("calling getFields with struct as primary key should return err, got nil") 185 | } 186 | 187 | // test with duplicate column name 188 | type personDuplicateColumn struct { 189 | ID int `meddler:"id,pk"` 190 | Foo1 string `meddler:"foo"` 191 | Foo2 string `meddler:"foo"` 192 | } 193 | if _, err := getFields(reflect.TypeOf((*personDuplicateColumn)(nil))); err == nil { 194 | t.Errorf("calling getFields with duplicated column name should return err, got nil") 195 | } 196 | 197 | // test with unexisting meddler 198 | type personUnexistingMeddler struct { 199 | ID int `meddler:"id,pk"` 200 | Foo string `meddler:"foo,bar"` 201 | } 202 | if _, err := getFields(reflect.TypeOf((*personUnexistingMeddler)(nil))); err == nil { 203 | t.Errorf("calling getFields with unexisting meddler should return err, got nil") 204 | } 205 | 206 | } 207 | 208 | func personEqual(t *testing.T, elt *Person, ref *Person) { 209 | if elt == nil { 210 | t.Errorf("Person %s is nil", ref.Name) 211 | return 212 | } 213 | if elt.ID != ref.ID { 214 | t.Errorf("Person %s ID is %v", ref.Name, elt.ID) 215 | } 216 | if elt.Name != ref.Name { 217 | t.Errorf("Person %s Name is %v", ref.Name, elt.Name) 218 | } 219 | if elt.private != ref.private { 220 | t.Errorf("Person %s private is %v", ref.Name, elt.private) 221 | } 222 | if elt.Email != ref.Email { 223 | t.Errorf("Person %s Email is %v", ref.Name, elt.Email) 224 | } 225 | if elt.Ephemeral != ref.Ephemeral { 226 | t.Errorf("Person %d Ephemeral is %d", ref.Ephemeral, elt.Ephemeral) 227 | } 228 | if elt.Age != ref.Age { 229 | t.Errorf("Person %s Age is %v", ref.Name, elt.Age) 230 | } 231 | if !elt.Opened.Equal(ref.Opened) { 232 | t.Errorf("Person %s Opened is %v", ref.Name, elt.Opened) 233 | } 234 | if !elt.Closed.Equal(ref.Closed) { 235 | t.Errorf("Person %s Closed is %v", ref.Name, elt.Closed) 236 | } 237 | if (elt.Updated == nil) != (ref.Updated == nil) { 238 | t.Errorf("Person %s Updated == nil is %v", ref.Name, elt.Updated == nil) 239 | } else if elt.Updated != nil && !elt.Updated.Equal(*ref.Updated) { 240 | t.Errorf("Person %s Updated is %v", ref.Name, *elt.Updated) 241 | } 242 | if elt.Updated != nil { 243 | zone, _ := elt.Updated.Zone() 244 | local, _ := when.Local().Zone() 245 | if zone != local { 246 | t.Errorf("Person %s Updated in time zone %v, expected %v", ref.Name, zone, local) 247 | } 248 | } 249 | if (elt.Height == nil) != (ref.Height == nil) { 250 | t.Errorf("Person %s Height == nil is %v", ref.Name, elt.Height == nil) 251 | } else if elt.Height != nil && *elt.Height != *ref.Height { 252 | t.Errorf("Person %s Height is %v", ref.Name, *elt.Height) 253 | } 254 | } 255 | 256 | func insertAliceBob(t *testing.T) { 257 | // insert Alice as row #1 258 | alice.ID = 0 259 | if err := Insert(db, "person", alice); err != nil { 260 | t.Errorf("Error inserting Alice: %v", err) 261 | } 262 | if alice.ID != 1 { 263 | t.Errorf("Alice ID is %d, expecting 1", alice.ID) 264 | } 265 | 266 | // insert Bob as row #2 267 | bob.ID = 0 268 | if err := Insert(db, "person", bob); err != nil { 269 | t.Errorf("Error inserting Bob: %v", err) 270 | } 271 | if bob.ID != 2 { 272 | t.Errorf("Bob ID is %d, expecting 2", bob.ID) 273 | } 274 | } 275 | 276 | func TestColumns(t *testing.T) { 277 | once.Do(setup) 278 | 279 | p := new(Person) 280 | names, err := Columns(p, true) 281 | if err != nil { 282 | t.Errorf("Error getting Columns: %v", err) 283 | } 284 | 285 | expected := []string{"id", "name", "Email", "Age", "opened", "closed", "updated", "height"} 286 | sort.Strings(expected) 287 | 288 | if len(names) != len(expected) { 289 | t.Errorf("Expected %d columns, got %d", len(expected), len(names)) 290 | } 291 | sort.Strings(names) 292 | for i := 0; i < len(expected); i++ { 293 | if expected[i] != names[i] { 294 | t.Errorf("Expected %s at position %d, got %s", expected[i], i, names[i]) 295 | } 296 | } 297 | 298 | } 299 | 300 | func TestColumnsQuoted(t *testing.T) { 301 | once.Do(setup) 302 | 303 | p := new(Person) 304 | names, err := ColumnsQuoted(p, true) 305 | if err != nil { 306 | t.Errorf("Error getting ColumnsQuoted: %v", err) 307 | } 308 | 309 | lst := []string{"id", "name", "Email", "Age", "opened", "closed", "updated", "height"} 310 | sort.Strings(lst) 311 | for i, orig := range lst { 312 | lst[i] = Default.quoted(orig) 313 | } 314 | expected := strings.Join(lst, ",") 315 | 316 | if len(names) != len(expected) { 317 | t.Errorf("Length mismatch: expected %d, got %d", len(expected), len(names)) 318 | } 319 | 320 | fields := strings.Split(names, ",") 321 | sort.Strings(fields) 322 | names = strings.Join(fields, ",") 323 | 324 | if expected != names { 325 | t.Errorf("Mismatch: expected %s, got %s", expected, names) 326 | } 327 | } 328 | 329 | func TestPrimaryKey(t *testing.T) { 330 | p := new(Person) 331 | p.ID = 56 332 | name, val, err := PrimaryKey(p) 333 | if err != nil { 334 | t.Errorf("Error getting PrimaryKey: %v", err) 335 | } 336 | if name != "id" { 337 | t.Errorf("Expected pk name to be id, found %s", name) 338 | } 339 | if val != 56 { 340 | t.Errorf("Expected pk value to be 56, found %d", val) 341 | } 342 | 343 | p2 := new(UintPerson) 344 | p2.ID = 56 345 | name, val, err = PrimaryKey(p2) 346 | if err != nil { 347 | t.Errorf("Error getting PrimaryKey: %v", err) 348 | } 349 | if name != "id" { 350 | t.Errorf("Expected pk name to be id, found %s", name) 351 | } 352 | if val != 56 { 353 | t.Errorf("Expected pk value to be 56, found %d", val) 354 | } 355 | } 356 | 357 | func TestSetPrimaryKey(t *testing.T) { 358 | p := new(Person) 359 | err := SetPrimaryKey(p, 14) 360 | if err != nil { 361 | t.Errorf("Error in SetPrimaryKey: %v", err) 362 | } 363 | if p.ID != 14 { 364 | t.Errorf("Expected id to be 14, found %d", p.ID) 365 | } 366 | 367 | p2 := new(Person) 368 | err = SetPrimaryKey(p2, 14) 369 | if err != nil { 370 | t.Errorf("Error in SetPrimaryKey: %v", err) 371 | } 372 | if p2.ID != 14 { 373 | t.Errorf("Expected id to be 14, found %d", p2.ID) 374 | } 375 | } 376 | 377 | func TestValues(t *testing.T) { 378 | alice.ID = 15 379 | lst, err := Values(alice, true) 380 | if err != nil { 381 | t.Errorf("Values error: %v", err) 382 | } 383 | 384 | if lst[0] != int64(15) { 385 | t.Errorf("expected 15, got %v", lst[0]) 386 | } 387 | if lst[1] != "Alice" { 388 | t.Errorf("Expected Alice, got %v", lst[1]) 389 | } 390 | if lst[2] != "alice@alice.com" { 391 | t.Errorf("Expected alice@alice.com, got %v", lst[2]) 392 | } 393 | if lst[3] != 32 { 394 | t.Errorf("Expected 32, got %v", lst[3]) 395 | } 396 | if lst[4] != when.UTC() { 397 | t.Errorf("Expected %v, got %v", when.UTC(), lst[4]) 398 | } 399 | if lst[5] != when.UTC() { 400 | t.Errorf("Expected %v, got %v", when.UTC(), lst[5]) 401 | } 402 | if lst[6] != when.UTC() { 403 | t.Errorf("Expected %v, got %v", when.UTC(), lst[6]) 404 | } 405 | if *(lst[7].(*int)) != aliceHeight { 406 | t.Errorf("Expected %d, got %v", aliceHeight, lst[7]) 407 | } 408 | 409 | lst, err = Values(alice, false) 410 | if err != nil { 411 | t.Errorf("Values error: %v", err) 412 | } 413 | if lst[0] != "Alice" { 414 | t.Errorf("Expected Alice, got %v", lst[0]) 415 | } 416 | } 417 | 418 | func TestPlaceholders(t *testing.T) { 419 | lst, err := MySQL.Placeholders(alice, true) 420 | if err != nil { 421 | t.Errorf("Error in Placeholders: %v", err) 422 | } 423 | if len(lst) != 8 { 424 | t.Errorf("expected 8 items, found %d", len(lst)) 425 | } 426 | for _, elt := range lst { 427 | if elt != MySQL.Placeholder { 428 | t.Errorf("expected %s, found %s", MySQL.Placeholder, elt) 429 | } 430 | } 431 | 432 | lst, err = PostgreSQL.Placeholders(alice, false) 433 | if err != nil { 434 | t.Errorf("Error in Placeholders: %v", err) 435 | } 436 | if len(lst) != 7 { 437 | t.Errorf("expected 7 items, found %d", len(lst)) 438 | } 439 | for i, elt := range lst { 440 | expected := fmt.Sprintf("$%d", i+1) 441 | if expected != elt { 442 | t.Errorf("expected %s, found %s", expected, elt) 443 | } 444 | } 445 | } 446 | 447 | func TestPlaceholdersString(t *testing.T) { 448 | s, err := SQLite.PlaceholdersString(alice, false) 449 | if err != nil { 450 | t.Errorf("Error in PlaceholdersString: %v", err) 451 | } 452 | expected := "?,?,?,?,?,?,?" 453 | if s != expected { 454 | t.Errorf("expected %s, found %s", expected, s) 455 | } 456 | 457 | s, err = PostgreSQL.PlaceholdersString(alice, true) 458 | if err != nil { 459 | t.Errorf("Error in PlaceholdersString: %v", err) 460 | } 461 | expected = "$1,$2,$3,$4,$5,$6,$7,$8" 462 | if s != expected { 463 | t.Errorf("expected %s, found %s", expected, s) 464 | } 465 | } 466 | 467 | func TestScanRow(t *testing.T) { 468 | once.Do(setup) 469 | insertAliceBob(t) 470 | 471 | rows, err := db.Query("select * from person where id in (1,2) order by id") 472 | if err != nil { 473 | t.Errorf("DB error on query: %v", err) 474 | return 475 | } 476 | 477 | alice := new(Person) 478 | if err = Scan(rows, alice); err != nil { 479 | t.Errorf("Scan error on Alice: %v", err) 480 | return 481 | } 482 | 483 | bob := new(Person) 484 | bob.Age = 50 485 | bob.Closed = time.Now() 486 | bob.private = 14 487 | bob.Ephemeral = 16 488 | if err = ScanRow(rows, bob); err != nil { 489 | t.Errorf("ScanRow error on Bob: %v", err) 490 | return 491 | } 492 | 493 | height := 65 494 | personEqual(t, alice, &Person{1, "Alice", 0, "alice@alice.com", 0, 32, when, when, &when, &height}) 495 | personEqual(t, bob, &Person{2, "Bob", 14, "bob@bob.com", 16, 0, when, time.Time{}, nil, nil}) 496 | db.Exec("delete from person") 497 | } 498 | 499 | func TestScanAll(t *testing.T) { 500 | once.Do(setup) 501 | insertAliceBob(t) 502 | 503 | rows, err := db.Query("select * from person order by id") 504 | if err != nil { 505 | t.Errorf("DB error on query: %v", err) 506 | return 507 | } 508 | 509 | var lst []*Person 510 | if err = ScanAll(rows, &lst); err != nil { 511 | t.Errorf("ScanAll error: %v", err) 512 | return 513 | } 514 | 515 | if len(lst) != 2 { 516 | t.Errorf("ScanAll found %d rows, expected 2", len(lst)) 517 | return 518 | } 519 | 520 | height := 65 521 | personEqual(t, lst[0], &Person{1, "Alice", 0, "alice@alice.com", 0, 32, when, when, &when, &height}) 522 | personEqual(t, lst[1], &Person{2, "Bob", 0, "bob@bob.com", 0, 0, when, time.Time{}, nil, nil}) 523 | db.Exec("delete from person") 524 | } 525 | 526 | func TestThrowAway(t *testing.T) { 527 | once.Do(setup) 528 | insertAliceBob(t) 529 | 530 | Debug = false 531 | hp := new(HalfPerson) 532 | err := QueryRow(db, hp, "select * from person where id = 1") 533 | if err != nil { 534 | t.Errorf("QueryRow error: %v", err) 535 | } 536 | Debug = true 537 | db.Exec("delete from person") 538 | } 539 | --------------------------------------------------------------------------------