├── .github └── workflows │ ├── go_lint.yaml │ └── go_test.yaml ├── .gitignore ├── .golangci.yaml ├── ARCHITECTURE.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── cmd └── pggen │ └── pggen.go ├── dev ├── sync_intellij_dictionary.sh └── words.dic ├── docker-compose.yml ├── example ├── acceptance_test.go ├── author │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── complex_params │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── composite │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── custom_types │ ├── codegen_test.go │ ├── mytype │ │ └── mytypes.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ ├── schema.sql │ └── types.go ├── device │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── domain │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── enums │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── erp │ ├── 01_schema.sql │ ├── 02_schema.sql │ └── order │ │ ├── codegen_test.go │ │ ├── customer.sql │ │ ├── customer.sql.go │ │ ├── customer.sql_test.go │ │ ├── price.sql │ │ └── price.sql.go ├── function │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── go_pointer_types │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── inline_param_count │ ├── codegen_test.go │ ├── inline0 │ │ ├── query.sql.go │ │ └── query.sql_test.go │ ├── inline1 │ │ ├── query.sql.go │ │ └── query.sql_test.go │ ├── inline2 │ │ ├── query.sql.go │ │ └── query.sql_test.go │ ├── inline3 │ │ ├── query.sql.go │ │ └── query.sql_test.go │ ├── query.sql │ └── schema.sql ├── ltree │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── nested │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── pgcrypto │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── pggen_schema.sql ├── separate_out_dir │ ├── README.md │ ├── alpha │ │ ├── alpha │ │ │ └── query.sql │ │ └── query.sql │ ├── bravo │ │ └── query.sql │ ├── out │ │ ├── alpha_query.sql.0.go │ │ ├── alpha_query.sql.1.go │ │ ├── bravo_query.sql.go │ │ ├── codegen_test.go │ │ └── query.sql_test.go │ └── schema.sql ├── slices │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── syntax │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql └── void │ ├── codegen_test.go │ ├── query.sql │ ├── query.sql.go │ ├── query.sql_test.go │ └── schema.sql ├── generate.go ├── generate_test.go ├── go.mod ├── go.sum ├── internal ├── ast │ └── ast.go ├── casing │ ├── casing.go │ ├── casing_test.go │ ├── sanitize.go │ └── sanitize_test.go ├── codegen │ ├── common.go │ └── golang │ │ ├── declarer.go │ │ ├── declarer_array.go │ │ ├── declarer_composite.go │ │ ├── declarer_enum.go │ │ ├── declarer_test.go │ │ ├── emitter.go │ │ ├── generate.go │ │ ├── gotype │ │ ├── known_types.go │ │ ├── predicates.go │ │ ├── predicates_test.go │ │ ├── types.go │ │ └── types_test.go │ │ ├── import_set.go │ │ ├── query.gotemplate │ │ ├── templated_file.go │ │ ├── templater.go │ │ ├── testdata │ │ ├── declarer_composite.input.golden │ │ ├── declarer_composite.output.golden │ │ ├── declarer_composite_array.input.golden │ │ ├── declarer_composite_array.output.golden │ │ ├── declarer_composite_enum.input.golden │ │ ├── declarer_composite_enum.output.golden │ │ ├── declarer_composite_nested.input.golden │ │ ├── declarer_composite_nested.output.golden │ │ ├── declarer_enum_escaping.input.golden │ │ ├── declarer_enum_escaping.output.golden │ │ ├── declarer_enum_simple.input.golden │ │ └── declarer_enum_simple.output.golden │ │ ├── type_resolver.go │ │ └── type_resolver_test.go ├── difftest │ └── difftest.go ├── errs │ └── errs.go ├── flags │ └── flags.go ├── gomod │ ├── gomod.go │ └── gomod_test.go ├── parser │ ├── interface.go │ ├── parser.go │ └── parser_test.go ├── paths │ └── paths.go ├── pg │ ├── column.go │ ├── column_test.go │ ├── known_types.go │ ├── pgoid │ │ └── oids.go │ ├── query.sql │ ├── query.sql.go │ ├── type_cache.go │ ├── type_fetcher.go │ ├── type_fetcher_test.go │ └── types.go ├── pgdocker │ ├── pgdocker.go │ └── template.go ├── pginfer │ ├── explain.go │ ├── nullability.go │ ├── pginfer.go │ └── pginfer_test.go ├── pgplan │ ├── node.go │ ├── pgplan.go │ └── pgplan_test.go ├── pgtest │ └── pg_test_db.go ├── ports │ └── port.go ├── ptrs │ └── ptrs.go ├── scanner │ ├── scanner.go │ └── scanner_test.go ├── texts │ ├── dedent.go │ └── dedent_test.go └── token │ └── token.go └── script ├── .goreleaser.yaml └── release.sh /.github/workflows/go_lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: push 3 | jobs: 4 | golangci: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - name: Checkout 8 | uses: actions/checkout@v4 9 | 10 | - name: Setup Go 11 | uses: actions/setup-go@v5 12 | with: 13 | go-version-file: go.mod 14 | 15 | - name: Run 16 | uses: golangci/golangci-lint-action@v6 17 | with: 18 | version: v1.64 19 | -------------------------------------------------------------------------------- /.github/workflows/go_test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: push 3 | jobs: 4 | test: 5 | runs-on: ubuntu-latest 6 | services: 7 | postgres: 8 | image: postgres 9 | env: 10 | POSTGRES_DB: pggen 11 | POSTGRES_USER: postgres 12 | POSTGRES_PASSWORD: hunter2 13 | ports: 14 | - 5555:5432 15 | options: >- 16 | --health-cmd pg_isready 17 | --health-interval 10s 18 | --health-timeout 5s 19 | --health-retries 5 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v4 23 | 24 | - name: Setup Go 25 | uses: actions/setup-go@v5 26 | with: 27 | go-version-file: go.mod 28 | 29 | - run: go test ./... 30 | 31 | - run: go test --tags=acceptance_test ./... 32 | env: 33 | DOCKER_API_VERSION: 1.39 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /*.iml 2 | /.idea/ 3 | /dist/ 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | linters: 2 | enable: 3 | # Enabled by default. 4 | - errcheck 5 | - gosimple 6 | - govet 7 | - ineffassign 8 | - staticcheck 9 | - unused 10 | # Disabled by default. 11 | - dupword 12 | - durationcheck 13 | - errchkjson 14 | - errname 15 | - errorlint 16 | - exhaustive 17 | - exptostd 18 | - gocheckcompilerdirectives 19 | - gochecknoglobals 20 | - gochecknoinits 21 | - gochecksumtype 22 | - gocritic 23 | - godot 24 | - godox 25 | - gofumpt 26 | - gosec 27 | - intrange 28 | - mirror 29 | - misspell 30 | - nilerr 31 | - nilnesserr 32 | - nilnil 33 | - nolintlint 34 | - predeclared 35 | - reassign 36 | - recvcheck 37 | - sloglint 38 | - unconvert 39 | - unparam 40 | - usestdlibvars 41 | - usetesting 42 | - wastedassign 43 | 44 | linters-settings: 45 | gosec: 46 | excludes: 47 | - G115 48 | 49 | -------------------------------------------------------------------------------- /ARCHITECTURE.md: -------------------------------------------------------------------------------- 1 | # Architecture of pggen 2 | 3 | In a nutshell, pggen runs each query on Postgres to extract type information, 4 | and generates the appropriate code. In detail, pggen processes a query file 5 | in the following steps. 6 | 7 | 1. Resolve the query files from the `--query-glob` flag and schema files from 8 | the `--schema-glob` flag in [cmd/pggen/pggen.go]. Pass the normalized 9 | options to `pggen.Generate` in [generate.go]. 10 | 11 | 2. Start Postgres by either connecting to the database specified in 12 | `--postgres-connection` or by starting a new Dockerized Postgres instance. 13 | [internal/pgdocker/pgdocker.go] creates and destroys Docker images for 14 | pggen. 15 | 16 | 3. Parse each query files into an `*ast.File` containing many 17 | `*ast.SourceQuery` nodes in [internal/parser/interface.go]. 18 | 19 | 4. Infer the Postgres types and nullability for the input parameters and output 20 | columns of an `*ast.SourceQuery` and store the results in 21 | `pginfer.TypedQuery` in [internal/pginfer/pginfer.go]. 22 | 23 | To determine the Postgres types, pggen uses itself to compile the queries 24 | in [internal/pg/query.sql]. The queries leverage the Postgres prepare 25 | command to find the input parameter types. 26 | 27 | pggen determines output columns types and names by preparing the query and 28 | reading the field descriptions returned with the query result rows. The 29 | field descriptions contain the type ID for each output column. The type ID 30 | is a Postgres object ID (OID), the primary key to identify a row in the 31 | [`pg_type`] catalog table. 32 | 33 | pggen determines if an output column can be null using heuristics. If a 34 | column cannot be null, pggen uses more ergonomic types to represent the 35 | output like `string` instead of `pgtype.Text`. The heuristics are quite 36 | simple; see [internal/pginfer/nullability.go]. A proper approach requires a 37 | control flow analysis to determine nullability. I've started down that road 38 | in [pgplan.go](./internal/pgplan/pgplan.go). 39 | 40 | 5. Transform each `*ast.File` into `codegen.QueryFile` in [generate.go] 41 | `parseQueries`. 42 | 43 | 6. Use a language-specific code generator to transform `codegen.QueryFile` 44 | into a `golang.TemplatedFile` like with [internal/codegen/golang/templater.go]. 45 | 46 | 7. Emit the generated code from `golang.TemplateFile` in 47 | [internal/codegen/golang/templated_file.go] 48 | 49 | [cmd/pggen/pggen.go]: cmd/pggen/pggen.go 50 | [internal/parser/interface.go]: internal/parser/interface.go 51 | [internal/pgdocker/pgdocker.go]: internal/pgdocker/pgdocker.go 52 | [internal/pginfer/pginfer.go]: internal/pginfer/pginfer.go 53 | [internal/pg/query.sql]: internal/pg/query.sql 54 | [generate.go]: ./generate.go 55 | [internal/codegen/golang/templater.go]: internal/codegen/golang/templater.go 56 | [internal/codegen/golang/templated_file.go]: internal/codegen/golang/templated_file.go 57 | [`pg_prepared_statement`]: https://www.postgresql.org/docs/current/view-pg-prepared-statements.html 58 | [`pg_type`]: https://www.postgresql.org/docs/13/catalog-pg-type.html 59 | 60 | For additional detail, see the original, outdated [design doc] and discussion with the 61 | [pgx author] and [sqlc author]. 62 | 63 | [design doc]: https://docs.google.com/document/d/1NvVKD6cyXvJLWUfqFYad76CWMDFoK9mzKuj1JawkL2A/edit# 64 | [pgx author]: https://github.com/jackc/pgx/issues/915 65 | [sqlc author]: https://github.com/kyleconroy/sqlc/issues/854 66 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to pggen 2 | 3 | First off, thank you for contributing! I welcome PRs. 4 | 5 | tl;dr: 6 | 7 | First, read [ARCHITECTURE.md](ARCHITECTURE.md) to get a lay of the land. 8 | 9 | ```shell 10 | # Dependencies - see Setup below 11 | 12 | # Start a long-lived Postgres server in Docker for integration tests. 13 | # Connect with "make psql" 14 | make start 15 | 16 | # Hack 17 | # Commit changes 18 | 19 | # Validate changes 20 | make lint && make test && make acceptance-test 21 | # make all - equivalent 22 | # make - equivalent 23 | 24 | # Send PR to GitHub. Check that tests and lints passed. 25 | 26 | # Stop Postgres server running in Docker. 27 | make stop 28 | ``` 29 | 30 | ## Design goals of pggen 31 | 32 | - Minimal API surface. There should be only 1 way to run pggen. For example, 33 | pggen only offers a `--query-glob` flag and not also a `--query-file` 34 | flag. The `--query-glob` flag can also be a normal file path. 35 | 36 | - If it's possible in SQL, don't add an option in pggen. If we can use SQL 37 | features to control output, prefer that over adding more controls to pggen. 38 | 39 | - Correctness over convenience. Prefer to expose the nitty-gritty details of 40 | Postgres instead of providing ergonomic APIs. 41 | 42 | - Generated code should look like a human wrote it. The generated code should 43 | be near perfect, including formatting. pggen output doesn't depend on gofmt. 44 | 45 | ## Setup 46 | 47 | You need to install 1 dependency: 48 | 49 | - [golangci-lint] to lint the project locally. 50 | 51 | For macOS: 52 | 53 | ```shell 54 | brew install golangci-lint 55 | ``` 56 | 57 | For Windows and Linux: 58 | 59 | ```shell 60 | # binary will be $(go env GOPATH)/bin/golangci-lint 61 | curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.36.0 62 | golangci-lint --version 63 | ```` 64 | 65 | [golangci-lint]: https://golangci-lint.run/ 66 | 67 | ## Testing 68 | 69 | To test pggen, you'll typically start a long-lived Docker container with a 70 | Postgres instance. The pggen tests create a new Postgres schema to isolate 71 | tests from one another. Creating a new schema is much faster than spinning up a 72 | new Dockerized Postgres instance. 73 | 74 | ```shell 75 | make start 76 | make test # all unit tests 77 | ``` 78 | 79 | To run the acceptance tests to validate that pggen produces the same code as 80 | the checked-in example code: 81 | 82 | ```shell 83 | # Acceptance tests check that there's no Git diffs so commit code first. 84 | git commit -m "some message" 85 | 86 | make acceptance-test 87 | ``` 88 | 89 | To update the acceptance tests after changing the code generator: 90 | 91 | ```shell 92 | make update-acceptance-test 93 | ``` 94 | 95 | ### Testing hierarchy 96 | 97 | pggen has tests at most parts of the testing hierarchy. 98 | 99 | - Unit tests to test the logic of small, independent components, like 100 | [casing_test.go]. Run with `make test`. 101 | 102 | - Integration tests like the [pginfer_test.go] to test that the code works 103 | (integrates) with different subsystems like Postgres, Docker, or other Go 104 | packages. As with unit tests, run with `make test`. 105 | 106 | - Acceptance tests like [example/nested/codegen_test.go] to test that pggen 107 | produces the exact same output as the checked-in examples. Run with 108 | `make acceptance-test`. 109 | 110 | [casing_test.go]: internal/casing/casing_test.go 111 | [pginfer_test.go]: internal/pginfer/pginfer_test.go 112 | [example/nested/codegen_test.go]: example/nested/codegen_test.go 113 | 114 | ## Debugging 115 | 116 | For unit-testable things, like type resolution, there should be a test you can 117 | debug. 118 | 119 | For debugging codegen bugs, the best place to start is the `codegen_test.go` 120 | file in each folder in ./example. 121 | 122 | To debug generated query execution, start with the `query.sql_test.go` file in 123 | each example. I've structured the tests (at least the recent ones like 124 | `example/author`) so that every generated query has an isolated subtest you can 125 | debug. 126 | 127 | For tests that use a Postgres instance, you can find the schema used in the test 128 | in the test output. You can connect to that schema with: 129 | 130 | ``` 131 | PGPASSWORD=hunter2 psql --host=127.0.0.1 --port=5555 --username=postgres pggen 132 | 133 | postgres> set search_path to 'pggen_test_' 134 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 Joe Schafer 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so, 8 | subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := bash 2 | .SHELLFLAGS := -euo pipefail -c 3 | .ONESHELL: # use a single shell for commands instead a new shell per line 4 | .DELETE_ON_ERROR: # delete output files when make rule fails 5 | MAKEFLAGS += --warn-undefined-variables 6 | MAKEFLAGS += --no-builtin-rules 7 | 8 | version := $(shell date '+%Y-%m-%d') 9 | commit := $(shell git rev-parse --short HEAD) 10 | ldflags := -ldflags "-X 'main.version=${version}' -X 'main.commit=${commit}'" 11 | 12 | .PHONY: all 13 | all: lint test acceptance-test 14 | 15 | .PHONY: start 16 | start: 17 | docker-compose up -d 18 | 19 | .PHONY: stop 20 | stop: 21 | docker-compose down 22 | 23 | .PHONY: restart 24 | restart: stop start 25 | 26 | .PHONY: psql 27 | psql: 28 | PGPASSWORD=hunter2 psql --host=127.0.0.1 --port=5555 --username=postgres pggen 29 | 30 | .PHONY: test 31 | test: 32 | go test ./... 33 | 34 | .PHONY: acceptance-test 35 | acceptance-test: 36 | DOCKER_API_VERSION=1.39 go test ./example/acceptance_test.go 37 | 38 | .PHONY: update-acceptance-test 39 | update-acceptance-test: 40 | go test ./example/acceptance_test.go -update 41 | 42 | .PHONY: lint 43 | lint: 44 | golangci-lint run 45 | 46 | .PHONY: dist-dir 47 | dist-dir: 48 | mkdir -p dist 49 | 50 | .PHONY: release 51 | release: 52 | ./script/release.sh 53 | -------------------------------------------------------------------------------- /dev/sync_intellij_dictionary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail 3 | 4 | dictFile=".idea/dictionaries/joe.xml" 5 | wordFile="dev/words.dic" 6 | 7 | # Check if the files exist 8 | if [[ ! -f "$dictFile" ]]; then 9 | echo "IntelliJ XML dictionary file not found at $dictFile" 10 | exit 1 11 | fi 12 | 13 | if [[ ! -f "$wordFile" ]]; then 14 | echo "Word list file not found at $wordFile" 15 | exit 1 16 | fi 17 | 18 | # Extract words from the IntelliJ XML dictionary 19 | words=$(awk -F'[<>]' '// {print $3}' "$dictFile") 20 | 21 | for word in $words; do 22 | # Check if the word already exists in the word list 23 | if ! grep -qw "$word" "$wordFile"; then 24 | echo "Adding word: $word" 25 | echo "$word" >>"$wordFile" 26 | fi 27 | done 28 | 29 | # Sort the words in-place 30 | sort -u -o "$wordFile" "$wordFile" 31 | 32 | echo "Synced IntelliJ dictionary to dev/words.dic" 33 | -------------------------------------------------------------------------------- /dev/words.dic: -------------------------------------------------------------------------------- 1 | .shellflags 2 | abcdefghijklmnopqrstuvwxyz 3 | aclitem 4 | acro 5 | acros 6 | arrs 7 | attname 8 | attnotnull 9 | attnum 10 | attrelid 11 | atttypid 12 | barbaz 13 | bigserial 14 | boolp 15 | bools 16 | bpchar 17 | caser 18 | citext 19 | collatable 20 | costsize 21 | crand 22 | cstring 23 | cust 24 | daterange 25 | descs 26 | enumsortorder 27 | erroring 28 | foos 29 | golangci 30 | gomod 31 | goqu 32 | goreleaser 33 | gotemplate 34 | initdb 35 | intp 36 | ints 37 | isready 38 | jackc 39 | jschaf 40 | lseg 41 | ltree 42 | macaddr 43 | makeflags 44 | msgs 45 | mytype 46 | nitty 47 | nspname 48 | numrange 49 | pgconn 50 | pgdocker 51 | pggen 52 | pggen's 53 | pginfer 54 | pgpassword 55 | pgplan 56 | pgtest 57 | pgtype 58 | pgxpool 59 | plannodes 60 | preferrer 61 | qual 62 | querier 63 | recv 64 | relname 65 | rname 66 | shopspring 67 | smallserial 68 | sourcegraph 69 | sqlc 70 | sslmode 71 | stringp 72 | subplan 73 | subquery 74 | subscripted 75 | targetlist 76 | templater 77 | timestamptz 78 | timestamptzs 79 | toks 80 | tsrange 81 | tstzrange 82 | tuplestore 83 | typbasetype 84 | typdefault 85 | typdefaultbin 86 | typelem 87 | typlen 88 | typname 89 | typnamespace 90 | typndims 91 | typnotnull 92 | typtype 93 | uncapitalized 94 | upgrader 95 | varbit 96 | varlena 97 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.6" 2 | services: 3 | postgresql: 4 | image: "postgres:15" 5 | ports: 6 | - "5555:5432" 7 | restart: always 8 | command: ["postgres", "-c", "log_statement=all"] 9 | environment: 10 | POSTGRES_DB: pggen 11 | POSTGRES_PASSWORD: hunter2 12 | POSTGRES_USER: postgres 13 | -------------------------------------------------------------------------------- /example/author/codegen_test.go: -------------------------------------------------------------------------------- 1 | package author 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_Author(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "author", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | }) 27 | if err != nil { 28 | t.Fatalf("Generate() example/author: %s", err) 29 | } 30 | 31 | wantQueryFile := "query.sql.go" 32 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 33 | assert.FileExists(t, gotQueryFile, 34 | "Generate() should emit query.sql.go") 35 | wantQueries, err := os.ReadFile(wantQueryFile) 36 | if err != nil { 37 | t.Fatalf("read wanted query.go.sql: %s", err) 38 | } 39 | gotQueries, err := os.ReadFile(gotQueryFile) 40 | if err != nil { 41 | t.Fatalf("read generated query.go.sql: %s", err) 42 | } 43 | assert.Equalf(t, string(wantQueries), string(gotQueries), 44 | "Got file %s; does not match contents of %s", 45 | gotQueryFile, wantQueryFile) 46 | } 47 | -------------------------------------------------------------------------------- /example/author/query.sql: -------------------------------------------------------------------------------- 1 | -- FindAuthorById finds one (or zero) authors by ID. 2 | -- name: FindAuthorByID :one 3 | SELECT * FROM author WHERE author_id = pggen.arg('AuthorID'); 4 | 5 | -- FindAuthors finds authors by first name. 6 | -- name: FindAuthors :many 7 | SELECT * FROM author WHERE first_name = pggen.arg('FirstName'); 8 | 9 | -- FindAuthorNames finds one (or zero) authors by ID. 10 | -- name: FindAuthorNames :many 11 | SELECT first_name, last_name FROM author ORDER BY author_id = pggen.arg('AuthorID'); 12 | 13 | -- FindFirstNames finds one (or zero) authors by ID. 14 | -- name: FindFirstNames :many 15 | SELECT first_name FROM author ORDER BY author_id = pggen.arg('AuthorID'); 16 | 17 | -- DeleteAuthors deletes authors with a first name of "joe". 18 | -- name: DeleteAuthors :exec 19 | DELETE FROM author WHERE first_name = 'joe'; 20 | 21 | -- DeleteAuthorsByFirstName deletes authors by first name. 22 | -- name: DeleteAuthorsByFirstName :exec 23 | DELETE FROM author WHERE first_name = pggen.arg('FirstName'); 24 | 25 | -- DeleteAuthorsByFullName deletes authors by the full name. 26 | -- name: DeleteAuthorsByFullName :exec 27 | DELETE 28 | FROM author 29 | WHERE first_name = pggen.arg('FirstName') 30 | AND last_name = pggen.arg('LastName') 31 | AND suffix = pggen.arg('Suffix'); 32 | 33 | -- InsertAuthor inserts an author by name and returns the ID. 34 | -- name: InsertAuthor :one 35 | INSERT INTO author (first_name, last_name) 36 | VALUES (pggen.arg('FirstName'), pggen.arg('LastName')) 37 | RETURNING author_id; 38 | 39 | -- InsertAuthorSuffix inserts an author by name and suffix and returns the 40 | -- entire row. 41 | -- name: InsertAuthorSuffix :one 42 | INSERT INTO author (first_name, last_name, suffix) 43 | VALUES (pggen.arg('FirstName'), pggen.arg('LastName'), pggen.arg('Suffix')) 44 | RETURNING author_id, first_name, last_name, suffix; 45 | 46 | -- name: StringAggFirstName :one 47 | SELECT string_agg(first_name, ',') AS names FROM author WHERE author_id = pggen.arg('author_id'); 48 | 49 | -- name: ArrayAggFirstName :one 50 | SELECT array_agg(first_name) AS names FROM author WHERE author_id = pggen.arg('author_id'); 51 | -------------------------------------------------------------------------------- /example/author/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE author ( 2 | author_id serial PRIMARY KEY, 3 | first_name text NOT NULL, 4 | last_name text NOT NULL, 5 | suffix text NULL 6 | ); 7 | -------------------------------------------------------------------------------- /example/complex_params/codegen_test.go: -------------------------------------------------------------------------------- 1 | package complex_params 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_ComplexParams(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "complex_params", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | TypeOverrides: map[string]string{ 27 | "int4": "int", 28 | "text": "string", 29 | }, 30 | }) 31 | if err != nil { 32 | t.Fatalf("Generate() example/complex_params: %s", err) 33 | } 34 | 35 | wantQueryFile := "query.sql.go" 36 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 37 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 38 | wantQueries, err := os.ReadFile(wantQueryFile) 39 | if err != nil { 40 | t.Fatalf("read wanted query.go.sql: %s", err) 41 | } 42 | gotQueries, err := os.ReadFile(gotQueryFile) 43 | if err != nil { 44 | t.Fatalf("read generated query.go.sql: %s", err) 45 | } 46 | assert.Equalf(t, string(wantQueries), string(gotQueries), 47 | "Got file %s; does not match contents of %s", 48 | gotQueryFile, wantQueryFile) 49 | } 50 | -------------------------------------------------------------------------------- /example/complex_params/query.sql: -------------------------------------------------------------------------------- 1 | -- name: ParamArrayInt :one 2 | SELECT pggen.arg('ints')::bigint[]; 3 | 4 | -- name: ParamNested1 :one 5 | SELECT pggen.arg('dimensions')::dimensions; 6 | 7 | -- name: ParamNested2 :one 8 | SELECT pggen.arg('image')::product_image_type; 9 | 10 | -- name: ParamNested2Array :one 11 | SELECT pggen.arg('images')::product_image_type[]; 12 | 13 | -- name: ParamNested3 :one 14 | SELECT pggen.arg('image_set')::product_image_set_type; 15 | 16 | -------------------------------------------------------------------------------- /example/complex_params/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package complex_params 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jschaf/pggen/internal/pgtest" 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestNewQuerier_ParamArrayInt(t *testing.T) { 12 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 13 | defer cleanup() 14 | 15 | q := NewQuerier(conn) 16 | ctx := t.Context() 17 | 18 | want := []int{1, 2, 3, 4} 19 | 20 | t.Run("ParamArrayInt", func(t *testing.T) { 21 | row, err := q.ParamArrayInt(ctx, want) 22 | require.NoError(t, err) 23 | assert.Equal(t, want, row) 24 | }) 25 | } 26 | 27 | func TestNewQuerier_ParamNested1(t *testing.T) { 28 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 29 | defer cleanup() 30 | 31 | q := NewQuerier(conn) 32 | ctx := t.Context() 33 | 34 | want := Dimensions{Width: 77, Height: 77} 35 | 36 | t.Run("ParamNested1", func(t *testing.T) { 37 | row, err := q.ParamNested1(ctx, want) 38 | require.NoError(t, err) 39 | assert.Equal(t, want, row) 40 | }) 41 | } 42 | 43 | func TestNewQuerier_ParamNested2(t *testing.T) { 44 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 45 | defer cleanup() 46 | 47 | q := NewQuerier(conn) 48 | ctx := t.Context() 49 | 50 | want := ProductImageType{ 51 | Source: "src", 52 | Dimensions: Dimensions{Width: 77, Height: 77}, 53 | } 54 | 55 | t.Run("ParamNested2", func(t *testing.T) { 56 | row, err := q.ParamNested2(ctx, want) 57 | require.NoError(t, err) 58 | assert.Equal(t, want, row) 59 | }) 60 | } 61 | 62 | func TestNewQuerier_ParamNested2Array(t *testing.T) { 63 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 64 | defer cleanup() 65 | 66 | q := NewQuerier(conn) 67 | ctx := t.Context() 68 | 69 | want := []ProductImageType{ 70 | {Source: "src1", Dimensions: Dimensions{Width: 11, Height: 11}}, 71 | {Source: "src2", Dimensions: Dimensions{Width: 22, Height: 22}}, 72 | } 73 | 74 | t.Run("ParamNested2Array", func(t *testing.T) { 75 | row, err := q.ParamNested2Array(ctx, want) 76 | require.NoError(t, err) 77 | assert.Equal(t, want, row) 78 | }) 79 | } 80 | 81 | func TestNewQuerier_ParamNested3(t *testing.T) { 82 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 83 | defer cleanup() 84 | 85 | q := NewQuerier(conn) 86 | ctx := t.Context() 87 | 88 | want := ProductImageSetType{ 89 | Name: "set1", 90 | OrigImage: ProductImageType{Source: "src1", Dimensions: Dimensions{Width: 11, Height: 11}}, 91 | Images: []ProductImageType{ 92 | {Source: "src1", Dimensions: Dimensions{Width: 11, Height: 11}}, 93 | {Source: "src2", Dimensions: Dimensions{Width: 22, Height: 22}}, 94 | }, 95 | } 96 | 97 | t.Run("ParamNested3", func(t *testing.T) { 98 | row, err := q.ParamNested3(ctx, want) 99 | require.NoError(t, err) 100 | assert.Equal(t, want, row) 101 | }) 102 | } 103 | 104 | func TestNewQuerier_ParamNested3_QueryAllDataTypes(t *testing.T) { 105 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 106 | defer cleanup() 107 | ctx := t.Context() 108 | // dataTypes, err := QueryAllDataTypes(ctx, conn) 109 | // require.NoError(t, err) 110 | q := NewQuerier(conn) 111 | 112 | want := ProductImageSetType{ 113 | Name: "set1", 114 | OrigImage: ProductImageType{Source: "src1", Dimensions: Dimensions{Width: 11, Height: 11}}, 115 | Images: []ProductImageType{ 116 | {Source: "src1", Dimensions: Dimensions{Width: 11, Height: 11}}, 117 | {Source: "src2", Dimensions: Dimensions{Width: 22, Height: 22}}, 118 | }, 119 | } 120 | 121 | t.Run("ParamNested3", func(t *testing.T) { 122 | row, err := q.ParamNested3(ctx, want) 123 | require.NoError(t, err) 124 | assert.Equal(t, want, row) 125 | }) 126 | } 127 | -------------------------------------------------------------------------------- /example/complex_params/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE TYPE dimensions AS ( 2 | width int4, 3 | height int4 4 | ); 5 | 6 | CREATE TYPE product_image_type AS ( 7 | source text, 8 | dimensions dimensions 9 | ); 10 | 11 | CREATE TYPE product_image_set_type AS ( 12 | name text, 13 | orig_image product_image_type, 14 | images product_image_type[] 15 | ); 16 | -------------------------------------------------------------------------------- /example/composite/codegen_test.go: -------------------------------------------------------------------------------- 1 | package composite 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_Composite(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "composite", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | TypeOverrides: map[string]string{ 27 | "_bool": "[]bool", 28 | "bool": "bool", 29 | "int8": "int", 30 | "int4": "int", 31 | "text": "string", 32 | "citext": "github.com/jackc/pgtype.Text", 33 | }, 34 | }) 35 | if err != nil { 36 | t.Fatalf("Generate(): %s", err) 37 | } 38 | 39 | wantQueryFile := "query.sql.go" 40 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 41 | assert.FileExists(t, gotQueryFile, 42 | "Generate() should emit query.sql.go") 43 | wantQueries, err := os.ReadFile(wantQueryFile) 44 | if err != nil { 45 | t.Fatalf("read wanted query.go.sql: %s", err) 46 | } 47 | gotQueries, err := os.ReadFile(gotQueryFile) 48 | if err != nil { 49 | t.Fatalf("read generated query.go.sql: %s", err) 50 | } 51 | assert.Equalf(t, string(wantQueries), string(gotQueries), 52 | "Got file %s; does not match contents of %s", 53 | gotQueryFile, wantQueryFile) 54 | } 55 | -------------------------------------------------------------------------------- /example/composite/query.sql: -------------------------------------------------------------------------------- 1 | -- name: SearchScreenshots :many 2 | SELECT 3 | ss.id, 4 | array_agg(bl) AS blocks 5 | FROM screenshots ss 6 | JOIN blocks bl ON bl.screenshot_id = ss.id 7 | WHERE bl.body LIKE pggen.arg('Body') || '%' 8 | GROUP BY ss.id 9 | ORDER BY ss.id 10 | LIMIT pggen.arg('Limit') OFFSET pggen.arg('Offset'); 11 | 12 | -- name: SearchScreenshotsOneCol :many 13 | SELECT 14 | array_agg(bl) AS blocks 15 | FROM screenshots ss 16 | JOIN blocks bl ON bl.screenshot_id = ss.id 17 | WHERE bl.body LIKE pggen.arg('Body') || '%' 18 | GROUP BY ss.id 19 | ORDER BY ss.id 20 | LIMIT pggen.arg('Limit') OFFSET pggen.arg('Offset'); 21 | 22 | -- name: InsertScreenshotBlocks :one 23 | WITH screens AS ( 24 | INSERT INTO screenshots (id) VALUES (pggen.arg('ScreenshotID')) 25 | ON CONFLICT DO NOTHING 26 | ) 27 | INSERT 28 | INTO blocks (screenshot_id, body) 29 | VALUES (pggen.arg('ScreenshotID'), pggen.arg('Body')) 30 | RETURNING id, screenshot_id, body; 31 | 32 | -- name: ArraysInput :one 33 | SELECT pggen.arg('arrays')::arrays; 34 | 35 | -- name: UserEmails :one 36 | SELECT ('foo', 'bar@example.com')::user_email; 37 | -------------------------------------------------------------------------------- /example/composite/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package composite 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jackc/pgtype" 7 | "github.com/jschaf/pggen/internal/difftest" 8 | "github.com/jschaf/pggen/internal/pgtest" 9 | "github.com/jschaf/pggen/internal/ptrs" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestNewQuerier_SearchScreenshots(t *testing.T) { 15 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 16 | defer cleanup() 17 | 18 | q := NewQuerier(conn) 19 | screenshotID := 99 20 | screenshot1 := insertScreenshotBlock(t, q, screenshotID, "body1") 21 | screenshot2 := insertScreenshotBlock(t, q, screenshotID, "body2") 22 | want := []SearchScreenshotsRow{ 23 | { 24 | ID: screenshotID, 25 | Blocks: []Blocks{ 26 | { 27 | ID: screenshot1.ID, 28 | ScreenshotID: screenshotID, 29 | Body: screenshot1.Body, 30 | }, 31 | { 32 | ID: screenshot2.ID, 33 | ScreenshotID: screenshotID, 34 | Body: screenshot2.Body, 35 | }, 36 | }, 37 | }, 38 | } 39 | 40 | t.Run("SearchScreenshots", func(t *testing.T) { 41 | rows, err := q.SearchScreenshots(t.Context(), SearchScreenshotsParams{ 42 | Body: "body", 43 | Limit: 5, 44 | Offset: 0, 45 | }) 46 | require.NoError(t, err) 47 | assert.Equal(t, want, rows) 48 | }) 49 | 50 | t.Run("SearchScreenshotsOneCol", func(t *testing.T) { 51 | rows, err := q.SearchScreenshotsOneCol(t.Context(), SearchScreenshotsOneColParams{ 52 | Body: "body", 53 | Limit: 5, 54 | Offset: 0, 55 | }) 56 | require.NoError(t, err) 57 | assert.Equal(t, [][]Blocks{want[0].Blocks}, rows) 58 | }) 59 | } 60 | 61 | func TestNewQuerier_ArraysInput(t *testing.T) { 62 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 63 | defer cleanup() 64 | 65 | q := NewQuerier(conn) 66 | 67 | t.Run("ArraysInput", func(t *testing.T) { 68 | want := Arrays{ 69 | Texts: []string{"foo", "bar"}, 70 | Int8s: []*int{ptrs.Int(1), ptrs.Int(2), ptrs.Int(3)}, 71 | Bools: []bool{true, true, false}, 72 | Floats: []*float64{ptrs.Float64(33.3), ptrs.Float64(66.6)}, 73 | } 74 | got, err := q.ArraysInput(t.Context(), want) 75 | require.NoError(t, err) 76 | difftest.AssertSame(t, want, got) 77 | }) 78 | } 79 | 80 | func TestNewQuerier_UserEmails(t *testing.T) { 81 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 82 | defer cleanup() 83 | 84 | q := NewQuerier(conn) 85 | 86 | got, err := q.UserEmails(t.Context()) 87 | require.NoError(t, err) 88 | want := UserEmail{ 89 | ID: "foo", 90 | Email: pgtype.Text{String: "bar@example.com", Status: pgtype.Present}, 91 | } 92 | difftest.AssertSame(t, want, got) 93 | } 94 | 95 | func insertScreenshotBlock(t *testing.T, q *DBQuerier, screenID int, body string) InsertScreenshotBlocksRow { 96 | t.Helper() 97 | row, err := q.InsertScreenshotBlocks(t.Context(), screenID, body) 98 | require.NoError(t, err, "insert screenshot blocks") 99 | return row 100 | } 101 | -------------------------------------------------------------------------------- /example/composite/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE EXTENSION IF NOT EXISTS citext; 2 | 3 | CREATE TABLE screenshots ( 4 | id bigint PRIMARY KEY 5 | ); 6 | 7 | CREATE TABLE blocks ( 8 | id serial PRIMARY KEY, 9 | screenshot_id bigint NOT NULL REFERENCES screenshots (id), 10 | body text NOT NULL 11 | ); 12 | 13 | CREATE TYPE arrays AS ( 14 | texts text[], 15 | int8s int8[], 16 | bools boolean[], 17 | floats float8[] 18 | ); 19 | 20 | 21 | CREATE TYPE user_email AS ( 22 | id text, 23 | email citext 24 | ); 25 | -------------------------------------------------------------------------------- /example/custom_types/codegen_test.go: -------------------------------------------------------------------------------- 1 | package custom_types 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_CustomTypes(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "custom_types", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | TypeOverrides: map[string]string{ 27 | "text": "github.com/jschaf/pggen/example/custom_types/mytype.String", 28 | "int8": "github.com/jschaf/pggen/example/custom_types.CustomInt", 29 | "my_int": "int", 30 | "_my_int": "[]int", 31 | }, 32 | }) 33 | if err != nil { 34 | t.Fatalf("Generate() example/custom_types: %s", err) 35 | } 36 | 37 | wantQueryFile := "query.sql.go" 38 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 39 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 40 | wantQueries, err := os.ReadFile(wantQueryFile) 41 | if err != nil { 42 | t.Fatalf("read wanted query.go.sql: %s", err) 43 | } 44 | gotQueries, err := os.ReadFile(gotQueryFile) 45 | if err != nil { 46 | t.Fatalf("read generated query.go.sql: %s", err) 47 | } 48 | assert.Equalf(t, string(wantQueries), string(gotQueries), 49 | "Got file %s; does not match contents of %s", 50 | gotQueryFile, wantQueryFile) 51 | } 52 | -------------------------------------------------------------------------------- /example/custom_types/mytype/mytypes.go: -------------------------------------------------------------------------------- 1 | package mytype 2 | 3 | // String is a simple custom type we can use for the Postgres text type. 4 | type String string 5 | -------------------------------------------------------------------------------- /example/custom_types/query.sql: -------------------------------------------------------------------------------- 1 | -- name: CustomTypes :one 2 | SELECT 'some_text', 1::bigint; 3 | 4 | -- name: CustomMyInt :one 5 | SELECT '5'::my_int as int5; 6 | 7 | -- name: IntArray :many 8 | SELECT ARRAY ['5', '6', '7']::int[] as ints; 9 | -------------------------------------------------------------------------------- /example/custom_types/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package custom_types 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jackc/pgtype" 7 | "github.com/jschaf/pggen/internal/pgtest" 8 | "github.com/jschaf/pggen/internal/texts" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestQuerier_CustomTypes(t *testing.T) { 14 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanup() 16 | q := NewQuerier(conn) 17 | ctx := t.Context() 18 | 19 | t.Run("CustomTypes", func(t *testing.T) { 20 | val, err := q.CustomTypes(ctx) 21 | require.NoError(t, err) 22 | want := CustomTypesRow{ 23 | Column: "some_text", 24 | Int8: 1, 25 | } 26 | assert.Equal(t, want, val) 27 | }) 28 | } 29 | 30 | func TestQuerier_CustomMyInt(t *testing.T) { 31 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 32 | defer cleanup() 33 | row := conn.QueryRow(t.Context(), texts.Dedent(` 34 | SELECT pt.oid 35 | FROM pg_type pt 36 | JOIN pg_namespace pn ON pt.typnamespace = pn.oid 37 | WHERE typname = 'my_int' 38 | AND pn.nspname = current_schema() 39 | LIMIT 1; 40 | `)) 41 | oidVal := pgtype.OIDValue{} 42 | err := row.Scan(&oidVal) 43 | require.NoError(t, err) 44 | t.Logf("my_int oid: %d", oidVal.Uint) 45 | 46 | conn.ConnInfo().RegisterDataType(pgtype.DataType{ 47 | Value: &pgtype.Int2{}, 48 | Name: "my_int", 49 | OID: oidVal.Uint, 50 | }) 51 | 52 | q := NewQuerier(conn) 53 | ctx := t.Context() 54 | 55 | t.Run("CustomMyInt", func(t *testing.T) { 56 | val, err := q.CustomMyInt(ctx) 57 | require.NoError(t, err) 58 | assert.Equal(t, 5, val) 59 | }) 60 | } 61 | 62 | func TestQuerier_IntArray(t *testing.T) { 63 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 64 | defer cleanup() 65 | q := NewQuerier(conn) 66 | ctx := t.Context() 67 | 68 | t.Run("IntArray", func(t *testing.T) { 69 | array, err := q.IntArray(ctx) 70 | require.NoError(t, err) 71 | assert.Equal(t, [][]int32{{5, 6, 7}}, array) 72 | }) 73 | } 74 | -------------------------------------------------------------------------------- /example/custom_types/schema.sql: -------------------------------------------------------------------------------- 1 | -- New base type my_int. 2 | -- https://stackoverflow.com/a/45190420/30900 3 | CREATE TYPE my_int; 4 | 5 | CREATE FUNCTION my_int_in(cstring) RETURNS my_int 6 | LANGUAGE internal 7 | IMMUTABLE STRICT PARALLEL SAFE AS 8 | 'int2in'; 9 | 10 | CREATE FUNCTION my_int_out(my_int) RETURNS cstring 11 | LANGUAGE internal 12 | IMMUTABLE STRICT PARALLEL SAFE AS 13 | 'int2out'; 14 | 15 | CREATE FUNCTION my_int_recv(internal) RETURNS my_int 16 | LANGUAGE internal 17 | IMMUTABLE STRICT PARALLEL SAFE AS 18 | 'int2recv'; 19 | 20 | CREATE FUNCTION my_int_send(my_int) RETURNS bytea 21 | LANGUAGE internal 22 | IMMUTABLE STRICT PARALLEL SAFE AS 23 | 'int2send'; 24 | 25 | CREATE TYPE my_int ( 26 | INPUT = my_int_in, 27 | OUTPUT = my_int_out, 28 | RECEIVE = my_int_recv, 29 | SEND = my_int_send, 30 | LIKE = smallint, 31 | CATEGORY = 'N', 32 | PREFERRED = FALSE, 33 | DELIMITER = ',', 34 | COLLATABLE = FALSE 35 | ); 36 | -------------------------------------------------------------------------------- /example/custom_types/types.go: -------------------------------------------------------------------------------- 1 | package custom_types 2 | 3 | // CustomInt is a custom type in the same package as the query file. 4 | type CustomInt int 5 | -------------------------------------------------------------------------------- /example/device/codegen_test.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_Device(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "device", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | }) 27 | if err != nil { 28 | t.Fatalf("Generate() example/device: %s", err) 29 | } 30 | 31 | wantQueryFile := "query.sql.go" 32 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 33 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 34 | wantQueries, err := os.ReadFile(wantQueryFile) 35 | if err != nil { 36 | t.Fatalf("read wanted query.go.sql: %s", err) 37 | } 38 | gotQueries, err := os.ReadFile(gotQueryFile) 39 | if err != nil { 40 | t.Fatalf("read generated query.go.sql: %s", err) 41 | } 42 | assert.Equalf(t, string(wantQueries), string(gotQueries), 43 | "Got file %s; does not match contents of %s", 44 | gotQueryFile, wantQueryFile) 45 | } 46 | -------------------------------------------------------------------------------- /example/device/query.sql: -------------------------------------------------------------------------------- 1 | -- name: FindDevicesByUser :many 2 | SELECT 3 | id, 4 | name, 5 | (SELECT array_agg(mac) FROM device WHERE owner = id) AS mac_addrs 6 | FROM "user" 7 | WHERE id = pggen.arg('ID'); 8 | 9 | -- name: CompositeUser :many 10 | SELECT 11 | d.mac, 12 | d.type, 13 | ROW (u.id, u.name)::"user" AS "user" 14 | FROM device d 15 | LEFT JOIN "user" u ON u.id = d.owner; 16 | 17 | -- name: CompositeUserOne :one 18 | SELECT ROW (15, 'qux')::"user" AS "user"; 19 | 20 | -- name: CompositeUserOneTwoCols :one 21 | SELECT 1 AS num, ROW (15, 'qux')::"user" AS "user"; 22 | 23 | -- name: CompositeUserMany :many 24 | SELECT ROW (15, 'qux')::"user" AS "user"; 25 | 26 | -- name: InsertUser :exec 27 | INSERT INTO "user" (id, name) 28 | VALUES (pggen.arg('user_id'), pggen.arg('name')); 29 | 30 | -- name: InsertDevice :exec 31 | INSERT INTO device (mac, owner) 32 | VALUES (pggen.arg('mac'), pggen.arg('owner')); 33 | -------------------------------------------------------------------------------- /example/device/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | 7 | "github.com/jackc/pgtype" 8 | "github.com/jschaf/pggen/internal/pgtest" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestQuerier_FindDevicesByUser(t *testing.T) { 14 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanup() 16 | q := NewQuerier(conn) 17 | ctx := t.Context() 18 | userID := 18 19 | _, err := q.InsertUser(ctx, userID, "foo") 20 | require.NoError(t, err) 21 | mac1, _ := net.ParseMAC("11:22:33:44:55:66") 22 | _, err = q.InsertDevice(ctx, pgtype.Macaddr{Status: pgtype.Present, Addr: mac1}, userID) 23 | require.NoError(t, err) 24 | 25 | t.Run("FindDevicesByUser", func(t *testing.T) { 26 | val, err := q.FindDevicesByUser(ctx, userID) 27 | require.NoError(t, err) 28 | want := []FindDevicesByUserRow{ 29 | { 30 | ID: userID, 31 | Name: "foo", 32 | MacAddrs: pgtype.MacaddrArray{ 33 | Elements: []pgtype.Macaddr{{Addr: mac1, Status: pgtype.Present}}, 34 | Dimensions: []pgtype.ArrayDimension{{ 35 | Length: 1, 36 | LowerBound: 1, 37 | }}, 38 | Status: pgtype.Present, 39 | }, 40 | }, 41 | } 42 | assert.Equal(t, want, val) 43 | }) 44 | } 45 | 46 | func TestQuerier_CompositeUser(t *testing.T) { 47 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 48 | defer cleanup() 49 | q := NewQuerier(conn) 50 | ctx := t.Context() 51 | 52 | userID := 18 53 | name := "foo" 54 | _, err := q.InsertUser(ctx, userID, name) 55 | require.NoError(t, err) 56 | 57 | mac1, _ := net.ParseMAC("11:22:33:44:55:66") 58 | mac2, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") 59 | _, err = q.InsertDevice(ctx, pgtype.Macaddr{Status: pgtype.Present, Addr: mac1}, userID) 60 | require.NoError(t, err) 61 | _, err = q.InsertDevice(ctx, pgtype.Macaddr{Status: pgtype.Present, Addr: mac2}, userID) 62 | require.NoError(t, err) 63 | 64 | t.Run("CompositeUser", func(t *testing.T) { 65 | users, err := q.CompositeUser(ctx) 66 | require.NoError(t, err) 67 | want := []CompositeUserRow{ 68 | { 69 | Mac: pgtype.Macaddr{Addr: mac1, Status: pgtype.Present}, 70 | Type: DeviceTypeUndefined, 71 | User: User{ID: &userID, Name: &name}, 72 | }, 73 | { 74 | Mac: pgtype.Macaddr{Addr: mac2, Status: pgtype.Present}, 75 | Type: DeviceTypeUndefined, 76 | User: User{ID: &userID, Name: &name}, 77 | }, 78 | } 79 | assert.Equal(t, want, users) 80 | }) 81 | } 82 | 83 | func TestQuerier_CompositeUserOne(t *testing.T) { 84 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 85 | defer cleanup() 86 | q := NewQuerier(conn) 87 | ctx := t.Context() 88 | id := 15 89 | name := "qux" 90 | wantUser := User{ID: &id, Name: &name} 91 | 92 | t.Run("CompositeUserOne", func(t *testing.T) { 93 | got, err := q.CompositeUserOne(ctx) 94 | require.NoError(t, err) 95 | assert.Equal(t, wantUser, got) 96 | }) 97 | } 98 | -------------------------------------------------------------------------------- /example/device/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE TYPE device_type AS ENUM ( 2 | 'undefined', 3 | 'phone', 4 | 'laptop', 5 | 'ipad', 6 | 'desktop', 7 | 'iot' 8 | ); 9 | 10 | CREATE TABLE "user" ( 11 | id bigint PRIMARY KEY, 12 | name text NOT NULL 13 | ); 14 | 15 | CREATE TABLE device ( 16 | mac MACADDR PRIMARY KEY, 17 | owner bigint REFERENCES "user" (id), 18 | type device_type NOT NULL DEFAULT 'undefined' 19 | ); 20 | -------------------------------------------------------------------------------- /example/domain/codegen_test.go: -------------------------------------------------------------------------------- 1 | package domain 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_Domain(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "domain", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | }) 27 | if err != nil { 28 | t.Fatalf("Generate() example/domain: %s", err) 29 | } 30 | 31 | wantQueryFile := "query.sql.go" 32 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 33 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 34 | wantQueries, err := os.ReadFile(wantQueryFile) 35 | if err != nil { 36 | t.Fatalf("read wanted query.go.sql: %s", err) 37 | } 38 | gotQueries, err := os.ReadFile(gotQueryFile) 39 | if err != nil { 40 | t.Fatalf("read generated query.go.sql: %s", err) 41 | } 42 | assert.Equalf(t, string(wantQueries), string(gotQueries), 43 | "Got file %s; does not match contents of %s", 44 | gotQueryFile, wantQueryFile) 45 | } 46 | -------------------------------------------------------------------------------- /example/domain/query.sql: -------------------------------------------------------------------------------- 1 | -- name: DomainOne :one 2 | SELECT '90210'::us_postal_code; 3 | -------------------------------------------------------------------------------- /example/domain/query.sql.go: -------------------------------------------------------------------------------- 1 | // Code generated by pggen. DO NOT EDIT. 2 | 3 | package domain 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "github.com/jackc/pgconn" 9 | "github.com/jackc/pgtype" 10 | "github.com/jackc/pgx/v4" 11 | ) 12 | 13 | // Querier is a typesafe Go interface backed by SQL queries. 14 | type Querier interface { 15 | DomainOne(ctx context.Context) (string, error) 16 | } 17 | 18 | var _ Querier = &DBQuerier{} 19 | 20 | type DBQuerier struct { 21 | conn genericConn // underlying Postgres transport to use 22 | types *typeResolver // resolve types by name 23 | } 24 | 25 | // genericConn is a connection like *pgx.Conn, pgx.Tx, or *pgxpool.Pool. 26 | type genericConn interface { 27 | Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) 28 | QueryRow(ctx context.Context, sql string, args ...any) pgx.Row 29 | Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) 30 | } 31 | 32 | // NewQuerier creates a DBQuerier that implements Querier. 33 | func NewQuerier(conn genericConn) *DBQuerier { 34 | return &DBQuerier{conn: conn, types: newTypeResolver()} 35 | } 36 | 37 | // typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. 38 | type typeResolver struct { 39 | connInfo *pgtype.ConnInfo // types by Postgres type name 40 | } 41 | 42 | func newTypeResolver() *typeResolver { 43 | ci := pgtype.NewConnInfo() 44 | return &typeResolver{connInfo: ci} 45 | } 46 | 47 | // findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. 48 | func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { 49 | typ, ok := tr.connInfo.DataTypeForName(name) 50 | if !ok { 51 | return 0, nil, false 52 | } 53 | v := pgtype.NewValue(typ.Value) 54 | return typ.OID, v.(pgtype.ValueTranscoder), true 55 | } 56 | 57 | // setValue sets the value of a ValueTranscoder to a value that should always 58 | // work and panics if it fails. 59 | func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { 60 | if err := vt.Set(val); err != nil { 61 | panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) 62 | } 63 | return vt 64 | } 65 | 66 | const domainOneSQL = `SELECT '90210'::us_postal_code;` 67 | 68 | // DomainOne implements Querier.DomainOne. 69 | func (q *DBQuerier) DomainOne(ctx context.Context) (string, error) { 70 | ctx = context.WithValue(ctx, "pggen_query_name", "DomainOne") 71 | row := q.conn.QueryRow(ctx, domainOneSQL) 72 | var item string 73 | if err := row.Scan(&item); err != nil { 74 | return item, fmt.Errorf("query DomainOne: %w", err) 75 | } 76 | return item, nil 77 | } 78 | 79 | // textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding 80 | // format to text instead binary (the default). pggen uses the text format 81 | // when the OID is unknownOID because the binary format requires the OID. 82 | // Typically occurs for unregistered types. 83 | type textPreferrer struct { 84 | pgtype.ValueTranscoder 85 | typeName string 86 | } 87 | 88 | // PreferredParamFormat implements pgtype.ParamFormatPreferrer. 89 | func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } 90 | 91 | func (t textPreferrer) NewTypeValue() pgtype.Value { 92 | return textPreferrer{ValueTranscoder: pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), typeName: t.typeName} 93 | } 94 | 95 | func (t textPreferrer) TypeName() string { 96 | return t.typeName 97 | } 98 | 99 | // unknownOID means we don't know the OID for a type. This is okay for decoding 100 | // because pgx call DecodeText or DecodeBinary without requiring the OID. For 101 | // encoding parameters, pggen uses textPreferrer if the OID is unknown. 102 | const unknownOID = 0 103 | -------------------------------------------------------------------------------- /example/domain/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package domain 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jschaf/pggen/internal/pgtest" 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestQuerier_DomainOne(t *testing.T) { 12 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 13 | defer cleanup() 14 | 15 | q := NewQuerier(conn) 16 | ctx := t.Context() 17 | 18 | t.Run("DomainOne", func(t *testing.T) { 19 | postCode, err := q.DomainOne(ctx) 20 | require.NoError(t, err) 21 | assert.Equal(t, "90210", postCode) 22 | }) 23 | } 24 | -------------------------------------------------------------------------------- /example/domain/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE DOMAIN us_postal_code AS TEXT 2 | CHECK ( 3 | VALUE ~ '^\d{5}$' 4 | OR VALUE ~ '^\d{5}-\d{4}$' 5 | ); 6 | -------------------------------------------------------------------------------- /example/enums/codegen_test.go: -------------------------------------------------------------------------------- 1 | package enums 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_Enums(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "enums", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | }) 27 | if err != nil { 28 | t.Fatalf("Generate() example/enums: %s", err) 29 | } 30 | 31 | wantQueryFile := "query.sql.go" 32 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 33 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 34 | wantQueries, err := os.ReadFile(wantQueryFile) 35 | if err != nil { 36 | t.Fatalf("read wanted query.go.sql: %s", err) 37 | } 38 | gotQueries, err := os.ReadFile(gotQueryFile) 39 | if err != nil { 40 | t.Fatalf("read generated query.go.sql: %s", err) 41 | } 42 | assert.Equalf(t, string(wantQueries), string(gotQueries), 43 | "Got file %s; does not match contents of %s", 44 | gotQueryFile, wantQueryFile) 45 | } 46 | -------------------------------------------------------------------------------- /example/enums/query.sql: -------------------------------------------------------------------------------- 1 | -- name: FindAllDevices :many 2 | SELECT mac, type 3 | FROM device; 4 | 5 | -- name: InsertDevice :exec 6 | INSERT INTO device (mac, type) 7 | VALUES (pggen.arg('Mac'), pggen.arg('TypePg')); 8 | 9 | -- Select an array of all device_type enum values. 10 | -- name: FindOneDeviceArray :one 11 | SELECT enum_range(NULL::device_type) AS device_types; 12 | 13 | -- Select many rows of device_type enum values. 14 | -- name: FindManyDeviceArray :many 15 | SELECT enum_range('ipad'::device_type, 'iot'::device_type) AS device_types 16 | UNION ALL 17 | SELECT enum_range(NULL::device_type) AS device_types; 18 | 19 | -- Select many rows of device_type enum values with multiple output columns. 20 | -- name: FindManyDeviceArrayWithNum :many 21 | SELECT 1 AS num, enum_range('ipad'::device_type, 'iot'::device_type) AS device_types 22 | UNION ALL 23 | SELECT 2 as num, enum_range(NULL::device_type) AS device_types; 24 | 25 | -- Regression test for https://github.com/jschaf/pggen/issues/23. 26 | -- name: EnumInsideComposite :one 27 | SELECT ROW('08:00:2b:01:02:03'::macaddr, 'phone'::device_type) ::device; 28 | -------------------------------------------------------------------------------- /example/enums/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package enums 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "testing" 7 | "time" 8 | 9 | "github.com/jackc/pgtype" 10 | "github.com/jschaf/pggen/internal/pgtest" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestNewQuerier_FindAllDevices(t *testing.T) { 16 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) 17 | defer cancel() 18 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 19 | defer cleanup() 20 | 21 | q := NewQuerier(conn) 22 | mac, _ := net.ParseMAC("00:00:5e:00:53:01") 23 | 24 | insertDevice(t, q, mac, DeviceTypeIot) 25 | 26 | t.Run("FindAllDevices", func(t *testing.T) { 27 | devices, err := q.FindAllDevices(ctx) 28 | require.NoError(t, err) 29 | assert.Equal(t, 30 | []FindAllDevicesRow{ 31 | {Mac: pgtype.Macaddr{Addr: mac, Status: pgtype.Present}, Type: DeviceTypeIot}, 32 | }, 33 | devices, 34 | ) 35 | }) 36 | } 37 | 38 | //nolint:gochecknoglobals 39 | var allDeviceTypes = []DeviceType{ 40 | DeviceTypeUndefined, 41 | DeviceTypePhone, 42 | DeviceTypeLaptop, 43 | DeviceTypeIpad, 44 | DeviceTypeDesktop, 45 | DeviceTypeIot, 46 | } 47 | 48 | func TestNewQuerier_FindOneDeviceArray(t *testing.T) { 49 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) 50 | defer cancel() 51 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 52 | defer cleanup() 53 | 54 | q := NewQuerier(conn) 55 | 56 | t.Run("FindOneDeviceArray", func(t *testing.T) { 57 | devices, err := q.FindOneDeviceArray(ctx) 58 | require.NoError(t, err) 59 | assert.Equal(t, allDeviceTypes, devices) 60 | }) 61 | } 62 | 63 | func TestNewQuerier_FindManyDeviceArray(t *testing.T) { 64 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) 65 | defer cancel() 66 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 67 | defer cleanup() 68 | 69 | q := NewQuerier(conn) 70 | 71 | t.Run("FindManyDeviceArray", func(t *testing.T) { 72 | devices, err := q.FindManyDeviceArray(ctx) 73 | require.NoError(t, err) 74 | assert.Equal(t, [][]DeviceType{allDeviceTypes[3:], allDeviceTypes}, devices) 75 | }) 76 | } 77 | 78 | func TestNewQuerier_FindManyDeviceArrayWithNum(t *testing.T) { 79 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) 80 | defer cancel() 81 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 82 | defer cleanup() 83 | 84 | q := NewQuerier(conn) 85 | one, two := int32(1), int32(2) 86 | 87 | t.Run("FindManyDeviceArrayWithNum", func(t *testing.T) { 88 | devices, err := q.FindManyDeviceArrayWithNum(ctx) 89 | require.NoError(t, err) 90 | assert.Equal(t, []FindManyDeviceArrayWithNumRow{ 91 | {Num: &one, DeviceTypes: allDeviceTypes[3:]}, 92 | {Num: &two, DeviceTypes: allDeviceTypes}, 93 | }, devices) 94 | }) 95 | } 96 | 97 | func TestNewQuerier_EnumInsideComposite(t *testing.T) { 98 | ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) 99 | defer cancel() 100 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 101 | defer cleanup() 102 | 103 | q := NewQuerier(conn) 104 | mac, _ := net.ParseMAC("08:00:2b:01:02:03") 105 | 106 | t.Run("EnumInsideComposite", func(t *testing.T) { 107 | device, err := q.EnumInsideComposite(ctx) 108 | require.NoError(t, err) 109 | assert.Equal(t, 110 | Device{Mac: pgtype.Macaddr{Addr: mac, Status: pgtype.Present}, Type: DeviceTypePhone}, 111 | device, 112 | ) 113 | }) 114 | } 115 | 116 | func insertDevice(t *testing.T, q *DBQuerier, mac net.HardwareAddr, device DeviceType) { 117 | t.Helper() 118 | _, err := q.InsertDevice(t.Context(), 119 | pgtype.Macaddr{Addr: mac, Status: pgtype.Present}, 120 | device, 121 | ) 122 | require.NoError(t, err) 123 | } 124 | -------------------------------------------------------------------------------- /example/enums/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE TYPE device_type AS ENUM ( 2 | 'undefined', 3 | 'phone', 4 | 'laptop', 5 | 'ipad', 6 | 'desktop', 7 | 'iot' 8 | ); 9 | 10 | CREATE TABLE device ( 11 | mac MACADDR PRIMARY KEY, 12 | type device_type NOT NULL DEFAULT 'undefined' 13 | ); 14 | -------------------------------------------------------------------------------- /example/erp/01_schema.sql: -------------------------------------------------------------------------------- 1 | CREATE DOMAIN js_int AS bigint CHECK ( 0 < value AND value < 9007199254740991 ); 2 | -- tenant_id should be 3-5 chars in base 36. 3 | CREATE DOMAIN tenant_id AS js_int CHECK ( 36 * 36 * 36 < value AND value < 36 * 36 * 36 * 36 * 36 ); 4 | 5 | CREATE TABLE tenant ( 6 | tenant_id tenant_id PRIMARY KEY, 7 | rname text UNIQUE GENERATED ALWAYS AS ( 'tenants/' || tenant_id::text ) STORED, 8 | name text NOT NULL CHECK ( name != '' ) 9 | ); 10 | 11 | CREATE TABLE customer ( 12 | customer_id serial PRIMARY KEY, 13 | first_name text NOT NULL, 14 | last_name text NOT NULL, 15 | email text NOT NULL 16 | ); 17 | 18 | CREATE TABLE orders ( 19 | order_id serial PRIMARY KEY, 20 | order_date timestamptz NOT NULL, 21 | order_total numeric NOT NULL, 22 | customer_id int REFERENCES customer 23 | ); 24 | 25 | CREATE TABLE product ( 26 | product_id serial PRIMARY KEY, 27 | name text NOT NULL, 28 | description text NOT NULL, 29 | list_price numeric NOT NULL 30 | ); 31 | 32 | CREATE OR REPLACE FUNCTION base36_encode( 33 | IN digits bigint 34 | ) RETURNS text 35 | AS 36 | $$ 37 | DECLARE 38 | chars char[]; 39 | ret text; 40 | val bigint; 41 | BEGIN 42 | chars := 43 | ARRAY ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z']; 44 | val := digits; 45 | ret := ''; 46 | IF val < 0 THEN 47 | val := val * -1; 48 | END IF; 49 | WHILE val != 0 50 | LOOP 51 | ret := chars[(val % 36) + 1] || ret; 52 | val := val / 36; 53 | END LOOP; 54 | RETURN ret; 55 | END; 56 | $$ LANGUAGE plpgsql IMMUTABLE 57 | PARALLEL SAFE; 58 | 59 | 60 | CREATE OR REPLACE FUNCTION base36_decode( 61 | IN base36 text 62 | ) 63 | RETURNS bigint 64 | AS 65 | $$ 66 | DECLARE 67 | a char[]; 68 | ret bigint; 69 | i int; 70 | val int; 71 | chars text; 72 | BEGIN 73 | -- Check for null so pggen can pass in null. 74 | IF base36 IS NULL THEN 75 | RETURN 0; 76 | END IF; 77 | chars := '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'; 78 | FOR i IN REVERSE char_length(base36)..1 79 | LOOP 80 | a := a || substring(upper(base36) FROM i FOR 1)::char; 81 | END LOOP; 82 | i := 0; 83 | ret := 0; 84 | WHILE i < (array_length(a, 1)) 85 | LOOP 86 | val := position(a[i + 1] IN chars) - 1; 87 | ret := ret + (val * (36 ^ i)); 88 | i := i + 1; 89 | END LOOP; 90 | RETURN ret; 91 | END; 92 | $$ LANGUAGE plpgsql IMMUTABLE 93 | PARALLEL SAFE; 94 | -------------------------------------------------------------------------------- /example/erp/02_schema.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE order_product ( 2 | order_id int REFERENCES orders, 3 | product_id int REFERENCES product 4 | ); 5 | 6 | -------------------------------------------------------------------------------- /example/erp/order/codegen_test.go: -------------------------------------------------------------------------------- 1 | package order 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_ERP_Order(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{ 15 | "../01_schema.sql", 16 | "../02_schema.sql", 17 | }) 18 | defer cleanupFunc() 19 | 20 | tmpDir := t.TempDir() 21 | err := pggen.Generate( 22 | pggen.GenerateOptions{ 23 | ConnString: conn.Config().ConnString(), 24 | QueryFiles: []string{ 25 | "customer.sql", 26 | "price.sql", 27 | }, 28 | OutputDir: tmpDir, 29 | GoPackage: "order", 30 | Language: pggen.LangGo, 31 | InlineParamCount: 2, 32 | Acronyms: map[string]string{"mrr": "MRR"}, 33 | TypeOverrides: map[string]string{"tenant_id": "int"}, 34 | }) 35 | if err != nil { 36 | t.Fatalf("Generate() example/erp/order: %s", err) 37 | } 38 | 39 | for _, file := range []string{"customer.sql.go", "price.sql.go"} { 40 | wantQueries, err := os.ReadFile(file) 41 | if err != nil { 42 | t.Fatalf("read wanted file %s: %s", file, err) 43 | } 44 | 45 | gotFile := filepath.Join(tmpDir, file) 46 | assert.FileExists(t, gotFile, "Generate() should emit "+file) 47 | gotQueries, err := os.ReadFile(gotFile) 48 | if err != nil { 49 | t.Fatalf("read generated %s: %s", file, err) 50 | } 51 | assert.Equalf(t, string(wantQueries), string(gotQueries), 52 | "Got file %s; does not match contents of file %s", 53 | gotFile, file) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /example/erp/order/customer.sql: -------------------------------------------------------------------------------- 1 | -- name: CreateTenant :one 2 | INSERT INTO tenant (tenant_id, name) 3 | VALUES (base36_decode(pggen.arg('key')::text)::tenant_id, pggen.arg('name')::text) 4 | RETURNING *; 5 | 6 | -- name: FindOrdersByCustomer :many 7 | SELECT * 8 | FROM orders 9 | WHERE customer_id = pggen.arg('CustomerID'); 10 | 11 | -- name: FindProductsInOrder :many 12 | SELECT o.order_id, p.product_id, p.name 13 | FROM orders o 14 | INNER JOIN order_product op USING (order_id) 15 | INNER JOIN product p USING (product_id) 16 | WHERE o.order_id = pggen.arg('OrderID'); 17 | 18 | -- name: InsertCustomer :one 19 | INSERT INTO customer (first_name, last_name, email) 20 | VALUES (pggen.arg('first_name'), pggen.arg('last_name'), pggen.arg('email')) 21 | RETURNING *; 22 | 23 | -- name: InsertOrder :one 24 | INSERT INTO orders (order_date, order_total, customer_id) 25 | VALUES (pggen.arg('order_date'), pggen.arg('order_total'), pggen.arg('cust_id')) 26 | RETURNING *; 27 | -------------------------------------------------------------------------------- /example/erp/order/customer.sql_test.go: -------------------------------------------------------------------------------- 1 | package order 2 | 3 | import ( 4 | "math/big" 5 | "testing" 6 | "time" 7 | 8 | "github.com/jackc/pgtype" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestNewQuerier_FindOrdersByCustomer(t *testing.T) { 15 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../01_schema.sql", "../02_schema.sql"}) 16 | defer cleanup() 17 | ctx := t.Context() 18 | 19 | q := NewQuerier(conn) 20 | cust1, err := q.InsertCustomer(ctx, InsertCustomerParams{ 21 | FirstName: "foo_first", 22 | LastName: "foo_last", 23 | Email: "foo_email", 24 | }) 25 | if err != nil { 26 | t.Error(err) 27 | return 28 | } 29 | order1, err := q.InsertOrder(ctx, InsertOrderParams{ 30 | OrderDate: pgtype.Timestamptz{Time: time.Now(), Status: pgtype.Present}, 31 | OrderTotal: pgtype.Numeric{Int: big.NewInt(77), Status: pgtype.Present}, 32 | CustID: cust1.CustomerID, 33 | }) 34 | if err != nil { 35 | t.Error(err) 36 | return 37 | } 38 | 39 | t.Run("FindOrdersByCustomer", func(t *testing.T) { 40 | orders, err := q.FindOrdersByCustomer(t.Context(), cust1.CustomerID) 41 | require.NoError(t, err) 42 | assert.Equal(t, []FindOrdersByCustomerRow{ 43 | { 44 | OrderID: order1.OrderID, 45 | OrderDate: order1.OrderDate, 46 | OrderTotal: order1.OrderTotal, 47 | CustomerID: order1.CustomerID, 48 | }, 49 | }, orders) 50 | }) 51 | } 52 | 53 | func TestNewQuerier_QuerierMatchesDBQuerier(t *testing.T) { 54 | var q Querier = NewQuerier(nil) 55 | require.NotNil(t, q.FindOrdersByCustomer) 56 | require.NotNil(t, q.FindProductsInOrder) 57 | require.NotNil(t, q.InsertOrder) 58 | require.NotNil(t, q.FindOrdersByPrice) 59 | require.NotNil(t, q.FindOrdersMRR) 60 | } 61 | -------------------------------------------------------------------------------- /example/erp/order/price.sql: -------------------------------------------------------------------------------- 1 | -- name: FindOrdersByPrice :many 2 | SELECT * FROM orders WHERE order_total > pggen.arg('MinTotal'); 3 | 4 | -- name: FindOrdersMRR :many 5 | SELECT date_trunc('month', order_date) AS month, sum(order_total) AS order_mrr 6 | FROM orders 7 | GROUP BY date_trunc('month', order_date); 8 | -------------------------------------------------------------------------------- /example/erp/order/price.sql.go: -------------------------------------------------------------------------------- 1 | // Code generated by pggen. DO NOT EDIT. 2 | 3 | package order 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "github.com/jackc/pgtype" 9 | ) 10 | 11 | const findOrdersByPriceSQL = `SELECT * FROM orders WHERE order_total > $1;` 12 | 13 | type FindOrdersByPriceRow struct { 14 | OrderID int32 `json:"order_id"` 15 | OrderDate pgtype.Timestamptz `json:"order_date"` 16 | OrderTotal pgtype.Numeric `json:"order_total"` 17 | CustomerID *int32 `json:"customer_id"` 18 | } 19 | 20 | // FindOrdersByPrice implements Querier.FindOrdersByPrice. 21 | func (q *DBQuerier) FindOrdersByPrice(ctx context.Context, minTotal pgtype.Numeric) ([]FindOrdersByPriceRow, error) { 22 | ctx = context.WithValue(ctx, "pggen_query_name", "FindOrdersByPrice") 23 | rows, err := q.conn.Query(ctx, findOrdersByPriceSQL, minTotal) 24 | if err != nil { 25 | return nil, fmt.Errorf("query FindOrdersByPrice: %w", err) 26 | } 27 | defer rows.Close() 28 | items := []FindOrdersByPriceRow{} 29 | for rows.Next() { 30 | var item FindOrdersByPriceRow 31 | if err := rows.Scan(&item.OrderID, &item.OrderDate, &item.OrderTotal, &item.CustomerID); err != nil { 32 | return nil, fmt.Errorf("scan FindOrdersByPrice row: %w", err) 33 | } 34 | items = append(items, item) 35 | } 36 | if err := rows.Err(); err != nil { 37 | return nil, fmt.Errorf("close FindOrdersByPrice rows: %w", err) 38 | } 39 | return items, err 40 | } 41 | 42 | const findOrdersMRRSQL = `SELECT date_trunc('month', order_date) AS month, sum(order_total) AS order_mrr 43 | FROM orders 44 | GROUP BY date_trunc('month', order_date);` 45 | 46 | type FindOrdersMRRRow struct { 47 | Month pgtype.Timestamptz `json:"month"` 48 | OrderMRR pgtype.Numeric `json:"order_mrr"` 49 | } 50 | 51 | // FindOrdersMRR implements Querier.FindOrdersMRR. 52 | func (q *DBQuerier) FindOrdersMRR(ctx context.Context) ([]FindOrdersMRRRow, error) { 53 | ctx = context.WithValue(ctx, "pggen_query_name", "FindOrdersMRR") 54 | rows, err := q.conn.Query(ctx, findOrdersMRRSQL) 55 | if err != nil { 56 | return nil, fmt.Errorf("query FindOrdersMRR: %w", err) 57 | } 58 | defer rows.Close() 59 | items := []FindOrdersMRRRow{} 60 | for rows.Next() { 61 | var item FindOrdersMRRRow 62 | if err := rows.Scan(&item.Month, &item.OrderMRR); err != nil { 63 | return nil, fmt.Errorf("scan FindOrdersMRR row: %w", err) 64 | } 65 | items = append(items, item) 66 | } 67 | if err := rows.Err(); err != nil { 68 | return nil, fmt.Errorf("close FindOrdersMRR rows: %w", err) 69 | } 70 | return items, err 71 | } 72 | -------------------------------------------------------------------------------- /example/function/codegen_test.go: -------------------------------------------------------------------------------- 1 | package function 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_Function(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "function", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | }) 27 | if err != nil { 28 | t.Fatalf("Generate() example/function: %s", err) 29 | } 30 | 31 | wantQueryFile := "query.sql.go" 32 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 33 | assert.FileExists(t, gotQueryFile, 34 | "Generate() should emit query.sql.go") 35 | wantQueries, err := os.ReadFile(wantQueryFile) 36 | if err != nil { 37 | t.Fatalf("read wanted query.go.sql: %s", err) 38 | } 39 | gotQueries, err := os.ReadFile(gotQueryFile) 40 | if err != nil { 41 | t.Fatalf("read generated query.go.sql: %s", err) 42 | } 43 | assert.Equalf(t, string(wantQueries), string(gotQueries), 44 | "Got file %s; does not match contents of %s", 45 | gotQueryFile, wantQueryFile) 46 | } 47 | -------------------------------------------------------------------------------- /example/function/query.sql: -------------------------------------------------------------------------------- 1 | -- name: OutParams :many 2 | SELECT * FROM out_params(); 3 | -------------------------------------------------------------------------------- /example/function/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package function 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jschaf/pggen/internal/difftest" 7 | "github.com/jschaf/pggen/internal/ptrs" 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/jschaf/pggen/internal/pgtest" 11 | ) 12 | 13 | func TestNewQuerier_OutParams(t *testing.T) { 14 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanup() 16 | 17 | q := NewQuerier(conn) 18 | 19 | t.Run("OutParams", func(t *testing.T) { 20 | got, err := q.OutParams(t.Context()) 21 | require.NoError(t, err) 22 | want := []OutParamsRow{ 23 | { 24 | Items: []ListItem{{Name: ptrs.String("some_name"), Color: ptrs.String("some_color")}}, 25 | Stats: ListStats{ 26 | Val1: ptrs.String("abc"), 27 | Val2: []*int32{ptrs.Int32(1), ptrs.Int32(2)}, 28 | }, 29 | }, 30 | } 31 | difftest.AssertSame(t, want, got) 32 | }) 33 | } 34 | -------------------------------------------------------------------------------- /example/function/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE TYPE list_item AS ( 2 | name text, 3 | color text 4 | ); 5 | 6 | CREATE TYPE list_stats AS ( 7 | val1 text, 8 | val2 int[] 9 | ); 10 | 11 | CREATE OR REPLACE FUNCTION out_params( 12 | OUT _items list_item[], 13 | OUT _stats list_stats 14 | ) 15 | LANGUAGE plpgsql AS $$ 16 | BEGIN 17 | _items := ARRAY [('some_name', 'some_color')::list_item]; 18 | _stats := ('abc', ARRAY [1, 2])::list_stats; 19 | END 20 | $$; 21 | -------------------------------------------------------------------------------- /example/go_pointer_types/codegen_test.go: -------------------------------------------------------------------------------- 1 | package go_pointer_types 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_GoPointerTypes(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "go_pointer_types", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | TypeOverrides: map[string]string{ 27 | "int4": "*int", 28 | "_int4": "[]int", 29 | "int8": "*int", 30 | "_int8": "[]int", 31 | "text": "*string", 32 | }, 33 | }) 34 | if err != nil { 35 | t.Fatalf("Generate() example/go_pointer_types: %s", err) 36 | } 37 | 38 | wantQueryFile := "query.sql.go" 39 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 40 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 41 | wantQueries, err := os.ReadFile(wantQueryFile) 42 | if err != nil { 43 | t.Fatalf("read wanted query.go.sql: %s", err) 44 | } 45 | gotQueries, err := os.ReadFile(gotQueryFile) 46 | if err != nil { 47 | t.Fatalf("read generated query.go.sql: %s", err) 48 | } 49 | assert.Equalf(t, string(wantQueries), string(gotQueries), 50 | "Got file %s; does not match contents of %s", 51 | gotQueryFile, wantQueryFile) 52 | } 53 | -------------------------------------------------------------------------------- /example/go_pointer_types/query.sql: -------------------------------------------------------------------------------- 1 | -- name: GenSeries1 :one 2 | SELECT n 3 | FROM generate_series(0, 2) n 4 | LIMIT 1; 5 | 6 | -- name: GenSeries :many 7 | SELECT n 8 | FROM generate_series(0, 2) n; 9 | 10 | -- name: GenSeriesArr1 :one 11 | SELECT array_agg(n) 12 | FROM generate_series(0, 2) n; 13 | 14 | -- name: GenSeriesArr :many 15 | SELECT array_agg(n) 16 | FROM generate_series(0, 2) n; 17 | 18 | -- name: GenSeriesStr1 :one 19 | SELECT n::text 20 | FROM generate_series(0, 2) n 21 | LIMIT 1; 22 | 23 | -- name: GenSeriesStr :many 24 | SELECT n::text 25 | FROM generate_series(0, 2) n; 26 | -------------------------------------------------------------------------------- /example/go_pointer_types/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package go_pointer_types 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jschaf/pggen/internal/pgtest" 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestQuerier_GenSeries1(t *testing.T) { 12 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 13 | defer cleanup() 14 | 15 | q := NewQuerier(conn) 16 | ctx := t.Context() 17 | 18 | t.Run("GenSeries1", func(t *testing.T) { 19 | got, err := q.GenSeries1(ctx) 20 | require.NoError(t, err) 21 | zero := 0 22 | assert.Equal(t, &zero, got) 23 | }) 24 | } 25 | 26 | func TestQuerier_GenSeries(t *testing.T) { 27 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 28 | defer cleanup() 29 | 30 | q := NewQuerier(conn) 31 | ctx := t.Context() 32 | 33 | t.Run("GenSeries", func(t *testing.T) { 34 | got, err := q.GenSeries(ctx) 35 | if err != nil { 36 | t.Fatal(err) 37 | } 38 | zero, one, two := 0, 1, 2 39 | assert.Equal(t, []*int{&zero, &one, &two}, got) 40 | }) 41 | } 42 | 43 | func TestQuerier_GenSeriesArr1(t *testing.T) { 44 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 45 | defer cleanup() 46 | 47 | q := NewQuerier(conn) 48 | ctx := t.Context() 49 | 50 | t.Run("GenSeriesArr1", func(t *testing.T) { 51 | got, err := q.GenSeriesArr1(ctx) 52 | require.NoError(t, err) 53 | assert.Equal(t, []int{0, 1, 2}, got) 54 | }) 55 | } 56 | 57 | func TestQuerier_GenSeriesArr(t *testing.T) { 58 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 59 | defer cleanup() 60 | 61 | q := NewQuerier(conn) 62 | ctx := t.Context() 63 | 64 | t.Run("GenSeriesArr", func(t *testing.T) { 65 | got, err := q.GenSeriesArr(ctx) 66 | require.NoError(t, err) 67 | assert.Equal(t, [][]int{{0, 1, 2}}, got) 68 | }) 69 | } 70 | 71 | func TestQuerier_GenSeriesStr(t *testing.T) { 72 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 73 | defer cleanup() 74 | 75 | q := NewQuerier(conn) 76 | ctx := t.Context() 77 | 78 | t.Run("GenSeriesStr1", func(t *testing.T) { 79 | got, err := q.GenSeriesStr1(ctx) 80 | require.NoError(t, err) 81 | zero := "0" 82 | assert.Equal(t, &zero, got) 83 | }) 84 | 85 | t.Run("GenSeriesStr", func(t *testing.T) { 86 | got, err := q.GenSeriesStr(ctx) 87 | require.NoError(t, err) 88 | zero, one, two := "0", "1", "2" 89 | assert.Equal(t, []*string{&zero, &one, &two}, got) 90 | }) 91 | } 92 | -------------------------------------------------------------------------------- /example/go_pointer_types/schema.sql: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jschaf/pggen/632aa8e2e34733a36dbda854dcb4aa8a9b57f884/example/go_pointer_types/schema.sql -------------------------------------------------------------------------------- /example/inline_param_count/codegen_test.go: -------------------------------------------------------------------------------- 1 | package author 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_InlineParamCount(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tests := []struct { 18 | name string 19 | opts pggen.GenerateOptions 20 | wantQueryPath string 21 | }{ 22 | { 23 | name: "inline0", 24 | opts: pggen.GenerateOptions{ 25 | ConnString: conn.Config().ConnString(), 26 | QueryFiles: []string{"query.sql"}, 27 | GoPackage: "inline0", 28 | Language: pggen.LangGo, 29 | InlineParamCount: 0, 30 | }, 31 | wantQueryPath: "inline0/query.sql.go", 32 | }, 33 | { 34 | name: "inline1", 35 | opts: pggen.GenerateOptions{ 36 | ConnString: conn.Config().ConnString(), 37 | QueryFiles: []string{"query.sql"}, 38 | GoPackage: "inline1", 39 | Language: pggen.LangGo, 40 | InlineParamCount: 1, 41 | }, 42 | wantQueryPath: "inline1/query.sql.go", 43 | }, 44 | { 45 | name: "inline2", 46 | opts: pggen.GenerateOptions{ 47 | ConnString: conn.Config().ConnString(), 48 | QueryFiles: []string{"query.sql"}, 49 | GoPackage: "inline2", 50 | Language: pggen.LangGo, 51 | InlineParamCount: 2, 52 | }, 53 | wantQueryPath: "inline2/query.sql.go", 54 | }, 55 | { 56 | name: "inline3", 57 | opts: pggen.GenerateOptions{ 58 | ConnString: conn.Config().ConnString(), 59 | QueryFiles: []string{"query.sql"}, 60 | GoPackage: "inline3", 61 | Language: pggen.LangGo, 62 | InlineParamCount: 3, 63 | }, 64 | wantQueryPath: "inline3/query.sql.go", 65 | }, 66 | } 67 | for _, tt := range tests { 68 | t.Run(tt.name, func(t *testing.T) { 69 | tmpDir := t.TempDir() 70 | tt.opts.OutputDir = tmpDir 71 | err := pggen.Generate(tt.opts) 72 | if err != nil { 73 | t.Fatalf("Generate() example/author %s: %s", tt.name, err.Error()) 74 | } 75 | 76 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 77 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 78 | wantQueries, err := os.ReadFile(tt.wantQueryPath) 79 | if err != nil { 80 | t.Fatalf("read wanted query.go.sql: %s", err) 81 | } 82 | gotQueries, err := os.ReadFile(gotQueryFile) 83 | if err != nil { 84 | t.Fatalf("read generated query.go.sql: %s", err) 85 | } 86 | assert.Equalf(t, string(wantQueries), string(gotQueries), 87 | "Got file %s; does not match contents of %s", 88 | gotQueryFile, tt.wantQueryPath) 89 | }) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /example/inline_param_count/inline0/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package inline0 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | 9 | "github.com/jackc/pgx/v4" 10 | "github.com/jschaf/pggen/internal/pgtest" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestNewQuerier_FindAuthorByID(t *testing.T) { 15 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) 16 | defer cleanup() 17 | 18 | q := NewQuerier(conn) 19 | adamsID := insertAuthor(t, q, "john", "adams") 20 | insertAuthor(t, q, "george", "washington") 21 | 22 | t.Run("CountAuthors two", func(t *testing.T) { 23 | got, err := q.CountAuthors(t.Context()) 24 | require.NoError(t, err) 25 | assert.Equal(t, 2, *got) 26 | }) 27 | 28 | t.Run("FindAuthorByID", func(t *testing.T) { 29 | authorByID, err := q.FindAuthorByID(t.Context(), FindAuthorByIDParams{AuthorID: adamsID}) 30 | require.NoError(t, err) 31 | assert.Equal(t, FindAuthorByIDRow{ 32 | AuthorID: adamsID, 33 | FirstName: "john", 34 | LastName: "adams", 35 | Suffix: nil, 36 | }, authorByID) 37 | }) 38 | 39 | t.Run("FindAuthorByID - none-exists", func(t *testing.T) { 40 | missingAuthorByID, err := q.FindAuthorByID(t.Context(), FindAuthorByIDParams{AuthorID: 888}) 41 | require.Error(t, err, "expected error when finding author ID that doesn't match") 42 | assert.Zero(t, missingAuthorByID, "expected zero value when error") 43 | if !errors.Is(err, pgx.ErrNoRows) { 44 | t.Fatalf("expected no rows error to wrap pgx.ErrNoRows; got %s", err) 45 | } 46 | }) 47 | } 48 | 49 | func TestNewQuerier_DeleteAuthorsByFullName(t *testing.T) { 50 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) 51 | defer cleanup() 52 | q := NewQuerier(conn) 53 | insertAuthor(t, q, "george", "washington") 54 | 55 | t.Run("DeleteAuthorsByFullName", func(t *testing.T) { 56 | tag, err := q.DeleteAuthorsByFullName(t.Context(), DeleteAuthorsByFullNameParams{ 57 | FirstName: "george", 58 | LastName: "washington", 59 | Suffix: "", 60 | }) 61 | require.NoError(t, err) 62 | assert.Truef(t, tag.Delete(), "expected delete tag; got %s", tag.String()) 63 | assert.Equal(t, int64(1), tag.RowsAffected()) 64 | }) 65 | } 66 | 67 | func insertAuthor(t *testing.T, q *DBQuerier, first, last string) int32 { 68 | t.Helper() 69 | authorID, err := q.InsertAuthor(t.Context(), InsertAuthorParams{ 70 | FirstName: first, 71 | LastName: last, 72 | }) 73 | require.NoError(t, err, "insert author") 74 | return authorID 75 | } 76 | -------------------------------------------------------------------------------- /example/inline_param_count/inline1/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package inline1 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | 9 | "github.com/jackc/pgx/v4" 10 | "github.com/jschaf/pggen/internal/pgtest" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestNewQuerier_FindAuthorByID(t *testing.T) { 15 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) 16 | defer cleanup() 17 | 18 | q := NewQuerier(conn) 19 | adamsID := insertAuthor(t, q, "john", "adams") 20 | insertAuthor(t, q, "george", "washington") 21 | 22 | t.Run("CountAuthors two", func(t *testing.T) { 23 | got, err := q.CountAuthors(t.Context()) 24 | require.NoError(t, err) 25 | assert.Equal(t, 2, *got) 26 | }) 27 | 28 | t.Run("FindAuthorByID", func(t *testing.T) { 29 | authorByID, err := q.FindAuthorByID(t.Context(), adamsID) 30 | require.NoError(t, err) 31 | assert.Equal(t, FindAuthorByIDRow{ 32 | AuthorID: adamsID, 33 | FirstName: "john", 34 | LastName: "adams", 35 | Suffix: nil, 36 | }, authorByID) 37 | }) 38 | 39 | t.Run("FindAuthorByID - none-exists", func(t *testing.T) { 40 | missingAuthorByID, err := q.FindAuthorByID(t.Context(), 888) 41 | require.Error(t, err, "expected error when finding author ID that doesn't match") 42 | assert.Zero(t, missingAuthorByID, "expected zero value when error") 43 | if !errors.Is(err, pgx.ErrNoRows) { 44 | t.Fatalf("expected no rows error to wrap pgx.ErrNoRows; got %s", err) 45 | } 46 | }) 47 | } 48 | 49 | func TestNewQuerier_DeleteAuthorsByFullName(t *testing.T) { 50 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) 51 | defer cleanup() 52 | q := NewQuerier(conn) 53 | insertAuthor(t, q, "george", "washington") 54 | 55 | t.Run("DeleteAuthorsByFullName", func(t *testing.T) { 56 | tag, err := q.DeleteAuthorsByFullName(t.Context(), DeleteAuthorsByFullNameParams{ 57 | FirstName: "george", 58 | LastName: "washington", 59 | Suffix: "", 60 | }) 61 | require.NoError(t, err) 62 | assert.Truef(t, tag.Delete(), "expected delete tag; got %s", tag.String()) 63 | assert.Equal(t, int64(1), tag.RowsAffected()) 64 | }) 65 | } 66 | 67 | func insertAuthor(t *testing.T, q *DBQuerier, first, last string) int32 { 68 | t.Helper() 69 | authorID, err := q.InsertAuthor(t.Context(), InsertAuthorParams{ 70 | FirstName: first, 71 | LastName: last, 72 | }) 73 | require.NoError(t, err, "insert author") 74 | return authorID 75 | } 76 | -------------------------------------------------------------------------------- /example/inline_param_count/inline2/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package inline2 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | 9 | "github.com/jackc/pgx/v4" 10 | "github.com/jschaf/pggen/internal/pgtest" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestNewQuerier_FindAuthorByID(t *testing.T) { 15 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) 16 | defer cleanup() 17 | 18 | q := NewQuerier(conn) 19 | adamsID := insertAuthor(t, q, "john", "adams") 20 | insertAuthor(t, q, "george", "washington") 21 | 22 | t.Run("CountAuthors two", func(t *testing.T) { 23 | got, err := q.CountAuthors(t.Context()) 24 | require.NoError(t, err) 25 | assert.Equal(t, 2, *got) 26 | }) 27 | 28 | t.Run("FindAuthorByID", func(t *testing.T) { 29 | authorByID, err := q.FindAuthorByID(t.Context(), adamsID) 30 | require.NoError(t, err) 31 | assert.Equal(t, FindAuthorByIDRow{ 32 | AuthorID: adamsID, 33 | FirstName: "john", 34 | LastName: "adams", 35 | Suffix: nil, 36 | }, authorByID) 37 | }) 38 | 39 | t.Run("FindAuthorByID - none-exists", func(t *testing.T) { 40 | missingAuthorByID, err := q.FindAuthorByID(t.Context(), 888) 41 | require.Error(t, err, "expected error when finding author ID that doesn't match") 42 | assert.Zero(t, missingAuthorByID, "expected zero value when error") 43 | if !errors.Is(err, pgx.ErrNoRows) { 44 | t.Fatalf("expected no rows error to wrap pgx.ErrNoRows; got %s", err) 45 | } 46 | }) 47 | } 48 | 49 | func TestNewQuerier_DeleteAuthorsByFullName(t *testing.T) { 50 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) 51 | defer cleanup() 52 | q := NewQuerier(conn) 53 | insertAuthor(t, q, "george", "washington") 54 | 55 | t.Run("DeleteAuthorsByFullName", func(t *testing.T) { 56 | tag, err := q.DeleteAuthorsByFullName(t.Context(), DeleteAuthorsByFullNameParams{ 57 | FirstName: "george", 58 | LastName: "washington", 59 | Suffix: "", 60 | }) 61 | require.NoError(t, err) 62 | assert.Truef(t, tag.Delete(), "expected delete tag; got %s", tag.String()) 63 | assert.Equal(t, int64(1), tag.RowsAffected()) 64 | }) 65 | } 66 | 67 | func insertAuthor(t *testing.T, q *DBQuerier, first, last string) int32 { 68 | t.Helper() 69 | authorID, err := q.InsertAuthor(t.Context(), first, last) 70 | require.NoError(t, err, "insert author") 71 | return authorID 72 | } 73 | -------------------------------------------------------------------------------- /example/inline_param_count/inline3/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package inline3 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | 9 | "github.com/jackc/pgx/v4" 10 | "github.com/jschaf/pggen/internal/pgtest" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestNewQuerier_FindAuthorByID(t *testing.T) { 15 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) 16 | defer cleanup() 17 | 18 | q := NewQuerier(conn) 19 | adamsID := insertAuthor(t, q, "john", "adams") 20 | insertAuthor(t, q, "george", "washington") 21 | 22 | t.Run("CountAuthors two", func(t *testing.T) { 23 | got, err := q.CountAuthors(t.Context()) 24 | require.NoError(t, err) 25 | assert.Equal(t, 2, *got) 26 | }) 27 | 28 | t.Run("FindAuthorByID", func(t *testing.T) { 29 | authorByID, err := q.FindAuthorByID(t.Context(), adamsID) 30 | require.NoError(t, err) 31 | assert.Equal(t, FindAuthorByIDRow{ 32 | AuthorID: adamsID, 33 | FirstName: "john", 34 | LastName: "adams", 35 | Suffix: nil, 36 | }, authorByID) 37 | }) 38 | 39 | t.Run("FindAuthorByID - none-exists", func(t *testing.T) { 40 | missingAuthorByID, err := q.FindAuthorByID(t.Context(), 888) 41 | require.Error(t, err, "expected error when finding author ID that doesn't match") 42 | assert.Zero(t, missingAuthorByID, "expected zero value when error") 43 | if !errors.Is(err, pgx.ErrNoRows) { 44 | t.Fatalf("expected no rows error to wrap pgx.ErrNoRows; got %s", err) 45 | } 46 | }) 47 | } 48 | 49 | func TestNewQuerier_DeleteAuthorsByFullName(t *testing.T) { 50 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) 51 | defer cleanup() 52 | q := NewQuerier(conn) 53 | insertAuthor(t, q, "george", "washington") 54 | 55 | t.Run("DeleteAuthorsByFullName", func(t *testing.T) { 56 | tag, err := q.DeleteAuthorsByFullName(t.Context(), "george", "washington", "") 57 | require.NoError(t, err) 58 | assert.Truef(t, tag.Delete(), "expected delete tag; got %s", tag.String()) 59 | assert.Equal(t, int64(1), tag.RowsAffected()) 60 | }) 61 | } 62 | 63 | func insertAuthor(t *testing.T, q *DBQuerier, first, last string) int32 { 64 | t.Helper() 65 | authorID, err := q.InsertAuthor(t.Context(), first, last) 66 | require.NoError(t, err, "insert author") 67 | return authorID 68 | } 69 | -------------------------------------------------------------------------------- /example/inline_param_count/query.sql: -------------------------------------------------------------------------------- 1 | -- CountAuthors returns the number of authors (zero params). 2 | -- name: CountAuthors :one 3 | SELECT count(*) FROM author; 4 | 5 | -- FindAuthorById finds one (or zero) authors by ID (one param). 6 | -- name: FindAuthorByID :one 7 | SELECT * FROM author WHERE author_id = pggen.arg('AuthorID'); 8 | 9 | -- InsertAuthor inserts an author by name and returns the ID (two params). 10 | -- name: InsertAuthor :one 11 | INSERT INTO author (first_name, last_name) 12 | VALUES (pggen.arg('FirstName'), pggen.arg('LastName')) 13 | RETURNING author_id; 14 | 15 | -- DeleteAuthorsByFullName deletes authors by the full name (three params). 16 | -- name: DeleteAuthorsByFullName :exec 17 | DELETE 18 | FROM author 19 | WHERE first_name = pggen.arg('FirstName') 20 | AND last_name = pggen.arg('LastName') 21 | AND CASE WHEN pggen.arg('Suffix') = '' THEN suffix IS NULL ELSE suffix = pggen.arg('Suffix') END; -------------------------------------------------------------------------------- /example/inline_param_count/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE author ( 2 | author_id serial PRIMARY KEY, 3 | first_name text NOT NULL, 4 | last_name text NOT NULL, 5 | suffix text NULL 6 | ); 7 | -------------------------------------------------------------------------------- /example/ltree/codegen_test.go: -------------------------------------------------------------------------------- 1 | package ltree 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_ltree(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "ltree", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | TypeOverrides: map[string]string{ 27 | "ltree": "github.com/jackc/pgtype.Text", 28 | "_ltree": "github.com/jackc/pgtype.TextArray", 29 | }, 30 | }) 31 | if err != nil { 32 | t.Fatalf("Generate() example/ltree: %s", err) 33 | } 34 | 35 | wantQueryFile := "query.sql.go" 36 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 37 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 38 | wantQueries, err := os.ReadFile(wantQueryFile) 39 | if err != nil { 40 | t.Fatalf("read wanted query.go.sql: %s", err) 41 | } 42 | gotQueries, err := os.ReadFile(gotQueryFile) 43 | if err != nil { 44 | t.Fatalf("read generated query.go.sql: %s", err) 45 | } 46 | assert.Equalf(t, string(wantQueries), string(gotQueries), 47 | "Got file %s; does not match contents of %s", 48 | gotQueryFile, wantQueryFile) 49 | } 50 | -------------------------------------------------------------------------------- /example/ltree/query.sql: -------------------------------------------------------------------------------- 1 | -- name: FindTopScienceChildren :many 2 | SELECT path 3 | FROM test 4 | WHERE path <@ 'Top.Science'; 5 | 6 | -- name: FindTopScienceChildrenAgg :one 7 | SELECT array_agg(path) 8 | FROM test 9 | WHERE path <@ 'Top.Science'; 10 | 11 | -- name: InsertSampleData :exec 12 | INSERT INTO test 13 | VALUES ('Top'), 14 | ('Top.Science'), 15 | ('Top.Science.Astronomy'), 16 | ('Top.Science.Astronomy.Astrophysics'), 17 | ('Top.Science.Astronomy.Cosmology'), 18 | ('Top.Hobbies'), 19 | ('Top.Hobbies.Amateurs_Astronomy'), 20 | ('Top.Collections'), 21 | ('Top.Collections.Pictures'), 22 | ('Top.Collections.Pictures.Astronomy'), 23 | ('Top.Collections.Pictures.Astronomy.Stars'), 24 | ('Top.Collections.Pictures.Astronomy.Galaxies'), 25 | ('Top.Collections.Pictures.Astronomy.Astronauts'); 26 | 27 | -- name: FindLtreeInput :one 28 | SELECT 29 | pggen.arg('in_ltree')::ltree AS ltree, 30 | -- This won't work, but I'm not quite sure why. 31 | -- Postgres errors with "wrong element type (SQLSTATE 42804)" 32 | -- All caps because we use regex to find pggen.arg and it confuses pggen. 33 | -- PGGEN.arg('in_ltree_array_direct')::ltree[] AS direct_arr, 34 | 35 | -- The parenthesis around the text[] cast are important. They signal to pggen 36 | -- that we need a text array that Postgres then converts to ltree[]. 37 | (pggen.arg('in_ltree_array')::text[])::ltree[] AS text_arr; -------------------------------------------------------------------------------- /example/ltree/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package ltree 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jackc/pgtype" 7 | "github.com/jschaf/pggen/internal/pgtest" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestQuerier(t *testing.T) { 13 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 14 | defer cleanup() 15 | 16 | q := NewQuerier(conn) 17 | ctx := t.Context() 18 | 19 | if _, err := q.InsertSampleData(ctx); err != nil { 20 | t.Fatal(err) 21 | } 22 | 23 | { 24 | rows, err := q.FindTopScienceChildren(ctx) 25 | require.NoError(t, err) 26 | want := []pgtype.Text{ 27 | {String: "Top.Science", Status: pgtype.Present}, 28 | {String: "Top.Science.Astronomy", Status: pgtype.Present}, 29 | {String: "Top.Science.Astronomy.Astrophysics", Status: pgtype.Present}, 30 | {String: "Top.Science.Astronomy.Cosmology", Status: pgtype.Present}, 31 | } 32 | assert.Equal(t, want, rows) 33 | } 34 | 35 | { 36 | rows, err := q.FindTopScienceChildrenAgg(ctx) 37 | require.NoError(t, err) 38 | want := pgtype.TextArray{ 39 | Elements: []pgtype.Text{ 40 | {String: "Top.Science", Status: pgtype.Present}, 41 | {String: "Top.Science.Astronomy", Status: pgtype.Present}, 42 | {String: "Top.Science.Astronomy.Astrophysics", Status: pgtype.Present}, 43 | {String: "Top.Science.Astronomy.Cosmology", Status: pgtype.Present}, 44 | }, 45 | Status: pgtype.Present, 46 | Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, 47 | } 48 | assert.Equal(t, want, rows) 49 | } 50 | 51 | { 52 | in1 := pgtype.Text{String: "foo", Status: pgtype.Present} 53 | in2 := []string{"qux", "qux"} 54 | in2Txt := newTextArray(in2) 55 | rows, err := q.FindLtreeInput(ctx, in1, in2) 56 | require.NoError(t, err) 57 | assert.Equal(t, FindLtreeInputRow{ 58 | Ltree: in1, 59 | TextArr: in2Txt, 60 | }, rows) 61 | } 62 | } 63 | 64 | // newTextArray creates a one dimensional text array from the string slice with 65 | // no null elements. 66 | func newTextArray(ss []string) pgtype.TextArray { 67 | elems := make([]pgtype.Text, len(ss)) 68 | for i, s := range ss { 69 | elems[i] = pgtype.Text{String: s, Status: pgtype.Present} 70 | } 71 | return pgtype.TextArray{ 72 | Elements: elems, 73 | Dimensions: []pgtype.ArrayDimension{{Length: int32(len(ss)), LowerBound: 1}}, 74 | Status: pgtype.Present, 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /example/ltree/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE EXTENSION IF NOT EXISTS ltree; 2 | 3 | -- noinspection SqlResolve 4 | CREATE TABLE test (path ltree); 5 | -------------------------------------------------------------------------------- /example/nested/codegen_test.go: -------------------------------------------------------------------------------- 1 | package nested 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_nested(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "nested", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | TypeOverrides: map[string]string{ 27 | "int4": "int", 28 | "text": "string", 29 | }, 30 | }) 31 | if err != nil { 32 | t.Fatalf("Generate() example/nested: %s", err) 33 | } 34 | 35 | wantQueryFile := "query.sql.go" 36 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 37 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 38 | wantQueries, err := os.ReadFile(wantQueryFile) 39 | if err != nil { 40 | t.Fatalf("read wanted query.go.sql: %s", err) 41 | } 42 | gotQueries, err := os.ReadFile(gotQueryFile) 43 | if err != nil { 44 | t.Fatalf("read generated query.go.sql: %s", err) 45 | } 46 | assert.Equalf(t, string(wantQueries), string(gotQueries), 47 | "Got file %s; does not match contents of %s", 48 | gotQueryFile, wantQueryFile) 49 | } 50 | -------------------------------------------------------------------------------- /example/nested/query.sql: -------------------------------------------------------------------------------- 1 | -- name: ArrayNested2 :one 2 | SELECT 3 | ARRAY [ 4 | ROW ('img2', ROW (22, 22)::dimensions)::product_image_type, 5 | ROW ('img3', ROW (33, 33)::dimensions)::product_image_type 6 | ] AS images; 7 | 8 | -- name: Nested3 :many 9 | SELECT 10 | ROW ( 11 | 'name', -- name 12 | ROW ('img1', ROW (11, 11)::dimensions)::product_image_type, -- orig_image 13 | ARRAY [ --images 14 | ROW ('img2', ROW (22, 22)::dimensions)::product_image_type, 15 | ROW ('img3', ROW (33, 33)::dimensions)::product_image_type 16 | ] 17 | )::product_image_set_type; 18 | 19 | -------------------------------------------------------------------------------- /example/nested/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package nested 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jschaf/pggen/internal/pgtest" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestNewQuerier_ArrayNested2(t *testing.T) { 11 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 12 | defer cleanup() 13 | 14 | q := NewQuerier(conn) 15 | ctx := t.Context() 16 | 17 | want := []ProductImageType{ 18 | {Source: "img2", Dimensions: Dimensions{22, 22}}, 19 | {Source: "img3", Dimensions: Dimensions{33, 33}}, 20 | } 21 | t.Run("ArrayNested2", func(t *testing.T) { 22 | rows, err := q.ArrayNested2(ctx) 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | assert.Equal(t, want, rows) 27 | }) 28 | } 29 | 30 | func TestNewQuerier_Nested3(t *testing.T) { 31 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 32 | defer cleanup() 33 | 34 | q := NewQuerier(conn) 35 | ctx := t.Context() 36 | 37 | want := []ProductImageSetType{ 38 | { 39 | Name: "name", 40 | OrigImage: ProductImageType{ 41 | Source: "img1", 42 | Dimensions: Dimensions{Width: 11, Height: 11}, 43 | }, 44 | Images: []ProductImageType{ 45 | {Source: "img2", Dimensions: Dimensions{22, 22}}, 46 | {Source: "img3", Dimensions: Dimensions{33, 33}}, 47 | }, 48 | }, 49 | } 50 | t.Run("Nested3", func(t *testing.T) { 51 | t.Skipf("https://github.com/jackc/pgx/issues/874") 52 | rows, err := q.Nested3(ctx) 53 | if err != nil { 54 | t.Fatal(err) 55 | } 56 | assert.Equal(t, want, rows) 57 | }) 58 | } 59 | -------------------------------------------------------------------------------- /example/nested/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE TYPE dimensions AS ( 2 | width int4, 3 | height int4 4 | ); 5 | 6 | CREATE TYPE product_image_type AS ( 7 | source text, 8 | dimensions dimensions 9 | ); 10 | 11 | CREATE TYPE product_image_set_type AS ( 12 | name text, 13 | orig_image product_image_type, 14 | images product_image_type[] 15 | ); 16 | -------------------------------------------------------------------------------- /example/pgcrypto/codegen_test.go: -------------------------------------------------------------------------------- 1 | package pgcrypto 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_Pgcrypto(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "pgcrypto", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | }) 27 | if err != nil { 28 | t.Fatalf("Generate() example/pgcrypto: %s", err) 29 | } 30 | 31 | wantQueryFile := "query.sql.go" 32 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 33 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 34 | wantQueries, err := os.ReadFile(wantQueryFile) 35 | if err != nil { 36 | t.Fatalf("read wanted query.go.sql: %s", err) 37 | } 38 | gotQueries, err := os.ReadFile(gotQueryFile) 39 | if err != nil { 40 | t.Fatalf("read generated query.go.sql: %s", err) 41 | } 42 | assert.Equalf(t, string(wantQueries), string(gotQueries), 43 | "Got file %s; does not match contents of %s", 44 | gotQueryFile, wantQueryFile) 45 | } 46 | -------------------------------------------------------------------------------- /example/pgcrypto/query.sql: -------------------------------------------------------------------------------- 1 | -- name: CreateUser :exec 2 | INSERT INTO "user" (email, pass) 3 | VALUES (pggen.arg('email'), crypt(pggen.arg('password'), gen_salt('bf'))); 4 | 5 | -- name: FindUser :one 6 | SELECT email, pass from "user" 7 | where email = pggen.arg('email'); 8 | -------------------------------------------------------------------------------- /example/pgcrypto/query.sql.go: -------------------------------------------------------------------------------- 1 | // Code generated by pggen. DO NOT EDIT. 2 | 3 | package pgcrypto 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "github.com/jackc/pgconn" 9 | "github.com/jackc/pgtype" 10 | "github.com/jackc/pgx/v4" 11 | ) 12 | 13 | // Querier is a typesafe Go interface backed by SQL queries. 14 | type Querier interface { 15 | CreateUser(ctx context.Context, email string, password string) (pgconn.CommandTag, error) 16 | 17 | FindUser(ctx context.Context, email string) (FindUserRow, error) 18 | } 19 | 20 | var _ Querier = &DBQuerier{} 21 | 22 | type DBQuerier struct { 23 | conn genericConn // underlying Postgres transport to use 24 | types *typeResolver // resolve types by name 25 | } 26 | 27 | // genericConn is a connection like *pgx.Conn, pgx.Tx, or *pgxpool.Pool. 28 | type genericConn interface { 29 | Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) 30 | QueryRow(ctx context.Context, sql string, args ...any) pgx.Row 31 | Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) 32 | } 33 | 34 | // NewQuerier creates a DBQuerier that implements Querier. 35 | func NewQuerier(conn genericConn) *DBQuerier { 36 | return &DBQuerier{conn: conn, types: newTypeResolver()} 37 | } 38 | 39 | // typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. 40 | type typeResolver struct { 41 | connInfo *pgtype.ConnInfo // types by Postgres type name 42 | } 43 | 44 | func newTypeResolver() *typeResolver { 45 | ci := pgtype.NewConnInfo() 46 | return &typeResolver{connInfo: ci} 47 | } 48 | 49 | // findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. 50 | func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { 51 | typ, ok := tr.connInfo.DataTypeForName(name) 52 | if !ok { 53 | return 0, nil, false 54 | } 55 | v := pgtype.NewValue(typ.Value) 56 | return typ.OID, v.(pgtype.ValueTranscoder), true 57 | } 58 | 59 | // setValue sets the value of a ValueTranscoder to a value that should always 60 | // work and panics if it fails. 61 | func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { 62 | if err := vt.Set(val); err != nil { 63 | panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) 64 | } 65 | return vt 66 | } 67 | 68 | const createUserSQL = `INSERT INTO "user" (email, pass) 69 | VALUES ($1, crypt($2, gen_salt('bf')));` 70 | 71 | // CreateUser implements Querier.CreateUser. 72 | func (q *DBQuerier) CreateUser(ctx context.Context, email string, password string) (pgconn.CommandTag, error) { 73 | ctx = context.WithValue(ctx, "pggen_query_name", "CreateUser") 74 | cmdTag, err := q.conn.Exec(ctx, createUserSQL, email, password) 75 | if err != nil { 76 | return cmdTag, fmt.Errorf("exec query CreateUser: %w", err) 77 | } 78 | return cmdTag, err 79 | } 80 | 81 | const findUserSQL = `SELECT email, pass from "user" 82 | where email = $1;` 83 | 84 | type FindUserRow struct { 85 | Email string `json:"email"` 86 | Pass string `json:"pass"` 87 | } 88 | 89 | // FindUser implements Querier.FindUser. 90 | func (q *DBQuerier) FindUser(ctx context.Context, email string) (FindUserRow, error) { 91 | ctx = context.WithValue(ctx, "pggen_query_name", "FindUser") 92 | row := q.conn.QueryRow(ctx, findUserSQL, email) 93 | var item FindUserRow 94 | if err := row.Scan(&item.Email, &item.Pass); err != nil { 95 | return item, fmt.Errorf("query FindUser: %w", err) 96 | } 97 | return item, nil 98 | } 99 | 100 | // textPreferrer wraps a pgtype.ValueTranscoder and sets the preferred encoding 101 | // format to text instead binary (the default). pggen uses the text format 102 | // when the OID is unknownOID because the binary format requires the OID. 103 | // Typically occurs for unregistered types. 104 | type textPreferrer struct { 105 | pgtype.ValueTranscoder 106 | typeName string 107 | } 108 | 109 | // PreferredParamFormat implements pgtype.ParamFormatPreferrer. 110 | func (t textPreferrer) PreferredParamFormat() int16 { return pgtype.TextFormatCode } 111 | 112 | func (t textPreferrer) NewTypeValue() pgtype.Value { 113 | return textPreferrer{ValueTranscoder: pgtype.NewValue(t.ValueTranscoder).(pgtype.ValueTranscoder), typeName: t.typeName} 114 | } 115 | 116 | func (t textPreferrer) TypeName() string { 117 | return t.typeName 118 | } 119 | 120 | // unknownOID means we don't know the OID for a type. This is okay for decoding 121 | // because pgx call DecodeText or DecodeBinary without requiring the OID. For 122 | // encoding parameters, pggen uses textPreferrer if the OID is unknown. 123 | const unknownOID = 0 124 | -------------------------------------------------------------------------------- /example/pgcrypto/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package pgcrypto 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/jschaf/pggen/internal/pgtest" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestQuerier(t *testing.T) { 12 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 13 | defer cleanup() 14 | 15 | q := NewQuerier(conn) 16 | ctx := t.Context() 17 | 18 | _, err := q.CreateUser(ctx, "foo", "hunter2") 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | 23 | row, err := q.FindUser(ctx, "foo") 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | assert.Equal(t, "foo", row.Email, "email should match") 28 | if !strings.HasPrefix(row.Pass, "$2a$") { 29 | t.Fatalf("expected hashed password to have prefix $2a$; got %s", row.Pass) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /example/pgcrypto/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE EXTENSION IF NOT EXISTS pgcrypto; 2 | 3 | CREATE TABLE "user" ( 4 | email TEXT PRIMARY KEY, 5 | pass TEXT NOT NULL 6 | ); 7 | -------------------------------------------------------------------------------- /example/pggen_schema.sql: -------------------------------------------------------------------------------- 1 | -- This schema file exists solely so IntelliJ doesn't underline every 2 | -- pggen.arg() expression in squiggly red. 3 | CREATE SCHEMA pggen; 4 | 5 | -- pggen.arg defines a named parameter that's eventually compiled into a 6 | -- placeholder for a prepared query: $1, $2, etc. 7 | CREATE FUNCTION pggen.arg(param text) RETURNS any AS 8 | ''; 9 | -------------------------------------------------------------------------------- /example/separate_out_dir/README.md: -------------------------------------------------------------------------------- 1 | # Example using a separate output dir 2 | 3 | This example shows how to use the `--output-dir` flag to control where pggen 4 | writes query output files. -------------------------------------------------------------------------------- /example/separate_out_dir/alpha/alpha/query.sql: -------------------------------------------------------------------------------- 1 | -- name: AlphaNested :one 2 | SELECT 'alpha_nested' as output; 3 | 4 | -- name: AlphaCompositeArray :one 5 | SELECT ARRAY[ROW('key')]::alpha[]; -------------------------------------------------------------------------------- /example/separate_out_dir/alpha/query.sql: -------------------------------------------------------------------------------- 1 | -- name: Alpha :one 2 | SELECT 'alpha' as output; -------------------------------------------------------------------------------- /example/separate_out_dir/bravo/query.sql: -------------------------------------------------------------------------------- 1 | -- name: Bravo :one 2 | SELECT 'bravo' as output; -------------------------------------------------------------------------------- /example/separate_out_dir/out/alpha_query.sql.1.go: -------------------------------------------------------------------------------- 1 | // Code generated by pggen. DO NOT EDIT. 2 | 3 | package out 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | ) 9 | 10 | const alphaSQL = `SELECT 'alpha' as output;` 11 | 12 | // Alpha implements Querier.Alpha. 13 | func (q *DBQuerier) Alpha(ctx context.Context) (string, error) { 14 | ctx = context.WithValue(ctx, "pggen_query_name", "Alpha") 15 | row := q.conn.QueryRow(ctx, alphaSQL) 16 | var item string 17 | if err := row.Scan(&item); err != nil { 18 | return item, fmt.Errorf("query Alpha: %w", err) 19 | } 20 | return item, nil 21 | } 22 | -------------------------------------------------------------------------------- /example/separate_out_dir/out/bravo_query.sql.go: -------------------------------------------------------------------------------- 1 | // Code generated by pggen. DO NOT EDIT. 2 | 3 | package out 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | ) 9 | 10 | const bravoSQL = `SELECT 'bravo' as output;` 11 | 12 | // Bravo implements Querier.Bravo. 13 | func (q *DBQuerier) Bravo(ctx context.Context) (string, error) { 14 | ctx = context.WithValue(ctx, "pggen_query_name", "Bravo") 15 | row := q.conn.QueryRow(ctx, bravoSQL) 16 | var item string 17 | if err := row.Scan(&item); err != nil { 18 | return item, fmt.Errorf("query Bravo: %w", err) 19 | } 20 | return item, nil 21 | } 22 | -------------------------------------------------------------------------------- /example/separate_out_dir/out/codegen_test.go: -------------------------------------------------------------------------------- 1 | package out 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_SeparateOutDir(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{ 15 | "../schema.sql", 16 | }) 17 | defer cleanupFunc() 18 | 19 | tmpDir := t.TempDir() 20 | err := pggen.Generate( 21 | pggen.GenerateOptions{ 22 | ConnString: conn.Config().ConnString(), 23 | QueryFiles: []string{ 24 | "../alpha/query.sql", 25 | "../alpha/alpha/query.sql", 26 | "../bravo/query.sql", 27 | }, 28 | OutputDir: tmpDir, 29 | GoPackage: "out", 30 | Language: pggen.LangGo, 31 | }) 32 | if err != nil { 33 | t.Fatalf("Generate(): %s", err) 34 | } 35 | 36 | for _, file := range []string{ 37 | "alpha_query.sql.0.go", 38 | "alpha_query.sql.1.go", 39 | "bravo_query.sql.go", 40 | } { 41 | wantQueries, err := os.ReadFile(file) 42 | if err != nil { 43 | t.Fatalf("read wanted file %s: %s", file, err) 44 | } 45 | 46 | gotFile := filepath.Join(tmpDir, file) 47 | assert.FileExists(t, gotFile, "Generate() should emit "+file) 48 | gotQueries, err := os.ReadFile(gotFile) 49 | if err != nil { 50 | t.Fatalf("read generated %s: %s", file, err) 51 | } 52 | assert.Equalf(t, string(wantQueries), string(gotQueries), 53 | "Got file %s; does not match contents of file %s", 54 | gotFile, file) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /example/separate_out_dir/out/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package out 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | 8 | "github.com/jschaf/pggen/internal/pgtest" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestNewQuerier_FindAuthorByID(t *testing.T) { 13 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"../schema.sql"}) 14 | defer cleanup() 15 | 16 | q := NewQuerier(conn) 17 | 18 | t.Run("AlphaNested", func(t *testing.T) { 19 | got, err := q.AlphaNested(t.Context()) 20 | require.NoError(t, err) 21 | assert.Equal(t, "alpha_nested", got) 22 | }) 23 | 24 | t.Run("Alpha", func(t *testing.T) { 25 | got, err := q.Alpha(t.Context()) 26 | require.NoError(t, err) 27 | assert.Equal(t, "alpha", got) 28 | }) 29 | 30 | t.Run("Bravo", func(t *testing.T) { 31 | got, err := q.Bravo(t.Context()) 32 | require.NoError(t, err) 33 | assert.Equal(t, "bravo", got) 34 | }) 35 | } 36 | -------------------------------------------------------------------------------- /example/separate_out_dir/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE alpha ( 2 | key text NOT NULL PRIMARY KEY 3 | ); 4 | 5 | CREATE TABLE bravo ( 6 | key text NOT NULL PRIMARY KEY 7 | ); 8 | -------------------------------------------------------------------------------- /example/slices/codegen_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/difftest" 10 | "github.com/jschaf/pggen/internal/pgtest" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestGenerate_Go_Example_Slices(t *testing.T) { 15 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 16 | defer cleanupFunc() 17 | 18 | tmpDir := t.TempDir() 19 | err := pggen.Generate( 20 | pggen.GenerateOptions{ 21 | ConnString: conn.Config().ConnString(), 22 | QueryFiles: []string{"query.sql"}, 23 | OutputDir: tmpDir, 24 | GoPackage: "slices", 25 | Language: pggen.LangGo, 26 | InlineParamCount: 2, 27 | TypeOverrides: map[string]string{ 28 | "_bool": "[]bool", 29 | "bool": "bool", 30 | "timestamp": "*time.Time", 31 | "_timestamp": "[]*time.Time", 32 | "timestamptz": "*time.Time", 33 | "_timestamptz": "[]time.Time", 34 | }, 35 | }) 36 | if err != nil { 37 | t.Fatalf("Generate() example/slices: %s", err) 38 | } 39 | 40 | wantQueryFile := "query.sql.go" 41 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 42 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 43 | wantQueries, err := os.ReadFile(wantQueryFile) 44 | if err != nil { 45 | t.Fatalf("read wanted query.go.sql: %s", err) 46 | } 47 | gotQueries, err := os.ReadFile(gotQueryFile) 48 | if err != nil { 49 | t.Fatalf("read generated query.go.sql: %s", err) 50 | } 51 | difftest.AssertSame(t, wantQueries, gotQueries) 52 | } 53 | -------------------------------------------------------------------------------- /example/slices/query.sql: -------------------------------------------------------------------------------- 1 | -- name: GetBools :one 2 | SELECT pggen.arg('data')::boolean[]; 3 | 4 | -- name: GetOneTimestamp :one 5 | SELECT pggen.arg('data')::timestamp; 6 | 7 | -- name: GetManyTimestamptzs :many 8 | SELECT * 9 | FROM unnest(pggen.arg('data')::timestamptz[]); 10 | 11 | -- name: GetManyTimestamps :many 12 | SELECT * 13 | FROM unnest(pggen.arg('data')::timestamp[]); 14 | -------------------------------------------------------------------------------- /example/slices/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/jschaf/pggen/internal/difftest" 8 | "github.com/jschaf/pggen/internal/pgtest" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestNewQuerier_GetBools(t *testing.T) { 13 | ctx := t.Context() 14 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanup() 16 | 17 | q := NewQuerier(conn) 18 | 19 | t.Run("GetBools", func(t *testing.T) { 20 | want := []bool{true, true, false} 21 | got, err := q.GetBools(ctx, want) 22 | require.NoError(t, err) 23 | difftest.AssertSame(t, want, got) 24 | }) 25 | } 26 | 27 | func TestNewQuerier_GetOneTimestamp(t *testing.T) { 28 | ctx := t.Context() 29 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 30 | defer cleanup() 31 | 32 | q := NewQuerier(conn) 33 | ts := time.Date(2020, 1, 1, 11, 11, 11, 0, time.UTC) 34 | 35 | t.Run("GetOneTimestamp", func(t *testing.T) { 36 | got, err := q.GetOneTimestamp(ctx, &ts) 37 | require.NoError(t, err) 38 | difftest.AssertSame(t, &ts, got) 39 | }) 40 | } 41 | 42 | func TestNewQuerier_GetManyTimestamptzs(t *testing.T) { 43 | ctx := t.Context() 44 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 45 | defer cleanup() 46 | 47 | q := NewQuerier(conn) 48 | ts1 := time.Date(2020, 1, 1, 11, 11, 11, 0, time.UTC) 49 | ts2 := time.Date(2022, 2, 2, 22, 22, 22, 0, time.UTC) 50 | 51 | t.Run("GetManyTimestamptzs", func(t *testing.T) { 52 | got, err := q.GetManyTimestamptzs(ctx, []time.Time{ts1, ts2}) 53 | require.NoError(t, err) 54 | difftest.AssertSame(t, []*time.Time{&ts1, &ts2}, got) 55 | }) 56 | } 57 | -------------------------------------------------------------------------------- /example/slices/schema.sql: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jschaf/pggen/632aa8e2e34733a36dbda854dcb4aa8a9b57f884/example/slices/schema.sql -------------------------------------------------------------------------------- /example/syntax/codegen_test.go: -------------------------------------------------------------------------------- 1 | package syntax 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_Syntax(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{ 22 | "query.sql", 23 | }, 24 | OutputDir: tmpDir, 25 | GoPackage: "syntax", 26 | Language: pggen.LangGo, 27 | InlineParamCount: 2, 28 | }) 29 | if err != nil { 30 | t.Fatalf("Generate() example/syntax: %s", err) 31 | } 32 | 33 | wantQueryFile := "query.sql.go" 34 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 35 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 36 | wantQueries, err := os.ReadFile(wantQueryFile) 37 | if err != nil { 38 | t.Fatalf("read wanted query.go.sql: %s", err) 39 | } 40 | gotQueries, err := os.ReadFile(gotQueryFile) 41 | if err != nil { 42 | t.Fatalf("read generated query.go.sql: %s", err) 43 | } 44 | assert.Equalf(t, string(wantQueries), string(gotQueries), 45 | "Got file %s; does not match contents of %s", 46 | gotQueryFile, wantQueryFile) 47 | } 48 | -------------------------------------------------------------------------------- /example/syntax/query.sql: -------------------------------------------------------------------------------- 1 | -- Query to test escaping in generated Go. 2 | -- name: Backtick :one 3 | SELECT '`'; 4 | 5 | -- Query to test escaping in generated Go. 6 | -- name: BacktickQuoteBacktick :one 7 | SELECT '`"`'; 8 | 9 | -- Query to test escaping in generated Go. 10 | -- name: BacktickNewline :one 11 | SELECT '` 12 | '; 13 | 14 | -- Query to test escaping in generated Go. 15 | -- name: BacktickDoubleQuote :one 16 | SELECT '`"'; 17 | 18 | -- Query to test escaping in generated Go. 19 | -- name: BacktickBackslashN :one 20 | SELECT '`\n'; 21 | 22 | -- Illegal names. 23 | -- name: IllegalNameSymbols :one 24 | SELECT '`\n' as "$", pggen.arg('@hello world!') as "foo.bar!@#$%&*()""--+"; 25 | 26 | -- Space after pggen.arg 27 | -- name: SpaceAfter :one 28 | SELECT pggen.arg ('space'); 29 | 30 | -- Enum named 123. 31 | -- name: BadEnumName :one 32 | SELECT 'inconvertible_enum_name'::"123"; 33 | 34 | -- name: GoKeyword :one 35 | SELECT pggen.arg('go')::text; 36 | -------------------------------------------------------------------------------- /example/syntax/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package syntax 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jschaf/pggen/internal/pgtest" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestQuerier(t *testing.T) { 11 | conn, cleanup := pgtest.NewPostgresSchema(t, nil) 12 | defer cleanup() 13 | q := NewQuerier(conn) 14 | ctx := t.Context() 15 | 16 | val, err := q.Backtick(ctx) 17 | assert.NoError(t, err, "Backtick") 18 | assert.Equal(t, "`", val, "Backtick") 19 | 20 | val, err = q.BacktickDoubleQuote(ctx) 21 | assert.NoError(t, err, "BacktickDoubleQuote") 22 | assert.Equal(t, "`\"", val, "BacktickDoubleQuote") 23 | 24 | val, err = q.BacktickQuoteBacktick(ctx) 25 | assert.NoError(t, err, "BacktickQuoteBacktick") 26 | assert.Equal(t, "`\"`", val, "BacktickQuoteBacktick") 27 | 28 | val, err = q.BacktickNewline(ctx) 29 | assert.NoError(t, err, "BacktickNewline") 30 | assert.Equal(t, "`\n", val, "BacktickNewline") 31 | 32 | val, err = q.BacktickBackslashN(ctx) 33 | assert.NoError(t, err, "BacktickBackslashN") 34 | assert.Equal(t, "`\\n", val, "BacktickBackslashN") 35 | } 36 | -------------------------------------------------------------------------------- /example/syntax/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE TYPE "123" AS ENUM ('inconvertible_enum_name', '', '111', '!!'); 2 | 3 | -------------------------------------------------------------------------------- /example/void/codegen_test.go: -------------------------------------------------------------------------------- 1 | package void 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen" 9 | "github.com/jschaf/pggen/internal/pgtest" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Go_Example_void(t *testing.T) { 14 | conn, cleanupFunc := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 15 | defer cleanupFunc() 16 | 17 | tmpDir := t.TempDir() 18 | err := pggen.Generate( 19 | pggen.GenerateOptions{ 20 | ConnString: conn.Config().ConnString(), 21 | QueryFiles: []string{"query.sql"}, 22 | OutputDir: tmpDir, 23 | GoPackage: "void", 24 | Language: pggen.LangGo, 25 | InlineParamCount: 2, 26 | }) 27 | if err != nil { 28 | t.Fatalf("Generate() example/void: %s", err) 29 | } 30 | 31 | wantQueryFile := "query.sql.go" 32 | gotQueryFile := filepath.Join(tmpDir, "query.sql.go") 33 | assert.FileExists(t, gotQueryFile, "Generate() should emit query.sql.go") 34 | wantQueries, err := os.ReadFile(wantQueryFile) 35 | if err != nil { 36 | t.Fatalf("read wanted query.go.sql: %s", err) 37 | } 38 | gotQueries, err := os.ReadFile(gotQueryFile) 39 | if err != nil { 40 | t.Fatalf("read generated query.go.sql: %s", err) 41 | } 42 | assert.Equalf(t, string(wantQueries), string(gotQueries), 43 | "Got file %s; does not match contents of %s", 44 | gotQueryFile, wantQueryFile) 45 | } 46 | -------------------------------------------------------------------------------- /example/void/query.sql: -------------------------------------------------------------------------------- 1 | -- name: VoidOnly :exec 2 | SELECT void_fn(); 3 | 4 | -- name: VoidOnlyTwoParams :exec 5 | SELECT void_fn_two_params(pggen.arg('id'), 'text'); 6 | 7 | -- name: VoidTwo :one 8 | SELECT void_fn(), 'foo' as name; 9 | 10 | -- name: VoidThree :one 11 | SELECT void_fn(), 'foo' as foo, 'bar' as bar; 12 | 13 | -- name: VoidThree2 :many 14 | SELECT 'foo' as foo, void_fn(), void_fn(); 15 | -------------------------------------------------------------------------------- /example/void/query.sql_test.go: -------------------------------------------------------------------------------- 1 | package void 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jschaf/pggen/internal/pgtest" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestQuerier(t *testing.T) { 11 | conn, cleanup := pgtest.NewPostgresSchema(t, []string{"schema.sql"}) 12 | defer cleanup() 13 | 14 | q := NewQuerier(conn) 15 | ctx := t.Context() 16 | 17 | if _, err := q.VoidOnly(ctx); err != nil { 18 | t.Fatal(err) 19 | } 20 | 21 | if _, err := q.VoidOnlyTwoParams(ctx, 33); err != nil { 22 | t.Fatal(err) 23 | } 24 | 25 | { 26 | row, err := q.VoidTwo(ctx) 27 | if err != nil { 28 | t.Fatal(err) 29 | } 30 | assert.Equal(t, "foo", row) 31 | } 32 | 33 | { 34 | row, err := q.VoidThree(ctx) 35 | if err != nil { 36 | t.Fatal(err) 37 | } 38 | assert.Equal(t, VoidThreeRow{Foo: "foo", Bar: "bar"}, row) 39 | } 40 | 41 | { 42 | foos, err := q.VoidThree2(ctx) 43 | if err != nil { 44 | t.Fatal(err) 45 | } 46 | assert.Equal(t, []string{"foo"}, foos) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /example/void/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE FUNCTION void_fn() RETURNS void AS 2 | $$ 3 | BEGIN 4 | END; 5 | $$ LANGUAGE plpgsql; 6 | 7 | -- noinspection SqlUnused 8 | CREATE FUNCTION void_fn_two_params(id int, comment text) RETURNS void AS 9 | $$ 10 | BEGIN 11 | END; 12 | $$ LANGUAGE plpgsql; 13 | -------------------------------------------------------------------------------- /generate_test.go: -------------------------------------------------------------------------------- 1 | package pggen 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/jschaf/pggen/internal/pgtest" 9 | "github.com/jschaf/pggen/internal/texts" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGenerate_Golang_Error(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | schema string 17 | queries string 18 | wantErrMsg string 19 | }{ 20 | { 21 | name: "duplicate query name", 22 | schema: "", 23 | queries: texts.Dedent(` 24 | -- name: Foo :many 25 | SELECT 1; 26 | -- name: Foo :many 27 | SELECT 1; 28 | `), 29 | wantErrMsg: `duplicate query name Foo`, 30 | }, 31 | { 32 | name: "type error", 33 | schema: "", 34 | queries: texts.Dedent(` 35 | -- name: Foo :one 36 | SELECT encode(123, 'foo'::text); 37 | `), 38 | wantErrMsg: `function encode(integer, text) does not exist`, 39 | }, 40 | } 41 | for _, tt := range tests { 42 | t.Run(tt.name, func(t *testing.T) { 43 | conn, cleanupFunc := pgtest.NewPostgresSchemaString(t, tt.schema) 44 | defer cleanupFunc() 45 | tmpDir := t.TempDir() 46 | queryFile := filepath.Join(tmpDir, "query.sql") 47 | err := os.WriteFile(queryFile, []byte(tt.queries), 0o600) 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | 52 | err = Generate( 53 | GenerateOptions{ 54 | ConnString: conn.Config().ConnString(), 55 | QueryFiles: []string{queryFile}, 56 | OutputDir: tmpDir, 57 | GoPackage: "error_test", 58 | Language: LangGo, 59 | }) 60 | 61 | if err == nil { 62 | t.Fatal("expected error from generate") 63 | } 64 | assert.Contains(t, err.Error(), tt.wantErrMsg, "error message should contain substring") 65 | }) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/jschaf/pggen 2 | 3 | go 1.24.1 4 | 5 | require ( 6 | github.com/bmatcuk/doublestar v1.3.4 7 | github.com/docker/docker v28.0.1+incompatible 8 | github.com/docker/go-connections v0.5.0 9 | github.com/google/go-cmp v0.7.0 10 | github.com/jackc/pgconn v1.14.3 11 | github.com/jackc/pgproto3/v2 v2.3.3 12 | github.com/jackc/pgtype v1.14.4 13 | github.com/jackc/pgx/v4 v4.18.3 14 | github.com/peterbourgon/ff/v3 v3.4.0 15 | github.com/stretchr/testify v1.10.0 16 | golang.org/x/mod v0.24.0 17 | ) 18 | 19 | require ( 20 | github.com/Microsoft/go-winio v0.6.2 // indirect 21 | github.com/containerd/log v0.1.0 // indirect 22 | github.com/davecgh/go-spew v1.1.1 // indirect 23 | github.com/distribution/reference v0.6.0 // indirect 24 | github.com/docker/go-units v0.5.0 // indirect 25 | github.com/felixge/httpsnoop v1.0.4 // indirect 26 | github.com/go-logr/logr v1.4.2 // indirect 27 | github.com/go-logr/stdr v1.2.2 // indirect 28 | github.com/gogo/protobuf v1.3.2 // indirect 29 | github.com/jackc/chunkreader/v2 v2.0.1 // indirect 30 | github.com/jackc/pgio v1.0.0 // indirect 31 | github.com/jackc/pgpassfile v1.0.0 // indirect 32 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect 33 | github.com/moby/docker-image-spec v1.3.1 // indirect 34 | github.com/moby/term v0.5.2 // indirect 35 | github.com/morikuni/aec v1.0.0 // indirect 36 | github.com/opencontainers/go-digest v1.0.0 // indirect 37 | github.com/opencontainers/image-spec v1.1.1 // indirect 38 | github.com/pkg/errors v0.9.1 // indirect 39 | github.com/pmezard/go-difflib v1.0.0 // indirect 40 | github.com/shopspring/decimal v1.4.0 // indirect 41 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect 42 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect 43 | go.opentelemetry.io/otel v1.35.0 // indirect 44 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.35.0 // indirect 45 | go.opentelemetry.io/otel/metric v1.35.0 // indirect 46 | go.opentelemetry.io/otel/trace v1.35.0 // indirect 47 | golang.org/x/crypto v0.36.0 // indirect 48 | golang.org/x/net v0.37.0 // indirect 49 | golang.org/x/sys v0.31.0 // indirect 50 | golang.org/x/text v0.23.0 // indirect 51 | golang.org/x/time v0.11.0 // indirect 52 | gopkg.in/yaml.v3 v3.0.1 // indirect 53 | gotest.tools/v3 v3.5.2 // indirect 54 | ) 55 | -------------------------------------------------------------------------------- /internal/casing/casing.go: -------------------------------------------------------------------------------- 1 | package casing 2 | 3 | import ( 4 | "strings" 5 | "unicode" 6 | "unicode/utf8" 7 | ) 8 | 9 | // Caser converts strings from camel_case to UpperCamelCase. 10 | type Caser struct { 11 | acronyms map[string]string 12 | } 13 | 14 | func NewCaser() Caser { 15 | return Caser{ 16 | acronyms: map[string]string{}, 17 | } 18 | } 19 | 20 | // AddAcronyms adds each acronym that's specially handled in conversion 21 | // routines. 22 | func (cs Caser) AddAcronyms(acros map[string]string) { 23 | for a, b := range acros { 24 | cs.acronyms[a] = b 25 | } 26 | } 27 | 28 | // AddAcronym adds an acronym that's specially handled in conversion routines. 29 | func (cs Caser) AddAcronym(str, acronym string) { 30 | cs.acronyms[str] = acronym 31 | } 32 | 33 | // ToUpperGoIdent converts a string into a legal, capitalized Go identifier, 34 | // respecting registered acronyms. Returns the empty string if no conversion 35 | // is possible. 36 | func (cs Caser) ToUpperGoIdent(s string) string { 37 | san := sanitize(s) 38 | if san == "" { 39 | return "" 40 | } 41 | return cs.convert(san, cs.appendUpperCamel) 42 | } 43 | 44 | // ToLowerGoIdent converts a string into a legal, uncapitalized Go identifier, 45 | // respecting registered acronyms. Returns the empty string if no conversion 46 | // is possible. 47 | func (cs Caser) ToLowerGoIdent(s string) string { 48 | san := sanitize(s) 49 | if san == "" { 50 | return "" 51 | } 52 | con := cs.convert(san, cs.appendLowerCamel) 53 | switch con { 54 | case "func", "interface", "select", "case", "defer", "go", "map", "struct", 55 | "chan", "else", "goto", "package", "switch", "const", "fallthrough", "if", 56 | "range", "type", "continue", "for", "import", "return", "var": 57 | return con + "_" 58 | default: 59 | return con 60 | } 61 | } 62 | 63 | type converter func(*strings.Builder, []byte) 64 | 65 | // convert converts a string using converter for each sub-word while 66 | // respecting the registered acronyms. 67 | func (cs Caser) convert(s string, converter converter) string { 68 | s = strings.TrimSpace(s) 69 | if len(s) == 0 { 70 | return s 71 | } 72 | sb := &strings.Builder{} 73 | sb.Grow(len(s)) 74 | chars := []byte(s) 75 | lo := 0 76 | isUpSpan := false 77 | for hi := 0; hi < len(chars); { 78 | ch, size := utf8.DecodeRune(chars[hi:]) 79 | switch { 80 | case ch == '_': 81 | isUpSpan = false 82 | converter(sb, chars[lo:hi]) 83 | lo = hi + size // skip underscore 84 | case unicode.IsUpper(ch): 85 | isUpSpan = lo+1 == hi || isUpSpan 86 | if !isUpSpan { 87 | converter(sb, chars[lo:hi]) 88 | lo = hi 89 | } 90 | default: 91 | isUpSpan = false 92 | } 93 | hi += size 94 | } 95 | converter(sb, chars[lo:]) 96 | return sb.String() 97 | } 98 | 99 | func (cs Caser) appendUpperCamel(sb *strings.Builder, chars []byte) { 100 | if len(chars) == 0 { 101 | return 102 | } 103 | if a, ok := cs.acronyms[strings.ToLower(string(chars))]; ok { 104 | sb.WriteString(a) 105 | return 106 | } 107 | firstCh, size := utf8.DecodeRune(chars) 108 | sb.WriteRune(unicode.ToUpper(firstCh)) 109 | sb.Write(chars[size:]) 110 | } 111 | 112 | func (cs Caser) appendLowerCamel(sb *strings.Builder, chars []byte) { 113 | if len(chars) == 0 { 114 | return 115 | } 116 | isFirst := sb.Len() == 0 117 | if a, ok := cs.acronyms[strings.ToLower(string(chars))]; ok { 118 | if isFirst { 119 | // First word should be uncapitalized. We don't know exactly how to do 120 | // that, so assume lower casing the acronym is sufficient. 121 | sb.WriteString(strings.ToLower(a)) 122 | } else { 123 | sb.WriteString(a) 124 | } 125 | return 126 | } 127 | firstCh, size := utf8.DecodeRune(chars) 128 | if isFirst { 129 | sb.WriteRune(unicode.ToLower(firstCh)) 130 | // Lowercase rest of first word. 131 | for i := size; i < len(chars); { 132 | ch, n := utf8.DecodeRune(chars[i:]) 133 | sb.WriteRune(unicode.ToLower(ch)) 134 | i += n 135 | } 136 | } else { 137 | sb.WriteRune(unicode.ToUpper(firstCh)) 138 | sb.Write(chars[size:]) 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /internal/casing/casing_test.go: -------------------------------------------------------------------------------- 1 | package casing 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestCaser_ToUpperGoIdent(t *testing.T) { 10 | tests := []struct { 11 | word string 12 | want string 13 | acronyms map[string]string 14 | }{ 15 | {"fooBar", "FooBar", nil}, 16 | {"FooBar", "FooBar", nil}, 17 | {"$", "", nil}, 18 | {"user.id", "UserID", map[string]string{"id": "ID"}}, 19 | {"$foo$bar", "FooBar", nil}, 20 | {"foo bar@@@!", "FooBar", nil}, 21 | {"12!foo bar@@@!", "FooBar", nil}, 22 | {"foo", "Foo", nil}, 23 | {"foo123", "Foo123", nil}, 24 | {"foo123bar", "Foo123bar", nil}, 25 | {"foo123bar_baz", "Foo123barBaz", nil}, 26 | {"foo", "FOO", map[string]string{"foo": "FOO"}}, 27 | {"foo_", "Foo", nil}, 28 | {"_foo_", "Foo", nil}, 29 | {"foo__", "Foo", nil}, 30 | {"foo__bar", "FooBar", nil}, 31 | {"foo_bar", "FooBar", nil}, 32 | {"foo_bar_baz", "FooBarBaz", nil}, 33 | {"foo_bar_baz", "FooBarBAZ", map[string]string{"baz": "BAZ"}}, 34 | {"foo_bar_baz", "FooBARBAZ", map[string]string{"bar": "BAR", "baz": "BAZ"}}, 35 | {"Ě", "Ě", nil}, 36 | {"ě", "Ě", nil}, 37 | {"Ěě_ě", "ĚěĚ", nil}, 38 | {"OIDs", "OIDs", map[string]string{"oids": "OIDs"}}, 39 | {"OIDsBar", "OIDsBar", map[string]string{"oids": "OIDs"}}, 40 | } 41 | for _, tt := range tests { 42 | t.Run(tt.word+"="+tt.want, func(t *testing.T) { 43 | caser := NewCaser() 44 | caser.AddAcronyms(tt.acronyms) 45 | got := caser.ToUpperGoIdent(tt.word) 46 | assert.Equal(t, tt.want, got) 47 | }) 48 | } 49 | } 50 | 51 | func TestCaser_ToLowerGoIdent(t *testing.T) { 52 | tests := []struct { 53 | word string 54 | want string 55 | acronyms map[string]string 56 | }{ 57 | {"fooBar", "fooBar", nil}, 58 | {"FooBar", "fooBar", nil}, 59 | {"$", "", nil}, 60 | {"user.id", "userID", map[string]string{"id": "ID"}}, 61 | {"$foo$bar", "fooBar", nil}, 62 | {"foo bar@@@!", "fooBar", nil}, 63 | {"12!foo bar@@@!", "fooBar", nil}, 64 | {"foo", "foo", nil}, 65 | {"foo123", "foo123", nil}, 66 | {"foo123bar", "foo123bar", nil}, 67 | {"foo123bar_baz", "foo123barBaz", nil}, 68 | {"foo", "foo", map[string]string{"foo": "FOO"}}, 69 | {"foo_", "foo", nil}, 70 | {"_foo_", "foo", nil}, 71 | {"foo__", "foo", nil}, 72 | {"foo__bar", "fooBar", nil}, 73 | {"foo_bar", "fooBar", nil}, 74 | {"foo_bar_baz", "fooBarBaz", nil}, 75 | {"foo_bar_baz", "fooBarBAZ", map[string]string{"baz": "BAZ"}}, 76 | {"foo_bar_baz", "fooBARBAZ", map[string]string{"bar": "BAR", "baz": "BAZ"}}, 77 | {"Ě", "ě", nil}, 78 | {"ě", "ě", nil}, 79 | {"Ěě_ě", "ěěĚ", nil}, 80 | {"if", "if_", nil}, 81 | {"type", "type_", nil}, 82 | {"OIDs", "oids", nil}, 83 | {"OIDsBar", "oidsBar", nil}, 84 | {"FindOIDByVal", "findOIDByVal", map[string]string{"oid": "OID"}}, 85 | } 86 | for _, tt := range tests { 87 | t.Run(tt.word+"="+tt.want, func(t *testing.T) { 88 | caser := NewCaser() 89 | caser.AddAcronyms(tt.acronyms) 90 | got := caser.ToLowerGoIdent(tt.word) 91 | assert.Equal(t, tt.want, got) 92 | }) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /internal/casing/sanitize.go: -------------------------------------------------------------------------------- 1 | package casing 2 | 3 | import ( 4 | "strings" 5 | "unicode" 6 | "unicode/utf8" 7 | ) 8 | 9 | // sanitize replaces a string with a version safe for using as a Go 10 | // identifier. A Go identifier begins with a unicode letter or underscore and 11 | // is followed by 0 or more unicode letters or digits, or underscores. 12 | // Replaces illegal runes with underscores. Skips leading characters that aren't 13 | // a letter or underscore. 14 | func sanitize(s string) string { 15 | sb := &strings.Builder{} 16 | sb.Grow(len(s)) 17 | var firstLetter rune 18 | secondCharIdx := -1 19 | // Find first legal starting char, a letter or an underscore. 20 | for idx, ch := range s { 21 | if unicode.IsLetter(ch) || ch == '_' { 22 | firstLetter = ch 23 | secondCharIdx = idx + utf8.RuneLen(ch) 24 | break 25 | } 26 | } 27 | if secondCharIdx == -1 { 28 | return "" 29 | } 30 | 31 | sb.WriteRune(firstLetter) 32 | prevUnderscore := firstLetter == '_' 33 | for _, ch := range s[secondCharIdx:] { 34 | switch { 35 | case unicode.IsLetter(ch) || unicode.IsDigit(ch): 36 | sb.WriteRune(ch) 37 | prevUnderscore = false 38 | default: 39 | if !prevUnderscore { 40 | sb.WriteRune('_') 41 | } 42 | prevUnderscore = true 43 | } 44 | } 45 | if sb.Len() == 0 { 46 | return "" 47 | } 48 | return sb.String() 49 | } 50 | -------------------------------------------------------------------------------- /internal/casing/sanitize_test.go: -------------------------------------------------------------------------------- 1 | package casing 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestSanitize(t *testing.T) { 10 | tests := []struct { 11 | str string 12 | want string 13 | }{ 14 | {"a", "a"}, 15 | {"a.", "a_"}, 16 | {"a.b", "a_b"}, 17 | {"", ""}, 18 | {"1", ""}, 19 | {"1abc", "abc"}, 20 | {"abc@123", "abc_123"}, 21 | {"abc@!123", "abc_123"}, 22 | {"T食", "T食"}, 23 | } 24 | for _, tt := range tests { 25 | t.Run(tt.str, func(t *testing.T) { 26 | got := sanitize(tt.str) 27 | assert.Equal(t, tt.want, got) 28 | }) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /internal/codegen/common.go: -------------------------------------------------------------------------------- 1 | // Package codegen contains common code shared between codegen and language 2 | // specific code generators. Separate package to avoid dependency cycles. 3 | package codegen 4 | 5 | import ( 6 | "github.com/jschaf/pggen/internal/pginfer" 7 | ) 8 | 9 | // QueryFile represents all SQL queries from a single file. 10 | type QueryFile struct { 11 | SourcePath string // absolute path to the source SQL query file 12 | Queries []pginfer.TypedQuery // the typed queries 13 | } 14 | -------------------------------------------------------------------------------- /internal/codegen/golang/declarer_enum.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | 7 | "github.com/jschaf/pggen/internal/codegen/golang/gotype" 8 | ) 9 | 10 | func NameEnumTranscoderFunc(typ *gotype.EnumType) string { 11 | return "new" + typ.Name + "Enum" 12 | } 13 | 14 | // EnumTypeDeclarer declares a new string type and the const values to map to a 15 | // Postgres enum. 16 | type EnumTypeDeclarer struct { 17 | enum *gotype.EnumType 18 | } 19 | 20 | func NewEnumTypeDeclarer(enum *gotype.EnumType) EnumTypeDeclarer { 21 | return EnumTypeDeclarer{enum: enum} 22 | } 23 | 24 | func (e EnumTypeDeclarer) DedupeKey() string { 25 | return "enum_type::" + e.enum.Name 26 | } 27 | 28 | func (e EnumTypeDeclarer) Declare(string) (string, error) { 29 | sb := &strings.Builder{} 30 | // Doc string. 31 | if e.enum.PgEnum.Name != "" { 32 | sb.WriteString("// ") 33 | sb.WriteString(e.enum.Name) 34 | sb.WriteString(" represents the Postgres enum ") 35 | sb.WriteString(strconv.Quote(e.enum.PgEnum.Name)) 36 | sb.WriteString(".\n") 37 | } 38 | // Type declaration. 39 | sb.WriteString("type ") 40 | sb.WriteString(e.enum.Name) 41 | sb.WriteString(" string\n\n") 42 | // Const enum values. 43 | sb.WriteString("const (\n") 44 | nameLen := 0 45 | for _, label := range e.enum.Labels { 46 | if len(label) > nameLen { 47 | nameLen = len(label) 48 | } 49 | } 50 | for i, label := range e.enum.Labels { 51 | sb.WriteString("\t") 52 | sb.WriteString(label) 53 | sb.WriteString(strings.Repeat(" ", nameLen+1-len(label))) 54 | sb.WriteString(e.enum.Name) 55 | sb.WriteString(` = `) 56 | sb.WriteString(strconv.Quote(e.enum.Values[i])) 57 | sb.WriteByte('\n') 58 | } 59 | sb.WriteString(")\n\n") 60 | // Stringer 61 | dispatcher := strings.ToLower(e.enum.Name)[0] 62 | sb.WriteString("func (") 63 | sb.WriteByte(dispatcher) 64 | sb.WriteByte(' ') 65 | sb.WriteString(e.enum.Name) 66 | sb.WriteString(") String() string { return string(") 67 | sb.WriteByte(dispatcher) 68 | sb.WriteString(") }") 69 | return sb.String(), nil 70 | } 71 | 72 | // EnumTranscoderDeclarer declares a new Go function that creates a pgx decoder 73 | // for the Postgres type represented by the gotype.EnumType. 74 | type EnumTranscoderDeclarer struct { 75 | typ *gotype.EnumType 76 | } 77 | 78 | func NewEnumTranscoderDeclarer(enum *gotype.EnumType) EnumTranscoderDeclarer { 79 | return EnumTranscoderDeclarer{typ: enum} 80 | } 81 | 82 | func (e EnumTranscoderDeclarer) DedupeKey() string { 83 | return "enum_decoder::" + e.typ.Name 84 | } 85 | 86 | func (e EnumTranscoderDeclarer) Declare(string) (string, error) { 87 | sb := &strings.Builder{} 88 | funcName := NameEnumTranscoderFunc(e.typ) 89 | 90 | // Doc comment 91 | sb.WriteString("// ") 92 | sb.WriteString(funcName) 93 | sb.WriteString(" creates a new pgtype.ValueTranscoder for the\n") 94 | sb.WriteString("// Postgres enum type '") 95 | sb.WriteString(e.typ.PgEnum.Name) 96 | sb.WriteString("'.\n") 97 | 98 | // Function signature 99 | sb.WriteString("func ") 100 | sb.WriteString(funcName) 101 | sb.WriteString("() pgtype.ValueTranscoder {\n\t") 102 | 103 | // NewEnumType call 104 | sb.WriteString("return pgtype.NewEnumType(\n\t\t") 105 | sb.WriteString(strconv.Quote(e.typ.PgEnum.Name)) 106 | sb.WriteString(",\n\t\t") 107 | sb.WriteString(`[]string{`) 108 | for _, label := range e.typ.Labels { 109 | sb.WriteString("\n\t\t\t") 110 | sb.WriteString("string(") 111 | sb.WriteString(label) 112 | sb.WriteString("),") 113 | } 114 | sb.WriteString("\n\t\t") 115 | sb.WriteString("},") 116 | sb.WriteString("\n\t") 117 | sb.WriteString(")") 118 | sb.WriteString("\n") 119 | sb.WriteString("}") 120 | 121 | return sb.String(), nil 122 | } 123 | -------------------------------------------------------------------------------- /internal/codegen/golang/emitter.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "strconv" 8 | "text/template" 9 | 10 | "github.com/jschaf/pggen/internal/errs" 11 | ) 12 | 13 | // Emitter writes a templated query file to a file. 14 | type Emitter struct { 15 | outDir string 16 | tmpl *template.Template 17 | } 18 | 19 | func NewEmitter(outDir string, tmpl *template.Template) Emitter { 20 | return Emitter{outDir: outDir, tmpl: tmpl} 21 | } 22 | 23 | // EmitAllQueryFiles emits a query file for each TemplatedFile. Ensure that 24 | // emitted files don't clash by prefixing with the parent directory if 25 | // necessary. 26 | func (em Emitter) EmitAllQueryFiles(tfs []TemplatedFile) (mErr error) { 27 | outs := em.chooseOutputFiles(tfs) 28 | for i, tf := range tfs { 29 | if err := em.emitQueryFile(outs[i], tf); err != nil { 30 | return err 31 | } 32 | } 33 | return nil 34 | } 35 | 36 | // chooseOutputFiles returns the output paths to use for each TemplatedFile. 37 | // Necessary for cases like "alpha/query.sql" and "bravo/query.sql" where 38 | // we can't simply use "query.sql.go". 39 | func (em Emitter) chooseOutputFiles(tfs []TemplatedFile) []string { 40 | // Check for any basename collisions. 41 | seenBaseNames := make(map[string]struct{}, len(tfs)) 42 | hasBaseCollision := false 43 | for _, tf := range tfs { 44 | base := filepath.Base(tf.SourcePath) 45 | if _, ok := seenBaseNames[base]; ok { 46 | hasBaseCollision = true 47 | } 48 | seenBaseNames[base] = struct{}{} 49 | } 50 | 51 | // If no base collision, just use base names. If no collisions, use the 52 | // basename, like "query.sql" => "query.go.sql". 53 | if !hasBaseCollision { 54 | outNames := make([]string, len(tfs)) 55 | for i, tf := range tfs { 56 | out := filepath.Base(tf.SourcePath) 57 | out += ".go" 58 | outNames[i] = out 59 | } 60 | return outNames 61 | } 62 | 63 | // If there's a basename collision, check for collisions after prefixing the 64 | // parent directory name. If there's still a collision we'll make each name 65 | // unique with a numeric literal suffix. Occurs with a file pattern like: 66 | // "alpha/query.sql" and "parent/alpha/query.sql". 67 | outNames := make([]string, len(tfs)) // names to return 68 | usedNames := make(map[string]int) // next int to use for a collision on key 69 | firstIdx := make(map[string]int) // first index a name was used 70 | for i, tf := range tfs { 71 | out := filepath.Base(tf.SourcePath) 72 | parent := filepath.Base(filepath.Dir(tf.SourcePath)) 73 | out = parent + "_" + out 74 | n, ok := usedNames[out] 75 | usedNames[out] = n + 1 76 | if ok { 77 | // We've seen this entry already. 78 | firstI := firstIdx[out] 79 | outNames[i] = out + "." + strconv.Itoa(n) 80 | if n == 1 { 81 | // Add suffix to first entry since we didn't do it the first time around 82 | // because we didn't know if it had a collision. 83 | outNames[firstI] += ".0" 84 | } 85 | } else { 86 | // First time seeing an entry. 87 | outNames[i] = out 88 | firstIdx[out] = i 89 | } 90 | } 91 | for i := range outNames { 92 | outNames[i] += ".go" 93 | } 94 | return outNames 95 | } 96 | 97 | // emitQueryFile emits a single query file. 98 | func (em Emitter) emitQueryFile(outRelPath string, tf TemplatedFile) (mErr error) { 99 | out := filepath.Join(em.outDir, outRelPath) 100 | file, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) 101 | defer errs.Capture(&mErr, file.Close, "close emit query file") 102 | if err != nil { 103 | return fmt.Errorf("open generated query file for writing: %w", err) 104 | } 105 | if err := em.tmpl.ExecuteTemplate(file, "gen_query", tf); err != nil { 106 | return fmt.Errorf("execute generated query file template %s: %w", out, err) 107 | } 108 | return nil 109 | } 110 | -------------------------------------------------------------------------------- /internal/codegen/golang/generate.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | "path/filepath" 7 | "sort" 8 | "text/template" 9 | 10 | "github.com/jschaf/pggen/internal/casing" 11 | "github.com/jschaf/pggen/internal/codegen" 12 | ) 13 | 14 | // GenerateOptions are options to control generated Go output. 15 | type GenerateOptions struct { 16 | GoPkg string 17 | OutputDir string 18 | // A map of lowercase acronyms to the upper case equivalent, like: 19 | // "api" => "API". 20 | Acronyms map[string]string 21 | // A map from a Postgres type name to a fully qualified Go type. 22 | TypeOverrides map[string]string 23 | // How many params to inline when calling querier methods. 24 | // Set to 0 to always create a struct for params. 25 | InlineParamCount int 26 | } 27 | 28 | // Generate emits generated Go files for each of the queryFiles. 29 | func Generate(opts GenerateOptions, queryFiles []codegen.QueryFile) error { 30 | pkgName := opts.GoPkg 31 | if pkgName == "" { 32 | pkgName = filepath.Base(opts.OutputDir) 33 | } 34 | caser := casing.NewCaser() 35 | caser.AddAcronyms(opts.Acronyms) 36 | templater := NewTemplater(TemplaterOpts{ 37 | Caser: caser, 38 | Resolver: NewTypeResolver(caser, opts.TypeOverrides), 39 | Pkg: pkgName, 40 | InlineParamCount: opts.InlineParamCount, 41 | }) 42 | templatedFiles, err := templater.TemplateAll(queryFiles) 43 | if err != nil { 44 | return fmt.Errorf("template all: %w", err) 45 | } 46 | 47 | // Order for reproducible results. 48 | sort.Slice(templatedFiles, func(i, j int) bool { 49 | return templatedFiles[i].SourcePath < templatedFiles[j].SourcePath 50 | }) 51 | 52 | // Link each child to the package. Necessary so the leader can define all 53 | // Querier methods. 54 | pkg := TemplatedPackage{Files: templatedFiles} 55 | for i := range templatedFiles { 56 | templatedFiles[i].Pkg = pkg 57 | } 58 | 59 | tmpl, err := parseQueryTemplate() 60 | if err != nil { 61 | return fmt.Errorf("parse generated Go code template: %w", err) 62 | } 63 | emitter := NewEmitter(opts.OutputDir, tmpl) 64 | if err := emitter.EmitAllQueryFiles(templatedFiles); err != nil { 65 | return fmt.Errorf("emit generated Go code: %w", err) 66 | } 67 | return nil 68 | } 69 | 70 | //go:embed query.gotemplate 71 | var queryTemplate string 72 | 73 | func parseQueryTemplate() (*template.Template, error) { 74 | tmpl, err := template.New("gen_query").Parse(queryTemplate) 75 | if err != nil { 76 | return nil, fmt.Errorf("parse query.gotemplate: %w", err) 77 | } 78 | return tmpl, nil 79 | } 80 | -------------------------------------------------------------------------------- /internal/codegen/golang/gotype/predicates.go: -------------------------------------------------------------------------------- 1 | package gotype 2 | 3 | // HasCompositeType returns true if t or any of t's descendants (for array and 4 | // composite types) is a composite type. 5 | func HasCompositeType(t Type) bool { 6 | switch t := t.(type) { 7 | case *ArrayType: 8 | return HasCompositeType(t.Elem) 9 | case *CompositeType: 10 | return true 11 | default: 12 | return false 13 | } 14 | } 15 | 16 | // HasArrayType returns true if t or any of t's descendants (for array and 17 | // composite types) is an array type. 18 | func HasArrayType(t Type) bool { 19 | switch t := t.(type) { 20 | case *ArrayType: 21 | return true 22 | case *CompositeType: 23 | for _, typ := range t.FieldTypes { 24 | if ok := HasArrayType(typ); ok { 25 | return true 26 | } 27 | } 28 | return false 29 | default: 30 | return false 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /internal/codegen/golang/gotype/predicates_test.go: -------------------------------------------------------------------------------- 1 | package gotype 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestHasCompositeType(t *testing.T) { 8 | tests := []struct { 9 | name string 10 | typ Type 11 | want bool 12 | }{ 13 | {"enum", &EnumType{}, false}, 14 | {"void", &VoidType{}, false}, 15 | {"opaque", &OpaqueType{}, false}, 16 | {"empty array", &ArrayType{}, false}, 17 | {"array with composite", &ArrayType{Elem: &CompositeType{}}, true}, 18 | {"composite", &CompositeType{}, true}, 19 | } 20 | for _, tt := range tests { 21 | t.Run(tt.name, func(t *testing.T) { 22 | if got := HasCompositeType(tt.typ); got != tt.want { 23 | t.Errorf("HasCompositeType() = %v, want %v", got, tt.want) 24 | } 25 | }) 26 | } 27 | } 28 | 29 | func TestHasArrayType(t *testing.T) { 30 | tests := []struct { 31 | name string 32 | typ Type 33 | want bool 34 | }{ 35 | {"enum", &EnumType{}, false}, 36 | {"void", &VoidType{}, false}, 37 | {"opaque", &OpaqueType{}, false}, 38 | {"empty array", &ArrayType{}, true}, 39 | {"array with composite", &ArrayType{Elem: &CompositeType{}}, true}, 40 | {"empty composite", &CompositeType{}, false}, 41 | {"composite with array", &CompositeType{FieldTypes: []Type{&ArrayType{}}}, true}, 42 | } 43 | for _, tt := range tests { 44 | t.Run(tt.name, func(t *testing.T) { 45 | if got := HasArrayType(tt.typ); got != tt.want { 46 | t.Errorf("HasArrayType() = %v, want %v", got, tt.want) 47 | } 48 | }) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /internal/codegen/golang/gotype/types_test.go: -------------------------------------------------------------------------------- 1 | package gotype 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/google/go-cmp/cmp" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestMustParseKnownType(t *testing.T) { 11 | tests := []struct { 12 | qualType string 13 | want Type 14 | }{ 15 | { 16 | qualType: "string", 17 | want: &OpaqueType{Name: "string"}, 18 | }, 19 | { 20 | qualType: "*string", 21 | want: &PointerType{Elem: &OpaqueType{Name: "string"}}, 22 | }, 23 | { 24 | qualType: "[]string", 25 | want: &ArrayType{Elem: &OpaqueType{Name: "string"}}, 26 | }, 27 | { 28 | qualType: "[]*string", 29 | want: &ArrayType{Elem: &PointerType{Elem: &OpaqueType{Name: "string"}}}, 30 | }, 31 | { 32 | qualType: "time.Time", 33 | want: &ImportType{ 34 | PkgPath: "time", 35 | Type: &OpaqueType{Name: "Time"}, 36 | }, 37 | }, 38 | { 39 | qualType: "[]time.Time", 40 | want: &ArrayType{ 41 | Elem: &ImportType{PkgPath: "time", Type: &OpaqueType{Name: "Time"}}, 42 | }, 43 | }, 44 | { 45 | qualType: "[]*time.Time", 46 | want: &ArrayType{ 47 | Elem: &PointerType{ 48 | Elem: &ImportType{PkgPath: "time", Type: &OpaqueType{Name: "Time"}}, 49 | }, 50 | }, 51 | }, 52 | { 53 | qualType: "[]util/custom/times.Interval", 54 | want: &ArrayType{ 55 | Elem: &ImportType{PkgPath: "util/custom/times", Type: &OpaqueType{Name: "Interval"}}, 56 | }, 57 | }, 58 | } 59 | for _, tt := range tests { 60 | t.Run(tt.qualType, func(t *testing.T) { 61 | got := MustParseOpaqueType(tt.qualType) 62 | if diff := cmp.Diff(tt.want, got); diff != "" { 63 | t.Errorf("mismatch (-want +got):\n%s", diff) 64 | } 65 | }) 66 | } 67 | } 68 | 69 | func TestQualifyType(t *testing.T) { 70 | tests := []struct { 71 | name string 72 | typ Type 73 | otherPkg string 74 | want string 75 | }{ 76 | { 77 | name: "string", 78 | typ: &OpaqueType{Name: "string"}, 79 | otherPkg: "example.com/foo", 80 | want: "string", 81 | }, 82 | { 83 | name: "[]string", 84 | typ: &ArrayType{Elem: &OpaqueType{Name: "string"}}, 85 | otherPkg: "example.com/foo", 86 | want: "[]string", 87 | }, 88 | { 89 | name: "[]*string", 90 | typ: &ArrayType{Elem: &PointerType{Elem: &OpaqueType{Name: "string"}}}, 91 | otherPkg: "example.com/foo", 92 | want: "[]*string", 93 | }, 94 | { 95 | name: "foo.com/qux.Bar - example.com/foo", 96 | typ: &ImportType{PkgPath: "foo.com/qux", Type: &OpaqueType{Name: "Bar"}}, 97 | otherPkg: "example.com/foo", 98 | want: "qux.Bar", 99 | }, 100 | { 101 | name: "[]foo.com/qux.Bar - example.com/foo", 102 | typ: &ArrayType{Elem: &ImportType{PkgPath: "foo.com/qux", Type: &OpaqueType{Name: "Bar"}}}, 103 | otherPkg: "example.com/foo", 104 | want: "[]qux.Bar", 105 | }, 106 | { 107 | name: "[]example.com/qux.Bar - example.com/foo", 108 | typ: &ArrayType{Elem: &ImportType{PkgPath: "example.com/qux", Type: &OpaqueType{Name: "Bar"}}}, 109 | otherPkg: "example.com/foo", 110 | want: "[]qux.Bar", 111 | }, 112 | { 113 | name: "[]example.com/foo.Bar - example.com/foo", 114 | typ: &ArrayType{Elem: &ImportType{PkgPath: "example.com/foo", Type: &OpaqueType{Name: "Bar"}}}, 115 | otherPkg: "example.com/foo", 116 | want: "[]Bar", 117 | }, 118 | } 119 | 120 | for _, tt := range tests { 121 | t.Run(tt.name, func(t *testing.T) { 122 | got := QualifyType(tt.typ, tt.otherPkg) 123 | assert.Equal(t, tt.want, got) 124 | }) 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /internal/codegen/golang/import_set.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "sort" 5 | 6 | "github.com/jschaf/pggen/internal/codegen/golang/gotype" 7 | ) 8 | 9 | // ImportSet contains a set of imports required by one Go file. 10 | type ImportSet struct { 11 | imports map[string]struct{} 12 | } 13 | 14 | func NewImportSet() *ImportSet { 15 | return &ImportSet{imports: make(map[string]struct{}, 4)} 16 | } 17 | 18 | // AddPackage adds a fully qualified package path to the set, like 19 | // "github.com/jschaf/pggen/foo". 20 | func (s *ImportSet) AddPackage(p string) { 21 | s.imports[p] = struct{}{} 22 | } 23 | 24 | // AddType adds all fully qualified package paths needed for type and any child 25 | // types. 26 | func (s *ImportSet) AddType(typ gotype.Type) { 27 | s.AddPackage(typ.Import()) 28 | comp, ok := typ.(*gotype.CompositeType) 29 | if !ok { 30 | return 31 | } 32 | for _, childType := range comp.FieldTypes { 33 | s.AddType(childType) 34 | } 35 | } 36 | 37 | // SortedPackages returns a new slice containing the sorted packages, suitable 38 | // for an import statement. 39 | func (s *ImportSet) SortedPackages() []string { 40 | imps := make([]string, 0, len(s.imports)) 41 | for pkg := range s.imports { 42 | if pkg != "" { 43 | imps = append(imps, pkg) 44 | } 45 | } 46 | sort.Strings(imps) 47 | return imps 48 | } 49 | -------------------------------------------------------------------------------- /internal/codegen/golang/testdata/declarer_composite.input.golden: -------------------------------------------------------------------------------- 1 | // SomeTable represents the Postgres composite type "some_table". 2 | type SomeTable struct { 3 | Foo int16 `json:"foo"` 4 | BarBaz pgtype.Text `json:"bar_baz"` 5 | } 6 | 7 | // typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. 8 | type typeResolver struct { 9 | connInfo *pgtype.ConnInfo // types by Postgres type name 10 | } 11 | 12 | func newTypeResolver() *typeResolver { 13 | ci := pgtype.NewConnInfo() 14 | return &typeResolver{connInfo: ci} 15 | } 16 | 17 | // findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. 18 | func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { 19 | typ, ok := tr.connInfo.DataTypeForName(name) 20 | if !ok { 21 | return 0, nil, false 22 | } 23 | v := pgtype.NewValue(typ.Value) 24 | return typ.OID, v.(pgtype.ValueTranscoder), true 25 | } 26 | 27 | // setValue sets the value of a ValueTranscoder to a value that should always 28 | // work and panics if it fails. 29 | func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { 30 | if err := vt.Set(val); err != nil { 31 | panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) 32 | } 33 | return vt 34 | } 35 | 36 | type compositeField struct { 37 | name string // name of the field 38 | typeName string // Postgres type name 39 | defaultVal pgtype.ValueTranscoder // default value to use 40 | } 41 | 42 | func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { 43 | if _, val, ok := tr.findValue(name); ok { 44 | return val 45 | } 46 | fs := make([]pgtype.CompositeTypeField, len(fields)) 47 | vals := make([]pgtype.ValueTranscoder, len(fields)) 48 | isBinaryOk := true 49 | for i, field := range fields { 50 | oid, val, ok := tr.findValue(field.typeName) 51 | if !ok { 52 | oid = unknownOID 53 | val = field.defaultVal 54 | } 55 | isBinaryOk = isBinaryOk && oid != unknownOID 56 | fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} 57 | vals[i] = val 58 | } 59 | // Okay to ignore error because it's only thrown when the number of field 60 | // names does not equal the number of ValueTranscoders. 61 | typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) 62 | if !isBinaryOk { 63 | return textPreferrer{ValueTranscoder: typ, typeName: name} 64 | } 65 | return typ 66 | } 67 | 68 | func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { 69 | if _, val, ok := tr.findValue(name); ok { 70 | return val 71 | } 72 | elemOID, elemVal, ok := tr.findValue(elemName) 73 | elemValFunc := func() pgtype.ValueTranscoder { 74 | return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) 75 | } 76 | if !ok { 77 | elemOID = unknownOID 78 | elemValFunc = defaultVal 79 | } 80 | typ := pgtype.NewArrayType(name, elemOID, elemValFunc) 81 | if elemOID == unknownOID { 82 | return textPreferrer{ValueTranscoder: typ, typeName: name} 83 | } 84 | return typ 85 | } 86 | 87 | // newSomeTable creates a new pgtype.ValueTranscoder for the Postgres 88 | // composite type 'some_table'. 89 | func (tr *typeResolver) newSomeTable() pgtype.ValueTranscoder { 90 | return tr.newCompositeValue( 91 | "some_table", 92 | compositeField{name: "foo", typeName: "int2", defaultVal: &pgtype.Int2{}}, 93 | compositeField{name: "bar_baz", typeName: "text", defaultVal: &pgtype.Text{}}, 94 | ) 95 | } 96 | 97 | // newSomeTableInit creates an initialized pgtype.ValueTranscoder for the 98 | // Postgres composite type 'some_table' to encode query parameters. 99 | func (tr *typeResolver) newSomeTableInit(v SomeTable) pgtype.ValueTranscoder { 100 | return tr.setValue(tr.newSomeTable(), tr.newSomeTableRaw(v)) 101 | } 102 | 103 | // newSomeTableRaw returns all composite fields for the Postgres composite 104 | // type 'some_table' as a slice of interface{} to encode query parameters. 105 | func (tr *typeResolver) newSomeTableRaw(v SomeTable) []interface{} { 106 | return []interface{}{ 107 | v.Foo, 108 | v.BarBaz, 109 | } 110 | } -------------------------------------------------------------------------------- /internal/codegen/golang/testdata/declarer_composite.output.golden: -------------------------------------------------------------------------------- 1 | // SomeTable represents the Postgres composite type "some_table". 2 | type SomeTable struct { 3 | Foo int16 `json:"foo"` 4 | BarBaz pgtype.Text `json:"bar_baz"` 5 | } 6 | 7 | // typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. 8 | type typeResolver struct { 9 | connInfo *pgtype.ConnInfo // types by Postgres type name 10 | } 11 | 12 | func newTypeResolver() *typeResolver { 13 | ci := pgtype.NewConnInfo() 14 | return &typeResolver{connInfo: ci} 15 | } 16 | 17 | // findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. 18 | func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { 19 | typ, ok := tr.connInfo.DataTypeForName(name) 20 | if !ok { 21 | return 0, nil, false 22 | } 23 | v := pgtype.NewValue(typ.Value) 24 | return typ.OID, v.(pgtype.ValueTranscoder), true 25 | } 26 | 27 | // setValue sets the value of a ValueTranscoder to a value that should always 28 | // work and panics if it fails. 29 | func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { 30 | if err := vt.Set(val); err != nil { 31 | panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) 32 | } 33 | return vt 34 | } 35 | 36 | type compositeField struct { 37 | name string // name of the field 38 | typeName string // Postgres type name 39 | defaultVal pgtype.ValueTranscoder // default value to use 40 | } 41 | 42 | func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { 43 | if _, val, ok := tr.findValue(name); ok { 44 | return val 45 | } 46 | fs := make([]pgtype.CompositeTypeField, len(fields)) 47 | vals := make([]pgtype.ValueTranscoder, len(fields)) 48 | isBinaryOk := true 49 | for i, field := range fields { 50 | oid, val, ok := tr.findValue(field.typeName) 51 | if !ok { 52 | oid = unknownOID 53 | val = field.defaultVal 54 | } 55 | isBinaryOk = isBinaryOk && oid != unknownOID 56 | fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} 57 | vals[i] = val 58 | } 59 | // Okay to ignore error because it's only thrown when the number of field 60 | // names does not equal the number of ValueTranscoders. 61 | typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) 62 | if !isBinaryOk { 63 | return textPreferrer{ValueTranscoder: typ, typeName: name} 64 | } 65 | return typ 66 | } 67 | 68 | func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { 69 | if _, val, ok := tr.findValue(name); ok { 70 | return val 71 | } 72 | elemOID, elemVal, ok := tr.findValue(elemName) 73 | elemValFunc := func() pgtype.ValueTranscoder { 74 | return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) 75 | } 76 | if !ok { 77 | elemOID = unknownOID 78 | elemValFunc = defaultVal 79 | } 80 | typ := pgtype.NewArrayType(name, elemOID, elemValFunc) 81 | if elemOID == unknownOID { 82 | return textPreferrer{ValueTranscoder: typ, typeName: name} 83 | } 84 | return typ 85 | } 86 | 87 | // newSomeTable creates a new pgtype.ValueTranscoder for the Postgres 88 | // composite type 'some_table'. 89 | func (tr *typeResolver) newSomeTable() pgtype.ValueTranscoder { 90 | return tr.newCompositeValue( 91 | "some_table", 92 | compositeField{name: "foo", typeName: "int2", defaultVal: &pgtype.Int2{}}, 93 | compositeField{name: "bar_baz", typeName: "text", defaultVal: &pgtype.Text{}}, 94 | ) 95 | } -------------------------------------------------------------------------------- /internal/codegen/golang/testdata/declarer_composite_array.output.golden: -------------------------------------------------------------------------------- 1 | // SomeTable represents the Postgres composite type "some_table". 2 | type SomeTable struct { 3 | Foo int16 `json:"foo"` 4 | BarBaz pgtype.Text `json:"bar_baz"` 5 | } 6 | 7 | // typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. 8 | type typeResolver struct { 9 | connInfo *pgtype.ConnInfo // types by Postgres type name 10 | } 11 | 12 | func newTypeResolver() *typeResolver { 13 | ci := pgtype.NewConnInfo() 14 | return &typeResolver{connInfo: ci} 15 | } 16 | 17 | // findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. 18 | func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { 19 | typ, ok := tr.connInfo.DataTypeForName(name) 20 | if !ok { 21 | return 0, nil, false 22 | } 23 | v := pgtype.NewValue(typ.Value) 24 | return typ.OID, v.(pgtype.ValueTranscoder), true 25 | } 26 | 27 | // setValue sets the value of a ValueTranscoder to a value that should always 28 | // work and panics if it fails. 29 | func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { 30 | if err := vt.Set(val); err != nil { 31 | panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) 32 | } 33 | return vt 34 | } 35 | 36 | type compositeField struct { 37 | name string // name of the field 38 | typeName string // Postgres type name 39 | defaultVal pgtype.ValueTranscoder // default value to use 40 | } 41 | 42 | func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { 43 | if _, val, ok := tr.findValue(name); ok { 44 | return val 45 | } 46 | fs := make([]pgtype.CompositeTypeField, len(fields)) 47 | vals := make([]pgtype.ValueTranscoder, len(fields)) 48 | isBinaryOk := true 49 | for i, field := range fields { 50 | oid, val, ok := tr.findValue(field.typeName) 51 | if !ok { 52 | oid = unknownOID 53 | val = field.defaultVal 54 | } 55 | isBinaryOk = isBinaryOk && oid != unknownOID 56 | fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} 57 | vals[i] = val 58 | } 59 | // Okay to ignore error because it's only thrown when the number of field 60 | // names does not equal the number of ValueTranscoders. 61 | typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) 62 | if !isBinaryOk { 63 | return textPreferrer{ValueTranscoder: typ, typeName: name} 64 | } 65 | return typ 66 | } 67 | 68 | func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { 69 | if _, val, ok := tr.findValue(name); ok { 70 | return val 71 | } 72 | elemOID, elemVal, ok := tr.findValue(elemName) 73 | elemValFunc := func() pgtype.ValueTranscoder { 74 | return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) 75 | } 76 | if !ok { 77 | elemOID = unknownOID 78 | elemValFunc = defaultVal 79 | } 80 | typ := pgtype.NewArrayType(name, elemOID, elemValFunc) 81 | if elemOID == unknownOID { 82 | return textPreferrer{ValueTranscoder: typ, typeName: name} 83 | } 84 | return typ 85 | } 86 | 87 | // newSomeTable creates a new pgtype.ValueTranscoder for the Postgres 88 | // composite type 'some_table'. 89 | func (tr *typeResolver) newSomeTable() pgtype.ValueTranscoder { 90 | return tr.newCompositeValue( 91 | "some_table", 92 | compositeField{name: "foo", typeName: "int2", defaultVal: &pgtype.Int2{}}, 93 | compositeField{name: "bar_baz", typeName: "text", defaultVal: &pgtype.Text{}}, 94 | ) 95 | } 96 | 97 | // newSomeTableArray creates a new pgtype.ValueTranscoder for the Postgres 98 | // '_some_array' array type. 99 | func (tr *typeResolver) newSomeTableArray() pgtype.ValueTranscoder { 100 | return tr.newArrayValue("_some_array", "some_table", tr.newSomeTable) 101 | } -------------------------------------------------------------------------------- /internal/codegen/golang/testdata/declarer_composite_enum.output.golden: -------------------------------------------------------------------------------- 1 | // SomeTableEnum represents the Postgres composite type "some_table_enum". 2 | type SomeTableEnum struct { 3 | Foo DeviceType `json:"foo"` 4 | } 5 | 6 | // newDeviceTypeEnum creates a new pgtype.ValueTranscoder for the 7 | // Postgres enum type 'device_type'. 8 | func newDeviceTypeEnum() pgtype.ValueTranscoder { 9 | return pgtype.NewEnumType( 10 | "device_type", 11 | []string{ 12 | string(DeviceTypeIOS), 13 | string(DeviceTypeMobile), 14 | }, 15 | ) 16 | } 17 | 18 | // DeviceType represents the Postgres enum "device_type". 19 | type DeviceType string 20 | 21 | const ( 22 | DeviceTypeIOS DeviceType = "ios" 23 | DeviceTypeMobile DeviceType = "mobile" 24 | ) 25 | 26 | func (d DeviceType) String() string { return string(d) } 27 | 28 | // typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. 29 | type typeResolver struct { 30 | connInfo *pgtype.ConnInfo // types by Postgres type name 31 | } 32 | 33 | func newTypeResolver() *typeResolver { 34 | ci := pgtype.NewConnInfo() 35 | return &typeResolver{connInfo: ci} 36 | } 37 | 38 | // findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. 39 | func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { 40 | typ, ok := tr.connInfo.DataTypeForName(name) 41 | if !ok { 42 | return 0, nil, false 43 | } 44 | v := pgtype.NewValue(typ.Value) 45 | return typ.OID, v.(pgtype.ValueTranscoder), true 46 | } 47 | 48 | // setValue sets the value of a ValueTranscoder to a value that should always 49 | // work and panics if it fails. 50 | func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { 51 | if err := vt.Set(val); err != nil { 52 | panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) 53 | } 54 | return vt 55 | } 56 | 57 | type compositeField struct { 58 | name string // name of the field 59 | typeName string // Postgres type name 60 | defaultVal pgtype.ValueTranscoder // default value to use 61 | } 62 | 63 | func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { 64 | if _, val, ok := tr.findValue(name); ok { 65 | return val 66 | } 67 | fs := make([]pgtype.CompositeTypeField, len(fields)) 68 | vals := make([]pgtype.ValueTranscoder, len(fields)) 69 | isBinaryOk := true 70 | for i, field := range fields { 71 | oid, val, ok := tr.findValue(field.typeName) 72 | if !ok { 73 | oid = unknownOID 74 | val = field.defaultVal 75 | } 76 | isBinaryOk = isBinaryOk && oid != unknownOID 77 | fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} 78 | vals[i] = val 79 | } 80 | // Okay to ignore error because it's only thrown when the number of field 81 | // names does not equal the number of ValueTranscoders. 82 | typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) 83 | if !isBinaryOk { 84 | return textPreferrer{ValueTranscoder: typ, typeName: name} 85 | } 86 | return typ 87 | } 88 | 89 | func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { 90 | if _, val, ok := tr.findValue(name); ok { 91 | return val 92 | } 93 | elemOID, elemVal, ok := tr.findValue(elemName) 94 | elemValFunc := func() pgtype.ValueTranscoder { 95 | return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) 96 | } 97 | if !ok { 98 | elemOID = unknownOID 99 | elemValFunc = defaultVal 100 | } 101 | typ := pgtype.NewArrayType(name, elemOID, elemValFunc) 102 | if elemOID == unknownOID { 103 | return textPreferrer{ValueTranscoder: typ, typeName: name} 104 | } 105 | return typ 106 | } 107 | 108 | // newSomeTableEnum creates a new pgtype.ValueTranscoder for the Postgres 109 | // composite type 'some_table_enum'. 110 | func (tr *typeResolver) newSomeTableEnum() pgtype.ValueTranscoder { 111 | return tr.newCompositeValue( 112 | "some_table_enum", 113 | compositeField{name: "foo", typeName: "some_table_enum", defaultVal: newDeviceTypeEnum()}, 114 | ) 115 | } -------------------------------------------------------------------------------- /internal/codegen/golang/testdata/declarer_composite_nested.output.golden: -------------------------------------------------------------------------------- 1 | // FooType represents the Postgres composite type "foo_type". 2 | type FooType struct { 3 | Alpha pgtype.Text `json:"alpha"` 4 | } 5 | 6 | // SomeTableNested represents the Postgres composite type "some_table_nested". 7 | type SomeTableNested struct { 8 | Foo FooType `json:"foo"` 9 | BarBaz pgtype.Text `json:"bar_baz"` 10 | } 11 | 12 | // typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. 13 | type typeResolver struct { 14 | connInfo *pgtype.ConnInfo // types by Postgres type name 15 | } 16 | 17 | func newTypeResolver() *typeResolver { 18 | ci := pgtype.NewConnInfo() 19 | return &typeResolver{connInfo: ci} 20 | } 21 | 22 | // findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. 23 | func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { 24 | typ, ok := tr.connInfo.DataTypeForName(name) 25 | if !ok { 26 | return 0, nil, false 27 | } 28 | v := pgtype.NewValue(typ.Value) 29 | return typ.OID, v.(pgtype.ValueTranscoder), true 30 | } 31 | 32 | // setValue sets the value of a ValueTranscoder to a value that should always 33 | // work and panics if it fails. 34 | func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { 35 | if err := vt.Set(val); err != nil { 36 | panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) 37 | } 38 | return vt 39 | } 40 | 41 | type compositeField struct { 42 | name string // name of the field 43 | typeName string // Postgres type name 44 | defaultVal pgtype.ValueTranscoder // default value to use 45 | } 46 | 47 | func (tr *typeResolver) newCompositeValue(name string, fields ...compositeField) pgtype.ValueTranscoder { 48 | if _, val, ok := tr.findValue(name); ok { 49 | return val 50 | } 51 | fs := make([]pgtype.CompositeTypeField, len(fields)) 52 | vals := make([]pgtype.ValueTranscoder, len(fields)) 53 | isBinaryOk := true 54 | for i, field := range fields { 55 | oid, val, ok := tr.findValue(field.typeName) 56 | if !ok { 57 | oid = unknownOID 58 | val = field.defaultVal 59 | } 60 | isBinaryOk = isBinaryOk && oid != unknownOID 61 | fs[i] = pgtype.CompositeTypeField{Name: field.name, OID: oid} 62 | vals[i] = val 63 | } 64 | // Okay to ignore error because it's only thrown when the number of field 65 | // names does not equal the number of ValueTranscoders. 66 | typ, _ := pgtype.NewCompositeTypeValues(name, fs, vals) 67 | if !isBinaryOk { 68 | return textPreferrer{ValueTranscoder: typ, typeName: name} 69 | } 70 | return typ 71 | } 72 | 73 | func (tr *typeResolver) newArrayValue(name, elemName string, defaultVal func() pgtype.ValueTranscoder) pgtype.ValueTranscoder { 74 | if _, val, ok := tr.findValue(name); ok { 75 | return val 76 | } 77 | elemOID, elemVal, ok := tr.findValue(elemName) 78 | elemValFunc := func() pgtype.ValueTranscoder { 79 | return pgtype.NewValue(elemVal).(pgtype.ValueTranscoder) 80 | } 81 | if !ok { 82 | elemOID = unknownOID 83 | elemValFunc = defaultVal 84 | } 85 | typ := pgtype.NewArrayType(name, elemOID, elemValFunc) 86 | if elemOID == unknownOID { 87 | return textPreferrer{ValueTranscoder: typ, typeName: name} 88 | } 89 | return typ 90 | } 91 | 92 | // newFooType creates a new pgtype.ValueTranscoder for the Postgres 93 | // composite type 'foo_type'. 94 | func (tr *typeResolver) newFooType() pgtype.ValueTranscoder { 95 | return tr.newCompositeValue( 96 | "foo_type", 97 | compositeField{name: "alpha", typeName: "text", defaultVal: &pgtype.Text{}}, 98 | ) 99 | } 100 | 101 | // newSomeTableNested creates a new pgtype.ValueTranscoder for the Postgres 102 | // composite type 'some_table_nested'. 103 | func (tr *typeResolver) newSomeTableNested() pgtype.ValueTranscoder { 104 | return tr.newCompositeValue( 105 | "some_table_nested", 106 | compositeField{name: "foo", typeName: "foo_type", defaultVal: tr.newFooType()}, 107 | compositeField{name: "bar_baz", typeName: "text", defaultVal: &pgtype.Text{}}, 108 | ) 109 | } -------------------------------------------------------------------------------- /internal/codegen/golang/testdata/declarer_enum_escaping.input.golden: -------------------------------------------------------------------------------- 1 | // Quoting represents the Postgres enum "quoting". 2 | type Quoting string 3 | 4 | const ( 5 | QuotingUnnamedLabel0 Quoting = "\"\n\t" 6 | QuotingUnnamedLabel1 Quoting = "`\"`" 7 | ) 8 | 9 | func (q Quoting) String() string { return string(q) } 10 | 11 | // typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. 12 | type typeResolver struct { 13 | connInfo *pgtype.ConnInfo // types by Postgres type name 14 | } 15 | 16 | func newTypeResolver() *typeResolver { 17 | ci := pgtype.NewConnInfo() 18 | return &typeResolver{connInfo: ci} 19 | } 20 | 21 | // findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. 22 | func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { 23 | typ, ok := tr.connInfo.DataTypeForName(name) 24 | if !ok { 25 | return 0, nil, false 26 | } 27 | v := pgtype.NewValue(typ.Value) 28 | return typ.OID, v.(pgtype.ValueTranscoder), true 29 | } 30 | 31 | // setValue sets the value of a ValueTranscoder to a value that should always 32 | // work and panics if it fails. 33 | func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { 34 | if err := vt.Set(val); err != nil { 35 | panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) 36 | } 37 | return vt 38 | } -------------------------------------------------------------------------------- /internal/codegen/golang/testdata/declarer_enum_escaping.output.golden: -------------------------------------------------------------------------------- 1 | // Quoting represents the Postgres enum "quoting". 2 | type Quoting string 3 | 4 | const ( 5 | QuotingUnnamedLabel0 Quoting = "\"\n\t" 6 | QuotingUnnamedLabel1 Quoting = "`\"`" 7 | ) 8 | 9 | func (q Quoting) String() string { return string(q) } 10 | 11 | // typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. 12 | type typeResolver struct { 13 | connInfo *pgtype.ConnInfo // types by Postgres type name 14 | } 15 | 16 | func newTypeResolver() *typeResolver { 17 | ci := pgtype.NewConnInfo() 18 | return &typeResolver{connInfo: ci} 19 | } 20 | 21 | // findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. 22 | func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { 23 | typ, ok := tr.connInfo.DataTypeForName(name) 24 | if !ok { 25 | return 0, nil, false 26 | } 27 | v := pgtype.NewValue(typ.Value) 28 | return typ.OID, v.(pgtype.ValueTranscoder), true 29 | } 30 | 31 | // setValue sets the value of a ValueTranscoder to a value that should always 32 | // work and panics if it fails. 33 | func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { 34 | if err := vt.Set(val); err != nil { 35 | panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) 36 | } 37 | return vt 38 | } -------------------------------------------------------------------------------- /internal/codegen/golang/testdata/declarer_enum_simple.input.golden: -------------------------------------------------------------------------------- 1 | // DeviceType represents the Postgres enum "device_type". 2 | type DeviceType string 3 | 4 | const ( 5 | DeviceTypeIOS DeviceType = "ios" 6 | DeviceTypeMobile DeviceType = "mobile" 7 | ) 8 | 9 | func (d DeviceType) String() string { return string(d) } 10 | 11 | // typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. 12 | type typeResolver struct { 13 | connInfo *pgtype.ConnInfo // types by Postgres type name 14 | } 15 | 16 | func newTypeResolver() *typeResolver { 17 | ci := pgtype.NewConnInfo() 18 | return &typeResolver{connInfo: ci} 19 | } 20 | 21 | // findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. 22 | func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { 23 | typ, ok := tr.connInfo.DataTypeForName(name) 24 | if !ok { 25 | return 0, nil, false 26 | } 27 | v := pgtype.NewValue(typ.Value) 28 | return typ.OID, v.(pgtype.ValueTranscoder), true 29 | } 30 | 31 | // setValue sets the value of a ValueTranscoder to a value that should always 32 | // work and panics if it fails. 33 | func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { 34 | if err := vt.Set(val); err != nil { 35 | panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) 36 | } 37 | return vt 38 | } -------------------------------------------------------------------------------- /internal/codegen/golang/testdata/declarer_enum_simple.output.golden: -------------------------------------------------------------------------------- 1 | // DeviceType represents the Postgres enum "device_type". 2 | type DeviceType string 3 | 4 | const ( 5 | DeviceTypeIOS DeviceType = "ios" 6 | DeviceTypeMobile DeviceType = "mobile" 7 | ) 8 | 9 | func (d DeviceType) String() string { return string(d) } 10 | 11 | // typeResolver looks up the pgtype.ValueTranscoder by Postgres type name. 12 | type typeResolver struct { 13 | connInfo *pgtype.ConnInfo // types by Postgres type name 14 | } 15 | 16 | func newTypeResolver() *typeResolver { 17 | ci := pgtype.NewConnInfo() 18 | return &typeResolver{connInfo: ci} 19 | } 20 | 21 | // findValue find the OID, and pgtype.ValueTranscoder for a Postgres type name. 22 | func (tr *typeResolver) findValue(name string) (uint32, pgtype.ValueTranscoder, bool) { 23 | typ, ok := tr.connInfo.DataTypeForName(name) 24 | if !ok { 25 | return 0, nil, false 26 | } 27 | v := pgtype.NewValue(typ.Value) 28 | return typ.OID, v.(pgtype.ValueTranscoder), true 29 | } 30 | 31 | // setValue sets the value of a ValueTranscoder to a value that should always 32 | // work and panics if it fails. 33 | func (tr *typeResolver) setValue(vt pgtype.ValueTranscoder, val interface{}) pgtype.ValueTranscoder { 34 | if err := vt.Set(val); err != nil { 35 | panic(fmt.Sprintf("set ValueTranscoder %T to %+v: %s", vt, val, err)) 36 | } 37 | return vt 38 | } -------------------------------------------------------------------------------- /internal/difftest/difftest.go: -------------------------------------------------------------------------------- 1 | package difftest 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/google/go-cmp/cmp" 7 | "github.com/google/go-cmp/cmp/cmpopts" 8 | ) 9 | 10 | func AssertSame(t *testing.T, want, got interface{}, opts ...cmp.Option) { 11 | t.Helper() 12 | allOpts := append([]cmp.Option{ 13 | cmpopts.EquateEmpty(), // useful so nil is same as 0-sized slice 14 | }, opts...) 15 | if diff := cmp.Diff(want, got, allOpts...); diff != "" { 16 | t.Errorf("mismatch (-want +got)\n%s", diff) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /internal/errs/errs.go: -------------------------------------------------------------------------------- 1 | package errs 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "testing" 7 | ) 8 | 9 | // Capture runs errF and assigns the error, if any, to *err. Preserves the 10 | // original err by wrapping in a MultiError if err is non-nil. If msg is not 11 | // empty, wrap the error returned by closer with the msg. 12 | // 13 | // - If errF returns nil, do nothing. 14 | // - If errF returns an error and *err == nil, replace *err with the error. 15 | // - If errF returns an error and *err != nil, replace *err with a MultiError 16 | // containing *err and the errF err. 17 | func Capture(err *error, errF func() error, msg string) { 18 | fErr := errF() 19 | if fErr == nil { 20 | return 21 | } 22 | 23 | wErr := fErr 24 | if msg != "" { 25 | wErr = fmt.Errorf(msg+": %w", fErr) 26 | } 27 | if *err == nil { 28 | // Only 1 error so return it directly 29 | *err = wErr 30 | return 31 | } 32 | 33 | *err = errors.Join(*err, wErr) 34 | } 35 | 36 | // CaptureT call t.Error if errF returns an error with an optional message. 37 | func CaptureT(t *testing.T, errF func() error, msg string) { 38 | t.Helper() 39 | if err := errF(); err != nil { 40 | if msg == "" { 41 | t.Error(err) 42 | } else { 43 | t.Errorf(msg+": %s", err) 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /internal/flags/flags.go: -------------------------------------------------------------------------------- 1 | package flags 2 | 3 | import ( 4 | "flag" 5 | "strings" 6 | ) 7 | 8 | // Strings returns a repeated string flag that accumulates value into a slice. 9 | func Strings(fset *flag.FlagSet, name string, value []string, usage string) *[]string { 10 | sv := &stringsValue{ 11 | strings: &value, 12 | } 13 | fset.Var(sv, name, usage) 14 | return sv.strings 15 | } 16 | 17 | type stringsValue struct { 18 | strings *[]string 19 | } 20 | 21 | // String implements flag.Value and fmt.Stringer. 22 | func (sv *stringsValue) String() string { 23 | return strings.Join(*sv.strings, ",") 24 | } 25 | 26 | // Get implements flag.Getter. 27 | func (sv *stringsValue) Get() interface{} { 28 | return *sv.strings 29 | } 30 | 31 | // Set implements flag.Value. 32 | func (sv *stringsValue) Set(value string) error { 33 | *sv.strings = append(*sv.strings, value) 34 | return nil 35 | } 36 | -------------------------------------------------------------------------------- /internal/gomod/gomod.go: -------------------------------------------------------------------------------- 1 | // Package gomod provides utilities for getting information about the current 2 | // Go module. 3 | package gomod 4 | 5 | import ( 6 | "fmt" 7 | "os" 8 | "path/filepath" 9 | "sync" 10 | 11 | "github.com/jschaf/pggen/internal/paths" 12 | "golang.org/x/mod/modfile" 13 | ) 14 | 15 | //nolint:gochecknoglobals 16 | var ( 17 | goModDirOnce = &sync.Once{} 18 | goModDir string 19 | errGoModDir error 20 | 21 | goModNameOnce = &sync.Once{} 22 | goModPath string 23 | errGoModPath error 24 | ) 25 | 26 | // FindDir finds the nearest directory containing a go.mod file. Checks 27 | // the current dir and then walks up parent directories. 28 | func FindDir() (string, error) { 29 | goModDirOnce.Do(func() { 30 | wd, err := os.Getwd() 31 | if err != nil { 32 | errGoModDir = fmt.Errorf("FindDir working dir: %w", err) 33 | return 34 | } 35 | goModDir, errGoModDir = paths.WalkUp(wd, "go.mod") 36 | }) 37 | return goModDir, errGoModDir 38 | } 39 | 40 | // ParsePath finds the module path in the nearest go.mod file. 41 | func ParsePath() (string, error) { 42 | goModNameOnce.Do(func() { 43 | dir, err := FindDir() 44 | if err != nil { 45 | errGoModPath = fmt.Errorf("find go.mod dir: %w", err) 46 | return 47 | } 48 | p := filepath.Join(dir, "go.mod") 49 | bs, err := os.ReadFile(p) 50 | if err != nil { 51 | errGoModPath = fmt.Errorf("read go.mod: %w", err) 52 | return 53 | } 54 | goModPath = modfile.ModulePath(bs) 55 | }) 56 | return goModPath, errGoModPath 57 | } 58 | 59 | // GuessPackage guesses the full Go package path for a file name, relative to 60 | // current working directory. 61 | // Imperfect. Assumes package names always match directory names. 62 | func GuessPackage(fileName string) (string, error) { 63 | goModDir, err := FindDir() 64 | if err != nil { 65 | return "", fmt.Errorf("find go.mod dir: %w", err) 66 | } 67 | goModPath, err := ParsePath() 68 | if err != nil { 69 | return "", fmt.Errorf("parse go.mod dir: %w", err) 70 | } 71 | abs, err := filepath.Abs(fileName) 72 | if err != nil { 73 | return "", fmt.Errorf("abs path for %s: %w", fileName, err) 74 | } 75 | rel, err := filepath.Rel(goModDir, abs) 76 | if err != nil { 77 | return "", fmt.Errorf("rel path to go.mod for %s: %w", fileName, err) 78 | } 79 | // Dir to remove file name. Clean to remove "./" suffix. Convert to slash to 80 | // get forward slashes, which match Go package paths. 81 | relDir := filepath.ToSlash(filepath.Clean(filepath.Dir(rel))) 82 | return goModPath + "/" + relDir, nil 83 | } 84 | -------------------------------------------------------------------------------- /internal/gomod/gomod_test.go: -------------------------------------------------------------------------------- 1 | package gomod 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestResolvePackage(t *testing.T) { 10 | tests := []struct { 11 | path string 12 | want string 13 | }{ 14 | { 15 | path: "Foo.go", 16 | want: "github.com/jschaf/pggen/internal/gomod", 17 | }, 18 | { 19 | path: "../Foo.go", 20 | want: "github.com/jschaf/pggen/internal", 21 | }, 22 | { 23 | path: "./Foo.go", 24 | want: "github.com/jschaf/pggen/internal/gomod", 25 | }, 26 | { 27 | path: "blah/qux/Foo.go", 28 | want: "github.com/jschaf/pggen/internal/gomod/blah/qux", 29 | }, 30 | } 31 | for _, tt := range tests { 32 | t.Run(tt.path, func(t *testing.T) { 33 | got, err := GuessPackage(tt.path) 34 | assert.NoError(t, err) 35 | assert.Equal(t, tt.want, got) 36 | }) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /internal/parser/interface.go: -------------------------------------------------------------------------------- 1 | // Package parser contains the exported entry points for invoking the parser. 2 | package parser 3 | 4 | import ( 5 | "bytes" 6 | "errors" 7 | gotok "go/token" 8 | "io" 9 | "os" 10 | 11 | "github.com/jschaf/pggen/internal/ast" 12 | ) 13 | 14 | // If src != nil, readSource converts src to a []byte if possible; otherwise it 15 | // returns an error. If src == nil, readSource returns the result of reading the 16 | // file specified by filename. 17 | func readSource(filename string, src interface{}) ([]byte, error) { 18 | if src != nil { 19 | switch s := src.(type) { 20 | case string: 21 | return []byte(s), nil 22 | case []byte: 23 | return s, nil 24 | case *bytes.Buffer: 25 | // is io.Reader, but src is already available in []byte form 26 | if s != nil { 27 | return s.Bytes(), nil 28 | } 29 | case io.Reader: 30 | return io.ReadAll(s) 31 | } 32 | return nil, errors.New("invalid source") 33 | } 34 | return os.ReadFile(filename) 35 | } 36 | 37 | // A Mode value is a set of flags (or 0). 38 | // They control the amount of source code parsed and other optional parser 39 | // functionality. 40 | type Mode uint 41 | 42 | const ( 43 | Trace Mode = 1 << iota // print a trace of parsed productions 44 | ) 45 | 46 | // ParseFile parses the source code of a single query source file and returns 47 | // the corresponding ast.File node. The source code may be provided via the 48 | // filename of the source file, or via the src parameter. 49 | // 50 | // If src != nil, ParseFile parses the source from src and the filename is only 51 | // used when recording position information. The type of the argument for the 52 | // src parameter must be string, []byte, or io.Reader. If src == nil, ParseFile 53 | // parses the file specified by filename. 54 | // 55 | // The mode parameter controls the amount of source text parsed and other 56 | // optional parser functionality. Position information is recorded in the file 57 | // set fset, which must not be nil. 58 | // 59 | // If the source couldn't be read, the returned AST is nil and the error 60 | // indicates the specific failure. If the source was read but syntax errors were 61 | // found, the result is a partial AST (with ast.Bad* nodes representing the 62 | // fragments of erroneous source code). Multiple errors are returned via 63 | // a scanner.ErrorList which is sorted by source position. 64 | func ParseFile(fset *gotok.FileSet, filename string, src interface{}, mode Mode) (f *ast.File, err error) { 65 | if fset == nil { 66 | panic("parser.ParseFile: no token.FileSet provided (fset == nil)") 67 | } 68 | 69 | // get source 70 | text, err := readSource(filename, src) 71 | if err != nil { 72 | return nil, err 73 | } 74 | 75 | var p parser 76 | defer func() { 77 | if e := recover(); e != nil { 78 | // resume same panic if it's not a bailout 79 | if _, ok := e.(bailout); !ok { 80 | panic(e) 81 | } 82 | } 83 | 84 | // set result values 85 | if f == nil { 86 | // src is not a valid query source file - satisfy ParseFile API and 87 | // return a valid (but) empty *ast.File 88 | f = &ast.File{Name: filename} 89 | } 90 | 91 | p.errors.Sort() 92 | err = p.errors.Err() 93 | }() 94 | 95 | // parse source 96 | p.init(fset, filename, text, mode) 97 | f = p.parseFile() 98 | 99 | return 100 | } 101 | -------------------------------------------------------------------------------- /internal/paths/paths.go: -------------------------------------------------------------------------------- 1 | package paths 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | ) 8 | 9 | // WalkUp traverses up directory tree from dir until it finds an ancestor file 10 | // named name. Checks the current directory first and then iteratively checks 11 | // parent directories. 12 | func WalkUp(dir, name string) (string, error) { 13 | for dir != string(os.PathSeparator) { 14 | p := filepath.Join(dir, name) 15 | if _, err := os.Stat(p); err != nil { 16 | if !os.IsNotExist(err) { 17 | return "", fmt.Errorf("stat file %s: %w", p, err) 18 | } 19 | } else { 20 | return dir, nil 21 | } 22 | dir = filepath.Dir(dir) 23 | } 24 | return "", fmt.Errorf("dir not found in directory tree starting from %s", dir) 25 | } 26 | -------------------------------------------------------------------------------- /internal/pg/column.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | "sync" 9 | "time" 10 | 11 | "github.com/jackc/pgtype" 12 | "github.com/jackc/pgx/v4" 13 | "github.com/jschaf/pggen/internal/texts" 14 | ) 15 | 16 | // Column stores information about a column in a TableOID. 17 | // https://www.postgresql.org/docs/13/catalog-pg-attribute.html 18 | type Column struct { 19 | Name string // pg_attribute.attname: column name 20 | TableOID pgtype.OID // pg_attribute:attrelid: table the column belongs to 21 | TableName string // pg_class.relname: name of table that owns the column 22 | Number uint16 // pg_attribute.attnum: the number of column starting from 1 23 | Type Type // pg_attribute.atttypid: data type of the column 24 | Null bool // pg_attribute.attnotnull: represents a not-null constraint 25 | } 26 | 27 | // ColumnKey is a composite key of a table OID and the number of the column 28 | // within the table. 29 | type ColumnKey struct { 30 | TableOID pgtype.OID 31 | Number uint16 // the number of column starting from 1 32 | } 33 | 34 | //nolint:gochecknoglobals 35 | var ( 36 | columnsMu = &sync.Mutex{} 37 | columnCache = make(map[ColumnKey]Column, 32) 38 | ) 39 | 40 | // FetchColumns fetches meta information about a Postgres column from the 41 | // pg_class and pg_attribute catalog tables. 42 | func FetchColumns(conn *pgx.Conn, keys []ColumnKey) ([]Column, error) { 43 | if len(keys) == 0 { 44 | return nil, nil 45 | } 46 | 47 | // Try cache first. 48 | uncachedKeys := make([]ColumnKey, 0, len(keys)) 49 | columnsMu.Lock() 50 | for _, key := range keys { 51 | if _, ok := columnCache[key]; !ok && key.TableOID > 0 { 52 | uncachedKeys = append(uncachedKeys, key) 53 | } 54 | } 55 | columnsMu.Unlock() 56 | if len(uncachedKeys) == 0 { 57 | return fetchCachedColumns(keys) 58 | } 59 | 60 | // Build query predicate. 61 | predicate := &strings.Builder{} 62 | predicate.Grow(len(uncachedKeys) * 40) 63 | for i, key := range uncachedKeys { 64 | predicate.WriteString("(cls.oid = ") 65 | predicate.WriteString(strconv.Itoa(int(key.TableOID))) 66 | predicate.WriteString(" AND attr.attnum = ") 67 | predicate.WriteString(strconv.Itoa(int(key.Number))) 68 | predicate.WriteString(")") 69 | if i < len(uncachedKeys)-1 { 70 | predicate.WriteString("\n OR ") 71 | } 72 | } 73 | 74 | // Execute query. 75 | q := texts.Dedent(` 76 | SELECT cls.oid AS table_oid, 77 | cls.relname AS table_name, 78 | attr.attname AS col_name, 79 | attr.attnum AS col_num, 80 | attr.attnotnull AS col_null 81 | FROM pg_class cls 82 | JOIN pg_attribute attr ON (attr.attrelid = cls.oid) 83 | `) + "\nWHERE " + predicate.String() 84 | ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) 85 | defer cancel() 86 | rows, err := conn.Query(ctx, q) 87 | if err != nil { 88 | return nil, fmt.Errorf("fetch column metadata: %w", err) 89 | } 90 | defer rows.Close() 91 | for rows.Next() { 92 | col := Column{} 93 | notNull := false 94 | if err := rows.Scan(&col.TableOID, &col.TableName, &col.Name, &col.Number, ¬Null); err != nil { 95 | return nil, fmt.Errorf("scan fetch column row: %w", err) 96 | } 97 | col.Null = !notNull 98 | columnCache[ColumnKey{col.TableOID, col.Number}] = col 99 | } 100 | if err := rows.Err(); err != nil { 101 | return nil, fmt.Errorf("close fetch column rows: %w", err) 102 | } 103 | 104 | return fetchCachedColumns(keys) 105 | } 106 | 107 | func fetchCachedColumns(keys []ColumnKey) ([]Column, error) { 108 | cols := make([]Column, 0, len(keys)) 109 | columnsMu.Lock() 110 | defer columnsMu.Unlock() 111 | for _, key := range keys { 112 | col, ok := columnCache[key] 113 | // Ignore columns not directly backed by a table. 114 | if !ok && col.TableOID > 0 { 115 | return nil, fmt.Errorf("missing column in fetch cache table_oid=%d Number=%d", key.TableOID, key.Number) 116 | } 117 | cols = append(cols, col) 118 | } 119 | return cols, nil 120 | } 121 | -------------------------------------------------------------------------------- /internal/pg/column_test.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | "time" 8 | 9 | "github.com/google/go-cmp/cmp" 10 | "github.com/jackc/pgtype" 11 | "github.com/jackc/pgx/v4" 12 | "github.com/jschaf/pggen/internal/pgtest" 13 | "github.com/jschaf/pggen/internal/texts" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func TestFetchColumns(t *testing.T) { 18 | tests := []struct { 19 | name string 20 | schema string 21 | colNums []uint16 22 | want []Column 23 | }{ 24 | {"empty", "", nil, nil}, 25 | { 26 | "one col null", 27 | "CREATE TABLE author ( first_name text );", 28 | []uint16{1}, 29 | []Column{{Name: "first_name", TableName: "author", Number: 1, Null: true}}, 30 | }, 31 | { 32 | "one col not null", 33 | "CREATE TABLE author ( first_name text NOT NULL);", 34 | []uint16{1}, 35 | []Column{{Name: "first_name", TableName: "author", Number: 1, Null: false}}, 36 | }, 37 | { 38 | "two col mixed", 39 | "CREATE TABLE author ( first_name text NOT NULL, last_name text);", 40 | []uint16{2, 1}, 41 | []Column{ 42 | {Name: "last_name", TableName: "author", Number: 2, Null: true}, 43 | {Name: "first_name", TableName: "author", Number: 1, Null: false}, 44 | }, 45 | }, 46 | } 47 | for _, tt := range tests { 48 | t.Run(tt.name, func(t *testing.T) { 49 | conn, cleanup := pgtest.NewPostgresSchemaString(t, tt.schema) 50 | defer cleanup() 51 | oid := findTableOID(t, conn, "author") 52 | keys := make([]ColumnKey, len(tt.colNums)) 53 | for i, num := range tt.colNums { 54 | keys[i] = ColumnKey{oid, num} 55 | } 56 | cols, err := FetchColumns(conn, keys) 57 | if err != nil { 58 | t.Fatal(err) 59 | } 60 | // Add table OID to each key. 61 | for i, col := range tt.want { 62 | col.TableOID = oid 63 | tt.want[i] = col 64 | } 65 | if diff := cmp.Diff(tt.want, cols); diff != "" { 66 | t.Errorf("FetchColumns() query mismatch (-want +got):\n%s", diff) 67 | } 68 | 69 | // Test cache. 70 | cols2, err := FetchColumns(conn, keys) 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | assert.Equal(t, cols, cols2, "same fetch columns in succession") 75 | }) 76 | } 77 | } 78 | 79 | func findTableOID(t *testing.T, conn *pgx.Conn, table string) pgtype.OID { 80 | sql := texts.Dedent(` 81 | SELECT oid AS table_oid 82 | FROM pg_class 83 | WHERE relname = $1 84 | ORDER BY table_oid DESC 85 | LIMIT 1; 86 | `) 87 | ctx, cancel := context.WithTimeout(t.Context(), 3*time.Second) 88 | defer cancel() 89 | row := conn.QueryRow(ctx, sql, table) 90 | var oid pgtype.OID = 0 91 | if err := row.Scan(&oid); err != nil && !errors.Is(err, pgx.ErrNoRows) { 92 | t.Fatal(err) 93 | } 94 | return oid 95 | } 96 | -------------------------------------------------------------------------------- /internal/pg/pgoid/oids.go: -------------------------------------------------------------------------------- 1 | package pgoid 2 | 3 | const ( 4 | PgNodeTree = 194 5 | OIDArray = 1028 6 | MacaddrArray = 1040 7 | Void = 2278 8 | ) 9 | -------------------------------------------------------------------------------- /internal/pg/type_cache.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/jackc/pgtype" 7 | ) 8 | 9 | // typeCache caches a map from a Postgres pg_type.oid to a Type. 10 | type typeCache struct { 11 | types map[pgtype.OID]Type 12 | mu *sync.Mutex 13 | } 14 | 15 | func newTypeCache() *typeCache { 16 | m := make(map[pgtype.OID]Type, len(defaultKnownTypes)) 17 | for oid, typ := range defaultKnownTypes { 18 | m[oid] = typ 19 | } 20 | return &typeCache{ 21 | types: m, 22 | mu: &sync.Mutex{}, 23 | } 24 | } 25 | 26 | // getOIDs returns the cached OIDS (with the type) and uncached OIDs. 27 | func (tc *typeCache) getOIDs(oids ...uint32) (map[pgtype.OID]Type, map[pgtype.OID]struct{}) { 28 | cachedTypes := make(map[pgtype.OID]Type, len(oids)) 29 | uncachedTypes := make(map[pgtype.OID]struct{}, len(oids)) 30 | tc.mu.Lock() 31 | defer tc.mu.Unlock() 32 | for _, oid := range oids { 33 | if t, ok := tc.types[pgtype.OID(oid)]; ok { 34 | cachedTypes[pgtype.OID(oid)] = t 35 | } else { 36 | uncachedTypes[pgtype.OID(oid)] = struct{}{} 37 | } 38 | } 39 | return cachedTypes, uncachedTypes 40 | } 41 | 42 | func (tc *typeCache) getOID(oid uint32) (Type, bool) { 43 | tc.mu.Lock() 44 | typ, ok := tc.types[pgtype.OID(oid)] 45 | tc.mu.Unlock() 46 | return typ, ok 47 | } 48 | 49 | func (tc *typeCache) addType(typ Type) { 50 | tc.mu.Lock() 51 | tc.types[typ.OID()] = typ 52 | tc.mu.Unlock() 53 | } 54 | -------------------------------------------------------------------------------- /internal/pgdocker/template.go: -------------------------------------------------------------------------------- 1 | package pgdocker 2 | 3 | type pgTemplate struct { 4 | PGPass string 5 | InitScripts []string 6 | } 7 | 8 | const dockerfileTemplate = ` 9 | {{- /*gotype: github.com/jschaf/pggen/internal/pgdocker.pgTemplate*/ -}} 10 | {{- define "dockerfile" -}} 11 | FROM postgres:13 12 | {{ range .InitScripts }} 13 | COPY {{.}} /docker-entrypoint-initdb.d/ 14 | {{ end }} 15 | {{- end }} 16 | ` 17 | -------------------------------------------------------------------------------- /internal/pginfer/explain.go: -------------------------------------------------------------------------------- 1 | package pginfer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/jschaf/pggen/internal/ast" 8 | ) 9 | 10 | // PlanType is the top-level node plan type that Postgres plans for executing 11 | // query. https://www.postgresql.org/docs/13/executor.html 12 | type PlanType string 13 | 14 | const ( 15 | PlanResult PlanType = "Result" // select statement 16 | PlanLimit PlanType = "Limit" // select statement with a limit 17 | PlanModifyTable PlanType = "ModifyTable" // update, insert, or delete statement 18 | ) 19 | 20 | // Plan is the plan output from an EXPLAIN query. 21 | type Plan struct { 22 | Type PlanType 23 | Relation string // target relation if any 24 | Outputs []string // the output expressions if any 25 | } 26 | 27 | type ExplainQueryResultRow struct { 28 | Plan map[string]interface{} `json:"Plan,omitempty"` 29 | QueryIdentifier *uint64 `json:"QueryIdentifier,omitempty"` 30 | } 31 | 32 | // explainQuery executes explain plan to get the node plan type and the format 33 | // of the output columns. 34 | func (inf *Inferrer) explainQuery(query *ast.SourceQuery) (Plan, error) { 35 | explainQuery := `EXPLAIN (VERBOSE, FORMAT JSON) ` + query.PreparedSQL 36 | ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) 37 | defer cancel() 38 | row := inf.conn.QueryRow(ctx, explainQuery, createParamArgs(query)...) 39 | explain := make([]ExplainQueryResultRow, 0, 1) 40 | if err := row.Scan(&explain); err != nil { 41 | return Plan{}, fmt.Errorf("explain prepared query: %w", err) 42 | } 43 | if len(explain) == 0 { 44 | return Plan{}, fmt.Errorf("no explain output") 45 | } 46 | plan := explain[0].Plan 47 | if len(plan) == 0 { 48 | return Plan{}, fmt.Errorf("explain output had no 'Plan' node") 49 | } 50 | 51 | // Node type 52 | node, ok := plan["Node Type"] 53 | if !ok { 54 | return Plan{}, fmt.Errorf("explain output had no 'Plan[Node Type]' node") 55 | } 56 | strNode, ok := node.(string) 57 | if !ok { 58 | return Plan{}, fmt.Errorf("explain output 'Plan[Node Type]' is not string; got type %T for value %v", node, node) 59 | } 60 | 61 | // Relation 62 | relation := plan["Relation Name"] 63 | relationStr, _ := relation.(string) 64 | 65 | // Outputs 66 | rawOuts := plan["Output"] 67 | outs, _ := rawOuts.([]interface{}) 68 | strOuts := make([]string, len(outs)) 69 | for i, out := range outs { 70 | out, ok := out.(string) 71 | if !ok { 72 | return Plan{}, fmt.Errorf("explain output 'Plan.Output[%d]' was not a string; got type %T for value %v", i, out, out) 73 | } 74 | strOuts[i] = out 75 | } 76 | return Plan{ 77 | Type: PlanType(strNode), 78 | Relation: relationStr, 79 | Outputs: strOuts, 80 | }, nil 81 | } 82 | -------------------------------------------------------------------------------- /internal/pginfer/nullability.go: -------------------------------------------------------------------------------- 1 | package pginfer 2 | 3 | import ( 4 | "strings" 5 | "unicode" 6 | 7 | "github.com/jschaf/pggen/internal/ast" 8 | "github.com/jschaf/pggen/internal/pg" 9 | ) 10 | 11 | // isColNullable tries to prove the column is not nullable. Strive for 12 | // correctness here: it's better to assume a column is nullable when we can't 13 | // know for sure. 14 | func isColNullable(query *ast.SourceQuery, plan Plan, out string, column pg.Column) bool { 15 | switch { 16 | case len(out) == 0: 17 | // No output? Not sure what this means but do the check here so that we 18 | // don't have to do it in each case below. 19 | return false 20 | case strings.HasPrefix(out, "'"): 21 | return false // literal string can't be null 22 | case unicode.IsDigit(rune(out[0])): 23 | return false // literal number can't be null 24 | default: 25 | // try below 26 | } 27 | 28 | // A plain select query (possibly with a LIMIT clause) with no joins where 29 | // the column comes from a table and has a not-null constraint. Not full 30 | // proof because of cross-join with comma syntax. 31 | if (plan.Type == PlanResult || plan.Type == PlanLimit) && 32 | !strings.Contains(strings.ToLower(query.PreparedSQL), "join") && 33 | !column.Null { 34 | return false 35 | } 36 | 37 | // A returning clause in an insert, update, or delete statement. The column 38 | // must come from the underlying table and must have a not null constraint. 39 | if plan.Type == PlanModifyTable && plan.Relation == column.TableName && !column.Null { 40 | return false 41 | } 42 | return true // we can't figure it out; assume nullable 43 | } 44 | -------------------------------------------------------------------------------- /internal/pgplan/pgplan_test.go: -------------------------------------------------------------------------------- 1 | package pgplan 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/google/go-cmp/cmp" 7 | "github.com/google/go-cmp/cmp/cmpopts" 8 | "github.com/jschaf/pggen/internal/pgtest" 9 | "github.com/jschaf/pggen/internal/texts" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestParseNode(t *testing.T) { 15 | tests := []struct { 16 | name string 17 | plan map[string]interface{} 18 | want Node 19 | }{ 20 | { 21 | name: "Result - common fields", 22 | plan: map[string]interface{}{ 23 | "Node Type": "Result", 24 | "Startup Cost": 88.8, 25 | "Total Cost": 99.9, 26 | "Plan Rows": 55.5, 27 | "Plan Width": 44, 28 | "Parallel Aware": true, 29 | "Parallel Safe": true, 30 | }, 31 | want: Result{ 32 | Plan: Plan{ 33 | StartupCost: 88.8, 34 | TotalCost: 99.9, 35 | PlanRows: 55.5, 36 | PlanWidth: 44, 37 | ParallelAware: true, 38 | ParallelSafe: true, 39 | }, 40 | }, 41 | }, 42 | } 43 | for _, tt := range tests { 44 | t.Run(tt.name, func(t *testing.T) { 45 | got, err := ParseNode(tt.plan) 46 | if err != nil { 47 | t.Fatal(err) 48 | } 49 | assert.Equal(t, tt.want, got) 50 | }) 51 | } 52 | } 53 | 54 | func TestParseNode_DB(t *testing.T) { 55 | conn, cleanupFunc := pgtest.NewPostgresSchemaString(t, texts.Dedent(` 56 | CREATE TABLE author ( 57 | author_id int PRIMARY KEY 58 | ); 59 | `)) 60 | defer cleanupFunc() 61 | tests := []struct { 62 | sql string 63 | want Node 64 | }{ 65 | { 66 | sql: "SELECT 1 AS one", 67 | want: Result{Plan{Outs: []string{"1"}}}, 68 | }, 69 | { 70 | sql: "SELECT 1 AS num UNION ALL SELECT 2 AS num", 71 | want: Append{ 72 | Plan{ 73 | Nodes: []Node{ 74 | Result{Plan{Outs: []string{"1"}}}, 75 | Result{Plan{Outs: []string{"2"}}}, 76 | }, 77 | }, 78 | }, 79 | }, 80 | { 81 | sql: "SELECT 1 AS num UNION SELECT 2 AS num", 82 | want: Unique{ 83 | Plan: Plan{ 84 | Outs: []string{"(1)"}, 85 | Nodes: []Node{ 86 | Sort{ 87 | Plan: Plan{ 88 | Outs: []string{"(1)"}, 89 | Nodes: []Node{ 90 | Append{ 91 | Plan{Nodes: []Node{ 92 | Result{Plan{Outs: []string{"1"}}}, 93 | Result{Plan{Outs: []string{"2"}}}, 94 | }}, 95 | }, 96 | }, 97 | }, 98 | SortKey: []string{"(1)"}, 99 | }, 100 | }, 101 | }, 102 | }, 103 | }, 104 | { 105 | sql: "INSERT INTO author (author_id) VALUES (1)", 106 | want: ModifyTable{ 107 | Plan: Plan{ 108 | Nodes: []Node{Result{Plan{Outs: []string{"1"}}}}, 109 | }, 110 | Operation: OperationInsert, 111 | RelationName: "author", 112 | Alias: "author", 113 | }, 114 | }, 115 | { 116 | sql: "INSERT INTO author (author_id) VALUES (1) RETURNING author_id", 117 | want: ModifyTable{ 118 | Plan: Plan{ 119 | Outs: []string{"author.author_id"}, 120 | Nodes: []Node{Result{Plan{Outs: []string{"1"}}}}, 121 | }, 122 | Operation: OperationInsert, 123 | RelationName: "author", 124 | Alias: "author", 125 | }, 126 | }, 127 | { 128 | sql: "SELECT generate_series(1,2)", 129 | want: ProjectSet{ 130 | Plan{ 131 | Outs: []string{"generate_series(1, 2)"}, 132 | Nodes: []Node{Result{}}, 133 | }, 134 | }, 135 | }, 136 | } 137 | for _, tt := range tests { 138 | t.Run(tt.sql, func(t *testing.T) { 139 | got, err := ExplainQuery(conn, tt.sql) 140 | require.NoError(t, err) 141 | 142 | opts := cmp.Options{ 143 | cmpopts.IgnoreFields(Plan{}, 144 | "StartupCost", "TotalCost", "ParallelAware", "ParallelSafe", 145 | "PlanRows", "PlanWidth", "ParentRelationship", 146 | ), 147 | cmpopts.IgnoreFields(ModifyTable{}, "Schema"), 148 | } 149 | if diff := cmp.Diff(tt.want, got, opts); diff != "" { 150 | t.Errorf("ExplainQuery() mismatch (-want +got):\n%s", diff) 151 | } 152 | }) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /internal/pgtest/pg_test_db.go: -------------------------------------------------------------------------------- 1 | package pgtest 2 | 3 | import ( 4 | "context" 5 | "math/rand/v2" 6 | "os" 7 | "strconv" 8 | "strings" 9 | "testing" 10 | "time" 11 | 12 | "github.com/jackc/pgx/v4" 13 | ) 14 | 15 | // CleanupFunc deletes the schema and all database objects. 16 | type CleanupFunc func() 17 | 18 | type Option func(config *pgx.ConnConfig) 19 | 20 | // NewPostgresSchemaString opens a connection with search_path set to a randomly 21 | // named, new schema and loads the sql string. 22 | func NewPostgresSchemaString(t *testing.T, sql string, opts ...Option) (*pgx.Conn, CleanupFunc) { 23 | t.Helper() 24 | // Create a new schema. 25 | connStr := "user=postgres password=hunter2 host=localhost port=5555 dbname=pggen" 26 | ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) 27 | defer cancel() 28 | conn, err := pgx.Connect(ctx, connStr) 29 | if err != nil { 30 | t.Fatalf("connect to docker postgres: %s", err) 31 | } 32 | schema := "pggen_test_" + strconv.Itoa(int(rand.Int32())) //nolint:gosec 33 | if _, err = conn.Exec(ctx, "CREATE SCHEMA "+schema); err != nil { 34 | t.Fatalf("create new schema: %s", err) 35 | } 36 | t.Logf("created schema: %s", schema) 37 | 38 | // Load SQL files into new schema. 39 | connStr += " search_path=" + schema 40 | connCfg, err := pgx.ParseConfig(connStr) 41 | if err != nil { 42 | t.Fatalf("parse config: %q: %s", connStr, err) 43 | } 44 | for _, opt := range opts { 45 | opt(connCfg) 46 | } 47 | schemaConn, err := pgx.ConnectConfig(ctx, connCfg) 48 | if err != nil { 49 | t.Fatalf("connect to docker postgres with search path: %s", err) 50 | } 51 | 52 | if _, err := schemaConn.Exec(ctx, sql); err != nil { 53 | t.Fatalf("run sql: %s", err) 54 | } 55 | 56 | cleanup := func() { 57 | ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) 58 | defer cancel() 59 | if _, err := conn.Exec(ctx, "DROP SCHEMA "+schema+" CASCADE"); err != nil { 60 | t.Errorf("close conn: %s", err) 61 | } 62 | if err := conn.Close(ctx); err != nil { 63 | t.Errorf("close conn: %s", err) 64 | } 65 | if err = schemaConn.Close(ctx); err != nil { 66 | t.Errorf("close schema conn: %s", err) 67 | } 68 | } 69 | return schemaConn, cleanup 70 | } 71 | 72 | // NewPostgresSchema opens a connection with search_path set to a randomly 73 | // named, new schema and loads all sqlFiles. 74 | func NewPostgresSchema(t *testing.T, sqlFiles []string, opts ...Option) (*pgx.Conn, CleanupFunc) { 75 | t.Helper() 76 | sb := &strings.Builder{} 77 | for _, file := range sqlFiles { 78 | bs, err := os.ReadFile(file) 79 | if err != nil { 80 | t.Fatalf("read test db sql file: %s", err) 81 | } 82 | sb.Write(bs) 83 | sb.WriteString(";\n\n -- FILE: ") 84 | sb.WriteString(file) 85 | sb.WriteString("\n") 86 | 87 | } 88 | return NewPostgresSchemaString(t, sb.String(), opts...) 89 | } 90 | -------------------------------------------------------------------------------- /internal/ports/port.go: -------------------------------------------------------------------------------- 1 | package ports 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/jschaf/pggen/internal/errs" 7 | ) 8 | 9 | // Port is a port. 10 | type Port = int 11 | 12 | // Licensed under BSD-3. Copyright (c) 2014, Patrick Hayes. 13 | func FindAvailable() (p Port, mErr error) { 14 | addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 15 | if err != nil { 16 | return 0, err 17 | } 18 | l, err := net.ListenTCP("tcp", addr) 19 | if err != nil { 20 | return 0, err 21 | } 22 | defer errs.Capture(&mErr, l.Close, "") 23 | return l.Addr().(*net.TCPAddr).Port, nil 24 | } 25 | -------------------------------------------------------------------------------- /internal/ptrs/ptrs.go: -------------------------------------------------------------------------------- 1 | package ptrs 2 | 3 | func Int(n int) *int { return &n } 4 | func Int32(n int32) *int32 { return &n } 5 | func Float64(f float64) *float64 { return &f } 6 | func String(s string) *string { return &s } 7 | -------------------------------------------------------------------------------- /internal/texts/dedent.go: -------------------------------------------------------------------------------- 1 | package texts 2 | 3 | import ( 4 | "bytes" 5 | "math" 6 | "strings" 7 | "unicode" 8 | ) 9 | 10 | // Dedent removes leading whitespace indentation from each line in the text. 11 | // 12 | // Whitespace is removed according to smallest whitespace prefix of a 13 | // determining line. A determining line is a line that has at least 1 non-space 14 | // character. The algorithm is: 15 | // 16 | // - If the first line is whitespace, discard it. 17 | // - If the last line is whitespace, discard it. 18 | // - For each remaining line: 19 | // - If the line only has whitespace, replace it with a single newline. 20 | // - If the line has non-whitespace chars, find the common whitespace prefix 21 | // of all such lines. 22 | // 23 | // - Remove the common whitespace prefix from each line. 24 | func Dedent(text string) string { 25 | indent := math.MaxInt32 26 | lines := strings.Split(text, "\n") 27 | 28 | for _, line := range lines { 29 | lineIndent := len(line) 30 | for i, r := range line { 31 | if !unicode.IsSpace(r) { 32 | lineIndent = i 33 | break 34 | } 35 | } 36 | isBlank := lineIndent == len(line) 37 | isDetermining := !isBlank 38 | if isDetermining && lineIndent < indent { 39 | indent = lineIndent 40 | } 41 | } 42 | 43 | start := 1 44 | end := len(lines) - 1 45 | // Should we include the first line? 46 | for _, c := range lines[0] { 47 | if !unicode.IsSpace(c) { 48 | start = 0 49 | } 50 | } 51 | // Should we include the last line? 52 | for _, c := range lines[len(lines)-1] { 53 | if !unicode.IsSpace(c) { 54 | end = len(lines) 55 | } 56 | } 57 | 58 | if end < start { 59 | return text 60 | } 61 | 62 | b := new(bytes.Buffer) 63 | for i, line := range lines[start:end] { 64 | lo := 0 65 | for _, r := range line { 66 | if unicode.IsSpace(r) { 67 | lo++ 68 | } else { 69 | break 70 | } 71 | } 72 | 73 | hi := len(line) 74 | for j := len(line) - 1; j >= 0; j-- { 75 | if unicode.IsSpace(rune(line[j])) { 76 | hi-- 77 | } else { 78 | break 79 | } 80 | } 81 | 82 | if lo >= hi { 83 | b.WriteString("\n") 84 | continue 85 | } 86 | 87 | if lo > indent { 88 | lo = indent 89 | } 90 | 91 | b.WriteString(line[lo:hi]) 92 | if i < end-start-1 { 93 | b.WriteString("\n") 94 | } 95 | } 96 | 97 | return b.String() 98 | } 99 | -------------------------------------------------------------------------------- /internal/texts/dedent_test.go: -------------------------------------------------------------------------------- 1 | package texts 2 | 3 | import "testing" 4 | 5 | func TestDedent(t *testing.T) { 6 | tests := []struct { 7 | name string 8 | input string 9 | want string 10 | }{ 11 | {"empty", "", ""}, 12 | {"only whitespace", "\n ", ""}, 13 | {"trailing newline", "foo\n ", "foo"}, 14 | {"trailing newline + whitespace", "foo\n ", "foo"}, 15 | {"simple", "foo", "foo"}, 16 | {"leading space 1 line", " foo", "foo"}, 17 | {"trailing space 1 line", "foo ", "foo"}, 18 | {"leading + trailing space 1 line", " foo ", "foo"}, 19 | {"preceding newline", "\n foo", "foo"}, 20 | {"preceding newline", "\n foo \n bar \n", " foo\nbar"}, 21 | {"leading space same 3 lines", " foo\n bar\n qux", "foo\nbar\nqux"}, 22 | {"leading space diff 3 lines", " foo\n bar\n qux", " foo\nbar\nqux"}, 23 | } 24 | for _, tt := range tests { 25 | t.Run(tt.name, func(t *testing.T) { 26 | if got := Dedent(tt.input); got != tt.want { 27 | t.Errorf("Dedent():\n '%s'\nwant:\n'%s'", got, tt.want) 28 | } 29 | }) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /internal/token/token.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import "strconv" 4 | 5 | // Token is the minimal set of lexical tokens for SQL that we need to extract 6 | // queries. 7 | type Token int 8 | 9 | const ( 10 | Illegal Token = iota 11 | EOF 12 | LineComment // -- foo 13 | BlockComment // /* foo */ 14 | String // 'foo', $$bar$$, $a$baz$a$ 15 | QuotedIdent // "foo_bar""baz" 16 | QueryFragment // anything else 17 | Semicolon // semicolon ending a query 18 | ) 19 | 20 | func (t Token) String() string { 21 | switch t { 22 | case Illegal: 23 | return "Illegal" 24 | case EOF: 25 | return "EOF" 26 | case LineComment: 27 | return "LineComment" 28 | case BlockComment: 29 | return "BlockComment" 30 | case String: 31 | return "String" 32 | case QuotedIdent: 33 | return "QuotedIdent" 34 | case QueryFragment: 35 | return "QueryFragment" 36 | case Semicolon: 37 | return "Semicolon" 38 | default: 39 | panic("unhandled token.String(): " + strconv.Itoa(int(t))) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /script/.goreleaser.yaml: -------------------------------------------------------------------------------- 1 | # Release pggen. 2 | # https://goreleaser.com 3 | builds: 4 | - env: 5 | - CGO_ENABLED=0 6 | goos: 7 | - linux 8 | - windows 9 | - darwin 10 | goarch: 11 | - amd64 12 | - arm64 13 | main: ./cmd/pggen 14 | ldflags: 15 | - -s -w -X main.version={{.Version}} -X main.commit={{.Commit}} 16 | archives: 17 | - files: 18 | - none* 19 | format: tar.xz 20 | name_template: "{{ .ProjectName }}-{{ .Os }}-{{ .Arch }}" 21 | rlcp: true 22 | checksum: 23 | disable: true 24 | changelog: 25 | sort: asc 26 | use: github-native 27 | -------------------------------------------------------------------------------- /script/release.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | 5 | if [[ -z "${GITHUB_TOKEN:-}" ]]; then 6 | echo 'error: no GITHUB_TOKEN env var' 7 | exit 1 8 | fi 9 | 10 | 11 | # Download github-release if necessary. Only used to delete existing releases. 12 | # We use GoReleaser to create new releases. 13 | githubRelease='github-release' 14 | if ! command -v "$githubRelease"; then 15 | echo 'downloading github-release' 16 | githubRelease="$(mktemp)" 17 | url=https://github.com/github-release/github-release/releases/download/v0.10.0/linux-amd64-github-release.bz2 18 | curl -L --fail --silent "${url}" | bzip2 -dc >"$githubRelease" 19 | chmod +x "$githubRelease" 20 | else 21 | echo 'github-release already downloaded' 22 | fi 23 | 24 | goReleaser='goreleaser' 25 | if ! command -v "$goReleaser"; then 26 | goReleaserUrl='https://github.com/goreleaser/goreleaser/releases/download/v1.10.2/goreleaser_Linux_x86_64.tar.gz' 27 | curl -L --fail --silent "${goReleaserUrl}" | tar xvz >"$goReleaser" 28 | chmod +x "$goReleaser" 29 | else 30 | echo 'goreleaser already downloaded' 31 | fi 32 | 33 | day="$(date '+%Y-%m-%d')" 34 | 35 | # Delete the remote tag since we're creating a new release tagged today. 36 | echo 'deleting existing release tag' 37 | git push origin ":refs/tags/$day" 2>/dev/null 38 | # Create or move the day tag to the latest commit. 39 | git tag -f "$day" 40 | git push origin "$day" 41 | 42 | # Delete any existing releases. We only support 1 release per day. 43 | # Ignore errors if we try to delete a release that doesn't exist. 44 | echo 'deleting existing releases' 45 | "$githubRelease" delete --user jschaf --repo pggen --tag "$day" || true 46 | 47 | echo 48 | echo "creating release $day" 49 | goreleaser release --config ./script/.goreleaser.yaml --clean 50 | --------------------------------------------------------------------------------