├── .github └── workflows │ ├── go.yml │ └── lint.yml ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── bench_test.go ├── cache.go ├── cache_test.go ├── dialect.go ├── doc.go ├── example_test.go ├── executor.go ├── executor_test.go ├── go.mod ├── go.sum ├── pool.go ├── stmt.go ├── stmt_test.go ├── util.go └── util_test.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Go 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | go: ['1.17.x', '1.18.x', '1.19.x', '1.20.x', '1.21.x', '1.22.x', '1.23.x'] 16 | name: Go ${{ matrix.go }} job 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: Set up Go 21 | uses: actions/setup-go@v5 22 | with: 23 | go-version: ${{ matrix.go }} 24 | 25 | # Run testing on the code 26 | - name: Run testing 27 | run: go test -v ./... -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Lint 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | lint: 14 | name: lint 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - uses: actions/setup-go@v5 19 | with: 20 | go-version: stable 21 | # Run vet & lint on the code 22 | - name: Run vet & lint 23 | run: | 24 | go mod tidy 25 | go mod verify 26 | go fix ./... 27 | go vet -all ./... 28 | 29 | - name: govulncheck 30 | uses: golang/govulncheck-action@v1 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | go.work.sum 23 | 24 | # Original 25 | .vscode 26 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.16 5 | - 1.17 6 | - master 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Vlad Glushchuk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sqlf 2 | 3 | [![GoDoc Reference](https://godoc.org/github.com/leporo/sqlf?status.svg)](http://godoc.org/github.com/leporo/sqlf) 4 | ![Build Status](https://github.com/leporo/sqlf/actions/workflows/go.yml/badge.svg) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/leporo/sqlf)](https://goreportcard.com/report/github.com/leporo/sqlf) 6 | [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge-flat.svg)](https://github.com/avelino/awesome-go#sql-query-builders) 7 | 8 | 9 | A fast SQL query builder for Go. 10 | 11 | What `sqlf` does? 12 | 13 | - It helps you efficiently build an SQL statement in run-time. 14 | - You may change the number of affected columns and change the number of arguments in a safe way. 15 | - You may use SQL expressions (like `UPDATE counters SET counter = counter + 1`) in your SQL statements. 16 | - You may dynamically apply filters by adding where conditions, change result ordering, etc. 17 | - You may safely use `?` placeholders in your SQL fragments - `sqlf` converts them to PostgreSQL-like `$1, $2, ...` placeholders if needed and does the numbering for you. 18 | - You may `.Bind` your structure to database columns like you do with other similar libraries. 19 | - `sqlf.Stmt` has methods to execute a query using any `database/sql` compatible driver. 20 | 21 | What `sqlf` doesn't? 22 | 23 | - `sqlf` isn't an ORM, you'll still have to use raw SQL. 24 | - There are no database schema migrations or any other database schema maintenance tools. 25 | - There are no compile-time type checks for query arguments, column and table names. 26 | - There is no wrapper for `OR` clause. It affects performance and in most cases can be avoided by using `UNION` expressions, `WITH` clause or window functions. Other option is to split a query into two. 27 | - `sqlf` doesn't help a developer to pinpoint the cause of issue with SQL statement. 28 | 29 | ## Is It Fast? 30 | 31 | It is. See benchmarks: https://github.com/leporo/golang-sql-builder-benchmark 32 | 33 | In order to maximize performance and minimize memory footprint, `sqlf` reuses memory allocated for query building. The heavier load is, the faster `sqlf` works. 34 | 35 | ## Usage 36 | 37 | Build complex statements: 38 | 39 | ```go 40 | var ( 41 | region string 42 | product string 43 | productUnits int 44 | productSales float64 45 | ) 46 | 47 | sqlf.SetDialect(sqlf.PostgreSQL) 48 | 49 | err := sqlf.From("orders"). 50 | With("regional_sales", 51 | sqlf.From("orders"). 52 | Select("region, SUM(amount) AS total_sales"). 53 | GroupBy("region")). 54 | With("top_regions", 55 | sqlf.From("regional_sales"). 56 | Select("region"). 57 | Where("total_sales > (SELECT SUM(total_sales)/10 FROM regional_sales)")). 58 | // Map query fields to variables 59 | Select("region").To(®ion). 60 | Select("product").To(&product). 61 | Select("SUM(quantity)").To(&productUnits). 62 | Select("SUM(amount) AS product_sales").To(&productSales). 63 | // 64 | Where("region IN (SELECT region FROM top_regions)"). 65 | GroupBy("region, product"). 66 | OrderBy("product_sales DESC"). 67 | // Execute the query 68 | QueryAndClose(ctx, db, func(row *sql.Rows){ 69 | // Callback function is called for every returned row. 70 | // Row values are scanned automatically to bound variables. 71 | fmt.Printf("%s\t%s\t%d\t$%.2f\n", region, product, productUnits, productSales) 72 | }) 73 | if err != nil { 74 | panic(err) 75 | } 76 | ``` 77 | 78 | Bind a structure: 79 | 80 | ```go 81 | type Offer struct { 82 | Id int64 `db:"id"` 83 | ProductId int64 `db:"product_id"` 84 | Price float64 `db:"price"` 85 | IsDeleted bool `db:"is_deleted"` 86 | } 87 | 88 | var o Offer 89 | 90 | err := sqlf.From("offers"). 91 | Bind(&o). 92 | Where("id = ?", 42). 93 | QueryRowAndClose(ctx, db) 94 | if err != nil { 95 | panic(err) 96 | } 97 | ``` 98 | 99 | Retrieve data to private fields with more granular control on retrieved fields: 100 | 101 | ```go 102 | type Offer struct { 103 | id int64 104 | productId int64 105 | price float64 106 | isDeleted bool 107 | } 108 | 109 | var o Offer 110 | 111 | err := sqlf.From("offers"). 112 | Select("id").To(&o.id). 113 | Select("product_id").To(&o.productId). 114 | Select("price").To(&o.price). 115 | Select("is_deleted").To(&o.isDeleted). 116 | Where("id = ?", 42). 117 | QueryRowAndClose(ctx, db) 118 | if err != nil { 119 | panic(err) 120 | } 121 | ``` 122 | 123 | Some SQL fragments, like a list of fields to be selected or filtering condition may appear over and over. It can be annoying to repeat them or combine an SQL statement from chunks. Use `sqlf.Stmt` to construct a basic query and extend it for a case: 124 | 125 | ```go 126 | func (o *Offer) Select() *sqlf.Stmt { 127 | return sqlf.From("products"). 128 | .Bind(o) 129 | // Ignore records, marked as deleted 130 | Where("is_deleted = false") 131 | } 132 | 133 | func (o Offer) Print() { 134 | fmt.Printf("%d\t%s\t$%.2f\n", o.id, o.name, o.price) 135 | } 136 | 137 | var o Offer 138 | 139 | // Fetch offer data 140 | err := o.Select(). 141 | Where("id = ?", offerId). 142 | QueryRowAndClose(ctx, db) 143 | if err != nil { 144 | panic(err) 145 | } 146 | o.Print() 147 | // ... 148 | 149 | // Select and print 5 most recently placed 150 | // offers for a given product 151 | err = o.Select(). 152 | Where("product_id = ?", productId). 153 | OrderBy("id DESC"). 154 | Limit(5). 155 | QueryAndClose(ctx, db, func(row *sql.Rows){ 156 | o.Print() 157 | }) 158 | if err != nil { 159 | panic(err) 160 | } 161 | // ... 162 | 163 | ``` 164 | 165 | ## SQL Statement Construction and Execution 166 | 167 | ### SELECT 168 | 169 | #### Value Binding 170 | 171 | Bind columns to values using `To` method: 172 | 173 | ```go 174 | var ( 175 | minAmountRequested = true 176 | maxAmount float64 177 | minAmount float64 178 | ) 179 | 180 | q := sqlf.From("offers"). 181 | Select("MAX(amount)").To(&maxAmount). 182 | Where("is_deleted = false") 183 | 184 | if minAmountRequested { 185 | q.Select("MIN(amount)").To(&minAmount) 186 | } 187 | 188 | err := q.QueryRowAndClose(ctx, db) 189 | if err != nil { 190 | panic(err) 191 | } 192 | if minAmountRequested { 193 | fmt.Printf("Cheapest offer: $%.2f\n", minAmount) 194 | } 195 | fmt.Printf("Most expensive offer: $%.2f\n", minAmount) 196 | ``` 197 | 198 | #### Joins 199 | 200 | There are helper methods to construct a JOIN clause: `Join`, `LeftJoin`, `RightJoin` and `FullJoin`. 201 | 202 | ```go 203 | var ( 204 | offerId int64 205 | productName string 206 | price float64 207 | } 208 | 209 | err := sqlf.From("offers o"). 210 | Select("o.id").To(&offerId). 211 | Select("price").To(&price). 212 | Where("is_deleted = false"). 213 | // Join 214 | LeftJoin("products p", "p.id = o.product_id"). 215 | // Bind a column from joined table to variable 216 | Select("p.name").To(&productName). 217 | // Print top 10 offers 218 | OrderBy("price DEST"). 219 | Limit(10). 220 | QueryAndClose(ctx, db, func(row *sql.Rows){ 221 | fmt.Printf("%d\t%s\t$%.2f\n", offerId, productName, price) 222 | }) 223 | if err != nil { 224 | panic(err) 225 | } 226 | ``` 227 | 228 | Use plain SQL for more fancy cases: 229 | 230 | ```go 231 | var ( 232 | num int64 233 | name string 234 | value string 235 | ) 236 | err := sqlf.From("t1 CROSS JOIN t2 ON t1.num = t2.num AND t2.value IN (?, ?)", "xxx", "yyy"). 237 | Select("t1.num").To(&num). 238 | Select("t1.name").To(&name). 239 | Select("t2.value").To(&value). 240 | QueryAndClose(ctx, db, func(row *sql.Rows){ 241 | fmt.Printf("%d\t%s\ts\n", num, name, value) 242 | }) 243 | if err != nil { 244 | panic(err) 245 | } 246 | ``` 247 | 248 | #### Subqueries 249 | 250 | Use `SubQuery` method to add a sub query to a statement: 251 | 252 | ```go 253 | q := sqlf.From("orders o"). 254 | Select("date, region"). 255 | SubQuery("(", ") AS prev_order_date", 256 | sqlf.From("orders po"). 257 | Select("date"). 258 | Where("region = o.region"). 259 | Where("id < o.id"). 260 | OrderBy("id DESC"). 261 | Clause("LIMIT 1")). 262 | Where("date > CURRENT_DATE - interval '1 day'"). 263 | OrderBy("id DESC") 264 | fmt.Println(q.String()) 265 | q.Close() 266 | ``` 267 | 268 | Note that if a subquery uses no arguments, it's more effective to add it as SQL fragment: 269 | 270 | ```go 271 | q := sqlf.From("orders o"). 272 | Select("date, region"). 273 | Where("date > CURRENT_DATE - interval '1 day'"). 274 | Where("exists (SELECT 1 FROM orders po WHERE region = o.region AND id < o.id ORDER BY id DESC LIMIT 1)"). 275 | OrderBy("id DESC") 276 | // ... 277 | q.Close() 278 | ``` 279 | 280 | To select from sub-query pass an empty string to From and immediately call a SubQuery method. 281 | 282 | The query constructed by the following example returns top 5 news in each section: 283 | 284 | ```go 285 | q := sqlf.Select(""). 286 | From(""). 287 | SubQuery( 288 | "(", ") counted_news", 289 | sqlf.From("news"). 290 | Select("id, section, header, score"). 291 | Select("row_number() OVER (PARTITION BY section ORDER BY score DESC) AS rating_in_section"). 292 | OrderBy("section, rating_in_section")). 293 | Where("rating_in_section <= 5") 294 | // ... 295 | q.Close() 296 | ``` 297 | 298 | #### Unions 299 | 300 | Use `Union` method to combine results of two queries: 301 | 302 | ```go 303 | q := sqlf.From("tasks"). 304 | Select("id, status"). 305 | Where("status = ?", "new"). 306 | Union(true, sqlf.PostgreSQL.From("tasks"). 307 | Select("id, status"). 308 | Where("status = ?", "wip")) 309 | // ... 310 | q.Close() 311 | ``` 312 | 313 | ### INSERT 314 | 315 | `sqlf` provides a `Set` method to be used both for UPDATE and INSERT statements: 316 | 317 | ```go 318 | var userId int64 319 | 320 | _, err := sqlf.InsertInto("users"). 321 | Set("email", "new@email.com"). 322 | Set("address", "320 Some Avenue, Somewhereville, GA, US"). 323 | Returning("id").To(&userId). 324 | Clause("ON CONFLICT (email) DO UPDATE SET address = users.address"). 325 | QueryRowAndClose(ctx, db) 326 | ``` 327 | 328 | The same statement execution using the `database/sql` standard library looks like this: 329 | 330 | ```go 331 | var userId int64 332 | 333 | // database/sql 334 | err := db.ExecContext(ctx, "INSERT INTO users (email, address) VALUES ($1, $2) RETURNING id ON CONFLICT (email) DO UPDATE SET address = users.address", "new@email.com", "320 Some Avenue, Somewhereville, GA, US").Scan(&userId) 335 | ``` 336 | 337 | There are just 2 fields of a new database record to be populated, and yet it takes some time to figure out what columns are being updated and what values are to be assigned to them. 338 | 339 | In real-world cases there are tens of fields. On any update both the list of field names and the list of values, passed to `ExecContext` method, have to to be reviewed and updated. It's a common thing to have values misplaced. 340 | 341 | The use of `Set` method to maintain a field-value map is a way to solve this issue. 342 | 343 | #### Bulk Insert 344 | 345 | To insert a multiple rows via a single query, use `NewRow` method: 346 | 347 | ``` 348 | _, err := sqlf.InsertInto("users"). 349 | NewRow(). 350 | Set("email", "first@email.com"). 351 | Set("address", "320 Some Avenue, Somewhereville, GA, US"). 352 | NewRow(). 353 | Set("email", "second@email.com"). 354 | Set("address", "320 Some Avenue, Somewhereville, GA, US"). 355 | ExecAndClose(ctx, db) 356 | ``` 357 | 358 | ### UPDATE 359 | 360 | ```go 361 | _, err := sqlf.Update("users"). 362 | Set("email", "new@email.com"). 363 | ExecAndClose(ctx, db) 364 | ``` 365 | 366 | ### DELETE 367 | 368 | ```go 369 | _, err := sqlf.DeleteFrom("products"). 370 | Where("id = ?", 42) 371 | ExecAndClose(ctx, db) 372 | ``` 373 | -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | package sqlf_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/leporo/sqlf" 8 | ) 9 | 10 | var s string 11 | 12 | func BenchmarkSelectDontClose(b *testing.B) { 13 | sqlf.NoDialect.ClearCache() 14 | for i := 0; i < b.N; i++ { 15 | q := sqlf.Select("id").From("table").Where("id > ?", 42).Where("id < ?", 1000) 16 | s = q.String() 17 | } 18 | } 19 | 20 | func BenchmarkSelect(b *testing.B) { 21 | sqlf.NoDialect.ClearCache() 22 | for i := 0; i < b.N; i++ { 23 | q := sqlf.Select("id").From("table").Where("id > ?", 42).Where("id < ?", 1000) 24 | s = q.String() 25 | q.Close() 26 | } 27 | } 28 | 29 | func BenchmarkSelectPg(b *testing.B) { 30 | sqlf.PostgreSQL.ClearCache() 31 | for i := 0; i < b.N; i++ { 32 | q := sqlf.PostgreSQL.Select("id").From("table").Where("id > ?", 42).Where("id < ?", 1000) 33 | s = q.String() 34 | q.Close() 35 | } 36 | } 37 | 38 | func BenchmarkManyFields(b *testing.B) { 39 | fields := make([]string, 0, 100) 40 | 41 | for n := 1; n <= cap(fields); n++ { 42 | fields = append(fields, fmt.Sprintf("field_%d", n)) 43 | } 44 | 45 | sqlf.NoDialect.ClearCache() 46 | 47 | b.ResetTimer() 48 | 49 | for i := 0; i < b.N; i++ { 50 | q := sqlf.Select("id").From("table").Where("id > ?", 42).Where("id < ?", 1000) 51 | for _, field := range fields { 52 | q.Select(field) 53 | } 54 | s = q.String() 55 | q.Close() 56 | } 57 | } 58 | 59 | func BenchmarkBind(b *testing.B) { 60 | type Record struct { 61 | ID int64 `db:"id"` 62 | } 63 | var u struct { 64 | Record 65 | Name string `db:"name"` 66 | } 67 | sqlf.NoDialect.ClearCache() 68 | 69 | b.ResetTimer() 70 | 71 | for i := 0; i < b.N; i++ { 72 | q := sqlf.From("table").Bind(&u).Where("id = ?", 42) 73 | s = q.String() 74 | q.Close() 75 | } 76 | } 77 | 78 | func BenchmarkManyFieldsPg(b *testing.B) { 79 | fields := make([]string, 0, 100) 80 | 81 | for n := 1; n <= cap(fields); n++ { 82 | fields = append(fields, fmt.Sprintf("field_%d", n)) 83 | } 84 | 85 | sqlf.PostgreSQL.ClearCache() 86 | 87 | b.ResetTimer() 88 | 89 | for i := 0; i < b.N; i++ { 90 | q := sqlf.PostgreSQL.Select("id").From("table").Where("id > ?", 42).Where("id < ?", 1000) 91 | for _, field := range fields { 92 | q.Select(field) 93 | } 94 | s = q.String() 95 | q.Close() 96 | } 97 | } 98 | 99 | func BenchmarkMixedOrder(b *testing.B) { 100 | sqlf.NoDialect.ClearCache() 101 | for i := 0; i < b.N; i++ { 102 | q := sqlf.Select("id").Where("id > ?", 42).From("table").Where("id < ?", 1000) 103 | s = q.String() 104 | q.Close() 105 | } 106 | } 107 | 108 | func BenchmarkBuildPg(b *testing.B) { 109 | sqlf.PostgreSQL.ClearCache() 110 | q := sqlf.PostgreSQL.Select("id").From("table").Where("id > ?", 42).Where("id < ?", 1000) 111 | 112 | for i := 0; i < b.N; i++ { 113 | q.Invalidate() 114 | s = q.String() 115 | } 116 | } 117 | 118 | func BenchmarkBuild(b *testing.B) { 119 | sqlf.NoDialect.ClearCache() 120 | q := sqlf.Select("id").From("table").Where("id > ?", 42).Where("id < ?", 1000) 121 | 122 | for i := 0; i < b.N; i++ { 123 | q.Invalidate() 124 | s = q.String() 125 | } 126 | } 127 | 128 | func BenchmarkDest(b *testing.B) { 129 | sqlf.NoDialect.ClearCache() 130 | var ( 131 | field1 int 132 | field2 string 133 | ) 134 | for i := 0; i < b.N; i++ { 135 | q := sqlf.From("table"). 136 | Select("field1").To(&field1). 137 | Select("field2").To(&field2) 138 | q.Close() 139 | } 140 | } 141 | 142 | func selectComplex(b *testing.B, dialect *sqlf.Dialect) { 143 | dialect.ClearCache() 144 | for n := 0; n < b.N; n++ { 145 | q := dialect.Select("DISTINCT a, b, z, y, x"). 146 | From("c"). 147 | Where("(d = ? OR e = ?)", 1, "wat"). 148 | Where("f = ? and x = ?", 2, "hi"). 149 | Where("g = ?", 3). 150 | Where("h").In(1, 2, 3). 151 | GroupBy("i"). 152 | GroupBy("ii"). 153 | GroupBy("iii"). 154 | Having("j = k"). 155 | Having("jj = ?", 1). 156 | Having("jjj = ?", 2). 157 | OrderBy("l"). 158 | OrderBy("l"). 159 | OrderBy("l"). 160 | Limit(7). 161 | Offset(8) 162 | s = q.String() 163 | q.Close() 164 | } 165 | } 166 | 167 | func selectSubqueryFmt(b *testing.B, dialect *sqlf.Dialect) { 168 | dialect.ClearCache() 169 | for n := 0; n < b.N; n++ { 170 | sq := dialect.Select("id"). 171 | From("tickets"). 172 | Where("subdomain_id = ? and (state = ? or state = ?)", 1, "open", "spam") 173 | subQuery := sq.String() 174 | 175 | q := dialect.Select("DISTINCT a, b"). 176 | Select(fmt.Sprintf("(%s) AS subq", subQuery)). 177 | From("c"). 178 | Where("f = ? and x = ?", 2, "hi"). 179 | Where("g = ?", 3). 180 | OrderBy("l"). 181 | OrderBy("l"). 182 | Limit(7). 183 | Offset(8) 184 | s = q.String() 185 | q.Close() 186 | sq.Close() 187 | } 188 | } 189 | 190 | func selectSubquery(b *testing.B, dialect *sqlf.Dialect) { 191 | dialect.ClearCache() 192 | for n := 0; n < b.N; n++ { 193 | q := dialect.Select("DISTINCT a, b"). 194 | SubQuery("(", ") AS subq", sqlf.Select("id"). 195 | From("tickets"). 196 | Where("subdomain_id = ? and (state = ? or state = ?)", 1, "open", "spam")). 197 | From("c"). 198 | Where("f = ? and x = ?", 2, "hi"). 199 | Where("g = ?", 3). 200 | OrderBy("l"). 201 | OrderBy("l"). 202 | Limit(7). 203 | Offset(8) 204 | s = q.String() 205 | q.Close() 206 | } 207 | } 208 | 209 | func BenchmarkSelectComplex(b *testing.B) { 210 | selectComplex(b, sqlf.NoDialect) 211 | } 212 | 213 | func BenchmarkSelectComplexPg(b *testing.B) { 214 | selectComplex(b, sqlf.PostgreSQL) 215 | } 216 | 217 | func BenchmarkSelectSubqueryFmt(b *testing.B) { 218 | selectSubqueryFmt(b, sqlf.NoDialect) 219 | } 220 | 221 | func BenchmarkSelectSubqueryFmtPostgreSQL(b *testing.B) { 222 | selectSubqueryFmt(b, sqlf.PostgreSQL) 223 | } 224 | 225 | func BenchmarkSelectSubquery(b *testing.B) { 226 | selectSubquery(b, sqlf.NoDialect) 227 | } 228 | 229 | func BenchmarkSelectSubqueryPostgreSQL(b *testing.B) { 230 | selectSubquery(b, sqlf.PostgreSQL) 231 | } 232 | 233 | func BenchmarkWith(b *testing.B) { 234 | sqlf.NoDialect.ClearCache() 235 | for n := 0; n < b.N; n++ { 236 | q := sqlf.From("orders"). 237 | With("regional_sales", 238 | sqlf.From("orders"). 239 | Select("region, SUM(amount) AS total_sales"). 240 | GroupBy("region")). 241 | With("top_regions", 242 | sqlf.From("regional_sales"). 243 | Select("region"). 244 | Where("total_sales > (SELECT SUM(total_sales)/10 FROM regional_sales)")). 245 | Select("region"). 246 | Select("product"). 247 | Select("SUM(quantity) AS product_units"). 248 | Select("SUM(amount) AS product_sales"). 249 | Where("region IN (SELECT region FROM top_regions)"). 250 | GroupBy("region, product") 251 | s = q.String() 252 | q.Close() 253 | } 254 | } 255 | 256 | func BenchmarkIn(b *testing.B) { 257 | a := make([]interface{}, 50) 258 | for i := 0; i < len(a); i++ { 259 | a[i] = i + 1 260 | } 261 | sqlf.NoDialect.ClearCache() 262 | b.ResetTimer() 263 | for n := 0; n < b.N; n++ { 264 | q := sqlf.From("orders"). 265 | Select("id"). 266 | Where("status").In(a...) 267 | s = q.String() 268 | q.Close() 269 | } 270 | } 271 | -------------------------------------------------------------------------------- /cache.go: -------------------------------------------------------------------------------- 1 | package sqlf 2 | 3 | import ( 4 | "github.com/valyala/bytebufferpool" 5 | ) 6 | 7 | type sqlCache map[string]string 8 | 9 | /* 10 | ClearCache clears the statement cache. 11 | 12 | In most cases you don't need to care about it. It's there to 13 | let caller free memory when a caller executes zillions of unique 14 | SQL statements. 15 | */ 16 | func (d *Dialect) ClearCache() { 17 | d.cacheLock.Lock() 18 | d.cache = make(sqlCache) 19 | d.cacheLock.Unlock() 20 | } 21 | 22 | func (d *Dialect) getCache() sqlCache { 23 | d.cacheOnce.Do(func() { 24 | d.cache = make(sqlCache) 25 | }) 26 | return d.cache 27 | } 28 | 29 | func (d *Dialect) getCachedSQL(buf *bytebufferpool.ByteBuffer) (string, bool) { 30 | c := d.getCache() 31 | s := bufToString(&buf.B) 32 | 33 | d.cacheLock.RLock() 34 | res, ok := c[s] 35 | d.cacheLock.RUnlock() 36 | return res, ok 37 | } 38 | 39 | func (d *Dialect) putCachedSQL(buf *bytebufferpool.ByteBuffer, sql string) { 40 | key := string(buf.B) 41 | c := d.getCache() 42 | d.cacheLock.Lock() 43 | c[key] = sql 44 | d.cacheLock.Unlock() 45 | } 46 | -------------------------------------------------------------------------------- /cache_test.go: -------------------------------------------------------------------------------- 1 | package sqlf 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestSQLCache(t *testing.T) { 10 | buf := getBuffer() 11 | defer putBuffer(buf) 12 | 13 | buf.WriteString("test") 14 | _, ok := defaultDialect.getCachedSQL(buf) 15 | require.False(t, ok) 16 | 17 | defaultDialect.putCachedSQL(buf, "test SQL") 18 | sql, ok := defaultDialect.getCachedSQL(buf) 19 | require.True(t, ok) 20 | require.Equal(t, "test SQL", sql) 21 | 22 | defaultDialect.ClearCache() 23 | } 24 | -------------------------------------------------------------------------------- /dialect.go: -------------------------------------------------------------------------------- 1 | package sqlf 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | "sync" 7 | "sync/atomic" 8 | "unsafe" 9 | ) 10 | 11 | // Dialect defines the method SQL statement is to be built. 12 | // 13 | // NoDialect is a default statement builder mode. 14 | // No SQL fragments will be altered. 15 | // PostgreSQL mode can be set for a statement: 16 | // 17 | // q := sqlf.PostgreSQL.From("table").Select("field") 18 | // ... 19 | // q.Close() 20 | // 21 | // or as default mode: 22 | // 23 | // sqlf.SetDialect(sqlf.PostgreSQL) 24 | // ... 25 | // q := sqlf.From("table").Select("field") 26 | // q.Close() 27 | // 28 | // When PostgreSQL mode is activated, ? placeholders are 29 | // replaced with numbered positional arguments like $1, $2... 30 | type Dialect struct { 31 | cacheOnce sync.Once 32 | cacheLock sync.RWMutex 33 | cache sqlCache 34 | } 35 | 36 | var ( 37 | // NoDialect is a default statement builder mode. 38 | NoDialect *Dialect = &Dialect{} 39 | // PostgreSQL mode is to be used to automatically replace ? placeholders with $1, $2... 40 | PostgreSQL *Dialect = &Dialect{} 41 | ) 42 | 43 | var defaultDialect = NoDialect 44 | 45 | /* 46 | SetDialect selects a Dialect to be used by default. 47 | 48 | Dialect can be one of sqlf.NoDialect or sqlf.PostgreSQL 49 | 50 | sqlf.SetDialect(sqlf.PostgreSQL) 51 | */ 52 | func SetDialect(newDefaultDialect *Dialect) { 53 | atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&defaultDialect)), unsafe.Pointer(newDefaultDialect)) 54 | } 55 | 56 | /* 57 | New starts an SQL statement with an arbitrary verb. 58 | 59 | Use From, Select, InsertInto or DeleteFrom methods to create 60 | an instance of an SQL statement builder for common statements. 61 | */ 62 | func (b *Dialect) New(verb string, args ...interface{}) *Stmt { 63 | q := getStmt(b) 64 | q.addChunk(posSelect, verb, "", args, ", ") 65 | return q 66 | } 67 | 68 | /* 69 | With starts a statement prepended by WITH clause 70 | and closes a subquery passed as an argument. 71 | */ 72 | func (b *Dialect) With(queryName string, query *Stmt) *Stmt { 73 | q := getStmt(b) 74 | return q.With(queryName, query) 75 | } 76 | 77 | /* 78 | From starts a SELECT statement. 79 | */ 80 | func (b *Dialect) From(expr string, args ...interface{}) *Stmt { 81 | q := getStmt(b) 82 | return q.From(expr, args...) 83 | } 84 | 85 | /* 86 | Select starts a SELECT statement. 87 | 88 | Consider using From method to start a SELECT statement - you may find 89 | it easier to read and maintain. 90 | */ 91 | func (b *Dialect) Select(expr string, args ...interface{}) *Stmt { 92 | q := getStmt(b) 93 | return q.Select(expr, args...) 94 | } 95 | 96 | // Update starts an UPDATE statement. 97 | func (b *Dialect) Update(tableName string) *Stmt { 98 | q := getStmt(b) 99 | return q.Update(tableName) 100 | } 101 | 102 | // InsertInto starts an INSERT statement. 103 | func (b *Dialect) InsertInto(tableName string) *Stmt { 104 | q := getStmt(b) 105 | return q.InsertInto(tableName) 106 | } 107 | 108 | // DeleteFrom starts a DELETE statement. 109 | func (b *Dialect) DeleteFrom(tableName string) *Stmt { 110 | q := getStmt(b) 111 | return q.DeleteFrom(tableName) 112 | } 113 | 114 | // writePg function copies s into buf and replaces ? placeholders with $1, $2... 115 | func writePg(argNo int, s []byte, buf *strings.Builder) (int, error) { 116 | var err error 117 | start := 0 118 | // Iterate by runes 119 | for pos, r := range bufToString(&s) { 120 | if start > pos { 121 | continue 122 | } 123 | switch r { 124 | case '\\': 125 | if pos < len(s)-1 && s[pos+1] == '?' { 126 | _, err = buf.Write(s[start:pos]) 127 | if err == nil { 128 | err = buf.WriteByte('?') 129 | } 130 | start = pos + 2 131 | } 132 | case '?': 133 | _, err = buf.Write(s[start:pos]) 134 | start = pos + 1 135 | if err == nil { 136 | err = buf.WriteByte('$') 137 | if err == nil { 138 | buf.WriteString(strconv.Itoa(argNo)) 139 | argNo++ 140 | } 141 | } 142 | } 143 | if err != nil { 144 | break 145 | } 146 | } 147 | if err == nil && start < len(s) { 148 | _, err = buf.Write(s[start:]) 149 | } 150 | return argNo, err 151 | } 152 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package sqlf is an SQL statement builder and executor. 2 | /* 3 | 4 | SQL Statement Builder 5 | 6 | sqlf statement builder provides a way to: 7 | 8 | - Combine SQL statements from fragments of raw SQL and arguments that match those fragments, 9 | 10 | - Map columns to variables to be referenced by Scan, 11 | 12 | - Convert ? placeholders into numbered ones for PostgreSQL ($1, $2, etc). 13 | */ 14 | package sqlf 15 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package sqlf_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | 8 | "github.com/leporo/sqlf" 9 | ) 10 | 11 | type dummyDB int 12 | 13 | func (db *dummyDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 14 | return nil, nil 15 | } 16 | 17 | func (db *dummyDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { 18 | return nil, nil 19 | } 20 | 21 | func (db *dummyDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 22 | return nil 23 | } 24 | 25 | var ( 26 | db = new(dummyDB) 27 | ctx = context.Background() 28 | ) 29 | 30 | func Example() { 31 | var ( 32 | region string 33 | product string 34 | productUnits int 35 | productSales float64 36 | ) 37 | 38 | sqlf.SetDialect(sqlf.PostgreSQL) 39 | 40 | err := sqlf.From("orders"). 41 | With("regional_sales", 42 | sqlf.From("orders"). 43 | Select("region, SUM(amount) AS total_sales"). 44 | GroupBy("region")). 45 | With("top_regions", 46 | sqlf.From("regional_sales"). 47 | Select("region"). 48 | Where("total_sales > (SELECT SUM(total_sales)/10 FROM regional_sales)")). 49 | // Map query fields to variables 50 | Select("region").To(®ion). 51 | Select("product").To(&product). 52 | Select("SUM(quantity)").To(&productUnits). 53 | Select("SUM(amount) AS product_sales").To(&productSales). 54 | // 55 | Where("region IN (SELECT region FROM top_regions)"). 56 | GroupBy("region, product"). 57 | OrderBy("product_sales DESC"). 58 | // Execute the query 59 | QueryAndClose(ctx, db, func(row *sql.Rows) { 60 | // Callback function is called for every returned row. 61 | // Row values are scanned automatically to bound variables. 62 | fmt.Printf("%s\t%s\t%d\t$%.2f\n", region, product, productUnits, productSales) 63 | }) 64 | if err != nil { 65 | panic(err) 66 | } 67 | } 68 | 69 | func ExampleStmt_OrderBy() { 70 | q := sqlf.Select("id").From("table").OrderBy("id", "name DESC") 71 | fmt.Println(q.String()) 72 | // Output: SELECT id FROM table ORDER BY id, name DESC 73 | } 74 | 75 | func ExampleStmt_Limit() { 76 | q := sqlf.Select("id").From("table").Limit(10) 77 | fmt.Println(q.String()) 78 | // Output: SELECT id FROM table LIMIT ? 79 | } 80 | 81 | func ExampleStmt_Offset() { 82 | q := sqlf.Select("id").From("table").Limit(10).Offset(10) 83 | fmt.Println(q.String()) 84 | // Output: SELECT id FROM table LIMIT ? OFFSET ? 85 | } 86 | 87 | func ExampleStmt_Paginate() { 88 | q := sqlf.Select("id").From("table").Paginate(5, 10) 89 | fmt.Println(q.String(), q.Args()) 90 | q.Close() 91 | 92 | q = sqlf.Select("id").From("table").Paginate(1, 10) 93 | fmt.Println(q.String(), q.Args()) 94 | q.Close() 95 | 96 | // Zero and negative values are replaced with 1 97 | q = sqlf.Select("id").From("table").Paginate(-1, -1) 98 | fmt.Println(q.String(), q.Args()) 99 | q.Close() 100 | 101 | // Output: 102 | // SELECT id FROM table LIMIT ? OFFSET ? [10 40] 103 | // SELECT id FROM table LIMIT ? [10] 104 | // SELECT id FROM table LIMIT ? [1] 105 | } 106 | 107 | func ExampleStmt_Update() { 108 | q := sqlf.Update("table").Set("field1", "newvalue").Where("id = ?", 42) 109 | fmt.Println(q.String(), q.Args()) 110 | q.Close() 111 | // Output: 112 | // UPDATE table SET field1=? WHERE id = ? [newvalue 42] 113 | } 114 | 115 | func ExampleStmt_SetExpr() { 116 | q := sqlf.Update("table").SetExpr("field1", "field2 + 1").Where("id = ?", 42) 117 | fmt.Println(q.String()) 118 | fmt.Println(q.Args()) 119 | q.Close() 120 | // Output: 121 | // UPDATE table SET field1=field2 + 1 WHERE id = ? 122 | // [42] 123 | } 124 | 125 | func ExampleStmt_InsertInto() { 126 | q := sqlf.InsertInto("table"). 127 | Set("field1", "newvalue"). 128 | SetExpr("field2", "field2 + 1") 129 | fmt.Println(q.String()) 130 | fmt.Println(q.Args()) 131 | q.Close() 132 | // Output: 133 | // INSERT INTO table ( field1, field2 ) VALUES ( ?, field2 + 1 ) 134 | // [newvalue] 135 | } 136 | 137 | func ExampleStmt_DeleteFrom() { 138 | q := sqlf.DeleteFrom("table").Where("id = ?", 42) 139 | fmt.Println(q.String()) 140 | fmt.Println(q.Args()) 141 | q.Close() 142 | // Output: 143 | // DELETE FROM table WHERE id = ? 144 | // [42] 145 | } 146 | 147 | func ExampleStmt_GroupBy() { 148 | q := sqlf.From("incomes"). 149 | Select("source, sum(amount) as s"). 150 | Where("amount > ?", 42). 151 | GroupBy("source") 152 | fmt.Println(q.String()) 153 | fmt.Println(q.Args()) 154 | q.Close() 155 | // Output: 156 | // SELECT source, sum(amount) as s FROM incomes WHERE amount > ? GROUP BY source 157 | // [42] 158 | } 159 | 160 | func ExampleStmt_Having() { 161 | q := sqlf.From("incomes"). 162 | Select("source, sum(amount) as s"). 163 | Where("amount > ?", 42). 164 | GroupBy("source"). 165 | Having("s > ?", 100) 166 | fmt.Println(q.String()) 167 | fmt.Println(q.Args()) 168 | q.Close() 169 | // Output: 170 | // SELECT source, sum(amount) as s FROM incomes WHERE amount > ? GROUP BY source HAVING s > ? 171 | // [42 100] 172 | } 173 | 174 | func ExampleStmt_Returning() { 175 | var newId int 176 | q := sqlf.InsertInto("table"). 177 | Set("field1", "newvalue"). 178 | Returning("id").To(&newId) 179 | fmt.Println(q.String(), q.Args()) 180 | q.Close() 181 | // Output: 182 | // INSERT INTO table ( field1 ) VALUES ( ? ) RETURNING id [newvalue] 183 | } 184 | 185 | func ExamplePostgreSQL() { 186 | q := sqlf.PostgreSQL.From("table").Select("field").Where("id = ?", 42) 187 | fmt.Println(q.String()) 188 | q.Close() 189 | // Output: 190 | // SELECT field FROM table WHERE id = $1 191 | } 192 | 193 | func ExampleStmt_With() { 194 | q := sqlf.From("orders"). 195 | With("regional_sales", 196 | sqlf.From("orders"). 197 | Select("region, SUM(amount) AS total_sales"). 198 | GroupBy("region")). 199 | With("top_regions", 200 | sqlf.From("regional_sales"). 201 | Select("region"). 202 | Where("total_sales > (SELECT SUM(total_sales)/10 FROM regional_sales)")). 203 | Select("region"). 204 | Select("product"). 205 | Select("SUM(quantity) AS product_units"). 206 | Select("SUM(amount) AS product_sales"). 207 | Where("region IN (SELECT region FROM top_regions)"). 208 | GroupBy("region, product") 209 | fmt.Println(q.String()) 210 | q.Close() 211 | // Output: 212 | // WITH regional_sales AS (SELECT region, SUM(amount) AS total_sales FROM orders GROUP BY region), top_regions AS (SELECT region FROM regional_sales WHERE total_sales > (SELECT SUM(total_sales)/10 FROM regional_sales)) SELECT region, product, SUM(quantity) AS product_units, SUM(amount) AS product_sales FROM orders WHERE region IN (SELECT region FROM top_regions) GROUP BY region, product 213 | } 214 | 215 | func ExampleStmt_From() { 216 | q := sqlf.Select("*"). 217 | From(""). 218 | SubQuery( 219 | "(", ") counted_news", 220 | sqlf.From("news"). 221 | Select("id, section, header, score"). 222 | Select("row_number() OVER (PARTITION BY section ORDER BY score DESC) AS rating_in_section"). 223 | OrderBy("section, rating_in_section")). 224 | Where("rating_in_section <= 5") 225 | fmt.Println(q.String()) 226 | q.Close() 227 | // Output: 228 | // SELECT * FROM (SELECT id, section, header, score, row_number() OVER (PARTITION BY section ORDER BY score DESC) AS rating_in_section FROM news ORDER BY section, rating_in_section) counted_news WHERE rating_in_section <= 5 229 | } 230 | 231 | func ExampleStmt_SubQuery() { 232 | q := sqlf.From("orders o"). 233 | Select("date, region"). 234 | SubQuery("(", ") AS prev_order_date", 235 | sqlf.From("orders po"). 236 | Select("date"). 237 | Where("region = o.region"). 238 | Where("id < o.id"). 239 | OrderBy("id DESC"). 240 | Clause("LIMIT 1")). 241 | Where("date > CURRENT_DATE - interval '1 day'"). 242 | OrderBy("id DESC") 243 | fmt.Println(q.String()) 244 | q.Close() 245 | 246 | // Output: 247 | // SELECT date, region, (SELECT date FROM orders po WHERE region = o.region AND id < o.id ORDER BY id DESC LIMIT 1) AS prev_order_date FROM orders o WHERE date > CURRENT_DATE - interval '1 day' ORDER BY id DESC 248 | } 249 | 250 | func ExampleStmt_Clause() { 251 | q := sqlf.From("empsalary"). 252 | Select("sum(salary) OVER w"). 253 | Clause("WINDOW w AS (PARTITION BY depname ORDER BY salary DESC)") 254 | fmt.Println(q.String()) 255 | q.Close() 256 | 257 | // Output: 258 | // SELECT sum(salary) OVER w FROM empsalary WINDOW w AS (PARTITION BY depname ORDER BY salary DESC) 259 | } 260 | 261 | func ExampleStmt_QueryRowAndClose() { 262 | type Offer struct { 263 | id int64 264 | productId int64 265 | price float64 266 | isDeleted bool 267 | } 268 | 269 | var o Offer 270 | 271 | err := sqlf.From("offers"). 272 | Select("id").To(&o.id). 273 | Select("product_id").To(&o.productId). 274 | Select("price").To(&o.price). 275 | Select("is_deleted").To(&o.isDeleted). 276 | Where("id = ?", 42). 277 | QueryRowAndClose(ctx, db) 278 | if err != nil { 279 | panic(err) 280 | } 281 | } 282 | 283 | func ExampleStmt_Bind() { 284 | type Offer struct { 285 | Id int64 `db:"id"` 286 | ProductId int64 `db:"product_id"` 287 | Price float64 `db:"price"` 288 | IsDeleted bool `db:"is_deleted"` 289 | } 290 | 291 | var o Offer 292 | 293 | err := sqlf.From("offers"). 294 | Bind(&o). 295 | Where("id = ?", 42). 296 | QueryRowAndClose(ctx, db) 297 | if err != nil { 298 | panic(err) 299 | } 300 | } 301 | 302 | func ExampleStmt_In() { 303 | q := sqlf.From("tasks"). 304 | Select("id, status"). 305 | Where("status").In("new", "pending", "wip") 306 | fmt.Println(q.String()) 307 | fmt.Println(q.Args()) 308 | q.Close() 309 | 310 | // Output: 311 | // SELECT id, status FROM tasks WHERE status IN (?,?,?) 312 | // [new pending wip] 313 | } 314 | 315 | func ExampleStmt_Union() { 316 | q := sqlf.From("tasks"). 317 | Select("id, status"). 318 | Where("status = ?", "new"). 319 | Union(true, sqlf.From("tasks"). 320 | Select("id, status"). 321 | Where("status = ?", "pending")). 322 | Union(true, sqlf.From("tasks"). 323 | Select("id, status"). 324 | Where("status = ?", "wip")). 325 | OrderBy("id") 326 | fmt.Println(q.String()) 327 | fmt.Println(q.Args()) 328 | q.Close() 329 | 330 | // Output: 331 | // SELECT id, status FROM tasks WHERE status = ? UNION ALL SELECT id, status FROM tasks WHERE status = ? UNION ALL SELECT id, status FROM tasks WHERE status = ? ORDER BY id 332 | // [new pending wip] 333 | } 334 | -------------------------------------------------------------------------------- /executor.go: -------------------------------------------------------------------------------- 1 | package sqlf 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | ) 7 | 8 | // Executor performs SQL queries. 9 | // It's an interface accepted by Query, QueryRow and Exec methods. 10 | // Both sql.DB, sql.Conn and sql.Tx can be passed as executor. 11 | type Executor interface { 12 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 13 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 14 | QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 15 | } 16 | 17 | // Query executes the statement. 18 | // For every row of a returned dataset it calls a handler function. 19 | // If scan targets were set via To method calls, Query method 20 | // executes rows.Scan right before calling a handler function. 21 | func (q *Stmt) Query(ctx context.Context, db Executor, handler func(rows *sql.Rows)) error { 22 | if ctx == nil { 23 | ctx = context.Background() 24 | } 25 | 26 | // Fetch rows 27 | rows, err := db.QueryContext(ctx, q.String(), q.args...) 28 | if err != nil { 29 | return err 30 | } 31 | 32 | // Iterate through rows of returned dataset 33 | for rows.Next() { 34 | if len(q.dest) > 0 { 35 | err = rows.Scan(q.dest...) 36 | if err != nil { 37 | break 38 | } 39 | } 40 | // Call a callback function 41 | handler(rows) 42 | } 43 | // Check for errors during rows "Close". 44 | // This may be more important if multiple statements are executed 45 | // in a single batch and rows were written as well as read. 46 | if closeErr := rows.Close(); closeErr != nil { 47 | return closeErr 48 | } 49 | 50 | // Check for row scan error. 51 | if err != nil { 52 | return err 53 | } 54 | 55 | // Check for errors during row iteration. 56 | return rows.Err() 57 | } 58 | 59 | // QueryAndClose executes the statement and releases all the resources that 60 | // can be reused to a pool. Do not call any Stmt methods after this call. 61 | // For every row of a returned dataset QueryAndClose executes a handler function. 62 | // If scan targets were set via To method calls, QueryAndClose method 63 | // executes rows.Scan right before calling a handler function. 64 | func (q *Stmt) QueryAndClose(ctx context.Context, db Executor, handler func(rows *sql.Rows)) error { 65 | err := q.Query(ctx, db, handler) 66 | q.Close() 67 | return err 68 | } 69 | 70 | // QueryRow executes the statement via Executor methods 71 | // and scans values to variables bound via To method calls. 72 | func (q *Stmt) QueryRow(ctx context.Context, db Executor) error { 73 | if ctx == nil { 74 | ctx = context.Background() 75 | } 76 | row := db.QueryRowContext(ctx, q.String(), q.args...) 77 | 78 | return row.Scan(q.dest...) 79 | } 80 | 81 | // QueryRowAndClose executes the statement via Executor methods 82 | // and scans values to variables bound via To method calls. 83 | // All the objects allocated by query builder are moved to a pool 84 | // to be reused. 85 | // 86 | // Do not call any Stmt methods after this call. 87 | func (q *Stmt) QueryRowAndClose(ctx context.Context, db Executor) error { 88 | err := q.QueryRow(ctx, db) 89 | q.Close() 90 | return err 91 | } 92 | 93 | // Exec executes the statement. 94 | func (q *Stmt) Exec(ctx context.Context, db Executor) (sql.Result, error) { 95 | if ctx == nil { 96 | ctx = context.Background() 97 | } 98 | return db.ExecContext(ctx, q.String(), q.args...) 99 | } 100 | 101 | // ExecAndClose executes the statement and releases all the objects 102 | // and buffers allocated by statement builder back to a pool. 103 | // 104 | // Do not call any Stmt methods after this call. 105 | func (q *Stmt) ExecAndClose(ctx context.Context, db Executor) (sql.Result, error) { 106 | res, err := q.Exec(ctx, db) 107 | q.Close() 108 | return res, err 109 | } 110 | -------------------------------------------------------------------------------- /executor_test.go: -------------------------------------------------------------------------------- 1 | package sqlf_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "log" 8 | "os" 9 | "testing" 10 | "time" 11 | 12 | "github.com/leporo/sqlf" 13 | _ "github.com/mattn/go-sqlite3" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | type dbEnv struct { 18 | driver string 19 | db *sql.DB 20 | sqlf *sqlf.Dialect 21 | } 22 | 23 | type dbConfig struct { 24 | driver string 25 | envVar string 26 | defDSN string 27 | dialect *sqlf.Dialect 28 | } 29 | 30 | var dbList = []dbConfig{ 31 | { 32 | driver: "sqlite3", 33 | envVar: "SQLF_SQLITE_DSN", 34 | defDSN: ":memory:", 35 | dialect: sqlf.NoDialect, 36 | }, 37 | } 38 | 39 | var envs = make([]dbEnv, 0, len(dbList)) 40 | 41 | func init() { 42 | connect() 43 | } 44 | 45 | func connect() { 46 | // Connect to databases 47 | for _, config := range dbList { 48 | dsn := os.Getenv(config.envVar) 49 | if dsn == "" { 50 | dsn = config.defDSN 51 | } 52 | if dsn == "" || dsn == "skip" { 53 | fmt.Printf("Skipping %s tests.", config.driver) 54 | continue 55 | } 56 | db, err := sql.Open(config.driver, dsn) 57 | if err != nil { 58 | log.Fatalf("Invalid %s DSN: %v", config.driver, err) 59 | } 60 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 61 | err = db.PingContext(ctx) 62 | cancel() 63 | if err != nil { 64 | log.Fatalf("Unable to connect to %s: %v", config.driver, err) 65 | } 66 | envs = append(envs, dbEnv{ 67 | driver: config.driver, 68 | db: db, 69 | sqlf: config.dialect, 70 | }) 71 | } 72 | } 73 | 74 | func execScript(db *sql.DB, script []string) (err error) { 75 | for _, stmt := range script { 76 | _, err = db.Exec(stmt) 77 | if err != nil { 78 | break 79 | } 80 | } 81 | return err 82 | } 83 | 84 | func forEveryDB(t *testing.T, test func(ctx context.Context, env *dbEnv)) { 85 | for _, ctx := range []context.Context{nil, context.Background()} { 86 | for n := range envs { 87 | env := &envs[n] 88 | // Create schema 89 | err := execScript(env.db, sqlSchemaCreate) 90 | if err != nil { 91 | t.Errorf("Failed to create a %s schema: %v", env.driver, err) 92 | } else { 93 | err = execScript(env.db, sqlFillDb) 94 | if err != nil { 95 | t.Errorf("Failed to populate a %s database: %v", env.driver, err) 96 | } else { 97 | // Execute a test 98 | test(ctx, env) 99 | } 100 | } 101 | err = execScript(env.db, sqlSchemaDrop) 102 | if err != nil { 103 | t.Errorf("Failed to drop a %s schema: %v", env.driver, err) 104 | } 105 | } 106 | } 107 | } 108 | 109 | func TestQueryRow(t *testing.T) { 110 | forEveryDB(t, func(ctx context.Context, env *dbEnv) { 111 | var name string 112 | q := env.sqlf.From("users"). 113 | Select("name").To(&name). 114 | Where("id = ?", 1) 115 | err := q.QueryRow(ctx, env.db) 116 | q.Close() 117 | require.NoError(t, err, "Failed to execute a query: %v", err) 118 | require.Equal(t, "User 1", name) 119 | }) 120 | } 121 | 122 | func TestQueryRowAndClose(t *testing.T) { 123 | forEveryDB(t, func(ctx context.Context, env *dbEnv) { 124 | var name string 125 | err := env.sqlf.From("users"). 126 | Select("name").To(&name). 127 | Where("id = ?", 1). 128 | QueryRowAndClose(ctx, env.db) 129 | require.NoError(t, err, "Failed to execute a query: %v", err) 130 | require.Equal(t, "User 1", name) 131 | }) 132 | } 133 | 134 | func TestBind(t *testing.T) { 135 | forEveryDB(t, func(ctx context.Context, env *dbEnv) { 136 | var u struct { 137 | ID int64 `db:"id"` 138 | Name string `db:"name"` 139 | } 140 | err := env.sqlf.From("users"). 141 | Bind(&u). 142 | Where("id = ?", 2). 143 | QueryRowAndClose(ctx, env.db) 144 | require.NoError(t, err, "Failed to execute a query: %v", err) 145 | require.Equal(t, "User 2", u.Name) 146 | require.EqualValues(t, 2, u.ID) 147 | }) 148 | } 149 | 150 | func TestBindNested(t *testing.T) { 151 | forEveryDB(t, func(ctx context.Context, env *dbEnv) { 152 | type Parent struct { 153 | ID int64 `db:"id"` 154 | } 155 | var u struct { 156 | Parent 157 | Name string `db:"name"` 158 | } 159 | err := env.sqlf.From("users"). 160 | Bind(&u). 161 | Where("id = ?", 2). 162 | QueryRowAndClose(ctx, env.db) 163 | require.NoError(t, err, "Failed to execute a query: %v", err) 164 | require.Equal(t, "User 2", u.Name) 165 | require.EqualValues(t, 2, u.ID) 166 | }) 167 | } 168 | 169 | func TestExec(t *testing.T) { 170 | forEveryDB(t, func(ctx context.Context, env *dbEnv) { 171 | var ( 172 | userId int 173 | count int 174 | ) 175 | q := env.sqlf.From("users"). 176 | Select("count(*)").To(&count). 177 | Select("min(id)").To(&userId) 178 | defer q.Close() 179 | 180 | q.QueryRow(ctx, env.db) 181 | 182 | require.Equal(t, 3, count) 183 | 184 | _, err := env.sqlf.DeleteFrom("users"). 185 | Where("id = ?", userId). 186 | ExecAndClose(ctx, env.db) 187 | require.NoError(t, err, "Failed to delete a row. %s error: %v", env.driver, err) 188 | 189 | // Re-check the number of remaining rows 190 | count = 0 191 | q.QueryRow(ctx, env.db) 192 | 193 | require.Equal(t, 2, count) 194 | }) 195 | } 196 | 197 | func TestPagination(t *testing.T) { 198 | forEveryDB(t, func(ctx context.Context, env *dbEnv) { 199 | type Income struct { 200 | Id int64 `db:"id"` 201 | UserId int64 `db:"user_id"` 202 | FromUserId int64 `db:"from_user_id"` 203 | Amount float64 `db:"amount"` 204 | } 205 | 206 | type PaginatedIncomes struct { 207 | Count int64 208 | Data []Income 209 | } 210 | 211 | var ( 212 | result PaginatedIncomes 213 | o Income 214 | err error 215 | ) 216 | 217 | // Create a base query, apply filters 218 | qs := sqlf.From("incomes").Where("amount > ?", 100) 219 | // Clone a statement and retrieve the record count 220 | err = qs.Clone(). 221 | Select("count(id)").To(&result.Count). 222 | QueryRowAndClose(ctx, env.db) 223 | if err != nil { 224 | return 225 | } 226 | 227 | // Retrieve page data 228 | err = qs.Bind(&o). 229 | OrderBy("id desc"). 230 | Paginate(1, 2). 231 | QueryAndClose(ctx, env.db, func(rows *sql.Rows) { 232 | result.Data = append(result.Data, o) 233 | }) 234 | if err != nil { 235 | return 236 | } 237 | require.EqualValues(t, 4, result.Count) 238 | require.Len(t, result.Data, 2) 239 | }) 240 | } 241 | 242 | func TestQuery(t *testing.T) { 243 | forEveryDB(t, func(ctx context.Context, env *dbEnv) { 244 | var ( 245 | nRows int = 0 246 | userTo string 247 | userFrom string 248 | amount float64 249 | ) 250 | q := env.sqlf. 251 | From("incomes"). 252 | From("users ut").Where("ut.id = user_id"). 253 | From("users uf").Where("uf.id = from_user_id"). 254 | Select("ut.name").To(&userTo). 255 | Select("uf.name").To(&userFrom). 256 | Select("sum(amount) as got").To(&amount). 257 | GroupBy("ut.name, uf.name"). 258 | OrderBy("got DESC") 259 | defer q.Close() 260 | err := q.Query(ctx, env.db, func(rows *sql.Rows) { 261 | nRows++ 262 | }) 263 | if err != nil { 264 | t.Errorf("Failed to execute a query: %v", err) 265 | } else { 266 | require.Equal(t, 4, nRows) 267 | 268 | q.Limit(1) 269 | 270 | nRows = 0 271 | err := q.Query(ctx, env.db, func(rows *sql.Rows) { 272 | nRows++ 273 | }) 274 | if err != nil { 275 | t.Errorf("Failed to execute a query: %v", err) 276 | } else { 277 | require.Equal(t, 1, nRows) 278 | require.Equal(t, "User 3", userTo) 279 | require.Equal(t, "User 1", userFrom) 280 | require.Equal(t, 500.0, amount) 281 | } 282 | } 283 | }) 284 | } 285 | 286 | func TestQueryAndClose(t *testing.T) { 287 | forEveryDB(t, func(ctx context.Context, env *dbEnv) { 288 | var ( 289 | nRows int = 0 290 | total float64 = 0.0 291 | amount float64 292 | ) 293 | err := env.sqlf. 294 | From("incomes"). 295 | Select("sum(amount) as got").To(&amount). 296 | GroupBy("user_id, from_user_id"). 297 | OrderBy("got DESC"). 298 | QueryAndClose(ctx, env.db, func(rows *sql.Rows) { 299 | nRows++ 300 | total += amount 301 | }) 302 | 303 | require.NoError(t, err, "Failed to execute a query. %s error: %v", env.driver, err) 304 | require.Equal(t, 4, nRows) 305 | require.Equal(t, 1550.0, total) 306 | }) 307 | } 308 | 309 | var sqlSchemaCreate = []string{ 310 | `CREATE TABLE users ( 311 | id int IDENTITY PRIMARY KEY, 312 | name varchar(128) NOT NULL)`, 313 | `CREATE TABLE incomes ( 314 | id int IDENTITY PRIMARY KEY, 315 | user_id int REFERENCES users(id), 316 | from_user_id int REFERENCES users(id), 317 | amount money)`, 318 | } 319 | 320 | var sqlFillDb = []string{ 321 | `INSERT INTO users (id, name) VALUES (1, "User 1")`, 322 | `INSERT INTO users (id, name) VALUES (2, "User 2")`, 323 | `INSERT INTO users (id, name) VALUES (3, "User 3")`, 324 | 325 | `INSERT INTO incomes (user_id, from_user_id, amount) VALUES (1, 2, 100)`, 326 | `INSERT INTO incomes (user_id, from_user_id, amount) VALUES (1, 2, 200)`, 327 | `INSERT INTO incomes (user_id, from_user_id, amount) VALUES (1, 3, 350)`, 328 | `INSERT INTO incomes (user_id, from_user_id, amount) VALUES (2, 3, 400)`, 329 | `INSERT INTO incomes (user_id, from_user_id, amount) VALUES (3, 1, 500)`, 330 | } 331 | 332 | var sqlSchemaDrop = []string{ 333 | `DROP TABLE incomes`, 334 | `DROP TABLE users`, 335 | } 336 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/leporo/sqlf 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/mattn/go-sqlite3 v1.14.16 7 | github.com/stretchr/testify v1.8.2 8 | github.com/valyala/bytebufferpool v1.0.0 9 | ) 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= 5 | github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= 6 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 7 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 8 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 9 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 10 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 11 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 12 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 13 | github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= 14 | github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 15 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= 16 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= 17 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 18 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 19 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 20 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 21 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 22 | -------------------------------------------------------------------------------- /pool.go: -------------------------------------------------------------------------------- 1 | package sqlf 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/valyala/bytebufferpool" 7 | ) 8 | 9 | var stmtPool = sync.Pool{New: newStmt} 10 | 11 | func newStmt() interface{} { 12 | return &Stmt{ 13 | chunks: make(stmtChunks, 0, 8), 14 | } 15 | } 16 | 17 | func getStmt(d *Dialect) *Stmt { 18 | stmt := stmtPool.Get().(*Stmt) 19 | stmt.dialect = d 20 | stmt.buf = getBuffer() 21 | return stmt 22 | } 23 | 24 | func reuseStmt(q *Stmt) { 25 | q.chunks = q.chunks[:0] 26 | if len(q.args) > 0 { 27 | for n := range q.args { 28 | q.args[n] = nil 29 | } 30 | q.args = q.args[:0] 31 | } 32 | if len(q.dest) > 0 { 33 | for n := range q.dest { 34 | q.dest[n] = nil 35 | } 36 | q.dest = q.dest[:0] 37 | } 38 | putBuffer(q.buf) 39 | q.buf = nil 40 | q.sql = "" 41 | 42 | stmtPool.Put(q) 43 | } 44 | 45 | func getBuffer() *bytebufferpool.ByteBuffer { 46 | return bytebufferpool.Get() 47 | } 48 | 49 | func putBuffer(buf *bytebufferpool.ByteBuffer) { 50 | bytebufferpool.Put(buf) 51 | } 52 | -------------------------------------------------------------------------------- /stmt.go: -------------------------------------------------------------------------------- 1 | package sqlf 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | 7 | "github.com/valyala/bytebufferpool" 8 | ) 9 | 10 | /* 11 | New initializes a SQL statement builder instance with an arbitrary verb. 12 | 13 | Use sqlf.Select(), sqlf.InsertInto(), sqlf.DeleteFrom() to start 14 | common SQL statements. 15 | 16 | Use New for special cases like this: 17 | 18 | q := sqlf.New("TRANCATE") 19 | for _, table := range tableNames { 20 | q.Expr(table) 21 | } 22 | q.Clause("RESTART IDENTITY") 23 | err := q.ExecAndClose(ctx, db) 24 | if err != nil { 25 | panic(err) 26 | } 27 | */ 28 | func New(verb string, args ...interface{}) *Stmt { 29 | return defaultDialect.New(verb, args...) 30 | } 31 | 32 | /* 33 | From starts a SELECT statement. 34 | 35 | var cnt int64 36 | 37 | err := sqlf.From("table"). 38 | Select("COUNT(*)").To(&cnt) 39 | Where("value >= ?", 42). 40 | QueryRowAndClose(ctx, db) 41 | if err != nil { 42 | panic(err) 43 | } 44 | */ 45 | func From(expr string, args ...interface{}) *Stmt { 46 | return defaultDialect.From(expr, args...) 47 | } 48 | 49 | /* 50 | With starts a statement prepended by WITH clause 51 | and closes a subquery passed as an argument. 52 | */ 53 | func With(queryName string, query *Stmt) *Stmt { 54 | return defaultDialect.With(queryName, query) 55 | } 56 | 57 | /* 58 | Select starts a SELECT statement. 59 | 60 | var cnt int64 61 | 62 | err := sqlf.Select("COUNT(*)").To(&cnt). 63 | From("table"). 64 | Where("value >= ?", 42). 65 | QueryRowAndClose(ctx, db) 66 | if err != nil { 67 | panic(err) 68 | } 69 | 70 | Note that From method can also be used to start a SELECT statement. 71 | */ 72 | func Select(expr string, args ...interface{}) *Stmt { 73 | return defaultDialect.Select(expr, args...) 74 | } 75 | 76 | /* 77 | Update starts an UPDATE statement. 78 | 79 | err := sqlf.Update("table"). 80 | Set("field1", "newvalue"). 81 | Where("id = ?", 42). 82 | ExecAndClose(ctx, db) 83 | if err != nil { 84 | panic(err) 85 | } 86 | */ 87 | func Update(tableName string) *Stmt { 88 | return defaultDialect.Update(tableName) 89 | } 90 | 91 | /* 92 | InsertInto starts an INSERT statement. 93 | 94 | var newId int64 95 | err := sqlf.InsertInto("table"). 96 | Set("field", value). 97 | Returning("id").To(&newId). 98 | QueryRowAndClose(ctx, db) 99 | if err != nil { 100 | panic(err) 101 | } 102 | */ 103 | func InsertInto(tableName string) *Stmt { 104 | return defaultDialect.InsertInto(tableName) 105 | } 106 | 107 | /* 108 | DeleteFrom starts a DELETE statement. 109 | 110 | err := sqlf.DeleteFrom("table").Where("id = ?", id).ExecAndClose(ctx, db) 111 | */ 112 | func DeleteFrom(tableName string) *Stmt { 113 | return defaultDialect.DeleteFrom(tableName) 114 | } 115 | 116 | type stmtChunk struct { 117 | pos chunkPos 118 | bufLow int 119 | bufHigh int 120 | hasExpr bool 121 | argLen int 122 | } 123 | type stmtChunks []stmtChunk 124 | 125 | /* 126 | Stmt provides a set of helper methods for SQL statement building and execution. 127 | 128 | Use one of the following methods to create a SQL statement builder instance: 129 | 130 | sqlf.From("table") 131 | sqlf.Select("field") 132 | sqlf.InsertInto("table") 133 | sqlf.Update("table") 134 | sqlf.DeleteFrom("table") 135 | 136 | For other SQL statements use New: 137 | 138 | q := sqlf.New("TRUNCATE") 139 | for _, table := range tablesToBeEmptied { 140 | q.Expr(table) 141 | } 142 | err := q.ExecAndClose(ctx, db) 143 | if err != nil { 144 | panic(err) 145 | } 146 | */ 147 | type Stmt struct { 148 | dialect *Dialect 149 | pos chunkPos 150 | chunks stmtChunks 151 | buf *bytebufferpool.ByteBuffer 152 | sql string 153 | args []interface{} 154 | dest []interface{} 155 | } 156 | 157 | type newRow struct { 158 | *Stmt 159 | first bool 160 | notEmpty bool 161 | } 162 | 163 | /* 164 | Select adds a SELECT clause to a statement and/or appends 165 | an expression that defines columns of a resulting data set. 166 | 167 | q := sqlf.Select("DISTINCT field1, field2").From("table") 168 | 169 | Select can be called multiple times to add more columns: 170 | 171 | q := sqlf.From("table").Select("field1") 172 | if needField2 { 173 | q.Select("field2") 174 | } 175 | // ... 176 | q.Close() 177 | 178 | Use To method to bind variables to selected columns: 179 | 180 | var ( 181 | num int 182 | name string 183 | ) 184 | 185 | res := sqlf.From("table"). 186 | Select("num, name").To(&num, &name). 187 | Where("id = ?", 42). 188 | QueryRowAndClose(ctx, db) 189 | if err != nil { 190 | panic(err) 191 | } 192 | 193 | Note that a SELECT statement can also be started by a From method call. 194 | */ 195 | func (q *Stmt) Select(expr string, args ...interface{}) *Stmt { 196 | q.addChunk(posSelect, "SELECT", expr, args, ", ") 197 | return q 198 | } 199 | 200 | /* 201 | To sets a scan target for columns to be selected. 202 | 203 | Accepts value pointers to be passed to sql.Rows.Scan by 204 | Query and QueryRow methods. 205 | 206 | var ( 207 | field1 int 208 | field2 string 209 | ) 210 | q := sqlf.From("table"). 211 | Select("field1").To(&field1). 212 | Select("field2").To(&field2) 213 | err := QueryRow(nil, db) 214 | q.Close() 215 | if err != nil { 216 | // ... 217 | } 218 | 219 | To method MUST be called immediately after Select, Returning or other 220 | method that defines data to be returned. 221 | */ 222 | func (q *Stmt) To(dest ...interface{}) *Stmt { 223 | if len(dest) > 0 { 224 | // As Scan bindings make sense for a single clause per statement, 225 | // the order expressions appear in SQL matches the order expressions 226 | // are added. So dest value pointers can safely be appended 227 | // to the list on every To call. 228 | q.dest = insertAt(q.dest, dest, len(q.dest)) 229 | } 230 | return q 231 | } 232 | 233 | /* 234 | Update adds UPDATE clause to a statement. 235 | 236 | q.Update("table") 237 | 238 | tableName argument can be a SQL fragment: 239 | 240 | q.Update("ONLY table AS t") 241 | */ 242 | func (q *Stmt) Update(tableName string) *Stmt { 243 | q.addChunk(posUpdate, "UPDATE", tableName, nil, ", ") 244 | return q 245 | } 246 | 247 | /* 248 | InsertInto adds INSERT INTO clause to a statement. 249 | 250 | q.InsertInto("table") 251 | 252 | tableName argument can be a SQL fragment: 253 | 254 | q.InsertInto("table AS t") 255 | */ 256 | func (q *Stmt) InsertInto(tableName string) *Stmt { 257 | q.addChunk(posInsert, "INSERT INTO", tableName, nil, ", ") 258 | q.addChunk(posInsertFields-1, "(", "", nil, "") 259 | q.addChunk(posValues-1, ") VALUES (", "", nil, "") 260 | q.addChunk(posValues+1, ")", "", nil, "") 261 | q.pos = posInsertFields 262 | return q 263 | } 264 | 265 | /* 266 | DeleteFrom adds DELETE clause to a statement. 267 | 268 | q.DeleteFrom("table").Where("id = ?", id) 269 | */ 270 | func (q *Stmt) DeleteFrom(tableName string) *Stmt { 271 | q.addChunk(posDelete, "DELETE FROM", tableName, nil, ", ") 272 | return q 273 | } 274 | 275 | /* 276 | Set method: 277 | 278 | - Adds a column to the list of columns and a value to VALUES clause of INSERT statement, 279 | 280 | - Adds an item to SET clause of an UPDATE statement. 281 | 282 | q.Set("field", 32) 283 | 284 | For INSERT statements a call to Set method generates 285 | both the list of columns and values to be inserted: 286 | 287 | q := sqlf.InsertInto("table").Set("field", 42) 288 | 289 | produces 290 | 291 | INSERT INTO table (field) VALUES (42) 292 | 293 | Do not use it to construct ON CONFLICT DO UPDATE SET or similar clauses. 294 | Use generic Clause and Expr methods instead: 295 | 296 | q.Clause("ON CONFLICT DO UPDATE SET").Expr("column_name = ?", value) 297 | */ 298 | func (q *Stmt) Set(field string, value interface{}) *Stmt { 299 | return q.SetExpr(field, "?", value) 300 | } 301 | 302 | /* 303 | SetExpr is an extended version of Set method. 304 | 305 | q.SetExpr("field", "field + 1") 306 | q.SetExpr("field", "? + ?", 31, 11) 307 | */ 308 | func (q *Stmt) SetExpr(field, expr string, args ...interface{}) *Stmt { 309 | p := chunkPos(0) 310 | for _, chunk := range q.chunks { 311 | if chunk.pos == posInsert || chunk.pos == posUpdate { 312 | p = chunk.pos 313 | break 314 | } 315 | } 316 | 317 | switch p { 318 | case posInsert: 319 | q.addChunk(posInsertFields, "", field, nil, ", ") 320 | q.addChunk(posValues, "", expr, args, ", ") 321 | case posUpdate: 322 | q.addChunk(posSet, "SET", field+"="+expr, args, ", ") 323 | } 324 | return q 325 | } 326 | 327 | // From adds a FROM clause to statement. 328 | func (q *Stmt) From(expr string, args ...interface{}) *Stmt { 329 | q.addChunk(posFrom, "FROM", expr, args, ", ") 330 | return q 331 | } 332 | 333 | /* 334 | Where adds a filter: 335 | 336 | sqlf.From("users"). 337 | Select("id, name"). 338 | Where("email = ?", email). 339 | Where("is_active = 1") 340 | */ 341 | func (q *Stmt) Where(expr string, args ...interface{}) *Stmt { 342 | q.addChunk(posWhere, "WHERE", expr, args, " AND ") 343 | return q 344 | } 345 | 346 | /* 347 | In adds IN expression to the current filter. 348 | 349 | In method must be called after a Where method call. 350 | */ 351 | func (q *Stmt) In(args ...interface{}) *Stmt { 352 | buf := bytebufferpool.Get() 353 | buf.WriteString("IN (") 354 | l := len(args) - 1 355 | for i := range args { 356 | if i < l { 357 | buf.Write(placeholderComma) 358 | } else { 359 | buf.Write(placeholder) 360 | } 361 | } 362 | buf.WriteString(")") 363 | 364 | q.addChunk(posWhere, "", bufToString(&buf.B), args, " ") 365 | 366 | bytebufferpool.Put(buf) 367 | return q 368 | } 369 | 370 | /* 371 | Join adds an INNERT JOIN clause to SELECT statement 372 | */ 373 | func (q *Stmt) Join(table, on string) *Stmt { 374 | q.join("JOIN ", table, on) 375 | return q 376 | } 377 | 378 | /* 379 | LeftJoin adds a LEFT OUTER JOIN clause to SELECT statement 380 | */ 381 | func (q *Stmt) LeftJoin(table, on string) *Stmt { 382 | q.join("LEFT JOIN ", table, on) 383 | return q 384 | } 385 | 386 | /* 387 | RightJoin adds a RIGHT OUTER JOIN clause to SELECT statement 388 | */ 389 | func (q *Stmt) RightJoin(table, on string) *Stmt { 390 | q.join("RIGHT JOIN ", table, on) 391 | return q 392 | } 393 | 394 | /* 395 | FullJoin adds a FULL OUTER JOIN clause to SELECT statement 396 | */ 397 | func (q *Stmt) FullJoin(table, on string) *Stmt { 398 | q.join("FULL JOIN ", table, on) 399 | return q 400 | } 401 | 402 | // OrderBy adds the ORDER BY clause to SELECT statement 403 | func (q *Stmt) OrderBy(expr ...string) *Stmt { 404 | q.addChunk(posOrderBy, "ORDER BY", strings.Join(expr, ", "), nil, ", ") 405 | return q 406 | } 407 | 408 | // GroupBy adds the GROUP BY clause to SELECT statement 409 | func (q *Stmt) GroupBy(expr string) *Stmt { 410 | q.addChunk(posGroupBy, "GROUP BY", expr, nil, ", ") 411 | return q 412 | } 413 | 414 | // Having adds the HAVING clause to SELECT statement 415 | func (q *Stmt) Having(expr string, args ...interface{}) *Stmt { 416 | q.addChunk(posHaving, "HAVING", expr, args, " AND ") 417 | return q 418 | } 419 | 420 | // Limit adds a limit on number of returned rows 421 | func (q *Stmt) Limit(limit interface{}) *Stmt { 422 | q.addChunk(posLimit, "LIMIT ?", "", []interface{}{limit}, "") 423 | return q 424 | } 425 | 426 | // Offset adds a limit on number of returned rows 427 | func (q *Stmt) Offset(offset interface{}) *Stmt { 428 | q.addChunk(posOffset, "OFFSET ?", "", []interface{}{offset}, "") 429 | return q 430 | } 431 | 432 | // Paginate provides an easy way to set both offset and limit 433 | func (q *Stmt) Paginate(page, pageSize int) *Stmt { 434 | if page < 1 { 435 | page = 1 436 | } 437 | if pageSize < 1 { 438 | pageSize = 1 439 | } 440 | if page > 1 { 441 | q.Offset((page - 1) * pageSize) 442 | } 443 | q.Limit(pageSize) 444 | return q 445 | } 446 | 447 | // Returning adds a RETURNING clause to a statement 448 | func (q *Stmt) Returning(expr string) *Stmt { 449 | q.addChunk(posReturning, "RETURNING", expr, nil, ", ") 450 | return q 451 | } 452 | 453 | // With prepends a statement with an WITH clause. 454 | // With method calls a Close method of a given query, so 455 | // make sure not to reuse it afterwards. 456 | func (q *Stmt) With(queryName string, query *Stmt) *Stmt { 457 | q.addChunk(posWith, "WITH", "", nil, "") 458 | return q.SubQuery(queryName+" AS (", ")", query) 459 | } 460 | 461 | /* 462 | Expr appends an expression to the most recently added clause. 463 | 464 | Expressions are separated with commas. 465 | */ 466 | func (q *Stmt) Expr(expr string, args ...interface{}) *Stmt { 467 | q.addChunk(q.pos, "", expr, args, ", ") 468 | return q 469 | } 470 | 471 | /* 472 | SubQuery appends a sub query expression to a current clause. 473 | 474 | SubQuery method call closes the Stmt passed as query parameter. 475 | Do not reuse it afterwards. 476 | */ 477 | func (q *Stmt) SubQuery(prefix, suffix string, query *Stmt) *Stmt { 478 | delimiter := ", " 479 | if q.pos == posWhere { 480 | delimiter = " AND " 481 | } 482 | index := q.addChunk(q.pos, "", prefix, query.args, delimiter) 483 | chunk := &q.chunks[index] 484 | // Make sure subquery is not dialect-specific. 485 | if query.dialect != NoDialect { 486 | query.dialect = NoDialect 487 | query.Invalidate() 488 | } 489 | q.buf.WriteString(query.String()) 490 | q.buf.WriteString(suffix) 491 | chunk.bufHigh = q.buf.Len() 492 | // Close the subquery 493 | query.Close() 494 | 495 | return q 496 | } 497 | 498 | /* 499 | Union adds a UNION clause to the statement. 500 | 501 | all argument controls if UNION ALL or UNION clause 502 | is to be constructed. Use UNION ALL if possible to 503 | get faster queries. 504 | */ 505 | func (q *Stmt) Union(all bool, query *Stmt) *Stmt { 506 | p := posUnion 507 | if len(q.chunks) > 0 { 508 | last := (&q.chunks[len(q.chunks)-1]).pos 509 | if last >= p { 510 | p = last + 1 511 | } 512 | } 513 | var index int 514 | if all { 515 | index = q.addChunk(p, "UNION ALL ", "", query.args, "") 516 | } else { 517 | index = q.addChunk(p, "UNION ", "", query.args, "") 518 | } 519 | chunk := &q.chunks[index] 520 | // Make sure subquery is not dialect-specific. 521 | if query.dialect != NoDialect { 522 | query.dialect = NoDialect 523 | query.Invalidate() 524 | } 525 | q.buf.WriteString(query.String()) 526 | chunk.bufHigh = q.buf.Len() 527 | // Close the subquery 528 | query.Close() 529 | 530 | return q 531 | } 532 | 533 | /* 534 | Clause appends a raw SQL fragment to the statement. 535 | 536 | Use it to add a raw SQL fragment like ON CONFLICT, ON DUPLICATE KEY, WINDOW, etc. 537 | 538 | An SQL fragment added via Clause method appears after the last clause previously 539 | added. If called first, Clause method prepends a statement with a raw SQL. 540 | */ 541 | func (q *Stmt) Clause(expr string, args ...interface{}) *Stmt { 542 | p := posStart 543 | if len(q.chunks) > 0 { 544 | p = (&q.chunks[len(q.chunks)-1]).pos + 10 545 | } 546 | q.addChunk(p, expr, "", args, ", ") 547 | return q 548 | } 549 | 550 | // String method builds and returns an SQL statement. 551 | func (q *Stmt) String() string { 552 | if q.sql == "" { 553 | // Calculate the buffer hash and check for available queries 554 | sql, ok := q.dialect.getCachedSQL(q.buf) 555 | if ok { 556 | q.sql = sql 557 | } else { 558 | // Build a query 559 | var argNo int = 1 560 | buf := strings.Builder{} 561 | 562 | pos := chunkPos(0) 563 | for n, chunk := range q.chunks { 564 | // Separate clauses with spaces 565 | if n > 0 && chunk.pos > pos { 566 | buf.Write(space) 567 | } 568 | s := q.buf.B[chunk.bufLow:chunk.bufHigh] 569 | if chunk.argLen > 0 && q.dialect == PostgreSQL { 570 | argNo, _ = writePg(argNo, s, &buf) 571 | } else { 572 | buf.Write(s) 573 | } 574 | pos = chunk.pos 575 | } 576 | q.sql = buf.String() 577 | // Save it for reuse 578 | q.dialect.putCachedSQL(q.buf, q.sql) 579 | } 580 | } 581 | return q.sql 582 | } 583 | 584 | /* 585 | Args returns the list of arguments to be passed to 586 | database driver for statement execution. 587 | 588 | Do not access a slice returned by this method after Stmt is closed. 589 | 590 | An array, a returned slice points to, can be altered by any method that 591 | adds a clause or an expression with arguments. 592 | 593 | Make sure to make a copy of the returned slice if you need to preserve it. 594 | */ 595 | func (q *Stmt) Args() []interface{} { 596 | return q.args 597 | } 598 | 599 | /* 600 | Dest returns a list of value pointers passed via To method calls. 601 | The order matches the constructed SQL statement. 602 | 603 | Do not access a slice returned by this method after Stmt is closed. 604 | 605 | Note that an array, a returned slice points to, can be altered by To method 606 | calls. 607 | 608 | Make sure to make a copy if you need to preserve a slice returned by this method. 609 | */ 610 | func (q *Stmt) Dest() []interface{} { 611 | return q.dest 612 | } 613 | 614 | /* 615 | Invalidate forces a rebuild on next query execution. 616 | 617 | Most likely you don't need to call this method directly. 618 | */ 619 | func (q *Stmt) Invalidate() { 620 | if q.sql != "" { 621 | q.sql = "" 622 | } 623 | } 624 | 625 | /* 626 | Close puts buffers and other objects allocated to build an SQL statement 627 | back to pool for reuse by other Stmt instances. 628 | 629 | Stmt instance should not be used after Close method call. 630 | */ 631 | func (q *Stmt) Close() { 632 | reuseStmt(q) 633 | } 634 | 635 | // Clone creates a copy of the statement. 636 | func (q *Stmt) Clone() *Stmt { 637 | stmt := getStmt(q.dialect) 638 | if cap(stmt.chunks) < len(q.chunks) { 639 | stmt.chunks = make(stmtChunks, len(q.chunks), len(q.chunks)+2) 640 | copy(stmt.chunks, q.chunks) 641 | } else { 642 | stmt.chunks = append(stmt.chunks, q.chunks...) 643 | } 644 | stmt.args = insertAt(stmt.args, q.args, 0) 645 | stmt.dest = insertAt(stmt.dest, q.dest, 0) 646 | stmt.buf.Write(q.buf.B) 647 | stmt.sql = q.sql 648 | 649 | return stmt 650 | } 651 | 652 | // Bind adds structure fields to SELECT statement. 653 | // Structure fields have to be annotated with "db" tag. 654 | // Reflect-based Bind is slightly slower than `Select("field").To(&record.field)` 655 | // but provides an easier way to retrieve data. 656 | // 657 | // Note: this method does no type checks and returns no errors. 658 | func (q *Stmt) Bind(data interface{}) *Stmt { 659 | typ := reflect.TypeOf(data).Elem() 660 | val := reflect.ValueOf(data).Elem() 661 | 662 | for i := 0; i < val.NumField(); i++ { 663 | field := val.Field(i) 664 | t := typ.Field(i) 665 | if field.Kind() == reflect.Struct && t.Anonymous { 666 | q.Bind(field.Addr().Interface()) 667 | } else { 668 | dbFieldName := t.Tag.Get("db") 669 | if dbFieldName != "" { 670 | q.Select(dbFieldName).To(field.Addr().Interface()) 671 | } 672 | } 673 | } 674 | return q 675 | } 676 | 677 | // join adds a join clause to a SELECT statement 678 | func (q *Stmt) join(joinType, table, on string) (index int) { 679 | buf := bytebufferpool.Get() 680 | buf.WriteString(joinType) 681 | buf.WriteString(table) 682 | buf.Write(joinOn) 683 | buf.WriteString(on) 684 | buf.WriteByte(')') 685 | 686 | index = q.addChunk(posFrom, "", bufToString(&buf.B), nil, " ") 687 | 688 | bytebufferpool.Put(buf) 689 | 690 | return index 691 | } 692 | 693 | // addChunk adds a clause or expression to a statement. 694 | func (q *Stmt) addChunk(pos chunkPos, clause, expr string, args []interface{}, sep string) (index int) { 695 | // Remember the position 696 | q.pos = pos 697 | 698 | argLen := len(args) 699 | bufLow := len(q.buf.B) 700 | index = len(q.chunks) 701 | argTail := 0 702 | 703 | addNew := true 704 | addClause := clause != "" 705 | 706 | // Find the position to insert a chunk to 707 | loop: 708 | for i := index - 1; i >= 0; i-- { 709 | chunk := &q.chunks[i] 710 | index = i 711 | switch { 712 | // See if an existing chunk can be extended 713 | case chunk.pos == pos: 714 | // Do nothing if a clause is already there and no expressions are to be added 715 | if expr == "" { 716 | // See if arguments are to be updated 717 | if argLen > 0 { 718 | copy(q.args[len(q.args)-argTail-chunk.argLen:], args) 719 | } 720 | return i 721 | } 722 | // Write a separator 723 | if chunk.hasExpr { 724 | q.buf.WriteString(sep) 725 | } else { 726 | q.buf.WriteString(" ") 727 | } 728 | if chunk.bufHigh == bufLow { 729 | // Do not add a chunk 730 | addNew = false 731 | // Update the existing one 732 | q.buf.WriteString(expr) 733 | chunk.argLen += argLen 734 | chunk.bufHigh = len(q.buf.B) 735 | chunk.hasExpr = true 736 | } else { 737 | // Do not add a clause 738 | addClause = false 739 | index = i + 1 740 | } 741 | break loop 742 | // No existing chunks of this type 743 | case chunk.pos < pos: 744 | index = i + 1 745 | break loop 746 | default: 747 | argTail += chunk.argLen 748 | } 749 | } 750 | 751 | if addNew { 752 | // Insert a new chunk 753 | if addClause { 754 | q.buf.WriteString(clause) 755 | if expr != "" { 756 | q.buf.WriteString(" ") 757 | } 758 | } 759 | q.buf.WriteString(expr) 760 | 761 | if cap(q.chunks) == len(q.chunks) { 762 | chunks := make(stmtChunks, len(q.chunks), cap(q.chunks)*2) 763 | copy(chunks, q.chunks) 764 | q.chunks = chunks 765 | } 766 | 767 | chunk := stmtChunk{ 768 | pos: pos, 769 | bufLow: bufLow, 770 | bufHigh: len(q.buf.B), 771 | argLen: argLen, 772 | hasExpr: expr != "", 773 | } 774 | 775 | q.chunks = append(q.chunks, chunk) 776 | if index < len(q.chunks)-1 { 777 | copy(q.chunks[index+1:], q.chunks[index:]) 778 | q.chunks[index] = chunk 779 | } 780 | } 781 | 782 | // Insert query arguments 783 | if argLen > 0 { 784 | q.args = insertAt(q.args, args, len(q.args)-argTail) 785 | } 786 | q.Invalidate() 787 | 788 | return index 789 | } 790 | 791 | /* 792 | NewRow method helps to construct a bulk INSERT statement. 793 | 794 | The following code 795 | 796 | q := stmt.InsertInto("table") 797 | for k, v := range entries { 798 | q.NewRow(). 799 | Set("key", k). 800 | Set("value", v) 801 | } 802 | 803 | produces (assuming there were 2 key/value pairs at entries map): 804 | 805 | INSERT INTO table ( key, value ) VALUES ( ?, ? ), ( ?, ? ) 806 | */ 807 | func (q *Stmt) NewRow() newRow { 808 | first := true 809 | // Check if there are values 810 | loop: 811 | for i := len(q.chunks) - 1; i >= 0; i-- { 812 | chunk := q.chunks[i] 813 | switch { 814 | // See if an existing chunk can be extended 815 | case chunk.pos == posValues: 816 | // Values section is there, prepend 817 | first = false 818 | break loop 819 | case chunk.pos < posValues: 820 | break loop 821 | } 822 | } 823 | if !first { 824 | q.addChunk(posValues, "", " ", nil, " ), (") 825 | } 826 | return newRow{ 827 | Stmt: q, 828 | first: first, 829 | } 830 | } 831 | 832 | /* 833 | Set method: 834 | 835 | - Adds a column to the list of columns and a value to VALUES clause of INSERT statement, 836 | 837 | A call to Set method generates both the list of columns and 838 | values to be inserted by INSERT statement: 839 | 840 | q := sqlf.InsertInto("table").Set("field", 42) 841 | 842 | produces 843 | 844 | INSERT INTO table (field) VALUES (42) 845 | 846 | Do not use it to construct ON CONFLICT DO UPDATE SET or similar clauses. 847 | Use generic Clause and Expr methods instead: 848 | 849 | q.Clause("ON CONFLICT DO UPDATE SET").Expr("column_name = ?", value) 850 | */ 851 | func (row newRow) Set(field string, value interface{}) newRow { 852 | return row.SetExpr(field, "?", value) 853 | } 854 | 855 | /* 856 | SetExpr is an extended version of Set method. 857 | 858 | q.SetExpr("field", "field + 1") 859 | q.SetExpr("field", "? + ?", 31, 11) 860 | */ 861 | func (row newRow) SetExpr(field, expr string, args ...interface{}) newRow { 862 | q := row.Stmt 863 | 864 | if row.first { 865 | q.addChunk(posInsertFields, "", field, nil, ", ") 866 | q.addChunk(posValues, "", expr, args, ", ") 867 | } else { 868 | sep := "" 869 | if row.notEmpty { 870 | sep = ", " 871 | } 872 | q.addChunk(posValues, "", expr, args, sep) 873 | } 874 | 875 | return newRow{ 876 | Stmt: row.Stmt, 877 | first: row.first, 878 | notEmpty: true, 879 | } 880 | } 881 | 882 | var ( 883 | space = []byte{' '} 884 | placeholder = []byte{'?'} 885 | placeholderComma = []byte{'?', ','} 886 | joinOn = []byte{' ', 'O', 'N', ' ', '('} 887 | ) 888 | 889 | type chunkPos int 890 | 891 | const ( 892 | _ chunkPos = iota 893 | posStart chunkPos = 100 * iota 894 | posWith 895 | posInsert 896 | posInsertFields 897 | posValues 898 | posDelete 899 | posUpdate 900 | posSet 901 | posSelect 902 | posInto 903 | posFrom 904 | posWhere 905 | posGroupBy 906 | posHaving 907 | posUnion 908 | posOrderBy 909 | posLimit 910 | posOffset 911 | posReturning 912 | posEnd 913 | ) 914 | -------------------------------------------------------------------------------- /stmt_test.go: -------------------------------------------------------------------------------- 1 | package sqlf_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/leporo/sqlf" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestNewBuilder(t *testing.T) { 13 | sqlf.SetDialect(sqlf.NoDialect) 14 | q := sqlf.New("SELECT *").From("table") 15 | defer q.Close() 16 | sql := q.String() 17 | args := q.Args() 18 | require.Equal(t, "SELECT * FROM table", sql) 19 | require.Empty(t, args) 20 | } 21 | 22 | func TestBasicSelect(t *testing.T) { 23 | q := sqlf.From("table").Select("id").Where("id > ?", 42).Where("id < ?", 1000) 24 | defer q.Close() 25 | sql, args := q.String(), q.Args() 26 | require.Equal(t, "SELECT id FROM table WHERE id > ? AND id < ?", sql) 27 | require.Equal(t, []interface{}{42, 1000}, args) 28 | } 29 | 30 | func TestMixedOrder(t *testing.T) { 31 | q := sqlf.Select("id").Where("id > ?", 42).From("table").Where("id < ?", 1000) 32 | defer q.Close() 33 | sql, args := q.String(), q.Args() 34 | require.Equal(t, "SELECT id FROM table WHERE id > ? AND id < ?", sql) 35 | require.Equal(t, []interface{}{42, 1000}, args) 36 | } 37 | 38 | func TestClause(t *testing.T) { 39 | q := sqlf.Select("id").From("table").Where("id > ?", 42).Clause("FETCH NEXT").Clause("FOR UPDATE") 40 | defer q.Close() 41 | sql, args := q.String(), q.Args() 42 | require.Equal(t, "SELECT id FROM table WHERE id > ? FETCH NEXT FOR UPDATE", sql) 43 | require.Equal(t, []interface{}{42}, args) 44 | } 45 | 46 | func TestExpr(t *testing.T) { 47 | q := sqlf.From("table"). 48 | Select("id"). 49 | Expr("(select 1 from related where table_id = table.id limit 1) AS has_related"). 50 | Where("id > ?", 42) 51 | require.Equal(t, "SELECT id, (select 1 from related where table_id = table.id limit 1) AS has_related FROM table WHERE id > ?", q.String()) 52 | require.Equal(t, []interface{}{42}, q.Args()) 53 | q.Close() 54 | } 55 | 56 | func TestManyFields(t *testing.T) { 57 | q := sqlf.Select("id").From("table").Where("id = ?", 42) 58 | defer q.Close() 59 | for i := 1; i <= 3; i++ { 60 | q.Select(fmt.Sprintf("(id + ?) as id_%d", i), i*10) 61 | } 62 | for _, field := range []string{"uno", "dos", "tres"} { 63 | q.Select(field) 64 | } 65 | sql, args := q.String(), q.Args() 66 | require.Equal(t, "SELECT id, (id + ?) as id_1, (id + ?) as id_2, (id + ?) as id_3, uno, dos, tres FROM table WHERE id = ?", sql) 67 | require.Equal(t, []interface{}{10, 20, 30, 42}, args) 68 | } 69 | 70 | func TestEvenMoreFields(t *testing.T) { 71 | q := sqlf.Select("id").From("table") 72 | defer q.Close() 73 | for n := 1; n <= 50; n++ { 74 | q.Select(fmt.Sprintf("field_%d", n)) 75 | } 76 | sql, args := q.String(), q.Args() 77 | require.Equal(t, 0, len(args)) 78 | for n := 1; n <= 50; n++ { 79 | field := fmt.Sprintf(", field_%d", n) 80 | require.Contains(t, sql, field) 81 | } 82 | } 83 | 84 | func TestPgPlaceholders(t *testing.T) { 85 | q := sqlf.PostgreSQL.From("series"). 86 | Select("id"). 87 | Where("time > ?", time.Now().Add(time.Hour*-24*14)). 88 | Where("(time < ?)", time.Now().Add(time.Hour*-24*7)) 89 | defer q.Close() 90 | sql, _ := q.String(), q.Args() 91 | require.Equal(t, "SELECT id FROM series WHERE time > $1 AND (time < $2)", sql) 92 | } 93 | 94 | func TestPgPlaceholderEscape(t *testing.T) { 95 | q := sqlf.PostgreSQL.From("series"). 96 | Select("id"). 97 | Where("time \\?> ? + 1", time.Now().Add(time.Hour*-24*14)). 98 | Where("time < ?", time.Now().Add(time.Hour*-24*7)) 99 | defer q.Close() 100 | sql, _ := q.String(), q.Args() 101 | require.Equal(t, "SELECT id FROM series WHERE time ?> $1 + 1 AND time < $2", sql) 102 | } 103 | 104 | func TestTo(t *testing.T) { 105 | var ( 106 | field1 int 107 | field2 string 108 | ) 109 | q := sqlf.From("table"). 110 | Select("field1").To(&field1). 111 | Select("field2").To(&field2) 112 | defer q.Close() 113 | dest := q.Dest() 114 | 115 | require.Equal(t, []interface{}{&field1, &field2}, dest) 116 | } 117 | 118 | func TestManyClauses(t *testing.T) { 119 | q := sqlf.From("table"). 120 | Select("field"). 121 | Where("id > ?", 2). 122 | Clause("UNO"). 123 | Clause("DOS"). 124 | Clause("TRES"). 125 | Clause("QUATRO"). 126 | Offset(10). 127 | Limit(5). 128 | Clause("NO LOCK") 129 | defer q.Close() 130 | sql, args := q.String(), q.Args() 131 | 132 | require.Equal(t, "SELECT field FROM table WHERE id > ? UNO DOS TRES QUATRO LIMIT ? OFFSET ? NO LOCK", sql) 133 | require.Equal(t, []interface{}{2, 5, 10}, args) 134 | } 135 | 136 | func TestWith(t *testing.T) { 137 | var row struct { 138 | ID int64 `db:"id"` 139 | Quantity int64 `db:"quantity"` 140 | } 141 | q := sqlf.With("t", 142 | sqlf.From("orders"). 143 | Select("id, quantity"). 144 | Where("ts < ?", time.Now())). 145 | From("t"). 146 | Bind(&row) 147 | defer q.Close() 148 | 149 | require.Equal(t, "WITH t AS (SELECT id, quantity FROM orders WHERE ts < ?) SELECT id, quantity FROM t", q.String()) 150 | } 151 | 152 | func TestWithRecursive(t *testing.T) { 153 | q := sqlf.From("orders"). 154 | With("RECURSIVE regional_sales", sqlf.From("orders").Select("region, SUM(amount) AS total_sales").GroupBy("region")). 155 | With("top_regions", sqlf.From("regional_sales").Select("region").OrderBy("total_sales DESC").Limit(5)). 156 | Select("region"). 157 | Select("product"). 158 | Select("SUM(quantity) AS product_units"). 159 | Select("SUM(amount) AS product_sales"). 160 | Where("region IN (SELECT region FROM top_regions)"). 161 | GroupBy("region, product") 162 | defer q.Close() 163 | 164 | require.Equal(t, "WITH RECURSIVE regional_sales AS (SELECT region, SUM(amount) AS total_sales FROM orders GROUP BY region), top_regions AS (SELECT region FROM regional_sales ORDER BY total_sales DESC LIMIT ?) SELECT region, product, SUM(quantity) AS product_units, SUM(amount) AS product_sales FROM orders WHERE region IN (SELECT region FROM top_regions) GROUP BY region, product", q.String()) 165 | } 166 | 167 | func TestSubQueryDialect(t *testing.T) { 168 | q := sqlf.PostgreSQL.From("users u"). 169 | Select("email"). 170 | Where("registered > ?", "2019-01-01"). 171 | SubQuery("EXISTS (", ")", 172 | sqlf.PostgreSQL.From("orders"). 173 | Select("id"). 174 | Where("user_id = u.id"). 175 | Where("amount > ?", 100)) 176 | defer q.Close() 177 | 178 | // Parameter placeholder numbering should match the arguments 179 | require.Equal(t, "SELECT email FROM users u WHERE registered > $1 AND EXISTS (SELECT id FROM orders WHERE user_id = u.id AND amount > $2)", q.String()) 180 | require.Equal(t, []interface{}{"2019-01-01", 100}, q.Args()) 181 | } 182 | 183 | func TestClone(t *testing.T) { 184 | var ( 185 | value string 186 | value2 string 187 | ) 188 | q := sqlf.From("table").Select("field").To(&value).Where("id = ?", 42) 189 | defer q.Close() 190 | 191 | require.Equal(t, "SELECT field FROM table WHERE id = ?", q.String()) 192 | 193 | q2 := q.Clone() 194 | defer q2.Close() 195 | 196 | require.Equal(t, q.Args(), q2.Args()) 197 | require.Equal(t, q.Dest(), q2.Dest()) 198 | require.Equal(t, q.String(), q2.String()) 199 | 200 | q2.Where("time < ?", time.Now()) 201 | 202 | require.Equal(t, q.Dest(), q2.Dest()) 203 | require.NotEqual(t, q.Args(), q2.Args()) 204 | require.NotEqual(t, q.String(), q2.String()) 205 | 206 | q2.Select("field2").To(&value2) 207 | require.NotEqual(t, q.Dest(), q2.Dest()) 208 | require.NotEqual(t, q.Args(), q2.Args()) 209 | require.NotEqual(t, q.String(), q2.String()) 210 | 211 | // Add more clauses to original statement to re-allocate chunks array 212 | q.With("top_users", sqlf.From("users").OrderBy("rating DESCT").Limit(10)). 213 | GroupBy("id"). 214 | Having("field > ?", 10). 215 | Paginate(1, 20). 216 | Clause("FETCH ROWS ONLY"). 217 | Clause("FOR UPDATE") 218 | 219 | q3 := q.Clone() 220 | require.Equal(t, q.Args(), q3.Args()) 221 | require.Equal(t, q.Dest(), q3.Dest()) 222 | require.Equal(t, q.String(), q3.String()) 223 | 224 | require.NotEqual(t, q.Dest(), q2.Dest()) 225 | require.NotEqual(t, q.Args(), q2.Args()) 226 | require.NotEqual(t, q.String(), q2.String()) 227 | } 228 | 229 | func TestJoin(t *testing.T) { 230 | q := sqlf.From("orders o").Select("id").Join("users u", "u.id = o.user_id") 231 | defer q.Close() 232 | require.Equal(t, "SELECT id FROM orders o JOIN users u ON (u.id = o.user_id)", q.String()) 233 | } 234 | 235 | func TestLeftJoin(t *testing.T) { 236 | q := sqlf.From("orders o").Select("id").LeftJoin("users u", "u.id = o.user_id") 237 | defer q.Close() 238 | require.Equal(t, "SELECT id FROM orders o LEFT JOIN users u ON (u.id = o.user_id)", q.String()) 239 | } 240 | 241 | func TestRightJoin(t *testing.T) { 242 | q := sqlf.From("orders o").Select("id").RightJoin("users u", "u.id = o.user_id") 243 | defer q.Close() 244 | require.Equal(t, "SELECT id FROM orders o RIGHT JOIN users u ON (u.id = o.user_id)", q.String()) 245 | } 246 | 247 | func TestFullJoin(t *testing.T) { 248 | q := sqlf.From("orders o").Select("id").FullJoin("users u", "u.id = o.user_id") 249 | defer q.Close() 250 | require.Equal(t, "SELECT id FROM orders o FULL JOIN users u ON (u.id = o.user_id)", q.String()) 251 | } 252 | 253 | func TestUnion(t *testing.T) { 254 | q := sqlf.From("tasks"). 255 | Select("id, status"). 256 | Where("status = ?", "new"). 257 | Union(false, sqlf.PostgreSQL.From("tasks"). 258 | Select("id, status"). 259 | Where("status = ?", "wip")) 260 | defer q.Close() 261 | require.Equal(t, "SELECT id, status FROM tasks WHERE status = ? UNION SELECT id, status FROM tasks WHERE status = ?", q.String()) 262 | } 263 | 264 | func TestLimit(t *testing.T) { 265 | q := sqlf.From("items"). 266 | Select("id"). 267 | Where("id > ?", 42). 268 | Limit(10). 269 | Limit(11). 270 | Limit(20) 271 | defer q.Close() 272 | require.Equal(t, "SELECT id FROM items WHERE id > ? LIMIT ?", q.String()) 273 | require.Equal(t, []interface{}{42, 20}, q.Args()) 274 | } 275 | 276 | func TestBindStruct(t *testing.T) { 277 | type Parent struct { 278 | ID int64 `db:"id"` 279 | Date time.Time `db:"date"` 280 | Skipped string 281 | } 282 | var u struct { 283 | Parent 284 | ChildTime time.Time `db:"child_time"` 285 | Name string `db:"name"` 286 | Extra int64 287 | } 288 | q := sqlf.From("users"). 289 | Bind(&u). 290 | Where("id = ?", 2) 291 | defer q.Close() 292 | require.Equal(t, "SELECT id, date, child_time, name FROM users WHERE id = ?", q.String()) 293 | require.Equal(t, []interface{}{2}, q.Args()) 294 | require.EqualValues(t, []interface{}{&u.ID, &u.Date, &u.ChildTime, &u.Name}, q.Dest()) 295 | } 296 | 297 | func TestBulkInsert(t *testing.T) { 298 | q := sqlf.InsertInto("vars") 299 | defer q.Close() 300 | for i := 1; i <= 5; i++ { 301 | q.NewRow(). 302 | Set("no", i). 303 | Set("val", i) 304 | } 305 | require.Equal(t, "INSERT INTO vars ( no, val ) VALUES ( ?, ? ), ( ?, ? ), ( ?, ? ), ( ?, ? ), ( ?, ? )", q.String()) 306 | require.Len(t, q.Args(), 10) 307 | } 308 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package sqlf 2 | 3 | import ( 4 | "unsafe" 5 | ) 6 | 7 | func insertAt(dest, src []interface{}, index int) []interface{} { 8 | srcLen := len(src) 9 | if srcLen > 0 { 10 | oldLen := len(dest) 11 | dest = append(dest, src...) 12 | if index < oldLen { 13 | copy(dest[index+srcLen:], dest[index:]) 14 | copy(dest[index:], src) 15 | } 16 | } 17 | 18 | return dest 19 | } 20 | 21 | // bufToString returns a string pointing to a ByteBuffer contents 22 | // It helps to avoid memory copyng. 23 | // Use the returned string with care, make sure to never use it after 24 | // the ByteBuffer is deallocated or returned to a pool. 25 | func bufToString(buf *[]byte) string { 26 | return *(*string)(unsafe.Pointer(buf)) 27 | } 28 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | package sqlf 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestInsertAt(t *testing.T) { 10 | a := insertAt([]interface{}{1, 2, 3, 4}, []interface{}{5, 6}, 4) 11 | require.Equal(t, a, []interface{}{1, 2, 3, 4, 5, 6}) 12 | 13 | a = insertAt([]interface{}{}, []interface{}{3, 2}, 0) 14 | require.Equal(t, a, []interface{}{3, 2}) 15 | 16 | a = insertAt([]interface{}{}, []interface{}{}, 5) 17 | require.Equal(t, a, []interface{}{}) 18 | 19 | a = insertAt([]interface{}{1, 2}, []interface{}{3}, 1) 20 | require.Equal(t, a, []interface{}{1, 3, 2}) 21 | } 22 | --------------------------------------------------------------------------------