├── .github ├── mddocs └── workflows │ ├── neocities.yml │ └── tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── START_HERE.md ├── builtins.go ├── builtins_test.go ├── colors.go ├── cte.go ├── cte_test.go ├── delete_query.go ├── delete_query_test.go ├── fetch_exec.go ├── fetch_exec_test.go ├── fields.go ├── fields_test.go ├── fmt.go ├── fmt_test.go ├── go.mod ├── go.sum ├── header.png ├── insert_query.go ├── insert_query_test.go ├── integration_test.go ├── internal ├── googleuuid │ └── googleuuid.go ├── pqarray │ └── pqarray.go └── testutil │ └── testutil.go ├── joins.go ├── joins_test.go ├── logger.go ├── logger_test.go ├── misc.go ├── misc_test.go ├── row_column.go ├── select_query.go ├── select_query_test.go ├── sq.go ├── sq.md ├── sq_test.go ├── update_query.go ├── update_query_test.go ├── window.go └── window_test.go /.github/mddocs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bokwoon95/sq/eae3b0c03361f5b98ac6d0701c6aa71c94d4e4c2/.github/mddocs -------------------------------------------------------------------------------- /.github/workflows/neocities.yml: -------------------------------------------------------------------------------- 1 | name: Deploy docs to Neocities 2 | on: 3 | push: 4 | branches: [main] 5 | jobs: 6 | deploy_to_neocities: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Clone repo 10 | uses: actions/checkout@v3 11 | - run: mkdir public && .github/mddocs sq.md public/sq.html 12 | - name: Deploy to neocities 13 | uses: bcomnes/deploy-to-neocities@v1 14 | with: 15 | api_token: ${{ secrets.NEOCITIES_API_KEY }} 16 | dist_dir: public 17 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | branches: [main] 7 | jobs: 8 | run_sq_tests: 9 | runs-on: ubuntu-latest 10 | services: 11 | postgres: 12 | image: postgres 13 | env: 14 | POSTGRES_USER: 'user1' 15 | POSTGRES_PASSWORD: 'Hunter2!' 16 | POSTGRES_DB: 'sakila' 17 | options: >- 18 | --health-cmd pg_isready 19 | --health-interval 10s 20 | --health-timeout 5s 21 | --health-retries 5 22 | ports: 23 | - '5456:5432' 24 | mysql: 25 | image: mysql 26 | env: 27 | MYSQL_ROOT_PASSWORD: 'Hunter2!' 28 | MYSQL_USER: 'user1' 29 | MYSQL_PASSWORD: 'Hunter2!' 30 | MYSQL_DATABASE: 'sakila' 31 | options: >- 32 | --health-cmd "mysqladmin ping" 33 | --health-interval 10s 34 | --health-timeout 5s 35 | --health-retries 5 36 | --health-start-period 30s 37 | ports: 38 | - '3330:3306' 39 | sqlserver: 40 | image: 'mcr.microsoft.com/azure-sql-edge' 41 | env: 42 | ACCEPT_EULA: 'Y' 43 | MSSQL_SA_PASSWORD: 'Hunter2!' 44 | options: >- 45 | --health-cmd "/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P Hunter2! -Q 'select 1' -b -o /dev/null" 46 | --health-interval 10s 47 | --health-timeout 5s 48 | --health-retries 5 49 | --health-start-period 30s 50 | ports: 51 | - '1447:1433' 52 | steps: 53 | - name: Install go 54 | uses: actions/setup-go@v3 55 | with: 56 | go-version: '>=1.18.0' 57 | - name: Clone repo 58 | uses: actions/checkout@v3 59 | - run: go test . -tags=fts5 -failfast -shuffle on -coverprofile coverage -race -postgres 'postgres://user1:Hunter2!@localhost:5456/sakila?sslmode=disable' -mysql 'root:Hunter2!@tcp(localhost:3330)/sakila?multiStatements=true&parseTime=true' -sqlserver 'sqlserver://sa:Hunter2!@localhost:1447' 60 | - name: Convert coverage to coverage.lcov 61 | uses: jandelgado/gcov2lcov-action@v1.0.0 62 | with: 63 | infile: coverage 64 | outfile: coverage.lcov 65 | - name: Upload coverage.lcov to Coveralls 66 | uses: coverallsapp/github-action@master 67 | with: 68 | github-token: ${{ secrets.GITHUB_TOKEN }} 69 | path-to-lcov: coverage.lcov 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.sqlite* 2 | .idea 3 | coverage.out 4 | coverage 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Chua Bok Woon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![GoDoc](https://img.shields.io/badge/pkg.go.dev-sq-blue)](https://pkg.go.dev/github.com/bokwoon95/sq) 2 | ![tests](https://github.com/bokwoon95/sq/actions/workflows/tests.yml/badge.svg?branch=main) 3 | [![Go Report Card](https://goreportcard.com/badge/github.com/bokwoon95/sq)](https://goreportcard.com/report/github.com/bokwoon95/sq) 4 | [![Coverage Status](https://coveralls.io/repos/github/bokwoon95/sq/badge.svg?branch=main)](https://coveralls.io/github/bokwoon95/sq?branch=main) 5 | 6 | code example of a select query using sq, to give viewers a quick idea of what the library is about 7 | 8 | # sq (Structured Query) 9 | 10 | [one-page documentation](https://bokwoon.neocities.org/sq.html) 11 | 12 | sq is a type-safe data mapper and query builder for Go. Its concept is simple: you provide a callback function that maps a row to a struct, generics ensure that you get back a slice of structs at the end. Additionally, mentioning a column in the callback function automatically adds it to the SELECT clause so you don't even have to explicitly mention what columns you want to select: the [act of mapping a column is the same as selecting it](#select-example-raw-sql). This eliminates a source of errors where you have specify the columns twice (once in the query itself, once to the call to rows.Scan) and end up missing a column, getting the column order wrong or mistyping a column name. 13 | 14 | Notable features: 15 | 16 | - Works across SQLite, Postgres, MySQL and SQL Server. [[more info](https://bokwoon.neocities.org/sq.html#set-query-dialect)] 17 | - Each dialect has its own query builder, allowing you to use dialect-specific features. [[more info](https://bokwoon.neocities.org/sq.html#dialect-specific-features)] 18 | - Declarative schema migrations. [[more info](https://bokwoon.neocities.org/sq.html#declarative-schema)] 19 | - Supports arrays, enums, JSON and UUID. [[more info](https://bokwoon.neocities.org/sq.html#arrays-enums-json-uuid)] 20 | - Query logging. [[more info](https://bokwoon.neocities.org/sq.html#logging)] 21 | 22 | # Installation 23 | 24 | This package only supports Go 1.19 and above. 25 | 26 | ```shell 27 | $ go get github.com/bokwoon95/sq 28 | $ go install -tags=fts5 github.com/bokwoon95/sqddl@latest 29 | ``` 30 | 31 | # Features 32 | 33 | - IN 34 | - [In Slice](https://bokwoon.neocities.org/sq.html#in-slice) - `a IN (1, 2, 3)` 35 | - [In RowValues](https://bokwoon.neocities.org/sq.html#in-rowvalues) - `(a, b, c) IN ((1, 2, 3), (4, 5, 6), (7, 8, 9))` 36 | - [In Subquery](https://bokwoon.neocities.org/sq.html#in-subquery) - `(a, b) IN (SELECT a, b FROM tbl WHERE condition)` 37 | - CASE 38 | - [Predicate Case](https://bokwoon.neocities.org/sq.html#predicate-case) - `CASE WHEN a THEN b WHEN c THEN d ELSE e END` 39 | - [Simple case](https://bokwoon.neocities.org/sq.html#simple-case) - `CASE expr WHEN a THEN b WHEN c THEN d ELSE e END` 40 | - EXISTS 41 | - [Where Exists](https://bokwoon.neocities.org/sq.html#where-exists) 42 | - [Where Not Exists](https://bokwoon.neocities.org/sq.html#where-not-exists) 43 | - [Select Exists](https://bokwoon.neocities.org/sq.html#querybuilder-fetch-exists) 44 | - [Subqueries](https://bokwoon.neocities.org/sq.html#subqueries) 45 | - [WITH (Common Table Expressions)](https://bokwoon.neocities.org/sq.html#common-table-expressions) 46 | - [Aggregate functions](https://bokwoon.neocities.org/sq.html#aggregate-functions) 47 | - [Window functions](https://bokwoon.neocities.org/sq.html#window-functions) 48 | - [UNION, INTERSECT, EXCEPT](https://bokwoon.neocities.org/sq.html#union-intersect-except) 49 | - [INSERT from SELECT](https://bokwoon.neocities.org/sq.html#querybuilder-insert-from-select) 50 | - RETURNING 51 | - [SQLite RETURNING](https://bokwoon.neocities.org/sq.html#sqlite-returning) 52 | - [Postgres RETURNING](https://bokwoon.neocities.org/sq.html#postgres-returning) 53 | - LastInsertId 54 | - [SQLite LastInsertId](https://bokwoon.neocities.org/sq.html#sqlite-last-insert-id) 55 | - [MySQL LastInsertId](https://bokwoon.neocities.org/sq.html#mysql-last-insert-id) 56 | - Insert ignore duplicates 57 | - [SQLite Insert ignore duplicates](https://bokwoon.neocities.org/sq.html#sqlite-insert-ignore-duplicates) 58 | - [Postgres Insert ignore duplicates](https://bokwoon.neocities.org/sq.html#postgres-insert-ignore-duplicates) 59 | - [MySQL Insert ignore duplicates](https://bokwoon.neocities.org/sq.html#mysql-insert-ignore-duplicates) 60 | - [SQL Server Insert ignore duplicates](https://bokwoon.neocities.org/sq.html#sqlserver-insert-ignore-duplicates) 61 | - Upsert 62 | - [SQLite Upsert](https://bokwoon.neocities.org/sq.html#sqlite-upsert) 63 | - [Postgres Upsert](https://bokwoon.neocities.org/sq.html#postgres-upsert) 64 | - [MySQL Upsert](https://bokwoon.neocities.org/sq.html#mysql-upsert) 65 | - [SQL Server Upsert](https://bokwoon.neocities.org/sq.html#sqlserver-upsert) 66 | - Update with Join 67 | - [SQLite Update with Join](https://bokwoon.neocities.org/sq.html#sqlite-update-with-join) 68 | - [Postgres Update with Join](https://bokwoon.neocities.org/sq.html#postgres-update-with-join) 69 | - [MySQL Update with Join](https://bokwoon.neocities.org/sq.html#mysql-update-with-join) 70 | - [SQL Server Update with Join](https://bokwoon.neocities.org/sq.html#sqlserver-update-with-join) 71 | - Delete with Join 72 | - [SQLite Delete with Join](https://bokwoon.neocities.org/sq.html#sqlite-delete-with-join) 73 | - [Postgres Delete with Join](https://bokwoon.neocities.org/sq.html#postgres-delete-with-join) 74 | - [MySQL Delete with Join](https://bokwoon.neocities.org/sq.html#mysql-delete-with-join) 75 | - [SQL Server Delete with Join](https://bokwoon.neocities.org/sq.html#sqlserver-delete-with-join) 76 | - Bulk Update 77 | - [SQLite Bulk Update](https://bokwoon.neocities.org/sq.html#sqlite-bulk-update) 78 | - [Postgres Bulk Update](https://bokwoon.neocities.org/sq.html#postgres-bulk-update) 79 | - [MySQL Bulk Update](https://bokwoon.neocities.org/sq.html#mysql-bulk-update) 80 | - [SQL Server Bulk Update](https://bokwoon.neocities.org/sq.html#sqlserver-bulk-update) 81 | 82 | ## SELECT example (Raw SQL) 83 | 84 | ```go 85 | db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable") 86 | 87 | actors, err := sq.FetchAll(db, sq. 88 | Queryf("SELECT {*} FROM actor AS a WHERE a.actor_id IN ({})", 89 | []int{1, 2, 3, 4, 5}, 90 | ). 91 | SetDialect(sq.DialectPostgres), 92 | func(row *sq.Row) Actor { 93 | return Actor{ 94 | ActorID: row.Int("a.actor_id"), 95 | FirstName: row.String("a.first_name"), 96 | LastName: row.String("a.last_name"), 97 | LastUpdate: row.Time("a.last_update"), 98 | } 99 | }, 100 | ) 101 | ``` 102 | 103 | ## SELECT example (Query Builder) 104 | 105 | To use the query builder, you must first [define your table structs](https://bokwoon.neocities.org/sq.html#table-structs). 106 | 107 | ```go 108 | type ACTOR struct { 109 | sq.TableStruct 110 | ACTOR_ID sq.NumberField 111 | FIRST_NAME sq.StringField 112 | LAST_NAME sq.StringField 113 | LAST_UPDATE sq.TimeField 114 | } 115 | 116 | db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable") 117 | 118 | a := sq.New[ACTOR]("a") 119 | actors, err := sq.FetchAll(db, sq. 120 | From(a). 121 | Where(a.ACTOR_ID.In([]int{1, 2, 3, 4, 5})). 122 | SetDialect(sq.DialectPostgres), 123 | func(row *sq.Row) Actor { 124 | return Actor{ 125 | ActorID: row.IntField(a.ACTOR_ID), 126 | FirstName: row.StringField(a.FIRST_NAME), 127 | LastName: row.StringField(a.LAST_NAME), 128 | LastUpdate: row.TimeField(a.LAST_UPDATE), 129 | } 130 | }, 131 | ) 132 | ``` 133 | 134 | ## INSERT example (Raw SQL) 135 | 136 | ```go 137 | db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable") 138 | 139 | _, err := sq.Exec(db, sq. 140 | Queryf("INSERT INTO actor (actor_id, first_name, last_name) VALUES {}", sq.RowValues{ 141 | {18, "DAN", "TORN"}, 142 | {56, "DAN", "HARRIS"}, 143 | {166, "DAN", "STREEP"}, 144 | }). 145 | SetDialect(sq.DialectPostgres), 146 | ) 147 | ``` 148 | 149 | ## INSERT example (Query Builder) 150 | 151 | To use the query builder, you must first [define your table structs](https://bokwoon.neocities.org/sq.html#table-structs). 152 | 153 | ```go 154 | type ACTOR struct { 155 | sq.TableStruct 156 | ACTOR_ID sq.NumberField 157 | FIRST_NAME sq.StringField 158 | LAST_NAME sq.StringField 159 | LAST_UPDATE sq.TimeField 160 | } 161 | 162 | db, err := sql.Open("postgres", "postgres://username:password@localhost:5432/sakila?sslmode=disable") 163 | 164 | a := sq.New[ACTOR]("a") 165 | _, err := sq.Exec(db, sq. 166 | InsertInto(a). 167 | Columns(a.ACTOR_ID, a.FIRST_NAME, a.LAST_NAME). 168 | Values(18, "DAN", "TORN"). 169 | Values(56, "DAN", "HARRIS"). 170 | Values(166, "DAN", "STREEP"). 171 | SetDialect(sq.DialectPostgres), 172 | ) 173 | ``` 174 | 175 | For a more detailed overview, look at the [Quickstart](https://bokwoon.neocities.org/sq.html#quickstart). 176 | 177 | ## Project Status 178 | 179 | sq is done for my use case (hence it may seem inactive, but it's just complete). At this point I'm just waiting for people to ask questions or file feature requests under [discussions](https://github.com/bokwoon95/sq/discussions). 180 | 181 | ## Contributing 182 | 183 | See [START\_HERE.md](https://github.com/bokwoon95/sq/blob/main/START_HERE.md). 184 | -------------------------------------------------------------------------------- /START_HERE.md: -------------------------------------------------------------------------------- 1 | This document describes how the codebase is organized. It is meant for people who are contributing to the codebase (or are just casually browsing). 2 | 3 | Files are written in such a way that **each successive file in the list below only depends on files that come before it**. This self-enforced restriction makes deep architectural changes trivial because you can essentially blow away the entire codebase and rewrite it from scratch file-by-file, complete with working tests every step of the way. Please adhere to this file order when submitting pull requests. 4 | 5 | - [**sq.go**](https://github.com/bokwoon95/sq/blob/main/sq.go) 6 | - Core interfaces: SQLWriter, DB, Query, Table, PolicyTable, Window, Field, Predicate, Assignment, Any, Array, Binary, Boolean, Enum, JSON, Number, String, UUID, Time, Enumeration, DialectValuer, 7 | - Data types: Result, TableStruct, ViewStruct. 8 | - Misc utility functions. 9 | - [**fmt.go**](https://github.com/bokwoon95/sq/blob/main/fmt.go) 10 | - Two important string building functions that everything else is built on: [Writef](https://pkg.go.dev/github.com/bokwoon95/sq#Writef) and [WriteValue](https://pkg.go.dev/github.com/bokwoon95/sq#WriteValue). 11 | - Data types: Parameter, BinaryParameter, BooleanParameter, NumberParameter, StringParameter, TimeParameter. 12 | - Utility functions: QuoteIdentifier, EscapeQuote, Sprintf, Sprint. 13 | - [**builtins.go**](https://github.com/bokwoon95/sq/blob/main/builtins.go) 14 | - Builtin data types that are built on top of Writef and WriteValue: Expression (Expr), CustomQuery (Queryf), VariadicPredicate, assignment, RowValue, RowValues, Fields. 15 | - Builtin functions that are built on top of Writef and WriteValue: Eq, Ne, Lt, Le, Gt, Ge, Exists, NotExists, In. 16 | - [**fields.go**](https://github.com/bokwoon95/sq/blob/main/fields.go) 17 | - All of the field types: AnyField, ArrayField, BinaryField, BooleanField, EnumField, JSONField, NumberField, StringField, UUIDField, TimeField. 18 | - Data types: Identifier, Timestamp. 19 | - Functions: [New](https://pkg.go.dev/github.com/bokwoon95/sq#New), ArrayValue, EnumValue, JSONValue, UUIDValue. 20 | - [**cte.go**](https://github.com/bokwoon95/sq/blob/main/cte.go) 21 | - CTE represents an SQL common table expression (CTE). 22 | - UNION, INTERSECT, EXCEPT. 23 | - [**joins.go**](https://github.com/bokwoon95/sq/blob/main/joins.go) 24 | - The various SQL joins. 25 | - [**row_column.go**](https://github.com/bokwoon95/sq/blob/main/row_column.go) 26 | - Row and Column methods. 27 | - [**window.go**](https://github.com/bokwoon95/sq/blob/main/window.go) 28 | - SQL windows and window functions. 29 | - [**select_query.go**](https://github.com/bokwoon95/sq/blob/main/select_query.go) 30 | - SQL SELECT query builder. 31 | - [**insert_query.go**](https://github.com/bokwoon95/sq/blob/main/insert_query.go) 32 | - SQL INSERT query builder. 33 | - [**update_query.go**](https://github.com/bokwoon95/sq/blob/main/update_query.go) 34 | - SQL UPDATE query builder. 35 | - [**delete_query.go**](https://github.com/bokwoon95/sq/blob/main/delete_query.go) 36 | - SQL DELETE query builder. 37 | - [**logger.go**](https://github.com/bokwoon95/sq/blob/main/logger.go) 38 | - sq.Log and sq.VerboseLog. 39 | - [**fetch_exec.go**](https://github.com/bokwoon95/sq/blob/main/fetch_exec.go) 40 | - FetchCursor, FetchOne, FetchAll, Exec. 41 | - CompiledFetch, CompiledExec. 42 | - PreparedFetch, PreparedExec. 43 | - [**misc.go**](https://github.com/bokwoon95/sq/blob/main/misc.go) 44 | - Misc SQL constructs. 45 | - ValueExpression, LiteralValue, DialectExpression, CaseExpression, SimpleCaseExpression. 46 | - SelectValues (`SELECT ... UNION ALL SELECT ... UNION ALL SELECT ...`) 47 | - TableValues (`VALUES (...), (...), (...)`). 48 | - [**integration_test.go**](https://github.com/bokwoon95/sq/blob/main/integration_test.go) 49 | - Tests that interact with a live database i.e. SQLite, Postgres, MySQL and SQL Server. 50 | 51 | ## Testing 52 | 53 | Add tests if you add code. 54 | 55 | To run tests, use: 56 | 57 | ```shell 58 | $ go test . # -failfast -shuffle=on -coverprofile=coverage 59 | ``` 60 | 61 | There are tests that require a live database connection. They will only run if you provide the corresponding database URL in the test flags: 62 | 63 | ```shell 64 | $ go test . -postgres $POSTGRES_URL -mysql $MYSQL_URL -sqlserver $SQLSERVER_URL # -failfast -shuffle=on -coverprofile=coverage 65 | ``` 66 | 67 | You can consider using the [docker-compose.yml defined in the sqddl repo](https://github.com/bokwoon95/sqddl/blob/main/docker-compose.yml) to spin up Postgres, MySQL and SQL Server databases that are reachable at the following URLs: 68 | 69 | ```shell 70 | # docker-compose up -d 71 | POSTGRES_URL='postgres://user1:Hunter2!@localhost:5456/sakila?sslmode=disable' 72 | MYSQL_URL='root:Hunter2!@tcp(localhost:3330)/sakila?multiStatements=true&parseTime=true' 73 | MARIADB_URL='root:Hunter2!@tcp(localhost:3340)/sakila?multiStatements=true&parseTime=true' 74 | SQLSERVER_URL='sqlserver://sa:Hunter2!@localhost:1447' 75 | ``` 76 | 77 | ## Documentation 78 | 79 | Documentation is contained entirely within [sq.md](https://github.com/bokwoon95/sq/blob/main/sq.md) in the project root directory. You can view the output at [https://bokwoon.neocities.org/sq.html](https://bokwoon.neocities.org/sq.html). The documentation is regenerated everytime a new commit is pushed to the main branch, so to change the documentation just change sq.md and submit a pull request. 80 | 81 | You can preview the output of sq.md locally by installing [github.com/bokwoon95/mddocs](https://github.com/bokwoon95/mddocs) and running it with sq.md as the argument. 82 | 83 | ```shell 84 | $ go install github/bokwoon95/mddocs@latest 85 | $ mddocs 86 | Usage: 87 | mddocs project.md # serves project.md on a localhost connection 88 | mddocs project.md project.html # render project.md into project.html 89 | 90 | $ mddocs sq.md 91 | serving sq.md at localhost:6060 92 | ``` 93 | 94 | To add a new section and register it in the table of contents, append a `#headerID` to the end of a header (replace `headerID` with the actual header ID). The header ID should only contain unicode letters, digits, hyphen `-` and underscore `_`. 95 | 96 | ```text 97 | ## This is a header. 98 | 99 | ## This is a header with a headerID. #header-id <-- added to table of contents 100 | ``` 101 | -------------------------------------------------------------------------------- /builtins.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "strings" 8 | ) 9 | 10 | // Expression is an SQL expression that satisfies the Table, Field, Predicate, 11 | // Binary, Boolean, Number, String and Time interfaces. 12 | type Expression struct { 13 | format string 14 | values []any 15 | alias string 16 | } 17 | 18 | var _ interface { 19 | Table 20 | Field 21 | Predicate 22 | Any 23 | Assignment 24 | } = (*Expression)(nil) 25 | 26 | // Expr creates a new Expression using Writef syntax. 27 | func Expr(format string, values ...any) Expression { 28 | return Expression{format: format, values: values} 29 | } 30 | 31 | // WriteSQL implements the SQLWriter interface. 32 | func (expr Expression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 33 | err := Writef(ctx, dialect, buf, args, params, expr.format, expr.values) 34 | if err != nil { 35 | return err 36 | } 37 | return nil 38 | } 39 | 40 | // As returns a new Expression with the given alias. 41 | func (expr Expression) As(alias string) Expression { 42 | expr.alias = alias 43 | return expr 44 | } 45 | 46 | // In returns an 'expr IN (value)' Predicate. 47 | func (expr Expression) In(value any) Predicate { return In(expr, value) } 48 | 49 | // In returns an 'expr NOT IN (value)' Predicate. 50 | func (expr Expression) NotIn(value any) Predicate { return NotIn(expr, value) } 51 | 52 | // Eq returns an 'expr = value' Predicate. 53 | func (expr Expression) Eq(value any) Predicate { return cmp("=", expr, value) } 54 | 55 | // Ne returns an 'expr <> value' Predicate. 56 | func (expr Expression) Ne(value any) Predicate { return cmp("<>", expr, value) } 57 | 58 | // Lt returns an 'expr < value' Predicate. 59 | func (expr Expression) Lt(value any) Predicate { return cmp("<", expr, value) } 60 | 61 | // Le returns an 'expr <= value' Predicate. 62 | func (expr Expression) Le(value any) Predicate { return cmp("<=", expr, value) } 63 | 64 | // Gt returns an 'expr > value' Predicate. 65 | func (expr Expression) Gt(value any) Predicate { return cmp(">", expr, value) } 66 | 67 | // Ge returns an 'expr >= value' Predicate. 68 | func (expr Expression) Ge(value any) Predicate { return cmp(">=", expr, value) } 69 | 70 | // GetAlias returns the alias of the Expression. 71 | func (expr Expression) GetAlias() string { return expr.alias } 72 | 73 | // IsTable implements the Table interface. 74 | func (expr Expression) IsTable() {} 75 | 76 | // IsField implements the Field interface. 77 | func (expr Expression) IsField() {} 78 | 79 | // IsArray implements the Array interface. 80 | func (expr Expression) IsArray() {} 81 | 82 | // IsBinary implements the Binary interface. 83 | func (expr Expression) IsBinary() {} 84 | 85 | // IsBoolean implements the Boolean interface. 86 | func (expr Expression) IsBoolean() {} 87 | 88 | // IsEnum implements the Enum interface. 89 | func (expr Expression) IsEnum() {} 90 | 91 | // IsJSON implements the JSON interface. 92 | func (expr Expression) IsJSON() {} 93 | 94 | // IsNumber implements the Number interface. 95 | func (expr Expression) IsNumber() {} 96 | 97 | // IsString implements the String interface. 98 | func (expr Expression) IsString() {} 99 | 100 | // IsTime implements the Time interface. 101 | func (expr Expression) IsTime() {} 102 | 103 | // IsUUID implements the UUID interface. 104 | func (expr Expression) IsUUID() {} 105 | 106 | func (e Expression) IsAssignment() {} 107 | 108 | // CustomQuery represents a user-defined query. 109 | type CustomQuery struct { 110 | Dialect string 111 | Format string 112 | Values []any 113 | fields []Field 114 | } 115 | 116 | var _ Query = (*CustomQuery)(nil) 117 | 118 | // Queryf creates a new query using Writef syntax. 119 | func Queryf(format string, values ...any) CustomQuery { 120 | return CustomQuery{Format: format, Values: values} 121 | } 122 | 123 | // Queryf creates a new SQLite query using Writef syntax. 124 | func (b sqliteQueryBuilder) Queryf(format string, values ...any) CustomQuery { 125 | return CustomQuery{Dialect: DialectSQLite, Format: format, Values: values} 126 | } 127 | 128 | // Queryf creates a new Postgres query using Writef syntax. 129 | func (b postgresQueryBuilder) Queryf(format string, values ...any) CustomQuery { 130 | return CustomQuery{Dialect: DialectPostgres, Format: format, Values: values} 131 | } 132 | 133 | // Queryf creates a new MySQL query using Writef syntax. 134 | func (b mysqlQueryBuilder) Queryf(format string, values ...any) CustomQuery { 135 | return CustomQuery{Dialect: DialectMySQL, Format: format, Values: values} 136 | } 137 | 138 | // Queryf creates a new SQL Server query using Writef syntax. 139 | func (b sqlserverQueryBuilder) Queryf(format string, values ...any) CustomQuery { 140 | return CustomQuery{Dialect: DialectSQLServer, Format: format, Values: values} 141 | } 142 | 143 | // Append returns a new CustomQuery with the format string and values slice 144 | // appended to the current CustomQuery. 145 | func (q CustomQuery) Append(format string, values ...any) CustomQuery { 146 | q.Format += " " + format 147 | q.Values = append(q.Values, values...) 148 | return q 149 | } 150 | 151 | // WriteSQL implements the SQLWriter interface. 152 | func (q CustomQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 153 | var err error 154 | format := q.Format 155 | splitAt := -1 156 | for i := strings.IndexByte(format, '{'); i >= 0; i = strings.IndexByte(format, '{') { 157 | if i+2 <= len(format) && format[i:i+2] == "{{" { 158 | format = format[i+2:] 159 | continue 160 | } 161 | if i+3 <= len(format) && format[i:i+3] == "{*}" { 162 | splitAt = len(q.Format) - len(format[i:]) 163 | break 164 | } 165 | format = format[i+1:] 166 | } 167 | if splitAt < 0 { 168 | return Writef(ctx, dialect, buf, args, params, q.Format, q.Values) 169 | } 170 | runningValuesIndex := 0 171 | ordinalIndices := make(map[int]int) 172 | err = writef(ctx, dialect, buf, args, params, q.Format[:splitAt], q.Values, &runningValuesIndex, ordinalIndices) 173 | if err != nil { 174 | return err 175 | } 176 | err = writeFields(ctx, dialect, buf, args, params, q.fields, true) 177 | if err != nil { 178 | return err 179 | } 180 | err = writef(ctx, dialect, buf, args, params, q.Format[splitAt+3:], q.Values, &runningValuesIndex, ordinalIndices) 181 | if err != nil { 182 | return err 183 | } 184 | return nil 185 | } 186 | 187 | // SetFetchableFields sets the fetchable fields of the query. 188 | func (q CustomQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { 189 | format := q.Format 190 | for i := strings.IndexByte(format, '{'); i >= 0; i = strings.IndexByte(format, '{') { 191 | if i+2 <= len(format) && format[i:i+2] == "{{" { 192 | format = format[i+2:] 193 | continue 194 | } 195 | if i+3 <= len(format) && format[i:i+3] == "{*}" { 196 | q.fields = fields 197 | return q, true 198 | } 199 | format = format[i+1:] 200 | } 201 | return q, false 202 | } 203 | 204 | // GetFetchableFields gets the fetchable fields of the query. 205 | func (q CustomQuery) GetFetchableFields() []Field { 206 | return q.fields 207 | } 208 | 209 | // GetDialect gets the dialect of the query. 210 | func (q CustomQuery) GetDialect() string { return q.Dialect } 211 | 212 | // SetDialect sets the dialect of the query. 213 | func (q CustomQuery) SetDialect(dialect string) CustomQuery { 214 | q.Dialect = dialect 215 | return q 216 | } 217 | 218 | // VariadicPredicate represents the 'x AND y AND z...' or 'x OR Y OR z...' SQL 219 | // construct. 220 | type VariadicPredicate struct { 221 | // Toplevel indicates if the VariadicPredicate can skip writing the 222 | // (surrounding brackets). 223 | Toplevel bool 224 | alias string 225 | // If IsDisjunction is true, the Predicates are joined using OR. If false, 226 | // the Predicates are joined using AND. The default is AND. 227 | IsDisjunction bool 228 | // Predicates holds the predicates inside the VariadicPredicate 229 | Predicates []Predicate 230 | } 231 | 232 | var _ Predicate = (*VariadicPredicate)(nil) 233 | 234 | // And joins the predicates together with the AND operator. 235 | func And(predicates ...Predicate) VariadicPredicate { 236 | return VariadicPredicate{IsDisjunction: false, Predicates: predicates} 237 | } 238 | 239 | // Or joins the predicates together with the OR operator. 240 | func Or(predicates ...Predicate) VariadicPredicate { 241 | return VariadicPredicate{IsDisjunction: true, Predicates: predicates} 242 | } 243 | 244 | // WriteSQL implements the SQLWriter interface. 245 | func (p VariadicPredicate) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 246 | var err error 247 | if len(p.Predicates) == 0 { 248 | return fmt.Errorf("VariadicPredicate empty") 249 | } 250 | 251 | if len(p.Predicates) == 1 { 252 | switch p1 := p.Predicates[0].(type) { 253 | case nil: 254 | return fmt.Errorf("predicate #1 is nil") 255 | case VariadicPredicate: 256 | p1.Toplevel = p.Toplevel 257 | err = p1.WriteSQL(ctx, dialect, buf, args, params) 258 | if err != nil { 259 | return err 260 | } 261 | default: 262 | err = p.Predicates[0].WriteSQL(ctx, dialect, buf, args, params) 263 | if err != nil { 264 | return err 265 | } 266 | } 267 | return nil 268 | } 269 | 270 | if !p.Toplevel { 271 | buf.WriteString("(") 272 | } 273 | for i, predicate := range p.Predicates { 274 | if i > 0 { 275 | if p.IsDisjunction { 276 | buf.WriteString(" OR ") 277 | } else { 278 | buf.WriteString(" AND ") 279 | } 280 | } 281 | switch predicate := predicate.(type) { 282 | case nil: 283 | return fmt.Errorf("predicate #%d is nil", i+1) 284 | case VariadicPredicate: 285 | predicate.Toplevel = false 286 | err = predicate.WriteSQL(ctx, dialect, buf, args, params) 287 | if err != nil { 288 | return fmt.Errorf("predicate #%d: %w", i+1, err) 289 | } 290 | default: 291 | err = predicate.WriteSQL(ctx, dialect, buf, args, params) 292 | if err != nil { 293 | return fmt.Errorf("predicate #%d: %w", i+1, err) 294 | } 295 | } 296 | } 297 | if !p.Toplevel { 298 | buf.WriteString(")") 299 | } 300 | return nil 301 | } 302 | 303 | // As returns a new VariadicPredicate with the given alias. 304 | func (p VariadicPredicate) As(alias string) VariadicPredicate { 305 | p.alias = alias 306 | return p 307 | } 308 | 309 | // GetAlias returns the alias of the VariadicPredicate. 310 | func (p VariadicPredicate) GetAlias() string { return p.alias } 311 | 312 | // IsField implements the Field interface. 313 | func (p VariadicPredicate) IsField() {} 314 | 315 | // IsBooleanType implements the Predicate interface. 316 | func (p VariadicPredicate) IsBoolean() {} 317 | 318 | // assignment represents assigning a value to a Field. 319 | type assignment struct { 320 | field Field 321 | value any 322 | } 323 | 324 | var _ Assignment = (*assignment)(nil) 325 | 326 | // Set creates a new Assignment assigning the value to a field. 327 | func Set(field Field, value any) Assignment { 328 | return assignment{field: field, value: value} 329 | } 330 | 331 | // Setf creates a new Assignment assigning a custom expression to a Field. 332 | func Setf(field Field, format string, values ...any) Assignment { 333 | return assignment{field: field, value: Expr(format, values...)} 334 | } 335 | 336 | // WriteSQL implements the SQLWriter interface. 337 | func (a assignment) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 338 | if a.field == nil { 339 | return fmt.Errorf("field is nil") 340 | } 341 | var err error 342 | if dialect == DialectMySQL { 343 | err = a.field.WriteSQL(ctx, dialect, buf, args, params) 344 | if err != nil { 345 | return err 346 | } 347 | } else { 348 | err = withPrefix(a.field, "").WriteSQL(ctx, dialect, buf, args, params) 349 | if err != nil { 350 | return err 351 | } 352 | } 353 | buf.WriteString(" = ") 354 | _, isQuery := a.value.(Query) 355 | if isQuery { 356 | buf.WriteString("(") 357 | } 358 | err = WriteValue(ctx, dialect, buf, args, params, a.value) 359 | if err != nil { 360 | return err 361 | } 362 | if isQuery { 363 | buf.WriteString(")") 364 | } 365 | return nil 366 | } 367 | 368 | // IsAssignment implements the Assignment interface. 369 | func (a assignment) IsAssignment() {} 370 | 371 | // Assignments represents a list of Assignments e.g. x = 1, y = 2, z = 3. 372 | type Assignments []Assignment 373 | 374 | // WriteSQL implements the SQLWriter interface. 375 | func (as Assignments) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 376 | var err error 377 | for i, assignment := range as { 378 | if assignment == nil { 379 | return fmt.Errorf("assignment #%d is nil", i+1) 380 | } 381 | if i > 0 { 382 | buf.WriteString(", ") 383 | } 384 | err = assignment.WriteSQL(ctx, dialect, buf, args, params) 385 | if err != nil { 386 | return fmt.Errorf("assignment #%d: %w", i+1, err) 387 | } 388 | } 389 | return nil 390 | } 391 | 392 | // RowValue represents an SQL row value expression e.g. (x, y, z). 393 | type RowValue []any 394 | 395 | // WriteSQL implements the SQLWriter interface. 396 | func (r RowValue) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 397 | buf.WriteString("(") 398 | var err error 399 | for i, value := range r { 400 | if i > 0 { 401 | buf.WriteString(", ") 402 | } 403 | _, isQuery := value.(Query) 404 | if isQuery { 405 | buf.WriteString("(") 406 | } 407 | err = WriteValue(ctx, dialect, buf, args, params, value) 408 | if err != nil { 409 | return fmt.Errorf("rowvalue #%d: %w", i+1, err) 410 | } 411 | if isQuery { 412 | buf.WriteString(")") 413 | } 414 | } 415 | buf.WriteString(")") 416 | return nil 417 | } 418 | 419 | // In returns an 'rowvalue IN (value)' Predicate. 420 | func (r RowValue) In(v any) Predicate { return In(r, v) } 421 | 422 | // NotIn returns an 'rowvalue NOT IN (value)' Predicate. 423 | func (r RowValue) NotIn(v any) Predicate { return NotIn(r, v) } 424 | 425 | // Eq returns an 'rowvalue = value' Predicate. 426 | func (r RowValue) Eq(v any) Predicate { return cmp("=", r, v) } 427 | 428 | // RowValues represents a list of RowValues e.g. (x, y, z), (a, b, c). 429 | type RowValues []RowValue 430 | 431 | // WriteSQL implements the SQLWriter interface. 432 | func (rs RowValues) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 433 | var err error 434 | for i, r := range rs { 435 | if i > 0 { 436 | buf.WriteString(", ") 437 | } 438 | err = r.WriteSQL(ctx, dialect, buf, args, params) 439 | if err != nil { 440 | return fmt.Errorf("rowvalues #%d: %w", i+1, err) 441 | } 442 | } 443 | return nil 444 | } 445 | 446 | // Fields represents a list of Fields e.g. tbl.field1, tbl.field2, tbl.field3. 447 | type Fields []Field 448 | 449 | // WriteSQL implements the SQLWriter interface. 450 | func (fs Fields) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 451 | var err error 452 | for i, field := range fs { 453 | if field == nil { 454 | return fmt.Errorf("field #%d is nil", i+1) 455 | } 456 | if i > 0 { 457 | buf.WriteString(", ") 458 | } 459 | _, isQuery := field.(Query) 460 | if isQuery { 461 | buf.WriteString("(") 462 | } 463 | err = field.WriteSQL(ctx, dialect, buf, args, params) 464 | if err != nil { 465 | return fmt.Errorf("field #%d: %w", i+1, err) 466 | } 467 | if isQuery { 468 | buf.WriteString(")") 469 | } 470 | } 471 | return nil 472 | } 473 | 474 | type ( 475 | sqliteQueryBuilder struct{ ctes []CTE } 476 | postgresQueryBuilder struct{ ctes []CTE } 477 | mysqlQueryBuilder struct{ ctes []CTE } 478 | sqlserverQueryBuilder struct{ ctes []CTE } 479 | ) 480 | 481 | // Dialect-specific query builder variables. 482 | var ( 483 | SQLite sqliteQueryBuilder 484 | Postgres postgresQueryBuilder 485 | MySQL mysqlQueryBuilder 486 | SQLServer sqlserverQueryBuilder 487 | ) 488 | 489 | // With sets the CTEs in the SQLiteQueryBuilder. 490 | func (b sqliteQueryBuilder) With(ctes ...CTE) sqliteQueryBuilder { 491 | b.ctes = ctes 492 | return b 493 | } 494 | 495 | // With sets the CTEs in the PostgresQueryBuilder. 496 | func (b postgresQueryBuilder) With(ctes ...CTE) postgresQueryBuilder { 497 | b.ctes = ctes 498 | return b 499 | } 500 | 501 | // With sets the CTEs in the MySQLQueryBuilder. 502 | func (b mysqlQueryBuilder) With(ctes ...CTE) mysqlQueryBuilder { 503 | b.ctes = ctes 504 | return b 505 | } 506 | 507 | // With sets the CTEs in the SQLServerQueryBuilder. 508 | func (b sqlserverQueryBuilder) With(ctes ...CTE) sqlserverQueryBuilder { 509 | b.ctes = ctes 510 | return b 511 | } 512 | 513 | // ToSQL converts an SQLWriter into a query string and args slice. 514 | // 515 | // The params map is used to hold the mappings between named parameters in the 516 | // query to the corresponding index in the args slice and is used for rebinding 517 | // args by their parameter name. If you don't need to track this, you can pass 518 | // in a nil map. 519 | func ToSQL(dialect string, w SQLWriter, params map[string][]int) (query string, args []any, err error) { 520 | return ToSQLContext(context.Background(), dialect, w, params) 521 | } 522 | 523 | // ToSQLContext is like ToSQL but additionally requires a context.Context. 524 | func ToSQLContext(ctx context.Context, dialect string, w SQLWriter, params map[string][]int) (query string, args []any, err error) { 525 | if w == nil { 526 | return "", nil, fmt.Errorf("SQLWriter is nil") 527 | } 528 | if dialect == "" { 529 | if q, ok := w.(Query); ok { 530 | dialect = q.GetDialect() 531 | } 532 | } 533 | buf := bufpool.Get().(*bytes.Buffer) 534 | buf.Reset() 535 | defer bufpool.Put(buf) 536 | err = w.WriteSQL(ctx, dialect, buf, &args, params) 537 | query = buf.String() 538 | if err != nil { 539 | return query, args, err 540 | } 541 | return query, args, nil 542 | } 543 | 544 | // Eq returns an 'x = y' Predicate. 545 | func Eq(x, y any) Predicate { return cmp("=", x, y) } 546 | 547 | // Ne returns an 'x <> y' Predicate. 548 | func Ne(x, y any) Predicate { return cmp("<>", x, y) } 549 | 550 | // Lt returns an 'x < y' Predicate. 551 | func Lt(x, y any) Predicate { return cmp("<", x, y) } 552 | 553 | // Le returns an 'x <= y' Predicate. 554 | func Le(x, y any) Predicate { return cmp("<=", x, y) } 555 | 556 | // Gt returns an 'x > y' Predicate. 557 | func Gt(x, y any) Predicate { return cmp(">", x, y) } 558 | 559 | // Ge returns an 'x >= y' Predicate. 560 | func Ge(x, y any) Predicate { return cmp(">=", x, y) } 561 | 562 | // Exists returns an 'EXISTS (query)' Predicate. 563 | func Exists(query Query) Predicate { return Expr("EXISTS ({})", query) } 564 | 565 | // NotExists returns a 'NOT EXISTS (query)' Predicate. 566 | func NotExists(query Query) Predicate { return Expr("NOT EXISTS ({})", query) } 567 | 568 | // In returns an 'x IN (y)' Predicate. 569 | func In(x, y any) Predicate { 570 | _, isQueryA := x.(Query) 571 | _, isRowValueB := y.(RowValue) 572 | if !isQueryA && !isRowValueB { 573 | return Expr("{} IN ({})", x, y) 574 | } else if !isQueryA && isRowValueB { 575 | return Expr("{} IN {}", x, y) 576 | } else if isQueryA && !isRowValueB { 577 | return Expr("({}) IN ({})", x, y) 578 | } else { 579 | return Expr("({}) IN {}", x, y) 580 | } 581 | } 582 | 583 | // NotIn returns an 'x NOT IN (y)' Predicate. 584 | func NotIn(x, y any) Predicate { 585 | _, isQueryA := x.(Query) 586 | _, isRowValueB := y.(RowValue) 587 | if !isQueryA && !isRowValueB { 588 | return Expr("{} NOT IN ({})", x, y) 589 | } else if !isQueryA && isRowValueB { 590 | return Expr("{} NOT IN {}", x, y) 591 | } else if isQueryA && !isRowValueB { 592 | return Expr("({}) NOT IN ({})", x, y) 593 | } else { 594 | return Expr("({}) NOT IN {}", x, y) 595 | } 596 | } 597 | 598 | // cmp returns an 'x y' Predicate. 599 | func cmp(operator string, x, y any) Expression { 600 | _, isQueryA := x.(Query) 601 | _, isQueryB := y.(Query) 602 | if !isQueryA && !isQueryB { 603 | return Expr("{} "+operator+" {}", x, y) 604 | } else if !isQueryA && isQueryB { 605 | return Expr("{} "+operator+" ({})", x, y) 606 | } else if isQueryA && !isQueryB { 607 | return Expr("({}) "+operator+" {}", x, y) 608 | } else { 609 | return Expr("({}) "+operator+" ({})", x, y) 610 | } 611 | } 612 | 613 | // appendPolicy will append a policy from a Table (if it implements 614 | // PolicyTable) to a slice of policies. The resultant slice is returned. 615 | func appendPolicy(ctx context.Context, dialect string, policies []Predicate, table Table) ([]Predicate, error) { 616 | policyTable, ok := table.(PolicyTable) 617 | if !ok { 618 | return policies, nil 619 | } 620 | policy, err := policyTable.Policy(ctx, dialect) 621 | if err != nil { 622 | return nil, err 623 | } 624 | if policy != nil { 625 | policies = append(policies, policy) 626 | } 627 | return policies, nil 628 | } 629 | 630 | // appendPredicates will append a slices of predicates into a predicate. 631 | func appendPredicates(predicate Predicate, predicates []Predicate) VariadicPredicate { 632 | if predicate == nil { 633 | return And(predicates...) 634 | } 635 | if p1, ok := predicate.(VariadicPredicate); ok && !p1.IsDisjunction { 636 | p1.Predicates = append(p1.Predicates, predicates...) 637 | return p1 638 | } 639 | p2 := VariadicPredicate{Predicates: make([]Predicate, 1+len(predicates))} 640 | p2.Predicates[0] = predicate 641 | copy(p2.Predicates[1:], predicates) 642 | return p2 643 | } 644 | 645 | func writeTop(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, topLimit, topPercentLimit any, withTies bool) error { 646 | var err error 647 | if topLimit != nil { 648 | buf.WriteString("TOP (") 649 | err = WriteValue(ctx, dialect, buf, args, params, topLimit) 650 | if err != nil { 651 | return fmt.Errorf("TOP: %w", err) 652 | } 653 | buf.WriteString(") ") 654 | } else if topPercentLimit != nil { 655 | buf.WriteString("TOP (") 656 | err = WriteValue(ctx, dialect, buf, args, params, topPercentLimit) 657 | if err != nil { 658 | return fmt.Errorf("TOP PERCENT: %w", err) 659 | } 660 | buf.WriteString(") PERCENT ") 661 | } 662 | if (topLimit != nil || topPercentLimit != nil) && withTies { 663 | buf.WriteString("WITH TIES ") 664 | } 665 | return nil 666 | } 667 | -------------------------------------------------------------------------------- /colors.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package sq 4 | 5 | import ( 6 | "os" 7 | "syscall" 8 | ) 9 | 10 | func init() { 11 | // https://stackoverflow.com/a/69542231 12 | const ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x4 13 | var stderrMode uint32 14 | stderr := syscall.Handle(os.Stderr.Fd()) 15 | syscall.GetConsoleMode(stderr, &stderrMode) 16 | syscall.MustLoadDLL("kernel32").MustFindProc("SetConsoleMode").Call(uintptr(stderr), uintptr(stderrMode|ENABLE_VIRTUAL_TERMINAL_PROCESSING)) 17 | var stdoutMode uint32 18 | stdout := syscall.Handle(os.Stdout.Fd()) 19 | syscall.GetConsoleMode(stdout, &stdoutMode) 20 | syscall.MustLoadDLL("kernel32").MustFindProc("SetConsoleMode").Call(uintptr(stdout), uintptr(stdoutMode|ENABLE_VIRTUAL_TERMINAL_PROCESSING)) 21 | } 22 | -------------------------------------------------------------------------------- /cte.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql" 7 | "fmt" 8 | ) 9 | 10 | // CTE represents an SQL common table expression (CTE). 11 | type CTE struct { 12 | query Query 13 | columns []string 14 | recursive bool 15 | materialized sql.NullBool 16 | name string 17 | alias string 18 | } 19 | 20 | var _ Table = (*CTE)(nil) 21 | 22 | // NewCTE creates a new CTE. 23 | func NewCTE(name string, columns []string, query Query) CTE { 24 | return CTE{name: name, columns: columns, query: query} 25 | } 26 | 27 | // NewRecursiveCTE creates a new recursive CTE. 28 | func NewRecursiveCTE(name string, columns []string, query Query) CTE { 29 | return CTE{name: name, columns: columns, query: query, recursive: true} 30 | } 31 | 32 | // WriteSQL implements the SQLWriter interface. 33 | func (cte CTE) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 34 | buf.WriteString(QuoteIdentifier(dialect, cte.name)) 35 | return nil 36 | } 37 | 38 | // As returns a new CTE with the given alias. 39 | func (cte CTE) As(alias string) CTE { 40 | cte.alias = alias 41 | return cte 42 | } 43 | 44 | // Materialized returns a new CTE marked as MATERIALIZED. This only works on 45 | // postgres. 46 | func (cte CTE) Materialized() CTE { 47 | cte.materialized.Valid = true 48 | cte.materialized.Bool = true 49 | return cte 50 | } 51 | 52 | // Materialized returns a new CTE marked as NOT MATERIALIZED. This only works 53 | // on postgres. 54 | func (cte CTE) NotMaterialized() CTE { 55 | cte.materialized.Valid = true 56 | cte.materialized.Bool = false 57 | return cte 58 | } 59 | 60 | // Field returns a Field from the CTE. 61 | func (cte CTE) Field(name string) AnyField { 62 | return NewAnyField(name, NewTableStruct("", cte.name, cte.alias)) 63 | } 64 | 65 | // GetAlias returns the alias of the CTE. 66 | func (cte CTE) GetAlias() string { return cte.alias } 67 | 68 | // AssertTable implements the Table interface. 69 | func (cte CTE) IsTable() {} 70 | 71 | func writeCTEs(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, ctes []CTE) error { 72 | var hasRecursiveCTE bool 73 | for _, cte := range ctes { 74 | if cte.recursive { 75 | hasRecursiveCTE = true 76 | break 77 | } 78 | } 79 | if hasRecursiveCTE { 80 | buf.WriteString("WITH RECURSIVE ") 81 | } else { 82 | buf.WriteString("WITH ") 83 | } 84 | for i, cte := range ctes { 85 | if i > 0 { 86 | buf.WriteString(", ") 87 | } 88 | if cte.name == "" { 89 | return fmt.Errorf("CTE #%d has no name", i+1) 90 | } 91 | buf.WriteString(QuoteIdentifier(dialect, cte.name)) 92 | if len(cte.columns) > 0 { 93 | buf.WriteString(" (") 94 | for j, column := range cte.columns { 95 | if j > 0 { 96 | buf.WriteString(", ") 97 | } 98 | buf.WriteString(QuoteIdentifier(dialect, column)) 99 | } 100 | buf.WriteString(")") 101 | } 102 | buf.WriteString(" AS ") 103 | if dialect == DialectPostgres && cte.materialized.Valid { 104 | if cte.materialized.Bool { 105 | buf.WriteString("MATERIALIZED ") 106 | } else { 107 | buf.WriteString("NOT MATERIALIZED ") 108 | } 109 | } 110 | buf.WriteString("(") 111 | switch query := cte.query.(type) { 112 | case nil: 113 | return fmt.Errorf("CTE #%d query is nil", i+1) 114 | case VariadicQuery: 115 | query.Toplevel = true 116 | err := query.WriteSQL(ctx, dialect, buf, args, params) 117 | if err != nil { 118 | return fmt.Errorf("CTE #%d failed to build query: %w", i+1, err) 119 | } 120 | default: 121 | err := query.WriteSQL(ctx, dialect, buf, args, params) 122 | if err != nil { 123 | return fmt.Errorf("CTE #%d failed to build query: %w", i+1, err) 124 | } 125 | } 126 | buf.WriteString(")") 127 | } 128 | buf.WriteString(" ") 129 | return nil 130 | } 131 | 132 | // VariadicQueryOperator represents a variadic query operator. 133 | type VariadicQueryOperator string 134 | 135 | // VariadicQuery operators. 136 | const ( 137 | QueryUnion VariadicQueryOperator = "UNION" 138 | QueryUnionAll VariadicQueryOperator = "UNION ALL" 139 | QueryIntersect VariadicQueryOperator = "INTERSECT" 140 | QueryIntersectAll VariadicQueryOperator = "INTERSECT ALL" 141 | QueryExcept VariadicQueryOperator = "EXCEPT" 142 | QueryExceptAll VariadicQueryOperator = "EXCEPT ALL" 143 | ) 144 | 145 | // VariadicQuery represents the 'x UNION y UNION z...' etc SQL constructs. 146 | type VariadicQuery struct { 147 | Toplevel bool 148 | Operator VariadicQueryOperator 149 | Queries []Query 150 | } 151 | 152 | var _ Query = (*VariadicQuery)(nil) 153 | 154 | // Union joins the queries together with the UNION operator. 155 | func Union(queries ...Query) VariadicQuery { 156 | return VariadicQuery{Operator: QueryUnion, Queries: queries} 157 | } 158 | 159 | // UnionAll joins the queries together with the UNION ALL operator. 160 | func UnionAll(queries ...Query) VariadicQuery { 161 | return VariadicQuery{Operator: QueryUnionAll, Queries: queries} 162 | } 163 | 164 | // Intersect joins the queries together with the INTERSECT operator. 165 | func Intersect(queries ...Query) VariadicQuery { 166 | return VariadicQuery{Operator: QueryIntersect, Queries: queries} 167 | } 168 | 169 | // IntersectAll joins the queries together with the INTERSECT ALL operator. 170 | func IntersectAll(queries ...Query) VariadicQuery { 171 | return VariadicQuery{Operator: QueryIntersectAll, Queries: queries} 172 | } 173 | 174 | // Except joins the queries together with the EXCEPT operator. 175 | func Except(queries ...Query) VariadicQuery { 176 | return VariadicQuery{Operator: QueryExcept, Queries: queries} 177 | } 178 | 179 | // ExceptAll joins the queries together with the EXCEPT ALL operator. 180 | func ExceptAll(queries ...Query) VariadicQuery { 181 | return VariadicQuery{Operator: QueryExceptAll, Queries: queries} 182 | } 183 | 184 | // WriteSQL implements the SQLWriter interface. 185 | func (q VariadicQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 186 | var err error 187 | if q.Operator == "" { 188 | q.Operator = QueryUnion 189 | } 190 | if len(q.Queries) == 0 { 191 | return fmt.Errorf("VariadicQuery empty") 192 | } 193 | 194 | if len(q.Queries) == 1 { 195 | switch q1 := q.Queries[0].(type) { 196 | case nil: 197 | return fmt.Errorf("query #1 is nil") 198 | case VariadicQuery: 199 | q1.Toplevel = q.Toplevel 200 | err = q1.WriteSQL(ctx, dialect, buf, args, params) 201 | if err != nil { 202 | return err 203 | } 204 | default: 205 | err = q.Queries[0].WriteSQL(ctx, dialect, buf, args, params) 206 | if err != nil { 207 | return err 208 | } 209 | } 210 | return nil 211 | } 212 | 213 | if !q.Toplevel { 214 | buf.WriteString("(") 215 | } 216 | for i, query := range q.Queries { 217 | if i > 0 { 218 | buf.WriteString(" " + string(q.Operator) + " ") 219 | } 220 | if query == nil { 221 | return fmt.Errorf("query #%d is nil", i+1) 222 | } 223 | err = query.WriteSQL(ctx, dialect, buf, args, params) 224 | if err != nil { 225 | return fmt.Errorf("query #%d: %w", i+1, err) 226 | } 227 | } 228 | if !q.Toplevel { 229 | buf.WriteString(")") 230 | } 231 | return nil 232 | } 233 | 234 | // SetFetchableFields implements the Query interface. 235 | func (q VariadicQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { 236 | return q, false 237 | } 238 | 239 | // GetFetchableFields implements the Query interface. 240 | func (q VariadicQuery) GetFetchableFields() []Field { 241 | return nil 242 | } 243 | 244 | // GetDialect returns the SQL dialect of the VariadicQuery. 245 | func (q VariadicQuery) GetDialect() string { 246 | if len(q.Queries) == 0 { 247 | return "" 248 | } 249 | q1 := q.Queries[0] 250 | if q1 == nil { 251 | return "" 252 | } 253 | return q1.GetDialect() 254 | } 255 | -------------------------------------------------------------------------------- /cte_test.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql" 7 | "errors" 8 | "testing" 9 | 10 | "github.com/bokwoon95/sq/internal/testutil" 11 | ) 12 | 13 | func TestCTE(t *testing.T) { 14 | t.Run("basic", func(t *testing.T) { 15 | cte := NewCTE("cte", []string{"n"}, Queryf("SELECT 1")).Materialized().NotMaterialized().As("c") 16 | TestTable{item: cte, wantQuery: "cte"}.assert(t) 17 | field := NewAnyField("ff", TableStruct{name: "cte", alias: "c"}) 18 | if diff := testutil.Diff(cte.Field("ff"), field); diff != "" { 19 | t.Error(testutil.Callers(), diff) 20 | } 21 | if diff := testutil.Diff(cte.materialized, sql.NullBool{Valid: true, Bool: false}); diff != "" { 22 | t.Error(testutil.Callers(), diff) 23 | } 24 | if diff := testutil.Diff(cte.GetAlias(), "c"); diff != "" { 25 | t.Error(testutil.Callers(), diff) 26 | } 27 | }) 28 | } 29 | 30 | func TestCTEs(t *testing.T) { 31 | type TT struct { 32 | description string 33 | dialect string 34 | ctes []CTE 35 | wantQuery string 36 | wantArgs []any 37 | wantParams map[string][]int 38 | } 39 | 40 | tests := []TT{{ 41 | description: "basic", 42 | ctes: []CTE{NewCTE("cte", nil, Queryf("SELECT 1"))}, 43 | wantQuery: "WITH cte AS (SELECT 1) ", 44 | }, { 45 | description: "recursive", 46 | ctes: []CTE{ 47 | NewCTE("cte", nil, Queryf("SELECT 1")), 48 | NewRecursiveCTE("nums", []string{"n"}, Union( 49 | Queryf("SELECT 1"), 50 | Queryf("SELECT n+1 FROM nums WHERE n < 10"), 51 | )), 52 | }, 53 | wantQuery: "WITH RECURSIVE cte AS (SELECT 1)" + 54 | ", nums (n) AS (SELECT 1 UNION SELECT n+1 FROM nums WHERE n < 10) ", 55 | }, { 56 | description: "mysql materialized", 57 | dialect: DialectMySQL, 58 | ctes: []CTE{NewCTE("cte", nil, Queryf("SELECT 1")).Materialized()}, 59 | wantQuery: "WITH cte AS (SELECT 1) ", 60 | }, { 61 | description: "postgres materialized", 62 | dialect: DialectPostgres, 63 | ctes: []CTE{NewCTE("cte", nil, Queryf("SELECT 1")).Materialized()}, 64 | wantQuery: "WITH cte AS MATERIALIZED (SELECT 1) ", 65 | }, { 66 | description: "postgres not materialized", 67 | dialect: DialectPostgres, 68 | ctes: []CTE{NewCTE("cte", nil, Queryf("SELECT 1")).NotMaterialized()}, 69 | wantQuery: "WITH cte AS NOT MATERIALIZED (SELECT 1) ", 70 | }} 71 | 72 | for _, tt := range tests { 73 | tt := tt 74 | t.Run(tt.description, func(t *testing.T) { 75 | t.Parallel() 76 | buf, args, params := bufpool.Get().(*bytes.Buffer), &[]any{}, make(map[string][]int) 77 | buf.Reset() 78 | defer bufpool.Put(buf) 79 | err := writeCTEs(context.Background(), tt.dialect, buf, args, params, tt.ctes) 80 | if err != nil { 81 | t.Fatal(testutil.Callers(), err) 82 | } 83 | if diff := testutil.Diff(buf.String(), tt.wantQuery); diff != "" { 84 | t.Error(testutil.Callers(), diff) 85 | } 86 | if diff := testutil.Diff(*args, tt.wantArgs); diff != "" { 87 | t.Error(testutil.Callers(), diff) 88 | } 89 | if diff := testutil.Diff(params, tt.wantParams); diff != "" { 90 | t.Error(testutil.Callers(), diff) 91 | } 92 | }) 93 | } 94 | 95 | t.Run("invalid cte", func(t *testing.T) { 96 | t.Parallel() 97 | buf, args, params := bufpool.Get().(*bytes.Buffer), &[]any{}, make(map[string][]int) 98 | buf.Reset() 99 | defer bufpool.Put(buf) 100 | // no name 101 | err := writeCTEs(context.Background(), "", buf, args, params, []CTE{ 102 | NewCTE("", nil, Queryf("SELECT 1")), 103 | }) 104 | if err == nil { 105 | t.Fatal(testutil.Callers(), "expected error but got nil") 106 | } 107 | // no query 108 | err = writeCTEs(context.Background(), "", buf, args, params, []CTE{ 109 | NewCTE("cte", nil, nil), 110 | }) 111 | if err == nil { 112 | t.Fatal(testutil.Callers(), "expected error but got nil") 113 | } 114 | }) 115 | 116 | t.Run("err", func(t *testing.T) { 117 | t.Parallel() 118 | buf, args, params := bufpool.Get().(*bytes.Buffer), &[]any{}, make(map[string][]int) 119 | buf.Reset() 120 | defer bufpool.Put(buf) 121 | // VariadicQuery 122 | err := writeCTEs(context.Background(), "", buf, args, params, []CTE{ 123 | NewCTE("cte", nil, Union( 124 | Queryf("SELECT 1"), 125 | Queryf("SELECT {}", FaultySQL{}), 126 | )), 127 | }) 128 | if !errors.Is(err, ErrFaultySQL) { 129 | t.Errorf(testutil.Callers()+"expected error %q but got %q", ErrFaultySQL, err) 130 | } 131 | // Query 132 | err = writeCTEs(context.Background(), "", buf, args, params, []CTE{ 133 | NewCTE("cte", nil, Queryf("SELECT {}", FaultySQL{})), 134 | }) 135 | if !errors.Is(err, ErrFaultySQL) { 136 | t.Errorf(testutil.Callers()+"expected error %q but got %q", ErrFaultySQL, err) 137 | } 138 | }) 139 | } 140 | 141 | func TestVariadicQuery(t *testing.T) { 142 | q1, q2, q3 := Queryf("SELECT 1"), Queryf("SELECT 2"), Queryf("SELECT 3") 143 | tests := []TestTable{{ 144 | description: "Union", 145 | item: Union(q1, q2, q3), 146 | wantQuery: "(SELECT 1 UNION SELECT 2 UNION SELECT 3)", 147 | }, { 148 | description: "UnionAll", 149 | item: UnionAll(q1, q2, q3), 150 | wantQuery: "(SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3)", 151 | }, { 152 | description: "Intersect", 153 | item: Intersect(q1, q2, q3), 154 | wantQuery: "(SELECT 1 INTERSECT SELECT 2 INTERSECT SELECT 3)", 155 | }, { 156 | description: "IntersectAll", 157 | item: IntersectAll(q1, q2, q3), 158 | wantQuery: "(SELECT 1 INTERSECT ALL SELECT 2 INTERSECT ALL SELECT 3)", 159 | }, { 160 | description: "Except", 161 | item: Except(q1, q2, q3), 162 | wantQuery: "(SELECT 1 EXCEPT SELECT 2 EXCEPT SELECT 3)", 163 | }, { 164 | description: "ExceptAll", 165 | item: ExceptAll(q1, q2, q3), 166 | wantQuery: "(SELECT 1 EXCEPT ALL SELECT 2 EXCEPT ALL SELECT 3)", 167 | }, { 168 | description: "No operator specified", 169 | item: VariadicQuery{Queries: []Query{q1, q2, q3}}, 170 | wantQuery: "(SELECT 1 UNION SELECT 2 UNION SELECT 3)", 171 | }, { 172 | description: "nested VariadicQuery", 173 | item: Union(Union(Union(q1, q2, q3))), 174 | wantQuery: "(SELECT 1 UNION SELECT 2 UNION SELECT 3)", 175 | }, { 176 | description: "1 query", 177 | item: Union(q1), 178 | wantQuery: "SELECT 1", 179 | }} 180 | 181 | for _, tt := range tests { 182 | tt := tt 183 | t.Run(tt.description, func(t *testing.T) { 184 | t.Parallel() 185 | tt.assert(t) 186 | }) 187 | } 188 | 189 | t.Run("invalid VariadicQuery", func(t *testing.T) { 190 | t.Parallel() 191 | // empty 192 | TestTable{item: Union()}.assertNotOK(t) 193 | // nil query 194 | TestTable{item: Union(nil)}.assertNotOK(t) 195 | // nil query 196 | TestTable{item: Union(q1, q2, nil)}.assertNotOK(t) 197 | }) 198 | 199 | t.Run("err", func(t *testing.T) { 200 | t.Parallel() 201 | // VariadicQuery 202 | TestTable{ 203 | item: Union( 204 | Union( 205 | Queryf("SELECT 1"), 206 | Queryf("SELECT {}", FaultySQL{}), 207 | ), 208 | ), 209 | }.assertErr(t, ErrFaultySQL) 210 | // Query 211 | TestTable{ 212 | item: Union(Queryf("SELECT {}", FaultySQL{})), 213 | }.assertErr(t, ErrFaultySQL) 214 | }) 215 | 216 | t.Run("SetFetchableFields", func(t *testing.T) { 217 | t.Parallel() 218 | _, ok := Union().SetFetchableFields([]Field{Expr("f1")}) 219 | if ok { 220 | t.Error(testutil.Callers(), "expected not ok but got ok") 221 | } 222 | }) 223 | 224 | t.Run("GetDialect", func(t *testing.T) { 225 | // empty VariadicQuery 226 | if diff := testutil.Diff(Union().GetDialect(), ""); diff != "" { 227 | t.Error(testutil.Callers(), diff) 228 | } 229 | // nil query 230 | if diff := testutil.Diff(Union(nil).GetDialect(), ""); diff != "" { 231 | t.Error(testutil.Callers(), diff) 232 | } 233 | // empty dialect propagated 234 | if diff := testutil.Diff(Union(Queryf("SELECT 1")).GetDialect(), ""); diff != "" { 235 | t.Error(testutil.Callers(), diff) 236 | } 237 | }) 238 | } 239 | -------------------------------------------------------------------------------- /delete_query.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | ) 8 | 9 | // DeleteQuery represents an SQL DELETE query. 10 | type DeleteQuery struct { 11 | Dialect string 12 | // WITH 13 | CTEs []CTE 14 | // DELETE FROM 15 | DeleteTable Table 16 | DeleteTables []Table 17 | // USING 18 | UsingTable Table 19 | JoinTables []JoinTable 20 | // WHERE 21 | WherePredicate Predicate 22 | // ORDER BY 23 | OrderByFields Fields 24 | // LIMIT 25 | LimitRows any 26 | // OFFSET 27 | OffsetRows any 28 | // RETURNING 29 | ReturningFields []Field 30 | } 31 | 32 | var _ Query = (*DeleteQuery)(nil) 33 | 34 | // WriteSQL implements the SQLWriter interface. 35 | func (q DeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 36 | var err error 37 | // Table Policies 38 | var policies []Predicate 39 | policies, err = appendPolicy(ctx, dialect, policies, q.DeleteTable) 40 | if err != nil { 41 | return fmt.Errorf("DELETE FROM %s Policy: %w", toString(q.Dialect, q.DeleteTable), err) 42 | } 43 | policies, err = appendPolicy(ctx, dialect, policies, q.UsingTable) 44 | if err != nil { 45 | return fmt.Errorf("USING %s Policy: %w", toString(q.Dialect, q.UsingTable), err) 46 | } 47 | for _, joinTable := range q.JoinTables { 48 | policies, err = appendPolicy(ctx, dialect, policies, joinTable.Table) 49 | if err != nil { 50 | return fmt.Errorf("%s %s Policy: %w", joinTable.JoinOperator, joinTable.Table, err) 51 | } 52 | } 53 | if len(policies) > 0 { 54 | if q.WherePredicate != nil { 55 | policies = append(policies, q.WherePredicate) 56 | } 57 | q.WherePredicate = And(policies...) 58 | } 59 | // WITH 60 | if len(q.CTEs) > 0 { 61 | err = writeCTEs(ctx, dialect, buf, args, params, q.CTEs) 62 | if err != nil { 63 | return fmt.Errorf("WITH: %w", err) 64 | } 65 | } 66 | // DELETE FROM 67 | if (dialect == DialectMySQL || dialect == DialectSQLServer) && len(q.DeleteTables) > 0 { 68 | buf.WriteString("DELETE ") 69 | if len(q.DeleteTables) > 1 && dialect != DialectMySQL { 70 | return fmt.Errorf("dialect %q does not support multi-table DELETE", dialect) 71 | } 72 | for i, table := range q.DeleteTables { 73 | if i > 0 { 74 | buf.WriteString(", ") 75 | } 76 | if alias := getAlias(table); alias != "" { 77 | buf.WriteString(alias) 78 | } else { 79 | err = table.WriteSQL(ctx, dialect, buf, args, params) 80 | if err != nil { 81 | return fmt.Errorf("table #%d: %w", i+1, err) 82 | } 83 | } 84 | } 85 | } else { 86 | buf.WriteString("DELETE FROM ") 87 | if q.DeleteTable == nil { 88 | return fmt.Errorf("no table provided to DELETE FROM") 89 | } 90 | err = q.DeleteTable.WriteSQL(ctx, dialect, buf, args, params) 91 | if err != nil { 92 | return fmt.Errorf("DELETE FROM: %w", err) 93 | } 94 | if dialect != DialectSQLServer { 95 | if alias := getAlias(q.DeleteTable); alias != "" { 96 | buf.WriteString(" AS " + QuoteIdentifier(dialect, alias)) 97 | } 98 | } 99 | } 100 | if q.UsingTable != nil || len(q.JoinTables) > 0 { 101 | if dialect != DialectPostgres && dialect != DialectMySQL && dialect != DialectSQLServer { 102 | return fmt.Errorf("%s DELETE does not support JOIN", dialect) 103 | } 104 | } 105 | // OUTPUT 106 | if len(q.ReturningFields) > 0 && dialect == DialectSQLServer { 107 | buf.WriteString(" OUTPUT ") 108 | err = writeFieldsWithPrefix(ctx, dialect, buf, args, params, q.ReturningFields, "DELETED", true) 109 | if err != nil { 110 | return err 111 | } 112 | } 113 | // USING/FROM 114 | if q.UsingTable != nil { 115 | switch dialect { 116 | case DialectPostgres: 117 | buf.WriteString(" USING ") 118 | err = q.UsingTable.WriteSQL(ctx, dialect, buf, args, params) 119 | if err != nil { 120 | return fmt.Errorf("USING: %w", err) 121 | } 122 | case DialectMySQL, DialectSQLServer: 123 | buf.WriteString(" FROM ") 124 | err = q.UsingTable.WriteSQL(ctx, dialect, buf, args, params) 125 | if err != nil { 126 | return fmt.Errorf("FROM: %w", err) 127 | } 128 | } 129 | if alias := getAlias(q.UsingTable); alias != "" { 130 | buf.WriteString(" AS " + QuoteIdentifier(dialect, alias)) 131 | } 132 | } 133 | // JOIN 134 | if len(q.JoinTables) > 0 { 135 | if q.UsingTable == nil { 136 | return fmt.Errorf("%s can't JOIN without a USING/FROM table", dialect) 137 | } 138 | buf.WriteString(" ") 139 | err = writeJoinTables(ctx, dialect, buf, args, params, q.JoinTables) 140 | if err != nil { 141 | return fmt.Errorf("JOIN: %w", err) 142 | } 143 | } 144 | // WHERE 145 | if q.WherePredicate != nil { 146 | buf.WriteString(" WHERE ") 147 | switch predicate := q.WherePredicate.(type) { 148 | case VariadicPredicate: 149 | predicate.Toplevel = true 150 | err = predicate.WriteSQL(ctx, dialect, buf, args, params) 151 | if err != nil { 152 | return fmt.Errorf("WHERE: %w", err) 153 | } 154 | default: 155 | err = q.WherePredicate.WriteSQL(ctx, dialect, buf, args, params) 156 | if err != nil { 157 | return fmt.Errorf("WHERE: %w", err) 158 | } 159 | } 160 | } 161 | // ORDER BY 162 | if len(q.OrderByFields) > 0 { 163 | if dialect != DialectMySQL { 164 | return fmt.Errorf("%s UPDATE does not support ORDER BY", dialect) 165 | } 166 | buf.WriteString(" ORDER BY ") 167 | err = q.OrderByFields.WriteSQL(ctx, dialect, buf, args, params) 168 | if err != nil { 169 | return fmt.Errorf("ORDER BY: %w", err) 170 | } 171 | } 172 | // LIMIT 173 | if q.LimitRows != nil { 174 | if dialect != DialectMySQL { 175 | return fmt.Errorf("%s UPDATE does not support LIMIT", dialect) 176 | } 177 | buf.WriteString(" LIMIT ") 178 | err = WriteValue(ctx, dialect, buf, args, params, q.LimitRows) 179 | if err != nil { 180 | return fmt.Errorf("LIMIT: %w", err) 181 | } 182 | } 183 | // RETURNING 184 | if len(q.ReturningFields) > 0 && dialect != DialectSQLServer { 185 | if dialect != DialectPostgres && dialect != DialectSQLite && dialect != DialectMySQL { 186 | return fmt.Errorf("%s UPDATE does not support RETURNING", dialect) 187 | } 188 | buf.WriteString(" RETURNING ") 189 | err = writeFields(ctx, dialect, buf, args, params, q.ReturningFields, true) 190 | if err != nil { 191 | return fmt.Errorf("RETURNING: %w", err) 192 | } 193 | } 194 | return nil 195 | } 196 | 197 | // DeleteFrom returns a new DeleteQuery. 198 | func DeleteFrom(table Table) DeleteQuery { 199 | return DeleteQuery{DeleteTable: table} 200 | } 201 | 202 | // Where appends to the WherePredicate field of the DeleteQuery. 203 | func (q DeleteQuery) Where(predicates ...Predicate) DeleteQuery { 204 | q.WherePredicate = appendPredicates(q.WherePredicate, predicates) 205 | return q 206 | } 207 | 208 | // SetFetchableFields implements the Query interface. 209 | func (q DeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { 210 | switch q.Dialect { 211 | case DialectPostgres, DialectSQLite: 212 | if len(q.ReturningFields) == 0 { 213 | q.ReturningFields = fields 214 | return q, true 215 | } 216 | return q, false 217 | default: 218 | return q, false 219 | } 220 | } 221 | 222 | // GetFetchableFields returns the fetchable fields of the query. 223 | func (q DeleteQuery) GetFetchableFields() []Field { 224 | switch q.Dialect { 225 | case DialectPostgres, DialectSQLite: 226 | return q.ReturningFields 227 | default: 228 | return nil 229 | } 230 | } 231 | 232 | // GetDialect implements the Query interface. 233 | func (q DeleteQuery) GetDialect() string { return q.Dialect } 234 | 235 | // SetDialect sets the dialect of the query. 236 | func (q DeleteQuery) SetDialect(dialect string) DeleteQuery { 237 | q.Dialect = dialect 238 | return q 239 | } 240 | 241 | // SQLiteDeleteQuery represents an SQLite DELETE query. 242 | type SQLiteDeleteQuery DeleteQuery 243 | 244 | var _ Query = (*SQLiteDeleteQuery)(nil) 245 | 246 | // WriteSQL implements the SQLWriter interface. 247 | func (q SQLiteDeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 248 | return DeleteQuery(q).WriteSQL(ctx, dialect, buf, args, params) 249 | } 250 | 251 | // DeleteFrom returns a new SQLiteDeleteQuery. 252 | func (b sqliteQueryBuilder) DeleteFrom(table Table) SQLiteDeleteQuery { 253 | return SQLiteDeleteQuery{ 254 | Dialect: DialectSQLite, 255 | CTEs: b.ctes, 256 | DeleteTable: table, 257 | } 258 | } 259 | 260 | // Where appends to the WherePredicate field of the SQLiteDeleteQuery. 261 | func (q SQLiteDeleteQuery) Where(predicates ...Predicate) SQLiteDeleteQuery { 262 | q.WherePredicate = appendPredicates(q.WherePredicate, predicates) 263 | return q 264 | } 265 | 266 | // Returning appends fields to the RETURNING clause of the SQLiteDeleteQuery. 267 | func (q SQLiteDeleteQuery) Returning(fields ...Field) SQLiteDeleteQuery { 268 | q.ReturningFields = append(q.ReturningFields, fields...) 269 | return q 270 | } 271 | 272 | // SetFetchableFields implements the Query interface. 273 | func (q SQLiteDeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { 274 | return DeleteQuery(q).SetFetchableFields(fields) 275 | } 276 | 277 | // GetFetchableFields returns the fetchable fields of the query. 278 | func (q SQLiteDeleteQuery) GetFetchableFields() []Field { 279 | return DeleteQuery(q).GetFetchableFields() 280 | } 281 | 282 | // GetDialect implements the Query interface. 283 | func (q SQLiteDeleteQuery) GetDialect() string { return q.Dialect } 284 | 285 | // SetDialect sets the dialect of the query. 286 | func (q SQLiteDeleteQuery) SetDialect(dialect string) SQLiteDeleteQuery { 287 | q.Dialect = dialect 288 | return q 289 | } 290 | 291 | // PostgresDeleteQuery represents a Postgres DELETE query. 292 | type PostgresDeleteQuery DeleteQuery 293 | 294 | var _ Query = (*PostgresDeleteQuery)(nil) 295 | 296 | // WriteSQL implements the SQLWriter interface. 297 | func (q PostgresDeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 298 | return DeleteQuery(q).WriteSQL(ctx, dialect, buf, args, params) 299 | } 300 | 301 | // DeleteFrom returns a new PostgresDeleteQuery. 302 | func (b postgresQueryBuilder) DeleteFrom(table Table) PostgresDeleteQuery { 303 | return PostgresDeleteQuery{ 304 | Dialect: DialectPostgres, 305 | CTEs: b.ctes, 306 | DeleteTable: table, 307 | } 308 | } 309 | 310 | // Using sets the UsingTable field of the PostgresDeleteQuery. 311 | func (q PostgresDeleteQuery) Using(table Table) PostgresDeleteQuery { 312 | q.UsingTable = table 313 | return q 314 | } 315 | 316 | // Join joins a new Table to the PostgresDeleteQuery. 317 | func (q PostgresDeleteQuery) Join(table Table, predicates ...Predicate) PostgresDeleteQuery { 318 | q.JoinTables = append(q.JoinTables, Join(table, predicates...)) 319 | return q 320 | } 321 | 322 | // LeftJoin left joins a new Table to the PostgresDeleteQuery. 323 | func (q PostgresDeleteQuery) LeftJoin(table Table, predicates ...Predicate) PostgresDeleteQuery { 324 | q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) 325 | return q 326 | } 327 | 328 | // FullJoin full joins a new Table to the PostgresDeleteQuery. 329 | func (q PostgresDeleteQuery) FullJoin(table Table, predicates ...Predicate) PostgresDeleteQuery { 330 | q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...)) 331 | return q 332 | } 333 | 334 | // CrossJoin cross joins a new Table to the PostgresDeleteQuery. 335 | func (q PostgresDeleteQuery) CrossJoin(table Table) PostgresDeleteQuery { 336 | q.JoinTables = append(q.JoinTables, CrossJoin(table)) 337 | return q 338 | } 339 | 340 | // CustomJoin joins a new Table to the PostgresDeleteQuery with a custom join 341 | // operator. 342 | func (q PostgresDeleteQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) PostgresDeleteQuery { 343 | q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) 344 | return q 345 | } 346 | 347 | // JoinUsing joins a new Table to the PostgresDeleteQuery with the USING operator. 348 | func (q PostgresDeleteQuery) JoinUsing(table Table, fields ...Field) PostgresDeleteQuery { 349 | q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...)) 350 | return q 351 | } 352 | 353 | // Where appends to the WherePredicate field of the PostgresDeleteQuery. 354 | func (q PostgresDeleteQuery) Where(predicates ...Predicate) PostgresDeleteQuery { 355 | q.WherePredicate = appendPredicates(q.WherePredicate, predicates) 356 | return q 357 | } 358 | 359 | // Returning appends fields to the RETURNING clause of the PostgresDeleteQuery. 360 | func (q PostgresDeleteQuery) Returning(fields ...Field) PostgresDeleteQuery { 361 | q.ReturningFields = append(q.ReturningFields, fields...) 362 | return q 363 | } 364 | 365 | // SetFetchableFields implements the Query interface. 366 | func (q PostgresDeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { 367 | return DeleteQuery(q).SetFetchableFields(fields) 368 | } 369 | 370 | // GetFetchableFields returns the fetchable fields of the query. 371 | func (q PostgresDeleteQuery) GetFetchableFields() []Field { 372 | return DeleteQuery(q).GetFetchableFields() 373 | } 374 | 375 | // GetDialect implements the Query interface. 376 | func (q PostgresDeleteQuery) GetDialect() string { return q.Dialect } 377 | 378 | // SetDialect sets the dialect of the query. 379 | func (q PostgresDeleteQuery) SetDialect(dialect string) PostgresDeleteQuery { 380 | q.Dialect = dialect 381 | return q 382 | } 383 | 384 | // MySQLDeleteQuery represents a MySQL DELETE query. 385 | type MySQLDeleteQuery DeleteQuery 386 | 387 | var _ Query = (*MySQLDeleteQuery)(nil) 388 | 389 | // WriteSQL implements the SQLWriter interface. 390 | func (q MySQLDeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 391 | return DeleteQuery(q).WriteSQL(ctx, dialect, buf, args, params) 392 | } 393 | 394 | // DeleteFrom returns a new MySQLDeleteQuery. 395 | func (b mysqlQueryBuilder) DeleteFrom(table Table) MySQLDeleteQuery { 396 | return MySQLDeleteQuery{ 397 | Dialect: DialectMySQL, 398 | CTEs: b.ctes, 399 | DeleteTable: table, 400 | } 401 | } 402 | 403 | // Delete returns a new MySQLDeleteQuery. 404 | func (b mysqlQueryBuilder) Delete(tables ...Table) MySQLDeleteQuery { 405 | return MySQLDeleteQuery{ 406 | Dialect: DialectMySQL, 407 | CTEs: b.ctes, 408 | DeleteTables: tables, 409 | } 410 | } 411 | 412 | // From sets the UsingTable of the MySQLDeleteQuery. 413 | func (q MySQLDeleteQuery) From(table Table) MySQLDeleteQuery { 414 | q.UsingTable = table 415 | return q 416 | } 417 | 418 | // Join joins a new Table to the MySQLDeleteQuery. 419 | func (q MySQLDeleteQuery) Join(table Table, predicates ...Predicate) MySQLDeleteQuery { 420 | q.JoinTables = append(q.JoinTables, Join(table, predicates...)) 421 | return q 422 | } 423 | 424 | // LeftJoin left joins a new Table to the MySQLDeleteQuery. 425 | func (q MySQLDeleteQuery) LeftJoin(table Table, predicates ...Predicate) MySQLDeleteQuery { 426 | q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) 427 | return q 428 | } 429 | 430 | // FullJoin full joins a new Table to the MySQLDeleteQuery. 431 | func (q MySQLDeleteQuery) FullJoin(table Table, predicates ...Predicate) MySQLDeleteQuery { 432 | q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...)) 433 | return q 434 | } 435 | 436 | // CrossJoin cross joins a new Table to the MySQLDeleteQuery. 437 | func (q MySQLDeleteQuery) CrossJoin(table Table) MySQLDeleteQuery { 438 | q.JoinTables = append(q.JoinTables, CrossJoin(table)) 439 | return q 440 | } 441 | 442 | // CustomJoin joins a new Table to the MySQLDeleteQuery with a custom join 443 | // operator. 444 | func (q MySQLDeleteQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) MySQLDeleteQuery { 445 | q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) 446 | return q 447 | } 448 | 449 | // JoinUsing joins a new Table to the MySQLDeleteQuery with the USING operator. 450 | func (q MySQLDeleteQuery) JoinUsing(table Table, fields ...Field) MySQLDeleteQuery { 451 | q.JoinTables = append(q.JoinTables, JoinUsing(table, fields...)) 452 | return q 453 | } 454 | 455 | // Where appends to the WherePredicate field of the MySQLDeleteQuery. 456 | func (q MySQLDeleteQuery) Where(predicates ...Predicate) MySQLDeleteQuery { 457 | q.WherePredicate = appendPredicates(q.WherePredicate, predicates) 458 | return q 459 | } 460 | 461 | // OrderBy sets the OrderByFields field of the MySQLDeleteQuery. 462 | func (q MySQLDeleteQuery) OrderBy(fields ...Field) MySQLDeleteQuery { 463 | q.OrderByFields = append(q.OrderByFields, fields...) 464 | return q 465 | } 466 | 467 | // Limit sets the LimitRows field of the MySQLDeleteQuery. 468 | func (q MySQLDeleteQuery) Limit(limit any) MySQLDeleteQuery { 469 | q.LimitRows = limit 470 | return q 471 | } 472 | 473 | // Returning appends fields to the RETURNING clause of the MySQLDeleteQuery. 474 | func (q MySQLDeleteQuery) Returning(fields ...Field) MySQLDeleteQuery { 475 | q.ReturningFields = append(q.ReturningFields, fields...) 476 | return q 477 | } 478 | 479 | // SetFetchableFields implements the Query interface. 480 | func (q MySQLDeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { 481 | return DeleteQuery(q).SetFetchableFields(fields) 482 | } 483 | 484 | // GetFetchableFields returns the fetchable fields of the query. 485 | func (q MySQLDeleteQuery) GetFetchableFields() []Field { 486 | return DeleteQuery(q).GetFetchableFields() 487 | } 488 | 489 | // GetDialect implements the Query interface. 490 | func (q MySQLDeleteQuery) GetDialect() string { return q.Dialect } 491 | 492 | // SetDialect sets the dialect of the query. 493 | func (q MySQLDeleteQuery) SetDialect(dialect string) MySQLDeleteQuery { 494 | q.Dialect = dialect 495 | return q 496 | } 497 | 498 | // SQLServerDeleteQuery represents an SQL Server DELETE query. 499 | type SQLServerDeleteQuery DeleteQuery 500 | 501 | var _ Query = (*SQLServerDeleteQuery)(nil) 502 | 503 | // WriteSQL implements the SQLWriter interface. 504 | func (q SQLServerDeleteQuery) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 505 | return DeleteQuery(q).WriteSQL(ctx, dialect, buf, args, params) 506 | } 507 | 508 | // DeleteFrom returns a new SQLServerDeleteQuery. 509 | func (b sqlserverQueryBuilder) DeleteFrom(table Table) SQLServerDeleteQuery { 510 | return SQLServerDeleteQuery{ 511 | Dialect: DialectSQLServer, 512 | CTEs: b.ctes, 513 | DeleteTable: table, 514 | } 515 | } 516 | 517 | // Delete returns a new SQLServerDeleteQuery. 518 | func (b sqlserverQueryBuilder) Delete(table Table) SQLServerDeleteQuery { 519 | return SQLServerDeleteQuery{ 520 | Dialect: DialectSQLServer, 521 | CTEs: b.ctes, 522 | DeleteTables: []Table{table}, 523 | } 524 | } 525 | 526 | // From sets the UsingTable of the SQLServerDeleteQuery. 527 | func (q SQLServerDeleteQuery) From(table Table) SQLServerDeleteQuery { 528 | q.UsingTable = table 529 | return q 530 | } 531 | 532 | // Join joins a new Table to the SQLServerDeleteQuery. 533 | func (q SQLServerDeleteQuery) Join(table Table, predicates ...Predicate) SQLServerDeleteQuery { 534 | q.JoinTables = append(q.JoinTables, Join(table, predicates...)) 535 | return q 536 | } 537 | 538 | // LeftJoin left joins a new Table to the SQLServerDeleteQuery. 539 | func (q SQLServerDeleteQuery) LeftJoin(table Table, predicates ...Predicate) SQLServerDeleteQuery { 540 | q.JoinTables = append(q.JoinTables, LeftJoin(table, predicates...)) 541 | return q 542 | } 543 | 544 | // FullJoin full joins a new Table to the SQLServerDeleteQuery. 545 | func (q SQLServerDeleteQuery) FullJoin(table Table, predicates ...Predicate) SQLServerDeleteQuery { 546 | q.JoinTables = append(q.JoinTables, FullJoin(table, predicates...)) 547 | return q 548 | } 549 | 550 | // CrossJoin cross joins a new Table to the SQLServerDeleteQuery. 551 | func (q SQLServerDeleteQuery) CrossJoin(table Table) SQLServerDeleteQuery { 552 | q.JoinTables = append(q.JoinTables, CrossJoin(table)) 553 | return q 554 | } 555 | 556 | // CustomJoin joins a new Table to the SQLServerDeleteQuery with a custom join 557 | // operator. 558 | func (q SQLServerDeleteQuery) CustomJoin(joinOperator string, table Table, predicates ...Predicate) SQLServerDeleteQuery { 559 | q.JoinTables = append(q.JoinTables, CustomJoin(joinOperator, table, predicates...)) 560 | return q 561 | } 562 | 563 | // Where appends to the WherePredicate field of the SQLServerDeleteQuery. 564 | func (q SQLServerDeleteQuery) Where(predicates ...Predicate) SQLServerDeleteQuery { 565 | q.WherePredicate = appendPredicates(q.WherePredicate, predicates) 566 | return q 567 | } 568 | 569 | // SetFetchableFields implements the Query interface. 570 | func (q SQLServerDeleteQuery) SetFetchableFields(fields []Field) (query Query, ok bool) { 571 | return DeleteQuery(q).SetFetchableFields(fields) 572 | } 573 | 574 | // GetFetchableFields returns the fetchable fields of the query. 575 | func (q SQLServerDeleteQuery) GetFetchableFields() []Field { 576 | return DeleteQuery(q).GetFetchableFields() 577 | } 578 | 579 | // GetDialect implements the Query interface. 580 | func (q SQLServerDeleteQuery) GetDialect() string { return q.Dialect } 581 | 582 | // SetDialect sets the dialect of the query. 583 | func (q SQLServerDeleteQuery) SetDialect(dialect string) SQLServerDeleteQuery { 584 | q.Dialect = dialect 585 | return q 586 | } 587 | -------------------------------------------------------------------------------- /delete_query_test.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/bokwoon95/sq/internal/testutil" 7 | ) 8 | 9 | func TestSQLiteDeleteQuery(t *testing.T) { 10 | type ACTOR struct { 11 | TableStruct 12 | ACTOR_ID NumberField 13 | FIRST_NAME StringField 14 | LAST_NAME StringField 15 | LAST_UPDATE TimeField 16 | } 17 | a := New[ACTOR]("a") 18 | 19 | t.Run("basic", func(t *testing.T) { 20 | t.Parallel() 21 | q1 := SQLite.DeleteFrom(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum") 22 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { 23 | t.Error(testutil.Callers(), diff) 24 | } 25 | q1 = q1.SetDialect(DialectSQLite) 26 | fields := q1.GetFetchableFields() 27 | if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" { 28 | t.Error(testutil.Callers(), diff) 29 | } 30 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) 31 | if ok { 32 | t.Fatal(testutil.Callers(), "field should not have been set") 33 | } 34 | q1.ReturningFields = q1.ReturningFields[:0] 35 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) 36 | if !ok { 37 | t.Fatal(testutil.Callers(), "field should have been set") 38 | } 39 | }) 40 | 41 | t.Run("Delete Returning", func(t *testing.T) { 42 | t.Parallel() 43 | var tt TestTable 44 | tt.item = SQLite. 45 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 46 | DeleteFrom(a). 47 | Where(a.ACTOR_ID.EqInt(1)). 48 | Returning(a.FIRST_NAME, a.LAST_NAME) 49 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 50 | " DELETE FROM actor AS a" + 51 | " WHERE a.actor_id = $1" + 52 | " RETURNING a.first_name, a.last_name" 53 | tt.wantArgs = []any{1} 54 | tt.assert(t) 55 | }) 56 | } 57 | 58 | func TestPostgresDeleteQuery(t *testing.T) { 59 | type ACTOR struct { 60 | TableStruct 61 | ACTOR_ID NumberField 62 | FIRST_NAME StringField 63 | LAST_NAME StringField 64 | LAST_UPDATE TimeField 65 | } 66 | a := New[ACTOR]("a") 67 | 68 | t.Run("basic", func(t *testing.T) { 69 | t.Parallel() 70 | q1 := Postgres.DeleteFrom(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum") 71 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { 72 | t.Error(testutil.Callers(), diff) 73 | } 74 | q1 = q1.SetDialect(DialectPostgres) 75 | fields := q1.GetFetchableFields() 76 | if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" { 77 | t.Error(testutil.Callers(), diff) 78 | } 79 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) 80 | if ok { 81 | t.Fatal(testutil.Callers(), "field should not have been set") 82 | } 83 | q1.ReturningFields = q1.ReturningFields[:0] 84 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) 85 | if !ok { 86 | t.Fatal(testutil.Callers(), "field should have been set") 87 | } 88 | }) 89 | 90 | t.Run("Delete Returning", func(t *testing.T) { 91 | t.Parallel() 92 | var tt TestTable 93 | tt.item = Postgres. 94 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 95 | DeleteFrom(a). 96 | Where(a.ACTOR_ID.EqInt(1)). 97 | Returning(a.FIRST_NAME, a.LAST_NAME) 98 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 99 | " DELETE FROM actor AS a" + 100 | " WHERE a.actor_id = $1" + 101 | " RETURNING a.first_name, a.last_name" 102 | tt.wantArgs = []any{1} 103 | tt.assert(t) 104 | }) 105 | 106 | t.Run("Join", func(t *testing.T) { 107 | t.Parallel() 108 | var tt TestTable 109 | tt.item = Postgres. 110 | DeleteFrom(a). 111 | Using(a). 112 | Join(a, Expr("1 = 1")). 113 | LeftJoin(a, Expr("1 = 1")). 114 | FullJoin(a, Expr("1 = 1")). 115 | CrossJoin(a). 116 | CustomJoin(",", a). 117 | JoinUsing(a, a.FIRST_NAME, a.LAST_NAME) 118 | tt.wantQuery = "DELETE FROM actor AS a" + 119 | " USING actor AS a" + 120 | " JOIN actor AS a ON 1 = 1" + 121 | " LEFT JOIN actor AS a ON 1 = 1" + 122 | " FULL JOIN actor AS a ON 1 = 1" + 123 | " CROSS JOIN actor AS a" + 124 | " , actor AS a" + 125 | " JOIN actor AS a USING (first_name, last_name)" 126 | tt.assert(t) 127 | }) 128 | } 129 | 130 | func TestMySQLDeleteQuery(t *testing.T) { 131 | type ACTOR struct { 132 | TableStruct 133 | ACTOR_ID NumberField 134 | FIRST_NAME StringField 135 | LAST_NAME StringField 136 | LAST_UPDATE TimeField 137 | } 138 | a := New[ACTOR]("") 139 | 140 | t.Run("basic", func(t *testing.T) { 141 | t.Parallel() 142 | q1 := MySQL.DeleteFrom(a).SetDialect("lorem ipsum") 143 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { 144 | t.Error(testutil.Callers(), diff) 145 | } 146 | q1 = q1.SetDialect(DialectMySQL) 147 | fields := q1.GetFetchableFields() 148 | if len(fields) != 0 { 149 | t.Error(testutil.Callers(), "expected 0 fields but got %v", fields) 150 | } 151 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) 152 | if ok { 153 | t.Fatal(testutil.Callers(), "field should not have been set") 154 | } 155 | q1.ReturningFields = q1.ReturningFields[:0] 156 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) 157 | if ok { 158 | t.Fatal(testutil.Callers(), "field should not have been set") 159 | } 160 | }) 161 | 162 | t.Run("Where", func(t *testing.T) { 163 | t.Parallel() 164 | var tt TestTable 165 | tt.item = MySQL. 166 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 167 | DeleteFrom(a). 168 | Where(a.ACTOR_ID.EqInt(1)) 169 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 170 | " DELETE FROM actor" + 171 | " WHERE actor.actor_id = ?" 172 | tt.wantArgs = []any{1} 173 | tt.assert(t) 174 | }) 175 | 176 | t.Run("OrderBy Limit", func(t *testing.T) { 177 | t.Parallel() 178 | var tt TestTable 179 | tt.item = MySQL. 180 | DeleteFrom(a). 181 | OrderBy(a.ACTOR_ID). 182 | Limit(5) 183 | tt.wantQuery = "DELETE FROM actor" + 184 | " ORDER BY actor.actor_id" + 185 | " LIMIT ?" 186 | tt.wantArgs = []any{5} 187 | tt.assert(t) 188 | }) 189 | 190 | t.Run("Delete Returning", func(t *testing.T) { 191 | t.Parallel() 192 | var tt TestTable 193 | tt.item = MySQL. 194 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 195 | DeleteFrom(a). 196 | Where(a.ACTOR_ID.EqInt(1)). 197 | Returning(a.FIRST_NAME, a.LAST_NAME) 198 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 199 | " DELETE FROM actor" + 200 | " WHERE actor.actor_id = ?" + 201 | " RETURNING actor.first_name, actor.last_name" 202 | tt.wantArgs = []any{1} 203 | tt.assert(t) 204 | }) 205 | 206 | t.Run("Join", func(t *testing.T) { 207 | t.Parallel() 208 | var tt TestTable 209 | tt.item = MySQL. 210 | Delete(a). 211 | From(a). 212 | Join(a, Expr("1 = 1")). 213 | LeftJoin(a, Expr("1 = 1")). 214 | FullJoin(a, Expr("1 = 1")). 215 | CrossJoin(a). 216 | CustomJoin(",", a). 217 | JoinUsing(a, a.FIRST_NAME, a.LAST_NAME) 218 | tt.wantQuery = "DELETE actor" + 219 | " FROM actor" + 220 | " JOIN actor ON 1 = 1" + 221 | " LEFT JOIN actor ON 1 = 1" + 222 | " FULL JOIN actor ON 1 = 1" + 223 | " CROSS JOIN actor" + 224 | " , actor" + 225 | " JOIN actor USING (first_name, last_name)" 226 | tt.assert(t) 227 | }) 228 | } 229 | 230 | func TestSQLServerDeleteQuery(t *testing.T) { 231 | type ACTOR struct { 232 | TableStruct 233 | ACTOR_ID NumberField 234 | FIRST_NAME StringField 235 | LAST_NAME StringField 236 | LAST_UPDATE TimeField 237 | } 238 | a := New[ACTOR]("") 239 | 240 | t.Run("basic", func(t *testing.T) { 241 | t.Parallel() 242 | q1 := SQLServer.DeleteFrom(a).SetDialect("lorem ipsum") 243 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { 244 | t.Error(testutil.Callers(), diff) 245 | } 246 | q1 = q1.SetDialect(DialectSQLServer) 247 | q1 = q1.SetDialect(DialectMySQL) 248 | fields := q1.GetFetchableFields() 249 | if len(fields) != 0 { 250 | t.Error(testutil.Callers(), "expected 0 fields but got %v", fields) 251 | } 252 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) 253 | if ok { 254 | t.Fatal(testutil.Callers(), "field should not have been set") 255 | } 256 | q1.ReturningFields = q1.ReturningFields[:0] 257 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) 258 | if ok { 259 | t.Fatal(testutil.Callers(), "field should not have been set") 260 | } 261 | }) 262 | 263 | t.Run("Where", func(t *testing.T) { 264 | t.Parallel() 265 | var tt TestTable 266 | tt.item = SQLServer. 267 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 268 | DeleteFrom(a). 269 | Where(a.ACTOR_ID.EqInt(1)) 270 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 271 | " DELETE FROM actor" + 272 | " WHERE actor.actor_id = @p1" 273 | tt.wantArgs = []any{1} 274 | tt.assert(t) 275 | }) 276 | 277 | t.Run("Join", func(t *testing.T) { 278 | t.Parallel() 279 | var tt TestTable 280 | tt.item = SQLServer. 281 | DeleteFrom(a). 282 | From(a). 283 | Join(a, Expr("1 = 1")). 284 | LeftJoin(a, Expr("1 = 1")). 285 | FullJoin(a, Expr("1 = 1")). 286 | CrossJoin(a). 287 | CustomJoin(",", a) 288 | tt.wantQuery = "DELETE FROM actor" + 289 | " FROM actor" + 290 | " JOIN actor ON 1 = 1" + 291 | " LEFT JOIN actor ON 1 = 1" + 292 | " FULL JOIN actor ON 1 = 1" + 293 | " CROSS JOIN actor" + 294 | " , actor" 295 | tt.assert(t) 296 | }) 297 | } 298 | 299 | func TestDeleteQuery(t *testing.T) { 300 | t.Run("basic", func(t *testing.T) { 301 | t.Parallel() 302 | q1 := DeleteQuery{DeleteTable: Expr("tbl"), Dialect: "lorem ipsum"} 303 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { 304 | t.Error(testutil.Callers(), diff) 305 | } 306 | }) 307 | 308 | t.Run("PolicyTable", func(t *testing.T) { 309 | t.Parallel() 310 | var tt TestTable 311 | tt.item = DeleteQuery{ 312 | DeleteTable: policyTableStub{policy: And(Expr("1 = 1"), Expr("2 = 2"))}, 313 | WherePredicate: Expr("3 = 3"), 314 | } 315 | tt.wantQuery = "DELETE FROM policy_table_stub WHERE (1 = 1 AND 2 = 2) AND 3 = 3" 316 | tt.assert(t) 317 | }) 318 | 319 | notOKTests := []TestTable{{ 320 | description: "nil FromTable not allowed", 321 | item: DeleteQuery{ 322 | DeleteTable: nil, 323 | }, 324 | }, { 325 | description: "sqlite does not support JOIN", 326 | item: DeleteQuery{ 327 | Dialect: DialectSQLite, 328 | DeleteTable: Expr("tbl"), 329 | UsingTable: Expr("tbl"), 330 | JoinTables: []JoinTable{ 331 | Join(Expr("tbl"), Expr("1 = 1")), 332 | }, 333 | }, 334 | }, { 335 | description: "postgres does not allow JOIN without USING", 336 | item: DeleteQuery{ 337 | Dialect: DialectPostgres, 338 | DeleteTable: Expr("tbl"), 339 | JoinTables: []JoinTable{ 340 | Join(Expr("tbl"), Expr("1 = 1")), 341 | }, 342 | }, 343 | }, { 344 | description: "dialect does not support ORDER BY", 345 | item: DeleteQuery{ 346 | Dialect: DialectPostgres, 347 | DeleteTable: Expr("tbl"), 348 | OrderByFields: Fields{Expr("f1")}, 349 | }, 350 | }, { 351 | description: "dialect does not support LIMIT", 352 | item: DeleteQuery{ 353 | Dialect: DialectPostgres, 354 | DeleteTable: Expr("tbl"), 355 | LimitRows: 5, 356 | }, 357 | }} 358 | 359 | for _, tt := range notOKTests { 360 | tt := tt 361 | t.Run(tt.description, func(t *testing.T) { 362 | t.Parallel() 363 | tt.assertNotOK(t) 364 | }) 365 | } 366 | 367 | errTests := []TestTable{{ 368 | description: "FromTable Policy err", 369 | item: DeleteQuery{ 370 | DeleteTable: policyTableStub{err: ErrFaultySQL}, 371 | }, 372 | }, { 373 | description: "UsingTable Policy err", 374 | item: DeleteQuery{ 375 | DeleteTable: Expr("tbl"), 376 | UsingTable: policyTableStub{err: ErrFaultySQL}, 377 | }, 378 | }, { 379 | description: "JoinTables Policy err", 380 | item: DeleteQuery{ 381 | DeleteTable: Expr("tbl"), 382 | UsingTable: Expr("tbl"), 383 | JoinTables: []JoinTable{ 384 | Join(policyTableStub{err: ErrFaultySQL}, Expr("1 = 1")), 385 | }, 386 | }, 387 | }, { 388 | description: "CTEs err", 389 | item: DeleteQuery{ 390 | CTEs: []CTE{NewCTE("cte", nil, Queryf("SELECT {}", FaultySQL{}))}, 391 | DeleteTable: Expr("tbl"), 392 | }, 393 | }, { 394 | description: "FromTable err", 395 | item: DeleteQuery{ 396 | DeleteTable: FaultySQL{}, 397 | }, 398 | }, { 399 | description: "postgres UsingTable err", 400 | item: DeleteQuery{ 401 | Dialect: DialectPostgres, 402 | DeleteTable: Expr("tbl"), 403 | UsingTable: FaultySQL{}, 404 | }, 405 | }, { 406 | description: "sqlserver UsingTable err", 407 | item: DeleteQuery{ 408 | Dialect: DialectSQLServer, 409 | DeleteTable: Expr("tbl"), 410 | UsingTable: FaultySQL{}, 411 | }, 412 | }, { 413 | description: "JoinTables err", 414 | item: DeleteQuery{ 415 | Dialect: DialectPostgres, 416 | DeleteTable: Expr("tbl"), 417 | UsingTable: Expr("tbl"), 418 | JoinTables: []JoinTable{ 419 | Join(Expr("tbl"), FaultySQL{}), 420 | }, 421 | }, 422 | }, { 423 | description: "WherePredicate Variadic err", 424 | item: DeleteQuery{ 425 | DeleteTable: Expr("tbl"), 426 | WherePredicate: And(FaultySQL{}), 427 | }, 428 | }, { 429 | description: "WherePredicate err", 430 | item: DeleteQuery{ 431 | DeleteTable: Expr("tbl"), 432 | WherePredicate: FaultySQL{}, 433 | }, 434 | }, { 435 | description: "OrderByFields err", 436 | item: DeleteQuery{ 437 | Dialect: DialectMySQL, 438 | DeleteTable: Expr("tbl"), 439 | OrderByFields: Fields{FaultySQL{}}, 440 | }, 441 | }, { 442 | description: "LimitRows err", 443 | item: DeleteQuery{ 444 | Dialect: DialectMySQL, 445 | DeleteTable: Expr("tbl"), 446 | OrderByFields: Fields{Expr("f1")}, 447 | LimitRows: FaultySQL{}, 448 | }, 449 | }, { 450 | description: "ReturningFields err", 451 | item: DeleteQuery{ 452 | Dialect: DialectPostgres, 453 | DeleteTable: Expr("tbl"), 454 | ReturningFields: Fields{FaultySQL{}}, 455 | }, 456 | }} 457 | 458 | for _, tt := range errTests { 459 | tt := tt 460 | t.Run(tt.description, func(t *testing.T) { 461 | t.Parallel() 462 | tt.assertErr(t, ErrFaultySQL) 463 | }) 464 | } 465 | } 466 | -------------------------------------------------------------------------------- /fetch_exec_test.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | "time" 7 | 8 | "github.com/bokwoon95/sq/internal/testutil" 9 | _ "github.com/mattn/go-sqlite3" 10 | ) 11 | 12 | var ACTOR = New[struct { 13 | TableStruct `sq:"actor"` 14 | ACTOR_ID NumberField 15 | FIRST_NAME StringField 16 | LAST_NAME StringField 17 | LAST_UPDATE TimeField 18 | }]("") 19 | 20 | type Actor struct { 21 | ActorID int 22 | FirstName string 23 | LastName string 24 | LastUpdate time.Time 25 | } 26 | 27 | func actorRowMapper(row *Row) Actor { 28 | var actor Actor 29 | actorID, _ := row.Value("actor.actor_id").(int64) 30 | actor.ActorID = int(actorID) 31 | actor.FirstName = row.StringField(ACTOR.FIRST_NAME) 32 | actor.LastName = row.StringField(ACTOR.LAST_NAME) 33 | actor.LastUpdate, _ = row.Value("actor.last_update").(time.Time) 34 | return actor 35 | } 36 | 37 | func actorRowMapperRawSQL(row *Row) Actor { 38 | result := make(map[string]any) 39 | values := row.Values() 40 | for i, column := range row.Columns() { 41 | result[column] = values[i] 42 | } 43 | var actor Actor 44 | actorID, _ := result["actor_id"].(int64) 45 | actor.ActorID = int(actorID) 46 | actor.FirstName, _ = result["first_name"].(string) 47 | actor.LastName, _ = result["last_name"].(string) 48 | actor.LastUpdate, _ = result["last_update"].(time.Time) 49 | return actor 50 | } 51 | 52 | func Test_substituteParams(t *testing.T) { 53 | t.Run("no params provided", func(t *testing.T) { 54 | t.Parallel() 55 | args := []any{1, 2, 3} 56 | params := map[string][]int{"one": {0}, "two": {1}, "three": {2}} 57 | gotArgs, err := substituteParams("", args, params, nil) 58 | if err != nil { 59 | t.Fatal(testutil.Callers(), err) 60 | } 61 | wantArgs := []any{1, 2, 3} 62 | if diff := testutil.Diff(gotArgs, wantArgs); diff != "" { 63 | t.Error(testutil.Callers(), diff) 64 | } 65 | }) 66 | 67 | t.Run("not all params provided", func(t *testing.T) { 68 | t.Parallel() 69 | args := []any{1, 2, 3} 70 | params := map[string][]int{"one": {0}, "two": {1}, "three": {2}} 71 | paramValues := Params{"one": "One", "two": "Two"} 72 | gotArgs, err := substituteParams("", args, params, paramValues) 73 | if err != nil { 74 | t.Fatal(testutil.Callers(), err) 75 | } 76 | wantArgs := []any{"One", "Two", 3} 77 | if diff := testutil.Diff(gotArgs, wantArgs); diff != "" { 78 | t.Error(testutil.Callers(), diff) 79 | } 80 | }) 81 | 82 | t.Run("params substituted", func(t *testing.T) { 83 | t.Parallel() 84 | type Data struct { 85 | id int 86 | name string 87 | } 88 | args := []any{ 89 | 0, 90 | sql.Named("one", 1), 91 | sql.Named("two", 2), 92 | 3, 93 | } 94 | params := map[string][]int{ 95 | "zero": {0}, 96 | "one": {1}, 97 | "two": {2}, 98 | "three": {3}, 99 | } 100 | paramValues := Params{ 101 | "one": "[one]", 102 | "two": "[two]", 103 | "three": "[three]", 104 | } 105 | wantArgs := []any{ 106 | 0, 107 | sql.Named("one", "[one]"), 108 | sql.Named("two", "[two]"), 109 | "[three]", 110 | } 111 | gotArgs, err := substituteParams("", args, params, paramValues) 112 | if err != nil { 113 | t.Fatal(testutil.Callers(), err) 114 | } 115 | if diff := testutil.Diff(gotArgs, wantArgs); diff != "" { 116 | t.Error(testutil.Callers(), diff) 117 | } 118 | }) 119 | } 120 | 121 | func Test_getFieldMappings(t *testing.T) { 122 | type TestTable struct { 123 | description string 124 | dialect string 125 | fields []Field 126 | scanDest []any 127 | wantFieldMappings string 128 | } 129 | 130 | var tests = []TestTable{{ 131 | description: "empty", 132 | wantFieldMappings: "", 133 | }, { 134 | description: "basic", 135 | fields: []Field{ 136 | Expr("actor_id"), 137 | Expr("first_name || {} || last_name", " "), 138 | Expr("last_update"), 139 | }, 140 | scanDest: []any{ 141 | &sql.NullInt64{}, 142 | &sql.NullString{}, 143 | &sql.NullTime{}, 144 | }, 145 | wantFieldMappings: "" + 146 | "\n 01. actor_id => *sql.NullInt64" + 147 | "\n 02. first_name || ' ' || last_name => *sql.NullString" + 148 | "\n 03. last_update => *sql.NullTime", 149 | }} 150 | 151 | for _, tt := range tests { 152 | tt := tt 153 | t.Run(tt.description, func(t *testing.T) { 154 | t.Parallel() 155 | gotFieldMappings := getFieldMappings(tt.dialect, tt.fields, tt.scanDest) 156 | if diff := testutil.Diff(gotFieldMappings, tt.wantFieldMappings); diff != "" { 157 | t.Error(testutil.Callers(), diff) 158 | } 159 | }) 160 | } 161 | } 162 | 163 | func TestFetchExec(t *testing.T) { 164 | t.Parallel() 165 | db := newDB(t) 166 | 167 | var referenceActors = []Actor{ 168 | {ActorID: 1, FirstName: "PENELOPE", LastName: "GUINESS", LastUpdate: time.Unix(1, 0).UTC()}, 169 | {ActorID: 2, FirstName: "NICK", LastName: "WAHLBERG", LastUpdate: time.Unix(1, 0).UTC()}, 170 | {ActorID: 3, FirstName: "ED", LastName: "CHASE", LastUpdate: time.Unix(1, 0).UTC()}, 171 | {ActorID: 4, FirstName: "JENNIFER", LastName: "DAVIS", LastUpdate: time.Unix(1, 0).UTC()}, 172 | {ActorID: 5, FirstName: "JOHNNY", LastName: "LOLLOBRIGIDA", LastUpdate: time.Unix(1, 0).UTC()}, 173 | } 174 | 175 | // Exec. 176 | res, err := Exec(Log(db), SQLite. 177 | InsertInto(ACTOR). 178 | ColumnValues(func(col *Column) { 179 | for _, actor := range referenceActors { 180 | col.SetInt(ACTOR.ACTOR_ID, actor.ActorID) 181 | col.SetString(ACTOR.FIRST_NAME, actor.FirstName) 182 | col.SetString(ACTOR.LAST_NAME, actor.LastName) 183 | col.SetTime(ACTOR.LAST_UPDATE, actor.LastUpdate) 184 | } 185 | }), 186 | ) 187 | if err != nil { 188 | t.Fatal(testutil.Callers(), err) 189 | } 190 | if diff := testutil.Diff(res.RowsAffected, int64(len(referenceActors))); diff != "" { 191 | t.Fatal(testutil.Callers(), diff) 192 | } 193 | 194 | // FetchOne. 195 | actor, err := FetchOne(Log(db), SQLite. 196 | From(ACTOR). 197 | Where(ACTOR.ACTOR_ID.EqInt(1)), 198 | actorRowMapper, 199 | ) 200 | if err != nil { 201 | t.Fatal(testutil.Callers(), err) 202 | } 203 | if diff := testutil.Diff(actor, referenceActors[0]); diff != "" { 204 | t.Fatal(testutil.Callers(), diff) 205 | } 206 | 207 | // FetchOne (Raw SQL). 208 | actor, err = FetchOne(Log(db), 209 | SQLite.Queryf("SELECT * FROM actor WHERE actor_id = {}", 1), 210 | actorRowMapperRawSQL, 211 | ) 212 | if err != nil { 213 | t.Fatal(testutil.Callers(), err) 214 | } 215 | if diff := testutil.Diff(actor, referenceActors[0]); diff != "" { 216 | t.Fatal(testutil.Callers(), diff) 217 | } 218 | 219 | // FetchAll. 220 | actors, err := FetchAll(VerboseLog(db), SQLite. 221 | From(ACTOR). 222 | OrderBy(ACTOR.ACTOR_ID), 223 | actorRowMapper, 224 | ) 225 | if err != nil { 226 | t.Fatal(testutil.Callers(), err) 227 | } 228 | if diff := testutil.Diff(actors, referenceActors); diff != "" { 229 | t.Fatal(testutil.Callers(), err) 230 | } 231 | 232 | // FetchAll (RawSQL). 233 | actors, err = FetchAll(VerboseLog(db), 234 | SQLite.Queryf("SELECT * FROM actor ORDER BY actor_id"), 235 | actorRowMapperRawSQL, 236 | ) 237 | if err != nil { 238 | t.Fatal(testutil.Callers(), err) 239 | } 240 | if diff := testutil.Diff(actors, referenceActors); diff != "" { 241 | t.Fatal(testutil.Callers(), err) 242 | } 243 | } 244 | 245 | func TestCompiledFetchExec(t *testing.T) { 246 | t.Parallel() 247 | db := newDB(t) 248 | var referenceActors = []Actor{ 249 | {ActorID: 1, FirstName: "PENELOPE", LastName: "GUINESS", LastUpdate: time.Unix(1, 0).UTC()}, 250 | {ActorID: 2, FirstName: "NICK", LastName: "WAHLBERG", LastUpdate: time.Unix(1, 0).UTC()}, 251 | {ActorID: 3, FirstName: "ED", LastName: "CHASE", LastUpdate: time.Unix(1, 0).UTC()}, 252 | {ActorID: 4, FirstName: "JENNIFER", LastName: "DAVIS", LastUpdate: time.Unix(1, 0).UTC()}, 253 | {ActorID: 5, FirstName: "JOHNNY", LastName: "LOLLOBRIGIDA", LastUpdate: time.Unix(1, 0).UTC()}, 254 | } 255 | 256 | // CompiledExec. 257 | compiledExec, err := CompileExec(SQLite. 258 | InsertInto(ACTOR). 259 | ColumnValues(func(col *Column) { 260 | col.Set(ACTOR.ACTOR_ID, IntParam("actor_id", 0)) 261 | col.Set(ACTOR.FIRST_NAME, StringParam("first_name", "")) 262 | col.Set(ACTOR.LAST_NAME, StringParam("last_name", "")) 263 | col.Set(ACTOR.LAST_UPDATE, TimeParam("last_update", time.Time{})) 264 | }), 265 | ) 266 | if err != nil { 267 | t.Fatal(testutil.Callers(), err) 268 | } 269 | for _, actor := range referenceActors { 270 | _, err = compiledExec.Exec(Log(db), Params{ 271 | "actor_id": actor.ActorID, 272 | "first_name": actor.FirstName, 273 | "last_name": actor.LastName, 274 | "last_update": actor.LastUpdate, 275 | }) 276 | if err != nil { 277 | t.Fatal(testutil.Callers(), err) 278 | } 279 | } 280 | 281 | // CompiledFetch FetchOne. 282 | compiledFetch, err := CompileFetch(SQLite. 283 | From(ACTOR). 284 | Where(ACTOR.ACTOR_ID.Eq(IntParam("actor_id", 0))), 285 | actorRowMapper, 286 | ) 287 | if err != nil { 288 | t.Fatal(testutil.Callers(), err) 289 | } 290 | actor, err := compiledFetch.FetchOne(Log(db), Params{"actor_id": 1}) 291 | if err != nil { 292 | t.Fatal(testutil.Callers(), err) 293 | } 294 | if diff := testutil.Diff(actor, referenceActors[0]); diff != "" { 295 | t.Fatal(testutil.Callers(), diff) 296 | } 297 | 298 | // CompiledFetch FetchOne (Raw SQL). 299 | compiledFetch, err = CompileFetch( 300 | SQLite.Queryf("SELECT * FROM actor WHERE actor_id = {actor_id}", IntParam("actor_id", 0)), 301 | actorRowMapperRawSQL, 302 | ) 303 | if err != nil { 304 | t.Fatal(testutil.Callers(), err) 305 | } 306 | actor, err = compiledFetch.FetchOne(Log(db), Params{"actor_id": 1}) 307 | if err != nil { 308 | t.Fatal(testutil.Callers(), err) 309 | } 310 | if diff := testutil.Diff(actor, referenceActors[0]); diff != "" { 311 | t.Fatal(testutil.Callers(), diff) 312 | } 313 | 314 | // CompiledFetch FetchAll. 315 | compiledFetch, err = CompileFetch(SQLite. 316 | From(ACTOR). 317 | OrderBy(ACTOR.ACTOR_ID), 318 | actorRowMapper, 319 | ) 320 | if err != nil { 321 | t.Fatal(testutil.Callers(), err) 322 | } 323 | actors, err := compiledFetch.FetchAll(VerboseLog(db), nil) 324 | if err != nil { 325 | t.Fatal(testutil.Callers(), err) 326 | } 327 | if diff := testutil.Diff(actors, referenceActors); diff != "" { 328 | t.Fatal(testutil.Callers(), diff) 329 | } 330 | 331 | // CompiledFetch FetchAll (Raw SQL). 332 | compiledFetch, err = CompileFetch( 333 | SQLite.Queryf("SELECT * FROM actor ORDER BY actor_id"), 334 | actorRowMapperRawSQL, 335 | ) 336 | if err != nil { 337 | t.Fatal(testutil.Callers(), err) 338 | } 339 | actors, err = compiledFetch.FetchAll(VerboseLog(db), nil) 340 | if err != nil { 341 | t.Fatal(testutil.Callers(), err) 342 | } 343 | if diff := testutil.Diff(actors, referenceActors); diff != "" { 344 | t.Fatal(testutil.Callers(), diff) 345 | } 346 | } 347 | 348 | func TestPreparedFetchExec(t *testing.T) { 349 | t.Parallel() 350 | db := newDB(t) 351 | 352 | var referenceActors = []Actor{ 353 | {ActorID: 1, FirstName: "PENELOPE", LastName: "GUINESS", LastUpdate: time.Unix(1, 0).UTC()}, 354 | {ActorID: 2, FirstName: "NICK", LastName: "WAHLBERG", LastUpdate: time.Unix(1, 0).UTC()}, 355 | {ActorID: 3, FirstName: "ED", LastName: "CHASE", LastUpdate: time.Unix(1, 0).UTC()}, 356 | {ActorID: 4, FirstName: "JENNIFER", LastName: "DAVIS", LastUpdate: time.Unix(1, 0).UTC()}, 357 | {ActorID: 5, FirstName: "JOHNNY", LastName: "LOLLOBRIGIDA", LastUpdate: time.Unix(1, 0).UTC()}, 358 | } 359 | 360 | // PreparedExec. 361 | preparedExec, err := PrepareExec(Log(db), SQLite. 362 | InsertInto(ACTOR). 363 | ColumnValues(func(col *Column) { 364 | col.Set(ACTOR.ACTOR_ID, IntParam("actor_id", 0)) 365 | col.Set(ACTOR.FIRST_NAME, StringParam("first_name", "")) 366 | col.Set(ACTOR.LAST_NAME, StringParam("last_name", "")) 367 | col.Set(ACTOR.LAST_UPDATE, TimeParam("last_update", time.Time{})) 368 | }), 369 | ) 370 | if err != nil { 371 | t.Fatal(testutil.Callers(), err) 372 | } 373 | for _, actor := range referenceActors { 374 | _, err = preparedExec.Exec(Params{ 375 | "actor_id": actor.ActorID, 376 | "first_name": actor.FirstName, 377 | "last_name": actor.LastName, 378 | "last_update": actor.LastUpdate, 379 | }) 380 | if err != nil { 381 | t.Fatal(testutil.Callers(), err) 382 | } 383 | } 384 | 385 | // PreparedFetch FetchOne. 386 | preparedFetch, err := PrepareFetch(Log(db), SQLite. 387 | From(ACTOR). 388 | Where(ACTOR.ACTOR_ID.Eq(IntParam("actor_id", 0))), 389 | actorRowMapper, 390 | ) 391 | if err != nil { 392 | t.Fatal(testutil.Callers(), err) 393 | } 394 | actor, err := preparedFetch.FetchOne(Params{"actor_id": 1}) 395 | if err != nil { 396 | t.Fatal(testutil.Callers(), err) 397 | } 398 | if diff := testutil.Diff(actor, referenceActors[0]); diff != "" { 399 | t.Fatal(testutil.Callers(), diff) 400 | } 401 | 402 | // PreparedFetch FetchOne (Raw SQL). 403 | preparedFetch, err = PrepareFetch( 404 | Log(db), 405 | SQLite.Queryf("SELECT * FROM actor WHERE actor_id = {actor_id}", IntParam("actor_id", 0)), 406 | actorRowMapperRawSQL, 407 | ) 408 | if err != nil { 409 | t.Fatal(testutil.Callers(), err) 410 | } 411 | actor, err = preparedFetch.FetchOne(Params{"actor_id": 1}) 412 | if err != nil { 413 | t.Fatal(testutil.Callers(), err) 414 | } 415 | if diff := testutil.Diff(actor, referenceActors[0]); diff != "" { 416 | t.Fatal(testutil.Callers(), diff) 417 | } 418 | 419 | // PreparedFetch FetchAll. 420 | preparedFetch, err = PrepareFetch(VerboseLog(db), SQLite. 421 | From(ACTOR). 422 | OrderBy(ACTOR.ACTOR_ID), 423 | actorRowMapper, 424 | ) 425 | if err != nil { 426 | t.Fatal(testutil.Callers(), err) 427 | } 428 | actors, err := preparedFetch.FetchAll(nil) 429 | if err != nil { 430 | t.Fatal(testutil.Callers(), err) 431 | } 432 | if diff := testutil.Diff(actors, referenceActors); diff != "" { 433 | t.Fatal(testutil.Callers(), diff) 434 | } 435 | 436 | // PreparedFetch FetchAll (Raw SQL). 437 | preparedFetch, err = PrepareFetch( 438 | VerboseLog(db), 439 | SQLite.Queryf("SELECT * FROM actor ORDER BY actor_id"), 440 | actorRowMapperRawSQL, 441 | ) 442 | if err != nil { 443 | t.Fatal(testutil.Callers(), err) 444 | } 445 | actors, err = preparedFetch.FetchAll(nil) 446 | if err != nil { 447 | t.Fatal(testutil.Callers(), err) 448 | } 449 | if diff := testutil.Diff(actors, referenceActors); diff != "" { 450 | t.Fatal(testutil.Callers(), diff) 451 | } 452 | } 453 | 454 | func newDB(t *testing.T) *sql.DB { 455 | db, err := sql.Open("sqlite3", ":memory:") 456 | if err != nil { 457 | t.Fatal(testutil.Callers(), err) 458 | } 459 | _, err = db.Exec(`CREATE TABLE actor ( 460 | actor_id INTEGER PRIMARY KEY AUTOINCREMENT 461 | ,first_name TEXT NOT NULL 462 | ,last_name TEXT NOT NULL 463 | ,last_update DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP 464 | )`) 465 | if err != nil { 466 | t.Fatal(testutil.Callers(), err) 467 | } 468 | return db 469 | } 470 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/bokwoon95/sq 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/denisenkom/go-mssqldb v0.12.3 7 | github.com/go-sql-driver/mysql v1.7.1 8 | github.com/google/go-cmp v0.5.9 9 | github.com/google/uuid v1.3.0 10 | github.com/lib/pq v1.10.9 11 | github.com/mattn/go-sqlite3 v1.14.16 12 | ) 13 | 14 | require ( 15 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect 16 | github.com/golang-sql/sqlexp v0.1.0 // indirect 17 | golang.org/x/crypto v0.9.0 // indirect 18 | ) 19 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0/go.mod h1:h6H6c8enJmmocHUbLiiGY6sx7f9i+X3m1CHdd5c6Rdw= 2 | github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0/go.mod h1:HcM1YX14R7CJcghJGOYCgdezslRSVzqwLf/q+4Y2r/0= 3 | github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2TzfVZ1pCb5vxm4BtZPUdYWe/Xo8= 4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/denisenkom/go-mssqldb v0.12.3 h1:pBSGx9Tq67pBOTLmxNuirNTeB8Vjmf886Kx+8Y+8shw= 7 | github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDrorD1Vrm1KEz5hxDo= 8 | github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= 9 | github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= 10 | github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= 11 | github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= 12 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= 13 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= 14 | github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= 15 | github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= 16 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 17 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 18 | github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= 19 | github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 20 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 21 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 22 | github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= 23 | github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= 24 | github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= 25 | github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= 26 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 27 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 28 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 29 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 30 | golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 31 | golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 32 | golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= 33 | golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= 34 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 35 | golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 36 | golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 37 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 38 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 39 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 40 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 41 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 42 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 43 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 44 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 45 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 46 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 47 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 48 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= 49 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 50 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 51 | -------------------------------------------------------------------------------- /header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bokwoon95/sq/eae3b0c03361f5b98ac6d0701c6aa71c94d4e4c2/header.png -------------------------------------------------------------------------------- /internal/googleuuid/googleuuid.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2009,2014 Google Inc. All rights reserved. 2 | // 3 | // Redistribution and use in source and binary forms, with or without 4 | // modification, are permitted provided that the following conditions are 5 | // met: 6 | // 7 | // * Redistributions of source code must retain the above copyright 8 | // notice, this list of conditions and the following disclaimer. 9 | // * Redistributions in binary form must reproduce the above 10 | // copyright notice, this list of conditions and the following disclaimer 11 | // in the documentation and/or other materials provided with the 12 | // distribution. 13 | // * Neither the name of Google Inc. nor the names of its 14 | // contributors may be used to endorse or promote products derived from 15 | // this software without specific prior written permission. 16 | // 17 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | package googleuuid 30 | 31 | import ( 32 | "bytes" 33 | "encoding/hex" 34 | "errors" 35 | "fmt" 36 | "strings" 37 | ) 38 | 39 | // ParseBytes decodes b into a UUID or returns an error. Both the UUID form of 40 | // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and 41 | // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded. 42 | func ParseBytes(b []byte) (uuid [16]byte, err error) { 43 | switch len(b) { 44 | case 36: // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx 45 | case 36 + 9: // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx 46 | if !bytes.Equal(bytes.ToLower(b[:9]), []byte("urn:uuid:")) { 47 | return uuid, fmt.Errorf("invalid urn prefix: %q", b[:9]) 48 | } 49 | b = b[9:] 50 | case 36 + 2: // {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx} 51 | b = b[1:] 52 | case 32: // xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx 53 | var ok bool 54 | for i := 0; i < 32; i += 2 { 55 | uuid[i/2], ok = xtob(b[i], b[i+1]) 56 | if !ok { 57 | return uuid, errors.New("invalid UUID format") 58 | } 59 | } 60 | return uuid, nil 61 | default: 62 | return uuid, fmt.Errorf("invalid UUID length: %d", len(b)) 63 | } 64 | // s is now at least 36 bytes long 65 | // it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx 66 | if b[8] != '-' || b[13] != '-' || b[18] != '-' || b[23] != '-' { 67 | return uuid, errors.New("invalid UUID format") 68 | } 69 | for i, x := range [16]int{ 70 | 0, 2, 4, 6, 71 | 9, 11, 72 | 14, 16, 73 | 19, 21, 74 | 24, 26, 28, 30, 32, 34, 75 | } { 76 | v, ok := xtob(b[x], b[x+1]) 77 | if !ok { 78 | return uuid, errors.New("invalid UUID format") 79 | } 80 | uuid[i] = v 81 | } 82 | return uuid, nil 83 | } 84 | 85 | func Parse(s string) (uuid [16]byte, err error) { 86 | if len(s) != 36 { 87 | if len(s) != 36+9 { 88 | return uuid, fmt.Errorf("invalid UUID length: %d", len(s)) 89 | } 90 | if strings.ToLower(s[:9]) != "urn:uuid:" { 91 | return uuid, fmt.Errorf("invalid urn prefix: %q", s[:9]) 92 | } 93 | s = s[9:] 94 | } 95 | if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' { 96 | return uuid, errors.New("invalid UUID format") 97 | } 98 | for i, x := range [16]int{ 99 | 0, 2, 4, 6, 100 | 9, 11, 101 | 14, 16, 102 | 19, 21, 103 | 24, 26, 28, 30, 32, 34, 104 | } { 105 | v, ok := xtob(s[x], s[x+1]) 106 | if !ok { 107 | return uuid, errors.New("invalid UUID format") 108 | } 109 | uuid[i] = v 110 | } 111 | return uuid, nil 112 | } 113 | 114 | // xvalues returns the value of a byte as a hexadecimal digit or 255. 115 | var xvalues = [256]byte{ 116 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 117 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 118 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 119 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255, 120 | 255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255, 121 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 122 | 255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255, 123 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 124 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 125 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 126 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 128 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 129 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 130 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 131 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 132 | } 133 | 134 | // xtob converts hex characters x1 and x2 into a byte. 135 | func xtob(x1, x2 byte) (byte, bool) { 136 | b1 := xvalues[x1] 137 | b2 := xvalues[x2] 138 | return (b1 << 4) | b2, b1 != 255 && b2 != 255 139 | } 140 | 141 | // var buf [36]byte; encodeHex(buf[:], [16]byte{...}); string(buf[:]) 142 | func EncodeHex(dst []byte, uuid [16]byte) { 143 | hex.Encode(dst[:], uuid[:4]) 144 | dst[8] = '-' 145 | hex.Encode(dst[9:13], uuid[4:6]) 146 | dst[13] = '-' 147 | hex.Encode(dst[14:18], uuid[6:8]) 148 | dst[18] = '-' 149 | hex.Encode(dst[19:23], uuid[8:10]) 150 | dst[23] = '-' 151 | hex.Encode(dst[24:], uuid[10:]) 152 | } 153 | -------------------------------------------------------------------------------- /internal/testutil/testutil.go: -------------------------------------------------------------------------------- 1 | package testutil 2 | 3 | import ( 4 | "path/filepath" 5 | "reflect" 6 | "runtime" 7 | "strconv" 8 | "strings" 9 | 10 | "github.com/google/go-cmp/cmp" 11 | "github.com/google/go-cmp/cmp/cmpopts" 12 | ) 13 | 14 | func Diff[T any](got, want T) string { 15 | diff := cmp.Diff( 16 | got, want, 17 | cmp.Exporter(func(typ reflect.Type) bool { return true }), 18 | cmpopts.EquateEmpty(), 19 | ) 20 | if diff != "" { 21 | return "\n-got +want\n" + diff 22 | } 23 | return "" 24 | } 25 | 26 | func Callers() string { 27 | var pc [50]uintptr 28 | n := runtime.Callers(2, pc[:]) // skip runtime.Callers + Callers 29 | callsites := make([]string, 0, n) 30 | frames := runtime.CallersFrames(pc[:n]) 31 | for frame, more := frames.Next(); more; frame, more = frames.Next() { 32 | callsites = append(callsites, frame.File+":"+strconv.Itoa(frame.Line)) 33 | } 34 | callsites = callsites[:len(callsites)-1] // skip testing.tRunner 35 | if len(callsites) == 1 { 36 | return "" 37 | } 38 | var b strings.Builder 39 | b.WriteString("\n[") 40 | for i := len(callsites) - 1; i >= 0; i-- { 41 | if i < len(callsites)-1 { 42 | b.WriteString(" -> ") 43 | } 44 | b.WriteString(filepath.Base(callsites[i])) 45 | } 46 | b.WriteString("]") 47 | return b.String() 48 | } 49 | -------------------------------------------------------------------------------- /joins.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | ) 8 | 9 | // Join operators. 10 | const ( 11 | JoinInner = "JOIN" 12 | JoinLeft = "LEFT JOIN" 13 | JoinRight = "RIGHT JOIN" 14 | JoinFull = "FULL JOIN" 15 | JoinCross = "CROSS JOIN" 16 | ) 17 | 18 | // JoinTable represents a join on a table. 19 | type JoinTable struct { 20 | JoinOperator string 21 | Table Table 22 | OnPredicate Predicate 23 | UsingFields []Field 24 | } 25 | 26 | // JoinUsing creates a new JoinTable with the USING operator. 27 | func JoinUsing(table Table, fields ...Field) JoinTable { 28 | return JoinTable{JoinOperator: JoinInner, Table: table, UsingFields: fields} 29 | } 30 | 31 | // Join creates a new JoinTable with the JOIN operator. 32 | func Join(table Table, predicates ...Predicate) JoinTable { 33 | return CustomJoin(JoinInner, table, predicates...) 34 | } 35 | 36 | // LeftJoin creates a new JoinTable with the LEFT JOIN operator. 37 | func LeftJoin(table Table, predicates ...Predicate) JoinTable { 38 | return CustomJoin(JoinLeft, table, predicates...) 39 | } 40 | 41 | // FullJoin creates a new JoinTable with the FULL JOIN operator. 42 | func FullJoin(table Table, predicates ...Predicate) JoinTable { 43 | return CustomJoin(JoinFull, table, predicates...) 44 | } 45 | 46 | // CrossJoin creates a new JoinTable with the CROSS JOIN operator. 47 | func CrossJoin(table Table) JoinTable { 48 | return CustomJoin(JoinCross, table) 49 | } 50 | 51 | // CustomJoin creates a new JoinTable with the a custom join operator. 52 | func CustomJoin(joinOperator string, table Table, predicates ...Predicate) JoinTable { 53 | switch len(predicates) { 54 | case 0: 55 | return JoinTable{JoinOperator: joinOperator, Table: table} 56 | case 1: 57 | return JoinTable{JoinOperator: joinOperator, Table: table, OnPredicate: predicates[0]} 58 | default: 59 | return JoinTable{JoinOperator: joinOperator, Table: table, OnPredicate: And(predicates...)} 60 | } 61 | } 62 | 63 | // WriteSQL implements the SQLWriter interface. 64 | func (join JoinTable) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 65 | if join.JoinOperator == "" { 66 | join.JoinOperator = JoinInner 67 | } 68 | variadicPredicate, isVariadic := join.OnPredicate.(VariadicPredicate) 69 | hasNoPredicate := join.OnPredicate == nil && len(variadicPredicate.Predicates) == 0 && len(join.UsingFields) == 0 70 | if hasNoPredicate && (join.JoinOperator == JoinInner || 71 | join.JoinOperator == JoinLeft || 72 | join.JoinOperator == JoinRight || 73 | join.JoinOperator == JoinFull) && 74 | // exclude sqlite from this check because they allow join without predicate 75 | dialect != DialectSQLite { 76 | return fmt.Errorf("%s requires at least one predicate specified", join.JoinOperator) 77 | } 78 | if dialect == DialectSQLite && (join.JoinOperator == JoinRight || join.JoinOperator == JoinFull) { 79 | return fmt.Errorf("sqlite does not support %s", join.JoinOperator) 80 | } 81 | 82 | // JOIN 83 | buf.WriteString(string(join.JoinOperator) + " ") 84 | if join.Table == nil { 85 | return fmt.Errorf("joining on a nil table") 86 | } 87 | 88 | // 89 | _, isQuery := join.Table.(Query) 90 | if isQuery { 91 | buf.WriteString("(") 92 | } 93 | err := join.Table.WriteSQL(ctx, dialect, buf, args, params) 94 | if err != nil { 95 | return err 96 | } 97 | if isQuery { 98 | buf.WriteString(")") 99 | } 100 | 101 | // AS 102 | if tableAlias := getAlias(join.Table); tableAlias != "" { 103 | buf.WriteString(" AS " + QuoteIdentifier(dialect, tableAlias) + quoteTableColumns(dialect, join.Table)) 104 | } else if isQuery && dialect != DialectSQLite { 105 | return fmt.Errorf("%s %s subquery must have alias", dialect, join.JoinOperator) 106 | } 107 | 108 | if isVariadic { 109 | // ON VariadicPredicate 110 | buf.WriteString(" ON ") 111 | variadicPredicate.Toplevel = true 112 | err = variadicPredicate.WriteSQL(ctx, dialect, buf, args, params) 113 | if err != nil { 114 | return err 115 | } 116 | } else if join.OnPredicate != nil { 117 | // ON Predicate 118 | buf.WriteString(" ON ") 119 | err = join.OnPredicate.WriteSQL(ctx, dialect, buf, args, params) 120 | if err != nil { 121 | return err 122 | } 123 | } else if len(join.UsingFields) > 0 { 124 | // USING Fields 125 | buf.WriteString(" USING (") 126 | err = writeFieldsWithPrefix(ctx, dialect, buf, args, params, join.UsingFields, "", false) 127 | if err != nil { 128 | return err 129 | } 130 | buf.WriteString(")") 131 | } 132 | return nil 133 | } 134 | 135 | func writeJoinTables(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, joinTables []JoinTable) error { 136 | var err error 137 | for i, joinTable := range joinTables { 138 | if i > 0 { 139 | buf.WriteString(" ") 140 | } 141 | err = joinTable.WriteSQL(ctx, dialect, buf, args, params) 142 | if err != nil { 143 | return fmt.Errorf("join #%d: %w", i+1, err) 144 | } 145 | } 146 | return nil 147 | } 148 | -------------------------------------------------------------------------------- /joins_test.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import "testing" 4 | 5 | func TestJoinTables(t *testing.T) { 6 | type ACTOR struct { 7 | TableStruct 8 | ACTOR_ID NumberField 9 | FIRST_NAME StringField 10 | LAST_NAME StringField 11 | LAST_UPDATE TimeField 12 | } 13 | a := New[ACTOR]("a") 14 | 15 | tests := []TestTable{{ 16 | description: "JoinUsing", 17 | item: JoinUsing(a, a.FIRST_NAME, a.LAST_NAME), 18 | wantQuery: "JOIN actor AS a USING (first_name, last_name)", 19 | }, { 20 | description: "Join without operator", 21 | item: CustomJoin("", a, a.ACTOR_ID.Eq(a.ACTOR_ID), a.FIRST_NAME.Ne(a.LAST_NAME)), 22 | wantQuery: "JOIN actor AS a ON a.actor_id = a.actor_id AND a.first_name <> a.last_name", 23 | }, { 24 | description: "Join", 25 | item: Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)), 26 | wantQuery: "JOIN actor AS a ON a.actor_id = a.actor_id", 27 | }, { 28 | description: "LeftJoin", 29 | item: LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)), 30 | wantQuery: "LEFT JOIN actor AS a ON a.actor_id = a.actor_id", 31 | }, { 32 | description: "Right Join", 33 | item: JoinTable{JoinOperator: JoinRight, Table: a, OnPredicate: a.ACTOR_ID.Eq(a.ACTOR_ID)}, 34 | wantQuery: "RIGHT JOIN actor AS a ON a.actor_id = a.actor_id", 35 | }, { 36 | description: "FullJoin", 37 | item: FullJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)), 38 | wantQuery: "FULL JOIN actor AS a ON a.actor_id = a.actor_id", 39 | }, { 40 | description: "CrossJoin", 41 | item: CrossJoin(a), 42 | wantQuery: "CROSS JOIN actor AS a", 43 | }} 44 | 45 | for _, tt := range tests { 46 | tt := tt 47 | t.Run(tt.description, func(t *testing.T) { 48 | t.Parallel() 49 | tt.assert(t) 50 | }) 51 | } 52 | 53 | notOKTests := []TestTable{{ 54 | description: "full join has no predicate", 55 | item: FullJoin(a), 56 | }, { 57 | description: "sqlite does not support full join", 58 | dialect: DialectSQLite, 59 | item: FullJoin(a, Expr("TRUE")), 60 | }, { 61 | description: "table is nil", 62 | item: Join(nil, Expr("TRUE")), 63 | }, { 64 | description: "UsingField returns err", 65 | item: JoinUsing(a, nil), 66 | }} 67 | 68 | for _, tt := range notOKTests { 69 | tt := tt 70 | t.Run(tt.description, func(t *testing.T) { 71 | t.Parallel() 72 | tt.assertNotOK(t) 73 | }) 74 | } 75 | 76 | errTests := []TestTable{{ 77 | description: "table err", 78 | item: Join(FaultySQL{}, a.ACTOR_ID.Eq(a.ACTOR_ID)), 79 | }, { 80 | description: "VariadicPredicate err", 81 | item: Join(a, And(FaultySQL{})), 82 | }, { 83 | description: "predicate err", 84 | item: Join(a, FaultySQL{}), 85 | }} 86 | 87 | for _, tt := range errTests { 88 | tt := tt 89 | t.Run(tt.description, func(t *testing.T) { 90 | t.Parallel() 91 | tt.assertErr(t, ErrFaultySQL) 92 | }) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql" 7 | "fmt" 8 | "io" 9 | "log" 10 | "os" 11 | "path/filepath" 12 | "strconv" 13 | "strings" 14 | "sync/atomic" 15 | "time" 16 | ) 17 | 18 | // QueryStats represents the statistics from running a query. 19 | type QueryStats struct { 20 | // Dialect of the query. 21 | Dialect string 22 | 23 | // Query string. 24 | Query string 25 | 26 | // Args slice provided with the query string. 27 | Args []any 28 | 29 | // Params maps param names back to arguments in the args slice (by index). 30 | Params map[string][]int 31 | 32 | // Err is the error from running the query. 33 | Err error 34 | 35 | // RowCount from running the query. Not valid for Exec(). 36 | RowCount sql.NullInt64 37 | 38 | // RowsAffected by running the query. Not valid for 39 | // FetchOne/FetchAll/FetchCursor. 40 | RowsAffected sql.NullInt64 41 | 42 | // LastInsertId of the query. 43 | LastInsertId sql.NullInt64 44 | 45 | // Exists is the result of FetchExists(). 46 | Exists sql.NullBool 47 | 48 | // When the query started at. 49 | StartedAt time.Time 50 | 51 | // Time taken by the query. 52 | TimeTaken time.Duration 53 | 54 | // The caller file where the query was invoked. 55 | CallerFile string 56 | 57 | // The line in the caller file that invoked the query. 58 | CallerLine int 59 | 60 | // The name of the function where the query was invoked. 61 | CallerFunction string 62 | 63 | // The results from running the query (if it was provided). 64 | Results string 65 | } 66 | 67 | // LogSettings are the various log settings taken into account when producing 68 | // the QueryStats. 69 | type LogSettings struct { 70 | // Dispatch logging asynchronously (logs may arrive out of order which can be confusing, but it won't block function calls). 71 | LogAsynchronously bool 72 | 73 | // Include time taken by the query. 74 | IncludeTime bool 75 | 76 | // Include caller (filename and line number). 77 | IncludeCaller bool 78 | 79 | // Include fetched results. 80 | IncludeResults int 81 | } 82 | 83 | // SqLogger represents a logger for the sq package. 84 | type SqLogger interface { 85 | // SqLogSettings should populate a LogSettings struct, which influences 86 | // what is added into the QueryStats. 87 | SqLogSettings(context.Context, *LogSettings) 88 | 89 | // SqLogQuery logs a query when for the given QueryStats. 90 | SqLogQuery(context.Context, QueryStats) 91 | } 92 | 93 | type sqLogger struct { 94 | logger *log.Logger 95 | config LoggerConfig 96 | } 97 | 98 | // LoggerConfig is the config used for the sq logger. 99 | type LoggerConfig struct { 100 | // Dispatch logging asynchronously (logs may arrive out of order which can be confusing, but it won't block function calls). 101 | LogAsynchronously bool 102 | 103 | // Show time taken by the query. 104 | ShowTimeTaken bool 105 | 106 | // Show caller (filename and line number). 107 | ShowCaller bool 108 | 109 | // Show fetched results. 110 | ShowResults int 111 | 112 | // If true, logs are shown as plaintext (no color). 113 | NoColor bool 114 | 115 | // Verbose query interpolation, which shows the query before and after 116 | // interpolating query arguments. The logged query is interpolated by 117 | // default, InterpolateVerbose only controls whether the query before 118 | // interpolation is shown. To disable query interpolation entirely, look at 119 | // HideArgs. 120 | InterpolateVerbose bool 121 | 122 | // Explicitly hides arguments when logging the query (only the query 123 | // placeholders will be shown). 124 | HideArgs bool 125 | } 126 | 127 | var _ SqLogger = (*sqLogger)(nil) 128 | 129 | var defaultLogger = NewLogger(os.Stdout, "", log.LstdFlags, LoggerConfig{ 130 | ShowTimeTaken: true, 131 | ShowCaller: true, 132 | }) 133 | 134 | var verboseLogger = NewLogger(os.Stdout, "", log.LstdFlags, LoggerConfig{ 135 | ShowTimeTaken: true, 136 | ShowCaller: true, 137 | ShowResults: 5, 138 | InterpolateVerbose: true, 139 | }) 140 | 141 | // NewLogger returns a new SqLogger. 142 | func NewLogger(w io.Writer, prefix string, flag int, config LoggerConfig) SqLogger { 143 | return &sqLogger{ 144 | logger: log.New(w, prefix, flag), 145 | config: config, 146 | } 147 | } 148 | 149 | // SqLogSettings implements the SqLogger interface. 150 | func (l *sqLogger) SqLogSettings(ctx context.Context, settings *LogSettings) { 151 | settings.LogAsynchronously = l.config.LogAsynchronously 152 | settings.IncludeTime = l.config.ShowTimeTaken 153 | settings.IncludeCaller = l.config.ShowCaller 154 | settings.IncludeResults = l.config.ShowResults 155 | } 156 | 157 | // SqLogQuery implements the SqLogger interface. 158 | func (l *sqLogger) SqLogQuery(ctx context.Context, queryStats QueryStats) { 159 | var reset, red, green, blue, purple string 160 | envNoColor, _ := strconv.ParseBool(os.Getenv("NO_COLOR")) 161 | if !l.config.NoColor && !envNoColor { 162 | reset = colorReset 163 | red = colorRed 164 | green = colorGreen 165 | blue = colorBlue 166 | purple = colorPurple 167 | } 168 | buf := bufpool.Get().(*bytes.Buffer) 169 | buf.Reset() 170 | defer bufpool.Put(buf) 171 | if queryStats.Err == nil { 172 | buf.WriteString(green + "[OK]" + reset) 173 | } else { 174 | buf.WriteString(red + "[FAIL]" + reset) 175 | } 176 | if l.config.HideArgs { 177 | buf.WriteString(" " + queryStats.Query + ";") 178 | } else if !l.config.InterpolateVerbose { 179 | if queryStats.Err != nil { 180 | buf.WriteString(" " + queryStats.Query + ";") 181 | if len(queryStats.Args) > 0 { 182 | buf.WriteString(" [") 183 | } 184 | for i := 0; i < len(queryStats.Args); i++ { 185 | if i > 0 { 186 | buf.WriteString(", ") 187 | } 188 | buf.WriteString(fmt.Sprintf("%#v", queryStats.Args[i])) 189 | } 190 | if len(queryStats.Args) > 0 { 191 | buf.WriteString("]") 192 | } 193 | } else { 194 | query, err := Sprintf(queryStats.Dialect, queryStats.Query, queryStats.Args) 195 | if err != nil { 196 | query += " " + err.Error() 197 | } 198 | buf.WriteString(" " + query + ";") 199 | } 200 | } 201 | if queryStats.Err != nil { 202 | errStr := queryStats.Err.Error() 203 | if i := strings.IndexByte(errStr, '\n'); i < 0 { 204 | buf.WriteString(blue + " err" + reset + "={" + queryStats.Err.Error() + "}") 205 | } 206 | } 207 | if l.config.ShowTimeTaken { 208 | buf.WriteString(blue + " timeTaken" + reset + "=" + queryStats.TimeTaken.String()) 209 | } 210 | if queryStats.RowCount.Valid { 211 | buf.WriteString(blue + " rowCount" + reset + "=" + strconv.FormatInt(queryStats.RowCount.Int64, 10)) 212 | } 213 | if queryStats.RowsAffected.Valid { 214 | buf.WriteString(blue + " rowsAffected" + reset + "=" + strconv.FormatInt(queryStats.RowsAffected.Int64, 10)) 215 | } 216 | if queryStats.LastInsertId.Valid { 217 | buf.WriteString(blue + " lastInsertId" + reset + "=" + strconv.FormatInt(queryStats.LastInsertId.Int64, 10)) 218 | } 219 | if queryStats.Exists.Valid { 220 | buf.WriteString(blue + " exists" + reset + "=" + strconv.FormatBool(queryStats.Exists.Bool)) 221 | } 222 | if l.config.ShowCaller { 223 | buf.WriteString(blue + " caller" + reset + "=" + queryStats.CallerFile + ":" + strconv.Itoa(queryStats.CallerLine) + ":" + filepath.Base(queryStats.CallerFunction)) 224 | } 225 | if !l.config.HideArgs && l.config.InterpolateVerbose { 226 | buf.WriteString("\n" + purple + "----[ Executing query ]----" + reset) 227 | buf.WriteString("\n" + queryStats.Query + "; " + fmt.Sprintf("%#v", queryStats.Args)) 228 | buf.WriteString("\n" + purple + "----[ with bind values ]----" + reset) 229 | query, err := Sprintf(queryStats.Dialect, queryStats.Query, queryStats.Args) 230 | query += ";" 231 | if err != nil { 232 | query += " " + err.Error() 233 | } 234 | buf.WriteString("\n" + query) 235 | } 236 | if l.config.ShowResults > 0 && queryStats.Err == nil { 237 | buf.WriteString("\n" + purple + "----[ Fetched result ]----" + reset) 238 | buf.WriteString(queryStats.Results) 239 | if queryStats.RowCount.Int64 > int64(l.config.ShowResults) { 240 | buf.WriteString("\n...\n(Fetched " + strconv.FormatInt(queryStats.RowCount.Int64, 10) + " rows)") 241 | } 242 | } 243 | if buf.Len() > 0 { 244 | l.logger.Println(buf.String()) 245 | } 246 | } 247 | 248 | // Log wraps a DB and adds logging to it. 249 | func Log(db DB) interface { 250 | DB 251 | SqLogger 252 | } { 253 | return struct { 254 | DB 255 | SqLogger 256 | }{DB: db, SqLogger: defaultLogger} 257 | } 258 | 259 | // VerboseLog wraps a DB and adds verbose logging to it. 260 | func VerboseLog(db DB) interface { 261 | DB 262 | SqLogger 263 | } { 264 | return struct { 265 | DB 266 | SqLogger 267 | }{DB: db, SqLogger: verboseLogger} 268 | } 269 | 270 | var defaultLogSettings atomic.Value 271 | 272 | // SetDefaultLogSettings sets the function to configure the default 273 | // LogSettings. This value is not used unless SetDefaultLogQuery is also 274 | // configured. 275 | func SetDefaultLogSettings(logSettings func(context.Context, *LogSettings)) { 276 | defaultLogSettings.Store(logSettings) 277 | } 278 | 279 | var defaultLogQuery atomic.Value 280 | 281 | // SetDefaultLogQuery sets the default logging function to call for all 282 | // queries (if a logger is not explicitly passed in). 283 | func SetDefaultLogQuery(logQuery func(context.Context, QueryStats)) { 284 | defaultLogQuery.Store(logQuery) 285 | } 286 | 287 | type sqLogStruct struct { 288 | logSettings func(context.Context, *LogSettings) 289 | logQuery func(context.Context, QueryStats) 290 | } 291 | 292 | var _ SqLogger = (*sqLogStruct)(nil) 293 | 294 | func (l *sqLogStruct) SqLogSettings(ctx context.Context, logSettings *LogSettings) { 295 | if l.logSettings == nil { 296 | return 297 | } 298 | l.logSettings(ctx, logSettings) 299 | } 300 | 301 | func (l *sqLogStruct) SqLogQuery(ctx context.Context, queryStats QueryStats) { 302 | if l.logQuery == nil { 303 | return 304 | } 305 | l.logQuery(ctx, queryStats) 306 | } 307 | 308 | const ( 309 | colorReset = "\x1b[0m" 310 | colorRed = "\x1b[91m" 311 | colorGreen = "\x1b[92m" 312 | colorYellow = "\x1b[93m" 313 | colorBlue = "\x1b[94m" 314 | colorPurple = "\x1b[95m" 315 | colorCyan = "\x1b[96m" 316 | colorGray = "\x1b[97m" 317 | colorWhite = "\x1b[97m" 318 | ) 319 | -------------------------------------------------------------------------------- /logger_test.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql" 7 | "fmt" 8 | "log" 9 | "testing" 10 | "time" 11 | 12 | "github.com/bokwoon95/sq/internal/testutil" 13 | ) 14 | 15 | func TestLogger(t *testing.T) { 16 | type TT struct { 17 | description string 18 | ctx context.Context 19 | stats QueryStats 20 | config LoggerConfig 21 | wantOutput string 22 | } 23 | 24 | assert := func(t *testing.T, tt TT) { 25 | if tt.ctx == nil { 26 | tt.ctx = context.Background() 27 | } 28 | buf := &bytes.Buffer{} 29 | logger := sqLogger{ 30 | logger: log.New(buf, "", 0), 31 | config: tt.config, 32 | } 33 | logger.SqLogQuery(tt.ctx, tt.stats) 34 | if diff := testutil.Diff(buf.String(), tt.wantOutput); diff != "" { 35 | t.Error(testutil.Callers(), diff) 36 | } 37 | } 38 | 39 | t.Run("Log VerboseLog", func(t *testing.T) { 40 | t.Parallel() 41 | var logSettings LogSettings 42 | Log(nil).SqLogSettings(context.Background(), &logSettings) 43 | diff := testutil.Diff(logSettings, LogSettings{ 44 | LogAsynchronously: false, 45 | IncludeTime: true, 46 | IncludeCaller: true, 47 | IncludeResults: 0, 48 | }) 49 | if diff != "" { 50 | t.Error(testutil.Callers(), diff) 51 | } 52 | VerboseLog(nil).SqLogSettings(context.Background(), &logSettings) 53 | diff = testutil.Diff(logSettings, LogSettings{ 54 | LogAsynchronously: false, 55 | IncludeTime: true, 56 | IncludeCaller: true, 57 | IncludeResults: 5, 58 | }) 59 | if diff != "" { 60 | t.Error(testutil.Callers(), diff) 61 | } 62 | }) 63 | 64 | t.Run("no color", func(t *testing.T) { 65 | var tt TT 66 | tt.config.NoColor = true 67 | tt.stats.Query = "SELECT 1" 68 | tt.wantOutput = "[OK] SELECT 1;\n" 69 | assert(t, tt) 70 | }) 71 | 72 | tests := []TT{{ 73 | description: "err", 74 | stats: QueryStats{ 75 | Query: "SELECT 1", 76 | Err: fmt.Errorf("lorem ipsum"), 77 | }, 78 | wantOutput: "\x1b[91m[FAIL]\x1b[0m SELECT 1;\x1b[94m err\x1b[0m={lorem ipsum}\n", 79 | }, { 80 | description: "HideArgs", 81 | config: LoggerConfig{HideArgs: true}, 82 | stats: QueryStats{ 83 | Query: "SELECT ?", Args: []any{1}, 84 | }, 85 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT ?;\n", 86 | }, { 87 | description: "RowCount", 88 | stats: QueryStats{ 89 | Query: "SELECT 1", 90 | RowCount: sql.NullInt64{Valid: true, Int64: 3}, 91 | }, 92 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;\x1b[94m rowCount\x1b[0m=3\n", 93 | }, { 94 | description: "RowsAffected", 95 | stats: QueryStats{ 96 | Query: "SELECT 1", 97 | RowsAffected: sql.NullInt64{Valid: true, Int64: 5}, 98 | }, 99 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;\x1b[94m rowsAffected\x1b[0m=5\n", 100 | }, { 101 | description: "LastInsertId", 102 | stats: QueryStats{ 103 | Query: "SELECT 1", 104 | LastInsertId: sql.NullInt64{Valid: true, Int64: 7}, 105 | }, 106 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;\x1b[94m lastInsertId\x1b[0m=7\n", 107 | }, { 108 | description: "Exists", 109 | stats: QueryStats{ 110 | Query: "SELECT EXISTS (SELECT 1)", 111 | Exists: sql.NullBool{Valid: true, Bool: true}, 112 | }, 113 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT EXISTS (SELECT 1);\x1b[94m exists\x1b[0m=true\n", 114 | }, { 115 | description: "ShowCaller", 116 | config: LoggerConfig{ShowCaller: true}, 117 | stats: QueryStats{ 118 | Query: "SELECT 1", 119 | CallerFile: "file.go", 120 | CallerLine: 22, 121 | CallerFunction: "someFunc", 122 | }, 123 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;\x1b[94m caller\x1b[0m=file.go:22:someFunc\n", 124 | }, { 125 | description: "Verbose", 126 | config: LoggerConfig{InterpolateVerbose: true, ShowTimeTaken: true}, 127 | stats: QueryStats{ 128 | Query: "SELECT ?, ?", Args: []any{1, "bob"}, 129 | TimeTaken: 300 * time.Millisecond, 130 | }, 131 | wantOutput: "\x1b[92m[OK]\x1b[0m\x1b[94m timeTaken\x1b[0m=300ms" + 132 | "\n\x1b[95m----[ Executing query ]----\x1b[0m" + 133 | "\nSELECT ?, ?; []interface {}{1, \"bob\"}" + 134 | "\n\x1b[95m----[ with bind values ]----\x1b[0m" + 135 | "\nSELECT 1, 'bob';\n", 136 | }, { 137 | description: "ShowResults", 138 | config: LoggerConfig{ShowResults: 1}, 139 | stats: QueryStats{ 140 | Query: "SELECT 1", 141 | Results: "\nlorem ipsum dolor sit amet", 142 | }, 143 | wantOutput: "\x1b[92m[OK]\x1b[0m SELECT 1;" + 144 | "\n\x1b[95m----[ Fetched result ]----\x1b[0m" + 145 | "\nlorem ipsum dolor sit amet\n", 146 | }} 147 | 148 | for _, tt := range tests { 149 | tt := tt 150 | t.Run(tt.description, func(t *testing.T) { 151 | assert(t, tt) 152 | }) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /misc.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "strings" 8 | ) 9 | 10 | // ValueExpression represents an SQL value that is passed in as an argument to 11 | // a prepared query. 12 | type ValueExpression struct { 13 | value any 14 | alias string 15 | } 16 | 17 | var _ interface { 18 | Field 19 | Predicate 20 | Any 21 | } = (*ValueExpression)(nil) 22 | 23 | // Value returns a new ValueExpression. 24 | func Value(value any) ValueExpression { return ValueExpression{value: value} } 25 | 26 | // WriteSQL implements the SQLWriter interface. 27 | func (e ValueExpression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 28 | return WriteValue(ctx, dialect, buf, args, params, e.value) 29 | } 30 | 31 | // As returns a new ValueExpression with the given alias. 32 | func (e ValueExpression) As(alias string) ValueExpression { 33 | e.alias = alias 34 | return e 35 | } 36 | 37 | // In returns a 'expr IN (val)' Predicate. 38 | func (e ValueExpression) In(val any) Predicate { return In(e.value, val) } 39 | 40 | // Eq returns a 'expr = val' Predicate. 41 | func (e ValueExpression) Eq(val any) Predicate { return Eq(e.value, val) } 42 | 43 | // Ne returns a 'expr <> val' Predicate. 44 | func (e ValueExpression) Ne(val any) Predicate { return Ne(e.value, val) } 45 | 46 | // Lt returns a 'expr < val' Predicate. 47 | func (e ValueExpression) Lt(val any) Predicate { return Lt(e.value, val) } 48 | 49 | // Le returns a 'expr <= val' Predicate. 50 | func (e ValueExpression) Le(val any) Predicate { return Le(e.value, val) } 51 | 52 | // Gt returns a 'expr > val' Predicate. 53 | func (e ValueExpression) Gt(val any) Predicate { return Gt(e.value, val) } 54 | 55 | // Ge returns a 'expr >= val' Predicate. 56 | func (e ValueExpression) Ge(val any) Predicate { return Ge(e.value, val) } 57 | 58 | // GetAlias returns the alias of the ValueExpression. 59 | func (e ValueExpression) GetAlias() string { return e.alias } 60 | 61 | // IsField implements the Field interface. 62 | func (e ValueExpression) IsField() {} 63 | 64 | // IsArray implements the Array interface. 65 | func (e ValueExpression) IsArray() {} 66 | 67 | // IsBinary implements the Binary interface. 68 | func (e ValueExpression) IsBinary() {} 69 | 70 | // IsBoolean implements the Boolean interface. 71 | func (e ValueExpression) IsBoolean() {} 72 | 73 | // IsEnum implements the Enum interface. 74 | func (e ValueExpression) IsEnum() {} 75 | 76 | // IsJSON implements the JSON interface. 77 | func (e ValueExpression) IsJSON() {} 78 | 79 | // IsNumber implements the Number interface. 80 | func (e ValueExpression) IsNumber() {} 81 | 82 | // IsString implements the String interface. 83 | func (e ValueExpression) IsString() {} 84 | 85 | // IsTime implements the Time interfaces. 86 | func (e ValueExpression) IsTime() {} 87 | 88 | // IsUUID implements the UUID interface. 89 | func (e ValueExpression) IsUUID() {} 90 | 91 | // LiteralValue represents an SQL value literally interpolated into the query. 92 | // Doing so potentially exposes the query to SQL injection so only do this for 93 | // values that you trust e.g. literals and constants. 94 | type LiteralValue struct { 95 | value any 96 | alias string 97 | } 98 | 99 | var _ interface { 100 | Field 101 | Predicate 102 | Binary 103 | Boolean 104 | Number 105 | String 106 | Time 107 | } = (*LiteralValue)(nil) 108 | 109 | // Literal returns a new LiteralValue. 110 | func Literal(value any) LiteralValue { return LiteralValue{value: value} } 111 | 112 | // WriteSQL implements the SQLWriter interface. 113 | func (v LiteralValue) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 114 | s, err := Sprint(dialect, v.value) 115 | if err != nil { 116 | return err 117 | } 118 | buf.WriteString(s) 119 | return nil 120 | } 121 | 122 | // As returns a new LiteralValue with the given alias. 123 | func (v LiteralValue) As(alias string) LiteralValue { 124 | v.alias = alias 125 | return v 126 | } 127 | 128 | // In returns a 'literal IN (val)' Predicate. 129 | func (v LiteralValue) In(val any) Predicate { return In(v, val) } 130 | 131 | // Eq returns a 'literal = val' Predicate. 132 | func (v LiteralValue) Eq(val any) Predicate { return Eq(v, val) } 133 | 134 | // Ne returns a 'literal <> val' Predicate. 135 | func (v LiteralValue) Ne(val any) Predicate { return Ne(v, val) } 136 | 137 | // Lt returns a 'literal < val' Predicate. 138 | func (v LiteralValue) Lt(val any) Predicate { return Lt(v, val) } 139 | 140 | // Le returns a 'literal <= val' Predicate. 141 | func (v LiteralValue) Le(val any) Predicate { return Le(v, val) } 142 | 143 | // Gt returns a 'literal > val' Predicate. 144 | func (v LiteralValue) Gt(val any) Predicate { return Gt(v, val) } 145 | 146 | // Ge returns a 'literal >= val' Predicate. 147 | func (v LiteralValue) Ge(val any) Predicate { return Ge(v, val) } 148 | 149 | // GetAlias returns the alias of the LiteralValue. 150 | func (v LiteralValue) GetAlias() string { return v.alias } 151 | 152 | // IsField implements the Field interface. 153 | func (v LiteralValue) IsField() {} 154 | 155 | // IsBinary implements the Binary interface. 156 | func (v LiteralValue) IsBinary() {} 157 | 158 | // IsBoolean implements the Boolean interface. 159 | func (v LiteralValue) IsBoolean() {} 160 | 161 | // IsNumber implements the Number interface. 162 | func (v LiteralValue) IsNumber() {} 163 | 164 | // IsString implements the String interface. 165 | func (v LiteralValue) IsString() {} 166 | 167 | // IsTime implements the Time interfaces. 168 | func (v LiteralValue) IsTime() {} 169 | 170 | // DialectExpression represents an SQL expression that renders differently 171 | // depending on the dialect. 172 | type DialectExpression struct { 173 | Default any 174 | Cases DialectCases 175 | } 176 | 177 | // DialectCases is a slice of DialectCases. 178 | type DialectCases = []DialectCase 179 | 180 | // DialectCase holds the result to be used for a given dialect in a 181 | // DialectExpression. 182 | type DialectCase struct { 183 | Dialect string 184 | Result any 185 | } 186 | 187 | var _ interface { 188 | Table 189 | Field 190 | Predicate 191 | Any 192 | } = (*DialectExpression)(nil) 193 | 194 | // DialectValue returns a new DialectExpression. The value passed in is used as 195 | // the default. 196 | func DialectValue(value any) DialectExpression { 197 | return DialectExpression{Default: value} 198 | } 199 | 200 | // DialectExpr returns a new DialectExpression. The expression passed in is 201 | // used as the default. 202 | func DialectExpr(format string, values ...any) DialectExpression { 203 | return DialectExpression{Default: Expr(format, values...)} 204 | } 205 | 206 | // WriteSQL implements the SQLWriter interface. 207 | func (e DialectExpression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 208 | for _, Case := range e.Cases { 209 | if dialect == Case.Dialect { 210 | return WriteValue(ctx, dialect, buf, args, params, Case.Result) 211 | } 212 | } 213 | return WriteValue(ctx, dialect, buf, args, params, e.Default) 214 | } 215 | 216 | // DialectValue adds a new dialect-value pair to the DialectExpression. 217 | func (e DialectExpression) DialectValue(dialect string, value any) DialectExpression { 218 | e.Cases = append(e.Cases, DialectCase{Dialect: dialect, Result: value}) 219 | return e 220 | } 221 | 222 | // DialectExpr adds a new dialect-expression pair to the DialectExpression. 223 | func (e DialectExpression) DialectExpr(dialect string, format string, values ...any) DialectExpression { 224 | e.Cases = append(e.Cases, DialectCase{Dialect: dialect, Result: Expr(format, values...)}) 225 | return e 226 | } 227 | 228 | // IsTable implements the Table interface. 229 | func (e DialectExpression) IsTable() {} 230 | 231 | // IsField implements the Field interface. 232 | func (e DialectExpression) IsField() {} 233 | 234 | // IsArray implements the Array interface. 235 | func (e DialectExpression) IsArray() {} 236 | 237 | // IsBinary implements the Binary interface. 238 | func (e DialectExpression) IsBinary() {} 239 | 240 | // IsBoolean implements the Boolean interface. 241 | func (e DialectExpression) IsBoolean() {} 242 | 243 | // IsEnum implements the Enum interface. 244 | func (e DialectExpression) IsEnum() {} 245 | 246 | // IsJSON implements the JSON interface. 247 | func (e DialectExpression) IsJSON() {} 248 | 249 | // IsNumber implements the Number interface. 250 | func (e DialectExpression) IsNumber() {} 251 | 252 | // IsString implements the String interface. 253 | func (e DialectExpression) IsString() {} 254 | 255 | // IsTime implements the Time interface. 256 | func (e DialectExpression) IsTime() {} 257 | 258 | // IsUUID implements the UUID interface. 259 | func (e DialectExpression) IsUUID() {} 260 | 261 | // CaseExpression represents an SQL CASE expression. 262 | type CaseExpression struct { 263 | alias string 264 | Cases PredicateCases 265 | Default any 266 | } 267 | 268 | // PredicateCases is a slice of PredicateCases. 269 | type PredicateCases = []PredicateCase 270 | 271 | // PredicateCase holds the result to be used for a given predicate in a 272 | // CaseExpression. 273 | type PredicateCase struct { 274 | Predicate Predicate 275 | Result any 276 | } 277 | 278 | var _ interface { 279 | Field 280 | Any 281 | } = (*CaseExpression)(nil) 282 | 283 | // CaseWhen returns a new CaseExpression. 284 | func CaseWhen(predicate Predicate, result any) CaseExpression { 285 | return CaseExpression{ 286 | Cases: PredicateCases{{Predicate: predicate, Result: result}}, 287 | } 288 | } 289 | 290 | // WriteSQL implements the SQLWriter interface. 291 | func (e CaseExpression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 292 | buf.WriteString("CASE") 293 | if len(e.Cases) == 0 { 294 | return fmt.Errorf("CaseExpression empty") 295 | } 296 | var err error 297 | for i, Case := range e.Cases { 298 | buf.WriteString(" WHEN ") 299 | err = WriteValue(ctx, dialect, buf, args, params, Case.Predicate) 300 | if err != nil { 301 | return fmt.Errorf("CASE #%d WHEN: %w", i+1, err) 302 | } 303 | buf.WriteString(" THEN ") 304 | err = WriteValue(ctx, dialect, buf, args, params, Case.Result) 305 | if err != nil { 306 | return fmt.Errorf("CASE #%d THEN: %w", i+1, err) 307 | } 308 | } 309 | if e.Default != nil { 310 | buf.WriteString(" ELSE ") 311 | err = WriteValue(ctx, dialect, buf, args, params, e.Default) 312 | if err != nil { 313 | return fmt.Errorf("CASE ELSE: %w", err) 314 | } 315 | } 316 | buf.WriteString(" END") 317 | return nil 318 | } 319 | 320 | // When adds a new predicate-result pair to the CaseExpression. 321 | func (e CaseExpression) When(predicate Predicate, result any) CaseExpression { 322 | e.Cases = append(e.Cases, PredicateCase{Predicate: predicate, Result: result}) 323 | return e 324 | } 325 | 326 | // Else sets the fallback result of the CaseExpression. 327 | func (e CaseExpression) Else(fallback any) CaseExpression { 328 | e.Default = fallback 329 | return e 330 | } 331 | 332 | // As returns a new CaseExpression with the given alias. 333 | func (e CaseExpression) As(alias string) CaseExpression { 334 | e.alias = alias 335 | return e 336 | } 337 | 338 | // GetAlias returns the alias of the CaseExpression. 339 | func (e CaseExpression) GetAlias() string { return e.alias } 340 | 341 | // IsField implements the Field interface. 342 | func (e CaseExpression) IsField() {} 343 | 344 | // IsArray implements the Array interface. 345 | func (e CaseExpression) IsArray() {} 346 | 347 | // IsBinary implements the Binary interface. 348 | func (e CaseExpression) IsBinary() {} 349 | 350 | // IsBoolean implements the Boolean interface. 351 | func (e CaseExpression) IsBoolean() {} 352 | 353 | // IsEnum implements the Enum interface. 354 | func (e CaseExpression) IsEnum() {} 355 | 356 | // IsJSON implements the JSON interface. 357 | func (e CaseExpression) IsJSON() {} 358 | 359 | // IsNumber implements the Number interface. 360 | func (e CaseExpression) IsNumber() {} 361 | 362 | // IsString implements the String interface. 363 | func (e CaseExpression) IsString() {} 364 | 365 | // IsTime implements the Time interface. 366 | func (e CaseExpression) IsTime() {} 367 | 368 | // IsUUID implements the UUID interface. 369 | func (e CaseExpression) IsUUID() {} 370 | 371 | // SimpleCaseExpression represents an SQL simple CASE expression. 372 | type SimpleCaseExpression struct { 373 | alias string 374 | Expression any 375 | Cases SimpleCases 376 | Default any 377 | } 378 | 379 | // SimpleCases is a slice of SimpleCases. 380 | type SimpleCases = []SimpleCase 381 | 382 | // SimpleCase holds the result to be used for a given value in a 383 | // SimpleCaseExpression. 384 | type SimpleCase struct { 385 | Value any 386 | Result any 387 | } 388 | 389 | var _ interface { 390 | Field 391 | Any 392 | } = (*SimpleCaseExpression)(nil) 393 | 394 | // Case returns a new SimpleCaseExpression. 395 | func Case(expression any) SimpleCaseExpression { 396 | return SimpleCaseExpression{Expression: expression} 397 | } 398 | 399 | // WriteSQL implements the SQLWriter interface. 400 | func (e SimpleCaseExpression) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 401 | buf.WriteString("CASE ") 402 | if len(e.Cases) == 0 { 403 | return fmt.Errorf("SimpleCaseExpression empty") 404 | } 405 | var err error 406 | err = WriteValue(ctx, dialect, buf, args, params, e.Expression) 407 | if err != nil { 408 | return fmt.Errorf("CASE: %w", err) 409 | } 410 | for i, Case := range e.Cases { 411 | buf.WriteString(" WHEN ") 412 | err = WriteValue(ctx, dialect, buf, args, params, Case.Value) 413 | if err != nil { 414 | return fmt.Errorf("CASE #%d WHEN: %w", i+1, err) 415 | } 416 | buf.WriteString(" THEN ") 417 | err = WriteValue(ctx, dialect, buf, args, params, Case.Result) 418 | if err != nil { 419 | return fmt.Errorf("CASE #%d THEN: %w", i+1, err) 420 | } 421 | } 422 | if e.Default != nil { 423 | buf.WriteString(" ELSE ") 424 | err = WriteValue(ctx, dialect, buf, args, params, e.Default) 425 | if err != nil { 426 | return fmt.Errorf("CASE ELSE: %w", err) 427 | } 428 | } 429 | buf.WriteString(" END") 430 | return nil 431 | } 432 | 433 | // When adds a new value-result pair to the SimpleCaseExpression. 434 | func (e SimpleCaseExpression) When(value any, result any) SimpleCaseExpression { 435 | e.Cases = append(e.Cases, SimpleCase{Value: value, Result: result}) 436 | return e 437 | } 438 | 439 | // Else sets the fallback result of the SimpleCaseExpression. 440 | func (e SimpleCaseExpression) Else(fallback any) SimpleCaseExpression { 441 | e.Default = fallback 442 | return e 443 | } 444 | 445 | // As returns a new SimpleCaseExpression with the given alias. 446 | func (e SimpleCaseExpression) As(alias string) SimpleCaseExpression { 447 | e.alias = alias 448 | return e 449 | } 450 | 451 | // GetAlias returns the alias of the SimpleCaseExpression. 452 | func (e SimpleCaseExpression) GetAlias() string { return e.alias } 453 | 454 | // IsField implements the Field interface. 455 | func (e SimpleCaseExpression) IsField() {} 456 | 457 | // IsArray implements the Array interface. 458 | func (e SimpleCaseExpression) IsArray() {} 459 | 460 | // IsBinary implements the Binary interface. 461 | func (e SimpleCaseExpression) IsBinary() {} 462 | 463 | // IsBoolean implements the Boolean interface. 464 | func (e SimpleCaseExpression) IsBoolean() {} 465 | 466 | // IsEnum implements the Enum interface. 467 | func (e SimpleCaseExpression) IsEnum() {} 468 | 469 | // IsJSON implements the JSON interface. 470 | func (e SimpleCaseExpression) IsJSON() {} 471 | 472 | // IsNumber implements the Number interface. 473 | func (e SimpleCaseExpression) IsNumber() {} 474 | 475 | // IsString implements the String interface. 476 | func (e SimpleCaseExpression) IsString() {} 477 | 478 | // IsTime implements the Time interface. 479 | func (e SimpleCaseExpression) IsTime() {} 480 | 481 | // IsUUID implements the UUID interface. 482 | func (e SimpleCaseExpression) IsUUID() {} 483 | 484 | // Count represents an SQL COUNT() expression. 485 | func Count(field Field) Expression { return Expr("COUNT({})", field) } 486 | 487 | // CountStar represents an SQL COUNT(*) expression. 488 | func CountStar() Expression { return Expr("COUNT(*)") } 489 | 490 | // Sum represents an SQL SUM() expression. 491 | func Sum(num Number) Expression { return Expr("SUM({})", num) } 492 | 493 | // Avg represents an SQL AVG() expression. 494 | func Avg(num Number) Expression { return Expr("AVG({})", num) } 495 | 496 | // Min represent an SQL MIN() expression. 497 | func Min(field Field) Expression { return Expr("MIN({})", field) } 498 | 499 | // Max represents an SQL MAX() expression. 500 | func Max(field Field) Expression { return Expr("MAX({})", field) } 501 | 502 | // SelectValues represents a table literal comprised of SELECT statements 503 | // UNION-ed together e.g. 504 | // 505 | // (SELECT 1 AS a, 2 AS b, 3 AS c 506 | // UNION ALL 507 | // SELECT 4, 5, 6 508 | // UNION ALL 509 | // SELECT 7, 8, 9) AS tbl 510 | type SelectValues struct { 511 | Alias string 512 | Columns []string 513 | RowValues [][]any 514 | } 515 | 516 | var _ interface { 517 | Query 518 | Table 519 | } = (*SelectValues)(nil) 520 | 521 | // WriteSQL implements the SQLWriter interface. 522 | func (vs SelectValues) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 523 | var err error 524 | for i, rowvalue := range vs.RowValues { 525 | if i > 0 { 526 | buf.WriteString(" UNION ALL ") 527 | } 528 | if len(vs.Columns) > 0 && len(rowvalue) != len(vs.Columns) { 529 | return fmt.Errorf("rowvalue #%d: got %d values, want %d values (%s)", i+1, len(rowvalue), len(vs.Columns), strings.Join(vs.Columns, ", ")) 530 | } 531 | buf.WriteString("SELECT ") 532 | for j, value := range rowvalue { 533 | if j > 0 { 534 | buf.WriteString(", ") 535 | } 536 | err = WriteValue(ctx, dialect, buf, args, params, value) 537 | if err != nil { 538 | return fmt.Errorf("rowvalue #%d value #%d: %w", i+1, j+1, err) 539 | } 540 | if i == 0 && j < len(vs.Columns) { 541 | buf.WriteString(" AS " + QuoteIdentifier(dialect, vs.Columns[j])) 542 | } 543 | } 544 | } 545 | return nil 546 | } 547 | 548 | // Field returns a new field qualified by the SelectValues' alias. 549 | func (vs SelectValues) Field(name string) AnyField { 550 | return NewAnyField(name, TableStruct{alias: vs.Alias}) 551 | } 552 | 553 | // SetFetchableFields implements the Query interface. It always returns false 554 | // as the second result. 555 | func (vs SelectValues) SetFetchableFields([]Field) (query Query, ok bool) { return vs, false } 556 | 557 | // GetDialect implements the Query interface. It always returns an empty 558 | // string. 559 | func (vs SelectValues) GetDialect() string { return "" } 560 | 561 | // GetAlias returns the alias of the SelectValues. 562 | func (vs SelectValues) GetAlias() string { return vs.Alias } 563 | 564 | // IsTable implements the Table interface. 565 | func (vs SelectValues) IsTable() {} 566 | 567 | // TableValues represents a table literal created by the VALUES clause e.g. 568 | // 569 | // (VALUES 570 | // 571 | // (1, 2, 3), 572 | // (4, 5, 6), 573 | // (7, 8, 9)) AS tbl (a, b, c) 574 | type TableValues struct { 575 | Alias string 576 | Columns []string 577 | RowValues [][]any 578 | } 579 | 580 | var _ interface { 581 | Query 582 | Table 583 | } = (*TableValues)(nil) 584 | 585 | // WriteSQL implements the SQLWriter interface. 586 | func (vs TableValues) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 587 | if len(vs.RowValues) == 0 { 588 | return nil 589 | } 590 | var err error 591 | buf.WriteString("VALUES ") 592 | for i, rowvalue := range vs.RowValues { 593 | if len(vs.Columns) > 0 && len(vs.Columns) != len(rowvalue) { 594 | return fmt.Errorf("rowvalue #%d: got %d values, want %d values (%s)", i+1, len(rowvalue), len(vs.Columns), strings.Join(vs.Columns, ", ")) 595 | } 596 | if i > 0 { 597 | buf.WriteString(", ") 598 | } 599 | if dialect == DialectMySQL { 600 | buf.WriteString("ROW(") 601 | } else { 602 | buf.WriteString("(") 603 | } 604 | for j, value := range rowvalue { 605 | if j > 0 { 606 | buf.WriteString(", ") 607 | } 608 | err = WriteValue(ctx, dialect, buf, args, params, value) 609 | if err != nil { 610 | return fmt.Errorf("rowvalue #%d value #%d: %w", i+1, j+1, err) 611 | } 612 | } 613 | buf.WriteString(")") 614 | } 615 | return nil 616 | } 617 | 618 | // Field returns a new field qualified by the TableValues' alias. 619 | func (vs TableValues) Field(name string) AnyField { 620 | return NewAnyField(name, TableStruct{alias: vs.Alias}) 621 | } 622 | 623 | // SetFetchableFields implements the Query interface. It always returns false 624 | // as the second result. 625 | func (vs TableValues) SetFetchableFields([]Field) (query Query, ok bool) { return vs, false } 626 | 627 | // GetDialect implements the Query interface. It always returns an empty 628 | // string. 629 | func (vs TableValues) GetDialect() string { return "" } 630 | 631 | // GetAlias returns the alias of the TableValues. 632 | func (vs TableValues) GetAlias() string { return vs.Alias } 633 | 634 | // GetColumns returns the names of the columns in the TableValues. 635 | func (vs TableValues) GetColumns() []string { return vs.Columns } 636 | 637 | // IsTable implements the Table interface. 638 | func (vs TableValues) IsTable() {} 639 | -------------------------------------------------------------------------------- /misc_test.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "testing" 7 | "time" 8 | 9 | "github.com/bokwoon95/sq/internal/testutil" 10 | ) 11 | 12 | func TestValueExpression(t *testing.T) { 13 | t.Run("alias", func(t *testing.T) { 14 | t.Parallel() 15 | expr := Value(1).As("num") 16 | if diff := testutil.Diff(expr.GetAlias(), "num"); diff != "" { 17 | t.Error(testutil.Callers(), diff) 18 | } 19 | }) 20 | 21 | tests := []TestTable{{ 22 | description: "basic", 23 | item: Value(Param("xyz", 42)), 24 | wantQuery: "?", 25 | wantArgs: []any{42}, 26 | wantParams: map[string][]int{"xyz": {0}}, 27 | }, { 28 | description: "In", item: Value(1).In([]int{18, 21, 32}), 29 | wantQuery: "? IN (?, ?, ?)", wantArgs: []any{1, 18, 21, 32}, 30 | }, { 31 | description: "Eq", item: Value(1).Eq(34), 32 | wantQuery: "? = ?", wantArgs: []any{1, 34}, 33 | }, { 34 | description: "Ne", item: Value(1).Ne(34), 35 | wantQuery: "? <> ?", wantArgs: []any{1, 34}, 36 | }, { 37 | description: "Lt", item: Value(1).Lt(34), 38 | wantQuery: "? < ?", wantArgs: []any{1, 34}, 39 | }, { 40 | description: "Le", item: Value(1).Le(34), 41 | wantQuery: "? <= ?", wantArgs: []any{1, 34}, 42 | }, { 43 | description: "Gt", item: Value(1).Gt(34), 44 | wantQuery: "? > ?", wantArgs: []any{1, 34}, 45 | }, { 46 | description: "Ge", item: Value(1).Ge(34), 47 | wantQuery: "? >= ?", wantArgs: []any{1, 34}, 48 | }} 49 | 50 | for _, tt := range tests { 51 | tt := tt 52 | t.Run(tt.description, func(t *testing.T) { 53 | t.Parallel() 54 | tt.assert(t) 55 | }) 56 | } 57 | } 58 | 59 | func TestLiteralExpression(t *testing.T) { 60 | t.Run("alias", func(t *testing.T) { 61 | t.Parallel() 62 | expr := Literal(1).As("num") 63 | if diff := testutil.Diff(expr.GetAlias(), "num"); diff != "" { 64 | t.Error(testutil.Callers(), diff) 65 | } 66 | }) 67 | 68 | tests := []TestTable{{ 69 | description: "binary", 70 | item: Literal([]byte{0xab, 0xcd, 0xef}), 71 | wantQuery: "x'abcdef'", 72 | }, { 73 | description: "time", item: Literal(time.Unix(0, 0).UTC()), 74 | wantQuery: "'1970-01-01 00:00:00'", 75 | }, { 76 | description: "In", item: Literal(1).In([]any{Literal(18), Literal(21), Literal(32)}), 77 | wantQuery: "1 IN (18, 21, 32)", 78 | }, { 79 | description: "Eq", item: Literal(true).Eq(Literal(false)), 80 | wantQuery: "TRUE = FALSE", 81 | }, { 82 | description: "Ne", item: Literal("one").Ne(Literal("thirty four")), 83 | wantQuery: "'one' <> 'thirty four'", 84 | }, { 85 | description: "Lt", item: Literal(1).Lt(Literal(34)), 86 | wantQuery: "1 < 34", 87 | }, { 88 | description: "Le", item: Literal(1).Le(Literal(34)), 89 | wantQuery: "1 <= 34", 90 | }, { 91 | description: "Gt", item: Literal(1).Gt(Literal(34)), 92 | wantQuery: "1 > 34", 93 | }, { 94 | description: "Ge", item: Literal(1).Ge(Literal(34)), 95 | wantQuery: "1 >= 34", 96 | }} 97 | 98 | for _, tt := range tests { 99 | tt := tt 100 | t.Run(tt.description, func(t *testing.T) { 101 | t.Parallel() 102 | tt.assert(t) 103 | }) 104 | } 105 | } 106 | 107 | func TestDialectExpression(t *testing.T) { 108 | t.Parallel() 109 | expr := DialectValue(Expr("default")). 110 | DialectValue(DialectSQLite, Expr("sqlite")). 111 | DialectValue(DialectPostgres, Expr("postgres")). 112 | DialectValue(DialectMySQL, Expr("mysql")). 113 | DialectExpr(DialectSQLServer, "{}", Expr("sqlserver")) 114 | var tt TestTable 115 | tt.item = expr 116 | // default 117 | tt.wantQuery = "default" 118 | tt.assert(t) 119 | // sqlite 120 | tt.dialect = DialectSQLite 121 | tt.wantQuery = "sqlite" 122 | tt.assert(t) 123 | // postgres 124 | tt.dialect = DialectPostgres 125 | tt.wantQuery = "postgres" 126 | tt.assert(t) 127 | // mysql 128 | tt.dialect = DialectMySQL 129 | tt.wantQuery = "mysql" 130 | tt.assert(t) 131 | // sqlserver 132 | tt.dialect = DialectSQLServer 133 | tt.wantQuery = "sqlserver" 134 | tt.assert(t) 135 | } 136 | 137 | func TestCaseExpressions(t *testing.T) { 138 | t.Run("name and alias", func(t *testing.T) { 139 | t.Parallel() 140 | // CaseExpression 141 | caseExpr := CaseWhen(Value(true), 1).As("result_a") 142 | if diff := testutil.Diff(caseExpr.GetAlias(), "result_a"); diff != "" { 143 | t.Error(testutil.Callers(), diff) 144 | } 145 | // SimpleCaseExpression 146 | simpleCaseExpr := Case(1).When(1, 2).As("result_b") 147 | if diff := testutil.Diff(simpleCaseExpr.GetAlias(), "result_b"); diff != "" { 148 | t.Error(testutil.Callers(), diff) 149 | } 150 | }) 151 | 152 | t.Run("CaseExpression", func(t *testing.T) { 153 | t.Parallel() 154 | TestTable{ 155 | item: CaseWhen(Expr("x = y"), 1).When(Expr("a = b"), 2).Else(3), 156 | wantQuery: "CASE WHEN x = y THEN ? WHEN a = b THEN ? ELSE ? END", 157 | wantArgs: []any{1, 2, 3}, 158 | }.assert(t) 159 | }) 160 | 161 | t.Run("SimpleCaseExpression", func(t *testing.T) { 162 | t.Parallel() 163 | TestTable{ 164 | item: Case(Expr("a")).When(1, 2).When(3, 4).Else(5), 165 | wantQuery: "CASE a WHEN ? THEN ? WHEN ? THEN ? ELSE ? END", 166 | wantArgs: []any{1, 2, 3, 4, 5}, 167 | }.assert(t) 168 | }) 169 | 170 | t.Run("CaseExpression cannot be empty", func(t *testing.T) { 171 | t.Parallel() 172 | TestTable{item: CaseExpression{}}.assertNotOK(t) 173 | }) 174 | 175 | t.Run("SimpleCaseExpression cannot be empty", func(t *testing.T) { 176 | t.Parallel() 177 | TestTable{item: SimpleCaseExpression{}}.assertNotOK(t) 178 | }) 179 | 180 | errTests := []TestTable{{ 181 | description: "CASE WHEN predicate err", item: CaseWhen(FaultySQL{}, 1), 182 | }, { 183 | description: "CASE WHEN result err", item: CaseWhen(Value(true), FaultySQL{}), 184 | }, { 185 | description: "CASE WHEN fallback err", item: CaseWhen(Value(true), 1).Else(FaultySQL{}), 186 | }, { 187 | description: "CASE expression err", item: Case(FaultySQL{}).When(1, 2), 188 | }, { 189 | description: "CASE value err", item: Case(1).When(FaultySQL{}, 2), 190 | }, { 191 | description: "CASE result err", item: Case(1).When(2, FaultySQL{}), 192 | }, { 193 | description: "CASE fallback err", item: Case(1).When(2, 3).Else(FaultySQL{}), 194 | }} 195 | 196 | for _, tt := range errTests { 197 | tt := tt 198 | t.Run(tt.description, func(t *testing.T) { 199 | t.Parallel() 200 | tt.assertErr(t, ErrFaultySQL) 201 | }) 202 | } 203 | } 204 | 205 | func TestSelectValues(t *testing.T) { 206 | type TestTable struct { 207 | description string 208 | dialect string 209 | item SelectValues 210 | wantQuery string 211 | wantArgs []any 212 | } 213 | 214 | t.Run("dialect alias and fields", func(t *testing.T) { 215 | selectValues := SelectValues{ 216 | Alias: "aaa", 217 | } 218 | if diff := testutil.Diff(selectValues.GetAlias(), "aaa"); diff != "" { 219 | t.Error(testutil.Callers(), diff) 220 | } 221 | if diff := testutil.Diff(selectValues.GetDialect(), ""); diff != "" { 222 | t.Error(testutil.Callers(), diff) 223 | } 224 | _, ok := selectValues.SetFetchableFields(nil) 225 | if diff := testutil.Diff(ok, false); diff != "" { 226 | t.Error(testutil.Callers(), diff) 227 | } 228 | gotField, _, _ := ToSQL("", selectValues.Field("bbb"), nil) 229 | if diff := testutil.Diff(gotField, "aaa.bbb"); diff != "" { 230 | t.Error(testutil.Callers(), diff) 231 | } 232 | }) 233 | 234 | tests := []TestTable{{ 235 | description: "empty", 236 | item: SelectValues{}, 237 | wantQuery: "", 238 | wantArgs: nil, 239 | }, { 240 | description: "no columns", 241 | item: SelectValues{ 242 | RowValues: [][]any{ 243 | {1, 2, 3}, 244 | {4, 5, 6}, 245 | {7, 8, 9}, 246 | }, 247 | }, 248 | wantQuery: "SELECT ?, ?, ?" + 249 | " UNION ALL SELECT ?, ?, ?" + 250 | " UNION ALL SELECT ?, ?, ?", 251 | wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, 252 | }, { 253 | description: "postgres", 254 | dialect: DialectPostgres, 255 | item: SelectValues{ 256 | Columns: []string{"a", "b", "c"}, 257 | RowValues: [][]any{ 258 | {1, 2, 3}, 259 | {4, 5, 6}, 260 | {7, 8, 9}, 261 | }, 262 | }, 263 | wantQuery: "SELECT $1 AS a, $2 AS b, $3 AS c" + 264 | " UNION ALL SELECT $4, $5, $6" + 265 | " UNION ALL SELECT $7, $8, $9", 266 | wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, 267 | }} 268 | 269 | for _, tt := range tests { 270 | tt := tt 271 | t.Run(tt.description, func(t *testing.T) { 272 | t.Parallel() 273 | var buf bytes.Buffer 274 | var gotArgs []any 275 | err := tt.item.WriteSQL(context.Background(), tt.dialect, &buf, &gotArgs, nil) 276 | if err != nil { 277 | t.Fatal(testutil.Callers(), err) 278 | } 279 | gotQuery := buf.String() 280 | if diff := testutil.Diff(gotQuery, tt.wantQuery); diff != "" { 281 | t.Error(testutil.Callers(), diff) 282 | } 283 | if diff := testutil.Diff(gotArgs, tt.wantArgs); diff != "" { 284 | t.Error(testutil.Callers(), diff) 285 | } 286 | }) 287 | } 288 | } 289 | 290 | func TestTableValues(t *testing.T) { 291 | type TestTable struct { 292 | description string 293 | dialect string 294 | item TableValues 295 | wantQuery string 296 | wantArgs []any 297 | } 298 | 299 | t.Run("dialect alias columns and fields", func(t *testing.T) { 300 | tableValues := TableValues{ 301 | Alias: "aaa", 302 | Columns: []string{"a", "b", "c"}, 303 | } 304 | if diff := testutil.Diff(tableValues.GetAlias(), "aaa"); diff != "" { 305 | t.Error(testutil.Callers(), diff) 306 | } 307 | if diff := testutil.Diff(tableValues.GetDialect(), ""); diff != "" { 308 | t.Error(testutil.Callers(), diff) 309 | } 310 | _, ok := tableValues.SetFetchableFields(nil) 311 | if diff := testutil.Diff(ok, false); diff != "" { 312 | t.Error(testutil.Callers(), diff) 313 | } 314 | gotColumns := tableValues.GetColumns() 315 | wantColumns := []string{"a", "b", "c"} 316 | if diff := testutil.Diff(gotColumns, wantColumns); diff != "" { 317 | t.Error(testutil.Callers(), diff) 318 | } 319 | gotField, _, _ := ToSQL("", tableValues.Field("bbb"), nil) 320 | wantField := "aaa.bbb" 321 | if diff := testutil.Diff(gotField, wantField); diff != "" { 322 | t.Error(testutil.Callers(), diff) 323 | } 324 | }) 325 | 326 | tests := []TestTable{{ 327 | description: "empty", 328 | item: TableValues{}, 329 | wantQuery: "", 330 | wantArgs: nil, 331 | }, { 332 | description: "no columns", 333 | item: TableValues{ 334 | RowValues: [][]any{ 335 | {1, 2, 3}, 336 | {4, 5, 6}, 337 | {7, 8, 9}, 338 | }, 339 | }, 340 | wantQuery: "VALUES (?, ?, ?)" + 341 | ", (?, ?, ?)" + 342 | ", (?, ?, ?)", 343 | wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, 344 | }, { 345 | description: "postgres", 346 | dialect: DialectPostgres, 347 | item: TableValues{ 348 | Columns: []string{"a", "b", "c"}, 349 | RowValues: [][]any{ 350 | {1, 2, 3}, 351 | {4, 5, 6}, 352 | {7, 8, 9}, 353 | }, 354 | }, 355 | wantQuery: "VALUES ($1, $2, $3)" + 356 | ", ($4, $5, $6)" + 357 | ", ($7, $8, $9)", 358 | wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, 359 | }, { 360 | description: "mysql", 361 | dialect: DialectMySQL, 362 | item: TableValues{ 363 | Columns: []string{"a", "b", "c"}, 364 | RowValues: [][]any{ 365 | {1, 2, 3}, 366 | {4, 5, 6}, 367 | {7, 8, 9}, 368 | }, 369 | }, 370 | wantQuery: "VALUES ROW(?, ?, ?)" + 371 | ", ROW(?, ?, ?)" + 372 | ", ROW(?, ?, ?)", 373 | wantArgs: []any{1, 2, 3, 4, 5, 6, 7, 8, 9}, 374 | }} 375 | 376 | for _, tt := range tests { 377 | tt := tt 378 | t.Run(tt.description, func(t *testing.T) { 379 | t.Parallel() 380 | var buf bytes.Buffer 381 | var gotArgs []any 382 | err := tt.item.WriteSQL(context.Background(), tt.dialect, &buf, &gotArgs, nil) 383 | if err != nil { 384 | t.Fatal(testutil.Callers(), err) 385 | } 386 | gotQuery := buf.String() 387 | if diff := testutil.Diff(gotQuery, tt.wantQuery); diff != "" { 388 | t.Error(testutil.Callers(), diff) 389 | } 390 | if diff := testutil.Diff(gotArgs, tt.wantArgs); diff != "" { 391 | t.Error(testutil.Callers(), diff) 392 | } 393 | }) 394 | } 395 | } 396 | -------------------------------------------------------------------------------- /sq.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql" 7 | "database/sql/driver" 8 | "encoding/json" 9 | "fmt" 10 | "reflect" 11 | "strings" 12 | "sync" 13 | 14 | "github.com/bokwoon95/sq/internal/googleuuid" 15 | "github.com/bokwoon95/sq/internal/pqarray" 16 | ) 17 | 18 | var bufpool = &sync.Pool{ 19 | New: func() any { return &bytes.Buffer{} }, 20 | } 21 | 22 | // Dialects supported. 23 | const ( 24 | DialectSQLite = "sqlite" 25 | DialectPostgres = "postgres" 26 | DialectMySQL = "mysql" 27 | DialectSQLServer = "sqlserver" 28 | ) 29 | 30 | // SQLWriter is anything that can be converted to SQL. 31 | type SQLWriter interface { 32 | // WriteSQL writes the SQL representation of the SQLWriter into the query 33 | // string (*bytes.Buffer) and args slice (*[]any). 34 | // 35 | // The params map is used to hold the mappings between named parameters in 36 | // the query to the corresponding index in the args slice and is used for 37 | // rebinding args by their parameter name. The params map may be nil, check 38 | // first before writing to it. 39 | WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error 40 | } 41 | 42 | // DB is a database/sql abstraction that can query the database. *sql.Conn, 43 | // *sql.DB and *sql.Tx all implement DB. 44 | type DB interface { 45 | QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) 46 | ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) 47 | PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) 48 | } 49 | 50 | // Result is the result of an Exec command. 51 | type Result struct { 52 | LastInsertId int64 53 | RowsAffected int64 54 | } 55 | 56 | // Query is either SELECT, INSERT, UPDATE or DELETE. 57 | type Query interface { 58 | SQLWriter 59 | // SetFetchableFields should return a query with its fetchable fields set 60 | // to the given fields. If not applicable, it should return false as the 61 | // second return value. 62 | SetFetchableFields([]Field) (query Query, ok bool) 63 | GetDialect() string 64 | } 65 | 66 | // Table is anything you can Select from or Join. 67 | type Table interface { 68 | SQLWriter 69 | IsTable() 70 | } 71 | 72 | // PolicyTable is a table that produces a policy (i.e. a predicate) to be 73 | // enforced whenever it is invoked in a query. This is equivalent to Postgres' 74 | // Row Level Security (RLS) feature but works application-side. Only SELECT, 75 | // UPDATE and DELETE queries are affected. 76 | type PolicyTable interface { 77 | Table 78 | Policy(ctx context.Context, dialect string) (Predicate, error) 79 | } 80 | 81 | // Window is a window used in SQL window functions. 82 | type Window interface { 83 | SQLWriter 84 | IsWindow() 85 | } 86 | 87 | // Field is either a table column or some SQL expression. 88 | type Field interface { 89 | SQLWriter 90 | IsField() 91 | } 92 | 93 | // Predicate is an SQL expression that evaluates to true or false. 94 | type Predicate interface { 95 | Boolean 96 | } 97 | 98 | // Assignment is an SQL assignment 'field = value'. 99 | type Assignment interface { 100 | SQLWriter 101 | IsAssignment() 102 | } 103 | 104 | // Any is a catch-all interface that covers every field type. 105 | type Any interface { 106 | Array 107 | Binary 108 | Boolean 109 | Enum 110 | JSON 111 | Number 112 | String 113 | Time 114 | UUID 115 | } 116 | 117 | // Enumeration represents a Go enum. 118 | type Enumeration interface { 119 | // Enumerate returns the names of all valid enum values. 120 | // 121 | // If the enum is backed by a string, each string in the slice is the 122 | // corresponding enum's string value. 123 | // 124 | // If the enum is backed by an int, each int index in the slice is the 125 | // corresponding enum's int value and the string is the enum's name. Enums 126 | // with empty string names are considered invalid, unless it is the very 127 | // first enum (at index 0). 128 | Enumerate() []string 129 | } 130 | 131 | // Array is a Field of array type. 132 | type Array interface { 133 | Field 134 | IsArray() 135 | } 136 | 137 | // Binary is a Field of binary type. 138 | type Binary interface { 139 | Field 140 | IsBinary() 141 | } 142 | 143 | // Boolean is a Field of boolean type. 144 | type Boolean interface { 145 | Field 146 | IsBoolean() 147 | } 148 | 149 | // Enum is a Field of enum type. 150 | type Enum interface { 151 | Field 152 | IsEnum() 153 | } 154 | 155 | // JSON is a Field of json type. 156 | type JSON interface { 157 | Field 158 | IsJSON() 159 | } 160 | 161 | // Number is a Field of numeric type. 162 | type Number interface { 163 | Field 164 | IsNumber() 165 | } 166 | 167 | // String is a Field of string type. 168 | type String interface { 169 | Field 170 | IsString() 171 | } 172 | 173 | // Time is a Field of time type. 174 | type Time interface { 175 | Field 176 | IsTime() 177 | } 178 | 179 | // UUID is a Field of uuid type. 180 | type UUID interface { 181 | Field 182 | IsUUID() 183 | } 184 | 185 | // DialectValuer is any type that will yield a different driver.Valuer 186 | // depending on the SQL dialect. 187 | type DialectValuer interface { 188 | DialectValuer(dialect string) (driver.Valuer, error) 189 | } 190 | 191 | // TableStruct is meant to be embedded in table structs to make them implement 192 | // the Table interface. 193 | type TableStruct struct { 194 | schema string 195 | name string 196 | alias string 197 | } 198 | 199 | // ViewStruct is just an alias for TableStruct. 200 | type ViewStruct = TableStruct 201 | 202 | var _ Table = (*TableStruct)(nil) 203 | 204 | // NewTableStruct creates a new TableStruct. 205 | func NewTableStruct(schema, name, alias string) TableStruct { 206 | return TableStruct{schema: schema, name: name, alias: alias} 207 | } 208 | 209 | // WriteSQL implements the SQLWriter interface. 210 | func (ts TableStruct) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 211 | if ts.schema != "" { 212 | buf.WriteString(QuoteIdentifier(dialect, ts.schema) + ".") 213 | } 214 | buf.WriteString(QuoteIdentifier(dialect, ts.name)) 215 | return nil 216 | } 217 | 218 | // GetAlias returns the alias of the TableStruct. 219 | func (ts TableStruct) GetAlias() string { return ts.alias } 220 | 221 | // IsTable implements the Table interface. 222 | func (ts TableStruct) IsTable() {} 223 | 224 | func withPrefix(w SQLWriter, prefix string) SQLWriter { 225 | if field, ok := w.(interface { 226 | SQLWriter 227 | WithPrefix(string) Field 228 | }); ok { 229 | return field.WithPrefix(prefix) 230 | } 231 | return w 232 | } 233 | 234 | func getAlias(w SQLWriter) string { 235 | if w, ok := w.(interface{ GetAlias() string }); ok { 236 | return w.GetAlias() 237 | } 238 | return "" 239 | } 240 | 241 | func toString(dialect string, w SQLWriter) string { 242 | buf := bufpool.Get().(*bytes.Buffer) 243 | buf.Reset() 244 | defer bufpool.Put(buf) 245 | var args []any 246 | _ = w.WriteSQL(context.Background(), dialect, buf, &args, nil) 247 | return buf.String() 248 | } 249 | 250 | func writeFieldsWithPrefix(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, fields []Field, prefix string, includeAlias bool) error { 251 | var err error 252 | var alias string 253 | for i, field := range fields { 254 | if field == nil { 255 | return fmt.Errorf("field #%d is nil", i+1) 256 | } 257 | if i > 0 { 258 | buf.WriteString(", ") 259 | } 260 | err = withPrefix(field, prefix).WriteSQL(ctx, dialect, buf, args, params) 261 | if err != nil { 262 | return fmt.Errorf("field #%d: %w", i+1, err) 263 | } 264 | if includeAlias { 265 | if alias = getAlias(field); alias != "" { 266 | buf.WriteString(" AS " + QuoteIdentifier(dialect, alias)) 267 | } 268 | } 269 | } 270 | return nil 271 | } 272 | 273 | func writeFields(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int, fields []Field, includeAlias bool) error { 274 | var err error 275 | var alias string 276 | for i, field := range fields { 277 | if field == nil { 278 | return fmt.Errorf("field #%d is nil", i+1) 279 | } 280 | if i > 0 { 281 | buf.WriteString(", ") 282 | } 283 | _, isQuery := field.(Query) 284 | if isQuery { 285 | buf.WriteString("(") 286 | } 287 | err = field.WriteSQL(ctx, dialect, buf, args, params) 288 | if err != nil { 289 | return fmt.Errorf("field #%d: %w", i+1, err) 290 | } 291 | if isQuery { 292 | buf.WriteString(")") 293 | } 294 | if includeAlias { 295 | if alias = getAlias(field); alias != "" { 296 | buf.WriteString(" AS " + QuoteIdentifier(dialect, alias)) 297 | } 298 | } 299 | } 300 | return nil 301 | } 302 | 303 | // mapperFunctionPanicked recovers from any panics. 304 | // 305 | // The function is called as such so that it shows up as 306 | // "sq.mapperFunctionPanicked" in panic stack trace, giving the user a 307 | // descriptive clue of what went wrong (i.e. their mapper function panicked). 308 | func mapperFunctionPanicked(err *error) { 309 | if r := recover(); r != nil { 310 | switch r := r.(type) { 311 | case error: 312 | *err = r 313 | default: 314 | *err = fmt.Errorf(fmt.Sprint(r)) 315 | } 316 | } 317 | } 318 | 319 | // ArrayValue takes in a []string, []int, []int64, []int32, []float64, 320 | // []float32 or []bool and returns a driver.Valuer for that type. For Postgres, 321 | // it serializes into a Postgres array. Otherwise, it serializes into a JSON 322 | // array. 323 | func ArrayValue(value any) driver.Valuer { 324 | return &arrayValue{value: value} 325 | } 326 | 327 | type arrayValue struct { 328 | dialect string 329 | value any 330 | } 331 | 332 | // Value implements the driver.Valuer interface. 333 | func (v *arrayValue) Value() (driver.Value, error) { 334 | switch v.value.(type) { 335 | case []string, []int, []int64, []int32, []float64, []float32, []bool: 336 | break 337 | default: 338 | return nil, fmt.Errorf("value %#v is not a []string, []int, []int32, []float64, []float32 or []bool", v.value) 339 | } 340 | if v.dialect != DialectPostgres { 341 | var b strings.Builder 342 | err := json.NewEncoder(&b).Encode(v.value) 343 | if err != nil { 344 | return nil, err 345 | } 346 | return strings.TrimSpace(b.String()), nil 347 | } 348 | if ints, ok := v.value.([]int); ok { 349 | bigints := make([]int64, len(ints)) 350 | for i, num := range ints { 351 | bigints[i] = int64(num) 352 | } 353 | v.value = bigints 354 | } 355 | return pqarray.Array(v.value).Value() 356 | } 357 | 358 | // DialectValuer implements the DialectValuer interface. 359 | func (v *arrayValue) DialectValuer(dialect string) (driver.Valuer, error) { 360 | v.dialect = dialect 361 | return v, nil 362 | } 363 | 364 | // EnumValue takes in an Enumeration and returns a driver.Valuer which 365 | // serializes the enum into a string and additionally checks if the enum is 366 | // valid. 367 | func EnumValue(value Enumeration) driver.Valuer { 368 | return &enumValue{value: value} 369 | } 370 | 371 | type enumValue struct { 372 | value Enumeration 373 | } 374 | 375 | // Value implements the driver.Valuer interface. 376 | func (v *enumValue) Value() (driver.Value, error) { 377 | value := reflect.ValueOf(v.value) 378 | names := v.value.Enumerate() 379 | switch value.Kind() { 380 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 381 | i := int(value.Int()) 382 | if i < 0 || i >= len(names) { 383 | return nil, fmt.Errorf("%d is not a valid %T", i, v.value) 384 | } 385 | name := names[i] 386 | if name == "" && i != 0 { 387 | return nil, fmt.Errorf("%d is not a valid %T", i, v.value) 388 | } 389 | return name, nil 390 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 391 | i := int(value.Uint()) 392 | if i < 0 || i >= len(names) { 393 | return nil, fmt.Errorf("%d is not a valid %T", i, v.value) 394 | } 395 | name := names[i] 396 | if name == "" && i != 0 { 397 | return nil, fmt.Errorf("%d is not a valid %T", i, v.value) 398 | } 399 | return name, nil 400 | case reflect.String: 401 | typ := value.Type() 402 | name := value.String() 403 | if getEnumIndex(name, names, typ) < 0 { 404 | return nil, fmt.Errorf("%q is not a valid %T", name, v.value) 405 | } 406 | return name, nil 407 | default: 408 | return nil, fmt.Errorf("underlying type of %[1]v is neither an integer nor string (%[1]T)", v.value) 409 | } 410 | } 411 | 412 | var ( 413 | enumIndexMu sync.RWMutex 414 | enumIndex = make(map[reflect.Type]map[string]int) 415 | ) 416 | 417 | // getEnumIndex returns the index of the enum within the names slice. 418 | func getEnumIndex(name string, names []string, typ reflect.Type) int { 419 | if len(names) <= 4 { 420 | for idx := range names { 421 | if names[idx] == name { 422 | return idx 423 | } 424 | } 425 | return -1 426 | } 427 | var nameIndex map[string]int 428 | enumIndexMu.RLock() 429 | nameIndex = enumIndex[typ] 430 | enumIndexMu.RUnlock() 431 | if nameIndex != nil { 432 | idx, ok := nameIndex[name] 433 | if !ok { 434 | return -1 435 | } 436 | return idx 437 | } 438 | idx := -1 439 | nameIndex = make(map[string]int) 440 | for i := range names { 441 | if names[i] == name { 442 | idx = i 443 | } 444 | nameIndex[names[i]] = i 445 | } 446 | enumIndexMu.Lock() 447 | enumIndex[typ] = nameIndex 448 | enumIndexMu.Unlock() 449 | return idx 450 | } 451 | 452 | // JSONValue takes in an interface{} and returns a driver.Valuer which runs the 453 | // value through json.Marshal before submitting it to the database. 454 | func JSONValue(value any) driver.Valuer { 455 | return &jsonValue{value: value} 456 | } 457 | 458 | type jsonValue struct { 459 | value any 460 | } 461 | 462 | // Value implements the driver.Valuer interface. 463 | func (v *jsonValue) Value() (driver.Value, error) { 464 | var b strings.Builder 465 | err := json.NewEncoder(&b).Encode(v.value) 466 | return strings.TrimSpace(b.String()), err 467 | } 468 | 469 | // UUIDValue takes in a type whose underlying type must be a [16]byte and 470 | // returns a driver.Valuer. 471 | func UUIDValue(value any) driver.Valuer { 472 | return &uuidValue{value: value} 473 | } 474 | 475 | type uuidValue struct { 476 | dialect string 477 | value any 478 | } 479 | 480 | // Value implements the driver.Valuer interface. 481 | func (v *uuidValue) Value() (driver.Value, error) { 482 | if v.value == nil { 483 | return nil, nil 484 | } 485 | uuid, ok := v.value.([16]byte) 486 | if !ok { 487 | value := reflect.ValueOf(v.value) 488 | typ := value.Type() 489 | if value.Kind() != reflect.Array || value.Len() != 16 || typ.Elem().Kind() != reflect.Uint8 { 490 | return nil, fmt.Errorf("%[1]v %[1]T is not [16]byte", v.value) 491 | } 492 | for i := 0; i < value.Len(); i++ { 493 | uuid[i] = value.Index(i).Interface().(byte) 494 | } 495 | } 496 | if v.dialect != DialectPostgres { 497 | return uuid[:], nil 498 | } 499 | var buf [36]byte 500 | googleuuid.EncodeHex(buf[:], uuid) 501 | return string(buf[:]), nil 502 | } 503 | 504 | // DialectValuer implements the DialectValuer interface. 505 | func (v *uuidValue) DialectValuer(dialect string) (driver.Valuer, error) { 506 | v.dialect = dialect 507 | return v, nil 508 | } 509 | 510 | func preprocessValue(dialect string, value any) (any, error) { 511 | if dialectValuer, ok := value.(DialectValuer); ok { 512 | driverValuer, err := dialectValuer.DialectValuer(dialect) 513 | if err != nil { 514 | return nil, fmt.Errorf("calling DialectValuer on %#v: %w", dialectValuer, err) 515 | } 516 | value = driverValuer 517 | } 518 | switch value := value.(type) { 519 | case nil: 520 | return nil, nil 521 | case Enumeration: 522 | driverValue, err := (&enumValue{value: value}).Value() 523 | if err != nil { 524 | return nil, fmt.Errorf("converting %#v to string: %w", value, err) 525 | } 526 | return driverValue, nil 527 | case [16]byte: 528 | driverValue, err := (&uuidValue{dialect: dialect, value: value}).Value() 529 | if err != nil { 530 | if dialect == DialectPostgres { 531 | return nil, fmt.Errorf("converting %#v to string: %w", value, err) 532 | } 533 | return nil, fmt.Errorf("converting %#v to bytes: %w", value, err) 534 | } 535 | return driverValue, nil 536 | case driver.Valuer: 537 | driverValue, err := value.Value() 538 | if err != nil { 539 | return nil, fmt.Errorf("calling Value on %#v: %w", value, err) 540 | } 541 | return driverValue, nil 542 | } 543 | return value, nil 544 | } 545 | -------------------------------------------------------------------------------- /sq_test.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/bokwoon95/sq/internal/testutil" 8 | "github.com/google/uuid" 9 | ) 10 | 11 | type Weekday uint 12 | 13 | const ( 14 | WeekdayInvalid Weekday = iota 15 | Sunday 16 | Monday 17 | Tuesday 18 | Wednesday 19 | Thursday 20 | Friday 21 | Saturday 22 | ) 23 | 24 | func (d Weekday) Enumerate() []string { 25 | return []string{ 26 | WeekdayInvalid: "", 27 | Sunday: "Sunday", 28 | Monday: "Monday", 29 | Tuesday: "Tuesday", 30 | Wednesday: "Wednesday", 31 | Thursday: "Thursday", 32 | Friday: "Friday", 33 | Saturday: "Saturday", 34 | } 35 | } 36 | 37 | func Test_preprocessValue(t *testing.T) { 38 | type TestTable struct { 39 | description string 40 | dialect string 41 | input any 42 | wantOutput any 43 | } 44 | 45 | tests := []TestTable{{ 46 | description: "empty", 47 | input: nil, 48 | wantOutput: nil, 49 | }, { 50 | description: "driver.Valuer", 51 | input: uuid.MustParse("a4f952f1-4c45-4e63-bd4e-159ca33c8e20"), 52 | wantOutput: "a4f952f1-4c45-4e63-bd4e-159ca33c8e20", 53 | }, { 54 | description: "Postgres DialectValuer", 55 | dialect: DialectPostgres, 56 | input: UUIDValue(uuid.MustParse("a4f952f1-4c45-4e63-bd4e-159ca33c8e20")), 57 | wantOutput: "a4f952f1-4c45-4e63-bd4e-159ca33c8e20", 58 | }, { 59 | description: "MySQL DialectValuer", 60 | dialect: DialectMySQL, 61 | input: UUIDValue(uuid.MustParse("a4f952f1-4c45-4e63-bd4e-159ca33c8e20")), 62 | wantOutput: []byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20}, 63 | }, { 64 | description: "Postgres [16]byte", 65 | dialect: DialectPostgres, 66 | input: [16]byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20}, 67 | wantOutput: "a4f952f1-4c45-4e63-bd4e-159ca33c8e20", 68 | }, { 69 | description: "MySQL [16]byte", 70 | dialect: DialectMySQL, 71 | input: [16]byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20}, 72 | wantOutput: []byte{0xa4, 0xf9, 0x52, 0xf1, 0x4c, 0x45, 0x4e, 0x63, 0xbd, 0x4e, 0x15, 0x9c, 0xa3, 0x3c, 0x8e, 0x20}, 73 | }, { 74 | description: "Enumeration", 75 | input: Monday, 76 | wantOutput: "Monday", 77 | }, { 78 | description: "int", 79 | input: 42, 80 | wantOutput: 42, 81 | }, { 82 | description: "sql.NullString", 83 | input: sql.NullString{ 84 | Valid: false, 85 | String: "lorem ipsum dolor sit amet", 86 | }, 87 | wantOutput: nil, 88 | }} 89 | 90 | for _, tt := range tests { 91 | tt := tt 92 | t.Run(tt.description, func(t *testing.T) { 93 | t.Parallel() 94 | gotOutput, err := preprocessValue(tt.dialect, tt.input) 95 | if err != nil { 96 | t.Fatal(testutil.Callers(), err) 97 | } 98 | if diff := testutil.Diff(gotOutput, tt.wantOutput); diff != "" { 99 | t.Error(testutil.Callers(), diff) 100 | } 101 | }) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /update_query_test.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/bokwoon95/sq/internal/testutil" 7 | ) 8 | 9 | func TestSQLiteUpdateQuery(t *testing.T) { 10 | type ACTOR struct { 11 | TableStruct 12 | ACTOR_ID NumberField 13 | FIRST_NAME StringField 14 | LAST_NAME StringField 15 | LAST_UPDATE TimeField 16 | } 17 | a := New[ACTOR]("a") 18 | 19 | t.Run("basic", func(t *testing.T) { 20 | t.Parallel() 21 | q1 := SQLite.Update(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum") 22 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { 23 | t.Error(testutil.Callers(), diff) 24 | } 25 | q1 = q1.SetDialect(DialectSQLite) 26 | fields := q1.GetFetchableFields() 27 | if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" { 28 | t.Error(testutil.Callers(), diff) 29 | } 30 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) 31 | if ok { 32 | t.Fatal(testutil.Callers(), "field should not have been set") 33 | } 34 | q1.ReturningFields = q1.ReturningFields[:0] 35 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) 36 | if !ok { 37 | t.Fatal(testutil.Callers(), "field should have been set") 38 | } 39 | }) 40 | 41 | t.Run("Set", func(t *testing.T) { 42 | t.Parallel() 43 | var tt TestTable 44 | tt.item = SQLite. 45 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 46 | Update(a). 47 | Set( 48 | a.FIRST_NAME.SetString("bob"), 49 | a.LAST_NAME.SetString("the builder"), 50 | ). 51 | Where(a.ACTOR_ID.EqInt(1), a.LAST_UPDATE.IsNotNull()). 52 | Returning(a.ACTOR_ID) 53 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 54 | " UPDATE actor AS a" + 55 | " SET first_name = $1, last_name = $2" + 56 | " WHERE a.actor_id = $3 AND a.last_update IS NOT NULL" + 57 | " RETURNING a.actor_id" 58 | tt.wantArgs = []any{"bob", "the builder", 1} 59 | tt.assert(t) 60 | }) 61 | 62 | t.Run("SetFunc", func(t *testing.T) { 63 | t.Parallel() 64 | var tt TestTable 65 | tt.item = SQLite. 66 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 67 | Update(a). 68 | SetFunc(func(col *Column) { 69 | col.SetString(a.FIRST_NAME, "bob") 70 | col.SetString(a.LAST_NAME, "the builder") 71 | }). 72 | Where(a.ACTOR_ID.EqInt(1)) 73 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 74 | " UPDATE actor AS a" + 75 | " SET first_name = $1, last_name = $2" + 76 | " WHERE a.actor_id = $3" 77 | tt.wantArgs = []any{"bob", "the builder", 1} 78 | tt.assert(t) 79 | }) 80 | 81 | t.Run("UPDATE with JOIN", func(t *testing.T) { 82 | t.Parallel() 83 | var tt TestTable 84 | tt.item = SQLite. 85 | Update(a). 86 | Set( 87 | a.FIRST_NAME.SetString("bob"), 88 | a.LAST_NAME.SetString("the builder"), 89 | ). 90 | From(a). 91 | Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). 92 | LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). 93 | CrossJoin(a). 94 | CustomJoin(",", a). 95 | JoinUsing(a, a.FIRST_NAME, a.LAST_NAME). 96 | Where(a.ACTOR_ID.EqInt(1)) 97 | tt.wantQuery = "UPDATE actor AS a" + 98 | " SET first_name = $1, last_name = $2" + 99 | " FROM actor AS a" + 100 | " JOIN actor AS a ON a.actor_id = a.actor_id" + 101 | " LEFT JOIN actor AS a ON a.actor_id = a.actor_id" + 102 | " CROSS JOIN actor AS a" + 103 | " , actor AS a" + 104 | " JOIN actor AS a USING (first_name, last_name)" + 105 | " WHERE a.actor_id = $3" 106 | tt.wantArgs = []any{"bob", "the builder", 1} 107 | tt.assert(t) 108 | }) 109 | } 110 | 111 | func TestPostgresUpdateQuery(t *testing.T) { 112 | type ACTOR struct { 113 | TableStruct 114 | ACTOR_ID NumberField 115 | FIRST_NAME StringField 116 | LAST_NAME StringField 117 | LAST_UPDATE TimeField 118 | } 119 | a := New[ACTOR]("a") 120 | 121 | t.Run("basic", func(t *testing.T) { 122 | t.Parallel() 123 | q1 := Postgres.Update(a).Returning(a.FIRST_NAME).SetDialect("lorem ipsum") 124 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { 125 | t.Error(testutil.Callers(), diff) 126 | } 127 | q1 = q1.SetDialect(DialectPostgres) 128 | fields := q1.GetFetchableFields() 129 | if diff := testutil.Diff(fields, []Field{a.FIRST_NAME}); diff != "" { 130 | t.Error(testutil.Callers(), diff) 131 | } 132 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) 133 | if ok { 134 | t.Fatal(testutil.Callers(), "field should not have been set") 135 | } 136 | q1.ReturningFields = q1.ReturningFields[:0] 137 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) 138 | if !ok { 139 | t.Fatal(testutil.Callers(), "field should have been set") 140 | } 141 | }) 142 | 143 | t.Run("Set", func(t *testing.T) { 144 | t.Parallel() 145 | var tt TestTable 146 | tt.item = Postgres. 147 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 148 | Update(a). 149 | Set( 150 | a.FIRST_NAME.SetString("bob"), 151 | a.LAST_NAME.SetString("the builder"), 152 | ). 153 | Where(a.ACTOR_ID.EqInt(1), a.LAST_UPDATE.IsNotNull()). 154 | Returning(a.ACTOR_ID) 155 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 156 | " UPDATE actor AS a" + 157 | " SET first_name = $1, last_name = $2" + 158 | " WHERE a.actor_id = $3 AND a.last_update IS NOT NULL" + 159 | " RETURNING a.actor_id" 160 | tt.wantArgs = []any{"bob", "the builder", 1} 161 | tt.assert(t) 162 | }) 163 | 164 | t.Run("SetFunc", func(t *testing.T) { 165 | t.Parallel() 166 | var tt TestTable 167 | tt.item = Postgres. 168 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 169 | Update(a). 170 | SetFunc(func(col *Column) { 171 | col.SetString(a.FIRST_NAME, "bob") 172 | col.SetString(a.LAST_NAME, "the builder") 173 | }). 174 | Where(a.ACTOR_ID.EqInt(1)) 175 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 176 | " UPDATE actor AS a" + 177 | " SET first_name = $1, last_name = $2" + 178 | " WHERE a.actor_id = $3" 179 | tt.wantArgs = []any{"bob", "the builder", 1} 180 | tt.assert(t) 181 | }) 182 | 183 | t.Run("UPDATE with JOIN", func(t *testing.T) { 184 | t.Parallel() 185 | var tt TestTable 186 | tt.item = Postgres. 187 | Update(a). 188 | Set( 189 | a.FIRST_NAME.SetString("bob"), 190 | a.LAST_NAME.SetString("the builder"), 191 | ). 192 | From(a). 193 | Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). 194 | LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). 195 | FullJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). 196 | CrossJoin(a). 197 | CustomJoin(",", a). 198 | JoinUsing(a, a.FIRST_NAME, a.LAST_NAME). 199 | Where(a.ACTOR_ID.EqInt(1)) 200 | tt.wantQuery = "UPDATE actor AS a" + 201 | " SET first_name = $1, last_name = $2" + 202 | " FROM actor AS a" + 203 | " JOIN actor AS a ON a.actor_id = a.actor_id" + 204 | " LEFT JOIN actor AS a ON a.actor_id = a.actor_id" + 205 | " FULL JOIN actor AS a ON a.actor_id = a.actor_id" + 206 | " CROSS JOIN actor AS a" + 207 | " , actor AS a" + 208 | " JOIN actor AS a USING (first_name, last_name)" + 209 | " WHERE a.actor_id = $3" 210 | tt.wantArgs = []any{"bob", "the builder", 1} 211 | tt.assert(t) 212 | }) 213 | } 214 | 215 | func TestMySQLUpdateQuery(t *testing.T) { 216 | type ACTOR struct { 217 | TableStruct 218 | ACTOR_ID NumberField 219 | FIRST_NAME StringField 220 | LAST_NAME StringField 221 | LAST_UPDATE TimeField 222 | } 223 | a := New[ACTOR]("a") 224 | 225 | t.Run("basic", func(t *testing.T) { 226 | t.Parallel() 227 | q1 := MySQL.Update(a).SetDialect("lorem ipsum") 228 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { 229 | t.Error(testutil.Callers(), diff) 230 | } 231 | q1 = q1.SetDialect(DialectMySQL) 232 | fields := q1.GetFetchableFields() 233 | if len(fields) != 0 { 234 | t.Error(testutil.Callers(), "expected 0 fields but got %v", fields) 235 | } 236 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) 237 | if ok { 238 | t.Fatal(testutil.Callers(), "field should not have been set") 239 | } 240 | q1.ReturningFields = q1.ReturningFields[:0] 241 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) 242 | if ok { 243 | t.Fatal(testutil.Callers(), "field should not have been set") 244 | } 245 | }) 246 | 247 | t.Run("Set", func(t *testing.T) { 248 | t.Parallel() 249 | var tt TestTable 250 | tt.item = MySQL. 251 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 252 | Update(a). 253 | Set( 254 | a.FIRST_NAME.SetString("bob"), 255 | a.LAST_NAME.SetString("the builder"), 256 | ). 257 | Where(a.ACTOR_ID.EqInt(1)) 258 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 259 | " UPDATE actor AS a" + 260 | " SET a.first_name = ?, a.last_name = ?" + 261 | " WHERE a.actor_id = ?" 262 | tt.wantArgs = []any{"bob", "the builder", 1} 263 | tt.assert(t) 264 | }) 265 | 266 | t.Run("SetFunc", func(t *testing.T) { 267 | t.Parallel() 268 | var tt TestTable 269 | tt.item = MySQL. 270 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 271 | Update(a). 272 | SetFunc(func(col *Column) { 273 | col.SetString(a.FIRST_NAME, "bob") 274 | col.SetString(a.LAST_NAME, "the builder") 275 | }). 276 | Where(a.ACTOR_ID.EqInt(1)) 277 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 278 | " UPDATE actor AS a" + 279 | " SET a.first_name = ?, a.last_name = ?" + 280 | " WHERE a.actor_id = ?" 281 | tt.wantArgs = []any{"bob", "the builder", 1} 282 | tt.assert(t) 283 | }) 284 | 285 | t.Run("UPDATE with JOIN, ORDER BY, LIMIT", func(t *testing.T) { 286 | t.Parallel() 287 | var tt TestTable 288 | tt.item = MySQL. 289 | Update(a). 290 | Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). 291 | LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). 292 | FullJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). 293 | CrossJoin(a). 294 | CustomJoin(",", a). 295 | JoinUsing(a, a.FIRST_NAME, a.LAST_NAME). 296 | Set( 297 | a.FIRST_NAME.SetString("bob"), 298 | a.LAST_NAME.SetString("the builder"), 299 | ). 300 | Where(a.ACTOR_ID.EqInt(1)). 301 | OrderBy(a.ACTOR_ID). 302 | Limit(5) 303 | tt.wantQuery = "UPDATE actor AS a" + 304 | " JOIN actor AS a ON a.actor_id = a.actor_id" + 305 | " LEFT JOIN actor AS a ON a.actor_id = a.actor_id" + 306 | " FULL JOIN actor AS a ON a.actor_id = a.actor_id" + 307 | " CROSS JOIN actor AS a" + 308 | " , actor AS a" + 309 | " JOIN actor AS a USING (first_name, last_name)" + 310 | " SET a.first_name = ?, a.last_name = ?" + 311 | " WHERE a.actor_id = ?" + 312 | " ORDER BY a.actor_id" + 313 | " LIMIT ?" 314 | tt.wantArgs = []any{"bob", "the builder", 1, 5} 315 | tt.assert(t) 316 | }) 317 | } 318 | 319 | func TestSQLServerUpdateQuery(t *testing.T) { 320 | type ACTOR struct { 321 | TableStruct 322 | ACTOR_ID NumberField 323 | FIRST_NAME StringField 324 | LAST_NAME StringField 325 | LAST_UPDATE TimeField 326 | } 327 | a := New[ACTOR]("") 328 | 329 | t.Run("basic", func(t *testing.T) { 330 | t.Parallel() 331 | q1 := SQLServer.Update(a).SetDialect("lorem ipsum") 332 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { 333 | t.Error(testutil.Callers(), diff) 334 | } 335 | q1 = q1.SetDialect(DialectSQLServer) 336 | fields := q1.GetFetchableFields() 337 | if len(fields) != 0 { 338 | t.Error(testutil.Callers(), "expected 0 fields but got %v", fields) 339 | } 340 | _, ok := q1.SetFetchableFields([]Field{a.LAST_NAME}) 341 | if ok { 342 | t.Fatal(testutil.Callers(), "field should not have been set") 343 | } 344 | q1.ReturningFields = q1.ReturningFields[:0] 345 | _, ok = q1.SetFetchableFields([]Field{a.LAST_NAME}) 346 | if ok { 347 | t.Fatal(testutil.Callers(), "field should not have been set") 348 | } 349 | }) 350 | 351 | t.Run("Set", func(t *testing.T) { 352 | t.Parallel() 353 | var tt TestTable 354 | tt.item = SQLServer. 355 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 356 | Update(a). 357 | Set( 358 | a.FIRST_NAME.SetString("bob"), 359 | a.LAST_NAME.SetString("the builder"), 360 | ). 361 | Where(a.ACTOR_ID.EqInt(1)) 362 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 363 | " UPDATE actor" + 364 | " SET first_name = @p1, last_name = @p2" + 365 | " WHERE actor.actor_id = @p3" 366 | tt.wantArgs = []any{"bob", "the builder", 1} 367 | tt.assert(t) 368 | }) 369 | 370 | t.Run("SetFunc", func(t *testing.T) { 371 | t.Parallel() 372 | var tt TestTable 373 | tt.item = SQLServer. 374 | With(NewCTE("cte", nil, Queryf("SELECT 1"))). 375 | Update(a). 376 | SetFunc(func(col *Column) { 377 | col.SetString(a.FIRST_NAME, "bob") 378 | col.SetString(a.LAST_NAME, "the builder") 379 | }). 380 | Where(a.ACTOR_ID.EqInt(1)) 381 | tt.wantQuery = "WITH cte AS (SELECT 1)" + 382 | " UPDATE actor" + 383 | " SET first_name = @p1, last_name = @p2" + 384 | " WHERE actor.actor_id = @p3" 385 | tt.wantArgs = []any{"bob", "the builder", 1} 386 | tt.assert(t) 387 | }) 388 | 389 | t.Run("UPDATE with JOIN", func(t *testing.T) { 390 | t.Parallel() 391 | var tt TestTable 392 | tt.item = SQLServer. 393 | Update(a). 394 | Set( 395 | a.FIRST_NAME.SetString("bob"), 396 | a.LAST_NAME.SetString("the builder"), 397 | ). 398 | From(a). 399 | Join(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). 400 | LeftJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). 401 | FullJoin(a, a.ACTOR_ID.Eq(a.ACTOR_ID)). 402 | CrossJoin(a). 403 | CustomJoin(",", a). 404 | Where(a.ACTOR_ID.EqInt(1)) 405 | tt.wantQuery = "UPDATE actor" + 406 | " SET first_name = @p1, last_name = @p2" + 407 | " FROM actor" + 408 | " JOIN actor ON actor.actor_id = actor.actor_id" + 409 | " LEFT JOIN actor ON actor.actor_id = actor.actor_id" + 410 | " FULL JOIN actor ON actor.actor_id = actor.actor_id" + 411 | " CROSS JOIN actor" + 412 | " , actor" + 413 | " WHERE actor.actor_id = @p3" 414 | tt.wantArgs = []any{"bob", "the builder", 1} 415 | tt.assert(t) 416 | }) 417 | } 418 | 419 | func TestUpdateQuery(t *testing.T) { 420 | t.Run("basic", func(t *testing.T) { 421 | t.Parallel() 422 | q1 := UpdateQuery{UpdateTable: Expr("tbl"), Dialect: "lorem ipsum"} 423 | if diff := testutil.Diff(q1.GetDialect(), "lorem ipsum"); diff != "" { 424 | t.Error(testutil.Callers(), diff) 425 | } 426 | }) 427 | 428 | f1, f2, f3 := Expr("f1"), Expr("f2"), Expr("f3") 429 | colmapper := func(col *Column) { 430 | col.Set(f1, 1) 431 | col.Set(f2, 2) 432 | col.Set(f3, 3) 433 | } 434 | 435 | t.Run("PolicyTable", func(t *testing.T) { 436 | t.Parallel() 437 | var tt TestTable 438 | tt.item = UpdateQuery{ 439 | UpdateTable: policyTableStub{policy: And(Expr("1 = 1"), Expr("2 = 2"))}, 440 | ColumnMapper: colmapper, 441 | WherePredicate: Expr("3 = 3"), 442 | } 443 | tt.wantQuery = "UPDATE policy_table_stub SET f1 = ?, f2 = ?, f3 = ? WHERE (1 = 1 AND 2 = 2) AND 3 = 3" 444 | tt.wantArgs = []any{1, 2, 3} 445 | tt.assert(t) 446 | }) 447 | 448 | notOKTests := []TestTable{{ 449 | description: "nil UpdateTable not allowed", 450 | item: UpdateQuery{ 451 | UpdateTable: nil, 452 | ColumnMapper: colmapper, 453 | }, 454 | }, { 455 | description: "empty Assignments not allowed", 456 | item: UpdateQuery{ 457 | UpdateTable: Expr("tbl"), 458 | Assignments: nil, 459 | }, 460 | }, { 461 | description: "mysql does not support FROM", 462 | item: UpdateQuery{ 463 | Dialect: DialectMySQL, 464 | UpdateTable: Expr("tbl"), 465 | FromTable: Expr("tbl"), 466 | ColumnMapper: colmapper, 467 | }, 468 | }, { 469 | description: "dialect does not allow JOIN without FROM", 470 | item: UpdateQuery{ 471 | Dialect: DialectPostgres, 472 | UpdateTable: Expr("tbl"), 473 | FromTable: nil, 474 | JoinTables: []JoinTable{ 475 | Join(Expr("tbl"), Expr("1 = 1")), 476 | }, 477 | ColumnMapper: colmapper, 478 | }, 479 | }, { 480 | description: "dialect does not support ORDER BY", 481 | item: UpdateQuery{ 482 | Dialect: DialectPostgres, 483 | UpdateTable: Expr("tbl"), 484 | ColumnMapper: colmapper, 485 | OrderByFields: Fields{f1}, 486 | }, 487 | }, { 488 | description: "dialect does not support LIMIT", 489 | item: UpdateQuery{ 490 | Dialect: DialectPostgres, 491 | UpdateTable: Expr("tbl"), 492 | ColumnMapper: colmapper, 493 | LimitRows: 5, 494 | }, 495 | }, { 496 | description: "dialect does not support RETURNING", 497 | item: UpdateQuery{ 498 | Dialect: DialectMySQL, 499 | UpdateTable: Expr("tbl"), 500 | ColumnMapper: colmapper, 501 | ReturningFields: Fields{f1, f2, f3}, 502 | }, 503 | }} 504 | 505 | for _, tt := range notOKTests { 506 | tt := tt 507 | t.Run(tt.description, func(t *testing.T) { 508 | t.Parallel() 509 | tt.assertNotOK(t) 510 | }) 511 | } 512 | 513 | errTests := []TestTable{{ 514 | description: "ColumnMapper err", 515 | item: UpdateQuery{ 516 | UpdateTable: Expr("tbl"), 517 | ColumnMapper: func(*Column) { panic(ErrFaultySQL) }, 518 | }, 519 | }, { 520 | description: "UpdateTable Policy err", 521 | item: UpdateQuery{ 522 | UpdateTable: policyTableStub{err: ErrFaultySQL}, 523 | ColumnMapper: colmapper, 524 | }, 525 | }, { 526 | description: "FromTable Policy err", 527 | item: UpdateQuery{ 528 | UpdateTable: Expr("tbl"), 529 | FromTable: policyTableStub{err: ErrFaultySQL}, 530 | ColumnMapper: colmapper, 531 | }, 532 | }, { 533 | description: "JoinTables Policy err", 534 | item: UpdateQuery{ 535 | UpdateTable: Expr("tbl"), 536 | ColumnMapper: colmapper, 537 | FromTable: Expr("tbl"), 538 | JoinTables: []JoinTable{ 539 | Join(policyTableStub{err: ErrFaultySQL}, Expr("1 = 1")), 540 | }, 541 | }, 542 | }, { 543 | description: "CTEs err", 544 | item: UpdateQuery{ 545 | CTEs: []CTE{NewCTE("cte", nil, Queryf("SELECT {}", FaultySQL{}))}, 546 | UpdateTable: Expr("tbl"), 547 | ColumnMapper: colmapper, 548 | }, 549 | }, { 550 | description: "UpdateTable err", 551 | item: UpdateQuery{ 552 | UpdateTable: FaultySQL{}, 553 | ColumnMapper: colmapper, 554 | }, 555 | }, { 556 | description: "not mysql Assignments err", 557 | item: UpdateQuery{ 558 | Dialect: DialectPostgres, 559 | UpdateTable: Expr("tbl"), 560 | Assignments: []Assignment{FaultySQL{}}, 561 | }, 562 | }, { 563 | description: "FromTable err", 564 | item: UpdateQuery{ 565 | Dialect: DialectPostgres, 566 | UpdateTable: Expr("tbl"), 567 | ColumnMapper: colmapper, 568 | FromTable: FaultySQL{}, 569 | }, 570 | }, { 571 | description: "JoinTables err", 572 | item: UpdateQuery{ 573 | Dialect: DialectPostgres, 574 | UpdateTable: Expr("tbl"), 575 | ColumnMapper: colmapper, 576 | FromTable: Expr("tbl"), 577 | JoinTables: []JoinTable{ 578 | Join(FaultySQL{}, Expr("1 = 1")), 579 | }, 580 | }, 581 | }, { 582 | description: "mysql Assignments err", 583 | item: UpdateQuery{ 584 | Dialect: DialectMySQL, 585 | UpdateTable: Expr("tbl"), 586 | Assignments: []Assignment{FaultySQL{}}, 587 | }, 588 | }, { 589 | description: "WherePredicate Variadic err", 590 | item: UpdateQuery{ 591 | UpdateTable: Expr("tbl"), 592 | ColumnMapper: colmapper, 593 | WherePredicate: And(FaultySQL{}), 594 | }, 595 | }, { 596 | description: "WherePredicate err", 597 | item: UpdateQuery{ 598 | UpdateTable: Expr("tbl"), 599 | ColumnMapper: colmapper, 600 | WherePredicate: FaultySQL{}, 601 | }, 602 | }, { 603 | description: "OrderByFields err", 604 | item: UpdateQuery{ 605 | Dialect: DialectMySQL, 606 | UpdateTable: Expr("tbl"), 607 | ColumnMapper: colmapper, 608 | OrderByFields: Fields{FaultySQL{}}, 609 | }, 610 | }, { 611 | description: "LimitRows err", 612 | item: UpdateQuery{ 613 | Dialect: DialectMySQL, 614 | UpdateTable: Expr("tbl"), 615 | ColumnMapper: colmapper, 616 | OrderByFields: Fields{f1}, 617 | LimitRows: FaultySQL{}, 618 | }, 619 | }, { 620 | description: "ReturningFields err", 621 | item: UpdateQuery{ 622 | Dialect: DialectPostgres, 623 | UpdateTable: Expr("tbl"), 624 | ColumnMapper: colmapper, 625 | ReturningFields: Fields{FaultySQL{}}, 626 | }, 627 | }} 628 | 629 | for _, tt := range errTests { 630 | tt := tt 631 | t.Run(tt.description, func(t *testing.T) { 632 | t.Parallel() 633 | tt.assertErr(t, ErrFaultySQL) 634 | }) 635 | } 636 | } 637 | -------------------------------------------------------------------------------- /window.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | ) 8 | 9 | // NamedWindow represents an SQL named window. 10 | type NamedWindow struct { 11 | Name string 12 | Definition Window 13 | } 14 | 15 | var _ Window = (*NamedWindow)(nil) 16 | 17 | // WriteSQL implements the SQLWriter interface. 18 | func (w NamedWindow) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 19 | buf.WriteString(w.Name) 20 | return nil 21 | } 22 | 23 | // IsWindow implements the Window interface. 24 | func (w NamedWindow) IsWindow() {} 25 | 26 | // WindowDefinition represents an SQL window definition. 27 | type WindowDefinition struct { 28 | BaseWindowName string 29 | PartitionByFields []Field 30 | OrderByFields []Field 31 | FrameSpec string 32 | FrameValues []any 33 | } 34 | 35 | var _ Window = (*WindowDefinition)(nil) 36 | 37 | // BaseWindow creates a new WindowDefinition based off an existing NamedWindow. 38 | func BaseWindow(w NamedWindow) WindowDefinition { 39 | return WindowDefinition{BaseWindowName: w.Name} 40 | } 41 | 42 | // PartitionBy returns a new WindowDefinition with the PARTITION BY clause. 43 | func PartitionBy(fields ...Field) WindowDefinition { 44 | return WindowDefinition{PartitionByFields: fields} 45 | } 46 | 47 | // PartitionBy returns a new WindowDefinition with the ORDER BY clause. 48 | func OrderBy(fields ...Field) WindowDefinition { 49 | return WindowDefinition{OrderByFields: fields} 50 | } 51 | 52 | // WriteSQL implements the SQLWriter interface. 53 | func (w WindowDefinition) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 54 | var err error 55 | var written bool 56 | buf.WriteString("(") 57 | if w.BaseWindowName != "" { 58 | buf.WriteString(w.BaseWindowName + " ") 59 | } 60 | if len(w.PartitionByFields) > 0 { 61 | written = true 62 | buf.WriteString("PARTITION BY ") 63 | err = writeFields(ctx, dialect, buf, args, params, w.PartitionByFields, false) 64 | if err != nil { 65 | return fmt.Errorf("Window PARTITION BY: %w", err) 66 | } 67 | } 68 | if len(w.OrderByFields) > 0 { 69 | if written { 70 | buf.WriteString(" ") 71 | } 72 | written = true 73 | buf.WriteString("ORDER BY ") 74 | err = writeFields(ctx, dialect, buf, args, params, w.OrderByFields, false) 75 | if err != nil { 76 | return fmt.Errorf("Window ORDER BY: %w", err) 77 | } 78 | } 79 | if w.FrameSpec != "" { 80 | if written { 81 | buf.WriteString(" ") 82 | } 83 | written = true 84 | err = Writef(ctx, dialect, buf, args, params, w.FrameSpec, w.FrameValues) 85 | if err != nil { 86 | return fmt.Errorf("Window FRAME: %w", err) 87 | } 88 | } 89 | buf.WriteString(")") 90 | return nil 91 | } 92 | 93 | // PartitionBy returns a new WindowDefinition with the PARTITION BY clause. 94 | func (w WindowDefinition) PartitionBy(fields ...Field) WindowDefinition { 95 | w.PartitionByFields = fields 96 | return w 97 | } 98 | 99 | // OrderBy returns a new WindowDefinition with the ORDER BY clause. 100 | func (w WindowDefinition) OrderBy(fields ...Field) WindowDefinition { 101 | w.OrderByFields = fields 102 | return w 103 | } 104 | 105 | // Frame returns a new WindowDefinition with the frame specification set. 106 | func (w WindowDefinition) Frame(frameSpec string, frameValues ...any) WindowDefinition { 107 | w.FrameSpec = frameSpec 108 | w.FrameValues = frameValues 109 | return w 110 | } 111 | 112 | // IsWindow implements the Window interface. 113 | func (w WindowDefinition) IsWindow() {} 114 | 115 | // NamedWindows represents a slice of NamedWindows. 116 | type NamedWindows []NamedWindow 117 | 118 | // WriteSQL imeplements the SQLWriter interface. 119 | func (ws NamedWindows) WriteSQL(ctx context.Context, dialect string, buf *bytes.Buffer, args *[]any, params map[string][]int) error { 120 | var err error 121 | for i, window := range ws { 122 | if i > 0 { 123 | buf.WriteString(", ") 124 | } 125 | buf.WriteString(window.Name + " AS ") 126 | err = window.Definition.WriteSQL(ctx, dialect, buf, args, params) 127 | if err != nil { 128 | return fmt.Errorf("window #%d: %w", i+1, err) 129 | } 130 | } 131 | return nil 132 | } 133 | 134 | // CountOver represents the COUNT() OVER () window function. 135 | func CountOver(field Field, window Window) Expression { 136 | if window == nil { 137 | return Expr("COUNT({}) OVER ()", field) 138 | } 139 | return Expr("COUNT({}) OVER {}", field, window) 140 | } 141 | 142 | // CountStarOver represents the COUNT(*) OVER () window function. 143 | func CountStarOver(window Window) Expression { 144 | if window == nil { 145 | return Expr("COUNT(*) OVER ()") 146 | } 147 | return Expr("COUNT(*) OVER {}", window) 148 | } 149 | 150 | // SumOver represents the SUM() OVER () window function. 151 | func SumOver(num Number, window Window) Expression { 152 | if window == nil { 153 | return Expr("SUM({}) OVER ()", num) 154 | } 155 | return Expr("SUM({}) OVER {}", num, window) 156 | } 157 | 158 | // AvgOver represents the AVG() OVER () window function. 159 | func AvgOver(num Number, window Window) Expression { 160 | if window == nil { 161 | return Expr("AVG({}) OVER ()", num) 162 | } 163 | return Expr("AVG({}) OVER {}", num, window) 164 | } 165 | 166 | // MinOver represents the MIN() OVER () window function. 167 | func MinOver(field Field, window Window) Expression { 168 | if window == nil { 169 | return Expr("MIN({}) OVER ()", field) 170 | } 171 | return Expr("MIN({}) OVER {}", field, window) 172 | } 173 | 174 | // MaxOver represents the MAX() OVER () window function. 175 | func MaxOver(field Field, window Window) Expression { 176 | if window == nil { 177 | return Expr("MAX({}) OVER ()", field) 178 | } 179 | return Expr("MAX({}) OVER {}", field, window) 180 | } 181 | 182 | // RowNumberOver represents the ROW_NUMBER() OVER () window function. 183 | func RowNumberOver(window Window) Expression { 184 | if window == nil { 185 | return Expr("ROW_NUMBER() OVER ()") 186 | } 187 | return Expr("ROW_NUMBER() OVER {}", window) 188 | } 189 | 190 | // RankOver represents the RANK() OVER () window function. 191 | func RankOver(window Window) Expression { 192 | if window == nil { 193 | return Expr("RANK() OVER ()") 194 | } 195 | return Expr("RANK() OVER {}", window) 196 | } 197 | 198 | // DenseRankOver represents the DENSE_RANK() OVER () window function. 199 | func DenseRankOver(window Window) Expression { 200 | if window == nil { 201 | return Expr("DENSE_RANK() OVER ()") 202 | } 203 | return Expr("DENSE_RANK() OVER {}", window) 204 | } 205 | 206 | // CumeDistOver represents the CUME_DIST() OVER () window function. 207 | func CumeDistOver(window Window) Expression { 208 | if window == nil { 209 | return Expr("CUME_DIST() OVER ()") 210 | } 211 | return Expr("CUME_DIST() OVER {}", window) 212 | } 213 | 214 | // FirstValueOver represents the FIRST_VALUE() OVER () window function. 215 | func FirstValueOver(field Field, window Window) Expression { 216 | if window == nil { 217 | return Expr("FIRST_VALUE({}) OVER ()", field) 218 | } 219 | return Expr("FIRST_VALUE({}) OVER {}", field, window) 220 | } 221 | 222 | // LastValueOver represents the LAST_VALUE() OVER () window 223 | // function. 224 | func LastValueOver(field Field, window Window) Expression { 225 | if window == nil { 226 | return Expr("LAST_VALUE({}) OVER ()", field) 227 | } 228 | return Expr("LAST_VALUE({}) OVER {}", field, window) 229 | } 230 | -------------------------------------------------------------------------------- /window_test.go: -------------------------------------------------------------------------------- 1 | package sq 2 | 3 | import "testing" 4 | 5 | func TestWindow(t *testing.T) { 6 | t.Run("basic", func(t *testing.T) { 7 | t.Parallel() 8 | f1, f2, f3 := Expr("f1"), Expr("f2"), Expr("f3") 9 | TestTable{ 10 | item: PartitionBy(f1).OrderBy(f2, f3).Frame("RANGE UNBOUNDED PRECEDING"), 11 | wantQuery: "(PARTITION BY f1 ORDER BY f2, f3 RANGE UNBOUNDED PRECEDING)", 12 | }.assert(t) 13 | TestTable{ 14 | item: OrderBy(f1).PartitionBy(f2, f3).Frame("ROWS {} PRECEDING", 5), 15 | wantQuery: "(PARTITION BY f2, f3 ORDER BY f1 ROWS ? PRECEDING)", 16 | wantArgs: []any{5}, 17 | }.assert(t) 18 | }) 19 | 20 | errTests := []TestTable{{ 21 | description: "PartitionBy err", item: PartitionBy(FaultySQL{}), 22 | }, { 23 | description: "OrderBy err", item: OrderBy(FaultySQL{}), 24 | }, { 25 | description: "Frame err", item: OrderBy(Expr("f")).Frame("ROWS {} PRECEDING", FaultySQL{}), 26 | }, { 27 | description: "NamedWindows err", item: NamedWindows{{ 28 | Name: "w", 29 | Definition: OrderBy(Expr("f")).Frame("ROWS {} PRECEDING", FaultySQL{}), 30 | }}, 31 | }} 32 | 33 | for _, tt := range errTests { 34 | tt := tt 35 | t.Run(tt.description, func(t *testing.T) { 36 | t.Parallel() 37 | tt.assertErr(t, ErrFaultySQL) 38 | }) 39 | } 40 | 41 | funcTests := []TestTable{{ 42 | description: "CountOver", item: CountOver(Expr("f1"), WindowDefinition{}), 43 | wantQuery: "COUNT(f1) OVER ()", 44 | }, { 45 | description: "CountOver nil", item: CountOver(Expr("f1"), nil), 46 | wantQuery: "COUNT(f1) OVER ()", 47 | }, { 48 | description: "CountStarOver", item: CountStarOver(WindowDefinition{}), 49 | wantQuery: "COUNT(*) OVER ()", 50 | }, { 51 | description: "SumOver", item: SumOver(Expr("f1"), PartitionBy(Expr("f2"))), 52 | wantQuery: "SUM(f1) OVER (PARTITION BY f2)", 53 | }, { 54 | description: "AvgOver", item: AvgOver(Expr("f1"), PartitionBy(Expr("f2"))), 55 | wantQuery: "AVG(f1) OVER (PARTITION BY f2)", 56 | }, { 57 | description: "MinOver", item: MinOver(Expr("f1"), PartitionBy(Expr("f2"))), 58 | wantQuery: "MIN(f1) OVER (PARTITION BY f2)", 59 | }, { 60 | description: "MaxOver", item: MaxOver(Expr("f1"), PartitionBy(Expr("f2"))), 61 | wantQuery: "MAX(f1) OVER (PARTITION BY f2)", 62 | }, { 63 | description: "RowNumberOver", item: RowNumberOver(PartitionBy(Expr("f1"))), 64 | wantQuery: "ROW_NUMBER() OVER (PARTITION BY f1)", 65 | }, { 66 | description: "RankOver", item: RankOver(PartitionBy(Expr("f1"))), 67 | wantQuery: "RANK() OVER (PARTITION BY f1)", 68 | }, { 69 | description: "DenseRankOver", item: DenseRankOver(PartitionBy(Expr("f1"))), 70 | wantQuery: "DENSE_RANK() OVER (PARTITION BY f1)", 71 | }, { 72 | description: "CumeDistOver", item: CumeDistOver(PartitionBy(Expr("f1"))), 73 | wantQuery: "CUME_DIST() OVER (PARTITION BY f1)", 74 | }, { 75 | description: "FirstValueOver", item: FirstValueOver(Expr("f1"), PartitionBy(Expr("f2"))), 76 | wantQuery: "FIRST_VALUE(f1) OVER (PARTITION BY f2)", 77 | }, { 78 | description: "LastValueOver", item: LastValueOver(Expr("f1"), PartitionBy(Expr("f2"))), 79 | wantQuery: "LAST_VALUE(f1) OVER (PARTITION BY f2)", 80 | }, { 81 | description: "NamedWindow", item: CountStarOver(NamedWindow{Name: "w"}), 82 | wantQuery: "COUNT(*) OVER w", 83 | }, func() TestTable { 84 | var tt TestTable 85 | tt.description = "BaseWindow" 86 | w := NamedWindow{Name: "w", Definition: PartitionBy(Expr("f1"))} 87 | tt.item = CountStarOver(BaseWindow(w).Frame("ROWS UNBOUNDED PRECEDING")) 88 | tt.wantQuery = "COUNT(*) OVER (w ROWS UNBOUNDED PRECEDING)" 89 | return tt 90 | }(), func() TestTable { 91 | var tt TestTable 92 | tt.description = "NamedWindows" 93 | w1 := NamedWindow{Name: "w1", Definition: PartitionBy(Expr("f1"))} 94 | w2 := NamedWindow{Name: "w2", Definition: OrderBy(Expr("f2"))} 95 | w3 := NamedWindow{Name: "w3", Definition: OrderBy(Expr("f3")).Frame("ROWS UNBOUNDED PRECEDING")} 96 | tt.item = NamedWindows{w1, w2, w3} 97 | tt.wantQuery = "w1 AS (PARTITION BY f1)" + 98 | ", w2 AS (ORDER BY f2)" + 99 | ", w3 AS (ORDER BY f3 ROWS UNBOUNDED PRECEDING)" 100 | return tt 101 | }()} 102 | 103 | for _, tt := range funcTests { 104 | tt := tt 105 | t.Run(tt.description, func(t *testing.T) { 106 | t.Parallel() 107 | tt.assert(t) 108 | }) 109 | } 110 | } 111 | --------------------------------------------------------------------------------