├── .github ├── CODE_OF_CONDUCT.md └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── _benchmarks ├── README.md ├── go.mod ├── go.sum └── insert_single_test.go ├── _examples ├── README.md ├── basic │ └── main.go ├── live-table │ ├── go.mod │ ├── go.sum │ ├── index.html │ └── main.go ├── logging │ ├── go.mod │ ├── go.sum │ └── main.go ├── password │ ├── go.mod │ ├── go.sum │ └── main.go ├── presenter │ └── main.go └── view │ ├── _embed │ └── example.sql │ └── main.go ├── common.go ├── concurrent_tx.go ├── db.go ├── db_example_test.go ├── db_information.go ├── db_information_example_test.go ├── db_information_test.go ├── db_repository.go ├── db_stat.go ├── db_table_listener.go ├── db_table_listener_example_test.go ├── desc ├── alter_table_constraint_query.go ├── argument.go ├── column.go ├── column_basic_info.go ├── column_filter_text_parser.go ├── column_filter_text_parser_test.go ├── constraint.go ├── constraint_test.go ├── create_table_query.go ├── data_type.go ├── delete_query.go ├── desc.go ├── duplicate_query.go ├── duplicate_query_test.go ├── exists_query.go ├── index_type.go ├── insert_query.go ├── naming.go ├── naming_test.go ├── password_handler.go ├── reflect.go ├── scanner.go ├── struct_table.go ├── struct_table_test.go ├── table.go ├── table_test.go ├── trigger.go ├── unique_index.go ├── update_query.go ├── zeroer.go └── zeroer_test.go ├── errors.go ├── example_test.go ├── gen ├── README.md ├── db_schema_gen.go ├── db_schema_gen_example_test.go ├── export_options.go ├── schema_columns_gen.go └── schema_columns_gen_test.go ├── go.mod ├── go.sum ├── listener.go ├── listener_example_test.go ├── repository.go ├── repository_example_test.go ├── repository_table_listener_example_test.go ├── schema.go └── schema_example_test.go /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, gender identity and expression, level of experience, 9 | nationality, personal appearance, race, religion, or sexual identity and 10 | orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at kataras2006@hotmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at [http://contributor-covenant.org/version/1/4][version] 72 | 73 | [homepage]: http://contributor-covenant.org 74 | [version]: http://contributor-covenant.org/version/1/4/ -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | 14 | test: 15 | name: Test 16 | runs-on: ubuntu-latest 17 | 18 | strategy: 19 | matrix: 20 | go_version: [1.24.x] 21 | 22 | services: 23 | postgres: 24 | image: postgres:16-alpine 25 | env: 26 | POSTGRES_USER: postgres 27 | POSTGRES_PASSWORD: admin!123 28 | POSTGRES_DB: test_db 29 | ports: 30 | - 5432:5432 31 | options: >- 32 | --health-cmd pg_isready 33 | --health-interval 10s 34 | --health-timeout 5s 35 | --health-retries 5 36 | 37 | steps: 38 | 39 | - name: Set up Go 1.x 40 | uses: actions/setup-go@v5 41 | with: 42 | go-version: ${{ matrix.go_version }} 43 | 44 | - name: Check out code into the Go module directory 45 | uses: actions/checkout@v4 46 | 47 | - name: Test 48 | run: go test -v ./... 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | .directory 4 | cov 5 | coverage.out 6 | package-lock.json 7 | access.log 8 | node_modules 9 | issue-*/ 10 | _bckp 11 | internalcode-*/ 12 | _testdata/** 13 | /_examples/feature-*/ 14 | _examples/**/uploads/* 15 | _issues/** 16 | .DS_STORE -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2023-2025 Gerasimos Maropoulos 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /_benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarks 2 | 3 | Execute the following SQL query to create the table: 4 | 5 | ```sql 6 | CREATE EXTENSION IF NOT EXISTS pgcrypto; 7 | 8 | -- ---------------------------- 9 | -- Table structure for customers 10 | -- ---------------------------- 11 | DROP TABLE IF EXISTS "public"."customers"; 12 | CREATE TABLE "public"."customers" ( 13 | "id" uuid NOT NULL DEFAULT gen_random_uuid(), 14 | "created_at" timestamp(6) NOT NULL DEFAULT clock_timestamp(), 15 | "updated_at" timestamp(6) NOT NULL DEFAULT clock_timestamp(), 16 | "cognito_user_id" uuid NOT NULL 17 | ); 18 | 19 | -- ---------------------------- 20 | -- Triggers structure for table customers 21 | -- ---------------------------- 22 | CREATE TRIGGER "set_timestamp" BEFORE UPDATE ON "public"."customers" 23 | FOR EACH ROW 24 | EXECUTE PROCEDURE "public"."trigger_set_timestamp"(); 25 | 26 | -- ---------------------------- 27 | -- Uniques structure for table customers 28 | -- ---------------------------- 29 | ALTER TABLE "public"."customers" ADD CONSTRAINT "customers_cognito_user_id_key" UNIQUE ("cognito_user_id"); 30 | 31 | -- ---------------------------- 32 | -- Primary Key structure for table customers 33 | -- ---------------------------- 34 | ALTER TABLE "public"."customers" ADD CONSTRAINT "customers_pkey" PRIMARY KEY ("id"); 35 | ``` 36 | 37 | Run the benchmarks: 38 | 39 | ```sh 40 | $ go test -bench=BenchmarkDB_InsertSingle_Gorm -count 6 | tee result_gorm.txt 41 | $ go test -bench=BenchmarkDB_InsertSingle_Pg -count 6 | tee result_pg.txt 42 | $ # Check the runtime inside these files and the ns/op numbers or: 43 | $ go install golang.org/x/perf/cmd/benchstat@latest 44 | $ benchstat result_gorm.txt result_pg.txt 45 | ``` 46 | -------------------------------------------------------------------------------- /_benchmarks/go.mod: -------------------------------------------------------------------------------- 1 | module benchmarks 2 | 3 | go 1.21 4 | 5 | replace github.com/kataras/pg => ../ 6 | 7 | require ( 8 | github.com/google/uuid v1.3.0 9 | github.com/kataras/pg v0.0.0-00010101000000-000000000000 10 | gorm.io/driver/postgres v1.5.2 11 | gorm.io/gorm v1.25.3 12 | ) 13 | 14 | require ( 15 | github.com/gertd/go-pluralize v0.2.1 // indirect 16 | github.com/jackc/pgpassfile v1.0.0 // indirect 17 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 18 | github.com/jackc/pgx/v5 v5.4.3 // indirect 19 | github.com/jackc/puddle/v2 v2.2.1 // indirect 20 | github.com/jinzhu/inflection v1.0.0 // indirect 21 | github.com/jinzhu/now v1.1.5 // indirect 22 | golang.org/x/crypto v0.9.0 // indirect 23 | golang.org/x/sync v0.1.0 // indirect 24 | golang.org/x/text v0.9.0 // indirect 25 | ) 26 | -------------------------------------------------------------------------------- /_benchmarks/go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/gertd/go-pluralize v0.2.1 h1:M3uASbVjMnTsPb0PNqg+E/24Vwigyo/tvyMTtAlLgiA= 5 | github.com/gertd/go-pluralize v0.2.1/go.mod h1:rbYaKDbsXxmRfr8uygAEKhOWsjyrrqrkHVpZvoOp8zk= 6 | github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= 7 | github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 8 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 9 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 10 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= 11 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 12 | github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= 13 | github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= 14 | github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= 15 | github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= 16 | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= 17 | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= 18 | github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= 19 | github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= 20 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 21 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 22 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 23 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 24 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 25 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 26 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 27 | golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= 28 | golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= 29 | golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= 30 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 31 | golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= 32 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 33 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 34 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 35 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 36 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 37 | gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= 38 | gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= 39 | gorm.io/gorm v1.25.3 h1:zi4rHZj1anhZS2EuEODMhDisGy+Daq9jtPrNGgbQYD8= 40 | gorm.io/gorm v1.25.3/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= 41 | -------------------------------------------------------------------------------- /_benchmarks/insert_single_test.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | 10 | "github.com/kataras/pg" 11 | 12 | "gorm.io/driver/postgres" 13 | "gorm.io/gorm" 14 | ) 15 | 16 | // Customer is a struct that represents a customer entity in the database. 17 | type Customer struct { 18 | ID string `pg:"type=uuid,primary"` 19 | CreatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 20 | UpdatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 21 | // CognitoUserID string `pg:"type=uuid,unique,conflict=DO UPDATE SET cognito_user_id=EXCLUDED.cognito_user_id"` 22 | CognitoUserID string `pg:"type=uuid,unique"` 23 | } 24 | 25 | var ( 26 | dsn = "host=localhost user=postgres password=admin!123 dbname=test_db sslmode=disable search_path=public" 27 | ) 28 | 29 | // go test -benchtime=5s -benchmem -run=^$ -bench ^BenchmarkDB_Insert* 30 | 31 | // go test -bench=BenchmarkDB_InsertSingle_Gorm -count 6 | tee result_gorm.txt 32 | func BenchmarkDB_InsertSingle_Gorm(b *testing.B) { 33 | db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) 34 | if err != nil { 35 | b.Fatal(err) 36 | } 37 | 38 | /* To create the schema: 39 | db.AutoMigrate(&Customer{}) 40 | */ 41 | 42 | // This doesn't even works... 43 | // db.Clauses(clause.OnConflict{DoUpdates: clause.AssignmentColumns([]string{"cognito_user_id"})}). 44 | // db.Clauses(clause.OnConflict{DoUpdates: clause.Assignments(map[string]any{"cognito_user_id": `EXCLUDED.cognito_user_id`})}). 45 | 46 | customer := Customer{ 47 | CognitoUserID: uuid.NewString(), 48 | } 49 | 50 | db. 51 | Omit("id", "created_at", "updated_at"). 52 | Create(&customer) 53 | } 54 | 55 | // go test -bench=BenchmarkDB_InsertSingle_Pg -count 6 | tee result_pg.txt 56 | func BenchmarkDB_InsertSingle_Pg(b *testing.B) { 57 | var schema = pg.NewSchema().MustRegister("customers", Customer{}) 58 | 59 | db, err := pg.Open(context.Background(), schema, dsn) 60 | if err != nil { 61 | b.Fatal(err) 62 | } 63 | 64 | /* To create the schema: 65 | db.CreateSchema(context.Background()) 66 | */ 67 | 68 | // Automatically takes care of id, created_at and updated_at fields. 69 | customer := Customer{CognitoUserID: uuid.NewString()} 70 | 71 | err = db.InsertSingle(context.Background(), customer, &customer.ID) 72 | if err != nil { 73 | b.Fatal(err) 74 | } 75 | 76 | /* To create a record from repository (static types): 77 | repo := pg.NewRepository[Customer](db) 78 | err := repo.InsertSingle(context.Background(), customer, &customer.ID) 79 | */ 80 | } 81 | 82 | // benchstat result_gorm.txt result_pg.txt 83 | // ± 351% 84 | -------------------------------------------------------------------------------- /_examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This folder contains some examples for the `PG` package. 4 | 5 | - [Basic](./basic/main.go) 6 | - [HTML5 Table Real-Time Database Table View](./live-table/main.go) 7 | - [Logging](./logging/main.go) 8 | - [Password](./password/main.go) 9 | - [Presenter](./presenter/main.go) 10 | - [View](./view/main.go) 11 | 12 | The document below describes some basic principles of the package. 13 | 14 | ## Basic Example 15 | 16 | This example shows how to use pg to perform basic CRUD operations on a single table. 17 | 18 | ### Model 19 | 20 | The model is a struct that represents a customer entity with an id and a firstname. 21 | 22 | ```go 23 | type Customer struct { 24 | ID string `pg:"type=uuid,primary"` 25 | CreatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 26 | UpdatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 27 | Firstname string `pg:"type=varchar(255)"` 28 | } 29 | ``` 30 | 31 | ### Schema 32 | 33 | The schema is an instance of `pg.Schema` that registers the model and its table name. 34 | 35 | ```go 36 | schema := pg.NewSchema() 37 | schema.MustRegister("customers", Customer{}) 38 | ``` 39 | 40 | ### Database 41 | 42 | The database is an instance of `pg.DB` that connects to the PostgreSQL server using the connection string and the schema. 43 | 44 | ```go 45 | connString := "postgres://postgres:admin!123@localhost:5432/test_db?sslmode=disable" 46 | db, err := pg.Open(context.Background(), schema, connString) 47 | if err != nil { 48 | panic(err) 49 | } 50 | defer db.Close() 51 | ``` 52 | 53 | ### Operations 54 | 55 | The operations are methods of `pg.DB` or `pg.Repository` that perform queries on the database using the model. 56 | 57 | - To create the tables for the pg.Schema above, use the `db.CreateSchema` method: 58 | 59 | ```go 60 | err := db.CreateSchema(context.Background()) 61 | if err != nil { 62 | panic(err) 63 | } 64 | ``` 65 | 66 | - To insert a record and bind the result ID, use the `db.InsertSingle` method: 67 | 68 | ```go 69 | customer := &Customer{ 70 | Firstname: "Alice", 71 | } 72 | err := db.InsertSingle(context.Background(), customer, &customer.ID) 73 | if err != nil { 74 | panic(err) 75 | } 76 | ``` 77 | 78 | - To insert one or more records, use the `db.Insert` method: 79 | 80 | ```go 81 | customer := &Customer{ 82 | Firstname: "Alice", 83 | } 84 | err := db.Insert(context.Background(), customer) 85 | if err != nil { 86 | panic(err) 87 | } 88 | ``` 89 | 90 | - To query a record by primary key, use the `db.SelectByID` method: 91 | 92 | ```go 93 | var customer Customer 94 | err := db.SelectByID(context.Background(), &customer, "some-uuid") 95 | if err != nil { 96 | panic(err) 97 | } 98 | fmt.Println(customer.Firstname) // Alice 99 | ``` 100 | 101 | - To update a record, use the `db.Update` method: 102 | 103 | ```go 104 | customer.Firstname = "Bob" 105 | err := db.Update(context.Background(), customer) 106 | if err != nil { 107 | panic(err) 108 | } 109 | ``` 110 | 111 | - To delete a record, use the `db.Delete` method: 112 | 113 | ```go 114 | err := db.Delete(context.Background(), customer) 115 | if err != nil { 116 | panic(err) 117 | } 118 | ``` 119 | 120 | ## Repository Example 121 | 122 | This example shows how to use pg to implement the repository pattern for a single table. 123 | 124 | ### Model 125 | 126 | The model is a struct that represents a product entity with an ID, a name, and a price. 127 | 128 | ```go 129 | type Product struct { 130 | ID int64 `pg:"type=int,primary"` 131 | Name string `pg:"name"` 132 | Price float64 `pg:"price"` 133 | } 134 | ``` 135 | 136 | ### Schema 137 | 138 | The schema is an instance of `pg.Schema` that registers the model and its table name. 139 | 140 | ```go 141 | schema := pg.NewSchema() 142 | schema.MustRegister("products", Product{}) 143 | ``` 144 | 145 | ### Database 146 | 147 | The database is an instance of `pg.DB` that connects to the PostgreSQL server using the connection string and the schema. 148 | 149 | ```go 150 | connString := "postgres://postgres:admin!123@localhost:5432/test_db?sslmode=disable" 151 | db, err := pg.Open(context.Background(), schema, connString) 152 | if err != nil { 153 | panic(err) 154 | } 155 | defer db.Close() 156 | ``` 157 | 158 | ### Repository 159 | 160 | The repository is an instance of `pg.Repository[Product]` that provides methods to perform queries on the products table using the model. 161 | 162 | ```go 163 | products := pg.NewRepository[Product](db) 164 | ``` 165 | 166 | ### Operations 167 | 168 | - To insert a record, use the `products.InsertSingle` method: 169 | 170 | ```go 171 | product := &Product{ 172 | Name: "Laptop", 173 | Price: 999.99, 174 | } 175 | err := products.InsertSingle(context.Background(), product, &product.ID) 176 | if err != nil { 177 | panic(err) 178 | } 179 | ``` 180 | 181 | - To query a record by primary key, use the `products.SelectByID` method: 182 | 183 | ```go 184 | err := products.SelectByID(context.Background(), 1) 185 | if err != nil { 186 | panic(err) 187 | } 188 | fmt.Println(product.Name) // Laptop 189 | ``` 190 | 191 | - To query multiple records by a condition, use the `products.Select` method: 192 | 193 | ```go 194 | query := `SELECT * FROM products WHERE price > $1 ORDER BY price DESC;` 195 | products, err := products.Select(context.Background(), query, 500) 196 | if err != nil { 197 | panic(err) 198 | } 199 | for _, product := range products { 200 | fmt.Printf("- (%d) %s: $%.2f\n", product.ID, product.Name, product.Price) 201 | } 202 | ``` 203 | 204 | - To update a record, use the `products.Update` method: 205 | 206 | ```go 207 | product.Price = 899.99 208 | err := products.Update(context.Background(), product) 209 | if err != nil { 210 | panic(err) 211 | } 212 | ``` 213 | 214 | - To delete a record, use the `products.Delete` method: 215 | 216 | ```go 217 | err := products.Delete(context.Background(), product) 218 | if err != nil { 219 | panic(err) 220 | } 221 | ``` 222 | 223 | ## Transaction Example 224 | 225 | This example shows how to use pg to perform queries within a transaction. 226 | 227 | ### Model 228 | 229 | The model is a struct that represents a customer entity with an id and a firstname. 230 | 231 | ```go 232 | type Customer struct { 233 | ID string `pg:"type=uuid,primary"` 234 | CreatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 235 | UpdatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 236 | Firstname string `pg:"type=varchar(255)"` 237 | } 238 | ``` 239 | 240 | ### Schema 241 | 242 | The schema is an instance of `pg.Schema` that registers the model and its table name. 243 | 244 | ```go 245 | schema := pg.NewSchema() 246 | schema.MustRegister("customers", Customer{}) 247 | ``` 248 | 249 | ### Database 250 | 251 | The database is an instance of `pg.DB` that connects to the PostgreSQL server using the connection string and the schema. 252 | 253 | ```go 254 | connString := "postgres://postgres:admin!123@localhost:5432/test_db?sslmode=disable" 255 | db, err := pg.Open(context.Background(), schema, connString) 256 | if err != nil { 257 | panic(err) 258 | } 259 | defer db.Close() 260 | ``` 261 | 262 | ### Transaction 263 | 264 | The transaction is an instance of `pg.DB` that is created by the `db.InTransaction` method. The `db.InTransaction` method takes a function that receives a `context.Context` and `pg.DB` instance as arguments. You can use the `pg.DB` instance to run queries within the transaction. If the function returns an error, the transaction will be rolled back. Otherwise, the transaction will be committed. 265 | 266 | ```go 267 | err := db.InTransaction(context.Background(), func(db *pg.DB) error { 268 | // Run queries within the transaction 269 | err := db.Insert(context.Background(), customer) 270 | if err != nil { 271 | return err 272 | } 273 | err := db.Update(context.Background(), customer) 274 | if err != nil { 275 | return err 276 | } 277 | // Return nil to commit the transaction 278 | return nil 279 | }) 280 | if err != nil { 281 | panic(err) 282 | } 283 | ``` 284 | -------------------------------------------------------------------------------- /_examples/basic/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "time" 8 | 9 | "github.com/kataras/pg" 10 | ) 11 | 12 | type Base struct { 13 | ID string `pg:"type=uuid,primary"` 14 | CreatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 15 | UpdatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 16 | } 17 | 18 | type Customer struct { 19 | Base 20 | 21 | Firstname string `pg:"type=varchar(255)"` 22 | } 23 | 24 | func main() { 25 | // Create Schema instance. 26 | schema := pg.NewSchema() 27 | schema.MustRegister("customers", Customer{}) 28 | 29 | // Create Database instance. 30 | connString := "postgres://postgres:admin!123@localhost:5432/test_db?sslmode=disable" 31 | db, err := pg.Open(context.Background(), schema, connString) 32 | if err != nil { 33 | log.Fatal(fmt.Errorf("open database: %w", err)) 34 | } 35 | defer db.Close() 36 | 37 | // Optionally create and check the database schema. 38 | if err = db.CreateSchema(context.Background()); err != nil { 39 | log.Fatal(fmt.Errorf("create schema: %w", err)) 40 | } 41 | 42 | if err = db.CheckSchema(context.Background()); err != nil { 43 | log.Fatal(fmt.Errorf("check schema: %w", err)) 44 | } 45 | 46 | // Create a Repository of Customer type. 47 | customers := pg.NewRepository[Customer](db) 48 | 49 | var newCustomer = Customer{ 50 | Firstname: "John", 51 | } 52 | 53 | // Insert a new Customer. 54 | err = customers.InsertSingle(context.Background(), newCustomer, &newCustomer.ID) 55 | if err != nil { 56 | log.Fatal(fmt.Errorf("insert customer: %w", err)) 57 | } 58 | 59 | // Get by id. 60 | 61 | /* 62 | query := `SELECT * FROM customers WHERE id = $1 LIMIT 1;` 63 | existing, err := customers.SelectSingle(context.Background(), query, newCustomer.ID) 64 | OR: 65 | */ 66 | existing, err := customers.SelectByID(context.Background(), newCustomer.ID) 67 | if err != nil { 68 | log.Fatal(err) 69 | } 70 | 71 | log.Printf("Existing Customer (SelectSingle):\n%#+v\n", existing) 72 | 73 | // Get all. 74 | query := `SELECT * FROM customers ORDER BY created_at DESC;` 75 | allCustomers, err := customers.Select(context.Background(), query) 76 | if err != nil { 77 | log.Fatal(fmt.Errorf("select all: %w", err)) 78 | } 79 | log.Printf("All Customers (%d): ", len(allCustomers)) 80 | for _, customer := range allCustomers { 81 | fmt.Printf("- (%s) %s\n", customer.ID, customer.Firstname) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /_examples/live-table/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/kataras/pg_examples/live-table-reload 2 | 3 | go 1.21.3 4 | 5 | replace github.com/kataras/pg => ../../ 6 | 7 | require ( 8 | github.com/gorilla/websocket v1.5.0 9 | github.com/kataras/pg v1.0.6 10 | ) 11 | 12 | require ( 13 | github.com/gertd/go-pluralize v0.2.1 // indirect 14 | github.com/jackc/pgpassfile v1.0.0 // indirect 15 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 16 | github.com/jackc/pgx/v5 v5.4.3 // indirect 17 | github.com/jackc/puddle/v2 v2.2.1 // indirect 18 | golang.org/x/crypto v0.14.0 // indirect 19 | golang.org/x/sync v0.4.0 // indirect 20 | golang.org/x/text v0.13.0 // indirect 21 | ) 22 | -------------------------------------------------------------------------------- /_examples/live-table/go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/gertd/go-pluralize v0.2.1 h1:M3uASbVjMnTsPb0PNqg+E/24Vwigyo/tvyMTtAlLgiA= 5 | github.com/gertd/go-pluralize v0.2.1/go.mod h1:rbYaKDbsXxmRfr8uygAEKhOWsjyrrqrkHVpZvoOp8zk= 6 | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= 7 | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 8 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 9 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 10 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= 11 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 12 | github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= 13 | github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= 14 | github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= 15 | github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= 16 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 17 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 18 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 19 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 20 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 21 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 22 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 23 | golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= 24 | golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= 25 | golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= 26 | golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= 27 | golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= 28 | golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= 29 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 30 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 31 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 32 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 33 | -------------------------------------------------------------------------------- /_examples/live-table/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | PG Real-Time Data 8 | 45 | 89 | 90 | 91 | 92 | 93 |
94 |

PG Real-Time Data

95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 111 | 112 |
IDUsernameEmail
113 |
114 | 115 | 116 | -------------------------------------------------------------------------------- /_examples/logging/go.mod: -------------------------------------------------------------------------------- 1 | module example_logging 2 | 3 | go 1.21 4 | 5 | replace github.com/kataras/pg => ../../ 6 | 7 | require ( 8 | github.com/kataras/golog v0.1.9 9 | github.com/kataras/pg v0.0.0-00010101000000-000000000000 10 | github.com/kataras/pgx-golog v0.0.1 11 | ) 12 | 13 | require ( 14 | github.com/gertd/go-pluralize v0.2.1 // indirect 15 | github.com/jackc/pgpassfile v1.0.0 // indirect 16 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 17 | github.com/jackc/pgx/v5 v5.4.3 // indirect 18 | github.com/jackc/puddle/v2 v2.2.1 // indirect 19 | github.com/kataras/pio v0.0.12 // indirect 20 | golang.org/x/crypto v0.9.0 // indirect 21 | golang.org/x/sync v0.1.0 // indirect 22 | golang.org/x/sys v0.9.0 // indirect 23 | golang.org/x/text v0.9.0 // indirect 24 | ) 25 | -------------------------------------------------------------------------------- /_examples/logging/go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/gertd/go-pluralize v0.2.1 h1:M3uASbVjMnTsPb0PNqg+E/24Vwigyo/tvyMTtAlLgiA= 5 | github.com/gertd/go-pluralize v0.2.1/go.mod h1:rbYaKDbsXxmRfr8uygAEKhOWsjyrrqrkHVpZvoOp8zk= 6 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 7 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 8 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= 9 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 10 | github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= 11 | github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= 12 | github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= 13 | github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= 14 | github.com/kataras/golog v0.1.9 h1:vLvSDpP7kihFGKFAvBSofYo7qZNULYSHOH2D7rPTKJk= 15 | github.com/kataras/golog v0.1.9/go.mod h1:jlpk/bOaYCyqDqH18pgDHdaJab72yBE6i0O3s30hpWY= 16 | github.com/kataras/pgx-golog v0.0.1 h1:e8bankbEM/2rKLgtb6wiiB0ze5nY+6cx3wmr1bj+KEI= 17 | github.com/kataras/pgx-golog v0.0.1/go.mod h1:lnfwUCGl9cPXNwu1yiepE+aal6N1vJmpCgb+UGy1p7k= 18 | github.com/kataras/pio v0.0.12 h1:o52SfVYauS3J5X08fNjlGS5arXHjW/ItLkyLcKjoH6w= 19 | github.com/kataras/pio v0.0.12/go.mod h1:ODK/8XBhhQ5WqrAhKy+9lTPS7sBf6O3KcLhc9klfRcY= 20 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 21 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 22 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 23 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 24 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 25 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 26 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 27 | golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= 28 | golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= 29 | golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= 30 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 31 | golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= 32 | golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 33 | golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= 34 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 35 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 36 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 37 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 38 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 39 | -------------------------------------------------------------------------------- /_examples/logging/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/kataras/golog" 7 | "github.com/kataras/pg" 8 | pgxgolog "github.com/kataras/pgx-golog" 9 | ) 10 | 11 | // [...] 12 | 13 | const connString = "postgres://postgres:admin!123@localhost:5432/test_db?sslmode=disable&search_path=public" 14 | 15 | func main() { 16 | golog.SetLevel("debug") 17 | schema := pg.NewSchema() 18 | 19 | logger := pgxgolog.NewLogger(golog.Default) 20 | /* 21 | tracer := &tracelog.TraceLog{ 22 | Logger: logger, 23 | LogLevel: tracelog.LogLevelTrace, 24 | } 25 | 26 | connConfig, err := pgxpool.ParseConfig(connString) 27 | if err != nil { 28 | panic(err) 29 | } 30 | 31 | // Set the tracer. 32 | connConfig.ConnConfig.Tracer = tracer 33 | 34 | pool, err := pgxpool.NewWithConfig(context.Background(), connConfig) 35 | if err != nil { 36 | panic(err) 37 | } 38 | 39 | // Use OpenPool instead of Open to use the pool's connections. 40 | db := pg.OpenPool(schema, pool) 41 | */ 42 | // OR: 43 | db, err := pg.Open(context.Background(), schema, connString, pg.WithLogger(logger)) 44 | if err != nil { 45 | panic(err) 46 | } 47 | defer db.Close() 48 | 49 | rows, err := db.Query(context.Background(), `SELECT * FROM blog_posts;`) 50 | if err != nil { 51 | panic(err) 52 | } 53 | 54 | rows.Close() 55 | } 56 | -------------------------------------------------------------------------------- /_examples/password/go.mod: -------------------------------------------------------------------------------- 1 | module example_password 2 | 3 | go 1.21 4 | 5 | replace github.com/kataras/pg => ../../ 6 | 7 | require github.com/kataras/pg v0.0.0-00010101000000-000000000000 8 | 9 | require ( 10 | github.com/gertd/go-pluralize v0.2.1 // indirect 11 | github.com/jackc/pgpassfile v1.0.0 // indirect 12 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 13 | github.com/jackc/pgx/v5 v5.4.3 // indirect 14 | github.com/jackc/puddle/v2 v2.2.1 // indirect 15 | golang.org/x/crypto v0.9.0 // indirect 16 | golang.org/x/mod v0.12.0 // indirect 17 | golang.org/x/sync v0.1.0 // indirect 18 | golang.org/x/text v0.9.0 // indirect 19 | ) 20 | -------------------------------------------------------------------------------- /_examples/password/go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/gertd/go-pluralize v0.2.1 h1:M3uASbVjMnTsPb0PNqg+E/24Vwigyo/tvyMTtAlLgiA= 5 | github.com/gertd/go-pluralize v0.2.1/go.mod h1:rbYaKDbsXxmRfr8uygAEKhOWsjyrrqrkHVpZvoOp8zk= 6 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 7 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 8 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= 9 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 10 | github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= 11 | github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= 12 | github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= 13 | github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= 14 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 15 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 16 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 17 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 18 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 19 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 20 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 21 | golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= 22 | golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= 23 | golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= 24 | golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 25 | golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= 26 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 27 | golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= 28 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 29 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 30 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 31 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 32 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 33 | -------------------------------------------------------------------------------- /_examples/password/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "time" 8 | 9 | "github.com/kataras/pg" 10 | "github.com/kataras/pg/gen" 11 | ) 12 | 13 | func init() { 14 | pg.SetDefaultTag("pg") // you can modify it to "db" as well. 15 | } 16 | 17 | type Base struct { 18 | ID string `pg:"type=uuid,primary"` 19 | CreatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 20 | UpdatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 21 | } 22 | 23 | type User struct { 24 | Base 25 | 26 | Firstname string `pg:"type=varchar(255)"` 27 | Lastname string `pg:"type=varchar(255)"` 28 | Email string `pg:"type=varchar(255),username,unique,conflict=DO UPDATE SET email=EXCLUDED.email"` 29 | Password string `pg:"type=varchar(72),password" json:"password,omitempty"` 30 | } 31 | 32 | /* 33 | // Use a PasswordHandler and Encrypt and Decrypt to manually encrypt and decrypt passwords. 34 | // However, for better security, just use the: `pg:"type=varchar(72),password"` tag for password fields 35 | // and let the library do the job for you. 36 | var passwordHandler = pg.PasswordHandler{ 37 | Encrypt: func(tableName, plainPassword string) (encryptedPassword string, err error) { 38 | return 39 | }, 40 | // If you don't want to set passwords on Select then skip this Decrypt field. 41 | Decrypt: func(tableName, encryptedPassword string) (plainPassword string, err error) { 42 | return 43 | }, 44 | } 45 | 46 | schema.HandlePassword(passwordHandler) 47 | */ 48 | 49 | func main() { 50 | // Create Schema instance. 51 | schema := pg.NewSchema() 52 | schema.MustRegister("users", User{}) 53 | 54 | // Optionally generate the files for the given schema. 55 | // This can be used to statically have access to column names of each registered table. 56 | // It's not required to run this, it's just a helper 57 | // for a separate CLI flag to generate-only your table definition. 58 | // 59 | // Generated code usage: 60 | // definition.User.PG_TableName // "users" 61 | // definition.User.CreatedAt.String() // "created_at" 62 | // definition.User.Firstname.String() // "firstname" 63 | defer func() { 64 | opts := gen.ExportOptions{ 65 | RootDir: "./definition", 66 | } 67 | gen.GenerateColumnsFromSchema(schema, &opts) 68 | }() 69 | // Create Database instance. 70 | /* 71 | Available connection string formats: 72 | - 73 | connString := fmt.Sprintf("host=%s port=%d user=%s password=%s search_path=%s dbname=%s sslmode=%s", 74 | host, port, user, password, schema, dbname, sslMode) 75 | - 76 | connString := "postgres://postgres:admin!123@localhost:5432/test_db?sslmode=disable&search_path=public" 77 | */ 78 | 79 | connString := "postgres://postgres:admin!123@localhost:5432/test_db?sslmode=disable&search_path=public" 80 | db, err := pg.Open(context.Background(), schema, connString) 81 | if err != nil { 82 | log.Fatal(fmt.Errorf("open database: %w", err)) 83 | } 84 | defer db.Close() 85 | 86 | if err = db.CreateSchema(context.Background()); err != nil { 87 | log.Fatal(fmt.Errorf("create schema: %w", err)) 88 | } 89 | 90 | if err = db.CheckSchema(context.Background()); err != nil { 91 | log.Fatal(fmt.Errorf("check schema: %w", err)) 92 | } 93 | 94 | // Create a Repository of User type. 95 | users := pg.NewRepository[User](db) 96 | 97 | var newUser = User{ 98 | Firstname: "John", 99 | Lastname: "Doe", 100 | Email: "kataras2006@hotmail.com", 101 | Password: "123456", 102 | } 103 | 104 | // Insert a new User with credentials. 105 | err = users.InsertSingle(context.Background(), newUser, &newUser.ID) 106 | if err != nil { 107 | log.Fatal(fmt.Errorf("insert user: %w", err)) 108 | } 109 | 110 | // Get by id. 111 | query := `SELECT * FROM users WHERE id = $1 LIMIT 1;` 112 | existingUser, err := users.SelectSingle(context.Background(), query, newUser.ID) 113 | if err != nil { 114 | log.Fatal(err) 115 | } 116 | 117 | log.Printf("Existing User (SelectSingle):\n%#+v\n", existingUser) 118 | 119 | // Check credentials. 120 | verifiedUser, err := users.SelectByUsernameAndPassword(context.Background(), "kataras2006@hotmail.com", "123456") 121 | if err != nil { // will return pg.ErrNoRows if not found (invalid username or password). 122 | log.Fatal(err) 123 | } 124 | verifiedUser.Password = "" // clear the password if you want (it contains the encrypted anyways). 125 | 126 | log.Printf("Verified User (SelectByUsernameAndPassword):\n%#+v\n", verifiedUser) 127 | } 128 | -------------------------------------------------------------------------------- /_examples/presenter/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | 8 | "github.com/kataras/pg" 9 | ) 10 | 11 | type TableInfo struct { 12 | TableName string `pg:"table_name"` 13 | TableType string `pg:"table_type"` 14 | } 15 | 16 | func main() { 17 | // Create Schema instance. 18 | schema := pg.NewSchema() 19 | // Register the table as a presenter, the third argument is the only important step here. 20 | schema.MustRegister("table_info_presenters", TableInfo{}, pg.Presenter) 21 | 22 | // Create Database instance. 23 | connString := "postgres://postgres:admin!123@localhost:5432/test_db?sslmode=disable" 24 | db, err := pg.Open(context.Background(), schema, connString) 25 | if err != nil { 26 | log.Fatal(fmt.Errorf("open database: %w", err)) 27 | } 28 | defer db.Close() 29 | 30 | if err = db.CreateSchema(context.Background()); err != nil { 31 | log.Fatal(err) 32 | } 33 | 34 | if err = db.CheckSchema(context.Background()); err != nil { 35 | log.Fatal(err) 36 | } 37 | 38 | repo := pg.NewRepository[TableInfo](db) 39 | 40 | // This can be created through normal table registration but 41 | // this is just an example of how to use the presenter. 42 | tables, err := repo.Select(context.Background(), `SELECT table_name,table_type FROM information_schema.tables WHERE table_schema = $1;`, db.SearchPath()) 43 | if err != nil { 44 | log.Fatal(err) 45 | } 46 | 47 | fmt.Printf("Found %d table(s).\n", len(tables)) 48 | for _, t := range tables { 49 | fmt.Printf("- %s (%s)", t.TableName, t.TableType) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /_examples/view/_embed/example.sql: -------------------------------------------------------------------------------- 1 | -- your commands here, can be splitted using ';' 2 | CREATE OR REPLACE VIEW blog_master AS 3 | SELECT b.*, COUNT(bp) as posts_count 4 | FROM blogs b 5 | INNER JOIN blog_posts bp ON blog_id = b.id 6 | GROUP BY b.id; -------------------------------------------------------------------------------- /_examples/view/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "embed" 6 | "fmt" 7 | "log" 8 | "time" 9 | 10 | "github.com/kataras/pg" 11 | ) 12 | 13 | //go:embed _embed 14 | var embedDir embed.FS 15 | 16 | type ( 17 | BaseEntity struct { 18 | ID string `pg:"type=uuid,primary"` 19 | CreatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 20 | UpdatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 21 | } 22 | 23 | Blog struct { 24 | BaseEntity 25 | 26 | Name string `pg:"type=varchar(255)"` 27 | } 28 | 29 | BlogMaster struct { 30 | Blog 31 | PostsCount int64 `pg:"type=bigint"` 32 | } 33 | ) 34 | 35 | func main() { 36 | // Create Schema instance. 37 | schema := pg.NewSchema() 38 | // Register the table as a view, the third argument is the only important step here. 39 | // This view is created through _embed/example.sql file. 40 | schema.MustRegister("blog_master", BlogMaster{}, pg.View) 41 | 42 | // Create Database instance. 43 | connString := "postgres://postgres:admin!123@localhost:5432/test_db?sslmode=disable" 44 | db, err := pg.Open(context.Background(), schema, connString) 45 | if err != nil { 46 | log.Fatal(fmt.Errorf("open database: %w", err)) 47 | } 48 | defer db.Close() 49 | 50 | // Here you can define your functions, triggers, tables and e.t.c. as an embedded sql file which 51 | // should be executed on the database. 52 | if err = db.ExecFiles(context.Background(), embedDir, "_embed/example.sql"); err != nil { 53 | log.Fatal(err) 54 | } 55 | 56 | // Optional, and this doesn't have any meaning here 57 | // because we explore just the "views" example here. 58 | if err := db.CreateSchema(context.Background()); err != nil { 59 | log.Fatal(err) 60 | } 61 | 62 | if err := db.CheckSchema(context.Background()); err != nil { 63 | log.Fatal(err) 64 | } 65 | // 66 | 67 | repo := pg.NewRepository[BlogMaster](db) 68 | blogs, err := repo.Select(context.Background(), `SELECT * FROM blog_master`) 69 | if err != nil { 70 | log.Fatal(fmt.Errorf("select all blog masters: %w", err)) 71 | } 72 | 73 | for _, blog := range blogs { 74 | fmt.Printf("%s: posts count: %d\n", blog.Name, blog.PostsCount) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /common.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // QuerySlice executes the given query and returns a list of T entries. 8 | // Note that the rows scanner will directly scan an element of T, meaning 9 | // that the type of T should be a database scannabled type (e.g. string, int, time.Time, etc.). 10 | // 11 | // The ErrNoRows is discarded, an empty list and a nil error will be returned instead. 12 | // If a string column is empty then it's skipped from the returning list. 13 | // Example: 14 | // 15 | // names, err := QuerySlice[string](ctx, db, "SELECT name FROM users;") 16 | func QuerySlice[T any](ctx context.Context, db *DB, query string, args ...any) ([]T, error) { 17 | rows, err := db.Query(ctx, query, args...) 18 | if err != nil { 19 | return nil, err 20 | } 21 | defer rows.Close() 22 | 23 | var t T 24 | _, isString := any(t).(string) 25 | 26 | var list []T 27 | 28 | for rows.Next() { 29 | var entry T 30 | if err = rows.Scan(&entry); err != nil { 31 | return nil, err 32 | } 33 | 34 | if isString { 35 | if any(entry).(string) == "" { 36 | continue 37 | } 38 | } 39 | 40 | list = append(list, entry) 41 | } 42 | 43 | if err = rows.Err(); err != nil && err != ErrNoRows { 44 | return nil, err 45 | } 46 | 47 | return list, nil 48 | } 49 | 50 | // QueryTwoSlices executes the given query and returns two lists of T and V entries. 51 | // Same behavior as QuerySlice but with two lists. 52 | func QueryTwoSlices[T, V any](ctx context.Context, db *DB, query string, args ...any) ([]T, []V, error) { 53 | rows, err := db.Query(ctx, query, args...) 54 | if err != nil { 55 | return nil, nil, err 56 | } 57 | defer rows.Close() 58 | 59 | var ( 60 | tList []T 61 | vList []V 62 | ) 63 | for rows.Next() { 64 | var ( 65 | t T 66 | v V 67 | ) 68 | if err = rows.Scan(&t, &v); err != nil { 69 | return nil, nil, err 70 | } 71 | 72 | tList = append(tList, t) 73 | vList = append(vList, v) 74 | } 75 | 76 | if err = rows.Err(); err != nil && err != ErrNoRows { 77 | return nil, nil, err 78 | } 79 | 80 | return tList, vList, nil 81 | } 82 | 83 | // QuerySingle executes the given query and returns a single T entry. 84 | // 85 | // Example: 86 | // 87 | // names, err := QuerySingle[MyType](ctx, db, "SELECT a_json_field FROM users;") 88 | func QuerySingle[T any](ctx context.Context, db *DB, query string, args ...any) (entry T, err error) { 89 | err = db.QueryRow(ctx, query, args...).Scan(&entry) 90 | return 91 | } 92 | 93 | func scanQuery[T any](ctx context.Context, db *DB, scanner func(rows Rows) (T, error), query string, args ...any) ([]T, error) { 94 | rows, err := db.Query(ctx, query, args...) 95 | if err != nil { 96 | return nil, err 97 | } 98 | defer rows.Close() 99 | var list []T 100 | 101 | for rows.Next() { 102 | entry, err := scanner(rows) 103 | if err != nil { 104 | return nil, err 105 | } 106 | 107 | list = append(list, entry) 108 | } 109 | 110 | if err = rows.Err(); err != nil && err != ErrNoRows { 111 | return nil, err 112 | } 113 | 114 | return list, nil 115 | } 116 | -------------------------------------------------------------------------------- /concurrent_tx.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | 7 | "github.com/jackc/pgx/v5" 8 | "github.com/jackc/pgx/v5/pgconn" 9 | "github.com/jackc/pgx/v5/pgxpool" 10 | ) 11 | 12 | // ConcurrentTx is a wrapper around pgx.Tx that provides a mutex to synchronize access 13 | // to the underlying pgx.Tx. This is useful when you want to use a pgx.Tx from 14 | // multiple goroutines. 15 | type ConcurrentTx struct { 16 | pgx.Tx 17 | mu sync.Mutex 18 | } 19 | 20 | // NewConcurrentTx is a wrapper around pgxpool.Pool.Begin that provides a mutex to synchronize 21 | // access to the underlying pgx.Tx. 22 | // It returns a TxSync that wraps the pgx.Tx. 23 | // The TxSync must be closed when done with it. 24 | func NewConcurrentTx(ctx context.Context, p *pgxpool.Pool) (*ConcurrentTx, error) { 25 | tx, err := p.Begin(ctx) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | return &ConcurrentTx{Tx: tx}, nil 31 | } 32 | 33 | // Rollback is a wrapper around pgx.Tx.Rollback that provides a mutex to synchronize 34 | // access to the underlying pgx.Tx. 35 | func (ct *ConcurrentTx) Rollback(ctx context.Context) error { 36 | ct.mu.Lock() 37 | defer ct.mu.Unlock() 38 | 39 | return ct.Tx.Rollback(ctx) 40 | } 41 | 42 | // Commit is a wrapper around pgx.Tx.Commit that provides a mutex to synchronize 43 | // access to the underlying pgx.Tx. 44 | func (ct *ConcurrentTx) Commit(ctx context.Context) error { 45 | ct.mu.Lock() 46 | defer ct.mu.Unlock() 47 | 48 | return ct.Tx.Commit(ctx) 49 | } 50 | 51 | // QueryRow is a wrapper around pgx.Tx.QueryRow that provides a mutex to synchronize 52 | // access to the underlying pgx.Tx. 53 | func (ct *ConcurrentTx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { 54 | ct.mu.Lock() 55 | defer ct.mu.Unlock() 56 | 57 | return ct.Tx.QueryRow(ctx, sql, args...) 58 | } 59 | 60 | // Query is a wrapper around pgx.Tx.Query that provides a mutex to synchronize 61 | // access to the underlying pgx.Tx. 62 | func (ct *ConcurrentTx) Query(ctx context.Context, sql string, args ...any) (Rows, error) { 63 | ct.mu.Lock() 64 | defer ct.mu.Unlock() 65 | 66 | return ct.Tx.Query(ctx, sql, args...) 67 | } 68 | 69 | // QueryRow is a wrapper around pgx.Tx.QueryRow that provides a mutex to synchronize 70 | // access to the underlying pgx.Tx. 71 | func (ct *ConcurrentTx) Exec(ctx context.Context, sql string, args ...any) (commandTag pgconn.CommandTag, err error) { 72 | ct.mu.Lock() 73 | defer ct.mu.Unlock() 74 | 75 | return ct.Tx.Exec(ctx, sql, args...) 76 | } 77 | 78 | // Prepare is a wrapper around pgx.Tx.Prepare that provides a mutex to synchronize 79 | // access to the underlying pgx.Tx. 80 | func (ct *ConcurrentTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { 81 | ct.mu.Lock() 82 | defer ct.mu.Unlock() 83 | 84 | return ct.Tx.Prepare(ctx, name, sql) 85 | } 86 | 87 | // SendBatch is a wrapper around pgx.Tx.SendBatch that provides a mutex to synchronize 88 | // access to the underlying pgx.Tx. 89 | func (ct *ConcurrentTx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { 90 | ct.mu.Lock() 91 | defer ct.mu.Unlock() 92 | 93 | return ct.Tx.SendBatch(ctx, b) 94 | } 95 | 96 | // Begin is a wrapper around pgx.Tx.Begin that provides a mutex to synchronize 97 | // access to the underlying pgx.Tx. 98 | func (ct *ConcurrentTx) Begin(ctx context.Context) (pgx.Tx, error) { 99 | ct.mu.Lock() 100 | defer ct.mu.Unlock() 101 | 102 | tx, err := ct.Tx.Begin(ctx) 103 | if err != nil { 104 | return nil, err 105 | } 106 | 107 | return &ConcurrentTx{Tx: tx}, nil 108 | } 109 | -------------------------------------------------------------------------------- /db_example_test.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | func ExampleOpen() { 9 | db, err := openTestConnection(true) 10 | if err != nil { 11 | handleExampleError(err) 12 | return 13 | } 14 | defer db.Close() 15 | 16 | // Work with the database... 17 | } 18 | 19 | func openTestConnection(resetSchema bool) (*DB, error) { 20 | // Database code. 21 | schema := NewSchema() 22 | schema.MustRegister("customers", Customer{}) // Register the Customer struct as a table named "customers". 23 | schema.MustRegister("blogs", Blog{}) // Register the Blog struct as a table named "blogs". 24 | schema.MustRegister("blog_posts", BlogPost{}) // Register the BlogPost struct as a table named "blog_posts". 25 | 26 | // Open a connection to the database using the schema and the connection string. 27 | db, err := Open(context.Background(), schema, getTestConnString()) 28 | if err != nil { 29 | return nil, err 30 | } 31 | // Let the caller close the database connection pool: defer db.Close() 32 | 33 | if resetSchema { 34 | // Let's clear the schema, so we can run the tests even if already ran once in the past. 35 | if err = db.DeleteSchema(context.Background()); err != nil { // DON'T DO THIS ON PRODUCTION. 36 | return nil, fmt.Errorf("delete schema: %w", err) 37 | } 38 | 39 | if err = db.CreateSchema(context.Background()); err != nil { // Create the schema in the database if it does not exist. 40 | return nil, fmt.Errorf("create schema: %w", err) 41 | } 42 | 43 | if err = db.CheckSchema(context.Background()); err != nil { // Check if the schema in the database matches the schema in the code. 44 | return nil, fmt.Errorf("check schema: %w", err) 45 | } 46 | } 47 | 48 | return db, nil 49 | } 50 | 51 | func openEmptyTestConnection() (*DB, error) { // without a schema. 52 | schema := NewSchema() 53 | // Open a connection to the database using the schema and the connection string. 54 | return Open(context.Background(), schema, getTestConnString()) 55 | } 56 | 57 | func createTestConnectionSchema() error { 58 | db, err := openTestConnection(true) 59 | if err != nil { 60 | return err 61 | } 62 | 63 | db.Close() 64 | return nil 65 | } 66 | 67 | // getTestConnString returns a connection string for connecting to a test database. 68 | // It uses constants to define the host, port, user, password, schema, dbname, and sslmode parameters. 69 | func getTestConnString() string { 70 | const ( 71 | host = "localhost" // The host name or IP address of the database server. 72 | port = 5432 // The port number of the database server. 73 | user = "postgres" // The user name to connect to the database with. 74 | password = "admin!123" // The password to connect to the database with. 75 | schema = "public" // The schema name to use in the database. 76 | dbname = "test_db" // The database name to connect to. 77 | sslMode = "disable" // The SSL mode to use for the connection. Can be disable, require, verify-ca or verify-full. 78 | ) 79 | 80 | connString := fmt.Sprintf("host=%s port=%d user=%s password=%s search_path=%s dbname=%s sslmode=%s", 81 | host, port, user, password, schema, dbname, sslMode) // Format the connection string using the parameters. 82 | 83 | return connString // Return the connection string. 84 | } 85 | -------------------------------------------------------------------------------- /db_information_test.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | ) 7 | 8 | // This should match the CI's postgres version. 9 | const expectedDBVersion = "16" 10 | 11 | func TestInformation_GetVersion(t *testing.T) { 12 | db, err := openEmptyTestConnection() 13 | if err != nil { 14 | t.Fatal(err) 15 | } 16 | defer db.Close() 17 | 18 | version, err := db.GetVersion(context.Background()) 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | 23 | if version != expectedDBVersion { 24 | t.Fatalf("expected version: %s but got: %s", expectedDBVersion, version) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /db_stat.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import "time" 4 | 5 | // PoolStat holds the database pool's statistics. 6 | type PoolStat struct { 7 | // AcquireCount is the cumulative count of successful acquires from the pool. 8 | AcquireCount int64 `json:"acquire_count"` 9 | // AcquireDuration is the total duration of all successful acquires from 10 | // the pool. 11 | AcquireDuration time.Duration `json:"acquire_duration"` 12 | // AcquiredConns is the number of currently acquired connections in the pool. 13 | AcquiredConns int32 `json:"acquired_conns"` 14 | // CanceledAcquireCount is the cumulative count of acquires from the pool 15 | // that were canceled by a context. 16 | CanceledAcquireCount int64 `json:"canceled_acquire_count"` 17 | // ConstructingConns is the number of conns with construction in progress in 18 | // the pool. 19 | ConstructingConns int32 `json:"constructing_conns"` 20 | // EmptyAcquireCount is the cumulative count of successful acquires from the pool 21 | // that waited for a resource to be released or constructed because the pool was 22 | // empty. 23 | EmptyAcquireCount int64 `json:"empty_acquire_count"` 24 | // IdleConns is the number of currently idle conns in the pool. 25 | IdleConns int32 `json:"idle_conns"` 26 | // MaxConns is the maximum size of the pool. 27 | MaxConns int32 `json:"max_conns"` 28 | // TotalConns is the total number of resources currently in the pool. 29 | // The value is the sum of ConstructingConns, AcquiredConns, and 30 | // IdleConns. 31 | TotalConns int32 `json:"total_conns"` 32 | } 33 | 34 | // PoolStat returns a snapshot of the database pool statistics. 35 | // The returned structure can be represented through JSON. 36 | func (db *DB) PoolStat() PoolStat { 37 | stats := db.Pool.Stat() 38 | return PoolStat{ 39 | AcquireCount: stats.AcquireCount(), 40 | AcquireDuration: stats.AcquireDuration(), 41 | AcquiredConns: stats.AcquiredConns(), 42 | CanceledAcquireCount: stats.CanceledAcquireCount(), 43 | ConstructingConns: stats.ConstructingConns(), 44 | EmptyAcquireCount: stats.EmptyAcquireCount(), 45 | IdleConns: stats.IdleConns(), 46 | MaxConns: stats.MaxConns(), 47 | TotalConns: stats.TotalConns(), 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /db_table_listener.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net" 10 | "strings" 11 | "sync/atomic" 12 | 13 | "github.com/kataras/pg/desc" 14 | ) 15 | 16 | // TableChangeType is the type of the table change. 17 | // Available values: INSERT, UPDATE, DELETE. 18 | type TableChangeType string 19 | 20 | const ( 21 | // TableChangeTypeInsert is the INSERT table change type. 22 | TableChangeTypeInsert TableChangeType = "INSERT" 23 | // TableChangeTypeUpdate is the UPDATE table change type. 24 | TableChangeTypeUpdate TableChangeType = "UPDATE" 25 | // TableChangeTypeDelete is the DELETE table change type. 26 | TableChangeTypeDelete TableChangeType = "DELETE" 27 | ) 28 | 29 | func changesToString(changes []TableChangeType) string { 30 | if len(changes) == 0 { 31 | return "" 32 | } 33 | 34 | var b strings.Builder 35 | for i, change := range changes { 36 | b.WriteString(string(change)) 37 | if i < len(changes)-1 { 38 | b.WriteString(" OR ") 39 | } 40 | } 41 | 42 | return b.String() 43 | } 44 | 45 | type ( 46 | // TableNotification is the notification message sent by the postgresql server 47 | // when a table change occurs. 48 | // The subscribed postgres channel is named 'table_change_notifications'. 49 | // The "old" and "new" fields are the old and new values of the row. 50 | // The "old" field is only available for UPDATE and DELETE table change types. 51 | // The "new" field is only available for INSERT and UPDATE table change types. 52 | // The "old" and "new" fields are raw json values, use the "json.Unmarshal" to decode them. 53 | // See "DB.ListenTable" method. 54 | TableNotification[T any] struct { 55 | Table string `json:"table"` 56 | Change TableChangeType `json:"change"` // INSERT, UPDATE, DELETE. 57 | 58 | New T `json:"new"` 59 | Old T `json:"old"` 60 | 61 | payload string `json:"-"` /* just in case */ 62 | } 63 | 64 | // TableNotificationJSON is the generic version of the TableNotification. 65 | TableNotificationJSON = TableNotification[json.RawMessage] 66 | ) 67 | 68 | // GetPayload returns the raw payload of the notification. 69 | func (tn TableNotification[T]) GetPayload() string { 70 | return tn.payload 71 | } 72 | 73 | // ListenTableOptions is the options for the "DB.ListenTable" method. 74 | type ListenTableOptions struct { 75 | // Tables map of table name and changes to listen for. 76 | // 77 | // Key is the table to listen on for changes. 78 | // Value is changes is the list of table changes to listen for. 79 | // Defaults to {"*": ["INSERT", "UPDATE", "DELETE"] }. 80 | Tables map[string][]TableChangeType 81 | 82 | // Channel is the name of the postgres channel to listen on. 83 | // Default: "table_change_notifications". 84 | Channel string 85 | 86 | // Function is the name of the postgres function 87 | // which is used to notify on table changes, the 88 | // trigger name is _. 89 | // Defaults to "table_change_notify". 90 | Function string 91 | } 92 | 93 | var defaultChangesToWatch = []TableChangeType{TableChangeTypeInsert, TableChangeTypeUpdate, TableChangeTypeDelete} 94 | 95 | func (opts *ListenTableOptions) setDefaults() { 96 | if opts.Channel == "" { 97 | opts.Channel = "table_change_notifications" 98 | } 99 | 100 | if opts.Function == "" { 101 | opts.Function = "table_change_notify" 102 | } 103 | 104 | if len(opts.Tables) == 0 { 105 | opts.Tables = map[string][]TableChangeType{wildcardTableStr: defaultChangesToWatch} 106 | } 107 | } 108 | 109 | const wildcardTableStr = "*" 110 | 111 | // PrepareListenTable prepares the table for listening for live table updates. 112 | // See "db.ListenTable" method for more. 113 | func (db *DB) PrepareListenTable(ctx context.Context, opts *ListenTableOptions) error { 114 | opts.setDefaults() 115 | 116 | isWildcard := false 117 | for table := range opts.Tables { 118 | if table == wildcardTableStr { 119 | isWildcard = true 120 | break 121 | } 122 | } 123 | 124 | if isWildcard { 125 | changesToWatch := opts.Tables[wildcardTableStr] 126 | if len(changesToWatch) == 0 { 127 | return nil 128 | } 129 | 130 | delete(opts.Tables, wildcardTableStr) // remove the wildcard entry and replace with table names in registered schema. 131 | for _, table := range db.schema.TableNames(desc.TableTypeBase) { 132 | opts.Tables[table] = changesToWatch 133 | } 134 | } 135 | 136 | if len(opts.Tables) == 0 { 137 | return nil 138 | } 139 | 140 | for table, changes := range opts.Tables { 141 | if err := db.prepareListenTable(ctx, opts.Channel, opts.Function, table, changes); err != nil { 142 | return err 143 | } 144 | } 145 | 146 | return nil 147 | } 148 | 149 | // PrepareListenTable prepares the table for listening for live table updates. 150 | // See "db.ListenTable" method for more. 151 | func (db *DB) prepareListenTable(ctx context.Context, channel, function, table string, changes []TableChangeType) error { 152 | if table == "" { 153 | return errors.New("empty table name") 154 | } 155 | 156 | if len(changes) == 0 { 157 | return nil 158 | } 159 | 160 | if atomic.LoadUint32(db.tableChangeNotifyFunctionOnce) == 0 { 161 | // First, check and create the trigger for all tables. 162 | query := fmt.Sprintf(` 163 | CREATE OR REPLACE FUNCTION %s() RETURNS trigger AS $$ 164 | DECLARE 165 | payload text; 166 | channel text := '%s'; 167 | 168 | BEGIN 169 | SELECT json_build_object('table', TG_TABLE_NAME, 'change', TG_OP, 'old', OLD, 'new', NEW)::text 170 | INTO payload; 171 | PERFORM pg_notify(channel, payload); 172 | IF (TG_OP = 'DELETE') THEN 173 | RETURN OLD; 174 | ELSE 175 | RETURN NEW; 176 | END IF; 177 | END; 178 | $$ 179 | LANGUAGE plpgsql;`, function, channel) 180 | 181 | _, err := db.Exec(ctx, query) 182 | if err != nil { 183 | return fmt.Errorf("create or replace function table_change_notify: %w", err) 184 | } 185 | 186 | atomic.StoreUint32(db.tableChangeNotifyFunctionOnce, 1) 187 | } 188 | 189 | db.tableChangeNotifyOnceMutex.RLock() 190 | _, triggerCreated := db.tableChangeNotifyTriggerOnce[table] 191 | db.tableChangeNotifyOnceMutex.RUnlock() 192 | if !triggerCreated { 193 | query := fmt.Sprintf(`CREATE OR REPLACE TRIGGER %s_%s 194 | AFTER %s 195 | ON %s 196 | FOR EACH ROW 197 | EXECUTE FUNCTION table_change_notify();`, table, function, changesToString(changes), table) 198 | 199 | _, err := db.Exec(ctx, query) 200 | if err != nil { 201 | return fmt.Errorf("create trigger %s_table_change_notify: %w", table, err) 202 | } 203 | 204 | db.tableChangeNotifyOnceMutex.Lock() 205 | db.tableChangeNotifyTriggerOnce[table] = struct{}{} 206 | db.tableChangeNotifyOnceMutex.Unlock() 207 | } 208 | 209 | return nil 210 | } 211 | 212 | // ListenTable registers a function which notifies on the given "table" changes (INSERT, UPDATE, DELETE), 213 | // the subscribed postgres channel is named 'table_change_notifications'. 214 | // 215 | // The callback function can return any other error to stop the listener. 216 | // The callback function can return nil to continue listening. 217 | // 218 | // TableNotification's New and Old fields are raw json values, use the "json.Unmarshal" to decode them 219 | // to the actual type. 220 | func (db *DB) ListenTable(ctx context.Context, opts *ListenTableOptions, callback func(TableNotificationJSON, error) error) (Closer, error) { 221 | if err := db.PrepareListenTable(ctx, opts); err != nil { 222 | return nil, err 223 | } 224 | 225 | conn, err := db.Listen(ctx, opts.Channel) 226 | if err != nil { 227 | return nil, err 228 | } 229 | 230 | go func() { 231 | defer conn.Close(ctx) 232 | 233 | for { 234 | var evt TableNotificationJSON 235 | 236 | notification, err := conn.Accept(ctx) 237 | if err != nil { 238 | if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, net.ErrClosed) { 239 | return // may produced by close. 240 | } 241 | 242 | if callback(evt, err) != nil { 243 | return 244 | } 245 | } 246 | 247 | // make payload available for debugging on errors. 248 | evt.payload = notification.Payload 249 | 250 | if err = json.Unmarshal([]byte(notification.Payload), &evt); err != nil { 251 | if callback(evt, err) != nil { 252 | return 253 | } 254 | } 255 | 256 | if err = callback(evt, nil); err != nil { 257 | // callback(evt, err) 258 | return 259 | } 260 | } 261 | }() 262 | 263 | return conn, nil 264 | } 265 | -------------------------------------------------------------------------------- /db_table_listener_example_test.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | ) 8 | 9 | func ExampleDB_ListenTable() { 10 | db, err := openTestConnection(true) 11 | if err != nil { 12 | handleExampleError(err) 13 | return 14 | } 15 | defer db.Close() 16 | 17 | opts := &ListenTableOptions{ 18 | Tables: map[string][]TableChangeType{"customers": defaultChangesToWatch}, 19 | } 20 | closer, err := db.ListenTable(context.Background(), opts, func(evt TableNotificationJSON, err error) error { 21 | if err != nil { 22 | fmt.Printf("received error: %v\n", err) 23 | return err 24 | } 25 | 26 | if evt.Change == "INSERT" { 27 | fmt.Printf("table: %s, event: %s, old: %s\n", evt.Table, evt.Change, string(evt.Old)) // new can't be predicated through its ID and timestamps. 28 | } else { 29 | fmt.Printf("table: %s, event: %s\n", evt.Table, evt.Change) 30 | } 31 | 32 | return nil 33 | }) 34 | if err != nil { 35 | fmt.Println(err) 36 | return 37 | } 38 | defer closer.Close(context.Background()) 39 | 40 | newCustomer := Customer{ 41 | CognitoUserID: "766064d4-a2a7-442d-aa75-33493bb4dbb9", 42 | Email: "kataras2024@hotmail.com", 43 | Name: "Makis", 44 | } 45 | err = db.InsertSingle(context.Background(), newCustomer, &newCustomer.ID) 46 | if err != nil { 47 | fmt.Println(err) 48 | return 49 | } 50 | 51 | newCustomer.Name = "Makis_UPDATED" 52 | _, err = db.UpdateOnlyColumns(context.Background(), []string{"name"}, newCustomer) 53 | if err != nil { 54 | fmt.Println(err) 55 | return 56 | } 57 | time.Sleep(8 * time.Second) // give it sometime to receive the notifications. 58 | // Output: 59 | // table: customers, event: INSERT, old: null 60 | // table: customers, event: UPDATE 61 | } 62 | -------------------------------------------------------------------------------- /desc/alter_table_constraint_query.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // BuildAlterTableForeignKeysQueries creates ALTER TABLE queries for adding foreign key constraints. 8 | func BuildAlterTableForeignKeysQueries(td *Table) []string { 9 | foreignKeys := td.ForeignKeys() 10 | queries := make([]string, 0, len(foreignKeys)) 11 | 12 | for _, fk := range foreignKeys { 13 | constraintName := fmt.Sprintf("%s_%s_fkey", td.Name, fk.ColumnName) 14 | 15 | dropQuery := fmt.Sprintf(`ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s;`, td.Name, constraintName) 16 | queries = append(queries, dropQuery) 17 | 18 | q := fmt.Sprintf(`ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`, 19 | td.Name, constraintName, fk.ColumnName, fk.ReferenceTableName, fk.ReferenceColumnName, fk.OnDelete) 20 | 21 | // Add the DEFERRABLE option if applicable 22 | if fk.Deferrable { 23 | q += " DEFERRABLE" 24 | } 25 | 26 | q += ";" 27 | queries = append(queries, q) 28 | } 29 | 30 | return queries 31 | } 32 | -------------------------------------------------------------------------------- /desc/argument.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | 7 | "github.com/jackc/pgx/v5/pgtype/zeronull" 8 | ) 9 | 10 | // Argument represents a single argument for a database query 11 | // It contains a column definition and a value. 12 | type Argument struct { 13 | Column *Column // the column definition for the argument 14 | Value any // the value for the argument 15 | } 16 | 17 | // Arguments is a slice of Argument. 18 | type Arguments []Argument 19 | 20 | // Values returns a slice of values from the arguments. 21 | func (args Arguments) Values() []any { 22 | values := make([]any, len(args)) // create a slice to hold the values 23 | for i := range args { 24 | values[i] = args[i].Value // assign each value from the argument to the slice 25 | } 26 | 27 | return values // return the slice of values 28 | } 29 | 30 | // ShiftEnd moves the argument with the given column name to the end of the slice. 31 | func (args *Arguments) ShiftEnd(arg Argument) { 32 | for i, a := range *args { 33 | if a.Column.Name == arg.Column.Name { // already exists, move to the end and return. 34 | *args = shiftToEndEnd(*args, i) 35 | return 36 | } 37 | } 38 | 39 | *args = append(*args, arg) // append the argument to the end of the slice. 40 | } 41 | 42 | func shiftToEndEnd[T any](s []T, x int) []T { 43 | if x < 0 { 44 | return s 45 | } 46 | 47 | if x >= len(s)-1 { 48 | return s 49 | } 50 | 51 | tmp := s[x] 52 | s = append(s[:x], s[x+1:]...) 53 | s = append(s, tmp) 54 | return s 55 | } 56 | 57 | // extractArguments takes a reflect value of a struct and a table definition 58 | // and returns a slice of arguments for each column in the table that is not auto-generated or has a default value. 59 | func extractArguments(td *Table, structValue reflect.Value, filter func(columnName string) bool) (Arguments, error) { 60 | args := make(Arguments, 0, len(td.Columns)) // create a slice to hold the arguments 61 | 62 | for _, c := range td.Columns { // loop over each column in the table definition 63 | if c.AutoGenerated || c.Presenter { 64 | continue // skip this column if it is auto-generated 65 | } 66 | 67 | field := structValue.FieldByIndex(c.FieldIndex) // get the struct field by using the column field index 68 | if !field.CanInterface() { 69 | continue // skip this field if it cannot be converted to an interface 70 | } 71 | 72 | fieldValue := field.Interface() // get the field value as an interface 73 | 74 | // If filter passed, respect just the filter. 75 | if filter != nil { 76 | if !filter(c.Name) { 77 | continue 78 | } 79 | } else { // if no custom filter passed, then check by its zero value if no default value on database. 80 | if c.Default != "" { 81 | if isZero(field) { 82 | // skip this field if it has a default value and the field value is zero, 83 | // the createTable function has configured the database's default value option 84 | continue 85 | } 86 | } 87 | } 88 | 89 | if c.Default != "" && c.Type == UUID && !c.Nullable && c.PrimaryKey { 90 | if isZero(fieldValue) { 91 | continue // skip this field if it is a UUID primary key and required and the field value is zero 92 | } 93 | } 94 | 95 | if c.Password && td.PasswordHandler.canEncrypt() { 96 | passwordFieldValue, ok := fieldValue.(string) 97 | if !ok { 98 | return nil, fmt.Errorf("password field: %s is not string", c.Name) 99 | } 100 | 101 | if passwordFieldValue == "" { 102 | return nil, fmt.Errorf("password field: %s is empty", c.Name) 103 | } 104 | 105 | encryptedPassword, err := td.PasswordHandler.Encrypt(td.Name, passwordFieldValue) 106 | if err != nil { 107 | return nil, fmt.Errorf("password handler: set: %w", err) 108 | } 109 | 110 | fieldValue = encryptedPassword // replace the value with the new password text 111 | } 112 | 113 | args = append(args, Argument{ 114 | Column: c, // assign the column definition to the argument 115 | Value: fieldValue, // assign the field value to the argument 116 | }) 117 | } 118 | 119 | return args, nil // return the arguments and nil error 120 | } 121 | 122 | // filterArguments takes a slice of arguments and a filter function and returns a slice of arguments. 123 | func filterArguments(args Arguments, filter func(arg *Argument) bool) Arguments { 124 | var filtered Arguments 125 | for _, arg := range args { 126 | if filter(&arg) { 127 | filtered = append(filtered, arg) 128 | } 129 | } 130 | return filtered 131 | } 132 | 133 | // FilterArgumentsForInsert takes a slice of arguments and returns a slice of arguments for insert. 134 | func filterArgumentsForFullUpdate(args Arguments) Arguments { 135 | return filterArguments(args, func(arg *Argument) bool { 136 | c := arg.Column 137 | 138 | if (c.PrimaryKey || c.ReferenceColumnName != "") && c.Default != "" && c.Type == UUID && c.Nullable { 139 | if isZero(arg.Value) { // fixes full update of a record which contains an optional reference UUID, we allow setting it to null, but 140 | // we have to replace empty string with zeronull.UUID{}. Note that on insert we omit it from the query, as it will default to the default sql line default value. 141 | arg.Value = zeronull.UUID{} 142 | } 143 | 144 | return true 145 | } 146 | 147 | return !c.IsGenerated() && !c.Presenter // && !arg.Column.Unscannable 148 | }) 149 | } 150 | 151 | // extractPrimaryKeyValues takes a table definition and a slice of reflect values of structs 152 | // and returns the primary key column name and a slice of primary key values. 153 | func extractPrimaryKeyValues(td *Table, values []any) (string, []any, error) { 154 | primaryKey, ok := td.PrimaryKey() 155 | if !ok { 156 | return "", nil, fmt.Errorf("no primary key found in table definition: %s", td.Name) 157 | } 158 | 159 | ids := make([]any, 0, len(values)) 160 | for _, value := range values { 161 | idValue, err := ExtractPrimaryKeyValue(primaryKey, IndirectValue(value)) 162 | if err != nil { 163 | return "", nil, err 164 | } 165 | 166 | ids = append(ids, idValue) 167 | } 168 | 169 | return primaryKey.Name, ids, nil 170 | } 171 | 172 | // ExtractPrimaryKeyValue takes a column definition and a reflect value of a struct 173 | func ExtractPrimaryKeyValue(primaryKey *Column, structValue reflect.Value) (any, error) { 174 | idField := structValue.FieldByIndex(primaryKey.FieldIndex) 175 | if idField.IsZero() { 176 | return nil, fmt.Errorf("primary key field value is zero") 177 | } 178 | 179 | if !idField.CanInterface() { 180 | return nil, fmt.Errorf("primary key field value cannot be extracted") 181 | } 182 | 183 | idValue := idField.Interface() 184 | if idValue == nil { 185 | return nil, fmt.Errorf("primary key value is nil") 186 | } 187 | 188 | return idValue, nil 189 | } 190 | -------------------------------------------------------------------------------- /desc/column.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "reflect" 7 | "strings" 8 | ) 9 | 10 | type ( 11 | // ColumnBuilder is an interface that is used to build a column definition. 12 | ColumnBuilder interface { 13 | BuildColumn(*Column) error 14 | } 15 | 16 | // Column is a type that represents a column definition for the database. 17 | Column struct { 18 | Table *Table // the parent table reference. 19 | TableName string // the name of the table this column lives at. 20 | TableDescription string // the description of the table this column lives at. 21 | TableType TableType // the type of the table this column lives at. 22 | 23 | Name string // the name of the column 24 | Type DataType // the data type of the column 25 | Description string // the description of the column 26 | OrdinalPosition int // the position (starting from 1) of the corresponding column in the table. 27 | FieldIndex []int // the index of the corresponding struct field 28 | FieldType reflect.Type // the reflect.Type of the corresponding struct field 29 | isPtr bool // reprots whether FieldType.Kind() == reflect.Ptr. 30 | /* if nil then wasn't able to resolve it by builtin method */ 31 | FieldName string // the name of the corresponding struct field 32 | TypeArgument string // an optional argument for the data type, e.g. 255 when Type is "varchar" 33 | PrimaryKey bool // a flag that indicates if the column is a primary key 34 | Identity bool // a flag that indicates if the column is an identity column, e.g. INT GENERATED ALWAYS AS IDENTITY 35 | // Required bool // a flag that indicates if the column is required (not null, let's just use the !Nullable) 36 | Default string // an optional default value or sql function for the column 37 | CheckConstraint string // an optional check constraint for the column 38 | Unique bool // a flag that indicates if the column has a unique constraint (postgres automatically adds an index for that single one) 39 | // As tested, the unique=true and a single column of unique_index is the same result in the database on table creation, 40 | // note that because the generator does store unique_index instead of simple unique on Go generated source files. 41 | Conflict string // an optional conflict action for the unique constraint, e.g do nothing 42 | 43 | // If true this and password field is used to SelectByUsernameAndPassword repository method. 44 | Username bool 45 | // If true Postgres handles password encryption (on inserts) and decryption (on selects), 46 | // note that you MUST set the Schema.HandlePassword in order for this to work by both ways. 47 | // A flag that indicates if the column is a password column. 48 | Password bool 49 | // If true it's a shorthand of default="null". 50 | Nullable bool // a flag that indicates if the column is nullable 51 | ReferenceTableName string // an optional reference table name for a foreign key constraint, e.g. user_profiles(id) -> user_profiles 52 | ReferenceColumnName string // an optional reference column name for a foreign key constraint, e.g. user_profiles(id) -> id 53 | DeferrableReference bool // a flag that indicates if the foreign key constraint is deferrable (omits foreign key checks on transactions) 54 | ReferenceOnDelete string // an optional action for deleting referenced rows when referencing rows are deleted, e.g. NO ACTION, RESTRICT, CASCADE, SET NULL and SET DEFAULT. Defaults to CASCADE. 55 | 56 | Index IndexType // an optional index type for the column 57 | 58 | // Unique indexes can really improve the performance on big data select queries 59 | // Read more at: https://www.postgresql.org/docs/current/indexes-unique.html 60 | UniqueIndex string // an optional name for a unique index on the column 61 | // If true then create table, insert, update and duplicate queries will omit this column. 62 | Presenter bool 63 | // If true then insert query will omit this column. 64 | AutoGenerated bool 65 | // If true then this column->struct value is skipped from the Select queries 66 | Unscannable bool // a flag that indicates if the column is unscannable 67 | 68 | // If true then this column-> struct field type is already implements a scanner interface for the table. 69 | isScanner bool 70 | } 71 | ) 72 | 73 | // IsGeneratedTimestamp returns true if the column is a timestamp column and 74 | // has a default value of "clock_timestamp()" or "now()". 75 | func (c *Column) IsGeneratedTimestamp() bool { 76 | if c.Type.IsTime() { 77 | defaultValue := strings.ToLower(c.Default) 78 | return (defaultValue == "clock_timestamp()" || defaultValue == "now()") 79 | } 80 | 81 | return false 82 | } 83 | 84 | // IsGeneratedPrimaryUUID returns true if the column is a primary UUID column and 85 | // has a default value of "gen_random_uuid()" or "uuid_generate_v4()". 86 | func (c *Column) IsGeneratedPrimaryUUID() bool { 87 | return c.PrimaryKey && !c.Nullable && c.Type == UUID && 88 | (c.Default == genRandomUUIDPGCryptoFunction1 || c.Default == genRandomUUIDPGCryptoFunction2) 89 | } 90 | 91 | // IsGenerated returns true if the column is a generated column. 92 | func (c *Column) IsGenerated() bool { 93 | return c.IsGeneratedPrimaryUUID() || c.IsGeneratedTimestamp() 94 | } 95 | 96 | //nolint:all 97 | func writeTagProp(w io.StringWriter, key string, value any) { 98 | if key == "" { 99 | return 100 | } 101 | 102 | if value == nil { 103 | w.WriteString(key) 104 | return 105 | } 106 | 107 | if isZero(value) { 108 | return // don't write if arg value is empty, e.g. "". 109 | } 110 | 111 | // if key[len(key)-1] != '=' { 112 | if !strings.Contains(key, "%") { 113 | // is probably just a boolean (which we don't need to declare its value if true). 114 | w.WriteString(key) 115 | return 116 | } 117 | 118 | if b, ok := value.(bool); ok && b { 119 | w.WriteString(key) 120 | return 121 | } 122 | 123 | _, _ = w.WriteString(fmt.Sprintf(key, value)) 124 | } 125 | 126 | // FieldTagString returns a string representation of the struct field tag for the column. 127 | func (c *Column) FieldTagString(strict bool) string { 128 | b := new(strings.Builder) 129 | b.WriteString(DefaultTag) 130 | b.WriteString(`:"`) 131 | 132 | writeTagProp(b, "name=%s", c.Name) 133 | writeTagProp(b, ",type=%s", c.Type.String()) 134 | if (c.Table != nil && c.Table.Type.IsReadOnly()) || c.TableType.IsReadOnly() { 135 | // If it's a view then don't write the rest of the tags, we only care for name and type. 136 | b.WriteString(`"`) 137 | return b.String() 138 | } 139 | 140 | if strict { 141 | writeTagProp(b, "(%s)", c.TypeArgument) 142 | } 143 | 144 | writeTagProp(b, ",primary", c.PrimaryKey) 145 | writeTagProp(b, ",identity", c.Identity) 146 | //writeTagProp(b, "", c.Required) 147 | if c.Nullable { 148 | // writeTagProp(b, ",default=%s", nullLiteral) 149 | writeTagProp(b, ",nullable", true) 150 | } else { 151 | defaultValue := c.Default 152 | if !strict { 153 | // E.g. {}::integer[], we need to cut the ::integer[] part as it's so strict. 154 | // Cut {}::integer[] the :: part. 155 | if names, ok := dataTypeText[c.Type]; ok { 156 | for _, name := range names { 157 | defaultValue = strings.TrimSuffix(defaultValue, "::"+name) 158 | } 159 | } 160 | } 161 | 162 | writeTagProp(b, ",default=%s", defaultValue) 163 | } 164 | 165 | writeTagProp(b, ",unique", c.Unique) 166 | writeTagProp(b, ",conflict=%s", c.Conflict) 167 | if strict { 168 | writeTagProp(b, ",username", c.Username) 169 | writeTagProp(b, ",password", c.Password) 170 | } 171 | 172 | if tb := c.ReferenceTableName; tb != "" { 173 | // write the ref line. 174 | writeTagProp(b, ",ref=%s", tb) 175 | 176 | if rc := c.ReferenceColumnName; rc != "" { 177 | writeTagProp(b, "(%s", rc) 178 | 179 | if c.ReferenceOnDelete != "" { 180 | writeTagProp(b, " "+c.ReferenceOnDelete, nil) 181 | } 182 | 183 | writeTagProp(b, " deferrable", c.DeferrableReference) 184 | writeTagProp(b, ")", nil) 185 | } 186 | } 187 | 188 | if c.Index != InvalidIndex { 189 | writeTagProp(b, ",index=%s", c.Index.String()) 190 | } 191 | 192 | writeTagProp(b, ",unique_index=%s", c.UniqueIndex) 193 | writeTagProp(b, ",check=%s", c.CheckConstraint) 194 | if strict { 195 | writeTagProp(b, ",auto", c.AutoGenerated) 196 | writeTagProp(b, ",presenter", c.Presenter) 197 | writeTagProp(b, ",unscannable", c.Unscannable) 198 | } 199 | 200 | b.WriteString(`"`) 201 | return b.String() 202 | } 203 | -------------------------------------------------------------------------------- /desc/column_basic_info.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import "reflect" 4 | 5 | // ColumnBasicInfo represents a basic column information, contains the table name, column name, ordinal position, column default value, 6 | // data type, data type argument, whether the column is nullable, whether the column is identity and whether the column is generated. 7 | type ColumnBasicInfo struct { 8 | TableName string 9 | TableDescription string 10 | TableType TableType 11 | Name string 12 | OrdinalPosition int 13 | Description string 14 | Default string 15 | DataType DataType 16 | DataTypeArgument string 17 | IsNullable bool 18 | IsIdentity bool 19 | IsGenerated bool 20 | } 21 | 22 | var _ ColumnBuilder = (*ColumnBasicInfo)(nil) 23 | 24 | func (c *ColumnBasicInfo) BuildColumn(column *Column) error { 25 | column.TableName = c.TableName 26 | column.TableDescription = c.TableDescription 27 | column.TableType = c.TableType 28 | column.Name = c.Name 29 | column.Description = c.Description 30 | column.OrdinalPosition = c.OrdinalPosition 31 | column.Default = c.Default 32 | column.Type = c.DataType 33 | column.TypeArgument = c.DataTypeArgument 34 | column.Nullable = c.IsNullable 35 | // column.Required = !c.IsNullable 36 | column.Identity = c.IsIdentity 37 | column.AutoGenerated = c.IsGenerated 38 | 39 | column.FieldIndex = []int{c.OrdinalPosition} 40 | column.FieldName = ToStructFieldName(c.Name) 41 | 42 | if typ := c.DataType.GoType(); typ != nil { 43 | column.FieldType = typ 44 | column.isPtr = typ.Kind() == reflect.Ptr 45 | } 46 | 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /desc/column_filter_text_parser.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "strings" 7 | ) 8 | 9 | // ColumnFilter is a function that returns whether this column should be live inside a table. 10 | type ColumnFilter func(*Column) bool 11 | 12 | // mergeColumnFilters merges multiple ColumnFilters into one. 13 | // func mergeColumnFilters(filters ...ColumnFilter) ColumnFilter { 14 | // if len(filters) == 0 { 15 | // return func(*Column) bool { 16 | // return true 17 | // } 18 | // } 19 | // 20 | // toFilter := func(c *Column) bool { 21 | // for _, filter := range filters { 22 | // if !filter(c) { 23 | // return false 24 | // } 25 | // } 26 | // 27 | // return true 28 | // } 29 | // 30 | // return toFilter 31 | // } 32 | 33 | const wildcardLiteral = "*" 34 | 35 | // columnFilterExpression is a type that represents a column filter expression. 36 | type columnFilterExpression struct { 37 | // store input here. 38 | input string 39 | // 40 | tableName string 41 | columnName string 42 | columnDataType DataType 43 | prefix string 44 | suffix string 45 | notEqualTo string 46 | containsColumnNames []string 47 | 48 | // For custom data storage. 49 | Data any // data to store. 50 | } 51 | 52 | // sortColumnFilterExpressions sorts the given ColumnFilterExpressions by static to more dynamic. 53 | func sortColumnFilterExpressions(expressions []*columnFilterExpression) { 54 | sort.SliceStable(expressions, func(i, j int) bool { 55 | c1 := expressions[i] 56 | c2 := expressions[j] 57 | 58 | if c1.tableName != wildcardLiteral && c1.columnName != wildcardLiteral && c1.columnDataType != InvalidDataType && len(c1.containsColumnNames) > 0 { 59 | return true 60 | } 61 | 62 | if c1.tableName == wildcardLiteral && c2.tableName == wildcardLiteral && c1.columnName != wildcardLiteral && c2.columnName != wildcardLiteral { 63 | // checks like target_date pointer and target_date not pointer field types. 64 | return len(c1.containsColumnNames) > len(c2.containsColumnNames) 65 | } 66 | 67 | if c1.tableName != wildcardLiteral && c1.columnName != wildcardLiteral && c1.columnDataType != InvalidDataType { 68 | return true 69 | } 70 | 71 | if c1.tableName != wildcardLiteral && c1.columnName != wildcardLiteral && 72 | c2.tableName == wildcardLiteral || c2.columnName == wildcardLiteral { 73 | return true 74 | } 75 | 76 | if c1.tableName != wildcardLiteral && c2.tableName == wildcardLiteral { 77 | return true 78 | } 79 | 80 | if c1.columnName != wildcardLiteral && c2.columnName == wildcardLiteral { 81 | return true 82 | } 83 | 84 | return false 85 | }) 86 | } 87 | 88 | // String returns the filter's raw input. 89 | func (p *columnFilterExpression) String() string { 90 | return p.input 91 | } 92 | 93 | // tableNameIsWildcard returns true if the table name is wildcard. 94 | func (p *columnFilterExpression) tableNameIsWildcard() bool { 95 | return p.tableName == wildcardLiteral 96 | } 97 | 98 | // columnNameIsWildcard returns true if the column name is wildcard. 99 | func (p *columnFilterExpression) columnNameIsWildcard() bool { 100 | return p.columnName == wildcardLiteral 101 | } 102 | 103 | // BuildColumnFilter returns a ColumnFilter. 104 | func (p *columnFilterExpression) BuildColumnFilter(otherColumnNamesInsideTheTable []string) ColumnFilter { 105 | return func(c *Column) bool { 106 | if p.tableName == "" || p.columnName == "" { 107 | return false 108 | } 109 | 110 | if !p.tableNameIsWildcard() { 111 | if c.TableName != p.tableName { 112 | return false 113 | } 114 | } 115 | 116 | if p.columnDataType != InvalidDataType { 117 | if c.Type != p.columnDataType { 118 | return false 119 | } 120 | } 121 | 122 | if p.prefix != "" { 123 | if !strings.HasPrefix(c.Name, p.prefix) { 124 | return false 125 | } 126 | } else if p.suffix != "" { 127 | if !strings.HasSuffix(c.Name, p.suffix) { 128 | return false 129 | } 130 | } else if p.notEqualTo != "" { 131 | if c.Name == p.notEqualTo { 132 | return false 133 | } 134 | } else { 135 | if !p.columnNameIsWildcard() { 136 | if c.Name != p.columnName { 137 | return false 138 | } 139 | } 140 | } 141 | 142 | if len(p.containsColumnNames) > 0 { 143 | foundCount := 0 144 | for _, columnName := range p.containsColumnNames { 145 | for _, v := range otherColumnNamesInsideTheTable { 146 | if columnName == v { 147 | foundCount++ 148 | break 149 | } 150 | } 151 | } 152 | 153 | if foundCount != len(p.containsColumnNames) { 154 | return false 155 | } 156 | } 157 | 158 | return true 159 | } 160 | } 161 | 162 | // parseColumnFilterExpression parses the input string and returns a slice of columnFilterExpression. 163 | func parseColumnFilterExpression(input string) ([]*columnFilterExpression, error) { 164 | var expressions []*columnFilterExpression 165 | 166 | fields := strings.FieldsFunc(input, func(r rune) bool { 167 | return r == '.' 168 | }) 169 | 170 | if len(fields) < 2 || len(fields) > 3 { 171 | return nil, fmt.Errorf("invalid input: %s", input) 172 | } 173 | 174 | tableName := fields[0] 175 | columnLine := fields[1] 176 | 177 | columnName := columnLine 178 | dataType := InvalidDataType 179 | if len(fields) == 3 { 180 | dataType, _ = ParseDataType(fields[2]) 181 | if dataType == InvalidDataType { 182 | return nil, fmt.Errorf("invalid data type: %s", fields[2]) 183 | } 184 | } 185 | 186 | var tableShouldContainColumnNames []string 187 | 188 | containsIdx := strings.IndexByte(columnLine, '&') 189 | if containsIdx > 0 { 190 | rest := columnLine[containsIdx+1:] 191 | tableShouldContainColumnNames = strings.Split(rest, ",") 192 | columnName = columnLine[0:containsIdx] 193 | } 194 | 195 | var moreColumnNames []string 196 | multipleIdx := strings.IndexByte(columnLine, ',') 197 | containsColumns := multipleIdx > 0 && (containsIdx == -1 || multipleIdx < containsIdx) 198 | if containsColumns { // for order, maybe we can improve it even better. 199 | columnName = columnLine[0:multipleIdx] 200 | } 201 | 202 | prefix, suffix, notEqualTo := parseColumnNameFilterFuncs(columnName) 203 | expr := &columnFilterExpression{ 204 | input: input, 205 | tableName: tableName, 206 | columnName: columnName, 207 | columnDataType: dataType, 208 | prefix: prefix, 209 | suffix: suffix, 210 | notEqualTo: notEqualTo, 211 | containsColumnNames: tableShouldContainColumnNames, 212 | } 213 | 214 | expressions = append(expressions, expr) 215 | 216 | if containsColumns { 217 | rest := columnLine[multipleIdx+1:] 218 | stopIdx := strings.IndexFunc(rest, func(r rune) bool { 219 | return r == '&' // || r == rune(rest[len(rest)-1]) // & or last letter. 220 | }) 221 | 222 | if stopIdx == -1 { 223 | stopIdx = len(rest) 224 | } 225 | 226 | moreColumnNames = strings.Split(rest[0:stopIdx], ",") 227 | 228 | // fmt.Printf("rest:stopidx = %s, more column names: %s\n", rest[0:stopIdx], strings.Join(moreColumnNames, ",)")) 229 | 230 | for _, columnName := range moreColumnNames { 231 | prefix, suffix, notEqualTo := parseColumnNameFilterFuncs(columnName) 232 | expr := &columnFilterExpression{ 233 | input: input, 234 | tableName: tableName, 235 | columnName: columnName, 236 | columnDataType: dataType, 237 | prefix: prefix, 238 | suffix: suffix, 239 | notEqualTo: notEqualTo, 240 | containsColumnNames: tableShouldContainColumnNames, 241 | } 242 | 243 | expressions = append(expressions, expr) 244 | } 245 | 246 | } 247 | 248 | return expressions, nil 249 | } 250 | 251 | func parseColumnNameFilterFuncs(columnName string) (prefix string, suffix string, notEqualTo string) { 252 | if strings.HasPrefix(columnName, "prefix(") { 253 | prefix = strings.TrimPrefix(columnName, "prefix(") 254 | prefix = strings.TrimSuffix(prefix, ")") 255 | columnName = prefix // assume the column name is the same as the prefix 256 | } else if strings.HasPrefix(columnName, "suffix(") { 257 | suffix = strings.TrimPrefix(columnName, "suffix(") 258 | suffix = strings.TrimSuffix(suffix, ")") 259 | columnName = suffix // assume the column name is the same as the suffix 260 | } else if strings.HasPrefix(columnName, "noteq(") { 261 | notEqualTo = strings.TrimPrefix(columnName, "noteq(") 262 | notEqualTo = strings.TrimSuffix(notEqualTo, ")") 263 | columnName = notEqualTo // assume the column name is the same as the not equal value 264 | } 265 | 266 | return 267 | } 268 | -------------------------------------------------------------------------------- /desc/column_filter_text_parser_test.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestParseColumnFilterExpression(t *testing.T) { 10 | var tests = []struct { 11 | input string 12 | expected []columnFilterExpression 13 | }{ 14 | {"tablename.column.varchar", []columnFilterExpression{ 15 | { 16 | tableName: "tablename", 17 | columnName: "column", 18 | columnDataType: CharacterVarying, 19 | }, 20 | }}, 21 | {"*.column", []columnFilterExpression{ 22 | { 23 | tableName: "*", 24 | columnName: "column", 25 | }, 26 | }}, 27 | {"*.column.varchar", []columnFilterExpression{ 28 | { 29 | tableName: "*", 30 | columnName: "column", 31 | columnDataType: CharacterVarying, 32 | }, 33 | }}, 34 | {"tablename.prefix(col).varchar", []columnFilterExpression{ 35 | { 36 | tableName: "tablename", 37 | columnName: "prefix(col)", 38 | columnDataType: CharacterVarying, 39 | prefix: "col", 40 | }, 41 | }}, 42 | {"tablename.suffix(col).varchar", []columnFilterExpression{ 43 | { 44 | tableName: "tablename", 45 | columnName: "suffix(col)", 46 | columnDataType: CharacterVarying, 47 | suffix: "col", 48 | }, 49 | }}, 50 | {"tablename.noteq(col).varchar", []columnFilterExpression{ 51 | { 52 | tableName: "tablename", 53 | columnName: "noteq(col)", 54 | columnDataType: CharacterVarying, 55 | notEqualTo: "col", 56 | }, 57 | }}, 58 | {"tablename.column1&column2,column3.varchar", []columnFilterExpression{ 59 | { 60 | tableName: "tablename", 61 | columnName: "column1", 62 | columnDataType: CharacterVarying, 63 | containsColumnNames: []string{"column2", "column3"}, 64 | }, 65 | }}, 66 | {"tablename.column1&column2,column3", []columnFilterExpression{ 67 | { 68 | tableName: "tablename", 69 | columnName: "column1", 70 | containsColumnNames: []string{"column2", "column3"}, 71 | }, 72 | }}, 73 | {"*.column1,column2,column3&column4.character[]", []columnFilterExpression{ 74 | { 75 | tableName: "*", 76 | columnName: "column1", 77 | columnDataType: CharacterArray, 78 | containsColumnNames: []string{"column4"}, 79 | }, 80 | { 81 | tableName: "*", 82 | columnName: "column2", 83 | columnDataType: CharacterArray, 84 | containsColumnNames: []string{"column4"}, 85 | }, 86 | { 87 | tableName: "*", 88 | columnName: "column3", 89 | columnDataType: CharacterArray, 90 | containsColumnNames: []string{"column4"}, 91 | }, 92 | }}, 93 | {"*.column1,column2,column3&column4", []columnFilterExpression{ 94 | { 95 | tableName: "*", 96 | columnName: "column1", 97 | containsColumnNames: []string{"column4"}, 98 | }, 99 | { 100 | tableName: "*", 101 | columnName: "column2", 102 | containsColumnNames: []string{"column4"}, 103 | }, 104 | { 105 | tableName: "*", 106 | columnName: "column3", 107 | containsColumnNames: []string{"column4"}, 108 | }, 109 | }}, 110 | {"tablename.column1,column2,column3&column4", []columnFilterExpression{ 111 | { 112 | tableName: "tablename", 113 | columnName: "column1", 114 | containsColumnNames: []string{"column4"}, 115 | }, 116 | { 117 | tableName: "tablename", 118 | columnName: "column2", 119 | containsColumnNames: []string{"column4"}, 120 | }, 121 | { 122 | tableName: "tablename", 123 | columnName: "column3", 124 | containsColumnNames: []string{"column4"}, 125 | }, 126 | }}, 127 | } 128 | 129 | for i, tt := range tests { 130 | exprs, err := parseColumnFilterExpression(tt.input) 131 | if err != nil { 132 | t.Fatal(err) 133 | } 134 | 135 | for j, expr := range exprs { 136 | tt.expected[j].input = tt.input 137 | 138 | if !reflect.DeepEqual(*expr, tt.expected[j]) { 139 | t.Fatalf("[%d:%d] [%s] expected:\n%v\nbut got:\n%v", i, j, tt.input, tt.expected[j], *expr) 140 | } 141 | } 142 | } 143 | } 144 | 145 | func TestParsedColumnFilterExpression(t *testing.T) { 146 | exprs, err := parseColumnFilterExpression("tablename.column1,column2,column3&column4.varchar") 147 | if err != nil { 148 | t.Fatal(err) 149 | } 150 | 151 | for i, expr := range exprs { 152 | filter := expr.BuildColumnFilter([]string{"column4"}) 153 | column := &Column{ 154 | TableName: "tablename", 155 | Name: fmt.Sprintf("column%d", i+1), 156 | Type: CharacterVarying, 157 | } 158 | 159 | if !filter(column) { 160 | t.Fatalf("[%d] expected to pass", i) 161 | } 162 | 163 | column2 := &Column{ 164 | TableName: "other", 165 | Name: fmt.Sprintf("column%d", i+1), 166 | Type: CharacterVarying, 167 | } 168 | if filter(column2) { 169 | t.Fatalf("[%d] expected not to pass", i) 170 | } 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /desc/constraint_test.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | // TestParseForeignKeyConstraint tests the parseForeignKeyConstraint function with a variety 9 | // of PostgreSQL foreign key definitions, ensuring that all supported actions (CASCADE, RESTRICT, 10 | // NO ACTION, SET NULL, SET DEFAULT) for ON DELETE and ON UPDATE clauses—as well as the DEFERRABLE 11 | // flag—are parsed correctly. 12 | func TestParseForeignKeyConstraint(t *testing.T) { 13 | tests := []struct { 14 | name string 15 | input string 16 | expected *ForeignKeyConstraint 17 | }{ 18 | { 19 | name: "Minimal definition", 20 | input: "FOREIGN KEY (col) REFERENCES tbl (ref)", 21 | expected: &ForeignKeyConstraint{ 22 | ColumnName: "col", 23 | ReferenceTableName: "tbl", 24 | ReferenceColumnName: "ref", 25 | OnDelete: "", 26 | OnUpdate: "", 27 | Deferrable: false, 28 | }, 29 | }, 30 | { 31 | name: "ON DELETE CASCADE", 32 | input: "FOREIGN KEY (col) REFERENCES tbl (ref) ON DELETE CASCADE", 33 | expected: &ForeignKeyConstraint{ 34 | ColumnName: "col", 35 | ReferenceTableName: "tbl", 36 | ReferenceColumnName: "ref", 37 | OnDelete: "CASCADE", 38 | OnUpdate: "", 39 | Deferrable: false, 40 | }, 41 | }, 42 | { 43 | name: "ON DELETE RESTRICT", 44 | input: "FOREIGN KEY (col) REFERENCES tbl (ref) ON DELETE RESTRICT", 45 | expected: &ForeignKeyConstraint{ 46 | ColumnName: "col", 47 | ReferenceTableName: "tbl", 48 | ReferenceColumnName: "ref", 49 | OnDelete: "RESTRICT", 50 | OnUpdate: "", 51 | Deferrable: false, 52 | }, 53 | }, 54 | { 55 | name: "ON DELETE NO ACTION", 56 | input: "FOREIGN KEY (col) REFERENCES tbl (ref) ON DELETE NO ACTION", 57 | expected: &ForeignKeyConstraint{ 58 | ColumnName: "col", 59 | ReferenceTableName: "tbl", 60 | ReferenceColumnName: "ref", 61 | OnDelete: "NO ACTION", 62 | OnUpdate: "", 63 | Deferrable: false, 64 | }, 65 | }, 66 | { 67 | name: "ON DELETE SET NULL", 68 | input: "FOREIGN KEY (col) REFERENCES tbl (ref) ON DELETE SET NULL", 69 | expected: &ForeignKeyConstraint{ 70 | ColumnName: "col", 71 | ReferenceTableName: "tbl", 72 | ReferenceColumnName: "ref", 73 | OnDelete: "SET NULL", 74 | OnUpdate: "", 75 | Deferrable: false, 76 | }, 77 | }, 78 | { 79 | name: "ON DELETE SET DEFAULT", 80 | input: "FOREIGN KEY (col) REFERENCES tbl (ref) ON DELETE SET DEFAULT", 81 | expected: &ForeignKeyConstraint{ 82 | ColumnName: "col", 83 | ReferenceTableName: "tbl", 84 | ReferenceColumnName: "ref", 85 | OnDelete: "SET DEFAULT", 86 | OnUpdate: "", 87 | Deferrable: false, 88 | }, 89 | }, 90 | { 91 | name: "ON UPDATE CASCADE", 92 | input: "FOREIGN KEY (col) REFERENCES tbl (ref) ON UPDATE CASCADE", 93 | expected: &ForeignKeyConstraint{ 94 | ColumnName: "col", 95 | ReferenceTableName: "tbl", 96 | ReferenceColumnName: "ref", 97 | OnDelete: "", 98 | OnUpdate: "CASCADE", 99 | Deferrable: false, 100 | }, 101 | }, 102 | { 103 | name: "ON UPDATE RESTRICT", 104 | input: "FOREIGN KEY (col) REFERENCES tbl (ref) ON UPDATE RESTRICT", 105 | expected: &ForeignKeyConstraint{ 106 | ColumnName: "col", 107 | ReferenceTableName: "tbl", 108 | ReferenceColumnName: "ref", 109 | OnDelete: "", 110 | OnUpdate: "RESTRICT", 111 | Deferrable: false, 112 | }, 113 | }, 114 | { 115 | name: "ON UPDATE NO ACTION", 116 | input: "FOREIGN KEY (col) REFERENCES tbl (ref) ON UPDATE NO ACTION", 117 | expected: &ForeignKeyConstraint{ 118 | ColumnName: "col", 119 | ReferenceTableName: "tbl", 120 | ReferenceColumnName: "ref", 121 | OnDelete: "", 122 | OnUpdate: "NO ACTION", 123 | Deferrable: false, 124 | }, 125 | }, 126 | { 127 | name: "ON UPDATE SET NULL", 128 | input: "FOREIGN KEY (col) REFERENCES tbl (ref) ON UPDATE SET NULL", 129 | expected: &ForeignKeyConstraint{ 130 | ColumnName: "col", 131 | ReferenceTableName: "tbl", 132 | ReferenceColumnName: "ref", 133 | OnDelete: "", 134 | OnUpdate: "SET NULL", 135 | Deferrable: false, 136 | }, 137 | }, 138 | { 139 | name: "ON UPDATE SET DEFAULT", 140 | input: "FOREIGN KEY (col) REFERENCES tbl (ref) ON UPDATE SET DEFAULT", 141 | expected: &ForeignKeyConstraint{ 142 | ColumnName: "col", 143 | ReferenceTableName: "tbl", 144 | ReferenceColumnName: "ref", 145 | OnDelete: "", 146 | OnUpdate: "SET DEFAULT", 147 | Deferrable: false, 148 | }, 149 | }, 150 | { 151 | name: "Combined ON DELETE and ON UPDATE", 152 | input: "FOREIGN KEY (col) REFERENCES tbl (ref) ON DELETE CASCADE ON UPDATE NO ACTION", 153 | expected: &ForeignKeyConstraint{ 154 | ColumnName: "col", 155 | ReferenceTableName: "tbl", 156 | ReferenceColumnName: "ref", 157 | OnDelete: "CASCADE", 158 | OnUpdate: "NO ACTION", 159 | Deferrable: false, 160 | }, 161 | }, 162 | { 163 | name: "Combined with DEFERRABLE", 164 | input: "FOREIGN KEY (col) REFERENCES tbl (ref) ON DELETE RESTRICT ON UPDATE SET DEFAULT DEFERRABLE", 165 | expected: &ForeignKeyConstraint{ 166 | ColumnName: "col", 167 | ReferenceTableName: "tbl", 168 | ReferenceColumnName: "ref", 169 | OnDelete: "RESTRICT", 170 | OnUpdate: "SET DEFAULT", 171 | Deferrable: true, 172 | }, 173 | }, 174 | { 175 | name: "Case Insensitive and extra spaces", 176 | input: "foreign key (Col) references TBL (Ref) on delete set null on update cascade deferrable", 177 | expected: &ForeignKeyConstraint{ 178 | ColumnName: "Col", 179 | ReferenceTableName: "TBL", 180 | ReferenceColumnName: "Ref", 181 | OnDelete: "SET NULL", 182 | OnUpdate: "CASCADE", 183 | Deferrable: true, 184 | }, 185 | }, 186 | } 187 | 188 | for _, tt := range tests { 189 | t.Run(tt.name, func(t *testing.T) { 190 | result := parseForeignKeyConstraint(tt.input) 191 | if !reflect.DeepEqual(result, tt.expected) { 192 | t.Errorf("For input %q, expected %+v, got %+v", tt.input, tt.expected, result) 193 | } 194 | }) 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /desc/create_table_query.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | // BuildCreateTableQuery creates a table in the database according to the given table definition. 10 | func BuildCreateTableQuery(td *Table) string { 11 | // Generate the SQL query to create the table 12 | var query strings.Builder 13 | 14 | // Start with the CREATE TABLE statement and the table name 15 | query.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (", td.Name)) 16 | 17 | columns := td.ListColumnsWithoutPresenter() 18 | // Loop over the columns and append their definitions to the query 19 | for i, col := range columns { 20 | // Add the column name and type 21 | query.WriteString(strconv.Quote(col.Name) + " " + col.Type.String()) 22 | 23 | // Add the type argument if any 24 | if col.TypeArgument != "" { 25 | query.WriteString(fmt.Sprintf("(%s)", col.TypeArgument)) 26 | } 27 | 28 | // Add the default value if any 29 | if col.Default != "" { 30 | query.WriteString(" DEFAULT " + col.Default) 31 | } 32 | // Add the NOT NULL constraint if applicable 33 | if !col.Nullable { 34 | query.WriteString(" NOT NULL") 35 | } 36 | // Add the UNIQUE constraint if applicable 37 | if col.Unique { 38 | query.WriteString(" UNIQUE") 39 | } 40 | 41 | // Add the CHECK constraint if any 42 | if col.CheckConstraint != "" { 43 | query.WriteString(fmt.Sprintf(" CHECK (%s)", col.CheckConstraint)) 44 | } 45 | 46 | // Add a comma separator if this is not the last column. 47 | if i < len(columns)-1 { 48 | query.WriteString(", ") 49 | } 50 | } 51 | 52 | // Add the primary key constraint if any. We only allow one Primary Key column. 53 | if primaryKey, ok := td.PrimaryKey(); ok { 54 | query.WriteString(fmt.Sprintf(`, PRIMARY KEY ("%s")`, primaryKey.Name)) 55 | } 56 | 57 | // Loop over the foreign key constraints and append them to the query 58 | /* No, let's create foreign keys at the end of the all known tables creation, 59 | so registeration order does not matter. 60 | for _, fk := range td.ForeignKeys() { 61 | query.WriteString(fmt.Sprintf(", FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE %s", fk.ColumnName, fk.ReferenceTableName, fk.ReferenceColumnName, fk.OnDelete)) 62 | 63 | // Add the DEFERRABLE option if applicable 64 | if fk.Deferrable { 65 | query.WriteString(" DEFERRABLE") 66 | } 67 | } 68 | See `buildAlterTableForeignKeysQuery`. 69 | */ 70 | 71 | // Loop over the unique indexes and append them to the query as constraints, 72 | // no WHERE clause is allowed in this case. 73 | // 74 | // Read more at: https://stackoverflow.com/questions/23542794/postgres-unique-constraint-vs-index 75 | for idxName, colNames := range td.UniqueIndexes() { 76 | for i := range colNames { 77 | colNames[i] = strconv.Quote(colNames[i]) // quote column names. 78 | } 79 | query.WriteString(fmt.Sprintf(", CONSTRAINT %s UNIQUE (%s)", idxName, strings.Join(colNames, ", "))) 80 | } 81 | 82 | // Close the CREATE TABLE statement with a semicolon 83 | query.WriteString(");") 84 | 85 | // Loop over the non-unique indexes and append them to the query as separate statements 86 | for _, idx := range td.Indexes() { 87 | // Use the CREATE INDEX statement with the index name, table name, type and column name 88 | query.WriteString(fmt.Sprintf(`CREATE INDEX IF NOT EXISTS %s ON %s USING %s ("%s");`, 89 | idx.Name, td.Name, idx.Type.String(), idx.ColumnName)) 90 | } 91 | 92 | return query.String() 93 | } 94 | -------------------------------------------------------------------------------- /desc/delete_query.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import "fmt" 4 | 5 | // BuildDeleteQuery builds and returns a SQL query for deleting one or more rows from a table. 6 | func BuildDeleteQuery(td *Table, values []any) (string, []any, error) { 7 | // extract the primary key column name and the primary key values from the table definition and the values 8 | primaryKeyName, ids, err := extractPrimaryKeyValues(td, values) 9 | if err != nil { 10 | return "", nil, err // return false and the wrapped error if extracting fails 11 | } 12 | 13 | // build the SQL query using the table name, the primary key name and a placeholder for the primary key values 14 | query := fmt.Sprintf(`DELETE FROM "%s" WHERE "%s" = ANY($1);`, td.Name, primaryKeyName) 15 | return query, ids, nil 16 | } 17 | -------------------------------------------------------------------------------- /desc/desc.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | var ( 4 | // DefaultTag is the default struct field tag. 5 | DefaultTag = "pg" 6 | // DefaultSearchPath is the default search path for the table. 7 | DefaultSearchPath = "public" 8 | ) 9 | -------------------------------------------------------------------------------- /desc/duplicate_query.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // BuildDuplicateQuery returns a query that duplicates a row by its primary key. 9 | func BuildDuplicateQuery(td *Table, idPtr any) (string, error) { 10 | primaryKey, ok := td.PrimaryKey() // get the primary key column definition from the table definition 11 | if !ok { 12 | return "", fmt.Errorf("duplicate: no primary key") 13 | } 14 | 15 | returningColumn := "" // a variable to store the name of the column to return after insertion 16 | if idPtr != nil && ok { 17 | // if idPtr is not nil, it means we want to get the primary key value of the inserted row 18 | returningColumn = primaryKey.Name // assign the column name to returningColumn 19 | } 20 | 21 | var b strings.Builder 22 | 23 | // INSERT INTO "schema"."tableName" 24 | b.WriteString(`INSERT INTO`) 25 | b.WriteByte(' ') 26 | writeTableName(&b, td.SearchPath, td.Name) 27 | b.WriteByte(' ') 28 | 29 | // (name, tag, source_id) 30 | b.WriteByte(leftParenLiteral) 31 | 32 | columns := td.listColumnsForSelectWithoutGenerated() 33 | for i, c := range columns { 34 | if i > 0 { 35 | b.WriteByte(',') 36 | } 37 | 38 | b.WriteString(c.Name) 39 | } 40 | 41 | b.WriteByte(rightParenLiteral) 42 | 43 | // SELECT (name, tag, COALESCE(source_id, id)) 44 | b.WriteByte(' ') 45 | b.WriteString(`SELECT`) 46 | b.WriteByte(' ') 47 | 48 | for i, c := range columns { 49 | if i > 0 { 50 | b.WriteByte(',') 51 | } 52 | 53 | columnName := c.Name 54 | if c.ReferenceColumnName == primaryKey.Name && c.ReferenceTableName == td.Name { 55 | // If self reference, then COALESCE(source_id, id). 56 | // This work as self reference is always a one-to-one relationship, 57 | // useful for applications that use tables for original vs daily weekly X plans. 58 | // 59 | // NOTE: TODO: However, keep a note that COALESCE may not work on that case, need testing on actual db. 60 | columnName = fmt.Sprintf("COALESCE(%s, %s)", c.Name, primaryKey.Name) 61 | } 62 | 63 | b.WriteString(columnName) 64 | } 65 | 66 | // FROM "schema"."tableName" 67 | b.WriteString(" FROM ") 68 | writeTableName(&b, td.SearchPath, td.Name) 69 | 70 | // WHERE id = $1 71 | buildWhereSubQueryByArguments(&b, Arguments{ 72 | { 73 | Column: primaryKey, 74 | }, 75 | }) 76 | 77 | // RETURNING id 78 | // If returningColumn is not empty. 79 | writeInsertReturning(&b, returningColumn) 80 | 81 | b.WriteByte(';') 82 | 83 | query := b.String() 84 | return query, nil 85 | } 86 | -------------------------------------------------------------------------------- /desc/duplicate_query_test.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import "testing" 4 | 5 | func TestBuildDuplicateQuery(t *testing.T) { 6 | td := &Table{ 7 | SearchPath: "public", 8 | Name: "test", 9 | } 10 | 11 | td.AddColumns( 12 | &Column{ 13 | Name: "id", 14 | PrimaryKey: true, 15 | Type: UUID, 16 | Default: genRandomUUIDPGCryptoFunction1, 17 | }, 18 | &Column{ 19 | Name: "created_at", 20 | Type: Timestamp, 21 | Default: "clock_timestamp()", 22 | }, 23 | &Column{ 24 | Name: "source_id", 25 | Type: UUID, 26 | ReferenceTableName: "test", 27 | ReferenceColumnName: "id", 28 | }, 29 | &Column{ 30 | Name: "name", 31 | Type: BitVarying, 32 | TypeArgument: "255", 33 | }, 34 | ) 35 | 36 | var newID string 37 | query, err := BuildDuplicateQuery(td, &newID) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | 42 | expected := `INSERT INTO "public"."test" (source_id,name) SELECT COALESCE(source_id, id),name FROM "public"."test" WHERE id = $1 RETURNING id;` 43 | if query != expected { 44 | t.Logf("expected duplicated query (returning id) to match: %s, but got: %s", expected, query) 45 | } 46 | 47 | queryNoReturningID, err := BuildDuplicateQuery(td, nil) 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | 52 | expectedyNoReturningID := `INSERT INTO "public"."test" (source_id,name) SELECT COALESCE(source_id, id),name FROM "public"."test" WHERE id = $1;` 53 | if queryNoReturningID != expectedyNoReturningID { 54 | t.Logf("expected duplicated query (no returning id) to match: %s, but got: %s", expected, query) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /desc/exists_query.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | // BuildExistsQuery builds and returns an SQL query for checking of existing in a row in the table, 11 | // based on the given struct value. 12 | func BuildExistsQuery(td *Table, structValue reflect.Value) (string, []any, error) { 13 | args, err := extractArguments(td, structValue, nil) 14 | if err != nil { 15 | return "", nil, err // return the error if finding arguments fails 16 | } 17 | 18 | if len(args) == 0 { 19 | return "", nil, fmt.Errorf(`no arguments found for exists, maybe missing struct field tag of "%s"`, DefaultTag) // return an error if no arguments are found. 20 | } 21 | 22 | // build the SQL query using the table definition, 23 | // the arguments and the returning column 24 | query := buildExistsQuery(td, args) 25 | return query, args.Values(), nil 26 | } 27 | 28 | // buildExistsQuery builds and returns an SQL query for checking of existing in a row in the table, 29 | // based on the given arguments. 30 | func buildExistsQuery(td *Table, args Arguments) string { 31 | // Create a new strings.Builder 32 | var b strings.Builder 33 | 34 | // Write the query prefix 35 | b.WriteString(`SELECT EXISTS(SELECT 1 FROM "` + td.Name + `"`) 36 | 37 | buildWhereSubQueryByArguments(&b, args) 38 | 39 | // Write the query (EXISTS) suffix 40 | b.WriteString(")") 41 | 42 | b.WriteByte(';') 43 | 44 | // Return the query string 45 | return b.String() 46 | } 47 | 48 | func buildWhereSubQueryByArguments(b *strings.Builder, args Arguments) { 49 | b.WriteString(` WHERE `) 50 | 51 | var paramIndex int 52 | 53 | for i, a := range args { 54 | c := a.Column 55 | 56 | if i > 0 { 57 | b.WriteString(" AND ") 58 | } 59 | 60 | paramIndex++ // starts from 1. 61 | paramIndexStr := strconv.Itoa(paramIndex) 62 | paramName := "$" + paramIndexStr 63 | 64 | b.WriteString(fmt.Sprintf("%s = %s", c.Name, paramName)) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /desc/index_type.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // IndexType is an enumeration type that represents different types of indexes in a database. 9 | type IndexType uint8 10 | 11 | // These are the possible values for IndexType. 12 | const ( 13 | InvalidIndex IndexType = iota // InvalidIndex is the zero value for IndexType and indicates an invalid or unknown index type 14 | Btree // Btree is an index type that uses a balanced tree data structure 15 | Hash // Hash is an index type that uses a hash table data structure 16 | Gist // Gist is an index type that supports generalized search trees for various data types 17 | Spgist // Spgist is an index type that supports space-partitioned generalized search trees for various data types 18 | Gin // Gin is an index type that supports inverted indexes for various data types 19 | Brin // Brin is an index type that supports block range indexes for large tables 20 | ) 21 | 22 | // indexTypeText is a map from IndexType to its string representation. 23 | var indexTypeText = map[IndexType]string{ 24 | Btree: "btree", 25 | Hash: "hash", 26 | Gist: "gist", 27 | Spgist: "spgist", 28 | Gin: "gin", 29 | Brin: "brin", 30 | } 31 | 32 | // String returns the string representation of an IndexType value. 33 | func (t IndexType) String() string { 34 | if name, ok := indexTypeText[t]; ok { 35 | return name // if the value is in the map, return the corresponding name 36 | } 37 | 38 | return fmt.Sprintf("IndexType(unexpected %d)", t) // otherwise, return a formatted string with the numeric value 39 | } 40 | 41 | func (t *IndexType) Scan(src interface{}) error { 42 | if src == nil { 43 | return nil 44 | } 45 | 46 | s, ok := src.(string) 47 | if !ok { 48 | return fmt.Errorf("index type: unknown type of: %T", src) 49 | } 50 | 51 | if s == "" { // allow empty strings to be scanned as nil. 52 | return nil 53 | } 54 | 55 | for k, v := range indexTypeText { 56 | if v == s { 57 | *t = k 58 | return nil 59 | } 60 | } 61 | 62 | return fmt.Errorf("index type: unknown value of: %s", s) 63 | } 64 | 65 | // parseIndexType takes a string and returns the corresponding IndexType value. 66 | func parseIndexType(s string) IndexType { 67 | s = strings.ToLower(s) // convert the string to lower case for case-insensitive comparison 68 | for t, name := range indexTypeText { 69 | if s == name { 70 | return t // if the string matches a name in the map, return the corresponding value 71 | } 72 | } 73 | 74 | return InvalidIndex // otherwise, return InvalidIndex to indicate an invalid or unknown index type 75 | } 76 | -------------------------------------------------------------------------------- /desc/insert_query.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | // BuildInsertQuery builds and returns an SQL query for inserting a row into the table, 11 | // based on the given struct value, arguments, and returning column. 12 | // The struct value is a reflect.Value of the struct that represents the row to be inserted. 13 | // The arguments are a slice of Argument that contains the column definitions and values for each field of the struct. 14 | // The returning column is an optional string that specifies which column to return after the insertion, 15 | // such as the primary key or any other generated value. 16 | func BuildInsertQuery(td *Table, structValue reflect.Value, idPtr any, forceOnConflictExpr string, upsert bool) (string, []any, error) { 17 | returningColumn := "" // a variable to store the name of the column to return after insertion 18 | if idPtr != nil { 19 | // if idPtr is not nil, it means we want to get the primary key value of the inserted row 20 | columnDefinition, ok := td.PrimaryKey() // get the primary key column definition from the table definition 21 | if ok { 22 | returningColumn = columnDefinition.Name // assign the column name to returningColumn 23 | } 24 | } 25 | 26 | // find the arguments for the SQL query based on the struct value and the table definition 27 | args, err := extractArguments(td, structValue, nil) 28 | if err != nil { 29 | return "", nil, err // return the error if finding arguments fails 30 | } 31 | 32 | if len(args) == 0 { 33 | return "", nil, fmt.Errorf(`no arguments found, maybe missing struct field tag of "%s"`, DefaultTag) // return an error if no arguments are found. 34 | } 35 | 36 | // build the SQL query using the table definition, 37 | // the arguments and the returning column 38 | query, err := buildInsertQuery(td, args, returningColumn, forceOnConflictExpr, upsert) 39 | if err != nil { 40 | return "", nil, err 41 | } 42 | 43 | return query, args.Values(), nil 44 | } 45 | 46 | func buildInsertQuery(td *Table, args Arguments, returningColumn string, forceOnConflictExpr string, upsert bool) (string, error) { 47 | var b strings.Builder 48 | 49 | // INSERT INTO "schema"."tableName" 50 | b.WriteString(`INSERT INTO`) 51 | b.WriteByte(' ') 52 | writeTableName(&b, td.SearchPath, td.Name) 53 | b.WriteByte(' ') 54 | 55 | var ( 56 | namedParametersValues = make([]string, 0, len(td.Columns)) 57 | columnNamesToInsert = make([]string, 0, len(td.Columns)) 58 | paramIndex int 59 | 60 | conflicts []string 61 | ) 62 | 63 | // (record_id,record) 64 | b.WriteByte(leftParenLiteral) 65 | 66 | onConflictExpression, hasConflict := td.OnConflict() 67 | 68 | i := 0 69 | for _, a := range args { 70 | c := a.Column 71 | 72 | if i > 0 { 73 | b.WriteByte(',') 74 | } 75 | 76 | if hasConflict { 77 | // if conflict is empty then an error of: duplicate key value violates unique constraint "$key" 78 | // will be fired, otherwise even if one is exists 79 | // then error will be ignored but we can't use the returning id. 80 | 81 | if c.Unique { 82 | conflicts = append(conflicts, c.Name) 83 | } 84 | } else if c.UniqueIndex != "" { 85 | conflicts = append(conflicts, c.Name) 86 | } 87 | 88 | paramIndex++ // starts from 1. 89 | paramIndexStr := strconv.Itoa(paramIndex) 90 | paramName := "$" + paramIndexStr 91 | 92 | if c.Password { 93 | if td.PasswordHandler.canEncrypt() { 94 | // handled at args state. 95 | } else { 96 | paramName = buildInsertPassword(paramName) 97 | } 98 | } 99 | 100 | namedParametersValues = append(namedParametersValues, paramName) 101 | columnNamesToInsert = append(columnNamesToInsert, c.Name) 102 | 103 | b.WriteString(c.Name) 104 | i++ 105 | } 106 | 107 | if len(namedParametersValues) == 0 { 108 | return "", fmt.Errorf("no columns to insert") 109 | } 110 | 111 | // set on conflict expression by custom unqiue index or column name 112 | // even if upsert is not true. 113 | if forceOnConflictExpr != "" { 114 | uniqueIndexes := td.UniqueIndexes() 115 | 116 | selectedUniqueIndexColumns, ok := uniqueIndexes[forceOnConflictExpr] 117 | if ok { 118 | conflicts = selectedUniqueIndexColumns // override the conflicts. 119 | } else { 120 | // if not found then check for unique index OR unique by column name. 121 | for _, conflict := range conflicts { 122 | if conflict == forceOnConflictExpr { 123 | conflicts = []string{forceOnConflictExpr} // override the conflicts. 124 | } 125 | } 126 | } 127 | 128 | if len(conflicts) == 0 { // force check of conflicts. 129 | return "", fmt.Errorf("can't find unique index with name: %s", forceOnConflictExpr) 130 | } 131 | 132 | // override the on conflict expression. 133 | onConflictExpression = `DO UPDATE SET ` 134 | j := 0 135 | for _, colName := range columnNamesToInsert { 136 | excluded := false 137 | for _, conflict := range conflicts { // skip the conflict columns. 138 | if conflict == colName { 139 | excluded = true 140 | } 141 | } 142 | 143 | if excluded { 144 | continue 145 | } 146 | if j > 0 { 147 | onConflictExpression += "," 148 | } 149 | 150 | onConflictExpression += fmt.Sprintf(`%s = EXCLUDED.%s`, colName, colName) 151 | j++ 152 | } 153 | } else if upsert && len(conflicts) > 0 && !hasConflict { 154 | // if asked for upsert and forceOnConflictExpr is empty, conflicts are set from unique_index or unique as always, 155 | // but on conflict tag was not set manually then generate a full upsert method to update all columns. 156 | 157 | // override the on conflict expression. 158 | onConflictExpression = `DO UPDATE SET ` 159 | j := 0 160 | for _, colName := range columnNamesToInsert { 161 | excluded := false 162 | for _, conflict := range conflicts { // skip the conflict columns. 163 | if conflict == colName { 164 | excluded = true 165 | } 166 | } 167 | 168 | if excluded { 169 | continue 170 | } 171 | if j > 0 { 172 | onConflictExpression += "," 173 | } 174 | 175 | onConflictExpression += fmt.Sprintf(`%s = EXCLUDED.%s`, colName, colName) 176 | j++ 177 | } 178 | } else { 179 | // If had unique tags but no custom on conflict expression then ignore them, 180 | // so the caller receives a duplication error. 181 | conflicts = nil 182 | } 183 | 184 | b.WriteByte(rightParenLiteral) 185 | 186 | // VALUES($1,$2,$3) 187 | b.WriteByte(' ') 188 | b.WriteString(`VALUES`) 189 | 190 | b.WriteByte(leftParenLiteral) 191 | b.WriteString(strings.Join(namedParametersValues, ",")) 192 | b.WriteByte(rightParenLiteral) 193 | 194 | if len(conflicts) > 0 { 195 | // ON CONFLICT(record_id) 196 | b.WriteByte(' ') 197 | b.WriteString(`ON CONFLICT`) 198 | 199 | b.WriteByte(leftParenLiteral) 200 | b.WriteString(strings.Join(conflicts, ",")) 201 | b.WriteByte(rightParenLiteral) 202 | 203 | b.WriteByte(' ') 204 | b.WriteString(onConflictExpression) 205 | 206 | if returningColumn != "" && strings.Contains(strings.ToUpper(onConflictExpression), "DO UPDATE") { 207 | // we can still use the returning column (source: https://stackoverflow.com/a/37543015). 208 | writeInsertReturning(&b, returningColumn) 209 | } 210 | } else if returningColumn != "" { 211 | writeInsertReturning(&b, returningColumn) 212 | } 213 | 214 | b.WriteByte(';') 215 | 216 | query := b.String() 217 | return query, nil 218 | } 219 | 220 | func writeTableName(b *strings.Builder, schema, tableName string) { 221 | b.WriteString(strconv.Quote(schema)) 222 | b.WriteByte('.') 223 | b.WriteString(strconv.Quote(tableName)) 224 | } 225 | 226 | // PasswordAlg is the password algorithm the library uses to tell postgres 227 | // how to generate a password field's salt. 228 | // Alternatives: 229 | // md5 230 | // xdes 231 | // des 232 | var PasswordAlg = "bf" // max password length: 72, salt bits: 128, output length: 60, blowfish-based. 233 | 234 | const ( 235 | singleQuoteLiteral = '\'' 236 | ) 237 | 238 | // crypt($1,gen_salt('PasswordAlg')) 239 | func buildInsertPassword(paramName string) string { 240 | var b strings.Builder 241 | 242 | // crypt($1, 243 | b.WriteString(`crypt`) 244 | b.WriteByte(leftParenLiteral) 245 | b.WriteString(paramName) 246 | 247 | b.WriteByte(',') 248 | 249 | // gen_salt('bf') 250 | b.WriteString(`gen_salt`) 251 | b.WriteByte(leftParenLiteral) 252 | b.WriteByte(singleQuoteLiteral) 253 | b.WriteString(PasswordAlg) 254 | b.WriteByte(singleQuoteLiteral) 255 | b.WriteByte(rightParenLiteral) 256 | 257 | // ) 258 | b.WriteByte(rightParenLiteral) 259 | return b.String() 260 | } 261 | 262 | func writeInsertReturning(b *strings.Builder, columnKey string) { 263 | if columnKey == "" { 264 | return 265 | } 266 | 267 | b.WriteByte(' ') 268 | b.WriteString(`RETURNING`) 269 | b.WriteByte(' ') 270 | b.WriteString(columnKey) 271 | } 272 | -------------------------------------------------------------------------------- /desc/naming.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "reflect" 5 | "regexp" 6 | "strings" 7 | 8 | "github.com/gertd/go-pluralize" 9 | ) 10 | 11 | var ( 12 | // ToStructName returns the struct name for the table name. 13 | // TODO: It can go to a NewTable function. 14 | ToStructName = func(tableName string) string { return PascalCase(Singular(tableName)) } 15 | // ToStructFieldName returns the struct field name for the column name. 16 | ToStructFieldName = func(columnName string) string { return PascalCase(columnName) } 17 | // ToColumnName returns the column name for the struct field. 18 | ToColumnName = func(field reflect.StructField) string { return SnakeCase(field.Name) } 19 | ) 20 | 21 | var p = pluralize.NewClient() 22 | 23 | func init() { 24 | p.AddIrregularRule("data", "data") 25 | p.AddSingularRule("*data", "data") // e.g. customer_health_data, we do NOT want it to become customer_health_datum. 26 | } 27 | 28 | // Singular returns the singular form of the given string. 29 | func Singular(s string) string { 30 | s = p.Singular(s) 31 | return s 32 | } 33 | 34 | // SnakeCase converts a given string to a friendly snake case, e.g. 35 | // - userId to user_id 36 | // - ID to id 37 | // - ProviderAPIKey to provider_api_key 38 | // - Option to option 39 | func SnakeCase(camel string) string { 40 | var ( 41 | b strings.Builder 42 | prevWasUpper bool 43 | ) 44 | 45 | for i, c := range camel { 46 | if isUppercase(c) { // it's upper. 47 | if b.Len() > 0 && !prevWasUpper { // it's not the first and the previous was not uppercased too (e.g "ID"). 48 | b.WriteRune('_') 49 | } else { // check for XxxAPIKey, it should be written as xxx_api_key. 50 | next := i + 1 51 | if next > 1 && len(camel)-1 > next { 52 | if !isUppercase(rune(camel[next])) { 53 | b.WriteRune('_') 54 | } 55 | } 56 | } 57 | 58 | b.WriteRune(c - 'A' + 'a') // write its lowercase version. 59 | prevWasUpper = true 60 | } else { 61 | b.WriteRune(c) // write it as it is, it's already lowercased. 62 | prevWasUpper = false 63 | } 64 | } 65 | 66 | return b.String() 67 | } 68 | 69 | // isUppercase returns true if the given rune is uppercase. 70 | func isUppercase(c rune) bool { 71 | return 'A' <= c && c <= 'Z' 72 | } 73 | 74 | // This should match id, api or url (case-insensitive) 75 | // only if they are preceded or followed by either a word boundary or a non-word character. 76 | var pascalReplacer = regexp.MustCompile(`(?i)(?:\b|[^a-z0-9])(id|api|url)(?:\b|[^a-z0-9])`) 77 | 78 | // PascalCase converts a given string to a friendly pascal case, e.g. 79 | // - user_id to UserID 80 | // - id to ID 81 | // - provider_api_key to ProviderAPIKey 82 | // - customer_provider to CustomerProvider 83 | func PascalCase(snake string) string { 84 | var ( 85 | b strings.Builder 86 | shouldUpper bool 87 | ) 88 | 89 | snake = pascalReplacer.ReplaceAllStringFunc(snake, strings.ToUpper) 90 | 91 | for i := range snake { 92 | c := rune(snake[i]) 93 | 94 | if i >= len(snake)-1 { // it's the last character. 95 | b.WriteRune(c) 96 | break 97 | } 98 | 99 | if c == '_' { // it's a separator. 100 | shouldUpper = true // the next character should be uppercased. 101 | } else if isLowercase(c) { // it's lower. 102 | if b.Len() == 0 || shouldUpper { // it's the first character or it should be uppercased. 103 | b.WriteRune(c - 'a' + 'A') // write its uppercase version. 104 | shouldUpper = false 105 | } else { 106 | b.WriteRune(c) // write it as it is, it's already lowercased. 107 | } 108 | } else { 109 | b.WriteRune(c) // write it as it is, it's already uppercased. 110 | shouldUpper = false 111 | } 112 | } 113 | 114 | return b.String() 115 | } 116 | 117 | // isLowercase returns true if the given rune is lowercase. 118 | func isLowercase(c rune) bool { 119 | return 'a' <= c && c <= 'z' 120 | } 121 | -------------------------------------------------------------------------------- /desc/naming_test.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import "testing" 4 | 5 | // TestSnakeCase tests the SnakeCase function with various inputs and outputs 6 | func TestSnakeCase(t *testing.T) { 7 | // Define a table of test cases 8 | testCases := []struct { 9 | input string // input string 10 | output string // expected output string 11 | }{ 12 | {"userId", "user_id"}, 13 | {"userID", "user_id"}, 14 | {"id", "id"}, 15 | {"ID", "id"}, 16 | {"ProviderAPIKey", "provider_api_key"}, 17 | {"Option", "option"}, 18 | {"CustomerHealthData", "customer_health_data"}, 19 | } 20 | 21 | // Loop over the test cases 22 | for _, tc := range testCases { 23 | // Call the SnakeCase function with the input 24 | result := SnakeCase(tc.input) 25 | // Compare the result with the expected output 26 | if result != tc.output { 27 | // Report an error if they don't match 28 | t.Errorf("SnakeCase(%q) = %q, want %q", tc.input, result, tc.output) 29 | } 30 | } 31 | } 32 | 33 | // TestPascalCase tests the PascalCase function with various inputs and outputs 34 | func TestPascalCase(t *testing.T) { 35 | // Define a table of test cases 36 | testCases := []struct { 37 | input string // input string 38 | output string // expected output string 39 | }{ 40 | {"user_id", "UserID"}, 41 | {"id", "ID"}, 42 | {"provider_api_key", "ProviderAPIKey"}, 43 | {"customer_provider", "CustomerProvider"}, 44 | {"url", "URL"}, 45 | {"api", "API"}, 46 | } 47 | 48 | // Loop over the test cases 49 | for _, tc := range testCases { 50 | // Call the PascalCase function with the input 51 | result := PascalCase(tc.input) 52 | // Compare the result with the expected output 53 | if result != tc.output { 54 | // Report an error if they don't match 55 | t.Errorf("PascalCase(%q) = %q, want %q", tc.input, result, tc.output) 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /desc/password_handler.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | // PasswordHandler is a type that represents a password handler for the database. 4 | type PasswordHandler struct { 5 | // Encrypt takes a table name and a plain password as strings and returns an encrypted password as a string. 6 | Encrypt func(tableName, plainPassword string) (encryptedPassword string, err error) 7 | // Decrypt takes a table name and an encrypted password as strings and returns a plain password as a string. 8 | Decrypt func(tableName, encryptedPassword string) (plainPassword string, err error) 9 | } 10 | 11 | func (h *PasswordHandler) canEncrypt() bool { 12 | if h == nil { 13 | return false 14 | } 15 | 16 | return h.Encrypt != nil 17 | } 18 | 19 | func (h *PasswordHandler) canDecrypt() bool { 20 | if h == nil { 21 | return false 22 | } 23 | 24 | return h.Decrypt != nil 25 | } 26 | -------------------------------------------------------------------------------- /desc/reflect.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "database/sql" 5 | "reflect" 6 | "strings" 7 | ) 8 | 9 | var scannerInterface = reflect.TypeOf((*sql.Scanner)(nil)).Elem() 10 | 11 | func implementsScanner(typ reflect.Type) bool { 12 | return typ.Implements(scannerInterface) || reflect.PointerTo(typ).Implements(scannerInterface) 13 | } 14 | 15 | // reflect. 16 | 17 | // IndirectType returns the value of a pointer-type "typ". 18 | // If "typ" is a pointer, array, chan, map or slice it returns its Elem, 19 | // otherwise returns the "typ" as it is. 20 | func IndirectType(typ reflect.Type) reflect.Type { 21 | switch typ.Kind() { 22 | case reflect.Ptr, reflect.Array, reflect.Chan, reflect.Map, reflect.Slice: 23 | return typ.Elem() 24 | } 25 | return typ 26 | } 27 | 28 | // IndirectValue returns the element type (e.g. if pointer of *User it will return the User type). 29 | func IndirectValue(v any) reflect.Value { 30 | return reflect.Indirect(reflect.ValueOf(v)) 31 | } 32 | 33 | // lookupFields takes a reflect.Type that represents a struct and a parent index slice 34 | // and returns a slice of reflect.StructField that represents the exported fields of the struct 35 | // that have a non-empty and non-dash value for the ‘pg’ tag. 36 | func lookupFields(typ reflect.Type, parentIndex []int) (fields []reflect.StructField) { 37 | // loop over all the exported fields of the struct (flattening any nested structs) 38 | for _, field := range lookupStructFields(typ, parentIndex) { 39 | // get the value of the tag with the default name and check if it is empty or dash 40 | if v := field.Tag.Get(DefaultTag); v == "" || v == "-" { 41 | // Skip fields that don’t contain the ‘pg’ tag or has ‘-’. 42 | // We do it here so we can have a calculated number of fields for columns. 43 | continue // skip this field 44 | } 45 | 46 | fields = append(fields, field) // append the field to the result slice 47 | } 48 | 49 | return // return the result slice 50 | } 51 | 52 | // isSpecialJSONStructure checks if a struct field has a tag that indicates a JSON or JSONB type. 53 | func isSpecialJSONStructure(field reflect.StructField) bool { 54 | tag := strings.ToLower(field.Tag.Get(DefaultTag)) // get the lower case value of the tag with the default name 55 | return strings.Contains(tag, "type=json") // return true if the tag contains "type=json" (this includes "type=jsonb" too) 56 | } 57 | 58 | // lookupStructFields takes a reflect.Type that represents a struct and a parent index slice 59 | // and returns a slice of reflect.StructField that represents the exported fields of the struct. 60 | func lookupStructFields(typ reflect.Type, parentIndex []int) (fields []reflect.StructField) { 61 | for i := 0; i < typ.NumField(); i++ { // loop over all the fields of the struct 62 | field := typ.Field(i) // get the i-th field 63 | if field.PkgPath != "" { // skip unexported fields (they have a non-empty package path) 64 | continue 65 | } 66 | 67 | fieldType := IndirectType(field.Type) // get the underlying type of the field 68 | 69 | if fieldType.Kind() == reflect.Struct { // if the field is a struct itself and it's not time, flatten it 70 | if fieldType != timeType && !isSpecialJSONStructure(field) /* do not flatten the struct's fields when jsonb struct field, let it behave as it is. */ { 71 | // on struct field: include all fields with an exception if the struct field itself is tagged for skipping explicitly "-" 72 | if field.Tag.Get(DefaultTag) == "-" { 73 | continue 74 | } 75 | 76 | if c, _ := convertStructFieldToColumnDefinion("", field); c != nil { 77 | if c.Presenter { 78 | continue 79 | } 80 | } 81 | 82 | // recursively look up the fields of the nested struct and append the current index to the parent index 83 | structFields := lookupFields(fieldType, append(parentIndex, i)) 84 | 85 | // as an exception, when this struct field is marked as a postgres column 86 | // but this field's struct type does not contain any `pg` tags 87 | // then treat that struct field itself as a postgres column, 88 | // e.g. a custom time.Time implementation. 89 | if len(structFields) > 0 { // if there are any nested fields found 90 | fields = append(fields, structFields...) // append them to the result slice 91 | continue 92 | } 93 | } 94 | } 95 | 96 | index := []int{i} // create a slice with the current index 97 | if len(parentIndex) > 0 { // if there is a parent index 98 | index = append(parentIndex, i) // append the current index to it 99 | } 100 | 101 | tmp := make([]int, len(index)) // make a copy of the index slice 102 | copy(tmp, index) 103 | field.Index = tmp // assign it to the field's index 104 | 105 | fields = append(fields, field) // append the field to the result slice 106 | } 107 | 108 | return // return the result slice 109 | } 110 | -------------------------------------------------------------------------------- /desc/struct_table_test.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | func TestParseReferenceTagValue(t *testing.T) { 9 | // Define some test cases with different input values and expected outputs. 10 | testCases := []struct { 11 | input string 12 | refTableName string 13 | refColumnName string 14 | onDeleteAction string 15 | isDeferrable bool 16 | err error 17 | }{ 18 | // Valid cases. 19 | {"blogs(id no action deferrable)", "blogs", "id", "NO ACTION", true, nil}, 20 | {"blogs(id no action)", "blogs", "id", "NO ACTION", false, nil}, 21 | {"blogs(id)", "blogs", "id", "CASCADE", false, nil}, 22 | {"blogs(id cascade)", "blogs", "id", "CASCADE", false, nil}, 23 | {"blogs(id set null deferrable)", "blogs", "id", "SET NULL", true, nil}, 24 | {"blogs(id set default)", "blogs", "id", "SET DEFAULT", false, nil}, 25 | {"users(id no action deferrable)", "users", "id", "NO ACTION", true, nil}, 26 | // Invalid cases. 27 | {"blogs(id foo)", "", "", "", false, errInvalidReferenceTag}, 28 | {"blogs(id restrict deferrable)", "", "", "", false, errInvalidReferenceTag}, 29 | } 30 | 31 | for _, tc := range testCases { 32 | t.Run(tc.input, func(t *testing.T) { 33 | // Call the function with the input value and get the output values. 34 | refTableName, refColumnName, onDeleteAction, isDeferrable, err := parseReferenceTagValue(tc.input) 35 | 36 | // Check if the output values match the expected values. 37 | if refTableName != tc.refTableName { 38 | t.Errorf("%s: expected refTableName to be %s, got %s", tc.input, tc.refTableName, refTableName) 39 | } 40 | 41 | if refColumnName != tc.refColumnName { 42 | t.Errorf("%s: expected refColumnName to be %s, got %s", tc.input, tc.refColumnName, refColumnName) 43 | } 44 | 45 | if onDeleteAction != tc.onDeleteAction { 46 | t.Errorf("%s: expected onDeleteAction to be %s, got %s", tc.input, tc.onDeleteAction, onDeleteAction) 47 | } 48 | 49 | if isDeferrable != tc.isDeferrable { 50 | t.Errorf("%s: expected isDeferrable to be %t, got %t", tc.input, tc.isDeferrable, isDeferrable) 51 | } 52 | 53 | if err != tc.err { 54 | if !errors.Is(err, tc.err) { 55 | t.Errorf("%s: expected err to be %v, got %v", tc.input, tc.err, err) 56 | } 57 | } 58 | }) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /desc/table_test.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import "testing" 4 | 5 | func TestColumnFieldTagString(t *testing.T) { 6 | c := &Column{ 7 | Name: "test", 8 | Type: CharacterVarying, 9 | TypeArgument: "255", 10 | PrimaryKey: true, 11 | Identity: true, 12 | // Required: true, 13 | Default: "test_default", 14 | Unique: true, 15 | Conflict: "test_conflict", 16 | Username: true, 17 | Password: true, 18 | ReferenceTableName: "test_reference_table_name", 19 | ReferenceColumnName: "test_reference_column_name", 20 | ReferenceOnDelete: "test_reference_on_delete", 21 | DeferrableReference: true, 22 | Index: Hash, 23 | UniqueIndex: "test_unique_index", 24 | CheckConstraint: "test_check_constraint", 25 | AutoGenerated: true, 26 | Presenter: true, 27 | Unscannable: true, 28 | } 29 | 30 | expected := `pg:"name=test,type=varchar(255),primary,identity,default=test_default,unique,conflict=test_conflict,username,password,ref=test_reference_table_name(test_reference_column_name test_reference_on_delete deferrable),index=hash,unique_index=test_unique_index,check=test_check_constraint,auto,presenter,unscannable"` 31 | if got := c.FieldTagString(true); expected != got { 32 | t.Fatalf("expected field tag:\n%s\nbut got:\n%s", expected, got) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /desc/trigger.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | // Trigger represents a database trigger. 4 | type Trigger struct { 5 | Catalog string // Catalog name of the trigger 6 | SearchPath string // Search path of the trigger 7 | Name string // Name of the trigger 8 | Manipulation string // Type of manipulation (INSERT, UPDATE, DELETE) 9 | TableName string // Name of the table the trigger is on 10 | ActionStatement string // SQL statement executed by the trigger 11 | ActionOrientation string // Orientation of the trigger (ROW or STATEMENT) 12 | ActionTiming string // Timing of the trigger (BEFORE or AFTER) 13 | } 14 | -------------------------------------------------------------------------------- /desc/unique_index.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | // UniqueIndex is a struct that represents a unique index. 4 | // See DB.ListUniqueIndexes method for more. 5 | type UniqueIndex struct { 6 | TableName string // table name 7 | IndexName string // index name 8 | Columns []string // column names. 9 | } 10 | -------------------------------------------------------------------------------- /desc/update_query.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | // BuildUpdateQuery builds and returns an SQL query for updating a row in the table, 10 | // using the given struct value and the primary key. 11 | func BuildUpdateQuery(value any, columnsToUpdate []string, reportNotFound bool, primaryKey *Column) (string, []any, error) { 12 | args, err := extractUpdateArguments(value, columnsToUpdate, primaryKey) 13 | if err != nil { 14 | return "", nil, err 15 | } 16 | 17 | shouldUpdateID := false 18 | for _, col := range columnsToUpdate { 19 | if col == primaryKey.Name { 20 | shouldUpdateID = true 21 | break 22 | } 23 | } 24 | 25 | if len(args) == 1 { // the last one is the id. 26 | return "", nil, fmt.Errorf("no arguments found for update, maybe missing struct field tag of \"%s\"", DefaultTag) 27 | } 28 | 29 | // build the SQL query using the table definition and its primary key. 30 | query := buildUpdateQuery(primaryKey.Table, args, primaryKey.Name, shouldUpdateID, reportNotFound) 31 | return query, args.Values(), nil 32 | } 33 | 34 | // extractUpdateArguments extracts the arguments from the given struct value and returns them. 35 | func extractUpdateArguments(value any, columnsToUpdate []string, primaryKey *Column) (Arguments, error) { 36 | structValue := IndirectValue(value) 37 | 38 | id, err := ExtractPrimaryKeyValue(primaryKey, structValue) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | columnsToUpdateLength := len(columnsToUpdate) 44 | 45 | args, err := extractArguments(primaryKey.Table, structValue, func(fieldName string) bool { 46 | if columnsToUpdateLength == 0 { 47 | // full update. 48 | return true 49 | } 50 | 51 | for _, onlyColumnName := range columnsToUpdate { 52 | if onlyColumnName == fieldName { 53 | return true 54 | } 55 | } 56 | 57 | return false 58 | }) 59 | if err != nil { 60 | return nil, err // return the error if finding arguments fails 61 | } 62 | 63 | if columnsToUpdateLength == 0 { 64 | // full update, even zero values (e.g. integer 0) all except ID and any created_at, updated_at. 65 | args = filterArgumentsForFullUpdate(args) 66 | } 67 | 68 | if len(args) == 0 { 69 | // nothing to update, raise an error 70 | return nil, fmt.Errorf(`no arguments found for update, maybe missing struct field tag of "%s"`, DefaultTag) 71 | } 72 | 73 | // Add (or move) the primary key value as the last argument, 74 | // move is a requiremend here in order to remove a duplicated primary key name in the query; 75 | // this can happen if the specified column names to update do not match the database schema. 76 | args.ShiftEnd(Argument{ 77 | Column: primaryKey, 78 | Value: id, 79 | }) 80 | 81 | return args, nil 82 | } 83 | 84 | func buildUpdateQuery(td *Table, args Arguments, primaryKeyName string, shouldUpdateID bool, reportNotFound bool) string { 85 | var b strings.Builder 86 | 87 | b.WriteString(`UPDATE "` + td.Name + `" SET `) 88 | 89 | var paramIndex int 90 | 91 | for i, a := range args { 92 | c := a.Column 93 | 94 | if !shouldUpdateID && c.Name == primaryKeyName { 95 | // Do not update ID if not specifically asked to. 96 | // Fixes #1. 97 | continue 98 | } 99 | 100 | if i > 0 { 101 | b.WriteByte(',') 102 | } 103 | 104 | paramIndex++ // starts from 1. 105 | paramIndexStr := strconv.Itoa(paramIndex) 106 | paramName := "$" + paramIndexStr 107 | 108 | if c.Password { 109 | if td.PasswordHandler.canEncrypt() { 110 | // handled at args state. 111 | } else { 112 | paramName = buildInsertPassword(paramName) 113 | } 114 | } 115 | 116 | b.WriteString(fmt.Sprintf(`"%s" = %s`, c.Name, paramName)) 117 | } 118 | 119 | primaryKeyWhereIndex := paramIndex + 1 120 | if shouldUpdateID { // if updating ID, then the last argument is the ID. 121 | primaryKeyWhereIndex = paramIndex 122 | } 123 | b.WriteString(` WHERE "` + primaryKeyName + `" = $` + strconv.Itoa(primaryKeyWhereIndex)) 124 | 125 | if reportNotFound { 126 | b.WriteString(` RETURNING "` + primaryKeyName + `"`) 127 | } 128 | 129 | b.WriteByte(';') 130 | 131 | return b.String() 132 | } 133 | -------------------------------------------------------------------------------- /desc/zeroer.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "encoding/json" 5 | "math/big" 6 | "net" 7 | "reflect" 8 | "time" 9 | ) 10 | 11 | // Zeroer is an interface that defines a method to check if a value is zero. 12 | // 13 | // Zeroer can be implemented by custom types 14 | // to report whether its current value is zero. 15 | // Standard Time also implements that. 16 | type Zeroer interface { 17 | IsZero() bool // IsZero returns true if the value is zero 18 | } 19 | 20 | // isZero takes an interface value and returns true if it is nil or zero. 21 | func isZero(v any) bool { 22 | if v == nil { 23 | // if the value is nil, return true 24 | return true 25 | } 26 | 27 | switch t := v.(type) { // switch on the type of the value 28 | case *time.Time: 29 | return t == nil || t.IsZero() 30 | case *string: 31 | return t == nil || *t == "" 32 | case *int: 33 | return t == nil || *t == 0 34 | case *int8: 35 | return t == nil || *t == 0 36 | case *int16: 37 | return t == nil || *t == 0 38 | case *int32: 39 | return t == nil || *t == 0 40 | case *int64: 41 | return t == nil || *t == 0 42 | case *uint: 43 | return t == nil || *t == 0 44 | case *uint8: 45 | return t == nil || *t == 0 46 | case *uint16: 47 | return t == nil || *t == 0 48 | case *uint32: 49 | return t == nil || *t == 0 50 | case *uint64: 51 | return t == nil || *t == 0 52 | case *float32: 53 | return t == nil || *t == 0 54 | case *float64: 55 | return t == nil || *t == 0 56 | case *bool: 57 | return t == nil || !*t 58 | case *[]string: 59 | return t == nil || len(*t) == 0 60 | case *[]int: 61 | return t == nil || len(*t) == 0 62 | case *[]int8: 63 | return t == nil || len(*t) == 0 64 | case *[]int16: 65 | return t == nil || len(*t) == 0 66 | case *[]int32: 67 | return t == nil || len(*t) == 0 68 | case *[]int64: 69 | return t == nil || len(*t) == 0 70 | case *[]uint: 71 | return t == nil || len(*t) == 0 72 | case *[]uint8: 73 | return t == nil || len(*t) == 0 74 | case *[]uint16: 75 | return t == nil || len(*t) == 0 76 | case *[]uint32: 77 | return t == nil || len(*t) == 0 78 | case *[]uint64: 79 | return t == nil || len(*t) == 0 80 | case *[]float32: 81 | return t == nil || len(*t) == 0 82 | case *[]float64: 83 | return t == nil || len(*t) == 0 84 | case *[]bool: 85 | return t == nil || len(*t) == 0 86 | case *[]any: 87 | return t == nil || len(*t) == 0 88 | case *map[string]string: 89 | return t == nil || len(*t) == 0 90 | case *map[string]int: 91 | return t == nil || len(*t) == 0 92 | case *map[string]any: 93 | return t == nil || len(*t) == 0 94 | case *map[int]int: 95 | return t == nil || len(*t) == 0 96 | case *map[int]any: 97 | return t == nil || len(*t) == 0 98 | case *map[any]any: 99 | return t == nil || len(*t) == 0 100 | case *map[any]int: 101 | return t == nil || len(*t) == 0 102 | case *map[any]string: 103 | return t == nil || len(*t) == 0 104 | case *map[any]float64: 105 | return t == nil || len(*t) == 0 106 | case *map[any]bool: 107 | return t == nil || len(*t) == 0 108 | case *map[any][]any: 109 | return t == nil || len(*t) == 0 110 | case *map[any][]int: 111 | return t == nil || len(*t) == 0 112 | case *map[any][]string: 113 | return t == nil || len(*t) == 0 114 | case *map[any]map[any]any: 115 | return t == nil || len(*t) == 0 116 | case *map[any]map[any]int: 117 | return t == nil || len(*t) == 0 118 | case *map[any]map[any]string: 119 | return t == nil || len(*t) == 0 120 | case *map[any]map[any]float64: 121 | return t == nil || len(*t) == 0 122 | case *map[any]map[any]bool: 123 | return t == nil || len(*t) == 0 124 | case *map[any]map[any][]any: 125 | return t == nil || len(*t) == 0 126 | case *map[any]map[any][]int: 127 | return t == nil || len(*t) == 0 128 | case reflect.Value: 129 | if t.Kind() == reflect.Ptr { 130 | return t.IsNil() 131 | } 132 | 133 | return t.IsZero() 134 | case Zeroer: // if the value implements the Zeroer interface 135 | return t == nil || t.IsZero() // call the IsZero method on the value 136 | case string: // if the value is a string 137 | return t == "" // return true if the string is empty 138 | case int: // if the value is an int 139 | return t == 0 // return true if the int is zero 140 | case int8: // if the value is an int8 141 | return t == 0 // return true if the int8 is zero 142 | case int16: // if the value is an int16 143 | return t == 0 // return true if the int16 is zero 144 | case int32: // if the value is an int32 145 | return t == 0 // return true if the int32 is zero 146 | case int64: // if the value is an int64 147 | return t == 0 // return true if the int64 is zero 148 | case uint: // if the value is a uint 149 | return t == 0 // return true if the uint is zero 150 | case uint8: // if the value is a uint8 151 | return t == 0 // return true if the uint8 is zero 152 | case uint16: // if the value is a uint16 153 | return t == 0 // return true if the uint16 is zero 154 | case uint32: // if the value is a uint32 155 | return t == 0 // return true if the uint32 is zero 156 | case uint64: // if the value is a uint64 157 | return t == 0 // return true if the uint64 is zero 158 | case float32: // if the value is a float32 159 | return t == 0 // return true if the float32 is zero 160 | case float64: // if the value is a float64 161 | return t == 0 // return true if the float64 is zero 162 | case bool: // if the value is a bool 163 | return !t // return true if the bool is false (the opposite of its value) 164 | case []int: // if the value is a slice of ints 165 | return len(t) == 0 // return true if the slice has zero length 166 | case []string: // if the value is a slice of strings 167 | return len(t) == 0 // return true if the slice has zero length 168 | case [][]int: // if the value is a slice of slices of ints 169 | return len(t) == 0 // return true if the slice has zero length 170 | case [][]string: // if the value is a slice of slices of strings 171 | return len(t) == 0 // return true if the slice has zero length 172 | case json.Number: // if the value is a json.Number (a string that represents a number in JSON) 173 | return t.String() == "" // return true if the string representation of the number is empty 174 | case net.IP: // if the value is a net.IP (a slice of bytes that represents an IP address) 175 | return len(t) == 0 // return true if the slice has zero length 176 | case map[string]any: 177 | return len(t) == 0 178 | case map[int]any: 179 | return len(t) == 0 180 | case map[string]string: 181 | return len(t) == 0 182 | case map[string]int: 183 | return len(t) == 0 184 | case map[int]int: 185 | return len(t) == 0 186 | case struct{}: 187 | return true 188 | case *big.Int: 189 | return t == nil 190 | case big.Int: 191 | return isZero(t.Int64()) 192 | case *big.Rat: 193 | return t == nil 194 | case big.Rat: 195 | return isZero(t.Num()) 196 | case *big.Float: 197 | return t == nil 198 | default: // for any other type of value 199 | return false // return false (assume it's not zero) 200 | } 201 | } 202 | -------------------------------------------------------------------------------- /desc/zeroer_test.go: -------------------------------------------------------------------------------- 1 | package desc 2 | 3 | import ( 4 | "encoding/json" 5 | "math/big" 6 | "net" 7 | "reflect" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | // TestIsZero tests the isZero function with various inputs and outputs 13 | func TestIsZero(t *testing.T) { 14 | now := time.Now() 15 | timePtr := &now 16 | var nilTimePtr *time.Time 17 | 18 | // Define a table of test cases 19 | testCases := []struct { 20 | input any // input value 21 | output bool // expected output value 22 | }{ 23 | {nil, true}, // nil value should be zero 24 | {"", true}, // empty string should be zero 25 | {"hello", false}, // non-empty string should not be zero 26 | {0, true}, // zero int should be zero 27 | {1, false}, // non-zero int should not be zero 28 | {0.0, true}, // zero float should be zero 29 | {1.0, false}, // non-zero float should not be zero 30 | {false, true}, // false bool should be zero 31 | {true, false}, // true bool should not be zero 32 | {[]int{}, true}, // empty slice of ints should be zero 33 | {[]int{1, 2, 3}, false}, // non-empty slice of ints should not be zero 34 | {[]string{}, true}, // empty slice of strings should be zero 35 | {[]string{"a", "b", "c"}, false}, // non-empty slice of strings should not be zero 36 | {map[string]int{}, true}, // empty map of strings to ints should be zero 37 | {map[string]int{"a": 1, "b": 2}, false}, // non-empty map of strings to ints should not be zero 38 | {struct{}{}, true}, // empty struct should be zero 39 | {struct{ x int }{1}, false}, // non-empty struct should not be zero 40 | {big.NewInt(0), false}, // big int pointer with value 0 should not be zero 41 | {big.NewInt(1), false}, // big int pointer with value 1 should not be zero 42 | {big.NewRat(0, 1), false}, // big rational pointer with value 0/1 should be zero 43 | {big.NewRat(1, 2), false}, // big rational pointer with value 1/2 should not be zero 44 | {big.NewFloat(0.0), false}, // big float pointer with value 0.0 should not be zero 45 | {big.NewFloat(1.0), false}, // big float pointer with value 1.0 should not be zero 46 | {json.Number(""), true}, // empty json.Number should be zero 47 | {json.Number("123"), false}, // non-empty json.Number should not be zero 48 | {net.IP{}, true}, // empty net.IP should be zero 49 | {net.IPv4(127, 0, 0, 1), false}, // non-empty net.IP should not be zero 50 | {time.Time{}, true}, // empty time.Time (zero time) should be zero 51 | {time.Now(), false}, // non-empty time.Time (current time) should not be zero 52 | {timePtr, false}, // non-nil time.Time (current time) should not be zero 53 | {nilTimePtr, true}, // nil time.Time should be zero 54 | } 55 | 56 | for i, tc := range testCases { 57 | isNil := false 58 | 59 | if val := reflect.ValueOf(tc.input); val.Kind() == reflect.Pointer { 60 | isNil = val.IsNil() 61 | } 62 | 63 | if tc.input == nil || isNil { 64 | t.Run("nil", func(t *testing.T) { 65 | result := isZero(tc.input) // call the isZero function with the input 66 | if result != tc.output { // compare the result with the expected output 67 | t.Errorf("[%d] isZero(%v) = %v, want %v", i, tc.input, result, tc.output) // report an error if they don't match 68 | } 69 | }) 70 | continue 71 | } 72 | 73 | if zr, ok := tc.input.(Zeroer); ok { // if the input implements the Zeroer interface (this includes time.Time as well) 74 | result := zr.IsZero() // call the IsZero method on the input value 75 | if result != tc.output { // compare the result with the expected output 76 | t.Errorf("[%d] %T.IsZero() = %v, want %v", i, tc.input, result, tc.output) // report an error if they don't match 77 | } 78 | 79 | continue 80 | } 81 | 82 | if tm, ok := tc.input.(time.Time); ok { // if the input is a time.Time value (this is a special case because time.Time implements Zeroer but has a different definition of zero) 83 | result := tm.IsZero() || tm.UnixNano() == 0 // call the IsZero method on the time value or check if its UnixNano representation is zero (this covers both the standard library definition and the custom definition of zero for time.Time) 84 | if result != tc.output { // compare the result with the expected output 85 | t.Errorf("[%d] %T.IsZero() = %v, want %v", i, tc.input, result, tc.output) // report an error if they don't match 86 | } 87 | 88 | continue 89 | } 90 | 91 | if ip, ok := tc.input.(net.IP); ok { // if the input is a net.IP value (this is another special case because net.IP is a slice of bytes but has a different definition of zero) 92 | result := len(ip) == 0 || ip.Equal(net.IPv4zero) || ip.Equal(net.IPv6zero) || ip.Equal(net.IPv6unspecified) || ip.Equal(net.IPv6loopback) || ip.Equal(net.IPv6interfacelocalallnodes) || ip.Equal(net.IPv6linklocalallnodes) || ip.Equal(net.IPv6linklocalallrouters) || ip.Equal(net.IPv4bcast) // check if the IP value is empty or equal to one of the predefined constants that represent a zero IP address (this covers all the possible cases of zero for net.IP) 93 | if result != tc.output { // compare the result with the expected output 94 | t.Errorf("[%d] %T.IsZero() = %v, want %v", i, tc.input, result, tc.output) // report an error if they don't match 95 | } 96 | 97 | continue 98 | } 99 | 100 | result := isZero(tc.input) // call the isZero function with the input 101 | if result != tc.output { // compare the result with the expected output 102 | t.Errorf("[%d] isZero(%v) = %v, want %v", i, tc.input, result, tc.output) // report an error if they don't match 103 | } 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/jackc/pgx/v5" 8 | ) 9 | 10 | var ( 11 | // ErrNoRows is fired from a query when no results are came back. 12 | // Usually it's ignored and an empty json array is sent to the client instead. 13 | ErrNoRows = pgx.ErrNoRows 14 | ) 15 | 16 | // IsErrDuplicate reports whether the return error from `Insert` method 17 | // was caused because of a violation of a unique constraint (it's not typed error at the underline driver). 18 | // It returns the constraint key if it's true. 19 | func IsErrDuplicate(err error) (string, bool) { 20 | if err != nil { 21 | errText := err.Error() 22 | if strings.Contains(errText, "ERROR: duplicate key value violates unique constraint") { 23 | if startIdx := strings.IndexByte(errText, '"'); startIdx > 0 && startIdx+1 < len(errText) { 24 | errText = errText[startIdx+1:] 25 | if endIdx := strings.IndexByte(errText, '"'); endIdx > 0 && endIdx < len(errText) { 26 | return errText[:endIdx], true 27 | } 28 | } 29 | } 30 | } 31 | 32 | return "", false 33 | } 34 | 35 | // IsErrForeignKey reports whether an insert or update command failed due 36 | // to an invalid foreign key: a foreign key is missing or its source was not found. 37 | // E.g. ERROR: insert or update on table "food_user_friendly_units" violates foreign key constraint "fk_food" (SQLSTATE 23503) 38 | func IsErrForeignKey(err error) (string, bool) { 39 | if err != nil { 40 | errText := err.Error() 41 | if strings.Contains(errText, "violates foreign key constraint") { 42 | if startIdx := strings.IndexByte(errText, '"'); startIdx > 0 && startIdx+1 < len(errText) { 43 | errText = errText[startIdx+1:] 44 | if endIdx := strings.IndexByte(errText, '"'); endIdx > 0 && endIdx < len(errText) { 45 | return errText[:endIdx], true 46 | } 47 | } 48 | } 49 | } 50 | return "", false 51 | } 52 | 53 | // IsErrInputSyntax reports whether the return error from `Insert` method 54 | // was caused because of invalid input syntax for a specific postgres column type. 55 | func IsErrInputSyntax(err error) (string, bool) { 56 | if err != nil { 57 | errText := err.Error() 58 | if strings.HasPrefix(errText, "ERROR: ") { 59 | if strings.Contains(errText, "ERROR: invalid input syntax for type") || strings.Contains(errText, "ERROR: syntax error in tsquery") || strings.Contains(errText, "ERROR: no operand in tsquery") { 60 | if startIdx := strings.IndexByte(errText, '"'); startIdx > 0 && startIdx+1 < len(errText) { 61 | errText = errText[startIdx+1:] 62 | if endIdx := strings.IndexByte(errText, '"'); endIdx > 0 && endIdx < len(errText) { 63 | return errText[:endIdx], true 64 | } 65 | } else { 66 | // more generic error. 67 | return "invalid input syntax", true 68 | } 69 | } 70 | } 71 | } 72 | 73 | return "", false 74 | } 75 | 76 | // IsErrColumnNotExists reports whether the error is caused because the "col" defined 77 | // in a select query was not exists in a row. 78 | // There is no a typed error available in the driver itself. 79 | func IsErrColumnNotExists(err error, col string) bool { 80 | if err == nil { 81 | return false 82 | } 83 | 84 | errText := fmt.Sprintf(`column "%s" does not exist`, col) 85 | return strings.Contains(err.Error(), errText) 86 | } 87 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | "time" 8 | ) 9 | 10 | // Example is a function that demonstrates how to use the Registry and Repository types 11 | // to perform database operations within a transaction. It uses the Customer, Blog, and BlogPost structs 12 | // as the entities to be stored and manipulated in the database. It also prints "OK" if everything succeeds, 13 | // or an error message otherwise. 14 | func Example() { 15 | db, err := openTestConnection(true) 16 | if err != nil { 17 | fmt.Println(err.Error()) 18 | return 19 | } 20 | defer db.Close() 21 | 22 | // Registry code. 23 | registry := NewRegistry(db) // Create a new Registry instance with the DB instance. 24 | 25 | // Execute a function within a database transaction, passing it a Registry instance that uses the transactional DB instance. 26 | err = registry.InTransaction(context.Background(), func(registry *Registry) error { 27 | customers := registry.Customers() // Get the CustomerRepository instance from the Registry. 28 | 29 | customerToInsert := Customer{ // Create a Customer struct to be inserted into the database. 30 | CognitoUserID: "373f90eb-00ac-410f-9fe0-1a7058d090ba", 31 | Email: "kataras2006@hotmail.com", 32 | Name: "kataras", 33 | Username: "kataras", 34 | } 35 | 36 | // Insert the customer into the database and get its ID. 37 | err = customers.InsertSingle(context.Background(), customerToInsert, &customerToInsert.ID) 38 | if err != nil { 39 | return fmt.Errorf("insert single: %w", err) 40 | } 41 | 42 | // Modify cognito user id. 43 | newCognitoUserID := "1e6a93d0-6276-4a43-b90a-4badad8407bb" 44 | // Update specific columns by id: 45 | updated, err := customers.UpdateOnlyColumns( 46 | context.Background(), 47 | []string{"cognito_user_id"}, 48 | Customer{ 49 | BaseEntity: BaseEntity{ 50 | ID: customerToInsert.ID, 51 | }, 52 | CognitoUserID: newCognitoUserID, 53 | }) 54 | // Full update of the object by id (except id and created_at, updated_at columns): 55 | // updated, err := customers.Update(context.Background(), 56 | // Customer{ 57 | // BaseEntity: BaseEntity{ 58 | // ID: customerToInsert.ID, 59 | // }, 60 | // CognitoUserID: newCognitoUserID, 61 | // Email: customerToInsert.Email, 62 | // Name: customerToInsert.Name, 63 | // }) 64 | if err != nil { 65 | return fmt.Errorf("update: %w", err) 66 | } else if updated == 0 { 67 | return fmt.Errorf("update: no record was updated") 68 | } 69 | 70 | // Update a default column to its zero value. 71 | updated, err = customers.UpdateOnlyColumns( 72 | context.Background(), 73 | []string{"username"}, 74 | Customer{ 75 | BaseEntity: BaseEntity{ 76 | ID: customerToInsert.ID, 77 | }, 78 | Username: "", 79 | }) 80 | if err != nil { 81 | return fmt.Errorf("update username: %w", err) 82 | } else if updated == 0 { 83 | return fmt.Errorf("update username: no record was updated") 84 | } 85 | // Select the customer from the database by its ID. 86 | customer, err := customers.SelectSingle(context.Background(), `SELECT * FROM customers WHERE id = $1;`, customerToInsert.ID) 87 | if err != nil { 88 | return fmt.Errorf("select single: %w", err) 89 | } 90 | 91 | if customer.CognitoUserID != newCognitoUserID { 92 | return fmt.Errorf("expected cognito user id to be updated but it wasn't ('%s' vs '%s')", 93 | newCognitoUserID, customer.CognitoUserID) 94 | } 95 | if customer.Email == "" { 96 | return fmt.Errorf("expected email field not be removed after update") 97 | } 98 | if customer.Name == "" { 99 | return fmt.Errorf("expected name field not be removed after update") 100 | } 101 | 102 | // Test Upsert by modifying the email. 103 | customerToUpsert := Customer{ 104 | CognitoUserID: customer.CognitoUserID, 105 | Email: "kataras2023@hotmail.com", 106 | Name: "kataras2023", 107 | } 108 | 109 | // Manually passing a column as the conflict column: 110 | // err = customers.UpsertSingle(context.Background(), "email", customerToUpsert, &customerToUpsert.ID) 111 | // 112 | // Automatically find the conflict column or expression by setting it to empty value: 113 | // err = customers.UpsertSingle(context.Background(), "", customerToUpsert, &customerToUpsert.ID) 114 | // Manually passing a unique index name, pg will resolve the conflict columns: 115 | err = customers.UpsertSingle(context.Background(), "customer_unique_idx", customerToUpsert, &customerToUpsert.ID) 116 | if err != nil { 117 | return fmt.Errorf("upsert single: %w", err) 118 | } 119 | 120 | if customerToUpsert.ID == "" { 121 | return fmt.Errorf("expected customer id to be filled after upsert") 122 | } 123 | 124 | // Delete the customer from the database by its struct value. 125 | deleted, err := customers.Delete(context.Background(), customer) 126 | if err != nil { 127 | return fmt.Errorf("delete: %w", err) 128 | } else if deleted == 0 { 129 | return fmt.Errorf("delete: was not removed") 130 | } 131 | 132 | exists, err := customers.Exists(context.Background(), customer.CognitoUserID) 133 | if err != nil { 134 | return fmt.Errorf("exists: %w", err) 135 | } 136 | if exists { 137 | return fmt.Errorf("exists: customer should not exist") 138 | } 139 | 140 | // Do something else with customers. 141 | return nil 142 | }) 143 | 144 | if err != nil { 145 | fmt.Println(fmt.Errorf("in transaction: %w", err)) 146 | return 147 | } 148 | 149 | // Insert a blog. 150 | blogs := registry.Blogs() 151 | newBlog := Blog{ 152 | Name: "test_blog_1", 153 | } 154 | err = blogs.InsertSingle(context.Background(), newBlog, &newBlog.ID) 155 | if err != nil { 156 | fmt.Println(fmt.Errorf("insert single: blog: %w", err)) 157 | return 158 | } 159 | 160 | // Insert a blog post to the blog. 161 | blogPosts := registry.BlogPosts() 162 | newBlogPost := BlogPost{ 163 | BlogID: newBlog.ID, 164 | Title: "test_blog_post_1", 165 | PhotoURL: "https://test.com/test_blog_post_1.png", 166 | SourceURL: "https://test.com/test_blog_post_1.html", 167 | ReadTimeMinutes: 5, 168 | Category: 1, 169 | SearchTerms: []string{ 170 | "test_search_blog_post_1", 171 | "test_search_blog_post_2", 172 | }, 173 | ReadDurations: []time.Duration{ 174 | 5 * time.Minute, 175 | 10 * time.Minute, 176 | }, 177 | Feature: Feature{ 178 | IsFeatured: true, 179 | }, 180 | OtherFeatures: Features{ 181 | Feature{ 182 | IsFeatured: true, 183 | }, 184 | Feature{ 185 | IsFeatured: false, 186 | }, 187 | }, 188 | Tags: []Tag{ 189 | {"test_tag_1", "test_tag_value_1"}, 190 | {"test_tag_2", 42}, 191 | }, 192 | } 193 | err = blogPosts.InsertSingle(context.Background(), newBlogPost, &newBlogPost.ID) 194 | if err != nil { 195 | fmt.Println(fmt.Errorf("insert single: blog post: %w", err)) 196 | return 197 | } 198 | 199 | query := `SELECT * FROM blog_posts WHERE id = $1 LIMIT 1;` 200 | existingBlogPost, err := blogPosts.SelectSingle(context.Background(), query, newBlogPost.ID) 201 | if err != nil { 202 | fmt.Println(fmt.Errorf("select single: blog post: %s: %w", newBlogPost.ID, err)) 203 | return 204 | } 205 | 206 | // Test select single jsonb column of a custom type of array of custom types. 207 | // 208 | var otherFeatures Features 209 | err = blogPosts.QueryRow( 210 | context.Background(), 211 | `SELECT other_features FROM blog_posts WHERE id = $1 LIMIT 1;`, 212 | newBlogPost.ID, 213 | ).Scan(&otherFeatures) 214 | // OR 215 | // otherFeatures, err := QuerySingle[Features]( 216 | // context.Background(), 217 | // db, 218 | // `SELECT other_features FROM blog_posts WHERE id = $1 LIMIT 1;`, 219 | // newBlogPost.ID, 220 | // ) 221 | if err != nil { 222 | fmt.Println(fmt.Errorf("select single jsonb column of custom array type of custom type: blog post: %s: %w", newBlogPost.ID, err)) 223 | return 224 | } 225 | 226 | if expected, got := len(otherFeatures), len(existingBlogPost.OtherFeatures); expected != got { 227 | fmt.Printf("expected %d other_features but got %d", expected, got) 228 | return 229 | } 230 | 231 | if !reflect.DeepEqual(otherFeatures, existingBlogPost.OtherFeatures) { 232 | fmt.Printf("expected other_features to be equal but got %#+v and %#+v", otherFeatures, existingBlogPost.OtherFeatures) 233 | return 234 | } 235 | 236 | // Output: 237 | // 238 | } 239 | -------------------------------------------------------------------------------- /gen/README.md: -------------------------------------------------------------------------------- 1 | # gen 2 | 3 | The gen package provides a function to generate Go schema files from a PostgreSQL database. 4 | 5 | ## Usage 6 | 7 | To use the gen package, you need to import it in your Go code: 8 | 9 | ```go 10 | import "github.com/kataras/pg/gen" 11 | ``` 12 | 13 | The main function of the gen package is `GenerateSchemaFromDatabase`, which takes a context, an `ImportOptions` struct and an `ExportOptions` struct as arguments. The `ImportOptions` struct contains the connection string and the list of tables to import from the database. The `ExportOptions` struct contains the root directory and the file name generator for the schema files. 14 | 15 | For example, this code snippet shows how to generate schema files for all tables in a test database: 16 | 17 | ```go 18 | package main 19 | 20 | import ( 21 | "context" 22 | "fmt" 23 | "os" 24 | "time" 25 | 26 | "github.com/kataras/pg/gen" 27 | ) 28 | 29 | func main() { 30 | rootDir := "./_testdata" 31 | 32 | i := gen.ImportOptions{ 33 | ConnString: "postgres://postgres:admin!123@localhost:5432/test_db?sslmode=disable", 34 | } 35 | 36 | e := gen.ExportOptions{ 37 | RootDir: rootDir, 38 | } 39 | 40 | if err := gen.GenerateSchemaFromDatabase(context.Background(), i, e); err != nil { 41 | fmt.Println(err.Error()) 42 | return 43 | } 44 | 45 | fmt.Println("OK") 46 | } 47 | ``` 48 | 49 | The `GenerateSchemaFromDatabase` function will create a directory named `_testdata` and write schema files for each table in the test database. The schema files will have the same name as the table name (on its singular form), with a `.go` extension. 50 | 51 | You can also customize the import and export options by using the fields of the `ImportOptions` and `ExportOptions` structs. For example, you can use the `ListTables.Filter` field to filter out some tables or columns, or use the `GetFileName` field to change how the schema files are named. 52 | 53 | For more details on how to use the gen package, please refer to the [godoc](https://pkg.go.dev/github.com/kataras/pg/gen) documentation. 54 | 55 | ## License 56 | 57 | The gen package is licensed under the MIT license. See [LICENSE](https://github.com/kataras/pg/blob/main/LICENSE) for more information. -------------------------------------------------------------------------------- /gen/db_schema_gen_example_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "reflect" 8 | "time" 9 | 10 | "github.com/kataras/pg" 11 | ) 12 | 13 | type Features []Feature 14 | 15 | type Feature struct { 16 | IsFeatured bool `json:"is_featured"` 17 | } 18 | 19 | type Tag struct { 20 | Name string `json:"name"` 21 | Value any `json:"value"` 22 | } 23 | 24 | func ExampleGenerateSchemaFromDatabase() { 25 | const ( 26 | rootDir = "./_testdata" 27 | ) 28 | defer func() { 29 | os.RemoveAll(rootDir) 30 | time.Sleep(1 * time.Second) 31 | }() 32 | 33 | i := ImportOptions{ 34 | ConnString: "postgres://postgres:admin!123@localhost:5432/test_db?sslmode=disable", 35 | ListTables: pg.ListTablesOptions{ 36 | Filter: pg.TableFilterFunc(func(table *pg.Table) bool { 37 | columnFilter := func(column *pg.Column) bool { 38 | columnName := column.Name 39 | 40 | switch table.Name { 41 | case "blog_posts": 42 | switch columnName { 43 | case "feature": 44 | column.FieldType = reflect.TypeOf(Feature{}) 45 | case "other_features": 46 | column.FieldType = reflect.TypeOf(Features{}) 47 | case "tags": 48 | column.FieldType = reflect.TypeOf([]Tag{}) 49 | } 50 | } 51 | 52 | return true 53 | } 54 | 55 | table.FilterColumns(columnFilter) 56 | return true 57 | }), 58 | }, 59 | } 60 | 61 | e := ExportOptions{ 62 | RootDir: rootDir, 63 | // Optionally: 64 | // GetFileName: EachTableToItsOwnPackage, 65 | GetFileName: EachTableGroupToItsOwnPackage(), 66 | } 67 | 68 | if err := GenerateSchemaFromDatabase(context.Background(), i, e); err != nil { 69 | fmt.Println(err.Error()) 70 | return 71 | } 72 | 73 | fmt.Println("OK") 74 | 75 | // Output: 76 | // OK 77 | } 78 | -------------------------------------------------------------------------------- /gen/export_options.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "fmt" 5 | "io/fs" 6 | "path/filepath" 7 | "strings" 8 | 9 | "github.com/kataras/pg/desc" 10 | ) 11 | 12 | // ExportOptions is the options for the GenerateColumnsFromSchema function. 13 | type ExportOptions struct { 14 | RootDir string 15 | FileMode fs.FileMode 16 | 17 | ToSingular func(string) string 18 | GetFileName func(rootDir, tableName string) string 19 | GetPackageName func(tableName string) string 20 | } 21 | 22 | func EachTableToItsOwnPackage(rootDir, tableName string) string { 23 | if strings.HasSuffix(tableName, ".go") { 24 | return filepath.Join(rootDir, tableName) 25 | } 26 | 27 | packageName := desc.Singular(tableName) 28 | filename := filepath.Join(rootDir, packageName, packageName+".go") 29 | return filename 30 | } 31 | 32 | func EachTableGroupToItsOwnPackage() func(rootDir, tableName string) string { 33 | visitedTables := make(map[string]struct{}) // table group. 34 | 35 | getTableGroup := func(rootDir, tableName string) string { 36 | tableName = desc.Singular(tableName) 37 | for t := range visitedTables { 38 | if strings.HasPrefix(tableName, t+"_") { 39 | return t 40 | } 41 | } 42 | 43 | visitedTables[tableName] = struct{}{} 44 | return tableName 45 | } 46 | 47 | return func(rootDir, tableName string) string { 48 | if strings.HasSuffix(tableName, ".go") { 49 | return filepath.Join(rootDir, tableName) 50 | } 51 | 52 | tableGroup := getTableGroup(rootDir, tableName) 53 | return filepath.Join(rootDir, tableGroup, desc.Singular(tableName)+".go") 54 | } 55 | } 56 | 57 | func (opts *ExportOptions) apply() error { 58 | if opts.RootDir == "" { 59 | opts.RootDir = "./" 60 | } 61 | 62 | if opts.FileMode <= 0 { 63 | opts.FileMode = 0777 64 | } 65 | 66 | rootDir, err := filepath.Abs(opts.RootDir) 67 | if err != nil { 68 | return fmt.Errorf("filepath.Abs: %w", err) 69 | } 70 | opts.RootDir = rootDir // we need the fullpath in order to find the package name if missing. 71 | 72 | if opts.ToSingular == nil { 73 | opts.ToSingular = desc.Singular 74 | } 75 | 76 | if opts.GetFileName == nil { 77 | opts.GetFileName = func(rootDir, tableName string) string { 78 | filename := tableName 79 | 80 | if filename == "" { // if empty default the filename to the last part of the root dir +.go. 81 | filename = strings.TrimPrefix(filepath.Base(rootDir), "_") 82 | } else if strings.HasSuffix(filename, ".go") { 83 | return filepath.Join(rootDir, filename) 84 | } else { // otherwise get the singular form of the tablename + .go. 85 | filename = opts.ToSingular(tableName) 86 | } 87 | 88 | filename = filepath.Join(rootDir, filename) 89 | return fmt.Sprintf("%s.go", filename) 90 | } 91 | } 92 | 93 | if opts.GetPackageName == nil { 94 | opts.GetPackageName = func(tableName string) string { 95 | if tableName == "" { 96 | return strings.TrimPrefix(filepath.Base(opts.RootDir), "_") 97 | } 98 | 99 | filename := opts.GetFileName(opts.RootDir, tableName) // contains the full path let's get the last part of it as package name. 100 | packageName := filepath.Base(filepath.Dir(filename)) 101 | packageName = strings.TrimPrefix(packageName, "_") 102 | if packageName == "" { 103 | packageName = filepath.Base(opts.RootDir) // else it's current dir. 104 | } 105 | 106 | return packageName 107 | } 108 | } 109 | 110 | return nil 111 | } 112 | -------------------------------------------------------------------------------- /gen/schema_columns_gen.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/format" 7 | "os" 8 | "path/filepath" 9 | "runtime" 10 | "strings" 11 | "text/template" 12 | 13 | "github.com/kataras/pg" 14 | ) 15 | 16 | // GenerateColumnsFromSchema generates Go code for the given schema. 17 | // The generated code includes a struct for each table which contains 18 | // struct fields for each column information. 19 | // 20 | // Example Code: 21 | // 22 | // schema := pg.NewSchema() 23 | // schema.MustRegister("companies", Company{}) 24 | // schema.MustRegister("customers", Customer{}) 25 | // 26 | // opts := &ExportOptions{ 27 | // RootDir: "./definition", 28 | // } 29 | // 30 | // if err := GenerateColumnsFromSchema(schema, opts); err != nil { 31 | // t.Fatal(err) 32 | // } 33 | // 34 | // Usage: 35 | // definition.Company.Name.String() // returns "name" 36 | // definition.Customer.Email.String() // returns "email" 37 | // 38 | // Useful for type-safe query builders. 39 | func GenerateColumnsFromSchema(s *pg.Schema, e ExportOptions) error { 40 | if err := e.apply(); err != nil { 41 | return err 42 | } 43 | 44 | tables := s.Tables() 45 | if len(tables) == 0 { 46 | return nil 47 | } 48 | 49 | // Create root file to store common structures and functions. 50 | data, err := generateRoot(e.GetPackageName("")) 51 | if err != nil { 52 | return fmt.Errorf("generate root: %s: %w", e.GetPackageName("columns"), err) 53 | } 54 | 55 | filename := e.GetFileName(e.RootDir, "") 56 | 57 | err = mkdir(filename) 58 | if err != nil { 59 | return fmt.Errorf("mkdir: %s: %w", e.RootDir, err) 60 | } 61 | 62 | err = os.WriteFile(filename, data, e.FileMode) 63 | if err != nil { 64 | return fmt.Errorf("write root: %s: %w", filename, err) 65 | } 66 | 67 | // Create each file for each table definition. 68 | for _, td := range tables { 69 | data, err = generateTableDefininion(e.GetPackageName(td.Name), td) 70 | if err != nil { 71 | return fmt.Errorf("generate table: %s: %w", td.Name, err) 72 | } 73 | 74 | filename := e.GetFileName(e.RootDir, td.Name) 75 | if filename == "" { 76 | continue 77 | } 78 | 79 | mkdir(filename) 80 | 81 | err = os.WriteFile(filename, data, e.FileMode) 82 | if err != nil { 83 | return fmt.Errorf("write table: %s: defininion file: %s: %w", td.Name, filename, err) 84 | } 85 | } 86 | 87 | return nil 88 | } 89 | 90 | var generateRootTmpl = template.Must( 91 | template.New("").Parse(` 92 | // Code generated by pg. DO NOT EDIT. 93 | package {{.PackageName}} 94 | 95 | // Column is a struct that represents a column in a table. 96 | type Column struct { 97 | Name string 98 | } 99 | 100 | // String returns the name of the column. 101 | func (c Column) String() string { 102 | return c.Name 103 | } 104 | `)) 105 | 106 | func generateRoot(packageName string) ([]byte, error) { 107 | tmplData := generateTemplateData{ 108 | PackageName: packageName, 109 | } 110 | var buf bytes.Buffer 111 | if err := generateRootTmpl.Execute(&buf, tmplData); err != nil { 112 | return nil, err 113 | } 114 | 115 | return format.Source(buf.Bytes()) 116 | } 117 | 118 | type generateTemplateData struct { 119 | *pg.Table 120 | PackageName string 121 | } 122 | 123 | var generateTableDefininionTmpl = template.Must( 124 | template.New(""). // Note that we don't need to put import paths here as the only one type is the generated Column in the same package. 125 | Parse(` 126 | // Code generated by pg. DO NOT EDIT. 127 | package {{.PackageName}} 128 | 129 | // {{.StructName}} is a struct value that represents a record in the {{.Name}} table. 130 | var {{.StructName}} = struct { 131 | PG_TableName string 132 | {{range .Columns}} {{.FieldName}} Column 133 | {{end}} 134 | }{ 135 | PG_TableName: "{{.Name}}", 136 | {{range .Columns}} {{.FieldName}}: Column{ 137 | Name: "{{.Name}}", 138 | }, 139 | {{end }} 140 | }`)) 141 | 142 | func generateTableDefininion(packageName string, td *pg.Table) ([]byte, error) { 143 | tmplData := generateTemplateData{ 144 | Table: td, 145 | PackageName: packageName, 146 | } 147 | var buf bytes.Buffer 148 | if err := generateTableDefininionTmpl.Execute(&buf, tmplData); err != nil { 149 | return nil, err 150 | } 151 | 152 | return format.Source(buf.Bytes()) 153 | } 154 | 155 | func mkdir(path string) error { 156 | dir := filepath.Dir(path) 157 | return os.MkdirAll(dir, 0777) 158 | } 159 | 160 | func getCallerPackageName() string { 161 | pc, _, _, _ := runtime.Caller(2) 162 | funcName := runtime.FuncForPC(pc).Name() 163 | lastSlash := strings.LastIndexByte(funcName, '/') 164 | if lastSlash < 0 { 165 | lastSlash = 0 166 | } 167 | 168 | lastDot := strings.LastIndexByte(funcName[lastSlash:], '.') 169 | if lastDot == -1 { 170 | return "" 171 | } 172 | 173 | return funcName[:lastDot] 174 | } 175 | -------------------------------------------------------------------------------- /gen/schema_columns_gen_test.go: -------------------------------------------------------------------------------- 1 | package gen 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "time" 7 | 8 | "github.com/kataras/pg" 9 | ) 10 | 11 | // BaseEntity is a struct that defines common fields for all entities in the database. 12 | // It has an ID field of type uuid that is the primary key, and two timestamp fields 13 | // for tracking the creation and update times of each row. 14 | type BaseEntity struct { 15 | ID string `pg:"type=uuid,primary"` 16 | CreatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 17 | UpdatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 18 | } 19 | 20 | // Company is a struct that represents a company entity in the database. 21 | type Company struct { 22 | BaseEntity 23 | 24 | Name string `pg:"type=varchar(255)"` 25 | } 26 | 27 | // Customer is a struct that represents a customer entity in the database. 28 | // It embeds the BaseEntity struct and adds a CognitoUserID field of type uuid 29 | // that is required and unique. It also specifies a conflict resolution strategy 30 | // for the CognitoUserID field in case of duplicate values. 31 | type Customer struct { 32 | BaseEntity 33 | // CognitoUserID string `pg:"type=uuid,conflict=DO UPDATE SET cognito_user_id=EXCLUDED.cognito_user_id"` 34 | 35 | CognitoUserID string `pg:"type=uuid,unique_index=customer_unique_idx"` 36 | Email string `pg:"type=varchar(255),unique_index=customer_unique_idx"` // optional: unique to allow upsert by "email"-only column confliction instead of the unique_index. 37 | Name string `pg:"type=varchar(255)"` 38 | CompanyID string `pg:"type=uuid,ref=companies(id)"` 39 | } 40 | 41 | func TestGenerateColumnsFromSchema(t *testing.T) { 42 | const ( 43 | rootDir = "./_testdata" 44 | ) 45 | defer func() { 46 | os.RemoveAll(rootDir) 47 | time.Sleep(1 * time.Second) 48 | }() 49 | 50 | schema := pg.NewSchema() 51 | schema.MustRegister("companies", Company{}) 52 | schema.MustRegister("customers", Customer{}) 53 | 54 | opts := ExportOptions{ 55 | RootDir: rootDir, 56 | } 57 | if err := GenerateColumnsFromSchema(schema, opts); err != nil { 58 | t.Fatal(err) 59 | } 60 | 61 | companyContents, err := os.ReadFile(rootDir + "/company.go") 62 | if err != nil { 63 | t.Fatal(err) 64 | } 65 | 66 | rootContents, err := os.ReadFile(rootDir + "/testdata.go") 67 | if err != nil { 68 | t.Fatal(err) 69 | } 70 | customerContents, err := os.ReadFile(rootDir + "/customer.go") 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | 75 | const ( 76 | expectedRootContents = `// Code generated by pg. DO NOT EDIT. 77 | package testdata 78 | 79 | // Column is a struct that represents a column in a table. 80 | type Column struct { 81 | Name string 82 | } 83 | 84 | // String returns the name of the column. 85 | func (c Column) String() string { 86 | return c.Name 87 | } 88 | ` 89 | 90 | expectedCompanyContents = `// Code generated by pg. DO NOT EDIT. 91 | package testdata 92 | 93 | // Company is a struct value that represents a record in the companies table. 94 | var Company = struct { 95 | PG_TableName string 96 | ID Column 97 | CreatedAt Column 98 | UpdatedAt Column 99 | Name Column 100 | }{ 101 | PG_TableName: "companies", 102 | ID: Column{ 103 | Name: "id", 104 | }, 105 | CreatedAt: Column{ 106 | Name: "created_at", 107 | }, 108 | UpdatedAt: Column{ 109 | Name: "updated_at", 110 | }, 111 | Name: Column{ 112 | Name: "name", 113 | }, 114 | } 115 | ` 116 | expectedCustomerContents = `// Code generated by pg. DO NOT EDIT. 117 | package testdata 118 | 119 | // Customer is a struct value that represents a record in the customers table. 120 | var Customer = struct { 121 | PG_TableName string 122 | ID Column 123 | CreatedAt Column 124 | UpdatedAt Column 125 | CognitoUserID Column 126 | Email Column 127 | Name Column 128 | CompanyID Column 129 | }{ 130 | PG_TableName: "customers", 131 | ID: Column{ 132 | Name: "id", 133 | }, 134 | CreatedAt: Column{ 135 | Name: "created_at", 136 | }, 137 | UpdatedAt: Column{ 138 | Name: "updated_at", 139 | }, 140 | CognitoUserID: Column{ 141 | Name: "cognito_user_id", 142 | }, 143 | Email: Column{ 144 | Name: "email", 145 | }, 146 | Name: Column{ 147 | Name: "name", 148 | }, 149 | CompanyID: Column{ 150 | Name: "company_id", 151 | }, 152 | } 153 | ` 154 | ) 155 | 156 | if string(companyContents) != expectedCompanyContents { 157 | t.Fatalf("expected company contents to be %q but got %q", expectedCompanyContents, companyContents) 158 | } 159 | 160 | if string(rootContents) != expectedRootContents { 161 | t.Fatalf("expected root contents to be %q but got %q", expectedRootContents, rootContents) 162 | } 163 | 164 | if string(customerContents) != expectedCustomerContents { 165 | t.Fatalf("expected customer contents to be %q but got %q", expectedCustomerContents, customerContents) 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/kataras/pg 2 | 3 | go 1.24 4 | 5 | require ( 6 | github.com/gertd/go-pluralize v0.2.1 7 | github.com/jackc/pgx/v5 v5.7.5 8 | golang.org/x/mod v0.24.0 9 | ) 10 | 11 | require ( 12 | github.com/jackc/pgpassfile v1.0.0 // indirect 13 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect 14 | github.com/jackc/puddle/v2 v2.2.2 // indirect 15 | golang.org/x/crypto v0.37.0 // indirect 16 | golang.org/x/sync v0.13.0 // indirect 17 | golang.org/x/text v0.24.0 // indirect 18 | ) 19 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/gertd/go-pluralize v0.2.1 h1:M3uASbVjMnTsPb0PNqg+E/24Vwigyo/tvyMTtAlLgiA= 5 | github.com/gertd/go-pluralize v0.2.1/go.mod h1:rbYaKDbsXxmRfr8uygAEKhOWsjyrrqrkHVpZvoOp8zk= 6 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 7 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 8 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= 9 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 10 | github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs= 11 | github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= 12 | github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= 13 | github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= 14 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 15 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 16 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 17 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 18 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 19 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 20 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 21 | golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= 22 | golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= 23 | golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= 24 | golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= 25 | golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= 26 | golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 27 | golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= 28 | golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= 29 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 30 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 31 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 32 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 33 | -------------------------------------------------------------------------------- /listener.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "sync/atomic" 8 | "unsafe" 9 | 10 | "github.com/jackc/pgx/v5/pgconn" 11 | "github.com/jackc/pgx/v5/pgxpool" 12 | ) 13 | 14 | // Notification is a type alias of pgconn.Notification type. 15 | type Notification = pgconn.Notification 16 | 17 | // Closer is the interface which is implemented by the Listener. 18 | // It's used to close the underline connection. 19 | type Closer interface { 20 | Close(ctx context.Context) error 21 | } 22 | 23 | // Listener represents a postgres database LISTEN connection. 24 | type Listener struct { 25 | conn *pgxpool.Conn 26 | 27 | channel string 28 | closed uint32 29 | } 30 | 31 | var _ Closer = (*Listener)(nil) 32 | 33 | // ErrEmptyPayload is returned when the notification payload is empty. 34 | var ErrEmptyPayload = fmt.Errorf("empty payload") 35 | 36 | // Accept waits for a notification and returns it. 37 | func (l *Listener) Accept(ctx context.Context) (*Notification, error) { 38 | nf, err := l.conn.Conn().WaitForNotification(ctx) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | /* Sadly this is not possible due to the Go's limitations. 44 | var payload T 45 | if s, ok := payload.(string); ok { 46 | // use nativeAccept. 47 | } 48 | */ 49 | 50 | if len(nf.Payload) == 0 { 51 | return nil, ErrEmptyPayload 52 | } 53 | 54 | return nf, nil 55 | } 56 | 57 | // Close closes the listener connection. 58 | func (l *Listener) Close(ctx context.Context) error { 59 | if l == nil { 60 | return nil 61 | } 62 | 63 | if l.conn == nil { 64 | return nil 65 | } 66 | 67 | if atomic.CompareAndSwapUint32(&l.closed, 0, 1) { 68 | defer l.conn.Release() 69 | 70 | query := `SELECT UNLISTEN $1;` 71 | _, err := l.conn.Exec(ctx, query, l.channel) 72 | if err != nil { 73 | return err 74 | } 75 | } 76 | 77 | return nil 78 | } 79 | 80 | // notifyJSON sends a notification of any type to the underline database listener. 81 | func notifyJSON(ctx context.Context, db *DB, channel string, payload any) error { 82 | b, err := json.Marshal(payload) 83 | if err != nil { 84 | return err 85 | } 86 | 87 | return notifyNative(ctx, db, channel, b) 88 | } 89 | 90 | // NotifyNative sends a raw notification to the underline database listener, 91 | // it accepts string or a slice of bytes because that's the only raw types that are allowed to be delivered. 92 | func notifyNative[T string | []byte](ctx context.Context, db *DB, channel string, payload T) error { 93 | query := `SELECT pg_notify($1, $2)` 94 | _, err := db.Pool.Exec(context.Background(), query, channel, payload) // Always on top. 95 | return err 96 | } 97 | 98 | // UnmarshalNotification returns the notification payload as a custom type of T. 99 | func UnmarshalNotification[T any](n *Notification) (T, error) { 100 | var payload T 101 | 102 | b := stringToBytes(n.Payload) 103 | 104 | err := json.Unmarshal(b, &payload) 105 | if err != nil { 106 | return payload, err 107 | } 108 | 109 | return payload, nil 110 | } 111 | 112 | func stringToBytes(s string) []byte { 113 | return unsafe.Slice(unsafe.StringData(s), len(s)) 114 | } 115 | -------------------------------------------------------------------------------- /listener_example_test.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | ) 8 | 9 | // go test -coverprofile=cov 10 | // go tool cover -html=cov 11 | func ExampleDB_Listen() { 12 | db, err := openEmptyTestConnection() 13 | if err != nil { 14 | handleExampleError(err) 15 | return 16 | } 17 | // defer db.Close() 18 | 19 | const channel = "chat_db" 20 | 21 | conn, err := db.Listen(context.Background(), channel) 22 | if err != nil { 23 | fmt.Println(fmt.Errorf("listen: %w\n", err)) 24 | return 25 | } 26 | 27 | go func() { 28 | // To just terminate this listener's connection and unlisten from the channel: 29 | defer conn.Close(context.Background()) 30 | 31 | for { 32 | notification, err := conn.Accept(context.Background()) 33 | if err != nil { 34 | fmt.Println(fmt.Errorf("accept: %w\n", err)) 35 | return 36 | } 37 | 38 | fmt.Printf("channel: %s, payload: %s\n", notification.Channel, notification.Payload) 39 | } 40 | }() 41 | 42 | if err = db.Notify(context.Background(), channel, "hello"); err != nil { 43 | fmt.Println(fmt.Errorf("notify: hello: %w", err)) 44 | return 45 | } 46 | 47 | if err = db.Notify(context.Background(), channel, "world"); err != nil { 48 | fmt.Println(fmt.Errorf("notify: world: %w", err)) 49 | return 50 | } 51 | 52 | time.Sleep(5 * time.Second) // give it sometime to receive the notifications. 53 | // Output: 54 | // channel: chat_db, payload: hello 55 | // channel: chat_db, payload: world 56 | } 57 | 58 | type Message struct { 59 | BaseEntity 60 | 61 | Sender string `pg:"type=varchar(255)" json:"sender"` 62 | Body string `pg:"type=text" json:"body"` 63 | } 64 | 65 | func Example_notify_JSON() { 66 | schema := NewSchema() 67 | db, err := Open(context.Background(), schema, getTestConnString()) 68 | if err != nil { 69 | fmt.Println(err) 70 | return 71 | } 72 | // defer db.Close() 73 | 74 | const channel = "chat_json" 75 | 76 | conn, err := db.Listen(context.Background(), channel) 77 | if err != nil { 78 | fmt.Println(fmt.Errorf("listen: %w", err)) 79 | } 80 | 81 | go func() { 82 | // To just terminate this listener's connection and unlisten from the channel: 83 | defer conn.Close(context.Background()) 84 | 85 | for { 86 | notification, err := conn.Accept(context.Background()) 87 | if err != nil { 88 | fmt.Println(fmt.Errorf("accept: %w\n", err)) 89 | return 90 | } 91 | 92 | payload, err := UnmarshalNotification[Message](notification) 93 | if err != nil { 94 | fmt.Println(fmt.Errorf("N: %w", err)) 95 | return 96 | } 97 | 98 | fmt.Printf("channel: %s, payload.sender: %s, payload.body: %s\n", 99 | notification.Channel, payload.Sender, payload.Body) 100 | } 101 | }() 102 | 103 | firstMessage := Message{ 104 | Sender: "kataras", 105 | Body: "hello", 106 | } 107 | if err = db.Notify(context.Background(), channel, firstMessage); err != nil { 108 | fmt.Println(fmt.Errorf("notify: first message: %w", err)) 109 | return 110 | } 111 | 112 | secondMessage := Message{ 113 | Sender: "kataras", 114 | Body: "world", 115 | } 116 | 117 | if err = db.Notify(context.Background(), channel, secondMessage); err != nil { 118 | fmt.Println(fmt.Errorf("notify: second message: %w", err)) 119 | return 120 | } 121 | 122 | time.Sleep(5 * time.Second) // give it sometime to receive the notifications, this is too much though. 123 | // Output: 124 | // channel: chat_json, payload.sender: kataras, payload.body: hello 125 | // channel: chat_json, payload.sender: kataras, payload.body: world 126 | } 127 | -------------------------------------------------------------------------------- /repository_example_test.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | // Repositories. 9 | 10 | // CustomerRepository is a struct that wraps a generic Repository instance with the Customer type parameter. 11 | // It provides methods for accessing and manipulating customer data in the database. 12 | type CustomerRepository struct { 13 | *Repository[Customer] 14 | } 15 | 16 | // NewCustomerRepository creates and returns a new CustomerRepository instance with the given DB instance. 17 | func NewCustomerRepository(db *DB) *CustomerRepository { 18 | return &CustomerRepository{ 19 | Repository: NewRepository[Customer](db), 20 | } 21 | } 22 | 23 | // InTransaction overrides the pg Repository's InTransaction method to include the custom type of CustomerRepository. 24 | // It takes a context and a function as arguments and executes the function within a database transaction, 25 | // passing it a CustomerRepository instance that uses the transactional DB instance. 26 | func (r *CustomerRepository) InTransaction(ctx context.Context, fn func(*CustomerRepository) error) (err error) { 27 | if r.DB().IsTransaction() { 28 | return fn(r) 29 | } 30 | 31 | return r.DB().InTransaction(ctx, func(db *DB) error { 32 | txRepository := NewCustomerRepository(db) 33 | return fn(txRepository) 34 | }) 35 | } 36 | 37 | // Exists is a custom method that uses the pg repository's Database instance to execute a query and return a result. 38 | // It takes a context and a cognitoUserID as arguments and checks if there is any customer row with that cognitoUserID in the database. 39 | func (r *CustomerRepository) Exists(ctx context.Context, cognitoUserID string) (exists bool, err error) { 40 | // query := `SELECT EXISTS(SELECT 1 FROM customers WHERE cognito_user_id = $1)` 41 | // err = r.QueryRow(ctx, query, cognitoUserID).Scan(&exists) 42 | // OR: 43 | 44 | exists, err = r.Repository.Exists(ctx, Customer{CognitoUserID: cognitoUserID}) 45 | return 46 | } 47 | 48 | // Registry is (optional) a struct that holds references to different repositories for accessing and manipulating data in the database. 49 | // It has a db field that is a pointer to a DB instance, and a customers field that is a pointer to a CustomerRepository instance. 50 | type Registry struct { 51 | db *DB 52 | 53 | customers *CustomerRepository 54 | blogs *Repository[Blog] 55 | blogPosts *Repository[BlogPost] 56 | } 57 | 58 | // NewRegistry creates and returns a new Registry instance with the given DB instance. 59 | // It also initializes the customers field with a new CustomerRepository instance that uses the same DB instance. 60 | func NewRegistry(db *DB) *Registry { 61 | return &Registry{ 62 | db: db, 63 | 64 | customers: NewCustomerRepository(db), 65 | blogs: NewRepository[Blog](db), 66 | blogPosts: NewRepository[BlogPost](db), 67 | } 68 | } 69 | 70 | // InTransaction overrides the pg Repository's InTransaction method to include the custom type of Registry. 71 | // It takes a context and a function as arguments and executes the function within a database transaction, 72 | // passing it a Registry instance that uses the transactional DB instance. 73 | func (r *Registry) InTransaction(ctx context.Context, fn func(*Registry) error) (err error) { 74 | if r.db.IsTransaction() { 75 | return fn(r) 76 | } 77 | 78 | return r.db.InTransaction(ctx, func(db *DB) error { 79 | txRegistry := NewRegistry(db) 80 | return fn(txRegistry) 81 | }) 82 | } 83 | 84 | // Customers returns the CustomerRepository instance of the Registry. 85 | func (r *Registry) Customers() *CustomerRepository { 86 | return r.customers 87 | } 88 | 89 | // Blogs returns the Repository instance of the Blog entity. 90 | func (r *Registry) Blogs() *Repository[Blog] { 91 | return r.blogs 92 | } 93 | 94 | // BlogPosts returns the Repository instance of the BlogPost entity. 95 | func (r *Registry) BlogPosts() *Repository[BlogPost] { 96 | return r.blogPosts 97 | } 98 | 99 | func ExampleNewRepository() { 100 | db, err := openTestConnection(true) 101 | if err != nil { 102 | handleExampleError(err) 103 | return 104 | } 105 | defer db.Close() 106 | 107 | registry := NewRegistry(db) // Create a new Registry instance with the DB instance. 108 | customers := registry.Customers() // Get the CustomerRepository instance from the Registry. 109 | 110 | // Repository example code. 111 | customerToInsert := Customer{ // Create a Customer struct to be inserted into the database. 112 | CognitoUserID: "373f90eb-00ac-410f-9fe0-1a7058d090ba", 113 | Email: "kataras2006@hotmail.com", 114 | Name: "kataras", 115 | } 116 | 117 | err = customers.InsertSingle(context.Background(), customerToInsert, &customerToInsert.ID) 118 | if err != nil { 119 | handleExampleError(err) 120 | return 121 | } 122 | 123 | fmt.Println(customerToInsert.ID) 124 | } 125 | -------------------------------------------------------------------------------- /repository_table_listener_example_test.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | ) 8 | 9 | func ExampleRepository_ListenTable() { 10 | db, err := openTestConnection(true) 11 | if err != nil { 12 | handleExampleError(err) 13 | return 14 | } 15 | defer db.Close() 16 | 17 | customers := NewRepository[Customer](db) 18 | 19 | closer, err := customers.ListenTable(context.Background(), func(evt TableNotification[Customer], err error) error { 20 | if err != nil { 21 | fmt.Printf("received error: %v\n", err) 22 | return err 23 | } 24 | 25 | fmt.Printf("table: %s, event: %s, old name: %s new name: %s\n", evt.Table, evt.Change, evt.Old.Name, evt.New.Name) 26 | return nil 27 | }) 28 | if err != nil { 29 | fmt.Println(err) 30 | return 31 | } 32 | defer closer.Close(context.Background()) 33 | 34 | newCustomer := Customer{ 35 | CognitoUserID: "766064d4-a2a7-442d-aa75-33493bb4dbb9", 36 | Email: "kataras2024@hotmail.com", 37 | Name: "Makis", 38 | } 39 | err = customers.InsertSingle(context.Background(), newCustomer, &newCustomer.ID) 40 | if err != nil { 41 | fmt.Println(err) 42 | return 43 | } 44 | 45 | newCustomer.Name = "Makis_UPDATED" 46 | _, err = customers.UpdateOnlyColumns(context.Background(), []string{"name"}, newCustomer) 47 | if err != nil { 48 | fmt.Println(err) 49 | return 50 | } 51 | time.Sleep(8 * time.Second) // give it sometime to receive the notifications. 52 | // Output: 53 | // table: customers, event: INSERT, old name: new name: Makis 54 | // table: customers, event: UPDATE, old name: Makis new name: Makis_UPDATED 55 | } 56 | -------------------------------------------------------------------------------- /schema.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "sort" 7 | 8 | "github.com/kataras/pg/desc" 9 | ) 10 | 11 | // Schema is a type that represents a schema for the database. 12 | type Schema struct { 13 | // structCache is a map from reflect.Type to Table 14 | // that stores the table definitions for the registered structs 15 | structCache map[reflect.Type]*desc.Table 16 | orderedTypes []reflect.Type 17 | 18 | passwordHandler *desc.PasswordHandler // cache for tables. 19 | // The name of the "updated_at" column. Defaults to "updated_at" but it can be modified, 20 | // this is useful to set when triggers should be registered automatically. 21 | // 22 | // If set to empty then triggers will not be registered automatically. 23 | UpdatedAtColumnName string 24 | // Set the name of the trigger that sets the "updated_at" column, defaults to "set_timestamp". 25 | // 26 | // If set to empty then triggers will not be registered automatically. 27 | SetTimestampTriggerName string 28 | 29 | // Strict reports whether the schema should be strict on the database side. 30 | // It's enabled by default. 31 | Strict bool 32 | } 33 | 34 | // NewSchema creates and returns a new Schema with an initialized struct cache. 35 | func NewSchema() *Schema { 36 | return &Schema{ 37 | // Make a map from reflect.Type to Table. 38 | structCache: make(map[reflect.Type]*desc.Table), 39 | // Set the default name for the "updated_at" column. 40 | UpdatedAtColumnName: "updated_at", 41 | // Set the default name for the trigger that sets the "updated_at" column. 42 | // having triggers with the same name (`set_timestamp`) across different tables is perfectly fine in PostgreSQL. 43 | // Trigger names only need to be unique within a table, not across the entire database. 44 | SetTimestampTriggerName: "set_timestamp", 45 | Strict: true, 46 | } 47 | } 48 | 49 | /* 50 | 51 | type TextFunc = func(context.Context, string) (string, error) 52 | 53 | func NewPasswordHandler(set, get TextFunc) PasswordHandler { 54 | return &plainPasswordHandler{ 55 | setter: set, 56 | getter: get, 57 | } 58 | } 59 | 60 | type plainPasswordHandler struct { 61 | setter TextFunc 62 | getter TextFunc 63 | } 64 | 65 | func (h *plainPasswordHandler) Set(ctx context.Context, plainPassword string) (encryptedPassword string, err error) { 66 | return h.setter(ctx, plainPassword) 67 | } 68 | 69 | func (h *plainPasswordHandler) Get(ctx context.Context, encryptedPassword string) (plainPassword string, err error) { 70 | return h.getter(ctx, encryptedPassword) 71 | } 72 | */ 73 | 74 | // HandlePassword sets the password handler. 75 | func (s *Schema) HandlePassword(handler desc.PasswordHandler) *Schema { 76 | if handler.Encrypt == nil && handler.Decrypt == nil { 77 | return s 78 | } 79 | 80 | s.passwordHandler = &handler 81 | return s 82 | } 83 | 84 | // View is a TableFilterFunc that sets the table type to "view" and returns true. 85 | // 86 | // Example: 87 | // 88 | // schema.MustRegister("customer_master", FullCustomer{}, pg.View) 89 | var View = func(td *desc.Table) bool { 90 | td.Type = desc.TableTypeView 91 | return true 92 | } 93 | 94 | // Presenter is a TableFilterFunc that sets the table type to "presenter" and returns true. 95 | // A presenter is a table that is used to present data from one or more tables with custom select queries. 96 | // It's not a base table neither a view. 97 | // Example: 98 | // 99 | // schema.MustRegister("customer_presenter", CustomerPresenter{}, pg.Presenter) 100 | var Presenter = func(td *desc.Table) bool { 101 | td.Type = desc.TableTypePresenter 102 | return true 103 | } 104 | 105 | // MustRegister same as "Register" but it panics on errors and returns the Schema instance instead of the Table one. 106 | func (s *Schema) MustRegister(tableName string, emptyStructValue any, opts ...TableFilterFunc) *Schema { 107 | td, err := s.Register(tableName, emptyStructValue, opts...) // call Register with the same arguments 108 | if err != nil { // if there is an error 109 | panic(err) // panic with the error 110 | } 111 | td.SetStrict(s.Strict) 112 | 113 | return s // return the table definition 114 | } 115 | 116 | // Register registers a database model (a struct value) mapped to a specific database table name. 117 | // Returns the generated Table definition. 118 | func (s *Schema) Register(tableName string, emptyStructValue any, opts ...TableFilterFunc) (*desc.Table, error) { 119 | typ := desc.IndirectType(reflect.TypeOf(emptyStructValue)) // get the underlying type of the struct value 120 | 121 | td, err := desc.ConvertStructToTable(tableName, typ) // convert the type to a table definition 122 | if err != nil { // if there is an error 123 | return nil, err // return the error 124 | } 125 | 126 | td.RegisteredPosition = len(s.structCache) + 1 // assign the registered position as the current size of the cache plus one 127 | td.PasswordHandler = s.passwordHandler 128 | 129 | for _, opt := range opts { 130 | if !opt(td) { // do not register if returns false. 131 | return td, nil 132 | } 133 | } 134 | 135 | s.structCache[typ] = td // store the table definition in the cache with the type as the key 136 | s.orderedTypes = append(s.orderedTypes, typ) 137 | 138 | return td, nil // return the table definition and no error 139 | } 140 | 141 | // Last returns the last registered table definition. 142 | func (s *Schema) Last() *desc.Table { 143 | if len(s.orderedTypes) == 0 { 144 | return nil 145 | } 146 | 147 | return s.structCache[s.orderedTypes[len(s.orderedTypes)-1]] 148 | } 149 | 150 | // Get takes a reflect.Type that represents a struct type 151 | // and returns a pointer to a Table that represents the table definition for the database 152 | // or an error if the type is not registered in the schema. 153 | func (s *Schema) Get(typ reflect.Type) (*desc.Table, error) { // NOTE: to make it even faster we could set and then retrieve a Definition variable for each table struct type by interface check. 154 | typ = desc.IndirectType(typ) // get the underlying type of the struct value. 155 | 156 | td, ok := s.structCache[typ] // get the table definition from the cache 157 | if !ok { // if not found 158 | return nil, fmt.Errorf("%s was not registered, forgot Schema.Register?", typ.String()) // return an error 159 | } 160 | 161 | return td, nil // return the table definition and no error 162 | } 163 | 164 | // GetByTableName takes a table name as a string 165 | // and returns a pointer to a Table that represents the table definition for the database 166 | // or an error if the table name is not registered in the schema. 167 | func (s *Schema) GetByTableName(tableName string) (*desc.Table, error) { 168 | for _, td := range s.structCache { // loop over all the table definitions in the cache 169 | if td.Name == tableName { // if the table name matches 170 | return td, nil // return the table definition and no error 171 | } 172 | } 173 | 174 | return nil, fmt.Errorf("table %s was not registered, forgot Schema.Register?", tableName) // return an error if no match found 175 | } 176 | 177 | // Tables returns a slice of pointers to Table that represents all the table definitions in the schema 178 | // sorted by their registered position. 179 | func (s *Schema) Tables(types ...desc.TableType) []*desc.Table { 180 | // make a slice of pointers to Table with the same capacity as the number of entries in the cache 181 | list := make([]*desc.Table, 0, len(s.structCache)) 182 | 183 | for _, td := range s.structCache { // loop over all the table definitions in the cache 184 | if !td.IsType(types...) { // if not the table type matches the given types (if any) then skip it. 185 | continue 186 | } 187 | 188 | list = append(list, td) // append each table definition to the slice 189 | } 190 | 191 | sort.Slice(list, func(i, j int) bool { // sort the slice by their registered position 192 | return list[i].RegisteredPosition < list[j].RegisteredPosition 193 | }) 194 | 195 | return list // return the sorted slice 196 | } 197 | 198 | // TableNames returns a slice of strings that represents all the table names in the schema. 199 | func (s *Schema) TableNames(types ...desc.TableType) []string { 200 | // make a slice of strings with the same capacity as the number of entries in the cache 201 | list := make([]string, 0, len(s.structCache)) 202 | 203 | for _, td := range s.Tables(types...) { // loop over all the table definitions in the schema (sorted by their registered position) 204 | list = append(list, td.Name) // append each table name to the slice 205 | } 206 | 207 | return list // return the slice of table names 208 | } 209 | 210 | // HasColumnType takes a DataType that represents a data type for the database 211 | // and returns true if any of the tables in the schema has a column with that data type. 212 | func (s *Schema) HasColumnType(dataTypes ...desc.DataType) bool { 213 | for _, td := range s.Tables() { // loop over all the tables in the schema (sorted by their registered position) 214 | for _, col := range td.Columns { // loop over all the columns in each table 215 | for _, dt := range dataTypes { 216 | if col.Type == dt { // if the column has the same data type as given 217 | return true // return true 218 | } 219 | } 220 | } 221 | } 222 | 223 | return false // return false if no match found 224 | } 225 | 226 | // HasPassword reports whether the tables in the schema have a column with the password feature enabled. 227 | func (s *Schema) HasPassword() bool { 228 | for _, td := range s.Tables() { 229 | for _, col := range td.Columns { // loop over all the columns in each table 230 | if col.Password { 231 | return true 232 | } 233 | } 234 | } 235 | 236 | return false 237 | } 238 | -------------------------------------------------------------------------------- /schema_example_test.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | // Structs. 9 | 10 | // BaseEntity is a struct that defines common fields for all entities in the database. 11 | // It has an ID field of type uuid that is the primary key, and two timestamp fields 12 | // for tracking the creation and update times of each row. 13 | type BaseEntity struct { 14 | ID string `pg:"type=uuid,primary"` 15 | CreatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 16 | UpdatedAt time.Time `pg:"type=timestamp,default=clock_timestamp()"` 17 | } 18 | 19 | // Customer is a struct that represents a customer entity in the database. 20 | // It embeds the BaseEntity struct and adds a CognitoUserID field of type uuid 21 | // that is required and unique. It also specifies a conflict resolution strategy 22 | // for the CognitoUserID field in case of duplicate values. 23 | type Customer struct { 24 | BaseEntity 25 | // CognitoUserID string `pg:"type=uuid,unique,conflict=DO UPDATE SET cognito_user_id=EXCLUDED.cognito_user_id"` 26 | 27 | CognitoUserID string `pg:"type=uuid,unique_index=customer_unique_idx"` 28 | Email string `pg:"type=varchar(255),unique_index=customer_unique_idx"` 29 | // ^ optional: unique to allow upsert by "email"-only column confliction instead of the unique_index. 30 | Name string `pg:"type=varchar(255),index=btree"` 31 | 32 | Username string `pg:"type=varchar(255),default=''"` 33 | } 34 | 35 | // Blog is a struct that represents a blog entity in the database. 36 | // It embeds the BaseEntity struct and has no other fields. 37 | type Blog struct { 38 | BaseEntity 39 | 40 | Name string `pg:"type=varchar(255)"` 41 | } 42 | 43 | // BlogPost is a struct that represents a blog post entity in the database. 44 | // It embeds the BaseEntity struct and adds several fields for the blog post details, 45 | // such as BlogID, Title, PhotoURL, SourceURL, ReadTimeMinutes, and Category. 46 | // The BlogID field is a foreign key that references the ID field of the blogs table, 47 | // with cascade option for deletion and deferrable option for constraint checking. 48 | // The Title and SourceURL fields are part of a unique index named uk_blog_post, 49 | // which ensures that no two blog posts have the same title or source URL. 50 | // The ReadTimeMinutes field is a smallint with a default value of 1 and a check constraint 51 | // that ensures it is positive. The Category field is a smallint with a default value of 0. 52 | type BlogPost struct { 53 | BaseEntity 54 | 55 | BlogID string `pg:"type=uuid,index,ref=blogs(id cascade deferrable)"` 56 | Title string `pg:"type=varchar(255),unique_index=uk_blog_post"` 57 | PhotoURL string `pg:"type=varchar(255)"` 58 | SourceURL string `pg:"type=varchar(255),unique_index=uk_blog_post"` 59 | ReadTimeMinutes int `pg:"type=smallint,default=1,check=read_time_minutes > 0"` 60 | Category int `pg:"type=smallint,default=0"` 61 | 62 | SearchTerms []string `pg:"type=varchar[]"` // Test a slice of strings. 63 | ReadDurations []time.Duration `pg:"type=bigint[]"` // Test a slice of time.Duration based on an int64. 64 | 65 | // Custom types. 66 | Feature Feature `pg:"type=jsonb"` // Test a JSON structure. 67 | OtherFeatures Features `pg:"type=jsonb"` // Test a JSON array of structures behind a custom type. 68 | Tags []Tag `pg:"type=jsonb"` // Test a JSON array of structures. 69 | } 70 | 71 | type Features []Feature 72 | 73 | type Feature struct { 74 | IsFeatured bool `json:"is_featured"` 75 | } 76 | 77 | type Tag struct { 78 | Name string `json:"name"` 79 | Value any `json:"value"` 80 | } 81 | 82 | func ExampleNewSchema() { 83 | // Database code. 84 | schema := NewSchema() 85 | schema.MustRegister("customers", Customer{}) // Register the Customer struct as a table named "customers". 86 | schema.MustRegister("blogs", Blog{}) // Register the Blog struct as a table named "blogs". 87 | schema.MustRegister("blog_posts", BlogPost{}) // Register the BlogPost struct as a table named "blog_posts". 88 | 89 | fmt.Println("OK") 90 | // Output: 91 | // OK 92 | } 93 | --------------------------------------------------------------------------------