├── .travis.yml ├── LICENCE ├── README.md ├── adopter.go ├── doc.go ├── example_test.go ├── postgres.go ├── postgres_test.go ├── sql.go ├── sql_test.go ├── table.go └── table_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.2.2 4 | - 1.3.3 5 | - 1.4.2 6 | - 1.5.1 7 | - release 8 | before_install: 9 | - go get -t -v 10 | - go get github.com/axw/gocov/gocov 11 | - go get github.com/mattn/goveralls 12 | - if ! go get code.google.com/p/go.tools/cmd/cover; then go get golang.org/x/tools/cmd/cover; fi 13 | before_script: 14 | - psql -c 'create database orange_test;' -U postgres 15 | script: 16 | - $HOME/gopath/bin/goveralls -service=travis-ci -repotoken=$COVERALLS 17 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 Geofrey Ernest 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Orange [![GoDoc](https://godoc.org/github.com/gernest/orange?status.svg)](https://godoc.org/github.com/gernest/orange)[![Coverage Status](https://coveralls.io/repos/github/gernest/orange/badge.svg?branch=master)](https://coveralls.io/github/gernest/orange?branch=master) 2 | [![Build Status](https://travis-ci.org/gernest/orange.svg?branch=master)](https://travis-ci.org/gernest/orange) 3 | 4 | Orange is a lightweight, simple Object relational Mapper for Golang. Orange offers a simple API for building meaningful database queries, and good abstractions on top of standard database/sq 5 | 6 | Orange is inspired by [gorm](https://github.com/jinzhu/gorm) 7 | # Features 8 | * Simple API 9 | * Fast 10 | * Auto migrations 11 | * Multiple database support( currently only postgresql but mysql and sqlite are 12 | work in progress) 13 | * Zero dependency( only the standard library) 14 | * Simple SQL query building API 15 | 16 | # Motivation 17 | This is my understanding of Object Relational Mapping with Golang. Instead of 18 | writing a blog post, I took the liberty to implement `orange`. It has almost all 19 | the things you might need to interact with databases with Golang. 20 | 21 | The source code is geared toward people who want to harness the power of Go. 22 | There is a lot of myths around reflections in Go, I have almost used all the 23 | techniques you will need to master reflections. 24 | 25 | THIS IS NOT FOR PRODUCTION USE, unless you know what you are doing in which case 26 | your contribution is welcome. 27 | 28 | 29 | 30 | # Installation 31 | 32 | ```bash 33 | go get github.com/gernest/orange 34 | ``` 35 | 36 | 37 | # Usage 38 | 39 | The following is a simple example to showcase the power of orange, for 40 | comprehensive API please check [ The Orange Documentation](https://godoc.org/github.com/gernest/orange) 41 | 42 | ```go 43 | package orange_test 44 | 45 | import ( 46 | "fmt" 47 | 48 | "github.com/gernest/orange" 49 | 50 | // Include the driver for your database 51 | _ "github.com/lib/pq" 52 | ) 53 | 54 | type golangster struct { 55 | ID int64 56 | Name string 57 | } 58 | 59 | func Example() { 60 | 61 | // Open a database connection 62 | connectionSTring := "user=postgres dbname=orange_test sslmode=disable" 63 | db, err := orange.Open("postgres", connectionSTring) 64 | if err != nil { 65 | panic(err) 66 | } 67 | 68 | // Register the structs that you want to map to 69 | err := db.Register(&golangster{}) 70 | if err != nil { 71 | panic(err) 72 | } 73 | 74 | // Do database migrations( tables will be created if they dont exist 75 | err = db.Automigrate() 76 | if err != nil { 77 | panic(err) 78 | } 79 | 80 | // Make sure we are connected to the database we want 81 | name := db.CurrentDatabase() 82 | fmt.Println(name) // on my case it is orange_test 83 | 84 | // Insert a new record into the database 85 | err = db.Create(&golangster{Name: "hello"}) 86 | if err != nil { 87 | panic(err) 88 | } 89 | 90 | // count the number of records 91 | var count int 92 | db.Count("*").Bind(&count) 93 | fmt.Println(count) // in my case 1 94 | 95 | // Retrieve a a record with name hello 96 | result := golangster{} 97 | db.Find(&result, &golangster{Name: "hello"}) 98 | fmt.Println(result) // on my case { 1, "hello"} 99 | 100 | } 101 | ``` 102 | 103 | # TODO list 104 | These are some of the things I will hope to add when I get time 105 | * Delete record 106 | * Support mysql 107 | * support sqlite 108 | * more comprehensive tests 109 | * improve perfomace 110 | * talk about orange 111 | 112 | 113 | # Contributing 114 | 115 | Contributions of all kinds are welcome 116 | 117 | # Author 118 | Geofrey Ernest 119 | [twitter @gernesti](https://twitter.com/gernesti) 120 | 121 | # Licence 122 | MIT see [LICENCE](LICENCE) 123 | 124 | -------------------------------------------------------------------------------- /adopter.go: -------------------------------------------------------------------------------- 1 | package orange 2 | 3 | // Adopter is an interface for database centric sql. 4 | type Adopter interface { 5 | 6 | //Create returns sql for creating the table 7 | Create(Table) (string, error) 8 | 9 | //Field is returns sql representation of the field in the database. The 10 | //returned string is used for the creation of the tables. 11 | Field(Field) (string, error) 12 | 13 | //Drop returns sql query for droping the table 14 | Drop(Table) (string, error) 15 | 16 | // Quote returns guoted string for use in the sql queries. This offers a 17 | // character for positional arguments. 18 | // 19 | // for mysql ? is used e.g name=? 20 | // for postgres $ is used e.g name=$1 21 | // The argument is the position of the parameter. 22 | Quote(int) string 23 | 24 | //HasPrepare returns true if the adopter support prepared statements 25 | HasPrepare() bool 26 | 27 | //Name returns the name of the adopter 28 | Name() string 29 | 30 | //Database returns the current Database 31 | Database(*SQL) string 32 | } 33 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | //Package orange is a lightwight object relation mapper for Go. 2 | // Orange offers a simple API for building meaningful database queries, and good 3 | // abstractions on top of standard database/sql. 4 | // 5 | // This is intendend for educational purpose only. It is not feature complete, and 6 | // lacks edge case testing. There is a lot of work still to be done to make this 7 | // package stable . 8 | // 9 | // Use this as a way to dive into Golang, a quick easy way to interact with your 10 | // database.Enjoy. 11 | package orange 12 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package orange_test 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/gernest/orange" 7 | 8 | // Include the driver for your database 9 | _ "github.com/lib/pq" 10 | ) 11 | 12 | type golangster struct { 13 | ID int64 14 | Name string 15 | } 16 | 17 | func Example() { 18 | 19 | // Open a database connection 20 | connectionSTring := "user=postgres dbname=orange_test sslmode=disable" 21 | db, err := orange.Open("postgres", connectionSTring) 22 | if err != nil { 23 | panic(err) 24 | } 25 | 26 | // Register the structs that you want to map to 27 | err = db.Register(&golangster{}) 28 | if err != nil { 29 | panic(err) 30 | } 31 | 32 | // Do database migrations( tables will be created if they dont exist 33 | err = db.Automigrate() 34 | if err != nil { 35 | panic(err) 36 | } 37 | 38 | // Make sure we are connected to the database we want 39 | name := db.CurrentDatabase() 40 | fmt.Println(name) // on my case it is orange_test 41 | 42 | // Insert a new record into the database 43 | err = db.Create(&golangster{Name: "hello"}) 44 | if err != nil { 45 | panic(err) 46 | } 47 | 48 | // count the number of records 49 | var count int 50 | err = db.Count("*").Bind(&count) 51 | if err != nil { 52 | panic(err) 53 | } 54 | fmt.Println(count) // in my case 1 55 | 56 | // Retrieve a a record with name hello 57 | result := golangster{} 58 | err = db.Find(&result, &golangster{Name: "hello"}) 59 | if err != nil { 60 | panic(err) 61 | } 62 | fmt.Println(result) // on my case { 1, "hello"} 63 | 64 | } 65 | -------------------------------------------------------------------------------- /postgres.go: -------------------------------------------------------------------------------- 1 | package orange 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | type postgresql struct{} 12 | 13 | // Create returns sql query for creating table t if it does not exist 14 | func (p *postgresql) Create(t Table) (string, error) { 15 | buf := &bytes.Buffer{} 16 | _, _ = buf.WriteString("CREATE TABLE IF NOT EXISTS " + t.Name() + " (") 17 | fields, err := t.Fields() 18 | if err != nil { 19 | return "", err 20 | } 21 | size := len(fields) 22 | for k, v := range fields { 23 | column, err := p.Field(v) 24 | if err != nil { 25 | return "", nil 26 | } 27 | if k == size-1 { 28 | _, _ = buf.WriteString(column) 29 | break 30 | } 31 | _, _ = buf.WriteString(column + ",") 32 | } 33 | _, _ = buf.WriteString(");") 34 | return buf.String(), nil 35 | } 36 | 37 | // Drop returns sql query for dropping table t. 38 | func (p *postgresql) Drop(t Table) (string, error) { 39 | query := "DROP TABLE IF EXISTS " + t.Name() 40 | return query, nil 41 | } 42 | 43 | // Field returns sql representation of field f.. 44 | func (p *postgresql) Field(f Field) (string, error) { 45 | buf := &bytes.Buffer{} 46 | fName := f.ColumnName() 47 | _, _ = buf.WriteString(fName + " ") 48 | var details string 49 | switch f.Type().Kind() { 50 | case reflect.String: 51 | details = "text" 52 | case reflect.Bool: 53 | details = "boolean" 54 | case reflect.Int: 55 | if strings.ToLower(f.Name()) == "id" { 56 | details = "serial" 57 | break 58 | } 59 | details = "integer" 60 | case reflect.Int64: 61 | if strings.ToLower(f.Name()) == "id" { 62 | details = "bigserial" 63 | break 64 | } 65 | details = "bigint" 66 | case reflect.Struct: 67 | if f.Type().AssignableTo(reflect.TypeOf(time.Time{})) { 68 | details = "timestamp with time zone" 69 | } 70 | } 71 | if details == "" { 72 | return "", fmt.Errorf(" unknown type for field %s", f.Type().Kind()) 73 | } 74 | _, _ = buf.WriteString(details) 75 | return buf.String(), nil 76 | } 77 | 78 | func (p *postgresql) Quote(pos int) string { 79 | return fmt.Sprintf("$%d", pos) 80 | } 81 | 82 | // Name returns the name of adopter. 83 | func (p *postgresql) Name() string { 84 | return "postgres" 85 | } 86 | 87 | // Database returns the name of the curent database that the queries are running 88 | // on. 89 | func (p *postgresql) Database(s *SQL) string { 90 | query := "SELECT current_database();" 91 | var name string 92 | r := s.QueryRow(query) 93 | _ = r.Scan(&name) 94 | return name 95 | } 96 | 97 | func (p *postgresql) HasPrepare() bool { 98 | return true 99 | } 100 | -------------------------------------------------------------------------------- /postgres_test.go: -------------------------------------------------------------------------------- 1 | package orange 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | type postgresTest struct { 9 | ID int 10 | Body string 11 | CreatedAt time.Time 12 | UpdatedAt time.Time 13 | } 14 | 15 | func TestPostgres_Create(t *testing.T) { 16 | p := &postgresql{} 17 | tab, err := loadTable(&postgresTest{}) 18 | if err != nil { 19 | t.Fatal(err) 20 | } 21 | create, err := p.Create(tab) 22 | if err != nil { 23 | t.Fatal(err) 24 | } 25 | expect := "CREATE TABLE IF NOT EXISTS postgres_test (id serial,body text,created_at timestamp with time zone,updated_at timestamp with time zone);" 26 | if create != expect { 27 | t.Errorf("expected %s got %s", expect, create) 28 | } 29 | } 30 | 31 | func TestPostgres_Drop(t *testing.T) { 32 | p := &postgresql{} 33 | tab, err := loadTable(&postgresTest{}) 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | drop, err := p.Drop(tab) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | expect := "DROP TABLE IF EXISTS postgres_test" 42 | if drop != expect { 43 | t.Errorf("expected %s got %s", expect, drop) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /sql.go: -------------------------------------------------------------------------------- 1 | package orange 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "html" 9 | "reflect" 10 | "strings" 11 | "sync" 12 | "time" 13 | ) 14 | 15 | // a sql command with the argumenst 16 | type clause struct { 17 | condition string 18 | args []interface{} 19 | } 20 | 21 | //SQL provides methods for interacting with Relational databases 22 | type SQL struct { 23 | models map[string]Table 24 | mu sync.RWMutex 25 | adopter Adopter 26 | loader LoadFunc 27 | clause struct { 28 | where, limit, offset, order, count, dbSelect *clause 29 | } 30 | db *sql.DB 31 | verbose bool 32 | isDone bool // true when the current query has already been executed. 33 | } 34 | 35 | func newSQL(dbAdopter Adopter, dbConnection string) (*SQL, error) { 36 | db, err := sql.Open(dbAdopter.Name(), dbConnection) 37 | if err != nil { 38 | return nil, err 39 | } 40 | return &SQL{ 41 | models: make(map[string]Table), 42 | adopter: dbAdopter, 43 | loader: loadTable, 44 | db: db, 45 | }, nil 46 | } 47 | 48 | //Open opens a new database connection for the given adopter 49 | // 50 | // There is only postgres support right now, mysql will come out soon. 51 | // database | adopter name 52 | // ---------------------------- 53 | // postgresql | postgres 54 | func Open(dbAdopter, dbConnection string) (*SQL, error) { 55 | switch dbAdopter { 56 | case "postgres": 57 | return newSQL(&postgresql{}, dbConnection) 58 | } 59 | return nil, errors.New("unsupported databse ") 60 | } 61 | 62 | //DB returns the underlying Database connection. 63 | func (s *SQL) DB() *sql.DB { 64 | return s.db 65 | } 66 | 67 | //LoadFunc sets f as the table loading function. To get Table from models the 68 | //function f will be used. It is up to the user to make sense out of the Table 69 | //implementation when this method is used. 70 | func (s *SQL) LoadFunc(f LoadFunc) *SQL { 71 | s.loader = f 72 | return s 73 | } 74 | 75 | //Register registers model. All models should be registered before calling any 76 | //method from this struct. It is safe to call this method in multiple 77 | //goroutines. 78 | func (s *SQL) Register(models ...interface{}) error { 79 | s.mu.Lock() 80 | defer s.mu.Unlock() 81 | if len(models) > 0 { 82 | for _, v := range models { 83 | t, err := s.loader(v) 84 | if err != nil { 85 | return err 86 | } 87 | typ := reflect.TypeOf(v) 88 | if typ.Kind() == reflect.Ptr { 89 | typ = typ.Elem() 90 | } 91 | s.models[typ.Name()] = t 92 | } 93 | } 94 | return nil 95 | } 96 | 97 | //DropTable drops the Database table for the model. 98 | func (s *SQL) DropTable(model interface{}) error { 99 | typ := reflect.TypeOf(model) 100 | if typ.Kind() == reflect.Ptr { 101 | typ = typ.Elem() 102 | } 103 | t := s.getModel(typ.Name()) 104 | if t != nil { 105 | query, err := s.adopter.Drop(t) 106 | if err != nil { 107 | return err 108 | } 109 | _, err = s.Exec(query) 110 | if err != nil { 111 | return err 112 | } 113 | return nil 114 | } 115 | return errors.New("the table is not registered yet ") 116 | } 117 | 118 | //getModel returns the table registered by name, returns nil if the table was 119 | //not yet registered. 120 | func (s *SQL) getModel(name string) Table { 121 | s.mu.RLock() 122 | defer s.mu.RUnlock() 123 | return s.models[name] 124 | } 125 | 126 | //Automigrate creates the database tables if they don't exist 127 | func (s *SQL) Automigrate() error { 128 | for _, m := range s.models { 129 | query, err := s.adopter.Create(m) 130 | if err != nil { 131 | return err 132 | } 133 | _, err = s.Exec(query) 134 | if err != nil { 135 | return err 136 | } 137 | } 138 | return nil 139 | } 140 | 141 | //Copy returns a new copy of s. It is used for effective method chaining to 142 | //avoid messing up the scope. 143 | func (s *SQL) Copy() *SQL { 144 | return &SQL{ 145 | db: s.db, 146 | models: s.models, 147 | adopter: s.adopter, 148 | } 149 | } 150 | 151 | //CopyQuery returns a copy of *SQL when the composed query has already been 152 | //executed. 153 | func (s *SQL) CopyQuery() *SQL { 154 | if s.isDone { 155 | return s.Copy() 156 | } 157 | return s 158 | } 159 | 160 | //Where adds a where query, value can be a query string, a model or a map 161 | func (s *SQL) Where(value interface{}, args ...interface{}) *SQL { 162 | dup := s.CopyQuery() 163 | refVal := reflect.ValueOf(value) 164 | if refVal.Kind() == reflect.Ptr { 165 | refVal = refVal.Elem() 166 | } 167 | 168 | switch refVal.Kind() { 169 | case reflect.String: 170 | c := &clause{condition: value.(string)} 171 | if len(args) > 0 { 172 | c.args = args 173 | } 174 | dup.clause.where = c 175 | return dup 176 | case reflect.Struct: 177 | t, err := loadTable(value) 178 | if err != nil { 179 | //TODO handle? 180 | return dup 181 | } 182 | cols, vals, err := Values(t, value) 183 | if err != nil { 184 | return dup 185 | } 186 | var keyVal string 187 | for k, v := range cols { 188 | keyVal = keyVal + fmt.Sprintf(" %s=%s", v, s.quote(vals[k])) 189 | } 190 | dup.clause.where = &clause{condition: keyVal} 191 | return dup 192 | } 193 | return dup 194 | } 195 | 196 | //Values returns the fields that are present in the table t which have values 197 | //set in model v. 198 | // THis tries to breakdown the mapping of table collum names with their 199 | // corresponding values. 200 | // 201 | // For instance if you have a model defined as 202 | // type foo struct{ 203 | // ID int 204 | // } 205 | // 206 | // After loading a table representation of foo, you can get the column names 207 | // that have been assigned values like this 208 | // cols,vals,err:=Values(fooTable,&foo{ID: 1}) 209 | // // cols will be []string{"ID"} 210 | // // vals will be []interface{}{1} 211 | func Values(t Table, v interface{}) (cols []string, vals []interface{}, err error) { 212 | f, err := t.Fields() 213 | if err != nil { 214 | return 215 | } 216 | value := reflect.ValueOf(v) 217 | if value.Kind() == reflect.Ptr { 218 | value = value.Elem() 219 | } 220 | for _, field := range f { 221 | fv := value.FieldByName(field.Name()) 222 | if fv.IsValid() { 223 | zero := reflect.Zero(fv.Type()) 224 | if reflect.DeepEqual(zero.Interface(), fv.Interface()) { 225 | continue 226 | } 227 | colName := field.Name() 228 | tags, _ := field.Flags() 229 | if tags != nil { 230 | for _, tag := range tags { 231 | if tag.Name() == "field_name" { 232 | colName = tag.Value() 233 | break 234 | } 235 | } 236 | } 237 | cols = append(cols, colName) 238 | vals = append(vals, fv.Interface()) 239 | } 240 | } 241 | return 242 | } 243 | 244 | //Limit sets up LIMIT clause, condition is the value for limit. Calling this with 245 | //condition, will set a barrier to the actual number of rows that are returned 246 | //after executing the query. 247 | // 248 | // If you set limit to 10, only the first 10 records will be returned and if the 249 | // query result returns less than the 10 rows thn they will be used instead. 250 | func (s *SQL) Limit(condition int) *SQL { 251 | dup := s.CopyQuery() 252 | query := fmt.Sprintf(" LIMIT %d", condition) 253 | dup.clause.limit = &clause{condition: query} 254 | return dup 255 | } 256 | 257 | // Count adds COUNT statement, colum is the column name that you want to 258 | // count. It is up to the caller to provide a single value to bind to( in which 259 | // the tal count will be written to. 260 | // 261 | // var total int64 262 | // db.Select(&user{}).Count"id").Bind(&total) 263 | func (s *SQL) Count(column string) *SQL { 264 | dup := s.CopyQuery() 265 | query := fmt.Sprintf(" COUNT (%s) ", column) 266 | dup.clause.count = &clause{condition: query} 267 | return dup 268 | } 269 | 270 | // Offset adds OFFSET clause with the offset value set to condition.This allows 271 | // you to pick just a part of the result of executing the whole query, all the 272 | // rows before condition will be skipped. 273 | // 274 | // For instance if condition is set to 5, then the results will contain rows 275 | // from number 6 276 | func (s *SQL) Offset(condition int) *SQL { 277 | dup := s.CopyQuery() 278 | query := fmt.Sprintf(" LIMIT %d", condition) 279 | dup.clause.offset = &clause{condition: query} 280 | return dup 281 | } 282 | 283 | //Select adds SELECT clause. No query is executed by this method, only the call 284 | //for *SQL.Bind will excute the built query( with exceptions of the wrappers for 285 | //database/sql package) 286 | // 287 | // query can be a model or a string. Only when query is a string will the args 288 | // be used. 289 | func (s *SQL) Select(query interface{}, args ...interface{}) *SQL { 290 | dup := s.CopyQuery() 291 | val := reflect.ValueOf(query) 292 | switch val.Kind() { 293 | case reflect.String: 294 | c := &clause{condition: query.(string)} 295 | if len(args) > 0 { 296 | c.args = args 297 | } 298 | dup.clause.dbSelect = c 299 | return dup 300 | case reflect.Struct: 301 | t := s.getModel(val.Type().Name()) 302 | if t == nil { 303 | //TODO return an error 304 | return dup 305 | } 306 | q := "* FROM " + t.Name() 307 | c := &clause{condition: q} 308 | dup.clause.dbSelect = c 309 | return dup 310 | case reflect.Ptr: 311 | val = val.Elem() 312 | if val.Kind() == reflect.Struct { 313 | t := s.getModel(val.Type().Name()) 314 | if t == nil { 315 | //TODO return an error 316 | return dup 317 | } 318 | q := "* FROM " + t.Name() 319 | c := &clause{condition: q} 320 | dup.clause.dbSelect = c 321 | return dup 322 | } 323 | } 324 | return dup 325 | } 326 | 327 | //BuildQuery returns the sql query that will be executed 328 | func (s *SQL) BuildQuery() (string, []interface{}, error) { 329 | buf := &bytes.Buffer{} 330 | var args []interface{} 331 | if s.clause.dbSelect != nil { 332 | _, _ = buf.WriteString("SELECT ") 333 | selectCond := s.clause.dbSelect.condition 334 | if s.clause.count != nil { 335 | _, _ = buf.WriteString(s.clause.count.condition) 336 | n := strings.Index(selectCond, "FROM") 337 | if n > 0 { 338 | selectCond = selectCond[n:] 339 | } 340 | } 341 | _, _ = buf.WriteString(selectCond) 342 | if s.clause.dbSelect.args != nil { 343 | args = append(args, s.clause.dbSelect.args) 344 | } 345 | } 346 | if s.clause.where != nil { 347 | _, _ = buf.WriteString(" WHERE" + s.clause.where.condition) 348 | if s.clause.dbSelect != nil && s.clause.dbSelect.args != nil { 349 | args = append(args, s.clause.where.args) 350 | } 351 | } 352 | if s.clause.offset != nil { 353 | _, _ = buf.WriteString("OFFSET " + s.clause.offset.condition) 354 | if s.clause.dbSelect.args != nil { 355 | args = append(args, s.clause.offset.args) 356 | } 357 | } 358 | if s.clause.limit != nil { 359 | _, _ = buf.WriteString("LIMIT" + s.clause.limit.condition) 360 | if s.clause.dbSelect.args != nil { 361 | args = append(args, s.clause.limit.args) 362 | } 363 | } 364 | _, _ = buf.WriteString(";") 365 | if s.verbose { 366 | fmt.Println(buf.String()) 367 | } 368 | return buf.String(), cleanArgs(args...), nil 369 | } 370 | 371 | //Find executes the composed query and retunrs a single value if model is not a 372 | //slice, or a slice of models when the model is slice. 373 | func (s *SQL) Find(model interface{}, where ...interface{}) error { 374 | dup := s.CopyQuery() 375 | dup.Select(model) 376 | switch len(where) { 377 | case 0: 378 | break 379 | case 1: 380 | dup.Where(where[0]) 381 | default: 382 | dup.Where(where[0], where[1:]...) 383 | } 384 | return dup.Bind(model) 385 | } 386 | 387 | // cleanArgs escapes all string values in args. It is a sane way to escape user 388 | // supplied inputs. 389 | // 390 | // NEVER EVER TRUST INPUT FROM THE USER 391 | func cleanArgs(args ...interface{}) (rst []interface{}) { 392 | if len(args) > 0 { 393 | for _, v := range args { 394 | if v != nil { 395 | if typ, ok := v.(string); ok { 396 | rst = append(rst, html.EscapeString(typ)) 397 | } 398 | rst = append(rst, v) 399 | } 400 | 401 | } 402 | return args 403 | } 404 | return 405 | } 406 | 407 | type valScanner interface { 408 | Scan(dest ...interface{}) error 409 | } 410 | 411 | func (s *SQL) scanStruct(scanner valScanner, model interface{}) error { 412 | val := reflect.ValueOf(model) 413 | if val.Kind() != reflect.Ptr { 414 | return errors.New("can not assign to model") 415 | } 416 | t, err := s.loader(model) 417 | if err != nil { 418 | return err 419 | } 420 | var result []interface{} 421 | fields, err := t.Fields() 422 | if err != nil { 423 | return err 424 | } 425 | 426 | for _, v := range fields { 427 | result = append(result, reflect.New(v.Type())) 428 | } 429 | err = scanner.Scan(result...) 430 | if err != nil { 431 | return err 432 | } 433 | 434 | // we use the actual value now not the address 435 | val = val.Elem() 436 | for k, v := range fields { 437 | f := val.FieldByName(v.Name()) 438 | fVal := reflect.ValueOf(result[k]) 439 | f.Set(fVal.Elem()) 440 | } 441 | return nil 442 | } 443 | 444 | //Query retriews matching rows . This wraps the sql.Query and no further 445 | //no further processing is done. 446 | func (s *SQL) Query(query string, args ...interface{}) (*sql.Rows, error) { 447 | return s.db.Query(query, args...) 448 | } 449 | 450 | //QueryRow QueryRow returnes a single matched row. This wraps sql.QueryRow no 451 | //further processing is done. 452 | func (s *SQL) QueryRow(query string, args ...interface{}) *sql.Row { 453 | return s.db.QueryRow(query, args...) 454 | } 455 | 456 | // Exec executes the query. 457 | func (s *SQL) Exec(query string, args ...interface{}) (sql.Result, error) { 458 | return s.db.Exec(query, args...) 459 | } 460 | 461 | //CurrentDatabase returns the name of the database in which the queries are 462 | //executed. 463 | func (s *SQL) CurrentDatabase() string { 464 | return s.adopter.Database(s) 465 | } 466 | 467 | //Bind executes the query and scans results into value. If there is any error it 468 | //will be returned. 469 | // 470 | // values is a pointer to the golang type into which the resulting query results 471 | // will be assigned. For structs, make sure the strucrs have been registered 472 | // with the Register method. 473 | // 474 | // If you want to assign values from the resulting query you can pass them ass a 475 | // comma separated list of argumens. 476 | // eg db.Bind(&col1,&col2,&col3) 477 | // will assign results from executing the query(only first row) to 478 | // col1,col2,and col3 respectively 479 | // 480 | // value can be a slice of struct , in which case the stuct should be a model 481 | // which has previously been registered with Register method. 482 | // 483 | // When value is a slice of truct the restult of multiple rows will be assigned 484 | // to the struct and appeded to the slice. So if the result of the query has 10 485 | // rows, the legth of the slice will be 10 and each slice item will be a struct 486 | // containing the row results. 487 | // 488 | // TODO(gernest) Add support for a slice of map[string]interface{} 489 | func (s *SQL) Bind(value interface{}, args ...interface{}) error { 490 | v := reflect.ValueOf(value) 491 | if v.Kind() != reflect.Ptr { 492 | return errors.New("non pointer argument") 493 | } 494 | defer func() { s.isDone = true }() 495 | var scanArgs []interface{} 496 | 497 | // We get the actual value that v points to.`:w 498 | actualVal := v.Elem() 499 | switch actualVal.Kind() { 500 | case reflect.Struct: 501 | t, err := s.loader(value) 502 | if err != nil { 503 | return err 504 | } 505 | fields, err := t.Fields() 506 | for _, v := range fields { 507 | scanArgs = append(scanArgs, reflect.New(v.Type()).Interface()) 508 | } 509 | query, qArgs, err := s.BuildQuery() 510 | if err != nil { 511 | return err 512 | } 513 | var row *sql.Row 514 | switch len(qArgs) { 515 | case 0: 516 | row = s.QueryRow(query) 517 | default: 518 | row = s.QueryRow(query, qArgs...) 519 | } 520 | err = row.Scan(scanArgs...) 521 | if err != nil { 522 | return err 523 | } 524 | for k, v := range scanArgs { 525 | sField, ok := actualVal.Type().FieldByName(fields[k].Name()) 526 | if ok { 527 | aField := actualVal.FieldByName(sField.Name) 528 | aField.Set(reflect.ValueOf(v).Elem()) 529 | } 530 | } 531 | default: 532 | scanArgs = append(scanArgs, value) 533 | if len(args) > 0 { 534 | scanArgs = append(scanArgs, args...) 535 | } 536 | query, qArgs, err := s.BuildQuery() 537 | if err != nil { 538 | return err 539 | } 540 | var row *sql.Row 541 | switch len(qArgs) { 542 | case 0: 543 | row = s.QueryRow(query) 544 | default: 545 | row = s.QueryRow(query, qArgs...) 546 | } 547 | return row.Scan(scanArgs...) 548 | } 549 | return nil 550 | } 551 | 552 | //Create creates a new record into the database 553 | func (s *SQL) Create(model interface{}) error { 554 | query, err := s.creare(model) 555 | if err != nil { 556 | return err 557 | } 558 | _, err = s.Exec(query) 559 | return err 560 | } 561 | 562 | func (s *SQL) creare(model interface{}) (string, error) { 563 | t, err := s.loader(model) 564 | if err != nil { 565 | return "", err 566 | } 567 | cols, vals, err := createValues(t, model) 568 | if err != nil { 569 | return "", err 570 | } 571 | buf := &bytes.Buffer{} 572 | _, _ = buf.WriteString("INSERT INTO " + t.Name()) 573 | _, _ = buf.WriteString(" (") 574 | for k, v := range cols { 575 | if k == 0 { 576 | _, _ = buf.WriteString(v) 577 | continue 578 | } 579 | _, _ = buf.WriteString(", " + v) 580 | } 581 | _, _ = buf.WriteString(")") 582 | 583 | _, _ = buf.WriteString(" VALUES (") 584 | 585 | for k, v := range vals { 586 | if k == 0 { 587 | _, _ = buf.WriteString(s.quote(v)) 588 | continue 589 | } 590 | _, _ = buf.WriteString(fmt.Sprintf(", %v", s.quote(v))) 591 | } 592 | _, _ = buf.WriteString(");") 593 | return buf.String(), nil 594 | } 595 | 596 | //createValues returns values for creating a new record 597 | func createValues(t Table, v interface{}) (cols []string, vals []interface{}, err error) { 598 | f, err := t.Fields() 599 | if err != nil { 600 | return 601 | } 602 | value := reflect.ValueOf(v) 603 | if value.Kind() == reflect.Ptr { 604 | value = value.Elem() 605 | } 606 | for _, field := range f { 607 | fv := value.FieldByName(field.Name()) 608 | if fv.IsValid() { 609 | zero := reflect.Zero(fv.Type()) 610 | colName := field.ColumnName() 611 | if reflect.DeepEqual(zero.Interface(), fv.Interface()) { 612 | if colName == "created_at" || colName == "updated_at" { 613 | cols = append(cols, colName) 614 | vals = append(vals, time.Now().Format(time.RFC3339)) 615 | } 616 | continue 617 | } 618 | cols = append(cols, colName) 619 | vals = append(vals, fv.Interface()) 620 | } 621 | } 622 | return 623 | } 624 | 625 | //Update updates a model values into the database 626 | func (s *SQL) Update(model interface{}) error { 627 | query, err := s.update(model) 628 | if err != nil { 629 | return err 630 | } 631 | _, err = s.Exec(query) 632 | return err 633 | } 634 | 635 | func (s *SQL) update(model interface{}) (string, error) { 636 | t, err := s.loader(model) 637 | if err != nil { 638 | return "", err 639 | } 640 | cols, vals, err := Values(t, model) 641 | if err != nil { 642 | return "", err 643 | } 644 | var where string 645 | var up string 646 | for k, v := range cols { 647 | if strings.ToLower(v) == "id" { 648 | where = fmt.Sprintf(" %s=%v", v, s.quote(vals[k])) 649 | continue 650 | } 651 | if up == "" { 652 | up = fmt.Sprintf("%s =%v", v, s.quote(vals[k])) 653 | continue 654 | } 655 | up = up + fmt.Sprintf(",%s =%v", v, s.quote(vals[k])) 656 | } 657 | return fmt.Sprintf("UPDATE %s SET %s WHERE %s", t.Name(), up, where), nil 658 | } 659 | 660 | // quote add single quote to val if val is a string. 661 | func (s *SQL) quote(val interface{}) string { 662 | typ := reflect.TypeOf(val) 663 | if typ.Kind() == reflect.String { 664 | return fmt.Sprintf("'%v'", val) 665 | } 666 | return fmt.Sprint(val) 667 | } 668 | -------------------------------------------------------------------------------- /sql_test.go: -------------------------------------------------------------------------------- 1 | package orange 2 | 3 | import ( 4 | "bytes" 5 | "os" 6 | "reflect" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | _ "github.com/lib/pq" 12 | ) 13 | 14 | var testDB = struct { 15 | ps, mysal, sqlite string 16 | }{} 17 | 18 | func init() { 19 | testDB.ps = os.Getenv("POSTGRES_ORANGE") 20 | if testDB.ps == "" { 21 | //testDB.ps = "postgres:://postgres@localhost/orange_test?sslmode=disable" 22 | testDB.ps = "user=postgres dbname=orange_test sslmode=disable" 23 | } 24 | } 25 | 26 | func TestOpen(t *testing.T) { 27 | _, err := Open("postgres", testDB.ps) 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | } 32 | 33 | type golangster struct { 34 | ID int64 35 | Name string 36 | CreatedAt time.Time 37 | UpdatedAt time.Time 38 | } 39 | 40 | func TestSQL_Register(t *testing.T) { 41 | db, err := Open("postgres", testDB.ps) 42 | if err != nil { 43 | t.Fatal(err) 44 | } 45 | err = db.Register(&golangster{}) 46 | if err != nil { 47 | t.Fatal(err) 48 | } 49 | } 50 | 51 | func TestSQL_Automigrate(t *testing.T) { 52 | db, err := Open("postgres", testDB.ps) 53 | if err != nil { 54 | t.Fatal(err) 55 | } 56 | err = db.Register(&golangster{}) 57 | if err != nil { 58 | t.Fatal(err) 59 | } 60 | err = db.Automigrate() 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | err = db.DropTable(&golangster{}) 65 | if err != nil { 66 | t.Fatal(err) 67 | } 68 | } 69 | 70 | func TestSQL_CurrentDatabase(t *testing.T) { 71 | db, err := Open("postgres", testDB.ps) 72 | if err != nil { 73 | t.Fatal(err) 74 | } 75 | name := db.CurrentDatabase() 76 | dbName := "orange_test" 77 | if name != dbName { 78 | t.Errorf("expected %s got %s", dbName, name) 79 | } 80 | } 81 | 82 | func TestValues(t *testing.T) { 83 | sample := []struct { 84 | id int64 85 | name string 86 | cols []string 87 | vals []interface{} 88 | }{ 89 | {0, "hello", []string{"Name"}, []interface{}{"hello"}}, 90 | } 91 | 92 | model, err := loadTable(&golangster{}) 93 | if err != nil { 94 | t.Fatal(err) 95 | } 96 | for _, v := range sample { 97 | cols, vals, err := Values(model, &golangster{Name: v.name}) 98 | if err != nil { 99 | t.Fatal(err) 100 | } 101 | if !reflect.DeepEqual(v.cols, cols) { 102 | t.Errorf("expected %v to equal %v", cols, v.cols) 103 | } 104 | if !reflect.DeepEqual(v.vals, vals) { 105 | t.Errorf("expected %v to equal %v", vals, v.vals) 106 | } 107 | } 108 | } 109 | 110 | func TestSQL_WHere(t *testing.T) { 111 | db, err := Open("postgres", testDB.ps) 112 | if err != nil { 113 | t.Fatal(err) 114 | } 115 | _ = db.Register(&golangster{}) 116 | 117 | db.Where(&golangster{Name: "hello"}) 118 | query, _, err := db.BuildQuery() 119 | if err != nil { 120 | t.Fatal(err) 121 | } 122 | exp := "WHERE Name='hello';" 123 | if strings.TrimSpace(query) != exp { 124 | t.Errorf("expected %s got %s", exp, query) 125 | } 126 | } 127 | 128 | func TestSQL_Select(t *testing.T) { 129 | db, err := Open("postgres", testDB.ps) 130 | if err != nil { 131 | t.Fatal(err) 132 | } 133 | _ = db.Register(&golangster{}) 134 | db.Select(&golangster{}) 135 | query, _, err := db.BuildQuery() 136 | if err != nil { 137 | t.Fatal(err) 138 | } 139 | expect := "SELECT * FROM golangster;" 140 | if strings.TrimSpace(query) != expect { 141 | t.Errorf("expected %s got %s", expect, query) 142 | } 143 | 144 | // This should work for non pointers too 145 | clone := db.Copy() 146 | clone.Select(golangster{}) 147 | query, _, err = clone.BuildQuery() 148 | if err != nil { 149 | t.Fatal(err) 150 | } 151 | if strings.TrimSpace(query) != expect { 152 | t.Errorf("expected %s got %s", expect, query) 153 | } 154 | 155 | clone = db.Copy() 156 | clone.Select("* FROM golangster") 157 | query, _, err = clone.BuildQuery() 158 | if err != nil { 159 | t.Fatal(err) 160 | } 161 | if strings.TrimSpace(query) != expect { 162 | t.Errorf("expected %s got %s", expect, query) 163 | } 164 | 165 | // combine select with where 166 | clone = db.Copy().Where(&golangster{Name: "gernest"}).Select(&golangster{}) 167 | query, _, err = clone.BuildQuery() 168 | if err != nil { 169 | t.Fatal(err) 170 | } 171 | comibeExpect := "SELECT * FROM golangster WHERE Name='gernest';" 172 | if strings.TrimSpace(query) != comibeExpect { 173 | t.Errorf("expected %s got %s", comibeExpect, query) 174 | } 175 | } 176 | 177 | func TestSQL_LoadFunc(t *testing.T) { 178 | buf := &bytes.Buffer{} 179 | db, err := Open("postgres", testDB.ps) 180 | if err != nil { 181 | t.Fatal(err) 182 | } 183 | db.LoadFunc(func(m interface{}) (Table, error) { 184 | tab, err := loadTable(m) 185 | if err != nil { 186 | _, _ = buf.WriteString(err.Error()) 187 | return nil, err 188 | } 189 | _, _ = buf.WriteString(tab.Name()) 190 | return tab, nil 191 | }) 192 | _ = db.Register(&golangster{}) 193 | if buf.String() != "golangster" { 194 | t.Errorf("expect golangster got %s instead", buf) 195 | } 196 | 197 | } 198 | 199 | func TestSQL_Create(t *testing.T) { 200 | db, err := Open("postgres", testDB.ps) 201 | if err != nil { 202 | t.Fatal(err) 203 | } 204 | _ = db.Register(&golangster{}) 205 | _ = db.Automigrate() 206 | 207 | // create an actual entry 208 | err = db.Create(&golangster{Name: "tanzania"}) 209 | if err != nil { 210 | t.Error(err) 211 | } 212 | 213 | } 214 | 215 | func TestSQL_Update(t *testing.T) { 216 | db, err := Open("postgres", testDB.ps) 217 | if err != nil { 218 | t.Fatal(err) 219 | } 220 | defer func() { _ = db.DropTable(&golangster{}) }() 221 | 222 | _ = db.Register(&golangster{}) 223 | query, err := db.update(&golangster{ID: 2, Name: "gernest the golangster"}) 224 | if err != nil { 225 | t.Fatal(err) 226 | } 227 | expect := "UPDATE golangster SET Name ='gernest the golangster' WHERE ID=2" 228 | if query != expect { 229 | t.Errorf("expected %s got %s", expect, query) 230 | } 231 | 232 | _ = db.Automigrate() 233 | 234 | // create an actual entry 235 | err = db.Create(&golangster{Name: "tanzania"}) 236 | if err != nil { 237 | t.Error(err) 238 | } 239 | err = db.Update(&golangster{ID: 1, Name: "gernest the golangster"}) 240 | if err != nil { 241 | t.Error(err) 242 | } 243 | } 244 | 245 | func TestSQL_Count(t *testing.T) { 246 | db, err := Open("postgres", testDB.ps) 247 | if err != nil { 248 | t.Fatal(err) 249 | } 250 | _ = db.Register(&golangster{}) 251 | db.Select(&golangster{}).Count("*") 252 | query, _, err := db.BuildQuery() 253 | if err != nil { 254 | t.Fatal(err) 255 | } 256 | expect := "SELECT COUNT (*) FROM golangster;" 257 | if strings.TrimSpace(query) != expect { 258 | t.Errorf("expected %s got %s", expect, query) 259 | } 260 | _ = db.Automigrate() 261 | names := []string{"one", "two", "three"} 262 | for _, v := range names { 263 | err = db.Create(&golangster{Name: v}) 264 | if err != nil { 265 | t.Error(err) 266 | } 267 | } 268 | var result int 269 | err = db.Select(&golangster{}).Count("*").Bind(&result) 270 | if err != nil { 271 | t.Error(err) 272 | } 273 | if result != len(names) { 274 | t.Errorf("expected %d got %d", len(names), result) 275 | } 276 | } 277 | 278 | func TestSQL_Find(t *testing.T) { 279 | db, err := Open("postgres", testDB.ps) 280 | if err != nil { 281 | t.Fatal(err) 282 | } 283 | _ = db.Register(&golangster{}) 284 | rst := &golangster{} 285 | err = db.Find(rst, &golangster{ID: 1}) 286 | if err != nil { 287 | t.Error(err) 288 | } 289 | if rst.ID != 1 { 290 | t.Errorf("expected %d got %d", 1, rst.ID) 291 | } 292 | 293 | } 294 | -------------------------------------------------------------------------------- /table.go: -------------------------------------------------------------------------------- 1 | package orange 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "reflect" 7 | "strings" 8 | "unicode" 9 | ) 10 | 11 | var ( 12 | //ErrNoField is returned when the field is not found 13 | ErrNoField = errors.New("no field found") 14 | 15 | //ErrNoFlag is returned when the flag is not found 16 | ErrNoFlag = errors.New("no flag found") 17 | specialTags = struct { 18 | fieldName, fieldType, relation string 19 | }{ 20 | "name", "type", "relation", 21 | } 22 | ) 23 | 24 | //Table is an interface for an object that can be mapped to a database table. 25 | //This has no one one one correspondance with the actual database table. There 26 | //is no limitation of the implementation of this interface. 27 | // 28 | // Just be aware that, In case you want to use custom implementation make sure 29 | // the loading function is set correctly to your custom table loading function. 30 | // see *SQL.LoadFunc for more details on custom table loading functoons. 31 | type Table interface { 32 | 33 | //Name returns the name of the table. 34 | Name() string 35 | 36 | //Fields returns a collection of fields of the table. They are like an abstract 37 | //representation of database table colums although in some case they might 38 | //not. This means what they are will depend on the implementation details. 39 | Fields() ([]Field, error) 40 | 41 | //Size returns the number of fields present in the table 42 | Size() int 43 | 44 | //Flags is a collection additional information that is tied to the table. They can be 45 | //anything within the scope of your wild imagination. 46 | Flags() ([]Flag, error) 47 | } 48 | 49 | //Field is an interface for a table field. 50 | type Field interface { 51 | Name() string 52 | Type() reflect.Type 53 | Flags() ([]Flag, error) 54 | 55 | //ColumnName is the name that this field is represented in the database table 56 | ColumnName() string 57 | } 58 | 59 | //Flag is an interface for tagging objects. This can hold additional information 60 | //about fields or tables. 61 | type Flag interface { 62 | Name() string 63 | Key() string 64 | Value() string 65 | } 66 | 67 | type table struct { 68 | name string 69 | fields []*field 70 | tags []*tag 71 | } 72 | 73 | //LoadFunc is an interface for loading tables from models. Models are structs 74 | //that maps to database tables. 75 | type LoadFunc func(model interface{}) (Table, error) 76 | 77 | //loadTable lods the model and returns a Table objec. A model is a golang struct 78 | //whose fields are the database column names. 79 | func loadTable(model interface{}) (Table, error) { 80 | value := reflect.ValueOf(model) 81 | switch value.Kind() { 82 | case reflect.Ptr: 83 | value = value.Elem() 84 | default: 85 | return nil, errors.New("provide a pointer to a model struct") 86 | } 87 | if value.Kind() != reflect.Struct { 88 | return nil, errors.New("modelsshould be structs") 89 | } 90 | t := &table{} 91 | typ := value.Type() 92 | t.name = tabulizeName(typ.Name()) 93 | for k := range make([]struct{}, typ.NumField()) { 94 | fieldTyp := typ.Field(k) 95 | f := &field{} 96 | f.name = fieldTyp.Name 97 | f.typ = fieldTyp.Type 98 | tags := fieldTyp.Tag.Get("sql") 99 | 100 | // do not add ignored fields 101 | if tags == "-" { 102 | continue 103 | } 104 | f.loadTags(tags) 105 | t.fields = append(t.fields, f) 106 | } 107 | return t, nil 108 | } 109 | 110 | // tabulizeName changes name to a good database name. This means 111 | // CamelCame will be changed to camel_case 112 | // MIXEDCase will be changed to mixed_case 113 | func tabulizeName(name string) string { 114 | if name == "" { 115 | return "" 116 | } 117 | if strings.ToLower(name) == "id" { 118 | return "id" 119 | } 120 | isFirstLower := false 121 | var capIndex []int 122 | for i, ch := range name { 123 | if i == 0 { 124 | isFirstLower = unicode.IsLower(ch) 125 | } 126 | if unicode.IsUpper(ch) { 127 | capIndex = append(capIndex, i) 128 | } 129 | } 130 | buf := &bytes.Buffer{} 131 | lenCap := len(capIndex) 132 | if lenCap == 0 { 133 | return name 134 | } 135 | i := 0 136 | left := 0 137 | piling := false 138 | END: 139 | for { 140 | switch lenCap { 141 | case 1: 142 | c := capIndex[0] 143 | if c == 0 { 144 | writeSnake(buf, name) 145 | break END 146 | } 147 | writeSnake(buf, name[left:c]) 148 | writeSnake(buf, name[c:]) 149 | break END 150 | default: 151 | if i == lenCap-1 { 152 | c := capIndex[i] 153 | if piling { 154 | writeSnake(buf, name[left:c]) 155 | } 156 | writeSnake(buf, name[c:]) 157 | break END 158 | } 159 | c := capIndex[i] 160 | if i == 0 && isFirstLower { 161 | writeSnake(buf, name[:c]) 162 | } 163 | next := capIndex[i+1] 164 | i++ 165 | if piling && next-c != 1 { 166 | writeSnake(buf, name[left:c]) 167 | writeSnake(buf, name[c:next]) 168 | left = next 169 | piling = false 170 | break 171 | } 172 | if next-c == 1 { 173 | if !piling { 174 | left = c 175 | } 176 | piling = true 177 | break 178 | } 179 | writeSnake(buf, name[left:next]) 180 | left = next 181 | } 182 | } 183 | return buf.String() 184 | } 185 | 186 | //writeSnake writes n in b with a snake case 187 | func writeSnake(b *bytes.Buffer, n string) { 188 | if b.Len() == 0 { 189 | _, _ = b.WriteString(strings.ToLower(n)) 190 | return 191 | } 192 | _, _ = b.WriteString("_") 193 | _, _ = b.WriteString(strings.ToLower(n)) 194 | } 195 | 196 | // Name returns the table name 197 | func (t *table) Name() string { 198 | return t.name 199 | } 200 | 201 | // size returns the number of fields present in this table. Note that this does 202 | // not include igored fieds. 203 | func (t *table) Size() int { 204 | return len(t.fields) 205 | } 206 | 207 | // Fields returns the fields of the tablle . 208 | func (t *table) Fields() ([]Field, error) { 209 | if t.fields != nil { 210 | var f []Field 211 | for _, v := range t.fields { 212 | f = append(f, v) 213 | } 214 | return f, nil 215 | } 216 | return nil, ErrNoField 217 | } 218 | 219 | func (t *table) Flags() ([]Flag, error) { 220 | if t.tags != nil { 221 | var f []Flag 222 | for _, v := range t.tags { 223 | f = append(f, v) 224 | } 225 | return f, nil 226 | } 227 | return nil, ErrNoFlag 228 | } 229 | 230 | //field implements the Field interface 231 | type field struct { 232 | name string 233 | typ reflect.Type 234 | tags []*tag 235 | } 236 | 237 | //Name returns the name of the field as the actualname defined in the struct, 238 | //the name specified in the tags will always remain in the tags to avoid 239 | //unnecessary name conversions. 240 | func (f *field) Name() string { 241 | return f.name 242 | } 243 | 244 | //Type returns the field value's type 245 | func (f *field) Type() reflect.Type { 246 | return f.typ 247 | } 248 | 249 | //Flags returns the tags held by the field 250 | func (f *field) Flags() ([]Flag, error) { 251 | if f.tags != nil { 252 | var t []Flag 253 | for _, v := range f.tags { 254 | t = append(t, v) 255 | } 256 | return t, nil 257 | } 258 | return nil, ErrNoFlag 259 | } 260 | 261 | func (f *field) SetValue(v interface{}) error { 262 | return nil 263 | } 264 | 265 | func (f *field) ColumnName() string { 266 | for _, v := range f.tags { 267 | if v.name == "field_name" { 268 | return v.value 269 | } 270 | } 271 | return tabulizeName(f.name) 272 | } 273 | 274 | func (f *field) loadTags(sqlTags string) { 275 | if sqlTags == "" { 276 | return 277 | } 278 | chunks := strings.Split(sqlTags, ",") 279 | if len(chunks) > 0 { 280 | for _, v := range chunks { 281 | f.tags = append(f.tags, &tag{name: "sql", value: v}) 282 | } 283 | } 284 | } 285 | 286 | type tag struct { 287 | name, key, value string 288 | } 289 | 290 | func (t *tag) Name() string { 291 | return t.name 292 | } 293 | 294 | func (t *tag) Key() string { 295 | return t.key 296 | } 297 | 298 | func (t *tag) Value() string { 299 | return t.value 300 | } 301 | -------------------------------------------------------------------------------- /table_test.go: -------------------------------------------------------------------------------- 1 | package orange 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | func TestTabulizeName(t *testing.T) { 9 | sample := []struct { 10 | name, expect string 11 | }{ 12 | {"gernest", "gernest"}, 13 | {"Gernest", "gernest"}, 14 | {"OrangeJuice", "orange_juice"}, 15 | {"orangeJuice", "orange_juice"}, 16 | {"OrangeJuiceIsSweet", "orange_juice_is_sweet"}, 17 | {"HTMLOrangeJuice", "html_orange_juice"}, 18 | {"normalPILLINGStuffs", "normal_pilling_stuffs"}, 19 | } 20 | for _, v := range sample { 21 | n := tabulizeName(v.name) 22 | if n != v.expect { 23 | t.Errorf("expected %s got %s", v.expect, n) 24 | } 25 | } 26 | } 27 | 28 | type simpleModel struct { 29 | ID int64 `sql:"id"` 30 | BOdy string 31 | CreatedAt time.Time 32 | UpdatedAT time.Time 33 | } 34 | 35 | func TestLoadTable(t *testing.T) { 36 | tb, err := loadTable(&simpleModel{}) 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | name := "simple_model" 41 | if tb.Name() != name { 42 | t.Errorf("expected %s got %s", name, tb.Name()) 43 | } 44 | } 45 | --------------------------------------------------------------------------------