├── gen ├── .gitignore ├── template.go ├── meta_test.go ├── sort_test.go ├── aggregate_test.go ├── predicate_test.go ├── repository_test.go ├── generate_test.go ├── file.go ├── internal │ └── user.go ├── meta.go ├── sort.go ├── generate.go ├── aggregate.go ├── predicate.go └── repository.go ├── test ├── gen │ └── customtypes │ │ ├── .gitignore │ │ ├── customtypes_test.go │ │ └── customtypes.go └── integration │ ├── playerrepo │ ├── meta.go │ ├── sort.go │ ├── repository_test.go │ ├── postgres_test.go │ ├── sqlite_test.go │ ├── aggregate.go │ ├── repository.go │ ├── predicate_test.go │ ├── sqlite.go │ ├── postgres.go │ └── predicate.go │ ├── generate.go │ └── player │ └── player.go ├── sort ├── sort_func.go ├── sort.go ├── direction.go └── direction_test.go ├── aggregate ├── agg_func.go ├── aggregate.go ├── operator.go └── operator_test.go ├── comparison ├── pred_func.go ├── predicate.go ├── operator.go └── operator_test.go ├── logger.go ├── tx.go ├── Makefile ├── .gitignore ├── errors_test.go ├── go.mod ├── pg_template_test.go ├── sqlite_template_test.go ├── errors.go ├── sql.go ├── field_builder_test.go ├── x ├── etc │ ├── fmt_src_test.go │ └── fmt_src.go └── strings │ ├── camel_case.go │ └── camel_case_test.go ├── revive.toml ├── field_builder.go ├── .github └── workflows │ ├── test.yml │ └── codeql-analysis.yml ├── schema_builder_test.go ├── template_test.go ├── schema.go ├── schema_builder.go ├── field.go ├── template.go ├── go.sum ├── README.md ├── LICENSE ├── sqlite_template.go └── pg_template.go /gen/.gitignore: -------------------------------------------------------------------------------- 1 | userrepo -------------------------------------------------------------------------------- /test/gen/customtypes/.gitignore: -------------------------------------------------------------------------------- 1 | customrepo -------------------------------------------------------------------------------- /sort/sort_func.go: -------------------------------------------------------------------------------- 1 | package sort 2 | 3 | // SortFunc is a sort list decorator 4 | type SortFunc func([]*Sort) []*Sort 5 | -------------------------------------------------------------------------------- /aggregate/agg_func.go: -------------------------------------------------------------------------------- 1 | package aggregate 2 | 3 | // AggFunc is an aggregate list decorator 4 | type AggFunc func([]*Aggregate) []*Aggregate 5 | -------------------------------------------------------------------------------- /sort/sort.go: -------------------------------------------------------------------------------- 1 | package sort 2 | 3 | // Sort is a sort parameter 4 | type Sort struct { 5 | Field string 6 | Direction Direction 7 | } 8 | -------------------------------------------------------------------------------- /comparison/pred_func.go: -------------------------------------------------------------------------------- 1 | package comparison 2 | 3 | // PredFunc is a predicate list decorator 4 | type PredFunc func([]*Predicate) []*Predicate 5 | -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | package nero 2 | 3 | // Logger is an interface that wraps the Printf method 4 | type Logger interface { 5 | Printf(string, ...interface{}) 6 | } 7 | -------------------------------------------------------------------------------- /aggregate/aggregate.go: -------------------------------------------------------------------------------- 1 | package aggregate 2 | 3 | // Aggregate is an aggregate parameter 4 | type Aggregate struct { 5 | Field string 6 | Op Operator 7 | } 8 | -------------------------------------------------------------------------------- /tx.go: -------------------------------------------------------------------------------- 1 | package nero 2 | 3 | // Tx is an interface that wraps the Commit and Rollback method 4 | type Tx interface { 5 | Commit() error 6 | Rollback() error 7 | } 8 | -------------------------------------------------------------------------------- /comparison/predicate.go: -------------------------------------------------------------------------------- 1 | package comparison 2 | 3 | // Predicate is a predicate parameter 4 | type Predicate struct { 5 | Field string 6 | Op Operator 7 | Arg interface{} 8 | } 9 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | DOCKER ?= docker 2 | 3 | .PHONY: pg-test 4 | pg-test: 5 | $(DOCKER) rm -f pg-test || true 6 | $(DOCKER) run --name pg-test -e POSTGRES_PASSWORD=postgres -d --rm -p 5432:5432 postgres:13 7 | $(DOCKER) exec -it pg-test bash -c 'while ! pg_isready; do sleep 1; done;' -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | -------------------------------------------------------------------------------- /errors_test.go: -------------------------------------------------------------------------------- 1 | package nero_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stevenferrer/nero" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestErrRequiredField(t *testing.T) { 11 | err := nero.NewErrRequiredField("Name") 12 | expect := `Name field is required` 13 | assert.Equal(t, expect, err.Error()) 14 | } 15 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/stevenferrer/nero 2 | 3 | go 1.15 4 | 5 | require ( 6 | github.com/Masterminds/squirrel v1.5.0 7 | github.com/hashicorp/go-multierror v1.1.1 8 | github.com/jinzhu/inflection v1.0.0 9 | github.com/lib/pq v1.10.1 10 | github.com/mattn/go-sqlite3 v1.14.7 11 | github.com/pkg/errors v0.9.1 12 | github.com/stevenferrer/mira v0.3.0 13 | github.com/stretchr/testify v1.7.0 14 | golang.org/x/tools v0.1.0 15 | ) 16 | -------------------------------------------------------------------------------- /pg_template_test.go: -------------------------------------------------------------------------------- 1 | package nero_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/require" 8 | 9 | "github.com/stevenferrer/nero" 10 | ) 11 | 12 | func TestPostgresTemplate(t *testing.T) { 13 | tmpl := nero.NewPostgresTemplate().WithFilename("pg.go") 14 | assert.Equal(t, "pg.go", tmpl.Filename()) 15 | 16 | _, err := nero.ParseTemplate(tmpl) 17 | require.NoError(t, err) 18 | } 19 | -------------------------------------------------------------------------------- /sqlite_template_test.go: -------------------------------------------------------------------------------- 1 | package nero_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/require" 8 | 9 | "github.com/stevenferrer/nero" 10 | ) 11 | 12 | func TestSQLiteTemplate(t *testing.T) { 13 | tmpl := nero.NewSQLiteTemplate().WithFilename("sqlite.go") 14 | assert.Equal(t, "sqlite.go", tmpl.Filename()) 15 | 16 | _, err := nero.ParseTemplate(tmpl) 17 | require.NoError(t, err) 18 | } 19 | -------------------------------------------------------------------------------- /gen/template.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/pkg/errors" 7 | "github.com/stevenferrer/nero" 8 | ) 9 | 10 | func newTemplate(schema *nero.Schema, template nero.Template) (*bytes.Buffer, error) { 11 | tmpl, err := nero.ParseTemplate(template) 12 | if err != nil { 13 | return nil, errors.Wrap(err, "parse template") 14 | } 15 | 16 | buf := &bytes.Buffer{} 17 | err = tmpl.Execute(buf, schema) 18 | return buf, err 19 | } 20 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package nero 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // ErrRequiredField is a required field error 8 | type ErrRequiredField struct { 9 | field string 10 | } 11 | 12 | // NewErrRequiredField returns an ErrFieldRequired error 13 | func NewErrRequiredField(field string) *ErrRequiredField { 14 | return &ErrRequiredField{field: field} 15 | } 16 | 17 | func (e *ErrRequiredField) Error() string { 18 | return fmt.Sprintf("%s field is required", e.field) 19 | } 20 | -------------------------------------------------------------------------------- /sort/direction.go: -------------------------------------------------------------------------------- 1 | package sort 2 | 3 | // Direction is a sort direction 4 | type Direction int 5 | 6 | func (d Direction) String() string { 7 | return [...]string{ 8 | "Asc", 9 | "Desc", 10 | }[d] 11 | } 12 | 13 | // Desc is a sort description 14 | func (d Direction) Desc() string { 15 | return [...]string{ 16 | "ascending", 17 | "descending", 18 | }[d] 19 | } 20 | 21 | const ( 22 | // Asc ascending sort direction 23 | Asc Direction = iota 24 | // Desc descending sort direction 25 | Desc 26 | ) 27 | -------------------------------------------------------------------------------- /gen/meta_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "go/format" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/stevenferrer/nero/gen/internal" 11 | ) 12 | 13 | func Test_newMetaFile(t *testing.T) { 14 | u := internal.User{} 15 | f, err := newMetaFile(u.Schema()) 16 | require.NoError(t, err) 17 | 18 | _, err = format.Source(f.Bytes()) 19 | require.NoError(t, err) 20 | 21 | _, err = newMetaFile(nil) 22 | assert.Error(t, err) 23 | } 24 | -------------------------------------------------------------------------------- /gen/sort_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "go/format" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/stevenferrer/nero/gen/internal" 11 | ) 12 | 13 | func Test_newSortFile(t *testing.T) { 14 | u := internal.User{} 15 | f, err := newSortFile(u.Schema()) 16 | require.NoError(t, err) 17 | 18 | _, err = format.Source(f.Bytes()) 19 | require.NoError(t, err) 20 | 21 | _, err = newSortFile(nil) 22 | assert.Error(t, err) 23 | } 24 | -------------------------------------------------------------------------------- /gen/aggregate_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "go/format" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/stevenferrer/nero/gen/internal" 11 | ) 12 | 13 | func Test_newAggregateFile(t *testing.T) { 14 | u := internal.User{} 15 | f, err := newAggregateFile(u.Schema()) 16 | require.NoError(t, err) 17 | 18 | _, err = format.Source(f.Bytes()) 19 | require.NoError(t, err) 20 | 21 | _, err = newAggregateFile(nil) 22 | assert.Error(t, err) 23 | } 24 | -------------------------------------------------------------------------------- /gen/predicate_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "go/format" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/stevenferrer/nero/gen/internal" 11 | ) 12 | 13 | func Test_newPredicateFile(t *testing.T) { 14 | u := internal.User{} 15 | f, err := newPredicateFile(u.Schema()) 16 | require.NoError(t, err) 17 | 18 | _, err = format.Source(f.Bytes()) 19 | require.NoError(t, err) 20 | 21 | _, err = newPredicateFile(nil) 22 | assert.Error(t, err) 23 | } 24 | -------------------------------------------------------------------------------- /gen/repository_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "go/format" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/stevenferrer/nero/gen/internal" 11 | ) 12 | 13 | func Test_newRepositoryFile(t *testing.T) { 14 | u := internal.User{} 15 | f, err := newRepositoryFile(u.Schema()) 16 | require.NoError(t, err) 17 | 18 | _, err = format.Source(f.Bytes()) 19 | require.NoError(t, err) 20 | 21 | _, err = newRepositoryFile(nil) 22 | assert.Error(t, err) 23 | } 24 | -------------------------------------------------------------------------------- /test/integration/playerrepo/meta.go: -------------------------------------------------------------------------------- 1 | // Code generated by nero, DO NOT EDIT. 2 | package playerrepo 3 | 4 | // Table is the database table 5 | const Table = "players" 6 | 7 | // Field is a Player field 8 | type Field int 9 | 10 | // String returns the string representation of the field 11 | func (f Field) String() string { 12 | return [...]string{ 13 | "invalid", 14 | "id", 15 | "email", 16 | "name", 17 | "age", 18 | "race", 19 | "updated_at", 20 | "created_at", 21 | }[f] 22 | } 23 | 24 | const ( 25 | FieldID Field = iota + 1 26 | FieldEmail 27 | FieldName 28 | FieldAge 29 | FieldRace 30 | FieldUpdatedAt 31 | FieldCreatedAt 32 | ) 33 | -------------------------------------------------------------------------------- /test/integration/generate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "os" 6 | "path" 7 | 8 | "github.com/stevenferrer/nero/gen" 9 | "github.com/stevenferrer/nero/test/integration/player" 10 | ) 11 | 12 | func main() { 13 | // generate 14 | p := player.Player{} 15 | files, err := gen.Generate(p.Schema()) 16 | checkErr(err) 17 | 18 | // create base directory 19 | basePath := path.Join("playerrepo") 20 | err = os.MkdirAll(basePath, os.ModePerm) 21 | checkErr(err) 22 | 23 | for _, file := range files { 24 | err = file.Render(basePath) 25 | checkErr(err) 26 | } 27 | } 28 | 29 | func checkErr(err error) { 30 | if err != nil { 31 | log.Fatal(err) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /test/integration/playerrepo/sort.go: -------------------------------------------------------------------------------- 1 | // Code generated by nero, DO NOT EDIT. 2 | package playerrepo 3 | 4 | import ( 5 | "github.com/stevenferrer/nero/sort" 6 | ) 7 | 8 | // Asc ascending sort direction 9 | func Asc(field Field) sort.SortFunc { 10 | return func(sorts []*sort.Sort) []*sort.Sort { 11 | return append(sorts, &sort.Sort{ 12 | Field: field.String(), 13 | Direction: sort.Asc, 14 | }) 15 | } 16 | } 17 | 18 | // Desc descending sort direction 19 | func Desc(field Field) sort.SortFunc { 20 | return func(sorts []*sort.Sort) []*sort.Sort { 21 | return append(sorts, &sort.Sort{ 22 | Field: field.String(), 23 | Direction: sort.Desc, 24 | }) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /sort/direction_test.go: -------------------------------------------------------------------------------- 1 | package sort_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stevenferrer/nero/sort" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestDirectionStrings(t *testing.T) { 11 | tests := []struct { 12 | direction sort.Direction 13 | wantStr, 14 | wantDesc string 15 | }{ 16 | { 17 | direction: sort.Asc, 18 | wantStr: "Asc", 19 | wantDesc: "ascending", 20 | }, 21 | { 22 | direction: sort.Desc, 23 | wantStr: "Desc", 24 | wantDesc: "descending", 25 | }, 26 | } 27 | 28 | for _, tc := range tests { 29 | assert.Equal(t, tc.wantStr, tc.direction.String()) 30 | assert.Equal(t, tc.wantDesc, tc.direction.Desc()) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /sql.go: -------------------------------------------------------------------------------- 1 | package nero 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | ) 8 | 9 | // SQLRunner is an interface that wraps the standard sql methods 10 | type SQLRunner interface { 11 | Query(string, ...interface{}) (*sql.Rows, error) 12 | QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) 13 | QueryRow(string, ...interface{}) *sql.Row 14 | QueryRowContext(context.Context, string, ...interface{}) *sql.Row 15 | Exec(string, ...interface{}) (sql.Result, error) 16 | ExecContext(context.Context, string, ...interface{}) (sql.Result, error) 17 | } 18 | 19 | // ValueScanner is an interface that wraps the driver.Valuer and sql.Scanner interface 20 | type ValueScanner interface { 21 | driver.Valuer 22 | sql.Scanner 23 | } 24 | -------------------------------------------------------------------------------- /test/gen/customtypes/customtypes_test.go: -------------------------------------------------------------------------------- 1 | package customtypes_test 2 | 3 | import ( 4 | "os" 5 | "path" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | 11 | "github.com/stevenferrer/nero/gen" 12 | "github.com/stevenferrer/nero/test/gen/customtypes" 13 | ) 14 | 15 | func TestCustomTypes(t *testing.T) { 16 | c := customtypes.Custom{} 17 | files, err := gen.Generate(c.Schema()) 18 | require.NoError(t, err) 19 | assert.Len(t, files, 7, "should have 7 generated files") 20 | 21 | // create base directory 22 | basePath := path.Join("customrepo") 23 | err = os.MkdirAll(basePath, os.ModePerm) 24 | require.NoError(t, err) 25 | 26 | for _, f := range files { 27 | err = f.Render(basePath) 28 | require.NoError(t, err) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /aggregate/operator.go: -------------------------------------------------------------------------------- 1 | package aggregate 2 | 3 | // Operator is an aggregate operator 4 | type Operator int 5 | 6 | const ( 7 | // Avg is average operator 8 | Avg Operator = iota 9 | // Count is the count operator 10 | Count 11 | // Max is the max operator 12 | Max 13 | // Min is the min operator 14 | Min 15 | // Sum is the sum operator 16 | Sum 17 | // None is used to include a field in the result 18 | None 19 | ) 20 | 21 | func (o Operator) String() string { 22 | return [...]string{ 23 | "Avg", 24 | "Count", 25 | "Max", 26 | "Min", 27 | "Sum", 28 | "None", 29 | }[o] 30 | } 31 | 32 | // Desc is a aggregate function description 33 | func (o Operator) Desc() string { 34 | return [...]string{ 35 | "average", 36 | "count", 37 | "max", 38 | "min", 39 | "sum", 40 | "none", 41 | }[o] 42 | } 43 | -------------------------------------------------------------------------------- /test/integration/playerrepo/repository_test.go: -------------------------------------------------------------------------------- 1 | package playerrepo 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type errTx struct{} 11 | 12 | func (et *errTx) Commit() error { 13 | return nil 14 | } 15 | 16 | func (et *errTx) Rollback() error { 17 | return errors.New("tx error") 18 | } 19 | 20 | type okTx struct{} 21 | 22 | func (ot *okTx) Commit() error { 23 | return nil 24 | } 25 | 26 | func (ot *okTx) Rollback() error { 27 | return nil 28 | } 29 | 30 | func Test_rollback(t *testing.T) { 31 | e := &errTx{} 32 | err := rollback(e, errors.New("an error")) 33 | assert.Equal(t, "rollback error: tx error: an error", err.Error()) 34 | 35 | o := &okTx{} 36 | err = rollback(o, errors.New("an error")) 37 | assert.Equal(t, "an error", err.Error()) 38 | } 39 | -------------------------------------------------------------------------------- /field_builder_test.go: -------------------------------------------------------------------------------- 1 | package nero_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | "github.com/stevenferrer/nero" 9 | ) 10 | 11 | func TestFieldBuilder(t *testing.T) { 12 | field := nero.NewFieldBuilder("id", int64(0)).Auto(). 13 | StructField("ID").Optional().Build() 14 | 15 | assert.True(t, field.IsOptional()) 16 | assert.True(t, field.IsAuto()) 17 | 18 | assert.NotNil(t, field.TypeInfo()) 19 | assert.Equal(t, "id", field.Name()) 20 | assert.Equal(t, "ID", field.StructField()) 21 | assert.Equal(t, "id", field.Identifier()) 22 | assert.Equal(t, "ids", field.IdentifierPlural()) 23 | assert.Equal(t, true, field.IsComparable()) 24 | assert.Equal(t, false, field.IsArray()) 25 | assert.Equal(t, false, field.IsNillable()) 26 | assert.Equal(t, false, field.IsValueScanner()) 27 | } 28 | -------------------------------------------------------------------------------- /gen/generate_test.go: -------------------------------------------------------------------------------- 1 | package gen_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/stevenferrer/nero/gen" 11 | "github.com/stevenferrer/nero/gen/internal" 12 | ) 13 | 14 | func TestGenerate(t *testing.T) { 15 | u := internal.User{} 16 | files, err := gen.Generate(u.Schema()) 17 | assert.NoError(t, err) 18 | assert.Len(t, files, 6) 19 | 20 | for _, file := range files { 21 | require.NotEmpty(t, file.Filename()) 22 | require.NotEmpty(t, file.Bytes()) 23 | } 24 | 25 | // create base directory 26 | basePath := "userrepo" 27 | err = os.MkdirAll(basePath, os.ModePerm) 28 | require.NoError(t, err) 29 | 30 | assert.NoError(t, err) 31 | for _, f := range files { 32 | err = f.Render(basePath) 33 | require.NoError(t, err) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /x/etc/fmt_src_test.go: -------------------------------------------------------------------------------- 1 | package etc_test 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "path" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | 12 | "github.com/stevenferrer/nero/x/etc" 13 | ) 14 | 15 | const src = ` 16 | package main 17 | 18 | import ("fmt" 19 | "os" 20 | nero "github.com/stevenferrer/nero" 21 | ) 22 | func main() { 23 | fmt.Println("Hello, world!") 24 | } 25 | ` 26 | 27 | func TestFmtSrc(t *testing.T) { 28 | filename := "temp.go" 29 | filepath := path.Join(os.TempDir(), filename) 30 | err := ioutil.WriteFile(filepath, []byte(src), 0644) 31 | require.NoError(t, err) 32 | 33 | err = etc.FmtSrc(filepath) 34 | assert.NoError(t, err) 35 | 36 | // cleanup 37 | assert.NoError(t, os.Remove(filepath)) 38 | 39 | err = etc.FmtSrc(filepath) 40 | assert.Error(t, err) 41 | } 42 | -------------------------------------------------------------------------------- /gen/file.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "os" 5 | "path" 6 | 7 | "github.com/pkg/errors" 8 | 9 | "github.com/stevenferrer/nero/x/etc" 10 | ) 11 | 12 | // File is a generated file 13 | type File struct { 14 | name string 15 | buf []byte 16 | } 17 | 18 | // Render renders the file to the specified path 19 | func (f *File) Render(basePath string) error { 20 | filePath := path.Join(basePath, f.name) 21 | of, err := os.Create(filePath) 22 | if err != nil { 23 | return errors.Wrap(err, "create base path") 24 | } 25 | defer of.Close() 26 | 27 | _, err = of.Write(f.buf) 28 | if err != nil { 29 | return errors.Wrap(err, "write file") 30 | } 31 | 32 | return errors.Wrap(etc.FmtSrc(filePath), "format source") 33 | } 34 | 35 | // Filename returns the filename 36 | func (f *File) Filename() string { 37 | return f.name 38 | } 39 | 40 | // Bytes returns the bytes 41 | func (f *File) Bytes() []byte { 42 | return f.buf[:] 43 | } 44 | -------------------------------------------------------------------------------- /revive.toml: -------------------------------------------------------------------------------- 1 | ignoreGeneratedHeader = false 2 | severity = "warning" 3 | confidence = 0.8 4 | errorCode = 0 5 | warningCode = 0 6 | 7 | [rule.blank-imports] 8 | [rule.context-as-argument] 9 | [rule.context-keys-type] 10 | [rule.dot-imports] 11 | [rule.error-return] 12 | [rule.error-strings] 13 | [rule.error-naming] 14 | [rule.exported] 15 | [rule.if-return] 16 | [rule.increment-decrement] 17 | [rule.var-naming] 18 | [rule.var-declaration] 19 | [rule.package-comments] 20 | [rule.range] 21 | [rule.receiver-naming] 22 | [rule.time-naming] 23 | [rule.unexported-return] 24 | [rule.indent-error-flow] 25 | [rule.errorf] 26 | [rule.empty-block] 27 | [rule.superfluous-else] 28 | [rule.unused-parameter] 29 | [rule.unreachable-code] 30 | [rule.redefines-builtin-id] 31 | 32 | [rule.argument-limit] 33 | Arguments = [5] 34 | [rule.cyclomatic] 35 | Arguments = [15] 36 | [rule.cognitive-complexity] 37 | Arguments = [15] 38 | [rule.function-result-limit] 39 | Arguments = [3] -------------------------------------------------------------------------------- /gen/internal/user.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/stevenferrer/nero" 7 | ) 8 | 9 | // User is a user model 10 | type User struct { 11 | ID int64 12 | Name string 13 | Department string 14 | UpdatedAt time.Time 15 | CreatedAt time.Time 16 | } 17 | 18 | // Schema returns the schema for user model 19 | func (u User) Schema() *nero.Schema { 20 | return nero.NewSchemaBuilder(&u). 21 | PkgName("userrepo").Table("users"). 22 | Identity( 23 | nero.NewFieldBuilder("id", u.ID). 24 | StructField("ID").Auto().Build(), 25 | ). 26 | Fields( 27 | nero.NewFieldBuilder("name", u.Name). 28 | Build(), 29 | nero.NewFieldBuilder("department", u.Department). 30 | Build(), 31 | nero.NewFieldBuilder("updated_at", u.UpdatedAt). 32 | Optional().Build(), 33 | nero.NewFieldBuilder("created_at", u.CreatedAt). 34 | Auto().Build(), 35 | ). 36 | Templates(nero.NewPostgresTemplate()). 37 | Build() 38 | } 39 | -------------------------------------------------------------------------------- /aggregate/operator_test.go: -------------------------------------------------------------------------------- 1 | package aggregate_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stevenferrer/nero/aggregate" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestOperatorStrings(t *testing.T) { 11 | tests := []struct { 12 | op aggregate.Operator 13 | wantStr, 14 | wantDesc string 15 | }{ 16 | { 17 | op: aggregate.Avg, 18 | wantStr: "Avg", 19 | wantDesc: "average", 20 | }, 21 | { 22 | op: aggregate.Count, 23 | wantStr: "Count", 24 | wantDesc: "count", 25 | }, 26 | { 27 | op: aggregate.Max, 28 | wantStr: "Max", 29 | wantDesc: "max", 30 | }, 31 | { 32 | op: aggregate.Min, 33 | wantStr: "Min", 34 | wantDesc: "min", 35 | }, 36 | { 37 | op: aggregate.Sum, 38 | wantStr: "Sum", 39 | wantDesc: "sum", 40 | }, 41 | } 42 | 43 | for _, tc := range tests { 44 | assert.Equal(t, tc.wantStr, tc.op.String()) 45 | assert.Equal(t, tc.wantDesc, tc.op.Desc()) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /field_builder.go: -------------------------------------------------------------------------------- 1 | package nero 2 | 3 | import "github.com/stevenferrer/mira" 4 | 5 | // FieldBuilder is a field builder 6 | type FieldBuilder struct { 7 | f *Field 8 | } 9 | 10 | // NewFieldBuilder takes a field name and a value and returns a FieldBuilder 11 | func NewFieldBuilder(name string, v interface{}) *FieldBuilder { 12 | return &FieldBuilder{&Field{ 13 | name: name, 14 | typeInfo: mira.NewTypeInfo(v), 15 | }} 16 | } 17 | 18 | // Auto sets the auto-populated flag 19 | func (fb *FieldBuilder) Auto() *FieldBuilder { 20 | fb.f.auto = true 21 | return fb 22 | } 23 | 24 | // Optional sets the optional flag 25 | func (fb *FieldBuilder) Optional() *FieldBuilder { 26 | fb.f.optional = true 27 | return fb 28 | } 29 | 30 | // StructField sets the struct field 31 | func (fb *FieldBuilder) StructField(structField string) *FieldBuilder { 32 | fb.f.structField = structField 33 | return fb 34 | } 35 | 36 | // Build builds the field 37 | func (fb *FieldBuilder) Build() *Field { 38 | return &Field{ 39 | name: fb.f.name, 40 | typeInfo: fb.f.typeInfo, 41 | auto: fb.f.auto, 42 | optional: fb.f.optional, 43 | structField: fb.f.structField, 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Run linter 15 | uses: golangci/golangci-lint-action@v2 16 | with: 17 | version: v1.29 18 | 19 | test: 20 | runs-on: ubuntu-latest 21 | 22 | services: 23 | postgres: 24 | image: postgres:12 25 | env: 26 | POSTGRES_PASSWORD: postgres 27 | options: >- 28 | --health-cmd pg_isready 29 | --health-interval 10s 30 | --health-timeout 5s 31 | --health-retries 5 32 | ports: 33 | - 5432:5432 34 | 35 | steps: 36 | - name: Setup Go 37 | uses: actions/setup-go@v2 38 | with: 39 | go-version: 1.15 40 | - name: Checkout 41 | uses: actions/checkout@v2 42 | - name: Test 43 | run: go test -race -tags integration -coverprofile=profile.cov ./... 44 | - name: Send Coverage 45 | uses: shogo82148/actions-goveralls@v1 46 | with: 47 | path-to-profile: profile.cov -------------------------------------------------------------------------------- /gen/meta.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bytes" 5 | "text/template" 6 | 7 | "github.com/stevenferrer/nero" 8 | ) 9 | 10 | func newMetaFile(schema *nero.Schema) (*File, error) { 11 | tmpl, err := template.New("meta.tmpl"). 12 | Funcs(nero.NewFuncMap()).Parse(metaTmpl) 13 | if err != nil { 14 | return nil, err 15 | } 16 | 17 | buf := &bytes.Buffer{} 18 | err = tmpl.Execute(buf, schema) 19 | if err != nil { 20 | return nil, err 21 | } 22 | 23 | return &File{name: "meta.go", buf: buf.Bytes()}, nil 24 | } 25 | 26 | // TODO: wrap all template data into a struct 27 | 28 | const metaTmpl = ` 29 | {{- fileHeaders -}} 30 | 31 | package {{.PkgName}} 32 | 33 | // Table is the database table 34 | const Table = "{{ .Table }}" 35 | 36 | // Field is a {{.TypeInfo.Name}} field 37 | type Field int 38 | 39 | // String returns the string representation of the field 40 | func (f Field) String() string { 41 | return [...]string{ 42 | "invalid", 43 | "{{.Identity.Name}}", 44 | {{range .Fields -}} 45 | "{{.Name}}", 46 | {{end -}} 47 | }[f] 48 | } 49 | 50 | const ( 51 | Field{{.Identity.StructField}} Field = iota + 1 52 | {{range $e := .Fields -}} 53 | Field{{$e.StructField}} 54 | {{end -}} 55 | )` 56 | -------------------------------------------------------------------------------- /gen/sort.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bytes" 5 | "text/template" 6 | 7 | "github.com/stevenferrer/nero" 8 | "github.com/stevenferrer/nero/sort" 9 | ) 10 | 11 | func newSortFile(schema *nero.Schema) (*File, error) { 12 | tmpl, err := template.New("sort.tmpl"). 13 | Funcs(nero.NewFuncMap()).Parse(sortTmpl) 14 | if err != nil { 15 | return nil, err 16 | } 17 | 18 | data := struct { 19 | Directions []sort.Direction 20 | Schema *nero.Schema 21 | }{ 22 | Directions: []sort.Direction{ 23 | sort.Asc, sort.Desc, 24 | }, 25 | Schema: schema, 26 | } 27 | 28 | buf := &bytes.Buffer{} 29 | err = tmpl.Execute(buf, data) 30 | if err != nil { 31 | return nil, err 32 | } 33 | 34 | return &File{name: "sort.go", buf: buf.Bytes()}, nil 35 | } 36 | 37 | const sortTmpl = ` 38 | {{- fileHeaders -}} 39 | 40 | package {{.Schema.PkgName}} 41 | 42 | import ( 43 | "github.com/stevenferrer/nero/sort" 44 | ) 45 | 46 | {{range $direction := .Directions}} 47 | // {{$direction.String}} {{$direction.Desc}} sort direction 48 | func {{$direction.String}}(field Field) sort.SortFunc { 49 | return func(sorts []*sort.Sort) []*sort.Sort { 50 | return append(sorts, &sort.Sort{ 51 | Field: field.String(), 52 | Direction: sort.{{$direction.String}}, 53 | }) 54 | } 55 | } 56 | {{end}} 57 | ` 58 | -------------------------------------------------------------------------------- /comparison/operator.go: -------------------------------------------------------------------------------- 1 | package comparison 2 | 3 | // Operator is comparison operator type 4 | type Operator int 5 | 6 | // List of comparison operators 7 | const ( 8 | // Eq is an equal operator 9 | Eq Operator = iota 10 | // NotEq is a not equal operator 11 | NotEq 12 | // Gt is a greater than operator 13 | Gt 14 | // GtOrEq is a greater than or equal operator 15 | GtOrEq 16 | // Lt is a less than operator 17 | Lt 18 | // LtOrEq is a less than or equal operator 19 | LtOrEq 20 | // IsNull is an "is null" operator 21 | IsNull 22 | // IsNotNull is an "is not null" operator 23 | IsNotNull 24 | // In is used to check if a value is in the list 25 | In 26 | // In is used to check if a value is not in the list 27 | NotIn 28 | ) 29 | 30 | func (o Operator) String() string { 31 | return [...]string{ 32 | "Eq", 33 | "NotEq", 34 | "Gt", 35 | "GtOrEq", 36 | "Lt", 37 | "LtOrEq", 38 | "IsNull", 39 | "IsNotNull", 40 | "In", 41 | "NotIn", 42 | }[o] 43 | } 44 | 45 | // Desc is a predicate operator description 46 | func (o Operator) Desc() string { 47 | return [...]string{ 48 | "equal", 49 | "not equal", 50 | "greater than", 51 | "greater than or equal", 52 | "less than", 53 | "less than or equal", 54 | "is null", 55 | "is not null", 56 | "in", 57 | "not in", 58 | }[o] 59 | } 60 | -------------------------------------------------------------------------------- /schema_builder_test.go: -------------------------------------------------------------------------------- 1 | package nero_test 2 | 3 | import ( 4 | "math/big" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | 9 | "github.com/stevenferrer/nero" 10 | ) 11 | 12 | type MyStruct struct { 13 | ID *big.Int 14 | Name string 15 | } 16 | 17 | func TestSchemaBuilder(t *testing.T) { 18 | pkg := "mypkg" 19 | table := "mytable" 20 | ms := &MyStruct{} 21 | schemaBuilder := nero.NewSchemaBuilder(ms). 22 | PkgName(pkg).Table(table). 23 | Identity( 24 | nero.NewFieldBuilder("id", ms.ID). 25 | Auto().StructField("ID").Build(), 26 | ). 27 | Fields(nero.NewFieldBuilder("name", ms.Name).Build()) 28 | 29 | schema := schemaBuilder.Build() 30 | assert.Equal(t, pkg, schema.PkgName()) 31 | assert.Equal(t, table, schema.Table()) 32 | assert.NotNil(t, schema.Identity()) 33 | assert.Len(t, schema.Fields(), 1) 34 | assert.Len(t, schema.Imports(), 2) 35 | assert.Len(t, schema.Templates(), 2) 36 | assert.NotNil(t, schema.TypeInfo()) 37 | assert.Equal(t, "MyStruct", schema.TypeName()) 38 | assert.Equal(t, "MyStructs", schema.TypeNamePlural()) 39 | assert.Equal(t, "myStruct", schema.TypeIdentifier()) 40 | assert.Equal(t, "myStructs", schema.TypeIdentifierPlural()) 41 | 42 | tmpl := nero.NewPostgresTemplate() 43 | schema = schemaBuilder.Templates(tmpl).Build() 44 | assert.Len(t, schema.Templates(), 1) 45 | } 46 | -------------------------------------------------------------------------------- /gen/generate.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "github.com/pkg/errors" 5 | "github.com/stevenferrer/nero" 6 | ) 7 | 8 | // Generate generates the repository code 9 | func Generate(schema *nero.Schema) ([]*File, error) { 10 | files := []*File{} 11 | file, err := newMetaFile(schema) 12 | if err != nil { 13 | return nil, errors.Wrap(err, "meta file") 14 | } 15 | files = append(files, file) 16 | 17 | file, err = newPredicateFile(schema) 18 | if err != nil { 19 | return nil, errors.Wrap(err, "predicate file") 20 | } 21 | files = append(files, file) 22 | 23 | file, err = newSortFile(schema) 24 | if err != nil { 25 | return nil, errors.Wrap(err, "sort file") 26 | } 27 | files = append(files, file) 28 | 29 | file, err = newAggregateFile(schema) 30 | if err != nil { 31 | return nil, errors.Wrap(err, "aggregate file") 32 | } 33 | files = append(files, file) 34 | 35 | file, err = newRepositoryFile(schema) 36 | if err != nil { 37 | return nil, errors.Wrap(err, "repository file") 38 | } 39 | files = append(files, file) 40 | 41 | for _, tmpl := range schema.Templates() { 42 | buf, err := newTemplate(schema, tmpl) 43 | if err != nil { 44 | return nil, errors.Wrap(err, "template file") 45 | } 46 | 47 | files = append(files, &File{name: tmpl.Filename(), buf: buf.Bytes()}) 48 | } 49 | 50 | return files, nil 51 | } 52 | -------------------------------------------------------------------------------- /test/integration/player/player.go: -------------------------------------------------------------------------------- 1 | package player 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/stevenferrer/nero" 7 | ) 8 | 9 | // Player is a plaer 10 | type Player struct { 11 | ID string 12 | Email string 13 | Name string 14 | Age int 15 | Race Race 16 | UpdatedAt *time.Time 17 | CreatedAt *time.Time 18 | } 19 | 20 | // Race is the player race 21 | type Race string 22 | 23 | // List of player race 24 | const ( 25 | RaceHuman Race = "human" 26 | RaceCharr Race = "charr" 27 | RaceNorn Race = "norn" 28 | RaceSylvari Race = "sylvari" 29 | RaceTitan Race = "titan" 30 | ) 31 | 32 | // Schema implements nero.Schemaer 33 | func (p Player) Schema() *nero.Schema { 34 | return nero.NewSchemaBuilder(&p). 35 | PkgName("playerrepo"). 36 | Table("players"). 37 | Identity(nero.NewFieldBuilder("id", p.ID). 38 | StructField("ID").Auto().Build()). 39 | Fields( 40 | nero.NewFieldBuilder("email", p.Email).Build(), 41 | nero.NewFieldBuilder("name", p.Name).Build(), 42 | nero.NewFieldBuilder("age", p.Age).Build(), 43 | nero.NewFieldBuilder("race", p.Race).Build(), 44 | nero.NewFieldBuilder("updated_at", p.UpdatedAt). 45 | Optional().Build(), 46 | nero.NewFieldBuilder("created_at", p.CreatedAt). 47 | Auto().Build(), 48 | ). 49 | Templates( 50 | nero.NewPostgresTemplate(), 51 | nero.NewSQLiteTemplate(), 52 | ). 53 | Build() 54 | } 55 | -------------------------------------------------------------------------------- /gen/aggregate.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bytes" 5 | "text/template" 6 | 7 | "github.com/stevenferrer/nero" 8 | "github.com/stevenferrer/nero/aggregate" 9 | ) 10 | 11 | func newAggregateFile(schema *nero.Schema) (*File, error) { 12 | tmpl, err := template.New("aggregates.tmpl"). 13 | Funcs(nero.NewFuncMap()).Parse(aggregatesTmpl) 14 | if err != nil { 15 | return nil, err 16 | } 17 | 18 | data := struct { 19 | Operators []aggregate.Operator 20 | Schema *nero.Schema 21 | }{ 22 | Operators: []aggregate.Operator{ 23 | aggregate.Avg, aggregate.Count, 24 | aggregate.Max, aggregate.Min, 25 | aggregate.Sum, aggregate.None, 26 | }, 27 | Schema: schema, 28 | } 29 | 30 | buf := &bytes.Buffer{} 31 | err = tmpl.Execute(buf, data) 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | return &File{name: "aggregate.go", buf: buf.Bytes()}, nil 37 | } 38 | 39 | const aggregatesTmpl = ` 40 | {{- fileHeaders -}} 41 | 42 | package {{.Schema.PkgName}} 43 | 44 | import ( 45 | "github.com/stevenferrer/nero/aggregate" 46 | ) 47 | 48 | {{range $op := .Operators}} 49 | // {{$op.String}} is the {{$op.Desc}} aggregate operator 50 | func {{$op.String}}(field Field) aggregate.AggFunc { 51 | return func(aggs []*aggregate.Aggregate) []*aggregate.Aggregate { 52 | return append(aggs, &aggregate.Aggregate{ 53 | Field: field.String(), 54 | Op: aggregate.{{$op.String}}, 55 | }) 56 | } 57 | } 58 | {{end}} 59 | ` 60 | -------------------------------------------------------------------------------- /x/strings/camel_case.go: -------------------------------------------------------------------------------- 1 | package strings 2 | 3 | import ( 4 | "strings" 5 | "unicode" 6 | "unicode/utf8" 7 | ) 8 | 9 | // ToCamel converts a string to camel case 10 | func ToCamel(s string) string { 11 | camel := toCamel(s) 12 | r, w := utf8.DecodeRuneInString(camel) 13 | if !unicode.IsUpper(r) { 14 | r = unicode.ToUpper(r) 15 | } 16 | return string(r) + camel[w:] 17 | } 18 | 19 | // ToLowerCamel converts a string to camel case 20 | // where first word is always lowercase 21 | func ToLowerCamel(s string) string { 22 | camel := toCamel(s) 23 | r, w := utf8.DecodeRuneInString(camel) 24 | if !unicode.IsLower(r) { 25 | r = unicode.ToLower(r) 26 | } 27 | 28 | return string(r) + camel[w:] 29 | } 30 | 31 | func toCamel(s string) string { 32 | var ( 33 | sb = &strings.Builder{} 34 | last rune 35 | wc int 36 | ) 37 | 38 | for len(s) > 0 { 39 | r, w := utf8.DecodeRuneInString(s) 40 | if unicode.IsLetter(r) || unicode.IsNumber(r) { 41 | // last is lower and current is upper 42 | // e.g. Htt[pS]erver 43 | if unicode.IsLower(last) && unicode.IsUpper(r) { 44 | wc++ 45 | } 46 | 47 | if !unicode.IsLetter(last) && 48 | !unicode.IsNumber(r) { 49 | r = unicode.ToUpper(r) 50 | wc++ 51 | } 52 | 53 | sb.WriteRune(r) 54 | } 55 | 56 | s, last = s[w:], r 57 | } 58 | 59 | // convert to lowercase when only one word 60 | if wc == 1 { 61 | return strings.ToLower(sb.String()) 62 | } 63 | 64 | return sb.String() 65 | } 66 | -------------------------------------------------------------------------------- /comparison/operator_test.go: -------------------------------------------------------------------------------- 1 | package comparison_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stevenferrer/nero/comparison" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestOperatorStrings(t *testing.T) { 11 | tests := []struct { 12 | op comparison.Operator 13 | wantStr, 14 | wantDesc string 15 | }{ 16 | { 17 | op: comparison.Eq, 18 | wantStr: "Eq", 19 | wantDesc: "equal", 20 | }, 21 | { 22 | op: comparison.NotEq, 23 | wantStr: "NotEq", 24 | wantDesc: "not equal", 25 | }, 26 | { 27 | op: comparison.Gt, 28 | wantStr: "Gt", 29 | wantDesc: "greater than", 30 | }, 31 | { 32 | op: comparison.GtOrEq, 33 | wantStr: "GtOrEq", 34 | wantDesc: "greater than or equal", 35 | }, 36 | { 37 | op: comparison.Lt, 38 | wantStr: "Lt", 39 | wantDesc: "less than", 40 | }, 41 | { 42 | op: comparison.LtOrEq, 43 | wantStr: "LtOrEq", 44 | wantDesc: "less than or equal", 45 | }, 46 | { 47 | op: comparison.IsNull, 48 | wantStr: "IsNull", 49 | wantDesc: "is null", 50 | }, 51 | { 52 | op: comparison.IsNotNull, 53 | wantStr: "IsNotNull", 54 | wantDesc: "is not null", 55 | }, 56 | { 57 | op: comparison.In, 58 | wantStr: "In", 59 | wantDesc: "in", 60 | }, 61 | { 62 | op: comparison.NotIn, 63 | wantStr: "NotIn", 64 | wantDesc: "not in", 65 | }, 66 | } 67 | 68 | for _, tc := range tests { 69 | assert.Equal(t, tc.wantStr, tc.op.String()) 70 | assert.Equal(t, tc.wantDesc, tc.op.Desc()) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /test/integration/playerrepo/postgres_test.go: -------------------------------------------------------------------------------- 1 | //go:build integration 2 | 3 | package playerrepo_test 4 | 5 | import ( 6 | "bytes" 7 | "database/sql" 8 | "log" 9 | "testing" 10 | 11 | _ "github.com/lib/pq" 12 | "github.com/stevenferrer/nero/test/integration/playerrepo" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestPostgresRepository(t *testing.T) { 17 | t.Parallel() 18 | 19 | const dsn = "postgres://postgres:postgres@localhost:5432?sslmode=disable" 20 | 21 | // regular methods 22 | db, err := sql.Open("postgres", dsn) 23 | require.NoError(t, err) 24 | require.NoError(t, db.Ping()) 25 | defer db.Close() 26 | 27 | // create table 28 | require.NoError(t, createPgTable(db)) 29 | 30 | // initialize a new repo 31 | repo := playerrepo.NewPostgresRepository(db).Debug(). 32 | WithLogger(log.New(&bytes.Buffer{}, "", 0)) 33 | newRepoTestRunner(repo)(t) 34 | require.NoError(t, dropTable(db)) 35 | 36 | // tx methods 37 | require.NoError(t, createPgTable(db)) 38 | repo = playerrepo.NewPostgresRepository(db).Debug(). 39 | WithLogger(log.New(&bytes.Buffer{}, "", 0)) 40 | newRepoTestRunnerTx(repo)(t) 41 | require.NoError(t, dropTable(db)) 42 | } 43 | 44 | func createPgTable(db *sql.DB) error { 45 | _, err := db.Exec(`CREATE TABLE players ( 46 | id bigint GENERATED always AS IDENTITY PRIMARY KEY, 47 | email VARCHAR(255) UNIQUE NOT NULL, 48 | "name" VARCHAR(50) NOT NULL, 49 | age INTEGER NOT NULL, 50 | "race" VARCHAR(20) NOT NULL, 51 | updated_at TIMESTAMP, 52 | created_at TIMESTAMP DEFAULT now() 53 | )`) 54 | return err 55 | } 56 | 57 | func dropTable(db *sql.DB) error { 58 | _, err := db.Exec(`drop table players`) 59 | return err 60 | } 61 | -------------------------------------------------------------------------------- /test/integration/playerrepo/sqlite_test.go: -------------------------------------------------------------------------------- 1 | //go:build integration 2 | // +build integration 3 | 4 | package playerrepo_test 5 | 6 | import ( 7 | "bytes" 8 | "database/sql" 9 | "log" 10 | "testing" 11 | 12 | _ "github.com/mattn/go-sqlite3" 13 | "github.com/stevenferrer/nero/test/integration/playerrepo" 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | ) 17 | 18 | func TestSQLiteRepository(t *testing.T) { 19 | t.Parallel() 20 | 21 | const dsn = "file:test.db?mode=memory&cache=shared" 22 | db, err := sql.Open("sqlite3", dsn) 23 | require.NoError(t, err) 24 | require.NoError(t, db.Ping()) 25 | defer db.Close() 26 | 27 | // create table 28 | err = createSqliteTable(db) 29 | assert.NoError(t, err) 30 | 31 | // initialize a new repo 32 | repo := playerrepo.NewSQLiteRepository(db).Debug(). 33 | WithLogger(log.New(&bytes.Buffer{}, "", 0)) 34 | newRepoTestRunner(repo)(t) 35 | // cleanup 36 | require.NoError(t, dropTable(db)) 37 | 38 | // Tx methods 39 | // re-create table 40 | err = createSqliteTable(db) 41 | assert.NoError(t, err) 42 | 43 | // initialize a new repo 44 | repo = playerrepo.NewSQLiteRepository(db).Debug(). 45 | WithLogger(log.New(&bytes.Buffer{}, "", 0)) 46 | newRepoTestRunnerTx(repo)(t) 47 | require.NoError(t, dropTable(db)) 48 | } 49 | 50 | func createSqliteTable(db *sql.DB) error { 51 | _, err := db.Exec(` 52 | CREATE TABLE players ( 53 | id INTEGER PRIMARY KEY, 54 | email TEXT NOT NULL UNIQUE, 55 | "name" TEXT NOT NULL, 56 | age INTEGER NOT NULL, 57 | race TEXT NOT NULL, 58 | updated_at DATETIME NULL, 59 | created_at DATETIME DEFAULT CURRENT_TIMESTAMP 60 | )`) 61 | 62 | return err 63 | } 64 | -------------------------------------------------------------------------------- /template_test.go: -------------------------------------------------------------------------------- 1 | package nero 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type myType struct{} 11 | 12 | func TestFuncs(t *testing.T) { 13 | t.Run("typeFunc", func(t *testing.T) { 14 | got := typeFunc(1) 15 | expect := "int" 16 | assert.Equal(t, expect, got) 17 | 18 | got = typeFunc(&myType{}) 19 | expect = "nero.myType" 20 | assert.Equal(t, expect, got) 21 | }) 22 | 23 | t.Run("rawTypeFunc", func(t *testing.T) { 24 | got := rawTypeFunc(1) 25 | expect := "int" 26 | assert.Equal(t, expect, got) 27 | 28 | got = rawTypeFunc(&myType{}) 29 | expect = "*nero.myType" 30 | assert.Equal(t, expect, got) 31 | }) 32 | 33 | t.Run("zeroFunc", func(t *testing.T) { 34 | got := zeroValueFunc(0) 35 | expect := "0" 36 | assert.Equal(t, expect, got) 37 | 38 | got = zeroValueFunc([]string{}) 39 | expect = "nil" 40 | assert.Equal(t, expect, got) 41 | 42 | got = zeroValueFunc(true) 43 | expect = "false" 44 | assert.Equal(t, expect, got) 45 | 46 | got = zeroValueFunc(myType{}) 47 | expect = "(nero.myType{})" 48 | assert.Equal(t, expect, got) 49 | 50 | got = zeroValueFunc([1]int{1}) 51 | expect = "([1]int{})" 52 | assert.Equal(t, expect, got) 53 | 54 | got = zeroValueFunc([1][1]*myType{}) 55 | expect = "([1][1]*nero.myType{})" 56 | assert.Equal(t, expect, got) 57 | 58 | got = zeroValueFunc([0]interface{}{}) 59 | expect = "([0]interface {}{})" 60 | assert.Equal(t, expect, got) 61 | 62 | got = zeroValueFunc("") 63 | expect = `""` 64 | assert.Equal(t, expect, got) 65 | }) 66 | 67 | t.Run("isType", func(t *testing.T) { 68 | now := time.Now() 69 | assert.True(t, isTypeFunc(now, "time.Time")) 70 | assert.True(t, isTypeFunc(&now, "time.Time")) 71 | }) 72 | assert.Len(t, prependToFields(&Field{}, []*Field{}), 1) 73 | assert.NotEmpty(t, fileHeadersFunc()) 74 | } 75 | -------------------------------------------------------------------------------- /test/integration/playerrepo/aggregate.go: -------------------------------------------------------------------------------- 1 | // Code generated by nero, DO NOT EDIT. 2 | package playerrepo 3 | 4 | import ( 5 | "github.com/stevenferrer/nero/aggregate" 6 | ) 7 | 8 | // Avg is the average aggregate operator 9 | func Avg(field Field) aggregate.AggFunc { 10 | return func(aggs []*aggregate.Aggregate) []*aggregate.Aggregate { 11 | return append(aggs, &aggregate.Aggregate{ 12 | Field: field.String(), 13 | Op: aggregate.Avg, 14 | }) 15 | } 16 | } 17 | 18 | // Count is the count aggregate operator 19 | func Count(field Field) aggregate.AggFunc { 20 | return func(aggs []*aggregate.Aggregate) []*aggregate.Aggregate { 21 | return append(aggs, &aggregate.Aggregate{ 22 | Field: field.String(), 23 | Op: aggregate.Count, 24 | }) 25 | } 26 | } 27 | 28 | // Max is the max aggregate operator 29 | func Max(field Field) aggregate.AggFunc { 30 | return func(aggs []*aggregate.Aggregate) []*aggregate.Aggregate { 31 | return append(aggs, &aggregate.Aggregate{ 32 | Field: field.String(), 33 | Op: aggregate.Max, 34 | }) 35 | } 36 | } 37 | 38 | // Min is the min aggregate operator 39 | func Min(field Field) aggregate.AggFunc { 40 | return func(aggs []*aggregate.Aggregate) []*aggregate.Aggregate { 41 | return append(aggs, &aggregate.Aggregate{ 42 | Field: field.String(), 43 | Op: aggregate.Min, 44 | }) 45 | } 46 | } 47 | 48 | // Sum is the sum aggregate operator 49 | func Sum(field Field) aggregate.AggFunc { 50 | return func(aggs []*aggregate.Aggregate) []*aggregate.Aggregate { 51 | return append(aggs, &aggregate.Aggregate{ 52 | Field: field.String(), 53 | Op: aggregate.Sum, 54 | }) 55 | } 56 | } 57 | 58 | // None is the none aggregate operator 59 | func None(field Field) aggregate.AggFunc { 60 | return func(aggs []*aggregate.Aggregate) []*aggregate.Aggregate { 61 | return append(aggs, &aggregate.Aggregate{ 62 | Field: field.String(), 63 | Op: aggregate.None, 64 | }) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /test/gen/customtypes/customtypes.go: -------------------------------------------------------------------------------- 1 | package customtypes 2 | 3 | import ( 4 | "github.com/stevenferrer/nero" 5 | ) 6 | 7 | // UUID is a uuid type 8 | type UUID [16]byte 9 | 10 | // Custom demonstrates the use of many different field types 11 | type Custom struct { 12 | ID int64 13 | UUID UUID 14 | Str string 15 | MapStrStr map[string]string 16 | MapStrPtrStr map[string]*string 17 | MapInt64Str map[int64]string 18 | MapInt64PtrStr map[int64]*string 19 | MapStrItem map[string]Item 20 | MapStrPtrItem map[string]*Item 21 | Item Item 22 | PtrItem *Item 23 | Items []Item 24 | PtrItems []*Item 25 | NullColumn *string 26 | } 27 | 28 | // Item is an example struct embedded in Custom struct 29 | // 30 | // Note: Custom types like these must implement ValueScanner 31 | type Item struct { 32 | Name string 33 | } 34 | 35 | // Schema implements nero.Schemaer 36 | func (c Custom) Schema() *nero.Schema { 37 | return nero.NewSchemaBuilder(&c). 38 | PkgName("user").Table("users"). 39 | Identity( 40 | nero.NewFieldBuilder("id", c.ID). 41 | StructField("ID").Auto().Build(), 42 | ). 43 | Fields( 44 | nero.NewFieldBuilder("uuid", c.UUID).StructField("UUID").Build(), 45 | nero.NewFieldBuilder("str", c.Str).Build(), 46 | nero.NewFieldBuilder("map_str_str", c.MapStrStr).Build(), 47 | nero.NewFieldBuilder("map_str_ptr_str", c.MapStrPtrStr).Build(), 48 | nero.NewFieldBuilder("map_int64_str", c.MapInt64Str).Build(), 49 | nero.NewFieldBuilder("map_int64_ptr_str", c.MapInt64PtrStr).Build(), 50 | nero.NewFieldBuilder("map_str_item", c.MapStrItem).Build(), 51 | nero.NewFieldBuilder("map_str_ptr_item", c.MapStrPtrItem).Build(), 52 | nero.NewFieldBuilder("item", c.Item).Build(), 53 | nero.NewFieldBuilder("ptr_item", c.PtrItem).Build(), 54 | nero.NewFieldBuilder("items", c.Items).Build(), 55 | nero.NewFieldBuilder("ptr_items", c.PtrItems).Build(), 56 | nero.NewFieldBuilder("null_column", c.NullColumn).Build(), 57 | ).Build() 58 | } 59 | -------------------------------------------------------------------------------- /x/etc/fmt_src.go: -------------------------------------------------------------------------------- 1 | package etc 2 | 3 | import ( 4 | "bytes" 5 | "go/ast" 6 | "go/format" 7 | "go/parser" 8 | "go/scanner" 9 | "go/token" 10 | "io/ioutil" 11 | "os" 12 | "strings" 13 | 14 | "github.com/pkg/errors" 15 | "golang.org/x/tools/go/ast/astutil" 16 | "golang.org/x/tools/imports" 17 | ) 18 | 19 | // FmtSrc removes unneeded imports from the given Go source file and runs gofmt on it. 20 | // 21 | // See https://github.com/goadesign/goa/blob/v3/codegen/file.go#L136 22 | func FmtSrc(path string) error { 23 | // Make sure file parses and print content if it does not. 24 | fileSet := token.NewFileSet() 25 | astFile, err := parser.ParseFile(fileSet, path, nil, parser.ParseComments) 26 | if err != nil { 27 | content, _ := ioutil.ReadFile(path) 28 | buf := &bytes.Buffer{} 29 | scanner.PrintError(buf, err) 30 | return errors.Errorf("%s\n========\nContent:\n%s", buf.String(), content) 31 | } 32 | 33 | // Clean unused imports 34 | impss := astutil.Imports(fileSet, astFile) 35 | for _, imps := range impss { 36 | for _, imp := range imps { 37 | path := strings.Trim(imp.Path.Value, `"`) 38 | if !astutil.UsesImport(astFile, path) { 39 | if imp.Name != nil { 40 | astutil.DeleteNamedImport(fileSet, astFile, imp.Name.Name, path) 41 | } else { 42 | astutil.DeleteImport(fileSet, astFile, path) 43 | } 44 | } 45 | } 46 | } 47 | ast.SortImports(fileSet, astFile) 48 | w, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) 49 | if err != nil { 50 | return err 51 | } 52 | if err := format.Node(w, fileSet, astFile); err != nil { 53 | return err 54 | } 55 | err = w.Close() 56 | if err != nil { 57 | return err 58 | } 59 | 60 | // Format code using goimport standard 61 | b, err := ioutil.ReadFile(path) 62 | if err != nil { 63 | return err 64 | } 65 | b, err = imports.Process(path, b, &imports.Options{ 66 | Comments: true, 67 | FormatOnly: true, 68 | }) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | return ioutil.WriteFile(path, b, os.ModePerm) 74 | } 75 | -------------------------------------------------------------------------------- /schema.go: -------------------------------------------------------------------------------- 1 | package nero 2 | 3 | import ( 4 | "github.com/jinzhu/inflection" 5 | "github.com/stevenferrer/mira" 6 | stringsx "github.com/stevenferrer/nero/x/strings" 7 | ) 8 | 9 | // Schema is a schema used for generating the repository 10 | type Schema struct { 11 | // pkgName is the package name of the generated files 12 | pkgName string 13 | // table is the database table name 14 | table string 15 | // typeInfo is the type info of the schema model 16 | typeInfo *mira.TypeInfo 17 | // identity is the identity field 18 | identity *Field 19 | // fields is the list of fields 20 | fields []*Field 21 | // imports are list of package imports 22 | imports []string 23 | // Templates is the list of custom repository templates 24 | templates []Template 25 | } 26 | 27 | // PkgName returns the pkg name 28 | func (s *Schema) PkgName() string { 29 | return s.pkgName 30 | } 31 | 32 | // Table returns the database table name 33 | func (s *Schema) Table() string { 34 | return s.table 35 | } 36 | 37 | // Identity returns the identity field 38 | func (s *Schema) Identity() *Field { 39 | return s.identity 40 | } 41 | 42 | // Fields returns the fields 43 | func (s *Schema) Fields() []*Field { 44 | return s.fields[:] 45 | } 46 | 47 | // Imports returns the pkg imports 48 | func (s *Schema) Imports() []string { 49 | return s.imports[:] 50 | } 51 | 52 | // Templates returns the templates 53 | func (s *Schema) Templates() []Template { 54 | return s.templates[:] 55 | } 56 | 57 | // TypeInfo returns the type info 58 | func (s *Schema) TypeInfo() *mira.TypeInfo { 59 | return s.typeInfo 60 | } 61 | 62 | // TypeName returns the type name 63 | func (s *Schema) TypeName() string { 64 | return s.typeInfo.Name() 65 | } 66 | 67 | // TypeNamePlural returns the plural form of the type name 68 | func (s *Schema) TypeNamePlural() string { 69 | return inflection.Plural(s.TypeName()) 70 | } 71 | 72 | // TypeIdentifier returns the type identifier 73 | func (s *Schema) TypeIdentifier() string { 74 | return stringsx.ToLowerCamel(s.TypeName()) 75 | } 76 | 77 | // TypeIdentifierPlural returns the plural form of type identifier 78 | func (s *Schema) TypeIdentifierPlural() string { 79 | return inflection.Plural(s.TypeIdentifier()) 80 | } 81 | -------------------------------------------------------------------------------- /schema_builder.go: -------------------------------------------------------------------------------- 1 | package nero 2 | 3 | import ( 4 | "github.com/stevenferrer/mira" 5 | ) 6 | 7 | // SchemaBuilder is used for building a schema 8 | type SchemaBuilder struct { 9 | sc *Schema 10 | } 11 | 12 | // NewSchemaBuilder takes a struct value and returns a SchemaBuilder 13 | func NewSchemaBuilder(v interface{}) *SchemaBuilder { 14 | return &SchemaBuilder{sc: &Schema{ 15 | typeInfo: mira.NewTypeInfo(v), 16 | fields: []*Field{}, 17 | templates: []Template{}, 18 | }} 19 | } 20 | 21 | // PkgName sets the package name 22 | func (sb *SchemaBuilder) PkgName(pkgName string) *SchemaBuilder { 23 | sb.sc.pkgName = pkgName 24 | return sb 25 | } 26 | 27 | // Table sets the database table/collection name 28 | func (sb *SchemaBuilder) Table(table string) *SchemaBuilder { 29 | sb.sc.table = table 30 | return sb 31 | } 32 | 33 | // Identity sets the identity field 34 | func (sb *SchemaBuilder) Identity(field *Field) *SchemaBuilder { 35 | sb.sc.identity = field 36 | return sb 37 | } 38 | 39 | // Fields sets the fields 40 | func (sb *SchemaBuilder) Fields(fields ...*Field) *SchemaBuilder { 41 | sb.sc.fields = append(sb.sc.fields, fields...) 42 | return sb 43 | } 44 | 45 | // Templates sets the templates 46 | func (sb *SchemaBuilder) Templates(templates ...Template) *SchemaBuilder { 47 | sb.sc.templates = append(sb.sc.templates, templates...) 48 | return sb 49 | } 50 | 51 | // Build builds the schema 52 | func (sb *SchemaBuilder) Build() *Schema { 53 | templates := sb.sc.templates 54 | 55 | // use default template set 56 | if len(templates) == 0 { 57 | templates = []Template{ 58 | NewPostgresTemplate(), 59 | NewSQLiteTemplate(), 60 | } 61 | } 62 | 63 | // get pkg imports 64 | importMap := map[string]int{} 65 | for _, fld := range append(sb.sc.fields, sb.sc.identity) { 66 | if fld.typeInfo.PkgPath() != "" { 67 | importMap[fld.typeInfo.PkgPath()] = 1 68 | } 69 | } 70 | 71 | imports := []string{sb.sc.typeInfo.PkgPath()} 72 | for imp := range importMap { 73 | imports = append(imports, imp) 74 | } 75 | 76 | return &Schema{ 77 | typeInfo: sb.sc.typeInfo, 78 | pkgName: sb.sc.pkgName, 79 | table: sb.sc.table, 80 | identity: sb.sc.identity, 81 | fields: sb.sc.fields, 82 | imports: imports, 83 | templates: templates, 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /field.go: -------------------------------------------------------------------------------- 1 | package nero 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/jinzhu/inflection" 7 | "github.com/stevenferrer/mira" 8 | stringsx "github.com/stevenferrer/nero/x/strings" 9 | ) 10 | 11 | // Field is a field 12 | type Field struct { 13 | // name is the field name 14 | name string 15 | // typeInfo is the field type info 16 | typeInfo *mira.TypeInfo 17 | // StructField overrides the struct field 18 | structField string 19 | // Auto is the auto-filled flag 20 | auto, 21 | // Optional is the optional flag 22 | optional bool 23 | } 24 | 25 | // TypeInfo returns the type info 26 | func (f *Field) TypeInfo() *mira.TypeInfo { 27 | return f.typeInfo 28 | } 29 | 30 | // Name returns the field name 31 | func (f *Field) Name() string { 32 | return f.name 33 | } 34 | 35 | // StructField returns the struct field 36 | func (f *Field) StructField() string { 37 | structField := stringsx.ToCamel(f.name) 38 | if len(f.structField) > 0 { 39 | structField = f.structField 40 | } 41 | 42 | return structField 43 | } 44 | 45 | // Identifier returns the lower-camelized struct field 46 | func (f *Field) Identifier() string { 47 | return stringsx.ToLowerCamel(f.StructField()) 48 | } 49 | 50 | // IdentifierPlural returns the plural form of identifier 51 | func (f *Field) IdentifierPlural() string { 52 | return inflection.Plural(f.Identifier()) 53 | } 54 | 55 | // IsArray returns true if field is an array or a slice 56 | func (f *Field) IsArray() bool { 57 | kind := f.typeInfo.T().Kind() 58 | return kind == reflect.Array || 59 | kind == reflect.Slice 60 | } 61 | 62 | // IsNillable returns true if the field is nillable 63 | func (f *Field) IsNillable() bool { 64 | return f.typeInfo.IsNillable() 65 | } 66 | 67 | // IsValueScanner returns true if field implements value scanner 68 | func (f *Field) IsValueScanner() bool { 69 | t := reflect.TypeOf(f.typeInfo.V()) 70 | if t.Kind() != reflect.Ptr { 71 | t = reflect.New(t).Type() 72 | } 73 | 74 | return t.Implements(reflect.TypeOf(new(ValueScanner)).Elem()) 75 | } 76 | 77 | // IsAuto returns the auto flag 78 | func (f *Field) IsAuto() bool { 79 | return f.auto 80 | } 81 | 82 | // IsOptional returns the optional flag 83 | func (f *Field) IsOptional() bool { 84 | return f.optional 85 | } 86 | 87 | // IsComparable returns true if field is comparable i.e. with comparisong operators 88 | func (f *Field) IsComparable() bool { 89 | kind := f.typeInfo.T().Kind() 90 | return !(kind == reflect.Map || 91 | kind == reflect.Slice) 92 | } 93 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ main ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ main ] 20 | schedule: 21 | - cron: '29 16 * * 1' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | 28 | strategy: 29 | fail-fast: false 30 | matrix: 31 | language: [ 'go' ] 32 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] 33 | # Learn more: 34 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed 35 | 36 | steps: 37 | - name: Checkout repository 38 | uses: actions/checkout@v2 39 | 40 | # Initializes the CodeQL tools for scanning. 41 | - name: Initialize CodeQL 42 | uses: github/codeql-action/init@v1 43 | with: 44 | languages: ${{ matrix.language }} 45 | # If you wish to specify custom queries, you can do so here or in a config file. 46 | # By default, queries listed here will override any specified in a config file. 47 | # Prefix the list here with "+" to use these queries and those in the config file. 48 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 49 | 50 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 51 | # If this step fails, then you should remove it and run the build manually (see below) 52 | - name: Autobuild 53 | uses: github/codeql-action/autobuild@v1 54 | 55 | # ℹ️ Command-line programs to run using the OS shell. 56 | # 📚 https://git.io/JvXDl 57 | 58 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 59 | # and modify them (or add more) to build your code if your project 60 | # uses a compiled language 61 | 62 | #- run: | 63 | # make bootstrap 64 | # make release 65 | 66 | - name: Perform CodeQL Analysis 67 | uses: github/codeql-action/analyze@v1 68 | -------------------------------------------------------------------------------- /template.go: -------------------------------------------------------------------------------- 1 | package nero 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "text/template" 7 | 8 | "github.com/stevenferrer/mira" 9 | ) 10 | 11 | // Template is an interface that wraps the Filename and Content method 12 | type Template interface { 13 | // Filename is the filename of the generated file 14 | Filename() string 15 | // Content is returns the template content 16 | Content() string 17 | } 18 | 19 | // ParseTemplate parses the repository template 20 | func ParseTemplate(t Template) (*template.Template, error) { 21 | return template.New(t.Filename() + ".tmpl"). 22 | Funcs(NewFuncMap()).Parse(t.Content()) 23 | } 24 | 25 | // NewFuncMap returns a template func map 26 | func NewFuncMap() template.FuncMap { 27 | return template.FuncMap{ 28 | "type": typeFunc, 29 | "rawType": rawTypeFunc, 30 | "zeroValue": zeroValueFunc, 31 | "prependToFields": prependToFields, 32 | "fileHeaders": fileHeadersFunc, 33 | "isType": isTypeFunc, 34 | } 35 | } 36 | 37 | // typeFunc returns the type of the value 38 | func typeFunc(v interface{}) string { 39 | t := reflect.TypeOf(v) 40 | if t.Kind() != reflect.Ptr { 41 | return fmt.Sprintf("%T", v) 42 | } 43 | 44 | ev := reflect.New(resolveType(t)).Elem().Interface() 45 | return fmt.Sprintf("%T", ev) 46 | } 47 | 48 | // rawTypeFunc returns the raw type of the value 49 | func rawTypeFunc(v interface{}) string { 50 | return fmt.Sprintf("%T", v) 51 | } 52 | 53 | // resolveType resolves the type of the value 54 | func resolveType(t reflect.Type) reflect.Type { 55 | switch t.Kind() { 56 | case reflect.Ptr: 57 | return resolveType(t.Elem()) 58 | } 59 | return t 60 | } 61 | 62 | // zeroValueFunc returns zero value as a string 63 | func zeroValueFunc(v interface{}) string { 64 | ti := mira.NewTypeInfo(v) 65 | 66 | if ti.IsNillable() { 67 | return "nil" 68 | } 69 | 70 | if ti.IsNumeric() { 71 | return "0" 72 | } 73 | 74 | switch ti.T().Kind() { 75 | case reflect.Bool: 76 | return "false" 77 | case reflect.Struct, 78 | reflect.Array: 79 | return fmt.Sprintf("(%T{})", v) 80 | } 81 | 82 | return "\"\"" 83 | 84 | } 85 | 86 | // prependToFields prepends a field to the list of fields 87 | func prependToFields(field *Field, fields []*Field) []*Field { 88 | return append([]*Field{field}, fields...) 89 | } 90 | 91 | const fileHeaders = ` 92 | // Code generated by nero, DO NOT EDIT. 93 | ` 94 | 95 | // fileHeadersFunc returns the standard file headers 96 | func fileHeadersFunc() string { 97 | return fileHeaders 98 | } 99 | 100 | func isTypeFunc(v interface{}, typeStr string) bool { 101 | return typeFunc(v) == typeStr 102 | } 103 | -------------------------------------------------------------------------------- /x/strings/camel_case_test.go: -------------------------------------------------------------------------------- 1 | package strings 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestToCamel(t *testing.T) { 11 | tests := []struct { 12 | input, 13 | want string 14 | }{ 15 | { 16 | input: "Hello camel", 17 | want: "HelloCamel", 18 | }, 19 | { 20 | input: "Hello-camel", 21 | want: "HelloCamel", 22 | }, 23 | { 24 | input: "Hello_under_camel", 25 | want: "HelloUnderCamel", 26 | }, 27 | { 28 | input: "Asd__-- _sep _every_wheres asdf___", 29 | want: "AsdSepEveryWheresAsdf", 30 | }, 31 | { 32 | input: "hello", 33 | want: "Hello", 34 | }, 35 | { 36 | input: "id", 37 | want: "Id", 38 | }, 39 | { 40 | input: "HttpServer", 41 | want: "HttpServer", 42 | }, 43 | { 44 | input: "Id", 45 | want: "Id", 46 | }, 47 | } 48 | for _, tt := range tests { 49 | t.Run(tt.input, func(t *testing.T) { 50 | got := ToCamel(tt.input) 51 | assert.Equal(t, tt.want, got, fmt.Sprintf("input: %q, expect: %q actual: %q", tt.input, tt.want, got)) 52 | }) 53 | } 54 | } 55 | 56 | var resultCamel string 57 | 58 | func BenchmarkToCamel(b *testing.B) { 59 | var s string 60 | for n := 0; n < b.N; n++ { 61 | s = ToLowerCamel("_--34-asd__-- _sep _every_wherea asdf___") 62 | } 63 | resultCamel = s 64 | } 65 | 66 | func TestToLowerCamel(t *testing.T) { 67 | tests := []struct { 68 | input, 69 | want string 70 | }{ 71 | { 72 | input: "Hello world of 123camel", 73 | want: "helloWorldOf123Camel", 74 | }, 75 | { 76 | input: "Hello-camel", 77 | want: "helloCamel", 78 | }, 79 | { 80 | input: "Hello_under_camel", 81 | want: "helloUnderCamel", 82 | }, 83 | { 84 | input: "Asd__-- _sep_3 _every_wherea asdf___", 85 | want: "asdSep3EveryWhereaAsdf", 86 | }, 87 | { 88 | input: "Hello", 89 | want: "hello", 90 | }, 91 | { 92 | input: "Http-Server", 93 | want: "httpServer", 94 | }, 95 | { 96 | input: "Http_Test", 97 | want: "httpTest", 98 | }, 99 | { 100 | input: "id", 101 | want: "id", 102 | }, 103 | { 104 | input: "mY Id", 105 | want: "mYId", 106 | }, 107 | { 108 | input: "ID", 109 | want: "id", 110 | }, 111 | { 112 | input: "HTTP", 113 | want: "http", 114 | }, 115 | } 116 | for _, tt := range tests { 117 | t.Run(tt.input, func(t *testing.T) { 118 | got := ToLowerCamel(tt.input) 119 | assert.Equal(t, tt.want, got) 120 | }) 121 | } 122 | 123 | } 124 | 125 | var resultLowerCamel string 126 | 127 | func BenchmarkToLowerCamel(b *testing.B) { 128 | var s string 129 | for n := 0; n < b.N; n++ { 130 | s = ToLowerCamel("_--34-asd__-- _sep _every_wherea asdf___") 131 | } 132 | resultLowerCamel = s 133 | } 134 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/Masterminds/squirrel v1.5.0 h1:JukIZisrUXadA9pl3rMkjhiamxiB0cXiu+HGp/Y8cY8= 2 | github.com/Masterminds/squirrel v1.5.0/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= 7 | github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 8 | github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= 9 | github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= 10 | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= 11 | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= 12 | github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw= 13 | github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o= 14 | github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk= 15 | github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw= 16 | github.com/lib/pq v1.10.1 h1:6VXZrLU0jHBYyAqrSPa+MgPfnSvTPuMgK+k0o5kVFWo= 17 | github.com/lib/pq v1.10.1/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 18 | github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= 19 | github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= 20 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 21 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 22 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 23 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 24 | github.com/stevenferrer/mira v0.3.0 h1:lofPeLbmVJVD1jitDwnrKzVMb4hyodRh/Lg6Q7hH8yQ= 25 | github.com/stevenferrer/mira v0.3.0/go.mod h1:sEnr07x0mIbYHk2zo92I+biz91nJ7B8W956eoiujY+I= 26 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 27 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 28 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 29 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 30 | github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= 31 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 32 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 33 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 34 | golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= 35 | golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 36 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 37 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 38 | golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= 39 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 40 | golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 41 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 42 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 43 | golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 44 | golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 h1:myAQVi0cGEoqQVR5POX+8RR2mrocKqNN1hmeMqhX27k= 45 | golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 46 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 47 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 48 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 49 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 50 | golang.org/x/tools v0.1.0 h1:po9/4sTYwZU9lPhi1tOrb4hCv3qrhiQ77LZfGa2OjwY= 51 | golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= 52 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 53 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 54 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= 55 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 56 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 57 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 58 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 59 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![GoDoc Reference](https://pkg.go.dev/badge/github.com/stevenferrer/nero)](https://pkg.go.dev/github.com/stevenferrer/nero) 2 | ![Github Actions](https://github.com/stevenferrer/nero/workflows/test/badge.svg) 3 | [![Coverage Status](https://coveralls.io/repos/github/stevenferrer/nero/badge.svg?branch=main)](https://coveralls.io/github/stevenferrer/nero?branch=main) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/stevenferrer/nero)](https://goreportcard.com/report/github.com/stevenferrer/nero) 5 | 6 | # Nero 7 | 8 | A library for generating the repository pattern. 9 | 10 | ## Motivation 11 | 12 | We heavily use the _[repository pattern](https://threedots.tech/post/repository-pattern-in-go/)_ in our codebases and we often [write our queries manually](https://golang.org/pkg/database/sql/#example_DB_QueryContext). It becomes tedious and repetitive as we have more tables/models to maintain. So, we decided to experiment on creating this library to generate our repositories automatically. 13 | 14 | ## Goals 15 | 16 | - Decouple implementation from the `Repository` interface 17 | - Easy integration with existing codebase 18 | - Minimal API 19 | 20 | ## Installation 21 | 22 | ```console 23 | $ go get github.com/stevenferrer/nero 24 | ``` 25 | 26 | ## Example 27 | 28 | See the [official example](https://github.com/stevenferrer/nero-example) and [integration test](./test/integration/playerrepo) for a more complete demo. 29 | 30 | ```go 31 | import ( 32 | "database/sql" 33 | 34 | // import the generated package 35 | "github.com/stevenferrer/nero-example/productrepo" 36 | ) 37 | 38 | func main() { 39 | dsn := "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable" 40 | db, err := sql.Open("postgres", dsn) 41 | ... 42 | 43 | ctx := context.Background() 44 | 45 | 46 | // initialize the repository (and optionally enable debug) 47 | productRepo := productrepo.NewPostgresRepository(db).Debug() 48 | 49 | // create 50 | creator := productrepo.NewCreator().Name("Product 1") 51 | productID, err := productRepo.Create(ctx, creator) 52 | ... 53 | 54 | // query 55 | queryer := productrepo.NewQueryer().Where(productrepo.IDEq(product1ID)) 56 | product, err := productRepo.QueryOne(ctx, queryer) 57 | ... 58 | 59 | // update 60 | now := time.Now() 61 | updater := productrepo.NewUpdater().Name("Updated Product 1"). 62 | UpdatedAt(&now).Where(productrepo.IDEq(product1ID)) 63 | _, err = productRepo.Update(ctx, updater) 64 | ... 65 | 66 | // delete 67 | deleter := productrepo.NewDeleter().Where(productrepo.IDEq(product1ID)) 68 | _, err = productRepo.Delete(ctx, deleter) 69 | ... 70 | } 71 | ``` 72 | 73 | ## Supported back-ends 74 | 75 | Below is the list of supported back-ends. 76 | 77 | | Back-end | Library | 78 | | -------------------- | ------------------------------------------------------------- | 79 | | PostgreSQL | [lib/pq](http://github.com/lib/pq) | 80 | | SQLite | [mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) | 81 | | MySQL/MariaDB (soon) | [go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) | 82 | 83 | If your your back-end is not yet supported, you can implement your own [custom back-end](#custom-back-ends). 84 | 85 | ## Custom back-ends 86 | 87 | Implementing a custom back-end is very easy. In fact, you don't have to use the official back-ends. You can implement custom back-ends (Oracle, MSSQL, BoltDB, MongoDB, etc.) by implementing the [_Template_](./template.go) interface. 88 | 89 | See official [postgres template](./pg_template.go) for reference. 90 | 91 | ## Limitations 92 | 93 | Currently, we only support basic CRUD and aggregate operations (i.e. count, sum). If you have more complex requirements other than that, we suggest that you just write your repositories manually, at least for now. 94 | 95 | We're still in the process of brain-storming how to elegantly support other operations such as joins. If you have any ideas, we'd love to hear from you! 96 | 97 | ## Standing on the shoulders of giants 98 | 99 | This project wouldn't be possible without the amazing open-source projects it was built upon: 100 | 101 | - [Masterminds/squirrel](https://github.com/Masterminds/squirrel) - Fluent SQL generation in golang 102 | - [lib/pq](https://github.com/lib/pq) - Pure Go Postgres driver for database/sql 103 | - [mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) - sqlite3 driver conforming to the built-in database/sql interface 104 | - [pkg/errors](https://github.com/pkg/errors) - Simple error handling primitives 105 | - [hashicorp/multierror](https://github.com/hashicorp/go-multierror) - A Go (golang) package for representing a list of errors as a single error. 106 | - [jinzhu/inflection](https://github.com/jinzhu/inflection) - Pluralizes and singularizes English nouns 107 | - [stretchr/testify](https://github.com/stretchr/testify) - A toolkit with common assertions and mocks that plays nicely with the standard library 108 | 109 | Also, the following have a huge influence on this project and deserves most of the credits: 110 | 111 | - [ent](https://github.com/facebook/ent) - An entity framework for Go. Simple, yet powerful ORM for modeling and querying data. 112 | - [SQLBoiler](https://github.com/volatiletech/sqlboiler) - Generate a Go ORM tailored to your database schema. 113 | 114 | ## What's in the name? 115 | 116 | The project name is inspired by an [anti-bird](https://blackclover.fandom.com/wiki/Anti-bird) in an anime called [Black Clover](https://blackclover.fandom.com/wiki/Black_Clover_Wiki). The anti-bird, which the [Black Bulls](https://blackclover.fandom.com/wiki/Black_Bull) squad calls _Nero_ is apparently a human named [Secre Swallowtail](https://blackclover.fandom.com/wiki/Secre_Swallowtail). It's a really cool anime with lots of magic! 117 | 118 | ## Contributing 119 | 120 | Any suggestions and ideas are very much welcome, feel free to [open an issue](https://github.com/stevenferrer/nero/issues) or [make a pull request](https://github.com/stevenferrer/nero/pulls)! 121 | 122 | ## License 123 | 124 | [Apache-2.0](LICENSE) 125 | -------------------------------------------------------------------------------- /gen/predicate.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bytes" 5 | "text/template" 6 | 7 | "github.com/stevenferrer/nero" 8 | "github.com/stevenferrer/nero/comparison" 9 | ) 10 | 11 | func newPredicateFile(schema *nero.Schema) (*File, error) { 12 | tmpl, err := template.New("predicates.tmpl"). 13 | Funcs(nero.NewFuncMap()).Parse(predicatesTmpl) 14 | if err != nil { 15 | return nil, err 16 | } 17 | 18 | data := struct { 19 | EqOps []comparison.Operator 20 | LtGtOps []comparison.Operator 21 | NullOps []comparison.Operator 22 | InOps []comparison.Operator 23 | Schema *nero.Schema 24 | }{ 25 | EqOps: []comparison.Operator{ 26 | comparison.Eq, 27 | comparison.NotEq, 28 | }, 29 | LtGtOps: []comparison.Operator{ 30 | comparison.Gt, 31 | comparison.GtOrEq, 32 | comparison.Lt, 33 | comparison.LtOrEq, 34 | }, 35 | NullOps: []comparison.Operator{ 36 | comparison.IsNull, 37 | comparison.IsNotNull, 38 | }, 39 | InOps: []comparison.Operator{ 40 | comparison.In, 41 | comparison.NotIn, 42 | }, 43 | Schema: schema, 44 | } 45 | 46 | buf := &bytes.Buffer{} 47 | err = tmpl.Execute(buf, data) 48 | if err != nil { 49 | return nil, err 50 | } 51 | 52 | return &File{name: "predicate.go", buf: buf.Bytes()}, nil 53 | } 54 | 55 | const predicatesTmpl = ` 56 | {{- fileHeaders -}} 57 | 58 | package {{.Schema.PkgName}} 59 | 60 | import ( 61 | "github.com/lib/pq" 62 | "github.com/stevenferrer/nero/comparison" 63 | {{range $import := .Schema.Imports -}} 64 | "{{$import}}" 65 | {{end -}} 66 | ) 67 | 68 | {{ $fields := prependToFields .Schema.Identity .Schema.Fields }} 69 | 70 | {{range $field := $fields -}} 71 | {{if $field.IsComparable -}} 72 | {{ range $op := $.EqOps }} 73 | // {{$field.StructField}}{{$op.String}} {{$op.Desc}} operator on {{$field.StructField}} field 74 | func {{$field.StructField}}{{$op.String}} ({{$field.Identifier}} {{rawType $field.TypeInfo.V}}) comparison.PredFunc { 75 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 76 | return append(preds, &comparison.Predicate{ 77 | Field: "{{$field.Name}}", 78 | Op: comparison.{{$op.String}}, 79 | {{if and ($field.IsArray) (ne $field.IsValueScanner true) -}} 80 | Arg: pq.Array({{$field.Identifier}}), 81 | {{else -}} 82 | Arg: {{$field.Identifier}}, 83 | {{end -}} 84 | }) 85 | } 86 | } 87 | {{end}} 88 | 89 | {{ range $op := $.LtGtOps }} 90 | {{if or $field.TypeInfo.IsNumeric (isType $field.TypeInfo.V "time.Time")}} 91 | // {{$field.StructField}}{{$op.String}} {{$op.Desc}} operator on {{$field.StructField}} field 92 | func {{$field.StructField}}{{$op.String}} ({{$field.Identifier}} {{rawType $field.TypeInfo.V}}) comparison.PredFunc { 93 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 94 | return append(preds, &comparison.Predicate{ 95 | Field: "{{$field.Name}}", 96 | Op: comparison.{{$op.String}}, 97 | {{if and ($field.IsArray) (ne $field.IsValueScanner true) -}} 98 | Arg: pq.Array({{$field.Identifier}}), 99 | {{else -}} 100 | Arg: {{$field.Identifier}}, 101 | {{end -}} 102 | }) 103 | } 104 | } 105 | {{end}} 106 | {{end }} 107 | 108 | {{ range $op := $.NullOps }} 109 | {{if $field.IsNillable}} 110 | // {{$field.StructField}}{{$op.String}} {{$op.Desc}} operator on {{$field.StructField}} field 111 | func {{$field.StructField}}{{$op.String}} () comparison.PredFunc { 112 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 113 | return append(preds, &comparison.Predicate{ 114 | Field: "{{$field.Name}}", 115 | Op: comparison.{{$op.String}}, 116 | }) 117 | } 118 | } 119 | {{end}} 120 | {{end}} 121 | 122 | {{ range $op := $.InOps }} 123 | // {{$field.StructField}}{{$op.String}} {{$op.Desc}} operator on {{$field.StructField}} field 124 | func {{$field.StructField}}{{$op.String}} ({{$field.IdentifierPlural}} ...{{rawType $field.TypeInfo.V}}) comparison.PredFunc { 125 | args := []interface{}{} 126 | for _, v := range {{$field.IdentifierPlural}} { 127 | args = append(args, v) 128 | } 129 | 130 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 131 | return append(preds, &comparison.Predicate{ 132 | Field: "{{$field.Name}}", 133 | Op: comparison.{{$op.String}}, 134 | Arg: args, 135 | }) 136 | } 137 | } 138 | {{end}} 139 | {{end}} 140 | {{end -}} 141 | 142 | {{ range $op := $.EqOps }} 143 | // FieldX{{$op.String}}FieldY fieldX {{$op.Desc}} fieldY 144 | // 145 | // fieldX and fieldY must be of the same type 146 | func FieldX{{$op.String}}FieldY (fieldX, fieldY Field) comparison.PredFunc { 147 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 148 | return append(preds, &comparison.Predicate{ 149 | Field: fieldX.String(), 150 | Op: comparison.{{$op.String}}, 151 | Arg: fieldY, 152 | }) 153 | } 154 | } 155 | {{end}} 156 | 157 | {{ range $op := $.LtGtOps }} 158 | // FieldX{{$op.String}}FieldY fieldX {{$op.Desc}} fieldY 159 | // 160 | // fieldX and fieldY must be of the same type 161 | func FieldX{{$op.String}}FieldY (fieldX, fieldY Field) comparison.PredFunc { 162 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 163 | return append(preds, &comparison.Predicate{ 164 | Field: fieldX.String(), 165 | Op: comparison.{{$op.String}}, 166 | Arg: fieldY, 167 | }) 168 | } 169 | } 170 | {{end}} 171 | ` 172 | -------------------------------------------------------------------------------- /test/integration/playerrepo/repository.go: -------------------------------------------------------------------------------- 1 | // Code generated by nero, DO NOT EDIT. 2 | package playerrepo 3 | 4 | import ( 5 | "context" 6 | "reflect" 7 | "time" 8 | 9 | multierror "github.com/hashicorp/go-multierror" 10 | "github.com/pkg/errors" 11 | "github.com/stevenferrer/nero" 12 | "github.com/stevenferrer/nero/aggregate" 13 | "github.com/stevenferrer/nero/comparison" 14 | "github.com/stevenferrer/nero/sort" 15 | "github.com/stevenferrer/nero/test/integration/player" 16 | ) 17 | 18 | // Repository is an interface that provides the methods 19 | // for interacting with a Player repository 20 | type Repository interface { 21 | // BeginTx starts a transaction 22 | BeginTx(context.Context) (nero.Tx, error) 23 | // Create creates a Player 24 | Create(context.Context, *Creator) (id string, err error) 25 | // CreateInTx creates a Player in a transaction 26 | CreateInTx(context.Context, nero.Tx, *Creator) (id string, err error) 27 | // CreateMany batch creates Players 28 | CreateMany(context.Context, ...*Creator) error 29 | // CreateManyInTx batch creates Players in a transaction 30 | CreateManyInTx(context.Context, nero.Tx, ...*Creator) error 31 | // Query queries Players 32 | Query(context.Context, *Queryer) ([]*player.Player, error) 33 | // QueryTx queries Players in a transaction 34 | QueryInTx(context.Context, nero.Tx, *Queryer) ([]*player.Player, error) 35 | // QueryOne queries a Player 36 | QueryOne(context.Context, *Queryer) (*player.Player, error) 37 | // QueryOneTx queries a Player in a transaction 38 | QueryOneInTx(context.Context, nero.Tx, *Queryer) (*player.Player, error) 39 | // Update updates a Player or many Players 40 | Update(context.Context, *Updater) (rowsAffected int64, err error) 41 | // UpdateTx updates a Player many Players in a transaction 42 | UpdateInTx(context.Context, nero.Tx, *Updater) (rowsAffected int64, err error) 43 | // Delete deletes a Player or many Players 44 | Delete(context.Context, *Deleter) (rowsAffected int64, err error) 45 | // Delete deletes a Player or many Players in a transaction 46 | DeleteInTx(context.Context, nero.Tx, *Deleter) (rowsAffected int64, err error) 47 | // Aggregate performs an aggregate query 48 | Aggregate(context.Context, *Aggregator) error 49 | // Aggregate performs an aggregate query in a transaction 50 | AggregateInTx(context.Context, nero.Tx, *Aggregator) error 51 | } 52 | 53 | // Creator is a create builder 54 | type Creator struct { 55 | email string 56 | name string 57 | age int 58 | race player.Race 59 | updatedAt *time.Time 60 | } 61 | 62 | // NewCreator returns a Creator 63 | func NewCreator() *Creator { 64 | return &Creator{} 65 | } 66 | 67 | // Email sets the Email field 68 | func (c *Creator) Email(email string) *Creator { 69 | c.email = email 70 | return c 71 | } 72 | 73 | // Name sets the Name field 74 | func (c *Creator) Name(name string) *Creator { 75 | c.name = name 76 | return c 77 | } 78 | 79 | // Age sets the Age field 80 | func (c *Creator) Age(age int) *Creator { 81 | c.age = age 82 | return c 83 | } 84 | 85 | // Race sets the Race field 86 | func (c *Creator) Race(race player.Race) *Creator { 87 | c.race = race 88 | return c 89 | } 90 | 91 | // UpdatedAt sets the UpdatedAt field 92 | func (c *Creator) UpdatedAt(updatedAt *time.Time) *Creator { 93 | c.updatedAt = updatedAt 94 | return c 95 | } 96 | 97 | // Validate validates the fields 98 | func (c *Creator) Validate() error { 99 | var err error 100 | if isZero(c.email) { 101 | err = multierror.Append(err, nero.NewErrRequiredField("email")) 102 | } 103 | 104 | if isZero(c.name) { 105 | err = multierror.Append(err, nero.NewErrRequiredField("name")) 106 | } 107 | 108 | if isZero(c.age) { 109 | err = multierror.Append(err, nero.NewErrRequiredField("age")) 110 | } 111 | 112 | if isZero(c.race) { 113 | err = multierror.Append(err, nero.NewErrRequiredField("race")) 114 | } 115 | 116 | return err 117 | } 118 | 119 | // Queryer is a query builder 120 | type Queryer struct { 121 | limit uint 122 | offset uint 123 | predFuncs []comparison.PredFunc 124 | sortFuncs []sort.SortFunc 125 | } 126 | 127 | // NewQueryer returns a Queryer 128 | func NewQueryer() *Queryer { 129 | return &Queryer{} 130 | } 131 | 132 | // Where applies predicates 133 | func (q *Queryer) Where(predFuncs ...comparison.PredFunc) *Queryer { 134 | q.predFuncs = append(q.predFuncs, predFuncs...) 135 | return q 136 | } 137 | 138 | // Sort applies sorting expressions 139 | func (q *Queryer) Sort(sortFuncs ...sort.SortFunc) *Queryer { 140 | q.sortFuncs = append(q.sortFuncs, sortFuncs...) 141 | return q 142 | } 143 | 144 | // Limit applies limit 145 | func (q *Queryer) Limit(limit uint) *Queryer { 146 | q.limit = limit 147 | return q 148 | } 149 | 150 | // Offset applies offset 151 | func (q *Queryer) Offset(offset uint) *Queryer { 152 | q.offset = offset 153 | return q 154 | } 155 | 156 | // Updater is an update builder 157 | type Updater struct { 158 | email string 159 | name string 160 | age int 161 | race player.Race 162 | updatedAt *time.Time 163 | predFuncs []comparison.PredFunc 164 | } 165 | 166 | // NewUpdater returns an Updater 167 | func NewUpdater() *Updater { 168 | return &Updater{} 169 | } 170 | 171 | // Email sets the Email field 172 | func (c *Updater) Email(email string) *Updater { 173 | c.email = email 174 | return c 175 | } 176 | 177 | // Name sets the Name field 178 | func (c *Updater) Name(name string) *Updater { 179 | c.name = name 180 | return c 181 | } 182 | 183 | // Age sets the Age field 184 | func (c *Updater) Age(age int) *Updater { 185 | c.age = age 186 | return c 187 | } 188 | 189 | // Race sets the Race field 190 | func (c *Updater) Race(race player.Race) *Updater { 191 | c.race = race 192 | return c 193 | } 194 | 195 | // UpdatedAt sets the UpdatedAt field 196 | func (c *Updater) UpdatedAt(updatedAt *time.Time) *Updater { 197 | c.updatedAt = updatedAt 198 | return c 199 | } 200 | 201 | // Where applies predicates 202 | func (u *Updater) Where(predFuncs ...comparison.PredFunc) *Updater { 203 | u.predFuncs = append(u.predFuncs, predFuncs...) 204 | return u 205 | } 206 | 207 | // Deleter is a delete builder 208 | type Deleter struct { 209 | predFuncs []comparison.PredFunc 210 | } 211 | 212 | // NewDeleter returns a Deleter 213 | func NewDeleter() *Deleter { 214 | return &Deleter{} 215 | } 216 | 217 | // Where applies predicates 218 | func (d *Deleter) Where(predFuncs ...comparison.PredFunc) *Deleter { 219 | d.predFuncs = append(d.predFuncs, predFuncs...) 220 | return d 221 | } 222 | 223 | // Aggregator is an aggregate query builder 224 | type Aggregator struct { 225 | v interface{} 226 | aggFuncs []aggregate.AggFunc 227 | predFuncs []comparison.PredFunc 228 | sortFuncs []sort.SortFunc 229 | groupBys []Field 230 | } 231 | 232 | // NewAggregator expects a v and returns an Aggregator 233 | // where 'v' argument must be an array of struct 234 | func NewAggregator(v interface{}) *Aggregator { 235 | return &Aggregator{v: v} 236 | } 237 | 238 | // Aggregate applies aggregate functions 239 | func (a *Aggregator) Aggregate(aggFuncs ...aggregate.AggFunc) *Aggregator { 240 | a.aggFuncs = append(a.aggFuncs, aggFuncs...) 241 | return a 242 | } 243 | 244 | // Where applies predicates 245 | func (a *Aggregator) Where(predFuncs ...comparison.PredFunc) *Aggregator { 246 | a.predFuncs = append(a.predFuncs, predFuncs...) 247 | return a 248 | } 249 | 250 | // Sort applies sorting expressions 251 | func (a *Aggregator) Sort(sortFuncs ...sort.SortFunc) *Aggregator { 252 | a.sortFuncs = append(a.sortFuncs, sortFuncs...) 253 | return a 254 | } 255 | 256 | // Group applies group clauses 257 | func (a *Aggregator) GroupBy(fields ...Field) *Aggregator { 258 | a.groupBys = append(a.groupBys, fields...) 259 | return a 260 | } 261 | 262 | // rollback performs a rollback 263 | func rollback(tx nero.Tx, err error) error { 264 | rerr := tx.Rollback() 265 | if rerr != nil { 266 | err = errors.Wrapf(err, "rollback error: %v", rerr) 267 | } 268 | return err 269 | } 270 | 271 | // isZero checks if v is a zero-value 272 | func isZero(v interface{}) bool { 273 | return reflect.ValueOf(v).IsZero() 274 | } 275 | -------------------------------------------------------------------------------- /gen/repository.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bytes" 5 | "text/template" 6 | 7 | "github.com/stevenferrer/nero" 8 | ) 9 | 10 | func newRepositoryFile(schema *nero.Schema) (*File, error) { 11 | tmpl, err := template.New("repository.tmpl"). 12 | Funcs(nero.NewFuncMap()).Parse(repositoryTmpl) 13 | if err != nil { 14 | return nil, err 15 | } 16 | 17 | buf := &bytes.Buffer{} 18 | err = tmpl.Execute(buf, schema) 19 | if err != nil { 20 | return nil, err 21 | } 22 | 23 | return &File{name: "repository.go", buf: buf.Bytes()}, nil 24 | } 25 | 26 | const repositoryTmpl = ` 27 | {{- fileHeaders -}} 28 | 29 | package {{.PkgName}} 30 | 31 | import ( 32 | "context" 33 | "reflect" 34 | "github.com/pkg/errors" 35 | "github.com/stevenferrer/nero" 36 | "github.com/stevenferrer/nero/comparison" 37 | "github.com/stevenferrer/nero/sort" 38 | "github.com/stevenferrer/nero/aggregate" 39 | multierror "github.com/hashicorp/go-multierror" 40 | {{range $import := .Imports -}} 41 | "{{$import}}" 42 | {{end -}} 43 | ) 44 | 45 | // Repository is an interface that provides the methods 46 | // for interacting with a {{.TypeInfo.Name}} repository 47 | type Repository interface { 48 | // BeginTx starts a transaction 49 | BeginTx(context.Context) (nero.Tx, error) 50 | // Create creates a {{.TypeName}} 51 | Create(context.Context, *Creator) (id {{rawType .Identity.TypeInfo.V}}, err error) 52 | // CreateInTx creates a {{.TypeName}} in a transaction 53 | CreateInTx(context.Context, nero.Tx, *Creator) (id {{rawType .Identity.TypeInfo.V}}, err error) 54 | // CreateMany batch creates {{.TypeNamePlural}} 55 | CreateMany(context.Context, ...*Creator) error 56 | // CreateManyInTx batch creates {{.TypeNamePlural}} in a transaction 57 | CreateManyInTx(context.Context, nero.Tx, ...*Creator) error 58 | // Query queries {{.TypeNamePlural}} 59 | Query(context.Context, *Queryer) ([]{{rawType .TypeInfo.V}}, error) 60 | // QueryTx queries {{.TypeNamePlural}} in a transaction 61 | QueryInTx(context.Context, nero.Tx, *Queryer) ([]{{rawType .TypeInfo.V}}, error) 62 | // QueryOne queries a {{.TypeName}} 63 | QueryOne(context.Context, *Queryer) ({{rawType .TypeInfo.V}}, error) 64 | // QueryOneTx queries a {{.TypeName}} in a transaction 65 | QueryOneInTx(context.Context, nero.Tx, *Queryer) ({{rawType .TypeInfo.V}}, error) 66 | // Update updates a {{.TypeName}} or many {{.TypeNamePlural}} 67 | Update(context.Context, *Updater) (rowsAffected int64, err error) 68 | // UpdateTx updates a {{.TypeName}} many {{.TypeNamePlural}} in a transaction 69 | UpdateInTx(context.Context, nero.Tx, *Updater) (rowsAffected int64, err error) 70 | // Delete deletes a {{.TypeName}} or many {{.TypeNamePlural}} 71 | Delete(context.Context, *Deleter) (rowsAffected int64, err error) 72 | // Delete deletes a {{.TypeName}} or many {{.TypeNamePlural}} in a transaction 73 | DeleteInTx(context.Context, nero.Tx, *Deleter) (rowsAffected int64, err error) 74 | // Aggregate performs an aggregate query 75 | Aggregate(context.Context, *Aggregator) error 76 | // Aggregate performs an aggregate query in a transaction 77 | AggregateInTx(context.Context, nero.Tx, *Aggregator) error 78 | } 79 | 80 | 81 | {{ $fields := prependToFields .Identity .Fields }} 82 | 83 | // Creator is a create builder 84 | type Creator struct { 85 | {{range $field := $fields -}} 86 | {{if ne $field.IsAuto true -}} 87 | {{$field.Identifier}} {{rawType $field.TypeInfo.V}} 88 | {{end -}} 89 | {{end -}} 90 | } 91 | 92 | // NewCreator returns a Creator 93 | func NewCreator() *Creator { 94 | return &Creator{} 95 | } 96 | 97 | {{range $field := $fields }} 98 | {{if ne $field.IsAuto true -}} 99 | // {{$field.StructField}} sets the {{$field.StructField}} field 100 | func (c *Creator) {{$field.StructField}}({{$field.Identifier}} {{rawType $field.TypeInfo.V}}) *Creator { 101 | c.{{$field.Identifier}} = {{$field.Identifier}} 102 | return c 103 | } 104 | {{end -}} 105 | {{end -}} 106 | 107 | // Validate validates the fields 108 | func (c *Creator) Validate() error { 109 | var err error 110 | {{range $field := .Fields -}} 111 | {{if and (ne $field.IsOptional true) (ne $field.IsAuto true) -}} 112 | if isZero(c.{{$field.Identifier}}) { 113 | err = multierror.Append(err, nero.NewErrRequiredField("{{$field.Name}}")) 114 | } 115 | {{end}} 116 | {{end}} 117 | 118 | return err 119 | } 120 | 121 | // Queryer is a query builder 122 | type Queryer struct { 123 | limit uint 124 | offset uint 125 | predFuncs []comparison.PredFunc 126 | sortFuncs []sort.SortFunc 127 | } 128 | 129 | // NewQueryer returns a Queryer 130 | func NewQueryer() *Queryer { 131 | return &Queryer{} 132 | } 133 | 134 | // Where applies predicates 135 | func (q *Queryer) Where(predFuncs ...comparison.PredFunc) *Queryer { 136 | q.predFuncs = append(q.predFuncs, predFuncs...) 137 | return q 138 | } 139 | 140 | // Sort applies sorting expressions 141 | func (q *Queryer) Sort(sortFuncs ...sort.SortFunc) *Queryer { 142 | q.sortFuncs = append(q.sortFuncs, sortFuncs...) 143 | return q 144 | } 145 | 146 | // Limit applies limit 147 | func (q *Queryer) Limit(limit uint) *Queryer { 148 | q.limit = limit 149 | return q 150 | } 151 | 152 | // Offset applies offset 153 | func (q *Queryer) Offset(offset uint) *Queryer { 154 | q.offset = offset 155 | return q 156 | } 157 | 158 | // Updater is an update builder 159 | type Updater struct { 160 | {{range $field := .Fields -}} 161 | {{if ne $field.IsAuto true -}} 162 | {{$field.Identifier}} {{rawType $field.TypeInfo.V}} 163 | {{end -}} 164 | {{end -}} 165 | predFuncs []comparison.PredFunc 166 | } 167 | 168 | // NewUpdater returns an Updater 169 | func NewUpdater() *Updater { 170 | return &Updater{} 171 | } 172 | 173 | {{range $field := .Fields}} 174 | {{if ne $field.IsAuto true -}} 175 | // {{$field.StructField}} sets the {{$field.StructField}} field 176 | func (c *Updater) {{$field.StructField}}({{$field.Identifier}} {{rawType $field.TypeInfo.V}}) *Updater { 177 | c.{{$field.Identifier}} = {{$field.Identifier}} 178 | return c 179 | } 180 | {{end -}} 181 | {{end -}} 182 | 183 | // Where applies predicates 184 | func (u *Updater) Where(predFuncs ...comparison.PredFunc) *Updater { 185 | u.predFuncs = append(u.predFuncs, predFuncs...) 186 | return u 187 | } 188 | 189 | // Deleter is a delete builder 190 | type Deleter struct { 191 | predFuncs []comparison.PredFunc 192 | } 193 | 194 | // NewDeleter returns a Deleter 195 | func NewDeleter() *Deleter { 196 | return &Deleter{} 197 | } 198 | 199 | // Where applies predicates 200 | func (d *Deleter) Where(predFuncs ...comparison.PredFunc) *Deleter { 201 | d.predFuncs = append(d.predFuncs, predFuncs...) 202 | return d 203 | } 204 | 205 | // Aggregator is an aggregate query builder 206 | type Aggregator struct { 207 | v interface{} 208 | aggFuncs []aggregate.AggFunc 209 | predFuncs []comparison.PredFunc 210 | sortFuncs []sort.SortFunc 211 | groupBys []Field 212 | } 213 | 214 | // NewAggregator expects a v and returns an Aggregator 215 | // where 'v' argument must be an array of struct 216 | func NewAggregator(v interface{}) *Aggregator { 217 | return &Aggregator{v: v} 218 | } 219 | 220 | // Aggregate applies aggregate functions 221 | func (a *Aggregator) Aggregate(aggFuncs ...aggregate.AggFunc) *Aggregator { 222 | a.aggFuncs = append(a.aggFuncs, aggFuncs...) 223 | return a 224 | } 225 | 226 | // Where applies predicates 227 | func (a *Aggregator) Where(predFuncs ...comparison.PredFunc) *Aggregator { 228 | a.predFuncs = append(a.predFuncs, predFuncs...) 229 | return a 230 | } 231 | 232 | // Sort applies sorting expressions 233 | func (a *Aggregator) Sort(sortFuncs ...sort.SortFunc) *Aggregator { 234 | a.sortFuncs = append(a.sortFuncs, sortFuncs...) 235 | return a 236 | } 237 | 238 | // Group applies group clauses 239 | func (a *Aggregator) GroupBy(fields ...Field) *Aggregator { 240 | a.groupBys = append(a.groupBys, fields...) 241 | return a 242 | } 243 | 244 | // rollback performs a rollback 245 | func rollback(tx nero.Tx, err error) error { 246 | rerr := tx.Rollback() 247 | if rerr != nil { 248 | err = errors.Wrapf(err, "rollback error: %v", rerr) 249 | } 250 | return err 251 | } 252 | 253 | // isZero checks if v is a zero-value 254 | func isZero(v interface{}) bool { 255 | return reflect.ValueOf(v).IsZero() 256 | } 257 | ` 258 | -------------------------------------------------------------------------------- /test/integration/playerrepo/predicate_test.go: -------------------------------------------------------------------------------- 1 | package playerrepo_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stevenferrer/nero/comparison" 8 | "github.com/stevenferrer/nero/test/integration/player" 9 | "github.com/stevenferrer/nero/test/integration/playerrepo" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestPredicate(t *testing.T) { 14 | now := time.Now() 15 | tests := []struct { 16 | predFunc comparison.PredFunc 17 | want *comparison.Predicate 18 | }{ 19 | // id 20 | { 21 | predFunc: playerrepo.IDEq("1"), 22 | want: &comparison.Predicate{ 23 | Field: playerrepo.FieldID.String(), 24 | Op: comparison.Eq, 25 | Arg: "1", 26 | }, 27 | }, 28 | { 29 | predFunc: playerrepo.IDNotEq("1"), 30 | want: &comparison.Predicate{ 31 | Field: playerrepo.FieldID.String(), 32 | Op: comparison.NotEq, 33 | Arg: "1", 34 | }, 35 | }, 36 | { 37 | predFunc: playerrepo.IDIn("1"), 38 | want: &comparison.Predicate{ 39 | Field: playerrepo.FieldID.String(), 40 | Op: comparison.In, 41 | Arg: []interface{}{"1"}, 42 | }, 43 | }, 44 | { 45 | predFunc: playerrepo.IDNotIn("1"), 46 | want: &comparison.Predicate{ 47 | Field: playerrepo.FieldID.String(), 48 | Op: comparison.NotIn, 49 | Arg: []interface{}{"1"}, 50 | }, 51 | }, 52 | 53 | // email 54 | { 55 | predFunc: playerrepo.EmailEq("me@me.io"), 56 | want: &comparison.Predicate{ 57 | Field: playerrepo.FieldEmail.String(), 58 | Op: comparison.Eq, 59 | Arg: "me@me.io", 60 | }, 61 | }, 62 | { 63 | predFunc: playerrepo.EmailNotEq("me@me.io"), 64 | want: &comparison.Predicate{ 65 | Field: playerrepo.FieldEmail.String(), 66 | Op: comparison.NotEq, 67 | Arg: "me@me.io", 68 | }, 69 | }, 70 | { 71 | predFunc: playerrepo.EmailIn("me@me.io"), 72 | want: &comparison.Predicate{ 73 | Field: playerrepo.FieldEmail.String(), 74 | Op: comparison.In, 75 | Arg: []interface{}{"me@me.io"}, 76 | }, 77 | }, 78 | { 79 | predFunc: playerrepo.EmailNotIn("me@me.io"), 80 | want: &comparison.Predicate{ 81 | Field: playerrepo.FieldEmail.String(), 82 | Op: comparison.NotIn, 83 | Arg: []interface{}{"me@me.io"}, 84 | }, 85 | }, 86 | // name 87 | { 88 | predFunc: playerrepo.NameEq("me"), 89 | want: &comparison.Predicate{ 90 | Field: playerrepo.FieldName.String(), 91 | Op: comparison.Eq, 92 | Arg: "me", 93 | }, 94 | }, 95 | { 96 | predFunc: playerrepo.NameNotEq("me"), 97 | want: &comparison.Predicate{ 98 | Field: playerrepo.FieldName.String(), 99 | Op: comparison.NotEq, 100 | Arg: "me", 101 | }, 102 | }, 103 | { 104 | predFunc: playerrepo.NameIn("me"), 105 | want: &comparison.Predicate{ 106 | Field: playerrepo.FieldName.String(), 107 | Op: comparison.In, 108 | Arg: []interface{}{"me"}, 109 | }, 110 | }, 111 | { 112 | predFunc: playerrepo.NameNotIn("me"), 113 | want: &comparison.Predicate{ 114 | Field: playerrepo.FieldName.String(), 115 | Op: comparison.NotIn, 116 | Arg: []interface{}{"me"}, 117 | }, 118 | }, 119 | 120 | // age 121 | { 122 | predFunc: playerrepo.AgeEq(18), 123 | want: &comparison.Predicate{ 124 | Field: playerrepo.FieldAge.String(), 125 | Op: comparison.Eq, 126 | Arg: 18, 127 | }, 128 | }, 129 | { 130 | predFunc: playerrepo.AgeNotEq(18), 131 | want: &comparison.Predicate{ 132 | Field: playerrepo.FieldAge.String(), 133 | Op: comparison.NotEq, 134 | Arg: 18, 135 | }, 136 | }, 137 | { 138 | predFunc: playerrepo.AgeGt(18), 139 | want: &comparison.Predicate{ 140 | Field: playerrepo.FieldAge.String(), 141 | Op: comparison.Gt, 142 | Arg: 18, 143 | }, 144 | }, 145 | { 146 | predFunc: playerrepo.AgeGtOrEq(18), 147 | want: &comparison.Predicate{ 148 | Field: playerrepo.FieldAge.String(), 149 | Op: comparison.GtOrEq, 150 | Arg: 18, 151 | }, 152 | }, 153 | { 154 | predFunc: playerrepo.AgeLt(18), 155 | want: &comparison.Predicate{ 156 | Field: playerrepo.FieldAge.String(), 157 | Op: comparison.Lt, 158 | Arg: 18, 159 | }, 160 | }, 161 | { 162 | predFunc: playerrepo.AgeLtOrEq(18), 163 | want: &comparison.Predicate{ 164 | Field: playerrepo.FieldAge.String(), 165 | Op: comparison.LtOrEq, 166 | Arg: 18, 167 | }, 168 | }, 169 | { 170 | predFunc: playerrepo.AgeNotEq(18), 171 | want: &comparison.Predicate{ 172 | Field: playerrepo.FieldAge.String(), 173 | Op: comparison.NotEq, 174 | Arg: 18, 175 | }, 176 | }, 177 | 178 | { 179 | predFunc: playerrepo.AgeIn(18), 180 | want: &comparison.Predicate{ 181 | Field: playerrepo.FieldAge.String(), 182 | Op: comparison.In, 183 | Arg: []interface{}{18}, 184 | }, 185 | }, 186 | { 187 | predFunc: playerrepo.AgeNotIn(18), 188 | want: &comparison.Predicate{ 189 | Field: playerrepo.FieldAge.String(), 190 | Op: comparison.NotIn, 191 | Arg: []interface{}{18}, 192 | }, 193 | }, 194 | 195 | // race 196 | { 197 | predFunc: playerrepo.RaceEq(player.RaceHuman), 198 | want: &comparison.Predicate{ 199 | Field: playerrepo.FieldRace.String(), 200 | Op: comparison.Eq, 201 | Arg: player.RaceHuman, 202 | }, 203 | }, 204 | { 205 | predFunc: playerrepo.RaceNotEq(player.RaceHuman), 206 | want: &comparison.Predicate{ 207 | Field: playerrepo.FieldRace.String(), 208 | Op: comparison.NotEq, 209 | Arg: player.RaceHuman, 210 | }, 211 | }, 212 | { 213 | predFunc: playerrepo.RaceIn(player.RaceHuman), 214 | want: &comparison.Predicate{ 215 | Field: playerrepo.FieldRace.String(), 216 | Op: comparison.In, 217 | Arg: []interface{}{player.RaceHuman}, 218 | }, 219 | }, 220 | { 221 | predFunc: playerrepo.RaceNotIn(player.RaceHuman), 222 | want: &comparison.Predicate{ 223 | Field: playerrepo.FieldRace.String(), 224 | Op: comparison.NotIn, 225 | Arg: []interface{}{player.RaceHuman}, 226 | }, 227 | }, 228 | 229 | // updated at 230 | { 231 | predFunc: playerrepo.UpdatedAtEq(&now), 232 | want: &comparison.Predicate{ 233 | Field: playerrepo.FieldUpdatedAt.String(), 234 | Op: comparison.Eq, 235 | Arg: &now, 236 | }, 237 | }, 238 | { 239 | predFunc: playerrepo.UpdatedAtNotEq(&now), 240 | want: &comparison.Predicate{ 241 | Field: playerrepo.FieldUpdatedAt.String(), 242 | Op: comparison.NotEq, 243 | Arg: &now, 244 | }, 245 | }, 246 | { 247 | predFunc: playerrepo.UpdatedAtIsNull(), 248 | want: &comparison.Predicate{ 249 | Field: playerrepo.FieldUpdatedAt.String(), 250 | Op: comparison.IsNull, 251 | }, 252 | }, 253 | { 254 | predFunc: playerrepo.UpdatedAtIsNotNull(), 255 | want: &comparison.Predicate{ 256 | Field: playerrepo.FieldUpdatedAt.String(), 257 | Op: comparison.IsNotNull, 258 | }, 259 | }, 260 | { 261 | predFunc: playerrepo.UpdatedAtIn(&now), 262 | want: &comparison.Predicate{ 263 | Field: playerrepo.FieldUpdatedAt.String(), 264 | Op: comparison.In, 265 | Arg: []interface{}{&now}, 266 | }, 267 | }, 268 | { 269 | predFunc: playerrepo.UpdatedAtNotIn(&now), 270 | want: &comparison.Predicate{ 271 | Field: playerrepo.FieldUpdatedAt.String(), 272 | Op: comparison.NotIn, 273 | Arg: []interface{}{&now}, 274 | }, 275 | }, 276 | 277 | // created at 278 | { 279 | predFunc: playerrepo.CreatedAtEq(&now), 280 | want: &comparison.Predicate{ 281 | Field: playerrepo.FieldCreatedAt.String(), 282 | Op: comparison.Eq, 283 | Arg: &now, 284 | }, 285 | }, 286 | { 287 | predFunc: playerrepo.CreatedAtNotEq(&now), 288 | want: &comparison.Predicate{ 289 | Field: playerrepo.FieldCreatedAt.String(), 290 | Op: comparison.NotEq, 291 | Arg: &now, 292 | }, 293 | }, 294 | { 295 | predFunc: playerrepo.CreatedAtIsNull(), 296 | want: &comparison.Predicate{ 297 | Field: playerrepo.FieldCreatedAt.String(), 298 | Op: comparison.IsNull, 299 | }, 300 | }, 301 | { 302 | predFunc: playerrepo.CreatedAtIsNotNull(), 303 | want: &comparison.Predicate{ 304 | Field: playerrepo.FieldCreatedAt.String(), 305 | Op: comparison.IsNotNull, 306 | }, 307 | }, 308 | { 309 | predFunc: playerrepo.CreatedAtIn(&now), 310 | want: &comparison.Predicate{ 311 | Field: playerrepo.FieldCreatedAt.String(), 312 | Op: comparison.In, 313 | Arg: []interface{}{&now}, 314 | }, 315 | }, 316 | { 317 | predFunc: playerrepo.CreatedAtNotIn(&now), 318 | want: &comparison.Predicate{ 319 | Field: playerrepo.FieldCreatedAt.String(), 320 | Op: comparison.NotIn, 321 | Arg: []interface{}{&now}, 322 | }, 323 | }, 324 | 325 | // field to field comparison 326 | { 327 | predFunc: playerrepo.FieldXEqFieldY( 328 | playerrepo.FieldUpdatedAt, 329 | playerrepo.FieldCreatedAt, 330 | ), 331 | want: &comparison.Predicate{ 332 | Field: playerrepo.FieldUpdatedAt.String(), 333 | Op: comparison.Eq, 334 | Arg: playerrepo.FieldCreatedAt, 335 | }, 336 | }, 337 | { 338 | predFunc: playerrepo.FieldXNotEqFieldY( 339 | playerrepo.FieldUpdatedAt, 340 | playerrepo.FieldCreatedAt, 341 | ), 342 | want: &comparison.Predicate{ 343 | Field: playerrepo.FieldUpdatedAt.String(), 344 | Op: comparison.NotEq, 345 | Arg: playerrepo.FieldCreatedAt, 346 | }, 347 | }, 348 | 349 | { 350 | predFunc: playerrepo.FieldXGtFieldY( 351 | playerrepo.FieldUpdatedAt, 352 | playerrepo.FieldCreatedAt, 353 | ), 354 | want: &comparison.Predicate{ 355 | Field: playerrepo.FieldUpdatedAt.String(), 356 | Op: comparison.Gt, 357 | Arg: playerrepo.FieldCreatedAt, 358 | }, 359 | }, 360 | { 361 | predFunc: playerrepo.FieldXGtOrEqFieldY( 362 | playerrepo.FieldUpdatedAt, 363 | playerrepo.FieldCreatedAt, 364 | ), 365 | want: &comparison.Predicate{ 366 | Field: playerrepo.FieldUpdatedAt.String(), 367 | Op: comparison.GtOrEq, 368 | Arg: playerrepo.FieldCreatedAt, 369 | }, 370 | }, 371 | 372 | { 373 | predFunc: playerrepo.FieldXLtFieldY( 374 | playerrepo.FieldUpdatedAt, 375 | playerrepo.FieldCreatedAt, 376 | ), 377 | want: &comparison.Predicate{ 378 | Field: playerrepo.FieldUpdatedAt.String(), 379 | Op: comparison.Lt, 380 | Arg: playerrepo.FieldCreatedAt, 381 | }, 382 | }, 383 | { 384 | predFunc: playerrepo.FieldXLtOrEqFieldY( 385 | playerrepo.FieldUpdatedAt, 386 | playerrepo.FieldCreatedAt, 387 | ), 388 | 389 | want: &comparison.Predicate{ 390 | Field: playerrepo.FieldUpdatedAt.String(), 391 | Op: comparison.LtOrEq, 392 | Arg: playerrepo.FieldCreatedAt, 393 | }, 394 | }, 395 | } 396 | 397 | for _, tc := range tests { 398 | got := tc.predFunc([]*comparison.Predicate{})[0] 399 | assert.Equal(t, tc.want.Field, got.Field) 400 | assert.Equal(t, tc.want.Arg, got.Arg) 401 | assert.Equal(t, tc.want.Op, got.Op) 402 | } 403 | } 404 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 Steven Ferrer 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /test/integration/playerrepo/sqlite.go: -------------------------------------------------------------------------------- 1 | // Code generated by nero, DO NOT EDIT. 2 | package playerrepo 3 | 4 | import ( 5 | "context" 6 | "database/sql" 7 | "fmt" 8 | "log" 9 | "os" 10 | "reflect" 11 | "strings" 12 | 13 | "github.com/Masterminds/squirrel" 14 | _ "github.com/mattn/go-sqlite3" 15 | "github.com/pkg/errors" 16 | "github.com/stevenferrer/nero" 17 | "github.com/stevenferrer/nero/aggregate" 18 | "github.com/stevenferrer/nero/comparison" 19 | "github.com/stevenferrer/nero/sort" 20 | "github.com/stevenferrer/nero/test/integration/player" 21 | ) 22 | 23 | // SQLiteRepository is a repository that uses SQLite3 as data store 24 | type SQLiteRepository struct { 25 | db *sql.DB 26 | logger nero.Logger 27 | debug bool 28 | } 29 | 30 | var _ Repository = (*SQLiteRepository)(nil) 31 | 32 | // NewSQLiteRepository returns a new SQLiteRepository 33 | func NewSQLiteRepository(db *sql.DB) *SQLiteRepository { 34 | return &SQLiteRepository{db: db} 35 | } 36 | 37 | // Debug enables debug mode 38 | func (repo *SQLiteRepository) Debug() *SQLiteRepository { 39 | l := log.New(os.Stdout, "[nero] ", log.LstdFlags|log.Lmicroseconds|log.Lmsgprefix) 40 | return &SQLiteRepository{ 41 | db: repo.db, 42 | debug: true, 43 | logger: l, 44 | } 45 | } 46 | 47 | // WithLogger overrides the default logger 48 | func (repo *SQLiteRepository) WithLogger(logger nero.Logger) *SQLiteRepository { 49 | repo.logger = logger 50 | return repo 51 | } 52 | 53 | // BeginTx starts a transaction 54 | func (repo *SQLiteRepository) BeginTx(ctx context.Context) (nero.Tx, error) { 55 | return repo.db.BeginTx(ctx, nil) 56 | } 57 | 58 | // Create creates a Player 59 | func (repo *SQLiteRepository) Create(ctx context.Context, c *Creator) (string, error) { 60 | return repo.create(ctx, repo.db, c) 61 | } 62 | 63 | // CreateInTx creates a Player in a transaction 64 | func (repo *SQLiteRepository) CreateInTx(ctx context.Context, tx nero.Tx, c *Creator) (string, error) { 65 | txx, ok := tx.(*sql.Tx) 66 | if !ok { 67 | return "", errors.New("expecting tx to be *sql.Tx") 68 | } 69 | 70 | return repo.create(ctx, txx, c) 71 | } 72 | 73 | func (repo *SQLiteRepository) create(ctx context.Context, runner nero.SQLRunner, c *Creator) (string, error) { 74 | if err := c.Validate(); err != nil { 75 | return "", err 76 | } 77 | 78 | columns := []string{ 79 | "\"email\"", 80 | "\"name\"", 81 | "\"age\"", 82 | "\"race\"", 83 | } 84 | 85 | values := []interface{}{ 86 | c.email, 87 | c.name, 88 | c.age, 89 | c.race, 90 | } 91 | 92 | if !isZero(c.updatedAt) { 93 | columns = append(columns, "updated_at") 94 | values = append(values, c.updatedAt) 95 | } 96 | 97 | qb := squirrel.Insert("\"players\"").Columns(columns...). 98 | Values(values...).RunWith(runner) 99 | if repo.debug && repo.logger != nil { 100 | sql, args, err := qb.ToSql() 101 | repo.logger.Printf("method: Create, stmt: %q, args: %v, error: %v", sql, args, err) 102 | } 103 | 104 | _, err := qb.ExecContext(ctx) 105 | if err != nil { 106 | return "", err 107 | } 108 | 109 | var id string 110 | err = repo.db.QueryRowContext(ctx, "select last_insert_rowid()").Scan(&id) 111 | if err != nil { 112 | return "", err 113 | } 114 | 115 | return id, nil 116 | } 117 | 118 | // CreateMany batch creates Players 119 | func (repo *SQLiteRepository) CreateMany(ctx context.Context, cs ...*Creator) error { 120 | return repo.createMany(ctx, repo.db, cs...) 121 | } 122 | 123 | // CreateManyInTx batch creates Players in a transaction 124 | func (repo *SQLiteRepository) CreateManyInTx(ctx context.Context, tx nero.Tx, cs ...*Creator) error { 125 | txx, ok := tx.(*sql.Tx) 126 | if !ok { 127 | return errors.New("expecting tx to be *sql.Tx") 128 | } 129 | 130 | return repo.createMany(ctx, txx, cs...) 131 | } 132 | 133 | func (repo *SQLiteRepository) createMany(ctx context.Context, runner nero.SQLRunner, cs ...*Creator) error { 134 | if len(cs) == 0 { 135 | return nil 136 | } 137 | 138 | columns := []string{ 139 | "\"email\"", 140 | "\"name\"", 141 | "\"age\"", 142 | "\"race\"", 143 | "\"updated_at\"", 144 | } 145 | qb := squirrel.Insert("\"players\"").Columns(columns...) 146 | for _, c := range cs { 147 | if err := c.Validate(); err != nil { 148 | return err 149 | } 150 | 151 | qb = qb.Values( 152 | c.email, 153 | c.name, 154 | c.age, 155 | c.race, 156 | c.updatedAt, 157 | ) 158 | } 159 | 160 | if repo.debug && repo.logger != nil { 161 | sql, args, err := qb.ToSql() 162 | repo.logger.Printf("method: CreateMany, stmt: %q, args: %v, error: %v", sql, args, err) 163 | } 164 | 165 | _, err := qb.RunWith(runner).ExecContext(ctx) 166 | if err != nil { 167 | return err 168 | } 169 | 170 | return nil 171 | } 172 | 173 | // Query queries Players 174 | func (repo *SQLiteRepository) Query(ctx context.Context, q *Queryer) ([]*player.Player, error) { 175 | return repo.query(ctx, repo.db, q) 176 | } 177 | 178 | // QueryInTx queries Players in a transaction 179 | func (repo *SQLiteRepository) QueryInTx(ctx context.Context, tx nero.Tx, q *Queryer) ([]*player.Player, error) { 180 | txx, ok := tx.(*sql.Tx) 181 | if !ok { 182 | return nil, errors.New("expecting tx to be *sql.Tx") 183 | } 184 | 185 | return repo.query(ctx, txx, q) 186 | } 187 | 188 | func (repo *SQLiteRepository) query(ctx context.Context, runner nero.SQLRunner, q *Queryer) ([]*player.Player, error) { 189 | qb := repo.buildSelect(q) 190 | if repo.debug && repo.logger != nil { 191 | sql, args, err := qb.ToSql() 192 | repo.logger.Printf("method: Query, stmt: %q, args: %v, error: %v", sql, args, err) 193 | } 194 | 195 | rows, err := qb.RunWith(runner).QueryContext(ctx) 196 | if err != nil { 197 | return nil, err 198 | } 199 | defer rows.Close() 200 | 201 | players := []*player.Player{} 202 | for rows.Next() { 203 | var player player.Player 204 | err = rows.Scan( 205 | &player.ID, 206 | &player.Email, 207 | &player.Name, 208 | &player.Age, 209 | &player.Race, 210 | &player.UpdatedAt, 211 | &player.CreatedAt, 212 | ) 213 | if err != nil { 214 | return nil, err 215 | } 216 | 217 | players = append(players, &player) 218 | } 219 | 220 | return players, nil 221 | } 222 | 223 | // QueryOne queries a Player 224 | func (repo *SQLiteRepository) QueryOne(ctx context.Context, q *Queryer) (*player.Player, error) { 225 | return repo.queryOne(ctx, repo.db, q) 226 | } 227 | 228 | // QueryOneInTx queries a Player in a transaction 229 | func (repo *SQLiteRepository) QueryOneInTx(ctx context.Context, tx nero.Tx, q *Queryer) (*player.Player, error) { 230 | txx, ok := tx.(*sql.Tx) 231 | if !ok { 232 | return nil, errors.New("expecting tx to be *sql.Tx") 233 | } 234 | 235 | return repo.queryOne(ctx, txx, q) 236 | } 237 | 238 | func (repo *SQLiteRepository) queryOne(ctx context.Context, runner nero.SQLRunner, q *Queryer) (*player.Player, error) { 239 | qb := repo.buildSelect(q) 240 | if repo.debug && repo.logger != nil { 241 | sql, args, err := qb.ToSql() 242 | repo.logger.Printf("method: QueryOne, stmt: %q, args: %v, error: %v", sql, args, err) 243 | } 244 | 245 | var player player.Player 246 | err := qb.RunWith(runner). 247 | QueryRowContext(ctx). 248 | Scan( 249 | &player.ID, 250 | &player.Email, 251 | &player.Name, 252 | &player.Age, 253 | &player.Race, 254 | &player.UpdatedAt, 255 | &player.CreatedAt, 256 | ) 257 | if err != nil { 258 | return nil, err 259 | } 260 | 261 | return &player, nil 262 | } 263 | 264 | func (repo *SQLiteRepository) buildSelect(q *Queryer) squirrel.SelectBuilder { 265 | columns := []string{ 266 | "\"id\"", 267 | "\"email\"", 268 | "\"name\"", 269 | "\"age\"", 270 | "\"race\"", 271 | "\"updated_at\"", 272 | "\"created_at\"", 273 | } 274 | qb := squirrel.Select(columns...).From("\"players\"") 275 | 276 | preds := []*comparison.Predicate{} 277 | for _, predFunc := range q.predFuncs { 278 | preds = predFunc(preds) 279 | } 280 | qb = squirrel.SelectBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 281 | 282 | sorts := []*sort.Sort{} 283 | for _, sortFunc := range q.sortFuncs { 284 | sorts = sortFunc(sorts) 285 | } 286 | qb = repo.buildSort(qb, sorts) 287 | 288 | if q.limit > 0 { 289 | qb = qb.Limit(uint64(q.limit)) 290 | } 291 | 292 | if q.offset > 0 { 293 | qb = qb.Offset(uint64(q.offset)) 294 | } 295 | 296 | return qb 297 | } 298 | 299 | func (repo *SQLiteRepository) buildPreds(sb squirrel.StatementBuilderType, preds []*comparison.Predicate) squirrel.StatementBuilderType { 300 | for _, pred := range preds { 301 | ph := "?" 302 | fieldX, arg := pred.Field, pred.Arg 303 | 304 | args := []interface{}{} 305 | if fieldY, ok := arg.(Field); ok { // a field 306 | ph = fmt.Sprintf("%q", fieldY) 307 | } else if vals, ok := arg.([]interface{}); ok { // array of values 308 | args = append(args, vals...) 309 | } else { // single value 310 | args = append(args, arg) 311 | } 312 | 313 | switch pred.Op { 314 | case comparison.Eq: 315 | sb = sb.Where(fmt.Sprintf("%q = "+ph, fieldX), args...) 316 | case comparison.NotEq: 317 | sb = sb.Where(fmt.Sprintf("%q <> "+ph, fieldX), args...) 318 | case comparison.Gt: 319 | sb = sb.Where(fmt.Sprintf("%q > "+ph, fieldX), args...) 320 | case comparison.GtOrEq: 321 | sb = sb.Where(fmt.Sprintf("%q >= "+ph, fieldX), args...) 322 | case comparison.Lt: 323 | sb = sb.Where(fmt.Sprintf("%q < "+ph, fieldX), args...) 324 | case comparison.LtOrEq: 325 | sb = sb.Where(fmt.Sprintf("%q <= "+ph, fieldX), args...) 326 | case comparison.IsNull, comparison.IsNotNull: 327 | fmtStr := "%q IS NULL" 328 | if pred.Op == comparison.IsNotNull { 329 | fmtStr = "%q IS NOT NULL" 330 | } 331 | sb = sb.Where(fmt.Sprintf(fmtStr, fieldX)) 332 | case comparison.In, comparison.NotIn: 333 | fmtStr := "%q IN (%s)" 334 | if pred.Op == comparison.NotIn { 335 | fmtStr = "%q NOT IN (%s)" 336 | } 337 | 338 | phs := []string{} 339 | for range args { 340 | phs = append(phs, "?") 341 | } 342 | 343 | sb = sb.Where(fmt.Sprintf(fmtStr, fieldX, strings.Join(phs, ",")), args...) 344 | } 345 | } 346 | 347 | return sb 348 | } 349 | 350 | func (repo *SQLiteRepository) buildSort(qb squirrel.SelectBuilder, sorts []*sort.Sort) squirrel.SelectBuilder { 351 | for _, s := range sorts { 352 | field := fmt.Sprintf("%q", s.Field) 353 | switch s.Direction { 354 | case sort.Asc: 355 | qb = qb.OrderBy(field + " ASC") 356 | case sort.Desc: 357 | qb = qb.OrderBy(field + " DESC") 358 | } 359 | } 360 | 361 | return qb 362 | } 363 | 364 | // Update updates a Player or many Players 365 | func (repo *SQLiteRepository) Update(ctx context.Context, u *Updater) (int64, error) { 366 | return repo.update(ctx, repo.db, u) 367 | } 368 | 369 | // UpdateInTx updates a Player many Players in a transaction 370 | func (repo *SQLiteRepository) UpdateInTx(ctx context.Context, tx nero.Tx, u *Updater) (int64, error) { 371 | txx, ok := tx.(*sql.Tx) 372 | if !ok { 373 | return 0, errors.New("expecting tx to be *sql.Tx") 374 | } 375 | 376 | return repo.update(ctx, txx, u) 377 | } 378 | 379 | func (repo *SQLiteRepository) update(ctx context.Context, runner nero.SQLRunner, u *Updater) (int64, error) { 380 | qb := squirrel.Update("\"players\"") 381 | 382 | cnt := 0 383 | 384 | if !isZero(u.email) { 385 | qb = qb.Set("\"email\"", u.email) 386 | cnt++ 387 | } 388 | 389 | if !isZero(u.name) { 390 | qb = qb.Set("\"name\"", u.name) 391 | cnt++ 392 | } 393 | 394 | if !isZero(u.age) { 395 | qb = qb.Set("\"age\"", u.age) 396 | cnt++ 397 | } 398 | 399 | if !isZero(u.race) { 400 | qb = qb.Set("\"race\"", u.race) 401 | cnt++ 402 | } 403 | 404 | if !isZero(u.updatedAt) { 405 | qb = qb.Set("\"updated_at\"", u.updatedAt) 406 | cnt++ 407 | } 408 | 409 | if cnt == 0 { 410 | return 0, nil 411 | } 412 | 413 | preds := []*comparison.Predicate{} 414 | for _, predFunc := range u.predFuncs { 415 | preds = predFunc(preds) 416 | } 417 | qb = squirrel.UpdateBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 418 | 419 | if repo.debug && repo.logger != nil { 420 | sql, args, err := qb.ToSql() 421 | repo.logger.Printf("method: Update, stmt: %q, args: %v, error: %v", sql, args, err) 422 | } 423 | 424 | res, err := qb.RunWith(runner).ExecContext(ctx) 425 | if err != nil { 426 | return 0, err 427 | } 428 | 429 | rowsAffected, err := res.RowsAffected() 430 | if err != nil { 431 | return 0, err 432 | } 433 | 434 | return rowsAffected, nil 435 | } 436 | 437 | // Delete deletes a Player or many Players 438 | func (repo *SQLiteRepository) Delete(ctx context.Context, d *Deleter) (int64, error) { 439 | return repo.delete(ctx, repo.db, d) 440 | } 441 | 442 | // DeleteInTx deletes a Player or many Players in a transaction 443 | func (repo *SQLiteRepository) DeleteInTx(ctx context.Context, tx nero.Tx, d *Deleter) (int64, error) { 444 | txx, ok := tx.(*sql.Tx) 445 | if !ok { 446 | return 0, errors.New("expecting tx to be *sql.Tx") 447 | } 448 | 449 | return repo.delete(ctx, txx, d) 450 | } 451 | 452 | func (repo *SQLiteRepository) delete(ctx context.Context, runner nero.SQLRunner, d *Deleter) (int64, error) { 453 | qb := squirrel.Delete("\"players\"") 454 | 455 | preds := []*comparison.Predicate{} 456 | for _, predFunc := range d.predFuncs { 457 | preds = predFunc(preds) 458 | } 459 | qb = squirrel.DeleteBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 460 | 461 | if repo.debug && repo.logger != nil { 462 | sql, args, err := qb.ToSql() 463 | repo.logger.Printf("method: Delete, stmt: %q, args: %v, error: %v", sql, args, err) 464 | } 465 | 466 | res, err := qb.RunWith(runner).ExecContext(ctx) 467 | if err != nil { 468 | return 0, err 469 | } 470 | 471 | rowsAffected, err := res.RowsAffected() 472 | if err != nil { 473 | return 0, err 474 | } 475 | 476 | return rowsAffected, nil 477 | } 478 | 479 | // Aggregate runs an aggregate query 480 | func (repo *SQLiteRepository) Aggregate(ctx context.Context, a *Aggregator) error { 481 | return repo.aggregate(ctx, repo.db, a) 482 | } 483 | 484 | // AggregateInTx runs an aggregate query in a transaction 485 | func (repo *SQLiteRepository) AggregateInTx(ctx context.Context, tx nero.Tx, a *Aggregator) error { 486 | txx, ok := tx.(*sql.Tx) 487 | if !ok { 488 | return errors.New("expecting tx to be *sql.Tx") 489 | } 490 | 491 | return repo.aggregate(ctx, txx, a) 492 | } 493 | 494 | func (repo *SQLiteRepository) aggregate(ctx context.Context, runner nero.SQLRunner, a *Aggregator) error { 495 | aggs := []*aggregate.Aggregate{} 496 | for _, aggFunc := range a.aggFuncs { 497 | aggs = aggFunc(aggs) 498 | } 499 | columns := []string{} 500 | for _, agg := range aggs { 501 | field := agg.Field 502 | qf := fmt.Sprintf("%q", field) 503 | switch agg.Op { 504 | case aggregate.Avg: 505 | columns = append(columns, "AVG("+qf+") avg_"+field) 506 | case aggregate.Count: 507 | columns = append(columns, "COUNT("+qf+") count_"+field) 508 | case aggregate.Max: 509 | columns = append(columns, "MAX("+qf+") max_"+field) 510 | case aggregate.Min: 511 | columns = append(columns, "MIN("+qf+") min_"+field) 512 | case aggregate.Sum: 513 | columns = append(columns, "SUM("+qf+") sum_"+field) 514 | case aggregate.None: 515 | columns = append(columns, qf) 516 | } 517 | } 518 | 519 | qb := squirrel.Select(columns...).From("\"players\"") 520 | 521 | groupBys := []string{} 522 | for _, groupBy := range a.groupBys { 523 | groupBys = append(groupBys, fmt.Sprintf("%q", groupBy.String())) 524 | } 525 | qb = qb.GroupBy(groupBys...) 526 | 527 | preds := []*comparison.Predicate{} 528 | for _, predFunc := range a.predFuncs { 529 | preds = predFunc(preds) 530 | } 531 | qb = squirrel.SelectBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 532 | 533 | sorts := []*sort.Sort{} 534 | for _, sortFunc := range a.sortFuncs { 535 | sorts = sortFunc(sorts) 536 | } 537 | qb = repo.buildSort(qb, sorts) 538 | 539 | if repo.debug && repo.logger != nil { 540 | sql, args, err := qb.ToSql() 541 | repo.logger.Printf("method: Aggregate, stmt: %q, args: %v, error: %v", sql, args, err) 542 | } 543 | 544 | rows, err := qb.RunWith(runner).QueryContext(ctx) 545 | if err != nil { 546 | return err 547 | } 548 | defer rows.Close() 549 | 550 | v := reflect.ValueOf(a.v).Elem() 551 | t := reflect.TypeOf(v.Interface()).Elem() 552 | if len(columns) != t.NumField() { 553 | return errors.Errorf("column count (%v) and destination struct field count (%v) doesn't match", len(columns), t.NumField()) 554 | } 555 | 556 | for rows.Next() { 557 | ve := reflect.New(t).Elem() 558 | dest := make([]interface{}, ve.NumField()) 559 | for i := 0; i < ve.NumField(); i++ { 560 | dest[i] = ve.Field(i).Addr().Interface() 561 | } 562 | 563 | err = rows.Scan(dest...) 564 | if err != nil { 565 | return err 566 | } 567 | 568 | v.Set(reflect.Append(v, ve)) 569 | } 570 | 571 | return nil 572 | } 573 | -------------------------------------------------------------------------------- /test/integration/playerrepo/postgres.go: -------------------------------------------------------------------------------- 1 | // Code generated by nero, DO NOT EDIT. 2 | package playerrepo 3 | 4 | import ( 5 | "context" 6 | "database/sql" 7 | "fmt" 8 | "log" 9 | "os" 10 | "reflect" 11 | "strings" 12 | 13 | "github.com/Masterminds/squirrel" 14 | "github.com/pkg/errors" 15 | "github.com/stevenferrer/nero" 16 | "github.com/stevenferrer/nero/aggregate" 17 | "github.com/stevenferrer/nero/comparison" 18 | "github.com/stevenferrer/nero/sort" 19 | "github.com/stevenferrer/nero/test/integration/player" 20 | ) 21 | 22 | // PostgresRepository is a repository that uses PostgreSQL as data store 23 | type PostgresRepository struct { 24 | db *sql.DB 25 | logger nero.Logger 26 | debug bool 27 | } 28 | 29 | var _ Repository = (*PostgresRepository)(nil) 30 | 31 | // NewPostgresRepository returns a PostgresRepository 32 | func NewPostgresRepository(db *sql.DB) *PostgresRepository { 33 | return &PostgresRepository{db: db} 34 | } 35 | 36 | // Debug enables debug mode 37 | func (repo *PostgresRepository) Debug() *PostgresRepository { 38 | l := log.New(os.Stdout, "[nero] ", log.LstdFlags|log.Lmicroseconds|log.Lmsgprefix) 39 | return &PostgresRepository{ 40 | db: repo.db, 41 | debug: true, 42 | logger: l, 43 | } 44 | } 45 | 46 | // WithLogger overrides the default logger 47 | func (repo *PostgresRepository) WithLogger(logger nero.Logger) *PostgresRepository { 48 | repo.logger = logger 49 | return repo 50 | } 51 | 52 | // BeginTx starts a transaction 53 | func (repo *PostgresRepository) BeginTx(ctx context.Context) (nero.Tx, error) { 54 | return repo.db.BeginTx(ctx, nil) 55 | } 56 | 57 | // Create creates a Player 58 | func (repo *PostgresRepository) Create(ctx context.Context, c *Creator) (string, error) { 59 | return repo.create(ctx, repo.db, c) 60 | } 61 | 62 | // CreateInTx creates a Player in a transaction 63 | func (repo *PostgresRepository) CreateInTx(ctx context.Context, tx nero.Tx, c *Creator) (string, error) { 64 | txx, ok := tx.(*sql.Tx) 65 | if !ok { 66 | return "", errors.New("expecting tx to be *sql.Tx") 67 | } 68 | 69 | return repo.create(ctx, txx, c) 70 | } 71 | 72 | func (repo *PostgresRepository) create(ctx context.Context, runner nero.SQLRunner, c *Creator) (string, error) { 73 | if err := c.Validate(); err != nil { 74 | return "", err 75 | } 76 | 77 | columns := []string{ 78 | "\"email\"", 79 | "\"name\"", 80 | "\"age\"", 81 | "\"race\"", 82 | } 83 | 84 | values := []interface{}{ 85 | c.email, 86 | c.name, 87 | c.age, 88 | c.race, 89 | } 90 | 91 | if !isZero(c.updatedAt) { 92 | columns = append(columns, "updated_at") 93 | values = append(values, c.updatedAt) 94 | } 95 | 96 | qb := squirrel.Insert("\"players\""). 97 | Columns(columns...). 98 | Values(values...). 99 | Suffix("RETURNING \"id\""). 100 | PlaceholderFormat(squirrel.Dollar). 101 | RunWith(runner) 102 | if repo.debug && repo.logger != nil { 103 | sql, args, err := qb.ToSql() 104 | repo.logger.Printf("method: Create, stmt: %q, args: %v, error: %v", sql, args, err) 105 | } 106 | 107 | var id string 108 | err := qb.QueryRowContext(ctx).Scan(&id) 109 | if err != nil { 110 | return "", err 111 | } 112 | 113 | return id, nil 114 | } 115 | 116 | // CreateMany batch creates Players 117 | func (repo *PostgresRepository) CreateMany(ctx context.Context, cs ...*Creator) error { 118 | return repo.createMany(ctx, repo.db, cs...) 119 | } 120 | 121 | // CreateManyInTx batch creates Players in a transaction 122 | func (repo *PostgresRepository) CreateManyInTx(ctx context.Context, tx nero.Tx, cs ...*Creator) error { 123 | txx, ok := tx.(*sql.Tx) 124 | if !ok { 125 | return errors.New("expecting tx to be *sql.Tx") 126 | } 127 | 128 | return repo.createMany(ctx, txx, cs...) 129 | } 130 | 131 | func (repo *PostgresRepository) createMany(ctx context.Context, runner nero.SQLRunner, cs ...*Creator) error { 132 | if len(cs) == 0 { 133 | return nil 134 | } 135 | 136 | columns := []string{ 137 | "\"email\"", 138 | "\"name\"", 139 | "\"age\"", 140 | "\"race\"", 141 | "\"updated_at\"", 142 | } 143 | 144 | qb := squirrel.Insert("\"players\"").Columns(columns...) 145 | for _, c := range cs { 146 | if err := c.Validate(); err != nil { 147 | return err 148 | } 149 | 150 | qb = qb.Values( 151 | c.email, 152 | c.name, 153 | c.age, 154 | c.race, 155 | c.updatedAt, 156 | ) 157 | } 158 | 159 | qb = qb.Suffix("RETURNING \"id\""). 160 | PlaceholderFormat(squirrel.Dollar) 161 | if repo.debug && repo.logger != nil { 162 | sql, args, err := qb.ToSql() 163 | repo.logger.Printf("method: CreateMany, stmt: %q, args: %v, error: %v", sql, args, err) 164 | } 165 | 166 | _, err := qb.RunWith(runner).ExecContext(ctx) 167 | if err != nil { 168 | return err 169 | } 170 | 171 | return nil 172 | } 173 | 174 | // Query queries Players 175 | func (repo *PostgresRepository) Query(ctx context.Context, q *Queryer) ([]*player.Player, error) { 176 | return repo.query(ctx, repo.db, q) 177 | } 178 | 179 | // QueryInTx queries Players in a transaction 180 | func (repo *PostgresRepository) QueryInTx(ctx context.Context, tx nero.Tx, q *Queryer) ([]*player.Player, error) { 181 | txx, ok := tx.(*sql.Tx) 182 | if !ok { 183 | return nil, errors.New("expecting tx to be *sql.Tx") 184 | } 185 | 186 | return repo.query(ctx, txx, q) 187 | } 188 | 189 | func (repo *PostgresRepository) query(ctx context.Context, runner nero.SQLRunner, q *Queryer) ([]*player.Player, error) { 190 | qb := repo.buildSelect(q) 191 | if repo.debug && repo.logger != nil { 192 | sql, args, err := qb.ToSql() 193 | repo.logger.Printf("method: Query, stmt: %q, args: %v, error: %v", sql, args, err) 194 | } 195 | 196 | rows, err := qb.RunWith(runner).QueryContext(ctx) 197 | if err != nil { 198 | return nil, err 199 | } 200 | defer rows.Close() 201 | 202 | players := []*player.Player{} 203 | for rows.Next() { 204 | var player player.Player 205 | err = rows.Scan( 206 | &player.ID, 207 | &player.Email, 208 | &player.Name, 209 | &player.Age, 210 | &player.Race, 211 | &player.UpdatedAt, 212 | &player.CreatedAt, 213 | ) 214 | if err != nil { 215 | return nil, err 216 | } 217 | 218 | players = append(players, &player) 219 | } 220 | 221 | return players, nil 222 | } 223 | 224 | // QueryOne queries a Player 225 | func (repo *PostgresRepository) QueryOne(ctx context.Context, q *Queryer) (*player.Player, error) { 226 | return repo.queryOne(ctx, repo.db, q) 227 | } 228 | 229 | // QueryOneInTx queries a Player in a transaction 230 | func (repo *PostgresRepository) QueryOneInTx(ctx context.Context, tx nero.Tx, q *Queryer) (*player.Player, error) { 231 | txx, ok := tx.(*sql.Tx) 232 | if !ok { 233 | return nil, errors.New("expecting tx to be *sql.Tx") 234 | } 235 | 236 | return repo.queryOne(ctx, txx, q) 237 | } 238 | 239 | func (repo *PostgresRepository) queryOne(ctx context.Context, runner nero.SQLRunner, q *Queryer) (*player.Player, error) { 240 | qb := repo.buildSelect(q) 241 | if repo.debug && repo.logger != nil { 242 | sql, args, err := qb.ToSql() 243 | repo.logger.Printf("method: QueryOne, stmt: %q, args: %v, error: %v", sql, args, err) 244 | } 245 | 246 | var player player.Player 247 | err := qb.RunWith(runner). 248 | QueryRowContext(ctx). 249 | Scan( 250 | &player.ID, 251 | &player.Email, 252 | &player.Name, 253 | &player.Age, 254 | &player.Race, 255 | &player.UpdatedAt, 256 | &player.CreatedAt, 257 | ) 258 | if err != nil { 259 | return nil, err 260 | } 261 | 262 | return &player, nil 263 | } 264 | 265 | func (repo *PostgresRepository) buildSelect(q *Queryer) squirrel.SelectBuilder { 266 | columns := []string{ 267 | "\"id\"", 268 | "\"email\"", 269 | "\"name\"", 270 | "\"age\"", 271 | "\"race\"", 272 | "\"updated_at\"", 273 | "\"created_at\"", 274 | } 275 | qb := squirrel.Select(columns...). 276 | From("\"players\""). 277 | PlaceholderFormat(squirrel.Dollar) 278 | 279 | preds := []*comparison.Predicate{} 280 | for _, predFunc := range q.predFuncs { 281 | preds = predFunc(preds) 282 | } 283 | qb = squirrel.SelectBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 284 | 285 | sorts := []*sort.Sort{} 286 | for _, sortFunc := range q.sortFuncs { 287 | sorts = sortFunc(sorts) 288 | } 289 | qb = repo.buildSort(qb, sorts) 290 | 291 | if q.limit > 0 { 292 | qb = qb.Limit(uint64(q.limit)) 293 | } 294 | 295 | if q.offset > 0 { 296 | qb = qb.Offset(uint64(q.offset)) 297 | } 298 | 299 | return qb 300 | } 301 | 302 | func (repo *PostgresRepository) buildPreds(sb squirrel.StatementBuilderType, preds []*comparison.Predicate) squirrel.StatementBuilderType { 303 | for _, pred := range preds { 304 | ph := "?" 305 | fieldX, arg := pred.Field, pred.Arg 306 | 307 | args := []interface{}{} 308 | if fieldY, ok := arg.(Field); ok { // a field 309 | ph = fmt.Sprintf("%q", fieldY) 310 | } else if vals, ok := arg.([]interface{}); ok { // array of values 311 | args = append(args, vals...) 312 | } else { // single value 313 | args = append(args, arg) 314 | } 315 | 316 | switch pred.Op { 317 | case comparison.Eq: 318 | sb = sb.Where(fmt.Sprintf("%q = "+ph, fieldX), args...) 319 | case comparison.NotEq: 320 | sb = sb.Where(fmt.Sprintf("%q <> "+ph, fieldX), args...) 321 | case comparison.Gt: 322 | sb = sb.Where(fmt.Sprintf("%q > "+ph, fieldX), args...) 323 | case comparison.GtOrEq: 324 | sb = sb.Where(fmt.Sprintf("%q >= "+ph, fieldX), args...) 325 | case comparison.Lt: 326 | sb = sb.Where(fmt.Sprintf("%q < "+ph, fieldX), args...) 327 | case comparison.LtOrEq: 328 | sb = sb.Where(fmt.Sprintf("%q <= "+ph, fieldX), args...) 329 | case comparison.IsNull, comparison.IsNotNull: 330 | fmtStr := "%q IS NULL" 331 | if pred.Op == comparison.IsNotNull { 332 | fmtStr = "%q IS NOT NULL" 333 | } 334 | sb = sb.Where(fmt.Sprintf(fmtStr, fieldX)) 335 | case comparison.In, comparison.NotIn: 336 | fmtStr := "%q IN (%s)" 337 | if pred.Op == comparison.NotIn { 338 | fmtStr = "%q NOT IN (%s)" 339 | } 340 | 341 | phs := []string{} 342 | for range args { 343 | phs = append(phs, "?") 344 | } 345 | 346 | sb = sb.Where(fmt.Sprintf(fmtStr, fieldX, strings.Join(phs, ",")), args...) 347 | } 348 | } 349 | 350 | return sb 351 | } 352 | 353 | func (repo *PostgresRepository) buildSort(qb squirrel.SelectBuilder, sorts []*sort.Sort) squirrel.SelectBuilder { 354 | for _, s := range sorts { 355 | field := fmt.Sprintf("%q", s.Field) 356 | switch s.Direction { 357 | case sort.Asc: 358 | qb = qb.OrderBy(field + " ASC") 359 | case sort.Desc: 360 | qb = qb.OrderBy(field + " DESC") 361 | } 362 | } 363 | 364 | return qb 365 | } 366 | 367 | // Update updates a Player or many Players 368 | func (repo *PostgresRepository) Update(ctx context.Context, u *Updater) (int64, error) { 369 | return repo.update(ctx, repo.db, u) 370 | } 371 | 372 | // UpdateInTx updates a Player many Players in a transaction 373 | func (repo *PostgresRepository) UpdateInTx(ctx context.Context, tx nero.Tx, u *Updater) (int64, error) { 374 | txx, ok := tx.(*sql.Tx) 375 | if !ok { 376 | return 0, errors.New("expecting tx to be *sql.Tx") 377 | } 378 | 379 | return repo.update(ctx, txx, u) 380 | } 381 | 382 | func (repo *PostgresRepository) update(ctx context.Context, runner nero.SQLRunner, u *Updater) (int64, error) { 383 | qb := squirrel.Update("\"players\""). 384 | PlaceholderFormat(squirrel.Dollar) 385 | 386 | cnt := 0 387 | 388 | if !isZero(u.email) { 389 | qb = qb.Set("\"email\"", u.email) 390 | cnt++ 391 | } 392 | 393 | if !isZero(u.name) { 394 | qb = qb.Set("\"name\"", u.name) 395 | cnt++ 396 | } 397 | 398 | if !isZero(u.age) { 399 | qb = qb.Set("\"age\"", u.age) 400 | cnt++ 401 | } 402 | 403 | if !isZero(u.race) { 404 | qb = qb.Set("\"race\"", u.race) 405 | cnt++ 406 | } 407 | 408 | if !isZero(u.updatedAt) { 409 | qb = qb.Set("\"updated_at\"", u.updatedAt) 410 | cnt++ 411 | } 412 | 413 | if cnt == 0 { 414 | return 0, nil 415 | } 416 | 417 | preds := []*comparison.Predicate{} 418 | for _, predFunc := range u.predFuncs { 419 | preds = predFunc(preds) 420 | } 421 | qb = squirrel.UpdateBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 422 | 423 | if repo.debug && repo.logger != nil { 424 | sql, args, err := qb.ToSql() 425 | repo.logger.Printf("method: Update, stmt: %q, args: %v, error: %v", sql, args, err) 426 | } 427 | 428 | res, err := qb.RunWith(runner).ExecContext(ctx) 429 | if err != nil { 430 | return 0, err 431 | } 432 | 433 | rowsAffected, err := res.RowsAffected() 434 | if err != nil { 435 | return 0, err 436 | } 437 | 438 | return rowsAffected, nil 439 | } 440 | 441 | // Delete deletes a Player or many Players 442 | func (repo *PostgresRepository) Delete(ctx context.Context, d *Deleter) (int64, error) { 443 | return repo.delete(ctx, repo.db, d) 444 | } 445 | 446 | // DeleteInTx deletes a Player or many Players in a transaction 447 | func (repo *PostgresRepository) DeleteInTx(ctx context.Context, tx nero.Tx, d *Deleter) (int64, error) { 448 | txx, ok := tx.(*sql.Tx) 449 | if !ok { 450 | return 0, errors.New("expecting tx to be *sql.Tx") 451 | } 452 | 453 | return repo.delete(ctx, txx, d) 454 | } 455 | 456 | func (repo *PostgresRepository) delete(ctx context.Context, runner nero.SQLRunner, d *Deleter) (int64, error) { 457 | qb := squirrel.Delete("\"players\""). 458 | PlaceholderFormat(squirrel.Dollar) 459 | 460 | preds := []*comparison.Predicate{} 461 | for _, predFunc := range d.predFuncs { 462 | preds = predFunc(preds) 463 | } 464 | qb = squirrel.DeleteBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 465 | 466 | if repo.debug && repo.logger != nil { 467 | sql, args, err := qb.ToSql() 468 | repo.logger.Printf("method: Delete, stmt: %q, args: %v, error: %v", sql, args, err) 469 | } 470 | 471 | res, err := qb.RunWith(runner).ExecContext(ctx) 472 | if err != nil { 473 | return 0, err 474 | } 475 | 476 | rowsAffected, err := res.RowsAffected() 477 | if err != nil { 478 | return 0, err 479 | } 480 | 481 | return rowsAffected, nil 482 | } 483 | 484 | // Aggregate performs an aggregate query 485 | func (repo *PostgresRepository) Aggregate(ctx context.Context, a *Aggregator) error { 486 | return repo.aggregate(ctx, repo.db, a) 487 | } 488 | 489 | // AggregateInTx performs an aggregate query in a transaction 490 | func (repo *PostgresRepository) AggregateInTx(ctx context.Context, tx nero.Tx, a *Aggregator) error { 491 | txx, ok := tx.(*sql.Tx) 492 | if !ok { 493 | return errors.New("expecting tx to be *sql.Tx") 494 | } 495 | 496 | return repo.aggregate(ctx, txx, a) 497 | } 498 | 499 | func (repo *PostgresRepository) aggregate(ctx context.Context, runner nero.SQLRunner, a *Aggregator) error { 500 | aggs := []*aggregate.Aggregate{} 501 | for _, aggFunc := range a.aggFuncs { 502 | aggs = aggFunc(aggs) 503 | } 504 | columns := []string{} 505 | for _, agg := range aggs { 506 | field := agg.Field 507 | qf := fmt.Sprintf("%q", field) 508 | switch agg.Op { 509 | case aggregate.Avg: 510 | columns = append(columns, "AVG("+qf+") avg_"+field) 511 | case aggregate.Count: 512 | columns = append(columns, "COUNT("+qf+") count_"+field) 513 | case aggregate.Max: 514 | columns = append(columns, "MAX("+qf+") max_"+field) 515 | case aggregate.Min: 516 | columns = append(columns, "MIN("+qf+") min_"+field) 517 | case aggregate.Sum: 518 | columns = append(columns, "SUM("+qf+") sum_"+field) 519 | case aggregate.None: 520 | columns = append(columns, qf) 521 | } 522 | } 523 | 524 | qb := squirrel.Select(columns...).From("\"players\""). 525 | PlaceholderFormat(squirrel.Dollar) 526 | 527 | groupBys := []string{} 528 | for _, groupBy := range a.groupBys { 529 | groupBys = append(groupBys, fmt.Sprintf("%q", groupBy.String())) 530 | } 531 | qb = qb.GroupBy(groupBys...) 532 | 533 | preds := []*comparison.Predicate{} 534 | for _, predFunc := range a.predFuncs { 535 | preds = predFunc(preds) 536 | } 537 | qb = squirrel.SelectBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 538 | 539 | sorts := []*sort.Sort{} 540 | for _, sortFunc := range a.sortFuncs { 541 | sorts = sortFunc(sorts) 542 | } 543 | qb = repo.buildSort(qb, sorts) 544 | 545 | if repo.debug && repo.logger != nil { 546 | sql, args, err := qb.ToSql() 547 | repo.logger.Printf("method: Aggregate, stmt: %q, args: %v, error: %v", sql, args, err) 548 | } 549 | 550 | rows, err := qb.RunWith(runner).QueryContext(ctx) 551 | if err != nil { 552 | return err 553 | } 554 | defer rows.Close() 555 | 556 | v := reflect.ValueOf(a.v).Elem() 557 | t := reflect.TypeOf(v.Interface()).Elem() 558 | if len(columns) != t.NumField() { 559 | return errors.Errorf("column count (%v) and destination struct field count (%v) doesn't match", len(columns), t.NumField()) 560 | } 561 | 562 | for rows.Next() { 563 | ve := reflect.New(t).Elem() 564 | dest := make([]interface{}, ve.NumField()) 565 | for i := 0; i < ve.NumField(); i++ { 566 | dest[i] = ve.Field(i).Addr().Interface() 567 | } 568 | 569 | err = rows.Scan(dest...) 570 | if err != nil { 571 | return err 572 | } 573 | 574 | v.Set(reflect.Append(v, ve)) 575 | } 576 | 577 | return nil 578 | } 579 | -------------------------------------------------------------------------------- /sqlite_template.go: -------------------------------------------------------------------------------- 1 | package nero 2 | 3 | // SQLiteTemplate is a template for generating an sqlite repository 4 | type SQLiteTemplate struct { 5 | filename string 6 | } 7 | 8 | var _ Template = (*SQLiteTemplate)(nil) 9 | 10 | // NewSQLiteTemplate returns a new SQLiteTemplate 11 | func NewSQLiteTemplate() *SQLiteTemplate { 12 | return &SQLiteTemplate{filename: "sqlite.go"} 13 | } 14 | 15 | // WithFilename overrides the default filename 16 | func (t *SQLiteTemplate) WithFilename(filename string) *SQLiteTemplate { 17 | t.filename = filename 18 | return t 19 | } 20 | 21 | // Filename returns the filename 22 | func (t *SQLiteTemplate) Filename() string { 23 | return t.filename 24 | } 25 | 26 | // Content returns the template content 27 | func (t *SQLiteTemplate) Content() string { 28 | return sqliteTmpl 29 | } 30 | 31 | const sqliteTmpl = ` 32 | {{- fileHeaders -}} 33 | 34 | package {{.PkgName}} 35 | 36 | import ( 37 | "context" 38 | "database/sql" 39 | "fmt" 40 | "reflect" 41 | "io" 42 | "strings" 43 | "log" 44 | "os" 45 | "github.com/Masterminds/squirrel" 46 | _ "github.com/mattn/go-sqlite3" 47 | "github.com/pkg/errors" 48 | "github.com/stevenferrer/nero" 49 | "github.com/stevenferrer/nero/aggregate" 50 | "github.com/stevenferrer/nero/comparison" 51 | "github.com/stevenferrer/nero/sort" 52 | {{range $import := .Imports -}} 53 | "{{$import}}" 54 | {{end -}} 55 | ) 56 | 57 | {{ $fields := prependToFields .Identity .Fields }} 58 | 59 | // SQLiteRepository is a repository that uses SQLite3 as data store 60 | type SQLiteRepository struct { 61 | db *sql.DB 62 | logger nero.Logger 63 | debug bool 64 | } 65 | 66 | var _ Repository = (*SQLiteRepository)(nil) 67 | 68 | // NewSQLiteRepository returns a new SQLiteRepository 69 | func NewSQLiteRepository(db *sql.DB) *SQLiteRepository { 70 | return &SQLiteRepository{db: db} 71 | } 72 | 73 | // Debug enables debug mode 74 | func (repo *SQLiteRepository) Debug() *SQLiteRepository { 75 | l := log.New(os.Stdout, "[nero] ", log.LstdFlags | log.Lmicroseconds | log.Lmsgprefix) 76 | return &SQLiteRepository{ 77 | db: repo.db, 78 | debug: true, 79 | logger: l, 80 | } 81 | } 82 | 83 | // WithLogger overrides the default logger 84 | func (repo *SQLiteRepository) WithLogger(logger nero.Logger) *SQLiteRepository { 85 | repo.logger = logger 86 | return repo 87 | } 88 | 89 | // BeginTx starts a transaction 90 | func (repo *SQLiteRepository) BeginTx(ctx context.Context) (nero.Tx, error) { 91 | return repo.db.BeginTx(ctx, nil) 92 | } 93 | 94 | // Create creates a {{.TypeName}} 95 | func (repo *SQLiteRepository) Create(ctx context.Context, c *Creator) ({{rawType .Identity.TypeInfo.V}}, error) { 96 | return repo.create(ctx, repo.db, c) 97 | } 98 | 99 | // CreateInTx creates a {{.TypeName}} in a transaction 100 | func (repo *SQLiteRepository) CreateInTx(ctx context.Context, tx nero.Tx, c *Creator) ({{rawType .Identity.TypeInfo.V}}, error) { 101 | txx, ok := tx.(*sql.Tx) 102 | if !ok { 103 | return {{zeroValue .Identity.TypeInfo.V}}, errors.New("expecting tx to be *sql.Tx") 104 | } 105 | 106 | return repo.create(ctx, txx, c) 107 | } 108 | 109 | func (repo *SQLiteRepository) create(ctx context.Context, runner nero.SQLRunner, c *Creator) ({{rawType .Identity.TypeInfo.V}}, error) { 110 | if err := c.Validate(); err != nil { 111 | return {{zeroValue .Identity.TypeInfo.V}}, err 112 | } 113 | 114 | columns := []string{ 115 | {{range $field := $fields -}} 116 | {{if and (ne $field.IsOptional true) (ne $field.IsAuto true) -}} 117 | "\"{{$field.Name}}\"", 118 | {{end -}} 119 | {{end -}} 120 | } 121 | 122 | values := []interface{}{ 123 | {{range $field := $fields -}} 124 | {{if and (ne $field.IsOptional true) (ne $field.IsAuto true) -}} 125 | c.{{$field.Identifier}}, 126 | {{end -}} 127 | {{end -}} 128 | } 129 | 130 | {{range $field := $fields -}} 131 | {{if and ($field.IsOptional) (ne $field.IsAuto true) -}} 132 | if !isZero(c.{{$field.Identifier}}) { 133 | columns = append(columns, "{{$field.Name}}") 134 | values = append(values, c.{{$field.Identifier}}) 135 | } 136 | {{end -}} 137 | {{end}} 138 | 139 | qb := squirrel.Insert("\"{{.Table}}\"").Columns(columns...). 140 | Values(values...).RunWith(runner) 141 | if repo.debug && repo.logger != nil { 142 | sql, args, err := qb.ToSql() 143 | repo.logger.Printf("method: Create, stmt: %q, args: %v, error: %v", sql, args, err) 144 | } 145 | 146 | _, err := qb.ExecContext(ctx) 147 | if err != nil { 148 | return {{zeroValue .Identity.TypeInfo.V}}, err 149 | } 150 | 151 | var {{.Identity.Identifier}} {{rawType .Identity.TypeInfo.V}} 152 | err = repo.db.QueryRowContext(ctx, "select last_insert_rowid()").Scan(&{{.Identity.Identifier}}) 153 | if err != nil { 154 | return {{zeroValue .Identity.TypeInfo.V}}, err 155 | } 156 | 157 | return {{.Identity.Identifier}}, nil 158 | } 159 | 160 | // CreateMany batch creates {{.TypeNamePlural}} 161 | func (repo *SQLiteRepository) CreateMany(ctx context.Context, cs ...*Creator) error { 162 | return repo.createMany(ctx, repo.db, cs...) 163 | } 164 | 165 | // CreateManyInTx batch creates {{.TypeNamePlural}} in a transaction 166 | func (repo *SQLiteRepository) CreateManyInTx(ctx context.Context, tx nero.Tx, cs ...*Creator) error { 167 | txx, ok := tx.(*sql.Tx) 168 | if !ok { 169 | return errors.New("expecting tx to be *sql.Tx") 170 | } 171 | 172 | return repo.createMany(ctx, txx, cs...) 173 | } 174 | 175 | func (repo *SQLiteRepository) createMany(ctx context.Context, runner nero.SQLRunner, cs ...*Creator) error { 176 | if len(cs) == 0 { 177 | return nil 178 | } 179 | 180 | columns := []string{ 181 | {{range $field := $fields -}} 182 | {{if ne $field.IsAuto true -}} 183 | "\"{{$field.Name}}\"", 184 | {{end -}} 185 | {{end -}} 186 | } 187 | qb := squirrel.Insert("\"{{.Table}}\"").Columns(columns...) 188 | for _, c := range cs { 189 | if err := c.Validate(); err != nil { 190 | return err 191 | } 192 | 193 | qb = qb.Values( 194 | {{range $field := $fields -}} 195 | {{if ne $field.IsAuto true -}} 196 | c.{{$field.Identifier}}, 197 | {{end -}} 198 | {{end -}} 199 | ) 200 | } 201 | 202 | if repo.debug && repo.logger != nil { 203 | sql, args, err := qb.ToSql() 204 | repo.logger.Printf("method: CreateMany, stmt: %q, args: %v, error: %v", sql, args, err) 205 | } 206 | 207 | _, err := qb.RunWith(runner).ExecContext(ctx) 208 | if err != nil { 209 | return err 210 | } 211 | 212 | return nil 213 | } 214 | 215 | // Query queries {{.TypeNamePlural}} 216 | func (repo *SQLiteRepository) Query(ctx context.Context, q *Queryer) ([]{{rawType .TypeInfo.V}}, error) { 217 | return repo.query(ctx, repo.db, q) 218 | } 219 | 220 | // QueryInTx queries {{.TypeNamePlural}} in a transaction 221 | func (repo *SQLiteRepository) QueryInTx(ctx context.Context, tx nero.Tx, q *Queryer) ([]{{rawType .TypeInfo.V}}, error) { 222 | txx, ok := tx.(*sql.Tx) 223 | if !ok { 224 | return nil, errors.New("expecting tx to be *sql.Tx") 225 | } 226 | 227 | return repo.query(ctx, txx, q) 228 | } 229 | 230 | func (repo *SQLiteRepository) query(ctx context.Context, runner nero.SQLRunner, q *Queryer) ([]{{rawType .TypeInfo.V}}, error) { 231 | qb := repo.buildSelect(q) 232 | if repo.debug && repo.logger != nil { 233 | sql, args, err := qb.ToSql() 234 | repo.logger.Printf("method: Query, stmt: %q, args: %v, error: %v", sql, args, err) 235 | } 236 | 237 | rows, err := qb.RunWith(runner).QueryContext(ctx) 238 | if err != nil { 239 | return nil, err 240 | } 241 | defer rows.Close() 242 | 243 | {{.TypeIdentifierPlural}} := []{{rawType .TypeInfo.V}}{} 244 | for rows.Next() { 245 | var {{.TypeIdentifier}} {{type .TypeInfo.V}} 246 | err = rows.Scan( 247 | {{range $field := $fields -}} 248 | &{{$.TypeIdentifier}}.{{$field.StructField}}, 249 | {{end -}} 250 | ) 251 | if err != nil { 252 | return nil, err 253 | } 254 | 255 | {{.TypeIdentifierPlural}} = append({{.TypeIdentifierPlural}}, &{{.TypeIdentifier}}) 256 | } 257 | 258 | return {{.TypeIdentifierPlural}}, nil 259 | } 260 | 261 | // QueryOne queries a {{.TypeName}} 262 | func (repo *SQLiteRepository) QueryOne(ctx context.Context, q *Queryer) ({{rawType .TypeInfo.V}}, error) { 263 | return repo.queryOne(ctx, repo.db, q) 264 | } 265 | 266 | // QueryOneInTx queries a {{.TypeName}} in a transaction 267 | func (repo *SQLiteRepository) QueryOneInTx(ctx context.Context, tx nero.Tx, q *Queryer) ({{rawType .TypeInfo.V}}, error) { 268 | txx, ok := tx.(*sql.Tx) 269 | if !ok { 270 | return nil, errors.New("expecting tx to be *sql.Tx") 271 | } 272 | 273 | return repo.queryOne(ctx, txx, q) 274 | } 275 | 276 | func (repo *SQLiteRepository) queryOne(ctx context.Context, runner nero.SQLRunner, q *Queryer) ({{rawType .TypeInfo.V}}, error) { 277 | qb := repo.buildSelect(q) 278 | if repo.debug && repo.logger != nil { 279 | sql, args, err := qb.ToSql() 280 | repo.logger.Printf("method: QueryOne, stmt: %q, args: %v, error: %v", sql, args, err) 281 | } 282 | 283 | var {{.TypeIdentifier}} {{type .TypeInfo.V}} 284 | err := qb.RunWith(runner). 285 | QueryRowContext(ctx). 286 | Scan( 287 | {{range $field := $fields -}} 288 | &{{$.TypeIdentifier}}.{{$field.StructField}}, 289 | {{end -}} 290 | ) 291 | if err != nil { 292 | return {{zeroValue .TypeInfo.V}}, err 293 | } 294 | 295 | return &{{.TypeIdentifier}}, nil 296 | } 297 | 298 | func (repo *SQLiteRepository) buildSelect(q *Queryer) squirrel.SelectBuilder { 299 | columns := []string{ 300 | {{range $field := $fields -}} 301 | "\"{{$field.Name}}\"", 302 | {{end -}} 303 | } 304 | qb := squirrel.Select(columns...).From("\"{{.Table}}\"") 305 | 306 | preds := []*comparison.Predicate{} 307 | for _, predFunc := range q.predFuncs { 308 | preds = predFunc(preds) 309 | } 310 | qb = squirrel.SelectBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 311 | 312 | sorts := []*sort.Sort{} 313 | for _, sortFunc := range q.sortFuncs { 314 | sorts = sortFunc(sorts) 315 | } 316 | qb = repo.buildSort(qb, sorts) 317 | 318 | if q.limit > 0 { 319 | qb = qb.Limit(uint64(q.limit)) 320 | } 321 | 322 | if q.offset > 0 { 323 | qb = qb.Offset(uint64(q.offset)) 324 | } 325 | 326 | return qb 327 | } 328 | 329 | func (repo *SQLiteRepository) buildPreds(sb squirrel.StatementBuilderType, preds []*comparison.Predicate) squirrel.StatementBuilderType { 330 | for _, pred := range preds { 331 | ph := "?" 332 | fieldX, arg := pred.Field, pred.Arg 333 | 334 | args := []interface{}{} 335 | if fieldY, ok := arg.(Field); ok { // a field 336 | ph = fmt.Sprintf("%q", fieldY) 337 | } else if vals, ok := arg.([]interface{}); ok { // array of values 338 | args = append(args, vals...) 339 | } else { // single value 340 | args = append(args, arg) 341 | } 342 | 343 | switch pred.Op { 344 | case comparison.Eq: 345 | sb = sb.Where(fmt.Sprintf("%q = "+ph, fieldX), args...) 346 | case comparison.NotEq: 347 | sb = sb.Where(fmt.Sprintf("%q <> "+ph, fieldX), args...) 348 | case comparison.Gt: 349 | sb = sb.Where(fmt.Sprintf("%q > "+ph, fieldX), args...) 350 | case comparison.GtOrEq: 351 | sb = sb.Where(fmt.Sprintf("%q >= "+ph, fieldX), args...) 352 | case comparison.Lt: 353 | sb = sb.Where(fmt.Sprintf("%q < "+ph, fieldX), args...) 354 | case comparison.LtOrEq: 355 | sb = sb.Where(fmt.Sprintf("%q <= "+ph, fieldX), args...) 356 | case comparison.IsNull, comparison.IsNotNull: 357 | fmtStr := "%q IS NULL" 358 | if pred.Op == comparison.IsNotNull { 359 | fmtStr = "%q IS NOT NULL" 360 | } 361 | sb = sb.Where(fmt.Sprintf(fmtStr, fieldX)) 362 | case comparison.In, comparison.NotIn: 363 | fmtStr := "%q IN (%s)" 364 | if pred.Op == comparison.NotIn { 365 | fmtStr = "%q NOT IN (%s)" 366 | } 367 | 368 | phs := []string{} 369 | for range args { 370 | phs = append(phs, "?") 371 | } 372 | 373 | sb = sb.Where(fmt.Sprintf(fmtStr, fieldX, strings.Join(phs, ",")), args...) 374 | } 375 | } 376 | 377 | return sb 378 | } 379 | 380 | func (repo *SQLiteRepository) buildSort(qb squirrel.SelectBuilder, sorts []*sort.Sort) squirrel.SelectBuilder { 381 | for _, s := range sorts { 382 | field := fmt.Sprintf("%q", s.Field) 383 | switch s.Direction { 384 | case sort.Asc: 385 | qb = qb.OrderBy(field + " ASC") 386 | case sort.Desc: 387 | qb = qb.OrderBy(field + " DESC") 388 | } 389 | } 390 | 391 | return qb 392 | } 393 | 394 | // Update updates a {{.TypeName}} or many {{.TypeNamePlural}} 395 | func (repo *SQLiteRepository) Update(ctx context.Context, u *Updater) (int64, error) { 396 | return repo.update(ctx, repo.db, u) 397 | } 398 | 399 | // UpdateInTx updates a {{.TypeName}} many {{.TypeNamePlural}} in a transaction 400 | func (repo *SQLiteRepository) UpdateInTx(ctx context.Context, tx nero.Tx, u *Updater) (int64, error) { 401 | txx, ok := tx.(*sql.Tx) 402 | if !ok { 403 | return 0, errors.New("expecting tx to be *sql.Tx") 404 | } 405 | 406 | return repo.update(ctx, txx, u) 407 | } 408 | 409 | func (repo *SQLiteRepository) update(ctx context.Context, runner nero.SQLRunner, u *Updater) (int64, error) { 410 | qb := squirrel.Update("\"{{.Table}}\"") 411 | 412 | cnt := 0 413 | {{range $field := .Fields }} 414 | {{if ne $field.IsAuto true}} 415 | if !isZero(u.{{$field.Identifier}}) { 416 | qb = qb.Set("\"{{$field.Name}}\"", u.{{$field.Identifier}}) 417 | cnt++ 418 | } 419 | {{end}} 420 | {{end}} 421 | 422 | if cnt == 0 { 423 | return 0, nil 424 | } 425 | 426 | preds := []*comparison.Predicate{} 427 | for _, predFunc := range u.predFuncs { 428 | preds = predFunc(preds) 429 | } 430 | qb = squirrel.UpdateBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 431 | 432 | if repo.debug && repo.logger != nil { 433 | sql, args, err := qb.ToSql() 434 | repo.logger.Printf("method: Update, stmt: %q, args: %v, error: %v", sql, args, err) 435 | } 436 | 437 | res, err := qb.RunWith(runner).ExecContext(ctx) 438 | if err != nil { 439 | return 0, err 440 | } 441 | 442 | rowsAffected, err := res.RowsAffected() 443 | if err != nil { 444 | return 0, err 445 | } 446 | 447 | return rowsAffected, nil 448 | } 449 | 450 | // Delete deletes a {{.TypeName}} or many {{.TypeNamePlural}} 451 | func (repo *SQLiteRepository) Delete(ctx context.Context, d *Deleter) (int64, error) { 452 | return repo.delete(ctx, repo.db, d) 453 | } 454 | 455 | // DeleteInTx deletes a {{.TypeName}} or many {{.TypeNamePlural}} in a transaction 456 | func (repo *SQLiteRepository) DeleteInTx(ctx context.Context, tx nero.Tx, d *Deleter) (int64, error) { 457 | txx, ok := tx.(*sql.Tx) 458 | if !ok { 459 | return 0, errors.New("expecting tx to be *sql.Tx") 460 | } 461 | 462 | return repo.delete(ctx, txx, d) 463 | } 464 | 465 | func (repo *SQLiteRepository) delete(ctx context.Context, runner nero.SQLRunner, d *Deleter) (int64, error) { 466 | qb := squirrel.Delete("\"{{.Table}}\"") 467 | 468 | preds := []*comparison.Predicate{} 469 | for _, predFunc := range d.predFuncs { 470 | preds = predFunc(preds) 471 | } 472 | qb = squirrel.DeleteBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 473 | 474 | if repo.debug && repo.logger != nil { 475 | sql, args, err := qb.ToSql() 476 | repo.logger.Printf("method: Delete, stmt: %q, args: %v, error: %v", sql, args, err) 477 | } 478 | 479 | res, err := qb.RunWith(runner).ExecContext(ctx) 480 | if err != nil { 481 | return 0, err 482 | } 483 | 484 | rowsAffected, err := res.RowsAffected() 485 | if err != nil { 486 | return 0, err 487 | } 488 | 489 | return rowsAffected, nil 490 | } 491 | 492 | // Aggregate runs an aggregate query 493 | func (repo *SQLiteRepository) Aggregate(ctx context.Context, a *Aggregator) error { 494 | return repo.aggregate(ctx, repo.db, a) 495 | } 496 | 497 | // AggregateInTx runs an aggregate query in a transaction 498 | func (repo *SQLiteRepository) AggregateInTx(ctx context.Context, tx nero.Tx, a *Aggregator) error { 499 | txx, ok := tx.(*sql.Tx) 500 | if !ok { 501 | return errors.New("expecting tx to be *sql.Tx") 502 | } 503 | 504 | return repo.aggregate(ctx, txx, a) 505 | } 506 | 507 | func (repo *SQLiteRepository) aggregate(ctx context.Context, runner nero.SQLRunner, a *Aggregator) error { 508 | aggs := []*aggregate.Aggregate{} 509 | for _, aggFunc := range a.aggFuncs { 510 | aggs = aggFunc(aggs) 511 | } 512 | columns := []string{} 513 | for _, agg := range aggs { 514 | field := agg.Field 515 | qf := fmt.Sprintf("%q", field) 516 | switch agg.Op { 517 | case aggregate.Avg: 518 | columns = append(columns, "AVG("+qf+") avg_"+field) 519 | case aggregate.Count: 520 | columns = append(columns, "COUNT("+qf+") count_"+field) 521 | case aggregate.Max: 522 | columns = append(columns, "MAX("+qf+") max_"+field) 523 | case aggregate.Min: 524 | columns = append(columns, "MIN("+qf+") min_"+field) 525 | case aggregate.Sum: 526 | columns = append(columns, "SUM("+qf+") sum_"+field) 527 | case aggregate.None: 528 | columns = append(columns, qf) 529 | } 530 | } 531 | 532 | qb := squirrel.Select(columns...).From("\"{{.Table}}\"") 533 | 534 | groupBys := []string{} 535 | for _, groupBy := range a.groupBys { 536 | groupBys = append(groupBys, fmt.Sprintf("%q", groupBy.String())) 537 | } 538 | qb = qb.GroupBy(groupBys...) 539 | 540 | preds := []*comparison.Predicate{} 541 | for _, predFunc := range a.predFuncs { 542 | preds = predFunc(preds) 543 | } 544 | qb = squirrel.SelectBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 545 | 546 | sorts := []*sort.Sort{} 547 | for _, sortFunc := range a.sortFuncs { 548 | sorts = sortFunc(sorts) 549 | } 550 | qb = repo.buildSort(qb, sorts) 551 | 552 | if repo.debug && repo.logger != nil { 553 | sql, args, err := qb.ToSql() 554 | repo.logger.Printf("method: Aggregate, stmt: %q, args: %v, error: %v", sql, args, err) 555 | } 556 | 557 | rows, err := qb.RunWith(runner).QueryContext(ctx) 558 | if err != nil { 559 | return err 560 | } 561 | defer rows.Close() 562 | 563 | v := reflect.ValueOf(a.v).Elem() 564 | t := reflect.TypeOf(v.Interface()).Elem() 565 | if len(columns) != t.NumField() { 566 | return errors.Errorf("column count (%v) and destination struct field count (%v) doesn't match", len(columns), t.NumField(),) 567 | } 568 | 569 | for rows.Next() { 570 | ve := reflect.New(t).Elem() 571 | dest := make([]interface{}, ve.NumField()) 572 | for i := 0; i < ve.NumField(); i++ { 573 | dest[i] = ve.Field(i).Addr().Interface() 574 | } 575 | 576 | err = rows.Scan(dest...) 577 | if err != nil { 578 | return err 579 | } 580 | 581 | v.Set(reflect.Append(v, ve)) 582 | } 583 | 584 | return nil 585 | } 586 | ` 587 | -------------------------------------------------------------------------------- /test/integration/playerrepo/predicate.go: -------------------------------------------------------------------------------- 1 | // Code generated by nero, DO NOT EDIT. 2 | package playerrepo 3 | 4 | import ( 5 | "time" 6 | 7 | "github.com/stevenferrer/nero/comparison" 8 | "github.com/stevenferrer/nero/test/integration/player" 9 | ) 10 | 11 | // IDEq equal operator on ID field 12 | func IDEq(id string) comparison.PredFunc { 13 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 14 | return append(preds, &comparison.Predicate{ 15 | Field: "id", 16 | Op: comparison.Eq, 17 | Arg: id, 18 | }) 19 | } 20 | } 21 | 22 | // IDNotEq not equal operator on ID field 23 | func IDNotEq(id string) comparison.PredFunc { 24 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 25 | return append(preds, &comparison.Predicate{ 26 | Field: "id", 27 | Op: comparison.NotEq, 28 | Arg: id, 29 | }) 30 | } 31 | } 32 | 33 | // IDIn in operator on ID field 34 | func IDIn(ids ...string) comparison.PredFunc { 35 | args := []interface{}{} 36 | for _, v := range ids { 37 | args = append(args, v) 38 | } 39 | 40 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 41 | return append(preds, &comparison.Predicate{ 42 | Field: "id", 43 | Op: comparison.In, 44 | Arg: args, 45 | }) 46 | } 47 | } 48 | 49 | // IDNotIn not in operator on ID field 50 | func IDNotIn(ids ...string) comparison.PredFunc { 51 | args := []interface{}{} 52 | for _, v := range ids { 53 | args = append(args, v) 54 | } 55 | 56 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 57 | return append(preds, &comparison.Predicate{ 58 | Field: "id", 59 | Op: comparison.NotIn, 60 | Arg: args, 61 | }) 62 | } 63 | } 64 | 65 | // EmailEq equal operator on Email field 66 | func EmailEq(email string) comparison.PredFunc { 67 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 68 | return append(preds, &comparison.Predicate{ 69 | Field: "email", 70 | Op: comparison.Eq, 71 | Arg: email, 72 | }) 73 | } 74 | } 75 | 76 | // EmailNotEq not equal operator on Email field 77 | func EmailNotEq(email string) comparison.PredFunc { 78 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 79 | return append(preds, &comparison.Predicate{ 80 | Field: "email", 81 | Op: comparison.NotEq, 82 | Arg: email, 83 | }) 84 | } 85 | } 86 | 87 | // EmailIn in operator on Email field 88 | func EmailIn(emails ...string) comparison.PredFunc { 89 | args := []interface{}{} 90 | for _, v := range emails { 91 | args = append(args, v) 92 | } 93 | 94 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 95 | return append(preds, &comparison.Predicate{ 96 | Field: "email", 97 | Op: comparison.In, 98 | Arg: args, 99 | }) 100 | } 101 | } 102 | 103 | // EmailNotIn not in operator on Email field 104 | func EmailNotIn(emails ...string) comparison.PredFunc { 105 | args := []interface{}{} 106 | for _, v := range emails { 107 | args = append(args, v) 108 | } 109 | 110 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 111 | return append(preds, &comparison.Predicate{ 112 | Field: "email", 113 | Op: comparison.NotIn, 114 | Arg: args, 115 | }) 116 | } 117 | } 118 | 119 | // NameEq equal operator on Name field 120 | func NameEq(name string) comparison.PredFunc { 121 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 122 | return append(preds, &comparison.Predicate{ 123 | Field: "name", 124 | Op: comparison.Eq, 125 | Arg: name, 126 | }) 127 | } 128 | } 129 | 130 | // NameNotEq not equal operator on Name field 131 | func NameNotEq(name string) comparison.PredFunc { 132 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 133 | return append(preds, &comparison.Predicate{ 134 | Field: "name", 135 | Op: comparison.NotEq, 136 | Arg: name, 137 | }) 138 | } 139 | } 140 | 141 | // NameIn in operator on Name field 142 | func NameIn(names ...string) comparison.PredFunc { 143 | args := []interface{}{} 144 | for _, v := range names { 145 | args = append(args, v) 146 | } 147 | 148 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 149 | return append(preds, &comparison.Predicate{ 150 | Field: "name", 151 | Op: comparison.In, 152 | Arg: args, 153 | }) 154 | } 155 | } 156 | 157 | // NameNotIn not in operator on Name field 158 | func NameNotIn(names ...string) comparison.PredFunc { 159 | args := []interface{}{} 160 | for _, v := range names { 161 | args = append(args, v) 162 | } 163 | 164 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 165 | return append(preds, &comparison.Predicate{ 166 | Field: "name", 167 | Op: comparison.NotIn, 168 | Arg: args, 169 | }) 170 | } 171 | } 172 | 173 | // AgeEq equal operator on Age field 174 | func AgeEq(age int) comparison.PredFunc { 175 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 176 | return append(preds, &comparison.Predicate{ 177 | Field: "age", 178 | Op: comparison.Eq, 179 | Arg: age, 180 | }) 181 | } 182 | } 183 | 184 | // AgeNotEq not equal operator on Age field 185 | func AgeNotEq(age int) comparison.PredFunc { 186 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 187 | return append(preds, &comparison.Predicate{ 188 | Field: "age", 189 | Op: comparison.NotEq, 190 | Arg: age, 191 | }) 192 | } 193 | } 194 | 195 | // AgeGt greater than operator on Age field 196 | func AgeGt(age int) comparison.PredFunc { 197 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 198 | return append(preds, &comparison.Predicate{ 199 | Field: "age", 200 | Op: comparison.Gt, 201 | Arg: age, 202 | }) 203 | } 204 | } 205 | 206 | // AgeGtOrEq greater than or equal operator on Age field 207 | func AgeGtOrEq(age int) comparison.PredFunc { 208 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 209 | return append(preds, &comparison.Predicate{ 210 | Field: "age", 211 | Op: comparison.GtOrEq, 212 | Arg: age, 213 | }) 214 | } 215 | } 216 | 217 | // AgeLt less than operator on Age field 218 | func AgeLt(age int) comparison.PredFunc { 219 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 220 | return append(preds, &comparison.Predicate{ 221 | Field: "age", 222 | Op: comparison.Lt, 223 | Arg: age, 224 | }) 225 | } 226 | } 227 | 228 | // AgeLtOrEq less than or equal operator on Age field 229 | func AgeLtOrEq(age int) comparison.PredFunc { 230 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 231 | return append(preds, &comparison.Predicate{ 232 | Field: "age", 233 | Op: comparison.LtOrEq, 234 | Arg: age, 235 | }) 236 | } 237 | } 238 | 239 | // AgeIn in operator on Age field 240 | func AgeIn(ages ...int) comparison.PredFunc { 241 | args := []interface{}{} 242 | for _, v := range ages { 243 | args = append(args, v) 244 | } 245 | 246 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 247 | return append(preds, &comparison.Predicate{ 248 | Field: "age", 249 | Op: comparison.In, 250 | Arg: args, 251 | }) 252 | } 253 | } 254 | 255 | // AgeNotIn not in operator on Age field 256 | func AgeNotIn(ages ...int) comparison.PredFunc { 257 | args := []interface{}{} 258 | for _, v := range ages { 259 | args = append(args, v) 260 | } 261 | 262 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 263 | return append(preds, &comparison.Predicate{ 264 | Field: "age", 265 | Op: comparison.NotIn, 266 | Arg: args, 267 | }) 268 | } 269 | } 270 | 271 | // RaceEq equal operator on Race field 272 | func RaceEq(race player.Race) comparison.PredFunc { 273 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 274 | return append(preds, &comparison.Predicate{ 275 | Field: "race", 276 | Op: comparison.Eq, 277 | Arg: race, 278 | }) 279 | } 280 | } 281 | 282 | // RaceNotEq not equal operator on Race field 283 | func RaceNotEq(race player.Race) comparison.PredFunc { 284 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 285 | return append(preds, &comparison.Predicate{ 286 | Field: "race", 287 | Op: comparison.NotEq, 288 | Arg: race, 289 | }) 290 | } 291 | } 292 | 293 | // RaceIn in operator on Race field 294 | func RaceIn(races ...player.Race) comparison.PredFunc { 295 | args := []interface{}{} 296 | for _, v := range races { 297 | args = append(args, v) 298 | } 299 | 300 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 301 | return append(preds, &comparison.Predicate{ 302 | Field: "race", 303 | Op: comparison.In, 304 | Arg: args, 305 | }) 306 | } 307 | } 308 | 309 | // RaceNotIn not in operator on Race field 310 | func RaceNotIn(races ...player.Race) comparison.PredFunc { 311 | args := []interface{}{} 312 | for _, v := range races { 313 | args = append(args, v) 314 | } 315 | 316 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 317 | return append(preds, &comparison.Predicate{ 318 | Field: "race", 319 | Op: comparison.NotIn, 320 | Arg: args, 321 | }) 322 | } 323 | } 324 | 325 | // UpdatedAtEq equal operator on UpdatedAt field 326 | func UpdatedAtEq(updatedAt *time.Time) comparison.PredFunc { 327 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 328 | return append(preds, &comparison.Predicate{ 329 | Field: "updated_at", 330 | Op: comparison.Eq, 331 | Arg: updatedAt, 332 | }) 333 | } 334 | } 335 | 336 | // UpdatedAtNotEq not equal operator on UpdatedAt field 337 | func UpdatedAtNotEq(updatedAt *time.Time) comparison.PredFunc { 338 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 339 | return append(preds, &comparison.Predicate{ 340 | Field: "updated_at", 341 | Op: comparison.NotEq, 342 | Arg: updatedAt, 343 | }) 344 | } 345 | } 346 | 347 | // UpdatedAtGt greater than operator on UpdatedAt field 348 | func UpdatedAtGt(updatedAt *time.Time) comparison.PredFunc { 349 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 350 | return append(preds, &comparison.Predicate{ 351 | Field: "updated_at", 352 | Op: comparison.Gt, 353 | Arg: updatedAt, 354 | }) 355 | } 356 | } 357 | 358 | // UpdatedAtGtOrEq greater than or equal operator on UpdatedAt field 359 | func UpdatedAtGtOrEq(updatedAt *time.Time) comparison.PredFunc { 360 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 361 | return append(preds, &comparison.Predicate{ 362 | Field: "updated_at", 363 | Op: comparison.GtOrEq, 364 | Arg: updatedAt, 365 | }) 366 | } 367 | } 368 | 369 | // UpdatedAtLt less than operator on UpdatedAt field 370 | func UpdatedAtLt(updatedAt *time.Time) comparison.PredFunc { 371 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 372 | return append(preds, &comparison.Predicate{ 373 | Field: "updated_at", 374 | Op: comparison.Lt, 375 | Arg: updatedAt, 376 | }) 377 | } 378 | } 379 | 380 | // UpdatedAtLtOrEq less than or equal operator on UpdatedAt field 381 | func UpdatedAtLtOrEq(updatedAt *time.Time) comparison.PredFunc { 382 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 383 | return append(preds, &comparison.Predicate{ 384 | Field: "updated_at", 385 | Op: comparison.LtOrEq, 386 | Arg: updatedAt, 387 | }) 388 | } 389 | } 390 | 391 | // UpdatedAtIsNull is null operator on UpdatedAt field 392 | func UpdatedAtIsNull() comparison.PredFunc { 393 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 394 | return append(preds, &comparison.Predicate{ 395 | Field: "updated_at", 396 | Op: comparison.IsNull, 397 | }) 398 | } 399 | } 400 | 401 | // UpdatedAtIsNotNull is not null operator on UpdatedAt field 402 | func UpdatedAtIsNotNull() comparison.PredFunc { 403 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 404 | return append(preds, &comparison.Predicate{ 405 | Field: "updated_at", 406 | Op: comparison.IsNotNull, 407 | }) 408 | } 409 | } 410 | 411 | // UpdatedAtIn in operator on UpdatedAt field 412 | func UpdatedAtIn(updatedAts ...*time.Time) comparison.PredFunc { 413 | args := []interface{}{} 414 | for _, v := range updatedAts { 415 | args = append(args, v) 416 | } 417 | 418 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 419 | return append(preds, &comparison.Predicate{ 420 | Field: "updated_at", 421 | Op: comparison.In, 422 | Arg: args, 423 | }) 424 | } 425 | } 426 | 427 | // UpdatedAtNotIn not in operator on UpdatedAt field 428 | func UpdatedAtNotIn(updatedAts ...*time.Time) comparison.PredFunc { 429 | args := []interface{}{} 430 | for _, v := range updatedAts { 431 | args = append(args, v) 432 | } 433 | 434 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 435 | return append(preds, &comparison.Predicate{ 436 | Field: "updated_at", 437 | Op: comparison.NotIn, 438 | Arg: args, 439 | }) 440 | } 441 | } 442 | 443 | // CreatedAtEq equal operator on CreatedAt field 444 | func CreatedAtEq(createdAt *time.Time) comparison.PredFunc { 445 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 446 | return append(preds, &comparison.Predicate{ 447 | Field: "created_at", 448 | Op: comparison.Eq, 449 | Arg: createdAt, 450 | }) 451 | } 452 | } 453 | 454 | // CreatedAtNotEq not equal operator on CreatedAt field 455 | func CreatedAtNotEq(createdAt *time.Time) comparison.PredFunc { 456 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 457 | return append(preds, &comparison.Predicate{ 458 | Field: "created_at", 459 | Op: comparison.NotEq, 460 | Arg: createdAt, 461 | }) 462 | } 463 | } 464 | 465 | // CreatedAtGt greater than operator on CreatedAt field 466 | func CreatedAtGt(createdAt *time.Time) comparison.PredFunc { 467 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 468 | return append(preds, &comparison.Predicate{ 469 | Field: "created_at", 470 | Op: comparison.Gt, 471 | Arg: createdAt, 472 | }) 473 | } 474 | } 475 | 476 | // CreatedAtGtOrEq greater than or equal operator on CreatedAt field 477 | func CreatedAtGtOrEq(createdAt *time.Time) comparison.PredFunc { 478 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 479 | return append(preds, &comparison.Predicate{ 480 | Field: "created_at", 481 | Op: comparison.GtOrEq, 482 | Arg: createdAt, 483 | }) 484 | } 485 | } 486 | 487 | // CreatedAtLt less than operator on CreatedAt field 488 | func CreatedAtLt(createdAt *time.Time) comparison.PredFunc { 489 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 490 | return append(preds, &comparison.Predicate{ 491 | Field: "created_at", 492 | Op: comparison.Lt, 493 | Arg: createdAt, 494 | }) 495 | } 496 | } 497 | 498 | // CreatedAtLtOrEq less than or equal operator on CreatedAt field 499 | func CreatedAtLtOrEq(createdAt *time.Time) comparison.PredFunc { 500 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 501 | return append(preds, &comparison.Predicate{ 502 | Field: "created_at", 503 | Op: comparison.LtOrEq, 504 | Arg: createdAt, 505 | }) 506 | } 507 | } 508 | 509 | // CreatedAtIsNull is null operator on CreatedAt field 510 | func CreatedAtIsNull() comparison.PredFunc { 511 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 512 | return append(preds, &comparison.Predicate{ 513 | Field: "created_at", 514 | Op: comparison.IsNull, 515 | }) 516 | } 517 | } 518 | 519 | // CreatedAtIsNotNull is not null operator on CreatedAt field 520 | func CreatedAtIsNotNull() comparison.PredFunc { 521 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 522 | return append(preds, &comparison.Predicate{ 523 | Field: "created_at", 524 | Op: comparison.IsNotNull, 525 | }) 526 | } 527 | } 528 | 529 | // CreatedAtIn in operator on CreatedAt field 530 | func CreatedAtIn(createdAts ...*time.Time) comparison.PredFunc { 531 | args := []interface{}{} 532 | for _, v := range createdAts { 533 | args = append(args, v) 534 | } 535 | 536 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 537 | return append(preds, &comparison.Predicate{ 538 | Field: "created_at", 539 | Op: comparison.In, 540 | Arg: args, 541 | }) 542 | } 543 | } 544 | 545 | // CreatedAtNotIn not in operator on CreatedAt field 546 | func CreatedAtNotIn(createdAts ...*time.Time) comparison.PredFunc { 547 | args := []interface{}{} 548 | for _, v := range createdAts { 549 | args = append(args, v) 550 | } 551 | 552 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 553 | return append(preds, &comparison.Predicate{ 554 | Field: "created_at", 555 | Op: comparison.NotIn, 556 | Arg: args, 557 | }) 558 | } 559 | } 560 | 561 | // FieldXEqFieldY fieldX equal fieldY 562 | // 563 | // fieldX and fieldY must be of the same type 564 | func FieldXEqFieldY(fieldX, fieldY Field) comparison.PredFunc { 565 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 566 | return append(preds, &comparison.Predicate{ 567 | Field: fieldX.String(), 568 | Op: comparison.Eq, 569 | Arg: fieldY, 570 | }) 571 | } 572 | } 573 | 574 | // FieldXNotEqFieldY fieldX not equal fieldY 575 | // 576 | // fieldX and fieldY must be of the same type 577 | func FieldXNotEqFieldY(fieldX, fieldY Field) comparison.PredFunc { 578 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 579 | return append(preds, &comparison.Predicate{ 580 | Field: fieldX.String(), 581 | Op: comparison.NotEq, 582 | Arg: fieldY, 583 | }) 584 | } 585 | } 586 | 587 | // FieldXGtFieldY fieldX greater than fieldY 588 | // 589 | // fieldX and fieldY must be of the same type 590 | func FieldXGtFieldY(fieldX, fieldY Field) comparison.PredFunc { 591 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 592 | return append(preds, &comparison.Predicate{ 593 | Field: fieldX.String(), 594 | Op: comparison.Gt, 595 | Arg: fieldY, 596 | }) 597 | } 598 | } 599 | 600 | // FieldXGtOrEqFieldY fieldX greater than or equal fieldY 601 | // 602 | // fieldX and fieldY must be of the same type 603 | func FieldXGtOrEqFieldY(fieldX, fieldY Field) comparison.PredFunc { 604 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 605 | return append(preds, &comparison.Predicate{ 606 | Field: fieldX.String(), 607 | Op: comparison.GtOrEq, 608 | Arg: fieldY, 609 | }) 610 | } 611 | } 612 | 613 | // FieldXLtFieldY fieldX less than fieldY 614 | // 615 | // fieldX and fieldY must be of the same type 616 | func FieldXLtFieldY(fieldX, fieldY Field) comparison.PredFunc { 617 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 618 | return append(preds, &comparison.Predicate{ 619 | Field: fieldX.String(), 620 | Op: comparison.Lt, 621 | Arg: fieldY, 622 | }) 623 | } 624 | } 625 | 626 | // FieldXLtOrEqFieldY fieldX less than or equal fieldY 627 | // 628 | // fieldX and fieldY must be of the same type 629 | func FieldXLtOrEqFieldY(fieldX, fieldY Field) comparison.PredFunc { 630 | return func(preds []*comparison.Predicate) []*comparison.Predicate { 631 | return append(preds, &comparison.Predicate{ 632 | Field: fieldX.String(), 633 | Op: comparison.LtOrEq, 634 | Arg: fieldY, 635 | }) 636 | } 637 | } 638 | -------------------------------------------------------------------------------- /pg_template.go: -------------------------------------------------------------------------------- 1 | package nero 2 | 3 | // PostgresTemplate is the template for generating a postgres repository 4 | type PostgresTemplate struct { 5 | filename string 6 | } 7 | 8 | var _ Template = (*PostgresTemplate)(nil) 9 | 10 | // NewPostgresTemplate returns a new PostgresTemplate 11 | func NewPostgresTemplate() *PostgresTemplate { 12 | return &PostgresTemplate{ 13 | filename: "postgres.go", 14 | } 15 | } 16 | 17 | // WithFilename overrides the default filename 18 | func (t *PostgresTemplate) WithFilename(filename string) *PostgresTemplate { 19 | t.filename = filename 20 | return t 21 | } 22 | 23 | // Filename returns the filename 24 | func (t *PostgresTemplate) Filename() string { 25 | return t.filename 26 | } 27 | 28 | // Content returns the template content 29 | func (t *PostgresTemplate) Content() string { 30 | return postgresTmpl 31 | } 32 | 33 | const postgresTmpl = ` 34 | {{- fileHeaders -}} 35 | 36 | package {{.PkgName}} 37 | 38 | import ( 39 | "context" 40 | "database/sql" 41 | "fmt" 42 | "reflect" 43 | "io" 44 | "strings" 45 | "log" 46 | "os" 47 | "github.com/Masterminds/squirrel" 48 | "github.com/lib/pq" 49 | "github.com/pkg/errors" 50 | "github.com/stevenferrer/nero" 51 | "github.com/stevenferrer/nero/aggregate" 52 | "github.com/stevenferrer/nero/comparison" 53 | "github.com/stevenferrer/nero/sort" 54 | {{range $import := .Imports -}} 55 | "{{$import}}" 56 | {{end -}} 57 | ) 58 | 59 | {{ $fields := prependToFields .Identity .Fields }} 60 | 61 | // PostgresRepository is a repository that uses PostgreSQL as data store 62 | type PostgresRepository struct { 63 | db *sql.DB 64 | logger nero.Logger 65 | debug bool 66 | } 67 | 68 | var _ Repository = (*PostgresRepository)(nil) 69 | 70 | // NewPostgresRepository returns a PostgresRepository 71 | func NewPostgresRepository(db *sql.DB) *PostgresRepository { 72 | return &PostgresRepository{db: db} 73 | } 74 | 75 | // Debug enables debug mode 76 | func (repo *PostgresRepository) Debug() *PostgresRepository { 77 | l := log.New(os.Stdout, "[nero] ", log.LstdFlags | log.Lmicroseconds | log.Lmsgprefix) 78 | return &PostgresRepository{ 79 | db: repo.db, 80 | debug: true, 81 | logger: l, 82 | } 83 | } 84 | 85 | // WithLogger overrides the default logger 86 | func (repo *PostgresRepository) WithLogger(logger nero.Logger) *PostgresRepository { 87 | repo.logger = logger 88 | return repo 89 | } 90 | 91 | // BeginTx starts a transaction 92 | func (repo *PostgresRepository) BeginTx(ctx context.Context) (nero.Tx, error) { 93 | return repo.db.BeginTx(ctx, nil) 94 | } 95 | 96 | // Create creates a {{.TypeName}} 97 | func (repo *PostgresRepository) Create(ctx context.Context, c *Creator) ({{rawType .Identity.TypeInfo.V}}, error) { 98 | return repo.create(ctx, repo.db, c) 99 | } 100 | 101 | // CreateInTx creates a {{.TypeName}} in a transaction 102 | func (repo *PostgresRepository) CreateInTx(ctx context.Context, tx nero.Tx, c *Creator) ({{rawType .Identity.TypeInfo.V}}, error) { 103 | txx, ok := tx.(*sql.Tx) 104 | if !ok { 105 | return {{zeroValue .Identity.TypeInfo.V}}, errors.New("expecting tx to be *sql.Tx") 106 | } 107 | 108 | return repo.create(ctx, txx, c) 109 | } 110 | 111 | func (repo *PostgresRepository) create(ctx context.Context, runner nero.SQLRunner, c *Creator) ({{rawType .Identity.TypeInfo.V}}, error) { 112 | if err := c.Validate(); err != nil { 113 | return {{zeroValue .Identity.TypeInfo.V}}, err 114 | } 115 | 116 | columns := []string{ 117 | {{range $field := $fields -}} 118 | {{if and (ne $field.IsOptional true) (ne $field.IsAuto true) -}} 119 | "\"{{$field.Name}}\"", 120 | {{end -}} 121 | {{end -}} 122 | } 123 | 124 | values := []interface{}{ 125 | {{range $field := $fields -}} 126 | {{if and (ne $field.IsOptional true) (ne $field.IsAuto true) -}} 127 | {{if and ($field.IsArray) (ne $field.IsValueScanner true) -}} 128 | pq.Array(c.{{$field.Identifier}}), 129 | {{else -}} 130 | c.{{$field.Identifier}}, 131 | {{end -}} 132 | {{end -}} 133 | {{end -}} 134 | } 135 | 136 | {{range $field := $fields -}} 137 | {{if and ($field.IsOptional) (ne $field.IsAuto true) -}} 138 | if !isZero(c.{{$field.Identifier}}) { 139 | columns = append(columns, "{{$field.Name}}") 140 | values = append(values, c.{{$field.Identifier}}) 141 | } 142 | {{end -}} 143 | {{end}} 144 | 145 | qb := squirrel.Insert("\"{{.Table}}\""). 146 | Columns(columns...). 147 | Values(values...). 148 | Suffix("RETURNING \"{{.Identity.Name}}\""). 149 | PlaceholderFormat(squirrel.Dollar). 150 | RunWith(runner) 151 | if repo.debug && repo.logger != nil { 152 | sql, args, err := qb.ToSql() 153 | repo.logger.Printf("method: Create, stmt: %q, args: %v, error: %v", sql, args, err) 154 | } 155 | 156 | var {{.Identity.Identifier}} {{rawType .Identity.TypeInfo.V}} 157 | err := qb.QueryRowContext(ctx).Scan(&{{.Identity.Identifier}}) 158 | if err != nil { 159 | return {{zeroValue .Identity.TypeInfo.V}}, err 160 | } 161 | 162 | return {{.Identity.Identifier}}, nil 163 | } 164 | 165 | // CreateMany batch creates {{.TypeNamePlural}} 166 | func (repo *PostgresRepository) CreateMany(ctx context.Context, cs ...*Creator) error { 167 | return repo.createMany(ctx, repo.db, cs...) 168 | } 169 | 170 | // CreateManyInTx batch creates {{.TypeNamePlural}} in a transaction 171 | func (repo *PostgresRepository) CreateManyInTx(ctx context.Context, tx nero.Tx, cs ...*Creator) error { 172 | txx, ok := tx.(*sql.Tx) 173 | if !ok { 174 | return errors.New("expecting tx to be *sql.Tx") 175 | } 176 | 177 | return repo.createMany(ctx, txx, cs...) 178 | } 179 | 180 | func (repo *PostgresRepository) createMany(ctx context.Context, runner nero.SQLRunner, cs ...*Creator) error { 181 | if len(cs) == 0 { 182 | return nil 183 | } 184 | 185 | columns := []string{ 186 | {{range $field := $fields -}} 187 | {{if ne $field.IsAuto true -}} 188 | "\"{{$field.Name}}\"", 189 | {{end -}} 190 | {{end -}} 191 | } 192 | 193 | qb := squirrel.Insert("\"{{.Table}}\"").Columns(columns...) 194 | for _, c := range cs { 195 | if err := c.Validate(); err != nil { 196 | return err 197 | } 198 | 199 | qb = qb.Values( 200 | {{range $field := $fields -}} 201 | {{if ne $field.IsAuto true -}} 202 | {{if and ($field.IsArray) (ne $field.IsValueScanner true) -}} 203 | pq.Array(c.{{$field.Identifier}}), 204 | {{else -}} 205 | c.{{$field.Identifier}}, 206 | {{end -}} 207 | {{end -}} 208 | {{end -}} 209 | ) 210 | } 211 | 212 | qb = qb.Suffix("RETURNING \"{{.Identity.Name}}\""). 213 | PlaceholderFormat(squirrel.Dollar) 214 | if repo.debug && repo.logger != nil { 215 | sql, args, err := qb.ToSql() 216 | repo.logger.Printf("method: CreateMany, stmt: %q, args: %v, error: %v", sql, args, err) 217 | } 218 | 219 | _, err := qb.RunWith(runner).ExecContext(ctx) 220 | if err != nil { 221 | return err 222 | } 223 | 224 | return nil 225 | } 226 | 227 | // Query queries {{.TypeNamePlural}} 228 | func (repo *PostgresRepository) Query(ctx context.Context, q *Queryer) ([]{{rawType .TypeInfo.V}}, error) { 229 | return repo.query(ctx, repo.db, q) 230 | } 231 | 232 | // QueryInTx queries {{.TypeNamePlural}} in a transaction 233 | func (repo *PostgresRepository) QueryInTx(ctx context.Context, tx nero.Tx, q *Queryer) ([]{{rawType .TypeInfo.V}}, error) { 234 | txx, ok := tx.(*sql.Tx) 235 | if !ok { 236 | return nil, errors.New("expecting tx to be *sql.Tx") 237 | } 238 | 239 | return repo.query(ctx, txx, q) 240 | } 241 | 242 | func (repo *PostgresRepository) query(ctx context.Context, runner nero.SQLRunner, q *Queryer) ([]{{rawType .TypeInfo.V}}, error) { 243 | qb := repo.buildSelect(q) 244 | if repo.debug && repo.logger != nil { 245 | sql, args, err := qb.ToSql() 246 | repo.logger.Printf("method: Query, stmt: %q, args: %v, error: %v", sql, args, err) 247 | } 248 | 249 | rows, err := qb.RunWith(runner).QueryContext(ctx) 250 | if err != nil { 251 | return nil, err 252 | } 253 | defer rows.Close() 254 | 255 | {{.TypeIdentifierPlural}} := []{{rawType .TypeInfo.V}}{} 256 | for rows.Next() { 257 | var {{.TypeIdentifier}} {{type .TypeInfo.V}} 258 | err = rows.Scan( 259 | {{range $field := $fields -}} 260 | {{if and ($field.IsArray) (ne $field.IsValueScanner true) -}} 261 | pq.Array(&{{$.TypeIdentifier}}.{{$field.StructField}}), 262 | {{else -}} 263 | &{{$.TypeIdentifier}}.{{$field.StructField}}, 264 | {{end -}} 265 | {{end -}} 266 | ) 267 | if err != nil { 268 | return nil, err 269 | } 270 | 271 | {{.TypeIdentifierPlural}} = append({{.TypeIdentifierPlural}}, &{{.TypeIdentifier}}) 272 | } 273 | 274 | return {{.TypeIdentifierPlural}}, nil 275 | } 276 | 277 | // QueryOne queries a {{.TypeName}} 278 | func (repo *PostgresRepository) QueryOne(ctx context.Context, q *Queryer) ({{rawType .TypeInfo.V}}, error) { 279 | return repo.queryOne(ctx, repo.db, q) 280 | } 281 | 282 | // QueryOneInTx queries a {{.TypeName}} in a transaction 283 | func (repo *PostgresRepository) QueryOneInTx(ctx context.Context, tx nero.Tx, q *Queryer) ({{rawType .TypeInfo.V}}, error) { 284 | txx, ok := tx.(*sql.Tx) 285 | if !ok { 286 | return nil, errors.New("expecting tx to be *sql.Tx") 287 | } 288 | 289 | return repo.queryOne(ctx, txx, q) 290 | } 291 | 292 | func (repo *PostgresRepository) queryOne(ctx context.Context, runner nero.SQLRunner, q *Queryer) ({{rawType .TypeInfo.V}}, error) { 293 | qb := repo.buildSelect(q) 294 | if repo.debug && repo.logger != nil { 295 | sql, args, err := qb.ToSql() 296 | repo.logger.Printf("method: QueryOne, stmt: %q, args: %v, error: %v", sql, args, err) 297 | } 298 | 299 | var {{.TypeIdentifier}} {{type .TypeInfo.V}} 300 | err := qb.RunWith(runner). 301 | QueryRowContext(ctx). 302 | Scan( 303 | {{range $field := $fields -}} 304 | {{if and ($field.IsArray) (ne $field.IsValueScanner true) -}} 305 | pq.Array(&{{$.TypeIdentifier}}.{{$field.StructField}}), 306 | {{else -}} 307 | &{{$.TypeIdentifier}}.{{$field.StructField}}, 308 | {{end -}} 309 | {{end -}} 310 | ) 311 | if err != nil { 312 | return {{zeroValue .TypeInfo.V}}, err 313 | } 314 | 315 | return &{{.TypeIdentifier}}, nil 316 | } 317 | 318 | func (repo *PostgresRepository) buildSelect(q *Queryer) squirrel.SelectBuilder { 319 | columns := []string{ 320 | {{range $field := $fields -}} 321 | "\"{{$field.Name}}\"", 322 | {{end -}} 323 | } 324 | qb := squirrel.Select(columns...). 325 | From("\"{{.Table}}\""). 326 | PlaceholderFormat(squirrel.Dollar) 327 | 328 | preds := []*comparison.Predicate{} 329 | for _, predFunc := range q.predFuncs { 330 | preds = predFunc(preds) 331 | } 332 | qb = squirrel.SelectBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 333 | 334 | sorts := []*sort.Sort{} 335 | for _, sortFunc := range q.sortFuncs { 336 | sorts = sortFunc(sorts) 337 | } 338 | qb = repo.buildSort(qb, sorts) 339 | 340 | if q.limit > 0 { 341 | qb = qb.Limit(uint64(q.limit)) 342 | } 343 | 344 | if q.offset > 0 { 345 | qb = qb.Offset(uint64(q.offset)) 346 | } 347 | 348 | return qb 349 | } 350 | 351 | func (repo *PostgresRepository) buildPreds(sb squirrel.StatementBuilderType, preds []*comparison.Predicate) squirrel.StatementBuilderType { 352 | for _, pred := range preds { 353 | ph := "?" 354 | fieldX, arg := pred.Field, pred.Arg 355 | 356 | args := []interface{}{} 357 | if fieldY, ok := arg.(Field); ok { // a field 358 | ph = fmt.Sprintf("%q", fieldY) 359 | } else if vals, ok := arg.([]interface{}); ok { // array of values 360 | args = append(args, vals...) 361 | } else { // single value 362 | args = append(args, arg) 363 | } 364 | 365 | switch pred.Op { 366 | case comparison.Eq: 367 | sb = sb.Where(fmt.Sprintf("%q = "+ph, fieldX), args...) 368 | case comparison.NotEq: 369 | sb = sb.Where(fmt.Sprintf("%q <> "+ph, fieldX), args...) 370 | case comparison.Gt: 371 | sb = sb.Where(fmt.Sprintf("%q > "+ph, fieldX), args...) 372 | case comparison.GtOrEq: 373 | sb = sb.Where(fmt.Sprintf("%q >= "+ph, fieldX), args...) 374 | case comparison.Lt: 375 | sb = sb.Where(fmt.Sprintf("%q < "+ph, fieldX), args...) 376 | case comparison.LtOrEq: 377 | sb = sb.Where(fmt.Sprintf("%q <= "+ph, fieldX), args...) 378 | case comparison.IsNull, comparison.IsNotNull: 379 | fmtStr := "%q IS NULL" 380 | if pred.Op == comparison.IsNotNull { 381 | fmtStr = "%q IS NOT NULL" 382 | } 383 | sb = sb.Where(fmt.Sprintf(fmtStr, fieldX)) 384 | case comparison.In, comparison.NotIn: 385 | fmtStr := "%q IN (%s)" 386 | if pred.Op == comparison.NotIn { 387 | fmtStr = "%q NOT IN (%s)" 388 | } 389 | 390 | phs := []string{} 391 | for range args { 392 | phs = append(phs, "?") 393 | } 394 | 395 | sb = sb.Where(fmt.Sprintf(fmtStr, fieldX, strings.Join(phs, ",")), args...) 396 | } 397 | } 398 | 399 | return sb 400 | } 401 | 402 | func (repo *PostgresRepository) buildSort(qb squirrel.SelectBuilder, sorts []*sort.Sort) squirrel.SelectBuilder { 403 | for _, s := range sorts { 404 | field := fmt.Sprintf("%q", s.Field) 405 | switch s.Direction { 406 | case sort.Asc: 407 | qb = qb.OrderBy(field + " ASC") 408 | case sort.Desc: 409 | qb = qb.OrderBy(field + " DESC") 410 | } 411 | } 412 | 413 | return qb 414 | } 415 | 416 | // Update updates a {{.TypeName}} or many {{.TypeNamePlural}} 417 | func (repo *PostgresRepository) Update(ctx context.Context, u *Updater) (int64, error) { 418 | return repo.update(ctx, repo.db, u) 419 | } 420 | 421 | // UpdateInTx updates a {{.TypeName}} many {{.TypeNamePlural}} in a transaction 422 | func (repo *PostgresRepository) UpdateInTx(ctx context.Context, tx nero.Tx, u *Updater) (int64, error) { 423 | txx, ok := tx.(*sql.Tx) 424 | if !ok { 425 | return 0, errors.New("expecting tx to be *sql.Tx") 426 | } 427 | 428 | return repo.update(ctx, txx, u) 429 | } 430 | 431 | func (repo *PostgresRepository) update(ctx context.Context, runner nero.SQLRunner, u *Updater) (int64, error) { 432 | qb := squirrel.Update("\"{{.Table}}\""). 433 | PlaceholderFormat(squirrel.Dollar) 434 | 435 | cnt := 0 436 | {{range $field := .Fields }} 437 | {{if ne $field.IsAuto true}} 438 | if !isZero(u.{{$field.Identifier}}) { 439 | {{if and ($field.IsArray) (ne $field.IsValueScanner true) -}} 440 | qb = qb.Set("\"{{$field.Name}}\"", pq.Array(u.{{$field.Identifier}})) 441 | {{else -}} 442 | qb = qb.Set("\"{{$field.Name}}\"", u.{{$field.Identifier}}) 443 | {{end -}} 444 | cnt++ 445 | } 446 | {{end}} 447 | {{end}} 448 | 449 | if cnt == 0 { 450 | return 0, nil 451 | } 452 | 453 | preds := []*comparison.Predicate{} 454 | for _, predFunc := range u.predFuncs { 455 | preds = predFunc(preds) 456 | } 457 | qb = squirrel.UpdateBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 458 | 459 | if repo.debug && repo.logger != nil { 460 | sql, args, err := qb.ToSql() 461 | repo.logger.Printf("method: Update, stmt: %q, args: %v, error: %v", sql, args, err) 462 | } 463 | 464 | res, err := qb.RunWith(runner).ExecContext(ctx) 465 | if err != nil { 466 | return 0, err 467 | } 468 | 469 | rowsAffected, err := res.RowsAffected() 470 | if err != nil { 471 | return 0, err 472 | } 473 | 474 | return rowsAffected, nil 475 | } 476 | 477 | // Delete deletes a {{.TypeName}} or many {{.TypeNamePlural}} 478 | func (repo *PostgresRepository) Delete(ctx context.Context, d *Deleter) (int64, error) { 479 | return repo.delete(ctx, repo.db, d) 480 | } 481 | 482 | // DeleteInTx deletes a {{.TypeName}} or many {{.TypeNamePlural}} in a transaction 483 | func (repo *PostgresRepository) DeleteInTx(ctx context.Context, tx nero.Tx, d *Deleter) (int64, error) { 484 | txx, ok := tx.(*sql.Tx) 485 | if !ok { 486 | return 0, errors.New("expecting tx to be *sql.Tx") 487 | } 488 | 489 | return repo.delete(ctx, txx, d) 490 | } 491 | 492 | func (repo *PostgresRepository) delete(ctx context.Context, runner nero.SQLRunner, d *Deleter) (int64, error) { 493 | qb := squirrel.Delete("\"{{.Table}}\""). 494 | PlaceholderFormat(squirrel.Dollar) 495 | 496 | preds := []*comparison.Predicate{} 497 | for _, predFunc := range d.predFuncs { 498 | preds = predFunc(preds) 499 | } 500 | qb = squirrel.DeleteBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 501 | 502 | if repo.debug && repo.logger != nil { 503 | sql, args, err := qb.ToSql() 504 | repo.logger.Printf("method: Delete, stmt: %q, args: %v, error: %v", sql, args, err) 505 | } 506 | 507 | res, err := qb.RunWith(runner).ExecContext(ctx) 508 | if err != nil { 509 | return 0, err 510 | } 511 | 512 | rowsAffected, err := res.RowsAffected() 513 | if err != nil { 514 | return 0, err 515 | } 516 | 517 | return rowsAffected, nil 518 | } 519 | 520 | // Aggregate performs an aggregate query 521 | func (repo *PostgresRepository) Aggregate(ctx context.Context, a *Aggregator) error { 522 | return repo.aggregate(ctx, repo.db, a) 523 | } 524 | 525 | // AggregateInTx performs an aggregate query in a transaction 526 | func (repo *PostgresRepository) AggregateInTx(ctx context.Context, tx nero.Tx, a *Aggregator) error { 527 | txx, ok := tx.(*sql.Tx) 528 | if !ok { 529 | return errors.New("expecting tx to be *sql.Tx") 530 | } 531 | 532 | return repo.aggregate(ctx, txx, a) 533 | } 534 | 535 | func (repo *PostgresRepository) aggregate(ctx context.Context, runner nero.SQLRunner, a *Aggregator) error { 536 | aggs := []*aggregate.Aggregate{} 537 | for _, aggFunc := range a.aggFuncs { 538 | aggs = aggFunc(aggs) 539 | } 540 | columns := []string{} 541 | for _, agg := range aggs { 542 | field := agg.Field 543 | qf := fmt.Sprintf("%q", field) 544 | switch agg.Op { 545 | case aggregate.Avg: 546 | columns = append(columns, "AVG("+qf+") avg_"+field) 547 | case aggregate.Count: 548 | columns = append(columns, "COUNT("+qf+") count_"+field) 549 | case aggregate.Max: 550 | columns = append(columns, "MAX("+qf+") max_"+field) 551 | case aggregate.Min: 552 | columns = append(columns, "MIN("+qf+") min_"+field) 553 | case aggregate.Sum: 554 | columns = append(columns, "SUM("+qf+") sum_"+field) 555 | case aggregate.None: 556 | columns = append(columns, qf) 557 | } 558 | } 559 | 560 | qb := squirrel.Select(columns...).From("\"{{.Table}}\""). 561 | PlaceholderFormat(squirrel.Dollar) 562 | 563 | groupBys := []string{} 564 | for _, groupBy := range a.groupBys { 565 | groupBys = append(groupBys, fmt.Sprintf("%q", groupBy.String())) 566 | } 567 | qb = qb.GroupBy(groupBys...) 568 | 569 | preds := []*comparison.Predicate{} 570 | for _, predFunc := range a.predFuncs { 571 | preds = predFunc(preds) 572 | } 573 | qb = squirrel.SelectBuilder(repo.buildPreds(squirrel.StatementBuilderType(qb), preds)) 574 | 575 | sorts := []*sort.Sort{} 576 | for _, sortFunc := range a.sortFuncs { 577 | sorts = sortFunc(sorts) 578 | } 579 | qb = repo.buildSort(qb, sorts) 580 | 581 | if repo.debug && repo.logger != nil { 582 | sql, args, err := qb.ToSql() 583 | repo.logger.Printf("method: Aggregate, stmt: %q, args: %v, error: %v", sql, args, err) 584 | } 585 | 586 | rows, err := qb.RunWith(runner).QueryContext(ctx) 587 | if err != nil { 588 | return err 589 | } 590 | defer rows.Close() 591 | 592 | v := reflect.ValueOf(a.v).Elem() 593 | t := reflect.TypeOf(v.Interface()).Elem() 594 | if len(columns) != t.NumField() { 595 | return errors.Errorf("column count (%v) and destination struct field count (%v) doesn't match", len(columns), t.NumField(),) 596 | } 597 | 598 | for rows.Next() { 599 | ve := reflect.New(t).Elem() 600 | dest := make([]interface{}, ve.NumField()) 601 | for i := 0; i < ve.NumField(); i++ { 602 | dest[i] = ve.Field(i).Addr().Interface() 603 | } 604 | 605 | err = rows.Scan(dest...) 606 | if err != nil { 607 | return err 608 | } 609 | 610 | v.Set(reflect.Append(v, ve)) 611 | } 612 | 613 | return nil 614 | } 615 | ` 616 | --------------------------------------------------------------------------------