├── logo.png ├── sqlingo-gen-sqlite3 └── main.go ├── sqlingo-gen-mysql └── main.go ├── sqlingo-gen-postgres └── main.go ├── order_test.go ├── order.go ├── dialect_test.go ├── geometry_test.go ├── dialect.go ├── table_test.go ├── function_test.go ├── sqlingo-gen └── main.go ├── generator ├── generator_test.go ├── fetcher_sqlite3.go ├── fetcher_postgres.go ├── args.go ├── fetcher_mysql.go └── generator.go ├── .github └── workflows │ └── go.yml ├── case_test.go ├── LICENSE ├── utils_test.go ├── geometry.go ├── delete_test.go ├── interceptor.go ├── table.go ├── interceptor_test.go ├── function.go ├── transaction.go ├── field_test.go ├── transaction_test.go ├── update_test.go ├── case.go ├── common_test.go ├── value.go ├── delete.go ├── database_test.go ├── value_test.go ├── update.go ├── field.go ├── common.go ├── insert_test.go ├── README.md ├── cursor_test.go ├── array.go ├── insert.go ├── cursor.go ├── expression_test.go ├── database.go ├── select_test.go ├── select.go └── expression.go /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lqs/sqlingo/HEAD/logo.png -------------------------------------------------------------------------------- /sqlingo-gen-sqlite3/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "github.com/lqs/sqlingo/generator" 6 | _ "github.com/mattn/go-sqlite3" 7 | ) 8 | 9 | func main() { 10 | code, err := generator.Generate("mysql", "./testdb.sqlite3") 11 | if err != nil { 12 | panic(err) 13 | } 14 | 15 | fmt.Print(code) 16 | } 17 | -------------------------------------------------------------------------------- /sqlingo-gen-mysql/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | _ "github.com/go-sql-driver/mysql" 6 | "github.com/lqs/sqlingo/generator" 7 | ) 8 | 9 | func main() { 10 | code, err := generator.Generate("mysql", "username:password@tcp(hostname:3306)/database") 11 | if err != nil { 12 | panic(err) 13 | } 14 | 15 | fmt.Print(code) 16 | } 17 | -------------------------------------------------------------------------------- /sqlingo-gen-postgres/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | _ "github.com/lib/pq" 6 | "github.com/lqs/sqlingo/generator" 7 | ) 8 | 9 | func main() { 10 | code, err := generator.Generate("postgres", "host=localhost port=5432 user=user password=pass dbname=db sslmode=disable") 11 | if err != nil { 12 | panic(err) 13 | } 14 | 15 | fmt.Print(code) 16 | } 17 | -------------------------------------------------------------------------------- /order_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | func TestOrder(t *testing.T) { 9 | e := expression{sql: "x"} 10 | assertValue(t, orderBy{by: e}, "x") 11 | assertValue(t, orderBy{by: e, desc: true}, "x DESC") 12 | assertError(t, orderBy{by: expression{builder: func(scope scope) (string, error) { 13 | return "", errors.New("error") 14 | }}}) 15 | } 16 | -------------------------------------------------------------------------------- /order.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | // OrderBy indicates the ORDER BY column and the status of descending order. 4 | type OrderBy interface { 5 | GetSQL(scope scope) (string, error) 6 | } 7 | 8 | type orderBy struct { 9 | by Expression 10 | desc bool 11 | } 12 | 13 | func (o orderBy) GetSQL(scope scope) (string, error) { 14 | sql, err := o.by.GetSQL(scope) 15 | if err != nil { 16 | return "", err 17 | } 18 | if o.desc { 19 | sql += " DESC" 20 | } 21 | return sql, nil 22 | } 23 | -------------------------------------------------------------------------------- /dialect_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import "testing" 4 | 5 | func TestDialect(t *testing.T) { 6 | nameToDialect := map[string]dialect{ 7 | "mysql": dialectMySQL, 8 | "sqlite3": dialectSqlite3, 9 | "postgres": dialectPostgres, 10 | "sqlserver": dialectMSSQL, 11 | "mssql": dialectMSSQL, 12 | "somedbidontknow": dialectUnknown, 13 | } 14 | 15 | for name, dialect := range nameToDialect { 16 | if getDialectFromDriverName(name) != dialect { 17 | t.Error() 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /geometry_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import "testing" 4 | 5 | func TestGeometry(t *testing.T) { 6 | assertValue(t, STGeomFromText("sample wkt"), "ST_GeomFromText('sample wkt')") 7 | assertValue(t, STGeomFromTextf("sample wkt %d", 1), "ST_GeomFromText('sample wkt 1')") 8 | 9 | e := expression{ 10 | builder: func(scope scope) (string, error) { 11 | return "<>", nil 12 | }, 13 | } 14 | assertValue(t, e.STAsText(), "ST_AsText(<>)") 15 | 16 | t1 := NewTable("t1") 17 | field := NewWellKnownBinaryField(t1, "f1") 18 | assertValue(t, field, "`t1`.`f1`") 19 | } 20 | -------------------------------------------------------------------------------- /dialect.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | type dialect int 4 | 5 | const ( 6 | dialectUnknown dialect = iota 7 | dialectMySQL 8 | dialectSqlite3 9 | dialectPostgres 10 | dialectMSSQL 11 | 12 | dialectCount 13 | ) 14 | 15 | type dialectArray [dialectCount]string 16 | 17 | func getDialectFromDriverName(driverName string) dialect { 18 | switch driverName { 19 | case "mysql": 20 | return dialectMySQL 21 | case "sqlite3": 22 | return dialectSqlite3 23 | case "postgres": 24 | return dialectPostgres 25 | case "sqlserver", "mssql": 26 | return dialectMSSQL 27 | default: 28 | return dialectUnknown 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /table_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import "testing" 4 | 5 | func TestTable(t *testing.T) { 6 | table := table{} 7 | if table.getOperatorPriority() != 0 { 8 | t.Error() 9 | } 10 | } 11 | 12 | func TestDerivedTable(t *testing.T) { 13 | dummyFields := []Field{NewNumberField(NewTable("table"), "field")} 14 | dt := derivedTable{ 15 | name: "t", 16 | selectStatus: selectStatus{ 17 | base: selectBase{ 18 | fields: dummyFields, 19 | }, 20 | }, 21 | } 22 | if dt.GetName() != "t" { 23 | t.Error() 24 | } 25 | 26 | sql, err := dt.GetFields()[0].GetSQL(dummyMySQLScope) 27 | if err != nil { 28 | t.Error(err) 29 | } 30 | if sql != "`table`.`field`" { 31 | t.Error(sql) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /function_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | func TestFunction(t *testing.T) { 9 | a1 := expression{sql: "a1"} 10 | a2 := expression{sql: "a2"} 11 | ee := expression{builder: func(scope scope) (string, error) { 12 | return "", errors.New("error") 13 | }} 14 | 15 | assertValue(t, Function("func"), "func()") 16 | assertValue(t, Function("func", a1), "func(a1)") 17 | assertValue(t, Function("func", a1, a2), "func(a1, a2)") 18 | assertError(t, Function("func", a1, ee)) 19 | 20 | assertValue(t, Concat(a1, a2), "CONCAT(a1, a2)") 21 | assertValue(t, Count(a1), "COUNT(a1)") 22 | assertValue(t, If(a1, 1, 2), "IF(a1, 1, 2)") 23 | assertValue(t, Length(a1), "LENGTH(a1)") 24 | assertValue(t, Sum(a1), "SUM(a1)") 25 | } 26 | -------------------------------------------------------------------------------- /sqlingo-gen/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | _ "github.com/go-sql-driver/mysql" 6 | "github.com/lqs/sqlingo/generator" 7 | "os" 8 | "strings" 9 | ) 10 | 11 | func main() { 12 | warningLines := []string{ 13 | "\u001b[31mThis command is deprecated. Please install the new generator with the corresponding driver:", 14 | "go get -u github.com/lqs/sqlingo/sqlingo-gen-mysql", 15 | "go get -u github.com/lqs/sqlingo/sqlingo-gen-sqlite3", 16 | "go get -u github.com/lqs/sqlingo/sqlingo-gen-postgres", 17 | "\u001b[0m", 18 | } 19 | _, _ = fmt.Fprintln(os.Stderr, strings.Join(warningLines, "\n")) 20 | code, err := generator.Generate("mysql", "username:password@tcp(hostname:3306)/database") 21 | if err != nil { 22 | panic(err) 23 | } 24 | 25 | fmt.Print(code) 26 | } 27 | -------------------------------------------------------------------------------- /generator/generator_test.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import "testing" 4 | 5 | func TestConvert(t *testing.T) { 6 | m := map[string]string{ 7 | "abc": "Abc", 8 | "name": "Name", 9 | "Name": "Name", 10 | "abc_def": "AbcDef", 11 | "ϢϢϢϢ1": "ϢϢϢϢ1", 12 | "_[[[[[": "E", 13 | "abc/def": "AbcDef", 14 | "中文开头": "E中文开头", 15 | "abc~~中文": "Abc中文", 16 | "abc-def--ghi": "AbcDefGhi", 17 | "user_id": "UserID", 18 | "user_ip_address": "UserIPAddress", 19 | "user_ix": "UserIx", 20 | "volume_db": "VolumedB", 21 | "db_volume": "EdBVolume", 22 | } 23 | for k, v := range m { 24 | if convertToExportedIdentifier(k, []string{"ID", "IP", "dB"}) != v { 25 | t.Errorf("'%s' should be converted to '%s'", k, v) 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "master" ] 9 | pull_request: 10 | branches: [ "master" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | go-version: [ '1.22', '1.23' ] 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Setup Go ${{ matrix.go-version }} 22 | uses: actions/setup-go@v4 23 | with: 24 | go-version: ${{ matrix.go-version }} 25 | - name: Display Go version 26 | run: go version 27 | - name: Build 28 | run: go mod init github.com/lqs/sqlingo && go mod tidy && go build -v . 29 | - name: Test 30 | run: go test -v . 31 | -------------------------------------------------------------------------------- /case_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | func TestCase(t *testing.T) { 9 | c1 := expression{ 10 | builder: func(scope scope) (string, error) { 11 | return "c1", nil 12 | }, 13 | } 14 | c2 := expression{ 15 | builder: func(scope scope) (string, error) { 16 | return "c2", nil 17 | }, 18 | } 19 | assertValue(t, Case().WhenThen(c1, 1).WhenThen(c2, 2), 20 | "CASE WHEN c1 THEN 1 WHEN c2 THEN 2 END") 21 | assertValue(t, Case().WhenThen(c1, 1).WhenThen(c2, 2).Else(0), 22 | "CASE WHEN c1 THEN 1 WHEN c2 THEN 2 ELSE 0 END") 23 | assertValue(t, Case().Else(0), 24 | "0") 25 | 26 | ee := expression{ 27 | builder: func(scope scope) (string, error) { 28 | return "", errors.New("error") 29 | }, 30 | } 31 | assertError(t, Case().WhenThen(ee, 2)) 32 | assertError(t, Case().WhenThen(c1, ee)) 33 | assertError(t, Case().WhenThen(c1, 1).Else(ee)) 34 | } 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Liu Qishuai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import "testing" 4 | 5 | var dummyMySQLScope = scope{Database: &database{dialect: dialectMySQL}} 6 | 7 | func assertEqual(t *testing.T, actualValue string, expectedValue string) { 8 | t.Helper() 9 | if actualValue != expectedValue { 10 | t.Errorf("actual [%s] expected [%s]", actualValue, expectedValue) 11 | } 12 | } 13 | 14 | func assertValue(t *testing.T, value interface{}, expectedSql string) { 15 | t.Helper() 16 | if generatedSql, _, _ := getSQL(dummyMySQLScope, value); generatedSql != expectedSql { 17 | t.Errorf("value [%v] generated [%s] expected [%s]", value, generatedSql, expectedSql) 18 | } 19 | } 20 | 21 | func assertLastSql(t *testing.T, expectedSql string) { 22 | t.Helper() 23 | if sharedMockConn.lastSql != expectedSql { 24 | t.Errorf("last sql [%s] expected [%s]", sharedMockConn.lastSql, expectedSql) 25 | } 26 | sharedMockConn.lastSql = "" 27 | } 28 | 29 | func assertError(t *testing.T, value interface{}) { 30 | t.Helper() 31 | if generatedSql, _, err := getSQL(dummyMySQLScope, value); err == nil { 32 | t.Errorf("value [%v] generated [%s] expected error", value, generatedSql) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /geometry.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import "fmt" 4 | 5 | // WellKnownBinaryField is the interface of a generated field of binary geometry (WKB) type. 6 | type WellKnownBinaryField interface { 7 | WellKnownBinaryExpression 8 | GetTable() Table 9 | } 10 | 11 | // WellKnownBinaryExpression is the interface of an SQL expression with binary geometry (WKB) value. 12 | type WellKnownBinaryExpression interface { 13 | Expression 14 | STAsText() StringExpression 15 | } 16 | 17 | // WellKnownBinary is the type of geometry well-known binary (WKB) field. 18 | type WellKnownBinary []byte 19 | 20 | // NewWellKnownBinaryField creates a reference to a geometry WKB field. It should only be called from generated code. 21 | func NewWellKnownBinaryField(table Table, fieldName string) WellKnownBinaryField { 22 | return newField(table, fieldName) 23 | } 24 | 25 | func (e expression) STAsText() StringExpression { 26 | return function("ST_AsText", e) 27 | } 28 | 29 | func STGeomFromText(text interface{}) WellKnownBinaryExpression { 30 | return function("ST_GeomFromText", text) 31 | } 32 | 33 | func STGeomFromTextf(format string, a ...interface{}) WellKnownBinaryExpression { 34 | text := fmt.Sprintf(format, a...) 35 | return STGeomFromText(text) 36 | } 37 | -------------------------------------------------------------------------------- /delete_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | ) 8 | 9 | func TestDelete(t *testing.T) { 10 | errorExpression := expression{ 11 | builder: func(scope scope) (string, error) { 12 | return "", errors.New("error") 13 | }, 14 | } 15 | db := newMockDatabase() 16 | if _, err := db.DeleteFrom(Table1).Where(staticExpression("##", 1, false)).Execute(); err != nil { 17 | t.Error(err) 18 | } 19 | assertLastSql(t, "DELETE FROM `table1` WHERE ##") 20 | 21 | if _, err := db.DeleteFrom(Table1).Where(errorExpression).Execute(); err == nil { 22 | t.Error("should get error here") 23 | } 24 | 25 | if _, err := db.DeleteFrom(Table1).Where(Raw("#1#")).Limit(3).Execute(); err != nil { 26 | t.Error(err) 27 | } 28 | assertLastSql(t, "DELETE FROM `table1` WHERE #1# LIMIT 3") 29 | 30 | if _, err := db.DeleteFrom(Table1).Where(Raw("#1#")).OrderBy(Raw("#2#")).Limit(3).Execute(); err != nil { 31 | t.Error(err) 32 | } 33 | assertLastSql(t, "DELETE FROM `table1` WHERE #1# ORDER BY #2# LIMIT 3") 34 | 35 | if _, err := db.DeleteFrom(Table1).Where(Raw("#1#")).WithContext(context.Background()).Execute(); err != nil { 36 | t.Error(err) 37 | } 38 | assertLastSql(t, "DELETE FROM `table1` WHERE #1#") 39 | } 40 | -------------------------------------------------------------------------------- /interceptor.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // InvokerFunc is the function type of the actual invoker. It should be called in an interceptor. 8 | type InvokerFunc = func(ctx context.Context, sql string) error 9 | 10 | // InterceptorFunc is the function type of an interceptor. An interceptor should implement this function to fulfill it's purpose. 11 | type InterceptorFunc = func(ctx context.Context, sql string, invoker InvokerFunc) error 12 | 13 | func noopInterceptor(ctx context.Context, sql string, invoker InvokerFunc) error { 14 | return invoker(ctx, sql) 15 | } 16 | 17 | // ChainInterceptors chains multiple interceptors into one interceptor. 18 | func ChainInterceptors(interceptors ...InterceptorFunc) InterceptorFunc { 19 | if len(interceptors) == 0 { 20 | return noopInterceptor 21 | } 22 | return func(ctx context.Context, sql string, invoker InvokerFunc) error { 23 | var chain func(int, context.Context, string) error 24 | chain = func(i int, ctx context.Context, sql string) error { 25 | if i == len(interceptors) { 26 | return invoker(ctx, sql) 27 | } 28 | return interceptors[i](ctx, sql, func(ctx context.Context, sql string) error { 29 | return chain(i+1, ctx, sql) 30 | }) 31 | } 32 | return chain(0, ctx, sql) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /table.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | // Table is the interface of a generated table. 4 | type Table interface { 5 | GetName() string 6 | GetSQL(scope scope) string 7 | GetFields() []Field 8 | } 9 | 10 | type actualTable interface { 11 | Table 12 | GetFieldsSQL() string 13 | GetFullFieldsSQL() string 14 | } 15 | 16 | type table struct { 17 | Table 18 | name string 19 | sqlDialects dialectArray 20 | } 21 | 22 | func (t table) GetName() string { 23 | return t.name 24 | } 25 | 26 | func (t table) GetSQL(scope scope) string { 27 | return t.sqlDialects[scope.Database.dialect] 28 | } 29 | 30 | func (t table) getOperatorPriority() int { 31 | return 0 32 | } 33 | 34 | // NewTable creates a reference to a table. It should only be called from generated code. 35 | func NewTable(name string) Table { 36 | return table{name: name, sqlDialects: quoteIdentifier(name)} 37 | } 38 | 39 | type derivedTable struct { 40 | name string 41 | selectStatus selectStatus 42 | } 43 | 44 | func (t derivedTable) GetName() string { 45 | return t.name 46 | } 47 | 48 | func (t derivedTable) GetSQL(scope scope) string { 49 | sql, _ := t.selectStatus.GetSQL() 50 | return "(" + sql + ") AS " + t.name 51 | } 52 | 53 | func (t derivedTable) GetFields() []Field { 54 | return activeSelectBase(&t.selectStatus).fields 55 | } 56 | -------------------------------------------------------------------------------- /interceptor_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | ) 7 | 8 | func TestChainInterceptors(t *testing.T) { 9 | s := "" 10 | i1 := func(ctx context.Context, sql string, invoker InvokerFunc) error { 11 | s += "" 12 | s += sql 13 | defer func() { 14 | s += "" 15 | }() 16 | return invoker(ctx, sql+"s1") 17 | } 18 | i2 := func(ctx context.Context, sql string, invoker InvokerFunc) error { 19 | s += "" 20 | s += sql 21 | defer func() { 22 | s += "" 23 | }() 24 | return invoker(ctx, sql+"s2") 25 | } 26 | chain := ChainInterceptors(i1, i2) 27 | _ = chain(context.Background(), "sql", func(ctx context.Context, sql string) error { 28 | s += "" 29 | s += sql 30 | defer func() { 31 | s += "" 32 | }() 33 | return nil 34 | }) 35 | if s != "sqlsqls1sqls1s2" { 36 | t.Error(s) 37 | } 38 | } 39 | 40 | func TestEmptyChainInterceptors(t *testing.T) { 41 | s := "" 42 | chain := ChainInterceptors() 43 | _ = chain(context.Background(), "sql", func(ctx context.Context, sql string) error { 44 | s += "" 45 | defer func() { 46 | s += "" 47 | }() 48 | s += sql 49 | return nil 50 | }) 51 | 52 | if s != "sql" { 53 | t.Error(s) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /function.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | func function(name string, args ...interface{}) expression { 4 | return expression{builder: func(scope scope) (string, error) { 5 | valuesSql, err := commaValues(scope, args) 6 | if err != nil { 7 | return "", err 8 | } 9 | return name + "(" + valuesSql + ")", nil 10 | }} 11 | } 12 | 13 | // Function creates an expression of the call to specified function. 14 | func Function(name string, args ...interface{}) UnknownExpression { 15 | return function(name, args...) 16 | } 17 | 18 | // Concat creates an expression of CONCAT function. 19 | func Concat(args ...interface{}) StringExpression { 20 | return function("CONCAT", args...) 21 | } 22 | 23 | // Count creates an expression of COUNT aggregator. 24 | func Count(arg interface{}) NumberExpression { 25 | return function("COUNT", arg) 26 | } 27 | 28 | // If creates an expression of IF function. 29 | func If(predicate Expression, trueValue interface{}, falseValue interface{}) (result UnknownExpression) { 30 | return function("IF", predicate, trueValue, falseValue) 31 | } 32 | 33 | // Length creates an expression of LENGTH function. 34 | func Length(arg interface{}) NumberExpression { 35 | return function("LENGTH", arg) 36 | } 37 | 38 | // Sum creates an expression of SUM aggregator. 39 | func Sum(arg interface{}) NumberExpression { 40 | return function("SUM", arg) 41 | } 42 | -------------------------------------------------------------------------------- /transaction.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | ) 7 | 8 | // Transaction is the interface of a transaction with underlying sql.Tx object. 9 | // It provides methods to execute DDL and TCL operations. 10 | type Transaction interface { 11 | GetDB() *sql.DB 12 | GetTx() *sql.Tx 13 | Query(sql string) (Cursor, error) 14 | Execute(sql string) (sql.Result, error) 15 | 16 | Select(fields ...interface{}) selectWithFields 17 | SelectDistinct(fields ...interface{}) selectWithFields 18 | SelectFrom(tables ...Table) selectWithTables 19 | InsertInto(table Table) insertWithTable 20 | Update(table Table) updateWithSet 21 | DeleteFrom(table Table) deleteWithTable 22 | } 23 | 24 | func (d *database) GetTx() *sql.Tx { 25 | return d.tx 26 | } 27 | 28 | func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx Transaction) error) error { 29 | if ctx == nil { 30 | ctx = context.Background() 31 | } 32 | tx, err := d.db.BeginTx(ctx, opts) 33 | if err != nil { 34 | return err 35 | } 36 | isCommitted := false 37 | defer func() { 38 | if !isCommitted { 39 | _ = tx.Rollback() 40 | } 41 | }() 42 | 43 | if f != nil { 44 | db := *d 45 | db.tx = tx 46 | err = f(&db) 47 | if err != nil { 48 | return err 49 | } 50 | } 51 | 52 | err = tx.Commit() 53 | if err != nil { 54 | return err 55 | } 56 | isCommitted = true 57 | return nil 58 | } 59 | -------------------------------------------------------------------------------- /generator/fetcher_sqlite3.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import "database/sql" 4 | 5 | type sqlite3SchemaFetcher struct { 6 | db *sql.DB 7 | } 8 | 9 | func (s sqlite3SchemaFetcher) GetDatabaseName() (dbName string, err error) { 10 | dbName = "main" 11 | return 12 | } 13 | 14 | func (s sqlite3SchemaFetcher) GetTableNames() (tableNames []string, err error) { 15 | rows, err := s.db.Query("SELECT `name` FROM `sqlite_master` WHERE `type` ='table' AND `name` NOT LIKE 'sqlite_%'") 16 | if err != nil { 17 | return 18 | } 19 | defer rows.Close() 20 | for rows.Next() { 21 | var name string 22 | if err = rows.Scan(&name); err != nil { 23 | return 24 | } 25 | tableNames = append(tableNames, name) 26 | } 27 | return 28 | } 29 | 30 | func (s sqlite3SchemaFetcher) GetFieldDescriptors(tableName string) (result []fieldDescriptor, err error) { 31 | rows, err := s.db.Query("SELECT `name`, `type`, `notnull` FROM pragma_table_info('" + tableName + "')") 32 | if err != nil { 33 | return 34 | } 35 | defer rows.Close() 36 | for rows.Next() { 37 | var fieldDescriptor fieldDescriptor 38 | var notNull int 39 | if err = rows.Scan(&fieldDescriptor.Name, &fieldDescriptor.Type, ¬Null); err != nil { 40 | return 41 | } 42 | fieldDescriptor.AllowNull = notNull == 0 43 | result = append(result, fieldDescriptor) 44 | } 45 | return 46 | } 47 | 48 | func (s sqlite3SchemaFetcher) QuoteIdentifier(identifier string) string { 49 | return "\"" + identifier + "\"" 50 | } 51 | 52 | func newSQLite3SchemaFetcher(db *sql.DB) schemaFetcher { 53 | return sqlite3SchemaFetcher{db: db} 54 | } 55 | -------------------------------------------------------------------------------- /field_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | type dummyTable struct { 9 | } 10 | 11 | func (d dummyTable) GetName() string { 12 | panic("should not be here") 13 | } 14 | 15 | func (d dummyTable) GetSQL(scope scope) string { 16 | panic("should not be here") 17 | } 18 | 19 | func (d dummyTable) GetFieldByName(name string) Field { 20 | panic("should not be here") 21 | } 22 | 23 | func (d dummyTable) GetFields() []Field { 24 | panic("implement me") 25 | } 26 | 27 | func (d dummyTable) GetFieldsSQL() string { 28 | return "" 29 | } 30 | 31 | func (d dummyTable) GetFullFieldsSQL() string { 32 | return "" 33 | } 34 | 35 | func TestField(t *testing.T) { 36 | t1 := NewTable("t1") 37 | assertValue(t, NewNumberField(t1, "f1").Equals(1), "`t1`.`f1` = 1") 38 | assertValue(t, NewBooleanField(t1, "f1").Equals(true), "`t1`.`f1` = 1") 39 | assertValue(t, NewStringField(t1, "f1").Equals("x"), "`t1`.`f1` = 'x'") 40 | 41 | sql, _ := fieldList{}.GetSQL(scope{ 42 | Tables: []Table{ 43 | &dummyTable{}, 44 | }, 45 | }) 46 | if sql != "" { 47 | t.Error(sql) 48 | } 49 | 50 | sql, _ = fieldList{}.GetSQL(scope{ 51 | Tables: []Table{ 52 | &dummyTable{}, 53 | &dummyTable{}, 54 | }, 55 | }) 56 | if sql != ", " { 57 | t.Error(sql) 58 | } 59 | 60 | if _, err := (fieldList{ 61 | expression{builder: func(scope scope) (string, error) { 62 | return "", errors.New("error") 63 | }}, 64 | }.GetSQL(scope{ 65 | Tables: []Table{ 66 | &dummyTable{}, 67 | &dummyTable{}, 68 | }, 69 | })); err == nil { 70 | t.Error("should get error here") 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /generator/fetcher_postgres.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import "database/sql" 4 | 5 | type postgresSchemaFetcher struct { 6 | db *sql.DB 7 | } 8 | 9 | func (p postgresSchemaFetcher) GetDatabaseName() (dbName string, err error) { 10 | row := p.db.QueryRow("SELECT current_database()") 11 | err = row.Scan(&dbName) 12 | return 13 | } 14 | 15 | func (p postgresSchemaFetcher) GetTableNames() (tableNames []string, err error) { 16 | rows, err := p.db.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'") 17 | if err != nil { 18 | return 19 | } 20 | defer rows.Close() 21 | for rows.Next() { 22 | var name string 23 | if err = rows.Scan(&name); err != nil { 24 | return 25 | } 26 | tableNames = append(tableNames, name) 27 | } 28 | return 29 | } 30 | 31 | func (p postgresSchemaFetcher) GetFieldDescriptors(tableName string) (result []fieldDescriptor, err error) { 32 | rows, err := p.db.Query("SELECT column_name, is_nullable, data_type FROM information_schema.columns WHERE table_schema = 'public' AND table_name = $1", tableName) 33 | if err != nil { 34 | return 35 | } 36 | defer rows.Close() 37 | for rows.Next() { 38 | var fieldDescriptor fieldDescriptor 39 | var isNullable string 40 | if err = rows.Scan(&fieldDescriptor.Name, &isNullable, &fieldDescriptor.Type); err != nil { 41 | return 42 | } 43 | fieldDescriptor.AllowNull = isNullable == "YES" 44 | result = append(result, fieldDescriptor) 45 | } 46 | return 47 | } 48 | 49 | func (p postgresSchemaFetcher) QuoteIdentifier(identifier string) string { 50 | return "\"" + identifier + "\"" 51 | } 52 | 53 | func newPostgresSchemaFetcher(db *sql.DB) schemaFetcher { 54 | return postgresSchemaFetcher{db: db} 55 | } 56 | -------------------------------------------------------------------------------- /generator/args.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | ) 8 | 9 | type options struct { 10 | dataSourceName string 11 | tableNames []string 12 | forceCases []string 13 | } 14 | 15 | func printUsageAndExit(exampleDataSourceName string) { 16 | cmd := os.Args[0] 17 | _, _ = fmt.Fprintf(os.Stderr, `Usage: 18 | %s [-t table1,table2,...] [-forcecases ID,IDs,HTML] dataSourceName 19 | Example: 20 | %s "%s" 21 | `, cmd, cmd, exampleDataSourceName) 22 | os.Exit(1) 23 | } 24 | 25 | func parseArgs(exampleDataSourceName string) (options options) { 26 | var args []string 27 | parseTable := false 28 | parseForceCases := false 29 | for _, arg := range os.Args[1:] { 30 | if arg != "" && arg[0] == '-' { 31 | switch arg[1:] { 32 | case "t": 33 | if parseTable { 34 | printUsageAndExit(exampleDataSourceName) 35 | } 36 | parseTable = true 37 | case "forcecases": 38 | if parseForceCases { 39 | printUsageAndExit(exampleDataSourceName) 40 | } 41 | parseForceCases = true 42 | case "timeAsString": 43 | timeAsString = true 44 | default: 45 | printUsageAndExit(exampleDataSourceName) 46 | } 47 | } else { 48 | if parseTable { 49 | options.tableNames = append(options.tableNames, strings.Split(arg, ",")...) 50 | parseTable = false 51 | } else if parseForceCases { 52 | options.forceCases = append(options.forceCases, strings.Split(arg, ",")...) 53 | parseForceCases = false 54 | } else { 55 | args = append(args, arg) 56 | } 57 | } 58 | } 59 | if parseTable || parseForceCases { 60 | // "-t" not closed 61 | printUsageAndExit(exampleDataSourceName) 62 | } 63 | 64 | switch len(args) { 65 | case 1: 66 | options.dataSourceName = args[0] 67 | default: 68 | printUsageAndExit(exampleDataSourceName) 69 | } 70 | 71 | return 72 | } 73 | -------------------------------------------------------------------------------- /transaction_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | ) 8 | 9 | type mockTx struct { 10 | isCommitted bool 11 | isRolledBack bool 12 | commitError error 13 | } 14 | 15 | func (m *mockTx) Commit() error { 16 | if m.commitError != nil { 17 | return m.commitError 18 | } 19 | m.isCommitted = true 20 | return nil 21 | } 22 | 23 | func (m *mockTx) Rollback() error { 24 | m.isRolledBack = true 25 | return nil 26 | } 27 | 28 | func TestTransaction(t *testing.T) { 29 | db := newMockDatabase() 30 | err := db.BeginTx(nil, nil, func(tx Transaction) error { 31 | if tx.GetDB() != db.GetDB() { 32 | t.Error() 33 | } 34 | if tx.GetTx() == nil { 35 | t.Error() 36 | } 37 | 38 | _, err := tx.Execute("") 39 | if err != nil { 40 | t.Error(err) 41 | } 42 | return nil 43 | }) 44 | if err != nil { 45 | t.Error(err) 46 | } 47 | if !sharedMockConn.mockTx.isCommitted { 48 | t.Error() 49 | } 50 | if sharedMockConn.mockTx.isRolledBack { 51 | t.Error() 52 | } 53 | 54 | err = db.BeginTx(context.Background(), nil, func(tx Transaction) error { 55 | return errors.New("error") 56 | }) 57 | if err == nil { 58 | t.Error("should get error here") 59 | } 60 | if sharedMockConn.mockTx.isCommitted { 61 | t.Error() 62 | } 63 | if !sharedMockConn.mockTx.isRolledBack { 64 | t.Error() 65 | } 66 | 67 | sharedMockConn.beginTxError = errors.New("error") 68 | err = db.BeginTx(context.Background(), nil, func(tx Transaction) error { 69 | return nil 70 | }) 71 | if err == nil { 72 | t.Error("should get error here") 73 | } 74 | sharedMockConn.beginTxError = nil 75 | 76 | err = db.BeginTx(context.Background(), nil, func(tx Transaction) error { 77 | sharedMockConn.mockTx.commitError = errors.New("error") 78 | return nil 79 | }) 80 | if err == nil { 81 | t.Error("should get error here") 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /update_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | ) 8 | 9 | func TestUpdate(t *testing.T) { 10 | db := newMockDatabase() 11 | 12 | _, _ = db.Update(Table1).Set(field1, field2).Where(True()).Execute() 13 | assertLastSql(t, "UPDATE `table1` SET `field1` = `field2`") 14 | 15 | _, _ = db.Update(Table1). 16 | Set(field1, 10). 17 | Where(field2.Equals(2)). 18 | OrderBy(field1.Desc()). 19 | Limit(2). 20 | Execute() 21 | assertLastSql(t, "UPDATE `table1` SET `field1` = 10 WHERE `field2` = 2 ORDER BY `field1` DESC LIMIT 2") 22 | 23 | _, _ = db.Update(Table1). 24 | SetIf(true, field1, 10). 25 | SetIf(false, field2, 10). 26 | Where(True()). 27 | Execute() 28 | assertLastSql(t, "UPDATE `table1` SET `field1` = 10") 29 | 30 | _, _ = db.Update(Table1). 31 | SetIf(false, field1, 10). 32 | Where(True()). 33 | Execute() 34 | assertLastSql(t, "/* UPDATE without SET clause */ DO 0") 35 | 36 | _, _ = db.Update(Table1).Limit(3).Execute() 37 | assertLastSql(t, "/* UPDATE without SET clause */ DO 0") 38 | 39 | errExp := &expression{ 40 | builder: func(scope scope) (string, error) { 41 | return "", errors.New("error") 42 | }, 43 | } 44 | 45 | if _, err := db.Update(Table1). 46 | Set(field1, 10). 47 | OrderBy(orderBy{by: errExp}). 48 | Execute(); err == nil { 49 | t.Error("should get error here") 50 | } 51 | 52 | if _, err := db.Update(Table1). 53 | Set(field1, errExp). 54 | Where(True()). 55 | Execute(); err == nil { 56 | t.Error("should get error here") 57 | } 58 | 59 | if _, err := db.Update(Table1). 60 | Set(field1, 10). 61 | Where(errExp). 62 | Execute(); err == nil { 63 | t.Error("should get error here") 64 | } 65 | 66 | if _, err := db.Update(Table1). 67 | Set(field1, 10). 68 | Where(True()). 69 | WithContext(context.Background()). 70 | Execute(); err != nil { 71 | t.Error(err) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /case.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import "strings" 4 | 5 | // CaseExpression indicates the status in a CASE statement 6 | type CaseExpression interface { 7 | WhenThen(when BooleanExpression, then interface{}) CaseExpression 8 | Else(value interface{}) CaseExpressionWithElse 9 | End() Expression 10 | } 11 | 12 | // CaseExpressionWithElse indicates the status in CASE ... ELSE ... statement 13 | type CaseExpressionWithElse interface { 14 | End() Expression 15 | } 16 | 17 | type caseStatus struct { 18 | head, tail *whenThen 19 | elseValue interface{} 20 | } 21 | 22 | type whenThen struct { 23 | next *whenThen 24 | when BooleanExpression 25 | then interface{} 26 | } 27 | 28 | // Case initiates a CASE statement 29 | func Case() CaseExpression { 30 | return caseStatus{} 31 | } 32 | 33 | func (s caseStatus) WhenThen(when BooleanExpression, then interface{}) CaseExpression { 34 | whenThen := &whenThen{when: when, then: then} 35 | if s.head == nil { 36 | s.head = whenThen 37 | } 38 | if s.tail != nil { 39 | s.tail.next = whenThen 40 | } 41 | s.tail = whenThen 42 | return s 43 | } 44 | 45 | func (s caseStatus) Else(value interface{}) CaseExpressionWithElse { 46 | s.elseValue = value 47 | return s 48 | } 49 | 50 | func (s caseStatus) End() Expression { 51 | if s.head == nil { 52 | return expression{ 53 | builder: func(scope scope) (string, error) { 54 | elseSql, _, err := getSQL(scope, s.elseValue) 55 | return elseSql, err 56 | }, 57 | } 58 | } 59 | 60 | return expression{ 61 | builder: func(scope scope) (string, error) { 62 | sb := strings.Builder{} 63 | sb.WriteString("CASE ") 64 | 65 | for whenThen := s.head; whenThen != nil; whenThen = whenThen.next { 66 | whenSql, err := whenThen.when.GetSQL(scope) 67 | if err != nil { 68 | return "", err 69 | } 70 | thenSql, _, err := getSQL(scope, whenThen.then) 71 | if err != nil { 72 | return "", err 73 | } 74 | sb.WriteString("WHEN " + whenSql + " THEN " + thenSql + " ") 75 | } 76 | if s.elseValue != nil { 77 | elseSql, _, err := getSQL(scope, s.elseValue) 78 | if err != nil { 79 | return "", err 80 | } 81 | sb.WriteString("ELSE " + elseSql + " ") 82 | } 83 | sb.WriteString("END") 84 | return sb.String(), nil 85 | }, 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /common_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestCommon(t *testing.T) { 10 | db := newMockDatabase() 11 | 12 | dummyExp1 := expression{sql: ""} 13 | dummyExp2 := expression{sql: ""} 14 | errExp := expression{builder: func(scope scope) (string, error) { 15 | return "", errors.New("error") 16 | }} 17 | assertValue(t, &assignment{ 18 | field: dummyExp1, 19 | value: dummyExp2, 20 | }, " = ") 21 | assertError(t, &assignment{ 22 | field: errExp, 23 | value: dummyExp2, 24 | }) 25 | assertError(t, &assignment{ 26 | field: dummyExp1, 27 | value: errExp, 28 | }) 29 | assertError(t, &assignment{ 30 | field: errExp, 31 | value: errExp, 32 | }) 33 | assertError(t, command("COMMAND", errExp)) 34 | 35 | sql, err := commaExpressions(scope{}, []Expression{dummyExp1, dummyExp2, dummyExp1}) 36 | if err != nil { 37 | t.Error(err) 38 | } 39 | if sql != ", , " { 40 | t.Error() 41 | } 42 | 43 | _, err = commaExpressions(scope{}, []Expression{dummyExp1, dummyExp2, errExp}) 44 | if err == nil { 45 | t.Error("should get error") 46 | } 47 | 48 | sql, err = commaAssignments(scope{}, []assignment{ 49 | {field: dummyExp1, value: dummyExp1}, 50 | {field: dummyExp1, value: dummyExp2}, 51 | {field: dummyExp2, value: dummyExp2}, 52 | }) 53 | if err != nil { 54 | t.Error(err) 55 | } 56 | if sql != " = , = , = " { 57 | t.Error() 58 | } 59 | _, err = commaAssignments(scope{}, []assignment{ 60 | {field: dummyExp1, value: dummyExp1}, 61 | {field: dummyExp1, value: errExp}, 62 | }) 63 | if err == nil { 64 | t.Error("should get error") 65 | } 66 | 67 | sql, err = commaOrderBys(scope{}, []OrderBy{ 68 | orderBy{by: dummyExp1, desc: true}, 69 | orderBy{by: dummyExp2}, 70 | }) 71 | if err != nil { 72 | t.Error(err) 73 | } 74 | if sql != " DESC, " { 75 | t.Error() 76 | } 77 | 78 | _, err = commaOrderBys(scope{}, []OrderBy{ 79 | orderBy{by: errExp}, 80 | }) 81 | if err == nil { 82 | t.Error("should get error") 83 | } 84 | 85 | db.EnableCallerInfo(true) 86 | if _, err := db.Select(1).FetchFirst(); err != nil { 87 | t.Error(err) 88 | } 89 | if !strings.HasPrefix(sharedMockConn.lastSql, "/* ") { 90 | t.Error() 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /value.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "math" 5 | "strconv" 6 | ) 7 | 8 | const ( 9 | maxInt = 1<<(strconv.IntSize-1) - 1 10 | minInt = -1 << (strconv.IntSize - 1) 11 | maxUint = 1<= minInt && r <= maxInt { 42 | return int(r) 43 | } 44 | return 0 45 | } 46 | 47 | func (v value) Int8() int8 { 48 | if r := v.Int64(); r >= math.MinInt8 && r <= math.MaxInt8 { 49 | return int8(r) 50 | } 51 | return 0 52 | } 53 | 54 | func (v value) Int16() int16 { 55 | if r := v.Int64(); r >= math.MinInt16 && r <= math.MaxInt16 { 56 | return int16(r) 57 | } 58 | return 0 59 | } 60 | 61 | func (v value) Int32() int32 { 62 | if r := v.Int64(); r >= math.MinInt32 && r <= math.MaxInt32 { 63 | return int32(r) 64 | } 65 | return 0 66 | } 67 | 68 | func (v value) Uint() uint { 69 | if r := v.Uint64(); r <= maxUint { 70 | return uint(r) 71 | } 72 | return 0 73 | } 74 | 75 | func (v value) Uint8() uint8 { 76 | if r := v.Uint64(); r <= math.MaxUint8 { 77 | return uint8(r) 78 | } 79 | return 0 80 | } 81 | 82 | func (v value) Uint16() uint16 { 83 | if r := v.Uint64(); r <= math.MaxUint16 { 84 | return uint16(r) 85 | } 86 | return 0 87 | } 88 | 89 | func (v value) Uint32() uint32 { 90 | if r := v.Uint64(); r <= math.MaxUint32 { 91 | return uint32(r) 92 | } 93 | return 0 94 | } 95 | 96 | func (v value) Bool() bool { 97 | if v.stringValue == nil { 98 | return false 99 | } 100 | switch *v.stringValue { 101 | case "", "0", "\x00": 102 | return false 103 | default: 104 | return true 105 | } 106 | } 107 | 108 | func (v value) String() string { 109 | if v.stringValue == nil { 110 | return "" 111 | } 112 | return *v.stringValue 113 | } 114 | 115 | func (v value) IsNull() bool { 116 | return v.stringValue == nil 117 | } 118 | -------------------------------------------------------------------------------- /generator/fetcher_mysql.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "database/sql" 5 | "regexp" 6 | "strconv" 7 | ) 8 | 9 | var timeAsString = false 10 | 11 | type mysqlSchemaFetcher struct { 12 | db *sql.DB 13 | } 14 | 15 | func (m mysqlSchemaFetcher) GetDatabaseName() (dbName string, err error) { 16 | row := m.db.QueryRow("SELECT DATABASE()") 17 | err = row.Scan(&dbName) 18 | return 19 | } 20 | 21 | func (m mysqlSchemaFetcher) GetTableNames() (tableNames []string, err error) { 22 | rows, err := m.db.Query("SHOW TABLES") 23 | if err != nil { 24 | return 25 | } 26 | defer rows.Close() 27 | 28 | for rows.Next() { 29 | var name string 30 | err = rows.Scan(&name) 31 | if err != nil { 32 | return 33 | } 34 | tableNames = append(tableNames, name) 35 | } 36 | return 37 | } 38 | 39 | func (m mysqlSchemaFetcher) GetFieldDescriptors(tableName string) ([]fieldDescriptor, error) { 40 | rows, err := m.db.Query("SHOW FULL COLUMNS FROM `" + tableName + "`") 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | var result []fieldDescriptor 46 | for rows.Next() { 47 | columns, err := rows.Columns() 48 | if err != nil { 49 | return nil, err 50 | } 51 | var pointers []interface{} 52 | for i := 0; i < len(columns); i++ { 53 | var value *string 54 | pointers = append(pointers, &value) 55 | } 56 | err = rows.Scan(pointers...) 57 | if err != nil { 58 | return nil, err 59 | } 60 | row := make(map[string]string) 61 | for i, column := range columns { 62 | pointer := *pointers[i].(**string) 63 | if pointer != nil { 64 | row[column] = *pointer 65 | } 66 | } 67 | 68 | r, _ := regexp.Compile("([a-z]+)(\\(([0-9]+)\\))?( ([a-z]+))?") 69 | submatches := r.FindStringSubmatch(row["Type"]) 70 | 71 | fieldType := submatches[1] 72 | fieldSize := 0 73 | if submatches[3] != "" { 74 | fieldSize, err = strconv.Atoi(submatches[3]) 75 | if err != nil { 76 | return nil, err 77 | } 78 | } 79 | unsigned := submatches[5] == "unsigned" 80 | 81 | result = append(result, fieldDescriptor{ 82 | Name: row["Field"], 83 | Type: fieldType, 84 | Size: fieldSize, 85 | Unsigned: unsigned, 86 | AllowNull: row["Null"] == "YES", 87 | Comment: row["Comment"], 88 | }) 89 | } 90 | return result, nil 91 | } 92 | 93 | func (m mysqlSchemaFetcher) QuoteIdentifier(identifier string) string { 94 | return "`" + identifier + "`" 95 | } 96 | 97 | func newMySQLSchemaFetcher(db *sql.DB) schemaFetcher { 98 | return mysqlSchemaFetcher{db: db} 99 | } 100 | -------------------------------------------------------------------------------- /delete.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | type deleteStatus struct { 11 | scope scope 12 | where BooleanExpression 13 | orderBys []OrderBy 14 | limit *int 15 | ctx context.Context 16 | } 17 | 18 | type deleteWithTable interface { 19 | Where(conditions ...BooleanExpression) deleteWithWhere 20 | } 21 | 22 | type deleteWithWhere interface { 23 | toDeleteWithContext 24 | toDeleteFinal 25 | OrderBy(orderBys ...OrderBy) deleteWithOrder 26 | Limit(limit int) deleteWithLimit 27 | } 28 | 29 | type deleteWithOrder interface { 30 | toDeleteWithContext 31 | toDeleteFinal 32 | Limit(limit int) deleteWithLimit 33 | } 34 | 35 | type deleteWithLimit interface { 36 | toDeleteWithContext 37 | toDeleteFinal 38 | } 39 | 40 | type toDeleteWithContext interface { 41 | WithContext(ctx context.Context) toDeleteFinal 42 | } 43 | 44 | type toDeleteFinal interface { 45 | GetSQL() (string, error) 46 | Execute() (result sql.Result, err error) 47 | } 48 | 49 | func (d *database) DeleteFrom(table Table) deleteWithTable { 50 | return deleteStatus{scope: scope{Database: d, Tables: []Table{table}}} 51 | } 52 | 53 | func (s deleteStatus) Where(conditions ...BooleanExpression) deleteWithWhere { 54 | s.where = And(conditions...) 55 | return s 56 | } 57 | 58 | func (s deleteStatus) OrderBy(orderBys ...OrderBy) deleteWithOrder { 59 | s.orderBys = orderBys 60 | return s 61 | } 62 | 63 | func (s deleteStatus) Limit(limit int) deleteWithLimit { 64 | s.limit = &limit 65 | return s 66 | } 67 | 68 | func (s deleteStatus) GetSQL() (string, error) { 69 | var sb strings.Builder 70 | sb.Grow(128) 71 | 72 | sb.WriteString("DELETE FROM ") 73 | sb.WriteString(s.scope.Tables[0].GetSQL(s.scope)) 74 | 75 | if err := appendWhere(&sb, s.scope, s.where); err != nil { 76 | return "", err 77 | } 78 | 79 | if len(s.orderBys) > 0 { 80 | orderBySql, err := commaOrderBys(s.scope, s.orderBys) 81 | if err != nil { 82 | return "", err 83 | } 84 | sb.WriteString(" ORDER BY ") 85 | sb.WriteString(orderBySql) 86 | } 87 | 88 | if s.limit != nil { 89 | sb.WriteString(" LIMIT ") 90 | sb.WriteString(strconv.Itoa(*s.limit)) 91 | } 92 | 93 | return sb.String(), nil 94 | } 95 | 96 | func (s deleteStatus) WithContext(ctx context.Context) toDeleteFinal { 97 | s.ctx = ctx 98 | return s 99 | } 100 | 101 | func (s deleteStatus) Execute() (sql.Result, error) { 102 | sqlString, err := s.GetSQL() 103 | if err != nil { 104 | return nil, err 105 | } 106 | return s.scope.Database.ExecuteContext(s.ctx, sqlString) 107 | } 108 | -------------------------------------------------------------------------------- /database_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "errors" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func (m *mockConn) Prepare(query string) (driver.Stmt, error) { 13 | if m.prepareError != nil { 14 | return nil, m.prepareError 15 | } 16 | m.lastSql = query 17 | return &mockStmt{ 18 | columnCount: m.columnCount, 19 | rowCount: m.rowCount, 20 | }, nil 21 | } 22 | 23 | func (m mockConn) Close() error { 24 | return nil 25 | } 26 | 27 | func (m *mockConn) Begin() (driver.Tx, error) { 28 | if m.beginTxError != nil { 29 | return nil, m.beginTxError 30 | } 31 | m.mockTx = &mockTx{} 32 | return m.mockTx, nil 33 | } 34 | 35 | var sharedMockConn = &mockConn{ 36 | columnCount: 11, 37 | rowCount: 10, 38 | } 39 | 40 | func (m mockDriver) Open(name string) (driver.Conn, error) { 41 | return sharedMockConn, nil 42 | } 43 | 44 | func newMockDatabase() Database { 45 | db, err := Open("sqlingo-mock", "dummy") 46 | if err != nil { 47 | panic(err) 48 | } 49 | db.(*database).dialect = dialectMySQL 50 | return db 51 | } 52 | 53 | func init() { 54 | sql.Register("sqlingo-mock", &mockDriver{}) 55 | } 56 | 57 | func TestDatabase(t *testing.T) { 58 | if _, err := Open("unknowndb", "unknown"); err == nil { 59 | t.Error() 60 | } 61 | 62 | db := newMockDatabase() 63 | if db.GetDB() == nil { 64 | t.Error() 65 | } 66 | 67 | interceptorExecutedCount := 0 68 | loggerExecutedCount := 0 69 | db.SetInterceptor(func(ctx context.Context, sql string, invoker InvokerFunc) error { 70 | if sql != "SELECT 1" { 71 | t.Error() 72 | } 73 | interceptorExecutedCount++ 74 | return invoker(ctx, sql) 75 | }) 76 | db.SetLogger(func(sql string, _ time.Duration, _, _ bool) { 77 | if sql != "SELECT 1" { 78 | t.Error(sql) 79 | } 80 | loggerExecutedCount++ 81 | }) 82 | _, _ = db.Query("SELECT 1") 83 | if interceptorExecutedCount != 1 || loggerExecutedCount != 1 { 84 | t.Error(interceptorExecutedCount, loggerExecutedCount) 85 | } 86 | 87 | _, _ = db.Execute("SELECT 1") 88 | if loggerExecutedCount != 2 { 89 | t.Error(loggerExecutedCount) 90 | } 91 | 92 | db.SetInterceptor(func(ctx context.Context, sql string, invoker InvokerFunc) error { 93 | return errors.New("error") 94 | }) 95 | if _, err := db.Query("SELECT 1"); err == nil { 96 | t.Error("should get error here") 97 | } 98 | } 99 | 100 | func TestDatabaseRetry(t *testing.T) { 101 | db := newMockDatabase() 102 | retryCount := 0 103 | db.SetRetryPolicy(func(err error) bool { 104 | retryCount++ 105 | return retryCount < 10 106 | }) 107 | 108 | sharedMockConn.prepareError = errors.New("error") 109 | if _, err := db.Query("SELECT 1"); err == nil { 110 | t.Error("should get error here") 111 | } 112 | if retryCount != 10 { 113 | t.Error(retryCount) 114 | } 115 | sharedMockConn.prepareError = nil 116 | } 117 | -------------------------------------------------------------------------------- /value_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import "testing" 4 | 5 | func newValue(s string) value { 6 | return value{stringValue: &s} 7 | } 8 | 9 | func TestValue(t *testing.T) { 10 | value := newValue("42") 11 | if value.String() != "42" { 12 | t.Error() 13 | } 14 | if !value.Bool() { 15 | t.Error() 16 | } 17 | if value.Int() != 42 { 18 | t.Error() 19 | } 20 | if value.Int8() != 42 { 21 | t.Error() 22 | } 23 | if value.Int16() != 42 { 24 | t.Error() 25 | } 26 | if value.Int32() != 42 { 27 | t.Error() 28 | } 29 | if value.Int64() != 42 { 30 | t.Error() 31 | } 32 | 33 | if value.Uint() != 42 { 34 | t.Error() 35 | } 36 | if value.Uint8() != 42 { 37 | t.Error() 38 | } 39 | if value.Uint16() != 42 { 40 | t.Error() 41 | } 42 | if value.Uint32() != 42 { 43 | t.Error() 44 | } 45 | if value.Uint64() != 42 { 46 | t.Error() 47 | } 48 | } 49 | 50 | func TestValueOverflow1(t *testing.T) { 51 | value := newValue("3000000000") 52 | if value.Int() != 3000000000 { 53 | t.Error() 54 | } 55 | if value.Int8() != 0 { 56 | t.Error() 57 | } 58 | if value.Int16() != 0 { 59 | t.Error() 60 | } 61 | if value.Int32() != 0 { 62 | t.Error() 63 | } 64 | if value.Int64() != 3000000000 { 65 | t.Error() 66 | } 67 | 68 | if value.Uint() != 3000000000 { 69 | t.Error() 70 | } 71 | if value.Uint8() != 0 { 72 | t.Error() 73 | } 74 | if value.Uint16() != 0 { 75 | t.Error() 76 | } 77 | if value.Uint32() != 3000000000 { 78 | t.Error() 79 | } 80 | if value.Uint64() != 3000000000 { 81 | t.Error() 82 | } 83 | } 84 | 85 | func TestValueOverflow2(t *testing.T) { 86 | value := newValue("3000000000000000") 87 | if value.Int() != 3000000000000000 { 88 | t.Error() 89 | } 90 | if value.Int8() != 0 { 91 | t.Error() 92 | } 93 | if value.Int16() != 0 { 94 | t.Error() 95 | } 96 | if value.Int32() != 0 { 97 | t.Error() 98 | } 99 | if value.Int64() != 3000000000000000 { 100 | t.Error() 101 | } 102 | 103 | if value.Uint() != 3000000000000000 { 104 | t.Error() 105 | } 106 | if value.Uint8() != 0 { 107 | t.Error() 108 | } 109 | if value.Uint16() != 0 { 110 | t.Error() 111 | } 112 | if value.Uint32() != 0 { 113 | t.Error() 114 | } 115 | if value.Uint64() != 3000000000000000 { 116 | t.Error() 117 | } 118 | } 119 | 120 | func TestValueOverflow3(t *testing.T) { 121 | value := newValue("10000000000000000000") 122 | if value.Int() != 0 { 123 | t.Error(value.Int()) 124 | } 125 | if value.Uint() != 10000000000000000000 { 126 | t.Error() 127 | } 128 | } 129 | 130 | func TestValueBool(t *testing.T) { 131 | if !newValue("1").Bool() { 132 | t.Error() 133 | } 134 | if newValue("0").Bool() { 135 | t.Error() 136 | } 137 | if newValue("").Bool() { 138 | t.Error() 139 | } 140 | if (value{}).Bool() { 141 | t.Error() 142 | } 143 | } 144 | 145 | func TestValueNull(t *testing.T) { 146 | value := value{} 147 | if value.String() != "" { 148 | t.Error() 149 | } 150 | if value.Int() != 0 { 151 | t.Error() 152 | } 153 | if value.Uint() != 0 { 154 | t.Error() 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /update.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | type updateStatus struct { 11 | scope scope 12 | assignments []assignment 13 | where BooleanExpression 14 | orderBys []OrderBy 15 | limit *int 16 | ctx context.Context 17 | } 18 | 19 | func (d *database) Update(table Table) updateWithSet { 20 | return updateStatus{scope: scope{Database: d, Tables: []Table{table}}} 21 | } 22 | 23 | type updateWithSet interface { 24 | Set(Field Field, value interface{}) updateWithSet 25 | SetIf(prerequisite bool, Field Field, value interface{}) updateWithSet 26 | Where(conditions ...BooleanExpression) updateWithWhere 27 | OrderBy(orderBys ...OrderBy) updateWithOrder 28 | Limit(limit int) updateWithLimit 29 | } 30 | 31 | type updateWithWhere interface { 32 | toUpdateWithContext 33 | toUpdateFinal 34 | OrderBy(orderBys ...OrderBy) updateWithOrder 35 | Limit(limit int) updateWithLimit 36 | } 37 | 38 | type updateWithOrder interface { 39 | toUpdateWithContext 40 | toUpdateFinal 41 | Limit(limit int) updateWithLimit 42 | } 43 | 44 | type updateWithLimit interface { 45 | toUpdateWithContext 46 | toUpdateFinal 47 | } 48 | 49 | type toUpdateWithContext interface { 50 | WithContext(ctx context.Context) toUpdateFinal 51 | } 52 | 53 | type toUpdateFinal interface { 54 | GetSQL() (string, error) 55 | Execute() (sql.Result, error) 56 | } 57 | 58 | func (s updateStatus) Set(field Field, value interface{}) updateWithSet { 59 | s.assignments = append([]assignment{}, s.assignments...) 60 | s.assignments = append(s.assignments, assignment{ 61 | field: field, 62 | value: value, 63 | }) 64 | return s 65 | } 66 | 67 | func (s updateStatus) SetIf(prerequisite bool, field Field, value interface{}) updateWithSet { 68 | if prerequisite { 69 | return s.Set(field, value) 70 | } 71 | return s 72 | } 73 | 74 | func (s updateStatus) Where(conditions ...BooleanExpression) updateWithWhere { 75 | s.where = And(conditions...) 76 | return s 77 | } 78 | 79 | func (s updateStatus) OrderBy(orderBys ...OrderBy) updateWithOrder { 80 | s.orderBys = orderBys 81 | return s 82 | } 83 | 84 | func (s updateStatus) Limit(limit int) updateWithLimit { 85 | s.limit = &limit 86 | return s 87 | } 88 | 89 | func (s updateStatus) GetSQL() (string, error) { 90 | if len(s.assignments) == 0 { 91 | return "/* UPDATE without SET clause */ DO 0", nil 92 | } 93 | var sb strings.Builder 94 | sb.Grow(128) 95 | 96 | sb.WriteString("UPDATE ") 97 | sb.WriteString(s.scope.Tables[0].GetSQL(s.scope)) 98 | 99 | assignmentsSql, err := commaAssignments(s.scope, s.assignments) 100 | if err != nil { 101 | return "", err 102 | } 103 | sb.WriteString(" SET ") 104 | sb.WriteString(assignmentsSql) 105 | 106 | if err := appendWhere(&sb, s.scope, s.where); err != nil { 107 | return "", err 108 | } 109 | 110 | if len(s.orderBys) > 0 { 111 | orderBySql, err := commaOrderBys(s.scope, s.orderBys) 112 | if err != nil { 113 | return "", err 114 | } 115 | sb.WriteString(" ORDER BY ") 116 | sb.WriteString(orderBySql) 117 | } 118 | 119 | if s.limit != nil { 120 | sb.WriteString(" LIMIT ") 121 | sb.WriteString(strconv.Itoa(*s.limit)) 122 | } 123 | 124 | return sb.String(), nil 125 | } 126 | 127 | func (s updateStatus) WithContext(ctx context.Context) toUpdateFinal { 128 | s.ctx = ctx 129 | return s 130 | } 131 | 132 | func (s updateStatus) Execute() (sql.Result, error) { 133 | sqlString, err := s.GetSQL() 134 | if err != nil { 135 | return nil, err 136 | } 137 | return s.scope.Database.ExecuteContext(s.ctx, sqlString) 138 | } 139 | -------------------------------------------------------------------------------- /field.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import "strings" 4 | 5 | // Field is the interface of a generated field. 6 | type Field interface { 7 | Expression 8 | GetTable() Table 9 | } 10 | 11 | // NumberField is the interface of a generated field of number type. 12 | type NumberField interface { 13 | NumberExpression 14 | GetTable() Table 15 | } 16 | 17 | // BooleanField is the interface of a generated field of boolean type. 18 | type BooleanField interface { 19 | BooleanExpression 20 | GetTable() Table 21 | } 22 | 23 | // StringField is the interface of a generated field of string type. 24 | type StringField interface { 25 | StringExpression 26 | GetTable() Table 27 | } 28 | 29 | // ArrayField is the interface of a generated field of array type. 30 | type ArrayField interface { 31 | ArrayExpression 32 | IndexOf(index int) Field 33 | GetTable() Table 34 | } 35 | 36 | type DateField interface { 37 | DateExpression 38 | GetTable() Table 39 | } 40 | 41 | type actualField struct { 42 | expression 43 | table Table 44 | } 45 | 46 | func (f actualField) GetTable() Table { 47 | return f.table 48 | } 49 | 50 | func newField(table Table, fieldName string) actualField { 51 | tableName := table.GetName() 52 | tableNameSqlArray := quoteIdentifier(tableName) 53 | fieldNameSqlArray := quoteIdentifier(fieldName) 54 | 55 | var fullFieldNameSqlArray dialectArray 56 | for dialect := dialect(0); dialect < dialectCount; dialect++ { 57 | fullFieldNameSqlArray[dialect] = tableNameSqlArray[dialect] + "." + fieldNameSqlArray[dialect] 58 | } 59 | 60 | return actualField{ 61 | expression: expression{ 62 | builder: func(scope scope) (string, error) { 63 | dialect := dialectUnknown 64 | if scope.Database != nil { 65 | dialect = scope.Database.dialect 66 | } 67 | if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName { 68 | return fullFieldNameSqlArray[dialect], nil 69 | } 70 | return fieldNameSqlArray[dialect], nil 71 | }, 72 | }, 73 | table: table, 74 | } 75 | } 76 | 77 | // NewNumberField creates a reference to a number field. It should only be called from generated code. 78 | func NewNumberField(table Table, fieldName string) NumberField { 79 | return newField(table, fieldName) 80 | } 81 | 82 | // NewBooleanField creates a reference to a boolean field. It should only be called from generated code. 83 | func NewBooleanField(table Table, fieldName string) BooleanField { 84 | return newField(table, fieldName) 85 | } 86 | 87 | // NewStringField creates a reference to a string field. It should only be called from generated code. 88 | func NewStringField(table Table, fieldName string) StringField { 89 | return newField(table, fieldName) 90 | } 91 | 92 | // NewDateField creates a reference to a time.Time field. It should only be called from generated code. 93 | func NewDateField(table Table, fieldName string) DateField { 94 | return newField(table, fieldName) 95 | } 96 | 97 | type fieldList []Field 98 | 99 | func (fields fieldList) GetSQL(scope scope) (string, error) { 100 | isSingleTable := len(scope.Tables) == 1 && scope.lastJoin == nil 101 | var sb strings.Builder 102 | if len(fields) == 0 { 103 | for i, table := range scope.Tables { 104 | if i > 0 { 105 | sb.WriteString(", ") 106 | } 107 | actualTable, ok := table.(actualTable) 108 | if ok { 109 | if isSingleTable { 110 | sb.WriteString(actualTable.GetFieldsSQL()) 111 | } else { 112 | sb.WriteString(actualTable.GetFullFieldsSQL()) 113 | } 114 | } else { 115 | sb.WriteByte('*') 116 | } 117 | } 118 | } else { 119 | fieldsSql, err := commaFields(scope, fields) 120 | if err != nil { 121 | return "", err 122 | } 123 | sb.WriteString(fieldsSql) 124 | } 125 | return sb.String(), nil 126 | } 127 | -------------------------------------------------------------------------------- /common.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "fmt" 5 | "runtime" 6 | "strings" 7 | ) 8 | 9 | const ( 10 | // SqlingoRuntimeVersion is the the runtime version of sqlingo 11 | SqlingoRuntimeVersion = 2 12 | ) 13 | 14 | // Model is the interface of generated model struct 15 | type Model interface { 16 | GetTable() Table 17 | GetValues() []interface{} 18 | } 19 | 20 | // Assignment is an assignment statement 21 | type Assignment interface { 22 | GetSQL(scope scope) (string, error) 23 | } 24 | 25 | type assignment struct { 26 | Assignment 27 | field Field 28 | value interface{} 29 | } 30 | 31 | func (a assignment) GetSQL(scope scope) (string, error) { 32 | value, _, err := getSQL(scope, a.value) 33 | if err != nil { 34 | return "", err 35 | } 36 | fieldSql, err := a.field.GetSQL(scope) 37 | if err != nil { 38 | return "", err 39 | } 40 | return fieldSql + " = " + value, nil 41 | } 42 | 43 | func command(name string, arg interface{}) expression { 44 | return expression{builder: func(scope scope) (string, error) { 45 | sql, _, err := getSQL(scope, arg) 46 | if err != nil { 47 | return "", err 48 | } 49 | return name + " " + sql, nil 50 | }} 51 | } 52 | 53 | func commaFields(scope scope, fields []Field) (string, error) { 54 | var sqlBuilder strings.Builder 55 | sqlBuilder.Grow(128) 56 | for i, item := range fields { 57 | if i > 0 { 58 | sqlBuilder.WriteString(", ") 59 | } 60 | itemSql, err := item.GetSQL(scope) 61 | if err != nil { 62 | return "", err 63 | } 64 | sqlBuilder.WriteString(itemSql) 65 | } 66 | return sqlBuilder.String(), nil 67 | } 68 | 69 | func commaExpressions(scope scope, expressions []Expression) (string, error) { 70 | var sqlBuilder strings.Builder 71 | sqlBuilder.Grow(128) 72 | for i, item := range expressions { 73 | if i > 0 { 74 | sqlBuilder.WriteString(", ") 75 | } 76 | itemSql, err := item.GetSQL(scope) 77 | if err != nil { 78 | return "", err 79 | } 80 | sqlBuilder.WriteString(itemSql) 81 | } 82 | return sqlBuilder.String(), nil 83 | } 84 | 85 | func commaTables(scope scope, tables []Table) string { 86 | var sqlBuilder strings.Builder 87 | sqlBuilder.Grow(32) 88 | for i, table := range tables { 89 | if i > 0 { 90 | sqlBuilder.WriteString(", ") 91 | } 92 | sqlBuilder.WriteString(table.GetSQL(scope)) 93 | } 94 | return sqlBuilder.String() 95 | } 96 | 97 | func commaValues(scope scope, values []interface{}) (string, error) { 98 | var sqlBuilder strings.Builder 99 | for i, item := range values { 100 | if i > 0 { 101 | sqlBuilder.WriteString(", ") 102 | } 103 | itemSql, _, err := getSQL(scope, item) 104 | if err != nil { 105 | return "", err 106 | } 107 | sqlBuilder.WriteString(itemSql) 108 | } 109 | return sqlBuilder.String(), nil 110 | } 111 | 112 | func commaAssignments(scope scope, assignments []assignment) (string, error) { 113 | var sqlBuilder strings.Builder 114 | for i, item := range assignments { 115 | if i > 0 { 116 | sqlBuilder.WriteString(", ") 117 | } 118 | itemSql, err := item.GetSQL(scope) 119 | if err != nil { 120 | return "", err 121 | } 122 | sqlBuilder.WriteString(itemSql) 123 | } 124 | return sqlBuilder.String(), nil 125 | } 126 | 127 | func commaOrderBys(scope scope, orderBys []OrderBy) (string, error) { 128 | var sqlBuilder strings.Builder 129 | for i, item := range orderBys { 130 | if i > 0 { 131 | sqlBuilder.WriteString(", ") 132 | } 133 | itemSql, err := item.GetSQL(scope) 134 | if err != nil { 135 | return "", err 136 | } 137 | sqlBuilder.WriteString(itemSql) 138 | } 139 | return sqlBuilder.String(), nil 140 | } 141 | 142 | func getCallerInfo(db database, retry bool) string { 143 | if !db.enableCallerInfo { 144 | return "" 145 | } 146 | extraInfo := "" 147 | if db.tx != nil { 148 | extraInfo += " (tx)" 149 | } 150 | if retry { 151 | extraInfo += " (retry)" 152 | } 153 | for i := 0; true; i++ { 154 | _, file, line, ok := runtime.Caller(i) 155 | if !ok { 156 | break 157 | } 158 | if file == "" || strings.Contains(file, "/sqlingo@v") { 159 | continue 160 | } 161 | segs := strings.Split(file, "/") 162 | name := segs[len(segs)-1] 163 | return fmt.Sprintf("/* %s:%d%s */ ", name, line, extraInfo) 164 | } 165 | return "" 166 | } 167 | -------------------------------------------------------------------------------- /insert_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | ) 8 | 9 | type tTest struct { 10 | Table 11 | 12 | F1 fTestF1 13 | F2 fTestF2 14 | } 15 | 16 | func (t tTest) GetFields() []Field { 17 | return []Field{t.F1, t.F2} 18 | } 19 | 20 | type fTestF1 struct{ NumberField } 21 | type fTestF2 struct{ StringField } 22 | 23 | type TestModel struct { 24 | F1 int64 25 | F2 string 26 | } 27 | 28 | var tTestTable = NewTable("test") 29 | 30 | var Test = tTest{ 31 | Table: NewTable("test"), 32 | F1: fTestF1{NewNumberField(tTestTable, "f1")}, 33 | F2: fTestF2{NewStringField(tTestTable, "f2")}, 34 | } 35 | 36 | func (m TestModel) GetTable() Table { 37 | return Test 38 | } 39 | 40 | func (m TestModel) GetValues() []interface{} { 41 | return []interface{}{m.F1, m.F2} 42 | } 43 | 44 | func TestInsert(t *testing.T) { 45 | db := newMockDatabase() 46 | 47 | if _, err := db.InsertInto(Table1).Fields(field1). 48 | Values(1). 49 | Values(2). 50 | OnDuplicateKeyUpdate().Set(field1, 10). 51 | Execute(); err != nil { 52 | t.Error(err) 53 | } 54 | assertLastSql(t, "INSERT INTO `table1` (`field1`)"+ 55 | " VALUES (1), (2)"+ 56 | " ON DUPLICATE KEY UPDATE `field1` = 10") 57 | 58 | if _, err := db.InsertInto(Table1).Fields(field1). 59 | Values(1). 60 | OnDuplicateKeyUpdate(). 61 | SetIf(false, field1, 2). 62 | Execute(); err != nil { 63 | t.Error(err) 64 | } 65 | assertLastSql(t, "INSERT INTO `table1` (`field1`)"+ 66 | " VALUES (1)") 67 | 68 | if _, err := db.InsertInto(Table1).Fields(field1). 69 | Values(0). 70 | OnDuplicateKeyUpdate(). 71 | SetIf(true, field1, 1). 72 | Execute(); err != nil { 73 | t.Error(err) 74 | } 75 | assertLastSql(t, "INSERT INTO `table1` (`field1`)"+ 76 | " VALUES (0)"+ 77 | " ON DUPLICATE KEY UPDATE `field1` = 1") 78 | 79 | if _, err := db.InsertInto(Table1). 80 | Fields(field1, field2). 81 | Values(1, 2). 82 | OnDuplicateKeyUpdate(). 83 | SetIf(false, field1, 10). 84 | SetIf(true, field2, 20). 85 | Execute(); err != nil { 86 | t.Error(err) 87 | } 88 | assertLastSql(t, "INSERT INTO `table1` (`field1`, `field2`)"+ 89 | " VALUES (1, 2)"+ 90 | " ON DUPLICATE KEY UPDATE `field2` = 20") 91 | 92 | if _, err := db.InsertInto(Table1).Fields(field1). 93 | Values(1). 94 | Values(2). 95 | OnDuplicateKeyIgnore(). 96 | Execute(); err != nil { 97 | t.Error(err) 98 | } 99 | assertLastSql(t, "INSERT INTO `table1` (`field1`)"+ 100 | " VALUES (1), (2)"+ 101 | " ON DUPLICATE KEY UPDATE `field1` = `field1`") 102 | 103 | model := &TestModel{ 104 | F1: 1, 105 | F2: "test", 106 | } 107 | if _, err := db.InsertInto(Test).Values(1, 2).Execute(); err != nil { 108 | t.Error(err) 109 | } 110 | assertLastSql(t, "INSERT INTO `test` (`f1`, `f2`) VALUES (1, 2)") 111 | 112 | if _, err := db.InsertInto(Test).Models(model, &model, []Model{model}).Execute(); err != nil { 113 | t.Error(err) 114 | } 115 | assertLastSql(t, "INSERT INTO `test` (`f1`, `f2`) VALUES (1, 'test'), (1, 'test'), (1, 'test')") 116 | 117 | if _, err := db.InsertInto(Test).Models(model, &model, []interface{}{model, "invalid type"}).Execute(); err == nil { 118 | t.Error("should get error here") 119 | } 120 | 121 | if _, err := db.InsertInto(Table1).Models(model).Execute(); err == nil { 122 | t.Error("should get error here") 123 | } 124 | 125 | if _, err := db.ReplaceInto(Test).Values(1, 2).Execute(); err != nil { 126 | t.Error(err) 127 | } 128 | assertLastSql(t, "REPLACE INTO `test` (`f1`, `f2`) VALUES (1, 2)") 129 | 130 | errExpr := expression{ 131 | builder: func(scope scope) (string, error) { 132 | return "", errors.New("error") 133 | }, 134 | } 135 | if _, err := db.InsertInto(Test).Fields(errExpr).Values(1).Execute(); err == nil { 136 | t.Error("should get error here") 137 | } 138 | if _, err := db.InsertInto(Test).Fields(Test.F1).Values(errExpr).Execute(); err == nil { 139 | t.Error("should get error here") 140 | } 141 | if _, err := db.InsertInto(Test). 142 | Fields(Test.F1).Values(1). 143 | OnDuplicateKeyUpdate().Set(Test.F1, errExpr).Execute(); err == nil { 144 | t.Error("should get error here") 145 | } 146 | 147 | if _, err := db.InsertInto(Test).Fields(Test.F1).Values(1).WithContext(context.Background()).Execute(); err != nil { 148 | t.Error(err) 149 | } 150 | 151 | if _, err := db.InsertInto(Test). 152 | Fields(Test.F1).Values(1). 153 | OnDuplicateKeyUpdate().Set(Test.F1, errExpr).WithContext(context.Background()).Execute(); err == nil { 154 | t.Error("should get error here") 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) 4 | [![go.dev](https://img.shields.io/badge/go.dev-reference-007d9c)](https://pkg.go.dev/github.com/lqs/sqlingo?tab=doc) 5 | [![Travis CI](https://api.travis-ci.com/lqs/sqlingo.svg?branch=master)](https://app.travis-ci.com/github/lqs/sqlingo) 6 | [![Go Report Card](https://goreportcard.com/badge/github.com/lqs/sqlingo)](https://goreportcard.com/report/github.com/lqs/sqlingo) 7 | [![codecov](https://codecov.io/gh/lqs/sqlingo/branch/master/graph/badge.svg)](https://codecov.io/gh/lqs/sqlingo) 8 | [![MIT license](http://img.shields.io/badge/license-MIT-9d1f14)](http://opensource.org/licenses/MIT) 9 | [![last commit](https://img.shields.io/github/last-commit/lqs/sqlingo.svg)](https://github.com/lqs/sqlingo/commits) 10 | 11 | **sqlingo** is a SQL DSL (a.k.a. SQL Builder or ORM) library in Go. It generates code from the database and lets you write SQL queries in an elegant way. 12 | 13 | 14 | 15 | ## Features 16 | * Auto-generating DSL objects and model structs from the database so you don't need to manually keep things in sync 17 | * SQL DML (SELECT / INSERT / UPDATE / DELETE) with some advanced SQL query syntaxes 18 | * Many common errors could be detected at compile time 19 | * Your can use the features in your editor / IDE, such as autocompleting the fields and queries, or finding the usage of a field or a table 20 | * Context support 21 | * Transaction support 22 | * Interceptor support 23 | * Golang time.Time is supported now, but you can still use the string type by adding `-timeAsString` when generating the model 24 | 25 | ## Database Support Status 26 | | Database | Status | 27 | ------------- | -------------- 28 | | MySQL | stable | 29 | | PostgreSQL | experimental | 30 | | SQLite | experimental | 31 | 32 | ## Tutorial 33 | 34 | ### Install and use sqlingo code generator 35 | The first step is to generate code from the database. In order to generate code, sqlingo requires your tables are already created in the database. 36 | 37 | ``` 38 | $ go install github.com/lqs/sqlingo/sqlingo-gen-mysql@latest 39 | $ mkdir -p generated/sqlingo 40 | $ sqlingo-gen-mysql root:123456@/database_name >generated/sqlingo/database_name.dsl.go 41 | ``` 42 | 43 | 44 | ### Write your application 45 | Here's a demonstration of some simple & advanced usage of sqlingo. 46 | ```go 47 | package main 48 | 49 | import ( 50 | "github.com/lqs/sqlingo" 51 | . "./generated/sqlingo" 52 | ) 53 | 54 | func main() { 55 | db, err := sqlingo.Open("mysql", "root:123456@/database_name") 56 | if err != nil { 57 | panic(err) 58 | } 59 | 60 | // a simple query 61 | var customers []*CustomerModel 62 | db.SelectFrom(Customer). 63 | Where(Customer.Id.In(1, 2)). 64 | OrderBy(Customer.Name.Desc()). 65 | FetchAll(&customers) 66 | 67 | // query from multiple tables 68 | var customerId int64 69 | var orderId int64 70 | err = db.Select(Customer.Id, Order.Id). 71 | From(Customer, Order). 72 | Where(Customer.Id.Equals(Order.CustomerId), Order.Id.Equals(1)). 73 | FetchFirst(&customerId, &orderId) 74 | 75 | // subquery and count 76 | count, err := db.SelectFrom(Order) 77 | Where(Order.CustomerId.In(db.Select(Customer.Id). 78 | From(Customer). 79 | Where(Customer.Name.Equals("Customer One")))). 80 | Count() 81 | 82 | // group-by with auto conversion to map 83 | var customerIdToOrderCount map[int64]int64 84 | err = db.Select(Order.CustomerId, f.Count(1)). 85 | From(Order). 86 | GroupBy(Order.CustomerId). 87 | FetchAll(&customerIdToOrderCount) 88 | if err != nil { 89 | println(err) 90 | } 91 | 92 | // insert some rows 93 | customer1 := &CustomerModel{name: "Customer One"} 94 | customer2 := &CustomerModel{name: "Customer Two"} 95 | _, err = db.InsertInto(Customer). 96 | Models(customer1, customer2). 97 | Execute() 98 | 99 | // insert with on-duplicate-key-update 100 | _, err = db.InsertInto(Customer). 101 | Fields(Customer.Id, Customer.Name). 102 | Values(42, "Universe"). 103 | OnDuplicateKeyUpdate(). 104 | Set(Customer.Name, Customer.Name.Concat(" 2")). 105 | Execute() 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /cursor_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "io" 7 | "strconv" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | type mockDriver struct{} 13 | 14 | type mockConn struct { 15 | lastSql string 16 | mockTx *mockTx 17 | beginTxError error 18 | prepareError error 19 | columnCount int 20 | rowCount int 21 | } 22 | 23 | type mockStmt struct { 24 | columnCount int 25 | rowCount int 26 | } 27 | 28 | type mockRows struct { 29 | columnCount int 30 | cursorPosition int 31 | rowCount int 32 | } 33 | 34 | func (m mockRows) Columns() []string { 35 | return []string{"a", "b", "c", "d", "e", "f", "g", "h", "j", "k", "l"}[:m.columnCount] 36 | } 37 | 38 | func (m mockRows) Close() error { 39 | return nil 40 | } 41 | 42 | func (m *mockRows) Next(dest []driver.Value) error { 43 | if m.cursorPosition >= m.rowCount { 44 | return io.EOF 45 | } 46 | m.cursorPosition++ 47 | for i := 0; i < m.columnCount; i++ { 48 | switch i { 49 | case 0: 50 | dest[i] = strconv.Itoa(m.cursorPosition) 51 | case 1: 52 | dest[i] = float32(m.cursorPosition) 53 | case 2: 54 | dest[i] = m.cursorPosition 55 | case 3: 56 | dest[i] = string(rune(m.cursorPosition % 2)) // '\x00' or '\x01' 57 | case 4: 58 | dest[i] = strconv.Itoa(m.cursorPosition % 2) // '0' or '1' 59 | case 5: 60 | dest[i] = dest[0] 61 | case 6: 62 | dest[i] = nil 63 | case 7: 64 | dest[i] = time.Date(2023, 9, 6, 18, 37, 46, 828000000, time.UTC) 65 | case 8: 66 | dest[i] = "2023-09-06 18:37:46.828" 67 | case 9: 68 | dest[i] = "2023-09-06 18:37:46" 69 | case 10: 70 | dest[i] = "2023-09-06 18:37:46" 71 | } 72 | } 73 | return nil 74 | } 75 | 76 | func (m mockStmt) Close() error { 77 | return nil 78 | } 79 | 80 | func (m mockStmt) NumInput() int { 81 | return 0 82 | } 83 | 84 | func (m mockStmt) Exec(args []driver.Value) (driver.Result, error) { 85 | return driver.ResultNoRows, nil 86 | } 87 | 88 | func (m mockStmt) Query(args []driver.Value) (driver.Rows, error) { 89 | return &mockRows{ 90 | columnCount: m.columnCount, 91 | rowCount: m.rowCount, 92 | }, nil 93 | } 94 | 95 | func TestCursor(t *testing.T) { 96 | db := newMockDatabase() 97 | cursor, _ := db.Query("dummy sql") 98 | 99 | var a int 100 | var b string 101 | 102 | var cde struct { 103 | C float32 104 | DE struct { 105 | D, E bool 106 | } 107 | } 108 | var f ****int // deep pointer 109 | var g *int // always null 110 | var h *time.Time 111 | var j time.Time 112 | var k *time.Time 113 | var l time.Time 114 | tmh, _ := time.Parse("2006-01-02 15:04:05.000", "2023-09-06 18:37:46.828") 115 | tmj, _ := time.Parse("2006-01-02 15:04:05.000", "2023-09-06 18:37:46.828") 116 | tmk, _ := time.Parse("2006-01-02 15:04:05", "2023-09-06 18:37:46") 117 | tml, _ := time.Parse("2006-01-02 15:04:05", "2023-09-06 18:37:46") 118 | for i := 1; i <= 10; i++ { 119 | if !cursor.Next() { 120 | t.Error() 121 | } 122 | g = &i 123 | if err := cursor.Scan(&a, &b, &cde, &f, &g, &h, &j, &k, &l); err != nil { 124 | t.Errorf("%v", err) 125 | } 126 | if a != i || 127 | b != strconv.Itoa(i) || 128 | cde.C != float32(i) || 129 | cde.DE.D != (i%2 == 1) || 130 | cde.DE.E != cde.DE.D || 131 | ****f != i || 132 | g != nil || 133 | *h != tmh || 134 | j != tmj || 135 | *k != tmk || 136 | l != tml { 137 | t.Error(a, b, cde.C, cde.DE.D, cde.DE.E, ****f, g) 138 | } 139 | if err := cursor.Scan(); err != nil { 140 | t.Errorf("%v", err) 141 | } 142 | 143 | var s string 144 | var b ****bool 145 | var p *string 146 | var bs []byte 147 | var u string 148 | if err := cursor.Scan(&s, &s, &s, &b, &s, &bs, &p, &u, &u, &u, &u); err != nil { 149 | t.Error(err) 150 | } 151 | if ****b != (i%2 == 1) || 152 | p != nil || 153 | string(bs) != strconv.Itoa(i) { 154 | t.Error(s, ****b, p, string(bs)) 155 | } 156 | } 157 | if cursor.Next() { 158 | t.Errorf("d") 159 | } 160 | if err := cursor.Close(); err != nil { 161 | t.Error(err) 162 | } 163 | 164 | } 165 | 166 | func TestScanTime(t *testing.T) { 167 | db := newMockDatabase() 168 | cursor, _ := db.Query("dummy sql") 169 | defer cursor.Close() 170 | 171 | var row struct { 172 | A sql.NullString 173 | B []byte 174 | C sql.NullInt32 175 | D sql.NullString 176 | E sql.NullString 177 | F sql.NullString 178 | G sql.NullString 179 | H sql.NullTime 180 | I sql.NullString 181 | J sql.NullString 182 | K sql.NullString 183 | } 184 | if !cursor.Next() { 185 | t.Error() 186 | } 187 | if err := cursor.Scan(&row); err != nil { 188 | t.Error(err) 189 | } 190 | } 191 | 192 | func TestCursorMap(t *testing.T) { 193 | db := newMockDatabase() 194 | cursor, _ := db.Query("dummy sql") 195 | 196 | for i := 1; i <= 10; i++ { 197 | if !cursor.Next() { 198 | t.Error() 199 | } 200 | row, err := cursor.GetMap() 201 | if err != nil { 202 | t.Error(err) 203 | } 204 | if row["a"].Int() != i { 205 | t.Error() 206 | } 207 | } 208 | if cursor.Next() { 209 | t.Error() 210 | } 211 | } 212 | 213 | func TestParseTime(t *testing.T) { 214 | tests := []struct { 215 | input string 216 | output time.Time 217 | }{ 218 | {"2024-09-06 11:22:33", time.Date(2024, 9, 6, 11, 22, 33, 0, time.UTC)}, 219 | {"2024-09-06 11:22:33.444", time.Date(2024, 9, 6, 11, 22, 33, 444000000, time.UTC)}, 220 | {"2024-09-06 11:22:33.444555666", time.Date(2024, 9, 6, 11, 22, 33, 444555666, time.UTC)}, 221 | {"2024-09-06T11:22:33.444555666Z", time.Date(2024, 9, 6, 11, 22, 33, 444555666, time.UTC)}, 222 | {"0000-00-00 00:00:00", time.Time{}}, 223 | } 224 | for _, test := range tests { 225 | tm, err := parseTime(test.input) 226 | if err != nil { 227 | t.Error(err) 228 | continue 229 | } 230 | if tm != test.output { 231 | t.Error(tm, test.output) 232 | } 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /array.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "strconv" 9 | "unicode" 10 | ) 11 | 12 | type ArrayDimension struct { 13 | Length int32 14 | LowerBound int32 15 | } 16 | 17 | type untypedTextArray struct { 18 | Elements []string // elements as string 19 | Quoted []bool // true if source is quoted 20 | Dimensions []ArrayDimension // 21 | } 22 | 23 | // Skip the space prefix 24 | func skipWhitespace(buf *bytes.Buffer) { 25 | var r rune 26 | var err error 27 | for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() { 28 | } 29 | 30 | if err != io.EOF { 31 | buf.UnreadRune() 32 | } 33 | } 34 | 35 | func parseToUntypedTextArray(src string) (*untypedTextArray, error) { 36 | dst := &untypedTextArray{ 37 | Elements: []string{}, 38 | Quoted: []bool{}, 39 | Dimensions: []ArrayDimension{}, 40 | } 41 | 42 | buf := bytes.NewBufferString(src) 43 | skipWhitespace(buf) 44 | 45 | // read failed 46 | r, _, err := buf.ReadRune() 47 | if err != nil { 48 | return nil, fmt.Errorf("invalid array: %v", err) 49 | } 50 | 51 | var explicitDimensions []ArrayDimension 52 | 53 | // Array has explicit dimensions 54 | if r == '[' { 55 | buf.UnreadRune() 56 | for { 57 | r, _, err = buf.ReadRune() 58 | if err != nil { 59 | return nil, fmt.Errorf("invalid array: %v", err) 60 | } 61 | if r == '=' { 62 | break 63 | } else if r != '[' { 64 | return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) 65 | } 66 | 67 | lower, err := arrayParseInteger(buf) 68 | if err != nil { 69 | return nil, fmt.Errorf("invalid array: %v", err) 70 | } 71 | 72 | r, _, err = buf.ReadRune() 73 | if err != nil { 74 | return nil, fmt.Errorf("invalid array: %v", err) 75 | } 76 | 77 | if r != ':' { 78 | return nil, fmt.Errorf("invalid array, expected ':' got %v", r) 79 | } 80 | 81 | upper, err := arrayParseInteger(buf) 82 | if err != nil { 83 | return nil, fmt.Errorf("invalid array: %v", err) 84 | } 85 | 86 | r, _, err = buf.ReadRune() 87 | if err != nil { 88 | return nil, fmt.Errorf("invalid array: %v", err) 89 | } 90 | 91 | if r != ']' { 92 | return nil, fmt.Errorf("invalid array, expected ']' got %v", r) 93 | } 94 | 95 | explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) 96 | } 97 | } 98 | 99 | if r != '{' { 100 | return nil, errors.New("invalid array, expected '{' prefix") 101 | } 102 | implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} 103 | 104 | // Consume all initial opening brackets. This provides number of dimensions. 105 | for { 106 | r, _, err = buf.ReadRune() 107 | if err != nil { 108 | return nil, fmt.Errorf("invalid array: %v", err) 109 | } 110 | 111 | if r == '{' { 112 | implicitDimensions[len(implicitDimensions)-1].Length = 1 113 | implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1}) 114 | } else { 115 | buf.UnreadRune() 116 | break 117 | } 118 | } 119 | currentDim := len(implicitDimensions) - 1 120 | counterDim := currentDim 121 | 122 | for { 123 | r, _, err = buf.ReadRune() 124 | if err != nil { 125 | return nil, fmt.Errorf("invalid array: %v", err) 126 | } 127 | 128 | switch r { 129 | case '{': 130 | if currentDim == counterDim { 131 | implicitDimensions[currentDim].Length++ 132 | } 133 | currentDim++ 134 | case ',': 135 | case '}': 136 | currentDim-- 137 | if currentDim < counterDim { 138 | counterDim = currentDim 139 | } 140 | default: 141 | buf.UnreadRune() 142 | value, quoted, err := arrayParseValue(buf) 143 | if err != nil { 144 | return nil, fmt.Errorf("invalid array value: %v", err) 145 | } 146 | if currentDim == counterDim { 147 | implicitDimensions[currentDim].Length++ 148 | } 149 | dst.Quoted = append(dst.Quoted, quoted) 150 | dst.Elements = append(dst.Elements, value) 151 | } 152 | 153 | if currentDim < 0 { 154 | break 155 | } 156 | } 157 | 158 | skipWhitespace(buf) 159 | 160 | if buf.Len() > 0 { 161 | return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) 162 | } 163 | 164 | if len(dst.Elements) == 0 { 165 | } else if len(explicitDimensions) > 0 { 166 | dst.Dimensions = explicitDimensions 167 | } else { 168 | dst.Dimensions = implicitDimensions 169 | } 170 | 171 | return dst, nil 172 | } 173 | 174 | func arrayParseInteger(buf *bytes.Buffer) (int32, error) { 175 | s := &bytes.Buffer{} 176 | 177 | for { 178 | r, _, err := buf.ReadRune() 179 | if err != nil { 180 | return 0, err 181 | } 182 | 183 | if ('0' <= r && r <= '9') || r == '-' { 184 | s.WriteRune(r) 185 | } else { 186 | buf.UnreadRune() 187 | n, err := strconv.ParseInt(s.String(), 10, 32) 188 | if err != nil { 189 | return 0, err 190 | } 191 | return int32(n), nil 192 | } 193 | } 194 | } 195 | func arrayParseValue(buf *bytes.Buffer) (string, bool, error) { 196 | r, _, err := buf.ReadRune() 197 | if err != nil { 198 | return "", false, err 199 | } 200 | if r == '"' { 201 | return arrayParseQuotedValue(buf) 202 | } 203 | buf.UnreadRune() 204 | 205 | s := &bytes.Buffer{} 206 | 207 | for { 208 | r, _, err := buf.ReadRune() 209 | if err != nil { 210 | return "", false, err 211 | } 212 | 213 | switch r { 214 | case ',', '}': 215 | buf.UnreadRune() 216 | return s.String(), false, nil 217 | } 218 | 219 | s.WriteRune(r) 220 | } 221 | } 222 | 223 | func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) { 224 | s := &bytes.Buffer{} 225 | 226 | for { 227 | r, _, err := buf.ReadRune() 228 | if err != nil { 229 | return "", false, err 230 | } 231 | 232 | switch r { 233 | case '\\': 234 | r, _, err = buf.ReadRune() 235 | if err != nil { 236 | return "", false, err 237 | } 238 | case '"': 239 | r, _, err = buf.ReadRune() 240 | if err != nil { 241 | return "", false, err 242 | } 243 | buf.UnreadRune() 244 | return s.String(), true, nil 245 | } 246 | s.WriteRune(r) 247 | } 248 | } 249 | -------------------------------------------------------------------------------- /insert.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "reflect" 9 | ) 10 | 11 | type insertStatus struct { 12 | method string 13 | scope scope 14 | fields []Field 15 | values []interface{} 16 | models []interface{} 17 | onDuplicateKeyUpdateAssignments []assignment 18 | ctx context.Context 19 | } 20 | 21 | type insertWithTable interface { 22 | Fields(fields ...Field) insertWithValues 23 | Values(values ...interface{}) insertWithValues 24 | Models(models ...interface{}) insertWithModels 25 | } 26 | 27 | type insertWithValues interface { 28 | toInsertWithContext 29 | toInsertFinal 30 | Values(values ...interface{}) insertWithValues 31 | OnDuplicateKeyIgnore() toInsertWithDuplicateKey 32 | OnDuplicateKeyUpdate() insertWithOnDuplicateKeyUpdateBegin 33 | } 34 | 35 | type insertWithModels interface { 36 | toInsertWithContext 37 | toInsertFinal 38 | Models(models ...interface{}) insertWithModels 39 | OnDuplicateKeyIgnore() toInsertWithDuplicateKey 40 | OnDuplicateKeyUpdate() insertWithOnDuplicateKeyUpdateBegin 41 | } 42 | 43 | type insertWithOnDuplicateKeyUpdateBegin interface { 44 | Set(Field Field, value interface{}) insertWithOnDuplicateKeyUpdate 45 | SetIf(condition bool, Field Field, value interface{}) insertWithOnDuplicateKeyUpdate 46 | } 47 | 48 | type insertWithOnDuplicateKeyUpdate interface { 49 | insertWithOnDuplicateKeyUpdateBegin 50 | toInsertWithDuplicateKey 51 | } 52 | 53 | type toInsertWithContext interface { 54 | WithContext(ctx context.Context) toInsertFinal 55 | } 56 | 57 | type toInsertFinal interface { 58 | GetSQL() (string, error) 59 | Execute() (result sql.Result, err error) 60 | } 61 | 62 | type toInsertWithDuplicateKey interface { 63 | toInsertWithContext 64 | toInsertFinal 65 | } 66 | 67 | func (d *database) InsertInto(table Table) insertWithTable { 68 | return insertStatus{method: "INSERT", scope: scope{Database: d, Tables: []Table{table}}} 69 | } 70 | 71 | func (d *database) ReplaceInto(table Table) insertWithTable { 72 | return insertStatus{method: "REPLACE", scope: scope{Database: d, Tables: []Table{table}}} 73 | } 74 | 75 | func (s insertStatus) Fields(fields ...Field) insertWithValues { 76 | s.fields = fields 77 | return s 78 | } 79 | 80 | func (s insertStatus) Values(values ...interface{}) insertWithValues { 81 | s.values = append([]interface{}{}, s.values...) 82 | s.values = append(s.values, values) 83 | return s 84 | } 85 | 86 | func addModel(models *[]Model, model interface{}) error { 87 | if model, ok := model.(Model); ok { 88 | *models = append(*models, model) 89 | return nil 90 | } 91 | 92 | value := reflect.ValueOf(model) 93 | switch value.Kind() { 94 | case reflect.Ptr: 95 | value = reflect.Indirect(value) 96 | return addModel(models, value.Interface()) 97 | case reflect.Slice, reflect.Array: 98 | for i := 0; i < value.Len(); i++ { 99 | elem := value.Index(i) 100 | addr := elem.Addr() 101 | inter := addr.Interface() 102 | if err := addModel(models, inter); err != nil { 103 | return err 104 | } 105 | } 106 | return nil 107 | default: 108 | return fmt.Errorf("unknown model type (kind = %d)", value.Kind()) 109 | } 110 | } 111 | 112 | func (s insertStatus) Models(models ...interface{}) insertWithModels { 113 | s.models = models 114 | return s 115 | } 116 | 117 | func (s insertStatus) OnDuplicateKeyUpdate() insertWithOnDuplicateKeyUpdateBegin { 118 | return s 119 | } 120 | 121 | func (s insertStatus) SetIf(condition bool, field Field, value interface{}) insertWithOnDuplicateKeyUpdate { 122 | if condition { 123 | return s.Set(field, value) 124 | } 125 | return s 126 | } 127 | 128 | func (s insertStatus) Set(field Field, value interface{}) insertWithOnDuplicateKeyUpdate { 129 | s.onDuplicateKeyUpdateAssignments = append([]assignment{}, s.onDuplicateKeyUpdateAssignments...) 130 | s.onDuplicateKeyUpdateAssignments = append(s.onDuplicateKeyUpdateAssignments, assignment{ 131 | field: field, 132 | value: value, 133 | }) 134 | return s 135 | } 136 | 137 | func (s insertStatus) OnDuplicateKeyIgnore() toInsertWithDuplicateKey { 138 | firstField := s.scope.Tables[0].GetFields()[0] 139 | return s.OnDuplicateKeyUpdate().Set(firstField, firstField) 140 | } 141 | 142 | func (s insertStatus) GetSQL() (string, error) { 143 | var fields []Field 144 | var values []interface{} 145 | if len(s.models) > 0 { 146 | models := make([]Model, 0, len(s.models)) 147 | for _, model := range s.models { 148 | if err := addModel(&models, model); err != nil { 149 | return "", err 150 | } 151 | } 152 | 153 | if len(models) > 0 { 154 | fields = models[0].GetTable().GetFields() 155 | for _, model := range models { 156 | if model.GetTable().GetName() != s.scope.Tables[0].GetName() { 157 | return "", errors.New("invalid table from model") 158 | } 159 | values = append(values, model.GetValues()) 160 | } 161 | } 162 | } else { 163 | if len(s.fields) == 0 { 164 | fields = s.scope.Tables[0].GetFields() 165 | } else { 166 | fields = s.fields 167 | } 168 | values = s.values 169 | } 170 | 171 | if len(values) == 0 { 172 | return "/* INSERT without VALUES */ DO 0", nil 173 | } 174 | 175 | tableSql := s.scope.Tables[0].GetSQL(s.scope) 176 | fieldsSql, err := commaFields(s.scope, fields) 177 | if err != nil { 178 | return "", err 179 | } 180 | valuesSql, err := commaValues(s.scope, values) 181 | if err != nil { 182 | return "", err 183 | } 184 | 185 | sqlString := s.method + " INTO " + tableSql + " (" + fieldsSql + ") VALUES " + valuesSql 186 | if len(s.onDuplicateKeyUpdateAssignments) > 0 { 187 | assignmentsSql, err := commaAssignments(s.scope, s.onDuplicateKeyUpdateAssignments) 188 | if err != nil { 189 | return "", err 190 | } 191 | sqlString += " ON DUPLICATE KEY UPDATE " + assignmentsSql 192 | } 193 | 194 | return sqlString, nil 195 | } 196 | 197 | func (s insertStatus) WithContext(ctx context.Context) toInsertFinal { 198 | s.ctx = ctx 199 | return s 200 | } 201 | 202 | func (s insertStatus) Execute() (result sql.Result, err error) { 203 | sqlString, err := s.GetSQL() 204 | if err != nil { 205 | return nil, err 206 | } 207 | return s.scope.Database.ExecuteContext(s.ctx, sqlString) 208 | } 209 | -------------------------------------------------------------------------------- /cursor.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "reflect" 7 | "regexp" 8 | "strconv" 9 | "strings" 10 | "time" 11 | ) 12 | 13 | // Scanner is the interface that wraps the Scan method. 14 | type Scanner interface { 15 | Scan(dest ...interface{}) error 16 | } 17 | 18 | // Cursor is the interface of a row cursor. 19 | type Cursor interface { 20 | Next() bool 21 | Scan(dest ...interface{}) error 22 | GetMap() (map[string]value, error) 23 | Close() error 24 | } 25 | 26 | type cursor struct { 27 | rows *sql.Rows 28 | } 29 | 30 | func (c cursor) Next() bool { 31 | return c.rows.Next() 32 | } 33 | 34 | var timeType = reflect.TypeOf(time.Time{}) 35 | 36 | var simpleTimeLayoutRegexp = regexp.MustCompile(`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.(\d+))?$`) 37 | 38 | func guessTimeLayout(s string) string { 39 | matches := simpleTimeLayoutRegexp.FindStringSubmatch(s) 40 | if len(matches) > 0 { 41 | var sb strings.Builder 42 | sb.Grow(32) 43 | sb.WriteString("2006-01-02 15:04:05") 44 | if matches[1] != "" { 45 | sb.WriteString(".") 46 | for i := 0; i < len(matches[2]); i++ { 47 | sb.WriteByte('0') 48 | } 49 | } 50 | return sb.String() 51 | } 52 | return time.RFC3339Nano 53 | } 54 | 55 | func parseTime(s string) (time.Time, error) { 56 | if strings.HasPrefix(s, "0000-00-00") { 57 | // MySQL zero date 58 | return time.Time{}, nil 59 | } 60 | layout := guessTimeLayout(s) 61 | t, err := time.Parse(layout, s) 62 | if err != nil { 63 | return time.Time{}, fmt.Errorf("unknown time format %s: %w", s, err) 64 | } 65 | return t, nil 66 | } 67 | 68 | func isScanner(val reflect.Value) bool { 69 | _, ok := val.Addr().Interface().(sql.Scanner) 70 | return ok 71 | } 72 | 73 | func preparePointers(val reflect.Value, scans *[]interface{}) error { 74 | kind := val.Kind() 75 | switch kind { 76 | case reflect.Bool, 77 | reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 78 | reflect.Uint, reflect.Uint8, reflect.Uint32, reflect.Uint64, 79 | reflect.Float32, reflect.Float64, 80 | reflect.String: 81 | if addr := val.Addr(); addr.CanInterface() { 82 | *scans = append(*scans, addr.Interface()) 83 | } 84 | case reflect.Struct: 85 | if canScan := val.Type() == timeType || isScanner(val); canScan { 86 | *scans = append(*scans, val.Addr().Interface()) 87 | return nil 88 | } 89 | for j := 0; j < val.NumField(); j++ { 90 | field := val.Field(j) 91 | if field.Kind() == reflect.Interface { 92 | continue 93 | } 94 | if err := preparePointers(field, scans); err != nil { 95 | return err 96 | } 97 | } 98 | case reflect.Ptr: 99 | toType := val.Type().Elem() 100 | switch toType.Kind() { 101 | case reflect.Bool, 102 | reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 103 | reflect.Uint, reflect.Uint8, reflect.Uint32, reflect.Uint64, 104 | reflect.Float32, reflect.Float64, 105 | reflect.String: 106 | *scans = append(*scans, val.Addr().Interface()) 107 | case reflect.Struct: 108 | if toType == reflect.TypeOf(time.Time{}) { 109 | *scans = append(*scans, val.Addr().Interface()) 110 | } else { 111 | to := reflect.New(toType).Elem() 112 | val.Set(to.Addr()) 113 | err := preparePointers(to, scans) 114 | if err != nil { 115 | return nil 116 | } 117 | } 118 | default: 119 | to := reflect.New(toType).Elem() 120 | val.Set(to.Addr()) 121 | err := preparePointers(to, scans) 122 | if err != nil { 123 | return nil 124 | } 125 | } 126 | case reflect.Slice: 127 | if _, ok := (val.Interface()).([]byte); ok { 128 | *scans = append(*scans, val.Addr().Interface()) 129 | } else { 130 | return fmt.Errorf("unknown type []%s", val.Type().Elem().Kind().String()) 131 | } 132 | default: 133 | return fmt.Errorf("unknown type %s", kind.String()) 134 | } 135 | return nil 136 | } 137 | 138 | func parseBool(s []byte) (bool, error) { 139 | if len(s) == 1 { 140 | if s[0] == 0 { 141 | return false, nil 142 | } else if s[0] == 1 { 143 | return true, nil 144 | } 145 | } 146 | return strconv.ParseBool(string(s)) 147 | } 148 | 149 | func (c cursor) Scan(dest ...interface{}) error { 150 | columns, err := c.rows.Columns() 151 | if err != nil { 152 | return err 153 | } 154 | values := make([]interface{}, len(columns)) 155 | pointers := make([]interface{}, len(columns)) 156 | for i := range columns { 157 | pointers[i] = &values[i] 158 | } 159 | if err := c.rows.Scan(pointers...); err != nil { 160 | return err 161 | } 162 | 163 | if len(dest) == 0 { 164 | // dry run 165 | return nil 166 | } 167 | 168 | var scans []interface{} 169 | for i, item := range dest { 170 | if reflect.ValueOf(item).Kind() != reflect.Ptr { 171 | return fmt.Errorf("argument %d is not pointer", i) 172 | } 173 | 174 | val := reflect.Indirect(reflect.ValueOf(item)) 175 | 176 | err := preparePointers(val, &scans) 177 | if err != nil { 178 | return err 179 | } 180 | } 181 | 182 | pbs := make(map[int]*bool) 183 | ppbs := make(map[int]**bool) 184 | pts := make(map[int]*time.Time) 185 | ppts := make(map[int]**time.Time) 186 | 187 | for i, scan := range scans { 188 | switch scan.(type) { 189 | case *bool: 190 | var s []uint8 191 | scans[i] = &s 192 | pbs[i] = scan.(*bool) 193 | case **bool: 194 | var s *[]uint8 195 | scans[i] = &s 196 | ppbs[i] = scan.(**bool) 197 | case *time.Time: 198 | var s string 199 | scans[i] = &s 200 | pts[i] = scan.(*time.Time) 201 | case **time.Time: 202 | var s sql.NullString 203 | scans[i] = &s 204 | ppts[i] = scan.(**time.Time) 205 | } 206 | } 207 | 208 | if err := c.rows.Scan(scans...); err != nil { 209 | return err 210 | } 211 | 212 | for i, pb := range pbs { 213 | if *(scans[i].(*[]byte)) == nil { 214 | return fmt.Errorf("field %d is null", i) 215 | } 216 | b, err := parseBool(*(scans[i].(*[]byte))) 217 | if err != nil { 218 | return err 219 | } 220 | *pb = b 221 | } 222 | for i, ppb := range ppbs { 223 | if *(scans[i].(**[]uint8)) == nil { 224 | *ppb = nil 225 | } else { 226 | b, err := parseBool(**(scans[i].(**[]byte))) 227 | if err != nil { 228 | return err 229 | } 230 | *ppb = &b 231 | } 232 | } 233 | for i := range pts { 234 | s := scans[i].(*string) 235 | if s == nil { 236 | return fmt.Errorf("field %d is null", i) 237 | } 238 | t, err := parseTime(*s) 239 | if err != nil { 240 | return err 241 | } 242 | *pts[i] = t 243 | 244 | } 245 | for i := range ppts { 246 | nullString := scans[i].(*sql.NullString) 247 | if nullString == nil { 248 | return fmt.Errorf("field %d is null", i) 249 | } 250 | if !nullString.Valid { 251 | *ppts[i] = nil 252 | } else { 253 | t, err := parseTime(nullString.String) 254 | if err != nil { 255 | return err 256 | } 257 | *ppts[i] = &t 258 | } 259 | } 260 | 261 | return err 262 | } 263 | 264 | func (c cursor) GetMap() (result map[string]value, err error) { 265 | columns, err := c.rows.Columns() 266 | if err != nil { 267 | return 268 | } 269 | 270 | columnCount := len(columns) 271 | values := make([]interface{}, columnCount) 272 | for i := 0; i < columnCount; i++ { 273 | var value *string 274 | values[i] = &value 275 | } 276 | if err = c.rows.Scan(values...); err != nil { 277 | return 278 | } 279 | 280 | result = make(map[string]value, columnCount) 281 | for i, column := range columns { 282 | result[column] = value{stringValue: *values[i].(**string)} 283 | } 284 | 285 | return 286 | } 287 | 288 | func (c cursor) Close() error { 289 | return c.rows.Close() 290 | } 291 | -------------------------------------------------------------------------------- /expression_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | type CustomInt int 9 | type CustomBool bool 10 | type CustomFloat float32 11 | type CustomString string 12 | 13 | func TestExpression(t *testing.T) { 14 | assertValue(t, nil, "NULL") 15 | 16 | assertValue(t, false, "0") 17 | assertValue(t, true, "1") 18 | assertValue(t, CustomBool(false), "0") 19 | assertValue(t, CustomBool(true), "1") 20 | 21 | assertValue(t, int8(11), "11") 22 | assertValue(t, int16(11111), "11111") 23 | assertValue(t, int32(1111111111), "1111111111") 24 | assertValue(t, CustomInt(1111111111), "1111111111") 25 | assertValue(t, int(1111111111), "1111111111") 26 | assertValue(t, int64(1111111111111111111), "1111111111111111111") 27 | 28 | assertValue(t, int8(-11), "-11") 29 | assertValue(t, int16(-11111), "-11111") 30 | assertValue(t, int32(-1111111111), "-1111111111") 31 | assertValue(t, CustomInt(-1111111111), "-1111111111") 32 | assertValue(t, int(-1111111111), "-1111111111") 33 | assertValue(t, int64(-1111111111111111111), "-1111111111111111111") 34 | 35 | assertValue(t, uint8(1), "1") 36 | assertValue(t, uint16(55555), "55555") 37 | assertValue(t, uint32(3333333333), "3333333333") 38 | assertValue(t, uint(3333333333), "3333333333") 39 | assertValue(t, uint64(11111111111111111111), "11111111111111111111") 40 | 41 | assertValue(t, float32(2), "2") 42 | assertValue(t, float32(-2), "-2") 43 | assertValue(t, float64(2), "2") 44 | assertValue(t, float64(-2), "-2") 45 | 46 | assertValue(t, "abc", "'abc'") 47 | assertValue(t, "", "''") 48 | assertValue(t, "a' or 'a'='a", "'a\\' or \\'a\\'=\\'a'") 49 | assertValue(t, "\n", "'\\\n'") 50 | assertValue(t, CustomString("abc"), "'abc'") 51 | 52 | x := 3 53 | px := &x 54 | ppx := &px 55 | var deepNil *****int 56 | assertValue(t, &x, "3") 57 | assertValue(t, &px, "3") 58 | assertValue(t, &ppx, "3") 59 | assertValue(t, deepNil, "NULL") 60 | } 61 | 62 | func TestFunc(t *testing.T) { 63 | e := expression{ 64 | builder: func(scope scope) (string, error) { 65 | return "<>", nil 66 | }, 67 | } 68 | 69 | assertValue(t, e.Equals(e), "<> = <>") 70 | assertValue(t, e.NotEquals(e), "<> <> <>") 71 | assertValue(t, e.LessThan(e), "<> < <>") 72 | assertValue(t, e.LessThanOrEquals(e), "<> <= <>") 73 | assertValue(t, e.GreaterThan(e), "<> > <>") 74 | assertValue(t, e.GreaterThanOrEquals(e), "<> >= <>") 75 | assertValue(t, e.And(e), "<> AND <>") 76 | assertValue(t, e.Or(e), "<> OR <>") 77 | assertValue(t, e.Xor(e), "<> XOR <>") 78 | assertValue(t, e.Not(), "NOT <>") 79 | 80 | assertValue(t, e.Add(e), "<> + <>") 81 | assertValue(t, e.Sub(e), "<> - <>") 82 | assertValue(t, e.Mul(e), "<> * <>") 83 | assertValue(t, e.Div(e), "<> / <>") 84 | assertValue(t, e.IntDiv(e), "<> DIV <>") 85 | assertValue(t, e.Mod(e), "<> % <>") 86 | assertValue(t, e.Sum(), "SUM(<>)") 87 | assertValue(t, e.Avg(), "AVG(<>)") 88 | assertValue(t, e.Min(), "MIN(<>)") 89 | assertValue(t, e.Max(), "MAX(<>)") 90 | assertValue(t, e.Between(2, 4), "<> BETWEEN 2 AND 4") 91 | assertValue(t, e.NotBetween(2, 4), "<> NOT BETWEEN 2 AND 4") 92 | 93 | assertValue(t, e.In(), "FALSE") 94 | assertValue(t, e.In(1), "<> = 1") 95 | assertValue(t, e.In(1, 2, 3), "<> IN (1, 2, 3)") 96 | assertValue(t, e.In([]int64{}), "FALSE") 97 | assertValue(t, e.In([]int64{1}), "<> = 1") 98 | assertValue(t, e.In([]int64{1, 2, 3}), "<> IN (1, 2, 3)") 99 | assertValue(t, e.In([]byte{1, 2, 3}), "<> IN (1, 2, 3)") 100 | 101 | assertValue(t, e.NotIn(), "TRUE") 102 | assertValue(t, e.NotIn(1), "<> <> 1") 103 | assertValue(t, e.NotIn(1, 2, 3), "<> NOT IN (1, 2, 3)") 104 | assertValue(t, e.NotIn([]int64{}), "TRUE") 105 | assertValue(t, e.NotIn([]int64{1}), "<> <> 1") 106 | assertValue(t, e.NotIn([]int64{1, 2, 3}), "<> NOT IN (1, 2, 3)") 107 | 108 | assertValue(t, e.Like("%A%"), "<> LIKE '%A%'") 109 | assertValue(t, e.Concat("-suffix"), "CONCAT(<>, '-suffix')") 110 | assertValue(t, e.Contains("\n"), "LOCATE('\\\n', <>) > 0") 111 | 112 | assertValue(t, []interface{}{1, 2, 3, "d"}, "(1, 2, 3, 'd')") 113 | 114 | assertValue(t, e.IsNull(), "<> IS NULL") 115 | assertValue(t, e.IsNotNull(), "<> IS NOT NULL") 116 | assertValue(t, e.IsTrue(), "<> IS TRUE") 117 | assertValue(t, e.IsNotTrue(), "<> IS NOT TRUE") 118 | assertValue(t, e.IsFalse(), "<> IS FALSE") 119 | assertValue(t, e.IsNotFalse(), "<> IS NOT FALSE") 120 | assertValue(t, e.If(3, 4), "IF(<>, 3, 4)") 121 | assertValue(t, e.IfNull(3), "IFNULL(<>, 3)") 122 | assertValue(t, e.IfEmpty(3), "IF(<> <> '', <>, 3)") 123 | assertValue(t, e.IsEmpty(), "<> = ''") 124 | assertValue(t, e.Lower(), "LOWER(<>)") 125 | assertValue(t, e.Upper(), "UPPER(<>)") 126 | assertValue(t, e.Left(10), "LEFT(<>, 10)") 127 | assertValue(t, e.Right(10), "RIGHT(<>, 10)") 128 | assertValue(t, e.Trim(), "TRIM(<>)") 129 | assertValue(t, e.HasPrefix("abc"), "LEFT(<>, CHAR_LENGTH('abc')) = 'abc'") 130 | assertValue(t, e.HasSuffix("abc"), "RIGHT(<>, CHAR_LENGTH('abc')) = 'abc'") 131 | 132 | e5 := expression{ 133 | builder: func(scope scope) (string, error) { 134 | return "e5", nil 135 | }, 136 | priority: 5, 137 | } 138 | e7 := expression{ 139 | builder: func(scope scope) (string, error) { 140 | return "e7", nil 141 | }, 142 | priority: 7, 143 | } 144 | e9 := expression{ 145 | builder: func(scope scope) (string, error) { 146 | return "e9", nil 147 | }, 148 | priority: 9, 149 | } 150 | 151 | assertValue(t, e7.Add(e7), "e7 + (e7)") 152 | assertValue(t, e5.Add(e7), "e5 + (e7)") 153 | assertValue(t, e7.Add(e5), "e7 + e5") 154 | assertValue(t, e5.Add(e9), "e5 + (e9)") 155 | assertValue(t, e9.Add(e5), "(e9) + e5") 156 | 157 | ee := expression{ 158 | builder: func(scope scope) (string, error) { 159 | return "", errors.New("error") 160 | }, 161 | } 162 | assertError(t, e.Add(ee)) 163 | assertError(t, ee.Add(e)) 164 | assertError(t, ee.IsNull()) 165 | assertError(t, e.In(ee, ee, ee)) 166 | assertError(t, ee.In(e, e, e)) 167 | 168 | assertError(t, ee.Between(2, 4)) 169 | assertError(t, e.Between(2, ee)) 170 | assertError(t, e.Between(ee, 4)) 171 | 172 | } 173 | 174 | func TestMisc(t *testing.T) { 175 | assertValue(t, True(), "TRUE") 176 | assertValue(t, False(), "FALSE") 177 | 178 | assertValue(t, command("COMMAND", staticExpression("", 0, false)), "COMMAND ") 179 | 180 | assertValue(t, staticExpression("", 1, false). 181 | prefixSuffixExpression("", "", 1, false), "") 182 | } 183 | 184 | func TestLogicalExpression(t *testing.T) { 185 | a := expression{sql: "a", priority: 1} 186 | b := expression{sql: "b", priority: 1} 187 | c := expression{sql: "c", priority: 1} 188 | d := expression{sql: "d", priority: 1} 189 | 190 | assertValue(t, And(a, b, c, d), "a AND b AND c AND d") 191 | assertValue(t, Or(a, b, c, d), "a OR b OR c OR d") 192 | assertValue(t, a.And(b).Or(c).And(a).Or(b).And(c), "((a AND b OR c) AND a OR b) AND c") 193 | assertValue(t, a.Or(b).And(c.Or(d)), "(a OR b) AND (c OR d)") 194 | assertValue(t, a.Or(b).And(c).Not(), "NOT ((a OR b) AND c)") 195 | 196 | assertValue(t, And(), "TRUE") 197 | assertValue(t, Or(), "FALSE") 198 | } 199 | 200 | func TestLogicalOptimizer(t *testing.T) { 201 | trueValue := True() 202 | falseValue := False() 203 | otherValue := staticExpression("<>", 0, false) 204 | otherBoolValue := staticExpression("<>", 0, true) 205 | 206 | assertValue(t, trueValue.Or(trueValue), "TRUE") 207 | assertValue(t, trueValue.Or(falseValue), "TRUE") 208 | assertValue(t, falseValue.Or(trueValue), "TRUE") 209 | assertValue(t, falseValue.Or(falseValue), "FALSE") 210 | 211 | assertValue(t, trueValue.And(trueValue), "TRUE") 212 | assertValue(t, trueValue.And(falseValue), "FALSE") 213 | assertValue(t, falseValue.And(trueValue), "FALSE") 214 | assertValue(t, falseValue.And(falseValue), "FALSE") 215 | 216 | assertValue(t, falseValue.Not(), "TRUE") 217 | assertValue(t, trueValue.Not(), "FALSE") 218 | 219 | assertValue(t, trueValue.And(otherValue), "TRUE AND <>") 220 | assertValue(t, trueValue.Or(otherValue), "TRUE") 221 | assertValue(t, trueValue.And(123), "TRUE AND 123") 222 | assertValue(t, trueValue.Or(123), "TRUE") 223 | assertValue(t, falseValue.And(otherValue), "FALSE") 224 | assertValue(t, falseValue.Or(otherValue), "FALSE OR <>") 225 | assertValue(t, falseValue.And(123), "FALSE") 226 | assertValue(t, falseValue.Or(123), "FALSE OR 123") 227 | 228 | assertValue(t, trueValue.And(otherBoolValue), "<>") 229 | assertValue(t, falseValue.Or(otherBoolValue), "<>") 230 | } 231 | -------------------------------------------------------------------------------- /database.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "os" 8 | "path/filepath" 9 | "runtime" 10 | "strings" 11 | "sync" 12 | "time" 13 | ) 14 | 15 | const ( 16 | // for colorful terminal print 17 | green = "\033[32m" 18 | red = "\033[31m" 19 | blue = "\033[34m" 20 | reset = "\033[0m" 21 | ) 22 | 23 | // Database is the interface of a database with underlying sql.DB object. 24 | type Database interface { 25 | // GetDB returns the underlying sql.DB object of the database 26 | GetDB() *sql.DB 27 | // BeginTx starts a transaction and executes the function f. 28 | BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx Transaction) error) error 29 | // Query executes a query and returns the cursor 30 | Query(sql string) (Cursor, error) 31 | // QueryContext executes a query with context and returns the cursor 32 | QueryContext(ctx context.Context, sqlString string) (Cursor, error) 33 | // Execute executes a statement 34 | Execute(sql string) (sql.Result, error) 35 | // ExecuteContext executes a statement with context 36 | ExecuteContext(ctx context.Context, sql string) (sql.Result, error) 37 | // SetLogger sets the logger function. 38 | // Deprecated: use SetInterceptor instead 39 | SetLogger(logger LoggerFunc) 40 | // SetRetryPolicy sets the retry policy function. 41 | // Deprecated: use SetInterceptor instead 42 | SetRetryPolicy(retryPolicy func(err error) bool) 43 | // EnableCallerInfo enable or disable the caller info in the log. 44 | // Deprecated: use SetInterceptor instead 45 | EnableCallerInfo(enableCallerInfo bool) 46 | // SetInterceptor sets an interceptor function 47 | SetInterceptor(interceptor InterceptorFunc) 48 | 49 | // Select initiates a SELECT statement 50 | Select(fields ...interface{}) selectWithFields 51 | // SelectDistinct initiates a SELECT DISTINCT statement 52 | SelectDistinct(fields ...interface{}) selectWithFields 53 | // SelectFrom initiates a SELECT * FROM statement 54 | SelectFrom(tables ...Table) selectWithTables 55 | // InsertInto initiates a INSERT INTO statement 56 | InsertInto(table Table) insertWithTable 57 | // ReplaceInto initiates a REPLACE INTO statement 58 | ReplaceInto(table Table) insertWithTable 59 | // Update initiates a UPDATE statement 60 | Update(table Table) updateWithSet 61 | // DeleteFrom initiates a DELETE FROM statement 62 | DeleteFrom(table Table) deleteWithTable 63 | } 64 | 65 | type txOrDB interface { 66 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 67 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 68 | } 69 | 70 | var ( 71 | once sync.Once 72 | srcPrefix string 73 | ) 74 | 75 | type database struct { 76 | db *sql.DB 77 | tx *sql.Tx 78 | logger LoggerFunc 79 | dialect dialect 80 | retryPolicy func(error) bool 81 | enableCallerInfo bool 82 | interceptor InterceptorFunc 83 | } 84 | 85 | type LoggerFunc func(sql string, duration time.Duration, isTx bool, retry bool) 86 | 87 | func (d *database) SetLogger(loggerFunc LoggerFunc) { 88 | d.logger = loggerFunc 89 | } 90 | 91 | // DefaultLogger is sqlingo default logger, 92 | // which print log to stderr and regard executing time gt 100ms as slow sql. 93 | func DefaultLogger(sql string, duration time.Duration, isTx bool, retry bool) { 94 | // for finding code position, try once is enough 95 | once.Do(func() { 96 | // $GOPATH/pkg/mod/github.com/lqs/sqlingo@vX.X.X/database.go 97 | _, file, _, _ := runtime.Caller(0) 98 | // $GOPATH/pkg/mod/github.com/lqs/sqlingo@vX.X.X 99 | srcPrefix = filepath.Dir(file) 100 | }) 101 | 102 | var file string 103 | var line int 104 | var ok bool 105 | for i := 0; i < 16; i++ { 106 | _, file, line, ok = runtime.Caller(i) 107 | // `!strings.HasPrefix(file, srcPrefix)` jump out when using sqlingo as dependent package 108 | // `strings.HasSuffix(file, "_test.go")` jump out when executing unit test cases 109 | // `!ok` this is so terrible for something unexpected happened 110 | if !ok || !strings.HasPrefix(file, srcPrefix) || strings.HasSuffix(file, "_test.go") { 111 | break 112 | } 113 | } 114 | 115 | // todo shouldn't append ';' here 116 | if !strings.HasSuffix(sql, ";") { 117 | sql += ";" 118 | } 119 | 120 | sb := strings.Builder{} 121 | sb.Grow(32) 122 | sb.WriteString("|") 123 | sb.WriteString(duration.String()) 124 | if isTx { 125 | sb.WriteString("|transaction") // todo using something traceable 126 | } 127 | if retry { 128 | sb.WriteString("|retry") 129 | } 130 | sb.WriteString("|") 131 | 132 | line1 := strings.Join( 133 | []string{ 134 | "[sqlingo]", 135 | time.Now().Format("2006-01-02 15:04:05"), 136 | sb.String(), 137 | file + ":" + fmt.Sprint(line), 138 | }, 139 | " ") 140 | 141 | // print to stderr 142 | fmt.Fprintln(os.Stderr, blue+line1+reset) 143 | if duration < 100*time.Millisecond { 144 | fmt.Fprintf(os.Stderr, "%s%s%s\n", green, sql, reset) 145 | } else { 146 | fmt.Fprintf(os.Stderr, "%s%s%s\n", red, sql, reset) 147 | } 148 | fmt.Fprintln(os.Stderr) 149 | } 150 | 151 | func (d *database) SetRetryPolicy(retryPolicy func(err error) bool) { 152 | d.retryPolicy = retryPolicy 153 | } 154 | 155 | func (d *database) EnableCallerInfo(enableCallerInfo bool) { 156 | d.enableCallerInfo = enableCallerInfo 157 | } 158 | 159 | func (d *database) SetInterceptor(interceptor InterceptorFunc) { 160 | d.interceptor = interceptor 161 | } 162 | 163 | // Open a database, similar to sql.Open. 164 | // `db` using a default logger, which print log to stderr and regard executing time gt 100ms as slow sql. 165 | // To disable the default logger, use `db.SetLogger(nil)`. 166 | func Open(driverName string, dataSourceName string) (db Database, err error) { 167 | var sqlDB *sql.DB 168 | if dataSourceName != "" { 169 | sqlDB, err = sql.Open(driverName, dataSourceName) 170 | if err != nil { 171 | return 172 | } 173 | } 174 | db = Use(driverName, sqlDB) 175 | return 176 | } 177 | 178 | // Use an existing *sql.DB handle 179 | func Use(driverName string, sqlDB *sql.DB) Database { 180 | return &database{ 181 | dialect: getDialectFromDriverName(driverName), 182 | db: sqlDB, 183 | } 184 | } 185 | 186 | func (d database) GetDB() *sql.DB { 187 | return d.db 188 | } 189 | 190 | func (d database) getTxOrDB() txOrDB { 191 | if d.tx != nil { 192 | return d.tx 193 | } 194 | return d.db 195 | } 196 | 197 | func (d database) Query(sqlString string) (Cursor, error) { 198 | return d.QueryContext(context.Background(), sqlString) 199 | } 200 | 201 | func (d database) QueryContext(ctx context.Context, sqlString string) (Cursor, error) { 202 | isRetry := false 203 | for { 204 | sqlStringWithCallerInfo := getCallerInfo(d, isRetry) + sqlString 205 | rows, err := d.queryContextOnce(ctx, sqlStringWithCallerInfo, isRetry) 206 | if err != nil { 207 | isRetry = d.tx == nil && d.retryPolicy != nil && d.retryPolicy(err) 208 | if isRetry { 209 | continue 210 | } 211 | return nil, err 212 | } 213 | return cursor{rows: rows}, nil 214 | } 215 | } 216 | 217 | func (d database) queryContextOnce(ctx context.Context, sqlString string, retry bool) (*sql.Rows, error) { 218 | if ctx == nil { 219 | ctx = context.Background() 220 | } 221 | startTime := time.Now() 222 | defer func() { 223 | endTime := time.Now() 224 | if d.logger != nil { 225 | d.logger(sqlString, endTime.Sub(startTime), false, retry) 226 | } 227 | }() 228 | 229 | interceptor := d.interceptor 230 | var rows *sql.Rows 231 | invoker := func(ctx context.Context, sql string) (err error) { 232 | rows, err = d.getTxOrDB().QueryContext(ctx, sql) 233 | return 234 | } 235 | 236 | var err error 237 | if interceptor == nil { 238 | err = invoker(ctx, sqlString) 239 | } else { 240 | err = interceptor(ctx, sqlString, invoker) 241 | } 242 | if err != nil { 243 | return nil, err 244 | } 245 | 246 | return rows, nil 247 | } 248 | 249 | func (d database) Execute(sqlString string) (sql.Result, error) { 250 | return d.ExecuteContext(context.Background(), sqlString) 251 | } 252 | 253 | // ExecuteContext todo Is there need retry? 254 | func (d database) ExecuteContext(ctx context.Context, sqlString string) (sql.Result, error) { 255 | if ctx == nil { 256 | ctx = context.Background() 257 | } 258 | sqlStringWithCallerInfo := getCallerInfo(d, false) + sqlString 259 | startTime := time.Now() 260 | defer func() { 261 | endTime := time.Now() 262 | if d.logger != nil { 263 | d.logger(sqlStringWithCallerInfo, endTime.Sub(startTime), false, false) 264 | } 265 | }() 266 | 267 | var result sql.Result 268 | invoker := func(ctx context.Context, sql string) (err error) { 269 | result, err = d.getTxOrDB().ExecContext(ctx, sql) 270 | return 271 | } 272 | var err error 273 | if d.interceptor == nil { 274 | err = invoker(ctx, sqlStringWithCallerInfo) 275 | } else { 276 | err = d.interceptor(ctx, sqlStringWithCallerInfo, invoker) 277 | } 278 | if err != nil { 279 | return nil, err 280 | } 281 | 282 | return result, err 283 | } 284 | -------------------------------------------------------------------------------- /select_test.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | ) 7 | 8 | type tTable1 struct { 9 | Table 10 | } 11 | 12 | var Table1 = tTable1{ 13 | NewTable("table1"), 14 | } 15 | 16 | var table1 = NewTable("table1") 17 | var field1 = NewNumberField(table1, "field1") 18 | var field2 = NewNumberField(table1, "field2") 19 | 20 | var table2 = NewTable("table2") 21 | var field3 = NewNumberField(table2, "field3") 22 | 23 | var table3 = NewTable("table3") 24 | var field4 = NewNumberField(table3, "field4") 25 | 26 | func (t tTable1) GetFields() []Field { 27 | return []Field{field1, field2} 28 | } 29 | 30 | func (t tTable1) GetFieldsSQL() string { 31 | return "" 32 | } 33 | 34 | func (t tTable1) GetFullFieldsSQL() string { 35 | return "" 36 | } 37 | 38 | func TestSelect(t *testing.T) { 39 | db := newMockDatabase() 40 | assertValue(t, db.Select(1), "(SELECT 1)") 41 | 42 | db.Select(field1).From(Table1).Where(field1.Equals(42)).Limit(10).GetSQL() 43 | 44 | _, _ = db.Select(field1).From(Table1).Where(field1.Equals(1), field2.Equals(2)).FetchFirst() 45 | assertLastSql(t, "SELECT `field1` FROM `table1` WHERE `field1` = 1 AND `field2` = 2") 46 | 47 | _, _ = db.Select(field1).From(Table1).Where(field1.Equals(1)).Where(field2.Equals(2)).FetchFirst() 48 | assertLastSql(t, "SELECT `field1` FROM `table1` WHERE `field1` = 1 AND `field2` = 2") 49 | 50 | _, _ = db.Select(field1).From(Table1).Where(field1.Equals(1)).WhereIf(true, field2.Equals(2)).FetchFirst() 51 | assertLastSql(t, "SELECT `field1` FROM `table1` WHERE `field1` = 1 AND `field2` = 2") 52 | 53 | _, _ = db.Select(field1).From(Table1).Where(field1.Equals(1)).WhereIf(false, field2.Equals(2)).FetchFirst() 54 | assertLastSql(t, "SELECT `field1` FROM `table1` WHERE `field1` = 1") 55 | 56 | _, _ = db.Select(field1).From(Table1).WhereIf(true, field2.Equals(2)).FetchFirst() 57 | assertLastSql(t, "SELECT `field1` FROM `table1` WHERE `field2` = 2") 58 | 59 | _, _ = db.Select(field1).From(Table1).WhereIf(false, field2.Equals(2)).FetchFirst() 60 | assertLastSql(t, "SELECT `field1` FROM `table1`") 61 | 62 | _, _ = db.Select(field1, field2, field3, Count(1).As("count")). 63 | From(Table1, table2). 64 | Where(field1.Equals(field3), field2.In(db.Select(field3).From(table2))). 65 | GroupBy(field2). 66 | Having(Raw("count").GreaterThan(1)). 67 | OrderBy(field1.Desc(), field2). 68 | Limit(10). 69 | Offset(20). 70 | LockInShareMode(). 71 | FetchFirst() 72 | assertLastSql(t, "SELECT `table1`.`field1`, `table1`.`field2`, `table2`.`field3`, COUNT(1) AS count FROM `table1`, `table2` WHERE `table1`.`field1` = `table2`.`field3` AND `table1`.`field2` IN (SELECT `field3` FROM `table2`) GROUP BY `table1`.`field2` HAVING (count) > 1 ORDER BY `table1`.`field1` DESC, `table1`.`field2` LIMIT 10 OFFSET 20 LOCK IN SHARE MODE") 73 | 74 | _, _ = db.SelectDistinct(field2).From(Table1).FetchFirst() 75 | assertLastSql(t, "SELECT DISTINCT `field2` FROM `table1`") 76 | 77 | _, _ = db.Select(field1, field3).From(Table1).Join(table2).On(field1.Equals(field3)).FetchFirst() 78 | assertLastSql(t, "SELECT `table1`.`field1`, `table2`.`field3` FROM `table1` JOIN `table2` ON `table1`.`field1` = `table2`.`field3`") 79 | _, _ = db.Select(field1, field3).From(Table1).LeftJoin(table2).On(field1.Equals(field3)).FetchFirst() 80 | assertLastSql(t, "SELECT `table1`.`field1`, `table2`.`field3` FROM `table1` LEFT JOIN `table2` ON `table1`.`field1` = `table2`.`field3`") 81 | _, _ = db.Select(field1, field3).From(Table1).RightJoin(table2).On(field1.Equals(field3)).FetchFirst() 82 | assertLastSql(t, "SELECT `table1`.`field1`, `table2`.`field3` FROM `table1` RIGHT JOIN `table2` ON `table1`.`field1` = `table2`.`field3`") 83 | 84 | _, _ = db.Select(field1, field3).From(Table1). 85 | LeftJoin(table2).On(field1.Equals(field3)). 86 | RightJoin(table3).On(field1.Equals(field4)).FetchFirst() 87 | assertLastSql(t, "SELECT `table1`.`field1`, `table2`.`field3` FROM `table1` LEFT JOIN `table2` ON `table1`.`field1` = `table2`.`field3` RIGHT JOIN `table3` ON `table1`.`field1` = `table3`.`field4`") 88 | 89 | db.Select(1).WithContext(context.Background()) 90 | 91 | _, _ = db.SelectFrom(Table1).FetchFirst() 92 | assertLastSql(t, "SELECT FROM `table1`") 93 | 94 | _, _ = db.Select([]Field{field1, field2}).From(Table1).FetchFirst() 95 | assertLastSql(t, "SELECT `field1`, `field2` FROM `table1`") 96 | 97 | _, _ = db.Select([]interface{}{&field1, field2, []int{3, 4}}).From(Table1).FetchFirst() 98 | assertLastSql(t, "SELECT `field1`, `field2`, 3, 4 FROM `table1`") 99 | 100 | _, _ = db.Select(field1, Table1).FetchFirst() 101 | assertLastSql(t, "SELECT `field1`, `field1`, `field2` FROM `table1`") 102 | } 103 | 104 | func TestCount(t *testing.T) { 105 | db := newMockDatabase() 106 | 107 | _, _ = db.SelectFrom(Test).Count() 108 | assertLastSql(t, "SELECT COUNT(1) FROM `test`") 109 | 110 | _, _ = db.SelectDistinct(Test.F1).From(Test).Count() 111 | assertLastSql(t, "SELECT COUNT(DISTINCT `f1`) FROM `test`") 112 | 113 | _, _ = db.Select(Test.F1).From(Test).GroupBy(Test.F2).Count() 114 | assertLastSql(t, "SELECT COUNT(1) FROM (SELECT 1 FROM `test` GROUP BY `f2`) AS t") 115 | 116 | _, _ = db.SelectDistinct(Test.F1).From(Test).GroupBy(Test.F2).Count() 117 | assertLastSql(t, "SELECT COUNT(1) FROM (SELECT DISTINCT `f1` FROM `test` GROUP BY `f2`) AS t") 118 | 119 | _, _ = db.Select(Test.F1).From(Test).Exists() 120 | assertLastSql(t, "SELECT EXISTS (SELECT `f1` FROM `test`)") 121 | 122 | _, _ = db.Select(Test.F1).From(Test).Limit(10).Count() 123 | assertLastSql(t, "SELECT COUNT(1) FROM (SELECT 1 FROM `test` LIMIT 10) AS t") 124 | } 125 | 126 | func TestSelectAutoFrom(t *testing.T) { 127 | db := newMockDatabase() 128 | 129 | _, _ = db.Select(field1, field2, 123).FetchFirst() 130 | assertLastSql(t, "SELECT `field1`, `field2`, 123 FROM `table1`") 131 | 132 | _, _ = db.Select(field1, field2, 123, field3).FetchFirst() 133 | assertLastSql(t, "SELECT `table1`.`field1`, `table1`.`field2`, 123, `table2`.`field3` FROM `table1`, `table2`") 134 | } 135 | 136 | func TestFetch(t *testing.T) { 137 | db := newMockDatabase() 138 | defer func() { 139 | sharedMockConn.columnCount = 7 140 | sharedMockConn.rowCount = 10 141 | }() 142 | 143 | sharedMockConn.columnCount = 2 144 | 145 | _ = db 146 | var f1 string 147 | var f2 int 148 | 149 | ok, err := db.Select(field1, field2).From(Table1).FetchFirst(&f1, &f2) 150 | if !ok || err != nil { 151 | t.Error() 152 | } 153 | 154 | if err := db.Select(field1, field2).From(Table1).FetchExactlyOne(&f1, &f2); err == nil { 155 | t.Error("should get error") 156 | } 157 | 158 | sharedMockConn.rowCount = 1 159 | if err := db.Select(field1, field2).From(Table1).FetchExactlyOne(&f1, &f2); err != nil { 160 | t.Error(err) 161 | } 162 | 163 | sharedMockConn.rowCount = 0 164 | if err := db.Select(field1, field2).From(Table1).FetchExactlyOne(&f1, &f2); err == nil { 165 | t.Error("should get error") 166 | } 167 | 168 | } 169 | 170 | func TestFetchAll(t *testing.T) { 171 | db := newMockDatabase() 172 | 173 | sharedMockConn.columnCount = 2 174 | defer func() { 175 | sharedMockConn.columnCount = 7 176 | }() 177 | 178 | // fetch all as slices 179 | var f1s []string 180 | var f2s []int 181 | if _, err := db.Select(field1, field2).From(Table1).FetchAll(&f1s, &f2s); err != nil { 182 | t.Error(err) 183 | } 184 | if len(f1s) != 10 || len(f2s) != 10 { 185 | t.Error(f1s, f2s) 186 | } 187 | 188 | type record struct { 189 | unexportedF1 string 190 | ExportedF1 string 191 | unexportedF2 int 192 | ExportedF2 int 193 | } 194 | var records []record 195 | if _, err := db.Select(field1, field2).From(Table1).FetchAll(&records); err != nil { 196 | t.Error(err) 197 | } 198 | if len(records) != 10 { 199 | t.Error(records) 200 | } 201 | 202 | // fetch all as map 203 | var m map[string]int 204 | if _, err := db.Select(field1).From(Table1).FetchAll(&m); err != nil { 205 | t.Error(err) 206 | } 207 | 208 | // fetch all as multiple maps is illegal 209 | if _, err := db.Select(field1).From(Table1).FetchAll(&m, &m); err == nil { 210 | t.Error("should get error here") 211 | } 212 | 213 | // fetch all as unsupported type 214 | var unsupported int 215 | if _, err := db.Select(field1).From(Table1).FetchAll(&unsupported); err == nil { 216 | t.Error("should get error here") 217 | } 218 | 219 | // fetch all for non-pointer 220 | if _, err := db.Select(field1).From(Table1).FetchAll(123); err == nil { 221 | t.Error("should get error here") 222 | } 223 | } 224 | 225 | func TestRangeFunc(t *testing.T) { 226 | db := newMockDatabase() 227 | 228 | oldColumnCount := sharedMockConn.columnCount 229 | sharedMockConn.columnCount = 2 230 | defer func() { 231 | sharedMockConn.columnCount = oldColumnCount 232 | }() 233 | 234 | count := 0 235 | seq := db.Select(field1, field2).From(Table1).FetchSeq() 236 | 237 | // for row := range db.Select(field1, field2).From(Table1).FetchSeq() {} 238 | seq(func(row Scanner) bool { 239 | var f1 string 240 | var f2 int 241 | if err := row.Scan(&f1, &f2); err != nil { 242 | t.Error(err) 243 | } 244 | count++ 245 | return true 246 | }) 247 | 248 | if count != 10 { 249 | t.Error(count) 250 | } 251 | } 252 | 253 | func TestLock(t *testing.T) { 254 | db := newMockDatabase() 255 | table1 := NewTable("table1") 256 | _, _ = db.Select(1).From(table1).LockInShareMode().FetchAll() 257 | assertLastSql(t, "SELECT 1 FROM `table1` LOCK IN SHARE MODE") 258 | _, _ = db.Select(1).From(table1).ForUpdate().FetchAll() 259 | assertLastSql(t, "SELECT 1 FROM `table1` FOR UPDATE") 260 | _, _ = db.Select(1).From(table1).ForUpdateNoWait().FetchAll() 261 | assertLastSql(t, "SELECT 1 FROM `table1` FOR UPDATE NOWAIT") 262 | _, _ = db.Select(1).From(table1).ForUpdateSkipLocked().FetchAll() 263 | assertLastSql(t, "SELECT 1 FROM `table1` FOR UPDATE SKIP LOCKED") 264 | } 265 | 266 | func TestUnion(t *testing.T) { 267 | db := newMockDatabase() 268 | table1 := NewTable("table1") 269 | table2 := NewTable("table2") 270 | 271 | cond1 := Raw("") 272 | cond2 := Raw("") 273 | 274 | _, _ = db.SelectFrom(table1).UnionSelectFrom(table2).Where(cond1).FetchAll() 275 | assertLastSql(t, "SELECT * FROM `table1` UNION SELECT * FROM `table2` WHERE ") 276 | 277 | _, _ = db.SelectFrom(table1).Where(cond1). 278 | UnionSelectFrom(table2).Where(cond2).FetchAll() 279 | assertLastSql(t, "SELECT * FROM `table1` WHERE UNION SELECT * FROM `table2` WHERE ") 280 | 281 | _, _ = db.SelectFrom(table1).Where(Raw("C1")). 282 | UnionSelectFrom(table2).Where(Raw("C2")). 283 | UnionSelect(3).From(table2).Where(Raw("C3")). 284 | UnionSelectDistinct(4).From(table2).Where(Raw("C4")). 285 | UnionAllSelectFrom(table2).Where(Raw("C5")). 286 | UnionAllSelect(6).From(table2).Where(Raw("C6")). 287 | UnionAllSelectDistinct(7).From(table2).Where(Raw("C7")). 288 | FetchAll() 289 | assertLastSql(t, "SELECT * FROM `table1` WHERE C1 "+ 290 | "UNION SELECT * FROM `table2` WHERE C2 "+ 291 | "UNION SELECT 3 FROM `table2` WHERE C3 "+ 292 | "UNION SELECT DISTINCT 4 FROM `table2` WHERE C4 "+ 293 | "UNION ALL SELECT * FROM `table2` WHERE C5 "+ 294 | "UNION ALL SELECT 6 FROM `table2` WHERE C6 "+ 295 | "UNION ALL SELECT DISTINCT 7 FROM `table2` WHERE C7") 296 | 297 | _, _ = db.SelectFrom(table1).Where(Raw("C1")). 298 | UnionSelectFrom(table2).Where(Raw("C2")). 299 | UnionSelect(3).From(table2).Where(Raw("C3")). 300 | UnionSelectDistinct(4).From(table2).Where(Raw("C4")). 301 | UnionAllSelectFrom(table2).Where(Raw("C5")). 302 | UnionAllSelect(6).From(table2).Where(Raw("C6")). 303 | UnionAllSelectDistinct(7).From(table2).Where(Raw("C7")). 304 | Count() 305 | assertLastSql(t, "SELECT COUNT(1) FROM ("+ 306 | "SELECT 1 FROM `table1` WHERE C1 "+ 307 | "UNION SELECT * FROM `table2` WHERE C2 "+ 308 | "UNION SELECT 3 FROM `table2` WHERE C3 "+ 309 | "UNION SELECT DISTINCT 4 FROM `table2` WHERE C4 "+ 310 | "UNION ALL SELECT * FROM `table2` WHERE C5 "+ 311 | "UNION ALL SELECT 6 FROM `table2` WHERE C6 "+ 312 | "UNION ALL SELECT DISTINCT 7 FROM `table2` WHERE C7"+ 313 | ") AS t") 314 | } 315 | 316 | func Test_selectStatus_NaturalJoin(t *testing.T) { 317 | db := newMockDatabase() 318 | table1 := NewTable("table1") 319 | table2 := NewTable("table2") 320 | table3 := NewTable("table3") 321 | table4 := NewTable("table4") 322 | cond1 := Raw("") 323 | cond2 := Raw("") 324 | cond3 := Raw("") 325 | 326 | _, _ = db.SelectFrom(table1).NaturalJoin(table2).Where(cond1).FetchAll() 327 | assertLastSql(t, "SELECT * FROM `table1` NATURAL JOIN `table2` WHERE ") 328 | _, _ = db.SelectFrom(table1). 329 | Join(table2).On(cond1). 330 | NaturalJoin(table2). 331 | NaturalJoin(table3). 332 | LeftJoin(table4).On(cond3). 333 | Where(cond2).FetchAll() 334 | assertLastSql(t, "SELECT * FROM `table1` JOIN `table2` ON NATURAL JOIN `table2`"+ 335 | " NATURAL JOIN `table3` LEFT JOIN `table4` ON WHERE ") 336 | 337 | } 338 | -------------------------------------------------------------------------------- /generator/generator.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | "go/format" 8 | "os" 9 | "regexp" 10 | "strconv" 11 | "strings" 12 | "sync" 13 | "sync/atomic" 14 | "unicode" 15 | ) 16 | 17 | const ( 18 | sqlingoGeneratorVersion = 2 19 | ) 20 | 21 | type schemaFetcher interface { 22 | GetDatabaseName() (dbName string, err error) 23 | GetTableNames() (tableNames []string, err error) 24 | GetFieldDescriptors(tableName string) ([]fieldDescriptor, error) 25 | QuoteIdentifier(identifier string) string 26 | } 27 | 28 | type fieldDescriptor struct { 29 | Name string 30 | Type string 31 | Size int 32 | Unsigned bool 33 | AllowNull bool 34 | Comment string 35 | } 36 | 37 | func convertToExportedIdentifier(s string, forceCases []string) string { 38 | var words []string 39 | nextCharShouldBeUpperCase := true 40 | for _, r := range s { 41 | if unicode.IsLetter(r) || unicode.IsDigit(r) { 42 | if nextCharShouldBeUpperCase { 43 | words = append(words, "") 44 | words[len(words)-1] += string(unicode.ToUpper(r)) 45 | nextCharShouldBeUpperCase = false 46 | } else { 47 | words[len(words)-1] += string(r) 48 | } 49 | } else { 50 | nextCharShouldBeUpperCase = true 51 | } 52 | } 53 | result := "" 54 | for _, word := range words { 55 | for _, caseWord := range forceCases { 56 | if strings.EqualFold(word, caseWord) { 57 | word = caseWord 58 | break 59 | } 60 | } 61 | result += word 62 | } 63 | var firstRune rune 64 | for _, r := range result { 65 | firstRune = r 66 | break 67 | } 68 | if result == "" || !unicode.IsUpper(firstRune) { 69 | result = "E" + result 70 | } 71 | return result 72 | } 73 | 74 | func getType(fieldDescriptor fieldDescriptor) (goType string, fieldClass string, fieldComment string, err error) { 75 | switch strings.ToLower(fieldDescriptor.Type) { 76 | case "tinyint": 77 | goType = "int8" 78 | fieldClass = "NumberField" 79 | case "smallint": 80 | goType = "int16" 81 | fieldClass = "NumberField" 82 | case "int", "mediumint": 83 | goType = "int32" 84 | fieldClass = "NumberField" 85 | case "bigint", "integer": 86 | goType = "int64" 87 | fieldClass = "NumberField" 88 | case "float", "double", "decimal", "real": 89 | goType = "float64" 90 | fieldClass = "NumberField" 91 | case "char", "varchar", "text", "tinytext", "mediumtext", "longtext", "enum", "date", "time", "json", "numeric", "character varying", "timestamp without time zone", "timestamp with time zone", "jsonb", "uuid": 92 | goType = "string" 93 | fieldClass = "StringField" 94 | case "year": 95 | goType = "int16" 96 | fieldClass = "NumberField" 97 | fieldDescriptor.Unsigned = true 98 | case "binary", "varbinary", "blob", "tinyblob", "mediumblob", "longblob": 99 | // TODO: use []byte ? 100 | goType = "string" 101 | fieldClass = "StringField" 102 | case "array": 103 | // TODO: Switch to specific type instead of interface. 104 | goType = "[]interface{}" 105 | fieldClass = "ArrayField" 106 | case "timestamp": 107 | if !timeAsString { 108 | goType = "time.Time" 109 | fieldClass = "DateField" 110 | fieldComment = "NOTICE: the range of timestamp is [1970-01-01 08:00:01, 2038-01-19 11:14:07]" 111 | } else { 112 | goType = "string" 113 | fieldClass = "StringField" 114 | } 115 | case "datetime": 116 | if !timeAsString { 117 | goType = "time.Time" 118 | fieldClass = "DateField" 119 | fieldComment = "NOTICE: the range of datetime is [0000-01-01 00:00:00, 9999-12-31 23:59:59]" 120 | } else { 121 | goType = "string" 122 | fieldClass = "StringField" 123 | } 124 | case "geometry", "point", "linestring", "polygon", "multipoint", "multilinestring", "multipolygon", "geometrycollection": 125 | goType = "sqlingo.WellKnownBinary" 126 | fieldClass = "WellKnownBinaryField" 127 | case "bit", "bool", "boolean": 128 | if fieldDescriptor.Size == 1 { 129 | goType = "bool" 130 | fieldClass = "BooleanField" 131 | } else { 132 | goType = "string" 133 | fieldClass = "StringField" 134 | } 135 | default: 136 | err = fmt.Errorf("unknown field type %s", fieldDescriptor.Type) 137 | return 138 | } 139 | if fieldDescriptor.Unsigned && strings.HasPrefix(goType, "int") { 140 | goType = "u" + goType 141 | } 142 | if fieldDescriptor.AllowNull { 143 | goType = "*" + goType 144 | } 145 | return 146 | } 147 | 148 | func getSchemaFetcherFactory(driverName string) func(db *sql.DB) schemaFetcher { 149 | switch driverName { 150 | case "mysql": 151 | return newMySQLSchemaFetcher 152 | case "sqlite3": 153 | return newSQLite3SchemaFetcher 154 | case "postgres": 155 | return newPostgresSchemaFetcher 156 | default: 157 | _, _ = fmt.Fprintln(os.Stderr, "unsupported driver "+driverName) 158 | os.Exit(2) 159 | return nil 160 | } 161 | } 162 | 163 | var nonIdentifierRegexp = regexp.MustCompile(`\W`) 164 | 165 | func ensureIdentifier(name string) string { 166 | result := nonIdentifierRegexp.ReplaceAllString(name, "_") 167 | if result == "" || (result[0] >= '0' && result[0] <= '9') { 168 | result = "_" + result 169 | } 170 | return result 171 | } 172 | 173 | // Generate generates code for the given driverName. 174 | func Generate(driverName string, exampleDataSourceName string) (string, error) { 175 | options := parseArgs(exampleDataSourceName) 176 | 177 | db, err := sql.Open(driverName, options.dataSourceName) 178 | if err != nil { 179 | return "", err 180 | } 181 | db.SetMaxOpenConns(10) 182 | 183 | schemaFetcherFactory := getSchemaFetcherFactory(driverName) 184 | schemaFetcher := schemaFetcherFactory(db) 185 | 186 | dbName, err := schemaFetcher.GetDatabaseName() 187 | if err != nil { 188 | return "", err 189 | } 190 | 191 | if dbName == "" { 192 | return "", errors.New("no database selected") 193 | } 194 | 195 | if len(options.tableNames) == 0 { 196 | options.tableNames, err = schemaFetcher.GetTableNames() 197 | if err != nil { 198 | return "", err 199 | } 200 | } 201 | 202 | needImportTime := false 203 | for _, tableName := range options.tableNames { 204 | fieldDescriptors, err := schemaFetcher.GetFieldDescriptors(tableName) 205 | if err != nil { 206 | return "", err 207 | } 208 | for _, fieldDescriptor := range fieldDescriptors { 209 | if !timeAsString && fieldDescriptor.Type == "datetime" || fieldDescriptor.Type == "timestamp" { 210 | needImportTime = true 211 | break 212 | } 213 | } 214 | } 215 | 216 | code := "// This file is generated by sqlingo (https://github.com/lqs/sqlingo)\n" 217 | code += "// DO NOT EDIT.\n\n" 218 | code += "package " + ensureIdentifier(dbName) + "_dsl\n" 219 | if needImportTime { 220 | code += "import (\n" 221 | code += "\t\"time\"\n" 222 | code += "\t\"github.com/lqs/sqlingo\"\n" 223 | code += ")\n\n" 224 | } else { 225 | code += "import \"github.com/lqs/sqlingo\"\n\n" 226 | } 227 | 228 | code += "type sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame uint32\n\n" 229 | 230 | sqlingoGeneratorVersionString := strconv.Itoa(sqlingoGeneratorVersion) 231 | code += "const _ = sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame(sqlingo.SqlingoRuntimeVersion - " + sqlingoGeneratorVersionString + ")\n" 232 | code += "const _ = sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame(" + sqlingoGeneratorVersionString + " - sqlingo.SqlingoRuntimeVersion)\n\n" 233 | 234 | code += "type table interface {\n" 235 | code += "\tsqlingo.Table\n" 236 | code += "}\n\n" 237 | 238 | code += "type numberField interface {\n" 239 | code += "\tsqlingo.NumberField\n" 240 | code += "}\n\n" 241 | 242 | code += "type stringField interface {\n" 243 | code += "\tsqlingo.StringField\n" 244 | code += "}\n\n" 245 | 246 | code += "type booleanField interface {\n" 247 | code += "\tsqlingo.BooleanField\n" 248 | code += "}\n\n" 249 | 250 | code += "type arrayField interface {\n" 251 | code += "\tsqlingo.ArrayField\n" 252 | code += "}\n\n" 253 | 254 | code += "type dateField interface {\n" 255 | code += "\tsqlingo.DateField\n" 256 | code += "}\n\n" 257 | 258 | var wg sync.WaitGroup 259 | 260 | type tableCodeItem struct { 261 | code string 262 | err error 263 | } 264 | tableCodeMap := make(map[string]*tableCodeItem) 265 | fmt.Fprintln(os.Stderr, "Generating code for tables...") 266 | var counter int32 267 | for _, tableName := range options.tableNames { 268 | wg.Add(1) 269 | item := &tableCodeItem{} 270 | tableCodeMap[tableName] = item 271 | go func(tableName string) { 272 | defer wg.Done() 273 | tableCode, err := generateTable(schemaFetcher, tableName, options.forceCases) 274 | if err != nil { 275 | item.err = err 276 | return 277 | } 278 | _, _ = fmt.Fprintf(os.Stderr, "Generated (%d/%d) %s\n", atomic.AddInt32(&counter, 1), len(options.tableNames), tableName) 279 | item.code = tableCode 280 | }(tableName) 281 | } 282 | wg.Wait() 283 | for _, tableName := range options.tableNames { 284 | item := tableCodeMap[tableName] 285 | if item.err != nil { 286 | return "", item.err 287 | } 288 | code += item.code 289 | } 290 | code += generateGetTable(options) 291 | codeOut, err := format.Source([]byte(code)) 292 | if err != nil { 293 | return "", err 294 | } 295 | return string(codeOut), nil 296 | } 297 | 298 | func generateGetTable(options options) string { 299 | code := "func GetTable(name string) sqlingo.Table {\n" 300 | code += "\tswitch name {\n" 301 | for _, tableName := range options.tableNames { 302 | code += "\tcase " + strconv.Quote(tableName) + ": return " + convertToExportedIdentifier(tableName, options.forceCases) + "\n" 303 | } 304 | code += "\tdefault: return nil\n" 305 | code += "\t}\n" 306 | code += "}\n\n" 307 | 308 | code += "func GetTables() []sqlingo.Table {\n" 309 | code += "\treturn []sqlingo.Table{\n" 310 | for _, tableName := range options.tableNames { 311 | code += "\t" + convertToExportedIdentifier(tableName, options.forceCases) + ",\n" 312 | } 313 | code += "\t}" 314 | code += "}\n\n" 315 | 316 | return code 317 | } 318 | 319 | func generateTable(schemaFetcher schemaFetcher, tableName string, forceCases []string) (string, error) { 320 | fieldDescriptors, err := schemaFetcher.GetFieldDescriptors(tableName) 321 | if err != nil { 322 | return "", err 323 | } 324 | 325 | className := convertToExportedIdentifier(tableName, forceCases) 326 | tableStructName := "t" + className 327 | tableObjectName := "o" + className 328 | 329 | modelClassName := className + "Model" 330 | 331 | tableLines := "" 332 | modelLines := "" 333 | objectLines := "\ttable: " + tableObjectName + ",\n\n" 334 | fieldCaseLines := "" 335 | classLines := "" 336 | 337 | fields := "" 338 | fieldsSQL := "" 339 | fullFieldsSQL := "" 340 | values := "" 341 | 342 | for _, fieldDescriptor := range fieldDescriptors { 343 | 344 | goName := convertToExportedIdentifier(fieldDescriptor.Name, forceCases) 345 | goType, fieldClass, typeComment, err := getType(fieldDescriptor) 346 | if err != nil { 347 | return "", err 348 | } 349 | 350 | privateFieldClass := string(fieldClass[0]+'a'-'A') + fieldClass[1:] 351 | 352 | commentLine := "" 353 | if fieldDescriptor.Comment != "" { 354 | commentLine = "\t// " + strings.ReplaceAll(fieldDescriptor.Comment, "\n", " ") + "\n" 355 | } 356 | if typeComment != "" { 357 | commentLine = "\t// " + typeComment + "\n" 358 | } 359 | 360 | fieldStructName := strings.ToLower(replaceTypeSpace(fieldDescriptor.Type)) + "_" + className + "_" + goName 361 | 362 | tableLines += commentLine 363 | tableLines += "\t" + goName + " " + fieldStructName + "\n" 364 | 365 | modelLines += commentLine 366 | modelLines += "\t" + goName + " " + goType + "\n" 367 | 368 | objectLines += commentLine 369 | objectLines += "\t" + goName + ": " + fieldStructName + "{" 370 | objectLines += "sqlingo.New" + fieldClass + "(" + tableObjectName + ", " + strconv.Quote(fieldDescriptor.Name) + ")},\n" 371 | 372 | fieldCaseLines += "\tcase " + strconv.Quote(fieldDescriptor.Name) + ": return t." + goName + "\n" 373 | 374 | classLines += "type " + fieldStructName + " struct{ " + privateFieldClass + " }\n" 375 | 376 | fields += "t." + goName + ", " 377 | 378 | if fieldsSQL != "" { 379 | fieldsSQL += ", " 380 | } 381 | fieldsSQL += schemaFetcher.QuoteIdentifier(fieldDescriptor.Name) 382 | 383 | if fullFieldsSQL != "" { 384 | fullFieldsSQL += ", " 385 | } 386 | fullFieldsSQL += schemaFetcher.QuoteIdentifier(tableName) + "." + schemaFetcher.QuoteIdentifier(fieldDescriptor.Name) 387 | 388 | values += "m." + goName + ", " 389 | } 390 | code := "" 391 | code += "type " + tableStructName + " struct {\n\ttable\n\n" 392 | code += tableLines 393 | code += "}\n\n" 394 | 395 | code += classLines 396 | 397 | code += "var " + tableObjectName + " = sqlingo.NewTable(" + strconv.Quote(tableName) + ")\n" 398 | code += "var " + className + " = " + tableStructName + "{\n" 399 | code += objectLines 400 | code += "}\n\n" 401 | 402 | code += "func (t t" + className + ") GetFields() []sqlingo.Field {\n" 403 | code += "\treturn []sqlingo.Field{" + fields + "}\n" 404 | code += "}\n\n" 405 | 406 | code += "func (t t" + className + ") GetFieldByName(name string) sqlingo.Field {\n" 407 | code += "\tswitch name {\n" 408 | code += fieldCaseLines 409 | code += "\tdefault: return nil\n" 410 | code += "\t}\n" 411 | code += "}\n\n" 412 | 413 | code += "func (t t" + className + ") GetFieldsSQL() string {\n" 414 | code += "\treturn " + strconv.Quote(fieldsSQL) + "\n" 415 | code += "}\n\n" 416 | 417 | code += "func (t t" + className + ") GetFullFieldsSQL() string {\n" 418 | code += "\treturn " + strconv.Quote(fullFieldsSQL) + "\n" 419 | code += "}\n\n" 420 | 421 | code += "type " + modelClassName + " struct {\n" 422 | code += modelLines 423 | code += "}\n\n" 424 | 425 | code += "func (m " + modelClassName + ") GetTable() sqlingo.Table {\n" 426 | code += "\treturn " + className + "\n" 427 | code += "}\n\n" 428 | 429 | code += "func (m " + modelClassName + ") GetValues() []interface{} {\n" 430 | code += "\treturn []interface{}{" + values + "}\n" 431 | code += "}\n\n" 432 | return code, nil 433 | } 434 | 435 | // replaceTypeSpace : To compatible some types contains spaces in postgresql 436 | // like [character varying, timestamp without time zone, timestamp with time zone] 437 | func replaceTypeSpace(typename string) string { 438 | return strings.ReplaceAll(typename, " ", "_") 439 | } 440 | -------------------------------------------------------------------------------- /select.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "reflect" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | type selectWithFields interface { 12 | toSelectWithContext 13 | toSelectFinal 14 | From(tables ...Table) selectWithTables 15 | } 16 | 17 | type selectWithTables interface { 18 | toSelectJoin 19 | toSelectWhere 20 | toSelectWithLock 21 | toSelectWithContext 22 | toSelectFinal 23 | toUnionSelect 24 | GroupBy(expressions ...Expression) selectWithGroupBy 25 | OrderBy(orderBys ...OrderBy) selectWithOrder 26 | Limit(limit int) selectWithLimit 27 | } 28 | 29 | type toSelectJoin interface { 30 | Join(table Table) selectWithJoin 31 | LeftJoin(table Table) selectWithJoin 32 | RightJoin(table Table) selectWithJoin 33 | NaturalJoin(table Table) selectWithJoinOn 34 | } 35 | 36 | type selectWithJoin interface { 37 | On(condition BooleanExpression) selectWithJoinOn 38 | } 39 | 40 | type selectWithJoinOn interface { 41 | toSelectWhere 42 | toSelectWithLock 43 | toSelectWithContext 44 | toSelectFinal 45 | toUnionSelect 46 | toSelectJoin 47 | GroupBy(expressions ...Expression) selectWithGroupBy 48 | OrderBy(orderBys ...OrderBy) selectWithOrder 49 | Limit(limit int) selectWithLimit 50 | } 51 | 52 | type toSelectWhere interface { 53 | Where(conditions ...BooleanExpression) selectWithWhere 54 | WhereIf(prerequisite bool, conditions ...BooleanExpression) selectWithWhere 55 | } 56 | 57 | type selectWithWhere interface { 58 | toSelectWhere 59 | toSelectWithLock 60 | toSelectWithContext 61 | toSelectFinal 62 | toUnionSelect 63 | GroupBy(expressions ...Expression) selectWithGroupBy 64 | OrderBy(orderBys ...OrderBy) selectWithOrder 65 | Limit(limit int) selectWithLimit 66 | } 67 | 68 | type selectWithGroupBy interface { 69 | toSelectWithLock 70 | toSelectWithContext 71 | toSelectFinal 72 | toUnionSelect 73 | Having(conditions ...BooleanExpression) selectWithGroupByHaving 74 | OrderBy(orderBys ...OrderBy) selectWithOrder 75 | Limit(limit int) selectWithLimit 76 | } 77 | 78 | type selectWithGroupByHaving interface { 79 | toSelectWithLock 80 | toSelectWithContext 81 | toSelectFinal 82 | toUnionSelect 83 | OrderBy(orderBys ...OrderBy) selectWithOrder 84 | } 85 | 86 | type selectWithOrder interface { 87 | toSelectWithLock 88 | toSelectWithContext 89 | toSelectFinal 90 | Limit(limit int) selectWithLimit 91 | } 92 | 93 | type selectWithLimit interface { 94 | toSelectWithLock 95 | toSelectWithContext 96 | toSelectFinal 97 | Offset(offset int) selectWithOffset 98 | } 99 | 100 | type selectWithOffset interface { 101 | toSelectWithLock 102 | toSelectWithContext 103 | toSelectFinal 104 | } 105 | 106 | type toSelectWithLock interface { 107 | LockInShareMode() selectWithLock 108 | ForUpdate() selectWithLock 109 | ForUpdateNoWait() selectWithLock 110 | ForUpdateSkipLocked() selectWithLock 111 | } 112 | 113 | type selectWithLock interface { 114 | toSelectWithContext 115 | toSelectFinal 116 | } 117 | 118 | type toSelectWithContext interface { 119 | WithContext(ctx context.Context) toSelectFinal 120 | } 121 | 122 | type toUnionSelect interface { 123 | UnionSelect(fields ...interface{}) selectWithFields 124 | UnionSelectFrom(tables ...Table) selectWithTables 125 | UnionSelectDistinct(fields ...interface{}) selectWithFields 126 | UnionAllSelect(fields ...interface{}) selectWithFields 127 | UnionAllSelectFrom(tables ...Table) selectWithTables 128 | UnionAllSelectDistinct(fields ...interface{}) selectWithFields 129 | } 130 | 131 | type toSelectFinal interface { 132 | Exists() (bool, error) 133 | Count() (int, error) 134 | GetSQL() (string, error) 135 | FetchFirst(out ...interface{}) (bool, error) 136 | FetchExactlyOne(out ...interface{}) error 137 | FetchAll(dest ...interface{}) (rows int, err error) 138 | FetchCursor() (Cursor, error) 139 | FetchSeq() func(yield func(row Scanner) bool) // use with "range over function" in Go 1.22 140 | } 141 | 142 | type join struct { 143 | previous *join 144 | prefix string 145 | table Table 146 | on BooleanExpression 147 | } 148 | 149 | type selectBase struct { 150 | scope scope 151 | distinct bool 152 | fields fieldList 153 | where BooleanExpression 154 | groupBys []Expression 155 | having BooleanExpression 156 | } 157 | 158 | type selectStatus struct { 159 | base selectBase 160 | orderBys []OrderBy 161 | lastUnion *unionSelectStatus 162 | limit *int 163 | offset int 164 | ctx context.Context 165 | lock string 166 | } 167 | 168 | type errorScanner struct { 169 | err error 170 | } 171 | 172 | func (e errorScanner) Scan(dest ...interface{}) error { 173 | return e.err 174 | } 175 | 176 | func (s selectStatus) FetchSeq() func(yield func(row Scanner) bool) { 177 | return func(yield func(row Scanner) bool) { 178 | cursor, err := s.FetchCursor() 179 | if err != nil { 180 | yield(errorScanner{err}) 181 | return 182 | } 183 | 184 | defer cursor.Close() 185 | for cursor.Next() { 186 | if !yield(cursor) { 187 | break 188 | } 189 | } 190 | } 191 | } 192 | 193 | type unionSelectStatus struct { 194 | base selectBase 195 | all bool 196 | previous *unionSelectStatus 197 | } 198 | 199 | func activeSelectBase(s *selectStatus) *selectBase { 200 | if s.lastUnion != nil { 201 | return &s.lastUnion.base 202 | } 203 | return &s.base 204 | } 205 | 206 | func (s selectStatus) Join(table Table) selectWithJoin { 207 | return s.join("", table) 208 | } 209 | 210 | func (s selectStatus) LeftJoin(table Table) selectWithJoin { 211 | return s.join("LEFT ", table) 212 | } 213 | 214 | func (s selectStatus) RightJoin(table Table) selectWithJoin { 215 | return s.join("RIGHT ", table) 216 | } 217 | 218 | // NaturalJoin joins the table using the NATURAL keyword. 219 | // it automatically matches the columns in the two tables that have the same name. 220 | // it not be needed but be provided for completeness. 221 | func (s selectStatus) NaturalJoin(table Table) selectWithJoinOn { 222 | base := activeSelectBase(&s) 223 | base.scope.lastJoin = &join{ 224 | previous: base.scope.lastJoin, 225 | prefix: "NATURAL ", 226 | table: table, 227 | } 228 | join := *base.scope.lastJoin 229 | base.scope.lastJoin = &join 230 | return s 231 | } 232 | 233 | func (s selectStatus) join(prefix string, table Table) selectWithJoin { 234 | base := activeSelectBase(&s) 235 | base.scope.lastJoin = &join{ 236 | previous: base.scope.lastJoin, 237 | prefix: prefix, 238 | table: table, 239 | } 240 | return s 241 | } 242 | 243 | func (s selectStatus) On(condition BooleanExpression) selectWithJoinOn { 244 | base := activeSelectBase(&s) 245 | join := *base.scope.lastJoin 246 | join.on = condition 247 | base.scope.lastJoin = &join 248 | return s 249 | } 250 | 251 | func getFields(fields []interface{}) (result []Field) { 252 | fields = expandSliceValues(fields) 253 | result = make([]Field, 0, len(fields)) 254 | for _, field := range fields { 255 | switch field.(type) { 256 | case Field: 257 | result = append(result, field.(Field)) 258 | case Table: 259 | result = append(result, field.(Table).GetFields()...) 260 | default: 261 | fieldCopy := field 262 | fieldExpression := expression{builder: func(scope scope) (string, error) { 263 | sql, _, err := getSQL(scope, fieldCopy) 264 | if err != nil { 265 | return "", err 266 | } 267 | return sql, nil 268 | }} 269 | result = append(result, fieldExpression) 270 | } 271 | } 272 | return 273 | } 274 | 275 | func (d *database) Select(fields ...interface{}) selectWithFields { 276 | return selectStatus{ 277 | base: selectBase{ 278 | scope: scope{ 279 | Database: d, 280 | }, 281 | fields: getFields(fields), 282 | }, 283 | } 284 | } 285 | 286 | func (s selectStatus) From(tables ...Table) selectWithTables { 287 | activeSelectBase(&s).scope.Tables = tables 288 | return s 289 | } 290 | 291 | func (d *database) SelectFrom(tables ...Table) selectWithTables { 292 | return selectStatus{ 293 | base: selectBase{ 294 | scope: scope{ 295 | Database: d, 296 | Tables: tables, 297 | }, 298 | }, 299 | } 300 | } 301 | 302 | func (d *database) SelectDistinct(fields ...interface{}) selectWithFields { 303 | return selectStatus{ 304 | base: selectBase{ 305 | scope: scope{ 306 | Database: d, 307 | }, 308 | fields: getFields(fields), 309 | distinct: true, 310 | }, 311 | } 312 | } 313 | 314 | func (s selectStatus) Where(conditions ...BooleanExpression) selectWithWhere { 315 | if base := activeSelectBase(&s); base.where == nil { 316 | base.where = And(conditions...) 317 | } else { 318 | base.where = And(base.where, And(conditions...)) 319 | } 320 | return s 321 | } 322 | 323 | func (s selectStatus) WhereIf(prerequisite bool, conditions ...BooleanExpression) selectWithWhere { 324 | if !prerequisite { 325 | return s 326 | } 327 | if base := activeSelectBase(&s); base.where == nil { 328 | base.where = And(conditions...) 329 | } else { 330 | base.where = And(base.where, And(conditions...)) 331 | } 332 | return s 333 | } 334 | 335 | func (s selectStatus) GroupBy(expressions ...Expression) selectWithGroupBy { 336 | activeSelectBase(&s).groupBys = expressions 337 | return s 338 | } 339 | 340 | func (s selectStatus) Having(conditions ...BooleanExpression) selectWithGroupByHaving { 341 | activeSelectBase(&s).having = And(conditions...) 342 | return s 343 | } 344 | 345 | func (s selectStatus) UnionSelect(fields ...interface{}) selectWithFields { 346 | return s.withUnionSelect(false, false, fields, nil) 347 | } 348 | 349 | func (s selectStatus) UnionSelectFrom(tables ...Table) selectWithTables { 350 | return s.withUnionSelect(false, false, nil, tables) 351 | } 352 | 353 | func (s selectStatus) UnionSelectDistinct(fields ...interface{}) selectWithFields { 354 | return s.withUnionSelect(false, true, fields, nil) 355 | } 356 | 357 | func (s selectStatus) UnionAllSelect(fields ...interface{}) selectWithFields { 358 | return s.withUnionSelect(true, false, fields, nil) 359 | } 360 | 361 | func (s selectStatus) UnionAllSelectFrom(tables ...Table) selectWithTables { 362 | return s.withUnionSelect(true, false, nil, tables) 363 | } 364 | 365 | func (s selectStatus) UnionAllSelectDistinct(fields ...interface{}) selectWithFields { 366 | return s.withUnionSelect(true, true, fields, nil) 367 | } 368 | 369 | func (s selectStatus) withUnionSelect(all bool, distinct bool, fields []interface{}, tables []Table) selectStatus { 370 | s.lastUnion = &unionSelectStatus{ 371 | base: selectBase{ 372 | scope: scope{ 373 | Database: s.base.scope.Database, 374 | Tables: tables, 375 | }, 376 | distinct: distinct, 377 | fields: getFields(fields), 378 | }, 379 | all: all, 380 | previous: s.lastUnion, 381 | } 382 | return s 383 | } 384 | 385 | func (s selectStatus) OrderBy(orderBys ...OrderBy) selectWithOrder { 386 | s.orderBys = orderBys 387 | return s 388 | } 389 | 390 | func (s selectStatus) Limit(limit int) selectWithLimit { 391 | s.limit = &limit 392 | return s 393 | } 394 | 395 | func (s selectStatus) Offset(offset int) selectWithOffset { 396 | s.offset = offset 397 | return s 398 | } 399 | 400 | func (s selectStatus) Count() (count int, err error) { 401 | if s.lastUnion == nil && len(s.base.groupBys) == 0 && s.limit == nil { 402 | if s.base.distinct { 403 | fields := s.base.fields 404 | s.base.distinct = false 405 | s.base.fields = []Field{expression{builder: func(scope scope) (string, error) { 406 | fieldsSql, err := fields.GetSQL(scope) 407 | if err != nil { 408 | return "", err 409 | } 410 | return "COUNT(DISTINCT " + fieldsSql + ")", nil 411 | }}} 412 | _, err = s.FetchFirst(&count) 413 | } else { 414 | s.base.fields = []Field{staticExpression("COUNT(1)", 0, false)} 415 | _, err = s.FetchFirst(&count) 416 | } 417 | } else { 418 | if !s.base.distinct { 419 | s.base.fields = []Field{staticExpression("1", 0, false)} 420 | } 421 | _, err = s.base.scope.Database.Select(Function("COUNT", 1)). 422 | From(s.asDerivedTable("t")). 423 | FetchFirst(&count) 424 | } 425 | 426 | return 427 | } 428 | 429 | func (s selectStatus) LockInShareMode() selectWithLock { 430 | s.lock = " LOCK IN SHARE MODE" 431 | return s 432 | } 433 | 434 | func (s selectStatus) ForUpdate() selectWithLock { 435 | s.lock = " FOR UPDATE" 436 | return s 437 | } 438 | 439 | func (s selectStatus) ForUpdateNoWait() selectWithLock { 440 | s.lock = " FOR UPDATE NOWAIT" 441 | return s 442 | } 443 | 444 | func (s selectStatus) ForUpdateSkipLocked() selectWithLock { 445 | s.lock = " FOR UPDATE SKIP LOCKED" 446 | return s 447 | } 448 | 449 | func (s selectStatus) asDerivedTable(name string) Table { 450 | return derivedTable{ 451 | name: name, 452 | selectStatus: s, 453 | } 454 | } 455 | 456 | func (s selectStatus) Exists() (exists bool, err error) { 457 | _, err = s.base.scope.Database.Select(command("EXISTS", s)).FetchFirst(&exists) 458 | return 459 | } 460 | 461 | func (s selectBase) buildSelectBase(sb *strings.Builder) error { 462 | sb.WriteString("SELECT ") 463 | if s.distinct { 464 | sb.WriteString("DISTINCT ") 465 | } 466 | 467 | // find tables from fields if "From" is not specified 468 | if len(s.scope.Tables) == 0 && len(s.fields) > 0 { 469 | tableNames := make([]string, 0, len(s.fields)) 470 | tableMap := make(map[string]Table) 471 | for _, field := range s.fields { 472 | table := field.GetTable() 473 | if table == nil { 474 | continue 475 | } 476 | tableName := table.GetName() 477 | if _, ok := tableMap[tableName]; !ok { 478 | tableMap[tableName] = table 479 | tableNames = append(tableNames, tableName) 480 | } 481 | } 482 | for _, tableName := range tableNames { 483 | table := tableMap[tableName] 484 | s.scope.Tables = append(s.scope.Tables, table) 485 | } 486 | } 487 | 488 | fieldsSql, err := s.fields.GetSQL(s.scope) 489 | if err != nil { 490 | return err 491 | } 492 | sb.WriteString(fieldsSql) 493 | 494 | if len(s.scope.Tables) > 0 { 495 | fromSql := commaTables(s.scope, s.scope.Tables) 496 | sb.WriteString(" FROM ") 497 | sb.WriteString(fromSql) 498 | } 499 | 500 | if s.scope.lastJoin != nil { 501 | var joins []*join 502 | for j := s.scope.lastJoin; j != nil; j = j.previous { 503 | joins = append(joins, j) 504 | } 505 | for i := len(joins) - 1; i >= 0; i-- { 506 | join := joins[i] 507 | sb.WriteString(" ") 508 | sb.WriteString(join.prefix) 509 | sb.WriteString("JOIN ") 510 | sb.WriteString(join.table.GetSQL(s.scope)) 511 | // cause on isn't a required part of join when using natural join, 512 | // so move it to if statement 513 | if join.on != nil { 514 | onSql, err := join.on.GetSQL(s.scope) 515 | if err != nil { 516 | return err 517 | } 518 | sb.WriteString(" ON ") 519 | sb.WriteString(onSql) 520 | } 521 | } 522 | } 523 | 524 | if err := appendWhere(sb, s.scope, s.where); err != nil { 525 | return err 526 | } 527 | 528 | if len(s.groupBys) != 0 { 529 | groupBySql, err := commaExpressions(s.scope, s.groupBys) 530 | if err != nil { 531 | return err 532 | } 533 | sb.WriteString(" GROUP BY ") 534 | sb.WriteString(groupBySql) 535 | 536 | if s.having != nil { 537 | havingSql, err := s.having.GetSQL(s.scope) 538 | if err != nil { 539 | return err 540 | } 541 | sb.WriteString(" HAVING ") 542 | sb.WriteString(havingSql) 543 | } 544 | } 545 | 546 | return nil 547 | } 548 | 549 | func (s selectStatus) GetSQL() (string, error) { 550 | var sb strings.Builder 551 | sb.Grow(128) 552 | 553 | if err := s.base.buildSelectBase(&sb); err != nil { 554 | return "", err 555 | } 556 | 557 | var unions []*unionSelectStatus 558 | for union := s.lastUnion; union != nil; union = union.previous { 559 | unions = append(unions, union) 560 | } 561 | for i := len(unions) - 1; i >= 0; i-- { 562 | union := unions[i] 563 | if union.all { 564 | sb.WriteString(" UNION ALL ") 565 | } else { 566 | sb.WriteString(" UNION ") 567 | } 568 | if err := union.base.buildSelectBase(&sb); err != nil { 569 | return "", err 570 | } 571 | } 572 | 573 | if len(s.orderBys) > 0 { 574 | orderBySql, err := commaOrderBys(s.base.scope, s.orderBys) 575 | if err != nil { 576 | return "", err 577 | } 578 | sb.WriteString(" ORDER BY ") 579 | sb.WriteString(orderBySql) 580 | } 581 | 582 | if s.limit != nil { 583 | sb.WriteString(" LIMIT ") 584 | sb.WriteString(strconv.Itoa(*s.limit)) 585 | } 586 | 587 | if s.offset != 0 { 588 | sb.WriteString(" OFFSET ") 589 | sb.WriteString(strconv.Itoa(s.offset)) 590 | } 591 | 592 | sb.WriteString(s.lock) 593 | 594 | return sb.String(), nil 595 | } 596 | 597 | func (s selectStatus) WithContext(ctx context.Context) toSelectFinal { 598 | s.ctx = ctx 599 | return s 600 | } 601 | 602 | func (s selectStatus) FetchCursor() (Cursor, error) { 603 | sqlString, err := s.GetSQL() 604 | if err != nil { 605 | return nil, err 606 | } 607 | 608 | cursor, err := s.base.scope.Database.QueryContext(s.ctx, sqlString) 609 | if err != nil { 610 | return nil, err 611 | } 612 | return cursor, nil 613 | } 614 | 615 | func (s selectStatus) FetchFirst(dest ...interface{}) (ok bool, err error) { 616 | cursor, err := s.FetchCursor() 617 | if err != nil { 618 | return 619 | } 620 | defer cursor.Close() 621 | 622 | for cursor.Next() { 623 | err = cursor.Scan(dest...) 624 | if err != nil { 625 | return 626 | } 627 | ok = true 628 | break 629 | } 630 | 631 | return 632 | } 633 | 634 | func (s selectStatus) FetchExactlyOne(dest ...interface{}) (err error) { 635 | cursor, err := s.FetchCursor() 636 | if err != nil { 637 | return 638 | } 639 | defer cursor.Close() 640 | 641 | hasResult := false 642 | for cursor.Next() { 643 | if hasResult { 644 | return errors.New("more than one rows") 645 | } 646 | err = cursor.Scan(dest...) 647 | if err != nil { 648 | return 649 | } 650 | hasResult = true 651 | } 652 | if !hasResult { 653 | err = errors.New("no rows") 654 | } 655 | return 656 | } 657 | 658 | func (s selectStatus) fetchAllAsMap(cursor Cursor, mapType reflect.Type) (mapValue reflect.Value, err error) { 659 | mapValue = reflect.MakeMap(mapType) 660 | key := reflect.New(mapType.Key()) 661 | elem := reflect.New(mapType.Elem()) 662 | 663 | for cursor.Next() { 664 | err = cursor.Scan(key.Interface(), elem.Interface()) 665 | if err != nil { 666 | return 667 | } 668 | 669 | mapValue.SetMapIndex(reflect.Indirect(key), reflect.Indirect(elem)) 670 | } 671 | return 672 | } 673 | 674 | func (s selectStatus) FetchAll(dest ...interface{}) (rows int, err error) { 675 | cursor, err := s.FetchCursor() 676 | if err != nil { 677 | return 678 | } 679 | defer cursor.Close() 680 | 681 | count := len(dest) 682 | values := make([]reflect.Value, count) 683 | for i, item := range dest { 684 | if reflect.ValueOf(item).Kind() != reflect.Ptr { 685 | err = errors.New("dest should be a pointer") 686 | return 687 | } 688 | val := reflect.Indirect(reflect.ValueOf(item)) 689 | 690 | switch val.Kind() { 691 | case reflect.Slice: 692 | values[i] = val 693 | case reflect.Map: 694 | if len(dest) != 1 { 695 | err = errors.New("dest map should be 1 element") 696 | return 697 | } 698 | var mapValue reflect.Value 699 | mapValue, err = s.fetchAllAsMap(cursor, val.Type()) 700 | if err != nil { 701 | return 702 | } 703 | reflect.ValueOf(item).Elem().Set(mapValue) 704 | return 705 | default: 706 | err = errors.New("dest should be pointed to a slice") 707 | return 708 | } 709 | } 710 | 711 | elements := make([]reflect.Value, count) 712 | pointers := make([]interface{}, count) 713 | for i := 0; i < count; i++ { 714 | elements[i] = reflect.New(values[i].Type().Elem()) 715 | pointers[i] = elements[i].Interface() 716 | } 717 | for cursor.Next() { 718 | err = cursor.Scan(pointers...) 719 | if err != nil { 720 | return 721 | } 722 | for i := 0; i < count; i++ { 723 | values[i].Set(reflect.Append(values[i], reflect.Indirect(elements[i]))) 724 | } 725 | rows++ 726 | } 727 | return 728 | } 729 | -------------------------------------------------------------------------------- /expression.go: -------------------------------------------------------------------------------- 1 | package sqlingo 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strconv" 7 | "strings" 8 | "time" 9 | "unsafe" 10 | ) 11 | 12 | type priority uint8 13 | 14 | // Expression is the interface of an SQL expression. 15 | type Expression interface { 16 | // get the SQL string 17 | GetSQL(scope scope) (string, error) 18 | getOperatorPriority() priority 19 | 20 | // <> operator 21 | NotEquals(other interface{}) BooleanExpression 22 | // == operator 23 | Equals(other interface{}) BooleanExpression 24 | // < operator 25 | LessThan(other interface{}) BooleanExpression 26 | // <= operator 27 | LessThanOrEquals(other interface{}) BooleanExpression 28 | // > operator 29 | GreaterThan(other interface{}) BooleanExpression 30 | // >= operator 31 | GreaterThanOrEquals(other interface{}) BooleanExpression 32 | 33 | IsNull() BooleanExpression 34 | IsNotNull() BooleanExpression 35 | IsTrue() BooleanExpression 36 | IsNotTrue() BooleanExpression 37 | IsFalse() BooleanExpression 38 | IsNotFalse() BooleanExpression 39 | In(values ...interface{}) BooleanExpression 40 | NotIn(values ...interface{}) BooleanExpression 41 | Between(min interface{}, max interface{}) BooleanExpression 42 | NotBetween(min interface{}, max interface{}) BooleanExpression 43 | Desc() OrderBy 44 | 45 | As(alias string) Alias 46 | 47 | If(trueValue interface{}, falseValue interface{}) UnknownExpression 48 | IfNull(altValue interface{}) UnknownExpression 49 | } 50 | 51 | // Alias is the interface of an table/column alias. 52 | type Alias interface { 53 | GetSQL(scope scope) (string, error) 54 | } 55 | 56 | // BooleanExpression is the interface of an SQL expression with boolean value. 57 | type BooleanExpression interface { 58 | Expression 59 | And(other interface{}) BooleanExpression 60 | Or(other interface{}) BooleanExpression 61 | Xor(other interface{}) BooleanExpression 62 | Not() BooleanExpression 63 | } 64 | 65 | // NumberExpression is the interface of an SQL expression with number value. 66 | type NumberExpression interface { 67 | Expression 68 | Add(other interface{}) NumberExpression 69 | Sub(other interface{}) NumberExpression 70 | Mul(other interface{}) NumberExpression 71 | Div(other interface{}) NumberExpression 72 | IntDiv(other interface{}) NumberExpression 73 | Mod(other interface{}) NumberExpression 74 | 75 | Sum() NumberExpression 76 | Avg() NumberExpression 77 | Min() UnknownExpression 78 | Max() UnknownExpression 79 | } 80 | 81 | // StringExpression is the interface of an SQL expression with string value. 82 | type StringExpression interface { 83 | Expression 84 | Min() UnknownExpression 85 | Max() UnknownExpression 86 | Like(other interface{}) BooleanExpression 87 | Contains(substring string) BooleanExpression 88 | Concat(other interface{}) StringExpression 89 | IfEmpty(altValue interface{}) StringExpression 90 | IsEmpty() BooleanExpression 91 | Lower() StringExpression 92 | Upper() StringExpression 93 | Left(count interface{}) StringExpression 94 | Right(count interface{}) StringExpression 95 | Trim() StringExpression 96 | } 97 | 98 | type ArrayExpression interface { 99 | Expression 100 | } 101 | 102 | type DateExpression interface { 103 | Expression 104 | Min() UnknownExpression 105 | Max() UnknownExpression 106 | } 107 | 108 | // UnknownExpression is the interface of an SQL expression with unknown value. 109 | type UnknownExpression interface { 110 | Expression 111 | And(other interface{}) BooleanExpression 112 | Or(other interface{}) BooleanExpression 113 | Xor(other interface{}) BooleanExpression 114 | Not() BooleanExpression 115 | Add(other interface{}) NumberExpression 116 | Sub(other interface{}) NumberExpression 117 | Mul(other interface{}) NumberExpression 118 | Div(other interface{}) NumberExpression 119 | IntDiv(other interface{}) NumberExpression 120 | Mod(other interface{}) NumberExpression 121 | 122 | Sum() NumberExpression 123 | Avg() NumberExpression 124 | Min() UnknownExpression 125 | Max() UnknownExpression 126 | 127 | Like(other interface{}) BooleanExpression 128 | Contains(substring string) BooleanExpression 129 | Concat(other interface{}) StringExpression 130 | IfEmpty(altValue interface{}) StringExpression 131 | IsEmpty() BooleanExpression 132 | Lower() StringExpression 133 | Upper() StringExpression 134 | Left(count interface{}) StringExpression 135 | Right(count interface{}) StringExpression 136 | Trim() StringExpression 137 | } 138 | 139 | type expression struct { 140 | sql string 141 | builder func(scope scope) (string, error) 142 | priority priority 143 | isTrue bool 144 | isFalse bool 145 | isBool bool 146 | } 147 | 148 | func (e expression) GetTable() Table { 149 | return nil 150 | } 151 | 152 | type scope struct { 153 | Database *database 154 | Tables []Table 155 | lastJoin *join 156 | } 157 | 158 | func staticExpression(sql string, priority priority, isBool bool) expression { 159 | return expression{ 160 | sql: sql, 161 | priority: priority, 162 | isBool: isBool, 163 | } 164 | } 165 | 166 | func True() BooleanExpression { 167 | return expression{ 168 | sql: "TRUE", 169 | isTrue: true, 170 | isBool: true, 171 | } 172 | } 173 | 174 | func False() BooleanExpression { 175 | return expression{ 176 | sql: "FALSE", 177 | isFalse: true, 178 | isBool: true, 179 | } 180 | } 181 | 182 | // Raw create a raw SQL statement 183 | func Raw(sql string) UnknownExpression { 184 | return expression{ 185 | sql: sql, 186 | priority: 99, 187 | } 188 | } 189 | 190 | // And creates an expression with AND operator. 191 | func And(expressions ...BooleanExpression) (result BooleanExpression) { 192 | if len(expressions) == 0 { 193 | result = True() 194 | return 195 | } 196 | for _, condition := range expressions { 197 | if result == nil { 198 | result = condition 199 | } else { 200 | result = result.And(condition) 201 | } 202 | } 203 | return 204 | } 205 | 206 | // Or creates an expression with OR operator. 207 | func Or(expressions ...BooleanExpression) (result BooleanExpression) { 208 | if len(expressions) == 0 { 209 | result = False() 210 | return 211 | } 212 | for _, condition := range expressions { 213 | if result == nil { 214 | result = condition 215 | } else { 216 | result = result.Or(condition) 217 | } 218 | } 219 | return 220 | } 221 | 222 | func (e expression) As(name string) Alias { 223 | return expression{builder: func(scope scope) (string, error) { 224 | expressionSql, err := e.GetSQL(scope) 225 | if err != nil { 226 | return "", err 227 | } 228 | return expressionSql + " AS " + name, nil 229 | }} 230 | } 231 | 232 | func (e expression) If(trueValue interface{}, falseValue interface{}) UnknownExpression { 233 | return If(e, trueValue, falseValue) 234 | } 235 | 236 | func (e expression) IfNull(altValue interface{}) UnknownExpression { 237 | return Function("IFNULL", e, altValue) 238 | } 239 | 240 | func (e expression) IfEmpty(altValue interface{}) StringExpression { 241 | return If(e.NotEquals(""), e, altValue) 242 | } 243 | 244 | func (e expression) IsEmpty() BooleanExpression { 245 | return e.Equals("") 246 | } 247 | 248 | func (e expression) Lower() StringExpression { 249 | return function("LOWER", e) 250 | } 251 | 252 | func (e expression) Upper() StringExpression { 253 | return function("UPPER", e) 254 | } 255 | 256 | func (e expression) Left(count interface{}) StringExpression { 257 | return function("LEFT", e, count) 258 | } 259 | 260 | func (e expression) Right(count interface{}) StringExpression { 261 | return function("RIGHT", e, count) 262 | } 263 | 264 | func (e expression) Trim() StringExpression { 265 | return function("TRIM", e) 266 | } 267 | 268 | func (e expression) CharLength() NumberExpression { 269 | return function("CHAR_LENGTH", e) 270 | } 271 | 272 | func (e expression) HasPrefix(prefix interface{}) BooleanExpression { 273 | return e.Left(function("CHAR_LENGTH", prefix)).Equals(prefix) 274 | } 275 | 276 | func (e expression) HasSuffix(suffix interface{}) BooleanExpression { 277 | return e.Right(function("CHAR_LENGTH", suffix)).Equals(suffix) 278 | } 279 | 280 | func (e expression) GetSQL(scope scope) (string, error) { 281 | if e.sql != "" { 282 | return e.sql, nil 283 | } 284 | return e.builder(scope) 285 | } 286 | 287 | var needsEscape = [256]int{ 288 | 0: 1, 289 | '\n': 1, 290 | '\r': 1, 291 | '\\': 1, 292 | '\'': 1, 293 | '"': 1, 294 | 0x1a: 1, 295 | } 296 | 297 | func quoteIdentifier(identifier string) (result dialectArray) { 298 | for dialect := dialect(0); dialect < dialectCount; dialect++ { 299 | switch dialect { 300 | case dialectMySQL: 301 | result[dialect] = "`" + identifier + "`" 302 | case dialectMSSQL: 303 | result[dialect] = "[" + identifier + "]" 304 | default: 305 | result[dialect] = "\"" + identifier + "\"" 306 | } 307 | } 308 | return 309 | } 310 | 311 | func quoteString(s string) string { 312 | if s == "" { 313 | return "''" 314 | } 315 | 316 | buf := make([]byte, len(s)*2+2) 317 | buf[0] = '\'' 318 | n := 1 319 | for i := 0; i < len(s); i++ { 320 | b := s[i] 321 | buf[n] = '\\' 322 | n += needsEscape[b] 323 | buf[n] = b 324 | n++ 325 | } 326 | buf[n] = '\'' 327 | n++ 328 | buf = buf[:n] 329 | return *(*string)(unsafe.Pointer(&buf)) 330 | } 331 | 332 | func getSQL(scope scope, value interface{}) (sql string, priority priority, err error) { 333 | const mysqlTimeFormat = "2006-01-02 15:04:05.000000" 334 | if value == nil { 335 | sql = "NULL" 336 | return 337 | } 338 | switch value.(type) { 339 | case int: 340 | sql = strconv.Itoa(value.(int)) 341 | case string: 342 | sql = quoteString(value.(string)) 343 | case Expression: 344 | sql, err = value.(Expression).GetSQL(scope) 345 | priority = value.(Expression).getOperatorPriority() 346 | case Assignment: 347 | sql, err = value.(Assignment).GetSQL(scope) 348 | case toSelectFinal: 349 | sql, err = value.(toSelectFinal).GetSQL() 350 | if err != nil { 351 | return 352 | } 353 | sql = "(" + sql + ")" 354 | case toUpdateFinal: 355 | sql, err = value.(toUpdateFinal).GetSQL() 356 | case Table: 357 | sql = value.(Table).GetSQL(scope) 358 | case CaseExpression: 359 | sql, err = value.(CaseExpression).End().GetSQL(scope) 360 | case time.Time: 361 | tm := value.(time.Time) 362 | if tm.IsZero() { 363 | sql = "NULL" 364 | } else { 365 | tmStr := tm.Format(mysqlTimeFormat) 366 | sql = quoteString(tmStr) 367 | } 368 | case *time.Time: 369 | tm := value.(*time.Time) 370 | if tm == nil || tm.IsZero() { 371 | sql = "NULL" 372 | } else { 373 | tmStr := tm.Format(mysqlTimeFormat) 374 | sql = quoteString(tmStr) 375 | } 376 | default: 377 | v := reflect.ValueOf(value) 378 | sql, priority, err = getSQLFromReflectValue(scope, v) 379 | } 380 | return 381 | } 382 | 383 | func getSQLFromReflectValue(scope scope, v reflect.Value) (sql string, priority priority, err error) { 384 | if v.Kind() == reflect.Ptr { 385 | // dereference pointers 386 | for { 387 | if v.IsNil() { 388 | sql = "NULL" 389 | return 390 | } 391 | v = v.Elem() 392 | if v.Kind() != reflect.Ptr { 393 | break 394 | } 395 | } 396 | sql, priority, err = getSQL(scope, v.Interface()) 397 | return 398 | } 399 | 400 | switch v.Kind() { 401 | case reflect.Bool: 402 | if v.Bool() { 403 | sql = "1" 404 | } else { 405 | sql = "0" 406 | } 407 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 408 | sql = strconv.FormatInt(v.Int(), 10) 409 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 410 | sql = strconv.FormatUint(v.Uint(), 10) 411 | case reflect.Float32, reflect.Float64: 412 | sql = strconv.FormatFloat(v.Float(), 'g', -1, 64) 413 | case reflect.String: 414 | sql = quoteString(v.String()) 415 | case reflect.Array, reflect.Slice: 416 | length := v.Len() 417 | values := make([]interface{}, length) 418 | for i := 0; i < length; i++ { 419 | values[i] = v.Index(i).Interface() 420 | } 421 | sql, err = commaValues(scope, values) 422 | if err == nil { 423 | sql = "(" + sql + ")" 424 | } 425 | default: 426 | if vs, ok := v.Interface().(interface{ String() string }); ok { 427 | sql = quoteString(vs.String()) 428 | } else { 429 | err = fmt.Errorf("invalid type %s", v.Kind().String()) 430 | } 431 | } 432 | return 433 | } 434 | 435 | /* 436 | 1 INTERVAL 437 | 2 BINARY, COLLATE 438 | 3 ! 439 | 4 - (unary minus), ~ (unary bit inversion) 440 | 5 ^ 441 | 6 *, /, DIV, %, MOD 442 | 7 -, + 443 | 8 <<, >> 444 | 9 & 445 | 10 | 446 | 11 = (comparison), <=>, >=, >, <=, <, <>, !=, IS, LIKE, REGEXP, IN 447 | 12 BETWEEN, CASE, WHEN, THEN, ELSE 448 | 13 NOT 449 | 14 AND, && 450 | 15 XOR 451 | 16 OR, || 452 | 17 = (assignment), := 453 | */ 454 | func (e expression) NotEquals(other interface{}) BooleanExpression { 455 | return e.binaryOperation("<>", other, 11, true) 456 | } 457 | 458 | func (e expression) Equals(other interface{}) BooleanExpression { 459 | return e.binaryOperation("=", other, 11, true) 460 | } 461 | 462 | func (e expression) LessThan(other interface{}) BooleanExpression { 463 | return e.binaryOperation("<", other, 11, true) 464 | } 465 | 466 | func (e expression) LessThanOrEquals(other interface{}) BooleanExpression { 467 | return e.binaryOperation("<=", other, 11, true) 468 | } 469 | 470 | func (e expression) GreaterThan(other interface{}) BooleanExpression { 471 | return e.binaryOperation(">", other, 11, true) 472 | } 473 | 474 | func (e expression) GreaterThanOrEquals(other interface{}) BooleanExpression { 475 | return e.binaryOperation(">=", other, 11, true) 476 | } 477 | 478 | func toBooleanExpression(value interface{}) BooleanExpression { 479 | e, ok := value.(expression) 480 | switch { 481 | case !ok: 482 | return nil 483 | case e.isTrue: 484 | return True() 485 | case e.isFalse: 486 | return False() 487 | case e.isBool: 488 | return e 489 | default: 490 | return nil 491 | } 492 | } 493 | 494 | func (e expression) And(other interface{}) BooleanExpression { 495 | switch { 496 | case e.isFalse: 497 | return e 498 | case e.isTrue: 499 | if exp := toBooleanExpression(other); exp != nil { 500 | return exp 501 | } 502 | } 503 | return e.binaryOperation("AND", other, 14, true) 504 | } 505 | 506 | func (e expression) Or(other interface{}) BooleanExpression { 507 | switch { 508 | case e.isTrue: 509 | return e 510 | case e.isFalse: 511 | if exp := toBooleanExpression(other); exp != nil { 512 | return exp 513 | } 514 | } 515 | return e.binaryOperation("OR", other, 16, true) 516 | } 517 | 518 | func (e expression) Xor(other interface{}) BooleanExpression { 519 | return e.binaryOperation("XOR", other, 15, true) 520 | } 521 | 522 | func (e expression) Add(other interface{}) NumberExpression { 523 | return e.binaryOperation("+", other, 7, false) 524 | } 525 | 526 | func (e expression) Sub(other interface{}) NumberExpression { 527 | return e.binaryOperation("-", other, 7, false) 528 | } 529 | 530 | func (e expression) Mul(other interface{}) NumberExpression { 531 | return e.binaryOperation("*", other, 6, false) 532 | } 533 | 534 | func (e expression) Div(other interface{}) NumberExpression { 535 | return e.binaryOperation("/", other, 6, false) 536 | } 537 | 538 | func (e expression) IntDiv(other interface{}) NumberExpression { 539 | return e.binaryOperation("DIV", other, 6, false) 540 | } 541 | 542 | func (e expression) Mod(other interface{}) NumberExpression { 543 | return e.binaryOperation("%", other, 6, false) 544 | } 545 | 546 | func (e expression) Sum() NumberExpression { 547 | return function("SUM", e) 548 | } 549 | 550 | func (e expression) Avg() NumberExpression { 551 | return function("AVG", e) 552 | } 553 | 554 | func (e expression) Min() UnknownExpression { 555 | return function("MIN", e) 556 | } 557 | 558 | func (e expression) Max() UnknownExpression { 559 | return function("MAX", e) 560 | } 561 | 562 | func (e expression) Like(other interface{}) BooleanExpression { 563 | return e.binaryOperation("LIKE", other, 11, true) 564 | } 565 | 566 | func (e expression) Concat(other interface{}) StringExpression { 567 | return Concat(e, other) 568 | } 569 | 570 | func (e expression) Contains(substring string) BooleanExpression { 571 | return function("LOCATE", substring, e).GreaterThan(0) 572 | } 573 | 574 | func (e expression) binaryOperation(operator string, value interface{}, priority priority, isBool bool) expression { 575 | return expression{builder: func(scope scope) (string, error) { 576 | leftSql, err := e.GetSQL(scope) 577 | if err != nil { 578 | return "", err 579 | } 580 | leftPriority := e.priority 581 | rightSql, rightPriority, err := getSQL(scope, value) 582 | if err != nil { 583 | return "", err 584 | } 585 | shouldParenthesizeLeft := leftPriority > priority 586 | shouldParenthesizeRight := rightPriority >= priority 587 | var sb strings.Builder 588 | sb.Grow(len(leftSql) + len(operator) + len(rightSql) + 6) 589 | if shouldParenthesizeLeft { 590 | sb.WriteByte('(') 591 | } 592 | sb.WriteString(leftSql) 593 | if shouldParenthesizeLeft { 594 | sb.WriteByte(')') 595 | } 596 | sb.WriteByte(' ') 597 | sb.WriteString(operator) 598 | sb.WriteByte(' ') 599 | if shouldParenthesizeRight { 600 | sb.WriteByte('(') 601 | } 602 | sb.WriteString(rightSql) 603 | if shouldParenthesizeRight { 604 | sb.WriteByte(')') 605 | } 606 | return sb.String(), nil 607 | }, priority: priority, isBool: isBool} 608 | } 609 | 610 | func (e expression) prefixSuffixExpression(prefix string, suffix string, priority priority, isBool bool) expression { 611 | if e.sql != "" { 612 | return expression{ 613 | sql: prefix + e.sql + suffix, 614 | priority: priority, 615 | isBool: isBool, 616 | } 617 | } 618 | return expression{ 619 | builder: func(scope scope) (string, error) { 620 | exprSql, err := e.GetSQL(scope) 621 | if err != nil { 622 | return "", err 623 | } 624 | var sb strings.Builder 625 | sb.Grow(len(prefix) + len(exprSql) + len(suffix) + 2) 626 | sb.WriteString(prefix) 627 | shouldParenthesize := e.priority > priority 628 | if shouldParenthesize { 629 | sb.WriteByte('(') 630 | } 631 | sb.WriteString(exprSql) 632 | if shouldParenthesize { 633 | sb.WriteByte(')') 634 | } 635 | sb.WriteString(suffix) 636 | return sb.String(), nil 637 | }, 638 | priority: priority, 639 | isBool: isBool, 640 | } 641 | } 642 | 643 | func (e expression) IsNull() BooleanExpression { 644 | return e.prefixSuffixExpression("", " IS NULL", 11, true) 645 | } 646 | 647 | func (e expression) Not() BooleanExpression { 648 | switch { 649 | case e.isTrue: 650 | return False() 651 | case e.isFalse: 652 | return True() 653 | default: 654 | return e.prefixSuffixExpression("NOT ", "", 13, true) 655 | } 656 | } 657 | 658 | func (e expression) IsNotNull() BooleanExpression { 659 | return e.prefixSuffixExpression("", " IS NOT NULL", 11, true) 660 | } 661 | 662 | func (e expression) IsTrue() BooleanExpression { 663 | return e.prefixSuffixExpression("", " IS TRUE", 11, true) 664 | } 665 | 666 | func (e expression) IsNotTrue() BooleanExpression { 667 | return e.prefixSuffixExpression("", " IS NOT TRUE", 11, true) 668 | } 669 | 670 | func (e expression) IsFalse() BooleanExpression { 671 | return e.prefixSuffixExpression("", " IS FALSE", 11, true) 672 | } 673 | 674 | func (e expression) IsNotFalse() BooleanExpression { 675 | return e.prefixSuffixExpression("", " IS NOT FALSE", 11, true) 676 | } 677 | 678 | func expandSliceValue(value reflect.Value) (result []interface{}) { 679 | result = make([]interface{}, 0, 16) 680 | kind := value.Kind() 681 | switch kind { 682 | case reflect.Array, reflect.Slice: 683 | length := value.Len() 684 | for i := 0; i < length; i++ { 685 | result = append(result, expandSliceValue(value.Index(i))...) 686 | } 687 | case reflect.Interface, reflect.Ptr: 688 | result = append(result, expandSliceValue(value.Elem())...) 689 | default: 690 | result = append(result, value.Interface()) 691 | } 692 | return 693 | } 694 | 695 | func expandSliceValues(values []interface{}) (result []interface{}) { 696 | result = make([]interface{}, 0, 16) 697 | for _, v := range values { 698 | value := reflect.ValueOf(v) 699 | result = append(result, expandSliceValue(value)...) 700 | } 701 | return 702 | } 703 | 704 | func (e expression) In(values ...interface{}) BooleanExpression { 705 | values = expandSliceValues(values) 706 | if len(values) == 0 { 707 | return False() 708 | } 709 | joiner := func(exprSql, valuesSql string) string { return exprSql + " IN (" + valuesSql + ")" } 710 | builder := e.getBuilder(e.Equals, joiner, values...) 711 | return expression{builder: builder, priority: 11} 712 | } 713 | 714 | func (e expression) NotIn(values ...interface{}) BooleanExpression { 715 | values = expandSliceValues(values) 716 | if len(values) == 0 { 717 | return True() 718 | } 719 | joiner := func(exprSql, valuesSql string) string { return exprSql + " NOT IN (" + valuesSql + ")" } 720 | builder := e.getBuilder(e.NotEquals, joiner, values...) 721 | return expression{builder: builder, priority: 11} 722 | } 723 | 724 | type joinerFunc = func(exprSql, valuesSql string) string 725 | type booleanFunc = func(other interface{}) BooleanExpression 726 | type builderFunc = func(scope scope) (string, error) 727 | 728 | func (e expression) getBuilder(single booleanFunc, joiner joinerFunc, values ...interface{}) builderFunc { 729 | return func(scope scope) (string, error) { 730 | var valuesSql string 731 | var err error 732 | 733 | if len(values) == 1 { 734 | value := values[0] 735 | if selectStatus, ok := value.(toSelectFinal); ok { 736 | // IN subquery 737 | valuesSql, err = selectStatus.GetSQL() 738 | if err != nil { 739 | return "", err 740 | } 741 | } else { 742 | // IN a single value 743 | return single(value).GetSQL(scope) 744 | } 745 | } else { 746 | // IN a list 747 | valuesSql, err = commaValues(scope, values) 748 | if err != nil { 749 | return "", err 750 | } 751 | } 752 | 753 | exprSql, err := e.GetSQL(scope) 754 | if err != nil { 755 | return "", err 756 | } 757 | return joiner(exprSql, valuesSql), nil 758 | } 759 | } 760 | 761 | func (e expression) Between(min interface{}, max interface{}) BooleanExpression { 762 | return e.buildBetween(" BETWEEN ", min, max) 763 | } 764 | 765 | func (e expression) NotBetween(min interface{}, max interface{}) BooleanExpression { 766 | return e.buildBetween(" NOT BETWEEN ", min, max) 767 | } 768 | 769 | func (e expression) buildBetween(operator string, min interface{}, max interface{}) BooleanExpression { 770 | return expression{builder: func(scope scope) (string, error) { 771 | exprSql, err := e.GetSQL(scope) 772 | if err != nil { 773 | return "", err 774 | } 775 | minSql, _, err := getSQL(scope, min) 776 | if err != nil { 777 | return "", err 778 | } 779 | maxSql, _, err := getSQL(scope, max) 780 | if err != nil { 781 | return "", err 782 | } 783 | return exprSql + operator + minSql + " AND " + maxSql, nil 784 | }, priority: 12} 785 | } 786 | 787 | func (e expression) getOperatorPriority() priority { 788 | return e.priority 789 | } 790 | 791 | func (e expression) Desc() OrderBy { 792 | return orderBy{by: e, desc: true} 793 | } 794 | --------------------------------------------------------------------------------