├── .gitignore ├── CODEOWNERS ├── integration ├── doc.go ├── go.mod ├── go.sum └── integration_test.go ├── go.mod ├── stmtcacher_ctx_test.go ├── stmtcacher_noctx.go ├── row.go ├── stmtcacher_test.go ├── where.go ├── .github └── workflows │ └── go.yaml ├── row_test.go ├── .travis.yml ├── select_ctx_test.go ├── go.sum ├── insert_ctx_test.go ├── update_ctx_test.go ├── delete_ctx_test.go ├── LICENSE ├── part.go ├── squirrel_ctx_test.go ├── where_test.go ├── delete_ctx.go ├── insert_ctx.go ├── select_ctx.go ├── update_ctx.go ├── placeholder_test.go ├── statement_test.go ├── delete_test.go ├── stmtcacher_ctx.go ├── placeholder.go ├── insert_test.go ├── stmtcacher.go ├── squirrel_ctx.go ├── update_test.go ├── case.go ├── statement.go ├── case_test.go ├── README.md ├── squirrel.go ├── squirrel_test.go ├── delete.go ├── insert.go ├── update.go ├── expr.go ├── expr_test.go ├── select.go └── select_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | squirrel.test -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @stytchauth/backend 2 | -------------------------------------------------------------------------------- /integration/doc.go: -------------------------------------------------------------------------------- 1 | // This is a tests-only package. 2 | package integration 3 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/stytchauth/squirrel 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/davecgh/go-spew v1.1.1 // indirect 7 | github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 8 | github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect 9 | github.com/pmezard/go-difflib v1.0.0 // indirect 10 | github.com/stretchr/testify v1.2.2 11 | ) 12 | -------------------------------------------------------------------------------- /integration/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/Masterminds/squirrel/integration 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/Masterminds/squirrel v1.1.0 7 | github.com/go-sql-driver/mysql v1.4.1 8 | github.com/lib/pq v1.2.0 9 | github.com/mattn/go-sqlite3 v1.13.0 10 | github.com/stretchr/testify v1.4.0 11 | google.golang.org/appengine v1.6.5 // indirect 12 | ) 13 | 14 | replace github.com/Masterminds/squirrel => ../ 15 | -------------------------------------------------------------------------------- /stmtcacher_ctx_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestStmtCacherPrepareContext(t *testing.T) { 12 | db := &DBStub{} 13 | sc := NewStmtCache(db) 14 | query := "SELECT 1" 15 | 16 | sc.PrepareContext(ctx, query) 17 | assert.Equal(t, query, db.LastPrepareSql) 18 | 19 | sc.PrepareContext(ctx, query) 20 | assert.Equal(t, 1, db.PrepareCount, "expected 1 Prepare, got %d", db.PrepareCount) 21 | } 22 | -------------------------------------------------------------------------------- /stmtcacher_noctx.go: -------------------------------------------------------------------------------- 1 | // +build !go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "database/sql" 7 | ) 8 | 9 | // NewStmtCacher returns a DBProxy wrapping prep that caches Prepared Stmts. 10 | // 11 | // Stmts are cached based on the string value of their queries. 12 | func NewStmtCache(prep Preparer) *StmtCache { 13 | return &StmtCacher{prep: prep, cache: make(map[string]*sql.Stmt)} 14 | } 15 | 16 | // NewStmtCacher is deprecated 17 | // 18 | // Use NewStmtCache instead 19 | func NewStmtCacher(prep Preparer) DBProxy { 20 | return NewStmtCache(prep) 21 | } 22 | -------------------------------------------------------------------------------- /row.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | // RowScanner is the interface that wraps the Scan method. 4 | // 5 | // Scan behaves like database/sql.Row.Scan. 6 | type RowScanner interface { 7 | Scan(...interface{}) error 8 | } 9 | 10 | // Row wraps database/sql.Row to let squirrel return new errors on Scan. 11 | type Row struct { 12 | RowScanner 13 | err error 14 | } 15 | 16 | // Scan returns Row.err or calls RowScanner.Scan. 17 | func (r *Row) Scan(dest ...interface{}) error { 18 | if r.err != nil { 19 | return r.err 20 | } 21 | return r.RowScanner.Scan(dest...) 22 | } 23 | -------------------------------------------------------------------------------- /stmtcacher_test.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestStmtCachePrepare(t *testing.T) { 10 | db := &DBStub{} 11 | sc := NewStmtCache(db) 12 | query := "SELECT 1" 13 | 14 | sc.Prepare(query) 15 | assert.Equal(t, query, db.LastPrepareSql) 16 | 17 | sc.Prepare(query) 18 | assert.Equal(t, 1, db.PrepareCount, "expected 1 Prepare, got %d", db.PrepareCount) 19 | 20 | // clear statement cache 21 | assert.Nil(t, sc.Clear()) 22 | 23 | // should prepare the query again 24 | sc.Prepare(query) 25 | assert.Equal(t, 2, db.PrepareCount, "expected 2 Prepare, got %d", db.PrepareCount) 26 | } 27 | -------------------------------------------------------------------------------- /where.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | type wherePart part 8 | 9 | func newWherePart(pred interface{}, args ...interface{}) Sqlizer { 10 | return &wherePart{pred: pred, args: args} 11 | } 12 | 13 | func (p wherePart) ToSql() (sql string, args []interface{}, err error) { 14 | switch pred := p.pred.(type) { 15 | case nil: 16 | // no-op 17 | case rawSqlizer: 18 | return pred.toSqlRaw() 19 | case Sqlizer: 20 | return pred.ToSql() 21 | case map[string]interface{}: 22 | return Eq(pred).ToSql() 23 | case string: 24 | sql = pred 25 | args = p.args 26 | default: 27 | err = fmt.Errorf("expected string-keyed map or string, not %T", pred) 28 | } 29 | return 30 | } 31 | -------------------------------------------------------------------------------- /.github/workflows/go.yaml: -------------------------------------------------------------------------------- 1 | name: Go 2 | on: 3 | push: 4 | branches: 5 | - 'main' 6 | pull_request: 7 | branches: 8 | - 'main' 9 | 10 | jobs: 11 | test: 12 | name: Test 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Install Go 16 | uses: actions/setup-go@v2.1.3 17 | with: 18 | go-version: 1.16 19 | 20 | - name: Check out repo 21 | uses: actions/checkout@v2 22 | 23 | - name: Download all Go modules 24 | run: go mod download 25 | 26 | - name: Check that modules are tidy 27 | run: | 28 | go mod tidy 29 | git diff --exit-code -- go.mod go.sum 30 | 31 | - name: Run Tests 32 | run: go test -v ./... 33 | -------------------------------------------------------------------------------- /row_test.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type RowStub struct { 11 | Scanned bool 12 | } 13 | 14 | func (r *RowStub) Scan(_ ...interface{}) error { 15 | r.Scanned = true 16 | return nil 17 | } 18 | 19 | func TestRowScan(t *testing.T) { 20 | stub := &RowStub{} 21 | row := &Row{RowScanner: stub} 22 | err := row.Scan() 23 | assert.True(t, stub.Scanned, "row was not scanned") 24 | assert.NoError(t, err) 25 | } 26 | 27 | func TestRowScanErr(t *testing.T) { 28 | stub := &RowStub{} 29 | rowErr := fmt.Errorf("scan err") 30 | row := &Row{RowScanner: stub, err: rowErr} 31 | err := row.Scan() 32 | assert.False(t, stub.Scanned, "row was scanned") 33 | assert.Equal(t, rowErr, err) 34 | } 35 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.11.x 5 | - 1.12.x 6 | - 1.13.x 7 | 8 | services: 9 | - mysql 10 | - postgresql 11 | 12 | # Setting sudo access to false will let Travis CI use containers rather than 13 | # VMs to run the tests. For more details see: 14 | # - http://docs.travis-ci.com/user/workers/container-based-infrastructure/ 15 | # - http://docs.travis-ci.com/user/workers/standard-infrastructure/ 16 | sudo: false 17 | 18 | before_script: 19 | - mysql -e 'CREATE DATABASE squirrel;' 20 | - psql -c 'CREATE DATABASE squirrel;' -U postgres 21 | 22 | script: 23 | - go test 24 | - cd integration 25 | - go test -args -driver sqlite3 26 | - go test -args -driver mysql -dataSource travis@/squirrel 27 | - go test -args -driver postgres -dataSource 'postgres://postgres@localhost/squirrel?sslmode=disable' 28 | 29 | notifications: 30 | irc: "irc.freenode.net#masterminds" 31 | -------------------------------------------------------------------------------- /select_ctx_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestSelectBuilderContextRunners(t *testing.T) { 12 | db := &DBStub{} 13 | b := Select("test").RunWith(db) 14 | 15 | expectedSql := "SELECT test" 16 | 17 | b.ExecContext(ctx) 18 | assert.Equal(t, expectedSql, db.LastExecSql) 19 | 20 | b.QueryContext(ctx) 21 | assert.Equal(t, expectedSql, db.LastQuerySql) 22 | 23 | b.QueryRowContext(ctx) 24 | assert.Equal(t, expectedSql, db.LastQueryRowSql) 25 | 26 | err := b.ScanContext(ctx) 27 | assert.NoError(t, err) 28 | } 29 | 30 | func TestSelectBuilderContextNoRunner(t *testing.T) { 31 | b := Select("test") 32 | 33 | _, err := b.ExecContext(ctx) 34 | assert.Equal(t, RunnerNotSet, err) 35 | 36 | _, err = b.QueryContext(ctx) 37 | assert.Equal(t, RunnerNotSet, err) 38 | 39 | err = b.ScanContext(ctx) 40 | assert.Equal(t, RunnerNotSet, err) 41 | } 42 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw= 4 | github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o= 5 | github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk= 6 | github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw= 7 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 8 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 9 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 10 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 11 | -------------------------------------------------------------------------------- /insert_ctx_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestInsertBuilderContextRunners(t *testing.T) { 12 | db := &DBStub{} 13 | b := Insert("test").Values(1).RunWith(db) 14 | 15 | expectedSql := "INSERT INTO test VALUES (?)" 16 | 17 | b.ExecContext(ctx) 18 | assert.Equal(t, expectedSql, db.LastExecSql) 19 | 20 | b.QueryContext(ctx) 21 | assert.Equal(t, expectedSql, db.LastQuerySql) 22 | 23 | b.QueryRowContext(ctx) 24 | assert.Equal(t, expectedSql, db.LastQueryRowSql) 25 | 26 | err := b.ScanContext(ctx) 27 | assert.NoError(t, err) 28 | } 29 | 30 | func TestInsertBuilderContextNoRunner(t *testing.T) { 31 | b := Insert("test").Values(1) 32 | 33 | _, err := b.ExecContext(ctx) 34 | assert.Equal(t, RunnerNotSet, err) 35 | 36 | _, err = b.QueryContext(ctx) 37 | assert.Equal(t, RunnerNotSet, err) 38 | 39 | err = b.ScanContext(ctx) 40 | assert.Equal(t, RunnerNotSet, err) 41 | } 42 | -------------------------------------------------------------------------------- /update_ctx_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestUpdateBuilderContextRunners(t *testing.T) { 12 | db := &DBStub{} 13 | b := Update("test").Set("x", 1).RunWith(db) 14 | 15 | expectedSql := "UPDATE test SET x = ?" 16 | 17 | b.ExecContext(ctx) 18 | assert.Equal(t, expectedSql, db.LastExecSql) 19 | 20 | b.QueryContext(ctx) 21 | assert.Equal(t, expectedSql, db.LastQuerySql) 22 | 23 | b.QueryRowContext(ctx) 24 | assert.Equal(t, expectedSql, db.LastQueryRowSql) 25 | 26 | err := b.ScanContext(ctx) 27 | assert.NoError(t, err) 28 | } 29 | 30 | func TestUpdateBuilderContextNoRunner(t *testing.T) { 31 | b := Update("test").Set("x", 1) 32 | 33 | _, err := b.ExecContext(ctx) 34 | assert.Equal(t, RunnerNotSet, err) 35 | 36 | _, err = b.QueryContext(ctx) 37 | assert.Equal(t, RunnerNotSet, err) 38 | 39 | err = b.ScanContext(ctx) 40 | assert.Equal(t, RunnerNotSet, err) 41 | } 42 | -------------------------------------------------------------------------------- /delete_ctx_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestDeleteBuilderContextRunners(t *testing.T) { 12 | db := &DBStub{} 13 | b := Delete("test").Where("x = ?", 1).RunWith(db) 14 | 15 | expectedSql := "DELETE test FROM test WHERE x = ?" 16 | 17 | b.ExecContext(ctx) 18 | assert.Equal(t, expectedSql, db.LastExecSql) 19 | 20 | b.QueryContext(ctx) 21 | assert.Equal(t, expectedSql, db.LastQuerySql) 22 | 23 | b.QueryRowContext(ctx) 24 | assert.Equal(t, expectedSql, db.LastQueryRowSql) 25 | 26 | err := b.ScanContext(ctx) 27 | assert.NoError(t, err) 28 | } 29 | 30 | func TestDeleteBuilderContextNoRunner(t *testing.T) { 31 | b := Delete("test").Where("x != ?", 0).Suffix("RETURNING x") 32 | 33 | _, err := b.ExecContext(ctx) 34 | assert.Equal(t, RunnerNotSet, err) 35 | 36 | _, err = b.QueryContext(ctx) 37 | assert.Equal(t, RunnerNotSet, err) 38 | 39 | err = b.ScanContext(ctx) 40 | assert.Equal(t, RunnerNotSet, err) 41 | } 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Squirrel: The Masterminds 4 | Copyright (c) 2014-2015, Lann Martin. Copyright (C) 2015-2016, Google. Copyright (C) 2015, Matt Farina and Matt Butcher. 5 | 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /part.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | ) 7 | 8 | type part struct { 9 | pred interface{} 10 | args []interface{} 11 | } 12 | 13 | func newPart(pred interface{}, args ...interface{}) Sqlizer { 14 | return &part{pred, args} 15 | } 16 | 17 | func (p part) ToSql() (sql string, args []interface{}, err error) { 18 | switch pred := p.pred.(type) { 19 | case nil: 20 | // no-op 21 | case Sqlizer: 22 | sql, args, err = nestedToSql(pred) 23 | case string: 24 | sql = pred 25 | args = p.args 26 | default: 27 | err = fmt.Errorf("expected string or Sqlizer, not %T", pred) 28 | } 29 | return 30 | } 31 | 32 | func nestedToSql(s Sqlizer) (string, []interface{}, error) { 33 | if raw, ok := s.(rawSqlizer); ok { 34 | return raw.toSqlRaw() 35 | } else { 36 | return s.ToSql() 37 | } 38 | } 39 | 40 | func appendToSql(parts []Sqlizer, w io.Writer, sep string, args []interface{}) ([]interface{}, error) { 41 | for i, p := range parts { 42 | partSql, partArgs, err := nestedToSql(p) 43 | if err != nil { 44 | return nil, err 45 | } else if len(partSql) == 0 { 46 | continue 47 | } 48 | 49 | if i > 0 { 50 | _, err := io.WriteString(w, sep) 51 | if err != nil { 52 | return nil, err 53 | } 54 | } 55 | 56 | _, err = io.WriteString(w, partSql) 57 | if err != nil { 58 | return nil, err 59 | } 60 | args = append(args, partArgs...) 61 | } 62 | return args, nil 63 | } 64 | -------------------------------------------------------------------------------- /squirrel_ctx_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func (s *DBStub) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { 14 | s.LastPrepareSql = query 15 | s.PrepareCount++ 16 | return nil, nil 17 | } 18 | 19 | func (s *DBStub) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 20 | s.LastExecSql = query 21 | s.LastExecArgs = args 22 | return nil, nil 23 | } 24 | 25 | func (s *DBStub) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { 26 | s.LastQuerySql = query 27 | s.LastQueryArgs = args 28 | return nil, nil 29 | } 30 | 31 | func (s *DBStub) QueryRowContext(ctx context.Context, query string, args ...interface{}) RowScanner { 32 | s.LastQueryRowSql = query 33 | s.LastQueryRowArgs = args 34 | return &Row{RowScanner: &RowStub{}} 35 | } 36 | 37 | var ctx = context.Background() 38 | 39 | func TestExecContextWith(t *testing.T) { 40 | db := &DBStub{} 41 | ExecContextWith(ctx, db, sqlizer) 42 | assert.Equal(t, sqlStr, db.LastExecSql) 43 | } 44 | 45 | func TestQueryContextWith(t *testing.T) { 46 | db := &DBStub{} 47 | QueryContextWith(ctx, db, sqlizer) 48 | assert.Equal(t, sqlStr, db.LastQuerySql) 49 | } 50 | 51 | func TestQueryRowContextWith(t *testing.T) { 52 | db := &DBStub{} 53 | QueryRowContextWith(ctx, db, sqlizer) 54 | assert.Equal(t, sqlStr, db.LastQueryRowSql) 55 | } 56 | -------------------------------------------------------------------------------- /where_test.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "testing" 5 | 6 | "bytes" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestWherePartsAppendToSql(t *testing.T) { 12 | parts := []Sqlizer{ 13 | newWherePart("x = ?", 1), 14 | newWherePart(nil), 15 | newWherePart(Eq{"y": 2}), 16 | } 17 | sql := &bytes.Buffer{} 18 | args, _ := appendToSql(parts, sql, " AND ", []interface{}{}) 19 | assert.Equal(t, "x = ? AND y = ?", sql.String()) 20 | assert.Equal(t, []interface{}{1, 2}, args) 21 | } 22 | 23 | func TestWherePartsAppendToSqlErr(t *testing.T) { 24 | parts := []Sqlizer{newWherePart(1)} 25 | _, err := appendToSql(parts, &bytes.Buffer{}, "", []interface{}{}) 26 | assert.Error(t, err) 27 | } 28 | 29 | func TestWherePartNil(t *testing.T) { 30 | sql, _, _ := newWherePart(nil).ToSql() 31 | assert.Equal(t, "", sql) 32 | } 33 | 34 | func TestWherePartErr(t *testing.T) { 35 | _, _, err := newWherePart(1).ToSql() 36 | assert.Error(t, err) 37 | } 38 | 39 | func TestWherePartString(t *testing.T) { 40 | sql, args, _ := newWherePart("x = ?", 1).ToSql() 41 | assert.Equal(t, "x = ?", sql) 42 | assert.Equal(t, []interface{}{1}, args) 43 | } 44 | 45 | func TestWherePartMap(t *testing.T) { 46 | test := func(pred interface{}) { 47 | sql, _, _ := newWherePart(pred).ToSql() 48 | expect := []string{"x = ? AND y = ?", "y = ? AND x = ?"} 49 | if sql != expect[0] && sql != expect[1] { 50 | t.Errorf("expected one of %#v, got %#v", expect, sql) 51 | } 52 | } 53 | m := map[string]interface{}{"x": 1, "y": 2} 54 | test(m) 55 | test(Eq(m)) 56 | } 57 | -------------------------------------------------------------------------------- /delete_ctx.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | 9 | "github.com/lann/builder" 10 | ) 11 | 12 | func (d *deleteData) ExecContext(ctx context.Context) (sql.Result, error) { 13 | if d.RunWith == nil { 14 | return nil, RunnerNotSet 15 | } 16 | ctxRunner, ok := d.RunWith.(ExecerContext) 17 | if !ok { 18 | return nil, NoContextSupport 19 | } 20 | return ExecContextWith(ctx, ctxRunner, d) 21 | } 22 | 23 | func (d *deleteData) QueryContext(ctx context.Context) (*sql.Rows, error) { 24 | if d.RunWith == nil { 25 | return nil, RunnerNotSet 26 | } 27 | ctxRunner, ok := d.RunWith.(QueryerContext) 28 | if !ok { 29 | return nil, NoContextSupport 30 | } 31 | return QueryContextWith(ctx, ctxRunner, d) 32 | } 33 | 34 | func (d *deleteData) QueryRowContext(ctx context.Context) RowScanner { 35 | if d.RunWith == nil { 36 | return &Row{err: RunnerNotSet} 37 | } 38 | queryRower, ok := d.RunWith.(QueryRowerContext) 39 | if !ok { 40 | if _, ok := d.RunWith.(QueryerContext); !ok { 41 | return &Row{err: RunnerNotQueryRunner} 42 | } 43 | return &Row{err: NoContextSupport} 44 | } 45 | return QueryRowContextWith(ctx, queryRower, d) 46 | } 47 | 48 | // ExecContext builds and ExecContexts the query with the Runner set by RunWith. 49 | func (b DeleteBuilder) ExecContext(ctx context.Context) (sql.Result, error) { 50 | data := builder.GetStruct(b).(deleteData) 51 | return data.ExecContext(ctx) 52 | } 53 | 54 | // QueryContext builds and QueryContexts the query with the Runner set by RunWith. 55 | func (b DeleteBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) { 56 | data := builder.GetStruct(b).(deleteData) 57 | return data.QueryContext(ctx) 58 | } 59 | 60 | // QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith. 61 | func (b DeleteBuilder) QueryRowContext(ctx context.Context) RowScanner { 62 | data := builder.GetStruct(b).(deleteData) 63 | return data.QueryRowContext(ctx) 64 | } 65 | 66 | // ScanContext is a shortcut for QueryRowContext().Scan. 67 | func (b DeleteBuilder) ScanContext(ctx context.Context, dest ...interface{}) error { 68 | return b.QueryRowContext(ctx).Scan(dest...) 69 | } 70 | -------------------------------------------------------------------------------- /insert_ctx.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | 9 | "github.com/lann/builder" 10 | ) 11 | 12 | func (d *insertData) ExecContext(ctx context.Context) (sql.Result, error) { 13 | if d.RunWith == nil { 14 | return nil, RunnerNotSet 15 | } 16 | ctxRunner, ok := d.RunWith.(ExecerContext) 17 | if !ok { 18 | return nil, NoContextSupport 19 | } 20 | return ExecContextWith(ctx, ctxRunner, d) 21 | } 22 | 23 | func (d *insertData) QueryContext(ctx context.Context) (*sql.Rows, error) { 24 | if d.RunWith == nil { 25 | return nil, RunnerNotSet 26 | } 27 | ctxRunner, ok := d.RunWith.(QueryerContext) 28 | if !ok { 29 | return nil, NoContextSupport 30 | } 31 | return QueryContextWith(ctx, ctxRunner, d) 32 | } 33 | 34 | func (d *insertData) QueryRowContext(ctx context.Context) RowScanner { 35 | if d.RunWith == nil { 36 | return &Row{err: RunnerNotSet} 37 | } 38 | queryRower, ok := d.RunWith.(QueryRowerContext) 39 | if !ok { 40 | if _, ok := d.RunWith.(QueryerContext); !ok { 41 | return &Row{err: RunnerNotQueryRunner} 42 | } 43 | return &Row{err: NoContextSupport} 44 | } 45 | return QueryRowContextWith(ctx, queryRower, d) 46 | } 47 | 48 | // ExecContext builds and ExecContexts the query with the Runner set by RunWith. 49 | func (b InsertBuilder) ExecContext(ctx context.Context) (sql.Result, error) { 50 | data := builder.GetStruct(b).(insertData) 51 | return data.ExecContext(ctx) 52 | } 53 | 54 | // QueryContext builds and QueryContexts the query with the Runner set by RunWith. 55 | func (b InsertBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) { 56 | data := builder.GetStruct(b).(insertData) 57 | return data.QueryContext(ctx) 58 | } 59 | 60 | // QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith. 61 | func (b InsertBuilder) QueryRowContext(ctx context.Context) RowScanner { 62 | data := builder.GetStruct(b).(insertData) 63 | return data.QueryRowContext(ctx) 64 | } 65 | 66 | // ScanContext is a shortcut for QueryRowContext().Scan. 67 | func (b InsertBuilder) ScanContext(ctx context.Context, dest ...interface{}) error { 68 | return b.QueryRowContext(ctx).Scan(dest...) 69 | } 70 | -------------------------------------------------------------------------------- /select_ctx.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | 9 | "github.com/lann/builder" 10 | ) 11 | 12 | func (d *selectData) ExecContext(ctx context.Context) (sql.Result, error) { 13 | if d.RunWith == nil { 14 | return nil, RunnerNotSet 15 | } 16 | ctxRunner, ok := d.RunWith.(ExecerContext) 17 | if !ok { 18 | return nil, NoContextSupport 19 | } 20 | return ExecContextWith(ctx, ctxRunner, d) 21 | } 22 | 23 | func (d *selectData) QueryContext(ctx context.Context) (*sql.Rows, error) { 24 | if d.RunWith == nil { 25 | return nil, RunnerNotSet 26 | } 27 | ctxRunner, ok := d.RunWith.(QueryerContext) 28 | if !ok { 29 | return nil, NoContextSupport 30 | } 31 | return QueryContextWith(ctx, ctxRunner, d) 32 | } 33 | 34 | func (d *selectData) QueryRowContext(ctx context.Context) RowScanner { 35 | if d.RunWith == nil { 36 | return &Row{err: RunnerNotSet} 37 | } 38 | queryRower, ok := d.RunWith.(QueryRowerContext) 39 | if !ok { 40 | if _, ok := d.RunWith.(QueryerContext); !ok { 41 | return &Row{err: RunnerNotQueryRunner} 42 | } 43 | return &Row{err: NoContextSupport} 44 | } 45 | return QueryRowContextWith(ctx, queryRower, d) 46 | } 47 | 48 | // ExecContext builds and ExecContexts the query with the Runner set by RunWith. 49 | func (b SelectBuilder) ExecContext(ctx context.Context) (sql.Result, error) { 50 | data := builder.GetStruct(b).(selectData) 51 | return data.ExecContext(ctx) 52 | } 53 | 54 | // QueryContext builds and QueryContexts the query with the Runner set by RunWith. 55 | func (b SelectBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) { 56 | data := builder.GetStruct(b).(selectData) 57 | return data.QueryContext(ctx) 58 | } 59 | 60 | // QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith. 61 | func (b SelectBuilder) QueryRowContext(ctx context.Context) RowScanner { 62 | data := builder.GetStruct(b).(selectData) 63 | return data.QueryRowContext(ctx) 64 | } 65 | 66 | // ScanContext is a shortcut for QueryRowContext().Scan. 67 | func (b SelectBuilder) ScanContext(ctx context.Context, dest ...interface{}) error { 68 | return b.QueryRowContext(ctx).Scan(dest...) 69 | } 70 | -------------------------------------------------------------------------------- /update_ctx.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | 9 | "github.com/lann/builder" 10 | ) 11 | 12 | func (d *updateData) ExecContext(ctx context.Context) (sql.Result, error) { 13 | if d.RunWith == nil { 14 | return nil, RunnerNotSet 15 | } 16 | ctxRunner, ok := d.RunWith.(ExecerContext) 17 | if !ok { 18 | return nil, NoContextSupport 19 | } 20 | return ExecContextWith(ctx, ctxRunner, d) 21 | } 22 | 23 | func (d *updateData) QueryContext(ctx context.Context) (*sql.Rows, error) { 24 | if d.RunWith == nil { 25 | return nil, RunnerNotSet 26 | } 27 | ctxRunner, ok := d.RunWith.(QueryerContext) 28 | if !ok { 29 | return nil, NoContextSupport 30 | } 31 | return QueryContextWith(ctx, ctxRunner, d) 32 | } 33 | 34 | func (d *updateData) QueryRowContext(ctx context.Context) RowScanner { 35 | if d.RunWith == nil { 36 | return &Row{err: RunnerNotSet} 37 | } 38 | queryRower, ok := d.RunWith.(QueryRowerContext) 39 | if !ok { 40 | if _, ok := d.RunWith.(QueryerContext); !ok { 41 | return &Row{err: RunnerNotQueryRunner} 42 | } 43 | return &Row{err: NoContextSupport} 44 | } 45 | return QueryRowContextWith(ctx, queryRower, d) 46 | } 47 | 48 | // ExecContext builds and ExecContexts the query with the Runner set by RunWith. 49 | func (b UpdateBuilder) ExecContext(ctx context.Context) (sql.Result, error) { 50 | data := builder.GetStruct(b).(updateData) 51 | return data.ExecContext(ctx) 52 | } 53 | 54 | // QueryContext builds and QueryContexts the query with the Runner set by RunWith. 55 | func (b UpdateBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) { 56 | data := builder.GetStruct(b).(updateData) 57 | return data.QueryContext(ctx) 58 | } 59 | 60 | // QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith. 61 | func (b UpdateBuilder) QueryRowContext(ctx context.Context) RowScanner { 62 | data := builder.GetStruct(b).(updateData) 63 | return data.QueryRowContext(ctx) 64 | } 65 | 66 | // ScanContext is a shortcut for QueryRowContext().Scan. 67 | func (b UpdateBuilder) ScanContext(ctx context.Context, dest ...interface{}) error { 68 | return b.QueryRowContext(ctx).Scan(dest...) 69 | } 70 | -------------------------------------------------------------------------------- /placeholder_test.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestQuestion(t *testing.T) { 11 | sql := "x = ? AND y = ?" 12 | s, _ := Question.ReplacePlaceholders(sql) 13 | assert.Equal(t, sql, s) 14 | } 15 | 16 | func TestDollar(t *testing.T) { 17 | sql := "x = ? AND y = ?" 18 | s, _ := Dollar.ReplacePlaceholders(sql) 19 | assert.Equal(t, "x = $1 AND y = $2", s) 20 | } 21 | 22 | func TestColon(t *testing.T) { 23 | sql := "x = ? AND y = ?" 24 | s, _ := Colon.ReplacePlaceholders(sql) 25 | assert.Equal(t, "x = :1 AND y = :2", s) 26 | } 27 | 28 | func TestAtp(t *testing.T) { 29 | sql := "x = ? AND y = ?" 30 | s, _ := AtP.ReplacePlaceholders(sql) 31 | assert.Equal(t, "x = @p1 AND y = @p2", s) 32 | } 33 | 34 | func TestPlaceholders(t *testing.T) { 35 | assert.Equal(t, Placeholders(2), "?,?") 36 | } 37 | 38 | func TestEscapeDollar(t *testing.T) { 39 | sql := "SELECT uuid, \"data\" #> '{tags}' AS tags FROM nodes WHERE \"data\" -> 'tags' ??| array['?'] AND enabled = ?" 40 | s, _ := Dollar.ReplacePlaceholders(sql) 41 | assert.Equal(t, "SELECT uuid, \"data\" #> '{tags}' AS tags FROM nodes WHERE \"data\" -> 'tags' ?| array['$1'] AND enabled = $2", s) 42 | } 43 | 44 | func TestEscapeColon(t *testing.T) { 45 | sql := "SELECT uuid, \"data\" #> '{tags}' AS tags FROM nodes WHERE \"data\" -> 'tags' ??| array['?'] AND enabled = ?" 46 | s, _ := Colon.ReplacePlaceholders(sql) 47 | assert.Equal(t, "SELECT uuid, \"data\" #> '{tags}' AS tags FROM nodes WHERE \"data\" -> 'tags' ?| array[':1'] AND enabled = :2", s) 48 | } 49 | 50 | func TestEscapeAtp(t *testing.T) { 51 | sql := "SELECT uuid, \"data\" #> '{tags}' AS tags FROM nodes WHERE \"data\" -> 'tags' ??| array['?'] AND enabled = ?" 52 | s, _ := AtP.ReplacePlaceholders(sql) 53 | assert.Equal(t, "SELECT uuid, \"data\" #> '{tags}' AS tags FROM nodes WHERE \"data\" -> 'tags' ?| array['@p1'] AND enabled = @p2", s) 54 | } 55 | 56 | func BenchmarkPlaceholdersArray(b *testing.B) { 57 | var count = b.N 58 | placeholders := make([]string, count) 59 | for i := 0; i < count; i++ { 60 | placeholders[i] = "?" 61 | } 62 | var _ = strings.Join(placeholders, ",") 63 | } 64 | 65 | func BenchmarkPlaceholdersStrings(b *testing.B) { 66 | Placeholders(b.N) 67 | } 68 | -------------------------------------------------------------------------------- /statement_test.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/lann/builder" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestStatementBuilder(t *testing.T) { 12 | db := &DBStub{} 13 | sb := StatementBuilder.RunWith(db) 14 | 15 | sb.Select("test").Exec() 16 | assert.Equal(t, "SELECT test", db.LastExecSql) 17 | } 18 | 19 | func TestStatementBuilderPlaceholderFormat(t *testing.T) { 20 | db := &DBStub{} 21 | sb := StatementBuilder.RunWith(db).PlaceholderFormat(Dollar) 22 | 23 | sb.Select("test").Where("x = ?").Exec() 24 | assert.Equal(t, "SELECT test WHERE x = $1", db.LastExecSql) 25 | } 26 | 27 | func TestRunWithDB(t *testing.T) { 28 | db := &sql.DB{} 29 | assert.NotPanics(t, func() { 30 | builder.GetStruct(Select().RunWith(db)) 31 | builder.GetStruct(Insert("t").RunWith(db)) 32 | builder.GetStruct(Update("t").RunWith(db)) 33 | builder.GetStruct(Delete("t").RunWith(db)) 34 | }, "RunWith(*sql.DB) should not panic") 35 | 36 | } 37 | 38 | func TestRunWithTx(t *testing.T) { 39 | tx := &sql.Tx{} 40 | assert.NotPanics(t, func() { 41 | builder.GetStruct(Select().RunWith(tx)) 42 | builder.GetStruct(Insert("t").RunWith(tx)) 43 | builder.GetStruct(Update("t").RunWith(tx)) 44 | builder.GetStruct(Delete("t").RunWith(tx)) 45 | }, "RunWith(*sql.Tx) should not panic") 46 | } 47 | 48 | type fakeBaseRunner struct{} 49 | 50 | func (fakeBaseRunner) Exec(query string, args ...interface{}) (sql.Result, error) { 51 | return nil, nil 52 | } 53 | 54 | func (fakeBaseRunner) Query(query string, args ...interface{}) (*sql.Rows, error) { 55 | return nil, nil 56 | } 57 | 58 | func TestRunWithBaseRunner(t *testing.T) { 59 | sb := StatementBuilder.RunWith(fakeBaseRunner{}) 60 | _, err := sb.Select("test").Exec() 61 | assert.NoError(t, err) 62 | } 63 | 64 | func TestRunWithBaseRunnerQueryRowError(t *testing.T) { 65 | sb := StatementBuilder.RunWith(fakeBaseRunner{}) 66 | assert.Error(t, RunnerNotQueryRunner, sb.Select("test").QueryRow().Scan(nil)) 67 | 68 | } 69 | 70 | func TestStatementBuilderWhere(t *testing.T) { 71 | sb := StatementBuilder.Where("x = ?", 1) 72 | 73 | sql, args, err := sb.Select("test").Where("y = ?", 2).ToSql() 74 | assert.NoError(t, err) 75 | 76 | expectedSql := "SELECT test WHERE x = ? AND y = ?" 77 | assert.Equal(t, expectedSql, sql) 78 | 79 | expectedArgs := []interface{}{1, 2} 80 | assert.Equal(t, expectedArgs, args) 81 | } 82 | -------------------------------------------------------------------------------- /delete_test.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestDeleteBuilderToSql(t *testing.T) { 10 | b := Delete(""). 11 | Prefix("WITH prefix AS ?", 0). 12 | From("a"). 13 | JoinClause("CROSS JOIN j1"). 14 | Join("j2"). 15 | LeftJoin("j3"). 16 | RightJoin("j4"). 17 | InnerJoin("j5"). 18 | CrossJoin("j6"). 19 | Where("b = ?", 1). 20 | OrderBy("c"). 21 | Limit(2). 22 | Offset(3). 23 | Suffix("RETURNING ?", 4) 24 | 25 | sql, args, err := b.ToSql() 26 | assert.NoError(t, err) 27 | 28 | expectedSql := 29 | "WITH prefix AS ? " + 30 | "DELETE a FROM a " + 31 | "CROSS JOIN j1 " + 32 | "JOIN j2 " + 33 | "LEFT JOIN j3 " + 34 | "RIGHT JOIN j4 " + 35 | "INNER JOIN j5 " + 36 | "CROSS JOIN j6 " + 37 | "WHERE b = ? ORDER BY c LIMIT 2 OFFSET 3 " + 38 | "RETURNING ?" 39 | assert.Equal(t, expectedSql, sql) 40 | 41 | expectedArgs := []interface{}{0, 1, 4} 42 | assert.Equal(t, expectedArgs, args) 43 | } 44 | 45 | func TestDeleteBuilderToSqlErr(t *testing.T) { 46 | _, _, err := Delete("").ToSql() 47 | assert.Error(t, err) 48 | } 49 | 50 | func TestDeleteBuilderMustSql(t *testing.T) { 51 | defer func() { 52 | if r := recover(); r == nil { 53 | t.Errorf("TestDeleteBuilderMustSql should have panicked!") 54 | } 55 | }() 56 | Delete("").MustSql() 57 | } 58 | 59 | func TestDeleteBuilderPlaceholders(t *testing.T) { 60 | b := Delete("test").Where("x = ? AND y = ?", 1, 2) 61 | 62 | sql, _, _ := b.PlaceholderFormat(Question).ToSql() 63 | assert.Equal(t, "DELETE test FROM test WHERE x = ? AND y = ?", sql) 64 | 65 | sql, _, _ = b.PlaceholderFormat(Dollar).ToSql() 66 | assert.Equal(t, "DELETE test FROM test WHERE x = $1 AND y = $2", sql) 67 | } 68 | 69 | func TestDeleteBuilderRunners(t *testing.T) { 70 | db := &DBStub{} 71 | b := Delete("test").Where("x = ?", 1).RunWith(db) 72 | 73 | expectedSql := "DELETE test FROM test WHERE x = ?" 74 | 75 | b.Exec() 76 | assert.Equal(t, expectedSql, db.LastExecSql) 77 | } 78 | 79 | func TestDeleteBuilderNoRunner(t *testing.T) { 80 | b := Delete("test") 81 | 82 | _, err := b.Exec() 83 | assert.Equal(t, RunnerNotSet, err) 84 | } 85 | 86 | func TestDeleteWithQuery(t *testing.T) { 87 | db := &DBStub{} 88 | b := Delete("test").Where("id=55").Suffix("RETURNING path").RunWith(db) 89 | 90 | expectedSql := "DELETE test FROM test WHERE id=55 RETURNING path" 91 | b.Query() 92 | 93 | assert.Equal(t, expectedSql, db.LastQuerySql) 94 | } 95 | -------------------------------------------------------------------------------- /stmtcacher_ctx.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | ) 9 | 10 | // PrepareerContext is the interface that wraps the Prepare and PrepareContext methods. 11 | // 12 | // Prepare executes the given query as implemented by database/sql.Prepare. 13 | // PrepareContext executes the given query as implemented by database/sql.PrepareContext. 14 | type PreparerContext interface { 15 | Preparer 16 | PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) 17 | } 18 | 19 | // DBProxyContext groups the Execer, Queryer, QueryRower and PreparerContext interfaces. 20 | type DBProxyContext interface { 21 | Execer 22 | Queryer 23 | QueryRower 24 | PreparerContext 25 | } 26 | 27 | // NewStmtCache returns a *StmtCache wrapping a PreparerContext that caches Prepared Stmts. 28 | // 29 | // Stmts are cached based on the string value of their queries. 30 | func NewStmtCache(prep PreparerContext) *StmtCache { 31 | return &StmtCache{prep: prep, cache: make(map[string]*sql.Stmt)} 32 | } 33 | 34 | // NewStmtCacher is deprecated 35 | // 36 | // Use NewStmtCache instead 37 | func NewStmtCacher(prep PreparerContext) DBProxyContext { 38 | return NewStmtCache(prep) 39 | } 40 | 41 | // PrepareContext delegates down to the underlying PreparerContext and caches the result 42 | // using the provided query as a key 43 | func (sc *StmtCache) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { 44 | ctxPrep, ok := sc.prep.(PreparerContext) 45 | if !ok { 46 | return nil, NoContextSupport 47 | } 48 | sc.mu.Lock() 49 | defer sc.mu.Unlock() 50 | stmt, ok := sc.cache[query] 51 | if ok { 52 | return stmt, nil 53 | } 54 | stmt, err := ctxPrep.PrepareContext(ctx, query) 55 | if err == nil { 56 | sc.cache[query] = stmt 57 | } 58 | return stmt, err 59 | } 60 | 61 | // ExecContext delegates down to the underlying PreparerContext using a prepared statement 62 | func (sc *StmtCache) ExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result, err error) { 63 | stmt, err := sc.PrepareContext(ctx, query) 64 | if err != nil { 65 | return 66 | } 67 | return stmt.ExecContext(ctx, args...) 68 | } 69 | 70 | // QueryContext delegates down to the underlying PreparerContext using a prepared statement 71 | func (sc *StmtCache) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { 72 | stmt, err := sc.PrepareContext(ctx, query) 73 | if err != nil { 74 | return 75 | } 76 | return stmt.QueryContext(ctx, args...) 77 | } 78 | 79 | // QueryRowContext delegates down to the underlying PreparerContext using a prepared statement 80 | func (sc *StmtCache) QueryRowContext(ctx context.Context, query string, args ...interface{}) RowScanner { 81 | stmt, err := sc.PrepareContext(ctx, query) 82 | if err != nil { 83 | return &Row{err: err} 84 | } 85 | return stmt.QueryRowContext(ctx, args...) 86 | } 87 | -------------------------------------------------------------------------------- /placeholder.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | // PlaceholderFormat is the interface that wraps the ReplacePlaceholders method. 10 | // 11 | // ReplacePlaceholders takes a SQL statement and replaces each question mark 12 | // placeholder with a (possibly different) SQL placeholder. 13 | type PlaceholderFormat interface { 14 | ReplacePlaceholders(sql string) (string, error) 15 | } 16 | 17 | type placeholderDebugger interface { 18 | debugPlaceholder() string 19 | } 20 | 21 | var ( 22 | // Question is a PlaceholderFormat instance that leaves placeholders as 23 | // question marks. 24 | Question = questionFormat{} 25 | 26 | // Dollar is a PlaceholderFormat instance that replaces placeholders with 27 | // dollar-prefixed positional placeholders (e.g. $1, $2, $3). 28 | Dollar = dollarFormat{} 29 | 30 | // Colon is a PlaceholderFormat instance that replaces placeholders with 31 | // colon-prefixed positional placeholders (e.g. :1, :2, :3). 32 | Colon = colonFormat{} 33 | 34 | // AtP is a PlaceholderFormat instance that replaces placeholders with 35 | // "@p"-prefixed positional placeholders (e.g. @p1, @p2, @p3). 36 | AtP = atpFormat{} 37 | ) 38 | 39 | type questionFormat struct{} 40 | 41 | func (questionFormat) ReplacePlaceholders(sql string) (string, error) { 42 | return sql, nil 43 | } 44 | 45 | func (questionFormat) debugPlaceholder() string { 46 | return "?" 47 | } 48 | 49 | type dollarFormat struct{} 50 | 51 | func (dollarFormat) ReplacePlaceholders(sql string) (string, error) { 52 | return replacePositionalPlaceholders(sql, "$") 53 | } 54 | 55 | func (dollarFormat) debugPlaceholder() string { 56 | return "$" 57 | } 58 | 59 | type colonFormat struct{} 60 | 61 | func (colonFormat) ReplacePlaceholders(sql string) (string, error) { 62 | return replacePositionalPlaceholders(sql, ":") 63 | } 64 | 65 | func (colonFormat) debugPlaceholder() string { 66 | return ":" 67 | } 68 | 69 | type atpFormat struct{} 70 | 71 | func (atpFormat) ReplacePlaceholders(sql string) (string, error) { 72 | return replacePositionalPlaceholders(sql, "@p") 73 | } 74 | 75 | func (atpFormat) debugPlaceholder() string { 76 | return "@p" 77 | } 78 | 79 | // Placeholders returns a string with count ? placeholders joined with commas. 80 | func Placeholders(count int) string { 81 | if count < 1 { 82 | return "" 83 | } 84 | 85 | return strings.Repeat(",?", count)[1:] 86 | } 87 | 88 | func replacePositionalPlaceholders(sql, prefix string) (string, error) { 89 | buf := &bytes.Buffer{} 90 | i := 0 91 | for { 92 | p := strings.Index(sql, "?") 93 | if p == -1 { 94 | break 95 | } 96 | 97 | if len(sql[p:]) > 1 && sql[p:p+2] == "??" { // escape ?? => ? 98 | buf.WriteString(sql[:p]) 99 | buf.WriteString("?") 100 | if len(sql[p:]) == 1 { 101 | break 102 | } 103 | sql = sql[p+2:] 104 | } else { 105 | i++ 106 | buf.WriteString(sql[:p]) 107 | fmt.Fprintf(buf, "%s%d", prefix, i) 108 | sql = sql[p+1:] 109 | } 110 | } 111 | 112 | buf.WriteString(sql) 113 | return buf.String(), nil 114 | } 115 | -------------------------------------------------------------------------------- /integration/go.sum: -------------------------------------------------------------------------------- 1 | github.com/Masterminds/squirrel v1.1.0 h1:baP1qLdoQCeTw3ifCdOq2dkYc6vGcmRdaociKLbEJXs= 2 | github.com/Masterminds/squirrel v1.1.0/go.mod h1:yaPeOnPG5ZRwL9oKdTsO/prlkPbXWZlRVMQ/gGlzIuA= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= 7 | github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= 8 | github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 9 | github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw= 10 | github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o= 11 | github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk= 12 | github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw= 13 | github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= 14 | github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= 15 | github.com/mattn/go-sqlite3 v1.13.0 h1:LnJI81JidiW9r7pS/hXe6cFeO5EXNq7KbfvoJLRI69c= 16 | github.com/mattn/go-sqlite3 v1.13.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= 17 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 18 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 19 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 20 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 21 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= 22 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 23 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 24 | golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= 25 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 26 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 27 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 28 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 29 | google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= 30 | google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= 31 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 32 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 33 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 34 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 35 | -------------------------------------------------------------------------------- /insert_test.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestInsertBuilderToSql(t *testing.T) { 10 | b := Insert(""). 11 | Prefix("WITH prefix AS ?", 0). 12 | Into("a"). 13 | Options("DELAYED", "IGNORE"). 14 | Columns("b", "c"). 15 | Values(1, 2). 16 | Values(3, Expr("? + 1", 4)). 17 | Suffix("RETURNING ?", 5) 18 | 19 | sql, args, err := b.ToSql() 20 | assert.NoError(t, err) 21 | 22 | expectedSQL := 23 | "WITH prefix AS ? " + 24 | "INSERT DELAYED IGNORE INTO a (b,c) VALUES (?,?),(?,? + 1) " + 25 | "RETURNING ?" 26 | assert.Equal(t, expectedSQL, sql) 27 | 28 | expectedArgs := []interface{}{0, 1, 2, 3, 4, 5} 29 | assert.Equal(t, expectedArgs, args) 30 | } 31 | 32 | func TestInsertBuilderToSqlErr(t *testing.T) { 33 | _, _, err := Insert("").Values(1).ToSql() 34 | assert.Error(t, err) 35 | 36 | _, _, err = Insert("x").ToSql() 37 | assert.Error(t, err) 38 | } 39 | 40 | func TestInsertBuilderMustSql(t *testing.T) { 41 | defer func() { 42 | if r := recover(); r == nil { 43 | t.Errorf("TestInsertBuilderMustSql should have panicked!") 44 | } 45 | }() 46 | Insert("").MustSql() 47 | } 48 | 49 | func TestInsertBuilderPlaceholders(t *testing.T) { 50 | b := Insert("test").Values(1, 2) 51 | 52 | sql, _, _ := b.PlaceholderFormat(Question).ToSql() 53 | assert.Equal(t, "INSERT INTO test VALUES (?,?)", sql) 54 | 55 | sql, _, _ = b.PlaceholderFormat(Dollar).ToSql() 56 | assert.Equal(t, "INSERT INTO test VALUES ($1,$2)", sql) 57 | } 58 | 59 | func TestInsertBuilderRunners(t *testing.T) { 60 | db := &DBStub{} 61 | b := Insert("test").Values(1).RunWith(db) 62 | 63 | expectedSQL := "INSERT INTO test VALUES (?)" 64 | 65 | b.Exec() 66 | assert.Equal(t, expectedSQL, db.LastExecSql) 67 | } 68 | 69 | func TestInsertBuilderNoRunner(t *testing.T) { 70 | b := Insert("test").Values(1) 71 | 72 | _, err := b.Exec() 73 | assert.Equal(t, RunnerNotSet, err) 74 | } 75 | 76 | func TestInsertBuilderSetMap(t *testing.T) { 77 | b := Insert("table").SetMap(Eq{"field1": 1, "field2": 2, "field3": 3}) 78 | 79 | sql, args, err := b.ToSql() 80 | assert.NoError(t, err) 81 | 82 | expectedSQL := "INSERT INTO table (field1,field2,field3) VALUES (?,?,?)" 83 | assert.Equal(t, expectedSQL, sql) 84 | 85 | expectedArgs := []interface{}{1, 2, 3} 86 | assert.Equal(t, expectedArgs, args) 87 | } 88 | 89 | func TestInsertBuilderSelect(t *testing.T) { 90 | sb := Select("field1").From("table1").Where(Eq{"field1": 1}) 91 | ib := Insert("table2").Columns("field1").Select(sb) 92 | 93 | sql, args, err := ib.ToSql() 94 | assert.NoError(t, err) 95 | 96 | expectedSQL := "INSERT INTO table2 (field1) SELECT field1 FROM table1 WHERE field1 = ?" 97 | assert.Equal(t, expectedSQL, sql) 98 | 99 | expectedArgs := []interface{}{1} 100 | assert.Equal(t, expectedArgs, args) 101 | } 102 | 103 | func TestInsertBuilderReplace(t *testing.T) { 104 | b := Replace("table").Values(1) 105 | 106 | expectedSQL := "REPLACE INTO table VALUES (?)" 107 | 108 | sql, _, err := b.ToSql() 109 | assert.NoError(t, err) 110 | 111 | assert.Equal(t, expectedSQL, sql) 112 | } 113 | -------------------------------------------------------------------------------- /stmtcacher.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "sync" 7 | ) 8 | 9 | // Prepareer is the interface that wraps the Prepare method. 10 | // 11 | // Prepare executes the given query as implemented by database/sql.Prepare. 12 | type Preparer interface { 13 | Prepare(query string) (*sql.Stmt, error) 14 | } 15 | 16 | // DBProxy groups the Execer, Queryer, QueryRower, and Preparer interfaces. 17 | type DBProxy interface { 18 | Execer 19 | Queryer 20 | QueryRower 21 | Preparer 22 | } 23 | 24 | // NOTE: NewStmtCache is defined in stmtcacher_ctx.go (Go >= 1.8) or stmtcacher_noctx.go (Go < 1.8). 25 | 26 | // StmtCache wraps and delegates down to a Preparer type 27 | // 28 | // It also automatically prepares all statements sent to the underlying Preparer calls 29 | // for Exec, Query and QueryRow and caches the returns *sql.Stmt using the provided 30 | // query as the key. So that it can be automatically re-used. 31 | type StmtCache struct { 32 | prep Preparer 33 | cache map[string]*sql.Stmt 34 | mu sync.Mutex 35 | } 36 | 37 | // Prepare delegates down to the underlying Preparer and caches the result 38 | // using the provided query as a key 39 | func (sc *StmtCache) Prepare(query string) (*sql.Stmt, error) { 40 | sc.mu.Lock() 41 | defer sc.mu.Unlock() 42 | 43 | stmt, ok := sc.cache[query] 44 | if ok { 45 | return stmt, nil 46 | } 47 | stmt, err := sc.prep.Prepare(query) 48 | if err == nil { 49 | sc.cache[query] = stmt 50 | } 51 | return stmt, err 52 | } 53 | 54 | // Exec delegates down to the underlying Preparer using a prepared statement 55 | func (sc *StmtCache) Exec(query string, args ...interface{}) (res sql.Result, err error) { 56 | stmt, err := sc.Prepare(query) 57 | if err != nil { 58 | return 59 | } 60 | return stmt.Exec(args...) 61 | } 62 | 63 | // Query delegates down to the underlying Preparer using a prepared statement 64 | func (sc *StmtCache) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { 65 | stmt, err := sc.Prepare(query) 66 | if err != nil { 67 | return 68 | } 69 | return stmt.Query(args...) 70 | } 71 | 72 | // QueryRow delegates down to the underlying Preparer using a prepared statement 73 | func (sc *StmtCache) QueryRow(query string, args ...interface{}) RowScanner { 74 | stmt, err := sc.Prepare(query) 75 | if err != nil { 76 | return &Row{err: err} 77 | } 78 | return stmt.QueryRow(args...) 79 | } 80 | 81 | // Clear removes and closes all the currently cached prepared statements 82 | func (sc *StmtCache) Clear() (err error) { 83 | sc.mu.Lock() 84 | defer sc.mu.Unlock() 85 | 86 | for key, stmt := range sc.cache { 87 | delete(sc.cache, key) 88 | 89 | if stmt == nil { 90 | continue 91 | } 92 | 93 | if cerr := stmt.Close(); cerr != nil { 94 | err = cerr 95 | } 96 | } 97 | 98 | if err != nil { 99 | return fmt.Errorf("one or more Stmt.Close failed; last error: %v", err) 100 | } 101 | 102 | return 103 | } 104 | 105 | type DBProxyBeginner interface { 106 | DBProxy 107 | Begin() (*sql.Tx, error) 108 | } 109 | 110 | type stmtCacheProxy struct { 111 | DBProxy 112 | db *sql.DB 113 | } 114 | 115 | func NewStmtCacheProxy(db *sql.DB) DBProxyBeginner { 116 | return &stmtCacheProxy{DBProxy: NewStmtCache(db), db: db} 117 | } 118 | 119 | func (sp *stmtCacheProxy) Begin() (*sql.Tx, error) { 120 | return sp.db.Begin() 121 | } 122 | -------------------------------------------------------------------------------- /squirrel_ctx.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package squirrel 4 | 5 | import ( 6 | "context" 7 | "database/sql" 8 | "errors" 9 | ) 10 | 11 | // NoContextSupport is returned if a db doesn't support Context. 12 | var NoContextSupport = errors.New("DB does not support Context") 13 | 14 | // ExecerContext is the interface that wraps the ExecContext method. 15 | // 16 | // Exec executes the given query as implemented by database/sql.ExecContext. 17 | type ExecerContext interface { 18 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 19 | } 20 | 21 | // QueryerContext is the interface that wraps the QueryContext method. 22 | // 23 | // QueryContext executes the given query as implemented by database/sql.QueryContext. 24 | type QueryerContext interface { 25 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 26 | } 27 | 28 | // QueryRowerContext is the interface that wraps the QueryRowContext method. 29 | // 30 | // QueryRowContext executes the given query as implemented by database/sql.QueryRowContext. 31 | type QueryRowerContext interface { 32 | QueryRowContext(ctx context.Context, query string, args ...interface{}) RowScanner 33 | } 34 | 35 | // RunnerContext groups the Runner interface, along with the Context versions of each of 36 | // its methods 37 | type RunnerContext interface { 38 | Runner 39 | QueryerContext 40 | QueryRowerContext 41 | ExecerContext 42 | } 43 | 44 | // WrapStdSqlCtx wraps a type implementing the standard SQL interface plus the context 45 | // versions of the methods with methods that squirrel expects. 46 | func WrapStdSqlCtx(stdSqlCtx StdSqlCtx) RunnerContext { 47 | return &stdsqlCtxRunner{stdSqlCtx} 48 | } 49 | 50 | // StdSqlCtx encompasses the standard methods of the *sql.DB type, along with the Context 51 | // versions of those methods, and other types that wrap these methods. 52 | type StdSqlCtx interface { 53 | StdSql 54 | QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) 55 | QueryRowContext(context.Context, string, ...interface{}) *sql.Row 56 | ExecContext(context.Context, string, ...interface{}) (sql.Result, error) 57 | } 58 | 59 | type stdsqlCtxRunner struct { 60 | StdSqlCtx 61 | } 62 | 63 | func (r *stdsqlCtxRunner) QueryRow(query string, args ...interface{}) RowScanner { 64 | return r.StdSqlCtx.QueryRow(query, args...) 65 | } 66 | 67 | func (r *stdsqlCtxRunner) QueryRowContext(ctx context.Context, query string, args ...interface{}) RowScanner { 68 | return r.StdSqlCtx.QueryRowContext(ctx, query, args...) 69 | } 70 | 71 | // ExecContextWith ExecContexts the SQL returned by s with db. 72 | func ExecContextWith(ctx context.Context, db ExecerContext, s Sqlizer) (res sql.Result, err error) { 73 | query, args, err := s.ToSql() 74 | if err != nil { 75 | return 76 | } 77 | return db.ExecContext(ctx, query, args...) 78 | } 79 | 80 | // QueryContextWith QueryContexts the SQL returned by s with db. 81 | func QueryContextWith(ctx context.Context, db QueryerContext, s Sqlizer) (rows *sql.Rows, err error) { 82 | query, args, err := s.ToSql() 83 | if err != nil { 84 | return 85 | } 86 | return db.QueryContext(ctx, query, args...) 87 | } 88 | 89 | // QueryRowContextWith QueryRowContexts the SQL returned by s with db. 90 | func QueryRowContextWith(ctx context.Context, db QueryRowerContext, s Sqlizer) RowScanner { 91 | query, args, err := s.ToSql() 92 | return &Row{RowScanner: db.QueryRowContext(ctx, query, args...), err: err} 93 | } 94 | -------------------------------------------------------------------------------- /update_test.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestUpdateBuilderToSql(t *testing.T) { 10 | b := Update(""). 11 | Prefix("WITH prefix AS ?", 0). 12 | Table("a"). 13 | JoinClause("CROSS JOIN j1"). 14 | Join("j2"). 15 | LeftJoin("j3"). 16 | RightJoin("j4"). 17 | InnerJoin("j5"). 18 | CrossJoin("j6"). 19 | Set("b", Expr("? + 1", 1)). 20 | SetMap(Eq{"c": 2}). 21 | Set("c1", Case("status").When("1", "2").When("2", "1")). 22 | Set("c2", Case().When("a = 2", Expr("?", "foo")).When("a = 3", Expr("?", "bar"))). 23 | Set("c3", Select("a").From("b")). 24 | Where("d = ?", 3). 25 | OrderBy("e"). 26 | Limit(4). 27 | Offset(5). 28 | Suffix("RETURNING ?", 6) 29 | 30 | sql, args, err := b.ToSql() 31 | assert.NoError(t, err) 32 | 33 | expectedSql := 34 | "WITH prefix AS ? " + 35 | "UPDATE a " + 36 | "CROSS JOIN j1 " + 37 | "JOIN j2 " + 38 | "LEFT JOIN j3 " + 39 | "RIGHT JOIN j4 " + 40 | "INNER JOIN j5 " + 41 | "CROSS JOIN j6 " + 42 | "SET b = ? + 1, c = ?, " + 43 | "c1 = CASE status WHEN 1 THEN 2 WHEN 2 THEN 1 END, " + 44 | "c2 = CASE WHEN a = 2 THEN ? WHEN a = 3 THEN ? END, " + 45 | "c3 = (SELECT a FROM b) " + 46 | "WHERE d = ? " + 47 | "ORDER BY e LIMIT 4 OFFSET 5 " + 48 | "RETURNING ?" 49 | assert.Equal(t, expectedSql, sql) 50 | 51 | expectedArgs := []interface{}{0, 1, 2, "foo", "bar", 3, 6} 52 | assert.Equal(t, expectedArgs, args) 53 | } 54 | 55 | func TestUpdateBuilderToSqlErr(t *testing.T) { 56 | _, _, err := Update("").Set("x", 1).ToSql() 57 | assert.Error(t, err) 58 | 59 | _, _, err = Update("x").ToSql() 60 | assert.Error(t, err) 61 | } 62 | 63 | func TestUpdateBuilderMustSql(t *testing.T) { 64 | defer func() { 65 | if r := recover(); r == nil { 66 | t.Errorf("TestUpdateBuilderMustSql should have panicked!") 67 | } 68 | }() 69 | Update("").MustSql() 70 | } 71 | 72 | func TestUpdateBuilderPlaceholders(t *testing.T) { 73 | b := Update("test").SetMap(Eq{"x": 1, "y": 2}) 74 | 75 | sql, _, _ := b.PlaceholderFormat(Question).ToSql() 76 | assert.Equal(t, "UPDATE test SET x = ?, y = ?", sql) 77 | 78 | sql, _, _ = b.PlaceholderFormat(Dollar).ToSql() 79 | assert.Equal(t, "UPDATE test SET x = $1, y = $2", sql) 80 | } 81 | 82 | func TestUpdateBuilderRunners(t *testing.T) { 83 | db := &DBStub{} 84 | b := Update("test").Set("x", 1).RunWith(db) 85 | 86 | expectedSql := "UPDATE test SET x = ?" 87 | 88 | b.Exec() 89 | assert.Equal(t, expectedSql, db.LastExecSql) 90 | } 91 | 92 | func TestUpdateBuilderNoRunner(t *testing.T) { 93 | b := Update("test").Set("x", 1) 94 | 95 | _, err := b.Exec() 96 | assert.Equal(t, RunnerNotSet, err) 97 | } 98 | 99 | func TestUpdateBuilderFrom(t *testing.T) { 100 | sql, _, err := Update("employees").Set("sales_count", 100).From("accounts").Where("accounts.name = ?", "ACME").ToSql() 101 | assert.NoError(t, err) 102 | assert.Equal(t, "UPDATE employees SET sales_count = ? FROM accounts WHERE accounts.name = ?", sql) 103 | } 104 | 105 | func TestUpdateBuilderFromSelect(t *testing.T) { 106 | sql, _, err := Update("employees"). 107 | Set("sales_count", 100). 108 | FromSelect(Select("id"). 109 | From("accounts"). 110 | Where("accounts.name = ?", "ACME"), "subquery"). 111 | Where("employees.account_id = subquery.id").ToSql() 112 | assert.NoError(t, err) 113 | 114 | expectedSql := 115 | "UPDATE employees " + 116 | "SET sales_count = ? " + 117 | "FROM (SELECT id FROM accounts WHERE accounts.name = ?) AS subquery " + 118 | "WHERE employees.account_id = subquery.id" 119 | assert.Equal(t, expectedSql, sql) 120 | } 121 | -------------------------------------------------------------------------------- /case.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | 7 | "github.com/lann/builder" 8 | ) 9 | 10 | func init() { 11 | builder.Register(CaseBuilder{}, caseData{}) 12 | } 13 | 14 | // sqlizerBuffer is a helper that allows to write many Sqlizers one by one 15 | // without constant checks for errors that may come from Sqlizer 16 | type sqlizerBuffer struct { 17 | bytes.Buffer 18 | args []interface{} 19 | err error 20 | } 21 | 22 | // WriteSql converts Sqlizer to SQL strings and writes it to buffer 23 | func (b *sqlizerBuffer) WriteSql(item Sqlizer) { 24 | if b.err != nil { 25 | return 26 | } 27 | 28 | var str string 29 | var args []interface{} 30 | str, args, b.err = nestedToSql(item) 31 | 32 | if b.err != nil { 33 | return 34 | } 35 | 36 | b.WriteString(str) 37 | b.WriteByte(' ') 38 | b.args = append(b.args, args...) 39 | } 40 | 41 | func (b *sqlizerBuffer) ToSql() (string, []interface{}, error) { 42 | return b.String(), b.args, b.err 43 | } 44 | 45 | // whenPart is a helper structure to describe SQLs "WHEN ... THEN ..." expression 46 | type whenPart struct { 47 | when Sqlizer 48 | then Sqlizer 49 | } 50 | 51 | func newWhenPart(when interface{}, then interface{}) whenPart { 52 | return whenPart{newPart(when), newPart(then)} 53 | } 54 | 55 | // caseData holds all the data required to build a CASE SQL construct 56 | type caseData struct { 57 | What Sqlizer 58 | WhenParts []whenPart 59 | Else Sqlizer 60 | } 61 | 62 | // ToSql implements Sqlizer 63 | func (d *caseData) ToSql() (sqlStr string, args []interface{}, err error) { 64 | if len(d.WhenParts) == 0 { 65 | err = errors.New("case expression must contain at lease one WHEN clause") 66 | 67 | return 68 | } 69 | 70 | sql := sqlizerBuffer{} 71 | 72 | sql.WriteString("CASE ") 73 | if d.What != nil { 74 | sql.WriteSql(d.What) 75 | } 76 | 77 | for _, p := range d.WhenParts { 78 | sql.WriteString("WHEN ") 79 | sql.WriteSql(p.when) 80 | sql.WriteString("THEN ") 81 | sql.WriteSql(p.then) 82 | } 83 | 84 | if d.Else != nil { 85 | sql.WriteString("ELSE ") 86 | sql.WriteSql(d.Else) 87 | } 88 | 89 | sql.WriteString("END") 90 | 91 | return sql.ToSql() 92 | } 93 | 94 | // CaseBuilder builds SQL CASE construct which could be used as parts of queries. 95 | type CaseBuilder builder.Builder 96 | 97 | // ToSql builds the query into a SQL string and bound args. 98 | func (b CaseBuilder) ToSql() (string, []interface{}, error) { 99 | data := builder.GetStruct(b).(caseData) 100 | return data.ToSql() 101 | } 102 | 103 | // MustSql builds the query into a SQL string and bound args. 104 | // It panics if there are any errors. 105 | func (b CaseBuilder) MustSql() (string, []interface{}) { 106 | sql, args, err := b.ToSql() 107 | if err != nil { 108 | panic(err) 109 | } 110 | return sql, args 111 | } 112 | 113 | // what sets optional value for CASE construct "CASE [value] ..." 114 | func (b CaseBuilder) what(expr interface{}) CaseBuilder { 115 | return builder.Set(b, "What", newPart(expr)).(CaseBuilder) 116 | } 117 | 118 | // When adds "WHEN ... THEN ..." part to CASE construct 119 | func (b CaseBuilder) When(when interface{}, then interface{}) CaseBuilder { 120 | // TODO: performance hint: replace slice of WhenPart with just slice of parts 121 | // where even indices of the slice belong to "when"s and odd indices belong to "then"s 122 | return builder.Append(b, "WhenParts", newWhenPart(when, then)).(CaseBuilder) 123 | } 124 | 125 | // What sets optional "ELSE ..." part for CASE construct 126 | func (b CaseBuilder) Else(expr interface{}) CaseBuilder { 127 | return builder.Set(b, "Else", newPart(expr)).(CaseBuilder) 128 | } 129 | -------------------------------------------------------------------------------- /statement.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import "github.com/lann/builder" 4 | 5 | // StatementBuilderType is the type of StatementBuilder. 6 | type StatementBuilderType builder.Builder 7 | 8 | // Select returns a SelectBuilder for this StatementBuilderType. 9 | func (b StatementBuilderType) Select(columns ...string) SelectBuilder { 10 | return SelectBuilder(b).Columns(columns...) 11 | } 12 | 13 | // Insert returns a InsertBuilder for this StatementBuilderType. 14 | func (b StatementBuilderType) Insert(into string) InsertBuilder { 15 | return InsertBuilder(b).Into(into) 16 | } 17 | 18 | // Replace returns a InsertBuilder for this StatementBuilderType with the 19 | // statement keyword set to "REPLACE". 20 | func (b StatementBuilderType) Replace(into string) InsertBuilder { 21 | return InsertBuilder(b).statementKeyword("REPLACE").Into(into) 22 | } 23 | 24 | // Update returns a UpdateBuilder for this StatementBuilderType. 25 | func (b StatementBuilderType) Update(table string) UpdateBuilder { 26 | return UpdateBuilder(b).Table(table) 27 | } 28 | 29 | // Delete returns a DeleteBuilder for this StatementBuilderType. 30 | func (b StatementBuilderType) Delete(from string) DeleteBuilder { 31 | return DeleteBuilder(b).From(from) 32 | } 33 | 34 | // PlaceholderFormat sets the PlaceholderFormat field for any child builders. 35 | func (b StatementBuilderType) PlaceholderFormat(f PlaceholderFormat) StatementBuilderType { 36 | return builder.Set(b, "PlaceholderFormat", f).(StatementBuilderType) 37 | } 38 | 39 | // RunWith sets the RunWith field for any child builders. 40 | func (b StatementBuilderType) RunWith(runner BaseRunner) StatementBuilderType { 41 | return setRunWith(b, runner).(StatementBuilderType) 42 | } 43 | 44 | // Where adds WHERE expressions to the query. 45 | // 46 | // See SelectBuilder.Where for more information. 47 | func (b StatementBuilderType) Where(pred interface{}, args ...interface{}) StatementBuilderType { 48 | return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(StatementBuilderType) 49 | } 50 | 51 | // StatementBuilder is a parent builder for other builders, e.g. SelectBuilder. 52 | var StatementBuilder = StatementBuilderType(builder.EmptyBuilder).PlaceholderFormat(Question) 53 | 54 | // Select returns a new SelectBuilder, optionally setting some result columns. 55 | // 56 | // See SelectBuilder.Columns. 57 | func Select(columns ...string) SelectBuilder { 58 | return StatementBuilder.Select(columns...) 59 | } 60 | 61 | // Insert returns a new InsertBuilder with the given table name. 62 | // 63 | // See InsertBuilder.Into. 64 | func Insert(into string) InsertBuilder { 65 | return StatementBuilder.Insert(into) 66 | } 67 | 68 | // Replace returns a new InsertBuilder with the statement keyword set to 69 | // "REPLACE" and with the given table name. 70 | // 71 | // See InsertBuilder.Into. 72 | func Replace(into string) InsertBuilder { 73 | return StatementBuilder.Replace(into) 74 | } 75 | 76 | // Update returns a new UpdateBuilder with the given table name. 77 | // 78 | // See UpdateBuilder.Table. 79 | func Update(table string) UpdateBuilder { 80 | return StatementBuilder.Update(table) 81 | } 82 | 83 | // Delete returns a new DeleteBuilder with the given table name. 84 | // 85 | // See DeleteBuilder.Table. 86 | func Delete(from string) DeleteBuilder { 87 | return StatementBuilder.Delete(from) 88 | } 89 | 90 | // Case returns a new CaseBuilder 91 | // "what" represents case value 92 | func Case(what ...interface{}) CaseBuilder { 93 | b := CaseBuilder(builder.EmptyBuilder) 94 | 95 | switch len(what) { 96 | case 0: 97 | case 1: 98 | b = b.what(what[0]) 99 | default: 100 | b = b.what(newPart(what[0], what[1:]...)) 101 | 102 | } 103 | return b 104 | } 105 | -------------------------------------------------------------------------------- /integration/integration_test.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "flag" 7 | "fmt" 8 | "os" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | 13 | sqrl "github.com/Masterminds/squirrel" 14 | 15 | _ "github.com/go-sql-driver/mysql" 16 | _ "github.com/lib/pq" 17 | _ "github.com/mattn/go-sqlite3" 18 | ) 19 | 20 | const ( 21 | testSchema = ` 22 | CREATE TABLE squirrel_integration ( k INT, v TEXT )` 23 | testData = ` 24 | INSERT INTO squirrel_integration VALUES 25 | (1, 'foo'), 26 | (3, 'bar'), 27 | (2, 'foo'), 28 | (4, 'baz') 29 | ` 30 | ) 31 | 32 | var ( 33 | sb sqrl.StatementBuilderType 34 | ) 35 | 36 | func TestMain(m *testing.M) { 37 | var driver, dataSource string 38 | flag.StringVar(&driver, "driver", "", "integration database driver") 39 | flag.StringVar(&dataSource, "dataSource", "", "integration database data source") 40 | flag.Parse() 41 | 42 | if driver == "" { 43 | driver = "sqlite3" 44 | } 45 | 46 | if driver == "sqlite3" && dataSource == "" { 47 | dataSource = ":memory:" 48 | } 49 | 50 | db, err := sql.Open(driver, dataSource) 51 | if err != nil { 52 | fmt.Printf("error opening database: %v\n", err) 53 | os.Exit(-1) 54 | } 55 | 56 | _, err = db.Exec(testSchema) 57 | if err != nil { 58 | fmt.Printf("error creating test schema: %v\n", err) 59 | os.Exit(-2) 60 | } 61 | 62 | defer func() { 63 | _, err = db.Exec("DROP TABLE squirrel_integration") 64 | fmt.Printf("error removing test schema: %v\n", err) 65 | }() 66 | 67 | _, err = db.Exec(testData) 68 | if err != nil { 69 | fmt.Printf("error inserting test data: %v\n", err) 70 | os.Exit(-3) 71 | } 72 | 73 | sb = sqrl.StatementBuilder.RunWith(db) 74 | 75 | if driver == "postgres" { 76 | sb = sb.PlaceholderFormat(sqrl.Dollar) 77 | } 78 | 79 | os.Exit(m.Run()) 80 | } 81 | 82 | func assertVals(t *testing.T, s sqrl.SelectBuilder, expected ...string) { 83 | rows, err := s.Query() 84 | assert.NoError(t, err) 85 | defer rows.Close() 86 | 87 | vals := make([]string, len(expected)) 88 | for i := range vals { 89 | assert.True(t, rows.Next()) 90 | assert.NoError(t, rows.Scan(&vals[i])) 91 | } 92 | assert.False(t, rows.Next()) 93 | 94 | if expected != nil { 95 | assert.Equal(t, expected, vals) 96 | } 97 | } 98 | 99 | func TestSimpleSelect(t *testing.T) { 100 | assertVals( 101 | t, 102 | sb.Select("v").From("squirrel_integration"), 103 | "foo", "bar", "foo", "baz") 104 | } 105 | 106 | func TestEq(t *testing.T) { 107 | s := sb.Select("v").From("squirrel_integration") 108 | assertVals(t, s.Where(sqrl.Eq{"k": 4}), "baz") 109 | assertVals(t, s.Where(sqrl.NotEq{"k": 2}), "foo", "bar", "baz") 110 | assertVals(t, s.Where(sqrl.Eq{"k": []int{1, 4}}), "foo", "baz") 111 | assertVals(t, s.Where(sqrl.NotEq{"k": []int{1, 4}}), "bar", "foo") 112 | assertVals(t, s.Where(sqrl.Eq{"k": nil})) 113 | assertVals(t, s.Where(sqrl.NotEq{"k": nil}), "foo", "bar", "foo", "baz") 114 | assertVals(t, s.Where(sqrl.Eq{"k": []int{}})) 115 | assertVals(t, s.Where(sqrl.NotEq{"k": []int{}}), "foo", "bar", "foo", "baz") 116 | } 117 | 118 | func TestIneq(t *testing.T) { 119 | s := sb.Select("v").From("squirrel_integration") 120 | assertVals(t, s.Where(sqrl.Lt{"k": 3}), "foo", "foo") 121 | assertVals(t, s.Where(sqrl.Gt{"k": 3}), "baz") 122 | } 123 | 124 | func TestConj(t *testing.T) { 125 | s := sb.Select("v").From("squirrel_integration") 126 | assertVals(t, s.Where(sqrl.And{sqrl.Gt{"k": 1}, sqrl.Lt{"k": 4}}), "bar", "foo") 127 | assertVals(t, s.Where(sqrl.Or{sqrl.Gt{"k": 3}, sqrl.Lt{"k": 2}}), "foo", "baz") 128 | } 129 | 130 | func TestContext(t *testing.T) { 131 | s := sb.Select("v").From("squirrel_integration") 132 | ctx := context.Background() 133 | _, err := s.QueryContext(ctx) 134 | assert.NoError(t, err) 135 | } 136 | -------------------------------------------------------------------------------- /case_test.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestCaseWithVal(t *testing.T) { 10 | caseStmt := Case("number"). 11 | When("1", "one"). 12 | When("2", "two"). 13 | Else(Expr("?", "big number")) 14 | 15 | qb := Select(). 16 | Column(caseStmt). 17 | From("table") 18 | sql, args, err := qb.ToSql() 19 | 20 | assert.NoError(t, err) 21 | 22 | expectedSql := "SELECT CASE number " + 23 | "WHEN 1 THEN one " + 24 | "WHEN 2 THEN two " + 25 | "ELSE ? " + 26 | "END " + 27 | "FROM table" 28 | assert.Equal(t, expectedSql, sql) 29 | 30 | expectedArgs := []interface{}{"big number"} 31 | assert.Equal(t, expectedArgs, args) 32 | } 33 | 34 | func TestCaseWithComplexVal(t *testing.T) { 35 | caseStmt := Case("? > ?", 10, 5). 36 | When("true", "'T'") 37 | 38 | qb := Select(). 39 | Column(Alias(caseStmt, "complexCase")). 40 | From("table") 41 | sql, args, err := qb.ToSql() 42 | 43 | assert.NoError(t, err) 44 | 45 | expectedSql := "SELECT (CASE ? > ? " + 46 | "WHEN true THEN 'T' " + 47 | "END) AS complexCase " + 48 | "FROM table" 49 | assert.Equal(t, expectedSql, sql) 50 | 51 | expectedArgs := []interface{}{10, 5} 52 | assert.Equal(t, expectedArgs, args) 53 | } 54 | 55 | func TestCaseWithNoVal(t *testing.T) { 56 | caseStmt := Case(). 57 | When(Eq{"x": 0}, "x is zero"). 58 | When(Expr("x > ?", 1), Expr("CONCAT('x is greater than ', ?)", 2)) 59 | 60 | qb := Select().Column(caseStmt).From("table") 61 | sql, args, err := qb.ToSql() 62 | 63 | assert.NoError(t, err) 64 | 65 | expectedSql := "SELECT CASE " + 66 | "WHEN x = ? THEN x is zero " + 67 | "WHEN x > ? THEN CONCAT('x is greater than ', ?) " + 68 | "END " + 69 | "FROM table" 70 | 71 | assert.Equal(t, expectedSql, sql) 72 | 73 | expectedArgs := []interface{}{0, 1, 2} 74 | assert.Equal(t, expectedArgs, args) 75 | } 76 | 77 | func TestCaseWithExpr(t *testing.T) { 78 | caseStmt := Case(Expr("x = ?", true)). 79 | When("true", Expr("?", "it's true!")). 80 | Else("42") 81 | 82 | qb := Select().Column(caseStmt).From("table") 83 | sql, args, err := qb.ToSql() 84 | 85 | assert.NoError(t, err) 86 | 87 | expectedSql := "SELECT CASE x = ? " + 88 | "WHEN true THEN ? " + 89 | "ELSE 42 " + 90 | "END " + 91 | "FROM table" 92 | 93 | assert.Equal(t, expectedSql, sql) 94 | 95 | expectedArgs := []interface{}{true, "it's true!"} 96 | assert.Equal(t, expectedArgs, args) 97 | } 98 | 99 | func TestMultipleCase(t *testing.T) { 100 | caseStmtNoval := Case(Expr("x = ?", true)). 101 | When("true", Expr("?", "it's true!")). 102 | Else("42") 103 | caseStmtExpr := Case(). 104 | When(Eq{"x": 0}, "'x is zero'"). 105 | When(Expr("x > ?", 1), Expr("CONCAT('x is greater than ', ?)", 2)) 106 | 107 | qb := Select(). 108 | Column(Alias(caseStmtNoval, "case_noval")). 109 | Column(Alias(caseStmtExpr, "case_expr")). 110 | From("table") 111 | 112 | sql, args, err := qb.ToSql() 113 | 114 | assert.NoError(t, err) 115 | 116 | expectedSql := "SELECT " + 117 | "(CASE x = ? WHEN true THEN ? ELSE 42 END) AS case_noval, " + 118 | "(CASE WHEN x = ? THEN 'x is zero' WHEN x > ? THEN CONCAT('x is greater than ', ?) END) AS case_expr " + 119 | "FROM table" 120 | 121 | assert.Equal(t, expectedSql, sql) 122 | 123 | expectedArgs := []interface{}{ 124 | true, "it's true!", 125 | 0, 1, 2, 126 | } 127 | assert.Equal(t, expectedArgs, args) 128 | } 129 | 130 | func TestCaseWithNoWhenClause(t *testing.T) { 131 | caseStmt := Case("something"). 132 | Else("42") 133 | 134 | qb := Select().Column(caseStmt).From("table") 135 | 136 | _, _, err := qb.ToSql() 137 | 138 | assert.Error(t, err) 139 | 140 | assert.Equal(t, "case expression must contain at lease one WHEN clause", err.Error()) 141 | } 142 | 143 | func TestCaseBuilderMustSql(t *testing.T) { 144 | defer func() { 145 | if r := recover(); r == nil { 146 | t.Errorf("TestCaseBuilderMustSql should have panicked!") 147 | } 148 | }() 149 | Case("").MustSql() 150 | } 151 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Stability: Maintenance](https://masterminds.github.io/stability/maintenance.svg)](https://masterminds.github.io/stability/maintenance.html) 2 | ### Squirrel is "complete". 3 | Bug fixes will still be merged (slowly). Bug reports are welcome, but I will not necessarily respond to them. If another fork (or substantially similar project) actively improves on what Squirrel does, let me know and I may link to it here. 4 | 5 | 6 | # Squirrel - fluent SQL generator for Go 7 | 8 | ```go 9 | import "github.com/Masterminds/squirrel" 10 | ``` 11 | 12 | 13 | [![GoDoc](https://godoc.org/github.com/Masterminds/squirrel?status.png)](https://godoc.org/github.com/Masterminds/squirrel) 14 | [![Build Status](https://api.travis-ci.org/Masterminds/squirrel.svg?branch=master)](https://travis-ci.org/Masterminds/squirrel) 15 | 16 | **Squirrel is not an ORM.** For an application of Squirrel, check out 17 | [structable, a table-struct mapper](https://github.com/Masterminds/structable) 18 | 19 | 20 | Squirrel helps you build SQL queries from composable parts: 21 | 22 | ```go 23 | import sq "github.com/Masterminds/squirrel" 24 | 25 | users := sq.Select("*").From("users").Join("emails USING (email_id)") 26 | 27 | active := users.Where(sq.Eq{"deleted_at": nil}) 28 | 29 | sql, args, err := active.ToSql() 30 | 31 | sql == "SELECT * FROM users JOIN emails USING (email_id) WHERE deleted_at IS NULL" 32 | ``` 33 | 34 | ```go 35 | sql, args, err := sq. 36 | Insert("users").Columns("name", "age"). 37 | Values("moe", 13).Values("larry", sq.Expr("? + 5", 12)). 38 | ToSql() 39 | 40 | sql == "INSERT INTO users (name,age) VALUES (?,?),(?,? + 5)" 41 | ``` 42 | 43 | Squirrel can also execute queries directly: 44 | 45 | ```go 46 | stooges := users.Where(sq.Eq{"username": []string{"moe", "larry", "curly", "shemp"}}) 47 | three_stooges := stooges.Limit(3) 48 | rows, err := three_stooges.RunWith(db).Query() 49 | 50 | // Behaves like: 51 | rows, err := db.Query("SELECT * FROM users WHERE username IN (?,?,?,?) LIMIT 3", 52 | "moe", "larry", "curly", "shemp") 53 | ``` 54 | 55 | Squirrel makes conditional query building a breeze: 56 | 57 | ```go 58 | if len(q) > 0 { 59 | users = users.Where("name LIKE ?", fmt.Sprint("%", q, "%")) 60 | } 61 | ``` 62 | 63 | Squirrel wants to make your life easier: 64 | 65 | ```go 66 | // StmtCache caches Prepared Stmts for you 67 | dbCache := sq.NewStmtCache(db) 68 | 69 | // StatementBuilder keeps your syntax neat 70 | mydb := sq.StatementBuilder.RunWith(dbCache) 71 | select_users := mydb.Select("*").From("users") 72 | ``` 73 | 74 | Squirrel loves PostgreSQL: 75 | 76 | ```go 77 | psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar) 78 | 79 | // You use question marks for placeholders... 80 | sql, _, _ := psql.Select("*").From("elephants").Where("name IN (?,?)", "Dumbo", "Verna").ToSql() 81 | 82 | /// ...squirrel replaces them using PlaceholderFormat. 83 | sql == "SELECT * FROM elephants WHERE name IN ($1,$2)" 84 | 85 | 86 | /// You can retrieve id ... 87 | query := sq.Insert("nodes"). 88 | Columns("uuid", "type", "data"). 89 | Values(node.Uuid, node.Type, node.Data). 90 | Suffix("RETURNING \"id\""). 91 | RunWith(m.db). 92 | PlaceholderFormat(sq.Dollar) 93 | 94 | query.QueryRow().Scan(&node.id) 95 | ``` 96 | 97 | You can escape question marks by inserting two question marks: 98 | 99 | ```sql 100 | SELECT * FROM nodes WHERE meta->'format' ??| array[?,?] 101 | ``` 102 | 103 | will generate with the Dollar Placeholder: 104 | 105 | ```sql 106 | SELECT * FROM nodes WHERE meta->'format' ?| array[$1,$2] 107 | ``` 108 | 109 | ## FAQ 110 | 111 | * **How can I build an IN query on composite keys / tuples, e.g. `WHERE (col1, col2) IN ((1,2),(3,4))`? ([#104](https://github.com/Masterminds/squirrel/issues/104))** 112 | 113 | Squirrel does not explicitly support tuples, but you can get the same effect with e.g.: 114 | 115 | ```go 116 | sq.Or{ 117 | sq.Eq{"col1": 1, "col2": 2}, 118 | sq.Eq{"col1": 3, "col2": 4}} 119 | ``` 120 | 121 | ```sql 122 | WHERE (col1 = 1 AND col2 = 2) OR (col1 = 3 AND col2 = 4) 123 | ``` 124 | 125 | (which should produce the same query plan as the tuple version) 126 | 127 | * **Why doesn't `Eq{"mynumber": []uint8{1,2,3}}` turn into an `IN` query? ([#114](https://github.com/Masterminds/squirrel/issues/114))** 128 | 129 | Values of type `[]byte` are handled specially by `database/sql`. In Go, [`byte` is just an alias of `uint8`](https://golang.org/pkg/builtin/#byte), so there is no way to distinguish `[]uint8` from `[]byte`. 130 | 131 | * **Some features are poorly documented!** 132 | 133 | This isn't a frequent complaints section! 134 | 135 | * **Some features are poorly documented?** 136 | 137 | Yes. The tests should be considered a part of the documentation; take a look at those for ideas on how to express more complex queries. 138 | 139 | ## License 140 | 141 | Squirrel is released under the 142 | [MIT License](http://www.opensource.org/licenses/MIT). 143 | -------------------------------------------------------------------------------- /squirrel.go: -------------------------------------------------------------------------------- 1 | // Package squirrel provides a fluent SQL generator. 2 | // 3 | // See https://github.com/Masterminds/squirrel for examples. 4 | package squirrel 5 | 6 | import ( 7 | "bytes" 8 | "database/sql" 9 | "fmt" 10 | "strings" 11 | 12 | "github.com/lann/builder" 13 | ) 14 | 15 | // Sqlizer is the interface that wraps the ToSql method. 16 | // 17 | // ToSql returns a SQL representation of the Sqlizer, along with a slice of args 18 | // as passed to e.g. database/sql.Exec. It can also return an error. 19 | type Sqlizer interface { 20 | ToSql() (string, []interface{}, error) 21 | } 22 | 23 | // rawSqlizer is expected to do what Sqlizer does, but without finalizing placeholders. 24 | // This is useful for nested queries. 25 | type rawSqlizer interface { 26 | toSqlRaw() (string, []interface{}, error) 27 | } 28 | 29 | // Execer is the interface that wraps the Exec method. 30 | // 31 | // Exec executes the given query as implemented by database/sql.Exec. 32 | type Execer interface { 33 | Exec(query string, args ...interface{}) (sql.Result, error) 34 | } 35 | 36 | // Queryer is the interface that wraps the Query method. 37 | // 38 | // Query executes the given query as implemented by database/sql.Query. 39 | type Queryer interface { 40 | Query(query string, args ...interface{}) (*sql.Rows, error) 41 | } 42 | 43 | // QueryRower is the interface that wraps the QueryRow method. 44 | // 45 | // QueryRow executes the given query as implemented by database/sql.QueryRow. 46 | type QueryRower interface { 47 | QueryRow(query string, args ...interface{}) RowScanner 48 | } 49 | 50 | // BaseRunner groups the Execer and Queryer interfaces. 51 | type BaseRunner interface { 52 | Execer 53 | Queryer 54 | } 55 | 56 | // Runner groups the Execer, Queryer, and QueryRower interfaces. 57 | type Runner interface { 58 | Execer 59 | Queryer 60 | QueryRower 61 | } 62 | 63 | // WrapStdSql wraps a type implementing the standard SQL interface with methods that 64 | // squirrel expects. 65 | func WrapStdSql(stdSql StdSql) Runner { 66 | return &stdsqlRunner{stdSql} 67 | } 68 | 69 | // StdSql encompasses the standard methods of the *sql.DB type, and other types that 70 | // wrap these methods. 71 | type StdSql interface { 72 | Query(string, ...interface{}) (*sql.Rows, error) 73 | QueryRow(string, ...interface{}) *sql.Row 74 | Exec(string, ...interface{}) (sql.Result, error) 75 | } 76 | 77 | type stdsqlRunner struct { 78 | StdSql 79 | } 80 | 81 | func (r *stdsqlRunner) QueryRow(query string, args ...interface{}) RowScanner { 82 | return r.StdSql.QueryRow(query, args...) 83 | } 84 | 85 | func setRunWith(b interface{}, runner BaseRunner) interface{} { 86 | switch r := runner.(type) { 87 | case StdSqlCtx: 88 | runner = WrapStdSqlCtx(r) 89 | case StdSql: 90 | runner = WrapStdSql(r) 91 | } 92 | return builder.Set(b, "RunWith", runner) 93 | } 94 | 95 | // RunnerNotSet is returned by methods that need a Runner if it isn't set. 96 | var RunnerNotSet = fmt.Errorf("cannot run; no Runner set (RunWith)") 97 | 98 | // RunnerNotQueryRunner is returned by QueryRow if the RunWith value doesn't implement QueryRower. 99 | var RunnerNotQueryRunner = fmt.Errorf("cannot QueryRow; Runner is not a QueryRower") 100 | 101 | // ExecWith Execs the SQL returned by s with db. 102 | func ExecWith(db Execer, s Sqlizer) (res sql.Result, err error) { 103 | query, args, err := s.ToSql() 104 | if err != nil { 105 | return 106 | } 107 | return db.Exec(query, args...) 108 | } 109 | 110 | // QueryWith Querys the SQL returned by s with db. 111 | func QueryWith(db Queryer, s Sqlizer) (rows *sql.Rows, err error) { 112 | query, args, err := s.ToSql() 113 | if err != nil { 114 | return 115 | } 116 | return db.Query(query, args...) 117 | } 118 | 119 | // QueryRowWith QueryRows the SQL returned by s with db. 120 | func QueryRowWith(db QueryRower, s Sqlizer) RowScanner { 121 | query, args, err := s.ToSql() 122 | return &Row{RowScanner: db.QueryRow(query, args...), err: err} 123 | } 124 | 125 | // DebugSqlizer calls ToSql on s and shows the approximate SQL to be executed 126 | // 127 | // If ToSql returns an error, the result of this method will look like: 128 | // "[ToSql error: %s]" or "[DebugSqlizer error: %s]" 129 | // 130 | // IMPORTANT: As its name suggests, this function should only be used for 131 | // debugging. While the string result *might* be valid SQL, this function does 132 | // not try very hard to ensure it. Additionally, executing the output of this 133 | // function with any untrusted user input is certainly insecure. 134 | func DebugSqlizer(s Sqlizer) string { 135 | sql, args, err := s.ToSql() 136 | if err != nil { 137 | return fmt.Sprintf("[ToSql error: %s]", err) 138 | } 139 | 140 | var placeholder string 141 | downCast, ok := s.(placeholderDebugger) 142 | if !ok { 143 | placeholder = "?" 144 | } else { 145 | placeholder = downCast.debugPlaceholder() 146 | } 147 | // TODO: dedupe this with placeholder.go 148 | buf := &bytes.Buffer{} 149 | i := 0 150 | for { 151 | p := strings.Index(sql, placeholder) 152 | if p == -1 { 153 | break 154 | } 155 | if len(sql[p:]) > 1 && sql[p:p+2] == "??" { // escape ?? => ? 156 | buf.WriteString(sql[:p]) 157 | buf.WriteString("?") 158 | if len(sql[p:]) == 1 { 159 | break 160 | } 161 | sql = sql[p+2:] 162 | } else { 163 | if i+1 > len(args) { 164 | return fmt.Sprintf( 165 | "[DebugSqlizer error: too many placeholders in %#v for %d args]", 166 | sql, len(args)) 167 | } 168 | buf.WriteString(sql[:p]) 169 | fmt.Fprintf(buf, "'%v'", args[i]) 170 | // advance our sql string "cursor" beyond the arg we placed 171 | sql = sql[p+1:] 172 | i++ 173 | } 174 | } 175 | if i < len(args) { 176 | return fmt.Sprintf( 177 | "[DebugSqlizer error: not enough placeholders in %#v for %d args]", 178 | sql, len(args)) 179 | } 180 | // "append" any remaning sql that won't need interpolating 181 | buf.WriteString(sql) 182 | return buf.String() 183 | } 184 | -------------------------------------------------------------------------------- /squirrel_test.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | type DBStub struct { 13 | err error 14 | 15 | LastPrepareSql string 16 | PrepareCount int 17 | 18 | LastExecSql string 19 | LastExecArgs []interface{} 20 | 21 | LastQuerySql string 22 | LastQueryArgs []interface{} 23 | 24 | LastQueryRowSql string 25 | LastQueryRowArgs []interface{} 26 | } 27 | 28 | var StubError = fmt.Errorf("this is a stub; this is only a stub") 29 | 30 | func (s *DBStub) Prepare(query string) (*sql.Stmt, error) { 31 | s.LastPrepareSql = query 32 | s.PrepareCount++ 33 | return nil, nil 34 | } 35 | 36 | func (s *DBStub) Exec(query string, args ...interface{}) (sql.Result, error) { 37 | s.LastExecSql = query 38 | s.LastExecArgs = args 39 | return nil, nil 40 | } 41 | 42 | func (s *DBStub) Query(query string, args ...interface{}) (*sql.Rows, error) { 43 | s.LastQuerySql = query 44 | s.LastQueryArgs = args 45 | return nil, nil 46 | } 47 | 48 | func (s *DBStub) QueryRow(query string, args ...interface{}) RowScanner { 49 | s.LastQueryRowSql = query 50 | s.LastQueryRowArgs = args 51 | return &Row{RowScanner: &RowStub{}} 52 | } 53 | 54 | var sqlizer = Select("test") 55 | var sqlStr = "SELECT test" 56 | 57 | func TestExecWith(t *testing.T) { 58 | db := &DBStub{} 59 | ExecWith(db, sqlizer) 60 | assert.Equal(t, sqlStr, db.LastExecSql) 61 | } 62 | 63 | func TestQueryWith(t *testing.T) { 64 | db := &DBStub{} 65 | QueryWith(db, sqlizer) 66 | assert.Equal(t, sqlStr, db.LastQuerySql) 67 | } 68 | 69 | func TestQueryRowWith(t *testing.T) { 70 | db := &DBStub{} 71 | QueryRowWith(db, sqlizer) 72 | assert.Equal(t, sqlStr, db.LastQueryRowSql) 73 | } 74 | 75 | func TestWithToSqlErr(t *testing.T) { 76 | db := &DBStub{} 77 | sqlizer := Select() 78 | 79 | _, err := ExecWith(db, sqlizer) 80 | assert.Error(t, err) 81 | 82 | _, err = QueryWith(db, sqlizer) 83 | assert.Error(t, err) 84 | 85 | err = QueryRowWith(db, sqlizer).Scan() 86 | assert.Error(t, err) 87 | } 88 | 89 | var testDebugUpdateSQL = Update("table").SetMap(Eq{"x": 1, "y": "val"}) 90 | var expectedDebugUpateSQL = "UPDATE table SET x = '1', y = 'val'" 91 | 92 | func TestDebugSqlizerUpdateColon(t *testing.T) { 93 | testDebugUpdateSQL.PlaceholderFormat(Colon) 94 | assert.Equal(t, expectedDebugUpateSQL, DebugSqlizer(testDebugUpdateSQL)) 95 | } 96 | 97 | func TestDebugSqlizerUpdateAtp(t *testing.T) { 98 | testDebugUpdateSQL.PlaceholderFormat(AtP) 99 | assert.Equal(t, expectedDebugUpateSQL, DebugSqlizer(testDebugUpdateSQL)) 100 | } 101 | 102 | func TestDebugSqlizerUpdateDollar(t *testing.T) { 103 | testDebugUpdateSQL.PlaceholderFormat(Dollar) 104 | assert.Equal(t, expectedDebugUpateSQL, DebugSqlizer(testDebugUpdateSQL)) 105 | } 106 | 107 | func TestDebugSqlizerUpdateQuestion(t *testing.T) { 108 | testDebugUpdateSQL.PlaceholderFormat(Question) 109 | assert.Equal(t, expectedDebugUpateSQL, DebugSqlizer(testDebugUpdateSQL)) 110 | } 111 | 112 | var testDebugDeleteSQL = Delete("table").Where(And{ 113 | Eq{"column": "val"}, 114 | Eq{"other": 1}, 115 | }) 116 | var expectedDebugDeleteSQL = "DELETE table FROM table WHERE (column = 'val' AND other = '1')" 117 | 118 | func TestDebugSqlizerDeleteColon(t *testing.T) { 119 | testDebugDeleteSQL.PlaceholderFormat(Colon) 120 | assert.Equal(t, expectedDebugDeleteSQL, DebugSqlizer(testDebugDeleteSQL)) 121 | } 122 | 123 | func TestDebugSqlizerDeleteAtp(t *testing.T) { 124 | testDebugDeleteSQL.PlaceholderFormat(AtP) 125 | assert.Equal(t, expectedDebugDeleteSQL, DebugSqlizer(testDebugDeleteSQL)) 126 | } 127 | 128 | func TestDebugSqlizerDeleteDollar(t *testing.T) { 129 | testDebugDeleteSQL.PlaceholderFormat(Dollar) 130 | assert.Equal(t, expectedDebugDeleteSQL, DebugSqlizer(testDebugDeleteSQL)) 131 | } 132 | 133 | func TestDebugSqlizerDeleteQuestion(t *testing.T) { 134 | testDebugDeleteSQL.PlaceholderFormat(Question) 135 | assert.Equal(t, expectedDebugDeleteSQL, DebugSqlizer(testDebugDeleteSQL)) 136 | } 137 | 138 | var testDebugInsertSQL = Insert("table").Values(1, "test") 139 | var expectedDebugInsertSQL = "INSERT INTO table VALUES ('1','test')" 140 | 141 | func TestDebugSqlizerInsertColon(t *testing.T) { 142 | testDebugInsertSQL.PlaceholderFormat(Colon) 143 | assert.Equal(t, expectedDebugInsertSQL, DebugSqlizer(testDebugInsertSQL)) 144 | } 145 | 146 | func TestDebugSqlizerInsertAtp(t *testing.T) { 147 | testDebugInsertSQL.PlaceholderFormat(AtP) 148 | assert.Equal(t, expectedDebugInsertSQL, DebugSqlizer(testDebugInsertSQL)) 149 | } 150 | 151 | func TestDebugSqlizerInsertDollar(t *testing.T) { 152 | testDebugInsertSQL.PlaceholderFormat(Dollar) 153 | assert.Equal(t, expectedDebugInsertSQL, DebugSqlizer(testDebugInsertSQL)) 154 | } 155 | 156 | func TestDebugSqlizerInsertQuestion(t *testing.T) { 157 | testDebugInsertSQL.PlaceholderFormat(Question) 158 | assert.Equal(t, expectedDebugInsertSQL, DebugSqlizer(testDebugInsertSQL)) 159 | } 160 | 161 | var testDebugSelectSQL = Select("*").From("table").Where(And{ 162 | Eq{"column": "val"}, 163 | Eq{"other": 1}, 164 | }) 165 | var expectedDebugSelectSQL = "SELECT * FROM table WHERE (column = 'val' AND other = '1')" 166 | 167 | func TestDebugSqlizerSelectColon(t *testing.T) { 168 | testDebugSelectSQL.PlaceholderFormat(Colon) 169 | assert.Equal(t, expectedDebugSelectSQL, DebugSqlizer(testDebugSelectSQL)) 170 | } 171 | 172 | func TestDebugSqlizerSelectAtp(t *testing.T) { 173 | testDebugSelectSQL.PlaceholderFormat(AtP) 174 | assert.Equal(t, expectedDebugSelectSQL, DebugSqlizer(testDebugSelectSQL)) 175 | } 176 | 177 | func TestDebugSqlizerSelectDollar(t *testing.T) { 178 | testDebugSelectSQL.PlaceholderFormat(Dollar) 179 | assert.Equal(t, expectedDebugSelectSQL, DebugSqlizer(testDebugSelectSQL)) 180 | } 181 | 182 | func TestDebugSqlizerSelectQuestion(t *testing.T) { 183 | testDebugSelectSQL.PlaceholderFormat(Question) 184 | assert.Equal(t, expectedDebugSelectSQL, DebugSqlizer(testDebugSelectSQL)) 185 | } 186 | 187 | func TestDebugSqlizer(t *testing.T) { 188 | sqlizer := Expr("x = ? AND y = ? AND z = '??'", 1, "text") 189 | expectedDebug := "x = '1' AND y = 'text' AND z = '?'" 190 | assert.Equal(t, expectedDebug, DebugSqlizer(sqlizer)) 191 | } 192 | 193 | func TestDebugSqlizerErrors(t *testing.T) { 194 | errorMsg := DebugSqlizer(Expr("x = ?", 1, 2)) // Not enough placeholders 195 | assert.True(t, strings.HasPrefix(errorMsg, "[DebugSqlizer error: ")) 196 | 197 | errorMsg = DebugSqlizer(Expr("x = ? AND y = ?", 1)) // Too many placeholders 198 | assert.True(t, strings.HasPrefix(errorMsg, "[DebugSqlizer error: ")) 199 | 200 | errorMsg = DebugSqlizer(Lt{"x": nil}) // Cannot use nil values with Lt 201 | assert.True(t, strings.HasPrefix(errorMsg, "[ToSql error: ")) 202 | } 203 | -------------------------------------------------------------------------------- /delete.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/lann/builder" 10 | ) 11 | 12 | type deleteData struct { 13 | PlaceholderFormat PlaceholderFormat 14 | RunWith BaseRunner 15 | Prefixes []Sqlizer 16 | From string 17 | Joins []Sqlizer 18 | WhereParts []Sqlizer 19 | OrderBys []string 20 | Limit string 21 | Offset string 22 | Suffixes []Sqlizer 23 | } 24 | 25 | func (d *deleteData) Exec() (sql.Result, error) { 26 | if d.RunWith == nil { 27 | return nil, RunnerNotSet 28 | } 29 | return ExecWith(d.RunWith, d) 30 | } 31 | 32 | func (d *deleteData) ToSql() (sqlStr string, args []interface{}, err error) { 33 | if len(d.From) == 0 { 34 | err = fmt.Errorf("delete statements must specify a From table") 35 | return 36 | } 37 | 38 | sql := &bytes.Buffer{} 39 | 40 | if len(d.Prefixes) > 0 { 41 | args, err = appendToSql(d.Prefixes, sql, " ", args) 42 | if err != nil { 43 | return 44 | } 45 | 46 | sql.WriteString(" ") 47 | } 48 | 49 | // For DELETE JOINs, we need to say where we're deleting from. 50 | // DELETE FROM X doesn't do that on its own, so we now query 51 | // with DELETE X FROM X which is safe for non-JOIN queries. 52 | sql.WriteString("DELETE ") 53 | sql.WriteString(d.From) 54 | sql.WriteString(" FROM ") 55 | sql.WriteString(d.From) 56 | 57 | if len(d.Joins) > 0 { 58 | sql.WriteString(" ") 59 | args, err = appendToSql(d.Joins, sql, " ", args) 60 | if err != nil { 61 | return 62 | } 63 | } 64 | 65 | if len(d.WhereParts) > 0 { 66 | sql.WriteString(" WHERE ") 67 | args, err = appendToSql(d.WhereParts, sql, " AND ", args) 68 | if err != nil { 69 | return 70 | } 71 | } 72 | 73 | if len(d.OrderBys) > 0 { 74 | sql.WriteString(" ORDER BY ") 75 | sql.WriteString(strings.Join(d.OrderBys, ", ")) 76 | } 77 | 78 | if len(d.Limit) > 0 { 79 | sql.WriteString(" LIMIT ") 80 | sql.WriteString(d.Limit) 81 | } 82 | 83 | if len(d.Offset) > 0 { 84 | sql.WriteString(" OFFSET ") 85 | sql.WriteString(d.Offset) 86 | } 87 | 88 | if len(d.Suffixes) > 0 { 89 | sql.WriteString(" ") 90 | args, err = appendToSql(d.Suffixes, sql, " ", args) 91 | if err != nil { 92 | return 93 | } 94 | } 95 | 96 | sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String()) 97 | return 98 | } 99 | 100 | // Builder 101 | 102 | // DeleteBuilder builds SQL DELETE statements. 103 | type DeleteBuilder builder.Builder 104 | 105 | func init() { 106 | builder.Register(DeleteBuilder{}, deleteData{}) 107 | } 108 | 109 | // Format methods 110 | 111 | // PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the 112 | // query. 113 | func (b DeleteBuilder) PlaceholderFormat(f PlaceholderFormat) DeleteBuilder { 114 | return builder.Set(b, "PlaceholderFormat", f).(DeleteBuilder) 115 | } 116 | 117 | // Runner methods 118 | 119 | // RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec. 120 | func (b DeleteBuilder) RunWith(runner BaseRunner) DeleteBuilder { 121 | return setRunWith(b, runner).(DeleteBuilder) 122 | } 123 | 124 | // Exec builds and Execs the query with the Runner set by RunWith. 125 | func (b DeleteBuilder) Exec() (sql.Result, error) { 126 | data := builder.GetStruct(b).(deleteData) 127 | return data.Exec() 128 | } 129 | 130 | // SQL methods 131 | 132 | // ToSql builds the query into a SQL string and bound args. 133 | func (b DeleteBuilder) ToSql() (string, []interface{}, error) { 134 | data := builder.GetStruct(b).(deleteData) 135 | return data.ToSql() 136 | } 137 | 138 | // MustSql builds the query into a SQL string and bound args. 139 | // It panics if there are any errors. 140 | func (b DeleteBuilder) MustSql() (string, []interface{}) { 141 | sql, args, err := b.ToSql() 142 | if err != nil { 143 | panic(err) 144 | } 145 | return sql, args 146 | } 147 | 148 | // Prefix adds an expression to the beginning of the query 149 | func (b DeleteBuilder) Prefix(sql string, args ...interface{}) DeleteBuilder { 150 | return b.PrefixExpr(Expr(sql, args...)) 151 | } 152 | 153 | // PrefixExpr adds an expression to the very beginning of the query 154 | func (b DeleteBuilder) PrefixExpr(expr Sqlizer) DeleteBuilder { 155 | return builder.Append(b, "Prefixes", expr).(DeleteBuilder) 156 | } 157 | 158 | // From sets the table to be deleted from. 159 | func (b DeleteBuilder) From(from string) DeleteBuilder { 160 | return builder.Set(b, "From", from).(DeleteBuilder) 161 | } 162 | 163 | // JoinClause adds a join clause to the query. 164 | func (b DeleteBuilder) JoinClause(pred interface{}, args ...interface{}) DeleteBuilder { 165 | return builder.Append(b, "Joins", newPart(pred, args...)).(DeleteBuilder) 166 | } 167 | 168 | // Join adds a JOIN clause to the query. 169 | func (b DeleteBuilder) Join(join string, rest ...interface{}) DeleteBuilder { 170 | return b.JoinClause("JOIN "+join, rest...) 171 | } 172 | 173 | // LeftJoin adds a LEFT JOIN clause to the query. 174 | func (b DeleteBuilder) LeftJoin(join string, rest ...interface{}) DeleteBuilder { 175 | return b.JoinClause("LEFT JOIN "+join, rest...) 176 | } 177 | 178 | // RightJoin adds a RIGHT JOIN clause to the query. 179 | func (b DeleteBuilder) RightJoin(join string, rest ...interface{}) DeleteBuilder { 180 | return b.JoinClause("RIGHT JOIN "+join, rest...) 181 | } 182 | 183 | // InnerJoin adds a INNER JOIN clause to the query. 184 | func (b DeleteBuilder) InnerJoin(join string, rest ...interface{}) DeleteBuilder { 185 | return b.JoinClause("INNER JOIN "+join, rest...) 186 | } 187 | 188 | // CrossJoin adds a CROSS JOIN clause to the query. 189 | func (b DeleteBuilder) CrossJoin(join string, rest ...interface{}) DeleteBuilder { 190 | return b.JoinClause("CROSS JOIN "+join, rest...) 191 | } 192 | 193 | // Where adds WHERE expressions to the query. 194 | // 195 | // See SelectBuilder.Where for more information. 196 | func (b DeleteBuilder) Where(pred interface{}, args ...interface{}) DeleteBuilder { 197 | return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(DeleteBuilder) 198 | } 199 | 200 | // OrderBy adds ORDER BY expressions to the query. 201 | func (b DeleteBuilder) OrderBy(orderBys ...string) DeleteBuilder { 202 | return builder.Extend(b, "OrderBys", orderBys).(DeleteBuilder) 203 | } 204 | 205 | // Limit sets a LIMIT clause on the query. 206 | func (b DeleteBuilder) Limit(limit uint64) DeleteBuilder { 207 | return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(DeleteBuilder) 208 | } 209 | 210 | // Offset sets a OFFSET clause on the query. 211 | func (b DeleteBuilder) Offset(offset uint64) DeleteBuilder { 212 | return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(DeleteBuilder) 213 | } 214 | 215 | // Suffix adds an expression to the end of the query 216 | func (b DeleteBuilder) Suffix(sql string, args ...interface{}) DeleteBuilder { 217 | return b.SuffixExpr(Expr(sql, args...)) 218 | } 219 | 220 | // SuffixExpr adds an expression to the end of the query 221 | func (b DeleteBuilder) SuffixExpr(expr Sqlizer) DeleteBuilder { 222 | return builder.Append(b, "Suffixes", expr).(DeleteBuilder) 223 | } 224 | 225 | func (b DeleteBuilder) Query() (*sql.Rows, error) { 226 | data := builder.GetStruct(b).(deleteData) 227 | return data.Query() 228 | } 229 | 230 | func (d *deleteData) Query() (*sql.Rows, error) { 231 | if d.RunWith == nil { 232 | return nil, RunnerNotSet 233 | } 234 | return QueryWith(d.RunWith, d) 235 | } 236 | -------------------------------------------------------------------------------- /insert.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "sort" 10 | "strings" 11 | 12 | "github.com/lann/builder" 13 | ) 14 | 15 | type insertData struct { 16 | PlaceholderFormat PlaceholderFormat 17 | RunWith BaseRunner 18 | Prefixes []Sqlizer 19 | StatementKeyword string 20 | Options []string 21 | Into string 22 | Columns []string 23 | Values [][]interface{} 24 | Suffixes []Sqlizer 25 | Select *SelectBuilder 26 | } 27 | 28 | func (d *insertData) Exec() (sql.Result, error) { 29 | if d.RunWith == nil { 30 | return nil, RunnerNotSet 31 | } 32 | return ExecWith(d.RunWith, d) 33 | } 34 | 35 | func (d *insertData) Query() (*sql.Rows, error) { 36 | if d.RunWith == nil { 37 | return nil, RunnerNotSet 38 | } 39 | return QueryWith(d.RunWith, d) 40 | } 41 | 42 | func (d *insertData) QueryRow() RowScanner { 43 | if d.RunWith == nil { 44 | return &Row{err: RunnerNotSet} 45 | } 46 | queryRower, ok := d.RunWith.(QueryRower) 47 | if !ok { 48 | return &Row{err: RunnerNotQueryRunner} 49 | } 50 | return QueryRowWith(queryRower, d) 51 | } 52 | 53 | func (d *insertData) ToSql() (sqlStr string, args []interface{}, err error) { 54 | if len(d.Into) == 0 { 55 | err = errors.New("insert statements must specify a table") 56 | return 57 | } 58 | if len(d.Values) == 0 && d.Select == nil { 59 | err = errors.New("insert statements must have at least one set of values or select clause") 60 | return 61 | } 62 | 63 | sql := &bytes.Buffer{} 64 | 65 | if len(d.Prefixes) > 0 { 66 | args, err = appendToSql(d.Prefixes, sql, " ", args) 67 | if err != nil { 68 | return 69 | } 70 | 71 | sql.WriteString(" ") 72 | } 73 | 74 | if d.StatementKeyword == "" { 75 | sql.WriteString("INSERT ") 76 | } else { 77 | sql.WriteString(d.StatementKeyword) 78 | sql.WriteString(" ") 79 | } 80 | 81 | if len(d.Options) > 0 { 82 | sql.WriteString(strings.Join(d.Options, " ")) 83 | sql.WriteString(" ") 84 | } 85 | 86 | sql.WriteString("INTO ") 87 | sql.WriteString(d.Into) 88 | sql.WriteString(" ") 89 | 90 | if len(d.Columns) > 0 { 91 | sql.WriteString("(") 92 | sql.WriteString(strings.Join(d.Columns, ",")) 93 | sql.WriteString(") ") 94 | } 95 | 96 | if d.Select != nil { 97 | args, err = d.appendSelectToSQL(sql, args) 98 | } else { 99 | args, err = d.appendValuesToSQL(sql, args) 100 | } 101 | if err != nil { 102 | return 103 | } 104 | 105 | if len(d.Suffixes) > 0 { 106 | sql.WriteString(" ") 107 | args, err = appendToSql(d.Suffixes, sql, " ", args) 108 | if err != nil { 109 | return 110 | } 111 | } 112 | 113 | sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String()) 114 | return 115 | } 116 | 117 | func (d *insertData) appendValuesToSQL(w io.Writer, args []interface{}) ([]interface{}, error) { 118 | if len(d.Values) == 0 { 119 | return args, errors.New("values for insert statements are not set") 120 | } 121 | 122 | io.WriteString(w, "VALUES ") 123 | 124 | valuesStrings := make([]string, len(d.Values)) 125 | for r, row := range d.Values { 126 | valueStrings := make([]string, len(row)) 127 | for v, val := range row { 128 | if vs, ok := val.(Sqlizer); ok { 129 | vsql, vargs, err := vs.ToSql() 130 | if err != nil { 131 | return nil, err 132 | } 133 | valueStrings[v] = vsql 134 | args = append(args, vargs...) 135 | } else { 136 | valueStrings[v] = "?" 137 | args = append(args, val) 138 | } 139 | } 140 | valuesStrings[r] = fmt.Sprintf("(%s)", strings.Join(valueStrings, ",")) 141 | } 142 | 143 | io.WriteString(w, strings.Join(valuesStrings, ",")) 144 | 145 | return args, nil 146 | } 147 | 148 | func (d *insertData) appendSelectToSQL(w io.Writer, args []interface{}) ([]interface{}, error) { 149 | if d.Select == nil { 150 | return args, errors.New("select clause for insert statements are not set") 151 | } 152 | 153 | selectClause, sArgs, err := d.Select.ToSql() 154 | if err != nil { 155 | return args, err 156 | } 157 | 158 | io.WriteString(w, selectClause) 159 | args = append(args, sArgs...) 160 | 161 | return args, nil 162 | } 163 | 164 | // Builder 165 | 166 | // InsertBuilder builds SQL INSERT statements. 167 | type InsertBuilder builder.Builder 168 | 169 | func init() { 170 | builder.Register(InsertBuilder{}, insertData{}) 171 | } 172 | 173 | // Format methods 174 | 175 | // PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the 176 | // query. 177 | func (b InsertBuilder) PlaceholderFormat(f PlaceholderFormat) InsertBuilder { 178 | return builder.Set(b, "PlaceholderFormat", f).(InsertBuilder) 179 | } 180 | 181 | // Runner methods 182 | 183 | // RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec. 184 | func (b InsertBuilder) RunWith(runner BaseRunner) InsertBuilder { 185 | return setRunWith(b, runner).(InsertBuilder) 186 | } 187 | 188 | // Exec builds and Execs the query with the Runner set by RunWith. 189 | func (b InsertBuilder) Exec() (sql.Result, error) { 190 | data := builder.GetStruct(b).(insertData) 191 | return data.Exec() 192 | } 193 | 194 | // Query builds and Querys the query with the Runner set by RunWith. 195 | func (b InsertBuilder) Query() (*sql.Rows, error) { 196 | data := builder.GetStruct(b).(insertData) 197 | return data.Query() 198 | } 199 | 200 | // QueryRow builds and QueryRows the query with the Runner set by RunWith. 201 | func (b InsertBuilder) QueryRow() RowScanner { 202 | data := builder.GetStruct(b).(insertData) 203 | return data.QueryRow() 204 | } 205 | 206 | // Scan is a shortcut for QueryRow().Scan. 207 | func (b InsertBuilder) Scan(dest ...interface{}) error { 208 | return b.QueryRow().Scan(dest...) 209 | } 210 | 211 | // SQL methods 212 | 213 | // ToSql builds the query into a SQL string and bound args. 214 | func (b InsertBuilder) ToSql() (string, []interface{}, error) { 215 | data := builder.GetStruct(b).(insertData) 216 | return data.ToSql() 217 | } 218 | 219 | // MustSql builds the query into a SQL string and bound args. 220 | // It panics if there are any errors. 221 | func (b InsertBuilder) MustSql() (string, []interface{}) { 222 | sql, args, err := b.ToSql() 223 | if err != nil { 224 | panic(err) 225 | } 226 | return sql, args 227 | } 228 | 229 | // Prefix adds an expression to the beginning of the query 230 | func (b InsertBuilder) Prefix(sql string, args ...interface{}) InsertBuilder { 231 | return b.PrefixExpr(Expr(sql, args...)) 232 | } 233 | 234 | // PrefixExpr adds an expression to the very beginning of the query 235 | func (b InsertBuilder) PrefixExpr(expr Sqlizer) InsertBuilder { 236 | return builder.Append(b, "Prefixes", expr).(InsertBuilder) 237 | } 238 | 239 | // Options adds keyword options before the INTO clause of the query. 240 | func (b InsertBuilder) Options(options ...string) InsertBuilder { 241 | return builder.Extend(b, "Options", options).(InsertBuilder) 242 | } 243 | 244 | // Into sets the INTO clause of the query. 245 | func (b InsertBuilder) Into(from string) InsertBuilder { 246 | return builder.Set(b, "Into", from).(InsertBuilder) 247 | } 248 | 249 | // Columns adds insert columns to the query. 250 | func (b InsertBuilder) Columns(columns ...string) InsertBuilder { 251 | return builder.Extend(b, "Columns", columns).(InsertBuilder) 252 | } 253 | 254 | // Values adds a single row's values to the query. 255 | func (b InsertBuilder) Values(values ...interface{}) InsertBuilder { 256 | return builder.Append(b, "Values", values).(InsertBuilder) 257 | } 258 | 259 | // Suffix adds an expression to the end of the query 260 | func (b InsertBuilder) Suffix(sql string, args ...interface{}) InsertBuilder { 261 | return b.SuffixExpr(Expr(sql, args...)) 262 | } 263 | 264 | // SuffixExpr adds an expression to the end of the query 265 | func (b InsertBuilder) SuffixExpr(expr Sqlizer) InsertBuilder { 266 | return builder.Append(b, "Suffixes", expr).(InsertBuilder) 267 | } 268 | 269 | // SetMap set columns and values for insert builder from a map of column name and value 270 | // note that it will reset all previous columns and values was set if any 271 | func (b InsertBuilder) SetMap(clauses map[string]interface{}) InsertBuilder { 272 | // Keep the columns in a consistent order by sorting the column key string. 273 | cols := make([]string, 0, len(clauses)) 274 | for col := range clauses { 275 | cols = append(cols, col) 276 | } 277 | sort.Strings(cols) 278 | 279 | vals := make([]interface{}, 0, len(clauses)) 280 | for _, col := range cols { 281 | vals = append(vals, clauses[col]) 282 | } 283 | 284 | b = builder.Set(b, "Columns", cols).(InsertBuilder) 285 | b = builder.Set(b, "Values", [][]interface{}{vals}).(InsertBuilder) 286 | 287 | return b 288 | } 289 | 290 | // Select set Select clause for insert query 291 | // If Values and Select are used, then Select has higher priority 292 | func (b InsertBuilder) Select(sb SelectBuilder) InsertBuilder { 293 | return builder.Set(b, "Select", &sb).(InsertBuilder) 294 | } 295 | 296 | func (b InsertBuilder) statementKeyword(keyword string) InsertBuilder { 297 | return builder.Set(b, "StatementKeyword", keyword).(InsertBuilder) 298 | } 299 | -------------------------------------------------------------------------------- /update.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "fmt" 7 | "sort" 8 | "strings" 9 | 10 | "github.com/lann/builder" 11 | ) 12 | 13 | type updateData struct { 14 | PlaceholderFormat PlaceholderFormat 15 | RunWith BaseRunner 16 | Prefixes []Sqlizer 17 | Table string 18 | Joins []Sqlizer 19 | SetClauses []setClause 20 | From Sqlizer 21 | WhereParts []Sqlizer 22 | OrderBys []string 23 | Limit string 24 | Offset string 25 | Suffixes []Sqlizer 26 | } 27 | 28 | type setClause struct { 29 | column string 30 | value interface{} 31 | } 32 | 33 | func (d *updateData) Exec() (sql.Result, error) { 34 | if d.RunWith == nil { 35 | return nil, RunnerNotSet 36 | } 37 | return ExecWith(d.RunWith, d) 38 | } 39 | 40 | func (d *updateData) Query() (*sql.Rows, error) { 41 | if d.RunWith == nil { 42 | return nil, RunnerNotSet 43 | } 44 | return QueryWith(d.RunWith, d) 45 | } 46 | 47 | func (d *updateData) QueryRow() RowScanner { 48 | if d.RunWith == nil { 49 | return &Row{err: RunnerNotSet} 50 | } 51 | queryRower, ok := d.RunWith.(QueryRower) 52 | if !ok { 53 | return &Row{err: RunnerNotQueryRunner} 54 | } 55 | return QueryRowWith(queryRower, d) 56 | } 57 | 58 | func (d *updateData) ToSql() (sqlStr string, args []interface{}, err error) { 59 | if len(d.Table) == 0 { 60 | err = fmt.Errorf("update statements must specify a table") 61 | return 62 | } 63 | if len(d.SetClauses) == 0 { 64 | err = fmt.Errorf("update statements must have at least one Set clause") 65 | return 66 | } 67 | 68 | sql := &bytes.Buffer{} 69 | 70 | if len(d.Prefixes) > 0 { 71 | args, err = appendToSql(d.Prefixes, sql, " ", args) 72 | if err != nil { 73 | return 74 | } 75 | 76 | sql.WriteString(" ") 77 | } 78 | 79 | sql.WriteString("UPDATE ") 80 | sql.WriteString(d.Table) 81 | 82 | if len(d.Joins) > 0 { 83 | sql.WriteString(" ") 84 | args, err = appendToSql(d.Joins, sql, " ", args) 85 | if err != nil { 86 | return 87 | } 88 | } 89 | 90 | sql.WriteString(" SET ") 91 | setSqls := make([]string, len(d.SetClauses)) 92 | for i, setClause := range d.SetClauses { 93 | var valSql string 94 | if vs, ok := setClause.value.(Sqlizer); ok { 95 | vsql, vargs, err := vs.ToSql() 96 | if err != nil { 97 | return "", nil, err 98 | } 99 | if _, ok := vs.(SelectBuilder); ok { 100 | valSql = fmt.Sprintf("(%s)", vsql) 101 | } else { 102 | valSql = vsql 103 | } 104 | args = append(args, vargs...) 105 | } else { 106 | valSql = "?" 107 | args = append(args, setClause.value) 108 | } 109 | setSqls[i] = fmt.Sprintf("%s = %s", setClause.column, valSql) 110 | } 111 | sql.WriteString(strings.Join(setSqls, ", ")) 112 | 113 | if d.From != nil { 114 | sql.WriteString(" FROM ") 115 | args, err = appendToSql([]Sqlizer{d.From}, sql, "", args) 116 | if err != nil { 117 | return 118 | } 119 | } 120 | 121 | if len(d.WhereParts) > 0 { 122 | sql.WriteString(" WHERE ") 123 | args, err = appendToSql(d.WhereParts, sql, " AND ", args) 124 | if err != nil { 125 | return 126 | } 127 | } 128 | 129 | if len(d.OrderBys) > 0 { 130 | sql.WriteString(" ORDER BY ") 131 | sql.WriteString(strings.Join(d.OrderBys, ", ")) 132 | } 133 | 134 | if len(d.Limit) > 0 { 135 | sql.WriteString(" LIMIT ") 136 | sql.WriteString(d.Limit) 137 | } 138 | 139 | if len(d.Offset) > 0 { 140 | sql.WriteString(" OFFSET ") 141 | sql.WriteString(d.Offset) 142 | } 143 | 144 | if len(d.Suffixes) > 0 { 145 | sql.WriteString(" ") 146 | args, err = appendToSql(d.Suffixes, sql, " ", args) 147 | if err != nil { 148 | return 149 | } 150 | } 151 | 152 | sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String()) 153 | return 154 | } 155 | 156 | // Builder 157 | 158 | // UpdateBuilder builds SQL UPDATE statements. 159 | type UpdateBuilder builder.Builder 160 | 161 | func init() { 162 | builder.Register(UpdateBuilder{}, updateData{}) 163 | } 164 | 165 | // Format methods 166 | 167 | // PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the 168 | // query. 169 | func (b UpdateBuilder) PlaceholderFormat(f PlaceholderFormat) UpdateBuilder { 170 | return builder.Set(b, "PlaceholderFormat", f).(UpdateBuilder) 171 | } 172 | 173 | // Runner methods 174 | 175 | // RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec. 176 | func (b UpdateBuilder) RunWith(runner BaseRunner) UpdateBuilder { 177 | return setRunWith(b, runner).(UpdateBuilder) 178 | } 179 | 180 | // Exec builds and Execs the query with the Runner set by RunWith. 181 | func (b UpdateBuilder) Exec() (sql.Result, error) { 182 | data := builder.GetStruct(b).(updateData) 183 | return data.Exec() 184 | } 185 | 186 | func (b UpdateBuilder) Query() (*sql.Rows, error) { 187 | data := builder.GetStruct(b).(updateData) 188 | return data.Query() 189 | } 190 | 191 | func (b UpdateBuilder) QueryRow() RowScanner { 192 | data := builder.GetStruct(b).(updateData) 193 | return data.QueryRow() 194 | } 195 | 196 | func (b UpdateBuilder) Scan(dest ...interface{}) error { 197 | return b.QueryRow().Scan(dest...) 198 | } 199 | 200 | // SQL methods 201 | 202 | // ToSql builds the query into a SQL string and bound args. 203 | func (b UpdateBuilder) ToSql() (string, []interface{}, error) { 204 | data := builder.GetStruct(b).(updateData) 205 | return data.ToSql() 206 | } 207 | 208 | // MustSql builds the query into a SQL string and bound args. 209 | // It panics if there are any errors. 210 | func (b UpdateBuilder) MustSql() (string, []interface{}) { 211 | sql, args, err := b.ToSql() 212 | if err != nil { 213 | panic(err) 214 | } 215 | return sql, args 216 | } 217 | 218 | // Prefix adds an expression to the beginning of the query 219 | func (b UpdateBuilder) Prefix(sql string, args ...interface{}) UpdateBuilder { 220 | return b.PrefixExpr(Expr(sql, args...)) 221 | } 222 | 223 | // PrefixExpr adds an expression to the very beginning of the query 224 | func (b UpdateBuilder) PrefixExpr(expr Sqlizer) UpdateBuilder { 225 | return builder.Append(b, "Prefixes", expr).(UpdateBuilder) 226 | } 227 | 228 | // Table sets the table to be updated. 229 | func (b UpdateBuilder) Table(table string) UpdateBuilder { 230 | return builder.Set(b, "Table", table).(UpdateBuilder) 231 | } 232 | 233 | // JoinClause adds a join clause to the query. 234 | func (b UpdateBuilder) JoinClause(pred interface{}, args ...interface{}) UpdateBuilder { 235 | return builder.Append(b, "Joins", newPart(pred, args...)).(UpdateBuilder) 236 | } 237 | 238 | // Join adds a JOIN clause to the query. 239 | func (b UpdateBuilder) Join(join string, rest ...interface{}) UpdateBuilder { 240 | return b.JoinClause("JOIN "+join, rest...) 241 | } 242 | 243 | // LeftJoin adds a LEFT JOIN clause to the query. 244 | func (b UpdateBuilder) LeftJoin(join string, rest ...interface{}) UpdateBuilder { 245 | return b.JoinClause("LEFT JOIN "+join, rest...) 246 | } 247 | 248 | // RightJoin adds a RIGHT JOIN clause to the query. 249 | func (b UpdateBuilder) RightJoin(join string, rest ...interface{}) UpdateBuilder { 250 | return b.JoinClause("RIGHT JOIN "+join, rest...) 251 | } 252 | 253 | // InnerJoin adds a INNER JOIN clause to the query. 254 | func (b UpdateBuilder) InnerJoin(join string, rest ...interface{}) UpdateBuilder { 255 | return b.JoinClause("INNER JOIN "+join, rest...) 256 | } 257 | 258 | // CrossJoin adds a CROSS JOIN clause to the query. 259 | func (b UpdateBuilder) CrossJoin(join string, rest ...interface{}) UpdateBuilder { 260 | return b.JoinClause("CROSS JOIN "+join, rest...) 261 | } 262 | 263 | // Set adds SET clauses to the query. 264 | func (b UpdateBuilder) Set(column string, value interface{}) UpdateBuilder { 265 | return builder.Append(b, "SetClauses", setClause{column: column, value: value}).(UpdateBuilder) 266 | } 267 | 268 | // SetMap is a convenience method which calls .Set for each key/value pair in clauses. 269 | func (b UpdateBuilder) SetMap(clauses map[string]interface{}) UpdateBuilder { 270 | keys := make([]string, len(clauses)) 271 | i := 0 272 | for key := range clauses { 273 | keys[i] = key 274 | i++ 275 | } 276 | sort.Strings(keys) 277 | for _, key := range keys { 278 | val, _ := clauses[key] 279 | b = b.Set(key, val) 280 | } 281 | return b 282 | } 283 | 284 | // From adds FROM clause to the query 285 | // FROM is valid construct in postgresql only. 286 | func (b UpdateBuilder) From(from string) UpdateBuilder { 287 | return builder.Set(b, "From", newPart(from)).(UpdateBuilder) 288 | } 289 | 290 | // FromSelect sets a subquery into the FROM clause of the query. 291 | func (b UpdateBuilder) FromSelect(from SelectBuilder, alias string) UpdateBuilder { 292 | // Prevent misnumbered parameters in nested selects (#183). 293 | from = from.PlaceholderFormat(Question) 294 | return builder.Set(b, "From", Alias(from, alias)).(UpdateBuilder) 295 | } 296 | 297 | // Where adds WHERE expressions to the query. 298 | // 299 | // See SelectBuilder.Where for more information. 300 | func (b UpdateBuilder) Where(pred interface{}, args ...interface{}) UpdateBuilder { 301 | return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(UpdateBuilder) 302 | } 303 | 304 | // OrderBy adds ORDER BY expressions to the query. 305 | func (b UpdateBuilder) OrderBy(orderBys ...string) UpdateBuilder { 306 | return builder.Extend(b, "OrderBys", orderBys).(UpdateBuilder) 307 | } 308 | 309 | // Limit sets a LIMIT clause on the query. 310 | func (b UpdateBuilder) Limit(limit uint64) UpdateBuilder { 311 | return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(UpdateBuilder) 312 | } 313 | 314 | // Offset sets a OFFSET clause on the query. 315 | func (b UpdateBuilder) Offset(offset uint64) UpdateBuilder { 316 | return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(UpdateBuilder) 317 | } 318 | 319 | // Suffix adds an expression to the end of the query 320 | func (b UpdateBuilder) Suffix(sql string, args ...interface{}) UpdateBuilder { 321 | return b.SuffixExpr(Expr(sql, args...)) 322 | } 323 | 324 | // SuffixExpr adds an expression to the end of the query 325 | func (b UpdateBuilder) SuffixExpr(expr Sqlizer) UpdateBuilder { 326 | return builder.Append(b, "Suffixes", expr).(UpdateBuilder) 327 | } 328 | -------------------------------------------------------------------------------- /expr.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "bytes" 5 | "database/sql/driver" 6 | "fmt" 7 | "reflect" 8 | "sort" 9 | "strings" 10 | ) 11 | 12 | const ( 13 | // Portable true/false literals. 14 | sqlTrue = "(1=1)" 15 | sqlFalse = "(1=0)" 16 | ) 17 | 18 | type expr struct { 19 | sql string 20 | args []interface{} 21 | } 22 | 23 | // Expr builds an expression from a SQL fragment and arguments. 24 | // 25 | // Ex: 26 | // Expr("FROM_UNIXTIME(?)", t) 27 | func Expr(sql string, args ...interface{}) Sqlizer { 28 | return expr{sql: sql, args: args} 29 | } 30 | 31 | func (e expr) ToSql() (sql string, args []interface{}, err error) { 32 | simple := true 33 | for _, arg := range e.args { 34 | if _, ok := arg.(Sqlizer); ok { 35 | simple = false 36 | } 37 | } 38 | if simple { 39 | return e.sql, e.args, nil 40 | } 41 | 42 | buf := &bytes.Buffer{} 43 | ap := e.args 44 | sp := e.sql 45 | 46 | var isql string 47 | var iargs []interface{} 48 | 49 | for err == nil && len(ap) > 0 && len(sp) > 0 { 50 | i := strings.Index(sp, "?") 51 | if i < 0 { 52 | // no more placeholders 53 | break 54 | } 55 | if len(sp) > i+1 && sp[i+1:i+2] == "?" { 56 | // escaped "??"; append it and step past 57 | buf.WriteString(sp[:i+2]) 58 | sp = sp[i+2:] 59 | continue 60 | } 61 | 62 | if as, ok := ap[0].(Sqlizer); ok { 63 | // sqlizer argument; expand it and append the result 64 | isql, iargs, err = as.ToSql() 65 | buf.WriteString(sp[:i]) 66 | buf.WriteString(isql) 67 | args = append(args, iargs...) 68 | } else { 69 | // normal argument; append it and the placeholder 70 | buf.WriteString(sp[:i+1]) 71 | args = append(args, ap[0]) 72 | } 73 | 74 | // step past the argument and placeholder 75 | ap = ap[1:] 76 | sp = sp[i+1:] 77 | } 78 | 79 | // append the remaining sql and arguments 80 | buf.WriteString(sp) 81 | return buf.String(), append(args, ap...), err 82 | } 83 | 84 | type concatExpr []interface{} 85 | 86 | func (ce concatExpr) ToSql() (sql string, args []interface{}, err error) { 87 | for _, part := range ce { 88 | switch p := part.(type) { 89 | case string: 90 | sql += p 91 | case Sqlizer: 92 | pSql, pArgs, err := p.ToSql() 93 | if err != nil { 94 | return "", nil, err 95 | } 96 | sql += pSql 97 | args = append(args, pArgs...) 98 | default: 99 | return "", nil, fmt.Errorf("%#v is not a string or Sqlizer", part) 100 | } 101 | } 102 | return 103 | } 104 | 105 | // ConcatExpr builds an expression by concatenating strings and other expressions. 106 | // 107 | // Ex: 108 | // name_expr := Expr("CONCAT(?, ' ', ?)", firstName, lastName) 109 | // ConcatExpr("COALESCE(full_name,", name_expr, ")") 110 | func ConcatExpr(parts ...interface{}) concatExpr { 111 | return concatExpr(parts) 112 | } 113 | 114 | // aliasExpr helps to alias part of SQL query generated with underlying "expr" 115 | type aliasExpr struct { 116 | expr Sqlizer 117 | alias string 118 | } 119 | 120 | // Alias allows to define alias for column in SelectBuilder. Useful when column is 121 | // defined as complex expression like IF or CASE 122 | // Ex: 123 | // .Column(Alias(caseStmt, "case_column")) 124 | func Alias(expr Sqlizer, alias string) aliasExpr { 125 | return aliasExpr{expr, alias} 126 | } 127 | 128 | func (e aliasExpr) ToSql() (sql string, args []interface{}, err error) { 129 | sql, args, err = e.expr.ToSql() 130 | if err == nil { 131 | sql = fmt.Sprintf("(%s) AS %s", sql, e.alias) 132 | } 133 | return 134 | } 135 | 136 | // Eq is syntactic sugar for use with Where/Having/Set methods. 137 | type Eq map[string]interface{} 138 | 139 | func (eq Eq) toSQL(useNotOpr bool) (sql string, args []interface{}, err error) { 140 | if len(eq) == 0 { 141 | // Empty Sql{} evaluates to true. 142 | sql = sqlTrue 143 | return 144 | } 145 | 146 | var ( 147 | exprs []string 148 | equalOpr = "=" 149 | inOpr = "IN" 150 | nullOpr = "IS" 151 | inEmptyExpr = sqlFalse 152 | ) 153 | 154 | if useNotOpr { 155 | equalOpr = "<>" 156 | inOpr = "NOT IN" 157 | nullOpr = "IS NOT" 158 | inEmptyExpr = sqlTrue 159 | } 160 | 161 | sortedKeys := getSortedKeys(eq) 162 | for _, key := range sortedKeys { 163 | var expr string 164 | val := eq[key] 165 | 166 | switch v := val.(type) { 167 | case driver.Valuer: 168 | if val, err = v.Value(); err != nil { 169 | return 170 | } 171 | } 172 | 173 | r := reflect.ValueOf(val) 174 | if r.Kind() == reflect.Ptr { 175 | if r.IsNil() { 176 | val = nil 177 | } else { 178 | val = r.Elem().Interface() 179 | } 180 | } 181 | 182 | if val == nil { 183 | expr = fmt.Sprintf("%s %s NULL", key, nullOpr) 184 | } else { 185 | if isListType(val) { 186 | valVal := reflect.ValueOf(val) 187 | if valVal.Len() == 0 { 188 | expr = inEmptyExpr 189 | if args == nil { 190 | args = []interface{}{} 191 | } 192 | } else { 193 | for i := 0; i < valVal.Len(); i++ { 194 | args = append(args, valVal.Index(i).Interface()) 195 | } 196 | expr = fmt.Sprintf("%s %s (%s)", key, inOpr, Placeholders(valVal.Len())) 197 | } 198 | } else { 199 | expr = fmt.Sprintf("%s %s ?", key, equalOpr) 200 | args = append(args, val) 201 | } 202 | } 203 | exprs = append(exprs, expr) 204 | } 205 | sql = strings.Join(exprs, " AND ") 206 | return 207 | } 208 | 209 | func (eq Eq) ToSql() (sql string, args []interface{}, err error) { 210 | return eq.toSQL(false) 211 | } 212 | 213 | // NotEq is syntactic sugar for use with Where/Having/Set methods. 214 | // Ex: 215 | // .Where(NotEq{"id": 1}) == "id <> 1" 216 | type NotEq Eq 217 | 218 | func (neq NotEq) ToSql() (sql string, args []interface{}, err error) { 219 | return Eq(neq).toSQL(true) 220 | } 221 | 222 | // Like is syntactic sugar for use with LIKE conditions. 223 | // Ex: 224 | // .Where(Like{"name": "%irrel"}) 225 | type Like map[string]interface{} 226 | 227 | func (lk Like) toSql(opr string) (sql string, args []interface{}, err error) { 228 | var exprs []string 229 | for key, val := range lk { 230 | expr := "" 231 | 232 | switch v := val.(type) { 233 | case driver.Valuer: 234 | if val, err = v.Value(); err != nil { 235 | return 236 | } 237 | } 238 | 239 | if val == nil { 240 | err = fmt.Errorf("cannot use null with like operators") 241 | return 242 | } else { 243 | if isListType(val) { 244 | err = fmt.Errorf("cannot use array or slice with like operators") 245 | return 246 | } else { 247 | expr = fmt.Sprintf("%s %s ?", key, opr) 248 | args = append(args, val) 249 | } 250 | } 251 | exprs = append(exprs, expr) 252 | } 253 | sql = strings.Join(exprs, " AND ") 254 | return 255 | } 256 | 257 | func (lk Like) ToSql() (sql string, args []interface{}, err error) { 258 | return lk.toSql("LIKE") 259 | } 260 | 261 | // NotLike is syntactic sugar for use with LIKE conditions. 262 | // Ex: 263 | // .Where(NotLike{"name": "%irrel"}) 264 | type NotLike Like 265 | 266 | func (nlk NotLike) ToSql() (sql string, args []interface{}, err error) { 267 | return Like(nlk).toSql("NOT LIKE") 268 | } 269 | 270 | // ILike is syntactic sugar for use with ILIKE conditions. 271 | // Ex: 272 | // .Where(ILike{"name": "sq%"}) 273 | type ILike Like 274 | 275 | func (ilk ILike) ToSql() (sql string, args []interface{}, err error) { 276 | return Like(ilk).toSql("ILIKE") 277 | } 278 | 279 | // NotILike is syntactic sugar for use with ILIKE conditions. 280 | // Ex: 281 | // .Where(NotILike{"name": "sq%"}) 282 | type NotILike Like 283 | 284 | func (nilk NotILike) ToSql() (sql string, args []interface{}, err error) { 285 | return Like(nilk).toSql("NOT ILIKE") 286 | } 287 | 288 | // Lt is syntactic sugar for use with Where/Having/Set methods. 289 | // Ex: 290 | // .Where(Lt{"id": 1}) 291 | type Lt map[string]interface{} 292 | 293 | func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err error) { 294 | var ( 295 | exprs []string 296 | opr = "<" 297 | ) 298 | 299 | if opposite { 300 | opr = ">" 301 | } 302 | 303 | if orEq { 304 | opr = fmt.Sprintf("%s%s", opr, "=") 305 | } 306 | 307 | sortedKeys := getSortedKeys(lt) 308 | for _, key := range sortedKeys { 309 | var expr string 310 | val := lt[key] 311 | 312 | switch v := val.(type) { 313 | case driver.Valuer: 314 | if val, err = v.Value(); err != nil { 315 | return 316 | } 317 | } 318 | 319 | if val == nil { 320 | err = fmt.Errorf("cannot use null with less than or greater than operators") 321 | return 322 | } 323 | if isListType(val) { 324 | err = fmt.Errorf("cannot use array or slice with less than or greater than operators") 325 | return 326 | } 327 | expr = fmt.Sprintf("%s %s ?", key, opr) 328 | args = append(args, val) 329 | 330 | exprs = append(exprs, expr) 331 | } 332 | sql = strings.Join(exprs, " AND ") 333 | return 334 | } 335 | 336 | func (lt Lt) ToSql() (sql string, args []interface{}, err error) { 337 | return lt.toSql(false, false) 338 | } 339 | 340 | // LtOrEq is syntactic sugar for use with Where/Having/Set methods. 341 | // Ex: 342 | // .Where(LtOrEq{"id": 1}) == "id <= 1" 343 | type LtOrEq Lt 344 | 345 | func (ltOrEq LtOrEq) ToSql() (sql string, args []interface{}, err error) { 346 | return Lt(ltOrEq).toSql(false, true) 347 | } 348 | 349 | // Gt is syntactic sugar for use with Where/Having/Set methods. 350 | // Ex: 351 | // .Where(Gt{"id": 1}) == "id > 1" 352 | type Gt Lt 353 | 354 | func (gt Gt) ToSql() (sql string, args []interface{}, err error) { 355 | return Lt(gt).toSql(true, false) 356 | } 357 | 358 | // GtOrEq is syntactic sugar for use with Where/Having/Set methods. 359 | // Ex: 360 | // .Where(GtOrEq{"id": 1}) == "id >= 1" 361 | type GtOrEq Lt 362 | 363 | func (gtOrEq GtOrEq) ToSql() (sql string, args []interface{}, err error) { 364 | return Lt(gtOrEq).toSql(true, true) 365 | } 366 | 367 | type conj []Sqlizer 368 | 369 | func (c conj) join(sep, defaultExpr string) (sql string, args []interface{}, err error) { 370 | if len(c) == 0 { 371 | return defaultExpr, []interface{}{}, nil 372 | } 373 | var sqlParts []string 374 | for _, sqlizer := range c { 375 | partSQL, partArgs, err := nestedToSql(sqlizer) 376 | if err != nil { 377 | return "", nil, err 378 | } 379 | if partSQL != "" { 380 | sqlParts = append(sqlParts, partSQL) 381 | args = append(args, partArgs...) 382 | } 383 | } 384 | if len(sqlParts) > 0 { 385 | sql = fmt.Sprintf("(%s)", strings.Join(sqlParts, sep)) 386 | } 387 | return 388 | } 389 | 390 | // And conjunction Sqlizers 391 | type And conj 392 | 393 | func (a And) ToSql() (string, []interface{}, error) { 394 | return conj(a).join(" AND ", sqlTrue) 395 | } 396 | 397 | // Or conjunction Sqlizers 398 | type Or conj 399 | 400 | func (o Or) ToSql() (string, []interface{}, error) { 401 | return conj(o).join(" OR ", sqlFalse) 402 | } 403 | 404 | func getSortedKeys(exp map[string]interface{}) []string { 405 | sortedKeys := make([]string, 0, len(exp)) 406 | for k := range exp { 407 | sortedKeys = append(sortedKeys, k) 408 | } 409 | sort.Strings(sortedKeys) 410 | return sortedKeys 411 | } 412 | 413 | func isListType(val interface{}) bool { 414 | if driver.IsValue(val) { 415 | return false 416 | } 417 | valVal := reflect.ValueOf(val) 418 | return valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice 419 | } 420 | -------------------------------------------------------------------------------- /expr_test.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestConcatExpr(t *testing.T) { 11 | b := ConcatExpr("COALESCE(name,", Expr("CONCAT(?,' ',?)", "f", "l"), ")") 12 | sql, args, err := b.ToSql() 13 | assert.NoError(t, err) 14 | 15 | expectedSql := "COALESCE(name,CONCAT(?,' ',?))" 16 | assert.Equal(t, expectedSql, sql) 17 | 18 | expectedArgs := []interface{}{"f", "l"} 19 | assert.Equal(t, expectedArgs, args) 20 | } 21 | 22 | func TestConcatExprBadType(t *testing.T) { 23 | b := ConcatExpr("prefix", 123, "suffix") 24 | _, _, err := b.ToSql() 25 | assert.Error(t, err) 26 | assert.Contains(t, err.Error(), "123 is not") 27 | } 28 | 29 | func TestEqToSql(t *testing.T) { 30 | b := Eq{"id": 1} 31 | sql, args, err := b.ToSql() 32 | assert.NoError(t, err) 33 | 34 | expectedSql := "id = ?" 35 | assert.Equal(t, expectedSql, sql) 36 | 37 | expectedArgs := []interface{}{1} 38 | assert.Equal(t, expectedArgs, args) 39 | } 40 | 41 | func TestEqEmptyToSql(t *testing.T) { 42 | sql, args, err := Eq{}.ToSql() 43 | assert.NoError(t, err) 44 | 45 | expectedSql := "(1=1)" 46 | assert.Equal(t, expectedSql, sql) 47 | assert.Empty(t, args) 48 | } 49 | 50 | func TestEqInToSql(t *testing.T) { 51 | b := Eq{"id": []int{1, 2, 3}} 52 | sql, args, err := b.ToSql() 53 | assert.NoError(t, err) 54 | 55 | expectedSql := "id IN (?,?,?)" 56 | assert.Equal(t, expectedSql, sql) 57 | 58 | expectedArgs := []interface{}{1, 2, 3} 59 | assert.Equal(t, expectedArgs, args) 60 | } 61 | 62 | func TestNotEqToSql(t *testing.T) { 63 | b := NotEq{"id": 1} 64 | sql, args, err := b.ToSql() 65 | assert.NoError(t, err) 66 | 67 | expectedSql := "id <> ?" 68 | assert.Equal(t, expectedSql, sql) 69 | 70 | expectedArgs := []interface{}{1} 71 | assert.Equal(t, expectedArgs, args) 72 | } 73 | 74 | func TestEqNotInToSql(t *testing.T) { 75 | b := NotEq{"id": []int{1, 2, 3}} 76 | sql, args, err := b.ToSql() 77 | assert.NoError(t, err) 78 | 79 | expectedSql := "id NOT IN (?,?,?)" 80 | assert.Equal(t, expectedSql, sql) 81 | 82 | expectedArgs := []interface{}{1, 2, 3} 83 | assert.Equal(t, expectedArgs, args) 84 | } 85 | 86 | func TestEqInEmptyToSql(t *testing.T) { 87 | b := Eq{"id": []int{}} 88 | sql, args, err := b.ToSql() 89 | assert.NoError(t, err) 90 | 91 | expectedSql := "(1=0)" 92 | assert.Equal(t, expectedSql, sql) 93 | 94 | expectedArgs := []interface{}{} 95 | assert.Equal(t, expectedArgs, args) 96 | } 97 | 98 | func TestNotEqInEmptyToSql(t *testing.T) { 99 | b := NotEq{"id": []int{}} 100 | sql, args, err := b.ToSql() 101 | assert.NoError(t, err) 102 | 103 | expectedSql := "(1=1)" 104 | assert.Equal(t, expectedSql, sql) 105 | 106 | expectedArgs := []interface{}{} 107 | assert.Equal(t, expectedArgs, args) 108 | } 109 | 110 | func TestEqBytesToSql(t *testing.T) { 111 | b := Eq{"id": []byte("test")} 112 | sql, args, err := b.ToSql() 113 | assert.NoError(t, err) 114 | 115 | expectedSql := "id = ?" 116 | assert.Equal(t, expectedSql, sql) 117 | 118 | expectedArgs := []interface{}{[]byte("test")} 119 | assert.Equal(t, expectedArgs, args) 120 | } 121 | 122 | func TestLtToSql(t *testing.T) { 123 | b := Lt{"id": 1} 124 | sql, args, err := b.ToSql() 125 | assert.NoError(t, err) 126 | 127 | expectedSql := "id < ?" 128 | assert.Equal(t, expectedSql, sql) 129 | 130 | expectedArgs := []interface{}{1} 131 | assert.Equal(t, expectedArgs, args) 132 | } 133 | 134 | func TestLtOrEqToSql(t *testing.T) { 135 | b := LtOrEq{"id": 1} 136 | sql, args, err := b.ToSql() 137 | assert.NoError(t, err) 138 | 139 | expectedSql := "id <= ?" 140 | assert.Equal(t, expectedSql, sql) 141 | 142 | expectedArgs := []interface{}{1} 143 | assert.Equal(t, expectedArgs, args) 144 | } 145 | 146 | func TestGtToSql(t *testing.T) { 147 | b := Gt{"id": 1} 148 | sql, args, err := b.ToSql() 149 | assert.NoError(t, err) 150 | 151 | expectedSql := "id > ?" 152 | assert.Equal(t, expectedSql, sql) 153 | 154 | expectedArgs := []interface{}{1} 155 | assert.Equal(t, expectedArgs, args) 156 | } 157 | 158 | func TestGtOrEqToSql(t *testing.T) { 159 | b := GtOrEq{"id": 1} 160 | sql, args, err := b.ToSql() 161 | assert.NoError(t, err) 162 | 163 | expectedSql := "id >= ?" 164 | assert.Equal(t, expectedSql, sql) 165 | 166 | expectedArgs := []interface{}{1} 167 | assert.Equal(t, expectedArgs, args) 168 | } 169 | 170 | func TestExprNilToSql(t *testing.T) { 171 | var b Sqlizer 172 | b = NotEq{"name": nil} 173 | sql, args, err := b.ToSql() 174 | assert.NoError(t, err) 175 | assert.Empty(t, args) 176 | 177 | expectedSql := "name IS NOT NULL" 178 | assert.Equal(t, expectedSql, sql) 179 | 180 | b = Eq{"name": nil} 181 | sql, args, err = b.ToSql() 182 | assert.NoError(t, err) 183 | assert.Empty(t, args) 184 | 185 | expectedSql = "name IS NULL" 186 | assert.Equal(t, expectedSql, sql) 187 | } 188 | 189 | func TestNullTypeString(t *testing.T) { 190 | var b Sqlizer 191 | var name sql.NullString 192 | 193 | b = Eq{"name": name} 194 | sql, args, err := b.ToSql() 195 | 196 | assert.NoError(t, err) 197 | assert.Empty(t, args) 198 | assert.Equal(t, "name IS NULL", sql) 199 | 200 | name.Scan("Name") 201 | b = Eq{"name": name} 202 | sql, args, err = b.ToSql() 203 | 204 | assert.NoError(t, err) 205 | assert.Equal(t, []interface{}{"Name"}, args) 206 | assert.Equal(t, "name = ?", sql) 207 | } 208 | 209 | func TestNullTypeInt64(t *testing.T) { 210 | var userID sql.NullInt64 211 | userID.Scan(nil) 212 | b := Eq{"user_id": userID} 213 | sql, args, err := b.ToSql() 214 | 215 | assert.NoError(t, err) 216 | assert.Empty(t, args) 217 | assert.Equal(t, "user_id IS NULL", sql) 218 | 219 | userID.Scan(int64(10)) 220 | b = Eq{"user_id": userID} 221 | sql, args, err = b.ToSql() 222 | 223 | assert.NoError(t, err) 224 | assert.Equal(t, []interface{}{int64(10)}, args) 225 | assert.Equal(t, "user_id = ?", sql) 226 | } 227 | 228 | func TestNilPointer(t *testing.T) { 229 | var name *string = nil 230 | eq := Eq{"name": name} 231 | sql, args, err := eq.ToSql() 232 | 233 | assert.NoError(t, err) 234 | assert.Empty(t, args) 235 | assert.Equal(t, "name IS NULL", sql) 236 | 237 | neq := NotEq{"name": name} 238 | sql, args, err = neq.ToSql() 239 | 240 | assert.NoError(t, err) 241 | assert.Empty(t, args) 242 | assert.Equal(t, "name IS NOT NULL", sql) 243 | 244 | var ids *[]int = nil 245 | eq = Eq{"id": ids} 246 | sql, args, err = eq.ToSql() 247 | assert.NoError(t, err) 248 | assert.Empty(t, args) 249 | assert.Equal(t, "id IS NULL", sql) 250 | 251 | neq = NotEq{"id": ids} 252 | sql, args, err = neq.ToSql() 253 | assert.NoError(t, err) 254 | assert.Empty(t, args) 255 | assert.Equal(t, "id IS NOT NULL", sql) 256 | 257 | var ida *[3]int = nil 258 | eq = Eq{"id": ida} 259 | sql, args, err = eq.ToSql() 260 | assert.NoError(t, err) 261 | assert.Empty(t, args) 262 | assert.Equal(t, "id IS NULL", sql) 263 | 264 | neq = NotEq{"id": ida} 265 | sql, args, err = neq.ToSql() 266 | assert.NoError(t, err) 267 | assert.Empty(t, args) 268 | assert.Equal(t, "id IS NOT NULL", sql) 269 | 270 | } 271 | 272 | func TestNotNilPointer(t *testing.T) { 273 | c := "Name" 274 | name := &c 275 | eq := Eq{"name": name} 276 | sql, args, err := eq.ToSql() 277 | 278 | assert.NoError(t, err) 279 | assert.Equal(t, []interface{}{"Name"}, args) 280 | assert.Equal(t, "name = ?", sql) 281 | 282 | neq := NotEq{"name": name} 283 | sql, args, err = neq.ToSql() 284 | 285 | assert.NoError(t, err) 286 | assert.Equal(t, []interface{}{"Name"}, args) 287 | assert.Equal(t, "name <> ?", sql) 288 | 289 | s := []int{1, 2, 3} 290 | ids := &s 291 | eq = Eq{"id": ids} 292 | sql, args, err = eq.ToSql() 293 | assert.NoError(t, err) 294 | assert.Equal(t, []interface{}{1, 2, 3}, args) 295 | assert.Equal(t, "id IN (?,?,?)", sql) 296 | 297 | neq = NotEq{"id": ids} 298 | sql, args, err = neq.ToSql() 299 | assert.NoError(t, err) 300 | assert.Equal(t, []interface{}{1, 2, 3}, args) 301 | assert.Equal(t, "id NOT IN (?,?,?)", sql) 302 | 303 | a := [3]int{1, 2, 3} 304 | ida := &a 305 | eq = Eq{"id": ida} 306 | sql, args, err = eq.ToSql() 307 | assert.NoError(t, err) 308 | assert.Equal(t, []interface{}{1, 2, 3}, args) 309 | assert.Equal(t, "id IN (?,?,?)", sql) 310 | 311 | neq = NotEq{"id": ida} 312 | sql, args, err = neq.ToSql() 313 | assert.NoError(t, err) 314 | assert.Equal(t, []interface{}{1, 2, 3}, args) 315 | assert.Equal(t, "id NOT IN (?,?,?)", sql) 316 | } 317 | 318 | func TestEmptyAndToSql(t *testing.T) { 319 | sql, args, err := And{}.ToSql() 320 | assert.NoError(t, err) 321 | 322 | expectedSql := "(1=1)" 323 | assert.Equal(t, expectedSql, sql) 324 | 325 | expectedArgs := []interface{}{} 326 | assert.Equal(t, expectedArgs, args) 327 | } 328 | 329 | func TestEmptyOrToSql(t *testing.T) { 330 | sql, args, err := Or{}.ToSql() 331 | assert.NoError(t, err) 332 | 333 | expectedSql := "(1=0)" 334 | assert.Equal(t, expectedSql, sql) 335 | 336 | expectedArgs := []interface{}{} 337 | assert.Equal(t, expectedArgs, args) 338 | } 339 | 340 | func TestLikeToSql(t *testing.T) { 341 | b := Like{"name": "%irrel"} 342 | sql, args, err := b.ToSql() 343 | assert.NoError(t, err) 344 | 345 | expectedSql := "name LIKE ?" 346 | assert.Equal(t, expectedSql, sql) 347 | 348 | expectedArgs := []interface{}{"%irrel"} 349 | assert.Equal(t, expectedArgs, args) 350 | } 351 | 352 | func TestNotLikeToSql(t *testing.T) { 353 | b := NotLike{"name": "%irrel"} 354 | sql, args, err := b.ToSql() 355 | assert.NoError(t, err) 356 | 357 | expectedSql := "name NOT LIKE ?" 358 | assert.Equal(t, expectedSql, sql) 359 | 360 | expectedArgs := []interface{}{"%irrel"} 361 | assert.Equal(t, expectedArgs, args) 362 | } 363 | 364 | func TestILikeToSql(t *testing.T) { 365 | b := ILike{"name": "sq%"} 366 | sql, args, err := b.ToSql() 367 | assert.NoError(t, err) 368 | 369 | expectedSql := "name ILIKE ?" 370 | assert.Equal(t, expectedSql, sql) 371 | 372 | expectedArgs := []interface{}{"sq%"} 373 | assert.Equal(t, expectedArgs, args) 374 | } 375 | 376 | func TestNotILikeToSql(t *testing.T) { 377 | b := NotILike{"name": "sq%"} 378 | sql, args, err := b.ToSql() 379 | assert.NoError(t, err) 380 | 381 | expectedSql := "name NOT ILIKE ?" 382 | assert.Equal(t, expectedSql, sql) 383 | 384 | expectedArgs := []interface{}{"sq%"} 385 | assert.Equal(t, expectedArgs, args) 386 | } 387 | 388 | func TestSqlEqOrder(t *testing.T) { 389 | b := Eq{"a": 1, "b": 2, "c": 3} 390 | sql, args, err := b.ToSql() 391 | assert.NoError(t, err) 392 | 393 | expectedSql := "a = ? AND b = ? AND c = ?" 394 | assert.Equal(t, expectedSql, sql) 395 | 396 | expectedArgs := []interface{}{1, 2, 3} 397 | assert.Equal(t, expectedArgs, args) 398 | } 399 | 400 | func TestSqlLtOrder(t *testing.T) { 401 | b := Lt{"a": 1, "b": 2, "c": 3} 402 | sql, args, err := b.ToSql() 403 | assert.NoError(t, err) 404 | 405 | expectedSql := "a < ? AND b < ? AND c < ?" 406 | assert.Equal(t, expectedSql, sql) 407 | 408 | expectedArgs := []interface{}{1, 2, 3} 409 | assert.Equal(t, expectedArgs, args) 410 | } 411 | 412 | func TestExprEscaped(t *testing.T) { 413 | b := Expr("count(??)", Expr("x")) 414 | sql, args, err := b.ToSql() 415 | assert.NoError(t, err) 416 | 417 | expectedSql := "count(??)" 418 | assert.Equal(t, expectedSql, sql) 419 | 420 | expectedArgs := []interface{}{Expr("x")} 421 | assert.Equal(t, expectedArgs, args) 422 | } 423 | 424 | func TestExprRecursion(t *testing.T) { 425 | { 426 | b := Expr("count(?)", Expr("nullif(a,?)", "b")) 427 | sql, args, err := b.ToSql() 428 | assert.NoError(t, err) 429 | 430 | expectedSql := "count(nullif(a,?))" 431 | assert.Equal(t, expectedSql, sql) 432 | 433 | expectedArgs := []interface{}{"b"} 434 | assert.Equal(t, expectedArgs, args) 435 | } 436 | { 437 | b := Expr("extract(? from ?)", Expr("epoch"), "2001-02-03") 438 | sql, args, err := b.ToSql() 439 | assert.NoError(t, err) 440 | 441 | expectedSql := "extract(epoch from ?)" 442 | assert.Equal(t, expectedSql, sql) 443 | 444 | expectedArgs := []interface{}{"2001-02-03"} 445 | assert.Equal(t, expectedArgs, args) 446 | } 447 | { 448 | b := Expr("JOIN t1 ON ?", And{Eq{"id": 1}, Expr("NOT c1"), Expr("? @@ ?", "x", "y")}) 449 | sql, args, err := b.ToSql() 450 | assert.NoError(t, err) 451 | 452 | expectedSql := "JOIN t1 ON (id = ? AND NOT c1 AND ? @@ ?)" 453 | assert.Equal(t, expectedSql, sql) 454 | 455 | expectedArgs := []interface{}{1, "x", "y"} 456 | assert.Equal(t, expectedArgs, args) 457 | } 458 | } 459 | 460 | func ExampleEq() { 461 | Select("id", "created", "first_name").From("users").Where(Eq{ 462 | "company": 20, 463 | }) 464 | } 465 | -------------------------------------------------------------------------------- /select.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/lann/builder" 10 | ) 11 | 12 | type selectData struct { 13 | PlaceholderFormat PlaceholderFormat 14 | RunWith BaseRunner 15 | Prefixes []Sqlizer 16 | Options []string 17 | Columns []Sqlizer 18 | From Sqlizer 19 | Joins []Sqlizer 20 | WhereParts []Sqlizer 21 | GroupBys []string 22 | HavingParts []Sqlizer 23 | OrderByParts []Sqlizer 24 | Limit string 25 | Offset string 26 | Suffixes []Sqlizer 27 | } 28 | 29 | func (d *selectData) Exec() (sql.Result, error) { 30 | if d.RunWith == nil { 31 | return nil, RunnerNotSet 32 | } 33 | return ExecWith(d.RunWith, d) 34 | } 35 | 36 | func (d *selectData) Query() (*sql.Rows, error) { 37 | if d.RunWith == nil { 38 | return nil, RunnerNotSet 39 | } 40 | return QueryWith(d.RunWith, d) 41 | } 42 | 43 | func (d *selectData) QueryRow() RowScanner { 44 | if d.RunWith == nil { 45 | return &Row{err: RunnerNotSet} 46 | } 47 | queryRower, ok := d.RunWith.(QueryRower) 48 | if !ok { 49 | return &Row{err: RunnerNotQueryRunner} 50 | } 51 | return QueryRowWith(queryRower, d) 52 | } 53 | 54 | func (d *selectData) ToSql() (sqlStr string, args []interface{}, err error) { 55 | sqlStr, args, err = d.toSqlRaw() 56 | if err != nil { 57 | return 58 | } 59 | 60 | sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sqlStr) 61 | return 62 | } 63 | 64 | func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) { 65 | if len(d.Columns) == 0 { 66 | err = fmt.Errorf("select statements must have at least one result column") 67 | return 68 | } 69 | 70 | sql := &bytes.Buffer{} 71 | 72 | if len(d.Prefixes) > 0 { 73 | args, err = appendToSql(d.Prefixes, sql, " ", args) 74 | if err != nil { 75 | return 76 | } 77 | 78 | sql.WriteString(" ") 79 | } 80 | 81 | sql.WriteString("SELECT ") 82 | 83 | if len(d.Options) > 0 { 84 | sql.WriteString(strings.Join(d.Options, " ")) 85 | sql.WriteString(" ") 86 | } 87 | 88 | if len(d.Columns) > 0 { 89 | args, err = appendToSql(d.Columns, sql, ", ", args) 90 | if err != nil { 91 | return 92 | } 93 | } 94 | 95 | if d.From != nil { 96 | sql.WriteString(" FROM ") 97 | args, err = appendToSql([]Sqlizer{d.From}, sql, "", args) 98 | if err != nil { 99 | return 100 | } 101 | } 102 | 103 | if len(d.Joins) > 0 { 104 | sql.WriteString(" ") 105 | args, err = appendToSql(d.Joins, sql, " ", args) 106 | if err != nil { 107 | return 108 | } 109 | } 110 | 111 | if len(d.WhereParts) > 0 { 112 | sql.WriteString(" WHERE ") 113 | args, err = appendToSql(d.WhereParts, sql, " AND ", args) 114 | if err != nil { 115 | return 116 | } 117 | } 118 | 119 | if len(d.GroupBys) > 0 { 120 | sql.WriteString(" GROUP BY ") 121 | sql.WriteString(strings.Join(d.GroupBys, ", ")) 122 | } 123 | 124 | if len(d.HavingParts) > 0 { 125 | sql.WriteString(" HAVING ") 126 | args, err = appendToSql(d.HavingParts, sql, " AND ", args) 127 | if err != nil { 128 | return 129 | } 130 | } 131 | 132 | if len(d.OrderByParts) > 0 { 133 | sql.WriteString(" ORDER BY ") 134 | args, err = appendToSql(d.OrderByParts, sql, ", ", args) 135 | if err != nil { 136 | return 137 | } 138 | } 139 | 140 | if len(d.Limit) > 0 { 141 | sql.WriteString(" LIMIT ") 142 | sql.WriteString(d.Limit) 143 | } 144 | 145 | if len(d.Offset) > 0 { 146 | sql.WriteString(" OFFSET ") 147 | sql.WriteString(d.Offset) 148 | } 149 | 150 | if len(d.Suffixes) > 0 { 151 | sql.WriteString(" ") 152 | 153 | args, err = appendToSql(d.Suffixes, sql, " ", args) 154 | if err != nil { 155 | return 156 | } 157 | } 158 | 159 | sqlStr = sql.String() 160 | return 161 | } 162 | 163 | // Builder 164 | 165 | // SelectBuilder builds SQL SELECT statements. 166 | type SelectBuilder builder.Builder 167 | 168 | func init() { 169 | builder.Register(SelectBuilder{}, selectData{}) 170 | } 171 | 172 | // Format methods 173 | 174 | // PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the 175 | // query. 176 | func (b SelectBuilder) PlaceholderFormat(f PlaceholderFormat) SelectBuilder { 177 | return builder.Set(b, "PlaceholderFormat", f).(SelectBuilder) 178 | } 179 | 180 | // Runner methods 181 | 182 | // RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec. 183 | // For most cases runner will be a database connection. 184 | // 185 | // Internally we use this to mock out the database connection for testing. 186 | func (b SelectBuilder) RunWith(runner BaseRunner) SelectBuilder { 187 | return setRunWith(b, runner).(SelectBuilder) 188 | } 189 | 190 | // Exec builds and Execs the query with the Runner set by RunWith. 191 | func (b SelectBuilder) Exec() (sql.Result, error) { 192 | data := builder.GetStruct(b).(selectData) 193 | return data.Exec() 194 | } 195 | 196 | // Query builds and Querys the query with the Runner set by RunWith. 197 | func (b SelectBuilder) Query() (*sql.Rows, error) { 198 | data := builder.GetStruct(b).(selectData) 199 | return data.Query() 200 | } 201 | 202 | // QueryRow builds and QueryRows the query with the Runner set by RunWith. 203 | func (b SelectBuilder) QueryRow() RowScanner { 204 | data := builder.GetStruct(b).(selectData) 205 | return data.QueryRow() 206 | } 207 | 208 | // Scan is a shortcut for QueryRow().Scan. 209 | func (b SelectBuilder) Scan(dest ...interface{}) error { 210 | return b.QueryRow().Scan(dest...) 211 | } 212 | 213 | // SQL methods 214 | 215 | // ToSql builds the query into a SQL string and bound args. 216 | func (b SelectBuilder) ToSql() (string, []interface{}, error) { 217 | data := builder.GetStruct(b).(selectData) 218 | return data.ToSql() 219 | } 220 | 221 | func (b SelectBuilder) toSqlRaw() (string, []interface{}, error) { 222 | data := builder.GetStruct(b).(selectData) 223 | return data.toSqlRaw() 224 | } 225 | 226 | // MustSql builds the query into a SQL string and bound args. 227 | // It panics if there are any errors. 228 | func (b SelectBuilder) MustSql() (string, []interface{}) { 229 | sql, args, err := b.ToSql() 230 | if err != nil { 231 | panic(err) 232 | } 233 | return sql, args 234 | } 235 | 236 | // Prefix adds an expression to the beginning of the query 237 | func (b SelectBuilder) Prefix(sql string, args ...interface{}) SelectBuilder { 238 | return b.PrefixExpr(Expr(sql, args...)) 239 | } 240 | 241 | // PrefixExpr adds an expression to the very beginning of the query 242 | func (b SelectBuilder) PrefixExpr(expr Sqlizer) SelectBuilder { 243 | return builder.Append(b, "Prefixes", expr).(SelectBuilder) 244 | } 245 | 246 | // Distinct adds a DISTINCT clause to the query. 247 | func (b SelectBuilder) Distinct() SelectBuilder { 248 | return b.Options("DISTINCT") 249 | } 250 | 251 | // Options adds select option to the query 252 | func (b SelectBuilder) Options(options ...string) SelectBuilder { 253 | return builder.Extend(b, "Options", options).(SelectBuilder) 254 | } 255 | 256 | // Columns adds result columns to the query. 257 | func (b SelectBuilder) Columns(columns ...string) SelectBuilder { 258 | parts := make([]interface{}, 0, len(columns)) 259 | for _, str := range columns { 260 | parts = append(parts, newPart(str)) 261 | } 262 | return builder.Extend(b, "Columns", parts).(SelectBuilder) 263 | } 264 | 265 | // RemoveColumns remove all columns from query. 266 | // Must add a new column with Column or Columns methods, otherwise 267 | // return a error. 268 | func (b SelectBuilder) RemoveColumns() SelectBuilder { 269 | return builder.Delete(b, "Columns").(SelectBuilder) 270 | } 271 | 272 | // Column adds a result column to the query. 273 | // Unlike Columns, Column accepts args which will be bound to placeholders in 274 | // the columns string, for example: 275 | // Column("IF(col IN ("+squirrel.Placeholders(3)+"), 1, 0) as col", 1, 2, 3) 276 | func (b SelectBuilder) Column(column interface{}, args ...interface{}) SelectBuilder { 277 | return builder.Append(b, "Columns", newPart(column, args...)).(SelectBuilder) 278 | } 279 | 280 | // From sets the FROM clause of the query. 281 | func (b SelectBuilder) From(from string) SelectBuilder { 282 | return builder.Set(b, "From", newPart(from)).(SelectBuilder) 283 | } 284 | 285 | // FromSelect sets a subquery into the FROM clause of the query. 286 | func (b SelectBuilder) FromSelect(from SelectBuilder, alias string) SelectBuilder { 287 | // Prevent misnumbered parameters in nested selects (#183). 288 | from = from.PlaceholderFormat(Question) 289 | return builder.Set(b, "From", Alias(from, alias)).(SelectBuilder) 290 | } 291 | 292 | // JoinClause adds a join clause to the query. 293 | func (b SelectBuilder) JoinClause(pred interface{}, args ...interface{}) SelectBuilder { 294 | return builder.Append(b, "Joins", newPart(pred, args...)).(SelectBuilder) 295 | } 296 | 297 | // Join adds a JOIN clause to the query. 298 | func (b SelectBuilder) Join(join string, rest ...interface{}) SelectBuilder { 299 | return b.JoinClause("JOIN "+join, rest...) 300 | } 301 | 302 | // LeftJoin adds a LEFT JOIN clause to the query. 303 | func (b SelectBuilder) LeftJoin(join string, rest ...interface{}) SelectBuilder { 304 | return b.JoinClause("LEFT JOIN "+join, rest...) 305 | } 306 | 307 | // RightJoin adds a RIGHT JOIN clause to the query. 308 | func (b SelectBuilder) RightJoin(join string, rest ...interface{}) SelectBuilder { 309 | return b.JoinClause("RIGHT JOIN "+join, rest...) 310 | } 311 | 312 | // InnerJoin adds a INNER JOIN clause to the query. 313 | func (b SelectBuilder) InnerJoin(join string, rest ...interface{}) SelectBuilder { 314 | return b.JoinClause("INNER JOIN "+join, rest...) 315 | } 316 | 317 | // CrossJoin adds a CROSS JOIN clause to the query. 318 | func (b SelectBuilder) CrossJoin(join string, rest ...interface{}) SelectBuilder { 319 | return b.JoinClause("CROSS JOIN "+join, rest...) 320 | } 321 | 322 | // Where adds an expression to the WHERE clause of the query. 323 | // 324 | // Expressions are ANDed together in the generated SQL. 325 | // 326 | // Where accepts several types for its pred argument: 327 | // 328 | // nil OR "" - ignored. 329 | // 330 | // string - SQL expression. 331 | // If the expression has SQL placeholders then a set of arguments must be passed 332 | // as well, one for each placeholder. 333 | // 334 | // map[string]interface{} OR Eq - map of SQL expressions to values. Each key is 335 | // transformed into an expression like " = ?", with the corresponding value 336 | // bound to the placeholder. If the value is nil, the expression will be " 337 | // IS NULL". If the value is an array or slice, the expression will be " IN 338 | // (?,?,...)", with one placeholder for each item in the value. These expressions 339 | // are ANDed together. 340 | // 341 | // Where will panic if pred isn't any of the above types. 342 | func (b SelectBuilder) Where(pred interface{}, args ...interface{}) SelectBuilder { 343 | if pred == nil || pred == "" { 344 | return b 345 | } 346 | return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(SelectBuilder) 347 | } 348 | 349 | // GroupBy adds GROUP BY expressions to the query. 350 | func (b SelectBuilder) GroupBy(groupBys ...string) SelectBuilder { 351 | return builder.Extend(b, "GroupBys", groupBys).(SelectBuilder) 352 | } 353 | 354 | // Having adds an expression to the HAVING clause of the query. 355 | // 356 | // See Where. 357 | func (b SelectBuilder) Having(pred interface{}, rest ...interface{}) SelectBuilder { 358 | return builder.Append(b, "HavingParts", newWherePart(pred, rest...)).(SelectBuilder) 359 | } 360 | 361 | // OrderByClause adds ORDER BY clause to the query. 362 | func (b SelectBuilder) OrderByClause(pred interface{}, args ...interface{}) SelectBuilder { 363 | return builder.Append(b, "OrderByParts", newPart(pred, args...)).(SelectBuilder) 364 | } 365 | 366 | // OrderBy adds ORDER BY expressions to the query. 367 | func (b SelectBuilder) OrderBy(orderBys ...string) SelectBuilder { 368 | for _, orderBy := range orderBys { 369 | b = b.OrderByClause(orderBy) 370 | } 371 | 372 | return b 373 | } 374 | 375 | // Limit sets a LIMIT clause on the query. 376 | func (b SelectBuilder) Limit(limit uint64) SelectBuilder { 377 | return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(SelectBuilder) 378 | } 379 | 380 | // Limit ALL allows to access all records with limit 381 | func (b SelectBuilder) RemoveLimit() SelectBuilder { 382 | return builder.Delete(b, "Limit").(SelectBuilder) 383 | } 384 | 385 | // Offset sets a OFFSET clause on the query. 386 | func (b SelectBuilder) Offset(offset uint64) SelectBuilder { 387 | return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(SelectBuilder) 388 | } 389 | 390 | // RemoveOffset removes OFFSET clause. 391 | func (b SelectBuilder) RemoveOffset() SelectBuilder { 392 | return builder.Delete(b, "Offset").(SelectBuilder) 393 | } 394 | 395 | // Suffix adds an expression to the end of the query 396 | func (b SelectBuilder) Suffix(sql string, args ...interface{}) SelectBuilder { 397 | return b.SuffixExpr(Expr(sql, args...)) 398 | } 399 | 400 | // SuffixExpr adds an expression to the end of the query 401 | func (b SelectBuilder) SuffixExpr(expr Sqlizer) SelectBuilder { 402 | return builder.Append(b, "Suffixes", expr).(SelectBuilder) 403 | } 404 | -------------------------------------------------------------------------------- /select_test.go: -------------------------------------------------------------------------------- 1 | package squirrel 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "log" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestSelectBuilderToSql(t *testing.T) { 14 | subQ := Select("aa", "bb").From("dd") 15 | b := Select("a", "b"). 16 | Prefix("WITH prefix AS ?", 0). 17 | Distinct(). 18 | Columns("c"). 19 | Column("IF(d IN ("+Placeholders(3)+"), 1, 0) as stat_column", 1, 2, 3). 20 | Column(Expr("a > ?", 100)). 21 | Column(Alias(Eq{"b": []int{101, 102, 103}}, "b_alias")). 22 | Column(Alias(subQ, "subq")). 23 | From("e"). 24 | JoinClause("CROSS JOIN j1"). 25 | Join("j2"). 26 | LeftJoin("j3"). 27 | RightJoin("j4"). 28 | InnerJoin("j5"). 29 | CrossJoin("j6"). 30 | Where("f = ?", 4). 31 | Where(Eq{"g": 5}). 32 | Where(map[string]interface{}{"h": 6}). 33 | Where(Eq{"i": []int{7, 8, 9}}). 34 | Where(Or{Expr("j = ?", 10), And{Eq{"k": 11}, Expr("true")}}). 35 | GroupBy("l"). 36 | Having("m = n"). 37 | OrderByClause("? DESC", 1). 38 | OrderBy("o ASC", "p DESC"). 39 | Limit(12). 40 | Offset(13). 41 | Suffix("FETCH FIRST ? ROWS ONLY", 14) 42 | 43 | sql, args, err := b.ToSql() 44 | assert.NoError(t, err) 45 | 46 | expectedSql := 47 | "WITH prefix AS ? " + 48 | "SELECT DISTINCT a, b, c, IF(d IN (?,?,?), 1, 0) as stat_column, a > ?, " + 49 | "(b IN (?,?,?)) AS b_alias, " + 50 | "(SELECT aa, bb FROM dd) AS subq " + 51 | "FROM e " + 52 | "CROSS JOIN j1 JOIN j2 LEFT JOIN j3 RIGHT JOIN j4 INNER JOIN j5 CROSS JOIN j6 " + 53 | "WHERE f = ? AND g = ? AND h = ? AND i IN (?,?,?) AND (j = ? OR (k = ? AND true)) " + 54 | "GROUP BY l HAVING m = n ORDER BY ? DESC, o ASC, p DESC LIMIT 12 OFFSET 13 " + 55 | "FETCH FIRST ? ROWS ONLY" 56 | assert.Equal(t, expectedSql, sql) 57 | 58 | expectedArgs := []interface{}{0, 1, 2, 3, 100, 101, 102, 103, 4, 5, 6, 7, 8, 9, 10, 11, 1, 14} 59 | assert.Equal(t, expectedArgs, args) 60 | } 61 | 62 | func TestSelectBuilderFromSelect(t *testing.T) { 63 | subQ := Select("c").From("d").Where(Eq{"i": 0}) 64 | b := Select("a", "b").FromSelect(subQ, "subq") 65 | sql, args, err := b.ToSql() 66 | assert.NoError(t, err) 67 | 68 | expectedSql := "SELECT a, b FROM (SELECT c FROM d WHERE i = ?) AS subq" 69 | assert.Equal(t, expectedSql, sql) 70 | 71 | expectedArgs := []interface{}{0} 72 | assert.Equal(t, expectedArgs, args) 73 | } 74 | 75 | func TestSelectBuilderFromSelectNestedDollarPlaceholders(t *testing.T) { 76 | subQ := Select("c"). 77 | From("t"). 78 | Where(Gt{"c": 1}). 79 | PlaceholderFormat(Dollar) 80 | b := Select("c"). 81 | FromSelect(subQ, "subq"). 82 | Where(Lt{"c": 2}). 83 | PlaceholderFormat(Dollar) 84 | sql, args, err := b.ToSql() 85 | assert.NoError(t, err) 86 | 87 | expectedSql := "SELECT c FROM (SELECT c FROM t WHERE c > $1) AS subq WHERE c < $2" 88 | assert.Equal(t, expectedSql, sql) 89 | 90 | expectedArgs := []interface{}{1, 2} 91 | assert.Equal(t, expectedArgs, args) 92 | } 93 | 94 | func TestSelectBuilderToSqlErr(t *testing.T) { 95 | _, _, err := Select().From("x").ToSql() 96 | assert.Error(t, err) 97 | } 98 | 99 | func TestSelectBuilderPlaceholders(t *testing.T) { 100 | b := Select("test").Where("x = ? AND y = ?") 101 | 102 | sql, _, _ := b.PlaceholderFormat(Question).ToSql() 103 | assert.Equal(t, "SELECT test WHERE x = ? AND y = ?", sql) 104 | 105 | sql, _, _ = b.PlaceholderFormat(Dollar).ToSql() 106 | assert.Equal(t, "SELECT test WHERE x = $1 AND y = $2", sql) 107 | 108 | sql, _, _ = b.PlaceholderFormat(Colon).ToSql() 109 | assert.Equal(t, "SELECT test WHERE x = :1 AND y = :2", sql) 110 | 111 | sql, _, _ = b.PlaceholderFormat(AtP).ToSql() 112 | assert.Equal(t, "SELECT test WHERE x = @p1 AND y = @p2", sql) 113 | } 114 | 115 | func TestSelectBuilderRunners(t *testing.T) { 116 | db := &DBStub{} 117 | b := Select("test").RunWith(db) 118 | 119 | expectedSql := "SELECT test" 120 | 121 | b.Exec() 122 | assert.Equal(t, expectedSql, db.LastExecSql) 123 | 124 | b.Query() 125 | assert.Equal(t, expectedSql, db.LastQuerySql) 126 | 127 | b.QueryRow() 128 | assert.Equal(t, expectedSql, db.LastQueryRowSql) 129 | 130 | err := b.Scan() 131 | assert.NoError(t, err) 132 | } 133 | 134 | func TestSelectBuilderNoRunner(t *testing.T) { 135 | b := Select("test") 136 | 137 | _, err := b.Exec() 138 | assert.Equal(t, RunnerNotSet, err) 139 | 140 | _, err = b.Query() 141 | assert.Equal(t, RunnerNotSet, err) 142 | 143 | err = b.Scan() 144 | assert.Equal(t, RunnerNotSet, err) 145 | } 146 | 147 | func TestSelectBuilderSimpleJoin(t *testing.T) { 148 | 149 | expectedSql := "SELECT * FROM bar JOIN baz ON bar.foo = baz.foo" 150 | expectedArgs := []interface{}(nil) 151 | 152 | b := Select("*").From("bar").Join("baz ON bar.foo = baz.foo") 153 | 154 | sql, args, err := b.ToSql() 155 | assert.NoError(t, err) 156 | 157 | assert.Equal(t, expectedSql, sql) 158 | assert.Equal(t, args, expectedArgs) 159 | } 160 | 161 | func TestSelectBuilderParamJoin(t *testing.T) { 162 | 163 | expectedSql := "SELECT * FROM bar JOIN baz ON bar.foo = baz.foo AND baz.foo = ?" 164 | expectedArgs := []interface{}{42} 165 | 166 | b := Select("*").From("bar").Join("baz ON bar.foo = baz.foo AND baz.foo = ?", 42) 167 | 168 | sql, args, err := b.ToSql() 169 | assert.NoError(t, err) 170 | 171 | assert.Equal(t, expectedSql, sql) 172 | assert.Equal(t, args, expectedArgs) 173 | } 174 | 175 | func TestSelectBuilderNestedSelectJoin(t *testing.T) { 176 | 177 | expectedSql := "SELECT * FROM bar JOIN ( SELECT * FROM baz WHERE foo = ? ) r ON bar.foo = r.foo" 178 | expectedArgs := []interface{}{42} 179 | 180 | nestedSelect := Select("*").From("baz").Where("foo = ?", 42) 181 | 182 | b := Select("*").From("bar").JoinClause(nestedSelect.Prefix("JOIN (").Suffix(") r ON bar.foo = r.foo")) 183 | 184 | sql, args, err := b.ToSql() 185 | assert.NoError(t, err) 186 | 187 | assert.Equal(t, expectedSql, sql) 188 | assert.Equal(t, args, expectedArgs) 189 | } 190 | 191 | func TestSelectWithOptions(t *testing.T) { 192 | sql, _, err := Select("*").From("foo").Distinct().Options("SQL_NO_CACHE").ToSql() 193 | 194 | assert.NoError(t, err) 195 | assert.Equal(t, "SELECT DISTINCT SQL_NO_CACHE * FROM foo", sql) 196 | } 197 | 198 | func TestSelectWithRemoveLimit(t *testing.T) { 199 | sql, _, err := Select("*").From("foo").Limit(10).RemoveLimit().ToSql() 200 | 201 | assert.NoError(t, err) 202 | assert.Equal(t, "SELECT * FROM foo", sql) 203 | } 204 | 205 | func TestSelectWithRemoveOffset(t *testing.T) { 206 | sql, _, err := Select("*").From("foo").Offset(10).RemoveOffset().ToSql() 207 | 208 | assert.NoError(t, err) 209 | assert.Equal(t, "SELECT * FROM foo", sql) 210 | } 211 | 212 | func TestSelectBuilderNestedSelectDollar(t *testing.T) { 213 | nestedBuilder := StatementBuilder.PlaceholderFormat(Dollar).Select("*").Prefix("NOT EXISTS ("). 214 | From("bar").Where("y = ?", 42).Suffix(")") 215 | outerSql, _, err := StatementBuilder.PlaceholderFormat(Dollar).Select("*"). 216 | From("foo").Where("x = ?").Where(nestedBuilder).ToSql() 217 | 218 | assert.NoError(t, err) 219 | assert.Equal(t, "SELECT * FROM foo WHERE x = $1 AND NOT EXISTS ( SELECT * FROM bar WHERE y = $2 )", outerSql) 220 | } 221 | 222 | func TestSelectBuilderMustSql(t *testing.T) { 223 | defer func() { 224 | if r := recover(); r == nil { 225 | t.Errorf("TestSelectBuilderMustSql should have panicked!") 226 | } 227 | }() 228 | // This function should cause a panic 229 | Select().From("foo").MustSql() 230 | } 231 | 232 | func TestSelectWithoutWhereClause(t *testing.T) { 233 | sql, _, err := Select("*").From("users").ToSql() 234 | assert.NoError(t, err) 235 | assert.Equal(t, "SELECT * FROM users", sql) 236 | } 237 | 238 | func TestSelectWithNilWhereClause(t *testing.T) { 239 | sql, _, err := Select("*").From("users").Where(nil).ToSql() 240 | assert.NoError(t, err) 241 | assert.Equal(t, "SELECT * FROM users", sql) 242 | } 243 | 244 | func TestSelectWithEmptyStringWhereClause(t *testing.T) { 245 | sql, _, err := Select("*").From("users").Where("").ToSql() 246 | assert.NoError(t, err) 247 | assert.Equal(t, "SELECT * FROM users", sql) 248 | } 249 | 250 | func TestSelectSubqueryPlaceholderNumbering(t *testing.T) { 251 | subquery := Select("a").Where("b = ?", 1).PlaceholderFormat(Dollar) 252 | with := subquery.Prefix("WITH a AS (").Suffix(")") 253 | 254 | sql, args, err := Select("*"). 255 | PrefixExpr(with). 256 | FromSelect(subquery, "q"). 257 | Where("c = ?", 2). 258 | PlaceholderFormat(Dollar). 259 | ToSql() 260 | assert.NoError(t, err) 261 | 262 | expectedSql := "WITH a AS ( SELECT a WHERE b = $1 ) SELECT * FROM (SELECT a WHERE b = $2) AS q WHERE c = $3" 263 | assert.Equal(t, expectedSql, sql) 264 | assert.Equal(t, []interface{}{1, 1, 2}, args) 265 | } 266 | 267 | func TestSelectSubqueryInConjunctionPlaceholderNumbering(t *testing.T) { 268 | subquery := Select("a").Where(Eq{"b": 1}).Prefix("EXISTS(").Suffix(")").PlaceholderFormat(Dollar) 269 | 270 | sql, args, err := Select("*"). 271 | Where(Or{subquery}). 272 | Where("c = ?", 2). 273 | PlaceholderFormat(Dollar). 274 | ToSql() 275 | assert.NoError(t, err) 276 | 277 | expectedSql := "SELECT * WHERE (EXISTS( SELECT a WHERE b = $1 )) AND c = $2" 278 | assert.Equal(t, expectedSql, sql) 279 | assert.Equal(t, []interface{}{1, 2}, args) 280 | } 281 | 282 | func TestSelectJoinClausePlaceholderNumbering(t *testing.T) { 283 | subquery := Select("a").Where(Eq{"b": 2}).PlaceholderFormat(Dollar) 284 | 285 | sql, args, err := Select("t1.a"). 286 | From("t1"). 287 | Where(Eq{"a": 1}). 288 | JoinClause(subquery.Prefix("JOIN (").Suffix(") t2 ON (t1.a = t2.a)")). 289 | PlaceholderFormat(Dollar). 290 | ToSql() 291 | assert.NoError(t, err) 292 | 293 | expectedSql := "SELECT t1.a FROM t1 JOIN ( SELECT a WHERE b = $1 ) t2 ON (t1.a = t2.a) WHERE a = $2" 294 | assert.Equal(t, expectedSql, sql) 295 | assert.Equal(t, []interface{}{2, 1}, args) 296 | } 297 | 298 | func ExampleSelect() { 299 | Select("id", "created", "first_name").From("users") // ... continue building up your query 300 | 301 | // sql methods in select columns are ok 302 | Select("first_name", "count(*)").From("users") 303 | 304 | // column aliases are ok too 305 | Select("first_name", "count(*) as n_users").From("users") 306 | } 307 | 308 | func ExampleSelectBuilder_From() { 309 | Select("id", "created", "first_name").From("users") // ... continue building up your query 310 | } 311 | 312 | func ExampleSelectBuilder_Where() { 313 | companyId := 20 314 | Select("id", "created", "first_name").From("users").Where("company = ?", companyId) 315 | } 316 | 317 | func ExampleSelectBuilder_Where_helpers() { 318 | companyId := 20 319 | 320 | Select("id", "created", "first_name").From("users").Where(Eq{ 321 | "company": companyId, 322 | }) 323 | 324 | Select("id", "created", "first_name").From("users").Where(GtOrEq{ 325 | "created": time.Now().AddDate(0, 0, -7), 326 | }) 327 | 328 | Select("id", "created", "first_name").From("users").Where(And{ 329 | GtOrEq{ 330 | "created": time.Now().AddDate(0, 0, -7), 331 | }, 332 | Eq{ 333 | "company": companyId, 334 | }, 335 | }) 336 | } 337 | 338 | func ExampleSelectBuilder_Where_multiple() { 339 | companyId := 20 340 | 341 | // multiple where's are ok 342 | 343 | Select("id", "created", "first_name"). 344 | From("users"). 345 | Where("company = ?", companyId). 346 | Where(GtOrEq{ 347 | "created": time.Now().AddDate(0, 0, -7), 348 | }) 349 | } 350 | 351 | func ExampleSelectBuilder_FromSelect() { 352 | usersByCompany := Select("company", "count(*) as n_users").From("users").GroupBy("company") 353 | query := Select("company.id", "company.name", "users_by_company.n_users"). 354 | FromSelect(usersByCompany, "users_by_company"). 355 | Join("company on company.id = users_by_company.company") 356 | 357 | sql, _, _ := query.ToSql() 358 | fmt.Println(sql) 359 | 360 | // Output: SELECT company.id, company.name, users_by_company.n_users FROM (SELECT company, count(*) as n_users FROM users GROUP BY company) AS users_by_company JOIN company on company.id = users_by_company.company 361 | } 362 | 363 | func ExampleSelectBuilder_Columns() { 364 | query := Select("id").Columns("created", "first_name").From("users") 365 | 366 | sql, _, _ := query.ToSql() 367 | fmt.Println(sql) 368 | // Output: SELECT id, created, first_name FROM users 369 | } 370 | 371 | func ExampleSelectBuilder_Columns_order() { 372 | // out of order is ok too 373 | query := Select("id").Columns("created").From("users").Columns("first_name") 374 | 375 | sql, _, _ := query.ToSql() 376 | fmt.Println(sql) 377 | // Output: SELECT id, created, first_name FROM users 378 | } 379 | 380 | func ExampleSelectBuilder_Scan() { 381 | 382 | var db *sql.DB 383 | 384 | query := Select("id", "created", "first_name").From("users") 385 | query = query.RunWith(db) 386 | 387 | var id int 388 | var created time.Time 389 | var firstName string 390 | 391 | if err := query.Scan(&id, &created, &firstName); err != nil { 392 | log.Println(err) 393 | return 394 | } 395 | } 396 | 397 | func ExampleSelectBuilder_ScanContext() { 398 | 399 | var db *sql.DB 400 | 401 | query := Select("id", "created", "first_name").From("users") 402 | query = query.RunWith(db) 403 | 404 | var id int 405 | var created time.Time 406 | var firstName string 407 | 408 | if err := query.ScanContext(ctx, &id, &created, &firstName); err != nil { 409 | log.Println(err) 410 | return 411 | } 412 | } 413 | 414 | func ExampleSelectBuilder_RunWith() { 415 | 416 | var db *sql.DB 417 | 418 | query := Select("id", "created", "first_name").From("users").RunWith(db) 419 | 420 | var id int 421 | var created time.Time 422 | var firstName string 423 | 424 | if err := query.Scan(&id, &created, &firstName); err != nil { 425 | log.Println(err) 426 | return 427 | } 428 | } 429 | 430 | func ExampleSelectBuilder_ToSql() { 431 | 432 | var db *sql.DB 433 | 434 | query := Select("id", "created", "first_name").From("users") 435 | 436 | sql, args, err := query.ToSql() 437 | if err != nil { 438 | log.Println(err) 439 | return 440 | } 441 | 442 | rows, err := db.Query(sql, args...) 443 | if err != nil { 444 | log.Println(err) 445 | return 446 | } 447 | 448 | defer rows.Close() 449 | 450 | for rows.Next() { 451 | // scan... 452 | } 453 | } 454 | 455 | func TestRemoveColumns(t *testing.T) { 456 | query := Select("id"). 457 | From("users"). 458 | RemoveColumns() 459 | query = query.Columns("name") 460 | sql, _, err := query.ToSql() 461 | assert.NoError(t, err) 462 | assert.Equal(t, "SELECT name FROM users", sql) 463 | } 464 | --------------------------------------------------------------------------------