├── .github
├── mddocs
└── workflows
│ ├── neocities.yml
│ └── tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── START_HERE.md
├── builtins.go
├── builtins_test.go
├── colors.go
├── cte.go
├── cte_test.go
├── delete_query.go
├── delete_query_test.go
├── fetch_exec.go
├── fetch_exec_test.go
├── fields.go
├── fields_test.go
├── fmt.go
├── fmt_test.go
├── go.mod
├── go.sum
├── header.png
├── insert_query.go
├── insert_query_test.go
├── integration_test.go
├── internal
├── googleuuid
│ └── googleuuid.go
├── pqarray
│ └── pqarray.go
└── testutil
│ └── testutil.go
├── joins.go
├── joins_test.go
├── logger.go
├── logger_test.go
├── misc.go
├── misc_test.go
├── row_column.go
├── select_query.go
├── select_query_test.go
├── sq.go
├── sq.md
├── sq_test.go
├── update_query.go
├── update_query_test.go
├── window.go
└── window_test.go
/.github/mddocs:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bokwoon95/sq/eae3b0c03361f5b98ac6d0701c6aa71c94d4e4c2/.github/mddocs
--------------------------------------------------------------------------------
/.github/workflows/neocities.yml:
--------------------------------------------------------------------------------
1 | name: Deploy docs to Neocities
2 | on:
3 | push:
4 | branches: [main]
5 | jobs:
6 | deploy_to_neocities:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - name: Clone repo
10 | uses: actions/checkout@v3
11 | - run: mkdir public && .github/mddocs sq.md public/sq.html
12 | - name: Deploy to neocities
13 | uses: bcomnes/deploy-to-neocities@v1
14 | with:
15 | api_token: ${{ secrets.NEOCITIES_API_KEY }}
16 | dist_dir: public
17 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: tests
2 | on:
3 | push:
4 | branches: [main]
5 | pull_request:
6 | branches: [main]
7 | jobs:
8 | run_sq_tests:
9 | runs-on: ubuntu-latest
10 | services:
11 | postgres:
12 | image: postgres
13 | env:
14 | POSTGRES_USER: 'user1'
15 | POSTGRES_PASSWORD: 'Hunter2!'
16 | POSTGRES_DB: 'sakila'
17 | options: >-
18 | --health-cmd pg_isready
19 | --health-interval 10s
20 | --health-timeout 5s
21 | --health-retries 5
22 | ports:
23 | - '5456:5432'
24 | mysql:
25 | image: mysql
26 | env:
27 | MYSQL_ROOT_PASSWORD: 'Hunter2!'
28 | MYSQL_USER: 'user1'
29 | MYSQL_PASSWORD: 'Hunter2!'
30 | MYSQL_DATABASE: 'sakila'
31 | options: >-
32 | --health-cmd "mysqladmin ping"
33 | --health-interval 10s
34 | --health-timeout 5s
35 | --health-retries 5
36 | --health-start-period 30s
37 | ports:
38 | - '3330:3306'
39 | sqlserver:
40 | image: 'mcr.microsoft.com/azure-sql-edge'
41 | env:
42 | ACCEPT_EULA: 'Y'
43 | MSSQL_SA_PASSWORD: 'Hunter2!'
44 | options: >-
45 | --health-cmd "/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P Hunter2! -Q 'select 1' -b -o /dev/null"
46 | --health-interval 10s
47 | --health-timeout 5s
48 | --health-retries 5
49 | --health-start-period 30s
50 | ports:
51 | - '1447:1433'
52 | steps:
53 | - name: Install go
54 | uses: actions/setup-go@v3
55 | with:
56 | go-version: '>=1.18.0'
57 | - name: Clone repo
58 | uses: actions/checkout@v3
59 | - run: go test . -tags=fts5 -failfast -shuffle on -coverprofile coverage -race -postgres 'postgres://user1:Hunter2!@localhost:5456/sakila?sslmode=disable' -mysql 'root:Hunter2!@tcp(localhost:3330)/sakila?multiStatements=true&parseTime=true' -sqlserver 'sqlserver://sa:Hunter2!@localhost:1447'
60 | - name: Convert coverage to coverage.lcov
61 | uses: jandelgado/gcov2lcov-action@v1.0.0
62 | with:
63 | infile: coverage
64 | outfile: coverage.lcov
65 | - name: Upload coverage.lcov to Coveralls
66 | uses: coverallsapp/github-action@master
67 | with:
68 | github-token: ${{ secrets.GITHUB_TOKEN }}
69 | path-to-lcov: coverage.lcov
70 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.sqlite*
2 | .idea
3 | coverage.out
4 | coverage
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Chua Bok Woon
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://pkg.go.dev/github.com/bokwoon95/sq)
2 | 
3 | [](https://goreportcard.com/report/github.com/bokwoon95/sq)
4 | [](https://coveralls.io/github/bokwoon95/sq?branch=main)
5 |
6 |
7 |
8 | # sq (Structured Query)
9 |
10 | [one-page documentation](https://bokwoon.neocities.org/sq.html)
11 |
12 | sq is a type-safe data mapper and query builder for Go. Its concept is simple: you provide a callback function that maps a row to a struct, generics ensure that you get back a slice of structs at the end. Additionally, mentioning a column in the callback function automatically adds it to the SELECT clause so you don't even have to explicitly mention what columns you want to select: the [act of mapping a column is the same as selecting it](#select-example-raw-sql). This eliminates a source of errors where you have specify the columns twice (once in the query itself, once to the call to rows.Scan) and end up missing a column, getting the column order wrong or mistyping a column name.
13 |
14 | Notable features:
15 |
16 | - Works across SQLite, Postgres, MySQL and SQL Server. [[more info](https://bokwoon.neocities.org/sq.html#set-query-dialect)]
17 | - Each dialect has its own query builder, allowing you to use dialect-specific features. [[more info](https://bokwoon.neocities.org/sq.html#dialect-specific-features)]
18 | - Declarative schema migrations. [[more info](https://bokwoon.neocities.org/sq.html#declarative-schema)]
19 | - Supports arrays, enums, JSON and UUID. [[more info](https://bokwoon.neocities.org/sq.html#arrays-enums-json-uuid)]
20 | - Query logging. [[more info](https://bokwoon.neocities.org/sq.html#logging)]
21 |
22 | # Installation
23 |
24 | This package only supports Go 1.19 and above.
25 |
26 | ```shell
27 | $ go get github.com/bokwoon95/sq
28 | $ go install -tags=fts5 github.com/bokwoon95/sqddl@latest
29 | ```
30 |
31 | # Features
32 |
33 | - IN
34 | - [In Slice](https://bokwoon.neocities.org/sq.html#in-slice) - `a IN (1, 2, 3)`
35 | - [In RowValues](https://bokwoon.neocities.org/sq.html#in-rowvalues) - `(a, b, c) IN ((1, 2, 3), (4, 5, 6), (7, 8, 9))`
36 | - [In Subquery](https://bokwoon.neocities.org/sq.html#in-subquery) - `(a, b) IN (SELECT a, b FROM tbl WHERE condition)`
37 | - CASE
38 | - [Predicate Case](https://bokwoon.neocities.org/sq.html#predicate-case) - `CASE WHEN a THEN b WHEN c THEN d ELSE e END`
39 | - [Simple case](https://bokwoon.neocities.org/sq.html#simple-case) - `CASE expr WHEN a THEN b WHEN c THEN d ELSE e END`
40 | - EXISTS
41 | - [Where Exists](https://bokwoon.neocities.org/sq.html#where-exists)
42 | - [Where Not Exists](https://bokwoon.neocities.org/sq.html#where-not-exists)
43 | - [Select Exists](https://bokwoon.neocities.org/sq.html#querybuilder-fetch-exists)
44 | - [Subqueries](https://bokwoon.neocities.org/sq.html#subqueries)
45 | - [WITH (Common Table Expressions)](https://bokwoon.neocities.org/sq.html#common-table-expressions)
46 | - [Aggregate functions](https://bokwoon.neocities.org/sq.html#aggregate-functions)
47 | - [Window functions](https://bokwoon.neocities.org/sq.html#window-functions)
48 | - [UNION, INTERSECT, EXCEPT](https://bokwoon.neocities.org/sq.html#union-intersect-except)
49 | - [INSERT from SELECT](https://bokwoon.neocities.org/sq.html#querybuilder-insert-from-select)
50 | - RETURNING
51 | - [SQLite RETURNING](https://bokwoon.neocities.org/sq.html#sqlite-returning)
52 | - [Postgres RETURNING](https://bokwoon.neocities.org/sq.html#postgres-returning)
53 | - LastInsertId
54 | - [SQLite LastInsertId](https://bokwoon.neocities.org/sq.html#sqlite-last-insert-id)
55 | - [MySQL LastInsertId](https://bokwoon.neocities.org/sq.html#mysql-last-insert-id)
56 | - Insert ignore duplicates
57 | - [SQLite Insert ignore duplicates](https://bokwoon.neocities.org/sq.html#sqlite-insert-ignore-duplicates)
58 | - [Postgres Insert ignore duplicates](https://bokwoon.neocities.org/sq.html#postgres-insert-ignore-duplicates)
59 | - [MySQL Insert ignore duplicates](https://bokwoon.neocities.org/sq.html#mysql-insert-ignore-duplicates)
60 | - [SQL Server Insert ignore duplicates](https://bokwoon.neocities.org/sq.html#sqlserver-insert-ignore-duplicates)
61 | - Upsert
62 | - [SQLite Upsert](https://bokwoon.neocities.org/sq.html#sqlite-upsert)
63 | - [Postgres Upsert](https://bokwoon.neocities.org/sq.html#postgres-upsert)
64 | - [MySQL Upsert](https://bokwoon.neocities.org/sq.html#mysql-upsert)
65 | - [SQL Server Upsert](https://bokwoon.neocities.org/sq.html#sqlserver-upsert)
66 | - Update with Join
67 | - [SQLite Update with Join](https://bokwoon.neocities.org/sq.html#sqlite-update-with-join)
68 | - [Postgres Update with Join](https://bokwoon.neocities.org/sq.html#postgres-update-with-join)
69 | - [MySQL Update with Join](https://bokwoon.neocities.org/sq.html#mysql-update-with-join)
70 | - [SQL Server Update with Join](https://bokwoon.neocities.org/sq.html#sqlserver-update-with-join)
71 | - Delete with Join
72 | - [SQLite Delete with Join](https://bokwoon.neocities.org/sq.html#sqlite-delete-with-join)
73 | - [Postgres Delete with Join](https://bokwoon.neocities.org/sq.html#postgres-delete-with-join)
74 | - [MySQL Delete with Join](https://bokwoon.neocities.org/sq.html#mysql-delete-with-join)
75 | - [SQL Server Delete with Join](https://bokwoon.neocities.org/sq.html#sqlserver-delete-with-join)
76 | - Bulk Update
77 | - [SQLite Bulk Update](https://bokwoon.neocities.org/sq.html#sqlite-bulk-update)
78 | - [Postgres Bulk Update](https://bokwoon.neocities.org/sq.html#postgres-bulk-update)
79 | - [MySQL Bulk Update](https://bokwoon.neocities.org/sq.html#mysql-bulk-update)
80 | - [SQL Server Bulk Update](https://bokwoon.neocities.org/sq.html#sqlserver-bulk-update)
81 |
82 | ## SELECT example (Raw SQL)
83 |
84 | ```go
85 | db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable")
86 |
87 | actors, err := sq.FetchAll(db, sq.
88 | Queryf("SELECT {*} FROM actor AS a WHERE a.actor_id IN ({})",
89 | []int{1, 2, 3, 4, 5},
90 | ).
91 | SetDialect(sq.DialectPostgres),
92 | func(row *sq.Row) Actor {
93 | return Actor{
94 | ActorID: row.Int("a.actor_id"),
95 | FirstName: row.String("a.first_name"),
96 | LastName: row.String("a.last_name"),
97 | LastUpdate: row.Time("a.last_update"),
98 | }
99 | },
100 | )
101 | ```
102 |
103 | ## SELECT example (Query Builder)
104 |
105 | To use the query builder, you must first [define your table structs](https://bokwoon.neocities.org/sq.html#table-structs).
106 |
107 | ```go
108 | type ACTOR struct {
109 | sq.TableStruct
110 | ACTOR_ID sq.NumberField
111 | FIRST_NAME sq.StringField
112 | LAST_NAME sq.StringField
113 | LAST_UPDATE sq.TimeField
114 | }
115 |
116 | db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable")
117 |
118 | a := sq.New[ACTOR]("a")
119 | actors, err := sq.FetchAll(db, sq.
120 | From(a).
121 | Where(a.ACTOR_ID.In([]int{1, 2, 3, 4, 5})).
122 | SetDialect(sq.DialectPostgres),
123 | func(row *sq.Row) Actor {
124 | return Actor{
125 | ActorID: row.IntField(a.ACTOR_ID),
126 | FirstName: row.StringField(a.FIRST_NAME),
127 | LastName: row.StringField(a.LAST_NAME),
128 | LastUpdate: row.TimeField(a.LAST_UPDATE),
129 | }
130 | },
131 | )
132 | ```
133 |
134 | ## INSERT example (Raw SQL)
135 |
136 | ```go
137 | db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable")
138 |
139 | _, err := sq.Exec(db, sq.
140 | Queryf("INSERT INTO actor (actor_id, first_name, last_name) VALUES {}", sq.RowValues{
141 | {18, "DAN", "TORN"},
142 | {56, "DAN", "HARRIS"},
143 | {166, "DAN", "STREEP"},
144 | }).
145 | SetDialect(sq.DialectPostgres),
146 | )
147 | ```
148 |
149 | ## INSERT example (Query Builder)
150 |
151 | To use the query builder, you must first [define your table structs](https://bokwoon.neocities.org/sq.html#table-structs).
152 |
153 | ```go
154 | type ACTOR struct {
155 | sq.TableStruct
156 | ACTOR_ID sq.NumberField
157 | FIRST_NAME sq.StringField
158 | LAST_NAME sq.StringField
159 | LAST_UPDATE sq.TimeField
160 | }
161 |
162 | db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable")
163 |
164 | a := sq.New[ACTOR]("a")
165 | _, err := sq.Exec(db, sq.
166 | InsertInto(a).
167 | Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME).
168 | Values(18, "DAN", "TORN").
169 | Values(56, "DAN", "HARRIS").
170 | Values(166, "DAN", "STREEP").
171 | SetDialect(sq.DialectPostgres),
172 | )
173 | ```
174 |
175 | For a more detailed overview, look at the [Quickstart](https://bokwoon.neocities.org/sq.html#quickstart).
176 |
177 | ## Project Status
178 |
179 | sq is done for my use case (hence it may seem inactive, but it's just complete). At this point I'm just waiting for people to ask questions or file feature requests under [discussions](https://github.com/bokwoon95/sq/discussions).
180 |
181 | ## Contributing
182 |
183 | See [START\_HERE.md](https://github.com/bokwoon95/sq/blob/main/START_HERE.md).
184 |
--------------------------------------------------------------------------------
/START_HERE.md:
--------------------------------------------------------------------------------
1 | This document describes how the codebase is organized. It is meant for people who are contributing to the codebase (or are just casually browsing).
2 |
3 | Files are written in such a way that **each successive file in the list below only depends on files that come before it**. This self-enforced restriction makes deep architectural changes trivial because you can essentially blow away the entire codebase and rewrite it from scratch file-by-file, complete with working tests every step of the way. Please adhere to this file order when submitting pull requests.
4 |
5 | - [**sq.go**](https://github.com/bokwoon95/sq/blob/main/sq.go)
6 | - Core interfaces: SQLWriter, DB, Query, Table, PolicyTable, Window, Field, Predicate, Assignment, Any, Array, Binary, Boolean, Enum, JSON, Number, String, UUID, Time, Enumeration, DialectValuer,
7 | - Data types: Result, TableStruct, ViewStruct.
8 | - Misc utility functions.
9 | - [**fmt.go**](https://github.com/bokwoon95/sq/blob/main/fmt.go)
10 | - Two important string building functions that everything else is built on: [Writef](https://pkg.go.dev/github.com/bokwoon95/sq#Writef) and [WriteValue](https://pkg.go.dev/github.com/bokwoon95/sq#WriteValue).
11 | - Data types: Parameter, BinaryParameter, BooleanParameter, NumberParameter, StringParameter, TimeParameter.
12 | - Utility functions: QuoteIdentifier, EscapeQuote, Sprintf, Sprint.
13 | - [**builtins.go**](https://github.com/bokwoon95/sq/blob/main/builtins.go)
14 | - Builtin data types that are built on top of Writef and WriteValue: Expression (Expr), CustomQuery (Queryf), VariadicPredicate, assignment, RowValue, RowValues, Fields.
15 | - Builtin functions that are built on top of Writef and WriteValue: Eq, Ne, Lt, Le, Gt, Ge, Exists, NotExists, In.
16 | - [**fields.go**](https://github.com/bokwoon95/sq/blob/main/fields.go)
17 | - All of the field types: AnyField, ArrayField, BinaryField, BooleanField, EnumField, JSONField, NumberField, StringField, UUIDField, TimeField.
18 | - Data types: Identifier, Timestamp.
19 | - Functions: [New](https://pkg.go.dev/github.com/bokwoon95/sq#New), ArrayValue, EnumValue, JSONValue, UUIDValue.
20 | - [**cte.go**](https://github.com/bokwoon95/sq/blob/main/cte.go)
21 | - CTE represents an SQL common table expression (CTE).
22 | - UNION, INTERSECT, EXCEPT.
23 | - [**joins.go**](https://github.com/bokwoon95/sq/blob/main/joins.go)
24 | - The various SQL joins.
25 | - [**row_column.go**](https://github.com/bokwoon95/sq/blob/main/row_column.go)
26 | - Row and Column methods.
27 | - [**window.go**](https://github.com/bokwoon95/sq/blob/main/window.go)
28 | - SQL windows and window functions.
29 | - [**select_query.go**](https://github.com/bokwoon95/sq/blob/main/select_query.go)
30 | - SQL SELECT query builder.
31 | - [**insert_query.go**](https://github.com/bokwoon95/sq/blob/main/insert_query.go)
32 | - SQL INSERT query builder.
33 | - [**update_query.go**](https://github.com/bokwoon95/sq/blob/main/update_query.go)
34 | - SQL UPDATE query builder.
35 | - [**delete_query.go**](https://github.com/bokwoon95/sq/blob/main/delete_query.go)
36 | - SQL DELETE query builder.
37 | - [**logger.go**](https://github.com/bokwoon95/sq/blob/main/logger.go)
38 | - sq.Log and sq.VerboseLog.
39 | - [**fetch_exec.go**](https://github.com/bokwoon95/sq/blob/main/fetch_exec.go)
40 | - FetchCursor, FetchOne, FetchAll, Exec.
41 | - CompiledFetch, CompiledExec.
42 | - PreparedFetch, PreparedExec.
43 | - [**misc.go**](https://github.com/bokwoon95/sq/blob/main/misc.go)
44 | - Misc SQL constructs.
45 | - ValueExpression, LiteralValue, DialectExpression, CaseExpression, SimpleCaseExpression.
46 | - SelectValues (`SELECT ... UNION ALL SELECT ... UNION ALL SELECT ...`)
47 | - TableValues (`VALUES (...), (...), (...)`).
48 | - [**integration_test.go**](https://github.com/bokwoon95/sq/blob/main/integration_test.go)
49 | - Tests that interact with a live database i.e. SQLite, Postgres, MySQL and SQL Server.
50 |
51 | ## Testing
52 |
53 | Add tests if you add code.
54 |
55 | To run tests, use:
56 |
57 | ```shell
58 | $ go test . # -failfast -shuffle=on -coverprofile=coverage
59 | ```
60 |
61 | There are tests that require a live database connection. They will only run if you provide the corresponding database URL in the test flags:
62 |
63 | ```shell
64 | $ go test . -postgres $POSTGRES_URL -mysql $MYSQL_URL -sqlserver $SQLSERVER_URL # -failfast -shuffle=on -coverprofile=coverage
65 | ```
66 |
67 | You can consider using the [docker-compose.yml defined in the sqddl repo](https://github.com/bokwoon95/sqddl/blob/main/docker-compose.yml) to spin up Postgres, MySQL and SQL Server databases that are reachable at the following URLs:
68 |
69 | ```shell
70 | # docker-compose up -d
71 | POSTGRES_URL='postgres://user1:Hunter2!@localhost:5456/sakila?sslmode=disable'
72 | MYSQL_URL='root:Hunter2!@tcp(localhost:3330)/sakila?multiStatements=true&parseTime=true'
73 | MARIADB_URL='root:Hunter2!@tcp(localhost:3340)/sakila?multiStatements=true&parseTime=true'
74 | SQLSERVER_URL='sqlserver://sa:Hunter2!@localhost:1447'
75 | ```
76 |
77 | ## Documentation
78 |
79 | Documentation is contained entirely within [sq.md](https://github.com/bokwoon95/sq/blob/main/sq.md) in the project root directory. You can view the output at [https://bokwoon.neocities.org/sq.html](https://bokwoon.neocities.org/sq.html). The documentation is regenerated everytime a new commit is pushed to the main branch, so to change the documentation just change sq.md and submit a pull request.
80 |
81 | You can preview the output of sq.md locally by installing [github.com/bokwoon95/mddocs](https://github.com/bokwoon95/mddocs) and running it with sq.md as the argument.
82 |
83 | ```shell
84 | $ go install github/bokwoon95/mddocs@latest
85 | $ mddocs
86 | Usage:
87 | mddocs project.md # serves project.md on a localhost connection
88 | mddocs project.md project.html # render project.md into project.html
89 |
90 | $ mddocs sq.md
91 | serving sq.md at localhost:6060
92 | ```
93 |
94 | To add a new section and register it in the table of contents, append a `#headerID` to the end of a header (replace `headerID` with the actual header ID). The header ID should only contain unicode letters, digits, hyphen `-` and underscore `_`.
95 |
96 | ```text
97 | ## This is a header.
98 |
99 | ## This is a header with a headerID. #header-id <-- added to table of contents
100 | ```
101 |
--------------------------------------------------------------------------------
/builtins.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | "strings"
8 | )
9 |
10 | // Expression is an SQL expression that satisfies the Table, Field, Predicate,
11 | // Binary, Boolean, Number, String and Time interfaces.
12 | type Expression struct {
13 | format string
14 | values []any
15 | alias string
16 | }
17 |
18 | var _ interface {
19 | Table
20 | Field
21 | Predicate
22 | Any
23 | Assignment
24 | } = (*Expression)(nil)
25 |
26 | // Expr creates a new Expression using Writef syntax.
27 | func Expr(format string, values ...any) Expression {
28 | return Expression{format: format, values: values}
29 | }
30 |
31 | // WriteSQL implements the SQLWriter interface.
32 | func (expr Expression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
33 | err := Writef(ctx, dialect, buf, args, params, expr.format, expr.values)
34 | if err != nil {
35 | return err
36 | }
37 | return nil
38 | }
39 |
40 | // As returns a new Expression with the given alias.
41 | func (expr Expression) As(alias string) Expression {
42 | expr.alias = alias
43 | return expr
44 | }
45 |
46 | // In returns an 'expr IN (value)' Predicate.
47 | func (expr Expression) In(value any) Predicate { return In(expr, value) }
48 |
49 | // In returns an 'expr NOT IN (value)' Predicate.
50 | func (expr Expression) NotIn(value any) Predicate { return NotIn(expr, value) }
51 |
52 | // Eq returns an 'expr = value' Predicate.
53 | func (expr Expression) Eq(value any) Predicate { return cmp("=", expr, value) }
54 |
55 | // Ne returns an 'expr <> value' Predicate.
56 | func (expr Expression) Ne(value any) Predicate { return cmp("<>", expr, value) }
57 |
58 | // Lt returns an 'expr < value' Predicate.
59 | func (expr Expression) Lt(value any) Predicate { return cmp("<", expr, value) }
60 |
61 | // Le returns an 'expr <= value' Predicate.
62 | func (expr Expression) Le(value any) Predicate { return cmp("<=", expr, value) }
63 |
64 | // Gt returns an 'expr > value' Predicate.
65 | func (expr Expression) Gt(value any) Predicate { return cmp(">", expr, value) }
66 |
67 | // Ge returns an 'expr >= value' Predicate.
68 | func (expr Expression) Ge(value any) Predicate { return cmp(">=", expr, value) }
69 |
70 | // GetAlias returns the alias of the Expression.
71 | func (expr Expression) GetAlias() string { return expr.alias }
72 |
73 | // IsTable implements the Table interface.
74 | func (expr Expression) IsTable() {}
75 |
76 | // IsField implements the Field interface.
77 | func (expr Expression) IsField() {}
78 |
79 | // IsArray implements the Array interface.
80 | func (expr Expression) IsArray() {}
81 |
82 | // IsBinary implements the Binary interface.
83 | func (expr Expression) IsBinary() {}
84 |
85 | // IsBoolean implements the Boolean interface.
86 | func (expr Expression) IsBoolean() {}
87 |
88 | // IsEnum implements the Enum interface.
89 | func (expr Expression) IsEnum() {}
90 |
91 | // IsJSON implements the JSON interface.
92 | func (expr Expression) IsJSON() {}
93 |
94 | // IsNumber implements the Number interface.
95 | func (expr Expression) IsNumber() {}
96 |
97 | // IsString implements the String interface.
98 | func (expr Expression) IsString() {}
99 |
100 | // IsTime implements the Time interface.
101 | func (expr Expression) IsTime() {}
102 |
103 | // IsUUID implements the UUID interface.
104 | func (expr Expression) IsUUID() {}
105 |
106 | func (e Expression) IsAssignment() {}
107 |
108 | // CustomQuery represents a user-defined query.
109 | type CustomQuery struct {
110 | Dialect string
111 | Format string
112 | Values []any
113 | fields []Field
114 | }
115 |
116 | var _ Query = (*CustomQuery)(nil)
117 |
118 | // Queryf creates a new query using Writef syntax.
119 | func Queryf(format string, values ...any) CustomQuery {
120 | return CustomQuery{Format: format, Values: values}
121 | }
122 |
123 | // Queryf creates a new SQLite query using Writef syntax.
124 | func (b sqliteQueryBuilder) Queryf(format string, values ...any) CustomQuery {
125 | return CustomQuery{Dialect: DialectSQLite, Format: format, Values: values}
126 | }
127 |
128 | // Queryf creates a new Postgres query using Writef syntax.
129 | func (b postgresQueryBuilder) Queryf(format string, values ...any) CustomQuery {
130 | return CustomQuery{Dialect: DialectPostgres, Format: format, Values: values}
131 | }
132 |
133 | // Queryf creates a new MySQL query using Writef syntax.
134 | func (b mysqlQueryBuilder) Queryf(format string, values ...any) CustomQuery {
135 | return CustomQuery{Dialect: DialectMySQL, Format: format, Values: values}
136 | }
137 |
138 | // Queryf creates a new SQL Server query using Writef syntax.
139 | func (b sqlserverQueryBuilder) Queryf(format string, values ...any) CustomQuery {
140 | return CustomQuery{Dialect: DialectSQLServer, Format: format, Values: values}
141 | }
142 |
143 | // Append returns a new CustomQuery with the format string and values slice
144 | // appended to the current CustomQuery.
145 | func (q CustomQuery) Append(format string, values ...any) CustomQuery {
146 | q.Format += " " + format
147 | q.Values = append(q.Values, values...)
148 | return q
149 | }
150 |
151 | // WriteSQL implements the SQLWriter interface.
152 | func (q CustomQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
153 | var err error
154 | format := q.Format
155 | splitAt := -1
156 | for i := strings.IndexByte(format, '{'); i >= 0; i = strings.IndexByte(format, '{') {
157 | if i+2 <= len(format) && format[i:i+2] == "{{" {
158 | format = format[i+2:]
159 | continue
160 | }
161 | if i+3 <= len(format) && format[i:i+3] == "{*}" {
162 | splitAt = len(q.Format) - len(format[i:])
163 | break
164 | }
165 | format = format[i+1:]
166 | }
167 | if splitAt < 0 {
168 | return Writef(ctx, dialect, buf, args, params, q.Format, q.Values)
169 | }
170 | runningValuesIndex := 0
171 | ordinalIndices := make(map[int]int)
172 | err = writef(ctx, dialect, buf, args, params, q.Format[:splitAt], q.Values, &runningValuesIndex, ordinalIndices)
173 | if err != nil {
174 | return err
175 | }
176 | err = writeFields(ctx, dialect, buf, args, params, q.fields, true)
177 | if err != nil {
178 | return err
179 | }
180 | err = writef(ctx, dialect, buf, args, params, q.Format[splitAt+3:], q.Values, &runningValuesIndex, ordinalIndices)
181 | if err != nil {
182 | return err
183 | }
184 | return nil
185 | }
186 |
187 | // SetFetchableFields sets the fetchable fields of the query.
188 | func (q CustomQuery) SetFetchableFields(fields []Field) (query Query, ok bool) {
189 | format := q.Format
190 | for i := strings.IndexByte(format, '{'); i >= 0; i = strings.IndexByte(format, '{') {
191 | if i+2 <= len(format) && format[i:i+2] == "{{" {
192 | format = format[i+2:]
193 | continue
194 | }
195 | if i+3 <= len(format) && format[i:i+3] == "{*}" {
196 | q.fields = fields
197 | return q, true
198 | }
199 | format = format[i+1:]
200 | }
201 | return q, false
202 | }
203 |
204 | // GetFetchableFields gets the fetchable fields of the query.
205 | func (q CustomQuery) GetFetchableFields() []Field {
206 | return q.fields
207 | }
208 |
209 | // GetDialect gets the dialect of the query.
210 | func (q CustomQuery) GetDialect() string { return q.Dialect }
211 |
212 | // SetDialect sets the dialect of the query.
213 | func (q CustomQuery) SetDialect(dialect string) CustomQuery {
214 | q.Dialect = dialect
215 | return q
216 | }
217 |
218 | // VariadicPredicate represents the 'x AND y AND z...' or 'x OR Y OR z...' SQL
219 | // construct.
220 | type VariadicPredicate struct {
221 | // Toplevel indicates if the VariadicPredicate can skip writing the
222 | // (surrounding brackets).
223 | Toplevel bool
224 | alias string
225 | // If IsDisjunction is true, the Predicates are joined using OR. If false,
226 | // the Predicates are joined using AND. The default is AND.
227 | IsDisjunction bool
228 | // Predicates holds the predicates inside the VariadicPredicate
229 | Predicates []Predicate
230 | }
231 |
232 | var _ Predicate = (*VariadicPredicate)(nil)
233 |
234 | // And joins the predicates together with the AND operator.
235 | func And(predicates ...Predicate) VariadicPredicate {
236 | return VariadicPredicate{IsDisjunction: false, Predicates: predicates}
237 | }
238 |
239 | // Or joins the predicates together with the OR operator.
240 | func Or(predicates ...Predicate) VariadicPredicate {
241 | return VariadicPredicate{IsDisjunction: true, Predicates: predicates}
242 | }
243 |
244 | // WriteSQL implements the SQLWriter interface.
245 | func (p VariadicPredicate) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
246 | var err error
247 | if len(p.Predicates) == 0 {
248 | return fmt.Errorf("VariadicPredicate empty")
249 | }
250 |
251 | if len(p.Predicates) == 1 {
252 | switch p1 := p.Predicates[0].(type) {
253 | case nil:
254 | return fmt.Errorf("predicate #1 is nil")
255 | case VariadicPredicate:
256 | p1.Toplevel = p.Toplevel
257 | err = p1.WriteSQL(ctx, dialect, buf, args, params)
258 | if err != nil {
259 | return err
260 | }
261 | default:
262 | err = p.Predicates[0].WriteSQL(ctx, dialect, buf, args, params)
263 | if err != nil {
264 | return err
265 | }
266 | }
267 | return nil
268 | }
269 |
270 | if !p.Toplevel {
271 | buf.WriteString("(")
272 | }
273 | for i, predicate := range p.Predicates {
274 | if i > 0 {
275 | if p.IsDisjunction {
276 | buf.WriteString(" OR ")
277 | } else {
278 | buf.WriteString(" AND ")
279 | }
280 | }
281 | switch predicate := predicate.(type) {
282 | case nil:
283 | return fmt.Errorf("predicate #%d is nil", i+1)
284 | case VariadicPredicate:
285 | predicate.Toplevel = false
286 | err = predicate.WriteSQL(ctx, dialect, buf, args, params)
287 | if err != nil {
288 | return fmt.Errorf("predicate #%d: %w", i+1, err)
289 | }
290 | default:
291 | err = predicate.WriteSQL(ctx, dialect, buf, args, params)
292 | if err != nil {
293 | return fmt.Errorf("predicate #%d: %w", i+1, err)
294 | }
295 | }
296 | }
297 | if !p.Toplevel {
298 | buf.WriteString(")")
299 | }
300 | return nil
301 | }
302 |
303 | // As returns a new VariadicPredicate with the given alias.
304 | func (p VariadicPredicate) As(alias string) VariadicPredicate {
305 | p.alias = alias
306 | return p
307 | }
308 |
309 | // GetAlias returns the alias of the VariadicPredicate.
310 | func (p VariadicPredicate) GetAlias() string { return p.alias }
311 |
312 | // IsField implements the Field interface.
313 | func (p VariadicPredicate) IsField() {}
314 |
315 | // IsBooleanType implements the Predicate interface.
316 | func (p VariadicPredicate) IsBoolean() {}
317 |
318 | // assignment represents assigning a value to a Field.
319 | type assignment struct {
320 | field Field
321 | value any
322 | }
323 |
324 | var _ Assignment = (*assignment)(nil)
325 |
326 | // Set creates a new Assignment assigning the value to a field.
327 | func Set(field Field, value any) Assignment {
328 | return assignment{field: field, value: value}
329 | }
330 |
331 | // Setf creates a new Assignment assigning a custom expression to a Field.
332 | func Setf(field Field, format string, values ...any) Assignment {
333 | return assignment{field: field, value: Expr(format, values...)}
334 | }
335 |
336 | // WriteSQL implements the SQLWriter interface.
337 | func (a assignment) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
338 | if a.field == nil {
339 | return fmt.Errorf("field is nil")
340 | }
341 | var err error
342 | if dialect == DialectMySQL {
343 | err = a.field.WriteSQL(ctx, dialect, buf, args, params)
344 | if err != nil {
345 | return err
346 | }
347 | } else {
348 | err = withPrefix(a.field, "").WriteSQL(ctx, dialect, buf, args, params)
349 | if err != nil {
350 | return err
351 | }
352 | }
353 | buf.WriteString(" = ")
354 | _, isQuery := a.value.(Query)
355 | if isQuery {
356 | buf.WriteString("(")
357 | }
358 | err = WriteValue(ctx, dialect, buf, args, params, a.value)
359 | if err != nil {
360 | return err
361 | }
362 | if isQuery {
363 | buf.WriteString(")")
364 | }
365 | return nil
366 | }
367 |
368 | // IsAssignment implements the Assignment interface.
369 | func (a assignment) IsAssignment() {}
370 |
371 | // Assignments represents a list of Assignments e.g. x = 1, y = 2, z = 3.
372 | type Assignments []Assignment
373 |
374 | // WriteSQL implements the SQLWriter interface.
375 | func (as Assignments) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
376 | var err error
377 | for i, assignment := range as {
378 | if assignment == nil {
379 | return fmt.Errorf("assignment #%d is nil", i+1)
380 | }
381 | if i > 0 {
382 | buf.WriteString(", ")
383 | }
384 | err = assignment.WriteSQL(ctx, dialect, buf, args, params)
385 | if err != nil {
386 | return fmt.Errorf("assignment #%d: %w", i+1, err)
387 | }
388 | }
389 | return nil
390 | }
391 |
392 | // RowValue represents an SQL row value expression e.g. (x, y, z).
393 | type RowValue []any
394 |
395 | // WriteSQL implements the SQLWriter interface.
396 | func (r RowValue) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
397 | buf.WriteString("(")
398 | var err error
399 | for i, value := range r {
400 | if i > 0 {
401 | buf.WriteString(", ")
402 | }
403 | _, isQuery := value.(Query)
404 | if isQuery {
405 | buf.WriteString("(")
406 | }
407 | err = WriteValue(ctx, dialect, buf, args, params, value)
408 | if err != nil {
409 | return fmt.Errorf("rowvalue #%d: %w", i+1, err)
410 | }
411 | if isQuery {
412 | buf.WriteString(")")
413 | }
414 | }
415 | buf.WriteString(")")
416 | return nil
417 | }
418 |
419 | // In returns an 'rowvalue IN (value)' Predicate.
420 | func (r RowValue) In(v any) Predicate { return In(r, v) }
421 |
422 | // NotIn returns an 'rowvalue NOT IN (value)' Predicate.
423 | func (r RowValue) NotIn(v any) Predicate { return NotIn(r, v) }
424 |
425 | // Eq returns an 'rowvalue = value' Predicate.
426 | func (r RowValue) Eq(v any) Predicate { return cmp("=", r, v) }
427 |
428 | // RowValues represents a list of RowValues e.g. (x, y, z), (a, b, c).
429 | type RowValues []RowValue
430 |
431 | // WriteSQL implements the SQLWriter interface.
432 | func (rs RowValues) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
433 | var err error
434 | for i, r := range rs {
435 | if i > 0 {
436 | buf.WriteString(", ")
437 | }
438 | err = r.WriteSQL(ctx, dialect, buf, args, params)
439 | if err != nil {
440 | return fmt.Errorf("rowvalues #%d: %w", i+1, err)
441 | }
442 | }
443 | return nil
444 | }
445 |
446 | // Fields represents a list of Fields e.g. tbl.field1, tbl.field2, tbl.field3.
447 | type Fields []Field
448 |
449 | // WriteSQL implements the SQLWriter interface.
450 | func (fs Fields) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
451 | var err error
452 | for i, field := range fs {
453 | if field == nil {
454 | return fmt.Errorf("field #%d is nil", i+1)
455 | }
456 | if i > 0 {
457 | buf.WriteString(", ")
458 | }
459 | _, isQuery := field.(Query)
460 | if isQuery {
461 | buf.WriteString("(")
462 | }
463 | err = field.WriteSQL(ctx, dialect, buf, args, params)
464 | if err != nil {
465 | return fmt.Errorf("field #%d: %w", i+1, err)
466 | }
467 | if isQuery {
468 | buf.WriteString(")")
469 | }
470 | }
471 | return nil
472 | }
473 |
474 | type (
475 | sqliteQueryBuilder struct{ ctes []CTE }
476 | postgresQueryBuilder struct{ ctes []CTE }
477 | mysqlQueryBuilder struct{ ctes []CTE }
478 | sqlserverQueryBuilder struct{ ctes []CTE }
479 | )
480 |
481 | // Dialect-specific query builder variables.
482 | var (
483 | SQLite sqliteQueryBuilder
484 | Postgres postgresQueryBuilder
485 | MySQL mysqlQueryBuilder
486 | SQLServer sqlserverQueryBuilder
487 | )
488 |
489 | // With sets the CTEs in the SQLiteQueryBuilder.
490 | func (b sqliteQueryBuilder) With(ctes ...CTE) sqliteQueryBuilder {
491 | b.ctes = ctes
492 | return b
493 | }
494 |
495 | // With sets the CTEs in the PostgresQueryBuilder.
496 | func (b postgresQueryBuilder) With(ctes ...CTE) postgresQueryBuilder {
497 | b.ctes = ctes
498 | return b
499 | }
500 |
501 | // With sets the CTEs in the MySQLQueryBuilder.
502 | func (b mysqlQueryBuilder) With(ctes ...CTE) mysqlQueryBuilder {
503 | b.ctes = ctes
504 | return b
505 | }
506 |
507 | // With sets the CTEs in the SQLServerQueryBuilder.
508 | func (b sqlserverQueryBuilder) With(ctes ...CTE) sqlserverQueryBuilder {
509 | b.ctes = ctes
510 | return b
511 | }
512 |
513 | // ToSQL converts an SQLWriter into a query string and args slice.
514 | //
515 | // The params map is used to hold the mappings between named parameters in the
516 | // query to the corresponding index in the args slice and is used for rebinding
517 | // args by their parameter name. If you don't need to track this, you can pass
518 | // in a nil map.
519 | func ToSQL(dialect string, w SQLWriter, params map[string][]int) (query string, args []any, err error) {
520 | return ToSQLContext(context.Background(), dialect, w, params)
521 | }
522 |
523 | // ToSQLContext is like ToSQL but additionally requires a context.Context.
524 | func ToSQLContext(ctx context.Context, dialect string, w SQLWriter, params map[string][]int) (query string, args []any, err error) {
525 | if w == nil {
526 | return "", nil, fmt.Errorf("SQLWriter is nil")
527 | }
528 | if dialect == "" {
529 | if q, ok := w.(Query); ok {
530 | dialect = q.GetDialect()
531 | }
532 | }
533 | buf := bufpool.Get().(*bytes.Buffer)
534 | buf.Reset()
535 | defer bufpool.Put(buf)
536 | err = w.WriteSQL(ctx, dialect, buf, &args, params)
537 | query = buf.String()
538 | if err != nil {
539 | return query, args, err
540 | }
541 | return query, args, nil
542 | }
543 |
544 | // Eq returns an 'x = y' Predicate.
545 | func Eq(x, y any) Predicate { return cmp("=", x, y) }
546 |
547 | // Ne returns an 'x <> y' Predicate.
548 | func Ne(x, y any) Predicate { return cmp("<>", x, y) }
549 |
550 | // Lt returns an 'x < y' Predicate.
551 | func Lt(x, y any) Predicate { return cmp("<", x, y) }
552 |
553 | // Le returns an 'x <= y' Predicate.
554 | func Le(x, y any) Predicate { return cmp("<=", x, y) }
555 |
556 | // Gt returns an 'x > y' Predicate.
557 | func Gt(x, y any) Predicate { return cmp(">", x, y) }
558 |
559 | // Ge returns an 'x >= y' Predicate.
560 | func Ge(x, y any) Predicate { return cmp(">=", x, y) }
561 |
562 | // Exists returns an 'EXISTS (query)' Predicate.
563 | func Exists(query Query) Predicate { return Expr("EXISTS ({})", query) }
564 |
565 | // NotExists returns a 'NOT EXISTS (query)' Predicate.
566 | func NotExists(query Query) Predicate { return Expr("NOT EXISTS ({})", query) }
567 |
568 | // In returns an 'x IN (y)' Predicate.
569 | func In(x, y any) Predicate {
570 | _, isQueryA := x.(Query)
571 | _, isRowValueB := y.(RowValue)
572 | if !isQueryA && !isRowValueB {
573 | return Expr("{} IN ({})", x, y)
574 | } else if !isQueryA && isRowValueB {
575 | return Expr("{} IN {}", x, y)
576 | } else if isQueryA && !isRowValueB {
577 | return Expr("({}) IN ({})", x, y)
578 | } else {
579 | return Expr("({}) IN {}", x, y)
580 | }
581 | }
582 |
583 | // NotIn returns an 'x NOT IN (y)' Predicate.
584 | func NotIn(x, y any) Predicate {
585 | _, isQueryA := x.(Query)
586 | _, isRowValueB := y.(RowValue)
587 | if !isQueryA && !isRowValueB {
588 | return Expr("{} NOT IN ({})", x, y)
589 | } else if !isQueryA && isRowValueB {
590 | return Expr("{} NOT IN {}", x, y)
591 | } else if isQueryA && !isRowValueB {
592 | return Expr("({}) NOT IN ({})", x, y)
593 | } else {
594 | return Expr("({}) NOT IN {}", x, y)
595 | }
596 | }
597 |
598 | // cmp returns an 'x y' Predicate.
599 | func cmp(operator string, x, y any) Expression {
600 | _, isQueryA := x.(Query)
601 | _, isQueryB := y.(Query)
602 | if !isQueryA && !isQueryB {
603 | return Expr("{} "+operator+" {}", x, y)
604 | } else if !isQueryA && isQueryB {
605 | return Expr("{} "+operator+" ({})", x, y)
606 | } else if isQueryA && !isQueryB {
607 | return Expr("({}) "+operator+" {}", x, y)
608 | } else {
609 | return Expr("({}) "+operator+" ({})", x, y)
610 | }
611 | }
612 |
613 | // appendPolicy will append a policy from a Table (if it implements
614 | // PolicyTable) to a slice of policies. The resultant slice is returned.
615 | func appendPolicy(ctx context.Context, dialect string, policies []Predicate, table Table) ([]Predicate, error) {
616 | policyTable, ok := table.(PolicyTable)
617 | if !ok {
618 | return policies, nil
619 | }
620 | policy, err := policyTable.Policy(ctx, dialect)
621 | if err != nil {
622 | return nil, err
623 | }
624 | if policy != nil {
625 | policies = append(policies, policy)
626 | }
627 | return policies, nil
628 | }
629 |
630 | // appendPredicates will append a slices of predicates into a predicate.
631 | func appendPredicates(predicate Predicate, predicates []Predicate) VariadicPredicate {
632 | if predicate == nil {
633 | return And(predicates...)
634 | }
635 | if p1, ok := predicate.(VariadicPredicate); ok && !p1.IsDisjunction {
636 | p1.Predicates = append(p1.Predicates, predicates...)
637 | return p1
638 | }
639 | p2 := VariadicPredicate{Predicates: make([]Predicate, 1+len(predicates))}
640 | p2.Predicates[0] = predicate
641 | copy(p2.Predicates[1:], predicates)
642 | return p2
643 | }
644 |
645 | func writeTop(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, topLimit, topPercentLimit any, withTies bool) error {
646 | var err error
647 | if topLimit != nil {
648 | buf.WriteString("TOP (")
649 | err = WriteValue(ctx, dialect, buf, args, params, topLimit)
650 | if err != nil {
651 | return fmt.Errorf("TOP: %w", err)
652 | }
653 | buf.WriteString(") ")
654 | } else if topPercentLimit != nil {
655 | buf.WriteString("TOP (")
656 | err = WriteValue(ctx, dialect, buf, args, params, topPercentLimit)
657 | if err != nil {
658 | return fmt.Errorf("TOP PERCENT: %w", err)
659 | }
660 | buf.WriteString(") PERCENT ")
661 | }
662 | if (topLimit != nil || topPercentLimit != nil) && withTies {
663 | buf.WriteString("WITH TIES ")
664 | }
665 | return nil
666 | }
667 |
--------------------------------------------------------------------------------
/colors.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 |
3 | package sq
4 |
5 | import (
6 | "os"
7 | "syscall"
8 | )
9 |
10 | func init() {
11 | // https://stackoverflow.com/a/69542231
12 | const ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x4
13 | var stderrMode uint32
14 | stderr := syscall.Handle(os.Stderr.Fd())
15 | syscall.GetConsoleMode(stderr, &stderrMode)
16 | syscall.MustLoadDLL("kernel32").MustFindProc("SetConsoleMode").Call(uintptr(stderr), uintptr(stderrMode|ENABLE_VIRTUAL_TERMINAL_PROCESSING))
17 | var stdoutMode uint32
18 | stdout := syscall.Handle(os.Stdout.Fd())
19 | syscall.GetConsoleMode(stdout, &stdoutMode)
20 | syscall.MustLoadDLL("kernel32").MustFindProc("SetConsoleMode").Call(uintptr(stdout), uintptr(stdoutMode|ENABLE_VIRTUAL_TERMINAL_PROCESSING))
21 | }
22 |
--------------------------------------------------------------------------------
/cte.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "database/sql"
7 | "fmt"
8 | )
9 |
10 | // CTE represents an SQL common table expression (CTE).
11 | type CTE struct {
12 | query Query
13 | columns []string
14 | recursive bool
15 | materialized sql.NullBool
16 | name string
17 | alias string
18 | }
19 |
20 | var _ Table = (*CTE)(nil)
21 |
22 | // NewCTE creates a new CTE.
23 | func NewCTE(name string, columns []string, query Query) CTE {
24 | return CTE{name: name, columns: columns, query: query}
25 | }
26 |
27 | // NewRecursiveCTE creates a new recursive CTE.
28 | func NewRecursiveCTE(name string, columns []string, query Query) CTE {
29 | return CTE{name: name, columns: columns, query: query, recursive: true}
30 | }
31 |
32 | // WriteSQL implements the SQLWriter interface.
33 | func (cte CTE) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
34 | buf.WriteString(QuoteIdentifier(dialect, cte.name))
35 | return nil
36 | }
37 |
38 | // As returns a new CTE with the given alias.
39 | func (cte CTE) As(alias string) CTE {
40 | cte.alias = alias
41 | return cte
42 | }
43 |
44 | // Materialized returns a new CTE marked as MATERIALIZED. This only works on
45 | // postgres.
46 | func (cte CTE) Materialized() CTE {
47 | cte.materialized.Valid = true
48 | cte.materialized.Bool = true
49 | return cte
50 | }
51 |
52 | // Materialized returns a new CTE marked as NOT MATERIALIZED. This only works
53 | // on postgres.
54 | func (cte CTE) NotMaterialized() CTE {
55 | cte.materialized.Valid = true
56 | cte.materialized.Bool = false
57 | return cte
58 | }
59 |
60 | // Field returns a Field from the CTE.
61 | func (cte CTE) Field(name string) AnyField {
62 | return NewAnyField(name, NewTableStruct("", cte.name, cte.alias))
63 | }
64 |
65 | // GetAlias returns the alias of the CTE.
66 | func (cte CTE) GetAlias() string { return cte.alias }
67 |
68 | // AssertTable implements the Table interface.
69 | func (cte CTE) IsTable() {}
70 |
71 | func writeCTEs(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, ctes []CTE) error {
72 | var hasRecursiveCTE bool
73 | for _, cte := range ctes {
74 | if cte.recursive {
75 | hasRecursiveCTE = true
76 | break
77 | }
78 | }
79 | if hasRecursiveCTE {
80 | buf.WriteString("WITH RECURSIVE ")
81 | } else {
82 | buf.WriteString("WITH ")
83 | }
84 | for i, cte := range ctes {
85 | if i > 0 {
86 | buf.WriteString(", ")
87 | }
88 | if cte.name == "" {
89 | return fmt.Errorf("CTE #%d has no name", i+1)
90 | }
91 | buf.WriteString(QuoteIdentifier(dialect, cte.name))
92 | if len(cte.columns) > 0 {
93 | buf.WriteString(" (")
94 | for j, column := range cte.columns {
95 | if j > 0 {
96 | buf.WriteString(", ")
97 | }
98 | buf.WriteString(QuoteIdentifier(dialect, column))
99 | }
100 | buf.WriteString(")")
101 | }
102 | buf.WriteString(" AS ")
103 | if dialect == DialectPostgres && cte.materialized.Valid {
104 | if cte.materialized.Bool {
105 | buf.WriteString("MATERIALIZED ")
106 | } else {
107 | buf.WriteString("NOT MATERIALIZED ")
108 | }
109 | }
110 | buf.WriteString("(")
111 | switch query := cte.query.(type) {
112 | case nil:
113 | return fmt.Errorf("CTE #%d query is nil", i+1)
114 | case VariadicQuery:
115 | query.Toplevel = true
116 | err := query.WriteSQL(ctx, dialect, buf, args, params)
117 | if err != nil {
118 | return fmt.Errorf("CTE #%d failed to build query: %w", i+1, err)
119 | }
120 | default:
121 | err := query.WriteSQL(ctx, dialect, buf, args, params)
122 | if err != nil {
123 | return fmt.Errorf("CTE #%d failed to build query: %w", i+1, err)
124 | }
125 | }
126 | buf.WriteString(")")
127 | }
128 | buf.WriteString(" ")
129 | return nil
130 | }
131 |
132 | // VariadicQueryOperator represents a variadic query operator.
133 | type VariadicQueryOperator string
134 |
135 | // VariadicQuery operators.
136 | const (
137 | QueryUnion VariadicQueryOperator = "UNION"
138 | QueryUnionAll VariadicQueryOperator = "UNION ALL"
139 | QueryIntersect VariadicQueryOperator = "INTERSECT"
140 | QueryIntersectAll VariadicQueryOperator = "INTERSECT ALL"
141 | QueryExcept VariadicQueryOperator = "EXCEPT"
142 | QueryExceptAll VariadicQueryOperator = "EXCEPT ALL"
143 | )
144 |
145 | // VariadicQuery represents the 'x UNION y UNION z...' etc SQL constructs.
146 | type VariadicQuery struct {
147 | Toplevel bool
148 | Operator VariadicQueryOperator
149 | Queries []Query
150 | }
151 |
152 | var _ Query = (*VariadicQuery)(nil)
153 |
154 | // Union joins the queries together with the UNION operator.
155 | func Union(queries ...Query) VariadicQuery {
156 | return VariadicQuery{Operator: QueryUnion, Queries: queries}
157 | }
158 |
159 | // UnionAll joins the queries together with the UNION ALL operator.
160 | func UnionAll(queries ...Query) VariadicQuery {
161 | return VariadicQuery{Operator: QueryUnionAll, Queries: queries}
162 | }
163 |
164 | // Intersect joins the queries together with the INTERSECT operator.
165 | func Intersect(queries ...Query) VariadicQuery {
166 | return VariadicQuery{Operator: QueryIntersect, Queries: queries}
167 | }
168 |
169 | // IntersectAll joins the queries together with the INTERSECT ALL operator.
170 | func IntersectAll(queries ...Query) VariadicQuery {
171 | return VariadicQuery{Operator: QueryIntersectAll, Queries: queries}
172 | }
173 |
174 | // Except joins the queries together with the EXCEPT operator.
175 | func Except(queries ...Query) VariadicQuery {
176 | return VariadicQuery{Operator: QueryExcept, Queries: queries}
177 | }
178 |
179 | // ExceptAll joins the queries together with the EXCEPT ALL operator.
180 | func ExceptAll(queries ...Query) VariadicQuery {
181 | return VariadicQuery{Operator: QueryExceptAll, Queries: queries}
182 | }
183 |
184 | // WriteSQL implements the SQLWriter interface.
185 | func (q VariadicQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
186 | var err error
187 | if q.Operator == "" {
188 | q.Operator = QueryUnion
189 | }
190 | if len(q.Queries) == 0 {
191 | return fmt.Errorf("VariadicQuery empty")
192 | }
193 |
194 | if len(q.Queries) == 1 {
195 | switch q1 := q.Queries[0].(type) {
196 | case nil:
197 | return fmt.Errorf("query #1 is nil")
198 | case VariadicQuery:
199 | q1.Toplevel = q.Toplevel
200 | err = q1.WriteSQL(ctx, dialect, buf, args, params)
201 | if err != nil {
202 | return err
203 | }
204 | default:
205 | err = q.Queries[0].WriteSQL(ctx, dialect, buf, args, params)
206 | if err != nil {
207 | return err
208 | }
209 | }
210 | return nil
211 | }
212 |
213 | if !q.Toplevel {
214 | buf.WriteString("(")
215 | }
216 | for i, query := range q.Queries {
217 | if i > 0 {
218 | buf.WriteString(" " + string(q.Operator) + " ")
219 | }
220 | if query == nil {
221 | return fmt.Errorf("query #%d is nil", i+1)
222 | }
223 | err = query.WriteSQL(ctx, dialect, buf, args, params)
224 | if err != nil {
225 | return fmt.Errorf("query #%d: %w", i+1, err)
226 | }
227 | }
228 | if !q.Toplevel {
229 | buf.WriteString(")")
230 | }
231 | return nil
232 | }
233 |
234 | // SetFetchableFields implements the Query interface.
235 | func (q VariadicQuery) SetFetchableFields(fields []Field) (query Query, ok bool) {
236 | return q, false
237 | }
238 |
239 | // GetFetchableFields implements the Query interface.
240 | func (q VariadicQuery) GetFetchableFields() []Field {
241 | return nil
242 | }
243 |
244 | // GetDialect returns the SQL dialect of the VariadicQuery.
245 | func (q VariadicQuery) GetDialect() string {
246 | if len(q.Queries) == 0 {
247 | return ""
248 | }
249 | q1 := q.Queries[0]
250 | if q1 == nil {
251 | return ""
252 | }
253 | return q1.GetDialect()
254 | }
255 |
--------------------------------------------------------------------------------
/cte_test.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "database/sql"
7 | "errors"
8 | "testing"
9 |
10 | "github.com/bokwoon95/sq/internal/testutil"
11 | )
12 |
13 | func TestCTE(t *testing.T) {
14 | t.Run("basic", func(t *testing.T) {
15 | cte := NewCTE("cte", []string{"n"}, Queryf("SELECT 1")).Materialized().NotMaterialized().As("c")
16 | TestTable{item: cte, wantQuery: "cte"}.assert(t)
17 | field := NewAnyField("ff", TableStruct{name: "cte", alias: "c"})
18 | if diff := testutil.Diff(cte.Field("ff"), field); diff != "" {
19 | t.Error(testutil.Callers(), diff)
20 | }
21 | if diff := testutil.Diff(cte.materialized, sql.NullBool{Valid: true, Bool: false}); diff != "" {
22 | t.Error(testutil.Callers(), diff)
23 | }
24 | if diff := testutil.Diff(cte.GetAlias(), "c"); diff != "" {
25 | t.Error(testutil.Callers(), diff)
26 | }
27 | })
28 | }
29 |
30 | func TestCTEs(t *testing.T) {
31 | type TT struct {
32 | description string
33 | dialect string
34 | ctes []CTE
35 | wantQuery string
36 | wantArgs []any
37 | wantParams map[string][]int
38 | }
39 |
40 | tests := []TT{{
41 | description: "basic",
42 | ctes: []CTE{NewCTE("cte", nil, Queryf("SELECT 1"))},
43 | wantQuery: "WITH cte AS (SELECT 1) ",
44 | }, {
45 | description: "recursive",
46 | ctes: []CTE{
47 | NewCTE("cte", nil, Queryf("SELECT 1")),
48 | NewRecursiveCTE("nums", []string{"n"}, Union(
49 | Queryf("SELECT 1"),
50 | Queryf("SELECT n+1 FROM nums WHERE n < 10"),
51 | )),
52 | },
53 | wantQuery: "WITH RECURSIVE cte AS (SELECT 1)" +
54 | ", nums (n) AS (SELECT 1 UNION SELECT n+1 FROM nums WHERE n < 10) ",
55 | }, {
56 | description: "mysql materialized",
57 | dialect: DialectMySQL,
58 | ctes: []CTE{NewCTE("cte", nil, Queryf("SELECT 1")).Materialized()},
59 | wantQuery: "WITH cte AS (SELECT 1) ",
60 | }, {
61 | description: "postgres materialized",
62 | dialect: DialectPostgres,
63 | ctes: []CTE{NewCTE("cte", nil, Queryf("SELECT 1")).Materialized()},
64 | wantQuery: "WITH cte AS MATERIALIZED (SELECT 1) ",
65 | }, {
66 | description: "postgres not materialized",
67 | dialect: DialectPostgres,
68 | ctes: []CTE{NewCTE("cte", nil, Queryf("SELECT 1")).NotMaterialized()},
69 | wantQuery: "WITH cte AS NOT MATERIALIZED (SELECT 1) ",
70 | }}
71 |
72 | for _, tt := range tests {
73 | tt := tt
74 | t.Run(tt.description, func(t *testing.T) {
75 | t.Parallel()
76 | buf, args, params := bufpool.Get().(*bytes.Buffer), &[]any{}, make(map[string][]int)
77 | buf.Reset()
78 | defer bufpool.Put(buf)
79 | err := writeCTEs(context.Background(), tt.dialect, buf, args, params, tt.ctes)
80 | if err != nil {
81 | t.Fatal(testutil.Callers(), err)
82 | }
83 | if diff := testutil.Diff(buf.String(), tt.wantQuery); diff != "" {
84 | t.Error(testutil.Callers(), diff)
85 | }
86 | if diff := testutil.Diff(*args, tt.wantArgs); diff != "" {
87 | t.Error(testutil.Callers(), diff)
88 | }
89 | if diff := testutil.Diff(params, tt.wantParams); diff != "" {
90 | t.Error(testutil.Callers(), diff)
91 | }
92 | })
93 | }
94 |
95 | t.Run("invalid cte", func(t *testing.T) {
96 | t.Parallel()
97 | buf, args, params := bufpool.Get().(*bytes.Buffer), &[]any{}, make(map[string][]int)
98 | buf.Reset()
99 | defer bufpool.Put(buf)
100 | // no name
101 | err := writeCTEs(context.Background(), "", buf, args, params, []CTE{
102 | NewCTE("", nil, Queryf("SELECT 1")),
103 | })
104 | if err == nil {
105 | t.Fatal(testutil.Callers(), "expected error but got nil")
106 | }
107 | // no query
108 | err = writeCTEs(context.Background(), "", buf, args, params, []CTE{
109 | NewCTE("cte", nil, nil),
110 | })
111 | if err == nil {
112 | t.Fatal(testutil.Callers(), "expected error but got nil")
113 | }
114 | })
115 |
116 | t.Run("err", func(t *testing.T) {
117 | t.Parallel()
118 | buf, args, params := bufpool.Get().(*bytes.Buffer), &[]any{}, make(map[string][]int)
119 | buf.Reset()
120 | defer bufpool.Put(buf)
121 | // VariadicQuery
122 | err := writeCTEs(context.Background(), "", buf, args, params, []CTE{
123 | NewCTE("cte", nil, Union(
124 | Queryf("SELECT 1"),
125 | Queryf("SELECT {}", FaultySQL{}),
126 | )),
127 | })
128 | if !errors.Is(err, ErrFaultySQL) {
129 | t.Errorf(testutil.Callers()+"expected error %q but got %q", ErrFaultySQL, err)
130 | }
131 | // Query
132 | err = writeCTEs(context.Background(), "", buf, args, params, []CTE{
133 | NewCTE("cte", nil, Queryf("SELECT {}", FaultySQL{})),
134 | })
135 | if !errors.Is(err, ErrFaultySQL) {
136 | t.Errorf(testutil.Callers()+"expected error %q but got %q", ErrFaultySQL, err)
137 | }
138 | })
139 | }
140 |
141 | func TestVariadicQuery(t *testing.T) {
142 | q1, q2, q3 := Queryf("SELECT 1"), Queryf("SELECT 2"), Queryf("SELECT 3")
143 | tests := []TestTable{{
144 | description: "Union",
145 | item: Union(q1, q2, q3),
146 | wantQuery: "(SELECT 1 UNION SELECT 2 UNION SELECT 3)",
147 | }, {
148 | description: "UnionAll",
149 | item: UnionAll(q1, q2, q3),
150 | wantQuery: "(SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3)",
151 | }, {
152 | description: "Intersect",
153 | item: Intersect(q1, q2, q3),
154 | wantQuery: "(SELECT 1 INTERSECT SELECT 2 INTERSECT SELECT 3)",
155 | }, {
156 | description: "IntersectAll",
157 | item: IntersectAll(q1, q2, q3),
158 | wantQuery: "(SELECT 1 INTERSECT ALL SELECT 2 INTERSECT ALL SELECT 3)",
159 | }, {
160 | description: "Except",
161 | item: Except(q1, q2, q3),
162 | wantQuery: "(SELECT 1 EXCEPT SELECT 2 EXCEPT SELECT 3)",
163 | }, {
164 | description: "ExceptAll",
165 | item: ExceptAll(q1, q2, q3),
166 | wantQuery: "(SELECT 1 EXCEPT ALL SELECT 2 EXCEPT ALL SELECT 3)",
167 | }, {
168 | description: "No operator specified",
169 | item: VariadicQuery{Queries: []Query{q1, q2, q3}},
170 | wantQuery: "(SELECT 1 UNION SELECT 2 UNION SELECT 3)",
171 | }, {
172 | description: "nested VariadicQuery",
173 | item: Union(Union(Union(q1, q2, q3))),
174 | wantQuery: "(SELECT 1 UNION SELECT 2 UNION SELECT 3)",
175 | }, {
176 | description: "1 query",
177 | item: Union(q1),
178 | wantQuery: "SELECT 1",
179 | }}
180 |
181 | for _, tt := range tests {
182 | tt := tt
183 | t.Run(tt.description, func(t *testing.T) {
184 | t.Parallel()
185 | tt.assert(t)
186 | })
187 | }
188 |
189 | t.Run("invalid VariadicQuery", func(t *testing.T) {
190 | t.Parallel()
191 | // empty
192 | TestTable{item: Union()}.assertNotOK(t)
193 | // nil query
194 | TestTable{item: Union(nil)}.assertNotOK(t)
195 | // nil query
196 | TestTable{item: Union(q1, q2, nil)}.assertNotOK(t)
197 | })
198 |
199 | t.Run("err", func(t *testing.T) {
200 | t.Parallel()
201 | // VariadicQuery
202 | TestTable{
203 | item: Union(
204 | Union(
205 | Queryf("SELECT 1"),
206 | Queryf("SELECT {}", FaultySQL{}),
207 | ),
208 | ),
209 | }.assertErr(t, ErrFaultySQL)
210 | // Query
211 | TestTable{
212 | item: Union(Queryf("SELECT {}", FaultySQL{})),
213 | }.assertErr(t, ErrFaultySQL)
214 | })
215 |
216 | t.Run("SetFetchableFields", func(t *testing.T) {
217 | t.Parallel()
218 | _, ok := Union().SetFetchableFields([]Field{Expr("f1")})
219 | if ok {
220 | t.Error(testutil.Callers(), "expected not ok but got ok")
221 | }
222 | })
223 |
224 | t.Run("GetDialect", func(t *testing.T) {
225 | // empty VariadicQuery
226 | if diff := testutil.Diff(Union().GetDialect(), ""); diff != "" {
227 | t.Error(testutil.Callers(), diff)
228 | }
229 | // nil query
230 | if diff := testutil.Diff(Union(nil).GetDialect(), ""); diff != "" {
231 | t.Error(testutil.Callers(), diff)
232 | }
233 | // empty dialect propagated
234 | if diff := testutil.Diff(Union(Queryf("SELECT 1")).GetDialect(), ""); diff != "" {
235 | t.Error(testutil.Callers(), diff)
236 | }
237 | })
238 | }
239 |
--------------------------------------------------------------------------------
/delete_query.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | )
8 |
9 | // DeleteQuery represents an SQL DELETE query.
10 | type DeleteQuery struct {
11 | Dialect string
12 | // WITH
13 | CTEs []CTE
14 | // DELETE FROM
15 | DeleteTable Table
16 | DeleteTables []Table
17 | // USING
18 | UsingTable Table
19 | JoinTables []JoinTable
20 | // WHERE
21 | WherePredicate Predicate
22 | // ORDER BY
23 | OrderByFields Fields
24 | // LIMIT
25 | LimitRows any
26 | // OFFSET
27 | OffsetRows any
28 | // RETURNING
29 | ReturningFields []Field
30 | }
31 |
32 | var _ Query = (*DeleteQuery)(nil)
33 |
34 | // WriteSQL implements the SQLWriter interface.
35 | func (q DeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
36 | var err error
37 | // Table Policies
38 | var policies []Predicate
39 | policies, err = appendPolicy(ctx, dialect, policies, q.DeleteTable)
40 | if err != nil {
41 | return fmt.Errorf("DELETE FROM %s Policy: %w", toString(q.Dialect, q.DeleteTable), err)
42 | }
43 | policies, err = appendPolicy(ctx, dialect, policies, q.UsingTable)
44 | if err != nil {
45 | return fmt.Errorf("USING %s Policy: %w", toString(q.Dialect, q.UsingTable), err)
46 | }
47 | for _, joinTable := range q.JoinTables {
48 | policies, err = appendPolicy(ctx, dialect, policies, joinTable.Table)
49 | if err != nil {
50 | return fmt.Errorf("%s %s Policy: %w", joinTable.JoinOperator, joinTable.Table, err)
51 | }
52 | }
53 | if len(policies) > 0 {
54 | if q.WherePredicate != nil {
55 | policies = append(policies, q.WherePredicate)
56 | }
57 | q.WherePredicate = And(policies...)
58 | }
59 | // WITH
60 | if len(q.CTEs) > 0 {
61 | err = writeCTEs(ctx, dialect, buf, args, params, q.CTEs)
62 | if err != nil {
63 | return fmt.Errorf("WITH: %w", err)
64 | }
65 | }
66 | // DELETE FROM
67 | if (dialect == DialectMySQL || dialect == DialectSQLServer) && len(q.DeleteTables) > 0 {
68 | buf.WriteString("DELETE ")
69 | if len(q.DeleteTables) > 1 && dialect != DialectMySQL {
70 | return fmt.Errorf("dialect %q does not support multi-table DELETE", dialect)
71 | }
72 | for i, table := range q.DeleteTables {
73 | if i > 0 {
74 | buf.WriteString(", ")
75 | }
76 | if alias := getAlias(table); alias != "" {
77 | buf.WriteString(alias)
78 | } else {
79 | err = table.WriteSQL(ctx, dialect, buf, args, params)
80 | if err != nil {
81 | return fmt.Errorf("table #%d: %w", i+1, err)
82 | }
83 | }
84 | }
85 | } else {
86 | buf.WriteString("DELETE FROM ")
87 | if q.DeleteTable == nil {
88 | return fmt.Errorf("no table provided to DELETE FROM")
89 | }
90 | err = q.DeleteTable.WriteSQL(ctx, dialect, buf, args, params)
91 | if err != nil {
92 | return fmt.Errorf("DELETE FROM: %w", err)
93 | }
94 | if dialect != DialectSQLServer {
95 | if alias := getAlias(q.DeleteTable); alias != "" {
96 | buf.WriteString(" AS " + QuoteIdentifier(dialect, alias))
97 | }
98 | }
99 | }
100 | if q.UsingTable != nil || len(q.JoinTables) > 0 {
101 | if dialect != DialectPostgres && dialect != DialectMySQL && dialect != DialectSQLServer {
102 | return fmt.Errorf("%s DELETE does not support JOIN", dialect)
103 | }
104 | }
105 | // OUTPUT
106 | if len(q.ReturningFields) > 0 && dialect == DialectSQLServer {
107 | buf.WriteString(" OUTPUT ")
108 | err = writeFieldsWithPrefix(ctx, dialect, buf, args, params, q.ReturningFields, "DELETED", true)
109 | if err != nil {
110 | return err
111 | }
112 | }
113 | // USING/FROM
114 | if q.UsingTable != nil {
115 | switch dialect {
116 | case DialectPostgres:
117 | buf.WriteString(" USING ")
118 | err = q.UsingTable.WriteSQL(ctx, dialect, buf, args, params)
119 | if err != nil {
120 | return fmt.Errorf("USING: %w", err)
121 | }
122 | case DialectMySQL, DialectSQLServer:
123 | buf.WriteString(" FROM ")
124 | err = q.UsingTable.WriteSQL(ctx, dialect, buf, args, params)
125 | if err != nil {
126 | return fmt.Errorf("FROM: %w", err)
127 | }
128 | }
129 | if alias := getAlias(q.UsingTable); alias != "" {
130 | buf.WriteString(" AS " + QuoteIdentifier(dialect, alias))
131 | }
132 | }
133 | // JOIN
134 | if len(q.JoinTables) > 0 {
135 | if q.UsingTable == nil {
136 | return fmt.Errorf("%s can't JOIN without a USING/FROM table", dialect)
137 | }
138 | buf.WriteString(" ")
139 | err = writeJoinTables(ctx, dialect, buf, args, params, q.JoinTables)
140 | if err != nil {
141 | return fmt.Errorf("JOIN: %w", err)
142 | }
143 | }
144 | // WHERE
145 | if q.WherePredicate != nil {
146 | buf.WriteString(" WHERE ")
147 | switch predicate := q.WherePredicate.(type) {
148 | case VariadicPredicate:
149 | predicate.Toplevel = true
150 | err = predicate.WriteSQL(ctx, dialect, buf, args, params)
151 | if err != nil {
152 | return fmt.Errorf("WHERE: %w", err)
153 | }
154 | default:
155 | err = q.WherePredicate.WriteSQL(ctx, dialect, buf, args, params)
156 | if err != nil {
157 | return fmt.Errorf("WHERE: %w", err)
158 | }
159 | }
160 | }
161 | // ORDER BY
162 | if len(q.OrderByFields) > 0 {
163 | if dialect != DialectMySQL {
164 | return fmt.Errorf("%s UPDATE does not support ORDER BY", dialect)
165 | }
166 | buf.WriteString(" ORDER BY ")
167 | err = q.OrderByFields.WriteSQL(ctx, dialect, buf, args, params)
168 | if err != nil {
169 | return fmt.Errorf("ORDER BY: %w", err)
170 | }
171 | }
172 | // LIMIT
173 | if q.LimitRows != nil {
174 | if dialect != DialectMySQL {
175 | return fmt.Errorf("%s UPDATE does not support LIMIT", dialect)
176 | }
177 | buf.WriteString(" LIMIT ")
178 | err = WriteValue(ctx, dialect, buf, args, params, q.LimitRows)
179 | if err != nil {
180 | return fmt.Errorf("LIMIT: %w", err)
181 | }
182 | }
183 | // RETURNING
184 | if len(q.ReturningFields) > 0 && dialect != DialectSQLServer {
185 | if dialect != DialectPostgres && dialect != DialectSQLite && dialect != DialectMySQL {
186 | return fmt.Errorf("%s UPDATE does not support RETURNING", dialect)
187 | }
188 | buf.WriteString(" RETURNING ")
189 | err = writeFields(ctx, dialect, buf, args, params, q.ReturningFields, true)
190 | if err != nil {
191 | return fmt.Errorf("RETURNING: %w", err)
192 | }
193 | }
194 | return nil
195 | }
196 |
197 | // DeleteFrom returns a new DeleteQuery.
198 | func DeleteFrom(table Table) DeleteQuery {
199 | return DeleteQuery{DeleteTable: table}
200 | }
201 |
202 | // Where appends to the WherePredicate field of the DeleteQuery.
203 | func (q DeleteQuery) Where(predicates ...Predicate) DeleteQuery {
204 | q.WherePredicate = appendPredicates(q.WherePredicate, predicates)
205 | return q
206 | }
207 |
208 | // SetFetchableFields implements the Query interface.
209 | func (q DeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) {
210 | switch q.Dialect {
211 | case DialectPostgres, DialectSQLite:
212 | if len(q.ReturningFields) == 0 {
213 | q.ReturningFields = fields
214 | return q, true
215 | }
216 | return q, false
217 | default:
218 | return q, false
219 | }
220 | }
221 |
222 | // GetFetchableFields returns the fetchable fields of the query.
223 | func (q DeleteQuery) GetFetchableFields() []Field {
224 | switch q.Dialect {
225 | case DialectPostgres, DialectSQLite:
226 | return q.ReturningFields
227 | default:
228 | return nil
229 | }
230 | }
231 |
232 | // GetDialect implements the Query interface.
233 | func (q DeleteQuery) GetDialect() string { return q.Dialect }
234 |
235 | // SetDialect sets the dialect of the query.
236 | func (q DeleteQuery) SetDialect(dialect string) DeleteQuery {
237 | q.Dialect = dialect
238 | return q
239 | }
240 |
241 | // SQLiteDeleteQuery represents an SQLite DELETE query.
242 | type SQLiteDeleteQuery DeleteQuery
243 |
244 | var _ Query = (*SQLiteDeleteQuery)(nil)
245 |
246 | // WriteSQL implements the SQLWriter interface.
247 | func (q SQLiteDeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
248 | return DeleteQuery(q).WriteSQL(ctx, dialect, buf, args, params)
249 | }
250 |
251 | // DeleteFrom returns a new SQLiteDeleteQuery.
252 | func (b sqliteQueryBuilder) DeleteFrom(table Table) SQLiteDeleteQuery {
253 | return SQLiteDeleteQuery{
254 | Dialect: DialectSQLite,
255 | CTEs: b.ctes,
256 | DeleteTable: table,
257 | }
258 | }
259 |
260 | // Where appends to the WherePredicate field of the SQLiteDeleteQuery.
261 | func (q SQLiteDeleteQuery) Where(predicates ...Predicate) SQLiteDeleteQuery {
262 | q.WherePredicate = appendPredicates(q.WherePredicate, predicates)
263 | return q
264 | }
265 |
266 | // Returning appends fields to the RETURNING clause of the SQLiteDeleteQuery.
267 | func (q SQLiteDeleteQuery) Returning(fields ...Field) SQLiteDeleteQuery {
268 | q.ReturningFields = append(q.ReturningFields, fields...)
269 | return q
270 | }
271 |
272 | // SetFetchableFields implements the Query interface.
273 | func (q SQLiteDeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) {
274 | return DeleteQuery(q).SetFetchableFields(fields)
275 | }
276 |
277 | // GetFetchableFields returns the fetchable fields of the query.
278 | func (q SQLiteDeleteQuery) GetFetchableFields() []Field {
279 | return DeleteQuery(q).GetFetchableFields()
280 | }
281 |
282 | // GetDialect implements the Query interface.
283 | func (q SQLiteDeleteQuery) GetDialect() string { return q.Dialect }
284 |
285 | // SetDialect sets the dialect of the query.
286 | func (q SQLiteDeleteQuery) SetDialect(dialect string) SQLiteDeleteQuery {
287 | q.Dialect = dialect
288 | return q
289 | }
290 |
291 | // PostgresDeleteQuery represents a Postgres DELETE query.
292 | type PostgresDeleteQuery DeleteQuery
293 |
294 | var _ Query = (*PostgresDeleteQuery)(nil)
295 |
296 | // WriteSQL implements the SQLWriter interface.
297 | func (q PostgresDeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
298 | return DeleteQuery(q).WriteSQL(ctx, dialect, buf, args, params)
299 | }
300 |
301 | // DeleteFrom returns a new PostgresDeleteQuery.
302 | func (b postgresQueryBuilder) DeleteFrom(table Table) PostgresDeleteQuery {
303 | return PostgresDeleteQuery{
304 | Dialect: DialectPostgres,
305 | CTEs: b.ctes,
306 | DeleteTable: table,
307 | }
308 | }
309 |
310 | // Using sets the UsingTable field of the PostgresDeleteQuery.
311 | func (q PostgresDeleteQuery) Using(table Table) PostgresDeleteQuery {
312 | q.UsingTable = table
313 | return q
314 | }
315 |
316 | // Join joins a new Table to the PostgresDeleteQuery.
317 | func (q PostgresDeleteQuery) Join(table Table, predicates ...Predicate) PostgresDeleteQuery {
318 | q.JoinTables = append(q.JoinTables, Join(table, predicates...))
319 | return q
320 | }
321 |
322 | // LeftJoin left joins a new Table to the PostgresDeleteQuery.
323 | func (q PostgresDeleteQuery) LeftJoin(table Table, predicates ...Predicate) PostgresDeleteQuery {
324 | q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...))
325 | return q
326 | }
327 |
328 | // FullJoin full joins a new Table to the PostgresDeleteQuery.
329 | func (q PostgresDeleteQuery) FullJoin(table Table, predicates ...Predicate) PostgresDeleteQuery {
330 | q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...))
331 | return q
332 | }
333 |
334 | // CrossJoin cross joins a new Table to the PostgresDeleteQuery.
335 | func (q PostgresDeleteQuery) CrossJoin(table Table) PostgresDeleteQuery {
336 | q.JoinTables = append(q.JoinTables, CrossJoin(table))
337 | return q
338 | }
339 |
340 | // CustomJoin joins a new Table to the PostgresDeleteQuery with a custom join
341 | // operator.
342 | func (q PostgresDeleteQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) PostgresDeleteQuery {
343 | q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...))
344 | return q
345 | }
346 |
347 | // JoinUsing joins a new Table to the PostgresDeleteQuery with the USING operator.
348 | func (q PostgresDeleteQuery) JoinUsing(table Table, fields ...Field) PostgresDeleteQuery {
349 | q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...))
350 | return q
351 | }
352 |
353 | // Where appends to the WherePredicate field of the PostgresDeleteQuery.
354 | func (q PostgresDeleteQuery) Where(predicates ...Predicate) PostgresDeleteQuery {
355 | q.WherePredicate = appendPredicates(q.WherePredicate, predicates)
356 | return q
357 | }
358 |
359 | // Returning appends fields to the RETURNING clause of the PostgresDeleteQuery.
360 | func (q PostgresDeleteQuery) Returning(fields ...Field) PostgresDeleteQuery {
361 | q.ReturningFields = append(q.ReturningFields, fields...)
362 | return q
363 | }
364 |
365 | // SetFetchableFields implements the Query interface.
366 | func (q PostgresDeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) {
367 | return DeleteQuery(q).SetFetchableFields(fields)
368 | }
369 |
370 | // GetFetchableFields returns the fetchable fields of the query.
371 | func (q PostgresDeleteQuery) GetFetchableFields() []Field {
372 | return DeleteQuery(q).GetFetchableFields()
373 | }
374 |
375 | // GetDialect implements the Query interface.
376 | func (q PostgresDeleteQuery) GetDialect() string { return q.Dialect }
377 |
378 | // SetDialect sets the dialect of the query.
379 | func (q PostgresDeleteQuery) SetDialect(dialect string) PostgresDeleteQuery {
380 | q.Dialect = dialect
381 | return q
382 | }
383 |
384 | // MySQLDeleteQuery represents a MySQL DELETE query.
385 | type MySQLDeleteQuery DeleteQuery
386 |
387 | var _ Query = (*MySQLDeleteQuery)(nil)
388 |
389 | // WriteSQL implements the SQLWriter interface.
390 | func (q MySQLDeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
391 | return DeleteQuery(q).WriteSQL(ctx, dialect, buf, args, params)
392 | }
393 |
394 | // DeleteFrom returns a new MySQLDeleteQuery.
395 | func (b mysqlQueryBuilder) DeleteFrom(table Table) MySQLDeleteQuery {
396 | return MySQLDeleteQuery{
397 | Dialect: DialectMySQL,
398 | CTEs: b.ctes,
399 | DeleteTable: table,
400 | }
401 | }
402 |
403 | // Delete returns a new MySQLDeleteQuery.
404 | func (b mysqlQueryBuilder) Delete(tables ...Table) MySQLDeleteQuery {
405 | return MySQLDeleteQuery{
406 | Dialect: DialectMySQL,
407 | CTEs: b.ctes,
408 | DeleteTables: tables,
409 | }
410 | }
411 |
412 | // From sets the UsingTable of the MySQLDeleteQuery.
413 | func (q MySQLDeleteQuery) From(table Table) MySQLDeleteQuery {
414 | q.UsingTable = table
415 | return q
416 | }
417 |
418 | // Join joins a new Table to the MySQLDeleteQuery.
419 | func (q MySQLDeleteQuery) Join(table Table, predicates ...Predicate) MySQLDeleteQuery {
420 | q.JoinTables = append(q.JoinTables, Join(table, predicates...))
421 | return q
422 | }
423 |
424 | // LeftJoin left joins a new Table to the MySQLDeleteQuery.
425 | func (q MySQLDeleteQuery) LeftJoin(table Table, predicates ...Predicate) MySQLDeleteQuery {
426 | q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...))
427 | return q
428 | }
429 |
430 | // FullJoin full joins a new Table to the MySQLDeleteQuery.
431 | func (q MySQLDeleteQuery) FullJoin(table Table, predicates ...Predicate) MySQLDeleteQuery {
432 | q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...))
433 | return q
434 | }
435 |
436 | // CrossJoin cross joins a new Table to the MySQLDeleteQuery.
437 | func (q MySQLDeleteQuery) CrossJoin(table Table) MySQLDeleteQuery {
438 | q.JoinTables = append(q.JoinTables, CrossJoin(table))
439 | return q
440 | }
441 |
442 | // CustomJoin joins a new Table to the MySQLDeleteQuery with a custom join
443 | // operator.
444 | func (q MySQLDeleteQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) MySQLDeleteQuery {
445 | q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...))
446 | return q
447 | }
448 |
449 | // JoinUsing joins a new Table to the MySQLDeleteQuery with the USING operator.
450 | func (q MySQLDeleteQuery) JoinUsing(table Table, fields ...Field) MySQLDeleteQuery {
451 | q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...))
452 | return q
453 | }
454 |
455 | // Where appends to the WherePredicate field of the MySQLDeleteQuery.
456 | func (q MySQLDeleteQuery) Where(predicates ...Predicate) MySQLDeleteQuery {
457 | q.WherePredicate = appendPredicates(q.WherePredicate, predicates)
458 | return q
459 | }
460 |
461 | // OrderBy sets the OrderByFields field of the MySQLDeleteQuery.
462 | func (q MySQLDeleteQuery) OrderBy(fields ...Field) MySQLDeleteQuery {
463 | q.OrderByFields = append(q.OrderByFields, fields...)
464 | return q
465 | }
466 |
467 | // Limit sets the LimitRows field of the MySQLDeleteQuery.
468 | func (q MySQLDeleteQuery) Limit(limit any) MySQLDeleteQuery {
469 | q.LimitRows = limit
470 | return q
471 | }
472 |
473 | // Returning appends fields to the RETURNING clause of the MySQLDeleteQuery.
474 | func (q MySQLDeleteQuery) Returning(fields ...Field) MySQLDeleteQuery {
475 | q.ReturningFields = append(q.ReturningFields, fields...)
476 | return q
477 | }
478 |
479 | // SetFetchableFields implements the Query interface.
480 | func (q MySQLDeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) {
481 | return DeleteQuery(q).SetFetchableFields(fields)
482 | }
483 |
484 | // GetFetchableFields returns the fetchable fields of the query.
485 | func (q MySQLDeleteQuery) GetFetchableFields() []Field {
486 | return DeleteQuery(q).GetFetchableFields()
487 | }
488 |
489 | // GetDialect implements the Query interface.
490 | func (q MySQLDeleteQuery) GetDialect() string { return q.Dialect }
491 |
492 | // SetDialect sets the dialect of the query.
493 | func (q MySQLDeleteQuery) SetDialect(dialect string) MySQLDeleteQuery {
494 | q.Dialect = dialect
495 | return q
496 | }
497 |
498 | // SQLServerDeleteQuery represents an SQL Server DELETE query.
499 | type SQLServerDeleteQuery DeleteQuery
500 |
501 | var _ Query = (*SQLServerDeleteQuery)(nil)
502 |
503 | // WriteSQL implements the SQLWriter interface.
504 | func (q SQLServerDeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
505 | return DeleteQuery(q).WriteSQL(ctx, dialect, buf, args, params)
506 | }
507 |
508 | // DeleteFrom returns a new SQLServerDeleteQuery.
509 | func (b sqlserverQueryBuilder) DeleteFrom(table Table) SQLServerDeleteQuery {
510 | return SQLServerDeleteQuery{
511 | Dialect: DialectSQLServer,
512 | CTEs: b.ctes,
513 | DeleteTable: table,
514 | }
515 | }
516 |
517 | // Delete returns a new SQLServerDeleteQuery.
518 | func (b sqlserverQueryBuilder) Delete(table Table) SQLServerDeleteQuery {
519 | return SQLServerDeleteQuery{
520 | Dialect: DialectSQLServer,
521 | CTEs: b.ctes,
522 | DeleteTables: []Table{table},
523 | }
524 | }
525 |
526 | // From sets the UsingTable of the SQLServerDeleteQuery.
527 | func (q SQLServerDeleteQuery) From(table Table) SQLServerDeleteQuery {
528 | q.UsingTable = table
529 | return q
530 | }
531 |
532 | // Join joins a new Table to the SQLServerDeleteQuery.
533 | func (q SQLServerDeleteQuery) Join(table Table, predicates ...Predicate) SQLServerDeleteQuery {
534 | q.JoinTables = append(q.JoinTables, Join(table, predicates...))
535 | return q
536 | }
537 |
538 | // LeftJoin left joins a new Table to the SQLServerDeleteQuery.
539 | func (q SQLServerDeleteQuery) LeftJoin(table Table, predicates ...Predicate) SQLServerDeleteQuery {
540 | q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...))
541 | return q
542 | }
543 |
544 | // FullJoin full joins a new Table to the SQLServerDeleteQuery.
545 | func (q SQLServerDeleteQuery) FullJoin(table Table, predicates ...Predicate) SQLServerDeleteQuery {
546 | q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...))
547 | return q
548 | }
549 |
550 | // CrossJoin cross joins a new Table to the SQLServerDeleteQuery.
551 | func (q SQLServerDeleteQuery) CrossJoin(table Table) SQLServerDeleteQuery {
552 | q.JoinTables = append(q.JoinTables, CrossJoin(table))
553 | return q
554 | }
555 |
556 | // CustomJoin joins a new Table to the SQLServerDeleteQuery with a custom join
557 | // operator.
558 | func (q SQLServerDeleteQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) SQLServerDeleteQuery {
559 | q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...))
560 | return q
561 | }
562 |
563 | // Where appends to the WherePredicate field of the SQLServerDeleteQuery.
564 | func (q SQLServerDeleteQuery) Where(predicates ...Predicate) SQLServerDeleteQuery {
565 | q.WherePredicate = appendPredicates(q.WherePredicate, predicates)
566 | return q
567 | }
568 |
569 | // SetFetchableFields implements the Query interface.
570 | func (q SQLServerDeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) {
571 | return DeleteQuery(q).SetFetchableFields(fields)
572 | }
573 |
574 | // GetFetchableFields returns the fetchable fields of the query.
575 | func (q SQLServerDeleteQuery) GetFetchableFields() []Field {
576 | return DeleteQuery(q).GetFetchableFields()
577 | }
578 |
579 | // GetDialect implements the Query interface.
580 | func (q SQLServerDeleteQuery) GetDialect() string { return q.Dialect }
581 |
582 | // SetDialect sets the dialect of the query.
583 | func (q SQLServerDeleteQuery) SetDialect(dialect string) SQLServerDeleteQuery {
584 | q.Dialect = dialect
585 | return q
586 | }
587 |
--------------------------------------------------------------------------------
/delete_query_test.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/bokwoon95/sq/internal/testutil"
7 | )
8 |
9 | func TestSQLiteDeleteQuery(t *testing.T) {
10 | type ACTOR struct {
11 | TableStruct
12 | ACTOR_ID NumberField
13 | FIRST_NAME StringField
14 | LAST_NAME StringField
15 | LAST_UPDATE TimeField
16 | }
17 | a := New[ACTOR]("a")
18 |
19 | t.Run("basic", func(t *testing.T) {
20 | t.Parallel()
21 | q1 := SQLite.DeleteFrom(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum")
22 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" {
23 | t.Error(testutil.Callers(), diff)
24 | }
25 | q1 = q1.SetDialect(DialectSQLite)
26 | fields := q1.GetFetchableFields()
27 | if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" {
28 | t.Error(testutil.Callers(), diff)
29 | }
30 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME})
31 | if ok {
32 | t.Fatal(testutil.Callers(), "field should not have been set")
33 | }
34 | q1.ReturningFields = q1.ReturningFields[:0]
35 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME})
36 | if !ok {
37 | t.Fatal(testutil.Callers(), "field should have been set")
38 | }
39 | })
40 |
41 | t.Run("Delete Returning", func(t *testing.T) {
42 | t.Parallel()
43 | var tt TestTable
44 | tt.item = SQLite.
45 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
46 | DeleteFrom(a).
47 | Where(a.ACTOR_ID.EqInt(1)).
48 | Returning(a.FIRST_NAME, a.LAST_NAME)
49 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
50 | " DELETE FROM actor AS a" +
51 | " WHERE a.actor_id = $1" +
52 | " RETURNING a.first_name, a.last_name"
53 | tt.wantArgs = []any{1}
54 | tt.assert(t)
55 | })
56 | }
57 |
58 | func TestPostgresDeleteQuery(t *testing.T) {
59 | type ACTOR struct {
60 | TableStruct
61 | ACTOR_ID NumberField
62 | FIRST_NAME StringField
63 | LAST_NAME StringField
64 | LAST_UPDATE TimeField
65 | }
66 | a := New[ACTOR]("a")
67 |
68 | t.Run("basic", func(t *testing.T) {
69 | t.Parallel()
70 | q1 := Postgres.DeleteFrom(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum")
71 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" {
72 | t.Error(testutil.Callers(), diff)
73 | }
74 | q1 = q1.SetDialect(DialectPostgres)
75 | fields := q1.GetFetchableFields()
76 | if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" {
77 | t.Error(testutil.Callers(), diff)
78 | }
79 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME})
80 | if ok {
81 | t.Fatal(testutil.Callers(), "field should not have been set")
82 | }
83 | q1.ReturningFields = q1.ReturningFields[:0]
84 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME})
85 | if !ok {
86 | t.Fatal(testutil.Callers(), "field should have been set")
87 | }
88 | })
89 |
90 | t.Run("Delete Returning", func(t *testing.T) {
91 | t.Parallel()
92 | var tt TestTable
93 | tt.item = Postgres.
94 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
95 | DeleteFrom(a).
96 | Where(a.ACTOR_ID.EqInt(1)).
97 | Returning(a.FIRST_NAME, a.LAST_NAME)
98 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
99 | " DELETE FROM actor AS a" +
100 | " WHERE a.actor_id = $1" +
101 | " RETURNING a.first_name, a.last_name"
102 | tt.wantArgs = []any{1}
103 | tt.assert(t)
104 | })
105 |
106 | t.Run("Join", func(t *testing.T) {
107 | t.Parallel()
108 | var tt TestTable
109 | tt.item = Postgres.
110 | DeleteFrom(a).
111 | Using(a).
112 | Join(a, Expr("1 = 1")).
113 | LeftJoin(a, Expr("1 = 1")).
114 | FullJoin(a, Expr("1 = 1")).
115 | CrossJoin(a).
116 | CustomJoin(",", a).
117 | JoinUsing(a, a.FIRST_NAME, a.LAST_NAME)
118 | tt.wantQuery = "DELETE FROM actor AS a" +
119 | " USING actor AS a" +
120 | " JOIN actor AS a ON 1 = 1" +
121 | " LEFT JOIN actor AS a ON 1 = 1" +
122 | " FULL JOIN actor AS a ON 1 = 1" +
123 | " CROSS JOIN actor AS a" +
124 | " , actor AS a" +
125 | " JOIN actor AS a USING (first_name, last_name)"
126 | tt.assert(t)
127 | })
128 | }
129 |
130 | func TestMySQLDeleteQuery(t *testing.T) {
131 | type ACTOR struct {
132 | TableStruct
133 | ACTOR_ID NumberField
134 | FIRST_NAME StringField
135 | LAST_NAME StringField
136 | LAST_UPDATE TimeField
137 | }
138 | a := New[ACTOR]("")
139 |
140 | t.Run("basic", func(t *testing.T) {
141 | t.Parallel()
142 | q1 := MySQL.DeleteFrom(a).SetDialect("lorem ipsum")
143 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" {
144 | t.Error(testutil.Callers(), diff)
145 | }
146 | q1 = q1.SetDialect(DialectMySQL)
147 | fields := q1.GetFetchableFields()
148 | if len(fields) != 0 {
149 | t.Error(testutil.Callers(), "expected 0 fields but got %v", fields)
150 | }
151 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME})
152 | if ok {
153 | t.Fatal(testutil.Callers(), "field should not have been set")
154 | }
155 | q1.ReturningFields = q1.ReturningFields[:0]
156 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME})
157 | if ok {
158 | t.Fatal(testutil.Callers(), "field should not have been set")
159 | }
160 | })
161 |
162 | t.Run("Where", func(t *testing.T) {
163 | t.Parallel()
164 | var tt TestTable
165 | tt.item = MySQL.
166 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
167 | DeleteFrom(a).
168 | Where(a.ACTOR_ID.EqInt(1))
169 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
170 | " DELETE FROM actor" +
171 | " WHERE actor.actor_id = ?"
172 | tt.wantArgs = []any{1}
173 | tt.assert(t)
174 | })
175 |
176 | t.Run("OrderBy Limit", func(t *testing.T) {
177 | t.Parallel()
178 | var tt TestTable
179 | tt.item = MySQL.
180 | DeleteFrom(a).
181 | OrderBy(a.ACTOR_ID).
182 | Limit(5)
183 | tt.wantQuery = "DELETE FROM actor" +
184 | " ORDER BY actor.actor_id" +
185 | " LIMIT ?"
186 | tt.wantArgs = []any{5}
187 | tt.assert(t)
188 | })
189 |
190 | t.Run("Delete Returning", func(t *testing.T) {
191 | t.Parallel()
192 | var tt TestTable
193 | tt.item = MySQL.
194 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
195 | DeleteFrom(a).
196 | Where(a.ACTOR_ID.EqInt(1)).
197 | Returning(a.FIRST_NAME, a.LAST_NAME)
198 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
199 | " DELETE FROM actor" +
200 | " WHERE actor.actor_id = ?" +
201 | " RETURNING actor.first_name, actor.last_name"
202 | tt.wantArgs = []any{1}
203 | tt.assert(t)
204 | })
205 |
206 | t.Run("Join", func(t *testing.T) {
207 | t.Parallel()
208 | var tt TestTable
209 | tt.item = MySQL.
210 | Delete(a).
211 | From(a).
212 | Join(a, Expr("1 = 1")).
213 | LeftJoin(a, Expr("1 = 1")).
214 | FullJoin(a, Expr("1 = 1")).
215 | CrossJoin(a).
216 | CustomJoin(",", a).
217 | JoinUsing(a, a.FIRST_NAME, a.LAST_NAME)
218 | tt.wantQuery = "DELETE actor" +
219 | " FROM actor" +
220 | " JOIN actor ON 1 = 1" +
221 | " LEFT JOIN actor ON 1 = 1" +
222 | " FULL JOIN actor ON 1 = 1" +
223 | " CROSS JOIN actor" +
224 | " , actor" +
225 | " JOIN actor USING (first_name, last_name)"
226 | tt.assert(t)
227 | })
228 | }
229 |
230 | func TestSQLServerDeleteQuery(t *testing.T) {
231 | type ACTOR struct {
232 | TableStruct
233 | ACTOR_ID NumberField
234 | FIRST_NAME StringField
235 | LAST_NAME StringField
236 | LAST_UPDATE TimeField
237 | }
238 | a := New[ACTOR]("")
239 |
240 | t.Run("basic", func(t *testing.T) {
241 | t.Parallel()
242 | q1 := SQLServer.DeleteFrom(a).SetDialect("lorem ipsum")
243 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" {
244 | t.Error(testutil.Callers(), diff)
245 | }
246 | q1 = q1.SetDialect(DialectSQLServer)
247 | q1 = q1.SetDialect(DialectMySQL)
248 | fields := q1.GetFetchableFields()
249 | if len(fields) != 0 {
250 | t.Error(testutil.Callers(), "expected 0 fields but got %v", fields)
251 | }
252 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME})
253 | if ok {
254 | t.Fatal(testutil.Callers(), "field should not have been set")
255 | }
256 | q1.ReturningFields = q1.ReturningFields[:0]
257 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME})
258 | if ok {
259 | t.Fatal(testutil.Callers(), "field should not have been set")
260 | }
261 | })
262 |
263 | t.Run("Where", func(t *testing.T) {
264 | t.Parallel()
265 | var tt TestTable
266 | tt.item = SQLServer.
267 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
268 | DeleteFrom(a).
269 | Where(a.ACTOR_ID.EqInt(1))
270 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
271 | " DELETE FROM actor" +
272 | " WHERE actor.actor_id = @p1"
273 | tt.wantArgs = []any{1}
274 | tt.assert(t)
275 | })
276 |
277 | t.Run("Join", func(t *testing.T) {
278 | t.Parallel()
279 | var tt TestTable
280 | tt.item = SQLServer.
281 | DeleteFrom(a).
282 | From(a).
283 | Join(a, Expr("1 = 1")).
284 | LeftJoin(a, Expr("1 = 1")).
285 | FullJoin(a, Expr("1 = 1")).
286 | CrossJoin(a).
287 | CustomJoin(",", a)
288 | tt.wantQuery = "DELETE FROM actor" +
289 | " FROM actor" +
290 | " JOIN actor ON 1 = 1" +
291 | " LEFT JOIN actor ON 1 = 1" +
292 | " FULL JOIN actor ON 1 = 1" +
293 | " CROSS JOIN actor" +
294 | " , actor"
295 | tt.assert(t)
296 | })
297 | }
298 |
299 | func TestDeleteQuery(t *testing.T) {
300 | t.Run("basic", func(t *testing.T) {
301 | t.Parallel()
302 | q1 := DeleteQuery{DeleteTable: Expr("tbl"), Dialect: "lorem ipsum"}
303 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" {
304 | t.Error(testutil.Callers(), diff)
305 | }
306 | })
307 |
308 | t.Run("PolicyTable", func(t *testing.T) {
309 | t.Parallel()
310 | var tt TestTable
311 | tt.item = DeleteQuery{
312 | DeleteTable: policyTableStub{policy: And(Expr("1 = 1"), Expr("2 = 2"))},
313 | WherePredicate: Expr("3 = 3"),
314 | }
315 | tt.wantQuery = "DELETE FROM policy_table_stub WHERE (1 = 1 AND 2 = 2) AND 3 = 3"
316 | tt.assert(t)
317 | })
318 |
319 | notOKTests := []TestTable{{
320 | description: "nil FromTable not allowed",
321 | item: DeleteQuery{
322 | DeleteTable: nil,
323 | },
324 | }, {
325 | description: "sqlite does not support JOIN",
326 | item: DeleteQuery{
327 | Dialect: DialectSQLite,
328 | DeleteTable: Expr("tbl"),
329 | UsingTable: Expr("tbl"),
330 | JoinTables: []JoinTable{
331 | Join(Expr("tbl"), Expr("1 = 1")),
332 | },
333 | },
334 | }, {
335 | description: "postgres does not allow JOIN without USING",
336 | item: DeleteQuery{
337 | Dialect: DialectPostgres,
338 | DeleteTable: Expr("tbl"),
339 | JoinTables: []JoinTable{
340 | Join(Expr("tbl"), Expr("1 = 1")),
341 | },
342 | },
343 | }, {
344 | description: "dialect does not support ORDER BY",
345 | item: DeleteQuery{
346 | Dialect: DialectPostgres,
347 | DeleteTable: Expr("tbl"),
348 | OrderByFields: Fields{Expr("f1")},
349 | },
350 | }, {
351 | description: "dialect does not support LIMIT",
352 | item: DeleteQuery{
353 | Dialect: DialectPostgres,
354 | DeleteTable: Expr("tbl"),
355 | LimitRows: 5,
356 | },
357 | }}
358 |
359 | for _, tt := range notOKTests {
360 | tt := tt
361 | t.Run(tt.description, func(t *testing.T) {
362 | t.Parallel()
363 | tt.assertNotOK(t)
364 | })
365 | }
366 |
367 | errTests := []TestTable{{
368 | description: "FromTable Policy err",
369 | item: DeleteQuery{
370 | DeleteTable: policyTableStub{err: ErrFaultySQL},
371 | },
372 | }, {
373 | description: "UsingTable Policy err",
374 | item: DeleteQuery{
375 | DeleteTable: Expr("tbl"),
376 | UsingTable: policyTableStub{err: ErrFaultySQL},
377 | },
378 | }, {
379 | description: "JoinTables Policy err",
380 | item: DeleteQuery{
381 | DeleteTable: Expr("tbl"),
382 | UsingTable: Expr("tbl"),
383 | JoinTables: []JoinTable{
384 | Join(policyTableStub{err: ErrFaultySQL}, Expr("1 = 1")),
385 | },
386 | },
387 | }, {
388 | description: "CTEs err",
389 | item: DeleteQuery{
390 | CTEs: []CTE{NewCTE("cte", nil, Queryf("SELECT {}", FaultySQL{}))},
391 | DeleteTable: Expr("tbl"),
392 | },
393 | }, {
394 | description: "FromTable err",
395 | item: DeleteQuery{
396 | DeleteTable: FaultySQL{},
397 | },
398 | }, {
399 | description: "postgres UsingTable err",
400 | item: DeleteQuery{
401 | Dialect: DialectPostgres,
402 | DeleteTable: Expr("tbl"),
403 | UsingTable: FaultySQL{},
404 | },
405 | }, {
406 | description: "sqlserver UsingTable err",
407 | item: DeleteQuery{
408 | Dialect: DialectSQLServer,
409 | DeleteTable: Expr("tbl"),
410 | UsingTable: FaultySQL{},
411 | },
412 | }, {
413 | description: "JoinTables err",
414 | item: DeleteQuery{
415 | Dialect: DialectPostgres,
416 | DeleteTable: Expr("tbl"),
417 | UsingTable: Expr("tbl"),
418 | JoinTables: []JoinTable{
419 | Join(Expr("tbl"), FaultySQL{}),
420 | },
421 | },
422 | }, {
423 | description: "WherePredicate Variadic err",
424 | item: DeleteQuery{
425 | DeleteTable: Expr("tbl"),
426 | WherePredicate: And(FaultySQL{}),
427 | },
428 | }, {
429 | description: "WherePredicate err",
430 | item: DeleteQuery{
431 | DeleteTable: Expr("tbl"),
432 | WherePredicate: FaultySQL{},
433 | },
434 | }, {
435 | description: "OrderByFields err",
436 | item: DeleteQuery{
437 | Dialect: DialectMySQL,
438 | DeleteTable: Expr("tbl"),
439 | OrderByFields: Fields{FaultySQL{}},
440 | },
441 | }, {
442 | description: "LimitRows err",
443 | item: DeleteQuery{
444 | Dialect: DialectMySQL,
445 | DeleteTable: Expr("tbl"),
446 | OrderByFields: Fields{Expr("f1")},
447 | LimitRows: FaultySQL{},
448 | },
449 | }, {
450 | description: "ReturningFields err",
451 | item: DeleteQuery{
452 | Dialect: DialectPostgres,
453 | DeleteTable: Expr("tbl"),
454 | ReturningFields: Fields{FaultySQL{}},
455 | },
456 | }}
457 |
458 | for _, tt := range errTests {
459 | tt := tt
460 | t.Run(tt.description, func(t *testing.T) {
461 | t.Parallel()
462 | tt.assertErr(t, ErrFaultySQL)
463 | })
464 | }
465 | }
466 |
--------------------------------------------------------------------------------
/fetch_exec_test.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "database/sql"
5 | "testing"
6 | "time"
7 |
8 | "github.com/bokwoon95/sq/internal/testutil"
9 | _ "github.com/mattn/go-sqlite3"
10 | )
11 |
12 | var ACTOR = New[struct {
13 | TableStruct `sq:"actor"`
14 | ACTOR_ID NumberField
15 | FIRST_NAME StringField
16 | LAST_NAME StringField
17 | LAST_UPDATE TimeField
18 | }]("")
19 |
20 | type Actor struct {
21 | ActorID int
22 | FirstName string
23 | LastName string
24 | LastUpdate time.Time
25 | }
26 |
27 | func actorRowMapper(row *Row) Actor {
28 | var actor Actor
29 | actorID, _ := row.Value("actor.actor_id").(int64)
30 | actor.ActorID = int(actorID)
31 | actor.FirstName = row.StringField(ACTOR.FIRST_NAME)
32 | actor.LastName = row.StringField(ACTOR.LAST_NAME)
33 | actor.LastUpdate, _ = row.Value("actor.last_update").(time.Time)
34 | return actor
35 | }
36 |
37 | func actorRowMapperRawSQL(row *Row) Actor {
38 | result := make(map[string]any)
39 | values := row.Values()
40 | for i, column := range row.Columns() {
41 | result[column] = values[i]
42 | }
43 | var actor Actor
44 | actorID, _ := result["actor_id"].(int64)
45 | actor.ActorID = int(actorID)
46 | actor.FirstName, _ = result["first_name"].(string)
47 | actor.LastName, _ = result["last_name"].(string)
48 | actor.LastUpdate, _ = result["last_update"].(time.Time)
49 | return actor
50 | }
51 |
52 | func Test_substituteParams(t *testing.T) {
53 | t.Run("no params provided", func(t *testing.T) {
54 | t.Parallel()
55 | args := []any{1, 2, 3}
56 | params := map[string][]int{"one": {0}, "two": {1}, "three": {2}}
57 | gotArgs, err := substituteParams("", args, params, nil)
58 | if err != nil {
59 | t.Fatal(testutil.Callers(), err)
60 | }
61 | wantArgs := []any{1, 2, 3}
62 | if diff := testutil.Diff(gotArgs, wantArgs); diff != "" {
63 | t.Error(testutil.Callers(), diff)
64 | }
65 | })
66 |
67 | t.Run("not all params provided", func(t *testing.T) {
68 | t.Parallel()
69 | args := []any{1, 2, 3}
70 | params := map[string][]int{"one": {0}, "two": {1}, "three": {2}}
71 | paramValues := Params{"one": "One", "two": "Two"}
72 | gotArgs, err := substituteParams("", args, params, paramValues)
73 | if err != nil {
74 | t.Fatal(testutil.Callers(), err)
75 | }
76 | wantArgs := []any{"One", "Two", 3}
77 | if diff := testutil.Diff(gotArgs, wantArgs); diff != "" {
78 | t.Error(testutil.Callers(), diff)
79 | }
80 | })
81 |
82 | t.Run("params substituted", func(t *testing.T) {
83 | t.Parallel()
84 | type Data struct {
85 | id int
86 | name string
87 | }
88 | args := []any{
89 | 0,
90 | sql.Named("one", 1),
91 | sql.Named("two", 2),
92 | 3,
93 | }
94 | params := map[string][]int{
95 | "zero": {0},
96 | "one": {1},
97 | "two": {2},
98 | "three": {3},
99 | }
100 | paramValues := Params{
101 | "one": "[one]",
102 | "two": "[two]",
103 | "three": "[three]",
104 | }
105 | wantArgs := []any{
106 | 0,
107 | sql.Named("one", "[one]"),
108 | sql.Named("two", "[two]"),
109 | "[three]",
110 | }
111 | gotArgs, err := substituteParams("", args, params, paramValues)
112 | if err != nil {
113 | t.Fatal(testutil.Callers(), err)
114 | }
115 | if diff := testutil.Diff(gotArgs, wantArgs); diff != "" {
116 | t.Error(testutil.Callers(), diff)
117 | }
118 | })
119 | }
120 |
121 | func Test_getFieldMappings(t *testing.T) {
122 | type TestTable struct {
123 | description string
124 | dialect string
125 | fields []Field
126 | scanDest []any
127 | wantFieldMappings string
128 | }
129 |
130 | var tests = []TestTable{{
131 | description: "empty",
132 | wantFieldMappings: "",
133 | }, {
134 | description: "basic",
135 | fields: []Field{
136 | Expr("actor_id"),
137 | Expr("first_name || {} || last_name", " "),
138 | Expr("last_update"),
139 | },
140 | scanDest: []any{
141 | &sql.NullInt64{},
142 | &sql.NullString{},
143 | &sql.NullTime{},
144 | },
145 | wantFieldMappings: "" +
146 | "\n 01. actor_id => *sql.NullInt64" +
147 | "\n 02. first_name || ' ' || last_name => *sql.NullString" +
148 | "\n 03. last_update => *sql.NullTime",
149 | }}
150 |
151 | for _, tt := range tests {
152 | tt := tt
153 | t.Run(tt.description, func(t *testing.T) {
154 | t.Parallel()
155 | gotFieldMappings := getFieldMappings(tt.dialect, tt.fields, tt.scanDest)
156 | if diff := testutil.Diff(gotFieldMappings, tt.wantFieldMappings); diff != "" {
157 | t.Error(testutil.Callers(), diff)
158 | }
159 | })
160 | }
161 | }
162 |
163 | func TestFetchExec(t *testing.T) {
164 | t.Parallel()
165 | db := newDB(t)
166 |
167 | var referenceActors = []Actor{
168 | {ActorID: 1, FirstName: "PENELOPE", LastName: "GUINESS", LastUpdate: time.Unix(1, 0).UTC()},
169 | {ActorID: 2, FirstName: "NICK", LastName: "WAHLBERG", LastUpdate: time.Unix(1, 0).UTC()},
170 | {ActorID: 3, FirstName: "ED", LastName: "CHASE", LastUpdate: time.Unix(1, 0).UTC()},
171 | {ActorID: 4, FirstName: "JENNIFER", LastName: "DAVIS", LastUpdate: time.Unix(1, 0).UTC()},
172 | {ActorID: 5, FirstName: "JOHNNY", LastName: "LOLLOBRIGIDA", LastUpdate: time.Unix(1, 0).UTC()},
173 | }
174 |
175 | // Exec.
176 | res, err := Exec(Log(db), SQLite.
177 | InsertInto(ACTOR).
178 | ColumnValues(func(col *Column) {
179 | for _, actor := range referenceActors {
180 | col.SetInt(ACTOR.ACTOR_ID, actor.ActorID)
181 | col.SetString(ACTOR.FIRST_NAME, actor.FirstName)
182 | col.SetString(ACTOR.LAST_NAME, actor.LastName)
183 | col.SetTime(ACTOR.LAST_UPDATE, actor.LastUpdate)
184 | }
185 | }),
186 | )
187 | if err != nil {
188 | t.Fatal(testutil.Callers(), err)
189 | }
190 | if diff := testutil.Diff(res.RowsAffected, int64(len(referenceActors))); diff != "" {
191 | t.Fatal(testutil.Callers(), diff)
192 | }
193 |
194 | // FetchOne.
195 | actor, err := FetchOne(Log(db), SQLite.
196 | From(ACTOR).
197 | Where(ACTOR.ACTOR_ID.EqInt(1)),
198 | actorRowMapper,
199 | )
200 | if err != nil {
201 | t.Fatal(testutil.Callers(), err)
202 | }
203 | if diff := testutil.Diff(actor, referenceActors[0]); diff != "" {
204 | t.Fatal(testutil.Callers(), diff)
205 | }
206 |
207 | // FetchOne (Raw SQL).
208 | actor, err = FetchOne(Log(db),
209 | SQLite.Queryf("SELECT * FROM actor WHERE actor_id = {}", 1),
210 | actorRowMapperRawSQL,
211 | )
212 | if err != nil {
213 | t.Fatal(testutil.Callers(), err)
214 | }
215 | if diff := testutil.Diff(actor, referenceActors[0]); diff != "" {
216 | t.Fatal(testutil.Callers(), diff)
217 | }
218 |
219 | // FetchAll.
220 | actors, err := FetchAll(VerboseLog(db), SQLite.
221 | From(ACTOR).
222 | OrderBy(ACTOR.ACTOR_ID),
223 | actorRowMapper,
224 | )
225 | if err != nil {
226 | t.Fatal(testutil.Callers(), err)
227 | }
228 | if diff := testutil.Diff(actors, referenceActors); diff != "" {
229 | t.Fatal(testutil.Callers(), err)
230 | }
231 |
232 | // FetchAll (RawSQL).
233 | actors, err = FetchAll(VerboseLog(db),
234 | SQLite.Queryf("SELECT * FROM actor ORDER BY actor_id"),
235 | actorRowMapperRawSQL,
236 | )
237 | if err != nil {
238 | t.Fatal(testutil.Callers(), err)
239 | }
240 | if diff := testutil.Diff(actors, referenceActors); diff != "" {
241 | t.Fatal(testutil.Callers(), err)
242 | }
243 | }
244 |
245 | func TestCompiledFetchExec(t *testing.T) {
246 | t.Parallel()
247 | db := newDB(t)
248 | var referenceActors = []Actor{
249 | {ActorID: 1, FirstName: "PENELOPE", LastName: "GUINESS", LastUpdate: time.Unix(1, 0).UTC()},
250 | {ActorID: 2, FirstName: "NICK", LastName: "WAHLBERG", LastUpdate: time.Unix(1, 0).UTC()},
251 | {ActorID: 3, FirstName: "ED", LastName: "CHASE", LastUpdate: time.Unix(1, 0).UTC()},
252 | {ActorID: 4, FirstName: "JENNIFER", LastName: "DAVIS", LastUpdate: time.Unix(1, 0).UTC()},
253 | {ActorID: 5, FirstName: "JOHNNY", LastName: "LOLLOBRIGIDA", LastUpdate: time.Unix(1, 0).UTC()},
254 | }
255 |
256 | // CompiledExec.
257 | compiledExec, err := CompileExec(SQLite.
258 | InsertInto(ACTOR).
259 | ColumnValues(func(col *Column) {
260 | col.Set(ACTOR.ACTOR_ID, IntParam("actor_id", 0))
261 | col.Set(ACTOR.FIRST_NAME, StringParam("first_name", ""))
262 | col.Set(ACTOR.LAST_NAME, StringParam("last_name", ""))
263 | col.Set(ACTOR.LAST_UPDATE, TimeParam("last_update", time.Time{}))
264 | }),
265 | )
266 | if err != nil {
267 | t.Fatal(testutil.Callers(), err)
268 | }
269 | for _, actor := range referenceActors {
270 | _, err = compiledExec.Exec(Log(db), Params{
271 | "actor_id": actor.ActorID,
272 | "first_name": actor.FirstName,
273 | "last_name": actor.LastName,
274 | "last_update": actor.LastUpdate,
275 | })
276 | if err != nil {
277 | t.Fatal(testutil.Callers(), err)
278 | }
279 | }
280 |
281 | // CompiledFetch FetchOne.
282 | compiledFetch, err := CompileFetch(SQLite.
283 | From(ACTOR).
284 | Where(ACTOR.ACTOR_ID.Eq(IntParam("actor_id", 0))),
285 | actorRowMapper,
286 | )
287 | if err != nil {
288 | t.Fatal(testutil.Callers(), err)
289 | }
290 | actor, err := compiledFetch.FetchOne(Log(db), Params{"actor_id": 1})
291 | if err != nil {
292 | t.Fatal(testutil.Callers(), err)
293 | }
294 | if diff := testutil.Diff(actor, referenceActors[0]); diff != "" {
295 | t.Fatal(testutil.Callers(), diff)
296 | }
297 |
298 | // CompiledFetch FetchOne (Raw SQL).
299 | compiledFetch, err = CompileFetch(
300 | SQLite.Queryf("SELECT * FROM actor WHERE actor_id = {actor_id}", IntParam("actor_id", 0)),
301 | actorRowMapperRawSQL,
302 | )
303 | if err != nil {
304 | t.Fatal(testutil.Callers(), err)
305 | }
306 | actor, err = compiledFetch.FetchOne(Log(db), Params{"actor_id": 1})
307 | if err != nil {
308 | t.Fatal(testutil.Callers(), err)
309 | }
310 | if diff := testutil.Diff(actor, referenceActors[0]); diff != "" {
311 | t.Fatal(testutil.Callers(), diff)
312 | }
313 |
314 | // CompiledFetch FetchAll.
315 | compiledFetch, err = CompileFetch(SQLite.
316 | From(ACTOR).
317 | OrderBy(ACTOR.ACTOR_ID),
318 | actorRowMapper,
319 | )
320 | if err != nil {
321 | t.Fatal(testutil.Callers(), err)
322 | }
323 | actors, err := compiledFetch.FetchAll(VerboseLog(db), nil)
324 | if err != nil {
325 | t.Fatal(testutil.Callers(), err)
326 | }
327 | if diff := testutil.Diff(actors, referenceActors); diff != "" {
328 | t.Fatal(testutil.Callers(), diff)
329 | }
330 |
331 | // CompiledFetch FetchAll (Raw SQL).
332 | compiledFetch, err = CompileFetch(
333 | SQLite.Queryf("SELECT * FROM actor ORDER BY actor_id"),
334 | actorRowMapperRawSQL,
335 | )
336 | if err != nil {
337 | t.Fatal(testutil.Callers(), err)
338 | }
339 | actors, err = compiledFetch.FetchAll(VerboseLog(db), nil)
340 | if err != nil {
341 | t.Fatal(testutil.Callers(), err)
342 | }
343 | if diff := testutil.Diff(actors, referenceActors); diff != "" {
344 | t.Fatal(testutil.Callers(), diff)
345 | }
346 | }
347 |
348 | func TestPreparedFetchExec(t *testing.T) {
349 | t.Parallel()
350 | db := newDB(t)
351 |
352 | var referenceActors = []Actor{
353 | {ActorID: 1, FirstName: "PENELOPE", LastName: "GUINESS", LastUpdate: time.Unix(1, 0).UTC()},
354 | {ActorID: 2, FirstName: "NICK", LastName: "WAHLBERG", LastUpdate: time.Unix(1, 0).UTC()},
355 | {ActorID: 3, FirstName: "ED", LastName: "CHASE", LastUpdate: time.Unix(1, 0).UTC()},
356 | {ActorID: 4, FirstName: "JENNIFER", LastName: "DAVIS", LastUpdate: time.Unix(1, 0).UTC()},
357 | {ActorID: 5, FirstName: "JOHNNY", LastName: "LOLLOBRIGIDA", LastUpdate: time.Unix(1, 0).UTC()},
358 | }
359 |
360 | // PreparedExec.
361 | preparedExec, err := PrepareExec(Log(db), SQLite.
362 | InsertInto(ACTOR).
363 | ColumnValues(func(col *Column) {
364 | col.Set(ACTOR.ACTOR_ID, IntParam("actor_id", 0))
365 | col.Set(ACTOR.FIRST_NAME, StringParam("first_name", ""))
366 | col.Set(ACTOR.LAST_NAME, StringParam("last_name", ""))
367 | col.Set(ACTOR.LAST_UPDATE, TimeParam("last_update", time.Time{}))
368 | }),
369 | )
370 | if err != nil {
371 | t.Fatal(testutil.Callers(), err)
372 | }
373 | for _, actor := range referenceActors {
374 | _, err = preparedExec.Exec(Params{
375 | "actor_id": actor.ActorID,
376 | "first_name": actor.FirstName,
377 | "last_name": actor.LastName,
378 | "last_update": actor.LastUpdate,
379 | })
380 | if err != nil {
381 | t.Fatal(testutil.Callers(), err)
382 | }
383 | }
384 |
385 | // PreparedFetch FetchOne.
386 | preparedFetch, err := PrepareFetch(Log(db), SQLite.
387 | From(ACTOR).
388 | Where(ACTOR.ACTOR_ID.Eq(IntParam("actor_id", 0))),
389 | actorRowMapper,
390 | )
391 | if err != nil {
392 | t.Fatal(testutil.Callers(), err)
393 | }
394 | actor, err := preparedFetch.FetchOne(Params{"actor_id": 1})
395 | if err != nil {
396 | t.Fatal(testutil.Callers(), err)
397 | }
398 | if diff := testutil.Diff(actor, referenceActors[0]); diff != "" {
399 | t.Fatal(testutil.Callers(), diff)
400 | }
401 |
402 | // PreparedFetch FetchOne (Raw SQL).
403 | preparedFetch, err = PrepareFetch(
404 | Log(db),
405 | SQLite.Queryf("SELECT * FROM actor WHERE actor_id = {actor_id}", IntParam("actor_id", 0)),
406 | actorRowMapperRawSQL,
407 | )
408 | if err != nil {
409 | t.Fatal(testutil.Callers(), err)
410 | }
411 | actor, err = preparedFetch.FetchOne(Params{"actor_id": 1})
412 | if err != nil {
413 | t.Fatal(testutil.Callers(), err)
414 | }
415 | if diff := testutil.Diff(actor, referenceActors[0]); diff != "" {
416 | t.Fatal(testutil.Callers(), diff)
417 | }
418 |
419 | // PreparedFetch FetchAll.
420 | preparedFetch, err = PrepareFetch(VerboseLog(db), SQLite.
421 | From(ACTOR).
422 | OrderBy(ACTOR.ACTOR_ID),
423 | actorRowMapper,
424 | )
425 | if err != nil {
426 | t.Fatal(testutil.Callers(), err)
427 | }
428 | actors, err := preparedFetch.FetchAll(nil)
429 | if err != nil {
430 | t.Fatal(testutil.Callers(), err)
431 | }
432 | if diff := testutil.Diff(actors, referenceActors); diff != "" {
433 | t.Fatal(testutil.Callers(), diff)
434 | }
435 |
436 | // PreparedFetch FetchAll (Raw SQL).
437 | preparedFetch, err = PrepareFetch(
438 | VerboseLog(db),
439 | SQLite.Queryf("SELECT * FROM actor ORDER BY actor_id"),
440 | actorRowMapperRawSQL,
441 | )
442 | if err != nil {
443 | t.Fatal(testutil.Callers(), err)
444 | }
445 | actors, err = preparedFetch.FetchAll(nil)
446 | if err != nil {
447 | t.Fatal(testutil.Callers(), err)
448 | }
449 | if diff := testutil.Diff(actors, referenceActors); diff != "" {
450 | t.Fatal(testutil.Callers(), diff)
451 | }
452 | }
453 |
454 | func newDB(t *testing.T) *sql.DB {
455 | db, err := sql.Open("sqlite3", ":memory:")
456 | if err != nil {
457 | t.Fatal(testutil.Callers(), err)
458 | }
459 | _, err = db.Exec(`CREATE TABLE actor (
460 | actor_id INTEGER PRIMARY KEY AUTOINCREMENT
461 | ,first_name TEXT NOT NULL
462 | ,last_name TEXT NOT NULL
463 | ,last_update DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
464 | )`)
465 | if err != nil {
466 | t.Fatal(testutil.Callers(), err)
467 | }
468 | return db
469 | }
470 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/bokwoon95/sq
2 |
3 | go 1.19
4 |
5 | require (
6 | github.com/denisenkom/go-mssqldb v0.12.3
7 | github.com/go-sql-driver/mysql v1.7.1
8 | github.com/google/go-cmp v0.5.9
9 | github.com/google/uuid v1.3.0
10 | github.com/lib/pq v1.10.9
11 | github.com/mattn/go-sqlite3 v1.14.16
12 | )
13 |
14 | require (
15 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
16 | github.com/golang-sql/sqlexp v0.1.0 // indirect
17 | golang.org/x/crypto v0.9.0 // indirect
18 | )
19 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0/go.mod h1:h6H6c8enJmmocHUbLiiGY6sx7f9i+X3m1CHdd5c6Rdw=
2 | github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0/go.mod h1:HcM1YX14R7CJcghJGOYCgdezslRSVzqwLf/q+4Y2r/0=
3 | github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2TzfVZ1pCb5vxm4BtZPUdYWe/Xo8=
4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
6 | github.com/denisenkom/go-mssqldb v0.12.3 h1:pBSGx9Tq67pBOTLmxNuirNTeB8Vjmf886Kx+8Y+8shw=
7 | github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDrorD1Vrm1KEz5hxDo=
8 | github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
9 | github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
10 | github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
11 | github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
12 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
13 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
14 | github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
15 | github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
16 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
17 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
18 | github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
19 | github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
20 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
21 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
22 | github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
23 | github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
24 | github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8=
25 | github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA=
26 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
27 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
28 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
29 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
30 | golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
31 | golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
32 | golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
33 | golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
34 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
35 | golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
36 | golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
37 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
38 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
39 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
40 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
41 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
42 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
43 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
44 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
45 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
46 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
47 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
48 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
49 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
50 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
51 |
--------------------------------------------------------------------------------
/header.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bokwoon95/sq/eae3b0c03361f5b98ac6d0701c6aa71c94d4e4c2/header.png
--------------------------------------------------------------------------------
/internal/googleuuid/googleuuid.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2009,2014 Google Inc. All rights reserved.
2 | //
3 | // Redistribution and use in source and binary forms, with or without
4 | // modification, are permitted provided that the following conditions are
5 | // met:
6 | //
7 | // * Redistributions of source code must retain the above copyright
8 | // notice, this list of conditions and the following disclaimer.
9 | // * Redistributions in binary form must reproduce the above
10 | // copyright notice, this list of conditions and the following disclaimer
11 | // in the documentation and/or other materials provided with the
12 | // distribution.
13 | // * Neither the name of Google Inc. nor the names of its
14 | // contributors may be used to endorse or promote products derived from
15 | // this software without specific prior written permission.
16 | //
17 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
21 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
22 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
23 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
25 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
29 | package googleuuid
30 |
31 | import (
32 | "bytes"
33 | "encoding/hex"
34 | "errors"
35 | "fmt"
36 | "strings"
37 | )
38 |
39 | // ParseBytes decodes b into a UUID or returns an error. Both the UUID form of
40 | // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and
41 | // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded.
42 | func ParseBytes(b []byte) (uuid [16]byte, err error) {
43 | switch len(b) {
44 | case 36: // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
45 | case 36 + 9: // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
46 | if !bytes.Equal(bytes.ToLower(b[:9]), []byte("urn:uuid:")) {
47 | return uuid, fmt.Errorf("invalid urn prefix: %q", b[:9])
48 | }
49 | b = b[9:]
50 | case 36 + 2: // {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}
51 | b = b[1:]
52 | case 32: // xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
53 | var ok bool
54 | for i := 0; i < 32; i += 2 {
55 | uuid[i/2], ok = xtob(b[i], b[i+1])
56 | if !ok {
57 | return uuid, errors.New("invalid UUID format")
58 | }
59 | }
60 | return uuid, nil
61 | default:
62 | return uuid, fmt.Errorf("invalid UUID length: %d", len(b))
63 | }
64 | // s is now at least 36 bytes long
65 | // it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
66 | if b[8] != '-' || b[13] != '-' || b[18] != '-' || b[23] != '-' {
67 | return uuid, errors.New("invalid UUID format")
68 | }
69 | for i, x := range [16]int{
70 | 0, 2, 4, 6,
71 | 9, 11,
72 | 14, 16,
73 | 19, 21,
74 | 24, 26, 28, 30, 32, 34,
75 | } {
76 | v, ok := xtob(b[x], b[x+1])
77 | if !ok {
78 | return uuid, errors.New("invalid UUID format")
79 | }
80 | uuid[i] = v
81 | }
82 | return uuid, nil
83 | }
84 |
85 | func Parse(s string) (uuid [16]byte, err error) {
86 | if len(s) != 36 {
87 | if len(s) != 36+9 {
88 | return uuid, fmt.Errorf("invalid UUID length: %d", len(s))
89 | }
90 | if strings.ToLower(s[:9]) != "urn:uuid:" {
91 | return uuid, fmt.Errorf("invalid urn prefix: %q", s[:9])
92 | }
93 | s = s[9:]
94 | }
95 | if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
96 | return uuid, errors.New("invalid UUID format")
97 | }
98 | for i, x := range [16]int{
99 | 0, 2, 4, 6,
100 | 9, 11,
101 | 14, 16,
102 | 19, 21,
103 | 24, 26, 28, 30, 32, 34,
104 | } {
105 | v, ok := xtob(s[x], s[x+1])
106 | if !ok {
107 | return uuid, errors.New("invalid UUID format")
108 | }
109 | uuid[i] = v
110 | }
111 | return uuid, nil
112 | }
113 |
114 | // xvalues returns the value of a byte as a hexadecimal digit or 255.
115 | var xvalues = [256]byte{
116 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
117 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
118 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
119 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255,
120 | 255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
121 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
122 | 255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
123 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
124 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
125 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
126 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
127 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
128 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
129 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
130 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
131 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
132 | }
133 |
134 | // xtob converts hex characters x1 and x2 into a byte.
135 | func xtob(x1, x2 byte) (byte, bool) {
136 | b1 := xvalues[x1]
137 | b2 := xvalues[x2]
138 | return (b1 << 4) | b2, b1 != 255 && b2 != 255
139 | }
140 |
141 | // var buf [36]byte; encodeHex(buf[:], [16]byte{...}); string(buf[:])
142 | func EncodeHex(dst []byte, uuid [16]byte) {
143 | hex.Encode(dst[:], uuid[:4])
144 | dst[8] = '-'
145 | hex.Encode(dst[9:13], uuid[4:6])
146 | dst[13] = '-'
147 | hex.Encode(dst[14:18], uuid[6:8])
148 | dst[18] = '-'
149 | hex.Encode(dst[19:23], uuid[8:10])
150 | dst[23] = '-'
151 | hex.Encode(dst[24:], uuid[10:])
152 | }
153 |
--------------------------------------------------------------------------------
/internal/testutil/testutil.go:
--------------------------------------------------------------------------------
1 | package testutil
2 |
3 | import (
4 | "path/filepath"
5 | "reflect"
6 | "runtime"
7 | "strconv"
8 | "strings"
9 |
10 | "github.com/google/go-cmp/cmp"
11 | "github.com/google/go-cmp/cmp/cmpopts"
12 | )
13 |
14 | func Diff[T any](got, want T) string {
15 | diff := cmp.Diff(
16 | got, want,
17 | cmp.Exporter(func(typ reflect.Type) bool { return true }),
18 | cmpopts.EquateEmpty(),
19 | )
20 | if diff != "" {
21 | return "\n-got +want\n" + diff
22 | }
23 | return ""
24 | }
25 |
26 | func Callers() string {
27 | var pc [50]uintptr
28 | n := runtime.Callers(2, pc[:]) // skip runtime.Callers + Callers
29 | callsites := make([]string, 0, n)
30 | frames := runtime.CallersFrames(pc[:n])
31 | for frame, more := frames.Next(); more; frame, more = frames.Next() {
32 | callsites = append(callsites, frame.File+":"+strconv.Itoa(frame.Line))
33 | }
34 | callsites = callsites[:len(callsites)-1] // skip testing.tRunner
35 | if len(callsites) == 1 {
36 | return ""
37 | }
38 | var b strings.Builder
39 | b.WriteString("\n[")
40 | for i := len(callsites) - 1; i >= 0; i-- {
41 | if i < len(callsites)-1 {
42 | b.WriteString(" -> ")
43 | }
44 | b.WriteString(filepath.Base(callsites[i]))
45 | }
46 | b.WriteString("]")
47 | return b.String()
48 | }
49 |
--------------------------------------------------------------------------------
/joins.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | )
8 |
9 | // Join operators.
10 | const (
11 | JoinInner = "JOIN"
12 | JoinLeft = "LEFT JOIN"
13 | JoinRight = "RIGHT JOIN"
14 | JoinFull = "FULL JOIN"
15 | JoinCross = "CROSS JOIN"
16 | )
17 |
18 | // JoinTable represents a join on a table.
19 | type JoinTable struct {
20 | JoinOperator string
21 | Table Table
22 | OnPredicate Predicate
23 | UsingFields []Field
24 | }
25 |
26 | // JoinUsing creates a new JoinTable with the USING operator.
27 | func JoinUsing(table Table, fields ...Field) JoinTable {
28 | return JoinTable{JoinOperator: JoinInner, Table: table, UsingFields: fields}
29 | }
30 |
31 | // Join creates a new JoinTable with the JOIN operator.
32 | func Join(table Table, predicates ...Predicate) JoinTable {
33 | return CustomJoin(JoinInner, table, predicates...)
34 | }
35 |
36 | // LeftJoin creates a new JoinTable with the LEFT JOIN operator.
37 | func LeftJoin(table Table, predicates ...Predicate) JoinTable {
38 | return CustomJoin(JoinLeft, table, predicates...)
39 | }
40 |
41 | // FullJoin creates a new JoinTable with the FULL JOIN operator.
42 | func FullJoin(table Table, predicates ...Predicate) JoinTable {
43 | return CustomJoin(JoinFull, table, predicates...)
44 | }
45 |
46 | // CrossJoin creates a new JoinTable with the CROSS JOIN operator.
47 | func CrossJoin(table Table) JoinTable {
48 | return CustomJoin(JoinCross, table)
49 | }
50 |
51 | // CustomJoin creates a new JoinTable with the a custom join operator.
52 | func CustomJoin(joinOperator string, table Table, predicates ...Predicate) JoinTable {
53 | switch len(predicates) {
54 | case 0:
55 | return JoinTable{JoinOperator: joinOperator, Table: table}
56 | case 1:
57 | return JoinTable{JoinOperator: joinOperator, Table: table, OnPredicate: predicates[0]}
58 | default:
59 | return JoinTable{JoinOperator: joinOperator, Table: table, OnPredicate: And(predicates...)}
60 | }
61 | }
62 |
63 | // WriteSQL implements the SQLWriter interface.
64 | func (join JoinTable) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
65 | if join.JoinOperator == "" {
66 | join.JoinOperator = JoinInner
67 | }
68 | variadicPredicate, isVariadic := join.OnPredicate.(VariadicPredicate)
69 | hasNoPredicate := join.OnPredicate == nil && len(variadicPredicate.Predicates) == 0 && len(join.UsingFields) == 0
70 | if hasNoPredicate && (join.JoinOperator == JoinInner ||
71 | join.JoinOperator == JoinLeft ||
72 | join.JoinOperator == JoinRight ||
73 | join.JoinOperator == JoinFull) &&
74 | // exclude sqlite from this check because they allow join without predicate
75 | dialect != DialectSQLite {
76 | return fmt.Errorf("%s requires at least one predicate specified", join.JoinOperator)
77 | }
78 | if dialect == DialectSQLite && (join.JoinOperator == JoinRight || join.JoinOperator == JoinFull) {
79 | return fmt.Errorf("sqlite does not support %s", join.JoinOperator)
80 | }
81 |
82 | // JOIN
83 | buf.WriteString(string(join.JoinOperator) + " ")
84 | if join.Table == nil {
85 | return fmt.Errorf("joining on a nil table")
86 | }
87 |
88 | //
89 | _, isQuery := join.Table.(Query)
90 | if isQuery {
91 | buf.WriteString("(")
92 | }
93 | err := join.Table.WriteSQL(ctx, dialect, buf, args, params)
94 | if err != nil {
95 | return err
96 | }
97 | if isQuery {
98 | buf.WriteString(")")
99 | }
100 |
101 | // AS
102 | if tableAlias := getAlias(join.Table); tableAlias != "" {
103 | buf.WriteString(" AS " + QuoteIdentifier(dialect, tableAlias) + quoteTableColumns(dialect, join.Table))
104 | } else if isQuery && dialect != DialectSQLite {
105 | return fmt.Errorf("%s %s subquery must have alias", dialect, join.JoinOperator)
106 | }
107 |
108 | if isVariadic {
109 | // ON VariadicPredicate
110 | buf.WriteString(" ON ")
111 | variadicPredicate.Toplevel = true
112 | err = variadicPredicate.WriteSQL(ctx, dialect, buf, args, params)
113 | if err != nil {
114 | return err
115 | }
116 | } else if join.OnPredicate != nil {
117 | // ON Predicate
118 | buf.WriteString(" ON ")
119 | err = join.OnPredicate.WriteSQL(ctx, dialect, buf, args, params)
120 | if err != nil {
121 | return err
122 | }
123 | } else if len(join.UsingFields) > 0 {
124 | // USING Fields
125 | buf.WriteString(" USING (")
126 | err = writeFieldsWithPrefix(ctx, dialect, buf, args, params, join.UsingFields, "", false)
127 | if err != nil {
128 | return err
129 | }
130 | buf.WriteString(")")
131 | }
132 | return nil
133 | }
134 |
135 | func writeJoinTables(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, joinTables []JoinTable) error {
136 | var err error
137 | for i, joinTable := range joinTables {
138 | if i > 0 {
139 | buf.WriteString(" ")
140 | }
141 | err = joinTable.WriteSQL(ctx, dialect, buf, args, params)
142 | if err != nil {
143 | return fmt.Errorf("join #%d: %w", i+1, err)
144 | }
145 | }
146 | return nil
147 | }
148 |
--------------------------------------------------------------------------------
/joins_test.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import "testing"
4 |
5 | func TestJoinTables(t *testing.T) {
6 | type ACTOR struct {
7 | TableStruct
8 | ACTOR_ID NumberField
9 | FIRST_NAME StringField
10 | LAST_NAME StringField
11 | LAST_UPDATE TimeField
12 | }
13 | a := New[ACTOR]("a")
14 |
15 | tests := []TestTable{{
16 | description: "JoinUsing",
17 | item: JoinUsing(a, a.FIRST_NAME, a.LAST_NAME),
18 | wantQuery: "JOIN actor AS a USING (first_name, last_name)",
19 | }, {
20 | description: "Join without operator",
21 | item: CustomJoin("", a, a.ACTOR_ID.Eq(a.ACTOR_ID), a.FIRST_NAME.Ne(a.LAST_NAME)),
22 | wantQuery: "JOIN actor AS a ON a.actor_id = a.actor_id AND a.first_name <> a.last_name",
23 | }, {
24 | description: "Join",
25 | item: Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)),
26 | wantQuery: "JOIN actor AS a ON a.actor_id = a.actor_id",
27 | }, {
28 | description: "LeftJoin",
29 | item: LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)),
30 | wantQuery: "LEFT JOIN actor AS a ON a.actor_id = a.actor_id",
31 | }, {
32 | description: "Right Join",
33 | item: JoinTable{JoinOperator: JoinRight, Table: a, OnPredicate: a.ACTOR_ID.Eq(a.ACTOR_ID)},
34 | wantQuery: "RIGHT JOIN actor AS a ON a.actor_id = a.actor_id",
35 | }, {
36 | description: "FullJoin",
37 | item: FullJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)),
38 | wantQuery: "FULL JOIN actor AS a ON a.actor_id = a.actor_id",
39 | }, {
40 | description: "CrossJoin",
41 | item: CrossJoin(a),
42 | wantQuery: "CROSS JOIN actor AS a",
43 | }}
44 |
45 | for _, tt := range tests {
46 | tt := tt
47 | t.Run(tt.description, func(t *testing.T) {
48 | t.Parallel()
49 | tt.assert(t)
50 | })
51 | }
52 |
53 | notOKTests := []TestTable{{
54 | description: "full join has no predicate",
55 | item: FullJoin(a),
56 | }, {
57 | description: "sqlite does not support full join",
58 | dialect: DialectSQLite,
59 | item: FullJoin(a, Expr("TRUE")),
60 | }, {
61 | description: "table is nil",
62 | item: Join(nil, Expr("TRUE")),
63 | }, {
64 | description: "UsingField returns err",
65 | item: JoinUsing(a, nil),
66 | }}
67 |
68 | for _, tt := range notOKTests {
69 | tt := tt
70 | t.Run(tt.description, func(t *testing.T) {
71 | t.Parallel()
72 | tt.assertNotOK(t)
73 | })
74 | }
75 |
76 | errTests := []TestTable{{
77 | description: "table err",
78 | item: Join(FaultySQL{}, a.ACTOR_ID.Eq(a.ACTOR_ID)),
79 | }, {
80 | description: "VariadicPredicate err",
81 | item: Join(a, And(FaultySQL{})),
82 | }, {
83 | description: "predicate err",
84 | item: Join(a, FaultySQL{}),
85 | }}
86 |
87 | for _, tt := range errTests {
88 | tt := tt
89 | t.Run(tt.description, func(t *testing.T) {
90 | t.Parallel()
91 | tt.assertErr(t, ErrFaultySQL)
92 | })
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/logger.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "database/sql"
7 | "fmt"
8 | "io"
9 | "log"
10 | "os"
11 | "path/filepath"
12 | "strconv"
13 | "strings"
14 | "sync/atomic"
15 | "time"
16 | )
17 |
18 | // QueryStats represents the statistics from running a query.
19 | type QueryStats struct {
20 | // Dialect of the query.
21 | Dialect string
22 |
23 | // Query string.
24 | Query string
25 |
26 | // Args slice provided with the query string.
27 | Args []any
28 |
29 | // Params maps param names back to arguments in the args slice (by index).
30 | Params map[string][]int
31 |
32 | // Err is the error from running the query.
33 | Err error
34 |
35 | // RowCount from running the query. Not valid for Exec().
36 | RowCount sql.NullInt64
37 |
38 | // RowsAffected by running the query. Not valid for
39 | // FetchOne/FetchAll/FetchCursor.
40 | RowsAffected sql.NullInt64
41 |
42 | // LastInsertId of the query.
43 | LastInsertId sql.NullInt64
44 |
45 | // Exists is the result of FetchExists().
46 | Exists sql.NullBool
47 |
48 | // When the query started at.
49 | StartedAt time.Time
50 |
51 | // Time taken by the query.
52 | TimeTaken time.Duration
53 |
54 | // The caller file where the query was invoked.
55 | CallerFile string
56 |
57 | // The line in the caller file that invoked the query.
58 | CallerLine int
59 |
60 | // The name of the function where the query was invoked.
61 | CallerFunction string
62 |
63 | // The results from running the query (if it was provided).
64 | Results string
65 | }
66 |
67 | // LogSettings are the various log settings taken into account when producing
68 | // the QueryStats.
69 | type LogSettings struct {
70 | // Dispatch logging asynchronously (logs may arrive out of order which can be confusing, but it won't block function calls).
71 | LogAsynchronously bool
72 |
73 | // Include time taken by the query.
74 | IncludeTime bool
75 |
76 | // Include caller (filename and line number).
77 | IncludeCaller bool
78 |
79 | // Include fetched results.
80 | IncludeResults int
81 | }
82 |
83 | // SqLogger represents a logger for the sq package.
84 | type SqLogger interface {
85 | // SqLogSettings should populate a LogSettings struct, which influences
86 | // what is added into the QueryStats.
87 | SqLogSettings(context.Context, *LogSettings)
88 |
89 | // SqLogQuery logs a query when for the given QueryStats.
90 | SqLogQuery(context.Context, QueryStats)
91 | }
92 |
93 | type sqLogger struct {
94 | logger *log.Logger
95 | config LoggerConfig
96 | }
97 |
98 | // LoggerConfig is the config used for the sq logger.
99 | type LoggerConfig struct {
100 | // Dispatch logging asynchronously (logs may arrive out of order which can be confusing, but it won't block function calls).
101 | LogAsynchronously bool
102 |
103 | // Show time taken by the query.
104 | ShowTimeTaken bool
105 |
106 | // Show caller (filename and line number).
107 | ShowCaller bool
108 |
109 | // Show fetched results.
110 | ShowResults int
111 |
112 | // If true, logs are shown as plaintext (no color).
113 | NoColor bool
114 |
115 | // Verbose query interpolation, which shows the query before and after
116 | // interpolating query arguments. The logged query is interpolated by
117 | // default, InterpolateVerbose only controls whether the query before
118 | // interpolation is shown. To disable query interpolation entirely, look at
119 | // HideArgs.
120 | InterpolateVerbose bool
121 |
122 | // Explicitly hides arguments when logging the query (only the query
123 | // placeholders will be shown).
124 | HideArgs bool
125 | }
126 |
127 | var _ SqLogger = (*sqLogger)(nil)
128 |
129 | var defaultLogger = NewLogger(os.Stdout, "", log.LstdFlags, LoggerConfig{
130 | ShowTimeTaken: true,
131 | ShowCaller: true,
132 | })
133 |
134 | var verboseLogger = NewLogger(os.Stdout, "", log.LstdFlags, LoggerConfig{
135 | ShowTimeTaken: true,
136 | ShowCaller: true,
137 | ShowResults: 5,
138 | InterpolateVerbose: true,
139 | })
140 |
141 | // NewLogger returns a new SqLogger.
142 | func NewLogger(w io.Writer, prefix string, flag int, config LoggerConfig) SqLogger {
143 | return &sqLogger{
144 | logger: log.New(w, prefix, flag),
145 | config: config,
146 | }
147 | }
148 |
149 | // SqLogSettings implements the SqLogger interface.
150 | func (l *sqLogger) SqLogSettings(ctx context.Context, settings *LogSettings) {
151 | settings.LogAsynchronously = l.config.LogAsynchronously
152 | settings.IncludeTime = l.config.ShowTimeTaken
153 | settings.IncludeCaller = l.config.ShowCaller
154 | settings.IncludeResults = l.config.ShowResults
155 | }
156 |
157 | // SqLogQuery implements the SqLogger interface.
158 | func (l *sqLogger) SqLogQuery(ctx context.Context, queryStats QueryStats) {
159 | var reset, red, green, blue, purple string
160 | envNoColor, _ := strconv.ParseBool(os.Getenv("NO_COLOR"))
161 | if !l.config.NoColor && !envNoColor {
162 | reset = colorReset
163 | red = colorRed
164 | green = colorGreen
165 | blue = colorBlue
166 | purple = colorPurple
167 | }
168 | buf := bufpool.Get().(*bytes.Buffer)
169 | buf.Reset()
170 | defer bufpool.Put(buf)
171 | if queryStats.Err == nil {
172 | buf.WriteString(green + "[OK]" + reset)
173 | } else {
174 | buf.WriteString(red + "[FAIL]" + reset)
175 | }
176 | if l.config.HideArgs {
177 | buf.WriteString(" " + queryStats.Query + ";")
178 | } else if !l.config.InterpolateVerbose {
179 | if queryStats.Err != nil {
180 | buf.WriteString(" " + queryStats.Query + ";")
181 | if len(queryStats.Args) > 0 {
182 | buf.WriteString(" [")
183 | }
184 | for i := 0; i < len(queryStats.Args); i++ {
185 | if i > 0 {
186 | buf.WriteString(", ")
187 | }
188 | buf.WriteString(fmt.Sprintf("%#v", queryStats.Args[i]))
189 | }
190 | if len(queryStats.Args) > 0 {
191 | buf.WriteString("]")
192 | }
193 | } else {
194 | query, err := Sprintf(queryStats.Dialect, queryStats.Query, queryStats.Args)
195 | if err != nil {
196 | query += " " + err.Error()
197 | }
198 | buf.WriteString(" " + query + ";")
199 | }
200 | }
201 | if queryStats.Err != nil {
202 | errStr := queryStats.Err.Error()
203 | if i := strings.IndexByte(errStr, '\n'); i < 0 {
204 | buf.WriteString(blue + " err" + reset + "={" + queryStats.Err.Error() + "}")
205 | }
206 | }
207 | if l.config.ShowTimeTaken {
208 | buf.WriteString(blue + " timeTaken" + reset + "=" + queryStats.TimeTaken.String())
209 | }
210 | if queryStats.RowCount.Valid {
211 | buf.WriteString(blue + " rowCount" + reset + "=" + strconv.FormatInt(queryStats.RowCount.Int64, 10))
212 | }
213 | if queryStats.RowsAffected.Valid {
214 | buf.WriteString(blue + " rowsAffected" + reset + "=" + strconv.FormatInt(queryStats.RowsAffected.Int64, 10))
215 | }
216 | if queryStats.LastInsertId.Valid {
217 | buf.WriteString(blue + " lastInsertId" + reset + "=" + strconv.FormatInt(queryStats.LastInsertId.Int64, 10))
218 | }
219 | if queryStats.Exists.Valid {
220 | buf.WriteString(blue + " exists" + reset + "=" + strconv.FormatBool(queryStats.Exists.Bool))
221 | }
222 | if l.config.ShowCaller {
223 | buf.WriteString(blue + " caller" + reset + "=" + queryStats.CallerFile + ":" + strconv.Itoa(queryStats.CallerLine) + ":" + filepath.Base(queryStats.CallerFunction))
224 | }
225 | if !l.config.HideArgs && l.config.InterpolateVerbose {
226 | buf.WriteString("\n" + purple + "----[ Executing query ]----" + reset)
227 | buf.WriteString("\n" + queryStats.Query + "; " + fmt.Sprintf("%#v", queryStats.Args))
228 | buf.WriteString("\n" + purple + "----[ with bind values ]----" + reset)
229 | query, err := Sprintf(queryStats.Dialect, queryStats.Query, queryStats.Args)
230 | query += ";"
231 | if err != nil {
232 | query += " " + err.Error()
233 | }
234 | buf.WriteString("\n" + query)
235 | }
236 | if l.config.ShowResults > 0 && queryStats.Err == nil {
237 | buf.WriteString("\n" + purple + "----[ Fetched result ]----" + reset)
238 | buf.WriteString(queryStats.Results)
239 | if queryStats.RowCount.Int64 > int64(l.config.ShowResults) {
240 | buf.WriteString("\n...\n(Fetched " + strconv.FormatInt(queryStats.RowCount.Int64, 10) + " rows)")
241 | }
242 | }
243 | if buf.Len() > 0 {
244 | l.logger.Println(buf.String())
245 | }
246 | }
247 |
248 | // Log wraps a DB and adds logging to it.
249 | func Log(db DB) interface {
250 | DB
251 | SqLogger
252 | } {
253 | return struct {
254 | DB
255 | SqLogger
256 | }{DB: db, SqLogger: defaultLogger}
257 | }
258 |
259 | // VerboseLog wraps a DB and adds verbose logging to it.
260 | func VerboseLog(db DB) interface {
261 | DB
262 | SqLogger
263 | } {
264 | return struct {
265 | DB
266 | SqLogger
267 | }{DB: db, SqLogger: verboseLogger}
268 | }
269 |
270 | var defaultLogSettings atomic.Value
271 |
272 | // SetDefaultLogSettings sets the function to configure the default
273 | // LogSettings. This value is not used unless SetDefaultLogQuery is also
274 | // configured.
275 | func SetDefaultLogSettings(logSettings func(context.Context, *LogSettings)) {
276 | defaultLogSettings.Store(logSettings)
277 | }
278 |
279 | var defaultLogQuery atomic.Value
280 |
281 | // SetDefaultLogQuery sets the default logging function to call for all
282 | // queries (if a logger is not explicitly passed in).
283 | func SetDefaultLogQuery(logQuery func(context.Context, QueryStats)) {
284 | defaultLogQuery.Store(logQuery)
285 | }
286 |
287 | type sqLogStruct struct {
288 | logSettings func(context.Context, *LogSettings)
289 | logQuery func(context.Context, QueryStats)
290 | }
291 |
292 | var _ SqLogger = (*sqLogStruct)(nil)
293 |
294 | func (l *sqLogStruct) SqLogSettings(ctx context.Context, logSettings *LogSettings) {
295 | if l.logSettings == nil {
296 | return
297 | }
298 | l.logSettings(ctx, logSettings)
299 | }
300 |
301 | func (l *sqLogStruct) SqLogQuery(ctx context.Context, queryStats QueryStats) {
302 | if l.logQuery == nil {
303 | return
304 | }
305 | l.logQuery(ctx, queryStats)
306 | }
307 |
308 | const (
309 | colorReset = "\x1b[0m"
310 | colorRed = "\x1b[91m"
311 | colorGreen = "\x1b[92m"
312 | colorYellow = "\x1b[93m"
313 | colorBlue = "\x1b[94m"
314 | colorPurple = "\x1b[95m"
315 | colorCyan = "\x1b[96m"
316 | colorGray = "\x1b[97m"
317 | colorWhite = "\x1b[97m"
318 | )
319 |
--------------------------------------------------------------------------------
/logger_test.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "database/sql"
7 | "fmt"
8 | "log"
9 | "testing"
10 | "time"
11 |
12 | "github.com/bokwoon95/sq/internal/testutil"
13 | )
14 |
15 | func TestLogger(t *testing.T) {
16 | type TT struct {
17 | description string
18 | ctx context.Context
19 | stats QueryStats
20 | config LoggerConfig
21 | wantOutput string
22 | }
23 |
24 | assert := func(t *testing.T, tt TT) {
25 | if tt.ctx == nil {
26 | tt.ctx = context.Background()
27 | }
28 | buf := &bytes.Buffer{}
29 | logger := sqLogger{
30 | logger: log.New(buf, "", 0),
31 | config: tt.config,
32 | }
33 | logger.SqLogQuery(tt.ctx, tt.stats)
34 | if diff := testutil.Diff(buf.String(), tt.wantOutput); diff != "" {
35 | t.Error(testutil.Callers(), diff)
36 | }
37 | }
38 |
39 | t.Run("Log VerboseLog", func(t *testing.T) {
40 | t.Parallel()
41 | var logSettings LogSettings
42 | Log(nil).SqLogSettings(context.Background(), &logSettings)
43 | diff := testutil.Diff(logSettings, LogSettings{
44 | LogAsynchronously: false,
45 | IncludeTime: true,
46 | IncludeCaller: true,
47 | IncludeResults: 0,
48 | })
49 | if diff != "" {
50 | t.Error(testutil.Callers(), diff)
51 | }
52 | VerboseLog(nil).SqLogSettings(context.Background(), &logSettings)
53 | diff = testutil.Diff(logSettings, LogSettings{
54 | LogAsynchronously: false,
55 | IncludeTime: true,
56 | IncludeCaller: true,
57 | IncludeResults: 5,
58 | })
59 | if diff != "" {
60 | t.Error(testutil.Callers(), diff)
61 | }
62 | })
63 |
64 | t.Run("no color", func(t *testing.T) {
65 | var tt TT
66 | tt.config.NoColor = true
67 | tt.stats.Query = "SELECT 1"
68 | tt.wantOutput = "[OK] SELECT 1;\n"
69 | assert(t, tt)
70 | })
71 |
72 | tests := []TT{{
73 | description: "err",
74 | stats: QueryStats{
75 | Query: "SELECT 1",
76 | Err: fmt.Errorf("lorem ipsum"),
77 | },
78 | wantOutput: "\x1b[91m[FAIL]\x1b[0m SELECT 1;\x1b[94m err\x1b[0m={lorem ipsum}\n",
79 | }, {
80 | description: "HideArgs",
81 | config: LoggerConfig{HideArgs: true},
82 | stats: QueryStats{
83 | Query: "SELECT ?", Args: []any{1},
84 | },
85 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT ?;\n",
86 | }, {
87 | description: "RowCount",
88 | stats: QueryStats{
89 | Query: "SELECT 1",
90 | RowCount: sql.NullInt64{Valid: true, Int64: 3},
91 | },
92 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;\x1b[94m rowCount\x1b[0m=3\n",
93 | }, {
94 | description: "RowsAffected",
95 | stats: QueryStats{
96 | Query: "SELECT 1",
97 | RowsAffected: sql.NullInt64{Valid: true, Int64: 5},
98 | },
99 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;\x1b[94m rowsAffected\x1b[0m=5\n",
100 | }, {
101 | description: "LastInsertId",
102 | stats: QueryStats{
103 | Query: "SELECT 1",
104 | LastInsertId: sql.NullInt64{Valid: true, Int64: 7},
105 | },
106 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;\x1b[94m lastInsertId\x1b[0m=7\n",
107 | }, {
108 | description: "Exists",
109 | stats: QueryStats{
110 | Query: "SELECT EXISTS (SELECT 1)",
111 | Exists: sql.NullBool{Valid: true, Bool: true},
112 | },
113 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT EXISTS (SELECT 1);\x1b[94m exists\x1b[0m=true\n",
114 | }, {
115 | description: "ShowCaller",
116 | config: LoggerConfig{ShowCaller: true},
117 | stats: QueryStats{
118 | Query: "SELECT 1",
119 | CallerFile: "file.go",
120 | CallerLine: 22,
121 | CallerFunction: "someFunc",
122 | },
123 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;\x1b[94m caller\x1b[0m=file.go:22:someFunc\n",
124 | }, {
125 | description: "Verbose",
126 | config: LoggerConfig{InterpolateVerbose: true, ShowTimeTaken: true},
127 | stats: QueryStats{
128 | Query: "SELECT ?, ?", Args: []any{1, "bob"},
129 | TimeTaken: 300 * time.Millisecond,
130 | },
131 | wantOutput: "\x1b[92m[OK]\x1b[0m\x1b[94m timeTaken\x1b[0m=300ms" +
132 | "\n\x1b[95m----[ Executing query ]----\x1b[0m" +
133 | "\nSELECT ?, ?; []interface {}{1, \"bob\"}" +
134 | "\n\x1b[95m----[ with bind values ]----\x1b[0m" +
135 | "\nSELECT 1, 'bob';\n",
136 | }, {
137 | description: "ShowResults",
138 | config: LoggerConfig{ShowResults: 1},
139 | stats: QueryStats{
140 | Query: "SELECT 1",
141 | Results: "\nlorem ipsum dolor sit amet",
142 | },
143 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;" +
144 | "\n\x1b[95m----[ Fetched result ]----\x1b[0m" +
145 | "\nlorem ipsum dolor sit amet\n",
146 | }}
147 |
148 | for _, tt := range tests {
149 | tt := tt
150 | t.Run(tt.description, func(t *testing.T) {
151 | assert(t, tt)
152 | })
153 | }
154 | }
155 |
--------------------------------------------------------------------------------
/misc.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | "strings"
8 | )
9 |
10 | // ValueExpression represents an SQL value that is passed in as an argument to
11 | // a prepared query.
12 | type ValueExpression struct {
13 | value any
14 | alias string
15 | }
16 |
17 | var _ interface {
18 | Field
19 | Predicate
20 | Any
21 | } = (*ValueExpression)(nil)
22 |
23 | // Value returns a new ValueExpression.
24 | func Value(value any) ValueExpression { return ValueExpression{value: value} }
25 |
26 | // WriteSQL implements the SQLWriter interface.
27 | func (e ValueExpression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
28 | return WriteValue(ctx, dialect, buf, args, params, e.value)
29 | }
30 |
31 | // As returns a new ValueExpression with the given alias.
32 | func (e ValueExpression) As(alias string) ValueExpression {
33 | e.alias = alias
34 | return e
35 | }
36 |
37 | // In returns a 'expr IN (val)' Predicate.
38 | func (e ValueExpression) In(val any) Predicate { return In(e.value, val) }
39 |
40 | // Eq returns a 'expr = val' Predicate.
41 | func (e ValueExpression) Eq(val any) Predicate { return Eq(e.value, val) }
42 |
43 | // Ne returns a 'expr <> val' Predicate.
44 | func (e ValueExpression) Ne(val any) Predicate { return Ne(e.value, val) }
45 |
46 | // Lt returns a 'expr < val' Predicate.
47 | func (e ValueExpression) Lt(val any) Predicate { return Lt(e.value, val) }
48 |
49 | // Le returns a 'expr <= val' Predicate.
50 | func (e ValueExpression) Le(val any) Predicate { return Le(e.value, val) }
51 |
52 | // Gt returns a 'expr > val' Predicate.
53 | func (e ValueExpression) Gt(val any) Predicate { return Gt(e.value, val) }
54 |
55 | // Ge returns a 'expr >= val' Predicate.
56 | func (e ValueExpression) Ge(val any) Predicate { return Ge(e.value, val) }
57 |
58 | // GetAlias returns the alias of the ValueExpression.
59 | func (e ValueExpression) GetAlias() string { return e.alias }
60 |
61 | // IsField implements the Field interface.
62 | func (e ValueExpression) IsField() {}
63 |
64 | // IsArray implements the Array interface.
65 | func (e ValueExpression) IsArray() {}
66 |
67 | // IsBinary implements the Binary interface.
68 | func (e ValueExpression) IsBinary() {}
69 |
70 | // IsBoolean implements the Boolean interface.
71 | func (e ValueExpression) IsBoolean() {}
72 |
73 | // IsEnum implements the Enum interface.
74 | func (e ValueExpression) IsEnum() {}
75 |
76 | // IsJSON implements the JSON interface.
77 | func (e ValueExpression) IsJSON() {}
78 |
79 | // IsNumber implements the Number interface.
80 | func (e ValueExpression) IsNumber() {}
81 |
82 | // IsString implements the String interface.
83 | func (e ValueExpression) IsString() {}
84 |
85 | // IsTime implements the Time interfaces.
86 | func (e ValueExpression) IsTime() {}
87 |
88 | // IsUUID implements the UUID interface.
89 | func (e ValueExpression) IsUUID() {}
90 |
91 | // LiteralValue represents an SQL value literally interpolated into the query.
92 | // Doing so potentially exposes the query to SQL injection so only do this for
93 | // values that you trust e.g. literals and constants.
94 | type LiteralValue struct {
95 | value any
96 | alias string
97 | }
98 |
99 | var _ interface {
100 | Field
101 | Predicate
102 | Binary
103 | Boolean
104 | Number
105 | String
106 | Time
107 | } = (*LiteralValue)(nil)
108 |
109 | // Literal returns a new LiteralValue.
110 | func Literal(value any) LiteralValue { return LiteralValue{value: value} }
111 |
112 | // WriteSQL implements the SQLWriter interface.
113 | func (v LiteralValue) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
114 | s, err := Sprint(dialect, v.value)
115 | if err != nil {
116 | return err
117 | }
118 | buf.WriteString(s)
119 | return nil
120 | }
121 |
122 | // As returns a new LiteralValue with the given alias.
123 | func (v LiteralValue) As(alias string) LiteralValue {
124 | v.alias = alias
125 | return v
126 | }
127 |
128 | // In returns a 'literal IN (val)' Predicate.
129 | func (v LiteralValue) In(val any) Predicate { return In(v, val) }
130 |
131 | // Eq returns a 'literal = val' Predicate.
132 | func (v LiteralValue) Eq(val any) Predicate { return Eq(v, val) }
133 |
134 | // Ne returns a 'literal <> val' Predicate.
135 | func (v LiteralValue) Ne(val any) Predicate { return Ne(v, val) }
136 |
137 | // Lt returns a 'literal < val' Predicate.
138 | func (v LiteralValue) Lt(val any) Predicate { return Lt(v, val) }
139 |
140 | // Le returns a 'literal <= val' Predicate.
141 | func (v LiteralValue) Le(val any) Predicate { return Le(v, val) }
142 |
143 | // Gt returns a 'literal > val' Predicate.
144 | func (v LiteralValue) Gt(val any) Predicate { return Gt(v, val) }
145 |
146 | // Ge returns a 'literal >= val' Predicate.
147 | func (v LiteralValue) Ge(val any) Predicate { return Ge(v, val) }
148 |
149 | // GetAlias returns the alias of the LiteralValue.
150 | func (v LiteralValue) GetAlias() string { return v.alias }
151 |
152 | // IsField implements the Field interface.
153 | func (v LiteralValue) IsField() {}
154 |
155 | // IsBinary implements the Binary interface.
156 | func (v LiteralValue) IsBinary() {}
157 |
158 | // IsBoolean implements the Boolean interface.
159 | func (v LiteralValue) IsBoolean() {}
160 |
161 | // IsNumber implements the Number interface.
162 | func (v LiteralValue) IsNumber() {}
163 |
164 | // IsString implements the String interface.
165 | func (v LiteralValue) IsString() {}
166 |
167 | // IsTime implements the Time interfaces.
168 | func (v LiteralValue) IsTime() {}
169 |
170 | // DialectExpression represents an SQL expression that renders differently
171 | // depending on the dialect.
172 | type DialectExpression struct {
173 | Default any
174 | Cases DialectCases
175 | }
176 |
177 | // DialectCases is a slice of DialectCases.
178 | type DialectCases = []DialectCase
179 |
180 | // DialectCase holds the result to be used for a given dialect in a
181 | // DialectExpression.
182 | type DialectCase struct {
183 | Dialect string
184 | Result any
185 | }
186 |
187 | var _ interface {
188 | Table
189 | Field
190 | Predicate
191 | Any
192 | } = (*DialectExpression)(nil)
193 |
194 | // DialectValue returns a new DialectExpression. The value passed in is used as
195 | // the default.
196 | func DialectValue(value any) DialectExpression {
197 | return DialectExpression{Default: value}
198 | }
199 |
200 | // DialectExpr returns a new DialectExpression. The expression passed in is
201 | // used as the default.
202 | func DialectExpr(format string, values ...any) DialectExpression {
203 | return DialectExpression{Default: Expr(format, values...)}
204 | }
205 |
206 | // WriteSQL implements the SQLWriter interface.
207 | func (e DialectExpression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
208 | for _, Case := range e.Cases {
209 | if dialect == Case.Dialect {
210 | return WriteValue(ctx, dialect, buf, args, params, Case.Result)
211 | }
212 | }
213 | return WriteValue(ctx, dialect, buf, args, params, e.Default)
214 | }
215 |
216 | // DialectValue adds a new dialect-value pair to the DialectExpression.
217 | func (e DialectExpression) DialectValue(dialect string, value any) DialectExpression {
218 | e.Cases = append(e.Cases, DialectCase{Dialect: dialect, Result: value})
219 | return e
220 | }
221 |
222 | // DialectExpr adds a new dialect-expression pair to the DialectExpression.
223 | func (e DialectExpression) DialectExpr(dialect string, format string, values ...any) DialectExpression {
224 | e.Cases = append(e.Cases, DialectCase{Dialect: dialect, Result: Expr(format, values...)})
225 | return e
226 | }
227 |
228 | // IsTable implements the Table interface.
229 | func (e DialectExpression) IsTable() {}
230 |
231 | // IsField implements the Field interface.
232 | func (e DialectExpression) IsField() {}
233 |
234 | // IsArray implements the Array interface.
235 | func (e DialectExpression) IsArray() {}
236 |
237 | // IsBinary implements the Binary interface.
238 | func (e DialectExpression) IsBinary() {}
239 |
240 | // IsBoolean implements the Boolean interface.
241 | func (e DialectExpression) IsBoolean() {}
242 |
243 | // IsEnum implements the Enum interface.
244 | func (e DialectExpression) IsEnum() {}
245 |
246 | // IsJSON implements the JSON interface.
247 | func (e DialectExpression) IsJSON() {}
248 |
249 | // IsNumber implements the Number interface.
250 | func (e DialectExpression) IsNumber() {}
251 |
252 | // IsString implements the String interface.
253 | func (e DialectExpression) IsString() {}
254 |
255 | // IsTime implements the Time interface.
256 | func (e DialectExpression) IsTime() {}
257 |
258 | // IsUUID implements the UUID interface.
259 | func (e DialectExpression) IsUUID() {}
260 |
261 | // CaseExpression represents an SQL CASE expression.
262 | type CaseExpression struct {
263 | alias string
264 | Cases PredicateCases
265 | Default any
266 | }
267 |
268 | // PredicateCases is a slice of PredicateCases.
269 | type PredicateCases = []PredicateCase
270 |
271 | // PredicateCase holds the result to be used for a given predicate in a
272 | // CaseExpression.
273 | type PredicateCase struct {
274 | Predicate Predicate
275 | Result any
276 | }
277 |
278 | var _ interface {
279 | Field
280 | Any
281 | } = (*CaseExpression)(nil)
282 |
283 | // CaseWhen returns a new CaseExpression.
284 | func CaseWhen(predicate Predicate, result any) CaseExpression {
285 | return CaseExpression{
286 | Cases: PredicateCases{{Predicate: predicate, Result: result}},
287 | }
288 | }
289 |
290 | // WriteSQL implements the SQLWriter interface.
291 | func (e CaseExpression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
292 | buf.WriteString("CASE")
293 | if len(e.Cases) == 0 {
294 | return fmt.Errorf("CaseExpression empty")
295 | }
296 | var err error
297 | for i, Case := range e.Cases {
298 | buf.WriteString(" WHEN ")
299 | err = WriteValue(ctx, dialect, buf, args, params, Case.Predicate)
300 | if err != nil {
301 | return fmt.Errorf("CASE #%d WHEN: %w", i+1, err)
302 | }
303 | buf.WriteString(" THEN ")
304 | err = WriteValue(ctx, dialect, buf, args, params, Case.Result)
305 | if err != nil {
306 | return fmt.Errorf("CASE #%d THEN: %w", i+1, err)
307 | }
308 | }
309 | if e.Default != nil {
310 | buf.WriteString(" ELSE ")
311 | err = WriteValue(ctx, dialect, buf, args, params, e.Default)
312 | if err != nil {
313 | return fmt.Errorf("CASE ELSE: %w", err)
314 | }
315 | }
316 | buf.WriteString(" END")
317 | return nil
318 | }
319 |
320 | // When adds a new predicate-result pair to the CaseExpression.
321 | func (e CaseExpression) When(predicate Predicate, result any) CaseExpression {
322 | e.Cases = append(e.Cases, PredicateCase{Predicate: predicate, Result: result})
323 | return e
324 | }
325 |
326 | // Else sets the fallback result of the CaseExpression.
327 | func (e CaseExpression) Else(fallback any) CaseExpression {
328 | e.Default = fallback
329 | return e
330 | }
331 |
332 | // As returns a new CaseExpression with the given alias.
333 | func (e CaseExpression) As(alias string) CaseExpression {
334 | e.alias = alias
335 | return e
336 | }
337 |
338 | // GetAlias returns the alias of the CaseExpression.
339 | func (e CaseExpression) GetAlias() string { return e.alias }
340 |
341 | // IsField implements the Field interface.
342 | func (e CaseExpression) IsField() {}
343 |
344 | // IsArray implements the Array interface.
345 | func (e CaseExpression) IsArray() {}
346 |
347 | // IsBinary implements the Binary interface.
348 | func (e CaseExpression) IsBinary() {}
349 |
350 | // IsBoolean implements the Boolean interface.
351 | func (e CaseExpression) IsBoolean() {}
352 |
353 | // IsEnum implements the Enum interface.
354 | func (e CaseExpression) IsEnum() {}
355 |
356 | // IsJSON implements the JSON interface.
357 | func (e CaseExpression) IsJSON() {}
358 |
359 | // IsNumber implements the Number interface.
360 | func (e CaseExpression) IsNumber() {}
361 |
362 | // IsString implements the String interface.
363 | func (e CaseExpression) IsString() {}
364 |
365 | // IsTime implements the Time interface.
366 | func (e CaseExpression) IsTime() {}
367 |
368 | // IsUUID implements the UUID interface.
369 | func (e CaseExpression) IsUUID() {}
370 |
371 | // SimpleCaseExpression represents an SQL simple CASE expression.
372 | type SimpleCaseExpression struct {
373 | alias string
374 | Expression any
375 | Cases SimpleCases
376 | Default any
377 | }
378 |
379 | // SimpleCases is a slice of SimpleCases.
380 | type SimpleCases = []SimpleCase
381 |
382 | // SimpleCase holds the result to be used for a given value in a
383 | // SimpleCaseExpression.
384 | type SimpleCase struct {
385 | Value any
386 | Result any
387 | }
388 |
389 | var _ interface {
390 | Field
391 | Any
392 | } = (*SimpleCaseExpression)(nil)
393 |
394 | // Case returns a new SimpleCaseExpression.
395 | func Case(expression any) SimpleCaseExpression {
396 | return SimpleCaseExpression{Expression: expression}
397 | }
398 |
399 | // WriteSQL implements the SQLWriter interface.
400 | func (e SimpleCaseExpression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
401 | buf.WriteString("CASE ")
402 | if len(e.Cases) == 0 {
403 | return fmt.Errorf("SimpleCaseExpression empty")
404 | }
405 | var err error
406 | err = WriteValue(ctx, dialect, buf, args, params, e.Expression)
407 | if err != nil {
408 | return fmt.Errorf("CASE: %w", err)
409 | }
410 | for i, Case := range e.Cases {
411 | buf.WriteString(" WHEN ")
412 | err = WriteValue(ctx, dialect, buf, args, params, Case.Value)
413 | if err != nil {
414 | return fmt.Errorf("CASE #%d WHEN: %w", i+1, err)
415 | }
416 | buf.WriteString(" THEN ")
417 | err = WriteValue(ctx, dialect, buf, args, params, Case.Result)
418 | if err != nil {
419 | return fmt.Errorf("CASE #%d THEN: %w", i+1, err)
420 | }
421 | }
422 | if e.Default != nil {
423 | buf.WriteString(" ELSE ")
424 | err = WriteValue(ctx, dialect, buf, args, params, e.Default)
425 | if err != nil {
426 | return fmt.Errorf("CASE ELSE: %w", err)
427 | }
428 | }
429 | buf.WriteString(" END")
430 | return nil
431 | }
432 |
433 | // When adds a new value-result pair to the SimpleCaseExpression.
434 | func (e SimpleCaseExpression) When(value any, result any) SimpleCaseExpression {
435 | e.Cases = append(e.Cases, SimpleCase{Value: value, Result: result})
436 | return e
437 | }
438 |
439 | // Else sets the fallback result of the SimpleCaseExpression.
440 | func (e SimpleCaseExpression) Else(fallback any) SimpleCaseExpression {
441 | e.Default = fallback
442 | return e
443 | }
444 |
445 | // As returns a new SimpleCaseExpression with the given alias.
446 | func (e SimpleCaseExpression) As(alias string) SimpleCaseExpression {
447 | e.alias = alias
448 | return e
449 | }
450 |
451 | // GetAlias returns the alias of the SimpleCaseExpression.
452 | func (e SimpleCaseExpression) GetAlias() string { return e.alias }
453 |
454 | // IsField implements the Field interface.
455 | func (e SimpleCaseExpression) IsField() {}
456 |
457 | // IsArray implements the Array interface.
458 | func (e SimpleCaseExpression) IsArray() {}
459 |
460 | // IsBinary implements the Binary interface.
461 | func (e SimpleCaseExpression) IsBinary() {}
462 |
463 | // IsBoolean implements the Boolean interface.
464 | func (e SimpleCaseExpression) IsBoolean() {}
465 |
466 | // IsEnum implements the Enum interface.
467 | func (e SimpleCaseExpression) IsEnum() {}
468 |
469 | // IsJSON implements the JSON interface.
470 | func (e SimpleCaseExpression) IsJSON() {}
471 |
472 | // IsNumber implements the Number interface.
473 | func (e SimpleCaseExpression) IsNumber() {}
474 |
475 | // IsString implements the String interface.
476 | func (e SimpleCaseExpression) IsString() {}
477 |
478 | // IsTime implements the Time interface.
479 | func (e SimpleCaseExpression) IsTime() {}
480 |
481 | // IsUUID implements the UUID interface.
482 | func (e SimpleCaseExpression) IsUUID() {}
483 |
484 | // Count represents an SQL COUNT() expression.
485 | func Count(field Field) Expression { return Expr("COUNT({})", field) }
486 |
487 | // CountStar represents an SQL COUNT(*) expression.
488 | func CountStar() Expression { return Expr("COUNT(*)") }
489 |
490 | // Sum represents an SQL SUM() expression.
491 | func Sum(num Number) Expression { return Expr("SUM({})", num) }
492 |
493 | // Avg represents an SQL AVG() expression.
494 | func Avg(num Number) Expression { return Expr("AVG({})", num) }
495 |
496 | // Min represent an SQL MIN() expression.
497 | func Min(field Field) Expression { return Expr("MIN({})", field) }
498 |
499 | // Max represents an SQL MAX() expression.
500 | func Max(field Field) Expression { return Expr("MAX({})", field) }
501 |
502 | // SelectValues represents a table literal comprised of SELECT statements
503 | // UNION-ed together e.g.
504 | //
505 | // (SELECT 1 AS a, 2 AS b, 3 AS c
506 | // UNION ALL
507 | // SELECT 4, 5, 6
508 | // UNION ALL
509 | // SELECT 7, 8, 9) AS tbl
510 | type SelectValues struct {
511 | Alias string
512 | Columns []string
513 | RowValues [][]any
514 | }
515 |
516 | var _ interface {
517 | Query
518 | Table
519 | } = (*SelectValues)(nil)
520 |
521 | // WriteSQL implements the SQLWriter interface.
522 | func (vs SelectValues) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
523 | var err error
524 | for i, rowvalue := range vs.RowValues {
525 | if i > 0 {
526 | buf.WriteString(" UNION ALL ")
527 | }
528 | if len(vs.Columns) > 0 && len(rowvalue) != len(vs.Columns) {
529 | return fmt.Errorf("rowvalue #%d: got %d values, want %d values (%s)", i+1, len(rowvalue), len(vs.Columns), strings.Join(vs.Columns, ", "))
530 | }
531 | buf.WriteString("SELECT ")
532 | for j, value := range rowvalue {
533 | if j > 0 {
534 | buf.WriteString(", ")
535 | }
536 | err = WriteValue(ctx, dialect, buf, args, params, value)
537 | if err != nil {
538 | return fmt.Errorf("rowvalue #%d value #%d: %w", i+1, j+1, err)
539 | }
540 | if i == 0 && j < len(vs.Columns) {
541 | buf.WriteString(" AS " + QuoteIdentifier(dialect, vs.Columns[j]))
542 | }
543 | }
544 | }
545 | return nil
546 | }
547 |
548 | // Field returns a new field qualified by the SelectValues' alias.
549 | func (vs SelectValues) Field(name string) AnyField {
550 | return NewAnyField(name, TableStruct{alias: vs.Alias})
551 | }
552 |
553 | // SetFetchableFields implements the Query interface. It always returns false
554 | // as the second result.
555 | func (vs SelectValues) SetFetchableFields([]Field) (query Query, ok bool) { return vs, false }
556 |
557 | // GetDialect implements the Query interface. It always returns an empty
558 | // string.
559 | func (vs SelectValues) GetDialect() string { return "" }
560 |
561 | // GetAlias returns the alias of the SelectValues.
562 | func (vs SelectValues) GetAlias() string { return vs.Alias }
563 |
564 | // IsTable implements the Table interface.
565 | func (vs SelectValues) IsTable() {}
566 |
567 | // TableValues represents a table literal created by the VALUES clause e.g.
568 | //
569 | // (VALUES
570 | //
571 | // (1, 2, 3),
572 | // (4, 5, 6),
573 | // (7, 8, 9)) AS tbl (a, b, c)
574 | type TableValues struct {
575 | Alias string
576 | Columns []string
577 | RowValues [][]any
578 | }
579 |
580 | var _ interface {
581 | Query
582 | Table
583 | } = (*TableValues)(nil)
584 |
585 | // WriteSQL implements the SQLWriter interface.
586 | func (vs TableValues) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
587 | if len(vs.RowValues) == 0 {
588 | return nil
589 | }
590 | var err error
591 | buf.WriteString("VALUES ")
592 | for i, rowvalue := range vs.RowValues {
593 | if len(vs.Columns) > 0 && len(vs.Columns) != len(rowvalue) {
594 | return fmt.Errorf("rowvalue #%d: got %d values, want %d values (%s)", i+1, len(rowvalue), len(vs.Columns), strings.Join(vs.Columns, ", "))
595 | }
596 | if i > 0 {
597 | buf.WriteString(", ")
598 | }
599 | if dialect == DialectMySQL {
600 | buf.WriteString("ROW(")
601 | } else {
602 | buf.WriteString("(")
603 | }
604 | for j, value := range rowvalue {
605 | if j > 0 {
606 | buf.WriteString(", ")
607 | }
608 | err = WriteValue(ctx, dialect, buf, args, params, value)
609 | if err != nil {
610 | return fmt.Errorf("rowvalue #%d value #%d: %w", i+1, j+1, err)
611 | }
612 | }
613 | buf.WriteString(")")
614 | }
615 | return nil
616 | }
617 |
618 | // Field returns a new field qualified by the TableValues' alias.
619 | func (vs TableValues) Field(name string) AnyField {
620 | return NewAnyField(name, TableStruct{alias: vs.Alias})
621 | }
622 |
623 | // SetFetchableFields implements the Query interface. It always returns false
624 | // as the second result.
625 | func (vs TableValues) SetFetchableFields([]Field) (query Query, ok bool) { return vs, false }
626 |
627 | // GetDialect implements the Query interface. It always returns an empty
628 | // string.
629 | func (vs TableValues) GetDialect() string { return "" }
630 |
631 | // GetAlias returns the alias of the TableValues.
632 | func (vs TableValues) GetAlias() string { return vs.Alias }
633 |
634 | // GetColumns returns the names of the columns in the TableValues.
635 | func (vs TableValues) GetColumns() []string { return vs.Columns }
636 |
637 | // IsTable implements the Table interface.
638 | func (vs TableValues) IsTable() {}
639 |
--------------------------------------------------------------------------------
/misc_test.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "testing"
7 | "time"
8 |
9 | "github.com/bokwoon95/sq/internal/testutil"
10 | )
11 |
12 | func TestValueExpression(t *testing.T) {
13 | t.Run("alias", func(t *testing.T) {
14 | t.Parallel()
15 | expr := Value(1).As("num")
16 | if diff := testutil.Diff(expr.GetAlias(), "num"); diff != "" {
17 | t.Error(testutil.Callers(), diff)
18 | }
19 | })
20 |
21 | tests := []TestTable{{
22 | description: "basic",
23 | item: Value(Param("xyz", 42)),
24 | wantQuery: "?",
25 | wantArgs: []any{42},
26 | wantParams: map[string][]int{"xyz": {0}},
27 | }, {
28 | description: "In", item: Value(1).In([]int{18, 21, 32}),
29 | wantQuery: "? IN (?, ?, ?)", wantArgs: []any{1, 18, 21, 32},
30 | }, {
31 | description: "Eq", item: Value(1).Eq(34),
32 | wantQuery: "? = ?", wantArgs: []any{1, 34},
33 | }, {
34 | description: "Ne", item: Value(1).Ne(34),
35 | wantQuery: "? <> ?", wantArgs: []any{1, 34},
36 | }, {
37 | description: "Lt", item: Value(1).Lt(34),
38 | wantQuery: "? < ?", wantArgs: []any{1, 34},
39 | }, {
40 | description: "Le", item: Value(1).Le(34),
41 | wantQuery: "? <= ?", wantArgs: []any{1, 34},
42 | }, {
43 | description: "Gt", item: Value(1).Gt(34),
44 | wantQuery: "? > ?", wantArgs: []any{1, 34},
45 | }, {
46 | description: "Ge", item: Value(1).Ge(34),
47 | wantQuery: "? >= ?", wantArgs: []any{1, 34},
48 | }}
49 |
50 | for _, tt := range tests {
51 | tt := tt
52 | t.Run(tt.description, func(t *testing.T) {
53 | t.Parallel()
54 | tt.assert(t)
55 | })
56 | }
57 | }
58 |
59 | func TestLiteralExpression(t *testing.T) {
60 | t.Run("alias", func(t *testing.T) {
61 | t.Parallel()
62 | expr := Literal(1).As("num")
63 | if diff := testutil.Diff(expr.GetAlias(), "num"); diff != "" {
64 | t.Error(testutil.Callers(), diff)
65 | }
66 | })
67 |
68 | tests := []TestTable{{
69 | description: "binary",
70 | item: Literal([]byte{0xab, 0xcd, 0xef}),
71 | wantQuery: "x'abcdef'",
72 | }, {
73 | description: "time", item: Literal(time.Unix(0, 0).UTC()),
74 | wantQuery: "'1970-01-01 00:00:00'",
75 | }, {
76 | description: "In", item: Literal(1).In([]any{Literal(18), Literal(21), Literal(32)}),
77 | wantQuery: "1 IN (18, 21, 32)",
78 | }, {
79 | description: "Eq", item: Literal(true).Eq(Literal(false)),
80 | wantQuery: "TRUE = FALSE",
81 | }, {
82 | description: "Ne", item: Literal("one").Ne(Literal("thirty four")),
83 | wantQuery: "'one' <> 'thirty four'",
84 | }, {
85 | description: "Lt", item: Literal(1).Lt(Literal(34)),
86 | wantQuery: "1 < 34",
87 | }, {
88 | description: "Le", item: Literal(1).Le(Literal(34)),
89 | wantQuery: "1 <= 34",
90 | }, {
91 | description: "Gt", item: Literal(1).Gt(Literal(34)),
92 | wantQuery: "1 > 34",
93 | }, {
94 | description: "Ge", item: Literal(1).Ge(Literal(34)),
95 | wantQuery: "1 >= 34",
96 | }}
97 |
98 | for _, tt := range tests {
99 | tt := tt
100 | t.Run(tt.description, func(t *testing.T) {
101 | t.Parallel()
102 | tt.assert(t)
103 | })
104 | }
105 | }
106 |
107 | func TestDialectExpression(t *testing.T) {
108 | t.Parallel()
109 | expr := DialectValue(Expr("default")).
110 | DialectValue(DialectSQLite, Expr("sqlite")).
111 | DialectValue(DialectPostgres, Expr("postgres")).
112 | DialectValue(DialectMySQL, Expr("mysql")).
113 | DialectExpr(DialectSQLServer, "{}", Expr("sqlserver"))
114 | var tt TestTable
115 | tt.item = expr
116 | // default
117 | tt.wantQuery = "default"
118 | tt.assert(t)
119 | // sqlite
120 | tt.dialect = DialectSQLite
121 | tt.wantQuery = "sqlite"
122 | tt.assert(t)
123 | // postgres
124 | tt.dialect = DialectPostgres
125 | tt.wantQuery = "postgres"
126 | tt.assert(t)
127 | // mysql
128 | tt.dialect = DialectMySQL
129 | tt.wantQuery = "mysql"
130 | tt.assert(t)
131 | // sqlserver
132 | tt.dialect = DialectSQLServer
133 | tt.wantQuery = "sqlserver"
134 | tt.assert(t)
135 | }
136 |
137 | func TestCaseExpressions(t *testing.T) {
138 | t.Run("name and alias", func(t *testing.T) {
139 | t.Parallel()
140 | // CaseExpression
141 | caseExpr := CaseWhen(Value(true), 1).As("result_a")
142 | if diff := testutil.Diff(caseExpr.GetAlias(), "result_a"); diff != "" {
143 | t.Error(testutil.Callers(), diff)
144 | }
145 | // SimpleCaseExpression
146 | simpleCaseExpr := Case(1).When(1, 2).As("result_b")
147 | if diff := testutil.Diff(simpleCaseExpr.GetAlias(), "result_b"); diff != "" {
148 | t.Error(testutil.Callers(), diff)
149 | }
150 | })
151 |
152 | t.Run("CaseExpression", func(t *testing.T) {
153 | t.Parallel()
154 | TestTable{
155 | item: CaseWhen(Expr("x = y"), 1).When(Expr("a = b"), 2).Else(3),
156 | wantQuery: "CASE WHEN x = y THEN ? WHEN a = b THEN ? ELSE ? END",
157 | wantArgs: []any{1, 2, 3},
158 | }.assert(t)
159 | })
160 |
161 | t.Run("SimpleCaseExpression", func(t *testing.T) {
162 | t.Parallel()
163 | TestTable{
164 | item: Case(Expr("a")).When(1, 2).When(3, 4).Else(5),
165 | wantQuery: "CASE a WHEN ? THEN ? WHEN ? THEN ? ELSE ? END",
166 | wantArgs: []any{1, 2, 3, 4, 5},
167 | }.assert(t)
168 | })
169 |
170 | t.Run("CaseExpression cannot be empty", func(t *testing.T) {
171 | t.Parallel()
172 | TestTable{item: CaseExpression{}}.assertNotOK(t)
173 | })
174 |
175 | t.Run("SimpleCaseExpression cannot be empty", func(t *testing.T) {
176 | t.Parallel()
177 | TestTable{item: SimpleCaseExpression{}}.assertNotOK(t)
178 | })
179 |
180 | errTests := []TestTable{{
181 | description: "CASE WHEN predicate err", item: CaseWhen(FaultySQL{}, 1),
182 | }, {
183 | description: "CASE WHEN result err", item: CaseWhen(Value(true), FaultySQL{}),
184 | }, {
185 | description: "CASE WHEN fallback err", item: CaseWhen(Value(true), 1).Else(FaultySQL{}),
186 | }, {
187 | description: "CASE expression err", item: Case(FaultySQL{}).When(1, 2),
188 | }, {
189 | description: "CASE value err", item: Case(1).When(FaultySQL{}, 2),
190 | }, {
191 | description: "CASE result err", item: Case(1).When(2, FaultySQL{}),
192 | }, {
193 | description: "CASE fallback err", item: Case(1).When(2, 3).Else(FaultySQL{}),
194 | }}
195 |
196 | for _, tt := range errTests {
197 | tt := tt
198 | t.Run(tt.description, func(t *testing.T) {
199 | t.Parallel()
200 | tt.assertErr(t, ErrFaultySQL)
201 | })
202 | }
203 | }
204 |
205 | func TestSelectValues(t *testing.T) {
206 | type TestTable struct {
207 | description string
208 | dialect string
209 | item SelectValues
210 | wantQuery string
211 | wantArgs []any
212 | }
213 |
214 | t.Run("dialect alias and fields", func(t *testing.T) {
215 | selectValues := SelectValues{
216 | Alias: "aaa",
217 | }
218 | if diff := testutil.Diff(selectValues.GetAlias(), "aaa"); diff != "" {
219 | t.Error(testutil.Callers(), diff)
220 | }
221 | if diff := testutil.Diff(selectValues.GetDialect(), ""); diff != "" {
222 | t.Error(testutil.Callers(), diff)
223 | }
224 | _, ok := selectValues.SetFetchableFields(nil)
225 | if diff := testutil.Diff(ok, false); diff != "" {
226 | t.Error(testutil.Callers(), diff)
227 | }
228 | gotField, _, _ := ToSQL("", selectValues.Field("bbb"), nil)
229 | if diff := testutil.Diff(gotField, "aaa.bbb"); diff != "" {
230 | t.Error(testutil.Callers(), diff)
231 | }
232 | })
233 |
234 | tests := []TestTable{{
235 | description: "empty",
236 | item: SelectValues{},
237 | wantQuery: "",
238 | wantArgs: nil,
239 | }, {
240 | description: "no columns",
241 | item: SelectValues{
242 | RowValues: [][]any{
243 | {1, 2, 3},
244 | {4, 5, 6},
245 | {7, 8, 9},
246 | },
247 | },
248 | wantQuery: "SELECT ?, ?, ?" +
249 | " UNION ALL SELECT ?, ?, ?" +
250 | " UNION ALL SELECT ?, ?, ?",
251 | wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9},
252 | }, {
253 | description: "postgres",
254 | dialect: DialectPostgres,
255 | item: SelectValues{
256 | Columns: []string{"a", "b", "c"},
257 | RowValues: [][]any{
258 | {1, 2, 3},
259 | {4, 5, 6},
260 | {7, 8, 9},
261 | },
262 | },
263 | wantQuery: "SELECT $1 AS a, $2 AS b, $3 AS c" +
264 | " UNION ALL SELECT $4, $5, $6" +
265 | " UNION ALL SELECT $7, $8, $9",
266 | wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9},
267 | }}
268 |
269 | for _, tt := range tests {
270 | tt := tt
271 | t.Run(tt.description, func(t *testing.T) {
272 | t.Parallel()
273 | var buf bytes.Buffer
274 | var gotArgs []any
275 | err := tt.item.WriteSQL(context.Background(), tt.dialect, &buf, &gotArgs, nil)
276 | if err != nil {
277 | t.Fatal(testutil.Callers(), err)
278 | }
279 | gotQuery := buf.String()
280 | if diff := testutil.Diff(gotQuery, tt.wantQuery); diff != "" {
281 | t.Error(testutil.Callers(), diff)
282 | }
283 | if diff := testutil.Diff(gotArgs, tt.wantArgs); diff != "" {
284 | t.Error(testutil.Callers(), diff)
285 | }
286 | })
287 | }
288 | }
289 |
290 | func TestTableValues(t *testing.T) {
291 | type TestTable struct {
292 | description string
293 | dialect string
294 | item TableValues
295 | wantQuery string
296 | wantArgs []any
297 | }
298 |
299 | t.Run("dialect alias columns and fields", func(t *testing.T) {
300 | tableValues := TableValues{
301 | Alias: "aaa",
302 | Columns: []string{"a", "b", "c"},
303 | }
304 | if diff := testutil.Diff(tableValues.GetAlias(), "aaa"); diff != "" {
305 | t.Error(testutil.Callers(), diff)
306 | }
307 | if diff := testutil.Diff(tableValues.GetDialect(), ""); diff != "" {
308 | t.Error(testutil.Callers(), diff)
309 | }
310 | _, ok := tableValues.SetFetchableFields(nil)
311 | if diff := testutil.Diff(ok, false); diff != "" {
312 | t.Error(testutil.Callers(), diff)
313 | }
314 | gotColumns := tableValues.GetColumns()
315 | wantColumns := []string{"a", "b", "c"}
316 | if diff := testutil.Diff(gotColumns, wantColumns); diff != "" {
317 | t.Error(testutil.Callers(), diff)
318 | }
319 | gotField, _, _ := ToSQL("", tableValues.Field("bbb"), nil)
320 | wantField := "aaa.bbb"
321 | if diff := testutil.Diff(gotField, wantField); diff != "" {
322 | t.Error(testutil.Callers(), diff)
323 | }
324 | })
325 |
326 | tests := []TestTable{{
327 | description: "empty",
328 | item: TableValues{},
329 | wantQuery: "",
330 | wantArgs: nil,
331 | }, {
332 | description: "no columns",
333 | item: TableValues{
334 | RowValues: [][]any{
335 | {1, 2, 3},
336 | {4, 5, 6},
337 | {7, 8, 9},
338 | },
339 | },
340 | wantQuery: "VALUES (?, ?, ?)" +
341 | ", (?, ?, ?)" +
342 | ", (?, ?, ?)",
343 | wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9},
344 | }, {
345 | description: "postgres",
346 | dialect: DialectPostgres,
347 | item: TableValues{
348 | Columns: []string{"a", "b", "c"},
349 | RowValues: [][]any{
350 | {1, 2, 3},
351 | {4, 5, 6},
352 | {7, 8, 9},
353 | },
354 | },
355 | wantQuery: "VALUES ($1, $2, $3)" +
356 | ", ($4, $5, $6)" +
357 | ", ($7, $8, $9)",
358 | wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9},
359 | }, {
360 | description: "mysql",
361 | dialect: DialectMySQL,
362 | item: TableValues{
363 | Columns: []string{"a", "b", "c"},
364 | RowValues: [][]any{
365 | {1, 2, 3},
366 | {4, 5, 6},
367 | {7, 8, 9},
368 | },
369 | },
370 | wantQuery: "VALUES ROW(?, ?, ?)" +
371 | ", ROW(?, ?, ?)" +
372 | ", ROW(?, ?, ?)",
373 | wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9},
374 | }}
375 |
376 | for _, tt := range tests {
377 | tt := tt
378 | t.Run(tt.description, func(t *testing.T) {
379 | t.Parallel()
380 | var buf bytes.Buffer
381 | var gotArgs []any
382 | err := tt.item.WriteSQL(context.Background(), tt.dialect, &buf, &gotArgs, nil)
383 | if err != nil {
384 | t.Fatal(testutil.Callers(), err)
385 | }
386 | gotQuery := buf.String()
387 | if diff := testutil.Diff(gotQuery, tt.wantQuery); diff != "" {
388 | t.Error(testutil.Callers(), diff)
389 | }
390 | if diff := testutil.Diff(gotArgs, tt.wantArgs); diff != "" {
391 | t.Error(testutil.Callers(), diff)
392 | }
393 | })
394 | }
395 | }
396 |
--------------------------------------------------------------------------------
/sq.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "database/sql"
7 | "database/sql/driver"
8 | "encoding/json"
9 | "fmt"
10 | "reflect"
11 | "strings"
12 | "sync"
13 |
14 | "github.com/bokwoon95/sq/internal/googleuuid"
15 | "github.com/bokwoon95/sq/internal/pqarray"
16 | )
17 |
18 | var bufpool = &sync.Pool{
19 | New: func() any { return &bytes.Buffer{} },
20 | }
21 |
22 | // Dialects supported.
23 | const (
24 | DialectSQLite = "sqlite"
25 | DialectPostgres = "postgres"
26 | DialectMySQL = "mysql"
27 | DialectSQLServer = "sqlserver"
28 | )
29 |
30 | // SQLWriter is anything that can be converted to SQL.
31 | type SQLWriter interface {
32 | // WriteSQL writes the SQL representation of the SQLWriter into the query
33 | // string (*bytes.Buffer) and args slice (*[]any).
34 | //
35 | // The params map is used to hold the mappings between named parameters in
36 | // the query to the corresponding index in the args slice and is used for
37 | // rebinding args by their parameter name. The params map may be nil, check
38 | // first before writing to it.
39 | WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error
40 | }
41 |
42 | // DB is a database/sql abstraction that can query the database. *sql.Conn,
43 | // *sql.DB and *sql.Tx all implement DB.
44 | type DB interface {
45 | QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
46 | ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
47 | PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
48 | }
49 |
50 | // Result is the result of an Exec command.
51 | type Result struct {
52 | LastInsertId int64
53 | RowsAffected int64
54 | }
55 |
56 | // Query is either SELECT, INSERT, UPDATE or DELETE.
57 | type Query interface {
58 | SQLWriter
59 | // SetFetchableFields should return a query with its fetchable fields set
60 | // to the given fields. If not applicable, it should return false as the
61 | // second return value.
62 | SetFetchableFields([]Field) (query Query, ok bool)
63 | GetDialect() string
64 | }
65 |
66 | // Table is anything you can Select from or Join.
67 | type Table interface {
68 | SQLWriter
69 | IsTable()
70 | }
71 |
72 | // PolicyTable is a table that produces a policy (i.e. a predicate) to be
73 | // enforced whenever it is invoked in a query. This is equivalent to Postgres'
74 | // Row Level Security (RLS) feature but works application-side. Only SELECT,
75 | // UPDATE and DELETE queries are affected.
76 | type PolicyTable interface {
77 | Table
78 | Policy(ctx context.Context, dialect string) (Predicate, error)
79 | }
80 |
81 | // Window is a window used in SQL window functions.
82 | type Window interface {
83 | SQLWriter
84 | IsWindow()
85 | }
86 |
87 | // Field is either a table column or some SQL expression.
88 | type Field interface {
89 | SQLWriter
90 | IsField()
91 | }
92 |
93 | // Predicate is an SQL expression that evaluates to true or false.
94 | type Predicate interface {
95 | Boolean
96 | }
97 |
98 | // Assignment is an SQL assignment 'field = value'.
99 | type Assignment interface {
100 | SQLWriter
101 | IsAssignment()
102 | }
103 |
104 | // Any is a catch-all interface that covers every field type.
105 | type Any interface {
106 | Array
107 | Binary
108 | Boolean
109 | Enum
110 | JSON
111 | Number
112 | String
113 | Time
114 | UUID
115 | }
116 |
117 | // Enumeration represents a Go enum.
118 | type Enumeration interface {
119 | // Enumerate returns the names of all valid enum values.
120 | //
121 | // If the enum is backed by a string, each string in the slice is the
122 | // corresponding enum's string value.
123 | //
124 | // If the enum is backed by an int, each int index in the slice is the
125 | // corresponding enum's int value and the string is the enum's name. Enums
126 | // with empty string names are considered invalid, unless it is the very
127 | // first enum (at index 0).
128 | Enumerate() []string
129 | }
130 |
131 | // Array is a Field of array type.
132 | type Array interface {
133 | Field
134 | IsArray()
135 | }
136 |
137 | // Binary is a Field of binary type.
138 | type Binary interface {
139 | Field
140 | IsBinary()
141 | }
142 |
143 | // Boolean is a Field of boolean type.
144 | type Boolean interface {
145 | Field
146 | IsBoolean()
147 | }
148 |
149 | // Enum is a Field of enum type.
150 | type Enum interface {
151 | Field
152 | IsEnum()
153 | }
154 |
155 | // JSON is a Field of json type.
156 | type JSON interface {
157 | Field
158 | IsJSON()
159 | }
160 |
161 | // Number is a Field of numeric type.
162 | type Number interface {
163 | Field
164 | IsNumber()
165 | }
166 |
167 | // String is a Field of string type.
168 | type String interface {
169 | Field
170 | IsString()
171 | }
172 |
173 | // Time is a Field of time type.
174 | type Time interface {
175 | Field
176 | IsTime()
177 | }
178 |
179 | // UUID is a Field of uuid type.
180 | type UUID interface {
181 | Field
182 | IsUUID()
183 | }
184 |
185 | // DialectValuer is any type that will yield a different driver.Valuer
186 | // depending on the SQL dialect.
187 | type DialectValuer interface {
188 | DialectValuer(dialect string) (driver.Valuer, error)
189 | }
190 |
191 | // TableStruct is meant to be embedded in table structs to make them implement
192 | // the Table interface.
193 | type TableStruct struct {
194 | schema string
195 | name string
196 | alias string
197 | }
198 |
199 | // ViewStruct is just an alias for TableStruct.
200 | type ViewStruct = TableStruct
201 |
202 | var _ Table = (*TableStruct)(nil)
203 |
204 | // NewTableStruct creates a new TableStruct.
205 | func NewTableStruct(schema, name, alias string) TableStruct {
206 | return TableStruct{schema: schema, name: name, alias: alias}
207 | }
208 |
209 | // WriteSQL implements the SQLWriter interface.
210 | func (ts TableStruct) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
211 | if ts.schema != "" {
212 | buf.WriteString(QuoteIdentifier(dialect, ts.schema) + ".")
213 | }
214 | buf.WriteString(QuoteIdentifier(dialect, ts.name))
215 | return nil
216 | }
217 |
218 | // GetAlias returns the alias of the TableStruct.
219 | func (ts TableStruct) GetAlias() string { return ts.alias }
220 |
221 | // IsTable implements the Table interface.
222 | func (ts TableStruct) IsTable() {}
223 |
224 | func withPrefix(w SQLWriter, prefix string) SQLWriter {
225 | if field, ok := w.(interface {
226 | SQLWriter
227 | WithPrefix(string) Field
228 | }); ok {
229 | return field.WithPrefix(prefix)
230 | }
231 | return w
232 | }
233 |
234 | func getAlias(w SQLWriter) string {
235 | if w, ok := w.(interface{ GetAlias() string }); ok {
236 | return w.GetAlias()
237 | }
238 | return ""
239 | }
240 |
241 | func toString(dialect string, w SQLWriter) string {
242 | buf := bufpool.Get().(*bytes.Buffer)
243 | buf.Reset()
244 | defer bufpool.Put(buf)
245 | var args []any
246 | _ = w.WriteSQL(context.Background(), dialect, buf, &args, nil)
247 | return buf.String()
248 | }
249 |
250 | func writeFieldsWithPrefix(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, fields []Field, prefix string, includeAlias bool) error {
251 | var err error
252 | var alias string
253 | for i, field := range fields {
254 | if field == nil {
255 | return fmt.Errorf("field #%d is nil", i+1)
256 | }
257 | if i > 0 {
258 | buf.WriteString(", ")
259 | }
260 | err = withPrefix(field, prefix).WriteSQL(ctx, dialect, buf, args, params)
261 | if err != nil {
262 | return fmt.Errorf("field #%d: %w", i+1, err)
263 | }
264 | if includeAlias {
265 | if alias = getAlias(field); alias != "" {
266 | buf.WriteString(" AS " + QuoteIdentifier(dialect, alias))
267 | }
268 | }
269 | }
270 | return nil
271 | }
272 |
273 | func writeFields(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, fields []Field, includeAlias bool) error {
274 | var err error
275 | var alias string
276 | for i, field := range fields {
277 | if field == nil {
278 | return fmt.Errorf("field #%d is nil", i+1)
279 | }
280 | if i > 0 {
281 | buf.WriteString(", ")
282 | }
283 | _, isQuery := field.(Query)
284 | if isQuery {
285 | buf.WriteString("(")
286 | }
287 | err = field.WriteSQL(ctx, dialect, buf, args, params)
288 | if err != nil {
289 | return fmt.Errorf("field #%d: %w", i+1, err)
290 | }
291 | if isQuery {
292 | buf.WriteString(")")
293 | }
294 | if includeAlias {
295 | if alias = getAlias(field); alias != "" {
296 | buf.WriteString(" AS " + QuoteIdentifier(dialect, alias))
297 | }
298 | }
299 | }
300 | return nil
301 | }
302 |
303 | // mapperFunctionPanicked recovers from any panics.
304 | //
305 | // The function is called as such so that it shows up as
306 | // "sq.mapperFunctionPanicked" in panic stack trace, giving the user a
307 | // descriptive clue of what went wrong (i.e. their mapper function panicked).
308 | func mapperFunctionPanicked(err *error) {
309 | if r := recover(); r != nil {
310 | switch r := r.(type) {
311 | case error:
312 | *err = r
313 | default:
314 | *err = fmt.Errorf(fmt.Sprint(r))
315 | }
316 | }
317 | }
318 |
319 | // ArrayValue takes in a []string, []int, []int64, []int32, []float64,
320 | // []float32 or []bool and returns a driver.Valuer for that type. For Postgres,
321 | // it serializes into a Postgres array. Otherwise, it serializes into a JSON
322 | // array.
323 | func ArrayValue(value any) driver.Valuer {
324 | return &arrayValue{value: value}
325 | }
326 |
327 | type arrayValue struct {
328 | dialect string
329 | value any
330 | }
331 |
332 | // Value implements the driver.Valuer interface.
333 | func (v *arrayValue) Value() (driver.Value, error) {
334 | switch v.value.(type) {
335 | case []string, []int, []int64, []int32, []float64, []float32, []bool:
336 | break
337 | default:
338 | return nil, fmt.Errorf("value %#v is not a []string, []int, []int32, []float64, []float32 or []bool", v.value)
339 | }
340 | if v.dialect != DialectPostgres {
341 | var b strings.Builder
342 | err := json.NewEncoder(&b).Encode(v.value)
343 | if err != nil {
344 | return nil, err
345 | }
346 | return strings.TrimSpace(b.String()), nil
347 | }
348 | if ints, ok := v.value.([]int); ok {
349 | bigints := make([]int64, len(ints))
350 | for i, num := range ints {
351 | bigints[i] = int64(num)
352 | }
353 | v.value = bigints
354 | }
355 | return pqarray.Array(v.value).Value()
356 | }
357 |
358 | // DialectValuer implements the DialectValuer interface.
359 | func (v *arrayValue) DialectValuer(dialect string) (driver.Valuer, error) {
360 | v.dialect = dialect
361 | return v, nil
362 | }
363 |
364 | // EnumValue takes in an Enumeration and returns a driver.Valuer which
365 | // serializes the enum into a string and additionally checks if the enum is
366 | // valid.
367 | func EnumValue(value Enumeration) driver.Valuer {
368 | return &enumValue{value: value}
369 | }
370 |
371 | type enumValue struct {
372 | value Enumeration
373 | }
374 |
375 | // Value implements the driver.Valuer interface.
376 | func (v *enumValue) Value() (driver.Value, error) {
377 | value := reflect.ValueOf(v.value)
378 | names := v.value.Enumerate()
379 | switch value.Kind() {
380 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
381 | i := int(value.Int())
382 | if i < 0 || i >= len(names) {
383 | return nil, fmt.Errorf("%d is not a valid %T", i, v.value)
384 | }
385 | name := names[i]
386 | if name == "" && i != 0 {
387 | return nil, fmt.Errorf("%d is not a valid %T", i, v.value)
388 | }
389 | return name, nil
390 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
391 | i := int(value.Uint())
392 | if i < 0 || i >= len(names) {
393 | return nil, fmt.Errorf("%d is not a valid %T", i, v.value)
394 | }
395 | name := names[i]
396 | if name == "" && i != 0 {
397 | return nil, fmt.Errorf("%d is not a valid %T", i, v.value)
398 | }
399 | return name, nil
400 | case reflect.String:
401 | typ := value.Type()
402 | name := value.String()
403 | if getEnumIndex(name, names, typ) < 0 {
404 | return nil, fmt.Errorf("%q is not a valid %T", name, v.value)
405 | }
406 | return name, nil
407 | default:
408 | return nil, fmt.Errorf("underlying type of %[1]v is neither an integer nor string (%[1]T)", v.value)
409 | }
410 | }
411 |
412 | var (
413 | enumIndexMu sync.RWMutex
414 | enumIndex = make(map[reflect.Type]map[string]int)
415 | )
416 |
417 | // getEnumIndex returns the index of the enum within the names slice.
418 | func getEnumIndex(name string, names []string, typ reflect.Type) int {
419 | if len(names) <= 4 {
420 | for idx := range names {
421 | if names[idx] == name {
422 | return idx
423 | }
424 | }
425 | return -1
426 | }
427 | var nameIndex map[string]int
428 | enumIndexMu.RLock()
429 | nameIndex = enumIndex[typ]
430 | enumIndexMu.RUnlock()
431 | if nameIndex != nil {
432 | idx, ok := nameIndex[name]
433 | if !ok {
434 | return -1
435 | }
436 | return idx
437 | }
438 | idx := -1
439 | nameIndex = make(map[string]int)
440 | for i := range names {
441 | if names[i] == name {
442 | idx = i
443 | }
444 | nameIndex[names[i]] = i
445 | }
446 | enumIndexMu.Lock()
447 | enumIndex[typ] = nameIndex
448 | enumIndexMu.Unlock()
449 | return idx
450 | }
451 |
452 | // JSONValue takes in an interface{} and returns a driver.Valuer which runs the
453 | // value through json.Marshal before submitting it to the database.
454 | func JSONValue(value any) driver.Valuer {
455 | return &jsonValue{value: value}
456 | }
457 |
458 | type jsonValue struct {
459 | value any
460 | }
461 |
462 | // Value implements the driver.Valuer interface.
463 | func (v *jsonValue) Value() (driver.Value, error) {
464 | var b strings.Builder
465 | err := json.NewEncoder(&b).Encode(v.value)
466 | return strings.TrimSpace(b.String()), err
467 | }
468 |
469 | // UUIDValue takes in a type whose underlying type must be a [16]byte and
470 | // returns a driver.Valuer.
471 | func UUIDValue(value any) driver.Valuer {
472 | return &uuidValue{value: value}
473 | }
474 |
475 | type uuidValue struct {
476 | dialect string
477 | value any
478 | }
479 |
480 | // Value implements the driver.Valuer interface.
481 | func (v *uuidValue) Value() (driver.Value, error) {
482 | if v.value == nil {
483 | return nil, nil
484 | }
485 | uuid, ok := v.value.([16]byte)
486 | if !ok {
487 | value := reflect.ValueOf(v.value)
488 | typ := value.Type()
489 | if value.Kind() != reflect.Array || value.Len() != 16 || typ.Elem().Kind() != reflect.Uint8 {
490 | return nil, fmt.Errorf("%[1]v %[1]T is not [16]byte", v.value)
491 | }
492 | for i := 0; i < value.Len(); i++ {
493 | uuid[i] = value.Index(i).Interface().(byte)
494 | }
495 | }
496 | if v.dialect != DialectPostgres {
497 | return uuid[:], nil
498 | }
499 | var buf [36]byte
500 | googleuuid.EncodeHex(buf[:], uuid)
501 | return string(buf[:]), nil
502 | }
503 |
504 | // DialectValuer implements the DialectValuer interface.
505 | func (v *uuidValue) DialectValuer(dialect string) (driver.Valuer, error) {
506 | v.dialect = dialect
507 | return v, nil
508 | }
509 |
510 | func preprocessValue(dialect string, value any) (any, error) {
511 | if dialectValuer, ok := value.(DialectValuer); ok {
512 | driverValuer, err := dialectValuer.DialectValuer(dialect)
513 | if err != nil {
514 | return nil, fmt.Errorf("calling DialectValuer on %#v: %w", dialectValuer, err)
515 | }
516 | value = driverValuer
517 | }
518 | switch value := value.(type) {
519 | case nil:
520 | return nil, nil
521 | case Enumeration:
522 | driverValue, err := (&enumValue{value: value}).Value()
523 | if err != nil {
524 | return nil, fmt.Errorf("converting %#v to string: %w", value, err)
525 | }
526 | return driverValue, nil
527 | case [16]byte:
528 | driverValue, err := (&uuidValue{dialect: dialect, value: value}).Value()
529 | if err != nil {
530 | if dialect == DialectPostgres {
531 | return nil, fmt.Errorf("converting %#v to string: %w", value, err)
532 | }
533 | return nil, fmt.Errorf("converting %#v to bytes: %w", value, err)
534 | }
535 | return driverValue, nil
536 | case driver.Valuer:
537 | driverValue, err := value.Value()
538 | if err != nil {
539 | return nil, fmt.Errorf("calling Value on %#v: %w", value, err)
540 | }
541 | return driverValue, nil
542 | }
543 | return value, nil
544 | }
545 |
--------------------------------------------------------------------------------
/sq_test.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "database/sql"
5 | "testing"
6 |
7 | "github.com/bokwoon95/sq/internal/testutil"
8 | "github.com/google/uuid"
9 | )
10 |
11 | type Weekday uint
12 |
13 | const (
14 | WeekdayInvalid Weekday = iota
15 | Sunday
16 | Monday
17 | Tuesday
18 | Wednesday
19 | Thursday
20 | Friday
21 | Saturday
22 | )
23 |
24 | func (d Weekday) Enumerate() []string {
25 | return []string{
26 | WeekdayInvalid: "",
27 | Sunday: "Sunday",
28 | Monday: "Monday",
29 | Tuesday: "Tuesday",
30 | Wednesday: "Wednesday",
31 | Thursday: "Thursday",
32 | Friday: "Friday",
33 | Saturday: "Saturday",
34 | }
35 | }
36 |
37 | func Test_preprocessValue(t *testing.T) {
38 | type TestTable struct {
39 | description string
40 | dialect string
41 | input any
42 | wantOutput any
43 | }
44 |
45 | tests := []TestTable{{
46 | description: "empty",
47 | input: nil,
48 | wantOutput: nil,
49 | }, {
50 | description: "driver.Valuer",
51 | input: uuid.MustParse("a4f952f1-4c45-4e63-bd4e-159ca33c8e20"),
52 | wantOutput: "a4f952f1-4c45-4e63-bd4e-159ca33c8e20",
53 | }, {
54 | description: "Postgres DialectValuer",
55 | dialect: DialectPostgres,
56 | input: UUIDValue(uuid.MustParse("a4f952f1-4c45-4e63-bd4e-159ca33c8e20")),
57 | wantOutput: "a4f952f1-4c45-4e63-bd4e-159ca33c8e20",
58 | }, {
59 | description: "MySQL DialectValuer",
60 | dialect: DialectMySQL,
61 | input: UUIDValue(uuid.MustParse("a4f952f1-4c45-4e63-bd4e-159ca33c8e20")),
62 | wantOutput: []byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20},
63 | }, {
64 | description: "Postgres [16]byte",
65 | dialect: DialectPostgres,
66 | input: [16]byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20},
67 | wantOutput: "a4f952f1-4c45-4e63-bd4e-159ca33c8e20",
68 | }, {
69 | description: "MySQL [16]byte",
70 | dialect: DialectMySQL,
71 | input: [16]byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20},
72 | wantOutput: []byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20},
73 | }, {
74 | description: "Enumeration",
75 | input: Monday,
76 | wantOutput: "Monday",
77 | }, {
78 | description: "int",
79 | input: 42,
80 | wantOutput: 42,
81 | }, {
82 | description: "sql.NullString",
83 | input: sql.NullString{
84 | Valid: false,
85 | String: "lorem ipsum dolor sit amet",
86 | },
87 | wantOutput: nil,
88 | }}
89 |
90 | for _, tt := range tests {
91 | tt := tt
92 | t.Run(tt.description, func(t *testing.T) {
93 | t.Parallel()
94 | gotOutput, err := preprocessValue(tt.dialect, tt.input)
95 | if err != nil {
96 | t.Fatal(testutil.Callers(), err)
97 | }
98 | if diff := testutil.Diff(gotOutput, tt.wantOutput); diff != "" {
99 | t.Error(testutil.Callers(), diff)
100 | }
101 | })
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/update_query_test.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/bokwoon95/sq/internal/testutil"
7 | )
8 |
9 | func TestSQLiteUpdateQuery(t *testing.T) {
10 | type ACTOR struct {
11 | TableStruct
12 | ACTOR_ID NumberField
13 | FIRST_NAME StringField
14 | LAST_NAME StringField
15 | LAST_UPDATE TimeField
16 | }
17 | a := New[ACTOR]("a")
18 |
19 | t.Run("basic", func(t *testing.T) {
20 | t.Parallel()
21 | q1 := SQLite.Update(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum")
22 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" {
23 | t.Error(testutil.Callers(), diff)
24 | }
25 | q1 = q1.SetDialect(DialectSQLite)
26 | fields := q1.GetFetchableFields()
27 | if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" {
28 | t.Error(testutil.Callers(), diff)
29 | }
30 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME})
31 | if ok {
32 | t.Fatal(testutil.Callers(), "field should not have been set")
33 | }
34 | q1.ReturningFields = q1.ReturningFields[:0]
35 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME})
36 | if !ok {
37 | t.Fatal(testutil.Callers(), "field should have been set")
38 | }
39 | })
40 |
41 | t.Run("Set", func(t *testing.T) {
42 | t.Parallel()
43 | var tt TestTable
44 | tt.item = SQLite.
45 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
46 | Update(a).
47 | Set(
48 | a.FIRST_NAME.SetString("bob"),
49 | a.LAST_NAME.SetString("the builder"),
50 | ).
51 | Where(a.ACTOR_ID.EqInt(1), a.LAST_UPDATE.IsNotNull()).
52 | Returning(a.ACTOR_ID)
53 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
54 | " UPDATE actor AS a" +
55 | " SET first_name = $1, last_name = $2" +
56 | " WHERE a.actor_id = $3 AND a.last_update IS NOT NULL" +
57 | " RETURNING a.actor_id"
58 | tt.wantArgs = []any{"bob", "the builder", 1}
59 | tt.assert(t)
60 | })
61 |
62 | t.Run("SetFunc", func(t *testing.T) {
63 | t.Parallel()
64 | var tt TestTable
65 | tt.item = SQLite.
66 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
67 | Update(a).
68 | SetFunc(func(col *Column) {
69 | col.SetString(a.FIRST_NAME, "bob")
70 | col.SetString(a.LAST_NAME, "the builder")
71 | }).
72 | Where(a.ACTOR_ID.EqInt(1))
73 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
74 | " UPDATE actor AS a" +
75 | " SET first_name = $1, last_name = $2" +
76 | " WHERE a.actor_id = $3"
77 | tt.wantArgs = []any{"bob", "the builder", 1}
78 | tt.assert(t)
79 | })
80 |
81 | t.Run("UPDATE with JOIN", func(t *testing.T) {
82 | t.Parallel()
83 | var tt TestTable
84 | tt.item = SQLite.
85 | Update(a).
86 | Set(
87 | a.FIRST_NAME.SetString("bob"),
88 | a.LAST_NAME.SetString("the builder"),
89 | ).
90 | From(a).
91 | Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)).
92 | LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)).
93 | CrossJoin(a).
94 | CustomJoin(",", a).
95 | JoinUsing(a, a.FIRST_NAME, a.LAST_NAME).
96 | Where(a.ACTOR_ID.EqInt(1))
97 | tt.wantQuery = "UPDATE actor AS a" +
98 | " SET first_name = $1, last_name = $2" +
99 | " FROM actor AS a" +
100 | " JOIN actor AS a ON a.actor_id = a.actor_id" +
101 | " LEFT JOIN actor AS a ON a.actor_id = a.actor_id" +
102 | " CROSS JOIN actor AS a" +
103 | " , actor AS a" +
104 | " JOIN actor AS a USING (first_name, last_name)" +
105 | " WHERE a.actor_id = $3"
106 | tt.wantArgs = []any{"bob", "the builder", 1}
107 | tt.assert(t)
108 | })
109 | }
110 |
111 | func TestPostgresUpdateQuery(t *testing.T) {
112 | type ACTOR struct {
113 | TableStruct
114 | ACTOR_ID NumberField
115 | FIRST_NAME StringField
116 | LAST_NAME StringField
117 | LAST_UPDATE TimeField
118 | }
119 | a := New[ACTOR]("a")
120 |
121 | t.Run("basic", func(t *testing.T) {
122 | t.Parallel()
123 | q1 := Postgres.Update(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum")
124 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" {
125 | t.Error(testutil.Callers(), diff)
126 | }
127 | q1 = q1.SetDialect(DialectPostgres)
128 | fields := q1.GetFetchableFields()
129 | if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" {
130 | t.Error(testutil.Callers(), diff)
131 | }
132 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME})
133 | if ok {
134 | t.Fatal(testutil.Callers(), "field should not have been set")
135 | }
136 | q1.ReturningFields = q1.ReturningFields[:0]
137 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME})
138 | if !ok {
139 | t.Fatal(testutil.Callers(), "field should have been set")
140 | }
141 | })
142 |
143 | t.Run("Set", func(t *testing.T) {
144 | t.Parallel()
145 | var tt TestTable
146 | tt.item = Postgres.
147 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
148 | Update(a).
149 | Set(
150 | a.FIRST_NAME.SetString("bob"),
151 | a.LAST_NAME.SetString("the builder"),
152 | ).
153 | Where(a.ACTOR_ID.EqInt(1), a.LAST_UPDATE.IsNotNull()).
154 | Returning(a.ACTOR_ID)
155 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
156 | " UPDATE actor AS a" +
157 | " SET first_name = $1, last_name = $2" +
158 | " WHERE a.actor_id = $3 AND a.last_update IS NOT NULL" +
159 | " RETURNING a.actor_id"
160 | tt.wantArgs = []any{"bob", "the builder", 1}
161 | tt.assert(t)
162 | })
163 |
164 | t.Run("SetFunc", func(t *testing.T) {
165 | t.Parallel()
166 | var tt TestTable
167 | tt.item = Postgres.
168 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
169 | Update(a).
170 | SetFunc(func(col *Column) {
171 | col.SetString(a.FIRST_NAME, "bob")
172 | col.SetString(a.LAST_NAME, "the builder")
173 | }).
174 | Where(a.ACTOR_ID.EqInt(1))
175 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
176 | " UPDATE actor AS a" +
177 | " SET first_name = $1, last_name = $2" +
178 | " WHERE a.actor_id = $3"
179 | tt.wantArgs = []any{"bob", "the builder", 1}
180 | tt.assert(t)
181 | })
182 |
183 | t.Run("UPDATE with JOIN", func(t *testing.T) {
184 | t.Parallel()
185 | var tt TestTable
186 | tt.item = Postgres.
187 | Update(a).
188 | Set(
189 | a.FIRST_NAME.SetString("bob"),
190 | a.LAST_NAME.SetString("the builder"),
191 | ).
192 | From(a).
193 | Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)).
194 | LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)).
195 | FullJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)).
196 | CrossJoin(a).
197 | CustomJoin(",", a).
198 | JoinUsing(a, a.FIRST_NAME, a.LAST_NAME).
199 | Where(a.ACTOR_ID.EqInt(1))
200 | tt.wantQuery = "UPDATE actor AS a" +
201 | " SET first_name = $1, last_name = $2" +
202 | " FROM actor AS a" +
203 | " JOIN actor AS a ON a.actor_id = a.actor_id" +
204 | " LEFT JOIN actor AS a ON a.actor_id = a.actor_id" +
205 | " FULL JOIN actor AS a ON a.actor_id = a.actor_id" +
206 | " CROSS JOIN actor AS a" +
207 | " , actor AS a" +
208 | " JOIN actor AS a USING (first_name, last_name)" +
209 | " WHERE a.actor_id = $3"
210 | tt.wantArgs = []any{"bob", "the builder", 1}
211 | tt.assert(t)
212 | })
213 | }
214 |
215 | func TestMySQLUpdateQuery(t *testing.T) {
216 | type ACTOR struct {
217 | TableStruct
218 | ACTOR_ID NumberField
219 | FIRST_NAME StringField
220 | LAST_NAME StringField
221 | LAST_UPDATE TimeField
222 | }
223 | a := New[ACTOR]("a")
224 |
225 | t.Run("basic", func(t *testing.T) {
226 | t.Parallel()
227 | q1 := MySQL.Update(a).SetDialect("lorem ipsum")
228 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" {
229 | t.Error(testutil.Callers(), diff)
230 | }
231 | q1 = q1.SetDialect(DialectMySQL)
232 | fields := q1.GetFetchableFields()
233 | if len(fields) != 0 {
234 | t.Error(testutil.Callers(), "expected 0 fields but got %v", fields)
235 | }
236 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME})
237 | if ok {
238 | t.Fatal(testutil.Callers(), "field should not have been set")
239 | }
240 | q1.ReturningFields = q1.ReturningFields[:0]
241 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME})
242 | if ok {
243 | t.Fatal(testutil.Callers(), "field should not have been set")
244 | }
245 | })
246 |
247 | t.Run("Set", func(t *testing.T) {
248 | t.Parallel()
249 | var tt TestTable
250 | tt.item = MySQL.
251 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
252 | Update(a).
253 | Set(
254 | a.FIRST_NAME.SetString("bob"),
255 | a.LAST_NAME.SetString("the builder"),
256 | ).
257 | Where(a.ACTOR_ID.EqInt(1))
258 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
259 | " UPDATE actor AS a" +
260 | " SET a.first_name = ?, a.last_name = ?" +
261 | " WHERE a.actor_id = ?"
262 | tt.wantArgs = []any{"bob", "the builder", 1}
263 | tt.assert(t)
264 | })
265 |
266 | t.Run("SetFunc", func(t *testing.T) {
267 | t.Parallel()
268 | var tt TestTable
269 | tt.item = MySQL.
270 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
271 | Update(a).
272 | SetFunc(func(col *Column) {
273 | col.SetString(a.FIRST_NAME, "bob")
274 | col.SetString(a.LAST_NAME, "the builder")
275 | }).
276 | Where(a.ACTOR_ID.EqInt(1))
277 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
278 | " UPDATE actor AS a" +
279 | " SET a.first_name = ?, a.last_name = ?" +
280 | " WHERE a.actor_id = ?"
281 | tt.wantArgs = []any{"bob", "the builder", 1}
282 | tt.assert(t)
283 | })
284 |
285 | t.Run("UPDATE with JOIN, ORDER BY, LIMIT", func(t *testing.T) {
286 | t.Parallel()
287 | var tt TestTable
288 | tt.item = MySQL.
289 | Update(a).
290 | Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)).
291 | LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)).
292 | FullJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)).
293 | CrossJoin(a).
294 | CustomJoin(",", a).
295 | JoinUsing(a, a.FIRST_NAME, a.LAST_NAME).
296 | Set(
297 | a.FIRST_NAME.SetString("bob"),
298 | a.LAST_NAME.SetString("the builder"),
299 | ).
300 | Where(a.ACTOR_ID.EqInt(1)).
301 | OrderBy(a.ACTOR_ID).
302 | Limit(5)
303 | tt.wantQuery = "UPDATE actor AS a" +
304 | " JOIN actor AS a ON a.actor_id = a.actor_id" +
305 | " LEFT JOIN actor AS a ON a.actor_id = a.actor_id" +
306 | " FULL JOIN actor AS a ON a.actor_id = a.actor_id" +
307 | " CROSS JOIN actor AS a" +
308 | " , actor AS a" +
309 | " JOIN actor AS a USING (first_name, last_name)" +
310 | " SET a.first_name = ?, a.last_name = ?" +
311 | " WHERE a.actor_id = ?" +
312 | " ORDER BY a.actor_id" +
313 | " LIMIT ?"
314 | tt.wantArgs = []any{"bob", "the builder", 1, 5}
315 | tt.assert(t)
316 | })
317 | }
318 |
319 | func TestSQLServerUpdateQuery(t *testing.T) {
320 | type ACTOR struct {
321 | TableStruct
322 | ACTOR_ID NumberField
323 | FIRST_NAME StringField
324 | LAST_NAME StringField
325 | LAST_UPDATE TimeField
326 | }
327 | a := New[ACTOR]("")
328 |
329 | t.Run("basic", func(t *testing.T) {
330 | t.Parallel()
331 | q1 := SQLServer.Update(a).SetDialect("lorem ipsum")
332 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" {
333 | t.Error(testutil.Callers(), diff)
334 | }
335 | q1 = q1.SetDialect(DialectSQLServer)
336 | fields := q1.GetFetchableFields()
337 | if len(fields) != 0 {
338 | t.Error(testutil.Callers(), "expected 0 fields but got %v", fields)
339 | }
340 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME})
341 | if ok {
342 | t.Fatal(testutil.Callers(), "field should not have been set")
343 | }
344 | q1.ReturningFields = q1.ReturningFields[:0]
345 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME})
346 | if ok {
347 | t.Fatal(testutil.Callers(), "field should not have been set")
348 | }
349 | })
350 |
351 | t.Run("Set", func(t *testing.T) {
352 | t.Parallel()
353 | var tt TestTable
354 | tt.item = SQLServer.
355 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
356 | Update(a).
357 | Set(
358 | a.FIRST_NAME.SetString("bob"),
359 | a.LAST_NAME.SetString("the builder"),
360 | ).
361 | Where(a.ACTOR_ID.EqInt(1))
362 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
363 | " UPDATE actor" +
364 | " SET first_name = @p1, last_name = @p2" +
365 | " WHERE actor.actor_id = @p3"
366 | tt.wantArgs = []any{"bob", "the builder", 1}
367 | tt.assert(t)
368 | })
369 |
370 | t.Run("SetFunc", func(t *testing.T) {
371 | t.Parallel()
372 | var tt TestTable
373 | tt.item = SQLServer.
374 | With(NewCTE("cte", nil, Queryf("SELECT 1"))).
375 | Update(a).
376 | SetFunc(func(col *Column) {
377 | col.SetString(a.FIRST_NAME, "bob")
378 | col.SetString(a.LAST_NAME, "the builder")
379 | }).
380 | Where(a.ACTOR_ID.EqInt(1))
381 | tt.wantQuery = "WITH cte AS (SELECT 1)" +
382 | " UPDATE actor" +
383 | " SET first_name = @p1, last_name = @p2" +
384 | " WHERE actor.actor_id = @p3"
385 | tt.wantArgs = []any{"bob", "the builder", 1}
386 | tt.assert(t)
387 | })
388 |
389 | t.Run("UPDATE with JOIN", func(t *testing.T) {
390 | t.Parallel()
391 | var tt TestTable
392 | tt.item = SQLServer.
393 | Update(a).
394 | Set(
395 | a.FIRST_NAME.SetString("bob"),
396 | a.LAST_NAME.SetString("the builder"),
397 | ).
398 | From(a).
399 | Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)).
400 | LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)).
401 | FullJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)).
402 | CrossJoin(a).
403 | CustomJoin(",", a).
404 | Where(a.ACTOR_ID.EqInt(1))
405 | tt.wantQuery = "UPDATE actor" +
406 | " SET first_name = @p1, last_name = @p2" +
407 | " FROM actor" +
408 | " JOIN actor ON actor.actor_id = actor.actor_id" +
409 | " LEFT JOIN actor ON actor.actor_id = actor.actor_id" +
410 | " FULL JOIN actor ON actor.actor_id = actor.actor_id" +
411 | " CROSS JOIN actor" +
412 | " , actor" +
413 | " WHERE actor.actor_id = @p3"
414 | tt.wantArgs = []any{"bob", "the builder", 1}
415 | tt.assert(t)
416 | })
417 | }
418 |
419 | func TestUpdateQuery(t *testing.T) {
420 | t.Run("basic", func(t *testing.T) {
421 | t.Parallel()
422 | q1 := UpdateQuery{UpdateTable: Expr("tbl"), Dialect: "lorem ipsum"}
423 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" {
424 | t.Error(testutil.Callers(), diff)
425 | }
426 | })
427 |
428 | f1, f2, f3 := Expr("f1"), Expr("f2"), Expr("f3")
429 | colmapper := func(col *Column) {
430 | col.Set(f1, 1)
431 | col.Set(f2, 2)
432 | col.Set(f3, 3)
433 | }
434 |
435 | t.Run("PolicyTable", func(t *testing.T) {
436 | t.Parallel()
437 | var tt TestTable
438 | tt.item = UpdateQuery{
439 | UpdateTable: policyTableStub{policy: And(Expr("1 = 1"), Expr("2 = 2"))},
440 | ColumnMapper: colmapper,
441 | WherePredicate: Expr("3 = 3"),
442 | }
443 | tt.wantQuery = "UPDATE policy_table_stub SET f1 = ?, f2 = ?, f3 = ? WHERE (1 = 1 AND 2 = 2) AND 3 = 3"
444 | tt.wantArgs = []any{1, 2, 3}
445 | tt.assert(t)
446 | })
447 |
448 | notOKTests := []TestTable{{
449 | description: "nil UpdateTable not allowed",
450 | item: UpdateQuery{
451 | UpdateTable: nil,
452 | ColumnMapper: colmapper,
453 | },
454 | }, {
455 | description: "empty Assignments not allowed",
456 | item: UpdateQuery{
457 | UpdateTable: Expr("tbl"),
458 | Assignments: nil,
459 | },
460 | }, {
461 | description: "mysql does not support FROM",
462 | item: UpdateQuery{
463 | Dialect: DialectMySQL,
464 | UpdateTable: Expr("tbl"),
465 | FromTable: Expr("tbl"),
466 | ColumnMapper: colmapper,
467 | },
468 | }, {
469 | description: "dialect does not allow JOIN without FROM",
470 | item: UpdateQuery{
471 | Dialect: DialectPostgres,
472 | UpdateTable: Expr("tbl"),
473 | FromTable: nil,
474 | JoinTables: []JoinTable{
475 | Join(Expr("tbl"), Expr("1 = 1")),
476 | },
477 | ColumnMapper: colmapper,
478 | },
479 | }, {
480 | description: "dialect does not support ORDER BY",
481 | item: UpdateQuery{
482 | Dialect: DialectPostgres,
483 | UpdateTable: Expr("tbl"),
484 | ColumnMapper: colmapper,
485 | OrderByFields: Fields{f1},
486 | },
487 | }, {
488 | description: "dialect does not support LIMIT",
489 | item: UpdateQuery{
490 | Dialect: DialectPostgres,
491 | UpdateTable: Expr("tbl"),
492 | ColumnMapper: colmapper,
493 | LimitRows: 5,
494 | },
495 | }, {
496 | description: "dialect does not support RETURNING",
497 | item: UpdateQuery{
498 | Dialect: DialectMySQL,
499 | UpdateTable: Expr("tbl"),
500 | ColumnMapper: colmapper,
501 | ReturningFields: Fields{f1, f2, f3},
502 | },
503 | }}
504 |
505 | for _, tt := range notOKTests {
506 | tt := tt
507 | t.Run(tt.description, func(t *testing.T) {
508 | t.Parallel()
509 | tt.assertNotOK(t)
510 | })
511 | }
512 |
513 | errTests := []TestTable{{
514 | description: "ColumnMapper err",
515 | item: UpdateQuery{
516 | UpdateTable: Expr("tbl"),
517 | ColumnMapper: func(*Column) { panic(ErrFaultySQL) },
518 | },
519 | }, {
520 | description: "UpdateTable Policy err",
521 | item: UpdateQuery{
522 | UpdateTable: policyTableStub{err: ErrFaultySQL},
523 | ColumnMapper: colmapper,
524 | },
525 | }, {
526 | description: "FromTable Policy err",
527 | item: UpdateQuery{
528 | UpdateTable: Expr("tbl"),
529 | FromTable: policyTableStub{err: ErrFaultySQL},
530 | ColumnMapper: colmapper,
531 | },
532 | }, {
533 | description: "JoinTables Policy err",
534 | item: UpdateQuery{
535 | UpdateTable: Expr("tbl"),
536 | ColumnMapper: colmapper,
537 | FromTable: Expr("tbl"),
538 | JoinTables: []JoinTable{
539 | Join(policyTableStub{err: ErrFaultySQL}, Expr("1 = 1")),
540 | },
541 | },
542 | }, {
543 | description: "CTEs err",
544 | item: UpdateQuery{
545 | CTEs: []CTE{NewCTE("cte", nil, Queryf("SELECT {}", FaultySQL{}))},
546 | UpdateTable: Expr("tbl"),
547 | ColumnMapper: colmapper,
548 | },
549 | }, {
550 | description: "UpdateTable err",
551 | item: UpdateQuery{
552 | UpdateTable: FaultySQL{},
553 | ColumnMapper: colmapper,
554 | },
555 | }, {
556 | description: "not mysql Assignments err",
557 | item: UpdateQuery{
558 | Dialect: DialectPostgres,
559 | UpdateTable: Expr("tbl"),
560 | Assignments: []Assignment{FaultySQL{}},
561 | },
562 | }, {
563 | description: "FromTable err",
564 | item: UpdateQuery{
565 | Dialect: DialectPostgres,
566 | UpdateTable: Expr("tbl"),
567 | ColumnMapper: colmapper,
568 | FromTable: FaultySQL{},
569 | },
570 | }, {
571 | description: "JoinTables err",
572 | item: UpdateQuery{
573 | Dialect: DialectPostgres,
574 | UpdateTable: Expr("tbl"),
575 | ColumnMapper: colmapper,
576 | FromTable: Expr("tbl"),
577 | JoinTables: []JoinTable{
578 | Join(FaultySQL{}, Expr("1 = 1")),
579 | },
580 | },
581 | }, {
582 | description: "mysql Assignments err",
583 | item: UpdateQuery{
584 | Dialect: DialectMySQL,
585 | UpdateTable: Expr("tbl"),
586 | Assignments: []Assignment{FaultySQL{}},
587 | },
588 | }, {
589 | description: "WherePredicate Variadic err",
590 | item: UpdateQuery{
591 | UpdateTable: Expr("tbl"),
592 | ColumnMapper: colmapper,
593 | WherePredicate: And(FaultySQL{}),
594 | },
595 | }, {
596 | description: "WherePredicate err",
597 | item: UpdateQuery{
598 | UpdateTable: Expr("tbl"),
599 | ColumnMapper: colmapper,
600 | WherePredicate: FaultySQL{},
601 | },
602 | }, {
603 | description: "OrderByFields err",
604 | item: UpdateQuery{
605 | Dialect: DialectMySQL,
606 | UpdateTable: Expr("tbl"),
607 | ColumnMapper: colmapper,
608 | OrderByFields: Fields{FaultySQL{}},
609 | },
610 | }, {
611 | description: "LimitRows err",
612 | item: UpdateQuery{
613 | Dialect: DialectMySQL,
614 | UpdateTable: Expr("tbl"),
615 | ColumnMapper: colmapper,
616 | OrderByFields: Fields{f1},
617 | LimitRows: FaultySQL{},
618 | },
619 | }, {
620 | description: "ReturningFields err",
621 | item: UpdateQuery{
622 | Dialect: DialectPostgres,
623 | UpdateTable: Expr("tbl"),
624 | ColumnMapper: colmapper,
625 | ReturningFields: Fields{FaultySQL{}},
626 | },
627 | }}
628 |
629 | for _, tt := range errTests {
630 | tt := tt
631 | t.Run(tt.description, func(t *testing.T) {
632 | t.Parallel()
633 | tt.assertErr(t, ErrFaultySQL)
634 | })
635 | }
636 | }
637 |
--------------------------------------------------------------------------------
/window.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | )
8 |
9 | // NamedWindow represents an SQL named window.
10 | type NamedWindow struct {
11 | Name string
12 | Definition Window
13 | }
14 |
15 | var _ Window = (*NamedWindow)(nil)
16 |
17 | // WriteSQL implements the SQLWriter interface.
18 | func (w NamedWindow) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
19 | buf.WriteString(w.Name)
20 | return nil
21 | }
22 |
23 | // IsWindow implements the Window interface.
24 | func (w NamedWindow) IsWindow() {}
25 |
26 | // WindowDefinition represents an SQL window definition.
27 | type WindowDefinition struct {
28 | BaseWindowName string
29 | PartitionByFields []Field
30 | OrderByFields []Field
31 | FrameSpec string
32 | FrameValues []any
33 | }
34 |
35 | var _ Window = (*WindowDefinition)(nil)
36 |
37 | // BaseWindow creates a new WindowDefinition based off an existing NamedWindow.
38 | func BaseWindow(w NamedWindow) WindowDefinition {
39 | return WindowDefinition{BaseWindowName: w.Name}
40 | }
41 |
42 | // PartitionBy returns a new WindowDefinition with the PARTITION BY clause.
43 | func PartitionBy(fields ...Field) WindowDefinition {
44 | return WindowDefinition{PartitionByFields: fields}
45 | }
46 |
47 | // PartitionBy returns a new WindowDefinition with the ORDER BY clause.
48 | func OrderBy(fields ...Field) WindowDefinition {
49 | return WindowDefinition{OrderByFields: fields}
50 | }
51 |
52 | // WriteSQL implements the SQLWriter interface.
53 | func (w WindowDefinition) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
54 | var err error
55 | var written bool
56 | buf.WriteString("(")
57 | if w.BaseWindowName != "" {
58 | buf.WriteString(w.BaseWindowName + " ")
59 | }
60 | if len(w.PartitionByFields) > 0 {
61 | written = true
62 | buf.WriteString("PARTITION BY ")
63 | err = writeFields(ctx, dialect, buf, args, params, w.PartitionByFields, false)
64 | if err != nil {
65 | return fmt.Errorf("Window PARTITION BY: %w", err)
66 | }
67 | }
68 | if len(w.OrderByFields) > 0 {
69 | if written {
70 | buf.WriteString(" ")
71 | }
72 | written = true
73 | buf.WriteString("ORDER BY ")
74 | err = writeFields(ctx, dialect, buf, args, params, w.OrderByFields, false)
75 | if err != nil {
76 | return fmt.Errorf("Window ORDER BY: %w", err)
77 | }
78 | }
79 | if w.FrameSpec != "" {
80 | if written {
81 | buf.WriteString(" ")
82 | }
83 | written = true
84 | err = Writef(ctx, dialect, buf, args, params, w.FrameSpec, w.FrameValues)
85 | if err != nil {
86 | return fmt.Errorf("Window FRAME: %w", err)
87 | }
88 | }
89 | buf.WriteString(")")
90 | return nil
91 | }
92 |
93 | // PartitionBy returns a new WindowDefinition with the PARTITION BY clause.
94 | func (w WindowDefinition) PartitionBy(fields ...Field) WindowDefinition {
95 | w.PartitionByFields = fields
96 | return w
97 | }
98 |
99 | // OrderBy returns a new WindowDefinition with the ORDER BY clause.
100 | func (w WindowDefinition) OrderBy(fields ...Field) WindowDefinition {
101 | w.OrderByFields = fields
102 | return w
103 | }
104 |
105 | // Frame returns a new WindowDefinition with the frame specification set.
106 | func (w WindowDefinition) Frame(frameSpec string, frameValues ...any) WindowDefinition {
107 | w.FrameSpec = frameSpec
108 | w.FrameValues = frameValues
109 | return w
110 | }
111 |
112 | // IsWindow implements the Window interface.
113 | func (w WindowDefinition) IsWindow() {}
114 |
115 | // NamedWindows represents a slice of NamedWindows.
116 | type NamedWindows []NamedWindow
117 |
118 | // WriteSQL imeplements the SQLWriter interface.
119 | func (ws NamedWindows) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error {
120 | var err error
121 | for i, window := range ws {
122 | if i > 0 {
123 | buf.WriteString(", ")
124 | }
125 | buf.WriteString(window.Name + " AS ")
126 | err = window.Definition.WriteSQL(ctx, dialect, buf, args, params)
127 | if err != nil {
128 | return fmt.Errorf("window #%d: %w", i+1, err)
129 | }
130 | }
131 | return nil
132 | }
133 |
134 | // CountOver represents the COUNT() OVER () window function.
135 | func CountOver(field Field, window Window) Expression {
136 | if window == nil {
137 | return Expr("COUNT({}) OVER ()", field)
138 | }
139 | return Expr("COUNT({}) OVER {}", field, window)
140 | }
141 |
142 | // CountStarOver represents the COUNT(*) OVER () window function.
143 | func CountStarOver(window Window) Expression {
144 | if window == nil {
145 | return Expr("COUNT(*) OVER ()")
146 | }
147 | return Expr("COUNT(*) OVER {}", window)
148 | }
149 |
150 | // SumOver represents the SUM() OVER () window function.
151 | func SumOver(num Number, window Window) Expression {
152 | if window == nil {
153 | return Expr("SUM({}) OVER ()", num)
154 | }
155 | return Expr("SUM({}) OVER {}", num, window)
156 | }
157 |
158 | // AvgOver represents the AVG() OVER () window function.
159 | func AvgOver(num Number, window Window) Expression {
160 | if window == nil {
161 | return Expr("AVG({}) OVER ()", num)
162 | }
163 | return Expr("AVG({}) OVER {}", num, window)
164 | }
165 |
166 | // MinOver represents the MIN() OVER () window function.
167 | func MinOver(field Field, window Window) Expression {
168 | if window == nil {
169 | return Expr("MIN({}) OVER ()", field)
170 | }
171 | return Expr("MIN({}) OVER {}", field, window)
172 | }
173 |
174 | // MaxOver represents the MAX() OVER () window function.
175 | func MaxOver(field Field, window Window) Expression {
176 | if window == nil {
177 | return Expr("MAX({}) OVER ()", field)
178 | }
179 | return Expr("MAX({}) OVER {}", field, window)
180 | }
181 |
182 | // RowNumberOver represents the ROW_NUMBER() OVER () window function.
183 | func RowNumberOver(window Window) Expression {
184 | if window == nil {
185 | return Expr("ROW_NUMBER() OVER ()")
186 | }
187 | return Expr("ROW_NUMBER() OVER {}", window)
188 | }
189 |
190 | // RankOver represents the RANK() OVER () window function.
191 | func RankOver(window Window) Expression {
192 | if window == nil {
193 | return Expr("RANK() OVER ()")
194 | }
195 | return Expr("RANK() OVER {}", window)
196 | }
197 |
198 | // DenseRankOver represents the DENSE_RANK() OVER () window function.
199 | func DenseRankOver(window Window) Expression {
200 | if window == nil {
201 | return Expr("DENSE_RANK() OVER ()")
202 | }
203 | return Expr("DENSE_RANK() OVER {}", window)
204 | }
205 |
206 | // CumeDistOver represents the CUME_DIST() OVER () window function.
207 | func CumeDistOver(window Window) Expression {
208 | if window == nil {
209 | return Expr("CUME_DIST() OVER ()")
210 | }
211 | return Expr("CUME_DIST() OVER {}", window)
212 | }
213 |
214 | // FirstValueOver represents the FIRST_VALUE() OVER () window function.
215 | func FirstValueOver(field Field, window Window) Expression {
216 | if window == nil {
217 | return Expr("FIRST_VALUE({}) OVER ()", field)
218 | }
219 | return Expr("FIRST_VALUE({}) OVER {}", field, window)
220 | }
221 |
222 | // LastValueOver represents the LAST_VALUE() OVER () window
223 | // function.
224 | func LastValueOver(field Field, window Window) Expression {
225 | if window == nil {
226 | return Expr("LAST_VALUE({}) OVER ()", field)
227 | }
228 | return Expr("LAST_VALUE({}) OVER {}", field, window)
229 | }
230 |
--------------------------------------------------------------------------------
/window_test.go:
--------------------------------------------------------------------------------
1 | package sq
2 |
3 | import "testing"
4 |
5 | func TestWindow(t *testing.T) {
6 | t.Run("basic", func(t *testing.T) {
7 | t.Parallel()
8 | f1, f2, f3 := Expr("f1"), Expr("f2"), Expr("f3")
9 | TestTable{
10 | item: PartitionBy(f1).OrderBy(f2, f3).Frame("RANGE UNBOUNDED PRECEDING"),
11 | wantQuery: "(PARTITION BY f1 ORDER BY f2, f3 RANGE UNBOUNDED PRECEDING)",
12 | }.assert(t)
13 | TestTable{
14 | item: OrderBy(f1).PartitionBy(f2, f3).Frame("ROWS {} PRECEDING", 5),
15 | wantQuery: "(PARTITION BY f2, f3 ORDER BY f1 ROWS ? PRECEDING)",
16 | wantArgs: []any{5},
17 | }.assert(t)
18 | })
19 |
20 | errTests := []TestTable{{
21 | description: "PartitionBy err", item: PartitionBy(FaultySQL{}),
22 | }, {
23 | description: "OrderBy err", item: OrderBy(FaultySQL{}),
24 | }, {
25 | description: "Frame err", item: OrderBy(Expr("f")).Frame("ROWS {} PRECEDING", FaultySQL{}),
26 | }, {
27 | description: "NamedWindows err", item: NamedWindows{{
28 | Name: "w",
29 | Definition: OrderBy(Expr("f")).Frame("ROWS {} PRECEDING", FaultySQL{}),
30 | }},
31 | }}
32 |
33 | for _, tt := range errTests {
34 | tt := tt
35 | t.Run(tt.description, func(t *testing.T) {
36 | t.Parallel()
37 | tt.assertErr(t, ErrFaultySQL)
38 | })
39 | }
40 |
41 | funcTests := []TestTable{{
42 | description: "CountOver", item: CountOver(Expr("f1"), WindowDefinition{}),
43 | wantQuery: "COUNT(f1) OVER ()",
44 | }, {
45 | description: "CountOver nil", item: CountOver(Expr("f1"), nil),
46 | wantQuery: "COUNT(f1) OVER ()",
47 | }, {
48 | description: "CountStarOver", item: CountStarOver(WindowDefinition{}),
49 | wantQuery: "COUNT(*) OVER ()",
50 | }, {
51 | description: "SumOver", item: SumOver(Expr("f1"), PartitionBy(Expr("f2"))),
52 | wantQuery: "SUM(f1) OVER (PARTITION BY f2)",
53 | }, {
54 | description: "AvgOver", item: AvgOver(Expr("f1"), PartitionBy(Expr("f2"))),
55 | wantQuery: "AVG(f1) OVER (PARTITION BY f2)",
56 | }, {
57 | description: "MinOver", item: MinOver(Expr("f1"), PartitionBy(Expr("f2"))),
58 | wantQuery: "MIN(f1) OVER (PARTITION BY f2)",
59 | }, {
60 | description: "MaxOver", item: MaxOver(Expr("f1"), PartitionBy(Expr("f2"))),
61 | wantQuery: "MAX(f1) OVER (PARTITION BY f2)",
62 | }, {
63 | description: "RowNumberOver", item: RowNumberOver(PartitionBy(Expr("f1"))),
64 | wantQuery: "ROW_NUMBER() OVER (PARTITION BY f1)",
65 | }, {
66 | description: "RankOver", item: RankOver(PartitionBy(Expr("f1"))),
67 | wantQuery: "RANK() OVER (PARTITION BY f1)",
68 | }, {
69 | description: "DenseRankOver", item: DenseRankOver(PartitionBy(Expr("f1"))),
70 | wantQuery: "DENSE_RANK() OVER (PARTITION BY f1)",
71 | }, {
72 | description: "CumeDistOver", item: CumeDistOver(PartitionBy(Expr("f1"))),
73 | wantQuery: "CUME_DIST() OVER (PARTITION BY f1)",
74 | }, {
75 | description: "FirstValueOver", item: FirstValueOver(Expr("f1"), PartitionBy(Expr("f2"))),
76 | wantQuery: "FIRST_VALUE(f1) OVER (PARTITION BY f2)",
77 | }, {
78 | description: "LastValueOver", item: LastValueOver(Expr("f1"), PartitionBy(Expr("f2"))),
79 | wantQuery: "LAST_VALUE(f1) OVER (PARTITION BY f2)",
80 | }, {
81 | description: "NamedWindow", item: CountStarOver(NamedWindow{Name: "w"}),
82 | wantQuery: "COUNT(*) OVER w",
83 | }, func() TestTable {
84 | var tt TestTable
85 | tt.description = "BaseWindow"
86 | w := NamedWindow{Name: "w", Definition: PartitionBy(Expr("f1"))}
87 | tt.item = CountStarOver(BaseWindow(w).Frame("ROWS UNBOUNDED PRECEDING"))
88 | tt.wantQuery = "COUNT(*) OVER (w ROWS UNBOUNDED PRECEDING)"
89 | return tt
90 | }(), func() TestTable {
91 | var tt TestTable
92 | tt.description = "NamedWindows"
93 | w1 := NamedWindow{Name: "w1", Definition: PartitionBy(Expr("f1"))}
94 | w2 := NamedWindow{Name: "w2", Definition: OrderBy(Expr("f2"))}
95 | w3 := NamedWindow{Name: "w3", Definition: OrderBy(Expr("f3")).Frame("ROWS UNBOUNDED PRECEDING")}
96 | tt.item = NamedWindows{w1, w2, w3}
97 | tt.wantQuery = "w1 AS (PARTITION BY f1)" +
98 | ", w2 AS (ORDER BY f2)" +
99 | ", w3 AS (ORDER BY f3 ROWS UNBOUNDED PRECEDING)"
100 | return tt
101 | }()}
102 |
103 | for _, tt := range funcTests {
104 | tt := tt
105 | t.Run(tt.description, func(t *testing.T) {
106 | t.Parallel()
107 | tt.assert(t)
108 | })
109 | }
110 | }
111 |
--------------------------------------------------------------------------------