├── .gitignore ├── .travis.yml ├── field.go ├── example_test.go ├── LICENSE ├── util.go ├── util_test.go ├── field_test.go ├── type.go ├── log.go ├── type_test.go ├── example_select_test.go ├── dialect.go ├── README.md ├── genmai.go └── dialect_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.test 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.7 4 | - 1.8 5 | install: 6 | - go get -v github.com/mattn/go-sqlite3 7 | - go get -v github.com/go-sql-driver/mysql 8 | - go get -v github.com/lib/pq 9 | - go get -v github.com/naoina/genmai 10 | env: 11 | - DB=sqlite3 12 | - DB=mysql 13 | - DB=postgres 14 | before_script: 15 | - sh -c "if [ '$DB' = 'postgres' ]; then psql -c 'DROP DATABASE IF EXISTS genmai_test;' -U postgres; fi" 16 | - sh -c "if [ '$DB' = 'postgres' ]; then psql -c 'CREATE DATABASE genmai_test;' -U postgres; fi" 17 | - sh -c "if [ '$DB' = 'mysql' ]; then mysql -e 'CREATE DATABASE IF NOT EXISTS genmai_test;'; fi" 18 | script: 19 | - go test ./... 20 | -------------------------------------------------------------------------------- /field.go: -------------------------------------------------------------------------------- 1 | package genmai 2 | 3 | import "time" 4 | 5 | // TimeStamp is fields for timestamps that commonly used. 6 | type TimeStamp struct { 7 | // Time of creation. This field will be set automatically by BeforeInsert. 8 | CreatedAt time.Time `json:"created_at"` 9 | 10 | // Time of update. This field will be set by BeforeInsert or BeforeUpdate. 11 | UpdatedAt time.Time `json:"updated_at"` 12 | } 13 | 14 | // BeforeInsert sets current time to CreatedAt and UpdatedAt field. 15 | // It always returns nil. 16 | func (ts *TimeStamp) BeforeInsert() error { 17 | n := now() 18 | ts.CreatedAt = n 19 | ts.UpdatedAt = n 20 | return nil 21 | } 22 | 23 | // BeforeUpdate sets current time to UpdatedAt field. 24 | // It always returns nil. 25 | func (ts *TimeStamp) BeforeUpdate() error { 26 | ts.UpdatedAt = now() 27 | return nil 28 | } 29 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package genmai_test 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | 7 | _ "github.com/mattn/go-sqlite3" 8 | "github.com/naoina/genmai" 9 | ) 10 | 11 | type TestModel struct { 12 | Id int64 13 | Name string 14 | Addr string 15 | } 16 | 17 | func Example() { 18 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 19 | if err != nil { 20 | log.Fatal(err) 21 | } 22 | defer db.Close() 23 | for _, query := range []string{ 24 | `CREATE TABLE test_model ( 25 | id INTEGER NOT NULL PRIMARY KEY, 26 | name TEXT NOT NULL, 27 | addr TEXT NOT NULL 28 | )`, 29 | `INSERT INTO test_model VALUES (1, 'test1', 'addr1')`, 30 | `INSERT INTO test_model VALUES (2, 'test2', 'addr2')`, 31 | `INSERT INTO test_model VALUES (3, 'test3', 'addr3')`, 32 | } { 33 | if _, err := db.DB().Exec(query); err != nil { 34 | log.Fatal(err) 35 | } 36 | } 37 | var results []TestModel 38 | // SELECT * FROM "test_model"; 39 | if err := db.Select(&results); err != nil { 40 | log.Fatal(err) 41 | } 42 | fmt.Println(results) 43 | // Output: [{1 test1 addr1} {2 test2 addr2} {3 test3 addr3}] 44 | } 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014 Naoya Inada 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package genmai 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "time" 7 | "unicode" 8 | ) 9 | 10 | var now = time.Now // for test. 11 | 12 | // ToInterfaceSlice convert to []interface{} from []string. 13 | func ToInterfaceSlice(slice []string) []interface{} { 14 | result := make([]interface{}, len(slice)) 15 | for i, v := range slice { 16 | result[i] = v 17 | } 18 | return result 19 | } 20 | 21 | // columnName returns the column name that added the table name with quoted if needed. 22 | func ColumnName(d Dialect, tname, cname string) string { 23 | if cname != "*" { 24 | cname = d.Quote(cname) 25 | } 26 | if tname == "" { 27 | return cname 28 | } 29 | return fmt.Sprintf("%s.%s", d.Quote(tname), cname) 30 | } 31 | 32 | // IsUnexportedField returns whether the field is unexported. 33 | // This function is to avoid the bug in versions older than Go1.3. 34 | // See following links: 35 | // https://code.google.com/p/go/issues/detail?id=7247 36 | // http://golang.org/ref/spec#Exported_identifiers 37 | func IsUnexportedField(field reflect.StructField) bool { 38 | return !(field.PkgPath == "" && unicode.IsUpper(rune(field.Name[0]))) 39 | } 40 | 41 | func flatten(args []interface{}) []interface{} { 42 | result := make([]interface{}, 0, len(args)) 43 | for _, v := range args { 44 | switch rv := reflect.ValueOf(v); rv.Kind() { 45 | case reflect.Slice, reflect.Array: 46 | for i := 0; i < rv.Len(); i++ { 47 | result = append(result, rv.Index(i).Interface()) 48 | } 49 | default: 50 | result = append(result, v) 51 | } 52 | } 53 | return result 54 | } 55 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | package genmai 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func Test_ToInterfaceSlice(t *testing.T) { 9 | actual := ToInterfaceSlice([]string{"1", "hoge", "foo"}) 10 | expected := []interface{}{"1", "hoge", "foo"} 11 | if !reflect.DeepEqual(actual, expected) { 12 | t.Errorf("Expect %[1]q(type %[1]T), but %[2]q(type %[2]T)", expected, actual) 13 | } 14 | } 15 | 16 | func TestColumnName(t *testing.T) { 17 | for _, v := range []struct { 18 | tableName string 19 | columnName string 20 | expected string 21 | }{ 22 | {`test_table`, `*`, `"test_table".*`}, 23 | {`test_table`, `test_column`, `"test_table"."test_column"`}, 24 | {``, `test_column`, `"test_column"`}, 25 | {``, `*`, `*`}, 26 | } { 27 | actual := ColumnName(&SQLite3Dialect{}, v.tableName, v.columnName) 28 | expected := v.expected 29 | if !reflect.DeepEqual(actual, expected) { 30 | t.Errorf("Expect %q, but %q", expected, actual) 31 | } 32 | } 33 | } 34 | 35 | func TestIsUnexportedField(t *testing.T) { 36 | // test for bug case less than Go1.3. 37 | func() { 38 | type b struct{} 39 | type C struct { 40 | b 41 | } 42 | v := reflect.TypeOf(C{}).Field(0) 43 | actual := IsUnexportedField(v) 44 | expected := true 45 | if !reflect.DeepEqual(actual, expected) { 46 | t.Errorf("IsUnexportedField(%q) => %v, want %v", v, actual, expected) 47 | } 48 | }() 49 | 50 | // test for correct case. 51 | func() { 52 | type B struct{} 53 | type C struct { 54 | B 55 | } 56 | v := reflect.TypeOf(C{}).Field(0) 57 | actual := IsUnexportedField(v) 58 | expected := false 59 | if !reflect.DeepEqual(actual, expected) { 60 | t.Errorf("IsUnexportedField(%q) => %v, want %v", v, actual, expected) 61 | } 62 | }() 63 | } 64 | -------------------------------------------------------------------------------- /field_test.go: -------------------------------------------------------------------------------- 1 | package genmai 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestTimeStamp_BeforeInsert(t *testing.T) { 10 | createdAt, err := time.Parse("2006-01-02 15:04:05", "2014-02-24 22:36:56") 11 | if err != nil { 12 | t.Fatal(err) 13 | } 14 | updatedAt, err := time.Parse("2006-01-02 15:04:05", "2014-02-24 23:51:26") 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | n, err := time.Parse("2006-01-02 15:04:05", "2000-02-02 16:57:38") 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | baknow := now 23 | now = func() time.Time { 24 | return n 25 | } 26 | defer func() { 27 | now = baknow 28 | }() 29 | tm := &TimeStamp{ 30 | CreatedAt: createdAt, 31 | UpdatedAt: updatedAt, 32 | } 33 | if err := tm.BeforeInsert(); err != nil { 34 | t.Fatal(err) 35 | } 36 | actual := tm.CreatedAt 37 | expected := n 38 | if !reflect.DeepEqual(actual, expected) { 39 | t.Errorf("Expect %q, but %q", expected, actual) 40 | } 41 | actual = tm.UpdatedAt 42 | expected = n 43 | if !reflect.DeepEqual(actual, expected) { 44 | t.Errorf("Expect %q, but %q", expected, actual) 45 | } 46 | } 47 | 48 | func TestTimeStamp_BeforeUpdate(t *testing.T) { 49 | createdAt, err := time.Parse("2006-01-02 15:04:05", "2014-02-24 22:36:56") 50 | if err != nil { 51 | t.Fatal(err) 52 | } 53 | updatedAt, err := time.Parse("2006-01-02 15:04:05", "2014-02-24 23:51:26") 54 | if err != nil { 55 | t.Fatal(err) 56 | } 57 | n, err := time.Parse("2006-01-02 15:04:05", "2000-02-02 16:57:38") 58 | if err != nil { 59 | t.Fatal(err) 60 | } 61 | baknow := now 62 | now = func() time.Time { 63 | return n 64 | } 65 | defer func() { 66 | now = baknow 67 | }() 68 | tm := &TimeStamp{ 69 | CreatedAt: createdAt, 70 | UpdatedAt: updatedAt, 71 | } 72 | if err := tm.BeforeUpdate(); err != nil { 73 | t.Fatal(err) 74 | } 75 | actual := tm.CreatedAt 76 | expected := createdAt 77 | if !reflect.DeepEqual(actual, expected) { 78 | t.Errorf("Expect %q, but %q", expected, actual) 79 | } 80 | actual = tm.UpdatedAt 81 | expected = n 82 | if !reflect.DeepEqual(actual, expected) { 83 | t.Errorf("Expect %q, but %q", expected, actual) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /type.go: -------------------------------------------------------------------------------- 1 | package genmai 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "math/big" 7 | ) 8 | 9 | type ( 10 | Float32 float32 11 | Float64 float64 12 | ) 13 | 14 | // Rat is an wrapper of the Rat of math/big. 15 | // However, Rat implements the sql Scanner interface. 16 | type Rat struct { 17 | *big.Rat 18 | } 19 | 20 | // NewRat returns a new Rat. 21 | // This is the similar to NewRat of math/big. 22 | func NewRat(a, b int64) *Rat { 23 | return &Rat{ 24 | Rat: big.NewRat(a, b), 25 | } 26 | } 27 | 28 | // Scan implements the database/sql Scanner interface. 29 | func (rat *Rat) Scan(src interface{}) (err error) { 30 | rat.Rat = new(big.Rat) 31 | switch t := src.(type) { 32 | case string: 33 | _, err = fmt.Sscan(t, rat.Rat) 34 | case []byte: 35 | _, err = fmt.Sscan(string(t), rat.Rat) 36 | case float64: 37 | rat.Rat.SetFloat64(t) 38 | default: 39 | _, err = fmt.Sscan(fmt.Sprint(t), rat.Rat) 40 | } 41 | return err 42 | } 43 | 44 | // Value implements the database/sql/driver Valuer interface. 45 | func (rat Rat) Value() (driver.Value, error) { 46 | return rat.FloatString(decimalScale), nil 47 | } 48 | 49 | // Scan implements the database/sql Scanner interface. 50 | func (f *Float32) Scan(src interface{}) (err error) { 51 | switch t := src.(type) { 52 | case string: 53 | _, err = fmt.Sscan(t, f) 54 | case []byte: 55 | _, err = fmt.Sscan(string(t), f) 56 | case float64: 57 | *f = Float32(t) 58 | case int64: 59 | *f = Float32(t) 60 | default: 61 | _, err = fmt.Sscan(fmt.Sprint(t), f) 62 | } 63 | return err 64 | } 65 | 66 | // Value implements the database/sql/driver Valuer interface. 67 | func (f Float32) Value() (driver.Value, error) { 68 | return float64(f), nil 69 | } 70 | 71 | // Scan implements the database/sql Scanner interface. 72 | func (f *Float64) Scan(src interface{}) (err error) { 73 | switch t := src.(type) { 74 | case string: 75 | _, err = fmt.Sscan(t, f) 76 | case []byte: 77 | _, err = fmt.Sscan(string(t), f) 78 | case float64: 79 | *f = Float64(t) 80 | case int64: 81 | *f = Float64(t) 82 | default: 83 | _, err = fmt.Sscan(fmt.Sprint(t), f) 84 | } 85 | return err 86 | } 87 | 88 | // Value implements the database/sql/driver Valuer interface. 89 | func (f Float64) Value() (driver.Value, error) { 90 | return float64(f), nil 91 | } 92 | -------------------------------------------------------------------------------- /log.go: -------------------------------------------------------------------------------- 1 | package genmai 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "strings" 8 | "sync" 9 | "text/template" 10 | "time" 11 | ) 12 | 13 | const defaultLoggingFormat = `[{{.time.Format "2006-01-02 15:04:05"}}] [{{.duration}}] {{.query}}` 14 | 15 | var ( 16 | defaultLoggerTemplate = template.Must(template.New("genmai").Parse(defaultLoggingFormat)) 17 | defaultLogger = &nullLogger{} 18 | ) 19 | 20 | // logger is the interface that query logger. 21 | type logger interface { 22 | // Print outputs query log. 23 | Print(start time.Time, query string, args ...interface{}) error 24 | 25 | // SetFormat sets the format for logging. 26 | SetFormat(format string) error 27 | } 28 | 29 | // templateLogger is a logger that Go's template to be used as a format. 30 | // It implements the logger interface. 31 | type templateLogger struct { 32 | w io.Writer 33 | t *template.Template 34 | m sync.Mutex 35 | } 36 | 37 | // SetFormat sets the format for logging. 38 | func (l *templateLogger) SetFormat(format string) error { 39 | l.m.Lock() 40 | defer l.m.Unlock() 41 | t, err := template.New("genmai").Parse(format) 42 | if err != nil { 43 | return err 44 | } 45 | l.t = t 46 | return nil 47 | } 48 | 49 | // Print outputs query log using format template. 50 | // All arguments will be used to formatting. 51 | func (l *templateLogger) Print(start time.Time, query string, args ...interface{}) error { 52 | if len(args) > 0 { 53 | values := make([]string, len(args)) 54 | for i, arg := range args { 55 | values[i] = fmt.Sprintf("%#v", arg) 56 | } 57 | query = fmt.Sprintf("%v; [%v]", query, strings.Join(values, ", ")) 58 | } else { 59 | query = fmt.Sprintf("%s;", query) 60 | } 61 | data := map[string]interface{}{ 62 | "time": start, 63 | "duration": fmt.Sprintf("%.2fms", now().Sub(start).Seconds()*float64(time.Microsecond)), 64 | "query": query, 65 | } 66 | var buf bytes.Buffer 67 | if err := l.t.Execute(&buf, data); err != nil { 68 | return err 69 | } 70 | l.m.Lock() 71 | defer l.m.Unlock() 72 | if _, err := fmt.Fprintln(l.w, strings.TrimSuffix(buf.String(), "\n")); err != nil { 73 | return err 74 | } 75 | return nil 76 | } 77 | 78 | // nullLogger is a null logger. 79 | // It implements the logger interface. 80 | type nullLogger struct{} 81 | 82 | // SetFormat is a dummy method. 83 | func (l *nullLogger) SetFormat(format string) error { 84 | return nil 85 | } 86 | 87 | // Print is a dummy method. 88 | func (l *nullLogger) Print(start time.Time, query string, args ...interface{}) error { 89 | return nil 90 | } 91 | -------------------------------------------------------------------------------- /type_test.go: -------------------------------------------------------------------------------- 1 | package genmai 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "math/big" 7 | "reflect" 8 | "testing" 9 | 10 | _ "github.com/mattn/go-sqlite3" 11 | ) 12 | 13 | func TestNewRat(t *testing.T) { 14 | actual := NewRat(1, 3) 15 | expected := &Rat{Rat: big.NewRat(1, 3)} 16 | if !reflect.DeepEqual(actual, expected) { 17 | t.Errorf("Expect %q, but %q", expected, actual) 18 | } 19 | } 20 | 21 | func TestRat_Scan(t *testing.T) { 22 | db, err := sql.Open("sqlite3", ":memory:") 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | for _, query := range []string{ 27 | `CREATE TABLE test_table ( 28 | id integer, 29 | rstr numeric, 30 | rreal real 31 | );`, 32 | `INSERT INTO test_table (id, rstr, rreal) VALUES (1, '0.3', '0.4')`, 33 | } { 34 | if _, err := db.Exec(query); err != nil { 35 | t.Fatal(err) 36 | } 37 | } 38 | rstr := new(Rat) 39 | rreal := new(Rat) 40 | row := db.QueryRow(`SELECT rstr, rreal FROM test_table`) 41 | if err := row.Scan(rstr, rreal); err != nil { 42 | t.Fatal(err) 43 | } 44 | for _, v := range []struct { 45 | r *Rat 46 | float float64 47 | }{{rstr, 0.3}, {rreal, 0.4}} { 48 | actual := v.r 49 | expected := &Rat{Rat: new(big.Rat).SetFloat64(v.float)} 50 | if !reflect.DeepEqual(actual, expected) { 51 | t.Errorf("%v expects %q, but %q", v.float, expected, actual) 52 | } 53 | } 54 | } 55 | 56 | func TestRat_Value(t *testing.T) { 57 | db, err := sql.Open("sqlite3", ":memory:") 58 | if err != nil { 59 | t.Fatal(err) 60 | } 61 | for _, query := range []string{ 62 | `CREATE TABLE test_table ( 63 | id integer, 64 | r numeric 65 | );`, 66 | } { 67 | if _, err := db.Exec(query); err != nil { 68 | t.Fatal(err) 69 | } 70 | } 71 | r := &Rat{Rat: big.NewRat(3, 10)} 72 | if _, err := db.Exec(`INSERT INTO test_table (id, r) VALUES (1, ?);`, r); err != nil { 73 | t.Fatal(err) 74 | } 75 | row := db.QueryRow(`SELECT r FROM test_table`) 76 | var s string 77 | if err := row.Scan(&s); err != nil { 78 | t.Fatal(err) 79 | } 80 | actual := s 81 | expected := "0.3" 82 | if !reflect.DeepEqual(actual, expected) { 83 | t.Errorf("Expect %q, but %q", expected, actual) 84 | } 85 | } 86 | 87 | func TestFloat64_Scan(t *testing.T) { 88 | type testcase struct { 89 | value interface{} 90 | expect Float64 91 | } 92 | testcases := []testcase{ 93 | {"1.5", Float64(1.5)}, 94 | {[]byte("2.8"), Float64(2.8)}, 95 | {float64(10.5), Float64(10.5)}, 96 | {int64(10), Float64(10.0)}, 97 | {float32(15.5), Float64(15.5)}, // for "default" case, it's not normal. 98 | } 99 | for _, c := range testcases { 100 | var f Float64 101 | err := f.Scan(c.value) 102 | if err != nil { 103 | t.Errorf("Unexpected error") 104 | } 105 | if f != c.expect { 106 | t.Errorf("Expect %f, but %f", c.expect, f) 107 | } 108 | } 109 | } 110 | 111 | func TestFloat64_Value(t *testing.T) { 112 | expect, val := Float64(10.5), driver.Value(10.5) 113 | var f Float64 114 | f.Scan(val) 115 | if f != expect { 116 | t.Errorf("Expect %f, but %f", expect, f) 117 | } 118 | } 119 | 120 | func TestFloat32_Scan(t *testing.T) { 121 | type testcase struct { 122 | value interface{} 123 | expect Float32 124 | } 125 | testcases := []testcase{ 126 | {"1.5", Float32(1.5)}, 127 | {[]byte("2.8"), Float32(2.8)}, 128 | {float64(10.5), Float32(10.5)}, 129 | {int64(10), Float32(10.0)}, 130 | {float32(15.5), Float32(15.5)}, // for "default" case, it's not normal. 131 | } 132 | for _, c := range testcases { 133 | var f Float32 134 | err := f.Scan(c.value) 135 | if err != nil { 136 | t.Errorf("Unexpected error") 137 | } 138 | if f != c.expect { 139 | t.Errorf("Expect %f, but %f", c.expect, f) 140 | } 141 | } 142 | } 143 | 144 | func TestFloat32_Value(t *testing.T) { 145 | expect, val := Float32(10.5), driver.Value(10.5) 146 | var f Float32 147 | f.Scan(val) 148 | if f != expect { 149 | t.Errorf("Expect %f, but %f", expect, f) 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /example_select_test.go: -------------------------------------------------------------------------------- 1 | package genmai_test 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | 7 | "github.com/naoina/genmai" 8 | ) 9 | 10 | type M2 struct { 11 | Id int64 12 | Body string 13 | } 14 | 15 | func ExampleDB_Select_all() { 16 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 17 | if err != nil { 18 | log.Fatal(err) 19 | } 20 | var results []TestModel 21 | // SELECT "test_model".* FROM "test_model"; 22 | if err := db.Select(&results); err != nil { 23 | log.Fatal(err) 24 | } 25 | fmt.Println(results) 26 | } 27 | 28 | func ExampleDB_Select_where() { 29 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 30 | if err != nil { 31 | log.Fatal(err) 32 | } 33 | var results []TestModel 34 | // SELECT "test_model".* FROM "test_model" WHERE "id" = 1; 35 | if err := db.Select(&results, db.Where("id", "=", 1)); err != nil { 36 | log.Fatal(err) 37 | } 38 | fmt.Println(results) 39 | } 40 | 41 | func ExampleDB_Select_whereAnd() { 42 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 43 | if err != nil { 44 | log.Fatal(err) 45 | } 46 | var results []TestModel 47 | // SELECT "test_model".* FROM "test_model" WHERE "id" = 1 AND "name" = "alice"; 48 | if err := db.Select(&results, db.Where("id", "=", 1).And("name", "=", "alice")); err != nil { 49 | log.Fatal(err) 50 | } 51 | fmt.Println(results) 52 | } 53 | 54 | func ExampleDB_Select_whereNested() { 55 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 56 | if err != nil { 57 | log.Fatal(err) 58 | } 59 | var results []TestModel 60 | // SELECT "test_model".* FROM "test_model" WHERE "id" = 1 OR ("name" = "alice" AND "addr" != "Tokyo"); 61 | if err := db.Select(&results, db.Where("id", "=", 1).Or(db.Where("name", "=", "alice").And("addr", "!=", "Tokyo"))); err != nil { 62 | log.Fatal(err) 63 | } 64 | fmt.Println(results) 65 | } 66 | 67 | func ExampleDB_Select_in() { 68 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 69 | if err != nil { 70 | log.Fatal(err) 71 | } 72 | var results []TestModel 73 | // SELECT "test_model".* FROM "test_model" WHERE "id" IN (1, 3, 5); 74 | if err := db.Select(&results, db.Where("id").In(1, 3, 5)); err != nil { 75 | log.Fatal(err) 76 | } 77 | fmt.Println(results) 78 | } 79 | 80 | func ExampleDB_Select_inWithSlice() { 81 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 82 | if err != nil { 83 | log.Fatal(err) 84 | } 85 | var results []TestModel 86 | values := []int64{1, 3, 5} 87 | // SELECT "test_model".* FROM "test_model" WHERE "id" IN (1, 3, 5); 88 | if err := db.Select(&results, db.Where("id").In(values)); err != nil { 89 | log.Fatal(err) 90 | } 91 | fmt.Println(results) 92 | } 93 | 94 | func ExampleDB_Select_like() { 95 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 96 | if err != nil { 97 | log.Fatal(err) 98 | } 99 | var results []TestModel 100 | // SELECT "test_model".* FROM "test_model" WHERE "name" LIKE "alice%"; 101 | if err := db.Select(&results, db.Where("name").Like("alice%")); err != nil { 102 | log.Fatal(err) 103 | } 104 | fmt.Println(results) 105 | } 106 | 107 | func ExampleDB_Select_between() { 108 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 109 | if err != nil { 110 | log.Fatal(err) 111 | } 112 | var results []TestModel 113 | // SELECT "test_model".* FROM "test_model" WHERE "id" BETWEEN 3 AND 5; 114 | if err := db.Select(&results, db.Where("id").Between(3, 5)); err != nil { 115 | log.Fatal(err) 116 | } 117 | fmt.Println(results) 118 | } 119 | 120 | func ExampleDB_Select_orderBy() { 121 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 122 | if err != nil { 123 | log.Fatal(err) 124 | } 125 | var results []TestModel 126 | // SELECT "test_model".* FROM "test_model" ORDER BY "name" DESC; 127 | if err := db.Select(&results, db.OrderBy("name", genmai.DESC)); err != nil { 128 | log.Fatal(err) 129 | } 130 | fmt.Println(results) 131 | } 132 | 133 | func ExampleDB_Select_orderByWithSpecificTable() { 134 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 135 | if err != nil { 136 | log.Fatal(err) 137 | } 138 | var results []TestModel 139 | // SELECT "test_model".* FROM "test_model" ORDER BY "test_model"."name" DESC; 140 | if err := db.Select(&results, db.OrderBy(TestModel{}, "name", genmai.DESC)); err != nil { 141 | log.Fatal(err) 142 | } 143 | fmt.Println(results) 144 | } 145 | 146 | func ExampleDB_Select_orderByMultiple() { 147 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 148 | if err != nil { 149 | log.Fatal(err) 150 | } 151 | var results []TestModel 152 | // SELECT "test_model".* FROM "test_model" ORDER BY "name" DESC, "addr" DESC; 153 | if err := db.Select(&results, db.OrderBy("name", genmai.DESC, "addr", genmai.DESC)); err != nil { 154 | log.Fatal(err) 155 | } 156 | fmt.Println(results) 157 | } 158 | 159 | func ExampleDB_Select_limit() { 160 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 161 | if err != nil { 162 | log.Fatal(err) 163 | } 164 | var results []TestModel 165 | // SELECT "test_model".* FROM "test_model" LIMIT 3; 166 | if err := db.Select(&results, db.Limit(3)); err != nil { 167 | log.Fatal(err) 168 | } 169 | fmt.Println(results) 170 | } 171 | 172 | func ExampleDB_Select_offset() { 173 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 174 | if err != nil { 175 | log.Fatal(err) 176 | } 177 | var results []TestModel 178 | // SELECT "test_model".* FROM "test_model" OFFSET 10; 179 | if err := db.Select(&results, db.Offset(10)); err != nil { 180 | log.Fatal(err) 181 | } 182 | fmt.Println(results) 183 | } 184 | 185 | func ExampleDB_Select_distinct() { 186 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 187 | if err != nil { 188 | log.Fatal(err) 189 | } 190 | var results []TestModel 191 | // SELECT DISTINCT "test_model"."name" FROM "test_model"; 192 | if err := db.Select(&results, db.Distinct("name")); err != nil { 193 | log.Fatal(err) 194 | } 195 | fmt.Println(results) 196 | } 197 | 198 | func ExampleDB_Select_count() { 199 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 200 | if err != nil { 201 | log.Fatal(err) 202 | } 203 | var result int64 204 | // SELECT COUNT(*) FROM "test_model"; 205 | if err := db.Select(&result, db.Count(), db.From(TestModel{})); err != nil { 206 | log.Fatal(err) 207 | } 208 | fmt.Println(result) 209 | } 210 | 211 | func ExampleDB_Select_countDistinct() { 212 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 213 | if err != nil { 214 | log.Fatal(err) 215 | } 216 | var result int64 217 | // SELECT COUNT(DISTINCT "test_model"."name") FROM "test_model"; 218 | if err := db.Select(&result, db.Count(db.Distinct("name")), db.From(TestModel{})); err != nil { 219 | log.Fatal(err) 220 | } 221 | fmt.Println(result) 222 | } 223 | 224 | func ExampleDB_Select_columns() { 225 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 226 | if err != nil { 227 | log.Fatal(err) 228 | } 229 | var results []TestModel 230 | // SELECT "test_model"."id", "test_model"."name" FROM "test_model"; 231 | if err := db.Select(&results, []string{"id", "name"}); err != nil { 232 | log.Fatal(err) 233 | } 234 | fmt.Println(results) 235 | } 236 | 237 | func ExampleDB_Select_complex() { 238 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 239 | if err != nil { 240 | log.Fatal(err) 241 | } 242 | var results []TestModel 243 | // SELECT "test_model"."name" FROM "test_model" 244 | // WHERE "name" LIKE "%alice%" OR ("id" > 100 AND "id" < 200) OR ("id" BETWEEN 700 AND 1000) 245 | // ORDER BY "id" ASC LIMIT 2 OFFSET 5 246 | if err := db.Select(&results, "name", db.Where("name"). 247 | Like("%alice%"). 248 | Or(db.Where("id", ">", 100).And("id", "<", 200)). 249 | Or(db.Where("id").Between(700, 1000)). 250 | Limit(2).Offset(5).OrderBy("id", genmai.ASC), 251 | ); err != nil { 252 | log.Fatal(err) 253 | } 254 | fmt.Println(results) 255 | } 256 | 257 | func ExampleDB_Select_join() { 258 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 259 | if err != nil { 260 | log.Fatal(err) 261 | } 262 | type M2 struct { 263 | Id int64 264 | Body string 265 | } 266 | var results []TestModel 267 | // SELECT "test_model".* FROM "test_model" JOIN "m2" ON "test_model"."id" = "m2"."id"; 268 | if err := db.Select(&results, "name", db.Join(&M2{}).On("id")); err != nil { 269 | log.Fatal(err) 270 | } 271 | fmt.Println(results) 272 | } 273 | 274 | func ExampleDB_Select_joinWithSpecificTable() { 275 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 276 | if err != nil { 277 | log.Fatal(err) 278 | } 279 | type M2 struct { 280 | Id int64 281 | Body string 282 | } 283 | type M3 struct { 284 | TestModelId int64 285 | M2Id int64 286 | } 287 | var results []TestModel 288 | // SELECT "test_model".* FROM "test_model" JOIN "m3" ON "test_model"."id" = "m3"."test_model_id" JOIN "m2" ON "m3"."m2_id" = "m2"."id"; 289 | if err := db.Select(&results, "name", db.Join(&M3{}).On("id", "=", "test_model_id"), db.Join(&M2{}).On(&M3{}, "m2_id", "=", "id")); err != nil { 290 | log.Fatal(err) 291 | } 292 | fmt.Println(results) 293 | } 294 | 295 | func ExampleDB_Select_leftJoin() { 296 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 297 | if err != nil { 298 | log.Fatal(err) 299 | } 300 | var results []TestModel 301 | // SELECT "test_model".* FROM "test_model" LEFT JOIN "m2" ON "test_model"."name" = "m2"."body" WHERE "m2"."body" IS NULL; 302 | if err := db.Select(&results, "name", db.LeftJoin(&M2{}).On("name", "=", "body").Where(&M2{}, "body").IsNull()); err != nil { 303 | log.Fatal(err) 304 | } 305 | fmt.Println(results) 306 | } 307 | -------------------------------------------------------------------------------- /dialect.go: -------------------------------------------------------------------------------- 1 | package genmai 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | // Dialect is an interface that the dialect of the database. 12 | type Dialect interface { 13 | // Name returns a name of the dialect. 14 | // Return value must be same as the driver name. 15 | Name() string 16 | 17 | // Quote returns a quoted s. 18 | // It is for a column name, not a value. 19 | Quote(s string) string 20 | 21 | // PlaceHolder returns the placeholder character of the database. 22 | // A current number of placeholder will passed to i. 23 | PlaceHolder(i int) string 24 | 25 | // SQLType returns the SQL type of the v. 26 | // autoIncrement is whether the field is auto increment. 27 | // If "size" tag specified to struct field, it will passed to size 28 | // argument. If it doesn't specify, size is 0. 29 | SQLType(v interface{}, autoIncrement bool, size uint64) (name string, allowNull bool) 30 | 31 | // AutoIncrement returns the keyword of auto increment. 32 | AutoIncrement() string 33 | 34 | // FormatBool returns boolean value as string according to the value of b. 35 | FormatBool(b bool) string 36 | 37 | // LastInsertId returns an SQL to get the last inserted id. 38 | LastInsertId() string 39 | } 40 | 41 | var ( 42 | ErrUsingFloatType = errors.New("float types have a rounding error problem.\n" + 43 | "Please use `genmai.Rat` if you want an exact value.\n" + 44 | "However, if you still want a float types, please use `genmai.Float32` and `Float64`.") 45 | ) 46 | 47 | const ( 48 | // Precision of the fixed-point number. 49 | // Digits of precision before the decimal point. 50 | decimalPrecision = 65 51 | 52 | // Scale of the fixed-point number. 53 | // Digits of precision after the decimal point. 54 | decimalScale = 30 55 | ) 56 | 57 | // SQLite3Dialect represents a dialect of the SQLite3. 58 | // It implements the Dialect interface. 59 | type SQLite3Dialect struct{} 60 | 61 | // Name returns name of the dialect. 62 | func (d *SQLite3Dialect) Name() string { 63 | return "sqlite3" 64 | } 65 | 66 | // Quote returns a quoted s for a column name. 67 | func (d *SQLite3Dialect) Quote(s string) string { 68 | return fmt.Sprintf(`"%s"`, strings.Replace(s, `"`, `""`, -1)) 69 | } 70 | 71 | // PlaceHolder returns the placeholder character of the SQLite3. 72 | func (d *SQLite3Dialect) PlaceHolder(i int) string { 73 | return "?" 74 | } 75 | 76 | // SQLType returns the SQL type of the v for SQLite3. 77 | func (d *SQLite3Dialect) SQLType(v interface{}, autoIncrement bool, size uint64) (name string, allowNull bool) { 78 | switch v.(type) { 79 | case bool: 80 | return "boolean", false 81 | case *bool, sql.NullBool: 82 | return "boolean", true 83 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: 84 | return "integer", false 85 | case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, sql.NullInt64: 86 | return "integer", true 87 | case string: 88 | return "text", false 89 | case *string, sql.NullString: 90 | return "text", true 91 | case []byte: 92 | return "blob", true 93 | case time.Time: 94 | return "datetime", false 95 | case *time.Time: 96 | return "datetime", true 97 | case Float32, Float64: 98 | return "real", false 99 | case *Float32, *Float64: 100 | return "real", true 101 | case Rat: 102 | return "numeric", false 103 | case *Rat: 104 | return "numeric", true 105 | case float32, *float32, float64, *float64, sql.NullFloat64: 106 | panic(ErrUsingFloatType) 107 | } 108 | panic(fmt.Errorf("SQLite3Dialect: unsupported SQL type: %T", v)) 109 | } 110 | 111 | func (d *SQLite3Dialect) AutoIncrement() string { 112 | return "AUTOINCREMENT" 113 | } 114 | 115 | // FormatBool returns "1" or "0" according to the value of b as boolean for SQLite3. 116 | func (d *SQLite3Dialect) FormatBool(b bool) string { 117 | if b { 118 | return "1" 119 | } else { 120 | return "0" 121 | } 122 | } 123 | 124 | func (d *SQLite3Dialect) LastInsertId() string { 125 | return `SELECT last_insert_rowid()` 126 | } 127 | 128 | // MySQLDialect represents a dialect of the MySQL. 129 | // It implements the Dialect interface. 130 | type MySQLDialect struct{} 131 | 132 | // Name returns name of the MySQLDialect. 133 | func (d *MySQLDialect) Name() string { 134 | return "mysql" 135 | } 136 | 137 | // Quote returns a quoted s for a column name. 138 | func (d *MySQLDialect) Quote(s string) string { 139 | return fmt.Sprintf("`%s`", strings.Replace(s, "`", "``", -1)) 140 | } 141 | 142 | // PlaceHolder returns the placeholder character of the MySQL. 143 | func (d *MySQLDialect) PlaceHolder(i int) string { 144 | return "?" 145 | } 146 | 147 | // SQLType returns the SQL type of the v for MySQL. 148 | func (d *MySQLDialect) SQLType(v interface{}, autoIncrement bool, size uint64) (name string, allowNull bool) { 149 | switch v.(type) { 150 | case bool: 151 | return "BOOLEAN", false 152 | case *bool, sql.NullBool: 153 | return "BOOLEAN", true 154 | case int8, int16, uint8, uint16: 155 | return "SMALLINT", false 156 | case *int8, *int16, *uint8, *uint16: 157 | return "SMALLINT", true 158 | case int, int32, uint, uint32: 159 | return "INT", false 160 | case *int, *int32, *uint, *uint32: 161 | return "INT", true 162 | case int64, uint64: 163 | return "BIGINT", false 164 | case *int64, *uint64, sql.NullInt64: 165 | return "BIGINT", true 166 | case string: 167 | return d.varchar(size), false 168 | case *string, sql.NullString: 169 | return d.varchar(size), true 170 | case []byte: 171 | switch { 172 | case size == 0: 173 | return "VARBINARY(255)", true // default. 174 | case size < (1<<16)-1-2: // approximate 64KB. 175 | // 65533 ((2^16) - 1) - (length of prefix) 176 | // See http://dev.mysql.com/doc/refman/5.5/en/string-type-overview.html#idm47703458759504 177 | return fmt.Sprintf("VARBINARY(%d)", size), true 178 | case size < 1<<24: // 16MB. 179 | return "MEDIUMBLOB", true 180 | } 181 | return "LONGBLOB", true 182 | case time.Time: 183 | return "DATETIME", false 184 | case *time.Time: 185 | return "DATETIME", true 186 | case Rat: 187 | return fmt.Sprintf("DECIMAL(%d, %d)", decimalPrecision, decimalScale), false 188 | case *Rat: 189 | return fmt.Sprintf("DECIMAL(%d, %d)", decimalPrecision, decimalScale), true 190 | case Float32, Float64: 191 | return "DOUBLE", false 192 | case *Float32, *Float64: 193 | return "DOUBLE", true 194 | case float32, *float32, float64, *float64, sql.NullFloat64: 195 | panic(ErrUsingFloatType) 196 | } 197 | panic(fmt.Errorf("MySQLDialect: unsupported SQL type: %T", v)) 198 | } 199 | 200 | func (d *MySQLDialect) AutoIncrement() string { 201 | return "AUTO_INCREMENT" 202 | } 203 | 204 | // FormatBool returns "TRUE" or "FALSE" according to the value of b as boolean for MySQL. 205 | func (d *MySQLDialect) FormatBool(b bool) string { 206 | if b { 207 | return "TRUE" 208 | } else { 209 | return "FALSE" 210 | } 211 | } 212 | 213 | func (d *MySQLDialect) LastInsertId() string { 214 | return `SELECT LAST_INSERT_ID()` 215 | } 216 | 217 | func (d *MySQLDialect) varchar(size uint64) string { 218 | switch { 219 | case size == 0: 220 | return "VARCHAR(255)" // default. 221 | case size < (1<<16)-1-2: // approximate 64KB. 222 | // 65533 ((2^16) - 1) - (length of prefix) 223 | // See http://dev.mysql.com/doc/refman/5.5/en/string-type-overview.html#idm47703458792704 224 | return fmt.Sprintf("VARCHAR(%d)", size) 225 | case size < 1<<24: // 16MB. 226 | return "MEDIUMTEXT" 227 | } 228 | return "LONGTEXT" 229 | } 230 | 231 | // PostgresDialect represents a dialect of the PostgreSQL. 232 | // It implements the Dialect interface. 233 | type PostgresDialect struct{} 234 | 235 | // Name returns name of the PostgresDialect. 236 | func (d *PostgresDialect) Name() string { 237 | return "postgres" 238 | } 239 | 240 | // Quote returns a quoted s for a column name. 241 | func (d *PostgresDialect) Quote(s string) string { 242 | return fmt.Sprintf(`"%s"`, strings.Replace(s, `"`, `""`, -1)) 243 | } 244 | 245 | // PlaceHolder returns the placeholder character of the PostgreSQL. 246 | func (d *PostgresDialect) PlaceHolder(i int) string { 247 | return fmt.Sprintf("$%d", i+1) 248 | } 249 | 250 | // SQLType returns the SQL type of the v for PostgreSQL. 251 | func (d *PostgresDialect) SQLType(v interface{}, autoIncrement bool, size uint64) (name string, allowNull bool) { 252 | switch v.(type) { 253 | case bool: 254 | return "boolean", false 255 | case *bool, sql.NullBool: 256 | return "boolean", true 257 | case int8, int16, uint8, uint16: 258 | return d.smallint(autoIncrement), false 259 | case *int8, *int16, *uint8, *uint16: 260 | return d.smallint(autoIncrement), true 261 | case int, int32, uint, uint32: 262 | return d.integer(autoIncrement), false 263 | case *int, *int32, *uint, *uint32: 264 | return d.integer(autoIncrement), true 265 | case int64, uint64: 266 | return d.bigint(autoIncrement), false 267 | case *int64, *uint64, sql.NullInt64: 268 | return d.bigint(autoIncrement), true 269 | case string: 270 | return d.varchar(size), false 271 | case *string, sql.NullString: 272 | return d.varchar(size), true 273 | case []byte: 274 | return "bytea", true 275 | case time.Time: 276 | return "timestamp with time zone", false 277 | case *time.Time: 278 | return "timestamp with time zone", true 279 | case Rat: 280 | return fmt.Sprintf("numeric(%d, %d)", decimalPrecision, decimalScale), false 281 | case *Rat: 282 | return fmt.Sprintf("numeric(%d, %d)", decimalPrecision, decimalScale), true 283 | case Float32, Float64: 284 | return "double precision", false 285 | case *Float32, *Float64: 286 | return "double precision", true 287 | case float32, *float32, float64, *float64, sql.NullFloat64: 288 | panic(ErrUsingFloatType) 289 | } 290 | panic(fmt.Errorf("PostgresDialect: unsupported SQL type: %T", v)) 291 | } 292 | 293 | func (d *PostgresDialect) AutoIncrement() string { 294 | return "" 295 | } 296 | 297 | // FormatBool returns "TRUE" or "FALSE" according to the value of b as boolean for PostgreSQL. 298 | func (d *PostgresDialect) FormatBool(b bool) string { 299 | if b { 300 | return "TRUE" 301 | } else { 302 | return "FALSE" 303 | } 304 | } 305 | 306 | func (d *PostgresDialect) LastInsertId() string { 307 | return `SELECT lastval()` 308 | } 309 | 310 | func (d *PostgresDialect) smallint(autoIncrement bool) string { 311 | if autoIncrement { 312 | return "smallserial" 313 | } 314 | return "smallint" 315 | } 316 | 317 | func (d *PostgresDialect) integer(autoIncrement bool) string { 318 | if autoIncrement { 319 | return "serial" 320 | } 321 | return "integer" 322 | } 323 | 324 | func (d *PostgresDialect) bigint(autoIncrement bool) string { 325 | if autoIncrement { 326 | return "bigserial" 327 | } 328 | return "bigint" 329 | } 330 | 331 | func (d *PostgresDialect) varchar(size uint64) string { 332 | switch { 333 | case size == 0: 334 | return "varchar(255)" // default. 335 | case size < (1<<16)-1-2: // approximate 64KB. 336 | // This isn't required in PostgreSQL, but defined in order to match to the MySQLDialect. 337 | return fmt.Sprintf("varchar(%d)", size) 338 | } 339 | return "text" 340 | } 341 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Genmai [![Build Status](https://travis-ci.org/naoina/genmai.png?branch=master)](https://travis-ci.org/naoina/genmai) 2 | 3 | Simple, better and easy-to-use ORM library for [Golang](http://golang.org/). 4 | 5 | ## Overview 6 | 7 | * flexibility with SQL-like API 8 | * Transaction support 9 | * Database dialect interface 10 | * Query logging 11 | * Update/Insert/Delete hooks 12 | * Embedded struct 13 | 14 | Database dialect currently supported are: 15 | 16 | * MySQL 17 | * PostgreSQL 18 | * SQLite3 19 | 20 | ## Installation 21 | 22 | go get -u github.com/naoina/genmai 23 | 24 | ## Schema 25 | 26 | Schema of the table will be defined as a struct. 27 | 28 | ```go 29 | // The struct "User" is the table name "user". 30 | // The field name will be converted to lowercase/snakecase, and used as a column name in table. 31 | // e.g. If field name is "CreatedAt", column name is "created_at". 32 | type User struct { 33 | // PRIMARY KEY. and column name will use "tbl_id" instead of "id". 34 | Id int64 `db:"pk" column:"tbl_id"` 35 | 36 | // NOT NULL and Default is "me". 37 | Name string `default:"me"` 38 | 39 | // Nullable column must use a pointer type, or sql.Null* types. 40 | CreatedAt *time.Time 41 | 42 | // UNIQUE column if specify the db:"unique" tag. 43 | ScreenName string `db:"unique"` 44 | 45 | // Ignore column if specify the db:"-" tag. 46 | Active bool `db:"-"` 47 | } 48 | ``` 49 | 50 | ## Query API 51 | 52 | ### Create table 53 | 54 | ```go 55 | package main 56 | 57 | import ( 58 | "time" 59 | 60 | _ "github.com/mattn/go-sqlite3" 61 | // _ "github.com/go-sql-driver/mysql" 62 | // _ "github.com/lib/pq" 63 | "github.com/naoina/genmai" 64 | ) 65 | 66 | // define a table schema. 67 | type TestTable struct { 68 | Id int64 `db:"pk" column:"tbl_id"` 69 | Name string `default:"me"` 70 | CreatedAt *time.Time 71 | UserName string `db:"unique" size:"255"` 72 | Active bool `db:"-"` 73 | } 74 | 75 | func main() { 76 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 77 | // db, err := genmai.New(&genmai.MySQLDialect{}, "dsn") 78 | // db, err := genmai.New(&genmai.PostgresDialect{}, "dsn") 79 | if err != nil { 80 | panic(err) 81 | } 82 | defer db.Close() 83 | if err := db.CreateTable(&TestTable{}); err != nil { 84 | panic(err) 85 | } 86 | } 87 | ``` 88 | 89 | ### Insert 90 | 91 | A single insert: 92 | 93 | ```go 94 | obj := &TestTable{ 95 | Name: "alice", 96 | Active: true, 97 | } 98 | n, err := db.Insert(obj) 99 | if err != nil { 100 | panic(err) 101 | } 102 | fmt.Printf("inserted rows: %d\n", n) 103 | ``` 104 | 105 | Or bulk-insert: 106 | 107 | ```go 108 | objs := []TestTable{ 109 | {Name: "alice", Active: true}, 110 | {Name: "bob", Active: true}, 111 | } 112 | n, err := db.Insert(objs) 113 | if err != nil { 114 | panic(err) 115 | } 116 | fmt.Printf("inserted rows: %d\n", n) 117 | ``` 118 | 119 | ### Select 120 | 121 | ```go 122 | var results []TestTable 123 | if err := db.Select(&results); err != nil { 124 | panic(err) 125 | } 126 | fmt.Printf("%v\n", results) 127 | ``` 128 | 129 | ### Where 130 | 131 | ```go 132 | var results []TestTable 133 | if err := db.Select(&results, db.Where("tbl_id", "=", 1)); err != nil { 134 | panic(err) 135 | } 136 | fmt.Printf("%v\n", results) 137 | ``` 138 | 139 | ### And/Or 140 | 141 | ```go 142 | var results []TestTable 143 | if err := db.Select(&results, db.Where("tbl_id", "=", 1).And(db.Where("name", "=", "alice").Or("name", "=", "bob"))); err != nil { 144 | panic(err) 145 | } 146 | fmt.Printf("%v\n", results) 147 | ``` 148 | 149 | ### In 150 | 151 | ```go 152 | var results []TestTable 153 | if err := db.Select(&results, db.Where("tbl_id").In(1, 2, 4)); err != nil { 154 | panic(err) 155 | } 156 | fmt.Printf("%v\n", results) 157 | ``` 158 | 159 | ### Like 160 | 161 | ```go 162 | var results []TestTable 163 | if err := db.Select(&results, db.Where("name").Like("%li%")); err != nil { 164 | panic(err) 165 | } 166 | fmt.Printf("%v\n", results) 167 | ``` 168 | 169 | ### Between 170 | 171 | ```go 172 | var results []TestTable 173 | if err := db.Select(&results, db.Where("name").Between(1, 3)); err != nil { 174 | panic(err) 175 | } 176 | fmt.Printf("%v\n", results) 177 | ``` 178 | 179 | ### Is Null/Is Not Null 180 | 181 | ```go 182 | var results []TestTable 183 | if err := db.Select(&results, db.Where("created_at").IsNull()); err != nil { 184 | panic(err) 185 | } 186 | fmt.Printf("%v\n", results) 187 | results = []TestTable{} 188 | if err := db.Select(&results, db.Where("created_at").IsNotNull()); err != nil { 189 | panic(err) 190 | } 191 | fmt.Printf("%v\n", results) 192 | ``` 193 | 194 | ### Order by/Offset/Limit 195 | 196 | ```go 197 | var results []TestTable 198 | if err := db.Select(&results, db.OrderBy("id", genmai.ASC).Offset(2).Limit(10)); err != nil { 199 | panic(err) 200 | } 201 | fmt.Printf("%v\n", results) 202 | ``` 203 | 204 | ### Distinct 205 | 206 | ```go 207 | var results []TestTable 208 | if err := db.Select(&results, db.Distinct("name"), db.Where("name").Like("%")); err != nil { 209 | panic(err) 210 | } 211 | fmt.Printf("%v\n", results) 212 | ``` 213 | 214 | ### Count 215 | 216 | ```go 217 | var n int64 218 | if err := db.Select(&n, db.Count(), db.From(&TestTable{})); err != nil { 219 | panic(err) 220 | } 221 | fmt.Printf("%v\n", n) 222 | ``` 223 | 224 | With condition: 225 | 226 | ```go 227 | var n int64 228 | if err := db.Select(&n, db.Count(), db.From(&TestTable{}), db.Where("id", ">", 100)); err != nil { 229 | panic(err) 230 | } 231 | fmt.Printf("%v\n", n) 232 | ``` 233 | 234 | ### Join 235 | 236 | Inner Join: 237 | 238 | ```go 239 | package main 240 | 241 | import ( 242 | "database/sql" 243 | "fmt" 244 | "time" 245 | 246 | _ "github.com/mattn/go-sqlite3" 247 | // _ "github.com/go-sql-driver/mysql" 248 | // _ "github.com/lib/pq" 249 | "github.com/naoina/genmai" 250 | ) 251 | 252 | type TestTable struct { 253 | Id int64 `db:"pk" column:"tbl_id"` 254 | Name string `default:"me"` 255 | CreatedAt *time.Time 256 | Active bool `db:"-"` // column to ignore. 257 | } 258 | 259 | type Table2 struct { 260 | Id int64 `db:"pk" column:"tbl_id"` 261 | Body sql.NullString 262 | } 263 | 264 | func main() { 265 | db, err := genmai.New(&genmai.SQLite3Dialect{}, ":memory:") 266 | // db, err := genmai.New(&genmai.MySQLDialect{}, "dsn") 267 | // db, err := genmai.New(&genmai.PostgresDialect{}, "dsn") 268 | if err != nil { 269 | panic(err) 270 | } 271 | defer db.Close() 272 | if err := db.CreateTable(&TestTable{}); err != nil { 273 | panic(err) 274 | } 275 | if err := db.CreateTable(&Table2{}); err != nil { 276 | panic(err) 277 | } 278 | objs1 := []TestTable{ 279 | {Name: "alice", Active: true}, 280 | {Name: "bob", Active: true}, 281 | } 282 | objs2 := []Table2{ 283 | {Body: sql.NullString{String: "something"}}, 284 | } 285 | if _, err = db.Insert(objs1); err != nil { 286 | panic(err) 287 | } 288 | if _, err := db.Insert(objs2); err != nil { 289 | panic(err) 290 | } 291 | // fmt.Printf("inserted rows: %d\n", n) 292 | var results []TestTable 293 | if err := db.Select(&results, db.Join(&Table2{}).On("tbl_id")); err != nil { 294 | panic(err) 295 | } 296 | fmt.Printf("%v\n", results) 297 | } 298 | ``` 299 | 300 | Left Join: 301 | 302 | ```go 303 | var results []TestTable 304 | if err := db.Select(&results, db.LeftJoin(&Table2{}).On("name", "=", "body")); err != nil { 305 | panic(err) 306 | } 307 | fmt.Printf("%v\n", results) 308 | ``` 309 | 310 | `RIGHT OUTER JOIN` and `FULL OUTER JOIN` are still unsupported. 311 | 312 | ### Update 313 | 314 | ```go 315 | var results []TestTable 316 | if err := db.Select(&results); err != nil { 317 | panic(err) 318 | } 319 | obj := results[0] 320 | obj.Name = "nico" 321 | if _, err := db.Update(&obj); err != nil { 322 | panic(err) 323 | } 324 | ``` 325 | 326 | ### Delete 327 | 328 | A single delete: 329 | 330 | ```go 331 | obj := TestTable{Id: 1} 332 | if _, err := db.Delete(&obj); err != nil { 333 | panic(err) 334 | } 335 | ``` 336 | 337 | Or bulk-delete: 338 | 339 | ```go 340 | objs := []TestTable{ 341 | {Id: 1}, {Id: 3}, 342 | } 343 | if _, err := db.Delete(objs); err != nil { 344 | panic(err) 345 | } 346 | ``` 347 | 348 | ### Transaction 349 | 350 | ```go 351 | defer func() { 352 | if err := recover(); err != nil { 353 | db.Rollback() 354 | } else { 355 | db.Commit() 356 | } 357 | }() 358 | if err := db.Begin(); err != nil { 359 | panic(err) 360 | } 361 | // do something. 362 | ``` 363 | 364 | ### Using any table name 365 | 366 | You can implement [TableNamer](https://godoc.org/github.com/naoina/genmai#TableNamer) interface to use any table name. 367 | 368 | ```go 369 | type UserTable struct { 370 | Id int64 `db:"pk"` 371 | } 372 | 373 | func (u *UserTable) TableName() string { 374 | return "user" 375 | } 376 | ``` 377 | 378 | In the above example, the table name `user` is used instead of `user_table`. 379 | 380 | ### Using raw database/sql interface 381 | 382 | ```go 383 | rawDB := db.DB() 384 | // do something with using raw database/sql interface... 385 | ``` 386 | 387 | ## Query logging 388 | 389 | By default, query logging is disabled. 390 | You can enable Query logging as follows. 391 | 392 | ```go 393 | db.SetLogOutput(os.Stdout) // Or any io.Writer can be passed. 394 | ``` 395 | 396 | Also you can change the format of output as follows. 397 | 398 | ```go 399 | db.SetLogFormat("format string") 400 | ``` 401 | 402 | Format syntax uses Go's template. And you can use the following data object in that template. 403 | 404 | ``` 405 | - .time time.Time object in current time. 406 | - .duration Processing time of SQL. It will format to "%.2fms". 407 | - .query string of SQL query. If it using placeholder, 408 | placeholder parameters will append to the end of query. 409 | ``` 410 | 411 | The default format is: 412 | 413 | [{{.time.Format "2006-01-02 15:04:05"}}] [{{.duration}}] {{.query}} 414 | 415 | In production, it is recommended to disable this feature in order to somewhat affect performance. 416 | 417 | ```go 418 | db.SetLogOutput(nil) // To disable logging by nil. 419 | ``` 420 | 421 | ## Update/Insert/Delete hooks 422 | 423 | Genmai calls `Before`/`After` hook method if defined in model struct. 424 | 425 | ```go 426 | func (t *TestTable) BeforeInsert() error { 427 | t.CreatedAt = time.Now() 428 | return nil 429 | } 430 | ``` 431 | 432 | If `Before` prefixed hook returns an error, it query won't run. 433 | 434 | All hooks are: 435 | 436 | * `BeforeInsert/AfterInsert` 437 | * `BeforeUpdate/AfterUpdate` 438 | * `BeforeDelete/AfterDelete` 439 | 440 | If use bulk-insert or bulk-delete, hooks method run for each object. 441 | 442 | ## Embedded struct 443 | 444 | Common fields can be defined on struct and embed that. 445 | 446 | ```go 447 | package main 448 | 449 | import "time" 450 | 451 | type TimeStamp struct { 452 | CreatedAt time.Time 453 | UpdatedAt time.Time 454 | } 455 | 456 | type User struct { 457 | Id int64 458 | Name string 459 | 460 | TimeStamp 461 | } 462 | ``` 463 | 464 | Also Genmai has defined `TimeStamp` struct for commonly used fields. 465 | 466 | ``` 467 | type User struct { 468 | Id int64 469 | 470 | genmai.TimeStamp 471 | } 472 | ``` 473 | 474 | See the Godoc of [TimeStamp](http://godoc.org/github.com/naoina/genmai#TimeStamp) for more information. 475 | 476 | If you'll override hook method defined in embedded struct, you'll should call the that hook in overridden method. 477 | For example in above struct case: 478 | 479 | ```go 480 | func (u *User) BeforeInsert() error { 481 | if err := u.TimeStamp.BeforeInsert(); err != nil { 482 | return err 483 | } 484 | // do something. 485 | return nil 486 | } 487 | ``` 488 | 489 | ## Documentation 490 | 491 | API document and more examples are available here: 492 | 493 | http://godoc.org/github.com/naoina/genmai 494 | 495 | ## TODO 496 | 497 | * Benchmark 498 | * More SQL support 499 | * Migration 500 | 501 | ## License 502 | 503 | Genmai is licensed under the MIT 504 | -------------------------------------------------------------------------------- /genmai.go: -------------------------------------------------------------------------------- 1 | // Copyright 2014 Naoya Inada. All rights reserved. 2 | // Use of this source code is governed by the MIT 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package genmai provides simple, better and easy-to-use Object-Relational Mapper. 6 | package genmai 7 | 8 | import ( 9 | "database/sql" 10 | "errors" 11 | "fmt" 12 | "io" 13 | "reflect" 14 | "runtime" 15 | "sort" 16 | "strconv" 17 | "strings" 18 | "sync" 19 | 20 | "github.com/naoina/go-stringutil" 21 | ) 22 | 23 | var ErrTxDone = errors.New("genmai: transaction hasn't been started or already committed or rolled back") 24 | 25 | // DB represents a database object. 26 | type DB struct { 27 | db *sql.DB 28 | dialect Dialect 29 | tx *sql.Tx 30 | m sync.Mutex 31 | logger logger 32 | } 33 | 34 | // New returns a new DB. 35 | // If any error occurs, it returns nil and error. 36 | func New(dialect Dialect, dsn string) (*DB, error) { 37 | db, err := sql.Open(dialect.Name(), dsn) 38 | if err != nil { 39 | return nil, err 40 | } 41 | return &DB{db: db, dialect: dialect, logger: defaultLogger}, nil 42 | } 43 | 44 | // Select fetch data into the output from the database. 45 | // output argument must be pointer to a slice of struct. If not a pointer or not a slice of struct, It returns error. 46 | // The table name of the database will be determined from name of struct. e.g. If *[]ATableName passed to output argument, table name will be "a_table_name". 47 | // If args are not given, fetch the all data like "SELECT * FROM table" SQL. 48 | func (db *DB) Select(output interface{}, args ...interface{}) (err error) { 49 | defer func() { 50 | if e := recover(); e != nil { 51 | buf := make([]byte, 4096) 52 | n := runtime.Stack(buf, false) 53 | err = fmt.Errorf("%v\n%v", e, string(buf[:n])) 54 | } 55 | }() 56 | rv := reflect.ValueOf(output) 57 | if rv.Kind() != reflect.Ptr { 58 | return fmt.Errorf("Select: first argument must be a pointer") 59 | } 60 | for rv.Kind() == reflect.Ptr { 61 | rv = rv.Elem() 62 | } 63 | var tableName string 64 | for _, arg := range args { 65 | if f, ok := arg.(*From); ok { 66 | if tableName != "" { 67 | return fmt.Errorf("Select: From statement specified more than once") 68 | } 69 | tableName = f.TableName 70 | } 71 | } 72 | var selectFunc selectFunc 73 | ptrN := 0 74 | switch rv.Kind() { 75 | case reflect.Slice: 76 | t := rv.Type().Elem() 77 | for ; t.Kind() == reflect.Ptr; ptrN++ { 78 | t = t.Elem() 79 | } 80 | if t.Kind() != reflect.Struct { 81 | return fmt.Errorf("Select: argument of slice must be slice of struct, but %v", rv.Type()) 82 | } 83 | if tableName == "" { 84 | tableName = db.tableName(t) 85 | } 86 | selectFunc = db.selectToSlice 87 | case reflect.Invalid: 88 | return fmt.Errorf("Select: nil pointer dereference") 89 | default: 90 | if tableName == "" { 91 | return fmt.Errorf("Select: From statement must be given if any Function is given") 92 | } 93 | selectFunc = db.selectToValue 94 | } 95 | col, from, conditions, err := db.classify(tableName, args) 96 | if err != nil { 97 | return err 98 | } 99 | queries := []string{`SELECT`, col, `FROM`, db.dialect.Quote(from)} 100 | var values []interface{} 101 | for _, cond := range conditions { 102 | q, a := cond.build(0, false) 103 | queries = append(queries, q...) 104 | values = append(values, a...) 105 | } 106 | query := strings.Join(queries, " ") 107 | stmt, err := db.prepare(query, values...) 108 | if err != nil { 109 | return err 110 | } 111 | defer stmt.Close() 112 | rows, err := stmt.Query(values...) 113 | if err != nil { 114 | return err 115 | } 116 | defer rows.Close() 117 | value, err := selectFunc(rows, rv.Type()) 118 | if err != nil { 119 | return err 120 | } 121 | rv.Set(value) 122 | return nil 123 | } 124 | 125 | // From returns a "FROM" statement. 126 | // A table name will be determined from name of struct of arg. 127 | // If arg argument is not struct type, it panics. 128 | func (db *DB) From(arg interface{}) *From { 129 | t := reflect.Indirect(reflect.ValueOf(arg)).Type() 130 | if t.Kind() != reflect.Struct { 131 | panic(fmt.Errorf("From: argument must be struct (or that pointer) type, got %v", t)) 132 | } 133 | return &From{TableName: db.tableName(t)} 134 | } 135 | 136 | // Where returns a new Condition of "WHERE" clause. 137 | func (db *DB) Where(cond interface{}, args ...interface{}) *Condition { 138 | return newCondition(db).Where(cond, args...) 139 | } 140 | 141 | // OrderBy returns a new Condition of "ORDER BY" clause. 142 | func (db *DB) OrderBy(table interface{}, column interface{}, order ...interface{}) *Condition { 143 | return newCondition(db).OrderBy(table, column, order...) 144 | } 145 | 146 | // Limit returns a new Condition of "LIMIT" clause. 147 | func (db *DB) Limit(lim int) *Condition { 148 | return newCondition(db).Limit(lim) 149 | } 150 | 151 | // Offset returns a new Condition of "OFFSET" clause. 152 | func (db *DB) Offset(offset int) *Condition { 153 | return newCondition(db).Offset(offset) 154 | } 155 | 156 | // Distinct returns a representation object of "DISTINCT" statement. 157 | func (db *DB) Distinct(columns ...string) *Distinct { 158 | return &Distinct{columns: columns} 159 | } 160 | 161 | // Join returns a new JoinCondition of "JOIN" clause. 162 | func (db *DB) Join(table interface{}) *JoinCondition { 163 | return (&JoinCondition{db: db}).Join(table) 164 | } 165 | 166 | func (db *DB) LeftJoin(table interface{}) *JoinCondition { 167 | return (&JoinCondition{db: db}).LeftJoin(table) 168 | } 169 | 170 | // Count returns "COUNT" function. 171 | func (db *DB) Count(column ...interface{}) *Function { 172 | switch len(column) { 173 | case 0, 1: 174 | // do nothing. 175 | default: 176 | panic(fmt.Errorf("Count: a number of argument must be 0 or 1, got %v", len(column))) 177 | } 178 | return &Function{ 179 | Name: "COUNT", 180 | Args: column, 181 | } 182 | } 183 | 184 | const ( 185 | dbTag = "db" 186 | dbColumnTag = "column" 187 | dbDefaultTag = "default" 188 | dbSizeTag = "size" 189 | skipTag = "-" 190 | ) 191 | 192 | // CreateTable creates the table into database. 193 | // If table isn't direct/indirect struct, it returns error. 194 | func (db *DB) CreateTable(table interface{}) error { 195 | return db.createTable(table, false) 196 | } 197 | 198 | // CreateTableIfNotExists creates the table into database if table isn't exists. 199 | // If table isn't direct/indirect struct, it returns error. 200 | func (db *DB) CreateTableIfNotExists(table interface{}) error { 201 | return db.createTable(table, true) 202 | } 203 | 204 | func (db *DB) createTable(table interface{}, ifNotExists bool) error { 205 | _, t, tableName, err := db.tableValueOf("CreateTable", table) 206 | if err != nil { 207 | return err 208 | } 209 | fields, err := db.collectTableFields(t) 210 | if err != nil { 211 | return err 212 | } 213 | var query string 214 | if ifNotExists { 215 | query = "CREATE TABLE IF NOT EXISTS %s (%s)" 216 | } else { 217 | query = "CREATE TABLE %s (%s)" 218 | } 219 | query = fmt.Sprintf(query, db.dialect.Quote(tableName), strings.Join(fields, ", ")) 220 | stmt, err := db.prepare(query) 221 | if err != nil { 222 | return err 223 | } 224 | defer stmt.Close() 225 | if _, err := stmt.Exec(); err != nil { 226 | return err 227 | } 228 | return nil 229 | } 230 | 231 | // DropTable removes the table from database. 232 | // If table isn't direct/indirect struct, it returns error. 233 | func (db *DB) DropTable(table interface{}) error { 234 | _, _, tableName, err := db.tableValueOf("DropTable", table) 235 | if err != nil { 236 | return err 237 | } 238 | query := fmt.Sprintf("DROP TABLE %s", db.dialect.Quote(tableName)) 239 | stmt, err := db.prepare(query) 240 | if err != nil { 241 | return err 242 | } 243 | defer stmt.Close() 244 | if _, err = stmt.Exec(); err != nil { 245 | return err 246 | } 247 | return nil 248 | } 249 | 250 | // CreateIndex creates the index into database. 251 | // If table isn't direct/indirect struct, it returns error. 252 | func (db *DB) CreateIndex(table interface{}, name string, names ...string) error { 253 | return db.createIndex(table, false, name, names...) 254 | } 255 | 256 | // CreateUniqueIndex creates the unique index into database. 257 | // If table isn't direct/indirect struct, it returns error. 258 | func (db *DB) CreateUniqueIndex(table interface{}, name string, names ...string) error { 259 | return db.createIndex(table, true, name, names...) 260 | } 261 | 262 | func (db *DB) createIndex(table interface{}, unique bool, name string, names ...string) error { 263 | _, _, tableName, err := db.tableValueOf("CreateIndex", table) 264 | if err != nil { 265 | return err 266 | } 267 | names = append([]string{name}, names...) 268 | indexes := make([]string, len(names)) 269 | for i, name := range names { 270 | indexes[i] = db.dialect.Quote(name) 271 | } 272 | indexName := strings.Join(append([]string{"index", tableName}, names...), "_") 273 | var query string 274 | if unique { 275 | query = "CREATE UNIQUE INDEX %s ON %s (%s)" 276 | } else { 277 | query = "CREATE INDEX %s ON %s (%s)" 278 | } 279 | query = fmt.Sprintf(query, 280 | db.dialect.Quote(indexName), 281 | db.dialect.Quote(tableName), 282 | strings.Join(indexes, ", ")) 283 | stmt, err := db.prepare(query) 284 | if err != nil { 285 | return err 286 | } 287 | defer stmt.Close() 288 | if _, err := stmt.Exec(); err != nil { 289 | return err 290 | } 291 | return nil 292 | } 293 | 294 | // Update updates the one record. 295 | // The obj must be struct, and must have field that specified "pk" struct tag. 296 | // Update will try to update record which searched by value of primary key in obj. 297 | // Update returns the number of rows affected by an update. 298 | func (db *DB) Update(obj interface{}) (affected int64, err error) { 299 | rv, rtype, tableName, err := db.tableValueOf("Update", obj) 300 | if err != nil { 301 | return -1, err 302 | } 303 | if hook, ok := obj.(BeforeUpdater); ok { 304 | if err := hook.BeforeUpdate(); err != nil { 305 | return -1, err 306 | } 307 | } 308 | fieldIndexes := db.collectFieldIndexes(rtype, nil) 309 | pkIdx := db.findPKIndex(rtype, nil) 310 | if len(pkIdx) < 1 { 311 | return -1, fmt.Errorf(`Update: fields of struct doesn't have primary key: "pk" struct tag must be specified for update`) 312 | } 313 | sets := make([]string, len(fieldIndexes)) 314 | var args []interface{} 315 | for i, index := range fieldIndexes { 316 | col := db.columnFromTag(rtype.FieldByIndex(index)) 317 | sets[i] = fmt.Sprintf("%s = %s", db.dialect.Quote(col), db.dialect.PlaceHolder(i)) 318 | args = append(args, rv.FieldByIndex(index).Interface()) 319 | } 320 | query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = %s", 321 | db.dialect.Quote(tableName), 322 | strings.Join(sets, ", "), 323 | db.dialect.Quote(db.columnFromTag(rtype.FieldByIndex(pkIdx))), 324 | db.dialect.PlaceHolder(len(fieldIndexes))) 325 | args = append(args, rv.FieldByIndex(pkIdx).Interface()) 326 | stmt, err := db.prepare(query, args...) 327 | if err != nil { 328 | return -1, err 329 | } 330 | defer stmt.Close() 331 | result, err := stmt.Exec(args...) 332 | if err != nil { 333 | return -1, err 334 | } 335 | affected, _ = result.RowsAffected() 336 | if hook, ok := obj.(AfterUpdater); ok { 337 | if err := hook.AfterUpdate(); err != nil { 338 | return affected, err 339 | } 340 | } 341 | return affected, nil 342 | } 343 | 344 | // Insert inserts one or more records to the database table. 345 | // The obj must be pointer to struct or slice of struct. If a struct have a 346 | // field which specified "pk" struct tag on type of autoincrementable, it 347 | // won't be used to as an insert value. 348 | // Insert sets the last inserted id to the primary key of the instance of the given obj if obj is single. 349 | // Insert returns the number of rows affected by insert. 350 | func (db *DB) Insert(obj interface{}) (affected int64, err error) { 351 | objs, rtype, tableName, err := db.tableObjs("Insert", obj) 352 | if err != nil { 353 | return -1, err 354 | } 355 | if len(objs) < 1 { 356 | return 0, nil 357 | } 358 | for _, obj := range objs { 359 | if hook, ok := obj.(BeforeInserter); ok { 360 | if err := hook.BeforeInsert(); err != nil { 361 | return -1, err 362 | } 363 | } 364 | } 365 | fieldIndexes := db.collectFieldIndexes(rtype, nil) 366 | cols := make([]string, len(fieldIndexes)) 367 | for i, index := range fieldIndexes { 368 | cols[i] = db.dialect.Quote(db.columnFromTag(rtype.FieldByIndex(index))) 369 | } 370 | var args []interface{} 371 | for _, obj := range objs { 372 | rv := reflect.Indirect(reflect.ValueOf(obj)) 373 | for _, index := range fieldIndexes { 374 | args = append(args, rv.FieldByIndex(index).Interface()) 375 | } 376 | } 377 | numHolders := 0 378 | values := make([]string, len(objs)) 379 | holders := make([]string, len(cols)) 380 | for i := 0; i < len(values); i++ { 381 | for j := 0; j < len(holders); j++ { 382 | holders[j] = db.dialect.PlaceHolder(numHolders) 383 | numHolders++ 384 | } 385 | values[i] = fmt.Sprintf("(%s)", strings.Join(holders, ", ")) 386 | } 387 | query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", 388 | db.dialect.Quote(tableName), 389 | strings.Join(cols, ", "), 390 | strings.Join(values, ", "), 391 | ) 392 | stmt, err := db.prepare(query, args...) 393 | if err != nil { 394 | return -1, err 395 | } 396 | defer stmt.Close() 397 | result, err := stmt.Exec(args...) 398 | if err != nil { 399 | return -1, err 400 | } 401 | affected, _ = result.RowsAffected() 402 | if len(objs) == 1 { 403 | if pkIdx := db.findPKIndex(rtype, nil); len(pkIdx) > 0 { 404 | field := rtype.FieldByIndex(pkIdx) 405 | if db.isAutoIncrementable(&field) { 406 | id, err := db.LastInsertId() 407 | if err != nil { 408 | return affected, err 409 | } 410 | rv := reflect.Indirect(reflect.ValueOf(objs[0])).FieldByIndex(pkIdx) 411 | for rv.Kind() == reflect.Ptr { 412 | rv = rv.Elem() 413 | } 414 | rv.Set(reflect.ValueOf(id).Convert(rv.Type())) 415 | } 416 | } 417 | } 418 | for _, obj := range objs { 419 | if hook, ok := obj.(AfterInserter); ok { 420 | if err := hook.AfterInsert(); err != nil { 421 | return affected, err 422 | } 423 | } 424 | } 425 | return affected, nil 426 | } 427 | 428 | // Delete deletes the records from database table. 429 | // The obj must be pointer to struct or slice of struct, and must have field that specified "pk" struct tag. 430 | // Delete will try to delete record which searched by value of primary key in obj. 431 | // Delete returns teh number of rows affected by a delete. 432 | func (db *DB) Delete(obj interface{}) (affected int64, err error) { 433 | objs, rtype, tableName, err := db.tableObjs("Delete", obj) 434 | if err != nil { 435 | return -1, err 436 | } 437 | if len(objs) < 1 { 438 | return 0, nil 439 | } 440 | for _, obj := range objs { 441 | if hook, ok := obj.(BeforeDeleter); ok { 442 | if err := hook.BeforeDelete(); err != nil { 443 | return -1, err 444 | } 445 | } 446 | } 447 | pkIdx := db.findPKIndex(rtype, nil) 448 | if len(pkIdx) < 1 { 449 | return -1, fmt.Errorf(`Delete: fields of struct doesn't have primary key: "pk" struct tag must be specified for delete`) 450 | } 451 | var args []interface{} 452 | for _, obj := range objs { 453 | rv := reflect.Indirect(reflect.ValueOf(obj)) 454 | args = append(args, rv.FieldByIndex(pkIdx).Interface()) 455 | } 456 | holders := make([]string, len(args)) 457 | for i := 0; i < len(holders); i++ { 458 | holders[i] = db.dialect.PlaceHolder(i) 459 | } 460 | query := fmt.Sprintf("DELETE FROM %s WHERE %s IN (%s)", 461 | db.dialect.Quote(tableName), 462 | db.dialect.Quote(db.columnFromTag(rtype.FieldByIndex(pkIdx))), 463 | strings.Join(holders, ", ")) 464 | stmt, err := db.prepare(query, args...) 465 | if err != nil { 466 | return -1, err 467 | } 468 | defer stmt.Close() 469 | result, err := stmt.Exec(args...) 470 | if err != nil { 471 | return -1, err 472 | } 473 | affected, _ = result.RowsAffected() 474 | for _, obj := range objs { 475 | if hook, ok := obj.(AfterDeleter); ok { 476 | if err := hook.AfterDelete(); err != nil { 477 | return affected, err 478 | } 479 | } 480 | } 481 | return affected, nil 482 | } 483 | 484 | // Begin starts a transaction. 485 | func (db *DB) Begin() error { 486 | tx, err := db.db.Begin() 487 | if err != nil { 488 | return err 489 | } 490 | db.m.Lock() 491 | defer db.m.Unlock() 492 | db.tx = tx 493 | return nil 494 | } 495 | 496 | // Commit commits the transaction. 497 | // If Begin still not called, or Commit or Rollback already called, Commit returns ErrTxDone. 498 | func (db *DB) Commit() error { 499 | db.m.Lock() 500 | defer db.m.Unlock() 501 | if db.tx == nil { 502 | return ErrTxDone 503 | } 504 | err := db.tx.Commit() 505 | db.tx = nil 506 | return err 507 | } 508 | 509 | // Rollback rollbacks the transaction. 510 | // If Begin still not called, or Commit or Rollback already called, Rollback returns ErrTxDone. 511 | func (db *DB) Rollback() error { 512 | db.m.Lock() 513 | defer db.m.Unlock() 514 | if db.tx == nil { 515 | return ErrTxDone 516 | } 517 | err := db.tx.Rollback() 518 | db.tx = nil 519 | return err 520 | } 521 | 522 | func (db *DB) LastInsertId() (int64, error) { 523 | stmt, err := db.prepare(db.dialect.LastInsertId()) 524 | if err != nil { 525 | return 0, err 526 | } 527 | defer stmt.Close() 528 | var id int64 529 | return id, stmt.QueryRow().Scan(&id) 530 | } 531 | 532 | // Raw returns a value that is wrapped with Raw. 533 | func (db *DB) Raw(v interface{}) Raw { 534 | return Raw(&v) 535 | } 536 | 537 | // Close closes the database. 538 | func (db *DB) Close() error { 539 | return db.db.Close() 540 | } 541 | 542 | // Quote returns a quoted s. 543 | // It is for a column name, not a value. 544 | func (db *DB) Quote(s string) string { 545 | return db.dialect.Quote(s) 546 | } 547 | 548 | // DB returns a *sql.DB that is associated to DB. 549 | func (db *DB) DB() *sql.DB { 550 | return db.db 551 | } 552 | 553 | // SetLogOutput sets output destination for logging. 554 | // If w is nil, it unsets output of logging. 555 | func (db *DB) SetLogOutput(w io.Writer) { 556 | if w == nil { 557 | db.logger = defaultLogger 558 | } else { 559 | db.logger = &templateLogger{w: w, t: defaultLoggerTemplate} 560 | } 561 | } 562 | 563 | // SetLogFormat sets format for logging. 564 | // 565 | // Format syntax uses Go's template. And you can use the following data object in that template. 566 | // 567 | // - .time time.Time object in current time. 568 | // - .duration Processing time of SQL. It will format to "%.2fms". 569 | // - .query string of SQL query. If it using placeholder, 570 | // placeholder parameters will append to the end of query. 571 | // 572 | // The default format is: 573 | // 574 | // [{{.time.Format "2006-01-02 15:04:05"}}] [{{.duration}}] {{.query}} 575 | func (db *DB) SetLogFormat(format string) error { 576 | return db.logger.SetFormat(format) 577 | } 578 | 579 | // selectToSlice returns a slice value fetched from rows. 580 | func (db *DB) selectToSlice(rows *sql.Rows, t reflect.Type) (reflect.Value, error) { 581 | columns, err := rows.Columns() 582 | if err != nil { 583 | return reflect.Value{}, err 584 | } 585 | t = t.Elem() 586 | ptrN := 0 587 | for ; t.Kind() == reflect.Ptr; ptrN++ { 588 | t = t.Elem() 589 | } 590 | fieldIndexes := make([][]int, len(columns)) 591 | for i, column := range columns { 592 | index := db.fieldIndexByName(t, column, nil) 593 | if len(index) < 1 { 594 | return reflect.Value{}, fmt.Errorf("`%v` field isn't defined in %v or embedded struct", stringutil.ToUpperCamelCase(column), t) 595 | } 596 | fieldIndexes[i] = index 597 | } 598 | dest := make([]interface{}, len(columns)) 599 | var result []reflect.Value 600 | for rows.Next() { 601 | v := reflect.New(t).Elem() 602 | for i, index := range fieldIndexes { 603 | field := v.FieldByIndex(index) 604 | dest[i] = field.Addr().Interface() 605 | } 606 | if err := rows.Scan(dest...); err != nil { 607 | return reflect.Value{}, err 608 | } 609 | result = append(result, v) 610 | } 611 | if err := rows.Err(); err != nil { 612 | return reflect.Value{}, err 613 | } 614 | for i := 0; i < ptrN; i++ { 615 | t = reflect.PtrTo(t) 616 | } 617 | slice := reflect.MakeSlice(reflect.SliceOf(t), len(result), len(result)) 618 | for i, v := range result { 619 | for j := 0; j < ptrN; j++ { 620 | v = v.Addr() 621 | } 622 | slice.Index(i).Set(v) 623 | } 624 | return slice, nil 625 | } 626 | 627 | // selectToValue returns a single value fetched from rows. 628 | func (db *DB) selectToValue(rows *sql.Rows, t reflect.Type) (reflect.Value, error) { 629 | ptrN := 0 630 | for ; t.Kind() == reflect.Ptr; ptrN++ { 631 | t = t.Elem() 632 | } 633 | dest := reflect.New(t).Elem() 634 | if rows.Next() { 635 | if err := rows.Scan(dest.Addr().Interface()); err != nil { 636 | return reflect.Value{}, err 637 | } 638 | } 639 | for i := 0; i < ptrN; i++ { 640 | dest = dest.Addr() 641 | } 642 | return dest, nil 643 | } 644 | 645 | // fieldIndexByName returns the nested field corresponding to the index sequence. 646 | func (db *DB) fieldIndexByName(t reflect.Type, name string, index []int) []int { 647 | for i := 0; i < t.NumField(); i++ { 648 | field := t.Field(i) 649 | if candidate := db.columnFromTag(field); candidate == name { 650 | return append(index, i) 651 | } 652 | if field.Anonymous { 653 | if idx := db.fieldIndexByName(field.Type, name, append(index, i)); len(idx) > 0 { 654 | return append(index, idx...) 655 | } 656 | } 657 | } 658 | return nil 659 | } 660 | 661 | func (db *DB) classify(tableName string, args []interface{}) (column, from string, conditions []*Condition, err error) { 662 | if len(args) == 0 { 663 | return ColumnName(db.dialect, tableName, "*"), tableName, nil, nil 664 | } 665 | offset := 1 666 | switch t := args[0].(type) { 667 | case string: 668 | if t != "" { 669 | column = ColumnName(db.dialect, tableName, t) 670 | } 671 | case []string: 672 | column = db.columns(tableName, ToInterfaceSlice(t)) 673 | case *Distinct: 674 | column = fmt.Sprintf("DISTINCT %s", db.columns(tableName, ToInterfaceSlice(t.columns))) 675 | case *Function: 676 | var col string 677 | if len(t.Args) == 0 { 678 | col = "*" 679 | } else { 680 | col = db.columns(tableName, t.Args) 681 | } 682 | column = fmt.Sprintf("%s(%s)", t.Name, col) 683 | default: 684 | offset-- 685 | } 686 | for i := offset; i < len(args); i++ { 687 | switch t := args[i].(type) { 688 | case *Condition: 689 | t.tableName = tableName 690 | conditions = append(conditions, t) 691 | case string, []string: 692 | return "", "", nil, fmt.Errorf("argument of %T type must be before the *Condition arguments", t) 693 | case *From: 694 | // ignore. 695 | case *Function: 696 | return "", "", nil, fmt.Errorf("%s function must be specified to the first argument", t.Name) 697 | default: 698 | return "", "", nil, fmt.Errorf("unsupported argument type: %T", t) 699 | } 700 | } 701 | if column == "" { 702 | column = ColumnName(db.dialect, tableName, "*") 703 | } 704 | return column, tableName, conditions, nil 705 | } 706 | 707 | // columns returns the comma-separated column name with quoted. 708 | func (db *DB) columns(tableName string, columns []interface{}) string { 709 | if len(columns) == 0 { 710 | return ColumnName(db.dialect, tableName, "*") 711 | } 712 | names := make([]string, len(columns)) 713 | for i, col := range columns { 714 | switch c := col.(type) { 715 | case Raw: 716 | names[i] = fmt.Sprint(*c) 717 | case string: 718 | names[i] = ColumnName(db.dialect, tableName, c) 719 | case *Distinct: 720 | names[i] = fmt.Sprintf("DISTINCT %s", db.columns(tableName, ToInterfaceSlice(c.columns))) 721 | default: 722 | panic(fmt.Errorf("column name must be string, Raw or *Distinct, got %T", c)) 723 | } 724 | } 725 | return strings.Join(names, ", ") 726 | } 727 | 728 | func (db *DB) collectTableFields(t reflect.Type) (fields []string, err error) { 729 | for i := 0; i < t.NumField(); i++ { 730 | field := t.Field(i) 731 | if IsUnexportedField(field) { 732 | continue 733 | } 734 | if db.hasSkipTag(&field) { 735 | continue 736 | } 737 | if field.Anonymous { 738 | fs, err := db.collectTableFields(field.Type) 739 | if err != nil { 740 | return nil, err 741 | } 742 | fields = append(fields, fs...) 743 | continue 744 | } 745 | var options []string 746 | autoIncrement := false 747 | for _, tag := range db.tagsFromField(&field) { 748 | switch tag { 749 | case "pk": 750 | options = append(options, "PRIMARY KEY") 751 | if db.isAutoIncrementable(&field) { 752 | options = append(options, db.dialect.AutoIncrement()) 753 | autoIncrement = true 754 | } 755 | case "unique": 756 | options = append(options, "UNIQUE") 757 | default: 758 | return nil, fmt.Errorf(`CreateTable: unsupported field tag: "%v"`, tag) 759 | } 760 | } 761 | size, err := db.sizeFromTag(&field) 762 | if err != nil { 763 | return nil, err 764 | } 765 | typName, allowNull := db.dialect.SQLType(reflect.Zero(field.Type).Interface(), autoIncrement, size) 766 | if !allowNull { 767 | options = append(options, "NOT NULL") 768 | } 769 | line := append([]string{db.dialect.Quote(db.columnFromTag(field)), typName}, options...) 770 | def, err := db.defaultFromTag(&field) 771 | if err != nil { 772 | return nil, err 773 | } 774 | if def != "" { 775 | line = append(line, def) 776 | } 777 | fields = append(fields, strings.Join(line, " ")) 778 | } 779 | return fields, nil 780 | } 781 | 782 | // tagsFromField returns a slice of option strings. 783 | func (db *DB) tagsFromField(field *reflect.StructField) (options []string) { 784 | if db.hasSkipTag(field) { 785 | return nil 786 | } 787 | for _, tag := range strings.Split(field.Tag.Get(dbTag), ",") { 788 | if t := strings.ToLower(strings.TrimSpace(tag)); t != "" { 789 | options = append(options, t) 790 | } 791 | } 792 | return options 793 | } 794 | 795 | // hasSkipTag returns whether the struct field has the "-" tag. 796 | func (db *DB) hasSkipTag(field *reflect.StructField) bool { 797 | if field.Tag.Get(dbTag) == skipTag { 798 | return true 799 | } 800 | return false 801 | } 802 | 803 | // hasPKTag returns whether the struct field has the "pk" tag. 804 | func (db *DB) hasPKTag(field *reflect.StructField) bool { 805 | for _, tag := range db.tagsFromField(field) { 806 | if tag == "pk" { 807 | return true 808 | } 809 | } 810 | return false 811 | } 812 | 813 | // isAutoIncrementable returns whether the struct field is integer. 814 | func (db *DB) isAutoIncrementable(field *reflect.StructField) bool { 815 | switch field.Type.Kind() { 816 | case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, 817 | reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: 818 | return true 819 | } 820 | return false 821 | } 822 | 823 | // collectFieldIndexes returns the indexes of field which doesn't have skip tag and pk tag. 824 | func (db *DB) collectFieldIndexes(typ reflect.Type, index []int) (indexes [][]int) { 825 | for i := 0; i < typ.NumField(); i++ { 826 | field := typ.Field(i) 827 | if IsUnexportedField(field) { 828 | continue 829 | } 830 | if !(db.hasSkipTag(&field) || (db.hasPKTag(&field) && db.isAutoIncrementable(&field))) { 831 | tmp := make([]int, len(index)+1) 832 | copy(tmp, index) 833 | tmp[len(tmp)-1] = i 834 | if field.Anonymous { 835 | indexes = append(indexes, db.collectFieldIndexes(field.Type, tmp)...) 836 | } else { 837 | indexes = append(indexes, tmp) 838 | } 839 | } 840 | } 841 | return indexes 842 | } 843 | 844 | // findPKIndex returns the nested field corresponding to the index sequence of field of primary key. 845 | func (db *DB) findPKIndex(typ reflect.Type, index []int) []int { 846 | for i := 0; i < typ.NumField(); i++ { 847 | field := typ.Field(i) 848 | if IsUnexportedField(field) { 849 | continue 850 | } 851 | if field.Anonymous { 852 | if idx := db.findPKIndex(field.Type, append(index, i)); idx != nil { 853 | return append(index, idx...) 854 | } 855 | continue 856 | } 857 | if db.hasPKTag(&field) { 858 | return append(index, i) 859 | } 860 | } 861 | return nil 862 | } 863 | 864 | // sizeFromTag returns a size from tag. 865 | // If "size" tag specified to struct field, it will converted to uint64 and returns it. 866 | // If it doesn't specify, it returns 0. 867 | // If value of "size" tag cannot convert to uint64, it returns 0 and error. 868 | func (db *DB) sizeFromTag(field *reflect.StructField) (size uint64, err error) { 869 | if s := field.Tag.Get(dbSizeTag); s != "" { 870 | size, err = strconv.ParseUint(s, 10, 64) 871 | } 872 | return size, err 873 | } 874 | 875 | func (db *DB) tableName(t reflect.Type) string { 876 | if table, ok := reflect.New(t).Interface().(TableNamer); ok { 877 | return table.TableName() 878 | } 879 | return stringutil.ToSnakeCase(t.Name()) 880 | } 881 | 882 | // columnFromTag returns the column name. 883 | // If "column" tag specified to struct field, returns it. 884 | // Otherwise, it returns snake-cased field name as column name. 885 | func (db *DB) columnFromTag(field reflect.StructField) string { 886 | col := field.Tag.Get(dbColumnTag) 887 | if col == "" { 888 | return stringutil.ToSnakeCase(field.Name) 889 | } 890 | return col 891 | } 892 | 893 | // defaultFromTag returns a "DEFAULT ..." keyword. 894 | // If "default" tag specified to struct field, it use as the default value. 895 | // If it doesn't specify, it returns empty string. 896 | func (db *DB) defaultFromTag(field *reflect.StructField) (string, error) { 897 | def := field.Tag.Get(dbDefaultTag) 898 | if def == "" { 899 | return "", nil 900 | } 901 | switch field.Type.Kind() { 902 | case reflect.Bool: 903 | b, err := strconv.ParseBool(def) 904 | if err != nil { 905 | return "", err 906 | } 907 | return fmt.Sprintf("DEFAULT %v", db.dialect.FormatBool(b)), nil 908 | } 909 | return fmt.Sprintf("DEFAULT %v", def), nil 910 | } 911 | 912 | func (db *DB) tableObjs(name string, obj interface{}) (objs []interface{}, rtype reflect.Type, tableName string, err error) { 913 | switch v := reflect.Indirect(reflect.ValueOf(obj)); v.Kind() { 914 | case reflect.Slice: 915 | if v.Len() < 1 { 916 | return objs, nil, "", nil 917 | } 918 | for i := 0; i < v.Len(); i++ { 919 | sv := v.Index(i) 920 | for sv.Kind() == reflect.Ptr { 921 | sv = sv.Elem() 922 | } 923 | if sv.Kind() == reflect.Interface { 924 | svk := reflect.Indirect(reflect.ValueOf(sv)).Kind() 925 | if svk != reflect.Struct { 926 | goto Error 927 | } 928 | objs = append(objs, sv.Interface()) 929 | } else { 930 | if sv.Kind() != reflect.Struct { 931 | goto Error 932 | } 933 | objs = append(objs, sv.Addr().Interface()) 934 | } 935 | } 936 | case reflect.Struct: 937 | if !v.CanAddr() { 938 | goto Error 939 | } 940 | objs = append(objs, v.Addr().Interface()) 941 | } 942 | _, rtype, tableName, err = db.tableValueOf(name, objs[0]) 943 | return objs, rtype, tableName, err 944 | Error: 945 | return nil, nil, "", fmt.Errorf("%s: argument must be pointer to struct or slice of struct, got %T", name, obj) 946 | } 947 | 948 | func (db *DB) tableValueOf(name string, table interface{}) (rv reflect.Value, rt reflect.Type, tableName string, err error) { 949 | rv = reflect.Indirect(reflect.ValueOf(table)) 950 | rt = rv.Type() 951 | if rt.Kind() != reflect.Struct { 952 | return rv, rt, "", fmt.Errorf("%s: a table must be struct type, got %v", name, rt) 953 | } 954 | tableName = db.tableName(rt) 955 | if tableName == "" { 956 | return rv, rt, "", fmt.Errorf("%s: a table isn't named", name) 957 | } 958 | return rv, rt, tableName, nil 959 | } 960 | 961 | func (db *DB) prepare(query string, args ...interface{}) (*sql.Stmt, error) { 962 | defer db.logger.Print(now(), query, args...) 963 | db.m.Lock() 964 | defer db.m.Unlock() 965 | if db.tx == nil { 966 | return db.db.Prepare(query) 967 | } else { 968 | return db.tx.Prepare(query) 969 | } 970 | } 971 | 972 | type selectFunc func(*sql.Rows, reflect.Type) (reflect.Value, error) 973 | 974 | // TableNamer is an interface that is used to use a different table name. 975 | type TableNamer interface { 976 | // TableName returns the table name on DB. 977 | TableName() string 978 | } 979 | 980 | // BeforeUpdater is an interface that hook for before Update. 981 | type BeforeUpdater interface { 982 | // BeforeUpdate called before an update by DB.Update. 983 | // If it returns error, the update will be cancelled. 984 | BeforeUpdate() error 985 | } 986 | 987 | // AfterUpdater is an interface that hook for after Update. 988 | type AfterUpdater interface { 989 | // AfterUpdate called after an update by DB.Update. 990 | AfterUpdate() error 991 | } 992 | 993 | // BeforeInserter is an interface that hook for before Insert. 994 | type BeforeInserter interface { 995 | // BeforeInsert called before an insert by DB.Insert. 996 | // If it returns error, the insert will be cancelled. 997 | BeforeInsert() error 998 | } 999 | 1000 | // AfterInserter is an interface that hook for after Insert. 1001 | type AfterInserter interface { 1002 | // AfterInsert called after an insert by DB.Insert. 1003 | AfterInsert() error 1004 | } 1005 | 1006 | // BeforeDeleter is an interface that hook for before Delete. 1007 | type BeforeDeleter interface { 1008 | // BeforeDelete called before a delete by DB.Delete. 1009 | // If it returns error, the delete will be cancelled. 1010 | BeforeDelete() error 1011 | } 1012 | 1013 | // AfterDeleter is an interface that hook for after Delete. 1014 | type AfterDeleter interface { 1015 | // AfterDelete called after a delete by DB.Delete. 1016 | AfterDelete() error 1017 | } 1018 | 1019 | // Raw represents a raw value. 1020 | // Raw value won't quoted. 1021 | type Raw *interface{} 1022 | 1023 | // From represents a "FROM" statement. 1024 | type From struct { 1025 | TableName string 1026 | } 1027 | 1028 | // Distinct represents a "DISTINCT" statement. 1029 | type Distinct struct { 1030 | columns []string 1031 | } 1032 | 1033 | // Function represents a function of SQL. 1034 | type Function struct { 1035 | // A function name. 1036 | Name string 1037 | 1038 | // function arguments (optional). 1039 | Args []interface{} 1040 | } 1041 | 1042 | // Order represents a keyword for the "ORDER" clause of SQL. 1043 | type Order string 1044 | 1045 | const ( 1046 | ASC Order = "ASC" 1047 | DESC Order = "DESC" 1048 | ) 1049 | 1050 | func (o Order) String() string { 1051 | return string(o) 1052 | } 1053 | 1054 | // Clause represents a clause of SQL. 1055 | type Clause uint 1056 | 1057 | const ( 1058 | Where Clause = iota 1059 | And 1060 | Or 1061 | OrderBy 1062 | Limit 1063 | Offset 1064 | In 1065 | Like 1066 | Between 1067 | Join 1068 | LeftJoin 1069 | IsNull 1070 | IsNotNull 1071 | ) 1072 | 1073 | func (c Clause) String() string { 1074 | if int(c) >= len(clauseStrings) { 1075 | panic(fmt.Errorf("Clause %v is not defined", uint(c))) 1076 | } 1077 | return clauseStrings[c] 1078 | } 1079 | 1080 | var clauseStrings = []string{ 1081 | Where: "WHERE", 1082 | And: "AND", 1083 | Or: "OR", 1084 | OrderBy: "ORDER BY", 1085 | Limit: "LIMIT", 1086 | Offset: "OFFSET", 1087 | In: "IN", 1088 | Like: "LIKE", 1089 | Between: "BETWEEN", 1090 | Join: "JOIN", 1091 | LeftJoin: "LEFT JOIN", 1092 | IsNull: "IS NULL", 1093 | IsNotNull: "IS NOT NULL", 1094 | } 1095 | 1096 | // column represents a column name in query. 1097 | type column struct { 1098 | table string // table name (optional). 1099 | name string // column name. 1100 | } 1101 | 1102 | // expr represents a expression in query. 1103 | type expr struct { 1104 | op string // operator. 1105 | column *column // column name. 1106 | value interface{} // value. 1107 | } 1108 | 1109 | // orderBy represents a "ORDER BY" query. 1110 | type orderBy struct { 1111 | column column // column name. 1112 | order Order // direction. 1113 | } 1114 | 1115 | // between represents a "BETWEEN" query. 1116 | type between struct { 1117 | from interface{} 1118 | to interface{} 1119 | } 1120 | 1121 | // Condition represents a condition for query. 1122 | type Condition struct { 1123 | db *DB 1124 | parts parts // parts of the query. 1125 | tableName string // table name (optional). 1126 | } 1127 | 1128 | // newCondition returns a new Condition with Dialect. 1129 | func newCondition(db *DB) *Condition { 1130 | return &Condition{db: db} 1131 | } 1132 | 1133 | // Where adds "WHERE" clause to the Condition and returns it for method chain. 1134 | func (c *Condition) Where(cond interface{}, args ...interface{}) *Condition { 1135 | return c.appendQueryByCondOrExpr("Where", 0, Where, cond, args...) 1136 | } 1137 | 1138 | // And adds "AND" operator to the Condition and returns it for method chain. 1139 | func (c *Condition) And(cond interface{}, args ...interface{}) *Condition { 1140 | return c.appendQueryByCondOrExpr("And", 100, And, cond, args...) 1141 | } 1142 | 1143 | // Or adds "OR" operator to the Condition and returns it for method chain. 1144 | func (c *Condition) Or(cond interface{}, args ...interface{}) *Condition { 1145 | return c.appendQueryByCondOrExpr("Or", 100, Or, cond, args...) 1146 | } 1147 | 1148 | // In adds "IN" clause to the Condition and returns it for method chain. 1149 | func (c *Condition) In(args ...interface{}) *Condition { 1150 | return c.appendQuery(100, In, args) 1151 | } 1152 | 1153 | // Like adds "LIKE" clause to the Condition and returns it for method chain. 1154 | func (c *Condition) Like(arg string) *Condition { 1155 | return c.appendQuery(100, Like, arg) 1156 | } 1157 | 1158 | // Between adds "BETWEEN ... AND ..." clause to the Condition and returns it for method chain. 1159 | func (c *Condition) Between(from, to interface{}) *Condition { 1160 | return c.appendQuery(100, Between, &between{from, to}) 1161 | } 1162 | 1163 | // IsNull adds "IS NULL" clause to the Condition and returns it for method chain. 1164 | func (c *Condition) IsNull() *Condition { 1165 | return c.appendQuery(100, IsNull, nil) 1166 | } 1167 | 1168 | // IsNotNull adds "IS NOT NULL" clause to the Condition and returns it for method chain. 1169 | func (c *Condition) IsNotNull() *Condition { 1170 | return c.appendQuery(100, IsNotNull, nil) 1171 | } 1172 | 1173 | // OrderBy adds "ORDER BY" clause to the Condition and returns it for method chain. 1174 | func (c *Condition) OrderBy(table, col interface{}, order ...interface{}) *Condition { 1175 | order = append([]interface{}{table, col}, order...) 1176 | orderbys := make([]orderBy, 0, 1) 1177 | for len(order) > 0 { 1178 | o, rest := order[0], order[1:] 1179 | if _, ok := o.(string); ok { 1180 | if len(rest) < 1 { 1181 | panic(fmt.Errorf("OrderBy: few arguments")) 1182 | } 1183 | // OrderBy("column", genmai.DESC) 1184 | orderbys = append(orderbys, c.orderBy(nil, o, rest[0])) 1185 | order = rest[1:] 1186 | continue 1187 | } 1188 | if len(rest) < 2 { 1189 | panic(fmt.Errorf("OrderBy: few arguments")) 1190 | } 1191 | // OrderBy(tbl{}, "column", genmai.DESC) 1192 | orderbys = append(orderbys, c.orderBy(o, rest[0], rest[1])) 1193 | order = rest[2:] 1194 | } 1195 | return c.appendQuery(300, OrderBy, orderbys) 1196 | } 1197 | 1198 | // Limit adds "LIMIT" clause to the Condition and returns it for method chain. 1199 | func (c *Condition) Limit(lim int) *Condition { 1200 | return c.appendQuery(500, Limit, lim) 1201 | } 1202 | 1203 | // Offset adds "OFFSET" clause to the Condition and returns it for method chain. 1204 | func (c *Condition) Offset(offset int) *Condition { 1205 | return c.appendQuery(700, Offset, offset) 1206 | } 1207 | 1208 | func (c *Condition) appendQuery(priority int, clause Clause, expr interface{}, args ...interface{}) *Condition { 1209 | c.parts = append(c.parts, part{ 1210 | clause: clause, 1211 | expr: expr, 1212 | priority: priority, 1213 | }) 1214 | return c 1215 | } 1216 | 1217 | func (c *Condition) appendQueryByCondOrExpr(name string, order int, clause Clause, cond interface{}, args ...interface{}) *Condition { 1218 | switch t := cond.(type) { 1219 | case string, *Condition: 1220 | args = append([]interface{}{t}, args...) 1221 | default: 1222 | v := reflect.Indirect(reflect.ValueOf(t)) 1223 | if v.Kind() != reflect.Struct { 1224 | panic(fmt.Errorf("%s: first argument must be string or struct, got %T", name, t)) 1225 | } 1226 | args = append([]interface{}{c.db.tableName(v.Type())}, args...) 1227 | } 1228 | switch len(args) { 1229 | case 1: // Where(Where("id", "=", 1)) 1230 | switch t := args[0].(type) { 1231 | case *Condition: 1232 | cond = t 1233 | case string: 1234 | cond = &column{name: t} 1235 | default: 1236 | panic(fmt.Errorf("%s: first argument must be string or *Condition if args not given, got %T", name, t)) 1237 | } 1238 | case 2: // Where(&Table{}, "id") 1239 | cond = &column{ 1240 | table: fmt.Sprint(args[0]), 1241 | name: fmt.Sprint(args[1]), 1242 | } 1243 | case 3: // Where("id", "=", 1) 1244 | cond = &expr{ 1245 | op: fmt.Sprint(args[1]), 1246 | column: &column{ 1247 | name: fmt.Sprint(args[0]), 1248 | }, 1249 | value: args[2], 1250 | } 1251 | case 4: // Where(&Table{}, "id", "=", 1) 1252 | cond = &expr{ 1253 | op: fmt.Sprint(args[2]), 1254 | column: &column{ 1255 | table: fmt.Sprint(args[0]), 1256 | name: fmt.Sprint(args[1]), 1257 | }, 1258 | value: args[3], 1259 | } 1260 | default: 1261 | panic(fmt.Errorf("%s: arguments expect between 1 and 4, got %v", name, len(args))) 1262 | } 1263 | return c.appendQuery(order, clause, cond) 1264 | } 1265 | 1266 | func (c *Condition) orderBy(table, col, order interface{}) orderBy { 1267 | o := orderBy{ 1268 | column: column{ 1269 | name: fmt.Sprint(col), 1270 | }, 1271 | order: Order(fmt.Sprint(order)), 1272 | } 1273 | if table != nil { 1274 | rt := reflect.TypeOf(table) 1275 | for rt.Kind() == reflect.Ptr { 1276 | rt = rt.Elem() 1277 | } 1278 | o.column.table = c.db.tableName(rt) 1279 | } 1280 | return o 1281 | } 1282 | 1283 | func (c *Condition) build(numHolders int, inner bool) (queries []string, args []interface{}) { 1284 | sort.Sort(c.parts) 1285 | for _, p := range c.parts { 1286 | if !(inner && p.clause == Where) { 1287 | queries = append(queries, p.clause.String()) 1288 | } 1289 | switch e := p.expr.(type) { 1290 | case *expr: 1291 | col := ColumnName(c.db.dialect, e.column.table, e.column.name) 1292 | queries = append(queries, col, e.op, c.db.dialect.PlaceHolder(numHolders)) 1293 | args = append(args, e.value) 1294 | numHolders++ 1295 | case []orderBy: 1296 | o := e[0] 1297 | queries = append(queries, ColumnName(c.db.dialect, o.column.table, o.column.name), o.order.String()) 1298 | if len(e) > 1 { 1299 | for _, o := range e[1:] { 1300 | queries = append(queries, ",", ColumnName(c.db.dialect, o.column.table, o.column.name), o.order.String()) 1301 | } 1302 | } 1303 | case *column: 1304 | col := ColumnName(c.db.dialect, e.table, e.name) 1305 | queries = append(queries, col) 1306 | case []interface{}: 1307 | e = flatten(e) 1308 | holders := make([]string, len(e)) 1309 | for i := 0; i < len(e); i++ { 1310 | holders[i] = c.db.dialect.PlaceHolder(numHolders) 1311 | numHolders++ 1312 | } 1313 | queries = append(queries, "(", strings.Join(holders, ", "), ")") 1314 | args = append(args, e...) 1315 | case *between: 1316 | queries = append(queries, c.db.dialect.PlaceHolder(numHolders), "AND", c.db.dialect.PlaceHolder(numHolders+1)) 1317 | args = append(args, e.from, e.to) 1318 | numHolders += 2 1319 | case *Condition: 1320 | q, a := e.build(numHolders, true) 1321 | queries = append(append(append(queries, "("), q...), ")") 1322 | args = append(args, a...) 1323 | case *JoinCondition: 1324 | var leftTableName string 1325 | if e.leftTableName == "" { 1326 | leftTableName = c.tableName 1327 | } else { 1328 | leftTableName = e.leftTableName 1329 | } 1330 | queries = append(queries, 1331 | c.db.dialect.Quote(e.tableName), "ON", 1332 | ColumnName(c.db.dialect, leftTableName, e.left), e.op, ColumnName(c.db.dialect, e.tableName, e.right)) 1333 | case nil: 1334 | // ignore. 1335 | default: 1336 | queries = append(queries, c.db.dialect.PlaceHolder(numHolders)) 1337 | args = append(args, e) 1338 | numHolders++ 1339 | } 1340 | } 1341 | return queries, args 1342 | } 1343 | 1344 | // JoinCondition represents a condition of "JOIN" query. 1345 | type JoinCondition struct { 1346 | db *DB 1347 | leftTableName string // A table name of 'to be joined'. 1348 | tableName string // A table name of 'to join'. 1349 | op string // A operator of expression in "ON" clause. 1350 | left string // A left column name of operator. 1351 | right string // A right column name of operator. 1352 | clause Clause // A type of join clause ("JOIN" or "LEFT JOIN") 1353 | } 1354 | 1355 | // Join adds table name to the JoinCondition of "JOIN". 1356 | // If table isn't direct/indirect struct type, it panics. 1357 | func (jc *JoinCondition) Join(table interface{}) *JoinCondition { 1358 | return jc.join(Join, table) 1359 | } 1360 | 1361 | // LeftJoin adds table name to the JoinCondition of "LEFT JOIN". 1362 | // If table isn't direct/indirect struct type, it panics. 1363 | func (jc *JoinCondition) LeftJoin(table interface{}) *JoinCondition { 1364 | return jc.join(LeftJoin, table) 1365 | } 1366 | 1367 | // On adds "[LEFT] JOIN ... ON" clause to the Condition and returns it for method chain. 1368 | func (jc *JoinCondition) On(larg interface{}, args ...string) *Condition { 1369 | var lcolumn string 1370 | switch rv := reflect.ValueOf(larg); rv.Kind() { 1371 | case reflect.String: 1372 | lcolumn = rv.String() 1373 | default: 1374 | for rv.Kind() == reflect.Ptr { 1375 | rv = rv.Elem() 1376 | } 1377 | if rv.Kind() != reflect.Struct { 1378 | panic(fmt.Errorf("On: first argument must be string or struct, got %v", rv.Type())) 1379 | } 1380 | jc.leftTableName = jc.db.tableName(rv.Type()) 1381 | lcolumn, args = args[0], args[1:] 1382 | } 1383 | switch len(args) { 1384 | case 0: 1385 | jc.left, jc.op, jc.right = lcolumn, "=", lcolumn 1386 | case 2: 1387 | jc.left, jc.op, jc.right = lcolumn, args[0], args[1] 1388 | default: 1389 | panic(fmt.Errorf("On: arguments expect 1 or 3, got %v", len(args)+1)) 1390 | } 1391 | c := newCondition(jc.db) 1392 | c.parts = append(c.parts, part{ 1393 | clause: jc.clause, 1394 | expr: jc, 1395 | priority: -100, 1396 | }) 1397 | return c 1398 | } 1399 | 1400 | func (jc *JoinCondition) join(joinClause Clause, table interface{}) *JoinCondition { 1401 | t := reflect.Indirect(reflect.ValueOf(table)).Type() 1402 | if t.Kind() != reflect.Struct { 1403 | panic(fmt.Errorf("%v: a table must be struct type, got %v", joinClause, t)) 1404 | } 1405 | jc.tableName = jc.db.tableName(t) 1406 | jc.clause = joinClause 1407 | return jc 1408 | } 1409 | 1410 | // part represents a part of query. 1411 | type part struct { 1412 | clause Clause 1413 | expr interface{} 1414 | 1415 | // a order for sort. A lower value is a high-priority. 1416 | priority int 1417 | } 1418 | 1419 | // parts is for sort.Interface. 1420 | type parts []part 1421 | 1422 | func (ps parts) Len() int { 1423 | return len(ps) 1424 | } 1425 | 1426 | func (ps parts) Less(i, j int) bool { 1427 | return ps[i].priority < ps[j].priority 1428 | } 1429 | 1430 | func (ps parts) Swap(i, j int) { 1431 | ps[i], ps[j] = ps[j], ps[i] 1432 | } 1433 | -------------------------------------------------------------------------------- /dialect_test.go: -------------------------------------------------------------------------------- 1 | package genmai 2 | 3 | import ( 4 | "database/sql" 5 | "reflect" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func Test_SQLite3Dialect_Name(t *testing.T) { 11 | d := &SQLite3Dialect{} 12 | actual := d.Name() 13 | expected := "sqlite3" 14 | if !reflect.DeepEqual(actual, expected) { 15 | t.Errorf("Expect %q, but %q", expected, actual) 16 | } 17 | } 18 | 19 | func Test_SQLite3Dialect_Quote(t *testing.T) { 20 | d := &SQLite3Dialect{} 21 | for _, v := range []struct { 22 | s, expected string 23 | }{ 24 | {``, `""`}, 25 | {`test`, `"test"`}, 26 | {`"test"`, `"""test"""`}, 27 | {`test"bar"baz`, `"test""bar""baz"`}, 28 | } { 29 | actual := d.Quote(v.s) 30 | expected := v.expected 31 | if !reflect.DeepEqual(actual, expected) { 32 | t.Errorf("Input %q expects %q, but %q", v.s, expected, actual) 33 | } 34 | } 35 | } 36 | 37 | func Test_SQLite3Dialect_PlaceHolder(t *testing.T) { 38 | d := &SQLite3Dialect{} 39 | actual := d.PlaceHolder(0) 40 | expected := "?" 41 | if !reflect.DeepEqual(actual, expected) { 42 | t.Errorf("Expect %q, but %q", expected, actual) 43 | } 44 | } 45 | 46 | func TestSQLite3Dialect_SQLType_boolDirect(t *testing.T) { 47 | d := &SQLite3Dialect{} 48 | sets := []interface{}{true, false} 49 | 50 | // autoIncrement is false. 51 | for _, v := range sets { 52 | name, null := d.SQLType(v, false, 0) 53 | actual := []interface{}{name, null} 54 | expected := []interface{}{"boolean", false} 55 | if !reflect.DeepEqual(actual, expected) { 56 | t.Errorf("%T expects %q, but %q", v, expected, actual) 57 | } 58 | } 59 | 60 | // autoIncrement is true. 61 | for _, v := range sets { 62 | name, null := d.SQLType(v, true, 0) 63 | actual := []interface{}{name, null} 64 | expected := []interface{}{"boolean", false} 65 | if !reflect.DeepEqual(actual, expected) { 66 | t.Errorf("%T expects %q, but %q", v, expected, actual) 67 | } 68 | } 69 | } 70 | 71 | func TestSQLite3Dialect_SQLType_boolIndirect(t *testing.T) { 72 | d := &SQLite3Dialect{} 73 | sets := []interface{}{new(bool), sql.NullBool{}} 74 | 75 | // autoIncrement is false. 76 | for _, v := range sets { 77 | name, null := d.SQLType(v, false, 0) 78 | actual := []interface{}{name, null} 79 | expected := []interface{}{"boolean", true} 80 | if !reflect.DeepEqual(actual, expected) { 81 | t.Errorf("%T expects %q, but %q", v, expected, actual) 82 | } 83 | } 84 | 85 | // autoIncrement is true. 86 | for _, v := range sets { 87 | name, null := d.SQLType(v, true, 0) 88 | actual := []interface{}{name, null} 89 | expected := []interface{}{"boolean", true} 90 | if !reflect.DeepEqual(actual, expected) { 91 | t.Errorf("%T expects %q, but %q", v, expected, actual) 92 | } 93 | } 94 | } 95 | 96 | func TestSQLite3Dialect_SQLType_primitiveFloat(t *testing.T) { 97 | d := &SQLite3Dialect{} 98 | sets := []interface{}{float32(.1), new(float32), float64(.1), new(float64), sql.NullFloat64{}} 99 | 100 | // autoIncrement is false. 101 | for _, v := range sets { 102 | func(v interface{}) { 103 | defer func() { 104 | if err := recover(); err == nil { 105 | t.Errorf("panic hasn't been occurred by %T", v) 106 | } 107 | }() 108 | d.SQLType(v, false, 0) 109 | }(v) 110 | } 111 | 112 | // autoIncrement is true. 113 | for _, v := range sets { 114 | func(v interface{}) { 115 | defer func() { 116 | if err := recover(); err == nil { 117 | t.Errorf("panic hasn't been occurred by %T", v) 118 | } 119 | }() 120 | d.SQLType(v, false, 0) 121 | }(v) 122 | } 123 | } 124 | 125 | func TestSQLite3Dialect_SQLType_intDirect(t *testing.T) { 126 | d := &SQLite3Dialect{} 127 | sets := []interface{}{ 128 | int(1), int8(1), int16(1), int32(1), int64(1), uint(1), uint8(1), 129 | uint16(1), uint32(1), uint64(1), 130 | } 131 | 132 | // autoIncrement is false. 133 | for _, v := range sets { 134 | name, null := d.SQLType(v, false, 0) 135 | actual := []interface{}{name, null} 136 | expected := []interface{}{"integer", false} 137 | if !reflect.DeepEqual(actual, expected) { 138 | t.Errorf("%T expects %q, but %q", v, expected, actual) 139 | } 140 | } 141 | 142 | // autoIncrement is true. 143 | for _, v := range sets { 144 | name, null := d.SQLType(v, true, 0) 145 | actual := []interface{}{name, null} 146 | expected := []interface{}{"integer", false} 147 | if !reflect.DeepEqual(actual, expected) { 148 | t.Errorf("%T expects %q, but %q", v, expected, actual) 149 | } 150 | } 151 | } 152 | 153 | func TestSQLite3Dialect_SQLType_intIndirect(t *testing.T) { 154 | d := &SQLite3Dialect{} 155 | sets := []interface{}{ 156 | new(int), new(int8), new(int16), new(int32), new(int64), new(uint), 157 | new(uint8), new(uint16), new(uint32), new(uint64), sql.NullInt64{}} 158 | 159 | // autoIncrement is false. 160 | for _, v := range sets { 161 | name, null := d.SQLType(v, false, 0) 162 | actual := []interface{}{name, null} 163 | expected := []interface{}{"integer", true} 164 | if !reflect.DeepEqual(actual, expected) { 165 | t.Errorf("%T expects %q, but %q", v, expected, actual) 166 | } 167 | } 168 | 169 | // autoIncrement is true. 170 | for _, v := range sets { 171 | name, null := d.SQLType(v, true, 0) 172 | actual := []interface{}{name, null} 173 | expected := []interface{}{"integer", true} 174 | if !reflect.DeepEqual(actual, expected) { 175 | t.Errorf("%T expects %q, but %q", v, expected, actual) 176 | } 177 | } 178 | } 179 | 180 | func TestSQLite3Dialect_SQLType_stringDirect(t *testing.T) { 181 | d := &SQLite3Dialect{} 182 | sets := []interface{}{""} 183 | 184 | // autoIncrement is false. 185 | for _, v := range sets { 186 | name, null := d.SQLType(v, false, 0) 187 | actual := []interface{}{name, null} 188 | expected := []interface{}{"text", false} 189 | if !reflect.DeepEqual(actual, expected) { 190 | t.Errorf("%T expects %q, but %q", v, expected, actual) 191 | } 192 | } 193 | 194 | // autoIncrement is true. 195 | for _, v := range sets { 196 | name, null := d.SQLType(v, true, 0) 197 | actual := []interface{}{name, null} 198 | expected := []interface{}{"text", false} 199 | if !reflect.DeepEqual(actual, expected) { 200 | t.Errorf("%T expects %q, but %q", v, expected, actual) 201 | } 202 | } 203 | } 204 | 205 | func TestSQLite3Dialect_SQLType_stringIndirect(t *testing.T) { 206 | d := &SQLite3Dialect{} 207 | sets := []interface{}{new(string), sql.NullString{}} 208 | 209 | // autoIncrement is false. 210 | for _, v := range sets { 211 | name, null := d.SQLType(v, false, 0) 212 | actual := []interface{}{name, null} 213 | expected := []interface{}{"text", true} 214 | if !reflect.DeepEqual(actual, expected) { 215 | t.Errorf("%T expects %q, but %q", v, expected, actual) 216 | } 217 | } 218 | 219 | // autoIncrement is true. 220 | for _, v := range sets { 221 | name, null := d.SQLType(v, true, 0) 222 | actual := []interface{}{name, null} 223 | expected := []interface{}{"text", true} 224 | if !reflect.DeepEqual(actual, expected) { 225 | t.Errorf("%T expects %q, but %q", v, expected, actual) 226 | } 227 | } 228 | } 229 | 230 | func TestSQLite3Dialect_SQLType_byteSlice(t *testing.T) { 231 | d := &SQLite3Dialect{} 232 | sets := []interface{}{[]byte("")} 233 | 234 | // autoIncrement is false. 235 | for _, v := range sets { 236 | name, null := d.SQLType(v, false, 0) 237 | actual := []interface{}{name, null} 238 | expected := []interface{}{"blob", true} 239 | if !reflect.DeepEqual(actual, expected) { 240 | t.Errorf("%T expects %q, but %q", v, expected, actual) 241 | } 242 | } 243 | 244 | // autoIncrement is true. 245 | for _, v := range sets { 246 | name, null := d.SQLType(v, true, 0) 247 | actual := []interface{}{name, null} 248 | expected := []interface{}{"blob", true} 249 | if !reflect.DeepEqual(actual, expected) { 250 | t.Errorf("%T expects %q, but %q", v, expected, actual) 251 | } 252 | } 253 | } 254 | 255 | func TestSQLite3Dialect_SQLType_timeDirect(t *testing.T) { 256 | d := &SQLite3Dialect{} 257 | sets := []interface{}{time.Time{}} 258 | 259 | // autoIncrement is false. 260 | for _, v := range sets { 261 | name, null := d.SQLType(v, false, 0) 262 | actual := []interface{}{name, null} 263 | expected := []interface{}{"datetime", false} 264 | if !reflect.DeepEqual(actual, expected) { 265 | t.Errorf("%T expects %q, but %q", v, expected, actual) 266 | } 267 | } 268 | 269 | // autoIncrement is true. 270 | for _, v := range sets { 271 | name, null := d.SQLType(v, true, 0) 272 | actual := []interface{}{name, null} 273 | expected := []interface{}{"datetime", false} 274 | if !reflect.DeepEqual(actual, expected) { 275 | t.Errorf("%T expects %q, but %q", v, expected, actual) 276 | } 277 | } 278 | } 279 | 280 | func TestSQLite3Dialect_SQLType_timeIndirect(t *testing.T) { 281 | d := &SQLite3Dialect{} 282 | sets := []interface{}{&time.Time{}} 283 | 284 | // autoIncrement is false. 285 | for _, v := range sets { 286 | name, null := d.SQLType(v, false, 0) 287 | actual := []interface{}{name, null} 288 | expected := []interface{}{"datetime", true} 289 | if !reflect.DeepEqual(actual, expected) { 290 | t.Errorf("%T expects %q, but %q", v, expected, actual) 291 | } 292 | } 293 | 294 | // autoIncrement is true. 295 | for _, v := range sets { 296 | name, null := d.SQLType(v, true, 0) 297 | actual := []interface{}{name, null} 298 | expected := []interface{}{"datetime", true} 299 | if !reflect.DeepEqual(actual, expected) { 300 | t.Errorf("%T expects %q, but %q", v, expected, actual) 301 | } 302 | } 303 | } 304 | 305 | func TestSQLite3Dialect_SQLType_floatDirect(t *testing.T) { 306 | d := &SQLite3Dialect{} 307 | sets := []interface{}{Float32(.1), Float64(.1)} 308 | 309 | // autoIncrement is false. 310 | for _, v := range sets { 311 | name, null := d.SQLType(v, false, 0) 312 | actual := []interface{}{name, null} 313 | expected := []interface{}{"real", false} 314 | if !reflect.DeepEqual(actual, expected) { 315 | t.Errorf("%T expects %q, but %q", v, expected, actual) 316 | } 317 | } 318 | 319 | // autoIncrement is true. 320 | for _, v := range sets { 321 | name, null := d.SQLType(v, true, 0) 322 | actual := []interface{}{name, null} 323 | expected := []interface{}{"real", false} 324 | if !reflect.DeepEqual(actual, expected) { 325 | t.Errorf("%T expects %q, but %q", v, expected, actual) 326 | } 327 | } 328 | } 329 | 330 | func TestSQLite3Dialect_SQLType_floatIndirect(t *testing.T) { 331 | d := &SQLite3Dialect{} 332 | sets := []interface{}{new(Float32), new(Float64)} 333 | 334 | // autoIncrement is false. 335 | for _, v := range sets { 336 | name, null := d.SQLType(v, false, 0) 337 | actual := []interface{}{name, null} 338 | expected := []interface{}{"real", true} 339 | if !reflect.DeepEqual(actual, expected) { 340 | t.Errorf("%T expects %q, but %q", v, expected, actual) 341 | } 342 | } 343 | 344 | // autoIncrement is true. 345 | for _, v := range sets { 346 | name, null := d.SQLType(v, true, 0) 347 | actual := []interface{}{name, null} 348 | expected := []interface{}{"real", true} 349 | if !reflect.DeepEqual(actual, expected) { 350 | t.Errorf("%T expects %q, but %q", v, expected, actual) 351 | } 352 | } 353 | } 354 | 355 | func TestSQLite3Dialect_SQLType_ratDirect(t *testing.T) { 356 | d := &SQLite3Dialect{} 357 | sets := []interface{}{Rat{}} 358 | 359 | // autoIncrement is false. 360 | for _, v := range sets { 361 | name, null := d.SQLType(v, false, 0) 362 | actual := []interface{}{name, null} 363 | expected := []interface{}{"numeric", false} 364 | if !reflect.DeepEqual(actual, expected) { 365 | t.Errorf("%T expects %q, but %q", v, expected, actual) 366 | } 367 | } 368 | 369 | // autoIncrement is true. 370 | for _, v := range sets { 371 | name, null := d.SQLType(v, true, 0) 372 | actual := []interface{}{name, null} 373 | expected := []interface{}{"numeric", false} 374 | if !reflect.DeepEqual(actual, expected) { 375 | t.Errorf("%T expects %q, but %q", v, expected, actual) 376 | } 377 | } 378 | } 379 | 380 | func TestSQLite3Dialect_SQLType_ratIndirect(t *testing.T) { 381 | d := &SQLite3Dialect{} 382 | sets := []interface{}{new(Rat)} 383 | 384 | // autoIncrement is false. 385 | for _, v := range sets { 386 | name, null := d.SQLType(v, false, 0) 387 | actual := []interface{}{name, null} 388 | expected := []interface{}{"numeric", true} 389 | if !reflect.DeepEqual(actual, expected) { 390 | t.Errorf("%T expects %q, but %q", v, expected, actual) 391 | } 392 | } 393 | 394 | // autoIncrement is true. 395 | for _, v := range sets { 396 | name, null := d.SQLType(v, true, 0) 397 | actual := []interface{}{name, null} 398 | expected := []interface{}{"numeric", true} 399 | if !reflect.DeepEqual(actual, expected) { 400 | t.Errorf("%T expects %q, but %q", v, expected, actual) 401 | } 402 | } 403 | } 404 | 405 | func TestSQLite3Dialect_AutoIncrement(t *testing.T) { 406 | d := &SQLite3Dialect{} 407 | actual := d.AutoIncrement() 408 | expected := "AUTOINCREMENT" 409 | if !reflect.DeepEqual(actual, expected) { 410 | t.Errorf("Expect %q, but %q", expected, actual) 411 | } 412 | } 413 | 414 | func TestSQLite3Dialect_FormatBool(t *testing.T) { 415 | d := &SQLite3Dialect{} 416 | actual := d.FormatBool(true) 417 | expected := "1" 418 | if !reflect.DeepEqual(actual, expected) { 419 | t.Errorf("Expect %q, but %q", expected, actual) 420 | } 421 | 422 | actual = d.FormatBool(false) 423 | expected = "0" 424 | if !reflect.DeepEqual(actual, expected) { 425 | t.Errorf("Expect %q, but %q", expected, actual) 426 | } 427 | } 428 | 429 | func TestSQLite3Dialect_LastInsertID(t *testing.T) { 430 | d := &SQLite3Dialect{} 431 | actual := d.LastInsertId() 432 | expect := "SELECT last_insert_rowid()" 433 | if !reflect.DeepEqual(actual, expect) { 434 | t.Errorf(`SQLite3Dialect.LastInsertId() => %#v; want %#v`, actual, expect) 435 | } 436 | } 437 | 438 | func Test_MySQLDialect_Name(t *testing.T) { 439 | d := &MySQLDialect{} 440 | actual := d.Name() 441 | expected := "mysql" 442 | if !reflect.DeepEqual(actual, expected) { 443 | t.Errorf("Expect %q, but %q", expected, actual) 444 | } 445 | } 446 | 447 | func Test_MySQLDialect_Quote(t *testing.T) { 448 | d := &MySQLDialect{} 449 | for _, v := range []struct { 450 | s, expected string 451 | }{ 452 | {"", "``"}, 453 | {"test", "`test`"}, 454 | {"`test`", "```test```"}, 455 | {"test`bar`baz", "`test``bar``baz`"}, 456 | } { 457 | actual := d.Quote(v.s) 458 | expected := v.expected 459 | if !reflect.DeepEqual(actual, expected) { 460 | t.Errorf("Input %q expects %q, but %q", v.s, expected, actual) 461 | } 462 | } 463 | } 464 | 465 | func Test_MySQLDialect_PlaceHolder(t *testing.T) { 466 | d := &MySQLDialect{} 467 | actual := d.PlaceHolder(0) 468 | expected := "?" 469 | if !reflect.DeepEqual(actual, expected) { 470 | t.Errorf("Expect %q, but %q", expected, actual) 471 | } 472 | } 473 | 474 | func TestMySQLDialect_SQLType_boolDirect(t *testing.T) { 475 | d := &MySQLDialect{} 476 | sets := []interface{}{true, false} 477 | 478 | // autoIncrement is false. 479 | for _, v := range sets { 480 | name, null := d.SQLType(v, false, 0) 481 | actual := []interface{}{name, null} 482 | expected := []interface{}{"BOOLEAN", false} 483 | if !reflect.DeepEqual(actual, expected) { 484 | t.Errorf("%T expects %q, but %q", v, expected, actual) 485 | } 486 | } 487 | 488 | // autoIncrement is true. 489 | for _, v := range sets { 490 | name, null := d.SQLType(v, true, 0) 491 | actual := []interface{}{name, null} 492 | expected := []interface{}{"BOOLEAN", false} 493 | if !reflect.DeepEqual(actual, expected) { 494 | t.Errorf("%T expects %q, but %q", v, expected, actual) 495 | } 496 | } 497 | } 498 | 499 | func TestMySQLDialect_SQLType_boolIndirect(t *testing.T) { 500 | d := &MySQLDialect{} 501 | sets := []interface{}{new(bool), sql.NullBool{}} 502 | 503 | // autoIncrement is false. 504 | for _, v := range sets { 505 | name, null := d.SQLType(v, false, 0) 506 | actual := []interface{}{name, null} 507 | expected := []interface{}{"BOOLEAN", true} 508 | if !reflect.DeepEqual(actual, expected) { 509 | t.Errorf("%T expects %q, but %q", v, expected, actual) 510 | } 511 | } 512 | 513 | // autoIncrement is true. 514 | for _, v := range sets { 515 | name, null := d.SQLType(v, true, 0) 516 | actual := []interface{}{name, null} 517 | expected := []interface{}{"BOOLEAN", true} 518 | if !reflect.DeepEqual(actual, expected) { 519 | t.Errorf("%T expects %q, but %q", v, expected, actual) 520 | } 521 | } 522 | } 523 | 524 | func TestMySQLDialect_SQLType_primitiveFloat(t *testing.T) { 525 | d := &MySQLDialect{} 526 | sets := []interface{}{float32(.1), new(float32), float64(.1), new(float64), sql.NullFloat64{}} 527 | 528 | // autoIncrement is false. 529 | for _, v := range sets { 530 | func(v interface{}) { 531 | defer func() { 532 | if err := recover(); err == nil { 533 | t.Errorf("panic hasn't been occurred by %T", v) 534 | } 535 | }() 536 | d.SQLType(v, false, 0) 537 | }(v) 538 | } 539 | 540 | // autoIncrement is true. 541 | for _, v := range sets { 542 | func(v interface{}) { 543 | defer func() { 544 | if err := recover(); err == nil { 545 | t.Errorf("panic hasn't been occurred by %T", v) 546 | } 547 | }() 548 | d.SQLType(v, false, 0) 549 | }(v) 550 | } 551 | } 552 | 553 | func TestMySQLDialect_SQLType_underInt16Direct(t *testing.T) { 554 | d := &MySQLDialect{} 555 | sets := []interface{}{int8(1), int16(1), uint8(1), uint16(1)} 556 | 557 | // autoIncrement is false. 558 | for _, v := range sets { 559 | name, null := d.SQLType(v, false, 0) 560 | actual := []interface{}{name, null} 561 | expected := []interface{}{"SMALLINT", false} 562 | if !reflect.DeepEqual(actual, expected) { 563 | t.Errorf("%T expects %q, but %q", v, expected, actual) 564 | } 565 | } 566 | 567 | // autoIncrement is true. 568 | for _, v := range sets { 569 | name, null := d.SQLType(v, true, 0) 570 | actual := []interface{}{name, null} 571 | expected := []interface{}{"SMALLINT", false} 572 | if !reflect.DeepEqual(actual, expected) { 573 | t.Errorf("%T expects %q, but %q", v, expected, actual) 574 | } 575 | } 576 | } 577 | 578 | func TestMySQLDialect_SQLType_underInt16Indirect(t *testing.T) { 579 | d := &MySQLDialect{} 580 | sets := []interface{}{new(int8), new(int16), new(uint8), new(uint16)} 581 | 582 | // autoIncrement is false. 583 | for _, v := range sets { 584 | name, null := d.SQLType(v, false, 0) 585 | actual := []interface{}{name, null} 586 | expected := []interface{}{"SMALLINT", true} 587 | if !reflect.DeepEqual(actual, expected) { 588 | t.Errorf("%T expects %q, but %q", v, expected, actual) 589 | } 590 | } 591 | 592 | // autoIncrement is true. 593 | for _, v := range sets { 594 | name, null := d.SQLType(v, true, 0) 595 | actual := []interface{}{name, null} 596 | expected := []interface{}{"SMALLINT", true} 597 | if !reflect.DeepEqual(actual, expected) { 598 | t.Errorf("%T expects %q, but %q", v, expected, actual) 599 | } 600 | } 601 | } 602 | 603 | func TestMySQLDialect_SQLType_intDirect(t *testing.T) { 604 | d := &MySQLDialect{} 605 | sets := []interface{}{int(1), int32(1), uint(1), uint32(1)} 606 | 607 | // autoIncrement is false. 608 | for _, v := range sets { 609 | name, null := d.SQLType(v, false, 0) 610 | actual := []interface{}{name, null} 611 | expected := []interface{}{"INT", false} 612 | if !reflect.DeepEqual(actual, expected) { 613 | t.Errorf("%T expects %q, but %q", v, expected, actual) 614 | } 615 | } 616 | 617 | // autoIncrement is true. 618 | for _, v := range sets { 619 | name, null := d.SQLType(v, true, 0) 620 | actual := []interface{}{name, null} 621 | expected := []interface{}{"INT", false} 622 | if !reflect.DeepEqual(actual, expected) { 623 | t.Errorf("%T expects %q, but %q", v, expected, actual) 624 | } 625 | } 626 | } 627 | 628 | func TestMySQLDialect_SQLType_intIndirect(t *testing.T) { 629 | d := &MySQLDialect{} 630 | sets := []interface{}{new(int), new(int32), new(uint), new(uint32)} 631 | 632 | // autoIncrement is false. 633 | for _, v := range sets { 634 | name, null := d.SQLType(v, false, 0) 635 | actual := []interface{}{name, null} 636 | expected := []interface{}{"INT", true} 637 | if !reflect.DeepEqual(actual, expected) { 638 | t.Errorf("%T expects %q, but %q", v, expected, actual) 639 | } 640 | } 641 | 642 | // autoIncrement is true. 643 | for _, v := range sets { 644 | name, null := d.SQLType(v, true, 0) 645 | actual := []interface{}{name, null} 646 | expected := []interface{}{"INT", true} 647 | if !reflect.DeepEqual(actual, expected) { 648 | t.Errorf("%T expects %q, but %q", v, expected, actual) 649 | } 650 | } 651 | } 652 | 653 | func TestMySQLDialect_SQLType_int64Direct(t *testing.T) { 654 | d := &MySQLDialect{} 655 | sets := []interface{}{int64(1), uint64(1)} 656 | 657 | // autoIncrement is false. 658 | for _, v := range sets { 659 | name, null := d.SQLType(v, false, 0) 660 | actual := []interface{}{name, null} 661 | expected := []interface{}{"BIGINT", false} 662 | if !reflect.DeepEqual(actual, expected) { 663 | t.Errorf("%T expects %q, but %q", v, expected, actual) 664 | } 665 | } 666 | 667 | // autoIncrement is true. 668 | for _, v := range sets { 669 | name, null := d.SQLType(v, true, 0) 670 | actual := []interface{}{name, null} 671 | expected := []interface{}{"BIGINT", false} 672 | if !reflect.DeepEqual(actual, expected) { 673 | t.Errorf("%T expects %q, but %q", v, expected, actual) 674 | } 675 | } 676 | } 677 | 678 | func TestMySQLDialect_SQLType_int64Indirect(t *testing.T) { 679 | d := &MySQLDialect{} 680 | sets := []interface{}{new(int64), new(uint64), sql.NullInt64{}} 681 | 682 | // autoIncrement is false. 683 | for _, v := range sets { 684 | name, null := d.SQLType(v, false, 0) 685 | actual := []interface{}{name, null} 686 | expected := []interface{}{"BIGINT", true} 687 | if !reflect.DeepEqual(actual, expected) { 688 | t.Errorf("%T expects %q, but %q", v, expected, actual) 689 | } 690 | } 691 | 692 | // autoIncrement is true. 693 | for _, v := range sets { 694 | name, null := d.SQLType(v, true, 0) 695 | actual := []interface{}{name, null} 696 | expected := []interface{}{"BIGINT", true} 697 | if !reflect.DeepEqual(actual, expected) { 698 | t.Errorf("%T expects %q, but %q", v, expected, actual) 699 | } 700 | } 701 | } 702 | 703 | func TestMySQLDialect_SQLType_stringDirect(t *testing.T) { 704 | d := &MySQLDialect{} 705 | sets := []interface{}{""} 706 | 707 | func() { 708 | // autoIncrement is false. 709 | for _, v := range sets { 710 | name, null := d.SQLType(v, false, 0) 711 | actual := []interface{}{name, null} 712 | expected := []interface{}{"VARCHAR(255)", false} 713 | if !reflect.DeepEqual(actual, expected) { 714 | t.Errorf("%T expects %q, but %q", v, expected, actual) 715 | } 716 | } 717 | 718 | // autoIncrement is true. 719 | for _, v := range sets { 720 | name, null := d.SQLType(v, true, 0) 721 | actual := []interface{}{name, null} 722 | expected := []interface{}{"VARCHAR(255)", false} 723 | if !reflect.DeepEqual(actual, expected) { 724 | t.Errorf("%T expects %q, but %q", v, expected, actual) 725 | } 726 | } 727 | }() 728 | 729 | func() { 730 | for _, v := range sets { 731 | name, null := d.SQLType(v, false, 1) 732 | actual := []interface{}{name, null} 733 | expected := []interface{}{"VARCHAR(1)", false} 734 | if !reflect.DeepEqual(actual, expected) { 735 | t.Errorf("Expect %q, but %q", expected, actual) 736 | } 737 | } 738 | }() 739 | 740 | func() { 741 | for _, v := range sets { 742 | name, null := d.SQLType(v, false, 2) 743 | actual := []interface{}{name, null} 744 | expected := []interface{}{"VARCHAR(2)", false} 745 | if !reflect.DeepEqual(actual, expected) { 746 | t.Errorf("Expect %q, but %q", expected, actual) 747 | } 748 | } 749 | }() 750 | 751 | func() { 752 | for _, v := range sets { 753 | name, null := d.SQLType(v, false, 65532) 754 | actual := []interface{}{name, null} 755 | expected := []interface{}{"VARCHAR(65532)", false} 756 | if !reflect.DeepEqual(actual, expected) { 757 | t.Errorf("Expect %q, but %q", expected, actual) 758 | } 759 | } 760 | }() 761 | 762 | func() { 763 | for _, v := range sets { 764 | name, null := d.SQLType(v, false, 65533) 765 | actual := []interface{}{name, null} 766 | expected := []interface{}{"MEDIUMTEXT", false} 767 | if !reflect.DeepEqual(actual, expected) { 768 | t.Errorf("Expect %q, but %q", expected, actual) 769 | } 770 | } 771 | }() 772 | 773 | func() { 774 | for _, v := range sets { 775 | name, null := d.SQLType(v, false, 16777215) 776 | actual := []interface{}{name, null} 777 | expected := []interface{}{"MEDIUMTEXT", false} 778 | if !reflect.DeepEqual(actual, expected) { 779 | t.Errorf("Expect %q, but %q", expected, actual) 780 | } 781 | } 782 | }() 783 | 784 | func() { 785 | for _, v := range sets { 786 | name, null := d.SQLType(v, false, 16777216) 787 | actual := []interface{}{name, null} 788 | expected := []interface{}{"LONGTEXT", false} 789 | if !reflect.DeepEqual(actual, expected) { 790 | t.Errorf("Expect %q, but %q", expected, actual) 791 | } 792 | } 793 | }() 794 | } 795 | 796 | func TestMySQLDialect_SQLType_stringIndirect(t *testing.T) { 797 | d := &MySQLDialect{} 798 | sets := []interface{}{new(string), sql.NullString{}} 799 | 800 | func() { 801 | // autoIncrement is false. 802 | for _, v := range sets { 803 | name, null := d.SQLType(v, false, 0) 804 | actual := []interface{}{name, null} 805 | expected := []interface{}{"VARCHAR(255)", true} 806 | if !reflect.DeepEqual(actual, expected) { 807 | t.Errorf("%T expects %q, but %q", v, expected, actual) 808 | } 809 | } 810 | 811 | // autoIncrement is true. 812 | for _, v := range sets { 813 | name, null := d.SQLType(v, true, 0) 814 | actual := []interface{}{name, null} 815 | expected := []interface{}{"VARCHAR(255)", true} 816 | if !reflect.DeepEqual(actual, expected) { 817 | t.Errorf("%T expects %q, but %q", v, expected, actual) 818 | } 819 | } 820 | }() 821 | 822 | func() { 823 | for _, v := range sets { 824 | name, null := d.SQLType(v, false, 1) 825 | actual := []interface{}{name, null} 826 | expected := []interface{}{"VARCHAR(1)", true} 827 | if !reflect.DeepEqual(actual, expected) { 828 | t.Errorf("Expect %q, but %q", expected, actual) 829 | } 830 | } 831 | }() 832 | 833 | func() { 834 | for _, v := range sets { 835 | name, null := d.SQLType(v, false, 2) 836 | actual := []interface{}{name, null} 837 | expected := []interface{}{"VARCHAR(2)", true} 838 | if !reflect.DeepEqual(actual, expected) { 839 | t.Errorf("Expect %q, but %q", expected, actual) 840 | } 841 | } 842 | }() 843 | 844 | func() { 845 | for _, v := range sets { 846 | name, null := d.SQLType(v, false, 65532) 847 | actual := []interface{}{name, null} 848 | expected := []interface{}{"VARCHAR(65532)", true} 849 | if !reflect.DeepEqual(actual, expected) { 850 | t.Errorf("Expect %q, but %q", expected, actual) 851 | } 852 | } 853 | }() 854 | 855 | func() { 856 | for _, v := range sets { 857 | name, null := d.SQLType(v, false, 65533) 858 | actual := []interface{}{name, null} 859 | expected := []interface{}{"MEDIUMTEXT", true} 860 | if !reflect.DeepEqual(actual, expected) { 861 | t.Errorf("Expect %q, but %q", expected, actual) 862 | } 863 | } 864 | }() 865 | 866 | func() { 867 | for _, v := range sets { 868 | name, null := d.SQLType(v, false, 16777215) 869 | actual := []interface{}{name, null} 870 | expected := []interface{}{"MEDIUMTEXT", true} 871 | if !reflect.DeepEqual(actual, expected) { 872 | t.Errorf("Expect %q, but %q", expected, actual) 873 | } 874 | } 875 | }() 876 | 877 | func() { 878 | for _, v := range sets { 879 | name, null := d.SQLType(v, false, 16777216) 880 | actual := []interface{}{name, null} 881 | expected := []interface{}{"LONGTEXT", true} 882 | if !reflect.DeepEqual(actual, expected) { 883 | t.Errorf("Expect %q, but %q", expected, actual) 884 | } 885 | } 886 | }() 887 | } 888 | 889 | func TestMySQLDialect_SQLType_byteSlice(t *testing.T) { 890 | d := &MySQLDialect{} 891 | sets := []interface{}{[]byte("")} 892 | 893 | func() { 894 | // autoIncrement is false. 895 | for _, v := range sets { 896 | name, null := d.SQLType(v, false, 0) 897 | actual := []interface{}{name, null} 898 | expected := []interface{}{"VARBINARY(255)", true} 899 | if !reflect.DeepEqual(actual, expected) { 900 | t.Errorf("%T expects %q, but %q", v, expected, actual) 901 | } 902 | } 903 | 904 | // autoIncrement is true. 905 | for _, v := range sets { 906 | name, null := d.SQLType(v, true, 0) 907 | actual := []interface{}{name, null} 908 | expected := []interface{}{"VARBINARY(255)", true} 909 | if !reflect.DeepEqual(actual, expected) { 910 | t.Errorf("%T expects %q, but %q", v, expected, actual) 911 | } 912 | } 913 | }() 914 | 915 | func() { 916 | for _, v := range sets { 917 | name, null := d.SQLType(v, false, 1) 918 | actual := []interface{}{name, null} 919 | expected := []interface{}{"VARBINARY(1)", true} 920 | if !reflect.DeepEqual(actual, expected) { 921 | t.Errorf("Expect %q, but %q", expected, actual) 922 | } 923 | } 924 | }() 925 | 926 | func() { 927 | for _, v := range sets { 928 | name, null := d.SQLType(v, false, 2) 929 | actual := []interface{}{name, null} 930 | expected := []interface{}{"VARBINARY(2)", true} 931 | if !reflect.DeepEqual(actual, expected) { 932 | t.Errorf("Expect %q, but %q", expected, actual) 933 | } 934 | } 935 | }() 936 | 937 | func() { 938 | for _, v := range sets { 939 | name, null := d.SQLType(v, false, 65532) 940 | actual := []interface{}{name, null} 941 | expected := []interface{}{"VARBINARY(65532)", true} 942 | if !reflect.DeepEqual(actual, expected) { 943 | t.Errorf("Expect %q, but %q", expected, actual) 944 | } 945 | } 946 | }() 947 | 948 | func() { 949 | for _, v := range sets { 950 | name, null := d.SQLType(v, false, 65533) 951 | actual := []interface{}{name, null} 952 | expected := []interface{}{"MEDIUMBLOB", true} 953 | if !reflect.DeepEqual(actual, expected) { 954 | t.Errorf("Expect %q, but %q", expected, actual) 955 | } 956 | } 957 | }() 958 | 959 | func() { 960 | for _, v := range sets { 961 | name, null := d.SQLType(v, false, 16777215) 962 | actual := []interface{}{name, null} 963 | expected := []interface{}{"MEDIUMBLOB", true} 964 | if !reflect.DeepEqual(actual, expected) { 965 | t.Errorf("Expect %q, but %q", expected, actual) 966 | } 967 | } 968 | }() 969 | 970 | func() { 971 | for _, v := range sets { 972 | name, null := d.SQLType(v, false, 16777216) 973 | actual := []interface{}{name, null} 974 | expected := []interface{}{"LONGBLOB", true} 975 | if !reflect.DeepEqual(actual, expected) { 976 | t.Errorf("Expect %q, but %q", expected, actual) 977 | } 978 | } 979 | }() 980 | } 981 | 982 | func TestMySQLDialect_SQLType_timeDirect(t *testing.T) { 983 | d := &MySQLDialect{} 984 | sets := []interface{}{time.Time{}} 985 | 986 | // autoIncrement is false. 987 | for _, v := range sets { 988 | name, null := d.SQLType(v, false, 0) 989 | actual := []interface{}{name, null} 990 | expected := []interface{}{"DATETIME", false} 991 | if !reflect.DeepEqual(actual, expected) { 992 | t.Errorf("%T expects %q, but %q", v, expected, actual) 993 | } 994 | } 995 | 996 | // autoIncrement is true. 997 | for _, v := range sets { 998 | name, null := d.SQLType(v, true, 0) 999 | actual := []interface{}{name, null} 1000 | expected := []interface{}{"DATETIME", false} 1001 | if !reflect.DeepEqual(actual, expected) { 1002 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1003 | } 1004 | } 1005 | } 1006 | 1007 | func TestMySQLDialect_SQLType_timeIndirect(t *testing.T) { 1008 | d := &MySQLDialect{} 1009 | sets := []interface{}{&time.Time{}} 1010 | 1011 | // autoIncrement is false. 1012 | for _, v := range sets { 1013 | name, null := d.SQLType(v, false, 0) 1014 | actual := []interface{}{name, null} 1015 | expected := []interface{}{"DATETIME", true} 1016 | if !reflect.DeepEqual(actual, expected) { 1017 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1018 | } 1019 | } 1020 | 1021 | // autoIncrement is true. 1022 | for _, v := range sets { 1023 | name, null := d.SQLType(v, true, 0) 1024 | actual := []interface{}{name, null} 1025 | expected := []interface{}{"DATETIME", true} 1026 | if !reflect.DeepEqual(actual, expected) { 1027 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1028 | } 1029 | } 1030 | } 1031 | 1032 | func TestMySQLDialect_SQLType_floatDirect(t *testing.T) { 1033 | d := &MySQLDialect{} 1034 | sets := []interface{}{Float32(.1), Float64(.1)} 1035 | 1036 | // autoIncrement is false. 1037 | for _, v := range sets { 1038 | name, null := d.SQLType(v, false, 0) 1039 | actual := []interface{}{name, null} 1040 | expected := []interface{}{"DOUBLE", false} 1041 | if !reflect.DeepEqual(actual, expected) { 1042 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1043 | } 1044 | } 1045 | 1046 | // autoIncrement is true. 1047 | for _, v := range sets { 1048 | name, null := d.SQLType(v, true, 0) 1049 | actual := []interface{}{name, null} 1050 | expected := []interface{}{"DOUBLE", false} 1051 | if !reflect.DeepEqual(actual, expected) { 1052 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1053 | } 1054 | } 1055 | } 1056 | 1057 | func TestMySQLDialect_SQLType_floatIndirect(t *testing.T) { 1058 | d := &MySQLDialect{} 1059 | sets := []interface{}{new(Float32), new(Float64)} 1060 | 1061 | // autoIncrement is false. 1062 | for _, v := range sets { 1063 | name, null := d.SQLType(v, false, 0) 1064 | actual := []interface{}{name, null} 1065 | expected := []interface{}{"DOUBLE", true} 1066 | if !reflect.DeepEqual(actual, expected) { 1067 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1068 | } 1069 | } 1070 | 1071 | // autoIncrement is true. 1072 | for _, v := range sets { 1073 | name, null := d.SQLType(v, true, 0) 1074 | actual := []interface{}{name, null} 1075 | expected := []interface{}{"DOUBLE", true} 1076 | if !reflect.DeepEqual(actual, expected) { 1077 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1078 | } 1079 | } 1080 | } 1081 | 1082 | func TestMySQLDialect_SQLType_ratDirect(t *testing.T) { 1083 | d := &MySQLDialect{} 1084 | sets := []interface{}{Rat{}} 1085 | 1086 | // autoIncrement is false. 1087 | for _, v := range sets { 1088 | name, null := d.SQLType(v, false, 0) 1089 | actual := []interface{}{name, null} 1090 | expected := []interface{}{"DECIMAL(65, 30)", false} 1091 | if !reflect.DeepEqual(actual, expected) { 1092 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1093 | } 1094 | } 1095 | 1096 | // autoIncrement is true. 1097 | for _, v := range sets { 1098 | name, null := d.SQLType(v, true, 0) 1099 | actual := []interface{}{name, null} 1100 | expected := []interface{}{"DECIMAL(65, 30)", false} 1101 | if !reflect.DeepEqual(actual, expected) { 1102 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1103 | } 1104 | } 1105 | } 1106 | 1107 | func TestMySQLDialect_SQLType_ratIndirect(t *testing.T) { 1108 | d := &MySQLDialect{} 1109 | sets := []interface{}{new(Rat)} 1110 | 1111 | // autoIncrement is false. 1112 | for _, v := range sets { 1113 | name, null := d.SQLType(v, false, 0) 1114 | actual := []interface{}{name, null} 1115 | expected := []interface{}{"DECIMAL(65, 30)", true} 1116 | if !reflect.DeepEqual(actual, expected) { 1117 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1118 | } 1119 | } 1120 | 1121 | // autoIncrement is true. 1122 | for _, v := range sets { 1123 | name, null := d.SQLType(v, true, 0) 1124 | actual := []interface{}{name, null} 1125 | expected := []interface{}{"DECIMAL(65, 30)", true} 1126 | if !reflect.DeepEqual(actual, expected) { 1127 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1128 | } 1129 | } 1130 | } 1131 | 1132 | func TestMySQLDialect_AutoIncrement(t *testing.T) { 1133 | d := &MySQLDialect{} 1134 | actual := d.AutoIncrement() 1135 | expected := "AUTO_INCREMENT" 1136 | if !reflect.DeepEqual(actual, expected) { 1137 | t.Errorf("Expect %q, but %q", expected, actual) 1138 | } 1139 | } 1140 | 1141 | func TestMySQLDialect_FormatBool(t *testing.T) { 1142 | d := &MySQLDialect{} 1143 | actual := d.FormatBool(true) 1144 | expected := "TRUE" 1145 | if !reflect.DeepEqual(actual, expected) { 1146 | t.Errorf("Expect %q, but %q", expected, actual) 1147 | } 1148 | 1149 | actual = d.FormatBool(false) 1150 | expected = "FALSE" 1151 | if !reflect.DeepEqual(actual, expected) { 1152 | t.Errorf("Expect %q, but %q", expected, actual) 1153 | } 1154 | } 1155 | 1156 | func TestMySQLDialect_LastInsertID(t *testing.T) { 1157 | d := &MySQLDialect{} 1158 | actual := d.LastInsertId() 1159 | expect := "SELECT LAST_INSERT_ID()" 1160 | if !reflect.DeepEqual(actual, expect) { 1161 | t.Errorf(`MySQLDialect.LastInsertId() => %#v; want %#v`, actual, expect) 1162 | } 1163 | } 1164 | 1165 | func Test_PostgresDialect_Name(t *testing.T) { 1166 | d := &PostgresDialect{} 1167 | actual := d.Name() 1168 | expected := "postgres" 1169 | if !reflect.DeepEqual(actual, expected) { 1170 | t.Errorf("Expect %q, but %q", expected, actual) 1171 | } 1172 | } 1173 | 1174 | func Test_PostgresDialect_Quote(t *testing.T) { 1175 | d := &PostgresDialect{} 1176 | for _, v := range []struct { 1177 | s, expected string 1178 | }{ 1179 | {``, `""`}, 1180 | {`test`, `"test"`}, 1181 | {`"test"`, `"""test"""`}, 1182 | {`test"bar"baz`, `"test""bar""baz"`}, 1183 | } { 1184 | actual := d.Quote(v.s) 1185 | expected := v.expected 1186 | if !reflect.DeepEqual(actual, expected) { 1187 | t.Errorf("Input %q expects %q, but %q", v.s, expected, actual) 1188 | } 1189 | } 1190 | } 1191 | 1192 | func Test_PostgresDialect_PlaceHolder(t *testing.T) { 1193 | d := &PostgresDialect{} 1194 | actual := d.PlaceHolder(0) 1195 | expected := "$1" 1196 | if !reflect.DeepEqual(actual, expected) { 1197 | t.Errorf("Expect %q, but %q", expected, actual) 1198 | } 1199 | 1200 | actual = d.PlaceHolder(1) 1201 | expected = "$2" 1202 | if !reflect.DeepEqual(actual, expected) { 1203 | t.Errorf("Expect %q, but %q", expected, actual) 1204 | } 1205 | } 1206 | 1207 | func TestPostgresDialect_SQLType_boolDirect(t *testing.T) { 1208 | d := &PostgresDialect{} 1209 | sets := []interface{}{true, false} 1210 | 1211 | // autoIncrement is false. 1212 | for _, v := range sets { 1213 | name, null := d.SQLType(v, false, 0) 1214 | actual := []interface{}{name, null} 1215 | expected := []interface{}{"boolean", false} 1216 | if !reflect.DeepEqual(actual, expected) { 1217 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1218 | } 1219 | } 1220 | 1221 | // autoIncrement is true. 1222 | for _, v := range sets { 1223 | name, null := d.SQLType(v, true, 0) 1224 | actual := []interface{}{name, null} 1225 | expected := []interface{}{"boolean", false} 1226 | if !reflect.DeepEqual(actual, expected) { 1227 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1228 | } 1229 | } 1230 | } 1231 | 1232 | func TestPostgresDialect_SQLType_boolIndirect(t *testing.T) { 1233 | d := &PostgresDialect{} 1234 | sets := []interface{}{new(bool), sql.NullBool{}} 1235 | 1236 | // autoIncrement is false. 1237 | for _, v := range sets { 1238 | name, null := d.SQLType(v, false, 0) 1239 | actual := []interface{}{name, null} 1240 | expected := []interface{}{"boolean", true} 1241 | if !reflect.DeepEqual(actual, expected) { 1242 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1243 | } 1244 | } 1245 | 1246 | // autoIncrement is true. 1247 | for _, v := range sets { 1248 | name, null := d.SQLType(v, true, 0) 1249 | actual := []interface{}{name, null} 1250 | expected := []interface{}{"boolean", true} 1251 | if !reflect.DeepEqual(actual, expected) { 1252 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1253 | } 1254 | } 1255 | } 1256 | 1257 | func TestPostgresDialect_SQLType_primitiveFloat(t *testing.T) { 1258 | d := &PostgresDialect{} 1259 | sets := []interface{}{float32(.1), new(float32), float64(.1), new(float64), sql.NullFloat64{}} 1260 | 1261 | // autoIncrement is false. 1262 | for _, v := range sets { 1263 | func(v interface{}) { 1264 | defer func() { 1265 | if err := recover(); err == nil { 1266 | t.Errorf("panic hasn't been occurred by %T", v) 1267 | } 1268 | }() 1269 | d.SQLType(v, false, 0) 1270 | }(v) 1271 | } 1272 | 1273 | // autoIncrement is true. 1274 | for _, v := range sets { 1275 | func(v interface{}) { 1276 | defer func() { 1277 | if err := recover(); err == nil { 1278 | t.Errorf("panic hasn't been occurred by %T", v) 1279 | } 1280 | }() 1281 | d.SQLType(v, false, 0) 1282 | }(v) 1283 | } 1284 | } 1285 | 1286 | func TestPostgresDialect_SQLType_underInt16Direct(t *testing.T) { 1287 | d := &PostgresDialect{} 1288 | sets := []interface{}{int8(1), int16(1), uint8(1), uint16(1)} 1289 | 1290 | // autoIncrement is false. 1291 | for _, v := range sets { 1292 | name, null := d.SQLType(v, false, 0) 1293 | actual := []interface{}{name, null} 1294 | expected := []interface{}{"smallint", false} 1295 | if !reflect.DeepEqual(actual, expected) { 1296 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1297 | } 1298 | } 1299 | 1300 | // autoIncrement is true. 1301 | for _, v := range sets { 1302 | name, null := d.SQLType(v, true, 0) 1303 | actual := []interface{}{name, null} 1304 | expected := []interface{}{"smallserial", false} 1305 | if !reflect.DeepEqual(actual, expected) { 1306 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1307 | } 1308 | } 1309 | } 1310 | 1311 | func TestPostgresDialect_SQLType_underInt16Indirect(t *testing.T) { 1312 | d := &PostgresDialect{} 1313 | sets := []interface{}{new(int8), new(int16), new(uint8), new(uint16)} 1314 | 1315 | // autoIncrement is false. 1316 | for _, v := range sets { 1317 | name, null := d.SQLType(v, false, 0) 1318 | actual := []interface{}{name, null} 1319 | expected := []interface{}{"smallint", true} 1320 | if !reflect.DeepEqual(actual, expected) { 1321 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1322 | } 1323 | } 1324 | 1325 | // autoIncrement is true. 1326 | for _, v := range sets { 1327 | name, null := d.SQLType(v, true, 0) 1328 | actual := []interface{}{name, null} 1329 | expected := []interface{}{"smallserial", true} 1330 | if !reflect.DeepEqual(actual, expected) { 1331 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1332 | } 1333 | } 1334 | } 1335 | 1336 | func TestPostgresDialect_SQLType_intDirect(t *testing.T) { 1337 | d := &PostgresDialect{} 1338 | sets := []interface{}{int(1), int32(1), uint(1), uint32(1)} 1339 | 1340 | // autoIncrement is false. 1341 | for _, v := range sets { 1342 | name, null := d.SQLType(v, false, 0) 1343 | actual := []interface{}{name, null} 1344 | expected := []interface{}{"integer", false} 1345 | if !reflect.DeepEqual(actual, expected) { 1346 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1347 | } 1348 | } 1349 | 1350 | // autoIncrement is true. 1351 | for _, v := range sets { 1352 | name, null := d.SQLType(v, true, 0) 1353 | actual := []interface{}{name, null} 1354 | expected := []interface{}{"serial", false} 1355 | if !reflect.DeepEqual(actual, expected) { 1356 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1357 | } 1358 | } 1359 | } 1360 | 1361 | func TestPostgresDialect_SQLType_intIndirect(t *testing.T) { 1362 | d := &PostgresDialect{} 1363 | sets := []interface{}{new(int), new(int32), new(uint), new(uint32)} 1364 | 1365 | // autoIncrement is false. 1366 | for _, v := range sets { 1367 | name, null := d.SQLType(v, false, 0) 1368 | actual := []interface{}{name, null} 1369 | expected := []interface{}{"integer", true} 1370 | if !reflect.DeepEqual(actual, expected) { 1371 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1372 | } 1373 | } 1374 | 1375 | // autoIncrement is true. 1376 | for _, v := range sets { 1377 | name, null := d.SQLType(v, true, 0) 1378 | actual := []interface{}{name, null} 1379 | expected := []interface{}{"serial", true} 1380 | if !reflect.DeepEqual(actual, expected) { 1381 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1382 | } 1383 | } 1384 | } 1385 | 1386 | func TestPostgresDialect_SQLType_int64Direct(t *testing.T) { 1387 | d := &PostgresDialect{} 1388 | sets := []interface{}{int64(1), uint64(1)} 1389 | 1390 | // autoIncrement is false. 1391 | for _, v := range sets { 1392 | name, null := d.SQLType(v, false, 0) 1393 | actual := []interface{}{name, null} 1394 | expected := []interface{}{"bigint", false} 1395 | if !reflect.DeepEqual(actual, expected) { 1396 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1397 | } 1398 | } 1399 | 1400 | // autoIncrement is true. 1401 | for _, v := range sets { 1402 | name, null := d.SQLType(v, true, 0) 1403 | actual := []interface{}{name, null} 1404 | expected := []interface{}{"bigserial", false} 1405 | if !reflect.DeepEqual(actual, expected) { 1406 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1407 | } 1408 | } 1409 | } 1410 | 1411 | func TestPostgresDialect_SQLType_int64Indirect(t *testing.T) { 1412 | d := &PostgresDialect{} 1413 | sets := []interface{}{new(int64), new(uint64), sql.NullInt64{}} 1414 | 1415 | // autoIncrement is false. 1416 | for _, v := range sets { 1417 | name, null := d.SQLType(v, false, 0) 1418 | actual := []interface{}{name, null} 1419 | expected := []interface{}{"bigint", true} 1420 | if !reflect.DeepEqual(actual, expected) { 1421 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1422 | } 1423 | } 1424 | 1425 | // autoIncrement is true. 1426 | for _, v := range sets { 1427 | name, null := d.SQLType(v, true, 0) 1428 | actual := []interface{}{name, null} 1429 | expected := []interface{}{"bigserial", true} 1430 | if !reflect.DeepEqual(actual, expected) { 1431 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1432 | } 1433 | } 1434 | } 1435 | 1436 | func TestPostgresDialect_SQLType_stringDirect(t *testing.T) { 1437 | d := &PostgresDialect{} 1438 | sets := []interface{}{""} 1439 | 1440 | func() { 1441 | // autoIncrement is false. 1442 | for _, v := range sets { 1443 | name, null := d.SQLType(v, false, 0) 1444 | actual := []interface{}{name, null} 1445 | expected := []interface{}{"varchar(255)", false} 1446 | if !reflect.DeepEqual(actual, expected) { 1447 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1448 | } 1449 | } 1450 | 1451 | // autoIncrement is true. 1452 | for _, v := range sets { 1453 | name, null := d.SQLType(v, true, 0) 1454 | actual := []interface{}{name, null} 1455 | expected := []interface{}{"varchar(255)", false} 1456 | if !reflect.DeepEqual(actual, expected) { 1457 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1458 | } 1459 | } 1460 | }() 1461 | 1462 | func() { 1463 | for _, v := range sets { 1464 | name, null := d.SQLType(v, false, 1) 1465 | actual := []interface{}{name, null} 1466 | expected := []interface{}{"varchar(1)", false} 1467 | if !reflect.DeepEqual(actual, expected) { 1468 | t.Errorf("Expect %q, but %q", expected, actual) 1469 | } 1470 | } 1471 | }() 1472 | 1473 | func() { 1474 | for _, v := range sets { 1475 | name, null := d.SQLType(v, false, 2) 1476 | actual := []interface{}{name, null} 1477 | expected := []interface{}{"varchar(2)", false} 1478 | if !reflect.DeepEqual(actual, expected) { 1479 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1480 | } 1481 | } 1482 | }() 1483 | 1484 | func() { 1485 | for _, v := range sets { 1486 | name, null := d.SQLType(v, false, 65532) 1487 | actual := []interface{}{name, null} 1488 | expected := []interface{}{"varchar(65532)", false} 1489 | if !reflect.DeepEqual(actual, expected) { 1490 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1491 | } 1492 | } 1493 | }() 1494 | 1495 | func() { 1496 | for _, v := range sets { 1497 | name, null := d.SQLType(v, false, 65533) 1498 | actual := []interface{}{name, null} 1499 | expected := []interface{}{"text", false} 1500 | if !reflect.DeepEqual(actual, expected) { 1501 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1502 | } 1503 | } 1504 | }() 1505 | 1506 | func() { 1507 | for _, v := range sets { 1508 | name, null := d.SQLType(v, false, 16777215) 1509 | actual := []interface{}{name, null} 1510 | expected := []interface{}{"text", false} 1511 | if !reflect.DeepEqual(actual, expected) { 1512 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1513 | } 1514 | } 1515 | }() 1516 | 1517 | func() { 1518 | for _, v := range sets { 1519 | name, null := d.SQLType(v, false, 16777216) 1520 | actual := []interface{}{name, null} 1521 | expected := []interface{}{"text", false} 1522 | if !reflect.DeepEqual(actual, expected) { 1523 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1524 | } 1525 | } 1526 | }() 1527 | } 1528 | 1529 | func TestPostgresDialect_SQLType_stringIndirect(t *testing.T) { 1530 | d := &PostgresDialect{} 1531 | sets := []interface{}{new(string), sql.NullString{}} 1532 | 1533 | func() { 1534 | // autoIncrement is false. 1535 | for _, v := range sets { 1536 | name, null := d.SQLType(v, false, 0) 1537 | actual := []interface{}{name, null} 1538 | expected := []interface{}{"varchar(255)", true} 1539 | if !reflect.DeepEqual(actual, expected) { 1540 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1541 | } 1542 | } 1543 | 1544 | // autoIncrement is true. 1545 | for _, v := range sets { 1546 | name, null := d.SQLType(v, true, 0) 1547 | actual := []interface{}{name, null} 1548 | expected := []interface{}{"varchar(255)", true} 1549 | if !reflect.DeepEqual(actual, expected) { 1550 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1551 | } 1552 | } 1553 | }() 1554 | 1555 | func() { 1556 | for _, v := range sets { 1557 | name, null := d.SQLType(v, false, 1) 1558 | actual := []interface{}{name, null} 1559 | expected := []interface{}{"varchar(1)", true} 1560 | if !reflect.DeepEqual(actual, expected) { 1561 | t.Errorf("Expect %q, but %q", expected, actual) 1562 | } 1563 | } 1564 | }() 1565 | 1566 | func() { 1567 | for _, v := range sets { 1568 | name, null := d.SQLType(v, false, 2) 1569 | actual := []interface{}{name, null} 1570 | expected := []interface{}{"varchar(2)", true} 1571 | if !reflect.DeepEqual(actual, expected) { 1572 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1573 | } 1574 | } 1575 | }() 1576 | 1577 | func() { 1578 | for _, v := range sets { 1579 | name, null := d.SQLType(v, false, 65532) 1580 | actual := []interface{}{name, null} 1581 | expected := []interface{}{"varchar(65532)", true} 1582 | if !reflect.DeepEqual(actual, expected) { 1583 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1584 | } 1585 | } 1586 | }() 1587 | 1588 | func() { 1589 | for _, v := range sets { 1590 | name, null := d.SQLType(v, false, 65533) 1591 | actual := []interface{}{name, null} 1592 | expected := []interface{}{"text", true} 1593 | if !reflect.DeepEqual(actual, expected) { 1594 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1595 | } 1596 | } 1597 | }() 1598 | 1599 | func() { 1600 | for _, v := range sets { 1601 | name, null := d.SQLType(v, false, 16777215) 1602 | actual := []interface{}{name, null} 1603 | expected := []interface{}{"text", true} 1604 | if !reflect.DeepEqual(actual, expected) { 1605 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1606 | } 1607 | } 1608 | }() 1609 | 1610 | func() { 1611 | for _, v := range sets { 1612 | name, null := d.SQLType(v, false, 16777216) 1613 | actual := []interface{}{name, null} 1614 | expected := []interface{}{"text", true} 1615 | if !reflect.DeepEqual(actual, expected) { 1616 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1617 | } 1618 | } 1619 | }() 1620 | } 1621 | 1622 | func TestPostgresDialect_SQLType_byteSlice(t *testing.T) { 1623 | d := &PostgresDialect{} 1624 | sets := []interface{}{[]byte("")} 1625 | 1626 | // autoIncrement is false. 1627 | for _, v := range sets { 1628 | name, null := d.SQLType(v, false, 0) 1629 | actual := []interface{}{name, null} 1630 | expected := []interface{}{"bytea", true} 1631 | if !reflect.DeepEqual(actual, expected) { 1632 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1633 | } 1634 | } 1635 | 1636 | // autoIncrement is true. 1637 | for _, v := range sets { 1638 | name, null := d.SQLType(v, true, 0) 1639 | actual := []interface{}{name, null} 1640 | expected := []interface{}{"bytea", true} 1641 | if !reflect.DeepEqual(actual, expected) { 1642 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1643 | } 1644 | } 1645 | } 1646 | 1647 | func TestPostgresDialect_SQLType_timeDirect(t *testing.T) { 1648 | d := &PostgresDialect{} 1649 | sets := []interface{}{time.Time{}} 1650 | 1651 | // autoIncrement is false. 1652 | for _, v := range sets { 1653 | name, null := d.SQLType(v, false, 0) 1654 | actual := []interface{}{name, null} 1655 | expected := []interface{}{"timestamp with time zone", false} 1656 | if !reflect.DeepEqual(actual, expected) { 1657 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1658 | } 1659 | } 1660 | 1661 | // autoIncrement is true. 1662 | for _, v := range sets { 1663 | name, null := d.SQLType(v, true, 0) 1664 | actual := []interface{}{name, null} 1665 | expected := []interface{}{"timestamp with time zone", false} 1666 | if !reflect.DeepEqual(actual, expected) { 1667 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1668 | } 1669 | } 1670 | } 1671 | 1672 | func TestPostgresDialect_SQLType_timeIndirect(t *testing.T) { 1673 | d := &PostgresDialect{} 1674 | sets := []interface{}{&time.Time{}} 1675 | 1676 | // autoIncrement is false. 1677 | for _, v := range sets { 1678 | name, null := d.SQLType(v, false, 0) 1679 | actual := []interface{}{name, null} 1680 | expected := []interface{}{"timestamp with time zone", true} 1681 | if !reflect.DeepEqual(actual, expected) { 1682 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1683 | } 1684 | } 1685 | 1686 | // autoIncrement is true. 1687 | for _, v := range sets { 1688 | name, null := d.SQLType(v, true, 0) 1689 | actual := []interface{}{name, null} 1690 | expected := []interface{}{"timestamp with time zone", true} 1691 | if !reflect.DeepEqual(actual, expected) { 1692 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1693 | } 1694 | } 1695 | } 1696 | 1697 | func TestPostgresDialect_SQLType_floatDirect(t *testing.T) { 1698 | d := &PostgresDialect{} 1699 | sets := []interface{}{Float32(.1), Float64(.1)} 1700 | 1701 | // autoIncrement is false. 1702 | for _, v := range sets { 1703 | name, null := d.SQLType(v, false, 0) 1704 | actual := []interface{}{name, null} 1705 | expected := []interface{}{"double precision", false} 1706 | if !reflect.DeepEqual(actual, expected) { 1707 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1708 | } 1709 | } 1710 | 1711 | // autoIncrement is true. 1712 | for _, v := range sets { 1713 | name, null := d.SQLType(v, true, 0) 1714 | actual := []interface{}{name, null} 1715 | expected := []interface{}{"double precision", false} 1716 | if !reflect.DeepEqual(actual, expected) { 1717 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1718 | } 1719 | } 1720 | } 1721 | 1722 | func TestPostgresDialect_SQLType_floatIndirect(t *testing.T) { 1723 | d := &PostgresDialect{} 1724 | sets := []interface{}{new(Float32), new(Float64)} 1725 | 1726 | // autoIncrement is false. 1727 | for _, v := range sets { 1728 | name, null := d.SQLType(v, false, 0) 1729 | actual := []interface{}{name, null} 1730 | expected := []interface{}{"double precision", true} 1731 | if !reflect.DeepEqual(actual, expected) { 1732 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1733 | } 1734 | } 1735 | 1736 | // autoIncrement is true. 1737 | for _, v := range sets { 1738 | name, null := d.SQLType(v, true, 0) 1739 | actual := []interface{}{name, null} 1740 | expected := []interface{}{"double precision", true} 1741 | if !reflect.DeepEqual(actual, expected) { 1742 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1743 | } 1744 | } 1745 | } 1746 | 1747 | func TestPostgresDialect_SQLType_ratDirect(t *testing.T) { 1748 | d := &PostgresDialect{} 1749 | sets := []interface{}{Rat{}} 1750 | 1751 | // autoIncrement is false. 1752 | for _, v := range sets { 1753 | name, null := d.SQLType(v, false, 0) 1754 | actual := []interface{}{name, null} 1755 | expected := []interface{}{"numeric(65, 30)", false} 1756 | if !reflect.DeepEqual(actual, expected) { 1757 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1758 | } 1759 | } 1760 | 1761 | // autoIncrement is true. 1762 | for _, v := range sets { 1763 | name, null := d.SQLType(v, true, 0) 1764 | actual := []interface{}{name, null} 1765 | expected := []interface{}{"numeric(65, 30)", false} 1766 | if !reflect.DeepEqual(actual, expected) { 1767 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1768 | } 1769 | } 1770 | } 1771 | 1772 | func TestPostgresDialect_SQLType_ratIndirect(t *testing.T) { 1773 | d := &PostgresDialect{} 1774 | sets := []interface{}{new(Rat)} 1775 | 1776 | // autoIncrement is false. 1777 | for _, v := range sets { 1778 | name, null := d.SQLType(v, false, 0) 1779 | actual := []interface{}{name, null} 1780 | expected := []interface{}{"numeric(65, 30)", true} 1781 | if !reflect.DeepEqual(actual, expected) { 1782 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1783 | } 1784 | } 1785 | 1786 | // autoIncrement is true. 1787 | for _, v := range sets { 1788 | name, null := d.SQLType(v, true, 0) 1789 | actual := []interface{}{name, null} 1790 | expected := []interface{}{"numeric(65, 30)", true} 1791 | if !reflect.DeepEqual(actual, expected) { 1792 | t.Errorf("%T expects %q, but %q", v, expected, actual) 1793 | } 1794 | } 1795 | } 1796 | 1797 | func TestPostgresDialect_AutoIncrement(t *testing.T) { 1798 | d := &PostgresDialect{} 1799 | actual := d.AutoIncrement() 1800 | expected := "" 1801 | if !reflect.DeepEqual(actual, expected) { 1802 | t.Errorf("Expect %q, but %q", expected, actual) 1803 | } 1804 | } 1805 | 1806 | func TestPostgresDialect_FormatBool(t *testing.T) { 1807 | d := &PostgresDialect{} 1808 | actual := d.FormatBool(true) 1809 | expected := "TRUE" 1810 | if !reflect.DeepEqual(actual, expected) { 1811 | t.Errorf("Expect %q, but %q", expected, actual) 1812 | } 1813 | 1814 | actual = d.FormatBool(false) 1815 | expected = "FALSE" 1816 | if !reflect.DeepEqual(actual, expected) { 1817 | t.Errorf("Expect %q, but %q", expected, actual) 1818 | } 1819 | } 1820 | 1821 | func TestPostgresDialect_LastInsertID(t *testing.T) { 1822 | d := &PostgresDialect{} 1823 | actual := d.LastInsertId() 1824 | expect := "SELECT lastval()" 1825 | if !reflect.DeepEqual(actual, expect) { 1826 | t.Errorf(`PostgresDialect.LastInsertId() => %#v; want %#v`, actual, expect) 1827 | } 1828 | } 1829 | --------------------------------------------------------------------------------