├── logo.png
├── sqlingo-gen-sqlite3
└── main.go
├── sqlingo-gen-mysql
└── main.go
├── sqlingo-gen-postgres
└── main.go
├── order_test.go
├── order.go
├── dialect_test.go
├── geometry_test.go
├── dialect.go
├── table_test.go
├── function_test.go
├── sqlingo-gen
└── main.go
├── generator
├── generator_test.go
├── fetcher_sqlite3.go
├── fetcher_postgres.go
├── args.go
├── fetcher_mysql.go
└── generator.go
├── .github
└── workflows
│ └── go.yml
├── case_test.go
├── LICENSE
├── utils_test.go
├── geometry.go
├── delete_test.go
├── interceptor.go
├── table.go
├── interceptor_test.go
├── function.go
├── transaction.go
├── field_test.go
├── transaction_test.go
├── update_test.go
├── case.go
├── common_test.go
├── value.go
├── delete.go
├── database_test.go
├── value_test.go
├── update.go
├── field.go
├── common.go
├── insert_test.go
├── README.md
├── cursor_test.go
├── array.go
├── insert.go
├── cursor.go
├── expression_test.go
├── database.go
├── select_test.go
├── select.go
└── expression.go
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lqs/sqlingo/HEAD/logo.png
--------------------------------------------------------------------------------
/sqlingo-gen-sqlite3/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "github.com/lqs/sqlingo/generator"
6 | _ "github.com/mattn/go-sqlite3"
7 | )
8 |
9 | func main() {
10 | code, err := generator.Generate("mysql", "./testdb.sqlite3")
11 | if err != nil {
12 | panic(err)
13 | }
14 |
15 | fmt.Print(code)
16 | }
17 |
--------------------------------------------------------------------------------
/sqlingo-gen-mysql/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | _ "github.com/go-sql-driver/mysql"
6 | "github.com/lqs/sqlingo/generator"
7 | )
8 |
9 | func main() {
10 | code, err := generator.Generate("mysql", "username:password@tcp(hostname:3306)/database")
11 | if err != nil {
12 | panic(err)
13 | }
14 |
15 | fmt.Print(code)
16 | }
17 |
--------------------------------------------------------------------------------
/sqlingo-gen-postgres/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | _ "github.com/lib/pq"
6 | "github.com/lqs/sqlingo/generator"
7 | )
8 |
9 | func main() {
10 | code, err := generator.Generate("postgres", "host=localhost port=5432 user=user password=pass dbname=db sslmode=disable")
11 | if err != nil {
12 | panic(err)
13 | }
14 |
15 | fmt.Print(code)
16 | }
17 |
--------------------------------------------------------------------------------
/order_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "errors"
5 | "testing"
6 | )
7 |
8 | func TestOrder(t *testing.T) {
9 | e := expression{sql: "x"}
10 | assertValue(t, orderBy{by: e}, "x")
11 | assertValue(t, orderBy{by: e, desc: true}, "x DESC")
12 | assertError(t, orderBy{by: expression{builder: func(scope scope) (string, error) {
13 | return "", errors.New("error")
14 | }}})
15 | }
16 |
--------------------------------------------------------------------------------
/order.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | // OrderBy indicates the ORDER BY column and the status of descending order.
4 | type OrderBy interface {
5 | GetSQL(scope scope) (string, error)
6 | }
7 |
8 | type orderBy struct {
9 | by Expression
10 | desc bool
11 | }
12 |
13 | func (o orderBy) GetSQL(scope scope) (string, error) {
14 | sql, err := o.by.GetSQL(scope)
15 | if err != nil {
16 | return "", err
17 | }
18 | if o.desc {
19 | sql += " DESC"
20 | }
21 | return sql, nil
22 | }
23 |
--------------------------------------------------------------------------------
/dialect_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import "testing"
4 |
5 | func TestDialect(t *testing.T) {
6 | nameToDialect := map[string]dialect{
7 | "mysql": dialectMySQL,
8 | "sqlite3": dialectSqlite3,
9 | "postgres": dialectPostgres,
10 | "sqlserver": dialectMSSQL,
11 | "mssql": dialectMSSQL,
12 | "somedbidontknow": dialectUnknown,
13 | }
14 |
15 | for name, dialect := range nameToDialect {
16 | if getDialectFromDriverName(name) != dialect {
17 | t.Error()
18 | }
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/geometry_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import "testing"
4 |
5 | func TestGeometry(t *testing.T) {
6 | assertValue(t, STGeomFromText("sample wkt"), "ST_GeomFromText('sample wkt')")
7 | assertValue(t, STGeomFromTextf("sample wkt %d", 1), "ST_GeomFromText('sample wkt 1')")
8 |
9 | e := expression{
10 | builder: func(scope scope) (string, error) {
11 | return "<>", nil
12 | },
13 | }
14 | assertValue(t, e.STAsText(), "ST_AsText(<>)")
15 |
16 | t1 := NewTable("t1")
17 | field := NewWellKnownBinaryField(t1, "f1")
18 | assertValue(t, field, "`t1`.`f1`")
19 | }
20 |
--------------------------------------------------------------------------------
/dialect.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | type dialect int
4 |
5 | const (
6 | dialectUnknown dialect = iota
7 | dialectMySQL
8 | dialectSqlite3
9 | dialectPostgres
10 | dialectMSSQL
11 |
12 | dialectCount
13 | )
14 |
15 | type dialectArray [dialectCount]string
16 |
17 | func getDialectFromDriverName(driverName string) dialect {
18 | switch driverName {
19 | case "mysql":
20 | return dialectMySQL
21 | case "sqlite3":
22 | return dialectSqlite3
23 | case "postgres":
24 | return dialectPostgres
25 | case "sqlserver", "mssql":
26 | return dialectMSSQL
27 | default:
28 | return dialectUnknown
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/table_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import "testing"
4 |
5 | func TestTable(t *testing.T) {
6 | table := table{}
7 | if table.getOperatorPriority() != 0 {
8 | t.Error()
9 | }
10 | }
11 |
12 | func TestDerivedTable(t *testing.T) {
13 | dummyFields := []Field{NewNumberField(NewTable("table"), "field")}
14 | dt := derivedTable{
15 | name: "t",
16 | selectStatus: selectStatus{
17 | base: selectBase{
18 | fields: dummyFields,
19 | },
20 | },
21 | }
22 | if dt.GetName() != "t" {
23 | t.Error()
24 | }
25 |
26 | sql, err := dt.GetFields()[0].GetSQL(dummyMySQLScope)
27 | if err != nil {
28 | t.Error(err)
29 | }
30 | if sql != "`table`.`field`" {
31 | t.Error(sql)
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/function_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "errors"
5 | "testing"
6 | )
7 |
8 | func TestFunction(t *testing.T) {
9 | a1 := expression{sql: "a1"}
10 | a2 := expression{sql: "a2"}
11 | ee := expression{builder: func(scope scope) (string, error) {
12 | return "", errors.New("error")
13 | }}
14 |
15 | assertValue(t, Function("func"), "func()")
16 | assertValue(t, Function("func", a1), "func(a1)")
17 | assertValue(t, Function("func", a1, a2), "func(a1, a2)")
18 | assertError(t, Function("func", a1, ee))
19 |
20 | assertValue(t, Concat(a1, a2), "CONCAT(a1, a2)")
21 | assertValue(t, Count(a1), "COUNT(a1)")
22 | assertValue(t, If(a1, 1, 2), "IF(a1, 1, 2)")
23 | assertValue(t, Length(a1), "LENGTH(a1)")
24 | assertValue(t, Sum(a1), "SUM(a1)")
25 | }
26 |
--------------------------------------------------------------------------------
/sqlingo-gen/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | _ "github.com/go-sql-driver/mysql"
6 | "github.com/lqs/sqlingo/generator"
7 | "os"
8 | "strings"
9 | )
10 |
11 | func main() {
12 | warningLines := []string{
13 | "\u001b[31mThis command is deprecated. Please install the new generator with the corresponding driver:",
14 | "go get -u github.com/lqs/sqlingo/sqlingo-gen-mysql",
15 | "go get -u github.com/lqs/sqlingo/sqlingo-gen-sqlite3",
16 | "go get -u github.com/lqs/sqlingo/sqlingo-gen-postgres",
17 | "\u001b[0m",
18 | }
19 | _, _ = fmt.Fprintln(os.Stderr, strings.Join(warningLines, "\n"))
20 | code, err := generator.Generate("mysql", "username:password@tcp(hostname:3306)/database")
21 | if err != nil {
22 | panic(err)
23 | }
24 |
25 | fmt.Print(code)
26 | }
27 |
--------------------------------------------------------------------------------
/generator/generator_test.go:
--------------------------------------------------------------------------------
1 | package generator
2 |
3 | import "testing"
4 |
5 | func TestConvert(t *testing.T) {
6 | m := map[string]string{
7 | "abc": "Abc",
8 | "name": "Name",
9 | "Name": "Name",
10 | "abc_def": "AbcDef",
11 | "ϢϢϢϢ1": "ϢϢϢϢ1",
12 | "_[[[[[": "E",
13 | "abc/def": "AbcDef",
14 | "中文开头": "E中文开头",
15 | "abc~~中文": "Abc中文",
16 | "abc-def--ghi": "AbcDefGhi",
17 | "user_id": "UserID",
18 | "user_ip_address": "UserIPAddress",
19 | "user_ix": "UserIx",
20 | "volume_db": "VolumedB",
21 | "db_volume": "EdBVolume",
22 | }
23 | for k, v := range m {
24 | if convertToExportedIdentifier(k, []string{"ID", "IP", "dB"}) != v {
25 | t.Errorf("'%s' should be converted to '%s'", k, v)
26 | }
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/.github/workflows/go.yml:
--------------------------------------------------------------------------------
1 | # This workflow will build a golang project
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
3 |
4 | name: Go
5 |
6 | on:
7 | push:
8 | branches: [ "master" ]
9 | pull_request:
10 | branches: [ "master" ]
11 |
12 | jobs:
13 |
14 | build:
15 | runs-on: ubuntu-latest
16 | strategy:
17 | matrix:
18 | go-version: [ '1.22', '1.23' ]
19 | steps:
20 | - uses: actions/checkout@v4
21 | - name: Setup Go ${{ matrix.go-version }}
22 | uses: actions/setup-go@v4
23 | with:
24 | go-version: ${{ matrix.go-version }}
25 | - name: Display Go version
26 | run: go version
27 | - name: Build
28 | run: go mod init github.com/lqs/sqlingo && go mod tidy && go build -v .
29 | - name: Test
30 | run: go test -v .
31 |
--------------------------------------------------------------------------------
/case_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "errors"
5 | "testing"
6 | )
7 |
8 | func TestCase(t *testing.T) {
9 | c1 := expression{
10 | builder: func(scope scope) (string, error) {
11 | return "c1", nil
12 | },
13 | }
14 | c2 := expression{
15 | builder: func(scope scope) (string, error) {
16 | return "c2", nil
17 | },
18 | }
19 | assertValue(t, Case().WhenThen(c1, 1).WhenThen(c2, 2),
20 | "CASE WHEN c1 THEN 1 WHEN c2 THEN 2 END")
21 | assertValue(t, Case().WhenThen(c1, 1).WhenThen(c2, 2).Else(0),
22 | "CASE WHEN c1 THEN 1 WHEN c2 THEN 2 ELSE 0 END")
23 | assertValue(t, Case().Else(0),
24 | "0")
25 |
26 | ee := expression{
27 | builder: func(scope scope) (string, error) {
28 | return "", errors.New("error")
29 | },
30 | }
31 | assertError(t, Case().WhenThen(ee, 2))
32 | assertError(t, Case().WhenThen(c1, ee))
33 | assertError(t, Case().WhenThen(c1, 1).Else(ee))
34 | }
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Liu Qishuai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import "testing"
4 |
5 | var dummyMySQLScope = scope{Database: &database{dialect: dialectMySQL}}
6 |
7 | func assertEqual(t *testing.T, actualValue string, expectedValue string) {
8 | t.Helper()
9 | if actualValue != expectedValue {
10 | t.Errorf("actual [%s] expected [%s]", actualValue, expectedValue)
11 | }
12 | }
13 |
14 | func assertValue(t *testing.T, value interface{}, expectedSql string) {
15 | t.Helper()
16 | if generatedSql, _, _ := getSQL(dummyMySQLScope, value); generatedSql != expectedSql {
17 | t.Errorf("value [%v] generated [%s] expected [%s]", value, generatedSql, expectedSql)
18 | }
19 | }
20 |
21 | func assertLastSql(t *testing.T, expectedSql string) {
22 | t.Helper()
23 | if sharedMockConn.lastSql != expectedSql {
24 | t.Errorf("last sql [%s] expected [%s]", sharedMockConn.lastSql, expectedSql)
25 | }
26 | sharedMockConn.lastSql = ""
27 | }
28 |
29 | func assertError(t *testing.T, value interface{}) {
30 | t.Helper()
31 | if generatedSql, _, err := getSQL(dummyMySQLScope, value); err == nil {
32 | t.Errorf("value [%v] generated [%s] expected error", value, generatedSql)
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/geometry.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import "fmt"
4 |
5 | // WellKnownBinaryField is the interface of a generated field of binary geometry (WKB) type.
6 | type WellKnownBinaryField interface {
7 | WellKnownBinaryExpression
8 | GetTable() Table
9 | }
10 |
11 | // WellKnownBinaryExpression is the interface of an SQL expression with binary geometry (WKB) value.
12 | type WellKnownBinaryExpression interface {
13 | Expression
14 | STAsText() StringExpression
15 | }
16 |
17 | // WellKnownBinary is the type of geometry well-known binary (WKB) field.
18 | type WellKnownBinary []byte
19 |
20 | // NewWellKnownBinaryField creates a reference to a geometry WKB field. It should only be called from generated code.
21 | func NewWellKnownBinaryField(table Table, fieldName string) WellKnownBinaryField {
22 | return newField(table, fieldName)
23 | }
24 |
25 | func (e expression) STAsText() StringExpression {
26 | return function("ST_AsText", e)
27 | }
28 |
29 | func STGeomFromText(text interface{}) WellKnownBinaryExpression {
30 | return function("ST_GeomFromText", text)
31 | }
32 |
33 | func STGeomFromTextf(format string, a ...interface{}) WellKnownBinaryExpression {
34 | text := fmt.Sprintf(format, a...)
35 | return STGeomFromText(text)
36 | }
37 |
--------------------------------------------------------------------------------
/delete_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "testing"
7 | )
8 |
9 | func TestDelete(t *testing.T) {
10 | errorExpression := expression{
11 | builder: func(scope scope) (string, error) {
12 | return "", errors.New("error")
13 | },
14 | }
15 | db := newMockDatabase()
16 | if _, err := db.DeleteFrom(Table1).Where(staticExpression("##", 1, false)).Execute(); err != nil {
17 | t.Error(err)
18 | }
19 | assertLastSql(t, "DELETE FROM `table1` WHERE ##")
20 |
21 | if _, err := db.DeleteFrom(Table1).Where(errorExpression).Execute(); err == nil {
22 | t.Error("should get error here")
23 | }
24 |
25 | if _, err := db.DeleteFrom(Table1).Where(Raw("#1#")).Limit(3).Execute(); err != nil {
26 | t.Error(err)
27 | }
28 | assertLastSql(t, "DELETE FROM `table1` WHERE #1# LIMIT 3")
29 |
30 | if _, err := db.DeleteFrom(Table1).Where(Raw("#1#")).OrderBy(Raw("#2#")).Limit(3).Execute(); err != nil {
31 | t.Error(err)
32 | }
33 | assertLastSql(t, "DELETE FROM `table1` WHERE #1# ORDER BY #2# LIMIT 3")
34 |
35 | if _, err := db.DeleteFrom(Table1).Where(Raw("#1#")).WithContext(context.Background()).Execute(); err != nil {
36 | t.Error(err)
37 | }
38 | assertLastSql(t, "DELETE FROM `table1` WHERE #1#")
39 | }
40 |
--------------------------------------------------------------------------------
/interceptor.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | )
6 |
7 | // InvokerFunc is the function type of the actual invoker. It should be called in an interceptor.
8 | type InvokerFunc = func(ctx context.Context, sql string) error
9 |
10 | // InterceptorFunc is the function type of an interceptor. An interceptor should implement this function to fulfill it's purpose.
11 | type InterceptorFunc = func(ctx context.Context, sql string, invoker InvokerFunc) error
12 |
13 | func noopInterceptor(ctx context.Context, sql string, invoker InvokerFunc) error {
14 | return invoker(ctx, sql)
15 | }
16 |
17 | // ChainInterceptors chains multiple interceptors into one interceptor.
18 | func ChainInterceptors(interceptors ...InterceptorFunc) InterceptorFunc {
19 | if len(interceptors) == 0 {
20 | return noopInterceptor
21 | }
22 | return func(ctx context.Context, sql string, invoker InvokerFunc) error {
23 | var chain func(int, context.Context, string) error
24 | chain = func(i int, ctx context.Context, sql string) error {
25 | if i == len(interceptors) {
26 | return invoker(ctx, sql)
27 | }
28 | return interceptors[i](ctx, sql, func(ctx context.Context, sql string) error {
29 | return chain(i+1, ctx, sql)
30 | })
31 | }
32 | return chain(0, ctx, sql)
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/table.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | // Table is the interface of a generated table.
4 | type Table interface {
5 | GetName() string
6 | GetSQL(scope scope) string
7 | GetFields() []Field
8 | }
9 |
10 | type actualTable interface {
11 | Table
12 | GetFieldsSQL() string
13 | GetFullFieldsSQL() string
14 | }
15 |
16 | type table struct {
17 | Table
18 | name string
19 | sqlDialects dialectArray
20 | }
21 |
22 | func (t table) GetName() string {
23 | return t.name
24 | }
25 |
26 | func (t table) GetSQL(scope scope) string {
27 | return t.sqlDialects[scope.Database.dialect]
28 | }
29 |
30 | func (t table) getOperatorPriority() int {
31 | return 0
32 | }
33 |
34 | // NewTable creates a reference to a table. It should only be called from generated code.
35 | func NewTable(name string) Table {
36 | return table{name: name, sqlDialects: quoteIdentifier(name)}
37 | }
38 |
39 | type derivedTable struct {
40 | name string
41 | selectStatus selectStatus
42 | }
43 |
44 | func (t derivedTable) GetName() string {
45 | return t.name
46 | }
47 |
48 | func (t derivedTable) GetSQL(scope scope) string {
49 | sql, _ := t.selectStatus.GetSQL()
50 | return "(" + sql + ") AS " + t.name
51 | }
52 |
53 | func (t derivedTable) GetFields() []Field {
54 | return activeSelectBase(&t.selectStatus).fields
55 | }
56 |
--------------------------------------------------------------------------------
/interceptor_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "testing"
6 | )
7 |
8 | func TestChainInterceptors(t *testing.T) {
9 | s := ""
10 | i1 := func(ctx context.Context, sql string, invoker InvokerFunc) error {
11 | s += ""
12 | s += sql
13 | defer func() {
14 | s += ""
15 | }()
16 | return invoker(ctx, sql+"s1")
17 | }
18 | i2 := func(ctx context.Context, sql string, invoker InvokerFunc) error {
19 | s += ""
20 | s += sql
21 | defer func() {
22 | s += ""
23 | }()
24 | return invoker(ctx, sql+"s2")
25 | }
26 | chain := ChainInterceptors(i1, i2)
27 | _ = chain(context.Background(), "sql", func(ctx context.Context, sql string) error {
28 | s += ""
29 | s += sql
30 | defer func() {
31 | s += ""
32 | }()
33 | return nil
34 | })
35 | if s != "sqlsqls1sqls1s2" {
36 | t.Error(s)
37 | }
38 | }
39 |
40 | func TestEmptyChainInterceptors(t *testing.T) {
41 | s := ""
42 | chain := ChainInterceptors()
43 | _ = chain(context.Background(), "sql", func(ctx context.Context, sql string) error {
44 | s += ""
45 | defer func() {
46 | s += ""
47 | }()
48 | s += sql
49 | return nil
50 | })
51 |
52 | if s != "sql" {
53 | t.Error(s)
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/function.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | func function(name string, args ...interface{}) expression {
4 | return expression{builder: func(scope scope) (string, error) {
5 | valuesSql, err := commaValues(scope, args)
6 | if err != nil {
7 | return "", err
8 | }
9 | return name + "(" + valuesSql + ")", nil
10 | }}
11 | }
12 |
13 | // Function creates an expression of the call to specified function.
14 | func Function(name string, args ...interface{}) UnknownExpression {
15 | return function(name, args...)
16 | }
17 |
18 | // Concat creates an expression of CONCAT function.
19 | func Concat(args ...interface{}) StringExpression {
20 | return function("CONCAT", args...)
21 | }
22 |
23 | // Count creates an expression of COUNT aggregator.
24 | func Count(arg interface{}) NumberExpression {
25 | return function("COUNT", arg)
26 | }
27 |
28 | // If creates an expression of IF function.
29 | func If(predicate Expression, trueValue interface{}, falseValue interface{}) (result UnknownExpression) {
30 | return function("IF", predicate, trueValue, falseValue)
31 | }
32 |
33 | // Length creates an expression of LENGTH function.
34 | func Length(arg interface{}) NumberExpression {
35 | return function("LENGTH", arg)
36 | }
37 |
38 | // Sum creates an expression of SUM aggregator.
39 | func Sum(arg interface{}) NumberExpression {
40 | return function("SUM", arg)
41 | }
42 |
--------------------------------------------------------------------------------
/transaction.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | )
7 |
8 | // Transaction is the interface of a transaction with underlying sql.Tx object.
9 | // It provides methods to execute DDL and TCL operations.
10 | type Transaction interface {
11 | GetDB() *sql.DB
12 | GetTx() *sql.Tx
13 | Query(sql string) (Cursor, error)
14 | Execute(sql string) (sql.Result, error)
15 |
16 | Select(fields ...interface{}) selectWithFields
17 | SelectDistinct(fields ...interface{}) selectWithFields
18 | SelectFrom(tables ...Table) selectWithTables
19 | InsertInto(table Table) insertWithTable
20 | Update(table Table) updateWithSet
21 | DeleteFrom(table Table) deleteWithTable
22 | }
23 |
24 | func (d *database) GetTx() *sql.Tx {
25 | return d.tx
26 | }
27 |
28 | func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx Transaction) error) error {
29 | if ctx == nil {
30 | ctx = context.Background()
31 | }
32 | tx, err := d.db.BeginTx(ctx, opts)
33 | if err != nil {
34 | return err
35 | }
36 | isCommitted := false
37 | defer func() {
38 | if !isCommitted {
39 | _ = tx.Rollback()
40 | }
41 | }()
42 |
43 | if f != nil {
44 | db := *d
45 | db.tx = tx
46 | err = f(&db)
47 | if err != nil {
48 | return err
49 | }
50 | }
51 |
52 | err = tx.Commit()
53 | if err != nil {
54 | return err
55 | }
56 | isCommitted = true
57 | return nil
58 | }
59 |
--------------------------------------------------------------------------------
/generator/fetcher_sqlite3.go:
--------------------------------------------------------------------------------
1 | package generator
2 |
3 | import "database/sql"
4 |
5 | type sqlite3SchemaFetcher struct {
6 | db *sql.DB
7 | }
8 |
9 | func (s sqlite3SchemaFetcher) GetDatabaseName() (dbName string, err error) {
10 | dbName = "main"
11 | return
12 | }
13 |
14 | func (s sqlite3SchemaFetcher) GetTableNames() (tableNames []string, err error) {
15 | rows, err := s.db.Query("SELECT `name` FROM `sqlite_master` WHERE `type` ='table' AND `name` NOT LIKE 'sqlite_%'")
16 | if err != nil {
17 | return
18 | }
19 | defer rows.Close()
20 | for rows.Next() {
21 | var name string
22 | if err = rows.Scan(&name); err != nil {
23 | return
24 | }
25 | tableNames = append(tableNames, name)
26 | }
27 | return
28 | }
29 |
30 | func (s sqlite3SchemaFetcher) GetFieldDescriptors(tableName string) (result []fieldDescriptor, err error) {
31 | rows, err := s.db.Query("SELECT `name`, `type`, `notnull` FROM pragma_table_info('" + tableName + "')")
32 | if err != nil {
33 | return
34 | }
35 | defer rows.Close()
36 | for rows.Next() {
37 | var fieldDescriptor fieldDescriptor
38 | var notNull int
39 | if err = rows.Scan(&fieldDescriptor.Name, &fieldDescriptor.Type, ¬Null); err != nil {
40 | return
41 | }
42 | fieldDescriptor.AllowNull = notNull == 0
43 | result = append(result, fieldDescriptor)
44 | }
45 | return
46 | }
47 |
48 | func (s sqlite3SchemaFetcher) QuoteIdentifier(identifier string) string {
49 | return "\"" + identifier + "\""
50 | }
51 |
52 | func newSQLite3SchemaFetcher(db *sql.DB) schemaFetcher {
53 | return sqlite3SchemaFetcher{db: db}
54 | }
55 |
--------------------------------------------------------------------------------
/field_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "errors"
5 | "testing"
6 | )
7 |
8 | type dummyTable struct {
9 | }
10 |
11 | func (d dummyTable) GetName() string {
12 | panic("should not be here")
13 | }
14 |
15 | func (d dummyTable) GetSQL(scope scope) string {
16 | panic("should not be here")
17 | }
18 |
19 | func (d dummyTable) GetFieldByName(name string) Field {
20 | panic("should not be here")
21 | }
22 |
23 | func (d dummyTable) GetFields() []Field {
24 | panic("implement me")
25 | }
26 |
27 | func (d dummyTable) GetFieldsSQL() string {
28 | return ""
29 | }
30 |
31 | func (d dummyTable) GetFullFieldsSQL() string {
32 | return ""
33 | }
34 |
35 | func TestField(t *testing.T) {
36 | t1 := NewTable("t1")
37 | assertValue(t, NewNumberField(t1, "f1").Equals(1), "`t1`.`f1` = 1")
38 | assertValue(t, NewBooleanField(t1, "f1").Equals(true), "`t1`.`f1` = 1")
39 | assertValue(t, NewStringField(t1, "f1").Equals("x"), "`t1`.`f1` = 'x'")
40 |
41 | sql, _ := fieldList{}.GetSQL(scope{
42 | Tables: []Table{
43 | &dummyTable{},
44 | },
45 | })
46 | if sql != "" {
47 | t.Error(sql)
48 | }
49 |
50 | sql, _ = fieldList{}.GetSQL(scope{
51 | Tables: []Table{
52 | &dummyTable{},
53 | &dummyTable{},
54 | },
55 | })
56 | if sql != ", " {
57 | t.Error(sql)
58 | }
59 |
60 | if _, err := (fieldList{
61 | expression{builder: func(scope scope) (string, error) {
62 | return "", errors.New("error")
63 | }},
64 | }.GetSQL(scope{
65 | Tables: []Table{
66 | &dummyTable{},
67 | &dummyTable{},
68 | },
69 | })); err == nil {
70 | t.Error("should get error here")
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/generator/fetcher_postgres.go:
--------------------------------------------------------------------------------
1 | package generator
2 |
3 | import "database/sql"
4 |
5 | type postgresSchemaFetcher struct {
6 | db *sql.DB
7 | }
8 |
9 | func (p postgresSchemaFetcher) GetDatabaseName() (dbName string, err error) {
10 | row := p.db.QueryRow("SELECT current_database()")
11 | err = row.Scan(&dbName)
12 | return
13 | }
14 |
15 | func (p postgresSchemaFetcher) GetTableNames() (tableNames []string, err error) {
16 | rows, err := p.db.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
17 | if err != nil {
18 | return
19 | }
20 | defer rows.Close()
21 | for rows.Next() {
22 | var name string
23 | if err = rows.Scan(&name); err != nil {
24 | return
25 | }
26 | tableNames = append(tableNames, name)
27 | }
28 | return
29 | }
30 |
31 | func (p postgresSchemaFetcher) GetFieldDescriptors(tableName string) (result []fieldDescriptor, err error) {
32 | rows, err := p.db.Query("SELECT column_name, is_nullable, data_type FROM information_schema.columns WHERE table_schema = 'public' AND table_name = $1", tableName)
33 | if err != nil {
34 | return
35 | }
36 | defer rows.Close()
37 | for rows.Next() {
38 | var fieldDescriptor fieldDescriptor
39 | var isNullable string
40 | if err = rows.Scan(&fieldDescriptor.Name, &isNullable, &fieldDescriptor.Type); err != nil {
41 | return
42 | }
43 | fieldDescriptor.AllowNull = isNullable == "YES"
44 | result = append(result, fieldDescriptor)
45 | }
46 | return
47 | }
48 |
49 | func (p postgresSchemaFetcher) QuoteIdentifier(identifier string) string {
50 | return "\"" + identifier + "\""
51 | }
52 |
53 | func newPostgresSchemaFetcher(db *sql.DB) schemaFetcher {
54 | return postgresSchemaFetcher{db: db}
55 | }
56 |
--------------------------------------------------------------------------------
/generator/args.go:
--------------------------------------------------------------------------------
1 | package generator
2 |
3 | import (
4 | "fmt"
5 | "os"
6 | "strings"
7 | )
8 |
9 | type options struct {
10 | dataSourceName string
11 | tableNames []string
12 | forceCases []string
13 | }
14 |
15 | func printUsageAndExit(exampleDataSourceName string) {
16 | cmd := os.Args[0]
17 | _, _ = fmt.Fprintf(os.Stderr, `Usage:
18 | %s [-t table1,table2,...] [-forcecases ID,IDs,HTML] dataSourceName
19 | Example:
20 | %s "%s"
21 | `, cmd, cmd, exampleDataSourceName)
22 | os.Exit(1)
23 | }
24 |
25 | func parseArgs(exampleDataSourceName string) (options options) {
26 | var args []string
27 | parseTable := false
28 | parseForceCases := false
29 | for _, arg := range os.Args[1:] {
30 | if arg != "" && arg[0] == '-' {
31 | switch arg[1:] {
32 | case "t":
33 | if parseTable {
34 | printUsageAndExit(exampleDataSourceName)
35 | }
36 | parseTable = true
37 | case "forcecases":
38 | if parseForceCases {
39 | printUsageAndExit(exampleDataSourceName)
40 | }
41 | parseForceCases = true
42 | case "timeAsString":
43 | timeAsString = true
44 | default:
45 | printUsageAndExit(exampleDataSourceName)
46 | }
47 | } else {
48 | if parseTable {
49 | options.tableNames = append(options.tableNames, strings.Split(arg, ",")...)
50 | parseTable = false
51 | } else if parseForceCases {
52 | options.forceCases = append(options.forceCases, strings.Split(arg, ",")...)
53 | parseForceCases = false
54 | } else {
55 | args = append(args, arg)
56 | }
57 | }
58 | }
59 | if parseTable || parseForceCases {
60 | // "-t" not closed
61 | printUsageAndExit(exampleDataSourceName)
62 | }
63 |
64 | switch len(args) {
65 | case 1:
66 | options.dataSourceName = args[0]
67 | default:
68 | printUsageAndExit(exampleDataSourceName)
69 | }
70 |
71 | return
72 | }
73 |
--------------------------------------------------------------------------------
/transaction_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "testing"
7 | )
8 |
9 | type mockTx struct {
10 | isCommitted bool
11 | isRolledBack bool
12 | commitError error
13 | }
14 |
15 | func (m *mockTx) Commit() error {
16 | if m.commitError != nil {
17 | return m.commitError
18 | }
19 | m.isCommitted = true
20 | return nil
21 | }
22 |
23 | func (m *mockTx) Rollback() error {
24 | m.isRolledBack = true
25 | return nil
26 | }
27 |
28 | func TestTransaction(t *testing.T) {
29 | db := newMockDatabase()
30 | err := db.BeginTx(nil, nil, func(tx Transaction) error {
31 | if tx.GetDB() != db.GetDB() {
32 | t.Error()
33 | }
34 | if tx.GetTx() == nil {
35 | t.Error()
36 | }
37 |
38 | _, err := tx.Execute("")
39 | if err != nil {
40 | t.Error(err)
41 | }
42 | return nil
43 | })
44 | if err != nil {
45 | t.Error(err)
46 | }
47 | if !sharedMockConn.mockTx.isCommitted {
48 | t.Error()
49 | }
50 | if sharedMockConn.mockTx.isRolledBack {
51 | t.Error()
52 | }
53 |
54 | err = db.BeginTx(context.Background(), nil, func(tx Transaction) error {
55 | return errors.New("error")
56 | })
57 | if err == nil {
58 | t.Error("should get error here")
59 | }
60 | if sharedMockConn.mockTx.isCommitted {
61 | t.Error()
62 | }
63 | if !sharedMockConn.mockTx.isRolledBack {
64 | t.Error()
65 | }
66 |
67 | sharedMockConn.beginTxError = errors.New("error")
68 | err = db.BeginTx(context.Background(), nil, func(tx Transaction) error {
69 | return nil
70 | })
71 | if err == nil {
72 | t.Error("should get error here")
73 | }
74 | sharedMockConn.beginTxError = nil
75 |
76 | err = db.BeginTx(context.Background(), nil, func(tx Transaction) error {
77 | sharedMockConn.mockTx.commitError = errors.New("error")
78 | return nil
79 | })
80 | if err == nil {
81 | t.Error("should get error here")
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/update_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "testing"
7 | )
8 |
9 | func TestUpdate(t *testing.T) {
10 | db := newMockDatabase()
11 |
12 | _, _ = db.Update(Table1).Set(field1, field2).Where(True()).Execute()
13 | assertLastSql(t, "UPDATE `table1` SET `field1` = `field2`")
14 |
15 | _, _ = db.Update(Table1).
16 | Set(field1, 10).
17 | Where(field2.Equals(2)).
18 | OrderBy(field1.Desc()).
19 | Limit(2).
20 | Execute()
21 | assertLastSql(t, "UPDATE `table1` SET `field1` = 10 WHERE `field2` = 2 ORDER BY `field1` DESC LIMIT 2")
22 |
23 | _, _ = db.Update(Table1).
24 | SetIf(true, field1, 10).
25 | SetIf(false, field2, 10).
26 | Where(True()).
27 | Execute()
28 | assertLastSql(t, "UPDATE `table1` SET `field1` = 10")
29 |
30 | _, _ = db.Update(Table1).
31 | SetIf(false, field1, 10).
32 | Where(True()).
33 | Execute()
34 | assertLastSql(t, "/* UPDATE without SET clause */ DO 0")
35 |
36 | _, _ = db.Update(Table1).Limit(3).Execute()
37 | assertLastSql(t, "/* UPDATE without SET clause */ DO 0")
38 |
39 | errExp := &expression{
40 | builder: func(scope scope) (string, error) {
41 | return "", errors.New("error")
42 | },
43 | }
44 |
45 | if _, err := db.Update(Table1).
46 | Set(field1, 10).
47 | OrderBy(orderBy{by: errExp}).
48 | Execute(); err == nil {
49 | t.Error("should get error here")
50 | }
51 |
52 | if _, err := db.Update(Table1).
53 | Set(field1, errExp).
54 | Where(True()).
55 | Execute(); err == nil {
56 | t.Error("should get error here")
57 | }
58 |
59 | if _, err := db.Update(Table1).
60 | Set(field1, 10).
61 | Where(errExp).
62 | Execute(); err == nil {
63 | t.Error("should get error here")
64 | }
65 |
66 | if _, err := db.Update(Table1).
67 | Set(field1, 10).
68 | Where(True()).
69 | WithContext(context.Background()).
70 | Execute(); err != nil {
71 | t.Error(err)
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/case.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import "strings"
4 |
5 | // CaseExpression indicates the status in a CASE statement
6 | type CaseExpression interface {
7 | WhenThen(when BooleanExpression, then interface{}) CaseExpression
8 | Else(value interface{}) CaseExpressionWithElse
9 | End() Expression
10 | }
11 |
12 | // CaseExpressionWithElse indicates the status in CASE ... ELSE ... statement
13 | type CaseExpressionWithElse interface {
14 | End() Expression
15 | }
16 |
17 | type caseStatus struct {
18 | head, tail *whenThen
19 | elseValue interface{}
20 | }
21 |
22 | type whenThen struct {
23 | next *whenThen
24 | when BooleanExpression
25 | then interface{}
26 | }
27 |
28 | // Case initiates a CASE statement
29 | func Case() CaseExpression {
30 | return caseStatus{}
31 | }
32 |
33 | func (s caseStatus) WhenThen(when BooleanExpression, then interface{}) CaseExpression {
34 | whenThen := &whenThen{when: when, then: then}
35 | if s.head == nil {
36 | s.head = whenThen
37 | }
38 | if s.tail != nil {
39 | s.tail.next = whenThen
40 | }
41 | s.tail = whenThen
42 | return s
43 | }
44 |
45 | func (s caseStatus) Else(value interface{}) CaseExpressionWithElse {
46 | s.elseValue = value
47 | return s
48 | }
49 |
50 | func (s caseStatus) End() Expression {
51 | if s.head == nil {
52 | return expression{
53 | builder: func(scope scope) (string, error) {
54 | elseSql, _, err := getSQL(scope, s.elseValue)
55 | return elseSql, err
56 | },
57 | }
58 | }
59 |
60 | return expression{
61 | builder: func(scope scope) (string, error) {
62 | sb := strings.Builder{}
63 | sb.WriteString("CASE ")
64 |
65 | for whenThen := s.head; whenThen != nil; whenThen = whenThen.next {
66 | whenSql, err := whenThen.when.GetSQL(scope)
67 | if err != nil {
68 | return "", err
69 | }
70 | thenSql, _, err := getSQL(scope, whenThen.then)
71 | if err != nil {
72 | return "", err
73 | }
74 | sb.WriteString("WHEN " + whenSql + " THEN " + thenSql + " ")
75 | }
76 | if s.elseValue != nil {
77 | elseSql, _, err := getSQL(scope, s.elseValue)
78 | if err != nil {
79 | return "", err
80 | }
81 | sb.WriteString("ELSE " + elseSql + " ")
82 | }
83 | sb.WriteString("END")
84 | return sb.String(), nil
85 | },
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/common_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "errors"
5 | "strings"
6 | "testing"
7 | )
8 |
9 | func TestCommon(t *testing.T) {
10 | db := newMockDatabase()
11 |
12 | dummyExp1 := expression{sql: ""}
13 | dummyExp2 := expression{sql: ""}
14 | errExp := expression{builder: func(scope scope) (string, error) {
15 | return "", errors.New("error")
16 | }}
17 | assertValue(t, &assignment{
18 | field: dummyExp1,
19 | value: dummyExp2,
20 | }, " = ")
21 | assertError(t, &assignment{
22 | field: errExp,
23 | value: dummyExp2,
24 | })
25 | assertError(t, &assignment{
26 | field: dummyExp1,
27 | value: errExp,
28 | })
29 | assertError(t, &assignment{
30 | field: errExp,
31 | value: errExp,
32 | })
33 | assertError(t, command("COMMAND", errExp))
34 |
35 | sql, err := commaExpressions(scope{}, []Expression{dummyExp1, dummyExp2, dummyExp1})
36 | if err != nil {
37 | t.Error(err)
38 | }
39 | if sql != ", , " {
40 | t.Error()
41 | }
42 |
43 | _, err = commaExpressions(scope{}, []Expression{dummyExp1, dummyExp2, errExp})
44 | if err == nil {
45 | t.Error("should get error")
46 | }
47 |
48 | sql, err = commaAssignments(scope{}, []assignment{
49 | {field: dummyExp1, value: dummyExp1},
50 | {field: dummyExp1, value: dummyExp2},
51 | {field: dummyExp2, value: dummyExp2},
52 | })
53 | if err != nil {
54 | t.Error(err)
55 | }
56 | if sql != " = , = , = " {
57 | t.Error()
58 | }
59 | _, err = commaAssignments(scope{}, []assignment{
60 | {field: dummyExp1, value: dummyExp1},
61 | {field: dummyExp1, value: errExp},
62 | })
63 | if err == nil {
64 | t.Error("should get error")
65 | }
66 |
67 | sql, err = commaOrderBys(scope{}, []OrderBy{
68 | orderBy{by: dummyExp1, desc: true},
69 | orderBy{by: dummyExp2},
70 | })
71 | if err != nil {
72 | t.Error(err)
73 | }
74 | if sql != " DESC, " {
75 | t.Error()
76 | }
77 |
78 | _, err = commaOrderBys(scope{}, []OrderBy{
79 | orderBy{by: errExp},
80 | })
81 | if err == nil {
82 | t.Error("should get error")
83 | }
84 |
85 | db.EnableCallerInfo(true)
86 | if _, err := db.Select(1).FetchFirst(); err != nil {
87 | t.Error(err)
88 | }
89 | if !strings.HasPrefix(sharedMockConn.lastSql, "/* ") {
90 | t.Error()
91 | }
92 | }
93 |
--------------------------------------------------------------------------------
/value.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "math"
5 | "strconv"
6 | )
7 |
8 | const (
9 | maxInt = 1<<(strconv.IntSize-1) - 1
10 | minInt = -1 << (strconv.IntSize - 1)
11 | maxUint = 1<= minInt && r <= maxInt {
42 | return int(r)
43 | }
44 | return 0
45 | }
46 |
47 | func (v value) Int8() int8 {
48 | if r := v.Int64(); r >= math.MinInt8 && r <= math.MaxInt8 {
49 | return int8(r)
50 | }
51 | return 0
52 | }
53 |
54 | func (v value) Int16() int16 {
55 | if r := v.Int64(); r >= math.MinInt16 && r <= math.MaxInt16 {
56 | return int16(r)
57 | }
58 | return 0
59 | }
60 |
61 | func (v value) Int32() int32 {
62 | if r := v.Int64(); r >= math.MinInt32 && r <= math.MaxInt32 {
63 | return int32(r)
64 | }
65 | return 0
66 | }
67 |
68 | func (v value) Uint() uint {
69 | if r := v.Uint64(); r <= maxUint {
70 | return uint(r)
71 | }
72 | return 0
73 | }
74 |
75 | func (v value) Uint8() uint8 {
76 | if r := v.Uint64(); r <= math.MaxUint8 {
77 | return uint8(r)
78 | }
79 | return 0
80 | }
81 |
82 | func (v value) Uint16() uint16 {
83 | if r := v.Uint64(); r <= math.MaxUint16 {
84 | return uint16(r)
85 | }
86 | return 0
87 | }
88 |
89 | func (v value) Uint32() uint32 {
90 | if r := v.Uint64(); r <= math.MaxUint32 {
91 | return uint32(r)
92 | }
93 | return 0
94 | }
95 |
96 | func (v value) Bool() bool {
97 | if v.stringValue == nil {
98 | return false
99 | }
100 | switch *v.stringValue {
101 | case "", "0", "\x00":
102 | return false
103 | default:
104 | return true
105 | }
106 | }
107 |
108 | func (v value) String() string {
109 | if v.stringValue == nil {
110 | return ""
111 | }
112 | return *v.stringValue
113 | }
114 |
115 | func (v value) IsNull() bool {
116 | return v.stringValue == nil
117 | }
118 |
--------------------------------------------------------------------------------
/generator/fetcher_mysql.go:
--------------------------------------------------------------------------------
1 | package generator
2 |
3 | import (
4 | "database/sql"
5 | "regexp"
6 | "strconv"
7 | )
8 |
9 | var timeAsString = false
10 |
11 | type mysqlSchemaFetcher struct {
12 | db *sql.DB
13 | }
14 |
15 | func (m mysqlSchemaFetcher) GetDatabaseName() (dbName string, err error) {
16 | row := m.db.QueryRow("SELECT DATABASE()")
17 | err = row.Scan(&dbName)
18 | return
19 | }
20 |
21 | func (m mysqlSchemaFetcher) GetTableNames() (tableNames []string, err error) {
22 | rows, err := m.db.Query("SHOW TABLES")
23 | if err != nil {
24 | return
25 | }
26 | defer rows.Close()
27 |
28 | for rows.Next() {
29 | var name string
30 | err = rows.Scan(&name)
31 | if err != nil {
32 | return
33 | }
34 | tableNames = append(tableNames, name)
35 | }
36 | return
37 | }
38 |
39 | func (m mysqlSchemaFetcher) GetFieldDescriptors(tableName string) ([]fieldDescriptor, error) {
40 | rows, err := m.db.Query("SHOW FULL COLUMNS FROM `" + tableName + "`")
41 | if err != nil {
42 | return nil, err
43 | }
44 |
45 | var result []fieldDescriptor
46 | for rows.Next() {
47 | columns, err := rows.Columns()
48 | if err != nil {
49 | return nil, err
50 | }
51 | var pointers []interface{}
52 | for i := 0; i < len(columns); i++ {
53 | var value *string
54 | pointers = append(pointers, &value)
55 | }
56 | err = rows.Scan(pointers...)
57 | if err != nil {
58 | return nil, err
59 | }
60 | row := make(map[string]string)
61 | for i, column := range columns {
62 | pointer := *pointers[i].(**string)
63 | if pointer != nil {
64 | row[column] = *pointer
65 | }
66 | }
67 |
68 | r, _ := regexp.Compile("([a-z]+)(\\(([0-9]+)\\))?( ([a-z]+))?")
69 | submatches := r.FindStringSubmatch(row["Type"])
70 |
71 | fieldType := submatches[1]
72 | fieldSize := 0
73 | if submatches[3] != "" {
74 | fieldSize, err = strconv.Atoi(submatches[3])
75 | if err != nil {
76 | return nil, err
77 | }
78 | }
79 | unsigned := submatches[5] == "unsigned"
80 |
81 | result = append(result, fieldDescriptor{
82 | Name: row["Field"],
83 | Type: fieldType,
84 | Size: fieldSize,
85 | Unsigned: unsigned,
86 | AllowNull: row["Null"] == "YES",
87 | Comment: row["Comment"],
88 | })
89 | }
90 | return result, nil
91 | }
92 |
93 | func (m mysqlSchemaFetcher) QuoteIdentifier(identifier string) string {
94 | return "`" + identifier + "`"
95 | }
96 |
97 | func newMySQLSchemaFetcher(db *sql.DB) schemaFetcher {
98 | return mysqlSchemaFetcher{db: db}
99 | }
100 |
--------------------------------------------------------------------------------
/delete.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "strconv"
7 | "strings"
8 | )
9 |
10 | type deleteStatus struct {
11 | scope scope
12 | where BooleanExpression
13 | orderBys []OrderBy
14 | limit *int
15 | ctx context.Context
16 | }
17 |
18 | type deleteWithTable interface {
19 | Where(conditions ...BooleanExpression) deleteWithWhere
20 | }
21 |
22 | type deleteWithWhere interface {
23 | toDeleteWithContext
24 | toDeleteFinal
25 | OrderBy(orderBys ...OrderBy) deleteWithOrder
26 | Limit(limit int) deleteWithLimit
27 | }
28 |
29 | type deleteWithOrder interface {
30 | toDeleteWithContext
31 | toDeleteFinal
32 | Limit(limit int) deleteWithLimit
33 | }
34 |
35 | type deleteWithLimit interface {
36 | toDeleteWithContext
37 | toDeleteFinal
38 | }
39 |
40 | type toDeleteWithContext interface {
41 | WithContext(ctx context.Context) toDeleteFinal
42 | }
43 |
44 | type toDeleteFinal interface {
45 | GetSQL() (string, error)
46 | Execute() (result sql.Result, err error)
47 | }
48 |
49 | func (d *database) DeleteFrom(table Table) deleteWithTable {
50 | return deleteStatus{scope: scope{Database: d, Tables: []Table{table}}}
51 | }
52 |
53 | func (s deleteStatus) Where(conditions ...BooleanExpression) deleteWithWhere {
54 | s.where = And(conditions...)
55 | return s
56 | }
57 |
58 | func (s deleteStatus) OrderBy(orderBys ...OrderBy) deleteWithOrder {
59 | s.orderBys = orderBys
60 | return s
61 | }
62 |
63 | func (s deleteStatus) Limit(limit int) deleteWithLimit {
64 | s.limit = &limit
65 | return s
66 | }
67 |
68 | func (s deleteStatus) GetSQL() (string, error) {
69 | var sb strings.Builder
70 | sb.Grow(128)
71 |
72 | sb.WriteString("DELETE FROM ")
73 | sb.WriteString(s.scope.Tables[0].GetSQL(s.scope))
74 |
75 | if err := appendWhere(&sb, s.scope, s.where); err != nil {
76 | return "", err
77 | }
78 |
79 | if len(s.orderBys) > 0 {
80 | orderBySql, err := commaOrderBys(s.scope, s.orderBys)
81 | if err != nil {
82 | return "", err
83 | }
84 | sb.WriteString(" ORDER BY ")
85 | sb.WriteString(orderBySql)
86 | }
87 |
88 | if s.limit != nil {
89 | sb.WriteString(" LIMIT ")
90 | sb.WriteString(strconv.Itoa(*s.limit))
91 | }
92 |
93 | return sb.String(), nil
94 | }
95 |
96 | func (s deleteStatus) WithContext(ctx context.Context) toDeleteFinal {
97 | s.ctx = ctx
98 | return s
99 | }
100 |
101 | func (s deleteStatus) Execute() (sql.Result, error) {
102 | sqlString, err := s.GetSQL()
103 | if err != nil {
104 | return nil, err
105 | }
106 | return s.scope.Database.ExecuteContext(s.ctx, sqlString)
107 | }
108 |
--------------------------------------------------------------------------------
/database_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "database/sql/driver"
7 | "errors"
8 | "testing"
9 | "time"
10 | )
11 |
12 | func (m *mockConn) Prepare(query string) (driver.Stmt, error) {
13 | if m.prepareError != nil {
14 | return nil, m.prepareError
15 | }
16 | m.lastSql = query
17 | return &mockStmt{
18 | columnCount: m.columnCount,
19 | rowCount: m.rowCount,
20 | }, nil
21 | }
22 |
23 | func (m mockConn) Close() error {
24 | return nil
25 | }
26 |
27 | func (m *mockConn) Begin() (driver.Tx, error) {
28 | if m.beginTxError != nil {
29 | return nil, m.beginTxError
30 | }
31 | m.mockTx = &mockTx{}
32 | return m.mockTx, nil
33 | }
34 |
35 | var sharedMockConn = &mockConn{
36 | columnCount: 11,
37 | rowCount: 10,
38 | }
39 |
40 | func (m mockDriver) Open(name string) (driver.Conn, error) {
41 | return sharedMockConn, nil
42 | }
43 |
44 | func newMockDatabase() Database {
45 | db, err := Open("sqlingo-mock", "dummy")
46 | if err != nil {
47 | panic(err)
48 | }
49 | db.(*database).dialect = dialectMySQL
50 | return db
51 | }
52 |
53 | func init() {
54 | sql.Register("sqlingo-mock", &mockDriver{})
55 | }
56 |
57 | func TestDatabase(t *testing.T) {
58 | if _, err := Open("unknowndb", "unknown"); err == nil {
59 | t.Error()
60 | }
61 |
62 | db := newMockDatabase()
63 | if db.GetDB() == nil {
64 | t.Error()
65 | }
66 |
67 | interceptorExecutedCount := 0
68 | loggerExecutedCount := 0
69 | db.SetInterceptor(func(ctx context.Context, sql string, invoker InvokerFunc) error {
70 | if sql != "SELECT 1" {
71 | t.Error()
72 | }
73 | interceptorExecutedCount++
74 | return invoker(ctx, sql)
75 | })
76 | db.SetLogger(func(sql string, _ time.Duration, _, _ bool) {
77 | if sql != "SELECT 1" {
78 | t.Error(sql)
79 | }
80 | loggerExecutedCount++
81 | })
82 | _, _ = db.Query("SELECT 1")
83 | if interceptorExecutedCount != 1 || loggerExecutedCount != 1 {
84 | t.Error(interceptorExecutedCount, loggerExecutedCount)
85 | }
86 |
87 | _, _ = db.Execute("SELECT 1")
88 | if loggerExecutedCount != 2 {
89 | t.Error(loggerExecutedCount)
90 | }
91 |
92 | db.SetInterceptor(func(ctx context.Context, sql string, invoker InvokerFunc) error {
93 | return errors.New("error")
94 | })
95 | if _, err := db.Query("SELECT 1"); err == nil {
96 | t.Error("should get error here")
97 | }
98 | }
99 |
100 | func TestDatabaseRetry(t *testing.T) {
101 | db := newMockDatabase()
102 | retryCount := 0
103 | db.SetRetryPolicy(func(err error) bool {
104 | retryCount++
105 | return retryCount < 10
106 | })
107 |
108 | sharedMockConn.prepareError = errors.New("error")
109 | if _, err := db.Query("SELECT 1"); err == nil {
110 | t.Error("should get error here")
111 | }
112 | if retryCount != 10 {
113 | t.Error(retryCount)
114 | }
115 | sharedMockConn.prepareError = nil
116 | }
117 |
--------------------------------------------------------------------------------
/value_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import "testing"
4 |
5 | func newValue(s string) value {
6 | return value{stringValue: &s}
7 | }
8 |
9 | func TestValue(t *testing.T) {
10 | value := newValue("42")
11 | if value.String() != "42" {
12 | t.Error()
13 | }
14 | if !value.Bool() {
15 | t.Error()
16 | }
17 | if value.Int() != 42 {
18 | t.Error()
19 | }
20 | if value.Int8() != 42 {
21 | t.Error()
22 | }
23 | if value.Int16() != 42 {
24 | t.Error()
25 | }
26 | if value.Int32() != 42 {
27 | t.Error()
28 | }
29 | if value.Int64() != 42 {
30 | t.Error()
31 | }
32 |
33 | if value.Uint() != 42 {
34 | t.Error()
35 | }
36 | if value.Uint8() != 42 {
37 | t.Error()
38 | }
39 | if value.Uint16() != 42 {
40 | t.Error()
41 | }
42 | if value.Uint32() != 42 {
43 | t.Error()
44 | }
45 | if value.Uint64() != 42 {
46 | t.Error()
47 | }
48 | }
49 |
50 | func TestValueOverflow1(t *testing.T) {
51 | value := newValue("3000000000")
52 | if value.Int() != 3000000000 {
53 | t.Error()
54 | }
55 | if value.Int8() != 0 {
56 | t.Error()
57 | }
58 | if value.Int16() != 0 {
59 | t.Error()
60 | }
61 | if value.Int32() != 0 {
62 | t.Error()
63 | }
64 | if value.Int64() != 3000000000 {
65 | t.Error()
66 | }
67 |
68 | if value.Uint() != 3000000000 {
69 | t.Error()
70 | }
71 | if value.Uint8() != 0 {
72 | t.Error()
73 | }
74 | if value.Uint16() != 0 {
75 | t.Error()
76 | }
77 | if value.Uint32() != 3000000000 {
78 | t.Error()
79 | }
80 | if value.Uint64() != 3000000000 {
81 | t.Error()
82 | }
83 | }
84 |
85 | func TestValueOverflow2(t *testing.T) {
86 | value := newValue("3000000000000000")
87 | if value.Int() != 3000000000000000 {
88 | t.Error()
89 | }
90 | if value.Int8() != 0 {
91 | t.Error()
92 | }
93 | if value.Int16() != 0 {
94 | t.Error()
95 | }
96 | if value.Int32() != 0 {
97 | t.Error()
98 | }
99 | if value.Int64() != 3000000000000000 {
100 | t.Error()
101 | }
102 |
103 | if value.Uint() != 3000000000000000 {
104 | t.Error()
105 | }
106 | if value.Uint8() != 0 {
107 | t.Error()
108 | }
109 | if value.Uint16() != 0 {
110 | t.Error()
111 | }
112 | if value.Uint32() != 0 {
113 | t.Error()
114 | }
115 | if value.Uint64() != 3000000000000000 {
116 | t.Error()
117 | }
118 | }
119 |
120 | func TestValueOverflow3(t *testing.T) {
121 | value := newValue("10000000000000000000")
122 | if value.Int() != 0 {
123 | t.Error(value.Int())
124 | }
125 | if value.Uint() != 10000000000000000000 {
126 | t.Error()
127 | }
128 | }
129 |
130 | func TestValueBool(t *testing.T) {
131 | if !newValue("1").Bool() {
132 | t.Error()
133 | }
134 | if newValue("0").Bool() {
135 | t.Error()
136 | }
137 | if newValue("").Bool() {
138 | t.Error()
139 | }
140 | if (value{}).Bool() {
141 | t.Error()
142 | }
143 | }
144 |
145 | func TestValueNull(t *testing.T) {
146 | value := value{}
147 | if value.String() != "" {
148 | t.Error()
149 | }
150 | if value.Int() != 0 {
151 | t.Error()
152 | }
153 | if value.Uint() != 0 {
154 | t.Error()
155 | }
156 | }
157 |
--------------------------------------------------------------------------------
/update.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "strconv"
7 | "strings"
8 | )
9 |
10 | type updateStatus struct {
11 | scope scope
12 | assignments []assignment
13 | where BooleanExpression
14 | orderBys []OrderBy
15 | limit *int
16 | ctx context.Context
17 | }
18 |
19 | func (d *database) Update(table Table) updateWithSet {
20 | return updateStatus{scope: scope{Database: d, Tables: []Table{table}}}
21 | }
22 |
23 | type updateWithSet interface {
24 | Set(Field Field, value interface{}) updateWithSet
25 | SetIf(prerequisite bool, Field Field, value interface{}) updateWithSet
26 | Where(conditions ...BooleanExpression) updateWithWhere
27 | OrderBy(orderBys ...OrderBy) updateWithOrder
28 | Limit(limit int) updateWithLimit
29 | }
30 |
31 | type updateWithWhere interface {
32 | toUpdateWithContext
33 | toUpdateFinal
34 | OrderBy(orderBys ...OrderBy) updateWithOrder
35 | Limit(limit int) updateWithLimit
36 | }
37 |
38 | type updateWithOrder interface {
39 | toUpdateWithContext
40 | toUpdateFinal
41 | Limit(limit int) updateWithLimit
42 | }
43 |
44 | type updateWithLimit interface {
45 | toUpdateWithContext
46 | toUpdateFinal
47 | }
48 |
49 | type toUpdateWithContext interface {
50 | WithContext(ctx context.Context) toUpdateFinal
51 | }
52 |
53 | type toUpdateFinal interface {
54 | GetSQL() (string, error)
55 | Execute() (sql.Result, error)
56 | }
57 |
58 | func (s updateStatus) Set(field Field, value interface{}) updateWithSet {
59 | s.assignments = append([]assignment{}, s.assignments...)
60 | s.assignments = append(s.assignments, assignment{
61 | field: field,
62 | value: value,
63 | })
64 | return s
65 | }
66 |
67 | func (s updateStatus) SetIf(prerequisite bool, field Field, value interface{}) updateWithSet {
68 | if prerequisite {
69 | return s.Set(field, value)
70 | }
71 | return s
72 | }
73 |
74 | func (s updateStatus) Where(conditions ...BooleanExpression) updateWithWhere {
75 | s.where = And(conditions...)
76 | return s
77 | }
78 |
79 | func (s updateStatus) OrderBy(orderBys ...OrderBy) updateWithOrder {
80 | s.orderBys = orderBys
81 | return s
82 | }
83 |
84 | func (s updateStatus) Limit(limit int) updateWithLimit {
85 | s.limit = &limit
86 | return s
87 | }
88 |
89 | func (s updateStatus) GetSQL() (string, error) {
90 | if len(s.assignments) == 0 {
91 | return "/* UPDATE without SET clause */ DO 0", nil
92 | }
93 | var sb strings.Builder
94 | sb.Grow(128)
95 |
96 | sb.WriteString("UPDATE ")
97 | sb.WriteString(s.scope.Tables[0].GetSQL(s.scope))
98 |
99 | assignmentsSql, err := commaAssignments(s.scope, s.assignments)
100 | if err != nil {
101 | return "", err
102 | }
103 | sb.WriteString(" SET ")
104 | sb.WriteString(assignmentsSql)
105 |
106 | if err := appendWhere(&sb, s.scope, s.where); err != nil {
107 | return "", err
108 | }
109 |
110 | if len(s.orderBys) > 0 {
111 | orderBySql, err := commaOrderBys(s.scope, s.orderBys)
112 | if err != nil {
113 | return "", err
114 | }
115 | sb.WriteString(" ORDER BY ")
116 | sb.WriteString(orderBySql)
117 | }
118 |
119 | if s.limit != nil {
120 | sb.WriteString(" LIMIT ")
121 | sb.WriteString(strconv.Itoa(*s.limit))
122 | }
123 |
124 | return sb.String(), nil
125 | }
126 |
127 | func (s updateStatus) WithContext(ctx context.Context) toUpdateFinal {
128 | s.ctx = ctx
129 | return s
130 | }
131 |
132 | func (s updateStatus) Execute() (sql.Result, error) {
133 | sqlString, err := s.GetSQL()
134 | if err != nil {
135 | return nil, err
136 | }
137 | return s.scope.Database.ExecuteContext(s.ctx, sqlString)
138 | }
139 |
--------------------------------------------------------------------------------
/field.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import "strings"
4 |
5 | // Field is the interface of a generated field.
6 | type Field interface {
7 | Expression
8 | GetTable() Table
9 | }
10 |
11 | // NumberField is the interface of a generated field of number type.
12 | type NumberField interface {
13 | NumberExpression
14 | GetTable() Table
15 | }
16 |
17 | // BooleanField is the interface of a generated field of boolean type.
18 | type BooleanField interface {
19 | BooleanExpression
20 | GetTable() Table
21 | }
22 |
23 | // StringField is the interface of a generated field of string type.
24 | type StringField interface {
25 | StringExpression
26 | GetTable() Table
27 | }
28 |
29 | // ArrayField is the interface of a generated field of array type.
30 | type ArrayField interface {
31 | ArrayExpression
32 | IndexOf(index int) Field
33 | GetTable() Table
34 | }
35 |
36 | type DateField interface {
37 | DateExpression
38 | GetTable() Table
39 | }
40 |
41 | type actualField struct {
42 | expression
43 | table Table
44 | }
45 |
46 | func (f actualField) GetTable() Table {
47 | return f.table
48 | }
49 |
50 | func newField(table Table, fieldName string) actualField {
51 | tableName := table.GetName()
52 | tableNameSqlArray := quoteIdentifier(tableName)
53 | fieldNameSqlArray := quoteIdentifier(fieldName)
54 |
55 | var fullFieldNameSqlArray dialectArray
56 | for dialect := dialect(0); dialect < dialectCount; dialect++ {
57 | fullFieldNameSqlArray[dialect] = tableNameSqlArray[dialect] + "." + fieldNameSqlArray[dialect]
58 | }
59 |
60 | return actualField{
61 | expression: expression{
62 | builder: func(scope scope) (string, error) {
63 | dialect := dialectUnknown
64 | if scope.Database != nil {
65 | dialect = scope.Database.dialect
66 | }
67 | if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName {
68 | return fullFieldNameSqlArray[dialect], nil
69 | }
70 | return fieldNameSqlArray[dialect], nil
71 | },
72 | },
73 | table: table,
74 | }
75 | }
76 |
77 | // NewNumberField creates a reference to a number field. It should only be called from generated code.
78 | func NewNumberField(table Table, fieldName string) NumberField {
79 | return newField(table, fieldName)
80 | }
81 |
82 | // NewBooleanField creates a reference to a boolean field. It should only be called from generated code.
83 | func NewBooleanField(table Table, fieldName string) BooleanField {
84 | return newField(table, fieldName)
85 | }
86 |
87 | // NewStringField creates a reference to a string field. It should only be called from generated code.
88 | func NewStringField(table Table, fieldName string) StringField {
89 | return newField(table, fieldName)
90 | }
91 |
92 | // NewDateField creates a reference to a time.Time field. It should only be called from generated code.
93 | func NewDateField(table Table, fieldName string) DateField {
94 | return newField(table, fieldName)
95 | }
96 |
97 | type fieldList []Field
98 |
99 | func (fields fieldList) GetSQL(scope scope) (string, error) {
100 | isSingleTable := len(scope.Tables) == 1 && scope.lastJoin == nil
101 | var sb strings.Builder
102 | if len(fields) == 0 {
103 | for i, table := range scope.Tables {
104 | if i > 0 {
105 | sb.WriteString(", ")
106 | }
107 | actualTable, ok := table.(actualTable)
108 | if ok {
109 | if isSingleTable {
110 | sb.WriteString(actualTable.GetFieldsSQL())
111 | } else {
112 | sb.WriteString(actualTable.GetFullFieldsSQL())
113 | }
114 | } else {
115 | sb.WriteByte('*')
116 | }
117 | }
118 | } else {
119 | fieldsSql, err := commaFields(scope, fields)
120 | if err != nil {
121 | return "", err
122 | }
123 | sb.WriteString(fieldsSql)
124 | }
125 | return sb.String(), nil
126 | }
127 |
--------------------------------------------------------------------------------
/common.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "fmt"
5 | "runtime"
6 | "strings"
7 | )
8 |
9 | const (
10 | // SqlingoRuntimeVersion is the the runtime version of sqlingo
11 | SqlingoRuntimeVersion = 2
12 | )
13 |
14 | // Model is the interface of generated model struct
15 | type Model interface {
16 | GetTable() Table
17 | GetValues() []interface{}
18 | }
19 |
20 | // Assignment is an assignment statement
21 | type Assignment interface {
22 | GetSQL(scope scope) (string, error)
23 | }
24 |
25 | type assignment struct {
26 | Assignment
27 | field Field
28 | value interface{}
29 | }
30 |
31 | func (a assignment) GetSQL(scope scope) (string, error) {
32 | value, _, err := getSQL(scope, a.value)
33 | if err != nil {
34 | return "", err
35 | }
36 | fieldSql, err := a.field.GetSQL(scope)
37 | if err != nil {
38 | return "", err
39 | }
40 | return fieldSql + " = " + value, nil
41 | }
42 |
43 | func command(name string, arg interface{}) expression {
44 | return expression{builder: func(scope scope) (string, error) {
45 | sql, _, err := getSQL(scope, arg)
46 | if err != nil {
47 | return "", err
48 | }
49 | return name + " " + sql, nil
50 | }}
51 | }
52 |
53 | func commaFields(scope scope, fields []Field) (string, error) {
54 | var sqlBuilder strings.Builder
55 | sqlBuilder.Grow(128)
56 | for i, item := range fields {
57 | if i > 0 {
58 | sqlBuilder.WriteString(", ")
59 | }
60 | itemSql, err := item.GetSQL(scope)
61 | if err != nil {
62 | return "", err
63 | }
64 | sqlBuilder.WriteString(itemSql)
65 | }
66 | return sqlBuilder.String(), nil
67 | }
68 |
69 | func commaExpressions(scope scope, expressions []Expression) (string, error) {
70 | var sqlBuilder strings.Builder
71 | sqlBuilder.Grow(128)
72 | for i, item := range expressions {
73 | if i > 0 {
74 | sqlBuilder.WriteString(", ")
75 | }
76 | itemSql, err := item.GetSQL(scope)
77 | if err != nil {
78 | return "", err
79 | }
80 | sqlBuilder.WriteString(itemSql)
81 | }
82 | return sqlBuilder.String(), nil
83 | }
84 |
85 | func commaTables(scope scope, tables []Table) string {
86 | var sqlBuilder strings.Builder
87 | sqlBuilder.Grow(32)
88 | for i, table := range tables {
89 | if i > 0 {
90 | sqlBuilder.WriteString(", ")
91 | }
92 | sqlBuilder.WriteString(table.GetSQL(scope))
93 | }
94 | return sqlBuilder.String()
95 | }
96 |
97 | func commaValues(scope scope, values []interface{}) (string, error) {
98 | var sqlBuilder strings.Builder
99 | for i, item := range values {
100 | if i > 0 {
101 | sqlBuilder.WriteString(", ")
102 | }
103 | itemSql, _, err := getSQL(scope, item)
104 | if err != nil {
105 | return "", err
106 | }
107 | sqlBuilder.WriteString(itemSql)
108 | }
109 | return sqlBuilder.String(), nil
110 | }
111 |
112 | func commaAssignments(scope scope, assignments []assignment) (string, error) {
113 | var sqlBuilder strings.Builder
114 | for i, item := range assignments {
115 | if i > 0 {
116 | sqlBuilder.WriteString(", ")
117 | }
118 | itemSql, err := item.GetSQL(scope)
119 | if err != nil {
120 | return "", err
121 | }
122 | sqlBuilder.WriteString(itemSql)
123 | }
124 | return sqlBuilder.String(), nil
125 | }
126 |
127 | func commaOrderBys(scope scope, orderBys []OrderBy) (string, error) {
128 | var sqlBuilder strings.Builder
129 | for i, item := range orderBys {
130 | if i > 0 {
131 | sqlBuilder.WriteString(", ")
132 | }
133 | itemSql, err := item.GetSQL(scope)
134 | if err != nil {
135 | return "", err
136 | }
137 | sqlBuilder.WriteString(itemSql)
138 | }
139 | return sqlBuilder.String(), nil
140 | }
141 |
142 | func getCallerInfo(db database, retry bool) string {
143 | if !db.enableCallerInfo {
144 | return ""
145 | }
146 | extraInfo := ""
147 | if db.tx != nil {
148 | extraInfo += " (tx)"
149 | }
150 | if retry {
151 | extraInfo += " (retry)"
152 | }
153 | for i := 0; true; i++ {
154 | _, file, line, ok := runtime.Caller(i)
155 | if !ok {
156 | break
157 | }
158 | if file == "" || strings.Contains(file, "/sqlingo@v") {
159 | continue
160 | }
161 | segs := strings.Split(file, "/")
162 | name := segs[len(segs)-1]
163 | return fmt.Sprintf("/* %s:%d%s */ ", name, line, extraInfo)
164 | }
165 | return ""
166 | }
167 |
--------------------------------------------------------------------------------
/insert_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "testing"
7 | )
8 |
9 | type tTest struct {
10 | Table
11 |
12 | F1 fTestF1
13 | F2 fTestF2
14 | }
15 |
16 | func (t tTest) GetFields() []Field {
17 | return []Field{t.F1, t.F2}
18 | }
19 |
20 | type fTestF1 struct{ NumberField }
21 | type fTestF2 struct{ StringField }
22 |
23 | type TestModel struct {
24 | F1 int64
25 | F2 string
26 | }
27 |
28 | var tTestTable = NewTable("test")
29 |
30 | var Test = tTest{
31 | Table: NewTable("test"),
32 | F1: fTestF1{NewNumberField(tTestTable, "f1")},
33 | F2: fTestF2{NewStringField(tTestTable, "f2")},
34 | }
35 |
36 | func (m TestModel) GetTable() Table {
37 | return Test
38 | }
39 |
40 | func (m TestModel) GetValues() []interface{} {
41 | return []interface{}{m.F1, m.F2}
42 | }
43 |
44 | func TestInsert(t *testing.T) {
45 | db := newMockDatabase()
46 |
47 | if _, err := db.InsertInto(Table1).Fields(field1).
48 | Values(1).
49 | Values(2).
50 | OnDuplicateKeyUpdate().Set(field1, 10).
51 | Execute(); err != nil {
52 | t.Error(err)
53 | }
54 | assertLastSql(t, "INSERT INTO `table1` (`field1`)"+
55 | " VALUES (1), (2)"+
56 | " ON DUPLICATE KEY UPDATE `field1` = 10")
57 |
58 | if _, err := db.InsertInto(Table1).Fields(field1).
59 | Values(1).
60 | OnDuplicateKeyUpdate().
61 | SetIf(false, field1, 2).
62 | Execute(); err != nil {
63 | t.Error(err)
64 | }
65 | assertLastSql(t, "INSERT INTO `table1` (`field1`)"+
66 | " VALUES (1)")
67 |
68 | if _, err := db.InsertInto(Table1).Fields(field1).
69 | Values(0).
70 | OnDuplicateKeyUpdate().
71 | SetIf(true, field1, 1).
72 | Execute(); err != nil {
73 | t.Error(err)
74 | }
75 | assertLastSql(t, "INSERT INTO `table1` (`field1`)"+
76 | " VALUES (0)"+
77 | " ON DUPLICATE KEY UPDATE `field1` = 1")
78 |
79 | if _, err := db.InsertInto(Table1).
80 | Fields(field1, field2).
81 | Values(1, 2).
82 | OnDuplicateKeyUpdate().
83 | SetIf(false, field1, 10).
84 | SetIf(true, field2, 20).
85 | Execute(); err != nil {
86 | t.Error(err)
87 | }
88 | assertLastSql(t, "INSERT INTO `table1` (`field1`, `field2`)"+
89 | " VALUES (1, 2)"+
90 | " ON DUPLICATE KEY UPDATE `field2` = 20")
91 |
92 | if _, err := db.InsertInto(Table1).Fields(field1).
93 | Values(1).
94 | Values(2).
95 | OnDuplicateKeyIgnore().
96 | Execute(); err != nil {
97 | t.Error(err)
98 | }
99 | assertLastSql(t, "INSERT INTO `table1` (`field1`)"+
100 | " VALUES (1), (2)"+
101 | " ON DUPLICATE KEY UPDATE `field1` = `field1`")
102 |
103 | model := &TestModel{
104 | F1: 1,
105 | F2: "test",
106 | }
107 | if _, err := db.InsertInto(Test).Values(1, 2).Execute(); err != nil {
108 | t.Error(err)
109 | }
110 | assertLastSql(t, "INSERT INTO `test` (`f1`, `f2`) VALUES (1, 2)")
111 |
112 | if _, err := db.InsertInto(Test).Models(model, &model, []Model{model}).Execute(); err != nil {
113 | t.Error(err)
114 | }
115 | assertLastSql(t, "INSERT INTO `test` (`f1`, `f2`) VALUES (1, 'test'), (1, 'test'), (1, 'test')")
116 |
117 | if _, err := db.InsertInto(Test).Models(model, &model, []interface{}{model, "invalid type"}).Execute(); err == nil {
118 | t.Error("should get error here")
119 | }
120 |
121 | if _, err := db.InsertInto(Table1).Models(model).Execute(); err == nil {
122 | t.Error("should get error here")
123 | }
124 |
125 | if _, err := db.ReplaceInto(Test).Values(1, 2).Execute(); err != nil {
126 | t.Error(err)
127 | }
128 | assertLastSql(t, "REPLACE INTO `test` (`f1`, `f2`) VALUES (1, 2)")
129 |
130 | errExpr := expression{
131 | builder: func(scope scope) (string, error) {
132 | return "", errors.New("error")
133 | },
134 | }
135 | if _, err := db.InsertInto(Test).Fields(errExpr).Values(1).Execute(); err == nil {
136 | t.Error("should get error here")
137 | }
138 | if _, err := db.InsertInto(Test).Fields(Test.F1).Values(errExpr).Execute(); err == nil {
139 | t.Error("should get error here")
140 | }
141 | if _, err := db.InsertInto(Test).
142 | Fields(Test.F1).Values(1).
143 | OnDuplicateKeyUpdate().Set(Test.F1, errExpr).Execute(); err == nil {
144 | t.Error("should get error here")
145 | }
146 |
147 | if _, err := db.InsertInto(Test).Fields(Test.F1).Values(1).WithContext(context.Background()).Execute(); err != nil {
148 | t.Error(err)
149 | }
150 |
151 | if _, err := db.InsertInto(Test).
152 | Fields(Test.F1).Values(1).
153 | OnDuplicateKeyUpdate().Set(Test.F1, errExpr).WithContext(context.Background()).Execute(); err == nil {
154 | t.Error("should get error here")
155 | }
156 | }
157 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | [](https://github.com/avelino/awesome-go)
4 | [](https://pkg.go.dev/github.com/lqs/sqlingo?tab=doc)
5 | [](https://app.travis-ci.com/github/lqs/sqlingo)
6 | [](https://goreportcard.com/report/github.com/lqs/sqlingo)
7 | [](https://codecov.io/gh/lqs/sqlingo)
8 | [](http://opensource.org/licenses/MIT)
9 | [](https://github.com/lqs/sqlingo/commits)
10 |
11 | **sqlingo** is a SQL DSL (a.k.a. SQL Builder or ORM) library in Go. It generates code from the database and lets you write SQL queries in an elegant way.
12 |
13 |
14 |
15 | ## Features
16 | * Auto-generating DSL objects and model structs from the database so you don't need to manually keep things in sync
17 | * SQL DML (SELECT / INSERT / UPDATE / DELETE) with some advanced SQL query syntaxes
18 | * Many common errors could be detected at compile time
19 | * Your can use the features in your editor / IDE, such as autocompleting the fields and queries, or finding the usage of a field or a table
20 | * Context support
21 | * Transaction support
22 | * Interceptor support
23 | * Golang time.Time is supported now, but you can still use the string type by adding `-timeAsString` when generating the model
24 |
25 | ## Database Support Status
26 | | Database | Status |
27 | ------------- | --------------
28 | | MySQL | stable |
29 | | PostgreSQL | experimental |
30 | | SQLite | experimental |
31 |
32 | ## Tutorial
33 |
34 | ### Install and use sqlingo code generator
35 | The first step is to generate code from the database. In order to generate code, sqlingo requires your tables are already created in the database.
36 |
37 | ```
38 | $ go install github.com/lqs/sqlingo/sqlingo-gen-mysql@latest
39 | $ mkdir -p generated/sqlingo
40 | $ sqlingo-gen-mysql root:123456@/database_name >generated/sqlingo/database_name.dsl.go
41 | ```
42 |
43 |
44 | ### Write your application
45 | Here's a demonstration of some simple & advanced usage of sqlingo.
46 | ```go
47 | package main
48 |
49 | import (
50 | "github.com/lqs/sqlingo"
51 | . "./generated/sqlingo"
52 | )
53 |
54 | func main() {
55 | db, err := sqlingo.Open("mysql", "root:123456@/database_name")
56 | if err != nil {
57 | panic(err)
58 | }
59 |
60 | // a simple query
61 | var customers []*CustomerModel
62 | db.SelectFrom(Customer).
63 | Where(Customer.Id.In(1, 2)).
64 | OrderBy(Customer.Name.Desc()).
65 | FetchAll(&customers)
66 |
67 | // query from multiple tables
68 | var customerId int64
69 | var orderId int64
70 | err = db.Select(Customer.Id, Order.Id).
71 | From(Customer, Order).
72 | Where(Customer.Id.Equals(Order.CustomerId), Order.Id.Equals(1)).
73 | FetchFirst(&customerId, &orderId)
74 |
75 | // subquery and count
76 | count, err := db.SelectFrom(Order)
77 | Where(Order.CustomerId.In(db.Select(Customer.Id).
78 | From(Customer).
79 | Where(Customer.Name.Equals("Customer One")))).
80 | Count()
81 |
82 | // group-by with auto conversion to map
83 | var customerIdToOrderCount map[int64]int64
84 | err = db.Select(Order.CustomerId, f.Count(1)).
85 | From(Order).
86 | GroupBy(Order.CustomerId).
87 | FetchAll(&customerIdToOrderCount)
88 | if err != nil {
89 | println(err)
90 | }
91 |
92 | // insert some rows
93 | customer1 := &CustomerModel{name: "Customer One"}
94 | customer2 := &CustomerModel{name: "Customer Two"}
95 | _, err = db.InsertInto(Customer).
96 | Models(customer1, customer2).
97 | Execute()
98 |
99 | // insert with on-duplicate-key-update
100 | _, err = db.InsertInto(Customer).
101 | Fields(Customer.Id, Customer.Name).
102 | Values(42, "Universe").
103 | OnDuplicateKeyUpdate().
104 | Set(Customer.Name, Customer.Name.Concat(" 2")).
105 | Execute()
106 | }
107 | ```
108 |
--------------------------------------------------------------------------------
/cursor_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "database/sql"
5 | "database/sql/driver"
6 | "io"
7 | "strconv"
8 | "testing"
9 | "time"
10 | )
11 |
12 | type mockDriver struct{}
13 |
14 | type mockConn struct {
15 | lastSql string
16 | mockTx *mockTx
17 | beginTxError error
18 | prepareError error
19 | columnCount int
20 | rowCount int
21 | }
22 |
23 | type mockStmt struct {
24 | columnCount int
25 | rowCount int
26 | }
27 |
28 | type mockRows struct {
29 | columnCount int
30 | cursorPosition int
31 | rowCount int
32 | }
33 |
34 | func (m mockRows) Columns() []string {
35 | return []string{"a", "b", "c", "d", "e", "f", "g", "h", "j", "k", "l"}[:m.columnCount]
36 | }
37 |
38 | func (m mockRows) Close() error {
39 | return nil
40 | }
41 |
42 | func (m *mockRows) Next(dest []driver.Value) error {
43 | if m.cursorPosition >= m.rowCount {
44 | return io.EOF
45 | }
46 | m.cursorPosition++
47 | for i := 0; i < m.columnCount; i++ {
48 | switch i {
49 | case 0:
50 | dest[i] = strconv.Itoa(m.cursorPosition)
51 | case 1:
52 | dest[i] = float32(m.cursorPosition)
53 | case 2:
54 | dest[i] = m.cursorPosition
55 | case 3:
56 | dest[i] = string(rune(m.cursorPosition % 2)) // '\x00' or '\x01'
57 | case 4:
58 | dest[i] = strconv.Itoa(m.cursorPosition % 2) // '0' or '1'
59 | case 5:
60 | dest[i] = dest[0]
61 | case 6:
62 | dest[i] = nil
63 | case 7:
64 | dest[i] = time.Date(2023, 9, 6, 18, 37, 46, 828000000, time.UTC)
65 | case 8:
66 | dest[i] = "2023-09-06 18:37:46.828"
67 | case 9:
68 | dest[i] = "2023-09-06 18:37:46"
69 | case 10:
70 | dest[i] = "2023-09-06 18:37:46"
71 | }
72 | }
73 | return nil
74 | }
75 |
76 | func (m mockStmt) Close() error {
77 | return nil
78 | }
79 |
80 | func (m mockStmt) NumInput() int {
81 | return 0
82 | }
83 |
84 | func (m mockStmt) Exec(args []driver.Value) (driver.Result, error) {
85 | return driver.ResultNoRows, nil
86 | }
87 |
88 | func (m mockStmt) Query(args []driver.Value) (driver.Rows, error) {
89 | return &mockRows{
90 | columnCount: m.columnCount,
91 | rowCount: m.rowCount,
92 | }, nil
93 | }
94 |
95 | func TestCursor(t *testing.T) {
96 | db := newMockDatabase()
97 | cursor, _ := db.Query("dummy sql")
98 |
99 | var a int
100 | var b string
101 |
102 | var cde struct {
103 | C float32
104 | DE struct {
105 | D, E bool
106 | }
107 | }
108 | var f ****int // deep pointer
109 | var g *int // always null
110 | var h *time.Time
111 | var j time.Time
112 | var k *time.Time
113 | var l time.Time
114 | tmh, _ := time.Parse("2006-01-02 15:04:05.000", "2023-09-06 18:37:46.828")
115 | tmj, _ := time.Parse("2006-01-02 15:04:05.000", "2023-09-06 18:37:46.828")
116 | tmk, _ := time.Parse("2006-01-02 15:04:05", "2023-09-06 18:37:46")
117 | tml, _ := time.Parse("2006-01-02 15:04:05", "2023-09-06 18:37:46")
118 | for i := 1; i <= 10; i++ {
119 | if !cursor.Next() {
120 | t.Error()
121 | }
122 | g = &i
123 | if err := cursor.Scan(&a, &b, &cde, &f, &g, &h, &j, &k, &l); err != nil {
124 | t.Errorf("%v", err)
125 | }
126 | if a != i ||
127 | b != strconv.Itoa(i) ||
128 | cde.C != float32(i) ||
129 | cde.DE.D != (i%2 == 1) ||
130 | cde.DE.E != cde.DE.D ||
131 | ****f != i ||
132 | g != nil ||
133 | *h != tmh ||
134 | j != tmj ||
135 | *k != tmk ||
136 | l != tml {
137 | t.Error(a, b, cde.C, cde.DE.D, cde.DE.E, ****f, g)
138 | }
139 | if err := cursor.Scan(); err != nil {
140 | t.Errorf("%v", err)
141 | }
142 |
143 | var s string
144 | var b ****bool
145 | var p *string
146 | var bs []byte
147 | var u string
148 | if err := cursor.Scan(&s, &s, &s, &b, &s, &bs, &p, &u, &u, &u, &u); err != nil {
149 | t.Error(err)
150 | }
151 | if ****b != (i%2 == 1) ||
152 | p != nil ||
153 | string(bs) != strconv.Itoa(i) {
154 | t.Error(s, ****b, p, string(bs))
155 | }
156 | }
157 | if cursor.Next() {
158 | t.Errorf("d")
159 | }
160 | if err := cursor.Close(); err != nil {
161 | t.Error(err)
162 | }
163 |
164 | }
165 |
166 | func TestScanTime(t *testing.T) {
167 | db := newMockDatabase()
168 | cursor, _ := db.Query("dummy sql")
169 | defer cursor.Close()
170 |
171 | var row struct {
172 | A sql.NullString
173 | B []byte
174 | C sql.NullInt32
175 | D sql.NullString
176 | E sql.NullString
177 | F sql.NullString
178 | G sql.NullString
179 | H sql.NullTime
180 | I sql.NullString
181 | J sql.NullString
182 | K sql.NullString
183 | }
184 | if !cursor.Next() {
185 | t.Error()
186 | }
187 | if err := cursor.Scan(&row); err != nil {
188 | t.Error(err)
189 | }
190 | }
191 |
192 | func TestCursorMap(t *testing.T) {
193 | db := newMockDatabase()
194 | cursor, _ := db.Query("dummy sql")
195 |
196 | for i := 1; i <= 10; i++ {
197 | if !cursor.Next() {
198 | t.Error()
199 | }
200 | row, err := cursor.GetMap()
201 | if err != nil {
202 | t.Error(err)
203 | }
204 | if row["a"].Int() != i {
205 | t.Error()
206 | }
207 | }
208 | if cursor.Next() {
209 | t.Error()
210 | }
211 | }
212 |
213 | func TestParseTime(t *testing.T) {
214 | tests := []struct {
215 | input string
216 | output time.Time
217 | }{
218 | {"2024-09-06 11:22:33", time.Date(2024, 9, 6, 11, 22, 33, 0, time.UTC)},
219 | {"2024-09-06 11:22:33.444", time.Date(2024, 9, 6, 11, 22, 33, 444000000, time.UTC)},
220 | {"2024-09-06 11:22:33.444555666", time.Date(2024, 9, 6, 11, 22, 33, 444555666, time.UTC)},
221 | {"2024-09-06T11:22:33.444555666Z", time.Date(2024, 9, 6, 11, 22, 33, 444555666, time.UTC)},
222 | {"0000-00-00 00:00:00", time.Time{}},
223 | }
224 | for _, test := range tests {
225 | tm, err := parseTime(test.input)
226 | if err != nil {
227 | t.Error(err)
228 | continue
229 | }
230 | if tm != test.output {
231 | t.Error(tm, test.output)
232 | }
233 | }
234 | }
235 |
--------------------------------------------------------------------------------
/array.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "fmt"
7 | "io"
8 | "strconv"
9 | "unicode"
10 | )
11 |
12 | type ArrayDimension struct {
13 | Length int32
14 | LowerBound int32
15 | }
16 |
17 | type untypedTextArray struct {
18 | Elements []string // elements as string
19 | Quoted []bool // true if source is quoted
20 | Dimensions []ArrayDimension //
21 | }
22 |
23 | // Skip the space prefix
24 | func skipWhitespace(buf *bytes.Buffer) {
25 | var r rune
26 | var err error
27 | for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() {
28 | }
29 |
30 | if err != io.EOF {
31 | buf.UnreadRune()
32 | }
33 | }
34 |
35 | func parseToUntypedTextArray(src string) (*untypedTextArray, error) {
36 | dst := &untypedTextArray{
37 | Elements: []string{},
38 | Quoted: []bool{},
39 | Dimensions: []ArrayDimension{},
40 | }
41 |
42 | buf := bytes.NewBufferString(src)
43 | skipWhitespace(buf)
44 |
45 | // read failed
46 | r, _, err := buf.ReadRune()
47 | if err != nil {
48 | return nil, fmt.Errorf("invalid array: %v", err)
49 | }
50 |
51 | var explicitDimensions []ArrayDimension
52 |
53 | // Array has explicit dimensions
54 | if r == '[' {
55 | buf.UnreadRune()
56 | for {
57 | r, _, err = buf.ReadRune()
58 | if err != nil {
59 | return nil, fmt.Errorf("invalid array: %v", err)
60 | }
61 | if r == '=' {
62 | break
63 | } else if r != '[' {
64 | return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r)
65 | }
66 |
67 | lower, err := arrayParseInteger(buf)
68 | if err != nil {
69 | return nil, fmt.Errorf("invalid array: %v", err)
70 | }
71 |
72 | r, _, err = buf.ReadRune()
73 | if err != nil {
74 | return nil, fmt.Errorf("invalid array: %v", err)
75 | }
76 |
77 | if r != ':' {
78 | return nil, fmt.Errorf("invalid array, expected ':' got %v", r)
79 | }
80 |
81 | upper, err := arrayParseInteger(buf)
82 | if err != nil {
83 | return nil, fmt.Errorf("invalid array: %v", err)
84 | }
85 |
86 | r, _, err = buf.ReadRune()
87 | if err != nil {
88 | return nil, fmt.Errorf("invalid array: %v", err)
89 | }
90 |
91 | if r != ']' {
92 | return nil, fmt.Errorf("invalid array, expected ']' got %v", r)
93 | }
94 |
95 | explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1})
96 | }
97 | }
98 |
99 | if r != '{' {
100 | return nil, errors.New("invalid array, expected '{' prefix")
101 | }
102 | implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}}
103 |
104 | // Consume all initial opening brackets. This provides number of dimensions.
105 | for {
106 | r, _, err = buf.ReadRune()
107 | if err != nil {
108 | return nil, fmt.Errorf("invalid array: %v", err)
109 | }
110 |
111 | if r == '{' {
112 | implicitDimensions[len(implicitDimensions)-1].Length = 1
113 | implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1})
114 | } else {
115 | buf.UnreadRune()
116 | break
117 | }
118 | }
119 | currentDim := len(implicitDimensions) - 1
120 | counterDim := currentDim
121 |
122 | for {
123 | r, _, err = buf.ReadRune()
124 | if err != nil {
125 | return nil, fmt.Errorf("invalid array: %v", err)
126 | }
127 |
128 | switch r {
129 | case '{':
130 | if currentDim == counterDim {
131 | implicitDimensions[currentDim].Length++
132 | }
133 | currentDim++
134 | case ',':
135 | case '}':
136 | currentDim--
137 | if currentDim < counterDim {
138 | counterDim = currentDim
139 | }
140 | default:
141 | buf.UnreadRune()
142 | value, quoted, err := arrayParseValue(buf)
143 | if err != nil {
144 | return nil, fmt.Errorf("invalid array value: %v", err)
145 | }
146 | if currentDim == counterDim {
147 | implicitDimensions[currentDim].Length++
148 | }
149 | dst.Quoted = append(dst.Quoted, quoted)
150 | dst.Elements = append(dst.Elements, value)
151 | }
152 |
153 | if currentDim < 0 {
154 | break
155 | }
156 | }
157 |
158 | skipWhitespace(buf)
159 |
160 | if buf.Len() > 0 {
161 | return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
162 | }
163 |
164 | if len(dst.Elements) == 0 {
165 | } else if len(explicitDimensions) > 0 {
166 | dst.Dimensions = explicitDimensions
167 | } else {
168 | dst.Dimensions = implicitDimensions
169 | }
170 |
171 | return dst, nil
172 | }
173 |
174 | func arrayParseInteger(buf *bytes.Buffer) (int32, error) {
175 | s := &bytes.Buffer{}
176 |
177 | for {
178 | r, _, err := buf.ReadRune()
179 | if err != nil {
180 | return 0, err
181 | }
182 |
183 | if ('0' <= r && r <= '9') || r == '-' {
184 | s.WriteRune(r)
185 | } else {
186 | buf.UnreadRune()
187 | n, err := strconv.ParseInt(s.String(), 10, 32)
188 | if err != nil {
189 | return 0, err
190 | }
191 | return int32(n), nil
192 | }
193 | }
194 | }
195 | func arrayParseValue(buf *bytes.Buffer) (string, bool, error) {
196 | r, _, err := buf.ReadRune()
197 | if err != nil {
198 | return "", false, err
199 | }
200 | if r == '"' {
201 | return arrayParseQuotedValue(buf)
202 | }
203 | buf.UnreadRune()
204 |
205 | s := &bytes.Buffer{}
206 |
207 | for {
208 | r, _, err := buf.ReadRune()
209 | if err != nil {
210 | return "", false, err
211 | }
212 |
213 | switch r {
214 | case ',', '}':
215 | buf.UnreadRune()
216 | return s.String(), false, nil
217 | }
218 |
219 | s.WriteRune(r)
220 | }
221 | }
222 |
223 | func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) {
224 | s := &bytes.Buffer{}
225 |
226 | for {
227 | r, _, err := buf.ReadRune()
228 | if err != nil {
229 | return "", false, err
230 | }
231 |
232 | switch r {
233 | case '\\':
234 | r, _, err = buf.ReadRune()
235 | if err != nil {
236 | return "", false, err
237 | }
238 | case '"':
239 | r, _, err = buf.ReadRune()
240 | if err != nil {
241 | return "", false, err
242 | }
243 | buf.UnreadRune()
244 | return s.String(), true, nil
245 | }
246 | s.WriteRune(r)
247 | }
248 | }
249 |
--------------------------------------------------------------------------------
/insert.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "errors"
7 | "fmt"
8 | "reflect"
9 | )
10 |
11 | type insertStatus struct {
12 | method string
13 | scope scope
14 | fields []Field
15 | values []interface{}
16 | models []interface{}
17 | onDuplicateKeyUpdateAssignments []assignment
18 | ctx context.Context
19 | }
20 |
21 | type insertWithTable interface {
22 | Fields(fields ...Field) insertWithValues
23 | Values(values ...interface{}) insertWithValues
24 | Models(models ...interface{}) insertWithModels
25 | }
26 |
27 | type insertWithValues interface {
28 | toInsertWithContext
29 | toInsertFinal
30 | Values(values ...interface{}) insertWithValues
31 | OnDuplicateKeyIgnore() toInsertWithDuplicateKey
32 | OnDuplicateKeyUpdate() insertWithOnDuplicateKeyUpdateBegin
33 | }
34 |
35 | type insertWithModels interface {
36 | toInsertWithContext
37 | toInsertFinal
38 | Models(models ...interface{}) insertWithModels
39 | OnDuplicateKeyIgnore() toInsertWithDuplicateKey
40 | OnDuplicateKeyUpdate() insertWithOnDuplicateKeyUpdateBegin
41 | }
42 |
43 | type insertWithOnDuplicateKeyUpdateBegin interface {
44 | Set(Field Field, value interface{}) insertWithOnDuplicateKeyUpdate
45 | SetIf(condition bool, Field Field, value interface{}) insertWithOnDuplicateKeyUpdate
46 | }
47 |
48 | type insertWithOnDuplicateKeyUpdate interface {
49 | insertWithOnDuplicateKeyUpdateBegin
50 | toInsertWithDuplicateKey
51 | }
52 |
53 | type toInsertWithContext interface {
54 | WithContext(ctx context.Context) toInsertFinal
55 | }
56 |
57 | type toInsertFinal interface {
58 | GetSQL() (string, error)
59 | Execute() (result sql.Result, err error)
60 | }
61 |
62 | type toInsertWithDuplicateKey interface {
63 | toInsertWithContext
64 | toInsertFinal
65 | }
66 |
67 | func (d *database) InsertInto(table Table) insertWithTable {
68 | return insertStatus{method: "INSERT", scope: scope{Database: d, Tables: []Table{table}}}
69 | }
70 |
71 | func (d *database) ReplaceInto(table Table) insertWithTable {
72 | return insertStatus{method: "REPLACE", scope: scope{Database: d, Tables: []Table{table}}}
73 | }
74 |
75 | func (s insertStatus) Fields(fields ...Field) insertWithValues {
76 | s.fields = fields
77 | return s
78 | }
79 |
80 | func (s insertStatus) Values(values ...interface{}) insertWithValues {
81 | s.values = append([]interface{}{}, s.values...)
82 | s.values = append(s.values, values)
83 | return s
84 | }
85 |
86 | func addModel(models *[]Model, model interface{}) error {
87 | if model, ok := model.(Model); ok {
88 | *models = append(*models, model)
89 | return nil
90 | }
91 |
92 | value := reflect.ValueOf(model)
93 | switch value.Kind() {
94 | case reflect.Ptr:
95 | value = reflect.Indirect(value)
96 | return addModel(models, value.Interface())
97 | case reflect.Slice, reflect.Array:
98 | for i := 0; i < value.Len(); i++ {
99 | elem := value.Index(i)
100 | addr := elem.Addr()
101 | inter := addr.Interface()
102 | if err := addModel(models, inter); err != nil {
103 | return err
104 | }
105 | }
106 | return nil
107 | default:
108 | return fmt.Errorf("unknown model type (kind = %d)", value.Kind())
109 | }
110 | }
111 |
112 | func (s insertStatus) Models(models ...interface{}) insertWithModels {
113 | s.models = models
114 | return s
115 | }
116 |
117 | func (s insertStatus) OnDuplicateKeyUpdate() insertWithOnDuplicateKeyUpdateBegin {
118 | return s
119 | }
120 |
121 | func (s insertStatus) SetIf(condition bool, field Field, value interface{}) insertWithOnDuplicateKeyUpdate {
122 | if condition {
123 | return s.Set(field, value)
124 | }
125 | return s
126 | }
127 |
128 | func (s insertStatus) Set(field Field, value interface{}) insertWithOnDuplicateKeyUpdate {
129 | s.onDuplicateKeyUpdateAssignments = append([]assignment{}, s.onDuplicateKeyUpdateAssignments...)
130 | s.onDuplicateKeyUpdateAssignments = append(s.onDuplicateKeyUpdateAssignments, assignment{
131 | field: field,
132 | value: value,
133 | })
134 | return s
135 | }
136 |
137 | func (s insertStatus) OnDuplicateKeyIgnore() toInsertWithDuplicateKey {
138 | firstField := s.scope.Tables[0].GetFields()[0]
139 | return s.OnDuplicateKeyUpdate().Set(firstField, firstField)
140 | }
141 |
142 | func (s insertStatus) GetSQL() (string, error) {
143 | var fields []Field
144 | var values []interface{}
145 | if len(s.models) > 0 {
146 | models := make([]Model, 0, len(s.models))
147 | for _, model := range s.models {
148 | if err := addModel(&models, model); err != nil {
149 | return "", err
150 | }
151 | }
152 |
153 | if len(models) > 0 {
154 | fields = models[0].GetTable().GetFields()
155 | for _, model := range models {
156 | if model.GetTable().GetName() != s.scope.Tables[0].GetName() {
157 | return "", errors.New("invalid table from model")
158 | }
159 | values = append(values, model.GetValues())
160 | }
161 | }
162 | } else {
163 | if len(s.fields) == 0 {
164 | fields = s.scope.Tables[0].GetFields()
165 | } else {
166 | fields = s.fields
167 | }
168 | values = s.values
169 | }
170 |
171 | if len(values) == 0 {
172 | return "/* INSERT without VALUES */ DO 0", nil
173 | }
174 |
175 | tableSql := s.scope.Tables[0].GetSQL(s.scope)
176 | fieldsSql, err := commaFields(s.scope, fields)
177 | if err != nil {
178 | return "", err
179 | }
180 | valuesSql, err := commaValues(s.scope, values)
181 | if err != nil {
182 | return "", err
183 | }
184 |
185 | sqlString := s.method + " INTO " + tableSql + " (" + fieldsSql + ") VALUES " + valuesSql
186 | if len(s.onDuplicateKeyUpdateAssignments) > 0 {
187 | assignmentsSql, err := commaAssignments(s.scope, s.onDuplicateKeyUpdateAssignments)
188 | if err != nil {
189 | return "", err
190 | }
191 | sqlString += " ON DUPLICATE KEY UPDATE " + assignmentsSql
192 | }
193 |
194 | return sqlString, nil
195 | }
196 |
197 | func (s insertStatus) WithContext(ctx context.Context) toInsertFinal {
198 | s.ctx = ctx
199 | return s
200 | }
201 |
202 | func (s insertStatus) Execute() (result sql.Result, err error) {
203 | sqlString, err := s.GetSQL()
204 | if err != nil {
205 | return nil, err
206 | }
207 | return s.scope.Database.ExecuteContext(s.ctx, sqlString)
208 | }
209 |
--------------------------------------------------------------------------------
/cursor.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "database/sql"
5 | "fmt"
6 | "reflect"
7 | "regexp"
8 | "strconv"
9 | "strings"
10 | "time"
11 | )
12 |
13 | // Scanner is the interface that wraps the Scan method.
14 | type Scanner interface {
15 | Scan(dest ...interface{}) error
16 | }
17 |
18 | // Cursor is the interface of a row cursor.
19 | type Cursor interface {
20 | Next() bool
21 | Scan(dest ...interface{}) error
22 | GetMap() (map[string]value, error)
23 | Close() error
24 | }
25 |
26 | type cursor struct {
27 | rows *sql.Rows
28 | }
29 |
30 | func (c cursor) Next() bool {
31 | return c.rows.Next()
32 | }
33 |
34 | var timeType = reflect.TypeOf(time.Time{})
35 |
36 | var simpleTimeLayoutRegexp = regexp.MustCompile(`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.(\d+))?$`)
37 |
38 | func guessTimeLayout(s string) string {
39 | matches := simpleTimeLayoutRegexp.FindStringSubmatch(s)
40 | if len(matches) > 0 {
41 | var sb strings.Builder
42 | sb.Grow(32)
43 | sb.WriteString("2006-01-02 15:04:05")
44 | if matches[1] != "" {
45 | sb.WriteString(".")
46 | for i := 0; i < len(matches[2]); i++ {
47 | sb.WriteByte('0')
48 | }
49 | }
50 | return sb.String()
51 | }
52 | return time.RFC3339Nano
53 | }
54 |
55 | func parseTime(s string) (time.Time, error) {
56 | if strings.HasPrefix(s, "0000-00-00") {
57 | // MySQL zero date
58 | return time.Time{}, nil
59 | }
60 | layout := guessTimeLayout(s)
61 | t, err := time.Parse(layout, s)
62 | if err != nil {
63 | return time.Time{}, fmt.Errorf("unknown time format %s: %w", s, err)
64 | }
65 | return t, nil
66 | }
67 |
68 | func isScanner(val reflect.Value) bool {
69 | _, ok := val.Addr().Interface().(sql.Scanner)
70 | return ok
71 | }
72 |
73 | func preparePointers(val reflect.Value, scans *[]interface{}) error {
74 | kind := val.Kind()
75 | switch kind {
76 | case reflect.Bool,
77 | reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
78 | reflect.Uint, reflect.Uint8, reflect.Uint32, reflect.Uint64,
79 | reflect.Float32, reflect.Float64,
80 | reflect.String:
81 | if addr := val.Addr(); addr.CanInterface() {
82 | *scans = append(*scans, addr.Interface())
83 | }
84 | case reflect.Struct:
85 | if canScan := val.Type() == timeType || isScanner(val); canScan {
86 | *scans = append(*scans, val.Addr().Interface())
87 | return nil
88 | }
89 | for j := 0; j < val.NumField(); j++ {
90 | field := val.Field(j)
91 | if field.Kind() == reflect.Interface {
92 | continue
93 | }
94 | if err := preparePointers(field, scans); err != nil {
95 | return err
96 | }
97 | }
98 | case reflect.Ptr:
99 | toType := val.Type().Elem()
100 | switch toType.Kind() {
101 | case reflect.Bool,
102 | reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
103 | reflect.Uint, reflect.Uint8, reflect.Uint32, reflect.Uint64,
104 | reflect.Float32, reflect.Float64,
105 | reflect.String:
106 | *scans = append(*scans, val.Addr().Interface())
107 | case reflect.Struct:
108 | if toType == reflect.TypeOf(time.Time{}) {
109 | *scans = append(*scans, val.Addr().Interface())
110 | } else {
111 | to := reflect.New(toType).Elem()
112 | val.Set(to.Addr())
113 | err := preparePointers(to, scans)
114 | if err != nil {
115 | return nil
116 | }
117 | }
118 | default:
119 | to := reflect.New(toType).Elem()
120 | val.Set(to.Addr())
121 | err := preparePointers(to, scans)
122 | if err != nil {
123 | return nil
124 | }
125 | }
126 | case reflect.Slice:
127 | if _, ok := (val.Interface()).([]byte); ok {
128 | *scans = append(*scans, val.Addr().Interface())
129 | } else {
130 | return fmt.Errorf("unknown type []%s", val.Type().Elem().Kind().String())
131 | }
132 | default:
133 | return fmt.Errorf("unknown type %s", kind.String())
134 | }
135 | return nil
136 | }
137 |
138 | func parseBool(s []byte) (bool, error) {
139 | if len(s) == 1 {
140 | if s[0] == 0 {
141 | return false, nil
142 | } else if s[0] == 1 {
143 | return true, nil
144 | }
145 | }
146 | return strconv.ParseBool(string(s))
147 | }
148 |
149 | func (c cursor) Scan(dest ...interface{}) error {
150 | columns, err := c.rows.Columns()
151 | if err != nil {
152 | return err
153 | }
154 | values := make([]interface{}, len(columns))
155 | pointers := make([]interface{}, len(columns))
156 | for i := range columns {
157 | pointers[i] = &values[i]
158 | }
159 | if err := c.rows.Scan(pointers...); err != nil {
160 | return err
161 | }
162 |
163 | if len(dest) == 0 {
164 | // dry run
165 | return nil
166 | }
167 |
168 | var scans []interface{}
169 | for i, item := range dest {
170 | if reflect.ValueOf(item).Kind() != reflect.Ptr {
171 | return fmt.Errorf("argument %d is not pointer", i)
172 | }
173 |
174 | val := reflect.Indirect(reflect.ValueOf(item))
175 |
176 | err := preparePointers(val, &scans)
177 | if err != nil {
178 | return err
179 | }
180 | }
181 |
182 | pbs := make(map[int]*bool)
183 | ppbs := make(map[int]**bool)
184 | pts := make(map[int]*time.Time)
185 | ppts := make(map[int]**time.Time)
186 |
187 | for i, scan := range scans {
188 | switch scan.(type) {
189 | case *bool:
190 | var s []uint8
191 | scans[i] = &s
192 | pbs[i] = scan.(*bool)
193 | case **bool:
194 | var s *[]uint8
195 | scans[i] = &s
196 | ppbs[i] = scan.(**bool)
197 | case *time.Time:
198 | var s string
199 | scans[i] = &s
200 | pts[i] = scan.(*time.Time)
201 | case **time.Time:
202 | var s sql.NullString
203 | scans[i] = &s
204 | ppts[i] = scan.(**time.Time)
205 | }
206 | }
207 |
208 | if err := c.rows.Scan(scans...); err != nil {
209 | return err
210 | }
211 |
212 | for i, pb := range pbs {
213 | if *(scans[i].(*[]byte)) == nil {
214 | return fmt.Errorf("field %d is null", i)
215 | }
216 | b, err := parseBool(*(scans[i].(*[]byte)))
217 | if err != nil {
218 | return err
219 | }
220 | *pb = b
221 | }
222 | for i, ppb := range ppbs {
223 | if *(scans[i].(**[]uint8)) == nil {
224 | *ppb = nil
225 | } else {
226 | b, err := parseBool(**(scans[i].(**[]byte)))
227 | if err != nil {
228 | return err
229 | }
230 | *ppb = &b
231 | }
232 | }
233 | for i := range pts {
234 | s := scans[i].(*string)
235 | if s == nil {
236 | return fmt.Errorf("field %d is null", i)
237 | }
238 | t, err := parseTime(*s)
239 | if err != nil {
240 | return err
241 | }
242 | *pts[i] = t
243 |
244 | }
245 | for i := range ppts {
246 | nullString := scans[i].(*sql.NullString)
247 | if nullString == nil {
248 | return fmt.Errorf("field %d is null", i)
249 | }
250 | if !nullString.Valid {
251 | *ppts[i] = nil
252 | } else {
253 | t, err := parseTime(nullString.String)
254 | if err != nil {
255 | return err
256 | }
257 | *ppts[i] = &t
258 | }
259 | }
260 |
261 | return err
262 | }
263 |
264 | func (c cursor) GetMap() (result map[string]value, err error) {
265 | columns, err := c.rows.Columns()
266 | if err != nil {
267 | return
268 | }
269 |
270 | columnCount := len(columns)
271 | values := make([]interface{}, columnCount)
272 | for i := 0; i < columnCount; i++ {
273 | var value *string
274 | values[i] = &value
275 | }
276 | if err = c.rows.Scan(values...); err != nil {
277 | return
278 | }
279 |
280 | result = make(map[string]value, columnCount)
281 | for i, column := range columns {
282 | result[column] = value{stringValue: *values[i].(**string)}
283 | }
284 |
285 | return
286 | }
287 |
288 | func (c cursor) Close() error {
289 | return c.rows.Close()
290 | }
291 |
--------------------------------------------------------------------------------
/expression_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "errors"
5 | "testing"
6 | )
7 |
8 | type CustomInt int
9 | type CustomBool bool
10 | type CustomFloat float32
11 | type CustomString string
12 |
13 | func TestExpression(t *testing.T) {
14 | assertValue(t, nil, "NULL")
15 |
16 | assertValue(t, false, "0")
17 | assertValue(t, true, "1")
18 | assertValue(t, CustomBool(false), "0")
19 | assertValue(t, CustomBool(true), "1")
20 |
21 | assertValue(t, int8(11), "11")
22 | assertValue(t, int16(11111), "11111")
23 | assertValue(t, int32(1111111111), "1111111111")
24 | assertValue(t, CustomInt(1111111111), "1111111111")
25 | assertValue(t, int(1111111111), "1111111111")
26 | assertValue(t, int64(1111111111111111111), "1111111111111111111")
27 |
28 | assertValue(t, int8(-11), "-11")
29 | assertValue(t, int16(-11111), "-11111")
30 | assertValue(t, int32(-1111111111), "-1111111111")
31 | assertValue(t, CustomInt(-1111111111), "-1111111111")
32 | assertValue(t, int(-1111111111), "-1111111111")
33 | assertValue(t, int64(-1111111111111111111), "-1111111111111111111")
34 |
35 | assertValue(t, uint8(1), "1")
36 | assertValue(t, uint16(55555), "55555")
37 | assertValue(t, uint32(3333333333), "3333333333")
38 | assertValue(t, uint(3333333333), "3333333333")
39 | assertValue(t, uint64(11111111111111111111), "11111111111111111111")
40 |
41 | assertValue(t, float32(2), "2")
42 | assertValue(t, float32(-2), "-2")
43 | assertValue(t, float64(2), "2")
44 | assertValue(t, float64(-2), "-2")
45 |
46 | assertValue(t, "abc", "'abc'")
47 | assertValue(t, "", "''")
48 | assertValue(t, "a' or 'a'='a", "'a\\' or \\'a\\'=\\'a'")
49 | assertValue(t, "\n", "'\\\n'")
50 | assertValue(t, CustomString("abc"), "'abc'")
51 |
52 | x := 3
53 | px := &x
54 | ppx := &px
55 | var deepNil *****int
56 | assertValue(t, &x, "3")
57 | assertValue(t, &px, "3")
58 | assertValue(t, &ppx, "3")
59 | assertValue(t, deepNil, "NULL")
60 | }
61 |
62 | func TestFunc(t *testing.T) {
63 | e := expression{
64 | builder: func(scope scope) (string, error) {
65 | return "<>", nil
66 | },
67 | }
68 |
69 | assertValue(t, e.Equals(e), "<> = <>")
70 | assertValue(t, e.NotEquals(e), "<> <> <>")
71 | assertValue(t, e.LessThan(e), "<> < <>")
72 | assertValue(t, e.LessThanOrEquals(e), "<> <= <>")
73 | assertValue(t, e.GreaterThan(e), "<> > <>")
74 | assertValue(t, e.GreaterThanOrEquals(e), "<> >= <>")
75 | assertValue(t, e.And(e), "<> AND <>")
76 | assertValue(t, e.Or(e), "<> OR <>")
77 | assertValue(t, e.Xor(e), "<> XOR <>")
78 | assertValue(t, e.Not(), "NOT <>")
79 |
80 | assertValue(t, e.Add(e), "<> + <>")
81 | assertValue(t, e.Sub(e), "<> - <>")
82 | assertValue(t, e.Mul(e), "<> * <>")
83 | assertValue(t, e.Div(e), "<> / <>")
84 | assertValue(t, e.IntDiv(e), "<> DIV <>")
85 | assertValue(t, e.Mod(e), "<> % <>")
86 | assertValue(t, e.Sum(), "SUM(<>)")
87 | assertValue(t, e.Avg(), "AVG(<>)")
88 | assertValue(t, e.Min(), "MIN(<>)")
89 | assertValue(t, e.Max(), "MAX(<>)")
90 | assertValue(t, e.Between(2, 4), "<> BETWEEN 2 AND 4")
91 | assertValue(t, e.NotBetween(2, 4), "<> NOT BETWEEN 2 AND 4")
92 |
93 | assertValue(t, e.In(), "FALSE")
94 | assertValue(t, e.In(1), "<> = 1")
95 | assertValue(t, e.In(1, 2, 3), "<> IN (1, 2, 3)")
96 | assertValue(t, e.In([]int64{}), "FALSE")
97 | assertValue(t, e.In([]int64{1}), "<> = 1")
98 | assertValue(t, e.In([]int64{1, 2, 3}), "<> IN (1, 2, 3)")
99 | assertValue(t, e.In([]byte{1, 2, 3}), "<> IN (1, 2, 3)")
100 |
101 | assertValue(t, e.NotIn(), "TRUE")
102 | assertValue(t, e.NotIn(1), "<> <> 1")
103 | assertValue(t, e.NotIn(1, 2, 3), "<> NOT IN (1, 2, 3)")
104 | assertValue(t, e.NotIn([]int64{}), "TRUE")
105 | assertValue(t, e.NotIn([]int64{1}), "<> <> 1")
106 | assertValue(t, e.NotIn([]int64{1, 2, 3}), "<> NOT IN (1, 2, 3)")
107 |
108 | assertValue(t, e.Like("%A%"), "<> LIKE '%A%'")
109 | assertValue(t, e.Concat("-suffix"), "CONCAT(<>, '-suffix')")
110 | assertValue(t, e.Contains("\n"), "LOCATE('\\\n', <>) > 0")
111 |
112 | assertValue(t, []interface{}{1, 2, 3, "d"}, "(1, 2, 3, 'd')")
113 |
114 | assertValue(t, e.IsNull(), "<> IS NULL")
115 | assertValue(t, e.IsNotNull(), "<> IS NOT NULL")
116 | assertValue(t, e.IsTrue(), "<> IS TRUE")
117 | assertValue(t, e.IsNotTrue(), "<> IS NOT TRUE")
118 | assertValue(t, e.IsFalse(), "<> IS FALSE")
119 | assertValue(t, e.IsNotFalse(), "<> IS NOT FALSE")
120 | assertValue(t, e.If(3, 4), "IF(<>, 3, 4)")
121 | assertValue(t, e.IfNull(3), "IFNULL(<>, 3)")
122 | assertValue(t, e.IfEmpty(3), "IF(<> <> '', <>, 3)")
123 | assertValue(t, e.IsEmpty(), "<> = ''")
124 | assertValue(t, e.Lower(), "LOWER(<>)")
125 | assertValue(t, e.Upper(), "UPPER(<>)")
126 | assertValue(t, e.Left(10), "LEFT(<>, 10)")
127 | assertValue(t, e.Right(10), "RIGHT(<>, 10)")
128 | assertValue(t, e.Trim(), "TRIM(<>)")
129 | assertValue(t, e.HasPrefix("abc"), "LEFT(<>, CHAR_LENGTH('abc')) = 'abc'")
130 | assertValue(t, e.HasSuffix("abc"), "RIGHT(<>, CHAR_LENGTH('abc')) = 'abc'")
131 |
132 | e5 := expression{
133 | builder: func(scope scope) (string, error) {
134 | return "e5", nil
135 | },
136 | priority: 5,
137 | }
138 | e7 := expression{
139 | builder: func(scope scope) (string, error) {
140 | return "e7", nil
141 | },
142 | priority: 7,
143 | }
144 | e9 := expression{
145 | builder: func(scope scope) (string, error) {
146 | return "e9", nil
147 | },
148 | priority: 9,
149 | }
150 |
151 | assertValue(t, e7.Add(e7), "e7 + (e7)")
152 | assertValue(t, e5.Add(e7), "e5 + (e7)")
153 | assertValue(t, e7.Add(e5), "e7 + e5")
154 | assertValue(t, e5.Add(e9), "e5 + (e9)")
155 | assertValue(t, e9.Add(e5), "(e9) + e5")
156 |
157 | ee := expression{
158 | builder: func(scope scope) (string, error) {
159 | return "", errors.New("error")
160 | },
161 | }
162 | assertError(t, e.Add(ee))
163 | assertError(t, ee.Add(e))
164 | assertError(t, ee.IsNull())
165 | assertError(t, e.In(ee, ee, ee))
166 | assertError(t, ee.In(e, e, e))
167 |
168 | assertError(t, ee.Between(2, 4))
169 | assertError(t, e.Between(2, ee))
170 | assertError(t, e.Between(ee, 4))
171 |
172 | }
173 |
174 | func TestMisc(t *testing.T) {
175 | assertValue(t, True(), "TRUE")
176 | assertValue(t, False(), "FALSE")
177 |
178 | assertValue(t, command("COMMAND", staticExpression("", 0, false)), "COMMAND ")
179 |
180 | assertValue(t, staticExpression("", 1, false).
181 | prefixSuffixExpression("", "", 1, false), "")
182 | }
183 |
184 | func TestLogicalExpression(t *testing.T) {
185 | a := expression{sql: "a", priority: 1}
186 | b := expression{sql: "b", priority: 1}
187 | c := expression{sql: "c", priority: 1}
188 | d := expression{sql: "d", priority: 1}
189 |
190 | assertValue(t, And(a, b, c, d), "a AND b AND c AND d")
191 | assertValue(t, Or(a, b, c, d), "a OR b OR c OR d")
192 | assertValue(t, a.And(b).Or(c).And(a).Or(b).And(c), "((a AND b OR c) AND a OR b) AND c")
193 | assertValue(t, a.Or(b).And(c.Or(d)), "(a OR b) AND (c OR d)")
194 | assertValue(t, a.Or(b).And(c).Not(), "NOT ((a OR b) AND c)")
195 |
196 | assertValue(t, And(), "TRUE")
197 | assertValue(t, Or(), "FALSE")
198 | }
199 |
200 | func TestLogicalOptimizer(t *testing.T) {
201 | trueValue := True()
202 | falseValue := False()
203 | otherValue := staticExpression("<>", 0, false)
204 | otherBoolValue := staticExpression("<>", 0, true)
205 |
206 | assertValue(t, trueValue.Or(trueValue), "TRUE")
207 | assertValue(t, trueValue.Or(falseValue), "TRUE")
208 | assertValue(t, falseValue.Or(trueValue), "TRUE")
209 | assertValue(t, falseValue.Or(falseValue), "FALSE")
210 |
211 | assertValue(t, trueValue.And(trueValue), "TRUE")
212 | assertValue(t, trueValue.And(falseValue), "FALSE")
213 | assertValue(t, falseValue.And(trueValue), "FALSE")
214 | assertValue(t, falseValue.And(falseValue), "FALSE")
215 |
216 | assertValue(t, falseValue.Not(), "TRUE")
217 | assertValue(t, trueValue.Not(), "FALSE")
218 |
219 | assertValue(t, trueValue.And(otherValue), "TRUE AND <>")
220 | assertValue(t, trueValue.Or(otherValue), "TRUE")
221 | assertValue(t, trueValue.And(123), "TRUE AND 123")
222 | assertValue(t, trueValue.Or(123), "TRUE")
223 | assertValue(t, falseValue.And(otherValue), "FALSE")
224 | assertValue(t, falseValue.Or(otherValue), "FALSE OR <>")
225 | assertValue(t, falseValue.And(123), "FALSE")
226 | assertValue(t, falseValue.Or(123), "FALSE OR 123")
227 |
228 | assertValue(t, trueValue.And(otherBoolValue), "<>")
229 | assertValue(t, falseValue.Or(otherBoolValue), "<>")
230 | }
231 |
--------------------------------------------------------------------------------
/database.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "fmt"
7 | "os"
8 | "path/filepath"
9 | "runtime"
10 | "strings"
11 | "sync"
12 | "time"
13 | )
14 |
15 | const (
16 | // for colorful terminal print
17 | green = "\033[32m"
18 | red = "\033[31m"
19 | blue = "\033[34m"
20 | reset = "\033[0m"
21 | )
22 |
23 | // Database is the interface of a database with underlying sql.DB object.
24 | type Database interface {
25 | // GetDB returns the underlying sql.DB object of the database
26 | GetDB() *sql.DB
27 | // BeginTx starts a transaction and executes the function f.
28 | BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx Transaction) error) error
29 | // Query executes a query and returns the cursor
30 | Query(sql string) (Cursor, error)
31 | // QueryContext executes a query with context and returns the cursor
32 | QueryContext(ctx context.Context, sqlString string) (Cursor, error)
33 | // Execute executes a statement
34 | Execute(sql string) (sql.Result, error)
35 | // ExecuteContext executes a statement with context
36 | ExecuteContext(ctx context.Context, sql string) (sql.Result, error)
37 | // SetLogger sets the logger function.
38 | // Deprecated: use SetInterceptor instead
39 | SetLogger(logger LoggerFunc)
40 | // SetRetryPolicy sets the retry policy function.
41 | // Deprecated: use SetInterceptor instead
42 | SetRetryPolicy(retryPolicy func(err error) bool)
43 | // EnableCallerInfo enable or disable the caller info in the log.
44 | // Deprecated: use SetInterceptor instead
45 | EnableCallerInfo(enableCallerInfo bool)
46 | // SetInterceptor sets an interceptor function
47 | SetInterceptor(interceptor InterceptorFunc)
48 |
49 | // Select initiates a SELECT statement
50 | Select(fields ...interface{}) selectWithFields
51 | // SelectDistinct initiates a SELECT DISTINCT statement
52 | SelectDistinct(fields ...interface{}) selectWithFields
53 | // SelectFrom initiates a SELECT * FROM statement
54 | SelectFrom(tables ...Table) selectWithTables
55 | // InsertInto initiates a INSERT INTO statement
56 | InsertInto(table Table) insertWithTable
57 | // ReplaceInto initiates a REPLACE INTO statement
58 | ReplaceInto(table Table) insertWithTable
59 | // Update initiates a UPDATE statement
60 | Update(table Table) updateWithSet
61 | // DeleteFrom initiates a DELETE FROM statement
62 | DeleteFrom(table Table) deleteWithTable
63 | }
64 |
65 | type txOrDB interface {
66 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
67 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
68 | }
69 |
70 | var (
71 | once sync.Once
72 | srcPrefix string
73 | )
74 |
75 | type database struct {
76 | db *sql.DB
77 | tx *sql.Tx
78 | logger LoggerFunc
79 | dialect dialect
80 | retryPolicy func(error) bool
81 | enableCallerInfo bool
82 | interceptor InterceptorFunc
83 | }
84 |
85 | type LoggerFunc func(sql string, duration time.Duration, isTx bool, retry bool)
86 |
87 | func (d *database) SetLogger(loggerFunc LoggerFunc) {
88 | d.logger = loggerFunc
89 | }
90 |
91 | // DefaultLogger is sqlingo default logger,
92 | // which print log to stderr and regard executing time gt 100ms as slow sql.
93 | func DefaultLogger(sql string, duration time.Duration, isTx bool, retry bool) {
94 | // for finding code position, try once is enough
95 | once.Do(func() {
96 | // $GOPATH/pkg/mod/github.com/lqs/sqlingo@vX.X.X/database.go
97 | _, file, _, _ := runtime.Caller(0)
98 | // $GOPATH/pkg/mod/github.com/lqs/sqlingo@vX.X.X
99 | srcPrefix = filepath.Dir(file)
100 | })
101 |
102 | var file string
103 | var line int
104 | var ok bool
105 | for i := 0; i < 16; i++ {
106 | _, file, line, ok = runtime.Caller(i)
107 | // `!strings.HasPrefix(file, srcPrefix)` jump out when using sqlingo as dependent package
108 | // `strings.HasSuffix(file, "_test.go")` jump out when executing unit test cases
109 | // `!ok` this is so terrible for something unexpected happened
110 | if !ok || !strings.HasPrefix(file, srcPrefix) || strings.HasSuffix(file, "_test.go") {
111 | break
112 | }
113 | }
114 |
115 | // todo shouldn't append ';' here
116 | if !strings.HasSuffix(sql, ";") {
117 | sql += ";"
118 | }
119 |
120 | sb := strings.Builder{}
121 | sb.Grow(32)
122 | sb.WriteString("|")
123 | sb.WriteString(duration.String())
124 | if isTx {
125 | sb.WriteString("|transaction") // todo using something traceable
126 | }
127 | if retry {
128 | sb.WriteString("|retry")
129 | }
130 | sb.WriteString("|")
131 |
132 | line1 := strings.Join(
133 | []string{
134 | "[sqlingo]",
135 | time.Now().Format("2006-01-02 15:04:05"),
136 | sb.String(),
137 | file + ":" + fmt.Sprint(line),
138 | },
139 | " ")
140 |
141 | // print to stderr
142 | fmt.Fprintln(os.Stderr, blue+line1+reset)
143 | if duration < 100*time.Millisecond {
144 | fmt.Fprintf(os.Stderr, "%s%s%s\n", green, sql, reset)
145 | } else {
146 | fmt.Fprintf(os.Stderr, "%s%s%s\n", red, sql, reset)
147 | }
148 | fmt.Fprintln(os.Stderr)
149 | }
150 |
151 | func (d *database) SetRetryPolicy(retryPolicy func(err error) bool) {
152 | d.retryPolicy = retryPolicy
153 | }
154 |
155 | func (d *database) EnableCallerInfo(enableCallerInfo bool) {
156 | d.enableCallerInfo = enableCallerInfo
157 | }
158 |
159 | func (d *database) SetInterceptor(interceptor InterceptorFunc) {
160 | d.interceptor = interceptor
161 | }
162 |
163 | // Open a database, similar to sql.Open.
164 | // `db` using a default logger, which print log to stderr and regard executing time gt 100ms as slow sql.
165 | // To disable the default logger, use `db.SetLogger(nil)`.
166 | func Open(driverName string, dataSourceName string) (db Database, err error) {
167 | var sqlDB *sql.DB
168 | if dataSourceName != "" {
169 | sqlDB, err = sql.Open(driverName, dataSourceName)
170 | if err != nil {
171 | return
172 | }
173 | }
174 | db = Use(driverName, sqlDB)
175 | return
176 | }
177 |
178 | // Use an existing *sql.DB handle
179 | func Use(driverName string, sqlDB *sql.DB) Database {
180 | return &database{
181 | dialect: getDialectFromDriverName(driverName),
182 | db: sqlDB,
183 | }
184 | }
185 |
186 | func (d database) GetDB() *sql.DB {
187 | return d.db
188 | }
189 |
190 | func (d database) getTxOrDB() txOrDB {
191 | if d.tx != nil {
192 | return d.tx
193 | }
194 | return d.db
195 | }
196 |
197 | func (d database) Query(sqlString string) (Cursor, error) {
198 | return d.QueryContext(context.Background(), sqlString)
199 | }
200 |
201 | func (d database) QueryContext(ctx context.Context, sqlString string) (Cursor, error) {
202 | isRetry := false
203 | for {
204 | sqlStringWithCallerInfo := getCallerInfo(d, isRetry) + sqlString
205 | rows, err := d.queryContextOnce(ctx, sqlStringWithCallerInfo, isRetry)
206 | if err != nil {
207 | isRetry = d.tx == nil && d.retryPolicy != nil && d.retryPolicy(err)
208 | if isRetry {
209 | continue
210 | }
211 | return nil, err
212 | }
213 | return cursor{rows: rows}, nil
214 | }
215 | }
216 |
217 | func (d database) queryContextOnce(ctx context.Context, sqlString string, retry bool) (*sql.Rows, error) {
218 | if ctx == nil {
219 | ctx = context.Background()
220 | }
221 | startTime := time.Now()
222 | defer func() {
223 | endTime := time.Now()
224 | if d.logger != nil {
225 | d.logger(sqlString, endTime.Sub(startTime), false, retry)
226 | }
227 | }()
228 |
229 | interceptor := d.interceptor
230 | var rows *sql.Rows
231 | invoker := func(ctx context.Context, sql string) (err error) {
232 | rows, err = d.getTxOrDB().QueryContext(ctx, sql)
233 | return
234 | }
235 |
236 | var err error
237 | if interceptor == nil {
238 | err = invoker(ctx, sqlString)
239 | } else {
240 | err = interceptor(ctx, sqlString, invoker)
241 | }
242 | if err != nil {
243 | return nil, err
244 | }
245 |
246 | return rows, nil
247 | }
248 |
249 | func (d database) Execute(sqlString string) (sql.Result, error) {
250 | return d.ExecuteContext(context.Background(), sqlString)
251 | }
252 |
253 | // ExecuteContext todo Is there need retry?
254 | func (d database) ExecuteContext(ctx context.Context, sqlString string) (sql.Result, error) {
255 | if ctx == nil {
256 | ctx = context.Background()
257 | }
258 | sqlStringWithCallerInfo := getCallerInfo(d, false) + sqlString
259 | startTime := time.Now()
260 | defer func() {
261 | endTime := time.Now()
262 | if d.logger != nil {
263 | d.logger(sqlStringWithCallerInfo, endTime.Sub(startTime), false, false)
264 | }
265 | }()
266 |
267 | var result sql.Result
268 | invoker := func(ctx context.Context, sql string) (err error) {
269 | result, err = d.getTxOrDB().ExecContext(ctx, sql)
270 | return
271 | }
272 | var err error
273 | if d.interceptor == nil {
274 | err = invoker(ctx, sqlStringWithCallerInfo)
275 | } else {
276 | err = d.interceptor(ctx, sqlStringWithCallerInfo, invoker)
277 | }
278 | if err != nil {
279 | return nil, err
280 | }
281 |
282 | return result, err
283 | }
284 |
--------------------------------------------------------------------------------
/select_test.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "testing"
6 | )
7 |
8 | type tTable1 struct {
9 | Table
10 | }
11 |
12 | var Table1 = tTable1{
13 | NewTable("table1"),
14 | }
15 |
16 | var table1 = NewTable("table1")
17 | var field1 = NewNumberField(table1, "field1")
18 | var field2 = NewNumberField(table1, "field2")
19 |
20 | var table2 = NewTable("table2")
21 | var field3 = NewNumberField(table2, "field3")
22 |
23 | var table3 = NewTable("table3")
24 | var field4 = NewNumberField(table3, "field4")
25 |
26 | func (t tTable1) GetFields() []Field {
27 | return []Field{field1, field2}
28 | }
29 |
30 | func (t tTable1) GetFieldsSQL() string {
31 | return ""
32 | }
33 |
34 | func (t tTable1) GetFullFieldsSQL() string {
35 | return ""
36 | }
37 |
38 | func TestSelect(t *testing.T) {
39 | db := newMockDatabase()
40 | assertValue(t, db.Select(1), "(SELECT 1)")
41 |
42 | db.Select(field1).From(Table1).Where(field1.Equals(42)).Limit(10).GetSQL()
43 |
44 | _, _ = db.Select(field1).From(Table1).Where(field1.Equals(1), field2.Equals(2)).FetchFirst()
45 | assertLastSql(t, "SELECT `field1` FROM `table1` WHERE `field1` = 1 AND `field2` = 2")
46 |
47 | _, _ = db.Select(field1).From(Table1).Where(field1.Equals(1)).Where(field2.Equals(2)).FetchFirst()
48 | assertLastSql(t, "SELECT `field1` FROM `table1` WHERE `field1` = 1 AND `field2` = 2")
49 |
50 | _, _ = db.Select(field1).From(Table1).Where(field1.Equals(1)).WhereIf(true, field2.Equals(2)).FetchFirst()
51 | assertLastSql(t, "SELECT `field1` FROM `table1` WHERE `field1` = 1 AND `field2` = 2")
52 |
53 | _, _ = db.Select(field1).From(Table1).Where(field1.Equals(1)).WhereIf(false, field2.Equals(2)).FetchFirst()
54 | assertLastSql(t, "SELECT `field1` FROM `table1` WHERE `field1` = 1")
55 |
56 | _, _ = db.Select(field1).From(Table1).WhereIf(true, field2.Equals(2)).FetchFirst()
57 | assertLastSql(t, "SELECT `field1` FROM `table1` WHERE `field2` = 2")
58 |
59 | _, _ = db.Select(field1).From(Table1).WhereIf(false, field2.Equals(2)).FetchFirst()
60 | assertLastSql(t, "SELECT `field1` FROM `table1`")
61 |
62 | _, _ = db.Select(field1, field2, field3, Count(1).As("count")).
63 | From(Table1, table2).
64 | Where(field1.Equals(field3), field2.In(db.Select(field3).From(table2))).
65 | GroupBy(field2).
66 | Having(Raw("count").GreaterThan(1)).
67 | OrderBy(field1.Desc(), field2).
68 | Limit(10).
69 | Offset(20).
70 | LockInShareMode().
71 | FetchFirst()
72 | assertLastSql(t, "SELECT `table1`.`field1`, `table1`.`field2`, `table2`.`field3`, COUNT(1) AS count FROM `table1`, `table2` WHERE `table1`.`field1` = `table2`.`field3` AND `table1`.`field2` IN (SELECT `field3` FROM `table2`) GROUP BY `table1`.`field2` HAVING (count) > 1 ORDER BY `table1`.`field1` DESC, `table1`.`field2` LIMIT 10 OFFSET 20 LOCK IN SHARE MODE")
73 |
74 | _, _ = db.SelectDistinct(field2).From(Table1).FetchFirst()
75 | assertLastSql(t, "SELECT DISTINCT `field2` FROM `table1`")
76 |
77 | _, _ = db.Select(field1, field3).From(Table1).Join(table2).On(field1.Equals(field3)).FetchFirst()
78 | assertLastSql(t, "SELECT `table1`.`field1`, `table2`.`field3` FROM `table1` JOIN `table2` ON `table1`.`field1` = `table2`.`field3`")
79 | _, _ = db.Select(field1, field3).From(Table1).LeftJoin(table2).On(field1.Equals(field3)).FetchFirst()
80 | assertLastSql(t, "SELECT `table1`.`field1`, `table2`.`field3` FROM `table1` LEFT JOIN `table2` ON `table1`.`field1` = `table2`.`field3`")
81 | _, _ = db.Select(field1, field3).From(Table1).RightJoin(table2).On(field1.Equals(field3)).FetchFirst()
82 | assertLastSql(t, "SELECT `table1`.`field1`, `table2`.`field3` FROM `table1` RIGHT JOIN `table2` ON `table1`.`field1` = `table2`.`field3`")
83 |
84 | _, _ = db.Select(field1, field3).From(Table1).
85 | LeftJoin(table2).On(field1.Equals(field3)).
86 | RightJoin(table3).On(field1.Equals(field4)).FetchFirst()
87 | assertLastSql(t, "SELECT `table1`.`field1`, `table2`.`field3` FROM `table1` LEFT JOIN `table2` ON `table1`.`field1` = `table2`.`field3` RIGHT JOIN `table3` ON `table1`.`field1` = `table3`.`field4`")
88 |
89 | db.Select(1).WithContext(context.Background())
90 |
91 | _, _ = db.SelectFrom(Table1).FetchFirst()
92 | assertLastSql(t, "SELECT FROM `table1`")
93 |
94 | _, _ = db.Select([]Field{field1, field2}).From(Table1).FetchFirst()
95 | assertLastSql(t, "SELECT `field1`, `field2` FROM `table1`")
96 |
97 | _, _ = db.Select([]interface{}{&field1, field2, []int{3, 4}}).From(Table1).FetchFirst()
98 | assertLastSql(t, "SELECT `field1`, `field2`, 3, 4 FROM `table1`")
99 |
100 | _, _ = db.Select(field1, Table1).FetchFirst()
101 | assertLastSql(t, "SELECT `field1`, `field1`, `field2` FROM `table1`")
102 | }
103 |
104 | func TestCount(t *testing.T) {
105 | db := newMockDatabase()
106 |
107 | _, _ = db.SelectFrom(Test).Count()
108 | assertLastSql(t, "SELECT COUNT(1) FROM `test`")
109 |
110 | _, _ = db.SelectDistinct(Test.F1).From(Test).Count()
111 | assertLastSql(t, "SELECT COUNT(DISTINCT `f1`) FROM `test`")
112 |
113 | _, _ = db.Select(Test.F1).From(Test).GroupBy(Test.F2).Count()
114 | assertLastSql(t, "SELECT COUNT(1) FROM (SELECT 1 FROM `test` GROUP BY `f2`) AS t")
115 |
116 | _, _ = db.SelectDistinct(Test.F1).From(Test).GroupBy(Test.F2).Count()
117 | assertLastSql(t, "SELECT COUNT(1) FROM (SELECT DISTINCT `f1` FROM `test` GROUP BY `f2`) AS t")
118 |
119 | _, _ = db.Select(Test.F1).From(Test).Exists()
120 | assertLastSql(t, "SELECT EXISTS (SELECT `f1` FROM `test`)")
121 |
122 | _, _ = db.Select(Test.F1).From(Test).Limit(10).Count()
123 | assertLastSql(t, "SELECT COUNT(1) FROM (SELECT 1 FROM `test` LIMIT 10) AS t")
124 | }
125 |
126 | func TestSelectAutoFrom(t *testing.T) {
127 | db := newMockDatabase()
128 |
129 | _, _ = db.Select(field1, field2, 123).FetchFirst()
130 | assertLastSql(t, "SELECT `field1`, `field2`, 123 FROM `table1`")
131 |
132 | _, _ = db.Select(field1, field2, 123, field3).FetchFirst()
133 | assertLastSql(t, "SELECT `table1`.`field1`, `table1`.`field2`, 123, `table2`.`field3` FROM `table1`, `table2`")
134 | }
135 |
136 | func TestFetch(t *testing.T) {
137 | db := newMockDatabase()
138 | defer func() {
139 | sharedMockConn.columnCount = 7
140 | sharedMockConn.rowCount = 10
141 | }()
142 |
143 | sharedMockConn.columnCount = 2
144 |
145 | _ = db
146 | var f1 string
147 | var f2 int
148 |
149 | ok, err := db.Select(field1, field2).From(Table1).FetchFirst(&f1, &f2)
150 | if !ok || err != nil {
151 | t.Error()
152 | }
153 |
154 | if err := db.Select(field1, field2).From(Table1).FetchExactlyOne(&f1, &f2); err == nil {
155 | t.Error("should get error")
156 | }
157 |
158 | sharedMockConn.rowCount = 1
159 | if err := db.Select(field1, field2).From(Table1).FetchExactlyOne(&f1, &f2); err != nil {
160 | t.Error(err)
161 | }
162 |
163 | sharedMockConn.rowCount = 0
164 | if err := db.Select(field1, field2).From(Table1).FetchExactlyOne(&f1, &f2); err == nil {
165 | t.Error("should get error")
166 | }
167 |
168 | }
169 |
170 | func TestFetchAll(t *testing.T) {
171 | db := newMockDatabase()
172 |
173 | sharedMockConn.columnCount = 2
174 | defer func() {
175 | sharedMockConn.columnCount = 7
176 | }()
177 |
178 | // fetch all as slices
179 | var f1s []string
180 | var f2s []int
181 | if _, err := db.Select(field1, field2).From(Table1).FetchAll(&f1s, &f2s); err != nil {
182 | t.Error(err)
183 | }
184 | if len(f1s) != 10 || len(f2s) != 10 {
185 | t.Error(f1s, f2s)
186 | }
187 |
188 | type record struct {
189 | unexportedF1 string
190 | ExportedF1 string
191 | unexportedF2 int
192 | ExportedF2 int
193 | }
194 | var records []record
195 | if _, err := db.Select(field1, field2).From(Table1).FetchAll(&records); err != nil {
196 | t.Error(err)
197 | }
198 | if len(records) != 10 {
199 | t.Error(records)
200 | }
201 |
202 | // fetch all as map
203 | var m map[string]int
204 | if _, err := db.Select(field1).From(Table1).FetchAll(&m); err != nil {
205 | t.Error(err)
206 | }
207 |
208 | // fetch all as multiple maps is illegal
209 | if _, err := db.Select(field1).From(Table1).FetchAll(&m, &m); err == nil {
210 | t.Error("should get error here")
211 | }
212 |
213 | // fetch all as unsupported type
214 | var unsupported int
215 | if _, err := db.Select(field1).From(Table1).FetchAll(&unsupported); err == nil {
216 | t.Error("should get error here")
217 | }
218 |
219 | // fetch all for non-pointer
220 | if _, err := db.Select(field1).From(Table1).FetchAll(123); err == nil {
221 | t.Error("should get error here")
222 | }
223 | }
224 |
225 | func TestRangeFunc(t *testing.T) {
226 | db := newMockDatabase()
227 |
228 | oldColumnCount := sharedMockConn.columnCount
229 | sharedMockConn.columnCount = 2
230 | defer func() {
231 | sharedMockConn.columnCount = oldColumnCount
232 | }()
233 |
234 | count := 0
235 | seq := db.Select(field1, field2).From(Table1).FetchSeq()
236 |
237 | // for row := range db.Select(field1, field2).From(Table1).FetchSeq() {}
238 | seq(func(row Scanner) bool {
239 | var f1 string
240 | var f2 int
241 | if err := row.Scan(&f1, &f2); err != nil {
242 | t.Error(err)
243 | }
244 | count++
245 | return true
246 | })
247 |
248 | if count != 10 {
249 | t.Error(count)
250 | }
251 | }
252 |
253 | func TestLock(t *testing.T) {
254 | db := newMockDatabase()
255 | table1 := NewTable("table1")
256 | _, _ = db.Select(1).From(table1).LockInShareMode().FetchAll()
257 | assertLastSql(t, "SELECT 1 FROM `table1` LOCK IN SHARE MODE")
258 | _, _ = db.Select(1).From(table1).ForUpdate().FetchAll()
259 | assertLastSql(t, "SELECT 1 FROM `table1` FOR UPDATE")
260 | _, _ = db.Select(1).From(table1).ForUpdateNoWait().FetchAll()
261 | assertLastSql(t, "SELECT 1 FROM `table1` FOR UPDATE NOWAIT")
262 | _, _ = db.Select(1).From(table1).ForUpdateSkipLocked().FetchAll()
263 | assertLastSql(t, "SELECT 1 FROM `table1` FOR UPDATE SKIP LOCKED")
264 | }
265 |
266 | func TestUnion(t *testing.T) {
267 | db := newMockDatabase()
268 | table1 := NewTable("table1")
269 | table2 := NewTable("table2")
270 |
271 | cond1 := Raw("")
272 | cond2 := Raw("")
273 |
274 | _, _ = db.SelectFrom(table1).UnionSelectFrom(table2).Where(cond1).FetchAll()
275 | assertLastSql(t, "SELECT * FROM `table1` UNION SELECT * FROM `table2` WHERE ")
276 |
277 | _, _ = db.SelectFrom(table1).Where(cond1).
278 | UnionSelectFrom(table2).Where(cond2).FetchAll()
279 | assertLastSql(t, "SELECT * FROM `table1` WHERE UNION SELECT * FROM `table2` WHERE ")
280 |
281 | _, _ = db.SelectFrom(table1).Where(Raw("C1")).
282 | UnionSelectFrom(table2).Where(Raw("C2")).
283 | UnionSelect(3).From(table2).Where(Raw("C3")).
284 | UnionSelectDistinct(4).From(table2).Where(Raw("C4")).
285 | UnionAllSelectFrom(table2).Where(Raw("C5")).
286 | UnionAllSelect(6).From(table2).Where(Raw("C6")).
287 | UnionAllSelectDistinct(7).From(table2).Where(Raw("C7")).
288 | FetchAll()
289 | assertLastSql(t, "SELECT * FROM `table1` WHERE C1 "+
290 | "UNION SELECT * FROM `table2` WHERE C2 "+
291 | "UNION SELECT 3 FROM `table2` WHERE C3 "+
292 | "UNION SELECT DISTINCT 4 FROM `table2` WHERE C4 "+
293 | "UNION ALL SELECT * FROM `table2` WHERE C5 "+
294 | "UNION ALL SELECT 6 FROM `table2` WHERE C6 "+
295 | "UNION ALL SELECT DISTINCT 7 FROM `table2` WHERE C7")
296 |
297 | _, _ = db.SelectFrom(table1).Where(Raw("C1")).
298 | UnionSelectFrom(table2).Where(Raw("C2")).
299 | UnionSelect(3).From(table2).Where(Raw("C3")).
300 | UnionSelectDistinct(4).From(table2).Where(Raw("C4")).
301 | UnionAllSelectFrom(table2).Where(Raw("C5")).
302 | UnionAllSelect(6).From(table2).Where(Raw("C6")).
303 | UnionAllSelectDistinct(7).From(table2).Where(Raw("C7")).
304 | Count()
305 | assertLastSql(t, "SELECT COUNT(1) FROM ("+
306 | "SELECT 1 FROM `table1` WHERE C1 "+
307 | "UNION SELECT * FROM `table2` WHERE C2 "+
308 | "UNION SELECT 3 FROM `table2` WHERE C3 "+
309 | "UNION SELECT DISTINCT 4 FROM `table2` WHERE C4 "+
310 | "UNION ALL SELECT * FROM `table2` WHERE C5 "+
311 | "UNION ALL SELECT 6 FROM `table2` WHERE C6 "+
312 | "UNION ALL SELECT DISTINCT 7 FROM `table2` WHERE C7"+
313 | ") AS t")
314 | }
315 |
316 | func Test_selectStatus_NaturalJoin(t *testing.T) {
317 | db := newMockDatabase()
318 | table1 := NewTable("table1")
319 | table2 := NewTable("table2")
320 | table3 := NewTable("table3")
321 | table4 := NewTable("table4")
322 | cond1 := Raw("")
323 | cond2 := Raw("")
324 | cond3 := Raw("")
325 |
326 | _, _ = db.SelectFrom(table1).NaturalJoin(table2).Where(cond1).FetchAll()
327 | assertLastSql(t, "SELECT * FROM `table1` NATURAL JOIN `table2` WHERE ")
328 | _, _ = db.SelectFrom(table1).
329 | Join(table2).On(cond1).
330 | NaturalJoin(table2).
331 | NaturalJoin(table3).
332 | LeftJoin(table4).On(cond3).
333 | Where(cond2).FetchAll()
334 | assertLastSql(t, "SELECT * FROM `table1` JOIN `table2` ON NATURAL JOIN `table2`"+
335 | " NATURAL JOIN `table3` LEFT JOIN `table4` ON WHERE ")
336 |
337 | }
338 |
--------------------------------------------------------------------------------
/generator/generator.go:
--------------------------------------------------------------------------------
1 | package generator
2 |
3 | import (
4 | "database/sql"
5 | "errors"
6 | "fmt"
7 | "go/format"
8 | "os"
9 | "regexp"
10 | "strconv"
11 | "strings"
12 | "sync"
13 | "sync/atomic"
14 | "unicode"
15 | )
16 |
17 | const (
18 | sqlingoGeneratorVersion = 2
19 | )
20 |
21 | type schemaFetcher interface {
22 | GetDatabaseName() (dbName string, err error)
23 | GetTableNames() (tableNames []string, err error)
24 | GetFieldDescriptors(tableName string) ([]fieldDescriptor, error)
25 | QuoteIdentifier(identifier string) string
26 | }
27 |
28 | type fieldDescriptor struct {
29 | Name string
30 | Type string
31 | Size int
32 | Unsigned bool
33 | AllowNull bool
34 | Comment string
35 | }
36 |
37 | func convertToExportedIdentifier(s string, forceCases []string) string {
38 | var words []string
39 | nextCharShouldBeUpperCase := true
40 | for _, r := range s {
41 | if unicode.IsLetter(r) || unicode.IsDigit(r) {
42 | if nextCharShouldBeUpperCase {
43 | words = append(words, "")
44 | words[len(words)-1] += string(unicode.ToUpper(r))
45 | nextCharShouldBeUpperCase = false
46 | } else {
47 | words[len(words)-1] += string(r)
48 | }
49 | } else {
50 | nextCharShouldBeUpperCase = true
51 | }
52 | }
53 | result := ""
54 | for _, word := range words {
55 | for _, caseWord := range forceCases {
56 | if strings.EqualFold(word, caseWord) {
57 | word = caseWord
58 | break
59 | }
60 | }
61 | result += word
62 | }
63 | var firstRune rune
64 | for _, r := range result {
65 | firstRune = r
66 | break
67 | }
68 | if result == "" || !unicode.IsUpper(firstRune) {
69 | result = "E" + result
70 | }
71 | return result
72 | }
73 |
74 | func getType(fieldDescriptor fieldDescriptor) (goType string, fieldClass string, fieldComment string, err error) {
75 | switch strings.ToLower(fieldDescriptor.Type) {
76 | case "tinyint":
77 | goType = "int8"
78 | fieldClass = "NumberField"
79 | case "smallint":
80 | goType = "int16"
81 | fieldClass = "NumberField"
82 | case "int", "mediumint":
83 | goType = "int32"
84 | fieldClass = "NumberField"
85 | case "bigint", "integer":
86 | goType = "int64"
87 | fieldClass = "NumberField"
88 | case "float", "double", "decimal", "real":
89 | goType = "float64"
90 | fieldClass = "NumberField"
91 | case "char", "varchar", "text", "tinytext", "mediumtext", "longtext", "enum", "date", "time", "json", "numeric", "character varying", "timestamp without time zone", "timestamp with time zone", "jsonb", "uuid":
92 | goType = "string"
93 | fieldClass = "StringField"
94 | case "year":
95 | goType = "int16"
96 | fieldClass = "NumberField"
97 | fieldDescriptor.Unsigned = true
98 | case "binary", "varbinary", "blob", "tinyblob", "mediumblob", "longblob":
99 | // TODO: use []byte ?
100 | goType = "string"
101 | fieldClass = "StringField"
102 | case "array":
103 | // TODO: Switch to specific type instead of interface.
104 | goType = "[]interface{}"
105 | fieldClass = "ArrayField"
106 | case "timestamp":
107 | if !timeAsString {
108 | goType = "time.Time"
109 | fieldClass = "DateField"
110 | fieldComment = "NOTICE: the range of timestamp is [1970-01-01 08:00:01, 2038-01-19 11:14:07]"
111 | } else {
112 | goType = "string"
113 | fieldClass = "StringField"
114 | }
115 | case "datetime":
116 | if !timeAsString {
117 | goType = "time.Time"
118 | fieldClass = "DateField"
119 | fieldComment = "NOTICE: the range of datetime is [0000-01-01 00:00:00, 9999-12-31 23:59:59]"
120 | } else {
121 | goType = "string"
122 | fieldClass = "StringField"
123 | }
124 | case "geometry", "point", "linestring", "polygon", "multipoint", "multilinestring", "multipolygon", "geometrycollection":
125 | goType = "sqlingo.WellKnownBinary"
126 | fieldClass = "WellKnownBinaryField"
127 | case "bit", "bool", "boolean":
128 | if fieldDescriptor.Size == 1 {
129 | goType = "bool"
130 | fieldClass = "BooleanField"
131 | } else {
132 | goType = "string"
133 | fieldClass = "StringField"
134 | }
135 | default:
136 | err = fmt.Errorf("unknown field type %s", fieldDescriptor.Type)
137 | return
138 | }
139 | if fieldDescriptor.Unsigned && strings.HasPrefix(goType, "int") {
140 | goType = "u" + goType
141 | }
142 | if fieldDescriptor.AllowNull {
143 | goType = "*" + goType
144 | }
145 | return
146 | }
147 |
148 | func getSchemaFetcherFactory(driverName string) func(db *sql.DB) schemaFetcher {
149 | switch driverName {
150 | case "mysql":
151 | return newMySQLSchemaFetcher
152 | case "sqlite3":
153 | return newSQLite3SchemaFetcher
154 | case "postgres":
155 | return newPostgresSchemaFetcher
156 | default:
157 | _, _ = fmt.Fprintln(os.Stderr, "unsupported driver "+driverName)
158 | os.Exit(2)
159 | return nil
160 | }
161 | }
162 |
163 | var nonIdentifierRegexp = regexp.MustCompile(`\W`)
164 |
165 | func ensureIdentifier(name string) string {
166 | result := nonIdentifierRegexp.ReplaceAllString(name, "_")
167 | if result == "" || (result[0] >= '0' && result[0] <= '9') {
168 | result = "_" + result
169 | }
170 | return result
171 | }
172 |
173 | // Generate generates code for the given driverName.
174 | func Generate(driverName string, exampleDataSourceName string) (string, error) {
175 | options := parseArgs(exampleDataSourceName)
176 |
177 | db, err := sql.Open(driverName, options.dataSourceName)
178 | if err != nil {
179 | return "", err
180 | }
181 | db.SetMaxOpenConns(10)
182 |
183 | schemaFetcherFactory := getSchemaFetcherFactory(driverName)
184 | schemaFetcher := schemaFetcherFactory(db)
185 |
186 | dbName, err := schemaFetcher.GetDatabaseName()
187 | if err != nil {
188 | return "", err
189 | }
190 |
191 | if dbName == "" {
192 | return "", errors.New("no database selected")
193 | }
194 |
195 | if len(options.tableNames) == 0 {
196 | options.tableNames, err = schemaFetcher.GetTableNames()
197 | if err != nil {
198 | return "", err
199 | }
200 | }
201 |
202 | needImportTime := false
203 | for _, tableName := range options.tableNames {
204 | fieldDescriptors, err := schemaFetcher.GetFieldDescriptors(tableName)
205 | if err != nil {
206 | return "", err
207 | }
208 | for _, fieldDescriptor := range fieldDescriptors {
209 | if !timeAsString && fieldDescriptor.Type == "datetime" || fieldDescriptor.Type == "timestamp" {
210 | needImportTime = true
211 | break
212 | }
213 | }
214 | }
215 |
216 | code := "// This file is generated by sqlingo (https://github.com/lqs/sqlingo)\n"
217 | code += "// DO NOT EDIT.\n\n"
218 | code += "package " + ensureIdentifier(dbName) + "_dsl\n"
219 | if needImportTime {
220 | code += "import (\n"
221 | code += "\t\"time\"\n"
222 | code += "\t\"github.com/lqs/sqlingo\"\n"
223 | code += ")\n\n"
224 | } else {
225 | code += "import \"github.com/lqs/sqlingo\"\n\n"
226 | }
227 |
228 | code += "type sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame uint32\n\n"
229 |
230 | sqlingoGeneratorVersionString := strconv.Itoa(sqlingoGeneratorVersion)
231 | code += "const _ = sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame(sqlingo.SqlingoRuntimeVersion - " + sqlingoGeneratorVersionString + ")\n"
232 | code += "const _ = sqlingoRuntimeAndGeneratorVersionsShouldBeTheSame(" + sqlingoGeneratorVersionString + " - sqlingo.SqlingoRuntimeVersion)\n\n"
233 |
234 | code += "type table interface {\n"
235 | code += "\tsqlingo.Table\n"
236 | code += "}\n\n"
237 |
238 | code += "type numberField interface {\n"
239 | code += "\tsqlingo.NumberField\n"
240 | code += "}\n\n"
241 |
242 | code += "type stringField interface {\n"
243 | code += "\tsqlingo.StringField\n"
244 | code += "}\n\n"
245 |
246 | code += "type booleanField interface {\n"
247 | code += "\tsqlingo.BooleanField\n"
248 | code += "}\n\n"
249 |
250 | code += "type arrayField interface {\n"
251 | code += "\tsqlingo.ArrayField\n"
252 | code += "}\n\n"
253 |
254 | code += "type dateField interface {\n"
255 | code += "\tsqlingo.DateField\n"
256 | code += "}\n\n"
257 |
258 | var wg sync.WaitGroup
259 |
260 | type tableCodeItem struct {
261 | code string
262 | err error
263 | }
264 | tableCodeMap := make(map[string]*tableCodeItem)
265 | fmt.Fprintln(os.Stderr, "Generating code for tables...")
266 | var counter int32
267 | for _, tableName := range options.tableNames {
268 | wg.Add(1)
269 | item := &tableCodeItem{}
270 | tableCodeMap[tableName] = item
271 | go func(tableName string) {
272 | defer wg.Done()
273 | tableCode, err := generateTable(schemaFetcher, tableName, options.forceCases)
274 | if err != nil {
275 | item.err = err
276 | return
277 | }
278 | _, _ = fmt.Fprintf(os.Stderr, "Generated (%d/%d) %s\n", atomic.AddInt32(&counter, 1), len(options.tableNames), tableName)
279 | item.code = tableCode
280 | }(tableName)
281 | }
282 | wg.Wait()
283 | for _, tableName := range options.tableNames {
284 | item := tableCodeMap[tableName]
285 | if item.err != nil {
286 | return "", item.err
287 | }
288 | code += item.code
289 | }
290 | code += generateGetTable(options)
291 | codeOut, err := format.Source([]byte(code))
292 | if err != nil {
293 | return "", err
294 | }
295 | return string(codeOut), nil
296 | }
297 |
298 | func generateGetTable(options options) string {
299 | code := "func GetTable(name string) sqlingo.Table {\n"
300 | code += "\tswitch name {\n"
301 | for _, tableName := range options.tableNames {
302 | code += "\tcase " + strconv.Quote(tableName) + ": return " + convertToExportedIdentifier(tableName, options.forceCases) + "\n"
303 | }
304 | code += "\tdefault: return nil\n"
305 | code += "\t}\n"
306 | code += "}\n\n"
307 |
308 | code += "func GetTables() []sqlingo.Table {\n"
309 | code += "\treturn []sqlingo.Table{\n"
310 | for _, tableName := range options.tableNames {
311 | code += "\t" + convertToExportedIdentifier(tableName, options.forceCases) + ",\n"
312 | }
313 | code += "\t}"
314 | code += "}\n\n"
315 |
316 | return code
317 | }
318 |
319 | func generateTable(schemaFetcher schemaFetcher, tableName string, forceCases []string) (string, error) {
320 | fieldDescriptors, err := schemaFetcher.GetFieldDescriptors(tableName)
321 | if err != nil {
322 | return "", err
323 | }
324 |
325 | className := convertToExportedIdentifier(tableName, forceCases)
326 | tableStructName := "t" + className
327 | tableObjectName := "o" + className
328 |
329 | modelClassName := className + "Model"
330 |
331 | tableLines := ""
332 | modelLines := ""
333 | objectLines := "\ttable: " + tableObjectName + ",\n\n"
334 | fieldCaseLines := ""
335 | classLines := ""
336 |
337 | fields := ""
338 | fieldsSQL := ""
339 | fullFieldsSQL := ""
340 | values := ""
341 |
342 | for _, fieldDescriptor := range fieldDescriptors {
343 |
344 | goName := convertToExportedIdentifier(fieldDescriptor.Name, forceCases)
345 | goType, fieldClass, typeComment, err := getType(fieldDescriptor)
346 | if err != nil {
347 | return "", err
348 | }
349 |
350 | privateFieldClass := string(fieldClass[0]+'a'-'A') + fieldClass[1:]
351 |
352 | commentLine := ""
353 | if fieldDescriptor.Comment != "" {
354 | commentLine = "\t// " + strings.ReplaceAll(fieldDescriptor.Comment, "\n", " ") + "\n"
355 | }
356 | if typeComment != "" {
357 | commentLine = "\t// " + typeComment + "\n"
358 | }
359 |
360 | fieldStructName := strings.ToLower(replaceTypeSpace(fieldDescriptor.Type)) + "_" + className + "_" + goName
361 |
362 | tableLines += commentLine
363 | tableLines += "\t" + goName + " " + fieldStructName + "\n"
364 |
365 | modelLines += commentLine
366 | modelLines += "\t" + goName + " " + goType + "\n"
367 |
368 | objectLines += commentLine
369 | objectLines += "\t" + goName + ": " + fieldStructName + "{"
370 | objectLines += "sqlingo.New" + fieldClass + "(" + tableObjectName + ", " + strconv.Quote(fieldDescriptor.Name) + ")},\n"
371 |
372 | fieldCaseLines += "\tcase " + strconv.Quote(fieldDescriptor.Name) + ": return t." + goName + "\n"
373 |
374 | classLines += "type " + fieldStructName + " struct{ " + privateFieldClass + " }\n"
375 |
376 | fields += "t." + goName + ", "
377 |
378 | if fieldsSQL != "" {
379 | fieldsSQL += ", "
380 | }
381 | fieldsSQL += schemaFetcher.QuoteIdentifier(fieldDescriptor.Name)
382 |
383 | if fullFieldsSQL != "" {
384 | fullFieldsSQL += ", "
385 | }
386 | fullFieldsSQL += schemaFetcher.QuoteIdentifier(tableName) + "." + schemaFetcher.QuoteIdentifier(fieldDescriptor.Name)
387 |
388 | values += "m." + goName + ", "
389 | }
390 | code := ""
391 | code += "type " + tableStructName + " struct {\n\ttable\n\n"
392 | code += tableLines
393 | code += "}\n\n"
394 |
395 | code += classLines
396 |
397 | code += "var " + tableObjectName + " = sqlingo.NewTable(" + strconv.Quote(tableName) + ")\n"
398 | code += "var " + className + " = " + tableStructName + "{\n"
399 | code += objectLines
400 | code += "}\n\n"
401 |
402 | code += "func (t t" + className + ") GetFields() []sqlingo.Field {\n"
403 | code += "\treturn []sqlingo.Field{" + fields + "}\n"
404 | code += "}\n\n"
405 |
406 | code += "func (t t" + className + ") GetFieldByName(name string) sqlingo.Field {\n"
407 | code += "\tswitch name {\n"
408 | code += fieldCaseLines
409 | code += "\tdefault: return nil\n"
410 | code += "\t}\n"
411 | code += "}\n\n"
412 |
413 | code += "func (t t" + className + ") GetFieldsSQL() string {\n"
414 | code += "\treturn " + strconv.Quote(fieldsSQL) + "\n"
415 | code += "}\n\n"
416 |
417 | code += "func (t t" + className + ") GetFullFieldsSQL() string {\n"
418 | code += "\treturn " + strconv.Quote(fullFieldsSQL) + "\n"
419 | code += "}\n\n"
420 |
421 | code += "type " + modelClassName + " struct {\n"
422 | code += modelLines
423 | code += "}\n\n"
424 |
425 | code += "func (m " + modelClassName + ") GetTable() sqlingo.Table {\n"
426 | code += "\treturn " + className + "\n"
427 | code += "}\n\n"
428 |
429 | code += "func (m " + modelClassName + ") GetValues() []interface{} {\n"
430 | code += "\treturn []interface{}{" + values + "}\n"
431 | code += "}\n\n"
432 | return code, nil
433 | }
434 |
435 | // replaceTypeSpace : To compatible some types contains spaces in postgresql
436 | // like [character varying, timestamp without time zone, timestamp with time zone]
437 | func replaceTypeSpace(typename string) string {
438 | return strings.ReplaceAll(typename, " ", "_")
439 | }
440 |
--------------------------------------------------------------------------------
/select.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "reflect"
7 | "strconv"
8 | "strings"
9 | )
10 |
11 | type selectWithFields interface {
12 | toSelectWithContext
13 | toSelectFinal
14 | From(tables ...Table) selectWithTables
15 | }
16 |
17 | type selectWithTables interface {
18 | toSelectJoin
19 | toSelectWhere
20 | toSelectWithLock
21 | toSelectWithContext
22 | toSelectFinal
23 | toUnionSelect
24 | GroupBy(expressions ...Expression) selectWithGroupBy
25 | OrderBy(orderBys ...OrderBy) selectWithOrder
26 | Limit(limit int) selectWithLimit
27 | }
28 |
29 | type toSelectJoin interface {
30 | Join(table Table) selectWithJoin
31 | LeftJoin(table Table) selectWithJoin
32 | RightJoin(table Table) selectWithJoin
33 | NaturalJoin(table Table) selectWithJoinOn
34 | }
35 |
36 | type selectWithJoin interface {
37 | On(condition BooleanExpression) selectWithJoinOn
38 | }
39 |
40 | type selectWithJoinOn interface {
41 | toSelectWhere
42 | toSelectWithLock
43 | toSelectWithContext
44 | toSelectFinal
45 | toUnionSelect
46 | toSelectJoin
47 | GroupBy(expressions ...Expression) selectWithGroupBy
48 | OrderBy(orderBys ...OrderBy) selectWithOrder
49 | Limit(limit int) selectWithLimit
50 | }
51 |
52 | type toSelectWhere interface {
53 | Where(conditions ...BooleanExpression) selectWithWhere
54 | WhereIf(prerequisite bool, conditions ...BooleanExpression) selectWithWhere
55 | }
56 |
57 | type selectWithWhere interface {
58 | toSelectWhere
59 | toSelectWithLock
60 | toSelectWithContext
61 | toSelectFinal
62 | toUnionSelect
63 | GroupBy(expressions ...Expression) selectWithGroupBy
64 | OrderBy(orderBys ...OrderBy) selectWithOrder
65 | Limit(limit int) selectWithLimit
66 | }
67 |
68 | type selectWithGroupBy interface {
69 | toSelectWithLock
70 | toSelectWithContext
71 | toSelectFinal
72 | toUnionSelect
73 | Having(conditions ...BooleanExpression) selectWithGroupByHaving
74 | OrderBy(orderBys ...OrderBy) selectWithOrder
75 | Limit(limit int) selectWithLimit
76 | }
77 |
78 | type selectWithGroupByHaving interface {
79 | toSelectWithLock
80 | toSelectWithContext
81 | toSelectFinal
82 | toUnionSelect
83 | OrderBy(orderBys ...OrderBy) selectWithOrder
84 | }
85 |
86 | type selectWithOrder interface {
87 | toSelectWithLock
88 | toSelectWithContext
89 | toSelectFinal
90 | Limit(limit int) selectWithLimit
91 | }
92 |
93 | type selectWithLimit interface {
94 | toSelectWithLock
95 | toSelectWithContext
96 | toSelectFinal
97 | Offset(offset int) selectWithOffset
98 | }
99 |
100 | type selectWithOffset interface {
101 | toSelectWithLock
102 | toSelectWithContext
103 | toSelectFinal
104 | }
105 |
106 | type toSelectWithLock interface {
107 | LockInShareMode() selectWithLock
108 | ForUpdate() selectWithLock
109 | ForUpdateNoWait() selectWithLock
110 | ForUpdateSkipLocked() selectWithLock
111 | }
112 |
113 | type selectWithLock interface {
114 | toSelectWithContext
115 | toSelectFinal
116 | }
117 |
118 | type toSelectWithContext interface {
119 | WithContext(ctx context.Context) toSelectFinal
120 | }
121 |
122 | type toUnionSelect interface {
123 | UnionSelect(fields ...interface{}) selectWithFields
124 | UnionSelectFrom(tables ...Table) selectWithTables
125 | UnionSelectDistinct(fields ...interface{}) selectWithFields
126 | UnionAllSelect(fields ...interface{}) selectWithFields
127 | UnionAllSelectFrom(tables ...Table) selectWithTables
128 | UnionAllSelectDistinct(fields ...interface{}) selectWithFields
129 | }
130 |
131 | type toSelectFinal interface {
132 | Exists() (bool, error)
133 | Count() (int, error)
134 | GetSQL() (string, error)
135 | FetchFirst(out ...interface{}) (bool, error)
136 | FetchExactlyOne(out ...interface{}) error
137 | FetchAll(dest ...interface{}) (rows int, err error)
138 | FetchCursor() (Cursor, error)
139 | FetchSeq() func(yield func(row Scanner) bool) // use with "range over function" in Go 1.22
140 | }
141 |
142 | type join struct {
143 | previous *join
144 | prefix string
145 | table Table
146 | on BooleanExpression
147 | }
148 |
149 | type selectBase struct {
150 | scope scope
151 | distinct bool
152 | fields fieldList
153 | where BooleanExpression
154 | groupBys []Expression
155 | having BooleanExpression
156 | }
157 |
158 | type selectStatus struct {
159 | base selectBase
160 | orderBys []OrderBy
161 | lastUnion *unionSelectStatus
162 | limit *int
163 | offset int
164 | ctx context.Context
165 | lock string
166 | }
167 |
168 | type errorScanner struct {
169 | err error
170 | }
171 |
172 | func (e errorScanner) Scan(dest ...interface{}) error {
173 | return e.err
174 | }
175 |
176 | func (s selectStatus) FetchSeq() func(yield func(row Scanner) bool) {
177 | return func(yield func(row Scanner) bool) {
178 | cursor, err := s.FetchCursor()
179 | if err != nil {
180 | yield(errorScanner{err})
181 | return
182 | }
183 |
184 | defer cursor.Close()
185 | for cursor.Next() {
186 | if !yield(cursor) {
187 | break
188 | }
189 | }
190 | }
191 | }
192 |
193 | type unionSelectStatus struct {
194 | base selectBase
195 | all bool
196 | previous *unionSelectStatus
197 | }
198 |
199 | func activeSelectBase(s *selectStatus) *selectBase {
200 | if s.lastUnion != nil {
201 | return &s.lastUnion.base
202 | }
203 | return &s.base
204 | }
205 |
206 | func (s selectStatus) Join(table Table) selectWithJoin {
207 | return s.join("", table)
208 | }
209 |
210 | func (s selectStatus) LeftJoin(table Table) selectWithJoin {
211 | return s.join("LEFT ", table)
212 | }
213 |
214 | func (s selectStatus) RightJoin(table Table) selectWithJoin {
215 | return s.join("RIGHT ", table)
216 | }
217 |
218 | // NaturalJoin joins the table using the NATURAL keyword.
219 | // it automatically matches the columns in the two tables that have the same name.
220 | // it not be needed but be provided for completeness.
221 | func (s selectStatus) NaturalJoin(table Table) selectWithJoinOn {
222 | base := activeSelectBase(&s)
223 | base.scope.lastJoin = &join{
224 | previous: base.scope.lastJoin,
225 | prefix: "NATURAL ",
226 | table: table,
227 | }
228 | join := *base.scope.lastJoin
229 | base.scope.lastJoin = &join
230 | return s
231 | }
232 |
233 | func (s selectStatus) join(prefix string, table Table) selectWithJoin {
234 | base := activeSelectBase(&s)
235 | base.scope.lastJoin = &join{
236 | previous: base.scope.lastJoin,
237 | prefix: prefix,
238 | table: table,
239 | }
240 | return s
241 | }
242 |
243 | func (s selectStatus) On(condition BooleanExpression) selectWithJoinOn {
244 | base := activeSelectBase(&s)
245 | join := *base.scope.lastJoin
246 | join.on = condition
247 | base.scope.lastJoin = &join
248 | return s
249 | }
250 |
251 | func getFields(fields []interface{}) (result []Field) {
252 | fields = expandSliceValues(fields)
253 | result = make([]Field, 0, len(fields))
254 | for _, field := range fields {
255 | switch field.(type) {
256 | case Field:
257 | result = append(result, field.(Field))
258 | case Table:
259 | result = append(result, field.(Table).GetFields()...)
260 | default:
261 | fieldCopy := field
262 | fieldExpression := expression{builder: func(scope scope) (string, error) {
263 | sql, _, err := getSQL(scope, fieldCopy)
264 | if err != nil {
265 | return "", err
266 | }
267 | return sql, nil
268 | }}
269 | result = append(result, fieldExpression)
270 | }
271 | }
272 | return
273 | }
274 |
275 | func (d *database) Select(fields ...interface{}) selectWithFields {
276 | return selectStatus{
277 | base: selectBase{
278 | scope: scope{
279 | Database: d,
280 | },
281 | fields: getFields(fields),
282 | },
283 | }
284 | }
285 |
286 | func (s selectStatus) From(tables ...Table) selectWithTables {
287 | activeSelectBase(&s).scope.Tables = tables
288 | return s
289 | }
290 |
291 | func (d *database) SelectFrom(tables ...Table) selectWithTables {
292 | return selectStatus{
293 | base: selectBase{
294 | scope: scope{
295 | Database: d,
296 | Tables: tables,
297 | },
298 | },
299 | }
300 | }
301 |
302 | func (d *database) SelectDistinct(fields ...interface{}) selectWithFields {
303 | return selectStatus{
304 | base: selectBase{
305 | scope: scope{
306 | Database: d,
307 | },
308 | fields: getFields(fields),
309 | distinct: true,
310 | },
311 | }
312 | }
313 |
314 | func (s selectStatus) Where(conditions ...BooleanExpression) selectWithWhere {
315 | if base := activeSelectBase(&s); base.where == nil {
316 | base.where = And(conditions...)
317 | } else {
318 | base.where = And(base.where, And(conditions...))
319 | }
320 | return s
321 | }
322 |
323 | func (s selectStatus) WhereIf(prerequisite bool, conditions ...BooleanExpression) selectWithWhere {
324 | if !prerequisite {
325 | return s
326 | }
327 | if base := activeSelectBase(&s); base.where == nil {
328 | base.where = And(conditions...)
329 | } else {
330 | base.where = And(base.where, And(conditions...))
331 | }
332 | return s
333 | }
334 |
335 | func (s selectStatus) GroupBy(expressions ...Expression) selectWithGroupBy {
336 | activeSelectBase(&s).groupBys = expressions
337 | return s
338 | }
339 |
340 | func (s selectStatus) Having(conditions ...BooleanExpression) selectWithGroupByHaving {
341 | activeSelectBase(&s).having = And(conditions...)
342 | return s
343 | }
344 |
345 | func (s selectStatus) UnionSelect(fields ...interface{}) selectWithFields {
346 | return s.withUnionSelect(false, false, fields, nil)
347 | }
348 |
349 | func (s selectStatus) UnionSelectFrom(tables ...Table) selectWithTables {
350 | return s.withUnionSelect(false, false, nil, tables)
351 | }
352 |
353 | func (s selectStatus) UnionSelectDistinct(fields ...interface{}) selectWithFields {
354 | return s.withUnionSelect(false, true, fields, nil)
355 | }
356 |
357 | func (s selectStatus) UnionAllSelect(fields ...interface{}) selectWithFields {
358 | return s.withUnionSelect(true, false, fields, nil)
359 | }
360 |
361 | func (s selectStatus) UnionAllSelectFrom(tables ...Table) selectWithTables {
362 | return s.withUnionSelect(true, false, nil, tables)
363 | }
364 |
365 | func (s selectStatus) UnionAllSelectDistinct(fields ...interface{}) selectWithFields {
366 | return s.withUnionSelect(true, true, fields, nil)
367 | }
368 |
369 | func (s selectStatus) withUnionSelect(all bool, distinct bool, fields []interface{}, tables []Table) selectStatus {
370 | s.lastUnion = &unionSelectStatus{
371 | base: selectBase{
372 | scope: scope{
373 | Database: s.base.scope.Database,
374 | Tables: tables,
375 | },
376 | distinct: distinct,
377 | fields: getFields(fields),
378 | },
379 | all: all,
380 | previous: s.lastUnion,
381 | }
382 | return s
383 | }
384 |
385 | func (s selectStatus) OrderBy(orderBys ...OrderBy) selectWithOrder {
386 | s.orderBys = orderBys
387 | return s
388 | }
389 |
390 | func (s selectStatus) Limit(limit int) selectWithLimit {
391 | s.limit = &limit
392 | return s
393 | }
394 |
395 | func (s selectStatus) Offset(offset int) selectWithOffset {
396 | s.offset = offset
397 | return s
398 | }
399 |
400 | func (s selectStatus) Count() (count int, err error) {
401 | if s.lastUnion == nil && len(s.base.groupBys) == 0 && s.limit == nil {
402 | if s.base.distinct {
403 | fields := s.base.fields
404 | s.base.distinct = false
405 | s.base.fields = []Field{expression{builder: func(scope scope) (string, error) {
406 | fieldsSql, err := fields.GetSQL(scope)
407 | if err != nil {
408 | return "", err
409 | }
410 | return "COUNT(DISTINCT " + fieldsSql + ")", nil
411 | }}}
412 | _, err = s.FetchFirst(&count)
413 | } else {
414 | s.base.fields = []Field{staticExpression("COUNT(1)", 0, false)}
415 | _, err = s.FetchFirst(&count)
416 | }
417 | } else {
418 | if !s.base.distinct {
419 | s.base.fields = []Field{staticExpression("1", 0, false)}
420 | }
421 | _, err = s.base.scope.Database.Select(Function("COUNT", 1)).
422 | From(s.asDerivedTable("t")).
423 | FetchFirst(&count)
424 | }
425 |
426 | return
427 | }
428 |
429 | func (s selectStatus) LockInShareMode() selectWithLock {
430 | s.lock = " LOCK IN SHARE MODE"
431 | return s
432 | }
433 |
434 | func (s selectStatus) ForUpdate() selectWithLock {
435 | s.lock = " FOR UPDATE"
436 | return s
437 | }
438 |
439 | func (s selectStatus) ForUpdateNoWait() selectWithLock {
440 | s.lock = " FOR UPDATE NOWAIT"
441 | return s
442 | }
443 |
444 | func (s selectStatus) ForUpdateSkipLocked() selectWithLock {
445 | s.lock = " FOR UPDATE SKIP LOCKED"
446 | return s
447 | }
448 |
449 | func (s selectStatus) asDerivedTable(name string) Table {
450 | return derivedTable{
451 | name: name,
452 | selectStatus: s,
453 | }
454 | }
455 |
456 | func (s selectStatus) Exists() (exists bool, err error) {
457 | _, err = s.base.scope.Database.Select(command("EXISTS", s)).FetchFirst(&exists)
458 | return
459 | }
460 |
461 | func (s selectBase) buildSelectBase(sb *strings.Builder) error {
462 | sb.WriteString("SELECT ")
463 | if s.distinct {
464 | sb.WriteString("DISTINCT ")
465 | }
466 |
467 | // find tables from fields if "From" is not specified
468 | if len(s.scope.Tables) == 0 && len(s.fields) > 0 {
469 | tableNames := make([]string, 0, len(s.fields))
470 | tableMap := make(map[string]Table)
471 | for _, field := range s.fields {
472 | table := field.GetTable()
473 | if table == nil {
474 | continue
475 | }
476 | tableName := table.GetName()
477 | if _, ok := tableMap[tableName]; !ok {
478 | tableMap[tableName] = table
479 | tableNames = append(tableNames, tableName)
480 | }
481 | }
482 | for _, tableName := range tableNames {
483 | table := tableMap[tableName]
484 | s.scope.Tables = append(s.scope.Tables, table)
485 | }
486 | }
487 |
488 | fieldsSql, err := s.fields.GetSQL(s.scope)
489 | if err != nil {
490 | return err
491 | }
492 | sb.WriteString(fieldsSql)
493 |
494 | if len(s.scope.Tables) > 0 {
495 | fromSql := commaTables(s.scope, s.scope.Tables)
496 | sb.WriteString(" FROM ")
497 | sb.WriteString(fromSql)
498 | }
499 |
500 | if s.scope.lastJoin != nil {
501 | var joins []*join
502 | for j := s.scope.lastJoin; j != nil; j = j.previous {
503 | joins = append(joins, j)
504 | }
505 | for i := len(joins) - 1; i >= 0; i-- {
506 | join := joins[i]
507 | sb.WriteString(" ")
508 | sb.WriteString(join.prefix)
509 | sb.WriteString("JOIN ")
510 | sb.WriteString(join.table.GetSQL(s.scope))
511 | // cause on isn't a required part of join when using natural join,
512 | // so move it to if statement
513 | if join.on != nil {
514 | onSql, err := join.on.GetSQL(s.scope)
515 | if err != nil {
516 | return err
517 | }
518 | sb.WriteString(" ON ")
519 | sb.WriteString(onSql)
520 | }
521 | }
522 | }
523 |
524 | if err := appendWhere(sb, s.scope, s.where); err != nil {
525 | return err
526 | }
527 |
528 | if len(s.groupBys) != 0 {
529 | groupBySql, err := commaExpressions(s.scope, s.groupBys)
530 | if err != nil {
531 | return err
532 | }
533 | sb.WriteString(" GROUP BY ")
534 | sb.WriteString(groupBySql)
535 |
536 | if s.having != nil {
537 | havingSql, err := s.having.GetSQL(s.scope)
538 | if err != nil {
539 | return err
540 | }
541 | sb.WriteString(" HAVING ")
542 | sb.WriteString(havingSql)
543 | }
544 | }
545 |
546 | return nil
547 | }
548 |
549 | func (s selectStatus) GetSQL() (string, error) {
550 | var sb strings.Builder
551 | sb.Grow(128)
552 |
553 | if err := s.base.buildSelectBase(&sb); err != nil {
554 | return "", err
555 | }
556 |
557 | var unions []*unionSelectStatus
558 | for union := s.lastUnion; union != nil; union = union.previous {
559 | unions = append(unions, union)
560 | }
561 | for i := len(unions) - 1; i >= 0; i-- {
562 | union := unions[i]
563 | if union.all {
564 | sb.WriteString(" UNION ALL ")
565 | } else {
566 | sb.WriteString(" UNION ")
567 | }
568 | if err := union.base.buildSelectBase(&sb); err != nil {
569 | return "", err
570 | }
571 | }
572 |
573 | if len(s.orderBys) > 0 {
574 | orderBySql, err := commaOrderBys(s.base.scope, s.orderBys)
575 | if err != nil {
576 | return "", err
577 | }
578 | sb.WriteString(" ORDER BY ")
579 | sb.WriteString(orderBySql)
580 | }
581 |
582 | if s.limit != nil {
583 | sb.WriteString(" LIMIT ")
584 | sb.WriteString(strconv.Itoa(*s.limit))
585 | }
586 |
587 | if s.offset != 0 {
588 | sb.WriteString(" OFFSET ")
589 | sb.WriteString(strconv.Itoa(s.offset))
590 | }
591 |
592 | sb.WriteString(s.lock)
593 |
594 | return sb.String(), nil
595 | }
596 |
597 | func (s selectStatus) WithContext(ctx context.Context) toSelectFinal {
598 | s.ctx = ctx
599 | return s
600 | }
601 |
602 | func (s selectStatus) FetchCursor() (Cursor, error) {
603 | sqlString, err := s.GetSQL()
604 | if err != nil {
605 | return nil, err
606 | }
607 |
608 | cursor, err := s.base.scope.Database.QueryContext(s.ctx, sqlString)
609 | if err != nil {
610 | return nil, err
611 | }
612 | return cursor, nil
613 | }
614 |
615 | func (s selectStatus) FetchFirst(dest ...interface{}) (ok bool, err error) {
616 | cursor, err := s.FetchCursor()
617 | if err != nil {
618 | return
619 | }
620 | defer cursor.Close()
621 |
622 | for cursor.Next() {
623 | err = cursor.Scan(dest...)
624 | if err != nil {
625 | return
626 | }
627 | ok = true
628 | break
629 | }
630 |
631 | return
632 | }
633 |
634 | func (s selectStatus) FetchExactlyOne(dest ...interface{}) (err error) {
635 | cursor, err := s.FetchCursor()
636 | if err != nil {
637 | return
638 | }
639 | defer cursor.Close()
640 |
641 | hasResult := false
642 | for cursor.Next() {
643 | if hasResult {
644 | return errors.New("more than one rows")
645 | }
646 | err = cursor.Scan(dest...)
647 | if err != nil {
648 | return
649 | }
650 | hasResult = true
651 | }
652 | if !hasResult {
653 | err = errors.New("no rows")
654 | }
655 | return
656 | }
657 |
658 | func (s selectStatus) fetchAllAsMap(cursor Cursor, mapType reflect.Type) (mapValue reflect.Value, err error) {
659 | mapValue = reflect.MakeMap(mapType)
660 | key := reflect.New(mapType.Key())
661 | elem := reflect.New(mapType.Elem())
662 |
663 | for cursor.Next() {
664 | err = cursor.Scan(key.Interface(), elem.Interface())
665 | if err != nil {
666 | return
667 | }
668 |
669 | mapValue.SetMapIndex(reflect.Indirect(key), reflect.Indirect(elem))
670 | }
671 | return
672 | }
673 |
674 | func (s selectStatus) FetchAll(dest ...interface{}) (rows int, err error) {
675 | cursor, err := s.FetchCursor()
676 | if err != nil {
677 | return
678 | }
679 | defer cursor.Close()
680 |
681 | count := len(dest)
682 | values := make([]reflect.Value, count)
683 | for i, item := range dest {
684 | if reflect.ValueOf(item).Kind() != reflect.Ptr {
685 | err = errors.New("dest should be a pointer")
686 | return
687 | }
688 | val := reflect.Indirect(reflect.ValueOf(item))
689 |
690 | switch val.Kind() {
691 | case reflect.Slice:
692 | values[i] = val
693 | case reflect.Map:
694 | if len(dest) != 1 {
695 | err = errors.New("dest map should be 1 element")
696 | return
697 | }
698 | var mapValue reflect.Value
699 | mapValue, err = s.fetchAllAsMap(cursor, val.Type())
700 | if err != nil {
701 | return
702 | }
703 | reflect.ValueOf(item).Elem().Set(mapValue)
704 | return
705 | default:
706 | err = errors.New("dest should be pointed to a slice")
707 | return
708 | }
709 | }
710 |
711 | elements := make([]reflect.Value, count)
712 | pointers := make([]interface{}, count)
713 | for i := 0; i < count; i++ {
714 | elements[i] = reflect.New(values[i].Type().Elem())
715 | pointers[i] = elements[i].Interface()
716 | }
717 | for cursor.Next() {
718 | err = cursor.Scan(pointers...)
719 | if err != nil {
720 | return
721 | }
722 | for i := 0; i < count; i++ {
723 | values[i].Set(reflect.Append(values[i], reflect.Indirect(elements[i])))
724 | }
725 | rows++
726 | }
727 | return
728 | }
729 |
--------------------------------------------------------------------------------
/expression.go:
--------------------------------------------------------------------------------
1 | package sqlingo
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "strconv"
7 | "strings"
8 | "time"
9 | "unsafe"
10 | )
11 |
12 | type priority uint8
13 |
14 | // Expression is the interface of an SQL expression.
15 | type Expression interface {
16 | // get the SQL string
17 | GetSQL(scope scope) (string, error)
18 | getOperatorPriority() priority
19 |
20 | // <> operator
21 | NotEquals(other interface{}) BooleanExpression
22 | // == operator
23 | Equals(other interface{}) BooleanExpression
24 | // < operator
25 | LessThan(other interface{}) BooleanExpression
26 | // <= operator
27 | LessThanOrEquals(other interface{}) BooleanExpression
28 | // > operator
29 | GreaterThan(other interface{}) BooleanExpression
30 | // >= operator
31 | GreaterThanOrEquals(other interface{}) BooleanExpression
32 |
33 | IsNull() BooleanExpression
34 | IsNotNull() BooleanExpression
35 | IsTrue() BooleanExpression
36 | IsNotTrue() BooleanExpression
37 | IsFalse() BooleanExpression
38 | IsNotFalse() BooleanExpression
39 | In(values ...interface{}) BooleanExpression
40 | NotIn(values ...interface{}) BooleanExpression
41 | Between(min interface{}, max interface{}) BooleanExpression
42 | NotBetween(min interface{}, max interface{}) BooleanExpression
43 | Desc() OrderBy
44 |
45 | As(alias string) Alias
46 |
47 | If(trueValue interface{}, falseValue interface{}) UnknownExpression
48 | IfNull(altValue interface{}) UnknownExpression
49 | }
50 |
51 | // Alias is the interface of an table/column alias.
52 | type Alias interface {
53 | GetSQL(scope scope) (string, error)
54 | }
55 |
56 | // BooleanExpression is the interface of an SQL expression with boolean value.
57 | type BooleanExpression interface {
58 | Expression
59 | And(other interface{}) BooleanExpression
60 | Or(other interface{}) BooleanExpression
61 | Xor(other interface{}) BooleanExpression
62 | Not() BooleanExpression
63 | }
64 |
65 | // NumberExpression is the interface of an SQL expression with number value.
66 | type NumberExpression interface {
67 | Expression
68 | Add(other interface{}) NumberExpression
69 | Sub(other interface{}) NumberExpression
70 | Mul(other interface{}) NumberExpression
71 | Div(other interface{}) NumberExpression
72 | IntDiv(other interface{}) NumberExpression
73 | Mod(other interface{}) NumberExpression
74 |
75 | Sum() NumberExpression
76 | Avg() NumberExpression
77 | Min() UnknownExpression
78 | Max() UnknownExpression
79 | }
80 |
81 | // StringExpression is the interface of an SQL expression with string value.
82 | type StringExpression interface {
83 | Expression
84 | Min() UnknownExpression
85 | Max() UnknownExpression
86 | Like(other interface{}) BooleanExpression
87 | Contains(substring string) BooleanExpression
88 | Concat(other interface{}) StringExpression
89 | IfEmpty(altValue interface{}) StringExpression
90 | IsEmpty() BooleanExpression
91 | Lower() StringExpression
92 | Upper() StringExpression
93 | Left(count interface{}) StringExpression
94 | Right(count interface{}) StringExpression
95 | Trim() StringExpression
96 | }
97 |
98 | type ArrayExpression interface {
99 | Expression
100 | }
101 |
102 | type DateExpression interface {
103 | Expression
104 | Min() UnknownExpression
105 | Max() UnknownExpression
106 | }
107 |
108 | // UnknownExpression is the interface of an SQL expression with unknown value.
109 | type UnknownExpression interface {
110 | Expression
111 | And(other interface{}) BooleanExpression
112 | Or(other interface{}) BooleanExpression
113 | Xor(other interface{}) BooleanExpression
114 | Not() BooleanExpression
115 | Add(other interface{}) NumberExpression
116 | Sub(other interface{}) NumberExpression
117 | Mul(other interface{}) NumberExpression
118 | Div(other interface{}) NumberExpression
119 | IntDiv(other interface{}) NumberExpression
120 | Mod(other interface{}) NumberExpression
121 |
122 | Sum() NumberExpression
123 | Avg() NumberExpression
124 | Min() UnknownExpression
125 | Max() UnknownExpression
126 |
127 | Like(other interface{}) BooleanExpression
128 | Contains(substring string) BooleanExpression
129 | Concat(other interface{}) StringExpression
130 | IfEmpty(altValue interface{}) StringExpression
131 | IsEmpty() BooleanExpression
132 | Lower() StringExpression
133 | Upper() StringExpression
134 | Left(count interface{}) StringExpression
135 | Right(count interface{}) StringExpression
136 | Trim() StringExpression
137 | }
138 |
139 | type expression struct {
140 | sql string
141 | builder func(scope scope) (string, error)
142 | priority priority
143 | isTrue bool
144 | isFalse bool
145 | isBool bool
146 | }
147 |
148 | func (e expression) GetTable() Table {
149 | return nil
150 | }
151 |
152 | type scope struct {
153 | Database *database
154 | Tables []Table
155 | lastJoin *join
156 | }
157 |
158 | func staticExpression(sql string, priority priority, isBool bool) expression {
159 | return expression{
160 | sql: sql,
161 | priority: priority,
162 | isBool: isBool,
163 | }
164 | }
165 |
166 | func True() BooleanExpression {
167 | return expression{
168 | sql: "TRUE",
169 | isTrue: true,
170 | isBool: true,
171 | }
172 | }
173 |
174 | func False() BooleanExpression {
175 | return expression{
176 | sql: "FALSE",
177 | isFalse: true,
178 | isBool: true,
179 | }
180 | }
181 |
182 | // Raw create a raw SQL statement
183 | func Raw(sql string) UnknownExpression {
184 | return expression{
185 | sql: sql,
186 | priority: 99,
187 | }
188 | }
189 |
190 | // And creates an expression with AND operator.
191 | func And(expressions ...BooleanExpression) (result BooleanExpression) {
192 | if len(expressions) == 0 {
193 | result = True()
194 | return
195 | }
196 | for _, condition := range expressions {
197 | if result == nil {
198 | result = condition
199 | } else {
200 | result = result.And(condition)
201 | }
202 | }
203 | return
204 | }
205 |
206 | // Or creates an expression with OR operator.
207 | func Or(expressions ...BooleanExpression) (result BooleanExpression) {
208 | if len(expressions) == 0 {
209 | result = False()
210 | return
211 | }
212 | for _, condition := range expressions {
213 | if result == nil {
214 | result = condition
215 | } else {
216 | result = result.Or(condition)
217 | }
218 | }
219 | return
220 | }
221 |
222 | func (e expression) As(name string) Alias {
223 | return expression{builder: func(scope scope) (string, error) {
224 | expressionSql, err := e.GetSQL(scope)
225 | if err != nil {
226 | return "", err
227 | }
228 | return expressionSql + " AS " + name, nil
229 | }}
230 | }
231 |
232 | func (e expression) If(trueValue interface{}, falseValue interface{}) UnknownExpression {
233 | return If(e, trueValue, falseValue)
234 | }
235 |
236 | func (e expression) IfNull(altValue interface{}) UnknownExpression {
237 | return Function("IFNULL", e, altValue)
238 | }
239 |
240 | func (e expression) IfEmpty(altValue interface{}) StringExpression {
241 | return If(e.NotEquals(""), e, altValue)
242 | }
243 |
244 | func (e expression) IsEmpty() BooleanExpression {
245 | return e.Equals("")
246 | }
247 |
248 | func (e expression) Lower() StringExpression {
249 | return function("LOWER", e)
250 | }
251 |
252 | func (e expression) Upper() StringExpression {
253 | return function("UPPER", e)
254 | }
255 |
256 | func (e expression) Left(count interface{}) StringExpression {
257 | return function("LEFT", e, count)
258 | }
259 |
260 | func (e expression) Right(count interface{}) StringExpression {
261 | return function("RIGHT", e, count)
262 | }
263 |
264 | func (e expression) Trim() StringExpression {
265 | return function("TRIM", e)
266 | }
267 |
268 | func (e expression) CharLength() NumberExpression {
269 | return function("CHAR_LENGTH", e)
270 | }
271 |
272 | func (e expression) HasPrefix(prefix interface{}) BooleanExpression {
273 | return e.Left(function("CHAR_LENGTH", prefix)).Equals(prefix)
274 | }
275 |
276 | func (e expression) HasSuffix(suffix interface{}) BooleanExpression {
277 | return e.Right(function("CHAR_LENGTH", suffix)).Equals(suffix)
278 | }
279 |
280 | func (e expression) GetSQL(scope scope) (string, error) {
281 | if e.sql != "" {
282 | return e.sql, nil
283 | }
284 | return e.builder(scope)
285 | }
286 |
287 | var needsEscape = [256]int{
288 | 0: 1,
289 | '\n': 1,
290 | '\r': 1,
291 | '\\': 1,
292 | '\'': 1,
293 | '"': 1,
294 | 0x1a: 1,
295 | }
296 |
297 | func quoteIdentifier(identifier string) (result dialectArray) {
298 | for dialect := dialect(0); dialect < dialectCount; dialect++ {
299 | switch dialect {
300 | case dialectMySQL:
301 | result[dialect] = "`" + identifier + "`"
302 | case dialectMSSQL:
303 | result[dialect] = "[" + identifier + "]"
304 | default:
305 | result[dialect] = "\"" + identifier + "\""
306 | }
307 | }
308 | return
309 | }
310 |
311 | func quoteString(s string) string {
312 | if s == "" {
313 | return "''"
314 | }
315 |
316 | buf := make([]byte, len(s)*2+2)
317 | buf[0] = '\''
318 | n := 1
319 | for i := 0; i < len(s); i++ {
320 | b := s[i]
321 | buf[n] = '\\'
322 | n += needsEscape[b]
323 | buf[n] = b
324 | n++
325 | }
326 | buf[n] = '\''
327 | n++
328 | buf = buf[:n]
329 | return *(*string)(unsafe.Pointer(&buf))
330 | }
331 |
332 | func getSQL(scope scope, value interface{}) (sql string, priority priority, err error) {
333 | const mysqlTimeFormat = "2006-01-02 15:04:05.000000"
334 | if value == nil {
335 | sql = "NULL"
336 | return
337 | }
338 | switch value.(type) {
339 | case int:
340 | sql = strconv.Itoa(value.(int))
341 | case string:
342 | sql = quoteString(value.(string))
343 | case Expression:
344 | sql, err = value.(Expression).GetSQL(scope)
345 | priority = value.(Expression).getOperatorPriority()
346 | case Assignment:
347 | sql, err = value.(Assignment).GetSQL(scope)
348 | case toSelectFinal:
349 | sql, err = value.(toSelectFinal).GetSQL()
350 | if err != nil {
351 | return
352 | }
353 | sql = "(" + sql + ")"
354 | case toUpdateFinal:
355 | sql, err = value.(toUpdateFinal).GetSQL()
356 | case Table:
357 | sql = value.(Table).GetSQL(scope)
358 | case CaseExpression:
359 | sql, err = value.(CaseExpression).End().GetSQL(scope)
360 | case time.Time:
361 | tm := value.(time.Time)
362 | if tm.IsZero() {
363 | sql = "NULL"
364 | } else {
365 | tmStr := tm.Format(mysqlTimeFormat)
366 | sql = quoteString(tmStr)
367 | }
368 | case *time.Time:
369 | tm := value.(*time.Time)
370 | if tm == nil || tm.IsZero() {
371 | sql = "NULL"
372 | } else {
373 | tmStr := tm.Format(mysqlTimeFormat)
374 | sql = quoteString(tmStr)
375 | }
376 | default:
377 | v := reflect.ValueOf(value)
378 | sql, priority, err = getSQLFromReflectValue(scope, v)
379 | }
380 | return
381 | }
382 |
383 | func getSQLFromReflectValue(scope scope, v reflect.Value) (sql string, priority priority, err error) {
384 | if v.Kind() == reflect.Ptr {
385 | // dereference pointers
386 | for {
387 | if v.IsNil() {
388 | sql = "NULL"
389 | return
390 | }
391 | v = v.Elem()
392 | if v.Kind() != reflect.Ptr {
393 | break
394 | }
395 | }
396 | sql, priority, err = getSQL(scope, v.Interface())
397 | return
398 | }
399 |
400 | switch v.Kind() {
401 | case reflect.Bool:
402 | if v.Bool() {
403 | sql = "1"
404 | } else {
405 | sql = "0"
406 | }
407 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
408 | sql = strconv.FormatInt(v.Int(), 10)
409 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
410 | sql = strconv.FormatUint(v.Uint(), 10)
411 | case reflect.Float32, reflect.Float64:
412 | sql = strconv.FormatFloat(v.Float(), 'g', -1, 64)
413 | case reflect.String:
414 | sql = quoteString(v.String())
415 | case reflect.Array, reflect.Slice:
416 | length := v.Len()
417 | values := make([]interface{}, length)
418 | for i := 0; i < length; i++ {
419 | values[i] = v.Index(i).Interface()
420 | }
421 | sql, err = commaValues(scope, values)
422 | if err == nil {
423 | sql = "(" + sql + ")"
424 | }
425 | default:
426 | if vs, ok := v.Interface().(interface{ String() string }); ok {
427 | sql = quoteString(vs.String())
428 | } else {
429 | err = fmt.Errorf("invalid type %s", v.Kind().String())
430 | }
431 | }
432 | return
433 | }
434 |
435 | /*
436 | 1 INTERVAL
437 | 2 BINARY, COLLATE
438 | 3 !
439 | 4 - (unary minus), ~ (unary bit inversion)
440 | 5 ^
441 | 6 *, /, DIV, %, MOD
442 | 7 -, +
443 | 8 <<, >>
444 | 9 &
445 | 10 |
446 | 11 = (comparison), <=>, >=, >, <=, <, <>, !=, IS, LIKE, REGEXP, IN
447 | 12 BETWEEN, CASE, WHEN, THEN, ELSE
448 | 13 NOT
449 | 14 AND, &&
450 | 15 XOR
451 | 16 OR, ||
452 | 17 = (assignment), :=
453 | */
454 | func (e expression) NotEquals(other interface{}) BooleanExpression {
455 | return e.binaryOperation("<>", other, 11, true)
456 | }
457 |
458 | func (e expression) Equals(other interface{}) BooleanExpression {
459 | return e.binaryOperation("=", other, 11, true)
460 | }
461 |
462 | func (e expression) LessThan(other interface{}) BooleanExpression {
463 | return e.binaryOperation("<", other, 11, true)
464 | }
465 |
466 | func (e expression) LessThanOrEquals(other interface{}) BooleanExpression {
467 | return e.binaryOperation("<=", other, 11, true)
468 | }
469 |
470 | func (e expression) GreaterThan(other interface{}) BooleanExpression {
471 | return e.binaryOperation(">", other, 11, true)
472 | }
473 |
474 | func (e expression) GreaterThanOrEquals(other interface{}) BooleanExpression {
475 | return e.binaryOperation(">=", other, 11, true)
476 | }
477 |
478 | func toBooleanExpression(value interface{}) BooleanExpression {
479 | e, ok := value.(expression)
480 | switch {
481 | case !ok:
482 | return nil
483 | case e.isTrue:
484 | return True()
485 | case e.isFalse:
486 | return False()
487 | case e.isBool:
488 | return e
489 | default:
490 | return nil
491 | }
492 | }
493 |
494 | func (e expression) And(other interface{}) BooleanExpression {
495 | switch {
496 | case e.isFalse:
497 | return e
498 | case e.isTrue:
499 | if exp := toBooleanExpression(other); exp != nil {
500 | return exp
501 | }
502 | }
503 | return e.binaryOperation("AND", other, 14, true)
504 | }
505 |
506 | func (e expression) Or(other interface{}) BooleanExpression {
507 | switch {
508 | case e.isTrue:
509 | return e
510 | case e.isFalse:
511 | if exp := toBooleanExpression(other); exp != nil {
512 | return exp
513 | }
514 | }
515 | return e.binaryOperation("OR", other, 16, true)
516 | }
517 |
518 | func (e expression) Xor(other interface{}) BooleanExpression {
519 | return e.binaryOperation("XOR", other, 15, true)
520 | }
521 |
522 | func (e expression) Add(other interface{}) NumberExpression {
523 | return e.binaryOperation("+", other, 7, false)
524 | }
525 |
526 | func (e expression) Sub(other interface{}) NumberExpression {
527 | return e.binaryOperation("-", other, 7, false)
528 | }
529 |
530 | func (e expression) Mul(other interface{}) NumberExpression {
531 | return e.binaryOperation("*", other, 6, false)
532 | }
533 |
534 | func (e expression) Div(other interface{}) NumberExpression {
535 | return e.binaryOperation("/", other, 6, false)
536 | }
537 |
538 | func (e expression) IntDiv(other interface{}) NumberExpression {
539 | return e.binaryOperation("DIV", other, 6, false)
540 | }
541 |
542 | func (e expression) Mod(other interface{}) NumberExpression {
543 | return e.binaryOperation("%", other, 6, false)
544 | }
545 |
546 | func (e expression) Sum() NumberExpression {
547 | return function("SUM", e)
548 | }
549 |
550 | func (e expression) Avg() NumberExpression {
551 | return function("AVG", e)
552 | }
553 |
554 | func (e expression) Min() UnknownExpression {
555 | return function("MIN", e)
556 | }
557 |
558 | func (e expression) Max() UnknownExpression {
559 | return function("MAX", e)
560 | }
561 |
562 | func (e expression) Like(other interface{}) BooleanExpression {
563 | return e.binaryOperation("LIKE", other, 11, true)
564 | }
565 |
566 | func (e expression) Concat(other interface{}) StringExpression {
567 | return Concat(e, other)
568 | }
569 |
570 | func (e expression) Contains(substring string) BooleanExpression {
571 | return function("LOCATE", substring, e).GreaterThan(0)
572 | }
573 |
574 | func (e expression) binaryOperation(operator string, value interface{}, priority priority, isBool bool) expression {
575 | return expression{builder: func(scope scope) (string, error) {
576 | leftSql, err := e.GetSQL(scope)
577 | if err != nil {
578 | return "", err
579 | }
580 | leftPriority := e.priority
581 | rightSql, rightPriority, err := getSQL(scope, value)
582 | if err != nil {
583 | return "", err
584 | }
585 | shouldParenthesizeLeft := leftPriority > priority
586 | shouldParenthesizeRight := rightPriority >= priority
587 | var sb strings.Builder
588 | sb.Grow(len(leftSql) + len(operator) + len(rightSql) + 6)
589 | if shouldParenthesizeLeft {
590 | sb.WriteByte('(')
591 | }
592 | sb.WriteString(leftSql)
593 | if shouldParenthesizeLeft {
594 | sb.WriteByte(')')
595 | }
596 | sb.WriteByte(' ')
597 | sb.WriteString(operator)
598 | sb.WriteByte(' ')
599 | if shouldParenthesizeRight {
600 | sb.WriteByte('(')
601 | }
602 | sb.WriteString(rightSql)
603 | if shouldParenthesizeRight {
604 | sb.WriteByte(')')
605 | }
606 | return sb.String(), nil
607 | }, priority: priority, isBool: isBool}
608 | }
609 |
610 | func (e expression) prefixSuffixExpression(prefix string, suffix string, priority priority, isBool bool) expression {
611 | if e.sql != "" {
612 | return expression{
613 | sql: prefix + e.sql + suffix,
614 | priority: priority,
615 | isBool: isBool,
616 | }
617 | }
618 | return expression{
619 | builder: func(scope scope) (string, error) {
620 | exprSql, err := e.GetSQL(scope)
621 | if err != nil {
622 | return "", err
623 | }
624 | var sb strings.Builder
625 | sb.Grow(len(prefix) + len(exprSql) + len(suffix) + 2)
626 | sb.WriteString(prefix)
627 | shouldParenthesize := e.priority > priority
628 | if shouldParenthesize {
629 | sb.WriteByte('(')
630 | }
631 | sb.WriteString(exprSql)
632 | if shouldParenthesize {
633 | sb.WriteByte(')')
634 | }
635 | sb.WriteString(suffix)
636 | return sb.String(), nil
637 | },
638 | priority: priority,
639 | isBool: isBool,
640 | }
641 | }
642 |
643 | func (e expression) IsNull() BooleanExpression {
644 | return e.prefixSuffixExpression("", " IS NULL", 11, true)
645 | }
646 |
647 | func (e expression) Not() BooleanExpression {
648 | switch {
649 | case e.isTrue:
650 | return False()
651 | case e.isFalse:
652 | return True()
653 | default:
654 | return e.prefixSuffixExpression("NOT ", "", 13, true)
655 | }
656 | }
657 |
658 | func (e expression) IsNotNull() BooleanExpression {
659 | return e.prefixSuffixExpression("", " IS NOT NULL", 11, true)
660 | }
661 |
662 | func (e expression) IsTrue() BooleanExpression {
663 | return e.prefixSuffixExpression("", " IS TRUE", 11, true)
664 | }
665 |
666 | func (e expression) IsNotTrue() BooleanExpression {
667 | return e.prefixSuffixExpression("", " IS NOT TRUE", 11, true)
668 | }
669 |
670 | func (e expression) IsFalse() BooleanExpression {
671 | return e.prefixSuffixExpression("", " IS FALSE", 11, true)
672 | }
673 |
674 | func (e expression) IsNotFalse() BooleanExpression {
675 | return e.prefixSuffixExpression("", " IS NOT FALSE", 11, true)
676 | }
677 |
678 | func expandSliceValue(value reflect.Value) (result []interface{}) {
679 | result = make([]interface{}, 0, 16)
680 | kind := value.Kind()
681 | switch kind {
682 | case reflect.Array, reflect.Slice:
683 | length := value.Len()
684 | for i := 0; i < length; i++ {
685 | result = append(result, expandSliceValue(value.Index(i))...)
686 | }
687 | case reflect.Interface, reflect.Ptr:
688 | result = append(result, expandSliceValue(value.Elem())...)
689 | default:
690 | result = append(result, value.Interface())
691 | }
692 | return
693 | }
694 |
695 | func expandSliceValues(values []interface{}) (result []interface{}) {
696 | result = make([]interface{}, 0, 16)
697 | for _, v := range values {
698 | value := reflect.ValueOf(v)
699 | result = append(result, expandSliceValue(value)...)
700 | }
701 | return
702 | }
703 |
704 | func (e expression) In(values ...interface{}) BooleanExpression {
705 | values = expandSliceValues(values)
706 | if len(values) == 0 {
707 | return False()
708 | }
709 | joiner := func(exprSql, valuesSql string) string { return exprSql + " IN (" + valuesSql + ")" }
710 | builder := e.getBuilder(e.Equals, joiner, values...)
711 | return expression{builder: builder, priority: 11}
712 | }
713 |
714 | func (e expression) NotIn(values ...interface{}) BooleanExpression {
715 | values = expandSliceValues(values)
716 | if len(values) == 0 {
717 | return True()
718 | }
719 | joiner := func(exprSql, valuesSql string) string { return exprSql + " NOT IN (" + valuesSql + ")" }
720 | builder := e.getBuilder(e.NotEquals, joiner, values...)
721 | return expression{builder: builder, priority: 11}
722 | }
723 |
724 | type joinerFunc = func(exprSql, valuesSql string) string
725 | type booleanFunc = func(other interface{}) BooleanExpression
726 | type builderFunc = func(scope scope) (string, error)
727 |
728 | func (e expression) getBuilder(single booleanFunc, joiner joinerFunc, values ...interface{}) builderFunc {
729 | return func(scope scope) (string, error) {
730 | var valuesSql string
731 | var err error
732 |
733 | if len(values) == 1 {
734 | value := values[0]
735 | if selectStatus, ok := value.(toSelectFinal); ok {
736 | // IN subquery
737 | valuesSql, err = selectStatus.GetSQL()
738 | if err != nil {
739 | return "", err
740 | }
741 | } else {
742 | // IN a single value
743 | return single(value).GetSQL(scope)
744 | }
745 | } else {
746 | // IN a list
747 | valuesSql, err = commaValues(scope, values)
748 | if err != nil {
749 | return "", err
750 | }
751 | }
752 |
753 | exprSql, err := e.GetSQL(scope)
754 | if err != nil {
755 | return "", err
756 | }
757 | return joiner(exprSql, valuesSql), nil
758 | }
759 | }
760 |
761 | func (e expression) Between(min interface{}, max interface{}) BooleanExpression {
762 | return e.buildBetween(" BETWEEN ", min, max)
763 | }
764 |
765 | func (e expression) NotBetween(min interface{}, max interface{}) BooleanExpression {
766 | return e.buildBetween(" NOT BETWEEN ", min, max)
767 | }
768 |
769 | func (e expression) buildBetween(operator string, min interface{}, max interface{}) BooleanExpression {
770 | return expression{builder: func(scope scope) (string, error) {
771 | exprSql, err := e.GetSQL(scope)
772 | if err != nil {
773 | return "", err
774 | }
775 | minSql, _, err := getSQL(scope, min)
776 | if err != nil {
777 | return "", err
778 | }
779 | maxSql, _, err := getSQL(scope, max)
780 | if err != nil {
781 | return "", err
782 | }
783 | return exprSql + operator + minSql + " AND " + maxSql, nil
784 | }, priority: 12}
785 | }
786 |
787 | func (e expression) getOperatorPriority() priority {
788 | return e.priority
789 | }
790 |
791 | func (e expression) Desc() OrderBy {
792 | return orderBy{by: e, desc: true}
793 | }
794 |
--------------------------------------------------------------------------------