├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── go.mod ├── go.sum ├── internal ├── debug │ └── debug.go ├── driver.go ├── endtoend │ └── testdata │ │ └── authors │ │ ├── go │ │ ├── db.go │ │ ├── models.go │ │ └── query.sql.go │ │ ├── query.sql │ │ ├── schema.sql │ │ └── sqlc.yaml ├── enum.go ├── field.go ├── gen.go ├── go_type.go ├── imports.go ├── inflection │ └── singular.go ├── mysql_type.go ├── opts │ ├── enum.go │ ├── go_type.go │ ├── options.go │ ├── override.go │ ├── override_test.go │ └── shim.go ├── postgresql_type.go ├── query.go ├── reserved.go ├── result.go ├── result_test.go ├── sqlite_type.go ├── struct.go ├── template.go └── templates │ ├── go-sql-driver-mysql │ └── copyfromCopy.tmpl │ ├── pgx │ ├── batchCode.tmpl │ ├── copyfromCopy.tmpl │ ├── dbCode.tmpl │ ├── interfaceCode.tmpl │ └── queryCode.tmpl │ ├── stdlib │ ├── dbCode.tmpl │ ├── interfaceCode.tmpl │ └── queryCode.tmpl │ └── template.tmpl └── plugin └── main.go /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: go 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | jobs: 8 | test: 9 | name: test 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-go@v4 14 | with: 15 | go-version: '1.21' 16 | - run: make -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Riza, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build test 2 | 3 | build: 4 | go build ./... 5 | 6 | test: bin/sqlc-gen-go.wasm 7 | go test ./... 8 | 9 | all: bin/sqlc-gen-go bin/sqlc-gen-go.wasm 10 | 11 | bin/sqlc-gen-go: bin go.mod go.sum $(wildcard **/*.go) 12 | cd plugin && go build -o ../bin/sqlc-gen-go ./main.go 13 | 14 | bin/sqlc-gen-go.wasm: bin/sqlc-gen-go 15 | cd plugin && GOOS=wasip1 GOARCH=wasm go build -o ../bin/sqlc-gen-go.wasm main.go 16 | 17 | bin: 18 | mkdir -p bin 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sqlc-gen-go 2 | 3 | > [!IMPORTANT] 4 | > This repository is read-only. It contains a working Go codegen plugin extracted from https://github.com/sqlc-dev/sqlc which you can fork and modify to meet your needs. 5 | 6 | See [Building from source](#building-from-source) and [Migrating from sqlc's built-in Go codegen](#migrating-from-sqlcs-built-in-go-codegen) if you want to use a modified fork in your project. 7 | 8 | ## Usage 9 | 10 | ```yaml 11 | version: '2' 12 | plugins: 13 | - name: golang 14 | wasm: 15 | url: https://downloads.sqlc.dev/plugin/sqlc-gen-go_1.5.0.wasm 16 | sha256: 4ca52949f4dc04b55188439f5de0ae20af2a71e3534b87907f2a7f466bda59ec 17 | sql: 18 | - schema: schema.sql 19 | queries: query.sql 20 | engine: postgresql 21 | codegen: 22 | - plugin: golang 23 | out: db 24 | options: 25 | package: db 26 | sql_package: pgx/v5 27 | ``` 28 | 29 | ## Building from source 30 | 31 | Assuming you have the Go toolchain set up, from the project root you can simply `make all`. 32 | 33 | ```sh 34 | make all 35 | ``` 36 | 37 | This will produce a standalone binary and a WASM blob in the `bin` directory. 38 | They don't depend on each other, they're just two different plugin styles. You can 39 | use either with sqlc, but we recommend WASM and all of the configuration examples 40 | here assume you're using a WASM plugin. 41 | 42 | To use a local WASM build with sqlc, just update your configuration with a `file://` 43 | URL pointing at the WASM blob in your `bin` directory: 44 | 45 | ```yaml 46 | plugins: 47 | - name: golang 48 | wasm: 49 | url: file:///path/to/bin/sqlc-gen-go.wasm 50 | sha256: "" 51 | ``` 52 | 53 | As-of sqlc v1.24.0 the `sha256` is optional, but without it sqlc won't cache your 54 | module internally which will impact performance. 55 | 56 | ## Migrating from sqlc's built-in Go codegen 57 | 58 | We’ve worked hard to make switching to sqlc-gen-go as seamless as possible. Let’s say you’re generating Go code today using a sqlc.yaml configuration that looks something like this: 59 | 60 | ```yaml 61 | version: 2 62 | sql: 63 | - schema: "query.sql" 64 | queries: "query.sql" 65 | engine: "postgresql" 66 | gen: 67 | go: 68 | package: "db" 69 | out: "db" 70 | emit_json_tags: true 71 | emit_pointers_for_null_types: true 72 | query_parameter_limit: 5 73 | overrides: 74 | - column: "authors.id" 75 | go_type: "your/package.SomeType" 76 | rename: 77 | foo: "bar" 78 | ``` 79 | 80 | To use the sqlc-gen-go WASM plugin for Go codegen, your config will instead look something like this: 81 | 82 | ```yaml 83 | version: 2 84 | plugins: 85 | - name: golang 86 | wasm: 87 | url: https://downloads.sqlc.dev/plugin/sqlc-gen-go_1.3.0.wasm 88 | sha256: e8206081686f95b461daf91a307e108a761526c6768d6f3eca9781b0726b7ec8 89 | sql: 90 | - schema: "query.sql" 91 | queries: "query.sql" 92 | engine: "postgresql" 93 | codegen: 94 | - plugin: golang 95 | out: "db" 96 | options: 97 | package: "db" 98 | emit_json_tags: true 99 | emit_pointers_for_null_types: true 100 | query_parameter_limit: 5 101 | overrides: 102 | - column: "authors.id" 103 | go_type: "your/package.SomeType" 104 | rename: 105 | foo: "bar" 106 | ``` 107 | 108 | The differences are: 109 | * An additional top-level `plugins` list with an entry for the Go codegen WASM plugin. If you’ve built the plugin from source you’ll want to use a `file://` URL. The `sha256` field is required, but will be optional in the upcoming sqlc v1.24.0 release. 110 | * Within the `sql` block, rather than `gen` with `go` nested beneath you’ll have a `codegen` list with an entry referencing the plugin name from the top-level `plugins` list. All options from the current `go` configuration block move as-is into the `options` block within `codegen`. The only special case is `out`, which moves up a level into the `codegen` configuration itself. 111 | 112 | ### Global overrides and renames 113 | 114 | If you have global overrides or renames configured, you’ll need to move those to the new top-level `options` field. Replace the existing `go` field name with the name you gave your plugin in the `plugins` list. We’ve used `"golang"` in this example. 115 | 116 | If your existing configuration looks like this: 117 | 118 | ```yaml 119 | version: "2" 120 | overrides: 121 | go: 122 | rename: 123 | id: "Identifier" 124 | overrides: 125 | - db_type: "timestamptz" 126 | nullable: true 127 | engine: "postgresql" 128 | go_type: 129 | import: "gopkg.in/guregu/null.v4" 130 | package: "null" 131 | type: "Time" 132 | ... 133 | ``` 134 | 135 | Then your updated configuration would look something like this: 136 | 137 | ```yaml 138 | version: "2" 139 | plugins: 140 | - name: golang 141 | wasm: 142 | url: https://downloads.sqlc.dev/plugin/sqlc-gen-go_1.3.0.wasm 143 | sha256: e8206081686f95b461daf91a307e108a761526c6768d6f3eca9781b0726b7ec8 144 | options: 145 | golang: 146 | rename: 147 | id: "Identifier" 148 | overrides: 149 | - db_type: "timestamptz" 150 | nullable: true 151 | engine: "postgresql" 152 | go_type: 153 | import: "gopkg.in/guregu/null.v4" 154 | package: "null" 155 | type: "Time" 156 | ... 157 | ``` 158 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/sqlc-dev/sqlc-gen-go 2 | 3 | go 1.21.3 4 | 5 | require ( 6 | github.com/fatih/structtag v1.2.0 7 | github.com/google/go-cmp v0.5.9 8 | github.com/jinzhu/inflection v1.0.0 9 | github.com/sqlc-dev/plugin-sdk-go v1.23.0 10 | ) 11 | 12 | require ( 13 | github.com/golang/protobuf v1.5.3 // indirect 14 | golang.org/x/net v0.14.0 // indirect 15 | golang.org/x/sys v0.11.0 // indirect 16 | golang.org/x/text v0.12.0 // indirect 17 | google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d // indirect 18 | google.golang.org/grpc v1.59.0 // indirect 19 | google.golang.org/protobuf v1.31.0 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= 2 | github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= 3 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 4 | github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= 5 | github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= 6 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 7 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 8 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 9 | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= 10 | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= 11 | github.com/sqlc-dev/plugin-sdk-go v1.23.0 h1:iSeJhnXPlbDXlbzUEebw/DxsGzE9rdDJArl8Hvt0RMM= 12 | github.com/sqlc-dev/plugin-sdk-go v1.23.0/go.mod h1:I1r4THOfyETD+LI2gogN2LX8wCjwUZrgy/NU4In3llA= 13 | golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= 14 | golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= 15 | golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= 16 | golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 17 | golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= 18 | golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= 19 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 20 | google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d h1:uvYuEyMHKNt+lT4K3bN6fGswmK8qSvcreM3BwjDh+y4= 21 | google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d/go.mod h1:+Bk1OCOj40wS2hwAMA+aCW9ypzm63QTBBHp6lQ3p+9M= 22 | google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= 23 | google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= 24 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 25 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= 26 | google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= 27 | google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= 28 | -------------------------------------------------------------------------------- /internal/debug/debug.go: -------------------------------------------------------------------------------- 1 | package debug 2 | 3 | var Active bool 4 | -------------------------------------------------------------------------------- /internal/driver.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import "github.com/sqlc-dev/sqlc-gen-go/internal/opts" 4 | 5 | func parseDriver(sqlPackage string) opts.SQLDriver { 6 | switch sqlPackage { 7 | case opts.SQLPackagePGXV4: 8 | return opts.SQLDriverPGXV4 9 | case opts.SQLPackagePGXV5: 10 | return opts.SQLDriverPGXV5 11 | default: 12 | return opts.SQLDriverLibPQ 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /internal/endtoend/testdata/authors/go/db.go: -------------------------------------------------------------------------------- 1 | // Code generated by sqlc. DO NOT EDIT. 2 | // versions: 3 | // sqlc v1.23.0 4 | 5 | package querytest 6 | 7 | import ( 8 | "context" 9 | 10 | "github.com/jackc/pgx/v5" 11 | "github.com/jackc/pgx/v5/pgconn" 12 | ) 13 | 14 | type DBTX interface { 15 | Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) 16 | Query(context.Context, string, ...interface{}) (pgx.Rows, error) 17 | QueryRow(context.Context, string, ...interface{}) pgx.Row 18 | } 19 | 20 | func New(db DBTX) *Queries { 21 | return &Queries{db: db} 22 | } 23 | 24 | type Queries struct { 25 | db DBTX 26 | } 27 | 28 | func (q *Queries) WithTx(tx pgx.Tx) *Queries { 29 | return &Queries{ 30 | db: tx, 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /internal/endtoend/testdata/authors/go/models.go: -------------------------------------------------------------------------------- 1 | // Code generated by sqlc. DO NOT EDIT. 2 | // versions: 3 | // sqlc v1.23.0 4 | 5 | package querytest 6 | 7 | import ( 8 | "github.com/jackc/pgx/v5/pgtype" 9 | ) 10 | 11 | type Author struct { 12 | ID int64 13 | Name string 14 | Bio pgtype.Text 15 | } 16 | -------------------------------------------------------------------------------- /internal/endtoend/testdata/authors/go/query.sql.go: -------------------------------------------------------------------------------- 1 | // Code generated by sqlc. DO NOT EDIT. 2 | // versions: 3 | // sqlc v1.23.0 4 | // source: query.sql 5 | 6 | package querytest 7 | 8 | import ( 9 | "context" 10 | 11 | "github.com/jackc/pgx/v5/pgtype" 12 | ) 13 | 14 | const createAuthor = `-- name: CreateAuthor :one 15 | INSERT INTO authors ( 16 | name, bio 17 | ) VALUES ( 18 | $1, $2 19 | ) 20 | RETURNING id, name, bio 21 | ` 22 | 23 | type CreateAuthorParams struct { 24 | Name string 25 | Bio pgtype.Text 26 | } 27 | 28 | func (q *Queries) CreateAuthor(ctx context.Context, arg CreateAuthorParams) (Author, error) { 29 | row := q.db.QueryRow(ctx, createAuthor, arg.Name, arg.Bio) 30 | var i Author 31 | err := row.Scan(&i.ID, &i.Name, &i.Bio) 32 | return i, err 33 | } 34 | 35 | const deleteAuthor = `-- name: DeleteAuthor :exec 36 | DELETE FROM authors 37 | WHERE id = $1 38 | ` 39 | 40 | func (q *Queries) DeleteAuthor(ctx context.Context, id int64) error { 41 | _, err := q.db.Exec(ctx, deleteAuthor, id) 42 | return err 43 | } 44 | 45 | const getAuthor = `-- name: GetAuthor :one 46 | SELECT id, name, bio FROM authors 47 | WHERE id = $1 LIMIT 1 48 | ` 49 | 50 | func (q *Queries) GetAuthor(ctx context.Context, id int64) (Author, error) { 51 | row := q.db.QueryRow(ctx, getAuthor, id) 52 | var i Author 53 | err := row.Scan(&i.ID, &i.Name, &i.Bio) 54 | return i, err 55 | } 56 | 57 | const listAuthors = `-- name: ListAuthors :many 58 | SELECT id, name, bio FROM authors 59 | ORDER BY name 60 | ` 61 | 62 | func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { 63 | rows, err := q.db.Query(ctx, listAuthors) 64 | if err != nil { 65 | return nil, err 66 | } 67 | defer rows.Close() 68 | var items []Author 69 | for rows.Next() { 70 | var i Author 71 | if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { 72 | return nil, err 73 | } 74 | items = append(items, i) 75 | } 76 | if err := rows.Err(); err != nil { 77 | return nil, err 78 | } 79 | return items, nil 80 | } 81 | -------------------------------------------------------------------------------- /internal/endtoend/testdata/authors/query.sql: -------------------------------------------------------------------------------- 1 | -- name: GetAuthor :one 2 | SELECT * FROM authors 3 | WHERE id = $1 LIMIT 1; 4 | 5 | -- name: ListAuthors :many 6 | SELECT * FROM authors 7 | ORDER BY name; 8 | 9 | -- name: CreateAuthor :one 10 | INSERT INTO authors ( 11 | name, bio 12 | ) VALUES ( 13 | $1, $2 14 | ) 15 | RETURNING *; 16 | 17 | -- name: DeleteAuthor :exec 18 | DELETE FROM authors 19 | WHERE id = $1; -------------------------------------------------------------------------------- /internal/endtoend/testdata/authors/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE authors ( 2 | id BIGSERIAL PRIMARY KEY, 3 | name text NOT NULL, 4 | bio text 5 | ); -------------------------------------------------------------------------------- /internal/endtoend/testdata/authors/sqlc.yaml: -------------------------------------------------------------------------------- 1 | version: '2' 2 | plugins: 3 | - name: golang 4 | wasm: 5 | url: https://downloads.sqlc.dev/plugin/sqlc-gen-go_1.0.0.wasm 6 | sha256: dbe302a0208afd31118fffcc268bd39b295655dfa9e3f385d2f4413544cfbed1 7 | sql: 8 | - schema: schema.sql 9 | queries: query.sql 10 | engine: postgresql 11 | codegen: 12 | - plugin: golang 13 | out: go 14 | options: 15 | package: querytest 16 | sql_package: pgx/v5 17 | -------------------------------------------------------------------------------- /internal/enum.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "strings" 5 | "unicode" 6 | ) 7 | 8 | type Constant struct { 9 | Name string 10 | Type string 11 | Value string 12 | } 13 | 14 | type Enum struct { 15 | Name string 16 | Comment string 17 | Constants []Constant 18 | NameTags map[string]string 19 | ValidTags map[string]string 20 | } 21 | 22 | func (e Enum) NameTag() string { 23 | return TagsToString(e.NameTags) 24 | } 25 | 26 | func (e Enum) ValidTag() string { 27 | return TagsToString(e.ValidTags) 28 | } 29 | 30 | func enumReplacer(r rune) rune { 31 | if strings.ContainsRune("-/:_", r) { 32 | return '_' 33 | } else if (r >= 'a' && r <= 'z') || 34 | (r >= 'A' && r <= 'Z') || 35 | (r >= '0' && r <= '9') { 36 | return r 37 | } else { 38 | return -1 39 | } 40 | } 41 | 42 | // EnumReplace removes all non ident symbols (all but letters, numbers and 43 | // underscore) and returns valid ident name for provided name. 44 | func EnumReplace(value string) string { 45 | return strings.Map(enumReplacer, value) 46 | } 47 | 48 | // EnumValueName removes all non ident symbols (all but letters, numbers and 49 | // underscore) and converts snake case ident to camel case. 50 | func EnumValueName(value string) string { 51 | parts := strings.Split(EnumReplace(value), "_") 52 | for i, part := range parts { 53 | parts[i] = titleFirst(part) 54 | } 55 | 56 | return strings.Join(parts, "") 57 | } 58 | 59 | func titleFirst(s string) string { 60 | r := []rune(s) 61 | r[0] = unicode.ToUpper(r[0]) 62 | 63 | return string(r) 64 | } 65 | -------------------------------------------------------------------------------- /internal/field.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | "sort" 7 | "strings" 8 | 9 | "github.com/sqlc-dev/sqlc-gen-go/internal/opts" 10 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 11 | ) 12 | 13 | type Field struct { 14 | Name string // CamelCased name for Go 15 | DBName string // Name as used in the DB 16 | Type string 17 | Tags map[string]string 18 | Comment string 19 | Column *plugin.Column 20 | // EmbedFields contains the embedded fields that require scanning. 21 | EmbedFields []Field 22 | } 23 | 24 | func (gf Field) Tag() string { 25 | return TagsToString(gf.Tags) 26 | } 27 | 28 | func (gf Field) HasSqlcSlice() bool { 29 | return gf.Column.IsSqlcSlice 30 | } 31 | 32 | func TagsToString(tags map[string]string) string { 33 | if len(tags) == 0 { 34 | return "" 35 | } 36 | tagParts := make([]string, 0, len(tags)) 37 | for key, val := range tags { 38 | tagParts = append(tagParts, fmt.Sprintf("%s:%q", key, val)) 39 | } 40 | sort.Strings(tagParts) 41 | return strings.Join(tagParts, " ") 42 | } 43 | 44 | func JSONTagName(name string, options *opts.Options) string { 45 | style := options.JsonTagsCaseStyle 46 | idUppercase := options.JsonTagsIdUppercase 47 | if style == "" || style == "none" { 48 | return name 49 | } else { 50 | return SetJSONCaseStyle(name, style, idUppercase) 51 | } 52 | } 53 | 54 | func SetCaseStyle(name string, style string) string { 55 | switch style { 56 | case "camel": 57 | return toCamelCase(name) 58 | case "pascal": 59 | return toPascalCase(name) 60 | case "snake": 61 | return toSnakeCase(name) 62 | default: 63 | panic(fmt.Sprintf("unsupported JSON tags case style: '%s'", style)) 64 | } 65 | } 66 | 67 | func SetJSONCaseStyle(name string, style string, idUppercase bool) string { 68 | switch style { 69 | case "camel": 70 | return toJsonCamelCase(name, idUppercase) 71 | case "pascal": 72 | return toPascalCase(name) 73 | case "snake": 74 | return toSnakeCase(name) 75 | default: 76 | panic(fmt.Sprintf("unsupported JSON tags case style: '%s'", style)) 77 | } 78 | } 79 | 80 | var camelPattern = regexp.MustCompile("[^A-Z][A-Z]+") 81 | 82 | func toSnakeCase(s string) string { 83 | if !strings.ContainsRune(s, '_') { 84 | s = camelPattern.ReplaceAllStringFunc(s, func(x string) string { 85 | return x[:1] + "_" + x[1:] 86 | }) 87 | } 88 | return strings.ToLower(s) 89 | } 90 | 91 | func toCamelCase(s string) string { 92 | return toCamelInitCase(s, false) 93 | } 94 | 95 | func toPascalCase(s string) string { 96 | return toCamelInitCase(s, true) 97 | } 98 | 99 | func toCamelInitCase(name string, initUpper bool) string { 100 | out := "" 101 | for i, p := range strings.Split(name, "_") { 102 | if !initUpper && i == 0 { 103 | out += p 104 | continue 105 | } 106 | if p == "id" { 107 | out += "ID" 108 | } else { 109 | out += strings.Title(p) 110 | } 111 | } 112 | return out 113 | } 114 | 115 | func toJsonCamelCase(name string, idUppercase bool) string { 116 | out := "" 117 | idStr := "Id" 118 | 119 | if idUppercase { 120 | idStr = "ID" 121 | } 122 | 123 | for i, p := range strings.Split(name, "_") { 124 | if i == 0 { 125 | out += p 126 | continue 127 | } 128 | if p == "id" { 129 | out += idStr 130 | } else { 131 | out += strings.Title(p) 132 | } 133 | } 134 | return out 135 | } 136 | 137 | func toLowerCase(str string) string { 138 | if str == "" { 139 | return "" 140 | } 141 | 142 | return strings.ToLower(str[:1]) + str[1:] 143 | } 144 | -------------------------------------------------------------------------------- /internal/gen.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "errors" 8 | "fmt" 9 | "go/format" 10 | "strings" 11 | "text/template" 12 | 13 | "github.com/sqlc-dev/sqlc-gen-go/internal/opts" 14 | "github.com/sqlc-dev/plugin-sdk-go/sdk" 15 | "github.com/sqlc-dev/plugin-sdk-go/metadata" 16 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 17 | ) 18 | 19 | type tmplCtx struct { 20 | Q string 21 | Package string 22 | SQLDriver opts.SQLDriver 23 | Enums []Enum 24 | Structs []Struct 25 | GoQueries []Query 26 | SqlcVersion string 27 | 28 | // TODO: Race conditions 29 | SourceName string 30 | 31 | EmitJSONTags bool 32 | JsonTagsIDUppercase bool 33 | EmitDBTags bool 34 | EmitPreparedQueries bool 35 | EmitInterface bool 36 | EmitEmptySlices bool 37 | EmitMethodsWithDBArgument bool 38 | EmitEnumValidMethod bool 39 | EmitAllEnumValues bool 40 | UsesCopyFrom bool 41 | UsesBatch bool 42 | OmitSqlcVersion bool 43 | BuildTags string 44 | } 45 | 46 | func (t *tmplCtx) OutputQuery(sourceName string) bool { 47 | return t.SourceName == sourceName 48 | } 49 | 50 | func (t *tmplCtx) codegenDbarg() string { 51 | if t.EmitMethodsWithDBArgument { 52 | return "db DBTX, " 53 | } 54 | return "" 55 | } 56 | 57 | // Called as a global method since subtemplate queryCodeStdExec does not have 58 | // access to the toplevel tmplCtx 59 | func (t *tmplCtx) codegenEmitPreparedQueries() bool { 60 | return t.EmitPreparedQueries 61 | } 62 | 63 | func (t *tmplCtx) codegenQueryMethod(q Query) string { 64 | db := "q.db" 65 | if t.EmitMethodsWithDBArgument { 66 | db = "db" 67 | } 68 | 69 | switch q.Cmd { 70 | case ":one": 71 | if t.EmitPreparedQueries { 72 | return "q.queryRow" 73 | } 74 | return db + ".QueryRowContext" 75 | 76 | case ":many": 77 | if t.EmitPreparedQueries { 78 | return "q.query" 79 | } 80 | return db + ".QueryContext" 81 | 82 | default: 83 | if t.EmitPreparedQueries { 84 | return "q.exec" 85 | } 86 | return db + ".ExecContext" 87 | } 88 | } 89 | 90 | func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) { 91 | switch q.Cmd { 92 | case ":one": 93 | return "row :=", nil 94 | case ":many": 95 | return "rows, err :=", nil 96 | case ":exec": 97 | return "_, err :=", nil 98 | case ":execrows", ":execlastid": 99 | return "result, err :=", nil 100 | case ":execresult": 101 | return "return", nil 102 | default: 103 | return "", fmt.Errorf("unhandled q.Cmd case %q", q.Cmd) 104 | } 105 | } 106 | 107 | func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) { 108 | options, err := opts.Parse(req) 109 | if err != nil { 110 | return nil, err 111 | } 112 | 113 | if err := opts.ValidateOpts(options); err != nil { 114 | return nil, err 115 | } 116 | 117 | enums := buildEnums(req, options) 118 | structs := buildStructs(req, options) 119 | queries, err := buildQueries(req, options, structs) 120 | if err != nil { 121 | return nil, err 122 | } 123 | 124 | if options.OmitUnusedStructs { 125 | enums, structs = filterUnusedStructs(enums, structs, queries) 126 | } 127 | 128 | if err := validate(options, enums, structs, queries); err != nil { 129 | return nil, err 130 | } 131 | 132 | return generate(req, options, enums, structs, queries) 133 | } 134 | 135 | func validate(options *opts.Options, enums []Enum, structs []Struct, queries []Query) error { 136 | enumNames := make(map[string]struct{}) 137 | for _, enum := range enums { 138 | enumNames[enum.Name] = struct{}{} 139 | enumNames["Null"+enum.Name] = struct{}{} 140 | } 141 | structNames := make(map[string]struct{}) 142 | for _, struckt := range structs { 143 | if _, ok := enumNames[struckt.Name]; ok { 144 | return fmt.Errorf("struct name conflicts with enum name: %s", struckt.Name) 145 | } 146 | structNames[struckt.Name] = struct{}{} 147 | } 148 | if !options.EmitExportedQueries { 149 | return nil 150 | } 151 | for _, query := range queries { 152 | if _, ok := enumNames[query.ConstantName]; ok { 153 | return fmt.Errorf("query constant name conflicts with enum name: %s", query.ConstantName) 154 | } 155 | if _, ok := structNames[query.ConstantName]; ok { 156 | return fmt.Errorf("query constant name conflicts with struct name: %s", query.ConstantName) 157 | } 158 | } 159 | return nil 160 | } 161 | 162 | func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, structs []Struct, queries []Query) (*plugin.GenerateResponse, error) { 163 | i := &importer{ 164 | Options: options, 165 | Queries: queries, 166 | Enums: enums, 167 | Structs: structs, 168 | } 169 | 170 | tctx := tmplCtx{ 171 | EmitInterface: options.EmitInterface, 172 | EmitJSONTags: options.EmitJsonTags, 173 | JsonTagsIDUppercase: options.JsonTagsIdUppercase, 174 | EmitDBTags: options.EmitDbTags, 175 | EmitPreparedQueries: options.EmitPreparedQueries, 176 | EmitEmptySlices: options.EmitEmptySlices, 177 | EmitMethodsWithDBArgument: options.EmitMethodsWithDbArgument, 178 | EmitEnumValidMethod: options.EmitEnumValidMethod, 179 | EmitAllEnumValues: options.EmitAllEnumValues, 180 | UsesCopyFrom: usesCopyFrom(queries), 181 | UsesBatch: usesBatch(queries), 182 | SQLDriver: parseDriver(options.SqlPackage), 183 | Q: "`", 184 | Package: options.Package, 185 | Enums: enums, 186 | Structs: structs, 187 | SqlcVersion: req.SqlcVersion, 188 | BuildTags: options.BuildTags, 189 | OmitSqlcVersion: options.OmitSqlcVersion, 190 | } 191 | 192 | if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != opts.SQLDriverGoSQLDriverMySQL { 193 | return nil, errors.New(":copyfrom is only supported by pgx and github.com/go-sql-driver/mysql") 194 | } 195 | 196 | if tctx.UsesCopyFrom && options.SqlDriver == opts.SQLDriverGoSQLDriverMySQL { 197 | if err := checkNoTimesForMySQLCopyFrom(queries); err != nil { 198 | return nil, err 199 | } 200 | tctx.SQLDriver = opts.SQLDriverGoSQLDriverMySQL 201 | } 202 | 203 | if tctx.UsesBatch && !tctx.SQLDriver.IsPGX() { 204 | return nil, errors.New(":batch* commands are only supported by pgx") 205 | } 206 | 207 | funcMap := template.FuncMap{ 208 | "lowerTitle": sdk.LowerTitle, 209 | "comment": sdk.DoubleSlashComment, 210 | "escape": sdk.EscapeBacktick, 211 | "imports": i.Imports, 212 | "hasImports": i.HasImports, 213 | "hasPrefix": strings.HasPrefix, 214 | 215 | // These methods are Go specific, they do not belong in the codegen package 216 | // (as that is language independent) 217 | "dbarg": tctx.codegenDbarg, 218 | "emitPreparedQueries": tctx.codegenEmitPreparedQueries, 219 | "queryMethod": tctx.codegenQueryMethod, 220 | "queryRetval": tctx.codegenQueryRetval, 221 | } 222 | 223 | tmpl := template.Must( 224 | template.New("table"). 225 | Funcs(funcMap). 226 | ParseFS( 227 | templates, 228 | "templates/*.tmpl", 229 | "templates/*/*.tmpl", 230 | ), 231 | ) 232 | 233 | output := map[string]string{} 234 | 235 | execute := func(name, templateName string) error { 236 | imports := i.Imports(name) 237 | replacedQueries := replaceConflictedArg(imports, queries) 238 | 239 | var b bytes.Buffer 240 | w := bufio.NewWriter(&b) 241 | tctx.SourceName = name 242 | tctx.GoQueries = replacedQueries 243 | err := tmpl.ExecuteTemplate(w, templateName, &tctx) 244 | w.Flush() 245 | if err != nil { 246 | return err 247 | } 248 | code, err := format.Source(b.Bytes()) 249 | if err != nil { 250 | fmt.Println(b.String()) 251 | return fmt.Errorf("source error: %w", err) 252 | } 253 | 254 | if templateName == "queryFile" && options.OutputFilesSuffix != "" { 255 | name += options.OutputFilesSuffix 256 | } 257 | 258 | if !strings.HasSuffix(name, ".go") { 259 | name += ".go" 260 | } 261 | output[name] = string(code) 262 | return nil 263 | } 264 | 265 | dbFileName := "db.go" 266 | if options.OutputDbFileName != "" { 267 | dbFileName = options.OutputDbFileName 268 | } 269 | modelsFileName := "models.go" 270 | if options.OutputModelsFileName != "" { 271 | modelsFileName = options.OutputModelsFileName 272 | } 273 | querierFileName := "querier.go" 274 | if options.OutputQuerierFileName != "" { 275 | querierFileName = options.OutputQuerierFileName 276 | } 277 | copyfromFileName := "copyfrom.go" 278 | if options.OutputCopyfromFileName != "" { 279 | copyfromFileName = options.OutputCopyfromFileName 280 | } 281 | 282 | batchFileName := "batch.go" 283 | if options.OutputBatchFileName != "" { 284 | batchFileName = options.OutputBatchFileName 285 | } 286 | 287 | if err := execute(dbFileName, "dbFile"); err != nil { 288 | return nil, err 289 | } 290 | if err := execute(modelsFileName, "modelsFile"); err != nil { 291 | return nil, err 292 | } 293 | if options.EmitInterface { 294 | if err := execute(querierFileName, "interfaceFile"); err != nil { 295 | return nil, err 296 | } 297 | } 298 | if tctx.UsesCopyFrom { 299 | if err := execute(copyfromFileName, "copyfromFile"); err != nil { 300 | return nil, err 301 | } 302 | } 303 | if tctx.UsesBatch { 304 | if err := execute(batchFileName, "batchFile"); err != nil { 305 | return nil, err 306 | } 307 | } 308 | 309 | files := map[string]struct{}{} 310 | for _, gq := range queries { 311 | files[gq.SourceName] = struct{}{} 312 | } 313 | 314 | for source := range files { 315 | if err := execute(source, "queryFile"); err != nil { 316 | return nil, err 317 | } 318 | } 319 | resp := plugin.GenerateResponse{} 320 | 321 | for filename, code := range output { 322 | resp.Files = append(resp.Files, &plugin.File{ 323 | Name: filename, 324 | Contents: []byte(code), 325 | }) 326 | } 327 | 328 | return &resp, nil 329 | } 330 | 331 | func usesCopyFrom(queries []Query) bool { 332 | for _, q := range queries { 333 | if q.Cmd == metadata.CmdCopyFrom { 334 | return true 335 | } 336 | } 337 | return false 338 | } 339 | 340 | func usesBatch(queries []Query) bool { 341 | for _, q := range queries { 342 | for _, cmd := range []string{metadata.CmdBatchExec, metadata.CmdBatchMany, metadata.CmdBatchOne} { 343 | if q.Cmd == cmd { 344 | return true 345 | } 346 | } 347 | } 348 | return false 349 | } 350 | 351 | func checkNoTimesForMySQLCopyFrom(queries []Query) error { 352 | for _, q := range queries { 353 | if q.Cmd != metadata.CmdCopyFrom { 354 | continue 355 | } 356 | for _, f := range q.Arg.CopyFromMySQLFields() { 357 | if f.Type == "time.Time" { 358 | return fmt.Errorf("values with a timezone are not yet supported") 359 | } 360 | } 361 | } 362 | return nil 363 | } 364 | 365 | func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enum, []Struct) { 366 | keepTypes := make(map[string]struct{}) 367 | 368 | for _, query := range queries { 369 | if !query.Arg.isEmpty() { 370 | keepTypes[query.Arg.Type()] = struct{}{} 371 | if query.Arg.IsStruct() { 372 | for _, field := range query.Arg.Struct.Fields { 373 | keepTypes[field.Type] = struct{}{} 374 | } 375 | } 376 | } 377 | if query.hasRetType() { 378 | keepTypes[query.Ret.Type()] = struct{}{} 379 | if query.Ret.IsStruct() { 380 | for _, field := range query.Ret.Struct.Fields { 381 | keepTypes[field.Type] = struct{}{} 382 | for _, embedField := range field.EmbedFields { 383 | keepTypes[embedField.Type] = struct{}{} 384 | } 385 | } 386 | } 387 | } 388 | } 389 | 390 | keepEnums := make([]Enum, 0, len(enums)) 391 | for _, enum := range enums { 392 | _, keep := keepTypes[enum.Name] 393 | _, keepNull := keepTypes["Null"+enum.Name] 394 | if keep || keepNull { 395 | keepEnums = append(keepEnums, enum) 396 | } 397 | } 398 | 399 | keepStructs := make([]Struct, 0, len(structs)) 400 | for _, st := range structs { 401 | if _, ok := keepTypes[st.Name]; ok { 402 | keepStructs = append(keepStructs, st) 403 | } 404 | } 405 | 406 | return keepEnums, keepStructs 407 | } 408 | -------------------------------------------------------------------------------- /internal/go_type.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/sqlc-dev/sqlc-gen-go/internal/opts" 7 | "github.com/sqlc-dev/plugin-sdk-go/sdk" 8 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 9 | ) 10 | 11 | func addExtraGoStructTags(tags map[string]string, req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) { 12 | for _, override := range options.Overrides { 13 | oride := override.ShimOverride 14 | if oride.GoType.StructTags == nil { 15 | continue 16 | } 17 | if !override.Matches(col.Table, req.Catalog.DefaultSchema) { 18 | // Different table. 19 | continue 20 | } 21 | cname := col.Name 22 | if col.OriginalName != "" { 23 | cname = col.OriginalName 24 | } 25 | if !sdk.MatchString(oride.ColumnName, cname) { 26 | // Different column. 27 | continue 28 | } 29 | // Add the extra tags. 30 | for k, v := range oride.GoType.StructTags { 31 | tags[k] = v 32 | } 33 | } 34 | } 35 | 36 | func goType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string { 37 | // Check if the column's type has been overridden 38 | for _, override := range options.Overrides { 39 | oride := override.ShimOverride 40 | 41 | if oride.GoType.TypeName == "" { 42 | continue 43 | } 44 | cname := col.Name 45 | if col.OriginalName != "" { 46 | cname = col.OriginalName 47 | } 48 | sameTable := override.Matches(col.Table, req.Catalog.DefaultSchema) 49 | if oride.Column != "" && sdk.MatchString(oride.ColumnName, cname) && sameTable { 50 | if col.IsSqlcSlice { 51 | return "[]" + oride.GoType.TypeName 52 | } 53 | return oride.GoType.TypeName 54 | } 55 | } 56 | typ := goInnerType(req, options, col) 57 | if col.IsSqlcSlice { 58 | return "[]" + typ 59 | } 60 | if col.IsArray { 61 | return strings.Repeat("[]", int(col.ArrayDims)) + typ 62 | } 63 | return typ 64 | } 65 | 66 | func goInnerType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string { 67 | columnType := sdk.DataType(col.Type) 68 | notNull := col.NotNull || col.IsArray 69 | 70 | // package overrides have a higher precedence 71 | for _, override := range options.Overrides { 72 | oride := override.ShimOverride 73 | if oride.GoType.TypeName == "" { 74 | continue 75 | } 76 | if oride.DbType != "" && oride.DbType == columnType && oride.Nullable != notNull && oride.Unsigned == col.Unsigned { 77 | return oride.GoType.TypeName 78 | } 79 | } 80 | 81 | // TODO: Extend the engine interface to handle types 82 | switch req.Settings.Engine { 83 | case "mysql": 84 | return mysqlType(req, options, col) 85 | case "postgresql": 86 | return postgresType(req, options, col) 87 | case "sqlite": 88 | return sqliteType(req, options, col) 89 | default: 90 | return "interface{}" 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /internal/imports.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "strings" 7 | 8 | "github.com/sqlc-dev/sqlc-gen-go/internal/opts" 9 | "github.com/sqlc-dev/plugin-sdk-go/metadata" 10 | ) 11 | 12 | type fileImports struct { 13 | Std []ImportSpec 14 | Dep []ImportSpec 15 | } 16 | 17 | type ImportSpec struct { 18 | ID string 19 | Path string 20 | } 21 | 22 | func (s ImportSpec) String() string { 23 | if s.ID != "" { 24 | return fmt.Sprintf("%s %q", s.ID, s.Path) 25 | } else { 26 | return fmt.Sprintf("%q", s.Path) 27 | } 28 | } 29 | 30 | func mergeImports(imps ...fileImports) [][]ImportSpec { 31 | if len(imps) == 1 { 32 | return [][]ImportSpec{ 33 | imps[0].Std, 34 | imps[0].Dep, 35 | } 36 | } 37 | 38 | var stds, pkgs []ImportSpec 39 | seenStd := map[string]struct{}{} 40 | seenPkg := map[string]struct{}{} 41 | for i := range imps { 42 | for _, spec := range imps[i].Std { 43 | if _, ok := seenStd[spec.Path]; ok { 44 | continue 45 | } 46 | stds = append(stds, spec) 47 | seenStd[spec.Path] = struct{}{} 48 | } 49 | for _, spec := range imps[i].Dep { 50 | if _, ok := seenPkg[spec.Path]; ok { 51 | continue 52 | } 53 | pkgs = append(pkgs, spec) 54 | seenPkg[spec.Path] = struct{}{} 55 | } 56 | } 57 | return [][]ImportSpec{stds, pkgs} 58 | } 59 | 60 | type importer struct { 61 | Options *opts.Options 62 | Queries []Query 63 | Enums []Enum 64 | Structs []Struct 65 | } 66 | 67 | func (i *importer) usesType(typ string) bool { 68 | for _, strct := range i.Structs { 69 | for _, f := range strct.Fields { 70 | if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, typ) { 71 | return true 72 | } 73 | } 74 | } 75 | return false 76 | } 77 | 78 | func (i *importer) HasImports(filename string) bool { 79 | imports := i.Imports(filename) 80 | return len(imports[0]) != 0 || len(imports[1]) != 0 81 | } 82 | 83 | func (i *importer) Imports(filename string) [][]ImportSpec { 84 | dbFileName := "db.go" 85 | if i.Options.OutputDbFileName != "" { 86 | dbFileName = i.Options.OutputDbFileName 87 | } 88 | modelsFileName := "models.go" 89 | if i.Options.OutputModelsFileName != "" { 90 | modelsFileName = i.Options.OutputModelsFileName 91 | } 92 | querierFileName := "querier.go" 93 | if i.Options.OutputQuerierFileName != "" { 94 | querierFileName = i.Options.OutputQuerierFileName 95 | } 96 | copyfromFileName := "copyfrom.go" 97 | if i.Options.OutputCopyfromFileName != "" { 98 | copyfromFileName = i.Options.OutputCopyfromFileName 99 | } 100 | batchFileName := "batch.go" 101 | if i.Options.OutputBatchFileName != "" { 102 | batchFileName = i.Options.OutputBatchFileName 103 | } 104 | 105 | switch filename { 106 | case dbFileName: 107 | return mergeImports(i.dbImports()) 108 | case modelsFileName: 109 | return mergeImports(i.modelImports()) 110 | case querierFileName: 111 | return mergeImports(i.interfaceImports()) 112 | case copyfromFileName: 113 | return mergeImports(i.copyfromImports()) 114 | case batchFileName: 115 | return mergeImports(i.batchImports()) 116 | default: 117 | return mergeImports(i.queryImports(filename)) 118 | } 119 | } 120 | 121 | func (i *importer) dbImports() fileImports { 122 | var pkg []ImportSpec 123 | std := []ImportSpec{ 124 | {Path: "context"}, 125 | } 126 | 127 | sqlpkg := parseDriver(i.Options.SqlPackage) 128 | switch sqlpkg { 129 | case opts.SQLDriverPGXV4: 130 | pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgconn"}) 131 | pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v4"}) 132 | case opts.SQLDriverPGXV5: 133 | pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"}) 134 | pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5"}) 135 | default: 136 | std = append(std, ImportSpec{Path: "database/sql"}) 137 | if i.Options.EmitPreparedQueries { 138 | std = append(std, ImportSpec{Path: "fmt"}) 139 | } 140 | } 141 | 142 | sort.Slice(std, func(i, j int) bool { return std[i].Path < std[j].Path }) 143 | sort.Slice(pkg, func(i, j int) bool { return pkg[i].Path < pkg[j].Path }) 144 | return fileImports{Std: std, Dep: pkg} 145 | } 146 | 147 | var stdlibTypes = map[string]string{ 148 | "json.RawMessage": "encoding/json", 149 | "time.Time": "time", 150 | "net.IP": "net", 151 | "net.HardwareAddr": "net", 152 | "netip.Addr": "net/netip", 153 | "netip.Prefix": "net/netip", 154 | } 155 | 156 | var pqtypeTypes = map[string]struct{}{ 157 | "pqtype.CIDR": {}, 158 | "pqtype.Inet": {}, 159 | "pqtype.Macaddr": {}, 160 | "pqtype.NullRawMessage": {}, 161 | } 162 | 163 | func buildImports(options *opts.Options, queries []Query, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) { 164 | pkg := make(map[ImportSpec]struct{}) 165 | std := make(map[string]struct{}) 166 | 167 | if uses("sql.Null") { 168 | std["database/sql"] = struct{}{} 169 | } 170 | 171 | sqlpkg := parseDriver(options.SqlPackage) 172 | for _, q := range queries { 173 | if q.Cmd == metadata.CmdExecResult { 174 | switch sqlpkg { 175 | case opts.SQLDriverPGXV4: 176 | pkg[ImportSpec{Path: "github.com/jackc/pgconn"}] = struct{}{} 177 | case opts.SQLDriverPGXV5: 178 | pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"}] = struct{}{} 179 | default: 180 | std["database/sql"] = struct{}{} 181 | } 182 | } 183 | } 184 | 185 | for typeName, pkg := range stdlibTypes { 186 | if uses(typeName) { 187 | std[pkg] = struct{}{} 188 | } 189 | } 190 | 191 | if uses("pgtype.") { 192 | if sqlpkg == opts.SQLDriverPGXV5 { 193 | pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgtype"}] = struct{}{} 194 | } else { 195 | pkg[ImportSpec{Path: "github.com/jackc/pgtype"}] = struct{}{} 196 | } 197 | } 198 | 199 | for typeName := range pqtypeTypes { 200 | if uses(typeName) { 201 | pkg[ImportSpec{Path: "github.com/sqlc-dev/pqtype"}] = struct{}{} 202 | break 203 | } 204 | } 205 | 206 | overrideTypes := map[string]string{} 207 | for _, override := range options.Overrides { 208 | o := override.ShimOverride 209 | if o.GoType.BasicType || o.GoType.TypeName == "" { 210 | continue 211 | } 212 | overrideTypes[o.GoType.TypeName] = o.GoType.ImportPath 213 | } 214 | 215 | _, overrideNullTime := overrideTypes["pq.NullTime"] 216 | if uses("pq.NullTime") && !overrideNullTime { 217 | pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} 218 | } 219 | _, overrideUUID := overrideTypes["uuid.UUID"] 220 | if uses("uuid.UUID") && !overrideUUID { 221 | pkg[ImportSpec{Path: "github.com/google/uuid"}] = struct{}{} 222 | } 223 | _, overrideNullUUID := overrideTypes["uuid.NullUUID"] 224 | if uses("uuid.NullUUID") && !overrideNullUUID { 225 | pkg[ImportSpec{Path: "github.com/google/uuid"}] = struct{}{} 226 | } 227 | _, overrideVector := overrideTypes["pgvector.Vector"] 228 | if uses("pgvector.Vector") && !overrideVector { 229 | pkg[ImportSpec{Path: "github.com/pgvector/pgvector-go"}] = struct{}{} 230 | } 231 | 232 | // Custom imports 233 | for _, override := range options.Overrides { 234 | o := override.ShimOverride 235 | 236 | if o.GoType.BasicType || o.GoType.TypeName == "" { 237 | continue 238 | } 239 | _, alreadyImported := std[o.GoType.ImportPath] 240 | hasPackageAlias := o.GoType.Package != "" 241 | if (!alreadyImported || hasPackageAlias) && uses(o.GoType.TypeName) { 242 | pkg[ImportSpec{Path: o.GoType.ImportPath, ID: o.GoType.Package}] = struct{}{} 243 | } 244 | } 245 | 246 | return std, pkg 247 | } 248 | 249 | func (i *importer) interfaceImports() fileImports { 250 | std, pkg := buildImports(i.Options, i.Queries, func(name string) bool { 251 | for _, q := range i.Queries { 252 | if q.hasRetType() { 253 | if usesBatch([]Query{q}) { 254 | continue 255 | } 256 | if hasPrefixIgnoringSliceAndPointerPrefix(q.Ret.Type(), name) { 257 | return true 258 | } 259 | } 260 | for _, f := range q.Arg.Pairs() { 261 | if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) { 262 | return true 263 | } 264 | } 265 | } 266 | return false 267 | }) 268 | 269 | std["context"] = struct{}{} 270 | 271 | return sortedImports(std, pkg) 272 | } 273 | 274 | func (i *importer) modelImports() fileImports { 275 | std, pkg := buildImports(i.Options, nil, i.usesType) 276 | 277 | if len(i.Enums) > 0 { 278 | std["fmt"] = struct{}{} 279 | std["database/sql/driver"] = struct{}{} 280 | } 281 | 282 | return sortedImports(std, pkg) 283 | } 284 | 285 | func sortedImports(std map[string]struct{}, pkg map[ImportSpec]struct{}) fileImports { 286 | pkgs := make([]ImportSpec, 0, len(pkg)) 287 | for spec := range pkg { 288 | pkgs = append(pkgs, spec) 289 | } 290 | stds := make([]ImportSpec, 0, len(std)) 291 | for path := range std { 292 | stds = append(stds, ImportSpec{Path: path}) 293 | } 294 | sort.Slice(stds, func(i, j int) bool { return stds[i].Path < stds[j].Path }) 295 | sort.Slice(pkgs, func(i, j int) bool { return pkgs[i].Path < pkgs[j].Path }) 296 | return fileImports{stds, pkgs} 297 | } 298 | 299 | func (i *importer) queryImports(filename string) fileImports { 300 | var gq []Query 301 | anyNonCopyFrom := false 302 | for _, query := range i.Queries { 303 | if usesBatch([]Query{query}) { 304 | continue 305 | } 306 | if query.SourceName == filename { 307 | gq = append(gq, query) 308 | if query.Cmd != metadata.CmdCopyFrom { 309 | anyNonCopyFrom = true 310 | } 311 | } 312 | } 313 | 314 | std, pkg := buildImports(i.Options, gq, func(name string) bool { 315 | for _, q := range gq { 316 | if q.hasRetType() { 317 | if q.Ret.EmitStruct() { 318 | for _, f := range q.Ret.Struct.Fields { 319 | if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) { 320 | return true 321 | } 322 | } 323 | } 324 | if hasPrefixIgnoringSliceAndPointerPrefix(q.Ret.Type(), name) { 325 | return true 326 | } 327 | } 328 | // Check the fields of the argument struct if it's emitted 329 | if q.Arg.EmitStruct() { 330 | for _, f := range q.Arg.Struct.Fields { 331 | if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) { 332 | return true 333 | } 334 | } 335 | } 336 | // Check the argument pairs inside the method definition 337 | for _, f := range q.Arg.Pairs() { 338 | if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) { 339 | return true 340 | } 341 | } 342 | } 343 | return false 344 | }) 345 | 346 | sliceScan := func() bool { 347 | for _, q := range gq { 348 | if q.hasRetType() { 349 | if q.Ret.IsStruct() { 350 | for _, f := range q.Ret.Struct.Fields { 351 | if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" { 352 | return true 353 | } 354 | for _, embed := range f.EmbedFields { 355 | if strings.HasPrefix(embed.Type, "[]") && embed.Type != "[]byte" { 356 | return true 357 | } 358 | } 359 | } 360 | } else { 361 | if strings.HasPrefix(q.Ret.Type(), "[]") && q.Ret.Type() != "[]byte" { 362 | return true 363 | } 364 | } 365 | } 366 | if !q.Arg.isEmpty() { 367 | if q.Arg.IsStruct() { 368 | for _, f := range q.Arg.Struct.Fields { 369 | if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !f.HasSqlcSlice() { 370 | return true 371 | } 372 | } 373 | } else { 374 | if strings.HasPrefix(q.Arg.Type(), "[]") && q.Arg.Type() != "[]byte" && !q.Arg.HasSqlcSlices() { 375 | return true 376 | } 377 | } 378 | } 379 | } 380 | return false 381 | } 382 | 383 | // Search for sqlc.slice() calls 384 | sqlcSliceScan := func() bool { 385 | for _, q := range gq { 386 | if q.Arg.HasSqlcSlices() { 387 | return true 388 | } 389 | } 390 | return false 391 | } 392 | 393 | if anyNonCopyFrom { 394 | std["context"] = struct{}{} 395 | } 396 | 397 | sqlpkg := parseDriver(i.Options.SqlPackage) 398 | if sqlcSliceScan() && !sqlpkg.IsPGX() { 399 | std["strings"] = struct{}{} 400 | } 401 | if sliceScan() && !sqlpkg.IsPGX() { 402 | pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} 403 | } 404 | 405 | return sortedImports(std, pkg) 406 | } 407 | 408 | func (i *importer) copyfromImports() fileImports { 409 | copyFromQueries := make([]Query, 0, len(i.Queries)) 410 | for _, q := range i.Queries { 411 | if q.Cmd == metadata.CmdCopyFrom { 412 | copyFromQueries = append(copyFromQueries, q) 413 | } 414 | } 415 | std, pkg := buildImports(i.Options, copyFromQueries, func(name string) bool { 416 | for _, q := range copyFromQueries { 417 | if q.hasRetType() { 418 | if strings.HasPrefix(q.Ret.Type(), name) { 419 | return true 420 | } 421 | } 422 | if !q.Arg.isEmpty() { 423 | if strings.HasPrefix(q.Arg.Type(), name) { 424 | return true 425 | } 426 | } 427 | } 428 | return false 429 | }) 430 | 431 | std["context"] = struct{}{} 432 | if i.Options.SqlDriver == opts.SQLDriverGoSQLDriverMySQL { 433 | std["io"] = struct{}{} 434 | std["fmt"] = struct{}{} 435 | std["sync/atomic"] = struct{}{} 436 | pkg[ImportSpec{Path: "github.com/go-sql-driver/mysql"}] = struct{}{} 437 | pkg[ImportSpec{Path: "github.com/hexon/mysqltsv"}] = struct{}{} 438 | } 439 | 440 | return sortedImports(std, pkg) 441 | } 442 | 443 | func (i *importer) batchImports() fileImports { 444 | batchQueries := make([]Query, 0, len(i.Queries)) 445 | for _, q := range i.Queries { 446 | if usesBatch([]Query{q}) { 447 | batchQueries = append(batchQueries, q) 448 | } 449 | } 450 | std, pkg := buildImports(i.Options, batchQueries, func(name string) bool { 451 | for _, q := range batchQueries { 452 | if q.hasRetType() { 453 | if q.Ret.EmitStruct() { 454 | for _, f := range q.Ret.Struct.Fields { 455 | if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) { 456 | return true 457 | } 458 | } 459 | } 460 | if hasPrefixIgnoringSliceAndPointerPrefix(q.Ret.Type(), name) { 461 | return true 462 | } 463 | } 464 | if q.Arg.EmitStruct() { 465 | for _, f := range q.Arg.Struct.Fields { 466 | if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) { 467 | return true 468 | } 469 | } 470 | } 471 | for _, f := range q.Arg.Pairs() { 472 | if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) { 473 | return true 474 | } 475 | } 476 | } 477 | return false 478 | }) 479 | 480 | std["context"] = struct{}{} 481 | std["errors"] = struct{}{} 482 | sqlpkg := parseDriver(i.Options.SqlPackage) 483 | switch sqlpkg { 484 | case opts.SQLDriverPGXV4: 485 | pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{} 486 | case opts.SQLDriverPGXV5: 487 | pkg[ImportSpec{Path: "github.com/jackc/pgx/v5"}] = struct{}{} 488 | } 489 | 490 | return sortedImports(std, pkg) 491 | } 492 | 493 | func trimSliceAndPointerPrefix(v string) string { 494 | v = strings.TrimPrefix(v, "[]") 495 | v = strings.TrimPrefix(v, "*") 496 | return v 497 | } 498 | 499 | func hasPrefixIgnoringSliceAndPointerPrefix(s, prefix string) bool { 500 | trimmedS := trimSliceAndPointerPrefix(s) 501 | trimmedPrefix := trimSliceAndPointerPrefix(prefix) 502 | return strings.HasPrefix(trimmedS, trimmedPrefix) 503 | } 504 | 505 | func replaceConflictedArg(imports [][]ImportSpec, queries []Query) []Query { 506 | m := make(map[string]struct{}) 507 | for _, is := range imports { 508 | for _, i := range is { 509 | paths := strings.Split(i.Path, "/") 510 | m[paths[len(paths)-1]] = struct{}{} 511 | } 512 | } 513 | 514 | replacedQueries := make([]Query, 0, len(queries)) 515 | for _, query := range queries { 516 | if _, exist := m[query.Arg.Name]; exist { 517 | query.Arg.Name = toCamelCase(fmt.Sprintf("arg_%s", query.Arg.Name)) 518 | } 519 | replacedQueries = append(replacedQueries, query) 520 | } 521 | return replacedQueries 522 | } 523 | -------------------------------------------------------------------------------- /internal/inflection/singular.go: -------------------------------------------------------------------------------- 1 | package inflection 2 | 3 | import ( 4 | "strings" 5 | 6 | upstream "github.com/jinzhu/inflection" 7 | ) 8 | 9 | type SingularParams struct { 10 | Name string 11 | Exclusions []string 12 | } 13 | 14 | func Singular(s SingularParams) string { 15 | for _, exclusion := range s.Exclusions { 16 | if strings.EqualFold(s.Name, exclusion) { 17 | return s.Name 18 | } 19 | } 20 | 21 | // Manual fix for incorrect handling of "campus" 22 | // 23 | // https://github.com/sqlc-dev/sqlc/issues/430 24 | // https://github.com/jinzhu/inflection/issues/13 25 | if strings.ToLower(s.Name) == "campus" { 26 | return s.Name 27 | } 28 | // Manual fix for incorrect handling of "meta" 29 | // 30 | // https://github.com/sqlc-dev/sqlc/issues/1217 31 | // https://github.com/jinzhu/inflection/issues/21 32 | if strings.ToLower(s.Name) == "meta" { 33 | return s.Name 34 | } 35 | // Manual fix for incorrect handling of "calories" 36 | // 37 | // https://github.com/sqlc-dev/sqlc/issues/2017 38 | // https://github.com/jinzhu/inflection/issues/23 39 | if strings.ToLower(s.Name) == "calories" { 40 | return "calorie" 41 | } 42 | // Manual fix for incorrect handling of "-ves" suffix 43 | if strings.ToLower(s.Name) == "waves" { 44 | return "wave" 45 | } 46 | 47 | if strings.ToLower(s.Name) == "metadata" { 48 | return "metadata" 49 | } 50 | 51 | return upstream.Singular(s.Name) 52 | } 53 | -------------------------------------------------------------------------------- /internal/mysql_type.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/sqlc-dev/sqlc-gen-go/internal/opts" 7 | "github.com/sqlc-dev/plugin-sdk-go/sdk" 8 | "github.com/sqlc-dev/sqlc-gen-go/internal/debug" 9 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 10 | ) 11 | 12 | func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string { 13 | columnType := sdk.DataType(col.Type) 14 | notNull := col.NotNull || col.IsArray 15 | unsigned := col.Unsigned 16 | 17 | switch columnType { 18 | 19 | case "varchar", "text", "char", "tinytext", "mediumtext", "longtext": 20 | if notNull { 21 | return "string" 22 | } 23 | return "sql.NullString" 24 | 25 | case "tinyint": 26 | if col.Length == 1 { 27 | if notNull { 28 | return "bool" 29 | } 30 | return "sql.NullBool" 31 | } else { 32 | if notNull { 33 | if unsigned { 34 | return "uint8" 35 | } 36 | return "int8" 37 | } 38 | // The database/sql package does not have a sql.NullInt8 type, so we 39 | // use the smallest type they have which is NullInt16 40 | return "sql.NullInt16" 41 | } 42 | 43 | case "year": 44 | if notNull { 45 | return "int16" 46 | } 47 | return "sql.NullInt16" 48 | 49 | case "smallint": 50 | if notNull { 51 | if unsigned { 52 | return "uint16" 53 | } 54 | return "int16" 55 | } 56 | return "sql.NullInt16" 57 | 58 | case "int", "integer", "mediumint": 59 | if notNull { 60 | if unsigned { 61 | return "uint32" 62 | } 63 | return "int32" 64 | } 65 | return "sql.NullInt32" 66 | 67 | case "bigint": 68 | if notNull { 69 | if unsigned { 70 | return "uint64" 71 | } 72 | return "int64" 73 | } 74 | return "sql.NullInt64" 75 | 76 | case "blob", "binary", "varbinary", "tinyblob", "mediumblob", "longblob": 77 | if notNull { 78 | return "[]byte" 79 | } 80 | return "sql.NullString" 81 | 82 | case "double", "double precision", "real", "float": 83 | if notNull { 84 | return "float64" 85 | } 86 | return "sql.NullFloat64" 87 | 88 | case "decimal", "dec", "fixed": 89 | if notNull { 90 | return "string" 91 | } 92 | return "sql.NullString" 93 | 94 | case "enum": 95 | // TODO: Proper Enum support 96 | return "string" 97 | 98 | case "date", "timestamp", "datetime", "time": 99 | if notNull { 100 | return "time.Time" 101 | } 102 | return "sql.NullTime" 103 | 104 | case "boolean", "bool": 105 | if notNull { 106 | return "bool" 107 | } 108 | return "sql.NullBool" 109 | 110 | case "json": 111 | return "json.RawMessage" 112 | 113 | case "any": 114 | return "interface{}" 115 | 116 | default: 117 | for _, schema := range req.Catalog.Schemas { 118 | for _, enum := range schema.Enums { 119 | if enum.Name == columnType { 120 | if notNull { 121 | if schema.Name == req.Catalog.DefaultSchema { 122 | return StructName(enum.Name, options) 123 | } 124 | return StructName(schema.Name+"_"+enum.Name, options) 125 | } else { 126 | if schema.Name == req.Catalog.DefaultSchema { 127 | return "Null" + StructName(enum.Name, options) 128 | } 129 | return "Null" + StructName(schema.Name+"_"+enum.Name, options) 130 | } 131 | } 132 | } 133 | } 134 | if debug.Active { 135 | log.Printf("Unknown MySQL type: %s\n", columnType) 136 | } 137 | return "interface{}" 138 | 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /internal/opts/enum.go: -------------------------------------------------------------------------------- 1 | package opts 2 | 3 | import "fmt" 4 | 5 | type SQLDriver string 6 | 7 | const ( 8 | SQLPackagePGXV4 string = "pgx/v4" 9 | SQLPackagePGXV5 string = "pgx/v5" 10 | SQLPackageStandard string = "database/sql" 11 | ) 12 | 13 | var validPackages = map[string]struct{}{ 14 | string(SQLPackagePGXV4): {}, 15 | string(SQLPackagePGXV5): {}, 16 | string(SQLPackageStandard): {}, 17 | } 18 | 19 | func validatePackage(sqlPackage string) error { 20 | if _, found := validPackages[sqlPackage]; !found { 21 | return fmt.Errorf("unknown SQL package: %s", sqlPackage) 22 | } 23 | return nil 24 | } 25 | 26 | const ( 27 | SQLDriverPGXV4 SQLDriver = "github.com/jackc/pgx/v4" 28 | SQLDriverPGXV5 = "github.com/jackc/pgx/v5" 29 | SQLDriverLibPQ = "github.com/lib/pq" 30 | SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql" 31 | ) 32 | 33 | var validDrivers = map[string]struct{}{ 34 | string(SQLDriverPGXV4): {}, 35 | string(SQLDriverPGXV5): {}, 36 | string(SQLDriverLibPQ): {}, 37 | string(SQLDriverGoSQLDriverMySQL): {}, 38 | } 39 | 40 | func validateDriver(sqlDriver string) error { 41 | if _, found := validDrivers[sqlDriver]; !found { 42 | return fmt.Errorf("unknown SQL driver: %s", sqlDriver) 43 | } 44 | return nil 45 | } 46 | 47 | func (d SQLDriver) IsPGX() bool { 48 | return d == SQLDriverPGXV4 || d == SQLDriverPGXV5 49 | } 50 | 51 | func (d SQLDriver) IsGoSQLDriverMySQL() bool { 52 | return d == SQLDriverGoSQLDriverMySQL 53 | } 54 | 55 | func (d SQLDriver) Package() string { 56 | switch d { 57 | case SQLDriverPGXV4: 58 | return SQLPackagePGXV4 59 | case SQLDriverPGXV5: 60 | return SQLPackagePGXV5 61 | default: 62 | return SQLPackageStandard 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /internal/opts/go_type.go: -------------------------------------------------------------------------------- 1 | package opts 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "go/types" 7 | "regexp" 8 | "strings" 9 | 10 | "github.com/fatih/structtag" 11 | ) 12 | 13 | type GoType struct { 14 | Path string `json:"import" yaml:"import"` 15 | Package string `json:"package" yaml:"package"` 16 | Name string `json:"type" yaml:"type"` 17 | Pointer bool `json:"pointer" yaml:"pointer"` 18 | Slice bool `json:"slice" yaml:"slice"` 19 | Spec string `json:"-"` 20 | BuiltIn bool `json:"-"` 21 | } 22 | 23 | type ParsedGoType struct { 24 | ImportPath string 25 | Package string 26 | TypeName string 27 | BasicType bool 28 | StructTag string 29 | } 30 | 31 | func (o *GoType) MarshalJSON() ([]byte, error) { 32 | if o.Spec != "" { 33 | return json.Marshal(o.Spec) 34 | } 35 | type alias GoType 36 | return json.Marshal(alias(*o)) 37 | } 38 | 39 | func (o *GoType) UnmarshalJSON(data []byte) error { 40 | var spec string 41 | if err := json.Unmarshal(data, &spec); err == nil { 42 | *o = GoType{Spec: spec} 43 | return nil 44 | } 45 | type alias GoType 46 | var a alias 47 | if err := json.Unmarshal(data, &a); err != nil { 48 | return err 49 | } 50 | *o = GoType(a) 51 | return nil 52 | } 53 | 54 | func (o *GoType) UnmarshalYAML(unmarshal func(interface{}) error) error { 55 | var spec string 56 | if err := unmarshal(&spec); err == nil { 57 | *o = GoType{Spec: spec} 58 | return nil 59 | } 60 | type alias GoType 61 | var a alias 62 | if err := unmarshal(&a); err != nil { 63 | return err 64 | } 65 | *o = GoType(a) 66 | return nil 67 | } 68 | 69 | var validIdentifier = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) 70 | var versionNumber = regexp.MustCompile(`^v[0-9]+$`) 71 | var invalidIdentifier = regexp.MustCompile(`[^a-zA-Z0-9_]`) 72 | 73 | func generatePackageID(importPath string) (string, bool) { 74 | parts := strings.Split(importPath, "/") 75 | name := parts[len(parts)-1] 76 | // If the last part of the import path is a valid identifier, assume that's the package name 77 | if versionNumber.MatchString(name) && len(parts) >= 2 { 78 | name = parts[len(parts)-2] 79 | return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_"), true 80 | } 81 | if validIdentifier.MatchString(name) { 82 | return name, false 83 | } 84 | return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_"), true 85 | } 86 | 87 | // validate GoType 88 | func (gt GoType) parse() (*ParsedGoType, error) { 89 | var o ParsedGoType 90 | 91 | if gt.Spec == "" { 92 | // TODO: Validation 93 | if gt.Path == "" && gt.Package != "" { 94 | return nil, fmt.Errorf("Package override `go_type`: package name requires an import path") 95 | } 96 | var pkg string 97 | var pkgNeedsAlias bool 98 | 99 | if gt.Package == "" && gt.Path != "" { 100 | pkg, pkgNeedsAlias = generatePackageID(gt.Path) 101 | if pkgNeedsAlias { 102 | o.Package = pkg 103 | } 104 | } else { 105 | pkg = gt.Package 106 | o.Package = gt.Package 107 | } 108 | 109 | o.ImportPath = gt.Path 110 | o.TypeName = gt.Name 111 | o.BasicType = gt.Path == "" && gt.Package == "" 112 | if pkg != "" { 113 | o.TypeName = pkg + "." + o.TypeName 114 | } 115 | if gt.Pointer { 116 | o.TypeName = "*" + o.TypeName 117 | } 118 | if gt.Slice { 119 | o.TypeName = "[]" + o.TypeName 120 | } 121 | return &o, nil 122 | } 123 | 124 | input := gt.Spec 125 | lastDot := strings.LastIndex(input, ".") 126 | lastSlash := strings.LastIndex(input, "/") 127 | typename := input 128 | if lastDot == -1 && lastSlash == -1 { 129 | // if the type name has no slash and no dot, validate that the type is a basic Go type 130 | var found bool 131 | for _, typ := range types.Typ { 132 | info := typ.Info() 133 | if info == 0 { 134 | continue 135 | } 136 | if info&types.IsUntyped != 0 { 137 | continue 138 | } 139 | if typename == typ.Name() { 140 | found = true 141 | } 142 | } 143 | if !found { 144 | return nil, fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", input) 145 | } 146 | o.BasicType = true 147 | } else { 148 | // assume the type lives in a Go package 149 | if lastDot == -1 { 150 | return nil, fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", input) 151 | } 152 | typename = input[lastSlash+1:] 153 | // a package name beginning with "go-" will give syntax errors in 154 | // generated code. We should do the right thing and get the actual 155 | // import name, but in lieu of that, stripping the leading "go-" may get 156 | // us what we want. 157 | typename = strings.TrimPrefix(typename, "go-") 158 | typename = strings.TrimSuffix(typename, "-go") 159 | o.ImportPath = input[:lastDot] 160 | } 161 | o.TypeName = typename 162 | isPointer := input[0] == '*' 163 | if isPointer { 164 | o.ImportPath = o.ImportPath[1:] 165 | o.TypeName = "*" + o.TypeName 166 | } 167 | return &o, nil 168 | } 169 | 170 | // GoStructTag is a raw Go struct tag. 171 | type GoStructTag string 172 | 173 | // Parse parses and validates a GoStructTag. 174 | // The output is in a form convenient for codegen. 175 | // 176 | // Sample valid inputs/outputs: 177 | // 178 | // In Out 179 | // empty string {} 180 | // `a:"b"` {"a": "b"} 181 | // `a:"b" x:"y,z"` {"a": "b", "x": "y,z"} 182 | func (s GoStructTag) parse() (map[string]string, error) { 183 | m := make(map[string]string) 184 | tags, err := structtag.Parse(string(s)) 185 | if err != nil { 186 | return nil, err 187 | } 188 | for _, tag := range tags.Tags() { 189 | m[tag.Key] = tag.Value() 190 | } 191 | return m, nil 192 | } 193 | -------------------------------------------------------------------------------- /internal/opts/options.go: -------------------------------------------------------------------------------- 1 | package opts 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "maps" 7 | "path/filepath" 8 | 9 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 10 | ) 11 | 12 | type Options struct { 13 | EmitInterface bool `json:"emit_interface" yaml:"emit_interface"` 14 | EmitJsonTags bool `json:"emit_json_tags" yaml:"emit_json_tags"` 15 | JsonTagsIdUppercase bool `json:"json_tags_id_uppercase" yaml:"json_tags_id_uppercase"` 16 | EmitDbTags bool `json:"emit_db_tags" yaml:"emit_db_tags"` 17 | EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"` 18 | EmitExactTableNames bool `json:"emit_exact_table_names,omitempty" yaml:"emit_exact_table_names"` 19 | EmitEmptySlices bool `json:"emit_empty_slices,omitempty" yaml:"emit_empty_slices"` 20 | EmitExportedQueries bool `json:"emit_exported_queries" yaml:"emit_exported_queries"` 21 | EmitResultStructPointers bool `json:"emit_result_struct_pointers" yaml:"emit_result_struct_pointers"` 22 | EmitParamsStructPointers bool `json:"emit_params_struct_pointers" yaml:"emit_params_struct_pointers"` 23 | EmitMethodsWithDbArgument bool `json:"emit_methods_with_db_argument,omitempty" yaml:"emit_methods_with_db_argument"` 24 | EmitPointersForNullTypes bool `json:"emit_pointers_for_null_types" yaml:"emit_pointers_for_null_types"` 25 | EmitEnumValidMethod bool `json:"emit_enum_valid_method,omitempty" yaml:"emit_enum_valid_method"` 26 | EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"` 27 | EmitSqlAsComment bool `json:"emit_sql_as_comment,omitempty" yaml:"emit_sql_as_comment"` 28 | JsonTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"` 29 | Package string `json:"package" yaml:"package"` 30 | Out string `json:"out" yaml:"out"` 31 | Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` 32 | Rename map[string]string `json:"rename,omitempty" yaml:"rename"` 33 | SqlPackage string `json:"sql_package" yaml:"sql_package"` 34 | SqlDriver string `json:"sql_driver" yaml:"sql_driver"` 35 | OutputBatchFileName string `json:"output_batch_file_name,omitempty" yaml:"output_batch_file_name"` 36 | OutputDbFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"` 37 | OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"` 38 | OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"` 39 | OutputCopyfromFileName string `json:"output_copyfrom_file_name,omitempty" yaml:"output_copyfrom_file_name"` 40 | OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"` 41 | InflectionExcludeTableNames []string `json:"inflection_exclude_table_names,omitempty" yaml:"inflection_exclude_table_names"` 42 | QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"` 43 | OmitSqlcVersion bool `json:"omit_sqlc_version,omitempty" yaml:"omit_sqlc_version"` 44 | OmitUnusedStructs bool `json:"omit_unused_structs,omitempty" yaml:"omit_unused_structs"` 45 | BuildTags string `json:"build_tags,omitempty" yaml:"build_tags"` 46 | Initialisms *[]string `json:"initialisms,omitempty" yaml:"initialisms"` 47 | 48 | InitialismsMap map[string]struct{} `json:"-" yaml:"-"` 49 | } 50 | 51 | type GlobalOptions struct { 52 | Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` 53 | Rename map[string]string `json:"rename,omitempty" yaml:"rename"` 54 | } 55 | 56 | func Parse(req *plugin.GenerateRequest) (*Options, error) { 57 | options, err := parseOpts(req) 58 | if err != nil { 59 | return nil, err 60 | } 61 | global, err := parseGlobalOpts(req) 62 | if err != nil { 63 | return nil, err 64 | } 65 | if len(global.Overrides) > 0 { 66 | options.Overrides = append(global.Overrides, options.Overrides...) 67 | } 68 | if len(global.Rename) > 0 { 69 | if options.Rename == nil { 70 | options.Rename = map[string]string{} 71 | } 72 | maps.Copy(options.Rename, global.Rename) 73 | } 74 | return options, nil 75 | } 76 | 77 | func parseOpts(req *plugin.GenerateRequest) (*Options, error) { 78 | var options Options 79 | if len(req.PluginOptions) == 0 { 80 | return &options, nil 81 | } 82 | if err := json.Unmarshal(req.PluginOptions, &options); err != nil { 83 | return nil, fmt.Errorf("unmarshalling plugin options: %w", err) 84 | } 85 | 86 | if options.Package == "" { 87 | if options.Out != "" { 88 | options.Package = filepath.Base(options.Out) 89 | } else { 90 | return nil, fmt.Errorf("invalid options: missing package name") 91 | } 92 | } 93 | 94 | for i := range options.Overrides { 95 | if err := options.Overrides[i].parse(req); err != nil { 96 | return nil, err 97 | } 98 | } 99 | 100 | if options.SqlPackage != "" { 101 | if err := validatePackage(options.SqlPackage); err != nil { 102 | return nil, fmt.Errorf("invalid options: %s", err) 103 | } 104 | } 105 | 106 | if options.SqlDriver != "" { 107 | if err := validateDriver(options.SqlDriver); err != nil { 108 | return nil, fmt.Errorf("invalid options: %s", err) 109 | } 110 | } 111 | 112 | if options.QueryParameterLimit == nil { 113 | options.QueryParameterLimit = new(int32) 114 | *options.QueryParameterLimit = 1 115 | } 116 | 117 | if options.Initialisms == nil { 118 | options.Initialisms = new([]string) 119 | *options.Initialisms = []string{"id"} 120 | } 121 | 122 | options.InitialismsMap = map[string]struct{}{} 123 | for _, initial := range *options.Initialisms { 124 | options.InitialismsMap[initial] = struct{}{} 125 | } 126 | 127 | return &options, nil 128 | } 129 | 130 | func parseGlobalOpts(req *plugin.GenerateRequest) (*GlobalOptions, error) { 131 | var options GlobalOptions 132 | if len(req.GlobalOptions) == 0 { 133 | return &options, nil 134 | } 135 | if err := json.Unmarshal(req.GlobalOptions, &options); err != nil { 136 | return nil, fmt.Errorf("unmarshalling global options: %w", err) 137 | } 138 | for i := range options.Overrides { 139 | if err := options.Overrides[i].parse(req); err != nil { 140 | return nil, err 141 | } 142 | } 143 | return &options, nil 144 | } 145 | 146 | func ValidateOpts(opts *Options) error { 147 | if opts.EmitMethodsWithDbArgument && opts.EmitPreparedQueries { 148 | return fmt.Errorf("invalid options: emit_methods_with_db_argument and emit_prepared_queries options are mutually exclusive") 149 | } 150 | if *opts.QueryParameterLimit < 0 { 151 | return fmt.Errorf("invalid options: query parameter limit must not be negative") 152 | } 153 | 154 | return nil 155 | } 156 | -------------------------------------------------------------------------------- /internal/opts/override.go: -------------------------------------------------------------------------------- 1 | package opts 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | 8 | "github.com/sqlc-dev/plugin-sdk-go/pattern" 9 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 10 | ) 11 | 12 | type Override struct { 13 | // name of the golang type to use, e.g. `github.com/segmentio/ksuid.KSUID` 14 | GoType GoType `json:"go_type" yaml:"go_type"` 15 | 16 | // additional Go struct tags to add to this field, in raw Go struct tag form, e.g. `validate:"required" x:"y,z"` 17 | // see https://github.com/sqlc-dev/sqlc/issues/534 18 | GoStructTag GoStructTag `json:"go_struct_tag" yaml:"go_struct_tag"` 19 | 20 | // fully qualified name of the Go type, e.g. `github.com/segmentio/ksuid.KSUID` 21 | DBType string `json:"db_type" yaml:"db_type"` 22 | Deprecated_PostgresType string `json:"postgres_type" yaml:"postgres_type"` 23 | 24 | // for global overrides only when two different engines are in use 25 | Engine string `json:"engine,omitempty" yaml:"engine"` 26 | 27 | // True if the GoType should override if the matching type is nullable 28 | Nullable bool `json:"nullable" yaml:"nullable"` 29 | 30 | // True if the GoType should override if the matching type is unsiged. 31 | Unsigned bool `json:"unsigned" yaml:"unsigned"` 32 | 33 | // Deprecated. Use the `nullable` property instead 34 | Deprecated_Null bool `json:"null" yaml:"null"` 35 | 36 | // fully qualified name of the column, e.g. `accounts.id` 37 | Column string `json:"column" yaml:"column"` 38 | 39 | ColumnName *pattern.Match `json:"-"` 40 | TableCatalog *pattern.Match `json:"-"` 41 | TableSchema *pattern.Match `json:"-"` 42 | TableRel *pattern.Match `json:"-"` 43 | GoImportPath string `json:"-"` 44 | GoPackage string `json:"-"` 45 | GoTypeName string `json:"-"` 46 | GoBasicType bool `json:"-"` 47 | 48 | // Parsed form of GoStructTag, e.g. {"validate:", "required"} 49 | GoStructTags map[string]string `json:"-"` 50 | ShimOverride *ShimOverride `json:"-"` 51 | } 52 | 53 | func (o *Override) Matches(n *plugin.Identifier, defaultSchema string) bool { 54 | if n == nil { 55 | return false 56 | } 57 | schema := n.Schema 58 | if n.Schema == "" { 59 | schema = defaultSchema 60 | } 61 | if o.TableCatalog != nil && !o.TableCatalog.MatchString(n.Catalog) { 62 | return false 63 | } 64 | if o.TableSchema == nil && schema != "" { 65 | return false 66 | } 67 | if o.TableSchema != nil && !o.TableSchema.MatchString(schema) { 68 | return false 69 | } 70 | if o.TableRel == nil && n.Name != "" { 71 | return false 72 | } 73 | if o.TableRel != nil && !o.TableRel.MatchString(n.Name) { 74 | return false 75 | } 76 | return true 77 | } 78 | 79 | func (o *Override) parse(req *plugin.GenerateRequest) (err error) { 80 | // validate deprecated postgres_type field 81 | if o.Deprecated_PostgresType != "" { 82 | fmt.Fprintf(os.Stderr, "WARNING: \"postgres_type\" is deprecated. Instead, use \"db_type\" to specify a type override.\n") 83 | if o.DBType != "" { 84 | return fmt.Errorf(`Type override configurations cannot have "db_type" and "postres_type" together. Use "db_type" alone`) 85 | } 86 | o.DBType = o.Deprecated_PostgresType 87 | } 88 | 89 | // validate deprecated null field 90 | if o.Deprecated_Null { 91 | fmt.Fprintf(os.Stderr, "WARNING: \"null\" is deprecated. Instead, use the \"nullable\" field.\n") 92 | o.Nullable = true 93 | } 94 | 95 | schema := "public" 96 | if req != nil && req.Catalog != nil { 97 | schema = req.Catalog.DefaultSchema 98 | } 99 | 100 | // validate option combinations 101 | switch { 102 | case o.Column != "" && o.DBType != "": 103 | return fmt.Errorf("Override specifying both `column` (%q) and `db_type` (%q) is not valid.", o.Column, o.DBType) 104 | case o.Column == "" && o.DBType == "": 105 | return fmt.Errorf("Override must specify one of either `column` or `db_type`") 106 | } 107 | 108 | // validate Column 109 | if o.Column != "" { 110 | colParts := strings.Split(o.Column, ".") 111 | switch len(colParts) { 112 | case 2: 113 | if o.ColumnName, err = pattern.MatchCompile(colParts[1]); err != nil { 114 | return err 115 | } 116 | if o.TableRel, err = pattern.MatchCompile(colParts[0]); err != nil { 117 | return err 118 | } 119 | if o.TableSchema, err = pattern.MatchCompile(schema); err != nil { 120 | return err 121 | } 122 | case 3: 123 | if o.ColumnName, err = pattern.MatchCompile(colParts[2]); err != nil { 124 | return err 125 | } 126 | if o.TableRel, err = pattern.MatchCompile(colParts[1]); err != nil { 127 | return err 128 | } 129 | if o.TableSchema, err = pattern.MatchCompile(colParts[0]); err != nil { 130 | return err 131 | } 132 | case 4: 133 | if o.ColumnName, err = pattern.MatchCompile(colParts[3]); err != nil { 134 | return err 135 | } 136 | if o.TableRel, err = pattern.MatchCompile(colParts[2]); err != nil { 137 | return err 138 | } 139 | if o.TableSchema, err = pattern.MatchCompile(colParts[1]); err != nil { 140 | return err 141 | } 142 | if o.TableCatalog, err = pattern.MatchCompile(colParts[0]); err != nil { 143 | return err 144 | } 145 | default: 146 | return fmt.Errorf("Override `column` specifier %q is not the proper format, expected '[catalog.][schema.]tablename.colname'", o.Column) 147 | } 148 | } 149 | 150 | // validate GoType 151 | parsed, err := o.GoType.parse() 152 | if err != nil { 153 | return err 154 | } 155 | o.GoImportPath = parsed.ImportPath 156 | o.GoPackage = parsed.Package 157 | o.GoTypeName = parsed.TypeName 158 | o.GoBasicType = parsed.BasicType 159 | 160 | // validate GoStructTag 161 | tags, err := o.GoStructTag.parse() 162 | if err != nil { 163 | return err 164 | } 165 | o.GoStructTags = tags 166 | 167 | o.ShimOverride = shimOverride(req, o) 168 | return nil 169 | } 170 | -------------------------------------------------------------------------------- /internal/opts/override_test.go: -------------------------------------------------------------------------------- 1 | package opts 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/google/go-cmp/cmp" 7 | ) 8 | 9 | func TestTypeOverrides(t *testing.T) { 10 | for _, test := range []struct { 11 | override Override 12 | pkg string 13 | typeName string 14 | basic bool 15 | }{ 16 | { 17 | Override{ 18 | DBType: "uuid", 19 | GoType: GoType{Spec: "github.com/segmentio/ksuid.KSUID"}, 20 | }, 21 | "github.com/segmentio/ksuid", 22 | "ksuid.KSUID", 23 | false, 24 | }, 25 | // TODO: Add test for struct pointers 26 | // 27 | // { 28 | // Override{ 29 | // DBType: "uuid", 30 | // GoType: "github.com/segmentio/*ksuid.KSUID", 31 | // }, 32 | // "github.com/segmentio/ksuid", 33 | // "*ksuid.KSUID", 34 | // false, 35 | // }, 36 | { 37 | Override{ 38 | DBType: "citext", 39 | GoType: GoType{Spec: "string"}, 40 | }, 41 | "", 42 | "string", 43 | true, 44 | }, 45 | { 46 | Override{ 47 | DBType: "timestamp", 48 | GoType: GoType{Spec: "time.Time"}, 49 | }, 50 | "time", 51 | "time.Time", 52 | false, 53 | }, 54 | } { 55 | tt := test 56 | t.Run(tt.override.GoType.Spec, func(t *testing.T) { 57 | if err := tt.override.parse(nil); err != nil { 58 | t.Fatalf("override parsing failed; %s", err) 59 | } 60 | if diff := cmp.Diff(tt.pkg, tt.override.GoImportPath); diff != "" { 61 | t.Errorf("package mismatch;\n%s", diff) 62 | } 63 | if diff := cmp.Diff(tt.typeName, tt.override.GoTypeName); diff != "" { 64 | t.Errorf("type name mismatch;\n%s", diff) 65 | } 66 | if diff := cmp.Diff(tt.basic, tt.override.GoBasicType); diff != "" { 67 | t.Errorf("basic mismatch;\n%s", diff) 68 | } 69 | }) 70 | } 71 | for _, test := range []struct { 72 | override Override 73 | err string 74 | }{ 75 | { 76 | Override{ 77 | DBType: "uuid", 78 | GoType: GoType{Spec: "Pointer"}, 79 | }, 80 | "Package override `go_type` specifier \"Pointer\" is not a Go basic type e.g. 'string'", 81 | }, 82 | { 83 | Override{ 84 | DBType: "uuid", 85 | GoType: GoType{Spec: "untyped rune"}, 86 | }, 87 | "Package override `go_type` specifier \"untyped rune\" is not a Go basic type e.g. 'string'", 88 | }, 89 | } { 90 | tt := test 91 | t.Run(tt.override.GoType.Spec, func(t *testing.T) { 92 | err := tt.override.parse(nil) 93 | if err == nil { 94 | t.Fatalf("expected parse to fail; got nil") 95 | } 96 | if diff := cmp.Diff(tt.err, err.Error()); diff != "" { 97 | t.Errorf("error mismatch;\n%s", diff) 98 | } 99 | }) 100 | } 101 | } 102 | 103 | func FuzzOverride(f *testing.F) { 104 | for _, spec := range []string{ 105 | "string", 106 | "github.com/gofrs/uuid.UUID", 107 | "github.com/segmentio/ksuid.KSUID", 108 | } { 109 | f.Add(spec) 110 | } 111 | f.Fuzz(func(t *testing.T, s string) { 112 | o := Override{ 113 | GoType: GoType{Spec: s}, 114 | } 115 | o.parse(nil) 116 | }) 117 | } 118 | -------------------------------------------------------------------------------- /internal/opts/shim.go: -------------------------------------------------------------------------------- 1 | package opts 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 7 | ) 8 | 9 | // The ShimOverride struct exists to bridge the gap between the Override struct 10 | // and the previous Override struct defined in codegen.proto. Eventually these 11 | // shim structs should be removed in favor of using the existing Override and 12 | // GoType structs, but it's easier to provide these shim structs to not change 13 | // the existing, working code. 14 | type ShimOverride struct { 15 | DbType string 16 | Nullable bool 17 | Column string 18 | Table *plugin.Identifier 19 | ColumnName string 20 | Unsigned bool 21 | GoType *ShimGoType 22 | } 23 | 24 | func shimOverride(req *plugin.GenerateRequest, o *Override) *ShimOverride { 25 | var column string 26 | var table plugin.Identifier 27 | 28 | if o.Column != "" { 29 | colParts := strings.Split(o.Column, ".") 30 | switch len(colParts) { 31 | case 2: 32 | table.Schema = req.Catalog.DefaultSchema 33 | table.Name = colParts[0] 34 | column = colParts[1] 35 | case 3: 36 | table.Schema = colParts[0] 37 | table.Name = colParts[1] 38 | column = colParts[2] 39 | case 4: 40 | table.Catalog = colParts[0] 41 | table.Schema = colParts[1] 42 | table.Name = colParts[2] 43 | column = colParts[3] 44 | } 45 | } 46 | return &ShimOverride{ 47 | DbType: o.DBType, 48 | Nullable: o.Nullable, 49 | Unsigned: o.Unsigned, 50 | Column: o.Column, 51 | ColumnName: column, 52 | Table: &table, 53 | GoType: shimGoType(o), 54 | } 55 | } 56 | 57 | type ShimGoType struct { 58 | ImportPath string 59 | Package string 60 | TypeName string 61 | BasicType bool 62 | StructTags map[string]string 63 | } 64 | 65 | func shimGoType(o *Override) *ShimGoType { 66 | // Note that there is a slight mismatch between this and the 67 | // proto api. The GoType on the override is the unparsed type, 68 | // which could be a qualified path or an object, as per 69 | // https://docs.sqlc.dev/en/v1.18.0/reference/config.html#type-overriding 70 | return &ShimGoType{ 71 | ImportPath: o.GoImportPath, 72 | Package: o.GoPackage, 73 | TypeName: o.GoTypeName, 74 | BasicType: o.GoBasicType, 75 | StructTags: o.GoStructTags, 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /internal/postgresql_type.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "strings" 7 | 8 | "github.com/sqlc-dev/sqlc-gen-go/internal/opts" 9 | "github.com/sqlc-dev/plugin-sdk-go/sdk" 10 | "github.com/sqlc-dev/sqlc-gen-go/internal/debug" 11 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 12 | ) 13 | 14 | func parseIdentifierString(name string) (*plugin.Identifier, error) { 15 | parts := strings.Split(name, ".") 16 | switch len(parts) { 17 | case 1: 18 | return &plugin.Identifier{ 19 | Name: parts[0], 20 | }, nil 21 | case 2: 22 | return &plugin.Identifier{ 23 | Schema: parts[0], 24 | Name: parts[1], 25 | }, nil 26 | case 3: 27 | return &plugin.Identifier{ 28 | Catalog: parts[0], 29 | Schema: parts[1], 30 | Name: parts[2], 31 | }, nil 32 | default: 33 | return nil, fmt.Errorf("invalid name: %s", name) 34 | } 35 | } 36 | 37 | func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string { 38 | columnType := sdk.DataType(col.Type) 39 | notNull := col.NotNull || col.IsArray 40 | driver := parseDriver(options.SqlPackage) 41 | emitPointersForNull := driver.IsPGX() && options.EmitPointersForNullTypes 42 | 43 | switch columnType { 44 | case "serial", "serial4", "pg_catalog.serial4": 45 | if notNull { 46 | return "int32" 47 | } 48 | if emitPointersForNull { 49 | return "*int32" 50 | } 51 | if driver == opts.SQLDriverPGXV5 { 52 | return "pgtype.Int4" 53 | } 54 | return "sql.NullInt32" 55 | 56 | case "bigserial", "serial8", "pg_catalog.serial8": 57 | if notNull { 58 | return "int64" 59 | } 60 | if emitPointersForNull { 61 | return "*int64" 62 | } 63 | if driver == opts.SQLDriverPGXV5 { 64 | return "pgtype.Int8" 65 | } 66 | return "sql.NullInt64" 67 | 68 | case "smallserial", "serial2", "pg_catalog.serial2": 69 | if notNull { 70 | return "int16" 71 | } 72 | if emitPointersForNull { 73 | return "*int16" 74 | } 75 | if driver == opts.SQLDriverPGXV5 { 76 | return "pgtype.Int2" 77 | } 78 | return "sql.NullInt16" 79 | 80 | case "integer", "int", "int4", "pg_catalog.int4": 81 | if notNull { 82 | return "int32" 83 | } 84 | if emitPointersForNull { 85 | return "*int32" 86 | } 87 | if driver == opts.SQLDriverPGXV5 { 88 | return "pgtype.Int4" 89 | } 90 | return "sql.NullInt32" 91 | 92 | case "bigint", "int8", "pg_catalog.int8": 93 | if notNull { 94 | return "int64" 95 | } 96 | if emitPointersForNull { 97 | return "*int64" 98 | } 99 | if driver == opts.SQLDriverPGXV5 { 100 | return "pgtype.Int8" 101 | } 102 | return "sql.NullInt64" 103 | 104 | case "smallint", "int2", "pg_catalog.int2": 105 | if notNull { 106 | return "int16" 107 | } 108 | if emitPointersForNull { 109 | return "*int16" 110 | } 111 | if driver == opts.SQLDriverPGXV5 { 112 | return "pgtype.Int2" 113 | } 114 | return "sql.NullInt16" 115 | 116 | case "float", "double precision", "float8", "pg_catalog.float8": 117 | if notNull { 118 | return "float64" 119 | } 120 | if emitPointersForNull { 121 | return "*float64" 122 | } 123 | if driver == opts.SQLDriverPGXV5 { 124 | return "pgtype.Float8" 125 | } 126 | return "sql.NullFloat64" 127 | 128 | case "real", "float4", "pg_catalog.float4": 129 | if notNull { 130 | return "float32" 131 | } 132 | if emitPointersForNull { 133 | return "*float32" 134 | } 135 | if driver == opts.SQLDriverPGXV5 { 136 | return "pgtype.Float4" 137 | } 138 | return "sql.NullFloat64" // TODO: Change to sql.NullFloat32 after updating the go.mod file 139 | 140 | case "numeric", "pg_catalog.numeric", "money": 141 | if driver.IsPGX() { 142 | return "pgtype.Numeric" 143 | } 144 | // Since the Go standard library does not have a decimal type, lib/pq 145 | // returns numerics as strings. 146 | // 147 | // https://github.com/lib/pq/issues/648 148 | if notNull { 149 | return "string" 150 | } 151 | if emitPointersForNull { 152 | return "*string" 153 | } 154 | return "sql.NullString" 155 | 156 | case "boolean", "bool", "pg_catalog.bool": 157 | if notNull { 158 | return "bool" 159 | } 160 | if emitPointersForNull { 161 | return "*bool" 162 | } 163 | if driver == opts.SQLDriverPGXV5 { 164 | return "pgtype.Bool" 165 | } 166 | return "sql.NullBool" 167 | 168 | case "json": 169 | switch driver { 170 | case opts.SQLDriverPGXV5: 171 | return "[]byte" 172 | case opts.SQLDriverPGXV4: 173 | return "pgtype.JSON" 174 | case opts.SQLDriverLibPQ: 175 | if notNull { 176 | return "json.RawMessage" 177 | } else { 178 | return "pqtype.NullRawMessage" 179 | } 180 | default: 181 | return "interface{}" 182 | } 183 | 184 | case "jsonb": 185 | switch driver { 186 | case opts.SQLDriverPGXV5: 187 | return "[]byte" 188 | case opts.SQLDriverPGXV4: 189 | return "pgtype.JSONB" 190 | case opts.SQLDriverLibPQ: 191 | if notNull { 192 | return "json.RawMessage" 193 | } else { 194 | return "pqtype.NullRawMessage" 195 | } 196 | default: 197 | return "interface{}" 198 | } 199 | 200 | case "bytea", "blob", "pg_catalog.bytea": 201 | return "[]byte" 202 | 203 | case "date": 204 | if driver == opts.SQLDriverPGXV5 { 205 | return "pgtype.Date" 206 | } 207 | if notNull { 208 | return "time.Time" 209 | } 210 | if emitPointersForNull { 211 | return "*time.Time" 212 | } 213 | return "sql.NullTime" 214 | 215 | case "pg_catalog.time": 216 | if driver == opts.SQLDriverPGXV5 { 217 | return "pgtype.Time" 218 | } 219 | if notNull { 220 | return "time.Time" 221 | } 222 | if emitPointersForNull { 223 | return "*time.Time" 224 | } 225 | return "sql.NullTime" 226 | 227 | case "pg_catalog.timetz": 228 | if notNull { 229 | return "time.Time" 230 | } 231 | if emitPointersForNull { 232 | return "*time.Time" 233 | } 234 | return "sql.NullTime" 235 | 236 | case "pg_catalog.timestamp": 237 | if driver == opts.SQLDriverPGXV5 { 238 | return "pgtype.Timestamp" 239 | } 240 | if notNull { 241 | return "time.Time" 242 | } 243 | if emitPointersForNull { 244 | return "*time.Time" 245 | } 246 | return "sql.NullTime" 247 | 248 | case "pg_catalog.timestamptz", "timestamptz": 249 | if driver == opts.SQLDriverPGXV5 { 250 | return "pgtype.Timestamptz" 251 | } 252 | if notNull { 253 | return "time.Time" 254 | } 255 | if emitPointersForNull { 256 | return "*time.Time" 257 | } 258 | return "sql.NullTime" 259 | 260 | case "text", "pg_catalog.varchar", "pg_catalog.bpchar", "string", "citext", "name": 261 | if notNull { 262 | return "string" 263 | } 264 | if emitPointersForNull { 265 | return "*string" 266 | } 267 | if driver == opts.SQLDriverPGXV5 { 268 | return "pgtype.Text" 269 | } 270 | return "sql.NullString" 271 | 272 | case "uuid": 273 | if driver == opts.SQLDriverPGXV5 { 274 | return "pgtype.UUID" 275 | } 276 | if notNull { 277 | return "uuid.UUID" 278 | } 279 | if emitPointersForNull { 280 | return "*uuid.UUID" 281 | } 282 | return "uuid.NullUUID" 283 | 284 | case "inet": 285 | switch driver { 286 | case opts.SQLDriverPGXV5: 287 | if notNull { 288 | return "netip.Addr" 289 | } 290 | return "*netip.Addr" 291 | case opts.SQLDriverPGXV4: 292 | return "pgtype.Inet" 293 | case opts.SQLDriverLibPQ: 294 | return "pqtype.Inet" 295 | default: 296 | return "interface{}" 297 | } 298 | 299 | case "cidr": 300 | switch driver { 301 | case opts.SQLDriverPGXV5: 302 | if notNull { 303 | return "netip.Prefix" 304 | } 305 | return "*netip.Prefix" 306 | case opts.SQLDriverPGXV4: 307 | return "pgtype.CIDR" 308 | case opts.SQLDriverLibPQ: 309 | return "pqtype.CIDR" 310 | default: 311 | return "interface{}" 312 | } 313 | 314 | case "macaddr", "macaddr8": 315 | switch driver { 316 | case opts.SQLDriverPGXV5: 317 | return "net.HardwareAddr" 318 | case opts.SQLDriverPGXV4: 319 | return "pgtype.Macaddr" 320 | case opts.SQLDriverLibPQ: 321 | return "pqtype.Macaddr" 322 | default: 323 | return "interface{}" 324 | } 325 | 326 | case "ltree", "lquery", "ltxtquery": 327 | // This module implements a data type ltree for representing labels 328 | // of data stored in a hierarchical tree-like structure. Extensive 329 | // facilities for searching through label trees are provided. 330 | // 331 | // https://www.postgresql.org/docs/current/ltree.html 332 | if notNull { 333 | return "string" 334 | } 335 | if emitPointersForNull { 336 | return "*string" 337 | } 338 | if driver == opts.SQLDriverPGXV5 { 339 | return "pgtype.Text" 340 | } 341 | return "sql.NullString" 342 | 343 | case "interval", "pg_catalog.interval": 344 | if driver == opts.SQLDriverPGXV5 { 345 | return "pgtype.Interval" 346 | } 347 | if notNull { 348 | return "int64" 349 | } 350 | if emitPointersForNull { 351 | return "*int64" 352 | } 353 | return "sql.NullInt64" 354 | 355 | case "daterange": 356 | switch driver { 357 | case opts.SQLDriverPGXV4: 358 | return "pgtype.Daterange" 359 | case opts.SQLDriverPGXV5: 360 | return "pgtype.Range[pgtype.Date]" 361 | default: 362 | return "interface{}" 363 | } 364 | 365 | case "datemultirange": 366 | switch driver { 367 | case opts.SQLDriverPGXV5: 368 | return "pgtype.Multirange[pgtype.Range[pgtype.Date]]" 369 | default: 370 | return "interface{}" 371 | } 372 | 373 | case "tsrange": 374 | switch driver { 375 | case opts.SQLDriverPGXV4: 376 | return "pgtype.Tsrange" 377 | case opts.SQLDriverPGXV5: 378 | return "pgtype.Range[pgtype.Timestamp]" 379 | default: 380 | return "interface{}" 381 | } 382 | 383 | case "tsmultirange": 384 | switch driver { 385 | case opts.SQLDriverPGXV5: 386 | return "pgtype.Multirange[pgtype.Range[pgtype.Timestamp]]" 387 | default: 388 | return "interface{}" 389 | } 390 | 391 | case "tstzrange": 392 | switch driver { 393 | case opts.SQLDriverPGXV4: 394 | return "pgtype.Tstzrange" 395 | case opts.SQLDriverPGXV5: 396 | return "pgtype.Range[pgtype.Timestamptz]" 397 | default: 398 | return "interface{}" 399 | } 400 | 401 | case "tstzmultirange": 402 | switch driver { 403 | case opts.SQLDriverPGXV5: 404 | return "pgtype.Multirange[pgtype.Range[pgtype.Timestamptz]]" 405 | default: 406 | return "interface{}" 407 | } 408 | 409 | case "numrange": 410 | switch driver { 411 | case opts.SQLDriverPGXV4: 412 | return "pgtype.Numrange" 413 | case opts.SQLDriverPGXV5: 414 | return "pgtype.Range[pgtype.Numeric]" 415 | default: 416 | return "interface{}" 417 | } 418 | 419 | case "nummultirange": 420 | switch driver { 421 | case opts.SQLDriverPGXV5: 422 | return "pgtype.Multirange[pgtype.Range[pgtype.Numeric]]" 423 | default: 424 | return "interface{}" 425 | } 426 | 427 | case "int4range": 428 | switch driver { 429 | case opts.SQLDriverPGXV4: 430 | return "pgtype.Int4range" 431 | case opts.SQLDriverPGXV5: 432 | return "pgtype.Range[pgtype.Int4]" 433 | default: 434 | return "interface{}" 435 | } 436 | 437 | case "int4multirange": 438 | switch driver { 439 | case opts.SQLDriverPGXV5: 440 | return "pgtype.Multirange[pgtype.Range[pgtype.Int4]]" 441 | default: 442 | return "interface{}" 443 | } 444 | 445 | case "int8range": 446 | switch driver { 447 | case opts.SQLDriverPGXV4: 448 | return "pgtype.Int8range" 449 | case opts.SQLDriverPGXV5: 450 | return "pgtype.Range[pgtype.Int8]" 451 | default: 452 | return "interface{}" 453 | } 454 | 455 | case "int8multirange": 456 | switch driver { 457 | case opts.SQLDriverPGXV5: 458 | return "pgtype.Multirange[pgtype.Range[pgtype.Int8]]" 459 | default: 460 | return "interface{}" 461 | } 462 | 463 | case "hstore": 464 | if driver.IsPGX() { 465 | return "pgtype.Hstore" 466 | } 467 | return "interface{}" 468 | 469 | case "bit", "varbit", "pg_catalog.bit", "pg_catalog.varbit": 470 | if driver == opts.SQLDriverPGXV5 { 471 | return "pgtype.Bits" 472 | } 473 | if driver == opts.SQLDriverPGXV4 { 474 | return "pgtype.Varbit" 475 | } 476 | 477 | case "cid": 478 | if driver == opts.SQLDriverPGXV5 { 479 | return "pgtype.Uint32" 480 | } 481 | if driver == opts.SQLDriverPGXV4 { 482 | return "pgtype.CID" 483 | } 484 | 485 | case "oid": 486 | if driver == opts.SQLDriverPGXV5 { 487 | return "pgtype.Uint32" 488 | } 489 | if driver == opts.SQLDriverPGXV4 { 490 | return "pgtype.OID" 491 | } 492 | 493 | case "tid": 494 | if driver.IsPGX() { 495 | return "pgtype.TID" 496 | } 497 | 498 | case "xid": 499 | if driver == opts.SQLDriverPGXV5 { 500 | return "pgtype.Uint32" 501 | } 502 | if driver == opts.SQLDriverPGXV4 { 503 | return "pgtype.XID" 504 | } 505 | 506 | case "box": 507 | if driver.IsPGX() { 508 | return "pgtype.Box" 509 | } 510 | 511 | case "circle": 512 | if driver.IsPGX() { 513 | return "pgtype.Circle" 514 | } 515 | 516 | case "line": 517 | if driver.IsPGX() { 518 | return "pgtype.Line" 519 | } 520 | 521 | case "lseg": 522 | if driver.IsPGX() { 523 | return "pgtype.Lseg" 524 | } 525 | 526 | case "path": 527 | if driver.IsPGX() { 528 | return "pgtype.Path" 529 | } 530 | 531 | case "point": 532 | if driver.IsPGX() { 533 | return "pgtype.Point" 534 | } 535 | 536 | case "polygon": 537 | if driver.IsPGX() { 538 | return "pgtype.Polygon" 539 | } 540 | 541 | case "vector": 542 | if driver == opts.SQLDriverPGXV5 { 543 | if emitPointersForNull { 544 | return "*pgvector.Vector" 545 | } else { 546 | return "pgvector.Vector" 547 | } 548 | } 549 | 550 | case "void": 551 | // A void value can only be scanned into an empty interface. 552 | return "interface{}" 553 | 554 | case "any": 555 | return "interface{}" 556 | 557 | default: 558 | rel, err := parseIdentifierString(columnType) 559 | if err != nil { 560 | // TODO: Should this actually return an error here? 561 | return "interface{}" 562 | } 563 | if rel.Schema == "" { 564 | rel.Schema = req.Catalog.DefaultSchema 565 | } 566 | 567 | for _, schema := range req.Catalog.Schemas { 568 | if schema.Name == "pg_catalog" || schema.Name == "information_schema" { 569 | continue 570 | } 571 | 572 | for _, enum := range schema.Enums { 573 | if rel.Name == enum.Name && rel.Schema == schema.Name { 574 | if notNull { 575 | if schema.Name == req.Catalog.DefaultSchema { 576 | return StructName(enum.Name, options) 577 | } 578 | return StructName(schema.Name+"_"+enum.Name, options) 579 | } else { 580 | if schema.Name == req.Catalog.DefaultSchema { 581 | return "Null" + StructName(enum.Name, options) 582 | } 583 | return "Null" + StructName(schema.Name+"_"+enum.Name, options) 584 | } 585 | } 586 | } 587 | 588 | for _, ct := range schema.CompositeTypes { 589 | if rel.Name == ct.Name && rel.Schema == schema.Name { 590 | if notNull { 591 | return "string" 592 | } 593 | if emitPointersForNull { 594 | return "*string" 595 | } 596 | return "sql.NullString" 597 | } 598 | } 599 | } 600 | } 601 | 602 | if debug.Active { 603 | log.Printf("unknown PostgreSQL type: %s\n", columnType) 604 | } 605 | return "interface{}" 606 | } 607 | -------------------------------------------------------------------------------- /internal/query.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/sqlc-dev/sqlc-gen-go/internal/opts" 8 | "github.com/sqlc-dev/plugin-sdk-go/metadata" 9 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 10 | ) 11 | 12 | type QueryValue struct { 13 | Emit bool 14 | EmitPointer bool 15 | Name string 16 | DBName string // The name of the field in the database. Only set if Struct==nil. 17 | Struct *Struct 18 | Typ string 19 | SQLDriver opts.SQLDriver 20 | 21 | // Column is kept so late in the generation process around to differentiate 22 | // between mysql slices and pg arrays 23 | Column *plugin.Column 24 | } 25 | 26 | func (v QueryValue) EmitStruct() bool { 27 | return v.Emit 28 | } 29 | 30 | func (v QueryValue) IsStruct() bool { 31 | return v.Struct != nil 32 | } 33 | 34 | func (v QueryValue) IsPointer() bool { 35 | return v.EmitPointer && v.Struct != nil 36 | } 37 | 38 | func (v QueryValue) isEmpty() bool { 39 | return v.Typ == "" && v.Name == "" && v.Struct == nil 40 | } 41 | 42 | type Argument struct { 43 | Name string 44 | Type string 45 | } 46 | 47 | func (v QueryValue) Pair() string { 48 | var out []string 49 | for _, arg := range v.Pairs() { 50 | out = append(out, arg.Name+" "+arg.Type) 51 | } 52 | return strings.Join(out, ",") 53 | } 54 | 55 | // Return the argument name and type for query methods. Should only be used in 56 | // the context of method arguments. 57 | func (v QueryValue) Pairs() []Argument { 58 | if v.isEmpty() { 59 | return nil 60 | } 61 | if !v.EmitStruct() && v.IsStruct() { 62 | var out []Argument 63 | for _, f := range v.Struct.Fields { 64 | out = append(out, Argument{ 65 | Name: escape(toLowerCase(f.Name)), 66 | Type: f.Type, 67 | }) 68 | } 69 | return out 70 | } 71 | return []Argument{ 72 | { 73 | Name: escape(v.Name), 74 | Type: v.DefineType(), 75 | }, 76 | } 77 | } 78 | 79 | func (v QueryValue) SlicePair() string { 80 | if v.isEmpty() { 81 | return "" 82 | } 83 | return v.Name + " []" + v.DefineType() 84 | } 85 | 86 | func (v QueryValue) Type() string { 87 | if v.Typ != "" { 88 | return v.Typ 89 | } 90 | if v.Struct != nil { 91 | return v.Struct.Name 92 | } 93 | panic("no type for QueryValue: " + v.Name) 94 | } 95 | 96 | func (v *QueryValue) DefineType() string { 97 | t := v.Type() 98 | if v.IsPointer() { 99 | return "*" + t 100 | } 101 | return t 102 | } 103 | 104 | func (v *QueryValue) ReturnName() string { 105 | if v.IsPointer() { 106 | return "&" + escape(v.Name) 107 | } 108 | return escape(v.Name) 109 | } 110 | 111 | func (v QueryValue) UniqueFields() []Field { 112 | seen := map[string]struct{}{} 113 | fields := make([]Field, 0, len(v.Struct.Fields)) 114 | 115 | for _, field := range v.Struct.Fields { 116 | if _, found := seen[field.Name]; found { 117 | continue 118 | } 119 | seen[field.Name] = struct{}{} 120 | fields = append(fields, field) 121 | } 122 | 123 | return fields 124 | } 125 | 126 | func (v QueryValue) Params() string { 127 | if v.isEmpty() { 128 | return "" 129 | } 130 | var out []string 131 | if v.Struct == nil { 132 | if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() { 133 | out = append(out, "pq.Array("+escape(v.Name)+")") 134 | } else { 135 | out = append(out, escape(v.Name)) 136 | } 137 | } else { 138 | for _, f := range v.Struct.Fields { 139 | if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() { 140 | out = append(out, "pq.Array("+escape(v.VariableForField(f))+")") 141 | } else { 142 | out = append(out, escape(v.VariableForField(f))) 143 | } 144 | } 145 | } 146 | if len(out) <= 3 { 147 | return strings.Join(out, ",") 148 | } 149 | out = append(out, "") 150 | return "\n" + strings.Join(out, ",\n") 151 | } 152 | 153 | func (v QueryValue) ColumnNames() []string { 154 | if v.Struct == nil { 155 | return []string{v.DBName} 156 | } 157 | names := make([]string, len(v.Struct.Fields)) 158 | for i, f := range v.Struct.Fields { 159 | names[i] = f.DBName 160 | } 161 | return names 162 | } 163 | 164 | func (v QueryValue) ColumnNamesAsGoSlice() string { 165 | if v.Struct == nil { 166 | return fmt.Sprintf("[]string{%q}", v.DBName) 167 | } 168 | escapedNames := make([]string, len(v.Struct.Fields)) 169 | for i, f := range v.Struct.Fields { 170 | if f.Column != nil && f.Column.OriginalName != "" { 171 | escapedNames[i] = fmt.Sprintf("%q", f.Column.OriginalName) 172 | } else { 173 | escapedNames[i] = fmt.Sprintf("%q", f.DBName) 174 | } 175 | } 176 | return "[]string{" + strings.Join(escapedNames, ", ") + "}" 177 | } 178 | 179 | // When true, we have to build the arguments to q.db.QueryContext in addition to 180 | // munging the SQL 181 | func (v QueryValue) HasSqlcSlices() bool { 182 | if v.Struct == nil { 183 | return v.Column != nil && v.Column.IsSqlcSlice 184 | } 185 | for _, v := range v.Struct.Fields { 186 | if v.Column.IsSqlcSlice { 187 | return true 188 | } 189 | } 190 | return false 191 | } 192 | 193 | func (v QueryValue) Scan() string { 194 | var out []string 195 | if v.Struct == nil { 196 | if strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() { 197 | out = append(out, "pq.Array(&"+v.Name+")") 198 | } else { 199 | out = append(out, "&"+v.Name) 200 | } 201 | } else { 202 | for _, f := range v.Struct.Fields { 203 | 204 | // append any embedded fields 205 | if len(f.EmbedFields) > 0 { 206 | for _, embed := range f.EmbedFields { 207 | if strings.HasPrefix(embed.Type, "[]") && embed.Type != "[]byte" && !v.SQLDriver.IsPGX() { 208 | out = append(out, "pq.Array(&"+v.Name+"."+f.Name+"."+embed.Name+")") 209 | } else { 210 | out = append(out, "&"+v.Name+"."+f.Name+"."+embed.Name) 211 | } 212 | } 213 | continue 214 | } 215 | 216 | if strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() { 217 | out = append(out, "pq.Array(&"+v.Name+"."+f.Name+")") 218 | } else { 219 | out = append(out, "&"+v.Name+"."+f.Name) 220 | } 221 | } 222 | } 223 | if len(out) <= 3 { 224 | return strings.Join(out, ",") 225 | } 226 | out = append(out, "") 227 | return "\n" + strings.Join(out, ",\n") 228 | } 229 | 230 | // Deprecated: This method does not respect the Emit field set on the 231 | // QueryValue. It's used by the go-sql-driver-mysql/copyfromCopy.tmpl and should 232 | // not be used other places. 233 | func (v QueryValue) CopyFromMySQLFields() []Field { 234 | // fmt.Printf("%#v\n", v) 235 | if v.Struct != nil { 236 | return v.Struct.Fields 237 | } 238 | return []Field{ 239 | { 240 | Name: v.Name, 241 | DBName: v.DBName, 242 | Type: v.Typ, 243 | }, 244 | } 245 | } 246 | 247 | func (v QueryValue) VariableForField(f Field) string { 248 | if !v.IsStruct() { 249 | return v.Name 250 | } 251 | if !v.EmitStruct() { 252 | return toLowerCase(f.Name) 253 | } 254 | return v.Name + "." + f.Name 255 | } 256 | 257 | // A struct used to generate methods and fields on the Queries struct 258 | type Query struct { 259 | Cmd string 260 | Comments []string 261 | MethodName string 262 | FieldName string 263 | ConstantName string 264 | SQL string 265 | SourceName string 266 | Ret QueryValue 267 | Arg QueryValue 268 | // Used for :copyfrom 269 | Table *plugin.Identifier 270 | } 271 | 272 | func (q Query) hasRetType() bool { 273 | scanned := q.Cmd == metadata.CmdOne || q.Cmd == metadata.CmdMany || 274 | q.Cmd == metadata.CmdBatchMany || q.Cmd == metadata.CmdBatchOne 275 | return scanned && !q.Ret.isEmpty() 276 | } 277 | 278 | func (q Query) TableIdentifierAsGoSlice() string { 279 | escapedNames := make([]string, 0, 3) 280 | for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} { 281 | if p != "" { 282 | escapedNames = append(escapedNames, fmt.Sprintf("%q", p)) 283 | } 284 | } 285 | return "[]string{" + strings.Join(escapedNames, ", ") + "}" 286 | } 287 | 288 | func (q Query) TableIdentifierForMySQL() string { 289 | escapedNames := make([]string, 0, 3) 290 | for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} { 291 | if p != "" { 292 | escapedNames = append(escapedNames, fmt.Sprintf("`%s`", p)) 293 | } 294 | } 295 | return strings.Join(escapedNames, ".") 296 | } 297 | -------------------------------------------------------------------------------- /internal/reserved.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | func escape(s string) string { 4 | if IsReserved(s) { 5 | return s + "_" 6 | } 7 | return s 8 | } 9 | 10 | func IsReserved(s string) bool { 11 | switch s { 12 | case "break": 13 | return true 14 | case "default": 15 | return true 16 | case "func": 17 | return true 18 | case "interface": 19 | return true 20 | case "select": 21 | return true 22 | case "case": 23 | return true 24 | case "defer": 25 | return true 26 | case "go": 27 | return true 28 | case "map": 29 | return true 30 | case "struct": 31 | return true 32 | case "chan": 33 | return true 34 | case "else": 35 | return true 36 | case "goto": 37 | return true 38 | case "package": 39 | return true 40 | case "switch": 41 | return true 42 | case "const": 43 | return true 44 | case "fallthrough": 45 | return true 46 | case "if": 47 | return true 48 | case "range": 49 | return true 50 | case "type": 51 | return true 52 | case "continue": 53 | return true 54 | case "for": 55 | return true 56 | case "import": 57 | return true 58 | case "return": 59 | return true 60 | case "var": 61 | return true 62 | case "q": 63 | return true 64 | default: 65 | return false 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /internal/result.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "sort" 7 | "strings" 8 | 9 | "github.com/sqlc-dev/sqlc-gen-go/internal/opts" 10 | "github.com/sqlc-dev/plugin-sdk-go/sdk" 11 | "github.com/sqlc-dev/sqlc-gen-go/internal/inflection" 12 | "github.com/sqlc-dev/plugin-sdk-go/metadata" 13 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 14 | ) 15 | 16 | func buildEnums(req *plugin.GenerateRequest, options *opts.Options) []Enum { 17 | var enums []Enum 18 | for _, schema := range req.Catalog.Schemas { 19 | if schema.Name == "pg_catalog" || schema.Name == "information_schema" { 20 | continue 21 | } 22 | for _, enum := range schema.Enums { 23 | var enumName string 24 | if schema.Name == req.Catalog.DefaultSchema { 25 | enumName = enum.Name 26 | } else { 27 | enumName = schema.Name + "_" + enum.Name 28 | } 29 | 30 | e := Enum{ 31 | Name: StructName(enumName, options), 32 | Comment: enum.Comment, 33 | NameTags: map[string]string{}, 34 | ValidTags: map[string]string{}, 35 | } 36 | if options.EmitJsonTags { 37 | e.NameTags["json"] = JSONTagName(enumName, options) 38 | e.ValidTags["json"] = JSONTagName("valid", options) 39 | } 40 | 41 | seen := make(map[string]struct{}, len(enum.Vals)) 42 | for i, v := range enum.Vals { 43 | value := EnumReplace(v) 44 | if _, found := seen[value]; found || value == "" { 45 | value = fmt.Sprintf("value_%d", i) 46 | } 47 | e.Constants = append(e.Constants, Constant{ 48 | Name: StructName(enumName+"_"+value, options), 49 | Value: v, 50 | Type: e.Name, 51 | }) 52 | seen[value] = struct{}{} 53 | } 54 | enums = append(enums, e) 55 | } 56 | } 57 | if len(enums) > 0 { 58 | sort.Slice(enums, func(i, j int) bool { return enums[i].Name < enums[j].Name }) 59 | } 60 | return enums 61 | } 62 | 63 | func buildStructs(req *plugin.GenerateRequest, options *opts.Options) []Struct { 64 | var structs []Struct 65 | for _, schema := range req.Catalog.Schemas { 66 | if schema.Name == "pg_catalog" || schema.Name == "information_schema" { 67 | continue 68 | } 69 | for _, table := range schema.Tables { 70 | var tableName string 71 | if schema.Name == req.Catalog.DefaultSchema { 72 | tableName = table.Rel.Name 73 | } else { 74 | tableName = schema.Name + "_" + table.Rel.Name 75 | } 76 | structName := tableName 77 | if !options.EmitExactTableNames { 78 | structName = inflection.Singular(inflection.SingularParams{ 79 | Name: structName, 80 | Exclusions: options.InflectionExcludeTableNames, 81 | }) 82 | } 83 | s := Struct{ 84 | Table: &plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name}, 85 | Name: StructName(structName, options), 86 | Comment: table.Comment, 87 | } 88 | for _, column := range table.Columns { 89 | tags := map[string]string{} 90 | if options.EmitDbTags { 91 | tags["db"] = column.Name 92 | } 93 | if options.EmitJsonTags { 94 | tags["json"] = JSONTagName(column.Name, options) 95 | } 96 | addExtraGoStructTags(tags, req, options, column) 97 | s.Fields = append(s.Fields, Field{ 98 | Name: StructName(column.Name, options), 99 | Type: goType(req, options, column), 100 | Tags: tags, 101 | Comment: column.Comment, 102 | }) 103 | } 104 | structs = append(structs, s) 105 | } 106 | } 107 | if len(structs) > 0 { 108 | sort.Slice(structs, func(i, j int) bool { return structs[i].Name < structs[j].Name }) 109 | } 110 | return structs 111 | } 112 | 113 | type goColumn struct { 114 | id int 115 | *plugin.Column 116 | embed *goEmbed 117 | } 118 | 119 | type goEmbed struct { 120 | modelType string 121 | modelName string 122 | fields []Field 123 | } 124 | 125 | // look through all the structs and attempt to find a matching one to embed 126 | // We need the name of the struct and its field names. 127 | func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string) *goEmbed { 128 | if embed == nil { 129 | return nil 130 | } 131 | 132 | for _, s := range structs { 133 | embedSchema := defaultSchema 134 | if embed.Schema != "" { 135 | embedSchema = embed.Schema 136 | } 137 | 138 | // compare the other attributes 139 | if embed.Catalog != s.Table.Catalog || embed.Name != s.Table.Name || embedSchema != s.Table.Schema { 140 | continue 141 | } 142 | 143 | fields := make([]Field, len(s.Fields)) 144 | for i, f := range s.Fields { 145 | fields[i] = f 146 | } 147 | 148 | return &goEmbed{ 149 | modelType: s.Name, 150 | modelName: s.Name, 151 | fields: fields, 152 | } 153 | } 154 | 155 | return nil 156 | } 157 | 158 | func columnName(c *plugin.Column, pos int) string { 159 | if c.Name != "" { 160 | return c.Name 161 | } 162 | return fmt.Sprintf("column_%d", pos+1) 163 | } 164 | 165 | func paramName(p *plugin.Parameter) string { 166 | if p.Column.Name != "" { 167 | return argName(p.Column.Name) 168 | } 169 | return fmt.Sprintf("dollar_%d", p.Number) 170 | } 171 | 172 | func argName(name string) string { 173 | out := "" 174 | for i, p := range strings.Split(name, "_") { 175 | if i == 0 { 176 | out += strings.ToLower(p) 177 | } else if p == "id" { 178 | out += "ID" 179 | } else { 180 | out += strings.Title(p) 181 | } 182 | } 183 | return out 184 | } 185 | 186 | func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []Struct) ([]Query, error) { 187 | qs := make([]Query, 0, len(req.Queries)) 188 | for _, query := range req.Queries { 189 | if query.Name == "" { 190 | continue 191 | } 192 | if query.Cmd == "" { 193 | continue 194 | } 195 | 196 | var constantName string 197 | if options.EmitExportedQueries { 198 | constantName = sdk.Title(query.Name) 199 | } else { 200 | constantName = sdk.LowerTitle(query.Name) 201 | } 202 | 203 | comments := query.Comments 204 | if options.EmitSqlAsComment { 205 | if len(comments) == 0 { 206 | comments = append(comments, query.Name) 207 | } 208 | comments = append(comments, " ") 209 | scanner := bufio.NewScanner(strings.NewReader(query.Text)) 210 | for scanner.Scan() { 211 | line := scanner.Text() 212 | comments = append(comments, " "+line) 213 | } 214 | if err := scanner.Err(); err != nil { 215 | return nil, err 216 | } 217 | } 218 | 219 | gq := Query{ 220 | Cmd: query.Cmd, 221 | ConstantName: constantName, 222 | FieldName: sdk.LowerTitle(query.Name) + "Stmt", 223 | MethodName: query.Name, 224 | SourceName: query.Filename, 225 | SQL: query.Text, 226 | Comments: comments, 227 | Table: query.InsertIntoTable, 228 | } 229 | sqlpkg := parseDriver(options.SqlPackage) 230 | 231 | qpl := int(*options.QueryParameterLimit) 232 | 233 | if len(query.Params) == 1 && qpl != 0 { 234 | p := query.Params[0] 235 | gq.Arg = QueryValue{ 236 | Name: escape(paramName(p)), 237 | DBName: p.Column.GetName(), 238 | Typ: goType(req, options, p.Column), 239 | SQLDriver: sqlpkg, 240 | Column: p.Column, 241 | } 242 | } else if len(query.Params) >= 1 { 243 | var cols []goColumn 244 | for _, p := range query.Params { 245 | cols = append(cols, goColumn{ 246 | id: int(p.Number), 247 | Column: p.Column, 248 | }) 249 | } 250 | s, err := columnsToStruct(req, options, gq.MethodName+"Params", cols, false) 251 | if err != nil { 252 | return nil, err 253 | } 254 | gq.Arg = QueryValue{ 255 | Emit: true, 256 | Name: "arg", 257 | Struct: s, 258 | SQLDriver: sqlpkg, 259 | EmitPointer: options.EmitParamsStructPointers, 260 | } 261 | 262 | // if query params is 2, and query params limit is 4 AND this is a copyfrom, we still want to emit the query's model 263 | // otherwise we end up with a copyfrom using a struct without the struct definition 264 | if len(query.Params) <= qpl && query.Cmd != ":copyfrom" { 265 | gq.Arg.Emit = false 266 | } 267 | } 268 | 269 | if len(query.Columns) == 1 && query.Columns[0].EmbedTable == nil { 270 | c := query.Columns[0] 271 | name := columnName(c, 0) 272 | name = strings.Replace(name, "$", "_", -1) 273 | gq.Ret = QueryValue{ 274 | Name: escape(name), 275 | DBName: name, 276 | Typ: goType(req, options, c), 277 | SQLDriver: sqlpkg, 278 | } 279 | } else if putOutColumns(query) { 280 | var gs *Struct 281 | var emit bool 282 | 283 | for _, s := range structs { 284 | if len(s.Fields) != len(query.Columns) { 285 | continue 286 | } 287 | same := true 288 | for i, f := range s.Fields { 289 | c := query.Columns[i] 290 | sameName := f.Name == StructName(columnName(c, i), options) 291 | sameType := f.Type == goType(req, options, c) 292 | sameTable := sdk.SameTableName(c.Table, s.Table, req.Catalog.DefaultSchema) 293 | if !sameName || !sameType || !sameTable { 294 | same = false 295 | } 296 | } 297 | if same { 298 | gs = &s 299 | break 300 | } 301 | } 302 | 303 | if gs == nil { 304 | var columns []goColumn 305 | for i, c := range query.Columns { 306 | columns = append(columns, goColumn{ 307 | id: i, 308 | Column: c, 309 | embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema), 310 | }) 311 | } 312 | var err error 313 | gs, err = columnsToStruct(req, options, gq.MethodName+"Row", columns, true) 314 | if err != nil { 315 | return nil, err 316 | } 317 | emit = true 318 | } 319 | gq.Ret = QueryValue{ 320 | Emit: emit, 321 | Name: "i", 322 | Struct: gs, 323 | SQLDriver: sqlpkg, 324 | EmitPointer: options.EmitResultStructPointers, 325 | } 326 | } 327 | 328 | qs = append(qs, gq) 329 | } 330 | sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName }) 331 | return qs, nil 332 | } 333 | 334 | var cmdReturnsData = map[string]struct{}{ 335 | metadata.CmdBatchMany: {}, 336 | metadata.CmdBatchOne: {}, 337 | metadata.CmdMany: {}, 338 | metadata.CmdOne: {}, 339 | } 340 | 341 | func putOutColumns(query *plugin.Query) bool { 342 | _, found := cmdReturnsData[query.Cmd] 343 | return found 344 | } 345 | 346 | // It's possible that this method will generate duplicate JSON tag values 347 | // 348 | // Columns: count, count, count_2 349 | // Fields: Count, Count_2, Count2 350 | // 351 | // JSON tags: count, count_2, count_2 352 | // 353 | // This is unlikely to happen, so don't fix it yet 354 | func columnsToStruct(req *plugin.GenerateRequest, options *opts.Options, name string, columns []goColumn, useID bool) (*Struct, error) { 355 | gs := Struct{ 356 | Name: name, 357 | } 358 | seen := map[string][]int{} 359 | suffixes := map[int]int{} 360 | for i, c := range columns { 361 | colName := columnName(c.Column, i) 362 | tagName := colName 363 | 364 | // override col/tag with expected model name 365 | if c.embed != nil { 366 | colName = c.embed.modelName 367 | tagName = SetCaseStyle(colName, "snake") 368 | } 369 | 370 | fieldName := StructName(colName, options) 371 | baseFieldName := fieldName 372 | // Track suffixes by the ID of the column, so that columns referring to the same numbered parameter can be 373 | // reused. 374 | suffix := 0 375 | if o, ok := suffixes[c.id]; ok && useID { 376 | suffix = o 377 | } else if v := len(seen[fieldName]); v > 0 && !c.IsNamedParam { 378 | suffix = v + 1 379 | } 380 | suffixes[c.id] = suffix 381 | if suffix > 0 { 382 | tagName = fmt.Sprintf("%s_%d", tagName, suffix) 383 | fieldName = fmt.Sprintf("%s_%d", fieldName, suffix) 384 | } 385 | tags := map[string]string{} 386 | if options.EmitDbTags { 387 | tags["db"] = tagName 388 | } 389 | if options.EmitJsonTags { 390 | tags["json"] = JSONTagName(tagName, options) 391 | } 392 | addExtraGoStructTags(tags, req, options, c.Column) 393 | f := Field{ 394 | Name: fieldName, 395 | DBName: colName, 396 | Tags: tags, 397 | Column: c.Column, 398 | } 399 | if c.embed == nil { 400 | f.Type = goType(req, options, c.Column) 401 | } else { 402 | f.Type = c.embed.modelType 403 | f.EmbedFields = c.embed.fields 404 | } 405 | 406 | gs.Fields = append(gs.Fields, f) 407 | if _, found := seen[baseFieldName]; !found { 408 | seen[baseFieldName] = []int{i} 409 | } else { 410 | seen[baseFieldName] = append(seen[baseFieldName], i) 411 | } 412 | } 413 | 414 | // If a field does not have a known type, but another 415 | // field with the same name has a known type, assign 416 | // the known type to the field without a known type 417 | for i, field := range gs.Fields { 418 | if len(seen[field.Name]) > 1 && field.Type == "interface{}" { 419 | for _, j := range seen[field.Name] { 420 | if i == j { 421 | continue 422 | } 423 | otherField := gs.Fields[j] 424 | if otherField.Type != field.Type { 425 | field.Type = otherField.Type 426 | } 427 | gs.Fields[i] = field 428 | } 429 | } 430 | } 431 | 432 | err := checkIncompatibleFieldTypes(gs.Fields) 433 | if err != nil { 434 | return nil, err 435 | } 436 | 437 | return &gs, nil 438 | } 439 | 440 | func checkIncompatibleFieldTypes(fields []Field) error { 441 | fieldTypes := map[string]string{} 442 | for _, field := range fields { 443 | if fieldType, found := fieldTypes[field.Name]; !found { 444 | fieldTypes[field.Name] = field.Type 445 | } else if field.Type != fieldType { 446 | return fmt.Errorf("named param %s has incompatible types: %s, %s", field.Name, field.Type, fieldType) 447 | } 448 | } 449 | return nil 450 | } 451 | -------------------------------------------------------------------------------- /internal/result_test.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/sqlc-dev/plugin-sdk-go/metadata" 7 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 8 | ) 9 | 10 | func TestPutOutColumns_ForZeroColumns(t *testing.T) { 11 | tests := []struct { 12 | cmd string 13 | want bool 14 | }{ 15 | { 16 | cmd: metadata.CmdExec, 17 | want: false, 18 | }, 19 | { 20 | cmd: metadata.CmdExecResult, 21 | want: false, 22 | }, 23 | { 24 | cmd: metadata.CmdExecRows, 25 | want: false, 26 | }, 27 | { 28 | cmd: metadata.CmdExecLastId, 29 | want: false, 30 | }, 31 | { 32 | cmd: metadata.CmdMany, 33 | want: true, 34 | }, 35 | { 36 | cmd: metadata.CmdOne, 37 | want: true, 38 | }, 39 | { 40 | cmd: metadata.CmdCopyFrom, 41 | want: false, 42 | }, 43 | { 44 | cmd: metadata.CmdBatchExec, 45 | want: false, 46 | }, 47 | { 48 | cmd: metadata.CmdBatchMany, 49 | want: true, 50 | }, 51 | { 52 | cmd: metadata.CmdBatchOne, 53 | want: true, 54 | }, 55 | } 56 | for _, tc := range tests { 57 | t.Run(tc.cmd, func(t *testing.T) { 58 | query := &plugin.Query{ 59 | Cmd: tc.cmd, 60 | Columns: []*plugin.Column{}, 61 | } 62 | got := putOutColumns(query) 63 | if got != tc.want { 64 | t.Errorf("putOutColumns failed. want %v, got %v", tc.want, got) 65 | } 66 | }) 67 | } 68 | } 69 | 70 | func TestPutOutColumns_AlwaysTrueWhenQueryHasColumns(t *testing.T) { 71 | query := &plugin.Query{ 72 | Cmd: metadata.CmdMany, 73 | Columns: []*plugin.Column{{}}, 74 | } 75 | if putOutColumns(query) != true { 76 | t.Error("should be true when we have columns") 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /internal/sqlite_type.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "log" 5 | "strings" 6 | 7 | "github.com/sqlc-dev/sqlc-gen-go/internal/opts" 8 | "github.com/sqlc-dev/plugin-sdk-go/sdk" 9 | "github.com/sqlc-dev/sqlc-gen-go/internal/debug" 10 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 11 | ) 12 | 13 | func sqliteType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string { 14 | dt := strings.ToLower(sdk.DataType(col.Type)) 15 | notNull := col.NotNull || col.IsArray 16 | emitPointersForNull := options.EmitPointersForNullTypes 17 | 18 | switch dt { 19 | 20 | case "int", "integer", "tinyint", "smallint", "mediumint", "bigint", "unsignedbigint", "int2", "int8": 21 | if notNull { 22 | return "int64" 23 | } 24 | if emitPointersForNull { 25 | return "*int64" 26 | } 27 | return "sql.NullInt64" 28 | 29 | case "blob": 30 | return "[]byte" 31 | 32 | case "real", "double", "doubleprecision", "float": 33 | if notNull { 34 | return "float64" 35 | } 36 | if emitPointersForNull { 37 | return "*float64" 38 | } 39 | return "sql.NullFloat64" 40 | 41 | case "boolean", "bool": 42 | if notNull { 43 | return "bool" 44 | } 45 | if emitPointersForNull { 46 | return "*bool" 47 | } 48 | return "sql.NullBool" 49 | 50 | case "date", "datetime", "timestamp": 51 | if notNull { 52 | return "time.Time" 53 | } 54 | if emitPointersForNull { 55 | return "*time.Time" 56 | } 57 | return "sql.NullTime" 58 | 59 | case "any": 60 | return "interface{}" 61 | 62 | } 63 | 64 | switch { 65 | 66 | case strings.HasPrefix(dt, "character"), 67 | strings.HasPrefix(dt, "varchar"), 68 | strings.HasPrefix(dt, "varyingcharacter"), 69 | strings.HasPrefix(dt, "nchar"), 70 | strings.HasPrefix(dt, "nativecharacter"), 71 | strings.HasPrefix(dt, "nvarchar"), 72 | dt == "text", 73 | dt == "clob": 74 | if notNull { 75 | return "string" 76 | } 77 | if emitPointersForNull { 78 | return "*string" 79 | } 80 | return "sql.NullString" 81 | 82 | case strings.HasPrefix(dt, "decimal"), dt == "numeric": 83 | if notNull { 84 | return "float64" 85 | } 86 | if emitPointersForNull { 87 | return "*float64" 88 | } 89 | return "sql.NullFloat64" 90 | 91 | default: 92 | if debug.Active { 93 | log.Printf("unknown SQLite type: %s\n", dt) 94 | } 95 | 96 | return "interface{}" 97 | 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /internal/struct.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import ( 4 | "strings" 5 | "unicode" 6 | "unicode/utf8" 7 | 8 | "github.com/sqlc-dev/sqlc-gen-go/internal/opts" 9 | "github.com/sqlc-dev/plugin-sdk-go/plugin" 10 | ) 11 | 12 | type Struct struct { 13 | Table *plugin.Identifier 14 | Name string 15 | Fields []Field 16 | Comment string 17 | } 18 | 19 | func StructName(name string, options *opts.Options) string { 20 | if rename := options.Rename[name]; rename != "" { 21 | return rename 22 | } 23 | out := "" 24 | name = strings.Map(func(r rune) rune { 25 | if unicode.IsLetter(r) { 26 | return r 27 | } 28 | if unicode.IsDigit(r) { 29 | return r 30 | } 31 | return rune('_') 32 | }, name) 33 | 34 | for _, p := range strings.Split(name, "_") { 35 | if _, found := options.InitialismsMap[p]; found { 36 | out += strings.ToUpper(p) 37 | } else { 38 | out += strings.Title(p) 39 | } 40 | } 41 | 42 | // If a name has a digit as its first char, prepand an underscore to make it a valid Go name. 43 | r, _ := utf8.DecodeRuneInString(out) 44 | if unicode.IsDigit(r) { 45 | return "_" + out 46 | } else { 47 | return out 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /internal/template.go: -------------------------------------------------------------------------------- 1 | package golang 2 | 3 | import "embed" 4 | 5 | //go:embed templates/* 6 | //go:embed templates/*/* 7 | var templates embed.FS 8 | -------------------------------------------------------------------------------- /internal/templates/go-sql-driver-mysql/copyfromCopy.tmpl: -------------------------------------------------------------------------------- 1 | {{define "copyfromCodeGoSqlDriver"}} 2 | {{range .GoQueries}} 3 | {{if eq .Cmd ":copyfrom" }} 4 | var readerHandlerSequenceFor{{.MethodName}} uint32 = 1 5 | 6 | func convertRowsFor{{.MethodName}}(w *io.PipeWriter, {{.Arg.SlicePair}}) { 7 | e := mysqltsv.NewEncoder(w, {{ len .Arg.CopyFromMySQLFields }}, nil) 8 | for _, row := range {{.Arg.Name}} { 9 | {{- with $arg := .Arg }} 10 | {{- range $arg.CopyFromMySQLFields}} 11 | {{- if eq .Type "string"}} 12 | e.AppendString({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) 13 | {{- else if or (eq .Type "[]byte") (eq .Type "json.RawMessage")}} 14 | e.AppendBytes({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) 15 | {{- else}} 16 | e.AppendValue({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) 17 | {{- end}} 18 | {{- end}} 19 | {{- end}} 20 | } 21 | w.CloseWithError(e.Close()) 22 | } 23 | 24 | {{range .Comments}}//{{.}} 25 | {{end -}} 26 | // {{.MethodName}} uses MySQL's LOAD DATA LOCAL INFILE and is not atomic. 27 | // 28 | // Errors and duplicate keys are treated as warnings and insertion will 29 | // continue, even without an error for some cases. Use this in a transaction 30 | // and use SHOW WARNINGS to check for any problems and roll back if you want to. 31 | // 32 | // Check the documentation for more information: 33 | // https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling 34 | func (q *Queries) {{.MethodName}}(ctx context.Context{{if $.EmitMethodsWithDBArgument}}, db DBTX{{end}}, {{.Arg.SlicePair}}) (int64, error) { 35 | pr, pw := io.Pipe() 36 | defer pr.Close() 37 | rh := fmt.Sprintf("{{.MethodName}}_%d", atomic.AddUint32(&readerHandlerSequenceFor{{.MethodName}}, 1)) 38 | mysql.RegisterReaderHandler(rh, func() io.Reader { return pr }) 39 | defer mysql.DeregisterReaderHandler(rh) 40 | go convertRowsFor{{.MethodName}}(pw, {{.Arg.Name}}) 41 | // The string interpolation is necessary because LOAD DATA INFILE requires 42 | // the file name to be given as a literal string. 43 | result, err := {{if (not $.EmitMethodsWithDBArgument)}}q.{{end}}db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping)) 44 | if err != nil { 45 | return 0, err 46 | } 47 | return result.RowsAffected() 48 | } 49 | 50 | {{end}} 51 | {{end}} 52 | {{end}} 53 | -------------------------------------------------------------------------------- /internal/templates/pgx/batchCode.tmpl: -------------------------------------------------------------------------------- 1 | {{define "batchCodePgx"}} 2 | 3 | var ( 4 | ErrBatchAlreadyClosed = errors.New("batch already closed") 5 | ) 6 | 7 | {{range .GoQueries}} 8 | {{if eq (hasPrefix .Cmd ":batch") true }} 9 | const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} 10 | {{escape .SQL}} 11 | {{$.Q}} 12 | 13 | type {{.MethodName}}BatchResults struct { 14 | br pgx.BatchResults 15 | tot int 16 | closed bool 17 | } 18 | 19 | {{if .Arg.Struct}} 20 | type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}} 21 | {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} 22 | {{- end}} 23 | } 24 | {{end}} 25 | 26 | {{if .Ret.EmitStruct}} 27 | type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} 28 | {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} 29 | {{- end}} 30 | } 31 | {{end}} 32 | 33 | {{range .Comments}}//{{.}} 34 | {{end -}} 35 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{ if $.EmitMethodsWithDBArgument}}db DBTX,{{end}} {{.Arg.SlicePair}}) *{{.MethodName}}BatchResults { 36 | batch := &pgx.Batch{} 37 | for _, a := range {{index .Arg.Name}} { 38 | vals := []interface{}{ 39 | {{- if .Arg.Struct }} 40 | {{- range .Arg.Struct.Fields }} 41 | a.{{.Name}}, 42 | {{- end }} 43 | {{- else }} 44 | a, 45 | {{- end }} 46 | } 47 | batch.Queue({{.ConstantName}}, vals...) 48 | } 49 | br := {{if not $.EmitMethodsWithDBArgument}}q.{{end}}db.SendBatch(ctx, batch) 50 | return &{{.MethodName}}BatchResults{br,len({{.Arg.Name}}),false} 51 | } 52 | 53 | {{if eq .Cmd ":batchexec"}} 54 | func (b *{{.MethodName}}BatchResults) Exec(f func(int, error)) { 55 | defer b.br.Close() 56 | for t := 0; t < b.tot; t++ { 57 | if b.closed { 58 | if f != nil { 59 | f(t, ErrBatchAlreadyClosed) 60 | } 61 | continue 62 | } 63 | _, err := b.br.Exec() 64 | if f != nil { 65 | f(t, err) 66 | } 67 | } 68 | } 69 | {{end}} 70 | 71 | {{if eq .Cmd ":batchmany"}} 72 | func (b *{{.MethodName}}BatchResults) Query(f func(int, []{{.Ret.DefineType}}, error)) { 73 | defer b.br.Close() 74 | for t := 0; t < b.tot; t++ { 75 | {{- if $.EmitEmptySlices}} 76 | items := []{{.Ret.DefineType}}{} 77 | {{else}} 78 | var items []{{.Ret.DefineType}} 79 | {{end -}} 80 | if b.closed { 81 | if f != nil { 82 | f(t, items, ErrBatchAlreadyClosed) 83 | } 84 | continue 85 | } 86 | err := func() error { 87 | rows, err := b.br.Query() 88 | if err != nil { 89 | return err 90 | } 91 | defer rows.Close() 92 | for rows.Next() { 93 | var {{.Ret.Name}} {{.Ret.Type}} 94 | if err := rows.Scan({{.Ret.Scan}}); err != nil { 95 | return err 96 | } 97 | items = append(items, {{.Ret.ReturnName}}) 98 | } 99 | return rows.Err() 100 | }() 101 | if f != nil { 102 | f(t, items, err) 103 | } 104 | } 105 | } 106 | {{end}} 107 | 108 | {{if eq .Cmd ":batchone"}} 109 | func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}}, error)) { 110 | defer b.br.Close() 111 | for t := 0; t < b.tot; t++ { 112 | var {{.Ret.Name}} {{.Ret.Type}} 113 | if b.closed { 114 | if f != nil { 115 | f(t, {{if .Ret.IsPointer}}nil{{else}}{{.Ret.Name}}{{end}}, ErrBatchAlreadyClosed) 116 | } 117 | continue 118 | } 119 | row := b.br.QueryRow() 120 | err := row.Scan({{.Ret.Scan}}) 121 | if f != nil { 122 | f(t, {{.Ret.ReturnName}}, err) 123 | } 124 | } 125 | } 126 | {{end}} 127 | 128 | func (b *{{.MethodName}}BatchResults) Close() error { 129 | b.closed = true 130 | return b.br.Close() 131 | } 132 | {{end}} 133 | {{end}} 134 | {{end}} 135 | -------------------------------------------------------------------------------- /internal/templates/pgx/copyfromCopy.tmpl: -------------------------------------------------------------------------------- 1 | {{define "copyfromCodePgx"}} 2 | {{range .GoQueries}} 3 | {{if eq .Cmd ":copyfrom" }} 4 | // iteratorFor{{.MethodName}} implements pgx.CopyFromSource. 5 | type iteratorFor{{.MethodName}} struct { 6 | rows []{{.Arg.DefineType}} 7 | skippedFirstNextCall bool 8 | } 9 | 10 | func (r *iteratorFor{{.MethodName}}) Next() bool { 11 | if len(r.rows) == 0 { 12 | return false 13 | } 14 | if !r.skippedFirstNextCall { 15 | r.skippedFirstNextCall = true 16 | return true 17 | } 18 | r.rows = r.rows[1:] 19 | return len(r.rows) > 0 20 | } 21 | 22 | func (r iteratorFor{{.MethodName}}) Values() ([]interface{}, error) { 23 | return []interface{}{ 24 | {{- if .Arg.Struct }} 25 | {{- range .Arg.Struct.Fields }} 26 | r.rows[0].{{.Name}}, 27 | {{- end }} 28 | {{- else }} 29 | r.rows[0], 30 | {{- end }} 31 | }, nil 32 | } 33 | 34 | func (r iteratorFor{{.MethodName}}) Err() error { 35 | return nil 36 | } 37 | 38 | {{range .Comments}}//{{.}} 39 | {{end -}} 40 | {{- if $.EmitMethodsWithDBArgument -}} 41 | func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) { 42 | return db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) 43 | {{- else -}} 44 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) { 45 | return q.db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) 46 | {{- end}} 47 | } 48 | 49 | {{end}} 50 | {{end}} 51 | {{end}} 52 | -------------------------------------------------------------------------------- /internal/templates/pgx/dbCode.tmpl: -------------------------------------------------------------------------------- 1 | {{define "dbCodeTemplatePgx"}} 2 | 3 | type DBTX interface { 4 | Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) 5 | Query(context.Context, string, ...interface{}) (pgx.Rows, error) 6 | QueryRow(context.Context, string, ...interface{}) pgx.Row 7 | {{- if .UsesCopyFrom }} 8 | CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) 9 | {{- end }} 10 | {{- if .UsesBatch }} 11 | SendBatch(context.Context, *pgx.Batch) pgx.BatchResults 12 | {{- end }} 13 | } 14 | 15 | {{ if .EmitMethodsWithDBArgument}} 16 | func New() *Queries { 17 | return &Queries{} 18 | {{- else -}} 19 | func New(db DBTX) *Queries { 20 | return &Queries{db: db} 21 | {{- end}} 22 | } 23 | 24 | type Queries struct { 25 | {{if not .EmitMethodsWithDBArgument}} 26 | db DBTX 27 | {{end}} 28 | } 29 | 30 | {{if not .EmitMethodsWithDBArgument}} 31 | func (q *Queries) WithTx(tx pgx.Tx) *Queries { 32 | return &Queries{ 33 | db: tx, 34 | } 35 | } 36 | {{end}} 37 | {{end}} 38 | -------------------------------------------------------------------------------- /internal/templates/pgx/interfaceCode.tmpl: -------------------------------------------------------------------------------- 1 | {{define "interfaceCodePgx"}} 2 | type Querier interface { 3 | {{- $dbtxParam := .EmitMethodsWithDBArgument -}} 4 | {{- range .GoQueries}} 5 | {{- if and (eq .Cmd ":one") ($dbtxParam) }} 6 | {{range .Comments}}//{{.}} 7 | {{end -}} 8 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) 9 | {{- else if eq .Cmd ":one" }} 10 | {{range .Comments}}//{{.}} 11 | {{end -}} 12 | {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) 13 | {{- end}} 14 | {{- if and (eq .Cmd ":many") ($dbtxParam) }} 15 | {{range .Comments}}//{{.}} 16 | {{end -}} 17 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) 18 | {{- else if eq .Cmd ":many" }} 19 | {{range .Comments}}//{{.}} 20 | {{end -}} 21 | {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) 22 | {{- end}} 23 | {{- if and (eq .Cmd ":exec") ($dbtxParam) }} 24 | {{range .Comments}}//{{.}} 25 | {{end -}} 26 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error 27 | {{- else if eq .Cmd ":exec" }} 28 | {{range .Comments}}//{{.}} 29 | {{end -}} 30 | {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error 31 | {{- end}} 32 | {{- if and (eq .Cmd ":execrows") ($dbtxParam) }} 33 | {{range .Comments}}//{{.}} 34 | {{end -}} 35 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) 36 | {{- else if eq .Cmd ":execrows" }} 37 | {{range .Comments}}//{{.}} 38 | {{end -}} 39 | {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) 40 | {{- end}} 41 | {{- if and (eq .Cmd ":execresult") ($dbtxParam) }} 42 | {{range .Comments}}//{{.}} 43 | {{end -}} 44 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) 45 | {{- else if eq .Cmd ":execresult" }} 46 | {{range .Comments}}//{{.}} 47 | {{end -}} 48 | {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) 49 | {{- end}} 50 | {{- if and (eq .Cmd ":copyfrom") ($dbtxParam) }} 51 | {{range .Comments}}//{{.}} 52 | {{end -}} 53 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) 54 | {{- else if eq .Cmd ":copyfrom" }} 55 | {{range .Comments}}//{{.}} 56 | {{end -}} 57 | {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) 58 | {{- end}} 59 | {{- if and (or (eq .Cmd ":batchexec") (eq .Cmd ":batchmany") (eq .Cmd ":batchone")) ($dbtxParam) }} 60 | {{range .Comments}}//{{.}} 61 | {{end -}} 62 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) *{{.MethodName}}BatchResults 63 | {{- else if or (eq .Cmd ":batchexec") (eq .Cmd ":batchmany") (eq .Cmd ":batchone") }} 64 | {{range .Comments}}//{{.}} 65 | {{end -}} 66 | {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) *{{.MethodName}}BatchResults 67 | {{- end}} 68 | 69 | {{- end}} 70 | } 71 | 72 | var _ Querier = (*Queries)(nil) 73 | {{end}} 74 | -------------------------------------------------------------------------------- /internal/templates/pgx/queryCode.tmpl: -------------------------------------------------------------------------------- 1 | {{define "queryCodePgx"}} 2 | {{range .GoQueries}} 3 | {{if $.OutputQuery .SourceName}} 4 | {{if and (ne .Cmd ":copyfrom") (ne (hasPrefix .Cmd ":batch") true)}} 5 | const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} 6 | {{escape .SQL}} 7 | {{$.Q}} 8 | {{end}} 9 | 10 | {{if ne (hasPrefix .Cmd ":batch") true}} 11 | {{if .Arg.EmitStruct}} 12 | type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}} 13 | {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} 14 | {{- end}} 15 | } 16 | {{end}} 17 | 18 | {{if .Ret.EmitStruct}} 19 | type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} 20 | {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} 21 | {{- end}} 22 | } 23 | {{end}} 24 | {{end}} 25 | 26 | {{if eq .Cmd ":one"}} 27 | {{range .Comments}}//{{.}} 28 | {{end -}} 29 | {{- if $.EmitMethodsWithDBArgument -}} 30 | func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { 31 | row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) 32 | {{- else -}} 33 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { 34 | row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) 35 | {{- end}} 36 | {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} 37 | var {{.Ret.Name}} {{.Ret.Type}} 38 | {{- end}} 39 | err := row.Scan({{.Ret.Scan}}) 40 | return {{.Ret.ReturnName}}, err 41 | } 42 | {{end}} 43 | 44 | {{if eq .Cmd ":many"}} 45 | {{range .Comments}}//{{.}} 46 | {{end -}} 47 | {{- if $.EmitMethodsWithDBArgument -}} 48 | func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { 49 | rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) 50 | {{- else -}} 51 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { 52 | rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) 53 | {{- end}} 54 | if err != nil { 55 | return nil, err 56 | } 57 | defer rows.Close() 58 | {{- if $.EmitEmptySlices}} 59 | items := []{{.Ret.DefineType}}{} 60 | {{else}} 61 | var items []{{.Ret.DefineType}} 62 | {{end -}} 63 | for rows.Next() { 64 | var {{.Ret.Name}} {{.Ret.Type}} 65 | if err := rows.Scan({{.Ret.Scan}}); err != nil { 66 | return nil, err 67 | } 68 | items = append(items, {{.Ret.ReturnName}}) 69 | } 70 | if err := rows.Err(); err != nil { 71 | return nil, err 72 | } 73 | return items, nil 74 | } 75 | {{end}} 76 | 77 | {{if eq .Cmd ":exec"}} 78 | {{range .Comments}}//{{.}} 79 | {{end -}} 80 | {{- if $.EmitMethodsWithDBArgument -}} 81 | func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error { 82 | _, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) 83 | {{- else -}} 84 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { 85 | _, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) 86 | {{- end}} 87 | return err 88 | } 89 | {{end}} 90 | 91 | {{if eq .Cmd ":execrows"}} 92 | {{range .Comments}}//{{.}} 93 | {{end -}} 94 | {{if $.EmitMethodsWithDBArgument -}} 95 | func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) { 96 | result, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) 97 | {{- else -}} 98 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) { 99 | result, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) 100 | {{- end}} 101 | if err != nil { 102 | return 0, err 103 | } 104 | return result.RowsAffected(), nil 105 | } 106 | {{end}} 107 | 108 | {{if eq .Cmd ":execresult"}} 109 | {{range .Comments}}//{{.}} 110 | {{end -}} 111 | {{- if $.EmitMethodsWithDBArgument -}} 112 | func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) { 113 | return db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) 114 | {{- else -}} 115 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) { 116 | return q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) 117 | {{- end}} 118 | } 119 | {{end}} 120 | 121 | 122 | {{end}} 123 | {{end}} 124 | {{end}} 125 | -------------------------------------------------------------------------------- /internal/templates/stdlib/dbCode.tmpl: -------------------------------------------------------------------------------- 1 | {{define "dbCodeTemplateStd"}} 2 | type DBTX interface { 3 | ExecContext(context.Context, string, ...interface{}) (sql.Result, error) 4 | PrepareContext(context.Context, string) (*sql.Stmt, error) 5 | QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) 6 | QueryRowContext(context.Context, string, ...interface{}) *sql.Row 7 | } 8 | 9 | {{ if .EmitMethodsWithDBArgument}} 10 | func New() *Queries { 11 | return &Queries{} 12 | {{- else -}} 13 | func New(db DBTX) *Queries { 14 | return &Queries{db: db} 15 | {{- end}} 16 | } 17 | 18 | {{if .EmitPreparedQueries}} 19 | func Prepare(ctx context.Context, db DBTX) (*Queries, error) { 20 | q := Queries{db: db} 21 | var err error 22 | {{- if eq (len .GoQueries) 0 }} 23 | _ = err 24 | {{- end }} 25 | {{- range .GoQueries }} 26 | if q.{{.FieldName}}, err = db.PrepareContext(ctx, {{.ConstantName}}); err != nil { 27 | return nil, fmt.Errorf("error preparing query {{.MethodName}}: %w", err) 28 | } 29 | {{- end}} 30 | return &q, nil 31 | } 32 | 33 | func (q *Queries) Close() error { 34 | var err error 35 | {{- range .GoQueries }} 36 | if q.{{.FieldName}} != nil { 37 | if cerr := q.{{.FieldName}}.Close(); cerr != nil { 38 | err = fmt.Errorf("error closing {{.FieldName}}: %w", cerr) 39 | } 40 | } 41 | {{- end}} 42 | return err 43 | } 44 | 45 | func (q *Queries) exec(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (sql.Result, error) { 46 | switch { 47 | case stmt != nil && q.tx != nil: 48 | return q.tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) 49 | case stmt != nil: 50 | return stmt.ExecContext(ctx, args...) 51 | default: 52 | return q.db.ExecContext(ctx, query, args...) 53 | } 54 | } 55 | 56 | func (q *Queries) query(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (*sql.Rows, error) { 57 | switch { 58 | case stmt != nil && q.tx != nil: 59 | return q.tx.StmtContext(ctx, stmt).QueryContext(ctx, args...) 60 | case stmt != nil: 61 | return stmt.QueryContext(ctx, args...) 62 | default: 63 | return q.db.QueryContext(ctx, query, args...) 64 | } 65 | } 66 | 67 | func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (*sql.Row) { 68 | switch { 69 | case stmt != nil && q.tx != nil: 70 | return q.tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) 71 | case stmt != nil: 72 | return stmt.QueryRowContext(ctx, args...) 73 | default: 74 | return q.db.QueryRowContext(ctx, query, args...) 75 | } 76 | } 77 | {{end}} 78 | 79 | type Queries struct { 80 | {{- if not .EmitMethodsWithDBArgument}} 81 | db DBTX 82 | {{- end}} 83 | 84 | {{- if .EmitPreparedQueries}} 85 | tx *sql.Tx 86 | {{- range .GoQueries}} 87 | {{.FieldName}} *sql.Stmt 88 | {{- end}} 89 | {{- end}} 90 | } 91 | 92 | {{if not .EmitMethodsWithDBArgument}} 93 | func (q *Queries) WithTx(tx *sql.Tx) *Queries { 94 | return &Queries{ 95 | db: tx, 96 | {{- if .EmitPreparedQueries}} 97 | tx: tx, 98 | {{- range .GoQueries}} 99 | {{.FieldName}}: q.{{.FieldName}}, 100 | {{- end}} 101 | {{- end}} 102 | } 103 | } 104 | {{end}} 105 | {{end}} 106 | -------------------------------------------------------------------------------- /internal/templates/stdlib/interfaceCode.tmpl: -------------------------------------------------------------------------------- 1 | {{define "interfaceCodeStd"}} 2 | type Querier interface { 3 | {{- $dbtxParam := .EmitMethodsWithDBArgument -}} 4 | {{- range .GoQueries}} 5 | {{- if and (eq .Cmd ":one") ($dbtxParam) }} 6 | {{range .Comments}}//{{.}} 7 | {{end -}} 8 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) 9 | {{- else if eq .Cmd ":one"}} 10 | {{range .Comments}}//{{.}} 11 | {{end -}} 12 | {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) 13 | {{- end}} 14 | {{- if and (eq .Cmd ":many") ($dbtxParam) }} 15 | {{range .Comments}}//{{.}} 16 | {{end -}} 17 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) 18 | {{- else if eq .Cmd ":many"}} 19 | {{range .Comments}}//{{.}} 20 | {{end -}} 21 | {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) 22 | {{- end}} 23 | {{- if and (eq .Cmd ":exec") ($dbtxParam) }} 24 | {{range .Comments}}//{{.}} 25 | {{end -}} 26 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error 27 | {{- else if eq .Cmd ":exec"}} 28 | {{range .Comments}}//{{.}} 29 | {{end -}} 30 | {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error 31 | {{- end}} 32 | {{- if and (eq .Cmd ":execrows") ($dbtxParam) }} 33 | {{range .Comments}}//{{.}} 34 | {{end -}} 35 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) 36 | {{- else if eq .Cmd ":execrows"}} 37 | {{range .Comments}}//{{.}} 38 | {{end -}} 39 | {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) 40 | {{- end}} 41 | {{- if and (eq .Cmd ":execlastid") ($dbtxParam) }} 42 | {{range .Comments}}//{{.}} 43 | {{end -}} 44 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) 45 | {{- else if eq .Cmd ":execlastid"}} 46 | {{range .Comments}}//{{.}} 47 | {{end -}} 48 | {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) 49 | {{- end}} 50 | {{- if and (eq .Cmd ":execresult") ($dbtxParam) }} 51 | {{range .Comments}}//{{.}} 52 | {{end -}} 53 | {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (sql.Result, error) 54 | {{- else if eq .Cmd ":execresult"}} 55 | {{range .Comments}}//{{.}} 56 | {{end -}} 57 | {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (sql.Result, error) 58 | {{- end}} 59 | {{- end}} 60 | } 61 | 62 | var _ Querier = (*Queries)(nil) 63 | {{end}} 64 | -------------------------------------------------------------------------------- /internal/templates/stdlib/queryCode.tmpl: -------------------------------------------------------------------------------- 1 | {{define "queryCodeStd"}} 2 | {{range .GoQueries}} 3 | {{if $.OutputQuery .SourceName}} 4 | const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} 5 | {{escape .SQL}} 6 | {{$.Q}} 7 | 8 | {{if .Arg.EmitStruct}} 9 | type {{.Arg.Type}} struct { {{- range .Arg.UniqueFields}} 10 | {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} 11 | {{- end}} 12 | } 13 | {{end}} 14 | 15 | {{if .Ret.EmitStruct}} 16 | type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} 17 | {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} 18 | {{- end}} 19 | } 20 | {{end}} 21 | 22 | {{if eq .Cmd ":one"}} 23 | {{range .Comments}}//{{.}} 24 | {{end -}} 25 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { 26 | {{- template "queryCodeStdExec" . }} 27 | {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} 28 | var {{.Ret.Name}} {{.Ret.Type}} 29 | {{- end}} 30 | err := row.Scan({{.Ret.Scan}}) 31 | return {{.Ret.ReturnName}}, err 32 | } 33 | {{end}} 34 | 35 | {{if eq .Cmd ":many"}} 36 | {{range .Comments}}//{{.}} 37 | {{end -}} 38 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { 39 | {{- template "queryCodeStdExec" . }} 40 | if err != nil { 41 | return nil, err 42 | } 43 | defer rows.Close() 44 | {{- if $.EmitEmptySlices}} 45 | items := []{{.Ret.DefineType}}{} 46 | {{else}} 47 | var items []{{.Ret.DefineType}} 48 | {{end -}} 49 | for rows.Next() { 50 | var {{.Ret.Name}} {{.Ret.Type}} 51 | if err := rows.Scan({{.Ret.Scan}}); err != nil { 52 | return nil, err 53 | } 54 | items = append(items, {{.Ret.ReturnName}}) 55 | } 56 | if err := rows.Close(); err != nil { 57 | return nil, err 58 | } 59 | if err := rows.Err(); err != nil { 60 | return nil, err 61 | } 62 | return items, nil 63 | } 64 | {{end}} 65 | 66 | {{if eq .Cmd ":exec"}} 67 | {{range .Comments}}//{{.}} 68 | {{end -}} 69 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) error { 70 | {{- template "queryCodeStdExec" . }} 71 | return err 72 | } 73 | {{end}} 74 | 75 | {{if eq .Cmd ":execrows"}} 76 | {{range .Comments}}//{{.}} 77 | {{end -}} 78 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { 79 | {{- template "queryCodeStdExec" . }} 80 | if err != nil { 81 | return 0, err 82 | } 83 | return result.RowsAffected() 84 | } 85 | {{end}} 86 | 87 | {{if eq .Cmd ":execlastid"}} 88 | {{range .Comments}}//{{.}} 89 | {{end -}} 90 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { 91 | {{- template "queryCodeStdExec" . }} 92 | if err != nil { 93 | return 0, err 94 | } 95 | return result.LastInsertId() 96 | } 97 | {{end}} 98 | 99 | {{if eq .Cmd ":execresult"}} 100 | {{range .Comments}}//{{.}} 101 | {{end -}} 102 | func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (sql.Result, error) { 103 | {{- template "queryCodeStdExec" . }} 104 | } 105 | {{end}} 106 | 107 | {{end}} 108 | {{end}} 109 | {{end}} 110 | 111 | {{define "queryCodeStdExec"}} 112 | {{- if .Arg.HasSqlcSlices }} 113 | query := {{.ConstantName}} 114 | var queryParams []interface{} 115 | {{- if .Arg.Struct }} 116 | {{- $arg := .Arg }} 117 | {{- range .Arg.Struct.Fields }} 118 | {{- if .HasSqlcSlice }} 119 | if len({{$arg.VariableForField .}}) > 0 { 120 | for _, v := range {{$arg.VariableForField .}} { 121 | queryParams = append(queryParams, v) 122 | } 123 | query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", strings.Repeat(",?", len({{$arg.VariableForField .}}))[1:], 1) 124 | } else { 125 | query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", "NULL", 1) 126 | } 127 | {{- else }} 128 | queryParams = append(queryParams, {{$arg.VariableForField .}}) 129 | {{- end }} 130 | {{- end }} 131 | {{- else }} 132 | {{- /* Single argument parameter to this goroutine (they are not packed 133 | in a struct), because .Arg.HasSqlcSlices further up above was true, 134 | this section is 100% a slice (impossible to get here otherwise). 135 | */}} 136 | if len({{.Arg.Name}}) > 0 { 137 | for _, v := range {{.Arg.Name}} { 138 | queryParams = append(queryParams, v) 139 | } 140 | query = strings.Replace(query, "/*SLICE:{{.Arg.Column.Name}}*/?", strings.Repeat(",?", len({{.Arg.Name}}))[1:], 1) 141 | } else { 142 | query = strings.Replace(query, "/*SLICE:{{.Arg.Column.Name}}*/?", "NULL", 1) 143 | } 144 | {{- end }} 145 | {{- if emitPreparedQueries }} 146 | {{ queryRetval . }} {{ queryMethod . }}(ctx, nil, query, queryParams...) 147 | {{- else}} 148 | {{ queryRetval . }} {{ queryMethod . }}(ctx, query, queryParams...) 149 | {{- end -}} 150 | {{- else if emitPreparedQueries }} 151 | {{- queryRetval . }} {{ queryMethod . }}(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) 152 | {{- else}} 153 | {{- queryRetval . }} {{ queryMethod . }}(ctx, {{.ConstantName}}, {{.Arg.Params}}) 154 | {{- end -}} 155 | {{end}} 156 | -------------------------------------------------------------------------------- /internal/templates/template.tmpl: -------------------------------------------------------------------------------- 1 | {{define "dbFile"}} 2 | {{if .BuildTags}} 3 | //go:build {{.BuildTags}} 4 | 5 | {{end}}// Code generated by sqlc. DO NOT EDIT. 6 | {{if not .OmitSqlcVersion}}// versions: 7 | // sqlc {{.SqlcVersion}} 8 | {{end}} 9 | 10 | package {{.Package}} 11 | 12 | {{ if hasImports .SourceName }} 13 | import ( 14 | {{range imports .SourceName}} 15 | {{range .}}{{.}} 16 | {{end}} 17 | {{end}} 18 | ) 19 | {{end}} 20 | 21 | {{template "dbCode" . }} 22 | {{end}} 23 | 24 | {{define "dbCode"}} 25 | 26 | {{if .SQLDriver.IsPGX }} 27 | {{- template "dbCodeTemplatePgx" .}} 28 | {{else}} 29 | {{- template "dbCodeTemplateStd" .}} 30 | {{end}} 31 | 32 | {{end}} 33 | 34 | {{define "interfaceFile"}} 35 | {{if .BuildTags}} 36 | //go:build {{.BuildTags}} 37 | 38 | {{end}}// Code generated by sqlc. DO NOT EDIT. 39 | {{if not .OmitSqlcVersion}}// versions: 40 | // sqlc {{.SqlcVersion}} 41 | {{end}} 42 | 43 | package {{.Package}} 44 | 45 | {{ if hasImports .SourceName }} 46 | import ( 47 | {{range imports .SourceName}} 48 | {{range .}}{{.}} 49 | {{end}} 50 | {{end}} 51 | ) 52 | {{end}} 53 | 54 | {{template "interfaceCode" . }} 55 | {{end}} 56 | 57 | {{define "interfaceCode"}} 58 | {{if .SQLDriver.IsPGX }} 59 | {{- template "interfaceCodePgx" .}} 60 | {{else}} 61 | {{- template "interfaceCodeStd" .}} 62 | {{end}} 63 | {{end}} 64 | 65 | {{define "modelsFile"}} 66 | {{if .BuildTags}} 67 | //go:build {{.BuildTags}} 68 | 69 | {{end}}// Code generated by sqlc. DO NOT EDIT. 70 | {{if not .OmitSqlcVersion}}// versions: 71 | // sqlc {{.SqlcVersion}} 72 | {{end}} 73 | 74 | package {{.Package}} 75 | 76 | {{ if hasImports .SourceName }} 77 | import ( 78 | {{range imports .SourceName}} 79 | {{range .}}{{.}} 80 | {{end}} 81 | {{end}} 82 | ) 83 | {{end}} 84 | 85 | {{template "modelsCode" . }} 86 | {{end}} 87 | 88 | {{define "modelsCode"}} 89 | {{range .Enums}} 90 | {{if .Comment}}{{comment .Comment}}{{end}} 91 | type {{.Name}} string 92 | 93 | const ( 94 | {{- range .Constants}} 95 | {{.Name}} {{.Type}} = "{{.Value}}" 96 | {{- end}} 97 | ) 98 | 99 | func (e *{{.Name}}) Scan(src interface{}) error { 100 | switch s := src.(type) { 101 | case []byte: 102 | *e = {{.Name}}(s) 103 | case string: 104 | *e = {{.Name}}(s) 105 | default: 106 | return fmt.Errorf("unsupported scan type for {{.Name}}: %T", src) 107 | } 108 | return nil 109 | } 110 | 111 | type Null{{.Name}} struct { 112 | {{.Name}} {{.Name}} {{if .NameTag}}{{$.Q}}{{.NameTag}}{{$.Q}}{{end}} 113 | Valid bool {{if .ValidTag}}{{$.Q}}{{.ValidTag}}{{$.Q}}{{end}} // Valid is true if {{.Name}} is not NULL 114 | } 115 | 116 | // Scan implements the Scanner interface. 117 | func (ns *Null{{.Name}}) Scan(value interface{}) error { 118 | if value == nil { 119 | ns.{{.Name}}, ns.Valid = "", false 120 | return nil 121 | } 122 | ns.Valid = true 123 | return ns.{{.Name}}.Scan(value) 124 | } 125 | 126 | // Value implements the driver Valuer interface. 127 | func (ns Null{{.Name}}) Value() (driver.Value, error) { 128 | if !ns.Valid { 129 | return nil, nil 130 | } 131 | return string(ns.{{.Name}}), nil 132 | } 133 | 134 | 135 | {{ if $.EmitEnumValidMethod }} 136 | func (e {{.Name}}) Valid() bool { 137 | switch e { 138 | case {{ range $idx, $name := .Constants }}{{ if ne $idx 0 }},{{ "\n" }}{{ end }}{{ .Name }}{{ end }}: 139 | return true 140 | } 141 | return false 142 | } 143 | {{ end }} 144 | 145 | {{ if $.EmitAllEnumValues }} 146 | func All{{ .Name }}Values() []{{ .Name }} { 147 | return []{{ .Name }}{ {{ range .Constants}}{{ "\n" }}{{ .Name }},{{ end }} 148 | } 149 | } 150 | {{ end }} 151 | {{end}} 152 | 153 | {{range .Structs}} 154 | {{if .Comment}}{{comment .Comment}}{{end}} 155 | type {{.Name}} struct { {{- range .Fields}} 156 | {{- if .Comment}} 157 | {{comment .Comment}}{{else}} 158 | {{- end}} 159 | {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} 160 | {{- end}} 161 | } 162 | {{end}} 163 | {{end}} 164 | 165 | {{define "queryFile"}} 166 | {{if .BuildTags}} 167 | //go:build {{.BuildTags}} 168 | 169 | {{end}}// Code generated by sqlc. DO NOT EDIT. 170 | {{if not .OmitSqlcVersion}}// versions: 171 | // sqlc {{.SqlcVersion}} 172 | {{end}}// source: {{.SourceName}} 173 | 174 | package {{.Package}} 175 | 176 | {{ if hasImports .SourceName }} 177 | import ( 178 | {{range imports .SourceName}} 179 | {{range .}}{{.}} 180 | {{end}} 181 | {{end}} 182 | ) 183 | {{end}} 184 | 185 | {{template "queryCode" . }} 186 | {{end}} 187 | 188 | {{define "queryCode"}} 189 | {{if .SQLDriver.IsPGX }} 190 | {{- template "queryCodePgx" .}} 191 | {{else}} 192 | {{- template "queryCodeStd" .}} 193 | {{end}} 194 | {{end}} 195 | 196 | {{define "copyfromFile"}} 197 | {{if .BuildTags}} 198 | //go:build {{.BuildTags}} 199 | 200 | {{end}}// Code generated by sqlc. DO NOT EDIT. 201 | {{if not .OmitSqlcVersion}}// versions: 202 | // sqlc {{.SqlcVersion}} 203 | {{end}}// source: {{.SourceName}} 204 | 205 | package {{.Package}} 206 | 207 | {{ if hasImports .SourceName }} 208 | import ( 209 | {{range imports .SourceName}} 210 | {{range .}}{{.}} 211 | {{end}} 212 | {{end}} 213 | ) 214 | {{end}} 215 | 216 | {{template "copyfromCode" . }} 217 | {{end}} 218 | 219 | {{define "copyfromCode"}} 220 | {{if .SQLDriver.IsPGX }} 221 | {{- template "copyfromCodePgx" .}} 222 | {{else if .SQLDriver.IsGoSQLDriverMySQL }} 223 | {{- template "copyfromCodeGoSqlDriver" .}} 224 | {{end}} 225 | {{end}} 226 | 227 | {{define "batchFile"}} 228 | {{if .BuildTags}} 229 | //go:build {{.BuildTags}} 230 | 231 | {{end}}// Code generated by sqlc. DO NOT EDIT. 232 | {{if not .OmitSqlcVersion}}// versions: 233 | // sqlc {{.SqlcVersion}} 234 | {{end}}// source: {{.SourceName}} 235 | 236 | package {{.Package}} 237 | 238 | {{ if hasImports .SourceName }} 239 | import ( 240 | {{range imports .SourceName}} 241 | {{range .}}{{.}} 242 | {{end}} 243 | {{end}} 244 | ) 245 | {{end}} 246 | 247 | {{template "batchCode" . }} 248 | {{end}} 249 | 250 | {{define "batchCode"}} 251 | {{if .SQLDriver.IsPGX }} 252 | {{- template "batchCodePgx" .}} 253 | {{end}} 254 | {{end}} 255 | -------------------------------------------------------------------------------- /plugin/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/sqlc-dev/plugin-sdk-go/codegen" 5 | 6 | golang "github.com/sqlc-dev/sqlc-gen-go/internal" 7 | ) 8 | 9 | func main() { 10 | codegen.Run(golang.Generate) 11 | } 12 | --------------------------------------------------------------------------------