├── tx.go ├── interface.go ├── go.mod ├── log.go ├── query.go ├── xsqlora.sql ├── go.sum ├── util.go ├── xsql.sql ├── options.go ├── model_executor.go ├── dbora_test.go ├── db.go ├── README.md ├── fetcher.go ├── executor.go └── db_test.go /tx.go: -------------------------------------------------------------------------------- 1 | package xsql 2 | 3 | import "database/sql" 4 | 5 | type Tx struct { 6 | raw *sql.Tx 7 | *DB 8 | } 9 | 10 | func (t *Tx) Commit() error { 11 | return t.raw.Commit() 12 | } 13 | 14 | func (t *Tx) Rollback() error { 15 | return t.raw.Rollback() 16 | } 17 | -------------------------------------------------------------------------------- /interface.go: -------------------------------------------------------------------------------- 1 | package xsql 2 | 3 | import "database/sql" 4 | 5 | type Executor interface { 6 | Exec(query string, args ...interface{}) (sql.Result, error) 7 | } 8 | 9 | type Query interface { 10 | Query(query string, args ...interface{}) (*sql.Rows, error) 11 | } 12 | 13 | type Table interface { 14 | TableName() string 15 | } 16 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mix-go/xsql 2 | 3 | go 1.24 4 | 5 | require ( 6 | github.com/go-sql-driver/mysql v1.6.0 7 | github.com/sijms/go-ora/v2 v2.9.0 8 | github.com/stretchr/testify v1.7.1 9 | google.golang.org/protobuf v1.36.10 10 | ) 11 | 12 | require ( 13 | github.com/davecgh/go-spew v1.1.0 // indirect 14 | github.com/pmezard/go-difflib v1.0.0 // indirect 15 | gopkg.in/yaml.v3 v3.0.0 // indirect 16 | ) 17 | -------------------------------------------------------------------------------- /log.go: -------------------------------------------------------------------------------- 1 | package xsql 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | type Log struct { 9 | Context context.Context `json:"context"` 10 | Duration time.Duration `json:"duration"` 11 | SQL string `json:"sql"` 12 | Bindings []interface{} `json:"bindings"` 13 | RowsAffected int64 `json:"rowsAffected"` 14 | Error error `json:"error"` 15 | } 16 | 17 | type DebugFunc func(l *Log) 18 | -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | package xsql 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | type query struct { 9 | Query 10 | } 11 | 12 | func (t *query) Fetch(ctx context.Context, query string, args []interface{}, opts *sqlOptions) (*Fetcher, error) { 13 | startTime := time.Now() 14 | r, err := t.Query.Query(query, args...) 15 | l := &Log{ 16 | Context: ctx, 17 | Duration: time.Now().Sub(startTime), 18 | SQL: query, 19 | Bindings: args, 20 | RowsAffected: 0, 21 | Error: err, 22 | } 23 | if err != nil { 24 | opts.doDebug(l) 25 | return nil, err 26 | } 27 | 28 | f := &Fetcher{ 29 | r: r, 30 | log: l, 31 | options: opts, 32 | } 33 | return f, err 34 | } 35 | -------------------------------------------------------------------------------- /xsqlora.sql: -------------------------------------------------------------------------------- 1 | -- ---------------------------- 2 | -- Table structure for XSQL 3 | -- ---------------------------- 4 | DROP TABLE "TEST"."XSQL"; 5 | CREATE TABLE "TEST"."XSQL" ( 6 | "ID" NUMBER VISIBLE NOT NULL, 7 | "FOO" VARCHAR2(255 BYTE) VISIBLE, 8 | "BAR" TIMESTAMP(6) VISIBLE 9 | ) 10 | LOGGING 11 | NOCOMPRESS 12 | PCTFREE 10 13 | INITRANS 1 14 | STORAGE ( 15 | INITIAL 1048576 16 | NEXT 1048576 17 | MINEXTENTS 1 18 | MAXEXTENTS 2147483645 19 | BUFFER_POOL DEFAULT 20 | ) 21 | PARALLEL 1 22 | NOCACHE 23 | DISABLE ROW MOVEMENT 24 | ; 25 | 26 | -- ---------------------------- 27 | -- Records of XSQL 28 | -- ---------------------------- 29 | INSERT INTO "TEST"."XSQL" ("ID", "FOO", "BAR") VALUES ('1', 'v', TO_TIMESTAMP('2022-04-14 23:49:48.000000', 'SYYYY-MM-DD HH24:MI:SS:FF6')); 30 | INSERT INTO "TEST"."XSQL" ("ID", "FOO", "BAR") VALUES ('2', 'v1', TO_TIMESTAMP('2022-04-14 23:50:00.000000', 'SYYYY-MM-DD HH24:MI:SS:FF6')); 31 | COMMIT; 32 | COMMIT; 33 | 34 | -- ---------------------------- 35 | -- Primary Key structure for table XSQL 36 | -- ---------------------------- 37 | ALTER TABLE "TEST"."XSQL" ADD CONSTRAINT "SYS_C00214001" PRIMARY KEY ("ID"); 38 | 39 | -- ---------------------------- 40 | -- Checks structure for table XSQL 41 | -- ---------------------------- 42 | ALTER TABLE "TEST"."XSQL" ADD CONSTRAINT "SYS_C00214000" CHECK ("ID" IS NOT NULL) NOT DEFERRABLE INITIALLY IMMEDIATE NORELY VALIDATE; 43 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= 4 | github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= 5 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= 6 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= 7 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 8 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 9 | github.com/sijms/go-ora/v2 v2.9.0 h1:+iQbUeTeCOFMb5BsOMgUhV8KWyrv9yjKpcK4x7+MFrg= 10 | github.com/sijms/go-ora/v2 v2.9.0/go.mod h1:QgFInVi3ZWyqAiJwzBQA+nbKYKH77tdp1PYoCqhR2dU= 11 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 12 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= 13 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 14 | google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= 15 | google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= 16 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 17 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 18 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 19 | gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA= 20 | gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 21 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package xsql 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | type TagValues []TagValue 9 | 10 | type TagValue struct { 11 | Key interface{} 12 | Value interface{} 13 | } 14 | 15 | // TagValuesMap takes a tag key, a pointer to a struct, and TagValues. 16 | // It constructs a map where each key is the struct field's tag value, paired with the corresponding value from TagValues. 17 | func TagValuesMap(tagKey string, ptr interface{}, values TagValues) (map[string]interface{}, error) { 18 | result := make(map[string]interface{}) 19 | structValue := reflect.ValueOf(ptr).Elem() 20 | 21 | if structValue.Kind() != reflect.Struct { 22 | return nil, fmt.Errorf("xsql: ptr must be a pointer to a struct") 23 | } 24 | 25 | fieldsMap := make(map[string]reflect.Value) 26 | populateFieldsMap(tagKey, structValue, fieldsMap) 27 | 28 | for i, tagValue := range values { 29 | fieldPtr, fieldValue := tagValue.Key, reflect.ValueOf(tagValue.Key) 30 | if fieldValue.Kind() != reflect.Ptr || fieldValue.IsNil() { 31 | return nil, fmt.Errorf("xsql: error at item %d in values slice: key is not a non-nil pointer to a struct field", i) 32 | } 33 | 34 | foundFieldName := "" 35 | for tagName, field := range fieldsMap { 36 | if field.Addr().Interface() == fieldPtr { 37 | foundFieldName = tagName 38 | break 39 | } 40 | } 41 | 42 | if foundFieldName == "" { 43 | return nil, fmt.Errorf("xsql: no matching struct field found for item %d in values slice", i) 44 | } 45 | 46 | result[foundFieldName] = tagValue.Value 47 | } 48 | 49 | return result, nil 50 | } 51 | 52 | func populateFieldsMap(tagKey string, v reflect.Value, fieldsMap map[string]reflect.Value) { 53 | for i := 0; i < v.NumField(); i++ { 54 | field := v.Field(i) 55 | fieldType := v.Type().Field(i) 56 | tag := fieldType.Tag.Get(tagKey) 57 | if fieldType.Anonymous && field.Type().Kind() == reflect.Struct { 58 | populateFieldsMap(tagKey, field, fieldsMap) 59 | } else if tag != "" { 60 | fieldsMap[tag] = field 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /xsql.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS `xsql`; 2 | CREATE TABLE `xsql` ( 3 | `id` int unsigned NOT NULL AUTO_INCREMENT, 4 | `foo` varchar(255) DEFAULT NULL, 5 | `bar` datetime DEFAULT NULL, 6 | `bool` int NOT NULL DEFAULT '0', 7 | `enum` int NOT NULL DEFAULT '0', 8 | `json` json DEFAULT NULL, 9 | PRIMARY KEY (`id`) 10 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; 11 | INSERT INTO `xsql` (`id`, `foo`, `bar`, `bool`, `enum`, `json`) VALUES (1, 'v', '2022-04-12 23:50:00', 1, 1, '{"foo":"bar"}'); 12 | INSERT INTO `xsql` (`id`, `foo`, `bar`, `bool`, `enum`, `json`) VALUES (2, 'v1', '2022-04-13 23:50:00', 1, 1, '[1,2]'); 13 | INSERT INTO `xsql` (`id`, `foo`, `bar`, `bool`, `enum`, `json`) VALUES (3, 'v2', '2022-04-14 23:50:00', 1, 1, null); 14 | 15 | DROP TABLE IF EXISTS `devices`; 16 | CREATE TABLE `devices` ( 17 | `id` bigint unsigned NOT NULL AUTO_INCREMENT, 18 | `uuid` varchar(255) NOT NULL, 19 | `user_id` bigint unsigned NOT NULL, 20 | `platform` tinyint unsigned NOT NULL, 21 | `info` varchar(255) NOT NULL, 22 | `app` json DEFAULT NULL, 23 | `language_code` varchar(255) NOT NULL, 24 | `status` tinyint unsigned NOT NULL, 25 | `synced_message_id` bigint unsigned NOT NULL DEFAULT '0', 26 | `firebase_token` varchar(255) NOT NULL DEFAULT '', 27 | `last_used_at` timestamp NOT NULL, 28 | `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, 29 | `updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 30 | PRIMARY KEY (`id`) 31 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; 32 | INSERT INTO `devices` (`id`, `uuid`, `user_id`, `platform`, `info`, `app`, `language_code`, `status`, `synced_message_id`, `firebase_token`, `last_used_at`, `created_at`, `updated_at`) VALUES (1, '0c7f60e8-d7a3-48ae-9139-5128e336736e', 100000010, 1, 'postman', '{\"build\": 1, \"version\": \"v1.0.0\"}', '', 1, 0, '', '1970-01-01 00:00:01', '2025-04-07 07:41:23', '2025-04-07 07:41:23'); 33 | INSERT INTO `devices` (`id`, `uuid`, `user_id`, `platform`, `info`, `app`, `language_code`, `status`, `synced_message_id`, `firebase_token`, `last_used_at`, `created_at`, `updated_at`) VALUES (2, '0c7f60e8-d7a3-48ae-9139-5128e336736e', 100000011, 1, 'postman', '{\"build\": 1, \"version\": \"v1.0.0\"}', '', 1, 0, '', '1970-01-01 00:00:01', '2025-04-07 07:41:25', '2025-04-07 07:41:25'); 34 | INSERT INTO `devices` (`id`, `uuid`, `user_id`, `platform`, `info`, `app`, `language_code`, `status`, `synced_message_id`, `firebase_token`, `last_used_at`, `created_at`, `updated_at`) VALUES (3, '0c7f60e8-d7a3-48ae-9139-5128e336736e', 100000012, 1, 'postman', '{\"build\": 1, \"version\": \"v1.0.0\"}', '', 1, 0, '', '1970-01-01 00:00:01', '2025-04-07 07:41:26', '2025-04-07 07:41:26'); 35 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package xsql 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | var DefaultOptions = newDefaultOptions() 9 | 10 | func newDefaultOptions() sqlOptions { 11 | return sqlOptions{ 12 | Tag: "xsql", 13 | InsertKey: "INSERT INTO", 14 | TableKey: "${TABLE}", 15 | Placeholder: "?", 16 | ColumnQuotes: "`", 17 | TimeLayout: "2006-01-02 15:04:05", 18 | TimeLocation: time.Local, 19 | TimeFunc: func(placeholder string) string { 20 | return placeholder 21 | }, 22 | DebugFunc: nil, 23 | } 24 | } 25 | 26 | type sqlOptions struct { 27 | // Default: xsql 28 | Tag string 29 | 30 | // Default: INSERT INTO 31 | InsertKey string 32 | 33 | // Default: ${TABLE} 34 | TableKey string 35 | 36 | // Default: ? 37 | // For oracle, can be configured as :%d 38 | Placeholder string 39 | 40 | // Default: ` 41 | // For oracle, can be configured as " 42 | ColumnQuotes string 43 | 44 | // Default: 2006-01-02 15:04:05 45 | TimeLayout string 46 | 47 | // Default: time.Local 48 | TimeLocation *time.Location 49 | 50 | // Default: func(placeholder string) string { return placeholder } 51 | // For oracle, this closure can be modified to add TO_TIMESTAMP 52 | TimeFunc TimeFunc 53 | 54 | // Global debug SQL 55 | DebugFunc DebugFunc 56 | } 57 | 58 | func mergeOptions(opts []SqlOption) *sqlOptions { 59 | cp := DefaultOptions // copy 60 | for _, o := range opts { 61 | o.apply(&cp) 62 | } 63 | return &cp 64 | } 65 | 66 | type SqlOption interface { 67 | apply(*sqlOptions) 68 | } 69 | 70 | type funcSqlOption struct { 71 | f func(*sqlOptions) 72 | } 73 | 74 | func (fdo *funcSqlOption) apply(do *sqlOptions) { 75 | fdo.f(do) 76 | } 77 | 78 | func WithTag(tag string) SqlOption { 79 | return &funcSqlOption{func(opts *sqlOptions) { 80 | opts.Tag = tag 81 | }} 82 | } 83 | 84 | func WithInsertKey(insertKey string) SqlOption { 85 | return &funcSqlOption{func(opts *sqlOptions) { 86 | opts.InsertKey = insertKey 87 | }} 88 | } 89 | 90 | func WithPlaceholder(placeholder string) SqlOption { 91 | return &funcSqlOption{func(opts *sqlOptions) { 92 | opts.Placeholder = placeholder 93 | }} 94 | } 95 | 96 | func WithColumnQuotes(columnQuotes string) SqlOption { 97 | return &funcSqlOption{func(opts *sqlOptions) { 98 | opts.ColumnQuotes = columnQuotes 99 | }} 100 | } 101 | 102 | func WithTimeLayout(timeLayout string) SqlOption { 103 | return &funcSqlOption{func(opts *sqlOptions) { 104 | opts.TimeLayout = timeLayout 105 | }} 106 | } 107 | 108 | func WithTimeLocation(timeLocation *time.Location) SqlOption { 109 | return &funcSqlOption{func(opts *sqlOptions) { 110 | opts.TimeLocation = timeLocation 111 | }} 112 | } 113 | 114 | func WithTimeFunc(f TimeFunc) SqlOption { 115 | return &funcSqlOption{func(opts *sqlOptions) { 116 | opts.TimeFunc = f 117 | }} 118 | } 119 | 120 | func WithDebugFunc(f DebugFunc) SqlOption { 121 | return &funcSqlOption{func(opts *sqlOptions) { 122 | opts.DebugFunc = f 123 | }} 124 | } 125 | 126 | func UseOracle() SqlOption { 127 | return &funcSqlOption{func(opts *sqlOptions) { 128 | opts.Placeholder = `:%d` 129 | opts.ColumnQuotes = `"` 130 | opts.TimeFunc = func(placeholder string) string { 131 | return fmt.Sprintf("TO_TIMESTAMP(%s, 'SYYYY-MM-DD HH24:MI:SS:FF6')", placeholder) 132 | } 133 | }} 134 | } 135 | 136 | func (t *sqlOptions) doDebug(l *Log) { 137 | if t.DebugFunc == nil { 138 | return 139 | } 140 | t.DebugFunc(l) 141 | } 142 | -------------------------------------------------------------------------------- /model_executor.go: -------------------------------------------------------------------------------- 1 | package xsql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "reflect" 8 | "slices" 9 | "strings" 10 | "time" 11 | ) 12 | 13 | type ModelExecutor struct { 14 | Executor 15 | Options *sqlOptions 16 | 17 | TableName string 18 | 19 | Error error 20 | RowsAffected int64 21 | } 22 | 23 | func (t *ModelExecutor) Update(ctx context.Context, data map[string]interface{}, expr string, args ...interface{}) *ModelExecutor { 24 | return t.getExecResult(t.update(ctx, data, expr, args...)) 25 | } 26 | 27 | func (t *ModelExecutor) getExecResult(r sql.Result, err error) *ModelExecutor { 28 | if err != nil { 29 | t.Error = err 30 | return t 31 | } else { 32 | t.Error = nil 33 | } 34 | 35 | rowsAffected, err := r.RowsAffected() 36 | if err != nil { 37 | t.Error = err 38 | return t 39 | } else { 40 | t.Error = nil 41 | } 42 | t.RowsAffected = rowsAffected 43 | return t 44 | } 45 | 46 | func (t *ModelExecutor) update(ctx context.Context, data map[string]interface{}, expr string, args ...interface{}) (sql.Result, error) { 47 | set := make([]string, 0) 48 | bindArgs := make([]interface{}, 0) 49 | 50 | table := t.TableName 51 | opts := t.Options 52 | 53 | n := 0 54 | for key, val := range data { 55 | value := reflect.ValueOf(val) 56 | vTyp := value.Type().String() 57 | isTime := isTime(vTyp) 58 | 59 | var v string 60 | if opts.Placeholder == "?" { 61 | v = opts.Placeholder 62 | } else { 63 | v = fmt.Sprintf(opts.Placeholder, n) 64 | n++ 65 | } 66 | if isTime { 67 | v = opts.TimeFunc(v) 68 | } 69 | 70 | var a interface{} 71 | if isTime { 72 | a = formatTime(vTyp, value.Interface(), opts) 73 | } else { 74 | // 非标量用JSON序列化处理 75 | if slices.Contains([]reflect.Kind{reflect.Ptr, reflect.Struct, reflect.Slice, reflect.Array, reflect.Map}, value.Kind()) { 76 | b, e := marshal(value.Interface()) 77 | if e != nil { 78 | return nil, fmt.Errorf("json unmarshal error %s for field %s", e, key) 79 | } 80 | a = string(b) 81 | } else { 82 | a = value.Interface() 83 | } 84 | } 85 | 86 | set = append(set, fmt.Sprintf("`%s` = %s", key, v)) 87 | bindArgs = append(bindArgs, a) 88 | } 89 | 90 | where := "" 91 | if expr != "" { 92 | where = fmt.Sprintf(` WHERE %s`, expr) 93 | bindArgs = append(bindArgs, args...) 94 | } 95 | 96 | SQL := fmt.Sprintf(`UPDATE %s SET %s%s`, table, strings.Join(set, ", "), where) 97 | 98 | startTime := time.Now() 99 | res, err := t.Executor.Exec(SQL, bindArgs...) 100 | var rowsAffected int64 101 | if res != nil { 102 | rowsAffected, _ = res.RowsAffected() 103 | } 104 | l := &Log{ 105 | Context: ctx, 106 | Duration: time.Now().Sub(startTime), 107 | SQL: SQL, 108 | Bindings: bindArgs, 109 | RowsAffected: rowsAffected, 110 | Error: err, 111 | } 112 | opts.doDebug(l) 113 | if err != nil { 114 | return nil, err 115 | } 116 | 117 | return res, nil 118 | } 119 | 120 | func (t *ModelExecutor) Delete(ctx context.Context, expr string, args ...interface{}) *ModelExecutor { 121 | return t.getExecResult(t.delete(ctx, expr, args...)) 122 | } 123 | 124 | func (t *ModelExecutor) delete(ctx context.Context, expr string, args ...interface{}) (sql.Result, error) { 125 | bindArgs := make([]interface{}, 0) 126 | 127 | table := t.TableName 128 | opts := t.Options 129 | 130 | where := "" 131 | if expr != "" { 132 | where = fmt.Sprintf(` WHERE %s`, expr) 133 | bindArgs = append(bindArgs, args...) 134 | } 135 | 136 | SQL := fmt.Sprintf(`DELETE FROM %s%s`, table, where) 137 | 138 | startTime := time.Now() 139 | res, err := t.Executor.Exec(SQL, bindArgs...) 140 | var rowsAffected int64 141 | if res != nil { 142 | rowsAffected, _ = res.RowsAffected() 143 | } 144 | l := &Log{ 145 | Context: ctx, 146 | Duration: time.Now().Sub(startTime), 147 | SQL: SQL, 148 | Bindings: bindArgs, 149 | RowsAffected: rowsAffected, 150 | Error: err, 151 | } 152 | opts.doDebug(l) 153 | if err != nil { 154 | return nil, err 155 | } 156 | 157 | return res, nil 158 | } 159 | -------------------------------------------------------------------------------- /dbora_test.go: -------------------------------------------------------------------------------- 1 | package xsql_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "github.com/mix-go/xsql" 8 | "github.com/sijms/go-ora/v2" 9 | "github.com/stretchr/testify/assert" 10 | "log" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | func newOracleDB() *xsql.DB { 16 | db, err := sql.Open("oracle", "oracle://root:123456@127.0.0.1:1521/orcl") 17 | if err != nil { 18 | log.Fatal(err) 19 | } 20 | return xsql.New( 21 | db, 22 | xsql.WithDebugFunc(func(l *xsql.Log) { 23 | log.Println(l) 24 | }), 25 | xsql.WithTimeLocation(time.UTC), 26 | xsql.UseOracle(), 27 | ) 28 | } 29 | 30 | func TestOracleClear(t *testing.T) { 31 | a := assert.New(t) 32 | 33 | DB := newOracleDB() 34 | 35 | err := DB.Exec(context.Background(), "DELETE FROM XSQL WHERE ID > 2").Error 36 | 37 | a.Empty(err) 38 | } 39 | 40 | func TestOracleQuery(t *testing.T) { 41 | a := assert.New(t) 42 | 43 | DB := newOracleDB() 44 | 45 | rows, err := DB.Query(context.Background(), "SELECT * FROM XSQL WHERE ROWNUM <= 2") 46 | if err != nil { 47 | log.Fatal(err) 48 | } 49 | bar := rows[0].Get("BAR").Time().Format(xsql.DefaultOptions.TimeLayout) 50 | log.Println(bar) 51 | 52 | a.Equal(bar, "2022-04-14 23:49:48") 53 | } 54 | 55 | type TestOracle struct { 56 | Id int `xsql:"ID"` 57 | Foo string `xsql:"FOO"` 58 | Bar go_ora.TimeStamp `xsql:"BAR"` 59 | } 60 | 61 | func (t TestOracle) TableName() string { 62 | return "XSQL" 63 | } 64 | 65 | func TestOracleInsert(t *testing.T) { 66 | a := assert.New(t) 67 | 68 | DB := newOracleDB() 69 | 70 | test := TestOracle{ 71 | Id: 3, 72 | Foo: "test", 73 | Bar: go_ora.TimeStamp(time.Now()), 74 | } 75 | err := DB.Insert(context.Background(), &test).Error 76 | 77 | a.Empty(err) 78 | } 79 | 80 | // oracle 不支持批量插入 81 | func __TestOracleBatchInsert(t *testing.T) { 82 | a := assert.New(t) 83 | 84 | DB := newOracleDB() 85 | 86 | tests := []TestOracle{ 87 | { 88 | Id: 4, 89 | Foo: "test1", 90 | Bar: go_ora.TimeStamp(time.Now()), 91 | }, 92 | { 93 | Id: 5, 94 | Foo: "test2", 95 | Bar: go_ora.TimeStamp(time.Now()), 96 | }, 97 | } 98 | err := DB.BatchInsert(context.Background(), &tests).Error 99 | 100 | a.Empty(err) 101 | } 102 | 103 | func TestOracleUpdate(t *testing.T) { 104 | a := assert.New(t) 105 | 106 | DB := newOracleDB() 107 | 108 | test := TestOracle{ 109 | Id: 999, 110 | Foo: "test update", 111 | Bar: go_ora.TimeStamp(time.Now()), 112 | } 113 | err := DB.Update(context.Background(), &test, "id = :id", 3).Error 114 | 115 | a.Empty(err) 116 | } 117 | 118 | func TestOracleExec(t *testing.T) { 119 | a := assert.New(t) 120 | 121 | DB := newOracleDB() 122 | 123 | err := DB.Exec(context.Background(), "DELETE FROM XSQL WHERE ID = :id", 999).Error 124 | 125 | a.Empty(err) 126 | } 127 | 128 | func TestOracleFirst(t *testing.T) { 129 | a := assert.New(t) 130 | 131 | DB := newOracleDB() 132 | 133 | var test TestOracle 134 | err := DB.First(context.Background(), &test, "SELECT * FROM XSQL").Error 135 | if err != nil { 136 | log.Fatal(err) 137 | } 138 | 139 | a.Equal(fmt.Sprintf("%+v", test), "{Id:1 Foo:v Bar:{wall:0 ext:63785576988 loc:}}") 140 | } 141 | 142 | func TestOracleFirstPart(t *testing.T) { 143 | a := assert.New(t) 144 | 145 | DB := newOracleDB() 146 | 147 | var test TestOracle 148 | err := DB.First(context.Background(), &test, "SELECT foo FROM XSQL").Error 149 | if err != nil { 150 | log.Fatal(err) 151 | } 152 | 153 | a.Equal(fmt.Sprintf("%+v", test), "{Id:0 Foo:v Bar:{wall:0 ext:0 loc:}}") 154 | } 155 | 156 | func TestOracleFind(t *testing.T) { 157 | a := assert.New(t) 158 | 159 | DB := newOracleDB() 160 | 161 | var tests []TestOracle 162 | err := DB.Find(context.Background(), &tests, "SELECT * FROM XSQL WHERE ROWNUM <= 2").Error 163 | if err != nil { 164 | log.Fatal(err) 165 | } 166 | 167 | a.Equal(fmt.Sprintf("%+v", tests), `[{Id:1 Foo:v Bar:{wall:0 ext:63785576988 loc:}} {Id:2 Foo:v1 Bar:{wall:0 ext:63785577000 loc:}}]`) 168 | } 169 | 170 | func TestOracleFindPart(t *testing.T) { 171 | a := assert.New(t) 172 | 173 | DB := newOracleDB() 174 | 175 | var tests []TestOracle 176 | err := DB.Find(context.Background(), &tests, "SELECT foo FROM XSQL WHERE ROWNUM <= 2").Error 177 | if err != nil { 178 | log.Fatal(err) 179 | } 180 | 181 | a.Equal(fmt.Sprintf("%+v", tests), `[{Id:0 Foo:v Bar:{wall:0 ext:0 loc:}} {Id:0 Foo:v1 Bar:{wall:0 ext:0 loc:}}]`) 182 | } 183 | 184 | func TestOracleTxCommit(t *testing.T) { 185 | a := assert.New(t) 186 | 187 | DB := newOracleDB() 188 | 189 | tx, _ := DB.Begin() 190 | 191 | test := TestOracle{ 192 | Id: 998, // oracle not support AUTO_INCREMENT 193 | Foo: "test", 194 | Bar: go_ora.TimeStamp(time.Now()), 195 | } 196 | err := tx.Insert(context.Background(), &test).Error 197 | a.Empty(err) 198 | 199 | err = tx.Commit() 200 | a.Empty(err) 201 | } 202 | 203 | func TestOracleTxRollback(t *testing.T) { 204 | a := assert.New(t) 205 | 206 | DB := newOracleDB() 207 | 208 | tx, _ := DB.Begin() 209 | 210 | test := TestOracle{ 211 | Id: 999, // oracle not support AUTO_INCREMENT 212 | Foo: "test", 213 | Bar: go_ora.TimeStamp(time.Now()), 214 | } 215 | err := tx.Insert(context.Background(), &test).Error 216 | a.Empty(err) 217 | 218 | err = tx.Rollback() 219 | a.Empty(err) 220 | } 221 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | package xsql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "encoding/json" 7 | "google.golang.org/protobuf/encoding/protojson" 8 | "google.golang.org/protobuf/proto" 9 | "reflect" 10 | "strings" 11 | ) 12 | 13 | type TimeFunc func(placeholder string) string 14 | 15 | type DB struct { 16 | *sql.DB 17 | Options *sqlOptions 18 | 19 | Error error 20 | RowsAffected int64 21 | LastInsertId int64 22 | 23 | executor executor 24 | query query 25 | } 26 | 27 | func New(db *sql.DB, opts ...SqlOption) *DB { 28 | return &DB{ 29 | DB: db, 30 | Options: mergeOptions(opts), 31 | executor: executor{ 32 | Executor: db, 33 | }, 34 | query: query{ 35 | Query: db, 36 | }, 37 | } 38 | } 39 | 40 | func (t *DB) mergeOptions(opts []SqlOption) *sqlOptions { 41 | cp := *t.Options // copy 42 | for _, o := range opts { 43 | o.apply(&cp) 44 | } 45 | return &cp 46 | } 47 | 48 | func (t *DB) getExecResult(r sql.Result, err error, getLastInsertId bool) *DB { 49 | if err != nil { 50 | t.Error = err 51 | return t 52 | } else { 53 | t.Error = nil 54 | } 55 | 56 | if getLastInsertId { 57 | lastInsertId, err := r.LastInsertId() 58 | if err != nil { 59 | t.Error = err 60 | return t 61 | } else { 62 | t.Error = nil 63 | } 64 | t.RowsAffected = lastInsertId 65 | } else { 66 | rowsAffected, err := r.RowsAffected() 67 | if err != nil { 68 | t.Error = err 69 | return t 70 | } else { 71 | t.Error = nil 72 | } 73 | t.RowsAffected = rowsAffected 74 | } 75 | 76 | return t 77 | } 78 | 79 | func (t *DB) getFetcherResult(err error) *DB { 80 | if err != nil { 81 | t.Error = err 82 | return t 83 | } else { 84 | t.Error = nil 85 | } 86 | return t 87 | } 88 | 89 | func (t *DB) Insert(ctx context.Context, data interface{}, opts ...SqlOption) *DB { 90 | r, err := t.executor.Insert(ctx, data, t.mergeOptions(opts)) 91 | return t.getExecResult(r, err, true) 92 | } 93 | 94 | func (t *DB) BatchInsert(ctx context.Context, data interface{}, opts ...SqlOption) *DB { 95 | r, err := t.executor.BatchInsert(ctx, data, t.mergeOptions(opts)) 96 | return t.getExecResult(r, err, false) 97 | } 98 | 99 | func (t *DB) Update(ctx context.Context, data interface{}, expr string, args ...interface{}) *DB { 100 | r, err := t.executor.Update(ctx, data, expr, args, t.Options) 101 | return t.getExecResult(r, err, false) 102 | } 103 | 104 | func (t *DB) Model(s interface{}) *ModelExecutor { 105 | return t.executor.model(s, t.Options) 106 | } 107 | 108 | func (t *DB) Exec(ctx context.Context, query string, args ...interface{}) *DB { 109 | r, err := t.executor.Exec(ctx, query, args, t.Options) 110 | return t.getExecResult(r, err, false) 111 | } 112 | 113 | func (t *DB) Begin() (*Tx, error) { 114 | tx, err := t.DB.Begin() 115 | if err != nil { 116 | return nil, err 117 | } 118 | return &Tx{ 119 | raw: tx, 120 | DB: &DB{ 121 | Options: t.Options, 122 | executor: executor{ 123 | Executor: tx, 124 | }, 125 | query: query{ 126 | Query: tx, 127 | }, 128 | }, 129 | }, nil 130 | } 131 | 132 | func (t *DB) Query(ctx context.Context, query string, args ...interface{}) ([]*Row, error) { 133 | f, err := t.query.Fetch(ctx, query, args, t.Options) 134 | if err != nil { 135 | return nil, err 136 | } 137 | r, err := f.Rows() 138 | if err != nil { 139 | return nil, err 140 | } 141 | return r, nil 142 | } 143 | 144 | func (t *DB) QueryFirst(ctx context.Context, query string, args ...interface{}) (*Row, error) { 145 | rows, err := t.Query(ctx, query, args...) 146 | if err != nil { 147 | return nil, err 148 | } 149 | if len(rows) == 0 { 150 | return nil, sql.ErrNoRows 151 | } 152 | return rows[0], nil 153 | } 154 | 155 | func (t *DB) Find(ctx context.Context, i interface{}, query string, args ...interface{}) *DB { 156 | query = tableReplace(i, query, t.Options) 157 | f, err := t.query.Fetch(ctx, query, args, t.Options) 158 | if err != nil { 159 | return t.getFetcherResult(err) 160 | } 161 | if err := f.Find(i); err != nil { 162 | return t.getFetcherResult(err) 163 | } 164 | return t.getFetcherResult(nil) 165 | } 166 | 167 | func (t *DB) First(ctx context.Context, i interface{}, query string, args ...interface{}) *DB { 168 | query = tableReplace(i, query, t.Options) 169 | f, err := t.query.Fetch(ctx, query, args, t.Options) 170 | if err != nil { 171 | return t.getFetcherResult(err) 172 | } 173 | if err := f.First(i); err != nil { 174 | return t.getFetcherResult(err) 175 | } 176 | return t.getFetcherResult(nil) 177 | } 178 | 179 | func tableReplace(i interface{}, query string, opts *sqlOptions) string { 180 | var table string 181 | 182 | value := reflect.ValueOf(i) 183 | switch value.Kind() { 184 | case reflect.Ptr: 185 | if value.Elem().IsValid() { 186 | // *Test > Test 187 | if value.Elem().Kind() == reflect.Struct { 188 | // 先尝试*Test能不能找到 189 | if tab, ok := value.Interface().(Table); ok { 190 | table = tab.TableName() 191 | break 192 | } 193 | } 194 | 195 | // **Test > *Test 196 | return tableReplace(value.Elem().Interface(), query, opts) 197 | } 198 | 199 | if tab, ok := value.Interface().(Table); ok { 200 | table = tab.TableName() 201 | break 202 | } 203 | 204 | table = getTypeName(i) 205 | case reflect.Struct: 206 | if tab, ok := value.Interface().(Table); ok { 207 | table = tab.TableName() 208 | break 209 | } 210 | 211 | // 也去尝试*Test能不能找到 212 | valuePtr := reflect.New(value.Type()) 213 | if tab, ok := valuePtr.Interface().(Table); ok { 214 | table = tab.TableName() 215 | break 216 | } 217 | table = getTypeName(i) 218 | case reflect.Array, reflect.Slice: 219 | elemType := value.Type().Elem() 220 | if elemType.Kind() == reflect.Ptr { 221 | elemType = elemType.Elem() 222 | } 223 | if elemType.Kind() == reflect.Struct { 224 | // 创建该类型的新实例 225 | // Test 226 | elemValue := reflect.New(elemType) 227 | elemInstance := elemValue.Interface() 228 | if tab, ok := elemInstance.(Table); ok { 229 | table = tab.TableName() 230 | break 231 | } 232 | 233 | // Test > *Test 234 | if elemValue.CanAddr() { 235 | elemPtrInstance := elemValue.Addr().Interface() 236 | if tab, ok := elemPtrInstance.(Table); ok { 237 | table = tab.TableName() 238 | break 239 | } 240 | } 241 | 242 | table = getTypeName(elemInstance) 243 | } else { 244 | return query // 如果元素不是结构体或其指针,返回原始查询 245 | } 246 | default: 247 | return query // 如果不是结构体、数组或切片,返回原始查询 248 | } 249 | 250 | return strings.Replace(query, opts.TableKey, table, 1) 251 | } 252 | 253 | func getTypeName(i interface{}) string { 254 | t := reflect.TypeOf(i) 255 | for t.Kind() == reflect.Ptr { 256 | t = t.Elem() 257 | } 258 | return t.Name() 259 | } 260 | 261 | var ProtoMarshalOptions = protojson.MarshalOptions{ 262 | UseProtoNames: true, 263 | UseEnumNumbers: true, 264 | } 265 | 266 | func marshal(v any) ([]byte, error) { 267 | if m, ok := v.(proto.Message); ok { 268 | return ProtoMarshalOptions.Marshal(m) 269 | } else { 270 | return json.Marshal(v) 271 | } 272 | } 273 | 274 | var ProtoUnmarshalOptions = protojson.UnmarshalOptions{ 275 | DiscardUnknown: true, 276 | } 277 | 278 | func unmarshal(b []byte, v any) error { 279 | if m, ok := v.(proto.Message); ok { 280 | return ProtoUnmarshalOptions.Unmarshal(b, m) 281 | } else { 282 | return json.Unmarshal(b, v) 283 | } 284 | } 285 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Mix XSQL 2 | 3 | A lightweight database based on database/sql, feature complete and supports any database driver. 4 | 5 | ## Installation 6 | 7 | ``` 8 | go get github.com/mix-go/xsql 9 | ``` 10 | 11 | ## Initialization 12 | 13 | - MySQL initialization, using [go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) driver. 14 | 15 | ```go 16 | import _ "github.com/go-sql-driver/mysql" 17 | 18 | db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8") 19 | if err != nil { 20 | log.Fatal(err) 21 | } 22 | 23 | DB := xsql.New(db) 24 | ``` 25 | 26 | - Oracle initialization, using [sijms/go-ora/v2](https://github.com/sijms/go-ora) driver (no need to install instantclient). 27 | 28 | ```go 29 | import _ "github.com/sijms/go-ora/v2" 30 | 31 | db, err := sql.Open("oracle", "oracle://root:123456@127.0.0.1:1521/orcl") 32 | if err != nil { 33 | log.Fatal(err) 34 | } 35 | 36 | DB := xsql.New(db, xsql.UseOracle()) 37 | ``` 38 | 39 | - [xorm#drivers](https://github.com/go-xorm/xorm#drivers-support) These drivers are also supported 40 | 41 | ## Query 42 | 43 | You can use it like a scripting language, not binding the struct, directly and freely get the value of each field. 44 | 45 | > Oracle field, table name needs to be uppercase 46 | 47 | ```go 48 | rows, err := DB.Query(context.Background(), "SELECT * FROM xsql") 49 | if err != nil { 50 | log.Fatal(err) 51 | } 52 | 53 | id := rows[0].Get("id").Int() 54 | foo := rows[0].Get("foo").String() 55 | bar := rows[0].Get("bar").Time() // time.Time 56 | val := rows[0].Get("bar").Value() // interface{} 57 | ``` 58 | 59 | ```go 60 | row, err := DB.QueryFirst(context.Background(), "SELECT * FROM xsql WHERE id = ?", 1) 61 | if err != nil { 62 | log.Fatal(err) 63 | } 64 | 65 | id := row.Get("id").Int() 66 | foo := row.Get("foo").String() 67 | bar := row.Get("bar").Time() // time.Time 68 | val := row.Get("bar").Value() // interface{} 69 | ``` 70 | 71 | ### Mapping 72 | 73 | Of course, you can also map usage like `gorm`, `xorm`. 74 | 75 | > Oracle field, table name needs to be uppercase 76 | 77 | ```go 78 | type Test struct { 79 | Id int `xsql:"id"` 80 | Foo string `xsql:"foo"` 81 | Bar time.Time `xsql:"bar"` // oracle uses go_ora.TimeStamp 82 | } 83 | 84 | func (t *Test) TableName() string { 85 | return "tableName" 86 | } 87 | ``` 88 | 89 | ### First 90 | 91 | Map the first row 92 | 93 | > Oracle placeholder needs to be modified to :id 94 | 95 | ```go 96 | var test Test 97 | err := DB.First(context.Background(), &test, "SELECT * FROM ${TABLE} WHERE id = ?", 1).Error 98 | if err != nil { 99 | log.Fatal(err) 100 | } 101 | ``` 102 | 103 | ### Find 104 | 105 | Map all rows 106 | 107 | ```go 108 | var tests []*Test 109 | err := DB.Find(context.Background(), &tests, "SELECT * FROM ${TABLE}").Error 110 | if err != nil { 111 | log.Fatal(err) 112 | } 113 | ``` 114 | 115 | ## Insert 116 | 117 | ```go 118 | test := Test{ 119 | Id: 0, 120 | Foo: "test", 121 | Bar: time.Now(), 122 | } 123 | err := DB.Insert(context.Background(), &test).Error 124 | if err != nil { 125 | log.Fatal(err) 126 | } 127 | ``` 128 | 129 | ## BatchInsert 130 | 131 | ```go 132 | tests := []Test{ 133 | { 134 | Id: 0, 135 | Foo: "test", 136 | Bar: time.Now(), 137 | }, 138 | { 139 | Id: 0, 140 | Foo: "test", 141 | Bar: time.Now(), 142 | }, 143 | } 144 | err := DB.BatchInsert(context.Background(), &tests).Error 145 | if err != nil { 146 | log.Fatal(err) 147 | } 148 | ``` 149 | 150 | ## Update 151 | 152 | > Oracle placeholder needs to be modified to :id 153 | 154 | Update all columns 155 | 156 | ```go 157 | test := Test{ 158 | Id: 8, 159 | Foo: "test", 160 | Bar: time.Now(), 161 | } 162 | err := DB.Update(context.Background(), &test, "id = ?", test.Id).Error 163 | if err != nil { 164 | log.Fatal(err) 165 | } 166 | ``` 167 | 168 | Update specific columns by map 169 | 170 | ```go 171 | data := map[string]interface{}{ 172 | "foo": "test", 173 | } 174 | err := DB.Model(&Test{}).Update(context.Background(), data, "id = ?", 8).Error 175 | if err != nil { 176 | log.Fatal(err) 177 | } 178 | ``` 179 | 180 | Update specific columns by struct pointer 181 | 182 | ```go 183 | test := Test{} 184 | data, err := xsql.TagValuesMap(DB.Options.Tag, &test, 185 | xsql.TagValues{ 186 | {&test.Foo, "test"}, 187 | }, 188 | ) 189 | if err != nil { 190 | log.Fatal(err) 191 | } 192 | err = DB.Model(&test).Update(context.Background(), data, "id = ?", 8).Error 193 | if err != nil { 194 | log.Fatal(err) 195 | } 196 | ``` 197 | 198 | ## Delete 199 | 200 | > Oracle placeholder needs to be modified to :id 201 | 202 | ```go 203 | test := Test{ 204 | Id: 8, 205 | Foo: "test", 206 | Bar: time.Now(), 207 | } 208 | err := DB.Model(&test).Delete(context.Background(), "id = ?", test.Id).Error 209 | if err != nil { 210 | log.Fatal(err) 211 | } 212 | ``` 213 | 214 | ```go 215 | err := DB.Model(&Test{}).Delete(context.Background(), "id = ?", 8).Error 216 | if err != nil { 217 | log.Fatal(err) 218 | } 219 | ``` 220 | 221 | ## Exec 222 | 223 | Use `Exec()` to manually execute the delete, you can also manually execute the update operation. 224 | 225 | > Oracle placeholder needs to be modified to :id 226 | 227 | ```go 228 | err := DB.Exec(context.Background(), "DELETE FROM xsql WHERE id = ?", 8).Error 229 | if err != nil { 230 | log.Fatal(err) 231 | } 232 | ``` 233 | 234 | ## Transaction 235 | 236 | ```go 237 | tx, err := DB.Begin() 238 | if err != nil { 239 | log.Fatal(err) 240 | } 241 | test := Test{ 242 | Id: 0, 243 | Foo: "test", 244 | Bar: time.Now(), 245 | } 246 | err = tx.Insert(context.Background(), &test).Error 247 | if err != nil { 248 | tx.Rollback() 249 | log.Fatal(err) 250 | } 251 | tx.Commit() 252 | ``` 253 | 254 | ## Configuration 255 | 256 | You can pass the following configuration object in the `xsql.New()` method 257 | 258 | - Default to mysql mode, when switching to oracle, you need to [modify the configuration](https://github.com/mix-go/mix/blob/master/src/xsql/dbora_test.go#L25) 259 | - `Insert()`, `BatchInsert()` can pass in configuration during execution to override insert related configuration, such as modifying InsertKey to REPLACE INTO 260 | 261 | ```go 262 | type sqlOptions struct { 263 | // Default: xsql 264 | Tag string 265 | 266 | // Default: INSERT INTO 267 | InsertKey string 268 | 269 | // Default: ${TABLE} 270 | TableKey string 271 | 272 | // Default: ? 273 | // For oracle, can be configured as :%d 274 | Placeholder string 275 | 276 | // Default: ` 277 | // For oracle, can be configured as " 278 | ColumnQuotes string 279 | 280 | // Default: 2006-01-02 15:04:05 281 | TimeLayout string 282 | 283 | // Default: time.Local 284 | TimeLocation *time.Location 285 | 286 | // Default: func(placeholder string) string { return placeholder } 287 | // For oracle, this closure can be modified to add TO_TIMESTAMP 288 | TimeFunc TimeFunc 289 | 290 | // Global debug SQL 291 | DebugFunc DebugFunc 292 | } 293 | ``` 294 | 295 | ## Log 296 | 297 | Pass in the configuration `DebugFunc` when using the `xsql.New()` method, you can print SQL information using any log library here. 298 | 299 | ```go 300 | DB := xsql.New( 301 | db, 302 | xsql.WithDebugFunc(func(l *xsql.Log) { 303 | log.Println(l) 304 | }), 305 | ) 306 | ``` 307 | 308 | The log object contains the following fields 309 | 310 | ```go 311 | type Log struct { 312 | Context context.Context `json:"context"` 313 | Duration time.Duration `json:"duration"` 314 | SQL string `json:"sql"` 315 | Bindings []interface{} `json:"bindings"` 316 | RowsAffected int64 `json:"rowsAffected"` 317 | Error error `json:"error"` 318 | } 319 | ``` 320 | 321 | ## License 322 | 323 | Apache License Version 2.0, http://www.apache.org/licenses/ 324 | -------------------------------------------------------------------------------- /fetcher.go: -------------------------------------------------------------------------------- 1 | package xsql 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | "github.com/sijms/go-ora/v2" 8 | "google.golang.org/protobuf/types/known/timestamppb" 9 | "reflect" 10 | "slices" 11 | "strconv" 12 | "time" 13 | ) 14 | 15 | type Fetcher struct { 16 | r *sql.Rows 17 | log *Log 18 | options *sqlOptions 19 | } 20 | 21 | func (t *Fetcher) First(i interface{}) error { 22 | value := reflect.ValueOf(i) 23 | if value.Kind() != reflect.Ptr { 24 | return errors.New("xsql: argument must be a pointer") 25 | } 26 | 27 | // 检查是否传入了指向指针的指针(如 **Test) 28 | if value.Elem().Kind() == reflect.Ptr { 29 | value = value.Elem() 30 | if value.IsNil() { 31 | // 初始化内部指针 32 | newInst := reflect.New(value.Type().Elem()) 33 | value.Set(newInst) 34 | } 35 | } 36 | 37 | rootValue := value.Elem() 38 | if !rootValue.IsValid() { 39 | return errors.New("xsql: argument must be a pointer") 40 | } 41 | rootType := rootValue.Type() 42 | 43 | rows, err := t.Rows() 44 | if err != nil { 45 | return err 46 | } 47 | if len(rows) == 0 { 48 | return sql.ErrNoRows 49 | } 50 | if err := t.foreach(rows[0], rootValue, rootType); err != nil { 51 | return err 52 | } 53 | 54 | return nil 55 | } 56 | 57 | func (t *Fetcher) Find(i interface{}) error { 58 | value := reflect.ValueOf(i) 59 | if value.Kind() != reflect.Ptr { 60 | return errors.New("xsql: argument must be a pointer") 61 | } 62 | root := value.Elem() 63 | if root.Kind() != reflect.Slice { 64 | return errors.New("xsql: argument must be a pointer to a slice") 65 | } 66 | elemType := root.Type().Elem() 67 | 68 | rows, err := t.Rows() 69 | if err != nil { 70 | return err 71 | } 72 | 73 | for r := 0; r < len(rows); r++ { 74 | var elemValue reflect.Value 75 | if elemType.Kind() == reflect.Ptr { 76 | // 元素类型为指针,创建指针指向的类型的实例,并直接获取指向的结构体 77 | elemValue = reflect.New(elemType.Elem()).Elem() 78 | if err := t.foreach(rows[r], elemValue, elemValue.Type()); err != nil { 79 | return err 80 | } 81 | root.Set(reflect.Append(root, elemValue.Addr())) 82 | } else { 83 | // 元素类型为值类型,创建该类型的实例 84 | elemValue = reflect.New(elemType).Elem() 85 | if err := t.foreach(rows[r], elemValue, elemValue.Type()); err != nil { 86 | return err 87 | } 88 | root.Set(reflect.Append(root, elemValue)) 89 | } 90 | } 91 | 92 | return nil 93 | } 94 | 95 | func (t *Fetcher) Rows() ([]*Row, error) { 96 | columns, err := t.r.Columns() 97 | if err != nil { 98 | return nil, err 99 | } 100 | 101 | // Make a slice for the values 102 | values := make([]interface{}, len(columns)) 103 | 104 | // rows.Scan wants '[]interface{}' as an argument, so we must copy the 105 | // references into such a slice 106 | // See http://code.google.com/p/go-wiki/wiki/InterfaceSlice for details 107 | scanArgs := make([]interface{}, len(values)) 108 | for i := range values { 109 | scanArgs[i] = &values[i] 110 | } 111 | 112 | // Fetch rows 113 | var rows []*Row 114 | 115 | for t.r.Next() { 116 | err = t.r.Scan(scanArgs...) 117 | if err != nil { 118 | return nil, err 119 | } 120 | 121 | rowMap := make(map[string]interface{}) 122 | for i, value := range values { 123 | // Here we can check if the value is nil (NULL value) 124 | if value != nil { 125 | rowMap[columns[i]] = value 126 | } 127 | } 128 | 129 | rows = append(rows, &Row{ 130 | v: rowMap, 131 | options: t.options, 132 | }) 133 | } 134 | 135 | t.log.RowsAffected = int64(len(rows)) 136 | t.options.doDebug(t.log) 137 | 138 | return rows, nil 139 | } 140 | 141 | type Row struct { 142 | v map[string]interface{} 143 | options *sqlOptions 144 | } 145 | 146 | func (t Row) Exist(field string) bool { 147 | _, ok := t.v[field] 148 | return ok 149 | } 150 | 151 | func (t Row) Get(field string) *RowResult { 152 | if v, ok := t.v[field]; ok { 153 | return &RowResult{ 154 | v: v, 155 | options: t.options, 156 | } 157 | } 158 | return &RowResult{ 159 | v: "", 160 | options: t.options, 161 | } 162 | } 163 | 164 | func (t Row) Value() map[string]interface{} { 165 | return t.v 166 | } 167 | 168 | type RowResult struct { 169 | v interface{} 170 | options *sqlOptions 171 | } 172 | 173 | func (t *RowResult) Empty() bool { 174 | if b, ok := t.v.([]uint8); ok { 175 | return len(b) == 0 176 | } 177 | if s, ok := t.v.(string); ok { 178 | return len(s) == 0 179 | } 180 | if t.v == nil { 181 | return true 182 | } 183 | return false 184 | } 185 | 186 | func (t *RowResult) String() string { 187 | switch reflect.ValueOf(t.v).Kind() { 188 | case reflect.Int: 189 | i := t.v.(int) 190 | return strconv.FormatInt(int64(i), 10) 191 | case reflect.Int8: 192 | i := t.v.(int8) 193 | return strconv.FormatInt(int64(i), 10) 194 | case reflect.Int16: 195 | i := t.v.(int16) 196 | return strconv.FormatInt(int64(i), 10) 197 | case reflect.Int32: 198 | i := t.v.(int32) 199 | return strconv.FormatInt(int64(i), 10) 200 | case reflect.Int64: 201 | i := t.v.(int64) 202 | return strconv.FormatInt(i, 10) 203 | case reflect.Uint: 204 | i := t.v.(uint) 205 | return strconv.FormatInt(int64(i), 10) 206 | case reflect.Uint8: 207 | i := t.v.(uint8) 208 | return strconv.FormatInt(int64(i), 10) 209 | case reflect.Uint16: 210 | i := t.v.(uint16) 211 | return strconv.FormatInt(int64(i), 10) 212 | case reflect.Uint32: 213 | i := t.v.(uint32) 214 | return strconv.FormatInt(int64(i), 10) 215 | case reflect.Uint64: 216 | i := t.v.(uint64) 217 | return strconv.FormatInt(int64(i), 10) 218 | case reflect.Float32: 219 | i := t.v.(float32) 220 | return strconv.FormatFloat(float64(i), 'g', -1, 32) 221 | case reflect.Float64: 222 | i := t.v.(float64) 223 | return strconv.FormatFloat(i, 'g', -1, 64) 224 | case reflect.String: 225 | return t.v.(string) 226 | default: 227 | switch v := t.v.(type) { 228 | case []uint8: 229 | return string(v) 230 | case time.Time: 231 | return v.Format(t.options.TimeLayout) 232 | } 233 | } 234 | return "" 235 | } 236 | 237 | func (t *RowResult) Int() int64 { 238 | switch reflect.ValueOf(t.v).Kind() { 239 | case reflect.Int: 240 | i := t.v.(int) 241 | return int64(i) 242 | case reflect.Int8: 243 | i := t.v.(int8) 244 | return int64(i) 245 | case reflect.Int16: 246 | i := t.v.(int16) 247 | return int64(i) 248 | case reflect.Int32: 249 | i := t.v.(int32) 250 | return int64(i) 251 | case reflect.Int64: 252 | i := t.v.(int64) 253 | return i 254 | case reflect.Uint: 255 | i := t.v.(uint) 256 | return int64(i) 257 | case reflect.Uint8: 258 | i := t.v.(uint8) 259 | return int64(i) 260 | case reflect.Uint16: 261 | i := t.v.(uint16) 262 | return int64(i) 263 | case reflect.Uint32: 264 | i := t.v.(uint32) 265 | return int64(i) 266 | case reflect.Uint64: 267 | i := t.v.(uint64) 268 | return int64(i) 269 | case reflect.Float32: 270 | i := t.v.(float32) 271 | return int64(i) 272 | case reflect.Float64: 273 | i := t.v.(float64) 274 | return int64(i) 275 | case reflect.String: 276 | s := t.v.(string) 277 | i, err := strconv.ParseInt(s, 10, 64) 278 | if err != nil { 279 | return 0 280 | } 281 | return i 282 | default: 283 | if b, ok := t.v.([]uint8); ok { 284 | s := string(b) 285 | i, err := strconv.ParseInt(s, 10, 64) 286 | if err != nil { 287 | return 0 288 | } 289 | return i 290 | } 291 | } 292 | return 0 293 | } 294 | 295 | func (t *RowResult) Time() time.Time { 296 | typ := t.Type() 297 | if typ == "string" || typ == "[]uint8" { 298 | tt, _ := time.ParseInLocation(t.options.TimeLayout, t.String(), t.options.TimeLocation) 299 | return tt 300 | } 301 | if typ == "time.Time" { 302 | return t.v.(time.Time) 303 | } 304 | if typ == "go_ora.TimeStamp" { 305 | return time.Time(t.v.(go_ora.TimeStamp)) 306 | } 307 | if typ == "*timestamppb.Timestamp" { 308 | return t.v.(*timestamppb.Timestamp).AsTime() 309 | } 310 | return time.Time{} 311 | } 312 | 313 | func (t *RowResult) Value() interface{} { 314 | return t.v 315 | } 316 | 317 | func (t *RowResult) Type() string { 318 | return reflect.TypeOf(t.v).String() 319 | } 320 | 321 | func (t *Fetcher) foreach(row *Row, value reflect.Value, typ reflect.Type) error { 322 | for n := 0; n < typ.NumField(); n++ { 323 | fieldValue := value.Field(n) 324 | fieldStruct := typ.Field(n) 325 | if fieldStruct.Anonymous { 326 | if err := t.foreach(row, fieldValue, fieldValue.Type()); err != nil { 327 | return err 328 | } 329 | continue 330 | } 331 | if !fieldValue.CanSet() { 332 | continue 333 | } 334 | tag := value.Type().Field(n).Tag.Get(t.options.Tag) 335 | if tag == "-" || tag == "_" { 336 | continue 337 | } 338 | if !row.Exist(tag) { 339 | continue 340 | } 341 | if err := t.mapped(row, tag, fieldValue, fieldValue.Type()); err != nil { 342 | return err 343 | } 344 | } 345 | return nil 346 | } 347 | 348 | func (t *Fetcher) mapped(row *Row, tag string, value reflect.Value, typ reflect.Type) (err error) { 349 | res := row.Get(tag) 350 | v := res.Value() 351 | 352 | switch value.Kind() { 353 | case reflect.Int: 354 | v = int(res.Int()) 355 | break 356 | case reflect.Int8: 357 | v = int8(res.Int()) 358 | break 359 | case reflect.Int16: 360 | v = int16(res.Int()) 361 | break 362 | case reflect.Int32: 363 | v = int32(res.Int()) 364 | break 365 | case reflect.Int64: 366 | v = res.Int() 367 | break 368 | case reflect.Uint: 369 | v = uint(res.Int()) 370 | break 371 | case reflect.Uint8: 372 | v = uint8(res.Int()) 373 | break 374 | case reflect.Uint16: 375 | v = uint16(res.Int()) 376 | break 377 | case reflect.Uint32: 378 | v = uint32(res.Int()) 379 | break 380 | case reflect.Uint64: 381 | v = uint64(res.Int()) 382 | break 383 | case reflect.String: 384 | v = res.String() 385 | break 386 | case reflect.Bool: 387 | v = res.Int() == 1 388 | break 389 | default: 390 | if !res.Empty() { 391 | vTyp := reflect.ValueOf(v).Type().String() 392 | if typ.String() == "time.Time" { // 如果结构体是time.Time类型,执行转换 393 | if vTyp == "time.Time" { 394 | // parseTime=true 395 | v = res.Value() 396 | } else { 397 | // parseTime=false 398 | if t, e := time.ParseInLocation(t.options.TimeLayout, res.String(), t.options.TimeLocation); e == nil { 399 | v = t 400 | } else { 401 | return fmt.Errorf("time parse fail for field %s: %v", tag, e) 402 | } 403 | } 404 | } else if typ.String() == "*timestamppb.Timestamp" { // 如果结构体是*timestamppb.Timestamp类型,执行转换 405 | if vTyp != "*timestamppb.Timestamp" { 406 | if t, e := time.ParseInLocation(t.options.TimeLayout, res.String(), t.options.TimeLocation); e == nil { 407 | v = timestamppb.New(t) 408 | } else { 409 | return fmt.Errorf("time parse fail for field %s: %v", tag, e) 410 | } 411 | } 412 | } else if slices.Contains([]reflect.Kind{reflect.Ptr, reflect.Struct, reflect.Slice, reflect.Array, reflect.Map}, typ.Kind()) { // 非标量用JSON反序列化处理 413 | jsonString := res.String() 414 | var newInstance reflect.Value 415 | if typ.Kind() == reflect.Ptr { 416 | newInstance = reflect.New(typ.Elem()) // 创建的都是指针 417 | } else { 418 | newInstance = reflect.New(typ) // 创建的都是指针 419 | } 420 | if e := unmarshal([]byte(jsonString), newInstance.Interface()); e != nil { 421 | return fmt.Errorf("json unmarshal error %s for field %s", e, tag) 422 | } 423 | if typ.Kind() == reflect.Ptr { 424 | v = newInstance.Interface() 425 | } else { 426 | v = newInstance.Elem().Interface() // 获取的是非指针 427 | } 428 | } 429 | } 430 | } 431 | 432 | // 设置值 433 | defer func() { 434 | if e := recover(); e != nil { 435 | err = fmt.Errorf("type mismatch for field %s: %v", tag, e) 436 | } 437 | }() 438 | value.Set(reflect.ValueOf(v).Convert(value.Type())) 439 | 440 | return 441 | } 442 | -------------------------------------------------------------------------------- /executor.go: -------------------------------------------------------------------------------- 1 | package xsql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | ora "github.com/sijms/go-ora/v2" 9 | "google.golang.org/protobuf/types/known/timestamppb" 10 | "reflect" 11 | "slices" 12 | "strings" 13 | "time" 14 | ) 15 | 16 | type executor struct { 17 | Executor 18 | } 19 | 20 | func (t *executor) Insert(ctx context.Context, data interface{}, opts *sqlOptions) (sql.Result, error) { 21 | var err error 22 | fields := make([]string, 0) 23 | vars := make([]string, 0) 24 | bindArgs := make([]interface{}, 0) 25 | 26 | value := reflect.ValueOf(data) 27 | typ := reflect.TypeOf(data) 28 | switch value.Kind() { 29 | case reflect.Ptr: 30 | return t.Insert(ctx, value.Elem().Interface(), opts) 31 | case reflect.Struct: 32 | fields, vars, bindArgs, err = t.foreachInsert(value, typ, opts) 33 | if err != nil { 34 | return nil, err 35 | } 36 | break 37 | default: 38 | return nil, errors.New("xsql: only support struct") 39 | } 40 | 41 | SQL := fmt.Sprintf(`%s %s (%s) VALUES (%s)`, opts.InsertKey, opts.TableKey, opts.ColumnQuotes+strings.Join(fields, opts.ColumnQuotes+", "+opts.ColumnQuotes)+opts.ColumnQuotes, strings.Join(vars, `, `)) 42 | SQL = tableReplace(data, SQL, opts) 43 | 44 | startTime := time.Now() 45 | res, err := t.Executor.Exec(SQL, bindArgs...) 46 | var rowsAffected int64 47 | if res != nil { 48 | rowsAffected, _ = res.RowsAffected() 49 | } 50 | l := &Log{ 51 | Context: ctx, 52 | Duration: time.Now().Sub(startTime), 53 | SQL: SQL, 54 | Bindings: bindArgs, 55 | RowsAffected: rowsAffected, 56 | Error: err, 57 | } 58 | opts.doDebug(l) 59 | if err != nil { 60 | return nil, err 61 | } 62 | 63 | return res, nil 64 | } 65 | 66 | func (t *executor) BatchInsert(ctx context.Context, array interface{}, opts *sqlOptions) (sql.Result, error) { 67 | fields := make([]string, 0) 68 | valueSql := make([]string, 0) 69 | bindArgs := make([]interface{}, 0) 70 | 71 | // check 72 | value := reflect.ValueOf(array) 73 | switch value.Kind() { 74 | case reflect.Ptr: 75 | return t.BatchInsert(ctx, value.Elem().Interface(), opts) 76 | case reflect.Array, reflect.Slice: 77 | break 78 | default: 79 | return nil, errors.New("xsql: only support array, slice") 80 | } 81 | if value.Len() == 0 { 82 | return nil, errors.New("xsql: array, slice length cannot be 0") 83 | } 84 | 85 | // fields 86 | switch value.Index(0).Kind() { 87 | case reflect.Struct: 88 | subValue := value.Index(0) 89 | subType := subValue.Type() 90 | fields = t.foreachBatchInsertFields(subValue, subType, opts) 91 | break 92 | default: 93 | return nil, errors.New("xsql: only support array, slice") 94 | } 95 | 96 | // values 97 | switch value.Kind() { 98 | case reflect.Slice, reflect.Array: 99 | for r := 0; r < value.Len(); r++ { 100 | switch value.Index(r).Kind() { 101 | case reflect.Struct: 102 | subValue := value.Index(r) 103 | vars, b, err := t.foreachBatchInsertValues(0, subValue, subValue.Type(), opts) 104 | if err != nil { 105 | return nil, err 106 | } 107 | bindArgs = append(bindArgs, b...) 108 | valueSql = append(valueSql, fmt.Sprintf("(%s)", strings.Join(vars, `, `))) 109 | break 110 | default: 111 | return nil, errors.New("xsql: only support array, slice") 112 | } 113 | } 114 | break 115 | default: 116 | return nil, errors.New("xsql: only support array, slice") 117 | } 118 | 119 | SQL := fmt.Sprintf(`%s %s (%s) VALUES %s`, opts.InsertKey, opts.TableKey, opts.ColumnQuotes+strings.Join(fields, opts.ColumnQuotes+", "+opts.ColumnQuotes)+opts.ColumnQuotes, strings.Join(valueSql, ", ")) 120 | SQL = tableReplace(array, SQL, opts) 121 | 122 | startTime := time.Now() 123 | res, err := t.Executor.Exec(SQL, bindArgs...) 124 | var rowsAffected int64 125 | if res != nil { 126 | rowsAffected, _ = res.RowsAffected() 127 | } 128 | l := &Log{ 129 | Context: ctx, 130 | Duration: time.Now().Sub(startTime), 131 | SQL: SQL, 132 | Bindings: bindArgs, 133 | RowsAffected: rowsAffected, 134 | Error: err, 135 | } 136 | opts.doDebug(l) 137 | if err != nil { 138 | return nil, err 139 | } 140 | 141 | return res, nil 142 | } 143 | 144 | func (t *executor) model(s interface{}, opts *sqlOptions) *ModelExecutor { 145 | table := tableReplace(s, opts.TableKey, opts) 146 | return &ModelExecutor{ 147 | Executor: t.Executor, 148 | Options: opts, 149 | TableName: table, 150 | } 151 | } 152 | 153 | func (t *executor) Update(ctx context.Context, data interface{}, expr string, args []interface{}, opts *sqlOptions) (sql.Result, error) { 154 | var err error 155 | set := make([]string, 0) 156 | bindArgs := make([]interface{}, 0) 157 | 158 | value := reflect.ValueOf(data) 159 | typ := reflect.TypeOf(data) 160 | switch value.Kind() { 161 | case reflect.Ptr: 162 | return t.Update(ctx, value.Elem().Interface(), expr, args, opts) 163 | case reflect.Struct: 164 | set, bindArgs, err = t.foreachUpdate(value, typ, opts) 165 | if err != nil { 166 | return nil, err 167 | } 168 | break 169 | default: 170 | return nil, errors.New("xsql: only support struct") 171 | } 172 | 173 | where := "" 174 | if expr != "" { 175 | where = fmt.Sprintf(` WHERE %s`, expr) 176 | bindArgs = append(bindArgs, args...) 177 | } 178 | 179 | SQL := fmt.Sprintf(`UPDATE %s SET %s%s`, opts.TableKey, strings.Join(set, ", "), where) 180 | SQL = tableReplace(data, SQL, opts) 181 | 182 | startTime := time.Now() 183 | res, err := t.Executor.Exec(SQL, bindArgs...) 184 | var rowsAffected int64 185 | if res != nil { 186 | rowsAffected, _ = res.RowsAffected() 187 | } 188 | l := &Log{ 189 | Context: ctx, 190 | Duration: time.Now().Sub(startTime), 191 | SQL: SQL, 192 | Bindings: bindArgs, 193 | RowsAffected: rowsAffected, 194 | Error: err, 195 | } 196 | opts.doDebug(l) 197 | if err != nil { 198 | return nil, err 199 | } 200 | 201 | return res, nil 202 | } 203 | 204 | func (t *executor) Exec(ctx context.Context, query string, args []interface{}, opts *sqlOptions) (sql.Result, error) { 205 | startTime := time.Now() 206 | res, err := t.Executor.Exec(query, args...) 207 | var rowsAffected int64 208 | if res != nil { 209 | rowsAffected, _ = res.RowsAffected() 210 | } 211 | l := &Log{ 212 | Context: ctx, 213 | Duration: time.Now().Sub(startTime), 214 | SQL: query, 215 | Bindings: args, 216 | RowsAffected: rowsAffected, 217 | Error: err, 218 | } 219 | opts.doDebug(l) 220 | if err != nil { 221 | return nil, err 222 | } 223 | 224 | return res, err 225 | } 226 | 227 | func isTime(typ string) bool { 228 | switch typ { 229 | case "time.Time", "go_ora.TimeStamp", "*timestamppb.Timestamp": 230 | return true 231 | default: 232 | return false 233 | } 234 | } 235 | 236 | func formatTime(typ string, v interface{}, opts *sqlOptions) string { 237 | switch typ { 238 | case "time.Time": 239 | return v.(time.Time).Format(opts.TimeLayout) 240 | case "go_ora.TimeStamp": 241 | return time.Time(v.(ora.TimeStamp)).Format(opts.TimeLayout) 242 | case "*timestamppb.Timestamp": 243 | return v.(*timestamppb.Timestamp).AsTime().Format(opts.TimeLayout) 244 | default: 245 | return "" 246 | } 247 | } 248 | 249 | func (t *executor) foreachInsert(value reflect.Value, typ reflect.Type, opts *sqlOptions) (fields, vars []string, bindArgs []interface{}, err error) { 250 | for n := 0; n < value.NumField(); n++ { 251 | fieldValue := value.Field(n) 252 | fieldStruct := typ.Field(n) 253 | 254 | // Embedded Structs 255 | if fieldStruct.Anonymous { 256 | f, v, b, e := t.foreachInsert(fieldValue, fieldValue.Type(), opts) 257 | if e != nil { 258 | return nil, nil, nil, e 259 | } 260 | fields = append(fields, f...) 261 | vars = append(vars, v...) 262 | bindArgs = append(bindArgs, b...) 263 | continue 264 | } 265 | 266 | if !fieldValue.CanInterface() { 267 | continue 268 | } 269 | 270 | tag := value.Type().Field(n).Tag.Get(opts.Tag) 271 | if tag == "" || tag == "-" { 272 | continue 273 | } 274 | 275 | fields = append(fields, tag) 276 | 277 | vTyp := fieldValue.Type().String() 278 | isTime := isTime(vTyp) 279 | 280 | var v string 281 | if opts.Placeholder == "?" { 282 | v = opts.Placeholder 283 | } else { 284 | v = fmt.Sprintf(opts.Placeholder, n) 285 | } 286 | if isTime { 287 | v = opts.TimeFunc(v) 288 | } 289 | vars = append(vars, v) 290 | 291 | var a interface{} 292 | if isTime { 293 | a = formatTime(vTyp, fieldValue.Interface(), opts) 294 | } else { 295 | // 非标量用JSON序列化处理 296 | if slices.Contains([]reflect.Kind{reflect.Ptr, reflect.Struct, reflect.Slice, reflect.Array, reflect.Map}, fieldValue.Kind()) { 297 | b, e := marshal(fieldValue.Interface()) 298 | if e != nil { 299 | return nil, nil, nil, fmt.Errorf("json unmarshal error %s for field %s", e, tag) 300 | } 301 | a = string(b) 302 | } else { 303 | a = fieldValue.Interface() 304 | } 305 | } 306 | bindArgs = append(bindArgs, a) 307 | } 308 | return 309 | } 310 | 311 | func (t *executor) foreachBatchInsertFields(value reflect.Value, typ reflect.Type, opts *sqlOptions) (fields []string) { 312 | for n := 0; n < value.NumField(); n++ { 313 | fieldValue := value.Field(n) 314 | fieldStruct := typ.Field(n) 315 | if fieldStruct.Anonymous { 316 | f := t.foreachBatchInsertFields(fieldValue, fieldValue.Type(), opts) 317 | fields = append(fields, f...) 318 | continue 319 | } 320 | 321 | if !fieldValue.CanInterface() { 322 | continue 323 | } 324 | 325 | tag := value.Type().Field(n).Tag.Get(opts.Tag) 326 | if tag == "" || tag == "-" { 327 | continue 328 | } 329 | 330 | fields = append(fields, tag) 331 | } 332 | return 333 | } 334 | 335 | func (t *executor) foreachBatchInsertValues(ai int, value reflect.Value, typ reflect.Type, opts *sqlOptions) (vars []string, bindArgs []interface{}, err error) { 336 | for n := 0; n < value.NumField(); n++ { 337 | fieldValue := value.Field(n) 338 | fieldStruct := typ.Field(n) 339 | 340 | // Embedded Structs 341 | if fieldStruct.Anonymous { 342 | v, b, e := t.foreachBatchInsertValues(ai+1000, fieldValue, fieldValue.Type(), opts) 343 | if e != nil { 344 | return nil, nil, e 345 | } 346 | vars = append(vars, v...) 347 | bindArgs = append(bindArgs, b...) 348 | continue 349 | } 350 | 351 | if !fieldValue.CanInterface() { 352 | continue 353 | } 354 | 355 | tag := value.Type().Field(n).Tag.Get(opts.Tag) 356 | if tag == "" || tag == "-" { 357 | continue 358 | } 359 | 360 | vTyp := fieldValue.Type().String() 361 | isTime := isTime(vTyp) 362 | 363 | var v string 364 | if opts.Placeholder == "?" { 365 | v = opts.Placeholder 366 | } else { 367 | v = fmt.Sprintf(opts.Placeholder, ai) 368 | ai += 1 369 | } 370 | if isTime { 371 | v = opts.TimeFunc(v) 372 | } 373 | vars = append(vars, v) 374 | 375 | var a interface{} 376 | if isTime { 377 | a = formatTime(vTyp, fieldValue.Interface(), opts) 378 | } else { 379 | // 非标量用JSON序列化处理 380 | if slices.Contains([]reflect.Kind{reflect.Ptr, reflect.Struct, reflect.Slice, reflect.Array, reflect.Map}, fieldValue.Kind()) { 381 | b, e := marshal(fieldValue.Interface()) 382 | if e != nil { 383 | return nil, nil, fmt.Errorf("json unmarshal error %s for field %s", e, tag) 384 | } 385 | a = string(b) 386 | } else { 387 | a = fieldValue.Interface() 388 | } 389 | } 390 | bindArgs = append(bindArgs, a) 391 | } 392 | return 393 | } 394 | 395 | func (t *executor) foreachUpdate(value reflect.Value, typ reflect.Type, opts *sqlOptions) (set []string, bindArgs []interface{}, err error) { 396 | for n := 0; n < value.NumField(); n++ { 397 | fieldValue := value.Field(n) 398 | fieldStruct := typ.Field(n) 399 | 400 | // Embedded Structs 401 | if fieldStruct.Anonymous { 402 | s, b, e := t.foreachUpdate(fieldValue, fieldValue.Type(), opts) 403 | if e != nil { 404 | return nil, nil, e 405 | } 406 | set = append(set, s...) 407 | bindArgs = append(bindArgs, b...) 408 | continue 409 | } 410 | 411 | if !fieldValue.CanInterface() { 412 | continue 413 | } 414 | 415 | tag := value.Type().Field(n).Tag.Get(opts.Tag) 416 | if tag == "" || tag == "-" { 417 | continue 418 | } 419 | 420 | vTyp := fieldValue.Type().String() 421 | isTime := isTime(vTyp) 422 | 423 | var v string 424 | if opts.Placeholder == "?" { 425 | v = opts.Placeholder 426 | } else { 427 | v = fmt.Sprintf(opts.Placeholder, n) 428 | } 429 | if isTime { 430 | v = opts.TimeFunc(v) 431 | } 432 | set = append(set, fmt.Sprintf("%s = %s", opts.ColumnQuotes+tag+opts.ColumnQuotes, v)) 433 | 434 | var a interface{} 435 | if isTime { 436 | a = formatTime(vTyp, fieldValue.Interface(), opts) 437 | } else { 438 | // 非标量用JSON序列化处理 439 | if slices.Contains([]reflect.Kind{reflect.Ptr, reflect.Struct, reflect.Slice, reflect.Array, reflect.Map}, fieldValue.Kind()) { 440 | b, e := marshal(fieldValue.Interface()) 441 | if e != nil { 442 | return nil, nil, fmt.Errorf("json unmarshal error %s for field %s", e, tag) 443 | } 444 | a = string(b) 445 | } else { 446 | a = fieldValue.Interface() 447 | } 448 | } 449 | bindArgs = append(bindArgs, a) 450 | } 451 | return 452 | } 453 | -------------------------------------------------------------------------------- /db_test.go: -------------------------------------------------------------------------------- 1 | package xsql_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "encoding/json" 7 | _ "github.com/go-sql-driver/mysql" 8 | "github.com/mix-go/xsql" 9 | "github.com/mix-go/xsql/testdata" 10 | "github.com/stretchr/testify/assert" 11 | "google.golang.org/protobuf/types/known/timestamppb" 12 | "log" 13 | "os" 14 | "testing" 15 | "time" 16 | ) 17 | 18 | type Enum int32 19 | 20 | type Test struct { 21 | Id int `xsql:"id"` 22 | Foo string `xsql:"foo"` 23 | Bar time.Time `xsql:"bar"` 24 | Bool bool `xsql:"bool" json:"-"` 25 | Enum Enum `xsql:"enum" json:"-"` 26 | } 27 | 28 | type Test1 struct { 29 | Id int `xsql:"id"` 30 | } 31 | 32 | type Test2 struct { 33 | Foo string `xsql:"foo"` 34 | Bar time.Time `xsql:"bar"` 35 | } 36 | 37 | type Test3 struct { 38 | Id int `xsql:"id"` 39 | Foo string `xsql:"foo"` 40 | Bar *timestamppb.Timestamp `xsql:"bar"` 41 | } 42 | 43 | type TestJsonStruct struct { 44 | Test 45 | Json JsonItem `xsql:"json"` 46 | } 47 | 48 | type TestJsonStructPtr struct { 49 | Test 50 | Json *JsonItem `xsql:"json"` 51 | } 52 | 53 | type TestJsonSlice struct { 54 | Test 55 | Json []int `xsql:"json"` 56 | } 57 | 58 | type TestJsonSlicePtr struct { 59 | Test 60 | Json []*JsonItem `xsql:"json"` 61 | } 62 | 63 | type JsonItem struct { 64 | Foo string `xsql:"foo"` 65 | } 66 | 67 | type TestEmbedding struct { 68 | Test1 69 | Test2 70 | } 71 | 72 | type TestPbStruct struct { 73 | testdata.Device 74 | } 75 | 76 | func (t *Test) TableName() string { 77 | return "xsql" 78 | } 79 | 80 | func (t *Test1) TableName() string { 81 | return "xsql" 82 | } 83 | 84 | func (t *Test2) TableName() string { 85 | return "xsql" 86 | } 87 | 88 | func (t *Test3) TableName() string { 89 | return "xsql" 90 | } 91 | 92 | func (t *TestEmbedding) TableName() string { 93 | return "xsql" 94 | } 95 | 96 | func (t *TestJsonStruct) TableName() string { 97 | return "xsql" 98 | } 99 | 100 | func (t *TestJsonStructPtr) TableName() string { 101 | return "xsql" 102 | } 103 | 104 | func (t *TestJsonSlice) TableName() string { 105 | return "xsql" 106 | } 107 | 108 | func (t *TestPbStruct) TableName() string { 109 | return "devices" 110 | } 111 | 112 | func newDB() *xsql.DB { 113 | db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8&parseTime=true&loc=UTC&multiStatements=true") 114 | if err != nil { 115 | log.Fatal(err) 116 | } 117 | return xsql.New( 118 | db, 119 | xsql.WithDebugFunc(func(l *xsql.Log) { 120 | log.Println(l) 121 | }), 122 | xsql.WithTimeLocation(time.UTC), 123 | ) 124 | } 125 | 126 | func TestCreateTable(t *testing.T) { 127 | a := assert.New(t) 128 | DB := newDB() 129 | 130 | b, err := os.ReadFile("./xsql.sql") 131 | a.Nil(err) 132 | err = DB.Exec(context.Background(), string(b)).Error 133 | a.Nil(err) 134 | } 135 | 136 | func TestClear(t *testing.T) { 137 | a := assert.New(t) 138 | 139 | DB := newDB() 140 | 141 | err := DB.Exec(context.Background(), "DELETE FROM xsql WHERE id > 2").Error 142 | a.Nil(err) 143 | } 144 | 145 | func TestDebugFunc(t *testing.T) { 146 | a := assert.New(t) 147 | 148 | DB := newDB() 149 | 150 | test := Test{ 151 | Id: 0, 152 | Foo: "test", 153 | Bar: time.Now(), 154 | } 155 | err := DB.Update(context.Background(), &test, "id = ?", 0).Error 156 | a.Nil(err) 157 | } 158 | 159 | func TestQuery(t *testing.T) { 160 | a := assert.New(t) 161 | 162 | DB := newDB() 163 | 164 | rows, err := DB.Query(context.Background(), "SELECT * FROM xsql") 165 | a.Nil(err) 166 | bar := rows[0].Get("bar").String() 167 | log.Println(bar) 168 | a.Equal(bar, "2022-04-12 23:50:00") 169 | } 170 | 171 | func TestInsert(t *testing.T) { 172 | a := assert.New(t) 173 | 174 | DB := newDB() 175 | 176 | test := Test{ 177 | Id: 0, 178 | Foo: "test", 179 | Bar: time.Now(), 180 | } 181 | err := DB.Insert(context.Background(), &test).Error 182 | a.Nil(err) 183 | } 184 | 185 | func TestEmbeddingInsert(t *testing.T) { 186 | a := assert.New(t) 187 | 188 | DB := newDB() 189 | 190 | test := TestEmbedding{ 191 | Test1: Test1{ 192 | Id: 0, 193 | }, 194 | Test2: Test2{ 195 | Foo: "test", 196 | Bar: time.Now(), 197 | }, 198 | } 199 | err := DB.Insert(context.Background(), &test).Error 200 | a.Nil(err) 201 | } 202 | 203 | func TestBatchInsert(t *testing.T) { 204 | a := assert.New(t) 205 | 206 | DB := newDB() 207 | 208 | tests := []Test{ 209 | { 210 | Id: 0, 211 | Foo: "test1", 212 | Bar: time.Now(), 213 | }, 214 | { 215 | Id: 0, 216 | Foo: "test2", 217 | Bar: time.Now(), 218 | }, 219 | } 220 | err := DB.BatchInsert(context.Background(), &tests).Error 221 | a.Nil(err) 222 | } 223 | 224 | func TestEmbeddingBatchInsert(t *testing.T) { 225 | a := assert.New(t) 226 | 227 | DB := newDB() 228 | 229 | tests := []TestEmbedding{ 230 | { 231 | Test1: Test1{ 232 | Id: 0, 233 | }, 234 | Test2: Test2{ 235 | Foo: "test1", 236 | Bar: time.Now(), 237 | }, 238 | }, 239 | { 240 | Test1: Test1{ 241 | Id: 0, 242 | }, 243 | Test2: Test2{ 244 | Foo: "test2", 245 | Bar: time.Now(), 246 | }, 247 | }, 248 | } 249 | err := DB.BatchInsert(context.Background(), &tests).Error 250 | a.Nil(err) 251 | } 252 | 253 | func TestEmbeddingUpdate(t *testing.T) { 254 | a := assert.New(t) 255 | 256 | DB := newDB() 257 | test := TestEmbedding{ 258 | Test1: Test1{ 259 | Id: 999, 260 | }, 261 | Test2: Test2{ 262 | Foo: "test_update", 263 | Bar: time.Now(), 264 | }, 265 | } 266 | err := DB.Update(context.Background(), &test, "id = ?", 8).Error 267 | a.Nil(err) 268 | } 269 | 270 | func TestUpdate(t *testing.T) { 271 | a := assert.New(t) 272 | 273 | DB := newDB() 274 | 275 | test := Test{ 276 | Id: 8, 277 | Foo: "test_update_1", 278 | Bar: time.Now(), 279 | } 280 | err := DB.Update(context.Background(), &test, "id = ?", 999).Error 281 | a.Nil(err) 282 | } 283 | 284 | func TestUpdateColumns(t *testing.T) { 285 | a := assert.New(t) 286 | 287 | DB := newDB() 288 | 289 | data := map[string]interface{}{ 290 | "foo": "test_update_2", 291 | } 292 | err := DB.Model(&Test{}).Update(context.Background(), data, "id = ?", 8).Error 293 | a.Nil(err) 294 | 295 | data = map[string]interface{}{ 296 | "foo": timestamppb.Now(), 297 | } 298 | err = DB.Model(&Test{}).Update(context.Background(), data, "id = ?", 8).Error 299 | 300 | a.Nil(err) 301 | } 302 | 303 | func TestUpdateTagValuesMap(t *testing.T) { 304 | a := assert.New(t) 305 | 306 | DB := newDB() 307 | 308 | test := Test{} 309 | data, err := xsql.TagValuesMap(DB.Options.Tag, &test, 310 | xsql.TagValues{ 311 | {&test.Foo, "test_update_3"}, 312 | }, 313 | ) 314 | a.Nil(err) 315 | 316 | err = DB.Model(&test).Update(context.Background(), data, "id = ?", 8).Error 317 | a.Nil(err) 318 | } 319 | 320 | func TestEmbeddingUpdateTagValuesMap(t *testing.T) { 321 | a := assert.New(t) 322 | 323 | DB := newDB() 324 | 325 | test := TestEmbedding{} 326 | data, err := xsql.TagValuesMap(DB.Options.Tag, &test, 327 | xsql.TagValues{ 328 | {&test.Foo, "test_update_4"}, 329 | }, 330 | ) 331 | a.Nil(err) 332 | 333 | err = DB.Model(&test).Update(context.Background(), data, "id = ?", 8).Error 334 | a.Nil(err) 335 | } 336 | 337 | func TestDelete(t *testing.T) { 338 | a := assert.New(t) 339 | 340 | DB := newDB() 341 | 342 | test := Test{ 343 | Id: 8, 344 | Foo: "test", 345 | Bar: time.Now(), 346 | } 347 | err := DB.Model(&test).Delete(context.Background(), "id = ?", test.Id).Error 348 | a.Nil(err) 349 | 350 | err = DB.Model(&Test{}).Delete(context.Background(), "id = ?", 8).Error 351 | a.Nil(err) 352 | } 353 | 354 | func TestExec(t *testing.T) { 355 | a := assert.New(t) 356 | 357 | DB := newDB() 358 | 359 | err := DB.Exec(context.Background(), "DELETE FROM xsql WHERE id = ?", 7).Error 360 | 361 | a.Nil(err) 362 | } 363 | 364 | func TestFirst(t *testing.T) { 365 | a := assert.New(t) 366 | 367 | DB := newDB() 368 | 369 | var test Test 370 | err := DB.First(context.Background(), &test, "SELECT * FROM ${TABLE}").Error 371 | a.Nil(err) 372 | 373 | b, _ := json.Marshal(test) 374 | a.Equal(string(b), `{"Id":1,"Foo":"v","Bar":"2022-04-12T23:50:00Z"}`) 375 | // bool 376 | a.Equal(test.Bool, true) 377 | // enum 378 | a.IsType(Enum(0), test.Enum) 379 | a.Equal(Enum(1), test.Enum) 380 | } 381 | 382 | func TestFirstPtr(t *testing.T) { 383 | a := assert.New(t) 384 | 385 | DB := newDB() 386 | 387 | var test *TestJsonStructPtr 388 | err := DB.First(context.Background(), &test, "SELECT * FROM ${TABLE}").Error 389 | a.Nil(err) 390 | 391 | b, _ := json.Marshal(test) 392 | a.Equal(string(b), `{"Id":1,"Foo":"v","Bar":"2022-04-12T23:50:00Z","Json":{"Foo":"bar"}}`) 393 | // bool 394 | a.Equal(test.Bool, true) 395 | // enum 396 | a.IsType(Enum(0), test.Enum) 397 | a.Equal(Enum(1), test.Enum) 398 | } 399 | 400 | func TestFirstEmbedding(t *testing.T) { 401 | a := assert.New(t) 402 | 403 | DB := newDB() 404 | 405 | var test TestEmbedding 406 | err := DB.First(context.Background(), &test, "SELECT * FROM ${TABLE}").Error 407 | a.Nil(err) 408 | 409 | b, _ := json.Marshal(test) 410 | a.Equal(string(b), `{"Id":1,"Foo":"v","Bar":"2022-04-12T23:50:00Z"}`) 411 | } 412 | 413 | func TestFirstPart(t *testing.T) { 414 | a := assert.New(t) 415 | 416 | DB := newDB() 417 | 418 | var test Test 419 | err := DB.First(context.Background(), &test, "SELECT foo FROM ${TABLE}").Error 420 | a.Nil(err) 421 | 422 | b, _ := json.Marshal(test) 423 | a.Equal(string(b), `{"Id":0,"Foo":"v","Bar":"0001-01-01T00:00:00Z"}`) 424 | } 425 | 426 | func TestFirstTableKey(t *testing.T) { 427 | a := assert.New(t) 428 | 429 | DB := newDB() 430 | 431 | var test Test 432 | err := DB.First(context.Background(), &test, "SELECT * FROM ${TABLE}").Error 433 | a.Nil(err) 434 | 435 | b, _ := json.Marshal(test) 436 | a.Equal(string(b), `{"Id":1,"Foo":"v","Bar":"2022-04-12T23:50:00Z"}`) 437 | } 438 | 439 | func TestFind(t *testing.T) { 440 | a := assert.New(t) 441 | 442 | DB := newDB() 443 | 444 | var tests []Test 445 | err := DB.Find(context.Background(), &tests, "SELECT * FROM ${TABLE} LIMIT 2").Error 446 | a.Nil(err) 447 | 448 | b, _ := json.Marshal(tests) 449 | a.Equal(string(b), `[{"Id":1,"Foo":"v","Bar":"2022-04-12T23:50:00Z"},{"Id":2,"Foo":"v1","Bar":"2022-04-13T23:50:00Z"}]`) 450 | } 451 | 452 | func TestFindPtr(t *testing.T) { 453 | a := assert.New(t) 454 | 455 | DB := newDB() 456 | 457 | var tests []*TestJsonStructPtr 458 | err := DB.Find(context.Background(), &tests, "SELECT * FROM ${TABLE} LIMIT 1").Error 459 | a.Nil(err) 460 | 461 | b, _ := json.Marshal(tests) 462 | a.Equal(string(b), `[{"Id":1,"Foo":"v","Bar":"2022-04-12T23:50:00Z","Json":{"Foo":"bar"}}]`) 463 | } 464 | 465 | func TestEmbeddingFind(t *testing.T) { 466 | a := assert.New(t) 467 | 468 | DB := newDB() 469 | 470 | var tests []TestEmbedding 471 | err := DB.Find(context.Background(), &tests, "SELECT * FROM ${TABLE} LIMIT 2").Error 472 | a.Nil(err) 473 | 474 | b, _ := json.Marshal(tests) 475 | a.Equal(string(b), `[{"Id":1,"Foo":"v","Bar":"2022-04-12T23:50:00Z"},{"Id":2,"Foo":"v1","Bar":"2022-04-13T23:50:00Z"}]`) 476 | } 477 | 478 | func TestFindPart(t *testing.T) { 479 | a := assert.New(t) 480 | 481 | DB := newDB() 482 | 483 | var tests []Test 484 | err := DB.Find(context.Background(), &tests, "SELECT foo FROM ${TABLE} LIMIT 2").Error 485 | a.Nil(err) 486 | 487 | b, _ := json.Marshal(tests) 488 | a.Equal(string(b), `[{"Id":0,"Foo":"v","Bar":"0001-01-01T00:00:00Z"},{"Id":0,"Foo":"v1","Bar":"0001-01-01T00:00:00Z"}]`) 489 | } 490 | 491 | func TestFindTableKey(t *testing.T) { 492 | a := assert.New(t) 493 | 494 | DB := newDB() 495 | 496 | var tests []Test 497 | err := DB.Find(context.Background(), &tests, "SELECT * FROM ${TABLE} LIMIT 2").Error 498 | a.Nil(err) 499 | 500 | b, _ := json.Marshal(tests) 501 | a.Equal(string(b), `[{"Id":1,"Foo":"v","Bar":"2022-04-12T23:50:00Z"},{"Id":2,"Foo":"v1","Bar":"2022-04-13T23:50:00Z"}]`) 502 | } 503 | 504 | func TestTxCommit(t *testing.T) { 505 | a := assert.New(t) 506 | 507 | DB := newDB() 508 | 509 | tx, _ := DB.Begin() 510 | 511 | test := Test{ 512 | Id: 0, 513 | Foo: "test", 514 | Bar: time.Now(), 515 | } 516 | err := tx.Insert(context.Background(), &test).Error 517 | a.Nil(err) 518 | 519 | err = tx.Commit() 520 | a.Nil(err) 521 | } 522 | 523 | func TestTxRollback(t *testing.T) { 524 | a := assert.New(t) 525 | 526 | DB := newDB() 527 | 528 | tx, _ := DB.Begin() 529 | 530 | test := Test{ 531 | Id: 0, 532 | Foo: "test", 533 | Bar: time.Now(), 534 | } 535 | err := tx.Insert(context.Background(), &test).Error 536 | a.Nil(err) 537 | 538 | err = tx.Rollback() 539 | a.Nil(err) 540 | } 541 | 542 | func TestPbTimestamp(t *testing.T) { 543 | a := assert.New(t) 544 | DB := newDB() 545 | 546 | // Insert 547 | now := timestamppb.Now() 548 | log.Println(now.AsTime().Format(time.RFC3339)) 549 | test := Test3{ 550 | Id: 0, 551 | Foo: "test_pb_timestamp", 552 | Bar: now, 553 | } 554 | res := DB.Insert(context.Background(), &test) 555 | err := res.Error 556 | a.Nil(err) 557 | insertId := res.LastInsertId 558 | 559 | // First 560 | var test2 Test3 561 | err = DB.First(context.Background(), &test2, "SELECT * FROM ${TABLE} WHERE id = ?", insertId).Error 562 | a.Nil(err) 563 | // Timestamp 564 | a.IsType(×tamppb.Timestamp{}, test2.Bar) 565 | a.Equal(test2.Bar.Seconds, now.Seconds) 566 | } 567 | 568 | func TestInsertPbJson(t *testing.T) { 569 | a := assert.New(t) 570 | DB := newDB() 571 | 572 | test1 := TestJsonStruct{ 573 | Test: Test{ 574 | Id: 0, 575 | Foo: "", 576 | Bar: time.Time{}, 577 | Bool: false, 578 | Enum: 0, 579 | }, 580 | Json: JsonItem{Foo: `bar`}, 581 | } 582 | err := DB.Insert(context.Background(), &test1).Error 583 | a.Nil(err) 584 | 585 | test2 := TestJsonStructPtr{ 586 | Test: Test{ 587 | Id: 0, 588 | Foo: "", 589 | Bar: time.Time{}, 590 | Bool: false, 591 | Enum: 0, 592 | }, 593 | Json: &JsonItem{Foo: `bar`}, 594 | } 595 | err = DB.Insert(context.Background(), &test2).Error 596 | a.Nil(err) 597 | 598 | test3 := TestJsonSlice{ 599 | Test: Test{ 600 | Id: 0, 601 | Foo: "", 602 | Bar: time.Time{}, 603 | Bool: false, 604 | Enum: 0, 605 | }, 606 | Json: []int{1, 2, 3}, 607 | } 608 | err = DB.Insert(context.Background(), &test3).Error 609 | a.Nil(err) 610 | 611 | test4 := TestJsonSlicePtr{ 612 | Test: Test{ 613 | Id: 0, 614 | Foo: "", 615 | Bar: time.Time{}, 616 | Bool: false, 617 | Enum: 0, 618 | }, 619 | Json: []*JsonItem{{Foo: `bar1`}, {Foo: `bar2`}, {Foo: `bar3`}}, 620 | } 621 | err = DB.Insert(context.Background(), &test4).Error 622 | a.Nil(err) 623 | } 624 | 625 | func TestFirstPbJsonField(t *testing.T) { 626 | a := assert.New(t) 627 | DB := newDB() 628 | 629 | var test1 TestJsonStruct 630 | err := DB.First(context.Background(), &test1, "SELECT * FROM ${TABLE} WHERE id = 1").Error 631 | a.Nil(err) 632 | a.NotEmpty(test1.Json) 633 | 634 | var test2 TestJsonStructPtr 635 | err = DB.First(context.Background(), &test2, "SELECT * FROM ${TABLE} WHERE id = 1").Error 636 | a.Nil(err) 637 | a.NotEmpty(test2.Json) 638 | 639 | var test3 TestJsonSlice 640 | err = DB.First(context.Background(), &test3, "SELECT * FROM ${TABLE} WHERE id = 2").Error 641 | a.Nil(err) 642 | a.NotEmpty(test3.Json) 643 | 644 | var test4 TestJsonSlicePtr 645 | err = DB.First(context.Background(), &test4, "SELECT * FROM ${TABLE} WHERE id = 1006").Error 646 | a.Nil(err) 647 | a.NotEmpty(test4.Json) 648 | } 649 | 650 | func TestFindPbJsonField(t *testing.T) { 651 | a := assert.New(t) 652 | DB := newDB() 653 | 654 | var test1 []*TestJsonStruct 655 | err := DB.Find(context.Background(), &test1, "SELECT * FROM ${TABLE} WHERE id = 1").Error 656 | a.Nil(err) 657 | a.NotEmpty(test1) 658 | 659 | var test2 []*TestJsonStructPtr 660 | err = DB.Find(context.Background(), &test2, "SELECT * FROM ${TABLE} WHERE id = 1").Error 661 | a.Nil(err) 662 | a.NotEmpty(test2) 663 | 664 | var test3 []*TestJsonSlice 665 | err = DB.Find(context.Background(), &test3, "SELECT * FROM ${TABLE} WHERE id = 2").Error 666 | a.Nil(err) 667 | a.NotEmpty(test3) 668 | 669 | var test4 []*TestJsonSlicePtr 670 | err = DB.Find(context.Background(), &test4, "SELECT * FROM ${TABLE} WHERE id = 1006").Error 671 | a.Nil(err) 672 | a.NotEmpty(test4) 673 | } 674 | 675 | func TestFirstPbStruct(t *testing.T) { 676 | a := assert.New(t) 677 | DB := newDB() 678 | 679 | var row TestPbStruct 680 | err := DB.First(context.Background(), &row, "SELECT * FROM ${TABLE} WHERE id = 1").Error 681 | a.Nil(err) 682 | a.NotEmpty(&row) 683 | 684 | var row1 testdata.Device 685 | err = DB.First(context.Background(), &row1, "SELECT * FROM ${TABLE} WHERE id = 1").Error 686 | a.Contains(err.Error(), "doesn't exist") 687 | } 688 | 689 | func TestFindPbStruct(t *testing.T) { 690 | a := assert.New(t) 691 | DB := newDB() 692 | 693 | var rows []*TestPbStruct 694 | err := DB.Find(context.Background(), &rows, "SELECT * FROM ${TABLE} WHERE id < 3").Error 695 | a.Nil(err) 696 | a.Len(rows, 2) 697 | 698 | var rows1 []*testdata.Device 699 | err = DB.Find(context.Background(), &rows1, "SELECT * FROM ${TABLE} WHERE id < 3").Error 700 | a.Contains(err.Error(), "doesn't exist") 701 | } 702 | --------------------------------------------------------------------------------