├── .gitignore ├── LICENSE ├── README.md ├── README_zh.md ├── aorm.go ├── base ├── Db.go ├── Link.go └── Tx.go ├── builder ├── aggregation.go ├── builder.go ├── cache.go ├── crud.go ├── exists.go ├── handle.go ├── having.go ├── increment.go ├── join.go ├── order.go ├── select.go ├── value.go └── where.go ├── driver └── driver.go ├── go.mod ├── migrate_mssql └── migrate.go ├── migrate_mysql └── migrate.go ├── migrate_postgres └── migrate.go ├── migrate_sqlite3 └── migrate.go ├── migrator └── migrator.go ├── null └── null.go ├── test └── aorm_test.go ├── utils └── str.go └── wechat.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | /go.sum 2 | /test/test.db 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 tangpanqing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tangpanqing/aorm/527e68060c454f369c63483683679dd14328636d/README.md -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tangpanqing/aorm/527e68060c454f369c63483683679dd14328636d/README_zh.md -------------------------------------------------------------------------------- /aorm.go: -------------------------------------------------------------------------------- 1 | package aorm 2 | 3 | import ( 4 | "database/sql" //只需导入你需要的驱动即可 5 | "github.com/tangpanqing/aorm/base" 6 | "github.com/tangpanqing/aorm/builder" 7 | "github.com/tangpanqing/aorm/migrator" 8 | ) 9 | 10 | //Open 开始一个数据库连接 11 | func Open(driverName string, dataSourceName string) (*base.Db, error) { 12 | sqlDB, err := sql.Open(driverName, dataSourceName) 13 | if err != nil { 14 | return &base.Db{}, err 15 | } 16 | 17 | err2 := sqlDB.Ping() 18 | if err2 != nil { 19 | return &base.Db{}, err2 20 | } 21 | 22 | return &base.Db{ 23 | Driver: driverName, 24 | SqlDB: sqlDB, 25 | }, nil 26 | } 27 | 28 | func Store(destList ...interface{}) { 29 | builder.Store(destList...) 30 | } 31 | 32 | // Db 开始一个数据库操作 33 | func Db(link base.Link) *builder.Builder { 34 | b := &builder.Builder{} 35 | 36 | b.Link = link 37 | b.Debug(link.GetDebugMode()) 38 | 39 | return b 40 | } 41 | 42 | // Migrator 开始一个数据库迁移 43 | func Migrator(linkCommon base.Link) *migrator.Migrator { 44 | mi := &migrator.Migrator{ 45 | Link: linkCommon, 46 | } 47 | return mi 48 | } 49 | -------------------------------------------------------------------------------- /base/Db.go: -------------------------------------------------------------------------------- 1 | package base 2 | 3 | import ( 4 | "database/sql" 5 | "time" 6 | ) 7 | 8 | type Db struct { 9 | Driver string 10 | DebugMode bool 11 | SqlDB *sql.DB 12 | } 13 | 14 | //Close 关闭 15 | func (db *Db) Close() error { 16 | return db.SqlDB.Close() 17 | } 18 | 19 | //Begin 开始一个事务 20 | func (db *Db) Begin() *Tx { 21 | SqlTx, _ := db.SqlDB.Begin() 22 | 23 | return &Tx{ 24 | driver: db.Driver, 25 | debugMode: db.DebugMode, 26 | 27 | sqlTx: SqlTx, 28 | } 29 | } 30 | 31 | //SetDebugMode 获取调试模式 32 | func (db *Db) SetDebugMode(debugMode bool) { 33 | db.DebugMode = debugMode 34 | } 35 | 36 | func (db *Db) SetConnMaxLifetime(d time.Duration) { 37 | db.SqlDB.SetConnMaxLifetime(d) 38 | } 39 | 40 | func (db *Db) SetConnMaxIdleTime(d time.Duration) { 41 | db.SqlDB.SetConnMaxIdleTime(d) 42 | } 43 | 44 | func (db *Db) SetMaxIdleConns(n int) { 45 | db.SqlDB.SetMaxIdleConns(n) 46 | } 47 | 48 | func (db *Db) SetMaxOpenConns(n int) { 49 | db.SqlDB.SetMaxOpenConns(n) 50 | } 51 | 52 | func (db *Db) Stats() sql.DBStats { 53 | return db.SqlDB.Stats() 54 | } 55 | 56 | //GetDebugMode 获取调试模式 57 | func (db *Db) GetDebugMode() bool { 58 | return db.DebugMode 59 | } 60 | 61 | func (db *Db) DriverName() string { 62 | return db.Driver 63 | } 64 | 65 | func (db *Db) Exec(query string, args ...interface{}) (sql.Result, error) { 66 | return db.SqlDB.Exec(query, args...) 67 | } 68 | 69 | func (db *Db) Prepare(query string) (*sql.Stmt, error) { 70 | return db.SqlDB.Prepare(query) 71 | } 72 | 73 | func (db *Db) Query(query string, args ...interface{}) (*sql.Rows, error) { 74 | return db.SqlDB.Query(query, args...) 75 | } 76 | 77 | func (db *Db) QueryRow(query string, args ...interface{}) *sql.Row { 78 | return db.SqlDB.QueryRow(query, args...) 79 | } 80 | -------------------------------------------------------------------------------- /base/Link.go: -------------------------------------------------------------------------------- 1 | package base 2 | 3 | import "database/sql" 4 | 5 | type Link interface { 6 | GetDebugMode() bool 7 | DriverName() string 8 | Exec(query string, args ...interface{}) (sql.Result, error) 9 | Prepare(query string) (*sql.Stmt, error) 10 | Query(query string, args ...interface{}) (*sql.Rows, error) 11 | QueryRow(query string, args ...interface{}) *sql.Row 12 | } 13 | -------------------------------------------------------------------------------- /base/Tx.go: -------------------------------------------------------------------------------- 1 | package base 2 | 3 | import "database/sql" 4 | 5 | type Tx struct { 6 | driver string 7 | debugMode bool 8 | sqlTx *sql.Tx 9 | } 10 | 11 | //GetDebugMode 获取调试状态 12 | func (tx *Tx) GetDebugMode() bool { 13 | return tx.debugMode 14 | } 15 | 16 | func (tx *Tx) DriverName() string { 17 | return tx.driver 18 | } 19 | 20 | func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { 21 | return tx.sqlTx.Exec(query, args...) 22 | } 23 | 24 | func (tx *Tx) Prepare(query string) (*sql.Stmt, error) { 25 | return tx.sqlTx.Prepare(query) 26 | } 27 | 28 | func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { 29 | return tx.sqlTx.Query(query, args...) 30 | } 31 | 32 | func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row { 33 | return tx.sqlTx.QueryRow(query, args...) 34 | } 35 | 36 | func (tx *Tx) Rollback() error { 37 | return tx.sqlTx.Rollback() 38 | } 39 | 40 | func (tx *Tx) Commit() error { 41 | return tx.sqlTx.Commit() 42 | } 43 | -------------------------------------------------------------------------------- /builder/aggregation.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import "github.com/tangpanqing/aorm/null" 4 | 5 | type IntStruct struct { 6 | C null.Int 7 | } 8 | 9 | type FloatStruct struct { 10 | C null.Float 11 | } 12 | 13 | // Count 聚合函数-数量 14 | func (b *Builder) Count(fieldName interface{}) (int64, error) { 15 | var obj []IntStruct 16 | err := b.SelectCount(fieldName, "c", "").GetMany(&obj) 17 | if err != nil { 18 | return 0, err 19 | } 20 | 21 | return obj[0].C.Int64, nil 22 | } 23 | 24 | // Sum 聚合函数-合计 25 | func (b *Builder) Sum(fieldName interface{}) (float64, error) { 26 | var obj []FloatStruct 27 | err := b.SelectSum(fieldName, "c").GetMany(&obj) 28 | if err != nil { 29 | return 0, err 30 | } 31 | 32 | return obj[0].C.Float64, nil 33 | } 34 | 35 | // Avg 聚合函数-平均值 36 | func (b *Builder) Avg(fieldName interface{}) (float64, error) { 37 | var obj []FloatStruct 38 | err := b.SelectAvg(fieldName, "c").GetMany(&obj) 39 | if err != nil { 40 | return 0, err 41 | } 42 | 43 | return obj[0].C.Float64, nil 44 | } 45 | 46 | // Max 聚合函数-最大值 47 | func (b *Builder) Max(fieldName interface{}) (float64, error) { 48 | var obj []FloatStruct 49 | err := b.SelectMax(fieldName, "c").GetMany(&obj) 50 | if err != nil { 51 | return 0, err 52 | } 53 | 54 | return obj[0].C.Float64, nil 55 | } 56 | 57 | // Min 聚合函数-最小值 58 | func (b *Builder) Min(fieldName interface{}) (float64, error) { 59 | var obj []FloatStruct 60 | err := b.SelectMin(fieldName, "c").GetMany(&obj) 61 | if err != nil { 62 | return 0, err 63 | } 64 | 65 | return obj[0].C.Float64, nil 66 | } 67 | -------------------------------------------------------------------------------- /builder/builder.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import ( 4 | "fmt" 5 | "github.com/tangpanqing/aorm/utils" 6 | "reflect" 7 | "strings" 8 | ) 9 | 10 | type GroupItem struct { 11 | Prefix []string 12 | Field interface{} 13 | } 14 | 15 | type WhereItem struct { 16 | Prefix []string 17 | Field interface{} 18 | Opt string 19 | Val interface{} 20 | } 21 | 22 | type SelectItem struct { 23 | FuncName string 24 | Prefix []string 25 | Field interface{} 26 | FieldNew interface{} 27 | } 28 | 29 | type SelectExpItem struct { 30 | Builder **Builder 31 | FieldName interface{} 32 | } 33 | 34 | type OrderItem struct { 35 | Prefix []string 36 | Field interface{} 37 | OrderType string 38 | } 39 | 40 | type LimitItem struct { 41 | offset int 42 | pageSize int 43 | } 44 | 45 | type JoinItem struct { 46 | joinType string 47 | table interface{} 48 | tableAlias []string 49 | condition []JoinCondition 50 | } 51 | 52 | type JoinCondition struct { 53 | FieldOfCurrentTable interface{} 54 | Opt string 55 | FieldOfOtherTable interface{} 56 | AliasOfOtherTable []string 57 | } 58 | 59 | //GenWhereItem 产生一个 WhereItem,用作 where 条件里 60 | func GenWhereItem(field interface{}, opt string, val interface{}, prefix ...string) WhereItem { 61 | return WhereItem{prefix, field, opt, val} 62 | } 63 | 64 | //GenHavingItem 产生一个 WhereItem,用作 having 条件里 65 | func GenHavingItem(field interface{}, opt string, val interface{}) WhereItem { 66 | return WhereItem{[]string{}, field, opt, val} 67 | } 68 | 69 | //GenJoinCondition 产生一个 JoinCondition,用作 join 条件里 70 | func GenJoinCondition(fieldOfCurrentTable interface{}, opt string, fieldOfOtherTable interface{}, aliasOfOtherTable ...string) JoinCondition { 71 | return JoinCondition{ 72 | FieldOfCurrentTable: fieldOfCurrentTable, 73 | Opt: opt, 74 | FieldOfOtherTable: fieldOfOtherTable, 75 | AliasOfOtherTable: aliasOfOtherTable, 76 | } 77 | } 78 | 79 | //getPrefixByField 获取字段前缀,如果传入则使用传入值,默认使用该字段的表名 80 | func getPrefixByField(valueOf reflect.Value, prefix ...string) string { 81 | str := "" 82 | if len(prefix) > 0 { 83 | str = prefix[0] 84 | } else { 85 | if reflect.Ptr == valueOf.Kind() { 86 | fieldPointer := valueOf.Pointer() 87 | tablePointer := getFieldMap(fieldPointer).TablePointer 88 | 89 | tableName := getTableMap(tablePointer) 90 | strArr := strings.Split(tableName, ".") 91 | str = utils.UnderLine(strArr[len(strArr)-1]) 92 | } else { 93 | //str = fmt.Sprintf("%v", valueOf.Interface()) 94 | } 95 | } 96 | 97 | return str 98 | } 99 | 100 | //getTableNameByTable 根据传入的表信息,获取表名 101 | func getTableNameByTable(table interface{}) string { 102 | valueOf := reflect.ValueOf(table) 103 | if reflect.Ptr == valueOf.Kind() { 104 | return getTableMap(valueOf.Pointer()) 105 | } else { 106 | return fmt.Sprintf("%v", table) 107 | } 108 | } 109 | 110 | //getTableNameByReflect 反射表名,优先从方法获取,没有方法则从名字获取 111 | func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string { 112 | method, isSet := typeOf.MethodByName("TableName") 113 | if isSet { 114 | var paramList []reflect.Value 115 | paramList = append(paramList, valueOf) 116 | res := method.Func.Call(paramList) 117 | return res[0].String() 118 | } else { 119 | arr := strings.Split(typeOf.String(), ".") 120 | return utils.UnderLine(arr[len(arr)-1]) 121 | } 122 | } 123 | 124 | //getFieldNameByField 根据传入字段,获取字段名 125 | func getFieldNameByField(field interface{}) string { 126 | return getFieldNameByReflectValue(reflect.ValueOf(field)) 127 | } 128 | 129 | //getFieldNameByReflectNew 根据传入字段,获取字段名 130 | func getFieldNameByReflectValue(valueOfField reflect.Value) string { 131 | if reflect.Ptr == valueOfField.Kind() { 132 | return getFieldMap(valueOfField.Pointer()).Name 133 | } else { 134 | return fmt.Sprintf("%v", valueOfField) 135 | } 136 | } 137 | 138 | //getFieldNameByStructField 139 | func getFieldNameByStructField(field reflect.StructField) (string, map[string]string) { 140 | key := utils.UnderLine(field.Name) 141 | tag := field.Tag.Get("aorm") 142 | tagMap := getTagMap(tag) 143 | if column, ok := tagMap["column"]; ok { 144 | key = column 145 | } 146 | return key, tagMap 147 | } 148 | 149 | //getFieldMapByReflect 从结构体反射出来的属性名 150 | func getFieldMapByReflect(destType reflect.Type) map[string][]int { 151 | fieldNameMap := make(map[string][]int) 152 | for i := 0; i < destType.NumField(); i++ { 153 | isMultiLevel := false 154 | if "struct" == destType.Field(i).Type.Kind().String() && (destType.Field(i).Type.Name() != "Int" && 155 | destType.Field(i).Type.Name() != "Float" && 156 | destType.Field(i).Type.Name() != "Time" && 157 | destType.Field(i).Type.Name() != "String" && 158 | destType.Field(i).Type.Name() != "Bool") { 159 | isMultiLevel = true 160 | } 161 | 162 | //fmt.Println(isMore, destType, destType.Field(i).Name, destType.Field(i).Type.Name(), destType.Field(i).Type.Kind().String()) 163 | if isMultiLevel { 164 | for j := 0; j < destType.Field(i).Type.NumField(); j++ { 165 | fieldNameMap[destType.Field(i).Type.Field(j).Name] = []int{i, j} 166 | } 167 | } else { 168 | fieldNameMap[destType.Field(i).Name] = []int{i} 169 | } 170 | } 171 | return fieldNameMap 172 | } 173 | 174 | //getScansAddr 获取赋值的地址 175 | func getScansAddr(columnNameList []string, fieldNameMap map[string][]int, destValue reflect.Value) []interface{} { 176 | var scans []interface{} 177 | for _, columnName := range columnNameList { 178 | fieldName := utils.CamelString(strings.ToLower(columnName)) 179 | index, ok := fieldNameMap[fieldName] 180 | if ok { 181 | t := destValue 182 | for j := 0; j < len(index); j++ { 183 | t = t.Field(index[j]) 184 | } 185 | scans = append(scans, t.Addr().Interface()) 186 | } else { 187 | var emptyVal interface{} 188 | scans = append(scans, &emptyVal) 189 | } 190 | } 191 | return scans 192 | } 193 | 194 | //genJoinConditionStr 产生关联查询条件 195 | func genJoinConditionStr(aliasOfCurrentTable string, joinCondition []JoinCondition) (string, []interface{}) { 196 | var paramList []interface{} 197 | var sqlList []string 198 | for i := 0; i < len(joinCondition); i++ { 199 | fieldNameOfCurrentTable := getFieldNameByField(joinCondition[i].FieldOfCurrentTable) 200 | 201 | if aliasOfCurrentTable == "" { 202 | aliasOfCurrentTable = getPrefixByField(reflect.ValueOf(joinCondition[i].FieldOfCurrentTable)) 203 | } 204 | 205 | fieldNameOfOtherTable := getFieldNameByField(joinCondition[i].FieldOfOtherTable) 206 | 207 | if joinCondition[i].Opt == RawEq { 208 | aliasOfOtherTable := getPrefixByField(reflect.ValueOf(joinCondition[i].FieldOfOtherTable), joinCondition[i].AliasOfOtherTable...) 209 | if aliasOfOtherTable != "" { 210 | aliasOfOtherTable += "." 211 | } 212 | 213 | sqlList = append(sqlList, aliasOfCurrentTable+"."+fieldNameOfCurrentTable+"="+aliasOfOtherTable+fieldNameOfOtherTable) 214 | } 215 | 216 | if joinCondition[i].Opt == Eq { 217 | sqlList = append(sqlList, aliasOfCurrentTable+"."+fieldNameOfCurrentTable+"=?") 218 | paramList = append(paramList, fieldNameOfOtherTable) 219 | } 220 | } 221 | 222 | return strings.Join(sqlList, " AND "), paramList 223 | } 224 | 225 | //toAnyArr 将一个interface抽取成数组 226 | func toAnyArr(val any) []any { 227 | var values []any 228 | switch val.(type) { 229 | case []int: 230 | for _, value := range val.([]int) { 231 | values = append(values, value) 232 | } 233 | case []int64: 234 | for _, value := range val.([]int64) { 235 | values = append(values, value) 236 | } 237 | case []float32: 238 | for _, value := range val.([]float32) { 239 | values = append(values, value) 240 | } 241 | case []float64: 242 | for _, value := range val.([]float64) { 243 | values = append(values, value) 244 | } 245 | case []string: 246 | for _, value := range val.([]string) { 247 | values = append(values, value) 248 | } 249 | } 250 | 251 | return values 252 | } 253 | -------------------------------------------------------------------------------- /builder/cache.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | type FieldInfo struct { 8 | TablePointer uintptr 9 | Name string 10 | TagMap map[string]string 11 | } 12 | 13 | var TableMap = make(map[uintptr]string) 14 | var FieldMap = make(map[uintptr]FieldInfo) 15 | 16 | //Store 保存到缓存 17 | func Store(destList ...interface{}) { 18 | for i := 0; i < len(destList); i++ { 19 | dest := destList[i] 20 | valueOf := reflect.ValueOf(dest) 21 | typeof := reflect.TypeOf(dest) 22 | 23 | tablePointer := valueOf.Pointer() 24 | tableName := getTableNameByReflect(typeof, valueOf) 25 | setTableMap(tablePointer, tableName) 26 | 27 | for j := 0; j < valueOf.Elem().NumField(); j++ { 28 | fieldPointer := valueOf.Elem().Field(j).Addr().Pointer() 29 | key, _ := getFieldNameByStructField(typeof.Elem().Field(j)) 30 | tag := typeof.Elem().Field(j).Tag.Get("aorm") 31 | 32 | setFieldMap(fieldPointer, FieldInfo{ 33 | TablePointer: tablePointer, 34 | Name: key, 35 | TagMap: getTagMap(tag), 36 | }) 37 | } 38 | } 39 | } 40 | 41 | func Comment(field interface{}) string { 42 | fieldPointer := reflect.ValueOf(field).Pointer() 43 | tagMap := getFieldMap(fieldPointer).TagMap 44 | val, ok := tagMap["comment"] 45 | if ok { 46 | return val 47 | } else { 48 | return "" 49 | } 50 | } 51 | 52 | func setTableMap(tablePointer uintptr, name string) { 53 | TableMap[tablePointer] = name 54 | } 55 | 56 | func getTableMap(tablePointer uintptr) string { 57 | return TableMap[tablePointer] 58 | } 59 | 60 | func setFieldMap(fieldPointer uintptr, fieldInfo FieldInfo) { 61 | FieldMap[fieldPointer] = fieldInfo 62 | } 63 | 64 | func getFieldMap(fieldPointer uintptr) FieldInfo { 65 | return FieldMap[fieldPointer] 66 | } 67 | -------------------------------------------------------------------------------- /builder/crud.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | "github.com/tangpanqing/aorm/base" 8 | "github.com/tangpanqing/aorm/driver" 9 | "reflect" 10 | "strconv" 11 | "strings" 12 | ) 13 | 14 | const Desc = "DESC" 15 | const Asc = "ASC" 16 | 17 | const Eq = "=" 18 | const Ne = "!=" 19 | const Gt = ">" 20 | const Ge = ">=" 21 | const Lt = "<" 22 | const Le = "<=" 23 | 24 | const In = "IN" 25 | const NotIn = "NOT IN" 26 | const Like = "LIKE" 27 | const NotLike = "NOT LIKE" 28 | const Between = "BETWEEN" 29 | const NotBetween = "NOT BETWEEN" 30 | 31 | const Raw = "Raw" 32 | const FindInSet = "FindInSet" 33 | const RawEq = "RawEq" 34 | 35 | // Builder 查询记录所需要的条件 36 | type Builder struct { 37 | Link base.Link 38 | 39 | table interface{} 40 | tableAlias string 41 | 42 | selectList []SelectItem 43 | selectExpList []*SelectExpItem 44 | groupList []GroupItem 45 | whereList []WhereItem 46 | joinList []JoinItem 47 | havingList []WhereItem 48 | orderList []OrderItem 49 | limitItem LimitItem 50 | 51 | distinct bool 52 | isDebug bool 53 | isLockForUpdate bool 54 | 55 | //sql与参数 56 | query string 57 | args []interface{} 58 | } 59 | 60 | // Debug 链式操作-是否开启调试,打印sql 61 | func (b *Builder) Debug(isDebug bool) *Builder { 62 | b.isDebug = isDebug 63 | return b 64 | } 65 | 66 | // Distinct 过滤重复记录 67 | func (b *Builder) Distinct(distinct bool) *Builder { 68 | b.distinct = distinct 69 | return b 70 | } 71 | 72 | // Table 链式操作-从哪个表查询,允许直接写别名,例如 person p 73 | func (b *Builder) Table(table interface{}, alias ...string) *Builder { 74 | b.table = table 75 | if len(alias) > 0 { 76 | b.tableAlias = alias[0] 77 | } 78 | return b 79 | } 80 | 81 | // Insert 增加记录 82 | func (b *Builder) Insert(dest interface{}) (int64, error) { 83 | typeOf := reflect.TypeOf(dest) 84 | valueOf := reflect.ValueOf(dest) 85 | 86 | //主键名字 87 | var primaryKey = "" 88 | 89 | var keys []string 90 | var args []any 91 | var place []string 92 | for i := 0; i < typeOf.Elem().NumField(); i++ { 93 | key, tagMap := getFieldNameByStructField(typeOf.Elem().Field(i)) 94 | 95 | //如果是Postgres数据库,寻找主键 96 | if b.Link.DriverName() == driver.Postgres { 97 | if _, ok := tagMap["primary"]; ok { 98 | primaryKey = key 99 | } 100 | } 101 | 102 | isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() 103 | if isNotNull { 104 | val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() 105 | 106 | keys = append(keys, key) 107 | args = append(args, val) 108 | place = append(place, "?") 109 | } 110 | } 111 | 112 | query := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf) + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")" 113 | 114 | if b.Link.DriverName() == driver.Mssql { 115 | return b.insertForMssqlOrPostgres(query+"; SELECT SCOPE_IDENTITY()", args...) 116 | } else if b.Link.DriverName() == driver.Postgres { 117 | return b.insertForMssqlOrPostgres(convertToPostgresSql(query)+" RETURNING "+primaryKey, args...) 118 | } else { 119 | return b.insertForCommon(query, args...) 120 | } 121 | } 122 | 123 | //对于Mssql,Postgres类型数据库,为了获取最后插入的id,需要改写入为查询 124 | func (b *Builder) insertForMssqlOrPostgres(query string, args ...any) (int64, error) { 125 | if b.isDebug { 126 | fmt.Println(query) 127 | fmt.Println(args...) 128 | } 129 | 130 | rows, err := b.Link.Query(query, args...) 131 | if err != nil { 132 | return 0, err 133 | } 134 | defer rows.Close() 135 | var lastInsertId1 int64 136 | for rows.Next() { 137 | rows.Scan(&lastInsertId1) 138 | } 139 | return lastInsertId1, nil 140 | } 141 | 142 | //对于非Mssql,Postgres类型数据库,可以直接获取最后插入的id 143 | func (b *Builder) insertForCommon(query string, args ...any) (int64, error) { 144 | res, err := b.RawSql(query, args...).Exec() 145 | if err != nil { 146 | return 0, err 147 | } 148 | 149 | lastId, err := res.LastInsertId() 150 | if err != nil { 151 | return 0, err 152 | } 153 | 154 | return lastId, nil 155 | } 156 | 157 | // InsertBatch 批量增加记录 158 | func (b *Builder) InsertBatch(values interface{}) (int64, error) { 159 | var keys []string 160 | var args []any 161 | var place []string 162 | 163 | valueOf := reflect.ValueOf(values).Elem() 164 | 165 | if valueOf.Len() == 0 { 166 | return 0, errors.New("the data list for insert batch not found") 167 | } 168 | typeOf := reflect.TypeOf(values).Elem().Elem() 169 | 170 | for j := 0; j < valueOf.Len(); j++ { 171 | var placeItem []string 172 | 173 | for i := 0; i < valueOf.Index(j).Elem().NumField(); i++ { 174 | isNotNull := valueOf.Index(j).Elem().Field(i).Field(0).Field(1).Bool() 175 | if isNotNull { 176 | if j == 0 { 177 | key, _ := getFieldNameByStructField(typeOf.Elem().Field(i)) 178 | keys = append(keys, key) 179 | } 180 | 181 | val := valueOf.Index(j).Elem().Field(i).Field(0).Field(0).Interface() 182 | args = append(args, val) 183 | placeItem = append(placeItem, "?") 184 | } 185 | } 186 | 187 | place = append(place, "("+strings.Join(placeItem, ",")+")") 188 | } 189 | 190 | query := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf.Index(0)) + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",") 191 | 192 | if b.Link.DriverName() == driver.Postgres { 193 | query = convertToPostgresSql(query) 194 | } 195 | 196 | res, err := b.RawSql(query, args...).Exec() 197 | if err != nil { 198 | return 0, err 199 | } 200 | 201 | count, err := res.RowsAffected() 202 | if err != nil { 203 | return 0, err 204 | } 205 | 206 | return count, nil 207 | } 208 | 209 | // GetMany 查询记录(新) 210 | func (b *Builder) GetMany(values interface{}) error { 211 | stmt, rows, errRows := b.GetRows() 212 | if errRows != nil { 213 | return errRows 214 | } 215 | defer stmt.Close() 216 | defer rows.Close() 217 | 218 | destSlice := reflect.Indirect(reflect.ValueOf(values)) 219 | destType := destSlice.Type().Elem() 220 | destValue := reflect.New(destType).Elem() 221 | 222 | //从数据库中读出来的字段名字 223 | columnNameList, errColumns := rows.Columns() 224 | if errColumns != nil { 225 | return errColumns 226 | } 227 | 228 | //从结构体反射出来的属性名 229 | fieldNameMap := getFieldMapByReflect(destType) 230 | 231 | for rows.Next() { 232 | scans := getScansAddr(columnNameList, fieldNameMap, destValue) 233 | 234 | errScan := rows.Scan(scans...) 235 | if errScan != nil { 236 | return errScan 237 | } 238 | 239 | destSlice.Set(reflect.Append(destSlice, destValue)) 240 | } 241 | 242 | return nil 243 | } 244 | 245 | // GetOne 查询某一条记录 246 | func (b *Builder) GetOne(obj interface{}) error { 247 | b.Limit(0, 1) 248 | 249 | stmt, rows, errRows := b.GetRows() 250 | if errRows != nil { 251 | return errRows 252 | } 253 | defer stmt.Close() 254 | defer rows.Close() 255 | 256 | if rows.Next() { 257 | destType := reflect.TypeOf(obj).Elem() 258 | destValue := reflect.ValueOf(obj).Elem() 259 | 260 | //从数据库中读出来的字段名字 261 | columnNameList, errColumns := rows.Columns() 262 | if errColumns != nil { 263 | return errColumns 264 | } 265 | 266 | //从结构体反射出来的属性名 267 | fieldNameMap := getFieldMapByReflect(destType) 268 | 269 | scans := getScansAddr(columnNameList, fieldNameMap, destValue) 270 | err := rows.Scan(scans...) 271 | if err != nil { 272 | return err 273 | } 274 | 275 | return nil 276 | } else { 277 | return errors.New("NOT FOUND") 278 | } 279 | } 280 | 281 | // Update 更新记录 282 | func (b *Builder) Update(dest interface{}) (int64, error) { 283 | typeOf := reflect.TypeOf(dest) 284 | valueOf := reflect.ValueOf(dest) 285 | 286 | var args []any 287 | setStr, args := b.handleSet(typeOf, valueOf, args) 288 | whereStr, args, err := b.handleWhere(args, false) 289 | if err != nil { 290 | return 0, err 291 | } 292 | query := "UPDATE " + b.getTableNameCommon(typeOf, valueOf) + setStr + whereStr 293 | 294 | return b.execAffected(query, args...) 295 | } 296 | 297 | // Delete 删除记录 298 | func (b *Builder) Delete(destList ...interface{}) (int64, error) { 299 | tableName := "" 300 | 301 | if len(destList) > 0 { 302 | b.Where(destList[0]) 303 | 304 | typeOf := reflect.TypeOf(destList[0]) 305 | valueOf := reflect.ValueOf(destList[0]) 306 | tableName = b.getTableNameCommon(typeOf, valueOf) 307 | } 308 | 309 | if tableName == "" { 310 | if b.table == nil { 311 | return 0, errors.New("表名不能为空") 312 | } 313 | tableName = getTableNameByTable(b.table) 314 | } 315 | 316 | var args []any 317 | whereStr, args, err := b.handleWhere(args, false) 318 | if err != nil { 319 | return 0, err 320 | } 321 | query := "DELETE FROM " + tableName + whereStr 322 | 323 | return b.execAffected(query, args...) 324 | } 325 | 326 | // GroupBy 链式操作,以某字段进行分组 327 | func (b *Builder) GroupBy(field interface{}, prefix ...string) *Builder { 328 | b.groupList = append(b.groupList, GroupItem{ 329 | Prefix: prefix, 330 | Field: field, 331 | }) 332 | return b 333 | } 334 | 335 | // Limit 链式操作,分页 336 | func (b *Builder) Limit(offset int, pageSize int) *Builder { 337 | b.limitItem = LimitItem{ 338 | offset: offset, 339 | pageSize: pageSize, 340 | } 341 | return b 342 | } 343 | 344 | // Page 链式操作,分页 345 | func (b *Builder) Page(pageNum int, pageSize int) *Builder { 346 | b.limitItem = LimitItem{ 347 | offset: (pageNum - 1) * pageSize, 348 | pageSize: pageSize, 349 | } 350 | return b 351 | } 352 | 353 | // LockForUpdate 加锁, sqlite3不支持此操作 354 | func (b *Builder) LockForUpdate(isLockForUpdate bool) *Builder { 355 | b.isLockForUpdate = isLockForUpdate 356 | return b 357 | } 358 | 359 | // Truncate 清空记录 360 | func (b *Builder) Truncate() (int64, error) { 361 | if b.table == nil { 362 | return 0, errors.New("表名不能为空") 363 | } 364 | 365 | query := "" 366 | if b.Link.DriverName() == driver.Sqlite3 { 367 | query = "DELETE FROM " + getTableNameByTable(b.table) 368 | } else { 369 | query = "TRUNCATE TABLE " + getTableNameByTable(b.table) 370 | } 371 | 372 | return b.execAffected(query) 373 | } 374 | 375 | // RawSql 执行原始的sql语句 376 | func (b *Builder) RawSql(query string, args ...interface{}) *Builder { 377 | b.query = query 378 | b.args = args 379 | return b 380 | } 381 | 382 | // GetRows 获取行操作 383 | func (b *Builder) GetRows() (*sql.Stmt, *sql.Rows, error) { 384 | query, args, err := b.GetSqlAndParams() 385 | if err != nil { 386 | return nil, nil, err 387 | } 388 | 389 | if b.Link.DriverName() == driver.Postgres { 390 | query = convertToPostgresSql(query) 391 | } 392 | 393 | if b.isDebug { 394 | fmt.Println(query) 395 | fmt.Println(args...) 396 | } 397 | 398 | smt, errSmt := b.Link.Prepare(query) 399 | if errSmt != nil { 400 | return nil, nil, errSmt 401 | } 402 | 403 | rows, errRows := smt.Query(args...) 404 | if errRows != nil { 405 | return nil, nil, errRows 406 | } 407 | 408 | return smt, rows, nil 409 | } 410 | 411 | // Exec 通用执行-新增,更新,删除 412 | func (b *Builder) Exec() (sql.Result, error) { 413 | if b.Link.DriverName() == driver.Postgres { 414 | b.query = convertToPostgresSql(b.query) 415 | } 416 | 417 | if b.isDebug { 418 | fmt.Println(b.query) 419 | fmt.Println(b.args...) 420 | } 421 | 422 | smt, err1 := b.Link.Prepare(b.query) 423 | if err1 != nil { 424 | return nil, err1 425 | } 426 | defer smt.Close() 427 | 428 | res, err2 := smt.Exec(b.args...) 429 | if err2 != nil { 430 | return nil, err2 431 | } 432 | 433 | //b.clear() 434 | return res, nil 435 | } 436 | 437 | //拼接SQL,查询与筛选通用操作 438 | func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving bool, needPrefix bool) ([]string, []any, error) { 439 | var whereList []string 440 | for i := 0; i < len(where); i++ { 441 | valueOfField := reflect.ValueOf(where[i].Field) 442 | 443 | allFieldName := "" 444 | if needPrefix { 445 | prefix := getPrefixByField(valueOfField, where[i].Prefix...) 446 | if prefix != "" { 447 | allFieldName += prefix + "." 448 | } 449 | } 450 | 451 | //如果是mssql或者Postgres,并且来自having的话,需要特殊处理 452 | if (b.Link.DriverName() == driver.Mssql || b.Link.DriverName() == driver.Postgres) && isFromHaving { 453 | fieldNameCurrent := getFieldNameByReflectValue(valueOfField) 454 | for m := 0; m < len(b.selectList); m++ { 455 | if fieldNameCurrent == getFieldNameByField(b.selectList[m].FieldNew) { 456 | allFieldName += handleSelectWith(b.selectList[m]) 457 | } 458 | } 459 | } else { 460 | allFieldName += getFieldNameByReflectValue(valueOfField) 461 | } 462 | 463 | if "**builder.Builder" == reflect.TypeOf(where[i].Val).String() { 464 | subBuilder := *(**Builder)(reflect.ValueOf(where[i].Val).UnsafePointer()) 465 | subSql, subParams, err := subBuilder.GetSqlAndParams() 466 | if err != nil { 467 | return whereList, args, err 468 | } 469 | 470 | if where[i].Opt != Raw { 471 | whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"("+subSql+")") 472 | args = append(args, subParams...) 473 | } 474 | } else { 475 | if where[i].Opt == Eq || where[i].Opt == Ne || where[i].Opt == Gt || where[i].Opt == Ge || where[i].Opt == Lt || where[i].Opt == Le { 476 | if b.Link.DriverName() == driver.Sqlite3 { 477 | whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"?") 478 | } else { 479 | switch where[i].Val.(type) { 480 | case float32: 481 | whereList = append(whereList, b.getConcatForFloat(allFieldName, "''")+" "+where[i].Opt+" "+"?") 482 | case float64: 483 | whereList = append(whereList, b.getConcatForFloat(allFieldName, "''")+" "+where[i].Opt+" "+"?") 484 | default: 485 | whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"?") 486 | } 487 | } 488 | 489 | args = append(args, fmt.Sprintf("%v", where[i].Val)) 490 | } 491 | 492 | if where[i].Opt == Between || where[i].Opt == NotBetween { 493 | values := toAnyArr(where[i].Val) 494 | whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"(?) AND (?)") 495 | args = append(args, values...) 496 | } 497 | 498 | if where[i].Opt == Like || where[i].Opt == NotLike { 499 | values := toAnyArr(where[i].Val) 500 | var valueStr []string 501 | for j := 0; j < len(values); j++ { 502 | str := fmt.Sprintf("%v", values[j]) 503 | 504 | if "%" != str { 505 | args = append(args, str) 506 | valueStr = append(valueStr, "?") 507 | } else { 508 | valueStr = append(valueStr, "'"+str+"'") 509 | } 510 | } 511 | 512 | whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+b.getConcatForLike(valueStr...)) 513 | } 514 | 515 | if where[i].Opt == In || where[i].Opt == NotIn { 516 | values := toAnyArr(where[i].Val) 517 | var placeholder []string 518 | for j := 0; j < len(values); j++ { 519 | placeholder = append(placeholder, "?") 520 | } 521 | 522 | whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"("+strings.Join(placeholder, ",")+")") 523 | args = append(args, values...) 524 | } 525 | 526 | if where[i].Opt == FindInSet { 527 | whereList = append(whereList, "FIND_IN_SET(?,"+getPrefixByField(reflect.ValueOf(where[i].Field), where[i].Prefix...)+"."+getFieldNameByField(where[i].Field)+")") 528 | args = append(args, where[i].Val) 529 | } 530 | 531 | if where[i].Opt == Raw { 532 | whereList = append(whereList, allFieldName+" "+fmt.Sprintf("%v", where[i].Val)) 533 | } 534 | 535 | if where[i].Opt == RawEq { 536 | whereList = append(whereList, allFieldName+Eq+getPrefixByField(reflect.ValueOf(where[i].Val))+"."+getFieldNameByField(where[i].Val)) 537 | } 538 | } 539 | } 540 | return whereList, args, nil 541 | } 542 | 543 | func (b *Builder) getConcatForFloat(vars ...string) string { 544 | if b.Link.DriverName() == driver.Sqlite3 { 545 | return strings.Join(vars, "||") 546 | } else if b.Link.DriverName() == driver.Postgres { 547 | return vars[0] 548 | } else { 549 | return "CONCAT(" + strings.Join(vars, ",") + ")" 550 | } 551 | } 552 | 553 | func (b *Builder) getConcatForLike(vars ...string) string { 554 | if b.Link.DriverName() == driver.Sqlite3 || b.Link.DriverName() == driver.Postgres { 555 | return strings.Join(vars, "||") 556 | } else { 557 | return "CONCAT(" + strings.Join(vars, ",") + ")" 558 | } 559 | } 560 | 561 | func (b *Builder) getTableNameCommon(typeOf reflect.Type, valueOf reflect.Value) string { 562 | if b.table != nil { 563 | return getTableNameByTable(b.table) 564 | } 565 | 566 | return getTableNameByReflect(typeOf, valueOf) 567 | } 568 | 569 | func (b *Builder) GetSqlAndParams() (string, []interface{}, error) { 570 | if b.query != "" { 571 | return b.query, b.args, nil 572 | } 573 | 574 | var args []interface{} 575 | selectStr, args, err := b.handleSelect(args) 576 | if err != nil { 577 | return "", args, err 578 | } 579 | 580 | tableStr, args, err := b.handleTable(args) 581 | if err != nil { 582 | return "", args, err 583 | } 584 | joinStr, args := b.handleJoin(args) 585 | whereStr, args, err := b.handleWhere(args, true) 586 | if err != nil { 587 | return "", args, err 588 | } 589 | 590 | groupStr, args := b.handleGroup(args) 591 | havingStr, args, err := b.handleHaving(args) 592 | if err != nil { 593 | return "", args, err 594 | } 595 | 596 | orderStr, args := b.handleOrder(args) 597 | limitStr, args := b.handleLimit(args) 598 | lockStr := b.handleLockForUpdate() 599 | 600 | //效率低 601 | //query := selectStr + tableStr + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr 602 | //return query, args, nil 603 | 604 | var bd strings.Builder 605 | bd.WriteString(selectStr) 606 | bd.WriteString(tableStr) 607 | bd.WriteString(joinStr) 608 | bd.WriteString(whereStr) 609 | bd.WriteString(groupStr) 610 | bd.WriteString(havingStr) 611 | bd.WriteString(orderStr) 612 | bd.WriteString(limitStr) 613 | bd.WriteString(lockStr) 614 | 615 | return bd.String(), args, nil 616 | } 617 | 618 | // execAffected 通用执行-更新,删除 619 | func (b *Builder) execAffected(query string, args ...interface{}) (int64, error) { 620 | if b.Link.DriverName() == driver.Postgres { 621 | query = convertToPostgresSql(query) 622 | } 623 | 624 | res, err := b.RawSql(query, args...).Exec() 625 | if err != nil { 626 | return 0, err 627 | } 628 | 629 | count, err := res.RowsAffected() 630 | if err != nil { 631 | return 0, err 632 | } 633 | 634 | return count, nil 635 | } 636 | 637 | func getTagMap(fieldTag string) map[string]string { 638 | var fieldMap = make(map[string]string) 639 | if "" != fieldTag { 640 | tagArr := strings.Split(fieldTag, ";") 641 | for j := 0; j < len(tagArr); j++ { 642 | tagArrArr := strings.Split(tagArr[j], ":") 643 | fieldMap[tagArrArr[0]] = "" 644 | if len(tagArrArr) > 1 { 645 | fieldMap[tagArrArr[0]] = tagArrArr[1] 646 | } 647 | } 648 | } 649 | return fieldMap 650 | } 651 | 652 | //对于Postgres数据库,不支持?占位符,支持$1,$2类型,需要做转换 653 | func convertToPostgresSql(query string) string { 654 | t := 1 655 | for { 656 | if strings.Index(query, "?") == -1 { 657 | break 658 | } 659 | query = strings.Replace(query, "?", "$"+strconv.Itoa(t), 1) 660 | t += 1 661 | } 662 | 663 | return query 664 | } 665 | -------------------------------------------------------------------------------- /builder/exists.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | // Exists 存在某记录 4 | func (b *Builder) Exists() (bool, error) { 5 | stmt, rows, err := b.selectCommon("", "1", nil, "").Limit(0, 1).GetRows() 6 | if err != nil { 7 | return false, err 8 | } 9 | defer stmt.Close() 10 | defer rows.Close() 11 | 12 | if rows.Next() { 13 | return true, nil 14 | } else { 15 | return false, nil 16 | } 17 | } 18 | 19 | // DoesntExist 不存在某记录 20 | func (b *Builder) DoesntExist() (bool, error) { 21 | isE, err := b.Exists() 22 | return !isE, err 23 | } 24 | -------------------------------------------------------------------------------- /builder/handle.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/tangpanqing/aorm/driver" 7 | "reflect" 8 | "strings" 9 | ) 10 | 11 | func handleSelectWith(selectItem SelectItem) string { 12 | str := "" 13 | if selectItem.FuncName != "" { 14 | str += selectItem.FuncName 15 | str += "(" 16 | } 17 | 18 | valueOfField := reflect.ValueOf(selectItem.Field) 19 | prefix := getPrefixByField(valueOfField, selectItem.Prefix...) 20 | if prefix != "" { 21 | str += prefix + "." 22 | } 23 | 24 | str += getFieldNameByReflectValue(valueOfField) 25 | 26 | if selectItem.FuncName != "" { 27 | str += ")" 28 | } 29 | 30 | return str 31 | } 32 | 33 | //拼接SQL,字段相关 34 | func (b *Builder) handleSelect(paramList []any) (string, []any, error) { 35 | fieldStr := "" 36 | if b.distinct { 37 | fieldStr += "DISTINCT " 38 | } 39 | 40 | if len(b.selectList) == 0 && len(b.selectExpList) == 0 { 41 | fieldStr += "*" 42 | return "SELECT " + fieldStr, paramList, nil 43 | } 44 | 45 | var strList []string 46 | 47 | //处理一般的参数 48 | for i := 0; i < len(b.selectList); i++ { 49 | selectItem := b.selectList[i] 50 | 51 | str := handleSelectWith(selectItem) 52 | 53 | if selectItem.FieldNew != nil { 54 | str += " AS " 55 | str += getFieldNameByField(selectItem.FieldNew) 56 | } 57 | 58 | strList = append(strList, str) 59 | } 60 | 61 | //处理子语句 62 | for i := 0; i < len(b.selectExpList); i++ { 63 | subBuilder := *(b.selectExpList[i].Builder) 64 | subSql, subParamList, err := subBuilder.GetSqlAndParams() 65 | if err != nil { 66 | return "", paramList, err 67 | } 68 | strList = append(strList, "("+subSql+") AS "+getFieldNameByField(b.selectExpList[i].FieldName)) 69 | paramList = append(paramList, subParamList...) 70 | } 71 | 72 | fieldStr += strings.Join(strList, ",") 73 | return "SELECT " + fieldStr, paramList, nil 74 | } 75 | 76 | func (b *Builder) handleTable(paramList []any) (string, []any, error) { 77 | if b.table == nil { 78 | return "", paramList, errors.New("表不能为空") 79 | } 80 | 81 | var tableName string 82 | 83 | valueOf := reflect.ValueOf(b.table) 84 | if reflect.Ptr == valueOf.Kind() { 85 | 86 | if "**builder.Builder" != valueOf.Type().String() { 87 | tableName = getTableMap(valueOf.Pointer()) 88 | } else { 89 | if b.tableAlias == "" { 90 | return "", paramList, errors.New("别名不能为空") 91 | } 92 | 93 | subBuilder := *(**Builder)(valueOf.UnsafePointer()) 94 | subSql, subParamList, err := subBuilder.GetSqlAndParams() 95 | if err != nil { 96 | return "", paramList, err 97 | } 98 | 99 | tableName = "(" + subSql + ")" 100 | paramList = append(paramList, subParamList...) 101 | } 102 | } else { 103 | tableName = fmt.Sprintf("%v", b.table) 104 | } 105 | 106 | return " FROM " + tableName + " " + b.tableAlias, paramList, nil 107 | } 108 | 109 | //拼接SQL,查询条件 110 | func (b *Builder) handleWhere(paramList []any, needPrefix bool) (string, []any, error) { 111 | if len(b.whereList) == 0 { 112 | return "", paramList, nil 113 | } 114 | 115 | strList, paramList, err := b.whereAndHaving(b.whereList, paramList, false, needPrefix) 116 | if err != nil { 117 | return "", paramList, nil 118 | } 119 | 120 | return " WHERE " + strings.Join(strList, " AND "), paramList, nil 121 | } 122 | 123 | //拼接SQL,更新信息 124 | func (b *Builder) handleSet(typeOf reflect.Type, valueOf reflect.Value, paramList []any) (string, []any) { 125 | 126 | //如果没有设置表名 127 | if b.table == nil { 128 | b.table = getTableNameByReflect(typeOf, valueOf) 129 | } 130 | 131 | var keys []string 132 | for i := 0; i < typeOf.Elem().NumField(); i++ { 133 | isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() 134 | if isNotNull { 135 | key, _ := getFieldNameByStructField(typeOf.Elem().Field(i)) 136 | 137 | val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() 138 | 139 | keys = append(keys, key+"=?") 140 | paramList = append(paramList, val) 141 | } 142 | } 143 | 144 | return " SET " + strings.Join(keys, ","), paramList 145 | } 146 | 147 | //拼接SQL,关联查询 148 | func (b *Builder) handleJoin(paramList []interface{}) (string, []interface{}) { 149 | if len(b.joinList) == 0 { 150 | return "", paramList 151 | } 152 | 153 | var sqlList []string 154 | for i := 0; i < len(b.joinList); i++ { 155 | joinItem := b.joinList[i] 156 | 157 | tableAlias := "" 158 | if len(joinItem.tableAlias) > 0 { 159 | tableAlias = joinItem.tableAlias[0] 160 | } 161 | 162 | str, paramList2 := genJoinConditionStr(tableAlias, joinItem.condition) 163 | paramList = append(paramList, paramList2...) 164 | 165 | sqlItem := joinItem.joinType + " " + getTableNameByTable(joinItem.table) + " " + tableAlias + " ON " + str 166 | sqlList = append(sqlList, sqlItem) 167 | } 168 | 169 | return " " + strings.Join(sqlList, " "), paramList 170 | } 171 | 172 | //拼接SQL,结果分组 173 | func (b *Builder) handleGroup(paramList []any) (string, []any) { 174 | if len(b.groupList) == 0 { 175 | return "", paramList 176 | } 177 | 178 | var groupList []string 179 | for i := 0; i < len(b.groupList); i++ { 180 | valueOfField := reflect.ValueOf(b.groupList[i].Field) 181 | prefix := getPrefixByField(valueOfField, b.groupList[i].Prefix...) 182 | if prefix != "" { 183 | prefix += "." 184 | } 185 | field := getFieldNameByReflectValue(valueOfField) 186 | groupList = append(groupList, prefix+field) 187 | } 188 | 189 | return " GROUP BY " + strings.Join(groupList, ","), paramList 190 | } 191 | 192 | //拼接SQL,结果筛选 193 | func (b *Builder) handleHaving(paramList []any) (string, []any, error) { 194 | if len(b.havingList) == 0 { 195 | return "", paramList, nil 196 | } 197 | 198 | strList, paramList, err := b.whereAndHaving(b.havingList, paramList, true, true) 199 | if err != nil { 200 | return "", paramList, err 201 | } 202 | 203 | return " Having " + strings.Join(strList, " AND "), paramList, nil 204 | } 205 | 206 | //拼接SQL,结果排序 207 | func (b *Builder) handleOrder(paramList []any) (string, []any) { 208 | if len(b.orderList) == 0 { 209 | return "", paramList 210 | } 211 | 212 | var orderList []string 213 | for i := 0; i < len(b.orderList); i++ { 214 | valueOfField := reflect.ValueOf(b.orderList[i].Field) 215 | prefix := getPrefixByField(valueOfField, b.orderList[i].Prefix...) 216 | field := getFieldNameByReflectValue(valueOfField) 217 | orderList = append(orderList, prefix+"."+field+" "+b.orderList[i].OrderType) 218 | } 219 | 220 | return " ORDER BY " + strings.Join(orderList, ","), paramList 221 | } 222 | 223 | //拼接SQL,分页相关 Postgres数据库分页数量在前偏移在后,其他数据库偏移量在前分页数量在后,另外Mssql数据库的关键词是offset...next 224 | func (b *Builder) handleLimit(paramList []any) (string, []any) { 225 | if 0 == b.limitItem.pageSize { 226 | return "", paramList 227 | } 228 | 229 | str := "" 230 | if b.Link.DriverName() == driver.Postgres { 231 | paramList = append(paramList, b.limitItem.pageSize) 232 | paramList = append(paramList, b.limitItem.offset) 233 | 234 | str = " Limit ? offset ? " 235 | } else { 236 | paramList = append(paramList, b.limitItem.offset) 237 | paramList = append(paramList, b.limitItem.pageSize) 238 | 239 | str = " Limit ?,? " 240 | if b.Link.DriverName() == driver.Mssql { 241 | str = " offset ? rows fetch next ? rows only " 242 | } 243 | } 244 | 245 | return str, paramList 246 | } 247 | 248 | //拼接SQL,锁 249 | func (b *Builder) handleLockForUpdate() string { 250 | if b.isLockForUpdate { 251 | return " FOR UPDATE" 252 | } 253 | 254 | return "" 255 | } 256 | -------------------------------------------------------------------------------- /builder/having.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import ( 4 | "github.com/tangpanqing/aorm/utils" 5 | "reflect" 6 | ) 7 | 8 | // Having 链式操作,以对象作为筛选条件 9 | func (b *Builder) Having(dest interface{}) *Builder { 10 | typeOf := reflect.TypeOf(dest) 11 | valueOf := reflect.ValueOf(dest) 12 | 13 | //如果没有设置表名 14 | if b.table == nil { 15 | b.table = getTableNameByReflect(typeOf, valueOf) 16 | } 17 | 18 | for i := 0; i < typeOf.Elem().NumField(); i++ { 19 | isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() 20 | if isNotNull { 21 | key := utils.UnderLine(typeOf.Elem().Field(i).Name) 22 | val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() 23 | b.havingList = append(b.havingList, WhereItem{Field: key, Opt: Eq, Val: val}) 24 | } 25 | } 26 | 27 | return b 28 | } 29 | 30 | // HavingArr 链式操作,以数组作为筛选条件 31 | func (b *Builder) HavingArr(havingList []WhereItem) *Builder { 32 | b.havingList = append(b.havingList, havingList...) 33 | return b 34 | } 35 | 36 | func (b *Builder) HavingEq(field interface{}, val interface{}) *Builder { 37 | return b.havingItemAppend(field, Eq, val) 38 | } 39 | 40 | func (b *Builder) HavingNe(field interface{}, val interface{}) *Builder { 41 | return b.havingItemAppend(field, Ne, val) 42 | } 43 | 44 | func (b *Builder) HavingGt(field interface{}, val interface{}) *Builder { 45 | return b.havingItemAppend(field, Gt, val) 46 | } 47 | 48 | func (b *Builder) HavingGe(field interface{}, val interface{}) *Builder { 49 | return b.havingItemAppend(field, Ge, val) 50 | } 51 | 52 | func (b *Builder) HavingLt(field interface{}, val interface{}) *Builder { 53 | return b.havingItemAppend(field, Lt, val) 54 | } 55 | 56 | func (b *Builder) HavingLe(field interface{}, val interface{}) *Builder { 57 | return b.havingItemAppend(field, Le, val) 58 | } 59 | 60 | func (b *Builder) HavingIn(field interface{}, val interface{}) *Builder { 61 | return b.havingItemAppend(field, In, val) 62 | } 63 | 64 | func (b *Builder) HavingNotIn(field interface{}, val interface{}) *Builder { 65 | return b.havingItemAppend(field, NotIn, val) 66 | } 67 | 68 | func (b *Builder) HavingBetween(field interface{}, val interface{}) *Builder { 69 | return b.havingItemAppend(field, Between, val) 70 | } 71 | 72 | func (b *Builder) HavingNotBetween(field interface{}, val interface{}) *Builder { 73 | return b.havingItemAppend(field, NotBetween, val) 74 | } 75 | 76 | func (b *Builder) HavingLike(field interface{}, val interface{}) *Builder { 77 | return b.havingItemAppend(field, Like, val) 78 | } 79 | 80 | func (b *Builder) HavingNotLike(field interface{}, val interface{}) *Builder { 81 | return b.havingItemAppend(field, NotLike, val) 82 | } 83 | 84 | func (b *Builder) HavingRaw(val interface{}) *Builder { 85 | return b.havingItemAppend("", Raw, val) 86 | } 87 | 88 | func (b *Builder) havingItemAppend(field interface{}, opt string, val interface{}) *Builder { 89 | b.havingList = append(b.havingList, WhereItem{[]string{""}, field, opt, val}) 90 | return b 91 | } 92 | -------------------------------------------------------------------------------- /builder/increment.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import "errors" 4 | 5 | // Increment 某字段自增 6 | func (b *Builder) Increment(field interface{}, step int) (int64, error) { 7 | var vars []any 8 | vars = append(vars, step) 9 | whereStr, vars, err := b.handleWhere(vars, false) 10 | if err != nil { 11 | return 0, err 12 | } 13 | 14 | if b.table == nil { 15 | return 0, errors.New("表名不能为空") 16 | } 17 | query := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldNameByField(field) + "=" + getFieldNameByField(field) + "+?" + whereStr 18 | return b.execAffected(query, vars...) 19 | } 20 | 21 | // Decrement 某字段自减 22 | func (b *Builder) Decrement(field interface{}, step int) (int64, error) { 23 | var vars []any 24 | vars = append(vars, step) 25 | whereStr, vars, err := b.handleWhere(vars, false) 26 | if err != nil { 27 | return 0, err 28 | } 29 | 30 | if b.table == nil { 31 | return 0, errors.New("表名不能为空") 32 | } 33 | query := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldNameByField(field) + "=" + getFieldNameByField(field) + "-?" + whereStr 34 | return b.execAffected(query, vars...) 35 | } 36 | -------------------------------------------------------------------------------- /builder/join.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | // LeftJoin 链式操作,左联查询,例如 LeftJoin("project p", "p.project_id=o.project_id") 4 | func (b *Builder) LeftJoin(table interface{}, condition []JoinCondition, alias ...string) *Builder { 5 | return b.join("LEFT JOIN", table, condition, alias...) 6 | } 7 | 8 | // RightJoin 链式操作,右联查询,例如 RightJoin("project p", "p.project_id=o.project_id") 9 | func (b *Builder) RightJoin(table interface{}, condition []JoinCondition, alias ...string) *Builder { 10 | return b.join("RIGHT JOIN", table, condition, alias...) 11 | } 12 | 13 | // Join 链式操作,内联查询,例如 Join("project p", "p.project_id=o.project_id") 14 | func (b *Builder) Join(table interface{}, condition []JoinCondition, alias ...string) *Builder { 15 | return b.join("INNER JOIN", table, condition, alias...) 16 | } 17 | 18 | func (b *Builder) join(joinType string, table interface{}, condition []JoinCondition, alias ...string) *Builder { 19 | b.joinList = append(b.joinList, JoinItem{joinType, table, alias, condition}) 20 | return b 21 | } 22 | -------------------------------------------------------------------------------- /builder/order.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | func (b *Builder) OrderDescBy(field interface{}, prefix ...string) *Builder { 4 | return b.OrderBy(field, Desc, prefix...) 5 | } 6 | 7 | func (b *Builder) OrderAscBy(field interface{}, prefix ...string) *Builder { 8 | return b.OrderBy(field, Asc, prefix...) 9 | } 10 | 11 | // OrderBy 链式操作,以某字段进行排序 12 | func (b *Builder) OrderBy(field interface{}, orderType string, prefix ...string) *Builder { 13 | b.orderList = append(b.orderList, OrderItem{prefix, field, orderType}) 14 | return b 15 | } 16 | -------------------------------------------------------------------------------- /builder/select.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | func (b *Builder) SelectAll(table interface{}) *Builder { 4 | return b.selectCommon("", "*", nil, getTableNameByTable(table)) 5 | } 6 | 7 | // Select 链式操作-查询哪些字段,默认 * 8 | func (b *Builder) Select(field interface{}, prefix ...string) *Builder { 9 | return b.selectCommon("", field, nil, prefix...) 10 | } 11 | 12 | func (b *Builder) SelectAs(field interface{}, fieldNew interface{}, prefix ...string) *Builder { 13 | return b.selectCommon("", field, fieldNew, prefix...) 14 | } 15 | 16 | // SelectCount 链式操作-count(field) as field_new 17 | func (b *Builder) SelectCount(field interface{}, fieldNew interface{}, prefix ...string) *Builder { 18 | return b.selectCommon("Count", field, fieldNew, prefix...) 19 | } 20 | 21 | // SelectSum 链式操作-sum(field) as field_new 22 | func (b *Builder) SelectSum(field interface{}, fieldNew interface{}, prefix ...string) *Builder { 23 | return b.selectCommon("Sum", field, fieldNew, prefix...) 24 | } 25 | 26 | // SelectMin 链式操作-min(field) as field_new 27 | func (b *Builder) SelectMin(field interface{}, fieldNew interface{}, prefix ...string) *Builder { 28 | return b.selectCommon("Min", field, fieldNew, prefix...) 29 | } 30 | 31 | // SelectMax 链式操作-max(field) as field_new 32 | func (b *Builder) SelectMax(field interface{}, fieldNew interface{}, prefix ...string) *Builder { 33 | return b.selectCommon("Max", field, fieldNew, prefix...) 34 | } 35 | 36 | // SelectAvg 链式操作-avg(field) as field_new 37 | func (b *Builder) SelectAvg(field interface{}, fieldNew interface{}, prefix ...string) *Builder { 38 | return b.selectCommon("Avg", field, fieldNew, prefix...) 39 | } 40 | 41 | // SelectConcat 链式操作-concat(field) as field_new 42 | func (b *Builder) SelectConcat(field interface{}, fieldNew interface{}, prefix ...string) *Builder { 43 | return b.selectCommon("concat", field, fieldNew, prefix...) 44 | } 45 | 46 | // SelectGroupConcat 链式操作-group_concat(field) as field_new 47 | func (b *Builder) SelectGroupConcat(field interface{}, fieldNew interface{}, prefix ...string) *Builder { 48 | return b.selectCommon("group_concat", field, fieldNew, prefix...) 49 | } 50 | 51 | func (b *Builder) selectCommon(funcName string, field interface{}, fieldNew interface{}, prefix ...string) *Builder { 52 | b.selectList = append(b.selectList, SelectItem{funcName, prefix, field, fieldNew}) 53 | return b 54 | } 55 | 56 | // SelectExp 链式操作-表达式 57 | func (b *Builder) SelectExp(dbSub **Builder, fieldName interface{}) *Builder { 58 | b.selectExpList = append(b.selectExpList, &SelectExpItem{ 59 | Builder: dbSub, 60 | FieldName: fieldName, 61 | }) 62 | return b 63 | } 64 | -------------------------------------------------------------------------------- /builder/value.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import "reflect" 4 | 5 | // Value 字段值 6 | func (b *Builder) Value(field interface{}, dest interface{}) error { 7 | b.Select(field).Limit(0, 1) 8 | 9 | fieldName := getFieldNameByField(field) 10 | 11 | stmt, rows, errRows := b.GetRows() 12 | if errRows != nil { 13 | return errRows 14 | } 15 | defer stmt.Close() 16 | defer rows.Close() 17 | 18 | destValue := reflect.ValueOf(dest).Elem() 19 | 20 | //从数据库中读出来的字段名字 21 | columnNameList, errColumns := rows.Columns() 22 | if errColumns != nil { 23 | return errColumns 24 | } 25 | 26 | for rows.Next() { 27 | var scans []interface{} 28 | for _, columnName := range columnNameList { 29 | if fieldName == columnName { 30 | scans = append(scans, destValue.Addr().Interface()) 31 | } else { 32 | var emptyVal interface{} 33 | scans = append(scans, &emptyVal) 34 | } 35 | } 36 | 37 | err := rows.Scan(scans...) 38 | if err != nil { 39 | return err 40 | } 41 | } 42 | 43 | return nil 44 | } 45 | 46 | // Pluck 获取某一列的值 47 | func (b *Builder) Pluck(field interface{}, values interface{}) error { 48 | b.Select(field) 49 | fieldName := getFieldNameByField(field) 50 | 51 | stmt, rows, errRows := b.GetRows() 52 | if errRows != nil { 53 | return errRows 54 | } 55 | defer stmt.Close() 56 | defer rows.Close() 57 | 58 | destSlice := reflect.Indirect(reflect.ValueOf(values)) 59 | destType := destSlice.Type().Elem() 60 | destValue := reflect.New(destType).Elem() 61 | 62 | //从数据库中读出来的字段名字 63 | columnNameList, errColumns := rows.Columns() 64 | if errColumns != nil { 65 | return errColumns 66 | } 67 | 68 | for rows.Next() { 69 | var scans []interface{} 70 | for _, columnName := range columnNameList { 71 | if fieldName == columnName { 72 | scans = append(scans, destValue.Addr().Interface()) 73 | } else { 74 | var emptyVal interface{} 75 | scans = append(scans, &emptyVal) 76 | } 77 | } 78 | 79 | err := rows.Scan(scans...) 80 | if err != nil { 81 | return err 82 | } 83 | 84 | destSlice.Set(reflect.Append(destSlice, destValue)) 85 | } 86 | 87 | return nil 88 | } 89 | -------------------------------------------------------------------------------- /builder/where.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import ( 4 | "github.com/tangpanqing/aorm/utils" 5 | "reflect" 6 | ) 7 | 8 | // Where 链式操作,以对象作为查询条件 9 | func (b *Builder) Where(dest interface{}) *Builder { 10 | typeOf := reflect.TypeOf(dest) 11 | valueOf := reflect.ValueOf(dest) 12 | 13 | //如果没有设置表名 14 | if b.table == nil { 15 | b.table = getTableNameByReflect(typeOf, valueOf) 16 | } 17 | 18 | for i := 0; i < typeOf.Elem().NumField(); i++ { 19 | isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool() 20 | if isNotNull { 21 | key := utils.UnderLine(typeOf.Elem().Field(i).Name) 22 | val := valueOf.Elem().Field(i).Field(0).Field(0).Interface() 23 | b.whereList = append(b.whereList, WhereItem{Field: key, Opt: Eq, Val: val}) 24 | } 25 | } 26 | 27 | return b 28 | } 29 | 30 | // WhereArr 链式操作,以数组作为查询条件 31 | func (b *Builder) WhereArr(whereList []WhereItem) *Builder { 32 | b.whereList = append(b.whereList, whereList...) 33 | return b 34 | } 35 | 36 | func (b *Builder) WhereEq(field interface{}, val interface{}, prefix ...string) *Builder { 37 | return b.whereItemAppend(field, Eq, val, prefix...) 38 | } 39 | 40 | func (b *Builder) WhereNe(field interface{}, val interface{}, prefix ...string) *Builder { 41 | return b.whereItemAppend(field, Ne, val, prefix...) 42 | } 43 | 44 | func (b *Builder) WhereGt(field interface{}, val interface{}, prefix ...string) *Builder { 45 | return b.whereItemAppend(field, Gt, val, prefix...) 46 | } 47 | 48 | func (b *Builder) WhereGe(field interface{}, val interface{}, prefix ...string) *Builder { 49 | return b.whereItemAppend(field, Ge, val, prefix...) 50 | } 51 | 52 | func (b *Builder) WhereLt(field interface{}, val interface{}, prefix ...string) *Builder { 53 | return b.whereItemAppend(field, Lt, val, prefix...) 54 | } 55 | 56 | func (b *Builder) WhereLe(field interface{}, val interface{}, prefix ...string) *Builder { 57 | return b.whereItemAppend(field, Le, val, prefix...) 58 | } 59 | 60 | func (b *Builder) WhereIn(field interface{}, val interface{}, prefix ...string) *Builder { 61 | return b.whereItemAppend(field, In, val, prefix...) 62 | } 63 | 64 | func (b *Builder) WhereNotIn(field interface{}, val interface{}, prefix ...string) *Builder { 65 | return b.whereItemAppend(field, NotIn, val, prefix...) 66 | } 67 | 68 | func (b *Builder) WhereBetween(field interface{}, val interface{}, prefix ...string) *Builder { 69 | return b.whereItemAppend(field, Between, val, prefix...) 70 | } 71 | 72 | func (b *Builder) WhereNotBetween(field interface{}, val interface{}, prefix ...string) *Builder { 73 | return b.whereItemAppend(field, NotBetween, val, prefix...) 74 | } 75 | 76 | func (b *Builder) WhereLike(field interface{}, val interface{}, prefix ...string) *Builder { 77 | return b.whereItemAppend(field, Like, val, prefix...) 78 | } 79 | 80 | func (b *Builder) WhereNotLike(field interface{}, val interface{}, prefix ...string) *Builder { 81 | return b.whereItemAppend(field, NotLike, val, prefix...) 82 | } 83 | 84 | func (b *Builder) WhereRaw(val interface{}) *Builder { 85 | return b.whereItemAppend("", Raw, val) 86 | } 87 | 88 | func (b *Builder) WhereFindInSet(val interface{}, field interface{}, prefix ...string) *Builder { 89 | return b.whereItemAppend(field, FindInSet, val, prefix...) 90 | } 91 | 92 | func (b *Builder) WhereIsNull(field interface{}, prefix ...string) *Builder { 93 | return b.whereItemAppend(field, Raw, "IS NULL", prefix...) 94 | } 95 | 96 | func (b *Builder) WhereIsNOTNull(field interface{}, prefix ...string) *Builder { 97 | return b.whereItemAppend(field, Raw, "IS NOT NULL", prefix...) 98 | } 99 | 100 | func (b *Builder) WhereRawEq(field interface{}, val interface{}, prefix ...string) *Builder { 101 | return b.whereItemAppend(field, RawEq, val, prefix...) 102 | } 103 | 104 | func (b *Builder) whereItemAppend(field interface{}, opt string, val interface{}, prefix ...string) *Builder { 105 | b.whereList = append(b.whereList, WhereItem{prefix, field, opt, val}) 106 | return b 107 | } 108 | -------------------------------------------------------------------------------- /driver/driver.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | const Mysql = "mysql" 4 | const Mssql = "mssql" 5 | const Postgres = "postgres" 6 | const Sqlite3 = "sqlite3" 7 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/tangpanqing/aorm 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/denisenkom/go-mssqldb v0.12.3 7 | github.com/go-sql-driver/mysql v1.7.0 8 | github.com/lib/pq v1.10.7 9 | github.com/mattn/go-sqlite3 v1.14.16 10 | ) 11 | 12 | require ( 13 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect 14 | github.com/golang-sql/sqlexp v0.1.0 // indirect 15 | golang.org/x/crypto v0.4.0 // indirect 16 | ) 17 | -------------------------------------------------------------------------------- /migrate_mssql/migrate.go: -------------------------------------------------------------------------------- 1 | package migrate_mssql 2 | 3 | import ( 4 | "fmt" 5 | "github.com/tangpanqing/aorm/builder" 6 | "github.com/tangpanqing/aorm/null" 7 | "github.com/tangpanqing/aorm/utils" 8 | "reflect" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | type Table struct { 14 | TableName null.String 15 | } 16 | 17 | type Column struct { 18 | ColumnName null.String 19 | ColumnDefault null.String 20 | IsNullable null.String 21 | DataType null.String //数据类型 varchar,bigint,int 22 | MaxLength null.Int //数据最大长度 20 23 | ColumnComment null.String 24 | Extra null.String //扩展信息 auto_increment 25 | } 26 | 27 | type Index struct { 28 | NonUnique null.Int 29 | ColumnName null.String 30 | KeyName null.String 31 | } 32 | 33 | //MigrateExecutor 定义结构 34 | type MigrateExecutor struct { 35 | //执行者 36 | Builder *builder.Builder 37 | } 38 | 39 | //ShowCreateTable 查看创建表的ddl 40 | func (mm *MigrateExecutor) ShowCreateTable(tableName string) string { 41 | var str string 42 | mm.Builder.RawSql("show create table "+tableName).Value("Create Table", &str) 43 | return str 44 | } 45 | 46 | //MigrateCommon 迁移的主要过程 47 | func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type) error { 48 | tableFromCode := mm.getTableFromCode(tableName) 49 | columnsFromCode := mm.getColumnsFromCode(typeOf) 50 | indexesFromCode := mm.getIndexesFromCode(typeOf, tableFromCode) 51 | 52 | dbName, dbErr := mm.getDbName() 53 | if dbErr != nil { 54 | return dbErr 55 | } 56 | 57 | tablesFromDb := mm.getTableFromDb(dbName, tableName) 58 | if len(tablesFromDb) != 0 { 59 | tableFromDb := tablesFromDb[0] 60 | columnsFromDb := mm.getColumnsFromDb(dbName, tableName) 61 | indexesFromDb := mm.getIndexesFromDb(tableName) 62 | 63 | mm.modifyTable(tableFromCode, columnsFromCode, indexesFromCode, tableFromDb, columnsFromDb, indexesFromDb) 64 | } else { 65 | mm.createTable(tableFromCode, columnsFromCode, indexesFromCode) 66 | } 67 | 68 | return nil 69 | } 70 | 71 | func (mm *MigrateExecutor) getTableFromCode(tableName string) Table { 72 | var tableFromCode Table 73 | tableFromCode.TableName = null.StringFrom(tableName) 74 | 75 | return tableFromCode 76 | } 77 | 78 | func (mm *MigrateExecutor) getColumnsFromCode(typeOf reflect.Type) []Column { 79 | var columnsFromCode []Column 80 | for i := 0; i < typeOf.Elem().NumField(); i++ { 81 | fieldName := utils.UnderLine(typeOf.Elem().Field(i).Name) 82 | fieldType := typeOf.Elem().Field(i).Type.Name() 83 | fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm")) 84 | columnsFromCode = append(columnsFromCode, Column{ 85 | ColumnName: null.StringFrom(fieldName), 86 | DataType: null.StringFrom(getDataType(fieldType, fieldMap)), 87 | MaxLength: null.IntFrom(int64(getMaxLength(getDataType(fieldType, fieldMap), fieldMap))), 88 | IsNullable: null.StringFrom(getNullAble(fieldMap)), 89 | ColumnComment: null.StringFrom(getComment(fieldMap)), 90 | Extra: null.StringFrom(getExtra(fieldMap)), 91 | ColumnDefault: null.StringFrom(getDefaultVal(fieldMap)), 92 | }) 93 | } 94 | 95 | return columnsFromCode 96 | } 97 | 98 | func (mm *MigrateExecutor) getIndexesFromCode(typeOf reflect.Type, tableFromCode Table) []Index { 99 | var indexesFromCode []Index 100 | for i := 0; i < typeOf.Elem().NumField(); i++ { 101 | fieldName := utils.UnderLine(typeOf.Elem().Field(i).Name) 102 | fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm")) 103 | 104 | _, primaryIs := fieldMap["primary"] 105 | if primaryIs { 106 | indexesFromCode = append(indexesFromCode, Index{ 107 | NonUnique: null.IntFrom(0), 108 | ColumnName: null.StringFrom(fieldName), 109 | KeyName: null.StringFrom("PRIMARY"), 110 | }) 111 | } 112 | 113 | _, uniqueIndexIs := fieldMap["unique"] 114 | if uniqueIndexIs { 115 | indexesFromCode = append(indexesFromCode, Index{ 116 | NonUnique: null.IntFrom(0), 117 | ColumnName: null.StringFrom(fieldName), 118 | KeyName: null.StringFrom("idx_" + tableFromCode.TableName.String + "_" + fieldName), 119 | }) 120 | } 121 | 122 | _, indexIs := fieldMap["index"] 123 | if indexIs { 124 | indexesFromCode = append(indexesFromCode, Index{ 125 | NonUnique: null.IntFrom(1), 126 | ColumnName: null.StringFrom(fieldName), 127 | KeyName: null.StringFrom("idx_" + tableFromCode.TableName.String + "_" + fieldName), 128 | }) 129 | } 130 | } 131 | 132 | return indexesFromCode 133 | } 134 | 135 | func (mm *MigrateExecutor) getDbName() (string, error) { 136 | //获取数据库名称 137 | var dbName string 138 | err := mm.Builder.RawSql("Select Name as db_name From Master..SysDataBases Where DbId=(Select Dbid From Master..SysProcesses Where Spid = @@spid)").Value("db_name", &dbName) 139 | if err != nil { 140 | return "", err 141 | } 142 | 143 | return dbName, nil 144 | } 145 | 146 | func (mm *MigrateExecutor) getTableFromDb(dbName string, tableName string) []Table { 147 | sql := "SELECT Name as TABLE_NAME FROM SysObjects Where XType='U' and Name =" + "'" + tableName + "'" 148 | var dataList []Table 149 | mm.Builder.RawSql(sql).GetMany(&dataList) 150 | 151 | return dataList 152 | } 153 | 154 | func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []Column { 155 | var columnsFromDb []Column 156 | sqlColumn := "SELECT " + 157 | // "table_name = Case When A.colorder=1 Then D.name Else '' End," + 158 | // "table_comment = Case When A.colorder=1 Then isnull(F.value,'') Else '' End," + 159 | "column_name = A.name," + 160 | "column_comment = isnull(G.[value],'')," + 161 | "data_type = B.name," + 162 | "max_length = COLUMNPROPERTY(A.id,A.name,'PRECISION')," + 163 | "is_nullable = Case When A.isnullable=1 Then 'YES'Else 'NO' End," + 164 | "column_default = isnull(E.Text,'') " + 165 | "FROM syscolumns A " + 166 | "Left Join systypes B On A.xusertype=B.xusertype " + 167 | "Inner Join sysobjects D On A.id=D.id and D.xtype='U' and D.name<>'dtproperties' " + 168 | "Left Join syscomments E ON A.cdefault=E.id " + 169 | "Left Join sys.extended_properties G On A.id=G.major_id and A.colid=G.minor_id " + 170 | "Left Join sys.extended_properties F On D.id=F.major_id and F.minor_id=0 " + 171 | "Order By A.id,A.colorder" 172 | 173 | mm.Builder.RawSql(sqlColumn).GetMany(&columnsFromDb) 174 | 175 | return columnsFromDb 176 | } 177 | 178 | func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { 179 | sqlIndex := "SELECT " + 180 | "i.[name] AS 'key_name'," + 181 | "SUBSTRING(column_names, 1, LEN(column_names) - 1) AS 'column_name'," + 182 | "CASE WHEN i.is_unique = 1 THEN 0 ELSE 1 END AS 'non_unique' " + 183 | "FROM sys.objects t " + 184 | "INNER JOIN sys.indexes i ON t.object_id = i.object_id " + 185 | "CROSS APPLY " + 186 | "(SELECT col.[name] + ', ' " + 187 | "FROM sys.index_columns ic " + 188 | "INNER JOIN sys.columns col ON ic.object_id = col.object_id AND ic.column_id = col.column_id " + 189 | "WHERE ic.object_id = t.object_id " + 190 | "AND ic.index_id = i.index_id " + 191 | "ORDER BY col.column_id " + 192 | "FOR XML PATH('') " + 193 | ") D(column_names) " + 194 | "WHERE t.is_ms_shipped <> 1 " + 195 | "AND index_id > 0 " + 196 | "AND t.name = '" + tableName + "'" 197 | 198 | var indexesFromDb []Index 199 | mm.Builder.RawSql(sqlIndex).GetMany(&indexesFromDb) 200 | return indexesFromDb 201 | } 202 | 203 | func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexesFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexesFromDb []Index) { 204 | for i := 0; i < len(columnsFromCode); i++ { 205 | isFind := 0 206 | columnCode := columnsFromCode[i] 207 | 208 | for j := 0; j < len(columnsFromDb); j++ { 209 | columnDb := columnsFromDb[j] 210 | if columnCode.ColumnName == columnDb.ColumnName { 211 | isFind = 1 212 | if columnCode.DataType.String != columnDb.DataType.String { 213 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode) 214 | fmt.Println(sql) 215 | 216 | _, err := mm.Builder.RawSql(sql).Exec() 217 | if err != nil { 218 | fmt.Println(err) 219 | } else { 220 | fmt.Println("修改属性:" + sql) 221 | } 222 | } 223 | } 224 | } 225 | 226 | if isFind == 0 { 227 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode) 228 | _, err := mm.Builder.RawSql(sql).Exec() 229 | if err != nil { 230 | fmt.Println(err) 231 | } else { 232 | fmt.Println("增加属性:" + sql) 233 | } 234 | } 235 | } 236 | 237 | for i := 0; i < len(indexesFromCode); i++ { 238 | isFind := 0 239 | indexCode := indexesFromCode[i] 240 | 241 | for j := 0; j < len(indexesFromDb); j++ { 242 | indexDb := indexesFromDb[j] 243 | if indexCode.ColumnName == indexDb.ColumnName { 244 | isFind = 1 245 | 246 | keyMatch := false 247 | if "PRIMARY" == indexCode.KeyName.String && strings.Index(indexDb.KeyName.String, "PK__") != -1 { 248 | keyMatch = true 249 | } 250 | 251 | if "PRIMARY" != indexCode.KeyName.String && indexCode.KeyName.String == indexDb.KeyName.String { 252 | keyMatch = true 253 | } 254 | 255 | if !keyMatch || indexCode.NonUnique.Int64 != indexDb.NonUnique.Int64 { 256 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode) 257 | _, err := mm.Builder.RawSql(sql).Exec() 258 | if err != nil { 259 | fmt.Println(err) 260 | } else { 261 | fmt.Println("修改索引:" + sql) 262 | } 263 | } 264 | } 265 | } 266 | 267 | if isFind == 0 { 268 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getIndexStr(indexCode) 269 | _, err := mm.Builder.RawSql(sql).Exec() 270 | if err != nil { 271 | fmt.Println(err) 272 | } else { 273 | fmt.Println("增加索引:" + sql) 274 | } 275 | } 276 | } 277 | } 278 | 279 | func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Column, indexesFromCode []Index) { 280 | var fieldArr []string 281 | 282 | for i := 0; i < len(columnsFromCode); i++ { 283 | column := columnsFromCode[i] 284 | fieldArr = append(fieldArr, getColumnStr(column)) 285 | } 286 | 287 | for i := 0; i < len(indexesFromCode); i++ { 288 | index := indexesFromCode[i] 289 | fieldArr = append(fieldArr, getIndexStr(index)) 290 | } 291 | 292 | sqlStr := "CREATE TABLE " + tableFromCode.TableName.String + " (\n" + strings.Join(fieldArr, ",\n") + "\n) " + ";" 293 | 294 | _, err := mm.Builder.RawSql(sqlStr).Exec() 295 | if err != nil { 296 | fmt.Println(err) 297 | } else { 298 | fmt.Println("创建表:" + tableFromCode.TableName.String) 299 | } 300 | } 301 | 302 | func getTagMap(fieldTag string) map[string]string { 303 | var fieldMap = make(map[string]string) 304 | if "" != fieldTag { 305 | tagArr := strings.Split(fieldTag, ";") 306 | for j := 0; j < len(tagArr); j++ { 307 | tagArrArr := strings.Split(tagArr[j], ":") 308 | fieldMap[tagArrArr[0]] = "" 309 | if len(tagArrArr) > 1 { 310 | fieldMap[tagArrArr[0]] = tagArrArr[1] 311 | } 312 | } 313 | } 314 | return fieldMap 315 | } 316 | 317 | func getColumnStr(column Column) string { 318 | var strArr []string 319 | strArr = append(strArr, column.ColumnName.String) 320 | if column.MaxLength.Int64 == 0 { 321 | if column.DataType.String == "varchar" { 322 | strArr = append(strArr, column.DataType.String+"(255)") 323 | } else { 324 | strArr = append(strArr, column.DataType.String) 325 | } 326 | } else { 327 | strArr = append(strArr, column.DataType.String+"("+strconv.Itoa(int(column.MaxLength.Int64))+")") 328 | } 329 | 330 | if column.ColumnDefault.String != "" { 331 | strArr = append(strArr, "DEFAULT '"+column.ColumnDefault.String+"'") 332 | } 333 | 334 | if column.IsNullable.String == "NO" { 335 | strArr = append(strArr, "NOT NULL") 336 | } 337 | 338 | if column.ColumnComment.String != "" { 339 | //strArr = append(strArr, "COMMENT '"+column.ColumnComment.String+"'") 340 | } 341 | 342 | if column.Extra.String != "" { 343 | if column.Extra.String == "auto_increment" { 344 | column.Extra.String = "identity(1,1)" 345 | } 346 | strArr = append(strArr, column.Extra.String) 347 | } 348 | 349 | return strings.Join(strArr, " ") 350 | } 351 | 352 | func getIndexStr(index Index) string { 353 | var strArr []string 354 | 355 | if "PRIMARY" == index.KeyName.String { 356 | strArr = append(strArr, index.KeyName.String) 357 | strArr = append(strArr, "KEY") 358 | strArr = append(strArr, "("+index.ColumnName.String+")") 359 | } else { 360 | if 0 == index.NonUnique.Int64 { 361 | strArr = append(strArr, "Unique") 362 | strArr = append(strArr, index.KeyName.String) 363 | strArr = append(strArr, "("+index.ColumnName.String+")") 364 | } else { 365 | strArr = append(strArr, "Index") 366 | strArr = append(strArr, index.KeyName.String) 367 | strArr = append(strArr, "("+index.ColumnName.String+")") 368 | } 369 | } 370 | 371 | return strings.Join(strArr, " ") 372 | } 373 | 374 | func getDataType(fieldType string, fieldMap map[string]string) string { 375 | var DataType string 376 | 377 | dataTypeVal, dataTypeOk := fieldMap["type"] 378 | if dataTypeOk { 379 | DataType = dataTypeVal 380 | if DataType == "double" { 381 | DataType = "float" 382 | } 383 | } else { 384 | if "Int" == fieldType { 385 | DataType = "int" 386 | } 387 | if "String" == fieldType { 388 | DataType = "varchar" 389 | } 390 | if "Bool" == fieldType { 391 | DataType = "tinyint" 392 | } 393 | if "Time" == fieldType { 394 | DataType = "datetime" 395 | } 396 | if "Float" == fieldType { 397 | DataType = "float" 398 | } 399 | } 400 | 401 | return DataType 402 | } 403 | 404 | func getMaxLength(DataType string, fieldMap map[string]string) int { 405 | var MaxLength int 406 | 407 | maxLengthVal, maxLengthOk := fieldMap["size"] 408 | if maxLengthOk { 409 | num, _ := strconv.Atoi(maxLengthVal) 410 | MaxLength = num 411 | } else { 412 | MaxLength = 0 413 | if "varchar" == DataType { 414 | MaxLength = 255 415 | } 416 | } 417 | 418 | return MaxLength 419 | } 420 | 421 | func getNullAble(fieldMap map[string]string) string { 422 | var IsNullable string 423 | 424 | _, primaryOk := fieldMap["primary"] 425 | if primaryOk { 426 | IsNullable = "NO" 427 | } else { 428 | _, ok := fieldMap["not null"] 429 | if ok { 430 | IsNullable = "NO" 431 | } else { 432 | IsNullable = "YES" 433 | } 434 | } 435 | 436 | return IsNullable 437 | } 438 | 439 | func getComment(fieldMap map[string]string) string { 440 | commentVal, commentIs := fieldMap["comment"] 441 | if commentIs { 442 | return commentVal 443 | } 444 | 445 | return "" 446 | } 447 | 448 | func getExtra(fieldMap map[string]string) string { 449 | _, commentIs := fieldMap["auto_increment"] 450 | if commentIs { 451 | return "auto_increment" 452 | } 453 | 454 | return "" 455 | } 456 | 457 | func getDefaultVal(fieldMap map[string]string) string { 458 | defaultVal, defaultIs := fieldMap["default"] 459 | if defaultIs { 460 | return defaultVal 461 | } 462 | 463 | return "" 464 | } 465 | -------------------------------------------------------------------------------- /migrate_mysql/migrate.go: -------------------------------------------------------------------------------- 1 | package migrate_mysql 2 | 3 | import ( 4 | "fmt" 5 | "github.com/tangpanqing/aorm/builder" 6 | "github.com/tangpanqing/aorm/null" 7 | "github.com/tangpanqing/aorm/utils" 8 | "reflect" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | type Table struct { 14 | TableName null.String 15 | Engine null.String 16 | TableComment null.String 17 | } 18 | 19 | type Column struct { 20 | ColumnName null.String 21 | ColumnDefault null.String 22 | IsNullable null.String 23 | DataType null.String //数据类型 varchar,bigint,int 24 | MaxLength null.Int //数据最大长度 20 25 | ColumnComment null.String 26 | Extra null.String //扩展信息 auto_increment 27 | } 28 | 29 | type Index struct { 30 | NonUnique null.Int 31 | ColumnName null.String 32 | KeyName null.String 33 | } 34 | 35 | //MigrateExecutor 定义结构 36 | type MigrateExecutor struct { 37 | //执行者 38 | Builder *builder.Builder 39 | } 40 | 41 | //ShowCreateTable 查看创建表的ddl 42 | func (mm *MigrateExecutor) ShowCreateTable(tableName string) string { 43 | var str string 44 | mm.Builder.RawSql("show create table "+tableName).Value("Create Table", &str) 45 | return str 46 | } 47 | 48 | //MigrateCommon 迁移的主要过程 49 | func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type, valueOf reflect.Value) error { 50 | tableFromCode := mm.getTableFromCode(tableName, typeOf, valueOf) 51 | columnsFromCode := mm.getColumnsFromCode(typeOf) 52 | indexesFromCode := mm.getIndexesFromCode(typeOf, tableFromCode) 53 | 54 | dbName, dbErr := mm.getDbName() 55 | if dbErr != nil { 56 | return dbErr 57 | } 58 | 59 | tablesFromDb := mm.getTableFromDb(dbName, tableName) 60 | if len(tablesFromDb) != 0 { 61 | tableFromDb := tablesFromDb[0] 62 | columnsFromDb := mm.getColumnsFromDb(dbName, tableName) 63 | indexesFromDb := mm.getIndexesFromDb(tableName) 64 | 65 | mm.modifyTable(tableFromCode, columnsFromCode, indexesFromCode, tableFromDb, columnsFromDb, indexesFromDb) 66 | } else { 67 | mm.createTable(tableFromCode, columnsFromCode, indexesFromCode) 68 | } 69 | 70 | return nil 71 | } 72 | 73 | func (mm *MigrateExecutor) getTableFromCode(tableName string, typeOf reflect.Type, valueOf reflect.Value) Table { 74 | table := Table{ 75 | TableName: null.StringFrom(tableName), 76 | Engine: null.StringFrom("MyISAM"), 77 | TableComment: null.StringFrom("''"), 78 | } 79 | 80 | method, isSet := typeOf.MethodByName("TableOpinion") 81 | if isSet { 82 | var paramList []reflect.Value 83 | paramList = append(paramList, valueOf) 84 | valueList := method.Func.Call(paramList) 85 | i := valueList[0].Interface() 86 | m := i.(map[string]string) 87 | 88 | m["COMMENT"] = "'" + m["COMMENT"] + "'" 89 | table.Engine = null.StringFrom(m["ENGINE"]) 90 | table.TableComment = null.StringFrom(m["COMMENT"]) 91 | } 92 | 93 | return table 94 | } 95 | 96 | func (mm *MigrateExecutor) getColumnsFromCode(typeOf reflect.Type) []Column { 97 | var columnsFromCode []Column 98 | for i := 0; i < typeOf.Elem().NumField(); i++ { 99 | fieldName := utils.UnderLine(typeOf.Elem().Field(i).Name) 100 | fieldType := typeOf.Elem().Field(i).Type.Name() 101 | fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm")) 102 | 103 | //如果tag里重新设置了字段名 104 | if column, ok := fieldMap["column"]; ok { 105 | fieldName = column 106 | } 107 | 108 | columnsFromCode = append(columnsFromCode, Column{ 109 | ColumnName: null.StringFrom(fieldName), 110 | DataType: null.StringFrom(getDataType(fieldType, fieldMap)), 111 | MaxLength: null.IntFrom(int64(getMaxLength(getDataType(fieldType, fieldMap), fieldMap))), 112 | IsNullable: null.StringFrom(getNullAble(fieldMap)), 113 | ColumnComment: null.StringFrom(getComment(fieldMap)), 114 | Extra: null.StringFrom(getExtra(fieldMap)), 115 | ColumnDefault: null.StringFrom(getDefaultVal(fieldMap)), 116 | }) 117 | } 118 | 119 | return columnsFromCode 120 | } 121 | 122 | func (mm *MigrateExecutor) getIndexesFromCode(typeOf reflect.Type, tableFromCode Table) []Index { 123 | var indexesFromCode []Index 124 | for i := 0; i < typeOf.Elem().NumField(); i++ { 125 | fieldName := utils.UnderLine(typeOf.Elem().Field(i).Name) 126 | fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm")) 127 | 128 | _, primaryIs := fieldMap["primary"] 129 | if primaryIs { 130 | indexesFromCode = append(indexesFromCode, Index{ 131 | NonUnique: null.IntFrom(0), 132 | ColumnName: null.StringFrom(fieldName), 133 | KeyName: null.StringFrom("PRIMARY"), 134 | }) 135 | } 136 | 137 | _, uniqueIndexIs := fieldMap["unique"] 138 | if uniqueIndexIs { 139 | indexesFromCode = append(indexesFromCode, Index{ 140 | NonUnique: null.IntFrom(0), 141 | ColumnName: null.StringFrom(fieldName), 142 | KeyName: null.StringFrom("idx_" + tableFromCode.TableName.String + "_" + fieldName), 143 | }) 144 | } 145 | 146 | _, indexIs := fieldMap["index"] 147 | if indexIs { 148 | indexesFromCode = append(indexesFromCode, Index{ 149 | NonUnique: null.IntFrom(1), 150 | ColumnName: null.StringFrom(fieldName), 151 | KeyName: null.StringFrom("idx_" + tableFromCode.TableName.String + "_" + fieldName), 152 | }) 153 | } 154 | } 155 | 156 | return indexesFromCode 157 | } 158 | 159 | func (mm *MigrateExecutor) getDbName() (string, error) { 160 | //获取数据库名称 161 | var dbName string 162 | err := mm.Builder.RawSql("SELECT DATABASE()").Value("DATABASE()", &dbName) 163 | if err != nil { 164 | return "", err 165 | } 166 | 167 | return dbName, nil 168 | } 169 | 170 | func (mm *MigrateExecutor) getTableFromDb(dbName string, tableName string) []Table { 171 | sql := "SELECT TABLE_NAME,ENGINE,TABLE_COMMENT FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'" 172 | var dataList []Table 173 | mm.Builder.RawSql(sql).GetMany(&dataList) 174 | for i := 0; i < len(dataList); i++ { 175 | dataList[i].TableComment = null.StringFrom("'" + dataList[i].TableComment.String + "'") 176 | } 177 | 178 | return dataList 179 | } 180 | 181 | func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []Column { 182 | var columnsFromDb []Column 183 | 184 | sqlColumn := "SELECT COLUMN_NAME,DATA_TYPE,CHARACTER_MAXIMUM_LENGTH as Max_Length,COLUMN_DEFAULT,COLUMN_COMMENT,EXTRA,IS_NULLABLE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'" 185 | mm.Builder.RawSql(sqlColumn).GetMany(&columnsFromDb) 186 | 187 | for j := 0; j < len(columnsFromDb); j++ { 188 | if columnsFromDb[j].DataType.String == "text" || columnsFromDb[j].DataType.String == "tinytext" || columnsFromDb[j].DataType.String == "longtext" || columnsFromDb[j].DataType.String == "mediumtext" { 189 | columnsFromDb[j].MaxLength = null.IntFrom(0) 190 | } 191 | } 192 | 193 | return columnsFromDb 194 | } 195 | 196 | func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { 197 | sqlIndex := "SHOW INDEXES FROM " + tableName 198 | 199 | var indexsFromDb []Index 200 | mm.Builder.RawSql(sqlIndex).GetMany(&indexsFromDb) 201 | 202 | return indexsFromDb 203 | } 204 | 205 | func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexesFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexesFromDb []Index) { 206 | if tableFromCode.Engine != tableFromDb.Engine { 207 | mm.modifyTableEngine(tableFromCode) 208 | } 209 | 210 | if tableFromCode.TableComment != tableFromDb.TableComment { 211 | mm.modifyTableComment(tableFromCode) 212 | } 213 | 214 | for i := 0; i < len(columnsFromCode); i++ { 215 | isFind := 0 216 | columnCode := columnsFromCode[i] 217 | 218 | for j := 0; j < len(columnsFromDb); j++ { 219 | columnDb := columnsFromDb[j] 220 | if columnCode.ColumnName == columnDb.ColumnName { 221 | isFind = 1 222 | if columnCode.DataType.String != columnDb.DataType.String || 223 | columnCode.MaxLength.Int64 != columnDb.MaxLength.Int64 || 224 | columnCode.ColumnComment.String != columnDb.ColumnComment.String || 225 | columnCode.Extra.String != columnDb.Extra.String || 226 | columnCode.ColumnDefault.String != columnDb.ColumnDefault.String { 227 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode) 228 | _, err := mm.Builder.RawSql(sql).Exec() 229 | if err != nil { 230 | fmt.Println(err) 231 | } else { 232 | fmt.Println("修改属性:" + sql) 233 | } 234 | } 235 | } 236 | } 237 | 238 | if isFind == 0 { 239 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode) 240 | _, err := mm.Builder.RawSql(sql).Exec() 241 | if err != nil { 242 | fmt.Println(err) 243 | } else { 244 | fmt.Println("增加属性:" + sql) 245 | } 246 | } 247 | } 248 | 249 | for i := 0; i < len(indexesFromCode); i++ { 250 | isFind := 0 251 | indexCode := indexesFromCode[i] 252 | 253 | for j := 0; j < len(indexesFromDb); j++ { 254 | indexDb := indexesFromDb[j] 255 | if indexCode.ColumnName == indexDb.ColumnName { 256 | isFind = 1 257 | if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique { 258 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode) 259 | _, err := mm.Builder.RawSql(sql).Exec() 260 | if err != nil { 261 | fmt.Println(err) 262 | } else { 263 | fmt.Println("修改索引:" + sql) 264 | } 265 | } 266 | } 267 | } 268 | 269 | if isFind == 0 { 270 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getIndexStr(indexCode) 271 | _, err := mm.Builder.RawSql(sql).Exec() 272 | if err != nil { 273 | fmt.Println(err) 274 | } else { 275 | fmt.Println("增加索引:" + sql) 276 | } 277 | } 278 | } 279 | } 280 | 281 | func (mm *MigrateExecutor) modifyTableEngine(tableFromCode Table) { 282 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " Engine " + tableFromCode.Engine.String 283 | _, err := mm.Builder.RawSql(sql).Exec() 284 | if err != nil { 285 | fmt.Println(err) 286 | } else { 287 | fmt.Println("修改表:" + sql) 288 | } 289 | } 290 | 291 | func (mm *MigrateExecutor) modifyTableComment(tableFromCode Table) { 292 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " Comment " + tableFromCode.TableComment.String 293 | _, err := mm.Builder.RawSql(sql).Exec() 294 | if err != nil { 295 | fmt.Println(err) 296 | } else { 297 | fmt.Println("修改表:" + sql) 298 | } 299 | } 300 | 301 | func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Column, indexesFromCode []Index) { 302 | var fieldArr []string 303 | 304 | for i := 0; i < len(columnsFromCode); i++ { 305 | column := columnsFromCode[i] 306 | fieldArr = append(fieldArr, getColumnStr(column)) 307 | } 308 | 309 | for i := 0; i < len(indexesFromCode); i++ { 310 | index := indexesFromCode[i] 311 | fieldArr = append(fieldArr, getIndexStr(index)) 312 | } 313 | 314 | sql := "CREATE TABLE `" + tableFromCode.TableName.String + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + " ENGINE " + tableFromCode.Engine.String + " COMMENT " + tableFromCode.TableComment.String + ";" 315 | _, err := mm.Builder.RawSql(sql).Exec() 316 | if err != nil { 317 | fmt.Println(err) 318 | } else { 319 | fmt.Println("创建表:" + tableFromCode.TableName.String) 320 | } 321 | } 322 | 323 | func getTagMap(fieldTag string) map[string]string { 324 | var fieldMap = make(map[string]string) 325 | if "" != fieldTag { 326 | tagArr := strings.Split(fieldTag, ";") 327 | for j := 0; j < len(tagArr); j++ { 328 | tagArrArr := strings.Split(tagArr[j], ":") 329 | fieldMap[tagArrArr[0]] = "" 330 | if len(tagArrArr) > 1 { 331 | fieldMap[tagArrArr[0]] = tagArrArr[1] 332 | } 333 | } 334 | } 335 | return fieldMap 336 | } 337 | 338 | func getColumnStr(column Column) string { 339 | var strArr []string 340 | strArr = append(strArr, column.ColumnName.String) 341 | if column.MaxLength.Int64 == 0 { 342 | if column.DataType.String == "varchar" { 343 | strArr = append(strArr, column.DataType.String+"(255)") 344 | } else { 345 | strArr = append(strArr, column.DataType.String) 346 | } 347 | } else { 348 | strArr = append(strArr, column.DataType.String+"("+strconv.Itoa(int(column.MaxLength.Int64))+")") 349 | } 350 | 351 | if column.ColumnDefault.String != "" { 352 | strArr = append(strArr, "DEFAULT '"+column.ColumnDefault.String+"'") 353 | } 354 | 355 | if column.IsNullable.String == "NO" { 356 | strArr = append(strArr, "NOT NULL") 357 | } 358 | 359 | if column.ColumnComment.String != "" { 360 | strArr = append(strArr, "COMMENT '"+column.ColumnComment.String+"'") 361 | } 362 | 363 | if column.Extra.String != "" { 364 | strArr = append(strArr, column.Extra.String) 365 | } 366 | 367 | return strings.Join(strArr, " ") 368 | } 369 | 370 | func getIndexStr(index Index) string { 371 | var strArr []string 372 | 373 | if "PRIMARY" == index.KeyName.String { 374 | strArr = append(strArr, index.KeyName.String) 375 | strArr = append(strArr, "KEY") 376 | strArr = append(strArr, "(`"+index.ColumnName.String+"`)") 377 | } else { 378 | if 0 == index.NonUnique.Int64 { 379 | strArr = append(strArr, "Unique") 380 | strArr = append(strArr, index.KeyName.String) 381 | strArr = append(strArr, "(`"+index.ColumnName.String+"`)") 382 | } else { 383 | strArr = append(strArr, "Index") 384 | strArr = append(strArr, index.KeyName.String) 385 | strArr = append(strArr, "(`"+index.ColumnName.String+"`)") 386 | } 387 | } 388 | 389 | return strings.Join(strArr, " ") 390 | } 391 | 392 | func getDataType(fieldType string, fieldMap map[string]string) string { 393 | var DataType string 394 | 395 | dataTypeVal, dataTypeOk := fieldMap["type"] 396 | if dataTypeOk { 397 | DataType = dataTypeVal 398 | } else { 399 | if "Int" == fieldType { 400 | DataType = "int" 401 | } 402 | if "String" == fieldType { 403 | DataType = "varchar" 404 | } 405 | if "Bool" == fieldType { 406 | DataType = "tinyint" 407 | } 408 | if "Time" == fieldType { 409 | DataType = "datetime" 410 | } 411 | if "Float" == fieldType { 412 | DataType = "float" 413 | } 414 | } 415 | 416 | return DataType 417 | } 418 | 419 | func getMaxLength(DataType string, fieldMap map[string]string) int { 420 | var MaxLength int 421 | 422 | maxLengthVal, maxLengthOk := fieldMap["size"] 423 | if maxLengthOk { 424 | num, _ := strconv.Atoi(maxLengthVal) 425 | MaxLength = num 426 | } else { 427 | MaxLength = 0 428 | if "varchar" == DataType { 429 | MaxLength = 255 430 | } 431 | } 432 | 433 | return MaxLength 434 | } 435 | 436 | func getNullAble(fieldMap map[string]string) string { 437 | var IsNullable string 438 | 439 | _, primaryOk := fieldMap["primary"] 440 | if primaryOk { 441 | IsNullable = "NO" 442 | } else { 443 | _, ok := fieldMap["not null"] 444 | if ok { 445 | IsNullable = "NO" 446 | } else { 447 | IsNullable = "YES" 448 | } 449 | } 450 | 451 | return IsNullable 452 | } 453 | 454 | func getComment(fieldMap map[string]string) string { 455 | commentVal, commentIs := fieldMap["comment"] 456 | if commentIs { 457 | return commentVal 458 | } 459 | 460 | return "" 461 | } 462 | 463 | func getExtra(fieldMap map[string]string) string { 464 | _, commentIs := fieldMap["auto_increment"] 465 | if commentIs { 466 | return "auto_increment" 467 | } 468 | 469 | return "" 470 | } 471 | 472 | func getDefaultVal(fieldMap map[string]string) string { 473 | defaultVal, defaultIs := fieldMap["default"] 474 | if defaultIs { 475 | return defaultVal 476 | } 477 | 478 | return "" 479 | } 480 | -------------------------------------------------------------------------------- /migrate_postgres/migrate.go: -------------------------------------------------------------------------------- 1 | package migrate_postgres 2 | 3 | import ( 4 | "fmt" 5 | "github.com/tangpanqing/aorm/builder" 6 | "github.com/tangpanqing/aorm/null" 7 | "github.com/tangpanqing/aorm/utils" 8 | "reflect" 9 | "regexp" 10 | "strconv" 11 | "strings" 12 | ) 13 | 14 | type PgIndexes struct { 15 | Schemaname null.String 16 | Tablename null.String 17 | Indexname null.String 18 | Tablespace null.String 19 | Indexdef null.String 20 | } 21 | 22 | type Table struct { 23 | TableName null.String 24 | TableComment null.String 25 | } 26 | 27 | type Column struct { 28 | ColumnName null.String 29 | ColumnDefault null.String 30 | IsNullable null.String 31 | DataType null.String //数据类型 varchar,bigint,int 32 | MaxLength null.Int //数据最大长度 20 33 | ColumnComment null.String 34 | Extra null.String //扩展信息 auto_increment 35 | } 36 | 37 | type Index struct { 38 | NonUnique null.Int 39 | ColumnName null.String 40 | KeyName null.String 41 | } 42 | 43 | //MigrateExecutor 定义结构 44 | type MigrateExecutor struct { 45 | //执行者 46 | Builder *builder.Builder 47 | } 48 | 49 | //ShowCreateTable 查看创建表的ddl 50 | func (mm *MigrateExecutor) ShowCreateTable(tableName string) string { 51 | var str string 52 | mm.Builder.RawSql("show create table "+tableName).Value("Create Table", &str) 53 | return str 54 | } 55 | 56 | //MigrateCommon 迁移的主要过程 57 | func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type, valueOf reflect.Value) error { 58 | tableFromCode := mm.getTableFromCode(tableName, typeOf, valueOf) 59 | columnsFromCode := mm.getColumnsFromCode(typeOf) 60 | indexesFromCode := mm.getIndexesFromCode(typeOf, tableFromCode) 61 | 62 | dbName, dbErr := mm.getDbName() 63 | if dbErr != nil { 64 | return dbErr 65 | } 66 | 67 | tablesFromDb := mm.getTableFromDb(dbName, tableName) 68 | if len(tablesFromDb) != 0 { 69 | tableFromDb := tablesFromDb[0] 70 | columnsFromDb := mm.getColumnsFromDb(dbName, tableName) 71 | indexesFromDb := mm.getIndexesFromDb(tableName) 72 | 73 | mm.modifyTable(tableFromCode, columnsFromCode, indexesFromCode, tableFromDb, columnsFromDb, indexesFromDb) 74 | } else { 75 | mm.createTable(tableFromCode, columnsFromCode, indexesFromCode) 76 | } 77 | 78 | return nil 79 | } 80 | 81 | func (mm *MigrateExecutor) getTableFromCode(tableName string, typeOf reflect.Type, valueOf reflect.Value) Table { 82 | table := Table{ 83 | TableName: null.StringFrom(tableName), 84 | TableComment: null.StringFrom("''"), 85 | } 86 | 87 | method, isSet := typeOf.MethodByName("TableOpinion") 88 | if isSet { 89 | var paramList []reflect.Value 90 | paramList = append(paramList, valueOf) 91 | valueList := method.Func.Call(paramList) 92 | i := valueList[0].Interface() 93 | m := i.(map[string]string) 94 | 95 | m["COMMENT"] = "'" + m["COMMENT"] + "'" 96 | table.TableComment = null.StringFrom(m["COMMENT"]) 97 | } 98 | 99 | return table 100 | } 101 | 102 | func (mm *MigrateExecutor) getColumnsFromCode(typeOf reflect.Type) []Column { 103 | var columnsFromCode []Column 104 | for i := 0; i < typeOf.Elem().NumField(); i++ { 105 | fieldName := utils.UnderLine(typeOf.Elem().Field(i).Name) 106 | fieldType := typeOf.Elem().Field(i).Type.Name() 107 | fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm")) 108 | columnsFromCode = append(columnsFromCode, Column{ 109 | ColumnName: null.StringFrom(fieldName), 110 | DataType: null.StringFrom(getDataType(fieldType, fieldMap)), 111 | MaxLength: null.IntFrom(int64(getMaxLength(getDataType(fieldType, fieldMap), fieldMap))), 112 | IsNullable: null.StringFrom(getNullAble(fieldMap)), 113 | ColumnComment: null.StringFrom(getComment(fieldMap)), 114 | Extra: null.StringFrom(getExtra(fieldMap)), 115 | ColumnDefault: null.StringFrom(getDefaultVal(fieldMap)), 116 | }) 117 | } 118 | 119 | return columnsFromCode 120 | } 121 | 122 | func (mm *MigrateExecutor) getIndexesFromCode(typeOf reflect.Type, tableFromCode Table) []Index { 123 | var indexesFromCode []Index 124 | for i := 0; i < typeOf.Elem().NumField(); i++ { 125 | fieldName := utils.UnderLine(typeOf.Elem().Field(i).Name) 126 | fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm")) 127 | 128 | _, primaryIs := fieldMap["primary"] 129 | if primaryIs { 130 | indexesFromCode = append(indexesFromCode, Index{ 131 | NonUnique: null.IntFrom(0), 132 | ColumnName: null.StringFrom(fieldName), 133 | KeyName: null.StringFrom("PRIMARY"), 134 | }) 135 | } 136 | 137 | _, uniqueIndexIs := fieldMap["unique"] 138 | if uniqueIndexIs { 139 | indexesFromCode = append(indexesFromCode, Index{ 140 | NonUnique: null.IntFrom(0), 141 | ColumnName: null.StringFrom(fieldName), 142 | KeyName: null.StringFrom("idx_" + tableFromCode.TableName.String + "_" + fieldName), 143 | }) 144 | } 145 | 146 | _, indexIs := fieldMap["index"] 147 | if indexIs { 148 | indexesFromCode = append(indexesFromCode, Index{ 149 | NonUnique: null.IntFrom(1), 150 | ColumnName: null.StringFrom(fieldName), 151 | KeyName: null.StringFrom("idx_" + tableFromCode.TableName.String + "_" + fieldName), 152 | }) 153 | } 154 | } 155 | 156 | return indexesFromCode 157 | } 158 | 159 | func (mm *MigrateExecutor) getDbName() (string, error) { 160 | //获取数据库名称 161 | var dbName string 162 | err := mm.Builder.RawSql("select current_database()").Value("current_database", &dbName) 163 | if err != nil { 164 | return "", err 165 | } 166 | 167 | return dbName, nil 168 | } 169 | 170 | func (mm *MigrateExecutor) getTableFromDb(dbName string, tableName string) []Table { 171 | sql := "select a.relname as TABLE_NAME, b.description as TABLE_COMMENT from pg_class a left join (select * from pg_description where objsubid =0) b on a.oid = b.objoid where a.relname in (select tablename from pg_tables where schemaname = 'public' and tablename = " + "'" + tableName + "') order by a.relname asc" 172 | var dataList []Table 173 | mm.Builder.RawSql(sql).GetMany(&dataList) 174 | for i := 0; i < len(dataList); i++ { 175 | dataList[i].TableComment = null.StringFrom("'" + dataList[i].TableComment.String + "'") 176 | } 177 | 178 | return dataList 179 | } 180 | 181 | func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []Column { 182 | var columnsFromDb []Column 183 | 184 | sqlColumn := "select column_name,data_type,character_maximum_length as max_length,column_default,'' as COLUMN_COMMENT, is_nullable from information_schema.columns where table_schema='public' and table_name=" + "'" + tableName + "'" 185 | 186 | mm.Builder.RawSql(sqlColumn).GetMany(&columnsFromDb) 187 | 188 | for j := 0; j < len(columnsFromDb); j++ { 189 | if columnsFromDb[j].DataType.String == "character varying" { 190 | columnsFromDb[j].DataType = null.StringFrom("varchar") 191 | } 192 | 193 | if columnsFromDb[j].DataType.String == "double precision" { 194 | columnsFromDb[j].DataType = null.StringFrom("float") 195 | } 196 | 197 | if columnsFromDb[j].DataType.String == "timestamp without time zone" { 198 | columnsFromDb[j].DataType = null.StringFrom("timestamp") 199 | } 200 | } 201 | 202 | return columnsFromDb 203 | } 204 | 205 | func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { 206 | sqlIndex := "select * from pg_indexes where tablename=" + "'" + tableName + "'" 207 | var sqliteMasterList []PgIndexes 208 | mm.Builder.RawSql(sqlIndex).GetMany(&sqliteMasterList) 209 | 210 | var indexesFromDb []Index 211 | for i := 0; i < len(sqliteMasterList); i++ { 212 | indexName := sqliteMasterList[i].Indexname.String 213 | sql := sqliteMasterList[i].Indexdef.String 214 | 215 | t := 1 216 | if strings.Index(sql, "UNIQUE") != -1 { 217 | t = 0 218 | } 219 | 220 | compileRegex := regexp.MustCompile("INDEX\\s(.*?)\\sON.*?\\((.*?)\\)") 221 | matchArr := compileRegex.FindAllStringSubmatch(sql, -1) 222 | 223 | //主键索引 224 | if indexName == tableName+"_pkey" { 225 | indexesFromDb = append(indexesFromDb, Index{ 226 | NonUnique: null.IntFrom(int64(t)), 227 | ColumnName: null.StringFrom(matchArr[0][2]), 228 | KeyName: null.StringFrom("PRIMARY"), 229 | }) 230 | } else { 231 | indexesFromDb = append(indexesFromDb, Index{ 232 | NonUnique: null.IntFrom(int64(t)), 233 | ColumnName: null.StringFrom(matchArr[0][2]), 234 | KeyName: null.StringFrom(matchArr[0][1]), 235 | }) 236 | } 237 | } 238 | 239 | return indexesFromDb 240 | } 241 | 242 | func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexesFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexesFromDb []Index) { 243 | //if tableFromCode.TableComment != tableFromDb.TableComment { 244 | // mm.modifyTableComment(tableFromCode) 245 | //} 246 | 247 | for i := 0; i < len(columnsFromCode); i++ { 248 | isFind := 0 249 | columnCode := columnsFromCode[i] 250 | 251 | for j := 0; j < len(columnsFromDb); j++ { 252 | columnDb := columnsFromDb[j] 253 | if columnCode.ColumnName.String == columnDb.ColumnName.String { 254 | isFind = 1 255 | if columnCode.DataType.String != columnDb.DataType.String { 256 | fmt.Println(columnCode.ColumnName.String, columnCode.DataType.String, columnDb.DataType.String) 257 | 258 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " alter COLUMN " + getColumnStr(columnCode, "driver") 259 | //fmt.Println(base) 260 | 261 | _, err := mm.Builder.RawSql(sql).Exec() 262 | if err != nil { 263 | fmt.Println(err) 264 | } else { 265 | fmt.Println("修改属性:" + sql) 266 | } 267 | } 268 | } 269 | } 270 | 271 | if isFind == 0 { 272 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode, "") 273 | _, err := mm.Builder.RawSql(sql).Exec() 274 | if err != nil { 275 | fmt.Println(err) 276 | } else { 277 | fmt.Println("增加属性:" + sql) 278 | } 279 | } 280 | } 281 | 282 | for i := 0; i < len(indexesFromCode); i++ { 283 | isFind := 0 284 | indexCode := indexesFromCode[i] 285 | 286 | for j := 0; j < len(indexesFromDb); j++ { 287 | indexDb := indexesFromDb[j] 288 | if indexCode.ColumnName == indexDb.ColumnName { 289 | isFind = 1 290 | if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique { 291 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode) 292 | _, err := mm.Builder.RawSql(sql).Exec() 293 | if err != nil { 294 | fmt.Println(err) 295 | } else { 296 | fmt.Println("修改索引:" + sql) 297 | } 298 | } 299 | } 300 | } 301 | 302 | if isFind == 0 { 303 | mm.createIndex(tableFromCode.TableName.String, indexCode) 304 | } 305 | } 306 | } 307 | 308 | func (mm *MigrateExecutor) modifyTableComment(tableFromCode Table) { 309 | sql := "ALTER TABLE " + tableFromCode.TableName.String + " Comment " + tableFromCode.TableComment.String 310 | _, err := mm.Builder.RawSql(sql).Exec() 311 | if err != nil { 312 | fmt.Println(err) 313 | } else { 314 | fmt.Println("修改表:" + sql) 315 | } 316 | } 317 | 318 | func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Column, indexesFromCode []Index) { 319 | var fieldArr []string 320 | 321 | for i := 0; i < len(columnsFromCode); i++ { 322 | column := columnsFromCode[i] 323 | fieldArr = append(fieldArr, getColumnStr(column, "")) 324 | } 325 | 326 | for i := 0; i < len(indexesFromCode); i++ { 327 | index := indexesFromCode[i] 328 | if index.KeyName.String == "PRIMARY" { 329 | fieldArr = append(fieldArr, "PRIMARY KEY ("+index.ColumnName.String+")") 330 | } 331 | } 332 | 333 | sql := "CREATE TABLE " + tableFromCode.TableName.String + " (\n" + strings.Join(fieldArr, ",\n") + "\n) " + ";" 334 | 335 | _, err := mm.Builder.RawSql(sql).Exec() 336 | if err != nil { 337 | fmt.Println(err) 338 | } else { 339 | fmt.Println("创建表:" + tableFromCode.TableName.String) 340 | } 341 | 342 | //创建其他索引 343 | for i := 0; i < len(indexesFromCode); i++ { 344 | index := indexesFromCode[i] 345 | if index.KeyName.String != "PRIMARY" { 346 | mm.createIndex(tableFromCode.TableName.String, index) 347 | } 348 | } 349 | } 350 | 351 | func (mm *MigrateExecutor) createIndex(tableName string, index Index) { 352 | keyType := "" 353 | if index.NonUnique.Int64 == 0 { 354 | keyType = "UNIQUE" 355 | } 356 | 357 | sql := "CREATE " + keyType + " INDEX " + index.KeyName.String + " on " + tableName + " (" + index.ColumnName.String + ")" 358 | _, err := mm.Builder.RawSql(sql).Exec() 359 | if err != nil { 360 | fmt.Println(err) 361 | } else { 362 | fmt.Println("增加索引:" + sql) 363 | } 364 | } 365 | 366 | func getTagMap(fieldTag string) map[string]string { 367 | var fieldMap = make(map[string]string) 368 | if "" != fieldTag { 369 | tagArr := strings.Split(fieldTag, ";") 370 | for j := 0; j < len(tagArr); j++ { 371 | tagArrArr := strings.Split(tagArr[j], ":") 372 | fieldMap[tagArrArr[0]] = "" 373 | if len(tagArrArr) > 1 { 374 | fieldMap[tagArrArr[0]] = tagArrArr[1] 375 | } 376 | } 377 | } 378 | return fieldMap 379 | } 380 | 381 | func getColumnStr(column Column, f string) string { 382 | var strArr []string 383 | strArr = append(strArr, column.ColumnName.String) 384 | 385 | //类型 386 | if column.Extra.String == "auto_increment" { 387 | strArr = append(strArr, "serial") 388 | } else { 389 | if column.MaxLength.Int64 == 0 { 390 | if column.DataType.String == "varchar" { 391 | strArr = append(strArr, column.DataType.String+"(255)") 392 | } else { 393 | strArr = append(strArr, f+" "+column.DataType.String) 394 | } 395 | } else { 396 | strArr = append(strArr, column.DataType.String+"("+strconv.Itoa(int(column.MaxLength.Int64))+")") 397 | } 398 | } 399 | 400 | if column.ColumnDefault.String != "" { 401 | strArr = append(strArr, "DEFAULT '"+column.ColumnDefault.String+"'") 402 | } 403 | 404 | if column.IsNullable.String == "NO" { 405 | //strArr = append(strArr, "NOT NULL") 406 | } 407 | 408 | if column.ColumnComment.String != "" { 409 | //strArr = append(strArr, "COMMENT '"+column.ColumnComment.String+"'") 410 | } 411 | 412 | if column.Extra.String != "" { 413 | //strArr = append(strArr, column.Extra.String) 414 | } 415 | 416 | return strings.Join(strArr, " ") 417 | } 418 | 419 | func getIndexStr(index Index) string { 420 | var strArr []string 421 | 422 | if "PRIMARY" == index.KeyName.String { 423 | strArr = append(strArr, index.KeyName.String) 424 | strArr = append(strArr, "KEY") 425 | strArr = append(strArr, "("+index.ColumnName.String+")") 426 | } else { 427 | if 0 == index.NonUnique.Int64 { 428 | strArr = append(strArr, "Unique") 429 | strArr = append(strArr, index.KeyName.String) 430 | strArr = append(strArr, "("+index.ColumnName.String+")") 431 | } else { 432 | strArr = append(strArr, "Index") 433 | strArr = append(strArr, index.KeyName.String) 434 | strArr = append(strArr, "("+index.ColumnName.String+")") 435 | } 436 | } 437 | 438 | return strings.Join(strArr, " ") 439 | } 440 | 441 | func getDataType(fieldType string, fieldMap map[string]string) string { 442 | var DataType string 443 | 444 | dataTypeVal, dataTypeOk := fieldMap["type"] 445 | if dataTypeOk { 446 | DataType = dataTypeVal 447 | if "tinyint" == DataType { 448 | DataType = "integer" 449 | } 450 | if "double" == DataType { 451 | DataType = "float" 452 | } 453 | } else { 454 | if "Int" == fieldType { 455 | DataType = "integer" 456 | } 457 | if "String" == fieldType { 458 | DataType = "varchar" 459 | } 460 | if "Bool" == fieldType { 461 | //DataType = "tinyint" 462 | DataType = "boolean" 463 | } 464 | if "Time" == fieldType { 465 | DataType = "date" 466 | DataType = "timestamp" 467 | } 468 | if "Float" == fieldType { 469 | DataType = "float" 470 | } 471 | } 472 | 473 | return DataType 474 | } 475 | 476 | func getMaxLength(DataType string, fieldMap map[string]string) int { 477 | var MaxLength int 478 | 479 | maxLengthVal, maxLengthOk := fieldMap["size"] 480 | if maxLengthOk { 481 | num, _ := strconv.Atoi(maxLengthVal) 482 | MaxLength = num 483 | } else { 484 | MaxLength = 0 485 | if "varchar" == DataType { 486 | MaxLength = 255 487 | } 488 | } 489 | 490 | return MaxLength 491 | } 492 | 493 | func getNullAble(fieldMap map[string]string) string { 494 | var IsNullable string 495 | 496 | _, primaryOk := fieldMap["primary"] 497 | if primaryOk { 498 | IsNullable = "NO" 499 | } else { 500 | _, ok := fieldMap["not null"] 501 | if ok { 502 | IsNullable = "NO" 503 | } else { 504 | IsNullable = "YES" 505 | } 506 | } 507 | 508 | return IsNullable 509 | } 510 | 511 | func getComment(fieldMap map[string]string) string { 512 | commentVal, commentIs := fieldMap["comment"] 513 | if commentIs { 514 | return commentVal 515 | } 516 | 517 | return "" 518 | } 519 | 520 | func getExtra(fieldMap map[string]string) string { 521 | _, commentIs := fieldMap["auto_increment"] 522 | if commentIs { 523 | return "auto_increment" 524 | } 525 | 526 | return "" 527 | } 528 | 529 | func getDefaultVal(fieldMap map[string]string) string { 530 | defaultVal, defaultIs := fieldMap["default"] 531 | if defaultIs { 532 | return defaultVal 533 | } 534 | 535 | return "" 536 | } 537 | -------------------------------------------------------------------------------- /migrate_sqlite3/migrate.go: -------------------------------------------------------------------------------- 1 | package migrate_sqlite3 2 | 3 | import ( 4 | "fmt" 5 | "github.com/tangpanqing/aorm/builder" 6 | "github.com/tangpanqing/aorm/null" 7 | "github.com/tangpanqing/aorm/utils" 8 | "reflect" 9 | "regexp" 10 | "strconv" 11 | "strings" 12 | ) 13 | 14 | type SqliteMaster struct { 15 | Type null.String 16 | Name null.String 17 | TblName null.String 18 | Rootpage null.String 19 | Sql null.String 20 | } 21 | 22 | type Table struct { 23 | TableName null.String 24 | } 25 | 26 | type Column struct { 27 | ColumnName null.String 28 | ColumnDefault null.String 29 | IsNullable null.String 30 | DataType null.String //数据类型 varchar,bigint,int 31 | MaxLength null.Int //数据最大长度 20 32 | Extra null.String //扩展信息 auto_increment 33 | } 34 | 35 | type Index struct { 36 | NonUnique null.Int 37 | ColumnName null.String 38 | KeyName null.String 39 | } 40 | 41 | //MigrateExecutor 定义结构 42 | type MigrateExecutor struct { 43 | //执行者 44 | Builder *builder.Builder 45 | } 46 | 47 | //ShowCreateTable 查看创建表的ddl 48 | func (mm *MigrateExecutor) ShowCreateTable(tableName string) string { 49 | var str string 50 | mm.Builder.RawSql("show create table "+tableName).Value("Create Table", &str) 51 | return str 52 | } 53 | 54 | //MigrateCommon 迁移的主要过程 55 | func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type) error { 56 | tableFromCode := mm.getTableFromCode(tableName) 57 | columnsFromCode := mm.getColumnsFromCode(typeOf) 58 | indexesFromCode := mm.getIndexesFromCode(typeOf, tableFromCode) 59 | 60 | dbName, dbErr := mm.getDbName() 61 | if dbErr != nil { 62 | return dbErr 63 | } 64 | 65 | tablesFromDb := mm.getTableFromDb(dbName, tableName) 66 | if len(tablesFromDb) != 0 { 67 | tableFromDb := tablesFromDb[0] 68 | columnsFromDb := mm.getColumnsFromDb(dbName, tableName) 69 | indexesFromDb := mm.getIndexesFromDb(tableName) 70 | 71 | mm.modifyTable(tableFromCode, columnsFromCode, indexesFromCode, tableFromDb, columnsFromDb, indexesFromDb) 72 | } else { 73 | mm.createTable(tableFromCode, columnsFromCode, indexesFromCode) 74 | } 75 | 76 | return nil 77 | } 78 | 79 | func (mm *MigrateExecutor) getTableFromCode(tableName string) Table { 80 | var tableFromCode Table 81 | tableFromCode.TableName = null.StringFrom(tableName) 82 | 83 | return tableFromCode 84 | } 85 | 86 | func (mm *MigrateExecutor) getColumnsFromCode(typeOf reflect.Type) []Column { 87 | var columnsFromCode []Column 88 | for i := 0; i < typeOf.Elem().NumField(); i++ { 89 | fieldName := utils.UnderLine(typeOf.Elem().Field(i).Name) 90 | fieldType := typeOf.Elem().Field(i).Type.Name() 91 | fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm")) 92 | columnsFromCode = append(columnsFromCode, Column{ 93 | ColumnName: null.StringFrom(fieldName), 94 | DataType: null.StringFrom(getDataType(fieldType, fieldMap)), 95 | MaxLength: null.IntFrom(int64(getMaxLength(getDataType(fieldType, fieldMap), fieldMap))), 96 | IsNullable: null.StringFrom(getNullAble(fieldMap)), 97 | Extra: null.StringFrom(getExtra(fieldMap)), 98 | ColumnDefault: null.StringFrom(getDefaultVal(fieldMap)), 99 | }) 100 | } 101 | 102 | return columnsFromCode 103 | } 104 | 105 | func (mm *MigrateExecutor) getIndexesFromCode(typeOf reflect.Type, tableFromCode Table) []Index { 106 | var indexesFromCode []Index 107 | for i := 0; i < typeOf.Elem().NumField(); i++ { 108 | fieldName := utils.UnderLine(typeOf.Elem().Field(i).Name) 109 | fieldMap := getTagMap(typeOf.Elem().Field(i).Tag.Get("aorm")) 110 | 111 | _, primaryIs := fieldMap["primary"] 112 | if primaryIs { 113 | indexesFromCode = append(indexesFromCode, Index{ 114 | NonUnique: null.IntFrom(0), 115 | ColumnName: null.StringFrom(fieldName), 116 | KeyName: null.StringFrom("PRIMARY"), 117 | }) 118 | } 119 | 120 | _, uniqueIndexIs := fieldMap["unique"] 121 | if uniqueIndexIs { 122 | indexesFromCode = append(indexesFromCode, Index{ 123 | NonUnique: null.IntFrom(0), 124 | ColumnName: null.StringFrom(fieldName), 125 | KeyName: null.StringFrom("idx_" + tableFromCode.TableName.String + "_" + fieldName), 126 | }) 127 | } 128 | 129 | _, indexIs := fieldMap["index"] 130 | if indexIs { 131 | indexesFromCode = append(indexesFromCode, Index{ 132 | NonUnique: null.IntFrom(1), 133 | ColumnName: null.StringFrom(fieldName), 134 | KeyName: null.StringFrom("idx_" + tableFromCode.TableName.String + "_" + fieldName), 135 | }) 136 | } 137 | } 138 | 139 | return indexesFromCode 140 | } 141 | 142 | func (mm *MigrateExecutor) getDbName() (string, error) { 143 | return "main", nil 144 | } 145 | 146 | func (mm *MigrateExecutor) getTableFromDb(dbName string, tableName string) []Table { 147 | query := "select * from sqlite_master where type='table' and tbl_name=" + "'" + tableName + "'" 148 | var sqliteMasterList []SqliteMaster 149 | mm.Builder.RawSql(query).GetMany(&sqliteMasterList) 150 | 151 | var dataList []Table 152 | for i := 0; i < len(sqliteMasterList); i++ { 153 | dataList = append(dataList, Table{ 154 | TableName: null.StringFrom(sqliteMasterList[i].TblName.String), 155 | }) 156 | } 157 | 158 | return dataList 159 | } 160 | 161 | func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []Column { 162 | var columnsFromDb []Column 163 | 164 | var sqliteMaster SqliteMaster 165 | sqlColumn1 := "select * from sqlite_master where type='table' and tbl_name = " + "'" + tableName + "'" 166 | mm.Builder.RawSql(sqlColumn1).GetOne(&sqliteMaster) 167 | 168 | str := sqliteMaster.Sql.String 169 | str = strings.ReplaceAll(str, "\n", "") 170 | compileRegex := regexp.MustCompile("\\(.*\\)") 171 | matchArr := compileRegex.FindAllString(str, -1) 172 | matchArr[0] = strings.TrimLeft(matchArr[0], "(") 173 | matchArr[0] = strings.TrimRight(matchArr[0], ")") 174 | 175 | strArr := strings.Split(matchArr[0], ",") 176 | for i := 0; i < len(strArr); i++ { 177 | columnStr := strings.TrimSpace(strArr[i]) 178 | columnStr = strings.Replace(columnStr, "NOT NULL", "NOT_NULL", -1) 179 | columnArr := strings.Split(columnStr, " ") 180 | 181 | columnName := columnArr[0] 182 | dataType := columnArr[1] 183 | IsNullable := "YES" 184 | if len(columnArr) >= 3 { 185 | IsNullable = "NO" 186 | } 187 | 188 | columnsFromDb = append(columnsFromDb, Column{ 189 | ColumnName: null.StringFrom(columnName), 190 | DataType: null.StringFrom(dataType), 191 | IsNullable: null.StringFrom(IsNullable), 192 | }) 193 | } 194 | 195 | return columnsFromDb 196 | } 197 | 198 | func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { 199 | sqlIndex := "select * from sqlite_master where type = 'index' and name not like '%sqlite_autoindex%' and tbl_name=" + "'" + tableName + "'" 200 | var sqliteMasterList []SqliteMaster 201 | mm.Builder.RawSql(sqlIndex).GetMany(&sqliteMasterList) 202 | 203 | var indexesFromDb []Index 204 | for i := 0; i < len(sqliteMasterList); i++ { 205 | sql := sqliteMasterList[i].Sql.String 206 | 207 | t := 1 208 | if strings.Index(sql, "UNIQUE") != -1 { 209 | t = 0 210 | } 211 | 212 | compileRegex := regexp.MustCompile("INDEX\\s(.*?)\\son.*?\\((.*?)\\)") 213 | matchArr := compileRegex.FindAllStringSubmatch(sql, -1) 214 | 215 | indexesFromDb = append(indexesFromDb, Index{ 216 | NonUnique: null.IntFrom(int64(t)), 217 | ColumnName: null.StringFrom(matchArr[0][2]), 218 | KeyName: null.StringFrom(matchArr[0][1]), 219 | }) 220 | } 221 | 222 | //查询是否有主键索引 223 | sql := "select * from sqlite_master where type='table' and tbl_name=" + "'" + tableName + "'" 224 | var sqliteMaster SqliteMaster 225 | mm.Builder.RawSql(sql).GetOne(&sqliteMaster) 226 | 227 | compileRegex := regexp.MustCompile("PRIMARY\\sKEY\\s\\((.*?)\\)") 228 | matchArr2 := compileRegex.FindAllStringSubmatch(sqliteMaster.Sql.String, -1) 229 | if len(matchArr2) > 0 { 230 | indexesFromDb = append(indexesFromDb, Index{ 231 | NonUnique: null.IntFrom(0), 232 | ColumnName: null.StringFrom(matchArr2[0][1]), 233 | KeyName: null.StringFrom("PRIMARY"), 234 | }) 235 | } 236 | 237 | return indexesFromDb 238 | } 239 | 240 | func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Column, indexesFromCode []Index, tableFromDb Table, columnsFromDb []Column, indexesFromDb []Index) { 241 | for i := 0; i < len(columnsFromCode); i++ { 242 | isFind := 0 243 | columnCode := columnsFromCode[i] 244 | 245 | for j := 0; j < len(columnsFromDb); j++ { 246 | columnDb := columnsFromDb[j] 247 | if columnCode.ColumnName == columnDb.ColumnName { 248 | isFind = 1 249 | 250 | if columnCode.DataType.String != columnDb.DataType.String || 251 | columnCode.ColumnDefault.String != columnDb.ColumnDefault.String { 252 | 253 | query := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode) 254 | _, err := mm.Builder.RawSql(query).Exec() 255 | if err != nil { 256 | fmt.Println(query) 257 | fmt.Println(err) 258 | } else { 259 | fmt.Println("修改属性:" + query) 260 | } 261 | } 262 | } 263 | } 264 | 265 | if isFind == 0 { 266 | query := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode) 267 | _, err := mm.Builder.RawSql(query).Exec() 268 | if err != nil { 269 | fmt.Println(query) 270 | fmt.Println(err) 271 | } else { 272 | fmt.Println("增加属性:" + query) 273 | } 274 | } 275 | } 276 | 277 | for i := 0; i < len(indexesFromCode); i++ { 278 | isFind := 0 279 | indexCode := indexesFromCode[i] 280 | 281 | for j := 0; j < len(indexesFromDb); j++ { 282 | indexDb := indexesFromDb[j] 283 | if indexCode.ColumnName == indexDb.ColumnName { 284 | isFind = 1 285 | if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique { 286 | query := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode) 287 | _, err := mm.Builder.RawSql(query).Exec() 288 | if err != nil { 289 | fmt.Println(err) 290 | } else { 291 | fmt.Println("修改索引:" + query) 292 | } 293 | } 294 | } 295 | } 296 | 297 | if isFind == 0 { 298 | mm.createIndex(tableFromCode.TableName.String, indexCode) 299 | } 300 | } 301 | } 302 | 303 | func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Column, indexesFromCode []Index) { 304 | var fieldArr []string 305 | 306 | for i := 0; i < len(columnsFromCode); i++ { 307 | column := columnsFromCode[i] 308 | fieldArr = append(fieldArr, getColumnStr(column)) 309 | } 310 | 311 | for i := 0; i < len(indexesFromCode); i++ { 312 | index := indexesFromCode[i] 313 | if index.KeyName.String == "PRIMARY" { 314 | fieldArr = append(fieldArr, "PRIMARY KEY ("+index.ColumnName.String+")") 315 | } 316 | } 317 | 318 | //创建表结构与主键索引 319 | sql := "CREATE TABLE `" + tableFromCode.TableName.String + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + ";" 320 | _, err := mm.Builder.RawSql(sql).Exec() 321 | if err != nil { 322 | fmt.Println(err) 323 | } else { 324 | fmt.Println("创建表:" + tableFromCode.TableName.String) 325 | } 326 | 327 | //创建其他索引 328 | for i := 0; i < len(indexesFromCode); i++ { 329 | index := indexesFromCode[i] 330 | if index.KeyName.String != "PRIMARY" { 331 | mm.createIndex(tableFromCode.TableName.String, index) 332 | } 333 | } 334 | } 335 | 336 | func (mm *MigrateExecutor) createIndex(tableName string, index Index) { 337 | keyType := "" 338 | if index.NonUnique.Int64 == 0 { 339 | keyType = "UNIQUE" 340 | } 341 | 342 | sql := "CREATE " + keyType + " INDEX " + index.KeyName.String + " on " + tableName + " (" + index.ColumnName.String + ")" 343 | _, err := mm.Builder.RawSql(sql).Exec() 344 | if err != nil { 345 | fmt.Println(err) 346 | } else { 347 | fmt.Println("增加索引:" + sql) 348 | } 349 | } 350 | 351 | func getTagMap(fieldTag string) map[string]string { 352 | var fieldMap = make(map[string]string) 353 | if "" != fieldTag { 354 | tagArr := strings.Split(fieldTag, ";") 355 | for j := 0; j < len(tagArr); j++ { 356 | tagArrArr := strings.Split(tagArr[j], ":") 357 | fieldMap[tagArrArr[0]] = "" 358 | if len(tagArrArr) > 1 { 359 | fieldMap[tagArrArr[0]] = tagArrArr[1] 360 | } 361 | } 362 | } 363 | return fieldMap 364 | } 365 | 366 | func getColumnStr(column Column) string { 367 | var strArr []string 368 | strArr = append(strArr, column.ColumnName.String) 369 | 370 | //类型 371 | if column.MaxLength.Int64 == 0 { 372 | if column.DataType.String == "varchar" { 373 | strArr = append(strArr, column.DataType.String+"(255)") 374 | } else { 375 | strArr = append(strArr, column.DataType.String) 376 | } 377 | } else { 378 | strArr = append(strArr, column.DataType.String) 379 | } 380 | 381 | if column.ColumnDefault.String != "" { 382 | strArr = append(strArr, "DEFAULT '"+column.ColumnDefault.String+"'") 383 | } 384 | 385 | if column.IsNullable.String == "NO" { 386 | strArr = append(strArr, "NOT NULL") 387 | } 388 | 389 | if column.Extra.String != "" { 390 | if column.Extra.String == "auto_increment" { 391 | column.Extra.String = "AUTOINCREMENT" 392 | } 393 | 394 | //strArr = append(strArr, column.Extra.String) 395 | } 396 | 397 | return strings.Join(strArr, " ") 398 | } 399 | 400 | func getIndexStr(index Index) string { 401 | var strArr []string 402 | 403 | if "PRIMARY" == index.KeyName.String { 404 | strArr = append(strArr, index.KeyName.String) 405 | strArr = append(strArr, "KEY") 406 | strArr = append(strArr, "(`"+index.ColumnName.String+"`)") 407 | } else { 408 | if 0 == index.NonUnique.Int64 { 409 | strArr = append(strArr, "Unique") 410 | strArr = append(strArr, index.KeyName.String) 411 | strArr = append(strArr, "(`"+index.ColumnName.String+"`)") 412 | } else { 413 | strArr = append(strArr, "Index") 414 | strArr = append(strArr, index.KeyName.String) 415 | strArr = append(strArr, "(`"+index.ColumnName.String+"`)") 416 | } 417 | } 418 | 419 | return strings.Join(strArr, " ") 420 | } 421 | 422 | func getDataType(fieldType string, fieldMap map[string]string) string { 423 | var DataType string 424 | 425 | dataTypeVal, dataTypeOk := fieldMap["type"] 426 | if dataTypeOk { 427 | DataType = dataTypeVal 428 | } else { 429 | if "Int" == fieldType { 430 | DataType = "INTEGER" 431 | } 432 | if "String" == fieldType { 433 | DataType = "TEXT" 434 | } 435 | if "Bool" == fieldType { 436 | DataType = "INTEGER" 437 | } 438 | if "Time" == fieldType { 439 | DataType = "datetime" 440 | } 441 | if "Float" == fieldType { 442 | DataType = "REAL" 443 | } 444 | } 445 | 446 | return DataType 447 | } 448 | 449 | func getMaxLength(DataType string, fieldMap map[string]string) int { 450 | var MaxLength int 451 | 452 | maxLengthVal, maxLengthOk := fieldMap["size"] 453 | if maxLengthOk { 454 | num, _ := strconv.Atoi(maxLengthVal) 455 | MaxLength = num 456 | } else { 457 | MaxLength = 0 458 | if "varchar" == DataType { 459 | MaxLength = 255 460 | } 461 | } 462 | 463 | return MaxLength 464 | } 465 | 466 | func getNullAble(fieldMap map[string]string) string { 467 | var IsNullable string 468 | 469 | _, primaryOk := fieldMap["primary"] 470 | if primaryOk { 471 | IsNullable = "NO" 472 | } else { 473 | _, ok := fieldMap["not null"] 474 | if ok { 475 | IsNullable = "NO" 476 | } else { 477 | IsNullable = "YES" 478 | } 479 | } 480 | 481 | return IsNullable 482 | } 483 | 484 | func getExtra(fieldMap map[string]string) string { 485 | _, commentIs := fieldMap["auto_increment"] 486 | if commentIs { 487 | return "auto_increment" 488 | } 489 | 490 | return "" 491 | } 492 | 493 | func getDefaultVal(fieldMap map[string]string) string { 494 | defaultVal, defaultIs := fieldMap["default"] 495 | if defaultIs { 496 | return defaultVal 497 | } 498 | 499 | return "" 500 | } 501 | -------------------------------------------------------------------------------- /migrator/migrator.go: -------------------------------------------------------------------------------- 1 | package migrator 2 | 3 | import ( 4 | "github.com/tangpanqing/aorm/base" 5 | "github.com/tangpanqing/aorm/builder" 6 | "github.com/tangpanqing/aorm/driver" 7 | "github.com/tangpanqing/aorm/migrate_mssql" 8 | "github.com/tangpanqing/aorm/migrate_mysql" 9 | "github.com/tangpanqing/aorm/migrate_postgres" 10 | "github.com/tangpanqing/aorm/migrate_sqlite3" 11 | "github.com/tangpanqing/aorm/utils" 12 | "reflect" 13 | "strings" 14 | ) 15 | 16 | type Migrator struct { 17 | //数据库操作连接 18 | Link base.Link 19 | } 20 | 21 | //ShowCreateTable 获取创建表的ddl 22 | func (mi *Migrator) ShowCreateTable(tableName string) string { 23 | if mi.Link.DriverName() == driver.Mysql { 24 | me := migrate_mysql.MigrateExecutor{ 25 | Builder: &builder.Builder{ 26 | Link: mi.Link, 27 | }, 28 | } 29 | return me.ShowCreateTable(tableName) 30 | } 31 | return "" 32 | } 33 | 34 | // AutoMigrate 迁移数据库结构,需要输入数据库名,表名自动获取 35 | func (mi *Migrator) AutoMigrate(destList ...interface{}) { 36 | for i := 0; i < len(destList); i++ { 37 | dest := destList[i] 38 | typeOf := reflect.TypeOf(dest) 39 | valueOf := reflect.ValueOf(dest) 40 | tableName := getTableNameByReflect(typeOf, valueOf) 41 | mi.migrateCommon(tableName, typeOf, valueOf) 42 | } 43 | } 44 | 45 | // Migrate 自动迁移数据库结构,需要输入数据库名,表名 46 | func (mi *Migrator) Migrate(tableName string, dest interface{}) { 47 | typeOf := reflect.TypeOf(dest) 48 | valueOf := reflect.ValueOf(dest) 49 | mi.migrateCommon(tableName, typeOf, valueOf) 50 | } 51 | 52 | func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type, valueOf reflect.Value) { 53 | if mi.Link.DriverName() == driver.Mssql { 54 | me := migrate_mssql.MigrateExecutor{ 55 | Builder: &builder.Builder{ 56 | Link: mi.Link, 57 | }, 58 | } 59 | me.MigrateCommon(tableName, typeOf) 60 | } 61 | 62 | if mi.Link.DriverName() == driver.Mysql { 63 | me := migrate_mysql.MigrateExecutor{ 64 | Builder: &builder.Builder{ 65 | Link: mi.Link, 66 | }, 67 | } 68 | me.MigrateCommon(tableName, typeOf, valueOf) 69 | } 70 | 71 | if mi.Link.DriverName() == driver.Sqlite3 { 72 | me := migrate_sqlite3.MigrateExecutor{ 73 | Builder: &builder.Builder{ 74 | Link: mi.Link, 75 | }, 76 | } 77 | me.MigrateCommon(tableName, typeOf) 78 | } 79 | 80 | if mi.Link.DriverName() == driver.Postgres { 81 | me := migrate_postgres.MigrateExecutor{ 82 | Builder: &builder.Builder{ 83 | Link: mi.Link, 84 | }, 85 | } 86 | me.MigrateCommon(tableName, typeOf, valueOf) 87 | } 88 | } 89 | 90 | //反射表名,优先从方法获取,没有方法则从名字获取 91 | func getTableNameByReflect(typeOf reflect.Type, valueOf reflect.Value) string { 92 | method, isSet := typeOf.MethodByName("TableName") 93 | if isSet { 94 | var paramList []reflect.Value 95 | paramList = append(paramList, valueOf) 96 | res := method.Func.Call(paramList) 97 | return res[0].String() 98 | } else { 99 | arr := strings.Split(typeOf.String(), ".") 100 | return utils.UnderLine(arr[len(arr)-1]) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /null/null.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "math" 10 | "reflect" 11 | "strconv" 12 | "time" 13 | ) 14 | 15 | // Int 整数 16 | type Int struct { 17 | sql.NullInt64 18 | } 19 | 20 | // IntFrom 创建整数 21 | func IntFrom(i int64) Int { 22 | return Int{ 23 | NullInt64: sql.NullInt64{ 24 | Int64: i, 25 | Valid: true, 26 | }, 27 | } 28 | } 29 | 30 | // String 字符串 31 | type String struct { 32 | sql.NullString 33 | } 34 | 35 | // StringFrom 创建字符串 36 | func StringFrom(s string) String { 37 | return String{ 38 | NullString: sql.NullString{ 39 | String: s, 40 | Valid: true, 41 | }, 42 | } 43 | } 44 | 45 | // Float 浮点数 46 | type Float struct { 47 | sql.NullFloat64 48 | } 49 | 50 | // FloatFrom 创建浮点数 51 | func FloatFrom(f float64) Float { 52 | return Float{ 53 | NullFloat64: sql.NullFloat64{ 54 | Float64: f, 55 | Valid: true, 56 | }, 57 | } 58 | } 59 | 60 | // Bool 布尔值 61 | type Bool struct { 62 | sql.NullBool 63 | } 64 | 65 | // BoolFrom 创建布尔值 66 | func BoolFrom(b bool) Bool { 67 | return Bool{ 68 | NullBool: sql.NullBool{ 69 | Bool: b, 70 | Valid: true, 71 | }, 72 | } 73 | } 74 | 75 | // Time 时间 76 | type Time struct { 77 | sql.NullTime 78 | } 79 | 80 | // TimeFrom 创建时间 81 | func TimeFrom(t time.Time) Time { 82 | return Time{ 83 | NullTime: sql.NullTime{ 84 | Time: t, 85 | Valid: true, 86 | }, 87 | } 88 | } 89 | 90 | var nullBytes = []byte("null") 91 | 92 | // UnmarshalJSON 反序列化浮点数 93 | func (f *Float) UnmarshalJSON(data []byte) error { 94 | if bytes.Equal(data, nullBytes) { 95 | f.Valid = false 96 | return nil 97 | } 98 | 99 | if err := json.Unmarshal(data, &f.Float64); err != nil { 100 | var typeError *json.UnmarshalTypeError 101 | if errors.As(err, &typeError) { 102 | // special case: accept string input 103 | if typeError.Value != "string" { 104 | return fmt.Errorf("null: JSON input is invalid driver (need float or string): %w", err) 105 | } 106 | var str string 107 | if err := json.Unmarshal(data, &str); err != nil { 108 | return fmt.Errorf("null: couldn't unmarshal number string: %w", err) 109 | } 110 | n, err := strconv.ParseFloat(str, 64) 111 | if err != nil { 112 | return fmt.Errorf("null: couldn't convert string to float: %w", err) 113 | } 114 | f.Float64 = n 115 | f.Valid = true 116 | return nil 117 | } 118 | return fmt.Errorf("null: couldn't unmarshal JSON: %w", err) 119 | } 120 | 121 | f.Valid = true 122 | return nil 123 | } 124 | 125 | // MarshalJSON 序列化浮点数 126 | func (f Float) MarshalJSON() ([]byte, error) { 127 | if !f.Valid { 128 | return []byte("null"), nil 129 | } 130 | if math.IsInf(f.Float64, 0) || math.IsNaN(f.Float64) { 131 | return nil, &json.UnsupportedValueError{ 132 | Value: reflect.ValueOf(f.Float64), 133 | Str: strconv.FormatFloat(f.Float64, 'g', -1, 64), 134 | } 135 | } 136 | return []byte(strconv.FormatFloat(f.Float64, 'f', -1, 64)), nil 137 | } 138 | 139 | // UnmarshalJSON 反序列化布尔值 140 | func (b *Bool) UnmarshalJSON(data []byte) error { 141 | if bytes.Equal(data, nullBytes) { 142 | b.Valid = false 143 | return nil 144 | } 145 | 146 | if err := json.Unmarshal(data, &b.Bool); err != nil { 147 | return fmt.Errorf("null: couldn't unmarshal JSON: %w", err) 148 | } 149 | 150 | b.Valid = true 151 | return nil 152 | } 153 | 154 | // MarshalJSON 序列化布尔值 155 | func (b Bool) MarshalJSON() ([]byte, error) { 156 | if !b.Valid { 157 | return []byte("null"), nil 158 | } 159 | if !b.Bool { 160 | return []byte("false"), nil 161 | } 162 | return []byte("true"), nil 163 | } 164 | 165 | // UnmarshalJSON 反序列化时间 166 | func (t *Time) UnmarshalJSON(data []byte) error { 167 | if bytes.Equal(data, nullBytes) { 168 | t.Valid = false 169 | return nil 170 | } 171 | 172 | if err := json.Unmarshal(data, &t.Time); err != nil { 173 | return fmt.Errorf("null: couldn't unmarshal JSON: %w", err) 174 | } 175 | 176 | t.Valid = true 177 | return nil 178 | } 179 | 180 | // MarshalJSON 序列化时间 181 | func (t Time) MarshalJSON() ([]byte, error) { 182 | if !t.Valid { 183 | return []byte("null"), nil 184 | } 185 | return t.Time.MarshalJSON() 186 | } 187 | 188 | // UnmarshalJSON 反序列化整数 189 | func (i *Int) UnmarshalJSON(data []byte) error { 190 | if bytes.Equal(data, nullBytes) { 191 | i.Valid = false 192 | return nil 193 | } 194 | 195 | if err := json.Unmarshal(data, &i.Int64); err != nil { 196 | var typeError *json.UnmarshalTypeError 197 | if errors.As(err, &typeError) { 198 | // special case: accept string input 199 | if typeError.Value != "string" { 200 | return fmt.Errorf("null: JSON input is invalid driver (need int or string): %w", err) 201 | } 202 | var str string 203 | if err := json.Unmarshal(data, &str); err != nil { 204 | return fmt.Errorf("null: couldn't unmarshal number string: %w", err) 205 | } 206 | n, err := strconv.ParseInt(str, 10, 64) 207 | if err != nil { 208 | return fmt.Errorf("null: couldn't convert string to int: %w", err) 209 | } 210 | i.Int64 = n 211 | i.Valid = true 212 | return nil 213 | } 214 | return fmt.Errorf("null: couldn't unmarshal JSON: %w", err) 215 | } 216 | 217 | i.Valid = true 218 | return nil 219 | } 220 | 221 | // MarshalJSON 序列化整数 222 | func (i Int) MarshalJSON() ([]byte, error) { 223 | if !i.Valid { 224 | return []byte("null"), nil 225 | } 226 | return []byte(strconv.FormatInt(i.Int64, 10)), nil 227 | } 228 | 229 | // UnmarshalJSON 反序列化字符串 230 | func (s *String) UnmarshalJSON(data []byte) error { 231 | if bytes.Equal(data, nullBytes) { 232 | s.Valid = false 233 | return nil 234 | } 235 | 236 | if err := json.Unmarshal(data, &s.String); err != nil { 237 | return fmt.Errorf("null: couldn't unmarshal JSON: %w", err) 238 | } 239 | 240 | s.Valid = true 241 | return nil 242 | } 243 | 244 | // MarshalJSON 序列化字符串 245 | func (s String) MarshalJSON() ([]byte, error) { 246 | if !s.Valid { 247 | return []byte("null"), nil 248 | } 249 | return json.Marshal(s.String) 250 | } 251 | -------------------------------------------------------------------------------- /test/aorm_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "fmt" 5 | _ "github.com/denisenkom/go-mssqldb" 6 | _ "github.com/go-sql-driver/mysql" 7 | _ "github.com/lib/pq" 8 | _ "github.com/mattn/go-sqlite3" 9 | "github.com/tangpanqing/aorm" 10 | "github.com/tangpanqing/aorm/base" 11 | "github.com/tangpanqing/aorm/builder" 12 | "github.com/tangpanqing/aorm/driver" 13 | "github.com/tangpanqing/aorm/null" 14 | "testing" 15 | "time" 16 | ) 17 | 18 | type Student struct { 19 | StudentId null.Int `aorm:"primary;auto_increment" json:"studentId"` 20 | Name null.String `aorm:"column:student_name;size:100;not null;comment:名字" json:"name"` 21 | } 22 | 23 | func (s *Student) TableOpinion() map[string]string { 24 | return map[string]string{ 25 | "ENGINE": "InnoDB", 26 | "COMMENT": "学生表", 27 | } 28 | } 29 | 30 | type Article struct { 31 | Id null.Int `aorm:"primary;auto_increment" json:"id"` 32 | Type null.Int `aorm:"index;comment:类型" json:"driver"` 33 | PersonId null.Int `aorm:"comment:人员Id" json:"personId"` 34 | ArticleBody null.String `aorm:"driver:text;comment:文章内容" json:"articleBody"` 35 | } 36 | 37 | func (a *Article) TableOpinion() map[string]string { 38 | return map[string]string{ 39 | "ENGINE": "InnoDB", 40 | "COMMENT": "文章表", 41 | } 42 | } 43 | 44 | type ArticleVO struct { 45 | Id null.Int `aorm:"primary;auto_increment" json:"id"` 46 | Type null.Int `aorm:"index;comment:类型" json:"driver"` 47 | PersonId null.Int `aorm:"comment:人员Id" json:"personId"` 48 | PersonName null.String `aorm:"comment:人员名称" json:"personName"` 49 | ArticleBody null.String `aorm:"driver:text;comment:文章内容" json:"articleBody"` 50 | } 51 | 52 | type Person struct { 53 | Id null.Int `aorm:"primary;auto_increment" json:"id"` 54 | Name null.String `aorm:"size:100;not null;comment:名字" json:"name"` 55 | Sex null.Bool `aorm:"index;comment:性别" json:"sex"` 56 | Age null.Int `aorm:"index;comment:年龄" json:"age"` 57 | Type null.Int `aorm:"index;comment:类型" json:"driver"` 58 | CreateTime null.Time `aorm:"comment:创建时间" json:"createTime"` 59 | Money null.Float `aorm:"comment:金额" json:"money"` 60 | Test null.Float `aorm:"driver:double;comment:测试" json:"test"` 61 | } 62 | 63 | func (p *Person) TableOpinion() map[string]string { 64 | return map[string]string{ 65 | "ENGINE": "InnoDB", 66 | "COMMENT": "人员表", 67 | } 68 | } 69 | 70 | type PersonAge struct { 71 | Age null.Int 72 | AgeCount null.Int 73 | } 74 | 75 | type PersonWithArticleCount struct { 76 | Id null.Int `aorm:"primary;auto_increment" json:"id"` 77 | Name null.String `aorm:"size:100;not null;comment:名字" json:"name"` 78 | Sex null.Bool `aorm:"index;comment:性别" json:"sex"` 79 | Age null.Int `aorm:"index;comment:年龄" json:"age"` 80 | Type null.Int `aorm:"index;comment:类型" json:"driver"` 81 | CreateTime null.Time `aorm:"comment:创建时间" json:"createTime"` 82 | Money null.Float `aorm:"comment:金额" json:"money"` 83 | Test null.Float `aorm:"driver:double;comment:测试" json:"test"` 84 | ArticleCount null.Int `aorm:"comment:文章数量" json:"articleCount"` 85 | } 86 | 87 | var student = Student{} 88 | var person = Person{} 89 | var article = Article{} 90 | var articleVO = ArticleVO{} 91 | var personAge = PersonAge{} 92 | var personWithArticleCount = PersonWithArticleCount{} 93 | 94 | func TestAll(t *testing.T) { 95 | aorm.Store(&person, &article, &student) 96 | aorm.Store(&articleVO) 97 | aorm.Store(&personAge, &personWithArticleCount) 98 | 99 | var dbList = []*base.Db{ 100 | testMysqlConnect(), 101 | testSqlite3Connect(), 102 | testPostgresConnect(), 103 | testMssqlConnect(), 104 | } 105 | defer closeAll(dbList) 106 | 107 | for i := 0; i < len(dbList); i++ { 108 | dbItem := dbList[i] 109 | 110 | testMigrate(dbItem) 111 | testShowCreateTable(dbItem) 112 | 113 | id := testInsert(dbItem) 114 | testInsertBatch(dbItem) 115 | testGetOne(dbItem, id) 116 | testGetMany(dbItem) 117 | testUpdate(dbItem, id) 118 | 119 | testNull(dbItem, id) 120 | 121 | isExists := testExists(dbItem, id) 122 | if isExists != true { 123 | panic("应该存在,但是数据库不存在") 124 | } 125 | 126 | testDelete(dbItem, id) 127 | isExists2 := testExists(dbItem, id) 128 | if isExists2 == true { 129 | panic("应该不存在,但是数据库存在") 130 | } 131 | 132 | id2 := testInsert(dbItem) 133 | testTable(dbItem) 134 | testSelect(dbItem) 135 | testSelectWithSub(dbItem) 136 | testWhereWithSub(dbItem) 137 | testWhere(dbItem) 138 | testJoin(dbItem) 139 | testJoinWithAlias(dbItem) 140 | 141 | testGroupBy(dbItem) 142 | testHaving(dbItem) 143 | testOrderBy(dbItem) 144 | testLimit(dbItem) 145 | testLock(dbItem, id2) 146 | 147 | testIncrement(dbItem, id2) 148 | testDecrement(dbItem, id2) 149 | 150 | testValue(dbItem, id2) 151 | testPluck(dbItem) 152 | 153 | testCount(dbItem) 154 | testSum(dbItem) 155 | testAvg(dbItem) 156 | testMin(dbItem) 157 | testMax(dbItem) 158 | 159 | testDistinct(dbItem) 160 | testRawSql(dbItem, id2) 161 | 162 | testTransaction(dbItem) 163 | testTruncate(dbItem) 164 | 165 | } 166 | 167 | testPreview() 168 | testDbContent() 169 | } 170 | 171 | func testSqlite3Connect() *base.Db { 172 | sqlite3Content, sqlite3Err := aorm.Open(driver.Sqlite3, "test.db") 173 | if sqlite3Err != nil { 174 | panic(sqlite3Err) 175 | } 176 | 177 | sqlite3Content.SetDebugMode(false) 178 | return sqlite3Content 179 | } 180 | 181 | func testMysqlConnect() *base.Db { 182 | username := "root" 183 | password := "root" 184 | hostname := "localhost" 185 | port := "3306" 186 | dbname := "database_name" 187 | 188 | mysqlContent, mysqlErr := aorm.Open(driver.Mysql, username+":"+password+"@tcp("+hostname+":"+port+")/"+dbname+"?charset=utf8mb4&parseTime=True&loc=Local") 189 | if mysqlErr != nil { 190 | panic(mysqlErr) 191 | } 192 | 193 | mysqlContent.SetDebugMode(false) 194 | return mysqlContent 195 | } 196 | 197 | func testPostgresConnect() *base.Db { 198 | psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", "localhost", 5432, "postgres", "root", "postgres") 199 | 200 | postgresContent, postgresErr := aorm.Open(driver.Postgres, psqlInfo) 201 | if postgresErr != nil { 202 | panic(postgresErr) 203 | } 204 | 205 | postgresContent.SetDebugMode(false) 206 | 207 | return postgresContent 208 | } 209 | 210 | func testMssqlConnect() *base.Db { 211 | mssqlInfo := fmt.Sprintf("server=%s;database=%s;user id=%s;password=%s;port=%d;encrypt=disable", "localhost", "database_name", "sa", "root", 1433) 212 | 213 | mssqlContent, mssqlErr := aorm.Open(driver.Mssql, mssqlInfo) 214 | if mssqlErr != nil { 215 | panic(mssqlErr) 216 | } 217 | 218 | mssqlContent.SetDebugMode(false) 219 | return mssqlContent 220 | } 221 | 222 | func testMigrate(db *base.Db) { 223 | aorm.Migrator(db).AutoMigrate(&person, &article, &student) 224 | 225 | aorm.Migrator(db).Migrate("person_1", &person) 226 | } 227 | 228 | func testShowCreateTable(db *base.Db) { 229 | aorm.Migrator(db).ShowCreateTable("person") 230 | } 231 | 232 | func testInsert(db *base.Db) int64 { 233 | obj := Person{ 234 | Name: null.StringFrom("Alice"), 235 | Sex: null.BoolFrom(true), 236 | Age: null.IntFrom(18), 237 | Type: null.IntFrom(0), 238 | CreateTime: null.TimeFrom(time.Now()), 239 | Money: null.FloatFrom(1), 240 | Test: null.FloatFrom(2), 241 | } 242 | 243 | id, errInsert := aorm.Db(db).Insert(&obj) 244 | if errInsert != nil { 245 | panic(db.DriverName() + " testInsert " + "found err: " + errInsert.Error()) 246 | } 247 | aorm.Db(db).Insert(&Article{ 248 | Type: null.IntFrom(0), 249 | PersonId: null.IntFrom(id), 250 | ArticleBody: null.StringFrom("文章内容"), 251 | }) 252 | 253 | var personItem Person 254 | err := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).GetOne(&personItem) 255 | if err != nil { 256 | fmt.Println(err.Error()) 257 | } 258 | 259 | if obj.Name.String != personItem.Name.String { 260 | fmt.Println(db.DriverName() + ",Name not match, expected: " + obj.Name.String + " ,but real is : " + personItem.Name.String) 261 | } 262 | 263 | if obj.Sex.Bool != personItem.Sex.Bool { 264 | fmt.Println(db.DriverName() + ",Sex not match, expected: " + fmt.Sprintf("%v", obj.Sex.Bool) + " ,but real is : " + fmt.Sprintf("%v", personItem.Sex.Bool)) 265 | } 266 | 267 | if obj.Age.Int64 != personItem.Age.Int64 { 268 | fmt.Println(db.DriverName() + ",Age not match, expected: " + fmt.Sprintf("%v", obj.Age.Int64) + " ,but real is : " + fmt.Sprintf("%v", personItem.Age.Int64)) 269 | } 270 | 271 | if obj.Type.Int64 != personItem.Type.Int64 { 272 | fmt.Println(db.DriverName() + ",Type not match, expected: " + fmt.Sprintf("%v", obj.Type.Int64) + " ,but real is : " + fmt.Sprintf("%v", personItem.Type.Int64)) 273 | } 274 | 275 | if obj.Money.Float64 != personItem.Money.Float64 { 276 | fmt.Println(db.DriverName() + ",Money not match, expected: " + fmt.Sprintf("%v", obj.Money.Float64) + " ,but real is : " + fmt.Sprintf("%v", personItem.Money.Float64)) 277 | } 278 | 279 | if obj.Test.Float64 != personItem.Test.Float64 { 280 | fmt.Println(db.DriverName() + ",Test not match, expected: " + fmt.Sprintf("%v", obj.Test.Float64) + " ,but real is : " + fmt.Sprintf("%v", personItem.Test.Float64)) 281 | } 282 | 283 | //测试非id主键 284 | aorm.Db(db).Insert(&Student{ 285 | Name: null.StringFrom("new student"), 286 | }) 287 | 288 | return id 289 | } 290 | 291 | func testInsertBatch(db *base.Db) int64 { 292 | var batch []*Person 293 | batch = append(batch, &Person{ 294 | Name: null.StringFrom("Alice"), 295 | Sex: null.BoolFrom(false), 296 | Age: null.IntFrom(18), 297 | Type: null.IntFrom(0), 298 | CreateTime: null.TimeFrom(time.Now()), 299 | Money: null.FloatFrom(100.15), 300 | Test: null.FloatFrom(200.15987654321987654321), 301 | }) 302 | 303 | batch = append(batch, &Person{ 304 | Name: null.StringFrom("Bob"), 305 | Sex: null.BoolFrom(true), 306 | Age: null.IntFrom(18), 307 | Type: null.IntFrom(0), 308 | CreateTime: null.TimeFrom(time.Now()), 309 | Money: null.FloatFrom(100.15), 310 | Test: null.FloatFrom(200.15987654321987654321), 311 | }) 312 | 313 | count, err := aorm.Db(db).InsertBatch(&batch) 314 | if err != nil { 315 | panic(db.DriverName() + " testInsertBatch " + "found err:" + err.Error()) 316 | } 317 | 318 | return count 319 | } 320 | 321 | func testGetOne(db *base.Db, id int64) { 322 | var personItem Person 323 | errFind := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).GetOne(&personItem) 324 | if errFind != nil { 325 | panic(db.DriverName() + "testGetOne" + "found err") 326 | } 327 | } 328 | 329 | func testGetMany(db *base.Db) { 330 | var list []Person 331 | errSelect := aorm.Db(db).Table(&person).WhereEq(&person.Type, 0).GetMany(&list) 332 | if errSelect != nil { 333 | panic(db.DriverName() + " testGetMany " + "found err:" + errSelect.Error()) 334 | } 335 | } 336 | 337 | func testUpdate(db *base.Db, id int64) { 338 | _, errUpdate := aorm.Db(db).WhereEq(&person.Id, id).Update(&Person{Name: null.StringFrom("Bob")}) 339 | if errUpdate != nil { 340 | panic(db.DriverName() + "testUpdate" + "found err") 341 | } 342 | } 343 | 344 | func testDelete(db *base.Db, id int64) { 345 | _, errDelete := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).Delete() 346 | if errDelete != nil { 347 | panic(db.DriverName() + "testDelete" + "found err") 348 | } 349 | 350 | _, errDelete2 := aorm.Db(db).Delete(&Person{ 351 | Id: null.IntFrom(id), 352 | }) 353 | if errDelete2 != nil { 354 | panic(db.DriverName() + "testDelete" + "found err") 355 | } 356 | } 357 | 358 | func testExists(db *base.Db, id int64) bool { 359 | exists, err := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).Exists() 360 | if err != nil { 361 | panic(db.DriverName() + " testExists " + "found err:" + err.Error()) 362 | } 363 | return exists 364 | } 365 | 366 | func testNull(db *base.Db, id int64) { 367 | var p Person 368 | err := aorm.Db(db).Table(&person).WhereIsNOTNull(&person.Id).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).Debug(false).GetOne(&p) 369 | if err != nil { 370 | panic(db.DriverName() + " test WhereIsNOTNull " + "found err:" + err.Error()) 371 | } 372 | 373 | _, err = aorm.Db(db).Table(&person).WhereIsNull(&person.Id).Debug(false).Count("*") 374 | if err != nil { 375 | panic(db.DriverName() + " test WhereIsNull " + "found err:" + err.Error()) 376 | } 377 | } 378 | 379 | func testTable(db *base.Db) { 380 | _, err := aorm.Db(db).Table("person_1").Insert(&Person{Name: null.StringFrom("Cherry")}) 381 | if err != nil { 382 | panic(db.DriverName() + " testTable " + "found err:" + err.Error()) 383 | } 384 | 385 | _, err2 := aorm.Db(db).Table(&person).Insert(&Person{Name: null.StringFrom("Cherry")}) 386 | if err2 != nil { 387 | panic(db.DriverName() + " testTable " + "found err:" + err2.Error()) 388 | } 389 | 390 | var personList2 []Person 391 | subTable := aorm.Db(db).Table(&person) 392 | err3 := aorm.Db(db).Table(&subTable, "o").Debug(false).GetMany(&personList2) 393 | if err3 != nil { 394 | panic(db.DriverName() + " testTable " + "found err:" + err3.Error()) 395 | } 396 | } 397 | 398 | func testSelect(db *base.Db) { 399 | var listByFiled []Person 400 | err := aorm.Db(db).Table(&person).Select(&person.Name).Select(&person.Age).WhereEq(&person.Age, 18).GetMany(&listByFiled) 401 | if err != nil { 402 | panic(db.DriverName() + " testSelect " + "found err:" + err.Error()) 403 | } 404 | } 405 | 406 | func testSelectWithSub(db *base.Db) { 407 | var listByFiled []PersonWithArticleCount 408 | 409 | sub := aorm.Db(db).Table(&article).SelectCount(&article.Id, "article_count_tem").WhereRawEq(&article.PersonId, &person.Id) 410 | err := aorm.Db(db). 411 | SelectExp(&sub, &personWithArticleCount.ArticleCount). 412 | SelectAll(&person). 413 | Table(&person). 414 | WhereEq(&person.Age, 18). 415 | GetMany(&listByFiled) 416 | 417 | if err != nil { 418 | panic(db.DriverName() + " testSelectWithSub " + "found err:" + err.Error()) 419 | } 420 | } 421 | 422 | func testWhereWithSub(db *base.Db) { 423 | var listByFiled []Person 424 | sub := aorm.Db(db).Table(&article).SelectCount(&article.PersonId, "count_person_id").GroupBy(&article.PersonId).HavingGt("count_person_id", 0) 425 | err := aorm.Db(db). 426 | Table(&person). 427 | WhereIn(&person.Id, &sub). 428 | GetMany(&listByFiled) 429 | if err != nil { 430 | panic(db.DriverName() + " testWhereWithSub " + "found err:" + err.Error()) 431 | } 432 | } 433 | 434 | func testWhere(db *base.Db) { 435 | var listByWhere []Person 436 | err := aorm.Db(db).Table(&person).WhereArr([]builder.WhereItem{ 437 | builder.GenWhereItem(&person.Type, builder.Eq, 0), 438 | builder.GenWhereItem(&person.Age, builder.In, []int{18, 20}), 439 | builder.GenWhereItem(&person.Money, builder.Between, []float64{100.1, 200.9}), 440 | builder.GenWhereItem(&person.Money, builder.Eq, 100.15), 441 | builder.GenWhereItem(&person.Name, builder.Like, []string{"%", "li", "%"}), 442 | }).GetMany(&listByWhere) 443 | if err != nil { 444 | panic(db.DriverName() + "testWhere" + "found err") 445 | } 446 | } 447 | 448 | func testJoin(db *base.Db) { 449 | var list2 []ArticleVO 450 | err := aorm.Db(db). 451 | Table(&article). 452 | LeftJoin( 453 | &person, 454 | []builder.JoinCondition{ 455 | builder.GenJoinCondition(&person.Id, builder.RawEq, &article.PersonId), 456 | }, 457 | ). 458 | SelectAll(&article). 459 | SelectAs(&person.Name, &articleVO.PersonName). 460 | WhereEq(&article.Type, 0). 461 | WhereIn(&person.Age, []int{18, 20}). 462 | GetMany(&list2) 463 | if err != nil { 464 | panic(db.DriverName() + " testWhere " + "found err " + err.Error()) 465 | } 466 | } 467 | 468 | func testJoinWithAlias(db *base.Db) { 469 | var list2 []ArticleVO 470 | err := aorm.Db(db). 471 | Table(&article, "o"). 472 | LeftJoin( 473 | &person, 474 | []builder.JoinCondition{ 475 | builder.GenJoinCondition(&person.Id, builder.RawEq, &article.PersonId, "o"), 476 | }, 477 | "p", 478 | ). 479 | Select("*", "o"). 480 | SelectAs(&person.Name, &articleVO.PersonName, "p"). 481 | WhereEq(&article.Type, 0, "o"). 482 | WhereIn(&person.Age, []int{18, 20}, "p"). 483 | GetMany(&list2) 484 | if err != nil { 485 | panic(db.DriverName() + " testWhere " + "found err " + err.Error()) 486 | } 487 | } 488 | 489 | func testGroupBy(db *base.Db) { 490 | var personAgeItem PersonAge 491 | err := aorm.Db(db). 492 | Table(&person). 493 | Select(&person.Age). 494 | SelectCount(&person.Age, &personAge.AgeCount). 495 | GroupBy(&person.Age). 496 | WhereEq(&person.Type, 0). 497 | OrderBy(&person.Age, builder.Desc). 498 | GetOne(&personAgeItem) 499 | if err != nil { 500 | panic(db.DriverName() + "testGroupBy" + "found err") 501 | } 502 | } 503 | 504 | func testHaving(db *base.Db) { 505 | var listByHaving []PersonAge 506 | 507 | err := aorm.Db(db). 508 | Table(&person). 509 | Select(&person.Age). 510 | SelectCount(&person.Age, &personAge.AgeCount). 511 | GroupBy(&person.Age). 512 | WhereEq(&person.Type, 0). 513 | OrderBy(&person.Age, builder.Desc). 514 | HavingGt(&personAge.AgeCount, 4). 515 | GetMany(&listByHaving) 516 | if err != nil { 517 | panic(db.DriverName() + " testHaving " + "found err") 518 | } 519 | } 520 | 521 | func testOrderBy(db *base.Db) { 522 | var listByOrder []Person 523 | err := aorm.Db(db). 524 | Table(&person). 525 | WhereEq(&person.Type, 0). 526 | OrderBy(&person.Age, builder.Desc). 527 | GetMany(&listByOrder) 528 | if err != nil { 529 | panic(db.DriverName() + "testOrderBy" + "found err") 530 | } 531 | 532 | var listByOrder2 []Person 533 | err2 := aorm.Db(db). 534 | Table(&person, "o"). 535 | WhereEq(&person.Type, 0, "o"). 536 | OrderBy(&person.Age, builder.Desc, "o"). 537 | GetMany(&listByOrder2) 538 | if err2 != nil { 539 | panic(db.DriverName() + "testOrderBy" + "found err") 540 | } 541 | } 542 | 543 | func testLimit(db *base.Db) { 544 | var list3 []Person 545 | err1 := aorm.Db(db). 546 | Table(&person). 547 | WhereEq(&person.Type, 0). 548 | Limit(50, 10). 549 | OrderBy(&person.Id, builder.Desc). 550 | GetMany(&list3) 551 | if err1 != nil { 552 | panic(db.DriverName() + "testLimit" + "found err") 553 | } 554 | 555 | var list4 []Person 556 | err := aorm.Db(db). 557 | Table(&person). 558 | WhereEq(&person.Type, 0). 559 | Page(3, 10). 560 | OrderBy(&person.Id, builder.Desc). 561 | GetMany(&list4) 562 | if err != nil { 563 | panic(db.DriverName() + "testPage" + "found err") 564 | } 565 | } 566 | 567 | func testLock(db *base.Db, id int64) { 568 | if db.DriverName() == driver.Sqlite3 || db.DriverName() == driver.Mssql { 569 | return 570 | } 571 | 572 | var itemByLock Person 573 | err := aorm.Db(db). 574 | LockForUpdate(true). 575 | Table(&person). 576 | WhereEq(&person.Id, id). 577 | OrderBy(&person.Id, builder.Desc). 578 | GetOne(&itemByLock) 579 | if err != nil { 580 | panic(db.DriverName() + "testLock" + "found err") 581 | } 582 | } 583 | 584 | func testIncrement(db *base.Db, id int64) { 585 | _, err := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).Increment(&person.Age, 1) 586 | if err != nil { 587 | panic(db.DriverName() + " testIncrement " + "found err:" + err.Error()) 588 | } 589 | } 590 | 591 | func testDecrement(db *base.Db, id int64) { 592 | _, err := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).Decrement(&person.Age, 2) 593 | if err != nil { 594 | panic(db.DriverName() + "testDecrement" + "found err") 595 | } 596 | } 597 | 598 | func testValue(db *base.Db, id int64) { 599 | 600 | var name string 601 | errName := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Name, &name) 602 | if errName != nil { 603 | panic(db.DriverName() + "testValue" + "found err") 604 | } 605 | 606 | var age int64 607 | errAge := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Age, &age) 608 | if errAge != nil { 609 | panic(db.DriverName() + "testValue" + "found err") 610 | } 611 | 612 | var money float32 613 | errMoney := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Money, &money) 614 | if errMoney != nil { 615 | panic(db.DriverName() + "testValue" + "found err") 616 | } 617 | 618 | var test float64 619 | errTest := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Test, &test) 620 | if errTest != nil { 621 | panic(db.DriverName() + "testValue" + "found err") 622 | } 623 | } 624 | 625 | func testPluck(db *base.Db) { 626 | var nameList []string 627 | errNameList := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Name, &nameList) 628 | if errNameList != nil { 629 | panic(db.DriverName() + "testPluck" + "found err") 630 | } 631 | 632 | var ageList []int64 633 | errAgeList := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Age, &ageList) 634 | if errAgeList != nil { 635 | panic(db.DriverName() + "testPluck" + "found err:" + errAgeList.Error()) 636 | } 637 | 638 | var moneyList []float32 639 | errMoneyList := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Money, &moneyList) 640 | if errMoneyList != nil { 641 | panic(db.DriverName() + "testPluck" + "found err") 642 | } 643 | 644 | var testList []float64 645 | errTestList := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Test, &testList) 646 | if errTestList != nil { 647 | panic(db.DriverName() + "testPluck" + "found err") 648 | } 649 | } 650 | 651 | func testCount(db *base.Db) { 652 | _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Count("*") 653 | if err != nil { 654 | panic(db.DriverName() + "testCount" + "found err") 655 | } 656 | } 657 | 658 | func testSum(db *base.Db) { 659 | _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Sum(&person.Age) 660 | if err != nil { 661 | panic(db.DriverName() + "testSum" + "found err") 662 | } 663 | } 664 | 665 | func testAvg(db *base.Db) { 666 | _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Avg(&person.Age) 667 | if err != nil { 668 | panic(db.DriverName() + "testAvg" + "found err") 669 | } 670 | } 671 | 672 | func testMin(db *base.Db) { 673 | _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Min(&person.Age) 674 | if err != nil { 675 | panic(db.DriverName() + "testMin" + "found err") 676 | } 677 | } 678 | 679 | func testMax(db *base.Db) { 680 | _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Max(&person.Age) 681 | if err != nil { 682 | panic(db.DriverName() + "testMax" + "found err") 683 | } 684 | } 685 | 686 | func testDistinct(db *base.Db) { 687 | var listByFiled []Person 688 | err := aorm.Db(db).Distinct(true).Table(&person).Select(&person.Name).Select(&person.Age).WhereEq(&person.Age, 18).GetMany(&listByFiled) 689 | if err != nil { 690 | panic(db.DriverName() + " testSelect " + "found err:" + err.Error()) 691 | } 692 | } 693 | 694 | func testRawSql(db *base.Db, id2 int64) { 695 | var list []Person 696 | err1 := aorm.Db(db).RawSql("SELECT * FROM person WHERE id=?", id2).GetMany(&list) 697 | if err1 != nil { 698 | panic(err1) 699 | } 700 | 701 | _, err := aorm.Db(db).RawSql("UPDATE person SET name = ? WHERE id=?", "Bob2", id2).Exec() 702 | if err != nil { 703 | panic(db.DriverName() + "testRawSql" + "found err") 704 | } 705 | } 706 | 707 | func testTransaction(db *base.Db) { 708 | tx := db.Begin() 709 | 710 | id, errInsert := aorm.Db(tx).Insert(&Person{ 711 | Name: null.StringFrom("Alice"), 712 | }) 713 | 714 | if errInsert != nil { 715 | tx.Rollback() 716 | panic(db.DriverName() + " testTransaction " + "found err:" + errInsert.Error()) 717 | return 718 | } 719 | 720 | _, errCount := aorm.Db(tx).Table(&person).WhereEq(&person.Id, id).Count("*") 721 | if errCount != nil { 722 | tx.Rollback() 723 | panic(db.DriverName() + "testTransaction" + "found err") 724 | return 725 | } 726 | 727 | var personItem Person 728 | errPerson := aorm.Db(tx).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).GetOne(&personItem) 729 | if errPerson != nil { 730 | tx.Rollback() 731 | panic(db.DriverName() + "testTransaction" + "found err") 732 | return 733 | } 734 | 735 | _, errUpdate := aorm.Db(tx).Where(&Person{ 736 | Id: null.IntFrom(id), 737 | }).Update(&Person{ 738 | Name: null.StringFrom("Bob"), 739 | }) 740 | 741 | if errUpdate != nil { 742 | tx.Rollback() 743 | panic(db.DriverName() + "testTransaction" + "found err") 744 | return 745 | } 746 | 747 | tx.Commit() 748 | } 749 | 750 | func testTruncate(db *base.Db) { 751 | _, err := aorm.Db(db).Table(&person).Truncate() 752 | if err != nil { 753 | panic(db.DriverName() + " testTruncate " + "found err") 754 | } 755 | } 756 | 757 | func testPreview() { 758 | 759 | //Content Mysql 760 | db, _ := aorm.Open(driver.Mysql, "root:root@tcp(localhost:3306)/database_name?charset=utf8mb4&parseTime=True&loc=Local") 761 | defer db.Close() 762 | 763 | //Insert a Person 764 | personId, _ := aorm.Db(db).Insert(&Person{ 765 | Name: null.StringFrom("Alice"), 766 | Sex: null.BoolFrom(true), 767 | Age: null.IntFrom(18), 768 | Type: null.IntFrom(0), 769 | CreateTime: null.TimeFrom(time.Now()), 770 | Money: null.FloatFrom(1), 771 | Test: null.FloatFrom(2), 772 | }) 773 | 774 | //Insert a Article 775 | articleId, _ := aorm.Db(db).Insert(&Article{ 776 | Type: null.IntFrom(0), 777 | PersonId: null.IntFrom(personId), 778 | ArticleBody: null.StringFrom("文章内容"), 779 | }) 780 | 781 | //GetOne 782 | var personItem Person 783 | err := aorm.Db(db).Table(&person).WhereEq(&person.Id, personId).OrderBy(&person.Id, builder.Desc).GetOne(&personItem) 784 | if err != nil { 785 | fmt.Println(err.Error()) 786 | } 787 | 788 | //Join 789 | var list2 []ArticleVO 790 | aorm. 791 | Db(db). 792 | Table(&article). 793 | LeftJoin(&person, []builder.JoinCondition{ 794 | builder.GenJoinCondition(&person.Id, builder.RawEq, &article.PersonId), 795 | }). 796 | SelectAll(&article).SelectAs(&person.Name, &articleVO.PersonName). 797 | WhereEq(&article.Id, articleId). 798 | GetMany(&list2) 799 | 800 | //Join With Alias 801 | var list3 []ArticleVO 802 | aorm. 803 | Db(db). 804 | Table(&article, "o"). 805 | LeftJoin(&person, []builder.JoinCondition{ 806 | builder.GenJoinCondition(&person.Id, builder.RawEq, &article.PersonId, "o"), 807 | }, "p"). 808 | Select("*", "o").SelectAs(&person.Name, &articleVO.PersonName, "p"). 809 | WhereEq(&article.Id, articleId, "o"). 810 | GetMany(&list3) 811 | } 812 | 813 | func testDbContent() { 814 | db, err := aorm.Open(driver.Mysql, "root:root@tcp(localhost:3306)/database_name?charset=utf8mb4&parseTime=True&loc=Local") 815 | if err != nil { 816 | panic(err) 817 | } 818 | db.SetMaxOpenConns(5) 819 | db.SetDebugMode(false) 820 | defer db.Close() 821 | 822 | aorm.Db(db).Insert(&Person{ 823 | Name: null.StringFrom("test name"), 824 | }) 825 | 826 | tx := db.Begin() 827 | aorm.Db(tx).Insert(&Person{ 828 | Name: null.StringFrom("test name"), 829 | }) 830 | 831 | tx.Commit() 832 | } 833 | 834 | func closeAll(dbList []*base.Db) { 835 | for i := 0; i < len(dbList); i++ { 836 | dbList[i].Close() 837 | } 838 | } 839 | -------------------------------------------------------------------------------- /utils/str.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "unicode" 5 | ) 6 | 7 | // CamelString 将某字符串转成驼峰写法 8 | func CamelString(s string) string { 9 | data := make([]byte, 0, len(s)) 10 | j := false 11 | k := false 12 | num := len(s) - 1 13 | for i := 0; i <= num; i++ { 14 | d := s[i] 15 | if k == false && d >= 'A' && d <= 'Z' { 16 | k = true 17 | } 18 | if d >= 'a' && d <= 'z' && (j || k == false) { 19 | d = d - 32 20 | j = false 21 | k = true 22 | } 23 | if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' { 24 | j = true 25 | continue 26 | } 27 | data = append(data, d) 28 | } 29 | return string(data[:]) 30 | } 31 | 32 | // Cs 将某字符串转成驼峰写法 33 | func Cs(s string) string { 34 | return CamelString(s) 35 | } 36 | 37 | // UnderLine 将某字符串转成下划线写法 38 | func UnderLine(s string) string { 39 | var output []rune 40 | for i, r := range s { 41 | if i == 0 { 42 | output = append(output, unicode.ToLower(r)) 43 | continue 44 | } 45 | if unicode.IsUpper(r) { 46 | output = append(output, '_') 47 | } 48 | output = append(output, unicode.ToLower(r)) 49 | } 50 | return string(output) 51 | } 52 | 53 | // Ul 将某字符串转成下划线写法 54 | func Ul(s string) string { 55 | return UnderLine(s) 56 | } 57 | -------------------------------------------------------------------------------- /wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tangpanqing/aorm/527e68060c454f369c63483683679dd14328636d/wechat.jpg --------------------------------------------------------------------------------