├── types.go ├── utils.go ├── orm.go ├── README.md └── repo.go /types.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | // G is shortcut 4 | type G map[string]interface{} 5 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import "bytes" 4 | 5 | func snakeToUpperCamel(s string) string { 6 | buf := new(bytes.Buffer) 7 | first := true 8 | for i := 0; i < len(s); i++ { 9 | c := s[i] 10 | if c >= 'a' && c <= 'z' && first { 11 | buf.WriteByte(c - 32) 12 | first = false 13 | } else if c == '_' { 14 | first = true 15 | continue 16 | } else { 17 | buf.WriteByte(c) 18 | } 19 | } 20 | return buf.String() 21 | } 22 | -------------------------------------------------------------------------------- /orm.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql" 5 | "reflect" 6 | ) 7 | 8 | // Orm 一个程序创建一个全局的Orm对象即可 9 | type Orm struct { 10 | dbs map[string]*sql.DB 11 | mappings map[string]map[string]string 12 | ShowSQL bool 13 | } 14 | 15 | // New 创建全局的Orm对象 16 | func New() *Orm { 17 | return &Orm{ 18 | dbs: make(map[string]*sql.DB), 19 | mappings: make(map[string]map[string]string), 20 | ShowSQL: true, 21 | } 22 | } 23 | 24 | // Add 增加一个DataSource 25 | func (o *Orm) Add(name, addr string, idle, max int) error { 26 | db, err := sql.Open("mysql", addr) 27 | if err != nil { 28 | return err 29 | } 30 | 31 | db.SetMaxIdleConns(idle) 32 | db.SetMaxOpenConns(max) 33 | 34 | o.dbs[name] = db 35 | return nil 36 | } 37 | 38 | // Register 注册Struct,程序启动的时候先进行Register 39 | // e.g. orm.New().Register(new(User), new(Topic)) 40 | func (o *Orm) Register(vs ...interface{}) { 41 | l := len(vs) 42 | for i := 0; i < l; i++ { 43 | typ := reflect.TypeOf(vs[i]) 44 | ele := typ.Elem() 45 | num := ele.NumField() 46 | fields := make(map[string]string) 47 | for j := 0; j < num; j++ { 48 | field := ele.Field(j) 49 | tag := field.Tag.Get("orm") 50 | if tag != "" { 51 | fields[tag] = field.Name 52 | } 53 | } 54 | o.mappings[typ.String()] = fields 55 | } 56 | } 57 | 58 | // NewRepo 创建一个Repo,每做一次SQL操作都要新new一个Repo 59 | func (o *Orm) NewRepo(tbl string) *Repo { 60 | return &Repo{ 61 | o: o, 62 | tbl: tbl, 63 | showSQL: o.ShowSQL, 64 | } 65 | } 66 | 67 | // Use 使用哪个数据库 68 | func (o *Orm) Use(name string) *sql.DB { 69 | db, has := o.dbs[name] 70 | if !has { 71 | panic("no such database: " + name) 72 | } 73 | return db 74 | } 75 | 76 | // Tag2field 通过tag查字段名称 77 | func (o *Orm) Tag2field(typ reflect.Type, key string) string { 78 | m, has := o.mappings[typ.String()] 79 | if !has { 80 | return snakeToUpperCamel(key) 81 | } 82 | 83 | val, has := m[key] 84 | if !has { 85 | return snakeToUpperCamel(key) 86 | } 87 | 88 | return val 89 | } 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | *最简单的orm小框架,只支持mysql,下面是基本用法范例* 2 | 3 | ```go 4 | package main 5 | 6 | import ( 7 | "log" 8 | 9 | "github.com/ulricqin/orm" 10 | 11 | _ "github.com/go-sql-driver/mysql" 12 | ) 13 | 14 | // User 表在default库,即:minos_portal库 15 | type User struct { 16 | ID int64 `orm:"id"` 17 | Username string 18 | Nickname string 19 | } 20 | 21 | // UserRepo 这是每次DB操作的入口函数 22 | // user表在default库,不需要使用Use来特别指定 23 | func UserRepo() *orm.Repo { 24 | return Orm.NewRepo("user") 25 | } 26 | 27 | // Judge 在naming库,即:minos_naming库 28 | type Judge struct { 29 | ID int64 `orm:"id"` 30 | Address string `orm:"address"` 31 | } 32 | 33 | // JudgeRepo 这是每次DB操作的入口函数 34 | // judge表不在默认的default库,故而需要执行Use 35 | func JudgeRepo() *orm.Repo { 36 | return Orm.NewRepo("judge").Use("naming") 37 | } 38 | 39 | // DBConfig 数据库配置,支持配置多个库 40 | // 至少有个default库 41 | type DBConfig struct { 42 | Addr map[string]string 43 | Idle int 44 | Max int 45 | } 46 | 47 | var configs = DBConfig{ 48 | Addr: map[string]string{ 49 | "default": "root@tcp(127.0.0.1:3306)/minos_portal?charset=utf8&&loc=Asia%2FShanghai", 50 | "naming": "root@tcp(127.0.0.1:3306)/minos_naming?charset=utf8&&loc=Asia%2FShanghai", 51 | }, 52 | Idle: 2, 53 | Max: 10, 54 | } 55 | 56 | // Orm 全局操作入口 57 | var Orm *orm.Orm 58 | 59 | func main() { 60 | 61 | // Orm 对象可以放在程序全局,程序启动的时候初始化好 62 | Orm = orm.New() 63 | 64 | // 配置Orm的DataSource 65 | for k, v := range configs.Addr { 66 | if err := Orm.Add(k, v, configs.Idle, configs.Max); err != nil { 67 | // 程序启动的时候如果发现数据库连接不上,直接报错退出 68 | log.Fatalln(err) 69 | } 70 | } 71 | 72 | // 将各个model注册给Orm,这样才能识别Struct中各个字段的orm tag 73 | Orm.Register(new(User), new(Judge)) 74 | 75 | // 插入一条记录 76 | lastid, err := UserRepo().Insert(orm.G{ 77 | "username": "UlricQin", 78 | "nickname": "秦晓辉", 79 | }) 80 | dangerous(err) 81 | 82 | log.Println("insert user success, lastid:", lastid) 83 | 84 | // 查一条记录出来 85 | var user User 86 | has, err := UserRepo().Where("id=?", lastid).Find(&user) 87 | dangerous(err) 88 | 89 | if !has { 90 | log.Fatalln("no such user") 91 | } 92 | 93 | log.Println("Find user:", user) 94 | 95 | // 更新一条记录,如果调用了Quiet,将不打印sql语句 96 | num, err := UserRepo().Quiet().Where("id=?", lastid).Update(orm.G{ 97 | "username": "Ulric2", 98 | "nickname": "晓辉", 99 | }) 100 | dangerous(err) 101 | 102 | log.Println("update affected rows:", num) 103 | 104 | // 再插入一条记录,做个列表查询 105 | _, err = UserRepo().Insert(orm.G{ 106 | "username": "Ulric1", 107 | "nickname": "Flame", 108 | }) 109 | dangerous(err) 110 | 111 | // 计数 112 | count, err := UserRepo().Where("id>=?", lastid).Count() 113 | dangerous(err) 114 | 115 | log.Printf("user count of id>=%d is %d", lastid, count) 116 | 117 | // 只查询一列 118 | usernames, err := UserRepo().Where("id>=?", lastid).OrderBy("username").Limit(1, 1).StrCol("username") 119 | dangerous(err) 120 | 121 | log.Println("usernames, should only has Ulric2 => ", usernames) 122 | 123 | // 查询列表 124 | var users []*User 125 | err = UserRepo().Where("id>=?", lastid).Finds(&users) 126 | dangerous(err) 127 | 128 | log.Println("Find users:") 129 | for i := 0; i < len(users); i++ { 130 | log.Println(users[i]) 131 | } 132 | 133 | // 删除操作 134 | num, err = UserRepo().Limit(2).Where("id>=?", lastid).Delete() 135 | dangerous(err) 136 | 137 | log.Println("delete user affected:", num) 138 | 139 | log.Println("------------------") 140 | 141 | // 以上封装的方法都是针对单表的,这个简易orm框架也就只做这些事情 142 | // 复杂的sql操作可以直接使用内部的*sql.DB,比如 143 | 144 | ret, err := Orm.Use("naming").Exec("insert into judge(address, last_update) values(?, now())", "127.0.0.1:7788") 145 | dangerous(err) 146 | 147 | lastid, err = ret.LastInsertId() 148 | dangerous(err) 149 | 150 | log.Println("insert address success, lastid:", lastid) 151 | 152 | row := Orm.Use("naming").QueryRow("select address from judge where id = ?", lastid) 153 | var address string 154 | err = row.Scan(&address) 155 | dangerous(err) 156 | log.Println("query row address:", address) 157 | 158 | _, err = Orm.Use("naming").Exec("delete from judge where id=?", lastid) 159 | dangerous(err) 160 | } 161 | 162 | func dangerous(err error) { 163 | if err != nil { 164 | log.Fatalln(err) 165 | } 166 | } 167 | 168 | 169 | ``` -------------------------------------------------------------------------------- /repo.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "fmt" 7 | "log" 8 | "reflect" 9 | "strconv" 10 | "strings" 11 | "time" 12 | ) 13 | 14 | // Repo 封装一次查询 15 | type Repo struct { 16 | o *Orm 17 | tbl string 18 | db *sql.DB 19 | where string 20 | args []interface{} 21 | orderBy string 22 | limit int 23 | offset int 24 | showSQL bool 25 | cols string 26 | sql string 27 | } 28 | 29 | // Use 使用哪个数据库实例 30 | func (r *Repo) Use(name string) *Repo { 31 | r.db = r.o.Use(name) 32 | return r 33 | } 34 | 35 | // Where 查询条件 36 | func (r *Repo) Where(where string, args ...interface{}) *Repo { 37 | r.where = where 38 | r.args = args 39 | return r 40 | } 41 | 42 | // OrderBy e.g. name => order by name 43 | func (r *Repo) OrderBy(by string) *Repo { 44 | r.orderBy = by 45 | return r 46 | } 47 | 48 | // Limit 设置limit和offset 49 | func (r *Repo) Limit(limit int, offset ...int) *Repo { 50 | r.limit = limit 51 | if len(offset) > 0 { 52 | r.offset = offset[0] 53 | } 54 | return r 55 | } 56 | 57 | // Quiet 不传参数则设置showSQL为false 58 | func (r *Repo) Quiet(showSQL ...bool) *Repo { 59 | if len(showSQL) > 0 { 60 | r.showSQL = showSQL[0] 61 | } else { 62 | r.showSQL = false 63 | } 64 | return r 65 | } 66 | 67 | // Cols 设置要查询的column 68 | func (r *Repo) Cols(cols string) *Repo { 69 | r.cols = cols 70 | return r 71 | } 72 | 73 | func (r *Repo) insure() { 74 | if r.db == nil { 75 | r.Use("default") 76 | } 77 | } 78 | 79 | func (r *Repo) p(query string, args []interface{}) { 80 | if r.showSQL { 81 | log.Println("[orm]", query, "params:", args) 82 | } 83 | } 84 | 85 | func (r *Repo) exec(query string, args ...interface{}) (sql.Result, error) { 86 | r.insure() 87 | r.p(query, args) 88 | return r.db.Exec(query, args...) 89 | } 90 | 91 | func (r *Repo) queryRow(query string, args ...interface{}) *sql.Row { 92 | r.insure() 93 | r.p(query, args) 94 | return r.db.QueryRow(query, args...) 95 | } 96 | 97 | func (r *Repo) query(query string, args ...interface{}) (*sql.Rows, error) { 98 | r.insure() 99 | r.p(query, args) 100 | return r.db.Query(query, args...) 101 | } 102 | 103 | func (r *Repo) buildSQL() { 104 | if r.cols == "" { 105 | r.cols = "*" 106 | } 107 | 108 | buf := new(bytes.Buffer) 109 | buf.WriteString("SELECT ") 110 | buf.WriteString(r.cols) 111 | buf.WriteString(" FROM `") 112 | buf.WriteString(r.tbl) 113 | buf.WriteString("`") 114 | 115 | if r.where != "" { 116 | buf.WriteString(" WHERE ") 117 | buf.WriteString(r.where) 118 | } 119 | 120 | if r.orderBy != "" { 121 | buf.WriteString(" ORDER BY ") 122 | buf.WriteString(r.orderBy) 123 | } 124 | 125 | if r.limit > 0 { 126 | buf.WriteString(" LIMIT ?") 127 | r.args = append(r.args, r.limit) 128 | } 129 | 130 | if r.offset > 0 { 131 | buf.WriteString(" OFFSET ?") 132 | r.args = append(r.args, r.offset) 133 | } 134 | 135 | r.sql = buf.String() 136 | } 137 | 138 | // Count 统计数目 139 | func (r *Repo) Count() (count int, err error) { 140 | r.cols = "count(*) as count" 141 | r.buildSQL() 142 | err = r.queryRow(r.sql, r.args...).Scan(&count) 143 | return 144 | } 145 | 146 | // Insert 保存一条数据,返回lastid 147 | func (r *Repo) Insert(attrs G) (int64, error) { 148 | ln := len(attrs) 149 | keys := make([]string, 0, ln) 150 | qms := make([]string, 0, ln) 151 | vals := make([]interface{}, 0, ln) 152 | for k, v := range attrs { 153 | keys = append(keys, fmt.Sprintf("`%s`", k)) 154 | qms = append(qms, "?") 155 | vals = append(vals, v) 156 | } 157 | 158 | s := fmt.Sprintf( 159 | "INSERT INTO `%s`(%s) VALUES(%s)", 160 | r.tbl, 161 | strings.Join(keys, ","), 162 | strings.Join(qms, ","), 163 | ) 164 | 165 | ret, err := r.exec(s, vals...) 166 | if err != nil { 167 | return 0, err 168 | } 169 | 170 | return ret.LastInsertId() 171 | } 172 | 173 | // Delete 根据where条件做删除,返回被影响的行数 174 | func (r *Repo) Delete() (int64, error) { 175 | s := fmt.Sprintf("DELETE FROM `%s`", r.tbl) 176 | if r.where != "" { 177 | s += " WHERE " + r.where 178 | } 179 | 180 | if r.limit > 0 { 181 | s += " LIMIT ?" 182 | r.args = append(r.args, r.limit) 183 | } 184 | 185 | ret, err := r.exec(s, r.args...) 186 | if err != nil { 187 | return 0, err 188 | } 189 | 190 | return ret.RowsAffected() 191 | } 192 | 193 | // Update 更新记录 194 | func (r *Repo) Update(attrs G) (int64, error) { 195 | ln := len(attrs) 196 | keys := make([]string, 0, ln) 197 | vals := make([]interface{}, 0, ln) 198 | for k, v := range attrs { 199 | keys = append(keys, fmt.Sprintf("`%s`=?", k)) 200 | vals = append(vals, v) 201 | } 202 | 203 | s := fmt.Sprintf("UPDATE `%s` SET %s", r.tbl, strings.Join(keys, ",")) 204 | if r.where != "" { 205 | s += " WHERE " + r.where 206 | vals = append(vals, r.args...) 207 | } 208 | 209 | if r.limit > 0 { 210 | vals = append(vals, r.limit) 211 | } 212 | 213 | ret, err := r.exec(s, vals...) 214 | if err != nil { 215 | return 0, err 216 | } 217 | 218 | return ret.RowsAffected() 219 | } 220 | 221 | // I64Col 获取一列数据,数据类型是int64 222 | func (r *Repo) I64Col(col string) ([]int64, error) { 223 | cols := []int64{} 224 | rs, err := r.col(col) 225 | if err != nil { 226 | return cols, err 227 | } 228 | 229 | defer rs.Close() 230 | 231 | for rs.Next() { 232 | var item int64 233 | err = rs.Scan(&item) 234 | if err != nil { 235 | return cols, err 236 | } 237 | 238 | cols = append(cols, item) 239 | } 240 | 241 | return cols, err 242 | } 243 | 244 | // StrCol 获取一列数据,数据类型是string 245 | func (r *Repo) StrCol(col string) ([]string, error) { 246 | cols := []string{} 247 | rs, err := r.col(col) 248 | if err != nil { 249 | return cols, err 250 | } 251 | 252 | defer rs.Close() 253 | 254 | for rs.Next() { 255 | var item string 256 | err = rs.Scan(&item) 257 | if err != nil { 258 | return cols, err 259 | } 260 | 261 | cols = append(cols, item) 262 | } 263 | 264 | return cols, err 265 | } 266 | 267 | func (r *Repo) col(col string) (*sql.Rows, error) { 268 | r.cols = col 269 | r.buildSQL() 270 | return r.query(r.sql, r.args...) 271 | } 272 | 273 | // U64s 将uint64类型的slice拼接成逗号分隔的string 274 | func U64s(ids []uint64) string { 275 | count := len(ids) 276 | strs := make([]string, count) 277 | for i := 0; i < count; i++ { 278 | strs[i] = fmt.Sprint(ids[i]) 279 | } 280 | return strings.Join(strs, ",") 281 | } 282 | 283 | // I64s 将int64类型的slice拼接成逗号分隔的string 284 | func I64s(ids []int64) string { 285 | count := len(ids) 286 | strs := make([]string, count) 287 | for i := 0; i < count; i++ { 288 | strs[i] = fmt.Sprint(ids[i]) 289 | } 290 | return strings.Join(strs, ",") 291 | } 292 | 293 | // I64Arr 将逗号分隔的字符串ID转换成[]int64 294 | func I64Arr(ids string) []int64 { 295 | if ids == "" { 296 | return []int64{} 297 | } 298 | 299 | arr := strings.Split(ids, ",") 300 | count := len(arr) 301 | ret := make([]int64, 0, count) 302 | for i := 0; i < count; i++ { 303 | if arr[i] == "" { 304 | continue 305 | } 306 | id, err := strconv.ParseInt(arr[i], 10, 64) 307 | if err != nil { 308 | continue 309 | } 310 | ret = append(ret, id) 311 | } 312 | return ret 313 | } 314 | 315 | // Rows 查询多行记录 316 | func (r *Repo) Rows() (*sql.Rows, error) { 317 | r.insure() 318 | r.buildSQL() 319 | r.p(r.sql, r.args) 320 | stmt, err := r.db.Prepare(r.sql) 321 | if err != nil { 322 | return nil, err 323 | } 324 | 325 | defer stmt.Close() 326 | return stmt.Query(r.args...) 327 | } 328 | 329 | // Row 查询一条记录 330 | func (r *Repo) Row() *sql.Row { 331 | r.buildSQL() 332 | return r.queryRow(r.sql, r.args...) 333 | } 334 | 335 | // Find 查找一个struct,传入的第一个参数是struct的指针 336 | func (r *Repo) Find(ptr interface{}) (bool, error) { 337 | rows, err := r.Rows() 338 | if err != nil { 339 | return false, err 340 | } 341 | 342 | val := reflect.ValueOf(ptr) 343 | 344 | defer rows.Close() 345 | 346 | if rows.Next() { 347 | err = r.scanRows(val, rows) 348 | if err != nil { 349 | return false, err 350 | } 351 | } else { 352 | return false, nil 353 | } 354 | 355 | return true, nil 356 | } 357 | 358 | // Finds 查询一个列表,ptr e.g. var user []*User -> &user 359 | func (r *Repo) Finds(ptr interface{}) error { 360 | rows, err := r.Rows() 361 | if err != nil { 362 | return err 363 | } 364 | 365 | sliceValue := reflect.Indirect(reflect.ValueOf(ptr)) 366 | structType := sliceValue.Type().Elem().Elem() 367 | 368 | defer rows.Close() 369 | 370 | for rows.Next() { 371 | rowValue := reflect.New(structType) 372 | err = r.scanRows(rowValue, rows) 373 | if err != nil { 374 | return err 375 | } 376 | sliceValue.Set(reflect.Append(sliceValue, rowValue)) 377 | } 378 | 379 | return nil 380 | } 381 | 382 | func (r *Repo) scanRows(val reflect.Value, rows *sql.Rows) (err error) { 383 | cols, _ := rows.Columns() 384 | 385 | containers := make([]interface{}, 0, len(cols)) 386 | for i := 0; i < cap(containers); i++ { 387 | var v interface{} 388 | containers = append(containers, &v) 389 | } 390 | 391 | err = rows.Scan(containers...) 392 | if err != nil { 393 | return 394 | } 395 | 396 | typ := val.Type() 397 | 398 | for i, v := range containers { 399 | value := reflect.Indirect(reflect.ValueOf(v)) 400 | if !value.Elem().IsValid() { 401 | continue 402 | } 403 | 404 | key := cols[i] 405 | 406 | field := val.Elem().FieldByName(r.o.Tag2field(typ, key)) 407 | if field.IsValid() { 408 | // value -> field 409 | err = setModelValue(value, field) 410 | if err != nil { 411 | return 412 | } 413 | } 414 | } 415 | 416 | return 417 | } 418 | 419 | func parseBool(value reflect.Value) bool { 420 | return value.Bool() 421 | } 422 | 423 | func setPtrValue(driverValue, fieldValue reflect.Value) { 424 | t := fieldValue.Type().Elem() 425 | v := reflect.New(t) 426 | fieldValue.Set(v) 427 | switch t.Kind() { 428 | case reflect.String: 429 | v.Elem().SetString(string(driverValue.Interface().([]uint8))) 430 | case reflect.Int64: 431 | v.Elem().SetInt(driverValue.Interface().(int64)) 432 | case reflect.Float64: 433 | v.Elem().SetFloat(driverValue.Interface().(float64)) 434 | case reflect.Bool: 435 | v.Elem().SetBool(driverValue.Interface().(bool)) 436 | } 437 | } 438 | 439 | func setModelValue(driverValue, fieldValue reflect.Value) error { 440 | switch fieldValue.Type().Kind() { 441 | case reflect.Bool: 442 | fieldValue.SetBool(parseBool(driverValue.Elem())) 443 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 444 | fieldValue.SetInt(driverValue.Elem().Int()) 445 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 446 | // reading uint from int value causes panic 447 | switch driverValue.Elem().Kind() { 448 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 449 | fieldValue.SetUint(uint64(driverValue.Elem().Int())) 450 | default: 451 | fieldValue.SetUint(driverValue.Elem().Uint()) 452 | } 453 | case reflect.Float32, reflect.Float64: 454 | fieldValue.SetFloat(driverValue.Elem().Float()) 455 | case reflect.String: 456 | fieldValue.SetString(string(driverValue.Elem().Bytes())) 457 | case reflect.Slice: 458 | if reflect.TypeOf(driverValue.Interface()).Elem().Kind() == reflect.Uint8 { 459 | fieldValue.SetBytes(driverValue.Elem().Bytes()) 460 | } 461 | case reflect.Ptr: 462 | setPtrValue(driverValue, fieldValue) 463 | case reflect.Struct: 464 | switch fieldValue.Interface().(type) { 465 | case time.Time: 466 | fieldValue.Set(driverValue.Elem()) 467 | default: 468 | if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { 469 | return scanner.Scan(driverValue.Interface()) 470 | } 471 | } 472 | } 473 | return nil 474 | } 475 | --------------------------------------------------------------------------------