├── light.jpg ├── .travis.yml ├── null ├── null_test.go ├── time_test.go ├── null.go ├── bool.go ├── timestamp.go ├── string.go ├── time.go ├── clickhousetime.go ├── floats.go └── ints.go ├── go.mod ├── generator ├── generate_test.go ├── generate.go └── template.go ├── sqlparser ├── delete_test.go ├── delete.go ├── update.go ├── create_test.go ├── create.go ├── update_test.go ├── ast.go ├── insert_test.go ├── select_test.go ├── select.go ├── replace.go ├── insert.go ├── token.go ├── scanner.go └── parser.go ├── example ├── model │ └── user.go ├── schema │ └── 1.user.sql ├── conf │ └── conf.go └── store │ ├── connect.go │ ├── user.go │ ├── user_test.go │ └── user.light.go ├── goparser ├── results.go ├── params.go ├── parse_test.go ├── method.go ├── profile.go ├── variable.go └── parse.go ├── light └── execer.go ├── .gitignore ├── CHANGELOG.md ├── main.go ├── README.md └── go.sum /light.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/omigo/light/HEAD/light.jpg -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - "1.8" 5 | - "1.9" 6 | - "1.10" 7 | - "tip" 8 | 9 | script: 10 | - go test ./... 11 | -------------------------------------------------------------------------------- /null/null_test.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "testing" 5 | "unsafe" 6 | ) 7 | 8 | func TestIntSize(t *testing.T) { 9 | var a int 10 | t.Log(unsafe.Sizeof(a)) 11 | } 12 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/omigo/light 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/DATA-DOG/go-sqlmock v1.3.3 7 | github.com/omigo/log v0.1.1 8 | golang.org/x/tools v0.1.0 9 | ) 10 | -------------------------------------------------------------------------------- /generator/generate_test.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import "testing" 4 | 5 | func TestGetGomodPath(t *testing.T) { 6 | paths := []string{ 7 | ".", 8 | "/a/b/c", 9 | "/Users/Arstd/Reposits/projects/light/example/store/user.go", 10 | } 11 | 12 | for _, path := range paths { 13 | t.Run(path, func(t *testing.T) { 14 | t.Log(getGomodPath(path)) 15 | }) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /null/time_test.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | func TestNullTimeZero(t *testing.T) { 9 | str := `"0000-00-00 00:00:00"` 10 | 11 | var ckDateTime ClickHouseTime 12 | 13 | err := ckDateTime.UnmarshalJSON([]byte(str)) 14 | if err != nil { 15 | t.Error(err) 16 | } 17 | 18 | t.Log(ckDateTime.Time) 19 | 20 | t.Log(time.Unix(0, 0)) 21 | } 22 | -------------------------------------------------------------------------------- /sqlparser/delete_test.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/omigo/log" 8 | ) 9 | 10 | func TestParseDeleteStmt(t *testing.T) { 11 | sql := `DELETE FROM users WHERE id=${id}` 12 | 13 | p := NewParser(bytes.NewBufferString(sql)) 14 | stmt, err := p.Parse() 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | log.JsonIndent(stmt) 19 | } 20 | -------------------------------------------------------------------------------- /example/model/user.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/omigo/light/null" 7 | ) 8 | 9 | type Status uint8 10 | 11 | type User struct { 12 | Id uint64 13 | Username string 14 | Phone string `json:"mobile" light:"mobile,nullable"` 15 | Address *string 16 | Status Status `light:"_status"` 17 | BirthDay *time.Time 18 | Created time.Time 19 | Updated null.Timestamp `light:",nullable"` 20 | } 21 | -------------------------------------------------------------------------------- /goparser/results.go: -------------------------------------------------------------------------------- 1 | package goparser 2 | 3 | import ( 4 | "go/types" 5 | ) 6 | 7 | type Results struct { 8 | *Params 9 | 10 | Result *Variable 11 | } 12 | 13 | func NewResults(tuple *types.Tuple) *Results { 14 | rs := &Results{Params: NewParams(tuple)} 15 | switch tuple.Len() { 16 | case 1: 17 | // ddl 18 | case 2: 19 | rs.Result = rs.List[0] 20 | case 3: 21 | rs.Result = rs.List[1] 22 | default: 23 | panic(len(rs.List)) 24 | } 25 | return rs 26 | } 27 | -------------------------------------------------------------------------------- /sqlparser/delete.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import "fmt" 4 | 5 | // Parse parses a SQL DELETE statement. 6 | func (p *Parser) ParseDelete() (*Statement, error) { 7 | stmt := Statement{Type: DELETE} 8 | 9 | // First token should be a "DELETE" keyword. 10 | if tok, lit := p.scanIgnoreWhitespace(); tok != DELETE { 11 | return nil, fmt.Errorf("found %q, expected DELETE", lit) 12 | } 13 | p.unscan() 14 | 15 | stmt.Fragments = p.scanFragments() 16 | return &stmt, nil 17 | } 18 | -------------------------------------------------------------------------------- /sqlparser/update.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import "fmt" 4 | 5 | // Parse parses a SQL UPDATE statement. 6 | func (p *Parser) ParseUpdate() (*Statement, error) { 7 | stmt := Statement{Type: UPDATE} 8 | 9 | // First token should be a "UPDATE" keyword. 10 | if tok, lit := p.scanIgnoreWhitespace(); tok != UPDATE { 11 | return nil, fmt.Errorf("found %q, expected UPDATE", lit) 12 | } 13 | p.unscan() 14 | 15 | stmt.Fragments = p.scanFragments() 16 | return &stmt, nil 17 | } 18 | -------------------------------------------------------------------------------- /sqlparser/create_test.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/omigo/log" 8 | ) 9 | 10 | func TestParseCreate(t *testing.T) { 11 | sql := `create table if not exists #{dev.Platform}_#{dev.Cid} ( 12 | cid text, platform text, version text 13 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8 ` 14 | 15 | p := NewParser(bytes.NewBufferString(sql)) 16 | stmt, err := p.Parse() 17 | if err != nil { 18 | t.Fatal(err) 19 | } 20 | log.JsonIndent(stmt) 21 | } 22 | -------------------------------------------------------------------------------- /sqlparser/create.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import "fmt" 4 | 5 | // Parse parses a SQL CREATE statement. 6 | func (p *Parser) ParseCreate() (*Statement, error) { 7 | stmt := Statement{Type: CREATE} 8 | 9 | // First token should be a "CREATE" keyword. 10 | if tok, lit := p.scanIgnoreWhitespace(); tok != CREATE { 11 | return nil, fmt.Errorf("found %q, expected CREATE", lit) 12 | } 13 | p.unscan() 14 | 15 | stmt.Fragments = p.scanFragments() 16 | 17 | // Return the successfully parsed statement. 18 | return &stmt, nil 19 | } 20 | -------------------------------------------------------------------------------- /example/schema/1.user.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE `users` ( 2 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, 3 | `username` varchar(32) NOT NULL, 4 | `Phone` varchar(32) DEFAULT NULL, 5 | `address` varchar(256) DEFAULT NULL, 6 | `status` tinyint(3) unsigned DEFAULT NULL, 7 | `birth_day` date DEFAULT NULL, 8 | `created` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, 9 | `updated` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, 10 | PRIMARY KEY (`id`), 11 | UNIQUE KEY `username` (`username`) 12 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8; 13 | -------------------------------------------------------------------------------- /light/execer.go: -------------------------------------------------------------------------------- 1 | package light 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | ) 7 | 8 | type Execer interface { 9 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 10 | PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) 11 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 12 | QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 13 | } 14 | 15 | func GetExec(tx *sql.Tx, db *sql.DB) Execer { 16 | if tx != nil { 17 | return tx 18 | } 19 | return db 20 | } 21 | -------------------------------------------------------------------------------- /sqlparser/update_test.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/omigo/log" 8 | ) 9 | 10 | func TestParseUpdateStmt(t *testing.T) { 11 | sql := `UPDATE users 12 | SET [username=${u.Username},] 13 | [phone=${u.Phone},] 14 | [address=${u.Address},] 15 | [status=${u.Status},] 16 | [birthday=${u.Birthday},] 17 | updated=CURRENT_TIMESTAMP 18 | WHERE id=${u.Id}` 19 | 20 | p := NewParser(bytes.NewBufferString(sql)) 21 | stmt, err := p.Parse() 22 | if err != nil { 23 | t.Fatal(err) 24 | } 25 | log.JsonIndent(stmt) 26 | } 27 | -------------------------------------------------------------------------------- /sqlparser/ast.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | type Statement struct { 4 | Type Token 5 | Table string 6 | Fields []string 7 | 8 | Fragments []*Fragment `json:"fragments,omitempty"` 9 | } 10 | 11 | type Fragment struct { 12 | Condition string `json:"cond,omitempty"` 13 | Range string `json:"range,omitempty"` 14 | Open string `json:"open,omitempty"` 15 | Close string `json:"close,omitempty"` 16 | 17 | Statement string `json:"stmt,omitempty"` 18 | Replacers []string `json:"replacers,omitempty"` 19 | Variables []string `json:"variables,omitempty"` 20 | 21 | Fragments []*Fragment `json:"fragments,omitempty"` 22 | } 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | # IDE 27 | .vscode/ 28 | 29 | 30 | # Push imports to repository 31 | /vendor/** 32 | !/vendor/golang.org/x/tools/imports 33 | !/vendor/golang.org/x/tools/go/ast/astutil 34 | !/vendor/golang.org/x/tools/internal/fastwalk 35 | 36 | # Generated files 37 | .idea 38 | -------------------------------------------------------------------------------- /sqlparser/insert_test.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/omigo/log" 8 | ) 9 | 10 | func TestParseInsertStmt(t *testing.T) { 11 | sql := "insert into users(`username`, phone, address, _status, birthday, created, updated)" + ` 12 | values (?,?,?,?,?,CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)on duplicate key update 13 | username=values(?), phone=values(?), address=values(?), 14 | status=values(?), birthday=values(?), update=CURRENT_TIMESTAMP 15 | ` 16 | 17 | p := NewParser(bytes.NewBufferString(sql)) 18 | stmt, err := p.Parse() 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | log.JsonIndent(stmt) 23 | } 24 | -------------------------------------------------------------------------------- /example/conf/conf.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import "github.com/omigo/log" 4 | 5 | type conf struct { 6 | Name string 7 | LogLevel string 8 | 9 | DB struct { 10 | Dialect string 11 | Host string 12 | Port int 13 | Username string 14 | Password string 15 | DBName string 16 | Params string 17 | } 18 | } 19 | 20 | var Conf conf 21 | 22 | func init() { 23 | Conf.LogLevel = "debug" 24 | 25 | Conf.DB.Dialect = "mysql" 26 | Conf.DB.Host = "127.0.0.1" 27 | Conf.DB.Port = 3306 28 | Conf.DB.Username = "test" 29 | Conf.DB.Password = "123456" 30 | Conf.DB.DBName = "test" 31 | Conf.DB.Params = "charset=utf8&parseTime=true&loc=Local" 32 | 33 | log.SetLevelString(Conf.LogLevel) 34 | // log.Json(Conf) 35 | } 36 | -------------------------------------------------------------------------------- /sqlparser/select_test.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/omigo/log" 8 | ) 9 | 10 | func TestParseSelectStmt(t *testing.T) { 11 | sql := "select (select id from users where id=1) as id, sum(status) status,`username`, phone as phone, address, birthday, created, updated" + ` 12 | from users` + 13 | "where `from`=${from} id != -1 and username > ''" + 14 | `username like ? 15 | [ 16 | and address = ? 17 | [and phone like ${u.Phone}] 18 | and created > ${u.Created} 19 | ] 20 | and status != ? 21 | [{ range } and status in (#{ss})] 22 | [and updated > ${u.Updated}] 23 | and birthday is not null 24 | order by updated desc 25 | limit ${page*size}, ${size}` 26 | 27 | p := NewParser(bytes.NewBufferString(sql)) 28 | stmt, err := p.Parse() 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | log.JsonIndent(stmt) 33 | } 34 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) 5 | and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). 6 | 7 | ## [Unreleased] 8 | ### TODO 9 | - Range for sql in (${u.Cities}). 10 | - more test case, and more covering. 11 | - Expression for select fields. 12 | - Rewrite types 13 | - Rewrite sql parser 14 | - id/page/offset/size null wrap not required 15 | - ${(page-1)*size} Support 16 | ### Added 17 | - Intelligent guess, use ? not ${...} 18 | - Parse go source file by go/types(get signatures) and go/parser(get comments). 19 | - Parse sql (comment) to generate Statement AST. 20 | - Support CREATE statement. 21 | - Replacers #{...} like Variables ${...}. 22 | - Generate implemented file by method signature and sql AST. 23 | - Literal '...' 24 | - Fields is keyword `...` 25 | -------------------------------------------------------------------------------- /example/store/connect.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | 7 | "github.com/omigo/light/example/conf" 8 | "github.com/omigo/log" 9 | ) 10 | 11 | var db *sql.DB 12 | 13 | // func init() { 14 | // open() 15 | // log.Fataln(Connect()) 16 | // } 17 | 18 | func open() { 19 | // log.Json(conf.Conf) 20 | 21 | dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", 22 | conf.Conf.DB.Username, 23 | conf.Conf.DB.Password, 24 | conf.Conf.DB.Host, 25 | conf.Conf.DB.Port, 26 | conf.Conf.DB.DBName, 27 | conf.Conf.DB.Params, 28 | ) 29 | var err error 30 | db, err = sql.Open(conf.Conf.DB.Dialect, dsn) 31 | log.Fataln(err) 32 | 33 | db.SetMaxIdleConns(0) 34 | db.SetMaxOpenConns(1) 35 | db.SetConnMaxLifetime(0) 36 | } 37 | 38 | func Connect() error { 39 | return db.Ping() 40 | } 41 | 42 | func Close() { 43 | if db != nil { 44 | log.Errorn(db.Close()) 45 | } 46 | } 47 | 48 | func Begin() (*sql.Tx, error) { 49 | return db.Begin() 50 | } 51 | -------------------------------------------------------------------------------- /goparser/params.go: -------------------------------------------------------------------------------- 1 | package goparser 2 | 3 | import ( 4 | "go/types" 5 | "strings" 6 | ) 7 | 8 | func ParamsLast(ps *Params) string { return ps.List[len(ps.List)-1].FullName("") } 9 | func ParamsLastElem(ps *Params) string { 10 | var x = ps.List[len(ps.List)-1] 11 | if x.Slice { 12 | if x.Name[len(x.Name)-1] == 's' { 13 | return x.Name[:len(x.Name)-1] 14 | } 15 | } 16 | return x.FullName("") 17 | } 18 | 19 | type Params struct { 20 | Tuple *types.Tuple 21 | List []*Variable 22 | 23 | Names map[string]*Variable 24 | } 25 | 26 | func NewParams(tuple *types.Tuple) *Params { 27 | ps := &Params{ 28 | Tuple: tuple, 29 | List: make([]*Variable, tuple.Len()), 30 | Names: make(map[string]*Variable), 31 | } 32 | 33 | for i := 0; i < tuple.Len(); i++ { 34 | v := tuple.At(i) 35 | ps.List[i] = NewVariable(v) 36 | } 37 | 38 | return ps 39 | } 40 | 41 | func (p *Params) Lookup(name string) *Variable { 42 | name = strings.Trim(name, "`") 43 | name = strings.TrimSpace(name) 44 | return p.Names[name] 45 | } 46 | -------------------------------------------------------------------------------- /goparser/parse_test.go: -------------------------------------------------------------------------------- 1 | package goparser 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/omigo/log" 9 | ) 10 | 11 | func TestParse(t *testing.T) { 12 | gopath := strings.TrimSuffix(os.Getenv("PWD"), "/") 13 | 14 | t.Log(os.Getenv("")) 15 | filename := gopath + "/../example/store/user.go" 16 | src := `package store 17 | import ( 18 | // "database/sql" 19 | "github.com/omigo/light/example/model" 20 | ) 21 | var User IUser 22 | type IUser interface { 23 | // insert ignore into users(username, phone, address, status, birth_day, created, updated) 24 | // values (?,?,?,?,?,CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) 25 | // Insert(tx *sql.Tx, u *model.User) (int64, error) 26 | 27 | // UPDATE users 28 | // SET [username=?,] 29 | // [phone=?,] 30 | // [address=?,] 31 | // [status=?,] 32 | // [birth_day=?,] 33 | // updated=CURRENT_TIMESTAMP 34 | // WHERE id=? 35 | // Update(u *model.User) (int64, error) 36 | 37 | // select id, username, phone, address, status, birth_day, created, updated 38 | // FROM users WHERE id=? 39 | Get(id uint64) (*model.User, error) 40 | } 41 | ` 42 | 43 | itf, err := Parse(filename, src) 44 | if err != nil { 45 | t.Fatal(err) 46 | } 47 | log.JsonIndent(itf) 48 | } 49 | -------------------------------------------------------------------------------- /null/null.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "time" 7 | ) 8 | 9 | func String(v *string) ValueScanner { return &NullString{String_: v} } 10 | func Uint8(v *uint8) ValueScanner { return &NullUint8{Uint8: v} } 11 | func Byte(v *byte) ValueScanner { return &NullUint8{Uint8: v} } 12 | func Int8(v *int8) ValueScanner { return &NullInt8{Int8: v} } 13 | func Uint16(v *uint16) ValueScanner { return &NullUint16{Uint16: v} } 14 | func Int16(v *int16) ValueScanner { return &NullInt16{Int16: v} } 15 | func Uint32(v *uint32) ValueScanner { return &NullUint32{Uint32: v} } 16 | func Int32(v *int32) ValueScanner { return &NullInt32{Int32: v} } 17 | func Rune(v *rune) ValueScanner { return &NullInt32{Int32: v} } 18 | func Int(v *int) ValueScanner { return &NullInt{Int: v} } 19 | func Uint64(v *uint64) ValueScanner { return &NullUint64{Uint64: v} } 20 | func Int64(v *int64) ValueScanner { return &NullInt64{Int64: v} } 21 | func Float32(v *float32) ValueScanner { return &NullFloat32{Float32: v} } 22 | func Float64(v *float64) ValueScanner { return &NullFloat64{Float64: v} } 23 | func Time(v *time.Time) ValueScanner { return &NullTime{Time: v} } 24 | func Bool(v *bool) ValueScanner { return &NullBool{Bool: v} } 25 | 26 | type ValueScanner interface { 27 | driver.Valuer 28 | sql.Scanner 29 | } 30 | -------------------------------------------------------------------------------- /null/bool.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "reflect" 7 | ) 8 | 9 | type NullBool struct { 10 | Bool *bool 11 | } 12 | 13 | func (n *NullBool) IsEmpty() bool { 14 | return n.Bool == nil || *n.Bool 15 | } 16 | 17 | func (n *NullBool) MarshalJSON() ([]byte, error) { 18 | if n.Bool == nil { 19 | return []byte("null"), nil 20 | } 21 | return []byte(fmt.Sprintf("%t", *n.Bool)), nil 22 | } 23 | 24 | func (n *NullBool) UnmarshalJSON(data []byte) error { 25 | if data == nil { 26 | return nil 27 | } 28 | if string(data) == "true" { 29 | var b bool = true 30 | *n.Bool = b 31 | } 32 | return nil 33 | } 34 | 35 | func (n *NullBool) String() string { 36 | if n.Bool != nil { 37 | return "nil" 38 | } 39 | if *n.Bool { 40 | return "true" 41 | } 42 | return "false" 43 | } 44 | 45 | // Scan implements the Scanner interface. 46 | func (s *NullBool) Scan(value interface{}) error { 47 | if value == nil { 48 | return nil 49 | } 50 | switch v := value.(type) { 51 | case int64: 52 | *s.Bool = v == 1 53 | case *int64: 54 | *s.Bool = *v == 1 55 | default: 56 | panic("unsupported type " + reflect.TypeOf(v).String()) 57 | } 58 | return nil 59 | } 60 | 61 | // Value implements the driver Valuer interface. 62 | func (s NullBool) Value() (driver.Value, error) { 63 | if s.Bool == nil { 64 | return nil, nil 65 | } 66 | if *s.Bool { 67 | return 1, nil 68 | } 69 | return 0, nil 70 | } 71 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "io/ioutil" 7 | "os" 8 | 9 | "github.com/omigo/light/generator" 10 | "github.com/omigo/light/goparser" 11 | "github.com/omigo/log" 12 | "golang.org/x/tools/imports" 13 | ) 14 | 15 | var ( 16 | withLog = flag.Bool("log", false, "Generated file with log") 17 | timeout = flag.Int64("timeout", 10, "Timeout(s) of SQL execution canceled context") 18 | ) 19 | 20 | func main() { 21 | flag.Parse() 22 | 23 | src := getSourceFile() 24 | fmt.Printf("Source file %s\n", src) 25 | dst := src[:len(src)-3] + ".light.go" 26 | // TODO must remove all *.light.go files 27 | os.Remove(dst) 28 | 29 | store, err := goparser.Parse(src, nil) 30 | if err != nil { 31 | log.Fatal(err) 32 | } 33 | // log.JSONIndent(store) 34 | store.Log = *withLog 35 | store.Timeout = *timeout 36 | 37 | content := generator.Generate(store) 38 | 39 | err = ioutil.WriteFile(dst, content, 0666) 40 | log.Fataln(err) 41 | fmt.Printf("Generated file %s\n", dst) 42 | 43 | pretty, err := imports.Process(dst, content, nil) 44 | log.Fataln(err) 45 | err = ioutil.WriteFile(dst, pretty, 0666) 46 | log.Fataln(err) 47 | } 48 | 49 | func getSourceFile() string { 50 | var src string 51 | if len(flag.Args()) > 0 { 52 | src = flag.Arg(0) 53 | } else { 54 | src = os.Getenv("GOFILE") 55 | } 56 | if src == "" { 57 | fmt.Println("source file must not blank") 58 | os.Exit(1) 59 | } 60 | if src[0] != '/' { 61 | wd, err := os.Getwd() 62 | log.Fataln(err) 63 | src = wd + "/" + src 64 | } 65 | return src 66 | } 67 | -------------------------------------------------------------------------------- /null/timestamp.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "database/sql/driver" 5 | "reflect" 6 | "strconv" 7 | "time" 8 | ) 9 | 10 | type Timestamp struct { 11 | Time *time.Time 12 | } 13 | 14 | func (n *Timestamp) IsEmpty() bool { 15 | return n.Time == nil || n.Time.IsZero() 16 | } 17 | 18 | func (n *Timestamp) MarshalJSON() ([]byte, error) { 19 | if n.Time == nil || n.Time.IsZero() { 20 | return []byte("0"), nil 21 | } 22 | return []byte(strconv.FormatInt(n.Time.Unix(), 10)), nil 23 | } 24 | 25 | func (n *Timestamp) UnmarshalJSON(data []byte) error { 26 | if n.Time == nil { 27 | n.Time = new(time.Time) 28 | } 29 | 30 | ts, err := strconv.ParseInt(string(data), 10, 64) 31 | if err != nil { 32 | return err 33 | } 34 | *n.Time = time.Unix(ts, 0) 35 | 36 | return err 37 | } 38 | 39 | func (n *Timestamp) String() string { 40 | if n.Time == nil { 41 | return "0" 42 | } 43 | if n.Time.IsZero() { 44 | return "0" 45 | } 46 | return strconv.FormatInt(n.Time.Unix(), 10) 47 | } 48 | 49 | func (n *Timestamp) Scan(value interface{}) error { 50 | if value == nil { 51 | return nil 52 | } 53 | 54 | if n.Time == nil { 55 | n.Time = new(time.Time) 56 | } 57 | 58 | switch v := value.(type) { 59 | case time.Time: 60 | *n.Time = v 61 | 62 | case *time.Time: 63 | *n.Time = *v 64 | 65 | default: 66 | panic("unsupported type " + reflect.TypeOf(v).String()) 67 | } 68 | return nil 69 | } 70 | func (n Timestamp) Value() (driver.Value, error) { 71 | if n.Time == nil { 72 | return nil, nil 73 | } 74 | 75 | if n.Time.IsZero() { 76 | return nil, nil 77 | } 78 | 79 | return *n.Time, nil 80 | } 81 | -------------------------------------------------------------------------------- /null/string.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "database/sql/driver" 5 | "reflect" 6 | ) 7 | 8 | // NullString represents a string that may be null. 9 | // NullString implements the Scanner interface so 10 | // it can be used as a scan destination: 11 | // 12 | // var plain string 13 | // err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&String{S:&s}) 14 | // ... 15 | // use plain if database return null, plain is blank 16 | type NullString struct { 17 | String_ *string 18 | } 19 | 20 | func (n *NullString) IsEmpty() bool { 21 | return n.String_ == nil || *n.String_ == "" 22 | } 23 | 24 | func (n *NullString) MarshalJSON() ([]byte, error) { 25 | if n.String_ == nil { 26 | return []byte("null"), nil 27 | } 28 | return []byte(`"` + *n.String_ + `"`), nil 29 | } 30 | 31 | func (n *NullString) UnmarshalJSON(data []byte) error { 32 | if data == nil { 33 | return nil 34 | } 35 | *n.String_ = string(data) 36 | return nil 37 | } 38 | 39 | func (n *NullString) String() string { 40 | if n.String_ == nil { 41 | return "nil" 42 | } 43 | if *n.String_ == "" { 44 | return "nil" 45 | } 46 | return *n.String_ 47 | } 48 | 49 | // Scan implements the Scanner interface. 50 | func (s *NullString) Scan(value interface{}) error { 51 | if value == nil { 52 | return nil 53 | } 54 | switch v := value.(type) { 55 | case []byte: 56 | *s.String_ = string(v) 57 | case *[]byte: 58 | *s.String_ = string(*v) 59 | default: 60 | panic("unsupported type " + reflect.TypeOf(v).String()) 61 | } 62 | return nil 63 | } 64 | 65 | // Value implements the driver Valuer interface. 66 | func (s NullString) Value() (driver.Value, error) { 67 | if s.String_ == nil { 68 | return nil, nil 69 | } 70 | if *s.String_ == "" { 71 | return nil, nil 72 | } 73 | return *s.String_, nil 74 | } 75 | -------------------------------------------------------------------------------- /null/time.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "bytes" 5 | "database/sql/driver" 6 | "reflect" 7 | "time" 8 | ) 9 | 10 | const ( 11 | formatDate = `"2006-01-02"` 12 | formatDatetime = `"2006-01-02 15:04:05"` 13 | ) 14 | 15 | type NullTime struct { 16 | Time *time.Time 17 | } 18 | 19 | func (n *NullTime) IsEmpty() bool { 20 | return n.Time == nil || n.Time.IsZero() 21 | } 22 | 23 | func (n *NullTime) MarshalJSON() ([]byte, error) { 24 | if n.Time == nil || n.Time.IsZero() { 25 | return []byte("null"), nil 26 | } 27 | if n.Time.Hour() == 0 && n.Time.Minute() == 0 && n.Time.Second() == 0 { 28 | return []byte(n.Time.Format(formatDate)), nil 29 | } 30 | return []byte(n.Time.Format(formatDatetime)), nil 31 | } 32 | 33 | func (n *NullTime) UnmarshalJSON(data []byte) (err error) { 34 | if n.Time == nil { 35 | n.Time = new(time.Time) 36 | } 37 | if bytes.HasPrefix(data, []byte(`"0000-00-00`)) { 38 | var tmp time.Time 39 | *n.Time = tmp 40 | return nil 41 | } 42 | if len(data) == len(formatDate) { 43 | *n.Time, err = time.ParseInLocation(formatDate, string(data), time.Local) 44 | } else { 45 | *n.Time, err = time.ParseInLocation(formatDatetime, string(data), time.Local) 46 | } 47 | return err 48 | } 49 | 50 | func (n *NullTime) String() string { 51 | if n.Time == nil { 52 | return "nil" 53 | } 54 | if n.Time.IsZero() { 55 | return "nil" 56 | } 57 | 58 | return n.Time.Format("2006-01-02 15:04:05.999") 59 | } 60 | 61 | func (n *NullTime) Scan(value interface{}) error { 62 | if value == nil { 63 | return nil 64 | } 65 | 66 | if n.Time == nil { 67 | n.Time = new(time.Time) 68 | } 69 | 70 | switch v := value.(type) { 71 | case time.Time: 72 | *n.Time = v 73 | 74 | case *time.Time: 75 | *n.Time = *v 76 | 77 | default: 78 | panic("unsupported type " + reflect.TypeOf(v).String()) 79 | } 80 | return nil 81 | } 82 | func (n NullTime) Value() (driver.Value, error) { 83 | if n.Time == nil { 84 | return nil, nil 85 | } 86 | 87 | if n.Time.IsZero() { 88 | return nil, nil 89 | } 90 | 91 | return *n.Time, nil 92 | } 93 | -------------------------------------------------------------------------------- /null/clickhousetime.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "bytes" 5 | "database/sql/driver" 6 | "reflect" 7 | "time" 8 | ) 9 | 10 | var zero = time.Unix(0, 0) 11 | 12 | type ClickHouseTime struct { 13 | Time *time.Time 14 | } 15 | 16 | func (n *ClickHouseTime) IsEmpty() bool { 17 | return n.Time == nil || n.Time.IsZero() 18 | } 19 | 20 | func (n *ClickHouseTime) MarshalJSON() ([]byte, error) { 21 | if n.Time == nil || n.Time.IsZero() { 22 | return []byte(zero.Format(formatDatetime)), nil 23 | } 24 | if n.Time.Hour() == 0 && n.Time.Minute() == 0 && n.Time.Second() == 0 { 25 | return []byte(n.Time.Format(formatDate)), nil 26 | } 27 | return []byte(n.Time.Format(formatDatetime)), nil 28 | } 29 | 30 | func (n *ClickHouseTime) UnmarshalJSON(data []byte) (err error) { 31 | if n.Time == nil { 32 | n.Time = new(time.Time) 33 | } 34 | if bytes.EqualFold(data, []byte("null")) || bytes.HasPrefix(data, []byte(`"0000-00-00`)) { 35 | *n.Time = zero 36 | return nil 37 | } 38 | if len(data) == len(formatDate) { 39 | *n.Time, err = time.ParseInLocation(formatDate, string(data), time.Local) 40 | } else { 41 | *n.Time, err = time.ParseInLocation(formatDatetime, string(data), time.Local) 42 | } 43 | return err 44 | } 45 | 46 | func (n *ClickHouseTime) String() string { 47 | if n.Time == nil { 48 | return "1970-01-01 08:00:00" 49 | } 50 | if n.Time.IsZero() { 51 | return "1970-01-01 08:00:00" 52 | } 53 | 54 | return n.Time.Format("2006-01-02 15:04:05") 55 | } 56 | 57 | func (n *ClickHouseTime) Scan(value interface{}) error { 58 | if value == nil { 59 | return nil 60 | } 61 | 62 | if n.Time == nil { 63 | n.Time = new(time.Time) 64 | } 65 | 66 | switch v := value.(type) { 67 | case time.Time: 68 | *n.Time = v 69 | 70 | case *time.Time: 71 | *n.Time = *v 72 | 73 | default: 74 | panic("unsupported type " + reflect.TypeOf(v).String()) 75 | } 76 | return nil 77 | } 78 | func (n ClickHouseTime) Value() (driver.Value, error) { 79 | if n.Time == nil { 80 | return nil, nil 81 | } 82 | 83 | if n.Time.IsZero() { 84 | return nil, nil 85 | } 86 | 87 | return *n.Time, nil 88 | } 89 | -------------------------------------------------------------------------------- /generator/generate.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "bytes" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | "text/template" 9 | 10 | "github.com/omigo/light/goparser" 11 | "github.com/omigo/light/sqlparser" 12 | "github.com/omigo/log" 13 | ) 14 | 15 | func getGomodPath(path string) string { 16 | for { 17 | path = filepath.Dir(path) 18 | if path == "" || path == "/" || path == "." { 19 | return path 20 | } 21 | if fileInfo, err := os.Stat(path + "/go.mod"); err != nil { 22 | if os.IsExist(err) { 23 | return path 24 | } 25 | } else if !fileInfo.IsDir() { 26 | return path 27 | } 28 | } 29 | } 30 | 31 | func Generate(itf *goparser.Interface) []byte { 32 | path := getGomodPath(itf.Source) 33 | 34 | if i := strings.LastIndex(path, string(filepath.Separator)); i > 0 { 35 | path = path[:i] 36 | 37 | if strings.HasPrefix(itf.Source, path) { 38 | itf.Source = itf.Source[len(path)+1:] 39 | } 40 | } 41 | 42 | for k, v := range itf.Imports { 43 | if i := strings.Index(k, "/vendor/"); i > 0 { 44 | delete(itf.Imports, k) 45 | itf.Imports[k[i+8:]] = v 46 | } 47 | } 48 | 49 | for _, m := range itf.Methods { 50 | p := sqlparser.NewParser(bytes.NewBufferString(m.Doc)) 51 | stmt, err := p.Parse() 52 | if err != nil { 53 | panic(err) 54 | } 55 | m.Statement = stmt 56 | // log.JSONIndent(stmt) 57 | 58 | m.GenCondition() 59 | m.SetType() 60 | 61 | m.SetSignature() 62 | } 63 | 64 | var t *template.Template 65 | t = template.New("tpl") 66 | t.Funcs(template.FuncMap{ 67 | "sub": func(a, b int) int { return a - b }, 68 | "aggregate": func(m *goparser.Method, v *sqlparser.Fragment, buf, args string) *Aggregate { 69 | return &Aggregate{Method: m, Fragment: v, Buf: buf, Args: args} 70 | }, 71 | "MethodTx": goparser.MethodTx, 72 | "HasVariable": goparser.HasVariable, 73 | "ParamsLast": goparser.ParamsLast, 74 | "ParamsLastElem": goparser.ParamsLastElem, 75 | "ResultWrap": goparser.ResultWrap, 76 | "ResultTypeName": goparser.ResultTypeName, 77 | "ResultElemTypeName": goparser.ResultElemTypeName, 78 | "LookupScanOfResults": goparser.LookupScanOfResults, 79 | "LookupValueOfParams": goparser.LookupValueOfParams, 80 | }) 81 | log.Fataln(t.Parse(tpl)) 82 | buf := bytes.NewBuffer(make([]byte, 0, 1024*16)) 83 | log.Fataln(t.Execute(buf, itf)) 84 | return buf.Bytes() 85 | } 86 | 87 | type Aggregate struct { 88 | Method *goparser.Method 89 | Fragment *sqlparser.Fragment 90 | Buf string 91 | Args string 92 | } 93 | -------------------------------------------------------------------------------- /sqlparser/select.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | // Parse parses a SQL SELECT statement. 10 | func (p *Parser) ParseSelect() (*Statement, error) { 11 | stmt := Statement{Type: SELECT} 12 | first := &Fragment{} 13 | var buf bytes.Buffer 14 | 15 | // First token should be a "SELECT" keyword. 16 | if tok, lit := p.scanIgnoreWhitespace(); tok != SELECT { 17 | return nil, fmt.Errorf("found %q, expected SELECT", lit) 18 | } 19 | buf.WriteString("SELECT ") 20 | 21 | // Next we should loop over all our comma-delimited fields. 22 | for { 23 | // Read a field. 24 | tok, lit, field, f := p.scanSelectField() 25 | if tok != IDENT && tok != ASTERISK { 26 | return nil, fmt.Errorf("found %q, expected field", lit) 27 | } 28 | 29 | stmt.Fields = append(stmt.Fields, field) 30 | first.Replacers = append(first.Replacers, f.Replacers...) 31 | first.Variables = append(first.Variables, f.Variables...) 32 | buf.WriteString(lit) 33 | 34 | // If the next token is not a comma then break the loop. 35 | if tok, _ = p.scanIgnoreWhitespace(); tok != COMMA { 36 | p.unscan() 37 | break 38 | } 39 | 40 | buf.WriteString(", ") 41 | } 42 | 43 | // First token should be a "FROM" keyword. 44 | if tok, lit := p.scanIgnoreWhitespace(); tok != FROM { 45 | return nil, fmt.Errorf("found %q, expected FROM", lit) 46 | } 47 | p.unscan() 48 | 49 | first.Statement = buf.String() 50 | stmt.Fragments = append(stmt.Fragments, first) 51 | stmt.Fragments = append(stmt.Fragments, p.scanFragments()...) 52 | 53 | return &stmt, nil 54 | } 55 | 56 | func (p *Parser) scanSelectField() (tok Token, lit, field string, f *Fragment) { 57 | var buf bytes.Buffer 58 | p.s.scanSpace() 59 | 60 | f = &Fragment{} 61 | 62 | var deep int 63 | for { 64 | tok, lit = p.scan() 65 | switch tok { 66 | case SPACE: 67 | buf.WriteByte(' ') 68 | 69 | case LPAREN: 70 | deep++ 71 | buf.WriteString(LPAREN.String()) 72 | 73 | case RPAREN: 74 | deep-- 75 | buf.WriteString(RPAREN.String()) 76 | 77 | case COMMA, FROM, EOF: 78 | if deep == 0 { 79 | p.unscan() 80 | return IDENT, strings.TrimSpace(buf.String()), field, f 81 | } 82 | buf.WriteString(lit) 83 | 84 | case REPLACER: 85 | buf.WriteString("%s") 86 | f.Replacers = append(f.Replacers, lit) 87 | 88 | case VARIABLE: 89 | buf.WriteString("?") 90 | f.Variables = append(f.Variables, lit) 91 | 92 | case BACKQUOTE: 93 | p.unscan() 94 | _, lit = p.s.scanBackQuoteIdent() 95 | buf.WriteString(lit) 96 | field = lit 97 | 98 | default: 99 | buf.WriteString(lit) 100 | field = lit 101 | } 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /sqlparser/replace.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | // Parse parses a SQL REPLACE statement. 10 | func (p *Parser) ParseReplace() (*Statement, error) { 11 | stmt := Statement{Type: REPLACE} 12 | f := Fragment{} 13 | stmt.Fragments = append(stmt.Fragments, &f) 14 | 15 | var buf bytes.Buffer 16 | // First token should be a "INSERT" keyword. 17 | if tok, lit := p.scanIgnoreWhitespace(); tok != REPLACE { 18 | return nil, fmt.Errorf("found %q, expected REPLACE", lit) 19 | } 20 | if tok, lit := p.scanIgnoreWhitespace(); tok != INTO { 21 | return nil, fmt.Errorf("found %q, expected INTO", lit) 22 | } 23 | buf.WriteString("REPLACE INTO ") 24 | 25 | // table name 26 | for { 27 | if tok, lit := p.scanIgnoreWhitespace(); tok == IDENT { 28 | buf.WriteString(lit) 29 | } else if tok == DOT { 30 | buf.WriteString(DOT.String()) 31 | } else if tok == POUND { 32 | p.unscan() 33 | v := p.scanReplacer() 34 | f.Replacers = append(f.Replacers, v) 35 | buf.WriteString("%v") 36 | } else { 37 | return nil, fmt.Errorf("found %q, expected IDENT, at `%s`", lit, buf.String()) 38 | } 39 | if tok, _ := p.scanIgnoreWhitespace(); tok != LPAREN { 40 | p.unscan() 41 | } else { 42 | buf.WriteByte('(') 43 | break 44 | } 45 | } 46 | 47 | for { 48 | if tok, lit := p.scanIgnoreWhitespace(); tok != IDENT { 49 | return nil, fmt.Errorf("found %q, expected IDENT", lit) 50 | } else { 51 | buf.WriteString(lit) 52 | stmt.Fields = append(stmt.Fields, lit) 53 | } 54 | 55 | if tok, lit := p.scanIgnoreWhitespace(); tok == COMMA { 56 | buf.WriteByte(',') 57 | } else if tok == RPAREN { 58 | buf.WriteByte(')') 59 | break 60 | } else { 61 | return nil, fmt.Errorf("found %q, expected `,` or `)`", lit) 62 | } 63 | } 64 | if tok, lit := p.scanIgnoreWhitespace(); tok != VALUES { 65 | return nil, fmt.Errorf("found %q, expected `VALUES`", lit) 66 | } 67 | buf.WriteString(" VALUES ") 68 | if tok, lit := p.scanIgnoreWhitespace(); tok != LPAREN { 69 | return nil, fmt.Errorf("found %q, expected `(`", lit) 70 | } 71 | buf.WriteByte('(') 72 | 73 | // values 74 | for i := 0; ; i++ { 75 | tok, lit := p.scanIgnoreWhitespace() 76 | if tok == QUESTION { 77 | f.Variables = append(f.Variables, stmt.Fields[i]) 78 | buf.WriteByte('?') 79 | } else if tok == DOLLAR { 80 | p.unscan() 81 | v := p.scanVariable() 82 | f.Variables = append(f.Variables, v) 83 | buf.WriteByte('?') 84 | } else { 85 | buf.WriteString(lit) 86 | } 87 | 88 | if tok, lit := p.scanIgnoreWhitespace(); tok == COMMA { 89 | buf.WriteByte(',') 90 | } else if tok == RPAREN { 91 | buf.WriteByte(')') 92 | break 93 | } else { 94 | return nil, fmt.Errorf("found %q, expected `,` or `)`", lit) 95 | } 96 | } 97 | 98 | f.Statement = strings.TrimSpace(buf.String()) 99 | return &stmt, nil 100 | } 101 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | light [![Build Status](https://api.travis-ci.org/omigo/light.svg?branch=master)](https://api.travis-ci.org/omigo/light.svg?branch=master) 2 | ===== 3 | 4 | 5 | 6 | `light` is a tool for generating database query code from go source file with 7 | interface methods commented with SQLs and Variables. 8 | 9 | `Interface methods commented with SQL and variables` => `go generate`=> `Database query code implementation` 10 | 11 | ![light.jpg](light.jpg) 12 | 13 | ### Usage 14 | 15 | Install `light` tool. Make sure $GOBIN in your $PATH environment. 16 | 17 | `go get -u -v github.com/omigo/light` 18 | 19 | Run `light -h`, check install. 20 | 21 | # light -h 22 | Usage of light: 23 | -log 24 | Generated file with log 25 | 26 | Define a interface, and comment methods with SQLs and Variables, then write a directive `//go:generate light`. 27 | 28 | ```go 29 | //go:generate light -log -timeout 30 30 | 31 | type User interface { 32 | // UPDATE users 33 | // SET [username=?,] 34 | // [phone=?,] 35 | // [address=?,] 36 | // [status=?,] 37 | // [birthday=?,] 38 | // updated=CURRENT_TIMESTAMP 39 | // WHERE id=? 40 | Update(u *model.User) (int64, error) 41 | } 42 | ``` 43 | 44 | After that, run `go generate ./...`, code generated. 45 | 46 | # go generate ./... 47 | Source file /Users/Arstd/Reposits/src/github.com/omigo/light/example/store/user.go 48 | Generated file /Users/Arstd/Reposits/src/github.com/omigo/light/example/store/user.light.go 49 | 50 | ```go 51 | 52 | type UserStore struct{} 53 | 54 | func (*UserStore) Update(u *model.User) (int64, error) { 55 | var buf bytes.Buffer 56 | var args []interface{} 57 | buf.WriteString(`UPDATE users SET `) 58 | if u.Username != "" { 59 | buf.WriteString(`username=?, `) 60 | args = append(args, u.Username) 61 | } 62 | if u.Phone != "" { 63 | buf.WriteString(`phone=?, `) 64 | args = append(args, null.String(&u.Phone)) 65 | } 66 | if u.Address != nil { 67 | buf.WriteString(`address=?, `) 68 | args = append(args, u.Address) 69 | } 70 | if u.Status != 0 { 71 | buf.WriteString(`status=?, `) 72 | args = append(args, null.Uint8(&u.Status)) 73 | } 74 | if u.Birthday != nil { 75 | buf.WriteString(`birthday=?, `) 76 | args = append(args, u.Birthday) 77 | } 78 | buf.WriteString(`updated=CURRENT_TIMESTAMP WHERE id=? `) 79 | args = append(args, null.Uint64(&u.Id)) 80 | query := buf.String() 81 | log.Debug(query) 82 | log.Debug(args...) 83 | res, err := db.Exec(query, args...) 84 | if err != nil { 85 | log.Error(query) 86 | log.Error(args...) 87 | log.Error(err) 88 | return 0, err 89 | } 90 | return res.RowsAffected() 91 | } 92 | ``` 93 | 94 | ### More 95 | 96 | Complete demo in example. 97 | 98 | Source interface: [example/store/user.go](example/store/user.go) 99 | 100 | Generated code: [example/store/user.light.go](example/store/user.light.go) 101 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/DATA-DOG/go-sqlmock v1.3.3 h1:CWUqKXe0s8A2z6qCgkP4Kru7wC11YoAnoupUKFDnH08= 2 | github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= 3 | github.com/omigo/log v0.1.1 h1:8DEDcN/uYCVoSH/wjM1NzgNFfUDsEsb2Vy3eiR+/1Fo= 4 | github.com/omigo/log v0.1.1/go.mod h1:3WjW3nPtshrS76Ws1iQMAgR/q+TP8ZXCjpZkCnHIDDQ= 5 | github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= 6 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 7 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 8 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 9 | golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= 10 | golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 11 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 12 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 13 | golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= 14 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 15 | golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 16 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 17 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 18 | golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 19 | golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 h1:myAQVi0cGEoqQVR5POX+8RR2mrocKqNN1hmeMqhX27k= 20 | golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 21 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 22 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 23 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 24 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 25 | golang.org/x/tools v0.1.0 h1:po9/4sTYwZU9lPhi1tOrb4hCv3qrhiQ77LZfGa2OjwY= 26 | golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= 27 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 28 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 29 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= 30 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 31 | -------------------------------------------------------------------------------- /sqlparser/insert.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | // Parse parses a SQL INSERT statement. 10 | func (p *Parser) ParseInsert() (*Statement, error) { 11 | stmt := Statement{Type: INSERT} 12 | 13 | f := Fragment{} 14 | var buf bytes.Buffer 15 | // First token should be a "INSERT" keyword. 16 | if tok, lit := p.scanIgnoreWhitespace(); tok != INSERT { 17 | return nil, fmt.Errorf("found %q, expect INSERT", lit) 18 | } 19 | buf.WriteString("INSERT ") 20 | tok, lit := p.scanIgnoreWhitespace() 21 | switch tok { 22 | case IGNORE: 23 | buf.WriteString("IGNORE ") 24 | if tok, lit = p.scanIgnoreWhitespace(); tok != INTO { 25 | return nil, fmt.Errorf("found %q, expect INTO", lit) 26 | } 27 | buf.WriteString("INTO ") 28 | 29 | case INTO: 30 | buf.WriteString("INTO ") 31 | 32 | default: 33 | return nil, fmt.Errorf("found %q, expect IGNORE or INTO", lit) 34 | } 35 | 36 | // table name 37 | for { 38 | if tok, lit := p.scanIgnoreWhitespace(); tok == IDENT { 39 | buf.WriteString(lit) 40 | } else if tok == DOT { 41 | buf.WriteString(DOT.String()) 42 | } else if tok == REPLACER { 43 | f.Replacers = append(f.Replacers, lit) 44 | buf.WriteString("%v") 45 | } else { 46 | return nil, fmt.Errorf("found %q, expect IDENT, at `%s`", lit, buf.String()) 47 | } 48 | if tok, _ := p.scanIgnoreWhitespace(); tok != LPAREN { 49 | p.unscan() 50 | } else { 51 | buf.WriteByte('(') 52 | break 53 | } 54 | } 55 | 56 | for { 57 | if tok, lit := p.scanIgnoreWhitespace(); tok != IDENT { 58 | return nil, fmt.Errorf("found %q, expect IDENT", lit) 59 | } else { 60 | buf.WriteString(lit) 61 | stmt.Fields = append(stmt.Fields, lit) 62 | } 63 | 64 | if tok, lit := p.scanIgnoreWhitespace(); tok == COMMA { 65 | buf.WriteByte(',') 66 | } else if tok == RPAREN { 67 | buf.WriteByte(')') 68 | break 69 | } else { 70 | return nil, fmt.Errorf("found %q, expect `,` or `)`", lit) 71 | } 72 | } 73 | if tok, lit := p.scanIgnoreWhitespace(); tok != VALUES { 74 | return nil, fmt.Errorf("found %q, expect `VALUES`", lit) 75 | } 76 | buf.WriteString(" VALUES ") 77 | if tok, lit := p.scanIgnoreWhitespace(); tok != LPAREN { 78 | return nil, fmt.Errorf("found %q, expect `(`", lit) 79 | } 80 | buf.WriteByte('(') 81 | 82 | // values 83 | for i := 0; ; i++ { 84 | tok, lit := p.scanIgnoreWhitespace() 85 | if tok == QUESTION { 86 | f.Variables = append(f.Variables, stmt.Fields[i]) 87 | buf.WriteByte('?') 88 | } else if tok == VARIABLE { 89 | f.Variables = append(f.Variables, lit) 90 | buf.WriteByte('?') 91 | } else { 92 | buf.WriteString(lit) 93 | } 94 | 95 | if tok, lit := p.scanIgnoreWhitespace(); tok == COMMA { 96 | buf.WriteByte(',') 97 | } else if tok == RPAREN { 98 | buf.WriteByte(')') 99 | break 100 | } else { 101 | return nil, fmt.Errorf("found %q, expect `,` or `)`", lit) 102 | } 103 | } 104 | 105 | fs := p.scanFragments() 106 | if len(fs) == 0 { 107 | fs = []*Fragment{&f} 108 | } else { 109 | buf.WriteByte(' ') 110 | buf.WriteString(fs[0].Statement) 111 | f.Statement = strings.TrimSpace(buf.String()) 112 | f.Replacers = append(f.Replacers, fs[0].Replacers...) 113 | f.Variables = append(f.Variables, fs[0].Variables...) 114 | f.Condition = fs[0].Condition 115 | f.Fragments = fs[0].Fragments 116 | fs[0] = &f 117 | } 118 | 119 | stmt.Fragments = fs 120 | 121 | return &stmt, nil 122 | } 123 | -------------------------------------------------------------------------------- /null/floats.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "database/sql/driver" 5 | "log" 6 | "reflect" 7 | "strconv" 8 | ) 9 | 10 | type NullFloat32 struct{ Float32 *float32 } 11 | type NullFloat64 struct{ Float64 *float64 } 12 | 13 | func (n *NullFloat32) IsEmpty() bool { return isEmptyFloat(n.Float32) } 14 | func (n *NullFloat64) IsEmpty() bool { return isEmptyFloat(n.Float64) } 15 | 16 | func (n *NullFloat32) MarshalJSON() ([]byte, error) { return marshalJSONFloat(n.Float32) } 17 | func (n *NullFloat64) MarshalJSON() ([]byte, error) { return marshalJSONFloat(n.Float64) } 18 | 19 | func (n *NullFloat32) UnmarshalJSON(data []byte) error { return unmarshalJSONFloat(n.Float32, data) } 20 | func (n *NullFloat64) UnmarshalJSON(data []byte) error { return unmarshalJSONFloat(n.Float64, data) } 21 | 22 | func (n *NullFloat32) String() string { return floatToString(n.Float32) } 23 | func (n *NullFloat64) String() string { return floatToString(n.Float64) } 24 | 25 | func (n *NullFloat32) Scan(value interface{}) error { return scanFloat(n.Float32, value) } 26 | func (n *NullFloat64) Scan(value interface{}) error { return scanFloat(n.Float64, value) } 27 | 28 | func (n NullFloat32) Value() (driver.Value, error) { return valueFloat(n.Float32) } 29 | func (n NullFloat64) Value() (driver.Value, error) { return valueFloat(n.Float64) } 30 | 31 | func isEmptyFloat(ptr interface{}) bool { 32 | if ptr == nil { 33 | return true 34 | } 35 | return toFloat64(ptr) == 0 36 | } 37 | 38 | func marshalJSONFloat(ptr interface{}) ([]byte, error) { 39 | if ptr == nil { 40 | return []byte{'0'}, nil 41 | } 42 | f64 := toFloat64(ptr) 43 | return []byte(strconv.FormatFloat(f64, 'f', '2', 64)), nil 44 | } 45 | 46 | func unmarshalJSONFloat(ptr interface{}, data []byte) error { 47 | if data == nil { 48 | return nil 49 | } 50 | i64, err := strconv.ParseFloat(string(data), 10) 51 | if err != nil { 52 | return err 53 | } 54 | 55 | fromF64(ptr, i64) 56 | return nil 57 | } 58 | 59 | func floatToString(ptr interface{}) string { 60 | if ptr == nil { 61 | return "nil" 62 | } 63 | 64 | f64 := toFloat64(ptr) 65 | 66 | if f64 == 0 { 67 | return "nil" 68 | } 69 | 70 | return strconv.FormatFloat(f64, 'e', 2, 32) 71 | } 72 | 73 | func valueFloat(ptr interface{}) (driver.Value, error) { 74 | if ptr == nil { 75 | return nil, nil 76 | } 77 | 78 | f64 := toFloat64(ptr) 79 | if f64 == 0 { 80 | return nil, nil 81 | } 82 | return f64, nil 83 | } 84 | 85 | func toFloat64(ptr interface{}) (f64 float64) { 86 | switch v := ptr.(type) { 87 | case *float32: 88 | f64 = float64(*v) 89 | case *float64: 90 | f64 = float64(*v) 91 | default: 92 | panic("unsupported type " + reflect.TypeOf(v).String()) 93 | } 94 | return 95 | } 96 | 97 | func scanFloat(ptr, value interface{}) error { 98 | if value == nil { 99 | return nil 100 | } 101 | 102 | var f64 float64 103 | switch v := value.(type) { 104 | case float64: 105 | f64 = v 106 | case *float64: 107 | f64 = *v 108 | case []byte: 109 | bitSize := 64 110 | if _, ok := ptr.(*float32); ok { 111 | bitSize = 32 112 | } 113 | var err error 114 | f64, err = strconv.ParseFloat(string(v), bitSize) 115 | if err != nil { 116 | log.Print(err) 117 | } 118 | default: 119 | panic("unsupported type " + reflect.TypeOf(v).String()) 120 | } 121 | 122 | fromF64(ptr, f64) 123 | 124 | return nil 125 | } 126 | 127 | func fromF64(ptr interface{}, f64 float64) { 128 | switch v := ptr.(type) { 129 | case *float32: 130 | *v = float32(f64) 131 | case *float64: 132 | *v = float64(f64) 133 | default: 134 | panic("unsupported type " + reflect.TypeOf(v).String()) 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /example/store/user.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "database/sql" 5 | "time" 6 | 7 | "github.com/omigo/light/example/model" 8 | ) 9 | 10 | //go:generate light -log -timeout 30 11 | 12 | var User IUser 13 | 14 | type IUser interface { 15 | 16 | // CREATE TABLE if NOT EXISTS #{name} ( 17 | // id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY, 18 | // username VARCHAR(32) NOT NULL UNIQUE, 19 | // Phone VARCHAR(32), 20 | // address VARCHAR(256), 21 | // _status TINYINT UNSIGNED, 22 | // birth_day DATE, 23 | // created TIMESTAMP default CURRENT_TIMESTAMP, 24 | // updated TIMESTAMP default CURRENT_TIMESTAMP 25 | // ) ENGINE=InnoDB DEFAULT CHARSET=utf8 26 | Create(name string) error 27 | 28 | // insert ignore into users(`username`, phone, address, _status, birth_day, created, updated) 29 | // values (${u.Username},?,?,?,?,CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) 30 | Insert(tx *sql.Tx, u *model.User) (a int64, b error) 31 | 32 | // insert ignore into users(`username`, phone, address, _status, birth_day, created, updated) 33 | // values (${u.Username},?,?,?,?,CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) 34 | Bulky(us []*model.User) (insertedRows int64, ignoreRows int64, err error) 35 | 36 | // insert into users(username, phone, address, _status, birth_day, created, updated) 37 | // values (?,?,?,?,?,CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) 38 | // on duplicate key update 39 | // username=values(username), phone=values(phone), address=values(address), 40 | // _status=values(_status), birth_day=values(birth_day), updated=CURRENT_TIMESTAMP 41 | Upsert(u *model.User, tx *sql.Tx) (int64, error) 42 | 43 | // replace into users(username, phone, address, _status, birth_day, created, updated) 44 | // values (?,?,?,?,?,CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) 45 | Replace(u *model.User) (int64, error) 46 | 47 | // UPDATE users 48 | // SET [username=?,] 49 | // [phone=?,] 50 | // [address=?,] 51 | // [_status=?,] 52 | // [birth_day=?,] 53 | // updated=CURRENT_TIMESTAMP 54 | // WHERE id=? 55 | Update(u *model.User) (int64, error) 56 | 57 | // DELETE FROM users WHERE id=? 58 | Delete(id uint64) (int64, error) 59 | 60 | // select id, username, mobile, address, _status, birth_day, created, updated 61 | // FROM users WHERE id=? 62 | Get(id uint64) (ret *model.User, e error) 63 | 64 | // select count(1) 65 | // from users 66 | // where birth_day < ? 67 | Count(birthDay time.Time) (int64, error) 68 | 69 | // select (select id from users where id=a.id) as id, 70 | // `username`, phone as phone, address, _status, birth_day, created, updated 71 | // from users a 72 | // where id != -1 and username <> 'admin' and username like ? 73 | // [ 74 | // and address = ? 75 | // [and phone like ?] 76 | // and created > ? 77 | // [{(u.BirthDay != nil && !u.BirthDay.IsZero()) || u.Id > 1 } 78 | // [and birth_day > ?] 79 | // [and id > ?] 80 | // ] 81 | // ] 82 | // and _status != ? 83 | // [and updated > ?] 84 | // and birth_day is not null 85 | // order by updated desc 86 | // limit ${offset}, ${size} 87 | List(u *model.User, offset, size int) (us []*model.User, xxx error) 88 | 89 | // select id, username, if(phone='', '0', phone) phone, address, _status, birth_day, created, updated 90 | // from users 91 | // where username like ? 92 | // [ 93 | // and address = ? 94 | // [and phone like ?] 95 | // and created > ? 96 | // ] 97 | // and birth_day is not null 98 | // and _status != ? 99 | // [{ range } and _status in (#{ss})] 100 | // [and updated > ?] 101 | // order by updated desc 102 | // limit ${offset}, ${size} 103 | Page(u *model.User, ss []model.Status, offset int, size int) (int64, []*model.User, error) 104 | } 105 | -------------------------------------------------------------------------------- /goparser/method.go: -------------------------------------------------------------------------------- 1 | package goparser 2 | 3 | import ( 4 | "bytes" 5 | "strings" 6 | 7 | "github.com/omigo/light/sqlparser" 8 | ) 9 | 10 | type MethodType string 11 | 12 | const ( 13 | MethodTypeDDL = "ddl" 14 | MethodTypeInsert = "insert" 15 | MethodTypeBulky = "bulky" 16 | MethodTypeUpdate = "update" 17 | MethodTypeDelete = "delete" 18 | MethodTypeGet = "get" 19 | MethodTypeList = "list" 20 | MethodTypePage = "page" 21 | MethodTypeAgg = "agg" 22 | ) 23 | 24 | func MethodTx(m *Method) string { 25 | for _, v := range m.Params.List { 26 | if v.Tx { 27 | return v.Name 28 | } 29 | } 30 | return "" 31 | } 32 | 33 | func HasVariable(m *Method) bool { 34 | for _, f := range m.Statement.Fragments { 35 | if len(f.Variables) > 0 || f.Range != "" { 36 | return true 37 | } 38 | } 39 | return false 40 | } 41 | 42 | type Method struct { 43 | Interface *Interface `json:"-"` 44 | 45 | Name string // Insert 46 | Doc string // insert into users ... 47 | Signature string // Insert(tx *sql.Tx, u *model.User) (int64, error) 48 | 49 | Statement *sqlparser.Statement 50 | Type MethodType 51 | 52 | Params *Params 53 | Results *Results 54 | } 55 | 56 | func NewMethod(itf *Interface, name, doc string) *Method { 57 | return &Method{ 58 | Interface: itf, 59 | Name: name, 60 | Doc: doc, 61 | } 62 | } 63 | 64 | func (m *Method) SetSignature() { 65 | var buf bytes.Buffer 66 | buf.WriteString(m.Name) 67 | 68 | buf.WriteByte('(') 69 | for i, v := range m.Params.List { 70 | if i != 0 { 71 | buf.WriteByte(',') 72 | } 73 | buf.WriteString(v.Define()) 74 | } 75 | buf.WriteByte(')') 76 | 77 | buf.WriteByte('(') 78 | for i, v := range m.Results.List { 79 | if i != 0 { 80 | buf.WriteByte(',') 81 | } 82 | buf.WriteString(v.Define()) 83 | } 84 | buf.WriteByte(')') 85 | 86 | m.Signature = buf.String() 87 | } 88 | 89 | func (m *Method) SetType() { 90 | switch m.Statement.Type { 91 | case sqlparser.SELECT: 92 | switch { 93 | case len(m.Results.List) == 3: 94 | m.Type = MethodTypePage 95 | case m.Results.Result.Slice: 96 | m.Type = MethodTypeList 97 | case !m.Results.Result.Array && !m.Results.Result.Slice && !m.Results.Result.Struct: 98 | m.Type = MethodTypeAgg 99 | default: 100 | m.Type = MethodTypeGet 101 | } 102 | 103 | case sqlparser.INSERT, sqlparser.REPLACE: 104 | if len(m.Results.List) == 3 { 105 | m.Type = MethodTypeBulky 106 | } else { 107 | m.Type = MethodTypeInsert 108 | } 109 | 110 | case sqlparser.UPDATE: 111 | m.Type = MethodTypeUpdate 112 | 113 | case sqlparser.DELETE: 114 | m.Type = MethodTypeDelete 115 | 116 | default: 117 | m.Type = MethodTypeDDL 118 | } 119 | } 120 | 121 | func (m *Method) GenCondition() { 122 | for _, f := range m.Statement.Fragments { 123 | deepGenCondition(f, m) 124 | } 125 | } 126 | 127 | func deepGenCondition(f *sqlparser.Fragment, m *Method) { 128 | if len(f.Fragments) == 0 { 129 | if f.Condition == "-" { 130 | var cs []string 131 | for _, name := range f.Variables { 132 | v := m.Params.Names[name] 133 | if v == nil { 134 | panic("method `" + m.Name + "` variable `" + name + "` not found") 135 | } 136 | d := v.NotDefault() 137 | cs = append(cs, "("+d+")") 138 | } 139 | f.Condition = strings.Join(cs, " && ") 140 | } 141 | return 142 | } 143 | 144 | for _, v := range f.Fragments { 145 | deepGenCondition(v, m) 146 | } 147 | 148 | if f.Condition != "-" { 149 | return 150 | } 151 | 152 | var cs []string 153 | for _, v := range f.Fragments { 154 | if v.Condition == "" { 155 | continue 156 | } 157 | cs = append(cs, "("+v.Condition+")") 158 | } 159 | f.Condition = strings.Join(cs, " || ") 160 | } 161 | -------------------------------------------------------------------------------- /goparser/profile.go: -------------------------------------------------------------------------------- 1 | package goparser 2 | 3 | import ( 4 | "go/types" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | type Profile struct { 10 | TypeName string 11 | 12 | PkgName string 13 | PkgPath string 14 | Alias string 15 | 16 | BasicKind types.BasicKind 17 | Array bool 18 | Slice bool 19 | Pointer bool 20 | Struct bool 21 | 22 | Tx bool 23 | 24 | Fields []*Variable `json:"-"` 25 | } 26 | 27 | func NewProfile(t types.Type, cache map[string]*Profile, deep bool) *Profile { 28 | p := &Profile{} 29 | 30 | str := t.String() 31 | 32 | switch str { 33 | case "*database/sql.Tx": 34 | p.Tx = true 35 | p.PkgPath = "database/sql" 36 | p.PkgName = "sql" 37 | p.TypeName = "Tx" 38 | p.Pointer = true 39 | 40 | case "error": 41 | p.Tx = true 42 | p.TypeName = "error" 43 | 44 | default: 45 | p.parseType(t, cache, deep) 46 | } 47 | 48 | return p 49 | } 50 | 51 | func (p *Profile) parseType(t types.Type, cache map[string]*Profile, deep bool) { 52 | switch v := t.(type) { 53 | case *types.Basic: 54 | p.TypeName = v.Name() 55 | p.BasicKind = v.Kind() 56 | 57 | case *types.Map: 58 | panic("unsupported type " + v.String()) 59 | 60 | case *types.Named: 61 | if obj := v.Obj(); obj != nil { 62 | p.TypeName = obj.Name() 63 | if pkg := obj.Pkg(); pkg != nil { 64 | p.PkgName = pkg.Name() 65 | p.PkgPath = pkg.Path() 66 | if i := strings.Index(p.PkgPath, "/vendor/"); i != -1 { 67 | p.PkgPath = p.PkgPath[i+len("/vendor/"):] 68 | } 69 | } 70 | if s, ok := v.Underlying().(*types.Struct); ok { 71 | p.Struct = true 72 | if deep { 73 | p.parseStruct(s, cache) 74 | } 75 | } else { 76 | p.parseType(v.Underlying(), cache, deep) 77 | tstr := v.Obj().Type().String() 78 | if p.PkgPath != "" && strings.HasPrefix(tstr, p.PkgPath) { 79 | p.Alias = tstr[len(p.PkgPath):] 80 | p.Alias = strings.TrimPrefix(p.Alias, ".") 81 | } 82 | } 83 | } 84 | 85 | case *types.Pointer: 86 | p.Pointer = true 87 | p.parseType(v.Elem(), cache, deep) 88 | 89 | case *types.Array: 90 | p.Array = true 91 | p.parseType(v.Elem(), cache, deep) 92 | 93 | case *types.Slice: 94 | p.Slice = true 95 | p.parseType(v.Elem(), cache, deep) 96 | 97 | case *types.Struct: 98 | p.Struct = true 99 | if deep { 100 | p.parseStruct(v, cache) 101 | } 102 | 103 | case *types.Chan, *types.Interface, *types.Signature, *types.Tuple: 104 | panic("unsupported type " + v.String()) 105 | } 106 | } 107 | 108 | func (p *Profile) parseStruct(s *types.Struct, cache map[string]*Profile) { 109 | for i := 0; i < s.NumFields(); i++ { 110 | alias, cmds := parseTags(s.Tag(i)) 111 | v := NewVariableTag(s.Field(i), alias, cmds) 112 | p.Fields = append(p.Fields, v) 113 | } 114 | } 115 | 116 | func parseTags(tag string) (alias string, cmds []string) { 117 | // Username string `json:"username" light:"uname,nullable"` 118 | 119 | groups := strings.Split(tag, " ") 120 | for _, g := range groups { 121 | kv := strings.Split(g, ":") 122 | if kv[0] != "light" { 123 | continue 124 | } 125 | v, err := strconv.Unquote(kv[1]) 126 | if err != nil { 127 | panic(err) 128 | } 129 | vs := strings.Split(v, ",") 130 | if len(vs) == 0 { 131 | return "", nil 132 | } else if len(vs) == 1 { 133 | return vs[0], nil 134 | } else { 135 | return vs[0], vs[1:] 136 | } 137 | } 138 | return "", nil 139 | } 140 | 141 | func (p *Profile) FullTypeName() string { 142 | var name string 143 | if p.Slice { 144 | name += "[]" 145 | } 146 | if p.Pointer { 147 | name += "*" 148 | } 149 | if p.PkgName != "" { 150 | name += p.PkgName + "." 151 | } 152 | 153 | if p.Alias != "" { 154 | return name + p.Alias 155 | } 156 | return name + p.TypeName 157 | } 158 | 159 | func (p *Profile) FullElemTypeName() string { 160 | var name string 161 | if p.PkgName != "" { 162 | name += p.PkgName + "." 163 | } 164 | return name + p.TypeName 165 | } 166 | -------------------------------------------------------------------------------- /goparser/variable.go: -------------------------------------------------------------------------------- 1 | package goparser 2 | 3 | import ( 4 | "fmt" 5 | "go/types" 6 | "strings" 7 | 8 | "github.com/omigo/log" 9 | ) 10 | 11 | func ResultWrap(v *Variable) string { return v.Wrap() } 12 | func ResultTypeName(v *Variable) string { return v.FullTypeName() } 13 | func ResultElemTypeName(v *Variable) string { return v.FullElemTypeName() } 14 | func LookupScanOfResults(m *Method, name string) string { 15 | v := m.Results.Lookup(name) 16 | if v == nil { 17 | // log.Error(fmt.Sprintf("method `%s` result varialbe `%s` not found", m.Name, name)) 18 | return m.Results.Result.Scan(name) 19 | } 20 | 21 | return v.Scan("") 22 | } 23 | func LookupValueOfParams(m *Method, name string) string { 24 | v := m.Params.Lookup(name) 25 | if v == nil { 26 | panic(fmt.Sprintf("method `%s` result varialbe `%s` not found", m.Name, name)) 27 | } 28 | // fmt.Println(name, v.FullName()) 29 | // if v.Slice { 30 | // v.Elem().Value() 31 | // } 32 | return v.Value() 33 | } 34 | 35 | type Variable struct { 36 | Name string 37 | *Profile 38 | 39 | Parent *Variable 40 | 41 | TagAlias string 42 | TagCmds []string 43 | 44 | Var *types.Var 45 | Type types.Type 46 | } 47 | 48 | func NewVariable(v *types.Var) *Variable { 49 | return NewVariableTag(v, "", nil) 50 | } 51 | 52 | func NewVariableTag(v *types.Var, tagAlias string, tagCmds []string) *Variable { 53 | variable := &Variable{ 54 | Name: v.Name(), 55 | Profile: new(Profile), 56 | 57 | TagAlias: tagAlias, 58 | TagCmds: tagCmds, 59 | 60 | Var: v, 61 | Type: v.Type(), 62 | } 63 | return variable 64 | } 65 | 66 | func (v *Variable) Nullable() bool { 67 | for _, cmd := range v.TagCmds { 68 | if cmd == "nullable" { 69 | return true 70 | } 71 | } 72 | return false 73 | } 74 | 75 | func (v *Variable) NotDefault() string { 76 | name := v.FullName("") 77 | 78 | switch { 79 | case v.PkgPath == "github.com/omigo/light/null": 80 | return "!" + name + ".IsEmpty()" 81 | 82 | case v.PkgPath == "time" && v.TypeName == "Time": 83 | return "!" + name + ".IsZero()" 84 | 85 | case v.Pointer: 86 | return name + " != nil" 87 | 88 | case v.Struct: 89 | return name + " != nil" 90 | 91 | case v.Array: 92 | return name + " != nil" 93 | 94 | case v.Slice: 95 | return "len(" + name + ") != 0" 96 | 97 | case v.BasicKind == types.String: 98 | return name + ` != ""` 99 | 100 | case v.BasicKind == types.Bool: 101 | return "!" + name 102 | 103 | case v.BasicKind >= types.Int && v.BasicKind <= types.Uint64: 104 | return name + ` != 0` 105 | 106 | default: 107 | log.JsonIndent(v) 108 | panic("unimplement not default for variable " + v.PkgPath + "." + v.TypeName) 109 | } 110 | } 111 | 112 | func (v *Variable) FullName(key string) (name string) { 113 | defer func() { 114 | if key != "" { 115 | name += "." + upperCamelCase(key) 116 | } 117 | }() 118 | 119 | if v.Parent != nil { 120 | if v.Parent.Name == "" { 121 | name += "xu." 122 | } else { 123 | if v.Parent.Slice { 124 | if v.Parent.Name[len(v.Parent.Name)-1] == 's' { 125 | name += v.Parent.Name[:len(v.Parent.Name)-1] + "." 126 | } 127 | } else { 128 | name += v.Parent.Name + "." 129 | } 130 | } 131 | } 132 | if v.Name == "" { 133 | return name + "xu" 134 | } 135 | 136 | return name + v.Name 137 | } 138 | 139 | func (v *Variable) Scan(name string) string { 140 | name = v.FullName(name) 141 | switch { 142 | case v.PkgPath == "github.com/omigo/light/null": 143 | return "&" + name 144 | case v.Pointer: 145 | return "&" + name 146 | case v.Nullable(): 147 | return fmt.Sprintf("null.%s%s(&%s)", strings.ToUpper(v.TypeName[:1]), v.TypeName[1:], name) 148 | default: 149 | return "&" + name 150 | } 151 | } 152 | 153 | func (v *Variable) Wrap() string { 154 | name := v.FullName("") 155 | if v.PkgPath == "github.com/omigo/light/null" { 156 | return name 157 | } 158 | name = fmt.Sprintf("null.%s%s(&%s)", strings.ToUpper(v.TypeName[:1]), v.TypeName[1:], name) 159 | return name 160 | } 161 | 162 | func (v *Variable) Value() string { 163 | name := v.FullName("") 164 | switch { 165 | case v.PkgPath == "github.com/omigo/light/null": 166 | return name 167 | case v.Pointer: 168 | return name 169 | case v.Nullable(): 170 | return fmt.Sprintf("null.%s%s(&%s)", strings.ToUpper(v.TypeName[:1]), v.TypeName[1:], name) 171 | default: 172 | return name 173 | } 174 | } 175 | 176 | func (v *Variable) Define() string { 177 | return v.Name + " " + v.FullTypeName() 178 | } 179 | 180 | // 181 | // func (v *Variable) Elem() *Variable { 182 | // switch { 183 | // case v.Slice: 184 | // x := *v 185 | // *x.Profile = *(v.Profile) 186 | // x.Slice = false 187 | // if x.Name[len(x.Name)-1] == 's' { 188 | // x.Name = x.Name[:len(x.Name)-1] 189 | // } 190 | // return &x 191 | // } 192 | // return v 193 | // } 194 | -------------------------------------------------------------------------------- /sqlparser/token.go: -------------------------------------------------------------------------------- 1 | // Package token defines constants representing the lexical tokens of the 2 | // light and basic operations on tokens (printing, predicates). 3 | // 4 | package sqlparser 5 | 6 | import ( 7 | "strconv" 8 | ) 9 | 10 | // Token is the set of lexical tokens of the Go programming language. 11 | type Token int 12 | 13 | // The list of tokens. 14 | const ( 15 | // Special tokens 16 | EOF Token = iota 17 | COMMENT 18 | 19 | literal_beg 20 | // Identifiers and basic type literals 21 | // (these tokens stand for classes of literals) 22 | IDENT // main 23 | INT // 12345 24 | FLOAT // 123.45 25 | STRING // 'abc' 26 | literal_end 27 | 28 | VARIABLE // ${...} 29 | REPLACER // #{...} 30 | 31 | // Special Character 32 | POUND // # 33 | DOLLAR // $ 34 | LBRACKET // [ 35 | RBRACKET // ] 36 | LBRACES // { 37 | RBRACES // } 38 | QUESTION // ? 39 | 40 | operator_beg 41 | // Operator 42 | EQ // = 43 | NE // != 44 | LG // <> 45 | GT // > 46 | GE // >= 47 | LT // < 48 | LE // <= 49 | BETWEEN // between ... and ... 50 | operator_end 51 | 52 | // Misc characters 53 | SPACE // SPACE 54 | EXCLAMATION // ! 55 | DOT // . 56 | ASTERISK // * 57 | COMMA // , 58 | LPAREN // ( 59 | RPAREN // ) 60 | APOSTROPHE // ' 61 | BACKQUOTE // ` 62 | MINUS // - 63 | 64 | keyword_beg 65 | // Keywords 66 | INSERT 67 | IGNORE 68 | REPLACE 69 | INTO 70 | VALUES 71 | UPDATE 72 | SET 73 | DELETE 74 | CREATE 75 | TABLE 76 | SELECT 77 | FROM 78 | WHERE 79 | AND 80 | OR 81 | LIKE 82 | NOT 83 | EXISTS 84 | GROUP 85 | BY 86 | ORDER 87 | HAVING 88 | IS 89 | NULL 90 | ASC 91 | DESC 92 | LIMIT 93 | UNION 94 | ALL 95 | CURRENT_TIMESTAMP 96 | ON 97 | DUPLICATE 98 | KEY 99 | AS 100 | keyword_end 101 | ) 102 | 103 | var tokens = [...]string{ 104 | // Special tokens 105 | EOF: "EOF", 106 | COMMENT: "COMMENT", 107 | 108 | // Literals 109 | IDENT: "IDENT", // fields, table_name 110 | 111 | // Special Character 112 | POUND: "#", 113 | DOLLAR: "$", 114 | LBRACKET: "[", 115 | RBRACKET: "]", 116 | LBRACES: "{", 117 | RBRACES: "}", 118 | QUESTION: "?", 119 | 120 | // Operator 121 | EQ: "=", 122 | NE: "!=", 123 | LG: "<>", 124 | GT: ">", 125 | GE: ">=", 126 | LT: "<", 127 | LE: "<=", 128 | BETWEEN: "BETWEEN", // between ... and ... 129 | 130 | // Misc characters 131 | SPACE: " ", // 132 | DOT: ".", 133 | ASTERISK: "*", 134 | COMMA: ",", 135 | LPAREN: "(", 136 | RPAREN: ")", 137 | APOSTROPHE: "'", 138 | BACKQUOTE: "`", 139 | MINUS: "-", 140 | 141 | // Keywords 142 | INSERT: "INSERT", 143 | IGNORE: "IGNORE", 144 | REPLACE: "REPLACE", 145 | INTO: "INTO", 146 | VALUES: "VALUES", 147 | UPDATE: "UPDATE", 148 | SET: "SET", 149 | DELETE: "DELETE", 150 | CREATE: "CREATE", 151 | TABLE: "TABLE", 152 | SELECT: "SELECT", 153 | FROM: "FROM", 154 | WHERE: "WHERE", 155 | AND: "AND", 156 | OR: "OR", 157 | LIKE: "LIKE", 158 | NOT: "NOT", 159 | EXISTS: "EXISTS", 160 | GROUP: "GROUP", 161 | BY: "BY", 162 | ORDER: "ORDER", 163 | HAVING: "HAVING", 164 | IS: "IS", 165 | NULL: "NULL", 166 | ASC: "ASC", 167 | DESC: "DESC", 168 | LIMIT: "LIMIT", 169 | UNION: "UNION", 170 | ALL: "ALL", 171 | CURRENT_TIMESTAMP: "CURRENT_TIMESTAMP", 172 | ON: "ON", 173 | DUPLICATE: "DUPLICATE", 174 | KEY: "KEY", 175 | AS: "AS", 176 | } 177 | 178 | // String returns the string corresponding to the token tok. 179 | // For operators, delimiters, and keywords the string is the actual 180 | // token character sequence (e.g., for the token ADD, the string is 181 | // "+"). For all other tokens the string corresponds to the token 182 | // constant name (e.g. for the token IDENT, the string is "IDENT"). 183 | // 184 | func (tok Token) String() string { 185 | s := "" 186 | if 0 <= tok && tok < Token(len(tokens)) { 187 | s = tokens[tok] 188 | } 189 | if s == "" { 190 | s = "token(" + strconv.Itoa(int(tok)) + ")" 191 | } 192 | return s 193 | } 194 | 195 | var keywords map[string]Token 196 | 197 | func init() { 198 | keywords = make(map[string]Token) 199 | for i := EOF; i < keyword_end; i++ { 200 | keywords[tokens[i]] = Token(i) 201 | } 202 | } 203 | 204 | // Lookup maps an identifier to its keyword token or IDENT (if not a keyword). 205 | // 206 | func Lookup(ident string) Token { 207 | if tok, is_keyword := keywords[ident]; is_keyword { 208 | return tok 209 | } 210 | return IDENT 211 | } 212 | 213 | // Predicates 214 | 215 | // IsLiteral returns true for tokens corresponding to identifiers 216 | // and basic type literals; it returns false otherwise. 217 | func (tok Token) IsLiteral() bool { 218 | return literal_beg < tok && tok < literal_end 219 | } 220 | 221 | // IsOperator returns true for tokens corresponding to operators and 222 | // delimiters; it returns false otherwise. 223 | func (tok Token) IsOperator() bool { 224 | return operator_beg < tok && tok < operator_end 225 | } 226 | 227 | // IsKeyword returns true for tokens corresponding to keywords; 228 | // it returns false otherwise. 229 | func (tok Token) IsKeyword() bool { 230 | return keyword_beg < tok && tok < keyword_end 231 | } 232 | -------------------------------------------------------------------------------- /sqlparser/scanner.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "io" 7 | "strings" 8 | ) 9 | 10 | const eof = 0 11 | 12 | func isSpace(ch rune) bool { 13 | return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' 14 | } 15 | func isLetter(ch rune) bool { 16 | return ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') 17 | } 18 | func isDigit(ch rune) bool { 19 | return ch >= '0' && ch <= '9' 20 | } 21 | 22 | // Scanner represents a lexical scanner. 23 | type Scanner struct { 24 | r *bufio.Reader 25 | } 26 | 27 | // NeSPACEcanner returns a new instance of Scanner. 28 | func NeSPACEcanner(r io.Reader) *Scanner { 29 | return &Scanner{r: bufio.NewReader(r)} 30 | } 31 | 32 | // read reads the next rune from the bufferred reader. 33 | // Returns the rune(0) if an error occurs (or io.EOF is returned). 34 | func (s *Scanner) read() rune { 35 | ch, _, err := s.r.ReadRune() 36 | if err != nil { 37 | return eof 38 | } 39 | return ch 40 | } 41 | 42 | // unread places the previously read rune back on the reader. 43 | func (s *Scanner) unread() { _ = s.r.UnreadRune() } 44 | 45 | // Scan returns the next token and literal value. 46 | func (s *Scanner) Scan() (tok Token, lit string) { 47 | // Read the next rune. 48 | ch := s.read() 49 | 50 | switch { 51 | case isSpace(ch): 52 | s.unread() 53 | return s.scanSpace() 54 | 55 | case isLetter(ch): 56 | s.unread() 57 | return s.scanIdent() 58 | 59 | case isDigit(ch): 60 | s.unread() 61 | return s.scanDigit() 62 | 63 | default: 64 | } 65 | 66 | switch ch { 67 | case eof: 68 | return EOF, "" 69 | 70 | case '-': 71 | s.unread() 72 | return s.scanDigit() 73 | 74 | case '`': 75 | s.unread() 76 | tok, lit = s.scanBackQuoteIdent() 77 | return tok, lit 78 | 79 | case '\'': 80 | s.unread() 81 | return s.scanApostropheIdent() 82 | 83 | case '[', ']', '{', '}', '?', '=', '.', '*', ',', '(', ')': 84 | return Lookup(string(ch)), string(ch) 85 | 86 | case '!': 87 | next := s.read() 88 | if next == '=' { 89 | return NE, "!=" 90 | } 91 | s.unread() 92 | return EXCLAMATION, "!" 93 | 94 | case '<': 95 | next := s.read() 96 | if next == '=' { 97 | return LE, "<=" 98 | } else if next == '>' { 99 | return NE, "<>" 100 | } 101 | s.unread() 102 | return LT, "<" 103 | 104 | case '>': 105 | next := s.read() 106 | if next == '=' { 107 | return GE, ">=" 108 | } 109 | s.unread() 110 | return GT, ">" 111 | 112 | case '#': 113 | if ch = s.read(); ch == '{' { 114 | return s.scanReplacer() 115 | } 116 | s.unread() 117 | return IDENT, string(ch) 118 | 119 | case '$': 120 | if ch = s.read(); ch == '{' { 121 | return s.scanVariable() 122 | } 123 | s.unread() 124 | return IDENT, string(ch) 125 | 126 | default: 127 | return IDENT, string(ch) 128 | } 129 | } 130 | 131 | // scanWhitespace consumes the current rune and all contiguous whitespace. 132 | func (s *Scanner) scanSpace() (tok Token, lit string) { 133 | // Create a buffer and read the current character into it. 134 | var buf bytes.Buffer 135 | 136 | // Read every subsequent whitespace character into the buffer. 137 | // Non-whitespace characters and EOF will cause the loop to exit. 138 | for { 139 | if ch := s.read(); ch == eof { 140 | break 141 | } else if !isSpace(ch) { 142 | s.unread() 143 | break 144 | } else { 145 | buf.WriteRune(ch) 146 | } 147 | } 148 | 149 | return SPACE, buf.String() 150 | } 151 | 152 | // scanIdent consumes the current rune and all contiguous ident runes. 153 | func (s *Scanner) scanIdent() (tok Token, lit string) { 154 | // Create a buffer and read the current character into it. 155 | var buf bytes.Buffer 156 | buf.WriteRune(s.read()) 157 | 158 | // Read every subsequent ident character into the buffer. 159 | // Non-ident characters and EOF will cause the loop to exit. 160 | for { 161 | if ch := s.read(); ch == eof { 162 | break 163 | } else if !isLetter(ch) && !isDigit(ch) && ch != '_' { 164 | s.unread() 165 | break 166 | } else { 167 | _, _ = buf.WriteRune(ch) 168 | } 169 | } 170 | 171 | // If the string matches a keyword then return that keyword. 172 | ident := buf.String() 173 | kw := strings.ToUpper(ident) 174 | if tok, ok := keywords[kw]; ok { 175 | return tok, kw 176 | } 177 | 178 | // Otherwise return as a regular identifier. 179 | return IDENT, ident 180 | } 181 | 182 | func (s *Scanner) scanBackQuoteIdent() (tok Token, lit string) { 183 | var buf bytes.Buffer 184 | buf.WriteRune(s.read()) 185 | for { 186 | if ch := s.read(); ch == eof { 187 | break 188 | } else if ch == '`' { 189 | buf.WriteByte('`') 190 | break 191 | } else { 192 | _, _ = buf.WriteRune(ch) 193 | } 194 | } 195 | return IDENT, buf.String() 196 | } 197 | 198 | func (s *Scanner) scanVariable() (tok Token, lit string) { 199 | var buf bytes.Buffer 200 | for { 201 | if ch := s.read(); ch == eof { 202 | break 203 | } else if ch == '}' { 204 | break 205 | } else { 206 | _, _ = buf.WriteRune(ch) 207 | } 208 | } 209 | return VARIABLE, buf.String() 210 | } 211 | 212 | func (s *Scanner) scanReplacer() (tok Token, lit string) { 213 | var buf bytes.Buffer 214 | for { 215 | if ch := s.read(); ch == eof { 216 | break 217 | } else if ch == '}' { 218 | break 219 | } else { 220 | _, _ = buf.WriteRune(ch) 221 | } 222 | } 223 | return REPLACER, buf.String() 224 | } 225 | 226 | func (s *Scanner) scanApostropheIdent() (tok Token, lit string) { 227 | var buf bytes.Buffer 228 | buf.WriteRune(s.read()) 229 | for { 230 | if ch := s.read(); ch == eof { 231 | break 232 | } else if ch == '\'' { 233 | buf.WriteByte('\'') 234 | if ch = s.read(); ch == eof { 235 | s.unread() 236 | break 237 | } else if ch == '\'' { 238 | // escape 239 | } else { 240 | s.unread() 241 | break 242 | } 243 | } else { 244 | buf.WriteRune(ch) 245 | } 246 | } 247 | 248 | return STRING, buf.String() 249 | } 250 | 251 | func (s *Scanner) scanDigit() (tok Token, lit string) { 252 | var buf bytes.Buffer 253 | buf.WriteRune(s.read()) 254 | tok = INT 255 | for { 256 | if ch := s.read(); ch == eof { 257 | break 258 | } else if isDigit(ch) { 259 | buf.WriteRune(ch) 260 | } else if ch == '.' { 261 | tok = FLOAT 262 | buf.WriteByte('.') 263 | } else { 264 | s.unread() 265 | break 266 | } 267 | } 268 | 269 | return tok, buf.String() 270 | } 271 | -------------------------------------------------------------------------------- /example/store/user_test.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "database/sql/driver" 5 | "strings" 6 | "testing" 7 | "time" 8 | 9 | sqlmock "github.com/DATA-DOG/go-sqlmock" 10 | "github.com/omigo/light/example/model" 11 | "github.com/omigo/light/null" 12 | "github.com/omigo/log" 13 | ) 14 | 15 | var mock sqlmock.Sqlmock 16 | 17 | func init() { 18 | var err error 19 | db, mock, err = sqlmock.New() 20 | log.Fataln(err) 21 | // defer db.Close() 22 | } 23 | 24 | func TestUserCreate(t *testing.T) { 25 | mock.ExpectExec("CREATE TABLE").WillReturnResult(sqlmock.NewResult(0, 0)) 26 | 27 | err := User.Create("users") 28 | if err != nil { 29 | t.Error(err) 30 | } 31 | } 32 | 33 | func TestUserInsert(t *testing.T) { 34 | mock.ExpectBegin() 35 | mock.ExpectExec("INSERT ").WillReturnResult(sqlmock.NewResult(1, 1)) 36 | mock.ExpectCommit() 37 | // mock.ExpectRollback() 38 | 39 | username := "admin" + time.Now().Format("150405") 40 | u := &model.User{ 41 | Username: username, 42 | Phone: username, 43 | } 44 | tx, err := db.Begin() 45 | if err != nil { 46 | t.Fatal(err) 47 | } 48 | // defer tx.Rollback() 49 | id0, err := User.Insert(tx, u) 50 | if err != nil { 51 | t.Error(err) 52 | } 53 | tx.Commit() 54 | if id0 == 0 { 55 | t.Errorf("expect id > 1, but %d", id0) 56 | } 57 | } 58 | 59 | func TestUserBulky(t *testing.T) { 60 | mock.ExpectBegin() 61 | stmt := mock.ExpectPrepare("INSERT ") 62 | stmt.ExpectExec().WillReturnResult(sqlmock.NewResult(1, 1)) 63 | stmt.ExpectExec().WillReturnResult(sqlmock.NewResult(1, 1)) 64 | mock.ExpectCommit() 65 | // mock.ExpectRollback() 66 | 67 | us := []*model.User{ 68 | { 69 | Username: "admin1" + time.Now().Format("150405"), 70 | Phone: "admin2" + time.Now().Format("150405"), 71 | }, 72 | { 73 | Username: "admin1" + time.Now().Format("150405"), 74 | Phone: "admin2" + time.Now().Format("150405"), 75 | }, 76 | } 77 | 78 | affect, _, err := User.Bulky(us) 79 | if err != nil { 80 | t.Error(err) 81 | } 82 | if affect <= 1 { 83 | t.Errorf("expect affect > 1, but %d", affect) 84 | } 85 | } 86 | 87 | func TestUserUpsert(t *testing.T) { 88 | mock.ExpectBegin() 89 | args := []driver.Value{sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()} 90 | mock.ExpectExec("INSERT INTO").WithArgs(args...).WillReturnResult(sqlmock.NewResult(0, 0)) 91 | mock.ExpectCommit() 92 | // mock.ExpectRollback() 93 | 94 | username := "admin" + time.Now().Format("150405") 95 | u := &model.User{ 96 | Username: username, 97 | Phone: username, 98 | } 99 | tx, err := db.Begin() 100 | if err != nil { 101 | t.Error(err) 102 | } 103 | defer tx.Rollback() 104 | id0, err := User.Upsert(u, tx) 105 | if err != nil { 106 | t.Error(err) 107 | } 108 | tx.Commit() 109 | if id0 != 0 { 110 | t.Errorf("expect id = 0, but %d", id0) 111 | } 112 | } 113 | 114 | func TestUserReplace(t *testing.T) { 115 | mock.ExpectExec("REPLACE INTO").WillReturnResult(sqlmock.NewResult(1, 2)) 116 | 117 | u := &model.User{ 118 | Username: "admin" + time.Now().Format("150405"), 119 | } 120 | id0, err := User.Replace(u) 121 | if err != nil { 122 | t.Error(err) 123 | } 124 | if id0 == 0 { 125 | t.Errorf("expect id > 1, but %d", id0) 126 | } 127 | } 128 | 129 | func TestUserUpdate(t *testing.T) { 130 | mock.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1)) 131 | 132 | addr := "address3" 133 | birth := time.Now() 134 | u := &model.User{ 135 | Id: 1, 136 | Username: "admin3" + time.Now().Format("150405"), 137 | Phone: "phone3", 138 | Address: &addr, 139 | Status: 3, 140 | BirthDay: &birth, 141 | } 142 | a, err := User.Update(u) 143 | if err != nil { 144 | t.Error(err) 145 | } 146 | if a != 1 { 147 | t.Errorf("expect affect 1 rows, but %d", a) 148 | } 149 | } 150 | 151 | func TestUserGet(t *testing.T) { 152 | columns := strings.Split("id, username, phone, address, status, birth_day, created, updated", ", ") 153 | returns := []driver.Value{int64(1), []byte("admin"), []byte("13812341234"), 154 | []byte("Pudong"), int64(1), time.Now(), time.Now(), time.Now()} 155 | rows := sqlmock.NewRows(columns).AddRow(returns...) 156 | mock.ExpectQuery("SELECT").WithArgs(1).WillReturnRows(rows) 157 | 158 | if u, err := User.Get(1); err != nil { 159 | t.Error(err) 160 | } else if u == nil { 161 | t.Error("expect get one record, but not") 162 | } else if u.Username != "admin" { 163 | t.Errorf("expect username=admin, but got %s", u.Username) 164 | } 165 | 166 | if err := mock.ExpectationsWereMet(); err != nil { 167 | t.Error(err) 168 | } 169 | } 170 | 171 | func TestUserList(t *testing.T) { 172 | columns := strings.Split("id, username, phone, address, status, birth_day, created, updated", ", ") 173 | returns := []driver.Value{int64(1), []byte("admin"), []byte("13812341234"), 174 | []byte("Pudong"), int64(1), time.Now(), time.Now(), time.Now()} 175 | rows := sqlmock.NewRows(columns).AddRow(returns...) 176 | mock.ExpectQuery("SELECT").WithArgs(1).WillReturnRows(rows) 177 | 178 | if u, err := User.Get(1); err != nil { 179 | t.Error(err) 180 | } else if u == nil { 181 | t.Error("expect get one record, but not") 182 | } else if u.Username != "admin" { 183 | t.Errorf("expect username=admin, but got %s", u.Username) 184 | } 185 | 186 | if err := mock.ExpectationsWereMet(); err != nil { 187 | t.Error(err) 188 | } 189 | } 190 | 191 | func TestUserPage(t *testing.T) { 192 | count := sqlmock.NewRows([]string{"count"}).AddRow(int64(10)) 193 | mock.ExpectQuery("SELECT").WillReturnRows(count) 194 | 195 | columns := strings.Split("id, username, phone, address, status, birth_day, created, updated", ", ") 196 | returns := []driver.Value{int64(1), []byte("admin"), []byte("13812341234"), 197 | []byte("Pudong"), int64(1), time.Now(), time.Now(), time.Now()} 198 | rows := sqlmock.NewRows(columns).AddRow(returns...).AddRow(returns...) 199 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 200 | 201 | update := time.Now().Add(-time.Hour) 202 | u := &model.User{ 203 | Username: "ad%", 204 | Updated: null.Timestamp{Time: &update}, 205 | Status: 9, 206 | } 207 | total, data, err := User.Page(u, []model.Status{1, 2, 3}, 1, 2) 208 | if err != nil { 209 | log.Error(err) 210 | } 211 | if total == 0 || len(data) == 0 { 212 | t.Error("expect get one or more records, but not") 213 | } 214 | } 215 | 216 | func TestUserDelete(t *testing.T) { 217 | mock.ExpectExec("DELETE").WithArgs(1).WillReturnResult(sqlmock.NewResult(0, 1)) 218 | 219 | a, err := User.Delete(1) 220 | if err != nil { 221 | t.Error(err) 222 | } 223 | if a != 1 { 224 | t.Errorf("expect affect 1 rows, but %d", a) 225 | } 226 | } 227 | -------------------------------------------------------------------------------- /sqlparser/parser.go: -------------------------------------------------------------------------------- 1 | package sqlparser 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "strings" 7 | ) 8 | 9 | func Parse(doc string) (s *Statement, err error) { 10 | return NewParser(bytes.NewBufferString(doc)).Parse() 11 | } 12 | 13 | // Parser represents a parser. 14 | type Parser struct { 15 | s *Scanner 16 | buf struct { 17 | tok Token // last read token 18 | lit string // last read literal 19 | n int // buffer size (max=1) 20 | } 21 | } 22 | 23 | // NewParser returns a new instance of Parser. 24 | func NewParser(r io.Reader) *Parser { 25 | return &Parser{s: NeSPACEcanner(r)} 26 | } 27 | 28 | func (p *Parser) Parse() (s *Statement, err error) { 29 | tok, _ := p.scanIgnoreWhitespace() 30 | p.unscan() 31 | switch tok { 32 | case SELECT: 33 | s, err = p.ParseSelect() 34 | 35 | case INSERT: 36 | s, err = p.ParseInsert() 37 | 38 | case REPLACE: 39 | s, err = p.ParseReplace() 40 | 41 | case UPDATE: 42 | s, err = p.ParseUpdate() 43 | 44 | case DELETE: 45 | s, err = p.ParseDelete() 46 | 47 | case CREATE: 48 | s, err = p.ParseCreate() 49 | 50 | default: 51 | panic("sql error, must start with SELECT/INSERT/UPDATE/DELETE") 52 | } 53 | if err != nil { 54 | return nil, err 55 | } 56 | 57 | return s, err 58 | } 59 | 60 | // scan returns the next token from the underlying scanner. 61 | // If a token has been unscanned then read that instead. 62 | func (p *Parser) scan() (tok Token, lit string) { 63 | // If we have a token on the buffer, then return it. 64 | if p.buf.n != 0 { 65 | p.buf.n = 0 66 | return p.buf.tok, p.buf.lit 67 | } 68 | 69 | // Otherwise read the next token from the scanner. 70 | tok, lit = p.s.Scan() 71 | 72 | // Save it to the buffer in case we unscan later. 73 | p.buf.tok, p.buf.lit = tok, lit 74 | 75 | return 76 | } 77 | 78 | // unscan pushes the previously read token back onto the buffer. 79 | func (p *Parser) unscan() { p.buf.n = 1 } 80 | 81 | // scanIgnoreWhitespace scans the next non-whitespace token. 82 | func (p *Parser) scanIgnoreWhitespace() (tok Token, lit string) { 83 | tok, lit = p.scan() 84 | if tok == SPACE { 85 | tok, lit = p.scan() 86 | } 87 | return 88 | } 89 | 90 | func (p *Parser) scanVariable() (v string) { 91 | tok, _ := p.scanIgnoreWhitespace() 92 | if tok != DOLLAR { 93 | panic("variable must start with $") 94 | } 95 | tok, _ = p.scanIgnoreWhitespace() 96 | if tok != LBRACES { 97 | panic("variable must wraped by ${...}") 98 | } 99 | 100 | var lit string 101 | for { 102 | tok, lit = p.scan() 103 | switch tok { 104 | default: 105 | v += lit 106 | case SPACE: 107 | // ingnore 108 | case RBRACES: 109 | return 110 | case EOF: 111 | panic("expect more words") 112 | } 113 | } 114 | } 115 | 116 | func (p *Parser) scanReplacer() (v string) { 117 | tok, _ := p.scanIgnoreWhitespace() 118 | if tok != POUND { 119 | panic("replacer must start with #") 120 | } 121 | tok, _ = p.scanIgnoreWhitespace() 122 | if tok != LBRACES { 123 | panic("replacer must wraped by #{...}") 124 | } 125 | 126 | var lit string 127 | for { 128 | tok, lit = p.scan() 129 | switch tok { 130 | default: 131 | v += lit 132 | case SPACE: 133 | // ingnore 134 | case RBRACES: 135 | return 136 | case EOF: 137 | panic("expect more words") 138 | } 139 | } 140 | } 141 | 142 | func (p *Parser) scanCondition() (v string) { 143 | tok, _ := p.scan() 144 | if tok != LBRACES { 145 | p.unscan() 146 | return "" 147 | } 148 | 149 | var buf bytes.Buffer 150 | for { 151 | tok, lit := p.scan() 152 | switch tok { 153 | default: 154 | buf.WriteString(lit) 155 | case SPACE: 156 | buf.WriteString(" ") 157 | case RBRACES: 158 | return buf.String() 159 | case EOF: 160 | panic("expect more words") 161 | } 162 | } 163 | } 164 | 165 | func (p *Parser) scanFragments() (fs []*Fragment) { 166 | // scan fragment 167 | for { 168 | f, lastToken := p.parseFragment() 169 | if f != nil { 170 | fs = append(fs, f) 171 | } 172 | if lastToken == EOF { 173 | break 174 | } 175 | } 176 | return fs 177 | } 178 | 179 | func (p *Parser) parseFragment() (*Fragment, Token) { 180 | var inner bool 181 | var buf bytes.Buffer 182 | 183 | tok, lit := p.scanIgnoreWhitespace() 184 | if tok == LBRACKET { 185 | inner = true 186 | } else if tok == RBRACKET { 187 | p.unscan() 188 | return nil, EOF 189 | } else if tok == ORDER { 190 | buf.WriteString(strings.ToUpper(lit)) 191 | } else { 192 | p.unscan() 193 | } 194 | 195 | f := Fragment{} 196 | f.Condition = p.scanCondition() 197 | if f.Condition == "" && inner { 198 | f.Condition = "-" 199 | } 200 | 201 | var last string 202 | for { 203 | tok, lit = p.scan() 204 | switch tok { 205 | default: 206 | buf.WriteString(lit) 207 | 208 | case IDENT: 209 | buf.WriteString(lit) 210 | last = lit 211 | 212 | case SPACE: 213 | buf.WriteString(SPACE.String()) 214 | 215 | case QUESTION: 216 | f.Variables = append(f.Variables, last) 217 | buf.WriteString(QUESTION.String()) 218 | 219 | // case DOLLAR: 220 | // p.unscan() 221 | // lit = p.scanVariable() 222 | // f.Variables = append(f.Variables, lit) 223 | // buf.WriteString(QUESTION.String()) 224 | case VARIABLE: 225 | f.Variables = append(f.Variables, lit) 226 | buf.WriteString(QUESTION.String()) 227 | 228 | // case POUND: 229 | // p.unscan() 230 | // lit = p.scanReplacer() 231 | // f.Replacers = append(f.Replacers, lit) 232 | // buf.WriteString("%v") 233 | case REPLACER: 234 | f.Replacers = append(f.Replacers, lit) 235 | buf.WriteString("%v") 236 | 237 | case LBRACKET: 238 | p.unscan() 239 | if inner { 240 | stmt := strings.TrimSpace(buf.String()) 241 | buf.Reset() 242 | if len(stmt) > 0 { 243 | innerFirst := Fragment{Statement: stmt, Variables: f.Variables} 244 | f.Variables = nil 245 | f.Fragments = append(f.Fragments, &innerFirst) 246 | } 247 | f.Fragments = append(f.Fragments, p.scanFragments()...) 248 | } 249 | goto END 250 | 251 | case ORDER: 252 | if inner { 253 | buf.WriteString(lit) 254 | } else { 255 | p.unscan() 256 | goto END 257 | } 258 | 259 | case RBRACKET, EOF: 260 | p.unscan() 261 | goto END 262 | } 263 | } 264 | 265 | END: 266 | tok, lit = p.scanIgnoreWhitespace() 267 | if inner { 268 | if tok != RBRACKET { 269 | panic("expect ], but got " + lit + ", " + buf.String()) 270 | } 271 | } else { 272 | p.unscan() 273 | if tok == RBRACKET { 274 | tok = EOF 275 | } 276 | } 277 | f.Statement = strings.TrimSpace(buf.String()) 278 | if strings.TrimSpace(f.Condition) == "range" { 279 | f.Condition = "-" 280 | f.Range = f.Replacers[0] 281 | } 282 | return &f, tok 283 | } 284 | -------------------------------------------------------------------------------- /goparser/parse.go: -------------------------------------------------------------------------------- 1 | package goparser 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/ast" 7 | "go/importer" 8 | "go/parser" 9 | "go/token" 10 | "go/types" 11 | "os/exec" 12 | "strconv" 13 | "strings" 14 | 15 | "github.com/omigo/log" 16 | ) 17 | 18 | type Interface struct { 19 | Source string 20 | Log bool 21 | Timeout int64 22 | 23 | Package string // itf 24 | Imports map[string]string // database/sql => sql 25 | Name string // IUser 26 | 27 | VarName string 28 | StoreName string 29 | 30 | Methods []*Method 31 | 32 | // full-type-name : type-profile 33 | Cache map[string]*Profile 34 | } 35 | 36 | func Parse(filename string, src interface{}) (*Interface, error) { 37 | fset := token.NewFileSet() 38 | f, err := parser.ParseFile(fset, filename, src, parser.ParseComments) 39 | if err != nil { 40 | log.Panic(err) 41 | } 42 | // ast.Print(fset, f) 43 | 44 | itf := &Interface{ 45 | Source: filename, 46 | Package: f.Name.Name, 47 | Imports: map[string]string{}, 48 | } 49 | 50 | goBuild(filename) 51 | 52 | extractDocs(itf, f, fset) 53 | 54 | extractTypes(itf, f, fset) 55 | 56 | // log.JsonIndent(itf) 57 | 58 | itf.makeCache() 59 | 60 | return itf, nil 61 | } 62 | 63 | func goBuild(src string) { 64 | cmd := exec.Command("go", "build", src) 65 | out, err := cmd.CombinedOutput() 66 | if bytes.HasSuffix(out, []byte("command-line-arguments\n")) { 67 | fmt.Printf("%s", out[:len(out)-23]) 68 | } else { 69 | fmt.Printf("%s", out) 70 | } 71 | if err != nil { 72 | panic(err) 73 | } 74 | } 75 | 76 | func extractDocs(itf *Interface, f *ast.File, fset *token.FileSet) { 77 | for _, decl := range f.Decls { 78 | if genDecl, ok := decl.(*ast.GenDecl); ok { 79 | switch genDecl.Tok { 80 | case token.IMPORT: 81 | for _, spec := range genDecl.Specs { 82 | if importSpec, ok := spec.(*ast.ImportSpec); ok { 83 | path, err := strconv.Unquote(importSpec.Path.Value) 84 | if err != nil { 85 | panic(importSpec.Path.Value + " " + err.Error()) 86 | } 87 | if importSpec.Name != nil { 88 | itf.Imports[path] = importSpec.Name.Name 89 | } else { 90 | itf.Imports[path] = "" 91 | } 92 | } 93 | } 94 | 95 | case token.TYPE: 96 | for _, spec := range genDecl.Specs { 97 | if typeSpec, ok := spec.(*ast.TypeSpec); ok { 98 | if interfaceType, ok := typeSpec.Type.(*ast.InterfaceType); ok { 99 | if itf.Name != "" { 100 | panic("one file must contains one interface only") 101 | } 102 | 103 | itf.Name = typeSpec.Name.Name 104 | for _, method := range interfaceType.Methods.List { 105 | m := NewMethod(itf, method.Names[0].Name, getDoc(method.Doc)) 106 | itf.Methods = append(itf.Methods, m) 107 | } 108 | } 109 | } 110 | } 111 | } 112 | } 113 | } 114 | } 115 | 116 | func getDoc(cg *ast.CommentGroup) (comment string) { 117 | if cg == nil { 118 | return "" 119 | } 120 | for _, c := range cg.List { 121 | comment += strings.TrimSpace(c.Text[2:]) + " " // remove `//` 122 | } 123 | return strings.TrimSpace(comment) 124 | } 125 | 126 | func extractTypes(itf *Interface, f *ast.File, fset *token.FileSet) { 127 | info := types.Info{Defs: make(map[*ast.Ident]types.Object)} 128 | conf := types.Config{Importer: importer.For("source", nil)} 129 | _, err := conf.Check(itf.Package, fset, []*ast.File{f}, &info) 130 | log.Fataln(err) 131 | 132 | for k, obj := range info.Defs { 133 | if k.Obj != nil { 134 | if k.Name == itf.Name { 135 | if k.Obj.Kind == ast.Typ { 136 | // get method name and params/returns 137 | if itfType, ok := obj.Type().Underlying().(*types.Interface); ok { 138 | for i := 0; i < itfType.NumMethods(); i++ { 139 | x := itfType.Method(i) 140 | m := getMethodByName(itf, x.Name()) 141 | y := x.Type().(*types.Signature) 142 | m.Params = NewParams(y.Params()) 143 | m.Results = NewResults(y.Results()) 144 | } 145 | } 146 | } 147 | } else { 148 | if tn, ok := obj.Type().(*types.Named); ok { 149 | if itf.Name == tn.Obj().Name() { 150 | itf.VarName = k.Name 151 | } 152 | } 153 | } 154 | } 155 | } 156 | } 157 | 158 | func getMethodByName(s *Interface, name string) *Method { 159 | for _, a := range s.Methods { 160 | if a.Name == name { 161 | return a 162 | } 163 | } 164 | return nil 165 | } 166 | 167 | func (itf *Interface) makeCache() { 168 | itf.Cache = map[string]*Profile{} 169 | 170 | for _, method := range itf.Methods { 171 | for _, param := range method.Params.List { 172 | key := param.Type.String() 173 | profile, ok := itf.Cache[key] 174 | if !ok { 175 | profile = NewProfile(param.Type, itf.Cache, true) 176 | itf.Cache[key] = profile 177 | } 178 | 179 | for _, f := range profile.Fields { 180 | if f.PkgPath != "" { 181 | itf.Imports[f.PkgPath] = "" 182 | } 183 | 184 | // field 是一个变量,在不同的方法中,名字不一样,所以不能公用 185 | field := new(Variable) 186 | *field = *f 187 | 188 | k := field.Type.String() 189 | p, ok := itf.Cache[k] 190 | if !ok { 191 | p = NewProfile(field.Type, itf.Cache, false) 192 | itf.Cache[k] = p 193 | } 194 | field.Profile = p 195 | field.Parent = param 196 | 197 | method.Params.Names[field.Name] = field 198 | method.Params.Names[underLower(field.Name)] = field 199 | method.Params.Names[param.Name+"."+field.Name] = field 200 | if field.TagAlias != "" { 201 | method.Params.Names[field.TagAlias] = field 202 | } 203 | if profile.Slice { 204 | if param.Name[len(param.Name)-1] == 's' { 205 | elem := param.Name[:len(param.Name)-1] 206 | method.Params.Names[elem+"."+field.Name] = field 207 | } 208 | } 209 | } 210 | *param.Profile = *profile 211 | method.Params.Names[param.Name] = param 212 | method.Params.Names[underLower(param.Name)] = param 213 | } 214 | for _, result := range method.Results.List { 215 | result.Name = "" 216 | key := result.Type.String() 217 | profile, ok := itf.Cache[key] 218 | if !ok { 219 | profile = NewProfile(result.Type, itf.Cache, true) 220 | itf.Cache[key] = profile 221 | } 222 | 223 | for _, f := range profile.Fields { 224 | if f.PkgPath != "" { 225 | itf.Imports[f.PkgPath] = "" 226 | } 227 | 228 | field := new(Variable) 229 | *field = *f 230 | 231 | k := field.Type.String() 232 | p, ok := itf.Cache[k] 233 | if !ok { 234 | p = NewProfile(field.Type, itf.Cache, false) 235 | itf.Cache[k] = p 236 | } 237 | field.Profile = p 238 | field.Parent = result 239 | if field.Name == "" { 240 | log.JsonIndent(profile) 241 | panic("unreachable code") 242 | } 243 | method.Results.Names[field.Name] = field 244 | method.Results.Names[underLower(field.Name)] = field 245 | if result.Name != "" { 246 | method.Results.Names[result.Name+"."+field.Name] = field 247 | } 248 | if field.TagAlias != "" { 249 | method.Results.Names[field.TagAlias] = field 250 | } 251 | } 252 | result.Profile = profile 253 | if result.Name != "" { 254 | method.Results.Names[result.Name] = result 255 | } 256 | } 257 | } 258 | } 259 | 260 | func underLower(field string) string { 261 | var buf bytes.Buffer 262 | for i, v := range field { 263 | if v >= 'A' && v <= 'Z' { 264 | if i != 0 { 265 | buf.WriteByte('_') 266 | } 267 | buf.WriteRune(v + 32) 268 | } else { 269 | buf.WriteRune(v) 270 | } 271 | } 272 | return buf.String() 273 | } 274 | 275 | func upperCamelCase(field string) string { 276 | var buf bytes.Buffer 277 | var upper bool = true 278 | for _, v := range field { 279 | if v == '_' { 280 | upper = true 281 | } else if upper { 282 | buf.WriteRune(v - 32) 283 | upper = false 284 | } else { 285 | buf.WriteRune(v) 286 | } 287 | } 288 | return buf.String() 289 | } 290 | -------------------------------------------------------------------------------- /null/ints.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "database/sql/driver" 5 | "reflect" 6 | "strconv" 7 | ) 8 | 9 | type NullInt struct{ Int *int } 10 | type NullInt8 struct{ Int8 *int8 } 11 | type NullUint8 struct{ Uint8 *uint8 } 12 | type NullInt16 struct{ Int16 *int16 } 13 | type NullUint16 struct{ Uint16 *uint16 } 14 | type NullInt32 struct{ Int32 *int32 } 15 | type NullUint32 struct{ Uint32 *uint32 } 16 | type NullInt64 struct{ Int64 *int64 } 17 | type NullUint64 struct{ Uint64 *uint64 } 18 | 19 | func (n *NullInt) IsEmpty() bool { return isEmpty(n.Int) } 20 | func (n *NullInt8) IsEmpty() bool { return isEmpty(n.Int8) } 21 | func (n *NullUint8) IsEmpty() bool { return isEmpty(n.Uint8) } 22 | func (n *NullInt16) IsEmpty() bool { return isEmpty(n.Int16) } 23 | func (n *NullUint16) IsEmpty() bool { return isEmpty(n.Uint16) } 24 | func (n *NullInt32) IsEmpty() bool { return isEmpty(n.Int32) } 25 | func (n *NullUint32) IsEmpty() bool { return isEmpty(n.Uint32) } 26 | func (n *NullInt64) IsEmpty() bool { return isEmpty(n.Int64) } 27 | func (n *NullUint64) IsEmpty() bool { return isEmpty(n.Uint64) } 28 | 29 | func (n *NullInt) MarshalJSON() ([]byte, error) { return marshalJSON(n.Int) } 30 | func (n *NullInt8) MarshalJSON() ([]byte, error) { return marshalJSON(n.Int8) } 31 | func (n *NullUint8) MarshalJSON() ([]byte, error) { return marshalJSON(n.Uint8) } 32 | func (n *NullInt16) MarshalJSON() ([]byte, error) { return marshalJSON(n.Int16) } 33 | func (n *NullUint16) MarshalJSON() ([]byte, error) { return marshalJSON(n.Uint16) } 34 | func (n *NullInt32) MarshalJSON() ([]byte, error) { return marshalJSON(n.Int32) } 35 | func (n *NullUint32) MarshalJSON() ([]byte, error) { return marshalJSON(n.Uint32) } 36 | func (n *NullInt64) MarshalJSON() ([]byte, error) { return marshalJSON(n.Int64) } 37 | func (n *NullUint64) MarshalJSON() ([]byte, error) { return marshalJSON(n.Uint64) } 38 | 39 | func (n *NullInt) UnmarshalJSON(data []byte) error { return unmarshalJSON(n.Int, data) } 40 | func (n *NullInt8) UnmarshalJSON(data []byte) error { return unmarshalJSON(n.Int8, data) } 41 | func (n *NullUint8) UnmarshalJSON(data []byte) error { return unmarshalJSON(n.Uint8, data) } 42 | func (n *NullInt16) UnmarshalJSON(data []byte) error { return unmarshalJSON(n.Int16, data) } 43 | func (n *NullUint16) UnmarshalJSON(data []byte) error { return unmarshalJSON(n.Uint16, data) } 44 | func (n *NullInt32) UnmarshalJSON(data []byte) error { return unmarshalJSON(n.Int32, data) } 45 | func (n *NullUint32) UnmarshalJSON(data []byte) error { return unmarshalJSON(n.Uint32, data) } 46 | func (n *NullInt64) UnmarshalJSON(data []byte) error { return unmarshalJSON(n.Int64, data) } 47 | func (n *NullUint64) UnmarshalJSON(data []byte) error { return unmarshalJSON(n.Uint64, data) } 48 | 49 | func (n *NullInt) String() string { return toString(n.Int) } 50 | func (n *NullInt8) String() string { return toString(n.Int8) } 51 | func (n *NullUint8) String() string { return toString(n.Uint8) } 52 | func (n *NullInt16) String() string { return toString(n.Int16) } 53 | func (n *NullUint16) String() string { return toString(n.Uint16) } 54 | func (n *NullInt32) String() string { return toString(n.Int32) } 55 | func (n *NullUint32) String() string { return toString(n.Uint32) } 56 | func (n *NullInt64) String() string { return toString(n.Int64) } 57 | func (n *NullUint64) String() string { return toString(n.Uint64) } 58 | 59 | func (n *NullInt) Scan(value interface{}) error { return scan(n.Int, value) } 60 | func (n *NullInt8) Scan(value interface{}) error { return scan(n.Int8, value) } 61 | func (n *NullUint8) Scan(value interface{}) error { return scan(n.Uint8, value) } 62 | func (n *NullInt16) Scan(value interface{}) error { return scan(n.Int16, value) } 63 | func (n *NullUint16) Scan(value interface{}) error { return scan(n.Uint16, value) } 64 | func (n *NullInt32) Scan(value interface{}) error { return scan(n.Int32, value) } 65 | func (n *NullUint32) Scan(value interface{}) error { return scan(n.Uint32, value) } 66 | func (n *NullInt64) Scan(value interface{}) error { return scan(n.Int64, value) } 67 | func (n *NullUint64) Scan(value interface{}) error { return scan(n.Uint64, value) } 68 | 69 | func (n NullInt) Value() (driver.Value, error) { return value(n.Int) } 70 | func (n NullInt8) Value() (driver.Value, error) { return value(n.Int8) } 71 | func (n NullUint8) Value() (driver.Value, error) { return value(n.Uint8) } 72 | func (n NullInt16) Value() (driver.Value, error) { return value(n.Int16) } 73 | func (n NullUint16) Value() (driver.Value, error) { return value(n.Uint16) } 74 | func (n NullInt32) Value() (driver.Value, error) { return value(n.Int32) } 75 | func (n NullUint32) Value() (driver.Value, error) { return value(n.Uint32) } 76 | func (n NullInt64) Value() (driver.Value, error) { return value(n.Int64) } 77 | func (n NullUint64) Value() (driver.Value, error) { return value(n.Uint64) } 78 | 79 | func toString(ptr interface{}) string { 80 | if ptr == nil { 81 | return "nil" 82 | } 83 | 84 | i64 := toInt64(ptr) 85 | if i64 == 0 { 86 | return "nil" 87 | } 88 | 89 | return strconv.FormatInt(i64, 10) 90 | } 91 | 92 | func value(ptr interface{}) (driver.Value, error) { 93 | if ptr == nil { 94 | return nil, nil 95 | } 96 | 97 | i64 := toInt64(ptr) 98 | if i64 == 0 { 99 | return nil, nil 100 | } 101 | return i64, nil 102 | } 103 | 104 | func toInt64(ptr interface{}) (i64 int64) { 105 | switch v := ptr.(type) { 106 | case *int8: 107 | i64 = int64(*v) 108 | case *uint8: 109 | i64 = int64(*v) 110 | // case *byte: 111 | // i64 = int64(*v) 112 | case *int16: 113 | i64 = int64(*v) 114 | case *uint16: 115 | i64 = int64(*v) 116 | case *int32: 117 | i64 = int64(*v) 118 | case *uint32: 119 | i64 = int64(*v) 120 | case *int: 121 | i64 = int64(*v) 122 | // case *rune: 123 | // i64 = int64(*v) 124 | case *int64: 125 | i64 = *v 126 | case *uint64: 127 | i64 = int64(*v) 128 | 129 | default: 130 | panic("unsupported type " + reflect.TypeOf(v).String()) 131 | } 132 | return 133 | } 134 | 135 | func scan(ptr, value interface{}) error { 136 | if value == nil { 137 | return nil 138 | } 139 | 140 | var i64 int64 141 | switch v := value.(type) { 142 | case int: 143 | i64 = int64(v) 144 | case int64: 145 | i64 = v 146 | case *int64: 147 | i64 = *v 148 | case uint64: 149 | i64 = int64(v) 150 | case *uint64: 151 | i64 = int64(*v) 152 | case int8: 153 | i64 = int64(v) 154 | case *int8: 155 | i64 = int64(*v) 156 | case uint8: 157 | i64 = int64(v) 158 | case *uint8: 159 | i64 = int64(*v) 160 | case int16: 161 | i64 = int64(v) 162 | case *int16: 163 | i64 = int64(*v) 164 | case uint16: 165 | i64 = int64(v) 166 | case *uint16: 167 | i64 = int64(*v) 168 | case int32: 169 | i64 = int64(v) 170 | case *int32: 171 | i64 = int64(*v) 172 | case uint32: 173 | i64 = int64(v) 174 | case *uint32: 175 | i64 = int64(*v) 176 | case []uint8: 177 | var err error 178 | i64, err = strconv.ParseInt(string(v), 10, 64) 179 | if err != nil { 180 | return err 181 | } 182 | default: 183 | panic("unsupported type " + reflect.TypeOf(v).String()) 184 | } 185 | 186 | fromI64(ptr, i64) 187 | 188 | return nil 189 | } 190 | 191 | func isEmpty(ptr interface{}) bool { 192 | if ptr == nil { 193 | return true 194 | } 195 | return toInt64(ptr) == 0 196 | } 197 | 198 | func marshalJSON(ptr interface{}) ([]byte, error) { 199 | if ptr == nil { 200 | return []byte{'0'}, nil 201 | } 202 | i64 := toInt64(ptr) 203 | return []byte(strconv.FormatInt(i64, 10)), nil 204 | } 205 | 206 | func unmarshalJSON(ptr interface{}, data []byte) error { 207 | if data == nil { 208 | return nil 209 | } 210 | i64, err := strconv.ParseInt(string(data), 10, 64) 211 | if err != nil { 212 | return err 213 | } 214 | 215 | fromI64(ptr, i64) 216 | return nil 217 | } 218 | 219 | func fromI64(ptr interface{}, i64 int64) { 220 | switch v := ptr.(type) { 221 | case *int: 222 | *v = int(i64) 223 | case *int64: 224 | *v = i64 225 | case *uint64: 226 | *v = uint64(i64) 227 | case *int8: 228 | *v = int8(i64) 229 | case *uint8: 230 | *v = uint8(i64) 231 | // case *byte: 232 | // *v = byte(i64) 233 | case *int16: 234 | *v = int16(i64) 235 | case *uint16: 236 | *v = uint16(i64) 237 | case *int32: 238 | *v = int32(i64) 239 | case *uint32: 240 | *v = uint32(i64) 241 | // case *rune: 242 | // *v = rune(i64) 243 | 244 | default: 245 | panic("unsupported type " + reflect.TypeOf(v).String()) 246 | } 247 | } 248 | -------------------------------------------------------------------------------- /generator/template.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | const tpl = ` 4 | {{- /*************** header template *****************/}} 5 | {{define "header" -}} 6 | // !!! DO NOT EDIT THIS FILE. It is generated by 'light' tool. 7 | // @light: https://github.com/omigo/light 8 | // Generated from source: {{.Source}} 9 | package {{.Package}} 10 | import ( 11 | "bytes" 12 | "fmt" 13 | "github.com/omigo/light/light" 14 | "github.com/omigo/light/null" 15 | {{- if .Log }} 16 | "github.com/omigo/log" 17 | {{- end}} 18 | 19 | {{- range $path, $short := .Imports}} 20 | {{/* $short */}} "{{$path}}" 21 | {{- end}} 22 | ) 23 | 24 | {{if .VarName}} 25 | func init() { {{.VarName}} = new(Store{{.Name}}) } 26 | {{end}} 27 | 28 | type Store{{.Name}} struct{} 29 | {{end}} 30 | 31 | {{- /*************** fragment template *****************/}} 32 | {{define "fragment" -}} 33 | {{- if .Fragment.Condition}} 34 | if {{.Fragment.Condition}} { 35 | {{- end }} 36 | {{- if .Fragment.Statement }} 37 | {{- if .Fragment.Range }} 38 | if len({{.Fragment.Range}}) > 0 { 39 | {{- if .Buf}} 40 | fmt.Fprintf(&{{.Buf}}, "{{.Fragment.Statement}} ", strings.Repeat(",?", len({{.Fragment.Range}}))[1:]) 41 | {{- end}} 42 | {{- if .Args}} 43 | for _, v := range {{.Fragment.Range}} { 44 | {{.Args}} = append({{.Args}}, v) 45 | } 46 | {{- end}} 47 | } 48 | {{- else if .Fragment.Replacers }} 49 | {{- if .Buf}} 50 | fmt.Fprintf(&{{.Buf}}, "{{.Fragment.Statement}} "{{range $elem := .Fragment.Replacers}}, {{$elem}}{{end}}) 51 | {{- end}} 52 | {{- else }} 53 | {{- if .Buf}} 54 | {{.Buf}}.WriteString("{{.Fragment.Statement}} ") 55 | {{- end}} 56 | {{- end }} 57 | {{- if .Fragment.Variables }} 58 | {{- if .Args}} 59 | {{.Args}} = append({{.Args}}{{range $elem := .Fragment.Variables}}, {{LookupValueOfParams $.Method $elem}}{{end}}) 60 | {{- end}} 61 | {{- end }} 62 | {{- else }} 63 | {{- range $fragment := .Fragment.Fragments }} 64 | {{- template "fragment" (aggregate $.Method $fragment $.Buf $.Args)}} 65 | {{- end }} 66 | {{- end }} 67 | {{- if .Fragment.Condition}} 68 | } 69 | {{- end }} 70 | {{- end}} 71 | 72 | 73 | {{- /*************** ddl template *****************/}} 74 | {{define "ddl" -}} 75 | query := buf.String() 76 | {{if .Interface.Log -}} 77 | log.Debug(query) 78 | {{if HasVariable $ -}} 79 | log.Debug(args...) 80 | {{end -}} 81 | {{end -}} 82 | ctx, cancel := context.WithTimeout(context.Background(), {{.Interface.Timeout}}*time.Second) 83 | defer cancel() 84 | _, err := exec.ExecContext(ctx, query{{if HasVariable $ }}, args...{{end}}) 85 | {{if .Interface.Log -}} 86 | if err != nil { 87 | log.Error(query) 88 | {{if HasVariable $ -}} 89 | log.Error(args...) 90 | {{end -}} 91 | log.Error(err) 92 | } 93 | {{end -}} 94 | return err 95 | {{end}} 96 | 97 | {{- /*************** update/delete template *****************/}} 98 | {{define "update" -}} 99 | query := buf.String() 100 | {{if .Interface.Log -}} 101 | log.Debug(query) 102 | {{if HasVariable $ -}} 103 | log.Debug(args...) 104 | {{end -}} 105 | {{end -}} 106 | ctx, cancel := context.WithTimeout(context.Background(), {{.Interface.Timeout}}*time.Second) 107 | defer cancel() 108 | res, err := exec.ExecContext(ctx, query{{if HasVariable $ }}, args...{{end}}) 109 | if err != nil { 110 | {{if .Interface.Log -}} 111 | log.Error(query) 112 | {{if HasVariable $ -}} 113 | log.Error(args...) 114 | {{end -}} 115 | log.Error(err) 116 | {{end -}} 117 | return 0, err 118 | } 119 | return res.RowsAffected() 120 | {{end -}} 121 | 122 | {{- /*************** insert template *****************/}} 123 | {{define "insert" -}} 124 | query := buf.String() 125 | {{if .Interface.Log -}} 126 | log.Debug(query) 127 | {{if HasVariable $ -}} 128 | log.Debug(args...) 129 | {{end -}} 130 | {{end -}} 131 | ctx, cancel := context.WithTimeout(context.Background(), {{.Interface.Timeout}}*time.Second) 132 | defer cancel() 133 | res, err := exec.ExecContext(ctx, query{{if HasVariable $ }}, args...{{end}}) 134 | if err != nil { 135 | {{if .Interface.Log -}} 136 | log.Error(query) 137 | {{if HasVariable $ -}} 138 | log.Error(args...) 139 | {{end -}} 140 | log.Error(err) 141 | {{end -}} 142 | return 0, err 143 | } 144 | return res.LastInsertId() 145 | {{end}} 146 | 147 | 148 | {{- /*************** bulky template *****************/}} 149 | {{define "bulky" -}} 150 | xn := int64(len({{ParamsLast .Params}})) 151 | if xn == 0 { 152 | return 0, 0, nil 153 | } 154 | 155 | var xaffect, xignore int64 156 | var buf bytes.Buffer 157 | 158 | {{- range $i, $fragment := .Statement.Fragments }} 159 | {{template "fragment" (aggregate $ $fragment "buf" "")}} 160 | {{- end }} 161 | 162 | query := buf.String() 163 | log.Debug(query) 164 | 165 | {{- $tx := MethodTx $ -}} 166 | {{- if $tx}} 167 | {{- if eq $tx "tx"}}{{else}} 168 | var tx = {{$tx}} 169 | {{- end}} 170 | {{- else}} 171 | tx, err := db.Begin() 172 | if err != nil { 173 | {{if .Interface.Log -}} 174 | log.Error(err) 175 | {{end -}} 176 | return 0, xn, err 177 | } 178 | defer tx.Rollback() 179 | {{- end}} 180 | 181 | stmt, err := tx.Prepare(query) 182 | if err != nil { 183 | {{if .Interface.Log -}} 184 | log.Error(query, err) 185 | {{end -}} 186 | return 0, xn, err 187 | } 188 | var args []interface{} 189 | for _, {{ParamsLastElem .Params}} := range {{ParamsLast .Params}} { 190 | args = args[:0] 191 | {{- range $i, $fragment := .Statement.Fragments }} 192 | {{- template "fragment" (aggregate $ $fragment "" "args")}} 193 | {{- end }} 194 | log.Debug(args...) 195 | if _, err := stmt.Exec(args...); err != nil { 196 | xignore++ 197 | {{if .Interface.Log -}} 198 | log.Error(args...) 199 | log.Error(err) 200 | {{end -}} 201 | } else { 202 | xaffect++ 203 | } 204 | } 205 | {{- if not $tx}} 206 | if err := tx.Commit(); err != nil { 207 | return 0, xn, err 208 | } 209 | {{- end}} 210 | 211 | return xaffect, xignore, nil 212 | {{end}} 213 | 214 | {{- /*************** get template *****************/}} 215 | {{define "get" -}} 216 | query := buf.String() 217 | {{if .Interface.Log -}} 218 | log.Debug(query) 219 | {{if HasVariable $ -}} 220 | log.Debug(args...) 221 | {{end -}} 222 | {{end -}} 223 | ctx, cancel := context.WithTimeout(context.Background(), {{.Interface.Timeout}}*time.Second) 224 | defer cancel() 225 | row := exec.QueryRowContext(ctx, query{{if HasVariable $ }}, args...{{end}}) 226 | xu := new({{ResultElemTypeName .Results.Result}}) 227 | xdst := []interface{}{ 228 | {{- range $i, $field := .Statement.Fields -}} 229 | {{- if $i -}} , {{- end -}} 230 | {{- LookupScanOfResults $ $field -}} 231 | {{- end -}} 232 | } 233 | err := row.Scan(xdst...) 234 | if err != nil { 235 | if err == sql.ErrNoRows { 236 | return nil, nil 237 | } 238 | {{if .Interface.Log -}} 239 | log.Error(query) 240 | {{if HasVariable $ -}} 241 | log.Error(args...) 242 | {{end -}} 243 | log.Error(err) 244 | {{end -}} 245 | return nil, err 246 | } 247 | {{if .Interface.Log -}} 248 | log.Trace(xdst) 249 | {{end -}} 250 | return xu, err 251 | {{end}} 252 | 253 | {{- /*************** list template *****************/}} 254 | {{define "list" -}} 255 | query := buf.String() 256 | {{if .Interface.Log -}} 257 | log.Debug(query) 258 | {{if HasVariable $ -}} 259 | log.Debug(args...) 260 | {{end -}} 261 | {{end -}} 262 | ctx, cancel := context.WithTimeout(context.Background(), {{.Interface.Timeout}}*time.Second) 263 | defer cancel() 264 | rows, err := exec.QueryContext(ctx, query{{if HasVariable $ }}, args...{{end}}) 265 | if err != nil { 266 | {{if .Interface.Log -}} 267 | log.Error(query) 268 | {{if HasVariable $ -}} 269 | log.Error(args...) 270 | {{end -}} 271 | log.Error(err) 272 | {{end -}} 273 | return nil, err 274 | } 275 | defer rows.Close() 276 | var data {{ResultTypeName .Results.Result}} 277 | for rows.Next() { 278 | xu := new({{ ResultElemTypeName .Results.Result }}) 279 | data = append(data, xu) 280 | xdst := []interface{}{ 281 | {{- range $i, $field := .Statement.Fields -}} 282 | {{- if $i -}} , {{- end -}} 283 | {{- LookupScanOfResults $ $field -}} 284 | {{- end -}} 285 | } 286 | err = rows.Scan(xdst...) 287 | if err != nil { 288 | {{if .Interface.Log -}} 289 | log.Error(query) 290 | {{if HasVariable $ -}} 291 | log.Error(args...) 292 | {{end -}} 293 | log.Error(err) 294 | {{end -}} 295 | return nil, err 296 | } 297 | {{if .Interface.Log -}} 298 | log.Trace(xdst) 299 | {{end -}} 300 | } 301 | if err = rows.Err(); err != nil { 302 | {{if .Interface.Log -}} 303 | log.Error(query) 304 | {{if HasVariable $ -}} 305 | log.Error(args...) 306 | {{end -}} 307 | log.Error(err) 308 | {{end -}} 309 | return nil, err 310 | } 311 | return data, nil 312 | {{end}} 313 | 314 | {{- /*************** page template *****************/}} 315 | {{define "page" -}} 316 | var total int64 317 | totalQuery := "SELECT count(1) "+ buf.String() 318 | {{if .Interface.Log -}} 319 | log.Debug(totalQuery) 320 | {{if HasVariable $ -}} 321 | log.Debug(args...) 322 | {{end -}} 323 | {{end -}} 324 | ctx, cancel := context.WithTimeout(context.Background(), {{.Interface.Timeout}}*time.Second) 325 | defer cancel() 326 | err := exec.QueryRowContext(ctx, totalQuery{{if HasVariable $ }}, args...{{end}}).Scan(&total) 327 | if err != nil { 328 | {{if .Interface.Log -}} 329 | log.Error(totalQuery) 330 | {{if HasVariable $ -}} 331 | log.Error(args...) 332 | {{end -}} 333 | log.Error(err) 334 | {{end -}} 335 | return 0, nil, err 336 | } 337 | {{if .Interface.Log -}} 338 | log.Debug(total) 339 | {{end -}} 340 | 341 | query := xFirstBuf.String() + buf.String() + xLastBuf.String() 342 | args = append(xFirstArgs, args...) 343 | args = append(args, xLastArgs...) 344 | {{if .Interface.Log -}} 345 | log.Debug(query) 346 | {{if HasVariable $ -}} 347 | log.Debug(args...) 348 | {{end -}} 349 | {{end -}} 350 | ctx, cancel = context.WithTimeout(context.Background(), {{.Interface.Timeout}}*time.Second) 351 | defer cancel() 352 | rows, err := exec.QueryContext(ctx, query{{if HasVariable $ }}, args...{{end}}) 353 | if err != nil { 354 | {{if .Interface.Log -}} 355 | log.Error(query) 356 | {{if HasVariable $ -}} 357 | log.Error(args...) 358 | {{end -}} 359 | log.Error(err) 360 | {{end -}} 361 | return 0, nil, err 362 | } 363 | defer rows.Close() 364 | var data {{ResultTypeName .Results.Result}} 365 | for rows.Next() { 366 | xu := new({{ ResultElemTypeName .Results.Result }}) 367 | data = append(data, xu) 368 | xdst := []interface{}{ 369 | {{- range $i, $field := .Statement.Fields -}} 370 | {{- if $i -}} , {{- end -}} 371 | {{- LookupScanOfResults $ $field -}} 372 | {{- end -}} 373 | } 374 | err = rows.Scan(xdst...) 375 | if err != nil { 376 | {{if .Interface.Log -}} 377 | log.Error(query) 378 | {{if HasVariable $ -}} 379 | log.Error(args...) 380 | {{end -}} 381 | log.Error(err) 382 | {{end -}} 383 | return 0, nil, err 384 | } 385 | {{if .Interface.Log -}} 386 | log.Trace(xdst) 387 | {{end -}} 388 | } 389 | if err = rows.Err(); err != nil { 390 | {{if .Interface.Log -}} 391 | log.Error(query) 392 | {{if HasVariable $ -}} 393 | log.Error(args...) 394 | {{end -}} 395 | log.Error(err) 396 | {{end -}} 397 | return 0, nil, err 398 | } 399 | return total, data, nil 400 | {{end}} 401 | 402 | 403 | {{- /*************** agg template *****************/}} 404 | {{define "agg" -}} 405 | query := buf.String() 406 | {{if .Interface.Log -}} 407 | log.Debug(query) 408 | {{if HasVariable $ -}} 409 | log.Debug(args...) 410 | {{end -}} 411 | {{end -}} 412 | var xu {{ResultTypeName .Results.Result}} 413 | ctx, cancel := context.WithTimeout(context.Background(), {{.Interface.Timeout}}*time.Second) 414 | defer cancel() 415 | err := exec.QueryRowContext(ctx, query{{if HasVariable $ }}, args...{{end}}).Scan({{ResultWrap .Results.Result}}) 416 | if err != nil { 417 | if err == sql.ErrNoRows { 418 | {{- if .Interface.Log}} 419 | log.Debug(xu) 420 | {{- end}} 421 | return xu, nil 422 | } 423 | {{if .Interface.Log -}} 424 | log.Error(query) 425 | {{if HasVariable $ -}} 426 | log.Error(args...) 427 | {{end -}} 428 | log.Error(err) 429 | {{end -}} 430 | return xu, err 431 | } 432 | {{if .Interface.Log -}} 433 | log.Debug(xu) 434 | {{end -}} 435 | return xu, nil 436 | {{end}} 437 | 438 | {{- /*************** main *****************/ -}} 439 | {{template "header" . -}} 440 | {{range $method := .Methods -}} 441 | func (*Store{{$.Name}}) {{$method.Signature}} { 442 | {{- if eq $method.Type "bulky"}} 443 | {{template "bulky" $method -}} 444 | {{- else}} 445 | {{- $tx := MethodTx $method -}} 446 | var exec = {{if $tx }} light.GetExec({{$tx}}, db) {{else}} db {{end}} 447 | var buf bytes.Buffer 448 | {{if HasVariable $method -}} 449 | var args []interface{} 450 | {{end -}} 451 | 452 | {{- range $i, $fragment := .Statement.Fragments }} 453 | {{/* if type=page, return field statement and ordery by limit statement reserved */}} 454 | {{$last := sub (len $method.Statement.Fragments) 1 }} 455 | {{if and (eq $method.Type "page") (eq $i 0) }} 456 | var xFirstBuf bytes.Buffer 457 | var xFirstArgs []interface{} 458 | {{- template "fragment" (aggregate $method $fragment "xFirstBuf" "xFirstArgs")}} 459 | {{else if and (eq $method.Type "page") (eq $i $last) }} 460 | var xLastBuf bytes.Buffer 461 | var xLastArgs []interface{} 462 | {{- template "fragment" (aggregate $method $fragment "xLastBuf" "xLastArgs")}} 463 | {{else if not (and (eq $method.Type "page") (or (eq $i 0) (eq $i $last)))}} 464 | {{template "fragment" (aggregate $method $fragment "buf" "args")}} 465 | {{end}} 466 | {{- end }} 467 | 468 | {{- if eq $method.Type "ddl"}} 469 | {{- template "ddl" $method}} 470 | {{- else if or (eq $method.Type "update") (eq $method.Type "delete")}} 471 | {{- template "update" $method}} 472 | {{- else if eq $method.Type "insert"}} 473 | {{- template "insert" $method}} 474 | {{- else if eq $method.Type "get"}} 475 | {{- template "get" $method}} 476 | {{- else if eq $method.Type "list"}} 477 | {{- template "list" $method}} 478 | {{- else if eq $method.Type "page"}} 479 | {{- template "page" $method}} 480 | {{- else if eq $method.Type "agg"}} 481 | {{- template "agg" $method}} 482 | {{- else}} 483 | panic("unimplemented") 484 | {{- end -}} 485 | {{- end -}} 486 | } 487 | 488 | {{end}} 489 | ` 490 | -------------------------------------------------------------------------------- /example/store/user.light.go: -------------------------------------------------------------------------------- 1 | // !!! DO NOT EDIT THIS FILE. It is generated by 'light' tool. 2 | // @light: https://github.com/omigo/light 3 | // Generated from source: light/example/store/user.go 4 | package store 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "database/sql" 10 | "fmt" 11 | "strings" 12 | "time" 13 | 14 | "github.com/omigo/light/example/model" 15 | "github.com/omigo/light/light" 16 | "github.com/omigo/light/null" 17 | "github.com/omigo/log" 18 | ) 19 | 20 | func init() { User = new(StoreIUser) } 21 | 22 | type StoreIUser struct{} 23 | 24 | func (*StoreIUser) Create(name string) error { 25 | var exec = db 26 | var buf bytes.Buffer 27 | 28 | fmt.Fprintf(&buf, "CREATE TABLE if NOT EXISTS %v ( id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY, username VARCHAR(32) NOT NULL UNIQUE, Phone VARCHAR(32), address VARCHAR(256), _status TINYINT UNSIGNED, birth_day DATE, created TIMESTAMP default CURRENT_TIMESTAMP, updated TIMESTAMP default CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8 ", name) 29 | query := buf.String() 30 | log.Debug(query) 31 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 32 | defer cancel() 33 | _, err := exec.ExecContext(ctx, query) 34 | if err != nil { 35 | log.Error(query) 36 | log.Error(err) 37 | } 38 | return err 39 | } 40 | 41 | func (*StoreIUser) Insert(tx *sql.Tx, u *model.User) (int64, error) { 42 | var exec = light.GetExec(tx, db) 43 | var buf bytes.Buffer 44 | var args []interface{} 45 | 46 | buf.WriteString("INSERT IGNORE INTO users(`username`,phone,address,_status,birth_day,created,updated) VALUES (?,?,?,?,?,CURRENT_TIMESTAMP,CURRENT_TIMESTAMP) ") 47 | args = append(args, u.Username, null.String(&u.Phone), u.Address, u.Status, u.BirthDay) 48 | query := buf.String() 49 | log.Debug(query) 50 | log.Debug(args...) 51 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 52 | defer cancel() 53 | res, err := exec.ExecContext(ctx, query, args...) 54 | if err != nil { 55 | log.Error(query) 56 | log.Error(args...) 57 | log.Error(err) 58 | return 0, err 59 | } 60 | return res.LastInsertId() 61 | } 62 | 63 | func (*StoreIUser) Bulky(us []*model.User) (int64, int64, error) { 64 | xn := int64(len(us)) 65 | if xn == 0 { 66 | return 0, 0, nil 67 | } 68 | 69 | var xaffect, xignore int64 70 | var buf bytes.Buffer 71 | 72 | buf.WriteString("INSERT IGNORE INTO users(`username`,phone,address,_status,birth_day,created,updated) VALUES (?,?,?,?,?,CURRENT_TIMESTAMP,CURRENT_TIMESTAMP) ") 73 | 74 | query := buf.String() 75 | log.Debug(query) 76 | tx, err := db.Begin() 77 | if err != nil { 78 | log.Error(err) 79 | return 0, xn, err 80 | } 81 | defer tx.Rollback() 82 | 83 | stmt, err := tx.Prepare(query) 84 | if err != nil { 85 | log.Error(query, err) 86 | return 0, xn, err 87 | } 88 | var args []interface{} 89 | for _, u := range us { 90 | args = args[:0] 91 | args = append(args, u.Username, null.String(&u.Phone), u.Address, u.Status, u.BirthDay) 92 | log.Debug(args...) 93 | if _, err := stmt.Exec(args...); err != nil { 94 | xignore++ 95 | log.Error(args...) 96 | log.Error(err) 97 | } else { 98 | xaffect++ 99 | } 100 | } 101 | if err := tx.Commit(); err != nil { 102 | return 0, xn, err 103 | } 104 | 105 | return xaffect, xignore, nil 106 | } 107 | 108 | func (*StoreIUser) Upsert(u *model.User, tx *sql.Tx) (int64, error) { 109 | var exec = light.GetExec(tx, db) 110 | var buf bytes.Buffer 111 | var args []interface{} 112 | 113 | buf.WriteString("INSERT INTO users(username,phone,address,_status,birth_day,created,updated) VALUES (?,?,?,?,?,CURRENT_TIMESTAMP,CURRENT_TIMESTAMP) ON DUPLICATE KEY UPDATE username=VALUES(username), phone=VALUES(phone), address=VALUES(address), _status=VALUES(_status), birth_day=VALUES(birth_day), updated=CURRENT_TIMESTAMP ") 114 | args = append(args, u.Username, null.String(&u.Phone), u.Address, u.Status, u.BirthDay) 115 | query := buf.String() 116 | log.Debug(query) 117 | log.Debug(args...) 118 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 119 | defer cancel() 120 | res, err := exec.ExecContext(ctx, query, args...) 121 | if err != nil { 122 | log.Error(query) 123 | log.Error(args...) 124 | log.Error(err) 125 | return 0, err 126 | } 127 | return res.LastInsertId() 128 | } 129 | 130 | func (*StoreIUser) Replace(u *model.User) (int64, error) { 131 | var exec = db 132 | var buf bytes.Buffer 133 | var args []interface{} 134 | 135 | buf.WriteString("REPLACE INTO users(username,phone,address,_status,birth_day,created,updated) VALUES (?,?,?,?,?,CURRENT_TIMESTAMP,CURRENT_TIMESTAMP) ") 136 | args = append(args, u.Username, null.String(&u.Phone), u.Address, u.Status, u.BirthDay) 137 | query := buf.String() 138 | log.Debug(query) 139 | log.Debug(args...) 140 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 141 | defer cancel() 142 | res, err := exec.ExecContext(ctx, query, args...) 143 | if err != nil { 144 | log.Error(query) 145 | log.Error(args...) 146 | log.Error(err) 147 | return 0, err 148 | } 149 | return res.LastInsertId() 150 | } 151 | 152 | func (*StoreIUser) Update(u *model.User) (int64, error) { 153 | var exec = db 154 | var buf bytes.Buffer 155 | var args []interface{} 156 | 157 | buf.WriteString("UPDATE users SET ") 158 | 159 | if u.Username != "" { 160 | buf.WriteString("username=?, ") 161 | args = append(args, u.Username) 162 | } 163 | 164 | if u.Phone != "" { 165 | buf.WriteString("phone=?, ") 166 | args = append(args, null.String(&u.Phone)) 167 | } 168 | 169 | if u.Address != nil { 170 | buf.WriteString("address=?, ") 171 | args = append(args, u.Address) 172 | } 173 | 174 | if u.Status != 0 { 175 | buf.WriteString("_status=?, ") 176 | args = append(args, u.Status) 177 | } 178 | 179 | if !u.BirthDay.IsZero() { 180 | buf.WriteString("birth_day=?, ") 181 | args = append(args, u.BirthDay) 182 | } 183 | 184 | buf.WriteString("updated=CURRENT_TIMESTAMP WHERE id=? ") 185 | args = append(args, u.Id) 186 | query := buf.String() 187 | log.Debug(query) 188 | log.Debug(args...) 189 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 190 | defer cancel() 191 | res, err := exec.ExecContext(ctx, query, args...) 192 | if err != nil { 193 | log.Error(query) 194 | log.Error(args...) 195 | log.Error(err) 196 | return 0, err 197 | } 198 | return res.RowsAffected() 199 | } 200 | 201 | func (*StoreIUser) Delete(id uint64) (int64, error) { 202 | var exec = db 203 | var buf bytes.Buffer 204 | var args []interface{} 205 | 206 | buf.WriteString("DELETE FROM users WHERE id=? ") 207 | args = append(args, id) 208 | query := buf.String() 209 | log.Debug(query) 210 | log.Debug(args...) 211 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 212 | defer cancel() 213 | res, err := exec.ExecContext(ctx, query, args...) 214 | if err != nil { 215 | log.Error(query) 216 | log.Error(args...) 217 | log.Error(err) 218 | return 0, err 219 | } 220 | return res.RowsAffected() 221 | } 222 | 223 | func (*StoreIUser) Get(id uint64) (*model.User, error) { 224 | var exec = db 225 | var buf bytes.Buffer 226 | var args []interface{} 227 | 228 | buf.WriteString("SELECT id, username, mobile, address, _status, birth_day, created, updated ") 229 | 230 | buf.WriteString("FROM users WHERE id=? ") 231 | args = append(args, id) 232 | query := buf.String() 233 | log.Debug(query) 234 | log.Debug(args...) 235 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 236 | defer cancel() 237 | row := exec.QueryRowContext(ctx, query, args...) 238 | xu := new(model.User) 239 | xdst := []interface{}{&xu.Id, &xu.Username, null.String(&xu.Phone), &xu.Address, &xu.Status, &xu.BirthDay, &xu.Created, &xu.Updated} 240 | err := row.Scan(xdst...) 241 | if err != nil { 242 | if err == sql.ErrNoRows { 243 | return nil, nil 244 | } 245 | log.Error(query) 246 | log.Error(args...) 247 | log.Error(err) 248 | return nil, err 249 | } 250 | log.Trace(xdst) 251 | return xu, err 252 | } 253 | 254 | func (*StoreIUser) Count(birthDay time.Time) (int64, error) { 255 | var exec = db 256 | var buf bytes.Buffer 257 | var args []interface{} 258 | 259 | buf.WriteString("SELECT count(1) ") 260 | 261 | buf.WriteString("FROM users WHERE birth_day < ? ") 262 | args = append(args, birthDay) 263 | query := buf.String() 264 | log.Debug(query) 265 | log.Debug(args...) 266 | var xu int64 267 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 268 | defer cancel() 269 | err := exec.QueryRowContext(ctx, query, args...).Scan(null.Int64(&xu)) 270 | if err != nil { 271 | if err == sql.ErrNoRows { 272 | log.Debug(xu) 273 | return xu, nil 274 | } 275 | log.Error(query) 276 | log.Error(args...) 277 | log.Error(err) 278 | return xu, err 279 | } 280 | log.Debug(xu) 281 | return xu, nil 282 | } 283 | 284 | func (*StoreIUser) List(u *model.User, offset int, size int) ([]*model.User, error) { 285 | var exec = db 286 | var buf bytes.Buffer 287 | var args []interface{} 288 | 289 | buf.WriteString("SELECT (SELECT id FROM users WHERE id=a.id) AS id, `username`, phone AS phone, address, _status, birth_day, created, updated ") 290 | 291 | buf.WriteString("FROM users a WHERE id != -1 AND username <> 'admin' AND username LIKE ? ") 292 | args = append(args, u.Username) 293 | 294 | if (u.Phone != "") || ((u.BirthDay != nil && !u.BirthDay.IsZero()) || u.Id > 1) { 295 | buf.WriteString("AND address = ? ") 296 | args = append(args, u.Address) 297 | if u.Phone != "" { 298 | buf.WriteString("AND phone LIKE ? ") 299 | args = append(args, null.String(&u.Phone)) 300 | } 301 | buf.WriteString("AND created > ? ") 302 | args = append(args, u.Created) 303 | if (u.BirthDay != nil && !u.BirthDay.IsZero()) || u.Id > 1 { 304 | if !u.BirthDay.IsZero() { 305 | buf.WriteString("AND birth_day > ? ") 306 | args = append(args, u.BirthDay) 307 | } 308 | if u.Id != 0 { 309 | buf.WriteString("AND id > ? ") 310 | args = append(args, u.Id) 311 | } 312 | } 313 | } 314 | 315 | buf.WriteString("AND _status != ? ") 316 | args = append(args, u.Status) 317 | 318 | if !u.Updated.IsEmpty() { 319 | buf.WriteString("AND updated > ? ") 320 | args = append(args, u.Updated) 321 | } 322 | 323 | buf.WriteString("AND birth_day IS NOT NULL ") 324 | 325 | buf.WriteString("ORDER BY updated DESC LIMIT ?, ? ") 326 | args = append(args, offset, size) 327 | query := buf.String() 328 | log.Debug(query) 329 | log.Debug(args...) 330 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 331 | defer cancel() 332 | rows, err := exec.QueryContext(ctx, query, args...) 333 | if err != nil { 334 | log.Error(query) 335 | log.Error(args...) 336 | log.Error(err) 337 | return nil, err 338 | } 339 | defer rows.Close() 340 | var data []*model.User 341 | for rows.Next() { 342 | xu := new(model.User) 343 | data = append(data, xu) 344 | xdst := []interface{}{&xu.Id, &xu.Username, null.String(&xu.Phone), &xu.Address, &xu.Status, &xu.BirthDay, &xu.Created, &xu.Updated} 345 | err = rows.Scan(xdst...) 346 | if err != nil { 347 | log.Error(query) 348 | log.Error(args...) 349 | log.Error(err) 350 | return nil, err 351 | } 352 | log.Trace(xdst) 353 | } 354 | if err = rows.Err(); err != nil { 355 | log.Error(query) 356 | log.Error(args...) 357 | log.Error(err) 358 | return nil, err 359 | } 360 | return data, nil 361 | } 362 | 363 | func (*StoreIUser) Page(u *model.User, ss []model.Status, offset int, size int) (int64, []*model.User, error) { 364 | var exec = db 365 | var buf bytes.Buffer 366 | var args []interface{} 367 | 368 | var xFirstBuf bytes.Buffer 369 | var xFirstArgs []interface{} 370 | xFirstBuf.WriteString("SELECT id, username, if(phone='', '0', phone) phone, address, _status, birth_day, created, updated ") 371 | 372 | buf.WriteString("FROM users WHERE username LIKE ? ") 373 | args = append(args, u.Username) 374 | 375 | if u.Phone != "" { 376 | buf.WriteString("AND address = ? ") 377 | args = append(args, u.Address) 378 | if u.Phone != "" { 379 | buf.WriteString("AND phone LIKE ? ") 380 | args = append(args, null.String(&u.Phone)) 381 | } 382 | buf.WriteString("AND created > ? ") 383 | args = append(args, u.Created) 384 | } 385 | 386 | buf.WriteString("AND birth_day IS NOT NULL AND _status != ? ") 387 | args = append(args, u.Status) 388 | 389 | if len(ss) > 0 { 390 | fmt.Fprintf(&buf, "AND _status in (%v) ", strings.Repeat(",?", len(ss))[1:]) 391 | for _, v := range ss { 392 | args = append(args, v) 393 | } 394 | } 395 | 396 | if !u.Updated.IsEmpty() { 397 | buf.WriteString("AND updated > ? ") 398 | args = append(args, u.Updated) 399 | } 400 | 401 | var xLastBuf bytes.Buffer 402 | var xLastArgs []interface{} 403 | xLastBuf.WriteString("ORDER BY updated DESC LIMIT ?, ? ") 404 | xLastArgs = append(xLastArgs, offset, size) 405 | var total int64 406 | totalQuery := "SELECT count(1) " + buf.String() 407 | log.Debug(totalQuery) 408 | log.Debug(args...) 409 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 410 | defer cancel() 411 | err := exec.QueryRowContext(ctx, totalQuery, args...).Scan(&total) 412 | if err != nil { 413 | log.Error(totalQuery) 414 | log.Error(args...) 415 | log.Error(err) 416 | return 0, nil, err 417 | } 418 | log.Debug(total) 419 | query := xFirstBuf.String() + buf.String() + xLastBuf.String() 420 | args = append(xFirstArgs, args...) 421 | args = append(args, xLastArgs...) 422 | log.Debug(query) 423 | log.Debug(args...) 424 | ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second) 425 | defer cancel() 426 | rows, err := exec.QueryContext(ctx, query, args...) 427 | if err != nil { 428 | log.Error(query) 429 | log.Error(args...) 430 | log.Error(err) 431 | return 0, nil, err 432 | } 433 | defer rows.Close() 434 | var data []*model.User 435 | for rows.Next() { 436 | xu := new(model.User) 437 | data = append(data, xu) 438 | xdst := []interface{}{&xu.Id, &xu.Username, null.String(&xu.Phone), &xu.Address, &xu.Status, &xu.BirthDay, &xu.Created, &xu.Updated} 439 | err = rows.Scan(xdst...) 440 | if err != nil { 441 | log.Error(query) 442 | log.Error(args...) 443 | log.Error(err) 444 | return 0, nil, err 445 | } 446 | log.Trace(xdst) 447 | } 448 | if err = rows.Err(); err != nil { 449 | log.Error(query) 450 | log.Error(args...) 451 | log.Error(err) 452 | return 0, nil, err 453 | } 454 | return total, data, nil 455 | } 456 | --------------------------------------------------------------------------------