├── .circleci └── config.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── bind.go ├── bind_test.go ├── doc.go ├── go.mod ├── go.sum ├── named.go ├── named_context.go ├── named_context_test.go ├── named_test.go ├── reflectx ├── README.md ├── reflect.go └── reflect_test.go ├── sqlx.go ├── sqlx_context.go ├── sqlx_context_test.go ├── sqlx_test.go └── types ├── README.md ├── doc.go ├── types.go └── types_test.go /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | "-": &go-versions 4 | [ "1.18.10", "1.19.13", "1.20.14", "1.21.9", "1.22.2" ] 5 | 6 | executors: 7 | go_executor: 8 | parameters: 9 | version: 10 | type: string 11 | docker: 12 | - image: cimg/go:<< parameters.version >> 13 | 14 | jobs: 15 | test: 16 | parameters: 17 | go_version: 18 | type: string 19 | executor: 20 | name: go_executor 21 | version: << parameters.go_version >> 22 | steps: 23 | - checkout 24 | - restore_cache: 25 | keys: 26 | - go-mod-v4-{{ checksum "go.sum" }} 27 | - run: 28 | name: Install Dependencies 29 | command: go mod download 30 | - save_cache: 31 | key: go-mod-v4-{{ checksum "go.sum" }} 32 | paths: 33 | - "/go/pkg/mod" 34 | - run: 35 | name: Run tests 36 | command: | 37 | mkdir -p /tmp/test-reports 38 | gotestsum --junitfile /tmp/test-reports/unit-tests.xml 39 | - store_test_results: 40 | path: /tmp/test-reports 41 | test-race: 42 | parameters: 43 | go_version: 44 | type: string 45 | executor: 46 | name: go_executor 47 | version: << parameters.go_version >> 48 | steps: 49 | - checkout 50 | - restore_cache: 51 | keys: 52 | - go-mod-v4-{{ checksum "go.sum" }} 53 | - run: 54 | name: Install Dependencies 55 | command: go mod download 56 | - save_cache: 57 | key: go-mod-v4-{{ checksum "go.sum" }} 58 | paths: 59 | - "/go/pkg/mod" 60 | - run: 61 | name: Run tests with race detector 62 | command: make test-race 63 | lint: 64 | parameters: 65 | go_version: 66 | type: string 67 | executor: 68 | name: go_executor 69 | version: << parameters.go_version >> 70 | steps: 71 | - checkout 72 | - restore_cache: 73 | keys: 74 | - go-mod-v4-{{ checksum "go.sum" }} 75 | - run: 76 | name: Install Dependencies 77 | command: go mod download 78 | - run: 79 | name: Install tooling 80 | command: | 81 | make tooling 82 | - save_cache: 83 | key: go-mod-v4-{{ checksum "go.sum" }} 84 | paths: 85 | - "/go/pkg/mod" 86 | - run: 87 | name: Linting 88 | command: make lint 89 | - run: 90 | name: Running vulncheck 91 | command: make vuln-check 92 | fmt: 93 | parameters: 94 | go_version: 95 | type: string 96 | executor: 97 | name: go_executor 98 | version: << parameters.go_version >> 99 | steps: 100 | - checkout 101 | - restore_cache: 102 | keys: 103 | - go-mod-v4-{{ checksum "go.sum" }} 104 | - run: 105 | name: Install Dependencies 106 | command: go mod download 107 | - run: 108 | name: Install tooling 109 | command: | 110 | make tooling 111 | - save_cache: 112 | key: go-mod-v4-{{ checksum "go.sum" }} 113 | paths: 114 | - "/go/pkg/mod" 115 | - run: 116 | name: Running formatting 117 | command: | 118 | make fmt 119 | make has-changes 120 | 121 | workflows: 122 | version: 2 123 | build-and-test: 124 | jobs: 125 | - test: 126 | matrix: 127 | parameters: 128 | go_version: *go-versions 129 | - test-race: 130 | matrix: 131 | parameters: 132 | go_version: *go-versions 133 | - lint: 134 | matrix: 135 | parameters: 136 | go_version: *go-versions 137 | - fmt: 138 | matrix: 139 | parameters: 140 | go_version: *go-versions 141 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | .idea 10 | 11 | # Architecture specific extensions/prefixes 12 | *.[568vq] 13 | [568vq].out 14 | 15 | *.cgo1.go 16 | *.cgo2.c 17 | _cgo_defun.c 18 | _cgo_gotypes.go 19 | _cgo_export.* 20 | 21 | _testmain.go 22 | 23 | *.exe 24 | tags 25 | environ 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013, Jason Moiron 2 | 3 | Permission is hereby granted, free of charge, to any person 4 | obtaining a copy of this software and associated documentation 5 | files (the "Software"), to deal in the Software without 6 | restriction, including without limitation the rights to use, 7 | copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the 9 | Software is furnished to do so, subject to the following 10 | conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 17 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 19 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 20 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .ONESHELL: 2 | SHELL = /bin/sh 3 | .SHELLFLAGS = -ec 4 | 5 | BASE_PACKAGE := github.com/jmoiron/sqlx 6 | 7 | tooling: 8 | go install honnef.co/go/tools/cmd/staticcheck@v0.4.7 9 | go install golang.org/x/vuln/cmd/govulncheck@v1.0.4 10 | go install golang.org/x/tools/cmd/goimports@v0.20.0 11 | 12 | has-changes: 13 | git diff --exit-code --quiet HEAD -- 14 | 15 | lint: 16 | go vet ./... 17 | staticcheck -checks=all ./... 18 | 19 | fmt: 20 | go list -f '{{.Dir}}' ./... | xargs -I {} goimports -local $(BASE_PACKAGE) -w {} 21 | 22 | vuln-check: 23 | govulncheck ./... 24 | 25 | test-race: 26 | go test -v -race -count=1 ./... 27 | 28 | update-dependencies: 29 | go get -u -t -v ./... 30 | go mod tidy 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sqlx 2 | 3 | [![CircleCI](https://dl.circleci.com/status-badge/img/gh/jmoiron/sqlx/tree/master.svg?style=shield)](https://dl.circleci.com/status-badge/redirect/gh/jmoiron/sqlx/tree/master) [![Coverage Status](https://coveralls.io/repos/github/jmoiron/sqlx/badge.svg?branch=master)](https://coveralls.io/github/jmoiron/sqlx?branch=master) [![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/jmoiron/sqlx) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/jmoiron/sqlx/master/LICENSE) 4 | 5 | sqlx is a library which provides a set of extensions on go's standard 6 | `database/sql` library. The sqlx versions of `sql.DB`, `sql.TX`, `sql.Stmt`, 7 | et al. all leave the underlying interfaces untouched, so that their interfaces 8 | are a superset on the standard ones. This makes it relatively painless to 9 | integrate existing codebases using database/sql with sqlx. 10 | 11 | Major additional concepts are: 12 | 13 | * Marshal rows into structs (with embedded struct support), maps, and slices 14 | * Named parameter support including prepared statements 15 | * `Get` and `Select` to go quickly from query to struct/slice 16 | 17 | In addition to the [godoc API documentation](http://godoc.org/github.com/jmoiron/sqlx), 18 | there is also some [user documentation](http://jmoiron.github.io/sqlx/) that 19 | explains how to use `database/sql` along with sqlx. 20 | 21 | ## Recent Changes 22 | 23 | 1.3.0: 24 | 25 | * `sqlx.DB.Connx(context.Context) *sqlx.Conn` 26 | * `sqlx.BindDriver(driverName, bindType)` 27 | * support for `[]map[string]interface{}` to do "batch" insertions 28 | * allocation & perf improvements for `sqlx.In` 29 | 30 | DB.Connx returns an `sqlx.Conn`, which is an `sql.Conn`-alike consistent with 31 | sqlx's wrapping of other types. 32 | 33 | `BindDriver` allows users to control the bindvars that sqlx will use for drivers, 34 | and add new drivers at runtime. This results in a very slight performance hit 35 | when resolving the driver into a bind type (~40ns per call), but it allows users 36 | to specify what bindtype their driver uses even when sqlx has not been updated 37 | to know about it by default. 38 | 39 | ### Backwards Compatibility 40 | 41 | Compatibility with the most recent two versions of Go is a requirement for any 42 | new changes. Compatibility beyond that is not guaranteed. 43 | 44 | Versioning is done with Go modules. Breaking changes (eg. removing deprecated API) 45 | will get major version number bumps. 46 | 47 | ## install 48 | 49 | go get github.com/jmoiron/sqlx 50 | 51 | ## issues 52 | 53 | Row headers can be ambiguous (`SELECT 1 AS a, 2 AS a`), and the result of 54 | `Columns()` does not fully qualify column names in queries like: 55 | 56 | ```sql 57 | SELECT a.id, a.name, b.id, b.name FROM foos AS a JOIN foos AS b ON a.parent = b.id; 58 | ``` 59 | 60 | making a struct or map destination ambiguous. Use `AS` in your queries 61 | to give columns distinct names, `rows.Scan` to scan them manually, or 62 | `SliceScan` to get a slice of results. 63 | 64 | ## usage 65 | 66 | Below is an example which shows some common use cases for sqlx. Check 67 | [sqlx_test.go](https://github.com/jmoiron/sqlx/blob/master/sqlx_test.go) for more 68 | usage. 69 | 70 | 71 | ```go 72 | package main 73 | 74 | import ( 75 | "database/sql" 76 | "fmt" 77 | "log" 78 | 79 | _ "github.com/lib/pq" 80 | "github.com/jmoiron/sqlx" 81 | ) 82 | 83 | var schema = ` 84 | CREATE TABLE person ( 85 | first_name text, 86 | last_name text, 87 | email text 88 | ); 89 | 90 | CREATE TABLE place ( 91 | country text, 92 | city text NULL, 93 | telcode integer 94 | )` 95 | 96 | type Person struct { 97 | FirstName string `db:"first_name"` 98 | LastName string `db:"last_name"` 99 | Email string 100 | } 101 | 102 | type Place struct { 103 | Country string 104 | City sql.NullString 105 | TelCode int 106 | } 107 | 108 | func main() { 109 | // this Pings the database trying to connect 110 | // use sqlx.Open() for sql.Open() semantics 111 | db, err := sqlx.Connect("postgres", "user=foo dbname=bar sslmode=disable") 112 | if err != nil { 113 | log.Fatalln(err) 114 | } 115 | 116 | // exec the schema or fail; multi-statement Exec behavior varies between 117 | // database drivers; pq will exec them all, sqlite3 won't, ymmv 118 | db.MustExec(schema) 119 | 120 | tx := db.MustBegin() 121 | tx.MustExec("INSERT INTO person (first_name, last_name, email) VALUES ($1, $2, $3)", "Jason", "Moiron", "jmoiron@jmoiron.net") 122 | tx.MustExec("INSERT INTO person (first_name, last_name, email) VALUES ($1, $2, $3)", "John", "Doe", "johndoeDNE@gmail.net") 123 | tx.MustExec("INSERT INTO place (country, city, telcode) VALUES ($1, $2, $3)", "United States", "New York", "1") 124 | tx.MustExec("INSERT INTO place (country, telcode) VALUES ($1, $2)", "Hong Kong", "852") 125 | tx.MustExec("INSERT INTO place (country, telcode) VALUES ($1, $2)", "Singapore", "65") 126 | // Named queries can use structs, so if you have an existing struct (i.e. person := &Person{}) that you have populated, you can pass it in as &person 127 | tx.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", &Person{"Jane", "Citizen", "jane.citzen@example.com"}) 128 | tx.Commit() 129 | 130 | // Query the database, storing results in a []Person (wrapped in []interface{}) 131 | people := []Person{} 132 | db.Select(&people, "SELECT * FROM person ORDER BY first_name ASC") 133 | jason, john := people[0], people[1] 134 | 135 | fmt.Printf("%#v\n%#v", jason, john) 136 | // Person{FirstName:"Jason", LastName:"Moiron", Email:"jmoiron@jmoiron.net"} 137 | // Person{FirstName:"John", LastName:"Doe", Email:"johndoeDNE@gmail.net"} 138 | 139 | // You can also get a single result, a la QueryRow 140 | jason = Person{} 141 | err = db.Get(&jason, "SELECT * FROM person WHERE first_name=$1", "Jason") 142 | fmt.Printf("%#v\n", jason) 143 | // Person{FirstName:"Jason", LastName:"Moiron", Email:"jmoiron@jmoiron.net"} 144 | 145 | // if you have null fields and use SELECT *, you must use sql.Null* in your struct 146 | places := []Place{} 147 | err = db.Select(&places, "SELECT * FROM place ORDER BY telcode ASC") 148 | if err != nil { 149 | fmt.Println(err) 150 | return 151 | } 152 | usa, singsing, honkers := places[0], places[1], places[2] 153 | 154 | fmt.Printf("%#v\n%#v\n%#v\n", usa, singsing, honkers) 155 | // Place{Country:"United States", City:sql.NullString{String:"New York", Valid:true}, TelCode:1} 156 | // Place{Country:"Singapore", City:sql.NullString{String:"", Valid:false}, TelCode:65} 157 | // Place{Country:"Hong Kong", City:sql.NullString{String:"", Valid:false}, TelCode:852} 158 | 159 | // Loop through rows using only one struct 160 | place := Place{} 161 | rows, err := db.Queryx("SELECT * FROM place") 162 | for rows.Next() { 163 | err := rows.StructScan(&place) 164 | if err != nil { 165 | log.Fatalln(err) 166 | } 167 | fmt.Printf("%#v\n", place) 168 | } 169 | // Place{Country:"United States", City:sql.NullString{String:"New York", Valid:true}, TelCode:1} 170 | // Place{Country:"Hong Kong", City:sql.NullString{String:"", Valid:false}, TelCode:852} 171 | // Place{Country:"Singapore", City:sql.NullString{String:"", Valid:false}, TelCode:65} 172 | 173 | // Named queries, using `:name` as the bindvar. Automatic bindvar support 174 | // which takes into account the dbtype based on the driverName on sqlx.Open/Connect 175 | _, err = db.NamedExec(`INSERT INTO person (first_name,last_name,email) VALUES (:first,:last,:email)`, 176 | map[string]interface{}{ 177 | "first": "Bin", 178 | "last": "Smuth", 179 | "email": "bensmith@allblacks.nz", 180 | }) 181 | 182 | // Selects Mr. Smith from the database 183 | rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:fn`, map[string]interface{}{"fn": "Bin"}) 184 | 185 | // Named queries can also use structs. Their bind names follow the same rules 186 | // as the name -> db mapping, so struct fields are lowercased and the `db` tag 187 | // is taken into consideration. 188 | rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:first_name`, jason) 189 | 190 | 191 | // batch insert 192 | 193 | // batch insert with structs 194 | personStructs := []Person{ 195 | {FirstName: "Ardie", LastName: "Savea", Email: "asavea@ab.co.nz"}, 196 | {FirstName: "Sonny Bill", LastName: "Williams", Email: "sbw@ab.co.nz"}, 197 | {FirstName: "Ngani", LastName: "Laumape", Email: "nlaumape@ab.co.nz"}, 198 | } 199 | 200 | _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email) 201 | VALUES (:first_name, :last_name, :email)`, personStructs) 202 | 203 | // batch insert with maps 204 | personMaps := []map[string]interface{}{ 205 | {"first_name": "Ardie", "last_name": "Savea", "email": "asavea@ab.co.nz"}, 206 | {"first_name": "Sonny Bill", "last_name": "Williams", "email": "sbw@ab.co.nz"}, 207 | {"first_name": "Ngani", "last_name": "Laumape", "email": "nlaumape@ab.co.nz"}, 208 | } 209 | 210 | _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email) 211 | VALUES (:first_name, :last_name, :email)`, personMaps) 212 | } 213 | ``` 214 | -------------------------------------------------------------------------------- /bind.go: -------------------------------------------------------------------------------- 1 | package sqlx 2 | 3 | import ( 4 | "bytes" 5 | "database/sql/driver" 6 | "errors" 7 | "reflect" 8 | "strconv" 9 | "strings" 10 | "sync" 11 | 12 | "github.com/jmoiron/sqlx/reflectx" 13 | ) 14 | 15 | // Bindvar types supported by Rebind, BindMap and BindStruct. 16 | const ( 17 | UNKNOWN = iota 18 | QUESTION 19 | DOLLAR 20 | NAMED 21 | AT 22 | ) 23 | 24 | var defaultBinds = map[int][]string{ 25 | DOLLAR: {"postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql", "nrpostgres", "cockroach"}, 26 | QUESTION: {"mysql", "sqlite3", "nrmysql", "nrsqlite3"}, 27 | NAMED: {"oci8", "ora", "goracle", "godror"}, 28 | AT: {"sqlserver", "azuresql"}, 29 | } 30 | 31 | var binds sync.Map 32 | 33 | func init() { 34 | for bind, drivers := range defaultBinds { 35 | for _, driver := range drivers { 36 | BindDriver(driver, bind) 37 | } 38 | } 39 | 40 | } 41 | 42 | // BindType returns the bindtype for a given database given a drivername. 43 | func BindType(driverName string) int { 44 | itype, ok := binds.Load(driverName) 45 | if !ok { 46 | return UNKNOWN 47 | } 48 | return itype.(int) 49 | } 50 | 51 | // BindDriver sets the BindType for driverName to bindType. 52 | func BindDriver(driverName string, bindType int) { 53 | binds.Store(driverName, bindType) 54 | } 55 | 56 | // FIXME: this should be able to be tolerant of escaped ?'s in queries without 57 | // losing much speed, and should be to avoid confusion. 58 | 59 | // Rebind a query from the default bindtype (QUESTION) to the target bindtype. 60 | func Rebind(bindType int, query string) string { 61 | switch bindType { 62 | case QUESTION, UNKNOWN: 63 | return query 64 | } 65 | 66 | // Add space enough for 10 params before we have to allocate 67 | rqb := make([]byte, 0, len(query)+10) 68 | 69 | var i, j int 70 | 71 | for i = strings.Index(query, "?"); i != -1; i = strings.Index(query, "?") { 72 | rqb = append(rqb, query[:i]...) 73 | 74 | switch bindType { 75 | case DOLLAR: 76 | rqb = append(rqb, '$') 77 | case NAMED: 78 | rqb = append(rqb, ':', 'a', 'r', 'g') 79 | case AT: 80 | rqb = append(rqb, '@', 'p') 81 | } 82 | 83 | j++ 84 | rqb = strconv.AppendInt(rqb, int64(j), 10) 85 | 86 | query = query[i+1:] 87 | } 88 | 89 | return string(append(rqb, query...)) 90 | } 91 | 92 | // Experimental implementation of Rebind which uses a bytes.Buffer. The code is 93 | // much simpler and should be more resistant to odd unicode, but it is twice as 94 | // slow. Kept here for benchmarking purposes and to possibly replace Rebind if 95 | // problems arise with its somewhat naive handling of unicode. 96 | func rebindBuff(bindType int, query string) string { 97 | if bindType != DOLLAR { 98 | return query 99 | } 100 | 101 | b := make([]byte, 0, len(query)) 102 | rqb := bytes.NewBuffer(b) 103 | j := 1 104 | for _, r := range query { 105 | if r == '?' { 106 | rqb.WriteRune('$') 107 | rqb.WriteString(strconv.Itoa(j)) 108 | j++ 109 | } else { 110 | rqb.WriteRune(r) 111 | } 112 | } 113 | 114 | return rqb.String() 115 | } 116 | 117 | func asSliceForIn(i interface{}) (v reflect.Value, ok bool) { 118 | if i == nil { 119 | return reflect.Value{}, false 120 | } 121 | 122 | v = reflect.ValueOf(i) 123 | t := reflectx.Deref(v.Type()) 124 | 125 | // Only expand slices 126 | if t.Kind() != reflect.Slice { 127 | return reflect.Value{}, false 128 | } 129 | 130 | // []byte is a driver.Value type so it should not be expanded 131 | if t == reflect.TypeOf([]byte{}) { 132 | return reflect.Value{}, false 133 | 134 | } 135 | 136 | return v, true 137 | } 138 | 139 | // In expands slice values in args, returning the modified query string 140 | // and a new arg list that can be executed by a database. The `query` should 141 | // use the `?` bindVar. The return value uses the `?` bindVar. 142 | func In(query string, args ...interface{}) (string, []interface{}, error) { 143 | // argMeta stores reflect.Value and length for slices and 144 | // the value itself for non-slice arguments 145 | type argMeta struct { 146 | v reflect.Value 147 | i interface{} 148 | length int 149 | } 150 | 151 | var flatArgsCount int 152 | var anySlices bool 153 | 154 | var stackMeta [32]argMeta 155 | 156 | var meta []argMeta 157 | if len(args) <= len(stackMeta) { 158 | meta = stackMeta[:len(args)] 159 | } else { 160 | meta = make([]argMeta, len(args)) 161 | } 162 | 163 | for i, arg := range args { 164 | if a, ok := arg.(driver.Valuer); ok { 165 | var err error 166 | arg, err = a.Value() 167 | if err != nil { 168 | return "", nil, err 169 | } 170 | } 171 | 172 | if v, ok := asSliceForIn(arg); ok { 173 | meta[i].length = v.Len() 174 | meta[i].v = v 175 | 176 | anySlices = true 177 | flatArgsCount += meta[i].length 178 | 179 | if meta[i].length == 0 { 180 | return "", nil, errors.New("empty slice passed to 'in' query") 181 | } 182 | } else { 183 | meta[i].i = arg 184 | flatArgsCount++ 185 | } 186 | } 187 | 188 | // don't do any parsing if there aren't any slices; note that this means 189 | // some errors that we might have caught below will not be returned. 190 | if !anySlices { 191 | return query, args, nil 192 | } 193 | 194 | newArgs := make([]interface{}, 0, flatArgsCount) 195 | 196 | var buf strings.Builder 197 | buf.Grow(len(query) + len(", ?")*flatArgsCount) 198 | 199 | var arg, offset int 200 | 201 | for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') { 202 | if arg >= len(meta) { 203 | // if an argument wasn't passed, lets return an error; this is 204 | // not actually how database/sql Exec/Query works, but since we are 205 | // creating an argument list programmatically, we want to be able 206 | // to catch these programmer errors earlier. 207 | return "", nil, errors.New("number of bindVars exceeds arguments") 208 | } 209 | 210 | argMeta := meta[arg] 211 | arg++ 212 | 213 | // not a slice, continue. 214 | // our questionmark will either be written before the next expansion 215 | // of a slice or after the loop when writing the rest of the query 216 | if argMeta.length == 0 { 217 | offset = offset + i + 1 218 | newArgs = append(newArgs, argMeta.i) 219 | continue 220 | } 221 | 222 | // write everything up to and including our ? character 223 | buf.WriteString(query[:offset+i+1]) 224 | 225 | for si := 1; si < argMeta.length; si++ { 226 | buf.WriteString(", ?") 227 | } 228 | 229 | newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length) 230 | 231 | // slice the query and reset the offset. this avoids some bookkeeping for 232 | // the write after the loop 233 | query = query[offset+i+1:] 234 | offset = 0 235 | } 236 | 237 | buf.WriteString(query) 238 | 239 | if arg < len(meta) { 240 | return "", nil, errors.New("number of bindVars less than number arguments") 241 | } 242 | 243 | return buf.String(), newArgs, nil 244 | } 245 | 246 | func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} { 247 | switch val := v.Interface().(type) { 248 | case []interface{}: 249 | args = append(args, val...) 250 | case []int: 251 | for i := range val { 252 | args = append(args, val[i]) 253 | } 254 | case []string: 255 | for i := range val { 256 | args = append(args, val[i]) 257 | } 258 | default: 259 | for si := 0; si < vlen; si++ { 260 | args = append(args, v.Index(si).Interface()) 261 | } 262 | } 263 | 264 | return args 265 | } 266 | -------------------------------------------------------------------------------- /bind_test.go: -------------------------------------------------------------------------------- 1 | package sqlx 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | ) 7 | 8 | func oldBindType(driverName string) int { 9 | switch driverName { 10 | case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql": 11 | return DOLLAR 12 | case "mysql": 13 | return QUESTION 14 | case "sqlite3": 15 | return QUESTION 16 | case "oci8", "ora", "goracle", "godror": 17 | return NAMED 18 | case "sqlserver": 19 | return AT 20 | } 21 | return UNKNOWN 22 | } 23 | 24 | /* 25 | sync.Map implementation: 26 | 27 | goos: linux 28 | goarch: amd64 29 | pkg: github.com/jmoiron/sqlx 30 | BenchmarkBindSpeed/old-4 100000000 11.0 ns/op 31 | BenchmarkBindSpeed/new-4 24575726 50.8 ns/op 32 | 33 | 34 | async.Value map implementation: 35 | 36 | goos: linux 37 | goarch: amd64 38 | pkg: github.com/jmoiron/sqlx 39 | BenchmarkBindSpeed/old-4 100000000 11.0 ns/op 40 | BenchmarkBindSpeed/new-4 42535839 27.5 ns/op 41 | */ 42 | 43 | func BenchmarkBindSpeed(b *testing.B) { 44 | testDrivers := []string{ 45 | "postgres", "pgx", "mysql", "sqlite3", "ora", "sqlserver", 46 | } 47 | 48 | b.Run("old", func(b *testing.B) { 49 | b.StopTimer() 50 | var seq []int 51 | for i := 0; i < b.N; i++ { 52 | seq = append(seq, rand.Intn(len(testDrivers))) 53 | } 54 | b.StartTimer() 55 | for i := 0; i < b.N; i++ { 56 | s := oldBindType(testDrivers[seq[i]]) 57 | if s == UNKNOWN { 58 | b.Error("unknown driver") 59 | } 60 | } 61 | 62 | }) 63 | 64 | b.Run("new", func(b *testing.B) { 65 | b.StopTimer() 66 | var seq []int 67 | for i := 0; i < b.N; i++ { 68 | seq = append(seq, rand.Intn(len(testDrivers))) 69 | } 70 | b.StartTimer() 71 | for i := 0; i < b.N; i++ { 72 | s := BindType(testDrivers[seq[i]]) 73 | if s == UNKNOWN { 74 | b.Error("unknown driver") 75 | } 76 | } 77 | 78 | }) 79 | } 80 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package sqlx provides general purpose extensions to database/sql. 2 | // 3 | // It is intended to seamlessly wrap database/sql and provide convenience 4 | // methods which are useful in the development of database driven applications. 5 | // None of the underlying database/sql methods are changed. Instead all extended 6 | // behavior is implemented through new methods defined on wrapper types. 7 | // 8 | // Additions include scanning into structs, named query support, rebinding 9 | // queries for different drivers, convenient shorthands for common error handling 10 | // and more. 11 | package sqlx 12 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/jmoiron/sqlx 2 | 3 | go 1.10 4 | 5 | require ( 6 | github.com/go-sql-driver/mysql v1.8.1 7 | github.com/lib/pq v1.10.9 8 | github.com/mattn/go-sqlite3 v1.14.22 9 | ) 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= 2 | filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= 3 | github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= 4 | github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= 5 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 6 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 7 | github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= 8 | github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= 9 | -------------------------------------------------------------------------------- /named.go: -------------------------------------------------------------------------------- 1 | package sqlx 2 | 3 | // Named Query Support 4 | // 5 | // * BindMap - bind query bindvars to map/struct args 6 | // * NamedExec, NamedQuery - named query w/ struct or map 7 | // * NamedStmt - a pre-compiled named query which is a prepared statement 8 | // 9 | // Internal Interfaces: 10 | // 11 | // * compileNamedQuery - rebind a named query, returning a query and list of names 12 | // * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist 13 | // 14 | import ( 15 | "bytes" 16 | "database/sql" 17 | "errors" 18 | "fmt" 19 | "reflect" 20 | "regexp" 21 | "strconv" 22 | "unicode" 23 | 24 | "github.com/jmoiron/sqlx/reflectx" 25 | ) 26 | 27 | // NamedStmt is a prepared statement that executes named queries. Prepare it 28 | // how you would execute a NamedQuery, but pass in a struct or map when executing. 29 | type NamedStmt struct { 30 | Params []string 31 | QueryString string 32 | Stmt *Stmt 33 | } 34 | 35 | // Close closes the named statement. 36 | func (n *NamedStmt) Close() error { 37 | return n.Stmt.Close() 38 | } 39 | 40 | // Exec executes a named statement using the struct passed. 41 | // Any named placeholder parameters are replaced with fields from arg. 42 | func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) { 43 | args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) 44 | if err != nil { 45 | return *new(sql.Result), err 46 | } 47 | return n.Stmt.Exec(args...) 48 | } 49 | 50 | // Query executes a named statement using the struct argument, returning rows. 51 | // Any named placeholder parameters are replaced with fields from arg. 52 | func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) { 53 | args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) 54 | if err != nil { 55 | return nil, err 56 | } 57 | return n.Stmt.Query(args...) 58 | } 59 | 60 | // QueryRow executes a named statement against the database. Because sqlx cannot 61 | // create a *sql.Row with an error condition pre-set for binding errors, sqlx 62 | // returns a *sqlx.Row instead. 63 | // Any named placeholder parameters are replaced with fields from arg. 64 | func (n *NamedStmt) QueryRow(arg interface{}) *Row { 65 | args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) 66 | if err != nil { 67 | return &Row{err: err} 68 | } 69 | return n.Stmt.QueryRowx(args...) 70 | } 71 | 72 | // MustExec execs a NamedStmt, panicing on error 73 | // Any named placeholder parameters are replaced with fields from arg. 74 | func (n *NamedStmt) MustExec(arg interface{}) sql.Result { 75 | res, err := n.Exec(arg) 76 | if err != nil { 77 | panic(err) 78 | } 79 | return res 80 | } 81 | 82 | // Queryx using this NamedStmt 83 | // Any named placeholder parameters are replaced with fields from arg. 84 | func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) { 85 | r, err := n.Query(arg) 86 | if err != nil { 87 | return nil, err 88 | } 89 | return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err 90 | } 91 | 92 | // QueryRowx this NamedStmt. Because of limitations with QueryRow, this is 93 | // an alias for QueryRow. 94 | // Any named placeholder parameters are replaced with fields from arg. 95 | func (n *NamedStmt) QueryRowx(arg interface{}) *Row { 96 | return n.QueryRow(arg) 97 | } 98 | 99 | // Select using this NamedStmt 100 | // Any named placeholder parameters are replaced with fields from arg. 101 | func (n *NamedStmt) Select(dest interface{}, arg interface{}) error { 102 | rows, err := n.Queryx(arg) 103 | if err != nil { 104 | return err 105 | } 106 | // if something happens here, we want to make sure the rows are Closed 107 | defer rows.Close() 108 | return scanAll(rows, dest, false) 109 | } 110 | 111 | // Get using this NamedStmt 112 | // Any named placeholder parameters are replaced with fields from arg. 113 | func (n *NamedStmt) Get(dest interface{}, arg interface{}) error { 114 | r := n.QueryRowx(arg) 115 | return r.scanAny(dest, false) 116 | } 117 | 118 | // Unsafe creates an unsafe version of the NamedStmt 119 | func (n *NamedStmt) Unsafe() *NamedStmt { 120 | r := &NamedStmt{Params: n.Params, Stmt: n.Stmt, QueryString: n.QueryString} 121 | r.Stmt.unsafe = true 122 | return r 123 | } 124 | 125 | // A union interface of preparer and binder, required to be able to prepare 126 | // named statements (as the bindtype must be determined). 127 | type namedPreparer interface { 128 | Preparer 129 | binder 130 | } 131 | 132 | func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) { 133 | bindType := BindType(p.DriverName()) 134 | q, args, err := compileNamedQuery([]byte(query), bindType) 135 | if err != nil { 136 | return nil, err 137 | } 138 | stmt, err := Preparex(p, q) 139 | if err != nil { 140 | return nil, err 141 | } 142 | return &NamedStmt{ 143 | QueryString: q, 144 | Params: args, 145 | Stmt: stmt, 146 | }, nil 147 | } 148 | 149 | // convertMapStringInterface attempts to convert v to map[string]interface{}. 150 | // Unlike v.(map[string]interface{}), this function works on named types that 151 | // are convertible to map[string]interface{} as well. 152 | func convertMapStringInterface(v interface{}) (map[string]interface{}, bool) { 153 | var m map[string]interface{} 154 | mtype := reflect.TypeOf(m) 155 | t := reflect.TypeOf(v) 156 | if !t.ConvertibleTo(mtype) { 157 | return nil, false 158 | } 159 | return reflect.ValueOf(v).Convert(mtype).Interface().(map[string]interface{}), true 160 | 161 | } 162 | 163 | func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { 164 | if maparg, ok := convertMapStringInterface(arg); ok { 165 | return bindMapArgs(names, maparg) 166 | } 167 | return bindArgs(names, arg, m) 168 | } 169 | 170 | // private interface to generate a list of interfaces from a given struct 171 | // type, given a list of names to pull out of the struct. Used by public 172 | // BindStruct interface. 173 | func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { 174 | arglist := make([]interface{}, 0, len(names)) 175 | 176 | // grab the indirected value of arg 177 | var v reflect.Value 178 | for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; { 179 | v = v.Elem() 180 | } 181 | 182 | err := m.TraversalsByNameFunc(v.Type(), names, func(i int, t []int) error { 183 | if len(t) == 0 { 184 | return fmt.Errorf("could not find name %s in %#v", names[i], arg) 185 | } 186 | 187 | val := reflectx.FieldByIndexesReadOnly(v, t) 188 | arglist = append(arglist, val.Interface()) 189 | 190 | return nil 191 | }) 192 | 193 | return arglist, err 194 | } 195 | 196 | // like bindArgs, but for maps. 197 | func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) { 198 | arglist := make([]interface{}, 0, len(names)) 199 | 200 | for _, name := range names { 201 | val, ok := arg[name] 202 | if !ok { 203 | return arglist, fmt.Errorf("could not find name %s in %#v", name, arg) 204 | } 205 | arglist = append(arglist, val) 206 | } 207 | return arglist, nil 208 | } 209 | 210 | // bindStruct binds a named parameter query with fields from a struct argument. 211 | // The rules for binding field names to parameter names follow the same 212 | // conventions as for StructScan, including obeying the `db` struct tags. 213 | func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { 214 | bound, names, err := compileNamedQuery([]byte(query), bindType) 215 | if err != nil { 216 | return "", []interface{}{}, err 217 | } 218 | 219 | arglist, err := bindAnyArgs(names, arg, m) 220 | if err != nil { 221 | return "", []interface{}{}, err 222 | } 223 | 224 | return bound, arglist, nil 225 | } 226 | 227 | var valuesReg = regexp.MustCompile(`\)\s*(?i)VALUES\s*\(`) 228 | 229 | func findMatchingClosingBracketIndex(s string) int { 230 | count := 0 231 | for i, ch := range s { 232 | if ch == '(' { 233 | count++ 234 | } 235 | if ch == ')' { 236 | count-- 237 | if count == 0 { 238 | return i 239 | } 240 | } 241 | } 242 | return 0 243 | } 244 | 245 | func fixBound(bound string, loop int) string { 246 | loc := valuesReg.FindStringIndex(bound) 247 | // defensive guard when "VALUES (...)" not found 248 | if len(loc) < 2 { 249 | return bound 250 | } 251 | 252 | openingBracketIndex := loc[1] - 1 253 | index := findMatchingClosingBracketIndex(bound[openingBracketIndex:]) 254 | // defensive guard. must have closing bracket 255 | if index == 0 { 256 | return bound 257 | } 258 | closingBracketIndex := openingBracketIndex + index + 1 259 | 260 | var buffer bytes.Buffer 261 | 262 | buffer.WriteString(bound[0:closingBracketIndex]) 263 | for i := 0; i < loop-1; i++ { 264 | buffer.WriteString(",") 265 | buffer.WriteString(bound[openingBracketIndex:closingBracketIndex]) 266 | } 267 | buffer.WriteString(bound[closingBracketIndex:]) 268 | return buffer.String() 269 | } 270 | 271 | // bindArray binds a named parameter query with fields from an array or slice of 272 | // structs argument. 273 | func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { 274 | // do the initial binding with QUESTION; if bindType is not question, 275 | // we can rebind it at the end. 276 | bound, names, err := compileNamedQuery([]byte(query), QUESTION) 277 | if err != nil { 278 | return "", []interface{}{}, err 279 | } 280 | arrayValue := reflect.ValueOf(arg) 281 | arrayLen := arrayValue.Len() 282 | if arrayLen == 0 { 283 | return "", []interface{}{}, fmt.Errorf("length of array is 0: %#v", arg) 284 | } 285 | var arglist = make([]interface{}, 0, len(names)*arrayLen) 286 | for i := 0; i < arrayLen; i++ { 287 | elemArglist, err := bindAnyArgs(names, arrayValue.Index(i).Interface(), m) 288 | if err != nil { 289 | return "", []interface{}{}, err 290 | } 291 | arglist = append(arglist, elemArglist...) 292 | } 293 | if arrayLen > 1 { 294 | bound = fixBound(bound, arrayLen) 295 | } 296 | // adjust binding type if we weren't on question 297 | if bindType != QUESTION { 298 | bound = Rebind(bindType, bound) 299 | } 300 | return bound, arglist, nil 301 | } 302 | 303 | // bindMap binds a named parameter query with a map of arguments. 304 | func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) { 305 | bound, names, err := compileNamedQuery([]byte(query), bindType) 306 | if err != nil { 307 | return "", []interface{}{}, err 308 | } 309 | 310 | arglist, err := bindMapArgs(names, args) 311 | return bound, arglist, err 312 | } 313 | 314 | // -- Compilation of Named Queries 315 | 316 | // Allow digits and letters in bind params; additionally runes are 317 | // checked against underscores, meaning that bind params can have be 318 | // alphanumeric with underscores. Mind the difference between unicode 319 | // digits and numbers, where '5' is a digit but '五' is not. 320 | var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit} 321 | 322 | // FIXME: this function isn't safe for unicode named params, as a failing test 323 | // can testify. This is not a regression but a failure of the original code 324 | // as well. It should be modified to range over runes in a string rather than 325 | // bytes, even though this is less convenient and slower. Hopefully the 326 | // addition of the prepared NamedStmt (which will only do this once) will make 327 | // up for the slightly slower ad-hoc NamedExec/NamedQuery. 328 | 329 | // compile a NamedQuery into an unbound query (using the '?' bindvar) and 330 | // a list of names. 331 | func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) { 332 | names = make([]string, 0, 10) 333 | rebound := make([]byte, 0, len(qs)) 334 | 335 | inName := false 336 | last := len(qs) - 1 337 | currentVar := 1 338 | name := make([]byte, 0, 10) 339 | 340 | for i, b := range qs { 341 | // a ':' while we're in a name is an error 342 | if b == ':' { 343 | // if this is the second ':' in a '::' escape sequence, append a ':' 344 | if inName && i > 0 && qs[i-1] == ':' { 345 | rebound = append(rebound, ':') 346 | inName = false 347 | continue 348 | } else if inName { 349 | err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i)) 350 | return query, names, err 351 | } 352 | inName = true 353 | name = []byte{} 354 | } else if inName && i > 0 && b == '=' && len(name) == 0 { 355 | rebound = append(rebound, ':', '=') 356 | inName = false 357 | continue 358 | // if we're in a name, and this is an allowed character, continue 359 | } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last { 360 | // append the byte to the name if we are in a name and not on the last byte 361 | name = append(name, b) 362 | // if we're in a name and it's not an allowed character, the name is done 363 | } else if inName { 364 | inName = false 365 | // if this is the final byte of the string and it is part of the name, then 366 | // make sure to add it to the name 367 | if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) { 368 | name = append(name, b) 369 | } 370 | // add the string representation to the names list 371 | names = append(names, string(name)) 372 | // add a proper bindvar for the bindType 373 | switch bindType { 374 | // oracle only supports named type bind vars even for positional 375 | case NAMED: 376 | rebound = append(rebound, ':') 377 | rebound = append(rebound, name...) 378 | case QUESTION, UNKNOWN: 379 | rebound = append(rebound, '?') 380 | case DOLLAR: 381 | rebound = append(rebound, '$') 382 | for _, b := range strconv.Itoa(currentVar) { 383 | rebound = append(rebound, byte(b)) 384 | } 385 | currentVar++ 386 | case AT: 387 | rebound = append(rebound, '@', 'p') 388 | for _, b := range strconv.Itoa(currentVar) { 389 | rebound = append(rebound, byte(b)) 390 | } 391 | currentVar++ 392 | } 393 | // add this byte to string unless it was not part of the name 394 | if i != last { 395 | rebound = append(rebound, b) 396 | } else if !unicode.IsOneOf(allowedBindRunes, rune(b)) { 397 | rebound = append(rebound, b) 398 | } 399 | } else { 400 | // this is a normal byte and should just go onto the rebound query 401 | rebound = append(rebound, b) 402 | } 403 | } 404 | 405 | return string(rebound), names, err 406 | } 407 | 408 | // BindNamed binds a struct or a map to a query with named parameters. 409 | // DEPRECATED: use sqlx.Named` instead of this, it may be removed in future. 410 | func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) { 411 | return bindNamedMapper(bindType, query, arg, mapper()) 412 | } 413 | 414 | // Named takes a query using named parameters and an argument and 415 | // returns a new query with a list of args that can be executed by 416 | // a database. The return value uses the `?` bindvar. 417 | func Named(query string, arg interface{}) (string, []interface{}, error) { 418 | return bindNamedMapper(QUESTION, query, arg, mapper()) 419 | } 420 | 421 | func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { 422 | t := reflect.TypeOf(arg) 423 | k := t.Kind() 424 | switch { 425 | case k == reflect.Map && t.Key().Kind() == reflect.String: 426 | m, ok := convertMapStringInterface(arg) 427 | if !ok { 428 | return "", nil, fmt.Errorf("sqlx.bindNamedMapper: unsupported map type: %T", arg) 429 | } 430 | return bindMap(bindType, query, m) 431 | case k == reflect.Array || k == reflect.Slice: 432 | return bindArray(bindType, query, arg, m) 433 | default: 434 | return bindStruct(bindType, query, arg, m) 435 | } 436 | } 437 | 438 | // NamedQuery binds a named query and then runs Query on the result using the 439 | // provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with 440 | // map[string]interface{} types. 441 | func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) { 442 | q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) 443 | if err != nil { 444 | return nil, err 445 | } 446 | return e.Queryx(q, args...) 447 | } 448 | 449 | // NamedExec uses BindStruct to get a query executable by the driver and 450 | // then runs Exec on the result. Returns an error from the binding 451 | // or the query execution itself. 452 | func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) { 453 | q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) 454 | if err != nil { 455 | return nil, err 456 | } 457 | return e.Exec(q, args...) 458 | } 459 | -------------------------------------------------------------------------------- /named_context.go: -------------------------------------------------------------------------------- 1 | //go:build go1.8 2 | // +build go1.8 3 | 4 | package sqlx 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | ) 10 | 11 | // A union interface of contextPreparer and binder, required to be able to 12 | // prepare named statements with context (as the bindtype must be determined). 13 | type namedPreparerContext interface { 14 | PreparerContext 15 | binder 16 | } 17 | 18 | func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) { 19 | bindType := BindType(p.DriverName()) 20 | q, args, err := compileNamedQuery([]byte(query), bindType) 21 | if err != nil { 22 | return nil, err 23 | } 24 | stmt, err := PreparexContext(ctx, p, q) 25 | if err != nil { 26 | return nil, err 27 | } 28 | return &NamedStmt{ 29 | QueryString: q, 30 | Params: args, 31 | Stmt: stmt, 32 | }, nil 33 | } 34 | 35 | // ExecContext executes a named statement using the struct passed. 36 | // Any named placeholder parameters are replaced with fields from arg. 37 | func (n *NamedStmt) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) { 38 | args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) 39 | if err != nil { 40 | return *new(sql.Result), err 41 | } 42 | return n.Stmt.ExecContext(ctx, args...) 43 | } 44 | 45 | // QueryContext executes a named statement using the struct argument, returning rows. 46 | // Any named placeholder parameters are replaced with fields from arg. 47 | func (n *NamedStmt) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) { 48 | args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) 49 | if err != nil { 50 | return nil, err 51 | } 52 | return n.Stmt.QueryContext(ctx, args...) 53 | } 54 | 55 | // QueryRowContext executes a named statement against the database. Because sqlx cannot 56 | // create a *sql.Row with an error condition pre-set for binding errors, sqlx 57 | // returns a *sqlx.Row instead. 58 | // Any named placeholder parameters are replaced with fields from arg. 59 | func (n *NamedStmt) QueryRowContext(ctx context.Context, arg interface{}) *Row { 60 | args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) 61 | if err != nil { 62 | return &Row{err: err} 63 | } 64 | return n.Stmt.QueryRowxContext(ctx, args...) 65 | } 66 | 67 | // MustExecContext execs a NamedStmt, panicing on error 68 | // Any named placeholder parameters are replaced with fields from arg. 69 | func (n *NamedStmt) MustExecContext(ctx context.Context, arg interface{}) sql.Result { 70 | res, err := n.ExecContext(ctx, arg) 71 | if err != nil { 72 | panic(err) 73 | } 74 | return res 75 | } 76 | 77 | // QueryxContext using this NamedStmt 78 | // Any named placeholder parameters are replaced with fields from arg. 79 | func (n *NamedStmt) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) { 80 | r, err := n.QueryContext(ctx, arg) 81 | if err != nil { 82 | return nil, err 83 | } 84 | return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err 85 | } 86 | 87 | // QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is 88 | // an alias for QueryRow. 89 | // Any named placeholder parameters are replaced with fields from arg. 90 | func (n *NamedStmt) QueryRowxContext(ctx context.Context, arg interface{}) *Row { 91 | return n.QueryRowContext(ctx, arg) 92 | } 93 | 94 | // SelectContext using this NamedStmt 95 | // Any named placeholder parameters are replaced with fields from arg. 96 | func (n *NamedStmt) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error { 97 | rows, err := n.QueryxContext(ctx, arg) 98 | if err != nil { 99 | return err 100 | } 101 | // if something happens here, we want to make sure the rows are Closed 102 | defer rows.Close() 103 | return scanAll(rows, dest, false) 104 | } 105 | 106 | // GetContext using this NamedStmt 107 | // Any named placeholder parameters are replaced with fields from arg. 108 | func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interface{}) error { 109 | r := n.QueryRowxContext(ctx, arg) 110 | return r.scanAny(dest, false) 111 | } 112 | 113 | // NamedQueryContext binds a named query and then runs Query on the result using the 114 | // provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with 115 | // map[string]interface{} types. 116 | func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) { 117 | q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) 118 | if err != nil { 119 | return nil, err 120 | } 121 | return e.QueryxContext(ctx, q, args...) 122 | } 123 | 124 | // NamedExecContext uses BindStruct to get a query executable by the driver and 125 | // then runs Exec on the result. Returns an error from the binding 126 | // or the query execution itself. 127 | func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) { 128 | q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) 129 | if err != nil { 130 | return nil, err 131 | } 132 | return e.ExecContext(ctx, q, args...) 133 | } 134 | -------------------------------------------------------------------------------- /named_context_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.8 2 | // +build go1.8 3 | 4 | package sqlx 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "testing" 10 | ) 11 | 12 | func TestNamedContextQueries(t *testing.T) { 13 | RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { 14 | loadDefaultFixture(db, t) 15 | test := Test{t} 16 | var ns *NamedStmt 17 | var err error 18 | 19 | ctx := context.Background() 20 | 21 | // Check that invalid preparations fail 22 | _, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first:name") 23 | if err == nil { 24 | t.Error("Expected an error with invalid prepared statement.") 25 | } 26 | 27 | _, err = db.PrepareNamedContext(ctx, "invalid sql") 28 | if err == nil { 29 | t.Error("Expected an error with invalid prepared statement.") 30 | } 31 | 32 | // Check closing works as anticipated 33 | ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first_name") 34 | test.Error(err) 35 | err = ns.Close() 36 | test.Error(err) 37 | 38 | ns, err = db.PrepareNamedContext(ctx, ` 39 | SELECT first_name, last_name, email 40 | FROM person WHERE first_name=:first_name AND email=:email`) 41 | test.Error(err) 42 | 43 | // test Queryx w/ uses Query 44 | p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"} 45 | 46 | rows, err := ns.QueryxContext(ctx, p) 47 | test.Error(err) 48 | for rows.Next() { 49 | var p2 Person 50 | rows.StructScan(&p2) 51 | if p.FirstName != p2.FirstName { 52 | t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName) 53 | } 54 | if p.LastName != p2.LastName { 55 | t.Errorf("got %s, expected %s", p.LastName, p2.LastName) 56 | } 57 | if p.Email != p2.Email { 58 | t.Errorf("got %s, expected %s", p.Email, p2.Email) 59 | } 60 | } 61 | 62 | // test Select 63 | people := make([]Person, 0, 5) 64 | err = ns.SelectContext(ctx, &people, p) 65 | test.Error(err) 66 | 67 | if len(people) != 1 { 68 | t.Errorf("got %d results, expected %d", len(people), 1) 69 | } 70 | if p.FirstName != people[0].FirstName { 71 | t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName) 72 | } 73 | if p.LastName != people[0].LastName { 74 | t.Errorf("got %s, expected %s", p.LastName, people[0].LastName) 75 | } 76 | if p.Email != people[0].Email { 77 | t.Errorf("got %s, expected %s", p.Email, people[0].Email) 78 | } 79 | 80 | // test Exec 81 | ns, err = db.PrepareNamedContext(ctx, ` 82 | INSERT INTO person (first_name, last_name, email) 83 | VALUES (:first_name, :last_name, :email)`) 84 | test.Error(err) 85 | 86 | js := Person{ 87 | FirstName: "Julien", 88 | LastName: "Savea", 89 | Email: "jsavea@ab.co.nz", 90 | } 91 | _, err = ns.ExecContext(ctx, js) 92 | test.Error(err) 93 | 94 | // Make sure we can pull him out again 95 | p2 := Person{} 96 | db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email) 97 | if p2.Email != js.Email { 98 | t.Errorf("expected %s, got %s", js.Email, p2.Email) 99 | } 100 | 101 | // test Txn NamedStmts 102 | tx := db.MustBeginTx(ctx, nil) 103 | txns := tx.NamedStmtContext(ctx, ns) 104 | 105 | // We're going to add Steven in this txn 106 | sl := Person{ 107 | FirstName: "Steven", 108 | LastName: "Luatua", 109 | Email: "sluatua@ab.co.nz", 110 | } 111 | 112 | _, err = txns.ExecContext(ctx, sl) 113 | test.Error(err) 114 | // then rollback... 115 | tx.Rollback() 116 | // looking for Steven after a rollback should fail 117 | err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) 118 | if err != sql.ErrNoRows { 119 | t.Errorf("expected no rows error, got %v", err) 120 | } 121 | 122 | // now do the same, but commit 123 | tx = db.MustBeginTx(ctx, nil) 124 | txns = tx.NamedStmtContext(ctx, ns) 125 | _, err = txns.ExecContext(ctx, sl) 126 | test.Error(err) 127 | tx.Commit() 128 | 129 | // looking for Steven after a Commit should succeed 130 | err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) 131 | test.Error(err) 132 | if p2.Email != sl.Email { 133 | t.Errorf("expected %s, got %s", sl.Email, p2.Email) 134 | } 135 | 136 | }) 137 | } 138 | -------------------------------------------------------------------------------- /named_test.go: -------------------------------------------------------------------------------- 1 | package sqlx 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "testing" 7 | ) 8 | 9 | func TestCompileQuery(t *testing.T) { 10 | table := []struct { 11 | Q, R, D, T, N string 12 | V []string 13 | }{ 14 | // basic test for named parameters, invalid char ',' terminating 15 | { 16 | Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, 17 | R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, 18 | D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, 19 | T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`, 20 | N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, 21 | V: []string{"name", "age", "first", "last"}, 22 | }, 23 | // This query tests a named parameter ending the string as well as numbers 24 | { 25 | Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, 26 | R: `SELECT * FROM a WHERE first_name=? AND last_name=?`, 27 | D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`, 28 | T: `SELECT * FROM a WHERE first_name=@p1 AND last_name=@p2`, 29 | N: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, 30 | V: []string{"name1", "name2"}, 31 | }, 32 | { 33 | Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, 34 | R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`, 35 | D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`, 36 | T: `SELECT ":foo" FROM a WHERE first_name=@p1 AND last_name=@p2`, 37 | N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, 38 | V: []string{"name1", "name2"}, 39 | }, 40 | { 41 | Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`, 42 | R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`, 43 | D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`, 44 | T: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`, 45 | N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`, 46 | V: []string{"first_name", "last_name"}, 47 | }, 48 | { 49 | Q: `SELECT @name := "name", :age, :first, :last`, 50 | R: `SELECT @name := "name", ?, ?, ?`, 51 | D: `SELECT @name := "name", $1, $2, $3`, 52 | N: `SELECT @name := "name", :age, :first, :last`, 53 | T: `SELECT @name := "name", @p1, @p2, @p3`, 54 | V: []string{"age", "first", "last"}, 55 | }, 56 | /* This unicode awareness test sadly fails, because of our byte-wise worldview. 57 | * We could certainly iterate by Rune instead, though it's a great deal slower, 58 | * it's probably the RightWay(tm) 59 | { 60 | Q: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`, 61 | R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, 62 | D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, 63 | N: []string{"name", "age", "first", "last"}, 64 | }, 65 | */ 66 | } 67 | 68 | for _, test := range table { 69 | qr, names, err := compileNamedQuery([]byte(test.Q), QUESTION) 70 | if err != nil { 71 | t.Error(err) 72 | } 73 | if qr != test.R { 74 | t.Errorf("expected %s, got %s", test.R, qr) 75 | } 76 | if len(names) != len(test.V) { 77 | t.Errorf("expected %#v, got %#v", test.V, names) 78 | } else { 79 | for i, name := range names { 80 | if name != test.V[i] { 81 | t.Errorf("expected %dth name to be %s, got %s", i+1, test.V[i], name) 82 | } 83 | } 84 | } 85 | qd, _, _ := compileNamedQuery([]byte(test.Q), DOLLAR) 86 | if qd != test.D { 87 | t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd) 88 | } 89 | 90 | qt, _, _ := compileNamedQuery([]byte(test.Q), AT) 91 | if qt != test.T { 92 | t.Errorf("\nexpected: `%s`\ngot: `%s`", test.T, qt) 93 | } 94 | 95 | qq, _, _ := compileNamedQuery([]byte(test.Q), NAMED) 96 | if qq != test.N { 97 | t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq)) 98 | } 99 | } 100 | } 101 | 102 | type Test struct { 103 | t *testing.T 104 | } 105 | 106 | func (t Test) Error(err error, msg ...interface{}) { 107 | t.t.Helper() 108 | if err != nil { 109 | if len(msg) == 0 { 110 | t.t.Error(err) 111 | } else { 112 | t.t.Error(msg...) 113 | } 114 | } 115 | } 116 | 117 | func (t Test) Errorf(err error, format string, args ...interface{}) { 118 | t.t.Helper() 119 | if err != nil { 120 | t.t.Errorf(format, args...) 121 | } 122 | } 123 | 124 | func TestEscapedColons(t *testing.T) { 125 | t.Skip("not sure it is possible to support this in general case without an SQL parser") 126 | var qs = `SELECT * FROM testtable WHERE timeposted BETWEEN (now() AT TIME ZONE 'utc') AND 127 | (now() AT TIME ZONE 'utc') - interval '01:30:00') AND name = '\'this is a test\'' and id = :id` 128 | _, _, err := compileNamedQuery([]byte(qs), DOLLAR) 129 | if err != nil { 130 | t.Error("Didn't handle colons correctly when inside a string") 131 | } 132 | } 133 | 134 | func TestNamedQueries(t *testing.T) { 135 | RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { 136 | loadDefaultFixture(db, t) 137 | test := Test{t} 138 | var ns *NamedStmt 139 | var err error 140 | 141 | // Check that invalid preparations fail 142 | _, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first:name") 143 | if err == nil { 144 | t.Error("Expected an error with invalid prepared statement.") 145 | } 146 | 147 | _, err = db.PrepareNamed("invalid sql") 148 | if err == nil { 149 | t.Error("Expected an error with invalid prepared statement.") 150 | } 151 | 152 | // Check closing works as anticipated 153 | ns, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first_name") 154 | test.Error(err) 155 | err = ns.Close() 156 | test.Error(err) 157 | 158 | ns, err = db.PrepareNamed(` 159 | SELECT first_name, last_name, email 160 | FROM person WHERE first_name=:first_name AND email=:email`) 161 | test.Error(err) 162 | 163 | // test Queryx w/ uses Query 164 | p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"} 165 | 166 | rows, err := ns.Queryx(p) 167 | test.Error(err) 168 | for rows.Next() { 169 | var p2 Person 170 | rows.StructScan(&p2) 171 | if p.FirstName != p2.FirstName { 172 | t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName) 173 | } 174 | if p.LastName != p2.LastName { 175 | t.Errorf("got %s, expected %s", p.LastName, p2.LastName) 176 | } 177 | if p.Email != p2.Email { 178 | t.Errorf("got %s, expected %s", p.Email, p2.Email) 179 | } 180 | } 181 | 182 | // test Select 183 | people := make([]Person, 0, 5) 184 | err = ns.Select(&people, p) 185 | test.Error(err) 186 | 187 | if len(people) != 1 { 188 | t.Errorf("got %d results, expected %d", len(people), 1) 189 | } 190 | if p.FirstName != people[0].FirstName { 191 | t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName) 192 | } 193 | if p.LastName != people[0].LastName { 194 | t.Errorf("got %s, expected %s", p.LastName, people[0].LastName) 195 | } 196 | if p.Email != people[0].Email { 197 | t.Errorf("got %s, expected %s", p.Email, people[0].Email) 198 | } 199 | 200 | // test struct batch inserts 201 | sls := []Person{ 202 | {FirstName: "Ardie", LastName: "Savea", Email: "asavea@ab.co.nz"}, 203 | {FirstName: "Sonny Bill", LastName: "Williams", Email: "sbw@ab.co.nz"}, 204 | {FirstName: "Ngani", LastName: "Laumape", Email: "nlaumape@ab.co.nz"}, 205 | } 206 | 207 | insert := fmt.Sprintf( 208 | "INSERT INTO person (first_name, last_name, email, added_at) VALUES (:first_name, :last_name, :email, %v)\n", 209 | now, 210 | ) 211 | _, err = db.NamedExec(insert, sls) 212 | test.Error(err) 213 | 214 | // test map batch inserts 215 | slsMap := []map[string]interface{}{ 216 | {"first_name": "Ardie", "last_name": "Savea", "email": "asavea@ab.co.nz"}, 217 | {"first_name": "Sonny Bill", "last_name": "Williams", "email": "sbw@ab.co.nz"}, 218 | {"first_name": "Ngani", "last_name": "Laumape", "email": "nlaumape@ab.co.nz"}, 219 | } 220 | 221 | _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email) 222 | VALUES (:first_name, :last_name, :email) ;--`, slsMap) 223 | test.Error(err) 224 | 225 | type A map[string]interface{} 226 | 227 | typedMap := []A{ 228 | {"first_name": "Ardie", "last_name": "Savea", "email": "asavea@ab.co.nz"}, 229 | {"first_name": "Sonny Bill", "last_name": "Williams", "email": "sbw@ab.co.nz"}, 230 | {"first_name": "Ngani", "last_name": "Laumape", "email": "nlaumape@ab.co.nz"}, 231 | } 232 | 233 | _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email) 234 | VALUES (:first_name, :last_name, :email) ;--`, typedMap) 235 | test.Error(err) 236 | 237 | for _, p := range sls { 238 | dest := Person{} 239 | err = db.Get(&dest, db.Rebind("SELECT * FROM person WHERE email=?"), p.Email) 240 | test.Error(err) 241 | if dest.Email != p.Email { 242 | t.Errorf("expected %s, got %s", p.Email, dest.Email) 243 | } 244 | } 245 | 246 | // test Exec 247 | ns, err = db.PrepareNamed(` 248 | INSERT INTO person (first_name, last_name, email) 249 | VALUES (:first_name, :last_name, :email)`) 250 | test.Error(err) 251 | 252 | js := Person{ 253 | FirstName: "Julien", 254 | LastName: "Savea", 255 | Email: "jsavea@ab.co.nz", 256 | } 257 | _, err = ns.Exec(js) 258 | test.Error(err) 259 | 260 | // Make sure we can pull him out again 261 | p2 := Person{} 262 | db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email) 263 | if p2.Email != js.Email { 264 | t.Errorf("expected %s, got %s", js.Email, p2.Email) 265 | } 266 | 267 | // test Txn NamedStmts 268 | tx := db.MustBegin() 269 | txns := tx.NamedStmt(ns) 270 | 271 | // We're going to add Steven in this txn 272 | sl := Person{ 273 | FirstName: "Steven", 274 | LastName: "Luatua", 275 | Email: "sluatua@ab.co.nz", 276 | } 277 | 278 | _, err = txns.Exec(sl) 279 | test.Error(err) 280 | // then rollback... 281 | tx.Rollback() 282 | // looking for Steven after a rollback should fail 283 | err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) 284 | if err != sql.ErrNoRows { 285 | t.Errorf("expected no rows error, got %v", err) 286 | } 287 | 288 | // now do the same, but commit 289 | tx = db.MustBegin() 290 | txns = tx.NamedStmt(ns) 291 | _, err = txns.Exec(sl) 292 | test.Error(err) 293 | tx.Commit() 294 | 295 | // looking for Steven after a Commit should succeed 296 | err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) 297 | test.Error(err) 298 | if p2.Email != sl.Email { 299 | t.Errorf("expected %s, got %s", sl.Email, p2.Email) 300 | } 301 | 302 | }) 303 | } 304 | 305 | func TestFixBounds(t *testing.T) { 306 | table := []struct { 307 | name, query, expect string 308 | loop int 309 | }{ 310 | { 311 | name: `named syntax`, 312 | query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, 313 | expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last)`, 314 | loop: 2, 315 | }, 316 | { 317 | name: `mysql syntax`, 318 | query: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, 319 | expect: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?)`, 320 | loop: 2, 321 | }, 322 | { 323 | name: `named syntax w/ trailer`, 324 | query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) ;--`, 325 | expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last) ;--`, 326 | loop: 2, 327 | }, 328 | { 329 | name: `mysql syntax w/ trailer`, 330 | query: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?) ;--`, 331 | expect: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?) ;--`, 332 | loop: 2, 333 | }, 334 | { 335 | name: `not found test`, 336 | query: `INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`, 337 | expect: `INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`, 338 | loop: 2, 339 | }, 340 | { 341 | name: `found twice test`, 342 | query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`, 343 | expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`, 344 | loop: 2, 345 | }, 346 | { 347 | name: `nospace`, 348 | query: `INSERT INTO foo (a,b) VALUES(:a, :b)`, 349 | expect: `INSERT INTO foo (a,b) VALUES(:a, :b),(:a, :b)`, 350 | loop: 2, 351 | }, 352 | { 353 | name: `lowercase`, 354 | query: `INSERT INTO foo (a,b) values(:a, :b)`, 355 | expect: `INSERT INTO foo (a,b) values(:a, :b),(:a, :b)`, 356 | loop: 2, 357 | }, 358 | { 359 | name: `on duplicate key using VALUES`, 360 | query: `INSERT INTO foo (a,b) VALUES (:a, :b) ON DUPLICATE KEY UPDATE a=VALUES(a)`, 361 | expect: `INSERT INTO foo (a,b) VALUES (:a, :b),(:a, :b) ON DUPLICATE KEY UPDATE a=VALUES(a)`, 362 | loop: 2, 363 | }, 364 | { 365 | name: `single column`, 366 | query: `INSERT INTO foo (a) VALUES (:a)`, 367 | expect: `INSERT INTO foo (a) VALUES (:a),(:a)`, 368 | loop: 2, 369 | }, 370 | { 371 | name: `call now`, 372 | query: `INSERT INTO foo (a, b) VALUES (:a, NOW())`, 373 | expect: `INSERT INTO foo (a, b) VALUES (:a, NOW()),(:a, NOW())`, 374 | loop: 2, 375 | }, 376 | { 377 | name: `two level depth function call`, 378 | query: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW()))`, 379 | expect: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())),(:a, YEAR(NOW()))`, 380 | loop: 2, 381 | }, 382 | { 383 | name: `missing closing bracket`, 384 | query: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())`, 385 | expect: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())`, 386 | loop: 2, 387 | }, 388 | { 389 | name: `table with "values" at the end`, 390 | query: `INSERT INTO table_values (a, b) VALUES (:a, :b)`, 391 | expect: `INSERT INTO table_values (a, b) VALUES (:a, :b),(:a, :b)`, 392 | loop: 2, 393 | }, 394 | { 395 | name: `multiline indented query`, 396 | query: `INSERT INTO foo ( 397 | a, 398 | b, 399 | c, 400 | d 401 | ) VALUES ( 402 | :name, 403 | :age, 404 | :first, 405 | :last 406 | )`, 407 | expect: `INSERT INTO foo ( 408 | a, 409 | b, 410 | c, 411 | d 412 | ) VALUES ( 413 | :name, 414 | :age, 415 | :first, 416 | :last 417 | ),( 418 | :name, 419 | :age, 420 | :first, 421 | :last 422 | )`, 423 | loop: 2, 424 | }, 425 | } 426 | 427 | for _, tc := range table { 428 | t.Run(tc.name, func(t *testing.T) { 429 | res := fixBound(tc.query, tc.loop) 430 | if res != tc.expect { 431 | t.Errorf("mismatched results") 432 | } 433 | }) 434 | } 435 | } 436 | -------------------------------------------------------------------------------- /reflectx/README.md: -------------------------------------------------------------------------------- 1 | # reflectx 2 | 3 | The sqlx package has special reflect needs. In particular, it needs to: 4 | 5 | * be able to map a name to a field 6 | * understand embedded structs 7 | * understand mapping names to fields by a particular tag 8 | * user specified name -> field mapping functions 9 | 10 | These behaviors mimic the behaviors by the standard library marshallers and also the 11 | behavior of standard Go accessors. 12 | 13 | The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is 14 | addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct 15 | tags in the ways that are vital to most marshallers, and they are slow. 16 | 17 | This reflectx package extends reflect to achieve these goals. 18 | -------------------------------------------------------------------------------- /reflectx/reflect.go: -------------------------------------------------------------------------------- 1 | // Package reflectx implements extensions to the standard reflect lib suitable 2 | // for implementing marshalling and unmarshalling packages. The main Mapper type 3 | // allows for Go-compatible named attribute access, including accessing embedded 4 | // struct attributes and the ability to use functions and struct tags to 5 | // customize field names. 6 | package reflectx 7 | 8 | import ( 9 | "reflect" 10 | "runtime" 11 | "strings" 12 | "sync" 13 | ) 14 | 15 | // A FieldInfo is metadata for a struct field. 16 | type FieldInfo struct { 17 | Index []int 18 | Path string 19 | Field reflect.StructField 20 | Zero reflect.Value 21 | Name string 22 | Options map[string]string 23 | Embedded bool 24 | Children []*FieldInfo 25 | Parent *FieldInfo 26 | } 27 | 28 | // A StructMap is an index of field metadata for a struct. 29 | type StructMap struct { 30 | Tree *FieldInfo 31 | Index []*FieldInfo 32 | Paths map[string]*FieldInfo 33 | Names map[string]*FieldInfo 34 | } 35 | 36 | // GetByPath returns a *FieldInfo for a given string path. 37 | func (f StructMap) GetByPath(path string) *FieldInfo { 38 | return f.Paths[path] 39 | } 40 | 41 | // GetByTraversal returns a *FieldInfo for a given integer path. It is 42 | // analogous to reflect.FieldByIndex, but using the cached traversal 43 | // rather than re-executing the reflect machinery each time. 44 | func (f StructMap) GetByTraversal(index []int) *FieldInfo { 45 | if len(index) == 0 { 46 | return nil 47 | } 48 | 49 | tree := f.Tree 50 | for _, i := range index { 51 | if i >= len(tree.Children) || tree.Children[i] == nil { 52 | return nil 53 | } 54 | tree = tree.Children[i] 55 | } 56 | return tree 57 | } 58 | 59 | // Mapper is a general purpose mapper of names to struct fields. A Mapper 60 | // behaves like most marshallers in the standard library, obeying a field tag 61 | // for name mapping but also providing a basic transform function. 62 | type Mapper struct { 63 | cache map[reflect.Type]*StructMap 64 | tagName string 65 | tagMapFunc func(string) string 66 | mapFunc func(string) string 67 | mutex sync.Mutex 68 | } 69 | 70 | // NewMapper returns a new mapper using the tagName as its struct field tag. 71 | // If tagName is the empty string, it is ignored. 72 | func NewMapper(tagName string) *Mapper { 73 | return &Mapper{ 74 | cache: make(map[reflect.Type]*StructMap), 75 | tagName: tagName, 76 | } 77 | } 78 | 79 | // NewMapperTagFunc returns a new mapper which contains a mapper for field names 80 | // AND a mapper for tag values. This is useful for tags like json which can 81 | // have values like "name,omitempty". 82 | func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper { 83 | return &Mapper{ 84 | cache: make(map[reflect.Type]*StructMap), 85 | tagName: tagName, 86 | mapFunc: mapFunc, 87 | tagMapFunc: tagMapFunc, 88 | } 89 | } 90 | 91 | // NewMapperFunc returns a new mapper which optionally obeys a field tag and 92 | // a struct field name mapper func given by f. Tags will take precedence, but 93 | // for any other field, the mapped name will be f(field.Name) 94 | func NewMapperFunc(tagName string, f func(string) string) *Mapper { 95 | return &Mapper{ 96 | cache: make(map[reflect.Type]*StructMap), 97 | tagName: tagName, 98 | mapFunc: f, 99 | } 100 | } 101 | 102 | // TypeMap returns a mapping of field strings to int slices representing 103 | // the traversal down the struct to reach the field. 104 | func (m *Mapper) TypeMap(t reflect.Type) *StructMap { 105 | m.mutex.Lock() 106 | mapping, ok := m.cache[t] 107 | if !ok { 108 | mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc) 109 | m.cache[t] = mapping 110 | } 111 | m.mutex.Unlock() 112 | return mapping 113 | } 114 | 115 | // FieldMap returns the mapper's mapping of field names to reflect values. Panics 116 | // if v's Kind is not Struct, or v is not Indirectable to a struct kind. 117 | func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value { 118 | v = reflect.Indirect(v) 119 | mustBe(v, reflect.Struct) 120 | 121 | r := map[string]reflect.Value{} 122 | tm := m.TypeMap(v.Type()) 123 | for tagName, fi := range tm.Names { 124 | r[tagName] = FieldByIndexes(v, fi.Index) 125 | } 126 | return r 127 | } 128 | 129 | // FieldByName returns a field by its mapped name as a reflect.Value. 130 | // Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. 131 | // Returns zero Value if the name is not found. 132 | func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { 133 | v = reflect.Indirect(v) 134 | mustBe(v, reflect.Struct) 135 | 136 | tm := m.TypeMap(v.Type()) 137 | fi, ok := tm.Names[name] 138 | if !ok { 139 | return v 140 | } 141 | return FieldByIndexes(v, fi.Index) 142 | } 143 | 144 | // FieldsByName returns a slice of values corresponding to the slice of names 145 | // for the value. Panics if v's Kind is not Struct or v is not Indirectable 146 | // to a struct Kind. Returns zero Value for each name not found. 147 | func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { 148 | v = reflect.Indirect(v) 149 | mustBe(v, reflect.Struct) 150 | 151 | tm := m.TypeMap(v.Type()) 152 | vals := make([]reflect.Value, 0, len(names)) 153 | for _, name := range names { 154 | fi, ok := tm.Names[name] 155 | if !ok { 156 | vals = append(vals, *new(reflect.Value)) 157 | } else { 158 | vals = append(vals, FieldByIndexes(v, fi.Index)) 159 | } 160 | } 161 | return vals 162 | } 163 | 164 | // TraversalsByName returns a slice of int slices which represent the struct 165 | // traversals for each mapped name. Panics if t is not a struct or Indirectable 166 | // to a struct. Returns empty int slice for each name not found. 167 | func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { 168 | r := make([][]int, 0, len(names)) 169 | m.TraversalsByNameFunc(t, names, func(_ int, i []int) error { 170 | if i == nil { 171 | r = append(r, []int{}) 172 | } else { 173 | r = append(r, i) 174 | } 175 | 176 | return nil 177 | }) 178 | return r 179 | } 180 | 181 | // TraversalsByNameFunc traverses the mapped names and calls fn with the index of 182 | // each name and the struct traversal represented by that name. Panics if t is not 183 | // a struct or Indirectable to a struct. Returns the first error returned by fn or nil. 184 | func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(int, []int) error) error { 185 | t = Deref(t) 186 | mustBe(t, reflect.Struct) 187 | tm := m.TypeMap(t) 188 | for i, name := range names { 189 | fi, ok := tm.Names[name] 190 | if !ok { 191 | if err := fn(i, nil); err != nil { 192 | return err 193 | } 194 | } else { 195 | if err := fn(i, fi.Index); err != nil { 196 | return err 197 | } 198 | } 199 | } 200 | return nil 201 | } 202 | 203 | // FieldByIndexes returns a value for the field given by the struct traversal 204 | // for the given value. 205 | func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { 206 | for _, i := range indexes { 207 | v = reflect.Indirect(v).Field(i) 208 | // if this is a pointer and it's nil, allocate a new value and set it 209 | if v.Kind() == reflect.Ptr && v.IsNil() { 210 | alloc := reflect.New(Deref(v.Type())) 211 | v.Set(alloc) 212 | } 213 | if v.Kind() == reflect.Map && v.IsNil() { 214 | v.Set(reflect.MakeMap(v.Type())) 215 | } 216 | } 217 | return v 218 | } 219 | 220 | // FieldByIndexesReadOnly returns a value for a particular struct traversal, 221 | // but is not concerned with allocating nil pointers because the value is 222 | // going to be used for reading and not setting. 223 | func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value { 224 | for _, i := range indexes { 225 | v = reflect.Indirect(v).Field(i) 226 | } 227 | return v 228 | } 229 | 230 | // Deref is Indirect for reflect.Types 231 | func Deref(t reflect.Type) reflect.Type { 232 | if t.Kind() == reflect.Ptr { 233 | t = t.Elem() 234 | } 235 | return t 236 | } 237 | 238 | // -- helpers & utilities -- 239 | 240 | type kinder interface { 241 | Kind() reflect.Kind 242 | } 243 | 244 | // mustBe checks a value against a kind, panicing with a reflect.ValueError 245 | // if the kind isn't that which is required. 246 | func mustBe(v kinder, expected reflect.Kind) { 247 | if k := v.Kind(); k != expected { 248 | panic(&reflect.ValueError{Method: methodName(), Kind: k}) 249 | } 250 | } 251 | 252 | // methodName returns the caller of the function calling methodName 253 | func methodName() string { 254 | pc, _, _, _ := runtime.Caller(2) 255 | f := runtime.FuncForPC(pc) 256 | if f == nil { 257 | return "unknown method" 258 | } 259 | return f.Name() 260 | } 261 | 262 | type typeQueue struct { 263 | t reflect.Type 264 | fi *FieldInfo 265 | pp string // Parent path 266 | } 267 | 268 | // A copying append that creates a new slice each time. 269 | func apnd(is []int, i int) []int { 270 | x := make([]int, len(is)+1) 271 | copy(x, is) 272 | x[len(x)-1] = i 273 | return x 274 | } 275 | 276 | type mapf func(string) string 277 | 278 | // parseName parses the tag and the target name for the given field using 279 | // the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the 280 | // field's name to a target name, and tagMapFunc for mapping the tag to 281 | // a target name. 282 | func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) { 283 | // first, set the fieldName to the field's name 284 | fieldName = field.Name 285 | // if a mapFunc is set, use that to override the fieldName 286 | if mapFunc != nil { 287 | fieldName = mapFunc(fieldName) 288 | } 289 | 290 | // if there's no tag to look for, return the field name 291 | if tagName == "" { 292 | return "", fieldName 293 | } 294 | 295 | // if this tag is not set using the normal convention in the tag, 296 | // then return the fieldname.. this check is done because according 297 | // to the reflect documentation: 298 | // If the tag does not have the conventional format, 299 | // the value returned by Get is unspecified. 300 | // which doesn't sound great. 301 | if !strings.Contains(string(field.Tag), tagName+":") { 302 | return "", fieldName 303 | } 304 | 305 | // at this point we're fairly sure that we have a tag, so lets pull it out 306 | tag = field.Tag.Get(tagName) 307 | 308 | // if we have a mapper function, call it on the whole tag 309 | // XXX: this is a change from the old version, which pulled out the name 310 | // before the tagMapFunc could be run, but I think this is the right way 311 | if tagMapFunc != nil { 312 | tag = tagMapFunc(tag) 313 | } 314 | 315 | // finally, split the options from the name 316 | parts := strings.Split(tag, ",") 317 | fieldName = parts[0] 318 | 319 | return tag, fieldName 320 | } 321 | 322 | // parseOptions parses options out of a tag string, skipping the name 323 | func parseOptions(tag string) map[string]string { 324 | parts := strings.Split(tag, ",") 325 | options := make(map[string]string, len(parts)) 326 | if len(parts) > 1 { 327 | for _, opt := range parts[1:] { 328 | // short circuit potentially expensive split op 329 | if strings.Contains(opt, "=") { 330 | kv := strings.Split(opt, "=") 331 | options[kv[0]] = kv[1] 332 | continue 333 | } 334 | options[opt] = "" 335 | } 336 | } 337 | return options 338 | } 339 | 340 | // getMapping returns a mapping for the t type, using the tagName, mapFunc and 341 | // tagMapFunc to determine the canonical names of fields. 342 | func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap { 343 | m := []*FieldInfo{} 344 | 345 | root := &FieldInfo{} 346 | queue := []typeQueue{} 347 | queue = append(queue, typeQueue{Deref(t), root, ""}) 348 | 349 | QueueLoop: 350 | for len(queue) != 0 { 351 | // pop the first item off of the queue 352 | tq := queue[0] 353 | queue = queue[1:] 354 | 355 | // ignore recursive field 356 | for p := tq.fi.Parent; p != nil; p = p.Parent { 357 | if tq.fi.Field.Type == p.Field.Type { 358 | continue QueueLoop 359 | } 360 | } 361 | 362 | nChildren := 0 363 | if tq.t.Kind() == reflect.Struct { 364 | nChildren = tq.t.NumField() 365 | } 366 | tq.fi.Children = make([]*FieldInfo, nChildren) 367 | 368 | // iterate through all of its fields 369 | for fieldPos := 0; fieldPos < nChildren; fieldPos++ { 370 | 371 | f := tq.t.Field(fieldPos) 372 | 373 | // parse the tag and the target name using the mapping options for this field 374 | tag, name := parseName(f, tagName, mapFunc, tagMapFunc) 375 | 376 | // if the name is "-", disabled via a tag, skip it 377 | if name == "-" { 378 | continue 379 | } 380 | 381 | fi := FieldInfo{ 382 | Field: f, 383 | Name: name, 384 | Zero: reflect.New(f.Type).Elem(), 385 | Options: parseOptions(tag), 386 | } 387 | 388 | // if the path is empty this path is just the name 389 | if tq.pp == "" { 390 | fi.Path = fi.Name 391 | } else { 392 | fi.Path = tq.pp + "." + fi.Name 393 | } 394 | 395 | // skip unexported fields 396 | if len(f.PkgPath) != 0 && !f.Anonymous { 397 | continue 398 | } 399 | 400 | // bfs search of anonymous embedded structs 401 | if f.Anonymous { 402 | pp := tq.pp 403 | if tag != "" { 404 | pp = fi.Path 405 | } 406 | 407 | fi.Embedded = true 408 | fi.Index = apnd(tq.fi.Index, fieldPos) 409 | nChildren := 0 410 | ft := Deref(f.Type) 411 | if ft.Kind() == reflect.Struct { 412 | nChildren = ft.NumField() 413 | } 414 | fi.Children = make([]*FieldInfo, nChildren) 415 | queue = append(queue, typeQueue{Deref(f.Type), &fi, pp}) 416 | } else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) { 417 | fi.Index = apnd(tq.fi.Index, fieldPos) 418 | fi.Children = make([]*FieldInfo, Deref(f.Type).NumField()) 419 | queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path}) 420 | } 421 | 422 | fi.Index = apnd(tq.fi.Index, fieldPos) 423 | fi.Parent = tq.fi 424 | tq.fi.Children[fieldPos] = &fi 425 | m = append(m, &fi) 426 | } 427 | } 428 | 429 | flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}} 430 | for _, fi := range flds.Index { 431 | // check if nothing has already been pushed with the same path 432 | // sometimes you can choose to override a type using embedded struct 433 | fld, ok := flds.Paths[fi.Path] 434 | if !ok || fld.Embedded { 435 | flds.Paths[fi.Path] = fi 436 | if fi.Name != "" && !fi.Embedded { 437 | flds.Names[fi.Path] = fi 438 | } 439 | } 440 | } 441 | 442 | return flds 443 | } 444 | -------------------------------------------------------------------------------- /reflectx/reflect_test.go: -------------------------------------------------------------------------------- 1 | package reflectx 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func ival(v reflect.Value) int { 10 | return v.Interface().(int) 11 | } 12 | 13 | func TestBasic(t *testing.T) { 14 | type Foo struct { 15 | A int 16 | B int 17 | C int 18 | } 19 | 20 | f := Foo{1, 2, 3} 21 | fv := reflect.ValueOf(f) 22 | m := NewMapperFunc("", func(s string) string { return s }) 23 | 24 | v := m.FieldByName(fv, "A") 25 | if ival(v) != f.A { 26 | t.Errorf("Expecting %d, got %d", ival(v), f.A) 27 | } 28 | v = m.FieldByName(fv, "B") 29 | if ival(v) != f.B { 30 | t.Errorf("Expecting %d, got %d", f.B, ival(v)) 31 | } 32 | v = m.FieldByName(fv, "C") 33 | if ival(v) != f.C { 34 | t.Errorf("Expecting %d, got %d", f.C, ival(v)) 35 | } 36 | } 37 | 38 | func TestBasicEmbedded(t *testing.T) { 39 | type Foo struct { 40 | A int 41 | } 42 | 43 | type Bar struct { 44 | Foo // `db:""` is implied for an embedded struct 45 | B int 46 | C int `db:"-"` 47 | } 48 | 49 | type Baz struct { 50 | A int 51 | Bar `db:"Bar"` 52 | } 53 | 54 | m := NewMapperFunc("db", func(s string) string { return s }) 55 | 56 | z := Baz{} 57 | z.A = 1 58 | z.B = 2 59 | z.C = 4 60 | z.Bar.Foo.A = 3 61 | 62 | zv := reflect.ValueOf(z) 63 | fields := m.TypeMap(reflect.TypeOf(z)) 64 | 65 | if len(fields.Index) != 5 { 66 | t.Errorf("Expecting 5 fields") 67 | } 68 | 69 | // for _, fi := range fields.Index { 70 | // log.Println(fi) 71 | // } 72 | 73 | v := m.FieldByName(zv, "A") 74 | if ival(v) != z.A { 75 | t.Errorf("Expecting %d, got %d", z.A, ival(v)) 76 | } 77 | v = m.FieldByName(zv, "Bar.B") 78 | if ival(v) != z.Bar.B { 79 | t.Errorf("Expecting %d, got %d", z.Bar.B, ival(v)) 80 | } 81 | v = m.FieldByName(zv, "Bar.A") 82 | if ival(v) != z.Bar.Foo.A { 83 | t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v)) 84 | } 85 | v = m.FieldByName(zv, "Bar.C") 86 | if _, ok := v.Interface().(int); ok { 87 | t.Errorf("Expecting Bar.C to not exist") 88 | } 89 | 90 | fi := fields.GetByPath("Bar.C") 91 | if fi != nil { 92 | t.Errorf("Bar.C should not exist") 93 | } 94 | } 95 | 96 | func TestEmbeddedSimple(t *testing.T) { 97 | type UUID [16]byte 98 | type MyID struct { 99 | UUID 100 | } 101 | type Item struct { 102 | ID MyID 103 | } 104 | z := Item{} 105 | 106 | m := NewMapper("db") 107 | m.TypeMap(reflect.TypeOf(z)) 108 | } 109 | 110 | func TestBasicEmbeddedWithTags(t *testing.T) { 111 | type Foo struct { 112 | A int `db:"a"` 113 | } 114 | 115 | type Bar struct { 116 | Foo // `db:""` is implied for an embedded struct 117 | B int `db:"b"` 118 | } 119 | 120 | type Baz struct { 121 | A int `db:"a"` 122 | Bar // `db:""` is implied for an embedded struct 123 | } 124 | 125 | m := NewMapper("db") 126 | 127 | z := Baz{} 128 | z.A = 1 129 | z.B = 2 130 | z.Bar.Foo.A = 3 131 | 132 | zv := reflect.ValueOf(z) 133 | fields := m.TypeMap(reflect.TypeOf(z)) 134 | 135 | if len(fields.Index) != 5 { 136 | t.Errorf("Expecting 5 fields") 137 | } 138 | 139 | // for _, fi := range fields.index { 140 | // log.Println(fi) 141 | // } 142 | 143 | v := m.FieldByName(zv, "a") 144 | if ival(v) != z.A { // the dominant field 145 | t.Errorf("Expecting %d, got %d", z.A, ival(v)) 146 | } 147 | v = m.FieldByName(zv, "b") 148 | if ival(v) != z.B { 149 | t.Errorf("Expecting %d, got %d", z.B, ival(v)) 150 | } 151 | } 152 | 153 | func TestBasicEmbeddedWithSameName(t *testing.T) { 154 | type Foo struct { 155 | A int `db:"a"` 156 | Foo int `db:"Foo"` // Same name as the embedded struct 157 | } 158 | 159 | type FooExt struct { 160 | Foo 161 | B int `db:"b"` 162 | } 163 | 164 | m := NewMapper("db") 165 | 166 | z := FooExt{} 167 | z.A = 1 168 | z.B = 2 169 | z.Foo.Foo = 3 170 | 171 | zv := reflect.ValueOf(z) 172 | fields := m.TypeMap(reflect.TypeOf(z)) 173 | 174 | if len(fields.Index) != 4 { 175 | t.Errorf("Expecting 3 fields, found %d", len(fields.Index)) 176 | } 177 | 178 | v := m.FieldByName(zv, "a") 179 | if ival(v) != z.A { // the dominant field 180 | t.Errorf("Expecting %d, got %d", z.A, ival(v)) 181 | } 182 | v = m.FieldByName(zv, "b") 183 | if ival(v) != z.B { 184 | t.Errorf("Expecting %d, got %d", z.B, ival(v)) 185 | } 186 | v = m.FieldByName(zv, "Foo") 187 | if ival(v) != z.Foo.Foo { 188 | t.Errorf("Expecting %d, got %d", z.Foo.Foo, ival(v)) 189 | } 190 | } 191 | 192 | func TestFlatTags(t *testing.T) { 193 | m := NewMapper("db") 194 | 195 | type Asset struct { 196 | Title string `db:"title"` 197 | } 198 | type Post struct { 199 | Author string `db:"author,required"` 200 | Asset Asset `db:""` 201 | } 202 | // Post columns: (author title) 203 | 204 | post := Post{Author: "Joe", Asset: Asset{Title: "Hello"}} 205 | pv := reflect.ValueOf(post) 206 | 207 | v := m.FieldByName(pv, "author") 208 | if v.Interface().(string) != post.Author { 209 | t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) 210 | } 211 | v = m.FieldByName(pv, "title") 212 | if v.Interface().(string) != post.Asset.Title { 213 | t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) 214 | } 215 | } 216 | 217 | func TestNestedStruct(t *testing.T) { 218 | m := NewMapper("db") 219 | 220 | type Details struct { 221 | Active bool `db:"active"` 222 | } 223 | type Asset struct { 224 | Title string `db:"title"` 225 | Details Details `db:"details"` 226 | } 227 | type Post struct { 228 | Author string `db:"author,required"` 229 | Asset `db:"asset"` 230 | } 231 | // Post columns: (author asset.title asset.details.active) 232 | 233 | post := Post{ 234 | Author: "Joe", 235 | Asset: Asset{Title: "Hello", Details: Details{Active: true}}, 236 | } 237 | pv := reflect.ValueOf(post) 238 | 239 | v := m.FieldByName(pv, "author") 240 | if v.Interface().(string) != post.Author { 241 | t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) 242 | } 243 | v = m.FieldByName(pv, "title") 244 | if _, ok := v.Interface().(string); ok { 245 | t.Errorf("Expecting field to not exist") 246 | } 247 | v = m.FieldByName(pv, "asset.title") 248 | if v.Interface().(string) != post.Asset.Title { 249 | t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) 250 | } 251 | v = m.FieldByName(pv, "asset.details.active") 252 | if v.Interface().(bool) != post.Asset.Details.Active { 253 | t.Errorf("Expecting %v, got %v", post.Asset.Details.Active, v.Interface().(bool)) 254 | } 255 | } 256 | 257 | func TestInlineStruct(t *testing.T) { 258 | m := NewMapperTagFunc("db", strings.ToLower, nil) 259 | 260 | type Employee struct { 261 | Name string 262 | ID int 263 | } 264 | type Boss Employee 265 | type person struct { 266 | Employee `db:"employee"` 267 | Boss `db:"boss"` 268 | } 269 | // employees columns: (employee.name employee.id boss.name boss.id) 270 | 271 | em := person{Employee: Employee{Name: "Joe", ID: 2}, Boss: Boss{Name: "Dick", ID: 1}} 272 | ev := reflect.ValueOf(em) 273 | 274 | fields := m.TypeMap(reflect.TypeOf(em)) 275 | if len(fields.Index) != 6 { 276 | t.Errorf("Expecting 6 fields") 277 | } 278 | 279 | v := m.FieldByName(ev, "employee.name") 280 | if v.Interface().(string) != em.Employee.Name { 281 | t.Errorf("Expecting %s, got %s", em.Employee.Name, v.Interface().(string)) 282 | } 283 | v = m.FieldByName(ev, "boss.id") 284 | if ival(v) != em.Boss.ID { 285 | t.Errorf("Expecting %v, got %v", em.Boss.ID, ival(v)) 286 | } 287 | } 288 | 289 | func TestRecursiveStruct(t *testing.T) { 290 | type Person struct { 291 | Parent *Person 292 | } 293 | m := NewMapperFunc("db", strings.ToLower) 294 | var p *Person 295 | m.TypeMap(reflect.TypeOf(p)) 296 | } 297 | 298 | func TestFieldsEmbedded(t *testing.T) { 299 | m := NewMapper("db") 300 | 301 | type Person struct { 302 | Name string `db:"name,size=64"` 303 | } 304 | type Place struct { 305 | Name string `db:"name"` 306 | } 307 | type Article struct { 308 | Title string `db:"title"` 309 | } 310 | type PP struct { 311 | Person `db:"person,required"` 312 | Place `db:",someflag"` 313 | Article `db:",required"` 314 | } 315 | // PP columns: (person.name name title) 316 | 317 | pp := PP{} 318 | pp.Person.Name = "Peter" 319 | pp.Place.Name = "Toronto" 320 | pp.Article.Title = "Best city ever" 321 | 322 | fields := m.TypeMap(reflect.TypeOf(pp)) 323 | // for i, f := range fields { 324 | // log.Println(i, f) 325 | // } 326 | 327 | ppv := reflect.ValueOf(pp) 328 | 329 | v := m.FieldByName(ppv, "person.name") 330 | if v.Interface().(string) != pp.Person.Name { 331 | t.Errorf("Expecting %s, got %s", pp.Person.Name, v.Interface().(string)) 332 | } 333 | 334 | v = m.FieldByName(ppv, "name") 335 | if v.Interface().(string) != pp.Place.Name { 336 | t.Errorf("Expecting %s, got %s", pp.Place.Name, v.Interface().(string)) 337 | } 338 | 339 | v = m.FieldByName(ppv, "title") 340 | if v.Interface().(string) != pp.Article.Title { 341 | t.Errorf("Expecting %s, got %s", pp.Article.Title, v.Interface().(string)) 342 | } 343 | 344 | fi := fields.GetByPath("person") 345 | if _, ok := fi.Options["required"]; !ok { 346 | t.Errorf("Expecting required option to be set") 347 | } 348 | if !fi.Embedded { 349 | t.Errorf("Expecting field to be embedded") 350 | } 351 | if len(fi.Index) != 1 || fi.Index[0] != 0 { 352 | t.Errorf("Expecting index to be [0]") 353 | } 354 | 355 | fi = fields.GetByPath("person.name") 356 | if fi == nil { 357 | t.Fatal("Expecting person.name to exist") 358 | } 359 | if fi.Path != "person.name" { 360 | t.Errorf("Expecting %s, got %s", "person.name", fi.Path) 361 | } 362 | if fi.Options["size"] != "64" { 363 | t.Errorf("Expecting %s, got %s", "64", fi.Options["size"]) 364 | } 365 | 366 | fi = fields.GetByTraversal([]int{1, 0}) 367 | if fi == nil { 368 | t.Fatal("Expecting traversal to exist") 369 | } 370 | if fi.Path != "name" { 371 | t.Errorf("Expecting %s, got %s", "name", fi.Path) 372 | } 373 | 374 | fi = fields.GetByTraversal([]int{2}) 375 | if fi == nil { 376 | t.Fatal("Expecting traversal to exist") 377 | } 378 | if _, ok := fi.Options["required"]; !ok { 379 | t.Errorf("Expecting required option to be set") 380 | } 381 | 382 | trs := m.TraversalsByName(reflect.TypeOf(pp), []string{"person.name", "name", "title"}) 383 | if !reflect.DeepEqual(trs, [][]int{{0, 0}, {1, 0}, {2, 0}}) { 384 | t.Errorf("Expecting traversal: %v", trs) 385 | } 386 | } 387 | 388 | func TestPtrFields(t *testing.T) { 389 | m := NewMapperTagFunc("db", strings.ToLower, nil) 390 | type Asset struct { 391 | Title string 392 | } 393 | type Post struct { 394 | *Asset `db:"asset"` 395 | Author string 396 | } 397 | 398 | post := &Post{Author: "Joe", Asset: &Asset{Title: "Hiyo"}} 399 | pv := reflect.ValueOf(post) 400 | 401 | fields := m.TypeMap(reflect.TypeOf(post)) 402 | if len(fields.Index) != 3 { 403 | t.Errorf("Expecting 3 fields") 404 | } 405 | 406 | v := m.FieldByName(pv, "asset.title") 407 | if v.Interface().(string) != post.Asset.Title { 408 | t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) 409 | } 410 | v = m.FieldByName(pv, "author") 411 | if v.Interface().(string) != post.Author { 412 | t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) 413 | } 414 | } 415 | 416 | func TestNamedPtrFields(t *testing.T) { 417 | m := NewMapperTagFunc("db", strings.ToLower, nil) 418 | 419 | type User struct { 420 | Name string 421 | } 422 | 423 | type Asset struct { 424 | Title string 425 | 426 | Owner *User `db:"owner"` 427 | } 428 | type Post struct { 429 | Author string 430 | 431 | Asset1 *Asset `db:"asset1"` 432 | Asset2 *Asset `db:"asset2"` 433 | } 434 | 435 | post := &Post{Author: "Joe", Asset1: &Asset{Title: "Hiyo", Owner: &User{"Username"}}} // Let Asset2 be nil 436 | pv := reflect.ValueOf(post) 437 | 438 | fields := m.TypeMap(reflect.TypeOf(post)) 439 | if len(fields.Index) != 9 { 440 | t.Errorf("Expecting 9 fields") 441 | } 442 | 443 | v := m.FieldByName(pv, "asset1.title") 444 | if v.Interface().(string) != post.Asset1.Title { 445 | t.Errorf("Expecting %s, got %s", post.Asset1.Title, v.Interface().(string)) 446 | } 447 | v = m.FieldByName(pv, "asset1.owner.name") 448 | if v.Interface().(string) != post.Asset1.Owner.Name { 449 | t.Errorf("Expecting %s, got %s", post.Asset1.Owner.Name, v.Interface().(string)) 450 | } 451 | v = m.FieldByName(pv, "asset2.title") 452 | if v.Interface().(string) != post.Asset2.Title { 453 | t.Errorf("Expecting %s, got %s", post.Asset2.Title, v.Interface().(string)) 454 | } 455 | v = m.FieldByName(pv, "asset2.owner.name") 456 | if v.Interface().(string) != post.Asset2.Owner.Name { 457 | t.Errorf("Expecting %s, got %s", post.Asset2.Owner.Name, v.Interface().(string)) 458 | } 459 | v = m.FieldByName(pv, "author") 460 | if v.Interface().(string) != post.Author { 461 | t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) 462 | } 463 | } 464 | 465 | func TestFieldMap(t *testing.T) { 466 | type Foo struct { 467 | A int 468 | B int 469 | C int 470 | } 471 | 472 | f := Foo{1, 2, 3} 473 | m := NewMapperFunc("db", strings.ToLower) 474 | 475 | fm := m.FieldMap(reflect.ValueOf(f)) 476 | 477 | if len(fm) != 3 { 478 | t.Errorf("Expecting %d keys, got %d", 3, len(fm)) 479 | } 480 | if fm["a"].Interface().(int) != 1 { 481 | t.Errorf("Expecting %d, got %d", 1, ival(fm["a"])) 482 | } 483 | if fm["b"].Interface().(int) != 2 { 484 | t.Errorf("Expecting %d, got %d", 2, ival(fm["b"])) 485 | } 486 | if fm["c"].Interface().(int) != 3 { 487 | t.Errorf("Expecting %d, got %d", 3, ival(fm["c"])) 488 | } 489 | } 490 | 491 | func TestTagNameMapping(t *testing.T) { 492 | type Strategy struct { 493 | StrategyID string `protobuf:"bytes,1,opt,name=strategy_id" json:"strategy_id,omitempty"` 494 | StrategyName string 495 | } 496 | 497 | m := NewMapperTagFunc("json", strings.ToUpper, func(value string) string { 498 | if strings.Contains(value, ",") { 499 | return strings.Split(value, ",")[0] 500 | } 501 | return value 502 | }) 503 | strategy := Strategy{"1", "Alpah"} 504 | mapping := m.TypeMap(reflect.TypeOf(strategy)) 505 | 506 | for _, key := range []string{"strategy_id", "STRATEGYNAME"} { 507 | if fi := mapping.GetByPath(key); fi == nil { 508 | t.Errorf("Expecting to find key %s in mapping but did not.", key) 509 | } 510 | } 511 | } 512 | 513 | func TestMapping(t *testing.T) { 514 | type Person struct { 515 | ID int 516 | Name string 517 | WearsGlasses bool `db:"wears_glasses"` 518 | } 519 | 520 | m := NewMapperFunc("db", strings.ToLower) 521 | p := Person{1, "Jason", true} 522 | mapping := m.TypeMap(reflect.TypeOf(p)) 523 | 524 | for _, key := range []string{"id", "name", "wears_glasses"} { 525 | if fi := mapping.GetByPath(key); fi == nil { 526 | t.Errorf("Expecting to find key %s in mapping but did not.", key) 527 | } 528 | } 529 | 530 | type SportsPerson struct { 531 | Weight int 532 | Age int 533 | Person 534 | } 535 | s := SportsPerson{Weight: 100, Age: 30, Person: p} 536 | mapping = m.TypeMap(reflect.TypeOf(s)) 537 | for _, key := range []string{"id", "name", "wears_glasses", "weight", "age"} { 538 | if fi := mapping.GetByPath(key); fi == nil { 539 | t.Errorf("Expecting to find key %s in mapping but did not.", key) 540 | } 541 | } 542 | 543 | type RugbyPlayer struct { 544 | Position int 545 | IsIntense bool `db:"is_intense"` 546 | IsAllBlack bool `db:"-"` 547 | SportsPerson 548 | } 549 | r := RugbyPlayer{12, true, false, s} 550 | mapping = m.TypeMap(reflect.TypeOf(r)) 551 | for _, key := range []string{"id", "name", "wears_glasses", "weight", "age", "position", "is_intense"} { 552 | if fi := mapping.GetByPath(key); fi == nil { 553 | t.Errorf("Expecting to find key %s in mapping but did not.", key) 554 | } 555 | } 556 | 557 | if fi := mapping.GetByPath("isallblack"); fi != nil { 558 | t.Errorf("Expecting to ignore `IsAllBlack` field") 559 | } 560 | } 561 | 562 | func TestGetByTraversal(t *testing.T) { 563 | type C struct { 564 | C0 int 565 | C1 int 566 | } 567 | type B struct { 568 | B0 string 569 | B1 *C 570 | } 571 | type A struct { 572 | A0 int 573 | A1 B 574 | } 575 | 576 | testCases := []struct { 577 | Index []int 578 | ExpectedName string 579 | ExpectNil bool 580 | }{ 581 | { 582 | Index: []int{0}, 583 | ExpectedName: "A0", 584 | }, 585 | { 586 | Index: []int{1, 0}, 587 | ExpectedName: "B0", 588 | }, 589 | { 590 | Index: []int{1, 1, 1}, 591 | ExpectedName: "C1", 592 | }, 593 | { 594 | Index: []int{3, 4, 5}, 595 | ExpectNil: true, 596 | }, 597 | { 598 | Index: []int{}, 599 | ExpectNil: true, 600 | }, 601 | { 602 | Index: nil, 603 | ExpectNil: true, 604 | }, 605 | } 606 | 607 | m := NewMapperFunc("db", func(n string) string { return n }) 608 | tm := m.TypeMap(reflect.TypeOf(A{})) 609 | 610 | for i, tc := range testCases { 611 | fi := tm.GetByTraversal(tc.Index) 612 | if tc.ExpectNil { 613 | if fi != nil { 614 | t.Errorf("%d: expected nil, got %v", i, fi) 615 | } 616 | continue 617 | } 618 | 619 | if fi == nil { 620 | t.Errorf("%d: expected %s, got nil", i, tc.ExpectedName) 621 | continue 622 | } 623 | 624 | if fi.Name != tc.ExpectedName { 625 | t.Errorf("%d: expected %s, got %s", i, tc.ExpectedName, fi.Name) 626 | } 627 | } 628 | } 629 | 630 | // TestMapperMethodsByName tests Mapper methods FieldByName and TraversalsByName 631 | func TestMapperMethodsByName(t *testing.T) { 632 | type C struct { 633 | C0 string 634 | C1 int 635 | } 636 | type B struct { 637 | B0 *C `db:"B0"` 638 | B1 C `db:"B1"` 639 | B2 string `db:"B2"` 640 | } 641 | type A struct { 642 | A0 *B `db:"A0"` 643 | B `db:"A1"` 644 | A2 int 645 | } 646 | 647 | val := &A{ 648 | A0: &B{ 649 | B0: &C{C0: "0", C1: 1}, 650 | B1: C{C0: "2", C1: 3}, 651 | B2: "4", 652 | }, 653 | B: B{ 654 | B0: nil, 655 | B1: C{C0: "5", C1: 6}, 656 | B2: "7", 657 | }, 658 | A2: 8, 659 | } 660 | 661 | testCases := []struct { 662 | Name string 663 | ExpectInvalid bool 664 | ExpectedValue interface{} 665 | ExpectedIndexes []int 666 | }{ 667 | { 668 | Name: "A0.B0.C0", 669 | ExpectedValue: "0", 670 | ExpectedIndexes: []int{0, 0, 0}, 671 | }, 672 | { 673 | Name: "A0.B0.C1", 674 | ExpectedValue: 1, 675 | ExpectedIndexes: []int{0, 0, 1}, 676 | }, 677 | { 678 | Name: "A0.B1.C0", 679 | ExpectedValue: "2", 680 | ExpectedIndexes: []int{0, 1, 0}, 681 | }, 682 | { 683 | Name: "A0.B1.C1", 684 | ExpectedValue: 3, 685 | ExpectedIndexes: []int{0, 1, 1}, 686 | }, 687 | { 688 | Name: "A0.B2", 689 | ExpectedValue: "4", 690 | ExpectedIndexes: []int{0, 2}, 691 | }, 692 | { 693 | Name: "A1.B0.C0", 694 | ExpectedValue: "", 695 | ExpectedIndexes: []int{1, 0, 0}, 696 | }, 697 | { 698 | Name: "A1.B0.C1", 699 | ExpectedValue: 0, 700 | ExpectedIndexes: []int{1, 0, 1}, 701 | }, 702 | { 703 | Name: "A1.B1.C0", 704 | ExpectedValue: "5", 705 | ExpectedIndexes: []int{1, 1, 0}, 706 | }, 707 | { 708 | Name: "A1.B1.C1", 709 | ExpectedValue: 6, 710 | ExpectedIndexes: []int{1, 1, 1}, 711 | }, 712 | { 713 | Name: "A1.B2", 714 | ExpectedValue: "7", 715 | ExpectedIndexes: []int{1, 2}, 716 | }, 717 | { 718 | Name: "A2", 719 | ExpectedValue: 8, 720 | ExpectedIndexes: []int{2}, 721 | }, 722 | { 723 | Name: "XYZ", 724 | ExpectInvalid: true, 725 | ExpectedIndexes: []int{}, 726 | }, 727 | { 728 | Name: "a3", 729 | ExpectInvalid: true, 730 | ExpectedIndexes: []int{}, 731 | }, 732 | } 733 | 734 | // build the names array from the test cases 735 | names := make([]string, len(testCases)) 736 | for i, tc := range testCases { 737 | names[i] = tc.Name 738 | } 739 | m := NewMapperFunc("db", func(n string) string { return n }) 740 | v := reflect.ValueOf(val) 741 | values := m.FieldsByName(v, names) 742 | if len(values) != len(testCases) { 743 | t.Errorf("expected %d values, got %d", len(testCases), len(values)) 744 | t.FailNow() 745 | } 746 | indexes := m.TraversalsByName(v.Type(), names) 747 | if len(indexes) != len(testCases) { 748 | t.Errorf("expected %d traversals, got %d", len(testCases), len(indexes)) 749 | t.FailNow() 750 | } 751 | for i, val := range values { 752 | tc := testCases[i] 753 | traversal := indexes[i] 754 | if !reflect.DeepEqual(tc.ExpectedIndexes, traversal) { 755 | t.Errorf("expected %v, got %v", tc.ExpectedIndexes, traversal) 756 | t.FailNow() 757 | } 758 | val = reflect.Indirect(val) 759 | if tc.ExpectInvalid { 760 | if val.IsValid() { 761 | t.Errorf("%d: expected zero value, got %v", i, val) 762 | } 763 | continue 764 | } 765 | if !val.IsValid() { 766 | t.Errorf("%d: expected valid value, got %v", i, val) 767 | continue 768 | } 769 | actualValue := reflect.Indirect(val).Interface() 770 | if !reflect.DeepEqual(tc.ExpectedValue, actualValue) { 771 | t.Errorf("%d: expected %v, got %v", i, tc.ExpectedValue, actualValue) 772 | } 773 | } 774 | } 775 | 776 | func TestFieldByIndexes(t *testing.T) { 777 | type C struct { 778 | C0 bool 779 | C1 string 780 | C2 int 781 | C3 map[string]int 782 | } 783 | type B struct { 784 | B1 C 785 | B2 *C 786 | } 787 | type A struct { 788 | A1 B 789 | A2 *B 790 | } 791 | testCases := []struct { 792 | value interface{} 793 | indexes []int 794 | expectedValue interface{} 795 | readOnly bool 796 | }{ 797 | { 798 | value: A{ 799 | A1: B{B1: C{C0: true}}, 800 | }, 801 | indexes: []int{0, 0, 0}, 802 | expectedValue: true, 803 | readOnly: true, 804 | }, 805 | { 806 | value: A{ 807 | A2: &B{B2: &C{C1: "answer"}}, 808 | }, 809 | indexes: []int{1, 1, 1}, 810 | expectedValue: "answer", 811 | readOnly: true, 812 | }, 813 | { 814 | value: &A{}, 815 | indexes: []int{1, 1, 3}, 816 | expectedValue: map[string]int{}, 817 | }, 818 | } 819 | 820 | for i, tc := range testCases { 821 | checkResults := func(v reflect.Value) { 822 | if tc.expectedValue == nil { 823 | if !v.IsNil() { 824 | t.Errorf("%d: expected nil, actual %v", i, v.Interface()) 825 | } 826 | } else { 827 | if !reflect.DeepEqual(tc.expectedValue, v.Interface()) { 828 | t.Errorf("%d: expected %v, actual %v", i, tc.expectedValue, v.Interface()) 829 | } 830 | } 831 | } 832 | 833 | checkResults(FieldByIndexes(reflect.ValueOf(tc.value), tc.indexes)) 834 | if tc.readOnly { 835 | checkResults(FieldByIndexesReadOnly(reflect.ValueOf(tc.value), tc.indexes)) 836 | } 837 | } 838 | } 839 | 840 | func TestMustBe(t *testing.T) { 841 | typ := reflect.TypeOf(E1{}) 842 | mustBe(typ, reflect.Struct) 843 | 844 | defer func() { 845 | if r := recover(); r != nil { 846 | valueErr, ok := r.(*reflect.ValueError) 847 | if !ok { 848 | t.Errorf("unexpected Method: %s", valueErr.Method) 849 | t.Fatal("expected panic with *reflect.ValueError") 850 | } 851 | if valueErr.Method != "github.com/jmoiron/sqlx/reflectx.TestMustBe" { 852 | t.Fatalf("unexpected Method: %s", valueErr.Method) 853 | } 854 | if valueErr.Kind != reflect.String { 855 | t.Fatalf("unexpected Kind: %s", valueErr.Kind) 856 | } 857 | } else { 858 | t.Fatal("expected panic") 859 | } 860 | }() 861 | 862 | typ = reflect.TypeOf("string") 863 | mustBe(typ, reflect.Struct) 864 | t.Fatal("got here, didn't expect to") 865 | } 866 | 867 | type E1 struct { 868 | A int 869 | } 870 | type E2 struct { 871 | E1 872 | B int 873 | } 874 | type E3 struct { 875 | E2 876 | C int 877 | } 878 | type E4 struct { 879 | E3 880 | D int 881 | } 882 | 883 | func BenchmarkFieldNameL1(b *testing.B) { 884 | e4 := E4{D: 1} 885 | for i := 0; i < b.N; i++ { 886 | v := reflect.ValueOf(e4) 887 | f := v.FieldByName("D") 888 | if f.Interface().(int) != 1 { 889 | b.Fatal("Wrong value.") 890 | } 891 | } 892 | } 893 | 894 | func BenchmarkFieldNameL4(b *testing.B) { 895 | e4 := E4{} 896 | e4.A = 1 897 | for i := 0; i < b.N; i++ { 898 | v := reflect.ValueOf(e4) 899 | f := v.FieldByName("A") 900 | if f.Interface().(int) != 1 { 901 | b.Fatal("Wrong value.") 902 | } 903 | } 904 | } 905 | 906 | func BenchmarkFieldPosL1(b *testing.B) { 907 | e4 := E4{D: 1} 908 | for i := 0; i < b.N; i++ { 909 | v := reflect.ValueOf(e4) 910 | f := v.Field(1) 911 | if f.Interface().(int) != 1 { 912 | b.Fatal("Wrong value.") 913 | } 914 | } 915 | } 916 | 917 | func BenchmarkFieldPosL4(b *testing.B) { 918 | e4 := E4{} 919 | e4.A = 1 920 | for i := 0; i < b.N; i++ { 921 | v := reflect.ValueOf(e4) 922 | f := v.Field(0) 923 | f = f.Field(0) 924 | f = f.Field(0) 925 | f = f.Field(0) 926 | if f.Interface().(int) != 1 { 927 | b.Fatal("Wrong value.") 928 | } 929 | } 930 | } 931 | 932 | func BenchmarkFieldByIndexL4(b *testing.B) { 933 | e4 := E4{} 934 | e4.A = 1 935 | idx := []int{0, 0, 0, 0} 936 | for i := 0; i < b.N; i++ { 937 | v := reflect.ValueOf(e4) 938 | f := FieldByIndexes(v, idx) 939 | if f.Interface().(int) != 1 { 940 | b.Fatal("Wrong value.") 941 | } 942 | } 943 | } 944 | 945 | func BenchmarkTraversalsByName(b *testing.B) { 946 | type A struct { 947 | Value int 948 | } 949 | 950 | type B struct { 951 | A A 952 | } 953 | 954 | type C struct { 955 | B B 956 | } 957 | 958 | type D struct { 959 | C C 960 | } 961 | 962 | m := NewMapper("") 963 | t := reflect.TypeOf(D{}) 964 | names := []string{"C", "B", "A", "Value"} 965 | 966 | b.ResetTimer() 967 | 968 | for i := 0; i < b.N; i++ { 969 | if l := len(m.TraversalsByName(t, names)); l != len(names) { 970 | b.Errorf("expected %d values, got %d", len(names), l) 971 | } 972 | } 973 | } 974 | 975 | func BenchmarkTraversalsByNameFunc(b *testing.B) { 976 | type A struct { 977 | Z int 978 | } 979 | 980 | type B struct { 981 | A A 982 | } 983 | 984 | type C struct { 985 | B B 986 | } 987 | 988 | type D struct { 989 | C C 990 | } 991 | 992 | m := NewMapper("") 993 | t := reflect.TypeOf(D{}) 994 | names := []string{"C", "B", "A", "Z", "Y"} 995 | 996 | b.ResetTimer() 997 | 998 | for i := 0; i < b.N; i++ { 999 | var l int 1000 | 1001 | if err := m.TraversalsByNameFunc(t, names, func(_ int, _ []int) error { 1002 | l++ 1003 | return nil 1004 | }); err != nil { 1005 | b.Errorf("unexpected error %s", err) 1006 | } 1007 | 1008 | if l != len(names) { 1009 | b.Errorf("expected %d values, got %d", len(names), l) 1010 | } 1011 | } 1012 | } 1013 | -------------------------------------------------------------------------------- /sqlx.go: -------------------------------------------------------------------------------- 1 | package sqlx 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "errors" 7 | "fmt" 8 | "io/ioutil" 9 | "path/filepath" 10 | "reflect" 11 | "strings" 12 | "sync" 13 | 14 | "github.com/jmoiron/sqlx/reflectx" 15 | ) 16 | 17 | // Although the NameMapper is convenient, in practice it should not 18 | // be relied on except for application code. If you are writing a library 19 | // that uses sqlx, you should be aware that the name mappings you expect 20 | // can be overridden by your user's application. 21 | 22 | // NameMapper is used to map column names to struct field names. By default, 23 | // it uses strings.ToLower to lowercase struct field names. It can be set 24 | // to whatever you want, but it is encouraged to be set before sqlx is used 25 | // as name-to-field mappings are cached after first use on a type. 26 | var NameMapper = strings.ToLower 27 | var origMapper = reflect.ValueOf(NameMapper) 28 | 29 | // Rather than creating on init, this is created when necessary so that 30 | // importers have time to customize the NameMapper. 31 | var mpr *reflectx.Mapper 32 | 33 | // mprMu protects mpr. 34 | var mprMu sync.Mutex 35 | 36 | // mapper returns a valid mapper using the configured NameMapper func. 37 | func mapper() *reflectx.Mapper { 38 | mprMu.Lock() 39 | defer mprMu.Unlock() 40 | 41 | if mpr == nil { 42 | mpr = reflectx.NewMapperFunc("db", NameMapper) 43 | } else if origMapper != reflect.ValueOf(NameMapper) { 44 | // if NameMapper has changed, create a new mapper 45 | mpr = reflectx.NewMapperFunc("db", NameMapper) 46 | origMapper = reflect.ValueOf(NameMapper) 47 | } 48 | return mpr 49 | } 50 | 51 | // isScannable takes the reflect.Type and the actual dest value and returns 52 | // whether or not it's Scannable. Something is scannable if: 53 | // - it is not a struct 54 | // - it implements sql.Scanner 55 | // - it has no exported fields 56 | func isScannable(t reflect.Type) bool { 57 | if reflect.PtrTo(t).Implements(_scannerInterface) { 58 | return true 59 | } 60 | if t.Kind() != reflect.Struct { 61 | return true 62 | } 63 | 64 | // it's not important that we use the right mapper for this particular object, 65 | // we're only concerned on how many exported fields this struct has 66 | return len(mapper().TypeMap(t).Index) == 0 67 | } 68 | 69 | // ColScanner is an interface used by MapScan and SliceScan 70 | type ColScanner interface { 71 | Columns() ([]string, error) 72 | Scan(dest ...interface{}) error 73 | Err() error 74 | } 75 | 76 | // Queryer is an interface used by Get and Select 77 | type Queryer interface { 78 | Query(query string, args ...interface{}) (*sql.Rows, error) 79 | Queryx(query string, args ...interface{}) (*Rows, error) 80 | QueryRowx(query string, args ...interface{}) *Row 81 | } 82 | 83 | // Execer is an interface used by MustExec and LoadFile 84 | type Execer interface { 85 | Exec(query string, args ...interface{}) (sql.Result, error) 86 | } 87 | 88 | // Binder is an interface for something which can bind queries (Tx, DB) 89 | type binder interface { 90 | DriverName() string 91 | Rebind(string) string 92 | BindNamed(string, interface{}) (string, []interface{}, error) 93 | } 94 | 95 | // Ext is a union interface which can bind, query, and exec, used by 96 | // NamedQuery and NamedExec. 97 | type Ext interface { 98 | binder 99 | Queryer 100 | Execer 101 | } 102 | 103 | // Preparer is an interface used by Preparex. 104 | type Preparer interface { 105 | Prepare(query string) (*sql.Stmt, error) 106 | } 107 | 108 | // determine if any of our extensions are unsafe 109 | func isUnsafe(i interface{}) bool { 110 | switch v := i.(type) { 111 | case Row: 112 | return v.unsafe 113 | case *Row: 114 | return v.unsafe 115 | case Rows: 116 | return v.unsafe 117 | case *Rows: 118 | return v.unsafe 119 | case NamedStmt: 120 | return v.Stmt.unsafe 121 | case *NamedStmt: 122 | return v.Stmt.unsafe 123 | case Stmt: 124 | return v.unsafe 125 | case *Stmt: 126 | return v.unsafe 127 | case qStmt: 128 | return v.unsafe 129 | case *qStmt: 130 | return v.unsafe 131 | case DB: 132 | return v.unsafe 133 | case *DB: 134 | return v.unsafe 135 | case Tx: 136 | return v.unsafe 137 | case *Tx: 138 | return v.unsafe 139 | case sql.Rows, *sql.Rows: 140 | return false 141 | default: 142 | return false 143 | } 144 | } 145 | 146 | func mapperFor(i interface{}) *reflectx.Mapper { 147 | switch i := i.(type) { 148 | case DB: 149 | return i.Mapper 150 | case *DB: 151 | return i.Mapper 152 | case Tx: 153 | return i.Mapper 154 | case *Tx: 155 | return i.Mapper 156 | default: 157 | return mapper() 158 | } 159 | } 160 | 161 | var _scannerInterface = reflect.TypeOf((*sql.Scanner)(nil)).Elem() 162 | 163 | //lint:ignore U1000 ignoring this for now 164 | var _valuerInterface = reflect.TypeOf((*driver.Valuer)(nil)).Elem() 165 | 166 | // Row is a reimplementation of sql.Row in order to gain access to the underlying 167 | // sql.Rows.Columns() data, necessary for StructScan. 168 | type Row struct { 169 | err error 170 | unsafe bool 171 | rows *sql.Rows 172 | Mapper *reflectx.Mapper 173 | } 174 | 175 | // Scan is a fixed implementation of sql.Row.Scan, which does not discard the 176 | // underlying error from the internal rows object if it exists. 177 | func (r *Row) Scan(dest ...interface{}) error { 178 | if r.err != nil { 179 | return r.err 180 | } 181 | 182 | // TODO(bradfitz): for now we need to defensively clone all 183 | // []byte that the driver returned (not permitting 184 | // *RawBytes in Rows.Scan), since we're about to close 185 | // the Rows in our defer, when we return from this function. 186 | // the contract with the driver.Next(...) interface is that it 187 | // can return slices into read-only temporary memory that's 188 | // only valid until the next Scan/Close. But the TODO is that 189 | // for a lot of drivers, this copy will be unnecessary. We 190 | // should provide an optional interface for drivers to 191 | // implement to say, "don't worry, the []bytes that I return 192 | // from Next will not be modified again." (for instance, if 193 | // they were obtained from the network anyway) But for now we 194 | // don't care. 195 | defer r.rows.Close() 196 | for _, dp := range dest { 197 | if _, ok := dp.(*sql.RawBytes); ok { 198 | return errors.New("sql: RawBytes isn't allowed on Row.Scan") 199 | } 200 | } 201 | 202 | if !r.rows.Next() { 203 | if err := r.rows.Err(); err != nil { 204 | return err 205 | } 206 | return sql.ErrNoRows 207 | } 208 | err := r.rows.Scan(dest...) 209 | if err != nil { 210 | return err 211 | } 212 | // Make sure the query can be processed to completion with no errors. 213 | if err := r.rows.Close(); err != nil { 214 | return err 215 | } 216 | return nil 217 | } 218 | 219 | // Columns returns the underlying sql.Rows.Columns(), or the deferred error usually 220 | // returned by Row.Scan() 221 | func (r *Row) Columns() ([]string, error) { 222 | if r.err != nil { 223 | return []string{}, r.err 224 | } 225 | return r.rows.Columns() 226 | } 227 | 228 | // ColumnTypes returns the underlying sql.Rows.ColumnTypes(), or the deferred error 229 | func (r *Row) ColumnTypes() ([]*sql.ColumnType, error) { 230 | if r.err != nil { 231 | return []*sql.ColumnType{}, r.err 232 | } 233 | return r.rows.ColumnTypes() 234 | } 235 | 236 | // Err returns the error encountered while scanning. 237 | func (r *Row) Err() error { 238 | return r.err 239 | } 240 | 241 | // DB is a wrapper around sql.DB which keeps track of the driverName upon Open, 242 | // used mostly to automatically bind named queries using the right bindvars. 243 | type DB struct { 244 | *sql.DB 245 | driverName string 246 | unsafe bool 247 | Mapper *reflectx.Mapper 248 | } 249 | 250 | // NewDb returns a new sqlx DB wrapper for a pre-existing *sql.DB. The 251 | // driverName of the original database is required for named query support. 252 | // 253 | //lint:ignore ST1003 changing this would break the package interface. 254 | func NewDb(db *sql.DB, driverName string) *DB { 255 | return &DB{DB: db, driverName: driverName, Mapper: mapper()} 256 | } 257 | 258 | // DriverName returns the driverName passed to the Open function for this DB. 259 | func (db *DB) DriverName() string { 260 | return db.driverName 261 | } 262 | 263 | // Open is the same as sql.Open, but returns an *sqlx.DB instead. 264 | func Open(driverName, dataSourceName string) (*DB, error) { 265 | db, err := sql.Open(driverName, dataSourceName) 266 | if err != nil { 267 | return nil, err 268 | } 269 | return &DB{DB: db, driverName: driverName, Mapper: mapper()}, err 270 | } 271 | 272 | // MustOpen is the same as sql.Open, but returns an *sqlx.DB instead and panics on error. 273 | func MustOpen(driverName, dataSourceName string) *DB { 274 | db, err := Open(driverName, dataSourceName) 275 | if err != nil { 276 | panic(err) 277 | } 278 | return db 279 | } 280 | 281 | // MapperFunc sets a new mapper for this db using the default sqlx struct tag 282 | // and the provided mapper function. 283 | func (db *DB) MapperFunc(mf func(string) string) { 284 | db.Mapper = reflectx.NewMapperFunc("db", mf) 285 | } 286 | 287 | // Rebind transforms a query from QUESTION to the DB driver's bindvar type. 288 | func (db *DB) Rebind(query string) string { 289 | return Rebind(BindType(db.driverName), query) 290 | } 291 | 292 | // Unsafe returns a version of DB which will silently succeed to scan when 293 | // columns in the SQL result have no fields in the destination struct. 294 | // sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its 295 | // safety behavior. 296 | func (db *DB) Unsafe() *DB { 297 | return &DB{DB: db.DB, driverName: db.driverName, unsafe: true, Mapper: db.Mapper} 298 | } 299 | 300 | // BindNamed binds a query using the DB driver's bindvar type. 301 | func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, error) { 302 | return bindNamedMapper(BindType(db.driverName), query, arg, db.Mapper) 303 | } 304 | 305 | // NamedQuery using this DB. 306 | // Any named placeholder parameters are replaced with fields from arg. 307 | func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) { 308 | return NamedQuery(db, query, arg) 309 | } 310 | 311 | // NamedExec using this DB. 312 | // Any named placeholder parameters are replaced with fields from arg. 313 | func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) { 314 | return NamedExec(db, query, arg) 315 | } 316 | 317 | // Select using this DB. 318 | // Any placeholder parameters are replaced with supplied args. 319 | func (db *DB) Select(dest interface{}, query string, args ...interface{}) error { 320 | return Select(db, dest, query, args...) 321 | } 322 | 323 | // Get using this DB. 324 | // Any placeholder parameters are replaced with supplied args. 325 | // An error is returned if the result set is empty. 326 | func (db *DB) Get(dest interface{}, query string, args ...interface{}) error { 327 | return Get(db, dest, query, args...) 328 | } 329 | 330 | // MustBegin starts a transaction, and panics on error. Returns an *sqlx.Tx instead 331 | // of an *sql.Tx. 332 | func (db *DB) MustBegin() *Tx { 333 | tx, err := db.Beginx() 334 | if err != nil { 335 | panic(err) 336 | } 337 | return tx 338 | } 339 | 340 | // Beginx begins a transaction and returns an *sqlx.Tx instead of an *sql.Tx. 341 | func (db *DB) Beginx() (*Tx, error) { 342 | tx, err := db.DB.Begin() 343 | if err != nil { 344 | return nil, err 345 | } 346 | return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err 347 | } 348 | 349 | // Queryx queries the database and returns an *sqlx.Rows. 350 | // Any placeholder parameters are replaced with supplied args. 351 | func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { 352 | r, err := db.DB.Query(query, args...) 353 | if err != nil { 354 | return nil, err 355 | } 356 | return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err 357 | } 358 | 359 | // QueryRowx queries the database and returns an *sqlx.Row. 360 | // Any placeholder parameters are replaced with supplied args. 361 | func (db *DB) QueryRowx(query string, args ...interface{}) *Row { 362 | rows, err := db.DB.Query(query, args...) 363 | return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} 364 | } 365 | 366 | // MustExec (panic) runs MustExec using this database. 367 | // Any placeholder parameters are replaced with supplied args. 368 | func (db *DB) MustExec(query string, args ...interface{}) sql.Result { 369 | return MustExec(db, query, args...) 370 | } 371 | 372 | // Preparex returns an sqlx.Stmt instead of a sql.Stmt 373 | func (db *DB) Preparex(query string) (*Stmt, error) { 374 | return Preparex(db, query) 375 | } 376 | 377 | // PrepareNamed returns an sqlx.NamedStmt 378 | func (db *DB) PrepareNamed(query string) (*NamedStmt, error) { 379 | return prepareNamed(db, query) 380 | } 381 | 382 | // Conn is a wrapper around sql.Conn with extra functionality 383 | type Conn struct { 384 | *sql.Conn 385 | driverName string 386 | unsafe bool 387 | Mapper *reflectx.Mapper 388 | } 389 | 390 | // Tx is an sqlx wrapper around sql.Tx with extra functionality 391 | type Tx struct { 392 | *sql.Tx 393 | driverName string 394 | unsafe bool 395 | Mapper *reflectx.Mapper 396 | } 397 | 398 | // DriverName returns the driverName used by the DB which began this transaction. 399 | func (tx *Tx) DriverName() string { 400 | return tx.driverName 401 | } 402 | 403 | // Rebind a query within a transaction's bindvar type. 404 | func (tx *Tx) Rebind(query string) string { 405 | return Rebind(BindType(tx.driverName), query) 406 | } 407 | 408 | // Unsafe returns a version of Tx which will silently succeed to scan when 409 | // columns in the SQL result have no fields in the destination struct. 410 | func (tx *Tx) Unsafe() *Tx { 411 | return &Tx{Tx: tx.Tx, driverName: tx.driverName, unsafe: true, Mapper: tx.Mapper} 412 | } 413 | 414 | // BindNamed binds a query within a transaction's bindvar type. 415 | func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, error) { 416 | return bindNamedMapper(BindType(tx.driverName), query, arg, tx.Mapper) 417 | } 418 | 419 | // NamedQuery within a transaction. 420 | // Any named placeholder parameters are replaced with fields from arg. 421 | func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) { 422 | return NamedQuery(tx, query, arg) 423 | } 424 | 425 | // NamedExec a named query within a transaction. 426 | // Any named placeholder parameters are replaced with fields from arg. 427 | func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) { 428 | return NamedExec(tx, query, arg) 429 | } 430 | 431 | // Select within a transaction. 432 | // Any placeholder parameters are replaced with supplied args. 433 | func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error { 434 | return Select(tx, dest, query, args...) 435 | } 436 | 437 | // Queryx within a transaction. 438 | // Any placeholder parameters are replaced with supplied args. 439 | func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { 440 | r, err := tx.Tx.Query(query, args...) 441 | if err != nil { 442 | return nil, err 443 | } 444 | return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err 445 | } 446 | 447 | // QueryRowx within a transaction. 448 | // Any placeholder parameters are replaced with supplied args. 449 | func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row { 450 | rows, err := tx.Tx.Query(query, args...) 451 | return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} 452 | } 453 | 454 | // Get within a transaction. 455 | // Any placeholder parameters are replaced with supplied args. 456 | // An error is returned if the result set is empty. 457 | func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error { 458 | return Get(tx, dest, query, args...) 459 | } 460 | 461 | // MustExec runs MustExec within a transaction. 462 | // Any placeholder parameters are replaced with supplied args. 463 | func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result { 464 | return MustExec(tx, query, args...) 465 | } 466 | 467 | // Preparex a statement within a transaction. 468 | func (tx *Tx) Preparex(query string) (*Stmt, error) { 469 | return Preparex(tx, query) 470 | } 471 | 472 | // Stmtx returns a version of the prepared statement which runs within a transaction. Provided 473 | // stmt can be either *sql.Stmt or *sqlx.Stmt. 474 | func (tx *Tx) Stmtx(stmt interface{}) *Stmt { 475 | var s *sql.Stmt 476 | switch v := stmt.(type) { 477 | case Stmt: 478 | s = v.Stmt 479 | case *Stmt: 480 | s = v.Stmt 481 | case *sql.Stmt: 482 | s = v 483 | default: 484 | panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) 485 | } 486 | return &Stmt{Stmt: tx.Stmt(s), Mapper: tx.Mapper} 487 | } 488 | 489 | // NamedStmt returns a version of the prepared statement which runs within a transaction. 490 | func (tx *Tx) NamedStmt(stmt *NamedStmt) *NamedStmt { 491 | return &NamedStmt{ 492 | QueryString: stmt.QueryString, 493 | Params: stmt.Params, 494 | Stmt: tx.Stmtx(stmt.Stmt), 495 | } 496 | } 497 | 498 | // PrepareNamed returns an sqlx.NamedStmt 499 | func (tx *Tx) PrepareNamed(query string) (*NamedStmt, error) { 500 | return prepareNamed(tx, query) 501 | } 502 | 503 | // Stmt is an sqlx wrapper around sql.Stmt with extra functionality 504 | type Stmt struct { 505 | *sql.Stmt 506 | unsafe bool 507 | Mapper *reflectx.Mapper 508 | } 509 | 510 | // Unsafe returns a version of Stmt which will silently succeed to scan when 511 | // columns in the SQL result have no fields in the destination struct. 512 | func (s *Stmt) Unsafe() *Stmt { 513 | return &Stmt{Stmt: s.Stmt, unsafe: true, Mapper: s.Mapper} 514 | } 515 | 516 | // Select using the prepared statement. 517 | // Any placeholder parameters are replaced with supplied args. 518 | func (s *Stmt) Select(dest interface{}, args ...interface{}) error { 519 | return Select(&qStmt{s}, dest, "", args...) 520 | } 521 | 522 | // Get using the prepared statement. 523 | // Any placeholder parameters are replaced with supplied args. 524 | // An error is returned if the result set is empty. 525 | func (s *Stmt) Get(dest interface{}, args ...interface{}) error { 526 | return Get(&qStmt{s}, dest, "", args...) 527 | } 528 | 529 | // MustExec (panic) using this statement. Note that the query portion of the error 530 | // output will be blank, as Stmt does not expose its query. 531 | // Any placeholder parameters are replaced with supplied args. 532 | func (s *Stmt) MustExec(args ...interface{}) sql.Result { 533 | return MustExec(&qStmt{s}, "", args...) 534 | } 535 | 536 | // QueryRowx using this statement. 537 | // Any placeholder parameters are replaced with supplied args. 538 | func (s *Stmt) QueryRowx(args ...interface{}) *Row { 539 | qs := &qStmt{s} 540 | return qs.QueryRowx("", args...) 541 | } 542 | 543 | // Queryx using this statement. 544 | // Any placeholder parameters are replaced with supplied args. 545 | func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) { 546 | qs := &qStmt{s} 547 | return qs.Queryx("", args...) 548 | } 549 | 550 | // qStmt is an unexposed wrapper which lets you use a Stmt as a Queryer & Execer by 551 | // implementing those interfaces and ignoring the `query` argument. 552 | type qStmt struct{ *Stmt } 553 | 554 | func (q *qStmt) Query(query string, args ...interface{}) (*sql.Rows, error) { 555 | return q.Stmt.Query(args...) 556 | } 557 | 558 | func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) { 559 | r, err := q.Stmt.Query(args...) 560 | if err != nil { 561 | return nil, err 562 | } 563 | return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err 564 | } 565 | 566 | func (q *qStmt) QueryRowx(query string, args ...interface{}) *Row { 567 | rows, err := q.Stmt.Query(args...) 568 | return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} 569 | } 570 | 571 | func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) { 572 | return q.Stmt.Exec(args...) 573 | } 574 | 575 | // Rows is a wrapper around sql.Rows which caches costly reflect operations 576 | // during a looped StructScan 577 | type Rows struct { 578 | *sql.Rows 579 | unsafe bool 580 | Mapper *reflectx.Mapper 581 | // these fields cache memory use for a rows during iteration w/ structScan 582 | started bool 583 | fields [][]int 584 | values []interface{} 585 | } 586 | 587 | // SliceScan using this Rows. 588 | func (r *Rows) SliceScan() ([]interface{}, error) { 589 | return SliceScan(r) 590 | } 591 | 592 | // MapScan using this Rows. 593 | func (r *Rows) MapScan(dest map[string]interface{}) error { 594 | return MapScan(r, dest) 595 | } 596 | 597 | // StructScan is like sql.Rows.Scan, but scans a single Row into a single Struct. 598 | // Use this and iterate over Rows manually when the memory load of Select() might be 599 | // prohibitive. *Rows.StructScan caches the reflect work of matching up column 600 | // positions to fields to avoid that overhead per scan, which means it is not safe 601 | // to run StructScan on the same Rows instance with different struct types. 602 | func (r *Rows) StructScan(dest interface{}) error { 603 | v := reflect.ValueOf(dest) 604 | 605 | if v.Kind() != reflect.Ptr { 606 | return errors.New("must pass a pointer, not a value, to StructScan destination") 607 | } 608 | 609 | v = v.Elem() 610 | 611 | if !r.started { 612 | columns, err := r.Columns() 613 | if err != nil { 614 | return err 615 | } 616 | m := r.Mapper 617 | 618 | r.fields = m.TraversalsByName(v.Type(), columns) 619 | // if we are not unsafe and are missing fields, return an error 620 | if f, err := missingFields(r.fields); err != nil && !r.unsafe { 621 | return fmt.Errorf("missing destination name %s in %T", columns[f], dest) 622 | } 623 | r.values = make([]interface{}, len(columns)) 624 | r.started = true 625 | } 626 | 627 | err := fieldsByTraversal(v, r.fields, r.values, true) 628 | if err != nil { 629 | return err 630 | } 631 | // scan into the struct field pointers and append to our results 632 | err = r.Scan(r.values...) 633 | if err != nil { 634 | return err 635 | } 636 | return r.Err() 637 | } 638 | 639 | // Connect to a database and verify with a ping. 640 | func Connect(driverName, dataSourceName string) (*DB, error) { 641 | db, err := Open(driverName, dataSourceName) 642 | if err != nil { 643 | return nil, err 644 | } 645 | err = db.Ping() 646 | if err != nil { 647 | db.Close() 648 | return nil, err 649 | } 650 | return db, nil 651 | } 652 | 653 | // MustConnect connects to a database and panics on error. 654 | func MustConnect(driverName, dataSourceName string) *DB { 655 | db, err := Connect(driverName, dataSourceName) 656 | if err != nil { 657 | panic(err) 658 | } 659 | return db 660 | } 661 | 662 | // Preparex prepares a statement. 663 | func Preparex(p Preparer, query string) (*Stmt, error) { 664 | s, err := p.Prepare(query) 665 | if err != nil { 666 | return nil, err 667 | } 668 | return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err 669 | } 670 | 671 | // Select executes a query using the provided Queryer, and StructScans each row 672 | // into dest, which must be a slice. If the slice elements are scannable, then 673 | // the result set must have only one column. Otherwise, StructScan is used. 674 | // The *sql.Rows are closed automatically. 675 | // Any placeholder parameters are replaced with supplied args. 676 | func Select(q Queryer, dest interface{}, query string, args ...interface{}) error { 677 | rows, err := q.Queryx(query, args...) 678 | if err != nil { 679 | return err 680 | } 681 | // if something happens here, we want to make sure the rows are Closed 682 | defer rows.Close() 683 | return scanAll(rows, dest, false) 684 | } 685 | 686 | // Get does a QueryRow using the provided Queryer, and scans the resulting row 687 | // to dest. If dest is scannable, the result must only have one column. Otherwise, 688 | // StructScan is used. Get will return sql.ErrNoRows like row.Scan would. 689 | // Any placeholder parameters are replaced with supplied args. 690 | // An error is returned if the result set is empty. 691 | func Get(q Queryer, dest interface{}, query string, args ...interface{}) error { 692 | r := q.QueryRowx(query, args...) 693 | return r.scanAny(dest, false) 694 | } 695 | 696 | // LoadFile exec's every statement in a file (as a single call to Exec). 697 | // LoadFile may return a nil *sql.Result if errors are encountered locating or 698 | // reading the file at path. LoadFile reads the entire file into memory, so it 699 | // is not suitable for loading large data dumps, but can be useful for initializing 700 | // schemas or loading indexes. 701 | // 702 | // FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 703 | // or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting 704 | // this by requiring something with DriverName() and then attempting to split the 705 | // queries will be difficult to get right, and its current driver-specific behavior 706 | // is deemed at least not complex in its incorrectness. 707 | func LoadFile(e Execer, path string) (*sql.Result, error) { 708 | realpath, err := filepath.Abs(path) 709 | if err != nil { 710 | return nil, err 711 | } 712 | contents, err := ioutil.ReadFile(realpath) 713 | if err != nil { 714 | return nil, err 715 | } 716 | res, err := e.Exec(string(contents)) 717 | return &res, err 718 | } 719 | 720 | // MustExec execs the query using e and panics if there was an error. 721 | // Any placeholder parameters are replaced with supplied args. 722 | func MustExec(e Execer, query string, args ...interface{}) sql.Result { 723 | res, err := e.Exec(query, args...) 724 | if err != nil { 725 | panic(err) 726 | } 727 | return res 728 | } 729 | 730 | // SliceScan using this Rows. 731 | func (r *Row) SliceScan() ([]interface{}, error) { 732 | return SliceScan(r) 733 | } 734 | 735 | // MapScan using this Rows. 736 | func (r *Row) MapScan(dest map[string]interface{}) error { 737 | return MapScan(r, dest) 738 | } 739 | 740 | func (r *Row) scanAny(dest interface{}, structOnly bool) error { 741 | if r.err != nil { 742 | return r.err 743 | } 744 | if r.rows == nil { 745 | r.err = sql.ErrNoRows 746 | return r.err 747 | } 748 | defer r.rows.Close() 749 | 750 | v := reflect.ValueOf(dest) 751 | if v.Kind() != reflect.Ptr { 752 | return errors.New("must pass a pointer, not a value, to StructScan destination") 753 | } 754 | if v.IsNil() { 755 | return errors.New("nil pointer passed to StructScan destination") 756 | } 757 | 758 | base := reflectx.Deref(v.Type()) 759 | scannable := isScannable(base) 760 | 761 | if structOnly && scannable { 762 | return structOnlyError(base) 763 | } 764 | 765 | columns, err := r.Columns() 766 | if err != nil { 767 | return err 768 | } 769 | 770 | if scannable && len(columns) > 1 { 771 | return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(columns)) 772 | } 773 | 774 | if scannable { 775 | return r.Scan(dest) 776 | } 777 | 778 | m := r.Mapper 779 | 780 | fields := m.TraversalsByName(v.Type(), columns) 781 | // if we are not unsafe and are missing fields, return an error 782 | if f, err := missingFields(fields); err != nil && !r.unsafe { 783 | return fmt.Errorf("missing destination name %s in %T", columns[f], dest) 784 | } 785 | values := make([]interface{}, len(columns)) 786 | 787 | err = fieldsByTraversal(v, fields, values, true) 788 | if err != nil { 789 | return err 790 | } 791 | // scan into the struct field pointers and append to our results 792 | return r.Scan(values...) 793 | } 794 | 795 | // StructScan a single Row into dest. 796 | func (r *Row) StructScan(dest interface{}) error { 797 | return r.scanAny(dest, true) 798 | } 799 | 800 | // SliceScan a row, returning a []interface{} with values similar to MapScan. 801 | // This function is primarily intended for use where the number of columns 802 | // is not known. Because you can pass an []interface{} directly to Scan, 803 | // it's recommended that you do that as it will not have to allocate new 804 | // slices per row. 805 | func SliceScan(r ColScanner) ([]interface{}, error) { 806 | // ignore r.started, since we needn't use reflect for anything. 807 | columns, err := r.Columns() 808 | if err != nil { 809 | return []interface{}{}, err 810 | } 811 | 812 | values := make([]interface{}, len(columns)) 813 | for i := range values { 814 | values[i] = new(interface{}) 815 | } 816 | 817 | err = r.Scan(values...) 818 | 819 | if err != nil { 820 | return values, err 821 | } 822 | 823 | for i := range columns { 824 | values[i] = *(values[i].(*interface{})) 825 | } 826 | 827 | return values, r.Err() 828 | } 829 | 830 | // MapScan scans a single Row into the dest map[string]interface{}. 831 | // Use this to get results for SQL that might not be under your control 832 | // (for instance, if you're building an interface for an SQL server that 833 | // executes SQL from input). Please do not use this as a primary interface! 834 | // This will modify the map sent to it in place, so reuse the same map with 835 | // care. Columns which occur more than once in the result will overwrite 836 | // each other! 837 | func MapScan(r ColScanner, dest map[string]interface{}) error { 838 | // ignore r.started, since we needn't use reflect for anything. 839 | columns, err := r.Columns() 840 | if err != nil { 841 | return err 842 | } 843 | 844 | values := make([]interface{}, len(columns)) 845 | for i := range values { 846 | values[i] = new(interface{}) 847 | } 848 | 849 | err = r.Scan(values...) 850 | if err != nil { 851 | return err 852 | } 853 | 854 | for i, column := range columns { 855 | dest[column] = *(values[i].(*interface{})) 856 | } 857 | 858 | return r.Err() 859 | } 860 | 861 | type rowsi interface { 862 | Close() error 863 | Columns() ([]string, error) 864 | Err() error 865 | Next() bool 866 | Scan(...interface{}) error 867 | } 868 | 869 | // structOnlyError returns an error appropriate for type when a non-scannable 870 | // struct is expected but something else is given 871 | func structOnlyError(t reflect.Type) error { 872 | isStruct := t.Kind() == reflect.Struct 873 | isScanner := reflect.PtrTo(t).Implements(_scannerInterface) 874 | if !isStruct { 875 | return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind()) 876 | } 877 | if isScanner { 878 | return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements scanner", t.Name()) 879 | } 880 | return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name()) 881 | } 882 | 883 | // scanAll scans all rows into a destination, which must be a slice of any 884 | // type. It resets the slice length to zero before appending each element to 885 | // the slice. If the destination slice type is a Struct, then StructScan will 886 | // be used on each row. If the destination is some other kind of base type, 887 | // then each row must only have one column which can scan into that type. This 888 | // allows you to do something like: 889 | // 890 | // rows, _ := db.Query("select id from people;") 891 | // var ids []int 892 | // scanAll(rows, &ids, false) 893 | // 894 | // and ids will be a list of the id results. I realize that this is a desirable 895 | // interface to expose to users, but for now it will only be exposed via changes 896 | // to `Get` and `Select`. The reason that this has been implemented like this is 897 | // this is the only way to not duplicate reflect work in the new API while 898 | // maintaining backwards compatibility. 899 | func scanAll(rows rowsi, dest interface{}, structOnly bool) error { 900 | var v, vp reflect.Value 901 | 902 | value := reflect.ValueOf(dest) 903 | 904 | // json.Unmarshal returns errors for these 905 | if value.Kind() != reflect.Ptr { 906 | return errors.New("must pass a pointer, not a value, to StructScan destination") 907 | } 908 | if value.IsNil() { 909 | return errors.New("nil pointer passed to StructScan destination") 910 | } 911 | direct := reflect.Indirect(value) 912 | 913 | slice, err := baseType(value.Type(), reflect.Slice) 914 | if err != nil { 915 | return err 916 | } 917 | direct.SetLen(0) 918 | 919 | isPtr := slice.Elem().Kind() == reflect.Ptr 920 | base := reflectx.Deref(slice.Elem()) 921 | scannable := isScannable(base) 922 | 923 | if structOnly && scannable { 924 | return structOnlyError(base) 925 | } 926 | 927 | columns, err := rows.Columns() 928 | if err != nil { 929 | return err 930 | } 931 | 932 | // if it's a base type make sure it only has 1 column; if not return an error 933 | if scannable && len(columns) > 1 { 934 | return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(columns)) 935 | } 936 | 937 | if !scannable { 938 | var values []interface{} 939 | var m *reflectx.Mapper 940 | 941 | switch rows := rows.(type) { 942 | case *Rows: 943 | m = rows.Mapper 944 | default: 945 | m = mapper() 946 | } 947 | 948 | fields := m.TraversalsByName(base, columns) 949 | // if we are not unsafe and are missing fields, return an error 950 | if f, err := missingFields(fields); err != nil && !isUnsafe(rows) { 951 | return fmt.Errorf("missing destination name %s in %T", columns[f], dest) 952 | } 953 | values = make([]interface{}, len(columns)) 954 | 955 | for rows.Next() { 956 | // create a new struct type (which returns PtrTo) and indirect it 957 | vp = reflect.New(base) 958 | v = reflect.Indirect(vp) 959 | 960 | err = fieldsByTraversal(v, fields, values, true) 961 | if err != nil { 962 | return err 963 | } 964 | 965 | // scan into the struct field pointers and append to our results 966 | err = rows.Scan(values...) 967 | if err != nil { 968 | return err 969 | } 970 | 971 | if isPtr { 972 | direct.Set(reflect.Append(direct, vp)) 973 | } else { 974 | direct.Set(reflect.Append(direct, v)) 975 | } 976 | } 977 | } else { 978 | for rows.Next() { 979 | vp = reflect.New(base) 980 | err = rows.Scan(vp.Interface()) 981 | if err != nil { 982 | return err 983 | } 984 | // append 985 | if isPtr { 986 | direct.Set(reflect.Append(direct, vp)) 987 | } else { 988 | direct.Set(reflect.Append(direct, reflect.Indirect(vp))) 989 | } 990 | } 991 | } 992 | 993 | return rows.Err() 994 | } 995 | 996 | // FIXME: StructScan was the very first bit of API in sqlx, and now unfortunately 997 | // it doesn't really feel like it's named properly. There is an incongruency 998 | // between this and the way that StructScan (which might better be ScanStruct 999 | // anyway) works on a rows object. 1000 | 1001 | // StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice. 1002 | // StructScan will scan in the entire rows result, so if you do not want to 1003 | // allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan. 1004 | // If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default. 1005 | func StructScan(rows rowsi, dest interface{}) error { 1006 | return scanAll(rows, dest, true) 1007 | 1008 | } 1009 | 1010 | // reflect helpers 1011 | 1012 | func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { 1013 | t = reflectx.Deref(t) 1014 | if t.Kind() != expected { 1015 | return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind()) 1016 | } 1017 | return t, nil 1018 | } 1019 | 1020 | // fieldsByName fills a values interface with fields from the passed value based 1021 | // on the traversals in int. If ptrs is true, return addresses instead of values. 1022 | // We write this instead of using FieldsByName to save allocations and map lookups 1023 | // when iterating over many rows. Empty traversals will get an interface pointer. 1024 | // Because of the necessity of requesting ptrs or values, it's considered a bit too 1025 | // specialized for inclusion in reflectx itself. 1026 | func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { 1027 | v = reflect.Indirect(v) 1028 | if v.Kind() != reflect.Struct { 1029 | return errors.New("argument not a struct") 1030 | } 1031 | 1032 | for i, traversal := range traversals { 1033 | if len(traversal) == 0 { 1034 | values[i] = new(interface{}) 1035 | continue 1036 | } 1037 | f := reflectx.FieldByIndexes(v, traversal) 1038 | if ptrs { 1039 | values[i] = f.Addr().Interface() 1040 | } else { 1041 | values[i] = f.Interface() 1042 | } 1043 | } 1044 | return nil 1045 | } 1046 | 1047 | func missingFields(transversals [][]int) (field int, err error) { 1048 | for i, t := range transversals { 1049 | if len(t) == 0 { 1050 | return i, errors.New("missing field") 1051 | } 1052 | } 1053 | return 0, nil 1054 | } 1055 | -------------------------------------------------------------------------------- /sqlx_context.go: -------------------------------------------------------------------------------- 1 | //go:build go1.8 2 | // +build go1.8 3 | 4 | package sqlx 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "fmt" 10 | "io/ioutil" 11 | "path/filepath" 12 | "reflect" 13 | ) 14 | 15 | // ConnectContext to a database and verify with a ping. 16 | func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB, error) { 17 | db, err := Open(driverName, dataSourceName) 18 | if err != nil { 19 | return db, err 20 | } 21 | err = db.PingContext(ctx) 22 | return db, err 23 | } 24 | 25 | // QueryerContext is an interface used by GetContext and SelectContext 26 | type QueryerContext interface { 27 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 28 | QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) 29 | QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row 30 | } 31 | 32 | // PreparerContext is an interface used by PreparexContext. 33 | type PreparerContext interface { 34 | PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) 35 | } 36 | 37 | // ExecerContext is an interface used by MustExecContext and LoadFileContext 38 | type ExecerContext interface { 39 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 40 | } 41 | 42 | // ExtContext is a union interface which can bind, query, and exec, with Context 43 | // used by NamedQueryContext and NamedExecContext. 44 | type ExtContext interface { 45 | binder 46 | QueryerContext 47 | ExecerContext 48 | } 49 | 50 | // SelectContext executes a query using the provided Queryer, and StructScans 51 | // each row into dest, which must be a slice. If the slice elements are 52 | // scannable, then the result set must have only one column. Otherwise, 53 | // StructScan is used. The *sql.Rows are closed automatically. 54 | // Any placeholder parameters are replaced with supplied args. 55 | func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { 56 | rows, err := q.QueryxContext(ctx, query, args...) 57 | if err != nil { 58 | return err 59 | } 60 | // if something happens here, we want to make sure the rows are Closed 61 | defer rows.Close() 62 | return scanAll(rows, dest, false) 63 | } 64 | 65 | // PreparexContext prepares a statement. 66 | // 67 | // The provided context is used for the preparation of the statement, not for 68 | // the execution of the statement. 69 | func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stmt, error) { 70 | s, err := p.PrepareContext(ctx, query) 71 | if err != nil { 72 | return nil, err 73 | } 74 | return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err 75 | } 76 | 77 | // GetContext does a QueryRow using the provided Queryer, and scans the 78 | // resulting row to dest. If dest is scannable, the result must only have one 79 | // column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like 80 | // row.Scan would. Any placeholder parameters are replaced with supplied args. 81 | // An error is returned if the result set is empty. 82 | func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { 83 | r := q.QueryRowxContext(ctx, query, args...) 84 | return r.scanAny(dest, false) 85 | } 86 | 87 | // LoadFileContext exec's every statement in a file (as a single call to Exec). 88 | // LoadFileContext may return a nil *sql.Result if errors are encountered 89 | // locating or reading the file at path. LoadFile reads the entire file into 90 | // memory, so it is not suitable for loading large data dumps, but can be useful 91 | // for initializing schemas or loading indexes. 92 | // 93 | // FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 94 | // or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting 95 | // this by requiring something with DriverName() and then attempting to split the 96 | // queries will be difficult to get right, and its current driver-specific behavior 97 | // is deemed at least not complex in its incorrectness. 98 | func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Result, error) { 99 | realpath, err := filepath.Abs(path) 100 | if err != nil { 101 | return nil, err 102 | } 103 | contents, err := ioutil.ReadFile(realpath) 104 | if err != nil { 105 | return nil, err 106 | } 107 | res, err := e.ExecContext(ctx, string(contents)) 108 | return &res, err 109 | } 110 | 111 | // MustExecContext execs the query using e and panics if there was an error. 112 | // Any placeholder parameters are replaced with supplied args. 113 | func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result { 114 | res, err := e.ExecContext(ctx, query, args...) 115 | if err != nil { 116 | panic(err) 117 | } 118 | return res 119 | } 120 | 121 | // PrepareNamedContext returns an sqlx.NamedStmt 122 | func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { 123 | return prepareNamedContext(ctx, db, query) 124 | } 125 | 126 | // NamedQueryContext using this DB. 127 | // Any named placeholder parameters are replaced with fields from arg. 128 | func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) { 129 | return NamedQueryContext(ctx, db, query, arg) 130 | } 131 | 132 | // NamedExecContext using this DB. 133 | // Any named placeholder parameters are replaced with fields from arg. 134 | func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { 135 | return NamedExecContext(ctx, db, query, arg) 136 | } 137 | 138 | // SelectContext using this DB. 139 | // Any placeholder parameters are replaced with supplied args. 140 | func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { 141 | return SelectContext(ctx, db, dest, query, args...) 142 | } 143 | 144 | // GetContext using this DB. 145 | // Any placeholder parameters are replaced with supplied args. 146 | // An error is returned if the result set is empty. 147 | func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { 148 | return GetContext(ctx, db, dest, query, args...) 149 | } 150 | 151 | // PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. 152 | // 153 | // The provided context is used for the preparation of the statement, not for 154 | // the execution of the statement. 155 | func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) { 156 | return PreparexContext(ctx, db, query) 157 | } 158 | 159 | // QueryxContext queries the database and returns an *sqlx.Rows. 160 | // Any placeholder parameters are replaced with supplied args. 161 | func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { 162 | r, err := db.DB.QueryContext(ctx, query, args...) 163 | if err != nil { 164 | return nil, err 165 | } 166 | return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err 167 | } 168 | 169 | // QueryRowxContext queries the database and returns an *sqlx.Row. 170 | // Any placeholder parameters are replaced with supplied args. 171 | func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { 172 | rows, err := db.DB.QueryContext(ctx, query, args...) 173 | return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} 174 | } 175 | 176 | // MustBeginTx starts a transaction, and panics on error. Returns an *sqlx.Tx instead 177 | // of an *sql.Tx. 178 | // 179 | // The provided context is used until the transaction is committed or rolled 180 | // back. If the context is canceled, the sql package will roll back the 181 | // transaction. Tx.Commit will return an error if the context provided to 182 | // MustBeginContext is canceled. 183 | func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx { 184 | tx, err := db.BeginTxx(ctx, opts) 185 | if err != nil { 186 | panic(err) 187 | } 188 | return tx 189 | } 190 | 191 | // MustExecContext (panic) runs MustExec using this database. 192 | // Any placeholder parameters are replaced with supplied args. 193 | func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { 194 | return MustExecContext(ctx, db, query, args...) 195 | } 196 | 197 | // BeginTxx begins a transaction and returns an *sqlx.Tx instead of an 198 | // *sql.Tx. 199 | // 200 | // The provided context is used until the transaction is committed or rolled 201 | // back. If the context is canceled, the sql package will roll back the 202 | // transaction. Tx.Commit will return an error if the context provided to 203 | // BeginxContext is canceled. 204 | func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { 205 | tx, err := db.DB.BeginTx(ctx, opts) 206 | if err != nil { 207 | return nil, err 208 | } 209 | return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err 210 | } 211 | 212 | // Connx returns an *sqlx.Conn instead of an *sql.Conn. 213 | func (db *DB) Connx(ctx context.Context) (*Conn, error) { 214 | conn, err := db.DB.Conn(ctx) 215 | if err != nil { 216 | return nil, err 217 | } 218 | 219 | return &Conn{Conn: conn, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, nil 220 | } 221 | 222 | // BeginTxx begins a transaction and returns an *sqlx.Tx instead of an 223 | // *sql.Tx. 224 | // 225 | // The provided context is used until the transaction is committed or rolled 226 | // back. If the context is canceled, the sql package will roll back the 227 | // transaction. Tx.Commit will return an error if the context provided to 228 | // BeginxContext is canceled. 229 | func (c *Conn) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { 230 | tx, err := c.Conn.BeginTx(ctx, opts) 231 | if err != nil { 232 | return nil, err 233 | } 234 | return &Tx{Tx: tx, driverName: c.driverName, unsafe: c.unsafe, Mapper: c.Mapper}, err 235 | } 236 | 237 | // SelectContext using this Conn. 238 | // Any placeholder parameters are replaced with supplied args. 239 | func (c *Conn) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { 240 | return SelectContext(ctx, c, dest, query, args...) 241 | } 242 | 243 | // GetContext using this Conn. 244 | // Any placeholder parameters are replaced with supplied args. 245 | // An error is returned if the result set is empty. 246 | func (c *Conn) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { 247 | return GetContext(ctx, c, dest, query, args...) 248 | } 249 | 250 | // PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. 251 | // 252 | // The provided context is used for the preparation of the statement, not for 253 | // the execution of the statement. 254 | func (c *Conn) PreparexContext(ctx context.Context, query string) (*Stmt, error) { 255 | return PreparexContext(ctx, c, query) 256 | } 257 | 258 | // QueryxContext queries the database and returns an *sqlx.Rows. 259 | // Any placeholder parameters are replaced with supplied args. 260 | func (c *Conn) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { 261 | r, err := c.Conn.QueryContext(ctx, query, args...) 262 | if err != nil { 263 | return nil, err 264 | } 265 | return &Rows{Rows: r, unsafe: c.unsafe, Mapper: c.Mapper}, err 266 | } 267 | 268 | // QueryRowxContext queries the database and returns an *sqlx.Row. 269 | // Any placeholder parameters are replaced with supplied args. 270 | func (c *Conn) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { 271 | rows, err := c.Conn.QueryContext(ctx, query, args...) 272 | return &Row{rows: rows, err: err, unsafe: c.unsafe, Mapper: c.Mapper} 273 | } 274 | 275 | // Rebind a query within a Conn's bindvar type. 276 | func (c *Conn) Rebind(query string) string { 277 | return Rebind(BindType(c.driverName), query) 278 | } 279 | 280 | // StmtxContext returns a version of the prepared statement which runs within a 281 | // transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt. 282 | func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt { 283 | var s *sql.Stmt 284 | switch v := stmt.(type) { 285 | case Stmt: 286 | s = v.Stmt 287 | case *Stmt: 288 | s = v.Stmt 289 | case *sql.Stmt: 290 | s = v 291 | default: 292 | panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) 293 | } 294 | return &Stmt{Stmt: tx.StmtContext(ctx, s), Mapper: tx.Mapper} 295 | } 296 | 297 | // NamedStmtContext returns a version of the prepared statement which runs 298 | // within a transaction. 299 | func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt { 300 | return &NamedStmt{ 301 | QueryString: stmt.QueryString, 302 | Params: stmt.Params, 303 | Stmt: tx.StmtxContext(ctx, stmt.Stmt), 304 | } 305 | } 306 | 307 | // PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. 308 | // 309 | // The provided context is used for the preparation of the statement, not for 310 | // the execution of the statement. 311 | func (tx *Tx) PreparexContext(ctx context.Context, query string) (*Stmt, error) { 312 | return PreparexContext(ctx, tx, query) 313 | } 314 | 315 | // PrepareNamedContext returns an sqlx.NamedStmt 316 | func (tx *Tx) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { 317 | return prepareNamedContext(ctx, tx, query) 318 | } 319 | 320 | // MustExecContext runs MustExecContext within a transaction. 321 | // Any placeholder parameters are replaced with supplied args. 322 | func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { 323 | return MustExecContext(ctx, tx, query, args...) 324 | } 325 | 326 | // QueryxContext within a transaction and context. 327 | // Any placeholder parameters are replaced with supplied args. 328 | func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { 329 | r, err := tx.Tx.QueryContext(ctx, query, args...) 330 | if err != nil { 331 | return nil, err 332 | } 333 | return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err 334 | } 335 | 336 | // SelectContext within a transaction and context. 337 | // Any placeholder parameters are replaced with supplied args. 338 | func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { 339 | return SelectContext(ctx, tx, dest, query, args...) 340 | } 341 | 342 | // GetContext within a transaction and context. 343 | // Any placeholder parameters are replaced with supplied args. 344 | // An error is returned if the result set is empty. 345 | func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { 346 | return GetContext(ctx, tx, dest, query, args...) 347 | } 348 | 349 | // QueryRowxContext within a transaction and context. 350 | // Any placeholder parameters are replaced with supplied args. 351 | func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { 352 | rows, err := tx.Tx.QueryContext(ctx, query, args...) 353 | return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} 354 | } 355 | 356 | // NamedExecContext using this Tx. 357 | // Any named placeholder parameters are replaced with fields from arg. 358 | func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { 359 | return NamedExecContext(ctx, tx, query, arg) 360 | } 361 | 362 | // SelectContext using the prepared statement. 363 | // Any placeholder parameters are replaced with supplied args. 364 | func (s *Stmt) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error { 365 | return SelectContext(ctx, &qStmt{s}, dest, "", args...) 366 | } 367 | 368 | // GetContext using the prepared statement. 369 | // Any placeholder parameters are replaced with supplied args. 370 | // An error is returned if the result set is empty. 371 | func (s *Stmt) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error { 372 | return GetContext(ctx, &qStmt{s}, dest, "", args...) 373 | } 374 | 375 | // MustExecContext (panic) using this statement. Note that the query portion of 376 | // the error output will be blank, as Stmt does not expose its query. 377 | // Any placeholder parameters are replaced with supplied args. 378 | func (s *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result { 379 | return MustExecContext(ctx, &qStmt{s}, "", args...) 380 | } 381 | 382 | // QueryRowxContext using this statement. 383 | // Any placeholder parameters are replaced with supplied args. 384 | func (s *Stmt) QueryRowxContext(ctx context.Context, args ...interface{}) *Row { 385 | qs := &qStmt{s} 386 | return qs.QueryRowxContext(ctx, "", args...) 387 | } 388 | 389 | // QueryxContext using this statement. 390 | // Any placeholder parameters are replaced with supplied args. 391 | func (s *Stmt) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) { 392 | qs := &qStmt{s} 393 | return qs.QueryxContext(ctx, "", args...) 394 | } 395 | 396 | func (q *qStmt) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { 397 | return q.Stmt.QueryContext(ctx, args...) 398 | } 399 | 400 | func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { 401 | r, err := q.Stmt.QueryContext(ctx, args...) 402 | if err != nil { 403 | return nil, err 404 | } 405 | return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err 406 | } 407 | 408 | func (q *qStmt) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { 409 | rows, err := q.Stmt.QueryContext(ctx, args...) 410 | return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} 411 | } 412 | 413 | func (q *qStmt) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 414 | return q.Stmt.ExecContext(ctx, args...) 415 | } 416 | -------------------------------------------------------------------------------- /sqlx_context_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.8 2 | // +build go1.8 3 | 4 | // The following environment variables, if set, will be used: 5 | // 6 | // - SQLX_SQLITE_DSN 7 | // - SQLX_POSTGRES_DSN 8 | // - SQLX_MYSQL_DSN 9 | // 10 | // Set any of these variables to 'skip' to skip them. Note that for MySQL, 11 | // the string '?parseTime=True' will be appended to the DSN if it's not there 12 | // already. 13 | package sqlx 14 | 15 | import ( 16 | "context" 17 | "database/sql" 18 | "encoding/json" 19 | "fmt" 20 | "log" 21 | "strings" 22 | "testing" 23 | "time" 24 | 25 | _ "github.com/go-sql-driver/mysql" 26 | _ "github.com/lib/pq" 27 | _ "github.com/mattn/go-sqlite3" 28 | 29 | "github.com/jmoiron/sqlx/reflectx" 30 | ) 31 | 32 | func MultiExecContext(ctx context.Context, e ExecerContext, query string) { 33 | stmts := strings.Split(query, ";\n") 34 | if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 { 35 | stmts = stmts[:len(stmts)-1] 36 | } 37 | for _, s := range stmts { 38 | _, err := e.ExecContext(ctx, s) 39 | if err != nil { 40 | fmt.Println(err, s) 41 | } 42 | } 43 | } 44 | 45 | func RunWithSchemaContext(ctx context.Context, schema Schema, t *testing.T, test func(ctx context.Context, db *DB, t *testing.T)) { 46 | runner := func(ctx context.Context, db *DB, t *testing.T, create, drop, now string) { 47 | defer func() { 48 | MultiExecContext(ctx, db, drop) 49 | }() 50 | 51 | MultiExecContext(ctx, db, create) 52 | test(ctx, db, t) 53 | } 54 | 55 | if TestPostgres { 56 | create, drop, now := schema.Postgres() 57 | runner(ctx, pgdb, t, create, drop, now) 58 | } 59 | if TestSqlite { 60 | create, drop, now := schema.Sqlite3() 61 | runner(ctx, sldb, t, create, drop, now) 62 | } 63 | if TestMysql { 64 | create, drop, now := schema.MySQL() 65 | runner(ctx, mysqldb, t, create, drop, now) 66 | } 67 | } 68 | 69 | func loadDefaultFixtureContext(ctx context.Context, db *DB, t *testing.T) { 70 | tx := db.MustBeginTx(ctx, nil) 71 | tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net") 72 | tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net") 73 | tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") 74 | tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") 75 | tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") 76 | if db.DriverName() == "mysql" { 77 | tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)"), "Sarf Efrica", "27") 78 | } else { 79 | tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)"), "Sarf Efrica", "27") 80 | } 81 | tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id) VALUES (?, ?)"), "Peter", "4444") 82 | tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Joe", "1", "4444") 83 | tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Martin", "2", "4444") 84 | tx.Commit() 85 | } 86 | 87 | // Test a new backwards compatible feature, that missing scan destinations 88 | // will silently scan into sql.RawText rather than failing/panicing 89 | func TestMissingNamesContextContext(t *testing.T) { 90 | RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { 91 | loadDefaultFixtureContext(ctx, db, t) 92 | type PersonPlus struct { 93 | FirstName string `db:"first_name"` 94 | LastName string `db:"last_name"` 95 | Email string 96 | // AddedAt time.Time `db:"added_at"` 97 | } 98 | 99 | // test Select first 100 | pps := []PersonPlus{} 101 | // pps lacks added_at destination 102 | err := db.SelectContext(ctx, &pps, "SELECT * FROM person") 103 | if err == nil { 104 | t.Error("Expected missing name from Select to fail, but it did not.") 105 | } 106 | 107 | // test Get 108 | pp := PersonPlus{} 109 | err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") 110 | if err == nil { 111 | t.Error("Expected missing name Get to fail, but it did not.") 112 | } 113 | 114 | // test naked StructScan 115 | pps = []PersonPlus{} 116 | rows, err := db.QueryContext(ctx, "SELECT * FROM person LIMIT 1") 117 | if err != nil { 118 | t.Fatal(err) 119 | } 120 | rows.Next() 121 | err = StructScan(rows, &pps) 122 | if err == nil { 123 | t.Error("Expected missing name in StructScan to fail, but it did not.") 124 | } 125 | rows.Close() 126 | 127 | // now try various things with unsafe set. 128 | db = db.Unsafe() 129 | pps = []PersonPlus{} 130 | err = db.SelectContext(ctx, &pps, "SELECT * FROM person") 131 | if err != nil { 132 | t.Error(err) 133 | } 134 | 135 | // test Get 136 | pp = PersonPlus{} 137 | err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") 138 | if err != nil { 139 | t.Error(err) 140 | } 141 | 142 | // test naked StructScan 143 | pps = []PersonPlus{} 144 | rowsx, err := db.QueryxContext(ctx, "SELECT * FROM person LIMIT 1") 145 | if err != nil { 146 | t.Fatal(err) 147 | } 148 | rowsx.Next() 149 | err = StructScan(rowsx, &pps) 150 | if err != nil { 151 | t.Error(err) 152 | } 153 | rowsx.Close() 154 | 155 | // test Named stmt 156 | if !isUnsafe(db) { 157 | t.Error("Expected db to be unsafe, but it isn't") 158 | } 159 | nstmt, err := db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) 160 | if err != nil { 161 | t.Fatal(err) 162 | } 163 | // its internal stmt should be marked unsafe 164 | if !nstmt.Stmt.unsafe { 165 | t.Error("expected NamedStmt to be unsafe but its underlying stmt did not inherit safety") 166 | } 167 | pps = []PersonPlus{} 168 | err = nstmt.SelectContext(ctx, &pps, map[string]interface{}{"name": "Jason"}) 169 | if err != nil { 170 | t.Fatal(err) 171 | } 172 | if len(pps) != 1 { 173 | t.Errorf("Expected 1 person back, got %d", len(pps)) 174 | } 175 | 176 | // test it with a safe db 177 | db.unsafe = false 178 | if isUnsafe(db) { 179 | t.Error("expected db to be safe but it isn't") 180 | } 181 | nstmt, err = db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) 182 | if err != nil { 183 | t.Fatal(err) 184 | } 185 | // it should be safe 186 | if isUnsafe(nstmt) { 187 | t.Error("NamedStmt did not inherit safety") 188 | } 189 | nstmt.Unsafe() 190 | if !isUnsafe(nstmt) { 191 | t.Error("expected newly unsafed NamedStmt to be unsafe") 192 | } 193 | pps = []PersonPlus{} 194 | err = nstmt.SelectContext(ctx, &pps, map[string]interface{}{"name": "Jason"}) 195 | if err != nil { 196 | t.Fatal(err) 197 | } 198 | if len(pps) != 1 { 199 | t.Errorf("Expected 1 person back, got %d", len(pps)) 200 | } 201 | 202 | }) 203 | } 204 | 205 | func TestEmbeddedStructsContextContext(t *testing.T) { 206 | type Loop1 struct{ Person } 207 | type Loop2 struct{ Loop1 } 208 | type Loop3 struct{ Loop2 } 209 | 210 | RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { 211 | loadDefaultFixtureContext(ctx, db, t) 212 | peopleAndPlaces := []PersonPlace{} 213 | err := db.SelectContext( 214 | ctx, 215 | &peopleAndPlaces, 216 | `SELECT person.*, place.* FROM 217 | person natural join place`) 218 | if err != nil { 219 | t.Fatal(err) 220 | } 221 | for _, pp := range peopleAndPlaces { 222 | if len(pp.Person.FirstName) == 0 { 223 | t.Errorf("Expected non zero lengthed first name.") 224 | } 225 | if len(pp.Place.Country) == 0 { 226 | t.Errorf("Expected non zero lengthed country.") 227 | } 228 | } 229 | 230 | // test embedded structs with StructScan 231 | rows, err := db.QueryxContext( 232 | ctx, 233 | `SELECT person.*, place.* FROM 234 | person natural join place`) 235 | if err != nil { 236 | t.Error(err) 237 | } 238 | 239 | perp := PersonPlace{} 240 | rows.Next() 241 | err = rows.StructScan(&perp) 242 | if err != nil { 243 | t.Error(err) 244 | } 245 | 246 | if len(perp.Person.FirstName) == 0 { 247 | t.Errorf("Expected non zero lengthed first name.") 248 | } 249 | if len(perp.Place.Country) == 0 { 250 | t.Errorf("Expected non zero lengthed country.") 251 | } 252 | 253 | rows.Close() 254 | 255 | // test the same for embedded pointer structs 256 | peopleAndPlacesPtrs := []PersonPlacePtr{} 257 | err = db.SelectContext( 258 | ctx, 259 | &peopleAndPlacesPtrs, 260 | `SELECT person.*, place.* FROM 261 | person natural join place`) 262 | if err != nil { 263 | t.Fatal(err) 264 | } 265 | for _, pp := range peopleAndPlacesPtrs { 266 | if len(pp.Person.FirstName) == 0 { 267 | t.Errorf("Expected non zero lengthed first name.") 268 | } 269 | if len(pp.Place.Country) == 0 { 270 | t.Errorf("Expected non zero lengthed country.") 271 | } 272 | } 273 | 274 | // test "deep nesting" 275 | l3s := []Loop3{} 276 | err = db.SelectContext(ctx, &l3s, `select * from person`) 277 | if err != nil { 278 | t.Fatal(err) 279 | } 280 | for _, l3 := range l3s { 281 | if len(l3.Loop2.Loop1.Person.FirstName) == 0 { 282 | t.Errorf("Expected non zero lengthed first name.") 283 | } 284 | } 285 | 286 | // test "embed conflicts" 287 | ec := []EmbedConflict{} 288 | err = db.SelectContext(ctx, &ec, `select * from person`) 289 | // I'm torn between erroring here or having some kind of working behavior 290 | // in order to allow for more flexibility in destination structs 291 | if err != nil { 292 | t.Errorf("Was not expecting an error on embed conflicts.") 293 | } 294 | }) 295 | } 296 | 297 | func TestJoinQueryContext(t *testing.T) { 298 | type Employee struct { 299 | Name string 300 | ID int64 301 | // BossID is an id into the employee table 302 | BossID sql.NullInt64 `db:"boss_id"` 303 | } 304 | type Boss Employee 305 | 306 | RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { 307 | loadDefaultFixtureContext(ctx, db, t) 308 | 309 | var employees []struct { 310 | Employee 311 | Boss `db:"boss"` 312 | } 313 | 314 | err := db.SelectContext(ctx, 315 | &employees, 316 | `SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees 317 | JOIN employees AS boss ON employees.boss_id = boss.id`) 318 | if err != nil { 319 | t.Fatal(err) 320 | } 321 | 322 | for _, em := range employees { 323 | if len(em.Employee.Name) == 0 { 324 | t.Errorf("Expected non zero lengthed name.") 325 | } 326 | if em.Employee.BossID.Int64 != em.Boss.ID { 327 | t.Errorf("Expected boss ids to match") 328 | } 329 | } 330 | }) 331 | } 332 | 333 | func TestJoinQueryNamedPointerStructsContext(t *testing.T) { 334 | type Employee struct { 335 | Name string 336 | ID int64 337 | // BossID is an id into the employee table 338 | BossID sql.NullInt64 `db:"boss_id"` 339 | } 340 | type Boss Employee 341 | 342 | RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { 343 | loadDefaultFixtureContext(ctx, db, t) 344 | 345 | var employees []struct { 346 | Emp1 *Employee `db:"emp1"` 347 | Emp2 *Employee `db:"emp2"` 348 | *Boss `db:"boss"` 349 | } 350 | 351 | err := db.SelectContext(ctx, 352 | &employees, 353 | `SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", 354 | emp.name "emp2.name", emp.id "emp2.id", emp.boss_id "emp2.boss_id", 355 | boss.id "boss.id", boss.name "boss.name" FROM employees AS emp 356 | JOIN employees AS boss ON emp.boss_id = boss.id 357 | `) 358 | if err != nil { 359 | t.Fatal(err) 360 | } 361 | 362 | for _, em := range employees { 363 | if len(em.Emp1.Name) == 0 || len(em.Emp2.Name) == 0 { 364 | t.Errorf("Expected non zero lengthed name.") 365 | } 366 | if em.Emp1.BossID.Int64 != em.Boss.ID || em.Emp2.BossID.Int64 != em.Boss.ID { 367 | t.Errorf("Expected boss ids to match") 368 | } 369 | } 370 | }) 371 | } 372 | 373 | func TestSelectSliceMapTimeContext(t *testing.T) { 374 | RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { 375 | loadDefaultFixtureContext(ctx, db, t) 376 | rows, err := db.QueryxContext(ctx, "SELECT * FROM person") 377 | if err != nil { 378 | t.Fatal(err) 379 | } 380 | for rows.Next() { 381 | _, err := rows.SliceScan() 382 | if err != nil { 383 | t.Error(err) 384 | } 385 | } 386 | 387 | rows, err = db.QueryxContext(ctx, "SELECT * FROM person") 388 | if err != nil { 389 | t.Fatal(err) 390 | } 391 | for rows.Next() { 392 | m := map[string]interface{}{} 393 | err := rows.MapScan(m) 394 | if err != nil { 395 | t.Error(err) 396 | } 397 | } 398 | 399 | }) 400 | } 401 | 402 | func TestNilReceiverContext(t *testing.T) { 403 | RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { 404 | loadDefaultFixtureContext(ctx, db, t) 405 | var p *Person 406 | err := db.GetContext(ctx, p, "SELECT * FROM person LIMIT 1") 407 | if err == nil { 408 | t.Error("Expected error when getting into nil struct ptr.") 409 | } 410 | var pp *[]Person 411 | err = db.SelectContext(ctx, pp, "SELECT * FROM person") 412 | if err == nil { 413 | t.Error("Expected an error when selecting into nil slice ptr.") 414 | } 415 | }) 416 | } 417 | 418 | func TestNamedQueryContext(t *testing.T) { 419 | var schema = Schema{ 420 | create: ` 421 | CREATE TABLE place ( 422 | id integer PRIMARY KEY, 423 | name text NULL 424 | ); 425 | CREATE TABLE person ( 426 | first_name text NULL, 427 | last_name text NULL, 428 | email text NULL 429 | ); 430 | CREATE TABLE placeperson ( 431 | first_name text NULL, 432 | last_name text NULL, 433 | email text NULL, 434 | place_id integer NULL 435 | ); 436 | CREATE TABLE jsperson ( 437 | "FIRST" text NULL, 438 | last_name text NULL, 439 | "EMAIL" text NULL 440 | );`, 441 | drop: ` 442 | drop table person; 443 | drop table jsperson; 444 | drop table place; 445 | drop table placeperson; 446 | `, 447 | } 448 | 449 | RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { 450 | type Person struct { 451 | FirstName sql.NullString `db:"first_name"` 452 | LastName sql.NullString `db:"last_name"` 453 | Email sql.NullString 454 | } 455 | 456 | p := Person{ 457 | FirstName: sql.NullString{String: "ben", Valid: true}, 458 | LastName: sql.NullString{String: "doe", Valid: true}, 459 | Email: sql.NullString{String: "ben@doe.com", Valid: true}, 460 | } 461 | 462 | q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` 463 | _, err := db.NamedExecContext(ctx, q1, p) 464 | if err != nil { 465 | log.Fatal(err) 466 | } 467 | 468 | p2 := &Person{} 469 | rows, err := db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p) 470 | if err != nil { 471 | log.Fatal(err) 472 | } 473 | for rows.Next() { 474 | err = rows.StructScan(p2) 475 | if err != nil { 476 | t.Error(err) 477 | } 478 | if p2.FirstName.String != "ben" { 479 | t.Error("Expected first name of `ben`, got " + p2.FirstName.String) 480 | } 481 | if p2.LastName.String != "doe" { 482 | t.Error("Expected first name of `doe`, got " + p2.LastName.String) 483 | } 484 | } 485 | 486 | // these are tests for #73; they verify that named queries work if you've 487 | // changed the db mapper. This code checks both NamedQuery "ad-hoc" style 488 | // queries and NamedStmt queries, which use different code paths internally. 489 | old := (*db).Mapper 490 | 491 | type JSONPerson struct { 492 | FirstName sql.NullString `json:"FIRST"` 493 | LastName sql.NullString `json:"last_name"` 494 | Email sql.NullString 495 | } 496 | 497 | jp := JSONPerson{ 498 | FirstName: sql.NullString{String: "ben", Valid: true}, 499 | LastName: sql.NullString{String: "smith", Valid: true}, 500 | Email: sql.NullString{String: "ben@smith.com", Valid: true}, 501 | } 502 | 503 | db.Mapper = reflectx.NewMapperFunc("json", strings.ToUpper) 504 | 505 | // prepare queries for case sensitivity to test our ToUpper function. 506 | // postgres and sqlite accept "", but mysql uses ``; since Go's multi-line 507 | // strings are `` we use "" by default and swap out for MySQL 508 | pdb := func(s string, db *DB) string { 509 | if db.DriverName() == "mysql" { 510 | return strings.Replace(s, `"`, "`", -1) 511 | } 512 | return s 513 | } 514 | 515 | q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` 516 | _, err = db.NamedExecContext(ctx, pdb(q1, db), jp) 517 | if err != nil { 518 | t.Fatal(err, db.DriverName()) 519 | } 520 | 521 | // Checks that a person pulled out of the db matches the one we put in 522 | check := func(t *testing.T, rows *Rows) { 523 | jp = JSONPerson{} 524 | for rows.Next() { 525 | err = rows.StructScan(&jp) 526 | if err != nil { 527 | t.Error(err) 528 | } 529 | if jp.FirstName.String != "ben" { 530 | t.Errorf("Expected first name of `ben`, got `%s` (%s) ", jp.FirstName.String, db.DriverName()) 531 | } 532 | if jp.LastName.String != "smith" { 533 | t.Errorf("Expected LastName of `smith`, got `%s` (%s)", jp.LastName.String, db.DriverName()) 534 | } 535 | if jp.Email.String != "ben@smith.com" { 536 | t.Errorf("Expected first name of `doe`, got `%s` (%s)", jp.Email.String, db.DriverName()) 537 | } 538 | } 539 | } 540 | 541 | ns, err := db.PrepareNamed(pdb(` 542 | SELECT * FROM jsperson 543 | WHERE 544 | "FIRST"=:FIRST AND 545 | last_name=:last_name AND 546 | "EMAIL"=:EMAIL 547 | `, db)) 548 | 549 | if err != nil { 550 | t.Fatal(err) 551 | } 552 | rows, err = ns.QueryxContext(ctx, jp) 553 | if err != nil { 554 | t.Fatal(err) 555 | } 556 | 557 | check(t, rows) 558 | 559 | // Check exactly the same thing, but with db.NamedQuery, which does not go 560 | // through the PrepareNamed/NamedStmt path. 561 | rows, err = db.NamedQueryContext(ctx, pdb(` 562 | SELECT * FROM jsperson 563 | WHERE 564 | "FIRST"=:FIRST AND 565 | last_name=:last_name AND 566 | "EMAIL"=:EMAIL 567 | `, db), jp) 568 | if err != nil { 569 | t.Fatal(err) 570 | } 571 | 572 | check(t, rows) 573 | 574 | db.Mapper = old 575 | 576 | // Test nested structs 577 | type Place struct { 578 | ID int `db:"id"` 579 | Name sql.NullString `db:"name"` 580 | } 581 | type PlacePerson struct { 582 | FirstName sql.NullString `db:"first_name"` 583 | LastName sql.NullString `db:"last_name"` 584 | Email sql.NullString 585 | Place Place `db:"place"` 586 | } 587 | 588 | pl := Place{ 589 | Name: sql.NullString{String: "myplace", Valid: true}, 590 | } 591 | 592 | pp := PlacePerson{ 593 | FirstName: sql.NullString{String: "ben", Valid: true}, 594 | LastName: sql.NullString{String: "doe", Valid: true}, 595 | Email: sql.NullString{String: "ben@doe.com", Valid: true}, 596 | } 597 | 598 | q2 := `INSERT INTO place (id, name) VALUES (1, :name)` 599 | _, err = db.NamedExecContext(ctx, q2, pl) 600 | if err != nil { 601 | log.Fatal(err) 602 | } 603 | 604 | id := 1 605 | pp.Place.ID = id 606 | 607 | q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` 608 | _, err = db.NamedExecContext(ctx, q3, pp) 609 | if err != nil { 610 | log.Fatal(err) 611 | } 612 | 613 | pp2 := &PlacePerson{} 614 | rows, err = db.NamedQueryContext(ctx, ` 615 | SELECT 616 | first_name, 617 | last_name, 618 | email, 619 | place.id AS "place.id", 620 | place.name AS "place.name" 621 | FROM placeperson 622 | INNER JOIN place ON place.id = placeperson.place_id 623 | WHERE 624 | place.id=:place.id`, pp) 625 | if err != nil { 626 | log.Fatal(err) 627 | } 628 | for rows.Next() { 629 | err = rows.StructScan(pp2) 630 | if err != nil { 631 | t.Error(err) 632 | } 633 | if pp2.FirstName.String != "ben" { 634 | t.Error("Expected first name of `ben`, got " + pp2.FirstName.String) 635 | } 636 | if pp2.LastName.String != "doe" { 637 | t.Error("Expected first name of `doe`, got " + pp2.LastName.String) 638 | } 639 | if pp2.Place.Name.String != "myplace" { 640 | t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String) 641 | } 642 | if pp2.Place.ID != pp.Place.ID { 643 | t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) 644 | } 645 | } 646 | }) 647 | } 648 | 649 | func TestNilInsertsContext(t *testing.T) { 650 | var schema = Schema{ 651 | create: ` 652 | CREATE TABLE tt ( 653 | id integer, 654 | value text NULL DEFAULT NULL 655 | );`, 656 | drop: "drop table tt;", 657 | } 658 | 659 | RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { 660 | type TT struct { 661 | ID int 662 | Value *string 663 | } 664 | var v, v2 TT 665 | r := db.Rebind 666 | 667 | db.MustExecContext(ctx, r(`INSERT INTO tt (id) VALUES (1)`)) 668 | db.GetContext(ctx, &v, r(`SELECT * FROM tt`)) 669 | if v.ID != 1 { 670 | t.Errorf("Expecting id of 1, got %v", v.ID) 671 | } 672 | if v.Value != nil { 673 | t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) 674 | } 675 | 676 | v.ID = 2 677 | // NOTE: this incidentally uncovered a bug which was that named queries with 678 | // pointer destinations would not work if the passed value here was not addressable, 679 | // as reflectx.FieldByIndexes attempts to allocate nil pointer receivers for 680 | // writing. This was fixed by creating & using the reflectx.FieldByIndexesReadOnly 681 | // function. This next line is important as it provides the only coverage for this. 682 | db.NamedExecContext(ctx, `INSERT INTO tt (id, value) VALUES (:id, :value)`, v) 683 | 684 | db.GetContext(ctx, &v2, r(`SELECT * FROM tt WHERE id=2`)) 685 | if v.ID != v2.ID { 686 | t.Errorf("%v != %v", v.ID, v2.ID) 687 | } 688 | if v2.Value != nil { 689 | t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) 690 | } 691 | }) 692 | } 693 | 694 | func TestScanErrorContext(t *testing.T) { 695 | var schema = Schema{ 696 | create: ` 697 | CREATE TABLE kv ( 698 | k text, 699 | v integer 700 | );`, 701 | drop: `drop table kv;`, 702 | } 703 | 704 | RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { 705 | type WrongTypes struct { 706 | K int 707 | V string 708 | } 709 | _, err := db.Exec(db.Rebind("INSERT INTO kv (k, v) VALUES (?, ?)"), "hi", 1) 710 | if err != nil { 711 | t.Error(err) 712 | } 713 | 714 | rows, err := db.QueryxContext(ctx, "SELECT * FROM kv") 715 | if err != nil { 716 | t.Error(err) 717 | } 718 | for rows.Next() { 719 | var wt WrongTypes 720 | err := rows.StructScan(&wt) 721 | if err == nil { 722 | t.Errorf("%s: Scanning wrong types into keys should have errored.", db.DriverName()) 723 | } 724 | } 725 | }) 726 | } 727 | 728 | // FIXME: this function is kinda big but it slows things down to be constantly 729 | // loading and reloading the schema.. 730 | 731 | func TestUsageContext(t *testing.T) { 732 | RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { 733 | loadDefaultFixtureContext(ctx, db, t) 734 | slicemembers := []SliceMember{} 735 | err := db.SelectContext(ctx, &slicemembers, "SELECT * FROM place ORDER BY telcode ASC") 736 | if err != nil { 737 | t.Fatal(err) 738 | } 739 | 740 | people := []Person{} 741 | 742 | err = db.SelectContext(ctx, &people, "SELECT * FROM person ORDER BY first_name ASC") 743 | if err != nil { 744 | t.Fatal(err) 745 | } 746 | 747 | jason, john := people[0], people[1] 748 | if jason.FirstName != "Jason" { 749 | t.Errorf("Expecting FirstName of Jason, got %s", jason.FirstName) 750 | } 751 | if jason.LastName != "Moiron" { 752 | t.Errorf("Expecting LastName of Moiron, got %s", jason.LastName) 753 | } 754 | if jason.Email != "jmoiron@jmoiron.net" { 755 | t.Errorf("Expecting Email of jmoiron@jmoiron.net, got %s", jason.Email) 756 | } 757 | if john.FirstName != "John" || john.LastName != "Doe" || john.Email != "johndoeDNE@gmail.net" { 758 | t.Errorf("John Doe's person record not what expected: Got %v\n", john) 759 | } 760 | 761 | jason = Person{} 762 | err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Jason") 763 | 764 | if err != nil { 765 | t.Fatal(err) 766 | } 767 | if jason.FirstName != "Jason" { 768 | t.Errorf("Expecting to get back Jason, but got %v\n", jason.FirstName) 769 | } 770 | 771 | err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Foobar") 772 | if err == nil { 773 | t.Errorf("Expecting an error, got nil\n") 774 | } 775 | if err != sql.ErrNoRows { 776 | t.Errorf("Expected sql.ErrNoRows, got %v\n", err) 777 | } 778 | 779 | // The following tests check statement reuse, which was actually a problem 780 | // due to copying being done when creating Stmt's which was eventually removed 781 | stmt1, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) 782 | if err != nil { 783 | t.Fatal(err) 784 | } 785 | jason = Person{} 786 | 787 | row := stmt1.QueryRowx("DoesNotExist") 788 | row.Scan(&jason) 789 | row = stmt1.QueryRowx("DoesNotExist") 790 | row.Scan(&jason) 791 | 792 | err = stmt1.GetContext(ctx, &jason, "DoesNotExist User") 793 | if err == nil { 794 | t.Error("Expected an error") 795 | } 796 | err = stmt1.GetContext(ctx, &jason, "DoesNotExist User 2") 797 | if err == nil { 798 | t.Fatal(err) 799 | } 800 | 801 | stmt2, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) 802 | if err != nil { 803 | t.Fatal(err) 804 | } 805 | jason = Person{} 806 | tx, err := db.Beginx() 807 | if err != nil { 808 | t.Fatal(err) 809 | } 810 | tstmt2 := tx.Stmtx(stmt2) 811 | row2 := tstmt2.QueryRowx("Jason") 812 | err = row2.StructScan(&jason) 813 | if err != nil { 814 | t.Error(err) 815 | } 816 | tx.Commit() 817 | 818 | places := []*Place{} 819 | err = db.SelectContext(ctx, &places, "SELECT telcode FROM place ORDER BY telcode ASC") 820 | if err != nil { 821 | t.Fatal(err) 822 | } 823 | 824 | usa, singsing, honkers := places[0], places[1], places[2] 825 | 826 | if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { 827 | t.Errorf("Expected integer telcodes to work, got %#v", places) 828 | } 829 | 830 | placesptr := []PlacePtr{} 831 | err = db.SelectContext(ctx, &placesptr, "SELECT * FROM place ORDER BY telcode ASC") 832 | if err != nil { 833 | t.Error(err) 834 | } 835 | // fmt.Printf("%#v\n%#v\n%#v\n", placesptr[0], placesptr[1], placesptr[2]) 836 | 837 | // if you have null fields and use SELECT *, you must use sql.Null* in your struct 838 | // this test also verifies that you can use either a []Struct{} or a []*Struct{} 839 | places2 := []Place{} 840 | err = db.SelectContext(ctx, &places2, "SELECT * FROM place ORDER BY telcode ASC") 841 | if err != nil { 842 | t.Fatal(err) 843 | } 844 | 845 | usa, singsing, honkers = &places2[0], &places2[1], &places2[2] 846 | 847 | // this should return a type error that &p is not a pointer to a struct slice 848 | p := Place{} 849 | err = db.SelectContext(ctx, &p, "SELECT * FROM place ORDER BY telcode ASC") 850 | if err == nil { 851 | t.Errorf("Expected an error, argument to select should be a pointer to a struct slice") 852 | } 853 | 854 | // this should be an error 855 | pl := []Place{} 856 | err = db.SelectContext(ctx, pl, "SELECT * FROM place ORDER BY telcode ASC") 857 | if err == nil { 858 | t.Errorf("Expected an error, argument to select should be a pointer to a struct slice, not a slice.") 859 | } 860 | 861 | if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { 862 | t.Errorf("Expected integer telcodes to work, got %#v", places) 863 | } 864 | 865 | stmt, err := db.PreparexContext(ctx, db.Rebind("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC")) 866 | if err != nil { 867 | t.Error(err) 868 | } 869 | 870 | places = []*Place{} 871 | err = stmt.SelectContext(ctx, &places, 10) 872 | if len(places) != 2 { 873 | t.Error("Expected 2 places, got 0.") 874 | } 875 | if err != nil { 876 | t.Fatal(err) 877 | } 878 | singsing, honkers = places[0], places[1] 879 | if singsing.TelCode != 65 || honkers.TelCode != 852 { 880 | t.Errorf("Expected the right telcodes, got %#v", places) 881 | } 882 | 883 | rows, err := db.QueryxContext(ctx, "SELECT * FROM place") 884 | if err != nil { 885 | t.Fatal(err) 886 | } 887 | place := Place{} 888 | for rows.Next() { 889 | err = rows.StructScan(&place) 890 | if err != nil { 891 | t.Fatal(err) 892 | } 893 | } 894 | 895 | rows, err = db.QueryxContext(ctx, "SELECT * FROM place") 896 | if err != nil { 897 | t.Fatal(err) 898 | } 899 | m := map[string]interface{}{} 900 | for rows.Next() { 901 | err = rows.MapScan(m) 902 | if err != nil { 903 | t.Fatal(err) 904 | } 905 | _, ok := m["country"] 906 | if !ok { 907 | t.Errorf("Expected key `country` in map but could not find it (%#v)\n", m) 908 | } 909 | } 910 | 911 | rows, err = db.QueryxContext(ctx, "SELECT * FROM place") 912 | if err != nil { 913 | t.Fatal(err) 914 | } 915 | for rows.Next() { 916 | s, err := rows.SliceScan() 917 | if err != nil { 918 | t.Error(err) 919 | } 920 | if len(s) != 3 { 921 | t.Errorf("Expected 3 columns in result, got %d\n", len(s)) 922 | } 923 | } 924 | 925 | // test advanced querying 926 | // test that NamedExec works with a map as well as a struct 927 | _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)", map[string]interface{}{ 928 | "first": "Bin", 929 | "last": "Smuth", 930 | "email": "bensmith@allblacks.nz", 931 | }) 932 | if err != nil { 933 | t.Fatal(err) 934 | } 935 | 936 | // ensure that if the named param happens right at the end it still works 937 | // ensure that NamedQuery works with a map[string]interface{} 938 | rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first", map[string]interface{}{"first": "Bin"}) 939 | if err != nil { 940 | t.Fatal(err) 941 | } 942 | 943 | ben := &Person{} 944 | for rows.Next() { 945 | err = rows.StructScan(ben) 946 | if err != nil { 947 | t.Fatal(err) 948 | } 949 | if ben.FirstName != "Bin" { 950 | t.Fatal("Expected first name of `Bin`, got " + ben.FirstName) 951 | } 952 | if ben.LastName != "Smuth" { 953 | t.Fatal("Expected first name of `Smuth`, got " + ben.LastName) 954 | } 955 | } 956 | 957 | ben.FirstName = "Ben" 958 | ben.LastName = "Smith" 959 | ben.Email = "binsmuth@allblacks.nz" 960 | 961 | // Insert via a named query using the struct 962 | _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", ben) 963 | 964 | if err != nil { 965 | t.Fatal(err) 966 | } 967 | 968 | rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", ben) 969 | if err != nil { 970 | t.Fatal(err) 971 | } 972 | for rows.Next() { 973 | err = rows.StructScan(ben) 974 | if err != nil { 975 | t.Fatal(err) 976 | } 977 | if ben.FirstName != "Ben" { 978 | t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) 979 | } 980 | if ben.LastName != "Smith" { 981 | t.Fatal("Expected first name of `Smith`, got " + ben.LastName) 982 | } 983 | } 984 | // ensure that Get does not panic on emppty result set 985 | person := &Person{} 986 | err = db.GetContext(ctx, person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist") 987 | if err == nil { 988 | t.Fatal("Should have got an error for Get on non-existent row.") 989 | } 990 | 991 | // lets test prepared statements some more 992 | 993 | stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) 994 | if err != nil { 995 | t.Fatal(err) 996 | } 997 | rows, err = stmt.QueryxContext(ctx, "Ben") 998 | if err != nil { 999 | t.Fatal(err) 1000 | } 1001 | for rows.Next() { 1002 | err = rows.StructScan(ben) 1003 | if err != nil { 1004 | t.Fatal(err) 1005 | } 1006 | if ben.FirstName != "Ben" { 1007 | t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) 1008 | } 1009 | if ben.LastName != "Smith" { 1010 | t.Fatal("Expected first name of `Smith`, got " + ben.LastName) 1011 | } 1012 | } 1013 | 1014 | john = Person{} 1015 | stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) 1016 | if err != nil { 1017 | t.Error(err) 1018 | } 1019 | err = stmt.GetContext(ctx, &john, "John") 1020 | if err != nil { 1021 | t.Error(err) 1022 | } 1023 | 1024 | // test name mapping 1025 | // THIS USED TO WORK BUT WILL NO LONGER WORK. 1026 | db.MapperFunc(strings.ToUpper) 1027 | rsa := CPlace{} 1028 | err = db.GetContext(ctx, &rsa, "SELECT * FROM capplace;") 1029 | if err != nil { 1030 | t.Error(err, "in db:", db.DriverName()) 1031 | } 1032 | db.MapperFunc(strings.ToLower) 1033 | 1034 | // create a copy and change the mapper, then verify the copy behaves 1035 | // differently from the original. 1036 | dbCopy := NewDb(db.DB, db.DriverName()) 1037 | dbCopy.MapperFunc(strings.ToUpper) 1038 | err = dbCopy.GetContext(ctx, &rsa, "SELECT * FROM capplace;") 1039 | if err != nil { 1040 | fmt.Println(db.DriverName()) 1041 | t.Error(err) 1042 | } 1043 | 1044 | err = db.GetContext(ctx, &rsa, "SELECT * FROM cappplace;") 1045 | if err == nil { 1046 | t.Error("Expected no error, got ", err) 1047 | } 1048 | 1049 | // test base type slices 1050 | var sdest []string 1051 | rows, err = db.QueryxContext(ctx, "SELECT email FROM person ORDER BY email ASC;") 1052 | if err != nil { 1053 | t.Error(err) 1054 | } 1055 | err = scanAll(rows, &sdest, false) 1056 | if err != nil { 1057 | t.Error(err) 1058 | } 1059 | 1060 | // test Get with base types 1061 | var count int 1062 | err = db.GetContext(ctx, &count, "SELECT count(*) FROM person;") 1063 | if err != nil { 1064 | t.Error(err) 1065 | } 1066 | if count != len(sdest) { 1067 | t.Errorf("Expected %d == %d (count(*) vs len(SELECT ..)", count, len(sdest)) 1068 | } 1069 | 1070 | // test Get and Select with time.Time, #84 1071 | var addedAt time.Time 1072 | err = db.GetContext(ctx, &addedAt, "SELECT added_at FROM person LIMIT 1;") 1073 | if err != nil { 1074 | t.Error(err) 1075 | } 1076 | 1077 | var addedAts []time.Time 1078 | err = db.SelectContext(ctx, &addedAts, "SELECT added_at FROM person;") 1079 | if err != nil { 1080 | t.Error(err) 1081 | } 1082 | 1083 | // test it on a double pointer 1084 | var pcount *int 1085 | err = db.GetContext(ctx, &pcount, "SELECT count(*) FROM person;") 1086 | if err != nil { 1087 | t.Error(err) 1088 | } 1089 | if *pcount != count { 1090 | t.Errorf("expected %d = %d", *pcount, count) 1091 | } 1092 | 1093 | // test Select... 1094 | sdest = []string{} 1095 | err = db.SelectContext(ctx, &sdest, "SELECT first_name FROM person ORDER BY first_name ASC;") 1096 | if err != nil { 1097 | t.Error(err) 1098 | } 1099 | expected := []string{"Ben", "Bin", "Jason", "John"} 1100 | for i, got := range sdest { 1101 | if got != expected[i] { 1102 | t.Errorf("Expected %d result to be %s, but got %s", i, expected[i], got) 1103 | } 1104 | } 1105 | 1106 | var nsdest []sql.NullString 1107 | err = db.SelectContext(ctx, &nsdest, "SELECT city FROM place ORDER BY city ASC") 1108 | if err != nil { 1109 | t.Error(err) 1110 | } 1111 | for _, val := range nsdest { 1112 | if val.Valid && val.String != "New York" { 1113 | t.Errorf("expected single valid result to be `New York`, but got %s", val.String) 1114 | } 1115 | } 1116 | }) 1117 | } 1118 | 1119 | // tests that sqlx will not panic when the wrong driver is passed because 1120 | // of an automatic nil dereference in sqlx.Open(), which was fixed. 1121 | func TestDoNotPanicOnConnectContext(t *testing.T) { 1122 | _, err := ConnectContext(context.Background(), "bogus", "hehe") 1123 | if err == nil { 1124 | t.Errorf("Should return error when using bogus driverName") 1125 | } 1126 | } 1127 | 1128 | func TestEmbeddedMapsContext(t *testing.T) { 1129 | var schema = Schema{ 1130 | create: ` 1131 | CREATE TABLE message ( 1132 | string text, 1133 | properties text 1134 | );`, 1135 | drop: `drop table message;`, 1136 | } 1137 | 1138 | RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { 1139 | messages := []Message{ 1140 | {"Hello, World", PropertyMap{"one": "1", "two": "2"}}, 1141 | {"Thanks, Joy", PropertyMap{"pull": "request"}}, 1142 | } 1143 | q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);` 1144 | for _, m := range messages { 1145 | _, err := db.NamedExecContext(ctx, q1, m) 1146 | if err != nil { 1147 | t.Fatal(err) 1148 | } 1149 | } 1150 | var count int 1151 | err := db.GetContext(ctx, &count, "SELECT count(*) FROM message") 1152 | if err != nil { 1153 | t.Fatal(err) 1154 | } 1155 | if count != len(messages) { 1156 | t.Fatalf("Expected %d messages in DB, found %d", len(messages), count) 1157 | } 1158 | 1159 | var m Message 1160 | err = db.GetContext(ctx, &m, "SELECT * FROM message LIMIT 1;") 1161 | if err != nil { 1162 | t.Fatal(err) 1163 | } 1164 | if m.Properties == nil { 1165 | t.Fatal("Expected m.Properties to not be nil, but it was.") 1166 | } 1167 | }) 1168 | } 1169 | 1170 | func TestIssue197Context(t *testing.T) { 1171 | // this test actually tests for a bug in database/sql: 1172 | // https://github.com/golang/go/issues/13905 1173 | // this potentially makes _any_ named type that is an alias for []byte 1174 | // unsafe to use in a lot of different ways (basically, unsafe to hold 1175 | // onto after loading from the database). 1176 | t.Skip() 1177 | 1178 | type mybyte []byte 1179 | type Var struct{ Raw json.RawMessage } 1180 | type Var2 struct{ Raw []byte } 1181 | type Var3 struct{ Raw mybyte } 1182 | RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { 1183 | var err error 1184 | var v, q Var 1185 | if err = db.GetContext(ctx, &v, `SELECT '{"a": "b"}' AS raw`); err != nil { 1186 | t.Fatal(err) 1187 | } 1188 | if err = db.GetContext(ctx, &q, `SELECT 'null' AS raw`); err != nil { 1189 | t.Fatal(err) 1190 | } 1191 | 1192 | var v2, q2 Var2 1193 | if err = db.GetContext(ctx, &v2, `SELECT '{"a": "b"}' AS raw`); err != nil { 1194 | t.Fatal(err) 1195 | } 1196 | if err = db.GetContext(ctx, &q2, `SELECT 'null' AS raw`); err != nil { 1197 | t.Fatal(err) 1198 | } 1199 | 1200 | var v3, q3 Var3 1201 | if err = db.QueryRowContext(ctx, `SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { 1202 | t.Fatal(err) 1203 | } 1204 | if err = db.QueryRowContext(ctx, `SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { 1205 | t.Fatal(err) 1206 | } 1207 | t.Fail() 1208 | }) 1209 | } 1210 | 1211 | func TestInContext(t *testing.T) { 1212 | // some quite normal situations 1213 | type tr struct { 1214 | q string 1215 | args []interface{} 1216 | c int 1217 | } 1218 | tests := []tr{ 1219 | {"SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?", 1220 | []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}, 1221 | 7}, 1222 | {"SELECT * FROM foo WHERE x in (?)", 1223 | []interface{}{[]int{1, 2, 3, 4, 5, 6, 7, 8}}, 1224 | 8}, 1225 | } 1226 | for _, test := range tests { 1227 | q, a, err := In(test.q, test.args...) 1228 | if err != nil { 1229 | t.Error(err) 1230 | } 1231 | if len(a) != test.c { 1232 | t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a) 1233 | } 1234 | if strings.Count(q, "?") != test.c { 1235 | t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?")) 1236 | } 1237 | } 1238 | 1239 | // too many bindVars, but no slices, so short circuits parsing 1240 | // i'm not sure if this is the right behavior; this query/arg combo 1241 | // might not work, but we shouldn't parse if we don't need to 1242 | { 1243 | orig := "SELECT * FROM foo WHERE x = ? AND y = ?" 1244 | q, a, err := In(orig, "foo", "bar", "baz") 1245 | if err != nil { 1246 | t.Error(err) 1247 | } 1248 | if len(a) != 3 { 1249 | t.Errorf("Expected 3 args, but got %d (%+v)", len(a), a) 1250 | } 1251 | if q != orig { 1252 | t.Error("Expected unchanged query.") 1253 | } 1254 | } 1255 | 1256 | tests = []tr{ 1257 | // too many bindvars; slice present so should return error during parse 1258 | {"SELECT * FROM foo WHERE x = ? and y = ?", 1259 | []interface{}{"foo", []int{1, 2, 3}, "bar"}, 1260 | 0}, 1261 | // empty slice, should return error before parse 1262 | {"SELECT * FROM foo WHERE x = ?", 1263 | []interface{}{[]int{}}, 1264 | 0}, 1265 | // too *few* bindvars, should return an error 1266 | {"SELECT * FROM foo WHERE x = ? AND y in (?)", 1267 | []interface{}{[]int{1, 2, 3}}, 1268 | 0}, 1269 | } 1270 | for _, test := range tests { 1271 | _, _, err := In(test.q, test.args...) 1272 | if err == nil { 1273 | t.Error("Expected an error, but got nil.") 1274 | } 1275 | } 1276 | RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { 1277 | loadDefaultFixtureContext(ctx, db, t) 1278 | // tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") 1279 | // tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") 1280 | // tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") 1281 | telcodes := []int{852, 65} 1282 | q := "SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode" 1283 | query, args, err := In(q, telcodes) 1284 | if err != nil { 1285 | t.Error(err) 1286 | } 1287 | query = db.Rebind(query) 1288 | places := []Place{} 1289 | err = db.SelectContext(ctx, &places, query, args...) 1290 | if err != nil { 1291 | t.Error(err) 1292 | } 1293 | if len(places) != 2 { 1294 | t.Fatalf("Expecting 2 results, got %d", len(places)) 1295 | } 1296 | if places[0].TelCode != 65 { 1297 | t.Errorf("Expecting singapore first, but got %#v", places[0]) 1298 | } 1299 | if places[1].TelCode != 852 { 1300 | t.Errorf("Expecting hong kong second, but got %#v", places[1]) 1301 | } 1302 | }) 1303 | } 1304 | 1305 | func TestEmbeddedLiteralsContext(t *testing.T) { 1306 | var schema = Schema{ 1307 | create: ` 1308 | CREATE TABLE x ( 1309 | k text 1310 | );`, 1311 | drop: `drop table x;`, 1312 | } 1313 | 1314 | RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { 1315 | type t1 struct { 1316 | K *string 1317 | } 1318 | type t2 struct { 1319 | Inline struct { 1320 | F string 1321 | } 1322 | K *string 1323 | } 1324 | 1325 | db.MustExecContext(ctx, db.Rebind("INSERT INTO x (k) VALUES (?), (?), (?);"), "one", "two", "three") 1326 | 1327 | target := t1{} 1328 | err := db.GetContext(ctx, &target, db.Rebind("SELECT * FROM x WHERE k=?"), "one") 1329 | if err != nil { 1330 | t.Error(err) 1331 | } 1332 | if *target.K != "one" { 1333 | t.Error("Expected target.K to be `one`, got ", target.K) 1334 | } 1335 | 1336 | target2 := t2{} 1337 | err = db.GetContext(ctx, &target2, db.Rebind("SELECT * FROM x WHERE k=?"), "one") 1338 | if err != nil { 1339 | t.Error(err) 1340 | } 1341 | if *target2.K != "one" { 1342 | t.Errorf("Expected target2.K to be `one`, got `%v`", target2.K) 1343 | } 1344 | }) 1345 | } 1346 | 1347 | func TestConn(t *testing.T) { 1348 | var schema = Schema{ 1349 | create: ` 1350 | CREATE TABLE tt_conn ( 1351 | id integer, 1352 | value text NULL DEFAULT NULL 1353 | );`, 1354 | drop: "drop table tt_conn;", 1355 | } 1356 | 1357 | RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { 1358 | conn, err := db.Connx(ctx) 1359 | defer conn.Close() //lint:ignore SA5001 it's OK to ignore this here. 1360 | if err != nil { 1361 | t.Fatal(err) 1362 | } 1363 | 1364 | _, err = conn.ExecContext(ctx, conn.Rebind(`INSERT INTO tt_conn (id, value) VALUES (?, ?), (?, ?)`), 1, "a", 2, "b") 1365 | if err != nil { 1366 | t.Fatal(err) 1367 | } 1368 | 1369 | type s struct { 1370 | ID int `db:"id"` 1371 | Value string `db:"value"` 1372 | } 1373 | 1374 | v := []s{} 1375 | 1376 | err = conn.SelectContext(ctx, &v, "SELECT * FROM tt_conn ORDER BY id ASC") 1377 | if err != nil { 1378 | t.Fatal(err) 1379 | } 1380 | 1381 | if v[0].ID != 1 { 1382 | t.Errorf("Expecting ID of 1, got %d", v[0].ID) 1383 | } 1384 | 1385 | v1 := s{} 1386 | err = conn.GetContext(ctx, &v1, conn.Rebind("SELECT * FROM tt_conn WHERE id=?"), 1) 1387 | 1388 | if err != nil { 1389 | t.Fatal(err) 1390 | } 1391 | if v1.ID != 1 { 1392 | t.Errorf("Expecting to get back 1, but got %v\n", v1.ID) 1393 | } 1394 | 1395 | stmt, err := conn.PreparexContext(ctx, conn.Rebind("SELECT * FROM tt_conn WHERE id=?")) 1396 | if err != nil { 1397 | t.Fatal(err) 1398 | } 1399 | v1 = s{} 1400 | tx, err := conn.BeginTxx(ctx, nil) 1401 | if err != nil { 1402 | t.Fatal(err) 1403 | } 1404 | tstmt := tx.Stmtx(stmt) 1405 | row := tstmt.QueryRowx(1) 1406 | err = row.StructScan(&v1) 1407 | if err != nil { 1408 | t.Error(err) 1409 | } 1410 | tx.Commit() 1411 | if v1.ID != 1 { 1412 | t.Errorf("Expecting to get back 1, but got %v\n", v1.ID) 1413 | } 1414 | 1415 | rows, err := conn.QueryxContext(ctx, "SELECT * FROM tt_conn") 1416 | if err != nil { 1417 | t.Fatal(err) 1418 | } 1419 | 1420 | for rows.Next() { 1421 | err = rows.StructScan(&v1) 1422 | if err != nil { 1423 | t.Fatal(err) 1424 | } 1425 | } 1426 | }) 1427 | } 1428 | -------------------------------------------------------------------------------- /types/README.md: -------------------------------------------------------------------------------- 1 | # types 2 | 3 | The types package provides some useful types which implement the `sql.Scanner` 4 | and `driver.Valuer` interfaces, suitable for use as scan and value targets with 5 | database/sql. 6 | -------------------------------------------------------------------------------- /types/doc.go: -------------------------------------------------------------------------------- 1 | // Package types provides some useful types which implement the `sql.Scanner` 2 | // and `driver.Valuer` interfaces, suitable for use as scan and value targets with 3 | // database/sql. 4 | package types 5 | -------------------------------------------------------------------------------- /types/types.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "database/sql/driver" 7 | "encoding/json" 8 | "errors" 9 | "io/ioutil" 10 | ) 11 | 12 | // GzippedText is a []byte which transparently gzips data being submitted to 13 | // a database and ungzips data being Scanned from a database. 14 | type GzippedText []byte 15 | 16 | // Value implements the driver.Valuer interface, gzipping the raw value of 17 | // this GzippedText. 18 | func (g GzippedText) Value() (driver.Value, error) { 19 | b := make([]byte, 0, len(g)) 20 | buf := bytes.NewBuffer(b) 21 | w := gzip.NewWriter(buf) 22 | w.Write(g) 23 | w.Close() 24 | return buf.Bytes(), nil 25 | 26 | } 27 | 28 | // Scan implements the sql.Scanner interface, ungzipping the value coming off 29 | // the wire and storing the raw result in the GzippedText. 30 | func (g *GzippedText) Scan(src interface{}) error { 31 | var source []byte 32 | switch src := src.(type) { 33 | case string: 34 | source = []byte(src) 35 | case []byte: 36 | source = src 37 | default: 38 | //lint:ignore ST1005 changing this could break consumers of this package 39 | return errors.New("Incompatible type for GzippedText") 40 | } 41 | reader, err := gzip.NewReader(bytes.NewReader(source)) 42 | if err != nil { 43 | return err 44 | } 45 | defer reader.Close() 46 | b, err := ioutil.ReadAll(reader) 47 | if err != nil { 48 | return err 49 | } 50 | *g = GzippedText(b) 51 | return nil 52 | } 53 | 54 | // JSONText is a json.RawMessage, which is a []byte underneath. 55 | // Value() validates the json format in the source, and returns an error if 56 | // the json is not valid. Scan does no validation. JSONText additionally 57 | // implements `Unmarshal`, which unmarshals the json within to an interface{} 58 | type JSONText json.RawMessage 59 | 60 | var emptyJSON = JSONText("{}") 61 | 62 | // MarshalJSON returns the *j as the JSON encoding of j. 63 | func (j JSONText) MarshalJSON() ([]byte, error) { 64 | if len(j) == 0 { 65 | return emptyJSON, nil 66 | } 67 | return j, nil 68 | } 69 | 70 | // UnmarshalJSON sets *j to a copy of data 71 | func (j *JSONText) UnmarshalJSON(data []byte) error { 72 | if j == nil { 73 | return errors.New("JSONText: UnmarshalJSON on nil pointer") 74 | } 75 | *j = append((*j)[0:0], data...) 76 | return nil 77 | } 78 | 79 | // Value returns j as a value. This does a validating unmarshal into another 80 | // RawMessage. If j is invalid json, it returns an error. 81 | func (j JSONText) Value() (driver.Value, error) { 82 | var m json.RawMessage 83 | var err = j.Unmarshal(&m) 84 | if err != nil { 85 | return []byte{}, err 86 | } 87 | return []byte(j), nil 88 | } 89 | 90 | // Scan stores the src in *j. No validation is done. 91 | func (j *JSONText) Scan(src interface{}) error { 92 | var source []byte 93 | switch t := src.(type) { 94 | case string: 95 | source = []byte(t) 96 | case []byte: 97 | if len(t) == 0 { 98 | source = emptyJSON 99 | } else { 100 | source = t 101 | } 102 | case nil: 103 | *j = emptyJSON 104 | default: 105 | //lint:ignore ST1005 changing this could break consumers of this package 106 | return errors.New("Incompatible type for JSONText") 107 | } 108 | *j = append((*j)[0:0], source...) 109 | return nil 110 | } 111 | 112 | // Unmarshal unmarshal's the json in j to v, as in json.Unmarshal. 113 | func (j *JSONText) Unmarshal(v interface{}) error { 114 | if len(*j) == 0 { 115 | *j = emptyJSON 116 | } 117 | return json.Unmarshal([]byte(*j), v) 118 | } 119 | 120 | // String supports pretty printing for JSONText types. 121 | func (j JSONText) String() string { 122 | return string(j) 123 | } 124 | 125 | // NullJSONText represents a JSONText that may be null. 126 | // NullJSONText implements the scanner interface so 127 | // it can be used as a scan destination, similar to NullString. 128 | type NullJSONText struct { 129 | JSONText 130 | Valid bool // Valid is true if JSONText is not NULL 131 | } 132 | 133 | // Scan implements the Scanner interface. 134 | func (n *NullJSONText) Scan(value interface{}) error { 135 | if value == nil { 136 | n.JSONText, n.Valid = emptyJSON, false 137 | return nil 138 | } 139 | n.Valid = true 140 | return n.JSONText.Scan(value) 141 | } 142 | 143 | // Value implements the driver Valuer interface. 144 | func (n NullJSONText) Value() (driver.Value, error) { 145 | if !n.Valid { 146 | return nil, nil 147 | } 148 | return n.JSONText.Value() 149 | } 150 | 151 | // BitBool is an implementation of a bool for the MySQL type BIT(1). 152 | // This type allows you to avoid wasting an entire byte for MySQL's boolean type TINYINT. 153 | type BitBool bool 154 | 155 | // Value implements the driver.Valuer interface, 156 | // and turns the BitBool into a bitfield (BIT(1)) for MySQL storage. 157 | func (b BitBool) Value() (driver.Value, error) { 158 | if b { 159 | return []byte{1}, nil 160 | } 161 | return []byte{0}, nil 162 | } 163 | 164 | // Scan implements the sql.Scanner interface, 165 | // and turns the bitfield incoming from MySQL into a BitBool 166 | func (b *BitBool) Scan(src interface{}) error { 167 | v, ok := src.([]byte) 168 | if !ok { 169 | return errors.New("bad []byte type assertion") 170 | } 171 | *b = v[0] == 1 172 | return nil 173 | } 174 | -------------------------------------------------------------------------------- /types/types_test.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import "testing" 4 | 5 | func TestGzipText(t *testing.T) { 6 | g := GzippedText("Hello, world") 7 | v, err := g.Value() 8 | if err != nil { 9 | t.Errorf("Was not expecting an error") 10 | } 11 | err = (&g).Scan(v) 12 | if err != nil { 13 | t.Errorf("Was not expecting an error") 14 | } 15 | if string(g) != "Hello, world" { 16 | t.Errorf("Was expecting the string we sent in (Hello World), got %s", string(g)) 17 | } 18 | } 19 | 20 | func TestJSONText(t *testing.T) { 21 | j := JSONText(`{"foo": 1, "bar": 2}`) 22 | v, err := j.Value() 23 | if err != nil { 24 | t.Errorf("Was not expecting an error") 25 | } 26 | err = (&j).Scan(v) 27 | if err != nil { 28 | t.Errorf("Was not expecting an error") 29 | } 30 | m := map[string]interface{}{} 31 | j.Unmarshal(&m) 32 | 33 | if m["foo"].(float64) != 1 || m["bar"].(float64) != 2 { 34 | t.Errorf("Expected valid json but got some garbage instead? %#v", m) 35 | } 36 | 37 | j = JSONText(`{"foo": 1, invalid, false}`) 38 | _, err = j.Value() 39 | if err == nil { 40 | t.Errorf("Was expecting invalid json to fail!") 41 | } 42 | 43 | j = JSONText("") 44 | v, err = j.Value() 45 | if err != nil { 46 | t.Errorf("Was not expecting an error") 47 | } 48 | 49 | err = (&j).Scan(v) 50 | if err != nil { 51 | t.Errorf("Was not expecting an error") 52 | } 53 | 54 | j = JSONText(nil) 55 | v, err = j.Value() 56 | if err != nil { 57 | t.Errorf("Was not expecting an error") 58 | } 59 | 60 | err = (&j).Scan(v) 61 | if err != nil { 62 | t.Errorf("Was not expecting an error") 63 | } 64 | } 65 | 66 | func TestNullJSONText(t *testing.T) { 67 | j := NullJSONText{} 68 | err := j.Scan(`{"foo": 1, "bar": 2}`) 69 | if err != nil { 70 | t.Errorf("Was not expecting an error") 71 | } 72 | v, err := j.Value() 73 | if err != nil { 74 | t.Errorf("Was not expecting an error") 75 | } 76 | err = (&j).Scan(v) 77 | if err != nil { 78 | t.Errorf("Was not expecting an error") 79 | } 80 | m := map[string]interface{}{} 81 | j.Unmarshal(&m) 82 | 83 | if m["foo"].(float64) != 1 || m["bar"].(float64) != 2 { 84 | t.Errorf("Expected valid json but got some garbage instead? %#v", m) 85 | } 86 | 87 | j = NullJSONText{} 88 | err = j.Scan(nil) 89 | if err != nil { 90 | t.Errorf("Was not expecting an error") 91 | } 92 | if j.Valid != false { 93 | t.Errorf("Expected valid to be false, but got true") 94 | } 95 | } 96 | 97 | func TestBitBool(t *testing.T) { 98 | // Test true value 99 | var b BitBool = true 100 | 101 | v, err := b.Value() 102 | if err != nil { 103 | t.Errorf("Cannot return error") 104 | } 105 | err = (&b).Scan(v) 106 | if err != nil { 107 | t.Errorf("Was not expecting an error") 108 | } 109 | if !b { 110 | t.Errorf("Was expecting the bool we sent in (true), got %v", b) 111 | } 112 | 113 | // Test false value 114 | b = false 115 | 116 | v, err = b.Value() 117 | if err != nil { 118 | t.Errorf("Cannot return error") 119 | } 120 | err = (&b).Scan(v) 121 | if err != nil { 122 | t.Errorf("Was not expecting an error") 123 | } 124 | if b { 125 | t.Errorf("Was expecting the bool we sent in (false), got %v", b) 126 | } 127 | } 128 | --------------------------------------------------------------------------------