├── .gitignore ├── LICENSE.txt ├── Makefile ├── README.md ├── example ├── fence.go ├── users.go └── users.sql ├── glide.lock ├── glide.yaml ├── schema2struct ├── Makefile ├── README.md └── schema2struct.go ├── sqlite_pod_test.go ├── sqlite_ptr_test.go ├── structable.go └── structable_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | vendor 2 | *.test 3 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Structable 2 | The Masterminds 3 | Copyright (C) 2014 Matt Butcher 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test 2 | test: 3 | go test -v -tags sqlite . 4 | 5 | .PHONY: test-fast 6 | test-fast: 7 | go test -v . 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Structable: Struct-Table Mapping for Go 2 | [![Stability: 3 | Sustained](https://masterminds.github.io/stability/sustained.svg)](https://masterminds.github.io/stability/sustained.html) 4 | 5 | **Warning:** This is the Structable 4 development branch. For a stable 6 | release, use version 3.1.0. Structable development happens very slowly. 7 | 8 | This library provides basic struct-to-table mapping for Go. 9 | 10 | It is based on the [Squirrel](https://github.com/Masterminds/squirrel) library. 11 | 12 | ## What It Does 13 | 14 | Structable maps a struct (`Record`) to a database table via a 15 | `structable.Recorder`. It is intended to be used as a back-end tool for 16 | building systems like Active Record mappers. 17 | 18 | It is designed to satisfy a CRUD-centered record management system, 19 | filling the following contract: 20 | 21 | ```go 22 | type Recorder interface { 23 | Bind(string, Record) Recorder // link struct to table 24 | Interface() interface{} // Get the struct that has been linked 25 | Insert() error // INSERT just one record 26 | Update() error // UPDATE just one record 27 | Delete() error // DELETE just one record 28 | Exists() (bool, error) // Check for just one record 29 | ExistsWhere(cond interface{}, args ...interface{}) (bool, error) 30 | Load() error // SELECT just one record 31 | LoadWhere(cond interface{}, args ...interface{}) error // Alternate Load() 32 | } 33 | ``` 34 | 35 | Squirrel already provides the ability to perform more complicated 36 | operations. 37 | 38 | ## How To Install It 39 | 40 | The usual way... 41 | 42 | ``` 43 | $ glide get github.com/Masterminds/structable 44 | $ # or... 45 | $ go get github.com/Masterminds/structable 46 | ``` 47 | 48 | And import it via: 49 | 50 | ``` 51 | import "github.com/Masterminds/structable" 52 | ``` 53 | 54 | ## How To Use It 55 | 56 | [![GoDoc](https://godoc.org/github.com/Masterminds/structable?status.png)](https://godoc.org/github.com/Masterminds/structable) 57 | 58 | Structable works by mapping a struct to columns in a database. 59 | 60 | To annotate a struct, you do something like this: 61 | 62 | ```go 63 | type Stool struct { 64 | Id int `stbl:"id, PRIMARY_KEY, AUTO_INCREMENT"` 65 | Legs int `stbl:"number_of_legs"` 66 | Material string `stbl:"material"` 67 | Ignored string // will not be stored. No tag. 68 | } 69 | ``` 70 | 71 | To manage instances of this struct, you do something like this: 72 | 73 | ```go 74 | stool := new(Stool) 75 | stool.Material = "Wood" 76 | db := getDb() // Get a sql.Db. You're on the hook to do this part. 77 | 78 | // Create a new structable.Recorder and tell it to 79 | // bind the given struct as a row in the given table. 80 | r := structable.New(db, "mysql").Bind("test_table", stool) 81 | 82 | // This will insert the stool into the test_table. 83 | err := r.Insert() 84 | ``` 85 | 86 | And of course you have `Load()`, `Update()`, `Delete()` and so on. 87 | 88 | The target use case for Structable is to use it as a backend for an 89 | Active Record pattern. An example of this can be found in the 90 | `structable_test.go` file 91 | 92 | Most of Structable focuses on individual objects, but there are helpers 93 | for listing objects: 94 | 95 | ```go 96 | // Get a list of things that have the same type as object. 97 | stool := new(Stool) 98 | items, err := structable.List(stool, offset, limit) 99 | 100 | // Customize a list of things that have the same type as object. 101 | fn = func(object structable.Describer, sql squirrel.SelectBuilder) (squirrel.SelectBuilder, error) { 102 | return sql.Limit(10), nil 103 | } 104 | items, err := structable.ListWhere(stool, fn) 105 | ``` 106 | 107 | For example, here is a function that uses `ListWhere` to get collection 108 | of definitions from a table described in a struct named `Table`: 109 | 110 | ```go 111 | func (s *SchemaInfo) Tables() ([]*Table, error) { 112 | 113 | // Bind a new recorder. We use an empty object just to get the field 114 | // data for that struct. 115 | t := &Table{} 116 | st := structable.New(s.Queryer, s.Driver).Bind(t.TableName(), t) 117 | 118 | // We want to return no more than 10 of these. 119 | fn := func(d structable.Describer, q squirrel.SelectBuilder) (squirrel.SelectBuilder, error) { 120 | return q.Limit(10), nil 121 | } 122 | 123 | // Fetch a list of Table structs. 124 | items, err := structable.ListWhere(st, fn) 125 | if err != nil { 126 | return []*Table{}, err 127 | } 128 | 129 | // Because we get back a []Recorder, we need to get the original data 130 | // back out. We have to manually convert it back to its real type. 131 | tables := make([]*Table, len(items)) 132 | for i, item := range items { 133 | tables[i] = item.Interface().(*Table) 134 | } 135 | return tables, nil 136 | } 137 | ``` 138 | 139 | ### Tested On 140 | 141 | - MySQL (5.5) 142 | - PostgreSQL (9.3, 9.4, 9.6) 143 | - SQLite 3 144 | 145 | ## What It Does Not Do 146 | 147 | It does not... 148 | 149 | * Create or manage schemas. 150 | * Guess or enforce table or column names. (You have to tell it how to 151 | map.) 152 | * Provide relational mapping. 153 | * Handle bulk operations (use Squirrel for that) 154 | 155 | ## LICENSE 156 | 157 | This software is licensed under an MIT-style license. See LICENSE.txt 158 | -------------------------------------------------------------------------------- /example/fence.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/Masterminds/squirrel" 5 | "github.com/Masterminds/structable" 6 | ) 7 | 8 | const FenceTable = "fences" 9 | 10 | // Fence represents a Geofence boundary. 11 | // 12 | // This struct is stubbed out to show how an ActiveRecord pattern might look 13 | // implemented using Squirrel and Structable. 14 | // 15 | // The DDL for the underlying table may look something like this: 16 | // CREATE TABLE fences ( 17 | // id SERIAL, 18 | // radius NUMERIC(20, 14), 19 | // latitude NUMERIC(20, 14), 20 | // longitude NUMERIC(20, 14), 21 | // region INTEGER, 22 | // 23 | // PRIMARY KEY(id), 24 | // ); 25 | type Fence struct { 26 | Id int `stbl:"id,PRIMARY_KEY,SERIAL"` 27 | Region int `stbl:"region"` 28 | Radius float64 `stbl:"radius"` 29 | Latitude float64 `stbl:"latitude"` 30 | Longitude float64 `stbl:"longitude"` 31 | 32 | rec structable.Recorder 33 | builder squirrel.StatementBuilderType 34 | } 35 | 36 | // NewFence creates a new empty fence. 37 | // 38 | // Note that a DBProxy is Squirrel's interface 39 | // that describes most sql.DB-like things. 40 | // 41 | // Flavor may be one of 'mysql', 'postgres'. Other DBs may 42 | // work, but are untested. 43 | func NewFence(db squirrel.DBProxyBeginner, dbFlavor string) *Fence { 44 | f := new(Fence) 45 | f.builder = squirrel.StatementBuilder.RunWith(db) 46 | 47 | // For Postgres we convert '?' to '$N' placeholders. 48 | if dbFlavor == "postgres" { 49 | f.builder = f.builder.PlaceholderFormat(squirrel.Dollar) 50 | } 51 | 52 | f.rec = structable.New(db, dbFlavor).Bind(FenceTable, f) 53 | 54 | return f 55 | } 56 | 57 | // Insert creates a new record. 58 | func (r *Fence) Insert() error { 59 | return r.rec.Insert() 60 | } 61 | 62 | // Update modifies an existing record 63 | func (r *Fence) Update() error { 64 | return r.rec.Update() 65 | } 66 | 67 | // Delete removes a record. 68 | func (r *Fence) Delete() error { 69 | return r.rec.Delete() 70 | } 71 | 72 | // Has returns true if the record exists. 73 | func (r *Fence) Has() (bool, error) { 74 | return r.rec.Exists() 75 | } 76 | 77 | // Load populates the struct with data from storage. 78 | // It presumes that the id field is set. 79 | func (r *Fence) Load() error { 80 | return r.rec.Load() 81 | } 82 | 83 | // LoadGeopoint loads by a given Lat/Long 84 | // Example of a custom loader 85 | // 86 | // Usage: 87 | // fence := NewFence(myDb, "postgres") 88 | // fence.Latitude = 1.000001 89 | // fence.Longitude = 1.000002 90 | // if err := fence.LoadGeopoint(); err != nil { 91 | // panic("Something went wrong! " + err.Error()) 92 | // } 93 | // fmt.Printf("Loaded ID %d\n", fence.Id) 94 | // 95 | func (r *Fence) LoadGeopoint() error { 96 | //q := r.rec.Select("id, radius, region").From(FenceTable). 97 | // Where("latitude = ? AND longitude = ?", r.Latitude, r.Longitude) 98 | 99 | //return q.Query().Scan(&r.Id, &r.Radius, &r.Region) 100 | return r.rec.LoadWhere("latitude = ? AND longitude = ?", r.Latitude, r.Longitude) 101 | } 102 | -------------------------------------------------------------------------------- /example/users.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/Masterminds/squirrel" 5 | "github.com/Masterminds/structable" 6 | _ "github.com/lib/pq" 7 | 8 | "database/sql" 9 | "fmt" 10 | ) 11 | 12 | // For convenience, we declare the table name as a constant. 13 | const UserTable = "users" 14 | 15 | // This is our struct. Notice that we make this a structable.Recorder. 16 | type User struct { 17 | structable.Recorder 18 | builder squirrel.StatementBuilderType 19 | 20 | Id int `stbl:"id,PRIMARY_KEY,SERIAL"` 21 | Name string `stbl:"name"` 22 | Email string `stbl:"email"` 23 | } 24 | 25 | // NewUser creates a new Structable wrapper for a user. 26 | // 27 | // Of particular importance, watch how we intialize the Recorder. 28 | func NewUser(db squirrel.DBProxyBeginner, dbFlavor string) *User { 29 | u := new(User) 30 | u.Recorder = structable.New(db, dbFlavor).Bind(UserTable, u) 31 | return u 32 | } 33 | 34 | // LoadByName is a custom loader. 35 | // 36 | // The Load() method on a Recorder loads by ID. This allows us to load by 37 | // a different field -- Name. 38 | func (u *User) LoadByName() error { 39 | return u.Recorder.LoadWhere("name = ? order by id desc", u.Name) 40 | } 41 | 42 | func main() { 43 | 44 | // Boilerplate DB setup. 45 | // First, we need to know the database driver. 46 | driver := "postgres" 47 | // Second, we need a database connection. 48 | con, _ := sql.Open(driver, "dbname=structable_test sslmode=disable") 49 | // Third, we wrap in a prepared statement cache for better performance. 50 | cache := squirrel.NewStmtCacheProxy(con) 51 | 52 | // Create an empty new user and give it some properties. 53 | user := NewUser(cache, driver) 54 | user.Name = "Matt" 55 | user.Email = "matt@example.com" 56 | 57 | // Insert this as a new record. 58 | if err := user.Insert(); err != nil { 59 | panic(err.Error()) 60 | } 61 | fmt.Printf("Initial insert has ID %d, name %s, and email %s\n", user.Id, user.Name, user.Email) 62 | 63 | // Now create another empty User and set the user's Name. 64 | again := NewUser(cache, driver) 65 | again.Name = "Matt" 66 | 67 | // Load using our custom loader. 68 | if err := again.LoadByName(); err != nil { 69 | panic(err.Error()) 70 | } 71 | fmt.Printf("User by name has ID %d and email %s\n", again.Id, again.Email) 72 | 73 | again.Email = "Masterminds@example.com" 74 | if err := again.Update(); err != nil { 75 | panic(err.Error()) 76 | } 77 | fmt.Printf("Updated user has ID %d and email %s\n", again.Id, again.Email) 78 | 79 | // Delete using the built-in Deleter. (delete by Id.) 80 | if err := again.Delete(); err != nil { 81 | panic(err.Error()) 82 | } 83 | fmt.Printf("Deleted user %d\n", again.Id) 84 | } 85 | -------------------------------------------------------------------------------- /example/users.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE users ( 2 | id SERIAL, 3 | name VARCHAR, 4 | email VARCHAR, 5 | PRIMARY KEY (id) 6 | ); 7 | -------------------------------------------------------------------------------- /glide.lock: -------------------------------------------------------------------------------- 1 | hash: 33914218da2ad50586a54a84aa0f1d94be0f335ca84e9e8ae2ef7719818a3d02 2 | updated: 2017-01-03T22:16:30.285599458-07:00 3 | imports: 4 | - name: github.com/codegangsta/cli 5 | version: 0bdeddeeb0f650497d603c4ad7b20cfe685682f6 6 | - name: github.com/lann/builder 7 | version: f22ce00fd9394014049dad11c244859432bd6820 8 | - name: github.com/lann/ps 9 | version: 62de8c46ede02a7675c4c79c84883eb164cb71e3 10 | - name: github.com/lib/pq 11 | version: 8df6253d1317616f36b0c3740eb30c239a7372cb 12 | subpackages: 13 | - oid 14 | - name: github.com/Masterminds/squirrel 15 | version: 20f192218cf52a73397fa2df45bdda720f3e47c8 16 | - name: github.com/mattn/go-sqlite3 17 | version: 6f2749a3ca9b233ffb8749ef9684f7f4d88cee7a 18 | devImports: [] 19 | -------------------------------------------------------------------------------- /glide.yaml: -------------------------------------------------------------------------------- 1 | package: github.com/Masterminds/structable 2 | import: 3 | - package: github.com/Masterminds/squirrel 4 | #- package: github.com/lann/builder 5 | #- package: github.com/lann/ps 6 | - package: github.com/lib/pq 7 | - package: github.com/mattn/go-sqlite3 8 | -------------------------------------------------------------------------------- /schema2struct/Makefile: -------------------------------------------------------------------------------- 1 | VERSION := $(shell git describe --tags) 2 | DIST_DIRS := find * -type d -exec 3 | 4 | build: 5 | go build -o schema2struct -ldflags "-X main.version=${VERSION}" schema2struct.go 6 | 7 | install: build 8 | install -d ${DESTDIR}/usr/local/bin/ 9 | install -m 755 ./schema2struct ${DESTDIR}/usr/local/bin/schema2struct 10 | 11 | .PHONY: build test install clean 12 | -------------------------------------------------------------------------------- /schema2struct/README.md: -------------------------------------------------------------------------------- 1 | # schema2struct: Create definitions from the database 2 | 3 | This program is a proof of concept for creating Structable structs by 4 | inspecting a database and generating closely matching structs. 5 | 6 | Currently this only works on Postgres, though there is no reason it 7 | could not be ported to support other databases. 8 | 9 | It works by querying the INFORMATION_SCHEMA tables to learn about what 10 | tables are present and what columns they stored. It then attempts to 11 | render structs that point to those tables. 12 | 13 | If you are interested in contributing to moving this beyond proof of 14 | concept, feel free to issue PRs against the codebase. 15 | 16 | ## Usage 17 | 18 | Install using `make install`. This will put `schema2struct` on your 19 | `$PATH`. 20 | 21 | In the package where you want to create the structs, add an annotation 22 | to one of the Go files: 23 | 24 | ```go 25 | //go:generate schema2struct -f schemata.go 26 | ``` 27 | 28 | The above annotation will instruct `go generate` to run `schema2struct` 29 | and generate a file called `schemata.go`. 30 | 31 | Finally, run `go generate` in that package's directory: 32 | 33 | ``` 34 | $ cd model 35 | $ go generate 36 | ``` 37 | 38 | The result should be a `schemata.go` source file. 39 | -------------------------------------------------------------------------------- /schema2struct/schema2struct.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "io" 7 | "os" 8 | "strings" 9 | "text/template" 10 | 11 | "github.com/Masterminds/squirrel" 12 | "github.com/codegangsta/cli" 13 | 14 | _ "github.com/go-sql-driver/mysql" 15 | _ "github.com/lib/pq" 16 | ) 17 | 18 | const version = "DEV" 19 | 20 | // Usage : exported const Usage 21 | const Usage = `Read a schema and generate Structable structs. 22 | 23 | This utility generates Structable structs be reading your database table and 24 | generating the appropriate code. 25 | ` 26 | 27 | const fileHeader = `package %s 28 | 29 | // This file is automatically generated by schema2struct. 30 | 31 | import ( 32 | "time" 33 | 34 | "github.com/Masterminds/squirrel" 35 | "github.com/Masterminds/structable" 36 | _ "github.com/go-sql-driver/mysql" 37 | _ "github.com/lib/pq" 38 | ) 39 | 40 | // QueryFunc modifies a SelectBuilder prior to execution. 41 | // 42 | // The SelectBuilder is modified in place. An error is returned under any 43 | // conditions where the query should not be executed. 44 | type QueryFunc func(q squirrel.SelectBuilder) (squirrel.SelectBuilder, error) 45 | 46 | ` 47 | 48 | const structTemplate = `// {{.StructName}} maps to database table {{.TableName}} 49 | type {{.StructName}} struct { 50 | tableName string {{ann "tablename" .TableName}} 51 | structable.Recorder 52 | builder squirrel.StatementBuilderType 53 | {{range .Fields}}{{.}} 54 | {{end}}db squirrel.DBProxyBeginner 55 | flavor string 56 | } 57 | 58 | // New{{.StructName}} creates a new {{.StructName}} wired to structable. 59 | func New{{.StructName}}(db squirrel.DBProxyBeginner, flavor string) *{{.StructName}} { 60 | o := &{{.StructName}}{db: db, flavor: flavor} 61 | o.Recorder = structable.New(db, flavor).Bind("{{.TableName}}", o) 62 | return o 63 | } 64 | 65 | // List{{.StructName}} returns a list of {{.StructName}} objects. 66 | // 67 | // Limit is the max number of items. Offset is the offset the results will 68 | // begin with. 69 | func List{{.StructName}}(db squirrel.DBProxyBeginner, flavor string, limit, offset uint64) ([]*{{.StructName}}, error) { 70 | fn := func(q squirrel.SelectBuilder) (squirrel.SelectBuilder, error) { 71 | return q.Limit(limit).Offset(offset), nil 72 | } 73 | return Query{{.StructName}}(db, flavor, fn) 74 | } 75 | 76 | // Query{{.StructName}} builds a base query, but allows the query to be modified before execution. 77 | // 78 | // This creates a new Select, settings the columns and table name, and then calling QueryFunc with the 79 | // query. The QueryFunc can then add a Where clause, etc. Provided QueryFunc does not exit with an 80 | // error, Query{{.StructName}} will then execute the query, extract the results into a slice of 81 | // {{.StructName}} structs, and then return. 82 | // 83 | // The QueryFunc should not modify the list of fields returned or the table name, 84 | // as the intent is to construct a complete {{.StructName}} from each result. 85 | // More sophisticated queries should be written directly. 86 | func Query{{.StructName}}(db squirrel.DBProxyBeginner, flavor string, fn QueryFunc) ([]*{{.StructName}}, error){ 87 | var tn string = "{{.TableName}}" 88 | 89 | // We need a prototype structable to learn about the table structure. 90 | ps := New{{.StructName}}(db, flavor) 91 | cols := ps.Columns(true) 92 | 93 | q := ps.Builder().Select(cols...).From(tn) 94 | var err error 95 | if q, err = fn(q); err != nil { 96 | return []*{{.StructName}}{}, err 97 | } 98 | rows, err := q.Query() 99 | if err != nil || rows == nil { 100 | return []*{{.StructName}}{}, err 101 | } 102 | defer rows.Close() 103 | 104 | buf := []*{{.StructName}}{} 105 | for rows.Next() { 106 | o := New{{.StructName}}(db, flavor) 107 | dest := o.FieldReferences(true) 108 | if err := rows.Scan(dest...); err != nil { 109 | return buf, err 110 | } 111 | buf = append(buf, o) 112 | } 113 | return buf, rows.Err() 114 | } 115 | 116 | // Len{{.StructName}} returns the number of {{.StructName}} objects in the database. 117 | func Len{{.StructName}}(db squirrel.DBProxyBeginner, flavor string) (int, error) { 118 | fn := func(q squirrel.SelectBuilder) (squirrel.SelectBuilder, error) {return q, nil} 119 | return QueryLen{{.StructName}}(db, flavor, fn) 120 | } 121 | 122 | // QueryLen{{.StructName}} returns the length of a table. 123 | // 124 | // The QueryFunc can be used to modify the query. For a simple length call, you 125 | // may prefer to use Len{{.StructName}}. 126 | func QueryLen{{.StructName}}(db squirrel.DBProxyBeginner, flavor string, fn QueryFunc) (int, error) { 127 | tn := "{{.TableName}}" 128 | ps := New{{.StructName}}(db, flavor) 129 | q := ps.Builder().Select("COUNT(*)").From(tn) 130 | var err error 131 | if q, err = fn(q); err != nil { 132 | return 0, err 133 | } 134 | var count int 135 | err = q.Scan(&count) 136 | return count, err 137 | } 138 | 139 | ` 140 | 141 | type structDesc struct { 142 | StructName string 143 | TableName string 144 | Fields []string 145 | } 146 | 147 | func main() { 148 | app := cli.NewApp() 149 | app.Name = "schema2struct" 150 | app.Version = "version" 151 | app.Usage = Usage 152 | app.Action = importTables 153 | app.Flags = []cli.Flag{ 154 | cli.StringFlag{ 155 | Name: "driver,d", 156 | Value: "postgres", 157 | Usage: "The name of the SQL driver to use.", 158 | }, 159 | cli.StringFlag{ 160 | Name: "connection,c", 161 | Value: "user=$USER dbname=$USER sslmode=disable", 162 | Usage: "The database connection string. Environment variables are expanded.", 163 | }, 164 | cli.StringFlag{ 165 | Name: "tables,t", 166 | Value: "", 167 | Usage: "The list of tables to generate, comma separated. If none specified, the entire schema is used.", 168 | }, 169 | cli.StringFlag{ 170 | Name: "file,f", 171 | Value: "", 172 | Usage: "The file to send the output.", 173 | }, 174 | cli.StringFlag{ 175 | Name: "package,p", 176 | Value: "main", 177 | Usage: "The name of the destination package.", 178 | EnvVar: "GOPACKAGE", 179 | }, 180 | } 181 | 182 | app.Run(os.Args) 183 | } 184 | 185 | func driver(c *cli.Context) string { 186 | return c.String("driver") 187 | } 188 | func conn(c *cli.Context) string { 189 | return os.ExpandEnv(c.String("connection")) 190 | } 191 | 192 | // dest gets the destination output writer. 193 | func dest(c *cli.Context) io.Writer { 194 | if out := c.String("file"); out != "" { 195 | f, err := os.Create(out) 196 | if err != nil { 197 | panic(f) 198 | } 199 | return f 200 | } 201 | return os.Stdout 202 | } 203 | 204 | func tableList(c *cli.Context) []string { 205 | z := c.String("tables") 206 | if z != "" { 207 | return strings.Split(z, ",") 208 | } 209 | return []string{} 210 | } 211 | 212 | func cxdie(c *cli.Context, err error) { 213 | fmt.Fprintf(os.Stderr, "Failed to connect to %s (type %s): %s", conn(c), driver(c), err) 214 | os.Exit(1) 215 | } 216 | 217 | var funcMap = map[string]interface{}{ 218 | "ann": func(tag, val string) string { 219 | return fmt.Sprintf("`%s:\"%s\"`", tag, val) 220 | }, 221 | } 222 | 223 | func importTables(c *cli.Context) { 224 | ttt := template.Must(template.New("st").Funcs(funcMap).Parse(structTemplate)) 225 | cxn, err := sql.Open(driver(c), conn(c)) 226 | if err != nil { 227 | cxdie(c, err) 228 | } 229 | // Many drivers defer connections until the first statement. We test 230 | // that here. 231 | if err := cxn.Ping(); err != nil { 232 | cxdie(c, err) 233 | } 234 | defer cxn.Close() 235 | 236 | // Set up Squirrel 237 | stmts := squirrel.NewStmtCacher(cxn) 238 | bldr := squirrel.StatementBuilder.RunWith(stmts) 239 | if driver(c) == "postgres" { 240 | bldr = bldr.PlaceholderFormat(squirrel.Dollar) 241 | } 242 | 243 | // Set up destination 244 | out := dest(c) 245 | fmt.Fprintf(out, fileHeader, c.String("package")) 246 | 247 | tables := tableList(c) 248 | 249 | if len(tables) == 0 { 250 | tables, err = publicTables(bldr) 251 | if err != nil { 252 | fmt.Fprintf(os.Stderr, "Cannot fetch list of tables: %s\n", err) 253 | os.Exit(2) 254 | } 255 | } 256 | 257 | for _, t := range tables { 258 | f, err := importTable(t, bldr, driver(c)) 259 | if err != nil { 260 | fmt.Fprintf(os.Stderr, "Failed to import table %s: %s", t, err) 261 | } 262 | 263 | //fmt.Fprintf(out, "%s %s %s\n", f.StructName, f.TableName, f.Fields) 264 | ttt.Execute(out, f) 265 | } 266 | } 267 | 268 | type column struct { 269 | Name, DataType string 270 | Max int64 271 | } 272 | 273 | func publicTables(b squirrel.StatementBuilderType) ([]string, error) { 274 | rows, err := b.Select("table_name").From("INFORMATION_SCHEMA.TABLES"). 275 | Where("table_schema = 'public'").Query() 276 | 277 | res := []string{} 278 | if err != nil { 279 | return res, err 280 | } 281 | 282 | for rows.Next() { 283 | var s string 284 | rows.Scan(&s) 285 | res = append(res, s) 286 | } 287 | 288 | return res, nil 289 | } 290 | 291 | // importTable reads a table definition and writes a corresponding struct. 292 | // SELECT table_name, column_name, data_type, character_maximum_length 293 | // FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = 'goose_db_version' 294 | func importTable(tbl string, b squirrel.StatementBuilderType, driver string) (*structDesc, error) { 295 | 296 | pks, err := primaryKeyField(tbl, b) 297 | if err != nil { 298 | fmt.Fprintf(os.Stderr, "Error getting primary keys: %s", err) 299 | } 300 | 301 | q := b.Select("column_name, data_type, character_maximum_length"). 302 | From("INFORMATION_SCHEMA.COLUMNS"). 303 | Where("table_name = ?", tbl) 304 | 305 | rows, err := q.Query() 306 | if err != nil { 307 | return nil, err 308 | } 309 | defer rows.Close() 310 | 311 | ff := []string{} 312 | for rows.Next() { 313 | c := &column{} 314 | var length sql.NullInt64 315 | if err := rows.Scan(&c.Name, &c.DataType, &length); err != nil { 316 | return nil, err 317 | } 318 | c.Max = length.Int64 319 | switch driver { 320 | case "mysql": 321 | ff = append(ff, structFieldMySQL(c, pks, tbl, b)) 322 | case "postgres": 323 | ff = append(ff, structField(c, pks, tbl, b)) 324 | } 325 | } 326 | sd := &structDesc{ 327 | StructName: goName(tbl), 328 | TableName: tbl, 329 | Fields: ff, 330 | } 331 | 332 | return sd, nil 333 | } 334 | 335 | func primaryKeyField(tbl string, b squirrel.StatementBuilderType) ([]string, error) { 336 | q := b.Select("column_name"). 337 | From("INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS c"). 338 | LeftJoin("INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t USING(constraint_name)"). 339 | Where("t.table_name = ? AND t.constraint_type = 'PRIMARY KEY'", tbl). 340 | OrderBy("ordinal_position") 341 | 342 | rows, err := q.Query() 343 | if err != nil { 344 | return []string{}, err 345 | } 346 | 347 | res := []string{} 348 | for rows.Next() { 349 | var s string 350 | rows.Scan(&s) 351 | res = append(res, s) 352 | } 353 | return res, nil 354 | } 355 | 356 | func autoincrementKey(tbl, pk string, b squirrel.StatementBuilderType) bool { 357 | q := b.Select("COUNT(*)"). 358 | From("INFORMATION_SCHEMA.COLUMNS"). 359 | Where("TABLE_NAME = ? AND COLUMN_NAME = ? AND EXTRA = 'auto_increment'", tbl, pk) 360 | var num int 361 | if err := q.Scan(&num); err != nil { 362 | panic(err) 363 | } 364 | return num > 0 365 | } 366 | 367 | func sequentialKey(tbl, pk string, b squirrel.StatementBuilderType) bool { 368 | tlen := 58 369 | 370 | stbl := tbl 371 | if len(tbl) > 29 { 372 | stbl = tbl[0:29] 373 | } 374 | 375 | left := tlen - len(stbl) 376 | spk := pk 377 | if len(pk) > left { 378 | spk = pk[0:left] 379 | } 380 | seq := fmt.Sprintf("%s_%s_seq", stbl, spk) 381 | 382 | q := b.Select("COUNT(*)"). 383 | From("INFORMATION_SCHEMA.SEQUENCES"). 384 | Where("sequence_name = ?", seq) 385 | 386 | var num int 387 | if err := q.Scan(&num); err != nil { 388 | panic(err) 389 | } 390 | return num > 0 391 | } 392 | 393 | func structFieldMySQL(c *column, pks []string, tbl string, b squirrel.StatementBuilderType) string { 394 | tpl := "%s %s `stbl:\"%s\"`" 395 | gn := destutter(goName(c.Name), goName(tbl)) 396 | tt := goType(c.DataType) 397 | 398 | tag := c.Name 399 | for _, p := range pks { 400 | if c.Name == p { 401 | tag += ",PRIMARY_KEY" 402 | if autoincrementKey(tbl, c.Name, b) { 403 | tag += ",AUTO_INCREMENT" 404 | } 405 | } 406 | } 407 | 408 | return fmt.Sprintf(tpl, gn, tt, tag) 409 | } 410 | 411 | func structField(c *column, pks []string, tbl string, b squirrel.StatementBuilderType) string { 412 | tpl := "%s %s `stbl:\"%s\"`" 413 | gn := destutter(goName(c.Name), goName(tbl)) 414 | tt := goType(c.DataType) 415 | 416 | tag := c.Name 417 | for _, p := range pks { 418 | if c.Name == p { 419 | tag += ",PRIMARY_KEY" 420 | if sequentialKey(tbl, c.Name, b) { 421 | tag += ",SERIAL" 422 | } 423 | } 424 | } 425 | 426 | return fmt.Sprintf(tpl, gn, tt, tag) 427 | } 428 | 429 | // goType takes a SQL type and returns a string containin the name of a Go type. 430 | // 431 | // The goal is not to provide an exact match for every type, but to provide a 432 | // safe Go representation of a SQL type. 433 | // 434 | // For some floating point SQL types, for example, we store them as strings 435 | // so as not to lose precision while also not adding new types. 436 | // 437 | // The default type is string. 438 | func goType(sqlType string) string { 439 | switch sqlType { 440 | case "smallint", "smallserial": 441 | return "int16" 442 | case "integer", "serial": 443 | return "int32" 444 | case "bigint", "bigserial": 445 | return "int" 446 | case "real": 447 | return "float32" 448 | case "double precision": 449 | return "float64" 450 | // Because we need to preserve base-10 precision. 451 | case "money": 452 | return "string" 453 | case "text", "varchar", "char", "character", "character varying", "uuid": 454 | return "string" 455 | case "bytea": 456 | return "[]byte" 457 | case "boolean": 458 | return "bool" 459 | case "timezone", "timezonetz", "date", "time": 460 | return "time.Time" 461 | case "interval": 462 | return "time.Duration" 463 | } 464 | return "string" 465 | } 466 | 467 | // Convert a SQL name to a Go name. 468 | func goName(sqlName string) string { 469 | // This can definitely be done better. 470 | goName := strings.Replace(sqlName, "_", " ", -1) 471 | goName = strings.Replace(goName, ".", " ", -1) 472 | goName = strings.Title(goName) 473 | goName = strings.Replace(goName, " ", "", -1) 474 | 475 | return goName 476 | } 477 | 478 | // destutter removes a stutter prefix. 479 | func destutter(str, prefix string) string { 480 | return strings.TrimPrefix(str, prefix) 481 | } 482 | -------------------------------------------------------------------------------- /sqlite_pod_test.go: -------------------------------------------------------------------------------- 1 | // +build sqlite 2 | 3 | package structable 4 | 5 | import ( 6 | "database/sql" 7 | "log" 8 | "testing" 9 | "time" 10 | 11 | "github.com/Masterminds/squirrel" 12 | _ "github.com/mattn/go-sqlite3" 13 | ) 14 | 15 | type Language struct { 16 | Recorder 17 | builder squirrel.StatementBuilderType 18 | 19 | Id int64 `stbl:"id,PRIMARY_KEY,AUTO_INCREMENT"` 20 | Name string `stbl:"name"` 21 | Version string `stbl:"version"` 22 | DtRelease time.Time `stbl:"dt_release"` 23 | } 24 | 25 | func (l *Language) equals(other *Language) bool { 26 | return l.Id == other.Id && 27 | l.Name == other.Name && 28 | l.Version == other.Version && 29 | l.DtRelease.Equal(other.DtRelease) 30 | } 31 | 32 | func (l *Language) loadFromSql(Id int64, db *sql.DB) error { 33 | 34 | err := db.QueryRow("SELECT name, version, dt_release FROM languages WHERE id=?", Id).Scan( 35 | &l.Name, 36 | &l.Version, 37 | &l.DtRelease) 38 | if err == nil { 39 | l.Id = Id 40 | } 41 | return err 42 | } 43 | 44 | func TestPlainStructInsert(t *testing.T) { 45 | 46 | db := getLanguagesDb() 47 | 48 | l := &Language{ 49 | Id: -1, 50 | Name: "Go", 51 | Version: "1.3", 52 | DtRelease: time.Date(2014, time.June, 18, 0, 0, 0, 0, time.UTC)} 53 | l.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("languages", l) 54 | 55 | if err := l.Insert(); err != nil { 56 | t.Fatalf("Failed Insert: %s", err) 57 | } 58 | 59 | lsql := new(Language) 60 | lsql.loadFromSql(l.Id, db) 61 | if !l.equals(lsql) { 62 | t.Fatal("Loaded and inserted objects should be equivalent") 63 | } 64 | } 65 | 66 | func TestPlainStructLoad(t *testing.T) { 67 | 68 | db := getLanguagesDb() 69 | 70 | lsql := &Language{} 71 | if res, err := db.Exec("INSERT INTO languages (name, version, dt_release) VALUES ('Scala', '2.11.7', '2015-06-23')"); err != nil { 72 | t.Fatalf("Sqlite Exec failed: %s", err) 73 | } else if lsql.Id, err = res.LastInsertId(); err != nil { 74 | t.Fatalf("Sqlite LastInsertId failed: %s", err) 75 | } 76 | 77 | l := &Language{Id: lsql.Id} 78 | l.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("languages", l) 79 | if err := l.Load(); err != nil { 80 | t.Fatalf("Failed Load: %s", err) 81 | } 82 | 83 | lsql.Name = "Scala" 84 | lsql.Version = "2.11.7" 85 | lsql.DtRelease = time.Date(2015, time.June, 23, 0, 0, 0, 0, time.UTC) 86 | if !l.equals(lsql) { 87 | t.Fatal("Loaded and inserted objects should be equivalent") 88 | } 89 | } 90 | 91 | func TestPlainStructLoadWhere(t *testing.T) { 92 | 93 | db := getLanguagesDb() 94 | 95 | var lastId int64 96 | if res, err := db.Exec("INSERT INTO languages (name, version, dt_release) VALUES ('Scala', '2.11.7', '2015-06-23')"); err != nil { 97 | t.Fatalf("Sqlite Exec failed: %s", err) 98 | } else if lastId, err = res.LastInsertId(); err != nil { 99 | t.Fatalf("Sqlite LastInsertId failed: %s", err) 100 | } 101 | 102 | lsql := &Language{ 103 | Id: -1, 104 | Name: "Scala", 105 | Version: "2.11.7", 106 | DtRelease: time.Date(2015, time.June, 23, 0, 0, 0, 0, time.UTC)} 107 | 108 | l := &Language{Id: lastId} 109 | l.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("languages", l) 110 | if err := l.LoadWhere("version = ?", "2.11.7"); err != nil { 111 | t.Fatalf("Failed LoadWhere: %s", err) 112 | } 113 | 114 | lsql.Id = lastId 115 | if !l.equals(lsql) { 116 | t.Fatal("Loaded and inserted objects should be equivalent") 117 | } 118 | } 119 | 120 | func TestPlainStructUpdate(t *testing.T) { 121 | 122 | db := getLanguagesDb() 123 | 124 | var lastId int64 125 | if res, err := db.Exec("INSERT INTO languages (name, version, dt_release) VALUES ('Scala', '2.11.7', '2015-06-23')"); err != nil { 126 | t.Fatalf("Sqlite Exec failed: %s", err) 127 | } else if lastId, err = res.LastInsertId(); err != nil { 128 | t.Fatalf("Sqlite LastInsertId failed: %s", err) 129 | } 130 | 131 | l := &Language{ 132 | Id: lastId, 133 | Name: "Go", 134 | Version: "1.4", 135 | DtRelease: time.Date(2014, time.June, 18, 0, 0, 0, 0, time.UTC)} 136 | 137 | l.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("languages", l) 138 | if err := l.Update(); err != nil { 139 | t.Fatalf("Failed Update: %s", err) 140 | } 141 | 142 | lsql := new(Language) 143 | lsql.loadFromSql(lastId, db) 144 | if !l.equals(lsql) { 145 | t.Fatal("Loaded and updated objects should be equivalent") 146 | } 147 | } 148 | 149 | func TestPlainStructDelete(t *testing.T) { 150 | 151 | db := getLanguagesDb() 152 | 153 | var lastId int64 154 | if res, err := db.Exec("INSERT INTO languages (name, version, dt_release) VALUES ('Scala', '2.11.7', '2015-06-23')"); err != nil { 155 | t.Fatalf("Sqlite Exec failed: %s", err) 156 | } else if lastId, err = res.LastInsertId(); err != nil { 157 | t.Fatalf("Sqlite LastInsertId failed: %s", err) 158 | } 159 | 160 | l := &Language{Id: lastId} 161 | l.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("languages", l) 162 | if err := l.Delete(); err != nil { 163 | t.Fatalf("Failed Delete: %s", err) 164 | } 165 | 166 | var count int64 167 | if err := db.QueryRow("SELECT COUNT(*) from languages;").Scan(&count); err != nil { 168 | t.Fatalf("Error executing query: %s", err) 169 | } 170 | if count != 0 { 171 | t.Fatalf("Database should count no rows, instead it has got: %v", count) 172 | } 173 | } 174 | 175 | func TestPlainStructExists(t *testing.T) { 176 | 177 | db := getLanguagesDb() 178 | 179 | l := &Language{Id: 1} 180 | l.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("languages", l) 181 | 182 | if exists, err := l.Exists(); err != nil { 183 | t.Fatalf("Failed Exists: %s", err) 184 | } else if exists { 185 | t.Fatal("Exists should return false") 186 | } 187 | 188 | var lastId int64 189 | if res, err := db.Exec("INSERT INTO languages (name, version, dt_release) VALUES ('Scala', '2.11.7', '2015-06-23')"); err != nil { 190 | t.Fatalf("Sqlite Exec failed: %s", err) 191 | } else if lastId, err = res.LastInsertId(); err != nil { 192 | t.Fatalf("Sqlite LastInsertId failed: %s", err) 193 | } 194 | 195 | l.Id = lastId 196 | if exists, err := l.Exists(); err != nil { 197 | t.Fatalf("Failed Exists: %s", err) 198 | } else if !exists { 199 | t.Fatal("Exists should return true") 200 | } 201 | } 202 | 203 | func TestPlainStructExistsWhere(t *testing.T) { 204 | 205 | db := getLanguagesDb() 206 | 207 | if _, err := db.Exec("INSERT INTO languages (name, version, dt_release) VALUES ('Scala', '2.11.7', '2015-06-23')"); err != nil { 208 | t.Fatalf("Sqlite Exec failed: %s", err) 209 | } 210 | 211 | l := &Language{} 212 | l.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("languages", l) 213 | 214 | if exists, err := l.ExistsWhere("Name = ?", "Go"); err != nil { 215 | t.Fatalf("Failed ExistsWhere: %s", err) 216 | } else if exists { 217 | t.Fatal("ExistsWhere should return false") 218 | } 219 | 220 | if exists, err := l.ExistsWhere("Name = ?", "Scala"); err != nil { 221 | t.Fatalf("Failed Exists: %s", err) 222 | } else if !exists { 223 | t.Fatal("Exists should return true") 224 | } 225 | } 226 | 227 | func getLanguagesDb() *sql.DB { 228 | 229 | db, err := sql.Open("sqlite3", ":memory:") 230 | if err != nil { 231 | log.Fatalf("Couldn't Open database: %s\n", err) 232 | } 233 | 234 | stmt := ` 235 | CREATE TABLE languages ( 236 | id INTEGER PRIMARY KEY AUTOINCREMENT, 237 | name STRING, 238 | version STRING, 239 | dt_release TIMESTAMP DEFAULT('1789-07-14 12:00:00.000') 240 | ); 241 | DELETE FROM languages; 242 | ` 243 | _, err = db.Exec(stmt) 244 | if err != nil { 245 | log.Fatalf("Couldn't Exec query \"%q\": %s\n", err, stmt) 246 | return nil 247 | } 248 | return db 249 | } 250 | -------------------------------------------------------------------------------- /sqlite_ptr_test.go: -------------------------------------------------------------------------------- 1 | // +build sqlite 2 | 3 | package structable 4 | 5 | import ( 6 | "database/sql" 7 | "log" 8 | "testing" 9 | 10 | "github.com/Masterminds/squirrel" 11 | _ "github.com/mattn/go-sqlite3" 12 | ) 13 | 14 | type Movie struct { 15 | Recorder 16 | builder squirrel.StatementBuilderType 17 | 18 | Id int64 `stbl:"id,PRIMARY_KEY,AUTO_INCREMENT"` 19 | Title string `stbl:"title"` 20 | Genre *string `stbl:"genre"` 21 | Budget float64 `stbl:"budget"` 22 | } 23 | 24 | func (l *Movie) equals(other *Movie) bool { 25 | return l.Id == other.Id && 26 | l.Title == other.Title && 27 | l.Budget == other.Budget && 28 | CompareStringPtr(l.Genre, other.Genre) 29 | } 30 | 31 | func CompareStringPtr(s1, s2 *string) bool { 32 | switch { 33 | case s1 == nil && s2 == nil: 34 | return true 35 | case s1 != nil && s2 == nil || s1 == nil && s2 != nil: 36 | return false 37 | default: 38 | return *s1 == *s2 39 | } 40 | } 41 | 42 | func stringPtr(s string) *string { 43 | return &s 44 | } 45 | 46 | func (m *Movie) loadFromSql(Id int64, db *sql.DB) error { 47 | 48 | if m.Genre == nil { 49 | m.Genre = new(string) 50 | } 51 | err := db.QueryRow("SELECT title, genre, budget FROM movies WHERE id=?", Id).Scan( 52 | &m.Title, 53 | m.Genre, 54 | &m.Budget) 55 | if err == nil { 56 | m.Id = Id 57 | } 58 | return err 59 | } 60 | 61 | func TestStructWithPointerInsert(t *testing.T) { 62 | 63 | db := getMoviesDb() 64 | 65 | m := &Movie{ 66 | Id: -1, 67 | Title: "2001: A Space Odyssey", 68 | Genre: stringPtr("Science-Fiction"), 69 | Budget: 1500000} 70 | m.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("movies", m) 71 | 72 | if err := m.Insert(); err != nil { 73 | t.Fatalf("Failed Insert: %s", err) 74 | } 75 | 76 | msql := new(Movie) 77 | msql.loadFromSql(m.Id, db) 78 | 79 | if *msql.Genre != "Science-Fiction" { 80 | t.Fatal("Insert should dereference allocated pointers") 81 | } 82 | 83 | m.Genre = nil 84 | if err := m.Insert(); err != nil { 85 | t.Fatalf("Failed Insert: %s", err) 86 | } 87 | 88 | msql.loadFromSql(m.Id, db) 89 | 90 | if *msql.Genre != "unclassifiable" { 91 | t.Fatal("Insert should ignore nil pointers") 92 | } 93 | } 94 | 95 | func TestStructWithPointerLoad(t *testing.T) { 96 | 97 | db := getMoviesDb() 98 | 99 | msql := &Movie{} 100 | if res, err := db.Exec("INSERT INTO movies (title, genre, budget) VALUES ('2001: A Space Odyssey', 'Science-Fiction', 1500000)"); err != nil { 101 | t.Fatalf("Sqlite Exec failed: %s", err) 102 | } else if msql.Id, err = res.LastInsertId(); err != nil { 103 | t.Fatalf("Sqlite LastInsertId failed: %s", err) 104 | } 105 | 106 | m := &Movie{Id: msql.Id, Genre: new(string)} 107 | m.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("movies", m) 108 | if err := m.Load(); err != nil { 109 | t.Fatalf("Failed Load: %s", err) 110 | } 111 | 112 | if !CompareStringPtr(m.Genre, stringPtr("Science-Fiction")) { 113 | t.Fatal("Load should load pointer fields") 114 | } 115 | 116 | m.Genre = nil 117 | m.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("movies", m) 118 | if err := m.Load(); err != nil { 119 | t.Fatalf("Failed Load: %s", err) 120 | } 121 | 122 | if !CompareStringPtr(m.Genre, stringPtr("Science-Fiction")) { 123 | t.Fatal("Load should instantiate nil pointers") 124 | } 125 | } 126 | 127 | func TestStructWithPointerLoadWhere(t *testing.T) { 128 | 129 | db := getMoviesDb() 130 | 131 | msql := &Movie{} 132 | if res, err := db.Exec("INSERT INTO movies (title, genre, budget) VALUES ('2001: A Space Odyssey', 'Science-Fiction', 1500000)"); err != nil { 133 | t.Fatalf("Sqlite Exec failed: %s", err) 134 | } else if msql.Id, err = res.LastInsertId(); err != nil { 135 | t.Fatalf("Sqlite LastInsertId failed: %s", err) 136 | } 137 | 138 | m := &Movie{} 139 | m.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("movies", m) 140 | if err := m.LoadWhere("budget = ?", 1500000); err != nil { 141 | t.Fatalf("Failed LoadWhere: %s", err) 142 | } 143 | 144 | if !CompareStringPtr(m.Genre, stringPtr("Science-Fiction")) { 145 | t.Fatal("LoadWhere should load pointer fields") 146 | } 147 | 148 | m.Genre = nil 149 | m.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("movies", m) 150 | if err := m.LoadWhere("budget = ?", 1500000); err != nil { 151 | t.Fatalf("Failed LoadWhere: %s", err) 152 | } 153 | 154 | if !CompareStringPtr(m.Genre, stringPtr("Science-Fiction")) { 155 | t.Fatal("LoadWhere should instantiate nil pointers") 156 | } 157 | } 158 | 159 | func TestStructWithPointerUpdate(t *testing.T) { 160 | 161 | db := getMoviesDb() 162 | 163 | var lastId int64 164 | if res, err := db.Exec("INSERT INTO movies (title, genre, budget) VALUES ('2001: A Space Odyssey', 'Science-Fiction', 1500000)"); err != nil { 165 | t.Fatalf("Sqlite Exec failed: %s", err) 166 | } else if lastId, err = res.LastInsertId(); err != nil { 167 | t.Fatalf("Sqlite LastInsertId failed: %s", err) 168 | } 169 | 170 | m := &Movie{ 171 | Id: lastId, 172 | Title: "The Usual Suspects", 173 | Genre: nil, 174 | Budget: 6000000} 175 | 176 | m.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("movies", m) 177 | if err := m.Update(); err != nil { 178 | t.Fatalf("Failed Update: %s", err) 179 | } 180 | 181 | msql := new(Movie) 182 | msql.loadFromSql(lastId, db) 183 | 184 | if !CompareStringPtr(msql.Genre, stringPtr("Science-Fiction")) { 185 | t.Fatal("Update should ignore nil pointers") 186 | } 187 | 188 | m.Genre = stringPtr("Crime Thriller") 189 | m.Recorder = New(squirrel.NewStmtCacheProxy(db), "mysql").Bind("movies", m) 190 | if err := m.Update(); err != nil { 191 | t.Fatalf("Failed Update: %s", err) 192 | } 193 | 194 | msql.loadFromSql(lastId, db) 195 | if !CompareStringPtr(msql.Genre, stringPtr("Crime Thriller")) { 196 | t.Log("msql.Genre: %v\n", *msql.Genre) 197 | t.Fatal("Update should ignore nil pointers") 198 | } 199 | } 200 | 201 | func getMoviesDb() *sql.DB { 202 | 203 | db, err := sql.Open("sqlite3", ":memory:") 204 | if err != nil { 205 | log.Fatalf("Couldn't Open database: %s\n", err) 206 | } 207 | 208 | stmt := ` 209 | CREATE TABLE movies ( 210 | id INTEGER PRIMARY KEY AUTOINCREMENT, 211 | title STRING, 212 | genre STRING DEFAULT('unclassifiable'), 213 | budget REAL 214 | ); 215 | DELETE FROM movies; 216 | ` 217 | 218 | _, err = db.Exec(stmt) 219 | if err != nil { 220 | log.Fatalf("Couldn't Exec query \"%q\": %s\n", err, stmt) 221 | return nil 222 | } 223 | return db 224 | } 225 | -------------------------------------------------------------------------------- /structable.go: -------------------------------------------------------------------------------- 1 | /* Structable is a struct-to-table mapper for databases. 2 | 3 | Structable makes a loose distinction between a Record (a description of the 4 | data to be stored) and a Recorder (the thing that does the storing). A 5 | Record is a simple annotated struct that describes the properties of an 6 | object. 7 | 8 | Structable provides the Recorder (an interface usually backed by a *DbRecorder). 9 | The Recorder is capable of doing the following: 10 | 11 | - Bind: Attach the Recorder to a Record 12 | - Load: Load a Record from a database 13 | - Insert: Create a new Record 14 | - Update: Change one or more fields on a Record 15 | - Delete: Destroy a record in the database 16 | - Has: Determine whether a given Record exists in a database 17 | - LoadWhere: Load a record where certain conditions obtain. 18 | 19 | Structable is pragmatic in the sense that it allows ActiveRecord-like extension 20 | of the Record object to allow business logic. A Record does not *have* to be 21 | a simple data-only struct. It can have methods -- even methods that operate 22 | on the database. 23 | 24 | Importantly, Structable does not do any relation management. There is no 25 | magic to convert structs, arrays, or maps to references to other tables. 26 | (If you want that, you may prefer GORM or GORP.) The preferred method of 27 | handling relations is to attach additional methods to the Record struct. 28 | 29 | Structable uses Squirrel for statement building, and you may also use 30 | Squirrel for working with your data. 31 | 32 | Basic Usage 33 | 34 | The following example is taken from the `example/users.go` file. 35 | 36 | 37 | package main 38 | 39 | import ( 40 | "github.com/Masterminds/squirrel" 41 | "github.com/Masterminds/structable" 42 | _ "github.com/lib/pq" 43 | 44 | "database/sql" 45 | "fmt" 46 | ) 47 | 48 | // This is our struct. Notice that we make this a structable.Recorder. 49 | type User struct { 50 | structable.Recorder 51 | builder squirrel.StatementBuilderType 52 | 53 | Id int `stbl:"id,PRIMARY_KEY,SERIAL"` 54 | Name string `stbl:"name"` 55 | Email string `stbl:"email"` 56 | } 57 | 58 | // NewUser creates a new Structable wrapper for a user. 59 | // 60 | // Of particular importance, watch how we intialize the Recorder. 61 | func NewUser(db squirrel.DBProxyBeginner, dbFlavor string) *User { 62 | u := new(User) 63 | u.Recorder = structable.New(db, dbFlavor).Bind(UserTable, u) 64 | return u 65 | } 66 | 67 | func main() { 68 | 69 | // Boilerplate DB setup. 70 | // First, we need to know the database driver. 71 | driver := "postgres" 72 | // Second, we need a database connection. 73 | con, _ := sql.Open(driver, "dbname=structable_test sslmode=disable") 74 | // Third, we wrap in a prepared statement cache for better performance. 75 | cache := squirrel.NewStmtCacheProxy(con) 76 | 77 | // Create an empty new user and give it some properties. 78 | user := NewUser(cache, driver) 79 | user.Name = "Matt" 80 | user.Email = "matt@example.com" 81 | 82 | // Insert this as a new record. 83 | if err := user.Insert(); err != nil { 84 | panic(err.Error()) 85 | } 86 | fmt.Printf("Initial insert has ID %d, name %s, and email %s\n", user.Id, user.Name, user.Email) 87 | 88 | // Now create another empty User and set the user's Name. 89 | again := NewUser(cache, driver) 90 | again.Id = user.Id 91 | 92 | // Load a duplicate copy of our user. This loads by the value of 93 | // again.Id 94 | again.Load() 95 | 96 | again.Email = "technosophos@example.com" 97 | if err := again.Update(); err != nil { 98 | panic(err.Error()) 99 | } 100 | fmt.Printf("Updated user has ID %d and email %s\n", again.Id, again.Email) 101 | 102 | // Delete using the built-in Deleter. (delete by Id.) 103 | if err := again.Delete(); err != nil { 104 | panic(err.Error()) 105 | } 106 | fmt.Printf("Deleted user %d\n", again.Id) 107 | } 108 | 109 | The above pattern closely binds the Recorder to the Record. Essentially, in 110 | this usage Structable works like an ActiveRecord. 111 | 112 | It is also possible to emulate a DAO-type model and use the Recorder as a data 113 | access object and the Record as the data description object. An example of this 114 | method can be found in the `example/fence.go` code. 115 | 116 | The Stbl Tag 117 | 118 | The `stbl` tag is of the form: 119 | 120 | stbl:"field_name [,PRIMARY_KEY[,AUTO_INCREMENT]]" 121 | 122 | The field name is passed verbatim to the database. So `fieldName` will go to the database as `fieldName`. 123 | Structable is not at all opinionated about how you name your tables or fields. Some databases are, though, so 124 | you may need to be careful about your own naming conventions. 125 | 126 | `PRIMARY_KEY` tells Structable that this field is (one of the pieces of) the primary key. Aliases: 'PRIMARY KEY' 127 | 128 | `AUTO_INCREMENT` tells Structable that this field is created by the database, and should never 129 | be assigned during an Insert(). Aliases: SERIAL, AUTO INCREMENT 130 | 131 | Limitations 132 | 133 | Things Structable doesn't do (by design) 134 | 135 | - Guess table or column names. You must specify these. 136 | - Handle relations between tables. 137 | - Manage the schema. 138 | - Transform complex struct fields into simple ones (that is, serialize fields). 139 | 140 | However, Squirrel can ease many of these tasks. 141 | 142 | */ 143 | package structable 144 | 145 | import ( 146 | "fmt" 147 | "reflect" 148 | "strings" 149 | 150 | "github.com/Masterminds/squirrel" 151 | ) 152 | 153 | // 'stbl' is the main tag used for annotating Structable Records. 154 | const StructableTag = "stbl" 155 | 156 | /* Record describes a struct that can be stored. 157 | 158 | Example: 159 | 160 | type Stool struct { 161 | Id int `stbl:"id PRIMARY_KEY AUTO_INCREMENT"` 162 | Legs int `stbl:"number_of_legs"` 163 | Material string `stbl:"material"` 164 | Ignored string // will not be stored. 165 | } 166 | 167 | The above links the Stool record to a database table that has a primary 168 | key (with auto-incrementing values) called 'id', an int field named 169 | 'number_of_legs', and a 'material' field that is a VARCHAR or TEXT (depending 170 | on the database implementation). 171 | 172 | */ 173 | type Record interface{} 174 | 175 | // Internal representation of a field on a database table, and its 176 | // relation to a struct field. 177 | type field struct { 178 | // name = Struct field name 179 | // column = table column name 180 | name, column string 181 | // Is a primary key 182 | isKey bool 183 | // Is an auto increment 184 | isAuto bool 185 | } 186 | 187 | // A Recorder is responsible for managing the persistence of a Record. 188 | // A Recorder is bound to a struct, which it then examines for fields 189 | // that should be stored in the database. From that point on, a recorder 190 | // can manage the persistent lifecycle of the record. 191 | type Recorder interface { 192 | // Bind this Recorder to a table and to a Record. 193 | // 194 | // The table name is used verbatim. DO NOT TRUST USER-SUPPLIED VALUES. 195 | // 196 | // The struct is examined for tags, and those tags are parsed and used to determine 197 | // details about each field. 198 | Bind(string, Record) Recorder 199 | 200 | // Interface provides a way of fetching the record from the Recorder. 201 | // 202 | // A record is bound to a Recorder via Bind, and retrieved from a Recorder 203 | // via Interface(). 204 | // 205 | // This is conceptually similar to reflect.Value.Interface(). 206 | Interface() interface{} 207 | 208 | Loader 209 | Haecceity 210 | Saver 211 | Describer 212 | 213 | // This returns the column names used for the primary key. 214 | //Key() []string 215 | } 216 | 217 | type Loader interface { 218 | // Loads the entire Record using the value of the PRIMARY_KEY(s) 219 | // This will only fetch columns that are mapped on the bound Record. But you can think of it 220 | // as doing something like this: 221 | // 222 | // SELECT * FROM bound_table WHERE id=? LIMIT 1 223 | // 224 | // And then mapping the result to the currently bound Record. 225 | Load() error 226 | // Load by a WHERE-like clause. See Squirrel's Where(pred, args) 227 | LoadWhere(interface{}, ...interface{}) error 228 | } 229 | 230 | type Saver interface { 231 | // Insert inserts the bound Record into the bound table. 232 | Insert() error 233 | 234 | // Update updates all of the fields on the bound Record based on the PRIMARY_KEY fields. 235 | // 236 | // Essentially, it does something like this: 237 | // UPDATE bound_table SET every=?, field=?, but=?, keys=? WHERE primary_key=? 238 | Update() error 239 | 240 | // Deletes a Record based on its PRIMARY_KEY(s). 241 | Delete() error 242 | } 243 | 244 | // Haecceity indicates whether a thing exists. 245 | // 246 | // Actually, it is responsible for testing whether a thing exists, and is 247 | // what we think it is. 248 | type Haecceity interface { 249 | // Exists verifies that a thing exists and is of this type. 250 | // This uses the PRIMARY_KEY to verify that a record exists. 251 | Exists() (bool, error) 252 | // ExistsWhere verifies that a thing exists and is of the expected type. 253 | // It takes a WHERE clause, and it needs to gaurantee that at least one 254 | // record matches. It need not assure that *only* one item exists. 255 | ExistsWhere(interface{}, ...interface{}) (bool, error) 256 | } 257 | 258 | // Describer is a structable object that can describe its table structure. 259 | type Describer interface { 260 | // Columns gets the columns on this table. 261 | Columns(bool) []string 262 | // FieldReferences gets references to the fields on this object. 263 | FieldReferences(bool) []interface{} 264 | // WhereIds returns a map of ID fields to (current) ID values. 265 | // 266 | // This is useful to quickly generate where clauses. 267 | WhereIds() map[string]interface{} 268 | 269 | // TableName returns the table name. 270 | TableName() string 271 | // Builder returns the builder 272 | Builder() *squirrel.StatementBuilderType 273 | // DB returns a DB-like handle. 274 | DB() squirrel.DBProxyBeginner 275 | 276 | Driver() string 277 | 278 | Init(d squirrel.DBProxyBeginner, flavor string) 279 | } 280 | 281 | // List returns a list of objects of the given kind. 282 | // 283 | // This runs a Select of the given kind, and returns the results. 284 | func List(d Recorder, limit, offset uint64) ([]Recorder, error) { 285 | fn := func(desc Describer, query squirrel.SelectBuilder) (squirrel.SelectBuilder, error) { 286 | return query.Limit(limit).Offset(offset), nil 287 | } 288 | 289 | return ListWhere(d, fn) 290 | } 291 | 292 | // WhereFunc modifies a basic select operation to add conditions. 293 | // 294 | // Technically, conditions are not limited to adding where clauses. It will receive 295 | // a select statement with the 'SELECT ... FROM tablename' portion composed already. 296 | type WhereFunc func(desc Describer, query squirrel.SelectBuilder) (squirrel.SelectBuilder, error) 297 | 298 | // ListWhere takes a Recorder and a query modifying function and executes a query. 299 | // 300 | // The WhereFunc will be given a SELECT d.Colsumns() FROM d.TableName() statement, 301 | // and may modify it. Note that while joining is supported, changing the column 302 | // list will have unpredictable side effects. It is advised that joins be done 303 | // using Squirrel instead. 304 | // 305 | // This will return a list of Recorder objects, where the underlying type 306 | // of each matches the underlying type of the passed-in 'd' Recorder. 307 | func ListWhere(d Recorder, fn WhereFunc) ([]Recorder, error) { 308 | var tn string = d.TableName() 309 | var cols []string = d.Columns(false) 310 | buf := []Recorder{} 311 | 312 | // Base query 313 | q := d.Builder().Select(cols...).From(tn) 314 | 315 | // Allow the fn to modify our query 316 | var err error 317 | q, err = fn(d, q) 318 | if err != nil { 319 | return buf, err 320 | } 321 | 322 | rows, err := q.Query() 323 | if err != nil || rows == nil { 324 | return buf, err 325 | } 326 | defer rows.Close() 327 | 328 | v := reflect.Indirect(reflect.ValueOf(d)) 329 | t := v.Type() 330 | for rows.Next() { 331 | nv := reflect.New(t) 332 | 333 | // Bind an empty base object. Basically, we fetch the object out of 334 | // the DbRecorder, and then construct an empty one. 335 | rec := reflect.New(reflect.Indirect(reflect.ValueOf(d.(*DbRecorder).record)).Type()) 336 | nv.Interface().(Recorder).Bind(d.TableName(), rec.Interface()) 337 | 338 | s := nv.Interface().(Recorder) 339 | s.Init(d.DB(), d.Driver()) 340 | dest := s.FieldReferences(true) 341 | rows.Scan(dest...) 342 | buf = append(buf, s) 343 | } 344 | 345 | return buf, rows.Err() 346 | } 347 | 348 | // Implements the Recorder interface, and stores data in a DB. 349 | type DbRecorder struct { 350 | builder *squirrel.StatementBuilderType 351 | db squirrel.DBProxyBeginner 352 | table string 353 | fields []*field 354 | key []*field 355 | record Record 356 | flavor string 357 | } 358 | 359 | func (d *DbRecorder) Interface() interface{} { 360 | return d.record 361 | } 362 | 363 | // New creates a new DbRecorder. 364 | // 365 | // (The squirrel.DBProxy interface defines the functions normal for a database connection 366 | // or a prepared statement cache.) 367 | func New(db squirrel.DBProxyBeginner, flavor string) *DbRecorder { 368 | d := new(DbRecorder) 369 | d.Init(db, flavor) 370 | return d 371 | } 372 | 373 | // Init initializes a DbRecorder 374 | func (d *DbRecorder) Init(db squirrel.DBProxyBeginner, flavor string) { 375 | b := squirrel.StatementBuilder.RunWith(db) 376 | if flavor == "postgres" { 377 | b = b.PlaceholderFormat(squirrel.Dollar) 378 | } 379 | 380 | d.builder = &b 381 | d.db = db 382 | d.flavor = flavor 383 | } 384 | 385 | // TableName returns the table name of this recorder. 386 | func (s *DbRecorder) TableName() string { 387 | return s.table 388 | } 389 | 390 | // DB returns the database (DBProxyBeginner) for this recorder. 391 | func (s *DbRecorder) DB() squirrel.DBProxyBeginner { 392 | return s.db 393 | } 394 | 395 | // Builder returns the statement builder for this recorder. 396 | func (s *DbRecorder) Builder() *squirrel.StatementBuilderType { 397 | return s.builder 398 | } 399 | 400 | // Driver returns the string name of the driver. 401 | func (s *DbRecorder) Driver() string { 402 | return s.flavor 403 | } 404 | 405 | // Bind binds a DbRecorder to a Record. 406 | // 407 | // This takes a given structable.Record and binds it to the recorder. That means 408 | // that the recorder will track all changes to the Record. 409 | // 410 | // The table name tells the recorder which database table to link this record 411 | // to. All storage operations will use that table. 412 | func (s *DbRecorder) Bind(tableName string, ar Record) Recorder { 413 | 414 | // "To be is to be the value of a bound variable." - W. O. Quine 415 | 416 | // Get the table name 417 | s.table = tableName 418 | 419 | // Get the fields 420 | s.scanFields(ar) 421 | 422 | s.record = ar 423 | 424 | return Recorder(s) 425 | } 426 | 427 | // Key gets the string names of the fields used as primary key. 428 | func (s *DbRecorder) Key() []string { 429 | key := make([]string, len(s.key)) 430 | 431 | for i, f := range s.key { 432 | key[i] = f.column 433 | } 434 | 435 | return key 436 | } 437 | 438 | // Load selects the record from the database and loads the values into the bound Record. 439 | // 440 | // Load uses the table's PRIMARY KEY(s) as the sole criterion for matching a 441 | // record. Essentially, it is akin to `SELECT * FROM table WHERE primary_key = ?`. 442 | // 443 | // This modifies the Record in-place. Other than the primary key fields, any 444 | // other field will be overwritten by the value retrieved from the database. 445 | func (s *DbRecorder) Load() error { 446 | whereParts := s.WhereIds() 447 | dest := s.FieldReferences(false) 448 | 449 | q := s.builder.Select(s.colList(false, false)...).From(s.table).Where(whereParts) 450 | err := q.QueryRow().Scan(dest...) 451 | 452 | return err 453 | } 454 | 455 | // LoadWhere loads an object based on a WHERE clause. 456 | // 457 | // This can be used to define alternate loaders: 458 | // 459 | // func (s *MyStructable) LoadUuid(uuid string) error { 460 | // return s.LoadWhere("uuid = ?", uuid) 461 | // } 462 | // 463 | // This functions similarly to Load, but with the notable difference that 464 | // it loads the entire object (it does not skip keys used to do the lookup). 465 | func (s *DbRecorder) LoadWhere(pred interface{}, args ...interface{}) error { 466 | dest := s.FieldReferences(true) 467 | 468 | q := s.builder.Select(s.colList(true, true)...).From(s.table).Where(pred, args...) 469 | err := q.QueryRow().Scan(dest...) 470 | 471 | return err 472 | } 473 | 474 | // Exists returns `true` if and only if there is at least one record that matches the primary keys for this Record. 475 | // 476 | // If the primary key on the Record has no value, this will look for records with no value (or the default 477 | // value). 478 | func (s *DbRecorder) Exists() (bool, error) { 479 | has := false 480 | whereParts := s.WhereIds() 481 | 482 | q := s.builder.Select("COUNT(*) > 0").From(s.table).Where(whereParts) 483 | err := q.QueryRow().Scan(&has) 484 | 485 | return has, err 486 | } 487 | 488 | // ExistsWhere returns `true` if and only if there is at least one record that matches one (or multiple) conditions. 489 | // 490 | // Conditions are expressed in the form of predicates and expected values 491 | // that together build a WHERE clause. See Squirrel's Where(pred, args) 492 | func (s *DbRecorder) ExistsWhere(pred interface{}, args ...interface{}) (bool, error) { 493 | has := false 494 | 495 | q := s.builder.Select("COUNT(*) > 0").From(s.table).Where(pred, args...) 496 | err := q.QueryRow().Scan(&has) 497 | 498 | return has, err 499 | } 500 | 501 | // Delete deletes the record from the underlying table. 502 | // 503 | // The fields on the present record will remain set, but not saved in the database. 504 | func (s *DbRecorder) Delete() error { 505 | wheres := s.WhereIds() 506 | q := s.builder.Delete(s.table).Where(wheres) 507 | _, err := q.Exec() 508 | return err 509 | } 510 | 511 | // Insert puts a new record into the database. 512 | // 513 | // This operation is particularly sensitive to DB differences in cases where AUTO_INCREMENT is set 514 | // on a member of the Record. 515 | func (s *DbRecorder) Insert() error { 516 | switch s.flavor { 517 | case "postgres": 518 | return s.insertPg() 519 | default: 520 | return s.insertStd() 521 | } 522 | } 523 | 524 | // Insert and assume that LastInsertId() returns something. 525 | func (s *DbRecorder) insertStd() error { 526 | 527 | cols, vals := s.colValLists(true, false) 528 | 529 | q := s.builder.Insert(s.table).Columns(cols...).Values(vals...) 530 | 531 | ret, err := q.Exec() 532 | if err != nil { 533 | return err 534 | } 535 | 536 | for _, f := range s.fields { 537 | if f.isAuto { 538 | ar := reflect.Indirect(reflect.ValueOf(s.record)) 539 | field := ar.FieldByName(f.name) 540 | 541 | id, err := ret.LastInsertId() 542 | if err != nil { 543 | return fmt.Errorf("Could not get last insert ID. Did you set the db flavor? %s", err) 544 | } 545 | 546 | if !field.CanSet() { 547 | return fmt.Errorf("Could not set %s to returned value", f.name) 548 | } 549 | field.SetInt(id) 550 | } 551 | } 552 | 553 | return err 554 | } 555 | 556 | // insertPg runs a postgres-specific INSERT. Unlike the default (MySQL) driver, 557 | // this actually refreshes ALL of the fields on the Record object. We do this 558 | // because it is trivially easy in Postgres. 559 | func (s *DbRecorder) insertPg() error { 560 | cols, vals := s.colValLists(true, false) 561 | dest := s.FieldReferences(true) 562 | q := s.builder.Insert(s.table).Columns(cols...).Values(vals...). 563 | Suffix("RETURNING " + strings.Join(s.colList(true, false), ",")) 564 | 565 | sql, vals, err := q.ToSql() 566 | if err != nil { 567 | return err 568 | } 569 | 570 | return s.db.QueryRow(sql, vals...).Scan(dest...) 571 | } 572 | 573 | // Update updates the values on an existing entry. 574 | // 575 | // This updates records where the Record's primary keys match the record in the 576 | // database. Essentially, it runs `UPDATE table SET names=values WHERE id=?` 577 | // 578 | // If no entry is found, update will NOT create (INSERT) a new record. 579 | func (s *DbRecorder) Update() error { 580 | whereParts := s.WhereIds() 581 | updates := s.updateFields() 582 | q := s.builder.Update(s.table).SetMap(updates).Where(whereParts) 583 | _, err := q.Exec() 584 | return err 585 | } 586 | 587 | // Columns returns the names of the columns on this table. 588 | // 589 | // If includeKeys is false, the columns that are marked as keys are omitted 590 | // from the returned list. 591 | func (s *DbRecorder) Columns(includeKeys bool) []string { 592 | return s.colList(includeKeys, false) 593 | } 594 | 595 | // colList gets a list of column names. If withKeys is false, columns that are 596 | // designated as primary keys will not be returned in this list. 597 | // If omitNil is true, a column represented by pointer will be omitted if this 598 | // pointer is nil in current record 599 | func (s *DbRecorder) colList(withKeys bool, omitNil bool) []string { 600 | names := make([]string, 0, len(s.fields)) 601 | 602 | var ar reflect.Value 603 | if omitNil { 604 | ar = reflect.Indirect(reflect.ValueOf(s.record)) 605 | } 606 | 607 | for _, field := range s.fields { 608 | if !withKeys && field.isKey { 609 | continue 610 | } 611 | if omitNil { 612 | f := ar.FieldByName(field.name) 613 | if f.Kind() == reflect.Ptr && f.IsNil() { 614 | continue 615 | } 616 | } 617 | names = append(names, field.column) 618 | } 619 | 620 | return names 621 | } 622 | 623 | // FieldReferences returns a list of references to fields on this object. 624 | // 625 | // If withKeys is true, fields that compose the primary key will also be 626 | // included. Otherwise, only non-primary key fields will be included. 627 | // 628 | // This is used for processing SQL results: 629 | // 630 | // dest := s.FieldReferences(false) 631 | // q := s.builder.Select(s.Columns(false)...).From(s.table) 632 | // err := q.QueryRow().Scan(dest...) 633 | func (s *DbRecorder) FieldReferences(withKeys bool) []interface{} { 634 | refs := make([]interface{}, 0, len(s.fields)) 635 | 636 | ar := reflect.Indirect(reflect.ValueOf(s.record)) 637 | for _, field := range s.fields { 638 | if !withKeys && field.isKey { 639 | continue 640 | } 641 | 642 | fv := ar.FieldByName(field.name) 643 | var ref reflect.Value 644 | if fv.Kind() != reflect.Ptr { 645 | // we want the address of field 646 | ref = fv.Addr() 647 | } else { 648 | // we already have an address 649 | ref = fv 650 | if fv.IsNil() { 651 | // allocate a new element of same type 652 | fv.Set(reflect.New(fv.Type().Elem())) 653 | } 654 | } 655 | refs = append(refs, ref.Interface()) 656 | } 657 | 658 | return refs 659 | } 660 | 661 | // colValLists returns 2 lists, the column names and values. 662 | // If withKeys is false, columns and values of fields designated as primary keys 663 | // will not be included in those lists. Also, if withAutos is false, the returned 664 | // lists will not include fields designated as auto-increment. 665 | func (s *DbRecorder) colValLists(withKeys, withAutos bool) (columns []string, values []interface{}) { 666 | ar := reflect.Indirect(reflect.ValueOf(s.record)) 667 | 668 | for _, field := range s.fields { 669 | 670 | switch { 671 | case !withKeys && field.isKey: 672 | continue 673 | case !withAutos && field.isAuto: 674 | continue 675 | } 676 | 677 | // Get the value of the field we are going to store. 678 | f := ar.FieldByName(field.name) 679 | var v reflect.Value 680 | if f.Kind() == reflect.Ptr { 681 | if f.IsNil() { 682 | // nothing to store 683 | continue 684 | } 685 | // no indirection: the field is already a reference to its value 686 | v = f 687 | } else { 688 | // get the value pointed to by the field 689 | v = reflect.Indirect(f) 690 | } 691 | 692 | values = append(values, v.Interface()) 693 | columns = append(columns, field.column) 694 | } 695 | 696 | return 697 | } 698 | 699 | // updateFields produces fields to go into SetMap for an update. 700 | // This will NOT update PRIMARY_KEY fields. 701 | func (s *DbRecorder) updateFields() map[string]interface{} { 702 | update := map[string]interface{}{} 703 | cols, vals := s.colValLists(false, true) 704 | for i, col := range cols { 705 | update[col] = vals[i] 706 | } 707 | return update 708 | } 709 | 710 | // WhereIds gets a list of names and a list of values for all columns marked as primary 711 | // keys. 712 | func (s *DbRecorder) WhereIds() map[string]interface{} { 713 | clause := make(map[string]interface{}, len(s.key)) 714 | 715 | ar := reflect.Indirect(reflect.ValueOf(s.record)) 716 | 717 | for _, f := range s.key { 718 | clause[f.column] = ar.FieldByName(f.name).Interface() 719 | } 720 | 721 | return clause 722 | } 723 | 724 | // scanFields extracts the tags from all of the fields on a struct. 725 | func (s *DbRecorder) scanFields(ar Record) { 726 | v := reflect.Indirect(reflect.ValueOf(ar)) 727 | t := v.Type() 728 | count := t.NumField() 729 | keys := make([]*field, 0, 2) 730 | 731 | for i := 0; i < count; i++ { 732 | f := t.Field(i) 733 | // Skip fields with no tag. 734 | if len(f.Tag) == 0 { 735 | continue 736 | } 737 | sqtag := f.Tag.Get("stbl") 738 | if len(sqtag) == 0 { 739 | continue 740 | } 741 | 742 | parts := s.parseTag(f.Name, sqtag) 743 | field := new(field) 744 | field.name = f.Name 745 | field.column = parts[0] 746 | for _, part := range parts[1:] { 747 | part = strings.TrimSpace(part) 748 | switch part { 749 | case "PRIMARY_KEY", "PRIMARY KEY": 750 | field.isKey = true 751 | keys = append(keys, field) 752 | case "AUTO_INCREMENT", "SERIAL", "AUTO INCREMENT": 753 | field.isAuto = true 754 | } 755 | } 756 | s.fields = append(s.fields, field) 757 | s.key = keys 758 | } 759 | } 760 | 761 | // parseTag parses the contents of a stbl tag. 762 | func (s *DbRecorder) parseTag(fieldName, tag string) []string { 763 | parts := strings.Split(tag, ",") 764 | if len(parts) == 0 { 765 | return []string{fieldName} 766 | } 767 | return parts 768 | } 769 | -------------------------------------------------------------------------------- /structable_test.go: -------------------------------------------------------------------------------- 1 | package structable 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | "regexp" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/Masterminds/squirrel" 12 | ) 13 | 14 | type Stool struct { 15 | Id int `stbl:"id,PRIMARY_KEY,AUTO_INCREMENT"` 16 | Id2 int `stbl:"id_two, PRIMARY_KEY "` 17 | Legs int `stbl:"number_of_legs"` 18 | Material string `stbl:"material"` 19 | Color *string `stbl:"color"` 20 | Ignored string // will not be stored. 21 | } 22 | 23 | func newStool() *Stool { 24 | stool := new(Stool) 25 | 26 | stool.Id = 1 27 | stool.Id2 = 2 28 | stool.Legs = 3 29 | stool.Material = "Stainless Steel" 30 | stool.Ignored = "Boo" 31 | 32 | return stool 33 | } 34 | 35 | type ActRec struct { 36 | Id int `stbl:"id,SERIAL,PRIMARY_KEY"` 37 | Name string `stbl:"name"` 38 | recorder Recorder 39 | } 40 | 41 | func NewActRec(db *DBStub) *ActRec { 42 | a := new(ActRec) 43 | 44 | a.recorder = New(db, "mysql").Bind("my_table", a) 45 | 46 | return a 47 | } 48 | 49 | func (a *ActRec) Exists() bool { 50 | 51 | ok, err := a.recorder.Exists() 52 | 53 | return err == nil && ok 54 | } 55 | 56 | func TestBind(t *testing.T) { 57 | store := new(DbRecorder) 58 | 59 | stool := newStool() 60 | store.Bind("test_table", stool) 61 | 62 | if store.table != "test_table" { 63 | t.Errorf("Failed to get table name.") 64 | } 65 | 66 | if len(store.fields) != 5 { 67 | t.Errorf("Expected 5 fields, got %d: %+v", len(store.fields), store.fields) 68 | } 69 | 70 | keyCount := 0 71 | for _, f := range store.fields { 72 | if f.isKey { 73 | keyCount++ 74 | } 75 | } 76 | 77 | if keyCount != 2 { 78 | t.Errorf("Expected two keys.") 79 | } 80 | 81 | if len(store.Key()) != 2 { 82 | t.Errorf("Wrong number of keys.") 83 | } 84 | } 85 | 86 | func TestLoad(t *testing.T) { 87 | stool := newStool() 88 | db := &DBStub{} 89 | //db, builder := squirrelFixture() 90 | 91 | r := New(db, "mysql").Bind("test_table", stool) 92 | 93 | if err := r.Load(); err != nil { 94 | t.Errorf("Error running query: %s", err) 95 | } 96 | 97 | expect := "SELECT number_of_legs, material, color FROM test_table WHERE id = ? AND id_two = ?" 98 | if db.LastQueryRowSql != expect { 99 | t.Errorf("Unexpected SQL: %s", db.LastQueryRowSql) 100 | } 101 | 102 | expectargs := []interface{}{1, 2} 103 | got := db.LastQueryRowArgs 104 | for i, exp := range expectargs { 105 | if exp != got[i] { 106 | t.Errorf("Surprise! %v doesn't equal %v", exp, got[i]) 107 | } 108 | } 109 | } 110 | 111 | func TestLoadWhere(t *testing.T) { 112 | stool := newStool() 113 | db := &DBStub{} 114 | 115 | r := New(db, "mysql").Bind("test_table", stool) 116 | 117 | if err := r.LoadWhere("number_of_legs = ?", 3); err != nil { 118 | t.Errorf("Error running query: %s", err) 119 | } 120 | 121 | if len(db.LastQueryRowArgs) != 1 { 122 | t.Errorf("Expected exactly one where arg.") 123 | } 124 | 125 | expect := "SELECT .* FROM test_table WHERE number_of_legs = ?" 126 | if ok, err := regexp.MatchString(expect, db.LastQueryRowSql); err != nil { 127 | t.Errorf("Failed to run regexp: %s", err) 128 | } else if !ok { 129 | t.Errorf("%s did not match pattern %s", db.LastQueryRowSql, expect) 130 | } 131 | 132 | } 133 | 134 | func TestList(t *testing.T) { 135 | stool := newStool() 136 | db := &DBStub{} 137 | //db, builder := squirrelFixture() 138 | 139 | r := New(db, "mysql").Bind("test_table", stool) 140 | 141 | if _, err := List(r, 10, 0); err != nil { 142 | t.Errorf("Error running query: %s", err) 143 | } 144 | 145 | expect := "SELECT number_of_legs, material, color FROM test_table LIMIT 10 OFFSET 0" 146 | if db.LastQuerySql != expect { 147 | t.Errorf("Unexpected SQL: %q\nGot %q", expect, db.LastQuerySql) 148 | } 149 | } 150 | 151 | func TestListWhere_Error(t *testing.T) { 152 | stool := newStool() 153 | db := &DBStub{} 154 | r := New(db, "mysql").Bind("test_table", stool) 155 | 156 | fn := func(d Describer, q squirrel.SelectBuilder) (squirrel.SelectBuilder, error) { 157 | return q, errors.New("intentional failure") 158 | } 159 | 160 | if _, err := ListWhere(r, fn); err == nil { 161 | t.Error("Expected failed WhereFunc to fail query") 162 | } 163 | } 164 | 165 | func TestInsert(t *testing.T) { 166 | stool := newStool() 167 | db := new(DBStub) 168 | 169 | rec := New(db, "mysql").Bind("test_table", stool) 170 | 171 | if err := rec.Insert(); err != nil { 172 | t.Errorf("Failed insert: %s", err) 173 | } 174 | 175 | expect := "INSERT INTO test_table (id_two,number_of_legs,material) VALUES (?,?,?)" 176 | if db.LastExecSql != expect { 177 | t.Errorf("Expected '%s', got '%s'", expect, db.LastExecSql) 178 | } 179 | 180 | expectargs := []interface{}{stool.Id2, stool.Legs, stool.Material} 181 | gotargs := db.LastExecArgs 182 | 183 | for i := range expectargs { 184 | if expectargs[i] != gotargs[i] { 185 | t.Errorf("Expected %v, got %v", expectargs[i], gotargs[i]) 186 | } 187 | } 188 | } 189 | 190 | func TestUpdate(t *testing.T) { 191 | stool := newStool() 192 | db := new(DBStub) 193 | 194 | rec := New(db, "mysql").Bind("test_table", stool) 195 | 196 | // with nil pointer field 197 | if err := rec.Update(); err != nil { 198 | t.Errorf("Update error: %s", err) 199 | } 200 | 201 | if !strings.Contains(db.LastExecSql, "number_of_legs = ") { 202 | t.Error("Expected 'number_of_legs' in query") 203 | } 204 | if !strings.Contains(db.LastExecSql, "material = ") { 205 | t.Error("Expected 'material' in query") 206 | } 207 | 208 | eargs := []interface{}{3, "Stainless Steel", 1, 2} 209 | gotargs := db.LastExecArgs 210 | for _, exp := range eargs { 211 | found := false 212 | for _, arg := range gotargs { 213 | if arg == exp { 214 | found = true 215 | break 216 | } 217 | } 218 | if !found { 219 | t.Errorf("Could not find %v in %v", exp, gotargs) 220 | } 221 | } 222 | 223 | // with allocated pointer 224 | blue := "Blue" 225 | stool.Color = &blue 226 | 227 | if err := rec.Update(); err != nil { 228 | t.Errorf("Update error: %s", err) 229 | } 230 | 231 | if !strings.Contains(db.LastExecSql, "number_of_legs = ") { 232 | t.Error("Expected 'number_of_legs' in query") 233 | } 234 | if !strings.Contains(db.LastExecSql, "material = ") { 235 | t.Error("Expected 'material' in query") 236 | } 237 | if !strings.Contains(db.LastExecSql, "color = ") { 238 | t.Error("Expected 'color' in query") 239 | } 240 | 241 | eargs = []interface{}{3, "Stainless Steel", &blue, 1, 2} 242 | gotargs = db.LastExecArgs 243 | for _, exp := range eargs { 244 | found := false 245 | for _, arg := range gotargs { 246 | if arg == exp { 247 | found = true 248 | break 249 | } 250 | } 251 | if !found { 252 | t.Errorf("Could not find %v in %v", exp, gotargs) 253 | } 254 | } 255 | } 256 | 257 | func TestDelete(t *testing.T) { 258 | stool := newStool() 259 | db := &DBStub{} 260 | r := New(db, "mysql").Bind("test_table", stool) 261 | 262 | if err := r.Delete(); err != nil { 263 | t.Errorf("Failed to delete: %s", err) 264 | } 265 | 266 | expect := "DELETE FROM test_table WHERE .* AND .*" 267 | if ok, _ := regexp.MatchString(expect, db.LastExecSql); !ok { 268 | t.Errorf("Unexpect query: %s", db.LastExecSql) 269 | } 270 | if got := db.LastExecArgs[0].(int); got != 1 { 271 | t.Errorf("Expected 1, got %d", got) 272 | } 273 | } 274 | 275 | func TestExists(t *testing.T) { 276 | stool := newStool() 277 | db := &DBStub{} 278 | r := New(db, "mysql").Bind("test_table", stool) 279 | 280 | _, err := r.Exists() 281 | if err != nil { 282 | t.Errorf("Error calling Exists: %s", err) 283 | } 284 | 285 | expect := "SELECT COUNT(*) > 0 FROM test_table WHERE id = ? AND id_two = ?" 286 | if db.LastQueryRowSql != expect { 287 | t.Errorf("Unexpected SQL: expected %q, got %q", expect, db.LastQueryRowSql) 288 | } 289 | } 290 | 291 | func TestActiveRecord(t *testing.T) { 292 | db := &DBStub{} 293 | a := NewActRec(db) 294 | a.Id = 999 295 | 296 | if a.Exists() { 297 | t.Errorf("Expected record to be absent.") 298 | } 299 | } 300 | 301 | func squirrelFixture() (*DBStub, squirrel.StatementBuilderType) { 302 | 303 | db := &DBStub{} 304 | //cache := squirrel.NewStmtCacher(db) 305 | return db, squirrel.StatementBuilder.RunWith(db) 306 | 307 | } 308 | 309 | // FIXTURES 310 | type DBStub struct { 311 | err error 312 | 313 | LastPrepareSql string 314 | PrepareCount int 315 | 316 | LastExecSql string 317 | LastExecArgs []interface{} 318 | 319 | LastQuerySql string 320 | LastQueryArgs []interface{} 321 | 322 | LastQueryRowSql string 323 | LastQueryRowArgs []interface{} 324 | } 325 | 326 | var StubError = fmt.Errorf("this is a stub; this is only a stub") 327 | 328 | func (s *DBStub) Prepare(query string) (*sql.Stmt, error) { 329 | s.LastPrepareSql = query 330 | s.PrepareCount++ 331 | return nil, nil 332 | } 333 | 334 | func (s *DBStub) Exec(query string, args ...interface{}) (sql.Result, error) { 335 | s.LastExecSql = query 336 | s.LastExecArgs = args 337 | return &ResultStub{id: 1, affectedRows: 1}, nil 338 | } 339 | 340 | func (s *DBStub) Query(query string, args ...interface{}) (*sql.Rows, error) { 341 | s.LastQuerySql = query 342 | s.LastQueryArgs = args 343 | return nil, nil 344 | } 345 | 346 | func (s *DBStub) QueryRow(query string, args ...interface{}) squirrel.RowScanner { 347 | s.LastQueryRowSql = query 348 | s.LastQueryRowArgs = args 349 | return &squirrel.Row{RowScanner: &RowStub{}} 350 | } 351 | 352 | func (s *DBStub) Begin() (*sql.Tx, error) { 353 | return nil, nil 354 | } 355 | 356 | type RowStub struct { 357 | Scanned bool 358 | } 359 | 360 | func (r *RowStub) Scan(_ ...interface{}) error { 361 | r.Scanned = true 362 | return nil 363 | } 364 | 365 | type ResultStub struct { 366 | id, affectedRows int64 367 | } 368 | 369 | func (r *ResultStub) LastInsertId() (int64, error) { 370 | return r.id, nil 371 | } 372 | func (r *ResultStub) RowsAffected() (int64, error) { 373 | return r.affectedRows, nil 374 | } 375 | --------------------------------------------------------------------------------