├── .github └── workflows │ └── go.yml ├── .gitignore ├── LICENSE ├── README.md ├── dialect.go ├── dialect_test.go ├── filters.go ├── filters_test.go ├── foreign_key.go ├── go.mod ├── go.sum ├── migration.go ├── model.go ├── model_test.go ├── mysqldialect ├── mysqldialect.go └── mysqldialect_test.go ├── one_to_many.go ├── pqdialect ├── pqdialect.go ├── pqdialect_test.go └── pqlib.go ├── query.go ├── query_test.go ├── sqlitedialect ├── sqlitedialect.go └── sqlitedialect_test.go ├── stringers.go └── stringers_test.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | on: 3 | push: 4 | branches: [ "main" ] 5 | pull_request: 6 | branches: [ "main" ] 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Set up Go 13 | uses: actions/setup-go@v4 14 | with: 15 | go-version: '1.20' 16 | - name: Test 17 | run: go test -v ./... 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | ISC License 2 | 3 | Copyright © 2023 Evan Byrne (https://evanbyrne.com) 4 | 5 | Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # REM 2 | 3 | The retro Golang ORM. **R**etro **E**ntity **M**apper. 4 | 5 | ```go 6 | type Accounts struct { 7 | Group rem.NullForeignKey[Groups] `db:"group_id"` 8 | Id int64 `db:"id" db_primary:"true"` 9 | Name string `db:"name"` 10 | } 11 | 12 | type Groups struct { 13 | Accounts rem.OneToMany[Accounts] `db:"group_id"` 14 | Id int64 `db:"id" db_primary:"true"` 15 | Name string `db:"name" db_max_length:"100"` 16 | } 17 | ``` 18 | 19 | ```go 20 | // Only one additional query is executed to fetch all related accounts. 21 | groups, err := rem.Use[Groups](). 22 | FetchRelated("Accounts"). 23 | Filter("id", "IN", []interface{}{10, 20, 30}). 24 | Sort("name", "-id"). 25 | All(db) 26 | 27 | if err != nil { 28 | panic(err) 29 | } 30 | for _, group := range groups { 31 | // group *Groups 32 | // group.Accounts.Rows []*Accounts 33 | } 34 | ``` 35 | 36 | ## Features 37 | 38 | - PostgreSQL, MySQL, and SQLite dialects. 39 | - Data and schema migrations that use the same model syntax. 40 | - Optimized foreign key and one-to-many prefetching. 41 | - Interface extensible query builder. Can be used for database-specific features. 42 | - Negligible performance difference from using database/sql directly. 43 | - Decoupled from database/sql connections and drivers. 44 | - Partially or fully fallback to a safely parameterized SQL format as desired. 45 | - Zero code gen. Models are just structs that may have your own fields and methods. 46 | - Standardized safety with explicitly null and not-null types. 47 | - Transaction and golang context support. 48 | - Subqueries, joins, selective fetching, map scanning, and more. 49 | 50 | ## Installation 51 | 52 | The `main` branch contains the latest release. From the shell: 53 | 54 | ``` 55 | go get github.com/evantbyrne/rem 56 | ``` 57 | 58 | **Note:** REM is not yet stable and pre-1.0 releases may result in breaking changes. 59 | 60 | 61 | ## Contributing 62 | 63 | Please post feature requests, questions, and other feedback to the [discussions board](https://github.com/evantbyrne/rem/discussions). Submit bug reports to the [issue tracker](https://github.com/evantbyrne/rem/issues). 64 | 65 | 66 | ## Dialects 67 | 68 | REM supports PostgreSQL, MySQL, and SQLite. To use a dialect, import the appropriate package and set it as the default once on application bootup. 69 | 70 | ```go 71 | import ( 72 | // Choose one: 73 | "github.com/evantbyrne/rem/mysqldialect" 74 | "github.com/evantbyrne/rem/pqdialect" 75 | "github.com/evantbyrne/rem/sqlitedialect" 76 | 77 | // Don't forget to import your database driver. 78 | ) 79 | ``` 80 | 81 | ```go 82 | // Choose one: 83 | rem.SetDialect(mysqldialect.MysqlDialect{}) 84 | rem.SetDialect(pqdialect.PqDialect{}) 85 | rem.SetDialect(sqlitedialect.SqliteDialect{}) 86 | 87 | // Then connect to your database as usual. 88 | db, err := sql.Open("", "") 89 | if err != nil { 90 | panic(err) 91 | } 92 | defer db.Close() 93 | ``` 94 | 95 | 96 | ## Models 97 | 98 | Models are structs that define table schemas. 99 | 100 | ```go 101 | type Accounts struct { 102 | Id int64 `db:"id" db_primary:"true"` 103 | Name string `db:"name" db_max_length:"100"` 104 | Junk string 105 | } 106 | ``` 107 | 108 | In the above struct definition, `Id` and `Name` are columns in the `accounts`. Their columns are defined by the `db` field tag. `Id` is also an auto-incrementing primary key. `Name` has a maximum character length of `100`. The `Junk` field is ignored by REM. 109 | 110 | After defining a model, register it once on application bootup, then query the database. 111 | 112 | ```go 113 | // rem.Register[To]() caches computed structure of the model. 114 | rem.Register[Accounts]() 115 | 116 | // rem.Use[To]() returns a query builder for the model. 117 | rows, err := rem.Use[Accounts]().All(db) 118 | 119 | // You can also reuse the Model[To] instance returned by rem.Register[To]() and rem.Use[To](). 120 | accounts := rem.Use[Accounts]() 121 | rows1, err1 := accounts.Filter("name", "=", "foo").All(db) 122 | rows2, err2 := accounts.Filter("name", "=", "bar").All(db) 123 | 124 | // Register and use a different table with the same model. 125 | rem.Register[Accounts](rem.Config{Table: "groups"}) 126 | groups := rem.Use[Accounts](rem.Config{Table: "groups"}) 127 | ``` 128 | 129 | 130 | ## Migrations 131 | 132 | REM provides the migrations interface as a way to simplify schema and data changes. The interface is just two methods to implement: 133 | 134 | ```go 135 | // github.com/evantbyrne/rem/migrations.go: 136 | type Migration interface { 137 | Down(db *sql.DB) error 138 | Up(db *sql.DB) error 139 | } 140 | ``` 141 | 142 | Models are defined in the same way within migrations as they are in the rest of the application. Here's an example: 143 | 144 | ```go 145 | type Migration0001Accounts struct{} 146 | 147 | func (m Migration0001Accounts) Up(db *sql.DB) error { 148 | // We embed the Accounts model to avoid colliding with the package-level Accounts model used for queries. You could also use `rem.Config` as demonstrated in the Models documentation section. 149 | type Accounts struct { 150 | Id int64 `db:"id" db_primary:"true"` 151 | Name string `db:"name" db_max_length:"100"` 152 | } 153 | 154 | // Note that we don't use rem.Register[To](), because we don't want to cache the model structure used within the migration. 155 | _, err := rem.Use[Accounts]().TableCreate(db) 156 | return err 157 | } 158 | 159 | func (m Migration0001Accounts) Down(db *sql.DB) error { 160 | // Fields aren't needed for dropping a table. 161 | type Accounts struct{} 162 | 163 | _, err := rem.Use[Accounts]().TableDrop(db) 164 | return err 165 | } 166 | ``` 167 | 168 | Then run the migrations: 169 | 170 | ```go 171 | logs, err := rem.MigrateUp(db, []rem.Migration{ 172 | Migration0001Accounts{}, 173 | // More migrations... 174 | }) 175 | // logs []string 176 | // For example: {"Migrating up to Migration0001Accounts..."} 177 | ``` 178 | 179 | REM will create a `migrationlogs` table to track which migrations have been run. Execution of subsequent migrations will stop if an error is returned. Use `rem.MigrateDown(*sql.DB, []rem.Migration)` to run migrations in reverse. 180 | 181 | 182 | ## Fields 183 | 184 | ### Field Types 185 | 186 | REM determines column types based on Go field types. The following table shows the default column types for each Go primative. 187 | 188 | **Note:** REM uses special Go types for nullable columns. Don't use pointers for model fields. 189 | 190 | Go | MySQL | PostgreSQL | SQLite 191 | --- | --- | --- | --- 192 | `bool` | `BOOLEAN` | `BOOLEAN` | `BOOLEAN`\[1\] 193 | `[]byte` | - | - | - 194 | `int8` | `TINYINT` | `SMALLINT` | `INTEGER` 195 | `int16` | `SMALLINT` | `SMALLINT` | `INTEGER` 196 | `int32` | `INTEGER` | `INTEGER` | `INTEGER` 197 | `int64` | `BIGINT` | `BIGINT` | `INTEGER` 198 | `float32` | `FLOAT` | - | `REAL` 199 | `float64` | `DOUBLE` | `DOUBLE PRECISION` | `REAL` 200 | `string` | `VARCHAR`,`TEXT`\[2\] | `VARCHAR`,`TEXT`\[2\] | `TEXT` 201 | `time.Time` | `DATETIME`\[3\] | `TIMESTAMP`\[4\] | `DATETIME` 202 | 203 | \[1\] SQLite `BOOLEAN` behaves as an `INTEGER` internally. The SQLite driver should automatically convert `bool` field values to `0` or `1` when parameterized. 204 | 205 | \[2\] The `VARCHAR` column type is used for `string` and `sql.NullString` fields when the `db_max_length` field tag is provided. Otherwise, `TEXT` is used. 206 | 207 | \[3\] Go's most popular MySQL driver requires adding the `parseTime=true` GET parameter to the connection string to properly scan into `time.Time` and `sql.NullTime` fields. 208 | 209 | \[4\] The PostgreSQL dialect defaults to `WITHOUT TIME ZONE` for time types. Add the `db_time_zone:"true"` field tag to use `WITH TIME ZONE` instead. 210 | 211 | Columns are not nullable by default. REM uses the standard `database/sql` package types to represent nullable columns. 212 | 213 | Not Null | Nullable 214 | --- | --- 215 | `bool` | `sql.NullBool` 216 | `float64` | `sql.NullFloat64` 217 | `int16` | `sql.NullInt16` 218 | `int32` | `sql.NullInt32` 219 | `int64` | `sql.NullInt64` 220 | `rem.ForeignKey[To]` | `rem.NullForeignKey[To]` 221 | `string` | `sql.NullString` 222 | `time.Time` | `sql.NullTime` 223 | 224 | Primary keys are specified with the `db_primary:"true"` field tag. All models must have a primary key. Integer fields that are primary keys will auto-increment. 225 | 226 | ```go 227 | // An auto-incrementing primary key. 228 | type A struct { 229 | Id int64 `db:"id" db_primary:"true"` 230 | } 231 | 232 | // VARCHAR primary key with no default value. 233 | type B struct { 234 | Guid string `db:"guid" db_max_length:"36" db_primary:"true"` 235 | } 236 | ``` 237 | 238 | ### Default 239 | 240 | The `db_default` field tag applies a default value to columns. It accepts any string. 241 | 242 | **Note:** Values provided to `db_default` are not escaped or otherwise sanitized. 243 | 244 | ```go 245 | // This timestamp uses the SQL function now() for its default value. 246 | type Logs struct { 247 | CreatedAt time.Time `db:"created_at" db_default:"now()"` 248 | // ... 249 | } 250 | ``` 251 | 252 | ### Unique 253 | 254 | The `db_unique:"true"` field tag applies a unique constraint to a column. 255 | 256 | ```go 257 | type Accounts struct { 258 | Nickname string `db:"created_at" db_unique:"true"` 259 | // ... 260 | } 261 | ``` 262 | 263 | ### Custom Types 264 | 265 | Custom column types can be set using the `db_type` field tag, which accpets any string value. 266 | 267 | **Note:** Values provided to `db_type` are not escaped or otherwise sanitized. 268 | 269 | ```go 270 | // An example of using PostgreSQL's JSONB type. 271 | type A struct { 272 | Id int64 `db:"id" db_primary:"true"` 273 | Data []byte `db:"data" db_type:"JSONB NOT NULL"` 274 | } 275 | 276 | // db_type takes priority over all other field tags, including primary key typing. 277 | type B struct { 278 | Guid string `db:"guid" db_type:"CHAR(36) NOT NULL" db_primary:"true"` 279 | } 280 | ``` 281 | 282 | Custom Go types may also be used for model fields, but they must implement the `driver.Valuer` and `sql.Scanner` interfaces in additon to being supported by your database driver. 283 | 284 | ### Foreign Keys 285 | 286 | Foreign keys are specified with the `rem.ForeignKey[To]` and `rem.NullForeignKey[To]` field types. REM automatically matches the foreign key column type to the primary key of the target model. 287 | 288 | On the other side of the relation, use `rem.OneToMany[To]`. On both sides of the relation, the `db` field tag signifies the column on the `rem.ForeignKey[To]` side. 289 | 290 | ```go 291 | type Groups struct { 292 | Members rem.OneToMany[Members] `db:"group_id"` 293 | Id int64 `db:"id" db_primary:"true"` 294 | } 295 | 296 | type Members struct { 297 | Group rem.ForeignKey[Groups] `db:"group_id"` 298 | Id int64 `db:"id" db_primary:"true"` 299 | } 300 | ``` 301 | 302 | See [Fetch Related](#fetch-related) for information on querying relationships effeciently. 303 | 304 | Relations may also be queried lazily. 305 | 306 | ```go 307 | // Lazily fetch from a one-to-many field. 308 | group, err := rem.Use[Groups]().Filter("id", "=", 100).First(db) 309 | if err != nil { 310 | panic(err) 311 | } 312 | accounts, err := group.Members.All(db) 313 | // accounts []*Accounts 314 | 315 | // Lazily fetch from a foreign key field. 316 | account, err := rem.Use[Accounts]().Filter("id", "=", 200).First(db) 317 | if err != nil { 318 | panic(err) 319 | } 320 | group, err := account.Group.Fetch(db) 321 | // group *Groups 322 | ``` 323 | 324 | Foreign key `ON DELETE` and `ON UPDATE` constraints, such as `CASCADE` or `SET NULL`, may be set with the `db_on_delete` and `db_on_update` field tags. 325 | 326 | ```go 327 | type Members struct { 328 | Group rem.NullForeignKey[Groups] `db:"group_id" db_on_delete:"SET NULL" db_on_update:"SET NULL"` 329 | // ... 330 | } 331 | ``` 332 | 333 | 334 | ## Reference 335 | 336 | ### All 337 | 338 | Executes a query and returns a list of records. 339 | 340 | ```go 341 | accounts, err := rem.Use[Accounts]().All(db) 342 | // accounts []*Accounts 343 | 344 | accounts, err := rem.Use[Accounts]().AllToMap(db) 345 | // accounts []map[string]interface{} 346 | ``` 347 | 348 | 349 | ### Context 350 | 351 | Pass a Golang context to queries. 352 | 353 | ```go 354 | var ctx context.Context 355 | rem.Use[Accounts]().Context(ctx).All(db) 356 | ``` 357 | 358 | 359 | ### Count 360 | 361 | The `Count` convenience method returns the number of matching records. 362 | 363 | ```go 364 | count, err := rem.Use[Accounts]().Filter("id", "<", 100).Count(db) 365 | // count uint 366 | ``` 367 | 368 | 369 | ### Delete 370 | 371 | The `Delete` convenience method deletes matching records. 372 | 373 | ```go 374 | results, err := rem.Use[Accounts]().Filter("id", "=", 100).Delete(db) 375 | // results sql.Result 376 | ``` 377 | 378 | 379 | ### Dialect 380 | 381 | Set the dialect for a specific query. This takes priority over the default dialect. 382 | 383 | ```go 384 | rem.Use[Accounts]().Dialect(mysqldialect.Dialect{}).All(db) 385 | ``` 386 | 387 | 388 | ### Fetch Related 389 | 390 | REM can optimize foreign key and one-to-many record lookups. This is done with the `FetchRelated` method, which takes any number of strings that represent the relation fields to prefetch. 391 | 392 | Regardless of which side of the relationship you start from or how many records are being fetched initially, REM will only execute one additional query for prefetching. 393 | 394 | ```go 395 | // Model definitions for Groups <->> Accounts relationship. 396 | type Accounts struct { 397 | Group rem.ForeignKey[Groups] `db:"group_id"` 398 | Id int64 `db:"id" db_primary:"true"` 399 | Name string `db:"name" db_max_length:"100"` 400 | } 401 | 402 | type Groups struct { 403 | Accounts rem.OneToMany[Accounts] `db:"group_id"` 404 | Id int64 `db:"id" db_primary:"true"` 405 | Name string `db:"name" db_max_length:"100"` 406 | } 407 | ``` 408 | 409 | ```go 410 | groups, err := rem.Use[Groups]().FetchRelated("Accounts").All(db) 411 | for _, group := range groups { 412 | // group *Groups 413 | // group.Accounts.Rows []*Accounts 414 | } 415 | 416 | accounts, err := rem.Use[Accounts]().FetchRelated("Group").All(db) 417 | for _, account := range accounts { 418 | // account *Accounts 419 | // account.Group.Row *Groups 420 | // account.Group.Valid bool 421 | } 422 | ``` 423 | 424 | 425 | ### Filter 426 | 427 | REM provides a few mechanisms for filtering database results. The most basic is the `Filter` method, which takes a left side value, operator, and right side value. 428 | 429 | Typically, the left side is a column name, which is represented by a `string`. 430 | 431 | The operator is always a `string`. Use uppercase for alphabetical operators such as `"IN"`, `"NOT IN"`, `"IS"`, `"IS NOT"`, `"EXISTS"`, and so on. 432 | 433 | The right side may be any value supported by the database driver for parameterization. 434 | 435 | The left and right sides may also be `rem.DialectStringerWithArgs`, `rem.DialectStringer`, or `rem.SqlUnsafe`. These types are used for more advanced filtering, such as subqueries, joins, or SQL function calls. 436 | 437 | ```go 438 | rem.Use[Accounts]().Filter("id", ">=", 100).All(db) 439 | 440 | // Filters may be chained. This is equivalent to "SELECT * FROM accounts WHERE id >= 100 AND id < 200". 441 | rem.Use[Accounts](). 442 | Filter("id", ">=", 100). 443 | Filter("id", "<", 200). 444 | All(db) 445 | 446 | // Chain filters with an OR using `rem.Q`. This is equivalent to "SELECT * FROM accounts WHERE name = 'foo' OR (id >= 100 AND id < 200"). 447 | rem.Use[Accounts](). 448 | FilterOr( 449 | rem.Q("name", "=", "foo"), 450 | rem.And( 451 | rem.Q("id", ">=", 100), 452 | rem.Q("id", "<", 200), 453 | ), 454 | ). 455 | All(db) 456 | 457 | // Complex chained and nested filters are fully supported. 458 | rem.Use[Accounts](). 459 | FilterAnd( 460 | rem.Q("a", "=", "foo"), 461 | rem.Or( 462 | rem.Q("ab", "=", "bar"), 463 | rem.And( 464 | rem.Q("abc1", ">", 100), 465 | rem.Q("abc2", "<", 200), 466 | ), 467 | ), 468 | ). 469 | FilterOr( 470 | rem.Q("b1", "IS", nil), 471 | rem.Q("b2", "IN", interface{}{10, 20, 30}), 472 | ). 473 | All(db) 474 | ``` 475 | 476 | #### Custom SQL 477 | 478 | Safely parameterized SQL may be embedded via the `rem.Sql()` and `rem.Param()` functions. String arguments to `rem.Sql()` are not escaped or otherwise sanitized. `rem.Param()` arguments are parameterized by the database driver. 479 | 480 | ```go 481 | // SQL: SELECT * FROM logs WHERE data.tags ?| array[$1] 482 | // Parameters: []interface{}{"foo"} 483 | rem.Use[Logs](). 484 | Filter("data.tags", "?|", rem.Sql("array[", rem.Param("foo"), "]")). 485 | All(db) 486 | ``` 487 | 488 | Raw SQL may also be embedded into either the left or right side of a filters via the `rem.Unsafe()` function. 489 | 490 | **Note:** Values provided to `rem.Unsafe()` are not escaped or otherwise sanitized. Only use this function with trusted values. 491 | 492 | ```go 493 | // SQL: SELECT * FROM accounts WHERE upper(name) = $1 494 | // Parameters: []interface{}{"FOO"} 495 | rem.Use[Accounts](). 496 | Filter(rem.Unsafe("upper(name)"), "=", "FOO"). 497 | All(db) 498 | ``` 499 | 500 | #### Subqueries 501 | 502 | REM allows subqueries to be embedded via the standard query syntax. 503 | 504 | ```go 505 | // SQL: SELECT * FROM accounts WHERE id IN (SELECT account_id FROM groups WHERE name = $1) 506 | // Parameters: []interface{}{"Group 1"} 507 | rem.Use[Accounts](). 508 | Filter("id", "IN", rem.Use[Groups]().Select("account_id").Filter("name", "=", "Group 1")). 509 | All(db) 510 | ``` 511 | 512 | The `rem.Exists()` and `rem.NotExists()` functions are provided as a convenience for subqueries that only need to check for the existence of a record. 513 | 514 | `rem.Column()` is also used in the following example to properly handle the column name that is used on the right side of a filter. 515 | 516 | ```go 517 | // SQL: SELECT * FROM groups WHERE EXISTS (SELECT * FROM accounts WHERE accounts.group_id = groups.id) 518 | rem.Use[Groups](). 519 | FilterAnd( 520 | rem.Exists(rem.Use[Accounts]().Filter("accounts.group_id", "=", rem.Column("groups.id"))) 521 | ). 522 | All(db) 523 | ``` 524 | 525 | 526 | ### First 527 | 528 | The `First` convenience method returns a single record. A `sql.ErrNoRows` error is returned if no matching records are found. 529 | 530 | ```go 531 | account, err := rem.Use[Accounts]().Filter("id", "=", 1).First(db) 532 | // account *Accounts 533 | 534 | account, err := rem.Use[Accounts]().Filter("id", "=", 1).FirstToMap(db) 535 | // account map[string]interface{} 536 | ``` 537 | 538 | 539 | ### Insert 540 | 541 | The `Insert` method adds new records to the database. 542 | 543 | The first argument is a `*sql.DB` instance. 544 | 545 | The second argument is a pointer to the new record. 546 | 547 | **Note:** Zero-valued primary keys aren't included in inserts via the `Insert` method. 548 | 549 | ```go 550 | account := &Accounts{ 551 | Name: "New Name", 552 | } 553 | 554 | results, err := rem.Use[Accounts]().Insert(db, account) 555 | // results sql.Result 556 | ``` 557 | 558 | REM also provides a `UpdateMap` convenience method that updates matching records with all columns provided by a `map[string]interface{}`. 559 | 560 | **Note:** Zero-valued primary keys **will** be included when provided to inserts via the `InsertMap` method. 561 | 562 | ```go 563 | account := map[string]interface{}{ 564 | "name": "New Name", 565 | } 566 | 567 | results, err := rem.Use[Accounts]().InsertMap(db, account) 568 | ``` 569 | 570 | 571 | ### Join 572 | 573 | The `Join`, `JoinFull`, `JoinLeft`, and `JoinRight` methods are for performing their respective types of SQL joins. 574 | 575 | The first argument is the table to join. 576 | 577 | The second argument takes any number of filters to join on. 578 | 579 | ```go 580 | rows, err := rem.Use[Accounts](). 581 | Select("accounts.id", "accounts.name", rem.As("groups.name", "group_name")). 582 | Join("groups", rem.Q("groups.id", "=", rem.Column("accounts.group_id"))). 583 | AllToMap(db) 584 | 585 | // Use a custom model. 586 | type AccountsWithGroupName struct { 587 | GroupName string `db:"group_name"` 588 | Id string `db:"id" db_primary:"true"` 589 | Name string `db:"name"` 590 | } 591 | 592 | rows, err := rem.Use[AccountsWithGroupName](rem.Config{Table: "accounts"}). 593 | Select(rem.As("accounts.id", "id"), rem.As("accounts.name", "name"), rem.As("groups.name", "group_name")). 594 | Join("groups", rem.Q("groups.id", "=", rem.Column("accounts.group_id"))). 595 | All(db) 596 | 597 | // Use Query() to join without selecting columns. 598 | rows, err := rem.Use[Accounts](). 599 | Query(). 600 | JoinFull("groups", rem.Or( 601 | rem.Q("groups.id", "IS", nil), 602 | rem.Q("groups.id", "=", rem.Column("accounts.group_id")), 603 | ). 604 | AllToMap(db) 605 | ``` 606 | 607 | 608 | ### Limit and Offset 609 | 610 | The `Limit` and `Offset` methods both take a single `int64` argument. 611 | 612 | ```go 613 | // LIMIT 10 614 | rem.Use[Accounts]().Limit(10).All(db) 615 | 616 | // LIMIT 10 OFFSET 20 617 | rem.Use[Accounts]().Limit(10).Offset(20).All(db) 618 | ``` 619 | 620 | 621 | ### Scan Map 622 | 623 | The `ScanMap` convenience method converts a `map[string]interface{}` into a model pointer. 624 | 625 | ```go 626 | data := map[string]interface{}{ 627 | "id": 100, 628 | "name": "New Name", 629 | } 630 | 631 | account, err := rem.Use[Accounts].ScanMap(data) 632 | // account *Accounts 633 | ``` 634 | 635 | 636 | ### Select 637 | 638 | By default, queries scans all columns on the model. The `Select` method takes any number of strings, which when present, represent the only columns to scan. It also accepts `rem.DialectStringer`, and `rem.SqlUnsafe` values for special cases. 639 | 640 | ```go 641 | // SELECT id FROM accounts 642 | rem.Use[Accounts]().Select("id").All(db) 643 | 644 | // SELECT id, UPPER(name) as name FROM accounts 645 | rem.Use[Accounts]().Select("id", rem.Unsafe("UPPER(name) as name")).All(db) 646 | ``` 647 | 648 | 649 | ### Sort 650 | 651 | The `Sort` method takes any number of strings, which represent columns. Using `-` as a prefix will sort in descending order. 652 | 653 | ```go 654 | // ORDER BY name ASC 655 | rem.Use[Accounts]().Sort("name").All(db) 656 | 657 | // ORDER BY name DESC 658 | rem.Use[Accounts]().Sort("-name").All(db) 659 | 660 | // ORDER BY name ASC, id DESC 661 | rem.Use[Accounts]().Sort("name", "-id").All(db) 662 | ``` 663 | 664 | 665 | ### SQL All 666 | 667 | Executes a raw SQL query with parameters and returns a list of records. 668 | 669 | ```go 670 | accounts, err := rem.Use[Accounts]().SqlAll(db, "select * from accounts where id >= ?", 100) 671 | // accounts []*Accounts 672 | 673 | accounts, err := rem.Use[Accounts]().SqlAllToMap(db, "select * from accounts where id >= ?", 100) 674 | // accounts []map[string]interface{} 675 | ``` 676 | 677 | 678 | ### Table Column Add 679 | 680 | The `TableColumnAdd` method adds a column to a table. A field must exist in the model struct for the column to be added. 681 | 682 | ```go 683 | type Accounts struct { 684 | Id int64 `db:"id" db_primary:"true"` 685 | Name string `db:"name"` 686 | IsAdmin bool `db:"is_admin"` 687 | } 688 | 689 | _, err := rem.Use[Accounts]().TableColumnAdd(db, "is_admin") 690 | ``` 691 | 692 | 693 | ### Table Column Drop 694 | 695 | The `TableColumnDrop` method drops a column to a table. 696 | 697 | ```go 698 | _, err := rem.Use[Accounts]().TableColumnDrop(db, "is_admin") 699 | ``` 700 | 701 | 702 | ### Table Create 703 | 704 | The `TableCreate` method creates a table for the model. 705 | 706 | ```go 707 | _, err := rem.Use[Accounts]().TableCreate(db) 708 | 709 | // Override the table name. 710 | _, err := rem.Use[Accounts](rem.Config{Table: "users"}).TableCreate(db) 711 | 712 | // Only create the table if it doesn't exist. 713 | _, err := rem.Use[Accounts]().TableCreate(db, rem.TableCreateConfig{IfNotExists: true}) 714 | ``` 715 | 716 | 717 | ### Table Drop 718 | 719 | The `TableDrop` method drops a table for the model. 720 | 721 | ```go 722 | _, err := rem.Use[Accounts]().TableDrop(db) 723 | 724 | // Override the table name. 725 | _, err := rem.Use[Accounts](rem.Config{Table: "users"}).TableDrop(db) 726 | 727 | // Only drop the table if it exists. 728 | _, err := rem.Use[Accounts]().TableDrop(db, rem.TableDropConfig{IfExists: true}) 729 | ``` 730 | 731 | 732 | ### To Map 733 | 734 | The `ToMap` convenience method converts a model pointer into a `map[string]interface{}`. Keys on the returned map are column names. 735 | 736 | **Note:** Zero-valued primary keys are excluded from the returned map. 737 | 738 | **Note:** Fields that implement the `driver.Valuer` interface are converted to their `Value()` representation. For example, a `sql.NullString` will be converted to either `string` or `nil`. 739 | 740 | ```go 741 | account := &Accounts{ 742 | Id: 100, 743 | Name: "New Name", 744 | } 745 | 746 | data := rem.Use[Accounts]().ToMap(account) 747 | // data map[string]interface{} 748 | ``` 749 | 750 | 751 | ### Transaction 752 | 753 | REM supports transactions via the `Transaction(*sql.Tx)` method. 754 | 755 | ```go 756 | tx, _ := db.Begin() 757 | 758 | _, err := rem.Use[Accounts](). 759 | Filter("id", "=", 100). 760 | Transaction(tx). 761 | Delete(db) 762 | 763 | if err != nil { 764 | tx.Rollback() 765 | panic(err) 766 | } 767 | 768 | err = tx.Commit() 769 | if err != nil { 770 | panic(err) 771 | } 772 | ``` 773 | 774 | 775 | ### Update 776 | 777 | The `Update` method updates matching records. 778 | 779 | The first argument is a `*sql.DB` instance. 780 | 781 | The second argument is a pointer to the updated record. 782 | 783 | The third argument is a spread of columns to update. If no columns are provided, the update will fail. This minor annoyance is by design and is intended to ensure that column updates are intentional. 784 | 785 | ```go 786 | account := &Accounts{ 787 | Id: 200, 788 | Name: "New Name", 789 | } 790 | 791 | // The `name` column will be updated, but `id` won't. 792 | results, err := rem.Use[Accounts](). 793 | Filter("id", "=", 100). 794 | Update(db, account, "name") 795 | 796 | // results sql.Result 797 | ``` 798 | 799 | REM also provides a `UpdateMap` convenience method that updates matching records with all columns provided by a `map[string]interface{}`. 800 | 801 | ```go 802 | account := map[string]interface{}{ 803 | "name": "New Name", 804 | } 805 | 806 | results, err := rem.Use[Accounts](). 807 | Filter("id", "=", 100). 808 | UpdateMap(db, account) 809 | ``` 810 | -------------------------------------------------------------------------------- /dialect.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | type Dialect interface { 8 | BuildDelete(QueryConfig) (string, []interface{}, error) 9 | BuildInsert(QueryConfig, map[string]interface{}, ...string) (string, []interface{}, error) 10 | BuildSelect(QueryConfig) (string, []interface{}, error) 11 | BuildTableColumnAdd(QueryConfig, string) (string, error) 12 | BuildTableColumnDrop(QueryConfig, string) (string, error) 13 | BuildTableCreate(QueryConfig, TableCreateConfig) (string, error) 14 | BuildTableDrop(QueryConfig, TableDropConfig) (string, error) 15 | BuildUpdate(QueryConfig, map[string]interface{}, ...string) (string, []interface{}, error) 16 | ColumnType(reflect.StructField) (string, error) 17 | Param(i int) string 18 | QuoteIdentifier(string) string 19 | } 20 | 21 | type DialectStringer interface { 22 | StringForDialect(Dialect) string 23 | } 24 | 25 | type DialectStringerWithArgs interface { 26 | StringWithArgs(Dialect, []interface{}) (string, []interface{}, error) 27 | } 28 | 29 | var defaultDialect Dialect 30 | 31 | func SetDialect(dialect Dialect) { 32 | defaultDialect = dialect 33 | } 34 | -------------------------------------------------------------------------------- /dialect_test.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | //lint:file-ignore U1000 Ignore report 9 | type testDialect struct{} 10 | 11 | func (dialect testDialect) BuildDelete(QueryConfig) (string, []interface{}, error) { 12 | panic("Not implemented") 13 | } 14 | 15 | func (dialect testDialect) BuildInsert(QueryConfig, map[string]interface{}, ...string) (string, []interface{}, error) { 16 | panic("Not implemented") 17 | } 18 | 19 | func (dialect testDialect) BuildSelect(config QueryConfig) (string, []interface{}, error) { 20 | return fmt.Sprintf("SELECT|FILTER%+v|", config.Filters), nil, nil 21 | } 22 | 23 | func (dialect testDialect) BuildTableColumnAdd(QueryConfig, string) (string, error) { 24 | panic("Not implemented") 25 | } 26 | 27 | func (dialect testDialect) BuildTableColumnDrop(QueryConfig, string) (string, error) { 28 | panic("Not implemented") 29 | } 30 | 31 | func (dialect testDialect) BuildTableCreate(QueryConfig, TableCreateConfig) (string, error) { 32 | panic("Not implemented") 33 | } 34 | 35 | func (dialect testDialect) BuildTableDrop(QueryConfig, TableDropConfig) (string, error) { 36 | panic("Not implemented") 37 | } 38 | 39 | func (dialect testDialect) BuildUpdate(QueryConfig, map[string]interface{}, ...string) (string, []interface{}, error) { 40 | panic("Not implemented") 41 | } 42 | 43 | func (dialect testDialect) ColumnType(reflect.StructField) (string, error) { 44 | panic("Not implemented") 45 | } 46 | 47 | func (dialect testDialect) Param(identifier int) string { 48 | return fmt.Sprintf("$%d", identifier) 49 | } 50 | 51 | func (dialect testDialect) QuoteIdentifier(identifier string) string { 52 | return fmt.Sprintf(`"%s"`, identifier) 53 | } 54 | -------------------------------------------------------------------------------- /filters.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | var filterOperators = map[string]struct{}{ 9 | "=": {}, 10 | "!=": {}, 11 | "<>": {}, 12 | "<": {}, 13 | ">": {}, 14 | "<=": {}, 15 | ">=": {}, 16 | "LIKE": {}, 17 | "NOT LIKE": {}, 18 | "IN": {}, 19 | "NOT IN": {}, 20 | "IS": {}, 21 | "IS NOT": {}, 22 | "ALL": {}, 23 | "<> ALL": {}, 24 | "ANY": {}, 25 | "<> ANY": {}, 26 | "EXISTS": {}, 27 | "NOT EXISTS": {}, 28 | "OVERLAPS": {}, 29 | "?": {}, 30 | "?&": {}, 31 | "?|": {}, 32 | "@>": {}, 33 | "<@": {}, 34 | } 35 | 36 | type FilterClause struct { 37 | Left interface{} 38 | Operator string 39 | Right interface{} 40 | Rule string 41 | } 42 | 43 | func (filter FilterClause) leftString(dialect Dialect, args []interface{}) ([]interface{}, string, error) { 44 | switch left := filter.Left.(type) { 45 | case string: 46 | return args, dialect.QuoteIdentifier(left), nil 47 | 48 | case DialectStringerWithArgs: 49 | lv, args, err := left.StringWithArgs(dialect, args) 50 | return args, lv, err 51 | 52 | case DialectStringer: 53 | return args, left.StringForDialect(dialect), nil 54 | 55 | case SqlUnsafe: 56 | return args, left.Sql, nil 57 | } 58 | 59 | return nil, "", fmt.Errorf("rem: unsupported type for left side of filter clause '%#v'", filter.Left) 60 | } 61 | 62 | func (filter FilterClause) rightString(dialect Dialect, args []interface{}) ([]interface{}, string, error) { 63 | switch right := filter.Right.(type) { 64 | case DialectStringerWithArgs: 65 | rv, args, err := right.StringWithArgs(dialect, args) 66 | return args, rv, err 67 | 68 | case DialectStringer: 69 | return args, right.StringForDialect(dialect), nil 70 | 71 | case SqlUnsafe: 72 | return args, right.Sql, nil 73 | 74 | case nil: 75 | return args, "NULL", nil 76 | 77 | case []interface{}: 78 | var sliceArgs strings.Builder 79 | for j, arg := range right { 80 | args = append(args, arg) 81 | if j > 0 { 82 | sliceArgs.WriteString(",") 83 | } 84 | sliceArgs.WriteString(dialect.Param(len(args))) 85 | } 86 | return args, sliceArgs.String(), nil 87 | 88 | default: 89 | args = append(args, right) 90 | return args, dialect.Param(len(args)), nil 91 | } 92 | } 93 | 94 | func (filter FilterClause) StringWithArgs(dialect Dialect, args []interface{}) (string, []interface{}, error) { 95 | switch filter.Rule { 96 | case "(": 97 | return " (", args, nil 98 | 99 | case ")": 100 | return " )", args, nil 101 | 102 | case "AND": 103 | return " AND", args, nil 104 | 105 | case "NOT": 106 | return " NOT", args, nil 107 | 108 | case "OR": 109 | return " OR", args, nil 110 | 111 | case "WHERE": 112 | if _, ok := filterOperators[filter.Operator]; !ok { 113 | return "", nil, fmt.Errorf("rem: invalid operator '%s' on WHERE clause", filter.Operator) 114 | } 115 | 116 | var err error 117 | var left string 118 | args, left, err = filter.leftString(dialect, args) 119 | if err != nil { 120 | return "", nil, err 121 | } 122 | 123 | var right string 124 | args, right, err = filter.rightString(dialect, args) 125 | if err != nil { 126 | return "", nil, err 127 | } 128 | 129 | if filter.Operator == "EXISTS" || filter.Operator == "NOT EXISTS" { 130 | return fmt.Sprintf(" %s (%s)", filter.Operator, right), args, nil 131 | } else if filter.Operator == "IN" || filter.Operator == "NOT IN" || filter.Operator == "ALL" || filter.Operator == "<> ALL" || filter.Operator == "ANY" || filter.Operator == "<> ANY" { 132 | return fmt.Sprintf(" %s %s (%s)", left, filter.Operator, right), args, nil 133 | } else if filter.Operator == "?&" || filter.Operator == "?|" { 134 | switch filter.Right.(type) { 135 | case DialectStringerWithArgs, SqlUnsafe: 136 | break 137 | default: 138 | return fmt.Sprintf(" %s %s array[%s]", left, filter.Operator, right), args, nil 139 | } 140 | } 141 | return fmt.Sprintf(" %s %s %s", left, filter.Operator, right), args, nil 142 | } 143 | 144 | return "", args, fmt.Errorf("rem: invalid rule '%s' on WHERE clause", filter.Rule) 145 | } 146 | 147 | func And(clauses ...interface{}) []FilterClause { 148 | flat := make([]FilterClause, 0) 149 | for _, clause := range clauses { 150 | flat = flattenFilterClause(flat, clause) 151 | } 152 | 153 | indent := 0 154 | filter := []FilterClause{{Rule: "("}} 155 | for i, clause := range flat { 156 | if i > 0 && indent == 0 && flat[i-1].Rule != "NOT" { 157 | filter = append(filter, FilterClause{Rule: "AND"}) 158 | } 159 | if clause.Rule == "(" { 160 | indent++ 161 | } else if clause.Rule == ")" { 162 | indent-- 163 | } 164 | filter = append(filter, clause) 165 | } 166 | return append(filter, FilterClause{Rule: ")"}) 167 | } 168 | 169 | func Exists(value interface{}) FilterClause { 170 | return FilterClause{ 171 | Left: "", 172 | Operator: "EXISTS", 173 | Right: value, 174 | Rule: "WHERE", 175 | } 176 | } 177 | 178 | func flattenFilterClause(clauses []FilterClause, clause interface{}) []FilterClause { 179 | switch ct := clause.(type) { 180 | case FilterClause: 181 | clauses = append(clauses, ct) 182 | case []FilterClause: 183 | clauses = append(clauses, ct...) 184 | } 185 | return clauses 186 | } 187 | 188 | func Not(column interface{}, operator string, value interface{}) []FilterClause { 189 | return []FilterClause{ 190 | {Rule: "NOT"}, 191 | {Rule: "("}, 192 | Q(column, operator, value), 193 | {Rule: ")"}, 194 | } 195 | } 196 | 197 | func NotExists(value interface{}) FilterClause { 198 | return FilterClause{ 199 | Left: "", 200 | Operator: "NOT EXISTS", 201 | Right: value, 202 | Rule: "WHERE", 203 | } 204 | } 205 | 206 | func Or(clauses ...interface{}) []FilterClause { 207 | flat := make([]FilterClause, 0) 208 | for _, clause := range clauses { 209 | flat = flattenFilterClause(flat, clause) 210 | } 211 | 212 | indent := 0 213 | filter := []FilterClause{{Rule: "("}} 214 | for i, clause := range flat { 215 | if i > 0 && indent == 0 && flat[i-1].Rule != "NOT" { 216 | filter = append(filter, FilterClause{Rule: "OR"}) 217 | } 218 | if clause.Rule == "(" { 219 | indent++ 220 | } else if clause.Rule == ")" { 221 | indent-- 222 | } 223 | filter = append(filter, clause) 224 | } 225 | return append(filter, FilterClause{Rule: ")"}) 226 | } 227 | 228 | func Q(column interface{}, operator string, value interface{}) FilterClause { 229 | return FilterClause{ 230 | Left: column, 231 | Operator: operator, 232 | Right: value, 233 | Rule: "WHERE", 234 | } 235 | } 236 | -------------------------------------------------------------------------------- /filters_test.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "testing" 5 | 6 | "golang.org/x/exp/slices" 7 | ) 8 | 9 | func TestFilterAnd(t *testing.T) { 10 | clauses := And( 11 | "SKIP", 12 | FilterClause{Rule: "A"}, 13 | FilterClause{Rule: "B"}, 14 | And( 15 | FilterClause{Rule: "C.1"}, 16 | FilterClause{Rule: "C.2"}, 17 | ), 18 | ) 19 | 20 | expected := []FilterClause{ 21 | {Rule: "("}, 22 | {Rule: "A"}, 23 | {Rule: "AND"}, 24 | {Rule: "B"}, 25 | {Rule: "AND"}, 26 | {Rule: "("}, 27 | {Rule: "C.1"}, 28 | {Rule: "AND"}, 29 | {Rule: "C.2"}, 30 | {Rule: ")"}, 31 | {Rule: ")"}, 32 | } 33 | if !slices.Equal(clauses, expected) { 34 | t.Errorf("Expected '%+v', got '%+v'", expected, clauses) 35 | } 36 | } 37 | 38 | func TestFilterNot(t *testing.T) { 39 | clauses := And( 40 | "SKIP", 41 | FilterClause{Rule: "A"}, 42 | Not("B", "=", "foo"), 43 | Or( 44 | Not("C.1", "=", "bar"), 45 | FilterClause{Rule: "C.2"}, 46 | ), 47 | ) 48 | 49 | expected := []FilterClause{ 50 | {Rule: "("}, 51 | {Rule: "A"}, 52 | {Rule: "AND"}, 53 | {Rule: "NOT"}, 54 | {Rule: "("}, 55 | {Left: "B", Operator: "=", Right: "foo", Rule: "WHERE"}, 56 | {Rule: ")"}, 57 | {Rule: "AND"}, 58 | {Rule: "("}, 59 | {Rule: "NOT"}, 60 | {Rule: "("}, 61 | {Left: "C.1", Operator: "=", Right: "bar", Rule: "WHERE"}, 62 | {Rule: ")"}, 63 | {Rule: "OR"}, 64 | {Rule: "C.2"}, 65 | {Rule: ")"}, 66 | {Rule: ")"}, 67 | } 68 | if !slices.Equal(clauses, expected) { 69 | t.Errorf("Expected '%+v', got '%+v'", expected, clauses) 70 | } 71 | } 72 | 73 | func TestFilterOr(t *testing.T) { 74 | clauses := Or( 75 | FilterClause{Rule: "A"}, 76 | "SKIP", 77 | Or( 78 | FilterClause{Rule: "B.1"}, 79 | FilterClause{Rule: "B.2"}, 80 | ), 81 | FilterClause{Rule: "C"}, 82 | ) 83 | 84 | expected := []FilterClause{ 85 | {Rule: "("}, 86 | {Rule: "A"}, 87 | {Rule: "OR"}, 88 | {Rule: "("}, 89 | {Rule: "B.1"}, 90 | {Rule: "OR"}, 91 | {Rule: "B.2"}, 92 | {Rule: ")"}, 93 | {Rule: "OR"}, 94 | {Rule: "C"}, 95 | {Rule: ")"}, 96 | } 97 | if !slices.Equal(clauses, expected) { 98 | t.Errorf("Expected '%+v', got '%+v'", expected, clauses) 99 | } 100 | } 101 | 102 | func TestFlattenFilterClause(t *testing.T) { 103 | clauses := []interface{}{ 104 | FilterClause{Rule: "A"}, 105 | FilterClause{Rule: "B"}, 106 | []FilterClause{ 107 | {Rule: "C.1"}, 108 | {Rule: "C.2"}, 109 | }, 110 | "SKIP", 111 | FilterClause{Rule: "D"}, 112 | } 113 | expected := []FilterClause{ 114 | {Rule: "Z"}, 115 | {Rule: "A"}, 116 | {Rule: "B"}, 117 | {Rule: "C.1"}, 118 | {Rule: "C.2"}, 119 | {Rule: "D"}, 120 | } 121 | flat := []FilterClause{ 122 | {Rule: "Z"}, 123 | } 124 | for _, clause := range clauses { 125 | flat = flattenFilterClause(flat, clause) 126 | } 127 | if !slices.Equal(flat, expected) { 128 | t.Errorf("Expected '%+v', got '%+v'", expected, flat) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /foreign_key.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "database/sql" 5 | "encoding/json" 6 | "reflect" 7 | ) 8 | 9 | type ForeignKey[To any] struct { 10 | Row *To 11 | Valid bool 12 | } 13 | 14 | func (fk *ForeignKey[To]) Fetch(db *sql.DB) (*To, error) { 15 | query := &Query[To]{ 16 | Model: fk.Model(), 17 | } 18 | value := reflect.ValueOf(&fk.Row).Elem() 19 | id := value.FieldByName(query.Model.PrimaryField).Interface() 20 | return query.Filter("id", "=", id).First(db) 21 | } 22 | 23 | func (fk ForeignKey[To]) JsonValue() interface{} { 24 | if !fk.Valid { 25 | return nil 26 | } 27 | return fk.Model().ToJsonMap(fk.Row) 28 | } 29 | 30 | func (fk ForeignKey[To]) MarshalJSON() ([]byte, error) { 31 | if !fk.Valid { 32 | return json.Marshal(nil) 33 | } 34 | return json.Marshal(fk.Model().ToJsonMap(fk.Row)) 35 | } 36 | 37 | func (fk *ForeignKey[To]) Model() *Model[To] { 38 | if fk.Row == nil { 39 | var zero To 40 | fk.Row = &zero 41 | } 42 | return Use[To]() 43 | } 44 | 45 | func (fk *ForeignKey[To]) Query() *Query[To] { 46 | return &Query[To]{ 47 | Model: Use[To](), 48 | } 49 | } 50 | 51 | type NullForeignKey[To any] struct { 52 | Row *To 53 | Valid bool 54 | } 55 | 56 | func (fk *NullForeignKey[To]) Fetch(db *sql.DB) (*To, error) { 57 | query := &Query[To]{ 58 | Model: Use[To](), 59 | } 60 | value := reflect.ValueOf(&fk.Row).Elem() 61 | id := value.FieldByName(query.Model.PrimaryField).Interface() 62 | return query.Filter("id", "=", id).First(db) 63 | } 64 | 65 | func (fk NullForeignKey[To]) JsonValue() interface{} { 66 | if !fk.Valid { 67 | return nil 68 | } 69 | return fk.Model().ToJsonMap(fk.Row) 70 | } 71 | 72 | func (fk NullForeignKey[To]) MarshalJSON() ([]byte, error) { 73 | if !fk.Valid { 74 | return json.Marshal(nil) 75 | } 76 | return json.Marshal(fk.Model().ToJsonMap(fk.Row)) 77 | } 78 | 79 | func (fk *NullForeignKey[To]) Model() *Model[To] { 80 | if fk.Row == nil { 81 | var zero To 82 | fk.Row = &zero 83 | } 84 | return Use[To]() 85 | } 86 | 87 | func (fk *NullForeignKey[To]) Query() *Query[To] { 88 | return &Query[To]{ 89 | Model: Use[To](), 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/evantbyrne/rem 2 | 3 | go 1.21 4 | 5 | require golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 6 | 7 | require github.com/DATA-DOG/go-sqlmock v1.5.0 8 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= 2 | github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= 3 | golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= 4 | golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= 5 | -------------------------------------------------------------------------------- /migration.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | "reflect" 8 | "time" 9 | ) 10 | 11 | type Migration interface { 12 | Down(db *sql.DB) error 13 | Up(db *sql.DB) error 14 | } 15 | 16 | type MigrationLogs struct { 17 | CreatedAt time.Time `db:"created_at"` 18 | Direction string `db:"direction" db_max_length:"10"` 19 | Id int64 `db:"id" db_primary:"true"` 20 | MigrationType string `db:"migration_type" db_max_length:"255"` 21 | } 22 | 23 | func MigrateDown(db *sql.DB, migrations []Migration) ([]string, error) { 24 | logs, latestIndex, err := migrateSetup(db, migrations) 25 | if err != nil { 26 | return logs, err 27 | } 28 | migrationLogs := Use[MigrationLogs]() 29 | 30 | for i := latestIndex; i > -1; i-- { 31 | migrationType := reflect.TypeOf(migrations[i]).String() 32 | logs = append(logs, "Migrating down to "+migrationType+"...") 33 | if err := migrations[i].Down(db); err != nil { 34 | return logs, errors.Join(fmt.Errorf("migration %s: failed", migrationType), err) 35 | } 36 | _, err := migrationLogs.Insert(db, &MigrationLogs{ 37 | CreatedAt: time.Now(), 38 | Direction: "down", 39 | MigrationType: migrationType, 40 | }) 41 | if err != nil { 42 | return logs, errors.Join(fmt.Errorf("migration %s: failed to insert migration logs", migrationType), err) 43 | } 44 | } 45 | 46 | return logs, nil 47 | } 48 | 49 | func migrateSetup(db *sql.DB, migrations []Migration) ([]string, int, error) { 50 | logs := make([]string, 0) 51 | migrationLogs := Use[MigrationLogs]() 52 | 53 | _, err := migrationLogs.TableCreate(db, TableCreateConfig{IfNotExists: true}) 54 | if err != nil { 55 | return nil, -1, errors.Join(errors.New("rem: migrations setup: failed to create table for migration logs"), err) 56 | } 57 | 58 | latest, err := migrationLogs.Sort("-id").First(db) 59 | latestIndex := -1 60 | if err != nil { 61 | if !errors.Is(err, sql.ErrNoRows) { 62 | return nil, -1, errors.Join(errors.New("rem: migrations setup: failed to get migrations list"), err) 63 | } 64 | } else { 65 | for i, migration := range migrations { 66 | if latest.MigrationType == reflect.TypeOf(migration).String() { 67 | if latest.Direction == "down" { 68 | latestIndex = i - 1 69 | } else { 70 | latestIndex = i 71 | } 72 | break 73 | } 74 | } 75 | } 76 | 77 | return logs, latestIndex, nil 78 | } 79 | 80 | func MigrateUp(db *sql.DB, migrations []Migration) ([]string, error) { 81 | logs, latestIndex, err := migrateSetup(db, migrations) 82 | if err != nil { 83 | return logs, err 84 | } 85 | migrationLogs := Use[MigrationLogs]() 86 | 87 | for i := latestIndex + 1; i < len(migrations); i++ { 88 | migrationType := reflect.TypeOf(migrations[i]).String() 89 | logs = append(logs, "Migrating up to "+migrationType+"...") 90 | if err := migrations[i].Up(db); err != nil { 91 | return logs, errors.Join(fmt.Errorf("rem: migration %s: failed", migrationType), err) 92 | } 93 | _, err := migrationLogs.Insert(db, &MigrationLogs{ 94 | CreatedAt: time.Now(), 95 | Direction: "up", 96 | MigrationType: migrationType, 97 | }) 98 | if err != nil { 99 | return logs, errors.Join(fmt.Errorf("rem: migration %s: failed to insert migration logs", migrationType), err) 100 | } 101 | } 102 | 103 | return logs, nil 104 | } 105 | -------------------------------------------------------------------------------- /model.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "fmt" 8 | "reflect" 9 | "strings" 10 | "time" 11 | ) 12 | 13 | type Config struct { 14 | Table string 15 | } 16 | 17 | type JsonValuer interface { 18 | JsonValue() interface{} 19 | } 20 | 21 | type Model[T any] struct { 22 | Fields map[string]reflect.StructField 23 | PrimaryColumn string 24 | PrimaryField string 25 | Table string 26 | Type reflect.Type 27 | } 28 | 29 | func (model *Model[T]) All(db *sql.DB) ([]*T, error) { 30 | query := &Query[T]{Model: model} 31 | return query.All(db) 32 | } 33 | 34 | func (model *Model[T]) AllToMap(db *sql.DB) ([]map[string]interface{}, error) { 35 | query := &Query[T]{Model: model} 36 | return query.AllToMap(db) 37 | } 38 | 39 | func (model *Model[T]) Context(context context.Context) *Query[T] { 40 | return &Query[T]{ 41 | Config: QueryConfig{Context: context}, 42 | Model: model, 43 | } 44 | } 45 | 46 | func (model *Model[T]) Count(db *sql.DB) (uint, error) { 47 | query := &Query[T]{Model: model} 48 | return query.Count(db) 49 | } 50 | 51 | func (model *Model[T]) Dialect(dialect Dialect) *Query[T] { 52 | return &Query[T]{ 53 | dialect: dialect, 54 | Model: model, 55 | } 56 | } 57 | 58 | func (model *Model[T]) FetchRelated(columns ...string) *Query[T] { 59 | return &Query[T]{ 60 | Config: QueryConfig{FetchRelated: columns}, 61 | Model: model, 62 | } 63 | } 64 | 65 | func (model *Model[T]) Filter(column interface{}, operator string, value interface{}) *Query[T] { 66 | query := &Query[T]{Model: model} 67 | return query.Filter(column, operator, value) 68 | } 69 | 70 | func (model *Model[T]) FilterAnd(clauses ...interface{}) *Query[T] { 71 | query := &Query[T]{Model: model} 72 | return query.FilterAnd(clauses...) 73 | } 74 | 75 | func (model *Model[T]) FilterOr(clauses ...interface{}) *Query[T] { 76 | query := &Query[T]{Model: model} 77 | return query.FilterOr(clauses...) 78 | } 79 | 80 | func (model *Model[T]) Insert(db *sql.DB, row *T) (sql.Result, error) { 81 | query := &Query[T]{Model: model} 82 | return query.Insert(db, row) 83 | } 84 | 85 | func (model *Model[T]) InsertMap(db *sql.DB, data map[string]interface{}) (sql.Result, error) { 86 | query := &Query[T]{Model: model} 87 | return query.InsertMap(db, data) 88 | } 89 | 90 | func (model *Model[T]) Query() *Query[T] { 91 | return &Query[T]{Model: model} 92 | } 93 | 94 | func (model *Model[T]) Scan(rows *sql.Rows) (*T, error) { 95 | data, err := model.ScanToMap(rows) 96 | if err != nil { 97 | return nil, err 98 | } 99 | return model.ScanMap(data) 100 | } 101 | 102 | func (model *Model[T]) ScanMap(data map[string]interface{}) (*T, error) { 103 | var row T 104 | value := reflect.ValueOf(&row).Elem() 105 | 106 | for column, v := range data { 107 | if field, ok := model.Fields[column]; ok { 108 | if field := value.FieldByName(field.Name); field.IsValid() { 109 | columnValue := reflect.ValueOf(v) 110 | 111 | if v == nil { 112 | // database/sql null types (NullString, etc) default to `Valid: false`. 113 | // rem.ForeignKey and rem.NullForeignKey also follow this convention. 114 | 115 | } else if columnValue.CanConvert(field.Type()) { 116 | field.Set(columnValue) 117 | 118 | } else if field.Kind() == reflect.Struct { 119 | if scanner, ok := reflect.New(field.Type()).Interface().(sql.Scanner); ok { 120 | scanner.Scan(v) 121 | field.Set(reflect.ValueOf(scanner).Elem()) 122 | 123 | } else if strings.HasPrefix(field.Type().String(), "rem.ForeignKey[") || strings.HasPrefix(field.Type().String(), "rem.NullForeignKey[") { 124 | subModelQ := field.Addr().MethodByName("Model").Call(nil) 125 | subPrimaryField := reflect.Indirect(subModelQ[0]).FieldByName("PrimaryField").Interface().(string) 126 | subField := field.FieldByName("Row").Elem().FieldByName(subPrimaryField) 127 | if subField.IsValid() { 128 | // TODO: Handle primary keys that are nullable types 129 | subField.Set(columnValue) 130 | field.FieldByName("Valid").SetBool(true) 131 | } 132 | } else { 133 | return nil, fmt.Errorf("rem: unhandled struct conversion in scan from '%s' to '%s'", columnValue.Type(), field.Type()) 134 | } 135 | 136 | } else { 137 | return nil, fmt.Errorf("rem: unhandled type conversion in scan from '%s' to '%s'", columnValue.Type(), field.Type()) 138 | } 139 | } 140 | } 141 | } 142 | 143 | // OneToMany relationships. 144 | for _, field := range model.Fields { 145 | if strings.HasPrefix(field.Type.String(), "rem.OneToMany[") { 146 | oneToMany := value.FieldByName(field.Name) 147 | oneToMany.FieldByName("RelatedColumn").SetString(field.Tag.Get("db")) 148 | oneToMany.FieldByName("RowPk").Set(value.FieldByName(model.PrimaryField)) 149 | } 150 | } 151 | 152 | return &row, nil 153 | } 154 | 155 | func (model *Model[T]) ScanToMap(rows *sql.Rows) (map[string]interface{}, error) { 156 | columns, err := rows.Columns() 157 | if err != nil { 158 | return nil, err 159 | } 160 | 161 | pointers := make([]interface{}, len(columns)) 162 | for i, column := range columns { 163 | field, ok := model.Fields[column] 164 | if !ok { 165 | return nil, fmt.Errorf("rem: column '%s' not found on model '%T'", column, model) 166 | } 167 | fieldType := field.Type 168 | if strings.HasPrefix(fieldType.String(), "rem.ForeignKey[") || strings.HasPrefix(fieldType.String(), "rem.NullForeignKey[") { 169 | fk := reflect.New(fieldType) 170 | q := fk.MethodByName("Model").Call(nil) 171 | fkPrimaryField := reflect.Indirect(q[0]).FieldByName("PrimaryField").Interface().(string) 172 | pointers[i] = reflect.New(reflect.Indirect(reflect.Indirect(fk).FieldByName("Row")).FieldByName(fkPrimaryField).Type()).Interface() 173 | if strings.HasPrefix(fieldType.String(), "rem.NullForeignKey[") { 174 | switch pointers[i].(type) { 175 | case *bool: 176 | pointers[i] = new(sql.NullBool) 177 | case *byte: 178 | pointers[i] = new(sql.NullByte) 179 | case *int, *int64: 180 | pointers[i] = new(sql.NullInt64) 181 | case *int32: 182 | pointers[i] = new(sql.NullInt32) 183 | case *int8, *int16: 184 | pointers[i] = new(sql.NullInt16) 185 | case *float32, *float64: 186 | pointers[i] = new(sql.NullFloat64) 187 | case *string: 188 | pointers[i] = new(sql.NullString) 189 | case *time.Time: 190 | pointers[i] = new(sql.NullTime) 191 | } 192 | } 193 | } else { 194 | pointers[i] = reflect.New(fieldType).Interface() 195 | } 196 | } 197 | 198 | if err := rows.Scan(pointers...); err != nil { 199 | return nil, err 200 | } 201 | 202 | row := make(map[string]interface{}) 203 | for i, column := range columns { 204 | switch vt := reflect.ValueOf(pointers[i]).Elem().Interface().(type) { 205 | case driver.Valuer: 206 | row[column], _ = vt.Value() 207 | default: 208 | row[column] = vt 209 | } 210 | } 211 | 212 | return row, nil 213 | } 214 | 215 | func (model *Model[T]) Select(columns ...interface{}) *Query[T] { 216 | return &Query[T]{ 217 | Config: QueryConfig{Selected: columns}, 218 | Model: model, 219 | } 220 | } 221 | 222 | func (model *Model[T]) Sort(columns ...string) *Query[T] { 223 | return &Query[T]{ 224 | Config: QueryConfig{Sort: columns}, 225 | Model: model, 226 | } 227 | } 228 | 229 | func (model *Model[T]) SqlAll(db *sql.DB, sql string, args ...interface{}) ([]*T, error) { 230 | rows, err := db.Query(sql, args...) 231 | if err != nil { 232 | return nil, err 233 | } 234 | query := &Query[T]{ 235 | Model: model, 236 | Rows: rows, 237 | } 238 | return query.slice(db) 239 | } 240 | 241 | func (model *Model[T]) SqlAllToMap(db *sql.DB, sql string, args ...interface{}) ([]map[string]interface{}, error) { 242 | rows, err := db.Query(sql, args...) 243 | if err != nil { 244 | return nil, err 245 | } 246 | query := &Query[T]{ 247 | Model: model, 248 | Rows: rows, 249 | } 250 | defer query.Rows.Close() 251 | 252 | mapped := make([]map[string]interface{}, 0) 253 | for query.Rows.Next() { 254 | data, err := model.ScanToMap(query.Rows) 255 | if err != nil { 256 | return nil, err 257 | } 258 | mapped = append(mapped, data) 259 | } 260 | 261 | return mapped, nil 262 | } 263 | 264 | func (model *Model[T]) TableColumnAdd(db *sql.DB, column string) (sql.Result, error) { 265 | query := &Query[T]{Model: model} 266 | return query.TableColumnAdd(db, column) 267 | } 268 | 269 | func (model *Model[T]) TableColumnDrop(db *sql.DB, column string) (sql.Result, error) { 270 | query := &Query[T]{Model: model} 271 | return query.TableColumnDrop(db, column) 272 | } 273 | 274 | func (model *Model[T]) TableCreate(db *sql.DB, tableCreateConfig ...TableCreateConfig) (sql.Result, error) { 275 | query := &Query[T]{Model: model} 276 | return query.TableCreate(db, tableCreateConfig...) 277 | } 278 | 279 | func (model *Model[T]) TableDrop(db *sql.DB, tableDropConfig ...TableDropConfig) (sql.Result, error) { 280 | query := &Query[T]{Model: model} 281 | if len(tableDropConfig) > 0 { 282 | return query.TableDrop(db, tableDropConfig[0]) 283 | } 284 | return query.TableDrop(db, TableDropConfig{}) 285 | } 286 | 287 | func (model *Model[T]) ToJsonMap(row *T) map[string]interface{} { 288 | result := make(map[string]interface{}, 0) 289 | value := reflect.ValueOf(row).Elem() 290 | for _, field := range model.Fields { 291 | fieldName := strings.ToLower(field.Name) 292 | switch fv := value.FieldByName(field.Name).Interface().(type) { 293 | case JsonValuer: 294 | result[fieldName] = fv.JsonValue() 295 | 296 | case driver.Valuer: 297 | result[fieldName], _ = fv.Value() 298 | 299 | default: 300 | result[fieldName] = fv 301 | } 302 | } 303 | 304 | return result 305 | } 306 | 307 | func (model *Model[T]) ToMap(row *T) (map[string]interface{}, error) { 308 | args := make(map[string]interface{}) 309 | value := reflect.ValueOf(*row) 310 | 311 | for column := range model.Fields { 312 | fieldName := model.Fields[column].Name 313 | field := value.FieldByName(fieldName) 314 | 315 | // Skip zero valued primary keys. 316 | if field.IsZero() && model.Fields[column].Tag.Get("db_primary") == "true" { 317 | continue 318 | } 319 | 320 | switch field.Kind() { 321 | case reflect.Struct: 322 | switch vv := field.Interface().(type) { 323 | case driver.Valuer: 324 | v, _ := vv.Value() 325 | args[column] = v 326 | 327 | case time.Time: 328 | args[column] = vv 329 | 330 | default: 331 | if strings.HasPrefix(field.Type().String(), "rem.ForeignKey[") || strings.HasPrefix(field.Type().String(), "rem.NullForeignKey[") { 332 | if !field.FieldByName("Valid").Interface().(bool) { 333 | args[column] = nil 334 | } else { 335 | q := reflect.New(field.Type()).MethodByName("Model").Call(nil) 336 | fkPrimaryField := reflect.Indirect(q[0]).FieldByName("PrimaryField").Interface().(string) 337 | args[column] = reflect.Indirect(field.FieldByName("Row")).FieldByName(fkPrimaryField).Interface() 338 | } 339 | } else if strings.HasPrefix(field.Type().String(), "rem.OneToMany[") { 340 | continue 341 | } else { 342 | return nil, fmt.Errorf("rem: unsupported field type '%s' for column '%s' on table '%s'", field.Type().String(), column, model.Table) 343 | } 344 | } 345 | 346 | default: 347 | args[column] = field.Interface() 348 | } 349 | } 350 | 351 | return args, nil 352 | } 353 | 354 | func (model *Model[T]) Transaction(transaction *sql.Tx) *Query[T] { 355 | return &Query[T]{ 356 | Config: QueryConfig{Transaction: transaction}, 357 | Model: model, 358 | } 359 | } 360 | 361 | func (model *Model[T]) Zero() T { 362 | var zero T 363 | return zero 364 | } 365 | 366 | type TableCreateConfig struct { 367 | IfNotExists bool 368 | } 369 | 370 | type TableDropConfig struct { 371 | IfExists bool 372 | } 373 | 374 | var registeredModels = make(map[string]interface{}) 375 | 376 | func Register[T any](configs ...Config) *Model[T] { 377 | var model T 378 | modelType := reflect.TypeOf(model) 379 | modelTypeStr := modelType.String() 380 | for _, config := range configs { 381 | modelTypeStr = fmt.Sprintf("%s%+v", modelTypeStr, config) 382 | } 383 | 384 | m := Use[T](configs...) 385 | registeredModels[modelTypeStr] = m 386 | return m 387 | } 388 | 389 | func Use[T any](configs ...Config) *Model[T] { 390 | var model T 391 | modelType := reflect.TypeOf(model) 392 | modelTypeStr := modelType.String() 393 | for _, config := range configs { 394 | modelTypeStr = fmt.Sprintf("%s%+v", modelTypeStr, config) 395 | } 396 | 397 | if existing, ok := registeredModels[modelTypeStr]; ok { 398 | return existing.(*Model[T]) 399 | } 400 | 401 | var primaryColumn string 402 | var primaryField string 403 | fields := make(map[string]reflect.StructField, 0) 404 | 405 | for _, field := range reflect.VisibleFields(modelType) { 406 | if column, ok := field.Tag.Lookup("db"); ok { 407 | if strings.HasPrefix(field.Type.String(), "rem.OneToMany[") { 408 | fields[field.Name] = field 409 | } else { 410 | fields[column] = field 411 | if field.Tag.Get("db_primary") == "true" { 412 | primaryColumn = column 413 | primaryField = field.Name 414 | } 415 | } 416 | } 417 | } 418 | 419 | table := strings.ToLower(modelType.Name()) 420 | for _, config := range configs { 421 | if config.Table != "" { 422 | table = config.Table 423 | } 424 | } 425 | 426 | return &Model[T]{ 427 | Fields: fields, 428 | PrimaryColumn: primaryColumn, 429 | PrimaryField: primaryField, 430 | Table: table, 431 | Type: modelType, 432 | } 433 | } 434 | -------------------------------------------------------------------------------- /model_test.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "database/sql" 5 | "sort" 6 | "testing" 7 | "time" 8 | 9 | "github.com/DATA-DOG/go-sqlmock" 10 | "golang.org/x/exp/maps" 11 | "golang.org/x/exp/slices" 12 | ) 13 | 14 | func TestModelScanMap(t *testing.T) { 15 | type testGroups struct { 16 | Id int64 `db:"id" db_primary:"true"` 17 | Name string `db:"name" db_max_length:"100"` 18 | } 19 | type testAccounts struct { 20 | EditedAt sql.NullTime `db:"edited_at"` 21 | Group NullForeignKey[testGroups] `db:"group_id" db_on_delete:"SET NULL"` 22 | Id int64 `db:"id" db_primary:"true"` 23 | Name string `db:"name"` 24 | } 25 | model := Use[testAccounts]() 26 | 27 | data := []map[string]interface{}{ 28 | { 29 | "id": int64(1), 30 | "name": "foo", 31 | "edited_at": time.Date(2009, time.January, 2, 3, 0, 0, 0, time.UTC), 32 | "group_id": int64(10), 33 | }, 34 | { 35 | "id": int64(2), 36 | "name": "bar", 37 | "edited_at": nil, 38 | "group_id": int64(20), 39 | }, 40 | { 41 | "id": int64(3), 42 | "name": "baz", 43 | "edited_at": nil, 44 | "group_id": nil, 45 | }, 46 | } 47 | expected := []testAccounts{ 48 | { 49 | Id: 1, 50 | Name: "foo", 51 | EditedAt: sql.NullTime{ 52 | Time: time.Date(2009, time.January, 2, 3, 0, 0, 0, time.UTC), 53 | Valid: true, 54 | }, 55 | Group: NullForeignKey[testGroups]{ 56 | Row: &testGroups{Id: 10}, 57 | Valid: true, 58 | }, 59 | }, 60 | { 61 | Id: 2, 62 | Name: "bar", 63 | Group: NullForeignKey[testGroups]{ 64 | Row: &testGroups{Id: 20}, 65 | Valid: true, 66 | }, 67 | }, 68 | {Id: 3, Name: "baz"}, 69 | } 70 | for i, row := range data { 71 | actual, err := model.ScanMap(row) 72 | if err != nil { 73 | t.Fatal("Unexpected error:", err) 74 | } 75 | if actual.Id != expected[i].Id || 76 | actual.Name != expected[i].Name || 77 | actual.EditedAt != expected[i].EditedAt || 78 | actual.Group.Valid != expected[i].Group.Valid || 79 | (actual.Group.Valid && *actual.Group.Row != *expected[i].Group.Row) { 80 | t.Errorf("Expected '%+v', got '%+v'", expected[i], actual) 81 | } 82 | } 83 | } 84 | 85 | type testGroupsModelToMap struct { 86 | Accounts OneToMany[testAccountsModelToMap] `db:"group_id"` 87 | Id int64 `db:"id" db_primary:"true"` 88 | Name string `db:"name" db_max_length:"100"` 89 | } 90 | type testAccountsModelToMap struct { 91 | EditedAt sql.NullTime `db:"edited_at"` 92 | Group NullForeignKey[testGroupsModelToMap] `db:"group_id" db_on_delete:"SET NULL"` 93 | Id int64 `db:"id" db_primary:"true"` 94 | Name string `db:"name"` 95 | } 96 | 97 | func assertMapDeepEquals(t *testing.T, actual map[string]interface{}, expected map[string]interface{}) { 98 | actualKeys := maps.Keys(actual) 99 | expectKeys := maps.Keys(expected) 100 | sort.Strings(actualKeys) 101 | sort.Strings(expectKeys) 102 | if !slices.Equal(actualKeys, expectKeys) { 103 | t.Errorf("Expected '%#v', got '%#v'", expected, actual) 104 | } 105 | for _, key := range actualKeys { 106 | actualValue := actual[key] 107 | expectValue := expected[key] 108 | 109 | switch av := actualValue.(type) { 110 | case []map[string]interface{}: 111 | ev, ok := expectValue.([]map[string]interface{}) 112 | if !ok { 113 | t.Errorf("Expected\n'%#v', got\n'%#v'", expected, actual) 114 | } 115 | for i := range av { 116 | assertMapDeepEquals(t, av[i], ev[i]) 117 | } 118 | case map[string]interface{}: 119 | ev, ok := expectValue.(map[string]interface{}) 120 | if !ok { 121 | t.Errorf("Expected\n'%#v', got\n'%#v'", expected, actual) 122 | } 123 | assertMapDeepEquals(t, av, ev) 124 | default: 125 | if av != expectValue { 126 | t.Errorf("Expected '%#v', got '%#v'", expected, actual) 127 | } 128 | } 129 | } 130 | } 131 | 132 | func TestModelToJsonMap(t *testing.T) { 133 | model := Use[testAccountsModelToMap]() 134 | 135 | rows := []testAccountsModelToMap{ 136 | { 137 | Id: 1, 138 | Name: "foo", 139 | EditedAt: sql.NullTime{ 140 | Time: time.Date(2009, time.January, 2, 3, 0, 0, 0, time.UTC), 141 | Valid: true, 142 | }, 143 | Group: NullForeignKey[testGroupsModelToMap]{ 144 | Row: &testGroupsModelToMap{Id: 10}, 145 | Valid: true, 146 | }, 147 | }, 148 | { 149 | Id: 2, 150 | Name: "bar", 151 | Group: NullForeignKey[testGroupsModelToMap]{ 152 | Row: &testGroupsModelToMap{Id: 20}, 153 | Valid: true, 154 | }, 155 | }, 156 | {Id: 3, Name: "baz"}, 157 | } 158 | expected := []map[string]interface{}{ 159 | { 160 | "id": int64(1), 161 | "name": "foo", 162 | "editedat": time.Date(2009, time.January, 2, 3, 0, 0, 0, time.UTC), 163 | "group": map[string]interface{}{ 164 | "accounts": []map[string]interface{}{}, 165 | "name": "", 166 | "id": int64(10), 167 | }, 168 | }, 169 | { 170 | "id": int64(2), 171 | "name": "bar", 172 | "editedat": nil, 173 | "group": map[string]interface{}{ 174 | "accounts": []map[string]interface{}{}, 175 | "name": "", 176 | "id": int64(20), 177 | }, 178 | }, 179 | { 180 | "id": int64(3), 181 | "name": "baz", 182 | "editedat": nil, 183 | "group": nil, 184 | }, 185 | } 186 | 187 | for i, row := range rows { 188 | actual := model.ToJsonMap(&row) 189 | assertMapDeepEquals(t, actual, expected[i]) 190 | } 191 | 192 | groupsModel := Use[testGroupsModelToMap]() 193 | groups := []testGroupsModelToMap{ 194 | { 195 | Id: 1, 196 | Name: "foo", 197 | }, 198 | { 199 | Id: 2, 200 | Name: "bar", 201 | }, 202 | } 203 | expected = []map[string]interface{}{ 204 | { 205 | "accounts": []map[string]interface{}{}, 206 | "id": int64(1), 207 | "name": "foo", 208 | }, 209 | { 210 | "accounts": []map[string]interface{}{}, 211 | "id": int64(2), 212 | "name": "bar", 213 | }, 214 | } 215 | for i, row := range groups { 216 | actual := groupsModel.ToJsonMap(&row) 217 | assertMapDeepEquals(t, actual, expected[i]) 218 | } 219 | } 220 | 221 | func TestModelToMap(t *testing.T) { 222 | model := Use[testAccountsModelToMap]() 223 | 224 | rows := []testAccountsModelToMap{ 225 | { 226 | Id: 1, 227 | Name: "foo", 228 | EditedAt: sql.NullTime{ 229 | Time: time.Date(2009, time.January, 2, 3, 0, 0, 0, time.UTC), 230 | Valid: true, 231 | }, 232 | Group: NullForeignKey[testGroupsModelToMap]{ 233 | Row: &testGroupsModelToMap{Id: 10}, 234 | Valid: true, 235 | }, 236 | }, 237 | { 238 | Id: 2, 239 | Name: "bar", 240 | Group: NullForeignKey[testGroupsModelToMap]{ 241 | Row: &testGroupsModelToMap{Id: 20}, 242 | Valid: true, 243 | }, 244 | }, 245 | {Id: 3, Name: "baz"}, 246 | } 247 | expected := []map[string]interface{}{ 248 | { 249 | "id": int64(1), 250 | "name": "foo", 251 | "edited_at": time.Date(2009, time.January, 2, 3, 0, 0, 0, time.UTC), 252 | "group_id": int64(10), 253 | }, 254 | { 255 | "id": int64(2), 256 | "name": "bar", 257 | "edited_at": nil, 258 | "group_id": int64(20), 259 | }, 260 | { 261 | "id": int64(3), 262 | "name": "baz", 263 | "edited_at": nil, 264 | "group_id": nil, 265 | }, 266 | } 267 | for i, row := range rows { 268 | actual, err := model.ToMap(&row) 269 | if err != nil { 270 | t.Fatal("Unexpected error:", err) 271 | } 272 | if !maps.Equal(actual, expected[i]) { 273 | t.Errorf("Expected '%#v', got '%#v'", expected[i], actual) 274 | } 275 | } 276 | 277 | groupsModel := Use[testGroupsModelToMap]() 278 | groups := []testGroupsModelToMap{ 279 | { 280 | Id: 1, 281 | Name: "foo", 282 | }, 283 | { 284 | Id: 2, 285 | Name: "bar", 286 | }, 287 | } 288 | expected = []map[string]interface{}{ 289 | { 290 | "id": int64(1), 291 | "name": "foo", 292 | }, 293 | { 294 | "id": int64(2), 295 | "name": "bar", 296 | }, 297 | } 298 | for i, row := range groups { 299 | actual, err := groupsModel.ToMap(&row) 300 | if err != nil { 301 | t.Fatal("Unexpected error:", err) 302 | } 303 | if !maps.Equal(actual, expected[i]) { 304 | t.Errorf("Expected '%#v', got '%#v'", expected[i], actual) 305 | } 306 | } 307 | } 308 | 309 | func TestRegister(t *testing.T) { 310 | defer func() { 311 | registeredModels = make(map[string]interface{}) 312 | }() 313 | type testModel struct { 314 | Id int64 `db:"id" db_primary:"true"` 315 | Name string `db:"name"` 316 | } 317 | m1 := Use[testModel]() 318 | m2 := Register[testModel]() 319 | m3 := Use[testModel]() 320 | m4 := Use[testModel](Config{Table: "testmodelwithconfig"}) 321 | if m1 == m2 || m1 == m3 { 322 | t.Errorf("Expected '%#v' to be different from '%#v' and '%#v'", m1, m2, m3) 323 | } 324 | if m2 != m3 { 325 | t.Errorf("Expected '%#v', got '%#v'", m2, m3) 326 | } 327 | if m2 == m4 { 328 | t.Errorf("Expected '%#v' to be different from '%#v'", m2, m4) 329 | } 330 | } 331 | 332 | func TestScanToMap(t *testing.T) { 333 | type testAccounts struct { 334 | EditedAt sql.NullTime `db:"edited_at"` 335 | Id int64 `db:"id" db_primary:"true"` 336 | Name string `db:"name"` 337 | } 338 | accounts := Use[testAccounts]() 339 | 340 | db, mock, err := sqlmock.New() 341 | if err != nil { 342 | t.Fatal("failed to open sqlmock database:", err) 343 | } 344 | defer db.Close() 345 | 346 | rows := sqlmock.NewRows([]string{"id", "name", "edited_at"}). 347 | AddRow(1, "foo", time.Date(2009, time.January, 2, 3, 0, 0, 0, time.UTC)). 348 | AddRow(2, "bar", nil). 349 | AddRow(3, "baz", nil) 350 | 351 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 352 | rs, _ := db.Query("SELECT") 353 | defer rs.Close() 354 | 355 | expected := []map[string]interface{}{ 356 | { 357 | "id": int64(1), 358 | "name": "foo", 359 | "edited_at": time.Date(2009, time.January, 2, 3, 0, 0, 0, time.UTC), 360 | }, 361 | { 362 | "id": int64(2), 363 | "name": "bar", 364 | "edited_at": nil, 365 | }, 366 | { 367 | "id": int64(3), 368 | "name": "baz", 369 | "edited_at": nil, 370 | }, 371 | } 372 | i := 0 373 | for rs.Next() { 374 | row, err := accounts.ScanToMap(rs) 375 | if err != nil { 376 | t.Fatal("Unexpected error:", err) 377 | } 378 | if !maps.Equal(row, expected[i]) { 379 | t.Errorf("Expected '%#v', got '%#v'", expected[i], row) 380 | } 381 | i++ 382 | } 383 | } 384 | 385 | func TestUse(t *testing.T) { 386 | type testAccounts struct { 387 | EditedAt sql.NullTime `db:"edited_at"` 388 | Id int64 `db:"id" db_primary:"true"` 389 | Name string `db:"name"` 390 | } 391 | type testGroups struct { 392 | Id int64 `db:"id" db_primary:"true"` 393 | Name string `db:"name" db_max_length:"100"` 394 | Accounts OneToMany[testAccounts] `db:"group_id"` 395 | } 396 | groups := Use[testGroups]() 397 | columns := maps.Keys(groups.Fields) 398 | sort.Strings(columns) 399 | expectedColumns := []string{"Accounts", "id", "name"} 400 | expectedPrimaryColumn := "id" 401 | expectedPrimaryField := "Id" 402 | expectedTable := "testgroups" 403 | if !slices.Equal(columns, expectedColumns) { 404 | t.Errorf("Expected '%+v', got '%+v'", expectedColumns, columns) 405 | } 406 | if groups.PrimaryColumn != expectedPrimaryColumn { 407 | t.Errorf("Expected '%s', got '%s'", expectedPrimaryColumn, groups.PrimaryColumn) 408 | } 409 | if groups.PrimaryField != expectedPrimaryField { 410 | t.Errorf("Expected '%s', got '%s'", expectedPrimaryField, groups.PrimaryField) 411 | } 412 | if groups.Table != expectedTable { 413 | t.Errorf("Expected '%s', got '%s'", expectedTable, groups.Table) 414 | } 415 | } 416 | -------------------------------------------------------------------------------- /mysqldialect/mysqldialect.go: -------------------------------------------------------------------------------- 1 | package mysqldialect 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "reflect" 7 | "sort" 8 | "strings" 9 | "time" 10 | 11 | "github.com/evantbyrne/rem" 12 | "golang.org/x/exp/maps" 13 | ) 14 | 15 | type MysqlDialect struct{} 16 | 17 | func (dialect MysqlDialect) BuildDelete(config rem.QueryConfig) (string, []interface{}, error) { 18 | args := append([]interface{}(nil), config.Params...) 19 | var queryString strings.Builder 20 | queryString.WriteString("DELETE FROM ") 21 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 22 | 23 | // WHERE 24 | where, args, err := dialect.buildWhere(config, args) 25 | if err != nil { 26 | return "", nil, err 27 | } 28 | if where != "" { 29 | queryString.WriteString(where) 30 | } 31 | 32 | // ORDER BY 33 | if len(config.Sort) > 0 { 34 | queryString.WriteString(" ORDER BY ") 35 | for i, column := range config.Sort { 36 | if i > 0 { 37 | queryString.WriteString(", ") 38 | } 39 | if strings.HasPrefix(column, "-") { 40 | queryString.WriteString(dialect.QuoteIdentifier(column[1:])) 41 | queryString.WriteString(" DESC") 42 | } else { 43 | queryString.WriteString(dialect.QuoteIdentifier(column)) 44 | queryString.WriteString(" ASC") 45 | } 46 | } 47 | } 48 | 49 | // LIMIT 50 | if config.Limit != nil { 51 | args = append(args, config.Limit) 52 | queryString.WriteString(" LIMIT ") 53 | queryString.WriteString(dialect.Param(len(args))) 54 | } 55 | 56 | // OFFSET 57 | if config.Offset != nil { 58 | return "", nil, fmt.Errorf("rem: DELETE does not support OFFSET") 59 | } 60 | 61 | return queryString.String(), args, nil 62 | } 63 | 64 | func (dialect MysqlDialect) BuildInsert(config rem.QueryConfig, rowMap map[string]interface{}, columns ...string) (string, []interface{}, error) { 65 | args := make([]interface{}, 0) 66 | var queryString strings.Builder 67 | 68 | queryString.WriteString("INSERT INTO ") 69 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 70 | queryString.WriteString(" (") 71 | first := true 72 | for _, column := range columns { 73 | if arg, ok := rowMap[column]; ok { 74 | if _, ok := config.Fields[column]; !ok { 75 | return "", nil, fmt.Errorf("rem: field for column '%s' not found on model for table '%s'", column, config.Table) 76 | } 77 | args = append(args, arg) 78 | if first { 79 | first = false 80 | } else { 81 | queryString.WriteString(",") 82 | } 83 | queryString.WriteString(dialect.QuoteIdentifier(column)) 84 | } else { 85 | return "", nil, fmt.Errorf("rem: invalid column '%s' on INSERT", column) 86 | } 87 | } 88 | 89 | queryString.WriteString(") VALUES (") 90 | for i := 1; i <= len(rowMap); i++ { 91 | if i > 1 { 92 | queryString.WriteString(",") 93 | } 94 | queryString.WriteString(dialect.Param(i)) 95 | } 96 | queryString.WriteString(")") 97 | 98 | return queryString.String(), args, nil 99 | } 100 | 101 | func (dialect MysqlDialect) buildJoins(config rem.QueryConfig, args []interface{}) (string, []interface{}, error) { 102 | var queryPart strings.Builder 103 | if len(config.Joins) > 0 { 104 | for _, join := range config.Joins { 105 | if len(join.On) > 0 { 106 | queryPart.WriteString(fmt.Sprintf(" %s JOIN %s ON", join.Direction, dialect.QuoteIdentifier(join.Table))) 107 | for _, where := range join.On { 108 | queryWhere, whereArgs, err := where.StringWithArgs(dialect, args) 109 | if err != nil { 110 | return "", nil, err 111 | } 112 | args = whereArgs 113 | queryPart.WriteString(queryWhere) 114 | } 115 | } 116 | } 117 | } 118 | return queryPart.String(), args, nil 119 | } 120 | 121 | func (dialect MysqlDialect) BuildSelect(config rem.QueryConfig) (string, []interface{}, error) { 122 | args := append([]interface{}(nil), config.Params...) 123 | var queryString strings.Builder 124 | if config.Count { 125 | queryString.WriteString("SELECT count(*) FROM ") 126 | } else if len(config.Selected) > 0 { 127 | queryString.WriteString("SELECT ") 128 | for i, column := range config.Selected { 129 | if i > 0 { 130 | queryString.WriteString(",") 131 | } 132 | switch cv := column.(type) { 133 | case string: 134 | queryString.WriteString(dialect.QuoteIdentifier(cv)) 135 | 136 | case rem.DialectStringer: 137 | queryString.WriteString(cv.StringForDialect(dialect)) 138 | 139 | case fmt.Stringer: 140 | queryString.WriteString(cv.String()) 141 | 142 | default: 143 | return "", nil, fmt.Errorf("rem: invalid column type %#v", column) 144 | } 145 | } 146 | queryString.WriteString(" FROM ") 147 | } else { 148 | queryString.WriteString("SELECT * FROM ") 149 | } 150 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 151 | 152 | // JOIN 153 | joins, args, err := dialect.buildJoins(config, args) 154 | if err != nil { 155 | return "", nil, err 156 | } 157 | if joins != "" { 158 | queryString.WriteString(joins) 159 | } 160 | 161 | // WHERE 162 | where, args, err := dialect.buildWhere(config, args) 163 | if err != nil { 164 | return "", nil, err 165 | } 166 | if where != "" { 167 | queryString.WriteString(where) 168 | } 169 | 170 | // ORDER BY 171 | if len(config.Sort) > 0 { 172 | queryString.WriteString(" ORDER BY ") 173 | for i, column := range config.Sort { 174 | if i > 0 { 175 | queryString.WriteString(", ") 176 | } 177 | if strings.HasPrefix(column, "-") { 178 | queryString.WriteString(dialect.QuoteIdentifier(column[1:])) 179 | queryString.WriteString(" DESC") 180 | } else { 181 | queryString.WriteString(dialect.QuoteIdentifier(column)) 182 | queryString.WriteString(" ASC") 183 | } 184 | } 185 | } 186 | 187 | // LIMIT 188 | if config.Limit != nil { 189 | args = append(args, config.Limit) 190 | queryString.WriteString(" LIMIT ") 191 | queryString.WriteString(dialect.Param(len(args))) 192 | } 193 | 194 | // OFFSET 195 | if config.Offset != nil { 196 | args = append(args, config.Offset) 197 | queryString.WriteString(" OFFSET ") 198 | queryString.WriteString(dialect.Param(len(args))) 199 | } 200 | 201 | return queryString.String(), args, nil 202 | } 203 | 204 | func (dialect MysqlDialect) BuildTableColumnAdd(config rem.QueryConfig, column string) (string, error) { 205 | field, ok := config.Fields[column] 206 | if !ok { 207 | return "", fmt.Errorf("rem: invalid column '%s' on model for table '%s'", column, config.Table) 208 | } 209 | 210 | columnType, err := dialect.ColumnType(field) 211 | if err != nil { 212 | return "", err 213 | } 214 | return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", dialect.QuoteIdentifier(config.Table), dialect.QuoteIdentifier(column), columnType), nil 215 | } 216 | 217 | func (dialect MysqlDialect) BuildTableColumnDrop(config rem.QueryConfig, column string) (string, error) { 218 | return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", dialect.QuoteIdentifier(config.Table), dialect.QuoteIdentifier(column)), nil 219 | } 220 | 221 | func (dialect MysqlDialect) BuildTableCreate(config rem.QueryConfig, tableCreateConfig rem.TableCreateConfig) (string, error) { 222 | var sql strings.Builder 223 | sql.WriteString("CREATE TABLE ") 224 | if tableCreateConfig.IfNotExists { 225 | sql.WriteString("IF NOT EXISTS ") 226 | } 227 | sql.WriteString(dialect.QuoteIdentifier(config.Table)) 228 | sql.WriteString(" (") 229 | fieldNames := maps.Keys(config.Fields) 230 | sort.Strings(fieldNames) 231 | for i, fieldName := range fieldNames { 232 | field := config.Fields[fieldName] 233 | columnType, err := dialect.ColumnType(field) 234 | if err != nil { 235 | return "", err 236 | } 237 | if i > 0 { 238 | sql.WriteString(",") 239 | } 240 | sql.WriteString("\n\t") 241 | sql.WriteString(dialect.QuoteIdentifier(fieldName)) 242 | sql.WriteString(" ") 243 | sql.WriteString(columnType) 244 | } 245 | sql.WriteString("\n)") 246 | return sql.String(), nil 247 | } 248 | 249 | func (dialect MysqlDialect) BuildTableDrop(config rem.QueryConfig, tableDropConfig rem.TableDropConfig) (string, error) { 250 | var queryString strings.Builder 251 | queryString.WriteString("DROP TABLE ") 252 | if tableDropConfig.IfExists { 253 | queryString.WriteString("IF EXISTS ") 254 | } 255 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 256 | return queryString.String(), nil 257 | } 258 | 259 | func (dialect MysqlDialect) BuildUpdate(config rem.QueryConfig, rowMap map[string]interface{}, columns ...string) (string, []interface{}, error) { 260 | args := append([]interface{}(nil), config.Params...) 261 | var queryString strings.Builder 262 | 263 | queryString.WriteString("UPDATE ") 264 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 265 | queryString.WriteString(" SET ") 266 | 267 | first := true 268 | for _, column := range columns { 269 | if arg, ok := rowMap[column]; ok { 270 | args = append(args, arg) 271 | if first { 272 | first = false 273 | } else { 274 | queryString.WriteString(",") 275 | } 276 | queryString.WriteString(dialect.QuoteIdentifier(column)) 277 | queryString.WriteString(" = ") 278 | queryString.WriteString(dialect.Param(len(args))) 279 | } else { 280 | return "", nil, fmt.Errorf("rem: invalid column '%s' on UPDATE", column) 281 | } 282 | } 283 | 284 | // WHERE 285 | where, args, err := dialect.buildWhere(config, args) 286 | if err != nil { 287 | return "", nil, err 288 | } 289 | if where != "" { 290 | queryString.WriteString(where) 291 | } 292 | 293 | // ORDER BY 294 | if len(config.Sort) > 0 { 295 | queryString.WriteString(" ORDER BY ") 296 | for i, column := range config.Sort { 297 | if i > 0 { 298 | queryString.WriteString(", ") 299 | } 300 | if strings.HasPrefix(column, "-") { 301 | queryString.WriteString(dialect.QuoteIdentifier(column[1:])) 302 | queryString.WriteString(" DESC") 303 | } else { 304 | queryString.WriteString(dialect.QuoteIdentifier(column)) 305 | queryString.WriteString(" ASC") 306 | } 307 | } 308 | } 309 | 310 | // LIMIT 311 | if config.Limit != nil { 312 | args = append(args, config.Limit) 313 | queryString.WriteString(" LIMIT ") 314 | queryString.WriteString(dialect.Param(len(args))) 315 | } 316 | 317 | // OFFSET 318 | if config.Offset != nil { 319 | return "", nil, fmt.Errorf("rem: UPDATE does not support OFFSET") 320 | } 321 | 322 | return queryString.String(), args, nil 323 | } 324 | 325 | func (dialect MysqlDialect) buildWhere(config rem.QueryConfig, args []interface{}) (string, []interface{}, error) { 326 | var queryPart strings.Builder 327 | if len(config.Filters) > 0 { 328 | queryPart.WriteString(" WHERE") 329 | for _, where := range config.Filters { 330 | queryWhere, whereArgs, err := where.StringWithArgs(dialect, args) 331 | if err != nil { 332 | return "", nil, err 333 | } 334 | args = whereArgs 335 | queryPart.WriteString(queryWhere) 336 | } 337 | } 338 | return queryPart.String(), args, nil 339 | } 340 | 341 | func (dialect MysqlDialect) ColumnType(field reflect.StructField) (string, error) { 342 | tagType := field.Tag.Get("db_type") 343 | if tagType != "" { 344 | return tagType, nil 345 | } 346 | 347 | fieldInstance := reflect.Indirect(reflect.New(field.Type)).Interface() 348 | var columnNull string 349 | var columnPrimary string 350 | var columnType string 351 | 352 | if field.Tag.Get("db_primary") == "true" { 353 | columnPrimary = " PRIMARY KEY" 354 | 355 | switch fieldInstance.(type) { 356 | case int, int64: 357 | columnNull = " NOT NULL AUTO_INCREMENT" 358 | columnType = "BIGINT" 359 | 360 | case int32: 361 | columnNull = " NOT NULL AUTO_INCREMENT" 362 | columnType = "INTEGER" 363 | 364 | case int16: 365 | columnNull = " NOT NULL AUTO_INCREMENT" 366 | columnType = "SMALLINT" 367 | 368 | case int8: 369 | columnNull = " NOT NULL AUTO_INCREMENT" 370 | columnType = "TINYINT" 371 | } 372 | } 373 | 374 | if columnType == "" { 375 | switch fieldInstance.(type) { 376 | case bool: 377 | columnNull = " NOT NULL" 378 | columnType = "BOOLEAN" 379 | 380 | case sql.NullBool: 381 | columnNull = " NULL" 382 | columnType = "BOOLEAN" 383 | 384 | case float32: 385 | columnNull = " NOT NULL" 386 | columnType = "FLOAT" 387 | 388 | case float64: 389 | columnNull = " NOT NULL" 390 | columnType = "DOUBLE" 391 | 392 | case sql.NullFloat64: 393 | columnNull = " NULL" 394 | columnType = "DOUBLE" 395 | 396 | case int, int64: 397 | columnNull = " NOT NULL" 398 | columnType = "BIGINT" 399 | 400 | case sql.NullInt64: 401 | columnNull = " NULL" 402 | columnType = "BIGINT" 403 | 404 | case int32: 405 | columnNull = " NOT NULL" 406 | columnType = "INTEGER" 407 | 408 | case sql.NullInt32: 409 | columnNull = " NULL" 410 | columnType = "INTEGER" 411 | 412 | case int8: 413 | columnNull = " NOT NULL" 414 | columnType = "TINYINT" 415 | 416 | case int16: 417 | columnNull = " NOT NULL" 418 | columnType = "SMALLINT" 419 | 420 | case sql.NullInt16: 421 | columnNull = " NULL" 422 | columnType = "SMALLINT" 423 | 424 | case string: 425 | columnNull = " NOT NULL" 426 | if tagMaxLength := field.Tag.Get("db_max_length"); tagMaxLength != "" { 427 | columnType = fmt.Sprintf("VARCHAR(%s)", tagMaxLength) 428 | } else { 429 | columnType = "TEXT" 430 | } 431 | 432 | case sql.NullString: 433 | columnNull = " NULL" 434 | if tagMaxLength := field.Tag.Get("db_max_length"); tagMaxLength != "" { 435 | columnType = fmt.Sprintf("VARCHAR(%s)", tagMaxLength) 436 | } else { 437 | columnType = "TEXT" 438 | } 439 | 440 | case time.Time: 441 | columnNull = " NOT NULL" 442 | columnType = "DATETIME" 443 | 444 | case sql.NullTime: 445 | columnNull = " NULL" 446 | columnType = "DATETIME" 447 | 448 | default: 449 | if strings.HasPrefix(field.Type.String(), "rem.ForeignKey[") || strings.HasPrefix(field.Type.String(), "rem.NullForeignKey[") { 450 | // Foreign keys. 451 | fv := reflect.New(field.Type).Elem() 452 | subModelQ := fv.Addr().MethodByName("Model").Call(nil) 453 | subFields := reflect.Indirect(subModelQ[0]).FieldByName("Fields").Interface().(map[string]reflect.StructField) 454 | subPrimaryColumn := reflect.Indirect(subModelQ[0]).FieldByName("PrimaryColumn").Interface().(string) 455 | subTable := reflect.Indirect(subModelQ[0]).FieldByName("Table").Interface().(string) 456 | columnTypeTemp, err := dialect.ColumnType(subFields[subPrimaryColumn]) 457 | if err != nil { 458 | return "", err 459 | } 460 | columnType = strings.SplitN(columnTypeTemp, " ", 2)[0] 461 | columnType = strings.Replace(columnType, " AUTO_INCREMENT", "", 1) 462 | 463 | columnNull = " NOT NULL" 464 | if strings.HasPrefix(field.Type.String(), "rem.NullForeignKey[") { 465 | columnNull = " NULL" 466 | } 467 | columnNull = fmt.Sprintf("%s REFERENCES %s (%s)", columnNull, dialect.QuoteIdentifier(subTable), dialect.QuoteIdentifier(subPrimaryColumn)) 468 | 469 | if tagOnUpdate := field.Tag.Get("db_on_update"); tagOnUpdate != "" { 470 | // ON UPDATE. 471 | columnNull = fmt.Sprint(columnNull, " ON UPDATE ", tagOnUpdate) 472 | } 473 | 474 | if tagOnDelete := field.Tag.Get("db_on_delete"); tagOnDelete != "" { 475 | // ON DELETE. 476 | columnNull = fmt.Sprint(columnNull, " ON DELETE ", tagOnDelete) 477 | } 478 | } 479 | } 480 | } 481 | 482 | if columnType == "" { 483 | return "", fmt.Errorf("rem: Unsupported column type: %T. Use the 'db_type' field tag to define a SQL type", fieldInstance) 484 | } 485 | 486 | if tagDefault := field.Tag.Get("db_default"); tagDefault != "" { 487 | // DEFAULT. 488 | columnNull += " DEFAULT " + tagDefault 489 | } 490 | 491 | if tagUnique := field.Tag.Get("db_unique"); tagUnique == "true" { 492 | // UNIQUE. 493 | columnNull += " UNIQUE" 494 | } 495 | 496 | return fmt.Sprint(columnType, columnPrimary, columnNull), nil 497 | } 498 | 499 | func (dialect MysqlDialect) Param(identifier int) string { 500 | return "?" 501 | } 502 | 503 | func (dialect MysqlDialect) QuoteIdentifier(identifier string) string { 504 | var query strings.Builder 505 | for i, part := range strings.Split(identifier, ".") { 506 | if i > 0 { 507 | query.WriteString(".") 508 | } 509 | query.WriteString(QuoteIdentifier(part)) 510 | } 511 | return query.String() 512 | } 513 | 514 | func QuoteIdentifier(identifier string) string { 515 | return "`" + strings.Replace(identifier, "`", "``", -1) + "`" 516 | } 517 | -------------------------------------------------------------------------------- /mysqldialect/mysqldialect_test.go: -------------------------------------------------------------------------------- 1 | package mysqldialect 2 | 3 | import ( 4 | "database/sql" 5 | "sort" 6 | "testing" 7 | "time" 8 | 9 | "github.com/evantbyrne/rem" 10 | "golang.org/x/exp/maps" 11 | "golang.org/x/exp/slices" 12 | ) 13 | 14 | func TestAs(t *testing.T) { 15 | dialect := MysqlDialect{} 16 | expected := map[string]rem.SqlAs{ 17 | "`x` AS `alias1`": rem.As("x", "alias1"), 18 | "`x` AS `y` AS `alias2`": rem.As(rem.As("x", "y"), "alias2"), 19 | "count(*) AS `alias3`": rem.As(rem.Unsafe("count(*)"), "alias3"), 20 | } 21 | for expected, alias := range expected { 22 | sql := alias.StringForDialect(dialect) 23 | if expected != sql { 24 | t.Errorf("Expected '%+v', got '%+v'", expected, sql) 25 | } 26 | } 27 | } 28 | 29 | func TestColumn(t *testing.T) { 30 | dialect := MysqlDialect{} 31 | expected := map[string]rem.SqlColumn{ 32 | "`x`": rem.Column("x"), 33 | "`x`.`y`": rem.Column("x.y"), 34 | "`x`.`y`.`z`": rem.Column("x.y.z"), 35 | "`x```": rem.Column("x`"), 36 | } 37 | for expected, column := range expected { 38 | sql := column.StringForDialect(dialect) 39 | if expected != sql { 40 | t.Errorf("Expected '%+v', got '%+v'", expected, sql) 41 | } 42 | } 43 | } 44 | 45 | func TestBuildDelete(t *testing.T) { 46 | type testModel struct { 47 | Id int64 `db:"test_id" db_primary:"true"` 48 | Value1 string `db:"test_value_1" db_max_length:"100"` 49 | Value2 string `db:"test_value_2" db_max_length:"100"` 50 | } 51 | 52 | dialect := MysqlDialect{} 53 | model := rem.Use[testModel]() 54 | 55 | query := model.Query() 56 | config := query.Config 57 | config.Fields = model.Fields 58 | config.Table = "testmodel" 59 | expectedArgs := []interface{}{} 60 | expectedSql := "DELETE FROM `testmodel`" 61 | queryString, args, err := dialect.BuildDelete(config) 62 | if err != nil { 63 | t.Errorf("Unexpected error %s", err.Error()) 64 | } 65 | if queryString != expectedSql { 66 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 67 | } 68 | if !slices.Equal(args, expectedArgs) { 69 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 70 | } 71 | 72 | // WHERE 73 | query.Filter("test_id", "=", 1) 74 | config = query.Config 75 | config.Fields = model.Fields 76 | config.Table = "testmodel" 77 | expectedArgs = []interface{}{1} 78 | expectedSql = "DELETE FROM `testmodel` WHERE `test_id` = ?" 79 | queryString, args, err = dialect.BuildDelete(config) 80 | if err != nil { 81 | t.Errorf("Unexpected error %s", err.Error()) 82 | } 83 | if queryString != expectedSql { 84 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 85 | } 86 | if !slices.Equal(args, expectedArgs) { 87 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 88 | } 89 | 90 | // ORDER BY 91 | query.Sort("-test_id") 92 | config = query.Config 93 | config.Fields = model.Fields 94 | config.Table = "testmodel" 95 | expectedSql = "DELETE FROM `testmodel` WHERE `test_id` = ? ORDER BY `test_id` DESC" 96 | queryString, args, err = dialect.BuildDelete(config) 97 | if err != nil { 98 | t.Errorf("Unexpected error %s", err.Error()) 99 | } 100 | if queryString != expectedSql { 101 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 102 | } 103 | if !slices.Equal(args, expectedArgs) { 104 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 105 | } 106 | 107 | // LIMIT 108 | query.Limit(3) 109 | config = query.Config 110 | config.Fields = model.Fields 111 | config.Table = "testmodel" 112 | expectedArgs = []interface{}{1, 3} 113 | expectedSql = "DELETE FROM `testmodel` WHERE `test_id` = ? ORDER BY `test_id` DESC LIMIT ?" 114 | queryString, args, err = dialect.BuildDelete(config) 115 | if err != nil { 116 | t.Errorf("Unexpected error %s", err.Error()) 117 | } 118 | if queryString != expectedSql { 119 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 120 | } 121 | if !slices.Equal(args, expectedArgs) { 122 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 123 | } 124 | } 125 | 126 | func TestBuildInsert(t *testing.T) { 127 | type testModel struct { 128 | Id int64 `db:"test_id" db_primary:"true"` 129 | Value1 string `db:"test_value_1" db_max_length:"100"` 130 | Value2 string `db:"test_value_2" db_max_length:"100"` 131 | } 132 | 133 | dialect := MysqlDialect{} 134 | model := rem.Use[testModel]() 135 | 136 | config := model.Query().Config 137 | config.Fields = model.Fields 138 | config.Table = "testmodel" 139 | expectedArgs := []interface{}{"foo", "bar"} 140 | expectedSql := "INSERT INTO `testmodel` (`test_value_1`,`test_value_2`) VALUES (?,?)" 141 | queryString, args, err := dialect.BuildInsert(config, map[string]interface{}{ 142 | "test_value_1": "foo", 143 | "test_value_2": "bar", 144 | }, "test_value_1", "test_value_2") 145 | if err != nil { 146 | t.Errorf("Unexpected error %s", err.Error()) 147 | } 148 | if queryString != expectedSql { 149 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 150 | } 151 | if !slices.Equal(args, expectedArgs) { 152 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 153 | } 154 | } 155 | 156 | func TestBuildSelect(t *testing.T) { 157 | type testModel struct { 158 | Id int64 `db:"test_id" db_primary:"true"` 159 | Value1 string `db:"test_value_1" db_max_length:"100"` 160 | Value2 string `db:"test_value_2" db_max_length:"100"` 161 | } 162 | 163 | dialect := MysqlDialect{} 164 | model := rem.Use[testModel]() 165 | 166 | config := model.Query().Config 167 | config.Fields = model.Fields 168 | config.Table = "testmodel" 169 | expectedArgs := []interface{}{} 170 | expectedSql := "SELECT * FROM `testmodel`" 171 | queryString, args, err := dialect.BuildSelect(config) 172 | if err != nil { 173 | t.Errorf("Unexpected error %s", err.Error()) 174 | } 175 | if queryString != expectedSql { 176 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 177 | } 178 | if !slices.Equal(args, expectedArgs) { 179 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 180 | } 181 | 182 | // SELECT 183 | config = model.Select("id", "value1", rem.Unsafe("count(1) as `count`"), rem.As("value2", "value3")).Config 184 | config.Fields = model.Fields 185 | config.Table = "testmodel" 186 | expectedArgs = []interface{}{} 187 | expectedSql = "SELECT `id`,`value1`,count(1) as `count`,`value2` AS `value3` FROM `testmodel`" 188 | queryString, args, err = dialect.BuildSelect(config) 189 | if err != nil { 190 | t.Errorf("Unexpected error %s", err.Error()) 191 | } 192 | if queryString != expectedSql { 193 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 194 | } 195 | if !slices.Equal(args, expectedArgs) { 196 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 197 | } 198 | 199 | // WHERE 200 | config = model.Filter("id", "=", 1).Config 201 | config.Fields = model.Fields 202 | config.Table = "testmodel" 203 | expectedArgs = []interface{}{1} 204 | expectedSql = "SELECT * FROM `testmodel` WHERE `id` = ?" 205 | queryString, args, err = dialect.BuildSelect(config) 206 | if err != nil { 207 | t.Errorf("Unexpected error %s", err.Error()) 208 | } 209 | if queryString != expectedSql { 210 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 211 | } 212 | if !slices.Equal(args, expectedArgs) { 213 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 214 | } 215 | 216 | config = model.Filter("id", "IN", rem.Sql(rem.Param(1), ",", rem.Param(2))).Config 217 | config.Fields = model.Fields 218 | config.Table = "testmodel" 219 | expectedArgs = []interface{}{1, 2} 220 | expectedSql = "SELECT * FROM `testmodel` WHERE `id` IN (?,?)" 221 | queryString, args, err = dialect.BuildSelect(config) 222 | if err != nil { 223 | t.Errorf("Unexpected error %s", err.Error()) 224 | } 225 | if queryString != expectedSql { 226 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 227 | } 228 | if !slices.Equal(args, expectedArgs) { 229 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 230 | } 231 | 232 | // JOIN 233 | config = model.Select(rem.Unsafe("*")).Join("groups", rem.Or( 234 | rem.Q("groups.id", "=", rem.Column("accounts.group_id")), 235 | rem.Q("groups.id", "IS", nil))).Config 236 | config.Fields = model.Fields 237 | config.Table = "testmodel" 238 | expectedArgs = []interface{}{} 239 | expectedSql = "SELECT * FROM `testmodel` INNER JOIN `groups` ON ( `groups`.`id` = `accounts`.`group_id` OR `groups`.`id` IS NULL )" 240 | queryString, args, err = dialect.BuildSelect(config) 241 | if err != nil { 242 | t.Errorf("Unexpected error %s", err.Error()) 243 | } 244 | if queryString != expectedSql { 245 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 246 | } 247 | if !slices.Equal(args, expectedArgs) { 248 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 249 | } 250 | 251 | // SORT 252 | config = model.Sort("test_id", "-test_value_1").Config 253 | config.Fields = model.Fields 254 | config.Table = "testmodel" 255 | expectedArgs = []interface{}{} 256 | expectedSql = "SELECT * FROM `testmodel` ORDER BY `test_id` ASC, `test_value_1` DESC" 257 | queryString, args, err = dialect.BuildSelect(config) 258 | if err != nil { 259 | t.Errorf("Unexpected error %s", err.Error()) 260 | } 261 | if queryString != expectedSql { 262 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 263 | } 264 | if !slices.Equal(args, expectedArgs) { 265 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 266 | } 267 | 268 | // LIMIT and OFFSET 269 | config = model.Filter("id", "=", 1).Offset(20).Limit(10).Config 270 | config.Fields = model.Fields 271 | config.Table = "testmodel" 272 | expectedArgs = []interface{}{1, 10, 20} 273 | expectedSql = "SELECT * FROM `testmodel` WHERE `id` = ? LIMIT ? OFFSET ?" 274 | queryString, args, err = dialect.BuildSelect(config) 275 | if err != nil { 276 | t.Errorf("Unexpected error %s", err.Error()) 277 | } 278 | if queryString != expectedSql { 279 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 280 | } 281 | if !slices.Equal(args, expectedArgs) { 282 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 283 | } 284 | } 285 | 286 | func TestBuildTableColumnAdd(t *testing.T) { 287 | type testModel struct { 288 | Value string `db:"test_value" db_max_length:"100"` 289 | } 290 | 291 | dialect := MysqlDialect{} 292 | model := rem.Use[testModel]() 293 | config := rem.QueryConfig{ 294 | Fields: model.Fields, 295 | Table: "testmodel", 296 | } 297 | expectedSql := "ALTER TABLE `testmodel` ADD COLUMN `test_value` VARCHAR(100) NOT NULL" 298 | queryString, err := dialect.BuildTableColumnAdd(config, "test_value") 299 | if err != nil { 300 | t.Errorf("Unexpected error %s", err.Error()) 301 | } 302 | if queryString != expectedSql { 303 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 304 | } 305 | } 306 | 307 | func TestBuildTableColumnDrop(t *testing.T) { 308 | dialect := MysqlDialect{} 309 | config := rem.QueryConfig{Table: "testmodel"} 310 | expectedSql := "ALTER TABLE `testmodel` DROP COLUMN `test_value`" 311 | queryString, err := dialect.BuildTableColumnDrop(config, "test_value") 312 | if err != nil { 313 | t.Errorf("Unexpected error %s", err.Error()) 314 | } 315 | if queryString != expectedSql { 316 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 317 | } 318 | } 319 | 320 | func TestBuildTableCreate(t *testing.T) { 321 | type testModel struct { 322 | Id int64 `db:"test_id" db_primary:"true"` 323 | Value1 string `db:"test_value_1" db_max_length:"100"` 324 | } 325 | 326 | dialect := MysqlDialect{} 327 | model := rem.Use[testModel]() 328 | config := rem.QueryConfig{ 329 | Fields: model.Fields, 330 | Table: "testmodel", 331 | } 332 | expectedSql := "CREATE TABLE `testmodel` (\n" + 333 | "\t`test_id` BIGINT PRIMARY KEY NOT NULL AUTO_INCREMENT,\n" + 334 | "\t`test_value_1` VARCHAR(100) NOT NULL\n" + 335 | ")" 336 | queryString, err := dialect.BuildTableCreate(config, rem.TableCreateConfig{}) 337 | if err != nil { 338 | t.Errorf("Unexpected error %s", err.Error()) 339 | } 340 | if queryString != expectedSql { 341 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 342 | } 343 | 344 | expectedSql = "CREATE TABLE IF NOT EXISTS `testmodel` (\n" + 345 | "\t`test_id` BIGINT PRIMARY KEY NOT NULL AUTO_INCREMENT,\n" + 346 | "\t`test_value_1` VARCHAR(100) NOT NULL\n" + 347 | ")" 348 | queryString, err = dialect.BuildTableCreate(config, rem.TableCreateConfig{IfNotExists: true}) 349 | if err != nil { 350 | t.Errorf("Unexpected error %s", err.Error()) 351 | } 352 | if queryString != expectedSql { 353 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 354 | } 355 | } 356 | 357 | func TestBuildTableDrop(t *testing.T) { 358 | dialect := MysqlDialect{} 359 | config := rem.QueryConfig{Table: "testmodel"} 360 | expectedSql := "DROP TABLE `testmodel`" 361 | queryString, err := dialect.BuildTableDrop(config, rem.TableDropConfig{}) 362 | if err != nil { 363 | t.Errorf("Unexpected error %s", err.Error()) 364 | } 365 | if queryString != expectedSql { 366 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 367 | } 368 | 369 | expectedSql = "DROP TABLE IF EXISTS `testmodel`" 370 | queryString, err = dialect.BuildTableDrop(config, rem.TableDropConfig{IfExists: true}) 371 | if err != nil { 372 | t.Errorf("Unexpected error %s", err.Error()) 373 | } 374 | if queryString != expectedSql { 375 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 376 | } 377 | } 378 | 379 | func TestBuildUpdate(t *testing.T) { 380 | type testModel struct { 381 | Id int64 `db:"test_id" db_primary:"true"` 382 | Value1 string `db:"test_value_1" db_max_length:"100"` 383 | Value2 string `db:"test_value_2" db_max_length:"100"` 384 | } 385 | 386 | dialect := MysqlDialect{} 387 | model := rem.Use[testModel]() 388 | 389 | query := model.Filter("test_id", "=", 1) 390 | config := query.Config 391 | config.Fields = model.Fields 392 | config.Table = "testmodel" 393 | expectedArgs := []interface{}{"foo", "bar", 1} 394 | expectedSql := "UPDATE `testmodel` SET `test_value_1` = ?,`test_value_2` = ? WHERE `test_id` = ?" 395 | queryString, args, err := dialect.BuildUpdate(config, map[string]interface{}{ 396 | "id": 123, 397 | "test_value_1": "foo", 398 | "test_value_2": "bar", 399 | }, "test_value_1", "test_value_2") 400 | if err != nil { 401 | t.Errorf("Unexpected error %s", err.Error()) 402 | } 403 | if queryString != expectedSql { 404 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 405 | } 406 | if !slices.Equal(args, expectedArgs) { 407 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 408 | } 409 | 410 | query.Limit(3) 411 | config = query.Config 412 | config.Fields = model.Fields 413 | config.Table = "testmodel" 414 | expectedArgs = []interface{}{"foo", "bar", 1, 3} 415 | expectedSql = "UPDATE `testmodel` SET `test_value_1` = ?,`test_value_2` = ? WHERE `test_id` = ? LIMIT ?" 416 | queryString, args, err = dialect.BuildUpdate(config, map[string]interface{}{ 417 | "id": 123, 418 | "test_value_1": "foo", 419 | "test_value_2": "bar", 420 | }, "test_value_1", "test_value_2") 421 | if err != nil { 422 | t.Errorf("Unexpected error %s", err.Error()) 423 | } 424 | if queryString != expectedSql { 425 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 426 | } 427 | if !slices.Equal(args, expectedArgs) { 428 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 429 | } 430 | 431 | query.Sort("-test_id") 432 | config = query.Config 433 | config.Fields = model.Fields 434 | config.Table = "testmodel" 435 | expectedSql = "UPDATE `testmodel` SET `test_value_1` = ?,`test_value_2` = ? WHERE `test_id` = ? ORDER BY `test_id` DESC LIMIT ?" 436 | queryString, args, err = dialect.BuildUpdate(config, map[string]interface{}{ 437 | "id": 123, 438 | "test_value_1": "foo", 439 | "test_value_2": "bar", 440 | }, "test_value_1", "test_value_2") 441 | if err != nil { 442 | t.Errorf("Unexpected error %s", err.Error()) 443 | } 444 | if queryString != expectedSql { 445 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 446 | } 447 | if !slices.Equal(args, expectedArgs) { 448 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 449 | } 450 | } 451 | 452 | func TestColumnType(t *testing.T) { 453 | type testFkInt struct { 454 | Id int64 `db:"id" db_primary:"true"` 455 | } 456 | 457 | type testFkString struct { 458 | Id string `db:"id" db_primary:"true" db_max_length:"100"` 459 | } 460 | 461 | type testModel struct { 462 | BigInt int64 `db:"test_big_int"` 463 | BigIntNull sql.NullInt64 `db:"test_big_int_null"` 464 | Bool bool `db:"test_bool"` 465 | BoolNull sql.NullBool `db:"test_bool_null"` 466 | Custom []byte `db:"test_custom" db_type:"BINARY(128) NOT NULL"` 467 | Default string `db:"test_default" db_default:"'foo'" db_max_length:"100"` 468 | Float float32 `db:"test_float"` 469 | Double float64 `db:"test_double"` 470 | DoubleNull sql.NullFloat64 `db:"test_double_null"` 471 | Id int64 `db:"test_id" db_primary:"true"` 472 | Int int32 `db:"test_int"` 473 | IntNull sql.NullInt32 `db:"test_int_null"` 474 | SmallInt int16 `db:"test_small_int"` 475 | SmallIntNull sql.NullInt16 `db:"test_small_int_null"` 476 | Text string `db:"test_text"` 477 | TextNull sql.NullString `db:"test_text_null"` 478 | Time time.Time `db:"test_time"` 479 | TimeNow time.Time `db:"test_time_now" db_default:"now()"` 480 | TimeNull sql.NullTime `db:"test_time_null"` 481 | TinyInt int8 `db:"test_tiny_int"` 482 | Varchar string `db:"test_varchar" db_max_length:"100"` 483 | VarcharNull sql.NullString `db:"test_varchar_null" db_max_length:"50"` 484 | ForiegnKey rem.ForeignKey[testFkString] `db:"test_fk_id" db_on_delete:"CASCADE"` 485 | ForiegnKeyNull rem.NullForeignKey[testFkInt] `db:"test_fk_null_id" db_on_delete:"SET NULL" db_on_update:"SET NULL"` 486 | Unique string `db:"test_unique" db_max_length:"255" db_unique:"true"` 487 | } 488 | 489 | expected := map[string]string{ 490 | "test_big_int": "BIGINT NOT NULL", 491 | "test_big_int_null": "BIGINT NULL", 492 | "test_bool": "BOOLEAN NOT NULL", 493 | "test_bool_null": "BOOLEAN NULL", 494 | "test_custom": "BINARY(128) NOT NULL", 495 | "test_default": "VARCHAR(100) NOT NULL DEFAULT 'foo'", 496 | "test_float": "FLOAT NOT NULL", 497 | "test_double": "DOUBLE NOT NULL", 498 | "test_double_null": "DOUBLE NULL", 499 | "test_id": "BIGINT PRIMARY KEY NOT NULL AUTO_INCREMENT", 500 | "test_int": "INTEGER NOT NULL", 501 | "test_int_null": "INTEGER NULL", 502 | "test_small_int": "SMALLINT NOT NULL", 503 | "test_small_int_null": "SMALLINT NULL", 504 | "test_time": "DATETIME NOT NULL", 505 | "test_time_now": "DATETIME NOT NULL DEFAULT now()", 506 | "test_time_null": "DATETIME NULL", 507 | "test_text": "TEXT NOT NULL", 508 | "test_text_null": "TEXT NULL", 509 | "test_tiny_int": "TINYINT NOT NULL", 510 | "test_varchar": "VARCHAR(100) NOT NULL", 511 | "test_varchar_null": "VARCHAR(50) NULL", 512 | "test_fk_id": "VARCHAR(100) NOT NULL REFERENCES `testfkstring` (`id`) ON DELETE CASCADE", 513 | "test_fk_null_id": "BIGINT NULL REFERENCES `testfkint` (`id`) ON UPDATE SET NULL ON DELETE SET NULL", 514 | "test_unique": "VARCHAR(255) NOT NULL UNIQUE", 515 | } 516 | 517 | dialect := MysqlDialect{} 518 | model := rem.Use[testModel]() 519 | fieldKeys := maps.Keys(model.Fields) 520 | sort.Strings(fieldKeys) 521 | 522 | for _, fieldName := range fieldKeys { 523 | field := model.Fields[fieldName] 524 | columnType, err := dialect.ColumnType(field) 525 | if err != nil { 526 | t.Fatalf(`dialect.ColumnType() threw error for '%#v': %s`, field, err) 527 | } 528 | if columnType != expected[fieldName] { 529 | t.Fatalf(`dialect.ColumnType() returned '%s', but expected '%s' for '%#v'`, columnType, expected[fieldName], field) 530 | } 531 | } 532 | } 533 | 534 | func TestQuoteIdentifier(t *testing.T) { 535 | values := map[string]string{ 536 | "abc": "`abc`", 537 | "a`bc": "`a``bc`", 538 | "a``b`c": "`a````b``c`", 539 | "`abc": "```abc`", 540 | "abc`": "`abc```", 541 | "ab\\`c": "`ab\\``c`", 542 | "abc\\": "`abc\\`", 543 | } 544 | 545 | for identifier, expected := range values { 546 | actual := QuoteIdentifier(identifier) 547 | if actual != expected { 548 | t.Errorf("Expected %s, got %s", expected, actual) 549 | } 550 | } 551 | } 552 | -------------------------------------------------------------------------------- /one_to_many.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "database/sql" 5 | "encoding/json" 6 | ) 7 | 8 | type OneToMany[To any] struct { 9 | RelatedColumn string 10 | RowPk interface{} 11 | Rows []*To 12 | } 13 | 14 | func (field *OneToMany[To]) All(db *sql.DB) ([]*To, error) { 15 | return field.Query().Filter(field.RelatedColumn, "=", field.RowPk).All(db) 16 | } 17 | 18 | func (field OneToMany[To]) JsonValue() interface{} { 19 | model := field.Model() 20 | results := make([]map[string]interface{}, len(field.Rows)) 21 | for i := range field.Rows { 22 | results[i] = model.ToJsonMap(field.Rows[i]) 23 | } 24 | return results 25 | } 26 | 27 | func (field OneToMany[To]) MarshalJSON() ([]byte, error) { 28 | model := field.Model() 29 | results := make([]map[string]interface{}, len(field.Rows)) 30 | for i, row := range field.Rows { 31 | results[i] = model.ToJsonMap(row) 32 | } 33 | return json.Marshal(results) 34 | } 35 | 36 | func (field *OneToMany[To]) Model() *Model[To] { 37 | return Use[To]() 38 | } 39 | 40 | func (field *OneToMany[To]) Query() *Query[To] { 41 | return &Query[To]{ 42 | Model: Use[To](), 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /pqdialect/pqdialect.go: -------------------------------------------------------------------------------- 1 | package pqdialect 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "reflect" 7 | "sort" 8 | "strconv" 9 | "strings" 10 | "time" 11 | 12 | "github.com/evantbyrne/rem" 13 | "golang.org/x/exp/maps" 14 | ) 15 | 16 | type PqDialect struct{} 17 | 18 | func (dialect PqDialect) BuildDelete(config rem.QueryConfig) (string, []interface{}, error) { 19 | args := append([]interface{}(nil), config.Params...) 20 | var queryString strings.Builder 21 | queryString.WriteString("DELETE FROM ") 22 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 23 | 24 | // WHERE 25 | where, args, err := dialect.buildWhere(config, args) 26 | if err != nil { 27 | return "", nil, err 28 | } 29 | if where != "" { 30 | queryString.WriteString(where) 31 | } 32 | 33 | // ORDER BY 34 | if len(config.Sort) > 0 { 35 | return "", nil, fmt.Errorf("rem: DELETE does not support ORDER BY") 36 | } 37 | 38 | // LIMIT 39 | if config.Limit != nil { 40 | return "", nil, fmt.Errorf("rem: DELETE does not support LIMIT") 41 | } 42 | 43 | // OFFSET 44 | if config.Offset != nil { 45 | return "", nil, fmt.Errorf("rem: DELETE does not support OFFSET") 46 | } 47 | 48 | return queryString.String(), args, nil 49 | } 50 | 51 | func (dialect PqDialect) BuildInsert(config rem.QueryConfig, rowMap map[string]interface{}, columns ...string) (string, []interface{}, error) { 52 | args := make([]interface{}, 0) 53 | var queryString strings.Builder 54 | 55 | queryString.WriteString("INSERT INTO ") 56 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 57 | queryString.WriteString(" (") 58 | first := true 59 | for _, column := range columns { 60 | if arg, ok := rowMap[column]; ok { 61 | if _, ok := config.Fields[column]; !ok { 62 | return "", nil, fmt.Errorf("rem: field for column '%s' not found on model for table '%s'", column, config.Table) 63 | } 64 | args = append(args, arg) 65 | if first { 66 | first = false 67 | } else { 68 | queryString.WriteString(",") 69 | } 70 | queryString.WriteString(dialect.QuoteIdentifier(column)) 71 | } else { 72 | return "", nil, fmt.Errorf("rem: invalid column '%s' on INSERT", column) 73 | } 74 | } 75 | 76 | queryString.WriteString(") VALUES (") 77 | for i := 1; i <= len(rowMap); i++ { 78 | if i > 1 { 79 | queryString.WriteString(",") 80 | } 81 | queryString.WriteString(dialect.Param(i)) 82 | } 83 | queryString.WriteString(")") 84 | 85 | return queryString.String(), args, nil 86 | } 87 | 88 | func (dialect PqDialect) buildJoins(config rem.QueryConfig, args []interface{}) (string, []interface{}, error) { 89 | var queryPart strings.Builder 90 | if len(config.Joins) > 0 { 91 | for _, join := range config.Joins { 92 | if len(join.On) > 0 { 93 | queryPart.WriteString(fmt.Sprintf(" %s JOIN %s ON", join.Direction, dialect.QuoteIdentifier(join.Table))) 94 | for _, where := range join.On { 95 | queryWhere, whereArgs, err := where.StringWithArgs(dialect, args) 96 | if err != nil { 97 | return "", nil, err 98 | } 99 | args = whereArgs 100 | queryPart.WriteString(queryWhere) 101 | } 102 | } 103 | } 104 | } 105 | return queryPart.String(), args, nil 106 | } 107 | 108 | func (dialect PqDialect) BuildSelect(config rem.QueryConfig) (string, []interface{}, error) { 109 | args := append([]interface{}(nil), config.Params...) 110 | var queryString strings.Builder 111 | if config.Count { 112 | queryString.WriteString("SELECT count(*) FROM ") 113 | } else if len(config.Selected) > 0 { 114 | queryString.WriteString("SELECT ") 115 | for i, column := range config.Selected { 116 | if i > 0 { 117 | queryString.WriteString(",") 118 | } 119 | switch cv := column.(type) { 120 | case string: 121 | queryString.WriteString(dialect.QuoteIdentifier(cv)) 122 | 123 | case rem.DialectStringer: 124 | queryString.WriteString(cv.StringForDialect(dialect)) 125 | 126 | case fmt.Stringer: 127 | queryString.WriteString(cv.String()) 128 | 129 | default: 130 | return "", nil, fmt.Errorf("rem: invalid column type %#v", column) 131 | } 132 | } 133 | queryString.WriteString(" FROM ") 134 | } else { 135 | queryString.WriteString("SELECT * FROM ") 136 | } 137 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 138 | 139 | // JOIN 140 | joins, args, err := dialect.buildJoins(config, args) 141 | if err != nil { 142 | return "", nil, err 143 | } 144 | if joins != "" { 145 | queryString.WriteString(joins) 146 | } 147 | 148 | // WHERE 149 | where, args, err := dialect.buildWhere(config, args) 150 | if err != nil { 151 | return "", nil, err 152 | } 153 | if where != "" { 154 | queryString.WriteString(where) 155 | } 156 | 157 | // ORDER BY 158 | if len(config.Sort) > 0 { 159 | queryString.WriteString(" ORDER BY ") 160 | for i, column := range config.Sort { 161 | if i > 0 { 162 | queryString.WriteString(", ") 163 | } 164 | if strings.HasPrefix(column, "-") { 165 | queryString.WriteString(dialect.QuoteIdentifier(column[1:])) 166 | queryString.WriteString(" DESC") 167 | } else { 168 | queryString.WriteString(dialect.QuoteIdentifier(column)) 169 | queryString.WriteString(" ASC") 170 | } 171 | } 172 | } 173 | 174 | // LIMIT 175 | if config.Limit != nil { 176 | args = append(args, config.Limit) 177 | queryString.WriteString(" LIMIT ") 178 | queryString.WriteString(dialect.Param(len(args))) 179 | } 180 | 181 | // OFFSET 182 | if config.Offset != nil { 183 | args = append(args, config.Offset) 184 | queryString.WriteString(" OFFSET ") 185 | queryString.WriteString(dialect.Param(len(args))) 186 | } 187 | 188 | return queryString.String(), args, nil 189 | } 190 | 191 | func (dialect PqDialect) BuildTableColumnAdd(config rem.QueryConfig, column string) (string, error) { 192 | field, ok := config.Fields[column] 193 | if !ok { 194 | return "", fmt.Errorf("rem: invalid column '%s' on model for table '%s'", column, config.Table) 195 | } 196 | 197 | columnType, err := dialect.ColumnType(field) 198 | if err != nil { 199 | return "", err 200 | } 201 | return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", dialect.QuoteIdentifier(config.Table), dialect.QuoteIdentifier(column), columnType), nil 202 | } 203 | 204 | func (dialect PqDialect) BuildTableColumnDrop(config rem.QueryConfig, column string) (string, error) { 205 | return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", dialect.QuoteIdentifier(config.Table), dialect.QuoteIdentifier(column)), nil 206 | } 207 | 208 | func (dialect PqDialect) BuildTableCreate(config rem.QueryConfig, tableCreateConfig rem.TableCreateConfig) (string, error) { 209 | var sql strings.Builder 210 | sql.WriteString("CREATE TABLE ") 211 | if tableCreateConfig.IfNotExists { 212 | sql.WriteString("IF NOT EXISTS ") 213 | } 214 | sql.WriteString(dialect.QuoteIdentifier(config.Table)) 215 | sql.WriteString(" (") 216 | fieldNames := maps.Keys(config.Fields) 217 | sort.Strings(fieldNames) 218 | for i, fieldName := range fieldNames { 219 | field := config.Fields[fieldName] 220 | columnType, err := dialect.ColumnType(field) 221 | if err != nil { 222 | return "", err 223 | } 224 | if i > 0 { 225 | sql.WriteString(",") 226 | } 227 | sql.WriteString("\n\t") 228 | sql.WriteString(dialect.QuoteIdentifier(fieldName)) 229 | sql.WriteString(" ") 230 | sql.WriteString(columnType) 231 | } 232 | sql.WriteString("\n)") 233 | return sql.String(), nil 234 | } 235 | 236 | func (dialect PqDialect) BuildTableDrop(config rem.QueryConfig, tableDropConfig rem.TableDropConfig) (string, error) { 237 | var queryString strings.Builder 238 | queryString.WriteString("DROP TABLE ") 239 | if tableDropConfig.IfExists { 240 | queryString.WriteString("IF EXISTS ") 241 | } 242 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 243 | return queryString.String(), nil 244 | } 245 | 246 | func (dialect PqDialect) BuildUpdate(config rem.QueryConfig, rowMap map[string]interface{}, columns ...string) (string, []interface{}, error) { 247 | args := append([]interface{}(nil), config.Params...) 248 | var queryString strings.Builder 249 | 250 | queryString.WriteString("UPDATE ") 251 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 252 | queryString.WriteString(" SET ") 253 | 254 | first := true 255 | for _, column := range columns { 256 | if arg, ok := rowMap[column]; ok { 257 | args = append(args, arg) 258 | if first { 259 | first = false 260 | } else { 261 | queryString.WriteString(",") 262 | } 263 | queryString.WriteString(dialect.QuoteIdentifier(column)) 264 | queryString.WriteString(" = ") 265 | queryString.WriteString(dialect.Param(len(args))) 266 | } else { 267 | return "", nil, fmt.Errorf("rem: invalid column '%s' on UPDATE", column) 268 | } 269 | } 270 | 271 | // WHERE 272 | where, args, err := dialect.buildWhere(config, args) 273 | if err != nil { 274 | return "", nil, err 275 | } 276 | if where != "" { 277 | queryString.WriteString(where) 278 | } 279 | 280 | return queryString.String(), args, nil 281 | } 282 | 283 | func (dialect PqDialect) buildWhere(config rem.QueryConfig, args []interface{}) (string, []interface{}, error) { 284 | var queryPart strings.Builder 285 | if len(config.Filters) > 0 { 286 | queryPart.WriteString(" WHERE") 287 | for _, where := range config.Filters { 288 | queryWhere, whereArgs, err := where.StringWithArgs(dialect, args) 289 | if err != nil { 290 | return "", nil, err 291 | } 292 | args = whereArgs 293 | queryPart.WriteString(queryWhere) 294 | } 295 | } 296 | return queryPart.String(), args, nil 297 | } 298 | 299 | func (dialect PqDialect) ColumnType(field reflect.StructField) (string, error) { 300 | tagType := field.Tag.Get("db_type") 301 | if tagType != "" { 302 | return tagType, nil 303 | } 304 | 305 | fieldInstance := reflect.Indirect(reflect.New(field.Type)).Interface() 306 | var columnNull string 307 | var columnPrimary string 308 | var columnType string 309 | 310 | if field.Tag.Get("db_primary") == "true" { 311 | columnPrimary = " PRIMARY KEY" 312 | 313 | switch fieldInstance.(type) { 314 | case int, int64: 315 | columnNull = " NOT NULL" 316 | columnType = "BIGSERIAL" 317 | 318 | case int32: 319 | columnNull = " NOT NULL" 320 | columnType = "SERIAL" 321 | 322 | case int8, int16: 323 | columnNull = " NOT NULL" 324 | columnType = "SMALLSERIAL" 325 | } 326 | } 327 | 328 | if columnType == "" { 329 | switch fieldInstance.(type) { 330 | case bool: 331 | columnNull = " NOT NULL" 332 | columnType = "BOOLEAN" 333 | 334 | case sql.NullBool: 335 | columnNull = " NULL" 336 | columnType = "BOOLEAN" 337 | 338 | case float64: 339 | columnNull = " NOT NULL" 340 | columnType = "DOUBLE PRECISION" 341 | 342 | case sql.NullFloat64: 343 | columnNull = " NULL" 344 | columnType = "DOUBLE PRECISION" 345 | 346 | case int, int64: 347 | columnNull = " NOT NULL" 348 | columnType = "BIGINT" 349 | 350 | case sql.NullInt64: 351 | columnNull = " NULL" 352 | columnType = "BIGINT" 353 | 354 | case int32: 355 | columnNull = " NOT NULL" 356 | columnType = "INTEGER" 357 | 358 | case sql.NullInt32: 359 | columnNull = " NULL" 360 | columnType = "INTEGER" 361 | 362 | case int8, int16: 363 | columnNull = " NOT NULL" 364 | columnType = "SMALLINT" 365 | 366 | case sql.NullInt16: 367 | columnNull = " NULL" 368 | columnType = "SMALLINT" 369 | 370 | case string: 371 | columnNull = " NOT NULL" 372 | if tagMaxLength := field.Tag.Get("db_max_length"); tagMaxLength != "" { 373 | columnType = fmt.Sprintf("VARCHAR(%s)", tagMaxLength) 374 | } else { 375 | columnType = "TEXT" 376 | } 377 | 378 | case sql.NullString: 379 | columnNull = " NULL" 380 | if tagMaxLength := field.Tag.Get("db_max_length"); tagMaxLength != "" { 381 | columnType = fmt.Sprintf("VARCHAR(%s)", tagMaxLength) 382 | } else { 383 | columnType = "TEXT" 384 | } 385 | 386 | case time.Time: 387 | columnNull = " NOT NULL" 388 | if tagTimeZone := field.Tag.Get("db_time_zone"); tagTimeZone == "true" { 389 | columnType = "TIMESTAMP WITH TIME ZONE" 390 | } else { 391 | columnType = "TIMESTAMP WITHOUT TIME ZONE" 392 | } 393 | 394 | case sql.NullTime: 395 | columnNull = " NULL" 396 | if tagTimeZone := field.Tag.Get("db_time_zone"); tagTimeZone == "true" { 397 | columnType = "TIMESTAMP WITH TIME ZONE" 398 | } else { 399 | columnType = "TIMESTAMP WITHOUT TIME ZONE" 400 | } 401 | 402 | default: 403 | if strings.HasPrefix(field.Type.String(), "rem.ForeignKey[") || strings.HasPrefix(field.Type.String(), "rem.NullForeignKey[") { 404 | // Foreign keys. 405 | fv := reflect.New(field.Type).Elem() 406 | subModelQ := fv.Addr().MethodByName("Model").Call(nil) 407 | subFields := reflect.Indirect(subModelQ[0]).FieldByName("Fields").Interface().(map[string]reflect.StructField) 408 | subPrimaryColumn := reflect.Indirect(subModelQ[0]).FieldByName("PrimaryColumn").Interface().(string) 409 | subTable := reflect.Indirect(subModelQ[0]).FieldByName("Table").Interface().(string) 410 | columnTypeTemp, err := dialect.ColumnType(subFields[subPrimaryColumn]) 411 | if err != nil { 412 | return "", err 413 | } 414 | columnType = strings.SplitN(columnTypeTemp, " ", 2)[0] 415 | columnType = strings.Replace(columnType, "BIGSERIAL", "BIGINT", 1) 416 | columnType = strings.Replace(columnType, "SMALLSERIAL", "SMALLINT", 1) 417 | columnType = strings.Replace(columnType, "SERIAL", "INTEGER", 1) 418 | 419 | columnNull = " NOT NULL" 420 | if strings.HasPrefix(field.Type.String(), "rem.NullForeignKey[") { 421 | columnNull = " NULL" 422 | } 423 | columnNull = fmt.Sprintf("%s REFERENCES %s (%s)", columnNull, dialect.QuoteIdentifier(subTable), dialect.QuoteIdentifier(subPrimaryColumn)) 424 | 425 | if tagOnUpdate := field.Tag.Get("db_on_update"); tagOnUpdate != "" { 426 | // ON UPDATE. 427 | columnNull = fmt.Sprint(columnNull, " ON UPDATE ", tagOnUpdate) 428 | } 429 | 430 | if tagOnDelete := field.Tag.Get("db_on_delete"); tagOnDelete != "" { 431 | // ON DELETE. 432 | columnNull = fmt.Sprint(columnNull, " ON DELETE ", tagOnDelete) 433 | } 434 | } 435 | } 436 | } 437 | 438 | if columnType == "" { 439 | return "", fmt.Errorf("rem: Unsupported column type: %T. Use the 'db_type' field tag to define a SQL type", fieldInstance) 440 | } 441 | 442 | if tagDefault := field.Tag.Get("db_default"); tagDefault != "" { 443 | // DEFAULT. 444 | columnNull += " DEFAULT " + tagDefault 445 | } 446 | 447 | if tagUnique := field.Tag.Get("db_unique"); tagUnique == "true" { 448 | // UNIQUE. 449 | columnNull += " UNIQUE" 450 | } 451 | 452 | return fmt.Sprint(columnType, columnPrimary, columnNull), nil 453 | } 454 | 455 | func (dialect PqDialect) Param(identifier int) string { 456 | var query strings.Builder 457 | query.WriteString("$") 458 | query.WriteString(strconv.Itoa(identifier)) 459 | return query.String() 460 | } 461 | 462 | func (dialect PqDialect) QuoteIdentifier(identifier string) string { 463 | // 100-500ns all the way up to ~45us on early op for some reason. 464 | var query strings.Builder 465 | for i, part := range strings.Split(identifier, ".") { 466 | if i > 0 { 467 | query.WriteString(".") 468 | } 469 | query.WriteString(QuoteIdentifier(part)) 470 | } 471 | return query.String() 472 | } 473 | -------------------------------------------------------------------------------- /pqdialect/pqdialect_test.go: -------------------------------------------------------------------------------- 1 | package pqdialect 2 | 3 | import ( 4 | "database/sql" 5 | "sort" 6 | "testing" 7 | "time" 8 | 9 | "github.com/evantbyrne/rem" 10 | "golang.org/x/exp/maps" 11 | "golang.org/x/exp/slices" 12 | ) 13 | 14 | func TestAs(t *testing.T) { 15 | dialect := PqDialect{} 16 | expected := map[string]rem.SqlAs{ 17 | `"x" AS "alias1"`: rem.As("x", "alias1"), 18 | `"x" AS "y" AS "alias2"`: rem.As(rem.As("x", "y"), "alias2"), 19 | `count(*) AS "alias3"`: rem.As(rem.Unsafe("count(*)"), "alias3"), 20 | } 21 | for expected, alias := range expected { 22 | sql := alias.StringForDialect(dialect) 23 | if expected != sql { 24 | t.Errorf("Expected '%+v', got '%+v'", expected, sql) 25 | } 26 | } 27 | } 28 | 29 | func TestColumn(t *testing.T) { 30 | dialect := PqDialect{} 31 | expected := map[string]rem.SqlColumn{ 32 | `"x"`: rem.Column("x"), 33 | `"x"."y"`: rem.Column("x.y"), 34 | `"x"."y"."z"`: rem.Column("x.y.z"), 35 | `"x"""`: rem.Column(`x"`), 36 | } 37 | for expected, column := range expected { 38 | sql := column.StringForDialect(dialect) 39 | if expected != sql { 40 | t.Errorf("Expected '%+v', got '%+v'", expected, sql) 41 | } 42 | } 43 | } 44 | 45 | func TestBuildDelete(t *testing.T) { 46 | type testModel struct { 47 | Id int64 `db:"test_id" db_primary:"true"` 48 | Value1 string `db:"test_value_1" db_max_length:"100"` 49 | Value2 string `db:"test_value_2" db_max_length:"100"` 50 | } 51 | 52 | dialect := PqDialect{} 53 | model := rem.Use[testModel]() 54 | 55 | config := model.Query().Config 56 | config.Fields = model.Fields 57 | config.Table = "testmodel" 58 | expectedArgs := []interface{}{} 59 | expectedSql := `DELETE FROM "testmodel"` 60 | queryString, args, err := dialect.BuildDelete(config) 61 | if err != nil { 62 | t.Errorf("Unexpected error %s", err.Error()) 63 | } 64 | if queryString != expectedSql { 65 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 66 | } 67 | if !slices.Equal(args, expectedArgs) { 68 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 69 | } 70 | 71 | // WHERE 72 | config = model.Filter("test_id", "=", 1).Config 73 | config.Fields = model.Fields 74 | config.Table = "testmodel" 75 | expectedArgs = []interface{}{1} 76 | expectedSql = `DELETE FROM "testmodel" WHERE "test_id" = $1` 77 | queryString, args, err = dialect.BuildDelete(config) 78 | if err != nil { 79 | t.Errorf("Unexpected error %s", err.Error()) 80 | } 81 | if queryString != expectedSql { 82 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 83 | } 84 | if !slices.Equal(args, expectedArgs) { 85 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 86 | } 87 | } 88 | 89 | func TestBuildInsert(t *testing.T) { 90 | type testModel struct { 91 | Id int64 `db:"test_id" db_primary:"true"` 92 | Value1 string `db:"test_value_1" db_max_length:"100"` 93 | Value2 string `db:"test_value_2" db_max_length:"100"` 94 | } 95 | 96 | dialect := PqDialect{} 97 | model := rem.Use[testModel]() 98 | 99 | config := model.Query().Config 100 | config.Fields = model.Fields 101 | config.Table = "testmodel" 102 | expectedArgs := []interface{}{"foo", "bar"} 103 | expectedSql := `INSERT INTO "testmodel" ("test_value_1","test_value_2") VALUES ($1,$2)` 104 | queryString, args, err := dialect.BuildInsert(config, map[string]interface{}{ 105 | "test_value_1": "foo", 106 | "test_value_2": "bar", 107 | }, "test_value_1", "test_value_2") 108 | if err != nil { 109 | t.Errorf("Unexpected error %s", err.Error()) 110 | } 111 | if queryString != expectedSql { 112 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 113 | } 114 | if !slices.Equal(args, expectedArgs) { 115 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 116 | } 117 | } 118 | 119 | func TestBuildSelect(t *testing.T) { 120 | type testModel struct { 121 | Id int64 `db:"test_id" db_primary:"true"` 122 | Value1 string `db:"test_value_1" db_max_length:"100"` 123 | Value2 string `db:"test_value_2" db_max_length:"100"` 124 | } 125 | 126 | dialect := PqDialect{} 127 | model := rem.Use[testModel]() 128 | 129 | config := model.Query().Config 130 | config.Fields = model.Fields 131 | config.Table = "testmodel" 132 | expectedArgs := []interface{}{} 133 | expectedSql := `SELECT * FROM "testmodel"` 134 | queryString, args, err := dialect.BuildSelect(config) 135 | if err != nil { 136 | t.Errorf("Unexpected error %s", err.Error()) 137 | } 138 | if queryString != expectedSql { 139 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 140 | } 141 | if !slices.Equal(args, expectedArgs) { 142 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 143 | } 144 | 145 | // SELECT 146 | config = model.Select("id", "value1", rem.Unsafe(`count(1) as "count"`), rem.As("value2", "value3")).Config 147 | config.Fields = model.Fields 148 | config.Table = "testmodel" 149 | expectedArgs = []interface{}{} 150 | expectedSql = `SELECT "id","value1",count(1) as "count","value2" AS "value3" FROM "testmodel"` 151 | queryString, args, err = dialect.BuildSelect(config) 152 | if err != nil { 153 | t.Errorf("Unexpected error %s", err.Error()) 154 | } 155 | if queryString != expectedSql { 156 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 157 | } 158 | if !slices.Equal(args, expectedArgs) { 159 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 160 | } 161 | 162 | // WHERE 163 | config = model.Filter("id", "=", 1).Config 164 | config.Fields = model.Fields 165 | config.Table = "testmodel" 166 | expectedArgs = []interface{}{1} 167 | expectedSql = `SELECT * FROM "testmodel" WHERE "id" = $1` 168 | queryString, args, err = dialect.BuildSelect(config) 169 | if err != nil { 170 | t.Errorf("Unexpected error %s", err.Error()) 171 | } 172 | if queryString != expectedSql { 173 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 174 | } 175 | if !slices.Equal(args, expectedArgs) { 176 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 177 | } 178 | 179 | config = model.Filter("id", "IN", rem.Sql(rem.Param(1), ",", rem.Param(2))).Config 180 | config.Fields = model.Fields 181 | config.Table = "testmodel" 182 | expectedArgs = []interface{}{1, 2} 183 | expectedSql = `SELECT * FROM "testmodel" WHERE "id" IN ($1,$2)` 184 | queryString, args, err = dialect.BuildSelect(config) 185 | if err != nil { 186 | t.Errorf("Unexpected error %s", err.Error()) 187 | } 188 | if queryString != expectedSql { 189 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 190 | } 191 | if !slices.Equal(args, expectedArgs) { 192 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 193 | } 194 | 195 | // JOIN 196 | config = model.Select(rem.Unsafe("*")).Join("groups", rem.Or( 197 | rem.Q("groups.id", "=", rem.Column("accounts.group_id")), 198 | rem.Q("groups.id", "IS", nil))).Config 199 | config.Fields = model.Fields 200 | config.Table = "testmodel" 201 | expectedArgs = []interface{}{} 202 | expectedSql = `SELECT * FROM "testmodel" INNER JOIN "groups" ON ( "groups"."id" = "accounts"."group_id" OR "groups"."id" IS NULL )` 203 | queryString, args, err = dialect.BuildSelect(config) 204 | if err != nil { 205 | t.Errorf("Unexpected error %s", err.Error()) 206 | } 207 | if queryString != expectedSql { 208 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 209 | } 210 | if !slices.Equal(args, expectedArgs) { 211 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 212 | } 213 | 214 | // SORT 215 | config = model.Sort("test_id", "-test_value_1").Config 216 | config.Fields = model.Fields 217 | config.Table = "testmodel" 218 | expectedArgs = []interface{}{} 219 | expectedSql = `SELECT * FROM "testmodel" ORDER BY "test_id" ASC, "test_value_1" DESC` 220 | queryString, args, err = dialect.BuildSelect(config) 221 | if err != nil { 222 | t.Errorf("Unexpected error %s", err.Error()) 223 | } 224 | if queryString != expectedSql { 225 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 226 | } 227 | if !slices.Equal(args, expectedArgs) { 228 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 229 | } 230 | 231 | // LIMIT and OFFSET 232 | config = model.Filter("id", "=", 1).Offset(20).Limit(10).Config 233 | config.Fields = model.Fields 234 | config.Table = "testmodel" 235 | expectedArgs = []interface{}{1, 10, 20} 236 | expectedSql = `SELECT * FROM "testmodel" WHERE "id" = $1 LIMIT $2 OFFSET $3` 237 | queryString, args, err = dialect.BuildSelect(config) 238 | if err != nil { 239 | t.Errorf("Unexpected error %s", err.Error()) 240 | } 241 | if queryString != expectedSql { 242 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 243 | } 244 | if !slices.Equal(args, expectedArgs) { 245 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 246 | } 247 | } 248 | 249 | func TestBuildTableColumnAdd(t *testing.T) { 250 | type testModel struct { 251 | Value string `db:"test_value" db_max_length:"100"` 252 | } 253 | 254 | dialect := PqDialect{} 255 | model := rem.Use[testModel]() 256 | config := rem.QueryConfig{ 257 | Fields: model.Fields, 258 | Table: "testmodel", 259 | } 260 | expectedSql := `ALTER TABLE "testmodel" ADD COLUMN "test_value" VARCHAR(100) NOT NULL` 261 | queryString, err := dialect.BuildTableColumnAdd(config, "test_value") 262 | if err != nil { 263 | t.Errorf("Unexpected error %s", err.Error()) 264 | } 265 | if queryString != expectedSql { 266 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 267 | } 268 | } 269 | 270 | func TestBuildTableColumnDrop(t *testing.T) { 271 | dialect := PqDialect{} 272 | config := rem.QueryConfig{Table: "testmodel"} 273 | expectedSql := `ALTER TABLE "testmodel" DROP COLUMN "test_value"` 274 | queryString, err := dialect.BuildTableColumnDrop(config, "test_value") 275 | if err != nil { 276 | t.Errorf("Unexpected error %s", err.Error()) 277 | } 278 | if queryString != expectedSql { 279 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 280 | } 281 | } 282 | 283 | func TestBuildTableCreate(t *testing.T) { 284 | type testModel struct { 285 | Id int64 `db:"test_id" db_primary:"true"` 286 | Value1 string `db:"test_value_1" db_max_length:"100"` 287 | } 288 | 289 | dialect := PqDialect{} 290 | model := rem.Use[testModel]() 291 | config := rem.QueryConfig{ 292 | Fields: model.Fields, 293 | Table: "testmodel", 294 | } 295 | expectedSql := `CREATE TABLE "testmodel" ( 296 | "test_id" BIGSERIAL PRIMARY KEY NOT NULL, 297 | "test_value_1" VARCHAR(100) NOT NULL 298 | )` 299 | queryString, err := dialect.BuildTableCreate(config, rem.TableCreateConfig{}) 300 | if err != nil { 301 | t.Errorf("Unexpected error %s", err.Error()) 302 | } 303 | if queryString != expectedSql { 304 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 305 | } 306 | 307 | expectedSql = `CREATE TABLE IF NOT EXISTS "testmodel" ( 308 | "test_id" BIGSERIAL PRIMARY KEY NOT NULL, 309 | "test_value_1" VARCHAR(100) NOT NULL 310 | )` 311 | queryString, err = dialect.BuildTableCreate(config, rem.TableCreateConfig{IfNotExists: true}) 312 | if err != nil { 313 | t.Errorf("Unexpected error %s", err.Error()) 314 | } 315 | if queryString != expectedSql { 316 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 317 | } 318 | } 319 | 320 | func TestBuildTableDrop(t *testing.T) { 321 | dialect := PqDialect{} 322 | config := rem.QueryConfig{Table: "testmodel"} 323 | expectedSql := `DROP TABLE "testmodel"` 324 | queryString, err := dialect.BuildTableDrop(config, rem.TableDropConfig{}) 325 | if err != nil { 326 | t.Errorf("Unexpected error %s", err.Error()) 327 | } 328 | if queryString != expectedSql { 329 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 330 | } 331 | 332 | expectedSql = `DROP TABLE IF EXISTS "testmodel"` 333 | queryString, err = dialect.BuildTableDrop(config, rem.TableDropConfig{IfExists: true}) 334 | if err != nil { 335 | t.Errorf("Unexpected error %s", err.Error()) 336 | } 337 | if queryString != expectedSql { 338 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 339 | } 340 | } 341 | 342 | func TestBuildUpdate(t *testing.T) { 343 | type testModel struct { 344 | Id int64 `db:"test_id" db_primary:"true"` 345 | Value1 string `db:"test_value_1" db_max_length:"100"` 346 | Value2 string `db:"test_value_2" db_max_length:"100"` 347 | } 348 | 349 | dialect := PqDialect{} 350 | model := rem.Use[testModel]() 351 | config := model.Filter("test_id", "=", 1).Config 352 | config.Fields = model.Fields 353 | config.Table = "testmodel" 354 | 355 | expectedArgs := []interface{}{"foo", "bar", 1} 356 | expectedSql := `UPDATE "testmodel" SET "test_value_1" = $1,"test_value_2" = $2 WHERE "test_id" = $3` 357 | queryString, args, err := dialect.BuildUpdate(config, map[string]interface{}{ 358 | "id": 123, 359 | "test_value_1": "foo", 360 | "test_value_2": "bar", 361 | }, "test_value_1", "test_value_2") 362 | if err != nil { 363 | t.Errorf("Unexpected error %s", err.Error()) 364 | } 365 | if queryString != expectedSql { 366 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 367 | } 368 | if !slices.Equal(args, expectedArgs) { 369 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 370 | } 371 | } 372 | 373 | func TestColumnType(t *testing.T) { 374 | type testFkInt struct { 375 | Id int64 `db:"id" db_primary:"true"` 376 | } 377 | 378 | type testFkString struct { 379 | Id string `db:"id" db_primary:"true" db_max_length:"100"` 380 | } 381 | 382 | type testModel struct { 383 | BigInt int64 `db:"test_big_int"` 384 | BigIntNull sql.NullInt64 `db:"test_big_int_null"` 385 | Bool bool `db:"test_bool"` 386 | BoolNull sql.NullBool `db:"test_bool_null"` 387 | Custom []byte `db:"test_custom" db_type:"JSONB NOT NULL"` 388 | Default string `db:"test_default" db_default:"'foo'" db_max_length:"100"` 389 | Float float64 `db:"test_float"` 390 | FloatNull sql.NullFloat64 `db:"test_float_null"` 391 | Id int64 `db:"test_id" db_primary:"true"` 392 | Int int32 `db:"test_int"` 393 | IntNull sql.NullInt32 `db:"test_int_null"` 394 | SmallInt int16 `db:"test_small_int"` 395 | SmallIntNull sql.NullInt16 `db:"test_small_int_null"` 396 | Text string `db:"test_text"` 397 | TextNull sql.NullString `db:"test_text_null"` 398 | Time time.Time `db:"test_time"` 399 | TimeNow time.Time `db:"test_time_now" db_default:"now()"` 400 | TimeNull sql.NullTime `db:"test_time_null"` 401 | TimeZone time.Time `db:"test_time_zone" db_time_zone:"true"` 402 | Varchar string `db:"test_varchar" db_max_length:"100"` 403 | VarcharNull sql.NullString `db:"test_varchar_null" db_max_length:"50"` 404 | ForiegnKey rem.ForeignKey[testFkString] `db:"test_fk_id" db_on_delete:"CASCADE" db_on_update:"CASCADE"` 405 | ForiegnKeyNull rem.NullForeignKey[testFkInt] `db:"test_fk_null_id" db_on_delete:"SET NULL"` 406 | Unique string `db:"test_unique" db_max_length:"255" db_unique:"true"` 407 | } 408 | 409 | expected := map[string]string{ 410 | "test_big_int": "BIGINT NOT NULL", 411 | "test_big_int_null": "BIGINT NULL", 412 | "test_bool": "BOOLEAN NOT NULL", 413 | "test_bool_null": "BOOLEAN NULL", 414 | "test_custom": "JSONB NOT NULL", 415 | "test_default": "VARCHAR(100) NOT NULL DEFAULT 'foo'", 416 | "test_float": "DOUBLE PRECISION NOT NULL", 417 | "test_float_null": "DOUBLE PRECISION NULL", 418 | "test_id": "BIGSERIAL PRIMARY KEY NOT NULL", 419 | "test_int": "INTEGER NOT NULL", 420 | "test_int_null": "INTEGER NULL", 421 | "test_small_int": "SMALLINT NOT NULL", 422 | "test_small_int_null": "SMALLINT NULL", 423 | "test_time": "TIMESTAMP WITHOUT TIME ZONE NOT NULL", 424 | "test_time_now": "TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT now()", 425 | "test_time_null": "TIMESTAMP WITHOUT TIME ZONE NULL", 426 | "test_time_zone": "TIMESTAMP WITH TIME ZONE NOT NULL", 427 | "test_text": "TEXT NOT NULL", 428 | "test_text_null": "TEXT NULL", 429 | "test_varchar": "VARCHAR(100) NOT NULL", 430 | "test_varchar_null": "VARCHAR(50) NULL", 431 | "test_fk_id": `VARCHAR(100) NOT NULL REFERENCES "testfkstring" ("id") ON UPDATE CASCADE ON DELETE CASCADE`, 432 | "test_fk_null_id": `BIGINT NULL REFERENCES "testfkint" ("id") ON DELETE SET NULL`, 433 | "test_unique": `VARCHAR(255) NOT NULL UNIQUE`, 434 | } 435 | 436 | dialect := PqDialect{} 437 | model := rem.Use[testModel]() 438 | fieldKeys := maps.Keys(model.Fields) 439 | sort.Strings(fieldKeys) 440 | 441 | for _, fieldName := range fieldKeys { 442 | field := model.Fields[fieldName] 443 | columnType, err := dialect.ColumnType(field) 444 | if err != nil { 445 | t.Fatalf(`dialect.ColumnType() threw error for '%#v': %s`, field, err) 446 | } 447 | if columnType != expected[fieldName] { 448 | t.Fatalf(`dialect.ColumnType() returned '%s', but expected '%s' for '%#v'`, columnType, expected[fieldName], field) 449 | } 450 | } 451 | } 452 | -------------------------------------------------------------------------------- /pqdialect/pqlib.go: -------------------------------------------------------------------------------- 1 | package pqdialect 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be 8 | // used as part of an SQL statement. For example: 9 | // 10 | // tblname := "my_table" 11 | // data := "my_data" 12 | // quoted := pq.QuoteIdentifier(tblname) 13 | // err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) 14 | // 15 | // Any double quotes in name will be escaped. The quoted identifier will be 16 | // case sensitive when used in a query. If the input string contains a zero 17 | // byte, the result will be truncated immediately before it. 18 | // 19 | // Via https://github.com/lib/pq v1.10.9. 20 | // Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany 21 | // Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 22 | // The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 23 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | func QuoteIdentifier(name string) string { 25 | end := strings.IndexRune(name, 0) 26 | if end > -1 { 27 | name = name[:end] 28 | } 29 | return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 30 | } 31 | -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "reflect" 8 | "strings" 9 | 10 | "golang.org/x/exp/maps" 11 | ) 12 | 13 | type JoinClause struct { 14 | Direction string 15 | On []FilterClause 16 | Table string 17 | } 18 | 19 | type QueryConfig struct { 20 | Count bool 21 | Context context.Context 22 | FetchRelated []string 23 | Fields map[string]reflect.StructField 24 | Filters []FilterClause 25 | Joins []JoinClause 26 | Limit interface{} 27 | Offset interface{} 28 | Params []interface{} 29 | Selected []interface{} 30 | Sort []string 31 | Table string 32 | Transaction *sql.Tx 33 | } 34 | 35 | type Query[T any] struct { 36 | Config QueryConfig 37 | Error error 38 | Model *Model[T] 39 | Rows *sql.Rows 40 | 41 | dialect Dialect 42 | } 43 | 44 | func (query *Query[T]) All(db *sql.DB) ([]*T, error) { 45 | query.detectDialect() 46 | query.configure() 47 | 48 | queryString, args, err := query.dialect.BuildSelect(query.Config) 49 | if err != nil { 50 | return make([]*T, 0), err 51 | } 52 | 53 | rows, err := query.dbQuery(db, queryString, args...) 54 | if err != nil { 55 | return make([]*T, 0), err 56 | } 57 | query.Rows = rows 58 | return query.slice(db) 59 | } 60 | 61 | func (query *Query[T]) AllToMap(db *sql.DB) ([]map[string]interface{}, error) { 62 | query.detectDialect() 63 | query.configure() 64 | 65 | queryString, args, err := query.dialect.BuildSelect(query.Config) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | rows, err := query.dbQuery(db, queryString, args...) 71 | if err != nil { 72 | return nil, err 73 | } 74 | query.Rows = rows 75 | defer query.Rows.Close() 76 | 77 | mapped := make([]map[string]interface{}, 0) 78 | for query.Rows.Next() { 79 | data, err := query.Model.ScanToMap(query.Rows) 80 | if err != nil { 81 | return nil, err 82 | } 83 | mapped = append(mapped, data) 84 | } 85 | 86 | if query.Config.Context != nil { 87 | select { 88 | default: 89 | case <-query.Config.Context.Done(): 90 | return nil, query.Config.Context.Err() 91 | } 92 | } 93 | 94 | return mapped, nil 95 | } 96 | 97 | func (query *Query[T]) configure() { 98 | query.Config.Fields = query.Model.Fields 99 | query.Config.Table = query.Model.Table 100 | } 101 | 102 | func (query *Query[T]) Context(context context.Context) *Query[T] { 103 | query.Config.Context = context 104 | return query 105 | } 106 | 107 | func (query *Query[T]) Count(db *sql.DB) (uint, error) { 108 | query.detectDialect() 109 | query.configure() 110 | 111 | var count uint 112 | 113 | query.Config.Count = true 114 | 115 | queryString, args, err := query.dialect.BuildSelect(query.Config) 116 | if err != nil { 117 | return count, err 118 | } 119 | 120 | if query.Config.Transaction != nil { 121 | if query.Config.Context != nil { 122 | err = query.Config.Transaction.QueryRowContext(query.Config.Context, queryString, args...).Scan(&count) 123 | } else { 124 | err = query.Config.Transaction.QueryRow(queryString, args...).Scan(&count) 125 | } 126 | } else if query.Config.Context != nil { 127 | err = db.QueryRowContext(query.Config.Context, queryString, args...).Scan(&count) 128 | } else { 129 | err = db.QueryRow(queryString, args...).Scan(&count) 130 | } 131 | if err != nil { 132 | return count, err 133 | } 134 | 135 | return count, nil 136 | } 137 | 138 | func (query *Query[T]) dbExec(db *sql.DB, queryString string, args ...interface{}) (sql.Result, error) { 139 | if query.Config.Transaction != nil { 140 | if query.Config.Context != nil { 141 | return query.Config.Transaction.ExecContext(query.Config.Context, queryString, args...) 142 | } 143 | return query.Config.Transaction.Exec(queryString, args...) 144 | } 145 | 146 | if query.Config.Context != nil { 147 | return db.ExecContext(query.Config.Context, queryString, args...) 148 | } 149 | return db.Exec(queryString, args...) 150 | } 151 | 152 | func (query *Query[T]) dbQuery(db *sql.DB, queryString string, args ...interface{}) (*sql.Rows, error) { 153 | if query.Config.Transaction != nil { 154 | if query.Config.Context != nil { 155 | return query.Config.Transaction.QueryContext(query.Config.Context, queryString, args...) 156 | } 157 | return query.Config.Transaction.Query(queryString, args...) 158 | } 159 | 160 | if query.Config.Context != nil { 161 | return db.QueryContext(query.Config.Context, queryString, args...) 162 | } 163 | return db.Query(queryString, args...) 164 | } 165 | 166 | func (query *Query[T]) Delete(db *sql.DB) (sql.Result, error) { 167 | query.detectDialect() 168 | query.configure() 169 | 170 | queryString, args, err := query.dialect.BuildDelete(query.Config) 171 | if err != nil { 172 | return nil, err 173 | } 174 | return query.dbExec(db, queryString, args...) 175 | } 176 | 177 | func (query *Query[T]) detectDialect() { 178 | if query.dialect == nil { 179 | if defaultDialect != nil { 180 | query.dialect = defaultDialect 181 | } else { 182 | panic("rem: no dialect registered. Use rem.SetDialect(dialect rem.Dialect) to register a default for SQL queries") 183 | } 184 | } 185 | } 186 | 187 | func (query *Query[T]) Dialect(dialect Dialect) *Query[T] { 188 | query.dialect = dialect 189 | return query 190 | } 191 | 192 | func (query *Query[T]) Exists(db *sql.DB) (bool, error) { 193 | query.detectDialect() 194 | query.configure() 195 | 196 | query.Config.Limit = 1 197 | 198 | queryString, args, err := query.dialect.BuildSelect(query.Config) 199 | if err != nil { 200 | return false, err 201 | } 202 | 203 | rows, err := query.dbQuery(db, queryString, args...) 204 | if err != nil { 205 | return false, err 206 | } 207 | return rows.Next(), nil 208 | } 209 | 210 | func (query *Query[T]) FetchRelated(columns ...string) *Query[T] { 211 | query.Config.FetchRelated = columns 212 | return query 213 | } 214 | 215 | func (query *Query[T]) Filter(column interface{}, operator string, value interface{}) *Query[T] { 216 | if len(query.Config.Filters) > 0 { 217 | query.Config.Filters = append(query.Config.Filters, FilterClause{Rule: "AND"}) 218 | } 219 | query.Config.Filters = append(query.Config.Filters, Q(column, operator, value)) 220 | return query 221 | } 222 | 223 | func (query *Query[T]) FilterAnd(clauses ...interface{}) *Query[T] { 224 | flat := make([]FilterClause, 0) 225 | for _, clause := range clauses { 226 | flat = flattenFilterClause(flat, clause) 227 | } 228 | 229 | if len(query.Config.Filters) > 0 { 230 | query.Config.Filters = append(query.Config.Filters, FilterClause{Rule: "AND"}) 231 | } 232 | query.Config.Filters = append(query.Config.Filters, FilterClause{Rule: "("}) 233 | indent := 0 234 | for i, clause := range flat { 235 | if i > 0 && indent == 0 && flat[i-1].Rule != "NOT" { 236 | query.Config.Filters = append(query.Config.Filters, FilterClause{Rule: "AND"}) 237 | } 238 | if clause.Rule == "(" { 239 | indent++ 240 | } else if clause.Rule == ")" { 241 | indent-- 242 | } 243 | query.Config.Filters = append(query.Config.Filters, clause) 244 | } 245 | 246 | query.Config.Filters = append(query.Config.Filters, FilterClause{Rule: ")"}) 247 | return query 248 | } 249 | 250 | func (query *Query[T]) FilterOr(clauses ...interface{}) *Query[T] { 251 | flat := make([]FilterClause, 0) 252 | for _, clause := range clauses { 253 | flat = flattenFilterClause(flat, clause) 254 | } 255 | 256 | if len(query.Config.Filters) > 0 { 257 | query.Config.Filters = append(query.Config.Filters, FilterClause{Rule: "AND"}) 258 | } 259 | query.Config.Filters = append(query.Config.Filters, FilterClause{Rule: "("}) 260 | indent := 0 261 | for i, clause := range flat { 262 | if i > 0 && indent == 0 && flat[i-1].Rule != "NOT" { 263 | query.Config.Filters = append(query.Config.Filters, FilterClause{Rule: "OR"}) 264 | } 265 | if clause.Rule == "(" { 266 | indent++ 267 | } else if clause.Rule == ")" { 268 | indent-- 269 | } 270 | query.Config.Filters = append(query.Config.Filters, clause) 271 | } 272 | 273 | query.Config.Filters = append(query.Config.Filters, FilterClause{Rule: ")"}) 274 | return query 275 | } 276 | 277 | func (query *Query[T]) First(db *sql.DB) (*T, error) { 278 | query.detectDialect() 279 | query.configure() 280 | 281 | query.Limit(1) 282 | 283 | queryString, args, err := query.dialect.BuildSelect(query.Config) 284 | if err != nil { 285 | return nil, err 286 | } 287 | 288 | rows, err := query.dbQuery(db, queryString, args...) 289 | if err != nil { 290 | return nil, err 291 | } 292 | query.Rows = rows 293 | 294 | defer query.Rows.Close() 295 | if query.Rows.Next() { 296 | return query.Model.Scan(query.Rows) 297 | } 298 | 299 | if query.Config.Context != nil { 300 | select { 301 | default: 302 | case <-query.Config.Context.Done(): 303 | return nil, query.Config.Context.Err() 304 | } 305 | } 306 | 307 | return nil, sql.ErrNoRows 308 | } 309 | 310 | func (query *Query[T]) FirstToMap(db *sql.DB) (map[string]interface{}, error) { 311 | query.detectDialect() 312 | query.configure() 313 | 314 | query.Limit(1) 315 | 316 | queryString, args, err := query.dialect.BuildSelect(query.Config) 317 | if err != nil { 318 | return nil, err 319 | } 320 | 321 | rows, err := query.dbQuery(db, queryString, args...) 322 | if err != nil { 323 | return nil, err 324 | } 325 | query.Rows = rows 326 | 327 | defer query.Rows.Close() 328 | if query.Rows.Next() { 329 | return query.Model.ScanToMap(query.Rows) 330 | } 331 | 332 | if query.Config.Context != nil { 333 | select { 334 | default: 335 | case <-query.Config.Context.Done(): 336 | return nil, query.Config.Context.Err() 337 | } 338 | } 339 | 340 | return nil, sql.ErrNoRows 341 | } 342 | 343 | func (query *Query[T]) Insert(db *sql.DB, row *T) (sql.Result, error) { 344 | query.detectDialect() 345 | query.configure() 346 | rowMap, err := query.Model.ToMap(row) 347 | if err != nil { 348 | return nil, err 349 | } 350 | queryString, args, err := query.dialect.BuildInsert(query.Config, rowMap, maps.Keys(rowMap)...) 351 | if err != nil { 352 | return nil, err 353 | } 354 | return query.dbExec(db, queryString, args...) 355 | } 356 | 357 | func (query *Query[T]) InsertMap(db *sql.DB, data map[string]interface{}) (sql.Result, error) { 358 | query.detectDialect() 359 | query.configure() 360 | queryString, args, err := query.dialect.BuildInsert(query.Config, data, maps.Keys(data)...) 361 | if err != nil { 362 | return nil, err 363 | } 364 | return query.dbExec(db, queryString, args...) 365 | } 366 | 367 | func (query *Query[T]) Join(table string, clauses ...interface{}) *Query[T] { 368 | flat := make([]FilterClause, 0) 369 | for _, clause := range clauses { 370 | flat = flattenFilterClause(flat, clause) 371 | } 372 | 373 | query.Config.Joins = append(query.Config.Joins, JoinClause{ 374 | Direction: "INNER", 375 | On: flat, 376 | Table: table, 377 | }) 378 | return query 379 | } 380 | 381 | func (query *Query[T]) JoinFull(table string, clauses ...interface{}) *Query[T] { 382 | flat := make([]FilterClause, 0) 383 | for _, clause := range clauses { 384 | flat = flattenFilterClause(flat, clause) 385 | } 386 | 387 | query.Config.Joins = append(query.Config.Joins, JoinClause{ 388 | Direction: "FULL", 389 | On: flat, 390 | Table: table, 391 | }) 392 | return query 393 | } 394 | 395 | func (query *Query[T]) JoinLeft(table string, clauses ...interface{}) *Query[T] { 396 | flat := make([]FilterClause, 0) 397 | for _, clause := range clauses { 398 | flat = flattenFilterClause(flat, clause) 399 | } 400 | 401 | query.Config.Joins = append(query.Config.Joins, JoinClause{ 402 | Direction: "LEFT", 403 | On: flat, 404 | Table: table, 405 | }) 406 | return query 407 | } 408 | 409 | func (query *Query[T]) JoinRight(table string, clauses ...interface{}) *Query[T] { 410 | flat := make([]FilterClause, 0) 411 | for _, clause := range clauses { 412 | flat = flattenFilterClause(flat, clause) 413 | } 414 | 415 | query.Config.Joins = append(query.Config.Joins, JoinClause{ 416 | Direction: "RIGHT", 417 | On: flat, 418 | Table: table, 419 | }) 420 | return query 421 | } 422 | 423 | func (query *Query[T]) Limit(limit interface{}) *Query[T] { 424 | query.Config.Limit = limit 425 | return query 426 | } 427 | 428 | func (query *Query[T]) Offset(offset interface{}) *Query[T] { 429 | query.Config.Offset = offset 430 | return query 431 | } 432 | 433 | func (query *Query[T]) Select(columns ...interface{}) *Query[T] { 434 | query.Config.Selected = columns 435 | return query 436 | } 437 | 438 | func (query *Query[T]) slice(db *sql.DB) ([]*T, error) { 439 | rows := make([]*T, 0) 440 | if query.Error != nil { 441 | return rows, query.Error 442 | } 443 | defer query.Rows.Close() 444 | 445 | relatedPks := make(map[string]relatedPk) 446 | for query.Rows.Next() { 447 | row, err := query.Model.Scan(query.Rows) 448 | if err != nil { 449 | return rows, err 450 | } 451 | if len(query.Config.FetchRelated) > 0 { 452 | value := reflect.ValueOf(*row) 453 | for _, column := range query.Config.FetchRelated { 454 | valueFk := value.FieldByName(column) 455 | if !valueFk.IsValid() { 456 | return rows, fmt.Errorf("rem: invalid field '%s' for fetching related. Field does not exist on model", column) 457 | } 458 | if strings.HasPrefix(valueFk.Type().String(), "rem.ForeignKey[") || strings.HasPrefix(valueFk.Type().String(), "rem.NullForeignKey[") { 459 | if valueFk.FieldByName("Valid").Interface().(bool) { 460 | r := reflect.New(valueFk.Type()).MethodByName("Model").Call(nil) 461 | rpk, ok := relatedPks[column] 462 | if !ok { 463 | rpk = relatedPk{ 464 | RelatedColumn: reflect.Indirect(r[0]).FieldByName("PrimaryColumn").Interface().(string), 465 | RelatedField: reflect.Indirect(r[0]).FieldByName("PrimaryField").Interface().(string), 466 | RelatedValues: make([]interface{}, 0), 467 | } 468 | } 469 | rpk.RelatedValues = append(rpk.RelatedValues, valueFk.FieldByName("Row").Elem().FieldByName(rpk.RelatedField).Interface()) 470 | relatedPks[column] = rpk 471 | } 472 | } else if strings.HasPrefix(valueFk.Type().String(), "rem.OneToMany[") { 473 | rpk, ok := relatedPks[column] 474 | if !ok { 475 | relatedColumn := valueFk.FieldByName("RelatedColumn").Interface().(string) 476 | r := reflect.New(valueFk.Type()).MethodByName("Model").Call(nil) 477 | fkModelFields := reflect.Indirect(r[0]).FieldByName("Fields").MapRange() 478 | var relatedField string 479 | for fkModelFields.Next() { 480 | fkModelField := fkModelFields.Value().FieldByName("Tag").MethodByName("Get").Call([]reflect.Value{reflect.ValueOf("db")})[0].Interface().(string) 481 | if fkModelField == relatedColumn { 482 | relatedField = fkModelFields.Value().FieldByName("Name").Interface().(string) 483 | break 484 | } 485 | } 486 | 487 | if relatedField == "" { 488 | return rows, fmt.Errorf("rem: invalid db tag of '%s' for fetching related on field '%s'. No fields with a matching column exist on the related model", relatedColumn, column) 489 | } 490 | 491 | rpk = relatedPk{ 492 | RelatedColumn: relatedColumn, 493 | RelatedField: relatedField, 494 | RelatedValues: make([]interface{}, 0), 495 | } 496 | } 497 | rpk.RelatedValues = append(rpk.RelatedValues, value.FieldByName(query.Model.PrimaryField).Interface()) 498 | relatedPks[column] = rpk 499 | } else { 500 | return rows, fmt.Errorf("rem: invalid field '%s' for fetching related. Field must be of type rem.ForeignKey[To], rem.NullForeignKey[To], or rem.OneToMany[To, From]", column) 501 | } 502 | } 503 | } 504 | rows = append(rows, row) 505 | } 506 | 507 | if len(relatedPks) > 0 { 508 | var temp T 509 | modelValue := reflect.ValueOf(&temp).Elem() 510 | 511 | for column, rpk := range relatedPks { 512 | if len(rpk.RelatedValues) > 0 { 513 | fk := reflect.New(modelValue.FieldByName(column).Type()) 514 | 515 | q := fk.MethodByName("Query").Call(nil) 516 | q = q[0].MethodByName("Filter").Call([]reflect.Value{ 517 | reflect.ValueOf(rpk.RelatedColumn), 518 | reflect.ValueOf("IN"), 519 | reflect.ValueOf(rpk.RelatedValues), 520 | }) 521 | q = q[0].MethodByName("All").Call([]reflect.Value{ 522 | reflect.ValueOf(db), 523 | }) 524 | rowsValue := reflect.ValueOf(rows) 525 | for i := 0; i < rowsValue.Len(); i++ { 526 | value := rowsValue.Index(i).Elem() 527 | valueFk := value.FieldByName(column) 528 | 529 | if strings.HasPrefix(fk.Type().String(), "*rem.OneToMany[") { 530 | mq := fk.MethodByName("Model").Call(nil) 531 | fkPrimaryField := reflect.Indirect(mq[0]).FieldByName("PrimaryField").Interface().(string) 532 | for j := 0; j < q[0].Len(); j++ { 533 | fkRow := q[0].Index(j).Elem() 534 | relatedFieldId := fkRow.FieldByName(rpk.RelatedField).FieldByName("Row").Elem().FieldByName(fkPrimaryField).Interface() 535 | if value.FieldByName(query.Model.PrimaryField).Interface() == relatedFieldId { 536 | valueFk.FieldByName("Rows").Set(reflect.Append(valueFk.FieldByName("Rows"), fkRow.Addr())) 537 | } 538 | } 539 | } else if valueFk.FieldByName("Valid").Interface().(bool) { 540 | mq := fk.MethodByName("Model").Call(nil) 541 | fkPrimaryField := reflect.Indirect(mq[0]).FieldByName("PrimaryField").Interface().(string) 542 | for j := 0; j < q[0].Len(); j++ { 543 | fkRow := q[0].Index(j) 544 | if valueFk.FieldByName("Row").Elem().FieldByName(query.Model.PrimaryField).Interface() == fkRow.Elem().FieldByName(fkPrimaryField).Interface() { 545 | valueFk.FieldByName("Row").Set(fkRow) 546 | break 547 | } 548 | } 549 | } 550 | } 551 | } 552 | } 553 | } 554 | 555 | if query.Config.Context != nil { 556 | select { 557 | default: 558 | case <-query.Config.Context.Done(): 559 | return nil, query.Config.Context.Err() 560 | } 561 | } 562 | 563 | return rows, nil 564 | } 565 | 566 | func (query *Query[T]) Sort(columns ...string) *Query[T] { 567 | query.Config.Sort = columns 568 | return query 569 | } 570 | 571 | func (query Query[T]) StringWithArgs(dialect Dialect, args []interface{}) (string, []interface{}, error) { 572 | query.dialect = dialect 573 | query.configure() 574 | query.Config.Params = args 575 | return query.dialect.BuildSelect(query.Config) 576 | } 577 | 578 | func (query *Query[T]) TableColumnAdd(db *sql.DB, column string) (sql.Result, error) { 579 | query.detectDialect() 580 | query.configure() 581 | queryString, err := query.dialect.BuildTableColumnAdd(query.Config, column) 582 | if err != nil { 583 | return nil, err 584 | } 585 | return query.dbExec(db, queryString) 586 | } 587 | 588 | func (query *Query[T]) TableColumnDrop(db *sql.DB, column string) (sql.Result, error) { 589 | query.detectDialect() 590 | query.configure() 591 | queryString, err := query.dialect.BuildTableColumnDrop(query.Config, column) 592 | if err != nil { 593 | return nil, err 594 | } 595 | return query.dbExec(db, queryString) 596 | } 597 | 598 | func (query *Query[T]) TableCreate(db *sql.DB, tableCreateConfig ...TableCreateConfig) (sql.Result, error) { 599 | query.detectDialect() 600 | query.configure() 601 | var config TableCreateConfig 602 | if len(tableCreateConfig) > 0 { 603 | config = tableCreateConfig[0] 604 | } 605 | queryString, err := query.dialect.BuildTableCreate(query.Config, config) 606 | if err != nil { 607 | return nil, err 608 | } 609 | return query.dbExec(db, queryString) 610 | } 611 | 612 | func (query *Query[T]) TableDrop(db *sql.DB, tableDropConfig ...TableDropConfig) (sql.Result, error) { 613 | query.detectDialect() 614 | query.configure() 615 | var config TableDropConfig 616 | if len(tableDropConfig) > 0 { 617 | config = tableDropConfig[0] 618 | } 619 | queryString, err := query.dialect.BuildTableDrop(query.Config, config) 620 | if err != nil { 621 | return nil, err 622 | } 623 | return query.dbExec(db, queryString) 624 | } 625 | 626 | func (query *Query[T]) Transaction(transaction *sql.Tx) *Query[T] { 627 | query.Config.Transaction = transaction 628 | return query 629 | } 630 | 631 | func (query *Query[T]) Update(db *sql.DB, row *T, columns ...string) (sql.Result, error) { 632 | query.detectDialect() 633 | query.configure() 634 | 635 | if len(columns) == 0 { 636 | return nil, fmt.Errorf("rem: no columns specified for update") 637 | } 638 | 639 | rowMap, err := query.Model.ToMap(row) 640 | if err != nil { 641 | return nil, err 642 | } 643 | 644 | queryString, args, err := query.dialect.BuildUpdate(query.Config, rowMap, columns...) 645 | if err != nil { 646 | return nil, err 647 | } 648 | return query.dbExec(db, queryString, args...) 649 | } 650 | 651 | func (query *Query[T]) UpdateMap(db *sql.DB, data map[string]interface{}) (sql.Result, error) { 652 | query.detectDialect() 653 | query.configure() 654 | 655 | if len(data) == 0 { 656 | return nil, fmt.Errorf("rem: no columns specified for update") 657 | } 658 | 659 | columns := make([]string, 0) 660 | for column := range data { 661 | columns = append(columns, column) 662 | } 663 | 664 | queryString, args, err := query.dialect.BuildUpdate(query.Config, data, columns...) 665 | if err != nil { 666 | return nil, err 667 | } 668 | return query.dbExec(db, queryString, args...) 669 | } 670 | 671 | type relatedPk struct { 672 | RelatedColumn string 673 | RelatedField string 674 | RelatedValues []interface{} 675 | } 676 | -------------------------------------------------------------------------------- /query_test.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "sort" 5 | "testing" 6 | 7 | "github.com/DATA-DOG/go-sqlmock" 8 | "golang.org/x/exp/maps" 9 | "golang.org/x/exp/slices" 10 | ) 11 | 12 | func TestQueryConfigure(t *testing.T) { 13 | type testModel struct { 14 | Id int64 `db:"test_id" db_primary:"true"` 15 | Value1 string `db:"test_value_1" db_max_length:"100"` 16 | Value2 string `db:"test_value_2" db_max_length:"100"` 17 | } 18 | 19 | query := Use[testModel]().Query() 20 | query.configure() 21 | columns := maps.Keys(query.Config.Fields) 22 | sort.Strings(columns) 23 | expectedColumns := []string{"test_id", "test_value_1", "test_value_2"} 24 | if !slices.Equal(columns, expectedColumns) { 25 | t.Errorf(`Expected '%+v', got '%+v'`, expectedColumns, columns) 26 | } 27 | if query.Config.Table != "testmodel" { 28 | t.Errorf(`Expected 'testmodel', got '%s'`, query.Config.Table) 29 | } 30 | } 31 | 32 | func TestQueryDetectDialect(t *testing.T) { 33 | type testModel struct { 34 | Id int64 `db:"test_id" db_primary:"true"` 35 | Value1 string `db:"test_value_1" db_max_length:"100"` 36 | Value2 string `db:"test_value_2" db_max_length:"100"` 37 | } 38 | 39 | query := &Query[testModel]{} 40 | 41 | defer func() { 42 | if r := recover(); r == nil { 43 | t.Errorf("Expected panic") 44 | } 45 | }() 46 | query.detectDialect() 47 | } 48 | 49 | func TestQueryFilters(t *testing.T) { 50 | type testModel struct { 51 | Id int64 `db:"test_id" db_primary:"true"` 52 | Value1 string `db:"test_value_1" db_max_length:"100"` 53 | Value2 string `db:"test_value_2" db_max_length:"100"` 54 | } 55 | 56 | query := Use[testModel]().Query().Filter("a", "=", "foo") 57 | expected := []FilterClause{ 58 | {Left: "a", Operator: "=", Right: "foo", Rule: "WHERE"}, 59 | } 60 | if !slices.Equal(query.Config.Filters, expected) { 61 | t.Errorf(`Expected '%+v', got '%+v'`, expected, query.Config.Filters) 62 | } 63 | 64 | query = query.Filter("b", "!=", "bar") 65 | expected = []FilterClause{ 66 | {Left: "a", Operator: "=", Right: "foo", Rule: "WHERE"}, 67 | {Rule: "AND"}, 68 | {Left: "b", Operator: "!=", Right: "bar", Rule: "WHERE"}, 69 | } 70 | if !slices.Equal(query.Config.Filters, expected) { 71 | t.Errorf(`Expected '%+v', got '%+v'`, expected, query.Config.Filters) 72 | } 73 | 74 | query = query.FilterOr( 75 | Q("c.one", "=", 1), 76 | Q("c.two", "=", 2), 77 | ) 78 | expected = []FilterClause{ 79 | {Left: "a", Operator: "=", Right: "foo", Rule: "WHERE"}, 80 | {Rule: "AND"}, 81 | {Left: "b", Operator: "!=", Right: "bar", Rule: "WHERE"}, 82 | {Rule: "AND"}, 83 | {Rule: "("}, 84 | {Left: "c.one", Operator: "=", Right: 1, Rule: "WHERE"}, 85 | {Rule: "OR"}, 86 | {Left: "c.two", Operator: "=", Right: 2, Rule: "WHERE"}, 87 | {Rule: ")"}, 88 | } 89 | if !slices.Equal(query.Config.Filters, expected) { 90 | t.Errorf(`Expected '%+v', got '%+v'`, expected, query.Config.Filters) 91 | } 92 | 93 | query = query.FilterAnd( 94 | Q("d.one", "IS", nil), 95 | Q("d.two", "IS NOT", nil), 96 | Q("d.three", "IN", []string{"foo", "bar", "baz"}), 97 | ) 98 | expected = []FilterClause{ 99 | {Left: "a", Operator: "=", Right: "foo", Rule: "WHERE"}, 100 | {Rule: "AND"}, 101 | {Left: "b", Operator: "!=", Right: "bar", Rule: "WHERE"}, 102 | {Rule: "AND"}, 103 | {Rule: "("}, 104 | {Left: "c.one", Operator: "=", Right: 1, Rule: "WHERE"}, 105 | {Rule: "OR"}, 106 | {Left: "c.two", Operator: "=", Right: 2, Rule: "WHERE"}, 107 | {Rule: ")"}, 108 | {Rule: "AND"}, 109 | {Rule: "("}, 110 | {Left: "d.one", Operator: "IS", Right: nil, Rule: "WHERE"}, 111 | {Rule: "AND"}, 112 | {Left: "d.two", Operator: "IS NOT", Right: nil, Rule: "WHERE"}, 113 | {Rule: "AND"}, 114 | {Left: "d.three", Operator: "IN", Right: []string{"foo", "bar", "baz"}, Rule: "WHERE"}, 115 | {Rule: ")"}, 116 | } 117 | if len(query.Config.Filters) != len(expected) || !slices.Equal(query.Config.Filters[:15], expected[:15]) { 118 | t.Errorf(`Expected '%+v', got '%+v'`, expected, query.Config.Filters) 119 | } 120 | if left := query.Config.Filters[15].Left; left != "d.three" { 121 | t.Errorf(`Expected '%s', got '%s'`, "d.three", left) 122 | } 123 | if right := query.Config.Filters[15].Right.([]string); !slices.Equal(right, []string{"foo", "bar", "baz"}) { 124 | t.Errorf(`Expected '%+v', got '%+v'`, []string{"foo", "bar", "baz"}, right) 125 | } 126 | if last := query.Config.Filters[len(query.Config.Filters)-1]; last.Rule != ")" { 127 | t.Errorf(`Expected '%+v', got '%+v'`, FilterClause{Rule: ")"}, last) 128 | } 129 | } 130 | 131 | func TestQueryJoins(t *testing.T) { 132 | type testGroups struct { 133 | Id int64 `db:"test_id" db_primary:"true"` 134 | Name string `db:"name" db_max_length:"100"` 135 | } 136 | type testAccounts struct { 137 | Id int64 `db:"id" db_primary:"true"` 138 | Group NullForeignKey[testGroups] `db:"group_id" db_on_delete:"SET NULL"` 139 | Name string `db:"name" db_max_length:"100"` 140 | } 141 | 142 | assertJoins := func(t *testing.T, expected, actual []JoinClause) { 143 | if len(actual) != len(expected) { 144 | t.Errorf(`Expected '%+v', got '%+v'`, expected, actual) 145 | } 146 | for i, join := range actual { 147 | if join.Direction != expected[i].Direction { 148 | t.Errorf(`Expected '%s', got '%s'`, expected[i].Direction, join.Direction) 149 | } 150 | if join.Table != expected[i].Table { 151 | t.Errorf(`Expected '%s', got '%s'`, expected[i].Table, join.Table) 152 | } 153 | if !slices.Equal(join.On, expected[i].On) { 154 | t.Errorf(`Expected '%+v', got '%+v'`, expected[i].On, join.On) 155 | } 156 | } 157 | } 158 | 159 | model := Use[testAccounts]() 160 | query := model.Query().Join("testgroups", Q("testgroups.id", "=", "testaccounts.group_id")) 161 | expected := []JoinClause{ 162 | { 163 | Direction: "INNER", 164 | Table: "testgroups", 165 | On: []FilterClause{ 166 | {Left: "testgroups.id", Operator: "=", Right: "testaccounts.group_id", Rule: "WHERE"}, 167 | }, 168 | }, 169 | } 170 | assertJoins(t, expected, query.Config.Joins) 171 | 172 | query = model.Query().JoinLeft("testgroups", Q("testgroups.id", "=", "testaccounts.group_id")) 173 | expected[0].Direction = "LEFT" 174 | assertJoins(t, expected, query.Config.Joins) 175 | 176 | query = model.Query().JoinRight("testgroups", Q("testgroups.id", "=", "testaccounts.group_id")) 177 | expected[0].Direction = "RIGHT" 178 | assertJoins(t, expected, query.Config.Joins) 179 | 180 | query = model.Query().JoinFull("testgroups", Or( 181 | Q("testaccounts.group_id", "=", "testgroups.id"), 182 | Q("testaccounts.group_id", "IS", nil), 183 | )) 184 | expected = []JoinClause{ 185 | { 186 | Direction: "FULL", 187 | Table: "testgroups", 188 | On: []FilterClause{ 189 | {Rule: "("}, 190 | {Left: "testaccounts.group_id", Operator: "=", Right: "testgroups.id", Rule: "WHERE"}, 191 | {Rule: "OR"}, 192 | {Left: "testaccounts.group_id", Operator: "IS", Right: nil, Rule: "WHERE"}, 193 | {Rule: ")"}, 194 | }, 195 | }, 196 | } 197 | assertJoins(t, expected, query.Config.Joins) 198 | } 199 | 200 | type testGroupsQuerySlice struct { 201 | Accounts OneToMany[testAccountsQuerySlice] `db:"group_id"` 202 | Id int64 `db:"id" db_primary:"true"` 203 | Name string `db:"name" db_max_length:"100"` 204 | } 205 | type testAccountsQuerySlice struct { 206 | Group NullForeignKey[testGroupsQuerySlice] `db:"group_id" db_on_delete:"SET NULL"` 207 | Id int64 `db:"id" db_primary:"true"` 208 | Name string `db:"name"` 209 | } 210 | 211 | func TestQuerySlice(t *testing.T) { 212 | defer func() { 213 | defaultDialect = nil 214 | }() 215 | SetDialect(testDialect{}) 216 | 217 | db, mock, err := sqlmock.New() 218 | if err != nil { 219 | t.Fatal("failed to open sqlmock database:", err) 220 | } 221 | defer db.Close() 222 | 223 | expect := func(t *testing.T, actual []*testAccountsQuerySlice, expected []testAccountsQuerySlice) { 224 | if len(actual) != len(expected) { 225 | t.Errorf(`Expected '%#v', got '%#v'`, expected, actual) 226 | } 227 | for i, account := range expected { 228 | if actual[i].Id != account.Id || 229 | actual[i].Name != account.Name || 230 | actual[i].Group.Valid != account.Group.Valid || 231 | (actual[i].Group.Valid && (actual[i].Group.Row.Id != account.Group.Row.Id || actual[i].Group.Row.Name != account.Group.Row.Name)) { 232 | t.Errorf(`Expected '%#v', got '%#v'`, account, *actual[i]) 233 | } 234 | } 235 | } 236 | 237 | query := Use[testAccountsQuerySlice]().Query() 238 | 239 | mock.ExpectQuery("SELECT"). 240 | WillReturnRows(sqlmock.NewRows([]string{"id", "name", "group_id"}). 241 | AddRow(1, "foo", 10). 242 | AddRow(2, "bar", nil). 243 | AddRow(3, "baz", nil)) 244 | rs, _ := db.Query("SELECT") 245 | defer rs.Close() 246 | query.Rows = rs 247 | actual, err := query.slice(db) 248 | if err != nil { 249 | t.Fatal("Unexpected error:", err) 250 | } 251 | expect(t, actual, []testAccountsQuerySlice{ 252 | {Id: 1, Name: "foo", Group: NullForeignKey[testGroupsQuerySlice]{Row: &testGroupsQuerySlice{Id: 10}, Valid: true}}, 253 | {Id: 2, Name: "bar"}, 254 | {Id: 3, Name: "baz"}, 255 | }) 256 | 257 | // Foreign keys. 258 | query.FetchRelated("Group") 259 | mock.ExpectQuery("SELECT"). 260 | WillReturnRows(sqlmock.NewRows([]string{"id", "name", "group_id"}). 261 | AddRow(1, "foo", 10). 262 | AddRow(2, "bar", 20). 263 | AddRow(3, "baz", nil)) 264 | rs, _ = db.Query("SELECT") 265 | defer rs.Close() 266 | query.Rows = rs 267 | 268 | mock.ExpectQuery(`SELECT|FILTER[{Left:id Operator:IN Right:[10,20] Rule:WHERE}]|`). 269 | WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). 270 | AddRow(10, "Group 10"). 271 | AddRow(20, "Group 20")) 272 | 273 | actual, err = query.slice(db) 274 | if err != nil { 275 | t.Fatal("Unexpected error:", err) 276 | } 277 | expect(t, actual, []testAccountsQuerySlice{ 278 | {Id: 1, Name: "foo", Group: NullForeignKey[testGroupsQuerySlice]{Row: &testGroupsQuerySlice{Id: 10, Name: "Group 10"}, Valid: true}}, 279 | {Id: 2, Name: "bar", Group: NullForeignKey[testGroupsQuerySlice]{Row: &testGroupsQuerySlice{Id: 20, Name: "Group 20"}, Valid: true}}, 280 | {Id: 3, Name: "baz"}, 281 | }) 282 | 283 | // One to many. 284 | query2 := Use[testGroupsQuerySlice]().Query() 285 | query2.FetchRelated("Accounts") 286 | mock.ExpectQuery("SELECT"). 287 | WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). 288 | AddRow(10, "Group 10"). 289 | AddRow(20, "Group 20")) 290 | 291 | rs2, _ := db.Query("SELECT") 292 | defer rs2.Close() 293 | query2.Rows = rs2 294 | 295 | mock.ExpectQuery(`SELECT|FILTER[{Left:id Operator:IN Right:[10,20] Rule:WHERE}]|`). 296 | WillReturnRows(sqlmock.NewRows([]string{"id", "name", "group_id"}). 297 | AddRow(1, "foo", 10). 298 | AddRow(2, "bar", 20). 299 | AddRow(3, "baz", 10)) 300 | 301 | actual2, err := query2.slice(db) 302 | if err != nil { 303 | t.Fatal("Unexpected error:", err) 304 | } 305 | expected2 := make([][]map[string]interface{}, 0) 306 | expected2 = append(expected2, []map[string]interface{}{ 307 | { 308 | "Id": int64(1), 309 | "Name": "foo", 310 | "GroupId": int64(10), 311 | }, 312 | { 313 | "Id": int64(3), 314 | "Name": "baz", 315 | "GroupId": int64(10), 316 | }, 317 | }) 318 | expected2 = append(expected2, []map[string]interface{}{ 319 | { 320 | "Id": int64(2), 321 | "Name": "bar", 322 | "GroupId": int64(20), 323 | }, 324 | }) 325 | if len(actual2) != len(expected2) { 326 | t.Errorf(`Expected 2 groups, got '%#v'`, actual2) 327 | } 328 | for i, expect2 := range expected2 { 329 | if len(actual2[i].Accounts.Rows) != len(expect2) { 330 | t.Errorf(`Expected '%#v' accounts, got '%#v'`, expect2, actual2[i].Accounts.Rows) 331 | } else { 332 | for j, actualAccount2 := range actual2[i].Accounts.Rows { 333 | if actualAccount2.Id != expect2[j]["Id"] || 334 | actualAccount2.Name != expect2[j]["Name"] || 335 | !actualAccount2.Group.Valid || 336 | actualAccount2.Group.Row.Id != expect2[j]["GroupId"] { 337 | t.Errorf(`Expected '%#v' accounts, got '%#v'`, expect2, actual2[i].Accounts.Rows) 338 | } 339 | } 340 | } 341 | } 342 | } 343 | -------------------------------------------------------------------------------- /sqlitedialect/sqlitedialect.go: -------------------------------------------------------------------------------- 1 | package sqlitedialect 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "reflect" 7 | "sort" 8 | "strings" 9 | "time" 10 | 11 | "github.com/evantbyrne/rem" 12 | "golang.org/x/exp/maps" 13 | ) 14 | 15 | type SqliteDialect struct { 16 | PreserveBooleans bool 17 | } 18 | 19 | func (dialect SqliteDialect) BuildDelete(config rem.QueryConfig) (string, []interface{}, error) { 20 | args := append([]interface{}(nil), config.Params...) 21 | var queryString strings.Builder 22 | queryString.WriteString("DELETE FROM ") 23 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 24 | 25 | // WHERE 26 | where, args, err := dialect.buildWhere(config, args) 27 | if err != nil { 28 | return "", nil, err 29 | } 30 | if where != "" { 31 | queryString.WriteString(where) 32 | } 33 | 34 | // ORDER BY 35 | if len(config.Sort) > 0 { 36 | queryString.WriteString(" ORDER BY ") 37 | for i, column := range config.Sort { 38 | if i > 0 { 39 | queryString.WriteString(", ") 40 | } 41 | if strings.HasPrefix(column, "-") { 42 | queryString.WriteString(dialect.QuoteIdentifier(column[1:])) 43 | queryString.WriteString(" DESC") 44 | } else { 45 | queryString.WriteString(dialect.QuoteIdentifier(column)) 46 | queryString.WriteString(" ASC") 47 | } 48 | } 49 | } 50 | 51 | // LIMIT 52 | if config.Limit != nil { 53 | args = append(args, config.Limit) 54 | queryString.WriteString(" LIMIT ") 55 | queryString.WriteString(dialect.Param(len(args))) 56 | } 57 | 58 | // OFFSET 59 | if config.Offset != nil { 60 | return "", nil, fmt.Errorf("rem: DELETE does not support OFFSET") 61 | } 62 | 63 | return queryString.String(), args, nil 64 | } 65 | 66 | func (dialect SqliteDialect) BuildInsert(config rem.QueryConfig, rowMap map[string]interface{}, columns ...string) (string, []interface{}, error) { 67 | args := make([]interface{}, 0) 68 | var queryString strings.Builder 69 | 70 | queryString.WriteString("INSERT INTO ") 71 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 72 | queryString.WriteString(" (") 73 | first := true 74 | for _, column := range columns { 75 | if arg, ok := rowMap[column]; ok { 76 | if _, ok := config.Fields[column]; !ok { 77 | return "", nil, fmt.Errorf("rem: field for column '%s' not found on model for table '%s'", column, config.Table) 78 | } 79 | args = append(args, arg) 80 | if first { 81 | first = false 82 | } else { 83 | queryString.WriteString(",") 84 | } 85 | queryString.WriteString(dialect.QuoteIdentifier(column)) 86 | } else { 87 | return "", nil, fmt.Errorf("rem: invalid column '%s' on INSERT", column) 88 | } 89 | } 90 | 91 | queryString.WriteString(") VALUES (") 92 | for i := 1; i <= len(rowMap); i++ { 93 | if i > 1 { 94 | queryString.WriteString(",") 95 | } 96 | queryString.WriteString(dialect.Param(i)) 97 | } 98 | queryString.WriteString(")") 99 | 100 | return queryString.String(), args, nil 101 | } 102 | 103 | func (dialect SqliteDialect) buildJoins(config rem.QueryConfig, args []interface{}) (string, []interface{}, error) { 104 | var queryPart strings.Builder 105 | if len(config.Joins) > 0 { 106 | for _, join := range config.Joins { 107 | if len(join.On) > 0 { 108 | queryPart.WriteString(fmt.Sprintf(" %s JOIN %s ON", join.Direction, dialect.QuoteIdentifier(join.Table))) 109 | for _, where := range join.On { 110 | queryWhere, whereArgs, err := where.StringWithArgs(dialect, args) 111 | if err != nil { 112 | return "", nil, err 113 | } 114 | args = whereArgs 115 | queryPart.WriteString(queryWhere) 116 | } 117 | } 118 | } 119 | } 120 | return queryPart.String(), args, nil 121 | } 122 | 123 | func (dialect SqliteDialect) BuildSelect(config rem.QueryConfig) (string, []interface{}, error) { 124 | args := append([]interface{}(nil), config.Params...) 125 | var queryString strings.Builder 126 | if config.Count { 127 | queryString.WriteString("SELECT count(*) FROM ") 128 | } else if len(config.Selected) > 0 { 129 | queryString.WriteString("SELECT ") 130 | for i, column := range config.Selected { 131 | if i > 0 { 132 | queryString.WriteString(",") 133 | } 134 | switch cv := column.(type) { 135 | case string: 136 | queryString.WriteString(dialect.QuoteIdentifier(cv)) 137 | 138 | case rem.DialectStringer: 139 | queryString.WriteString(cv.StringForDialect(dialect)) 140 | 141 | case fmt.Stringer: 142 | queryString.WriteString(cv.String()) 143 | 144 | default: 145 | return "", nil, fmt.Errorf("rem: invalid column type %#v", column) 146 | } 147 | } 148 | queryString.WriteString(" FROM ") 149 | } else { 150 | queryString.WriteString("SELECT * FROM ") 151 | } 152 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 153 | 154 | // JOIN 155 | joins, args, err := dialect.buildJoins(config, args) 156 | if err != nil { 157 | return "", nil, err 158 | } 159 | if joins != "" { 160 | queryString.WriteString(joins) 161 | } 162 | 163 | // WHERE 164 | where, args, err := dialect.buildWhere(config, args) 165 | if err != nil { 166 | return "", nil, err 167 | } 168 | if where != "" { 169 | queryString.WriteString(where) 170 | } 171 | 172 | // ORDER BY 173 | if len(config.Sort) > 0 { 174 | queryString.WriteString(" ORDER BY ") 175 | for i, column := range config.Sort { 176 | if i > 0 { 177 | queryString.WriteString(", ") 178 | } 179 | if strings.HasPrefix(column, "-") { 180 | queryString.WriteString(dialect.QuoteIdentifier(column[1:])) 181 | queryString.WriteString(" DESC") 182 | } else { 183 | queryString.WriteString(dialect.QuoteIdentifier(column)) 184 | queryString.WriteString(" ASC") 185 | } 186 | } 187 | } 188 | 189 | // LIMIT 190 | if config.Limit != nil { 191 | args = append(args, config.Limit) 192 | queryString.WriteString(" LIMIT ") 193 | queryString.WriteString(dialect.Param(len(args))) 194 | } 195 | 196 | // OFFSET 197 | if config.Offset != nil { 198 | args = append(args, config.Offset) 199 | queryString.WriteString(" OFFSET ") 200 | queryString.WriteString(dialect.Param(len(args))) 201 | } 202 | 203 | return queryString.String(), args, nil 204 | } 205 | 206 | func (dialect SqliteDialect) BuildTableColumnAdd(config rem.QueryConfig, column string) (string, error) { 207 | field, ok := config.Fields[column] 208 | if !ok { 209 | return "", fmt.Errorf("rem: invalid column '%s' on model for table '%s'", column, config.Table) 210 | } 211 | 212 | columnType, err := dialect.ColumnType(field) 213 | if err != nil { 214 | return "", err 215 | } 216 | return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", dialect.QuoteIdentifier(config.Table), dialect.QuoteIdentifier(column), columnType), nil 217 | } 218 | 219 | func (dialect SqliteDialect) BuildTableColumnDrop(config rem.QueryConfig, column string) (string, error) { 220 | return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", dialect.QuoteIdentifier(config.Table), dialect.QuoteIdentifier(column)), nil 221 | } 222 | 223 | func (dialect SqliteDialect) BuildTableCreate(config rem.QueryConfig, tableCreateConfig rem.TableCreateConfig) (string, error) { 224 | var sql strings.Builder 225 | sql.WriteString("CREATE TABLE ") 226 | if tableCreateConfig.IfNotExists { 227 | sql.WriteString("IF NOT EXISTS ") 228 | } 229 | sql.WriteString(dialect.QuoteIdentifier(config.Table)) 230 | sql.WriteString(" (") 231 | fieldNames := maps.Keys(config.Fields) 232 | sort.Strings(fieldNames) 233 | for i, fieldName := range fieldNames { 234 | field := config.Fields[fieldName] 235 | columnType, err := dialect.ColumnType(field) 236 | if err != nil { 237 | return "", err 238 | } 239 | if i > 0 { 240 | sql.WriteString(",") 241 | } 242 | sql.WriteString("\n\t") 243 | sql.WriteString(dialect.QuoteIdentifier(fieldName)) 244 | sql.WriteString(" ") 245 | sql.WriteString(columnType) 246 | } 247 | sql.WriteString("\n)") 248 | 249 | return sql.String(), nil 250 | } 251 | 252 | func (dialect SqliteDialect) BuildTableDrop(config rem.QueryConfig, tableDropConfig rem.TableDropConfig) (string, error) { 253 | var queryString strings.Builder 254 | queryString.WriteString("DROP TABLE ") 255 | if tableDropConfig.IfExists { 256 | queryString.WriteString("IF EXISTS ") 257 | } 258 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 259 | return queryString.String(), nil 260 | } 261 | 262 | func (dialect SqliteDialect) BuildUpdate(config rem.QueryConfig, rowMap map[string]interface{}, columns ...string) (string, []interface{}, error) { 263 | args := append([]interface{}(nil), config.Params...) 264 | var queryString strings.Builder 265 | 266 | queryString.WriteString("UPDATE ") 267 | queryString.WriteString(dialect.QuoteIdentifier(config.Table)) 268 | queryString.WriteString(" SET ") 269 | 270 | first := true 271 | for _, column := range columns { 272 | if arg, ok := rowMap[column]; ok { 273 | args = append(args, arg) 274 | if first { 275 | first = false 276 | } else { 277 | queryString.WriteString(",") 278 | } 279 | queryString.WriteString(dialect.QuoteIdentifier(column)) 280 | queryString.WriteString(" = ") 281 | queryString.WriteString(dialect.Param(len(args))) 282 | } else { 283 | return "", nil, fmt.Errorf("rem: invalid column '%s' on UPDATE", column) 284 | } 285 | } 286 | 287 | // WHERE 288 | where, args, err := dialect.buildWhere(config, args) 289 | if err != nil { 290 | return "", nil, err 291 | } 292 | if where != "" { 293 | queryString.WriteString(where) 294 | } 295 | 296 | // ORDER BY 297 | if len(config.Sort) > 0 { 298 | queryString.WriteString(" ORDER BY ") 299 | for i, column := range config.Sort { 300 | if i > 0 { 301 | queryString.WriteString(", ") 302 | } 303 | if strings.HasPrefix(column, "-") { 304 | queryString.WriteString(dialect.QuoteIdentifier(column[1:])) 305 | queryString.WriteString(" DESC") 306 | } else { 307 | queryString.WriteString(dialect.QuoteIdentifier(column)) 308 | queryString.WriteString(" ASC") 309 | } 310 | } 311 | } 312 | 313 | // LIMIT 314 | if config.Limit != nil { 315 | args = append(args, config.Limit) 316 | queryString.WriteString(" LIMIT ") 317 | queryString.WriteString(dialect.Param(len(args))) 318 | } 319 | 320 | // OFFSET 321 | if config.Offset != nil { 322 | return "", nil, fmt.Errorf("rem: UPDATE does not support OFFSET") 323 | } 324 | 325 | return queryString.String(), args, nil 326 | } 327 | 328 | func (dialect SqliteDialect) buildWhere(config rem.QueryConfig, args []interface{}) (string, []interface{}, error) { 329 | var queryPart strings.Builder 330 | if len(config.Filters) > 0 { 331 | queryPart.WriteString(" WHERE") 332 | for _, where := range config.Filters { 333 | queryWhere, whereArgs, err := where.StringWithArgs(dialect, args) 334 | if err != nil { 335 | return "", nil, err 336 | } 337 | args = whereArgs 338 | queryPart.WriteString(queryWhere) 339 | } 340 | } 341 | return queryPart.String(), args, nil 342 | } 343 | 344 | func (dialect SqliteDialect) ColumnType(field reflect.StructField) (string, error) { 345 | tagType := field.Tag.Get("db_type") 346 | if tagType != "" { 347 | return tagType, nil 348 | } 349 | 350 | fieldInstance := reflect.Indirect(reflect.New(field.Type)).Interface() 351 | var columnNull string 352 | var columnPrimary string 353 | var columnType string 354 | 355 | if field.Tag.Get("db_primary") == "true" { 356 | columnPrimary = " PRIMARY KEY" 357 | } 358 | 359 | switch fieldInstance.(type) { 360 | case bool: 361 | columnNull = " NOT NULL" 362 | columnType = "BOOLEAN" 363 | 364 | case int, int8, int16, int32, int64: 365 | columnNull = " NOT NULL" 366 | columnType = "INTEGER" 367 | 368 | case sql.NullBool: 369 | columnNull = " NULL" 370 | columnType = "BOOLEAN" 371 | 372 | case sql.NullInt16, sql.NullInt32, sql.NullInt64: 373 | columnNull = " NULL" 374 | columnType = "INTEGER" 375 | 376 | case float32, float64: 377 | columnNull = " NOT NULL" 378 | columnType = "REAL" 379 | 380 | case sql.NullFloat64: 381 | columnNull = " NULL" 382 | columnType = "REAL" 383 | 384 | case string: 385 | columnNull = " NOT NULL" 386 | columnType = "TEXT" 387 | 388 | case time.Time: 389 | columnNull = " NOT NULL" 390 | columnType = "DATETIME" 391 | 392 | case sql.NullString: 393 | columnNull = " NULL" 394 | columnType = "TEXT" 395 | 396 | case sql.NullTime: 397 | columnNull = " NULL" 398 | columnType = "DATETIME" 399 | 400 | default: 401 | if strings.HasPrefix(field.Type.String(), "rem.ForeignKey[") || strings.HasPrefix(field.Type.String(), "rem.NullForeignKey[") { 402 | // Foreign keys. 403 | fv := reflect.New(field.Type).Elem() 404 | subModelQ := fv.Addr().MethodByName("Model").Call(nil) 405 | subFields := reflect.Indirect(subModelQ[0]).FieldByName("Fields").Interface().(map[string]reflect.StructField) 406 | subPrimaryColumn := reflect.Indirect(subModelQ[0]).FieldByName("PrimaryColumn").Interface().(string) 407 | subTable := reflect.Indirect(subModelQ[0]).FieldByName("Table").Interface().(string) 408 | columnTypeTemp, err := dialect.ColumnType(subFields[subPrimaryColumn]) 409 | if err != nil { 410 | return "", err 411 | } 412 | columnType = strings.SplitN(columnTypeTemp, " ", 2)[0] 413 | 414 | columnNull = " NOT NULL" 415 | if strings.HasPrefix(field.Type.String(), "rem.NullForeignKey[") { 416 | columnNull = " NULL" 417 | } 418 | columnNull = fmt.Sprintf("%s REFERENCES %s (%s)", columnNull, dialect.QuoteIdentifier(subTable), dialect.QuoteIdentifier(subPrimaryColumn)) 419 | 420 | if tagOnUpdate := field.Tag.Get("db_on_update"); tagOnUpdate != "" { 421 | // ON UPDATE. 422 | columnNull = fmt.Sprint(columnNull, " ON UPDATE ", tagOnUpdate) 423 | } 424 | 425 | if tagOnDelete := field.Tag.Get("db_on_delete"); tagOnDelete != "" { 426 | // ON DELETE. 427 | columnNull = fmt.Sprint(columnNull, " ON DELETE ", tagOnDelete) 428 | } 429 | } 430 | } 431 | 432 | if columnType == "" { 433 | return "", fmt.Errorf("rem: Unsupported column type: %T. Use the 'db_type' field tag to define a SQL type", fieldInstance) 434 | } 435 | 436 | if tagDefault := field.Tag.Get("db_default"); tagDefault != "" { 437 | // DEFAULT. 438 | columnNull += " DEFAULT " + tagDefault 439 | } 440 | 441 | if tagUnique := field.Tag.Get("db_unique"); tagUnique == "true" { 442 | // UNIQUE. 443 | columnNull += " UNIQUE" 444 | } 445 | 446 | return fmt.Sprint(columnType, columnPrimary, columnNull), nil 447 | } 448 | 449 | func (dialect SqliteDialect) Param(identifier int) string { 450 | return "?" 451 | } 452 | 453 | func (dialect SqliteDialect) QuoteIdentifier(identifier string) string { 454 | var query strings.Builder 455 | for i, part := range strings.Split(identifier, ".") { 456 | if i > 0 { 457 | query.WriteString(".") 458 | } 459 | query.WriteString(QuoteIdentifier(part)) 460 | } 461 | return query.String() 462 | } 463 | 464 | func QuoteIdentifier(identifier string) string { 465 | return "`" + strings.Replace(identifier, "`", "``", -1) + "`" 466 | } 467 | -------------------------------------------------------------------------------- /sqlitedialect/sqlitedialect_test.go: -------------------------------------------------------------------------------- 1 | package sqlitedialect 2 | 3 | import ( 4 | "database/sql" 5 | "sort" 6 | "testing" 7 | "time" 8 | 9 | "github.com/evantbyrne/rem" 10 | "golang.org/x/exp/maps" 11 | "golang.org/x/exp/slices" 12 | ) 13 | 14 | func TestAs(t *testing.T) { 15 | dialect := SqliteDialect{} 16 | expected := map[string]rem.SqlAs{ 17 | "`x` AS `alias1`": rem.As("x", "alias1"), 18 | "`x` AS `y` AS `alias2`": rem.As(rem.As("x", "y"), "alias2"), 19 | "count(*) AS `alias3`": rem.As(rem.Unsafe("count(*)"), "alias3"), 20 | } 21 | for expected, alias := range expected { 22 | sql := alias.StringForDialect(dialect) 23 | if expected != sql { 24 | t.Errorf("Expected '%+v', got '%+v'", expected, sql) 25 | } 26 | } 27 | } 28 | 29 | func TestColumn(t *testing.T) { 30 | dialect := SqliteDialect{} 31 | expected := map[string]rem.SqlColumn{ 32 | "`x`": rem.Column("x"), 33 | "`x`.`y`": rem.Column("x.y"), 34 | "`x`.`y`.`z`": rem.Column("x.y.z"), 35 | "`x```": rem.Column("x`"), 36 | } 37 | for expected, column := range expected { 38 | sql := column.StringForDialect(dialect) 39 | if expected != sql { 40 | t.Errorf("Expected '%+v', got '%+v'", expected, sql) 41 | } 42 | } 43 | } 44 | 45 | func TestBuildDelete(t *testing.T) { 46 | type testModel struct { 47 | Id int64 `db:"test_id" db_primary:"true"` 48 | Value1 string `db:"test_value_1" db_max_length:"100"` 49 | Value2 string `db:"test_value_2" db_max_length:"100"` 50 | } 51 | 52 | dialect := SqliteDialect{} 53 | model := rem.Use[testModel]() 54 | 55 | query := model.Query() 56 | config := query.Config 57 | config.Fields = model.Fields 58 | config.Table = "testmodel" 59 | expectedArgs := []interface{}{} 60 | expectedSql := "DELETE FROM `testmodel`" 61 | queryString, args, err := dialect.BuildDelete(config) 62 | if err != nil { 63 | t.Errorf("Unexpected error %s", err.Error()) 64 | } 65 | if queryString != expectedSql { 66 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 67 | } 68 | if !slices.Equal(args, expectedArgs) { 69 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 70 | } 71 | 72 | // WHERE 73 | query.Filter("test_id", "=", 1) 74 | config = query.Config 75 | config.Fields = model.Fields 76 | config.Table = "testmodel" 77 | expectedArgs = []interface{}{1} 78 | expectedSql = "DELETE FROM `testmodel` WHERE `test_id` = ?" 79 | queryString, args, err = dialect.BuildDelete(config) 80 | if err != nil { 81 | t.Errorf("Unexpected error %s", err.Error()) 82 | } 83 | if queryString != expectedSql { 84 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 85 | } 86 | if !slices.Equal(args, expectedArgs) { 87 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 88 | } 89 | 90 | // ORDER BY 91 | query.Sort("-test_id") 92 | config = query.Config 93 | config.Fields = model.Fields 94 | config.Table = "testmodel" 95 | expectedSql = "DELETE FROM `testmodel` WHERE `test_id` = ? ORDER BY `test_id` DESC" 96 | queryString, args, err = dialect.BuildDelete(config) 97 | if err != nil { 98 | t.Errorf("Unexpected error %s", err.Error()) 99 | } 100 | if queryString != expectedSql { 101 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 102 | } 103 | if !slices.Equal(args, expectedArgs) { 104 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 105 | } 106 | 107 | // LIMIT 108 | query.Limit(3) 109 | config = query.Config 110 | config.Fields = model.Fields 111 | config.Table = "testmodel" 112 | expectedArgs = []interface{}{1, 3} 113 | expectedSql = "DELETE FROM `testmodel` WHERE `test_id` = ? ORDER BY `test_id` DESC LIMIT ?" 114 | queryString, args, err = dialect.BuildDelete(config) 115 | if err != nil { 116 | t.Errorf("Unexpected error %s", err.Error()) 117 | } 118 | if queryString != expectedSql { 119 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 120 | } 121 | if !slices.Equal(args, expectedArgs) { 122 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 123 | } 124 | } 125 | 126 | func TestBuildInsert(t *testing.T) { 127 | type testModel struct { 128 | Id int64 `db:"test_id" db_primary:"true"` 129 | Value1 string `db:"test_value_1" db_max_length:"100"` 130 | Value2 string `db:"test_value_2" db_max_length:"100"` 131 | } 132 | 133 | dialect := SqliteDialect{} 134 | model := rem.Use[testModel]() 135 | 136 | config := model.Query().Config 137 | config.Fields = model.Fields 138 | config.Table = "testmodel" 139 | expectedArgs := []interface{}{"foo", "bar"} 140 | expectedSql := "INSERT INTO `testmodel` (`test_value_1`,`test_value_2`) VALUES (?,?)" 141 | queryString, args, err := dialect.BuildInsert(config, map[string]interface{}{ 142 | "test_value_1": "foo", 143 | "test_value_2": "bar", 144 | }, "test_value_1", "test_value_2") 145 | if err != nil { 146 | t.Errorf("Unexpected error %s", err.Error()) 147 | } 148 | if queryString != expectedSql { 149 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 150 | } 151 | if !slices.Equal(args, expectedArgs) { 152 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 153 | } 154 | } 155 | 156 | func TestBuildSelect(t *testing.T) { 157 | type testModel struct { 158 | Id int64 `db:"test_id" db_primary:"true"` 159 | Value1 string `db:"test_value_1" db_max_length:"100"` 160 | Value2 string `db:"test_value_2" db_max_length:"100"` 161 | } 162 | 163 | dialect := SqliteDialect{} 164 | model := rem.Use[testModel]() 165 | 166 | config := model.Query().Config 167 | config.Fields = model.Fields 168 | config.Table = "testmodel" 169 | expectedArgs := []interface{}{} 170 | expectedSql := "SELECT * FROM `testmodel`" 171 | queryString, args, err := dialect.BuildSelect(config) 172 | if err != nil { 173 | t.Errorf("Unexpected error %s", err.Error()) 174 | } 175 | if queryString != expectedSql { 176 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 177 | } 178 | if !slices.Equal(args, expectedArgs) { 179 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 180 | } 181 | 182 | // SELECT 183 | config = model.Select("id", "value1", rem.Unsafe("count(1) as `count`"), rem.As("value2", "value3")).Config 184 | config.Fields = model.Fields 185 | config.Table = "testmodel" 186 | expectedArgs = []interface{}{} 187 | expectedSql = "SELECT `id`,`value1`,count(1) as `count`,`value2` AS `value3` FROM `testmodel`" 188 | queryString, args, err = dialect.BuildSelect(config) 189 | if err != nil { 190 | t.Errorf("Unexpected error %s", err.Error()) 191 | } 192 | if queryString != expectedSql { 193 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 194 | } 195 | if !slices.Equal(args, expectedArgs) { 196 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 197 | } 198 | 199 | // WHERE 200 | config = model.Filter("id", "=", 1).Config 201 | config.Fields = model.Fields 202 | config.Table = "testmodel" 203 | expectedArgs = []interface{}{1} 204 | expectedSql = "SELECT * FROM `testmodel` WHERE `id` = ?" 205 | queryString, args, err = dialect.BuildSelect(config) 206 | if err != nil { 207 | t.Errorf("Unexpected error %s", err.Error()) 208 | } 209 | if queryString != expectedSql { 210 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 211 | } 212 | if !slices.Equal(args, expectedArgs) { 213 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 214 | } 215 | 216 | config = model.Filter("id", "IN", rem.Sql(rem.Param(1), ",", rem.Param(2))).Config 217 | config.Fields = model.Fields 218 | config.Table = "testmodel" 219 | expectedArgs = []interface{}{1, 2} 220 | expectedSql = "SELECT * FROM `testmodel` WHERE `id` IN (?,?)" 221 | queryString, args, err = dialect.BuildSelect(config) 222 | if err != nil { 223 | t.Errorf("Unexpected error %s", err.Error()) 224 | } 225 | if queryString != expectedSql { 226 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 227 | } 228 | if !slices.Equal(args, expectedArgs) { 229 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 230 | } 231 | 232 | // JOIN 233 | config = model.Select(rem.Unsafe("*")).Join("groups", rem.Or( 234 | rem.Q("groups.id", "=", rem.Column("accounts.group_id")), 235 | rem.Q("groups.id", "IS", nil))).Config 236 | config.Fields = model.Fields 237 | config.Table = "testmodel" 238 | expectedArgs = []interface{}{} 239 | expectedSql = "SELECT * FROM `testmodel` INNER JOIN `groups` ON ( `groups`.`id` = `accounts`.`group_id` OR `groups`.`id` IS NULL )" 240 | queryString, args, err = dialect.BuildSelect(config) 241 | if err != nil { 242 | t.Errorf("Unexpected error %s", err.Error()) 243 | } 244 | if queryString != expectedSql { 245 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 246 | } 247 | if !slices.Equal(args, expectedArgs) { 248 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 249 | } 250 | 251 | // SORT 252 | config = model.Sort("test_id", "-test_value_1").Config 253 | config.Fields = model.Fields 254 | config.Table = "testmodel" 255 | expectedArgs = []interface{}{} 256 | expectedSql = "SELECT * FROM `testmodel` ORDER BY `test_id` ASC, `test_value_1` DESC" 257 | queryString, args, err = dialect.BuildSelect(config) 258 | if err != nil { 259 | t.Errorf("Unexpected error %s", err.Error()) 260 | } 261 | if queryString != expectedSql { 262 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 263 | } 264 | if !slices.Equal(args, expectedArgs) { 265 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 266 | } 267 | 268 | // LIMIT and OFFSET 269 | config = model.Filter("id", "=", 1).Offset(20).Limit(10).Config 270 | config.Fields = model.Fields 271 | config.Table = "testmodel" 272 | expectedArgs = []interface{}{1, 10, 20} 273 | expectedSql = "SELECT * FROM `testmodel` WHERE `id` = ? LIMIT ? OFFSET ?" 274 | queryString, args, err = dialect.BuildSelect(config) 275 | if err != nil { 276 | t.Errorf("Unexpected error %s", err.Error()) 277 | } 278 | if queryString != expectedSql { 279 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 280 | } 281 | if !slices.Equal(args, expectedArgs) { 282 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 283 | } 284 | } 285 | 286 | func TestBuildTableColumnAdd(t *testing.T) { 287 | type testModel struct { 288 | Value string `db:"test_value" db_max_length:"100"` 289 | } 290 | 291 | dialect := SqliteDialect{} 292 | model := rem.Use[testModel]() 293 | config := rem.QueryConfig{ 294 | Fields: model.Fields, 295 | Table: "testmodel", 296 | } 297 | expectedSql := "ALTER TABLE `testmodel` ADD COLUMN `test_value` TEXT NOT NULL" 298 | queryString, err := dialect.BuildTableColumnAdd(config, "test_value") 299 | if err != nil { 300 | t.Errorf("Unexpected error %s", err.Error()) 301 | } 302 | if queryString != expectedSql { 303 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 304 | } 305 | } 306 | 307 | func TestBuildTableColumnDrop(t *testing.T) { 308 | dialect := SqliteDialect{} 309 | config := rem.QueryConfig{Table: "testmodel"} 310 | expectedSql := "ALTER TABLE `testmodel` DROP COLUMN `test_value`" 311 | queryString, err := dialect.BuildTableColumnDrop(config, "test_value") 312 | if err != nil { 313 | t.Errorf("Unexpected error %s", err.Error()) 314 | } 315 | if queryString != expectedSql { 316 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 317 | } 318 | } 319 | 320 | func TestBuildTableCreate(t *testing.T) { 321 | type testModel struct { 322 | Id int64 `db:"test_id" db_primary:"true"` 323 | Value1 string `db:"test_value_1" db_max_length:"100"` 324 | } 325 | 326 | dialect := SqliteDialect{} 327 | model := rem.Use[testModel]() 328 | config := rem.QueryConfig{ 329 | Fields: model.Fields, 330 | Table: "testmodel", 331 | } 332 | expectedSql := "CREATE TABLE `testmodel` (\n" + 333 | "\t`test_id` INTEGER PRIMARY KEY NOT NULL,\n" + 334 | "\t`test_value_1` TEXT NOT NULL\n" + 335 | ")" 336 | queryString, err := dialect.BuildTableCreate(config, rem.TableCreateConfig{}) 337 | if err != nil { 338 | t.Errorf("Unexpected error %s", err.Error()) 339 | } 340 | if queryString != expectedSql { 341 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 342 | } 343 | 344 | expectedSql = "CREATE TABLE IF NOT EXISTS `testmodel` (\n" + 345 | "\t`test_id` INTEGER PRIMARY KEY NOT NULL,\n" + 346 | "\t`test_value_1` TEXT NOT NULL\n" + 347 | ")" 348 | queryString, err = dialect.BuildTableCreate(config, rem.TableCreateConfig{IfNotExists: true}) 349 | if err != nil { 350 | t.Errorf("Unexpected error %s", err.Error()) 351 | } 352 | if queryString != expectedSql { 353 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 354 | } 355 | } 356 | 357 | func TestBuildTableDrop(t *testing.T) { 358 | dialect := SqliteDialect{} 359 | config := rem.QueryConfig{Table: "testmodel"} 360 | expectedSql := "DROP TABLE `testmodel`" 361 | queryString, err := dialect.BuildTableDrop(config, rem.TableDropConfig{}) 362 | if err != nil { 363 | t.Errorf("Unexpected error %s", err.Error()) 364 | } 365 | if queryString != expectedSql { 366 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 367 | } 368 | 369 | expectedSql = "DROP TABLE IF EXISTS `testmodel`" 370 | queryString, err = dialect.BuildTableDrop(config, rem.TableDropConfig{IfExists: true}) 371 | if err != nil { 372 | t.Errorf("Unexpected error %s", err.Error()) 373 | } 374 | if queryString != expectedSql { 375 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 376 | } 377 | } 378 | 379 | func TestBuildUpdate(t *testing.T) { 380 | type testModel struct { 381 | Id int64 `db:"test_id" db_primary:"true"` 382 | Value1 string `db:"test_value_1" db_max_length:"100"` 383 | Value2 string `db:"test_value_2" db_max_length:"100"` 384 | } 385 | 386 | dialect := SqliteDialect{} 387 | model := rem.Use[testModel]() 388 | 389 | query := model.Filter("test_id", "=", 1) 390 | config := query.Config 391 | config.Fields = model.Fields 392 | config.Table = "testmodel" 393 | expectedArgs := []interface{}{"foo", "bar", 1} 394 | expectedSql := "UPDATE `testmodel` SET `test_value_1` = ?,`test_value_2` = ? WHERE `test_id` = ?" 395 | queryString, args, err := dialect.BuildUpdate(config, map[string]interface{}{ 396 | "id": 123, 397 | "test_value_1": "foo", 398 | "test_value_2": "bar", 399 | }, "test_value_1", "test_value_2") 400 | if err != nil { 401 | t.Errorf("Unexpected error %s", err.Error()) 402 | } 403 | if queryString != expectedSql { 404 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 405 | } 406 | if !slices.Equal(args, expectedArgs) { 407 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 408 | } 409 | 410 | query.Limit(3) 411 | config = query.Config 412 | config.Fields = model.Fields 413 | config.Table = "testmodel" 414 | expectedArgs = []interface{}{"foo", "bar", 1, 3} 415 | expectedSql = "UPDATE `testmodel` SET `test_value_1` = ?,`test_value_2` = ? WHERE `test_id` = ? LIMIT ?" 416 | queryString, args, err = dialect.BuildUpdate(config, map[string]interface{}{ 417 | "id": 123, 418 | "test_value_1": "foo", 419 | "test_value_2": "bar", 420 | }, "test_value_1", "test_value_2") 421 | if err != nil { 422 | t.Errorf("Unexpected error %s", err.Error()) 423 | } 424 | if queryString != expectedSql { 425 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 426 | } 427 | if !slices.Equal(args, expectedArgs) { 428 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 429 | } 430 | 431 | query.Sort("-test_id") 432 | config = query.Config 433 | config.Fields = model.Fields 434 | config.Table = "testmodel" 435 | expectedSql = "UPDATE `testmodel` SET `test_value_1` = ?,`test_value_2` = ? WHERE `test_id` = ? ORDER BY `test_id` DESC LIMIT ?" 436 | queryString, args, err = dialect.BuildUpdate(config, map[string]interface{}{ 437 | "id": 123, 438 | "test_value_1": "foo", 439 | "test_value_2": "bar", 440 | }, "test_value_1", "test_value_2") 441 | if err != nil { 442 | t.Errorf("Unexpected error %s", err.Error()) 443 | } 444 | if queryString != expectedSql { 445 | t.Errorf("Expected '%s', got '%s'", expectedSql, queryString) 446 | } 447 | if !slices.Equal(args, expectedArgs) { 448 | t.Errorf("Expected '%s', got '%s'", expectedArgs, args) 449 | } 450 | } 451 | 452 | func TestColumnType(t *testing.T) { 453 | type testFkInt struct { 454 | Id int64 `db:"id" db_primary:"true"` 455 | } 456 | 457 | type testFkString struct { 458 | Id string `db:"id" db_primary:"true" db_max_length:"100"` 459 | } 460 | 461 | type testModel struct { 462 | BigInt int64 `db:"test_big_int"` 463 | BigIntNull sql.NullInt64 `db:"test_big_int_null"` 464 | Bool bool `db:"test_bool"` 465 | BoolNull sql.NullBool `db:"test_bool_null"` 466 | Custom []byte `db:"test_custom" db_type:"BLOB NOT NULL"` 467 | Default string `db:"test_default" db_default:"'foo'" db_max_length:"100"` 468 | Float float32 `db:"test_float"` 469 | Double float64 `db:"test_double"` 470 | DoubleNull sql.NullFloat64 `db:"test_double_null"` 471 | Id int64 `db:"test_id" db_primary:"true"` 472 | Int int32 `db:"test_int"` 473 | IntNull sql.NullInt32 `db:"test_int_null"` 474 | SmallInt int16 `db:"test_small_int"` 475 | SmallIntNull sql.NullInt16 `db:"test_small_int_null"` 476 | Text string `db:"test_text"` 477 | TextNull sql.NullString `db:"test_text_null"` 478 | Time time.Time `db:"test_time"` 479 | TimeNow time.Time `db:"test_time_now" db_default:"CURRENT_TIMESTAMP"` 480 | TimeNull sql.NullTime `db:"test_time_null"` 481 | TinyInt int8 `db:"test_tiny_int"` 482 | Varchar string `db:"test_varchar" db_max_length:"100"` 483 | VarcharNull sql.NullString `db:"test_varchar_null" db_max_length:"50"` 484 | ForiegnKey rem.ForeignKey[testFkString] `db:"test_fk_id" db_on_delete:"CASCADE"` 485 | ForiegnKeyNull rem.NullForeignKey[testFkInt] `db:"test_fk_null_id" db_on_delete:"SET NULL" db_on_update:"SET NULL"` 486 | Unique string `db:"test_unique" db_max_length:"255" db_unique:"true"` 487 | } 488 | 489 | expected := map[string]string{ 490 | "test_big_int": "INTEGER NOT NULL", 491 | "test_big_int_null": "INTEGER NULL", 492 | "test_bool": "BOOLEAN NOT NULL", 493 | "test_bool_null": "BOOLEAN NULL", 494 | "test_custom": "BLOB NOT NULL", 495 | "test_default": "TEXT NOT NULL DEFAULT 'foo'", 496 | "test_float": "REAL NOT NULL", 497 | "test_double": "REAL NOT NULL", 498 | "test_double_null": "REAL NULL", 499 | "test_id": "INTEGER PRIMARY KEY NOT NULL", 500 | "test_int": "INTEGER NOT NULL", 501 | "test_int_null": "INTEGER NULL", 502 | "test_small_int": "INTEGER NOT NULL", 503 | "test_small_int_null": "INTEGER NULL", 504 | "test_time": "DATETIME NOT NULL", 505 | "test_time_now": "DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP", 506 | "test_time_null": "DATETIME NULL", 507 | "test_text": "TEXT NOT NULL", 508 | "test_text_null": "TEXT NULL", 509 | "test_tiny_int": "INTEGER NOT NULL", 510 | "test_varchar": "TEXT NOT NULL", 511 | "test_varchar_null": "TEXT NULL", 512 | "test_fk_id": "TEXT NOT NULL REFERENCES `testfkstring` (`id`) ON DELETE CASCADE", 513 | "test_fk_null_id": "INTEGER NULL REFERENCES `testfkint` (`id`) ON UPDATE SET NULL ON DELETE SET NULL", 514 | "test_unique": "TEXT NOT NULL UNIQUE", 515 | } 516 | 517 | dialect := SqliteDialect{} 518 | model := rem.Use[testModel]() 519 | fieldKeys := maps.Keys(model.Fields) 520 | sort.Strings(fieldKeys) 521 | 522 | for _, fieldName := range fieldKeys { 523 | field := model.Fields[fieldName] 524 | columnType, err := dialect.ColumnType(field) 525 | if err != nil { 526 | t.Fatalf(`dialect.ColumnType() threw error for '%#v': %s`, field, err) 527 | } 528 | if columnType != expected[fieldName] { 529 | t.Fatalf(`dialect.ColumnType() returned '%s', but expected '%s' for '%#v'`, columnType, expected[fieldName], field) 530 | } 531 | } 532 | } 533 | 534 | func TestQuoteIdentifier(t *testing.T) { 535 | values := map[string]string{ 536 | "abc": "`abc`", 537 | "a`bc": "`a``bc`", 538 | "a``b`c": "`a````b``c`", 539 | "`abc": "```abc`", 540 | "abc`": "`abc```", 541 | "ab\\`c": "`ab\\``c`", 542 | "abc\\": "`abc\\`", 543 | } 544 | 545 | for identifier, expected := range values { 546 | actual := QuoteIdentifier(identifier) 547 | if actual != expected { 548 | t.Errorf("Expected %s, got %s", expected, actual) 549 | } 550 | } 551 | } 552 | -------------------------------------------------------------------------------- /stringers.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type SqlAs struct { 9 | Alias string 10 | Column interface{} 11 | } 12 | 13 | func (as SqlAs) StringForDialect(dialect Dialect) string { 14 | switch cv := as.Column.(type) { 15 | case string: 16 | return fmt.Sprint(dialect.QuoteIdentifier(cv), " AS ", dialect.QuoteIdentifier(as.Alias)) 17 | 18 | case DialectStringer: 19 | return fmt.Sprint(cv.StringForDialect(dialect), " AS ", dialect.QuoteIdentifier(as.Alias)) 20 | 21 | case fmt.Stringer: 22 | return fmt.Sprint(cv.String(), " AS ", dialect.QuoteIdentifier(as.Alias)) 23 | } 24 | 25 | panic(fmt.Sprintf("rem: unsupported type for rem.As '%#v'", as.Column)) 26 | } 27 | 28 | func As(column interface{}, alias string) SqlAs { 29 | return SqlAs{Alias: alias, Column: column} 30 | } 31 | 32 | type SqlColumn string 33 | 34 | func (column SqlColumn) StringForDialect(dialect Dialect) string { 35 | return dialect.QuoteIdentifier(string(column)) 36 | } 37 | 38 | func Column(column string) SqlColumn { 39 | return SqlColumn(column) 40 | } 41 | 42 | type SqlParam struct { 43 | Value interface{} 44 | } 45 | 46 | func Param(value interface{}) SqlParam { 47 | return SqlParam{Value: value} 48 | } 49 | 50 | type SqlWithParams struct { 51 | Segments []interface{} 52 | } 53 | 54 | func (sqlWithParams SqlWithParams) StringWithArgs(dialect Dialect, args []interface{}) (string, []interface{}, error) { 55 | var queryString strings.Builder 56 | for _, part := range sqlWithParams.Segments { 57 | switch cv := part.(type) { 58 | case SqlParam: 59 | args = append(args, cv.Value) 60 | queryString.WriteString(dialect.Param(len(args))) 61 | case string: 62 | queryString.WriteString(cv) 63 | default: 64 | queryString.WriteString(fmt.Sprint(cv)) 65 | } 66 | } 67 | return queryString.String(), args, nil 68 | } 69 | 70 | func Sql(segments ...interface{}) SqlWithParams { 71 | return SqlWithParams{Segments: segments} 72 | } 73 | 74 | type SqlUnsafe struct { 75 | Sql string 76 | } 77 | 78 | func (sqlUnsafe SqlUnsafe) String() string { 79 | return sqlUnsafe.Sql 80 | } 81 | 82 | func Unsafe(sql string) SqlUnsafe { 83 | return SqlUnsafe{Sql: sql} 84 | } 85 | -------------------------------------------------------------------------------- /stringers_test.go: -------------------------------------------------------------------------------- 1 | package rem 2 | 3 | import ( 4 | "testing" 5 | 6 | "golang.org/x/exp/slices" 7 | ) 8 | 9 | func TestSql(t *testing.T) { 10 | dialect := testDialect{} 11 | sql, args, err := Sql(`SELECT count(1) AS "x"`).StringWithArgs(dialect, []interface{}{}) 12 | expectedArgs := []interface{}{} 13 | expectedSql := `SELECT count(1) AS "x"` 14 | if err != nil { 15 | t.Error("Unexpected error:", err) 16 | } 17 | if !slices.Equal(args, expectedArgs) { 18 | t.Errorf("Expected '%+v', got '%+v'", expectedArgs, args) 19 | } 20 | if sql != expectedSql { 21 | t.Errorf("Expected '%s', got '%s'", expectedSql, sql) 22 | } 23 | 24 | sql, args, err = Sql("SELECT * FROM x WHERE y = ", Param(100), " AND z IS ", Param(true)).StringWithArgs(dialect, []interface{}{}) 25 | expectedArgs = []interface{}{100, true} 26 | expectedSql = "SELECT * FROM x WHERE y = $1 AND z IS $2" 27 | if err != nil { 28 | t.Error("Unexpected error:", err) 29 | } 30 | if !slices.Equal(args, expectedArgs) { 31 | t.Errorf("Expected '%+v', got '%+v'", expectedArgs, args) 32 | } 33 | if sql != expectedSql { 34 | t.Errorf("Expected '%s', got '%s'", expectedSql, sql) 35 | } 36 | } 37 | 38 | func TestUnsafe(t *testing.T) { 39 | unsafe := Unsafe(`SELECT count(1) AS "x"`) 40 | if unsafe.Sql != `SELECT count(1) AS "x"` { 41 | t.Errorf(`Expected 'SELECT count(1) AS "x"', got '%s'`, unsafe.Sql) 42 | } 43 | } 44 | --------------------------------------------------------------------------------