├── .github └── workflows │ └── push.yml ├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── Makefile ├── README.md ├── changelogs └── v2.0.md ├── db.go ├── db_test.go ├── example ├── delete.go ├── generator.go ├── insert.go ├── mapper.go ├── open.go ├── other.go ├── outer_join.go ├── query_builder.go ├── serial_mapper.go ├── tx.go └── update.go ├── export_test.go ├── finder.go ├── finder_test.go ├── generator.go ├── generator_test.go ├── go.mod ├── go.sum ├── interface.go ├── mapper.go ├── mapper_test.go ├── mocks ├── mock_exql │ ├── interface.go │ └── saver.go └── mock_query │ └── query.go ├── model ├── fields.go ├── group_users.go ├── testmodel │ └── testmodel.go ├── user_groups.go ├── user_login_histories.go └── users.go ├── parser.go ├── parser_test.go ├── query.go ├── query ├── builder.go ├── builder_test.go ├── query.go ├── query_test.go ├── util.go └── util_test.go ├── query_test.go ├── saver.go ├── saver_test.go ├── schema └── model.sql ├── stmt.go ├── stmt_test.go ├── tag.go ├── tag_test.go ├── template └── README.md ├── test └── db.go ├── test_db_test.go ├── tool ├── composegen │ └── main.go ├── modelgen │ └── main.go └── rdmegen │ └── main.go ├── tx.go ├── tx_test.go ├── util.go └── util_test.go /.github/workflows/push.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | branches: 5 | - "**" 6 | env: 7 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v1 13 | - uses: actions/setup-go@v2 14 | with: 15 | go-version: "1.22" 16 | - run: make up 17 | - name: Check generated codes 18 | run: | 19 | go install 20 | go run tool/modelgen/main.go 21 | go run tool/rdmegen/main.go 22 | make fmt 23 | git diff --exit-code 24 | - run: make test 25 | - uses: codecov/codecov-action@v1 26 | with: 27 | token: ${{ secrets.CODECOV_TOKEN }} 28 | file: coverage.out 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | dist 3 | coverage.out 4 | compose.yml 5 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "explorer.fileNesting.enabled": true, 3 | "explorer.fileNesting.patterns": { 4 | "*.go": "${capture}_test.go", 5 | "go.mod": "go.sum" 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 LoiLo inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | up: down 2 | docker compose up -d 3 | down: compose.yml 4 | docker compose down 5 | fmt: 6 | go fmt github.com/loilo-inc/exql/... 7 | test: 8 | go test ./... -race -cover -coverprofile=coverage.out -covermode=atomic -count 1 9 | README.md: template/README.md tool/**/*.go example/*.go 10 | go run tool/rdmegen/main.go 11 | .PHONY: mocks 12 | mocks: 13 | rm -rf mocks/ 14 | go generate ./... 15 | compose.yml: tool/composegen/* 16 | go run tool/composegen/main.go 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | exql 2 | --- 3 | [![codecov](https://codecov.io/gh/loilo-inc/exql/branch/master/graph/badge.svg?token=aGixN2xIMP)](https://codecov.io/gh/loilo-inc/exql) 4 | 5 | Safe, strict and clear ORM for Go 6 | 7 | ## Introduction 8 | 9 | exql is a simple ORM library for MySQL, written in Go. It is designed to work at the minimum for real software development. It has a few, limited but enough convenient functionalities of SQL database. 10 | We adopted the data mapper model, not the active record. Records in the database are mapped into structs simply. Each model has no state and also no methods to modify itself and sync database records. You need to write bare SQL code for every operation you need except for a few cases. 11 | 12 | exql is designed by focusing on safety and clearness in SQL usage. In other words, we never generate any SQL statements that are potentially dangerous or have ambiguous side effects across tables and the database. 13 | 14 | It does: 15 | 16 | - make insert/update query from model structs. 17 | - map rows returned from the database into structs. 18 | - map joined table into one or more structs. 19 | - provide a safe syntax for the transaction. 20 | - provide a framework to build dynamic SQL statements safely. 21 | - generate model codes automatically from the database. 22 | 23 | It DOESN'T 24 | 25 | - make delete/update statements across the table. 26 | - make unexpectedly slow select queries that don't use correct indices. 27 | - modify any database settings, schemas and indices. 28 | 29 | ## Table of contents 30 | 31 | - [exql](#exql) 32 | - [Introduction](#introduction) 33 | - [Table of contents](#table-of-contents) 34 | - [Usage](#usage) 35 | - [Open database connection](#open-database-connection) 36 | - [Code Generation](#code-generation) 37 | - [Execute queries](#execute-queries) 38 | - [Insert](#insert) 39 | - [Update](#update) 40 | - [Delete](#delete) 41 | - [Other](#other) 42 | - [Transaction](#transaction) 43 | - [Find records](#find-records) 44 | - [For simple query](#for-simple-query) 45 | - [For joined table](#for-joined-table) 46 | - [For outer-joined table](#for-outer-joined-table) 47 | - [Use query builder](#use-query-builder) 48 | - [License](#license) 49 | 50 | ## Usage 51 | 52 | ### Open database connection 53 | 54 | ```go 55 | package main 56 | 57 | import ( 58 | "time" 59 | 60 | "log" 61 | 62 | "github.com/loilo-inc/exql/v2" 63 | ) 64 | 65 | func OpenDB() exql.DB { 66 | db, err := exql.Open(&exql.OpenOptions{ 67 | // MySQL url for sql.Open() 68 | Url: "user:password@tcp(127.0.0.1:3306)/database?charset=utf8mb4&parseTime=True&loc=Local", 69 | // Max retry count for database connection failure 70 | MaxRetryCount: 3, 71 | RetryInterval: 10 * time.Second, 72 | }) 73 | if err != nil { 74 | log.Fatalf("open error: %s", err) 75 | return nil 76 | } 77 | return db 78 | } 79 | 80 | ``` 81 | 82 | ### Code Generation 83 | exql provides an automated code generator of models based on the database schema. This is a typical table schema of MySQL database. 84 | 85 | ``` 86 | mysql> show columns from users; 87 | +-------+--------------+------+-----+---------+----------------+ 88 | | Field | Type | Null | Key | Default | Extra | 89 | +-------+--------------+------+-----+---------+----------------+ 90 | | id | int(11) | NO | PRI | NULL | auto_increment | 91 | | name | varchar(255) | NO | | NULL | | 92 | | age | int(11) | NO | | NULL | | 93 | +-------+--------------+------+-----+---------+----------------+ 94 | ``` 95 | 96 | To generate model codes, based on the schema, you need to write the code like this: 97 | 98 | ```go 99 | package main 100 | 101 | import ( 102 | "database/sql" 103 | "log" 104 | 105 | _ "github.com/go-sql-driver/mysql" 106 | "github.com/loilo-inc/exql/v2" 107 | ) 108 | 109 | func GenerateModels() { 110 | db, _ := sql.Open("mysql", "url-for-db") 111 | gen := exql.NewGenerator(db) 112 | err := gen.Generate(&exql.GenerateOptions{ 113 | // Directory path for result. Default is `model` 114 | OutDir: "dist", 115 | // Package name for models. Default is `model` 116 | Package: "dist", 117 | // Exclude table names for generation. Default is [] 118 | Exclude: []string{ 119 | "internal", 120 | }, 121 | }) 122 | if err != nil { 123 | log.Fatalf(err.Error()) 124 | } 125 | } 126 | 127 | ``` 128 | 129 | And results are mostly like this: 130 | 131 | ```go 132 | // This file is generated by exql. DO NOT edit. 133 | package model 134 | 135 | type Users struct { 136 | Id int64 `exql:"column:id;type:int;primary;not null;auto_increment" json:"id"` 137 | Name string `exql:"column:name;type:varchar(255);not null" json:"name"` 138 | Age int64 `exql:"column:age;type:int;not null" json:"age"` 139 | } 140 | 141 | func (u *Users) TableName() string { 142 | return UsersTableName 143 | } 144 | 145 | type UpdateUsers struct { 146 | Id *int64 `exql:"column:id;type:int;primary;not null;auto_increment" json:"id"` 147 | Name *string `exql:"column:name;type:varchar(255);not null" json:"name"` 148 | Age *int64 `exql:"column:age;type:int;not null" json:"age"` 149 | } 150 | 151 | func (u *UpdateUsers) UpdateTableName() string { 152 | return UsersTableName 153 | } 154 | 155 | const UsersTableName = "users" 156 | 157 | ``` 158 | 159 | `Users` is the destination of the data mapper. It only has value fields and one method, `TableName()`. This is the implementation of `exql.Model` that can be passed into data saver. All structs, methods and field tags must be preserved as it is, for internal use. If you want to modify the results, you must run the generator again. 160 | 161 | `UpdateUsers` is a partial structure for the data model. It has identical name fields to `Users`, but all types are represented as a pointer. It is used to update table columns partially. In other words, it is a designated, typesafe map for the model. 162 | 163 | ### Execute queries 164 | 165 | There are several ways to publish SQL statements with exql. 166 | 167 | #### Insert 168 | 169 | INSERT query is constructed automatically based on model data and executed without writing the statement. To insert new records into the database, set values to the model and pass it to `exql.DB#Insert` method. 170 | 171 | ```go 172 | package main 173 | 174 | import ( 175 | "log" 176 | 177 | "github.com/loilo-inc/exql/v2" 178 | "github.com/loilo-inc/exql/v2/model" 179 | ) 180 | 181 | func Insert(db exql.DB) { 182 | // Create a user model 183 | // Primary key (id) is not needed to set. 184 | // It will be ignored on building the insert query. 185 | user := model.Users{ 186 | Name: "Go", 187 | } 188 | // You must pass the model as a pointer. 189 | if result, err := db.Insert(&user); err != nil { 190 | log.Fatal(err.Error()) 191 | } else { 192 | insertedId, _ := result.LastInsertId() 193 | // Inserted id is assigned into the auto-increment field after the insertion, 194 | // if these field is int64/uint64 195 | if insertedId != user.Id { 196 | log.Fatal("never happens") 197 | } 198 | } 199 | } 200 | 201 | func BulkInsert(db exql.DB) { 202 | user1 := model.Users{Name: "Go"} 203 | user2 := model.Users{Name: "Lang"} 204 | // INSERT INTO users (name) VALUES (?),(?) 205 | // ["Go", "Lang"] 206 | if q, err := exql.QueryForBulkInsert(&user1, &user2); err != nil { 207 | log.Fatal(err) 208 | } else if _, err := db.Exec(q); err != nil { 209 | log.Fatal(err) 210 | } 211 | // NOTE: unlike a single insertion, bulk insertion doesn't obtain auto-incremented values from results. 212 | } 213 | 214 | ``` 215 | 216 | #### Update 217 | 218 | UPDATE query is constructed automatically based on the model update struct. To avoid unexpected updates to the table, all values are represented by a pointer of data type. 219 | 220 | ```go 221 | package main 222 | 223 | import ( 224 | "log" 225 | 226 | "github.com/loilo-inc/exql/v2" 227 | "github.com/loilo-inc/exql/v2/model" 228 | ) 229 | 230 | // Using designated update struct 231 | func UpdateModel(db exql.DB) { 232 | // UPDATE `users` SET `name` = `GoGo` WHERE `id` = ? 233 | // [1] 234 | _, err := db.UpdateModel(&model.UpdateUsers{ 235 | Name: exql.Ptr("GoGo"), 236 | }, exql.Where("id = ?", 1), 237 | ) 238 | if err != nil { 239 | log.Fatal(err) 240 | } 241 | } 242 | 243 | // With table name and key-value pairs 244 | func Update(db exql.DB) { 245 | // UPDATE `users` SET `name` = `GoGo` WHERE `id` = ? 246 | // [1] 247 | _, err := db.Update("users", map[string]any{ 248 | "name": "GoGo", 249 | }, exql.Where("id = ?", 1)) 250 | if err != nil { 251 | log.Fatal(err) 252 | } 253 | } 254 | 255 | ``` 256 | 257 | #### Delete 258 | 259 | DELETE query is published to the table with given conditions. There's no way to construct DELETE query from the model as a security reason. 260 | 261 | ```go 262 | package main 263 | 264 | import ( 265 | "log" 266 | 267 | "github.com/loilo-inc/exql/v2" 268 | ) 269 | 270 | func Delete(db exql.DB) { 271 | // DELETE FROM `users` WHERE id = ? 272 | // [1] 273 | _, err := db.Delete("users", exql.Where("id = ?", 1)) 274 | if err != nil { 275 | log.Fatal(err) 276 | } 277 | } 278 | 279 | ``` 280 | 281 | #### Other 282 | 283 | Other queries should be executed by `sql.DB` that got from `DB`. 284 | 285 | ```go 286 | package main 287 | 288 | import ( 289 | "log" 290 | 291 | "github.com/loilo-inc/exql/v2" 292 | ) 293 | 294 | // To execute other kind of queries, unwrap sql.DB. 295 | func OtherQuery(db exql.DB) { 296 | // db.DB() returns *sql.DB 297 | row := db.DB().QueryRow("SELECT COUNT(*) FROM users") 298 | var count int 299 | row.Scan(&count) 300 | log.Printf("%d", count) 301 | } 302 | 303 | ``` 304 | 305 | ### Transaction 306 | 307 | Transaction with `BEGIN`~`COMMIT`/`ROLLBACK` is done by `TransactionWithContext`. You don't need to call `BeginTx` and `Commit`/`Rollback` manually and all atomic operations are done within a callback. 308 | 309 | ```go 310 | package main 311 | 312 | import ( 313 | "context" 314 | "database/sql" 315 | "time" 316 | 317 | "github.com/loilo-inc/exql/v2" 318 | "github.com/loilo-inc/exql/v2/model" 319 | ) 320 | 321 | func Transaction(db exql.DB) { 322 | timeout, _ := context.WithTimeout(context.Background(), 10*time.Second) 323 | err := db.TransactionWithContext(timeout, &sql.TxOptions{ 324 | Isolation: sql.LevelDefault, 325 | ReadOnly: false, 326 | }, func(tx exql.Tx) error { 327 | user := model.Users{Name: "go"} 328 | _, err := tx.Insert(&user) 329 | return err 330 | }) 331 | if err != nil { 332 | // Transaction has been rolled back 333 | } else { 334 | // Transaction has been committed 335 | } 336 | } 337 | 338 | ``` 339 | 340 | ### Find records 341 | 342 | To find records from the database, use `Find`/`FindMany` method. It executes the query and maps results into structs correctly. 343 | 344 | #### For simple query 345 | 346 | ```go 347 | package main 348 | 349 | import ( 350 | "log" 351 | 352 | "github.com/loilo-inc/exql/v2" 353 | "github.com/loilo-inc/exql/v2/model" 354 | "github.com/loilo-inc/exql/v2/query" 355 | ) 356 | 357 | func Find(db exql.DB) { 358 | // Destination model struct 359 | var user model.Users 360 | // Pass as a pointer 361 | err := db.Find(query.Q(`SELECT * FROM users WHERE id = ?`, 1), &user) 362 | if err != nil { 363 | log.Fatal(err) 364 | } 365 | log.Printf("%d", user.Id) // -> 1 366 | } 367 | 368 | func FindMany(db exql.DB) { 369 | // Destination slice of models. 370 | // NOTE: It must be the slice of pointers of models. 371 | var users []*model.Users 372 | // Passing destination to MapMany(). 373 | // Second argument must be a pointer. 374 | err := db.FindMany(query.Q(`SELECT * FROM users LIMIT ?`, 5), &users) 375 | if err != nil { 376 | log.Fatal(err) 377 | } 378 | log.Printf("%d", len(users)) // -> 5 379 | } 380 | 381 | ``` 382 | 383 | #### For joined table 384 | 385 | ```go 386 | package main 387 | 388 | import ( 389 | "log" 390 | 391 | "github.com/loilo-inc/exql/v2" 392 | "github.com/loilo-inc/exql/v2/model" 393 | ) 394 | 395 | /* 396 | user_groups has many users 397 | users belongs to many groups 398 | */ 399 | func MapSerial(db exql.DB) { 400 | query := ` 401 | SELECT * FROM users 402 | JOIN group_users ON group_users.user_id = users.id 403 | JOIN user_groups ON user_groups.id = group_users.id 404 | WHERE user_groups.name = ?` 405 | rows, err := db.DB().Query(query, "goland") 406 | if err != nil { 407 | log.Fatal(err) 408 | return 409 | } 410 | defer rows.Close() 411 | serialMapper := exql.NewSerialMapper(func(i int) string { 412 | // Each column's separator is `id` 413 | return "id" 414 | }) 415 | var users []*model.Users 416 | for rows.Next() { 417 | var user model.Users 418 | var groupUsers model.GroupUsers 419 | var userGroup model.UserGroups 420 | // Create serial mapper. It will split joined columns by logical tables. 421 | // In this case, joined table and destination mappings are: 422 | // | users | group_users | user_groups | 423 | // + --------- + ------------------------ + ------------- + 424 | // | id | name | id | user_id | group_id | id | name | 425 | // + --------- + ------------------------ + ------------- + 426 | // | &user | &groupUsers | &userGroup | 427 | // + --------- + ------------------------ + ------------- + 428 | if err := serialMapper.Map(rows, &user, &groupUsers, &userGroup); err != nil { 429 | log.Fatalf(err.Error()) 430 | return 431 | } 432 | users = append(users, &user) 433 | } 434 | // enumerate users... 435 | } 436 | 437 | ``` 438 | 439 | #### For outer-joined table 440 | 441 | ```go 442 | package main 443 | 444 | import ( 445 | "log" 446 | 447 | "github.com/loilo-inc/exql/v2" 448 | "github.com/loilo-inc/exql/v2/model" 449 | ) 450 | 451 | func MapSerialOuterJoin(db exql.DB) { 452 | query := ` 453 | SELECT * FROM users 454 | LEFT JOIN group_users ON group_users.user_id = users.id 455 | LEFT JOIN user_groups ON user_groups.id = group_users.id 456 | WHERE users.id = ?` 457 | rows, err := db.DB().Query(query, 1) 458 | if err != nil { 459 | log.Fatal(err) 460 | return 461 | } 462 | defer rows.Close() 463 | serialMapper := exql.NewSerialMapper(func(i int) string { 464 | // Each column's separator is `id` 465 | return "id" 466 | }) 467 | var users []*model.Users 468 | var groups []*model.UserGroups 469 | for rows.Next() { 470 | var user model.Users 471 | var groupUser *model.GroupUsers // Use *GroupUsers/*Group for outer join so that it can be nil 472 | var group *model.UserGroups // when the values of outer joined columns are NULL. 473 | if err := serialMapper.Map(rows, &user, &groupUser, &group); err != nil { 474 | log.Fatal(err.Error()) 475 | return 476 | } 477 | users = append(users, &user) 478 | groups = append(groups, group) // group = nil when the user does not belong to any group. 479 | } 480 | // enumerate users and groups. 481 | } 482 | 483 | ``` 484 | 485 | ### Use query builder 486 | 487 | `exql/query` package is a low-level API for building complicated SQL statements. See [V2 Release Notes](https://github.com/loilo-inc/exql/blob/main/changelogs/v2.0.md#exqlquery-package) for more details. 488 | 489 | ```go 490 | package main 491 | 492 | import ( 493 | "github.com/loilo-inc/exql/v2" 494 | "github.com/loilo-inc/exql/v2/query" 495 | ) 496 | 497 | func Query(db exql.DB) { 498 | q := query.New( 499 | `SELECT * FROM users WHERE id IN (:?) AND age = ?`, 500 | query.V(1, 2, 3), 20, 501 | ) 502 | // SELECT * FROM users WHERE id IN (?,?,?) AND age = ? 503 | // [1,2,3,20] 504 | db.Query(q) 505 | } 506 | 507 | func QueryBulider(db exql.DB) { 508 | qb := query.NewBuilder() 509 | qb.Sprintf("SELECT * FROM %s", "users") 510 | qb.Query("WHERE id IN (:?) AND age >= ?", query.V(1, 2), 20) 511 | // SELECT * FROM users WHERE id IN (?,?) AND age >= ? 512 | // [1,2,20] 513 | db.Query(qb.Build()) 514 | } 515 | 516 | func CondBulider(db exql.DB) { 517 | cond := query.Cond("id = ?", 1) 518 | cond.And("age >= ?", 20) 519 | cond.And("name in (:?)", query.V("go", "lang")) 520 | q := query.New("SELECT * FROM users WHERE :?", cond) 521 | // SELECT * FROM users WHERE id = ? and age >= ? and name in (?,?) 522 | // [1, 20, go, lang] 523 | db.Query(q) 524 | } 525 | 526 | ``` 527 | 528 | ## License 529 | 530 | MIT License / Copyright (c) LoiLo inc. 531 | 532 | -------------------------------------------------------------------------------- /changelogs/v2.0.md: -------------------------------------------------------------------------------- 1 | # exql v2 Release Note 2 | 3 | 2022-02-03 4 | Yusuke SAKURAI 5 | Software Engineer at LoiLo Inc. 6 | 7 | ## Introduction 8 | 9 | exql@v2 is the first major update from the release in 2020. Through the real development experience of 3 years, we met many DRY issues and learned many practices from them. All new features and improvementns included v2 are an actual resolutions for them. We hope that exql gets more pragmatic and essential tool that lets Gophers go far. 10 | 11 | ## New methods of Saver 12 | 13 | New methods were introduced into `exql.Saver`. 14 | 15 | ### Delete/DeleteContext 16 | 17 | ```go 18 | Delete(table string, where q.Condition) (sql.Result, error) 19 | DeleteContext(ctx context.Context, table string, where q.Condition) (sql.Result, error) 20 | ``` 21 | 22 | Those are methods for deleting entities from database. Its usage are mostly the same as Update/UpdateContext. It requires where condition clause, not accepting empty condition that causes entire deletion of data from database. Please use them carefully. 23 | 24 | ### Query Executors 25 | 26 | ```go 27 | Exec(query q.Query) (sql.Result, error) 28 | ExecContext(ctx context.Context, query q.Query) (sql.Result, error) 29 | Query(query q.Query) (*sql.Rows, error) 30 | QueryContext(ctx context.Context, query q.Query) (*sql.Rows, error) 31 | QueryRow(query q.Query) (*sql.Row, error) 32 | QueryRowContext(ctx context.Context, query q.Query) (*sql.Row, error) 33 | ``` 34 | 35 | Those are methods that execute new `query.Query` interface directly instead of raw SQL statements. It is documented in detail later. 36 | 37 | ### Query maker for bulk insertion 38 | 39 | ```go 40 | func QueryForBulkInsert[T Model](modelPtrs ...T) (q.Query, error) 41 | ``` 42 | 43 | It is similar function to `QueryForInsert`/`QueryForUpdate`, making insert statement for multiple entities. It is useful for inserting models in batch. Unlike a single insertion, auto-increment field are not fulfilled automatically. 44 | 45 | Example: 46 | 47 | ```go 48 | q, err := exql.QueryForBulkInsert(user1, user2) 49 | result, err := db.Exec(q) 50 | // INSERT INTO users (age,name) VALUES (?,?),(?,?) 51 | // [20, go, 30, lang]" 52 | ``` 53 | 54 | ## Finder interface 55 | 56 | A new interface, `Finder` was intoroduced. This is integrated interface of quering records and maping rows into models. In former verison, a typical SELECT query and code were like this: 57 | 58 | ```go 59 | rows, err := db.DB().Query(`SELECT * FROM users WHERE id = ?`, 1) 60 | if err != nil { 61 | log.Fatal(err) 62 | } else { 63 | var user model.Users 64 | if err := db.Map(rows, &user); err != nil { 65 | log.Fatal(err) 66 | } 67 | log.Printf("%d", user.Id) // -> 1 68 | } 69 | ``` 70 | 71 | That can be rewritten in new version briefly: 72 | 73 | ```go 74 | var user model.Users 75 | err := db.Find(query.Q(`SELECT * FROM users WHERE id = ?`, 1), &user) 76 | if err != nil { 77 | log.Fatal(err) 78 | } 79 | log.Printf("%d", user.Id) // -> 1 80 | ``` 81 | 82 | `Mapper` has been deprecated and is going te be removed in the next major version. Faster refactoring is recommended. 83 | 84 | ## StmtExecutor 85 | 86 | `StmtExecutor` is the `Executor` that caches queries as `sql.Stmt`. This is designed for the repeated executon of same query. It prevents the potential lack of connection pool caused by too many prepared satatemts 87 | 88 | ```go 89 | stmtExecer := exql.NewStmtExecutor(tx.Tx()) 90 | defer stmtExecer.Close() 91 | stmtSaver := exql.NewSaver(stmtExecer) 92 | stmtSaver.Insert(&model.Users{Name: "user1"}) 93 | stmtSaver.Insert(&model.Users{Name: "user2"}) 94 | ``` 95 | 96 | The code above is equivalent to SQL below: 97 | 98 | ```sql 99 | PREPARE stmt FROM "INSERT INTO `users` (`name`) VALUES (?)"; 100 | SET @name = "user1"; 101 | EXECUTE stmt USING @name; 102 | SET @name = "user2"; 103 | EXECUTE stmt USING @name; 104 | DEALLOCATE PREPARE stmt; 105 | ``` 106 | 107 | This has the advantage on the concurrent execution of same queries because it holds a single underlying connection for the statement. Without preparation, even if they are identical, different connection is assigned to each query. 108 | 109 | ## Interface Updates 110 | 111 | ### New interfaces 112 | 113 | ```go 114 | // An abstraction of sql.DB/sql.Tx 115 | type Executor interface { 116 | Exec(query string, args ...any) (sql.Result, error) 117 | ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) 118 | Query(query string, args ...any) (*sql.Rows, error) 119 | QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) 120 | QueryRow(query string, args ...any) *sql.Row 121 | QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row 122 | Prepare(stmt string) (*sql.Stmt, error) 123 | PrepareContext(ctx context.Context, stmt string) (*sql.Stmt, error) 124 | } 125 | 126 | type Model interface { 127 | TableName() string 128 | } 129 | 130 | type ModelUpdate interface { 131 | UpdateTableName() string 132 | } 133 | ``` 134 | 135 | Some interfaces introduced for the convenience. 136 | 137 | ### Removed interfaces 138 | 139 | Two interfaces were removed. 140 | 141 | - `exql.Clause`: Use `query.Query` or `query.Condition` instead. 142 | - `exql.SET`: Use `map[string]any` instead. 143 | 144 | ## Changes in generated codes 145 | 146 | Codes that generated by `Generator` changed partially. It was simplified for presenting model's table name. 147 | 148 | Old: 149 | 150 | ```go 151 | type usersTable struct {} 152 | var UsersTable = &usersTable{} 153 | func (u *usersTable) Name() string { 154 | return "users" 155 | } 156 | ``` 157 | 158 | New: 159 | 160 | ```go 161 | const UsersTableName = "users" 162 | ``` 163 | 164 | ## exql/query package 165 | 166 | `query` package is the biggest new feature in exql v2. It is an independent framework for building SQL query. In exql v1, except for insert/update queries, you must write a raw SQL down. 167 | However, from the real development for 3 years, we've found there are many cases for constructing SQL queries dynamically with `fmt.Sprintf`. The typical code is below: 168 | 169 | ```go 170 | // make a select statement with the dynamic condition clause 171 | ids := []int{1,2,3} 172 | placeholders := "?,?,?" // depends on ids size 173 | q := fmt.Sprintf(`SELECT * FROM users WHERE id IN (%s) AND age = ?`, placeholders) 174 | var args []any 175 | args = append(args, util.MapToInterfaces(ids)..., ) 176 | args = append(args, 20) 177 | rows, err := db.DB().Query(q, args...) 178 | ``` 179 | 180 | This is much simple, straight expression of code that Gophers like. But there are much redundant, unsafe codes and you also must do it repeat yourself. We designed the system for making the query construction more simple, short and safe. 181 | We paid much attention to not losing clearness of SQL generation. The implicit generation of an ambiguous SQL statements, especially by library, is terribly harmful. Whole statements of SQL query should be written down by programmers as possible, it should help them partially. That is a good ORM we think. 182 | 183 | The code above can be rewritten in v2: 184 | 185 | ```go 186 | q := query.New( 187 | `SELECT * FROM users WHERE id IN (:?) AND age = ?`, query.V(1,2,3), 20, 188 | ) 189 | // SELECT * FROM users WHERE id IN (?,?,?) AND age = ? 190 | // [1,2,3,20] 191 | rows, err := db.Query(q) 192 | ``` 193 | 194 | You noticed an unfamiliar symbol `:?`, that is the new designated placeholder of exql. It is accepted by `query.New` and many methods in the package and interpolated by the corresponding value in rest arguments. In this case, the first argument, `query.V()` is used. `query.V()` is one of utility function that makes `query.Query` interface. `query.Query` interface is an abstraction of the query component and argument values. It holds query string and value separately. It embeds string into the final SQL statement and passing values later in order. The built query object can be passed to `exql.Saver`, getting execution results from the database. 195 | 196 | What is the functionality of `query.V()`? 197 | 198 | The code below: 199 | 200 | ```go 201 | ids := []int{1,2} 202 | db.Query(query.New("select * from users where id in (:?)", query.V(ids...))) 203 | ``` 204 | 205 | ...is the same as: 206 | 207 | ```go 208 | db.DB().Query("select * from users where id in (?,?)", 1, 2) 209 | ``` 210 | 211 | `query.V()` transforms values to SQL placeholder (`?`) and joins them by commas. Actual values are stored in the buffer and passed on calling `Exec`/`Query` methods of `sql.DB`. 212 | By using those functions, dynamic SQL generation gets more simple and code less. 213 | 214 | ### query.Condition 215 | 216 | `query.Condition` is the builder especially for the condition of where clause. 217 | 218 | ```go 219 | cond := query.Cond("id = ?", 1) 220 | cond.And("age >= ?", 20) 221 | cond.And("name in (:?)", query.V("go","lang")) 222 | q := query.New("select * from users where :?", cond) 223 | // select * from users where id = ? and age >= ? and name in (?,?) 224 | // [1, 20, go, lang] 225 | ``` 226 | 227 | ### query.Builder 228 | 229 | `query.Builder` is the general utility for building query dynamically. It has very similar interfaces to `strings.Builder`. But it is different in points that it accepts exql placeholder and holds values separately. 230 | 231 | ```go 232 | var qb query.Builder 233 | qb.Sprintf("SELECT * FROM %s", "users") 234 | qb.Query("WHERE id IN (:?) AND age >= ?", query.V(1,2), 20) 235 | // SELECT * FROM users WHERE id IN (?,?) AND age >= ? 236 | // [1,2,20] 237 | rows, err := db.Query(qb.Build()) 238 | ``` 239 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | package exql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "sync" 7 | "time" 8 | 9 | "log" 10 | 11 | "golang.org/x/xerrors" 12 | ) 13 | 14 | type DB interface { 15 | Saver 16 | Mapper 17 | Finder 18 | // DB returns *sql.DB object. 19 | DB() *sql.DB 20 | // SetDB sets *sql.DB object. 21 | SetDB(db *sql.DB) 22 | // Transaction begins a transaction and commits after the callback is called. 23 | // If an error is returned from the callback, it is rolled back. 24 | // Internally call tx.BeginTx(context.Background(), nil). 25 | Transaction(callback func(tx Tx) error) error 26 | // TransactionWithContext is same as Transaction(). 27 | // Internally call tx.BeginTx(ctx, opts). 28 | TransactionWithContext(ctx context.Context, opts *sql.TxOptions, callback func(tx Tx) error) error 29 | // Close calls db.Close(). 30 | Close() error 31 | } 32 | 33 | type db struct { 34 | *saver 35 | *finder 36 | *mapper 37 | db *sql.DB 38 | mutex sync.Mutex 39 | } 40 | 41 | // OpenFunc is an abstraction of sql.Open function. 42 | type OpenFunc func(driverName string, url string) (*sql.DB, error) 43 | 44 | type OpenOptions struct { 45 | // @required 46 | // DSN format for database connection. 47 | Url string 48 | // @default "mysql" 49 | DriverName string 50 | // @default 5 51 | MaxRetryCount int 52 | // @default 5s 53 | RetryInterval time.Duration 54 | // Custom opener function. 55 | OpenFunc OpenFunc 56 | } 57 | 58 | // Open opens the connection to the database and makes exql.DB interface. 59 | func Open(opts *OpenOptions) (DB, error) { 60 | return OpenContext(context.Background(), opts) 61 | } 62 | 63 | // OpenContext opens the connection to the database and makes exql.DB interface. 64 | // If something failed, it retries automatically until given retry strategies satisfied 65 | // or aborts handshaking. 66 | // 67 | // Example: 68 | // 69 | // db, err := exql.Open(context.Background(), &exql.OpenOptions{ 70 | // Url: "user:pass@tcp(127.0.0.1:3306)/database?charset=utf8mb4&parseTime=True&loc=Local", 71 | // MaxRetryCount: 3, 72 | // RetryInterval: 10, //sec 73 | // }) 74 | func OpenContext(ctx context.Context, opts *OpenOptions) (DB, error) { 75 | if opts.Url == "" { 76 | return nil, xerrors.New("opts.Url is required") 77 | } 78 | driverName := "mysql" 79 | if opts.DriverName != "" { 80 | driverName = opts.DriverName 81 | } 82 | maxRetryCount := 5 83 | retryInterval := 5 * time.Second 84 | if opts.MaxRetryCount > 0 { 85 | maxRetryCount = opts.MaxRetryCount 86 | } 87 | if opts.RetryInterval > 0 { 88 | retryInterval = opts.RetryInterval 89 | } 90 | var d *sql.DB 91 | var err error 92 | var openFunc OpenFunc = sql.Open 93 | if opts.OpenFunc != nil { 94 | openFunc = opts.OpenFunc 95 | } 96 | retryCnt := 0 97 | for retryCnt < maxRetryCount { 98 | d, err = openFunc(driverName, opts.Url) 99 | if err != nil { 100 | goto retry 101 | } else if err = d.PingContext(ctx); err != nil { 102 | goto retry 103 | } else { 104 | goto success 105 | } 106 | retry: 107 | log.Printf("failed to connect database: %s, retrying after %ds...\n", err, int(retryInterval.Seconds())) 108 | <-time.NewTimer(retryInterval).C 109 | retryCnt++ 110 | } 111 | if err != nil { 112 | return nil, err 113 | } 114 | success: 115 | return NewDB(d), nil 116 | } 117 | 118 | func NewDB(d *sql.DB) DB { 119 | return &db{ 120 | saver: newSaver(d), 121 | finder: newFinder(d), 122 | mapper: &mapper{}, 123 | db: d, 124 | } 125 | } 126 | 127 | func (d *db) Close() error { 128 | return d.db.Close() 129 | } 130 | 131 | func (d *db) DB() *sql.DB { 132 | return d.db 133 | } 134 | 135 | func (d *db) SetDB(db *sql.DB) { 136 | d.mutex.Lock() 137 | defer d.mutex.Unlock() 138 | d.db = db 139 | d.saver.ex = db 140 | } 141 | 142 | func (d *db) Transaction(callback func(tx Tx) error) error { 143 | return d.TransactionWithContext(context.Background(), nil, callback) 144 | } 145 | 146 | func (d *db) TransactionWithContext(ctx context.Context, opts *sql.TxOptions, callback func(tx Tx) error) error { 147 | return Transaction(d.db, ctx, opts, callback) 148 | } 149 | -------------------------------------------------------------------------------- /db_test.go: -------------------------------------------------------------------------------- 1 | package exql_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "testing" 7 | 8 | "github.com/loilo-inc/exql/v2" 9 | "github.com/loilo-inc/exql/v2/test" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestDb_DB(t *testing.T) { 14 | db := testDb() 15 | db.SetDB(nil) 16 | assert.Nil(t, db.DB()) 17 | } 18 | 19 | func TestNewDB(t *testing.T) { 20 | d := testSqlDB() 21 | db := exql.NewDB(d) 22 | assert.Equal(t, d, db.DB()) 23 | } 24 | 25 | func TestOpen(t *testing.T) { 26 | t.Run("should call OpenContext", func(t *testing.T) { 27 | d, err := exql.Open(&exql.OpenOptions{ 28 | Url: test.DbUrl, 29 | }) 30 | if err != nil { 31 | t.Fatal(err) 32 | } 33 | assert.NotNil(t, d) 34 | }) 35 | } 36 | 37 | func TestOpenContext(t *testing.T) { 38 | t.Run("should return error when url is empty", func(t *testing.T) { 39 | _, err := exql.OpenContext(context.TODO(), &exql.OpenOptions{ 40 | Url: "", 41 | }) 42 | assert.EqualError(t, err, "opts.Url is required") 43 | }) 44 | t.Run("with custom opener", func(t *testing.T) { 45 | var called bool 46 | _, err := exql.OpenContext(context.TODO(), &exql.OpenOptions{ 47 | Url: test.DbUrl, 48 | OpenFunc: func(driverName string, url string) (*sql.DB, error) { 49 | called = true 50 | return sql.Open(driverName, url) 51 | }, 52 | }) 53 | assert.NoError(t, err) 54 | assert.True(t, called) 55 | }) 56 | } 57 | -------------------------------------------------------------------------------- /example/delete.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/loilo-inc/exql/v2" 7 | ) 8 | 9 | func Delete(db exql.DB) { 10 | // DELETE FROM `users` WHERE id = ? 11 | // [1] 12 | _, err := db.Delete("users", exql.Where("id = ?", 1)) 13 | if err != nil { 14 | log.Fatal(err) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /example/generator.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "log" 6 | 7 | _ "github.com/go-sql-driver/mysql" 8 | "github.com/loilo-inc/exql/v2" 9 | ) 10 | 11 | func GenerateModels() { 12 | db, _ := sql.Open("mysql", "url-for-db") 13 | gen := exql.NewGenerator(db) 14 | err := gen.Generate(&exql.GenerateOptions{ 15 | // Directory path for result. Default is `model` 16 | OutDir: "dist", 17 | // Package name for models. Default is `model` 18 | Package: "dist", 19 | // Exclude table names for generation. Default is [] 20 | Exclude: []string{ 21 | "internal", 22 | }, 23 | }) 24 | if err != nil { 25 | log.Fatalf(err.Error()) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /example/insert.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/loilo-inc/exql/v2" 7 | "github.com/loilo-inc/exql/v2/model" 8 | ) 9 | 10 | func Insert(db exql.DB) { 11 | // Create a user model 12 | // Primary key (id) is not needed to set. 13 | // It will be ignored on building the insert query. 14 | user := model.Users{ 15 | Name: "Go", 16 | } 17 | // You must pass the model as a pointer. 18 | if result, err := db.Insert(&user); err != nil { 19 | log.Fatal(err.Error()) 20 | } else { 21 | insertedId, _ := result.LastInsertId() 22 | // Inserted id is assigned into the auto-increment field after the insertion, 23 | // if these field is int64/uint64 24 | if insertedId != user.Id { 25 | log.Fatal("never happens") 26 | } 27 | } 28 | } 29 | 30 | func BulkInsert(db exql.DB) { 31 | user1 := model.Users{Name: "Go"} 32 | user2 := model.Users{Name: "Lang"} 33 | // INSERT INTO users (name) VALUES (?),(?) 34 | // ["Go", "Lang"] 35 | if q, err := exql.QueryForBulkInsert(&user1, &user2); err != nil { 36 | log.Fatal(err) 37 | } else if _, err := db.Exec(q); err != nil { 38 | log.Fatal(err) 39 | } 40 | // NOTE: unlike a single insertion, bulk insertion doesn't obtain auto-incremented values from results. 41 | } 42 | -------------------------------------------------------------------------------- /example/mapper.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/loilo-inc/exql/v2" 7 | "github.com/loilo-inc/exql/v2/model" 8 | "github.com/loilo-inc/exql/v2/query" 9 | ) 10 | 11 | func Find(db exql.DB) { 12 | // Destination model struct 13 | var user model.Users 14 | // Pass as a pointer 15 | err := db.Find(query.Q(`SELECT * FROM users WHERE id = ?`, 1), &user) 16 | if err != nil { 17 | log.Fatal(err) 18 | } 19 | log.Printf("%d", user.Id) // -> 1 20 | } 21 | 22 | func FindMany(db exql.DB) { 23 | // Destination slice of models. 24 | // NOTE: It must be the slice of pointers of models. 25 | var users []*model.Users 26 | // Passing destination to MapMany(). 27 | // Second argument must be a pointer. 28 | err := db.FindMany(query.Q(`SELECT * FROM users LIMIT ?`, 5), &users) 29 | if err != nil { 30 | log.Fatal(err) 31 | } 32 | log.Printf("%d", len(users)) // -> 5 33 | } 34 | -------------------------------------------------------------------------------- /example/open.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "time" 5 | 6 | "log" 7 | 8 | "github.com/loilo-inc/exql/v2" 9 | ) 10 | 11 | func OpenDB() exql.DB { 12 | db, err := exql.Open(&exql.OpenOptions{ 13 | // MySQL url for sql.Open() 14 | Url: "user:password@tcp(127.0.0.1:3306)/database?charset=utf8mb4&parseTime=True&loc=Local", 15 | // Max retry count for database connection failure 16 | MaxRetryCount: 3, 17 | RetryInterval: 10 * time.Second, 18 | }) 19 | if err != nil { 20 | log.Fatalf("open error: %s", err) 21 | return nil 22 | } 23 | return db 24 | } 25 | -------------------------------------------------------------------------------- /example/other.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/loilo-inc/exql/v2" 7 | ) 8 | 9 | // To execute other kind of queries, unwrap sql.DB. 10 | func OtherQuery(db exql.DB) { 11 | // db.DB() returns *sql.DB 12 | row := db.DB().QueryRow("SELECT COUNT(*) FROM users") 13 | var count int 14 | row.Scan(&count) 15 | log.Printf("%d", count) 16 | } 17 | -------------------------------------------------------------------------------- /example/outer_join.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/loilo-inc/exql/v2" 7 | "github.com/loilo-inc/exql/v2/model" 8 | ) 9 | 10 | func MapSerialOuterJoin(db exql.DB) { 11 | query := ` 12 | SELECT * FROM users 13 | LEFT JOIN group_users ON group_users.user_id = users.id 14 | LEFT JOIN user_groups ON user_groups.id = group_users.id 15 | WHERE users.id = ?` 16 | rows, err := db.DB().Query(query, 1) 17 | if err != nil { 18 | log.Fatal(err) 19 | return 20 | } 21 | defer rows.Close() 22 | serialMapper := exql.NewSerialMapper(func(i int) string { 23 | // Each column's separator is `id` 24 | return "id" 25 | }) 26 | var users []*model.Users 27 | var groups []*model.UserGroups 28 | for rows.Next() { 29 | var user model.Users 30 | var groupUser *model.GroupUsers // Use *GroupUsers/*Group for outer join so that it can be nil 31 | var group *model.UserGroups // when the values of outer joined columns are NULL. 32 | if err := serialMapper.Map(rows, &user, &groupUser, &group); err != nil { 33 | log.Fatal(err.Error()) 34 | return 35 | } 36 | users = append(users, &user) 37 | groups = append(groups, group) // group = nil when the user does not belong to any group. 38 | } 39 | // enumerate users and groups. 40 | } 41 | -------------------------------------------------------------------------------- /example/query_builder.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/loilo-inc/exql/v2" 5 | "github.com/loilo-inc/exql/v2/query" 6 | ) 7 | 8 | func Query(db exql.DB) { 9 | q := query.New( 10 | `SELECT * FROM users WHERE id IN (:?) AND age = ?`, 11 | query.V(1, 2, 3), 20, 12 | ) 13 | // SELECT * FROM users WHERE id IN (?,?,?) AND age = ? 14 | // [1,2,3,20] 15 | db.Query(q) 16 | } 17 | 18 | func QueryBulider(db exql.DB) { 19 | qb := query.NewBuilder() 20 | qb.Sprintf("SELECT * FROM %s", "users") 21 | qb.Query("WHERE id IN (:?) AND age >= ?", query.V(1, 2), 20) 22 | // SELECT * FROM users WHERE id IN (?,?) AND age >= ? 23 | // [1,2,20] 24 | db.Query(qb.Build()) 25 | } 26 | 27 | func CondBulider(db exql.DB) { 28 | cond := query.Cond("id = ?", 1) 29 | cond.And("age >= ?", 20) 30 | cond.And("name in (:?)", query.V("go", "lang")) 31 | q := query.New("SELECT * FROM users WHERE :?", cond) 32 | // SELECT * FROM users WHERE id = ? and age >= ? and name in (?,?) 33 | // [1, 20, go, lang] 34 | db.Query(q) 35 | } 36 | -------------------------------------------------------------------------------- /example/serial_mapper.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/loilo-inc/exql/v2" 7 | "github.com/loilo-inc/exql/v2/model" 8 | ) 9 | 10 | /* 11 | user_groups has many users 12 | users belongs to many groups 13 | */ 14 | func MapSerial(db exql.DB) { 15 | query := ` 16 | SELECT * FROM users 17 | JOIN group_users ON group_users.user_id = users.id 18 | JOIN user_groups ON user_groups.id = group_users.id 19 | WHERE user_groups.name = ?` 20 | rows, err := db.DB().Query(query, "goland") 21 | if err != nil { 22 | log.Fatal(err) 23 | return 24 | } 25 | defer rows.Close() 26 | serialMapper := exql.NewSerialMapper(func(i int) string { 27 | // Each column's separator is `id` 28 | return "id" 29 | }) 30 | var users []*model.Users 31 | for rows.Next() { 32 | var user model.Users 33 | var groupUsers model.GroupUsers 34 | var userGroup model.UserGroups 35 | // Create serial mapper. It will split joined columns by logical tables. 36 | // In this case, joined table and destination mappings are: 37 | // | users | group_users | user_groups | 38 | // + --------- + ------------------------ + ------------- + 39 | // | id | name | id | user_id | group_id | id | name | 40 | // + --------- + ------------------------ + ------------- + 41 | // | &user | &groupUsers | &userGroup | 42 | // + --------- + ------------------------ + ------------- + 43 | if err := serialMapper.Map(rows, &user, &groupUsers, &userGroup); err != nil { 44 | log.Fatalf(err.Error()) 45 | return 46 | } 47 | users = append(users, &user) 48 | } 49 | // enumerate users... 50 | } 51 | -------------------------------------------------------------------------------- /example/tx.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "time" 7 | 8 | "github.com/loilo-inc/exql/v2" 9 | "github.com/loilo-inc/exql/v2/model" 10 | ) 11 | 12 | func Transaction(db exql.DB) { 13 | timeout, _ := context.WithTimeout(context.Background(), 10*time.Second) 14 | err := db.TransactionWithContext(timeout, &sql.TxOptions{ 15 | Isolation: sql.LevelDefault, 16 | ReadOnly: false, 17 | }, func(tx exql.Tx) error { 18 | user := model.Users{Name: "go"} 19 | _, err := tx.Insert(&user) 20 | return err 21 | }) 22 | if err != nil { 23 | // Transaction has been rolled back 24 | } else { 25 | // Transaction has been committed 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /example/update.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/loilo-inc/exql/v2" 7 | "github.com/loilo-inc/exql/v2/model" 8 | ) 9 | 10 | // Using designated update struct 11 | func UpdateModel(db exql.DB) { 12 | // UPDATE `users` SET `name` = `GoGo` WHERE `id` = ? 13 | // [1] 14 | _, err := db.UpdateModel(&model.UpdateUsers{ 15 | Name: exql.Ptr("GoGo"), 16 | }, exql.Where("id = ?", 1), 17 | ) 18 | if err != nil { 19 | log.Fatal(err) 20 | } 21 | } 22 | 23 | // With table name and key-value pairs 24 | func Update(db exql.DB) { 25 | // UPDATE `users` SET `name` = `GoGo` WHERE `id` = ? 26 | // [1] 27 | _, err := db.Update("users", map[string]any{ 28 | "name": "GoGo", 29 | }, exql.Where("id = ?", 1)) 30 | if err != nil { 31 | log.Fatal(err) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /export_test.go: -------------------------------------------------------------------------------- 1 | package exql 2 | 3 | var ErrMapRowSerialDestination = errMapRowSerialDestination 4 | var ErrMapDestination = errMapDestination 5 | var ErrMapManyDestination = errMapManyDestination 6 | 7 | type Adb = db 8 | 9 | func NewFinder(ex Executor) *finder { 10 | return newFinder(ex) 11 | } 12 | -------------------------------------------------------------------------------- /finder.go: -------------------------------------------------------------------------------- 1 | package exql 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/loilo-inc/exql/v2/query" 7 | ) 8 | 9 | // Finder is an interface to execute select query and map rows into the destination. 10 | type Finder interface { 11 | Find(q query.Query, destPtrOfStruct any) error 12 | FindContext(ctx context.Context, q query.Query, destPtrOfStruct any) error 13 | FindMany(q query.Query, destSlicePtrOfStruct any) error 14 | FindManyContext(ctx context.Context, q query.Query, destSlicePtrOfStruct any) error 15 | } 16 | 17 | type finder struct { 18 | ex Executor 19 | } 20 | 21 | // Find implements Finder 22 | func (f *finder) Find(q query.Query, destPtrOfStruct any) error { 23 | return f.FindContext(context.Background(), q, destPtrOfStruct) 24 | } 25 | 26 | // FindContext implements Finder 27 | func (f *finder) FindContext(ctx context.Context, q query.Query, destPtrOfStruct any) error { 28 | if stmt, args, err := q.Query(); err != nil { 29 | return err 30 | } else if rows, err := f.ex.QueryContext(ctx, stmt, args...); err != nil { 31 | return err 32 | } else if err := MapRow(rows, destPtrOfStruct); err != nil { 33 | return err 34 | } 35 | return nil 36 | } 37 | 38 | // FindMany implements Finder 39 | func (f *finder) FindMany(q query.Query, destSlicePtrOfStruct any) error { 40 | return f.FindManyContext(context.Background(), q, destSlicePtrOfStruct) 41 | } 42 | 43 | // FindManyContext implements Finder 44 | func (f *finder) FindManyContext(ctx context.Context, q query.Query, destSlicePtrOfStruct any) error { 45 | if stmt, args, err := q.Query(); err != nil { 46 | return err 47 | } else if rows, err := f.ex.QueryContext(ctx, stmt, args...); err != nil { 48 | return err 49 | } else if err := MapRows(rows, destSlicePtrOfStruct); err != nil { 50 | return err 51 | } 52 | return nil 53 | } 54 | 55 | func newFinder(ex Executor) *finder { 56 | return &finder{ex: ex} 57 | } 58 | -------------------------------------------------------------------------------- /finder_test.go: -------------------------------------------------------------------------------- 1 | package exql_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/golang/mock/gomock" 8 | "github.com/loilo-inc/exql/v2" 9 | "github.com/loilo-inc/exql/v2/mocks/mock_query" 10 | "github.com/loilo-inc/exql/v2/model" 11 | "github.com/loilo-inc/exql/v2/query" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestFinder(t *testing.T) { 16 | db := testDb() 17 | ctrl := gomock.NewController(t) 18 | user1 := model.Users{Name: "user1"} 19 | user2 := model.Users{Name: "user2"} 20 | db.Insert(&user1) 21 | db.Insert(&user2) 22 | t.Cleanup(func() { 23 | db.Delete( 24 | model.UsersTableName, 25 | query.Cond("id in (?,?)", user1.Id, user2.Id), 26 | ) 27 | }) 28 | f := exql.NewFinder(db.DB()) 29 | t.Run("FindContext", func(t *testing.T) { 30 | t.Run("basic", func(t *testing.T) { 31 | var dest model.Users 32 | err := f.Find(query.Q(`select * from users where id = ?`, user1.Id), &dest) 33 | assert.NoError(t, err) 34 | assert.Equal(t, user1.Name, dest.Name) 35 | }) 36 | t.Run("should error if query is invalid", func(t *testing.T) { 37 | q := mock_query.NewMockQuery(ctrl) 38 | q.EXPECT().Query().Return("", nil, fmt.Errorf("err")) 39 | err := f.Find(q, nil) 40 | assert.EqualError(t, err, "err") 41 | }) 42 | t.Run("should error if query failed", func(t *testing.T) { 43 | err := f.Find(query.Q(`select`), nil) 44 | assert.Error(t, err) 45 | }) 46 | t.Run("should error if mapping failed", func(t *testing.T) { 47 | var dest model.Users 48 | err := f.Find(query.Q(`select * from users where id = -1`), &dest) 49 | assert.ErrorIs(t, err, exql.ErrRecordNotFound) 50 | }) 51 | }) 52 | t.Run("FindManyContext", func(t *testing.T) { 53 | t.Run("basic", func(t *testing.T) { 54 | var dest []*model.Users 55 | err := f.FindMany(query.Q(`select * from users where id in (?,?)`, user1.Id, user2.Id), &dest) 56 | assert.NoError(t, err) 57 | assert.Equal(t, 2, len(dest)) 58 | assert.ElementsMatch(t, []int64{user1.Id, user2.Id}, []int64{dest[0].Id, dest[1].Id}) 59 | }) 60 | t.Run("should error if query is invalid", func(t *testing.T) { 61 | q := mock_query.NewMockQuery(ctrl) 62 | q.EXPECT().Query().Return("", nil, fmt.Errorf("err")) 63 | err := f.FindMany(q, nil) 64 | assert.EqualError(t, err, "err") 65 | }) 66 | t.Run("should error if query failed", func(t *testing.T) { 67 | err := f.FindMany(query.Q(`select`), nil) 68 | assert.Error(t, err) 69 | }) 70 | t.Run("should error if mapping failed", func(t *testing.T) { 71 | var dest []*model.Users 72 | err := f.FindMany(query.Q(`select * from users where id = -1`), &dest) 73 | assert.ErrorIs(t, err, exql.ErrRecordNotFound) 74 | }) 75 | }) 76 | } 77 | -------------------------------------------------------------------------------- /generator.go: -------------------------------------------------------------------------------- 1 | package exql 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "fmt" 7 | "go/format" 8 | "os" 9 | "path/filepath" 10 | "strings" 11 | "text/template" 12 | 13 | "github.com/iancoleman/strcase" 14 | ) 15 | 16 | type Generator interface { 17 | Generate(opts *GenerateOptions) error 18 | } 19 | type generator struct { 20 | db *sql.DB 21 | } 22 | type GenerateOptions struct { 23 | OutDir string 24 | Package string 25 | Exclude []string 26 | } 27 | 28 | type templateData struct { 29 | Imports string 30 | Model string 31 | ModelLower string 32 | M string 33 | Package string 34 | Fields string 35 | UpdaterFields string 36 | ScannedFields string 37 | TableName string 38 | } 39 | 40 | func NewGenerator(db *sql.DB) Generator { 41 | return &generator{db: db} 42 | } 43 | 44 | func (d *generator) Generate(opts *GenerateOptions) error { 45 | rows, err := d.db.Query(`show tables`) 46 | if err != nil { 47 | return err 48 | } 49 | if opts.OutDir == "" { 50 | opts.OutDir = "model" 51 | } 52 | if opts.Package == "" { 53 | opts.Package = "model" 54 | } 55 | if _, err := os.Stat(opts.OutDir); os.IsNotExist(err) { 56 | err := os.Mkdir(opts.OutDir, 0777) 57 | if err != nil { 58 | return err 59 | } 60 | } else if err != nil { 61 | return err 62 | } 63 | defer rows.Close() 64 | var tables []string 65 | for rows.Next() { 66 | var table string 67 | if err := rows.Scan(&table); err != nil { 68 | return err 69 | } 70 | for _, e := range opts.Exclude { 71 | if e == table { 72 | goto EOL 73 | } 74 | } 75 | tables = append(tables, table) 76 | EOL: 77 | } 78 | if err := rows.Err(); err != nil { 79 | return err 80 | } 81 | for _, table := range tables { 82 | if err := d.generateModelFile(table, opts); err != nil { 83 | return err 84 | } 85 | } 86 | return nil 87 | } 88 | 89 | func (d *generator) generateModelFile(tableName string, opt *GenerateOptions) error { 90 | tmpl := template.Must(template.New("model").Parse(modelTemplate)) 91 | p := NewParser() 92 | table, err := p.ParseTable(d.db, tableName) 93 | if err != nil { 94 | return err 95 | } 96 | var imports []string 97 | 98 | if table.HasJsonField() { 99 | imports = append(imports, `import "encoding/json"`) 100 | } 101 | if table.HasTimeField() { 102 | imports = append(imports, `import "time"`) 103 | } 104 | if table.HasNullField() { 105 | imports = append(imports, `import "github.com/volatiletech/null"`) 106 | } 107 | fields := strings.Builder{} 108 | updateFields := strings.Builder{} 109 | scannedFields := strings.Builder{} 110 | for i, col := range table.Columns { 111 | scannedFields.WriteString(fmt.Sprintf( 112 | "\t&%s.%s,", table.TableName[0:1], col.Field()), 113 | ) 114 | fields.WriteString(fmt.Sprintf("\t%s", col.Field())) 115 | updateFields.WriteString(fmt.Sprintf("\t%s", col.UpdateField())) 116 | if i < len(table.Columns)-1 { 117 | scannedFields.WriteString("\n") 118 | fields.WriteString("\n") 119 | updateFields.WriteString("\n") 120 | } 121 | } 122 | data := &templateData{ 123 | Imports: strings.Join(imports, "\n"), 124 | Model: strcase.ToCamel(table.TableName), 125 | ModelLower: strcase.ToLowerCamel(table.TableName), 126 | M: table.TableName[0:1], 127 | UpdaterFields: updateFields.String(), 128 | Package: opt.Package, 129 | Fields: fields.String(), 130 | TableName: tableName, 131 | ScannedFields: scannedFields.String(), 132 | } 133 | outFile := filepath.Join( 134 | opt.OutDir, 135 | fmt.Sprintf("%s.go", strcase.ToSnake(table.TableName)), 136 | ) 137 | var buf = &bytes.Buffer{} 138 | if err := tmpl.Execute(buf, data); err != nil { 139 | return err 140 | } 141 | if fmted, err := format.Source(buf.Bytes()); err != nil { 142 | return err 143 | } else if err := os.WriteFile(outFile, fmted, os.ModePerm); err != nil { 144 | return err 145 | } 146 | return nil 147 | } 148 | 149 | const modelTemplate = `// This file is generated by exql. DO NOT edit. 150 | package {{.Package}} 151 | 152 | {{.Imports}} 153 | 154 | type {{.Model}} struct { 155 | {{.Fields}} 156 | } 157 | 158 | func ({{.M}} *{{.Model}}) TableName() string { 159 | return {{.Model}}TableName 160 | } 161 | 162 | type Update{{.Model}} struct { 163 | {{.UpdaterFields}} 164 | } 165 | 166 | func ({{.M}} *Update{{.Model}}) UpdateTableName() string { 167 | return {{.Model}}TableName 168 | } 169 | 170 | const {{.Model}}TableName = "{{.TableName}}" 171 | ` 172 | -------------------------------------------------------------------------------- /generator_test.go: -------------------------------------------------------------------------------- 1 | package exql_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "github.com/DATA-DOG/go-sqlmock" 9 | "github.com/loilo-inc/exql/v2" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestGenerator_Generate(t *testing.T) { 15 | for version, db := range map[string]exql.DB{ 16 | "mysql8": testDb(), 17 | } { 18 | t.Run(version, func(t *testing.T) { 19 | g := exql.NewGenerator(db.DB()) 20 | checkFiles := func(dir string, elements []string) { 21 | entries, err := os.ReadDir(dir) 22 | assert.NoError(t, err) 23 | var files []string 24 | for _, e := range entries { 25 | files = append(files, e.Name()) 26 | } 27 | assert.ElementsMatch(t, files, elements) 28 | } 29 | t.Run("basic", func(t *testing.T) { 30 | dir := t.TempDir() 31 | err := g.Generate(&exql.GenerateOptions{ 32 | OutDir: dir, 33 | Package: "dist", 34 | }) 35 | assert.NoError(t, err) 36 | checkFiles(dir, []string{"users.go", "user_groups.go", "user_login_histories.go", "group_users.go", "fields.go"}) 37 | }) 38 | t.Run("exclude", func(t *testing.T) { 39 | dir := t.TempDir() 40 | err := g.Generate(&exql.GenerateOptions{ 41 | OutDir: dir, 42 | Package: "dist", 43 | Exclude: []string{"fields"}, 44 | }) 45 | assert.NoError(t, err) 46 | checkFiles(dir, []string{"users.go", "user_groups.go", "user_login_histories.go", "group_users.go"}) 47 | }) 48 | 49 | t.Run("should return error when rows.Error() return error", func(t *testing.T) { 50 | mockDb, mock, err := sqlmock.New() 51 | assert.NoError(t, err) 52 | defer mockDb.Close() 53 | 54 | mock.ExpectQuery(`show tables`).WillReturnRows( 55 | sqlmock.NewRows([]string{"tables"}). 56 | AddRow("users"). 57 | RowError(0, fmt.Errorf("err"))) 58 | 59 | dir := t.TempDir() 60 | assert.EqualError(t, exql.NewGenerator(mockDb). 61 | Generate(&exql.GenerateOptions{ 62 | OutDir: dir, 63 | Package: "dist", 64 | }), "err") 65 | }) 66 | }) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/loilo-inc/exql/v2 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/DATA-DOG/go-sqlmock v1.5.2 7 | github.com/friendsofgo/errors v0.9.2 // indirect 8 | github.com/go-sql-driver/mysql v1.8.1 9 | github.com/gofrs/uuid v4.4.0+incompatible // indirect 10 | github.com/iancoleman/strcase v0.3.0 11 | github.com/stretchr/testify v1.9.0 12 | github.com/volatiletech/null v8.0.0+incompatible 13 | ) 14 | 15 | require github.com/golang/mock v1.6.0 16 | 17 | require ( 18 | filippo.io/edwards25519 v1.1.0 // indirect 19 | github.com/davecgh/go-spew v1.1.1 // indirect 20 | github.com/kr/pretty v0.3.1 // indirect 21 | github.com/pmezard/go-difflib v1.0.0 // indirect 22 | github.com/volatiletech/inflect v0.0.1 // indirect 23 | github.com/volatiletech/sqlboiler v3.7.1+incompatible // indirect 24 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect 25 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 26 | gopkg.in/yaml.v3 v3.0.1 // indirect 27 | ) 28 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= 2 | filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= 3 | github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= 4 | github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= 5 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/friendsofgo/errors v0.9.2 h1:X6NYxef4efCBdwI7BgS820zFaN7Cphrmb+Pljdzjtgk= 9 | github.com/friendsofgo/errors v0.9.2/go.mod h1:yCvFW5AkDIL9qn7suHVLiI/gH228n7PC4Pn44IGoTOI= 10 | github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= 11 | github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= 12 | github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= 13 | github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= 14 | github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= 15 | github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= 16 | github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= 17 | github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= 18 | github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= 19 | github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 20 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 21 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 22 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 23 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 24 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 25 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 26 | github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= 27 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 28 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 29 | github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= 30 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= 31 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 32 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 33 | github.com/volatiletech/inflect v0.0.1 h1:2a6FcMQyhmPZcLa+uet3VJ8gLn/9svWhJxJYwvE8KsU= 34 | github.com/volatiletech/inflect v0.0.1/go.mod h1:IBti31tG6phkHitLlr5j7shC5SOo//x0AjDzaJU1PLA= 35 | github.com/volatiletech/null v8.0.0+incompatible h1:7wP8m5d/gZ6kW/9GnrLtMCRre2dlEnaQ9Km5OXlK4zg= 36 | github.com/volatiletech/null v8.0.0+incompatible/go.mod h1:0wD98JzdqB+rLyZ70fN05VDbXbafIb0KU0MdVhCzmOQ= 37 | github.com/volatiletech/sqlboiler v3.7.1+incompatible h1:dm9/NjDskQVwAarmpeZ2UqLn1NKE8M3WHSHBS4jw2x8= 38 | github.com/volatiletech/sqlboiler v3.7.1+incompatible/go.mod h1:jLfDkkHWPbS2cWRLkyC20vQWaIQsASEY7gM7zSo11Yw= 39 | github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= 40 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 41 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 42 | golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 43 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 44 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 45 | golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= 46 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 47 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 48 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 49 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 50 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 51 | golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 52 | golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 53 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 54 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 55 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 56 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 57 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 58 | golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= 59 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 60 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 61 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 62 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= 63 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= 64 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 65 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 66 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 67 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 68 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 69 | -------------------------------------------------------------------------------- /interface.go: -------------------------------------------------------------------------------- 1 | //go:generate mockgen -source $GOFILE -destination ./mocks/mock_$GOPACKAGE/$GOFILE -package mock_$GOPACKAGE 2 | package exql 3 | 4 | import ( 5 | "context" 6 | "database/sql" 7 | ) 8 | 9 | // Executor is an abstraction of both sql.DB/sql.Tx 10 | type Executor interface { 11 | Exec(query string, args ...any) (sql.Result, error) 12 | ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) 13 | Query(query string, args ...any) (*sql.Rows, error) 14 | QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) 15 | QueryRow(query string, args ...any) *sql.Row 16 | QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row 17 | Prepare(stmt string) (*sql.Stmt, error) 18 | PrepareContext(ctx context.Context, stmt string) (*sql.Stmt, error) 19 | } 20 | 21 | type Model interface { 22 | TableName() string 23 | } 24 | 25 | type ModelUpdate interface { 26 | UpdateTableName() string 27 | } 28 | -------------------------------------------------------------------------------- /mapper.go: -------------------------------------------------------------------------------- 1 | package exql 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "reflect" 7 | 8 | "golang.org/x/xerrors" 9 | ) 10 | 11 | // Error returned when record not found 12 | var ErrRecordNotFound = errors.New("record not found") 13 | 14 | // Deprecated: Use Finder It will be removed in next version. 15 | type Mapper interface { 16 | // Deprecated: Use Find or MapRow. It will be removed in next version. 17 | Map(rows *sql.Rows, destPtr any) error 18 | // Deprecated: Use FindContext or MapRows. It will be removed in next version. 19 | MapMany(rows *sql.Rows, destSlicePtr any) error 20 | } 21 | 22 | type mapper struct{} 23 | 24 | // Map reads data from single row and maps those columns into destination struct. 25 | func (m *mapper) Map(rows *sql.Rows, destPtr any) error { 26 | return MapRow(rows, destPtr) 27 | } 28 | 29 | // MapMany reads all data from rows and maps those columns for each destination struct. 30 | func (m *mapper) MapMany(rows *sql.Rows, destSlicePtr any) error { 31 | return MapRows(rows, destSlicePtr) 32 | } 33 | 34 | type ColumnSplitter func(i int) string 35 | 36 | // SerialMapper is an interface for mapping a joined row into one or more destinations serially. 37 | type SerialMapper interface { 38 | // Map reads joined rows and maps columns for each destination serially. 39 | // The second argument, pointerOfStruct, MUST BE a pointer of the struct. 40 | // 41 | // NOTE: DO NOT FORGET to close rows manually, as it WON'T do it automatically. 42 | // 43 | // Example: 44 | // 45 | // var user User 46 | // var favorite UserFavorite 47 | // defer rows.Close() 48 | // err := m.Map(rows, &user, &favorite) 49 | Map(rows *sql.Rows, pointersOfStruct ...any) error 50 | } 51 | 52 | type serialMapper struct { 53 | splitter ColumnSplitter 54 | } 55 | 56 | func NewSerialMapper(s ColumnSplitter) SerialMapper { 57 | return &serialMapper{splitter: s} 58 | } 59 | 60 | var errMapDestination = xerrors.Errorf("destination must be a pointer of struct") 61 | 62 | // MapRow reads data from single row and maps those columns into destination struct. 63 | // pointerOfStruct MUST BE a pointer of struct. 64 | // It closes rows after mapping regardless error occurred. 65 | // 66 | // Example: 67 | // 68 | // var user User 69 | // err := exql.MapRow(rows, &user) 70 | func MapRow(row *sql.Rows, pointerOfStruct interface{}) error { 71 | defer func() { 72 | if row != nil { 73 | row.Close() 74 | } 75 | }() 76 | if pointerOfStruct == nil { 77 | return errMapDestination 78 | } 79 | destValue := reflect.ValueOf(pointerOfStruct) 80 | destType := destValue.Type() 81 | if destType.Kind() != reflect.Ptr { 82 | return errMapDestination 83 | } 84 | destValue = destValue.Elem() 85 | if destValue.Kind() != reflect.Struct { 86 | return errMapDestination 87 | } 88 | if row.Next() { 89 | return mapRow(row, &destValue) 90 | } 91 | if err := row.Err(); err != nil { 92 | return err 93 | } 94 | err := row.Close() 95 | if err != nil { 96 | return err 97 | } 98 | return ErrRecordNotFound 99 | } 100 | 101 | var errMapManyDestination = xerrors.Errorf("destination must be a pointer of slice of struct") 102 | 103 | // MapRows reads all data from rows and maps those columns for each destination struct. 104 | // pointerOfSliceOfStruct MUST BE a pointer of slice of pointer of struct. 105 | // It closes rows after mapping regardless error occurred. 106 | // 107 | // Example: 108 | // 109 | // var users []*Users 110 | // err := exql.MapRows(rows, &users) 111 | func MapRows(rows *sql.Rows, structPtrOrSlicePtr interface{}) error { 112 | defer func() { 113 | if rows != nil { 114 | rows.Close() 115 | } 116 | }() 117 | if structPtrOrSlicePtr == nil { 118 | return errMapManyDestination 119 | } 120 | destValue := reflect.ValueOf(structPtrOrSlicePtr) 121 | destType := destValue.Type() 122 | if destType.Kind() != reflect.Ptr { 123 | return errMapManyDestination 124 | } 125 | destType = destType.Elem() 126 | if destType.Kind() != reflect.Slice { 127 | return errMapManyDestination 128 | } 129 | // []*Model -> *Model 130 | sliceType := destType.Elem() 131 | if sliceType.Kind() != reflect.Ptr { 132 | return errMapManyDestination 133 | } 134 | // *Model -> Model 135 | sliceType = sliceType.Elem() 136 | cnt := 0 137 | for rows.Next() { 138 | // modelValue := SliceType{} 139 | modelValue := reflect.New(sliceType).Elem() 140 | if err := mapRow(rows, &modelValue); err != nil { 141 | return err 142 | } 143 | // *dest = append(*dest, i) 144 | destValue.Elem().Set(reflect.Append(destValue.Elem(), modelValue.Addr())) 145 | cnt++ 146 | } 147 | if err := rows.Err(); err != nil { 148 | return err 149 | } 150 | err := rows.Close() 151 | if err != nil { 152 | return err 153 | } 154 | if cnt == 0 { 155 | return ErrRecordNotFound 156 | } 157 | return nil 158 | } 159 | 160 | func mapRow( 161 | row *sql.Rows, 162 | dest *reflect.Value, 163 | ) error { 164 | fields, err := aggregateFields(dest) 165 | if err != nil { 166 | return err 167 | } 168 | cols, err := row.ColumnTypes() 169 | if err != nil { 170 | return err 171 | } 172 | destVals := make([]interface{}, len(cols)) 173 | for j, col := range cols { 174 | if fIndex, ok := fields[col.Name()]; ok { 175 | f := dest.Field(fIndex) 176 | destVals[j] = f.Addr().Interface() 177 | } else { 178 | ns := &noopScanner{} 179 | destVals[j] = ns 180 | } 181 | } 182 | return row.Scan(destVals...) 183 | } 184 | 185 | func aggregateFields(dest *reflect.Value) (map[string]int, error) { 186 | // *Model || **Model 187 | destType := dest.Type() 188 | if dest.Kind() == reflect.Ptr { 189 | destType = destType.Elem() 190 | } 191 | fields := make(map[string]int) 192 | for i := 0; i < destType.NumField(); i++ { 193 | f := destType.Field(i) 194 | tag := f.Tag.Get("exql") 195 | if tag != "" { 196 | if f.Type.Kind() == reflect.Ptr { 197 | return nil, xerrors.Errorf("struct field must not be a pointer: %s %s", f.Type.Name(), f.Type.Kind()) 198 | } 199 | tags, err := ParseTags(tag) 200 | if err != nil { 201 | return nil, err 202 | } 203 | col := tags["column"] 204 | fields[col] = i 205 | } 206 | } 207 | return fields, nil 208 | } 209 | 210 | var errMapRowSerialDestination = xerrors.Errorf("destination must be either *(struct) or *((*struct)(nil))") 211 | 212 | func (s *serialMapper) Map(rows *sql.Rows, dest ...interface{}) error { 213 | var values []*reflect.Value 214 | 215 | if len(dest) == 0 { 216 | return xerrors.Errorf("empty dest list") 217 | } 218 | 219 | for _, model := range dest { 220 | v := reflect.ValueOf(model) 221 | if v.Kind() != reflect.Ptr { 222 | return errMapRowSerialDestination 223 | } 224 | v = v.Elem() 225 | if v.Kind() == reflect.Struct { 226 | values = append(values, &v) 227 | } else if v.Kind() != reflect.Ptr { 228 | return errMapRowSerialDestination 229 | } else if !v.IsNil() || v.Type().Elem().Kind() != reflect.Struct { 230 | return errMapRowSerialDestination 231 | } else { 232 | values = append(values, &v) 233 | 234 | } 235 | } 236 | return mapRowSerial(rows, values, s.splitter) 237 | } 238 | 239 | func mapRowSerial( 240 | row *sql.Rows, 241 | destList []*reflect.Value, 242 | headColProvider ColumnSplitter, 243 | ) error { 244 | // *Model || **Model 245 | var destFields []map[string]int 246 | destTypes := map[int]reflect.Type{} 247 | for destIndex, dest := range destList { 248 | fields, err := aggregateFields(dest) 249 | if err != nil { 250 | return err 251 | } 252 | destFields = append(destFields, fields) 253 | destTypes[destIndex] = dest.Type() // Model || *Model 254 | } 255 | cols, err := row.ColumnTypes() 256 | if err != nil { 257 | return err 258 | } 259 | destVals := make([]interface{}, len(cols)) 260 | colIndex := 0 261 | columnCounts := map[int]int{} 262 | for destIndex, dest := range destList { 263 | fields := destFields[destIndex] 264 | headCol := cols[colIndex] 265 | expectedHeadCol := headColProvider(destIndex) 266 | if headCol.Name() != expectedHeadCol { 267 | return xerrors.Errorf( 268 | "head col mismatch: expected=%s, actual=%s", 269 | expectedHeadCol, headCol.Name(), 270 | ) 271 | } 272 | start := colIndex 273 | ns := &noopScanner{} 274 | model := dest 275 | if destTypes[destIndex].Kind() == reflect.Ptr { 276 | m := reflect.New(destTypes[destIndex].Elem()).Elem() // Model 277 | model = &m 278 | } 279 | for ; colIndex < len(cols); colIndex++ { 280 | col := cols[colIndex] 281 | if colIndex > start && destIndex < len(destList)-1 { 282 | // Reach next column's head 283 | if col.Name() == headColProvider(destIndex+1) { 284 | columnCounts[destIndex] = colIndex - start 285 | break 286 | } 287 | } else if destIndex == len(destList)-1 { 288 | columnCounts[destIndex]++ 289 | } 290 | if fIndex, ok := fields[col.Name()]; ok { 291 | f := model.Field(fIndex) 292 | if destTypes[destIndex].Kind() == reflect.Struct { 293 | destVals[colIndex] = f.Addr().Interface() // *(Model.Field) 294 | } else { 295 | destVals[colIndex] = reflect.New(f.Addr().Type()).Interface() // **(Model.Field) 296 | } 297 | } else { 298 | destVals[colIndex] = ns 299 | } 300 | } 301 | } 302 | if err := row.Scan(destVals...); err != nil { 303 | return err 304 | } 305 | 306 | colIndex = 0 307 | for destIndex, dest := range destList { 308 | fields := destFields[destIndex] 309 | if destTypes[destIndex].Kind() == reflect.Struct || reflect.ValueOf(destVals[colIndex]).Elem().IsNil() { 310 | if destIndex < len(destList)-1 { 311 | colIndex += columnCounts[destIndex] 312 | } 313 | continue 314 | } 315 | 316 | model := reflect.New(destTypes[destIndex].Elem()) // *Model 317 | start := colIndex 318 | for ; colIndex < start+columnCounts[destIndex]; colIndex++ { 319 | col := cols[colIndex] 320 | if fIndex, ok := fields[col.Name()]; ok { 321 | f := model.Elem().Field(fIndex) 322 | if t := reflect.ValueOf(destVals[colIndex]).Elem(); t.IsNil() { 323 | f.Set(reflect.Zero(t.Type().Elem())) // To set (*null.Type)(nil) as null.Type{} 324 | } else { 325 | f.Set(reflect.ValueOf(destVals[colIndex]).Elem().Elem()) 326 | } 327 | } 328 | } 329 | dest.Set(model) // dest = *Model 330 | } 331 | 332 | return nil 333 | } 334 | 335 | type noopScanner struct { 336 | } 337 | 338 | func (n *noopScanner) Scan(_ interface{}) error { 339 | // noop 340 | return nil 341 | } 342 | -------------------------------------------------------------------------------- /mocks/mock_exql/interface.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: interface.go 3 | 4 | // Package mock_exql is a generated GoMock package. 5 | package mock_exql 6 | 7 | import ( 8 | context "context" 9 | sql "database/sql" 10 | reflect "reflect" 11 | 12 | gomock "github.com/golang/mock/gomock" 13 | ) 14 | 15 | // MockExecutor is a mock of Executor interface. 16 | type MockExecutor struct { 17 | ctrl *gomock.Controller 18 | recorder *MockExecutorMockRecorder 19 | } 20 | 21 | // MockExecutorMockRecorder is the mock recorder for MockExecutor. 22 | type MockExecutorMockRecorder struct { 23 | mock *MockExecutor 24 | } 25 | 26 | // NewMockExecutor creates a new mock instance. 27 | func NewMockExecutor(ctrl *gomock.Controller) *MockExecutor { 28 | mock := &MockExecutor{ctrl: ctrl} 29 | mock.recorder = &MockExecutorMockRecorder{mock} 30 | return mock 31 | } 32 | 33 | // EXPECT returns an object that allows the caller to indicate expected use. 34 | func (m *MockExecutor) EXPECT() *MockExecutorMockRecorder { 35 | return m.recorder 36 | } 37 | 38 | // Exec mocks base method. 39 | func (m *MockExecutor) Exec(query string, args ...any) (sql.Result, error) { 40 | m.ctrl.T.Helper() 41 | varargs := []interface{}{query} 42 | for _, a := range args { 43 | varargs = append(varargs, a) 44 | } 45 | ret := m.ctrl.Call(m, "Exec", varargs...) 46 | ret0, _ := ret[0].(sql.Result) 47 | ret1, _ := ret[1].(error) 48 | return ret0, ret1 49 | } 50 | 51 | // Exec indicates an expected call of Exec. 52 | func (mr *MockExecutorMockRecorder) Exec(query interface{}, args ...interface{}) *gomock.Call { 53 | mr.mock.ctrl.T.Helper() 54 | varargs := append([]interface{}{query}, args...) 55 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockExecutor)(nil).Exec), varargs...) 56 | } 57 | 58 | // ExecContext mocks base method. 59 | func (m *MockExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { 60 | m.ctrl.T.Helper() 61 | varargs := []interface{}{ctx, query} 62 | for _, a := range args { 63 | varargs = append(varargs, a) 64 | } 65 | ret := m.ctrl.Call(m, "ExecContext", varargs...) 66 | ret0, _ := ret[0].(sql.Result) 67 | ret1, _ := ret[1].(error) 68 | return ret0, ret1 69 | } 70 | 71 | // ExecContext indicates an expected call of ExecContext. 72 | func (mr *MockExecutorMockRecorder) ExecContext(ctx, query interface{}, args ...interface{}) *gomock.Call { 73 | mr.mock.ctrl.T.Helper() 74 | varargs := append([]interface{}{ctx, query}, args...) 75 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockExecutor)(nil).ExecContext), varargs...) 76 | } 77 | 78 | // Prepare mocks base method. 79 | func (m *MockExecutor) Prepare(stmt string) (*sql.Stmt, error) { 80 | m.ctrl.T.Helper() 81 | ret := m.ctrl.Call(m, "Prepare", stmt) 82 | ret0, _ := ret[0].(*sql.Stmt) 83 | ret1, _ := ret[1].(error) 84 | return ret0, ret1 85 | } 86 | 87 | // Prepare indicates an expected call of Prepare. 88 | func (mr *MockExecutorMockRecorder) Prepare(stmt interface{}) *gomock.Call { 89 | mr.mock.ctrl.T.Helper() 90 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockExecutor)(nil).Prepare), stmt) 91 | } 92 | 93 | // PrepareContext mocks base method. 94 | func (m *MockExecutor) PrepareContext(ctx context.Context, stmt string) (*sql.Stmt, error) { 95 | m.ctrl.T.Helper() 96 | ret := m.ctrl.Call(m, "PrepareContext", ctx, stmt) 97 | ret0, _ := ret[0].(*sql.Stmt) 98 | ret1, _ := ret[1].(error) 99 | return ret0, ret1 100 | } 101 | 102 | // PrepareContext indicates an expected call of PrepareContext. 103 | func (mr *MockExecutorMockRecorder) PrepareContext(ctx, stmt interface{}) *gomock.Call { 104 | mr.mock.ctrl.T.Helper() 105 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrepareContext", reflect.TypeOf((*MockExecutor)(nil).PrepareContext), ctx, stmt) 106 | } 107 | 108 | // Query mocks base method. 109 | func (m *MockExecutor) Query(query string, args ...any) (*sql.Rows, error) { 110 | m.ctrl.T.Helper() 111 | varargs := []interface{}{query} 112 | for _, a := range args { 113 | varargs = append(varargs, a) 114 | } 115 | ret := m.ctrl.Call(m, "Query", varargs...) 116 | ret0, _ := ret[0].(*sql.Rows) 117 | ret1, _ := ret[1].(error) 118 | return ret0, ret1 119 | } 120 | 121 | // Query indicates an expected call of Query. 122 | func (mr *MockExecutorMockRecorder) Query(query interface{}, args ...interface{}) *gomock.Call { 123 | mr.mock.ctrl.T.Helper() 124 | varargs := append([]interface{}{query}, args...) 125 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockExecutor)(nil).Query), varargs...) 126 | } 127 | 128 | // QueryContext mocks base method. 129 | func (m *MockExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { 130 | m.ctrl.T.Helper() 131 | varargs := []interface{}{ctx, query} 132 | for _, a := range args { 133 | varargs = append(varargs, a) 134 | } 135 | ret := m.ctrl.Call(m, "QueryContext", varargs...) 136 | ret0, _ := ret[0].(*sql.Rows) 137 | ret1, _ := ret[1].(error) 138 | return ret0, ret1 139 | } 140 | 141 | // QueryContext indicates an expected call of QueryContext. 142 | func (mr *MockExecutorMockRecorder) QueryContext(ctx, query interface{}, args ...interface{}) *gomock.Call { 143 | mr.mock.ctrl.T.Helper() 144 | varargs := append([]interface{}{ctx, query}, args...) 145 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockExecutor)(nil).QueryContext), varargs...) 146 | } 147 | 148 | // QueryRow mocks base method. 149 | func (m *MockExecutor) QueryRow(query string, args ...any) *sql.Row { 150 | m.ctrl.T.Helper() 151 | varargs := []interface{}{query} 152 | for _, a := range args { 153 | varargs = append(varargs, a) 154 | } 155 | ret := m.ctrl.Call(m, "QueryRow", varargs...) 156 | ret0, _ := ret[0].(*sql.Row) 157 | return ret0 158 | } 159 | 160 | // QueryRow indicates an expected call of QueryRow. 161 | func (mr *MockExecutorMockRecorder) QueryRow(query interface{}, args ...interface{}) *gomock.Call { 162 | mr.mock.ctrl.T.Helper() 163 | varargs := append([]interface{}{query}, args...) 164 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockExecutor)(nil).QueryRow), varargs...) 165 | } 166 | 167 | // QueryRowContext mocks base method. 168 | func (m *MockExecutor) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { 169 | m.ctrl.T.Helper() 170 | varargs := []interface{}{ctx, query} 171 | for _, a := range args { 172 | varargs = append(varargs, a) 173 | } 174 | ret := m.ctrl.Call(m, "QueryRowContext", varargs...) 175 | ret0, _ := ret[0].(*sql.Row) 176 | return ret0 177 | } 178 | 179 | // QueryRowContext indicates an expected call of QueryRowContext. 180 | func (mr *MockExecutorMockRecorder) QueryRowContext(ctx, query interface{}, args ...interface{}) *gomock.Call { 181 | mr.mock.ctrl.T.Helper() 182 | varargs := append([]interface{}{ctx, query}, args...) 183 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockExecutor)(nil).QueryRowContext), varargs...) 184 | } 185 | 186 | // MockModel is a mock of Model interface. 187 | type MockModel struct { 188 | ctrl *gomock.Controller 189 | recorder *MockModelMockRecorder 190 | } 191 | 192 | // MockModelMockRecorder is the mock recorder for MockModel. 193 | type MockModelMockRecorder struct { 194 | mock *MockModel 195 | } 196 | 197 | // NewMockModel creates a new mock instance. 198 | func NewMockModel(ctrl *gomock.Controller) *MockModel { 199 | mock := &MockModel{ctrl: ctrl} 200 | mock.recorder = &MockModelMockRecorder{mock} 201 | return mock 202 | } 203 | 204 | // EXPECT returns an object that allows the caller to indicate expected use. 205 | func (m *MockModel) EXPECT() *MockModelMockRecorder { 206 | return m.recorder 207 | } 208 | 209 | // TableName mocks base method. 210 | func (m *MockModel) TableName() string { 211 | m.ctrl.T.Helper() 212 | ret := m.ctrl.Call(m, "TableName") 213 | ret0, _ := ret[0].(string) 214 | return ret0 215 | } 216 | 217 | // TableName indicates an expected call of TableName. 218 | func (mr *MockModelMockRecorder) TableName() *gomock.Call { 219 | mr.mock.ctrl.T.Helper() 220 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TableName", reflect.TypeOf((*MockModel)(nil).TableName)) 221 | } 222 | 223 | // MockModelUpdate is a mock of ModelUpdate interface. 224 | type MockModelUpdate struct { 225 | ctrl *gomock.Controller 226 | recorder *MockModelUpdateMockRecorder 227 | } 228 | 229 | // MockModelUpdateMockRecorder is the mock recorder for MockModelUpdate. 230 | type MockModelUpdateMockRecorder struct { 231 | mock *MockModelUpdate 232 | } 233 | 234 | // NewMockModelUpdate creates a new mock instance. 235 | func NewMockModelUpdate(ctrl *gomock.Controller) *MockModelUpdate { 236 | mock := &MockModelUpdate{ctrl: ctrl} 237 | mock.recorder = &MockModelUpdateMockRecorder{mock} 238 | return mock 239 | } 240 | 241 | // EXPECT returns an object that allows the caller to indicate expected use. 242 | func (m *MockModelUpdate) EXPECT() *MockModelUpdateMockRecorder { 243 | return m.recorder 244 | } 245 | 246 | // UpdateTableName mocks base method. 247 | func (m *MockModelUpdate) UpdateTableName() string { 248 | m.ctrl.T.Helper() 249 | ret := m.ctrl.Call(m, "UpdateTableName") 250 | ret0, _ := ret[0].(string) 251 | return ret0 252 | } 253 | 254 | // UpdateTableName indicates an expected call of UpdateTableName. 255 | func (mr *MockModelUpdateMockRecorder) UpdateTableName() *gomock.Call { 256 | mr.mock.ctrl.T.Helper() 257 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTableName", reflect.TypeOf((*MockModelUpdate)(nil).UpdateTableName)) 258 | } 259 | -------------------------------------------------------------------------------- /mocks/mock_exql/saver.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: saver.go 3 | 4 | // Package mock_exql is a generated GoMock package. 5 | package mock_exql 6 | 7 | import ( 8 | context "context" 9 | sql "database/sql" 10 | reflect "reflect" 11 | 12 | gomock "github.com/golang/mock/gomock" 13 | exql "github.com/loilo-inc/exql/v2" 14 | query "github.com/loilo-inc/exql/v2/query" 15 | ) 16 | 17 | // MockSaver is a mock of Saver interface. 18 | type MockSaver struct { 19 | ctrl *gomock.Controller 20 | recorder *MockSaverMockRecorder 21 | } 22 | 23 | // MockSaverMockRecorder is the mock recorder for MockSaver. 24 | type MockSaverMockRecorder struct { 25 | mock *MockSaver 26 | } 27 | 28 | // NewMockSaver creates a new mock instance. 29 | func NewMockSaver(ctrl *gomock.Controller) *MockSaver { 30 | mock := &MockSaver{ctrl: ctrl} 31 | mock.recorder = &MockSaverMockRecorder{mock} 32 | return mock 33 | } 34 | 35 | // EXPECT returns an object that allows the caller to indicate expected use. 36 | func (m *MockSaver) EXPECT() *MockSaverMockRecorder { 37 | return m.recorder 38 | } 39 | 40 | // Delete mocks base method. 41 | func (m *MockSaver) Delete(table string, where query.Condition) (sql.Result, error) { 42 | m.ctrl.T.Helper() 43 | ret := m.ctrl.Call(m, "Delete", table, where) 44 | ret0, _ := ret[0].(sql.Result) 45 | ret1, _ := ret[1].(error) 46 | return ret0, ret1 47 | } 48 | 49 | // Delete indicates an expected call of Delete. 50 | func (mr *MockSaverMockRecorder) Delete(table, where interface{}) *gomock.Call { 51 | mr.mock.ctrl.T.Helper() 52 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSaver)(nil).Delete), table, where) 53 | } 54 | 55 | // DeleteContext mocks base method. 56 | func (m *MockSaver) DeleteContext(ctx context.Context, table string, where query.Condition) (sql.Result, error) { 57 | m.ctrl.T.Helper() 58 | ret := m.ctrl.Call(m, "DeleteContext", ctx, table, where) 59 | ret0, _ := ret[0].(sql.Result) 60 | ret1, _ := ret[1].(error) 61 | return ret0, ret1 62 | } 63 | 64 | // DeleteContext indicates an expected call of DeleteContext. 65 | func (mr *MockSaverMockRecorder) DeleteContext(ctx, table, where interface{}) *gomock.Call { 66 | mr.mock.ctrl.T.Helper() 67 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteContext", reflect.TypeOf((*MockSaver)(nil).DeleteContext), ctx, table, where) 68 | } 69 | 70 | // Exec mocks base method. 71 | func (m *MockSaver) Exec(query query.Query) (sql.Result, error) { 72 | m.ctrl.T.Helper() 73 | ret := m.ctrl.Call(m, "Exec", query) 74 | ret0, _ := ret[0].(sql.Result) 75 | ret1, _ := ret[1].(error) 76 | return ret0, ret1 77 | } 78 | 79 | // Exec indicates an expected call of Exec. 80 | func (mr *MockSaverMockRecorder) Exec(query interface{}) *gomock.Call { 81 | mr.mock.ctrl.T.Helper() 82 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSaver)(nil).Exec), query) 83 | } 84 | 85 | // ExecContext mocks base method. 86 | func (m *MockSaver) ExecContext(ctx context.Context, query query.Query) (sql.Result, error) { 87 | m.ctrl.T.Helper() 88 | ret := m.ctrl.Call(m, "ExecContext", ctx, query) 89 | ret0, _ := ret[0].(sql.Result) 90 | ret1, _ := ret[1].(error) 91 | return ret0, ret1 92 | } 93 | 94 | // ExecContext indicates an expected call of ExecContext. 95 | func (mr *MockSaverMockRecorder) ExecContext(ctx, query interface{}) *gomock.Call { 96 | mr.mock.ctrl.T.Helper() 97 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockSaver)(nil).ExecContext), ctx, query) 98 | } 99 | 100 | // Insert mocks base method. 101 | func (m *MockSaver) Insert(structPtr exql.Model) (sql.Result, error) { 102 | m.ctrl.T.Helper() 103 | ret := m.ctrl.Call(m, "Insert", structPtr) 104 | ret0, _ := ret[0].(sql.Result) 105 | ret1, _ := ret[1].(error) 106 | return ret0, ret1 107 | } 108 | 109 | // Insert indicates an expected call of Insert. 110 | func (mr *MockSaverMockRecorder) Insert(structPtr interface{}) *gomock.Call { 111 | mr.mock.ctrl.T.Helper() 112 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockSaver)(nil).Insert), structPtr) 113 | } 114 | 115 | // InsertContext mocks base method. 116 | func (m *MockSaver) InsertContext(ctx context.Context, structPtr exql.Model) (sql.Result, error) { 117 | m.ctrl.T.Helper() 118 | ret := m.ctrl.Call(m, "InsertContext", ctx, structPtr) 119 | ret0, _ := ret[0].(sql.Result) 120 | ret1, _ := ret[1].(error) 121 | return ret0, ret1 122 | } 123 | 124 | // InsertContext indicates an expected call of InsertContext. 125 | func (mr *MockSaverMockRecorder) InsertContext(ctx, structPtr interface{}) *gomock.Call { 126 | mr.mock.ctrl.T.Helper() 127 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertContext", reflect.TypeOf((*MockSaver)(nil).InsertContext), ctx, structPtr) 128 | } 129 | 130 | // Query mocks base method. 131 | func (m *MockSaver) Query(query query.Query) (*sql.Rows, error) { 132 | m.ctrl.T.Helper() 133 | ret := m.ctrl.Call(m, "Query", query) 134 | ret0, _ := ret[0].(*sql.Rows) 135 | ret1, _ := ret[1].(error) 136 | return ret0, ret1 137 | } 138 | 139 | // Query indicates an expected call of Query. 140 | func (mr *MockSaverMockRecorder) Query(query interface{}) *gomock.Call { 141 | mr.mock.ctrl.T.Helper() 142 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockSaver)(nil).Query), query) 143 | } 144 | 145 | // QueryContext mocks base method. 146 | func (m *MockSaver) QueryContext(ctx context.Context, query query.Query) (*sql.Rows, error) { 147 | m.ctrl.T.Helper() 148 | ret := m.ctrl.Call(m, "QueryContext", ctx, query) 149 | ret0, _ := ret[0].(*sql.Rows) 150 | ret1, _ := ret[1].(error) 151 | return ret0, ret1 152 | } 153 | 154 | // QueryContext indicates an expected call of QueryContext. 155 | func (mr *MockSaverMockRecorder) QueryContext(ctx, query interface{}) *gomock.Call { 156 | mr.mock.ctrl.T.Helper() 157 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockSaver)(nil).QueryContext), ctx, query) 158 | } 159 | 160 | // QueryRow mocks base method. 161 | func (m *MockSaver) QueryRow(query query.Query) (*sql.Row, error) { 162 | m.ctrl.T.Helper() 163 | ret := m.ctrl.Call(m, "QueryRow", query) 164 | ret0, _ := ret[0].(*sql.Row) 165 | ret1, _ := ret[1].(error) 166 | return ret0, ret1 167 | } 168 | 169 | // QueryRow indicates an expected call of QueryRow. 170 | func (mr *MockSaverMockRecorder) QueryRow(query interface{}) *gomock.Call { 171 | mr.mock.ctrl.T.Helper() 172 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockSaver)(nil).QueryRow), query) 173 | } 174 | 175 | // QueryRowContext mocks base method. 176 | func (m *MockSaver) QueryRowContext(ctx context.Context, query query.Query) (*sql.Row, error) { 177 | m.ctrl.T.Helper() 178 | ret := m.ctrl.Call(m, "QueryRowContext", ctx, query) 179 | ret0, _ := ret[0].(*sql.Row) 180 | ret1, _ := ret[1].(error) 181 | return ret0, ret1 182 | } 183 | 184 | // QueryRowContext indicates an expected call of QueryRowContext. 185 | func (mr *MockSaverMockRecorder) QueryRowContext(ctx, query interface{}) *gomock.Call { 186 | mr.mock.ctrl.T.Helper() 187 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockSaver)(nil).QueryRowContext), ctx, query) 188 | } 189 | 190 | // Update mocks base method. 191 | func (m *MockSaver) Update(table string, set map[string]any, where query.Condition) (sql.Result, error) { 192 | m.ctrl.T.Helper() 193 | ret := m.ctrl.Call(m, "Update", table, set, where) 194 | ret0, _ := ret[0].(sql.Result) 195 | ret1, _ := ret[1].(error) 196 | return ret0, ret1 197 | } 198 | 199 | // Update indicates an expected call of Update. 200 | func (mr *MockSaverMockRecorder) Update(table, set, where interface{}) *gomock.Call { 201 | mr.mock.ctrl.T.Helper() 202 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockSaver)(nil).Update), table, set, where) 203 | } 204 | 205 | // UpdateContext mocks base method. 206 | func (m *MockSaver) UpdateContext(ctx context.Context, table string, set map[string]any, where query.Condition) (sql.Result, error) { 207 | m.ctrl.T.Helper() 208 | ret := m.ctrl.Call(m, "UpdateContext", ctx, table, set, where) 209 | ret0, _ := ret[0].(sql.Result) 210 | ret1, _ := ret[1].(error) 211 | return ret0, ret1 212 | } 213 | 214 | // UpdateContext indicates an expected call of UpdateContext. 215 | func (mr *MockSaverMockRecorder) UpdateContext(ctx, table, set, where interface{}) *gomock.Call { 216 | mr.mock.ctrl.T.Helper() 217 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateContext", reflect.TypeOf((*MockSaver)(nil).UpdateContext), ctx, table, set, where) 218 | } 219 | 220 | // UpdateModel mocks base method. 221 | func (m *MockSaver) UpdateModel(updaterStructPtr exql.ModelUpdate, where query.Condition) (sql.Result, error) { 222 | m.ctrl.T.Helper() 223 | ret := m.ctrl.Call(m, "UpdateModel", updaterStructPtr, where) 224 | ret0, _ := ret[0].(sql.Result) 225 | ret1, _ := ret[1].(error) 226 | return ret0, ret1 227 | } 228 | 229 | // UpdateModel indicates an expected call of UpdateModel. 230 | func (mr *MockSaverMockRecorder) UpdateModel(updaterStructPtr, where interface{}) *gomock.Call { 231 | mr.mock.ctrl.T.Helper() 232 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateModel", reflect.TypeOf((*MockSaver)(nil).UpdateModel), updaterStructPtr, where) 233 | } 234 | 235 | // UpdateModelContext mocks base method. 236 | func (m *MockSaver) UpdateModelContext(ctx context.Context, updaterStructPtr exql.ModelUpdate, where query.Condition) (sql.Result, error) { 237 | m.ctrl.T.Helper() 238 | ret := m.ctrl.Call(m, "UpdateModelContext", ctx, updaterStructPtr, where) 239 | ret0, _ := ret[0].(sql.Result) 240 | ret1, _ := ret[1].(error) 241 | return ret0, ret1 242 | } 243 | 244 | // UpdateModelContext indicates an expected call of UpdateModelContext. 245 | func (mr *MockSaverMockRecorder) UpdateModelContext(ctx, updaterStructPtr, where interface{}) *gomock.Call { 246 | mr.mock.ctrl.T.Helper() 247 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateModelContext", reflect.TypeOf((*MockSaver)(nil).UpdateModelContext), ctx, updaterStructPtr, where) 248 | } 249 | -------------------------------------------------------------------------------- /mocks/mock_query/query.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: query.go 3 | 4 | // Package mock_query is a generated GoMock package. 5 | package mock_query 6 | 7 | import ( 8 | reflect "reflect" 9 | 10 | gomock "github.com/golang/mock/gomock" 11 | query "github.com/loilo-inc/exql/v2/query" 12 | ) 13 | 14 | // MockQuery is a mock of Query interface. 15 | type MockQuery struct { 16 | ctrl *gomock.Controller 17 | recorder *MockQueryMockRecorder 18 | } 19 | 20 | // MockQueryMockRecorder is the mock recorder for MockQuery. 21 | type MockQueryMockRecorder struct { 22 | mock *MockQuery 23 | } 24 | 25 | // NewMockQuery creates a new mock instance. 26 | func NewMockQuery(ctrl *gomock.Controller) *MockQuery { 27 | mock := &MockQuery{ctrl: ctrl} 28 | mock.recorder = &MockQueryMockRecorder{mock} 29 | return mock 30 | } 31 | 32 | // EXPECT returns an object that allows the caller to indicate expected use. 33 | func (m *MockQuery) EXPECT() *MockQueryMockRecorder { 34 | return m.recorder 35 | } 36 | 37 | // Query mocks base method. 38 | func (m *MockQuery) Query() (string, []any, error) { 39 | m.ctrl.T.Helper() 40 | ret := m.ctrl.Call(m, "Query") 41 | ret0, _ := ret[0].(string) 42 | ret1, _ := ret[1].([]any) 43 | ret2, _ := ret[2].(error) 44 | return ret0, ret1, ret2 45 | } 46 | 47 | // Query indicates an expected call of Query. 48 | func (mr *MockQueryMockRecorder) Query() *gomock.Call { 49 | mr.mock.ctrl.T.Helper() 50 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockQuery)(nil).Query)) 51 | } 52 | 53 | // MockCondition is a mock of Condition interface. 54 | type MockCondition struct { 55 | ctrl *gomock.Controller 56 | recorder *MockConditionMockRecorder 57 | } 58 | 59 | // MockConditionMockRecorder is the mock recorder for MockCondition. 60 | type MockConditionMockRecorder struct { 61 | mock *MockCondition 62 | } 63 | 64 | // NewMockCondition creates a new mock instance. 65 | func NewMockCondition(ctrl *gomock.Controller) *MockCondition { 66 | mock := &MockCondition{ctrl: ctrl} 67 | mock.recorder = &MockConditionMockRecorder{mock} 68 | return mock 69 | } 70 | 71 | // EXPECT returns an object that allows the caller to indicate expected use. 72 | func (m *MockCondition) EXPECT() *MockConditionMockRecorder { 73 | return m.recorder 74 | } 75 | 76 | // And mocks base method. 77 | func (m *MockCondition) And(str string, args ...any) { 78 | m.ctrl.T.Helper() 79 | varargs := []interface{}{str} 80 | for _, a := range args { 81 | varargs = append(varargs, a) 82 | } 83 | m.ctrl.Call(m, "And", varargs...) 84 | } 85 | 86 | // And indicates an expected call of And. 87 | func (mr *MockConditionMockRecorder) And(str interface{}, args ...interface{}) *gomock.Call { 88 | mr.mock.ctrl.T.Helper() 89 | varargs := append([]interface{}{str}, args...) 90 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "And", reflect.TypeOf((*MockCondition)(nil).And), varargs...) 91 | } 92 | 93 | // AndCond mocks base method. 94 | func (m *MockCondition) AndCond(other query.Condition) { 95 | m.ctrl.T.Helper() 96 | m.ctrl.Call(m, "AndCond", other) 97 | } 98 | 99 | // AndCond indicates an expected call of AndCond. 100 | func (mr *MockConditionMockRecorder) AndCond(other interface{}) *gomock.Call { 101 | mr.mock.ctrl.T.Helper() 102 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AndCond", reflect.TypeOf((*MockCondition)(nil).AndCond), other) 103 | } 104 | 105 | // Or mocks base method. 106 | func (m *MockCondition) Or(str string, args ...any) { 107 | m.ctrl.T.Helper() 108 | varargs := []interface{}{str} 109 | for _, a := range args { 110 | varargs = append(varargs, a) 111 | } 112 | m.ctrl.Call(m, "Or", varargs...) 113 | } 114 | 115 | // Or indicates an expected call of Or. 116 | func (mr *MockConditionMockRecorder) Or(str interface{}, args ...interface{}) *gomock.Call { 117 | mr.mock.ctrl.T.Helper() 118 | varargs := append([]interface{}{str}, args...) 119 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Or", reflect.TypeOf((*MockCondition)(nil).Or), varargs...) 120 | } 121 | 122 | // OrCond mocks base method. 123 | func (m *MockCondition) OrCond(other query.Condition) { 124 | m.ctrl.T.Helper() 125 | m.ctrl.Call(m, "OrCond", other) 126 | } 127 | 128 | // OrCond indicates an expected call of OrCond. 129 | func (mr *MockConditionMockRecorder) OrCond(other interface{}) *gomock.Call { 130 | mr.mock.ctrl.T.Helper() 131 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OrCond", reflect.TypeOf((*MockCondition)(nil).OrCond), other) 132 | } 133 | 134 | // Query mocks base method. 135 | func (m *MockCondition) Query() (string, []any, error) { 136 | m.ctrl.T.Helper() 137 | ret := m.ctrl.Call(m, "Query") 138 | ret0, _ := ret[0].(string) 139 | ret1, _ := ret[1].([]any) 140 | ret2, _ := ret[2].(error) 141 | return ret0, ret1, ret2 142 | } 143 | 144 | // Query indicates an expected call of Query. 145 | func (mr *MockConditionMockRecorder) Query() *gomock.Call { 146 | mr.mock.ctrl.T.Helper() 147 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockCondition)(nil).Query)) 148 | } 149 | -------------------------------------------------------------------------------- /model/fields.go: -------------------------------------------------------------------------------- 1 | // This file is generated by exql. DO NOT edit. 2 | package model 3 | 4 | import "encoding/json" 5 | import "time" 6 | import "github.com/volatiletech/null" 7 | 8 | type Fields struct { 9 | Id int64 `exql:"column:id;type:int;primary;not null;auto_increment" json:"id"` 10 | TinyintField int64 `exql:"column:tinyint_field;type:tinyint;not null" json:"tinyint_field"` 11 | TinyintUnsignedField int64 `exql:"column:tinyint_unsigned_field;type:tinyint unsigned;not null" json:"tinyint_unsigned_field"` 12 | TinyintNullableField null.Int64 `exql:"column:tinyint_nullable_field;type:tinyint" json:"tinyint_nullable_field"` 13 | TinyintUnsignedNullableField null.Int64 `exql:"column:tinyint_unsigned_nullable_field;type:tinyint unsigned" json:"tinyint_unsigned_nullable_field"` 14 | SmallintField int64 `exql:"column:smallint_field;type:smallint;not null" json:"smallint_field"` 15 | SmallintUnsignedField int64 `exql:"column:smallint_unsigned_field;type:smallint unsigned;not null" json:"smallint_unsigned_field"` 16 | SmallintNullableField null.Int64 `exql:"column:smallint_nullable_field;type:smallint" json:"smallint_nullable_field"` 17 | SmallintUnsignedNullableField null.Int64 `exql:"column:smallint_unsigned_nullable_field;type:smallint unsigned" json:"smallint_unsigned_nullable_field"` 18 | MediumintField int64 `exql:"column:mediumint_field;type:mediumint;not null" json:"mediumint_field"` 19 | MediumintUnsignedField int64 `exql:"column:mediumint_unsigned_field;type:mediumint unsigned;not null" json:"mediumint_unsigned_field"` 20 | MediumintNullableField null.Int64 `exql:"column:mediumint_nullable_field;type:mediumint" json:"mediumint_nullable_field"` 21 | MediumintUnsignedNullableField null.Int64 `exql:"column:mediumint_unsigned_nullable_field;type:mediumint unsigned" json:"mediumint_unsigned_nullable_field"` 22 | IntField int64 `exql:"column:int_field;type:int;not null" json:"int_field"` 23 | IntUnsignedField int64 `exql:"column:int_unsigned_field;type:int unsigned;not null" json:"int_unsigned_field"` 24 | IntNullableField null.Int64 `exql:"column:int_nullable_field;type:int" json:"int_nullable_field"` 25 | IntUnsignedNullableField null.Int64 `exql:"column:int_unsigned_nullable_field;type:int unsigned" json:"int_unsigned_nullable_field"` 26 | BigintField int64 `exql:"column:bigint_field;type:bigint;not null" json:"bigint_field"` 27 | BigintUnsignedField uint64 `exql:"column:bigint_unsigned_field;type:bigint unsigned;not null" json:"bigint_unsigned_field"` 28 | BigintNullableField null.Int64 `exql:"column:bigint_nullable_field;type:bigint" json:"bigint_nullable_field"` 29 | BigintUnsignedNullableField null.Uint64 `exql:"column:bigint_unsigned_nullable_field;type:bigint unsigned" json:"bigint_unsigned_nullable_field"` 30 | FloatField float32 `exql:"column:float_field;type:float;not null" json:"float_field"` 31 | FloatNullField null.Float32 `exql:"column:float_null_field;type:float" json:"float_null_field"` 32 | DoubleField float64 `exql:"column:double_field;type:double;not null" json:"double_field"` 33 | DoubleNullField null.Float64 `exql:"column:double_null_field;type:double" json:"double_null_field"` 34 | TinytextField string `exql:"column:tinytext_field;type:tinytext;not null" json:"tinytext_field"` 35 | TinytextNullField null.String `exql:"column:tinytext_null_field;type:tinytext" json:"tinytext_null_field"` 36 | MediumtextField string `exql:"column:mediumtext_field;type:mediumtext;not null" json:"mediumtext_field"` 37 | MediumtextNullField null.String `exql:"column:mediumtext_null_field;type:mediumtext" json:"mediumtext_null_field"` 38 | TextField string `exql:"column:text_field;type:text;not null" json:"text_field"` 39 | TextNullField null.String `exql:"column:text_null_field;type:text" json:"text_null_field"` 40 | LongtextField string `exql:"column:longtext_field;type:longtext;not null" json:"longtext_field"` 41 | LongtextNullField null.String `exql:"column:longtext_null_field;type:longtext" json:"longtext_null_field"` 42 | VarcharFiledField string `exql:"column:varchar_filed_field;type:varchar(255);not null" json:"varchar_filed_field"` 43 | VarcharNullField null.String `exql:"column:varchar_null_field;type:varchar(255)" json:"varchar_null_field"` 44 | CharFiledField string `exql:"column:char_filed_field;type:char(10);not null" json:"char_filed_field"` 45 | CharFiledNullField null.String `exql:"column:char_filed_null_field;type:char(10)" json:"char_filed_null_field"` 46 | DateField time.Time `exql:"column:date_field;type:date;not null" json:"date_field"` 47 | DateNullField null.Time `exql:"column:date_null_field;type:date" json:"date_null_field"` 48 | DatetimeField time.Time `exql:"column:datetime_field;type:datetime;not null" json:"datetime_field"` 49 | DatetimeNullField null.Time `exql:"column:datetime_null_field;type:datetime" json:"datetime_null_field"` 50 | TimeField string `exql:"column:time_field;type:time;not null" json:"time_field"` 51 | TimeNullField null.String `exql:"column:time_null_field;type:time" json:"time_null_field"` 52 | TimestampField time.Time `exql:"column:timestamp_field;type:timestamp;not null" json:"timestamp_field"` 53 | TimestampNullField null.Time `exql:"column:timestamp_null_field;type:timestamp" json:"timestamp_null_field"` 54 | TinyblobField []byte `exql:"column:tinyblob_field;type:tinyblob;not null" json:"tinyblob_field"` 55 | TinyblobNullField null.Bytes `exql:"column:tinyblob_null_field;type:tinyblob" json:"tinyblob_null_field"` 56 | MediumblobField []byte `exql:"column:mediumblob_field;type:mediumblob;not null" json:"mediumblob_field"` 57 | MediumblobNullField null.Bytes `exql:"column:mediumblob_null_field;type:mediumblob" json:"mediumblob_null_field"` 58 | BlobField []byte `exql:"column:blob_field;type:blob;not null" json:"blob_field"` 59 | BlobNullField null.Bytes `exql:"column:blob_null_field;type:blob" json:"blob_null_field"` 60 | LongblobField []byte `exql:"column:longblob_field;type:longblob;not null" json:"longblob_field"` 61 | LongblobNullField null.Bytes `exql:"column:longblob_null_field;type:longblob" json:"longblob_null_field"` 62 | JsonField json.RawMessage `exql:"column:json_field;type:json;not null" json:"json_field"` 63 | JsonNullField null.JSON `exql:"column:json_null_field;type:json" json:"json_null_field"` 64 | } 65 | 66 | func (f *Fields) TableName() string { 67 | return FieldsTableName 68 | } 69 | 70 | type UpdateFields struct { 71 | Id *int64 `exql:"column:id;type:int;primary;not null;auto_increment" json:"id"` 72 | TinyintField *int64 `exql:"column:tinyint_field;type:tinyint;not null" json:"tinyint_field"` 73 | TinyintUnsignedField *int64 `exql:"column:tinyint_unsigned_field;type:tinyint unsigned;not null" json:"tinyint_unsigned_field"` 74 | TinyintNullableField *null.Int64 `exql:"column:tinyint_nullable_field;type:tinyint" json:"tinyint_nullable_field"` 75 | TinyintUnsignedNullableField *null.Int64 `exql:"column:tinyint_unsigned_nullable_field;type:tinyint unsigned" json:"tinyint_unsigned_nullable_field"` 76 | SmallintField *int64 `exql:"column:smallint_field;type:smallint;not null" json:"smallint_field"` 77 | SmallintUnsignedField *int64 `exql:"column:smallint_unsigned_field;type:smallint unsigned;not null" json:"smallint_unsigned_field"` 78 | SmallintNullableField *null.Int64 `exql:"column:smallint_nullable_field;type:smallint" json:"smallint_nullable_field"` 79 | SmallintUnsignedNullableField *null.Int64 `exql:"column:smallint_unsigned_nullable_field;type:smallint unsigned" json:"smallint_unsigned_nullable_field"` 80 | MediumintField *int64 `exql:"column:mediumint_field;type:mediumint;not null" json:"mediumint_field"` 81 | MediumintUnsignedField *int64 `exql:"column:mediumint_unsigned_field;type:mediumint unsigned;not null" json:"mediumint_unsigned_field"` 82 | MediumintNullableField *null.Int64 `exql:"column:mediumint_nullable_field;type:mediumint" json:"mediumint_nullable_field"` 83 | MediumintUnsignedNullableField *null.Int64 `exql:"column:mediumint_unsigned_nullable_field;type:mediumint unsigned" json:"mediumint_unsigned_nullable_field"` 84 | IntField *int64 `exql:"column:int_field;type:int;not null" json:"int_field"` 85 | IntUnsignedField *int64 `exql:"column:int_unsigned_field;type:int unsigned;not null" json:"int_unsigned_field"` 86 | IntNullableField *null.Int64 `exql:"column:int_nullable_field;type:int" json:"int_nullable_field"` 87 | IntUnsignedNullableField *null.Int64 `exql:"column:int_unsigned_nullable_field;type:int unsigned" json:"int_unsigned_nullable_field"` 88 | BigintField *int64 `exql:"column:bigint_field;type:bigint;not null" json:"bigint_field"` 89 | BigintUnsignedField *uint64 `exql:"column:bigint_unsigned_field;type:bigint unsigned;not null" json:"bigint_unsigned_field"` 90 | BigintNullableField *null.Int64 `exql:"column:bigint_nullable_field;type:bigint" json:"bigint_nullable_field"` 91 | BigintUnsignedNullableField *null.Uint64 `exql:"column:bigint_unsigned_nullable_field;type:bigint unsigned" json:"bigint_unsigned_nullable_field"` 92 | FloatField *float32 `exql:"column:float_field;type:float;not null" json:"float_field"` 93 | FloatNullField *null.Float32 `exql:"column:float_null_field;type:float" json:"float_null_field"` 94 | DoubleField *float64 `exql:"column:double_field;type:double;not null" json:"double_field"` 95 | DoubleNullField *null.Float64 `exql:"column:double_null_field;type:double" json:"double_null_field"` 96 | TinytextField *string `exql:"column:tinytext_field;type:tinytext;not null" json:"tinytext_field"` 97 | TinytextNullField *null.String `exql:"column:tinytext_null_field;type:tinytext" json:"tinytext_null_field"` 98 | MediumtextField *string `exql:"column:mediumtext_field;type:mediumtext;not null" json:"mediumtext_field"` 99 | MediumtextNullField *null.String `exql:"column:mediumtext_null_field;type:mediumtext" json:"mediumtext_null_field"` 100 | TextField *string `exql:"column:text_field;type:text;not null" json:"text_field"` 101 | TextNullField *null.String `exql:"column:text_null_field;type:text" json:"text_null_field"` 102 | LongtextField *string `exql:"column:longtext_field;type:longtext;not null" json:"longtext_field"` 103 | LongtextNullField *null.String `exql:"column:longtext_null_field;type:longtext" json:"longtext_null_field"` 104 | VarcharFiledField *string `exql:"column:varchar_filed_field;type:varchar(255);not null" json:"varchar_filed_field"` 105 | VarcharNullField *null.String `exql:"column:varchar_null_field;type:varchar(255)" json:"varchar_null_field"` 106 | CharFiledField *string `exql:"column:char_filed_field;type:char(10);not null" json:"char_filed_field"` 107 | CharFiledNullField *null.String `exql:"column:char_filed_null_field;type:char(10)" json:"char_filed_null_field"` 108 | DateField *time.Time `exql:"column:date_field;type:date;not null" json:"date_field"` 109 | DateNullField *null.Time `exql:"column:date_null_field;type:date" json:"date_null_field"` 110 | DatetimeField *time.Time `exql:"column:datetime_field;type:datetime;not null" json:"datetime_field"` 111 | DatetimeNullField *null.Time `exql:"column:datetime_null_field;type:datetime" json:"datetime_null_field"` 112 | TimeField *string `exql:"column:time_field;type:time;not null" json:"time_field"` 113 | TimeNullField *null.String `exql:"column:time_null_field;type:time" json:"time_null_field"` 114 | TimestampField *time.Time `exql:"column:timestamp_field;type:timestamp;not null" json:"timestamp_field"` 115 | TimestampNullField *null.Time `exql:"column:timestamp_null_field;type:timestamp" json:"timestamp_null_field"` 116 | TinyblobField *[]byte `exql:"column:tinyblob_field;type:tinyblob;not null" json:"tinyblob_field"` 117 | TinyblobNullField *null.Bytes `exql:"column:tinyblob_null_field;type:tinyblob" json:"tinyblob_null_field"` 118 | MediumblobField *[]byte `exql:"column:mediumblob_field;type:mediumblob;not null" json:"mediumblob_field"` 119 | MediumblobNullField *null.Bytes `exql:"column:mediumblob_null_field;type:mediumblob" json:"mediumblob_null_field"` 120 | BlobField *[]byte `exql:"column:blob_field;type:blob;not null" json:"blob_field"` 121 | BlobNullField *null.Bytes `exql:"column:blob_null_field;type:blob" json:"blob_null_field"` 122 | LongblobField *[]byte `exql:"column:longblob_field;type:longblob;not null" json:"longblob_field"` 123 | LongblobNullField *null.Bytes `exql:"column:longblob_null_field;type:longblob" json:"longblob_null_field"` 124 | JsonField *json.RawMessage `exql:"column:json_field;type:json;not null" json:"json_field"` 125 | JsonNullField *null.JSON `exql:"column:json_null_field;type:json" json:"json_null_field"` 126 | } 127 | 128 | func (f *UpdateFields) UpdateTableName() string { 129 | return FieldsTableName 130 | } 131 | 132 | const FieldsTableName = "fields" 133 | -------------------------------------------------------------------------------- /model/group_users.go: -------------------------------------------------------------------------------- 1 | // This file is generated by exql. DO NOT edit. 2 | package model 3 | 4 | type GroupUsers struct { 5 | Id int64 `exql:"column:id;type:int;primary;not null;auto_increment" json:"id"` 6 | UserId int64 `exql:"column:user_id;type:int;not null" json:"user_id"` 7 | GroupId int64 `exql:"column:group_id;type:int;not null" json:"group_id"` 8 | } 9 | 10 | func (g *GroupUsers) TableName() string { 11 | return GroupUsersTableName 12 | } 13 | 14 | type UpdateGroupUsers struct { 15 | Id *int64 `exql:"column:id;type:int;primary;not null;auto_increment" json:"id"` 16 | UserId *int64 `exql:"column:user_id;type:int;not null" json:"user_id"` 17 | GroupId *int64 `exql:"column:group_id;type:int;not null" json:"group_id"` 18 | } 19 | 20 | func (g *UpdateGroupUsers) UpdateTableName() string { 21 | return GroupUsersTableName 22 | } 23 | 24 | const GroupUsersTableName = "group_users" 25 | -------------------------------------------------------------------------------- /model/testmodel/testmodel.go: -------------------------------------------------------------------------------- 1 | package testmodel 2 | 3 | type MultiplePrimaryKey struct { 4 | Pk1 string `exql:"column:pk1;primary"` 5 | Pk2 string `exql:"column:pk2;primary"` 6 | Other int `exql:"column:other"` 7 | } 8 | 9 | func (*MultiplePrimaryKey) TableName() string { 10 | return "dummy" 11 | } 12 | 13 | type NoTag struct { 14 | } 15 | 16 | func (*NoTag) TableName() string { 17 | return "dummy" 18 | } 19 | 20 | type BadTableName struct { 21 | Id int `exql:"column:id;primary;auto_increment"` 22 | } 23 | 24 | func (BadTableName) TableName() string { 25 | return "" 26 | } 27 | 28 | type NoPrimaryKey struct { 29 | Id int `exql:"column:id;auto_increment"` 30 | } 31 | 32 | func (NoPrimaryKey) TableName() string { 33 | return "" 34 | } 35 | 36 | type NoColumnTag struct { 37 | Id int `exql:"primary;auto_increment"` 38 | } 39 | 40 | func (NoColumnTag) TableName() string { 41 | return "" 42 | } 43 | 44 | type BadTag struct { 45 | Id int `exql:"a;a:1"` 46 | } 47 | 48 | func (BadTag) TableName() string { 49 | return "" 50 | } 51 | 52 | type NoAutoIncrementKey struct { 53 | Id int `exql:"column:id;primary"` 54 | Name string `exql:"column:name"` 55 | } 56 | 57 | func (s *NoAutoIncrementKey) TableName() string { 58 | return "sampleNoAutoIncrementKey" 59 | } 60 | 61 | type PrimaryUint64 struct { 62 | Id uint64 `exql:"column:id;primary;auto_increment"` 63 | Name string `exql:"column:name"` 64 | } 65 | 66 | func (s *PrimaryUint64) TableName() string { 67 | return "samplePrimaryUint64" 68 | } 69 | -------------------------------------------------------------------------------- /model/user_groups.go: -------------------------------------------------------------------------------- 1 | // This file is generated by exql. DO NOT edit. 2 | package model 3 | 4 | type UserGroups struct { 5 | Id int64 `exql:"column:id;type:int;primary;not null;auto_increment" json:"id"` 6 | Name string `exql:"column:name;type:varchar(255);not null" json:"name"` 7 | } 8 | 9 | func (u *UserGroups) TableName() string { 10 | return UserGroupsTableName 11 | } 12 | 13 | type UpdateUserGroups struct { 14 | Id *int64 `exql:"column:id;type:int;primary;not null;auto_increment" json:"id"` 15 | Name *string `exql:"column:name;type:varchar(255);not null" json:"name"` 16 | } 17 | 18 | func (u *UpdateUserGroups) UpdateTableName() string { 19 | return UserGroupsTableName 20 | } 21 | 22 | const UserGroupsTableName = "user_groups" 23 | -------------------------------------------------------------------------------- /model/user_login_histories.go: -------------------------------------------------------------------------------- 1 | // This file is generated by exql. DO NOT edit. 2 | package model 3 | 4 | import "time" 5 | 6 | type UserLoginHistories struct { 7 | Id int64 `exql:"column:id;type:int;primary;not null;auto_increment" json:"id"` 8 | UserId int64 `exql:"column:user_id;type:int;not null" json:"user_id"` 9 | CreatedAt time.Time `exql:"column:created_at;type:datetime;primary;not null" json:"created_at"` 10 | } 11 | 12 | func (u *UserLoginHistories) TableName() string { 13 | return UserLoginHistoriesTableName 14 | } 15 | 16 | type UpdateUserLoginHistories struct { 17 | Id *int64 `exql:"column:id;type:int;primary;not null;auto_increment" json:"id"` 18 | UserId *int64 `exql:"column:user_id;type:int;not null" json:"user_id"` 19 | CreatedAt *time.Time `exql:"column:created_at;type:datetime;primary;not null" json:"created_at"` 20 | } 21 | 22 | func (u *UpdateUserLoginHistories) UpdateTableName() string { 23 | return UserLoginHistoriesTableName 24 | } 25 | 26 | const UserLoginHistoriesTableName = "user_login_histories" 27 | -------------------------------------------------------------------------------- /model/users.go: -------------------------------------------------------------------------------- 1 | // This file is generated by exql. DO NOT edit. 2 | package model 3 | 4 | type Users struct { 5 | Id int64 `exql:"column:id;type:int;primary;not null;auto_increment" json:"id"` 6 | Name string `exql:"column:name;type:varchar(255);not null" json:"name"` 7 | Age int64 `exql:"column:age;type:int;not null" json:"age"` 8 | } 9 | 10 | func (u *Users) TableName() string { 11 | return UsersTableName 12 | } 13 | 14 | type UpdateUsers struct { 15 | Id *int64 `exql:"column:id;type:int;primary;not null;auto_increment" json:"id"` 16 | Name *string `exql:"column:name;type:varchar(255);not null" json:"name"` 17 | Age *int64 `exql:"column:age;type:int;not null" json:"age"` 18 | } 19 | 20 | func (u *UpdateUsers) UpdateTableName() string { 21 | return UsersTableName 22 | } 23 | 24 | const UsersTableName = "users" 25 | -------------------------------------------------------------------------------- /parser.go: -------------------------------------------------------------------------------- 1 | package exql 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "regexp" 7 | "strings" 8 | 9 | "github.com/iancoleman/strcase" 10 | "golang.org/x/xerrors" 11 | ) 12 | 13 | type parser struct{} 14 | 15 | type Parser interface { 16 | ParseTable(db *sql.DB, table string) (*Table, error) 17 | } 18 | 19 | func NewParser() Parser { 20 | return &parser{} 21 | } 22 | 23 | type Table struct { 24 | TableName string `json:"table_name"` 25 | Columns []*Column `json:"columns"` 26 | } 27 | 28 | func (t *Table) Fields() []string { 29 | var ret []string 30 | for _, c := range t.Columns { 31 | ret = append(ret, c.Field()) 32 | } 33 | return ret 34 | } 35 | func (t *Table) HasNullField() bool { 36 | for _, c := range t.Columns { 37 | if c.Nullable { 38 | return true 39 | } 40 | } 41 | return false 42 | } 43 | 44 | func (t *Table) HasTimeField() bool { 45 | for _, c := range t.Columns { 46 | if c.GoFieldType == "time.Time" { 47 | return true 48 | } 49 | } 50 | return false 51 | } 52 | 53 | func (t *Table) HasJsonField() bool { 54 | for _, c := range t.Columns { 55 | if c.GoFieldType == "json.RawMessage" { 56 | return true 57 | } 58 | } 59 | return false 60 | } 61 | 62 | type Column struct { 63 | FieldName string `json:"field_name"` 64 | FieldType string `json:"field_type"` 65 | FieldIndex int `json:"field_index"` 66 | GoFieldType string `json:"go_field_type"` 67 | Nullable bool `json:"nullable"` 68 | DefaultValue sql.NullString `json:"default_value"` 69 | Key sql.NullString `json:"key"` 70 | Extra sql.NullString `json:"extra"` 71 | } 72 | 73 | func (c *Column) IsPrimary() bool { 74 | return c.Key.String == "PRI" 75 | } 76 | 77 | func (c *Column) ParseExtra() []string { 78 | comps := strings.Split(c.Extra.String, " ") 79 | empty := regexp.MustCompile(`^\s*$`) 80 | var ret []string 81 | for i := 0; i < len(comps); i++ { 82 | v := strings.Trim(comps[i], " ") 83 | if empty.MatchString(v) { 84 | continue 85 | } 86 | ret = append(ret, v) 87 | } 88 | return ret 89 | } 90 | 91 | func (c *Column) Field() string { 92 | return c.field(c.GoFieldType) 93 | } 94 | 95 | func (c *Column) UpdateField() string { 96 | return c.field("*" + c.GoFieldType) 97 | } 98 | 99 | func (c *Column) field(goFiledType string) string { 100 | var tag []string 101 | tag = append(tag, fmt.Sprintf("column:%s", c.FieldName)) 102 | tag = append(tag, fmt.Sprintf("type:%s", c.FieldType)) 103 | if c.IsPrimary() { 104 | tag = append(tag, "primary") 105 | } 106 | if !c.Nullable { 107 | tag = append(tag, "not null") 108 | } 109 | tag = append(tag, c.ParseExtra()...) 110 | return fmt.Sprintf("%s %s `exql:\"%s\" json:\"%s\"`", 111 | strcase.ToCamel(c.FieldName), 112 | goFiledType, 113 | strings.Join(tag, ";"), 114 | strcase.ToSnake(c.FieldName), 115 | ) 116 | } 117 | 118 | func (p *parser) ParseTable(db *sql.DB, table string) (*Table, error) { 119 | rows, err := db.Query(fmt.Sprintf("show columns from %s", table)) 120 | if err != nil { 121 | return nil, err 122 | } 123 | defer rows.Close() 124 | var cols []*Column 125 | i := 0 126 | for rows.Next() { 127 | field := "" 128 | _type := "" 129 | _null := sql.NullString{} 130 | key := sql.NullString{} 131 | _default := sql.NullString{} 132 | extra := sql.NullString{} 133 | if err := rows.Scan(&field, &_type, &_null, &key, &_default, &extra); err != nil { 134 | return nil, err 135 | } 136 | parsedType, err := ParseType(_type, _null.String == "YES") 137 | if err != nil { 138 | return nil, err 139 | } 140 | cols = append(cols, &Column{ 141 | FieldName: field, 142 | FieldType: _type, 143 | FieldIndex: i, 144 | GoFieldType: parsedType, 145 | Nullable: _null.String == "YES", 146 | DefaultValue: _default, 147 | Key: key, 148 | Extra: extra, 149 | }) 150 | i++ 151 | } 152 | if err := rows.Err(); err != nil { 153 | return nil, err 154 | } 155 | return &Table{ 156 | TableName: table, 157 | Columns: cols, 158 | }, nil 159 | } 160 | 161 | var ( 162 | intPat = regexp.MustCompile(`^(tiny|small|medium|big)?int(\(\d+?\))?( unsigned)?( zerofill)?$`) 163 | floatPat = regexp.MustCompile(`^float$`) 164 | doublePat = regexp.MustCompile(`^double$`) 165 | charPat = regexp.MustCompile(`^(var)?char\(\d+?\)$`) 166 | textPat = regexp.MustCompile(`^(tiny|medium|long)?text$`) 167 | blobPat = regexp.MustCompile(`^(tiny|medium|long)?blob$`) 168 | datePat = regexp.MustCompile(`^(date|datetime|datetime\(\d\)|timestamp|timestamp\(\d\))$`) 169 | timePat = regexp.MustCompile(`^(time|time\(\d\))$`) 170 | jsonPat = regexp.MustCompile(`^json$`) 171 | ) 172 | 173 | const ( 174 | nullUint64Type = "null.Uint64" 175 | nullInt64Type = "null.Int64" 176 | uint64Type = "uint64" 177 | int64Type = "int64" 178 | nullFloat64Type = "null.Float64" 179 | float64Type = "float64" 180 | nullFloat32Type = "null.Float32" 181 | float32Type = "float32" 182 | nullTimeType = "null.Time" 183 | timeType = "time.Time" 184 | nullStrType = "null.String" 185 | strType = "string" 186 | nullBytesType = "null.Bytes" 187 | bytesType = "[]byte" 188 | nullJsonType = "null.JSON" 189 | jsonType = "json.RawMessage" 190 | ) 191 | 192 | func ParseType(t string, nullable bool) (string, error) { 193 | if intPat.MatchString(t) { 194 | m := intPat.FindStringSubmatch(t) 195 | unsigned := strings.Contains(t, "unsigned") 196 | is64 := false 197 | if len(m) > 2 { 198 | switch m[1] { 199 | case "big": 200 | is64 = true 201 | default: 202 | } 203 | } 204 | if nullable { 205 | if unsigned && is64 { 206 | return nullUint64Type, nil 207 | } else { 208 | return nullInt64Type, nil 209 | } 210 | } else { 211 | if unsigned && is64 { 212 | return uint64Type, nil 213 | } else { 214 | return int64Type, nil 215 | } 216 | } 217 | } else if datePat.MatchString(t) { 218 | if nullable { 219 | return nullTimeType, nil 220 | } 221 | return timeType, nil 222 | } else if timePat.MatchString(t) { 223 | if nullable { 224 | return nullStrType, nil 225 | } 226 | return strType, nil 227 | } else if textPat.MatchString(t) || charPat.MatchString(t) { 228 | if nullable { 229 | return nullStrType, nil 230 | } 231 | return strType, nil 232 | } else if floatPat.MatchString(t) { 233 | if nullable { 234 | return nullFloat32Type, nil 235 | } 236 | return float32Type, nil 237 | } else if doublePat.MatchString(t) { 238 | if nullable { 239 | return nullFloat64Type, nil 240 | } 241 | return float64Type, nil 242 | } else if blobPat.MatchString(t) { 243 | if nullable { 244 | return nullBytesType, nil 245 | } 246 | return bytesType, nil 247 | } else if jsonPat.MatchString(t) { 248 | if nullable { 249 | return nullJsonType, nil 250 | } 251 | return jsonType, nil 252 | } 253 | return "", xerrors.Errorf("unknown type: %s", t) 254 | } 255 | -------------------------------------------------------------------------------- /parser_test.go: -------------------------------------------------------------------------------- 1 | package exql_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/DATA-DOG/go-sqlmock" 8 | "github.com/loilo-inc/exql/v2" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestParser_ParseTable(t *testing.T) { 14 | t.Run("should return error when rows.Error() return error", func(t *testing.T) { 15 | mockDb, mock, err := sqlmock.New() 16 | assert.NoError(t, err) 17 | defer mockDb.Close() 18 | 19 | p := exql.NewParser() 20 | 21 | mock.ExpectQuery(`show columns from users`).WillReturnRows( 22 | sqlmock.NewRows([]string{"field", "type"}). 23 | AddRow("id", "int(11)"). 24 | RowError(0, fmt.Errorf("err"))) 25 | 26 | table, err := p.ParseTable(mockDb, "users") 27 | assert.Nil(t, table) 28 | assert.EqualError(t, err, "err") 29 | }) 30 | } 31 | 32 | func TestParser_ParseType(t *testing.T) { 33 | assertType := func(s string, nullable bool, tp interface{}) { 34 | ret, err := exql.ParseType(s, nullable) 35 | assert.NoError(t, err) 36 | assert.Equal(t, ret, tp) 37 | } 38 | t.Run("int", func(t *testing.T) { 39 | list := [][]interface{}{ 40 | {"int", "int64", "int64", "null.Int64", "null.Int64"}, 41 | {"tinyint", "int64", "int64", "null.Int64", "null.Int64"}, 42 | {"smallint", "int64", "int64", "null.Int64", "null.Int64"}, 43 | {"mediumint", "int64", "int64", "null.Int64", "null.Int64"}, 44 | {"bigint", "int64", "uint64", "null.Int64", "null.Uint64"}, 45 | } 46 | for _, v := range list { 47 | title := v[0].(string) 48 | t.Run(title, func(t *testing.T) { 49 | assertType(fmt.Sprintf("%s(1)", title), false, v[1]) 50 | assertType(fmt.Sprintf("%s(1) unsigned", title), false, v[2]) 51 | assertType(fmt.Sprintf("%s(1)", title), true, v[3]) 52 | assertType(fmt.Sprintf("%s(1) unsigned", title), true, v[4]) 53 | }) 54 | } 55 | }) 56 | t.Run("float", func(t *testing.T) { 57 | assertType("float", false, "float32") 58 | assertType("float", true, "null.Float32") 59 | }) 60 | t.Run("double", func(t *testing.T) { 61 | assertType("double", false, "float64") 62 | assertType("double", true, "null.Float64") 63 | }) 64 | t.Run("date", func(t *testing.T) { 65 | list := [][]interface{}{ 66 | {"date", "time.Time", "null.Time"}, 67 | {"datetime", "time.Time", "null.Time"}, 68 | {"datetime(6)", "time.Time", "null.Time"}, 69 | {"timestamp", "time.Time", "null.Time"}, 70 | {"timestamp(6)", "time.Time", "null.Time"}, 71 | {"time", "string", "null.String"}, 72 | {"time(6)", "string", "null.String"}, 73 | } 74 | for _, v := range list { 75 | t := v[0].(string) 76 | assertType(t, false, v[1].(string)) 77 | assertType(t, true, v[2].(string)) 78 | } 79 | }) 80 | t.Run("string", func(t *testing.T) { 81 | list := [][]interface{}{ 82 | {"text", "string", "null.String"}, 83 | {"tinytext", "string", "null.String"}, 84 | {"mediumtext", "string", "null.String"}, 85 | {"longtext", "string", "null.String"}, 86 | {"char(10)", "string", "null.String"}, 87 | {"varchar(255)", "string", "null.String"}, 88 | } 89 | for _, v := range list { 90 | t := v[0].(string) 91 | assertType(t, false, v[1].(string)) 92 | assertType(t, true, v[2].(string)) 93 | } 94 | }) 95 | t.Run("blob", func(t *testing.T) { 96 | list := [][]interface{}{ 97 | {"blob", "[]byte", "null.Bytes"}, 98 | {"tinyblob", "[]byte", "null.Bytes"}, 99 | {"mediumblob", "[]byte", "null.Bytes"}, 100 | {"longblob", "[]byte", "null.Bytes"}, 101 | } 102 | for _, v := range list { 103 | t := v[0].(string) 104 | assertType(t, false, v[1].(string)) 105 | assertType(t, true, v[2].(string)) 106 | } 107 | }) 108 | t.Run("json", func(t *testing.T) { 109 | assertType("json", false, "json.RawMessage") 110 | assertType("json", true, "null.JSON") 111 | }) 112 | } 113 | -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | package exql 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | 7 | q "github.com/loilo-inc/exql/v2/query" 8 | "golang.org/x/xerrors" 9 | ) 10 | 11 | func Where(str string, args ...any) q.Condition { 12 | return q.Cond(str, args...) 13 | } 14 | 15 | type ModelMetadata struct { 16 | TableName string 17 | AutoIncrementField *reflect.Value 18 | PrimaryKeyColumns []string 19 | PrimaryKeyValues []any 20 | Values q.KeyIterator[any] 21 | } 22 | 23 | func QueryForInsert(modelPtr Model) (q.Query, *reflect.Value, error) { 24 | m, err := AggregateModelMetadata(modelPtr) 25 | if err != nil { 26 | return nil, nil, err 27 | } 28 | b := q.NewBuilder() 29 | cols := q.Cols(m.Values.Keys()...) 30 | vals := q.Vals(m.Values.Values()) 31 | b.Sprintf("INSERT INTO `%s`", modelPtr.TableName()) 32 | b.Query("(:?) VALUES (:?)", cols, vals) 33 | return b.Build(), m.AutoIncrementField, nil 34 | } 35 | 36 | func QueryForBulkInsert[T Model](modelPtrs ...T) (q.Query, error) { 37 | if len(modelPtrs) == 0 { 38 | return nil, errors.New("empty list") 39 | } 40 | var head *ModelMetadata 41 | b := q.NewBuilder() 42 | vals := q.NewBuilder() 43 | for _, v := range modelPtrs { 44 | if data, err := AggregateModelMetadata(v); err != nil { 45 | return nil, err 46 | } else { 47 | if head == nil { 48 | head = data 49 | } 50 | vals.Query("(:?)", q.Vals(data.Values.Values())) 51 | } 52 | } 53 | b.Sprintf("INSERT INTO `%s`", head.TableName) 54 | b.Query("(:?) VALUES :?", q.Cols(head.Values.Keys()...), vals.Join(",")) 55 | return b.Build(), nil 56 | } 57 | 58 | func AggregateModelMetadata(modelPtr Model) (*ModelMetadata, error) { 59 | if modelPtr == nil { 60 | return nil, xerrors.Errorf("pointer is nil") 61 | } 62 | objValue := reflect.ValueOf(modelPtr) 63 | objType := objValue.Type() 64 | if objType.Kind() != reflect.Ptr || objType.Elem().Kind() != reflect.Struct { 65 | return nil, xerrors.Errorf("object must be pointer of struct") 66 | } 67 | data := map[string]any{} 68 | // *User -> User 69 | objType = objType.Elem() 70 | exqlTagCount := 0 71 | var primaryKeyColumns []string 72 | var primaryKeyValues []any 73 | var autoIncrementField *reflect.Value 74 | for i := 0; i < objType.NumField(); i++ { 75 | f := objType.Field(i) 76 | if t, ok := f.Tag.Lookup("exql"); ok { 77 | tags, err := ParseTags(t) 78 | if err != nil { 79 | return nil, err 80 | } 81 | colName, ok := tags["column"] 82 | if !ok || colName == "" { 83 | return nil, xerrors.Errorf("column tag is not set") 84 | } 85 | exqlTagCount++ 86 | if _, primary := tags["primary"]; primary { 87 | primaryKeyField := objValue.Elem().Field(i) 88 | primaryKeyColumns = append(primaryKeyColumns, colName) 89 | primaryKeyValues = append(primaryKeyValues, primaryKeyField.Interface()) 90 | } 91 | if _, autoIncrement := tags["auto_increment"]; autoIncrement { 92 | field := objValue.Elem().Field(i) 93 | autoIncrementField = &field 94 | // Not include auto_increment field in insert query 95 | continue 96 | } 97 | data[colName] = objValue.Elem().Field(i).Interface() 98 | } 99 | } 100 | if exqlTagCount == 0 { 101 | return nil, xerrors.Errorf("obj doesn't have exql tags in any fields") 102 | } 103 | 104 | if len(primaryKeyColumns) == 0 { 105 | return nil, xerrors.Errorf("table has no primary key") 106 | } 107 | 108 | tableName := modelPtr.TableName() 109 | if tableName == "" { 110 | return nil, xerrors.Errorf("empty table name") 111 | } 112 | return &ModelMetadata{ 113 | TableName: tableName, 114 | AutoIncrementField: autoIncrementField, 115 | PrimaryKeyColumns: primaryKeyColumns, 116 | PrimaryKeyValues: primaryKeyValues, 117 | Values: q.NewKeyIterator(data), 118 | }, nil 119 | } 120 | 121 | func QueryForUpdateModel( 122 | updateStructPtr ModelUpdate, 123 | where q.Condition, 124 | ) (q.Query, error) { 125 | if updateStructPtr == nil { 126 | return nil, xerrors.Errorf("pointer is nil") 127 | } 128 | objValue := reflect.ValueOf(updateStructPtr) 129 | objType := objValue.Type() 130 | if objType.Kind() != reflect.Ptr || objType.Elem().Kind() != reflect.Struct { 131 | return nil, xerrors.Errorf("must be pointer of struct") 132 | } 133 | objType = objType.Elem() 134 | values := make(map[string]any) 135 | if objType.NumField() == 0 { 136 | return nil, xerrors.Errorf("struct has no field") 137 | } 138 | 139 | for i := 0; i < objType.NumField(); i++ { 140 | f := objType.Field(i) 141 | tag, ok := f.Tag.Lookup("exql") 142 | if !ok { 143 | continue 144 | } 145 | var colName string 146 | if tags, err := ParseTags(tag); err != nil { 147 | return nil, err 148 | } else if col, ok := tags["column"]; !ok { 149 | return nil, xerrors.Errorf("tag must include column") 150 | } else { 151 | colName = col 152 | } 153 | if f.Type.Kind() != reflect.Ptr { 154 | return nil, xerrors.Errorf("field must be pointer") 155 | } 156 | fieldValue := objValue.Elem().Field(i) 157 | if !fieldValue.IsNil() { 158 | values[colName] = fieldValue.Elem().Interface() 159 | } 160 | } 161 | if len(values) == 0 { 162 | return nil, xerrors.Errorf("no value for update") 163 | } 164 | 165 | tableName := updateStructPtr.UpdateTableName() 166 | if tableName == "" { 167 | return nil, xerrors.Errorf("empty table name") 168 | } 169 | b := q.NewBuilder() 170 | b.Sprintf("UPDATE `%s`", tableName) 171 | b.Query("SET :? WHERE :?", q.Set(values), where) 172 | return b.Build(), nil 173 | } 174 | -------------------------------------------------------------------------------- /query/builder.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // Builder is a dynamic SQL query builder. 8 | type Builder struct { 9 | qs []Query 10 | } 11 | 12 | // Sprintf is short-hand for fmt.Sprintf. 13 | // 14 | // Example: 15 | // 16 | // b.Sprintf("%s", "go") 17 | // 18 | // is the same as: 19 | // 20 | // b.Query(fmt.Sprintf("%s", "go")) 21 | func (b *Builder) Sprintf(str string, args ...any) *Builder { 22 | return b.Query(fmt.Sprintf(str, args...)) 23 | } 24 | 25 | // Query appends the given query component and arguments into the buffer. 26 | // 27 | // Example: 28 | // 29 | // b.Query(":?", query.V(1,2)) 30 | // 31 | // is the same as: 32 | // 33 | // b.Add(query.Q(":?", query.V(1,2))) 34 | func (b *Builder) Query(str string, args ...any) *Builder { 35 | b.qs = append(b.qs, Q(str, args...)) 36 | return b 37 | } 38 | 39 | // Add appends given Queries components. 40 | func (b *Builder) Add(q ...Query) *Builder { 41 | b.qs = append(b.qs, q...) 42 | return b 43 | } 44 | 45 | // Build constructs final SQL statement, joining by single space(" "). 46 | func (b *Builder) Build() Query { 47 | return b.Join(" ") 48 | } 49 | 50 | // Clone makes a shallow copy of the builder. 51 | func (b *Builder) Clone() *Builder { 52 | return NewBuilder(b.qs...) 53 | } 54 | 55 | // Join joins accumulative query components by given separator. 56 | func (b *Builder) Join(sep string) Query { 57 | c := &chain{joiner: sep} 58 | c.append(b.qs...) 59 | return c 60 | } 61 | 62 | func NewBuilder(base ...Query) *Builder { 63 | return &Builder{qs: base} 64 | } 65 | -------------------------------------------------------------------------------- /query/builder_test.go: -------------------------------------------------------------------------------- 1 | package query_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/loilo-inc/exql/v2/query" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func assertQuery(t *testing.T, q query.Query, str string, args ...any) { 11 | stmt, vals, err := q.Query() 12 | assert.NoError(t, err) 13 | assert.Equal(t, str, stmt) 14 | assert.ElementsMatch(t, args, vals) 15 | } 16 | func assertQueryErr(t *testing.T, q query.Query, msg string) { 17 | stmt, vals, err := q.Query() 18 | assert.EqualError(t, err, msg) 19 | assert.Equal(t, "", stmt) 20 | assert.Nil(t, vals) 21 | } 22 | func TestBuilder(t *testing.T) { 23 | t.Run("Sprintf", func(t *testing.T) { 24 | assertQuery(t, 25 | query.NewBuilder().Sprintf("this is %s", "str").Build(), 26 | "this is str", 27 | ) 28 | }) 29 | t.Run("Qprintf", func(t *testing.T) { 30 | assertQuery(t, 31 | query.NewBuilder().Query("(:?)", query.Q("id = ?", 1)).Build(), 32 | "(id = ?)", 1, 33 | ) 34 | }) 35 | t.Run("Query", func(t *testing.T) { 36 | assertQuery(t, 37 | query.NewBuilder().Query("id = ?", 1).Build(), 38 | "id = ?", 1, 39 | ) 40 | }) 41 | t.Run("Add", func(t *testing.T) { 42 | assertQuery(t, 43 | query.NewBuilder().Add(query.Q("id = ?", 1)).Build(), 44 | "id = ?", 1, 45 | ) 46 | }) 47 | t.Run("Clone", func(t *testing.T) { 48 | base := query.NewBuilder().Query("id = ?", 1) 49 | copy := base.Clone() 50 | assertQuery(t, base.Build(), "id = ?", 1) 51 | assertQuery(t, copy.Build(), "id = ?", 1) 52 | }) 53 | } 54 | -------------------------------------------------------------------------------- /query/query.go: -------------------------------------------------------------------------------- 1 | //go:generate mockgen -source $GOFILE -destination ../mocks/mock_$GOPACKAGE/$GOFILE -package mock_$GOPACKAGE 2 | package query 3 | 4 | import ( 5 | "regexp" 6 | "strings" 7 | 8 | "golang.org/x/xerrors" 9 | ) 10 | 11 | type Query interface { 12 | Query() (string, []any, error) 13 | } 14 | 15 | type query struct { 16 | query string 17 | args []any 18 | err error 19 | } 20 | 21 | func errQuery(err error) Query { 22 | return &query{err: err} 23 | } 24 | 25 | func (f *query) Query() (sqlStmt string, sqlArgs []any, resErr error) { 26 | if f.err != nil { 27 | resErr = f.err 28 | return 29 | } 30 | str := f.query 31 | args := f.args 32 | sb := &strings.Builder{} 33 | var argIdx = 0 34 | reg := regexp.MustCompile(`:?\?`) 35 | for { 36 | match := reg.FindStringIndex(str) 37 | if match == nil { 38 | break 39 | } 40 | if argIdx == len(args) { 41 | resErr = xerrors.Errorf("missing argument at %d", argIdx) 42 | return 43 | } 44 | mStart := match[0] 45 | mEnd := match[1] 46 | if mEnd-mStart == 2 { 47 | // :? 48 | if q, ok := args[argIdx].(Query); !ok { 49 | resErr = xerrors.Errorf("unexpected argument type for :? placeholder at %d", argIdx) 50 | return 51 | } else if stmt, vals, err := q.Query(); err != nil { 52 | resErr = err 53 | return 54 | } else { 55 | pre := str[:mStart] 56 | sb.WriteString(pre) 57 | sb.WriteString(stmt) 58 | sqlArgs = append(sqlArgs, vals...) 59 | } 60 | } else { 61 | // ? 62 | sb.WriteString(str[:mEnd]) 63 | sqlArgs = append(sqlArgs, args[argIdx]) 64 | } 65 | str = str[mEnd:] 66 | argIdx += 1 67 | } 68 | if len(args) != argIdx { 69 | resErr = xerrors.Errorf("arguments count mismatch: found %d, got %d", argIdx, len(args)) 70 | return 71 | } 72 | if len(str) > 0 { 73 | sb.WriteString(str) 74 | } 75 | sqlStmt = sb.String() 76 | if resErr = guardQuery(sqlStmt); resErr != nil { 77 | return 78 | } 79 | return sqlStmt, sqlArgs, nil 80 | } 81 | 82 | type Condition interface { 83 | Query 84 | And(str string, args ...any) 85 | Or(str string, args ...any) 86 | AndCond(other Condition) 87 | OrCond(other Condition) 88 | } 89 | 90 | func Cond(str string, args ...any) Condition { 91 | return CondFrom(Q(str, args...)) 92 | } 93 | 94 | func CondFrom(q ...Query) Condition { 95 | base := &chain{ 96 | joiner: " ", 97 | list: q, 98 | } 99 | return &cond{base: base} 100 | } 101 | 102 | type cond struct { 103 | base *chain 104 | } 105 | 106 | func (c *cond) And(str string, args ...any) { 107 | c.append("AND", Q(str, args...)) 108 | } 109 | 110 | func (c *cond) Or(str string, args ...any) { 111 | c.append("OR", Q(str, args...)) 112 | } 113 | 114 | func (c *cond) AndCond(other Condition) { 115 | c.append("AND", other) 116 | } 117 | 118 | func (c *cond) OrCond(other Condition) { 119 | c.append("OR", other) 120 | } 121 | 122 | func (c *cond) Query() (string, []any, error) { 123 | return c.base.Query() 124 | } 125 | 126 | func (c *cond) append(sep string, other ...Query) { 127 | joiner := Q(sep) 128 | for _, v := range other { 129 | if len(c.base.list) == 0 { 130 | c.base.append(v) 131 | } else { 132 | c.base.append(joiner, v) 133 | } 134 | } 135 | } 136 | 137 | type chain struct { 138 | joiner string 139 | list []Query 140 | } 141 | 142 | func (c *chain) append(other ...Query) { 143 | c.list = append(c.list, other...) 144 | } 145 | 146 | func (c *chain) Query() (string, []any, error) { 147 | var strs []string 148 | var args []any 149 | for _, v := range c.list { 150 | if s, v, err := v.Query(); err != nil { 151 | return "", nil, err 152 | } else { 153 | strs = append(strs, s) 154 | args = append(args, v...) 155 | } 156 | } 157 | stmt := strings.Join(strs, c.joiner) 158 | if err := guardQuery(stmt); err != nil { 159 | return "", nil, err 160 | } 161 | return stmt, args, nil 162 | } 163 | 164 | // New returns Query based on given query and arguments. 165 | // First argument query can contain exql placeholder format (:?) with the corresponding Query in rest arguments. 166 | // Given query component will be interpolated internally and embedded into the final SQL statement. 167 | // Except (:?) placeholders, all static statements will be embedded barely with no assertions. 168 | // You must pay attention to the input query if it is variable. 169 | func New(q string, args ...any) Query { 170 | return NewBuilder().Query(q, args...).Build() 171 | } 172 | 173 | // Q is a short-hand version of New. 174 | func Q(q string, args ...any) Query { 175 | return &query{ 176 | query: q, 177 | args: args, 178 | } 179 | } 180 | 181 | // Cols wraps given identifiers like column, and table with backquote as possible. 182 | // It is used for embedding table names or columns into queries dynamically. 183 | // If multiple values are given, they will be joined by a comma(,). 184 | // 185 | // Example: 186 | // 187 | // Cols("aaa","bbb") // `aaa`,`bbb` 188 | // Cols("users.*") // `users`.* 189 | func Cols(cols ...string) Query { 190 | if len(cols) == 0 { 191 | return errQuery(xerrors.Errorf("empty columns")) 192 | } 193 | return &query{ 194 | query: QuoteColumns(cols...), 195 | } 196 | } 197 | 198 | // V wraps one or more values for the prepared statement. 199 | // It counts number of values and interpolates Go's SQL placeholder(?), holding values for later. 200 | // Multiple values will be joined by comma(,). 201 | // 202 | // Example: 203 | // 204 | // V(1,"a") // ?,? -> query | [1,"a"] -> arguments 205 | // 206 | // The code below 207 | // 208 | // db.Query(query.New("select * from users where id in (:?)", query.V(1,2))) 209 | // 210 | // is the same as: 211 | // 212 | // db.DB().Query("select * from users where id in (?,?)", 1, 2) 213 | func V(a ...any) Query { 214 | return &query{ 215 | query: Placeholders(len(a)), 216 | args: a, 217 | } 218 | } 219 | 220 | // Vals is another form of V that accepts a slice in generic type. 221 | func Vals[T any](vals []T) Query { 222 | if len(vals) == 0 { 223 | return errQuery(xerrors.Errorf("empty values")) 224 | } 225 | var args []any 226 | for _, v := range vals { 227 | args = append(args, v) 228 | } 229 | return &query{ 230 | query: Placeholders(len(vals)), 231 | args: args, 232 | } 233 | } 234 | 235 | // Set transforms map into "key = value" assignment expression in SQL. 236 | // Example: 237 | // 238 | // values := map[string]any{ "name": "go", "age": 20} 239 | // db.Exec("update users set :? where id = ?", query.Set(values, 1)) 240 | // 241 | // is the same as: 242 | // 243 | // db.DB().Exec("update users set age = ?, name = ? where id = ?", 20, "go", 1) 244 | func Set(m map[string]any) Query { 245 | if len(m) == 0 { 246 | return errQuery(xerrors.Errorf("empty values for set clause")) 247 | } 248 | b := NewBuilder() 249 | it := NewKeyIterator(m) 250 | for i := 0; i < it.Size(); i++ { 251 | k, v := it.Get(i) 252 | b.Query(":? = ?", Cols(k), v) 253 | } 254 | return b.Join(",") 255 | } 256 | -------------------------------------------------------------------------------- /query/query_test.go: -------------------------------------------------------------------------------- 1 | package query_test 2 | 3 | import ( 4 | "testing" 5 | 6 | q "github.com/loilo-inc/exql/v2/query" 7 | ) 8 | 9 | func TestQuery(t *testing.T) { 10 | assertQuery(t, q.V(1, 2), "?,?", 1, 2) 11 | assertQuery(t, q.Vals([]int{1, 2}), "?,?", 1, 2) 12 | assertQuery(t, q.Cols("a.b", "c.*"), "`a`.`b`,`c`.*") 13 | assertQuery(t, q.Q("id = ?", 1), "id = ?", 1) 14 | assertQuery(t, 15 | q.Set(map[string]any{ 16 | "a": "a", 17 | "b": "b", 18 | "a.b.*": "ab*", 19 | "`c`": "c", 20 | }), 21 | "`c` = ?,`a` = ?,`a`.`b`.* = ?,`b` = ?", "c", "a", "ab*", "b", 22 | ) 23 | assertQueryErr(t, q.Q(""), "DANGER: empty query") 24 | assertQueryErr(t, q.Vals[any](nil), "empty values") 25 | assertQueryErr(t, q.Cols(), "empty columns") 26 | assertQueryErr(t, q.Set(map[string]any{}), "empty values for set clause") 27 | } 28 | 29 | func TestNew(t *testing.T) { 30 | assertQuery(t, 31 | q.New("id in (:?) and name = ? and more", q.Vals([]int{1, 2}), "go"), 32 | "id in (?,?) and name = ? and more", 1, 2, "go", 33 | ) 34 | assertQueryErr(t, q.New(""), "DANGER: empty query") 35 | assertQueryErr(t, q.New(":?", q.Q("")), "DANGER: empty query") 36 | assertQueryErr(t, q.New("?"), "missing argument at 0") 37 | assertQueryErr(t, q.New("?,?", 1), "missing argument at 1") 38 | assertQueryErr(t, q.New(":?", 1), "unexpected argument type for :? placeholder at 0") 39 | assertQueryErr(t, q.New("?", 1, 2), "arguments count mismatch: found 1, got 2") 40 | } 41 | 42 | func TestCondition(t *testing.T) { 43 | t.Run("basic", func(t *testing.T) { 44 | cond := q.Cond("id = ?", 1) 45 | cond.And("name = ?", "go") 46 | cond.Or("age in (:?)", q.V(20, 21)) 47 | cond.AndCond(q.Cond("foo = ?", "foo")) 48 | cond.OrCond(q.Cond("var = ?", "var")) 49 | assertQuery(t, cond, 50 | "id = ? AND name = ? OR age in (?,?) AND foo = ? OR var = ?", 51 | 1, "go", 20, 21, "foo", "var", 52 | ) 53 | }) 54 | t.Run("should error if query retuerned an error", func(t *testing.T) { 55 | cond := q.CondFrom(q.Q("")) 56 | assertQueryErr(t, cond, "DANGER: empty query") 57 | }) 58 | } 59 | -------------------------------------------------------------------------------- /query/util.go: -------------------------------------------------------------------------------- 1 | package query 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "sort" 7 | "strings" 8 | ) 9 | 10 | type keyIterator[T any] struct { 11 | keys []string 12 | values []T 13 | } 14 | 15 | type KeyIterator[T any] interface { 16 | Get(i int) (string, T) 17 | Keys() []string 18 | Values() []T 19 | Size() int 20 | Map() map[string]T 21 | } 22 | 23 | func NewKeyIterator[T any](data map[string]T) KeyIterator[T] { 24 | var keys []string 25 | for k := range data { 26 | keys = append(keys, k) 27 | } 28 | sort.Slice(keys, func(i, j int) bool { 29 | return strings.Compare(keys[i], keys[j]) < 0 30 | }) 31 | var values []T 32 | for _, v := range keys { 33 | values = append(values, data[v]) 34 | } 35 | return &keyIterator[T]{keys: keys, values: values} 36 | } 37 | 38 | func (e *keyIterator[T]) Get(i int) (string, T) { 39 | k := e.keys[i] 40 | v := e.values[i] 41 | return k, v 42 | } 43 | 44 | func (e *keyIterator[T]) Size() int { 45 | return len(e.keys) 46 | } 47 | 48 | func (k *keyIterator[T]) Keys() []string { 49 | return k.keys 50 | } 51 | 52 | func (k *keyIterator[T]) Values() []T { 53 | return k.values 54 | } 55 | 56 | func (k *keyIterator[T]) Map() map[string]T { 57 | res := map[string]T{} 58 | for i := 0; i < k.Size(); i++ { 59 | res[k.keys[i]] = k.values[i] 60 | } 61 | return res 62 | } 63 | 64 | // Placeholders makes n-th repeats of Go's SQL placeholder(?), 65 | // joining them by comma(,). 66 | // 67 | // Example: 68 | // 69 | // Placeholder(2) // ?,? 70 | func Placeholders(repeat int) string { 71 | res := make([]string, repeat) 72 | for i := 0; i < repeat; i++ { 73 | res[i] = "?" 74 | } 75 | return strings.Join(res, ",") 76 | } 77 | 78 | // QuoteColumns quotes each string and joins them by comma(,). 79 | func QuoteColumns(str ...string) string { 80 | var result []string 81 | for _, v := range str { 82 | result = append(result, QuoteColumn(v)) 83 | } 84 | return strings.Join(result, ",") 85 | } 86 | 87 | func guardQuery(q string) error { 88 | if q == "" { 89 | return errors.New("DANGER: empty query") 90 | } 91 | return nil 92 | } 93 | 94 | // QuoteColumn surrounds SQL identifiers with backquote, 95 | // keeping some meta-characters "*", ".", "`" intact. 96 | // 97 | // Example: 98 | // 99 | // QuoteColumn("users.id") // `users`.`id` 100 | // QuoteColumn("users.*") // `users`.* 101 | func QuoteColumn(col string) string { 102 | var sb strings.Builder 103 | var start = 0 104 | var end = len(col) 105 | for i := 0; i < end; i++ { 106 | char := col[i] 107 | if char == '.' || char == '*' || char == '`' { 108 | if start != i { 109 | sb.WriteString(fmt.Sprintf("`%s`", col[start:i])) 110 | } 111 | if char != '`' { 112 | sb.WriteByte(char) 113 | } 114 | start = i + 1 115 | } 116 | } 117 | if start < end { 118 | sb.WriteString(fmt.Sprintf("`%s`", col[start:end])) 119 | } 120 | return sb.String() 121 | } 122 | -------------------------------------------------------------------------------- /query/util_test.go: -------------------------------------------------------------------------------- 1 | package query_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/loilo-inc/exql/v2/query" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestKeyIteretor(t *testing.T) { 11 | src := map[string]any{ 12 | "a": 1, 13 | "b": 2, 14 | "c": 3, 15 | } 16 | it := NewKeyIterator(src) 17 | assert.Equal(t, it.Size(), 3) 18 | assert.ElementsMatch(t, it.Keys(), []string{"a", "b", "c"}) 19 | assert.ElementsMatch(t, it.Values(), []any{1, 2, 3}) 20 | for i := 0; i < it.Size(); i++ { 21 | k, v := it.Get(i) 22 | assert.Equal(t, it.Keys()[i], k) 23 | assert.Equal(t, it.Values()[i], v) 24 | } 25 | assert.InDeltaMapValues(t, it.Map(), src, 0) 26 | } 27 | 28 | func TestSqlPraceholder(t *testing.T) { 29 | assert.Equal(t, "", Placeholders(0)) 30 | assert.Equal(t, "?", Placeholders(1)) 31 | assert.Equal(t, "?,?,?", Placeholders(3)) 32 | } 33 | 34 | func TestCuoteColumn(t *testing.T) { 35 | assert.Equal(t, "", QuoteColumn("")) 36 | assert.Equal(t, "`table`", QuoteColumn("table")) 37 | assert.Equal(t, "`users`.`id`", QuoteColumn("users.id")) 38 | assert.Equal(t, "`users`.`id`", QuoteColumn("`users`.`id`")) 39 | assert.Equal(t, "`users`.`id`", QuoteColumn("`users`.id")) 40 | assert.Equal(t, "`users`.*", QuoteColumn("users.*")) 41 | assert.Equal(t, "`users`.", QuoteColumn("users.")) 42 | } 43 | 44 | func TestQuoteColumns(t *testing.T) { 45 | assert.Equal(t, "`a`,`b`", QuoteColumns("a", "b")) 46 | } 47 | -------------------------------------------------------------------------------- /query_test.go: -------------------------------------------------------------------------------- 1 | package exql_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/loilo-inc/exql/v2" 7 | "github.com/loilo-inc/exql/v2/model" 8 | "github.com/loilo-inc/exql/v2/model/testmodel" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestQueryWhere(t *testing.T) { 13 | t.Run("Where", func(t *testing.T) { 14 | v, args, err := exql.Where("q = ?", 1).Query() 15 | assert.NoError(t, err) 16 | assert.Equal(t, "q = ?", v) 17 | assert.ElementsMatch(t, []any{1}, args) 18 | }) 19 | } 20 | func TestQueryForInsert(t *testing.T) { 21 | t.Run("basic", func(t *testing.T) { 22 | user := model.Users{ 23 | Name: "go", Age: 10, 24 | } 25 | s, f, err := exql.QueryForInsert(&user) 26 | assert.NoError(t, err) 27 | assert.NotNil(t, f) 28 | exp := "INSERT INTO `users` (`age`,`name`) VALUES (?,?)" 29 | stmt, args, err := s.Query() 30 | assert.NoError(t, err) 31 | assert.Equal(t, exp, stmt) 32 | assert.ElementsMatch(t, args, []any{user.Age, user.Name}) 33 | }) 34 | } 35 | 36 | func TestQueryForBulkInsert(t *testing.T) { 37 | t.Run("basic", func(t *testing.T) { 38 | q, err := exql.QueryForBulkInsert( 39 | &model.Users{Age: 1, Name: "one"}, 40 | &model.Users{Age: 2, Name: "two"}, 41 | ) 42 | assert.NoError(t, err) 43 | stmt, args, err := q.Query() 44 | assert.NoError(t, err) 45 | assert.Equal(t, "INSERT INTO `users` (`age`,`name`) VALUES (?,?),(?,?)", stmt) 46 | assert.ElementsMatch(t, []any{int64(1), "one", int64(2), "two"}, args) 47 | }) 48 | t.Run("error if args empty", func(t *testing.T) { 49 | q, err := exql.QueryForBulkInsert[*model.Users]() 50 | assert.Nil(t, q) 51 | assert.EqualError(t, err, "empty list") 52 | }) 53 | } 54 | 55 | func TestAggregateModelMetadata(t *testing.T) { 56 | t.Run("basic", func(t *testing.T) { 57 | m, err := exql.AggregateModelMetadata(&model.Users{Name: "go", Age: 10}) 58 | assert.NoError(t, err) 59 | assert.Equal(t, "users", m.TableName) 60 | assert.NotNil(t, m.AutoIncrementField) 61 | assert.ElementsMatch(t, []string{"id"}, m.PrimaryKeyColumns) 62 | assert.ElementsMatch(t, []any{int64(0)}, m.PrimaryKeyValues) 63 | assert.ElementsMatch(t, []string{"age", "name"}, m.Values.Keys()) 64 | assert.ElementsMatch(t, []any{int64(10), "go"}, m.Values.Values()) 65 | }) 66 | t.Run("multiple primary key", func(t *testing.T) { 67 | data := &testmodel.MultiplePrimaryKey{ 68 | Pk1: "val1", 69 | Pk2: "val2", 70 | Other: 1, 71 | } 72 | md, err := exql.AggregateModelMetadata(data) 73 | assert.NoError(t, err) 74 | assert.Equal(t, data.TableName(), md.TableName) 75 | assert.Nil(t, md.AutoIncrementField) 76 | assert.ElementsMatch(t, []string{"pk1", "pk2"}, md.PrimaryKeyColumns) 77 | assert.ElementsMatch(t, []any{"val1", "val2"}, md.PrimaryKeyValues) 78 | assert.ElementsMatch(t, []string{"pk1", "pk2", "other"}, md.Values.Keys()) 79 | assert.ElementsMatch(t, []any{"val1", "val2", 1}, md.Values.Values()) 80 | }) 81 | assertInvalid := func(t *testing.T, m exql.Model, e string) { 82 | s, f, err := exql.QueryForInsert(m) 83 | assert.Nil(t, s) 84 | assert.Nil(t, f) 85 | assert.EqualError(t, err, e) 86 | } 87 | t.Run("should error if dest is nil", func(t *testing.T) { 88 | assertInvalid(t, nil, "pointer is nil") 89 | }) 90 | t.Run("should error if TableName() doesn't return string", func(t *testing.T) { 91 | assertInvalid(t, &testmodel.BadTableName{}, "empty table name") 92 | }) 93 | t.Run("should error if field doesn't have column tag", func(t *testing.T) { 94 | assertInvalid(t, &testmodel.NoColumnTag{}, "column tag is not set") 95 | }) 96 | t.Run("should error if field tag is invalid", func(t *testing.T) { 97 | assertInvalid(t, &testmodel.BadTag{}, "duplicated tag: a") 98 | }) 99 | t.Run("should error if dest has no primary key tag", func(t *testing.T) { 100 | assertInvalid(t, &testmodel.NoPrimaryKey{}, "table has no primary key") 101 | }) 102 | t.Run("shoud error if no exql tags found", func(t *testing.T) { 103 | assertInvalid(t, &testmodel.NoTag{}, "obj doesn't have exql tags in any fields") 104 | }) 105 | } 106 | 107 | func TestQueryForUpdateModel(t *testing.T) { 108 | t.Run("basic", func(t *testing.T) { 109 | name := "go" 110 | age := int64(20) 111 | q, err := exql.QueryForUpdateModel(&model.UpdateUsers{ 112 | Name: &name, 113 | Age: &age, 114 | }, exql.Where(`id = ?`, 1)) 115 | if err != nil { 116 | t.Fatal(err) 117 | } 118 | stmt, args, err := q.Query() 119 | assert.NoError(t, err) 120 | assert.Equal(t, stmt, 121 | "UPDATE `users` SET `age` = ?,`name` = ? WHERE id = ?", 122 | ) 123 | assert.ElementsMatch(t, []any{age, name, 1}, args) 124 | }) 125 | t.Run("should error if pointer is nil", func(t *testing.T) { 126 | _, err := exql.QueryForUpdateModel(nil, nil) 127 | assert.EqualError(t, err, "pointer is nil") 128 | }) 129 | t.Run("should error if has invalid tag", func(t *testing.T) { 130 | _, err := exql.QueryForUpdateModel(&upSampleInvalidTag{}, nil) 131 | assert.EqualError(t, err, "invalid tag format") 132 | }) 133 | t.Run("should error if field is not pointer", func(t *testing.T) { 134 | _, err := exql.QueryForUpdateModel(&upSampleNotPtr{}, nil) 135 | assert.EqualError(t, err, "field must be pointer") 136 | }) 137 | t.Run("should ignore if field is nil", func(t *testing.T) { 138 | _, err := exql.QueryForUpdateModel(&upSample{}, nil) 139 | assert.EqualError(t, err, "no value for update") 140 | }) 141 | t.Run("should error if struct has no fields", func(t *testing.T) { 142 | _, err := exql.QueryForUpdateModel(&upSampleNoFields{}, nil) 143 | assert.EqualError(t, err, "struct has no field") 144 | }) 145 | t.Run("should error if struct doesn't implement ForTableName()", func(t *testing.T) { 146 | id := 1 147 | _, err := exql.QueryForUpdateModel(&upSample{Id: &id}, nil) 148 | assert.EqualError(t, err, "empty table name") 149 | }) 150 | t.Run("should error if no column in tag", func(t *testing.T) { 151 | id := 1 152 | _, err := exql.QueryForUpdateModel(&upSampleNoColumn{Id: &id}, nil) 153 | assert.EqualError(t, err, "tag must include column") 154 | }) 155 | } 156 | -------------------------------------------------------------------------------- /saver.go: -------------------------------------------------------------------------------- 1 | //go:generate mockgen -source $GOFILE -destination ./mocks/mock_$GOPACKAGE/$GOFILE -package mock_$GOPACKAGE 2 | package exql 3 | 4 | import ( 5 | "context" 6 | "database/sql" 7 | "errors" 8 | "reflect" 9 | 10 | q "github.com/loilo-inc/exql/v2/query" 11 | ) 12 | 13 | type Saver interface { 14 | Insert(structPtr Model) (sql.Result, error) 15 | InsertContext(ctx context.Context, structPtr Model) (sql.Result, error) 16 | Update(table string, set map[string]any, where q.Condition) (sql.Result, error) 17 | UpdateModel(updaterStructPtr ModelUpdate, where q.Condition) (sql.Result, error) 18 | UpdateContext(ctx context.Context, table string, set map[string]any, where q.Condition) (sql.Result, error) 19 | UpdateModelContext(ctx context.Context, updaterStructPtr ModelUpdate, where q.Condition) (sql.Result, error) 20 | Delete(table string, where q.Condition) (sql.Result, error) 21 | DeleteContext(ctx context.Context, table string, where q.Condition) (sql.Result, error) 22 | Exec(query q.Query) (sql.Result, error) 23 | ExecContext(ctx context.Context, query q.Query) (sql.Result, error) 24 | Query(query q.Query) (*sql.Rows, error) 25 | QueryContext(ctx context.Context, query q.Query) (*sql.Rows, error) 26 | QueryRow(query q.Query) (*sql.Row, error) 27 | QueryRowContext(ctx context.Context, query q.Query) (*sql.Row, error) 28 | } 29 | 30 | type saver struct { 31 | ex Executor 32 | } 33 | 34 | func NewSaver(ex Executor) Saver { 35 | return &saver{ex: ex} 36 | } 37 | 38 | func newSaver(ex Executor) *saver { 39 | return &saver{ex: ex} 40 | } 41 | 42 | func (s *saver) Insert(modelPtr Model) (sql.Result, error) { 43 | return s.InsertContext(context.Background(), modelPtr) 44 | } 45 | 46 | func (s *saver) InsertContext(ctx context.Context, modelPtr Model) (sql.Result, error) { 47 | q, autoIncrField, err := QueryForInsert(modelPtr) 48 | if err != nil { 49 | return nil, err 50 | } 51 | result, err := s.ExecContext(ctx, q) 52 | if err != nil { 53 | return nil, err 54 | } 55 | if autoIncrField != nil { 56 | lid, err := result.LastInsertId() 57 | if err != nil { 58 | return nil, err 59 | } 60 | kind := autoIncrField.Kind() 61 | if kind == reflect.Int64 { 62 | autoIncrField.Set(reflect.ValueOf(lid)) 63 | } else if kind == reflect.Uint64 { 64 | autoIncrField.Set(reflect.ValueOf(uint64(lid))) 65 | } 66 | } 67 | return result, nil 68 | } 69 | 70 | func (s *saver) Update( 71 | table string, 72 | set map[string]any, 73 | where q.Condition, 74 | ) (sql.Result, error) { 75 | return s.UpdateContext(context.Background(), table, set, where) 76 | } 77 | 78 | func (s *saver) UpdateContext( 79 | ctx context.Context, 80 | table string, 81 | set map[string]any, 82 | where q.Condition, 83 | ) (sql.Result, error) { 84 | if table == "" { 85 | return nil, errors.New("empty table name for update query") 86 | } else if where == nil { 87 | return nil, errors.New("nil condition for update query") 88 | } 89 | b := q.NewBuilder() 90 | b.Sprintf("UPDATE `%s`", table) 91 | b.Query("SET :? WHERE :?", q.Set(set), where) 92 | return s.ExecContext(ctx, b.Build()) 93 | } 94 | 95 | func (s *saver) Delete(from string, where q.Condition) (sql.Result, error) { 96 | return s.DeleteContext(context.Background(), from, where) 97 | } 98 | 99 | func (s *saver) DeleteContext(ctx context.Context, from string, where q.Condition) (sql.Result, error) { 100 | if from == "" { 101 | return nil, errors.New("empty table name for delete query") 102 | } else if where == nil { 103 | return nil, errors.New("nil condition for delete query") 104 | } 105 | b := q.NewBuilder() 106 | b.Sprintf("DELETE FROM `%s`", from) 107 | b.Query("WHERE :?", where) 108 | return s.ExecContext(ctx, b.Build()) 109 | } 110 | 111 | func (s *saver) UpdateModel( 112 | ptr ModelUpdate, 113 | where q.Condition, 114 | ) (sql.Result, error) { 115 | return s.UpdateModelContext(context.Background(), ptr, where) 116 | } 117 | 118 | func (s *saver) UpdateModelContext( 119 | ctx context.Context, 120 | ptr ModelUpdate, 121 | where q.Condition, 122 | ) (sql.Result, error) { 123 | q, err := QueryForUpdateModel(ptr, where) 124 | if err != nil { 125 | return nil, err 126 | } 127 | return s.ExecContext(ctx, q) 128 | } 129 | 130 | func (s *saver) Exec(query q.Query) (sql.Result, error) { 131 | if stmt, args, err := query.Query(); err != nil { 132 | return nil, err 133 | } else { 134 | return s.ex.Exec(stmt, args...) 135 | } 136 | } 137 | 138 | func (s *saver) ExecContext(ctx context.Context, query q.Query) (sql.Result, error) { 139 | if stmt, args, err := query.Query(); err != nil { 140 | return nil, err 141 | } else { 142 | return s.ex.ExecContext(ctx, stmt, args...) 143 | } 144 | } 145 | 146 | func (s *saver) Query(query q.Query) (*sql.Rows, error) { 147 | if stmt, args, err := query.Query(); err != nil { 148 | return nil, err 149 | } else { 150 | return s.ex.Query(stmt, args...) 151 | } 152 | } 153 | 154 | func (s *saver) QueryContext(ctx context.Context, query q.Query) (*sql.Rows, error) { 155 | if stmt, args, err := query.Query(); err != nil { 156 | return nil, err 157 | } else { 158 | return s.ex.QueryContext(ctx, stmt, args...) 159 | } 160 | } 161 | 162 | func (s *saver) QueryRow(query q.Query) (*sql.Row, error) { 163 | if stmt, args, err := query.Query(); err != nil { 164 | return nil, err 165 | } else { 166 | return s.ex.QueryRow(stmt, args...), nil 167 | } 168 | } 169 | 170 | func (s *saver) QueryRowContext(ctx context.Context, query q.Query) (*sql.Row, error) { 171 | if stmt, args, err := query.Query(); err != nil { 172 | return nil, err 173 | } else { 174 | return s.ex.QueryRowContext(ctx, stmt, args...), nil 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /saver_test.go: -------------------------------------------------------------------------------- 1 | package exql_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | "time" 8 | 9 | "github.com/DATA-DOG/go-sqlmock" 10 | "github.com/golang/mock/gomock" 11 | "github.com/loilo-inc/exql/v2" 12 | "github.com/loilo-inc/exql/v2/mocks/mock_exql" 13 | "github.com/loilo-inc/exql/v2/mocks/mock_query" 14 | "github.com/loilo-inc/exql/v2/model" 15 | "github.com/loilo-inc/exql/v2/model/testmodel" 16 | q "github.com/loilo-inc/exql/v2/query" 17 | "github.com/stretchr/testify/assert" 18 | ) 19 | 20 | func TestSaver_Insert(t *testing.T) { 21 | d := testDb() 22 | s := exql.NewSaver(d.DB()) 23 | t.Run("basic", func(t *testing.T) { 24 | user := &model.Users{ 25 | Name: "go", Age: 10, 26 | } 27 | result, err := s.Insert(user) 28 | assert.NoError(t, err) 29 | assert.False(t, user.Id == 0) 30 | defer func() { 31 | d.DB().Exec(`DELETE FROM users WHERE id = ?`, user.Id) 32 | }() 33 | r, err := result.RowsAffected() 34 | assert.NoError(t, err) 35 | assert.Equal(t, int64(1), r) 36 | lid, err := result.LastInsertId() 37 | assert.NoError(t, err) 38 | assert.Equal(t, user.Id, lid) 39 | rows, err := d.DB().Query(`SELECT * FROM users WHERE id = ?`, lid) 40 | assert.NoError(t, err) 41 | var actual model.Users 42 | err = exql.MapRow(rows, &actual) 43 | assert.NoError(t, err) 44 | assert.Equal(t, lid, actual.Id) 45 | assert.Equal(t, user.Name, actual.Name) 46 | assert.Equal(t, user.Age, actual.Age) 47 | }) 48 | t.Run("should error if modelPtr is invalid", func(t *testing.T) { 49 | res, err := s.Insert(nil) 50 | assert.Error(t, err) 51 | assert.Nil(t, res) 52 | }) 53 | t.Run("should error if db.Exec() failed", func(t *testing.T) { 54 | db, mock, _ := sqlmock.New() 55 | mock.ExpectExec("INSERT INTO `users`").WithArgs(int64(0), "").WillReturnError(fmt.Errorf("err")) 56 | s := exql.NewSaver(db) 57 | user := &model.Users{} 58 | _, err := s.Insert(user) 59 | assert.EqualError(t, err, "err") 60 | }) 61 | t.Run("should error if result.LastInsertId() failed", func(t *testing.T) { 62 | db, mock, _ := sqlmock.New() 63 | mock.ExpectExec("INSERT INTO `users`").WithArgs(int64(0), "").WillReturnResult(sqlmock.NewErrorResult(fmt.Errorf("err"))) 64 | s := exql.NewSaver(db) 65 | user := &model.Users{} 66 | _, err := s.Insert(user) 67 | assert.EqualError(t, err, "err") 68 | }) 69 | t.Run("should assign lid to uint primary key", func(t *testing.T) { 70 | db, mock, _ := sqlmock.New() 71 | mock.ExpectExec("INSERT INTO `samplePrimaryUint64`").WillReturnResult(sqlmock.NewResult(11, 1)) 72 | s := exql.NewSaver(db) 73 | user := &testmodel.PrimaryUint64{} 74 | _, err := s.Insert(user) 75 | assert.NoError(t, err) 76 | assert.Equal(t, uint64(11), user.Id) 77 | }) 78 | t.Run("should not assign lid in case of not auto_increment", func(t *testing.T) { 79 | db, mock, _ := sqlmock.New() 80 | mock.ExpectExec("INSERT INTO `sampleNoAutoIncrementKey`").WillReturnResult(sqlmock.NewResult(11, 1)) 81 | s := exql.NewSaver(db) 82 | user := &testmodel.NoAutoIncrementKey{ 83 | Id: 1, 84 | } 85 | _, err := s.Insert(user) 86 | assert.NoError(t, err) 87 | assert.Equal(t, 1, user.Id) 88 | }) 89 | } 90 | 91 | func TestSaver_InsertContext(t *testing.T) { 92 | d := testDb() 93 | s := exql.NewSaver(d.DB()) 94 | t.Run("basic", func(t *testing.T) { 95 | user := &model.Users{ 96 | Name: "go", Age: 10, 97 | } 98 | result, err := s.InsertContext(context.Background(), user) 99 | assert.NoError(t, err) 100 | assert.False(t, user.Id == 0) 101 | defer func() { 102 | d.DB().Exec(`DELETE FROM users WHERE id = ?`, user.Id) 103 | }() 104 | r, err := result.RowsAffected() 105 | assert.NoError(t, err) 106 | assert.Equal(t, int64(1), r) 107 | lid, err := result.LastInsertId() 108 | assert.NoError(t, err) 109 | assert.Equal(t, user.Id, lid) 110 | rows, err := d.DB().Query(`SELECT * FROM users WHERE id = ?`, lid) 111 | assert.NoError(t, err) 112 | var actual model.Users 113 | err = exql.MapRow(rows, &actual) 114 | assert.NoError(t, err) 115 | assert.Equal(t, lid, actual.Id) 116 | assert.Equal(t, user.Name, actual.Name) 117 | assert.Equal(t, user.Age, actual.Age) 118 | }) 119 | t.Run("inserting to composite primary key table", func(t *testing.T) { 120 | history := &model.UserLoginHistories{ 121 | UserId: 1, 122 | CreatedAt: time.Now(), 123 | } 124 | result, err := s.InsertContext(context.Background(), history) 125 | assert.NoError(t, err) 126 | assert.False(t, history.Id == 0) 127 | defer func() { 128 | d.DB().Exec(`DELETE FROM user_login_histries WHERE id = ?`, history.Id) 129 | }() 130 | r, err := result.RowsAffected() 131 | assert.NoError(t, err) 132 | assert.Equal(t, int64(1), r) 133 | lid, err := result.LastInsertId() 134 | assert.NoError(t, err) 135 | assert.Equal(t, history.Id, lid) 136 | rows, err := d.DB().Query(`SELECT * FROM user_login_histories WHERE id = ?`, lid) 137 | assert.NoError(t, err) 138 | var actual model.UserLoginHistories 139 | err = exql.MapRow(rows, &actual) 140 | assert.NoError(t, err) 141 | assert.Equal(t, lid, actual.Id) 142 | assert.Equal(t, history.UserId, actual.UserId) 143 | assert.Equal(t, history.CreatedAt.Round(time.Second), actual.CreatedAt.Round(time.Second)) 144 | }) 145 | t.Run("inserting to composite primary key table", func(t *testing.T) { 146 | history := &model.UserLoginHistories{ 147 | UserId: 1, 148 | CreatedAt: time.Now(), 149 | } 150 | result, err := s.InsertContext(context.Background(), history) 151 | assert.NoError(t, err) 152 | assert.False(t, history.Id == 0) 153 | defer func() { 154 | d.DB().Exec(`DELETE FROM user_login_histries WHERE id = ?`, history.Id) 155 | }() 156 | r, err := result.RowsAffected() 157 | assert.NoError(t, err) 158 | assert.Equal(t, int64(1), r) 159 | lid, err := result.LastInsertId() 160 | assert.NoError(t, err) 161 | assert.Equal(t, history.Id, lid) 162 | rows, err := d.DB().Query(`SELECT * FROM user_login_histories WHERE id = ?`, lid) 163 | assert.NoError(t, err) 164 | var actual model.UserLoginHistories 165 | err = exql.MapRow(rows, &actual) 166 | assert.NoError(t, err) 167 | assert.Equal(t, lid, actual.Id) 168 | assert.Equal(t, history.UserId, actual.UserId) 169 | assert.Equal(t, history.CreatedAt.Round(time.Second), actual.CreatedAt.Round(time.Second)) 170 | }) 171 | t.Run("inserting to no auto_increment key table", func(t *testing.T) { 172 | user := &model.Users{ 173 | Name: "go", Age: 10, 174 | } 175 | result, err := s.InsertContext(context.Background(), user) 176 | assert.NoError(t, err) 177 | assert.False(t, user.Id == 0) 178 | defer func() { 179 | d.DB().Exec(`DELETE FROM users WHERE id = ?`, user.Id) 180 | }() 181 | r, err := result.RowsAffected() 182 | assert.NoError(t, err) 183 | assert.Equal(t, int64(1), r) 184 | lid, err := result.LastInsertId() 185 | assert.NoError(t, err) 186 | assert.Equal(t, user.Id, lid) 187 | rows, err := d.DB().Query(`SELECT * FROM users WHERE id = ?`, lid) 188 | assert.NoError(t, err) 189 | var actual model.Users 190 | err = exql.MapRow(rows, &actual) 191 | assert.NoError(t, err) 192 | assert.Equal(t, lid, actual.Id) 193 | assert.Equal(t, user.Name, actual.Name) 194 | assert.Equal(t, user.Age, actual.Age) 195 | }) 196 | } 197 | 198 | func TestSaver_Update(t *testing.T) { 199 | d := testDb() 200 | s := exql.NewSaver(d.DB()) 201 | t.Run("basic", func(t *testing.T) { 202 | result, err := d.DB().Exec( 203 | "INSERT INTO `users` (`age`,`name`) VALUES (?, ?)", 204 | int64(10), "go") 205 | assert.NoError(t, err) 206 | lid, err := result.LastInsertId() 207 | assert.NoError(t, err) 208 | defer func() { 209 | d.DB().Exec(`DELETE FROM users WHERE id = ?`, lid) 210 | }() 211 | result, err = s.Update("users", map[string]interface{}{ 212 | "name": "lang", 213 | "age": int64(20), 214 | }, exql.Where(`id = ?`, lid)) 215 | assert.NoError(t, err) 216 | ra, err := result.RowsAffected() 217 | assert.NoError(t, err) 218 | assert.Equal(t, int64(1), ra) 219 | var actual model.Users 220 | rows, err := d.DB().Query(`SELECT * FROM users WHERE id = ?`, lid) 221 | assert.NoError(t, err) 222 | err = exql.MapRow(rows, &actual) 223 | assert.NoError(t, err) 224 | assert.Equal(t, "lang", actual.Name) 225 | assert.Equal(t, int64(20), actual.Age) 226 | }) 227 | t.Run("should error if tableName is empty", func(t *testing.T) { 228 | q, err := s.Update("", nil, nil) 229 | assert.Nil(t, q) 230 | assert.EqualError(t, err, "empty table name for update query") 231 | }) 232 | t.Run("should error if where clause is nil", func(t *testing.T) { 233 | q, err := s.Update("users", make(map[string]interface{}), nil) 234 | assert.Nil(t, q) 235 | assert.EqualError(t, err, "nil condition for update query") 236 | }) 237 | t.Run("should error if map is empty", func(t *testing.T) { 238 | q, err := s.Update("users", make(map[string]interface{}), exql.Where("id = 1")) 239 | assert.Nil(t, q) 240 | assert.EqualError(t, err, "empty values for set clause") 241 | }) 242 | t.Run("should error if where clause is empty", func(t *testing.T) { 243 | q, err := s.Update("users", map[string]interface{}{"first_name": "go"}, exql.Where("")) 244 | assert.Nil(t, q) 245 | assert.EqualError(t, err, "DANGER: empty query") 246 | }) 247 | } 248 | 249 | func TestSaver_UpdateModel(t *testing.T) { 250 | t.Run("basic", func(t *testing.T) { 251 | db, mock, _ := sqlmock.New() 252 | s := exql.NewSaver(db) 253 | name := "lang" 254 | mock.ExpectExec( 255 | "UPDATE `users` SET `name` = \\? WHERE id = \\?", 256 | ).WithArgs(name, 1).WillReturnResult(sqlmock.NewResult(1, 1)) 257 | result, err := s.UpdateModel(&model.UpdateUsers{ 258 | Name: &name, 259 | }, exql.Where(`id = ?`, 1)) 260 | if err != nil { 261 | t.Fatal(err) 262 | } 263 | lid, _ := result.LastInsertId() 264 | row, _ := result.RowsAffected() 265 | assert.Equal(t, int64(1), row) 266 | assert.Equal(t, int64(1), lid) 267 | }) 268 | } 269 | 270 | func TestSaver_UpdateModelContext(t *testing.T) { 271 | t.Run("basic", func(t *testing.T) { 272 | db, mock, _ := sqlmock.New() 273 | s := exql.NewSaver(db) 274 | name := "name" 275 | mock.ExpectExec( 276 | "UPDATE `users` SET `name` = \\? WHERE id = \\?", 277 | ).WithArgs(name, 1).WillReturnResult(sqlmock.NewResult(1, 1)) 278 | result, err := s.UpdateModelContext(context.Background(), &model.UpdateUsers{ 279 | Name: &name, 280 | }, exql.Where(`id = ?`, 1)) 281 | if err != nil { 282 | t.Fatal(err) 283 | } 284 | lid, _ := result.LastInsertId() 285 | row, _ := result.RowsAffected() 286 | assert.Equal(t, int64(1), row) 287 | assert.Equal(t, int64(1), lid) 288 | }) 289 | t.Run("should error if model invalid", func(t *testing.T) { 290 | db, _, _ := sqlmock.New() 291 | s := exql.NewSaver(db) 292 | _, err := s.UpdateModelContext(context.Background(), nil, exql.Where("id = ?", 1)) 293 | assert.EqualError(t, err, "pointer is nil") 294 | }) 295 | } 296 | 297 | func TestSaver_UpdateContext(t *testing.T) { 298 | d := testDb() 299 | s := exql.NewSaver(d.DB()) 300 | t.Run("basic", func(t *testing.T) { 301 | result, err := d.DB().Exec( 302 | "INSERT INTO `users` (`age`,`name`) VALUES (?, ?)", 303 | int64(10), "last") 304 | assert.NoError(t, err) 305 | lid, err := result.LastInsertId() 306 | assert.NoError(t, err) 307 | defer func() { 308 | d.DB().Exec(`DELETE FROM users WHERE id = ?`, lid) 309 | }() 310 | result, err = s.UpdateContext(context.Background(), "users", map[string]interface{}{ 311 | "age": int64(20), 312 | "name": "lang", 313 | }, exql.Where(`id = ?`, lid)) 314 | assert.NoError(t, err) 315 | ra, err := result.RowsAffected() 316 | assert.NoError(t, err) 317 | assert.Equal(t, int64(1), ra) 318 | var actual model.Users 319 | rows, err := d.DB().Query(`SELECT * FROM users WHERE id = ?`, lid) 320 | assert.NoError(t, err) 321 | err = exql.MapRow(rows, &actual) 322 | assert.NoError(t, err) 323 | assert.Equal(t, "lang", actual.Name) 324 | assert.Equal(t, int64(20), actual.Age) 325 | }) 326 | } 327 | 328 | func TestSaver_Delete(t *testing.T) { 329 | t.Run("basic", func(t *testing.T) { 330 | db, mock, _ := sqlmock.New() 331 | mock.ExpectExec("DELETE FROM `table` WHERE id = ?"). 332 | WithArgs(1). 333 | WillReturnResult(sqlmock.NewResult(0, 1)) 334 | s := exql.NewSaver(db) 335 | _, err := s.Delete("table", exql.Where("id = ?", 1)) 336 | assert.NoError(t, err) 337 | }) 338 | t.Run("should error if clause returened an error", func(t *testing.T) { 339 | s := exql.NewSaver(nil) 340 | res, err := s.Delete("table", exql.Where("")) 341 | assert.EqualError(t, err, "DANGER: empty query") 342 | assert.Nil(t, res) 343 | }) 344 | t.Run("should error if table name is empty", func(t *testing.T) { 345 | s := exql.NewSaver(nil) 346 | res, err := s.Delete("", exql.Where("")) 347 | assert.EqualError(t, err, "empty table name for delete query") 348 | assert.Nil(t, res) 349 | }) 350 | t.Run("should error if condition is nil ", func(t *testing.T) { 351 | s := exql.NewSaver(nil) 352 | res, err := s.Delete("table", nil) 353 | assert.EqualError(t, err, "nil condition for delete query") 354 | assert.Nil(t, res) 355 | }) 356 | } 357 | 358 | type upSampleInvalidTag struct { 359 | Id *int `exql:"column::"` 360 | } 361 | 362 | func (upSampleInvalidTag) UpdateTableName() string { 363 | return "" 364 | } 365 | 366 | type upSampleNotPtr struct { 367 | Id int `exql:"column:id"` 368 | } 369 | 370 | func (upSampleNotPtr) UpdateTableName() string { 371 | return "" 372 | } 373 | 374 | type upSample struct { 375 | Id *int `exql:"column:id"` 376 | } 377 | 378 | func (upSample) UpdateTableName() string { 379 | return "" 380 | } 381 | 382 | type upSampleNoFields struct { 383 | } 384 | 385 | func (upSampleNoFields) UpdateTableName() string { 386 | return "" 387 | } 388 | 389 | type upSampleNoColumn struct { 390 | Id *int `exql:"row:id"` 391 | } 392 | 393 | func (upSampleNoColumn) UpdateTableName() string { 394 | return "table" 395 | } 396 | 397 | func TestSaver_QueryExtra(t *testing.T) { 398 | query := q.NewBuilder().Query("SELECT * FROM table WHERE id = ?", 1).Build() 399 | stmt := "SELECT * FROM table WHERE id = ?" 400 | args := []any{1} 401 | aErr := fmt.Errorf("err") 402 | ctx := context.TODO() 403 | setup := func(t *testing.T) (*mock_exql.MockExecutor, exql.Saver) { 404 | ctrl := gomock.NewController(t) 405 | ex := mock_exql.NewMockExecutor(ctrl) 406 | s := exql.NewSaver(ex) 407 | return ex, s 408 | } 409 | setupQueryErr := func(t *testing.T) (*mock_query.MockQuery, exql.Saver) { 410 | ctrl := gomock.NewController(t) 411 | query := mock_query.NewMockQuery(ctrl) 412 | query.EXPECT().Query().Return("", nil, aErr) 413 | s := exql.NewSaver(nil) 414 | return query, s 415 | } 416 | 417 | t.Run("Exec", func(t *testing.T) { 418 | ex, s := setup(t) 419 | ex.EXPECT().Exec(stmt, args...).Return(nil, nil) 420 | res, err := s.Exec(query) 421 | assert.Nil(t, res) 422 | assert.NoError(t, err) 423 | }) 424 | t.Run("Exec/Error", func(t *testing.T) { 425 | query, s := setupQueryErr(t) 426 | res, err := s.Exec(query) 427 | assert.Nil(t, res) 428 | assert.Equal(t, aErr, err) 429 | }) 430 | t.Run("ExecContext", func(t *testing.T) { 431 | ex, s := setup(t) 432 | ex.EXPECT().ExecContext(ctx, stmt, args...).Return(nil, nil) 433 | res, err := s.ExecContext(ctx, query) 434 | assert.Nil(t, res) 435 | assert.NoError(t, err) 436 | }) 437 | t.Run("ExecContext/Error", func(t *testing.T) { 438 | query, s := setupQueryErr(t) 439 | res, err := s.ExecContext(ctx, query) 440 | assert.Nil(t, res) 441 | assert.Equal(t, aErr, err) 442 | }) 443 | t.Run("Query", func(t *testing.T) { 444 | ex, s := setup(t) 445 | ex.EXPECT().Query(stmt, args...).Return(nil, nil) 446 | res, err := s.Query(query) 447 | assert.Nil(t, res) 448 | assert.NoError(t, err) 449 | }) 450 | t.Run("Query/Error", func(t *testing.T) { 451 | query, s := setupQueryErr(t) 452 | res, err := s.Query(query) 453 | assert.Nil(t, res) 454 | assert.Equal(t, aErr, err) 455 | }) 456 | t.Run("QueryContext", func(t *testing.T) { 457 | ex, s := setup(t) 458 | ex.EXPECT().QueryContext(ctx, stmt, args...).Return(nil, nil) 459 | res, err := s.QueryContext(ctx, query) 460 | assert.Nil(t, res) 461 | assert.NoError(t, err) 462 | }) 463 | t.Run("QueryContext/Error", func(t *testing.T) { 464 | query, s := setupQueryErr(t) 465 | res, err := s.QueryContext(ctx, query) 466 | assert.Nil(t, res) 467 | assert.Equal(t, aErr, err) 468 | }) 469 | t.Run("QueryRow", func(t *testing.T) { 470 | ex, s := setup(t) 471 | ex.EXPECT().QueryRow(stmt, args...).Return(nil) 472 | res, err := s.QueryRow(query) 473 | assert.Nil(t, res) 474 | assert.NoError(t, err) 475 | }) 476 | t.Run("QueryRow/Error", func(t *testing.T) { 477 | query, s := setupQueryErr(t) 478 | res, err := s.QueryRow(query) 479 | assert.Nil(t, res) 480 | assert.Equal(t, aErr, err) 481 | }) 482 | t.Run("QueryRowContext", func(t *testing.T) { 483 | ex, s := setup(t) 484 | ex.EXPECT().QueryRowContext(ctx, stmt, args...).Return(nil) 485 | res, err := s.QueryRowContext(ctx, query) 486 | assert.Nil(t, res) 487 | assert.NoError(t, err) 488 | }) 489 | t.Run("QueryRowContext/Error", func(t *testing.T) { 490 | query, s := setupQueryErr(t) 491 | res, err := s.QueryRowContext(ctx, query) 492 | assert.Nil(t, res) 493 | assert.Equal(t, aErr, err) 494 | }) 495 | } 496 | -------------------------------------------------------------------------------- /schema/model.sql: -------------------------------------------------------------------------------- 1 | drop table if exists users; 2 | create table users ( 3 | id int(11) not null auto_increment, 4 | name varchar(255) not null, 5 | age int(11) not null, 6 | primary key (id) 7 | ); 8 | 9 | drop table if exists user_groups; 10 | create table user_groups ( 11 | id int(11) not null auto_increment, 12 | name varchar(255) not null, 13 | primary key (id) 14 | ); 15 | 16 | drop table if exists group_users; 17 | create table group_users ( 18 | id int(11) not null auto_increment, 19 | user_id int(11) not null, 20 | group_id int(11) not null, 21 | primary key (id), 22 | foreign key (user_id) references users(id), 23 | foreign key (group_id) references user_groups(id) 24 | ); 25 | create index group_users_user_id on group_users(user_id); 26 | create index group_users_group_id on group_users(group_id); 27 | 28 | drop table if exists user_login_histories; 29 | create table user_login_histories ( 30 | id int(11) not null auto_increment, 31 | user_id int(11) not null, 32 | created_at datetime not null, 33 | primary key (id, created_at) 34 | ) partition by hash(year(created_at)) partitions 16; 35 | 36 | drop table if exists fields; 37 | create table fields ( 38 | id int not null auto_increment, 39 | tinyint_field tinyint(4) not null, 40 | tinyint_unsigned_field tinyint(4) unsigned not null, 41 | tinyint_nullable_field tinyint(4), 42 | tinyint_unsigned_nullable_field tinyint(4) unsigned, 43 | smallint_field smallint(6) not null, 44 | smallint_unsigned_field smallint(6) unsigned not null, 45 | smallint_nullable_field smallint(6) , 46 | smallint_unsigned_nullable_field smallint(6) unsigned, 47 | mediumint_field mediumint(6) not null, 48 | mediumint_unsigned_field mediumint(6) unsigned not null, 49 | mediumint_nullable_field mediumint(6) , 50 | mediumint_unsigned_nullable_field mediumint(6) unsigned, 51 | int_field int(11) not null, 52 | int_unsigned_field int(11) unsigned not null, 53 | int_nullable_field int(11) , 54 | int_unsigned_nullable_field int(11) unsigned, 55 | bigint_field bigint(20) not null, 56 | bigint_unsigned_field bigint(20) unsigned not null, 57 | bigint_nullable_field bigint(20) , 58 | bigint_unsigned_nullable_field bigint(20) unsigned, 59 | float_field float not null, 60 | float_null_field float, 61 | double_field double not null, 62 | double_null_field double , 63 | tinytext_field tinytext not null, 64 | tinytext_null_field tinytext, 65 | mediumtext_field mediumtext not null, 66 | mediumtext_null_field mediumtext, 67 | text_field text not null, 68 | text_null_field text, 69 | longtext_field longtext not null, 70 | longtext_null_field longtext, 71 | varchar_filed_field varchar(255) not null, 72 | varchar_null_field varchar(255), 73 | char_filed_field char(10) not null, 74 | char_filed_null_field char(10), 75 | date_field date not null, 76 | date_null_field date, 77 | datetime_field datetime not null, 78 | datetime_null_field datetime, 79 | time_field time not null, 80 | time_null_field time, 81 | timestamp_field timestamp not null, 82 | timestamp_null_field timestamp null, 83 | tinyblob_field tinyblob not null, 84 | tinyblob_null_field tinyblob, 85 | mediumblob_field mediumblob not null, 86 | mediumblob_null_field mediumblob, 87 | blob_field blob not null, 88 | blob_null_field blob, 89 | longblob_field longblob not null, 90 | longblob_null_field longblob, 91 | json_field json not null, 92 | json_null_field json, 93 | primary key (id) 94 | ); 95 | -------------------------------------------------------------------------------- /stmt.go: -------------------------------------------------------------------------------- 1 | package exql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | ) 7 | 8 | type stmtExecutor struct { 9 | ex Executor 10 | stmts map[string]*sql.Stmt 11 | } 12 | 13 | func (e *stmtExecutor) Exec(query string, args ...any) (sql.Result, error) { 14 | return e.ExecContext(context.Background(), query, args...) 15 | } 16 | 17 | func (e *stmtExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { 18 | stmt, err := e.prepare(ctx, query) 19 | if err != nil { 20 | return nil, err 21 | } 22 | return stmt.ExecContext(ctx, args...) 23 | } 24 | 25 | func (e *stmtExecutor) Prepare(stmt string) (*sql.Stmt, error) { 26 | return e.ex.Prepare(stmt) 27 | } 28 | 29 | func (e *stmtExecutor) PrepareContext(ctx context.Context, stmt string) (*sql.Stmt, error) { 30 | return e.ex.PrepareContext(ctx, stmt) 31 | } 32 | 33 | func (e *stmtExecutor) Query(query string, args ...any) (*sql.Rows, error) { 34 | return e.QueryContext(context.Background(), query, args...) 35 | } 36 | 37 | func (e *stmtExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { 38 | stmt, err := e.prepare(ctx, query) 39 | if err != nil { 40 | return nil, err 41 | } 42 | return stmt.QueryContext(ctx, args...) 43 | } 44 | 45 | func (e *stmtExecutor) QueryRow(query string, args ...any) *sql.Row { 46 | return e.QueryRowContext(context.Background(), query, args...) 47 | } 48 | 49 | func (e *stmtExecutor) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { 50 | return e.ex.QueryRowContext(ctx, query, args...) 51 | } 52 | 53 | // StmtExecutor is the Executor that caches queries as *sql.Stmt. 54 | // It uses the cached Stmt for the next execution if query is identical. 55 | // They are held until Close() is called. This is useful for the case 56 | // of executing the same query repeatedly in the for-loop. 57 | // It may prevent errors caused by the db's connection pool. 58 | // 59 | // Example: 60 | // 61 | // stmtExecer := exql.NewStmtExecutor(tx.Tx()) 62 | // defer stmtExecer.Close() 63 | // stmtSaver := exql.NewSaver(stmtExecer) 64 | type StmtExecutor interface { 65 | Executor 66 | // Close calls all retained *sql.Stmt and clears the buffer. 67 | // DON'T forget to call this on the manual use. 68 | Close() error 69 | } 70 | 71 | func (e *stmtExecutor) prepare(ctx context.Context, q string) (*sql.Stmt, error) { 72 | var err error 73 | stmt, ok := e.stmts[q] 74 | if !ok { 75 | if stmt, err = e.PrepareContext(ctx, q); err != nil { 76 | return nil, err 77 | } else { 78 | e.stmts[q] = stmt 79 | } 80 | } 81 | return stmt, nil 82 | } 83 | 84 | func (e *stmtExecutor) Close() error { 85 | var lastErr error 86 | for _, v := range e.stmts { 87 | err := v.Close() 88 | if err != nil { 89 | lastErr = err 90 | } 91 | } 92 | e.stmts = make(map[string]*sql.Stmt) 93 | return lastErr 94 | } 95 | 96 | func NewStmtExecutor(ex Executor) StmtExecutor { 97 | return newStmtExecutor(ex) 98 | } 99 | 100 | func newStmtExecutor(ex Executor) *stmtExecutor { 101 | return &stmtExecutor{ex: ex, stmts: make(map[string]*sql.Stmt)} 102 | } 103 | -------------------------------------------------------------------------------- /stmt_test.go: -------------------------------------------------------------------------------- 1 | package exql_test 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | "testing" 7 | 8 | "github.com/DATA-DOG/go-sqlmock" 9 | "github.com/golang/mock/gomock" 10 | "github.com/loilo-inc/exql/v2" 11 | "github.com/loilo-inc/exql/v2/mocks/mock_exql" 12 | "github.com/loilo-inc/exql/v2/model" 13 | "github.com/loilo-inc/exql/v2/query" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func TestPreparedExecutor(t *testing.T) { 18 | setup := func(t *testing.T, db exql.Executor) exql.StmtExecutor { 19 | pex := exql.NewStmtExecutor(db) 20 | t.Cleanup(func() { 21 | assert.Nil(t, pex.Close()) 22 | }) 23 | return pex 24 | } 25 | t.Run("integration", func(t *testing.T) { 26 | db := testDb() 27 | user1 := model.Users{Name: "go"} 28 | user2 := model.Users{Name: "lang"} 29 | err := db.Transaction(func(tx exql.Tx) error { 30 | pex := setup(t, tx.Tx()) 31 | saver := exql.NewSaver(pex) 32 | for _, user := range []*model.Users{&user1, &user2} { 33 | if _, err := saver.Insert(user); err != nil { 34 | return err 35 | } 36 | t.Cleanup(func() { 37 | db.Delete(model.UsersTableName, exql.Where("id = ?", user.Id)) 38 | }) 39 | } 40 | return nil 41 | }) 42 | assert.Nil(t, err) 43 | var list []*model.Users 44 | err = db.FindMany(query.Q( 45 | `select * from users where id in (?,?)`, user1.Id, user2.Id), 46 | &list, 47 | ) 48 | assert.Nil(t, err) 49 | }) 50 | t.Run("mock", func(t *testing.T) { 51 | db, mock, err := sqlmock.New() 52 | assert.NoError(t, err) 53 | ex := setup(t, db) 54 | qm := regexp.QuoteMeta 55 | insertQ := "insert into `users` (`name`) values (?)" 56 | selectQ := "select * from `users` where `name` = ?" 57 | stmt1 := mock.ExpectPrepare(qm(insertQ)).WillBeClosed() 58 | stmt1.ExpectExec().WithArgs("go").WillReturnResult(sqlmock.NewResult(0, 0)) 59 | stmt1.ExpectExec().WithArgs("og").WillReturnResult(sqlmock.NewResult(0, 0)) 60 | stmt2 := mock.ExpectPrepare(qm(selectQ)).WillBeClosed() 61 | stmt2.ExpectQuery().WithArgs("go").WillReturnRows(sqlmock.NewRows([]string{})) 62 | _, err = ex.Exec(insertQ, "go") 63 | assert.NoError(t, err) 64 | _, err = ex.Exec(insertQ, "og") 65 | assert.NoError(t, err) 66 | _, err = ex.Query(selectQ, "go") 67 | assert.NoError(t, err) 68 | err = ex.Close() 69 | assert.NoError(t, err) 70 | }) 71 | 72 | t.Run("preparation error", func(t *testing.T) { 73 | ctrl := gomock.NewController(t) 74 | stmt := "stmt" 75 | testFunc := func(t *testing.T, body func(ex exql.StmtExecutor) (err error)) { 76 | mock := mock_exql.NewMockExecutor(ctrl) 77 | mock.EXPECT().PrepareContext(gomock.Any(), stmt).Return(nil, fmt.Errorf("err")) 78 | ex := exql.NewStmtExecutor(mock) 79 | err := body(ex) 80 | assert.EqualError(t, err, "err") 81 | } 82 | t.Run("Exec", func(t *testing.T) { 83 | testFunc(t, func(ex exql.StmtExecutor) (err error) { 84 | _, err = ex.Exec(stmt) 85 | return 86 | }) 87 | }) 88 | t.Run("Query", func(t *testing.T) { 89 | testFunc(t, func(ex exql.StmtExecutor) (err error) { 90 | _, err = ex.Query(stmt) 91 | return 92 | }) 93 | }) 94 | }) 95 | t.Run("Prepare bypass to the inner executor", func(t *testing.T) { 96 | db, mock, _ := sqlmock.New() 97 | mock.ExpectPrepare("stmt").WillBeClosed() 98 | ex := exql.NewStmtExecutor(db) 99 | stmt, err := ex.Prepare("stmt") 100 | stmt.Close() 101 | assert.Nil(t, err) 102 | }) 103 | t.Run("QueryRow bypass to the inner executor", func(t *testing.T) { 104 | ctrl := gomock.NewController(t) 105 | mock := mock_exql.NewMockExecutor(ctrl) 106 | mock.EXPECT().QueryRowContext(gomock.Any(), "stmt").Return(nil) 107 | ex := exql.NewStmtExecutor(mock) 108 | row := ex.QueryRow("stmt") 109 | assert.Nil(t, row) 110 | }) 111 | } 112 | -------------------------------------------------------------------------------- /tag.go: -------------------------------------------------------------------------------- 1 | package exql 2 | 3 | import ( 4 | "strings" 5 | 6 | "golang.org/x/xerrors" 7 | ) 8 | 9 | func ParseTags(tag string) (map[string]string, error) { 10 | tags := strings.Split(tag, ";") 11 | ret := make(map[string]string) 12 | set := func(k string, v string) error { 13 | if k == "" { 14 | return nil 15 | } 16 | if _, ok := ret[k]; ok { 17 | return xerrors.Errorf("duplicated tag: %s", k) 18 | } 19 | ret[k] = v 20 | return nil 21 | } 22 | for _, tag := range tags { 23 | kv := strings.Split(tag, ":") 24 | if len(kv) == 1 { 25 | if err := set(kv[0], ""); err != nil { 26 | return nil, err 27 | } 28 | } else if len(kv) == 2 { 29 | if err := set(kv[0], kv[1]); err != nil { 30 | return nil, err 31 | } 32 | } else { 33 | return nil, xerrors.Errorf("invalid tag format") 34 | } 35 | } 36 | if len(ret) == 0 { 37 | return nil, xerrors.Errorf("invalid tag format") 38 | } 39 | return ret, nil 40 | } 41 | -------------------------------------------------------------------------------- /tag_test.go: -------------------------------------------------------------------------------- 1 | package exql_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/loilo-inc/exql/v2" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestParseTags(t *testing.T) { 11 | t.Run("basic", func(t *testing.T) { 12 | tags, err := exql.ParseTags("a:1;b:2;c:3") 13 | assert.NoError(t, err) 14 | assert.Equal(t, len(tags), 3) 15 | assert.Equal(t, "1", tags["a"]) 16 | assert.Equal(t, "2", tags["b"]) 17 | assert.Equal(t, "3", tags["c"]) 18 | }) 19 | t.Run("key only", func(t *testing.T) { 20 | tags, err := exql.ParseTags("a;b;c;") 21 | assert.NoError(t, err) 22 | assert.Equal(t, len(tags), 3) 23 | assert.Equal(t, "", tags["a"]) 24 | assert.Equal(t, "", tags["b"]) 25 | assert.Equal(t, "", tags["c"]) 26 | }) 27 | assertInvalid := func(s string, e string) { 28 | tags, err := exql.ParseTags(s) 29 | assert.Nil(t, tags) 30 | assert.EqualError(t, err, e) 31 | } 32 | t.Run("should return error for duplicate tag", func(t *testing.T) { 33 | assertInvalid("a:1;a:2", "duplicated tag: a") 34 | assertInvalid("a;a;", "duplicated tag: a") 35 | }) 36 | t.Run("should return error if tag is empty", func(t *testing.T) { 37 | assertInvalid(";", "invalid tag format") 38 | assertInvalid("", "invalid tag format") 39 | assertInvalid(";:;", "invalid tag format") 40 | assertInvalid(":::", "invalid tag format") 41 | assertInvalid(";;;", "invalid tag format") 42 | }) 43 | } 44 | -------------------------------------------------------------------------------- /template/README.md: -------------------------------------------------------------------------------- 1 | exql 2 | --- 3 | [![codecov](https://codecov.io/gh/loilo-inc/exql/branch/master/graph/badge.svg?token=aGixN2xIMP)](https://codecov.io/gh/loilo-inc/exql) 4 | 5 | Safe, strict and clear ORM for Go 6 | 7 | ## Introduction 8 | 9 | exql is a simple ORM library for MySQL, written in Go. It is designed to work at the minimum for real software development. It has a few, limited but enough convenient functionalities of SQL database. 10 | We adopted the data mapper model, not the active record. Records in the database are mapped into structs simply. Each model has no state and also no methods to modify itself and sync database records. You need to write bare SQL code for every operation you need except for a few cases. 11 | 12 | exql is designed by focusing on safety and clearness in SQL usage. In other words, we never generate any SQL statements that are potentially dangerous or have ambiguous side effects across tables and the database. 13 | 14 | It does: 15 | 16 | - make insert/update query from model structs. 17 | - map rows returned from the database into structs. 18 | - map joined table into one or more structs. 19 | - provide a safe syntax for the transaction. 20 | - provide a framework to build dynamic SQL statements safely. 21 | - generate model codes automatically from the database. 22 | 23 | It DOESN'T 24 | 25 | - make delete/update statements across the table. 26 | - make unexpectedly slow select queries that don't use correct indices. 27 | - modify any database settings, schemas and indices. 28 | 29 | ## Table of contents 30 | 31 | - [exql](#exql) 32 | - [Introduction](#introduction) 33 | - [Table of contents](#table-of-contents) 34 | - [Usage](#usage) 35 | - [Open database connection](#open-database-connection) 36 | - [Code Generation](#code-generation) 37 | - [Execute queries](#execute-queries) 38 | - [Insert](#insert) 39 | - [Update](#update) 40 | - [Delete](#delete) 41 | - [Other](#other) 42 | - [Transaction](#transaction) 43 | - [Find records](#find-records) 44 | - [For simple query](#for-simple-query) 45 | - [For joined table](#for-joined-table) 46 | - [For outer-joined table](#for-outer-joined-table) 47 | - [Use query builder](#use-query-builder) 48 | - [License](#license) 49 | 50 | ## Usage 51 | 52 | ### Open database connection 53 | 54 | ```go 55 | {{.Open}} 56 | ``` 57 | 58 | ### Code Generation 59 | exql provides an automated code generator of models based on the database schema. This is a typical table schema of MySQL database. 60 | 61 | ``` 62 | mysql> show columns from users; 63 | +-------+--------------+------+-----+---------+----------------+ 64 | | Field | Type | Null | Key | Default | Extra | 65 | +-------+--------------+------+-----+---------+----------------+ 66 | | id | int(11) | NO | PRI | NULL | auto_increment | 67 | | name | varchar(255) | NO | | NULL | | 68 | | age | int(11) | NO | | NULL | | 69 | +-------+--------------+------+-----+---------+----------------+ 70 | ``` 71 | 72 | To generate model codes, based on the schema, you need to write the code like this: 73 | 74 | ```go 75 | {{.GenerateModels}} 76 | ``` 77 | 78 | And results are mostly like this: 79 | 80 | ```go 81 | {{.AutoGenerateCode}} 82 | ``` 83 | 84 | `Users` is the destination of the data mapper. It only has value fields and one method, `TableName()`. This is the implementation of `exql.Model` that can be passed into data saver. All structs, methods and field tags must be preserved as it is, for internal use. If you want to modify the results, you must run the generator again. 85 | 86 | `UpdateUsers` is a partial structure for the data model. It has identical name fields to `Users`, but all types are represented as a pointer. It is used to update table columns partially. In other words, it is a designated, typesafe map for the model. 87 | 88 | ### Execute queries 89 | 90 | There are several ways to publish SQL statements with exql. 91 | 92 | #### Insert 93 | 94 | INSERT query is constructed automatically based on model data and executed without writing the statement. To insert new records into the database, set values to the model and pass it to `exql.DB#Insert` method. 95 | 96 | ```go 97 | {{.Insert}} 98 | ``` 99 | 100 | #### Update 101 | 102 | UPDATE query is constructed automatically based on the model update struct. To avoid unexpected updates to the table, all values are represented by a pointer of data type. 103 | 104 | ```go 105 | {{.Update}} 106 | ``` 107 | 108 | #### Delete 109 | 110 | DELETE query is published to the table with given conditions. There's no way to construct DELETE query from the model as a security reason. 111 | 112 | ```go 113 | {{.Delete}} 114 | ``` 115 | 116 | #### Other 117 | 118 | Other queries should be executed by `sql.DB` that got from `DB`. 119 | 120 | ```go 121 | {{.Other}} 122 | ``` 123 | 124 | ### Transaction 125 | 126 | Transaction with `BEGIN`~`COMMIT`/`ROLLBACK` is done by `TransactionWithContext`. You don't need to call `BeginTx` and `Commit`/`Rollback` manually and all atomic operations are done within a callback. 127 | 128 | ```go 129 | {{.Tx}} 130 | ``` 131 | 132 | ### Find records 133 | 134 | To find records from the database, use `Find`/`FindMany` method. It executes the query and maps results into structs correctly. 135 | 136 | #### For simple query 137 | 138 | ```go 139 | {{.MapRows}} 140 | ``` 141 | 142 | #### For joined table 143 | 144 | ```go 145 | {{.MapJoinedRows}} 146 | ``` 147 | 148 | #### For outer-joined table 149 | 150 | ```go 151 | {{.MapOuterJoinedRows}} 152 | ``` 153 | 154 | ### Use query builder 155 | 156 | `exql/query` package is a low-level API for building complicated SQL statements. See [V2 Release Notes](https://github.com/loilo-inc/exql/blob/main/changelogs/v2.0.md#exqlquery-package) for more details. 157 | 158 | ```go 159 | {{.QueryBuilder}} 160 | ``` 161 | 162 | ## License 163 | 164 | MIT License / Copyright (c) LoiLo inc. 165 | 166 | -------------------------------------------------------------------------------- /test/db.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | const DbUrl = "root:@tcp(127.0.0.1:13326)/exql?charset=utf8mb4&parseTime=True&loc=Local" 4 | -------------------------------------------------------------------------------- /test_db_test.go: -------------------------------------------------------------------------------- 1 | package exql_test 2 | 3 | import ( 4 | "database/sql" 5 | 6 | _ "github.com/go-sql-driver/mysql" 7 | "github.com/loilo-inc/exql/v2" 8 | "github.com/loilo-inc/exql/v2/test" 9 | ) 10 | 11 | func testDb() exql.DB { 12 | db, err := exql.Open(&exql.OpenOptions{ 13 | Url: test.DbUrl, 14 | }) 15 | if err != nil { 16 | panic(err) 17 | } 18 | return db 19 | } 20 | 21 | func testSqlDB() *sql.DB { 22 | db, err := sql.Open("mysql", test.DbUrl) 23 | if err != nil { 24 | panic(err) 25 | } 26 | return db 27 | } 28 | -------------------------------------------------------------------------------- /tool/composegen/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | "runtime" 8 | ) 9 | 10 | func main() { 11 | arch := runtime.GOARCH 12 | var prefix string 13 | switch arch { 14 | case "amd64": 15 | prefix = "amd64" 16 | case "arm64": 17 | prefix = "arm64v8" 18 | } 19 | if prefix == "" { 20 | log.Fatalf("unsupported arch: %s", arch) 21 | } 22 | yml := fmt.Sprintf(` 23 | services: 24 | mysql: 25 | container_name: exql_mysql8 26 | image: %s/mysql:8 27 | ports: 28 | - 13326:3306 29 | environment: 30 | MYSQL_ALLOW_EMPTY_PASSWORD: 1 31 | MYSQL_DATABASE: exql 32 | volumes: 33 | - ./schema:/docker-entrypoint-initdb.d`, prefix) 34 | err := os.WriteFile("compose.yml", []byte(yml), 0644) 35 | if err != nil { 36 | log.Fatalf("failed to write compose.yml: %s", err) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /tool/modelgen/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | _ "github.com/go-sql-driver/mysql" 7 | "github.com/loilo-inc/exql/v2" 8 | ) 9 | 10 | func main() { 11 | db, _ := exql.Open(&exql.OpenOptions{ 12 | Url: "root:@tcp(127.0.0.1:13326)/exql?charset=utf8mb4&parseTime=True&loc=Local", 13 | }) 14 | g := exql.NewGenerator(db.DB()) 15 | err := g.Generate(&exql.GenerateOptions{ 16 | OutDir: "model", 17 | }) 18 | if err != nil { 19 | log.Fatal(err) 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /tool/rdmegen/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "text/template" 6 | ) 7 | 8 | func main() { 9 | t := template.Must(template.ParseFiles("template/README.md")) 10 | data := map[string]string{ 11 | "Open": catFile("example/open.go"), 12 | "GenerateModels": catFile("example/generator.go"), 13 | "Insert": catFile("example/insert.go"), 14 | "Update": catFile("example/update.go"), 15 | "Delete": catFile("example/delete.go"), 16 | "Other": catFile("example/other.go"), 17 | "MapRows": catFile("example/mapper.go"), 18 | "MapJoinedRows": catFile("example/serial_mapper.go"), 19 | "MapOuterJoinedRows": catFile("example/outer_join.go"), 20 | "Tx": catFile("example/tx.go"), 21 | "QueryBuilder": catFile("example/query_builder.go"), 22 | "AutoGenerateCode": catFile("model/users.go"), 23 | } 24 | o, err := os.Create("README.md") 25 | if err != nil { 26 | panic(err) 27 | } 28 | if err := t.Execute(o, data); err != nil { 29 | panic(err) 30 | } 31 | } 32 | 33 | func catFile(f string) string { 34 | s, err := os.ReadFile(f) 35 | if err != nil { 36 | panic(err) 37 | } 38 | return string(s) 39 | } 40 | -------------------------------------------------------------------------------- /tx.go: -------------------------------------------------------------------------------- 1 | package exql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "golang.org/x/xerrors" 8 | ) 9 | 10 | type Tx interface { 11 | Saver 12 | Finder 13 | Mapper 14 | Tx() *sql.Tx 15 | } 16 | 17 | type tx struct { 18 | *saver 19 | *finder 20 | *mapper 21 | tx *sql.Tx 22 | } 23 | 24 | func newTx(t *sql.Tx) *tx { 25 | return &tx{saver: newSaver(t), finder: newFinder(t), mapper: &mapper{}, tx: t} 26 | } 27 | 28 | func (t *tx) Tx() *sql.Tx { 29 | return t.tx 30 | } 31 | 32 | func Transaction(db *sql.DB, ctx context.Context, opts *sql.TxOptions, callback func(tx Tx) error) error { 33 | sqlTx, err := db.BeginTx(ctx, opts) 34 | if err != nil { 35 | return err 36 | } 37 | tx := newTx(sqlTx) 38 | var p interface{} 39 | txErr := func() error { 40 | defer func() { 41 | p = recover() 42 | }() 43 | return callback(tx) 44 | }() 45 | if p != nil { 46 | txErr = xerrors.Errorf("recovered: %s", p) 47 | } 48 | if txErr != nil { 49 | if err := sqlTx.Rollback(); err != nil { 50 | return err 51 | } 52 | return txErr 53 | } else if err := sqlTx.Commit(); err != nil { 54 | return err 55 | } 56 | return nil 57 | } 58 | -------------------------------------------------------------------------------- /tx_test.go: -------------------------------------------------------------------------------- 1 | package exql_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/loilo-inc/exql/v2" 9 | "github.com/loilo-inc/exql/v2/model" 10 | "github.com/loilo-inc/exql/v2/query" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestTx_Transaction(t *testing.T) { 15 | db := testDb() 16 | t.Run("basic", func(t *testing.T) { 17 | var user *model.Users 18 | err := exql.Transaction(db.DB(), context.Background(), nil, func(tx exql.Tx) error { 19 | user = &model.Users{Name: "go"} 20 | res, err := tx.Insert(user) 21 | assert.NoError(t, err) 22 | lid, err := res.LastInsertId() 23 | assert.NoError(t, err) 24 | assert.Equal(t, lid, user.Id) 25 | return nil 26 | }) 27 | assert.NoError(t, err) 28 | var dest model.Users 29 | err = db.Find(query.Q(`select * from users where id = ?`, user.Id), &dest) 30 | assert.NoError(t, err) 31 | assert.Equal(t, user.Id, dest.Id) 32 | }) 33 | t.Run("rollback", func(t *testing.T) { 34 | var user *model.Users 35 | err := exql.Transaction(db.DB(), context.Background(), nil, func(tx exql.Tx) error { 36 | user = &model.Users{Name: "go"} 37 | res, err := tx.Insert(user) 38 | assert.NoError(t, err) 39 | lid, err := res.LastInsertId() 40 | assert.NoError(t, err) 41 | assert.Equal(t, lid, user.Id) 42 | return fmt.Errorf("err") 43 | }) 44 | assert.EqualError(t, err, "err") 45 | var dest model.Users 46 | rows, err := db.DB().Query(`select * from users where id = ?`, user.Id) 47 | assert.NoError(t, err) 48 | err = exql.MapRow(rows, &dest) 49 | assert.ErrorIs(t, err, exql.ErrRecordNotFound) 50 | }) 51 | t.Run("should rollback if panic happened during transaction", func(t *testing.T) { 52 | var user *model.Users 53 | err := exql.Transaction(db.DB(), context.Background(), nil, func(tx exql.Tx) error { 54 | user = &model.Users{} 55 | _, err := tx.Insert(user) 56 | assert.NoError(t, err) 57 | panic("panic") 58 | }) 59 | assert.EqualError(t, err, "recovered: panic") 60 | var dest model.Users 61 | err = db.Find(query.Q(`select * from users where id = ?`, user.Id), &dest) 62 | assert.Equal(t, exql.ErrRecordNotFound, err) 63 | }) 64 | } 65 | func TestTx_Map(t *testing.T) { 66 | db := testDb() 67 | user := &model.Users{Name: "go"} 68 | defer func() { 69 | db.DB().Exec(`delete from users where id = ?`, user.Id) 70 | }() 71 | var dest model.Users 72 | err := exql.Transaction(db.DB(), context.Background(), nil, func(tx exql.Tx) error { 73 | if _, err := tx.Insert(user); err != nil { 74 | return err 75 | } 76 | rows, err := tx.Tx().Query(`select * from users where id = ?`, user.Id) 77 | if err != nil { 78 | return err 79 | } 80 | if err := exql.MapRow(rows, &dest); err != nil { 81 | return err 82 | } 83 | return nil 84 | }) 85 | assert.NoError(t, err) 86 | assert.Equal(t, user.Id, dest.Id) 87 | } 88 | 89 | func TestTx_MapMany(t *testing.T) { 90 | user := &model.Users{Name: "go"} 91 | db := testDb() 92 | var dest []*model.Users 93 | defer func() { 94 | db.DB().Exec(`delete from users where id = ?`, user.Id) 95 | }() 96 | err := exql.Transaction(db.DB(), context.Background(), nil, func(tx exql.Tx) error { 97 | if _, err := tx.Insert(user); err != nil { 98 | return err 99 | } 100 | return tx.FindMany( 101 | query.Q(`select * from users where id = ?`, user.Id), 102 | &dest, 103 | ) 104 | }) 105 | assert.NoError(t, err) 106 | assert.Equal(t, user.Id, dest[0].Id) 107 | } 108 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package exql 2 | 3 | // Ptr returns the pointer of the argument. 4 | func Ptr[T any](t T) *T { 5 | return &t 6 | } 7 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | package exql_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/loilo-inc/exql/v2" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestPtr(t *testing.T) { 11 | var str = "str" 12 | ptr := exql.Ptr(str) 13 | assert.Equal(t, "str", *ptr) 14 | } 15 | --------------------------------------------------------------------------------