├── .gitignore ├── .travis.yml ├── LICENSE ├── MAINTAINERS ├── Makefile ├── README.md ├── appveyor.yml ├── batcher.go ├── batcher_test.go ├── benchmarks ├── bench_test.go ├── kallax.go ├── models │ ├── boil_queries.go │ ├── boil_types.go │ ├── people.go │ ├── pets.go │ └── schema_migrations.go ├── models_gorm.go └── models_kallax.go ├── common_test.go ├── doc.go ├── events.go ├── events_test.go ├── generator ├── cli │ └── kallax │ │ ├── cmd.go │ │ └── cmd │ │ ├── gen.go │ │ ├── migrate.go │ │ ├── migrate_test.go │ │ ├── migrate_windows_test.go │ │ └── util.go ├── common_test.go ├── generator.go ├── generator_test.go ├── migration.go ├── migration_test.go ├── processor.go ├── processor_test.go ├── template.go ├── template_test.go ├── templates │ ├── base.tgo │ ├── model.tgo │ ├── query.tgo │ ├── resultset.tgo │ └── schema.tgo ├── types.go └── types_test.go ├── kallax.svg ├── model.go ├── model_test.go ├── operators.go ├── operators_test.go ├── query.go ├── query_test.go ├── resultset.go ├── schema.go ├── schema_test.go ├── store.go ├── store_test.go ├── tests ├── common.go ├── common_test.go ├── connection_test.go ├── events.go ├── events_test.go ├── fixtures │ └── fixtures.go ├── json.go ├── json_test.go ├── kallax.go ├── query.go ├── query_test.go ├── relationships.go ├── relationships_test.go ├── resultset.go ├── resultset_test.go ├── schema.go ├── schema_test.go ├── store.go └── store_test.go ├── timestamps.go ├── timestamps_test.go └── types ├── slices.go ├── slices_test.go ├── types.go └── types_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.9.x 5 | - 1.10.x 6 | - tip 7 | 8 | matrix: 9 | allow_failures: 10 | - go: tip 11 | fast_finish: true 12 | 13 | addons: 14 | postgresql: "9.4" 15 | 16 | env: 17 | - DBNAME=kallax_test DBUSER=postgres DBPASS='' GOPATH=/tmp/whatever:$GOPATH 18 | 19 | services: 20 | - postgresql 21 | 22 | before_script: 23 | - psql -c 'create database kallax_test;' -U postgres 24 | 25 | install: 26 | - rm -rf $GOPATH/src/gopkg.in/src-d 27 | - mkdir -p $GOPATH/src/gopkg.in/src-d 28 | - mv $PWD $GOPATH/src/gopkg.in/src-d/go-kallax.v1 29 | - cd $GOPATH/src/gopkg.in/src-d/go-kallax.v1 30 | - go get -v -t ./... 31 | 32 | script: 33 | - make test 34 | 35 | after_success: 36 | - bash <(curl -s https://codecov.io/bash) 37 | 38 | notifications: 39 | email: 40 | on_success: change 41 | on_failure: always 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 source{d} 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MAINTAINERS: -------------------------------------------------------------------------------- 1 | Miguel Molina (@erizocosmico) 2 | Roberto Santalla (@roobre) 3 | 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | COVERAGE_REPORT := coverage.txt 2 | COVERAGE_PROFILE := profile.out 3 | COVERAGE_MODE := atomic 4 | 5 | test: 6 | @echo "mode: $(COVERAGE_MODE)" > $(COVERAGE_REPORT); \ 7 | if [ -f $(COVERAGE_PROFILE) ]; then \ 8 | tail -n +2 $(COVERAGE_PROFILE) >> $(COVERAGE_REPORT); \ 9 | rm $(COVERAGE_PROFILE); \ 10 | fi; \ 11 | for dir in `go list ./... | grep -v '/tests' | grep -v '/fixtures' | grep -v '/benchmarks'`; do \ 12 | go test -v $$dir -coverprofile=$(COVERAGE_PROFILE) -covermode=$(COVERAGE_MODE); \ 13 | if [ $$? != 0 ]; then \ 14 | exit 2; \ 15 | fi; \ 16 | if [ -f $(COVERAGE_PROFILE) ]; then \ 17 | tail -n +2 $(COVERAGE_PROFILE) >> $(COVERAGE_REPORT); \ 18 | rm $(COVERAGE_PROFILE); \ 19 | fi; \ 20 | done; \ 21 | go install ./generator/...; \ 22 | rm ./tests/kallax.go ; \ 23 | go generate ./tests/...; \ 24 | git diff --no-prefix -U1000; \ 25 | if [ `git status | grep 'Changes not staged for commit' | wc -l` != '0' ]; then \ 26 | echo 'There are differences between the commited tests/kallax.go and the one generated right now'; \ 27 | exit 2; \ 28 | fi; \ 29 | go test -v ./tests/...; 30 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | version: build-{build}.{branch} 2 | platform: x64 3 | 4 | image: 5 | - Visual Studio 2015 6 | 7 | clone_folder: c:\gopath\src\gopkg.in\src-d\go-kallax.v1 8 | 9 | shallow_clone: false 10 | 11 | environment: 12 | GOPATH: c:\gopath 13 | PGPASSWORD: "Password12!" 14 | PGUSER: "postgres" 15 | DBUSER: "postgres" 16 | DBPASS: "Password12!" 17 | 18 | services: 19 | - postgresql96 20 | 21 | install: 22 | - set PATH=C:\Program Files\PostgreSQL\9.6\bin\;C:\MinGW\bin;%GOPATH%\bin;c:\go\bin;%PATH% 23 | - go version 24 | - go get -v -t .\... 25 | 26 | build: off 27 | 28 | test_script: 29 | - createdb testing 30 | - mingw32-make test 31 | -------------------------------------------------------------------------------- /batcher.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | 8 | "github.com/Masterminds/squirrel" 9 | ) 10 | 11 | type batchQueryRunner struct { 12 | schema Schema 13 | cols []string 14 | q Query 15 | oneToOneRels []Relationship 16 | oneToManyRels []Relationship 17 | db squirrel.BaseRunner 18 | builder squirrel.SelectBuilder 19 | total int 20 | eof bool 21 | // records is the cache of the records in the last batch. 22 | records []Record 23 | } 24 | 25 | var errNoMoreRows = errors.New("kallax: there are no more rows in the result set") 26 | 27 | func newBatchQueryRunner(schema Schema, db squirrel.BaseRunner, q Query) *batchQueryRunner { 28 | cols, builder := q.compile() 29 | var ( 30 | oneToOneRels []Relationship 31 | oneToManyRels []Relationship 32 | ) 33 | 34 | for _, rel := range q.getRelationships() { 35 | switch rel.Type { 36 | case OneToOne: 37 | oneToOneRels = append(oneToOneRels, rel) 38 | case OneToMany: 39 | oneToManyRels = append(oneToManyRels, rel) 40 | } 41 | } 42 | 43 | return &batchQueryRunner{ 44 | schema: schema, 45 | cols: cols, 46 | q: q, 47 | oneToOneRels: oneToOneRels, 48 | oneToManyRels: oneToManyRels, 49 | db: db, 50 | builder: builder, 51 | } 52 | } 53 | 54 | func (r *batchQueryRunner) next() (Record, error) { 55 | if r.eof && len(r.records) == 0 { 56 | return nil, errNoMoreRows 57 | } 58 | 59 | if len(r.records) == 0 { 60 | var ( 61 | records []Record 62 | err error 63 | ) 64 | 65 | limit := r.q.GetLimit() 66 | if limit == 0 || limit > uint64(r.total) { 67 | records, err = r.loadNextBatch() 68 | if err != nil { 69 | return nil, err 70 | } 71 | } 72 | 73 | if len(records) == 0 { 74 | r.eof = true 75 | return nil, errNoMoreRows 76 | } 77 | 78 | batchSize := r.q.GetBatchSize() 79 | if batchSize > 0 && (batchSize < limit || limit == 0) { 80 | if uint64(len(records)) < batchSize { 81 | r.eof = true 82 | } 83 | } else if limit > 0 { 84 | if uint64(len(records)) < limit { 85 | r.eof = true 86 | } 87 | } 88 | 89 | r.total += len(records) 90 | r.records = records[1:] 91 | return records[0], nil 92 | } 93 | 94 | record := r.records[0] 95 | r.records = r.records[1:] 96 | return record, nil 97 | } 98 | 99 | func (r *batchQueryRunner) loadNextBatch() ([]Record, error) { 100 | limit := r.q.GetLimit() - uint64(r.total) 101 | if r.q.GetBatchSize() < limit || limit <= 0 { 102 | limit = r.q.GetBatchSize() 103 | } 104 | 105 | rows, err := r.builder. 106 | Offset(r.q.GetOffset() + uint64(r.total)). 107 | Limit(limit). 108 | RunWith(r.db). 109 | Query() 110 | 111 | if err != nil { 112 | return nil, err 113 | } 114 | 115 | return r.processBatch(rows) 116 | } 117 | 118 | func (r *batchQueryRunner) processBatch(rows *sql.Rows) ([]Record, error) { 119 | batchRs := NewResultSet( 120 | rows, 121 | r.q.isReadOnly(), 122 | r.oneToOneRels, 123 | r.cols..., 124 | ) 125 | 126 | var records []Record 127 | for batchRs.Next() { 128 | var rec = r.schema.New() 129 | if err := batchRs.Scan(rec); err != nil { 130 | return nil, err 131 | } 132 | records = append(records, rec) 133 | } 134 | 135 | if err := batchRs.Close(); err != nil { 136 | return nil, err 137 | } 138 | 139 | var ids = make([]interface{}, len(records)) 140 | for i, r := range records { 141 | ids[i] = r.GetID().Raw() 142 | } 143 | 144 | for _, rel := range r.oneToManyRels { 145 | indexedResults, err := r.getRecordRelationships(ids, rel) 146 | if err != nil { 147 | return nil, err 148 | } 149 | 150 | for _, r := range records { 151 | err := r.SetRelationship(rel.Field, indexedResults[r.GetID().Raw()]) 152 | if err != nil { 153 | return nil, err 154 | } 155 | 156 | // If the relationship is partial, we can not ensure the results 157 | // in the field reflect the truth of the database. 158 | // In this case, the parent is marked as non-writable. 159 | if rel.Filter != nil { 160 | r.setWritable(false) 161 | } 162 | } 163 | } 164 | 165 | return records, nil 166 | } 167 | 168 | type indexedRecords map[interface{}][]Record 169 | 170 | func (r *batchQueryRunner) getRecordRelationships(ids []interface{}, rel Relationship) (indexedRecords, error) { 171 | fk, ok := r.schema.ForeignKey(rel.Field) 172 | if !ok { 173 | return nil, fmt.Errorf("kallax: cannot find foreign key on field %s for table %s", rel.Field, r.schema.Table()) 174 | } 175 | 176 | filter := In(fk, ids...) 177 | if rel.Filter != nil { 178 | rel.Filter = And(rel.Filter, filter) 179 | } else { 180 | rel.Filter = filter 181 | } 182 | 183 | q := NewBaseQuery(rel.Schema) 184 | q.Where(rel.Filter) 185 | cols, builder := q.compile() 186 | rows, err := builder.RunWith(r.db).Query() 187 | if err != nil { 188 | return nil, err 189 | } 190 | 191 | relRs := NewResultSet(rows, false, nil, cols...) 192 | var indexedResults = make(indexedRecords) 193 | for relRs.Next() { 194 | rec, err := relRs.Get(rel.Schema) 195 | if err != nil { 196 | return nil, err 197 | } 198 | 199 | val, err := rec.Value(fk.String()) 200 | if err != nil { 201 | return nil, err 202 | } 203 | 204 | rec.setPersisted() 205 | rec.setWritable(true) 206 | id := val.(Identifier).Raw() 207 | indexedResults[id] = append(indexedResults[id], rec) 208 | } 209 | 210 | if err := relRs.Close(); err != nil { 211 | return nil, err 212 | } 213 | 214 | return indexedResults, nil 215 | } 216 | -------------------------------------------------------------------------------- /batcher_test.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/Masterminds/squirrel" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestOneToManyWithFilterNotWritable(t *testing.T) { 12 | r := require.New(t) 13 | db, err := openTestDB() 14 | r.NoError(err) 15 | setupTables(t, db) 16 | defer db.Close() 17 | defer teardownTables(t, db) 18 | 19 | store := NewStore(db) 20 | m := newModel("foo", "bar", 1) 21 | r.NoError(store.Insert(ModelSchema, m)) 22 | 23 | for i := 0; i < 4; i++ { 24 | r.NoError(store.Insert(RelSchema, newRel(m.GetID(), fmt.Sprint(i)))) 25 | } 26 | 27 | q := NewBaseQuery(ModelSchema) 28 | r.NoError(q.AddRelation(RelSchema, "rels", OneToMany, Eq(f("foo"), "1"))) 29 | runner := newBatchQueryRunner(ModelSchema, squirrel.NewStmtCacher(db), q) 30 | record, err := runner.next() 31 | r.NoError(err) 32 | r.False(record.IsWritable()) 33 | } 34 | 35 | func TestBatcherLimit(t *testing.T) { 36 | r := require.New(t) 37 | db, err := openTestDB() 38 | r.NoError(err) 39 | setupTables(t, db) 40 | defer db.Close() 41 | defer teardownTables(t, db) 42 | 43 | store := NewStore(db) 44 | for i := 0; i < 10; i++ { 45 | m := newModel("foo", "bar", 1) 46 | r.NoError(store.Insert(ModelSchema, m)) 47 | 48 | for i := 0; i < 4; i++ { 49 | r.NoError(store.Insert(RelSchema, newRel(m.GetID(), fmt.Sprint(i)))) 50 | } 51 | } 52 | 53 | q := NewBaseQuery(ModelSchema) 54 | q.BatchSize(2) 55 | q.Limit(5) 56 | r.NoError(q.AddRelation(RelSchema, "rels", OneToMany, Eq(f("foo"), "1"))) 57 | runner := newBatchQueryRunner(ModelSchema, store.runner, q) 58 | rs := NewBatchingResultSet(runner) 59 | 60 | var count int 61 | for rs.Next() { 62 | _, err := rs.Get(nil) 63 | r.NoError(err) 64 | count++ 65 | } 66 | r.NoError(err) 67 | r.Equal(5, count) 68 | } 69 | 70 | func TestBatcherNoExtraQueryIfLessThanLimit(t *testing.T) { 71 | r := require.New(t) 72 | db, err := openTestDB() 73 | r.NoError(err) 74 | setupTables(t, db) 75 | defer db.Close() 76 | defer teardownTables(t, db) 77 | 78 | store := NewStore(db) 79 | for i := 0; i < 4; i++ { 80 | m := newModel("foo", "bar", 1) 81 | r.NoError(store.Insert(ModelSchema, m)) 82 | 83 | for i := 0; i < 4; i++ { 84 | r.NoError(store.Insert(RelSchema, newRel(m.GetID(), fmt.Sprint(i)))) 85 | } 86 | } 87 | 88 | q := NewBaseQuery(ModelSchema) 89 | q.Limit(6) 90 | r.NoError(q.AddRelation(RelSchema, "rels", OneToMany, Eq(f("foo"), "1"))) 91 | var queries int 92 | proxy := store.DebugWith(func(_ string, _ ...interface{}) { 93 | queries++ 94 | }).runner 95 | runner := newBatchQueryRunner(ModelSchema, proxy, q) 96 | rs := NewBatchingResultSet(runner) 97 | 98 | var count int 99 | for rs.Next() { 100 | _, err := rs.Get(nil) 101 | r.NoError(err) 102 | count++ 103 | } 104 | r.NoError(err) 105 | r.Equal(4, count) 106 | r.Equal(2, queries) 107 | } 108 | 109 | func TestBatcherNoExtraQueryIfLessThanBatchSize(t *testing.T) { 110 | r := require.New(t) 111 | db, err := openTestDB() 112 | r.NoError(err) 113 | setupTables(t, db) 114 | defer db.Close() 115 | defer teardownTables(t, db) 116 | 117 | store := NewStore(db) 118 | for i := 0; i < 4; i++ { 119 | m := newModel("foo", "bar", 1) 120 | r.NoError(store.Insert(ModelSchema, m)) 121 | 122 | for i := 0; i < 4; i++ { 123 | r.NoError(store.Insert(RelSchema, newRel(m.GetID(), fmt.Sprint(i)))) 124 | } 125 | } 126 | 127 | q := NewBaseQuery(ModelSchema) 128 | q.BatchSize(3) 129 | r.NoError(q.AddRelation(RelSchema, "rels", OneToMany, Eq(f("foo"), "1"))) 130 | var queries int 131 | proxy := store.DebugWith(func(_ string, _ ...interface{}) { 132 | queries++ 133 | }).runner 134 | runner := newBatchQueryRunner(ModelSchema, proxy, q) 135 | rs := NewBatchingResultSet(runner) 136 | 137 | var count int 138 | for rs.Next() { 139 | _, err := rs.Get(nil) 140 | r.NoError(err) 141 | count++ 142 | } 143 | r.NoError(err) 144 | r.Equal(4, count) 145 | r.Equal(4, queries) 146 | } 147 | -------------------------------------------------------------------------------- /benchmarks/bench_test.go: -------------------------------------------------------------------------------- 1 | package benchmark 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "os" 7 | "testing" 8 | 9 | "github.com/jinzhu/gorm" 10 | "github.com/vattle/sqlboiler/queries/qm" 11 | null "gopkg.in/nullbio/null.v6" 12 | "gopkg.in/src-d/go-kallax.v1/benchmarks/models" 13 | ) 14 | 15 | func envOrDefault(key string, def string) string { 16 | v := os.Getenv(key) 17 | if v == "" { 18 | v = def 19 | } 20 | return v 21 | } 22 | 23 | func dbURL() string { 24 | return fmt.Sprintf( 25 | "postgres://%s:%s@%s/%s?sslmode=disable", 26 | envOrDefault("DBUSER", "testing"), 27 | envOrDefault("DBPASS", "testing"), 28 | envOrDefault("DBHOST", "0.0.0.0:5432"), 29 | envOrDefault("DBNAME", "testing"), 30 | ) 31 | } 32 | 33 | func openTestDB(b *testing.B) *sql.DB { 34 | db, err := sql.Open("postgres", dbURL()) 35 | if err != nil { 36 | b.Fatalf("error opening db: %s", err) 37 | } 38 | return db 39 | } 40 | 41 | func openGormTestDB(b *testing.B) *gorm.DB { 42 | db, err := gorm.Open("postgres", dbURL()) 43 | if err != nil { 44 | b.Fatalf("error opening db: %s", err) 45 | } 46 | return db 47 | } 48 | 49 | var schemas = []string{ 50 | `CREATE TABLE IF NOT EXISTS people ( 51 | id serial primary key, 52 | name text 53 | )`, 54 | `CREATE TABLE IF NOT EXISTS pets ( 55 | id serial primary key, 56 | name text, 57 | kind text, 58 | person_id integer references people(id) 59 | )`, 60 | } 61 | 62 | var tables = []string{"pets", "people"} 63 | 64 | func setupDB(b *testing.B, db *sql.DB) *sql.DB { 65 | for _, s := range schemas { 66 | _, err := db.Exec(s) 67 | if err != nil { 68 | b.Fatalf("error creating schema: %s", err) 69 | } 70 | } 71 | 72 | return db 73 | } 74 | 75 | func teardownDB(b *testing.B, db *sql.DB) { 76 | for _, t := range tables { 77 | _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", t)) 78 | if err != nil { 79 | b.Fatalf("error dropping table: %s", err) 80 | } 81 | } 82 | 83 | if err := db.Close(); err != nil { 84 | b.Fatalf("error closing db: %s", err) 85 | } 86 | } 87 | 88 | func mkPersonWithRels() *Person { 89 | return &Person{ 90 | Name: "Dolan", 91 | Pets: []*Pet{ 92 | {Name: "Garfield", Kind: Cat}, 93 | {Name: "Oddie", Kind: Dog}, 94 | {Name: "Reptar", Kind: Fish}, 95 | }, 96 | } 97 | } 98 | 99 | func mkGormPersonWithRels() *GORMPerson { 100 | return &GORMPerson{ 101 | Name: "Dolan", 102 | Pets: []*GORMPet{ 103 | {Name: "Garfield", Kind: string(Cat)}, 104 | {Name: "Oddie", Kind: string(Dog)}, 105 | {Name: "Reptar", Kind: string(Fish)}, 106 | }, 107 | } 108 | } 109 | 110 | func BenchmarkKallaxInsertWithRelationships(b *testing.B) { 111 | db := setupDB(b, openTestDB(b)) 112 | defer teardownDB(b, db) 113 | 114 | store := NewPersonStore(db) 115 | for i := 0; i < b.N; i++ { 116 | if err := store.Insert(mkPersonWithRels()); err != nil { 117 | b.Fatalf("error inserting: %s", err) 118 | } 119 | } 120 | } 121 | 122 | func BenchmarkKallaxUpdateWithRelationships(b *testing.B) { 123 | db := setupDB(b, openTestDB(b)) 124 | defer teardownDB(b, db) 125 | 126 | store := NewPersonStore(db) 127 | pers := mkPersonWithRels() 128 | if err := store.Insert(pers); err != nil { 129 | b.Fatalf("error inserting: %s", err) 130 | } 131 | 132 | for i := 0; i < b.N; i++ { 133 | if _, err := store.Update(pers); err != nil { 134 | b.Fatalf("error updating: %s", err) 135 | } 136 | } 137 | } 138 | 139 | func BenchmarkSQLBoilerInsertWithRelationships(b *testing.B) { 140 | db := setupDB(b, openTestDB(b)) 141 | defer teardownDB(b, db) 142 | 143 | for i := 0; i < b.N; i++ { 144 | tx, _ := db.Begin() 145 | person := &models.Person{Name: null.StringFrom("Dolan")} 146 | if err := person.Insert(tx); err != nil { 147 | b.Fatalf("error inserting: %s", err) 148 | } 149 | 150 | err := person.SetPets(tx, true, []*models.Pet{ 151 | {Name: null.StringFrom("Garfield"), Kind: null.StringFrom("cat")}, 152 | {Name: null.StringFrom("Oddie"), Kind: null.StringFrom("dog")}, 153 | {Name: null.StringFrom("Reptar"), Kind: null.StringFrom("fish")}, 154 | }...) 155 | if err != nil { 156 | b.Fatalf("error inserting relationships: %s", err) 157 | } 158 | 159 | tx.Commit() 160 | } 161 | } 162 | 163 | func BenchmarkRawSQLInsertWithRelationships(b *testing.B) { 164 | db := setupDB(b, openTestDB(b)) 165 | defer teardownDB(b, db) 166 | 167 | for i := 0; i < b.N; i++ { 168 | p := mkPersonWithRels() 169 | tx, err := db.Begin() 170 | 171 | err = tx.QueryRow("INSERT INTO people (name) VALUES ($1) RETURNING id", p.Name). 172 | Scan(&p.ID) 173 | if err != nil { 174 | b.Fatalf("error inserting: %s", err) 175 | } 176 | 177 | for _, pet := range p.Pets { 178 | err := tx.QueryRow( 179 | "INSERT INTO pets (name, kind, person_id) VALUES ($1, $2, $3) RETURNING id", 180 | pet.Name, string(pet.Kind), p.ID, 181 | ).Scan(&pet.ID) 182 | if err != nil { 183 | b.Fatalf("error inserting rel: %s", err) 184 | } 185 | } 186 | 187 | if err := tx.Commit(); err != nil { 188 | b.Fatalf("error committing transaction: %s", err) 189 | } 190 | } 191 | } 192 | 193 | func BenchmarkGORMInsertWithRelationships(b *testing.B) { 194 | store := openGormTestDB(b) 195 | setupDB(b, store.DB()) 196 | defer teardownDB(b, store.DB()) 197 | 198 | for i := 0; i < b.N; i++ { 199 | if db := store.Create(mkGormPersonWithRels()); db.Error != nil { 200 | b.Fatalf("error inserting: %s", db.Error) 201 | } 202 | } 203 | } 204 | 205 | func BenchmarkKallaxInsert(b *testing.B) { 206 | db := setupDB(b, openTestDB(b)) 207 | defer teardownDB(b, db) 208 | 209 | store := NewPersonStore(db) 210 | for i := 0; i < b.N; i++ { 211 | if err := store.Insert(&Person{Name: "foo"}); err != nil { 212 | b.Fatalf("error inserting: %s", err) 213 | } 214 | } 215 | } 216 | 217 | func BenchmarkKallaxUpdate(b *testing.B) { 218 | db := setupDB(b, openTestDB(b)) 219 | defer teardownDB(b, db) 220 | 221 | store := NewPersonStore(db) 222 | pers := &Person{Name: "foo"} 223 | if err := store.Insert(pers); err != nil { 224 | b.Fatalf("error inserting: %s", err) 225 | } 226 | 227 | for i := 0; i < b.N; i++ { 228 | if _, err := store.Update(pers); err != nil { 229 | b.Fatalf("error updating: %s", err) 230 | } 231 | } 232 | } 233 | 234 | func BenchmarkSQLBoilerInsert(b *testing.B) { 235 | db := setupDB(b, openTestDB(b)) 236 | defer teardownDB(b, db) 237 | 238 | for i := 0; i < b.N; i++ { 239 | if err := (&models.Person{Name: null.StringFrom("foo")}).Insert(db); err != nil { 240 | b.Fatalf("error inserting: %s", err) 241 | } 242 | } 243 | } 244 | 245 | func BenchmarkRawSQLInsert(b *testing.B) { 246 | db := setupDB(b, openTestDB(b)) 247 | defer teardownDB(b, db) 248 | 249 | for i := 0; i < b.N; i++ { 250 | p := &Person{Name: "foo"} 251 | 252 | err := db.QueryRow("INSERT INTO people (name) VALUES ($1) RETURNING id", p.Name). 253 | Scan(&p.ID) 254 | if err != nil { 255 | b.Fatalf("error inserting: %s", err) 256 | } 257 | } 258 | } 259 | 260 | func BenchmarkGORMInsert(b *testing.B) { 261 | store := openGormTestDB(b) 262 | setupDB(b, store.DB()) 263 | defer teardownDB(b, store.DB()) 264 | 265 | for i := 0; i < b.N; i++ { 266 | if db := store.Create(&GORMPerson{Name: "foo"}); db.Error != nil { 267 | b.Fatalf("error inserting: %s", db.Error) 268 | } 269 | } 270 | } 271 | 272 | func BenchmarkKallaxQueryRelationships(b *testing.B) { 273 | db := openTestDB(b) 274 | setupDB(b, db) 275 | defer teardownDB(b, db) 276 | 277 | store := NewPersonStore(db) 278 | for i := 0; i < 200; i++ { 279 | if err := store.Insert(mkPersonWithRels()); err != nil { 280 | b.Fatalf("error inserting: %s", err) 281 | } 282 | } 283 | 284 | b.Run("query", func(b *testing.B) { 285 | for i := 0; i < b.N; i++ { 286 | _, err := store.FindAll(NewPersonQuery().WithPets(nil).Limit(100)) 287 | if err != nil { 288 | b.Fatalf("error retrieving persons: %s", err) 289 | } 290 | } 291 | }) 292 | } 293 | 294 | func BenchmarkSQLBoilerQueryRelationships(b *testing.B) { 295 | db := openTestDB(b) 296 | setupDB(b, db) 297 | defer teardownDB(b, db) 298 | 299 | for i := 0; i < 200; i++ { 300 | person := &models.Person{Name: null.StringFrom("Dolan")} 301 | if err := person.Insert(db); err != nil { 302 | b.Fatalf("error inserting: %s", err) 303 | } 304 | 305 | err := person.SetPets(db, true, []*models.Pet{ 306 | {Name: null.StringFrom("Garfield"), Kind: null.StringFrom("cat")}, 307 | {Name: null.StringFrom("Oddie"), Kind: null.StringFrom("dog")}, 308 | {Name: null.StringFrom("Reptar"), Kind: null.StringFrom("fish")}, 309 | }...) 310 | if err != nil { 311 | b.Fatalf("error inserting relationships: %s", err) 312 | } 313 | } 314 | 315 | b.Run("query", func(b *testing.B) { 316 | for i := 0; i < b.N; i++ { 317 | _, err := models.People(db, qm.Load("Pets"), qm.Limit(100)).All() 318 | if err != nil { 319 | b.Fatalf("error retrieving persons: %s", err) 320 | } 321 | } 322 | }) 323 | } 324 | 325 | func BenchmarkRawSQLQueryRelationships(b *testing.B) { 326 | db := openTestDB(b) 327 | setupDB(b, db) 328 | defer teardownDB(b, db) 329 | 330 | store := NewPersonStore(db) 331 | for i := 0; i < 200; i++ { 332 | if err := store.Insert(mkPersonWithRels()); err != nil { 333 | b.Fatalf("error inserting: %s", err) 334 | } 335 | } 336 | 337 | b.Run("query", func(b *testing.B) { 338 | for i := 0; i < b.N; i++ { 339 | rows, err := db.Query("SELECT * FROM people") 340 | if err != nil { 341 | b.Fatalf("error querying: %s", err) 342 | } 343 | 344 | var people []*GORMPerson 345 | for rows.Next() { 346 | var p GORMPerson 347 | if err := rows.Scan(&p.ID, &p.Name); err != nil { 348 | b.Fatalf("error scanning: %s", err) 349 | } 350 | 351 | r, err := db.Query("SELECT * FROM pets WHERE person_id = $1", p.ID) 352 | if err != nil { 353 | b.Fatalf("error querying relationships: %s", err) 354 | } 355 | 356 | for r.Next() { 357 | var pet GORMPet 358 | if err := r.Scan(&pet.ID, &pet.Name, &pet.Kind, &pet.PersonID); err != nil { 359 | b.Fatalf("error scanning relationship: %s", err) 360 | } 361 | p.Pets = append(p.Pets, &pet) 362 | } 363 | 364 | r.Close() 365 | people = append(people, &p) 366 | } 367 | 368 | _ = people 369 | rows.Close() 370 | } 371 | }) 372 | } 373 | 374 | func BenchmarkGORMQueryRelationships(b *testing.B) { 375 | store := openGormTestDB(b) 376 | setupDB(b, store.DB()) 377 | defer teardownDB(b, store.DB()) 378 | 379 | for i := 0; i < 300; i++ { 380 | if db := store.Create(mkGormPersonWithRels()); db.Error != nil { 381 | b.Fatalf("error inserting: %s", db.Error) 382 | } 383 | } 384 | 385 | b.Run("query", func(b *testing.B) { 386 | for i := 0; i < b.N; i++ { 387 | var persons []*GORMPerson 388 | db := store.Preload("Pets").Limit(100).Find(&persons) 389 | if db.Error != nil { 390 | b.Fatalf("error retrieving persons: %s", db.Error) 391 | } 392 | } 393 | }) 394 | } 395 | 396 | func BenchmarkKallaxQuery(b *testing.B) { 397 | db := openTestDB(b) 398 | setupDB(b, db) 399 | defer teardownDB(b, db) 400 | 401 | store := NewPersonStore(db) 402 | for i := 0; i < 300; i++ { 403 | if err := store.Insert(&Person{Name: "foo"}); err != nil { 404 | b.Fatalf("error inserting: %s", err) 405 | } 406 | } 407 | 408 | b.Run("query", func(b *testing.B) { 409 | for i := 0; i < b.N; i++ { 410 | _, err := store.FindAll(NewPersonQuery()) 411 | if err != nil { 412 | b.Fatalf("error retrieving persons: %s", err) 413 | } 414 | } 415 | }) 416 | } 417 | 418 | func BenchmarkSQLBoilerQuery(b *testing.B) { 419 | db := openTestDB(b) 420 | setupDB(b, db) 421 | defer teardownDB(b, db) 422 | 423 | for i := 0; i < 300; i++ { 424 | if err := (&models.Person{Name: null.StringFrom("foo")}).Insert(db); err != nil { 425 | b.Fatalf("error inserting: %s", err) 426 | } 427 | } 428 | 429 | b.Run("query", func(b *testing.B) { 430 | for i := 0; i < b.N; i++ { 431 | _, err := models.People(db).All() 432 | if err != nil { 433 | b.Fatalf("error retrieving persons: %s", err) 434 | } 435 | } 436 | }) 437 | } 438 | 439 | func BenchmarkRawSQLQuery(b *testing.B) { 440 | db := openTestDB(b) 441 | setupDB(b, db) 442 | defer teardownDB(b, db) 443 | 444 | store := NewPersonStore(db) 445 | for i := 0; i < 300; i++ { 446 | if err := store.Insert(&Person{Name: "foo"}); err != nil { 447 | b.Fatalf("error inserting: %s", err) 448 | } 449 | } 450 | 451 | b.Run("query", func(b *testing.B) { 452 | for i := 0; i < b.N; i++ { 453 | rows, err := db.Query("SELECT * FROM people") 454 | if err != nil { 455 | b.Fatalf("error querying: %s", err) 456 | } 457 | 458 | var people []*Person 459 | for rows.Next() { 460 | var p Person 461 | err := rows.Scan(&p.ID, &p.Name) 462 | if err != nil { 463 | b.Fatalf("error scanning: %s", err) 464 | } 465 | people = append(people, &p) 466 | } 467 | 468 | _ = people 469 | rows.Close() 470 | } 471 | }) 472 | } 473 | 474 | func BenchmarkGORMQuery(b *testing.B) { 475 | store := openGormTestDB(b) 476 | setupDB(b, store.DB()) 477 | defer teardownDB(b, store.DB()) 478 | 479 | for i := 0; i < 200; i++ { 480 | if db := store.Create(&GORMPerson{Name: "foo"}); db.Error != nil { 481 | b.Fatalf("error inserting: %s", db.Error) 482 | } 483 | } 484 | 485 | b.Run("query", func(b *testing.B) { 486 | for i := 0; i < b.N; i++ { 487 | var persons []*GORMPerson 488 | db := store.Find(&persons) 489 | if db.Error != nil { 490 | b.Fatal("error retrieving persons:", db.Error) 491 | } 492 | } 493 | }) 494 | } 495 | -------------------------------------------------------------------------------- /benchmarks/models/boil_queries.go: -------------------------------------------------------------------------------- 1 | // This file is generated by SQLBoiler (https://github.com/vattle/sqlboiler) 2 | // and is meant to be re-generated in place and/or deleted at any time. 3 | // DO NOT EDIT 4 | 5 | package models 6 | 7 | import ( 8 | "github.com/vattle/sqlboiler/boil" 9 | "github.com/vattle/sqlboiler/queries" 10 | "github.com/vattle/sqlboiler/queries/qm" 11 | ) 12 | 13 | var dialect = queries.Dialect{ 14 | LQ: 0x22, 15 | RQ: 0x22, 16 | IndexPlaceholders: true, 17 | UseTopClause: false, 18 | } 19 | 20 | // NewQueryG initializes a new Query using the passed in QueryMods 21 | func NewQueryG(mods ...qm.QueryMod) *queries.Query { 22 | return NewQuery(boil.GetDB(), mods...) 23 | } 24 | 25 | // NewQuery initializes a new Query using the passed in QueryMods 26 | func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *queries.Query { 27 | q := &queries.Query{} 28 | queries.SetExecutor(q, exec) 29 | queries.SetDialect(q, &dialect) 30 | qm.Apply(q, mods...) 31 | 32 | return q 33 | } 34 | -------------------------------------------------------------------------------- /benchmarks/models/boil_types.go: -------------------------------------------------------------------------------- 1 | // This file is generated by SQLBoiler (https://github.com/vattle/sqlboiler) 2 | // and is meant to be re-generated in place and/or deleted at any time. 3 | // DO NOT EDIT 4 | 5 | package models 6 | 7 | import ( 8 | "github.com/pkg/errors" 9 | "github.com/vattle/sqlboiler/strmangle" 10 | ) 11 | 12 | // M type is for providing columns and column values to UpdateAll. 13 | type M map[string]interface{} 14 | 15 | // ErrSyncFail occurs during insert when the record could not be retrieved in 16 | // order to populate default value information. This usually happens when LastInsertId 17 | // fails or there was a primary key configuration that was not resolvable. 18 | var ErrSyncFail = errors.New("models: failed to synchronize data after insert") 19 | 20 | type insertCache struct { 21 | query string 22 | retQuery string 23 | valueMapping []uint64 24 | retMapping []uint64 25 | } 26 | 27 | type updateCache struct { 28 | query string 29 | valueMapping []uint64 30 | } 31 | 32 | func makeCacheKey(wl, nzDefaults []string) string { 33 | buf := strmangle.GetBuffer() 34 | 35 | for _, w := range wl { 36 | buf.WriteString(w) 37 | } 38 | if len(nzDefaults) != 0 { 39 | buf.WriteByte('.') 40 | } 41 | for _, nz := range nzDefaults { 42 | buf.WriteString(nz) 43 | } 44 | 45 | str := buf.String() 46 | strmangle.PutBuffer(buf) 47 | return str 48 | } 49 | -------------------------------------------------------------------------------- /benchmarks/models_gorm.go: -------------------------------------------------------------------------------- 1 | package benchmark 2 | 3 | type GORMPerson struct { 4 | ID int64 `gorm:"primary_key"` 5 | Name string 6 | Pets []*GORMPet `gorm:"ForeignKey:PersonID"` 7 | } 8 | 9 | func (GORMPerson) TableName() string { 10 | return "people" 11 | } 12 | 13 | type GORMPet struct { 14 | ID int64 `gorm:"primary_key"` 15 | PersonID int64 16 | Name string 17 | Kind string 18 | } 19 | 20 | func (GORMPet) TableName() string { 21 | return "pets" 22 | } 23 | -------------------------------------------------------------------------------- /benchmarks/models_kallax.go: -------------------------------------------------------------------------------- 1 | package benchmark 2 | 3 | import kallax "gopkg.in/src-d/go-kallax.v1" 4 | 5 | type Person struct { 6 | kallax.Model `table:"people"` 7 | ID int64 `pk:"autoincr"` 8 | Name string 9 | Pets []*Pet 10 | } 11 | 12 | type Pet struct { 13 | kallax.Model `table:"pets"` 14 | ID int64 `pk:"autoincr"` 15 | Name string 16 | Kind PetKind 17 | } 18 | 19 | type PetKind string 20 | 21 | const ( 22 | Cat PetKind = "cat" 23 | Dog PetKind = "dog" 24 | Fish PetKind = "fish" 25 | ) 26 | -------------------------------------------------------------------------------- /common_test.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "os" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func envOrDefault(key string, def string) string { 13 | v := os.Getenv(key) 14 | if v == "" { 15 | v = def 16 | } 17 | return v 18 | } 19 | 20 | func openTestDB() (*sql.DB, error) { 21 | return sql.Open("postgres", fmt.Sprintf( 22 | "postgres://%s:%s@%s/%s?sslmode=disable", 23 | envOrDefault("DBUSER", "testing"), 24 | envOrDefault("DBPASS", "testing"), 25 | envOrDefault("DBHOST", "0.0.0.0:5432"), 26 | envOrDefault("DBNAME", "testing"), 27 | )) 28 | } 29 | 30 | func setupTables(t *testing.T, db *sql.DB) { 31 | _, err := db.Exec(`CREATE TABLE IF NOT EXISTS model ( 32 | id serial PRIMARY KEY, 33 | name varchar(255) not null, 34 | email varchar(255) not null, 35 | age int not null 36 | )`) 37 | require.NoError(t, err) 38 | 39 | _, err = db.Exec(`CREATE TABLE IF NOT EXISTS rel ( 40 | id serial PRIMARY KEY, 41 | model_id integer, 42 | foo text 43 | )`) 44 | require.NoError(t, err) 45 | } 46 | 47 | func teardownTables(t *testing.T, db *sql.DB) { 48 | _, err := db.Exec("DROP TABLE model") 49 | require.NoError(t, err) 50 | _, err = db.Exec("DROP TABLE rel") 51 | require.NoError(t, err) 52 | } 53 | 54 | type model struct { 55 | Model 56 | ID int64 `pk:"autoincr"` 57 | Name string 58 | Email string 59 | Age int 60 | Rel *rel 61 | Rels []*rel 62 | } 63 | 64 | func newModel(name, email string, age int) *model { 65 | m := &model{Model: NewModel(), Name: name, Email: email, Age: age} 66 | return m 67 | } 68 | 69 | func (m *model) Value(col string) (interface{}, error) { 70 | switch col { 71 | case "id": 72 | return m.ID, nil 73 | case "name": 74 | return m.Name, nil 75 | case "email": 76 | return m.Email, nil 77 | case "age": 78 | return m.Age, nil 79 | } 80 | return nil, fmt.Errorf("kallax: column does not exist: %s", col) 81 | } 82 | 83 | func (m *model) ColumnAddress(col string) (interface{}, error) { 84 | switch col { 85 | case "id": 86 | return &m.ID, nil 87 | case "name": 88 | return &m.Name, nil 89 | case "email": 90 | return &m.Email, nil 91 | case "age": 92 | return &m.Age, nil 93 | } 94 | return nil, fmt.Errorf("kallax: column does not exist: %s", col) 95 | } 96 | 97 | func (m *model) NewRelationshipRecord(field string) (Record, error) { 98 | switch field { 99 | case "rel": 100 | return new(rel), nil 101 | case "rels": 102 | return new(rel), nil 103 | } 104 | return nil, fmt.Errorf("kallax: no relationship found for field %s", field) 105 | } 106 | 107 | func (m *model) SetRelationship(field string, record interface{}) error { 108 | switch field { 109 | case "rel": 110 | rel, ok := record.(*rel) 111 | if !ok { 112 | return fmt.Errorf("kallax: can't set relationship %s with a record of type %t", field, record) 113 | } 114 | m.Rel = rel 115 | return nil 116 | case "rels": 117 | rels, ok := record.([]Record) 118 | if !ok { 119 | return fmt.Errorf("kallax: can't set relationship %s with value of type %T", field, record) 120 | } 121 | m.Rels = make([]*rel, len(rels)) 122 | for i, r := range rels { 123 | rel, ok := r.(*rel) 124 | if !ok { 125 | return fmt.Errorf("kallax: can't set element of relationship %s with element of type %T", field, r) 126 | } 127 | m.Rels[i] = rel 128 | } 129 | return nil 130 | } 131 | return fmt.Errorf("kallax: no relationship found for field %s", field) 132 | } 133 | 134 | func (m *model) GetID() Identifier { 135 | return (*NumericID)(&m.ID) 136 | } 137 | 138 | type rel struct { 139 | Model 140 | ID int64 `pk:"autoincr"` 141 | Foo string 142 | } 143 | 144 | func newRel(id Identifier, foo string) *rel { 145 | rel := &rel{NewModel(), 0, foo} 146 | rel.AddVirtualColumn("model_id", id) 147 | return rel 148 | } 149 | 150 | func (r *rel) GetID() Identifier { 151 | return (*NumericID)(&r.ID) 152 | } 153 | 154 | func (m *rel) Value(col string) (interface{}, error) { 155 | switch col { 156 | case "id": 157 | return m.ID, nil 158 | case "model_id": 159 | return m.VirtualColumn(col), nil 160 | case "foo": 161 | return m.Foo, nil 162 | } 163 | return nil, fmt.Errorf("kallax: column does not exist: %s", col) 164 | } 165 | 166 | func (m *rel) ColumnAddress(col string) (interface{}, error) { 167 | switch col { 168 | case "id": 169 | return &m.ID, nil 170 | case "model_id": 171 | return VirtualColumn(col, m, new(NumericID)), nil 172 | case "foo": 173 | return &m.Foo, nil 174 | } 175 | return nil, fmt.Errorf("kallax: column does not exist: %s", col) 176 | } 177 | 178 | func (m *rel) NewRelationshipRecord(field string) (Record, error) { 179 | return nil, fmt.Errorf("kallax: no relationship found for field %s", field) 180 | } 181 | 182 | func (m *rel) SetRelationship(field string, record interface{}) error { 183 | return fmt.Errorf("kallax: no relationship found for field %s", field) 184 | } 185 | 186 | type onlyPkModel struct { 187 | Model 188 | ID int64 `pk:"autoincr"` 189 | } 190 | 191 | func newOnlyPkModel() *onlyPkModel { 192 | m := new(onlyPkModel) 193 | return m 194 | } 195 | 196 | func (m *onlyPkModel) Value(col string) (interface{}, error) { 197 | switch col { 198 | case "id": 199 | return m.ID, nil 200 | } 201 | return nil, fmt.Errorf("kallax: column does not exist: %s", col) 202 | } 203 | 204 | func (m *onlyPkModel) ColumnAddress(col string) (interface{}, error) { 205 | switch col { 206 | case "id": 207 | return &m.ID, nil 208 | } 209 | return nil, fmt.Errorf("kallax: column does not exist: %s", col) 210 | } 211 | 212 | func (m *onlyPkModel) NewRelationshipRecord(field string) (Record, error) { 213 | return nil, fmt.Errorf("kallax: no relationship found for field %s", field) 214 | } 215 | 216 | func (m *onlyPkModel) SetRelationship(field string, record interface{}) error { 217 | return fmt.Errorf("kallax: no relationship found for field %s", field) 218 | } 219 | 220 | func (m *onlyPkModel) GetID() Identifier { 221 | return (*NumericID)(&m.ID) 222 | } 223 | 224 | var ModelSchema = NewBaseSchema( 225 | "model", 226 | "__model", 227 | f("id"), 228 | ForeignKeys{ 229 | "rel": NewForeignKey("model_id", false), 230 | "rels": NewForeignKey("model_id", false), 231 | "rel_inv": NewForeignKey("model_id", true), 232 | }, 233 | func() Record { 234 | return new(model) 235 | }, 236 | true, 237 | f("id"), 238 | f("name"), 239 | f("email"), 240 | f("age"), 241 | ) 242 | 243 | var RelSchema = NewBaseSchema( 244 | "rel", 245 | "__rel", 246 | f("id"), 247 | ForeignKeys{}, 248 | func() Record { 249 | return new(rel) 250 | }, 251 | true, 252 | f("id"), 253 | f("model_id"), 254 | f("foo"), 255 | ) 256 | 257 | var onlyPkModelSchema = NewBaseSchema( 258 | "model", 259 | "__model", 260 | f("id"), 261 | nil, 262 | func() Record { 263 | return new(onlyPkModel) 264 | }, 265 | true, 266 | f("id"), 267 | ) 268 | 269 | func f(name string) SchemaField { 270 | return NewSchemaField(name) 271 | } 272 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Kallax is a PostgreSQL typesafe ORM for the Go language. 2 | // 3 | // Kallax aims to provide a way of programmatically write queries and interact 4 | // with a PostgreSQL database without having to write a single line of SQL, 5 | // use strings to refer to columns and use values of any type in queries. 6 | // For that reason, the first priority of kallax is to provide type safety to 7 | // the data access layer. 8 | // Another of the goals of kallax is make sure all models are, first and 9 | // foremost, Go structs without having to use database-specific types such as, 10 | // for example, `sql.NullInt64`. 11 | // Support for arrays of all basic Go types and all JSON and arrays operators is 12 | // provided as well. 13 | package kallax // import "gopkg.in/src-d/go-kallax.v1" 14 | -------------------------------------------------------------------------------- /events.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | // BeforeInserter will do some operations before being inserted. 4 | type BeforeInserter interface { 5 | // BeforeInsert will do some operations before being inserted. If an error is 6 | // returned, it will prevent the insert from happening. 7 | BeforeInsert() error 8 | } 9 | 10 | // BeforeUpdater will do some operations before being updated. 11 | type BeforeUpdater interface { 12 | // BeforeUpdate will do some operations before being updated. If an error is 13 | // returned, it will prevent the update from happening. 14 | BeforeUpdate() error 15 | } 16 | 17 | // BeforeSaver will do some operations before being updated or inserted. 18 | type BeforeSaver interface { 19 | // BeforeSave will do some operations before being updated or inserted. If an 20 | // error is returned, it will prevent the update or insert from happening. 21 | BeforeSave() error 22 | } 23 | 24 | // BeforeDeleter will do some operations before being deleted. 25 | type BeforeDeleter interface { 26 | // BeforeDelete will do some operations before being deleted. If an error is 27 | // returned, it will prevent the delete from happening. 28 | BeforeDelete() error 29 | } 30 | 31 | // AfterInserter will do some operations after being inserted. 32 | type AfterInserter interface { 33 | // AfterInsert will do some operations after being inserted. If an error is 34 | // returned, it will cause the insert to be rolled back. 35 | AfterInsert() error 36 | } 37 | 38 | // AfterUpdater will do some operations after being updated. 39 | type AfterUpdater interface { 40 | // AfterUpdate will do some operations after being updated. If an error is 41 | // returned, it will cause the update to be rolled back. 42 | AfterUpdate() error 43 | } 44 | 45 | // AfterSaver will do some operations after being inserted or updated. 46 | type AfterSaver interface { 47 | // AfterSave will do some operations after being inserted or updated. If an 48 | // error is returned, it will cause the insert or update to be rolled back. 49 | AfterSave() error 50 | } 51 | 52 | // AfterDeleter will do some operations after being deleted. 53 | type AfterDeleter interface { 54 | // AfterDelete will do some operations after being deleted. If an error is 55 | // returned, it will cause the delete to be rolled back. 56 | AfterDelete() error 57 | } 58 | 59 | // ApplyBeforeEvents calls all the update, insert or save before events of the 60 | // record. Save events are always called before the insert or update event. 61 | func ApplyBeforeEvents(r Record) error { 62 | if rec, ok := r.(BeforeSaver); ok { 63 | if err := rec.BeforeSave(); err != nil { 64 | return err 65 | } 66 | } 67 | 68 | if rec, ok := r.(BeforeInserter); ok && !r.IsPersisted() { 69 | if err := rec.BeforeInsert(); err != nil { 70 | return err 71 | } 72 | } 73 | 74 | if rec, ok := r.(BeforeUpdater); ok && r.IsPersisted() { 75 | if err := rec.BeforeUpdate(); err != nil { 76 | return err 77 | } 78 | } 79 | 80 | return nil 81 | } 82 | 83 | // ApplyAfterEvents calls all the update, insert or save after events of the 84 | // record. Save events are always called after the insert or update event. 85 | func ApplyAfterEvents(r Record, wasPersisted bool) error { 86 | if rec, ok := r.(AfterInserter); ok && !wasPersisted { 87 | if err := rec.AfterInsert(); err != nil { 88 | return err 89 | } 90 | } 91 | 92 | if rec, ok := r.(AfterUpdater); ok && wasPersisted { 93 | if err := rec.AfterUpdate(); err != nil { 94 | return err 95 | } 96 | } 97 | 98 | if rec, ok := r.(AfterSaver); ok { 99 | if err := rec.AfterSave(); err != nil { 100 | return err 101 | } 102 | } 103 | 104 | return nil 105 | } 106 | -------------------------------------------------------------------------------- /events_test.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | type ( 11 | evented struct { 12 | events map[string]int 13 | } 14 | 15 | before struct { 16 | model 17 | evented 18 | errorBeforeSave bool 19 | errorBeforeInsert bool 20 | errorBeforeUpdate bool 21 | } 22 | 23 | after struct { 24 | model 25 | evented 26 | errorAfterSave bool 27 | errorAfterInsert bool 28 | errorAfterUpdate bool 29 | } 30 | ) 31 | 32 | func (e *evented) setup() { 33 | if e.events == nil { 34 | e.events = make(map[string]int) 35 | } 36 | } 37 | 38 | func (b *before) BeforeInsert() error { 39 | b.setup() 40 | b.events["BeforeInsert"]++ 41 | if b.errorBeforeInsert { 42 | return errors.New("foo") 43 | } 44 | return nil 45 | } 46 | 47 | func (b *before) BeforeUpdate() error { 48 | b.setup() 49 | b.events["BeforeUpdate"]++ 50 | if b.errorBeforeUpdate { 51 | return errors.New("foo") 52 | } 53 | return nil 54 | } 55 | 56 | func (b *before) BeforeSave() error { 57 | b.setup() 58 | b.events["BeforeSave"]++ 59 | if b.errorBeforeSave { 60 | return errors.New("foo") 61 | } 62 | return nil 63 | } 64 | 65 | func (b *after) AfterInsert() error { 66 | b.setup() 67 | b.events["AfterInsert"]++ 68 | if b.errorAfterInsert { 69 | return errors.New("foo") 70 | } 71 | return nil 72 | } 73 | 74 | func (b *after) AfterUpdate() error { 75 | b.setup() 76 | b.events["AfterUpdate"]++ 77 | if b.errorAfterUpdate { 78 | return errors.New("foo") 79 | } 80 | return nil 81 | } 82 | 83 | func (b *after) AfterSave() error { 84 | b.setup() 85 | b.events["AfterSave"]++ 86 | if b.errorAfterSave { 87 | return errors.New("foo") 88 | } 89 | return nil 90 | } 91 | 92 | func TestApplyBeforeEvents(t *testing.T) { 93 | r := require.New(t) 94 | 95 | var before before 96 | r.Nil(ApplyBeforeEvents(&before)) 97 | before.setPersisted() 98 | r.Nil(ApplyBeforeEvents(&before)) 99 | 100 | r.Equal(1, before.events["BeforeInsert"]) 101 | r.Equal(1, before.events["BeforeUpdate"]) 102 | r.Equal(2, before.events["BeforeSave"]) 103 | 104 | before.errorBeforeUpdate = true 105 | r.NotNil(ApplyBeforeEvents(&before)) 106 | 107 | before.errorBeforeInsert = true 108 | before.errorBeforeUpdate = false 109 | before.persisted = false 110 | r.NotNil(ApplyBeforeEvents(&before)) 111 | 112 | before.errorBeforeSave = true 113 | before.errorBeforeInsert = false 114 | r.NotNil(ApplyBeforeEvents(&before)) 115 | } 116 | 117 | func TestApplyAfterEvents(t *testing.T) { 118 | r := require.New(t) 119 | 120 | var after after 121 | r.Nil(ApplyAfterEvents(&after, false)) 122 | r.Nil(ApplyAfterEvents(&after, true)) 123 | 124 | r.Equal(1, after.events["AfterInsert"]) 125 | r.Equal(1, after.events["AfterUpdate"]) 126 | r.Equal(2, after.events["AfterSave"]) 127 | 128 | after.errorAfterUpdate = true 129 | r.NotNil(ApplyAfterEvents(&after, true)) 130 | 131 | after.errorAfterInsert = true 132 | after.errorAfterUpdate = false 133 | r.NotNil(ApplyAfterEvents(&after, false)) 134 | 135 | after.errorAfterSave = true 136 | after.errorAfterInsert = false 137 | r.NotNil(ApplyAfterEvents(&after, false)) 138 | } 139 | -------------------------------------------------------------------------------- /generator/cli/kallax/cmd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "gopkg.in/src-d/go-kallax.v1/generator/cli/kallax/cmd" 8 | 9 | "gopkg.in/urfave/cli.v1" 10 | ) 11 | 12 | const version = "1.3.5" 13 | 14 | func main() { 15 | if err := newApp().Run(os.Args); err != nil { 16 | fmt.Fprintln(os.Stderr, err) 17 | os.Exit(1) 18 | } 19 | } 20 | 21 | func newApp() *cli.App { 22 | app := cli.NewApp() 23 | app.Name = "kallax" 24 | app.Version = version 25 | app.Usage = "generate kallax models" 26 | app.Flags = cmd.Generate.Flags 27 | app.Action = cmd.Generate.Action 28 | app.Commands = cli.Commands{ 29 | cmd.Generate, 30 | cmd.Migrate, 31 | } 32 | 33 | return app 34 | } 35 | -------------------------------------------------------------------------------- /generator/cli/kallax/cmd/gen.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | "path/filepath" 6 | 7 | "gopkg.in/src-d/go-kallax.v1/generator" 8 | cli "gopkg.in/urfave/cli.v1" 9 | "os" 10 | ) 11 | 12 | var Generate = cli.Command{ 13 | Name: "gen", 14 | Usage: "Generate kallax models", 15 | Action: generateAction, 16 | Flags: []cli.Flag{ 17 | cli.StringFlag{ 18 | Name: "input", 19 | Value: ".", 20 | Usage: "Input package directory", 21 | }, 22 | cli.StringFlag{ 23 | Name: "output", 24 | Value: "kallax.go", 25 | Usage: "Output file name", 26 | }, 27 | cli.StringSliceFlag{ 28 | Name: "exclude, e", 29 | Usage: "List of excluded files from the package when generating the code for your models. Use this to exclude files in your package that uses the generated code. You can use this flag as many times as you want.", 30 | }, 31 | }, 32 | } 33 | 34 | func generateAction(c *cli.Context) error { 35 | input := c.String("input") 36 | output := c.String("output") 37 | excluded := c.StringSlice("exclude") 38 | 39 | ok, err := isDirectory(input) 40 | if err != nil { 41 | return fmt.Errorf("kallax: can't check input directory: %s", err) 42 | } 43 | 44 | if !ok { 45 | return fmt.Errorf("kallax: Input path should be a directory %s", input) 46 | } 47 | 48 | var foundPrevious bool 49 | if _, err = os.Stat(output); err == nil { 50 | foundPrevious = true 51 | fmt.Fprintf(os.Stderr, "NOTE: Previous generated file `%s` found, renaming to `%s`\n", output, output+".old") 52 | err = os.Rename(output, output+".old") 53 | } 54 | 55 | p := generator.NewProcessor(input, excluded) 56 | pkg, err := p.Do() 57 | if err != nil { 58 | return err 59 | } 60 | 61 | gen := generator.NewGenerator(filepath.Join(input, output)) 62 | err = gen.Generate(pkg) 63 | if err != nil { 64 | return err 65 | } 66 | 67 | if foundPrevious { 68 | fmt.Fprintf(os.Stderr, "NOTE: Generation succeded, removing `%s`\n", output+".old") 69 | os.Remove(output + ".old") 70 | } 71 | 72 | return nil 73 | } 74 | -------------------------------------------------------------------------------- /generator/cli/kallax/cmd/migrate.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | "path/filepath" 6 | 7 | "github.com/golang-migrate/migrate" 8 | _ "github.com/golang-migrate/migrate/database/postgres" 9 | _ "github.com/golang-migrate/migrate/source/file" 10 | 11 | "gopkg.in/src-d/go-kallax.v1/generator" 12 | cli "gopkg.in/urfave/cli.v1" 13 | ) 14 | 15 | var Migrate = cli.Command{ 16 | Name: "migrate", 17 | Usage: "Generate migrations for current kallax models", 18 | Action: migrateAction, 19 | Flags: []cli.Flag{ 20 | cli.StringFlag{ 21 | Name: "out, o", 22 | Usage: "Output directory of migrations", 23 | }, 24 | cli.StringFlag{ 25 | Name: "name, n", 26 | Usage: "Descriptive name for the migration", 27 | Value: "migration", 28 | }, 29 | cli.StringSliceFlag{ 30 | Name: "input, i", 31 | Usage: "List of directories to scan models from. You can use this flag as many times as you want.", 32 | }, 33 | }, 34 | Subcommands: cli.Commands{ 35 | Up, 36 | Down, 37 | }, 38 | } 39 | 40 | var migrationFlags = []cli.Flag{ 41 | cli.StringFlag{ 42 | Name: "dir, d", 43 | Value: "./migrations", 44 | Usage: "Directory where your migrations are stored", 45 | }, 46 | cli.StringFlag{ 47 | Name: "dsn", 48 | Usage: "PostgreSQL data source name. Example: `user:pass@localhost:5432/database?sslmode=enable`", 49 | }, 50 | cli.UintFlag{ 51 | Name: "steps, n", 52 | Usage: "Number of migrations to run", 53 | }, 54 | cli.UintFlag{ 55 | Name: "version, v", 56 | Usage: "Migrate to a specific version. If `steps` and this flag are given, this will be used.", 57 | }, 58 | } 59 | 60 | var Up = cli.Command{ 61 | Name: "up", 62 | Usage: "Executes the migrations from the current version until the specified version.", 63 | Action: runMigrationAction(upAction), 64 | Flags: append(migrationFlags, cli.BoolFlag{ 65 | Name: "all", 66 | Usage: "If this flag is used, the database will be migrated all the way up.", 67 | }), 68 | } 69 | 70 | var Down = cli.Command{ 71 | Name: "down", 72 | Usage: "Downgrades the database a certain number of migrations or until a certain version.", 73 | Action: runMigrationAction(downAction), 74 | Flags: migrationFlags, 75 | } 76 | 77 | func upAction(m *migrate.Migrate, steps, version uint, all bool) error { 78 | if all { 79 | if err := m.Up(); err != nil { 80 | return fmt.Errorf("kallax: unable to upgrade the database all the way up: %s", err) 81 | } 82 | } else if version > 0 { 83 | if err := m.Migrate(version); err != nil { 84 | return fmt.Errorf("kallax: unable to upgrade up to version %d: %s", version, err) 85 | } 86 | } else if steps > 0 { 87 | if err := m.Steps(int(steps)); err != nil { 88 | return fmt.Errorf("kallax: unable to execute %d migration(s) up: %s", steps, err) 89 | } 90 | } else { 91 | return fmt.Errorf("WARN: No `version` or `steps` provided") 92 | } 93 | reportMigrationSuccess(m) 94 | return nil 95 | } 96 | 97 | func downAction(m *migrate.Migrate, steps, version uint, all bool) error { 98 | if version > 0 { 99 | if err := m.Migrate(version); err != nil { 100 | return fmt.Errorf("kallax: unable to rollback to version %d: %s", version, err) 101 | } 102 | } else if steps > 0 { 103 | if err := m.Steps(-int(steps)); err != nil { 104 | return fmt.Errorf("kallax: unable to execute %d migration(s) down: %s", steps, err) 105 | } 106 | } else { 107 | return fmt.Errorf("kallax: no `version` or `steps` provided. You need to specify one of them.") 108 | } 109 | reportMigrationSuccess(m) 110 | return nil 111 | } 112 | 113 | func reportMigrationSuccess(m *migrate.Migrate) { 114 | fmt.Println("Success! the migration has been run.") 115 | 116 | if v, _, err := m.Version(); err != nil { 117 | fmt.Printf("Unable to check the latest version of the database: %s.\n", err) 118 | } else { 119 | fmt.Printf("Database is now at version %d.\n", v) 120 | } 121 | } 122 | 123 | type runMigrationFunc func(m *migrate.Migrate, steps, version uint, all bool) error 124 | 125 | func runMigrationAction(fn runMigrationFunc) cli.ActionFunc { 126 | return func(c *cli.Context) error { 127 | var ( 128 | dir = c.String("dir") 129 | dsn = c.String("dsn") 130 | steps = c.Uint("steps") 131 | version = c.Uint("version") 132 | all = c.Bool("all") 133 | ) 134 | 135 | ok, err := isDirectory(dir) 136 | if err != nil { 137 | return fmt.Errorf("kallax: cannot check if `dir` is a directory: %s", err) 138 | } 139 | 140 | if !ok { 141 | return fmt.Errorf("kallax: argument `dir` must be a valid directory") 142 | } 143 | 144 | dir, err = filepath.Abs(dir) 145 | if err != nil { 146 | return fmt.Errorf("kallax: cannot get absolute path of `dir`: %s", err) 147 | } 148 | 149 | m, err := migrate.New(pathToFileURL(dir), fmt.Sprintf("postgres://%s", dsn)) 150 | if err != nil { 151 | return fmt.Errorf("kallax: unable to open a connection with the database: %s", err) 152 | } 153 | 154 | return fn(m, steps, version, all) 155 | } 156 | } 157 | 158 | func pathToFileURL(path string) string { 159 | if !filepath.IsAbs(path) { 160 | var err error 161 | path, err = filepath.Abs(path) 162 | if err != nil { 163 | return "" 164 | } 165 | } 166 | return fmt.Sprintf("file://%s", filepath.ToSlash(path)) 167 | } 168 | 169 | func migrateAction(c *cli.Context) error { 170 | dirs := c.StringSlice("input") 171 | dir := c.String("out") 172 | name := c.String("name") 173 | 174 | var pkgs []*generator.Package 175 | for _, dir := range dirs { 176 | ok, err := isDirectory(dir) 177 | if err != nil { 178 | return fmt.Errorf("kallax: cannot check directory in `input`: %s", err) 179 | } 180 | 181 | if !ok { 182 | return fmt.Errorf("kallax: `input` must be a valid directory") 183 | } 184 | 185 | p := generator.NewProcessor(dir, nil) 186 | p.Silent() 187 | pkg, err := p.Do() 188 | if err != nil { 189 | return err 190 | } 191 | 192 | pkgs = append(pkgs, pkg) 193 | } 194 | 195 | ok, err := isDirectory(dir) 196 | if err != nil { 197 | return fmt.Errorf("kallax: cannot check directory in `out`: %s", err) 198 | } 199 | 200 | if !ok { 201 | return fmt.Errorf("kallax: `out` must be a valid directory") 202 | } 203 | 204 | g := generator.NewMigrationGenerator(name, dir) 205 | migration, err := g.Build(pkgs...) 206 | if err != nil { 207 | return err 208 | } 209 | 210 | return g.Generate(migration) 211 | } 212 | -------------------------------------------------------------------------------- /generator/cli/kallax/cmd/migrate_test.go: -------------------------------------------------------------------------------- 1 | // +build !windows 2 | 3 | package cmd 4 | 5 | import ( 6 | "os" 7 | "path/filepath" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestPathToFileURL(t *testing.T) { 14 | wd, err := os.Getwd() 15 | require.NoError(t, err) 16 | 17 | cases := []struct { 18 | input string 19 | expected string 20 | }{ 21 | {"/foo/bar/baz", "file:///foo/bar/baz"}, 22 | {"foo/bar", "file://" + filepath.Join(wd, "foo/bar")}, 23 | } 24 | 25 | for _, tt := range cases { 26 | require.Equal(t, tt.expected, pathToFileURL(tt.input), tt.input) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /generator/cli/kallax/cmd/migrate_windows_test.go: -------------------------------------------------------------------------------- 1 | // +build windows 2 | 3 | package cmd 4 | 5 | import ( 6 | "os" 7 | "path/filepath" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestPathToFileURL(t *testing.T) { 14 | wd, err := os.Getwd() 15 | require.NoError(t, err) 16 | 17 | cases := []struct { 18 | input string 19 | expected string 20 | }{ 21 | {"c:\\foo\\bar\\baz", "file://c:/foo/bar/baz"}, 22 | {"foo\\bar", "file://" + filepath.ToSlash(filepath.Join(wd, "foo", "bar"))}, 23 | } 24 | 25 | for _, tt := range cases { 26 | require.Equal(t, tt.expected, pathToFileURL(tt.input), tt.input) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /generator/cli/kallax/cmd/util.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import "os" 4 | 5 | func isDirectory(name string) (bool, error) { 6 | info, err := os.Stat(name) 7 | if err != nil { 8 | return false, err 9 | } 10 | 11 | return info.IsDir(), nil 12 | } 13 | -------------------------------------------------------------------------------- /generator/common_test.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "go/ast" 5 | "go/parser" 6 | "go/token" 7 | "go/types" 8 | "reflect" 9 | 10 | parseutil "gopkg.in/src-d/go-parse-utils.v1" 11 | ) 12 | 13 | func mkField(name, typ, tag string, fields ...*Field) *Field { 14 | f := NewField(name, typ, reflect.StructTag(tag)) 15 | f.SetFields(fields) 16 | return f 17 | } 18 | 19 | func withKind(f *Field, kind FieldKind) *Field { 20 | f.Kind = kind 21 | return f 22 | } 23 | 24 | func withPtr(f *Field) *Field { 25 | f.IsPtr = true 26 | return f 27 | } 28 | 29 | func withAlias(f *Field) *Field { 30 | f.IsAlias = true 31 | return f 32 | } 33 | 34 | func withJSON(f *Field) *Field { 35 | f.IsJSON = true 36 | return f 37 | } 38 | 39 | func withParent(f *Field, parent *Field) *Field { 40 | f.Parent = parent 41 | return f 42 | } 43 | 44 | func withNode(f *Field, name string, typ types.Type) *Field { 45 | f.Node = types.NewVar(token.NoPos, nil, name, typ) 46 | return f 47 | } 48 | 49 | func inline(f *Field) *Field { 50 | f.Tag = reflect.StructTag(`kallax:",inline"`) 51 | return f 52 | } 53 | 54 | func processorFixture(source string) (*Processor, error) { 55 | fset := &token.FileSet{} 56 | astFile, err := parser.ParseFile(fset, "fixture.go", source, 0) 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | cfg := &types.Config{ 62 | Importer: parseutil.NewImporter(), 63 | } 64 | p, err := cfg.Check("foo", fset, []*ast.File{astFile}, nil) 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | prc := NewProcessor("fixture", []string{"foo.go"}) 70 | prc.Package = p 71 | return prc, nil 72 | } 73 | 74 | func processFixture(source string) (*Package, error) { 75 | prc, err := processorFixture(source) 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | prc.Silent() 81 | return prc.processPackage() 82 | } 83 | -------------------------------------------------------------------------------- /generator/generator.go: -------------------------------------------------------------------------------- 1 | // Package generator implements the processor of source code and generator of 2 | // kallax models based on Go source code. 3 | package generator // import "gopkg.in/src-d/go-kallax.v1/generator" 4 | 5 | import ( 6 | "bytes" 7 | "encoding" 8 | "encoding/json" 9 | "fmt" 10 | "io/ioutil" 11 | "os" 12 | "path/filepath" 13 | "runtime/debug" 14 | "strings" 15 | "time" 16 | 17 | "github.com/fatih/color" 18 | ) 19 | 20 | // Generator is in charge of generating files for packages. 21 | type Generator struct { 22 | filename string 23 | } 24 | 25 | // NewGenerator creates a new generator that can save on the given filename. 26 | func NewGenerator(filename string) *Generator { 27 | return &Generator{filename} 28 | } 29 | 30 | // Generate writes the file with the contents of the given package. 31 | func (g *Generator) Generate(pkg *Package) error { 32 | return g.writeFile(pkg) 33 | } 34 | 35 | func (g *Generator) writeFile(pkg *Package) (err error) { 36 | file, err := os.Create(g.filename) 37 | if err != nil { 38 | return err 39 | } 40 | 41 | defer func() { 42 | if r := recover(); r != nil { 43 | fmt.Printf("kallax: PANIC during '%s' generation:\n%s\n\n", g.filename, r) 44 | if err == nil { 45 | err = fmt.Errorf(string(debug.Stack())) 46 | } 47 | } 48 | 49 | file.Close() 50 | if err != nil { 51 | if os.Remove(g.filename) == nil { 52 | fmt.Println("kallax: No file generated due to an occurred error:") 53 | } else { 54 | fmt.Printf("kallax: The autogenerated file '%s' could not be completed nor deleted due to an occurred error:\n", g.filename) 55 | } 56 | } 57 | }() 58 | 59 | return Base.Execute(file, pkg) 60 | } 61 | 62 | // Timestamper is a function that returns the current time. 63 | type Timestamper func() time.Time 64 | 65 | // MigrationGenerator is a generator of migrations. 66 | type MigrationGenerator struct { 67 | name string 68 | dir string 69 | now Timestamper 70 | } 71 | 72 | type migrationFileType string 73 | 74 | const ( 75 | migrationUp = migrationFileType("up.sql") 76 | migrationDown = migrationFileType("down.sql") 77 | migrationLock = migrationFileType("lock.json") 78 | ) 79 | 80 | // NewMigrationGenerator returns a new migration generator with the given 81 | // migrations directory. 82 | func NewMigrationGenerator(name, dir string) *MigrationGenerator { 83 | return &MigrationGenerator{slugify(name), dir, time.Now} 84 | } 85 | 86 | // Build creates a new migration from a set of scanned packages. 87 | func (g *MigrationGenerator) Build(pkgs ...*Package) (*Migration, error) { 88 | old, err := g.LoadLock() 89 | if err != nil { 90 | return nil, err 91 | } 92 | 93 | new, err := SchemaFromPackages(pkgs...) 94 | if err != nil { 95 | return nil, err 96 | } 97 | 98 | return NewMigration(old, new) 99 | } 100 | 101 | // Generate will generate the given migration. 102 | func (g *MigrationGenerator) Generate(migration *Migration) error { 103 | g.printMigrationInfo(migration) 104 | if len(migration.Up) == 0 { 105 | return nil 106 | } 107 | return g.writeMigration(migration) 108 | } 109 | 110 | func (g *MigrationGenerator) printMigrationInfo(migration *Migration) { 111 | if len(migration.Up) == 0 { 112 | fmt.Println("There are no changes since last migration. Nothing will be generated.") 113 | return 114 | } 115 | 116 | fmt.Println("There are changes since last migration.\n\nThese are the proposed changes:") 117 | for _, change := range migration.Up { 118 | c := color.FgGreen 119 | switch change.(type) { 120 | case *DropColumn, *DropTable: 121 | c = color.FgRed 122 | case *ManualChange: 123 | c = color.FgYellow 124 | } 125 | color := color.New(c, color.Bold) 126 | color.Printf(" => ") 127 | fmt.Println(change.String()) 128 | } 129 | } 130 | 131 | // LoadLock loads the lock file. 132 | func (g *MigrationGenerator) LoadLock() (*DBSchema, error) { 133 | bytes, err := ioutil.ReadFile(filepath.Join(g.dir, string(migrationLock))) 134 | if os.IsNotExist(err) { 135 | return new(DBSchema), nil 136 | } else if err != nil { 137 | return nil, fmt.Errorf("error opening lock file: %s", err) 138 | } 139 | 140 | var schema DBSchema 141 | if err := json.Unmarshal(bytes, &schema); err != nil { 142 | return nil, fmt.Errorf("error unmarshaling lock schema: %s", err) 143 | } 144 | 145 | return &schema, nil 146 | } 147 | 148 | func (g *MigrationGenerator) writeMigration(migration *Migration) error { 149 | t := g.now() 150 | files := []struct { 151 | file string 152 | content encoding.TextMarshaler 153 | }{ 154 | {filepath.Join(g.dir, string(migrationLock)), migration.Lock}, 155 | {g.migrationFile(migrationDown, t), migration.Down}, 156 | {g.migrationFile(migrationUp, t), migration.Up}, 157 | } 158 | 159 | for _, f := range files { 160 | if err := g.createFile(f.file, f.content); err != nil { 161 | return err 162 | } 163 | } 164 | 165 | return nil 166 | } 167 | 168 | func (g *MigrationGenerator) migrationFile(typ migrationFileType, t time.Time) string { 169 | return filepath.Join(g.dir, fmt.Sprintf("%d_%s.%s", t.Unix(), g.name, typ)) 170 | } 171 | 172 | func (g *MigrationGenerator) createFile(filename string, marshaler encoding.TextMarshaler) error { 173 | data, err := marshaler.MarshalText() 174 | if err != nil { 175 | return err 176 | } 177 | 178 | f, err := os.OpenFile(filename, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0755) 179 | if err != nil { 180 | return fmt.Errorf("error opening file: %s: %s", filename, err) 181 | } 182 | 183 | defer f.Close() 184 | if _, err := f.Write(data); err != nil { 185 | return fmt.Errorf("error writing file: %s: %s", filename, err) 186 | } 187 | 188 | return nil 189 | } 190 | 191 | func slugify(str string) string { 192 | var buf bytes.Buffer 193 | for _, r := range strings.ToLower(str) { 194 | if ('a' <= r && r <= 'z') || ('0' <= r && r <= '9') { 195 | buf.WriteRune(r) 196 | } else if r == ' ' || r == '_' || r == '-' { 197 | buf.WriteRune('_') 198 | } 199 | } 200 | return buf.String() 201 | } 202 | -------------------------------------------------------------------------------- /generator/generator_test.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "path/filepath" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestMigrationGeneratorLoadLock(t *testing.T) { 14 | dir, err := ioutil.TempDir("", "kallax-migration-generator") 15 | require.NoError(t, err) 16 | defer os.RemoveAll(dir) 17 | 18 | g := NewMigrationGenerator("migration", dir) 19 | schema, err := g.LoadLock() 20 | require.NoError(t, err) 21 | require.NotNil(t, schema) 22 | require.Len(t, schema.Tables, 0) 23 | 24 | content, err := mkSchema(mkTable("foo")).MarshalText() 25 | require.NoError(t, err) 26 | 27 | err = ioutil.WriteFile(filepath.Join(dir, string(migrationLock)), content, 0755) 28 | require.NoError(t, err) 29 | 30 | schema, err = g.LoadLock() 31 | require.NoError(t, err) 32 | require.NotNil(t, schema) 33 | require.Len(t, schema.Tables, 1) 34 | } 35 | 36 | func TestMigrationGeneratorBuild(t *testing.T) { 37 | dir, err := ioutil.TempDir("", "kallax-migration-generator") 38 | require.NoError(t, err) 39 | defer os.RemoveAll(dir) 40 | 41 | g := NewMigrationGenerator("migration", dir) 42 | content, err := mkSchema(mkTable("foo")).MarshalText() 43 | require.NoError(t, err) 44 | 45 | err = ioutil.WriteFile(filepath.Join(dir, string(migrationLock)), content, 0755) 46 | require.NoError(t, err) 47 | 48 | migration, err := g.Build() 49 | require.NoError(t, err) 50 | require.NotNil(t, migration) 51 | } 52 | 53 | func TestMigrationGeneratorGenerate(t *testing.T) { 54 | old := mkSchema(table1) 55 | new := mkSchema(table1, table2) 56 | migration, err := NewMigration(old, new) 57 | require.NoError(t, err) 58 | 59 | dir, err := ioutil.TempDir("", "kallax-migration-generator") 60 | require.NoError(t, err) 61 | defer os.RemoveAll(dir) 62 | 63 | g := NewMigrationGenerator("migration", dir) 64 | g.now = func() time.Time { 65 | var t time.Time 66 | return t 67 | } 68 | 69 | require.NoError(t, g.Generate(migration)) 70 | 71 | content, err := ioutil.ReadFile(g.migrationFile(migrationUp, g.now())) 72 | require.NoError(t, err) 73 | require.Equal(t, "BEGIN;\n\n"+expectedTable2+"\n\nCOMMIT;\n", string(content)) 74 | 75 | content, err = ioutil.ReadFile(g.migrationFile(migrationDown, g.now())) 76 | require.NoError(t, err) 77 | require.Equal(t, "BEGIN;\n\nDROP TABLE table2;\n\nCOMMIT;\n", string(content)) 78 | 79 | expected, err := migration.Lock.MarshalText() 80 | require.NoError(t, err) 81 | 82 | content, err = ioutil.ReadFile(filepath.Join(dir, string(migrationLock))) 83 | require.NoError(t, err) 84 | require.Equal(t, string(expected), string(content)) 85 | } 86 | 87 | func TestSlugify(t *testing.T) { 88 | cases := []struct { 89 | input string 90 | expected string 91 | }{ 92 | {"the fancy slug", "the_fancy_slug"}, 93 | {"ThE-FaNcYnEss", "the_fancyness"}, 94 | {"this is: a migration", "this_is_a_migration"}, 95 | {"add caché", "add_cach"}, 96 | } 97 | 98 | for _, c := range cases { 99 | require.Equal(t, c.expected, slugify(c.input)) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /generator/processor.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/build" 7 | "go/parser" 8 | "go/token" 9 | "go/types" 10 | "path/filepath" 11 | "reflect" 12 | "strings" 13 | 14 | parseutil "gopkg.in/src-d/go-parse-utils.v1" 15 | ) 16 | 17 | const ( 18 | // BaseModel is the type name of the kallax base model. 19 | BaseModel = "gopkg.in/src-d/go-kallax.v1.Model" 20 | //URL is the type name of the net/url.URL. 21 | URL = "url.URL" 22 | ) 23 | 24 | // Processor is in charge of processing the package in a patch and 25 | // scan models from it. 26 | type Processor struct { 27 | // Path of the package. 28 | Path string 29 | // Ignore is the list of files to ignore when scanning. 30 | Ignore map[string]struct{} 31 | // Package is the scanned package. 32 | Package *types.Package 33 | silent bool 34 | } 35 | 36 | // NewProcessor creates a new Processor for the given path and ignored files. 37 | func NewProcessor(path string, ignore []string) *Processor { 38 | i := make(map[string]struct{}) 39 | for _, file := range ignore { 40 | i[file] = struct{}{} 41 | } 42 | 43 | return &Processor{ 44 | Path: path, 45 | Ignore: i, 46 | } 47 | } 48 | 49 | // Silent makes the processor not spit any output to stdout. 50 | func (p *Processor) Silent() { 51 | p.silent = true 52 | } 53 | 54 | func (p *Processor) write(msg string, args ...interface{}) { 55 | if !p.silent { 56 | fmt.Println(fmt.Sprintf(msg, args...)) 57 | } 58 | } 59 | 60 | // Do performs all the processing and returns the scanned package. 61 | func (p *Processor) Do() (*Package, error) { 62 | files, err := p.getSourceFiles() 63 | if err != nil { 64 | return nil, err 65 | } 66 | 67 | p.Package, err = p.parseSourceFiles(files) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | return p.processPackage() 73 | } 74 | 75 | func (p *Processor) getSourceFiles() ([]string, error) { 76 | pkg, err := build.Default.ImportDir(p.Path, 0) 77 | if err != nil { 78 | return nil, fmt.Errorf("kallax: cannot process directory %s: %s", p.Path, err) 79 | } 80 | 81 | var files []string 82 | files = append(files, pkg.GoFiles...) 83 | files = append(files, pkg.CgoFiles...) 84 | 85 | if len(files) == 0 { 86 | return nil, fmt.Errorf("kallax: %s: no buildable Go files", p.Path) 87 | } 88 | 89 | return joinDirectory(p.Path, p.removeIgnoredFiles(files)), nil 90 | } 91 | 92 | func (p *Processor) removeIgnoredFiles(filenames []string) []string { 93 | var output []string 94 | for _, filename := range filenames { 95 | if _, ok := p.Ignore[filename]; ok { 96 | continue 97 | } 98 | 99 | output = append(output, filename) 100 | } 101 | 102 | return output 103 | } 104 | 105 | func (p *Processor) parseSourceFiles(filenames []string) (*types.Package, error) { 106 | var files []*ast.File 107 | fs := token.NewFileSet() 108 | for _, filename := range filenames { 109 | file, err := parser.ParseFile(fs, filename, nil, 0) 110 | if err != nil { 111 | return nil, fmt.Errorf("kallax: parsing package: %s: %s", filename, err) 112 | } 113 | 114 | files = append(files, file) 115 | } 116 | 117 | config := types.Config{ 118 | FakeImportC: true, 119 | Error: func(error) {}, 120 | Importer: parseutil.NewImporter(), 121 | } 122 | 123 | return config.Check(p.Path, fs, files, new(types.Info)) 124 | } 125 | 126 | func (p *Processor) processPackage() (*Package, error) { 127 | pkg := NewPackage(p.Package) 128 | var ctors []*types.Func 129 | 130 | p.write("Package: %s", pkg.Name) 131 | 132 | s := p.Package.Scope() 133 | var models []*Model 134 | for _, name := range s.Names() { 135 | obj := s.Lookup(name) 136 | switch t := obj.Type().(type) { 137 | case *types.Signature: 138 | if strings.HasPrefix(name, "new") { 139 | ctors = append(ctors, obj.(*types.Func)) 140 | } 141 | case *types.Named: 142 | if str, ok := t.Underlying().(*types.Struct); ok { 143 | if m, err := p.processModel(name, str, t); err != nil { 144 | return nil, err 145 | } else if m != nil { 146 | p.write("Model: %s", m) 147 | 148 | if err := m.Validate(); err != nil { 149 | return nil, err 150 | } 151 | 152 | models = append(models, m) 153 | m.Node = t 154 | m.Package = p.Package 155 | } 156 | } 157 | } 158 | } 159 | 160 | pkg.SetModels(models) 161 | if err := pkg.addMissingRelationships(); err != nil { 162 | return nil, err 163 | } 164 | for _, ctor := range ctors { 165 | p.tryMatchConstructor(pkg, ctor) 166 | } 167 | 168 | return pkg, nil 169 | } 170 | 171 | func (p *Processor) tryMatchConstructor(pkg *Package, fun *types.Func) { 172 | if !strings.HasPrefix(fun.Name(), "new") { 173 | return 174 | } 175 | 176 | if m := pkg.FindModel(fun.Name()[3:]); m != nil { 177 | sig := fun.Type().(*types.Signature) 178 | if sig.Recv() != nil { 179 | return 180 | } 181 | 182 | res := sig.Results() 183 | if res.Len() > 0 { 184 | for i := 0; i < res.Len(); i++ { 185 | if isTypeOrPtrTo(res.At(i).Type(), m.Node) { 186 | m.CtorFunc = fun 187 | return 188 | } 189 | } 190 | } 191 | } 192 | } 193 | 194 | func (p *Processor) processModel(name string, s *types.Struct, t *types.Named) (*Model, error) { 195 | m := NewModel(name) 196 | m.Events = p.findEvents(t) 197 | 198 | var base int 199 | var fields []*Field 200 | if base, fields = p.processFields(s, nil, true); base == -1 { 201 | return nil, nil 202 | } 203 | 204 | p.processBaseField(m, fields[base]) 205 | if err := m.SetFields(fields); err != nil { 206 | return nil, err 207 | } 208 | 209 | return m, nil 210 | } 211 | 212 | var allEvents = Events{ 213 | BeforeInsert, 214 | AfterInsert, 215 | BeforeUpdate, 216 | AfterUpdate, 217 | BeforeSave, 218 | AfterSave, 219 | BeforeDelete, 220 | AfterDelete, 221 | } 222 | 223 | func (p *Processor) findEvents(node *types.Named) []Event { 224 | var events []Event 225 | for _, e := range allEvents { 226 | if p.isEventPresent(node, e) { 227 | events = append(events, e) 228 | } 229 | } 230 | 231 | return events 232 | } 233 | 234 | // isEventPresent checks the given Event is implemented for the given node. 235 | func (p *Processor) isEventPresent(node *types.Named, e Event) bool { 236 | signature := getMethodSignature(p.Package, types.NewPointer(node), string(e)) 237 | return signatureMatches(signature, nil, typeCheckers{isBuiltinError}) 238 | } 239 | 240 | // processFields returns which field index is an embedded kallax.Model, or -1 if none. 241 | func (p *Processor) processFields(s *types.Struct, done []*types.Struct, root bool) (base int, fields []*Field) { 242 | base = -1 243 | 244 | for i := 0; i < s.NumFields(); i++ { 245 | f := s.Field(i) 246 | if !f.Exported() || isIgnoredField(s, i) { 247 | continue 248 | } 249 | 250 | field := NewField( 251 | f.Name(), 252 | typeName(f.Type().Underlying()), 253 | reflect.StructTag(s.Tag(i)), 254 | ) 255 | field.Node = f 256 | if typeName(f.Type()) == BaseModel { 257 | base = i 258 | field.Type = BaseModel 259 | } 260 | 261 | if f.Anonymous() { 262 | field.IsEmbedded = true 263 | } 264 | 265 | p.processField(field, f.Type(), done, root) 266 | if field.Kind == Invalid { 267 | p.write("WARNING: arrays of relationships are not supported. Field %s will be ignored.", field.Name) 268 | continue 269 | } 270 | 271 | fields = append(fields, field) 272 | } 273 | 274 | return base, fields 275 | } 276 | 277 | // processField processes recursively the field. During the processing several 278 | // field properties might be modified, such as the properties that report if 279 | // the type has to be serialized to json, if it's an alias or if it's a pointer 280 | // and so on. Also, the kind of the field is set here. 281 | // If root is true, models are established as relationships. If not, they are 282 | // just treated as structs. 283 | // The following types are always set as JSON: 284 | // - Map 285 | // - Slice or Array with non-basic underlying type 286 | // - Interface 287 | // - Struct that is not a model or is not at root level 288 | func (p *Processor) processField(field *Field, typ types.Type, done []*types.Struct, root bool) { 289 | switch typ := typ.(type) { 290 | case *types.Basic: 291 | field.Type = typ.String() 292 | field.Kind = Basic 293 | case *types.Pointer: 294 | field.IsPtr = true 295 | p.processField(field, typ.Elem(), done, root) 296 | case *types.Named: 297 | if field.Type == BaseModel { 298 | p.processField(field, typ.Underlying(), done, root) 299 | return 300 | } 301 | 302 | if isModel(typ, true) && root { 303 | field.Kind = Relationship 304 | field.Type = typ.String() 305 | return 306 | } 307 | 308 | // embedded fields won't be stored, only their fields, so it's irrelevant 309 | // if they implement scanner and valuer 310 | if !field.IsEmbedded && isSQLType(p.Package, types.NewPointer(typ)) { 311 | field.Kind = Interface 312 | return 313 | } 314 | 315 | if t, ok := specialTypes[typeName(typ)]; ok { 316 | field.Type = t 317 | return 318 | } 319 | 320 | p.processField(field, typ.Underlying(), done, root) 321 | field.IsAlias = !field.IsJSON 322 | case *types.Array: 323 | var underlying Field 324 | p.processField(&underlying, typ.Elem(), done, root) 325 | if underlying.Kind == Relationship { 326 | field.Kind = Invalid 327 | return 328 | } 329 | 330 | if underlying.Kind != Basic { 331 | field.IsJSON = true 332 | } 333 | field.Kind = Array 334 | field.SetFields(underlying.Fields) 335 | case *types.Slice: 336 | var underlying Field 337 | p.processField(&underlying, typ.Elem(), done, root) 338 | if underlying.Kind == Relationship { 339 | field.Kind = Relationship 340 | return 341 | } 342 | 343 | if underlying.Kind != Basic { 344 | field.IsJSON = true 345 | } 346 | field.Kind = Slice 347 | field.SetFields(underlying.Fields) 348 | case *types.Map: 349 | field.Kind = Map 350 | field.IsJSON = true 351 | case *types.Interface: 352 | field.Kind = Interface 353 | field.IsJSON = true 354 | case *types.Struct: 355 | field.Kind = Struct 356 | field.IsJSON = true 357 | 358 | d := false 359 | for _, v := range done { 360 | if v == typ { 361 | d = true 362 | break 363 | } 364 | } 365 | 366 | if !d { 367 | _, subfs := p.processFields(typ, append(done, typ), false) 368 | field.SetFields(subfs) 369 | } 370 | default: 371 | p.write("WARNING: Ignored field %s of type %s.", field.Name, field.Type) 372 | } 373 | } 374 | 375 | func isSQLType(pkg *types.Package, typ types.Type) bool { 376 | scan := getMethodSignature(pkg, typ, "Scan") 377 | if !signatureMatches(scan, typeCheckers{isEmptyInterface}, typeCheckers{isBuiltinError}) { 378 | return false 379 | } 380 | 381 | value := getMethodSignature(pkg, typ, "Value") 382 | if !signatureMatches(value, nil, typeCheckers{isDriverValue, isBuiltinError}) { 383 | return false 384 | } 385 | 386 | return true 387 | } 388 | 389 | func signatureMatches(s *types.Signature, params typeCheckers, results typeCheckers) bool { 390 | return s != nil && 391 | s.Params().Len() == len(params) && 392 | s.Results().Len() == len(results) && 393 | params.check(s.Params()) && 394 | results.check(s.Results()) 395 | } 396 | 397 | type typeCheckers []typeChecker 398 | 399 | func (c typeCheckers) check(tuple *types.Tuple) bool { 400 | for i, checker := range c { 401 | if !checker(tuple.At(i).Type()) { 402 | return false 403 | } 404 | } 405 | return true 406 | } 407 | 408 | type typeChecker func(types.Type) bool 409 | 410 | func getMethodSignature(pkg *types.Package, typ types.Type, name string) *types.Signature { 411 | ms := types.NewMethodSet(typ) 412 | method := ms.Lookup(pkg, name) 413 | if method == nil { 414 | return nil 415 | } 416 | 417 | return method.Obj().(*types.Func).Type().(*types.Signature) 418 | } 419 | 420 | func isEmptyInterface(typ types.Type) bool { 421 | switch typ := typ.(type) { 422 | case *types.Interface: 423 | return typ.NumMethods() == 0 424 | } 425 | return false 426 | } 427 | 428 | func isDriverValue(typ types.Type) bool { 429 | switch typ := typ.(type) { 430 | case *types.Named: 431 | return typ.String() == "database/sql/driver.Value" 432 | } 433 | return false 434 | } 435 | 436 | // isModel checks if the type is a model. If dive is true, it will check also 437 | // the types of the struct if the type is a struct. 438 | func isModel(typ types.Type, dive bool) bool { 439 | switch typ := typ.(type) { 440 | case *types.Named: 441 | if typeName(typ) == BaseModel { 442 | return true 443 | } 444 | return isModel(typ.Underlying(), true && dive) 445 | case *types.Pointer: 446 | return isModel(typ.Elem(), true && dive) 447 | case *types.Struct: 448 | if !dive { 449 | return false 450 | } 451 | 452 | for i := 0; i < typ.NumFields(); i++ { 453 | if isModel(typ.Field(i).Type(), false) { 454 | return true 455 | } 456 | } 457 | } 458 | return false 459 | } 460 | 461 | func (p *Processor) processBaseField(m *Model, f *Field) { 462 | m.Table = f.Tag.Get("table") 463 | if m.Table == "" { 464 | m.Table = toLowerSnakeCase(m.Name) 465 | } 466 | } 467 | 468 | func joinDirectory(directory string, files []string) []string { 469 | result := make([]string, len(files)) 470 | for i, file := range files { 471 | result[i] = filepath.Join(directory, file) 472 | } 473 | 474 | return result 475 | } 476 | 477 | func typeName(typ types.Type) string { 478 | return removeGoPath(typ.String()) 479 | } 480 | 481 | var separator = filepath.Separator 482 | 483 | // toSlash is an identical implementation of filepath.ToSlash. Is only 484 | // implemented so we can change the separator on runtime for testing purposes, 485 | // since filepath.Separator is a constant. 486 | // Parts of the code using filepath.ToSlash that need cross-platform tests 487 | // should use this instead. 488 | func toSlash(path string) string { 489 | if separator == '/' { 490 | return path 491 | } 492 | return strings.Replace(path, string(separator), "/", -1) 493 | } 494 | 495 | func removeGoPath(path string) string { 496 | var prefix string 497 | if strings.HasPrefix(path, "[]*") { 498 | prefix = "[]*" 499 | path = path[3:] 500 | } else if strings.HasPrefix(path, "[]") { 501 | prefix = "[]" 502 | path = path[2:] 503 | } else if strings.HasPrefix(path, "*") { 504 | prefix = "*" 505 | path = path[1:] 506 | } 507 | 508 | path = toSlash(path) 509 | for _, p := range parseutil.DefaultGoPath { 510 | p = toSlash(p + "/src/") 511 | if strings.HasPrefix(path, p) { 512 | // Directories named "vendor" are only vendor directories 513 | // if they're under $GOPATH/src. 514 | if idx := strings.LastIndex(path, "/vendor/"); idx >= len(p)-1 { 515 | return prefix + path[idx+8:] 516 | } 517 | return prefix + path[len(p):] 518 | } 519 | } 520 | return prefix + path 521 | } 522 | 523 | func isIgnoredField(s *types.Struct, idx int) bool { 524 | tag := reflect.StructTag(s.Tag(idx)) 525 | return strings.Split(tag.Get("kallax"), ",")[0] == "-" 526 | } 527 | -------------------------------------------------------------------------------- /generator/processor_test.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "go/types" 5 | "reflect" 6 | "testing" 7 | 8 | "gopkg.in/src-d/go-parse-utils.v1" 9 | 10 | "github.com/stretchr/testify/require" 11 | "github.com/stretchr/testify/suite" 12 | ) 13 | 14 | type ProcessorSuite struct { 15 | suite.Suite 16 | } 17 | 18 | func (s *ProcessorSuite) TestInlineStruct() { 19 | fixtureSrc := ` 20 | package fixture 21 | 22 | import "gopkg.in/src-d/go-kallax.v1" 23 | 24 | type Foo struct {} 25 | 26 | type Bar struct { 27 | kallax.Model 28 | ID int64 ` + "`pk:\"autoincr\"`" + ` 29 | Foo string 30 | R *Foo ` + "`kallax:\",inline\"`" + ` 31 | } 32 | ` 33 | 34 | pkg := s.processFixture(fixtureSrc) 35 | s.True(findModel(pkg, "Bar").Fields[3].Inline()) 36 | } 37 | 38 | func (s *ProcessorSuite) TestTags() { 39 | fixtureSrc := ` 40 | package fixture 41 | 42 | import "gopkg.in/src-d/go-kallax.v1" 43 | 44 | type Foo struct { 45 | kallax.Model 46 | ID int64 ` + "`pk:\"autoincr\"`" + ` 47 | Int int "foo" 48 | } 49 | ` 50 | 51 | pkg := s.processFixture(fixtureSrc) 52 | s.Equal(reflect.StructTag("foo"), findModel(pkg, "Foo").Fields[2].Tag) 53 | } 54 | 55 | func (s *ProcessorSuite) TestRecursiveModel() { 56 | fixtureSrc := ` 57 | package fixture 58 | 59 | import "gopkg.in/src-d/go-kallax.v1" 60 | 61 | type Recur struct { 62 | kallax.Model 63 | ID int64 ` + "`pk:\"autoincr\"`" + ` 64 | Foo string 65 | R *Recur 66 | } 67 | ` 68 | 69 | pkg := s.processFixture(fixtureSrc) 70 | m := findModel(pkg, "Recur") 71 | 72 | s.Equal(findField(m, "R").Kind, Relationship) 73 | s.Len(findField(m, "R").Fields, 0) 74 | } 75 | 76 | func (s *ProcessorSuite) TestDeepRecursiveStruct() { 77 | fixtureSrc := ` 78 | package fixture 79 | 80 | import "gopkg.in/src-d/go-kallax.v1" 81 | 82 | type Recur struct { 83 | kallax.Model 84 | ID int64 ` + "`pk:\"autoincr\"`" + ` 85 | Foo string 86 | Rec *Other 87 | } 88 | 89 | type Other struct { 90 | R *Recur 91 | } 92 | ` 93 | 94 | pkg := s.processFixture(fixtureSrc) 95 | m := findModel(pkg, "Recur") 96 | 97 | s.Equal( 98 | m.Fields[3].Fields[0].Fields[3].Node, 99 | m.Fields[3].Node, 100 | "indirect type recursivity not handled correctly.", 101 | ) 102 | s.Len(pkg.Models[0].Fields[3].Fields[0].Fields[3].Fields, 0) 103 | } 104 | 105 | func (s *ProcessorSuite) TestIsEventPresent() { 106 | fixtureSrc := ` 107 | package fixture 108 | 109 | import "gopkg.in/src-d/go-kallax.v1" 110 | 111 | type Foo struct { 112 | kallax.Model 113 | ID int64 ` + "`pk:\"autoincr\"`" + ` 114 | Foo string 115 | } 116 | 117 | func (r *Foo) BeforeUpdate() error { 118 | return nil 119 | } 120 | 121 | func (r *Foo) BeforeInsert() int { 122 | return 0 123 | } 124 | 125 | func (r *Foo) AfterInsert() int { 126 | return 0 127 | } 128 | 129 | func (r *Foo) AfterUpdate(foo int) { 130 | } 131 | ` 132 | 133 | p := s.processorFixture(fixtureSrc) 134 | pkg, err := p.processPackage() 135 | s.Nil(err) 136 | 137 | m := findModel(pkg, "Foo") 138 | s.True(p.isEventPresent(m.Node, BeforeUpdate)) 139 | s.False(p.isEventPresent(m.Node, BeforeInsert)) 140 | s.False(p.isEventPresent(m.Node, AfterInsert)) 141 | s.False(p.isEventPresent(m.Node, AfterUpdate)) 142 | } 143 | 144 | func (s *ProcessorSuite) TestProcessField() { 145 | fixtureSrc := ` 146 | package fixture 147 | 148 | import "gopkg.in/src-d/go-kallax.v1" 149 | import "database/sql/driver" 150 | 151 | type BasicAlias string 152 | type MapAlias map[string]int 153 | type SliceAlias []string 154 | type ArrayAlias [4]string 155 | 156 | type Related struct { 157 | kallax.Model 158 | ID int64 ` + "`pk:\"autoincr\"`" + ` 159 | Foo string 160 | } 161 | 162 | type JSON struct { 163 | Bar string 164 | } 165 | 166 | type Interface interface { 167 | Foo() 168 | } 169 | 170 | type SQLInterface interface { 171 | Scan(interface{}) error 172 | Value(interface{}) (driver.Value, error) 173 | } 174 | 175 | type Foo struct { 176 | kallax.Model 177 | ID int64 ` + "`pk:\"autoincr\"`" + ` 178 | Basic string 179 | AliasBasic BasicAlias 180 | BasicPtr *string 181 | Relationship Related 182 | RelSlice []Related 183 | RelArray [4]Related 184 | Map map[string]interface{} 185 | MapAlias MapAlias 186 | AliasSlice SliceAlias 187 | BasicSlice []string 188 | ComplexSlice []JSON 189 | JSON JSON 190 | JSONPtr *JSON 191 | AliasArray ArrayAlias 192 | BasicArray [4]string 193 | ComplexArray [4]JSON 194 | InlineArray struct{A int} 195 | Interface Interface 196 | SQLInterface SQLInterface 197 | } 198 | ` 199 | 200 | pkg := s.processFixture(fixtureSrc) 201 | cases := []struct { 202 | name string 203 | kind FieldKind 204 | isJSON bool 205 | isAlias bool 206 | isPtr bool 207 | }{ 208 | {"Basic", Basic, false, false, false}, 209 | {"AliasBasic", Basic, false, true, false}, 210 | {"BasicPtr", Basic, false, false, true}, 211 | {"Relationship", Relationship, false, false, false}, 212 | {"RelSlice", Relationship, false, false, false}, 213 | {"Map", Map, true, false, false}, 214 | {"MapAlias", Map, true, false, false}, 215 | {"AliasSlice", Slice, false, true, false}, 216 | {"BasicSlice", Slice, false, false, false}, 217 | {"ComplexSlice", Slice, true, false, false}, 218 | {"JSON", Struct, true, false, false}, 219 | {"JSONPtr", Struct, true, false, true}, 220 | {"AliasArray", Array, false, true, false}, 221 | {"BasicArray", Array, false, false, false}, 222 | {"ComplexArray", Array, true, false, false}, 223 | {"InlineArray", Struct, true, false, false}, 224 | {"Interface", Interface, true, false, false}, 225 | {"SQLInterface", Interface, true, false, false}, // TODO false, false, false 226 | } 227 | 228 | m := findModel(pkg, "Foo") 229 | for _, c := range cases { 230 | f := findField(m, c.name) 231 | s.NotNil(f, "%s should not be nil", c.name) 232 | 233 | s.Equal(c.kind, f.Kind, "%s kind", c.name) 234 | s.Equal(c.isJSON, f.IsJSON, "%s is json", c.name) 235 | s.Equal(c.isAlias, f.IsAlias, "%s is alias", c.name) 236 | s.Equal(c.isPtr, f.IsPtr, "%s is ptr", c.name) 237 | } 238 | 239 | s.Nil(findField(m, "RelArray"), "RelArray should not be generated") 240 | } 241 | 242 | func (s *ProcessorSuite) TestCtor() { 243 | fixtureSrc := ` 244 | package fixture 245 | 246 | import "gopkg.in/src-d/go-kallax.v1" 247 | 248 | type Foo struct { 249 | kallax.Model 250 | ID int64 ` + "`pk:\"autoincr\"`" + ` 251 | Foo string 252 | } 253 | 254 | func newFoo() *Foo { 255 | return &Foo{} 256 | } 257 | ` 258 | 259 | pkg := s.processFixture(fixtureSrc) 260 | m := findModel(pkg, "Foo") 261 | 262 | s.NotNil(m.CtorFunc, "Foo should have ctor") 263 | } 264 | 265 | func (s *ProcessorSuite) TestSQLTypeIsInterface() { 266 | fixtureSrc := ` 267 | package fixture 268 | 269 | import "gopkg.in/src-d/go-kallax.v1" 270 | import "database/sql/driver" 271 | 272 | type Foo struct { 273 | kallax.Model 274 | ID int64 ` + "`pk:\"autoincr\"`" + ` 275 | Foo Bar 276 | } 277 | 278 | type Bar string 279 | 280 | func (*Bar) Scan(v interface{}) error { 281 | return nil 282 | } 283 | 284 | func (Bar) Value() (driver.Value, error) { 285 | return nil, nil 286 | } 287 | ` 288 | 289 | pkg := s.processFixture(fixtureSrc) 290 | field := findField(findModel(pkg, "Foo"), "Foo") 291 | s.Equal(Interface, field.Kind) 292 | } 293 | 294 | func (s *ProcessorSuite) TestIsSQLType() { 295 | fixtureSrc := ` 296 | package fixture 297 | 298 | import "gopkg.in/src-d/go-kallax.v1" 299 | 300 | type SQLTypeFixture struct { 301 | kallax.Model 302 | ID kallax.ULID ` + "`pk:\"\"`" + ` 303 | Foo string 304 | } 305 | ` 306 | 307 | p := s.processorFixture(fixtureSrc) 308 | pkg, err := p.processPackage() 309 | s.Nil(err) 310 | m := findModel(pkg, "SQLTypeFixture") 311 | 312 | s.True(isSQLType(p.Package, types.NewPointer(m.ID.Node.Type()))) 313 | // model is index 1 because ID is always index 0 314 | s.False(isSQLType(p.Package, types.NewPointer(m.Fields[1].Node.Type()))) 315 | } 316 | 317 | func (s *ProcessorSuite) processorFixture(source string) *Processor { 318 | prc, err := processorFixture(source) 319 | s.Require().NoError(err) 320 | return prc 321 | } 322 | 323 | func (s *ProcessorSuite) processFixture(source string) *Package { 324 | pkg, err := processFixture(source) 325 | s.Require().NoError(err) 326 | return pkg 327 | } 328 | 329 | func (s *ProcessorSuite) TestDo() { 330 | p := NewProcessor(pkgAbsPath, []string{"README.md"}) 331 | pkg, err := p.Do() 332 | s.NotNil(pkg) 333 | s.NoError(err) 334 | } 335 | 336 | func (s *ProcessorSuite) TestIsModel() { 337 | src := ` 338 | package fixture 339 | 340 | import "gopkg.in/src-d/go-kallax.v1" 341 | 342 | type Bar struct { 343 | kallax.Model 344 | ID int64 ` + "`pk:\"autoincr\"`" + ` 345 | Bar string 346 | } 347 | 348 | type Struct struct { 349 | Bar Bar 350 | } 351 | 352 | type Foo struct { 353 | kallax.Model 354 | ID int64 ` + "`pk:\"autoincr\"`" + ` 355 | Foo string 356 | Ptr *Bar 357 | NoPtr Bar 358 | Struct Struct 359 | } 360 | ` 361 | pkg := s.processFixture(src) 362 | m := findModel(pkg, "Foo") 363 | cases := []struct { 364 | field string 365 | expected bool 366 | }{ 367 | {"Foo", false}, 368 | {"Ptr", true}, 369 | {"NoPtr", true}, 370 | {"Struct", false}, 371 | } 372 | 373 | for _, c := range cases { 374 | s.Equal(c.expected, isModel(findField(m, c.field).Node.Type(), true), c.field) 375 | } 376 | } 377 | 378 | func (s *ProcessorSuite) TestIsEmbedded() { 379 | src := ` 380 | package fixture 381 | 382 | import "gopkg.in/src-d/go-kallax.v1" 383 | 384 | type Bar struct { 385 | kallax.Model 386 | ID int64 ` + "`pk:\"autoincr\"`" + ` 387 | Baz string 388 | } 389 | 390 | type Struct struct { 391 | Qux Bar 392 | } 393 | 394 | type Struct2 struct { 395 | Mux string 396 | } 397 | 398 | type Foo struct { 399 | kallax.Model 400 | ID int64 ` + "`pk:\"autoincr\"`" + ` 401 | A Bar 402 | B *Bar 403 | Struct2 404 | *Struct 405 | C struct { 406 | D int 407 | } 408 | } 409 | ` 410 | pkg := s.processFixture(src) 411 | m := findModel(pkg, "Foo") 412 | expected := []string{ 413 | "ID", "Model", "A", "B", "Mux", "Qux", "C", 414 | } 415 | 416 | var names []string 417 | for _, f := range m.Fields { 418 | names = append(names, f.Name) 419 | } 420 | 421 | s.Equal(expected, names) 422 | } 423 | 424 | func TestProcessor(t *testing.T) { 425 | suite.Run(t, new(ProcessorSuite)) 426 | } 427 | 428 | func TestRemoveGoPath(t *testing.T) { 429 | oldGoPath := parseutil.DefaultGoPath 430 | oldSep := separator 431 | defer func() { 432 | parseutil.DefaultGoPath = oldGoPath 433 | separator = oldSep 434 | }() 435 | 436 | cases := []struct { 437 | typ string 438 | result string 439 | gopath []string 440 | sep rune 441 | }{ 442 | { 443 | `E:\workspace\gopath\src\gopkg.in\src-d\go-kallax.v1\tests\fixtures.AliasString`, 444 | "gopkg.in/src-d/go-kallax.v1/tests/fixtures.AliasString", 445 | []string{ 446 | `E:\workspace\gopath`, 447 | }, 448 | '\\', 449 | }, 450 | { 451 | "/home/workspace/gopath/src/gopkg.in/src-d/go-kallax.v1/tests/fixtures.AliasString", 452 | "gopkg.in/src-d/go-kallax.v1/tests/fixtures.AliasString", 453 | []string{ 454 | "/home/foo/go", 455 | "/home/workspace/gopath", 456 | }, 457 | '/', 458 | }, 459 | { 460 | "/go/src/foo/go/src/fixtures.AliasString", 461 | "foo/go/src/fixtures.AliasString", 462 | []string{ 463 | "/go", 464 | }, 465 | '/', 466 | }, 467 | { 468 | "/home/workspace/gopath/src/foo/bar/vendor/gopkg.in/src-d/go-kallax.v1/tests/fixtures.AliasString", 469 | "gopkg.in/src-d/go-kallax.v1/tests/fixtures.AliasString", 470 | []string{ 471 | "/home/foo/go", 472 | "/home/workspace/gopath", 473 | }, 474 | '/', 475 | }, 476 | { 477 | "/home/vendor/workspace/gopath/src/gopkg.in/src-d/go-kallax.v1/tests/fixtures.AliasString", 478 | "gopkg.in/src-d/go-kallax.v1/tests/fixtures.AliasString", 479 | []string{ 480 | "/home/foo/go", 481 | "/home/vendor/workspace/gopath", 482 | }, 483 | '/', 484 | }, 485 | { 486 | "/home/vendor/workspace/gopath/src/vendor/gopkg.in/src-d/go-kallax.v1/tests/fixtures.AliasString", 487 | "gopkg.in/src-d/go-kallax.v1/tests/fixtures.AliasString", 488 | []string{ 489 | "/home/foo/go", 490 | "/home/vendor/workspace/gopath", 491 | }, 492 | '/', 493 | }, 494 | } 495 | 496 | for _, c := range cases { 497 | parseutil.DefaultGoPath = parseutil.GoPath(c.gopath) 498 | separator = c.sep 499 | require.Equal(t, c.result, removeGoPath(c.typ), c.typ) 500 | } 501 | } 502 | 503 | func findModel(pkg *Package, name string) *Model { 504 | for _, m := range pkg.Models { 505 | if m.Name == name { 506 | return m 507 | } 508 | } 509 | return nil 510 | } 511 | 512 | func findField(m *Model, name string) *Field { 513 | for _, f := range m.Fields { 514 | if f.Name == name { 515 | return f 516 | } 517 | } 518 | return nil 519 | } 520 | -------------------------------------------------------------------------------- /generator/templates/base.tgo: -------------------------------------------------------------------------------- 1 | // Code generated by https://github.com/src-d/go-kallax. DO NOT EDIT. 2 | // Please, do not touch the code below, and if you do, do it under your own 3 | // risk. Take into account that all the code you write here will be completely 4 | // erased from earth the next time you generate the kallax models. 5 | package {{.Name}} 6 | 7 | import ( 8 | "gopkg.in/src-d/go-kallax.v1" 9 | "gopkg.in/src-d/go-kallax.v1/types" 10 | "database/sql" 11 | "database/sql/driver" 12 | "fmt" 13 | ) 14 | 15 | var _ types.SQLType 16 | var _ fmt.Formatter 17 | 18 | type modelSaveFunc func(*kallax.Store) error 19 | 20 | {{template "model" .}} 21 | {{template "schema" .}} 22 | -------------------------------------------------------------------------------- /generator/templates/query.tgo: -------------------------------------------------------------------------------- 1 | 2 | // {{.QueryName}} is the object used to create queries for the {{.Name}} 3 | // entity. 4 | type {{.QueryName}} struct { 5 | *kallax.BaseQuery 6 | } 7 | 8 | // New{{.QueryName}} returns a new instance of {{.QueryName}}. 9 | func New{{.QueryName}}() *{{.QueryName}} { 10 | return &{{.QueryName}}{ 11 | BaseQuery: kallax.NewBaseQuery(Schema.{{.Name}}.BaseSchema), 12 | } 13 | } 14 | 15 | // Select adds columns to select in the query. 16 | func (q *{{.QueryName}}) Select(columns ...kallax.SchemaField) *{{.QueryName}} { 17 | if len(columns) == 0 { 18 | return q 19 | } 20 | q.BaseQuery.Select(columns...) 21 | return q 22 | } 23 | 24 | // SelectNot excludes columns from being selected in the query. 25 | func (q *{{.QueryName}}) SelectNot(columns ...kallax.SchemaField) *{{.QueryName}} { 26 | q.BaseQuery.SelectNot(columns...) 27 | return q 28 | } 29 | 30 | // Copy returns a new identical copy of the query. Remember queries are mutable 31 | // so make a copy any time you need to reuse them. 32 | func (q *{{.QueryName}}) Copy() *{{.QueryName}} { 33 | return &{{.QueryName}}{ 34 | BaseQuery: q.BaseQuery.Copy(), 35 | } 36 | } 37 | 38 | // Order adds order clauses to the query for the given columns. 39 | func (q *{{.QueryName}}) Order(cols ...kallax.ColumnOrder) *{{.QueryName}} { 40 | q.BaseQuery.Order(cols...) 41 | return q 42 | } 43 | 44 | // BatchSize sets the number of items to fetch per batch when there are 1:N 45 | // relationships selected in the query. 46 | func (q *{{.QueryName}}) BatchSize(size uint64) *{{.QueryName}} { 47 | q.BaseQuery.BatchSize(size) 48 | return q 49 | } 50 | 51 | // Limit sets the max number of items to retrieve. 52 | func (q *{{.QueryName}}) Limit(n uint64) *{{.QueryName}} { 53 | q.BaseQuery.Limit(n) 54 | return q 55 | } 56 | 57 | // Offset sets the number of items to skip from the result set of items. 58 | func (q *{{.QueryName}}) Offset(n uint64) *{{.QueryName}} { 59 | q.BaseQuery.Offset(n) 60 | return q 61 | } 62 | 63 | // Where adds a condition to the query. All conditions added are concatenated 64 | // using a logical AND. 65 | func (q *{{.QueryName}}) Where(cond kallax.Condition) *{{.QueryName}} { 66 | q.BaseQuery.Where(cond) 67 | return q 68 | } 69 | 70 | {{range .Relationships}} 71 | {{if not .IsOneToManyRelationship}} 72 | func (q *{{$.QueryName}}) With{{.Name}}() *{{$.QueryName}} { 73 | q.AddRelation(Schema.{{.TypeSchemaName}}.BaseSchema, "{{.Name}}", kallax.OneToOne, nil) 74 | return q 75 | } 76 | {{else}} 77 | func (q *{{$.QueryName}}) With{{.Name}}(cond kallax.Condition) *{{$.QueryName}} { 78 | q.AddRelation(Schema.{{.TypeSchemaName}}.BaseSchema, "{{.Name}}", kallax.OneToMany, cond) 79 | return q 80 | } 81 | {{end}} 82 | {{end}} 83 | -------------------------------------------------------------------------------- /generator/templates/resultset.tgo: -------------------------------------------------------------------------------- 1 | 2 | // {{.ResultSetName}} is the set of results returned by a query to the 3 | // database. 4 | type {{.ResultSetName}} struct { 5 | ResultSet kallax.ResultSet 6 | last *{{.Name}} 7 | lastErr error 8 | } 9 | 10 | // New{{.ResultSetName}} creates a new result set for rows of the type 11 | // {{.Name}}. 12 | func New{{.ResultSetName}}(rs kallax.ResultSet) *{{.ResultSetName}} { 13 | return &{{.ResultSetName}}{ResultSet: rs} 14 | } 15 | 16 | // Next fetches the next item in the result set and returns true if there is 17 | // a next item. 18 | // The result set is closed automatically when there are no more items. 19 | func (rs *{{.ResultSetName}}) Next() bool { 20 | if !rs.ResultSet.Next() { 21 | rs.lastErr = rs.ResultSet.Close() 22 | rs.last = nil 23 | return false 24 | } 25 | 26 | var record kallax.Record 27 | record, rs.lastErr = rs.ResultSet.Get(Schema.{{.Name}}.BaseSchema) 28 | if rs.lastErr != nil { 29 | rs.last = nil 30 | } else { 31 | var ok bool 32 | rs.last, ok = record.(*{{.Name}}) 33 | if !ok { 34 | rs.lastErr = fmt.Errorf("kallax: unable to convert record to *{{.Name}}") 35 | rs.last = nil 36 | } 37 | } 38 | 39 | return true 40 | } 41 | 42 | // Get retrieves the last fetched item from the result set and the last error. 43 | func (rs *{{.ResultSetName}}) Get() (*{{.Name}}, error) { 44 | return rs.last, rs.lastErr 45 | } 46 | 47 | // ForEach iterates over the complete result set passing every record found to 48 | // the given callback. It is possible to stop the iteration by returning 49 | // `kallax.ErrStop` in the callback. 50 | // Result set is always closed at the end. 51 | func (rs *{{.ResultSetName}}) ForEach(fn func(*{{.Name}}) error) error { 52 | for rs.Next() { 53 | record, err := rs.Get() 54 | if err != nil { 55 | return err 56 | } 57 | 58 | if err := fn(record); err != nil { 59 | if err == kallax.ErrStop { 60 | return rs.Close() 61 | } 62 | 63 | return err 64 | } 65 | } 66 | return nil 67 | } 68 | 69 | // All returns all records on the result set and closes the result set. 70 | func (rs *{{.ResultSetName}}) All() ([]*{{.Name}}, error) { 71 | var result []*{{.Name}} 72 | for rs.Next() { 73 | record, err := rs.Get() 74 | if err != nil { 75 | return nil, err 76 | } 77 | result = append(result, record) 78 | } 79 | return result, nil 80 | } 81 | 82 | // One returns the first record on the result set and closes the result set. 83 | func (rs *{{.ResultSetName}}) One() (*{{.Name}}, error) { 84 | if !rs.Next() { 85 | return nil, kallax.ErrNotFound 86 | } 87 | 88 | record, err := rs.Get() 89 | if err != nil { 90 | return nil, err 91 | } 92 | 93 | if err := rs.Close(); err != nil { 94 | return nil, err 95 | } 96 | 97 | return record, nil 98 | } 99 | 100 | // Err returns the last error occurred. 101 | func (rs *{{.ResultSetName}}) Err() error { 102 | return rs.lastErr 103 | } 104 | 105 | // Close closes the result set. 106 | func (rs *{{.ResultSetName}}) Close() error { 107 | return rs.ResultSet.Close() 108 | } 109 | -------------------------------------------------------------------------------- /generator/templates/schema.tgo: -------------------------------------------------------------------------------- 1 | 2 | type schema struct { 3 | {{range .Models}}{{.Name}} *schema{{.Name}} 4 | {{end}} 5 | } 6 | 7 | {{range .Models}} 8 | type schema{{.Name}} struct { 9 | *kallax.BaseSchema 10 | {{$.GenModelSchema .}} 11 | } 12 | {{end}} 13 | 14 | {{$.GenSubSchemas}} 15 | 16 | var Schema = &schema{ 17 | {{range .Models}}{{.Name}}: &schema{{.Name}}{ 18 | BaseSchema: kallax.NewBaseSchema( 19 | "{{.Table}}", 20 | "{{.Alias}}", 21 | kallax.NewSchemaField("{{.ID.ColumnName}}"), 22 | kallax.ForeignKeys{ 23 | {{range .Relationships}}"{{.Name}}": kallax.NewForeignKey("{{.ForeignKey}}", {{if .IsInverse}}true{{else}}false{{end}}), 24 | {{end}} 25 | }, 26 | func() kallax.Record { 27 | return new({{.Name}}) 28 | }, 29 | {{if .ID.IsAutoIncrement}}true{{else}}false{{end}}, 30 | {{$.GenModelColumns .}} 31 | ), 32 | {{$.GenSchemaInit .}} 33 | }, 34 | {{end}} 35 | } 36 | -------------------------------------------------------------------------------- /generator/types_test.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "go/ast" 5 | "go/importer" 6 | "go/parser" 7 | "go/token" 8 | "go/types" 9 | "reflect" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/require" 13 | "github.com/stretchr/testify/suite" 14 | ) 15 | 16 | type FieldSuite struct { 17 | suite.Suite 18 | } 19 | 20 | func TestField(t *testing.T) { 21 | suite.Run(t, new(FieldSuite)) 22 | } 23 | 24 | func (s *FieldSuite) TestInline() { 25 | cases := []struct { 26 | typ string 27 | tag string 28 | inline bool 29 | }{ 30 | {"", "", false}, 31 | {BaseModel, "", true}, 32 | {"", `kallax:"foo"`, false}, 33 | {"", `kallax:"foo,inline"`, true}, 34 | {"", `kallax:"foo,inline,omitempty"`, true}, 35 | {"", `kallax:",inline,omitempty"`, true}, 36 | {"", `kallax:",inline"`, true}, 37 | } 38 | 39 | for _, c := range cases { 40 | s.Equal(c.inline, mkField("", c.typ, c.tag).Inline(), "field with tag: %s", c.tag) 41 | } 42 | } 43 | 44 | func (s *FieldSuite) TestIsPrimaryKey() { 45 | cases := []struct { 46 | tag string 47 | ok bool 48 | }{ 49 | {"", false}, 50 | {`kallax:"pk"`, false}, 51 | {`kallax:"foo,pk"`, false}, 52 | {`pk:""`, true}, 53 | {`pk:"foo"`, true}, 54 | {`pk:"autoincr"`, true}, 55 | } 56 | 57 | for _, c := range cases { 58 | s.Equal(c.ok, mkField("", "", c.tag).IsPrimaryKey(), "field with tag: %s", c.tag) 59 | } 60 | } 61 | 62 | func (s *FieldSuite) TestIsAutoIncrement() { 63 | cases := []struct { 64 | tag string 65 | ok bool 66 | }{ 67 | {"", false}, 68 | {`pk:""`, false}, 69 | {`pk:"ponies"`, false}, 70 | {`pk:"autoincr"`, true}, 71 | } 72 | 73 | for _, c := range cases { 74 | s.Equal(c.ok, mkField("", "", c.tag).IsAutoIncrement(), "field with tag: %s", c.tag) 75 | } 76 | } 77 | 78 | func (s *FieldSuite) TestColumnName() { 79 | cases := []struct { 80 | tag string 81 | name string 82 | expected string 83 | }{ 84 | {"", "Foo", "foo"}, 85 | {"", "FooBar", "foo_bar"}, 86 | {"", "ID", "id"}, 87 | {"", "References", "_references"}, 88 | {`kallax:"foo"`, "Bar", "foo"}, 89 | {`kallax:"References"`, "Bar", "_References"}, 90 | {`kallax:"references"`, "Bar", "_references"}, 91 | } 92 | 93 | for _, c := range cases { 94 | name := mkField(c.name, "", c.tag).ColumnName() 95 | s.Equal(c.expected, name, "field with name: %q and tag: %s", c.name, c.tag) 96 | } 97 | } 98 | 99 | func (s *FieldSuite) TestAddress() { 100 | cases := []struct { 101 | kind FieldKind 102 | isJSON bool 103 | isPtr bool 104 | name string 105 | typeStr string 106 | parent *Field 107 | expected string 108 | }{ 109 | { 110 | Struct, true, false, "Foo", "", nil, 111 | "types.JSON(&r.Foo)", 112 | }, 113 | { 114 | Map, true, false, "Foo", "", nil, 115 | "types.JSON(&r.Foo)", 116 | }, 117 | { 118 | Struct, false, false, "Foo", "", nil, 119 | "&r.Foo", 120 | }, 121 | { 122 | Array, false, false, "Foo", "[5]string", nil, 123 | `types.Array(&r.Foo, 5)`, 124 | }, 125 | { 126 | Slice, false, false, "Foo", "[]string", nil, 127 | "types.Slice(&r.Foo)", 128 | }, 129 | { 130 | Basic, false, true, "Foo", "", nil, 131 | "types.Nullable(&r.Foo)", 132 | }, 133 | { 134 | Interface, false, true, "Foo", "", nil, 135 | "types.Nullable(r.Foo)", 136 | }, 137 | { 138 | Basic, false, true, "Foo", "", withParent(mkField("Bar", "", ""), mkField("Baz", "", "")), 139 | "types.Nullable(&r.Baz.Bar.Foo)", 140 | }, 141 | } 142 | 143 | for i, c := range cases { 144 | f := withKind(withParent(mkField(c.name, c.typeStr, ""), c.parent), c.kind) 145 | if c.isJSON { 146 | f = withJSON(f) 147 | } 148 | 149 | if c.isPtr { 150 | f = withPtr(f) 151 | } 152 | 153 | s.Equal(c.expected, f.Address(), "Field %s, i = %d", f.Name, i) 154 | } 155 | } 156 | 157 | func (s *FieldSuite) TestValue() { 158 | cases := []struct { 159 | field *Field 160 | expected string 161 | }{ 162 | { 163 | mkField("Foo", "string", ""), 164 | "r.Foo, nil", 165 | }, 166 | { 167 | withAlias(mkField("Foo", "string", "")), 168 | "(string)(r.Foo), nil", 169 | }, 170 | { 171 | withPtr(withAlias(mkField("Foo", "string", ""))), 172 | "(*string)(r.Foo), nil", 173 | }, 174 | { 175 | withKind(mkField("Foo", "", ""), Slice), 176 | "types.Slice(r.Foo), nil", 177 | }, 178 | { 179 | withKind(mkField("Foo", "[5]string", ""), Array), 180 | `types.Array(&r.Foo, 5), nil`, 181 | }, 182 | { 183 | withJSON(withKind(mkField("Foo", "", ""), Map)), 184 | "types.JSON(r.Foo), nil", 185 | }, 186 | { 187 | withKind(mkField("Foo", "", ""), Struct), 188 | "r.Foo, nil", 189 | }, 190 | } 191 | 192 | for i, c := range cases { 193 | s.Equal(c.expected, c.field.Value(), "Field %s, i=%d", c.field.Name, i) 194 | } 195 | } 196 | 197 | type ModelSuite struct { 198 | suite.Suite 199 | model *Model 200 | variadic *Model 201 | } 202 | 203 | const fixturesSource = ` 204 | package fixtures 205 | 206 | import ( 207 | "errors" 208 | "strings" 209 | 210 | kallax "gopkg.in/src-d/go-kallax.v1" 211 | ) 212 | 213 | type User struct { 214 | kallax.Model ` + "`table:\"users\" pk:\"id\"`" + ` 215 | ID kallax.ULID 216 | Username string 217 | Email string 218 | Password Password 219 | Websites []string 220 | Emails []*Email 221 | Settings *Settings 222 | } 223 | 224 | func newUser(id kallax.ULID, username, email string, websites []string) (*User, error) { 225 | if strings.Contains(email, "@spam.org") { 226 | return nil, errors.New("kallax: is spam!") 227 | } 228 | return &User{ID: id, Username: username, Email: email, Websites: websites}, nil 229 | } 230 | 231 | type Email struct { 232 | kallax.Model ` + "`table:\"models\"`" + ` 233 | ID int64 ` + "`pk:\"autoincr\"`" + ` 234 | Address string 235 | Primary bool 236 | } 237 | 238 | func newProfile(address string, primary bool) *Email { 239 | return &Email{Address: address, Primary: primary} 240 | } 241 | 242 | type Password string 243 | 244 | // Kids, don't do this at home 245 | func (p *Password) Set(pwd string) { 246 | *p = Password("such cypher" + pwd + "much secure") 247 | } 248 | 249 | type Settings struct { 250 | NotificationsActive bool 251 | NotifyByEmail bool 252 | } 253 | 254 | type Variadic struct { 255 | kallax.Model 256 | ID int64 ` + "`pk:\"autoincr\"`" + ` 257 | Foo []string 258 | Bar string 259 | } 260 | 261 | func newVariadic(bar string, foo ...string) *Variadic { 262 | return &Variadic{Foo: foo, Bar: bar} 263 | } 264 | ` 265 | 266 | func (s *ModelSuite) SetupSuite() { 267 | fset := &token.FileSet{} 268 | astFile, err := parser.ParseFile(fset, "fixture.go", fixturesSource, 0) 269 | s.Nil(err) 270 | 271 | cfg := &types.Config{ 272 | Importer: importer.For("gc", nil), 273 | } 274 | p, err := cfg.Check("foo", fset, []*ast.File{astFile}, nil) 275 | s.Nil(err) 276 | 277 | prc := NewProcessor("fixture", []string{"foo.go"}) 278 | prc.Package = p 279 | pkg, err := prc.processPackage() 280 | s.Nil(err) 281 | 282 | s.Len(pkg.Models, 3, "there should exist 3 models") 283 | for _, m := range pkg.Models { 284 | if m.Name == "User" { 285 | s.model = m 286 | } 287 | 288 | if m.Name == "Variadic" { 289 | s.variadic = m 290 | } 291 | } 292 | s.NotNil(s.model, "User struct should be defined") 293 | } 294 | 295 | func (s *ModelSuite) TestModel() { 296 | s.Equal("__user", s.model.Alias()) 297 | s.Equal("users", s.model.Table) 298 | s.Equal("User", s.model.Name) 299 | s.Equal("UserStore", s.model.StoreName) 300 | s.Equal("UserQuery", s.model.QueryName) 301 | s.Equal("UserResultSet", s.model.ResultSetName) 302 | } 303 | 304 | func (s *ModelSuite) TestCtor() { 305 | s.Equal("id kallax.ULID, username string, email string, websites []string", s.model.CtorArgs()) 306 | s.Equal("id, username, email, websites", s.model.CtorArgVars()) 307 | s.Equal("(record *User, err error)", s.model.CtorReturns()) 308 | s.Equal("record, err", s.model.CtorRetVars()) 309 | } 310 | 311 | func (s *ModelSuite) TestCtor_Variadic() { 312 | s.Equal("bar string, foo ...string", s.variadic.CtorArgs()) 313 | s.Equal("bar, foo...", s.variadic.CtorArgVars()) 314 | s.Equal("(record *Variadic)", s.variadic.CtorReturns()) 315 | s.Equal("record", s.variadic.CtorRetVars()) 316 | } 317 | 318 | func (s *ModelSuite) TestModelValidate() { 319 | require := s.Require() 320 | 321 | id := s.model.ID 322 | m := &Model{Name: "Foo", Table: "foo", ID: id} 323 | m.Fields = []*Field{ 324 | mkField("ID", "", ""), 325 | inline(mkField("Nested", "", "", inline( 326 | mkField("Deep", "", "", mkField("ID", "", "")), 327 | ))), 328 | } 329 | require.Error(m.Validate(), "should return error") 330 | 331 | m.Fields = []*Field{ 332 | mkField("ID", "", ""), 333 | inline(mkField("Nested", "", "", mkField("Foo", "", ""))), 334 | } 335 | require.NoError(m.Validate(), "should not return error") 336 | 337 | m.ID = nil 338 | require.Error(m.Validate(), "should return error") 339 | 340 | m.ID = s.model.Fields[2] 341 | require.Error(m.Validate(), "should return error") 342 | 343 | m.ID = id 344 | m.Table = "" 345 | require.Error(m.Validate(), "should return error") 346 | } 347 | 348 | func (s *ModelSuite) TestString() { 349 | s.Equal("\"Variadic\" [3 Field(s)] [Events: []]", s.variadic.String()) 350 | s.Equal("\"User\" [7 Field(s)] [Events: []]", s.model.String()) 351 | } 352 | 353 | func TestFieldForeignKey(t *testing.T) { 354 | r := require.New(t) 355 | m := &Model{Name: "Foo", Table: "bar", Type: "foo.Foo"} 356 | 357 | cases := []struct { 358 | tag string 359 | inverse bool 360 | typ string 361 | expected string 362 | }{ 363 | {`fk:""`, false, "", "foo_id"}, 364 | {`fk:"foo_bar_baz"`, false, "", "foo_bar_baz"}, 365 | {`fk:",inverse"`, true, "Bar", "bar_id"}, 366 | {`fk:"foos,inverse"`, true, "Bar", "foos"}, 367 | {``, false, "", "foo_id"}, 368 | } 369 | 370 | for _, c := range cases { 371 | f := NewField("", "", reflect.StructTag(c.tag)) 372 | f.Kind = Relationship 373 | f.Model = m 374 | f.Type = c.typ 375 | 376 | r.Equal(c.expected, f.ForeignKey(), "foreign key with tag: %s", c.tag) 377 | r.Equal(c.inverse, f.IsInverse(), "is inverse: %s", c.tag) 378 | } 379 | } 380 | 381 | func TestModelSetFields(t *testing.T) { 382 | r := require.New(t) 383 | cases := []struct { 384 | name string 385 | fields []*Field 386 | err bool 387 | id string 388 | }{ 389 | { 390 | "only one primary key", 391 | []*Field{ 392 | mkField("Foo", "", ""), 393 | mkField("ID", "", `pk:""`), 394 | }, 395 | false, 396 | "ID", 397 | }, 398 | { 399 | "multiple primary keys", 400 | []*Field{ 401 | mkField("ID", "", `pk:""`), 402 | mkField("FooID", "", `pk:""`), 403 | }, 404 | true, 405 | "", 406 | }, 407 | { 408 | "primary key defined in model but empty", 409 | []*Field{ 410 | mkField("Model", BaseModel, `pk:""`), 411 | }, 412 | true, 413 | "", 414 | }, 415 | { 416 | "primary key defined in model and non existent", 417 | []*Field{ 418 | mkField("Model", BaseModel, `pk:"foo"`), 419 | mkField("Bar", "", ""), 420 | }, 421 | true, 422 | "", 423 | }, 424 | { 425 | "primary key defined in model", 426 | []*Field{ 427 | mkField("Model", BaseModel, `pk:"foo"`), 428 | mkField("Baz", "", ""), 429 | mkField("Foo", "", ""), 430 | mkField("Bar", "", ""), 431 | }, 432 | false, 433 | "Foo", 434 | }, 435 | } 436 | 437 | for _, c := range cases { 438 | m := new(Model) 439 | err := m.SetFields(c.fields) 440 | if c.err { 441 | r.Error(err, c.name) 442 | } else { 443 | r.NoError(err, c.name) 444 | r.Equal(c.id, m.ID.Name) 445 | } 446 | } 447 | } 448 | 449 | func TestModel(t *testing.T) { 450 | suite.Run(t, new(ModelSuite)) 451 | } 452 | 453 | func TestPkProperties(t *testing.T) { 454 | cases := []struct { 455 | tag string 456 | name string 457 | autoincr bool 458 | isPrimaryKey bool 459 | }{ 460 | {`pk:"bar"`, "bar", false, true}, 461 | {`pk:""`, "", false, true}, 462 | {`pk:"autoincr"`, "", true, true}, 463 | {`pk:",autoincr"`, "", true, true}, 464 | {`bar:"baz" pk:"foo"`, "foo", false, true}, 465 | {`pk:"foo,autoincr"`, "foo", true, true}, 466 | } 467 | 468 | require := require.New(t) 469 | for _, tt := range cases { 470 | name, autoincr, isPrimaryKey := pkProperties(reflect.StructTag(tt.tag)) 471 | require.Equal(tt.name, name, tt.tag) 472 | require.Equal(tt.autoincr, autoincr, tt.tag) 473 | require.Equal(tt.isPrimaryKey, isPrimaryKey, tt.tag) 474 | } 475 | } 476 | 477 | func TestIsUnique(t *testing.T) { 478 | cases := []struct { 479 | tag string 480 | unique bool 481 | }{ 482 | {``, false}, 483 | {`fk:"foo"`, false}, 484 | {`unique:""`, false}, 485 | {`unique:"true"`, true}, 486 | {`fk:"foo" unique:"true"`, true}, 487 | } 488 | 489 | for _, tt := range cases { 490 | t.Run(tt.tag, func(t *testing.T) { 491 | f := NewField("", "", reflect.StructTag(tt.tag)) 492 | require.Equal(t, tt.unique, f.IsUnique()) 493 | }) 494 | } 495 | } 496 | -------------------------------------------------------------------------------- /kallax.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /model_test.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestUULIDIsEmpty(t *testing.T) { 11 | r := require.New(t) 12 | var id ULID 13 | r.True(id.IsEmpty()) 14 | 15 | id = NewULID() 16 | r.False(id.IsEmpty()) 17 | } 18 | 19 | func TestULID_Value(t *testing.T) { 20 | id := NewULID() 21 | v, _ := id.Value() 22 | require.Equal(t, id.String(), v) 23 | } 24 | 25 | func TestUULID_ThreeNewIDsAreDifferent(t *testing.T) { 26 | r := require.New(t) 27 | 28 | goroutines := 100 29 | ids_per_goroutine := 1000 30 | 31 | ids := make(map[ULID]bool, ids_per_goroutine*goroutines) 32 | m := &sync.Mutex{} 33 | 34 | wg := &sync.WaitGroup{} 35 | wg.Add(goroutines) 36 | for i := 0; i < goroutines; i++ { 37 | go func() { 38 | var oids []ULID 39 | for j := 0; j < ids_per_goroutine; j++ { 40 | oids = append(oids, NewULID()) 41 | } 42 | 43 | m.Lock() 44 | for _, id := range oids { 45 | ids[id] = true 46 | } 47 | m.Unlock() 48 | wg.Done() 49 | }() 50 | } 51 | 52 | wg.Wait() 53 | 54 | r.Equal(goroutines*ids_per_goroutine, len(ids)) 55 | } 56 | 57 | func TestULID_ScanValue(t *testing.T) { 58 | r := require.New(t) 59 | 60 | expected := NewULID() 61 | v, err := expected.Value() 62 | r.NoError(err) 63 | 64 | var id ULID 65 | r.NoError(id.Scan(v)) 66 | 67 | r.Equal(expected, id) 68 | r.Equal(expected.String(), id.String()) 69 | 70 | r.NoError(id.Scan([]byte("015af13d-2271-fb69-2dcd-fb24a1fd7dcc"))) 71 | } 72 | 73 | func TestVirtualColumn(t *testing.T) { 74 | r := require.New(t) 75 | record := newModel("", "", 0) 76 | record.virtualColumns = nil 77 | r.Equal(nil, record.VirtualColumn("foo")) 78 | 79 | record.virtualColumns = nil 80 | s := VirtualColumn("foo", record, new(ULID)) 81 | 82 | id := NewULID() 83 | v, err := id.Value() 84 | r.NoError(err) 85 | r.NoError(s.Scan(v)) 86 | r.Len(record.virtualColumns, 1) 87 | r.Equal(&id, record.VirtualColumn("foo")) 88 | 89 | r.Error(s.Scan(nil)) 90 | } 91 | -------------------------------------------------------------------------------- /operators_test.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/suite" 8 | "gopkg.in/src-d/go-kallax.v1/types" 9 | ) 10 | 11 | type OpsSuite struct { 12 | suite.Suite 13 | db *sql.DB 14 | store *Store 15 | } 16 | 17 | func (s *OpsSuite) SetupTest() { 18 | var err error 19 | s.db, err = openTestDB() 20 | s.Nil(err) 21 | s.store = NewStore(s.db) 22 | } 23 | 24 | func (s *OpsSuite) create(sql string) { 25 | _, err := s.db.Exec(sql) 26 | s.NoError(err) 27 | } 28 | 29 | func (s *OpsSuite) remove(table string) { 30 | _, err := s.db.Exec("DROP TABLE IF EXISTS " + table) 31 | s.NoError(err) 32 | } 33 | 34 | func (s *OpsSuite) TestOperators() { 35 | s.create(`CREATE TABLE model ( 36 | id serial PRIMARY KEY, 37 | name varchar(255) not null, 38 | email varchar(255) not null, 39 | age int not null 40 | )`) 41 | defer s.remove("model") 42 | 43 | customGt := NewOperator(":col: > :arg:") 44 | customIn := NewMultiOperator(":col: IN :arg:") 45 | 46 | cases := []struct { 47 | name string 48 | cond Condition 49 | count int64 50 | }{ 51 | {"Eq", Eq(f("name"), "Joe"), 1}, 52 | {"Gt", Gt(f("age"), 1), 2}, 53 | {"customGt", customGt(f("age"), 1), 2}, 54 | {"Lt", Lt(f("age"), 2), 1}, 55 | {"Neq", Neq(f("name"), "Joe"), 2}, 56 | {"Like upper", Like(f("name"), "J%"), 2}, 57 | {"Like lower", Like(f("name"), "j%"), 0}, 58 | {"Ilike upper", Ilike(f("name"), "J%"), 2}, 59 | {"Ilike lower", Ilike(f("name"), "j%"), 2}, 60 | {"SimilarTo", SimilarTo(f("name"), "An{2}a"), 1}, 61 | {"NotSimilarTo", NotSimilarTo(f("name"), "An{2}a"), 2}, 62 | {"GtOrEq", GtOrEq(f("age"), 2), 2}, 63 | {"LtOrEq", LtOrEq(f("age"), 3), 3}, 64 | {"Not", Not(Eq(f("name"), "Joe")), 2}, 65 | {"And", And(Neq(f("name"), "Joe"), Gt(f("age"), 1)), 2}, 66 | {"Or", Or(Neq(f("name"), "Joe"), Eq(f("age"), 1)), 3}, 67 | {"In", In(f("name"), "Joe", "Jane"), 2}, 68 | {"customIn", customIn(f("name"), "Joe", "Jane"), 2}, 69 | {"NotIn", NotIn(f("name"), "Joe", "Jane"), 1}, 70 | {"MatchRegexCase upper", MatchRegexCase(f("name"), "J.*"), 2}, 71 | {"MatchRegexCase lower", MatchRegexCase(f("name"), "j.*"), 0}, 72 | {"MatchRegex upper", MatchRegex(f("name"), "J.*"), 2}, 73 | {"MatchRegex lower", MatchRegex(f("name"), "j.*"), 2}, 74 | {"NotMatchRegexCase upper", NotMatchRegexCase(f("name"), "J.*"), 1}, 75 | {"NotMatchRegexCase lower", NotMatchRegexCase(f("name"), "j.*"), 3}, 76 | {"NotMatchRegex upper", NotMatchRegex(f("name"), "J.*"), 1}, 77 | {"NotMatchRegex lower", NotMatchRegex(f("name"), "j.*"), 1}, 78 | } 79 | 80 | s.Nil(s.store.Insert(ModelSchema, newModel("Joe", "", 1))) 81 | s.Nil(s.store.Insert(ModelSchema, newModel("Jane", "", 2))) 82 | s.Nil(s.store.Insert(ModelSchema, newModel("Anna", "", 2))) 83 | 84 | for _, c := range cases { 85 | q := NewBaseQuery(ModelSchema) 86 | q.Where(c.cond) 87 | 88 | s.Equal(c.count, s.store.Debug().MustCount(q), c.name) 89 | } 90 | } 91 | 92 | func (s *OpsSuite) TestArrayOperators() { 93 | s.create(`CREATE TABLE slices ( 94 | id uuid PRIMARY KEY, 95 | elems bigint[] 96 | )`) 97 | defer s.remove("slices") 98 | 99 | f := f("elems") 100 | 101 | cases := []struct { 102 | name string 103 | cond Condition 104 | ok bool 105 | }{ 106 | {"ArrayEq", ArrayEq(f, 1, 2, 3), true}, 107 | {"ArrayEq fail", ArrayEq(f, 1, 2, 2), false}, 108 | {"ArrayNotEq", ArrayNotEq(f, 1, 2, 2), true}, 109 | {"ArrayNotEq fail", ArrayNotEq(f, 1, 2, 3), false}, 110 | {"ArrayGt", ArrayGt(f, 1, 2, 2), true}, 111 | {"ArrayGt all eq", ArrayGt(f, 1, 2, 3), false}, 112 | {"ArrayGt some lt", ArrayGt(f, 1, 3, 1), false}, 113 | {"ArrayLt", ArrayLt(f, 1, 2, 4), true}, 114 | {"ArrayLt all eq", ArrayLt(f, 1, 2, 3), false}, 115 | {"ArrayLt some gt", ArrayLt(f, 1, 1, 4), false}, 116 | {"ArrayGtOrEq", ArrayGtOrEq(f, 1, 2, 2), true}, 117 | {"ArrayGtOrEq all eq", ArrayGtOrEq(f, 1, 2, 3), true}, 118 | {"ArrayGtOrEq some lt", ArrayGtOrEq(f, 1, 3, 1), false}, 119 | {"ArrayLtOrEq", ArrayLtOrEq(f, 1, 2, 4), true}, 120 | {"ArrayLtOrEq all eq", ArrayLtOrEq(f, 1, 2, 3), true}, 121 | {"ArrayLtOrEq some gt", ArrayLtOrEq(f, 1, 1, 4), false}, 122 | {"ArrayContains", ArrayContains(f, 1, 2), true}, 123 | {"ArrayContains fail", ArrayContains(f, 5, 6), false}, 124 | {"ArrayContainedBy", ArrayContainedBy(f, 1, 2, 3, 5, 6), true}, 125 | {"ArrayContainedBy fail", ArrayContainedBy(f, 1, 2, 5, 6), false}, 126 | {"ArrayOverlap", ArrayOverlap(f, 5, 1, 7), true}, 127 | {"ArrayOverlap fail", ArrayOverlap(f, 6, 7, 8, 9), false}, 128 | } 129 | 130 | _, err := s.db.Exec("INSERT INTO slices (id,elems) VALUES ($1, $2)", NewULID(), types.Slice([]int64{1, 2, 3})) 131 | s.NoError(err) 132 | 133 | for _, c := range cases { 134 | q := NewBaseQuery(SlicesSchema) 135 | q.Where(c.cond) 136 | cnt, err := s.store.Count(q) 137 | s.NoError(err, c.name) 138 | s.Equal(c.ok, cnt > 0, "success: %s", c.name) 139 | } 140 | } 141 | 142 | type object map[string]interface{} 143 | 144 | type array []interface{} 145 | 146 | func (s *OpsSuite) TestJSONOperators() { 147 | s.create(`CREATE TABLE jsons ( 148 | id uuid primary key, 149 | elem jsonb 150 | )`) 151 | defer s.remove("jsons") 152 | 153 | f := f("elem") 154 | cases := []struct { 155 | name string 156 | cond Condition 157 | n int64 158 | }{ 159 | {"JSONIsObject", JSONIsObject(f), 2}, 160 | {"JSONIsArray", JSONIsArray(f), 3}, 161 | {"JSONContains", JSONContains(f, object{"a": 1}), 1}, 162 | {"JSONContainedBy", JSONContainedBy(f, object{ 163 | "a": 1, 164 | "b": 2, 165 | "c": 3, 166 | "d": 1, 167 | }), 1}, 168 | {"JSONContainsAnyKey with array match", JSONContainsAnyKey(f, "a", "c"), 3}, 169 | {"JSONContainsAnyKey", JSONContainsAnyKey(f, "b", "e"), 2}, 170 | {"JSONContainsAllKeys with array match", JSONContainsAllKeys(f, "a", "c"), 3}, 171 | {"JSONContainsAllKeys", JSONContainsAllKeys(f, "b", "e"), 0}, 172 | {"JSONContainsAllKeys only objects", JSONContainsAllKeys(f, "a", "b", "c"), 2}, 173 | {"JSONContainsAny", JSONContainsAny(f, 174 | object{"a": 1}, 175 | object{"a": true}, 176 | ), 2}, 177 | } 178 | 179 | var records = []interface{}{ 180 | array{"a", "c", "d"}, 181 | object{ 182 | "a": true, 183 | "b": array{1, 2, 3}, 184 | "c": object{"d": "foo"}, 185 | }, 186 | object{ 187 | "a": 1, 188 | "b": 2, 189 | "c": 3, 190 | }, 191 | array{.5, 1., 1.5}, 192 | array{1, 2, 3}, 193 | } 194 | 195 | for _, r := range records { 196 | _, err := s.db.Exec("INSERT INTO jsons (id,elem) VALUES ($1, $2)", NewULID(), types.JSON(r)) 197 | s.NoError(err) 198 | } 199 | 200 | for _, c := range cases { 201 | q := NewBaseQuery(JsonsSchema) 202 | q.Where(c.cond) 203 | cnt, err := s.store.Count(q) 204 | s.NoError(err, c.name) 205 | s.Equal(c.n, cnt, "should retrieve %d records: %s", c.n, c.name) 206 | } 207 | } 208 | 209 | func TestOperators(t *testing.T) { 210 | suite.Run(t, new(OpsSuite)) 211 | } 212 | 213 | var SlicesSchema = &BaseSchema{ 214 | alias: "_sl", 215 | table: "slices", 216 | id: f("id"), 217 | columns: []SchemaField{ 218 | f("id"), 219 | f("elems"), 220 | }, 221 | } 222 | 223 | var JsonsSchema = &BaseSchema{ 224 | alias: "_js", 225 | table: "jsons", 226 | id: f("id"), 227 | columns: []SchemaField{ 228 | f("id"), 229 | f("elem"), 230 | }, 231 | } 232 | -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/Masterminds/squirrel" 8 | ) 9 | 10 | var ( 11 | // ErrManyToManyNotSupported is returned when a many to many relationship 12 | // is added to a query. 13 | ErrManyToManyNotSupported = errors.New("kallax: many to many relationships are not supported") 14 | ) 15 | 16 | // Query is the common interface all queries must satisfy. The basic abilities 17 | // of a query are compiling themselves to something executable and return 18 | // some query settings. 19 | type Query interface { 20 | compile() ([]string, squirrel.SelectBuilder) 21 | getRelationships() []Relationship 22 | isReadOnly() bool 23 | // Schema returns the schema of the query model. 24 | Schema() Schema 25 | // GetOffset returns the number of skipped rows in the query. 26 | GetOffset() uint64 27 | // GetLimit returns the max number of rows retrieved by the query. 28 | GetLimit() uint64 29 | // GetBatchSize returns the number of rows retrieved by the store per 30 | // batch. This is only used and has effect on queries with 1:N 31 | // relationships. 32 | GetBatchSize() uint64 33 | } 34 | 35 | type columnSet []SchemaField 36 | 37 | func (cs columnSet) contains(col SchemaField) bool { 38 | for _, c := range cs { 39 | if c.String() == col.String() { 40 | return true 41 | } 42 | } 43 | return false 44 | } 45 | 46 | func (cs *columnSet) add(cols ...SchemaField) { 47 | for _, col := range cols { 48 | cs.addCol(col) 49 | } 50 | } 51 | 52 | func (cs *columnSet) addCol(col SchemaField) { 53 | if !cs.contains(col) { 54 | *cs = append(*cs, col) 55 | } 56 | } 57 | 58 | func (cs *columnSet) remove(cols ...SchemaField) { 59 | var newSet = make(columnSet, 0, len(*cs)) 60 | toRemove := columnSet(cols) 61 | for _, col := range *cs { 62 | if !toRemove.contains(col) { 63 | newSet = append(newSet, col) 64 | } 65 | } 66 | *cs = newSet 67 | } 68 | 69 | func (cs columnSet) copy() []SchemaField { 70 | var result = make(columnSet, len(cs)) 71 | for i, col := range cs { 72 | result[i] = col 73 | } 74 | return result 75 | } 76 | 77 | // BaseQuery is a generic query builder to build queries programmatically. 78 | type BaseQuery struct { 79 | schema Schema 80 | columns columnSet 81 | excludedColumns columnSet 82 | // relationColumns contains the qualified names of the columns selected by the 1:1 83 | // relationships 84 | relationColumns []string 85 | relationships []Relationship 86 | builder squirrel.SelectBuilder 87 | 88 | selectChanged bool 89 | batchSize uint64 90 | offset uint64 91 | limit uint64 92 | } 93 | 94 | // NewBaseQuery creates a new BaseQuery for querying the table of the given schema. 95 | func NewBaseQuery(schema Schema) *BaseQuery { 96 | return &BaseQuery{ 97 | builder: squirrel.StatementBuilder. 98 | PlaceholderFormat(squirrel.Dollar). 99 | Select(). 100 | From(schema.Table() + " " + schema.Alias()), 101 | columns: columnSet(schema.Columns()), 102 | batchSize: 50, 103 | schema: schema, 104 | } 105 | } 106 | 107 | // Schema returns the Schema of the query. 108 | func (q *BaseQuery) Schema() Schema { 109 | return q.schema 110 | } 111 | 112 | func (q *BaseQuery) isReadOnly() bool { 113 | return q.selectChanged 114 | } 115 | 116 | // Select adds the given columns to the list of selected columns in the query. 117 | func (q *BaseQuery) Select(columns ...SchemaField) { 118 | if !q.selectChanged { 119 | q.columns = columnSet{} 120 | q.selectChanged = true 121 | } 122 | 123 | q.excludedColumns.remove(columns...) 124 | q.columns.add(columns...) 125 | } 126 | 127 | // SelectNot adds the given columns to the list of excluded columns in the query. 128 | func (q *BaseQuery) SelectNot(columns ...SchemaField) { 129 | q.excludedColumns.add(columns...) 130 | } 131 | 132 | // Copy returns an identical copy of the query. BaseQuery is mutable, that is 133 | // why this method is provided. 134 | func (q *BaseQuery) Copy() *BaseQuery { 135 | return &BaseQuery{ 136 | builder: q.builder, 137 | columns: q.columns.copy(), 138 | excludedColumns: q.excludedColumns.copy(), 139 | relationColumns: q.relationColumns[:], 140 | relationships: q.relationships[:], 141 | selectChanged: q.selectChanged, 142 | batchSize: q.GetBatchSize(), 143 | limit: q.GetLimit(), 144 | offset: q.GetOffset(), 145 | schema: q.schema, 146 | } 147 | } 148 | 149 | func (q *BaseQuery) getRelationships() []Relationship { 150 | return q.relationships 151 | } 152 | 153 | func (q *BaseQuery) selectedColumns() []SchemaField { 154 | var result = make([]SchemaField, 0, len(q.columns)) 155 | for _, col := range q.columns { 156 | if !q.excludedColumns.contains(col) { 157 | result = append(result, col) 158 | } 159 | } 160 | return result 161 | } 162 | 163 | // AddRelation adds a relationship if the given to the query, which is present 164 | // in the given field of the query base schema. A condition to filter can also 165 | // be passed in the case of one to many relationships. 166 | func (q *BaseQuery) AddRelation(schema Schema, field string, typ RelationshipType, filter Condition) error { 167 | if typ == ManyToMany { 168 | return ErrManyToManyNotSupported 169 | } 170 | 171 | fk, ok := q.schema.ForeignKey(field) 172 | if !ok { 173 | return fmt.Errorf( 174 | "kallax: cannot find foreign key to join tables %s and %s", 175 | q.schema.Table(), schema.Table(), 176 | ) 177 | } 178 | schema = schema.WithAlias(field) 179 | 180 | if typ == OneToOne { 181 | q.join(schema, fk) 182 | } 183 | 184 | q.relationships = append(q.relationships, Relationship{typ, field, schema, filter}) 185 | return nil 186 | } 187 | 188 | func (q *BaseQuery) join(schema Schema, fk *ForeignKey) { 189 | fkCol := fk.QualifiedName(schema) 190 | idCol := q.schema.ID().QualifiedName(q.schema) 191 | if fk.Inverse { 192 | fkCol = schema.ID().QualifiedName(schema) 193 | idCol = fk.QualifiedName(q.schema) 194 | } 195 | 196 | q.builder = q.builder.LeftJoin(fmt.Sprintf( 197 | "%s %s ON (%s = %s)", 198 | schema.Table(), 199 | schema.Alias(), 200 | fkCol, 201 | idCol, 202 | )) 203 | 204 | for _, col := range schema.Columns() { 205 | q.relationColumns = append( 206 | q.relationColumns, 207 | col.QualifiedName(schema), 208 | ) 209 | } 210 | } 211 | 212 | // Order adds the given order clauses to the list of columns to order the 213 | // results by. 214 | func (q *BaseQuery) Order(cols ...ColumnOrder) { 215 | var c = make([]string, len(cols)) 216 | for i, v := range cols { 217 | c[i] = v.ToSql(q.schema) 218 | } 219 | q.builder = q.builder.OrderBy(c...) 220 | } 221 | 222 | // BatchSize sets the batch size. 223 | func (q *BaseQuery) BatchSize(size uint64) { 224 | q.batchSize = size 225 | } 226 | 227 | // GetBatchSize returns the number of rows retrieved per batch while retrieving 228 | // 1:N relationships. 229 | func (q *BaseQuery) GetBatchSize() uint64 { 230 | return q.batchSize 231 | } 232 | 233 | // Limit sets the max number of rows to retrieve. 234 | func (q *BaseQuery) Limit(n uint64) { 235 | q.limit = n 236 | } 237 | 238 | // GetLimit returns the max number of rows to retrieve. 239 | func (q *BaseQuery) GetLimit() uint64 { 240 | return q.limit 241 | } 242 | 243 | // Offset sets the number of rows to skip. 244 | func (q *BaseQuery) Offset(n uint64) { 245 | q.offset = n 246 | } 247 | 248 | // GetOffset returns the number of rows to skip. 249 | func (q *BaseQuery) GetOffset() uint64 { 250 | return q.offset 251 | } 252 | 253 | // Where adds a new condition to filter the query. All conditions added are 254 | // concatenated with "and". 255 | // q.Where(Eq(NameColumn, "foo")) 256 | // q.Where(Gt(AgeColumn, 18)) 257 | // // ... WHERE name = "foo" AND age > 18 258 | func (q *BaseQuery) Where(cond Condition) { 259 | q.builder = q.builder.Where(cond(q.schema)) 260 | } 261 | 262 | // compile returns the selected column names and the select builder. 263 | func (q *BaseQuery) compile() ([]string, squirrel.SelectBuilder) { 264 | columns := q.selectedColumns() 265 | var ( 266 | qualifiedColumns = make([]string, len(columns)) 267 | columnNames = make([]string, len(columns)) 268 | ) 269 | 270 | for i := range columns { 271 | qualifiedColumns[i] = columns[i].QualifiedName(q.schema) 272 | columnNames[i] = columns[i].String() 273 | } 274 | return columnNames, q.builder.Columns( 275 | append(qualifiedColumns, q.relationColumns...)..., 276 | ) 277 | } 278 | 279 | // String returns the SQL generated by the query. If the query is malformed, 280 | // it will return an empty string, as errors compiling the SQL are ignored. 281 | func (q *BaseQuery) String() string { 282 | _, builder := q.compile() 283 | sql, _, _ := builder.ToSql() 284 | return sql 285 | } 286 | 287 | // ToSql returns the SQL generated by the query, the query arguments, and 288 | // any error returned during the compile process. 289 | func (q *BaseQuery) ToSql() (string, []interface{}, error) { 290 | _, builder := q.compile() 291 | return builder.ToSql() 292 | } 293 | 294 | // ColumnOrder represents a column name with its order. 295 | type ColumnOrder interface { 296 | // ToSql returns the SQL representation of the column with its order. 297 | ToSql(Schema) string 298 | isColumnOrder() 299 | } 300 | 301 | type colOrder struct { 302 | order string 303 | col SchemaField 304 | } 305 | 306 | // ToSql returns the SQL representation of the column with its order. 307 | func (o *colOrder) ToSql(schema Schema) string { 308 | return fmt.Sprintf("%s %s", o.col.QualifiedName(schema), o.order) 309 | } 310 | func (colOrder) isColumnOrder() {} 311 | 312 | const ( 313 | asc = "ASC" 314 | desc = "DESC" 315 | ) 316 | 317 | // Asc returns a column ordered by ascending order. 318 | func Asc(col SchemaField) ColumnOrder { 319 | return &colOrder{asc, col} 320 | } 321 | 322 | // Desc returns a column ordered by descending order. 323 | func Desc(col SchemaField) ColumnOrder { 324 | return &colOrder{desc, col} 325 | } 326 | -------------------------------------------------------------------------------- /query_test.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import ( 4 | "testing" 5 | "unsafe" 6 | 7 | "github.com/stretchr/testify/suite" 8 | ) 9 | 10 | func TestBaseQuery(t *testing.T) { 11 | suite.Run(t, new(QuerySuite)) 12 | } 13 | 14 | type QuerySuite struct { 15 | suite.Suite 16 | q *BaseQuery 17 | } 18 | 19 | func (s *QuerySuite) SetupTest() { 20 | s.q = NewBaseQuery(ModelSchema) 21 | } 22 | 23 | func (s *QuerySuite) TestSelect() { 24 | s.q.Select(f("a"), f("b"), f("c")) 25 | s.Equal(columnSet{f("a"), f("b"), f("c")}, s.q.columns) 26 | } 27 | 28 | func (s *QuerySuite) TestSelectNot() { 29 | s.q.SelectNot(f("a"), f("b"), f("c")) 30 | s.Equal(columnSet{f("a"), f("b"), f("c")}, s.q.excludedColumns) 31 | } 32 | 33 | func (s *QuerySuite) TestSelectNotSelectSelectNot() { 34 | s.q.SelectNot(f("a"), f("b")) 35 | s.q.Select(f("a"), f("c")) 36 | s.q.SelectNot(f("a")) 37 | s.Equal([]SchemaField{f("c")}, s.q.selectedColumns()) 38 | } 39 | 40 | func (s *QuerySuite) TestSelectSelectNot() { 41 | s.q.Select(f("a"), f("c")) 42 | s.q.SelectNot(f("a")) 43 | s.Equal([]SchemaField{f("c")}, s.q.selectedColumns()) 44 | } 45 | 46 | func (s *QuerySuite) TestCopy() { 47 | s.q.Select(f("a"), f("b"), f("c")) 48 | s.q.SelectNot(f("foo")) 49 | s.q.BatchSize(30) 50 | s.q.Limit(2) 51 | s.q.Offset(30) 52 | copy := s.q.Copy() 53 | 54 | s.Equal(s.q, copy) 55 | s.NotEqual(unsafe.Pointer(s.q), unsafe.Pointer(copy)) 56 | } 57 | 58 | func (s *QuerySuite) TestSelectedColumns() { 59 | s.q.Select(f("a"), f("b"), f("c")) 60 | s.q.SelectNot(f("b")) 61 | s.Equal([]SchemaField{f("a"), f("c")}, s.q.selectedColumns()) 62 | } 63 | 64 | func (s *QuerySuite) TestOrder() { 65 | s.q.Select(f("foo")) 66 | s.q.Order(Asc(f("bar"))) 67 | s.q.Order(Desc(f("baz"))) 68 | 69 | s.assertSql("SELECT __model.foo FROM model __model ORDER BY __model.bar ASC, __model.baz DESC") 70 | } 71 | 72 | func (s *QuerySuite) TestWhere() { 73 | s.q.Select(f("foo")) 74 | s.q.Where(Eq(f("foo"), 5)) 75 | s.q.Where(Eq(f("bar"), "baz")) 76 | 77 | s.assertSql("SELECT __model.foo FROM model __model WHERE __model.foo = $1 AND __model.bar = $2") 78 | } 79 | 80 | func (s *QuerySuite) TestString() { 81 | s.q.Select(f("foo")) 82 | s.Equal("SELECT __model.foo FROM model __model", s.q.String()) 83 | } 84 | 85 | func (s *QuerySuite) TestToSql() { 86 | s.q.Select(f("foo")) 87 | s.q.Where(Eq(f("foo"), 5)) 88 | s.q.Where(Eq(f("bar"), "baz")) 89 | sql, args, err := s.q.ToSql() 90 | s.Equal("SELECT __model.foo FROM model __model WHERE __model.foo = $1 AND __model.bar = $2", sql) 91 | s.Equal([]interface{}{5, "baz"}, args) 92 | s.Equal(err, nil) 93 | } 94 | 95 | func (s *QuerySuite) TestAddRelation() { 96 | s.Nil(s.q.AddRelation(RelSchema, "rel", OneToOne, nil)) 97 | s.Equal("SELECT __model.id, __model.name, __model.email, __model.age, __rel_rel.id, __rel_rel.model_id, __rel_rel.foo FROM model __model LEFT JOIN rel __rel_rel ON (__rel_rel.model_id = __model.id)", s.q.String()) 98 | } 99 | 100 | func (s *QuerySuite) TestAddRelation_Inverse() { 101 | s.Nil(s.q.AddRelation(RelSchema, "rel_inv", OneToOne, nil)) 102 | s.Equal("SELECT __model.id, __model.name, __model.email, __model.age, __rel_rel_inv.id, __rel_rel_inv.model_id, __rel_rel_inv.foo FROM model __model LEFT JOIN rel __rel_rel_inv ON (__rel_rel_inv.id = __model.model_id)", s.q.String()) 103 | } 104 | 105 | func (s *QuerySuite) TestAddRelation_ManyToMany() { 106 | err := s.q.AddRelation(RelSchema, "rel", ManyToMany, nil) 107 | s.Equal(ErrManyToManyNotSupported, err) 108 | } 109 | 110 | func (s *QuerySuite) TestAddRelation_FKNotFound() { 111 | s.Error(s.q.AddRelation(RelSchema, "fooo", OneToOne, nil)) 112 | } 113 | 114 | func (s *QuerySuite) assertSql(sql string) { 115 | _, builder := s.q.compile() 116 | result, _, err := builder.ToSql() 117 | s.Nil(err) 118 | s.Equal(sql, result) 119 | } 120 | -------------------------------------------------------------------------------- /resultset.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "io" 7 | 8 | "gopkg.in/src-d/go-kallax.v1/types" 9 | ) 10 | 11 | // ResultSet is the common interface all result sets need to implement. 12 | type ResultSet interface { 13 | // RawScan allows for raw scanning of fields in a result set. 14 | RawScan(...interface{}) error 15 | // Next moves the pointer to the next item in the result set and returns 16 | // if there was any. 17 | Next() bool 18 | // Get returns the next record of the given schema. 19 | Get(Schema) (Record, error) 20 | io.Closer 21 | } 22 | 23 | // ErrRawScan is an error returned when a the `Scan` method of `ResultSet` 24 | // is called with a `ResultSet` created as a result of a `RawQuery`, which is 25 | // not allowed. 26 | var ErrRawScan = errors.New("kallax: result set comes from raw query, use RawScan instead") 27 | 28 | // ErrRawScanBatching is an error returned when the `RawScan` method is used 29 | // with a batching result set. 30 | var ErrRawScanBatching = errors.New("kallax: cannot perform a raw scan on a batching result set") 31 | 32 | // BaseResultSet is a generic collection of rows. 33 | type BaseResultSet struct { 34 | relationships []Relationship 35 | columns []string 36 | readOnly bool 37 | *sql.Rows 38 | } 39 | 40 | // NewResultSet creates a new result set with the given rows and columns. 41 | // It is mandatory that all column names are in the same order and are exactly 42 | // equal to the ones in the query that produced the rows. 43 | func NewResultSet(rows *sql.Rows, readOnly bool, relationships []Relationship, columns ...string) *BaseResultSet { 44 | return &BaseResultSet{ 45 | relationships, 46 | columns, 47 | readOnly, 48 | rows, 49 | } 50 | } 51 | 52 | // Get returns the next record in the schema. 53 | func (rs *BaseResultSet) Get(schema Schema) (Record, error) { 54 | record := schema.New() 55 | if err := rs.Scan(record); err != nil { 56 | return nil, err 57 | } 58 | return record, nil 59 | } 60 | 61 | // Scan fills the column fields of the given value with the current row. 62 | func (rs *BaseResultSet) Scan(record Record) error { 63 | if len(rs.columns) == 0 { 64 | return ErrRawScan 65 | } 66 | 67 | var ( 68 | relationships = make([]Record, len(rs.relationships)) 69 | pointers = make([]interface{}, len(rs.columns)) 70 | ) 71 | 72 | for i, col := range rs.columns { 73 | ptr, err := record.ColumnAddress(col) 74 | if err != nil { 75 | return err 76 | } 77 | 78 | pointers[i] = ptr 79 | } 80 | 81 | for i, r := range rs.relationships { 82 | rec, err := record.NewRelationshipRecord(r.Field) 83 | if err != nil { 84 | return err 85 | } 86 | 87 | for _, col := range r.Schema.Columns() { 88 | ptr, err := rec.ColumnAddress(col.String()) 89 | if err != nil { 90 | return err 91 | } 92 | pointers = append(pointers, types.Nullable(ptr)) 93 | } 94 | 95 | relationships[i] = rec 96 | } 97 | 98 | if err := rs.Rows.Scan(pointers...); err != nil { 99 | return err 100 | } 101 | 102 | for i, r := range rs.relationships { 103 | relationships[i].setPersisted() 104 | relationships[i].setWritable(true) 105 | err := record.SetRelationship(r.Field, relationships[i]) 106 | if err != nil { 107 | return err 108 | } 109 | } 110 | 111 | record.setWritable(!rs.readOnly) 112 | record.setPersisted() 113 | return nil 114 | } 115 | 116 | // RowScan copies the columns in the current row into the values pointed at by 117 | // dest. The number of values in dest must be the same as the number of columns 118 | // selected in the query. 119 | func (rs *BaseResultSet) RawScan(dest ...interface{}) error { 120 | return rs.Rows.Scan(dest...) 121 | } 122 | 123 | // NewBatchingResultSet returns a new result set that performs batching 124 | // underneath. 125 | func NewBatchingResultSet(runner *batchQueryRunner) *BatchingResultSet { 126 | return &BatchingResultSet{runner: runner} 127 | } 128 | 129 | // BatchingResultSet is a result set that retrieves all the items up to the 130 | // batch size set in the query. 131 | // If there are 1:N relationships, it collects all the identifiers of 132 | // those records, retrieves all the rows matching them in the table of the 133 | // the N end, and assigns them to their correspondent to the record they belong 134 | // to. 135 | // It will continue doing this process until no more rows are returned by the 136 | // query. 137 | // This minimizes the number of queries and operations to perform in order to 138 | // retrieve a set of results and their relationships. 139 | type BatchingResultSet struct { 140 | runner *batchQueryRunner 141 | last Record 142 | lastErr error 143 | } 144 | 145 | // Next advances the internal index of the fetched records in one. 146 | // If there are no fetched records, will fetch the next batch. 147 | // It will return false when there are no more rows. 148 | func (rs *BatchingResultSet) Next() bool { 149 | rs.last, rs.lastErr = rs.runner.next() 150 | if rs.lastErr == errNoMoreRows { 151 | return false 152 | } 153 | 154 | return true 155 | } 156 | 157 | // Get returns the next processed record and the last error occurred. 158 | // Even though it accepts a schema, it is ignored, as the result set is 159 | // already aware of it. This is here just to be able to imeplement the 160 | // ResultSet interface. 161 | func (rs *BatchingResultSet) Get(_ Schema) (Record, error) { 162 | return rs.last, rs.lastErr 163 | } 164 | 165 | // Close will do nothing, as the internal result sets used by this are closed 166 | // when the rows at fetched. It will never throw an error. 167 | func (rs *BatchingResultSet) Close() error { 168 | return nil 169 | } 170 | 171 | // RawScan will always throw an error, as this is not a supported operation of 172 | // a batching result set. 173 | func (rs *BatchingResultSet) RawScan(_ ...interface{}) error { 174 | return ErrRawScanBatching 175 | } 176 | -------------------------------------------------------------------------------- /schema.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // Schema represents a table schema in the database. Contains some information 9 | // like the table name, its columns, its identifier and so on. 10 | type Schema interface { 11 | // Alias returns the name of the alias used in queries for this schema. 12 | Alias() string 13 | // Table returns the table name. 14 | Table() string 15 | // ID returns the name of the identifier of the table. 16 | ID() SchemaField 17 | // Columns returns the list of columns in the schema. 18 | Columns() []SchemaField 19 | // ForeignKey returns the name of the foreign key of the given model field. 20 | ForeignKey(string) (*ForeignKey, bool) 21 | // WithAlias returns a new schema with the given string added to the 22 | // default alias. 23 | // Calling WithAlias on a schema returned by WithAlias not return a 24 | // schema based on the child, but another based on the parent. 25 | WithAlias(string) Schema 26 | // New creates a new record with the given schema. 27 | New() Record 28 | isPrimaryKeyAutoIncrementable() bool 29 | } 30 | 31 | // BaseSchema is the basic implementation of Schema. 32 | type BaseSchema struct { 33 | alias string 34 | table string 35 | foreignKeys ForeignKeys 36 | id SchemaField 37 | columns []SchemaField 38 | constructor RecordConstructor 39 | autoIncr bool 40 | } 41 | 42 | // RecordConstructor is a function that creates a record. 43 | type RecordConstructor func() Record 44 | 45 | // NewBaseSchema creates a new schema with the given table, alias, identifier 46 | // and columns. 47 | func NewBaseSchema(table, alias string, id SchemaField, fks ForeignKeys, ctor RecordConstructor, autoIncr bool, columns ...SchemaField) *BaseSchema { 48 | return &BaseSchema{ 49 | alias: alias, 50 | table: table, 51 | foreignKeys: fks, 52 | id: id, 53 | columns: columns, 54 | constructor: ctor, 55 | autoIncr: autoIncr, 56 | } 57 | } 58 | 59 | func (s *BaseSchema) Alias() string { return s.alias } 60 | func (s *BaseSchema) Table() string { return s.table } 61 | func (s *BaseSchema) ID() SchemaField { return s.id } 62 | func (s *BaseSchema) Columns() []SchemaField { return s.columns } 63 | func (s *BaseSchema) ForeignKey(field string) (*ForeignKey, bool) { 64 | k, ok := s.foreignKeys[field] 65 | return k, ok 66 | } 67 | func (s *BaseSchema) WithAlias(field string) Schema { 68 | return &aliasSchema{s, field} 69 | } 70 | func (s *BaseSchema) New() Record { 71 | return s.constructor() 72 | } 73 | func (s *BaseSchema) isPrimaryKeyAutoIncrementable() bool { return s.autoIncr } 74 | 75 | type aliasSchema struct { 76 | *BaseSchema 77 | alias string 78 | } 79 | 80 | func (s *aliasSchema) Alias() string { 81 | return fmt.Sprintf("%s_%s", s.BaseSchema.Alias(), s.alias) 82 | } 83 | 84 | // ForeignKeys is a mapping between relationships and their foreign key field. 85 | type ForeignKeys map[string]*ForeignKey 86 | 87 | // SchemaField is a named field in the table schema. 88 | type SchemaField interface { 89 | isSchemaField() 90 | // String returns the string representation of the field. That is, its name. 91 | String() string 92 | // QualifiedString returns the name of the field qualified by the alias of 93 | // the given schema. 94 | QualifiedName(Schema) string 95 | } 96 | 97 | // BaseSchemaField is a basic schema field with name. 98 | type BaseSchemaField struct { 99 | name string 100 | } 101 | 102 | // NewSchemaField creates a new schema field with the given name. 103 | func NewSchemaField(name string) SchemaField { 104 | return &BaseSchemaField{name} 105 | } 106 | 107 | func (*BaseSchemaField) isSchemaField() {} 108 | 109 | func (f BaseSchemaField) String() string { 110 | return f.name 111 | } 112 | 113 | func (f *BaseSchemaField) QualifiedName(schema Schema) string { 114 | alias := schema.Alias() 115 | if alias != "" { 116 | return fmt.Sprintf("%s.%s", alias, f.name) 117 | } 118 | return f.name 119 | } 120 | 121 | // ForeignKey contains the schema field of the foreign key and if it is an 122 | // inverse foreign key or not. 123 | type ForeignKey struct { 124 | *BaseSchemaField 125 | Inverse bool 126 | } 127 | 128 | // NewForeignKey creates a new Foreign key with the given name. 129 | func NewForeignKey(name string, inverse bool) *ForeignKey { 130 | return &ForeignKey{&BaseSchemaField{name}, inverse} 131 | } 132 | 133 | // JSONSchemaKey is a SchemaField that represents a key in a JSON object. 134 | type JSONSchemaKey struct { 135 | typ JSONKeyType 136 | field string 137 | paths []string 138 | } 139 | 140 | // JSONSchemaArray is a SchemaField that represents a JSON array. 141 | type JSONSchemaArray struct { 142 | key *JSONSchemaKey 143 | } 144 | 145 | // JSONKeyType is the type of an object key in a JSON. 146 | type JSONKeyType string 147 | 148 | const ( 149 | // JSONAny represents a type that can't be casted. 150 | JSONAny JSONKeyType = "" 151 | // JSONText is a text json type. 152 | JSONText JSONKeyType = "text" 153 | // JSONInt is a numeric json type. 154 | JSONInt JSONKeyType = "bigint" 155 | // JSONFloat is a floating point json type. 156 | JSONFloat JSONKeyType = "decimal" 157 | // JSONBool is a boolean json type. 158 | JSONBool JSONKeyType = "bool" 159 | ) 160 | 161 | // ArraySchemaField is an interface that defines if a field is a JSON 162 | // array. 163 | type ArraySchemaField interface { 164 | SchemaField 165 | isArraySchemaField() 166 | } 167 | 168 | // NewJSONSchemaArray creates a new SchemaField that is a json array. 169 | func NewJSONSchemaArray(field string, paths ...string) *JSONSchemaArray { 170 | return &JSONSchemaArray{NewJSONSchemaKey(JSONAny, field, paths...)} 171 | } 172 | 173 | func (f *JSONSchemaArray) QualifiedName(schema Schema) string { 174 | return f.key.QualifiedName(schema) 175 | } 176 | 177 | func (f *JSONSchemaArray) String() string { 178 | return f.key.String() 179 | } 180 | 181 | // NewJSONSchemaKey creates a new SchemaField that is a json key. 182 | func NewJSONSchemaKey(typ JSONKeyType, field string, paths ...string) *JSONSchemaKey { 183 | return &JSONSchemaKey{typ, field, paths} 184 | } 185 | 186 | func (f *JSONSchemaKey) QualifiedName(schema Schema) string { 187 | op := "#>" 188 | format := "%s%s %s'{%s}'" 189 | if f.typ == JSONText { 190 | op = "#>>" 191 | } else if f.typ != JSONAny { 192 | op = "#>>" 193 | format = "CAST(%s%s %s'{%s}' as " + string(f.typ) + ")" 194 | } 195 | 196 | var alias string 197 | if schema != nil && schema.Alias() != "" { 198 | alias = schema.Alias() + "." 199 | } 200 | 201 | return fmt.Sprintf(format, alias, f.field, op, strings.Join(f.paths, ",")) 202 | } 203 | 204 | func (f *JSONSchemaKey) String() string { 205 | return f.QualifiedName(nil) 206 | } 207 | 208 | func (*JSONSchemaKey) isSchemaField() {} 209 | func (*JSONSchemaArray) isSchemaField() {} 210 | func (*JSONSchemaArray) isArraySchemaField() {} 211 | 212 | // AtJSONPath returns the schema field to query an arbitrary JSON element at 213 | // the given path. 214 | func AtJSONPath(field SchemaField, typ JSONKeyType, path ...string) SchemaField { 215 | return NewJSONSchemaKey(typ, field.String(), path...) 216 | } 217 | 218 | // Relationship is a relationship with its schema and the field of te relation 219 | // in the record. 220 | type Relationship struct { 221 | // Type is the kind of relationship this is. 222 | Type RelationshipType 223 | // Field is the field in the record where the relationship is. 224 | Field string 225 | // Schema is the schema of the relationship. 226 | Schema Schema 227 | // Filter establishes the filter to be applied when retrieving rows of the 228 | // relationships. 229 | Filter Condition 230 | } 231 | 232 | // RelationshipType describes the type of the relationship. 233 | type RelationshipType byte 234 | 235 | const ( 236 | // OneToOne is a relationship between one record in a table and another in 237 | // another table. 238 | OneToOne RelationshipType = iota 239 | // OneToMany is a relationship between one record in a table and multiple 240 | // in another table. 241 | OneToMany 242 | // ManyToMany is a relationship between many records on both sides of the 243 | // relationship. 244 | // NOTE: It is not supported yet. 245 | ManyToMany 246 | ) 247 | 248 | func containsRelationshipOfType(rels []Relationship, typ RelationshipType) bool { 249 | for _, r := range rels { 250 | if r.Type == typ { 251 | return true 252 | } 253 | } 254 | return false 255 | } 256 | 257 | // ColumnNames returns the names of the given schema fields. 258 | func ColumnNames(columns []SchemaField) []string { 259 | var names = make([]string, len(columns)) 260 | for i, v := range columns { 261 | names[i] = v.String() 262 | } 263 | return names 264 | } 265 | -------------------------------------------------------------------------------- /schema_test.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | var emptySchema = NewBaseSchema("", "", nil, nil, nil, false) 10 | 11 | func TestBaseSchemaFieldQualifiedName(t *testing.T) { 12 | var cases = []struct { 13 | name string 14 | field SchemaField 15 | schema Schema 16 | expected string 17 | }{ 18 | {"non empty schema alias", f("foo"), ModelSchema, "__model.foo"}, 19 | {"empty schema alias", f("foo"), emptySchema, "foo"}, 20 | } 21 | 22 | r := require.New(t) 23 | for _, c := range cases { 24 | r.Equal(c.expected, c.field.QualifiedName(c.schema), c.name) 25 | } 26 | } 27 | 28 | func TestJSONSchemaKeyQualifiedName(t *testing.T) { 29 | var cases = []struct { 30 | name string 31 | key *JSONSchemaKey 32 | schema Schema 33 | expected string 34 | }{ 35 | { 36 | "json text key", 37 | NewJSONSchemaKey(JSONText, "foo", "bar", "baz"), 38 | ModelSchema, 39 | "__model.foo #>>'{bar,baz}'", 40 | }, 41 | { 42 | "json int key", 43 | NewJSONSchemaKey(JSONInt, "foo", "bar", "baz"), 44 | ModelSchema, 45 | "CAST(__model.foo #>>'{bar,baz}' as bigint)", 46 | }, 47 | { 48 | "json any key", 49 | NewJSONSchemaKey(JSONAny, "foo", "bar", "baz"), 50 | ModelSchema, 51 | "__model.foo #>'{bar,baz}'", 52 | }, 53 | { 54 | "json key with empty schema", 55 | NewJSONSchemaKey(JSONBool, "foo", "bar", "baz"), 56 | nil, 57 | "CAST(foo #>>'{bar,baz}' as bool)", 58 | }, 59 | } 60 | 61 | r := require.New(t) 62 | for _, c := range cases { 63 | r.Equal(c.expected, c.key.QualifiedName(c.schema), c.name) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /tests/common.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | //go:generate kallax gen 4 | -------------------------------------------------------------------------------- /tests/common_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "os" 7 | "reflect" 8 | 9 | "github.com/stretchr/testify/suite" 10 | ) 11 | 12 | var ( 13 | connectionString = "postgres://%s:%s@%s/%s?sslmode=disable" 14 | host = envOrDefault("DBHOST", "0.0.0.0:5432") 15 | database = envOrDefault("DBNAME", "testing") 16 | user = envOrDefault("DBUSER", "testing") 17 | password = envOrDefault("DBPASS", "testing") 18 | ) 19 | 20 | type BaseTestSuite struct { 21 | suite.Suite 22 | db *sql.DB 23 | schemas []string 24 | tables []string 25 | 26 | // the count of opened connections. 27 | // this will be set in the SetupTest function. 28 | openConnectionsBeforeTest int 29 | } 30 | 31 | func NewBaseSuite(schemas []string, tables ...string) BaseTestSuite { 32 | return BaseTestSuite{ 33 | schemas: schemas, 34 | tables: tables, 35 | } 36 | } 37 | 38 | func (s *BaseTestSuite) SetupSuite() { 39 | db, err := sql.Open( 40 | "postgres", 41 | fmt.Sprintf(connectionString, user, password, host, database), 42 | ) 43 | 44 | if err != nil { 45 | panic(fmt.Sprintf("It was unable to connect to the DB.\n%s\n", err)) 46 | } 47 | 48 | // set all connections will be closed immediately. 49 | // this is required to check connections are leaked or not. 50 | // because database/sql keep connection in the pool by default. 51 | db.SetMaxIdleConns(0) 52 | 53 | s.db = db 54 | } 55 | 56 | func (s *BaseTestSuite) TearDownSuite() { 57 | s.db.Close() 58 | } 59 | 60 | func (s *BaseTestSuite) SetupTest() { 61 | // save current open connection count for detecting that connection was leaked while a test. 62 | s.openConnectionsBeforeTest = s.db.Stats().OpenConnections 63 | 64 | if len(s.tables) == 0 { 65 | return 66 | } 67 | 68 | s.QuerySucceed(s.schemas...) 69 | } 70 | 71 | func (s *BaseTestSuite) TearDownTest() { 72 | openConnections := s.db.Stats().OpenConnections 73 | leakedConnections := openConnections - s.openConnectionsBeforeTest 74 | if leakedConnections > 0 { 75 | s.Fail(fmt.Sprintf("%d database connections were leaked", leakedConnections)) 76 | } 77 | 78 | if len(s.tables) == 0 { 79 | return 80 | } 81 | var queries []string 82 | for _, t := range s.tables { 83 | queries = append(queries, fmt.Sprintf("DROP TABLE IF EXISTS %s", t)) 84 | } 85 | s.QuerySucceed(queries...) 86 | } 87 | 88 | func (s *BaseTestSuite) QuerySucceed(queries ...string) bool { 89 | success := true 90 | for _, query := range queries { 91 | res, err := s.db.Exec(query) 92 | assert1 := s.NotNil(res, "Resulset should not be empty") 93 | assert2 := s.Nil(err, fmt.Sprintf("%s\nshould succeed but it failed.\n%s\n", query, err)) 94 | if !assert1 || !assert2 { 95 | success = false 96 | } 97 | } 98 | 99 | return success 100 | } 101 | 102 | func (s *BaseTestSuite) QueryFails(queries ...string) bool { 103 | success := true 104 | for _, query := range queries { 105 | res, err := s.db.Exec(query) 106 | assert1 := s.Nil(res, "Resulset should be empty but it was not") 107 | assert2 := s.NotNil(err, fmt.Sprintf("%s\nshould fail but it succeed", query)) 108 | if !assert1 || !assert2 { 109 | success = false 110 | } 111 | } 112 | 113 | return success 114 | } 115 | 116 | func (s *BaseTestSuite) resultOrError(res interface{}, err error) bool { 117 | if !reflect.ValueOf(res).Elem().IsValid() { 118 | res = nil 119 | } 120 | 121 | if err == nil && res == nil { 122 | s.Fail("FindOne should return an error or a document, but nothing was returned") 123 | return false 124 | } 125 | 126 | if err != nil && res != nil { 127 | s.Fail("FindOne should return only an error or a document, but it was returned both") 128 | return false 129 | } 130 | 131 | return true 132 | } 133 | 134 | func (s *BaseTestSuite) resultsOrError(res interface{}, err error) bool { 135 | if reflect.ValueOf(res).Kind() != reflect.Slice { 136 | panic("resultsOrError expects res is a slice") 137 | } 138 | 139 | if err == nil && res == nil { 140 | s.Fail("FindAll should return an error or a documents, but nothing was returned") 141 | return false 142 | } 143 | 144 | if err != nil && res != nil { 145 | s.Fail("FindAll should return only an error or a documents, but it was returned both") 146 | return false 147 | } 148 | 149 | return true 150 | } 151 | 152 | func envOrDefault(key string, def string) string { 153 | v := os.Getenv(key) 154 | if v == "" { 155 | return def 156 | } 157 | 158 | return v 159 | } 160 | -------------------------------------------------------------------------------- /tests/connection_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/suite" 7 | ) 8 | 9 | func TestConnectionSuite(t *testing.T) { 10 | suite.Run(t, new(ConnectionSuite)) 11 | } 12 | 13 | type ConnectionSuite struct { 14 | BaseTestSuite 15 | } 16 | 17 | func (s *ConnectionSuite) TestConnection() { 18 | s.QuerySucceed( 19 | `CREATE TABLE testing (id uuid primary key)`, 20 | `DROP TABLE testing`, 21 | `DROP TABLE IF EXISTS testing`, 22 | ) 23 | s.QueryFails(`DROP TABLE _THIS_TABLE_DOES_NOT_EXIST`) 24 | } 25 | -------------------------------------------------------------------------------- /tests/events.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import "gopkg.in/src-d/go-kallax.v1" 4 | 5 | type EventsFixture struct { 6 | kallax.Model `table:"event"` 7 | ID kallax.ULID `pk:""` 8 | Checks map[string]bool 9 | MustFailBefore error 10 | MustFailAfter error 11 | } 12 | 13 | func newEventsFixture() *EventsFixture { 14 | return &EventsFixture{ 15 | ID: kallax.NewULID(), 16 | Checks: make(map[string]bool, 0), 17 | } 18 | } 19 | 20 | func (s *EventsFixture) BeforeInsert() error { 21 | if s.MustFailBefore != nil { 22 | return s.MustFailBefore 23 | } 24 | 25 | s.Checks["BeforeInsert"] = true 26 | return nil 27 | } 28 | 29 | func (s *EventsFixture) AfterInsert() error { 30 | if s.MustFailAfter != nil { 31 | return s.MustFailAfter 32 | } 33 | 34 | s.Checks["AfterInsert"] = true 35 | return nil 36 | } 37 | 38 | func (s *EventsFixture) BeforeUpdate() error { 39 | if s.MustFailBefore != nil { 40 | return s.MustFailBefore 41 | } 42 | 43 | s.Checks["BeforeUpdate"] = true 44 | return nil 45 | } 46 | 47 | func (s *EventsFixture) AfterUpdate() error { 48 | if s.MustFailAfter != nil { 49 | return s.MustFailAfter 50 | } 51 | 52 | s.Checks["AfterUpdate"] = true 53 | return nil 54 | } 55 | 56 | type EventsSaveFixture struct { 57 | kallax.Model `table:"event"` 58 | ID kallax.ULID `pk:""` 59 | Checks map[string]bool 60 | MustFailBefore error 61 | MustFailAfter error 62 | } 63 | 64 | func newEventsSaveFixture() *EventsSaveFixture { 65 | return &EventsSaveFixture{ 66 | ID: kallax.NewULID(), 67 | Checks: make(map[string]bool, 0), 68 | } 69 | } 70 | 71 | func (s *EventsSaveFixture) BeforeSave() error { 72 | if s.MustFailBefore != nil { 73 | return s.MustFailBefore 74 | } 75 | 76 | s.Checks["BeforeSave"] = true 77 | return nil 78 | } 79 | 80 | func (s *EventsSaveFixture) AfterSave() error { 81 | if s.MustFailAfter != nil { 82 | return s.MustFailAfter 83 | } 84 | 85 | s.Checks["AfterSave"] = true 86 | return nil 87 | } 88 | 89 | type EventsAllFixture struct { 90 | kallax.Model `table:"event"` 91 | ID kallax.ULID `pk:""` 92 | Checks map[string]bool 93 | MustFailBefore error 94 | MustFailAfter error 95 | } 96 | 97 | func newEventsAllFixture() *EventsAllFixture { 98 | return &EventsAllFixture{ 99 | ID: kallax.NewULID(), 100 | Checks: make(map[string]bool, 0), 101 | } 102 | } 103 | 104 | func (s *EventsAllFixture) BeforeInsert() error { 105 | if s.MustFailBefore != nil { 106 | return s.MustFailBefore 107 | } 108 | 109 | s.Checks["BeforeInsert"] = true 110 | return nil 111 | } 112 | 113 | func (s *EventsAllFixture) AfterInsert() error { 114 | if s.MustFailAfter != nil { 115 | return s.MustFailAfter 116 | } 117 | 118 | s.Checks["AfterInsert"] = true 119 | return nil 120 | } 121 | 122 | func (s *EventsAllFixture) BeforeUpdate() error { 123 | if s.MustFailBefore != nil { 124 | return s.MustFailBefore 125 | } 126 | 127 | s.Checks["BeforeUpdate"] = true 128 | return nil 129 | } 130 | 131 | func (s *EventsAllFixture) AfterUpdate() error { 132 | if s.MustFailAfter != nil { 133 | return s.MustFailAfter 134 | } 135 | 136 | s.Checks["AfterUpdate"] = true 137 | return nil 138 | } 139 | 140 | func (s *EventsAllFixture) BeforeSave() error { 141 | if s.MustFailBefore != nil { 142 | return s.MustFailBefore 143 | } 144 | 145 | s.Checks["BeforeSave"] = true 146 | return nil 147 | } 148 | 149 | func (s *EventsAllFixture) AfterSave() error { 150 | if s.MustFailAfter != nil { 151 | return s.MustFailAfter 152 | } 153 | 154 | s.Checks["AfterSave"] = true 155 | return nil 156 | } 157 | -------------------------------------------------------------------------------- /tests/events_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/suite" 9 | ) 10 | 11 | type EventsSuite struct { 12 | BaseTestSuite 13 | } 14 | 15 | func TestEventsSuite(t *testing.T) { 16 | schema := []string{ 17 | `CREATE TABLE IF NOT EXISTS event ( 18 | id uuid primary key, 19 | checks JSON, 20 | must_fail_before JSON, 21 | must_fail_after JSON 22 | )`, 23 | } 24 | suite.Run(t, &EventsSuite{NewBaseSuite(schema, "event")}) 25 | } 26 | 27 | type eventsCheck map[string]bool 28 | 29 | func (s *EventsSuite) assertEventsPassed(expected eventsCheck, received eventsCheck) { 30 | for expectedEvent, expectedSign := range expected { 31 | receivedSign, ok := received[expectedEvent] 32 | if s.True(ok, fmt.Sprintf(`Event '%s' was not received`, expectedEvent)) { 33 | s.Equal(expectedSign, receivedSign, expectedEvent) 34 | } 35 | } 36 | 37 | s.Equal(len(expected), len(received), "expected same number of events") 38 | } 39 | 40 | func (s *EventsSuite) TestEventsInsert() { 41 | store := NewEventsFixtureStore(s.db) 42 | 43 | doc := NewEventsFixture() 44 | err := store.Insert(doc) 45 | s.Nil(err) 46 | s.assertEventsPassed(map[string]bool{ 47 | "BeforeInsert": true, 48 | "AfterInsert": true, 49 | }, doc.Checks) 50 | } 51 | 52 | func (s *EventsSuite) TestEventsUpdate() { 53 | store := NewEventsFixtureStore(s.db) 54 | 55 | doc := NewEventsFixture() 56 | err := store.Insert(doc) 57 | s.Nil(err) 58 | 59 | doc.Checks = make(map[string]bool) 60 | updatedRows, err := store.Update(doc) 61 | s.Nil(err) 62 | s.True(updatedRows > 0) 63 | s.assertEventsPassed(map[string]bool{ 64 | "BeforeUpdate": true, 65 | "AfterUpdate": true, 66 | }, doc.Checks) 67 | } 68 | 69 | func (s *EventsSuite) TestEventsUpdateError() { 70 | store := NewEventsFixtureStore(s.db) 71 | 72 | doc := NewEventsFixture() 73 | err := store.Insert(doc) 74 | doc.Checks = make(map[string]bool) 75 | 76 | doc.MustFailAfter = errors.New("kallax: after") 77 | updatedRows, err := store.Update(doc) 78 | s.Equal(int64(0), updatedRows) 79 | s.Equal(doc.MustFailAfter, err) 80 | 81 | doc.MustFailBefore = errors.New("kallax: before") 82 | updatedRows, err = store.Update(doc) 83 | s.Equal(int64(0), updatedRows) 84 | s.Equal(doc.MustFailBefore, err) 85 | } 86 | 87 | func (s *EventsSuite) TestEventsSaveOnInsert() { 88 | store := NewEventsFixtureStore(s.db) 89 | 90 | doc := NewEventsFixture() 91 | updated, err := store.Save(doc) 92 | s.Nil(err) 93 | s.False(updated) 94 | s.assertEventsPassed(map[string]bool{ 95 | "BeforeInsert": true, 96 | "AfterInsert": true, 97 | }, doc.Checks) 98 | } 99 | 100 | func (s *EventsSuite) TestEventsSaveOnUpdate() { 101 | store := NewEventsFixtureStore(s.db) 102 | 103 | doc := NewEventsFixture() 104 | err := store.Insert(doc) 105 | doc.Checks = make(map[string]bool) 106 | 107 | updated, err := store.Save(doc) 108 | s.Nil(err) 109 | s.True(updated) 110 | s.assertEventsPassed(map[string]bool{ 111 | "BeforeUpdate": true, 112 | "AfterUpdate": true, 113 | }, doc.Checks) 114 | } 115 | 116 | func (s *EventsSuite) TestEventsSaveInsert() { 117 | store := NewEventsSaveFixtureStore(s.db) 118 | 119 | doc := NewEventsSaveFixture() 120 | err := store.Insert(doc) 121 | s.Nil(err) 122 | s.assertEventsPassed(map[string]bool{ 123 | "BeforeSave": true, 124 | "AfterSave": true, 125 | }, doc.Checks) 126 | } 127 | 128 | func (s *EventsSuite) TestEventsSaveUpdate() { 129 | store := NewEventsSaveFixtureStore(s.db) 130 | 131 | doc := NewEventsSaveFixture() 132 | err := store.Insert(doc) 133 | s.Nil(err) 134 | 135 | doc.Checks = make(map[string]bool) 136 | updatedRows, err := store.Update(doc) 137 | s.Nil(err) 138 | s.True(updatedRows > 0) 139 | s.assertEventsPassed(map[string]bool{ 140 | "BeforeSave": true, 141 | "AfterSave": true, 142 | }, doc.Checks) 143 | } 144 | 145 | func (s *EventsSuite) TestEventsSaveSave() { 146 | store := NewEventsSaveFixtureStore(s.db) 147 | 148 | doc := NewEventsSaveFixture() 149 | err := store.Insert(doc) 150 | doc.Checks = map[string]bool{"AfterInsert": true} 151 | 152 | updated, err := store.Save(doc) 153 | s.Nil(err) 154 | s.True(updated) 155 | s.assertEventsPassed(map[string]bool{ 156 | "AfterInsert": true, 157 | "BeforeSave": true, 158 | "AfterSave": true, 159 | }, doc.Checks) 160 | } 161 | 162 | func (s *EventsSuite) TestEventsAllInsert() { 163 | store := NewEventsAllFixtureStore(s.db) 164 | 165 | doc := NewEventsAllFixture() 166 | err := store.Insert(doc) 167 | s.Nil(err) 168 | s.assertEventsPassed(map[string]bool{ 169 | "AfterInsert": true, 170 | "AfterSave": true, 171 | "BeforeSave": true, 172 | "BeforeInsert": true, 173 | }, doc.Checks) 174 | } 175 | 176 | func (s *EventsSuite) TestEventsAllUpdate() { 177 | store := NewEventsAllFixtureStore(s.db) 178 | 179 | doc := NewEventsAllFixture() 180 | err := store.Insert(doc) 181 | s.Nil(err) 182 | 183 | doc.Checks = make(map[string]bool) 184 | updatedRows, err := store.Update(doc) 185 | s.Nil(err) 186 | s.True(updatedRows > 0) 187 | s.assertEventsPassed(map[string]bool{ 188 | "BeforeUpdate": true, 189 | "BeforeSave": true, 190 | "AfterUpdate": true, 191 | "AfterSave": true, 192 | }, doc.Checks) 193 | } 194 | 195 | func (s *EventsSuite) TestEventsAllSave() { 196 | store := NewEventsAllFixtureStore(s.db) 197 | 198 | doc := NewEventsAllFixture() 199 | err := store.Insert(doc) 200 | s.Nil(err) 201 | s.assertEventsPassed(map[string]bool{ 202 | "AfterInsert": true, 203 | "AfterSave": true, 204 | "BeforeSave": true, 205 | "BeforeInsert": true, 206 | }, doc.Checks) 207 | 208 | doc.Checks = make(map[string]bool) 209 | 210 | updated, err := store.Save(doc) 211 | s.Nil(err) 212 | s.True(updated) 213 | s.assertEventsPassed(map[string]bool{ 214 | "BeforeUpdate": true, 215 | "BeforeSave": true, 216 | "AfterUpdate": true, 217 | "AfterSave": true, 218 | }, doc.Checks) 219 | } 220 | -------------------------------------------------------------------------------- /tests/fixtures/fixtures.go: -------------------------------------------------------------------------------- 1 | package fixtures 2 | 3 | import "database/sql/driver" 4 | 5 | type AliasArray [3]string 6 | type AliasSlice []string 7 | type AliasString string 8 | type AliasInt int 9 | type AliasArrAliasSlice []AliasSlice 10 | type AliasArrAliasString []AliasString 11 | type AliasDummyParam QueryDummy 12 | 13 | type QueryDummy struct { 14 | name string 15 | } 16 | 17 | type InterfaceImplementation struct { 18 | ScannerValuer 19 | Str string 20 | } 21 | 22 | type ScannerValuer struct{} 23 | 24 | func (i ScannerValuer) Value() (driver.Value, error) { 25 | return nil, nil 26 | } 27 | 28 | func (i ScannerValuer) Scan(src interface{}) error { 29 | return nil 30 | } 31 | -------------------------------------------------------------------------------- /tests/json.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import kallax "gopkg.in/src-d/go-kallax.v1" 4 | 5 | type JSONModel struct { 6 | kallax.Model `table:"jsons"` 7 | ID kallax.ULID `pk:""` 8 | Foo string 9 | Bar *Bar 10 | BazSlice []Baz 11 | Baz map[string]interface{} 12 | } 13 | 14 | type Bar struct { 15 | Qux []Qux 16 | Mux string 17 | } 18 | 19 | type Baz struct { 20 | Mux string 21 | } 22 | 23 | type Qux struct { 24 | Schnooga string 25 | Balooga int 26 | Boo float64 27 | } 28 | 29 | func newJSONModel() *JSONModel { 30 | return &JSONModel{ID: kallax.NewULID(), Baz: make(map[string]interface{})} 31 | } 32 | -------------------------------------------------------------------------------- /tests/json_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/suite" 7 | kallax "gopkg.in/src-d/go-kallax.v1" 8 | ) 9 | 10 | type JSONSuite struct { 11 | BaseTestSuite 12 | } 13 | 14 | func TestJSON(t *testing.T) { 15 | schema := []string{ 16 | `CREATE TABLE IF NOT EXISTS jsons ( 17 | id uuid primary key, 18 | foo text, 19 | bar jsonb, 20 | baz jsonb, 21 | baz_slice jsonb 22 | )`, 23 | } 24 | suite.Run(t, &JSONSuite{NewBaseSuite(schema, "jsons")}) 25 | } 26 | 27 | func (s *JSONSuite) TestSearchByField() { 28 | s.insertFixtures() 29 | q := NewJSONModelQuery().Where( 30 | kallax.Eq(Schema.JSONModel.Bar.Mux, "mux1"), 31 | ) 32 | s.assertFound(q, "1") 33 | } 34 | 35 | func (s *JSONSuite) TestSearchByCustomField() { 36 | s.insertFixtures() 37 | q := NewJSONModelQuery().Where( 38 | kallax.Eq(kallax.AtJSONPath(Schema.JSONModel.Baz, kallax.JSONInt, "a", "0", "b"), 3), 39 | ) 40 | 41 | s.assertFound(q, "2") 42 | 43 | q = NewJSONModelQuery().Where( 44 | kallax.Eq(kallax.AtJSONPath(Schema.JSONModel.Baz, kallax.JSONBool, "b"), true), 45 | ) 46 | 47 | s.assertFound(q, "1") 48 | } 49 | 50 | func (s *JSONSuite) assertFound(q *JSONModelQuery, foos ...string) { 51 | require := s.Require() 52 | store := NewJSONModelStore(s.db) 53 | rs, err := store.Find(q) 54 | require.NoError(err) 55 | 56 | models, err := rs.All() 57 | require.NoError(err) 58 | require.Len(models, len(foos)) 59 | for i, f := range foos { 60 | require.Equal(f, models[i].Foo) 61 | } 62 | } 63 | 64 | func (s *JSONSuite) insertFixtures() { 65 | store := NewJSONModelStore(s.db) 66 | 67 | m := NewJSONModel() 68 | m.Foo = "1" 69 | m.Baz = map[string]interface{}{ 70 | "a": []interface{}{ 71 | map[string]interface{}{ 72 | "b": 1, 73 | }, 74 | map[string]interface{}{ 75 | "b": 2, 76 | }, 77 | }, 78 | "b": true, 79 | } 80 | m.Bar = &Bar{ 81 | Qux: []Qux{ 82 | {"schnooga1", 1, .5}, 83 | {"schnooga2", 2, .6}, 84 | }, 85 | Mux: "mux1", 86 | } 87 | 88 | s.NoError(store.Insert(m)) 89 | 90 | m = NewJSONModel() 91 | m.Foo = "2" 92 | m.Baz = map[string]interface{}{ 93 | "a": []interface{}{ 94 | map[string]interface{}{ 95 | "b": 3, 96 | }, 97 | map[string]interface{}{ 98 | "b": 4, 99 | }, 100 | }, 101 | "b": false, 102 | } 103 | m.Bar = &Bar{ 104 | Qux: []Qux{ 105 | {"schnooga3", 3, .7}, 106 | {"schnooga4", 4, .8}, 107 | }, 108 | Mux: "mux2", 109 | } 110 | 111 | s.NoError(store.Insert(m)) 112 | } 113 | -------------------------------------------------------------------------------- /tests/query.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "net/url" 5 | "time" 6 | 7 | "gopkg.in/src-d/go-kallax.v1" 8 | "gopkg.in/src-d/go-kallax.v1/tests/fixtures" 9 | ) 10 | 11 | type QueryFixture struct { 12 | kallax.Model `table:"query"` 13 | ID kallax.ULID `pk:""` 14 | 15 | Relation *QueryRelationFixture `fk:"owner_id"` 16 | Inverse *QueryRelationFixture `fk:"inverse_id,inverse"` 17 | NRelation []*QueryRelationFixture `fk:"owner_id"` 18 | Embedded fixtures.QueryDummy 19 | Ignored fixtures.QueryDummy `kallax:"-"` 20 | Inline struct { 21 | Inline string 22 | } `kallax:",inline"` 23 | MapOfString map[string]string 24 | MapOfInterface map[string]interface{} 25 | MapOfSomeType map[string]fixtures.QueryDummy 26 | Foo string 27 | StringProperty string 28 | Integer int 29 | Integer64 int64 30 | Float32 float32 31 | Boolean bool 32 | ArrayParam [3]string 33 | SliceParam []string 34 | AliasArrayParam fixtures.AliasArray 35 | AliasSliceParam fixtures.AliasSlice 36 | AliasStringParam fixtures.AliasString 37 | AliasIntParam fixtures.AliasInt 38 | DummyParam fixtures.QueryDummy 39 | AliasDummyParam fixtures.AliasDummyParam 40 | SliceDummyParam []fixtures.QueryDummy 41 | IDPropertyParam kallax.ULID 42 | InterfacePropParam fixtures.InterfaceImplementation `sqltype:"jsonb"` 43 | URLParam url.URL 44 | TimeParam time.Time 45 | AliasArrAliasStringParam fixtures.AliasArrAliasString 46 | AliasHereArrayParam AliasHereArray 47 | ArrayAliasHereStringParam []AliasHereString 48 | ScannerValuerParam ScannerValuer `sqltype:"jsonb"` 49 | } 50 | 51 | type AliasHereString string 52 | type AliasHereArray [3]string 53 | type ScannerValuer struct { 54 | fixtures.ScannerValuer 55 | } 56 | 57 | type AliasID kallax.ULID 58 | 59 | func newQueryFixture(f string) *QueryFixture { 60 | return &QueryFixture{ID: kallax.NewULID(), Foo: f} 61 | } 62 | 63 | func (q *QueryFixture) Eq(v *QueryFixture) bool { 64 | return q.ID == v.ID 65 | } 66 | 67 | type QueryRelationFixture struct { 68 | kallax.Model `table:"query_relation"` 69 | ID kallax.ULID `pk:""` 70 | Name string 71 | Owner *QueryFixture `fk:"owner_id,inverse"` 72 | } 73 | 74 | var queryFixtures = []*QueryFixture{ 75 | { 76 | ID: kallax.NewULID(), 77 | Foo: "Foo0", 78 | StringProperty: "StringProperty0", 79 | Integer: 0, 80 | Integer64: 0, 81 | Float32: 0, 82 | Boolean: true, 83 | ArrayParam: [3]string{"ArrayParam0One", "ArrayParam0Two", "ArrayParam0Three"}, 84 | SliceParam: []string{"SliceParam0One", "SliceParam0Two", "SliceParam0Three"}, 85 | AliasArrayParam: [3]string{"AliasArray0One", "AliasArray0Two", "AliasArray0Three"}, 86 | AliasSliceParam: []string{"AliasSlice0One", "AliasSlice0Two", "AliasSlice0Three"}, 87 | AliasStringParam: "AliasString0", 88 | AliasIntParam: 0, 89 | }, 90 | { 91 | ID: kallax.NewULID(), 92 | Foo: "Foo1", 93 | StringProperty: "StringProperty1", 94 | Integer: 1, 95 | Integer64: 1, 96 | Float32: 1, 97 | Boolean: false, 98 | ArrayParam: [3]string{"ArrayParm1One", "ArrayParm1Two", "ArrayParm1Three"}, 99 | SliceParam: []string{"SliceParam1One", "SliceParam1Two", "SliceParam1Three"}, 100 | AliasArrayParam: [3]string{"AliasArray1One", "AliasArray1Two", "AliasArray1Three"}, 101 | AliasSliceParam: []string{"AliasSlice1One", "AliasSlice1Two", "AliasSlice1Three"}, 102 | AliasStringParam: "AliasString1", 103 | AliasIntParam: 1, 104 | }, 105 | { 106 | ID: kallax.NewULID(), 107 | Foo: "Foo2", 108 | StringProperty: "StringProperty2", 109 | Integer: 2, 110 | Integer64: 2, 111 | Float32: 2, 112 | Boolean: true, 113 | ArrayParam: [3]string{"ArrayParm2One", "ArrayParm2Two", "ArrayParm2Three"}, 114 | SliceParam: []string{"SliceParam2One", "SliceParam2Two", "SliceParam2Three"}, 115 | AliasArrayParam: [3]string{"AliasArray2One", "AliasArray2Two", "AliasArray2Three"}, 116 | AliasSliceParam: []string{"AliasSlice2One", "AliasSlice2Two", "AliasSlice2Three"}, 117 | AliasStringParam: "AliasString2", 118 | AliasIntParam: 2, 119 | }, 120 | } 121 | 122 | func resetQueryFixtures() { 123 | for i, fixture := range queryFixtures { 124 | queryFixtures[i] = &QueryFixture{ 125 | ID: fixture.ID, 126 | Foo: fixture.Foo, 127 | StringProperty: fixture.StringProperty, 128 | Integer: fixture.Integer, 129 | Integer64: fixture.Integer64, 130 | Float32: fixture.Float32, 131 | Boolean: fixture.Boolean, 132 | ArrayParam: fixture.ArrayParam, 133 | SliceParam: fixture.SliceParam, 134 | AliasArrayParam: fixture.AliasArrayParam, 135 | AliasSliceParam: fixture.AliasSliceParam, 136 | AliasStringParam: fixture.AliasStringParam, 137 | AliasIntParam: fixture.AliasIntParam, 138 | } 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /tests/query_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/suite" 9 | "gopkg.in/src-d/go-kallax.v1" 10 | ) 11 | 12 | type QuerySuite struct { 13 | BaseTestSuite 14 | } 15 | 16 | func TestQuerySuite(t *testing.T) { 17 | schema := []string{ 18 | `CREATE TABLE IF NOT EXISTS query ( 19 | id uuid primary key, 20 | idproperty uuid, 21 | idproperty_ptr uuid, 22 | inverse_id uuid, 23 | foo varchar(256), 24 | embedded jsonb, 25 | inline varchar(256), 26 | map_of_string jsonb, 27 | map_of_interface jsonb, 28 | map_of_some_type jsonb, 29 | string_property varchar(256), 30 | integer int, 31 | integer64 bigint, 32 | float32 float, 33 | boolean boolean, 34 | array_param text[], 35 | slice_param text[], 36 | alias_array_param text[], 37 | alias_slice_param text[], 38 | alias_string_param varchar(256), 39 | alias_int_param int, 40 | dummy_param jsonb, 41 | alias_dummy_param jsonb, 42 | slice_dummy_param jsonb, 43 | idproperty_param uuid, 44 | idproperty_ptr_param uuid, 45 | slice_idpptr_param uuid[], 46 | interface_prop_param text, 47 | urlparam varchar(256), 48 | time_param timestamptz, 49 | alias_arr_alias_string_param text[], 50 | alias_here_array_param text[], 51 | array_alias_here_string_param text[], 52 | scanner_valuer_param text 53 | )`, 54 | `CREATE TABLE IF NOT EXISTS query_relation ( 55 | id uuid primary key, 56 | name varchar(256), 57 | owner_id uuid references query(id) 58 | )`, 59 | } 60 | suite.Run(t, &QuerySuite{NewBaseSuite(schema, "query_relation", "query")}) 61 | } 62 | 63 | func (s *QuerySuite) SetupTest() { 64 | s.BaseTestSuite.SetupTest() 65 | 66 | resetQueryFixtures() 67 | store := NewQueryFixtureStore(s.db) 68 | for _, fixture := range queryFixtures { 69 | s.Nil(store.Insert(fixture)) 70 | } 71 | } 72 | 73 | func (s *QuerySuite) TestInsertTruncateTime() { 74 | s.BaseTestSuite.SetupTest() 75 | f := NewQueryFixture("fixture") 76 | for f.TimeParam.Nanosecond() == 0 { 77 | f.TimeParam = time.Now() 78 | } 79 | 80 | store := NewQueryFixtureStore(s.db) 81 | s.NoError(store.Insert(f)) 82 | 83 | f2, err := store.FindOne(NewQueryFixtureQuery().FindByID(f.ID)) 84 | s.NoError(err) 85 | s.Equal(f.TimeParam, f2.TimeParam.Local()) 86 | } 87 | 88 | func (s *QuerySuite) TestUpdateTruncateTime() { 89 | s.BaseTestSuite.SetupTest() 90 | f := NewQueryFixture("fixture") 91 | store := NewQueryFixtureStore(s.db) 92 | s.NoError(store.Insert(f)) 93 | for f.TimeParam.Nanosecond() == 0 { 94 | f.TimeParam = time.Now() 95 | } 96 | 97 | _, err := store.Update(f) 98 | s.NoError(err) 99 | f2, err := store.FindOne(NewQueryFixtureQuery().FindByID(f.ID)) 100 | s.NoError(err) 101 | s.Equal(f.TimeParam, f2.TimeParam.Local()) 102 | } 103 | 104 | func (s *QuerySuite) TestSaveTruncateTime() { 105 | s.BaseTestSuite.SetupTest() 106 | f := NewQueryFixture("fixture") 107 | for f.TimeParam.Nanosecond() == 0 { 108 | f.TimeParam = time.Now() 109 | } 110 | 111 | store := NewQueryFixtureStore(s.db) 112 | _, err := store.Save(f) 113 | s.NoError(err) 114 | 115 | f2, err := store.FindOne(NewQueryFixtureQuery().FindByID(f.ID)) 116 | s.NoError(err) 117 | s.Equal(f.TimeParam, f2.TimeParam.Local()) 118 | } 119 | 120 | func (s *QuerySuite) TestQuery() { 121 | store := NewQueryFixtureStore(s.db) 122 | doc := NewQueryFixture("bar") 123 | s.Nil(store.Insert(doc)) 124 | 125 | query := NewQueryFixtureQuery() 126 | query.Where(kallax.Eq(Schema.QueryFixture.ID, doc.ID)) 127 | 128 | s.NotPanics(func() { 129 | s.Equal("bar", store.MustFindOne(query).Foo) 130 | }) 131 | 132 | notID := kallax.NewULID() 133 | queryErr := NewQueryFixtureQuery() 134 | queryErr.Where(kallax.Eq(Schema.QueryFixture.ID, notID)) 135 | s.Panics(func() { 136 | s.Equal("bar", store.MustFindOne(queryErr).Foo) 137 | }) 138 | } 139 | 140 | func (s *QuerySuite) TestFindById() { 141 | store := NewQueryFixtureStore(s.db) 142 | 143 | docName := "bar" 144 | doc := NewQueryFixture(docName) 145 | s.Nil(store.Insert(doc)) 146 | 147 | query := NewQueryFixtureQuery() 148 | query.FindByID(doc.ID) 149 | s.NotPanics(func() { 150 | s.Equal(docName, store.MustFindOne(query).Foo) 151 | }) 152 | 153 | queryManyId := NewQueryFixtureQuery() 154 | queryManyId.FindByID(queryFixtures[1].ID, queryFixtures[2].ID) 155 | count, err := store.Count(queryManyId) 156 | s.Equal(2, int(count)) 157 | s.Nil(err) 158 | 159 | notID := kallax.NewULID() 160 | queryErr := NewQueryFixtureQuery() 161 | queryErr.FindByID(notID) 162 | s.Panics(func() { 163 | store.MustFindOne(queryErr) 164 | }) 165 | } 166 | 167 | func (s *QuerySuite) TestFindBy() { 168 | store := NewQueryFixtureStore(s.db) 169 | s.NotPanics(func() { 170 | s.True(store.MustFindOne(NewQueryFixtureQuery().FindByStringProperty("StringProperty1")).Eq(queryFixtures[1])) 171 | }) 172 | s.Panics(func() { 173 | store.MustFindOne(NewQueryFixtureQuery().FindByStringProperty("NOT_FOUND")) 174 | }) 175 | 176 | s.NotPanics(func() { 177 | s.True(store.MustFindOne(NewQueryFixtureQuery().FindByBoolean(false)).Eq(queryFixtures[1])) 178 | }) 179 | s.NotPanics(func() { 180 | count, err := store.Count(NewQueryFixtureQuery().FindByBoolean(true)) 181 | s.Equal(int64(2), count) 182 | s.Nil(err) 183 | }) 184 | 185 | s.NotPanics(func() { 186 | s.True(store.MustFindOne(NewQueryFixtureQuery().FindByInteger(kallax.Eq, 2)).Eq(queryFixtures[2])) 187 | }) 188 | s.Panics(func() { 189 | store.MustFindOne(NewQueryFixtureQuery().FindByInteger(kallax.Eq, 99)) 190 | }) 191 | 192 | s.NotPanics(func() { 193 | count, err := store.Count(NewQueryFixtureQuery().FindByInteger(kallax.GtOrEq, 1)) 194 | s.Equal(int64(2), count) 195 | s.Nil(err) 196 | }) 197 | s.NotPanics(func() { 198 | count, err := store.Count(NewQueryFixtureQuery().FindByInteger(kallax.Lt, 0)) 199 | s.Equal(int64(0), count) 200 | s.Nil(err) 201 | }) 202 | } 203 | 204 | func (s *QuerySuite) TestGeneration() { 205 | var cases = []struct { 206 | propertyName string 207 | autoGeneratedFindBy bool 208 | }{ 209 | {"ID", true}, 210 | {"SelfRelation", false}, 211 | {"Inverse", true}, 212 | {"SelfNRelation", false}, 213 | {"Embedded", false}, 214 | {"Ignored", false}, 215 | {"Inline", true}, 216 | {"MapOfString", false}, 217 | {"MapOfInterface", false}, 218 | {"MapOfSomeType", false}, 219 | {"Foo", true}, 220 | {"StringProperty", true}, 221 | {"Integer", true}, 222 | {"Integer64", true}, 223 | {"Float32", true}, 224 | {"Boolean", true}, 225 | {"ArrayParam", true}, 226 | {"SliceParam", true}, 227 | {"AliasArrayParam", true}, 228 | {"AliasSliceParam", true}, 229 | {"AliasStringParam", true}, 230 | {"AliasIntParam", true}, 231 | {"DummyParam", false}, 232 | {"AliasDummyParam", false}, 233 | {"SliceDummyParam", false}, 234 | {"IDPropertyParam", true}, 235 | {"InterfacePropParam", true}, 236 | {"URLParam", true}, 237 | {"TimeParam", true}, 238 | {"AliasArrAliasStringParam", true}, 239 | {"AliasArrAliasSliceParam", false}, 240 | {"ArrayArrayParam", false}, 241 | {"AliasHereArrayParam", true}, 242 | {"ScannerValuerParam", true}, 243 | } 244 | 245 | q := NewQueryFixtureQuery() 246 | for _, c := range cases { 247 | s.hasFindByMethod(q, c.propertyName, c.autoGeneratedFindBy) 248 | } 249 | } 250 | 251 | func (s *QuerySuite) hasFindByMethod(q *QueryFixtureQuery, name string, exists bool) { 252 | queryValue := reflect.TypeOf(q) 253 | _, ok := queryValue.MethodByName("FindBy" + name) 254 | if exists { 255 | s.True(ok, "'FindBy%s' method should BE generated", name) 256 | } else { 257 | s.False(ok, "'FindBy%s' method should NOT be generated", name) 258 | } 259 | } 260 | -------------------------------------------------------------------------------- /tests/relationships.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import kallax "gopkg.in/src-d/go-kallax.v1" 4 | 5 | type Car struct { 6 | kallax.Model `table:"cars"` 7 | ID kallax.ULID `pk:""` 8 | Owner *Person `fk:"owner_id,inverse"` 9 | ModelName string 10 | Brand Brand `fk:"brand_id,inverse"` 11 | events map[string]int 12 | } 13 | 14 | type Brand struct { 15 | kallax.Model `table:"brands"` 16 | ID kallax.ULID `pk:""` 17 | Name string 18 | } 19 | 20 | func (c *Car) ensureMapInitialized() { 21 | if c.events == nil { 22 | c.events = make(map[string]int) 23 | } 24 | } 25 | 26 | func (c *Car) BeforeSave() error { 27 | c.ensureMapInitialized() 28 | c.events["BeforeSave"]++ 29 | return nil 30 | } 31 | 32 | func (c *Car) AfterSave() error { 33 | c.ensureMapInitialized() 34 | c.events["AfterSave"]++ 35 | return nil 36 | } 37 | 38 | func (c *Car) BeforeDelete() error { 39 | c.ensureMapInitialized() 40 | c.events["BeforeDelete"]++ 41 | return nil 42 | } 43 | 44 | func (c *Car) AfterDelete() error { 45 | c.ensureMapInitialized() 46 | c.events["AfterDelete"]++ 47 | return nil 48 | } 49 | 50 | type Person struct { 51 | kallax.Model `table:"persons"` 52 | ID int64 `pk:"autoincr"` 53 | Name string 54 | Pets []*Pet `fk:"owner_id"` 55 | PetsArr [2]*Pet `fk:"owner_id"` 56 | Car *Car `fk:"owner_id"` 57 | events map[string]int 58 | } 59 | 60 | func (c *Person) ensureMapInitialized() { 61 | if c.events == nil { 62 | c.events = make(map[string]int) 63 | } 64 | } 65 | 66 | func (c *Person) BeforeSave() error { 67 | c.ensureMapInitialized() 68 | c.events["BeforeSave"]++ 69 | return nil 70 | } 71 | 72 | func (c *Person) AfterSave() error { 73 | c.ensureMapInitialized() 74 | c.events["AfterSave"]++ 75 | return nil 76 | } 77 | 78 | func (c *Person) BeforeDelete() error { 79 | c.ensureMapInitialized() 80 | c.events["BeforeDelete"]++ 81 | return nil 82 | } 83 | 84 | func (c *Person) AfterDelete() error { 85 | c.ensureMapInitialized() 86 | c.events["AfterDelete"]++ 87 | return nil 88 | } 89 | 90 | type Pet struct { 91 | kallax.Model `table:"pets"` 92 | ID kallax.ULID `pk:""` 93 | Name string 94 | Kind string 95 | Owner *Person `fk:"owner_id,inverse"` 96 | events map[string]int 97 | } 98 | 99 | func (c *Pet) ensureMapInitialized() { 100 | if c.events == nil { 101 | c.events = make(map[string]int) 102 | } 103 | } 104 | 105 | func (c *Pet) BeforeSave() error { 106 | c.ensureMapInitialized() 107 | c.events["BeforeSave"]++ 108 | return nil 109 | } 110 | 111 | func (c *Pet) AfterSave() error { 112 | c.ensureMapInitialized() 113 | c.events["AfterSave"]++ 114 | return nil 115 | } 116 | 117 | func (c *Pet) BeforeDelete() error { 118 | c.ensureMapInitialized() 119 | c.events["BeforeDelete"]++ 120 | return nil 121 | } 122 | 123 | func (c *Pet) AfterDelete() error { 124 | c.ensureMapInitialized() 125 | c.events["AfterDelete"]++ 126 | return nil 127 | } 128 | 129 | func newPet(name, kind string, owner *Person) *Pet { 130 | pet := &Pet{ID: kallax.NewULID(), Name: name, Kind: kind, Owner: owner} 131 | owner.Pets = append(owner.Pets, pet) 132 | return pet 133 | } 134 | 135 | func newPerson(name string) *Person { 136 | return &Person{Name: name} 137 | } 138 | 139 | func newCar(model string, owner *Person) *Car { 140 | car := &Car{ID: kallax.NewULID(), ModelName: model, Owner: owner} 141 | owner.Car = car 142 | return car 143 | } 144 | 145 | func newBrandedCar(model string, owner *Person, brand Brand) *Car { 146 | car := &Car{ID: kallax.NewULID(), ModelName: model, Owner: owner, Brand: brand} 147 | owner.Car = car 148 | return car 149 | } 150 | 151 | func newBrand(name string) *Brand { 152 | return &Brand{Name: name, ID: kallax.NewULID()} 153 | } 154 | 155 | func makeBrand(name string) Brand { 156 | return Brand{Name: name, ID: kallax.NewULID()} 157 | } 158 | -------------------------------------------------------------------------------- /tests/relationships_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/suite" 7 | ) 8 | 9 | type RelationshipsSuite struct { 10 | BaseTestSuite 11 | } 12 | 13 | func TestRelationships(t *testing.T) { 14 | schemas := []string{ 15 | `CREATE TABLE IF NOT EXISTS persons ( 16 | id serial primary key, 17 | name text 18 | )`, 19 | `CREATE TABLE IF NOT EXISTS brands ( 20 | id uuid primary key, 21 | name text 22 | )`, 23 | `CREATE TABLE IF NOT EXISTS cars ( 24 | id uuid primary key, 25 | model_name text, 26 | owner_id integer references persons(id), 27 | brand_id uuid references brands(id) 28 | )`, 29 | `CREATE TABLE IF NOT EXISTS pets ( 30 | id uuid primary key, 31 | name text, 32 | kind text, 33 | owner_id integer references persons(id) 34 | )`, 35 | } 36 | suite.Run(t, &RelationshipsSuite{NewBaseSuite(schemas, "cars", "pets", "persons")}) 37 | } 38 | 39 | func (s *RelationshipsSuite) TestInsertFind() { 40 | require := s.Require() 41 | p := NewPerson("Dolan") 42 | car := NewCar("Tesla Model S", p) 43 | cat := NewPet("Garfield", "cat", p) 44 | dog := NewPet("Oddie", "dog", p) 45 | 46 | store := NewPersonStore(s.db) 47 | require.NoError(store.Insert(p)) 48 | 49 | pers := s.getPerson() 50 | s.assertPerson(p.Name, pers, car, cat, dog) 51 | } 52 | 53 | func (s *RelationshipsSuite) TestInsertFindRemove() { 54 | p := NewPerson("Dolan") 55 | car := NewCar("Tesla Model S", p) 56 | cat := NewPet("Garfield", "cat", p) 57 | dog := NewPet("Oddie", "dog", p) 58 | reptar := NewPet("Reptar", "dinosaur", p) 59 | 60 | store := NewPersonStore(s.db) 61 | s.NoError(store.Insert(p)) 62 | 63 | pers := s.getPerson() 64 | s.assertPerson(p.Name, pers, car, cat, dog, reptar) 65 | 66 | s.NoError(store.RemovePets(pers, dog)) 67 | pers = s.getPerson() 68 | s.assertPerson(p.Name, pers, car, cat, reptar) 69 | 70 | s.NoError(store.RemovePets(pers)) 71 | s.NoError(store.RemoveCar(pers)) 72 | pers = s.getPerson() 73 | s.assertPerson(p.Name, pers, nil) 74 | } 75 | 76 | func (s *RelationshipsSuite) TestInsertFindUpdate() { 77 | p := NewPerson("Dolan") 78 | car := NewCar("Tesla Model S", p) 79 | cat := NewPet("Garfield", "cat", p) 80 | dog := NewPet("Oddie", "dog", p) 81 | 82 | store := NewPersonStore(s.db) 83 | s.NoError(store.Insert(p)) 84 | 85 | pers := s.getPerson() 86 | s.assertPerson(p.Name, pers, car, cat, dog) 87 | 88 | pony := NewPet("Sparkling Twilight", "pony", pers) 89 | _, err := store.Save(pers) 90 | s.NoError(err) 91 | 92 | pers = s.getPerson() 93 | s.assertPerson(p.Name, pers, car, cat, dog, pony) 94 | } 95 | 96 | func (s *RelationshipsSuite) TestEvents() { 97 | p := NewPerson("Dolan") 98 | car := NewCar("Tesla Model S", p) 99 | cat := NewPet("Garfield", "cat", p) 100 | dog := NewPet("Oddie", "dog", p) 101 | reptar := NewPet("Reptar", "dinosaur", p) 102 | 103 | store := NewPersonStore(s.db) 104 | s.NoError(store.Insert(p)) 105 | 106 | s.assertEvents(p.events, "BeforeSave", "AfterSave") 107 | s.assertEvents(car.events, "BeforeSave", "AfterSave") 108 | s.assertEvents(cat.events, "BeforeSave", "AfterSave") 109 | s.assertEvents(dog.events, "BeforeSave", "AfterSave") 110 | s.assertEvents(reptar.events, "BeforeSave", "AfterSave") 111 | 112 | s.NoError(store.RemovePets(p, dog)) 113 | s.assertNoEvents(cat.events, "BeforeDelete", "AfterDelete") 114 | s.assertNoEvents(reptar.events, "BeforeDelete", "AfterDelete") 115 | s.assertEvents(dog.events, "BeforeDelete", "AfterDelete") 116 | 117 | s.NoError(store.RemovePets(p)) 118 | s.assertEvents(reptar.events, "BeforeDelete", "AfterDelete") 119 | s.assertEvents(cat.events, "BeforeDelete", "AfterDelete") 120 | 121 | s.NoError(store.RemoveCar(p)) 122 | s.assertEvents(car.events, "BeforeDelete", "AfterDelete") 123 | } 124 | 125 | func (s *RelationshipsSuite) TestSaveWithInverse() { 126 | p := NewPerson("Foo") 127 | car := NewCar("Bar", p) 128 | 129 | store := NewCarStore(s.db) 130 | s.NoError(store.Insert(car)) 131 | 132 | s.NotNil(s.getPerson()) 133 | } 134 | 135 | func (s *RelationshipsSuite) TestSaveRelations() { 136 | p := NewPerson("Musk") 137 | brand := makeBrand("Tesla") 138 | car := newBrandedCar("Model S", p, brand) 139 | 140 | store := NewCarStore(s.db).Debug() 141 | _, err := store.Save(car) 142 | s.NoError(err) 143 | 144 | car, err = store.FindOne(NewCarQuery().FindByID(car.ID).WithBrand()) 145 | s.NoError(err) 146 | s.NotNil(car) 147 | s.NotNil(car.Brand) 148 | 149 | pStore := NewPersonStore(s.db).Debug() 150 | 151 | p.Name = "Elon" 152 | _, err = pStore.Save(p) 153 | s.NoError(err) 154 | s.NotNil(p.Car) 155 | s.NotNil(p.Car.Brand) 156 | 157 | car, err = store.FindOne(NewCarQuery().FindByID(car.ID).WithBrand()) 158 | s.NoError(err) 159 | s.NotNil(car) 160 | s.NotNil(car.Brand) 161 | } 162 | 163 | func (s *RelationshipsSuite) assertEvents(evs map[string]int, events ...string) { 164 | for _, e := range events { 165 | s.Equal(1, evs[e]) 166 | } 167 | } 168 | 169 | func (s *RelationshipsSuite) assertNoEvents(evs map[string]int, events ...string) { 170 | for _, e := range events { 171 | s.Equal(0, evs[e]) 172 | } 173 | } 174 | 175 | func (s *RelationshipsSuite) assertPerson(name string, pers *Person, car *Car, pets ...*Pet) { 176 | require := s.Require() 177 | require.False(pers.GetID().IsEmpty(), "ID should not be empty") 178 | require.Equal(name, pers.Name) 179 | pers.events = nil 180 | require.Len(pers.Pets, len(pets)) 181 | 182 | // Owner are set to nil to be able to deep equal in the tests. 183 | // Records coming from relationships don't have their relationships 184 | // populated, so it will always be nil. 185 | // Same with events. 186 | var petList = make([]*Pet, len(pets)) 187 | for i, pet := range pets { 188 | p := *pet 189 | require.False(p.GetID().IsEmpty(), "ID should not be empty") 190 | p.Owner = nil 191 | p.events = nil 192 | petList[i] = &p 193 | } 194 | 195 | var c Car 196 | if car == nil { 197 | require.Nil(pers.Car) 198 | } else { 199 | c = *car 200 | require.False(c.GetID().IsEmpty(), "ID should not be empty") 201 | c.Owner = nil 202 | c.events = nil 203 | require.Equal(&c, pers.Car) 204 | } 205 | for i, p := range petList { 206 | require.Equal(p, pers.Pets[i]) 207 | } 208 | } 209 | 210 | func (s *RelationshipsSuite) getPerson() *Person { 211 | require := s.Require() 212 | q := NewPersonQuery(). 213 | WithCar(). 214 | WithPets(nil) 215 | pers, err := NewPersonStore(s.db).FindOne(q) 216 | require.NoError(err) 217 | require.NotNil(pers) 218 | 219 | return pers 220 | } 221 | -------------------------------------------------------------------------------- /tests/resultset.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import "gopkg.in/src-d/go-kallax.v1" 4 | 5 | type ResultSetFixture struct { 6 | kallax.Model `table:"resultset"` 7 | ID kallax.ULID `pk:""` 8 | Foo string 9 | } 10 | 11 | func newResultSetFixture(f string) *ResultSetFixture { 12 | return &ResultSetFixture{ID: kallax.NewULID(), Foo: f} 13 | } 14 | -------------------------------------------------------------------------------- /tests/resultset_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/suite" 8 | "gopkg.in/src-d/go-kallax.v1" 9 | ) 10 | 11 | type ResulsetSuite struct { 12 | BaseTestSuite 13 | } 14 | 15 | func TestResulsetSuite(t *testing.T) { 16 | schema := []string{ 17 | `CREATE TABLE IF NOT EXISTS resultset ( 18 | id uuid primary key, 19 | foo varchar(10) 20 | )`, 21 | } 22 | suite.Run(t, &ResulsetSuite{NewBaseSuite(schema, "resultset")}) 23 | } 24 | 25 | func (s *ResulsetSuite) TestResultSetAll() { 26 | store := NewResultSetFixtureStore(s.db) 27 | s.Nil(store.Insert(NewResultSetFixture("bar"))) 28 | s.Nil(store.Insert(NewResultSetFixture("foo"))) 29 | 30 | s.NotPanics(func() { 31 | rs := store.MustFind(NewResultSetFixtureQuery()) 32 | docs, err := rs.All() 33 | s.Nil(err) 34 | s.Len(docs, 2) 35 | }) 36 | } 37 | 38 | func (s *ResulsetSuite) TestResultSetOne() { 39 | store := NewResultSetFixtureStore(s.db) 40 | s.Nil(store.Insert(NewResultSetFixture("bar"))) 41 | 42 | s.NotPanics(func() { 43 | rs := store.MustFind(NewResultSetFixtureQuery()) 44 | doc, err := rs.One() 45 | s.Nil(err) 46 | s.Equal("bar", doc.Foo) 47 | }) 48 | } 49 | 50 | func (s *ResulsetSuite) TestResultSetNextEmpty() { 51 | store := NewResultSetFixtureStore(s.db) 52 | 53 | s.NotPanics(func() { 54 | rs := store.MustFind(NewResultSetFixtureQuery()) 55 | returned := rs.Next() 56 | s.False(returned) 57 | 58 | doc, err := rs.Get() 59 | s.Nil(err) 60 | s.Nil(doc) 61 | }) 62 | } 63 | 64 | func (s *ResulsetSuite) TestResultSetNext() { 65 | store := NewResultSetFixtureStore(s.db) 66 | s.Nil(store.Insert(NewResultSetFixture("bar"))) 67 | 68 | s.NotPanics(func() { 69 | rs := store.MustFind(NewResultSetFixtureQuery()) 70 | returned := rs.Next() 71 | s.True(returned) 72 | 73 | doc, err := rs.Get() 74 | s.Nil(err) 75 | s.Equal("bar", doc.Foo) 76 | 77 | returned = rs.Next() 78 | s.False(returned) 79 | 80 | doc, err = rs.Get() 81 | s.Nil(err) 82 | s.Nil(doc) 83 | }) 84 | } 85 | 86 | func (s *ResulsetSuite) TestResultSetForEach() { 87 | store := NewResultSetFixtureStore(s.db) 88 | s.Nil(store.Insert(NewResultSetFixture("bar"))) 89 | s.Nil(store.Insert(NewResultSetFixture("foo"))) 90 | 91 | s.NotPanics(func() { 92 | count := 0 93 | rs := store.MustFind(NewResultSetFixtureQuery()) 94 | err := rs.ForEach(func(*ResultSetFixture) error { 95 | count++ 96 | return nil 97 | }) 98 | 99 | s.Nil(err) 100 | s.Equal(2, count) 101 | }) 102 | } 103 | 104 | func (s *ResulsetSuite) TestResultSetForEachStop() { 105 | store := NewResultSetFixtureStore(s.db) 106 | s.Nil(store.Insert(NewResultSetFixture("bar"))) 107 | s.Nil(store.Insert(NewResultSetFixture("foo"))) 108 | 109 | s.NotPanics(func() { 110 | count := 0 111 | rs := store.MustFind(NewResultSetFixtureQuery()) 112 | err := rs.ForEach(func(*ResultSetFixture) error { 113 | count++ 114 | return kallax.ErrStop 115 | }) 116 | 117 | s.Nil(err) 118 | s.Equal(1, count) 119 | }) 120 | } 121 | 122 | func (s *ResulsetSuite) TestResultSetForEachError() { 123 | store := NewResultSetFixtureStore(s.db) 124 | s.Nil(store.Insert(NewResultSetFixture("bar"))) 125 | s.Nil(store.Insert(NewResultSetFixture("foo"))) 126 | 127 | fail := errors.New("kallax: foo") 128 | 129 | s.NotPanics(func() { 130 | rs := store.MustFind(NewResultSetFixtureQuery()) 131 | defer rs.Close() 132 | err := rs.ForEach(func(*ResultSetFixture) error { 133 | return fail 134 | }) 135 | 136 | s.Equal(fail, err) 137 | }) 138 | } 139 | 140 | func (s *ResulsetSuite) TestForEachAndCount() { 141 | store := NewResultSetFixtureStore(s.db) 142 | 143 | docInserted1 := NewResultSetFixture("bar") 144 | s.Nil(store.Insert(docInserted1)) 145 | docInserted2 := NewResultSetFixture("baz") 146 | s.Nil(store.Insert(docInserted2)) 147 | 148 | query := NewResultSetFixtureQuery() 149 | rs, err := store.Find(query) 150 | s.Nil(err) 151 | manualCount := 0 152 | rs.ForEach(func(doc *ResultSetFixture) error { 153 | manualCount++ 154 | s.NotNil(doc) 155 | return nil 156 | }) 157 | s.Equal(2, manualCount) 158 | 159 | queriedCount, err := store.Count(query) 160 | s.NoError(err) 161 | s.Equal(int64(2), queriedCount) 162 | } 163 | -------------------------------------------------------------------------------- /tests/schema.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import "gopkg.in/src-d/go-kallax.v1" 4 | 5 | type SchemaFixture struct { 6 | kallax.Model `table:"schema"` 7 | ID kallax.ULID `pk:""` 8 | ShouldIgnore string `kallax:"-"` 9 | String string 10 | Int int 11 | Nested *SchemaFixture 12 | Inline struct { 13 | Inline string 14 | } `kallax:",inline"` 15 | MapOfString map[string]string 16 | MapOfInterface map[string]interface{} 17 | MapOfSomeType map[string]struct { 18 | Foo string 19 | } 20 | Inverse *SchemaRelationshipFixture `fk:"rel_id,inverse"` 21 | } 22 | 23 | type SchemaRelationshipFixture struct { 24 | kallax.Model `table:"relationship"` 25 | ID kallax.ULID `pk:""` 26 | } 27 | 28 | func newSchemaFixture() *SchemaFixture { 29 | return &SchemaFixture{ID: kallax.NewULID()} 30 | } 31 | -------------------------------------------------------------------------------- /tests/schema_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/suite" 8 | ) 9 | 10 | type SchemaSuite struct { 11 | BaseTestSuite 12 | } 13 | 14 | func TestSchemaSuite(t *testing.T) { 15 | suite.Run(t, new(SchemaSuite)) 16 | } 17 | 18 | func (s *SchemaSuite) TestSchemaID() { 19 | s.Equal("id", Schema.SchemaFixture.ID.String()) 20 | } 21 | 22 | func (s *SchemaSuite) TestSchemaBasicField() { 23 | s.Equal("string", Schema.SchemaFixture.String.String()) 24 | } 25 | 26 | func (s *SchemaSuite) TestSchemaRanamedField() { 27 | s.Equal("int", Schema.SchemaFixture.Int.String()) 28 | } 29 | 30 | func (s *SchemaSuite) TestSchemaInlineField() { 31 | s.Equal("inline", Schema.SchemaFixture.Inline.String()) 32 | } 33 | 34 | func (s *SchemaSuite) TestSchemaMapsOfString() { 35 | s.Equal("map_of_string", Schema.SchemaFixture.MapOfString.String()) 36 | } 37 | 38 | func (s *SchemaSuite) TestSchemaMapOfSomeType() { 39 | s.Equal("map_of_interface", Schema.SchemaFixture.MapOfInterface.String()) 40 | } 41 | 42 | func (s *SchemaSuite) TestSchemaMapOfInterface() { 43 | s.Equal("map_of_some_type", Schema.SchemaFixture.MapOfSomeType.String()) 44 | } 45 | 46 | func (s *SchemaSuite) TestSchemaInverse() { 47 | s.Equal("rel_id", Schema.SchemaFixture.InverseFK.String()) 48 | } 49 | 50 | func (s *SchemaSuite) TestSchemaIgnored() { 51 | schema := reflect.ValueOf(Schema.SchemaFixture) 52 | field := reflect.Indirect(schema).FieldByName("ShouldIgnore") 53 | s.False(field.IsValid()) 54 | } 55 | -------------------------------------------------------------------------------- /tests/store.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "time" 5 | 6 | "gopkg.in/src-d/go-kallax.v1" 7 | ) 8 | 9 | type A struct { 10 | kallax.Model `table:"a" pk:"id,autoincr"` 11 | ID int64 12 | Name string 13 | B *B 14 | } 15 | 16 | func newA(name string) *A { 17 | return &A{Name: name} 18 | } 19 | 20 | type B struct { 21 | kallax.Model `table:"b" pk:"id,autoincr"` 22 | ID int64 23 | Name string 24 | A *A `fk:",inverse"` 25 | C *C 26 | } 27 | 28 | func newB(name string, a *A) *B { 29 | b := &B{Name: name, A: a} 30 | a.B = b 31 | return b 32 | } 33 | 34 | type C struct { 35 | kallax.Model `table:"c" pk:"id,autoincr"` 36 | ID int64 37 | Name string 38 | B *B `fk:",inverse"` 39 | } 40 | 41 | func newC(name string, b *B) *C { 42 | c := &C{Name: name, B: b} 43 | b.C = c 44 | return c 45 | } 46 | 47 | type AliasSliceString []string 48 | 49 | type StoreFixture struct { 50 | kallax.Model `table:"store" pk:"id"` 51 | ID kallax.ULID 52 | Foo string 53 | SliceProp []string 54 | AliasSliceProp AliasSliceString 55 | } 56 | 57 | func newStoreFixture() *StoreFixture { 58 | return &StoreFixture{ID: kallax.NewULID()} 59 | } 60 | 61 | type StoreWithConstructFixture struct { 62 | kallax.Model `table:"store_construct"` 63 | ID kallax.ULID `pk:""` 64 | Foo string 65 | } 66 | 67 | func newStoreWithConstructFixture(f string) *StoreWithConstructFixture { 68 | if f == "" { 69 | return nil 70 | } 71 | return &StoreWithConstructFixture{ID: kallax.NewULID(), Foo: f} 72 | } 73 | 74 | type StoreWithNewFixture struct { 75 | kallax.Model `table:"store_new"` 76 | ID kallax.ULID `pk:""` 77 | Foo string 78 | Bar string 79 | } 80 | 81 | func newStoreWithNewFixture() *StoreWithNewFixture { 82 | return &StoreWithNewFixture{ID: kallax.NewULID()} 83 | } 84 | 85 | type MultiKeySortFixture struct { 86 | kallax.Model `table:"query"` 87 | ID kallax.ULID `pk:""` 88 | Name string 89 | Start time.Time 90 | End time.Time 91 | } 92 | 93 | func newMultiKeySortFixture() *MultiKeySortFixture { 94 | return &MultiKeySortFixture{ID: kallax.NewULID()} 95 | } 96 | 97 | type SomeJSON struct { 98 | Foo int 99 | } 100 | 101 | type Nullable struct { 102 | kallax.Model `table:"nullable"` 103 | ID int64 `pk:"autoincr"` 104 | T *time.Time 105 | SomeJSON *SomeJSON 106 | Scanner *kallax.ULID 107 | } 108 | 109 | type Parent struct { 110 | kallax.Model `table:"parents" pk:"id,autoincr"` 111 | ID int64 112 | Name string 113 | Children []*Child 114 | } 115 | 116 | type Child struct { 117 | kallax.Model `table:"children"` 118 | ID int64 `pk:"autoincr"` 119 | Name string 120 | } 121 | 122 | type ParentNoPtr struct { 123 | kallax.Model `table:"parents"` 124 | ID int64 `pk:"autoincr"` 125 | Name string 126 | Children []Child `fk:"parent_id"` 127 | } 128 | -------------------------------------------------------------------------------- /timestamps.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import "time" 4 | 5 | // Timestamps contains the dates of the last time the model was created 6 | // or deleted. Because this is such a common functionality in models, it is 7 | // provided by default by the library. It is intended to be embedded in the 8 | // model. 9 | // 10 | // type MyModel struct { 11 | // kallax.Model 12 | // kallax.Timestamps 13 | // Foo string 14 | // } 15 | type Timestamps struct { 16 | // CreatedAt is the time where the object was created. 17 | CreatedAt time.Time 18 | // UpdatedAt is the time where the object was updated. 19 | UpdatedAt time.Time 20 | } 21 | 22 | // BeforeSave updates the last time the model was updated every single time the 23 | // model is saved, and the last time the model was created only if the model 24 | // has no date of creation yet. 25 | func (t *Timestamps) BeforeSave() error { 26 | if t.CreatedAt.IsZero() { 27 | t.CreatedAt = time.Now() 28 | } 29 | 30 | t.UpdatedAt = time.Now() 31 | return nil 32 | } 33 | -------------------------------------------------------------------------------- /timestamps_test.go: -------------------------------------------------------------------------------- 1 | package kallax 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestTimestampsBeforeSave(t *testing.T) { 11 | s := require.New(t) 12 | 13 | var ts Timestamps 14 | s.True(ts.CreatedAt.IsZero()) 15 | s.True(ts.UpdatedAt.IsZero()) 16 | 17 | s.NoError(ts.BeforeSave()) 18 | s.False(ts.CreatedAt.IsZero()) 19 | s.False(ts.UpdatedAt.IsZero()) 20 | 21 | createdAt := ts.CreatedAt 22 | updatedAt := ts.UpdatedAt 23 | time.Sleep(50 * time.Millisecond) 24 | s.NoError(ts.BeforeSave()) 25 | s.Equal(createdAt, ts.CreatedAt) 26 | s.NotEqual(updatedAt, ts.UpdatedAt) 27 | } 28 | -------------------------------------------------------------------------------- /types/slices_test.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "math" 7 | "net/url" 8 | "os" 9 | "reflect" 10 | "testing" 11 | 12 | "github.com/lib/pq" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestSlice(t *testing.T) { 17 | cases := []struct { 18 | v interface{} 19 | input interface{} 20 | dest interface{} 21 | }{ 22 | { 23 | &([]url.URL{mustURL("https://foo.com"), mustURL("http://foo.foo")}), 24 | []string{"https://foo.com", "http://foo.foo"}, 25 | &([]url.URL{}), 26 | }, 27 | { 28 | &([]*url.URL{mustPtrURL("https://foo.com"), mustPtrURL("http://foo.foo")}), 29 | []string{"https://foo.com", "http://foo.foo"}, 30 | &([]*url.URL{}), 31 | }, 32 | { 33 | &([]string{"a", "b"}), 34 | []string{"a", "b"}, 35 | &([]string{}), 36 | }, 37 | { 38 | &([]uint64{123, 321, 333}), 39 | []uint64{123, 321, 333}, 40 | &([]uint64{}), 41 | }, 42 | { 43 | &([]int{123, 321, 333}), 44 | []int{123, 321, 333}, 45 | &([]int{}), 46 | }, 47 | { 48 | &([]uint{123, 321, 333}), 49 | []uint{123, 321, 333}, 50 | &([]uint{}), 51 | }, 52 | { 53 | &([]int32{123, 321, 333}), 54 | []int32{123, 321, 333}, 55 | &([]int32{}), 56 | }, 57 | { 58 | &([]uint32{123, 321, 333}), 59 | []uint32{123, 321, 333}, 60 | &([]uint32{}), 61 | }, 62 | { 63 | &([]int16{123, 321, 333}), 64 | []int16{123, 321, 333}, 65 | &([]int16{}), 66 | }, 67 | { 68 | &([]uint16{123, 321, 333}), 69 | []uint16{123, 321, 333}, 70 | &([]uint16{}), 71 | }, 72 | { 73 | &([]int8{1, 3, 4}), 74 | []int8{1, 3, 4}, 75 | &([]int8{}), 76 | }, 77 | { 78 | &([]float32{1., 3., .4}), 79 | []float32{1., 3., .4}, 80 | &([]float32{}), 81 | }, 82 | } 83 | 84 | for _, c := range cases { 85 | t.Run(reflect.TypeOf(c.input).String(), func(t *testing.T) { 86 | require := require.New(t) 87 | arr := Slice(c.v) 88 | val, err := arr.Value() 89 | require.NoError(err) 90 | 91 | pqArr := pq.Array(c.input) 92 | pqVal, err := pqArr.Value() 93 | require.NoError(err) 94 | 95 | require.Equal(pqVal, val) 96 | require.NoError(Slice(c.dest).Scan(val)) 97 | require.Equal(c.v, c.dest) 98 | }) 99 | } 100 | 101 | t.Run("[]byte", func(t *testing.T) { 102 | require := require.New(t) 103 | arr := Slice([]byte{1, 2, 3}) 104 | val, err := arr.Value() 105 | require.NoError(err) 106 | 107 | var b []byte 108 | require.NoError(Slice(&b).Scan(val)) 109 | require.Equal([]byte{1, 2, 3}, b) 110 | }) 111 | } 112 | 113 | func TestSlice_Integration(t *testing.T) { 114 | cases := []struct { 115 | name string 116 | typ string 117 | input interface{} 118 | dst interface{} 119 | }{ 120 | { 121 | "int8", 122 | "smallint[]", 123 | []int8{math.MaxInt8, math.MinInt8}, 124 | &([]int8{}), 125 | }, 126 | { 127 | "byte", 128 | "bytea", 129 | []byte{math.MaxUint8, 0}, 130 | &([]byte{}), 131 | }, 132 | { 133 | "int16", 134 | "smallint[]", 135 | []int16{math.MaxInt16, math.MinInt16}, 136 | &([]int16{}), 137 | }, 138 | { 139 | "unsigned int16", 140 | "integer[]", 141 | []uint16{math.MaxUint16, 0}, 142 | &([]uint16{}), 143 | }, 144 | { 145 | "int32", 146 | "integer[]", 147 | []int32{math.MaxInt32, math.MinInt32}, 148 | &([]int32{}), 149 | }, 150 | { 151 | "unsigned int32", 152 | "bigint[]", 153 | []uint32{math.MaxUint32, 0}, 154 | &([]uint32{}), 155 | }, 156 | { 157 | "int/int64", 158 | "bigint[]", 159 | []int{math.MaxInt64, math.MinInt64}, 160 | &([]int{}), 161 | }, 162 | { 163 | "unsigned int/int64", 164 | "numeric(20)[]", 165 | []uint{math.MaxUint64, 0}, 166 | &([]uint{}), 167 | }, 168 | { 169 | "float32", 170 | "decimal(10,3)[]", 171 | []float32{.3, .6}, 172 | &([]float32{.3, .6}), 173 | }, 174 | } 175 | 176 | db, err := openTestDB() 177 | require.NoError(t, err) 178 | 179 | defer func() { 180 | _, err = db.Exec("DROP TABLE IF EXISTS foo") 181 | require.NoError(t, err) 182 | 183 | require.NoError(t, db.Close()) 184 | }() 185 | 186 | for _, c := range cases { 187 | t.Run(c.name, func(t *testing.T) { 188 | require := require.New(t) 189 | 190 | _, err := db.Exec(fmt.Sprintf(`CREATE TABLE foo ( 191 | testcol %s 192 | )`, c.typ)) 193 | require.NoError(err, c.name) 194 | 195 | defer func() { 196 | _, err := db.Exec("DROP TABLE foo") 197 | require.NoError(err) 198 | }() 199 | 200 | _, err = db.Exec("INSERT INTO foo (testcol) VALUES ($1)", Slice(c.input)) 201 | require.NoError(err, c.name) 202 | 203 | require.NoError(db.QueryRow("SELECT testcol FROM foo LIMIT 1").Scan(Slice(c.dst)), c.name) 204 | slice := reflect.ValueOf(c.dst).Elem().Interface() 205 | require.Equal(c.input, slice, c.name) 206 | }) 207 | } 208 | } 209 | 210 | func TestByteArray_ScannerNoBufferReuse(t *testing.T) { 211 | require := require.New(t) 212 | 213 | var sharedbuf [32]byte 214 | 215 | var ba ByteArray 216 | 217 | err := ba.Scan(sharedbuf[:]) 218 | require.NoError(err) 219 | 220 | // Modify the "driver" buffer src 221 | sharedbuf[0] = 1 222 | 223 | require.Equal(uint8(0), ba[0], "ByteBuffer should not share reference with scanned src") 224 | 225 | } 226 | 227 | func envOrDefault(key string, def string) string { 228 | v := os.Getenv(key) 229 | if v == "" { 230 | v = def 231 | } 232 | return v 233 | } 234 | 235 | func openTestDB() (*sql.DB, error) { 236 | return sql.Open("postgres", fmt.Sprintf( 237 | "postgres://%s:%s@%s/%s?sslmode=disable", 238 | envOrDefault("DBUSER", "testing"), 239 | envOrDefault("DBPASS", "testing"), 240 | envOrDefault("DBHOST", "0.0.0.0:5432"), 241 | envOrDefault("DBNAME", "testing"), 242 | )) 243 | } 244 | -------------------------------------------------------------------------------- /types/types_test.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | "reflect" 7 | "testing" 8 | "time" 9 | 10 | "github.com/lib/pq" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestURL(t *testing.T) { 15 | require := require.New(t) 16 | expectedURL := "https://foo.com" 17 | 18 | var u URL 19 | require.Nil(u.Scan(expectedURL)) 20 | require.Equal(expectedURL, urlStr(url.URL(u))) 21 | 22 | u = URL{} 23 | require.Nil(u.Scan([]byte("https://foo.com"))) 24 | require.Equal(expectedURL, urlStr(url.URL(u))) 25 | 26 | val, err := u.Value() 27 | require.Nil(err) 28 | require.Equal(expectedURL, val) 29 | } 30 | 31 | func urlStr(u url.URL) string { 32 | url := &u 33 | return url.String() 34 | } 35 | 36 | func mustURL(u string) url.URL { 37 | url, _ := url.Parse(u) 38 | return *url 39 | } 40 | 41 | func mustPtrURL(u string) *url.URL { 42 | url, _ := url.Parse(u) 43 | return url 44 | } 45 | 46 | type jsonType struct { 47 | Foo string `json:"foo"` 48 | Bar int `json:"bar"` 49 | } 50 | 51 | func TestJSON(t *testing.T) { 52 | input := `{"foo":"a","bar":1}` 53 | 54 | t.Run("into object", func(t *testing.T) { 55 | var dst jsonType 56 | expected := jsonType{"a", 1} 57 | 58 | json := JSON(&dst) 59 | require.Nil(t, json.Scan([]byte(input))) 60 | require.Equal(t, expected, dst) 61 | 62 | val, err := json.Value() 63 | require.Nil(t, err) 64 | require.Equal(t, input, string(val.([]byte))) 65 | }) 66 | 67 | t.Run("into map", func(t *testing.T) { 68 | var dst = make(map[string]interface{}) 69 | 70 | json := JSON(&dst) 71 | require.Nil(t, json.Scan([]byte(input))) 72 | val, ok := dst["foo"] 73 | require.True(t, ok) 74 | require.Equal(t, "a", val.(string)) 75 | 76 | val, ok = dst["bar"] 77 | require.True(t, ok) 78 | require.Equal(t, float64(1), val.(float64)) 79 | }) 80 | 81 | t.Run("nil input", func(t *testing.T) { 82 | require.NoError(t, JSON(&map[string]interface{}{}).Scan(nil)) 83 | }) 84 | } 85 | 86 | func TestArray(t *testing.T) { 87 | require := require.New(t) 88 | input, err := pq.Array([]int64{1, 2}).Value() 89 | require.Nil(err) 90 | 91 | var dst [2]int64 92 | 93 | arr := Array(&dst, 2) 94 | require.Nil(arr.Scan(input)) 95 | require.Equal(int64(1), dst[0]) 96 | require.Equal(int64(2), dst[1]) 97 | 98 | v, err := arr.Value() 99 | require.Nil(err) 100 | require.Equal(input, v) 101 | } 102 | 103 | func TestNullable(t *testing.T) { 104 | var ( 105 | Str string 106 | Int8 int8 107 | Uint8 uint8 108 | Byte byte 109 | Int16 int16 110 | Uint16 uint16 111 | Int32 int32 112 | Uint32 uint32 113 | Int int 114 | Uint uint 115 | Int64 int64 116 | Uint64 uint64 117 | Float32 float32 118 | Float64 float64 119 | Bool bool 120 | Time time.Time 121 | Duration time.Duration 122 | Url URL 123 | PtrStr *string 124 | PtrInt8 *int8 125 | PtrUint8 *uint8 126 | PtrByte *byte 127 | PtrInt16 *int16 128 | PtrUint16 *uint16 129 | PtrInt32 *int32 130 | PtrUint32 *uint32 131 | PtrInt *int 132 | PtrUint *uint 133 | PtrInt64 *int64 134 | PtrUint64 *uint64 135 | PtrFloat32 *float32 136 | PtrFloat64 *float64 137 | PtrBool *bool 138 | PtrTime *time.Time 139 | PtrDuration *time.Duration 140 | ) 141 | tim := time.Now().UTC() 142 | tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), 143 | tim.Minute(), tim.Second(), 0, tim.Location()) 144 | s := require.New(t) 145 | url, err := url.Parse("http://foo.me") 146 | s.NoError(err) 147 | 148 | cases := []struct { 149 | name string 150 | typ string 151 | nonNullInput interface{} 152 | dst interface{} 153 | isPtr bool 154 | }{ 155 | { 156 | "string", 157 | "text", 158 | "foo", 159 | &Str, 160 | false, 161 | }, 162 | { 163 | "int8", 164 | "bigint", 165 | int8(1), 166 | &Int8, 167 | false, 168 | }, 169 | { 170 | "byte", 171 | "bigint", 172 | byte(1), 173 | &Byte, 174 | false, 175 | }, 176 | { 177 | "int16", 178 | "bigint", 179 | int16(1), 180 | &Int16, 181 | false, 182 | }, 183 | { 184 | "int32", 185 | "bigint", 186 | int32(1), 187 | &Int32, 188 | false, 189 | }, 190 | { 191 | "int", 192 | "bigint", 193 | int(1), 194 | &Int, 195 | false, 196 | }, 197 | { 198 | "int64", 199 | "bigint", 200 | int64(1), 201 | &Int64, 202 | false, 203 | }, 204 | { 205 | "uint8", 206 | "bigint", 207 | uint8(1), 208 | &Uint8, 209 | false, 210 | }, 211 | { 212 | "uint16", 213 | "bigint", 214 | uint16(1), 215 | &Uint16, 216 | false, 217 | }, 218 | { 219 | "uint32", 220 | "bigint", 221 | uint32(1), 222 | &Uint32, 223 | false, 224 | }, 225 | { 226 | "uint", 227 | "bigint", 228 | uint(1), 229 | &Uint, 230 | false, 231 | }, 232 | { 233 | "uint64", 234 | "bigint", 235 | uint64(1), 236 | &Uint64, 237 | false, 238 | }, 239 | { 240 | "float32", 241 | "decimal", 242 | float32(.5), 243 | &Float32, 244 | false, 245 | }, 246 | { 247 | "float64", 248 | "decimal", 249 | float64(.5), 250 | &Float64, 251 | false, 252 | }, 253 | { 254 | "bool", 255 | "bool", 256 | true, 257 | &Bool, 258 | false, 259 | }, 260 | { 261 | "time.Duration", 262 | "bigint", 263 | 3 * time.Second, 264 | &Duration, 265 | false, 266 | }, 267 | { 268 | "time.Time", 269 | "timestamptz", 270 | tim, 271 | &Time, 272 | false, 273 | }, 274 | { 275 | "URL", 276 | "text", 277 | URL(*url), 278 | &Url, 279 | false, 280 | }, 281 | { 282 | "*string", 283 | "text", 284 | "foo", 285 | &PtrStr, 286 | true, 287 | }, 288 | { 289 | "*int8", 290 | "bigint", 291 | int8(1), 292 | &PtrInt8, 293 | true, 294 | }, 295 | { 296 | "*byte", 297 | "bigint", 298 | byte(1), 299 | &PtrByte, 300 | true, 301 | }, 302 | { 303 | "*int16", 304 | "bigint", 305 | int16(1), 306 | &PtrInt16, 307 | true, 308 | }, 309 | { 310 | "*int32", 311 | "bigint", 312 | int32(1), 313 | &PtrInt32, 314 | true, 315 | }, 316 | { 317 | "*int", 318 | "bigint", 319 | int(1), 320 | &PtrInt, 321 | true, 322 | }, 323 | { 324 | "*int64", 325 | "bigint", 326 | int64(1), 327 | &PtrInt64, 328 | true, 329 | }, 330 | { 331 | "*uint8", 332 | "bigint", 333 | uint8(1), 334 | &PtrUint8, 335 | true, 336 | }, 337 | { 338 | "*uint16", 339 | "bigint", 340 | uint16(1), 341 | &PtrUint16, 342 | true, 343 | }, 344 | { 345 | "*uint32", 346 | "bigint", 347 | uint32(1), 348 | &PtrUint32, 349 | true, 350 | }, 351 | { 352 | "*uint", 353 | "bigint", 354 | uint(1), 355 | &PtrUint, 356 | true, 357 | }, 358 | { 359 | "*uint64", 360 | "bigint", 361 | uint64(1), 362 | &PtrUint64, 363 | true, 364 | }, 365 | { 366 | "*float32", 367 | "decimal", 368 | float32(.5), 369 | &PtrFloat32, 370 | true, 371 | }, 372 | { 373 | "*float64", 374 | "decimal", 375 | float64(.5), 376 | &PtrFloat64, 377 | true, 378 | }, 379 | { 380 | "*bool", 381 | "bool", 382 | true, 383 | &PtrBool, 384 | true, 385 | }, 386 | { 387 | "*time.Duration", 388 | "bigint", 389 | 3 * time.Second, 390 | &PtrDuration, 391 | true, 392 | }, 393 | { 394 | "*time.Time", 395 | "timestamptz", 396 | tim, 397 | &PtrTime, 398 | true, 399 | }, 400 | } 401 | 402 | db, err := openTestDB() 403 | s.Nil(err) 404 | 405 | defer func() { 406 | _, err = db.Exec("DROP TABLE IF EXISTS foo") 407 | s.Nil(err) 408 | s.Nil(db.Close()) 409 | }() 410 | 411 | for _, c := range cases { 412 | s.Nil(db.QueryRow("SELECT null").Scan(Nullable(c.dst)), c.name) 413 | elem := reflect.ValueOf(c.dst).Elem() 414 | zero := reflect.Zero(elem.Type()) 415 | s.Equal(zero.Interface(), elem.Interface(), c.name) 416 | 417 | var input = c.nonNullInput 418 | if v, ok := c.nonNullInput.(time.Duration); ok { 419 | input = int64(v) 420 | } 421 | 422 | _, err := db.Exec(fmt.Sprintf(`CREATE TABLE foo ( 423 | testcol %s 424 | )`, c.typ)) 425 | s.Nil(err, c.name) 426 | 427 | _, err = db.Exec("INSERT INTO foo (testcol) VALUES ($1)", input) 428 | s.Nil(err, c.name) 429 | 430 | s.Nil(db.QueryRow("SELECT testcol FROM foo").Scan(Nullable(c.dst)), c.name) 431 | elem = reflect.ValueOf(c.dst).Elem() 432 | if c.isPtr { 433 | elem = elem.Elem() 434 | } 435 | 436 | result := elem.Interface() 437 | switch v := result.(type) { 438 | case time.Time: 439 | result = v.UTC() 440 | } 441 | 442 | s.Equal(c.nonNullInput, result, c.name) 443 | 444 | _, err = db.Exec("DROP TABLE foo") 445 | s.Nil(err, c.name) 446 | } 447 | } 448 | --------------------------------------------------------------------------------