├── .gitignore ├── LICENSE ├── README.md ├── go.mod ├── go.sum ├── orm ├── create_struct.go ├── create_table.go ├── db.go ├── err.go ├── json_field.go ├── json_int.go ├── json_time.go ├── log.go ├── query.go ├── query_delete.go ├── query_execute.go ├── query_get.go ├── query_insert.go ├── query_join.go ├── query_raw.go ├── query_result.go ├── query_select.go ├── query_select_gen.go ├── query_table.go ├── query_table_cache.go ├── query_transaction.go ├── query_union.go ├── query_update.go ├── query_where.go ├── query_window.go ├── query_with.go ├── reflect.go ├── slice.go ├── string.go ├── subquery.go ├── table.go ├── update_column.go └── where.go └── orm_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .DS_Store 3 | vendor/ 4 | local_test.go 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 magacy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Base on [go-sql-driver/mysql](https://github.com/go-sql-driver/mysql). 2 | 3 | ## Get Started 4 | ```go 5 | import ( 6 | "database/sql" 7 | "github.com/folospace/go-mysql-orm/orm" 8 | ) 9 | 10 | //connect mysql db 11 | var db, _ = orm.OpenMysql("user:password@tcp(127.0.0.1:3306)/mydb?parseTime=true&charset=utf8mb4&loc=Asia%2FShanghai") 12 | 13 | //user table 14 | var UserTable = new(User) 15 | 16 | type User struct { 17 | Id int `json:"id"` 18 | Email string `json:"email" orm:"email,unique"` 19 | Name string `json:"name" default:"jack"` 20 | Avatar string `json:"avatar" comment:"head image"` 21 | CreatedAt time.Time `json:"created_at"` 22 | UpdatedAt time.Time `json:"updated_at"` 23 | } 24 | 25 | func (*User) Connections() []*sql.DB { 26 | return []*sql.DB{db} 27 | } 28 | func (*User) DatabaseName() string { 29 | return "mydb" 30 | } 31 | func (*User) TableName() string { 32 | return "user" 33 | } 34 | func (u *User) Query() *orm.Query[*User] { 35 | return orm.NewQuery(UserTable).WherePrimaryIfNotZero(u.Id) 36 | } 37 | 38 | func main() { 39 | //create db table, add new columns if table already exist. 40 | UserTable.Query().CreateTable() 41 | 42 | //create struct from db table 43 | UserTable.Query().CreateStruct() 44 | } 45 | ``` 46 | 47 | ## select 48 | 49 | ```go 50 | //select * from user where id = 1 //to struct 51 | user, _ := UserTable.Query().Get(1) 52 | fmt.Println(user) //User{Id:1} 53 | 54 | //select * from user where name='john' //to struct slice 55 | users, _ := UserTable.Query().Where(&UserTable.Name, "john").Gets() 56 | fmt.Println(users) //User{Id:1}, User{Id:2}, ... 57 | 58 | //select email from user //to slice 59 | emails, _ := UserTable.Query().Select(&UserTable.Email).Limit(10).GetSliceString() 60 | fmt.Println(emails) //a**@gmail.com, b**@gmail.com, ... 61 | 62 | //select user info to slice, group by id 63 | var userInfoMap map[int][]string 64 | UserTable.Query().Select(&UserTable.Id, &UserTable.Email, &UserTable.Name).Limit(10).GetTo(&userInfoMap) 65 | fmt.Println(userInfoMap) //{1:[a**@gmail.com, a**], 2:[b**@gmail.com, b**], ...} 66 | 67 | //select user id to slice, group by name 68 | var sameNameUsers map[string][]int 69 | UserTable.Query().Select(&UserTable.Name, &UserTable.Id).Limit(10).GetTo(&sameNameUsers) 70 | fmt.Println(sameNameUsers) //{a**:[1,3], b**:[2,4], ...} 71 | 72 | ``` 73 | 74 | ## update | delete | insert 75 | 76 | ```go 77 | //update user set name="john 2" where id = 1 78 | UserTable.Query().WherePrimary(1).Update(&UserTable.Name, "john 2") 79 | 80 | //delete 81 | UserTable.Query().Delete(1, 2, 3) 82 | 83 | //insert 84 | UserTable.Query().Insert(&User{Name: "han"}) 85 | 86 | //update users with different names 87 | _ = UserTable.Query().OnConflictUpdate(&UserTable.Name, &UserTable.Name). 88 | Insert(&User{Id: 1, Name: "han"}, &User{Id: 2, Name: "join"}) 89 | ``` 90 | 91 | ### join 92 | 93 | ```go 94 | //query join 95 | UserTable.Query().Join(OrderTable, func (join *orm.Query[*User]) *orm.Query[*User] { 96 | return join.Where(&UserTable.Id, &OrderTable.UserId) 97 | }).Select(UserTable).Gets() 98 | ``` 99 | 100 | ## transaction 101 | 102 | ```go 103 | //transaction 104 | _ = UserTable.Query().Transaction(func (query *orm.Query[*User]) error { 105 | newId := query.Insert(&User{Name: "john"}).LastInsertId //insert 106 | //newId := orm.NewQuery(UserTable).UseTx(query.Tx()).Insert(&User{Name: "john"}).LastInsertId 107 | fmt.Println(newId) 108 | return errors.New("I want rollback") //rollback 109 | }) 110 | ``` 111 | 112 | ## subquery 113 | 114 | ```go 115 | //subquery 116 | subquery := UserTable.Query().WherePrimary(1).Select(&UserTable.Id).SubQuery() 117 | 118 | //where in suquery 119 | UserTable.Query().Where(&UserTable.Id, orm.WhereIn, subquery).Gets() 120 | 121 | //insert subquery 122 | UserTable.Query().Select(&UserTable.Id).InsertSubquery(subquery) 123 | 124 | //join subquery 125 | UserTable.Query().Join(subquery, func (query *orm.Query[*User]) *orm.Query[*User] { 126 | return query.Where(&UserTable.Id, orm.Raw("sub.id")) 127 | }).Gets() 128 | 129 | ``` -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/folospace/go-mysql-orm 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/go-sql-driver/mysql v1.6.0 7 | github.com/gobeam/stringy v0.0.5 8 | github.com/mcuadros/go-defaults v1.2.0 9 | ) 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= 2 | github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= 3 | github.com/gobeam/stringy v0.0.5 h1:TvxQGSAqr/qF0SBVxa8Q67WWIo7bCWS0bM101WOd52g= 4 | github.com/gobeam/stringy v0.0.5/go.mod h1:W3620X9dJHf2FSZF5fRnWekHcHQjwmCz8ZQ2d1qloqE= 5 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 6 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 7 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 8 | github.com/mcuadros/go-defaults v1.2.0 h1:FODb8WSf0uGaY8elWJAkoLL0Ri6AlZ1bFlenk56oZtc= 9 | github.com/mcuadros/go-defaults v1.2.0/go.mod h1:WEZtHEVIGYVDqkKSWBdWKUVdRyKlMfulPaGDWIVeCWY= 10 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= 11 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= 12 | gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= 13 | gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 14 | -------------------------------------------------------------------------------- /orm/create_struct.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/gobeam/stringy" 7 | "io/ioutil" 8 | "reflect" 9 | "regexp" 10 | "runtime" 11 | "strconv" 12 | "strings" 13 | "time" 14 | ) 15 | 16 | var findCommentRegex = regexp.MustCompile("(.+) COMMENT '(.+)'") 17 | var findDefaultRegex = regexp.MustCompile("(.+) DEFAULT (.+)") 18 | var findAutoIncrementRegex = regexp.MustCompile("(.+) AUTO_INCREMENT") 19 | var findNotNullRegex = regexp.MustCompile("(.+) NOT NULL") 20 | var findNullRegex = regexp.MustCompile("(.+) NULL") 21 | 22 | func (q *Query[T]) CreateStruct(file ...string) error { 23 | table := q.tableInterface() 24 | dbColumns, err := getTableDbColumns(q) 25 | if err != nil { 26 | return err 27 | } 28 | 29 | var structLines []string 30 | for _, v := range dbColumns { 31 | structFieldName := stringy.New(v.Name).CamelCase() 32 | sturctFieldType := getStructFieldTypeStringByDBType(v.Type) 33 | if v.Null { 34 | sturctFieldType = "*" + sturctFieldType 35 | } 36 | 37 | var structFieldTags []string 38 | structFieldTags = append(structFieldTags, fmt.Sprintf("json:\"%s\"", v.Name)) 39 | var ormTags []string 40 | ormTags = append(ormTags, v.Name) 41 | ormTags = append(ormTags, v.Type) 42 | if v.Null { 43 | ormTags = append(ormTags, nullPrefix) 44 | } 45 | if v.AutoIncrement { 46 | ormTags = append(ormTags, autoIncrementPrefix) 47 | } 48 | if v.Primary { 49 | ormTags = append(ormTags, primaryKeyPrefix) 50 | } 51 | if v.Unique { 52 | ormTags = append(ormTags, uniqueKeyPrefix) 53 | } 54 | if v.Index { 55 | ormTags = append(ormTags, keyPrefix) 56 | } 57 | if len(v.Uniques) > 0 { 58 | ormTags = append(ormTags, v.Uniques...) 59 | } 60 | if len(v.Indexs) > 0 { 61 | ormTags = append(ormTags, v.Indexs...) 62 | } 63 | 64 | structFieldTags = append(structFieldTags, fmt.Sprintf("orm:\"%s\"", strings.Join(ormTags, ","))) 65 | 66 | if v.Default != "" { 67 | structFieldTags = append(structFieldTags, fmt.Sprintf("default:\"%s\"", v.Default)) 68 | } 69 | if v.Comment != "" { 70 | structFieldTags = append(structFieldTags, fmt.Sprintf("comment:\"%s\"", v.Comment)) 71 | } 72 | 73 | line := structFieldName + " " + sturctFieldType + " " + "`" + strings.Join(structFieldTags, " ") + "`" 74 | 75 | structLines = append(structLines, line) 76 | } 77 | 78 | var structFile = "" 79 | if len(file) > 0 { 80 | structFile = file[0] 81 | } else { 82 | _, fs, _, _ := runtime.Caller(1) 83 | structFile = fs 84 | } 85 | 86 | fileBytes, err := ioutil.ReadFile(structFile) 87 | if err != nil { 88 | return err 89 | } 90 | 91 | fileContent := string(fileBytes) 92 | 93 | structNameSrc := strings.Split(reflect.TypeOf(table).Elem().String(), ".") 94 | structName := structNameSrc[len(structNameSrc)-1] 95 | 96 | search := "type " + structName + " struct {" 97 | oldStructRename := "type " + structName + "_" + time.Now().Format("2006_01_02_15_04_05") + " struct {" 98 | 99 | fileParts := strings.SplitN(fileContent, search, 2) 100 | 101 | finalFileContent := fileParts[0] + search + "\n" + strings.Join(structLines, "\n") + "\n}\n" 102 | 103 | if len(fileParts) > 1 { 104 | finalFileContent += oldStructRename + fileParts[1] 105 | } 106 | 107 | return ioutil.WriteFile(structFile, []byte(finalFileContent), 0644) 108 | } 109 | 110 | func getStructFieldTypeStringByDBType(dbType string) string { 111 | if strings.Contains(dbType, "char") || strings.Contains(dbType, "text") { 112 | return "string" 113 | } 114 | if strings.Contains(dbType, "int") { 115 | if strings.Contains(dbType, "unsigned") { 116 | if strings.HasPrefix(dbType, "tiny") { 117 | return "uint8" 118 | } else if strings.HasPrefix(dbType, "big") { 119 | return "uint64" 120 | } else { 121 | return "uint" 122 | } 123 | } else { 124 | if strings.HasPrefix(dbType, "tiny") { 125 | return "int8" 126 | } else if strings.HasPrefix(dbType, "big") { 127 | return "int64" 128 | } else { 129 | return "int" 130 | } 131 | } 132 | } else if strings.Contains(dbType, "float") || strings.Contains(dbType, "double") || strings.Contains(dbType, "decimal") { 133 | return "float64" 134 | } else if strings.Contains(dbType, "time") || strings.Contains(dbType, "date") { 135 | return "time.Time" 136 | } 137 | return "string" 138 | } 139 | 140 | func getSqlSegments[T Table](query *Query[T]) ([]string, error) { 141 | table := query.tableInterface() 142 | var res map[string]string 143 | 144 | err := query.Raw("show create table " + "`" + table.TableName() + "`").GetTo(&res).Err 145 | if err != nil { 146 | return nil, err 147 | } 148 | 149 | createTableSql := res[table.TableName()] 150 | if createTableSql == "" { 151 | return nil, ErrTableNotExisted 152 | } 153 | 154 | sqlSegments := strings.Split(createTableSql, "\n") 155 | 156 | if len(sqlSegments) <= 2 { 157 | return nil, errors.New(createTableSql) 158 | } 159 | sqlSegments = sqlSegments[1 : len(sqlSegments)-1] 160 | return sqlSegments, nil 161 | } 162 | 163 | func getTableDbColumns[T Table](query *Query[T]) ([]dBColumn, error) { 164 | sqlSegments, err := getSqlSegments(query) 165 | if err != nil { 166 | return nil, err 167 | } 168 | 169 | ret := make([]dBColumn, 0) 170 | existColumn := make(map[string]int) 171 | 172 | for k, v := range sqlSegments { 173 | v = strings.TrimLeft(v, " ") 174 | v = strings.TrimRight(v, ",") 175 | 176 | if strings.HasPrefix(v, "PRIMARY KEY ") { 177 | v = strings.TrimPrefix(v, "PRIMARY KEY ") 178 | keyNameAndCols := strings.Trim(v, "()") 179 | keyNameAndCols = strings.Trim(keyNameAndCols, "`") 180 | ret[existColumn[keyNameAndCols]].Primary = true 181 | } else if strings.HasPrefix(v, "UNIQUE KEY ") { 182 | v = strings.TrimPrefix(v, "UNIQUE KEY ") 183 | keyNameAndCols := strings.Split(v, " ") 184 | if len(keyNameAndCols) != 2 { 185 | continue 186 | } 187 | 188 | keyName := strings.Trim(keyNameAndCols[0], "`") 189 | cols := strings.Split(strings.Trim(keyNameAndCols[1], "()"), ",") 190 | 191 | if len(cols) == 1 && cols[0] == keyNameAndCols[0] { 192 | keyName = uniqueKeyPrefix 193 | } else { 194 | keyName = uniqueKeyPrefix + ":" + keyName 195 | } 196 | for k2, v2 := range cols { 197 | colName := strings.Trim(v2, "`") 198 | if len(cols) > 1 { 199 | ret[existColumn[colName]].Uniques = append(ret[existColumn[colName]].Uniques, keyName+":"+strconv.Itoa(k2)) 200 | } else { 201 | ret[existColumn[colName]].Uniques = append(ret[existColumn[colName]].Uniques, keyName) 202 | } 203 | } 204 | } else if strings.HasPrefix(v, "KEY ") { 205 | v = strings.TrimPrefix(v, "KEY ") 206 | keyNameAndCols := strings.Split(v, " ") 207 | if len(keyNameAndCols) != 2 { 208 | continue 209 | } 210 | 211 | keyName := strings.Trim(keyNameAndCols[0], "`") 212 | cols := strings.Split(strings.Trim(keyNameAndCols[1], "()"), ",") 213 | 214 | if len(cols) == 1 && cols[0] == keyNameAndCols[0] { 215 | keyName = keyPrefix 216 | } else { 217 | keyName = keyPrefix + ":" + keyName 218 | } 219 | for k2, v2 := range cols { 220 | colName := strings.Trim(v2, "`") 221 | 222 | if len(cols) > 1 { 223 | ret[existColumn[colName]].Indexs = append(ret[existColumn[colName]].Indexs, keyName+":"+strconv.Itoa(k2)) 224 | } else { 225 | ret[existColumn[colName]].Indexs = append(ret[existColumn[colName]].Indexs, keyName) 226 | } 227 | } 228 | } else if strings.HasPrefix(v, "`") { 229 | var col dBColumn 230 | col.Null = true 231 | temp := findCommentRegex.FindStringSubmatch(v) 232 | if len(temp) >= 3 { 233 | v = temp[1] 234 | col.Comment = temp[2] 235 | } 236 | 237 | temp = findDefaultRegex.FindStringSubmatch(v) 238 | if len(temp) >= 3 { 239 | v = temp[1] 240 | col.Default = strings.Trim(temp[2], "'") 241 | } 242 | 243 | temp = findAutoIncrementRegex.FindStringSubmatch(v) 244 | if len(temp) >= 2 { 245 | v = temp[1] 246 | col.AutoIncrement = true 247 | } 248 | 249 | temp = findNotNullRegex.FindStringSubmatch(v) 250 | if len(temp) >= 2 { 251 | v = temp[1] 252 | col.Null = false 253 | } 254 | 255 | temp = findNullRegex.FindStringSubmatch(v) 256 | if len(temp) >= 2 { 257 | v = temp[1] 258 | } 259 | 260 | nameAndTypeStrs := strings.SplitN(v, " ", 2) 261 | if len(nameAndTypeStrs) != 2 { 262 | continue 263 | } 264 | 265 | col.Type = nameAndTypeStrs[1] 266 | col.Name = strings.Trim(nameAndTypeStrs[0], "`") 267 | existColumn[col.Name] = k 268 | ret = append(ret, col) 269 | } 270 | } 271 | 272 | return ret, nil 273 | } 274 | -------------------------------------------------------------------------------- /orm/create_table.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strconv" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | const primaryKeyPrefix = "primary" 12 | const uniqueKeyPrefix = "unique" 13 | const keyPrefix = "index" 14 | const nullPrefix = "null" 15 | const notNullPrefix = "not null" 16 | const autoIncrementPrefix = "auto_increment" 17 | const createdAtColumn = "created_at" 18 | const updatedAtColumn = "updated_at" 19 | const deletedAtColumn = "deleted_at" 20 | 21 | var definedDefault = []string{"null", "current_timestamp", "current_timestamp on update current_timestamp"} 22 | 23 | //var tagSplitRegex = regexp.MustCompile(`[^\s"]+|"([^"]*)"`) 24 | 25 | //a := r.FindAllString(s, -1) 26 | 27 | type dBColumn struct { 28 | Name string // `id` 29 | Type string //bigint //varchar(255) 30 | Null bool //null //not null 31 | AutoIncrement bool //auto_increment 32 | Primary bool 33 | Unique bool 34 | Index bool 35 | 36 | Default string //default '' 37 | Comment string //comment '' 38 | Indexs []string //composite index names 39 | Uniques []string //composite unique index names 40 | } 41 | 42 | func (q *Query[T]) CreateTable() (string, error) { 43 | originQuery := q 44 | db := originQuery.DB() 45 | if db == nil { 46 | return "", ErrDbNotSelected 47 | } 48 | 49 | if len(originQuery.tables) == 0 || len(originQuery.tables[0].ormFields) == 0 || 50 | originQuery.tables[0].table == nil || originQuery.tables[0].table.TableName() == "" { 51 | return "", ErrTableNotSelected 52 | } 53 | 54 | dbColums := getMigrateColumns(originQuery.tables[0]) 55 | if len(dbColums) == 0 { 56 | return "", ErrColumnNotSelected 57 | } 58 | 59 | dbColumnStrs := generateColumnStrings(dbColums) 60 | 61 | originColumnStrs, _ := getSqlSegments(q) 62 | 63 | if len(originColumnStrs) > 0 { 64 | extraStrs := getTableNewColumns(originColumnStrs, dbColumnStrs) 65 | retSql := "" 66 | var err error 67 | 68 | for _, v := range extraStrs { 69 | tempSql := "ALTER TABLE " + "`" + q.tableInterface().TableName() + "`" + " ADD " + v 70 | 71 | retSql += tempSql + "\n" 72 | _, err = db.Exec(tempSql) 73 | if err != nil { 74 | break 75 | } 76 | } 77 | return retSql, err 78 | } else { 79 | createTableSql := fmt.Sprintf("create table IF NOT EXISTS `%s` (%s)", 80 | originQuery.tables[0].table.TableName(), 81 | strings.Join(dbColumnStrs, ",")) 82 | 83 | _, err := db.Exec(createTableSql) 84 | return createTableSql, err 85 | 86 | } 87 | } 88 | 89 | func getTableNewColumns(origin, current []string) []string { 90 | var exist = make(map[string]bool) 91 | for _, v := range origin { 92 | v := strings.Trim(v, " ") 93 | if strings.HasPrefix(v, "`") { 94 | tempStrs := strings.SplitN(v, " ", 2) 95 | exist[strings.ToLower(tempStrs[0])] = true 96 | } else { 97 | tempStrs := strings.SplitN(v, " (", 2) 98 | exist[strings.ToLower(tempStrs[0])] = true 99 | } 100 | } 101 | 102 | var ret []string 103 | var preCol string 104 | for _, v := range current { 105 | if strings.HasPrefix(v, "`") { 106 | tempStrs := strings.SplitN(v, " ", 2) 107 | if exist[strings.ToLower(tempStrs[0])] == false { 108 | if preCol != "" { 109 | ret = append(ret, v+" after "+preCol) 110 | } else { 111 | ret = append(ret, v) 112 | } 113 | } 114 | preCol = tempStrs[0] 115 | } else { 116 | tempStrs := strings.SplitN(v, " (", 2) 117 | if exist[strings.ToLower(tempStrs[0])] == false { 118 | ret = append(ret, v) 119 | } 120 | } 121 | } 122 | 123 | return ret 124 | } 125 | 126 | func generateColumnStrings(dbColums []dBColumn) []string { 127 | var ret []string 128 | var primaryStr string 129 | var uniqueColumns []string 130 | var indexColumns []string 131 | var uniqueComps = make(map[string][]string) 132 | var indexComps = make(map[string][]string) 133 | 134 | for _, v := range dbColums { 135 | var words []string 136 | //add column name 137 | words = append(words, "`"+v.Name+"`") 138 | //add type 139 | words = append(words, v.Type) 140 | 141 | //add null 142 | if v.Null { 143 | words = append(words, "null") 144 | } else { 145 | words = append(words, "not null") 146 | } 147 | //add default 148 | if v.AutoIncrement { 149 | words = append(words, "auto_increment") 150 | } else if v.Default != "" { 151 | words = append(words, "default "+v.Default) 152 | } 153 | 154 | //add comment 155 | if v.Comment != "" { 156 | words = append(words, "comment "+"'"+v.Comment+"'") 157 | } 158 | 159 | if v.Primary { 160 | primaryStr = fmt.Sprintf("primary key (%s)", "`"+v.Name+"`") 161 | } else if v.Unique { 162 | uniqueColumns = append(uniqueColumns, fmt.Sprintf("unique key `%s` (`%s`)", v.Name, v.Name)) 163 | } else if v.Index { 164 | indexColumns = append(indexColumns, fmt.Sprintf("key `%s` (`%s`)", v.Name, v.Name)) 165 | } 166 | 167 | if len(v.Uniques) > 0 { 168 | for _, v2 := range v.Uniques { 169 | li := strings.LastIndex(v2, ":") 170 | if li > 0 { 171 | numStr := v2[li+1:] 172 | num, numErr := strconv.Atoi(numStr) 173 | if numErr == nil { 174 | if uniqueComps[v2[:li]] == nil { 175 | uniqueComps[v2[:li]] = make([]string, 16) 176 | } 177 | uniqueComps[v2[:li]][num] = v.Name 178 | continue 179 | } 180 | } 181 | uniqueComps[v2] = append(uniqueComps[v2], v.Name) 182 | 183 | } 184 | } 185 | 186 | if len(v.Indexs) > 0 { 187 | for _, v2 := range v.Indexs { 188 | li := strings.LastIndex(v2, ":") 189 | if li > 0 { 190 | numStr := v2[li+1:] 191 | num, numErr := strconv.Atoi(numStr) 192 | if numErr == nil { 193 | if indexComps[v2[:li]] == nil { 194 | indexComps[v2[:li]] = make([]string, 16) 195 | } 196 | indexComps[v2[:li]][num] = v.Name 197 | continue 198 | } 199 | } 200 | indexComps[v2] = append(indexComps[v2], v.Name) 201 | } 202 | } 203 | ret = append(ret, strings.Join(words, " ")) 204 | } 205 | if primaryStr != "" { 206 | ret = append(ret, primaryStr) 207 | } 208 | for _, v := range uniqueColumns { 209 | ret = append(ret, v) 210 | } 211 | 212 | for _, v := range indexColumns { 213 | ret = append(ret, v) 214 | } 215 | for k, v := range uniqueComps { 216 | var newv []string 217 | for _, v2 := range v { 218 | if v2 != "" { 219 | newv = append(newv, v2) 220 | } 221 | } 222 | ret = append(ret, fmt.Sprintf("unique key `%s` (%s)", k, "`"+strings.Join(newv, "`,`")+"`")) 223 | } 224 | for k, v := range indexComps { 225 | var newv []string 226 | for _, v2 := range v { 227 | if v2 != "" { 228 | newv = append(newv, v2) 229 | } 230 | } 231 | ret = append(ret, fmt.Sprintf("key `%s` (%s)", k, "`"+strings.Join(newv, "`,`")+"`")) 232 | } 233 | return ret 234 | } 235 | 236 | func getMigrateColumns(table *queryTable) []dBColumn { 237 | var ret []dBColumn 238 | for i := 0; i < table.tableStruct.NumField(); i++ { 239 | varField := table.tableStruct.Field(i) 240 | 241 | if varField.CanSet() == false { 242 | continue 243 | } 244 | 245 | column := dBColumn{} 246 | 247 | ormTags := stringSplitEscapeParentheses(table.getTag(i, "orm"), ",") 248 | if ormTags[0] != "" { 249 | column.Name = ormTags[0] 250 | } else { 251 | column.Name = table.getTags(i, "json")[0] 252 | } 253 | 254 | if column.Name == "" || column.Name == "-" { 255 | continue 256 | } 257 | 258 | kind := varField.Kind() 259 | if varField.Kind() == reflect.Ptr { 260 | kind = varField.Elem().Kind() 261 | if varField.Elem().Kind() == reflect.Ptr { 262 | continue 263 | } 264 | column.Null = true 265 | } 266 | 267 | column.Type, column.Default = getTypeAndDefault(varField) 268 | 269 | if i == 0 { 270 | column.Primary = true 271 | if column.Default == "0" { 272 | column.AutoIncrement = true 273 | } 274 | } 275 | 276 | if column.Name == createdAtColumn { 277 | column.Null = false 278 | column.Type = "datetime" 279 | column.Default = "CURRENT_TIMESTAMP" 280 | } else if column.Name == updatedAtColumn { 281 | column.Null = false 282 | column.Type = "datetime" 283 | column.Default = "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP" 284 | } else if column.Name == deletedAtColumn { 285 | column.Null = true 286 | column.Type = "datetime" 287 | column.Default = "Null" 288 | } 289 | 290 | column.Comment = table.getTag(i, "comment") 291 | customDefault := table.getTag(i, "default") 292 | if customDefault != "" { 293 | column.Default = customDefault 294 | if kind == reflect.Bool { 295 | if strings.ToLower(customDefault) == "true" { 296 | column.Default = "1" 297 | } else if strings.ToLower(customDefault) == "false" { 298 | column.Default = "0" 299 | } 300 | } 301 | } 302 | if strings.ToLower(column.Default) == "null" { 303 | column.Null = true 304 | } 305 | 306 | if ormTags[0] != "" { 307 | overideColumn := dBColumn{} 308 | 309 | overrideNull := false 310 | for k, v := range ormTags { 311 | if k == 0 { 312 | continue 313 | } 314 | if v == nullPrefix { 315 | overrideNull = true 316 | overideColumn.Null = true 317 | } else if v == notNullPrefix { 318 | overrideNull = true 319 | overideColumn.Null = false 320 | } else if v == autoIncrementPrefix { 321 | overideColumn.AutoIncrement = true 322 | } else if strings.HasPrefix(v, primaryKeyPrefix) { 323 | overideColumn.Primary = true 324 | } else if strings.HasPrefix(v, uniqueKeyPrefix) { 325 | if v == uniqueKeyPrefix { 326 | column.Unique = true 327 | } else { 328 | column.Uniques = append(column.Uniques, strings.TrimPrefix(v, uniqueKeyPrefix+":")) 329 | } 330 | } else if strings.HasPrefix(v, keyPrefix) { 331 | if v == keyPrefix { 332 | column.Index = true 333 | } else { 334 | column.Indexs = append(column.Indexs, strings.TrimPrefix(v, keyPrefix+":")) 335 | } 336 | } else { 337 | overideColumn.Type = v 338 | } 339 | } 340 | 341 | if overrideNull { 342 | column.Null = overideColumn.Null 343 | } 344 | column.AutoIncrement = overideColumn.AutoIncrement 345 | column.Primary = overideColumn.Primary 346 | if overideColumn.Type != "" { 347 | column.Type = overideColumn.Type 348 | } 349 | } 350 | 351 | if column.Null { 352 | if customDefault == "" { 353 | column.Default = "null" 354 | } 355 | } 356 | 357 | if column.Default == "" || sliceContain(definedDefault, strings.ToLower(column.Default)) == false { 358 | column.Default = "'" + column.Default + "'" 359 | } 360 | 361 | if column.Null == false && strings.ToLower(customDefault) == "null" { 362 | column.Default = "" 363 | } 364 | 365 | ret = append(ret, column) 366 | } 367 | 368 | return ret 369 | } 370 | 371 | func getTypeAndDefault(val reflect.Value) (string, string) { 372 | var types, defaults string 373 | typ := val.Type() 374 | kind := val.Kind() 375 | if kind == reflect.Ptr { 376 | kind = val.Type().Elem().Kind() 377 | typ = typ.Elem() 378 | } 379 | switch kind { 380 | case reflect.Bool, reflect.Int8: 381 | types = "tinyint" 382 | defaults = "0" 383 | case reflect.Int16: 384 | types = "smallint" 385 | defaults = "0" 386 | case reflect.Int, reflect.Int32: 387 | types = "int" 388 | defaults = "0" 389 | case reflect.Int64: 390 | types = "bigint" 391 | defaults = "0" 392 | case reflect.Uint8: 393 | types = "tinyint unsigned" 394 | defaults = "0" 395 | case reflect.Uint16: 396 | types = "smallint unsigned" 397 | defaults = "0" 398 | case reflect.Uint, reflect.Uint32: 399 | types = "int unsigned" 400 | defaults = "0" 401 | case reflect.Uint64: 402 | types = "bigint unsigned" 403 | defaults = "0" 404 | case reflect.Float32: 405 | types = "float" 406 | defaults = "0" 407 | case reflect.Float64: 408 | types = "double" 409 | defaults = "0" 410 | case reflect.String: 411 | types = "varchar(255)" 412 | default: 413 | if _, ok := val.Interface().(*time.Time); ok { 414 | types = "datetime" 415 | } else if _, ok := val.Interface().(time.Time); ok { 416 | types = "datetime" 417 | } else if kind == reflect.Struct { 418 | realVal := reflect.New(typ) 419 | if _, ok := realVal.Elem().Field(0).Interface().(*time.Time); ok { 420 | types = "datetime" 421 | } else if _, ok := realVal.Elem().Field(0).Interface().(time.Time); ok { 422 | types = "datetime" 423 | } else { 424 | types = "varchar(255)" 425 | } 426 | } else { 427 | types = "varchar(255)" 428 | } 429 | if types == "datetime" { 430 | defaults = "current_timestamp" 431 | } 432 | } 433 | return types, defaults 434 | } 435 | 436 | func stringSplitEscapeParentheses(s string, separator string) []string { 437 | var splits []string 438 | var start = "(" 439 | var end = ")" 440 | 441 | var openP int 442 | var before string 443 | for i, v := range s { 444 | temp := string(v) 445 | if temp == separator && (openP == 0 || strings.Contains(s[i:], end) == false) { 446 | if before != "" { 447 | splits = append(splits, before) 448 | } 449 | before = "" 450 | } else { 451 | if temp == start { 452 | openP += 1 453 | } else if temp == end { 454 | if openP > 0 { 455 | openP -= 1 456 | } 457 | } 458 | before += temp 459 | } 460 | } 461 | if before != "" { 462 | splits = append(splits, before) 463 | } 464 | if len(splits) == 0 { 465 | splits = append(splits, "") 466 | } 467 | return splits 468 | } 469 | -------------------------------------------------------------------------------- /orm/db.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql" 5 | sqldriver "database/sql/driver" 6 | _ "github.com/go-sql-driver/mysql" 7 | ) 8 | 9 | func OpenMysql(dataSourceName string) (*sql.DB, error) { 10 | return sql.Open("mysql", dataSourceName) 11 | } 12 | 13 | func Open(driverName, dataSourceName string) (*sql.DB, error) { 14 | return sql.Open(driverName, dataSourceName) 15 | } 16 | 17 | func OpenDB(driver sqldriver.Connector) *sql.DB { 18 | return sql.OpenDB(driver) 19 | } 20 | 21 | func Register(name string, drvier sqldriver.Driver) { 22 | sql.Register(name, drvier) 23 | } 24 | -------------------------------------------------------------------------------- /orm/err.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import "errors" 4 | 5 | var ( 6 | ErrDbNotSelected = errors.New("db not selecteed") 7 | ErrTableNotExisted = errors.New("table not existed") 8 | ErrTableNotSelected = errors.New("table not selected") 9 | ErrColumnNotSelected = errors.New("column not selected") 10 | ErrColumnNotExisted = errors.New("column not existed") 11 | ErrRawSqlRequired = errors.New("raw sql required") 12 | ErrParamMustBePtr = errors.New("param must be ptr") 13 | ErrParamElemKindMustBeStruct = errors.New("param elem kind must be struct") 14 | ErrColumnShouldBeStringOrPtr = errors.New("select|where column should be string or ptr of Table.T.field") 15 | ErrDestOfGetToMustBePtr = errors.New("dest of Get-to must be ptr") 16 | ErrDestOfGetToSliceElemMustNotBePtr = errors.New("dest of Get-to slice elem kind must not be ptr") 17 | ErrDestOfGetToMapElemMustNotBePtr = errors.New("dest of Get-to map elem kind must not be ptr") 18 | ErrInsertPtrNotAllowed = errors.New("insert ptr data not allowed") 19 | ErrUpdateWithoutCondition = errors.New("update without condition not allowed") 20 | ErrDeleteWithoutCondition = errors.New("delete without condition not allowed") 21 | ) 22 | -------------------------------------------------------------------------------- /orm/json_field.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/json" 6 | "unsafe" 7 | ) 8 | 9 | //json ojbect/slice <=> go struct/slice <=> db json string 10 | type JsonField[T any] struct { 11 | Data T 12 | } 13 | 14 | func NewJsonField[T any](data T) JsonField[T] { 15 | return JsonField[T]{Data: data} 16 | } 17 | 18 | func (t JsonField[T]) MarshalJSON() ([]byte, error) { 19 | return json.Marshal(t.Data) 20 | } 21 | 22 | func (t *JsonField[T]) UnmarshalJSON(data []byte) error { 23 | return json.Unmarshal(data, &t.Data) 24 | } 25 | 26 | func (t JsonField[T]) Value() (driver.Value, error) { 27 | data, err := json.Marshal(t.Data) 28 | return *(*string)(unsafe.Pointer(&data)), err 29 | } 30 | 31 | func (t *JsonField[T]) Scan(raw any) error { 32 | rawData, ok := raw.([]byte) 33 | if !ok { 34 | return nil 35 | } 36 | 37 | if len(rawData) == 0 { 38 | return nil 39 | } 40 | 41 | var i T 42 | err := json.Unmarshal(rawData, &i) 43 | if err != nil { 44 | return err 45 | } 46 | t.Data = i 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /orm/json_int.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | //json int or long string int <=> go int64 11 | type JsonInt int64 12 | 13 | func (t JsonInt) MarshalJSON() ([]byte, error) { 14 | return []byte("\"" + strconv.FormatInt(int64(t), 10) + "\""), nil 15 | } 16 | 17 | func (t *JsonInt) UnmarshalJSON(data []byte) error { 18 | val, err := strconv.ParseInt(strings.Trim(string(data), "\""), 10, 64) 19 | if err != nil { 20 | return err 21 | } 22 | *t = JsonInt(val) 23 | return nil 24 | } 25 | 26 | func (t JsonInt) Value() (driver.Value, error) { 27 | return int64(t), nil 28 | } 29 | 30 | func (t JsonInt) ToString() string { 31 | return strconv.FormatInt(int64(t), 10) 32 | } 33 | 34 | func (t *JsonInt) Scan(v any) error { 35 | val, ok := v.(int64) 36 | if ok { 37 | *t = JsonInt(val) 38 | return nil 39 | } else { 40 | val, ok := v.([]uint8) 41 | if ok { 42 | v, _ := strconv.ParseInt(string(val), 10, 64) 43 | *t = JsonInt(v) 44 | return nil 45 | } 46 | } 47 | return fmt.Errorf("can not convert %v to json int", v) 48 | } 49 | -------------------------------------------------------------------------------- /orm/json_time.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "strconv" 7 | "time" 8 | ) 9 | 10 | //json int(13) <=> go time.time 11 | type JsonTime struct { 12 | time.Time 13 | } 14 | 15 | // MarshalJSON on JsonTime format Time field with %Y-%m-%d %H:%M:%S 16 | func (t JsonTime) MarshalJSON() ([]byte, error) { 17 | //formatted := fmt.Sprintf("\"%s\"", t.Format("2006-01-02 15:04:05")) 18 | //return []byte(formatted), nil 19 | ts := t.UnixNano() / 1e6 20 | if ts < 0 { 21 | ts = 0 22 | } 23 | return []byte(strconv.FormatInt(ts, 10)), nil 24 | } 25 | 26 | // MarshalJSON on JsonTime format Time field with %Y-%m-%d %H:%M:%S 27 | func (t *JsonTime) UnmarshalJSON(data []byte) error { 28 | //formatted := fmt.Sprintf("\"%s\"", t.Format("2006-01-02 15:04:05")) 29 | //return []byte(formatted), nil 30 | val, err := strconv.ParseInt(string(data), 10, 64) 31 | if err != nil { 32 | return err 33 | } 34 | if val == 0 { 35 | t.Time = time.Time{} 36 | return nil 37 | } 38 | t.Time = time.Unix(0, val*1e6) 39 | return nil 40 | } 41 | 42 | // Value insert timestamp into mysql need this function. 43 | func (t JsonTime) Value() (driver.Value, error) { 44 | var zeroTime time.Time 45 | if t.Time.UnixNano() == zeroTime.UnixNano() { 46 | return nil, nil 47 | } 48 | return t.Time, nil 49 | } 50 | 51 | // Scan valueof time.Time 52 | func (t *JsonTime) Scan(v any) error { 53 | value, ok := v.(time.Time) 54 | if ok && t != nil { 55 | *t = JsonTime{Time: value} 56 | return nil 57 | } 58 | return fmt.Errorf("can not convert %v to timestamp", v) 59 | } 60 | -------------------------------------------------------------------------------- /orm/log.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | type InfoLogger interface { 4 | Info(args ...any) 5 | } 6 | 7 | type ErrorLogger interface { 8 | Error(args ...any) 9 | } 10 | 11 | var infoLogger InfoLogger 12 | var errorLogger ErrorLogger 13 | 14 | func SetInfoLogger(l InfoLogger) { 15 | infoLogger = l 16 | infoLogger.Info("set info logger") 17 | } 18 | 19 | func SetErrorLogger(l ErrorLogger) { 20 | errorLogger = l 21 | errorLogger.Error("set error logger") 22 | } 23 | -------------------------------------------------------------------------------- /orm/query.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "github.com/mcuadros/go-defaults" 7 | "math/rand" 8 | "reflect" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | type Raw string 14 | 15 | type Query[T Table] struct { 16 | writeAndReadDbs []*sql.DB //first element as write db, rest as read dbs 17 | tx *sql.Tx 18 | ctx *context.Context 19 | tables []*queryTable 20 | wheres []where 21 | result QueryResult 22 | limit int 23 | offset int 24 | partitionbys []string 25 | orderbys []string 26 | forUpdate SelectForUpdateType 27 | T T 28 | columns []any 29 | insertIgnore bool 30 | conflictUpdates []updateColumn 31 | prepareSql string 32 | bindings []any 33 | groupBy []any 34 | having []where 35 | unions []*SubQuery 36 | withCtes []*SubQuery 37 | windows []*SubQuery 38 | self *Query[*SubQuery] 39 | selectTimeout string 40 | notlog bool 41 | } 42 | 43 | //query table[struct] generics 44 | func NewQuery[T Table](t T, writeAndReadDbs ...*sql.DB) *Query[T] { 45 | q := Query[T]{T: t, writeAndReadDbs: writeAndReadDbs} 46 | return q.fromTable(q.T) 47 | } 48 | 49 | func (q *Query[T]) NewDefault() T { 50 | ret, _ := reflect.New(q.tables[0].tableStructType).Interface().(T) 51 | defaults.SetDefaults(ret) 52 | return ret 53 | } 54 | 55 | func (q *Query[T]) Clone() *Query[T] { 56 | var clone = *q 57 | return &clone 58 | } 59 | 60 | func (q *Query[T]) UseDB(db ...*sql.DB) *Query[T] { 61 | q.writeAndReadDbs = db 62 | return q 63 | } 64 | 65 | func (q *Query[T]) UseFirstDB() *Query[T] { 66 | q.writeAndReadDbs = q.tables[0].table.Connections()[:1] 67 | return q 68 | } 69 | 70 | func (q *Query[T]) UseTx(tx *sql.Tx) *Query[T] { 71 | q.tx = tx 72 | return q 73 | } 74 | 75 | func (q *Query[T]) DB() *sql.DB { 76 | return q.writeDB() 77 | } 78 | 79 | func (q *Query[T]) DBs() []*sql.DB { 80 | if len(q.writeAndReadDbs) == 0 && len(q.tables) > 0 { 81 | q.writeAndReadDbs = q.tables[0].table.Connections() 82 | } 83 | return q.writeAndReadDbs 84 | } 85 | 86 | func (q *Query[T]) Tx() *sql.Tx { 87 | return q.tx 88 | } 89 | 90 | func (q *Query[T]) Alias(alias string) *Query[T] { 91 | q.tables[0].alias = alias 92 | return q 93 | } 94 | func (q *Query[T]) TableName(override string) *Query[T] { 95 | q.tables[0].overrideTableName = override 96 | return q 97 | } 98 | 99 | //query raw, tablename can be empty 100 | func newQueryRaw(tableName string, writeAndReadDbs ...*sql.DB) *Query[*SubQuery] { 101 | sq := &SubQuery{} 102 | if tableName != "" { 103 | sq.tableName = tableName 104 | } 105 | return NewQuery(sq, writeAndReadDbs...) 106 | } 107 | 108 | func (q *Query[T]) writeDB() *sql.DB { 109 | dbs := q.DBs() 110 | if len(dbs) > 0 { 111 | return dbs[0] 112 | } 113 | return nil 114 | } 115 | func (q *Query[T]) readDB() *sql.DB { 116 | dbs := q.DBs() 117 | if len(dbs) > 1 { 118 | return dbs[rand.Intn(len(dbs)-1)+1] //rand get db 119 | } else { 120 | return q.writeDB() 121 | } 122 | } 123 | 124 | func (q *Query[T]) tableInterface() Table { 125 | return any(q.T).(Table) 126 | } 127 | 128 | func (q *Query[T]) allCols() string { 129 | return q.tables[0].getAliasOrTableName() + ".*" 130 | } 131 | 132 | func (q *Query[T]) fromTable(table Table) *Query[T] { 133 | newTable, err := q.parseTable(table) 134 | if err != nil { 135 | return q.setErr(err) 136 | } 137 | 138 | if newTable.rawSql != "" { 139 | newTable.alias = subqueryDefaultName 140 | } 141 | q.tables = append(q.tables, newTable) 142 | return q 143 | } 144 | 145 | func (q *Query[T]) parseTable(table Table) (*queryTable, error) { 146 | var newTable *queryTable 147 | 148 | if temp, ok := table.(SubQuery); ok { 149 | newTable = &queryTable{ 150 | table: table, 151 | rawSql: temp.raw, 152 | bindings: temp.bindings, 153 | } 154 | return newTable, nil 155 | } else if temp, ok := table.(*SubQuery); ok { 156 | newTable = &queryTable{ 157 | table: table, 158 | rawSql: temp.raw, 159 | bindings: temp.bindings, 160 | } 161 | return newTable, nil 162 | } else { 163 | cached := getTableFromCache(table) 164 | if cached != nil { 165 | tmp := *cached 166 | return &tmp, nil 167 | } 168 | tableStructAddr := reflect.ValueOf(table) 169 | if tableStructAddr.Kind() != reflect.Ptr { 170 | return nil, ErrParamMustBePtr 171 | } 172 | //reset query vars 173 | tableStruct := tableStructAddr.Elem() 174 | if tableStruct.Kind() != reflect.Struct { 175 | return nil, ErrParamElemKindMustBeStruct 176 | } 177 | 178 | tableStructType := reflect.TypeOf(table).Elem() 179 | ormFields := make(map[any]string) 180 | 181 | for i := 0; i < tableStruct.NumField(); i++ { 182 | valueField := tableStruct.Field(i) 183 | 184 | ormTag := strings.Split(tableStructType.Field(i).Tag.Get("orm"), ",")[0] 185 | if ormTag == "-" { 186 | continue 187 | } 188 | if ormTag != "" { 189 | ormFields[valueField.Addr().Interface()] = ormTag 190 | continue 191 | } 192 | 193 | name := strings.Split(tableStructType.Field(i).Tag.Get("json"), ",")[0] 194 | if name == "-" { 195 | continue 196 | } 197 | if name != "" { 198 | ormFields[valueField.Addr().Interface()] = name 199 | } 200 | } 201 | newTable = &queryTable{ 202 | table: table, 203 | tableStruct: tableStruct, 204 | tableStructType: reflect.TypeOf(table).Elem(), 205 | ormFields: ormFields, 206 | } 207 | cacheTable(table, newTable) 208 | 209 | tmp := *newTable 210 | return &tmp, nil 211 | } 212 | } 213 | 214 | func (q *Query[T]) isRaw(v any) (string, bool) { 215 | val, ok := v.(Raw) 216 | return string(val), ok 217 | } 218 | 219 | func (q *Query[T]) isOperator(v any) (string, bool) { 220 | val, ok := v.(WhereOperator) 221 | return string(val), ok 222 | } 223 | 224 | func (q *Query[T]) isStringOrRaw(v any) (string, bool) { 225 | val := reflect.ValueOf(v) 226 | 227 | if val.Kind() == reflect.String { 228 | return val.String(), true 229 | } else { 230 | return "", false 231 | } 232 | } 233 | 234 | func (q *Query[T]) parseColumn(v any) (string, error) { 235 | columnVar := reflect.ValueOf(v) 236 | if columnVar.Kind() == reflect.String { 237 | ret := columnVar.String() 238 | if ret == "*" && len(q.tables) > 0 { 239 | prefix := q.tables[0].getAliasOrTableName() 240 | if prefix != "" { 241 | prefix += "." 242 | } 243 | return prefix + ret, nil 244 | } else if ret == "" { 245 | return "", ErrColumnShouldBeStringOrPtr 246 | } else { 247 | return ret, nil 248 | } 249 | } else if columnVar.Kind() == reflect.Ptr && columnVar.Elem().CanAddr() { 250 | table, column := q.getTableColumn(columnVar) 251 | if table == nil { 252 | return "", ErrColumnNotExisted 253 | } 254 | prefix := table.getAliasOrTableName() 255 | if prefix != "" { 256 | prefix += "." 257 | } 258 | if column == "" { 259 | return prefix + "*", nil 260 | } 261 | 262 | return prefix + "`" + column + "`", nil 263 | } else { 264 | return "", ErrColumnShouldBeStringOrPtr 265 | } 266 | } 267 | 268 | func (q *Query[T]) getTableColumn(i reflect.Value) (*queryTable, string) { 269 | for _, t := range q.tables { 270 | if i.Interface() == t.table || (i.Elem().CanInterface() && i.Elem().Interface() == t.table) { 271 | return t, "" 272 | } 273 | if s, exist := t.ormFields[i.Elem().Addr().Interface()]; exist { 274 | return t, s 275 | } 276 | } 277 | return nil, "" 278 | } 279 | 280 | func (q *Query[T]) setErr(err error) *Query[T] { 281 | if err != nil { 282 | q.result.Err = err 283 | } 284 | return q 285 | } 286 | 287 | func (q *Query[T]) Limit(limit int) *Query[T] { 288 | q.limit = limit 289 | return q 290 | } 291 | 292 | func (q *Query[T]) Offset(offset int) *Query[T] { 293 | q.offset = offset 294 | return q 295 | } 296 | 297 | //should not use group by after order by 298 | func (q *Query[T]) GroupBy(columns ...any) *Query[T] { 299 | q.groupBy = append(q.groupBy, columns...) 300 | return q 301 | } 302 | 303 | func (q *Query[T]) Having(column any, vals ...any) *Query[T] { 304 | oldWheres := q.wheres 305 | 306 | newQuery := q.where(false, column, vals...) 307 | 308 | newWheres := newQuery.wheres[len(oldWheres):] 309 | if len(newWheres) > 0 { 310 | newQuery.having = append(newQuery.having, newWheres...) 311 | newQuery.wheres = oldWheres 312 | } 313 | return newQuery 314 | } 315 | 316 | func (q *Query[T]) OrHaving(column any, vals ...any) *Query[T] { 317 | oldWheres := q.wheres 318 | 319 | newQuery := q.where(true, column, vals...) 320 | 321 | newWheres := newQuery.wheres[len(oldWheres):] 322 | if len(newWheres) > 0 { 323 | newQuery.having = append(newQuery.having, newWheres...) 324 | newQuery.wheres = oldWheres 325 | } 326 | return newQuery 327 | } 328 | func (q *Query[T]) Partition(p string) *Query[T] { 329 | q.tables[0].partition = p 330 | return q 331 | } 332 | 333 | func (q *Query[T]) PartitionBy(column any) *Query[T] { 334 | val, err := q.parseColumn(column) 335 | if err != nil { 336 | return q.setErr(err) 337 | } 338 | q.partitionbys = append(q.partitionbys, val) 339 | return q 340 | } 341 | func (q *Query[T]) OrderBy(column any) *Query[T] { 342 | val, err := q.parseColumn(column) 343 | if err != nil { 344 | return q.setErr(err) 345 | } 346 | q.orderbys = append(q.orderbys, val) 347 | return q 348 | } 349 | func (q *Query[T]) OrderByDesc(column any) *Query[T] { 350 | val, err := q.parseColumn(column) 351 | if err != nil { 352 | return q.setErr(err) 353 | } 354 | q.orderbys = append(q.orderbys, val+" desc") 355 | return q 356 | } 357 | 358 | func (q *Query[T]) getOrderAndLimitSqlStr() string { 359 | var ret []string 360 | if len(q.orderbys) > 0 { 361 | orderStr := "order by " + strings.Join(q.orderbys, ",") 362 | ret = append(ret, orderStr) 363 | } 364 | if q.limit > 0 { 365 | limitStr := "limit " + strconv.Itoa(q.limit) 366 | ret = append(ret, limitStr) 367 | } 368 | if q.offset > 0 { 369 | offsetStr := "offset " + strconv.Itoa(q.offset) 370 | ret = append(ret, offsetStr) 371 | } 372 | 373 | return strings.Join(ret, " ") 374 | } 375 | 376 | func (q *Query[T]) NotLog() *Query[T] { 377 | q.notlog = true 378 | return q 379 | } 380 | -------------------------------------------------------------------------------- /orm/query_delete.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | func (q *Query[T]) Delete(primaryIds ...any) QueryResult { 4 | if len(q.tables) == 0 { 5 | q.setErr(ErrTableNotSelected) 6 | return q.result 7 | } 8 | 9 | if len(primaryIds) == 1 { 10 | return q.WherePrimary(primaryIds[0]).delete() 11 | } else if len(primaryIds) > 0 { 12 | return q.WherePrimary(primaryIds).delete() 13 | } else { 14 | return q.delete() 15 | } 16 | } 17 | 18 | func (q *Query[T]) delete() QueryResult { 19 | bindings := make([]any, 0) 20 | 21 | if len(q.wheres) == 0 && len(q.tables) <= 1 && q.limit == 0 { 22 | q.setErr(ErrDeleteWithoutCondition) 23 | } 24 | 25 | tableStr := q.generateTableAndJoinStr(q.tables, &bindings) 26 | 27 | whereStr := q.generateWhereStr(q.wheres, &bindings) 28 | 29 | orderLimitOffsetStr := q.getOrderAndLimitSqlStr() 30 | 31 | rawSql := "delete" 32 | if orderLimitOffsetStr == "" { 33 | rawSql += " " + q.tables[0].getTableName() 34 | } 35 | rawSql += " from " + tableStr 36 | 37 | if whereStr != "" { 38 | rawSql += " where " + whereStr 39 | } 40 | if orderLimitOffsetStr != "" { 41 | rawSql += " " + orderLimitOffsetStr 42 | } 43 | 44 | q.prepareSql = rawSql 45 | q.bindings = bindings 46 | 47 | return q.Execute() 48 | } 49 | -------------------------------------------------------------------------------- /orm/query_execute.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql" 5 | ) 6 | 7 | //excute raw 8 | func (q *Query[T]) Execute() QueryResult { 9 | if q.prepareSql == "" { 10 | q.setErr(ErrRawSqlRequired) 11 | } 12 | 13 | q.result.PrepareSql = q.prepareSql 14 | q.result.Bindings = q.bindings 15 | 16 | if q.result.Err != nil { 17 | if errorLogger != nil { 18 | errorLogger.Error(q.result.Sql(), q.result.Error()) 19 | } 20 | return q.result 21 | } else if infoLogger != nil { 22 | infoLogger.Info(q.result.Sql(), q.result.Error()) 23 | } 24 | 25 | var res sql.Result 26 | var err error 27 | if q.Tx() != nil { 28 | if q.ctx != nil { 29 | res, err = q.Tx().ExecContext(*q.ctx, q.prepareSql, q.bindings...) 30 | } else { 31 | res, err = q.Tx().Exec(q.prepareSql, q.bindings...) 32 | } 33 | } else { 34 | if q.ctx != nil { 35 | res, err = q.DB().ExecContext(*q.ctx, q.prepareSql, q.bindings...) 36 | } else { 37 | res, err = q.DB().Exec(q.prepareSql, q.bindings...) 38 | } 39 | } 40 | 41 | if err != nil { 42 | q.result.Err = err 43 | if errorLogger != nil && q.notlog == false { 44 | errorLogger.Error(q.result.Sql(), q.result.Error()) 45 | } 46 | } else if res != nil { 47 | q.result.LastInsertId, q.result.Err = res.LastInsertId() 48 | q.result.RowsAffected, q.result.Err = res.RowsAffected() 49 | } 50 | return q.result 51 | } 52 | -------------------------------------------------------------------------------- /orm/query_get.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql" 5 | "reflect" 6 | "regexp" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | var scanErrPrefix = "sql: Scan error on column index" 12 | var matchScanErrIndex = regexp.MustCompile("\\d+") 13 | 14 | //get first T 15 | func (q *Query[T]) Get(primaryIds ...any) (T, QueryResult) { 16 | ret := reflect.New(q.tables[0].tableStructType).Interface() 17 | var res QueryResult 18 | 19 | if len(primaryIds) == 1 { 20 | res = q.WherePrimary(primaryIds[0]).Limit(1).GetTo(ret) 21 | } else if len(primaryIds) > 0 { 22 | res = q.WherePrimary(primaryIds).Limit(1).GetTo(ret) 23 | } else { 24 | res = q.Limit(1).GetTo(ret) 25 | } 26 | 27 | return ret.(T), res 28 | } 29 | 30 | //get slice T 31 | func (q *Query[T]) Gets(primaryIds ...any) ([]T, QueryResult) { 32 | var ret []T 33 | var res QueryResult 34 | if len(primaryIds) == 1 { 35 | res = q.WherePrimary(primaryIds[0]).GetTo(&ret) 36 | } else if len(primaryIds) > 0 { 37 | res = q.WherePrimary(primaryIds).GetTo(&ret) 38 | } else { 39 | res = q.GetTo(&ret) 40 | } 41 | return ret, res 42 | } 43 | 44 | //get first bool 45 | func (q *Query[T]) GetBool() (bool, QueryResult) { 46 | var ret bool 47 | res := q.Limit(1).GetTo(&ret) 48 | return ret, res 49 | } 50 | 51 | //get first int 52 | func (q *Query[T]) GetInt() (int64, QueryResult) { 53 | var ret int64 54 | res := q.Limit(1).GetTo(&ret) 55 | return ret, res 56 | } 57 | func (q *Query[T]) GetUint() (uint64, QueryResult) { 58 | var ret uint64 59 | res := q.Limit(1).GetTo(&ret) 60 | return ret, res 61 | } 62 | 63 | //get first string 64 | func (q *Query[T]) GetString() (string, QueryResult) { 65 | var ret string 66 | res := q.Limit(1).GetTo(&ret) 67 | return ret, res 68 | } 69 | 70 | // ↓↓ more Get examples ↓↓ 71 | func (q *Query[T]) GetSliceInt() ([]int64, QueryResult) { 72 | var ret []int64 73 | res := q.GetTo(&ret) 74 | return ret, res 75 | } 76 | func (q *Query[T]) GetSliceUint() ([]uint64, QueryResult) { 77 | var ret []uint64 78 | res := q.GetTo(&ret) 79 | return ret, res 80 | } 81 | func (q *Query[T]) GetSliceString() ([]string, QueryResult) { 82 | var ret []string 83 | res := q.GetTo(&ret) 84 | return ret, res 85 | } 86 | func (q *Query[T]) GetMapString() (map[string]string, QueryResult) { 87 | var ret map[string]string 88 | res := q.GetTo(&ret) 89 | return ret, res 90 | } 91 | func (q *Query[T]) GetMapSliceString() (map[string][]string, QueryResult) { 92 | var ret map[string][]string 93 | res := q.GetTo(&ret) 94 | return ret, res 95 | } 96 | func (q *Query[T]) GetMapStringInt() (map[string]int64, QueryResult) { 97 | var ret map[string]int64 98 | res := q.GetTo(&ret) 99 | return ret, res 100 | } 101 | func (q *Query[T]) GetMapIntString() (map[int64]string, QueryResult) { 102 | var ret map[int64]string 103 | res := q.GetTo(&ret) 104 | return ret, res 105 | } 106 | func (q *Query[T]) GetMapStringUint() (map[string]uint64, QueryResult) { 107 | var ret map[string]uint64 108 | res := q.GetTo(&ret) 109 | return ret, res 110 | } 111 | func (q *Query[T]) GetMapUintString() (map[uint64]string, QueryResult) { 112 | var ret map[uint64]string 113 | res := q.GetTo(&ret) 114 | return ret, res 115 | } 116 | func (q *Query[T]) GetMapInt() (map[int64]int64, QueryResult) { 117 | var ret map[int64]int64 118 | res := q.GetTo(&ret) 119 | return ret, res 120 | } 121 | func (q *Query[T]) GetMapUint() (map[uint64]uint64, QueryResult) { 122 | var ret map[uint64]uint64 123 | res := q.GetTo(&ret) 124 | return ret, res 125 | } 126 | 127 | //get count T 128 | func (q *Query[T]) GetCount() (int64, QueryResult) { 129 | if len(q.groupBy) == 0 { 130 | if len(q.columns) == 0 { 131 | return q.Select("count(*)").GetInt() 132 | } else { 133 | c, err := q.parseColumn(q.columns[0]) 134 | q.columns = nil 135 | if err == nil { 136 | cl := strings.ToLower(c) 137 | if strings.HasPrefix(cl, "count(") == false || strings.Contains(cl, ")") == false { 138 | c = "count(" + c + ")" 139 | } 140 | } 141 | return q.setErr(err).Select(c).GetInt() 142 | } 143 | } else { 144 | tempTable := q.SubQuery() 145 | 146 | newQuery := NewQuery(tempTable, tempTable.dbs...) 147 | 148 | return newQuery.setErr(tempTable.err).Select("count(*)").GetInt() 149 | } 150 | } 151 | 152 | /* 153 | destPtr, pointer of any value like: 154 | value, []value, map[key]value, map[key][]value 155 | */ 156 | func (q *Query[T]) GetTo(destPtr any) QueryResult { 157 | tempTable := q.SubQuery() 158 | 159 | q.result.PrepareSql = tempTable.raw 160 | q.result.Bindings = tempTable.bindings 161 | if tempTable.err != nil { 162 | q.result.Err = tempTable.err 163 | } 164 | 165 | if q.result.Err != nil { 166 | if errorLogger != nil { 167 | errorLogger.Error(q.result.Sql(), q.result.Error()) 168 | } 169 | return q.result 170 | } else if infoLogger != nil { 171 | infoLogger.Info(q.result.Sql(), q.result.Error()) 172 | } 173 | 174 | var rows *sql.Rows 175 | var err error 176 | if q.Tx() != nil { 177 | if q.ctx != nil { 178 | rows, err = q.Tx().QueryContext(*q.ctx, tempTable.raw, tempTable.bindings...) 179 | } else { 180 | rows, err = q.Tx().Query(tempTable.raw, tempTable.bindings...) 181 | } 182 | } else { 183 | if q.ctx != nil { 184 | rows, err = q.readDB().QueryContext(*q.ctx, tempTable.raw, tempTable.bindings...) 185 | } else { 186 | rows, err = q.readDB().Query(tempTable.raw, tempTable.bindings...) 187 | } 188 | } 189 | 190 | defer func() { 191 | if rows != nil { 192 | _ = rows.Close() 193 | } 194 | }() 195 | 196 | if err != nil { 197 | q.result.Err = err 198 | if errorLogger != nil { 199 | errorLogger.Error(q.result.Sql(), q.result.Error()) 200 | } 201 | return q.result 202 | } 203 | 204 | q.result.Err = q.scanRows(destPtr, rows) 205 | return q.result 206 | } 207 | 208 | func (q *Query[T]) scanRows(dest any, rows *sql.Rows) error { 209 | rowColumns, gerr := rows.Columns() 210 | if gerr != nil { 211 | return gerr 212 | } 213 | destValue := reflect.ValueOf(dest) 214 | if destValue.Kind() != reflect.Ptr { 215 | return ErrDestOfGetToMustBePtr 216 | } 217 | destValueValue := destValue.Elem() 218 | if destValueValue.Kind() == reflect.Ptr { 219 | return ErrDestOfGetToMustBePtr 220 | } 221 | 222 | if reflectValueIsOrmField(destValue) == true { 223 | var basePtrs = make([]any, len(rowColumns)) 224 | for k := 0; k < len(rowColumns); k++ { 225 | if k == 0 { 226 | basePtrs[k] = dest 227 | } else { 228 | var temp any 229 | basePtrs[k] = &temp 230 | } 231 | } 232 | gerr = q.scanValues(basePtrs, rowColumns, rows, nil, true) 233 | } else { 234 | switch destValueValue.Kind() { 235 | case reflect.Map: 236 | reflectMap := reflect.TypeOf(dest).Elem() 237 | 238 | mapKeyType := reflectMap.Key() 239 | mapValueType := reflectMap.Elem() 240 | if mapValueType.Kind() == reflect.Ptr && mapValueType.Elem() != q.tables[0].tableStructType { 241 | return ErrDestOfGetToMapElemMustNotBePtr 242 | } 243 | newVal := reflect.MakeMap(reflectMap) 244 | switch mapValueType.Kind() { 245 | case reflect.Ptr: 246 | structAddr := reflect.New(q.tables[0].tableStructType).Interface() 247 | 248 | structAddrMap, err := getStructFieldAddrMap(structAddr) 249 | if err != nil { 250 | return err 251 | } 252 | var basePtrs = make([]any, len(rowColumns)) 253 | 254 | keyType := reflectMap.Key() 255 | keyAddr := reflect.New(keyType).Interface() 256 | 257 | structVal := reflect.ValueOf(structAddr).Elem() 258 | 259 | for k, v := range rowColumns { 260 | basePtrs[k] = structAddrMap[v] 261 | if basePtrs[k] == nil { 262 | if k == 0 { 263 | basePtrs[k] = keyAddr 264 | } else { 265 | var temp any 266 | basePtrs[k] = &temp 267 | } 268 | } 269 | } 270 | 271 | gerr = q.scanValues(basePtrs, rowColumns, rows, func() { 272 | tmp := reflect.New(q.tables[0].tableStructType) 273 | tmp.Elem().Set(structVal) 274 | newVal.SetMapIndex(reflect.ValueOf(basePtrs[0]).Elem(), tmp) 275 | }, false) 276 | destValue.Elem().Set(newVal) 277 | case reflect.Slice: //group by results 278 | switch mapValueType.Elem().Kind() { 279 | case reflect.Ptr: 280 | if mapValueType.Elem().Elem() != q.tables[0].tableStructType { 281 | return ErrDestOfGetToSliceElemMustNotBePtr 282 | } 283 | keyAddr := reflect.New(reflectMap.Key()).Interface() 284 | 285 | structAddr := reflect.New(q.tables[0].tableStructType).Interface() 286 | 287 | structAddrMap, err := getStructFieldAddrMap(structAddr) 288 | if err != nil { 289 | return err 290 | } 291 | var basePtrs = make([]any, len(rowColumns)) 292 | 293 | structVal := reflect.ValueOf(structAddr).Elem() 294 | 295 | for k, v := range rowColumns { 296 | basePtrs[k] = structAddrMap[v] 297 | if basePtrs[k] == nil { 298 | var temp any 299 | basePtrs[k] = &temp 300 | } 301 | } 302 | basePtrs[0] = keyAddr 303 | 304 | gerr = q.scanValues(basePtrs, rowColumns, rows, func() { 305 | tmp := reflect.New(q.tables[0].tableStructType) 306 | tmp.Elem().Set(structVal) 307 | 308 | index := reflect.ValueOf(basePtrs[0]).Elem() 309 | tempSlice := newVal.MapIndex(index) 310 | if tempSlice.IsValid() == false { 311 | tempSlice = reflect.MakeSlice(mapValueType, 0, 0) 312 | } 313 | 314 | newVal.SetMapIndex(index, reflect.Append(tempSlice, tmp)) 315 | }, false) 316 | 317 | destValue.Elem().Set(newVal) 318 | case reflect.Struct: 319 | newStructVal := reflect.New(mapValueType.Elem()) 320 | if reflectValueIsOrmField(newStructVal) == false { 321 | keyAddr := reflect.New(reflectMap.Key()).Interface() 322 | structAddr := newStructVal.Interface() 323 | structAddrMap, err := getStructFieldAddrMap(structAddr) 324 | if err != nil { 325 | return err 326 | } 327 | var basePtrs = make([]any, len(rowColumns)) 328 | structVal := newStructVal.Elem() 329 | 330 | for k, v := range rowColumns { 331 | basePtrs[k] = structAddrMap[v] 332 | if basePtrs[k] == nil { 333 | var temp any 334 | basePtrs[k] = &temp 335 | } 336 | } 337 | basePtrs[0] = keyAddr 338 | gerr = q.scanValues(basePtrs, rowColumns, rows, func() { 339 | index := reflect.ValueOf(basePtrs[0]).Elem() 340 | tempSlice := newVal.MapIndex(index) 341 | if tempSlice.IsValid() == false { 342 | tempSlice = reflect.MakeSlice(mapValueType, 0, 0) 343 | } 344 | newVal.SetMapIndex(index, reflect.Append(tempSlice, structVal)) 345 | }, false) 346 | destValue.Elem().Set(newVal) 347 | break 348 | } 349 | fallthrough 350 | default: 351 | newKeyVal := reflect.New(reflectMap.Key()) 352 | var basePtrs = make([]any, len(rowColumns)) 353 | for k := range basePtrs { 354 | if k == 0 { 355 | basePtrs[0] = newKeyVal.Interface() 356 | } else { 357 | basePtrs[k] = reflect.New(mapValueType.Elem()).Interface() 358 | } 359 | } 360 | gerr = q.scanValues(basePtrs, rowColumns, rows, func() { 361 | index := newKeyVal.Elem() 362 | tempSlice := newVal.MapIndex(index) 363 | if tempSlice.IsValid() == false { 364 | tempSlice = reflect.MakeSlice(mapValueType, 0, 0) 365 | } 366 | for k := range basePtrs { 367 | if k > 0 { 368 | tempSlice = reflect.Append(tempSlice, reflect.ValueOf(basePtrs[k]).Elem()) 369 | } 370 | } 371 | newVal.SetMapIndex(index, tempSlice) 372 | }, false) 373 | destValue.Elem().Set(newVal) 374 | } 375 | case reflect.Struct: 376 | newValue := reflect.New(mapValueType) 377 | if reflectValueIsOrmField(newValue) == false { 378 | structAddr := newValue.Interface() 379 | structAddrMap, err := getStructFieldAddrMap(structAddr) 380 | if err != nil { 381 | return err 382 | } 383 | var basePtrs = make([]any, len(rowColumns)) 384 | 385 | keyType := reflectMap.Key() 386 | keyAddr := reflect.New(keyType).Interface() 387 | 388 | structVal := newValue.Elem() 389 | 390 | for k, v := range rowColumns { 391 | basePtrs[k] = structAddrMap[v] 392 | if k == 0 { 393 | basePtrs[k] = keyAddr 394 | } else if basePtrs[k] == nil { 395 | var temp any 396 | basePtrs[k] = &temp 397 | } 398 | } 399 | gerr = q.scanValues(basePtrs, rowColumns, rows, func() { 400 | newVal.SetMapIndex(reflect.ValueOf(basePtrs[0]).Elem(), structVal) 401 | }, false) 402 | destValue.Elem().Set(newVal) 403 | break 404 | } 405 | fallthrough 406 | default: 407 | newKeyValue := reflect.New(mapKeyType) 408 | newValValue := reflect.New(mapValueType) 409 | 410 | var basePtrs = make([]any, len(rowColumns)) 411 | 412 | for k := 0; k < len(rowColumns); k++ { 413 | if k == 0 { 414 | basePtrs[k] = newKeyValue.Interface() 415 | } else if k == 1 { 416 | basePtrs[k] = newValValue.Interface() 417 | } else { 418 | var temp any 419 | basePtrs[k] = &temp 420 | } 421 | } 422 | gerr = q.scanValues(basePtrs, rowColumns, rows, func() { 423 | newVal.SetMapIndex(newKeyValue.Elem(), newValValue.Elem()) 424 | }, false) 425 | 426 | destValue.Elem().Set(newVal) 427 | } 428 | case reflect.Slice: 429 | ele := reflect.TypeOf(dest).Elem().Elem() 430 | if ele.Kind() == reflect.Ptr && ele.Elem() != q.tables[0].tableStructType { 431 | return ErrDestOfGetToSliceElemMustNotBePtr 432 | } 433 | 434 | switch ele.Kind() { 435 | case reflect.Ptr: 436 | structAddr := reflect.New(q.tables[0].tableStructType).Interface() 437 | 438 | structAddrMap, err := getStructFieldAddrMap(structAddr) 439 | if err != nil { 440 | return err 441 | } 442 | var basePtrs = make([]any, len(rowColumns)) 443 | 444 | structVal := reflect.ValueOf(structAddr).Elem() 445 | 446 | for k, v := range rowColumns { 447 | basePtrs[k] = structAddrMap[v] 448 | if basePtrs[k] == nil { 449 | var temp any 450 | basePtrs[k] = &temp 451 | } 452 | } 453 | 454 | gerr = q.scanValues(basePtrs, rowColumns, rows, func() { 455 | tmp := reflect.New(q.tables[0].tableStructType) 456 | tmp.Elem().Set(structVal) 457 | destValueValue = reflect.Append(destValueValue, tmp) 458 | }, false) 459 | 460 | destValue.Elem().Set(destValueValue) 461 | case reflect.Struct: 462 | eleNew := reflect.New(ele) 463 | if reflectValueIsOrmField(eleNew) == false { 464 | structAddr := eleNew.Interface() 465 | structVal := eleNew.Elem() 466 | structAddrMap, err := getStructFieldAddrMap(structAddr) 467 | if err != nil { 468 | return err 469 | } 470 | var basePtrs = make([]any, len(rowColumns)) 471 | 472 | for k, v := range rowColumns { 473 | basePtrs[k] = structAddrMap[v] 474 | if basePtrs[k] == nil { 475 | var temp any 476 | basePtrs[k] = &temp 477 | } 478 | } 479 | 480 | gerr = q.scanValues(basePtrs, rowColumns, rows, func() { 481 | destValueValue = reflect.Append(destValueValue, structVal) 482 | }, false) 483 | 484 | destValue.Elem().Set(destValueValue) 485 | break 486 | } 487 | fallthrough 488 | default: 489 | var basePtrs = make([]any, len(rowColumns)) 490 | 491 | for k := 0; k < len(rowColumns); k++ { 492 | basePtrs[k] = reflect.New(ele).Interface() 493 | } 494 | 495 | gerr = q.scanValues(basePtrs, rowColumns, rows, func() { 496 | for _, v := range basePtrs { 497 | destValueValue = reflect.Append(destValueValue, reflect.ValueOf(v).Elem()) 498 | } 499 | }, false) 500 | 501 | destValue.Elem().Set(destValueValue) 502 | } 503 | case reflect.Struct: 504 | if reflectValueIsOrmField(destValue) == false { 505 | structAddr := dest 506 | structAddrMap, err := getStructFieldAddrMap(structAddr) 507 | if err != nil { 508 | return err 509 | } 510 | var basePtrs = make([]any, len(rowColumns)) 511 | 512 | for k, v := range rowColumns { 513 | basePtrs[k] = structAddrMap[v] 514 | if basePtrs[k] == nil { 515 | var temp any 516 | basePtrs[k] = &temp 517 | } 518 | } 519 | gerr = q.scanValues(basePtrs, rowColumns, rows, nil, true) 520 | break 521 | } 522 | fallthrough 523 | default: 524 | var basePtrs = make([]any, len(rowColumns)) 525 | for k := 0; k < len(rowColumns); k++ { 526 | if k == 0 { 527 | basePtrs[k] = dest 528 | } else { 529 | var temp any 530 | basePtrs[k] = &temp 531 | } 532 | } 533 | gerr = q.scanValues(basePtrs, rowColumns, rows, nil, true) 534 | } 535 | } 536 | return gerr 537 | } 538 | 539 | func (q *Query[T]) scanValues(basePtrs []any, rowColumns []string, rows *sql.Rows, setVal func(), tryOnce bool) error { 540 | var err error 541 | var tempPtrs = make([]any, len(rowColumns)) 542 | for k := range rowColumns { 543 | var temp any 544 | tempPtrs[k] = &temp 545 | } 546 | 547 | finalPtrs := make([]any, len(rowColumns)) 548 | 549 | for rows.Next() { 550 | err = rows.Scan(tempPtrs...) 551 | if err != nil { 552 | return err 553 | } 554 | 555 | for k, v := range tempPtrs { 556 | if *v.(*any) == nil { 557 | felement := reflect.ValueOf(basePtrs[k]).Elem() 558 | felement.Set(reflect.Zero(felement.Type())) 559 | finalPtrs[k] = v 560 | } else { 561 | finalPtrs[k] = basePtrs[k] 562 | } 563 | } 564 | 565 | err = rows.Scan(finalPtrs...) 566 | //set zero value after err column index 567 | if err != nil && strings.HasPrefix(err.Error(), scanErrPrefix) { 568 | indexs := matchScanErrIndex.FindAllString(err.Error(), 1) 569 | if len(indexs) > 0 { 570 | i, _ := strconv.Atoi(indexs[0]) 571 | for k := range basePtrs { 572 | if k >= i { 573 | felement := reflect.ValueOf(basePtrs[k]).Elem() 574 | felement.Set(reflect.Zero(felement.Type())) 575 | } 576 | } 577 | } 578 | } 579 | 580 | q.result.RowsAffected += 1 581 | 582 | if setVal != nil { 583 | setVal() 584 | } 585 | if tryOnce { 586 | break 587 | } 588 | } 589 | if err == nil { 590 | err = rows.Err() 591 | } 592 | return err 593 | } 594 | -------------------------------------------------------------------------------- /orm/query_insert.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "sort" 7 | "strings" 8 | ) 9 | 10 | //insert and set primary for first T 11 | func (q *Query[T]) Insert(data ...T) QueryResult { 12 | return q.insert(data) 13 | } 14 | 15 | func (q *Query[T]) InsertSubquery(data *SubQuery) QueryResult { 16 | return q.insert(data) 17 | } 18 | 19 | func (q *Query[T]) OnConflictUpdate(column any, val any, columnVars ...any) *Query[T] { 20 | q.insertIgnore = true 21 | q.conflictUpdates = []updateColumn{{col: column, val: val}} 22 | if len(columnVars) > 0 { 23 | for k := range columnVars { 24 | if k%2 != 0 { 25 | q.conflictUpdates = append(q.conflictUpdates, updateColumn{col: columnVars[k-1], val: columnVars[k]}) 26 | } 27 | } 28 | } 29 | return q 30 | } 31 | 32 | func (q *Query[T]) OnConflictUpdates(columnVars map[any]any) *Query[T] { 33 | q.insertIgnore = true 34 | 35 | if len(columnVars) > 0 { 36 | q.conflictUpdates = make([]updateColumn, len(columnVars)) 37 | var i = 0 38 | for k := range columnVars { 39 | q.conflictUpdates[i] = updateColumn{col: k, val: columnVars[k]} 40 | i++ 41 | } 42 | } 43 | 44 | return q 45 | } 46 | 47 | func (q *Query[T]) gennerateInsertSql(InsertColumns []string, rowCount int) string { 48 | columnRawStr := "" 49 | valRawStr := "" 50 | 51 | if len(InsertColumns) > 0 { 52 | for _, v := range InsertColumns { 53 | columnRawStr += "`" + v + "`," 54 | valRawStr += "?," 55 | } 56 | columnRawStr = "(" + strings.TrimRight(columnRawStr, ",") + ")" 57 | valRawStr = "(" + strings.TrimRight(valRawStr, ",") + ")" 58 | } 59 | 60 | if rowCount > 0 { 61 | rows := make([]string, rowCount) 62 | for k := range rows { 63 | rows[k] = valRawStr 64 | } 65 | return columnRawStr + " values " + strings.Join(rows, ",") 66 | } else { 67 | return columnRawStr 68 | } 69 | } 70 | 71 | func (q *Query[T]) getInsertBindings(val reflect.Value, validFieldIndex map[int]struct{}, defaults map[int]any) []any { 72 | var bindings []any 73 | for i := 0; i < val.Len(); i++ { 74 | for k := 0; k < val.Index(i).Elem().NumField(); k++ { 75 | if _, ok := validFieldIndex[k]; ok { 76 | if defaults[k] != nil && val.Index(i).Elem().Field(k).IsZero() { 77 | bindings = append(bindings, defaults[k]) 78 | } else { 79 | bindings = append(bindings, val.Index(i).Elem().Field(k).Interface()) 80 | } 81 | } 82 | } 83 | } 84 | return bindings 85 | } 86 | 87 | func (q *Query[T]) insert(data any) QueryResult { 88 | var err error 89 | var acceptFields = q.columns 90 | 91 | var updates = q.conflictUpdates 92 | var ignore = q.insertIgnore 93 | 94 | val := reflect.ValueOf(data) 95 | isSubQuery := false 96 | var subq *SubQuery 97 | rowCount := 1 98 | var structFields []string 99 | var structDefaults map[int]any 100 | if val.Kind() == reflect.Slice { 101 | rowCount = val.Len() 102 | if rowCount == 0 { 103 | q.setErr(errors.New("slice is empty")) 104 | return q.result 105 | } 106 | if val.Index(0).Type().Elem() != q.tables[0].tableStructType { 107 | q.setErr(errors.New("slice elem must be T")) 108 | } else { 109 | structFields, err = getStructFieldNameSlice(val.Index(0).Elem().Interface()) 110 | q.setErr(err) 111 | structDefaults, err = getStructFieldWithDefaultTime(val.Index(0).Elem().Interface()) 112 | q.setErr(err) 113 | } 114 | } else if val.Kind() == reflect.Ptr { 115 | sub, ok := data.(*SubQuery) 116 | if ok { 117 | isSubQuery = true 118 | subq = sub 119 | } else { 120 | q.setErr(ErrInsertPtrNotAllowed) 121 | } 122 | } else { 123 | q.setErr(errors.New("data must be subquery or slice of T")) 124 | } 125 | 126 | if q.result.Err != nil { 127 | if errorLogger != nil { 128 | errorLogger.Error(q.result.Sql(), q.result.Error()) 129 | } 130 | return q.result 131 | } else if infoLogger != nil { 132 | infoLogger.Info(q.result.Sql(), q.result.Error()) 133 | } 134 | 135 | var validFieldNameMap = make(map[string]int) 136 | var validFieldNames = make([]string, 0) 137 | var validFieldIndex = make(map[int]struct{}) 138 | var InsertColumns []string //actually insert columns 139 | var allowFields = make(map[any]int) 140 | 141 | for k, v := range acceptFields { 142 | allowFields[v] = k 143 | } 144 | for k, v := range q.tables[0].ormFields { 145 | pos, ok := allowFields[k] 146 | if ok || (len(allowFields) == 0 && isSubQuery == false) { 147 | validFieldNameMap[v] = pos 148 | validFieldNames = append(validFieldNames, v) 149 | } 150 | } 151 | 152 | var insertSql, updateStr string 153 | var bindings []any 154 | if isSubQuery { 155 | sort.SliceStable(validFieldNames, func(i, j int) bool { 156 | return validFieldNameMap[validFieldNames[i]] < validFieldNameMap[validFieldNames[j]] 157 | }) 158 | insertSql = q.gennerateInsertSql(validFieldNames, 0) 159 | if insertSql != "" { 160 | insertSql += " " 161 | } 162 | insertSql += subq.raw 163 | bindings = subq.bindings 164 | } else { 165 | for k, v := range structFields { 166 | if _, ok := validFieldNameMap[v]; ok { 167 | validFieldIndex[k] = struct{}{} 168 | InsertColumns = append(InsertColumns, v) 169 | } 170 | } 171 | insertSql = q.gennerateInsertSql(InsertColumns, rowCount) 172 | bindings = q.getInsertBindings(val, validFieldIndex, structDefaults) 173 | } 174 | 175 | updateStr = q.generateUpdateStr(updates, &bindings) 176 | 177 | rawSql := "insert" 178 | if ignore { 179 | rawSql += " ignore" 180 | } 181 | 182 | rawSql += " into " + q.tables[0].getTableNamePartition() + " " + insertSql 183 | 184 | if updateStr != "" { 185 | rawSql += " on duplicate key update " + updateStr 186 | } 187 | 188 | rawSql += ";" 189 | 190 | q.prepareSql = rawSql 191 | q.bindings = bindings 192 | 193 | res := q.Execute() 194 | 195 | //set first element's first field on condition 196 | if isSubQuery == false && res.Err == nil && res.LastInsertId > 0 && (val.Len() == 1 || q.insertIgnore == false) { 197 | val.Index(0).Elem().Field(0).Set(reflect.ValueOf(res.LastInsertId).Convert(val.Index(0).Elem().Field(0).Type())) 198 | } 199 | return res 200 | } 201 | -------------------------------------------------------------------------------- /orm/query_join.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | type JoinType string 8 | 9 | const ( 10 | JoinTypeInner JoinType = "inner join" 11 | JoinTypeLeft JoinType = "left join" 12 | JoinTypeRight JoinType = "right join" 13 | JoinTypeOuter JoinType = "outer join" 14 | ) 15 | 16 | func (q *Query[T]) Join(table Table, where func(join *Query[T]) *Query[T], alias ...string) *Query[T] { 17 | return q.join(JoinTypeInner, table, where, alias...) 18 | } 19 | func (q *Query[T]) LeftJoin(table Table, where func(join *Query[T]) *Query[T], alias ...string) *Query[T] { 20 | return q.join(JoinTypeLeft, table, where, alias...) 21 | } 22 | func (q *Query[T]) RightJoin(table Table, where func(join *Query[T]) *Query[T], alias ...string) *Query[T] { 23 | return q.join(JoinTypeRight, table, where, alias...) 24 | } 25 | func (q *Query[T]) OuterJoin(table Table, where func(join *Query[T]) *Query[T], alias ...string) *Query[T] { 26 | return q.join(JoinTypeOuter, table, where, alias...) 27 | } 28 | 29 | func (q *Query[T]) join(joinType JoinType, table Table, wheref func(where *Query[T]) *Query[T], alias ...string) *Query[T] { 30 | newTable, err := q.parseTable(table) 31 | if err != nil { 32 | return q.setErr(err) 33 | } 34 | 35 | //join self 36 | for k := range q.tables { 37 | if q.tables[k] == newTable { 38 | var tmp = *newTable 39 | newTable = &(tmp) 40 | break 41 | } 42 | } 43 | 44 | if len(alias) > 0 { 45 | newTable.alias = alias[0] 46 | } else if newTable.rawSql != "" { 47 | newTable.alias = subqueryDefaultName 48 | } 49 | 50 | newTable.joinType = joinType 51 | q.tables = append(q.tables, newTable) 52 | q.tables[len(q.tables)-1].joinCondition, err = q.generateWhereGroup(wheref) 53 | return q.setErr(err) 54 | } 55 | 56 | func (q *Query[T]) generateTableAndJoinStr(tables []*queryTable, bindings *[]any) string { 57 | if len(tables) == 0 { 58 | return "" 59 | } 60 | var tableStrs []string 61 | for k, v := range tables { 62 | tempStr := "" 63 | if v.rawSql == "" { 64 | if k == 0 { 65 | tempStr = v.getTableNameAndAlias() 66 | } else { 67 | tempStr = string(v.joinType) 68 | tempStr += " " + v.getTableNameAndAlias() 69 | if len(v.joinCondition.SubWheres) > 0 { 70 | whereStr := q.generateWhereStr(v.joinCondition.SubWheres, bindings) 71 | tempStr += " on " + whereStr 72 | } 73 | } 74 | } else { 75 | if k == 0 { 76 | tempStr = "(" + v.rawSql + ")" 77 | if v.getAlias() != "" { 78 | tempStr += " " + v.getAlias() 79 | } 80 | *bindings = append(*bindings, v.bindings...) 81 | } else { 82 | tempStr = string(v.joinType) 83 | tempStr += " (" + v.rawSql + ")" 84 | if v.getAlias() != "" { 85 | tempStr += " " + v.getAlias() 86 | } 87 | *bindings = append(*bindings, v.bindings...) 88 | if len(v.joinCondition.SubWheres) > 0 { 89 | whereStr := q.generateWhereStr(v.joinCondition.SubWheres, bindings) 90 | tempStr += " on " + whereStr 91 | } 92 | } 93 | } 94 | 95 | tableStrs = append(tableStrs, tempStr) 96 | } 97 | 98 | return strings.Join(tableStrs, " ") 99 | } 100 | -------------------------------------------------------------------------------- /orm/query_raw.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | func (q *Query[T]) Raw(prepareSql string, bindings ...any) *Query[T] { 4 | q.prepareSql = prepareSql 5 | q.bindings = bindings 6 | return q 7 | } 8 | -------------------------------------------------------------------------------- /orm/query_result.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | type QueryResult struct { 8 | PrepareSql string 9 | Bindings []any 10 | LastInsertId int64 11 | RowsAffected int64 12 | Err error 13 | } 14 | 15 | func (q QueryResult) Error() error { 16 | return q.Err 17 | } 18 | 19 | func (q QueryResult) Sql() string { 20 | params := make([]string, len(q.Bindings)) 21 | for k, v := range q.Bindings { 22 | params[k] = varToString(v) 23 | } 24 | 25 | var sql strings.Builder 26 | var index = 0 27 | 28 | for _, v := range []byte(q.PrepareSql) { 29 | if v == '?' { 30 | if len(params) > index { 31 | sql.WriteString(params[index]) 32 | index++ 33 | } else { 34 | break 35 | } 36 | } else { 37 | sql.WriteByte(v) 38 | } 39 | } 40 | 41 | return sql.String() 42 | } 43 | -------------------------------------------------------------------------------- /orm/query_select.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | "time" 7 | ) 8 | 9 | type SelectForUpdateType string 10 | 11 | const ( 12 | SelectForUpdateTypeDefault SelectForUpdateType = "for update" 13 | SelectForUpdateTypeNowait SelectForUpdateType = "for update nowait" 14 | SelectForUpdateTypeSkipLocked SelectForUpdateType = "for update skip locked" 15 | ) 16 | 17 | func (q *Query[T]) SelectRank(column any, alias string) *Query[T] { 18 | return q.SelectOver("rank()", func(query *Query[T]) *Query[T] { 19 | return query.OrderBy(column) 20 | }, alias) 21 | } 22 | 23 | func (q *Query[T]) SelectRankDesc(column any, alias string) *Query[T] { 24 | return q.SelectOver("rank()", func(query *Query[T]) *Query[T] { 25 | return query.OrderByDesc(column) 26 | }, alias) 27 | } 28 | 29 | func (q *Query[T]) SelectRowNumber(column any, alias string) *Query[T] { 30 | return q.SelectOver("row_number()", func(query *Query[T]) *Query[T] { 31 | return query.OrderBy(column) 32 | }, alias) 33 | } 34 | 35 | func (q *Query[T]) SelectRowNumberDesc(column any, alias string) *Query[T] { 36 | return q.SelectOver("row_number()", func(query *Query[T]) *Query[T] { 37 | return query.OrderByDesc(column) 38 | }, alias) 39 | } 40 | 41 | func (q *Query[T]) SelectOver(windowFunc string, f func(query *Query[T]) *Query[T], alias string) *Query[T] { 42 | partitionStart := len(q.partitionbys) 43 | orderStart := len(q.orderbys) 44 | nq := *q 45 | f(&nq) 46 | partitions := nq.partitionbys[partitionStart:] 47 | orders := nq.orderbys[orderStart:] 48 | 49 | q.setErr(nq.result.Err) 50 | 51 | newSelect := windowFunc + " over (" 52 | if len(partitions) > 0 { 53 | newSelect += "partition by " + strings.Join(partitions, ",") + " " 54 | } 55 | if len(orders) > 0 { 56 | newSelect += "order by " + strings.Join(orders, ",") 57 | } 58 | newSelect += ")" 59 | 60 | newSelect += " as " + alias 61 | 62 | q.columns = append(q.columns, newSelect) 63 | return q 64 | } 65 | 66 | func (q *Query[T]) SelectOverRaw(windowFunc string, windowName string, alias string) *Query[T] { 67 | newSelect := windowFunc + " over " + windowName + " as " + alias 68 | q.columns = append(q.columns, newSelect) 69 | return q 70 | } 71 | 72 | func (q *Query[T]) Select(columns ...any) *Query[T] { 73 | q.columns = append(q.columns, columns...) 74 | return q 75 | } 76 | 77 | func (q *Query[T]) SelectAs(column any, as string) *Query[T] { 78 | colStr, err := q.generateSelectColumns(column) 79 | q.setErr(err) 80 | q.columns = append(q.columns, colStr+" as "+as) 81 | return q 82 | } 83 | 84 | func (q *Query[T]) SelectExclude(exceptColumns ...any) *Query[T] { 85 | q.columns = nil 86 | 87 | for i := 0; i < q.tables[0].tableStruct.NumField(); i++ { 88 | field := q.tables[0].tableStruct.Field(i) 89 | if field.CanAddr() == false || field.Addr().CanInterface() == false { 90 | continue 91 | } 92 | addr := field.Addr().Interface() 93 | 94 | if v, ok := q.tables[0].ormFields[addr]; ok { 95 | except := false 96 | for _, v2 := range exceptColumns { 97 | if v2 == v || v2 == addr { 98 | except = true 99 | break 100 | } 101 | } 102 | if except == false { 103 | q.columns = append(q.columns, addr) 104 | } 105 | } 106 | } 107 | 108 | return q 109 | } 110 | 111 | func (q *Query[T]) ForUpdate(forUpdateType ...SelectForUpdateType) *Query[T] { 112 | if len(forUpdateType) == 0 { 113 | q.forUpdate = SelectForUpdateTypeDefault 114 | } else { 115 | q.forUpdate = forUpdateType[0] 116 | } 117 | return q 118 | } 119 | 120 | func (q *Query[T]) SelectWithTimeout(duration time.Duration) *Query[T] { 121 | ms := duration.Milliseconds() 122 | if ms > 0 { 123 | q.selectTimeout = "/*+ MAX_EXECUTION_TIME(" + strconv.FormatInt(ms, 10) + ") */" 124 | } 125 | return q 126 | } 127 | -------------------------------------------------------------------------------- /orm/query_select_gen.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | func (q *Query[T]) SubQuery() *SubQuery { 8 | if q.self != nil { 9 | cte := q.self 10 | q.self = nil 11 | 12 | mt := cte.WithRecursiveCte(q.SubQuery(), cte.T.TableName()) 13 | tempTable := mt.generateSelectQuery(mt.columns...) 14 | 15 | tempTable.dbs = mt.DBs() 16 | tempTable.tx = mt.tx 17 | tempTable.dbName = mt.tables[0].table.DatabaseName() 18 | if mt.result.Err != nil { 19 | tempTable.err = mt.result.Err 20 | } 21 | 22 | return tempTable 23 | } else { 24 | mt := q 25 | tempTable := mt.generateSelectQuery(mt.columns...) 26 | 27 | tempTable.dbs = mt.DBs() 28 | tempTable.tx = mt.tx 29 | tempTable.dbName = mt.tables[0].table.DatabaseName() 30 | if mt.result.Err != nil { 31 | tempTable.err = mt.result.Err 32 | } 33 | 34 | return tempTable 35 | } 36 | } 37 | 38 | func (q *Query[T]) generateSelectQuery(columns ...any) *SubQuery { 39 | var ret SubQuery 40 | if q.prepareSql != "" { 41 | ret.raw = q.prepareSql 42 | ret.bindings = q.bindings 43 | } else { 44 | var rawSql string 45 | bindings := make([]any, 0) 46 | 47 | if len(q.withCtes) > 0 { 48 | var raws []string 49 | for _, v := range q.withCtes { 50 | var raw string 51 | if v.recursive { 52 | raw += "recursive " 53 | } 54 | raw += v.tableName 55 | if len(v.columns) > 0 { 56 | raw += "(" + strings.Join(v.columns, ",") + ")" 57 | } 58 | raw += " as (" 59 | raw += v.raw 60 | raw += ")" 61 | raws = append(raws, raw) 62 | bindings = append(bindings, v.bindings...) 63 | } 64 | rawSql += "with " + strings.Join(raws, ",\n") + "\n" 65 | } 66 | 67 | selectStr, err := q.generateSelectColumns(columns...) 68 | if err != nil { 69 | ret.err = err 70 | } 71 | 72 | tableStr := q.generateTableAndJoinStr(q.tables, &bindings) 73 | 74 | whereStr := q.generateWhereStr(q.wheres, &bindings) 75 | 76 | var groupBy string 77 | if len(q.groupBy) > 0 { 78 | groupBy, err = q.generateSelectColumns(q.groupBy...) 79 | if err != nil { 80 | ret.err = err 81 | } 82 | } 83 | var having string 84 | if len(q.having) > 0 { 85 | having = q.generateWhereStr(q.having, &bindings) 86 | } 87 | 88 | orderLimitOffsetStr := q.getOrderAndLimitSqlStr() 89 | 90 | var selectKeyword = "select" 91 | if q.selectTimeout != "" { 92 | selectKeyword += " " + q.selectTimeout 93 | } 94 | 95 | rawSql += selectKeyword + " " + selectStr 96 | 97 | if tableStr != "" { 98 | rawSql += " from " + tableStr 99 | if whereStr != "" { 100 | rawSql += " where " + whereStr 101 | } 102 | } 103 | 104 | if groupBy != "" && groupBy != "*" { 105 | rawSql += " group by " + groupBy 106 | if having != "" { 107 | rawSql += " having " + having 108 | } 109 | } 110 | 111 | if orderLimitOffsetStr != "" { 112 | rawSql += " " + orderLimitOffsetStr 113 | } 114 | 115 | if q.forUpdate != "" { 116 | rawSql += " " + string(q.forUpdate) 117 | } 118 | 119 | if len(q.unions) > 0 { 120 | for _, v := range q.unions { 121 | prefix := "\nunion" 122 | if v.unionAll { 123 | prefix += " all" 124 | } 125 | prefix += " \n" + v.raw 126 | rawSql += prefix 127 | bindings = append(bindings, v.bindings...) 128 | } 129 | } 130 | 131 | if len(q.windows) > 0 { 132 | var raws []string 133 | for _, v := range q.windows { 134 | var raw = v.tableName + " as (" + v.raw + ")" 135 | raws = append(raws, raw) 136 | } 137 | rawSql += "\nwindow " + strings.Join(raws, ",\n") 138 | } 139 | 140 | ret.raw = rawSql 141 | ret.bindings = bindings 142 | } 143 | return &ret 144 | } 145 | 146 | func (q *Query[T]) generateSelectColumns(columns ...any) (string, error) { 147 | var outColumns []string 148 | for _, v := range columns { 149 | column, err := q.parseColumn(v) 150 | 151 | if err != nil { 152 | return "", err 153 | } 154 | outColumns = append(outColumns, column) //column string name 155 | } 156 | 157 | if len(outColumns) == 0 { 158 | return "*", nil 159 | } else { 160 | return strings.Join(outColumns, ","), nil 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /orm/query_table.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | ) 7 | 8 | type queryTable struct { 9 | table Table 10 | tableStruct reflect.Value 11 | tableStructType reflect.Type 12 | ormFields map[any]string 13 | joinType JoinType //(left|right) join 14 | joinCondition where 15 | alias string 16 | overrideTableName string //override table name 17 | partition string 18 | rawSql string 19 | bindings []any 20 | } 21 | 22 | func (q queryTable) getAlias() string { 23 | return q.alias 24 | } 25 | 26 | func (q queryTable) getAliasOrTableName() string { 27 | if q.alias != "" { 28 | return q.alias 29 | } 30 | return q.getTableName() 31 | } 32 | 33 | func (q queryTable) getTableNameAndAlias() string { 34 | var strs []string 35 | temp := q.getTableNamePartition() 36 | if temp != "" { 37 | strs = append(strs, temp) 38 | } 39 | temp = q.getAlias() 40 | if temp != "" { 41 | strs = append(strs, temp) 42 | } 43 | return strings.Join(strs, " ") 44 | } 45 | 46 | func (q queryTable) getTableName() string { 47 | if q.overrideTableName != "" { 48 | return q.overrideTableName 49 | } 50 | if q.table.TableName() != "" { 51 | if q.table.DatabaseName() != "" { 52 | return q.table.DatabaseName() + "." + q.table.TableName() 53 | } else { 54 | return q.table.TableName() 55 | } 56 | } 57 | return "" 58 | } 59 | 60 | func (q queryTable) getTableNamePartition() string { 61 | if q.overrideTableName != "" { 62 | return q.overrideTableName 63 | } 64 | if q.table.TableName() != "" { 65 | if q.table.DatabaseName() != "" { 66 | return q.table.DatabaseName() + "." + q.table.TableName() + q.getPartition() 67 | } else { 68 | return q.table.TableName() + q.getPartition() 69 | } 70 | } 71 | return "" 72 | } 73 | 74 | func (q queryTable) getPartition() string { 75 | if q.partition == "" { 76 | return "" 77 | } 78 | return " partition(" + q.partition + ")" 79 | } 80 | 81 | func (q queryTable) getTags(index int, tagName string) []string { 82 | tags := strings.Split(q.tableStructType.Field(index).Tag.Get(tagName), ",") 83 | return tags 84 | } 85 | func (q queryTable) getTag(index int, tagName string) string { 86 | return q.tableStructType.Field(index).Tag.Get(tagName) 87 | } 88 | -------------------------------------------------------------------------------- /orm/query_table_cache.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | var tableCache sync.Map 8 | 9 | func getTableFromCache(key any) *queryTable { 10 | res, ok := tableCache.Load(key) 11 | if ok { 12 | ret, ok := res.(*queryTable) 13 | if ok { 14 | return ret 15 | } 16 | } 17 | return nil 18 | } 19 | 20 | func cacheTable(key any, val *queryTable) { 21 | tableCache.Store(key, val) 22 | } 23 | -------------------------------------------------------------------------------- /orm/query_transaction.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | func (q *Query[T]) Transaction(f func(query *Query[T]) error) error { 8 | tx, err := q.DB().BeginTx(context.Background(), nil) 9 | if err != nil { 10 | return err 11 | } 12 | q.tx = tx 13 | 14 | err = f(q) 15 | 16 | if err != nil { 17 | _ = tx.Rollback() 18 | return err 19 | } 20 | return tx.Commit() 21 | } 22 | -------------------------------------------------------------------------------- /orm/query_union.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | func (q *Query[T]) Union(subquery *SubQuery) *Query[T] { 4 | return q.union(false, subquery) 5 | } 6 | func (q *Query[T]) UnionAll(subquery *SubQuery) *Query[T] { 7 | return q.union(true, subquery) 8 | } 9 | func (q *Query[T]) union(isAll bool, subquery *SubQuery) *Query[T] { 10 | subquery.unionAll = isAll 11 | q.setErr(subquery.err) 12 | q.unions = append(q.unions, subquery) 13 | return q 14 | } 15 | -------------------------------------------------------------------------------- /orm/query_update.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "strings" 7 | ) 8 | 9 | func (q *Query[T]) Update(column any, val any, columnVars ...any) QueryResult { 10 | 11 | var updates = []updateColumn{{col: column, val: val}} 12 | 13 | if len(columnVars) > 0 { 14 | for k := range columnVars { 15 | if k%2 != 0 { 16 | updates = append(updates, updateColumn{col: columnVars[k-1], val: columnVars[k]}) 17 | } 18 | } 19 | } 20 | 21 | return q.updates(updates...) 22 | } 23 | 24 | func (q *Query[T]) Updates(columnVars map[any]any) QueryResult { 25 | var updates = make([]updateColumn, len(columnVars)) 26 | var i = 0 27 | for k := range columnVars { 28 | updates[i] = updateColumn{col: k, val: columnVars[k]} 29 | i++ 30 | } 31 | return q.updates(updates...) 32 | } 33 | 34 | func (q *Query[T]) genUpdates(vals ...any) ([]updateColumn, error) { 35 | if len(vals)%2 != 0 { 36 | return nil, errors.New("update column and val should be pairs") 37 | } 38 | 39 | var updates []updateColumn 40 | 41 | for k := range vals { 42 | if k%2 != 0 { 43 | continue 44 | } 45 | updates = append(updates, updateColumn{col: vals[k], val: vals[k+1]}) 46 | } 47 | 48 | return updates, nil 49 | } 50 | 51 | func (q *Query[T]) updates(updates ...updateColumn) QueryResult { 52 | bindings := make([]any, 0) 53 | 54 | if len(q.wheres) == 0 && len(q.tables) <= 1 && q.limit == 0 { 55 | q.setErr(ErrUpdateWithoutCondition) 56 | } 57 | 58 | tableStr := q.generateTableAndJoinStr(q.tables, &bindings) 59 | 60 | updateStr := q.generateUpdateStr(updates, &bindings) 61 | 62 | whereStr := q.generateWhereStr(q.wheres, &bindings) 63 | 64 | orderAndLimitStr := q.getOrderAndLimitSqlStr() 65 | 66 | rawSql := "update " + tableStr 67 | if updateStr != "" { 68 | rawSql += " set " + updateStr 69 | if whereStr != "" { 70 | rawSql += " where " + whereStr 71 | } 72 | } 73 | 74 | if orderAndLimitStr != "" { 75 | rawSql += " " + orderAndLimitStr 76 | } 77 | 78 | q.prepareSql = rawSql 79 | q.bindings = bindings 80 | 81 | return q.Execute() 82 | } 83 | 84 | func (q *Query[T]) generateUpdateStr(updates []updateColumn, bindings *[]any) string { 85 | var updateStrs []string 86 | for _, v := range updates { 87 | var temp string 88 | column, err := q.parseColumn(v.col) 89 | if err != nil { 90 | q.setErr(err) 91 | return "" 92 | } 93 | 94 | val, ok := q.isRaw(v.val) 95 | if ok { 96 | temp = column + " = " + val 97 | } else if reflect.ValueOf(v.val).Kind() == reflect.Ptr { 98 | if v.val == v.col { 99 | dotIndex := strings.LastIndex(column, ".") 100 | temp = column + " = values(`" + strings.Trim(column[dotIndex+1:], "`") + "`)" 101 | } else { 102 | targetColumn, err := q.parseColumn(v.val) 103 | if err == nil { 104 | temp = column + " = " + targetColumn 105 | } else { 106 | //q.setErr(err) 107 | //return "" 108 | temp = column + " = ?" 109 | *bindings = append(*bindings, v.val) 110 | } 111 | } 112 | } else { 113 | temp = column + " = ?" 114 | *bindings = append(*bindings, v.val) 115 | } 116 | updateStrs = append(updateStrs, temp) 117 | } 118 | return strings.Join(updateStrs, ",") 119 | } 120 | -------------------------------------------------------------------------------- /orm/query_where.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "strings" 7 | ) 8 | 9 | type WhereOperator Raw 10 | 11 | const ( 12 | WhereEqual WhereOperator = "=" 13 | WhereNotEqual WhereOperator = "!=" 14 | WhereGreatThan WhereOperator = ">" 15 | WhereGreaterOrEqual WhereOperator = ">=" 16 | WhereLessThan WhereOperator = "<" 17 | WhereLessOrEqual WhereOperator = "<=" 18 | WhereIn WhereOperator = "in" 19 | WhereNotIn WhereOperator = "not in" 20 | WhereLike WhereOperator = "like" 21 | WhereNotLike WhereOperator = "not like" 22 | WhereRlike WhereOperator = "rlike" 23 | WhereNotRlike WhereOperator = "not rlike" 24 | WhereIsNull WhereOperator = "is null" 25 | WhereIsNotNull WhereOperator = "is not null" 26 | ) 27 | 28 | //"id=1" 29 | //&obj.id, 1 30 | //&obj.id, "=", 1 31 | func (q *Query[T]) Where(column any, vals ...any) *Query[T] { 32 | t := q.where(false, column, vals...) 33 | return t 34 | } 35 | 36 | //"id=1" 37 | //&obj.id, 1 38 | //&obj.id, "=", 1 39 | func (q *Query[T]) OrWhere(column any, vals ...any) *Query[T] { 40 | return q.where(true, column, vals...) 41 | } 42 | 43 | //short for Where(primaryKey, vals...) 44 | func (q *Query[T]) WherePrimary(operator any, vals ...any) *Query[T] { 45 | //operator as vals 46 | if len(vals) == 0 { 47 | vals = []any{operator} 48 | reflectVar := reflect.ValueOf(operator) 49 | if reflectVar.Kind() == reflect.Slice { 50 | operator = WhereIn 51 | } else { 52 | operator = WhereEqual 53 | } 54 | } 55 | 56 | return q.where(false, q.tables[0].tableStruct.Field(0).Addr().Interface(), operator, vals[0]) 57 | } 58 | 59 | func (q *Query[T]) WherePrimaryIfNotZero(val any) *Query[T] { 60 | if reflect.ValueOf(val).IsZero() { 61 | return q 62 | } else { 63 | return q.WherePrimary(val) 64 | } 65 | } 66 | 67 | //short for OrWhere(primaryKey, vals...) 68 | func (q *Query[T]) OrWherePrimary(operator any, vals ...any) *Query[T] { 69 | //operator as vals 70 | if len(vals) == 0 { 71 | vals = []any{operator} 72 | reflectVar := reflect.ValueOf(operator) 73 | if reflectVar.Kind() == reflect.Slice { 74 | operator = WhereIn 75 | } else { 76 | operator = WhereEqual 77 | } 78 | } 79 | 80 | return q.where(true, q.tables[0].tableStruct.Field(0).Addr().Interface(), operator, vals[0]) 81 | } 82 | 83 | func (q *Query[T]) WhereBetween(column any, valLess, valGreat any) *Query[T] { 84 | return q.WhereFunc(func(query *Query[T]) *Query[T] { 85 | return query.where(false, column, WhereGreaterOrEqual, valLess).where(false, column, WhereLessOrEqual, valGreat) 86 | }) 87 | } 88 | 89 | func (q *Query[T]) OrWhereBetween(column any, valLess, valGreat any) *Query[T] { 90 | return q.OrWhereFunc(func(query *Query[T]) *Query[T] { 91 | return query.where(false, column, WhereGreaterOrEqual, valLess).where(false, column, WhereLessOrEqual, valGreat) 92 | }) 93 | } 94 | 95 | //"id=1" 96 | //&obj.id, 1 97 | //&obj.id, "=", 1 98 | func (q *Query[T]) WhereFunc(f func(where *Query[T]) *Query[T]) *Query[T] { 99 | return q.whereGroup(false, f) 100 | } 101 | 102 | //"id=1" 103 | //&obj.id, 1 104 | //&obj.id, "=", 1 105 | func (q *Query[T]) OrWhereFunc(f func(where *Query[T]) *Query[T]) *Query[T] { 106 | return q.whereGroup(true, f) 107 | } 108 | 109 | func (q *Query[T]) whereGroup(isOr bool, f func(where *Query[T]) *Query[T]) *Query[T] { 110 | temp, err := q.generateWhereGroup(f) 111 | q.setErr(err) 112 | if len(temp.SubWheres) > 0 { 113 | temp.IsOr = isOr 114 | q.wheres = append(q.wheres, temp) 115 | } 116 | return q 117 | } 118 | 119 | func (q *Query[T]) generateWhereGroup(f func(where *Query[T]) *Query[T]) (where, error) { 120 | start := len(q.wheres) 121 | nq := *q 122 | f(&nq) 123 | newWheres := nq.wheres[start:] 124 | 125 | if len(newWheres) > 0 { 126 | return where{SubWheres: append([]where{}, newWheres...)}, nq.result.Err 127 | } 128 | return where{}, nq.result.Err 129 | } 130 | 131 | func (q *Query[T]) where(isOr bool, column any, vals ...any) *Query[T] { 132 | if len(vals) > 2 { 133 | return q.setErr(errors.New("two many where-params")) 134 | } 135 | 136 | if len(vals) == 0 { 137 | c, ok := q.isStringOrRaw(column) 138 | if ok == false { 139 | return q.setErr(errors.New("where-param should be string while only 1 param exist")) 140 | } 141 | if c != "" { 142 | q.wheres = append(q.wheres, where{Raw: c, IsOr: isOr}) 143 | } else { 144 | return q.setErr(errors.New("where-param should not be empty string")) 145 | } 146 | } else { 147 | c, err := q.parseColumn(column) 148 | if err != nil { 149 | return q.setErr(err) 150 | } 151 | operator := "=" 152 | var val any 153 | if len(vals) == 2 { 154 | operator2, ok := q.isStringOrRaw(vals[0]) 155 | if ok == false { 156 | return q.setErr(errors.New("the second where-param should be operator as string")) 157 | } 158 | operator = operator2 159 | val = vals[1] 160 | } else { 161 | if vals[0] == nil { 162 | vals[0] = WhereIsNull 163 | } 164 | tempVal, ok := q.isOperator(vals[0]) 165 | if ok { 166 | if tempVal != string(WhereIsNull) && tempVal != string(WhereIsNotNull) { 167 | return q.setErr(errors.New("operator \"" + tempVal + "\" must have params")) 168 | } 169 | operator = "" 170 | val = Raw(tempVal) 171 | } else { 172 | val = vals[0] 173 | } 174 | } 175 | 176 | value, ok := q.isRaw(val) 177 | raw := "" 178 | var rawBindings []any 179 | if ok { 180 | if operator != "" { 181 | operator += " " 182 | } 183 | raw = c + " " + operator + value 184 | } else { 185 | tempTable, ok := val.(*SubQuery) 186 | if ok { 187 | if operator != "" { 188 | operator += " " 189 | } 190 | raw = c + " " + operator + "(" + tempTable.raw + ")" 191 | rawBindings = append(rawBindings, tempTable.bindings...) 192 | } else { 193 | temp := reflect.ValueOf(val) 194 | if temp.Kind() == reflect.Slice && temp.Len() > 0 { 195 | rawBindings = make([]any, temp.Len()) 196 | rawCells := make([]string, temp.Len()) 197 | 198 | for i := 0; i < temp.Len(); i++ { 199 | rawCells[i] = "?" 200 | rawBindings[i] = temp.Index(i).Interface() 201 | } 202 | 203 | raw = c + " " + operator + " " + "(" + strings.Join(rawCells, ",") + ")" 204 | } else if temp.Kind() == reflect.Ptr { 205 | rawColumn, err := q.parseColumn(val) 206 | if err == nil { 207 | raw = c + " " + operator + " " + rawColumn 208 | } else { 209 | return q.setErr(errors.New("Error where " + c + " " + operator + " ? val is invalid")) 210 | } 211 | } 212 | } 213 | } 214 | q.wheres = append(q.wheres, where{Raw: raw, Column: c, Val: val, Operator: operator, IsOr: isOr, RawBindings: rawBindings}) 215 | } 216 | return q 217 | } 218 | 219 | func (q *Query[T]) generateWhereStr(wheres []where, bindings *[]any) string { 220 | var whereStr []string 221 | for k, v := range wheres { 222 | tempStr := "" 223 | if k > 0 { 224 | if v.IsOr { 225 | tempStr = "or " 226 | } else { 227 | tempStr = "and " 228 | } 229 | } 230 | if len(v.SubWheres) == 0 { 231 | if v.Raw != "" { 232 | tempStr += v.Raw 233 | if len(v.RawBindings) > 0 { 234 | *bindings = append(*bindings, v.RawBindings...) 235 | } 236 | } else { 237 | tempStr += v.Column + " " + v.Operator + " ?" 238 | *bindings = append(*bindings, v.Val) 239 | } 240 | } else { 241 | tempStr += "(" + q.generateWhereStr(v.SubWheres, bindings) + ")" 242 | } 243 | whereStr = append(whereStr, tempStr) 244 | } 245 | return strings.Join(whereStr, " ") 246 | } 247 | 248 | func (q *Query[T]) WhereExists(sq *SubQuery) *Query[T] { 249 | q.bindings = append(q.bindings, sq.bindings...) 250 | return q.Where("EXISTS (" + sq.raw + ")") 251 | } 252 | 253 | func (q *Query[T]) OrWhereExists(sq *SubQuery) *Query[T] { 254 | q.bindings = append(q.bindings, sq.bindings...) 255 | return q.OrWhere("EXISTS (" + sq.raw + ")") 256 | } 257 | 258 | func (q *Query[T]) WhereNotExists(sq *SubQuery) *Query[T] { 259 | q.bindings = append(q.bindings, sq.bindings...) 260 | return q.Where("NOT EXISTS (" + sq.raw + ")") 261 | } 262 | 263 | func (q *Query[T]) OrWhereNotExists(sq *SubQuery) *Query[T] { 264 | q.bindings = append(q.bindings, sq.bindings...) 265 | return q.OrWhere("NOT EXISTS (" + sq.raw + ")") 266 | } 267 | -------------------------------------------------------------------------------- /orm/query_window.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | func (q *Query[T]) WithWindow(subquery *SubQuery, windowName string) *Query[T] { 4 | subquery.tableName = windowName 5 | q.windows = append(q.windows, subquery) 6 | return q 7 | } 8 | -------------------------------------------------------------------------------- /orm/query_with.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | ) 7 | 8 | func (q *Query[T]) WithParentsOnColumn(pidColumn any) *Query[T] { 9 | tempName := q.tableInterface().TableName() + "_cte" 10 | 11 | col, err := q.parseColumn(pidColumn) 12 | if err != nil { 13 | return q.setErr(err) 14 | } 15 | coln := strings.Split(col, ".") 16 | newcol := strings.Trim(coln[len(coln)-1], "`") 17 | 18 | cte := newQueryRaw(tempName, q.DBs()...) 19 | 20 | appendQuery := NewQuery(q.T, q.DBs()...) 21 | appendQuery = appendQuery.Join(cte.T, func(query *Query[T]) *Query[T] { 22 | return query.Where(appendQuery.tables[0].tableStruct.Field(0).Addr().Interface(), Raw(tempName+"."+newcol)) 23 | }) 24 | 25 | if len(q.columns) > 0 { 26 | appendQuery.Select(q.columns...) 27 | } else { 28 | appendQuery.Select(appendQuery.allCols()) 29 | } 30 | 31 | q.self = cte 32 | return q.UnionAll(appendQuery.SubQuery()) 33 | } 34 | 35 | func (q *Query[T]) WithChildrenOnColumn(pidColumn any) *Query[T] { 36 | tempName := q.tableInterface().TableName() + "_cte" 37 | 38 | pcol, err := q.parseColumn(pidColumn) 39 | if err != nil { 40 | return q.setErr(err) 41 | } 42 | if strings.Contains(pcol, ".") == false { 43 | pcol = q.tableInterface().TableName() + "." + pcol 44 | } 45 | col, err := q.parseColumn(q.tables[0].tableStruct.Field(0).Addr().Interface()) 46 | if err != nil { 47 | return q.setErr(err) 48 | } 49 | coln := strings.Split(col, ".") 50 | newcol := strings.Trim(coln[len(coln)-1], "`") 51 | 52 | cte := newQueryRaw(tempName, q.DBs()...) 53 | 54 | appendQuery := NewQuery(q.T, q.DBs()...) 55 | appendQuery = appendQuery.Join(cte.T, func(query *Query[T]) *Query[T] { 56 | return query.Where(pcol, Raw(tempName+"."+newcol)) 57 | }) 58 | 59 | if len(q.columns) > 0 { 60 | appendQuery.Select(q.columns...) 61 | } else { 62 | appendQuery.Select(appendQuery.allCols()) 63 | } 64 | 65 | q.self = cte 66 | return q.UnionAll(appendQuery.SubQuery()) 67 | } 68 | 69 | func (q *Query[T]) WithCte(subquery *SubQuery, cteName string, columns ...string) *Query[T] { 70 | return q.withCte(subquery, cteName, false, columns...) 71 | } 72 | 73 | func (q *Query[T]) WithRecursiveCte(subquery *SubQuery, cteName string, columns ...string) *Query[T] { 74 | return q.withCte(subquery, cteName, true, columns...) 75 | } 76 | 77 | func (q *Query[T]) withCte(subquery *SubQuery, cteName string, recursive bool, columns ...string) *Query[T] { 78 | subquery.tableName = cteName 79 | subquery.recursive = recursive 80 | subquery.columns = columns 81 | q.setErr(subquery.err) 82 | q.withCtes = append(q.withCtes, subquery) 83 | return q 84 | } 85 | 86 | func (q *Query[T]) WithContext(ctx context.Context) *Query[T] { 87 | q.ctx = &ctx 88 | return q 89 | } 90 | -------------------------------------------------------------------------------- /orm/reflect.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "reflect" 7 | "strconv" 8 | "strings" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | func reflectValueIsOrmField(v reflect.Value) bool { 14 | if v.CanInterface() == false { 15 | return false 16 | } 17 | 18 | if _, ok := v.Interface().(*time.Time); ok { 19 | return true 20 | } 21 | if _, ok := v.Interface().(**time.Time); ok { 22 | return true 23 | } 24 | if _, ok := v.Interface().(sql.Scanner); ok { 25 | return true 26 | } 27 | 28 | vv := reflect.Indirect(v) 29 | 30 | if vv.CanInterface() { 31 | if _, ok := vv.Interface().(sql.Scanner); ok { 32 | return true 33 | } 34 | } 35 | 36 | return false 37 | } 38 | 39 | var structFieldsCache sync.Map 40 | 41 | func getFieldsCache(key string) []string { 42 | intef, ok := structFieldsCache.Load(key) 43 | if ok { 44 | ret, ok := intef.([]string) 45 | if ok { 46 | return ret 47 | } 48 | } 49 | return nil 50 | } 51 | func setFieldsCache(key string, val []string) { 52 | structFieldsCache.Store(key, val) 53 | } 54 | 55 | func castFieldsToStrSlice(tableObjAddr any, tableColumnPtrs ...any) ([]string, error) { 56 | if len(tableColumnPtrs) == 0 { 57 | return nil, nil 58 | } 59 | 60 | tableStructAddr := reflect.ValueOf(tableObjAddr) 61 | if tableStructAddr.Kind() != reflect.Ptr { 62 | return nil, ErrParamMustBePtr 63 | } 64 | 65 | tableStruct := tableStructAddr.Elem() 66 | 67 | if tableStruct.Kind() != reflect.Struct { 68 | return nil, ErrParamElemKindMustBeStruct 69 | } 70 | 71 | tableStructType := reflect.TypeOf(tableObjAddr).Elem() 72 | 73 | var columns []string 74 | for k, v := range tableColumnPtrs { 75 | columnVar := reflect.ValueOf(v) 76 | if columnVar.Kind() != reflect.Ptr { 77 | return nil, ErrParamMustBePtr 78 | } 79 | 80 | for i := 0; i < tableStruct.NumField(); i++ { 81 | valueField := tableStruct.Field(i) 82 | if valueField.Addr().Interface() == columnVar.Elem().Addr().Interface() { 83 | name := strings.Split(tableStructType.Field(i).Tag.Get("json"), ",")[0] 84 | if name != "" && name != "-" { 85 | columns = append(columns, name) 86 | } 87 | break 88 | } else if i == tableStruct.NumField()-1 { 89 | return columns, errors.New("param " + strconv.Itoa(k+2) + " is not a field of first obj") 90 | } 91 | } 92 | } 93 | 94 | return columns, nil 95 | } 96 | 97 | func getStructFieldAddrMap(objAddr any) (map[string]any, error) { 98 | tableStructAddr := reflect.ValueOf(objAddr) 99 | if tableStructAddr.Kind() != reflect.Ptr { 100 | return nil, ErrParamMustBePtr 101 | } 102 | 103 | tableStruct := tableStructAddr.Elem() 104 | if tableStruct.Kind() != reflect.Struct { 105 | return nil, ErrParamElemKindMustBeStruct 106 | } 107 | 108 | tableStructType := reflect.TypeOf(objAddr).Elem() 109 | 110 | ret := make(map[string]any) 111 | 112 | fields, err := getStructFieldNameSlice(tableStruct.Interface()) 113 | if err != nil { 114 | return nil, err 115 | } 116 | 117 | for i := 0; i < tableStruct.NumField(); i++ { 118 | if tableStruct.Field(i).Kind() == reflect.Struct && tableStructType.Field(i).Anonymous { 119 | innerMap, err := getStructFieldAddrMap(tableStruct.Field(i).Addr().Interface()) 120 | if err != nil { 121 | return ret, err 122 | } 123 | for k, v := range innerMap { 124 | ret[k] = v 125 | } 126 | } else { 127 | valueField := tableStruct.Field(i) 128 | 129 | name := fields[i] 130 | if name != "" { 131 | ret[name] = valueField.Addr().Interface() 132 | } 133 | 134 | } 135 | } 136 | 137 | return ret, nil 138 | } 139 | 140 | func getStructAddrFieldMap(objAddr any) (map[any]string, error) { 141 | tableStructAddr := reflect.ValueOf(objAddr) 142 | if tableStructAddr.Kind() != reflect.Ptr { 143 | return nil, ErrParamMustBePtr 144 | } 145 | 146 | tableStruct := tableStructAddr.Elem() 147 | if tableStruct.Kind() != reflect.Struct { 148 | return nil, ErrParamElemKindMustBeStruct 149 | } 150 | 151 | tableStructType := reflect.TypeOf(objAddr).Elem() 152 | 153 | ret := make(map[any]string) 154 | 155 | for i := 0; i < tableStruct.NumField(); i++ { 156 | valueField := tableStruct.Field(i) 157 | 158 | name := strings.Split(tableStructType.Field(i).Tag.Get("json"), ",")[0] 159 | if name == "-" { 160 | name = "" 161 | } 162 | if name != "" { 163 | ret[valueField.Addr().Interface()] = name 164 | } 165 | } 166 | return ret, nil 167 | } 168 | 169 | func getStructFieldNameSlice(obj any) ([]string, error) { 170 | tableStructType := reflect.TypeOf(obj) 171 | 172 | fieldsCache := getFieldsCache(tableStructType.String()) 173 | if fieldsCache != nil { 174 | return fieldsCache, nil 175 | } 176 | 177 | tableStruct := reflect.ValueOf(obj) 178 | if tableStruct.Kind() != reflect.Struct { 179 | return nil, ErrParamElemKindMustBeStruct 180 | } 181 | var ret = make([]string, tableStruct.NumField()) 182 | 183 | for i := 0; i < tableStruct.NumField(); i++ { 184 | //if tableStruct.Field(i).Kind() == reflect.Struct && tableStructType.Field(i).Anonymous { 185 | // innerFields, err := getStructFieldNameSlice(tableStruct.Field(i).Interface()) 186 | // if err != nil { 187 | // return ret, err 188 | // } 189 | // ret = append(ret, innerFields...) 190 | //} 191 | ormTag := strings.Split(tableStructType.Field(i).Tag.Get("orm"), ",")[0] 192 | if ormTag == "-" { 193 | ormTag = "" 194 | } 195 | 196 | if ormTag != "" { 197 | ret[i] = ormTag 198 | continue 199 | } 200 | 201 | ormTag = strings.Split(tableStructType.Field(i).Tag.Get("json"), ",")[0] 202 | if ormTag == "-" { 203 | ormTag = "" 204 | } 205 | if ormTag != "" { 206 | ret[i] = ormTag 207 | } 208 | } 209 | 210 | setFieldsCache(tableStructType.String(), ret) 211 | return ret, nil 212 | } 213 | 214 | func getStructFieldWithDefaultTime(obj any) (map[int]any, error) { 215 | tableStructType := reflect.TypeOf(obj) 216 | 217 | tableStruct := reflect.ValueOf(obj) 218 | if tableStruct.Kind() != reflect.Struct { 219 | return nil, ErrParamElemKindMustBeStruct 220 | } 221 | ret := make(map[int]any) 222 | now := time.Now() 223 | for i := 0; i < tableStruct.NumField(); i++ { 224 | v := tableStruct.Field(i) 225 | if v.CanInterface() { 226 | defaultVar := tableStructType.Field(i).Tag.Get("default") 227 | name := tableStructType.Field(i).Tag.Get("json") 228 | ormName := tableStructType.Field(i).Tag.Get("orm") 229 | ormNames := strings.Split(ormName, ",") 230 | ormName = ormNames[0] 231 | if ormName != "" { 232 | name = ormName 233 | } 234 | if defaultVar == "" { 235 | for _, v := range ormNames { 236 | if strings.ToLower(v) == "null" { 237 | defaultVar = "null" 238 | break 239 | } 240 | } 241 | } 242 | 243 | if _, ok := v.Interface().(time.Time); ok { 244 | lowerDefault := strings.ToLower(defaultVar) 245 | if strings.Contains(lowerDefault, "current_timestamp") { 246 | ret[i] = now 247 | } else if lowerDefault == "" { 248 | if name != deletedAtColumn { 249 | ret[i] = now 250 | } 251 | } else if lowerDefault != "null" { 252 | ret[i] = defaultVar 253 | } 254 | } else if v.Kind() == reflect.Struct { 255 | realVal := reflect.New(v.Type()) 256 | if _, ok := realVal.Elem().Field(0).Interface().(time.Time); ok { 257 | lowerDefault := strings.ToLower(defaultVar) 258 | if strings.Contains(lowerDefault, "current_timestamp") { 259 | ret[i] = now 260 | } else if lowerDefault == "" { 261 | if name != deletedAtColumn { 262 | ret[i] = now 263 | } 264 | } else if lowerDefault != "null" { 265 | ret[i] = defaultVar 266 | } 267 | } 268 | } 269 | } 270 | } 271 | return ret, nil 272 | } 273 | -------------------------------------------------------------------------------- /orm/slice.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | func sliceContain[T comparable](a []T, b T) bool { 4 | return sliceContainIndex(a, b) > -1 5 | } 6 | 7 | func sliceContainIndex[T comparable](a []T, b T) int { 8 | for k, v := range a { 9 | if v == b { 10 | return k 11 | } 12 | } 13 | return -1 14 | } 15 | -------------------------------------------------------------------------------- /orm/string.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "reflect" 7 | "strconv" 8 | "strings" 9 | "time" 10 | "unicode" 11 | ) 12 | 13 | var ( 14 | escape = `'` 15 | nullStr = "NULL" 16 | ) 17 | 18 | func varToString(i any) string { 19 | switch v := i.(type) { 20 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: 21 | return fmt.Sprintf("%d", v) 22 | case float64, float32: 23 | return fmt.Sprintf("%.6f", v) 24 | case bool: 25 | return strconv.FormatBool(v) 26 | case string: 27 | return escape + strings.ReplaceAll(v, escape, "\\"+escape) + escape 28 | case []byte: 29 | if s := string(v); stringIsPrintable(s) { 30 | return escape + strings.ReplaceAll(s, escape, "\\"+escape) + escape 31 | } else { 32 | return escape + "" + escape 33 | } 34 | case time.Time: 35 | if v.IsZero() { 36 | return escape + "0000-00-00 00:00:00" + escape 37 | } else { 38 | return escape + v.Format("2006-01-02 15:04:05.999") + escape 39 | } 40 | case *time.Time: 41 | if v != nil { 42 | if v.IsZero() { 43 | return escape + "0000-00-00 00:00:00" + escape 44 | } else { 45 | return escape + v.Format("2006-01-02 15:04:05.999") + escape 46 | } 47 | } else { 48 | return nullStr 49 | } 50 | case driver.Valuer: 51 | reflectValue := reflect.ValueOf(v) 52 | if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { 53 | r, _ := v.Value() 54 | return varToString(r) 55 | } else { 56 | return nullStr 57 | } 58 | case fmt.Stringer: 59 | reflectValue := reflect.ValueOf(v) 60 | switch reflectValue.Kind() { 61 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 62 | return fmt.Sprintf("%d", reflectValue.Interface()) 63 | case reflect.Float32, reflect.Float64: 64 | return fmt.Sprintf("%.6f", reflectValue.Interface()) 65 | case reflect.Bool: 66 | return fmt.Sprintf("%t", reflectValue.Interface()) 67 | case reflect.String: 68 | return escape + strings.ReplaceAll(fmt.Sprintf("%v", v), escape, "\\"+escape) + escape 69 | default: 70 | if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { 71 | return escape + strings.ReplaceAll(fmt.Sprintf("%v", v), escape, "\\"+escape) + escape 72 | } else { 73 | return nullStr 74 | } 75 | } 76 | default: 77 | rv := reflect.ValueOf(v) 78 | if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { 79 | return nullStr 80 | } else if valuer, ok := v.(driver.Valuer); ok { 81 | v, _ = valuer.Value() 82 | return varToString(v) 83 | } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { 84 | return varToString(reflect.Indirect(rv).Interface()) 85 | } else { 86 | for _, t := range []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} { 87 | if rv.Type().ConvertibleTo(t) { 88 | return varToString(rv.Convert(t).Interface()) 89 | } 90 | } 91 | return escape + strings.ReplaceAll(fmt.Sprint(v), escape, "\\"+escape) + escape 92 | } 93 | } 94 | } 95 | 96 | func stringIsPrintable(s string) bool { 97 | for _, r := range s { 98 | if !unicode.IsPrint(r) { 99 | return false 100 | } 101 | } 102 | return true 103 | } 104 | -------------------------------------------------------------------------------- /orm/subquery.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql" 5 | "strings" 6 | ) 7 | 8 | const subqueryDefaultName = "sub" 9 | 10 | type SubQuery struct { 11 | raw string 12 | bindings []any 13 | recursive bool 14 | dbName string 15 | tableName string 16 | columns []string 17 | dbs []*sql.DB 18 | tx *sql.Tx 19 | err error 20 | unionAll bool 21 | } 22 | 23 | func (m SubQuery) Connections() []*sql.DB { 24 | return m.dbs 25 | } 26 | 27 | func (m SubQuery) DatabaseName() string { 28 | return m.dbName 29 | } 30 | 31 | func (m SubQuery) TableName() string { 32 | if m.tableName != "" { 33 | return m.tableName 34 | } 35 | if m.raw != "" { 36 | return subqueryDefaultName 37 | } 38 | return "" 39 | } 40 | 41 | func (m SubQuery) Error() error { 42 | return m.err 43 | } 44 | 45 | func (m SubQuery) Sql() string { 46 | params := make([]string, len(m.bindings)) 47 | for k, v := range m.bindings { 48 | params[k] = varToString(v) 49 | } 50 | 51 | var sqlb strings.Builder 52 | var index = 0 53 | 54 | for _, v := range []byte(m.raw) { 55 | if v == '?' { 56 | if len(params) > index { 57 | sqlb.WriteString(params[index]) 58 | index++ 59 | } else { 60 | break 61 | } 62 | } else { 63 | sqlb.WriteByte(v) 64 | } 65 | } 66 | 67 | return sqlb.String() 68 | } 69 | -------------------------------------------------------------------------------- /orm/table.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import "database/sql" 4 | 5 | type Table interface { 6 | Connections() []*sql.DB 7 | DatabaseName() string 8 | TableName() string 9 | } 10 | -------------------------------------------------------------------------------- /orm/update_column.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | type updateColumn struct { 4 | col any 5 | val any 6 | } 7 | -------------------------------------------------------------------------------- /orm/where.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | type where struct { 4 | IsOr bool 5 | Column string 6 | Operator string 7 | Val any 8 | Raw string 9 | RawBindings []any 10 | SubWheres []where 11 | } 12 | -------------------------------------------------------------------------------- /orm_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "github.com/folospace/go-mysql-orm/orm" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | var tdb, _ = orm.OpenMysql("rfamro@tcp(mysql-rfam-public.ebi.ac.uk:4497)/Rfam?parseTime=true&charset=utf8mb4&loc=Asia%2FShanghai") 11 | 12 | var FamilyTable = new(Family) 13 | 14 | type Family struct { 15 | RfamAcc string `json:"rfam_acc" orm:"rfam_acc,varchar(7),primary,unique"` 16 | RfamId string `json:"rfam_id" orm:"rfam_id,varchar(40),index"` 17 | AutoWiki uint `json:"auto_wiki" orm:"auto_wiki,int(10) unsigned,index"` 18 | Description *string `json:"description" orm:"description,varchar(75),null" default:"NULL"` 19 | Author *string `json:"author" orm:"author,tinytext,null"` 20 | SeedSource *string `json:"seed_source" orm:"seed_source,tinytext,null"` 21 | GatheringCutoff *float64 `json:"gathering_cutoff" orm:"gathering_cutoff,double(5,2),null" default:"NULL"` 22 | TrustedCutoff *float64 `json:"trusted_cutoff" orm:"trusted_cutoff,double(5,2),null" default:"NULL"` 23 | NoiseCutoff *float64 `json:"noise_cutoff" orm:"noise_cutoff,double(5,2),null" default:"NULL"` 24 | Comment *string `json:"comment" orm:"comment,longtext,null"` 25 | PreviousId *string `json:"previous_id" orm:"previous_id,tinytext,null"` 26 | Cmbuild *string `json:"cmbuild" orm:"cmbuild,tinytext,null"` 27 | Cmcalibrate *string `json:"cmcalibrate" orm:"cmcalibrate,tinytext,null"` 28 | Cmsearch *string `json:"cmsearch" orm:"cmsearch,tinytext,null"` 29 | NumSeed *int64 `json:"num_seed" orm:"num_seed,bigint(20),null" default:"NULL"` 30 | NumFull *int64 `json:"num_full" orm:"num_full,bigint(20),null" default:"NULL"` 31 | NumGenomeSeq *int64 `json:"num_genome_seq" orm:"num_genome_seq,bigint(20),null" default:"NULL"` 32 | NumRefseq *int64 `json:"num_refseq" orm:"num_refseq,bigint(20),null" default:"NULL"` 33 | Type *string `json:"type" orm:"type,varchar(50),null" default:"NULL"` 34 | StructureSource *string `json:"structure_source" orm:"structure_source,tinytext,null"` 35 | NumberOfSpecies *int64 `json:"number_of_species" orm:"number_of_species,bigint(20),null" default:"NULL"` 36 | Number3dStructures *int `json:"number_3d_structures" orm:"number_3d_structures,int(11),null" default:"NULL"` 37 | NumPseudonokts *int `json:"num_pseudonokts" orm:"num_pseudonokts,int(11),null" default:"NULL"` 38 | TaxSeed *string `json:"tax_seed" orm:"tax_seed,mediumtext,null"` 39 | EcmliLambda *float64 `json:"ecmli_lambda" orm:"ecmli_lambda,double(10,5),null" default:"NULL"` 40 | EcmliMu *float64 `json:"ecmli_mu" orm:"ecmli_mu,double(10,5),null" default:"NULL"` 41 | EcmliCalDb *int `json:"ecmli_cal_db" orm:"ecmli_cal_db,mediumint(9),null" default:"0"` 42 | EcmliCalHits *int `json:"ecmli_cal_hits" orm:"ecmli_cal_hits,mediumint(9),null" default:"0"` 43 | Maxl *int `json:"maxl" orm:"maxl,mediumint(9),null" default:"0"` 44 | Clen *int `json:"clen" orm:"clen,mediumint(9),null" default:"0"` 45 | MatchPairNode *int8 `json:"match_pair_node" orm:"match_pair_node,tinyint(1),null" default:"0"` 46 | HmmTau *float64 `json:"hmm_tau" orm:"hmm_tau,double(10,5),null" default:"NULL"` 47 | HmmLambda *float64 `json:"hmm_lambda" orm:"hmm_lambda,double(10,5),null" default:"NULL"` 48 | Created time.Time `json:"created" orm:"created,datetime"` 49 | Updated time.Time `json:"updated" orm:"updated,timestamp" default:"CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"` 50 | } 51 | 52 | func (f *Family) Connections() []*sql.DB { 53 | return []*sql.DB{tdb} 54 | } 55 | 56 | func (*Family) TableName() string { 57 | return "family" 58 | } 59 | 60 | func (*Family) DatabaseName() string { 61 | return "Rfam" 62 | } 63 | func (f *Family) Query() *orm.Query[*Family] { 64 | return orm.NewQuery(f) 65 | } 66 | 67 | func TestOrm(t *testing.T) { 68 | t.Run("create_struct", func(t *testing.T) { 69 | err := FamilyTable.Query().CreateStruct() 70 | 71 | t.Log(err) 72 | }) 73 | t.Run("create_table", func(t *testing.T) { 74 | sqlStr, err := FamilyTable.Query().CreateTable() 75 | t.Log(sqlStr) 76 | t.Log(err) 77 | }) 78 | 79 | t.Run("mysql_version", func(t *testing.T) { 80 | data, query := FamilyTable.Query().Raw("select version()").GetString() 81 | 82 | t.Log(data) 83 | t.Log(query.Sql()) 84 | t.Log(query.Error()) 85 | }) 86 | t.Run("query_timeout", func(t *testing.T) { 87 | var data map[string]int 88 | query := FamilyTable.Query().Raw("show variables like '%timeout%'").GetTo(&data) 89 | 90 | t.Log(data) 91 | t.Log(query.Sql()) 92 | t.Log(query.Error()) 93 | }) 94 | t.Run("query_table_sql", func(t *testing.T) { 95 | var data map[string]string 96 | query := FamilyTable.Query().Raw("show create table " + FamilyTable.Query().T.TableName()).GetTo(&data) 97 | 98 | t.Log(data) 99 | t.Log(query.Sql()) 100 | t.Log(query.Error()) 101 | }) 102 | t.Run("count_total", func(t *testing.T) { 103 | data, query := FamilyTable.Query().GetCount() 104 | t.Log(data) 105 | 106 | t.Log(query.Sql()) 107 | t.Log(query.Error()) 108 | }) 109 | t.Run("count_distinct_total", func(t *testing.T) { 110 | //data, query := FamilyTable.Query().Select("count(distinct(type))").GetInt() 111 | //t.Log(data) 112 | 113 | data, query := FamilyTable.Query().GroupBy(&FamilyTable.Type).GetCount() 114 | t.Log(data) 115 | 116 | t.Log(query.Sql()) 117 | t.Log(query.Error()) 118 | }) 119 | t.Run("result_to_map", func(t *testing.T) { 120 | var data map[string][]string 121 | query := FamilyTable.Query().Select(&FamilyTable.Query().T.Type, &FamilyTable.Query().T.RfamAcc).Limit(20).GetTo(&data) 122 | t.Log(data) 123 | 124 | t.Log(query.Sql()) 125 | t.Log(query.Error()) 126 | }) 127 | t.Run("where_primary", func(t *testing.T) { 128 | data, query := FamilyTable.Query().WherePrimary("RF00006").Select(&FamilyTable).Get() 129 | 130 | t.Log(data) 131 | 132 | t.Log(query.Sql()) 133 | t.Log(query.Error()) 134 | }) 135 | } 136 | --------------------------------------------------------------------------------