├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── other-issues.md └── workflows │ └── ci.yml ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── Rakefile ├── batch.go ├── batch_test.go ├── bench_test.go ├── ci └── setup_test.bash ├── conn.go ├── conn_test.go ├── copy_from.go ├── copy_from_test.go ├── doc.go ├── examples ├── README.md ├── chat │ ├── README.md │ └── main.go ├── todo │ ├── README.md │ ├── main.go │ └── structure.sql └── url_shortener │ ├── README.md │ ├── main.go │ └── structure.sql ├── extended_query_builder.go ├── go.mod ├── go.sum ├── helper_test.go ├── internal ├── anynil │ └── anynil.go ├── iobufpool │ ├── iobufpool.go │ ├── iobufpool_internal_test.go │ └── iobufpool_test.go ├── nbconn │ ├── bufferqueue.go │ ├── nbconn.go │ ├── nbconn_fake_non_block.go │ ├── nbconn_real_non_block.go │ └── nbconn_test.go ├── pgio │ ├── README.md │ ├── doc.go │ ├── write.go │ └── write_test.go ├── pgmock │ ├── pgmock.go │ └── pgmock_test.go ├── sanitize │ ├── sanitize.go │ └── sanitize_test.go └── stmtcache │ ├── lru_cache.go │ ├── stmtcache.go │ └── unlimited_cache.go ├── large_objects.go ├── large_objects_test.go ├── log └── testingadapter │ └── adapter.go ├── named_args.go ├── named_args_test.go ├── pgbouncer_test.go ├── pgconn ├── README.md ├── auth_scram.go ├── benchmark_private_test.go ├── benchmark_test.go ├── config.go ├── config_test.go ├── defaults.go ├── defaults_windows.go ├── doc.go ├── errors.go ├── errors_test.go ├── export_test.go ├── helper_test.go ├── internal │ └── ctxwatch │ │ ├── context_watcher.go │ │ └── context_watcher_test.go ├── krb5.go ├── pgconn.go ├── pgconn_private_test.go ├── pgconn_stress_test.go └── pgconn_test.go ├── pgproto3 ├── README.md ├── authentication_cleartext_password.go ├── authentication_gss.go ├── authentication_gss_continue.go ├── authentication_md5_password.go ├── authentication_ok.go ├── authentication_sasl.go ├── authentication_sasl_continue.go ├── authentication_sasl_final.go ├── backend.go ├── backend_key_data.go ├── backend_test.go ├── big_endian.go ├── bind.go ├── bind_complete.go ├── cancel_request.go ├── chunkreader.go ├── chunkreader_test.go ├── close.go ├── close_complete.go ├── command_complete.go ├── copy_both_response.go ├── copy_both_response_test.go ├── copy_data.go ├── copy_done.go ├── copy_fail.go ├── copy_in_response.go ├── copy_out_response.go ├── data_row.go ├── describe.go ├── doc.go ├── empty_query_response.go ├── error_response.go ├── example │ └── pgfortune │ │ ├── README.md │ │ ├── main.go │ │ └── server.go ├── execute.go ├── flush.go ├── frontend.go ├── frontend_test.go ├── function_call.go ├── function_call_response.go ├── function_call_test.go ├── fuzz_test.go ├── gss_enc_request.go ├── gss_response.go ├── json_test.go ├── no_data.go ├── notice_response.go ├── notification_response.go ├── parameter_description.go ├── parameter_status.go ├── parse.go ├── parse_complete.go ├── password_message.go ├── pgproto3.go ├── portal_suspended.go ├── query.go ├── ready_for_query.go ├── row_description.go ├── sasl_initial_response.go ├── sasl_response.go ├── ssl_request.go ├── startup_message.go ├── sync.go ├── terminate.go ├── testdata │ └── fuzz │ │ └── FuzzFrontend │ │ ├── 39c5e864da4707fc15fea48f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6 │ │ ├── 9b06792b1aaac8a907dbfa04d526ae14326c8573b7409032caac8461e83065f7 │ │ ├── a661fb98e802839f0a7361160fbc6e28794612a411d00bde104364ee281c4214 │ │ └── fc98dcd487a5173b38763a5f7dd023933f3a86ab566e3f2b091eb36248107eb4 ├── trace.go └── trace_test.go ├── pgtype ├── array.go ├── array_codec.go ├── array_codec_test.go ├── array_test.go ├── bits.go ├── bits_test.go ├── bool.go ├── bool_test.go ├── box.go ├── box_test.go ├── builtin_wrappers.go ├── bytea.go ├── bytea_test.go ├── circle.go ├── circle_test.go ├── composite.go ├── composite_test.go ├── convert.go ├── date.go ├── date_test.go ├── doc.go ├── enum_codec.go ├── enum_codec_test.go ├── example_child_records_test.go ├── example_custom_type_test.go ├── example_json_test.go ├── float4.go ├── float4_test.go ├── float8.go ├── float8_test.go ├── hstore.go ├── hstore_test.go ├── inet.go ├── inet_test.go ├── int.go ├── int.go.erb ├── int_test.go ├── int_test.go.erb ├── integration_benchmark_test.go ├── integration_benchmark_test.go.erb ├── integration_benchmark_test_gen.sh ├── interval.go ├── interval_test.go ├── json.go ├── json_test.go ├── jsonb.go ├── jsonb_test.go ├── line.go ├── line_test.go ├── lseg.go ├── lseg_test.go ├── macaddr.go ├── macaddr_test.go ├── multirange.go ├── multirange_test.go ├── numeric.go ├── numeric_test.go ├── path.go ├── path_test.go ├── pgtype.go ├── pgtype_test.go ├── point.go ├── point_test.go ├── polygon.go ├── polygon_test.go ├── qchar.go ├── qchar_test.go ├── range.go ├── range_codec.go ├── range_codec_test.go ├── range_test.go ├── record_codec.go ├── record_codec_test.go ├── register_default_pg_types.go ├── register_default_pg_types_disabled.go ├── text.go ├── text_format_only_codec.go ├── text_test.go ├── tid.go ├── tid_test.go ├── time.go ├── time_test.go ├── timestamp.go ├── timestamp_test.go ├── timestamptz.go ├── timestamptz_test.go ├── uint32.go ├── uint32_test.go ├── uuid.go ├── uuid_test.go └── zeronull │ ├── doc.go │ ├── float8.go │ ├── float8_test.go │ ├── int.go │ ├── int.go.erb │ ├── int_test.go │ ├── int_test.go.erb │ ├── text.go │ ├── text_test.go │ ├── timestamp.go │ ├── timestamp_test.go │ ├── timestamptz.go │ ├── timestamptz_test.go │ ├── uuid.go │ ├── uuid_test.go │ ├── zeronull.go │ └── zeronull_test.go ├── pgxpool ├── batch_results.go ├── bench_test.go ├── common_test.go ├── conn.go ├── conn_test.go ├── doc.go ├── pool.go ├── pool_test.go ├── rows.go ├── stat.go ├── tx.go └── tx_test.go ├── pgxtest └── pgxtest.go ├── pipeline_test.go ├── query_test.go ├── rows.go ├── rows_test.go ├── stdlib ├── bench_test.go ├── sql.go └── sql_test.go ├── testsetup ├── README.md ├── ca.cnf ├── localhost.cnf ├── pg_hba.conf ├── pgx_sslcert.cnf ├── postgresql_setup.sql └── postgresql_ssl.conf ├── tracelog ├── tracelog.go └── tracelog_test.go ├── tracer.go ├── tracer_test.go ├── tx.go ├── tx_test.go ├── values.go └── values_test.go /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 16 | If possible, please provide runnable example such as: 17 | 18 | ```go 19 | package main 20 | 21 | import ( 22 | "context" 23 | "log" 24 | "os" 25 | 26 | "github.com/jackc/pgx/v5" 27 | ) 28 | 29 | func main() { 30 | conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) 31 | if err != nil { 32 | log.Fatal(err) 33 | } 34 | defer conn.Close(context.Background()) 35 | 36 | // Your code here... 37 | } 38 | ``` 39 | 40 | **Expected behavior** 41 | A clear and concise description of what you expected to happen. 42 | 43 | **Actual behavior** 44 | A clear and concise description of what actually happened. 45 | 46 | **Version** 47 | - Go: `$ go version` -> [e.g. go version go1.18.3 darwin/amd64] 48 | - PostgreSQL: `$ psql --no-psqlrc --tuples-only -c 'select version()'` -> [e.g. PostgreSQL 14.4 on x86_64-apple-darwin21.5.0, compiled by Apple clang version 13.1.6 (clang-1316.0.21.2.5), 64-bit] 49 | - pgx: `$ grep 'github.com/jackc/pgx/v[0-9]' go.mod` -> [e.g. v4.16.1] 50 | 51 | **Additional context** 52 | Add any other context about the problem here. 53 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/other-issues.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Other issues 3 | about: Any issue that is not a bug or a feature request 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | Please describe the issue in detail. If this is a question about how to use pgx please use discussions instead. 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | 24 | .envrc 25 | /.testdb 26 | 27 | .DS_Store 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013-2021 Jack Christensen 2 | 3 | MIT License 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining 6 | a copy of this software and associated documentation files (the 7 | "Software"), to deal in the Software without restriction, including 8 | without limitation the rights to use, copy, modify, merge, publish, 9 | distribute, sublicense, and/or sell copies of the Software, and to 10 | permit persons to whom the Software is furnished to do so, subject to 11 | the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be 14 | included in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 19 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 20 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 22 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | require "erb" 2 | 3 | rule '.go' => '.go.erb' do |task| 4 | erb = ERB.new(File.read(task.source)) 5 | File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding)) 6 | sh "goimports", "-w", task.name 7 | end 8 | 9 | generated_code_files = [ 10 | "pgtype/int.go", 11 | "pgtype/int_test.go", 12 | "pgtype/integration_benchmark_test.go", 13 | "pgtype/zeronull/int.go", 14 | "pgtype/zeronull/int_test.go" 15 | ] 16 | 17 | desc "Generate code" 18 | task generate: generated_code_files 19 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | * chat is a command line chat program using listen/notify. 4 | * todo is a command line todo list that demonstrates basic CRUD actions. 5 | * url_shortener contains a simple example of using pgx in a web context. 6 | * [Tern](https://github.com/jackc/tern) is a migration tool that uses pgx. 7 | * [The Pithy Reader](https://github.com/jackc/tpr) is a RSS aggregator that uses pgx. 8 | -------------------------------------------------------------------------------- /examples/chat/README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | This is a sample chat program implemented using PostgreSQL's listen/notify 4 | functionality with pgx. 5 | 6 | Start multiple instances of this program connected to the same database to chat 7 | between them. 8 | 9 | ## Connection configuration 10 | 11 | The database connection is configured via DATABASE_URL and standard PostgreSQL environment variables (PGHOST, PGUSER, etc.) 12 | 13 | You can either export them then run chat: 14 | 15 | export PGHOST=/private/tmp 16 | ./chat 17 | 18 | Or you can prefix the chat execution with the environment variables: 19 | 20 | PGHOST=/private/tmp ./chat 21 | -------------------------------------------------------------------------------- /examples/chat/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "fmt" 7 | "os" 8 | 9 | "github.com/jackc/pgx/v5/pgxpool" 10 | ) 11 | 12 | var pool *pgxpool.Pool 13 | 14 | func main() { 15 | var err error 16 | pool, err = pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) 17 | if err != nil { 18 | fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) 19 | os.Exit(1) 20 | } 21 | 22 | go listen() 23 | 24 | fmt.Println(`Type a message and press enter. 25 | 26 | This message should appear in any other chat instances connected to the same 27 | database. 28 | 29 | Type "exit" to quit.`) 30 | 31 | scanner := bufio.NewScanner(os.Stdin) 32 | for scanner.Scan() { 33 | msg := scanner.Text() 34 | if msg == "exit" { 35 | os.Exit(0) 36 | } 37 | 38 | _, err = pool.Exec(context.Background(), "select pg_notify('chat', $1)", msg) 39 | if err != nil { 40 | fmt.Fprintln(os.Stderr, "Error sending notification:", err) 41 | os.Exit(1) 42 | } 43 | } 44 | if err := scanner.Err(); err != nil { 45 | fmt.Fprintln(os.Stderr, "Error scanning from stdin:", err) 46 | os.Exit(1) 47 | } 48 | } 49 | 50 | func listen() { 51 | conn, err := pool.Acquire(context.Background()) 52 | if err != nil { 53 | fmt.Fprintln(os.Stderr, "Error acquiring connection:", err) 54 | os.Exit(1) 55 | } 56 | defer conn.Release() 57 | 58 | _, err = conn.Exec(context.Background(), "listen chat") 59 | if err != nil { 60 | fmt.Fprintln(os.Stderr, "Error listening to chat channel:", err) 61 | os.Exit(1) 62 | } 63 | 64 | for { 65 | notification, err := conn.Conn().WaitForNotification(context.Background()) 66 | if err != nil { 67 | fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) 68 | os.Exit(1) 69 | } 70 | 71 | fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /examples/todo/README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | This is a sample todo list implemented using pgx as the connector to a 4 | PostgreSQL data store. 5 | 6 | # Usage 7 | 8 | Create a PostgreSQL database and run structure.sql into it to create the 9 | necessary data schema. 10 | 11 | Example: 12 | 13 | createdb todo 14 | psql todo < structure.sql 15 | 16 | Build todo: 17 | 18 | go build 19 | 20 | ## Connection configuration 21 | 22 | The database connection is configured via DATABASE_URL and standard PostgreSQL environment variables (PGHOST, PGUSER, etc.) 23 | 24 | You can either export them then run todo: 25 | 26 | export PGDATABASE=todo 27 | ./todo list 28 | 29 | Or you can prefix the todo execution with the environment variables: 30 | 31 | PGDATABASE=todo ./todo list 32 | 33 | ## Add a todo item 34 | 35 | ./todo add 'Learn go' 36 | 37 | ## List tasks 38 | 39 | ./todo list 40 | 41 | ## Update a task 42 | 43 | ./todo update 1 'Learn more go' 44 | 45 | ## Delete a task 46 | 47 | ./todo remove 1 48 | 49 | # Example Setup and Execution 50 | 51 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ createdb todo 52 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ psql todo < structure.sql 53 | Expanded display is used automatically. 54 | Timing is on. 55 | CREATE TABLE 56 | Time: 6.363 ms 57 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ go build 58 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ export PGDATABASE=todo 59 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list 60 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo add 'Learn Go' 61 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list 62 | 1. Learn Go 63 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo update 1 'Learn more Go' 64 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list 65 | 1. Learn more Go 66 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo remove 1 67 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list 68 | -------------------------------------------------------------------------------- /examples/todo/structure.sql: -------------------------------------------------------------------------------- 1 | create table tasks ( 2 | id serial primary key, 3 | description text not null 4 | ); 5 | -------------------------------------------------------------------------------- /examples/url_shortener/README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | This is a sample REST URL shortener service implemented using pgx as the connector to a PostgreSQL data store. 4 | 5 | # Usage 6 | 7 | Create a PostgreSQL database and run structure.sql into it to create the necessary data schema. 8 | 9 | Configure the database connection with `DATABASE_URL` or standard PostgreSQL (`PG*`) environment variables or 10 | 11 | Run main.go: 12 | 13 | ``` 14 | go run main.go 15 | ``` 16 | 17 | ## Create or Update a Shortened URL 18 | 19 | ``` 20 | curl -X PUT -d 'http://www.google.com' http://localhost:8080/google 21 | ``` 22 | 23 | ## Get a Shortened URL 24 | 25 | ``` 26 | curl http://localhost:8080/google 27 | ``` 28 | 29 | ## Delete a Shortened URL 30 | 31 | ``` 32 | curl -X DELETE http://localhost:8080/google 33 | ``` 34 | -------------------------------------------------------------------------------- /examples/url_shortener/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "io/ioutil" 6 | "log" 7 | "net/http" 8 | "os" 9 | 10 | "github.com/jackc/pgx/v5" 11 | "github.com/jackc/pgx/v5/pgxpool" 12 | ) 13 | 14 | var db *pgxpool.Pool 15 | 16 | func getUrlHandler(w http.ResponseWriter, req *http.Request) { 17 | var url string 18 | err := db.QueryRow(context.Background(), "select url from shortened_urls where id=$1", req.URL.Path).Scan(&url) 19 | switch err { 20 | case nil: 21 | http.Redirect(w, req, url, http.StatusSeeOther) 22 | case pgx.ErrNoRows: 23 | http.NotFound(w, req) 24 | default: 25 | http.Error(w, "Internal server error", http.StatusInternalServerError) 26 | } 27 | } 28 | 29 | func putUrlHandler(w http.ResponseWriter, req *http.Request) { 30 | id := req.URL.Path 31 | var url string 32 | if body, err := ioutil.ReadAll(req.Body); err == nil { 33 | url = string(body) 34 | } else { 35 | http.Error(w, "Internal server error", http.StatusInternalServerError) 36 | return 37 | } 38 | 39 | if _, err := db.Exec(context.Background(), `insert into shortened_urls(id, url) values ($1, $2) 40 | on conflict (id) do update set url=excluded.url`, id, url); err == nil { 41 | w.WriteHeader(http.StatusOK) 42 | } else { 43 | http.Error(w, "Internal server error", http.StatusInternalServerError) 44 | } 45 | } 46 | 47 | func deleteUrlHandler(w http.ResponseWriter, req *http.Request) { 48 | if _, err := db.Exec(context.Background(), "delete from shortened_urls where id=$1", req.URL.Path); err == nil { 49 | w.WriteHeader(http.StatusOK) 50 | } else { 51 | http.Error(w, "Internal server error", http.StatusInternalServerError) 52 | } 53 | } 54 | 55 | func urlHandler(w http.ResponseWriter, req *http.Request) { 56 | switch req.Method { 57 | case "GET": 58 | getUrlHandler(w, req) 59 | 60 | case "PUT": 61 | putUrlHandler(w, req) 62 | 63 | case "DELETE": 64 | deleteUrlHandler(w, req) 65 | 66 | default: 67 | w.Header().Add("Allow", "GET, PUT, DELETE") 68 | w.WriteHeader(http.StatusMethodNotAllowed) 69 | } 70 | } 71 | 72 | func main() { 73 | poolConfig, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) 74 | if err != nil { 75 | log.Fatalln("Unable to parse DATABASE_URL:", err) 76 | } 77 | 78 | db, err = pgxpool.NewWithConfig(context.Background(), poolConfig) 79 | if err != nil { 80 | log.Fatalln("Unable to create connection pool:", err) 81 | } 82 | 83 | http.HandleFunc("/", urlHandler) 84 | 85 | log.Println("Starting URL shortener on localhost:8080") 86 | err = http.ListenAndServe("localhost:8080", nil) 87 | if err != nil { 88 | log.Fatalln("Unable to start web server:", err) 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /examples/url_shortener/structure.sql: -------------------------------------------------------------------------------- 1 | create table shortened_urls ( 2 | id text primary key, 3 | url text not null 4 | ); -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/jackc/pgx/v5 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/jackc/pgpassfile v1.0.0 7 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a 8 | github.com/jackc/puddle/v2 v2.1.3-0.20230114152537-cc12efc05a26 9 | github.com/stretchr/testify v1.8.0 10 | golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 11 | golang.org/x/text v0.3.8 12 | ) 13 | 14 | require ( 15 | github.com/davecgh/go-spew v1.1.1 // indirect 16 | github.com/kr/pretty v0.3.0 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | golang.org/x/sync v0.0.0-20220923202941-7f9b1623fab7 // indirect 19 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 20 | gopkg.in/yaml.v3 v3.0.1 // indirect 21 | ) 22 | -------------------------------------------------------------------------------- /internal/anynil/anynil.go: -------------------------------------------------------------------------------- 1 | package anynil 2 | 3 | import "reflect" 4 | 5 | // Is returns true if value is any type of nil. e.g. nil or []byte(nil). 6 | func Is(value any) bool { 7 | if value == nil { 8 | return true 9 | } 10 | 11 | refVal := reflect.ValueOf(value) 12 | switch refVal.Kind() { 13 | case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: 14 | return refVal.IsNil() 15 | default: 16 | return false 17 | } 18 | } 19 | 20 | // Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified. 21 | func Normalize(v any) any { 22 | if Is(v) { 23 | return nil 24 | } 25 | return v 26 | } 27 | 28 | // NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is 29 | // mutated in place. 30 | func NormalizeSlice(s []any) { 31 | for i := range s { 32 | if Is(s[i]) { 33 | s[i] = nil 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /internal/iobufpool/iobufpool.go: -------------------------------------------------------------------------------- 1 | // Package iobufpool implements a global segregated-fit pool of buffers for IO. 2 | // 3 | // It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid 4 | // an allocation is purposely not documented. https://github.com/golang/go/issues/16323 5 | package iobufpool 6 | 7 | import "sync" 8 | 9 | const minPoolExpOf2 = 8 10 | 11 | var pools [18]*sync.Pool 12 | 13 | func init() { 14 | for i := range pools { 15 | bufLen := 1 << (minPoolExpOf2 + i) 16 | pools[i] = &sync.Pool{ 17 | New: func() any { 18 | buf := make([]byte, bufLen) 19 | return &buf 20 | }, 21 | } 22 | } 23 | } 24 | 25 | // Get gets a []byte of len size with cap <= size*2. 26 | func Get(size int) *[]byte { 27 | i := getPoolIdx(size) 28 | if i >= len(pools) { 29 | buf := make([]byte, size) 30 | return &buf 31 | } 32 | 33 | ptrBuf := (pools[i].Get().(*[]byte)) 34 | *ptrBuf = (*ptrBuf)[:size] 35 | 36 | return ptrBuf 37 | } 38 | 39 | func getPoolIdx(size int) int { 40 | size-- 41 | size >>= minPoolExpOf2 42 | i := 0 43 | for size > 0 { 44 | size >>= 1 45 | i++ 46 | } 47 | 48 | return i 49 | } 50 | 51 | // Put returns buf to the pool. 52 | func Put(buf *[]byte) { 53 | i := putPoolIdx(cap(*buf)) 54 | if i < 0 { 55 | return 56 | } 57 | 58 | pools[i].Put(buf) 59 | } 60 | 61 | func putPoolIdx(size int) int { 62 | minPoolSize := 1 << minPoolExpOf2 63 | for i := range pools { 64 | if size == minPoolSize<= len(bq.queue) { 20 | bq.growQueue() 21 | } 22 | bq.queue[bq.w] = buf 23 | bq.w++ 24 | } 25 | 26 | func (bq *bufferQueue) pushFront(buf *[]byte) { 27 | bq.lock.Lock() 28 | defer bq.lock.Unlock() 29 | 30 | if bq.w >= len(bq.queue) { 31 | bq.growQueue() 32 | } 33 | copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w]) 34 | bq.queue[bq.r] = buf 35 | bq.w++ 36 | } 37 | 38 | func (bq *bufferQueue) popFront() *[]byte { 39 | bq.lock.Lock() 40 | defer bq.lock.Unlock() 41 | 42 | if bq.r == bq.w { 43 | return nil 44 | } 45 | 46 | buf := bq.queue[bq.r] 47 | bq.queue[bq.r] = nil // Clear reference so it can be garbage collected. 48 | bq.r++ 49 | 50 | if bq.r == bq.w { 51 | bq.r = 0 52 | bq.w = 0 53 | if len(bq.queue) > minBufferQueueLen { 54 | bq.queue = make([]*[]byte, minBufferQueueLen) 55 | } 56 | } 57 | 58 | return buf 59 | } 60 | 61 | func (bq *bufferQueue) growQueue() { 62 | desiredLen := (len(bq.queue) + 1) * 3 / 2 63 | if desiredLen < minBufferQueueLen { 64 | desiredLen = minBufferQueueLen 65 | } 66 | 67 | newQueue := make([]*[]byte, desiredLen) 68 | copy(newQueue, bq.queue) 69 | bq.queue = newQueue 70 | } 71 | -------------------------------------------------------------------------------- /internal/nbconn/nbconn_fake_non_block.go: -------------------------------------------------------------------------------- 1 | //go:build !unix 2 | 3 | package nbconn 4 | 5 | func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { 6 | return c.fakeNonblockingWrite(b) 7 | } 8 | 9 | func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { 10 | return c.fakeNonblockingRead(b) 11 | } 12 | -------------------------------------------------------------------------------- /internal/nbconn/nbconn_real_non_block.go: -------------------------------------------------------------------------------- 1 | //go:build unix 2 | 3 | package nbconn 4 | 5 | import ( 6 | "errors" 7 | "io" 8 | "syscall" 9 | ) 10 | 11 | // realNonblockingWrite does a non-blocking write. readFlushLock must already be held. 12 | func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { 13 | if c.nonblockWriteFunc == nil { 14 | c.nonblockWriteFunc = func(fd uintptr) (done bool) { 15 | c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) 16 | return true 17 | } 18 | } 19 | c.nonblockWriteBuf = b 20 | c.nonblockWriteN = 0 21 | c.nonblockWriteErr = nil 22 | 23 | err = c.rawConn.Write(c.nonblockWriteFunc) 24 | n = c.nonblockWriteN 25 | c.nonblockWriteBuf = nil // ensure that no reference to b is kept. 26 | if err == nil && c.nonblockWriteErr != nil { 27 | if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { 28 | err = ErrWouldBlock 29 | } else { 30 | err = c.nonblockWriteErr 31 | } 32 | } 33 | if err != nil { 34 | // n may be -1 when an error occurs. 35 | if n < 0 { 36 | n = 0 37 | } 38 | 39 | return n, err 40 | } 41 | 42 | return n, nil 43 | } 44 | 45 | func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { 46 | if c.nonblockReadFunc == nil { 47 | c.nonblockReadFunc = func(fd uintptr) (done bool) { 48 | c.nonblockReadN, c.nonblockReadErr = syscall.Read(int(fd), c.nonblockReadBuf) 49 | return true 50 | } 51 | } 52 | c.nonblockReadBuf = b 53 | c.nonblockReadN = 0 54 | c.nonblockReadErr = nil 55 | 56 | err = c.rawConn.Read(c.nonblockReadFunc) 57 | n = c.nonblockReadN 58 | c.nonblockReadBuf = nil // ensure that no reference to b is kept. 59 | if err == nil && c.nonblockReadErr != nil { 60 | if errors.Is(c.nonblockReadErr, syscall.EWOULDBLOCK) { 61 | err = ErrWouldBlock 62 | } else { 63 | err = c.nonblockReadErr 64 | } 65 | } 66 | if err != nil { 67 | // n may be -1 when an error occurs. 68 | if n < 0 { 69 | n = 0 70 | } 71 | 72 | return n, err 73 | } 74 | 75 | // syscall read did not return an error and 0 bytes were read means EOF. 76 | if n == 0 { 77 | return 0, io.EOF 78 | } 79 | 80 | return n, nil 81 | } 82 | -------------------------------------------------------------------------------- /internal/pgio/README.md: -------------------------------------------------------------------------------- 1 | # pgio 2 | 3 | Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. 4 | 5 | pgio provides functions for appending integers to a []byte while doing byte 6 | order conversion. 7 | -------------------------------------------------------------------------------- /internal/pgio/doc.go: -------------------------------------------------------------------------------- 1 | // Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. 2 | /* 3 | pgio provides functions for appending integers to a []byte while doing byte 4 | order conversion. 5 | */ 6 | package pgio 7 | -------------------------------------------------------------------------------- /internal/pgio/write.go: -------------------------------------------------------------------------------- 1 | package pgio 2 | 3 | import "encoding/binary" 4 | 5 | func AppendUint16(buf []byte, n uint16) []byte { 6 | wp := len(buf) 7 | buf = append(buf, 0, 0) 8 | binary.BigEndian.PutUint16(buf[wp:], n) 9 | return buf 10 | } 11 | 12 | func AppendUint32(buf []byte, n uint32) []byte { 13 | wp := len(buf) 14 | buf = append(buf, 0, 0, 0, 0) 15 | binary.BigEndian.PutUint32(buf[wp:], n) 16 | return buf 17 | } 18 | 19 | func AppendUint64(buf []byte, n uint64) []byte { 20 | wp := len(buf) 21 | buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) 22 | binary.BigEndian.PutUint64(buf[wp:], n) 23 | return buf 24 | } 25 | 26 | func AppendInt16(buf []byte, n int16) []byte { 27 | return AppendUint16(buf, uint16(n)) 28 | } 29 | 30 | func AppendInt32(buf []byte, n int32) []byte { 31 | return AppendUint32(buf, uint32(n)) 32 | } 33 | 34 | func AppendInt64(buf []byte, n int64) []byte { 35 | return AppendUint64(buf, uint64(n)) 36 | } 37 | 38 | func SetInt32(buf []byte, n int32) { 39 | binary.BigEndian.PutUint32(buf, uint32(n)) 40 | } 41 | -------------------------------------------------------------------------------- /internal/pgio/write_test.go: -------------------------------------------------------------------------------- 1 | package pgio 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestAppendUint16NilBuf(t *testing.T) { 9 | buf := AppendUint16(nil, 1) 10 | if !reflect.DeepEqual(buf, []byte{0, 1}) { 11 | t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) 12 | } 13 | } 14 | 15 | func TestAppendUint16EmptyBuf(t *testing.T) { 16 | buf := []byte{} 17 | buf = AppendUint16(buf, 1) 18 | if !reflect.DeepEqual(buf, []byte{0, 1}) { 19 | t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) 20 | } 21 | } 22 | 23 | func TestAppendUint16BufWithCapacityDoesNotAllocate(t *testing.T) { 24 | buf := make([]byte, 0, 4) 25 | AppendUint16(buf, 1) 26 | buf = buf[0:2] 27 | if !reflect.DeepEqual(buf, []byte{0, 1}) { 28 | t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) 29 | } 30 | } 31 | 32 | func TestAppendUint32NilBuf(t *testing.T) { 33 | buf := AppendUint32(nil, 1) 34 | if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { 35 | t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) 36 | } 37 | } 38 | 39 | func TestAppendUint32EmptyBuf(t *testing.T) { 40 | buf := []byte{} 41 | buf = AppendUint32(buf, 1) 42 | if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { 43 | t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) 44 | } 45 | } 46 | 47 | func TestAppendUint32BufWithCapacityDoesNotAllocate(t *testing.T) { 48 | buf := make([]byte, 0, 4) 49 | AppendUint32(buf, 1) 50 | buf = buf[0:4] 51 | if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { 52 | t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) 53 | } 54 | } 55 | 56 | func TestAppendUint64NilBuf(t *testing.T) { 57 | buf := AppendUint64(nil, 1) 58 | if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { 59 | t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) 60 | } 61 | } 62 | 63 | func TestAppendUint64EmptyBuf(t *testing.T) { 64 | buf := []byte{} 65 | buf = AppendUint64(buf, 1) 66 | if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { 67 | t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) 68 | } 69 | } 70 | 71 | func TestAppendUint64BufWithCapacityDoesNotAllocate(t *testing.T) { 72 | buf := make([]byte, 0, 8) 73 | AppendUint64(buf, 1) 74 | buf = buf[0:8] 75 | if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { 76 | t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /internal/pgmock/pgmock_test.go: -------------------------------------------------------------------------------- 1 | package pgmock_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/jackc/pgx/v5/internal/pgmock" 12 | "github.com/jackc/pgx/v5/pgconn" 13 | "github.com/jackc/pgx/v5/pgproto3" 14 | 15 | "github.com/stretchr/testify/assert" 16 | "github.com/stretchr/testify/require" 17 | ) 18 | 19 | func TestScript(t *testing.T) { 20 | script := &pgmock.Script{ 21 | Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), 22 | } 23 | script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"})) 24 | script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{ 25 | Fields: []pgproto3.FieldDescription{ 26 | { 27 | Name: []byte("?column?"), 28 | TableOID: 0, 29 | TableAttributeNumber: 0, 30 | DataTypeOID: 23, 31 | DataTypeSize: 4, 32 | TypeModifier: -1, 33 | Format: 0, 34 | }, 35 | }, 36 | })) 37 | script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.DataRow{ 38 | Values: [][]byte{[]byte("42")}, 39 | })) 40 | script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")})) 41 | script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) 42 | script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Terminate{})) 43 | 44 | ln, err := net.Listen("tcp", "127.0.0.1:") 45 | require.NoError(t, err) 46 | defer ln.Close() 47 | 48 | serverErrChan := make(chan error, 1) 49 | go func() { 50 | defer close(serverErrChan) 51 | 52 | conn, err := ln.Accept() 53 | if err != nil { 54 | serverErrChan <- err 55 | return 56 | } 57 | defer conn.Close() 58 | 59 | err = conn.SetDeadline(time.Now().Add(time.Second)) 60 | if err != nil { 61 | serverErrChan <- err 62 | return 63 | } 64 | 65 | err = script.Run(pgproto3.NewBackend(conn, conn)) 66 | if err != nil { 67 | serverErrChan <- err 68 | return 69 | } 70 | }() 71 | 72 | parts := strings.Split(ln.Addr().String(), ":") 73 | host := parts[0] 74 | port := parts[1] 75 | connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) 76 | 77 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 78 | defer cancel() 79 | pgConn, err := pgconn.Connect(ctx, connStr) 80 | require.NoError(t, err) 81 | results, err := pgConn.Exec(ctx, "select 42").ReadAll() 82 | assert.NoError(t, err) 83 | 84 | assert.Len(t, results, 1) 85 | assert.Nil(t, results[0].Err) 86 | assert.Equal(t, "SELECT 1", results[0].CommandTag.String()) 87 | assert.Len(t, results[0].Rows, 1) 88 | assert.Equal(t, "42", string(results[0].Rows[0][0])) 89 | 90 | pgConn.Close(ctx) 91 | 92 | assert.NoError(t, <-serverErrChan) 93 | } 94 | -------------------------------------------------------------------------------- /internal/stmtcache/lru_cache.go: -------------------------------------------------------------------------------- 1 | package stmtcache 2 | 3 | import ( 4 | "container/list" 5 | 6 | "github.com/jackc/pgx/v5/pgconn" 7 | ) 8 | 9 | // LRUCache implements Cache with a Least Recently Used (LRU) cache. 10 | type LRUCache struct { 11 | cap int 12 | m map[string]*list.Element 13 | l *list.List 14 | invalidStmts []*pgconn.StatementDescription 15 | } 16 | 17 | // NewLRUCache creates a new LRUCache. cap is the maximum size of the cache. 18 | func NewLRUCache(cap int) *LRUCache { 19 | return &LRUCache{ 20 | cap: cap, 21 | m: make(map[string]*list.Element), 22 | l: list.New(), 23 | } 24 | } 25 | 26 | // Get returns the statement description for sql. Returns nil if not found. 27 | func (c *LRUCache) Get(key string) *pgconn.StatementDescription { 28 | if el, ok := c.m[key]; ok { 29 | c.l.MoveToFront(el) 30 | return el.Value.(*pgconn.StatementDescription) 31 | } 32 | 33 | return nil 34 | 35 | } 36 | 37 | // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. 38 | func (c *LRUCache) Put(sd *pgconn.StatementDescription) { 39 | if sd.SQL == "" { 40 | panic("cannot store statement description with empty SQL") 41 | } 42 | 43 | if _, present := c.m[sd.SQL]; present { 44 | return 45 | } 46 | 47 | if c.l.Len() == c.cap { 48 | c.invalidateOldest() 49 | } 50 | 51 | el := c.l.PushFront(sd) 52 | c.m[sd.SQL] = el 53 | } 54 | 55 | // Invalidate invalidates statement description identified by sql. Does nothing if not found. 56 | func (c *LRUCache) Invalidate(sql string) { 57 | if el, ok := c.m[sql]; ok { 58 | delete(c.m, sql) 59 | c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) 60 | c.l.Remove(el) 61 | } 62 | } 63 | 64 | // InvalidateAll invalidates all statement descriptions. 65 | func (c *LRUCache) InvalidateAll() { 66 | el := c.l.Front() 67 | for el != nil { 68 | c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) 69 | el = el.Next() 70 | } 71 | 72 | c.m = make(map[string]*list.Element) 73 | c.l = list.New() 74 | } 75 | 76 | func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription { 77 | invalidStmts := c.invalidStmts 78 | c.invalidStmts = nil 79 | return invalidStmts 80 | } 81 | 82 | // Len returns the number of cached prepared statement descriptions. 83 | func (c *LRUCache) Len() int { 84 | return c.l.Len() 85 | } 86 | 87 | // Cap returns the maximum number of cached prepared statement descriptions. 88 | func (c *LRUCache) Cap() int { 89 | return c.cap 90 | } 91 | 92 | func (c *LRUCache) invalidateOldest() { 93 | oldest := c.l.Back() 94 | sd := oldest.Value.(*pgconn.StatementDescription) 95 | c.invalidStmts = append(c.invalidStmts, sd) 96 | delete(c.m, sd.SQL) 97 | c.l.Remove(oldest) 98 | } 99 | -------------------------------------------------------------------------------- /internal/stmtcache/stmtcache.go: -------------------------------------------------------------------------------- 1 | // Package stmtcache is a cache for statement descriptions. 2 | package stmtcache 3 | 4 | import ( 5 | "strconv" 6 | "sync/atomic" 7 | 8 | "github.com/jackc/pgx/v5/pgconn" 9 | ) 10 | 11 | var stmtCounter int64 12 | 13 | // NextStatementName returns a statement name that will be unique for the lifetime of the program. 14 | func NextStatementName() string { 15 | n := atomic.AddInt64(&stmtCounter, 1) 16 | return "stmtcache_" + strconv.FormatInt(n, 10) 17 | } 18 | 19 | // Cache caches statement descriptions. 20 | type Cache interface { 21 | // Get returns the statement description for sql. Returns nil if not found. 22 | Get(sql string) *pgconn.StatementDescription 23 | 24 | // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. 25 | Put(sd *pgconn.StatementDescription) 26 | 27 | // Invalidate invalidates statement description identified by sql. Does nothing if not found. 28 | Invalidate(sql string) 29 | 30 | // InvalidateAll invalidates all statement descriptions. 31 | InvalidateAll() 32 | 33 | // HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. 34 | HandleInvalidated() []*pgconn.StatementDescription 35 | 36 | // Len returns the number of cached prepared statement descriptions. 37 | Len() int 38 | 39 | // Cap returns the maximum number of cached prepared statement descriptions. 40 | Cap() int 41 | } 42 | 43 | func IsStatementInvalid(err error) bool { 44 | pgErr, ok := err.(*pgconn.PgError) 45 | if !ok { 46 | return false 47 | } 48 | 49 | // https://github.com/jackc/pgx/issues/1162 50 | // 51 | // We used to look for the message "cached plan must not change result type". However, that message can be localized. 52 | // Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to 53 | // tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't 54 | // have so it should be safe. 55 | possibleInvalidCachedPlanError := pgErr.Code == "0A000" 56 | return possibleInvalidCachedPlanError 57 | } 58 | -------------------------------------------------------------------------------- /internal/stmtcache/unlimited_cache.go: -------------------------------------------------------------------------------- 1 | package stmtcache 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/jackc/pgx/v5/pgconn" 7 | ) 8 | 9 | // UnlimitedCache implements Cache with no capacity limit. 10 | type UnlimitedCache struct { 11 | m map[string]*pgconn.StatementDescription 12 | invalidStmts []*pgconn.StatementDescription 13 | } 14 | 15 | // NewUnlimitedCache creates a new UnlimitedCache. 16 | func NewUnlimitedCache() *UnlimitedCache { 17 | return &UnlimitedCache{ 18 | m: make(map[string]*pgconn.StatementDescription), 19 | } 20 | } 21 | 22 | // Get returns the statement description for sql. Returns nil if not found. 23 | func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription { 24 | return c.m[sql] 25 | } 26 | 27 | // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. 28 | func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) { 29 | if sd.SQL == "" { 30 | panic("cannot store statement description with empty SQL") 31 | } 32 | 33 | if _, present := c.m[sd.SQL]; present { 34 | return 35 | } 36 | 37 | c.m[sd.SQL] = sd 38 | } 39 | 40 | // Invalidate invalidates statement description identified by sql. Does nothing if not found. 41 | func (c *UnlimitedCache) Invalidate(sql string) { 42 | if sd, ok := c.m[sql]; ok { 43 | delete(c.m, sql) 44 | c.invalidStmts = append(c.invalidStmts, sd) 45 | } 46 | } 47 | 48 | // InvalidateAll invalidates all statement descriptions. 49 | func (c *UnlimitedCache) InvalidateAll() { 50 | for _, sd := range c.m { 51 | c.invalidStmts = append(c.invalidStmts, sd) 52 | } 53 | 54 | c.m = make(map[string]*pgconn.StatementDescription) 55 | } 56 | 57 | func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription { 58 | invalidStmts := c.invalidStmts 59 | c.invalidStmts = nil 60 | return invalidStmts 61 | } 62 | 63 | // Len returns the number of cached prepared statement descriptions. 64 | func (c *UnlimitedCache) Len() int { 65 | return len(c.m) 66 | } 67 | 68 | // Cap returns the maximum number of cached prepared statement descriptions. 69 | func (c *UnlimitedCache) Cap() int { 70 | return math.MaxInt 71 | } 72 | -------------------------------------------------------------------------------- /log/testingadapter/adapter.go: -------------------------------------------------------------------------------- 1 | // Package testingadapter provides a logger that writes to a test or benchmark 2 | // log. 3 | package testingadapter 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | 9 | "github.com/jackc/pgx/v5/tracelog" 10 | ) 11 | 12 | // TestingLogger interface defines the subset of testing.TB methods used by this 13 | // adapter. 14 | type TestingLogger interface { 15 | Log(args ...any) 16 | } 17 | 18 | type Logger struct { 19 | l TestingLogger 20 | } 21 | 22 | func NewLogger(l TestingLogger) *Logger { 23 | return &Logger{l: l} 24 | } 25 | 26 | func (l *Logger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) { 27 | logArgs := make([]any, 0, 2+len(data)) 28 | logArgs = append(logArgs, level, msg) 29 | for k, v := range data { 30 | logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v)) 31 | } 32 | l.l.Log(logArgs...) 33 | } 34 | -------------------------------------------------------------------------------- /pgbouncer_test.go: -------------------------------------------------------------------------------- 1 | package pgx_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/jackc/pgx/v5" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestPgbouncerStatementCacheDescribe(t *testing.T) { 14 | connString := os.Getenv("PGX_TEST_PGBOUNCER_CONN_STRING") 15 | if connString == "" { 16 | t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_PGBOUNCER_CONN_STRING") 17 | } 18 | 19 | config := mustParseConfig(t, connString) 20 | config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe 21 | config.DescriptionCacheCapacity = 1024 22 | 23 | testPgbouncer(t, config, 10, 100) 24 | } 25 | 26 | func TestPgbouncerSimpleProtocol(t *testing.T) { 27 | connString := os.Getenv("PGX_TEST_PGBOUNCER_CONN_STRING") 28 | if connString == "" { 29 | t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_PGBOUNCER_CONN_STRING") 30 | } 31 | 32 | config := mustParseConfig(t, connString) 33 | config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol 34 | 35 | testPgbouncer(t, config, 10, 100) 36 | } 37 | 38 | func testPgbouncer(t *testing.T, config *pgx.ConnConfig, workers, iterations int) { 39 | doneChan := make(chan struct{}) 40 | 41 | for i := 0; i < workers; i++ { 42 | go func() { 43 | defer func() { doneChan <- struct{}{} }() 44 | conn, err := pgx.ConnectConfig(context.Background(), config) 45 | require.Nil(t, err) 46 | defer closeConn(t, conn) 47 | 48 | for i := 0; i < iterations; i++ { 49 | var i32 int32 50 | var i64 int64 51 | var f32 float32 52 | var s string 53 | var s2 string 54 | err = conn.QueryRow(context.Background(), "select 1::int4, 2::int8, 3::float4, 'hi'::text").Scan(&i32, &i64, &f32, &s) 55 | require.NoError(t, err) 56 | assert.Equal(t, int32(1), i32) 57 | assert.Equal(t, int64(2), i64) 58 | assert.Equal(t, float32(3), f32) 59 | assert.Equal(t, "hi", s) 60 | 61 | err = conn.QueryRow(context.Background(), "select 1::int8, 2::float4, 'bye'::text, 4::int4, 'whatever'::text").Scan(&i64, &f32, &s, &i32, &s2) 62 | require.NoError(t, err) 63 | assert.Equal(t, int64(1), i64) 64 | assert.Equal(t, float32(2), f32) 65 | assert.Equal(t, "bye", s) 66 | assert.Equal(t, int32(4), i32) 67 | assert.Equal(t, "whatever", s2) 68 | } 69 | }() 70 | } 71 | 72 | for i := 0; i < workers; i++ { 73 | <-doneChan 74 | } 75 | 76 | } 77 | -------------------------------------------------------------------------------- /pgconn/README.md: -------------------------------------------------------------------------------- 1 | # pgconn 2 | 3 | Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq. 4 | It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx. 5 | Applications should handle normal queries with a higher level library and only use pgconn directly when required for 6 | low-level access to PostgreSQL functionality. 7 | 8 | ## Example Usage 9 | 10 | ```go 11 | pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) 12 | if err != nil { 13 | log.Fatalln("pgconn failed to connect:", err) 14 | } 15 | defer pgConn.Close(context.Background()) 16 | 17 | result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) 18 | for result.NextRow() { 19 | fmt.Println("User 123 has email:", string(result.Values()[0])) 20 | } 21 | _, err = result.Close() 22 | if err != nil { 23 | log.Fatalln("failed reading result:", err) 24 | } 25 | ``` 26 | 27 | ## Testing 28 | 29 | See CONTRIBUTING.md for setup instructions. 30 | -------------------------------------------------------------------------------- /pgconn/benchmark_private_test.go: -------------------------------------------------------------------------------- 1 | package pgconn 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func BenchmarkCommandTagRowsAffected(b *testing.B) { 9 | benchmarks := []struct { 10 | commandTag string 11 | rowsAffected int64 12 | }{ 13 | {"UPDATE 1", 1}, 14 | {"UPDATE 123456789", 123456789}, 15 | {"INSERT 0 1", 1}, 16 | {"INSERT 0 123456789", 123456789}, 17 | } 18 | 19 | for _, bm := range benchmarks { 20 | ct := CommandTag{s: bm.commandTag} 21 | b.Run(bm.commandTag, func(b *testing.B) { 22 | var n int64 23 | for i := 0; i < b.N; i++ { 24 | n = ct.RowsAffected() 25 | } 26 | if n != bm.rowsAffected { 27 | b.Errorf("expected %d got %d", bm.rowsAffected, n) 28 | } 29 | }) 30 | } 31 | } 32 | 33 | func BenchmarkCommandTagTypeFromString(b *testing.B) { 34 | ct := CommandTag{s: "UPDATE 1"} 35 | 36 | var update bool 37 | for i := 0; i < b.N; i++ { 38 | update = strings.HasPrefix(ct.String(), "UPDATE") 39 | } 40 | if !update { 41 | b.Error("expected update") 42 | } 43 | } 44 | 45 | func BenchmarkCommandTagInsert(b *testing.B) { 46 | benchmarks := []struct { 47 | commandTag string 48 | is bool 49 | }{ 50 | {"INSERT 1", true}, 51 | {"INSERT 1234567890", true}, 52 | {"UPDATE 1", false}, 53 | {"UPDATE 1234567890", false}, 54 | {"DELETE 1", false}, 55 | {"DELETE 1234567890", false}, 56 | {"SELECT 1", false}, 57 | {"SELECT 1234567890", false}, 58 | {"UNKNOWN 1234567890", false}, 59 | } 60 | 61 | for _, bm := range benchmarks { 62 | ct := CommandTag{s: bm.commandTag} 63 | b.Run(bm.commandTag, func(b *testing.B) { 64 | var is bool 65 | for i := 0; i < b.N; i++ { 66 | is = ct.Insert() 67 | } 68 | if is != bm.is { 69 | b.Errorf("expected %v got %v", bm.is, is) 70 | } 71 | }) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /pgconn/defaults.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | // +build !windows 3 | 4 | package pgconn 5 | 6 | import ( 7 | "os" 8 | "os/user" 9 | "path/filepath" 10 | ) 11 | 12 | func defaultSettings() map[string]string { 13 | settings := make(map[string]string) 14 | 15 | settings["host"] = defaultHost() 16 | settings["port"] = "5432" 17 | 18 | // Default to the OS user name. Purposely ignoring err getting user name from 19 | // OS. The client application will simply have to specify the user in that 20 | // case (which they typically will be doing anyway). 21 | user, err := user.Current() 22 | if err == nil { 23 | settings["user"] = user.Username 24 | settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") 25 | settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") 26 | sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") 27 | sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") 28 | if _, err := os.Stat(sslcert); err == nil { 29 | if _, err := os.Stat(sslkey); err == nil { 30 | // Both the cert and key must be present to use them, or do not use either 31 | settings["sslcert"] = sslcert 32 | settings["sslkey"] = sslkey 33 | } 34 | } 35 | sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt") 36 | if _, err := os.Stat(sslrootcert); err == nil { 37 | settings["sslrootcert"] = sslrootcert 38 | } 39 | } 40 | 41 | settings["target_session_attrs"] = "any" 42 | 43 | return settings 44 | } 45 | 46 | // defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost 47 | // on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it 48 | // checks the existence of common locations. 49 | func defaultHost() string { 50 | candidatePaths := []string{ 51 | "/var/run/postgresql", // Debian 52 | "/private/tmp", // OSX - homebrew 53 | "/tmp", // standard PostgreSQL 54 | } 55 | 56 | for _, path := range candidatePaths { 57 | if _, err := os.Stat(path); err == nil { 58 | return path 59 | } 60 | } 61 | 62 | return "localhost" 63 | } 64 | -------------------------------------------------------------------------------- /pgconn/defaults_windows.go: -------------------------------------------------------------------------------- 1 | package pgconn 2 | 3 | import ( 4 | "os" 5 | "os/user" 6 | "path/filepath" 7 | "strings" 8 | ) 9 | 10 | func defaultSettings() map[string]string { 11 | settings := make(map[string]string) 12 | 13 | settings["host"] = defaultHost() 14 | settings["port"] = "5432" 15 | 16 | // Default to the OS user name. Purposely ignoring err getting user name from 17 | // OS. The client application will simply have to specify the user in that 18 | // case (which they typically will be doing anyway). 19 | user, err := user.Current() 20 | appData := os.Getenv("APPDATA") 21 | if err == nil { 22 | // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, 23 | // but the libpq default is just the `user` portion, so we strip off the first part. 24 | username := user.Username 25 | if strings.Contains(username, "\\") { 26 | username = username[strings.LastIndex(username, "\\")+1:] 27 | } 28 | 29 | settings["user"] = username 30 | settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf") 31 | settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") 32 | sslcert := filepath.Join(appData, "postgresql", "postgresql.crt") 33 | sslkey := filepath.Join(appData, "postgresql", "postgresql.key") 34 | if _, err := os.Stat(sslcert); err == nil { 35 | if _, err := os.Stat(sslkey); err == nil { 36 | // Both the cert and key must be present to use them, or do not use either 37 | settings["sslcert"] = sslcert 38 | settings["sslkey"] = sslkey 39 | } 40 | } 41 | sslrootcert := filepath.Join(appData, "postgresql", "root.crt") 42 | if _, err := os.Stat(sslrootcert); err == nil { 43 | settings["sslrootcert"] = sslrootcert 44 | } 45 | } 46 | 47 | settings["target_session_attrs"] = "any" 48 | 49 | return settings 50 | } 51 | 52 | // defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost 53 | // on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it 54 | // checks the existence of common locations. 55 | func defaultHost() string { 56 | return "localhost" 57 | } 58 | -------------------------------------------------------------------------------- /pgconn/doc.go: -------------------------------------------------------------------------------- 1 | // Package pgconn is a low-level PostgreSQL database driver. 2 | /* 3 | pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at 4 | nearly the same level is the C library libpq. 5 | 6 | Establishing a Connection 7 | 8 | Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for 9 | libpq style environment variables. 10 | 11 | Executing a Query 12 | 13 | ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method 14 | reads all rows into memory. 15 | 16 | Executing Multiple Queries in a Single Round Trip 17 | 18 | Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query 19 | result. The ReadAll method reads all query results into memory. 20 | 21 | Pipeline Mode 22 | 23 | Pipeline mode allows sending queries without having read the results of previously sent queries. It allows 24 | control of exactly how many and when network round trips occur. 25 | 26 | Context Support 27 | 28 | All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the 29 | method immediately returns. In most circumstances, this will close the underlying connection. 30 | 31 | The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the 32 | client to abort. 33 | */ 34 | package pgconn 35 | -------------------------------------------------------------------------------- /pgconn/errors_test.go: -------------------------------------------------------------------------------- 1 | package pgconn_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jackc/pgx/v5/pgconn" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestConfigError(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | err error 14 | expectedMsg string 15 | }{ 16 | { 17 | name: "url with password", 18 | err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil), 19 | expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg", 20 | }, 21 | { 22 | name: "dsn with password unquoted", 23 | err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil), 24 | expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", 25 | }, 26 | { 27 | name: "dsn with password quoted", 28 | err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil), 29 | expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", 30 | }, 31 | { 32 | name: "weird url", 33 | err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil), 34 | expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg", 35 | }, 36 | { 37 | name: "weird url with slash in password", 38 | err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil), 39 | expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg", 40 | }, 41 | { 42 | name: "url without password", 43 | err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil), 44 | expectedMsg: "cannot parse `postgresql://other@host/db`: msg", 45 | }, 46 | } 47 | for _, tt := range tests { 48 | tt := tt 49 | t.Run(tt.name, func(t *testing.T) { 50 | t.Parallel() 51 | assert.EqualError(t, tt.err, tt.expectedMsg) 52 | }) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /pgconn/export_test.go: -------------------------------------------------------------------------------- 1 | // File export_test exports some methods for better testing. 2 | 3 | package pgconn 4 | 5 | func NewParseConfigError(conn, msg string, err error) error { 6 | return &parseConfigError{ 7 | connString: conn, 8 | msg: msg, 9 | err: err, 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /pgconn/helper_test.go: -------------------------------------------------------------------------------- 1 | package pgconn_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/jackc/pgx/v5/pgconn" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func closeConn(t testing.TB, conn *pgconn.PgConn) { 15 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 16 | defer cancel() 17 | require.NoError(t, conn.Close(ctx)) 18 | select { 19 | case <-conn.CleanupDone(): 20 | case <-time.After(5 * time.Second): 21 | t.Fatal("Connection cleanup exceeded maximum time") 22 | } 23 | } 24 | 25 | // Do a simple query to ensure the connection is still usable 26 | func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { 27 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 28 | result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read() 29 | cancel() 30 | 31 | require.Nil(t, result.Err) 32 | assert.Equal(t, 3, len(result.Rows)) 33 | assert.Equal(t, "1", string(result.Rows[0][0])) 34 | assert.Equal(t, "2", string(result.Rows[1][0])) 35 | assert.Equal(t, "3", string(result.Rows[2][0])) 36 | } 37 | -------------------------------------------------------------------------------- /pgconn/internal/ctxwatch/context_watcher.go: -------------------------------------------------------------------------------- 1 | package ctxwatch 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | ) 7 | 8 | // ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a 9 | // time. 10 | type ContextWatcher struct { 11 | onCancel func() 12 | onUnwatchAfterCancel func() 13 | unwatchChan chan struct{} 14 | 15 | lock sync.Mutex 16 | watchInProgress bool 17 | onCancelWasCalled bool 18 | } 19 | 20 | // NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. 21 | // OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and 22 | // onCancel called. 23 | func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { 24 | cw := &ContextWatcher{ 25 | onCancel: onCancel, 26 | onUnwatchAfterCancel: onUnwatchAfterCancel, 27 | unwatchChan: make(chan struct{}), 28 | } 29 | 30 | return cw 31 | } 32 | 33 | // Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. 34 | func (cw *ContextWatcher) Watch(ctx context.Context) { 35 | cw.lock.Lock() 36 | defer cw.lock.Unlock() 37 | 38 | if cw.watchInProgress { 39 | panic("Watch already in progress") 40 | } 41 | 42 | cw.onCancelWasCalled = false 43 | 44 | if ctx.Done() != nil { 45 | cw.watchInProgress = true 46 | go func() { 47 | select { 48 | case <-ctx.Done(): 49 | cw.onCancel() 50 | cw.onCancelWasCalled = true 51 | <-cw.unwatchChan 52 | case <-cw.unwatchChan: 53 | } 54 | }() 55 | } else { 56 | cw.watchInProgress = false 57 | } 58 | } 59 | 60 | // Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was 61 | // called then onUnwatchAfterCancel will also be called. 62 | func (cw *ContextWatcher) Unwatch() { 63 | cw.lock.Lock() 64 | defer cw.lock.Unlock() 65 | 66 | if cw.watchInProgress { 67 | cw.unwatchChan <- struct{}{} 68 | if cw.onCancelWasCalled { 69 | cw.onUnwatchAfterCancel() 70 | } 71 | cw.watchInProgress = false 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /pgconn/pgconn_private_test.go: -------------------------------------------------------------------------------- 1 | package pgconn 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestCommandTag(t *testing.T) { 10 | t.Parallel() 11 | 12 | var tests = []struct { 13 | commandTag CommandTag 14 | rowsAffected int64 15 | isInsert bool 16 | isUpdate bool 17 | isDelete bool 18 | isSelect bool 19 | }{ 20 | {commandTag: CommandTag{s: "INSERT 0 5"}, rowsAffected: 5, isInsert: true}, 21 | {commandTag: CommandTag{s: "UPDATE 0"}, rowsAffected: 0, isUpdate: true}, 22 | {commandTag: CommandTag{s: "UPDATE 1"}, rowsAffected: 1, isUpdate: true}, 23 | {commandTag: CommandTag{s: "DELETE 0"}, rowsAffected: 0, isDelete: true}, 24 | {commandTag: CommandTag{s: "DELETE 1"}, rowsAffected: 1, isDelete: true}, 25 | {commandTag: CommandTag{s: "DELETE 1234567890"}, rowsAffected: 1234567890, isDelete: true}, 26 | {commandTag: CommandTag{s: "SELECT 1"}, rowsAffected: 1, isSelect: true}, 27 | {commandTag: CommandTag{s: "SELECT 99999999999"}, rowsAffected: 99999999999, isSelect: true}, 28 | {commandTag: CommandTag{s: "CREATE TABLE"}, rowsAffected: 0}, 29 | {commandTag: CommandTag{s: "ALTER TABLE"}, rowsAffected: 0}, 30 | {commandTag: CommandTag{s: "DROP TABLE"}, rowsAffected: 0}, 31 | } 32 | 33 | for i, tt := range tests { 34 | ct := tt.commandTag 35 | assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag) 36 | assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag) 37 | assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag) 38 | assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag) 39 | assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /pgproto3/README.md: -------------------------------------------------------------------------------- 1 | # pgproto3 2 | 3 | Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. 4 | 5 | pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more. 6 | 7 | See example/pgfortune for a playful example of a fake PostgreSQL server. 8 | -------------------------------------------------------------------------------- /pgproto3/authentication_cleartext_password.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | // AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required. 12 | type AuthenticationCleartextPassword struct { 13 | } 14 | 15 | // Backend identifies this message as sendable by the PostgreSQL backend. 16 | func (*AuthenticationCleartextPassword) Backend() {} 17 | 18 | // Backend identifies this message as an authentication response. 19 | func (*AuthenticationCleartextPassword) AuthenticationResponse() {} 20 | 21 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 22 | // type identifier and 4 byte message length. 23 | func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { 24 | if len(src) != 4 { 25 | return errors.New("bad authentication message size") 26 | } 27 | 28 | authType := binary.BigEndian.Uint32(src) 29 | 30 | if authType != AuthTypeCleartextPassword { 31 | return errors.New("bad auth type") 32 | } 33 | 34 | return nil 35 | } 36 | 37 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 38 | func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { 39 | dst = append(dst, 'R') 40 | dst = pgio.AppendInt32(dst, 8) 41 | dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) 42 | return dst 43 | } 44 | 45 | // MarshalJSON implements encoding/json.Marshaler. 46 | func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) { 47 | return json.Marshal(struct { 48 | Type string 49 | }{ 50 | Type: "AuthenticationCleartextPassword", 51 | }) 52 | } 53 | -------------------------------------------------------------------------------- /pgproto3/authentication_gss.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | type AuthenticationGSS struct{} 12 | 13 | func (a *AuthenticationGSS) Backend() {} 14 | 15 | func (a *AuthenticationGSS) AuthenticationResponse() {} 16 | 17 | func (a *AuthenticationGSS) Decode(src []byte) error { 18 | if len(src) < 4 { 19 | return errors.New("authentication message too short") 20 | } 21 | 22 | authType := binary.BigEndian.Uint32(src) 23 | 24 | if authType != AuthTypeGSS { 25 | return errors.New("bad auth type") 26 | } 27 | return nil 28 | } 29 | 30 | func (a *AuthenticationGSS) Encode(dst []byte) []byte { 31 | dst = append(dst, 'R') 32 | dst = pgio.AppendInt32(dst, 4) 33 | dst = pgio.AppendUint32(dst, AuthTypeGSS) 34 | return dst 35 | } 36 | 37 | func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { 38 | return json.Marshal(struct { 39 | Type string 40 | Data []byte 41 | }{ 42 | Type: "AuthenticationGSS", 43 | }) 44 | } 45 | 46 | func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error { 47 | // Ignore null, like in the main JSON package. 48 | if string(data) == "null" { 49 | return nil 50 | } 51 | 52 | var msg struct { 53 | Type string 54 | } 55 | if err := json.Unmarshal(data, &msg); err != nil { 56 | return err 57 | } 58 | return nil 59 | } 60 | -------------------------------------------------------------------------------- /pgproto3/authentication_gss_continue.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | type AuthenticationGSSContinue struct { 12 | Data []byte 13 | } 14 | 15 | func (a *AuthenticationGSSContinue) Backend() {} 16 | 17 | func (a *AuthenticationGSSContinue) AuthenticationResponse() {} 18 | 19 | func (a *AuthenticationGSSContinue) Decode(src []byte) error { 20 | if len(src) < 4 { 21 | return errors.New("authentication message too short") 22 | } 23 | 24 | authType := binary.BigEndian.Uint32(src) 25 | 26 | if authType != AuthTypeGSSCont { 27 | return errors.New("bad auth type") 28 | } 29 | 30 | a.Data = src[4:] 31 | return nil 32 | } 33 | 34 | func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte { 35 | dst = append(dst, 'R') 36 | dst = pgio.AppendInt32(dst, int32(len(a.Data))+8) 37 | dst = pgio.AppendUint32(dst, AuthTypeGSSCont) 38 | dst = append(dst, a.Data...) 39 | return dst 40 | } 41 | 42 | func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { 43 | return json.Marshal(struct { 44 | Type string 45 | Data []byte 46 | }{ 47 | Type: "AuthenticationGSSContinue", 48 | Data: a.Data, 49 | }) 50 | } 51 | 52 | func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error { 53 | // Ignore null, like in the main JSON package. 54 | if string(data) == "null" { 55 | return nil 56 | } 57 | 58 | var msg struct { 59 | Type string 60 | Data []byte 61 | } 62 | if err := json.Unmarshal(data, &msg); err != nil { 63 | return err 64 | } 65 | 66 | a.Data = msg.Data 67 | return nil 68 | } 69 | -------------------------------------------------------------------------------- /pgproto3/authentication_md5_password.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | // AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required. 12 | type AuthenticationMD5Password struct { 13 | Salt [4]byte 14 | } 15 | 16 | // Backend identifies this message as sendable by the PostgreSQL backend. 17 | func (*AuthenticationMD5Password) Backend() {} 18 | 19 | // Backend identifies this message as an authentication response. 20 | func (*AuthenticationMD5Password) AuthenticationResponse() {} 21 | 22 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 23 | // type identifier and 4 byte message length. 24 | func (dst *AuthenticationMD5Password) Decode(src []byte) error { 25 | if len(src) != 8 { 26 | return errors.New("bad authentication message size") 27 | } 28 | 29 | authType := binary.BigEndian.Uint32(src) 30 | 31 | if authType != AuthTypeMD5Password { 32 | return errors.New("bad auth type") 33 | } 34 | 35 | copy(dst.Salt[:], src[4:8]) 36 | 37 | return nil 38 | } 39 | 40 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 41 | func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { 42 | dst = append(dst, 'R') 43 | dst = pgio.AppendInt32(dst, 12) 44 | dst = pgio.AppendUint32(dst, AuthTypeMD5Password) 45 | dst = append(dst, src.Salt[:]...) 46 | return dst 47 | } 48 | 49 | // MarshalJSON implements encoding/json.Marshaler. 50 | func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) { 51 | return json.Marshal(struct { 52 | Type string 53 | Salt [4]byte 54 | }{ 55 | Type: "AuthenticationMD5Password", 56 | Salt: src.Salt, 57 | }) 58 | } 59 | 60 | // UnmarshalJSON implements encoding/json.Unmarshaler. 61 | func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error { 62 | // Ignore null, like in the main JSON package. 63 | if string(data) == "null" { 64 | return nil 65 | } 66 | 67 | var msg struct { 68 | Type string 69 | Salt [4]byte 70 | } 71 | if err := json.Unmarshal(data, &msg); err != nil { 72 | return err 73 | } 74 | 75 | dst.Salt = msg.Salt 76 | return nil 77 | } 78 | -------------------------------------------------------------------------------- /pgproto3/authentication_ok.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | // AuthenticationOk is a message sent from the backend indicating that authentication was successful. 12 | type AuthenticationOk struct { 13 | } 14 | 15 | // Backend identifies this message as sendable by the PostgreSQL backend. 16 | func (*AuthenticationOk) Backend() {} 17 | 18 | // Backend identifies this message as an authentication response. 19 | func (*AuthenticationOk) AuthenticationResponse() {} 20 | 21 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 22 | // type identifier and 4 byte message length. 23 | func (dst *AuthenticationOk) Decode(src []byte) error { 24 | if len(src) != 4 { 25 | return errors.New("bad authentication message size") 26 | } 27 | 28 | authType := binary.BigEndian.Uint32(src) 29 | 30 | if authType != AuthTypeOk { 31 | return errors.New("bad auth type") 32 | } 33 | 34 | return nil 35 | } 36 | 37 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 38 | func (src *AuthenticationOk) Encode(dst []byte) []byte { 39 | dst = append(dst, 'R') 40 | dst = pgio.AppendInt32(dst, 8) 41 | dst = pgio.AppendUint32(dst, AuthTypeOk) 42 | return dst 43 | } 44 | 45 | // MarshalJSON implements encoding/json.Marshaler. 46 | func (src AuthenticationOk) MarshalJSON() ([]byte, error) { 47 | return json.Marshal(struct { 48 | Type string 49 | }{ 50 | Type: "AuthenticationOK", 51 | }) 52 | } 53 | -------------------------------------------------------------------------------- /pgproto3/authentication_sasl.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "encoding/json" 7 | "errors" 8 | 9 | "github.com/jackc/pgx/v5/internal/pgio" 10 | ) 11 | 12 | // AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required. 13 | type AuthenticationSASL struct { 14 | AuthMechanisms []string 15 | } 16 | 17 | // Backend identifies this message as sendable by the PostgreSQL backend. 18 | func (*AuthenticationSASL) Backend() {} 19 | 20 | // Backend identifies this message as an authentication response. 21 | func (*AuthenticationSASL) AuthenticationResponse() {} 22 | 23 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 24 | // type identifier and 4 byte message length. 25 | func (dst *AuthenticationSASL) Decode(src []byte) error { 26 | if len(src) < 4 { 27 | return errors.New("authentication message too short") 28 | } 29 | 30 | authType := binary.BigEndian.Uint32(src) 31 | 32 | if authType != AuthTypeSASL { 33 | return errors.New("bad auth type") 34 | } 35 | 36 | authMechanisms := src[4:] 37 | for len(authMechanisms) > 1 { 38 | idx := bytes.IndexByte(authMechanisms, 0) 39 | if idx == -1 { 40 | return &invalidMessageFormatErr{messageType: "AuthenticationSASL", details: "unterminated string"} 41 | } 42 | dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx])) 43 | authMechanisms = authMechanisms[idx+1:] 44 | } 45 | 46 | return nil 47 | } 48 | 49 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 50 | func (src *AuthenticationSASL) Encode(dst []byte) []byte { 51 | dst = append(dst, 'R') 52 | sp := len(dst) 53 | dst = pgio.AppendInt32(dst, -1) 54 | dst = pgio.AppendUint32(dst, AuthTypeSASL) 55 | 56 | for _, s := range src.AuthMechanisms { 57 | dst = append(dst, []byte(s)...) 58 | dst = append(dst, 0) 59 | } 60 | dst = append(dst, 0) 61 | 62 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 63 | 64 | return dst 65 | } 66 | 67 | // MarshalJSON implements encoding/json.Marshaler. 68 | func (src AuthenticationSASL) MarshalJSON() ([]byte, error) { 69 | return json.Marshal(struct { 70 | Type string 71 | AuthMechanisms []string 72 | }{ 73 | Type: "AuthenticationSASL", 74 | AuthMechanisms: src.AuthMechanisms, 75 | }) 76 | } 77 | -------------------------------------------------------------------------------- /pgproto3/authentication_sasl_continue.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | // AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge. 12 | type AuthenticationSASLContinue struct { 13 | Data []byte 14 | } 15 | 16 | // Backend identifies this message as sendable by the PostgreSQL backend. 17 | func (*AuthenticationSASLContinue) Backend() {} 18 | 19 | // Backend identifies this message as an authentication response. 20 | func (*AuthenticationSASLContinue) AuthenticationResponse() {} 21 | 22 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 23 | // type identifier and 4 byte message length. 24 | func (dst *AuthenticationSASLContinue) Decode(src []byte) error { 25 | if len(src) < 4 { 26 | return errors.New("authentication message too short") 27 | } 28 | 29 | authType := binary.BigEndian.Uint32(src) 30 | 31 | if authType != AuthTypeSASLContinue { 32 | return errors.New("bad auth type") 33 | } 34 | 35 | dst.Data = src[4:] 36 | 37 | return nil 38 | } 39 | 40 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 41 | func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { 42 | dst = append(dst, 'R') 43 | sp := len(dst) 44 | dst = pgio.AppendInt32(dst, -1) 45 | dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) 46 | 47 | dst = append(dst, src.Data...) 48 | 49 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 50 | 51 | return dst 52 | } 53 | 54 | // MarshalJSON implements encoding/json.Marshaler. 55 | func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) { 56 | return json.Marshal(struct { 57 | Type string 58 | Data string 59 | }{ 60 | Type: "AuthenticationSASLContinue", 61 | Data: string(src.Data), 62 | }) 63 | } 64 | 65 | // UnmarshalJSON implements encoding/json.Unmarshaler. 66 | func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error { 67 | // Ignore null, like in the main JSON package. 68 | if string(data) == "null" { 69 | return nil 70 | } 71 | 72 | var msg struct { 73 | Data string 74 | } 75 | if err := json.Unmarshal(data, &msg); err != nil { 76 | return err 77 | } 78 | 79 | dst.Data = []byte(msg.Data) 80 | return nil 81 | } 82 | -------------------------------------------------------------------------------- /pgproto3/authentication_sasl_final.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | // AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed. 12 | type AuthenticationSASLFinal struct { 13 | Data []byte 14 | } 15 | 16 | // Backend identifies this message as sendable by the PostgreSQL backend. 17 | func (*AuthenticationSASLFinal) Backend() {} 18 | 19 | // Backend identifies this message as an authentication response. 20 | func (*AuthenticationSASLFinal) AuthenticationResponse() {} 21 | 22 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 23 | // type identifier and 4 byte message length. 24 | func (dst *AuthenticationSASLFinal) Decode(src []byte) error { 25 | if len(src) < 4 { 26 | return errors.New("authentication message too short") 27 | } 28 | 29 | authType := binary.BigEndian.Uint32(src) 30 | 31 | if authType != AuthTypeSASLFinal { 32 | return errors.New("bad auth type") 33 | } 34 | 35 | dst.Data = src[4:] 36 | 37 | return nil 38 | } 39 | 40 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 41 | func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { 42 | dst = append(dst, 'R') 43 | sp := len(dst) 44 | dst = pgio.AppendInt32(dst, -1) 45 | dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) 46 | 47 | dst = append(dst, src.Data...) 48 | 49 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 50 | 51 | return dst 52 | } 53 | 54 | // MarshalJSON implements encoding/json.Unmarshaler. 55 | func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) { 56 | return json.Marshal(struct { 57 | Type string 58 | Data string 59 | }{ 60 | Type: "AuthenticationSASLFinal", 61 | Data: string(src.Data), 62 | }) 63 | } 64 | 65 | // UnmarshalJSON implements encoding/json.Unmarshaler. 66 | func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error { 67 | // Ignore null, like in the main JSON package. 68 | if string(data) == "null" { 69 | return nil 70 | } 71 | 72 | var msg struct { 73 | Data string 74 | } 75 | if err := json.Unmarshal(data, &msg); err != nil { 76 | return err 77 | } 78 | 79 | dst.Data = []byte(msg.Data) 80 | return nil 81 | } 82 | -------------------------------------------------------------------------------- /pgproto3/backend_key_data.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/json" 6 | 7 | "github.com/jackc/pgx/v5/internal/pgio" 8 | ) 9 | 10 | type BackendKeyData struct { 11 | ProcessID uint32 12 | SecretKey uint32 13 | } 14 | 15 | // Backend identifies this message as sendable by the PostgreSQL backend. 16 | func (*BackendKeyData) Backend() {} 17 | 18 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 19 | // type identifier and 4 byte message length. 20 | func (dst *BackendKeyData) Decode(src []byte) error { 21 | if len(src) != 8 { 22 | return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} 23 | } 24 | 25 | dst.ProcessID = binary.BigEndian.Uint32(src[:4]) 26 | dst.SecretKey = binary.BigEndian.Uint32(src[4:]) 27 | 28 | return nil 29 | } 30 | 31 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 32 | func (src *BackendKeyData) Encode(dst []byte) []byte { 33 | dst = append(dst, 'K') 34 | dst = pgio.AppendUint32(dst, 12) 35 | dst = pgio.AppendUint32(dst, src.ProcessID) 36 | dst = pgio.AppendUint32(dst, src.SecretKey) 37 | return dst 38 | } 39 | 40 | // MarshalJSON implements encoding/json.Marshaler. 41 | func (src BackendKeyData) MarshalJSON() ([]byte, error) { 42 | return json.Marshal(struct { 43 | Type string 44 | ProcessID uint32 45 | SecretKey uint32 46 | }{ 47 | Type: "BackendKeyData", 48 | ProcessID: src.ProcessID, 49 | SecretKey: src.SecretKey, 50 | }) 51 | } 52 | -------------------------------------------------------------------------------- /pgproto3/big_endian.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | ) 6 | 7 | type BigEndianBuf [8]byte 8 | 9 | func (b BigEndianBuf) Int16(n int16) []byte { 10 | buf := b[0:2] 11 | binary.BigEndian.PutUint16(buf, uint16(n)) 12 | return buf 13 | } 14 | 15 | func (b BigEndianBuf) Uint16(n uint16) []byte { 16 | buf := b[0:2] 17 | binary.BigEndian.PutUint16(buf, n) 18 | return buf 19 | } 20 | 21 | func (b BigEndianBuf) Int32(n int32) []byte { 22 | buf := b[0:4] 23 | binary.BigEndian.PutUint32(buf, uint32(n)) 24 | return buf 25 | } 26 | 27 | func (b BigEndianBuf) Uint32(n uint32) []byte { 28 | buf := b[0:4] 29 | binary.BigEndian.PutUint32(buf, n) 30 | return buf 31 | } 32 | 33 | func (b BigEndianBuf) Int64(n int64) []byte { 34 | buf := b[0:8] 35 | binary.BigEndian.PutUint64(buf, uint64(n)) 36 | return buf 37 | } 38 | -------------------------------------------------------------------------------- /pgproto3/bind_complete.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type BindComplete struct{} 8 | 9 | // Backend identifies this message as sendable by the PostgreSQL backend. 10 | func (*BindComplete) Backend() {} 11 | 12 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 13 | // type identifier and 4 byte message length. 14 | func (dst *BindComplete) Decode(src []byte) error { 15 | if len(src) != 0 { 16 | return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} 17 | } 18 | 19 | return nil 20 | } 21 | 22 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 23 | func (src *BindComplete) Encode(dst []byte) []byte { 24 | return append(dst, '2', 0, 0, 0, 4) 25 | } 26 | 27 | // MarshalJSON implements encoding/json.Marshaler. 28 | func (src BindComplete) MarshalJSON() ([]byte, error) { 29 | return json.Marshal(struct { 30 | Type string 31 | }{ 32 | Type: "BindComplete", 33 | }) 34 | } 35 | -------------------------------------------------------------------------------- /pgproto3/cancel_request.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | const cancelRequestCode = 80877102 12 | 13 | type CancelRequest struct { 14 | ProcessID uint32 15 | SecretKey uint32 16 | } 17 | 18 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 19 | func (*CancelRequest) Frontend() {} 20 | 21 | func (dst *CancelRequest) Decode(src []byte) error { 22 | if len(src) != 12 { 23 | return errors.New("bad cancel request size") 24 | } 25 | 26 | requestCode := binary.BigEndian.Uint32(src) 27 | 28 | if requestCode != cancelRequestCode { 29 | return errors.New("bad cancel request code") 30 | } 31 | 32 | dst.ProcessID = binary.BigEndian.Uint32(src[4:]) 33 | dst.SecretKey = binary.BigEndian.Uint32(src[8:]) 34 | 35 | return nil 36 | } 37 | 38 | // Encode encodes src into dst. dst will include the 4 byte message length. 39 | func (src *CancelRequest) Encode(dst []byte) []byte { 40 | dst = pgio.AppendInt32(dst, 16) 41 | dst = pgio.AppendInt32(dst, cancelRequestCode) 42 | dst = pgio.AppendUint32(dst, src.ProcessID) 43 | dst = pgio.AppendUint32(dst, src.SecretKey) 44 | return dst 45 | } 46 | 47 | // MarshalJSON implements encoding/json.Marshaler. 48 | func (src CancelRequest) MarshalJSON() ([]byte, error) { 49 | return json.Marshal(struct { 50 | Type string 51 | ProcessID uint32 52 | SecretKey uint32 53 | }{ 54 | Type: "CancelRequest", 55 | ProcessID: src.ProcessID, 56 | SecretKey: src.SecretKey, 57 | }) 58 | } 59 | -------------------------------------------------------------------------------- /pgproto3/chunkreader_test.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "math/rand" 6 | "testing" 7 | ) 8 | 9 | func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { 10 | server := &bytes.Buffer{} 11 | r := newChunkReader(server, 4) 12 | 13 | src := []byte{1, 2, 3, 4} 14 | server.Write(src) 15 | 16 | n1, err := r.Next(2) 17 | if err != nil { 18 | t.Fatal(err) 19 | } 20 | if bytes.Compare(n1, src[0:2]) != 0 { 21 | t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1) 22 | } 23 | 24 | n2, err := r.Next(2) 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | if bytes.Compare(n2, src[2:4]) != 0 { 29 | t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2) 30 | } 31 | 32 | if bytes.Compare((*r.buf)[:len(src)], src) != 0 { 33 | t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf) 34 | } 35 | 36 | _, err = r.Next(0) // Trigger the buffer reset. 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | 41 | if r.rp != 0 { 42 | t.Fatalf("Expected r.rp to be %v, but it was %v", 0, r.rp) 43 | } 44 | if r.wp != 0 { 45 | t.Fatalf("Expected r.wp to be %v, but it was %v", 0, r.wp) 46 | } 47 | } 48 | 49 | type randomReader struct { 50 | rnd *rand.Rand 51 | } 52 | 53 | // Read reads a random number of random bytes. 54 | func (r *randomReader) Read(p []byte) (n int, err error) { 55 | n = r.rnd.Intn(len(p) + 1) 56 | return r.rnd.Read(p[:n]) 57 | } 58 | 59 | func TestChunkReaderNextFuzz(t *testing.T) { 60 | rr := &randomReader{rnd: rand.New(rand.NewSource(1))} 61 | r := newChunkReader(rr, 8192) 62 | 63 | randomSizes := rand.New(rand.NewSource(0)) 64 | 65 | for i := 0; i < 100000; i++ { 66 | size := randomSizes.Intn(16384) + 1 67 | buf, err := r.Next(size) 68 | if err != nil { 69 | t.Fatal(err) 70 | } 71 | if len(buf) != size { 72 | t.Fatalf("Expected to get %v bytes but got %v bytes", size, len(buf)) 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /pgproto3/close.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | type Close struct { 12 | ObjectType byte // 'S' = prepared statement, 'P' = portal 13 | Name string 14 | } 15 | 16 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 17 | func (*Close) Frontend() {} 18 | 19 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 20 | // type identifier and 4 byte message length. 21 | func (dst *Close) Decode(src []byte) error { 22 | if len(src) < 2 { 23 | return &invalidMessageFormatErr{messageType: "Close"} 24 | } 25 | 26 | dst.ObjectType = src[0] 27 | rp := 1 28 | 29 | idx := bytes.IndexByte(src[rp:], 0) 30 | if idx != len(src[rp:])-1 { 31 | return &invalidMessageFormatErr{messageType: "Close"} 32 | } 33 | 34 | dst.Name = string(src[rp : len(src)-1]) 35 | 36 | return nil 37 | } 38 | 39 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 40 | func (src *Close) Encode(dst []byte) []byte { 41 | dst = append(dst, 'C') 42 | sp := len(dst) 43 | dst = pgio.AppendInt32(dst, -1) 44 | 45 | dst = append(dst, src.ObjectType) 46 | dst = append(dst, src.Name...) 47 | dst = append(dst, 0) 48 | 49 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 50 | 51 | return dst 52 | } 53 | 54 | // MarshalJSON implements encoding/json.Marshaler. 55 | func (src Close) MarshalJSON() ([]byte, error) { 56 | return json.Marshal(struct { 57 | Type string 58 | ObjectType string 59 | Name string 60 | }{ 61 | Type: "Close", 62 | ObjectType: string(src.ObjectType), 63 | Name: src.Name, 64 | }) 65 | } 66 | 67 | // UnmarshalJSON implements encoding/json.Unmarshaler. 68 | func (dst *Close) UnmarshalJSON(data []byte) error { 69 | // Ignore null, like in the main JSON package. 70 | if string(data) == "null" { 71 | return nil 72 | } 73 | 74 | var msg struct { 75 | ObjectType string 76 | Name string 77 | } 78 | if err := json.Unmarshal(data, &msg); err != nil { 79 | return err 80 | } 81 | 82 | if len(msg.ObjectType) != 1 { 83 | return errors.New("invalid length for Close.ObjectType") 84 | } 85 | 86 | dst.ObjectType = byte(msg.ObjectType[0]) 87 | dst.Name = msg.Name 88 | return nil 89 | } 90 | -------------------------------------------------------------------------------- /pgproto3/close_complete.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type CloseComplete struct{} 8 | 9 | // Backend identifies this message as sendable by the PostgreSQL backend. 10 | func (*CloseComplete) Backend() {} 11 | 12 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 13 | // type identifier and 4 byte message length. 14 | func (dst *CloseComplete) Decode(src []byte) error { 15 | if len(src) != 0 { 16 | return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} 17 | } 18 | 19 | return nil 20 | } 21 | 22 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 23 | func (src *CloseComplete) Encode(dst []byte) []byte { 24 | return append(dst, '3', 0, 0, 0, 4) 25 | } 26 | 27 | // MarshalJSON implements encoding/json.Marshaler. 28 | func (src CloseComplete) MarshalJSON() ([]byte, error) { 29 | return json.Marshal(struct { 30 | Type string 31 | }{ 32 | Type: "CloseComplete", 33 | }) 34 | } 35 | -------------------------------------------------------------------------------- /pgproto3/command_complete.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | 7 | "github.com/jackc/pgx/v5/internal/pgio" 8 | ) 9 | 10 | type CommandComplete struct { 11 | CommandTag []byte 12 | } 13 | 14 | // Backend identifies this message as sendable by the PostgreSQL backend. 15 | func (*CommandComplete) Backend() {} 16 | 17 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 18 | // type identifier and 4 byte message length. 19 | func (dst *CommandComplete) Decode(src []byte) error { 20 | idx := bytes.IndexByte(src, 0) 21 | if idx == -1 { 22 | return &invalidMessageFormatErr{messageType: "CommandComplete", details: "unterminated string"} 23 | } 24 | if idx != len(src)-1 { 25 | return &invalidMessageFormatErr{messageType: "CommandComplete", details: "string terminated too early"} 26 | } 27 | 28 | dst.CommandTag = src[:idx] 29 | 30 | return nil 31 | } 32 | 33 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 34 | func (src *CommandComplete) Encode(dst []byte) []byte { 35 | dst = append(dst, 'C') 36 | sp := len(dst) 37 | dst = pgio.AppendInt32(dst, -1) 38 | 39 | dst = append(dst, src.CommandTag...) 40 | dst = append(dst, 0) 41 | 42 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 43 | 44 | return dst 45 | } 46 | 47 | // MarshalJSON implements encoding/json.Marshaler. 48 | func (src CommandComplete) MarshalJSON() ([]byte, error) { 49 | return json.Marshal(struct { 50 | Type string 51 | CommandTag string 52 | }{ 53 | Type: "CommandComplete", 54 | CommandTag: string(src.CommandTag), 55 | }) 56 | } 57 | 58 | // UnmarshalJSON implements encoding/json.Unmarshaler. 59 | func (dst *CommandComplete) UnmarshalJSON(data []byte) error { 60 | // Ignore null, like in the main JSON package. 61 | if string(data) == "null" { 62 | return nil 63 | } 64 | 65 | var msg struct { 66 | CommandTag string 67 | } 68 | if err := json.Unmarshal(data, &msg); err != nil { 69 | return err 70 | } 71 | 72 | dst.CommandTag = []byte(msg.CommandTag) 73 | return nil 74 | } 75 | -------------------------------------------------------------------------------- /pgproto3/copy_both_response_test.go: -------------------------------------------------------------------------------- 1 | package pgproto3_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jackc/pgx/v5/pgproto3" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestEncodeDecode(t *testing.T) { 11 | srcBytes := []byte{'W', 0x00, 0x00, 0x00, 0x0b, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01} 12 | dstResp := pgproto3.CopyBothResponse{} 13 | err := dstResp.Decode(srcBytes[5:]) 14 | assert.NoError(t, err, "No errors on decode") 15 | dstBytes := []byte{} 16 | dstBytes = dstResp.Encode(dstBytes) 17 | assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match") 18 | } 19 | -------------------------------------------------------------------------------- /pgproto3/copy_data.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/hex" 5 | "encoding/json" 6 | 7 | "github.com/jackc/pgx/v5/internal/pgio" 8 | ) 9 | 10 | type CopyData struct { 11 | Data []byte 12 | } 13 | 14 | // Backend identifies this message as sendable by the PostgreSQL backend. 15 | func (*CopyData) Backend() {} 16 | 17 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 18 | func (*CopyData) Frontend() {} 19 | 20 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 21 | // type identifier and 4 byte message length. 22 | func (dst *CopyData) Decode(src []byte) error { 23 | dst.Data = src 24 | return nil 25 | } 26 | 27 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 28 | func (src *CopyData) Encode(dst []byte) []byte { 29 | dst = append(dst, 'd') 30 | dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) 31 | dst = append(dst, src.Data...) 32 | return dst 33 | } 34 | 35 | // MarshalJSON implements encoding/json.Marshaler. 36 | func (src CopyData) MarshalJSON() ([]byte, error) { 37 | return json.Marshal(struct { 38 | Type string 39 | Data string 40 | }{ 41 | Type: "CopyData", 42 | Data: hex.EncodeToString(src.Data), 43 | }) 44 | } 45 | 46 | // UnmarshalJSON implements encoding/json.Unmarshaler. 47 | func (dst *CopyData) UnmarshalJSON(data []byte) error { 48 | // Ignore null, like in the main JSON package. 49 | if string(data) == "null" { 50 | return nil 51 | } 52 | 53 | var msg struct { 54 | Data string 55 | } 56 | if err := json.Unmarshal(data, &msg); err != nil { 57 | return err 58 | } 59 | 60 | dst.Data = []byte(msg.Data) 61 | return nil 62 | } 63 | -------------------------------------------------------------------------------- /pgproto3/copy_done.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type CopyDone struct { 8 | } 9 | 10 | // Backend identifies this message as sendable by the PostgreSQL backend. 11 | func (*CopyDone) Backend() {} 12 | 13 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 14 | func (*CopyDone) Frontend() {} 15 | 16 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 17 | // type identifier and 4 byte message length. 18 | func (dst *CopyDone) Decode(src []byte) error { 19 | if len(src) != 0 { 20 | return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)} 21 | } 22 | 23 | return nil 24 | } 25 | 26 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 27 | func (src *CopyDone) Encode(dst []byte) []byte { 28 | return append(dst, 'c', 0, 0, 0, 4) 29 | } 30 | 31 | // MarshalJSON implements encoding/json.Marshaler. 32 | func (src CopyDone) MarshalJSON() ([]byte, error) { 33 | return json.Marshal(struct { 34 | Type string 35 | }{ 36 | Type: "CopyDone", 37 | }) 38 | } 39 | -------------------------------------------------------------------------------- /pgproto3/copy_fail.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | 7 | "github.com/jackc/pgx/v5/internal/pgio" 8 | ) 9 | 10 | type CopyFail struct { 11 | Message string 12 | } 13 | 14 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 15 | func (*CopyFail) Frontend() {} 16 | 17 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 18 | // type identifier and 4 byte message length. 19 | func (dst *CopyFail) Decode(src []byte) error { 20 | idx := bytes.IndexByte(src, 0) 21 | if idx != len(src)-1 { 22 | return &invalidMessageFormatErr{messageType: "CopyFail"} 23 | } 24 | 25 | dst.Message = string(src[:idx]) 26 | 27 | return nil 28 | } 29 | 30 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 31 | func (src *CopyFail) Encode(dst []byte) []byte { 32 | dst = append(dst, 'f') 33 | sp := len(dst) 34 | dst = pgio.AppendInt32(dst, -1) 35 | 36 | dst = append(dst, src.Message...) 37 | dst = append(dst, 0) 38 | 39 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 40 | 41 | return dst 42 | } 43 | 44 | // MarshalJSON implements encoding/json.Marshaler. 45 | func (src CopyFail) MarshalJSON() ([]byte, error) { 46 | return json.Marshal(struct { 47 | Type string 48 | Message string 49 | }{ 50 | Type: "CopyFail", 51 | Message: src.Message, 52 | }) 53 | } 54 | -------------------------------------------------------------------------------- /pgproto3/copy_out_response.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "encoding/json" 7 | "errors" 8 | 9 | "github.com/jackc/pgx/v5/internal/pgio" 10 | ) 11 | 12 | type CopyOutResponse struct { 13 | OverallFormat byte 14 | ColumnFormatCodes []uint16 15 | } 16 | 17 | func (*CopyOutResponse) Backend() {} 18 | 19 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 20 | // type identifier and 4 byte message length. 21 | func (dst *CopyOutResponse) Decode(src []byte) error { 22 | buf := bytes.NewBuffer(src) 23 | 24 | if buf.Len() < 3 { 25 | return &invalidMessageFormatErr{messageType: "CopyOutResponse"} 26 | } 27 | 28 | overallFormat := buf.Next(1)[0] 29 | 30 | columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) 31 | if buf.Len() != columnCount*2 { 32 | return &invalidMessageFormatErr{messageType: "CopyOutResponse"} 33 | } 34 | 35 | columnFormatCodes := make([]uint16, columnCount) 36 | for i := 0; i < columnCount; i++ { 37 | columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) 38 | } 39 | 40 | *dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} 41 | 42 | return nil 43 | } 44 | 45 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 46 | func (src *CopyOutResponse) Encode(dst []byte) []byte { 47 | dst = append(dst, 'H') 48 | sp := len(dst) 49 | dst = pgio.AppendInt32(dst, -1) 50 | 51 | dst = append(dst, src.OverallFormat) 52 | 53 | dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) 54 | for _, fc := range src.ColumnFormatCodes { 55 | dst = pgio.AppendUint16(dst, fc) 56 | } 57 | 58 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 59 | 60 | return dst 61 | } 62 | 63 | // MarshalJSON implements encoding/json.Marshaler. 64 | func (src CopyOutResponse) MarshalJSON() ([]byte, error) { 65 | return json.Marshal(struct { 66 | Type string 67 | ColumnFormatCodes []uint16 68 | }{ 69 | Type: "CopyOutResponse", 70 | ColumnFormatCodes: src.ColumnFormatCodes, 71 | }) 72 | } 73 | 74 | // UnmarshalJSON implements encoding/json.Unmarshaler. 75 | func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error { 76 | // Ignore null, like in the main JSON package. 77 | if string(data) == "null" { 78 | return nil 79 | } 80 | 81 | var msg struct { 82 | OverallFormat string 83 | ColumnFormatCodes []uint16 84 | } 85 | if err := json.Unmarshal(data, &msg); err != nil { 86 | return err 87 | } 88 | 89 | if len(msg.OverallFormat) != 1 { 90 | return errors.New("invalid length for CopyOutResponse.OverallFormat") 91 | } 92 | 93 | dst.OverallFormat = msg.OverallFormat[0] 94 | dst.ColumnFormatCodes = msg.ColumnFormatCodes 95 | return nil 96 | } 97 | -------------------------------------------------------------------------------- /pgproto3/describe.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | type Describe struct { 12 | ObjectType byte // 'S' = prepared statement, 'P' = portal 13 | Name string 14 | } 15 | 16 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 17 | func (*Describe) Frontend() {} 18 | 19 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 20 | // type identifier and 4 byte message length. 21 | func (dst *Describe) Decode(src []byte) error { 22 | if len(src) < 2 { 23 | return &invalidMessageFormatErr{messageType: "Describe"} 24 | } 25 | 26 | dst.ObjectType = src[0] 27 | rp := 1 28 | 29 | idx := bytes.IndexByte(src[rp:], 0) 30 | if idx != len(src[rp:])-1 { 31 | return &invalidMessageFormatErr{messageType: "Describe"} 32 | } 33 | 34 | dst.Name = string(src[rp : len(src)-1]) 35 | 36 | return nil 37 | } 38 | 39 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 40 | func (src *Describe) Encode(dst []byte) []byte { 41 | dst = append(dst, 'D') 42 | sp := len(dst) 43 | dst = pgio.AppendInt32(dst, -1) 44 | 45 | dst = append(dst, src.ObjectType) 46 | dst = append(dst, src.Name...) 47 | dst = append(dst, 0) 48 | 49 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 50 | 51 | return dst 52 | } 53 | 54 | // MarshalJSON implements encoding/json.Marshaler. 55 | func (src Describe) MarshalJSON() ([]byte, error) { 56 | return json.Marshal(struct { 57 | Type string 58 | ObjectType string 59 | Name string 60 | }{ 61 | Type: "Describe", 62 | ObjectType: string(src.ObjectType), 63 | Name: src.Name, 64 | }) 65 | } 66 | 67 | // UnmarshalJSON implements encoding/json.Unmarshaler. 68 | func (dst *Describe) UnmarshalJSON(data []byte) error { 69 | // Ignore null, like in the main JSON package. 70 | if string(data) == "null" { 71 | return nil 72 | } 73 | 74 | var msg struct { 75 | ObjectType string 76 | Name string 77 | } 78 | if err := json.Unmarshal(data, &msg); err != nil { 79 | return err 80 | } 81 | if len(msg.ObjectType) != 1 { 82 | return errors.New("invalid length for Describe.ObjectType") 83 | } 84 | 85 | dst.ObjectType = byte(msg.ObjectType[0]) 86 | dst.Name = msg.Name 87 | return nil 88 | } 89 | -------------------------------------------------------------------------------- /pgproto3/doc.go: -------------------------------------------------------------------------------- 1 | // Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. 2 | // 3 | // The primary interfaces are Frontend and Backend. They correspond to a client and server respectively. Messages are 4 | // sent with Send (or a specialized Send variant). Messages are automatically bufferred to minimize small writes. Call 5 | // Flush to ensure a message has actually been sent. 6 | // 7 | // The Trace method of Frontend and Backend can be used to examine the wire-level message traffic. It outputs in a 8 | // similar format to the PQtrace function in libpq. 9 | // 10 | // See https://www.postgresql.org/docs/current/protocol-message-formats.html for meanings of the different messages. 11 | package pgproto3 12 | -------------------------------------------------------------------------------- /pgproto3/empty_query_response.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type EmptyQueryResponse struct{} 8 | 9 | // Backend identifies this message as sendable by the PostgreSQL backend. 10 | func (*EmptyQueryResponse) Backend() {} 11 | 12 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 13 | // type identifier and 4 byte message length. 14 | func (dst *EmptyQueryResponse) Decode(src []byte) error { 15 | if len(src) != 0 { 16 | return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} 17 | } 18 | 19 | return nil 20 | } 21 | 22 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 23 | func (src *EmptyQueryResponse) Encode(dst []byte) []byte { 24 | return append(dst, 'I', 0, 0, 0, 4) 25 | } 26 | 27 | // MarshalJSON implements encoding/json.Marshaler. 28 | func (src EmptyQueryResponse) MarshalJSON() ([]byte, error) { 29 | return json.Marshal(struct { 30 | Type string 31 | }{ 32 | Type: "EmptyQueryResponse", 33 | }) 34 | } 35 | -------------------------------------------------------------------------------- /pgproto3/example/pgfortune/README.md: -------------------------------------------------------------------------------- 1 | # pgfortune 2 | 3 | pgfortune is a mock PostgreSQL server that responds to every query with a fortune. 4 | 5 | ## Installation 6 | 7 | Install `fortune` and `cowsay`. They should be available in any Unix package manager (apt, yum, brew, etc.) 8 | 9 | ``` 10 | go get -u github.com/jackc/pgproto3/example/pgfortune 11 | ``` 12 | 13 | ## Usage 14 | 15 | ``` 16 | $ pgfortune 17 | ``` 18 | 19 | By default pgfortune listens on 127.0.0.1:15432 and responds to queries with `fortune | cowsay -f elephant`. These are 20 | configurable with the `listen` and `response-command` arguments respectively. 21 | 22 | While `pgfortune` is running connect to it with `psql`. 23 | 24 | ``` 25 | $ psql -h 127.0.0.1 -p 15432 26 | Timing is on. 27 | Null display is "∅". 28 | Line style is unicode. 29 | psql (11.5, server 0.0.0) 30 | Type "help" for help. 31 | 32 | jack@127.0.0.1:15432 jack=# select foo; 33 | fortune 34 | ───────────────────────────────────────────── 35 | _________________________________________ ↵ 36 | / Ships are safe in harbor, but they were \↵ 37 | \ never meant to stay there. /↵ 38 | ----------------------------------------- ↵ 39 | \ /\ ___ /\ ↵ 40 | \ // \/ \/ \\ ↵ 41 | (( O O )) ↵ 42 | \\ / \ // ↵ 43 | \/ | | \/ ↵ 44 | | | | | ↵ 45 | | | | | ↵ 46 | | o | ↵ 47 | | | | | ↵ 48 | |m| |m| ↵ 49 | 50 | (1 row) 51 | 52 | Time: 28.161 ms 53 | ``` 54 | -------------------------------------------------------------------------------- /pgproto3/example/pgfortune/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "net" 8 | "os" 9 | "os/exec" 10 | ) 11 | 12 | var options struct { 13 | listenAddress string 14 | responseCommand string 15 | } 16 | 17 | func main() { 18 | flag.Usage = func() { 19 | fmt.Fprintf(os.Stderr, "usage: %s [options]\n", os.Args[0]) 20 | flag.PrintDefaults() 21 | } 22 | 23 | flag.StringVar(&options.listenAddress, "listen", "127.0.0.1:15432", "Listen address") 24 | flag.StringVar(&options.responseCommand, "response-command", "fortune | cowsay -f elephant", "Command to execute to generate query response") 25 | flag.Parse() 26 | 27 | ln, err := net.Listen("tcp", options.listenAddress) 28 | if err != nil { 29 | log.Fatal(err) 30 | } 31 | log.Println("Listening on", ln.Addr()) 32 | 33 | for { 34 | conn, err := ln.Accept() 35 | if err != nil { 36 | log.Fatal(err) 37 | } 38 | log.Println("Accepted connection from", conn.RemoteAddr()) 39 | 40 | b := NewPgFortuneBackend(conn, func() ([]byte, error) { 41 | return exec.Command("sh", "-c", options.responseCommand).CombinedOutput() 42 | }) 43 | go func() { 44 | err := b.Run() 45 | if err != nil { 46 | log.Println(err) 47 | } 48 | log.Println("Closed connection from", conn.RemoteAddr()) 49 | }() 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /pgproto3/execute.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "encoding/json" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | type Execute struct { 12 | Portal string 13 | MaxRows uint32 14 | } 15 | 16 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 17 | func (*Execute) Frontend() {} 18 | 19 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 20 | // type identifier and 4 byte message length. 21 | func (dst *Execute) Decode(src []byte) error { 22 | buf := bytes.NewBuffer(src) 23 | 24 | b, err := buf.ReadBytes(0) 25 | if err != nil { 26 | return err 27 | } 28 | dst.Portal = string(b[:len(b)-1]) 29 | 30 | if buf.Len() < 4 { 31 | return &invalidMessageFormatErr{messageType: "Execute"} 32 | } 33 | dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4)) 34 | 35 | return nil 36 | } 37 | 38 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 39 | func (src *Execute) Encode(dst []byte) []byte { 40 | dst = append(dst, 'E') 41 | sp := len(dst) 42 | dst = pgio.AppendInt32(dst, -1) 43 | 44 | dst = append(dst, src.Portal...) 45 | dst = append(dst, 0) 46 | 47 | dst = pgio.AppendUint32(dst, src.MaxRows) 48 | 49 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 50 | 51 | return dst 52 | } 53 | 54 | // MarshalJSON implements encoding/json.Marshaler. 55 | func (src Execute) MarshalJSON() ([]byte, error) { 56 | return json.Marshal(struct { 57 | Type string 58 | Portal string 59 | MaxRows uint32 60 | }{ 61 | Type: "Execute", 62 | Portal: src.Portal, 63 | MaxRows: src.MaxRows, 64 | }) 65 | } 66 | -------------------------------------------------------------------------------- /pgproto3/flush.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type Flush struct{} 8 | 9 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 10 | func (*Flush) Frontend() {} 11 | 12 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 13 | // type identifier and 4 byte message length. 14 | func (dst *Flush) Decode(src []byte) error { 15 | if len(src) != 0 { 16 | return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)} 17 | } 18 | 19 | return nil 20 | } 21 | 22 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 23 | func (src *Flush) Encode(dst []byte) []byte { 24 | return append(dst, 'H', 0, 0, 0, 4) 25 | } 26 | 27 | // MarshalJSON implements encoding/json.Marshaler. 28 | func (src Flush) MarshalJSON() ([]byte, error) { 29 | return json.Marshal(struct { 30 | Type string 31 | }{ 32 | Type: "Flush", 33 | }) 34 | } 35 | -------------------------------------------------------------------------------- /pgproto3/function_call_test.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestFunctionCall_EncodeDecode(t *testing.T) { 10 | type fields struct { 11 | Function uint32 12 | ArgFormatCodes []uint16 13 | Arguments [][]byte 14 | ResultFormatCode uint16 15 | } 16 | tests := []struct { 17 | name string 18 | fields fields 19 | wantErr bool 20 | }{ 21 | {"valid", fields{uint32(123), []uint16{0, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(1)}, false}, 22 | {"invalid format code", fields{uint32(123), []uint16{2, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(0)}, true}, 23 | {"invalid result format code", fields{uint32(123), []uint16{1, 1, 0, 1}, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, uint16(2)}, true}, 24 | } 25 | for _, tt := range tests { 26 | t.Run(tt.name, func(t *testing.T) { 27 | src := &FunctionCall{ 28 | Function: tt.fields.Function, 29 | ArgFormatCodes: tt.fields.ArgFormatCodes, 30 | Arguments: tt.fields.Arguments, 31 | ResultFormatCode: tt.fields.ResultFormatCode, 32 | } 33 | encoded := src.Encode([]byte{}) 34 | dst := &FunctionCall{} 35 | // Check the header 36 | msgTypeCode := encoded[0] 37 | if msgTypeCode != 'F' { 38 | t.Errorf("msgTypeCode %v should be 'F'", msgTypeCode) 39 | return 40 | } 41 | // Check length, does not include type code character 42 | l := binary.BigEndian.Uint32(encoded[1:5]) 43 | if int(l) != (len(encoded) - 1) { 44 | t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded)) 45 | } 46 | // Check decoding works as expected 47 | err := dst.Decode(encoded[5:]) 48 | if err != nil { 49 | if !tt.wantErr { 50 | t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) 51 | } 52 | return 53 | } 54 | 55 | if !reflect.DeepEqual(src, dst) { 56 | t.Error("difference after encode / decode cycle") 57 | t.Errorf("src = %v", src) 58 | t.Errorf("dst = %v", dst) 59 | } 60 | }) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /pgproto3/fuzz_test.go: -------------------------------------------------------------------------------- 1 | package pgproto3_test 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/internal/pgio" 8 | "github.com/jackc/pgx/v5/pgproto3" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func FuzzFrontend(f *testing.F) { 13 | testcases := []struct { 14 | msgType byte 15 | msgLen uint32 16 | msgBody []byte 17 | }{ 18 | { 19 | msgType: 'Z', 20 | msgLen: 2, 21 | msgBody: []byte{'I'}, 22 | }, 23 | { 24 | msgType: 'Z', 25 | msgLen: 5, 26 | msgBody: []byte{'I'}, 27 | }, 28 | } 29 | for _, tc := range testcases { 30 | f.Add(tc.msgType, tc.msgLen, tc.msgBody) 31 | } 32 | f.Fuzz(func(t *testing.T, msgType byte, msgLen uint32, msgBody []byte) { 33 | // Prune any msgLen > len(msgBody) because they would hang the test waiting for more input. 34 | if int(msgLen) > len(msgBody)+4 { 35 | return 36 | } 37 | 38 | // Prune any messages that are too long. 39 | if msgLen > 128 || len(msgBody) > 128 { 40 | return 41 | } 42 | 43 | r := &bytes.Buffer{} 44 | w := &bytes.Buffer{} 45 | fe := pgproto3.NewFrontend(r, w) 46 | 47 | var encodedMsg []byte 48 | encodedMsg = append(encodedMsg, msgType) 49 | encodedMsg = pgio.AppendUint32(encodedMsg, msgLen) 50 | encodedMsg = append(encodedMsg, msgBody...) 51 | _, err := r.Write(encodedMsg) 52 | require.NoError(t, err) 53 | 54 | // Not checking anything other than no panic. 55 | fe.Receive() 56 | }) 57 | } 58 | -------------------------------------------------------------------------------- /pgproto3/gss_enc_request.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | const gssEncReqNumber = 80877104 12 | 13 | type GSSEncRequest struct { 14 | } 15 | 16 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 17 | func (*GSSEncRequest) Frontend() {} 18 | 19 | func (dst *GSSEncRequest) Decode(src []byte) error { 20 | if len(src) < 4 { 21 | return errors.New("gss encoding request too short") 22 | } 23 | 24 | requestCode := binary.BigEndian.Uint32(src) 25 | 26 | if requestCode != gssEncReqNumber { 27 | return errors.New("bad gss encoding request code") 28 | } 29 | 30 | return nil 31 | } 32 | 33 | // Encode encodes src into dst. dst will include the 4 byte message length. 34 | func (src *GSSEncRequest) Encode(dst []byte) []byte { 35 | dst = pgio.AppendInt32(dst, 8) 36 | dst = pgio.AppendInt32(dst, gssEncReqNumber) 37 | return dst 38 | } 39 | 40 | // MarshalJSON implements encoding/json.Marshaler. 41 | func (src GSSEncRequest) MarshalJSON() ([]byte, error) { 42 | return json.Marshal(struct { 43 | Type string 44 | ProtocolVersion uint32 45 | Parameters map[string]string 46 | }{ 47 | Type: "GSSEncRequest", 48 | }) 49 | } 50 | -------------------------------------------------------------------------------- /pgproto3/gss_response.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/jackc/pgx/v5/internal/pgio" 7 | ) 8 | 9 | type GSSResponse struct { 10 | Data []byte 11 | } 12 | 13 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 14 | func (g *GSSResponse) Frontend() {} 15 | 16 | func (g *GSSResponse) Decode(data []byte) error { 17 | g.Data = data 18 | return nil 19 | } 20 | 21 | func (g *GSSResponse) Encode(dst []byte) []byte { 22 | dst = append(dst, 'p') 23 | dst = pgio.AppendInt32(dst, int32(4+len(g.Data))) 24 | dst = append(dst, g.Data...) 25 | return dst 26 | } 27 | 28 | // MarshalJSON implements encoding/json.Marshaler. 29 | func (g *GSSResponse) MarshalJSON() ([]byte, error) { 30 | return json.Marshal(struct { 31 | Type string 32 | Data []byte 33 | }{ 34 | Type: "GSSResponse", 35 | Data: g.Data, 36 | }) 37 | } 38 | 39 | // UnmarshalJSON implements encoding/json.Unmarshaler. 40 | func (g *GSSResponse) UnmarshalJSON(data []byte) error { 41 | var msg struct { 42 | Data []byte 43 | } 44 | if err := json.Unmarshal(data, &msg); err != nil { 45 | return err 46 | } 47 | g.Data = msg.Data 48 | return nil 49 | } 50 | -------------------------------------------------------------------------------- /pgproto3/no_data.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type NoData struct{} 8 | 9 | // Backend identifies this message as sendable by the PostgreSQL backend. 10 | func (*NoData) Backend() {} 11 | 12 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 13 | // type identifier and 4 byte message length. 14 | func (dst *NoData) Decode(src []byte) error { 15 | if len(src) != 0 { 16 | return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} 17 | } 18 | 19 | return nil 20 | } 21 | 22 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 23 | func (src *NoData) Encode(dst []byte) []byte { 24 | return append(dst, 'n', 0, 0, 0, 4) 25 | } 26 | 27 | // MarshalJSON implements encoding/json.Marshaler. 28 | func (src NoData) MarshalJSON() ([]byte, error) { 29 | return json.Marshal(struct { 30 | Type string 31 | }{ 32 | Type: "NoData", 33 | }) 34 | } 35 | -------------------------------------------------------------------------------- /pgproto3/notice_response.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | type NoticeResponse ErrorResponse 4 | 5 | // Backend identifies this message as sendable by the PostgreSQL backend. 6 | func (*NoticeResponse) Backend() {} 7 | 8 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 9 | // type identifier and 4 byte message length. 10 | func (dst *NoticeResponse) Decode(src []byte) error { 11 | return (*ErrorResponse)(dst).Decode(src) 12 | } 13 | 14 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 15 | func (src *NoticeResponse) Encode(dst []byte) []byte { 16 | return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) 17 | } 18 | -------------------------------------------------------------------------------- /pgproto3/notification_response.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "encoding/json" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | type NotificationResponse struct { 12 | PID uint32 13 | Channel string 14 | Payload string 15 | } 16 | 17 | // Backend identifies this message as sendable by the PostgreSQL backend. 18 | func (*NotificationResponse) Backend() {} 19 | 20 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 21 | // type identifier and 4 byte message length. 22 | func (dst *NotificationResponse) Decode(src []byte) error { 23 | buf := bytes.NewBuffer(src) 24 | 25 | if buf.Len() < 4 { 26 | return &invalidMessageFormatErr{messageType: "NotificationResponse", details: "too short"} 27 | } 28 | 29 | pid := binary.BigEndian.Uint32(buf.Next(4)) 30 | 31 | b, err := buf.ReadBytes(0) 32 | if err != nil { 33 | return err 34 | } 35 | channel := string(b[:len(b)-1]) 36 | 37 | b, err = buf.ReadBytes(0) 38 | if err != nil { 39 | return err 40 | } 41 | payload := string(b[:len(b)-1]) 42 | 43 | *dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload} 44 | return nil 45 | } 46 | 47 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 48 | func (src *NotificationResponse) Encode(dst []byte) []byte { 49 | dst = append(dst, 'A') 50 | sp := len(dst) 51 | dst = pgio.AppendInt32(dst, -1) 52 | 53 | dst = pgio.AppendUint32(dst, src.PID) 54 | dst = append(dst, src.Channel...) 55 | dst = append(dst, 0) 56 | dst = append(dst, src.Payload...) 57 | dst = append(dst, 0) 58 | 59 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 60 | 61 | return dst 62 | } 63 | 64 | // MarshalJSON implements encoding/json.Marshaler. 65 | func (src NotificationResponse) MarshalJSON() ([]byte, error) { 66 | return json.Marshal(struct { 67 | Type string 68 | PID uint32 69 | Channel string 70 | Payload string 71 | }{ 72 | Type: "NotificationResponse", 73 | PID: src.PID, 74 | Channel: src.Channel, 75 | Payload: src.Payload, 76 | }) 77 | } 78 | -------------------------------------------------------------------------------- /pgproto3/parameter_description.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "encoding/json" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | type ParameterDescription struct { 12 | ParameterOIDs []uint32 13 | } 14 | 15 | // Backend identifies this message as sendable by the PostgreSQL backend. 16 | func (*ParameterDescription) Backend() {} 17 | 18 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 19 | // type identifier and 4 byte message length. 20 | func (dst *ParameterDescription) Decode(src []byte) error { 21 | buf := bytes.NewBuffer(src) 22 | 23 | if buf.Len() < 2 { 24 | return &invalidMessageFormatErr{messageType: "ParameterDescription"} 25 | } 26 | 27 | // Reported parameter count will be incorrect when number of args is greater than uint16 28 | buf.Next(2) 29 | // Instead infer parameter count by remaining size of message 30 | parameterCount := buf.Len() / 4 31 | 32 | *dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)} 33 | 34 | for i := 0; i < parameterCount; i++ { 35 | dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) 36 | } 37 | 38 | return nil 39 | } 40 | 41 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 42 | func (src *ParameterDescription) Encode(dst []byte) []byte { 43 | dst = append(dst, 't') 44 | sp := len(dst) 45 | dst = pgio.AppendInt32(dst, -1) 46 | 47 | dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) 48 | for _, oid := range src.ParameterOIDs { 49 | dst = pgio.AppendUint32(dst, oid) 50 | } 51 | 52 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 53 | 54 | return dst 55 | } 56 | 57 | // MarshalJSON implements encoding/json.Marshaler. 58 | func (src ParameterDescription) MarshalJSON() ([]byte, error) { 59 | return json.Marshal(struct { 60 | Type string 61 | ParameterOIDs []uint32 62 | }{ 63 | Type: "ParameterDescription", 64 | ParameterOIDs: src.ParameterOIDs, 65 | }) 66 | } 67 | -------------------------------------------------------------------------------- /pgproto3/parameter_status.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | 7 | "github.com/jackc/pgx/v5/internal/pgio" 8 | ) 9 | 10 | type ParameterStatus struct { 11 | Name string 12 | Value string 13 | } 14 | 15 | // Backend identifies this message as sendable by the PostgreSQL backend. 16 | func (*ParameterStatus) Backend() {} 17 | 18 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 19 | // type identifier and 4 byte message length. 20 | func (dst *ParameterStatus) Decode(src []byte) error { 21 | buf := bytes.NewBuffer(src) 22 | 23 | b, err := buf.ReadBytes(0) 24 | if err != nil { 25 | return err 26 | } 27 | name := string(b[:len(b)-1]) 28 | 29 | b, err = buf.ReadBytes(0) 30 | if err != nil { 31 | return err 32 | } 33 | value := string(b[:len(b)-1]) 34 | 35 | *dst = ParameterStatus{Name: name, Value: value} 36 | return nil 37 | } 38 | 39 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 40 | func (src *ParameterStatus) Encode(dst []byte) []byte { 41 | dst = append(dst, 'S') 42 | sp := len(dst) 43 | dst = pgio.AppendInt32(dst, -1) 44 | 45 | dst = append(dst, src.Name...) 46 | dst = append(dst, 0) 47 | dst = append(dst, src.Value...) 48 | dst = append(dst, 0) 49 | 50 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 51 | 52 | return dst 53 | } 54 | 55 | // MarshalJSON implements encoding/json.Marshaler. 56 | func (ps ParameterStatus) MarshalJSON() ([]byte, error) { 57 | return json.Marshal(struct { 58 | Type string 59 | Name string 60 | Value string 61 | }{ 62 | Type: "ParameterStatus", 63 | Name: ps.Name, 64 | Value: ps.Value, 65 | }) 66 | } 67 | -------------------------------------------------------------------------------- /pgproto3/parse.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "encoding/json" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | type Parse struct { 12 | Name string 13 | Query string 14 | ParameterOIDs []uint32 15 | } 16 | 17 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 18 | func (*Parse) Frontend() {} 19 | 20 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 21 | // type identifier and 4 byte message length. 22 | func (dst *Parse) Decode(src []byte) error { 23 | *dst = Parse{} 24 | 25 | buf := bytes.NewBuffer(src) 26 | 27 | b, err := buf.ReadBytes(0) 28 | if err != nil { 29 | return err 30 | } 31 | dst.Name = string(b[:len(b)-1]) 32 | 33 | b, err = buf.ReadBytes(0) 34 | if err != nil { 35 | return err 36 | } 37 | dst.Query = string(b[:len(b)-1]) 38 | 39 | if buf.Len() < 2 { 40 | return &invalidMessageFormatErr{messageType: "Parse"} 41 | } 42 | parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2))) 43 | 44 | for i := 0; i < parameterOIDCount; i++ { 45 | if buf.Len() < 4 { 46 | return &invalidMessageFormatErr{messageType: "Parse"} 47 | } 48 | dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4))) 49 | } 50 | 51 | return nil 52 | } 53 | 54 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 55 | func (src *Parse) Encode(dst []byte) []byte { 56 | dst = append(dst, 'P') 57 | sp := len(dst) 58 | dst = pgio.AppendInt32(dst, -1) 59 | 60 | dst = append(dst, src.Name...) 61 | dst = append(dst, 0) 62 | dst = append(dst, src.Query...) 63 | dst = append(dst, 0) 64 | 65 | dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) 66 | for _, oid := range src.ParameterOIDs { 67 | dst = pgio.AppendUint32(dst, oid) 68 | } 69 | 70 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 71 | 72 | return dst 73 | } 74 | 75 | // MarshalJSON implements encoding/json.Marshaler. 76 | func (src Parse) MarshalJSON() ([]byte, error) { 77 | return json.Marshal(struct { 78 | Type string 79 | Name string 80 | Query string 81 | ParameterOIDs []uint32 82 | }{ 83 | Type: "Parse", 84 | Name: src.Name, 85 | Query: src.Query, 86 | ParameterOIDs: src.ParameterOIDs, 87 | }) 88 | } 89 | -------------------------------------------------------------------------------- /pgproto3/parse_complete.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type ParseComplete struct{} 8 | 9 | // Backend identifies this message as sendable by the PostgreSQL backend. 10 | func (*ParseComplete) Backend() {} 11 | 12 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 13 | // type identifier and 4 byte message length. 14 | func (dst *ParseComplete) Decode(src []byte) error { 15 | if len(src) != 0 { 16 | return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} 17 | } 18 | 19 | return nil 20 | } 21 | 22 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 23 | func (src *ParseComplete) Encode(dst []byte) []byte { 24 | return append(dst, '1', 0, 0, 0, 4) 25 | } 26 | 27 | // MarshalJSON implements encoding/json.Marshaler. 28 | func (src ParseComplete) MarshalJSON() ([]byte, error) { 29 | return json.Marshal(struct { 30 | Type string 31 | }{ 32 | Type: "ParseComplete", 33 | }) 34 | } 35 | -------------------------------------------------------------------------------- /pgproto3/password_message.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | 7 | "github.com/jackc/pgx/v5/internal/pgio" 8 | ) 9 | 10 | type PasswordMessage struct { 11 | Password string 12 | } 13 | 14 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 15 | func (*PasswordMessage) Frontend() {} 16 | 17 | // Frontend identifies this message as an authentication response. 18 | func (*PasswordMessage) InitialResponse() {} 19 | 20 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 21 | // type identifier and 4 byte message length. 22 | func (dst *PasswordMessage) Decode(src []byte) error { 23 | buf := bytes.NewBuffer(src) 24 | 25 | b, err := buf.ReadBytes(0) 26 | if err != nil { 27 | return err 28 | } 29 | dst.Password = string(b[:len(b)-1]) 30 | 31 | return nil 32 | } 33 | 34 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 35 | func (src *PasswordMessage) Encode(dst []byte) []byte { 36 | dst = append(dst, 'p') 37 | dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) 38 | 39 | dst = append(dst, src.Password...) 40 | dst = append(dst, 0) 41 | 42 | return dst 43 | } 44 | 45 | // MarshalJSON implements encoding/json.Marshaler. 46 | func (src PasswordMessage) MarshalJSON() ([]byte, error) { 47 | return json.Marshal(struct { 48 | Type string 49 | Password string 50 | }{ 51 | Type: "PasswordMessage", 52 | Password: src.Password, 53 | }) 54 | } 55 | -------------------------------------------------------------------------------- /pgproto3/pgproto3.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/hex" 5 | "errors" 6 | "fmt" 7 | ) 8 | 9 | // Message is the interface implemented by an object that can decode and encode 10 | // a particular PostgreSQL message. 11 | type Message interface { 12 | // Decode is allowed and expected to retain a reference to data after 13 | // returning (unlike encoding.BinaryUnmarshaler). 14 | Decode(data []byte) error 15 | 16 | // Encode appends itself to dst and returns the new buffer. 17 | Encode(dst []byte) []byte 18 | } 19 | 20 | // FrontendMessage is a message sent by the frontend (i.e. the client). 21 | type FrontendMessage interface { 22 | Message 23 | Frontend() // no-op method to distinguish frontend from backend methods 24 | } 25 | 26 | // BackendMessage is a message sent by the backend (i.e. the server). 27 | type BackendMessage interface { 28 | Message 29 | Backend() // no-op method to distinguish frontend from backend methods 30 | } 31 | 32 | type AuthenticationResponseMessage interface { 33 | BackendMessage 34 | AuthenticationResponse() // no-op method to distinguish authentication responses 35 | } 36 | 37 | type invalidMessageLenErr struct { 38 | messageType string 39 | expectedLen int 40 | actualLen int 41 | } 42 | 43 | func (e *invalidMessageLenErr) Error() string { 44 | return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen) 45 | } 46 | 47 | type invalidMessageFormatErr struct { 48 | messageType string 49 | details string 50 | } 51 | 52 | func (e *invalidMessageFormatErr) Error() string { 53 | return fmt.Sprintf("%s body is invalid %s", e.messageType, e.details) 54 | } 55 | 56 | type writeError struct { 57 | err error 58 | safeToRetry bool 59 | } 60 | 61 | func (e *writeError) Error() string { 62 | return fmt.Sprintf("write failed: %s", e.err.Error()) 63 | } 64 | 65 | func (e *writeError) SafeToRetry() bool { 66 | return e.safeToRetry 67 | } 68 | 69 | func (e *writeError) Unwrap() error { 70 | return e.err 71 | } 72 | 73 | // getValueFromJSON gets the value from a protocol message representation in JSON. 74 | func getValueFromJSON(v map[string]string) ([]byte, error) { 75 | if v == nil { 76 | return nil, nil 77 | } 78 | if text, ok := v["text"]; ok { 79 | return []byte(text), nil 80 | } 81 | if binary, ok := v["binary"]; ok { 82 | return hex.DecodeString(binary) 83 | } 84 | return nil, errors.New("unknown protocol representation") 85 | } 86 | -------------------------------------------------------------------------------- /pgproto3/portal_suspended.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type PortalSuspended struct{} 8 | 9 | // Backend identifies this message as sendable by the PostgreSQL backend. 10 | func (*PortalSuspended) Backend() {} 11 | 12 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 13 | // type identifier and 4 byte message length. 14 | func (dst *PortalSuspended) Decode(src []byte) error { 15 | if len(src) != 0 { 16 | return &invalidMessageLenErr{messageType: "PortalSuspended", expectedLen: 0, actualLen: len(src)} 17 | } 18 | 19 | return nil 20 | } 21 | 22 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 23 | func (src *PortalSuspended) Encode(dst []byte) []byte { 24 | return append(dst, 's', 0, 0, 0, 4) 25 | } 26 | 27 | // MarshalJSON implements encoding/json.Marshaler. 28 | func (src PortalSuspended) MarshalJSON() ([]byte, error) { 29 | return json.Marshal(struct { 30 | Type string 31 | }{ 32 | Type: "PortalSuspended", 33 | }) 34 | } 35 | -------------------------------------------------------------------------------- /pgproto3/query.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | 7 | "github.com/jackc/pgx/v5/internal/pgio" 8 | ) 9 | 10 | type Query struct { 11 | String string 12 | } 13 | 14 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 15 | func (*Query) Frontend() {} 16 | 17 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 18 | // type identifier and 4 byte message length. 19 | func (dst *Query) Decode(src []byte) error { 20 | i := bytes.IndexByte(src, 0) 21 | if i != len(src)-1 { 22 | return &invalidMessageFormatErr{messageType: "Query"} 23 | } 24 | 25 | dst.String = string(src[:i]) 26 | 27 | return nil 28 | } 29 | 30 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 31 | func (src *Query) Encode(dst []byte) []byte { 32 | dst = append(dst, 'Q') 33 | dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) 34 | 35 | dst = append(dst, src.String...) 36 | dst = append(dst, 0) 37 | 38 | return dst 39 | } 40 | 41 | // MarshalJSON implements encoding/json.Marshaler. 42 | func (src Query) MarshalJSON() ([]byte, error) { 43 | return json.Marshal(struct { 44 | Type string 45 | String string 46 | }{ 47 | Type: "Query", 48 | String: src.String, 49 | }) 50 | } 51 | -------------------------------------------------------------------------------- /pgproto3/ready_for_query.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | ) 7 | 8 | type ReadyForQuery struct { 9 | TxStatus byte 10 | } 11 | 12 | // Backend identifies this message as sendable by the PostgreSQL backend. 13 | func (*ReadyForQuery) Backend() {} 14 | 15 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 16 | // type identifier and 4 byte message length. 17 | func (dst *ReadyForQuery) Decode(src []byte) error { 18 | if len(src) != 1 { 19 | return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} 20 | } 21 | 22 | dst.TxStatus = src[0] 23 | 24 | return nil 25 | } 26 | 27 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 28 | func (src *ReadyForQuery) Encode(dst []byte) []byte { 29 | return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) 30 | } 31 | 32 | // MarshalJSON implements encoding/json.Marshaler. 33 | func (src ReadyForQuery) MarshalJSON() ([]byte, error) { 34 | return json.Marshal(struct { 35 | Type string 36 | TxStatus string 37 | }{ 38 | Type: "ReadyForQuery", 39 | TxStatus: string(src.TxStatus), 40 | }) 41 | } 42 | 43 | // UnmarshalJSON implements encoding/json.Unmarshaler. 44 | func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error { 45 | // Ignore null, like in the main JSON package. 46 | if string(data) == "null" { 47 | return nil 48 | } 49 | 50 | var msg struct { 51 | TxStatus string 52 | } 53 | if err := json.Unmarshal(data, &msg); err != nil { 54 | return err 55 | } 56 | if len(msg.TxStatus) != 1 { 57 | return errors.New("invalid length for ReadyForQuery.TxStatus") 58 | } 59 | dst.TxStatus = msg.TxStatus[0] 60 | return nil 61 | } 62 | -------------------------------------------------------------------------------- /pgproto3/sasl_initial_response.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/hex" 6 | "encoding/json" 7 | "errors" 8 | 9 | "github.com/jackc/pgx/v5/internal/pgio" 10 | ) 11 | 12 | type SASLInitialResponse struct { 13 | AuthMechanism string 14 | Data []byte 15 | } 16 | 17 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 18 | func (*SASLInitialResponse) Frontend() {} 19 | 20 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 21 | // type identifier and 4 byte message length. 22 | func (dst *SASLInitialResponse) Decode(src []byte) error { 23 | *dst = SASLInitialResponse{} 24 | 25 | rp := 0 26 | 27 | idx := bytes.IndexByte(src, 0) 28 | if idx < 0 { 29 | return errors.New("invalid SASLInitialResponse") 30 | } 31 | 32 | dst.AuthMechanism = string(src[rp:idx]) 33 | rp = idx + 1 34 | 35 | rp += 4 // The rest of the message is data so we can just skip the size 36 | dst.Data = src[rp:] 37 | 38 | return nil 39 | } 40 | 41 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 42 | func (src *SASLInitialResponse) Encode(dst []byte) []byte { 43 | dst = append(dst, 'p') 44 | sp := len(dst) 45 | dst = pgio.AppendInt32(dst, -1) 46 | 47 | dst = append(dst, []byte(src.AuthMechanism)...) 48 | dst = append(dst, 0) 49 | 50 | dst = pgio.AppendInt32(dst, int32(len(src.Data))) 51 | dst = append(dst, src.Data...) 52 | 53 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 54 | 55 | return dst 56 | } 57 | 58 | // MarshalJSON implements encoding/json.Marshaler. 59 | func (src SASLInitialResponse) MarshalJSON() ([]byte, error) { 60 | return json.Marshal(struct { 61 | Type string 62 | AuthMechanism string 63 | Data string 64 | }{ 65 | Type: "SASLInitialResponse", 66 | AuthMechanism: src.AuthMechanism, 67 | Data: string(src.Data), 68 | }) 69 | } 70 | 71 | // UnmarshalJSON implements encoding/json.Unmarshaler. 72 | func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error { 73 | // Ignore null, like in the main JSON package. 74 | if string(data) == "null" { 75 | return nil 76 | } 77 | 78 | var msg struct { 79 | AuthMechanism string 80 | Data string 81 | } 82 | if err := json.Unmarshal(data, &msg); err != nil { 83 | return err 84 | } 85 | dst.AuthMechanism = msg.AuthMechanism 86 | if msg.Data != "" { 87 | decoded, err := hex.DecodeString(msg.Data) 88 | if err != nil { 89 | return err 90 | } 91 | dst.Data = decoded 92 | } 93 | return nil 94 | } 95 | -------------------------------------------------------------------------------- /pgproto3/sasl_response.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/hex" 5 | "encoding/json" 6 | 7 | "github.com/jackc/pgx/v5/internal/pgio" 8 | ) 9 | 10 | type SASLResponse struct { 11 | Data []byte 12 | } 13 | 14 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 15 | func (*SASLResponse) Frontend() {} 16 | 17 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 18 | // type identifier and 4 byte message length. 19 | func (dst *SASLResponse) Decode(src []byte) error { 20 | *dst = SASLResponse{Data: src} 21 | return nil 22 | } 23 | 24 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 25 | func (src *SASLResponse) Encode(dst []byte) []byte { 26 | dst = append(dst, 'p') 27 | dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) 28 | 29 | dst = append(dst, src.Data...) 30 | 31 | return dst 32 | } 33 | 34 | // MarshalJSON implements encoding/json.Marshaler. 35 | func (src SASLResponse) MarshalJSON() ([]byte, error) { 36 | return json.Marshal(struct { 37 | Type string 38 | Data string 39 | }{ 40 | Type: "SASLResponse", 41 | Data: string(src.Data), 42 | }) 43 | } 44 | 45 | // UnmarshalJSON implements encoding/json.Unmarshaler. 46 | func (dst *SASLResponse) UnmarshalJSON(data []byte) error { 47 | var msg struct { 48 | Data string 49 | } 50 | if err := json.Unmarshal(data, &msg); err != nil { 51 | return err 52 | } 53 | if msg.Data != "" { 54 | decoded, err := hex.DecodeString(msg.Data) 55 | if err != nil { 56 | return err 57 | } 58 | dst.Data = decoded 59 | } 60 | return nil 61 | } 62 | -------------------------------------------------------------------------------- /pgproto3/ssl_request.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/binary" 5 | "encoding/json" 6 | "errors" 7 | 8 | "github.com/jackc/pgx/v5/internal/pgio" 9 | ) 10 | 11 | const sslRequestNumber = 80877103 12 | 13 | type SSLRequest struct { 14 | } 15 | 16 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 17 | func (*SSLRequest) Frontend() {} 18 | 19 | func (dst *SSLRequest) Decode(src []byte) error { 20 | if len(src) < 4 { 21 | return errors.New("ssl request too short") 22 | } 23 | 24 | requestCode := binary.BigEndian.Uint32(src) 25 | 26 | if requestCode != sslRequestNumber { 27 | return errors.New("bad ssl request code") 28 | } 29 | 30 | return nil 31 | } 32 | 33 | // Encode encodes src into dst. dst will include the 4 byte message length. 34 | func (src *SSLRequest) Encode(dst []byte) []byte { 35 | dst = pgio.AppendInt32(dst, 8) 36 | dst = pgio.AppendInt32(dst, sslRequestNumber) 37 | return dst 38 | } 39 | 40 | // MarshalJSON implements encoding/json.Marshaler. 41 | func (src SSLRequest) MarshalJSON() ([]byte, error) { 42 | return json.Marshal(struct { 43 | Type string 44 | ProtocolVersion uint32 45 | Parameters map[string]string 46 | }{ 47 | Type: "SSLRequest", 48 | }) 49 | } 50 | -------------------------------------------------------------------------------- /pgproto3/startup_message.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | 10 | "github.com/jackc/pgx/v5/internal/pgio" 11 | ) 12 | 13 | const ProtocolVersionNumber = 196608 // 3.0 14 | 15 | type StartupMessage struct { 16 | ProtocolVersion uint32 17 | Parameters map[string]string 18 | } 19 | 20 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 21 | func (*StartupMessage) Frontend() {} 22 | 23 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 24 | // type identifier and 4 byte message length. 25 | func (dst *StartupMessage) Decode(src []byte) error { 26 | if len(src) < 4 { 27 | return errors.New("startup message too short") 28 | } 29 | 30 | dst.ProtocolVersion = binary.BigEndian.Uint32(src) 31 | rp := 4 32 | 33 | if dst.ProtocolVersion != ProtocolVersionNumber { 34 | return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) 35 | } 36 | 37 | dst.Parameters = make(map[string]string) 38 | for { 39 | idx := bytes.IndexByte(src[rp:], 0) 40 | if idx < 0 { 41 | return &invalidMessageFormatErr{messageType: "StartupMesage"} 42 | } 43 | key := string(src[rp : rp+idx]) 44 | rp += idx + 1 45 | 46 | idx = bytes.IndexByte(src[rp:], 0) 47 | if idx < 0 { 48 | return &invalidMessageFormatErr{messageType: "StartupMesage"} 49 | } 50 | value := string(src[rp : rp+idx]) 51 | rp += idx + 1 52 | 53 | dst.Parameters[key] = value 54 | 55 | if len(src[rp:]) == 1 { 56 | if src[rp] != 0 { 57 | return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) 58 | } 59 | break 60 | } 61 | } 62 | 63 | return nil 64 | } 65 | 66 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 67 | func (src *StartupMessage) Encode(dst []byte) []byte { 68 | sp := len(dst) 69 | dst = pgio.AppendInt32(dst, -1) 70 | 71 | dst = pgio.AppendUint32(dst, src.ProtocolVersion) 72 | for k, v := range src.Parameters { 73 | dst = append(dst, k...) 74 | dst = append(dst, 0) 75 | dst = append(dst, v...) 76 | dst = append(dst, 0) 77 | } 78 | dst = append(dst, 0) 79 | 80 | pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) 81 | 82 | return dst 83 | } 84 | 85 | // MarshalJSON implements encoding/json.Marshaler. 86 | func (src StartupMessage) MarshalJSON() ([]byte, error) { 87 | return json.Marshal(struct { 88 | Type string 89 | ProtocolVersion uint32 90 | Parameters map[string]string 91 | }{ 92 | Type: "StartupMessage", 93 | ProtocolVersion: src.ProtocolVersion, 94 | Parameters: src.Parameters, 95 | }) 96 | } 97 | -------------------------------------------------------------------------------- /pgproto3/sync.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type Sync struct{} 8 | 9 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 10 | func (*Sync) Frontend() {} 11 | 12 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 13 | // type identifier and 4 byte message length. 14 | func (dst *Sync) Decode(src []byte) error { 15 | if len(src) != 0 { 16 | return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)} 17 | } 18 | 19 | return nil 20 | } 21 | 22 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 23 | func (src *Sync) Encode(dst []byte) []byte { 24 | return append(dst, 'S', 0, 0, 0, 4) 25 | } 26 | 27 | // MarshalJSON implements encoding/json.Marshaler. 28 | func (src Sync) MarshalJSON() ([]byte, error) { 29 | return json.Marshal(struct { 30 | Type string 31 | }{ 32 | Type: "Sync", 33 | }) 34 | } 35 | -------------------------------------------------------------------------------- /pgproto3/terminate.go: -------------------------------------------------------------------------------- 1 | package pgproto3 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type Terminate struct{} 8 | 9 | // Frontend identifies this message as sendable by a PostgreSQL frontend. 10 | func (*Terminate) Frontend() {} 11 | 12 | // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message 13 | // type identifier and 4 byte message length. 14 | func (dst *Terminate) Decode(src []byte) error { 15 | if len(src) != 0 { 16 | return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)} 17 | } 18 | 19 | return nil 20 | } 21 | 22 | // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. 23 | func (src *Terminate) Encode(dst []byte) []byte { 24 | return append(dst, 'X', 0, 0, 0, 4) 25 | } 26 | 27 | // MarshalJSON implements encoding/json.Marshaler. 28 | func (src Terminate) MarshalJSON() ([]byte, error) { 29 | return json.Marshal(struct { 30 | Type string 31 | }{ 32 | Type: "Terminate", 33 | }) 34 | } 35 | -------------------------------------------------------------------------------- /pgproto3/testdata/fuzz/FuzzFrontend/39c5e864da4707fc15fea48f7062d6a07796fdc43b33e0ba9dbd7074a0211fa6: -------------------------------------------------------------------------------- 1 | go test fuzz v1 2 | byte('A') 3 | uint32(5) 4 | []byte("0") 5 | -------------------------------------------------------------------------------- /pgproto3/testdata/fuzz/FuzzFrontend/9b06792b1aaac8a907dbfa04d526ae14326c8573b7409032caac8461e83065f7: -------------------------------------------------------------------------------- 1 | go test fuzz v1 2 | byte('D') 3 | uint32(21) 4 | []byte("00\xb300000000000000") 5 | -------------------------------------------------------------------------------- /pgproto3/testdata/fuzz/FuzzFrontend/a661fb98e802839f0a7361160fbc6e28794612a411d00bde104364ee281c4214: -------------------------------------------------------------------------------- 1 | go test fuzz v1 2 | byte('C') 3 | uint32(4) 4 | []byte("0") 5 | -------------------------------------------------------------------------------- /pgproto3/testdata/fuzz/FuzzFrontend/fc98dcd487a5173b38763a5f7dd023933f3a86ab566e3f2b091eb36248107eb4: -------------------------------------------------------------------------------- 1 | go test fuzz v1 2 | byte('R') 3 | uint32(13) 4 | []byte("\x00\x00\x00\n0\x12\xebG\x8dI']G\xdac\x95\xb7\x18\xb0\x02\xe8m\xc2\x00\xef\x03\x12\x1b\xbdj\x10\x9f\xf9\xeb\xb8") 5 | -------------------------------------------------------------------------------- /pgproto3/trace_test.go: -------------------------------------------------------------------------------- 1 | package pgproto3_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "os" 7 | "testing" 8 | "time" 9 | 10 | "github.com/jackc/pgx/v5/pgconn" 11 | "github.com/jackc/pgx/v5/pgproto3" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestTrace(t *testing.T) { 16 | t.Parallel() 17 | 18 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 19 | defer cancel() 20 | 21 | conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) 22 | require.NoError(t, err) 23 | defer conn.Close(ctx) 24 | 25 | if conn.ParameterStatus("crdb_version") != "" { 26 | t.Skip("Skipping message trace on CockroachDB as it varies slightly from PostgreSQL") 27 | } 28 | 29 | traceOutput := &bytes.Buffer{} 30 | conn.Frontend().Trace(traceOutput, pgproto3.TracerOptions{ 31 | SuppressTimestamps: true, 32 | RegressMode: true, 33 | }) 34 | 35 | result := conn.ExecParams(ctx, "select n from generate_series(1,5) n", nil, nil, nil, nil).Read() 36 | require.NoError(t, result.Err) 37 | 38 | expected := `F Parse 45 "" "select n from generate_series(1,5) n" 0 39 | F Bind 13 "" "" 0 0 0 40 | F Describe 7 P "" 41 | F Execute 10 "" 0 42 | F Sync 5 43 | B ParseComplete 5 44 | B BindComplete 5 45 | B RowDescription 27 1 "n" 0 0 23 4 -1 0 46 | B DataRow 12 1 1 '1' 47 | B DataRow 12 1 1 '2' 48 | B DataRow 12 1 1 '3' 49 | B DataRow 12 1 1 '4' 50 | B DataRow 12 1 1 '5' 51 | B CommandComplete 14 "SELECT 5" 52 | B ReadyForQuery 6 I 53 | ` 54 | 55 | require.Equal(t, expected, traceOutput.String()) 56 | } 57 | -------------------------------------------------------------------------------- /pgtype/bits_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "testing" 7 | 8 | "github.com/jackc/pgx/v5/pgtype" 9 | "github.com/jackc/pgx/v5/pgxtest" 10 | ) 11 | 12 | func isExpectedEqBits(a any) func(any) bool { 13 | return func(v any) bool { 14 | ab := a.(pgtype.Bits) 15 | vb := v.(pgtype.Bits) 16 | return bytes.Compare(ab.Bytes, vb.Bytes) == 0 && ab.Len == vb.Len && ab.Valid == vb.Valid 17 | } 18 | } 19 | 20 | func TestBitsCodecBit(t *testing.T) { 21 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bit(40)", []pgxtest.ValueRoundTripTest{ 22 | { 23 | pgtype.Bits{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}, 24 | new(pgtype.Bits), 25 | isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 0, 0, 0, 0}, Len: 40, Valid: true}), 26 | }, 27 | { 28 | pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, 29 | new(pgtype.Bits), 30 | isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}), 31 | }, 32 | {pgtype.Bits{}, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, 33 | {nil, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, 34 | }) 35 | } 36 | 37 | func TestBitsCodecVarbit(t *testing.T) { 38 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "varbit", []pgxtest.ValueRoundTripTest{ 39 | { 40 | pgtype.Bits{Bytes: []byte{}, Len: 0, Valid: true}, 41 | new(pgtype.Bits), 42 | isExpectedEqBits(pgtype.Bits{Bytes: []byte{}, Len: 0, Valid: true}), 43 | }, 44 | { 45 | pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}, 46 | new(pgtype.Bits), 47 | isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Valid: true}), 48 | }, 49 | { 50 | pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}, 51 | new(pgtype.Bits), 52 | isExpectedEqBits(pgtype.Bits{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Valid: true}), 53 | }, 54 | {pgtype.Bits{}, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, 55 | {nil, new(pgtype.Bits), isExpectedEqBits(pgtype.Bits{})}, 56 | }) 57 | } 58 | -------------------------------------------------------------------------------- /pgtype/bool_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func TestBoolCodec(t *testing.T) { 12 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bool", []pgxtest.ValueRoundTripTest{ 13 | {true, new(bool), isExpectedEq(true)}, 14 | {false, new(bool), isExpectedEq(false)}, 15 | {true, new(pgtype.Bool), isExpectedEq(pgtype.Bool{Bool: true, Valid: true})}, 16 | {pgtype.Bool{}, new(pgtype.Bool), isExpectedEq(pgtype.Bool{})}, 17 | {nil, new(*bool), isExpectedEq((*bool)(nil))}, 18 | }) 19 | } 20 | 21 | func TestBoolMarshalJSON(t *testing.T) { 22 | successfulTests := []struct { 23 | source pgtype.Bool 24 | result string 25 | }{ 26 | {source: pgtype.Bool{}, result: "null"}, 27 | {source: pgtype.Bool{Bool: true, Valid: true}, result: "true"}, 28 | {source: pgtype.Bool{Bool: false, Valid: true}, result: "false"}, 29 | } 30 | for i, tt := range successfulTests { 31 | r, err := tt.source.MarshalJSON() 32 | if err != nil { 33 | t.Errorf("%d: %v", i, err) 34 | } 35 | 36 | if string(r) != tt.result { 37 | t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) 38 | } 39 | } 40 | } 41 | 42 | func TestBoolUnmarshalJSON(t *testing.T) { 43 | successfulTests := []struct { 44 | source string 45 | result pgtype.Bool 46 | }{ 47 | {source: "null", result: pgtype.Bool{}}, 48 | {source: "true", result: pgtype.Bool{Bool: true, Valid: true}}, 49 | {source: "false", result: pgtype.Bool{Bool: false, Valid: true}}, 50 | } 51 | for i, tt := range successfulTests { 52 | var r pgtype.Bool 53 | err := r.UnmarshalJSON([]byte(tt.source)) 54 | if err != nil { 55 | t.Errorf("%d: %v", i, err) 56 | } 57 | 58 | if r != tt.result { 59 | t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /pgtype/box_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func TestBoxCodec(t *testing.T) { 12 | skipCockroachDB(t, "Server does not support box type") 13 | 14 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "box", []pgxtest.ValueRoundTripTest{ 15 | { 16 | pgtype.Box{ 17 | P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, 18 | Valid: true, 19 | }, 20 | new(pgtype.Box), 21 | isExpectedEq(pgtype.Box{ 22 | P: [2]pgtype.Vec2{{7.1, 5.2345678}, {3.14, 1.678}}, 23 | Valid: true, 24 | }), 25 | }, 26 | { 27 | pgtype.Box{ 28 | P: [2]pgtype.Vec2{{7.1, 5.2345678}, {-13.14, -5.234}}, 29 | Valid: true, 30 | }, 31 | new(pgtype.Box), 32 | isExpectedEq(pgtype.Box{ 33 | P: [2]pgtype.Vec2{{7.1, 5.2345678}, {-13.14, -5.234}}, 34 | Valid: true, 35 | }), 36 | }, 37 | {pgtype.Box{}, new(pgtype.Box), isExpectedEq(pgtype.Box{})}, 38 | {nil, new(pgtype.Box), isExpectedEq(pgtype.Box{})}, 39 | }) 40 | } 41 | -------------------------------------------------------------------------------- /pgtype/circle_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func TestCircleTranscode(t *testing.T) { 12 | skipCockroachDB(t, "Server does not support box type") 13 | 14 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "circle", []pgxtest.ValueRoundTripTest{ 15 | { 16 | pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, 17 | new(pgtype.Circle), 18 | isExpectedEq(pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}), 19 | }, 20 | { 21 | pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, 22 | new(pgtype.Circle), 23 | isExpectedEq(pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}), 24 | }, 25 | {pgtype.Circle{}, new(pgtype.Circle), isExpectedEq(pgtype.Circle{})}, 26 | {nil, new(pgtype.Circle), isExpectedEq(pgtype.Circle{})}, 27 | }) 28 | } 29 | -------------------------------------------------------------------------------- /pgtype/enum_codec_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | pgx "github.com/jackc/pgx/v5" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestEnumCodec(t *testing.T) { 12 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 13 | 14 | _, err := conn.Exec(ctx, `drop type if exists enum_test; 15 | 16 | create type enum_test as enum ('foo', 'bar', 'baz');`) 17 | require.NoError(t, err) 18 | defer conn.Exec(ctx, "drop type enum_test") 19 | 20 | dt, err := conn.LoadType(ctx, "enum_test") 21 | require.NoError(t, err) 22 | 23 | conn.TypeMap().RegisterType(dt) 24 | 25 | var s string 26 | err = conn.QueryRow(ctx, `select 'foo'::enum_test`).Scan(&s) 27 | require.NoError(t, err) 28 | require.Equal(t, "foo", s) 29 | 30 | err = conn.QueryRow(ctx, `select $1::enum_test`, "bar").Scan(&s) 31 | require.NoError(t, err) 32 | require.Equal(t, "bar", s) 33 | 34 | err = conn.QueryRow(ctx, `select 'foo'::enum_test`).Scan(&s) 35 | require.NoError(t, err) 36 | require.Equal(t, "foo", s) 37 | 38 | err = conn.QueryRow(ctx, `select $1::enum_test`, "bar").Scan(&s) 39 | require.NoError(t, err) 40 | require.Equal(t, "bar", s) 41 | 42 | err = conn.QueryRow(ctx, `select 'baz'::enum_test`).Scan(&s) 43 | require.NoError(t, err) 44 | require.Equal(t, "baz", s) 45 | }) 46 | } 47 | 48 | func TestEnumCodecValues(t *testing.T) { 49 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 50 | 51 | _, err := conn.Exec(ctx, `drop type if exists enum_test; 52 | 53 | create type enum_test as enum ('foo', 'bar', 'baz');`) 54 | require.NoError(t, err) 55 | defer conn.Exec(ctx, "drop type enum_test") 56 | 57 | dt, err := conn.LoadType(ctx, "enum_test") 58 | require.NoError(t, err) 59 | 60 | conn.TypeMap().RegisterType(dt) 61 | 62 | rows, err := conn.Query(ctx, `select 'foo'::enum_test`) 63 | require.NoError(t, err) 64 | require.True(t, rows.Next()) 65 | values, err := rows.Values() 66 | require.NoError(t, err) 67 | require.Equal(t, values, []any{"foo"}) 68 | }) 69 | } 70 | -------------------------------------------------------------------------------- /pgtype/example_child_records_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "time" 8 | 9 | "github.com/jackc/pgx/v5" 10 | ) 11 | 12 | type Player struct { 13 | Name string 14 | Position string 15 | } 16 | 17 | type Team struct { 18 | Name string 19 | Players []Player 20 | } 21 | 22 | // This example uses a single query to return parent and child records. 23 | func Example_childRecords() { 24 | ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) 25 | defer cancel() 26 | 27 | conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) 28 | if err != nil { 29 | fmt.Printf("Unable to establish connection: %v", err) 30 | return 31 | } 32 | 33 | if conn.PgConn().ParameterStatus("crdb_version") != "" { 34 | // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. 35 | fmt.Println(`Alpha 36 | Adam: wing 37 | Bill: halfback 38 | Charlie: fullback 39 | Beta 40 | Don: halfback 41 | Edgar: halfback 42 | Frank: fullback`) 43 | return 44 | } 45 | 46 | // Setup example schema and data. 47 | _, err = conn.Exec(ctx, ` 48 | create temporary table teams ( 49 | name text primary key 50 | ); 51 | 52 | create temporary table players ( 53 | name text primary key, 54 | team_name text, 55 | position text 56 | ); 57 | 58 | insert into teams (name) values 59 | ('Alpha'), 60 | ('Beta'); 61 | 62 | insert into players (name, team_name, position) values 63 | ('Adam', 'Alpha', 'wing'), 64 | ('Bill', 'Alpha', 'halfback'), 65 | ('Charlie', 'Alpha', 'fullback'), 66 | ('Don', 'Beta', 'halfback'), 67 | ('Edgar', 'Beta', 'halfback'), 68 | ('Frank', 'Beta', 'fullback') 69 | `) 70 | if err != nil { 71 | fmt.Printf("Unable to setup example schema and data: %v", err) 72 | return 73 | } 74 | 75 | rows, _ := conn.Query(ctx, ` 76 | select t.name, 77 | (select array_agg(row(p.name, position) order by p.name) from players p where p.team_name = t.name) 78 | from teams t 79 | order by t.name 80 | `) 81 | teams, err := pgx.CollectRows(rows, pgx.RowToStructByPos[Team]) 82 | if err != nil { 83 | fmt.Printf("CollectRows error: %v", err) 84 | return 85 | } 86 | 87 | for _, team := range teams { 88 | fmt.Println(team.Name) 89 | for _, player := range team.Players { 90 | fmt.Printf(" %s: %s\n", player.Name, player.Position) 91 | } 92 | } 93 | 94 | // Output: 95 | // Alpha 96 | // Adam: wing 97 | // Bill: halfback 98 | // Charlie: fullback 99 | // Beta 100 | // Don: halfback 101 | // Edgar: halfback 102 | // Frank: fullback 103 | } 104 | -------------------------------------------------------------------------------- /pgtype/example_custom_type_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/jackc/pgx/v5" 9 | "github.com/jackc/pgx/v5/pgtype" 10 | ) 11 | 12 | // Point represents a point that may be null. 13 | type Point struct { 14 | X, Y float32 // Coordinates of point 15 | Valid bool 16 | } 17 | 18 | func (p *Point) ScanPoint(v pgtype.Point) error { 19 | *p = Point{ 20 | X: float32(v.P.X), 21 | Y: float32(v.P.Y), 22 | Valid: v.Valid, 23 | } 24 | return nil 25 | } 26 | 27 | func (p Point) PointValue() (pgtype.Point, error) { 28 | return pgtype.Point{ 29 | P: pgtype.Vec2{X: float64(p.X), Y: float64(p.Y)}, 30 | Valid: true, 31 | }, nil 32 | } 33 | 34 | func (src *Point) String() string { 35 | if !src.Valid { 36 | return "null point" 37 | } 38 | 39 | return fmt.Sprintf("%.1f, %.1f", src.X, src.Y) 40 | } 41 | 42 | func Example_customType() { 43 | conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 44 | if err != nil { 45 | fmt.Printf("Unable to establish connection: %v", err) 46 | return 47 | } 48 | defer conn.Close(context.Background()) 49 | 50 | if conn.PgConn().ParameterStatus("crdb_version") != "" { 51 | // Skip test / example when running on CockroachDB which doesn't support the point type. Since an example can't be 52 | // skipped fake success instead. 53 | fmt.Println("null point") 54 | fmt.Println("1.5, 2.5") 55 | return 56 | } 57 | 58 | p := &Point{} 59 | err = conn.QueryRow(context.Background(), "select null::point").Scan(p) 60 | if err != nil { 61 | fmt.Println(err) 62 | return 63 | } 64 | fmt.Println(p) 65 | 66 | err = conn.QueryRow(context.Background(), "select point(1.5,2.5)").Scan(p) 67 | if err != nil { 68 | fmt.Println(err) 69 | return 70 | } 71 | fmt.Println(p) 72 | // Output: 73 | // null point 74 | // 1.5, 2.5 75 | } 76 | -------------------------------------------------------------------------------- /pgtype/example_json_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/jackc/pgx/v5" 9 | ) 10 | 11 | func Example_json() { 12 | conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 13 | if err != nil { 14 | fmt.Printf("Unable to establish connection: %v", err) 15 | return 16 | } 17 | 18 | type person struct { 19 | Name string `json:"name"` 20 | Age int `json:"age"` 21 | } 22 | 23 | input := person{ 24 | Name: "John", 25 | Age: 42, 26 | } 27 | 28 | var output person 29 | 30 | err = conn.QueryRow(context.Background(), "select $1::json", input).Scan(&output) 31 | if err != nil { 32 | fmt.Println(err) 33 | return 34 | } 35 | 36 | fmt.Println(output.Name, output.Age) 37 | // Output: 38 | // John 42 39 | } 40 | -------------------------------------------------------------------------------- /pgtype/float4_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func TestFloat4Codec(t *testing.T) { 12 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "float4", []pgxtest.ValueRoundTripTest{ 13 | {pgtype.Float4{Float32: -1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: -1, Valid: true})}, 14 | {pgtype.Float4{Float32: 0, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: 0, Valid: true})}, 15 | {pgtype.Float4{Float32: 1, Valid: true}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{Float32: 1, Valid: true})}, 16 | {float32(0.00001), new(float32), isExpectedEq(float32(0.00001))}, 17 | {float32(9999.99), new(float32), isExpectedEq(float32(9999.99))}, 18 | {pgtype.Float4{}, new(pgtype.Float4), isExpectedEq(pgtype.Float4{})}, 19 | {int64(1), new(int64), isExpectedEq(int64(1))}, 20 | {"1.23", new(string), isExpectedEq("1.23")}, 21 | {nil, new(*float32), isExpectedEq((*float32)(nil))}, 22 | }) 23 | } 24 | -------------------------------------------------------------------------------- /pgtype/float8_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func TestFloat8Codec(t *testing.T) { 12 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "float8", []pgxtest.ValueRoundTripTest{ 13 | {pgtype.Float8{Float64: -1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: -1, Valid: true})}, 14 | {pgtype.Float8{Float64: 0, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: 0, Valid: true})}, 15 | {pgtype.Float8{Float64: 1, Valid: true}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{Float64: 1, Valid: true})}, 16 | {float64(0.00001), new(float64), isExpectedEq(float64(0.00001))}, 17 | {float64(9999.99), new(float64), isExpectedEq(float64(9999.99))}, 18 | {pgtype.Float8{}, new(pgtype.Float8), isExpectedEq(pgtype.Float8{})}, 19 | {int64(1), new(int64), isExpectedEq(int64(1))}, 20 | {"1.23", new(string), isExpectedEq("1.23")}, 21 | {nil, new(*float64), isExpectedEq((*float64)(nil))}, 22 | }) 23 | } 24 | -------------------------------------------------------------------------------- /pgtype/integration_benchmark_test.go.erb: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype/testutil" 8 | "github.com/jackc/pgx/v5" 9 | ) 10 | 11 | <% 12 | [ 13 | ["int4", ["int16", "int32", "int64", "uint64", "pgtype.Int4"], [[1, 1], [1, 10], [10, 1], [100, 10]]], 14 | ["numeric", ["int64", "float64", "pgtype.Numeric"], [[1, 1], [1, 10], [10, 1], [100, 10]]], 15 | ].each do |pg_type, go_types, rows_columns| 16 | %> 17 | <% go_types.each do |go_type| %> 18 | <% rows_columns.each do |rows, columns| %> 19 | <% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |format_name, format_code| %> 20 | func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go_type.gsub(/\W/, "_") %>_<%= rows %>_rows_<%= columns %>_columns(b *testing.B) { 21 | defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { 22 | b.ResetTimer() 23 | var v [<%= columns %>]<%= go_type %> 24 | for i := 0; i < b.N; i++ { 25 | rows, _ := conn.Query( 26 | ctx, 27 | `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, 28 | []any{pgx.QueryResultFormats{<%= format_code %>}}, 29 | ) 30 | _, err := pgx.ForEachRow(rows, []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func() error { return nil }) 31 | if err != nil { 32 | b.Fatal(err) 33 | } 34 | } 35 | }) 36 | } 37 | <% end %> 38 | <% end %> 39 | <% end %> 40 | <% end %> 41 | 42 | <% [10, 100, 1000].each do |array_size| %> 43 | <% [["Text", "pgx.TextFormatCode"], ["Binary", "pgx.BinaryFormatCode"]].each do |format_name, format_code| %> 44 | func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array_<%= array_size %>(b *testing.B) { 45 | defaultConnTestRunner.RunTest(context.Background(), b, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { 46 | b.ResetTimer() 47 | var v []int32 48 | for i := 0; i < b.N; i++ { 49 | rows, _ := conn.Query( 50 | ctx, 51 | `select array_agg(n) from generate_series(1, <%= array_size %>) n`, 52 | []any{pgx.QueryResultFormats{<%= format_code %>}}, 53 | ) 54 | _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) 55 | if err != nil { 56 | b.Fatal(err) 57 | } 58 | } 59 | }) 60 | } 61 | <% end %> 62 | <% end %> 63 | -------------------------------------------------------------------------------- /pgtype/integration_benchmark_test_gen.sh: -------------------------------------------------------------------------------- 1 | erb integration_benchmark_test.go.erb > integration_benchmark_test.go 2 | goimports -w integration_benchmark_test.go 3 | -------------------------------------------------------------------------------- /pgtype/jsonb_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgxtest" 8 | ) 9 | 10 | func TestJSONBTranscode(t *testing.T) { 11 | type jsonStruct struct { 12 | Name string `json:"name"` 13 | Age int `json:"age"` 14 | } 15 | 16 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "jsonb", []pgxtest.ValueRoundTripTest{ 17 | {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, 18 | {map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, 19 | {map[string]any(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, 20 | {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, 21 | {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, 22 | }) 23 | 24 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{ 25 | {[]byte("{}"), new([]byte), isExpectedEqBytes([]byte("{}"))}, 26 | {[]byte("null"), new([]byte), isExpectedEqBytes([]byte("null"))}, 27 | {[]byte("42"), new([]byte), isExpectedEqBytes([]byte("42"))}, 28 | {[]byte(`"hello"`), new([]byte), isExpectedEqBytes([]byte(`"hello"`))}, 29 | {[]byte(`"hello"`), new(string), isExpectedEq(`"hello"`)}, 30 | {map[string]any{"foo": "bar"}, new(map[string]any), isExpectedEqMap(map[string]any{"foo": "bar"})}, 31 | {jsonStruct{Name: "Adam", Age: 10}, new(jsonStruct), isExpectedEq(jsonStruct{Name: "Adam", Age: 10})}, 32 | }) 33 | } 34 | -------------------------------------------------------------------------------- /pgtype/line_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | pgx "github.com/jackc/pgx/v5" 8 | "github.com/jackc/pgx/v5/pgtype" 9 | "github.com/jackc/pgx/v5/pgxtest" 10 | ) 11 | 12 | func TestLineTranscode(t *testing.T) { 13 | ctr := defaultConnTestRunner 14 | ctr.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 15 | pgxtest.SkipCockroachDB(t, conn, "Server does not support type line") 16 | 17 | if _, ok := conn.TypeMap().TypeForName("line"); !ok { 18 | t.Skip("Skipping due to no line type") 19 | } 20 | 21 | // line may exist but not be usable on 9.3 :( 22 | var isPG93 bool 23 | err := conn.QueryRow(context.Background(), "select version() ~ '9.3'").Scan(&isPG93) 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | if isPG93 { 28 | t.Skip("Skipping due to unimplemented line type in PG 9.3") 29 | } 30 | } 31 | 32 | pgxtest.RunValueRoundTripTests(context.Background(), t, ctr, nil, "line", []pgxtest.ValueRoundTripTest{ 33 | { 34 | pgtype.Line{ 35 | A: 1.23, B: 4.56, C: 7.89012345, 36 | Valid: true, 37 | }, 38 | new(pgtype.Line), 39 | isExpectedEq(pgtype.Line{ 40 | A: 1.23, B: 4.56, C: 7.89012345, 41 | Valid: true, 42 | }), 43 | }, 44 | { 45 | pgtype.Line{ 46 | A: -1.23, B: -4.56, C: -7.89, 47 | Valid: true, 48 | }, 49 | new(pgtype.Line), 50 | isExpectedEq(pgtype.Line{ 51 | A: -1.23, B: -4.56, C: -7.89, 52 | Valid: true, 53 | }), 54 | }, 55 | {pgtype.Line{}, new(pgtype.Line), isExpectedEq(pgtype.Line{})}, 56 | {nil, new(pgtype.Line), isExpectedEq(pgtype.Line{})}, 57 | }) 58 | } 59 | -------------------------------------------------------------------------------- /pgtype/lseg_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func TestLsegTranscode(t *testing.T) { 12 | skipCockroachDB(t, "Server does not support type lseg") 13 | 14 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "lseg", []pgxtest.ValueRoundTripTest{ 15 | { 16 | pgtype.Lseg{ 17 | P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, 18 | Valid: true, 19 | }, 20 | new(pgtype.Lseg), 21 | isExpectedEq(pgtype.Lseg{ 22 | P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, 23 | Valid: true, 24 | }), 25 | }, 26 | { 27 | pgtype.Lseg{ 28 | P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, 29 | Valid: true, 30 | }, 31 | new(pgtype.Lseg), 32 | isExpectedEq(pgtype.Lseg{ 33 | P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, 34 | Valid: true, 35 | }), 36 | }, 37 | {pgtype.Lseg{}, new(pgtype.Lseg), isExpectedEq(pgtype.Lseg{})}, 38 | {nil, new(pgtype.Lseg), isExpectedEq(pgtype.Lseg{})}, 39 | }) 40 | } 41 | -------------------------------------------------------------------------------- /pgtype/macaddr_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "net" 7 | "testing" 8 | 9 | "github.com/jackc/pgx/v5/pgxtest" 10 | ) 11 | 12 | func isExpectedEqHardwareAddr(a any) func(any) bool { 13 | return func(v any) bool { 14 | aa := a.(net.HardwareAddr) 15 | vv := v.(net.HardwareAddr) 16 | 17 | if (aa == nil) != (vv == nil) { 18 | return false 19 | } 20 | 21 | if aa == nil { 22 | return true 23 | } 24 | 25 | return bytes.Compare(aa, vv) == 0 26 | } 27 | } 28 | 29 | func TestMacaddrCodec(t *testing.T) { 30 | skipCockroachDB(t, "Server does not support type macaddr") 31 | 32 | // Only testing known OID query exec modes as net.HardwareAddr could map to macaddr or macaddr8. 33 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "macaddr", []pgxtest.ValueRoundTripTest{ 34 | { 35 | mustParseMacaddr(t, "01:23:45:67:89:ab"), 36 | new(net.HardwareAddr), 37 | isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab")), 38 | }, 39 | { 40 | "01:23:45:67:89:ab", 41 | new(net.HardwareAddr), 42 | isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab")), 43 | }, 44 | { 45 | mustParseMacaddr(t, "01:23:45:67:89:ab"), 46 | new(string), 47 | isExpectedEq("01:23:45:67:89:ab"), 48 | }, 49 | {nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))}, 50 | }) 51 | } 52 | -------------------------------------------------------------------------------- /pgtype/path_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func isExpectedEqPath(a any) func(any) bool { 12 | return func(v any) bool { 13 | ap := a.(pgtype.Path) 14 | vp := v.(pgtype.Path) 15 | 16 | if !(ap.Valid == vp.Valid && ap.Closed == vp.Closed && len(ap.P) == len(vp.P)) { 17 | return false 18 | } 19 | 20 | for i := range ap.P { 21 | if ap.P[i] != vp.P[i] { 22 | return false 23 | } 24 | } 25 | 26 | return true 27 | } 28 | } 29 | 30 | func TestPathTranscode(t *testing.T) { 31 | skipCockroachDB(t, "Server does not support type path") 32 | 33 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "path", []pgxtest.ValueRoundTripTest{ 34 | { 35 | pgtype.Path{ 36 | P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, 37 | Closed: false, 38 | Valid: true, 39 | }, 40 | new(pgtype.Path), 41 | isExpectedEqPath(pgtype.Path{ 42 | P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}}, 43 | Closed: false, 44 | Valid: true, 45 | }), 46 | }, 47 | { 48 | pgtype.Path{ 49 | P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, 50 | Closed: true, 51 | Valid: true, 52 | }, 53 | new(pgtype.Path), 54 | isExpectedEqPath(pgtype.Path{ 55 | P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, 56 | Closed: true, 57 | Valid: true, 58 | }), 59 | }, 60 | { 61 | pgtype.Path{ 62 | P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, 63 | Closed: true, 64 | Valid: true, 65 | }, 66 | new(pgtype.Path), 67 | isExpectedEqPath(pgtype.Path{ 68 | P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, 69 | Closed: true, 70 | Valid: true, 71 | }), 72 | }, 73 | {pgtype.Path{}, new(pgtype.Path), isExpectedEqPath(pgtype.Path{})}, 74 | {nil, new(pgtype.Path), isExpectedEqPath(pgtype.Path{})}, 75 | }) 76 | } 77 | -------------------------------------------------------------------------------- /pgtype/polygon_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func isExpectedEqPolygon(a any) func(any) bool { 12 | return func(v any) bool { 13 | ap := a.(pgtype.Polygon) 14 | vp := v.(pgtype.Polygon) 15 | 16 | if !(ap.Valid == vp.Valid && len(ap.P) == len(vp.P)) { 17 | return false 18 | } 19 | 20 | for i := range ap.P { 21 | if ap.P[i] != vp.P[i] { 22 | return false 23 | } 24 | } 25 | 26 | return true 27 | } 28 | } 29 | 30 | func TestPolygonTranscode(t *testing.T) { 31 | skipCockroachDB(t, "Server does not support type polygon") 32 | 33 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "polygon", []pgxtest.ValueRoundTripTest{ 34 | { 35 | pgtype.Polygon{ 36 | P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, 37 | Valid: true, 38 | }, 39 | new(pgtype.Polygon), 40 | isExpectedEqPolygon(pgtype.Polygon{ 41 | P: []pgtype.Vec2{{3.14, 1.678901234}, {7.1, 5.234}, {5.0, 3.234}}, 42 | Valid: true, 43 | }), 44 | }, 45 | { 46 | pgtype.Polygon{ 47 | P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, 48 | Valid: true, 49 | }, 50 | new(pgtype.Polygon), 51 | isExpectedEqPolygon(pgtype.Polygon{ 52 | P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, 53 | Valid: true, 54 | }), 55 | }, 56 | {pgtype.Polygon{}, new(pgtype.Polygon), isExpectedEqPolygon(pgtype.Polygon{})}, 57 | {nil, new(pgtype.Polygon), isExpectedEqPolygon(pgtype.Polygon{})}, 58 | }) 59 | } 60 | -------------------------------------------------------------------------------- /pgtype/qchar_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "math" 6 | "testing" 7 | 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func TestQcharTranscode(t *testing.T) { 12 | skipCockroachDB(t, "Server does not support qchar") 13 | 14 | var tests []pgxtest.ValueRoundTripTest 15 | for i := 0; i <= math.MaxUint8; i++ { 16 | tests = append(tests, pgxtest.ValueRoundTripTest{rune(i), new(rune), isExpectedEq(rune(i))}) 17 | tests = append(tests, pgxtest.ValueRoundTripTest{byte(i), new(byte), isExpectedEq(byte(i))}) 18 | } 19 | tests = append(tests, pgxtest.ValueRoundTripTest{nil, new(*rune), isExpectedEq((*rune)(nil))}) 20 | tests = append(tests, pgxtest.ValueRoundTripTest{nil, new(*byte), isExpectedEq((*byte)(nil))}) 21 | 22 | // Can only test with known OIDs as rune and byte would be considered numbers. 23 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, `"char"`, tests) 24 | } 25 | -------------------------------------------------------------------------------- /pgtype/record_codec_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | pgx "github.com/jackc/pgx/v5" 8 | "github.com/jackc/pgx/v5/pgtype" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestRecordCodec(t *testing.T) { 13 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 14 | var a string 15 | var b int32 16 | err := conn.QueryRow(ctx, `select row('foo'::text, 42::int4)`).Scan(pgtype.CompositeFields{&a, &b}) 17 | require.NoError(t, err) 18 | 19 | require.Equal(t, "foo", a) 20 | require.Equal(t, int32(42), b) 21 | }) 22 | } 23 | 24 | func TestRecordCodecDecodeValue(t *testing.T) { 25 | skipCockroachDB(t, "Server converts row int4 to int8") 26 | 27 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { 28 | for _, tt := range []struct { 29 | sql string 30 | expected any 31 | }{ 32 | { 33 | sql: `select row()`, 34 | expected: []any{}, 35 | }, 36 | { 37 | sql: `select row('foo'::text, 42::int4)`, 38 | expected: []any{"foo", int32(42)}, 39 | }, 40 | { 41 | sql: `select row(100.0::float4, 1.09::float4)`, 42 | expected: []any{float32(100), float32(1.09)}, 43 | }, 44 | { 45 | sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, 46 | expected: []any{"foo", []any{int32(1), int32(2), nil, int32(4)}, int32(42)}, 47 | }, 48 | { 49 | sql: `select row(null)`, 50 | expected: []any{nil}, 51 | }, 52 | { 53 | sql: `select null::record`, 54 | expected: nil, 55 | }, 56 | } { 57 | t.Run(tt.sql, func(t *testing.T) { 58 | rows, err := conn.Query(context.Background(), tt.sql) 59 | require.NoError(t, err) 60 | defer rows.Close() 61 | 62 | for rows.Next() { 63 | values, err := rows.Values() 64 | require.NoError(t, err) 65 | require.Len(t, values, 1) 66 | require.Equal(t, tt.expected, values[0]) 67 | } 68 | 69 | require.NoError(t, rows.Err()) 70 | }) 71 | } 72 | }) 73 | } 74 | -------------------------------------------------------------------------------- /pgtype/register_default_pg_types.go: -------------------------------------------------------------------------------- 1 | //go:build !nopgxregisterdefaulttypes 2 | 3 | package pgtype 4 | 5 | func registerDefaultPgTypeVariants[T any](m *Map, name string) { 6 | arrayName := "_" + name 7 | 8 | var value T 9 | m.RegisterDefaultPgType(value, name) // T 10 | m.RegisterDefaultPgType(&value, name) // *T 11 | 12 | var sliceT []T 13 | m.RegisterDefaultPgType(sliceT, arrayName) // []T 14 | m.RegisterDefaultPgType(&sliceT, arrayName) // *[]T 15 | 16 | var slicePtrT []*T 17 | m.RegisterDefaultPgType(slicePtrT, arrayName) // []*T 18 | m.RegisterDefaultPgType(&slicePtrT, arrayName) // *[]*T 19 | 20 | var arrayOfT Array[T] 21 | m.RegisterDefaultPgType(arrayOfT, arrayName) // Array[T] 22 | m.RegisterDefaultPgType(&arrayOfT, arrayName) // *Array[T] 23 | 24 | var arrayOfPtrT Array[*T] 25 | m.RegisterDefaultPgType(arrayOfPtrT, arrayName) // Array[*T] 26 | m.RegisterDefaultPgType(&arrayOfPtrT, arrayName) // *Array[*T] 27 | 28 | var flatArrayOfT FlatArray[T] 29 | m.RegisterDefaultPgType(flatArrayOfT, arrayName) // FlatArray[T] 30 | m.RegisterDefaultPgType(&flatArrayOfT, arrayName) // *FlatArray[T] 31 | 32 | var flatArrayOfPtrT FlatArray[*T] 33 | m.RegisterDefaultPgType(flatArrayOfPtrT, arrayName) // FlatArray[*T] 34 | m.RegisterDefaultPgType(&flatArrayOfPtrT, arrayName) // *FlatArray[*T] 35 | } 36 | -------------------------------------------------------------------------------- /pgtype/register_default_pg_types_disabled.go: -------------------------------------------------------------------------------- 1 | //go:build nopgxregisterdefaulttypes 2 | 3 | package pgtype 4 | 5 | func registerDefaultPgTypeVariants[T any](m *Map, name string) { 6 | } 7 | -------------------------------------------------------------------------------- /pgtype/text_format_only_codec.go: -------------------------------------------------------------------------------- 1 | package pgtype 2 | 3 | type TextFormatOnlyCodec struct { 4 | Codec 5 | } 6 | 7 | func (c *TextFormatOnlyCodec) FormatSupported(format int16) bool { 8 | return format == TextFormatCode && c.Codec.FormatSupported(format) 9 | } 10 | 11 | func (TextFormatOnlyCodec) PreferredFormat() int16 { 12 | return TextFormatCode 13 | } 14 | -------------------------------------------------------------------------------- /pgtype/tid_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func TestTIDCodec(t *testing.T) { 12 | skipCockroachDB(t, "Server does not support type tid") 13 | 14 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "tid", []pgxtest.ValueRoundTripTest{ 15 | { 16 | pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, 17 | new(pgtype.TID), 18 | isExpectedEq(pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}), 19 | }, 20 | { 21 | pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, 22 | new(pgtype.TID), 23 | isExpectedEq(pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}), 24 | }, 25 | { 26 | pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Valid: true}, 27 | new(string), 28 | isExpectedEq("(42,43)"), 29 | }, 30 | { 31 | pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Valid: true}, 32 | new(string), 33 | isExpectedEq("(4294967295,65535)"), 34 | }, 35 | {pgtype.TID{}, new(pgtype.TID), isExpectedEq(pgtype.TID{})}, 36 | {nil, new(pgtype.TID), isExpectedEq(pgtype.TID{})}, 37 | }) 38 | } 39 | -------------------------------------------------------------------------------- /pgtype/time_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/jackc/pgx/v5/pgtype" 9 | "github.com/jackc/pgx/v5/pgxtest" 10 | ) 11 | 12 | func TestTimeCodec(t *testing.T) { 13 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "time", []pgxtest.ValueRoundTripTest{ 14 | { 15 | pgtype.Time{Microseconds: 0, Valid: true}, 16 | new(pgtype.Time), 17 | isExpectedEq(pgtype.Time{Microseconds: 0, Valid: true}), 18 | }, 19 | { 20 | pgtype.Time{Microseconds: 1, Valid: true}, 21 | new(pgtype.Time), 22 | isExpectedEq(pgtype.Time{Microseconds: 1, Valid: true}), 23 | }, 24 | { 25 | pgtype.Time{Microseconds: 86399999999, Valid: true}, 26 | new(pgtype.Time), 27 | isExpectedEq(pgtype.Time{Microseconds: 86399999999, Valid: true}), 28 | }, 29 | { 30 | pgtype.Time{Microseconds: 86400000000, Valid: true}, 31 | new(pgtype.Time), 32 | isExpectedEq(pgtype.Time{Microseconds: 86400000000, Valid: true}), 33 | }, 34 | { 35 | time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), 36 | new(pgtype.Time), 37 | isExpectedEq(pgtype.Time{Microseconds: 0, Valid: true}), 38 | }, 39 | { 40 | pgtype.Time{Microseconds: 0, Valid: true}, 41 | new(time.Time), 42 | isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)), 43 | }, 44 | {pgtype.Time{}, new(pgtype.Time), isExpectedEq(pgtype.Time{})}, 45 | {nil, new(pgtype.Time), isExpectedEq(pgtype.Time{})}, 46 | }) 47 | } 48 | -------------------------------------------------------------------------------- /pgtype/uint32_test.go: -------------------------------------------------------------------------------- 1 | package pgtype_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func TestUint32Codec(t *testing.T) { 12 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "oid", []pgxtest.ValueRoundTripTest{ 13 | { 14 | pgtype.Uint32{Uint32: pgtype.TextOID, Valid: true}, 15 | new(pgtype.Uint32), 16 | isExpectedEq(pgtype.Uint32{Uint32: pgtype.TextOID, Valid: true}), 17 | }, 18 | {pgtype.Uint32{}, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, 19 | {nil, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, 20 | }) 21 | } 22 | -------------------------------------------------------------------------------- /pgtype/zeronull/doc.go: -------------------------------------------------------------------------------- 1 | // Package zeronull contains types that automatically convert between database NULLs and Go zero values. 2 | /* 3 | Sometimes the distinction between a zero value and a NULL value is not useful at the application level. For example, 4 | in PostgreSQL an empty string may be stored as NULL. There is usually no application level distinction between an 5 | empty string and a NULL string. Package zeronull implements types that seamlessly convert between PostgreSQL NULL and 6 | the zero value. 7 | 8 | It is recommended to convert types at usage time rather than instantiate these types directly. In the example below, 9 | middlename would be stored as a NULL. 10 | 11 | firstname := "John" 12 | middlename := "" 13 | lastname := "Smith" 14 | _, err := conn.Exec( 15 | ctx, 16 | "insert into people(firstname, middlename, lastname) values($1, $2, $3)", 17 | zeronull.Text(firstname), 18 | zeronull.Text(middlename), 19 | zeronull.Text(lastname), 20 | ) 21 | */ 22 | package zeronull 23 | -------------------------------------------------------------------------------- /pgtype/zeronull/float8.go: -------------------------------------------------------------------------------- 1 | package zeronull 2 | 3 | import ( 4 | "database/sql/driver" 5 | 6 | "github.com/jackc/pgx/v5/pgtype" 7 | ) 8 | 9 | type Float8 float64 10 | 11 | func (Float8) SkipUnderlyingTypePlan() {} 12 | 13 | // ScanFloat64 implements the Float64Scanner interface. 14 | func (f *Float8) ScanFloat64(n pgtype.Float8) error { 15 | if !n.Valid { 16 | *f = 0 17 | return nil 18 | } 19 | 20 | *f = Float8(n.Float64) 21 | 22 | return nil 23 | } 24 | 25 | func (f Float8) Float64Value() (pgtype.Float8, error) { 26 | if f == 0 { 27 | return pgtype.Float8{}, nil 28 | } 29 | return pgtype.Float8{Float64: float64(f), Valid: true}, nil 30 | } 31 | 32 | // Scan implements the database/sql Scanner interface. 33 | func (f *Float8) Scan(src any) error { 34 | if src == nil { 35 | *f = 0 36 | return nil 37 | } 38 | 39 | var nullable pgtype.Float8 40 | err := nullable.Scan(src) 41 | if err != nil { 42 | return err 43 | } 44 | 45 | *f = Float8(nullable.Float64) 46 | 47 | return nil 48 | } 49 | 50 | // Value implements the database/sql/driver Valuer interface. 51 | func (f Float8) Value() (driver.Value, error) { 52 | if f == 0 { 53 | return nil, nil 54 | } 55 | return float64(f), nil 56 | } 57 | -------------------------------------------------------------------------------- /pgtype/zeronull/float8_test.go: -------------------------------------------------------------------------------- 1 | package zeronull_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype/zeronull" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func isExpectedEq(a any) func(any) bool { 12 | return func(v any) bool { 13 | return a == v 14 | } 15 | } 16 | 17 | func TestFloat8Transcode(t *testing.T) { 18 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "float8", []pgxtest.ValueRoundTripTest{ 19 | { 20 | (zeronull.Float8)(1), 21 | new(zeronull.Float8), 22 | isExpectedEq((zeronull.Float8)(1)), 23 | }, 24 | { 25 | nil, 26 | new(zeronull.Float8), 27 | isExpectedEq((zeronull.Float8)(0)), 28 | }, 29 | { 30 | (zeronull.Float8)(0), 31 | new(any), 32 | isExpectedEq(nil), 33 | }, 34 | }) 35 | } 36 | -------------------------------------------------------------------------------- /pgtype/zeronull/int.go.erb: -------------------------------------------------------------------------------- 1 | package zeronull 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "math" 7 | 8 | "github.com/jackc/pgx/v5/pgtype" 9 | ) 10 | 11 | <% [2, 4, 8].each do |pg_byte_size| %> 12 | <% pg_bit_size = pg_byte_size * 8 %> 13 | type Int<%= pg_byte_size %> int<%= pg_bit_size %> 14 | 15 | func (Int<%= pg_byte_size %>) SkipUnderlyingTypePlan() {} 16 | 17 | // ScanInt64 implements the Int64Scanner interface. 18 | func (dst *Int<%= pg_byte_size %>) ScanInt64(n int64, valid bool) error { 19 | if !valid { 20 | *dst = 0 21 | return nil 22 | } 23 | 24 | if n < math.MinInt<%= pg_bit_size %> { 25 | return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) 26 | } 27 | if n > math.MaxInt<%= pg_bit_size %> { 28 | return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) 29 | } 30 | *dst = Int<%= pg_byte_size %>(n) 31 | 32 | return nil 33 | } 34 | 35 | // Scan implements the database/sql Scanner interface. 36 | func (dst *Int<%= pg_byte_size %>) Scan(src any) error { 37 | if src == nil { 38 | *dst = 0 39 | return nil 40 | } 41 | 42 | var nullable pgtype.Int<%= pg_byte_size %> 43 | err := nullable.Scan(src) 44 | if err != nil { 45 | return err 46 | } 47 | 48 | *dst = Int<%= pg_byte_size %>(nullable.Int<%= pg_bit_size %>) 49 | 50 | return nil 51 | } 52 | 53 | // Value implements the database/sql/driver Valuer interface. 54 | func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { 55 | if src == 0 { 56 | return nil, nil 57 | } 58 | return int64(src), nil 59 | } 60 | <% end %> 61 | -------------------------------------------------------------------------------- /pgtype/zeronull/int_test.go: -------------------------------------------------------------------------------- 1 | // Do not edit. Generated from pgtype/zeronull/int_test.go.erb 2 | package zeronull_test 3 | 4 | import ( 5 | "context" 6 | "testing" 7 | 8 | "github.com/jackc/pgx/v5/pgtype/zeronull" 9 | "github.com/jackc/pgx/v5/pgxtest" 10 | ) 11 | 12 | func TestInt2Transcode(t *testing.T) { 13 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int2", []pgxtest.ValueRoundTripTest{ 14 | { 15 | (zeronull.Int2)(1), 16 | new(zeronull.Int2), 17 | isExpectedEq((zeronull.Int2)(1)), 18 | }, 19 | { 20 | nil, 21 | new(zeronull.Int2), 22 | isExpectedEq((zeronull.Int2)(0)), 23 | }, 24 | { 25 | (zeronull.Int2)(0), 26 | new(any), 27 | isExpectedEq(nil), 28 | }, 29 | }) 30 | } 31 | 32 | func TestInt4Transcode(t *testing.T) { 33 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int4", []pgxtest.ValueRoundTripTest{ 34 | { 35 | (zeronull.Int4)(1), 36 | new(zeronull.Int4), 37 | isExpectedEq((zeronull.Int4)(1)), 38 | }, 39 | { 40 | nil, 41 | new(zeronull.Int4), 42 | isExpectedEq((zeronull.Int4)(0)), 43 | }, 44 | { 45 | (zeronull.Int4)(0), 46 | new(any), 47 | isExpectedEq(nil), 48 | }, 49 | }) 50 | } 51 | 52 | func TestInt8Transcode(t *testing.T) { 53 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{ 54 | { 55 | (zeronull.Int8)(1), 56 | new(zeronull.Int8), 57 | isExpectedEq((zeronull.Int8)(1)), 58 | }, 59 | { 60 | nil, 61 | new(zeronull.Int8), 62 | isExpectedEq((zeronull.Int8)(0)), 63 | }, 64 | { 65 | (zeronull.Int8)(0), 66 | new(any), 67 | isExpectedEq(nil), 68 | }, 69 | }) 70 | } 71 | -------------------------------------------------------------------------------- /pgtype/zeronull/int_test.go.erb: -------------------------------------------------------------------------------- 1 | package zeronull_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/jackc/pgx/v5/pgtype/testutil" 7 | "github.com/jackc/pgx/v5/pgtype/zeronull" 8 | ) 9 | 10 | <% [2, 4, 8].each do |pg_byte_size| %> 11 | <% pg_bit_size = pg_byte_size * 8 %> 12 | func TestInt<%= pg_byte_size %>Transcode(t *testing.T) { 13 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int<%= pg_byte_size %>", []pgxtest.ValueRoundTripTest{ 14 | { 15 | (zeronull.Int<%= pg_byte_size %>)(1), 16 | new(zeronull.Int<%= pg_byte_size %>), 17 | isExpectedEq((zeronull.Int<%= pg_byte_size %>)(1)), 18 | }, 19 | { 20 | nil, 21 | new(zeronull.Int<%= pg_byte_size %>), 22 | isExpectedEq((zeronull.Int<%= pg_byte_size %>)(0)), 23 | }, 24 | { 25 | (zeronull.Int<%= pg_byte_size %>)(0), 26 | new(any), 27 | isExpectedEq(nil), 28 | }, 29 | }) 30 | } 31 | <% end %> 32 | -------------------------------------------------------------------------------- /pgtype/zeronull/text.go: -------------------------------------------------------------------------------- 1 | package zeronull 2 | 3 | import ( 4 | "database/sql/driver" 5 | 6 | "github.com/jackc/pgx/v5/pgtype" 7 | ) 8 | 9 | type Text string 10 | 11 | func (Text) SkipUnderlyingTypePlan() {} 12 | 13 | // ScanText implements the TextScanner interface. 14 | func (dst *Text) ScanText(v pgtype.Text) error { 15 | if !v.Valid { 16 | *dst = "" 17 | return nil 18 | } 19 | 20 | *dst = Text(v.String) 21 | 22 | return nil 23 | } 24 | 25 | // Scan implements the database/sql Scanner interface. 26 | func (dst *Text) Scan(src any) error { 27 | if src == nil { 28 | *dst = "" 29 | return nil 30 | } 31 | 32 | var nullable pgtype.Text 33 | err := nullable.Scan(src) 34 | if err != nil { 35 | return err 36 | } 37 | 38 | *dst = Text(nullable.String) 39 | 40 | return nil 41 | } 42 | 43 | // Value implements the database/sql/driver Valuer interface. 44 | func (src Text) Value() (driver.Value, error) { 45 | if src == "" { 46 | return nil, nil 47 | } 48 | return string(src), nil 49 | } 50 | -------------------------------------------------------------------------------- /pgtype/zeronull/text_test.go: -------------------------------------------------------------------------------- 1 | package zeronull_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype/zeronull" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func TestTextTranscode(t *testing.T) { 12 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "text", []pgxtest.ValueRoundTripTest{ 13 | { 14 | (zeronull.Text)("foo"), 15 | new(zeronull.Text), 16 | isExpectedEq((zeronull.Text)("foo")), 17 | }, 18 | { 19 | nil, 20 | new(zeronull.Text), 21 | isExpectedEq((zeronull.Text)("")), 22 | }, 23 | { 24 | (zeronull.Text)(""), 25 | new(any), 26 | isExpectedEq(nil), 27 | }, 28 | }) 29 | } 30 | -------------------------------------------------------------------------------- /pgtype/zeronull/timestamp.go: -------------------------------------------------------------------------------- 1 | package zeronull 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/jackc/pgx/v5/pgtype" 9 | ) 10 | 11 | type Timestamp time.Time 12 | 13 | func (Timestamp) SkipUnderlyingTypePlan() {} 14 | 15 | func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error { 16 | if !v.Valid { 17 | *ts = Timestamp{} 18 | return nil 19 | } 20 | 21 | switch v.InfinityModifier { 22 | case pgtype.Finite: 23 | *ts = Timestamp(v.Time) 24 | return nil 25 | case pgtype.Infinity: 26 | return fmt.Errorf("cannot scan Infinity into *time.Time") 27 | case pgtype.NegativeInfinity: 28 | return fmt.Errorf("cannot scan -Infinity into *time.Time") 29 | default: 30 | return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) 31 | } 32 | } 33 | 34 | func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) { 35 | if time.Time(ts).IsZero() { 36 | return pgtype.Timestamp{}, nil 37 | } 38 | 39 | return pgtype.Timestamp{Time: time.Time(ts), Valid: true}, nil 40 | } 41 | 42 | // Scan implements the database/sql Scanner interface. 43 | func (ts *Timestamp) Scan(src any) error { 44 | if src == nil { 45 | *ts = Timestamp{} 46 | return nil 47 | } 48 | 49 | var nullable pgtype.Timestamp 50 | err := nullable.Scan(src) 51 | if err != nil { 52 | return err 53 | } 54 | 55 | *ts = Timestamp(nullable.Time) 56 | 57 | return nil 58 | } 59 | 60 | // Value implements the database/sql/driver Valuer interface. 61 | func (ts Timestamp) Value() (driver.Value, error) { 62 | if time.Time(ts).IsZero() { 63 | return nil, nil 64 | } 65 | 66 | return time.Time(ts), nil 67 | } 68 | -------------------------------------------------------------------------------- /pgtype/zeronull/timestamp_test.go: -------------------------------------------------------------------------------- 1 | package zeronull_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/jackc/pgx/v5/pgtype/zeronull" 9 | "github.com/jackc/pgx/v5/pgxtest" 10 | ) 11 | 12 | func isExpectedEqTimestamp(a any) func(any) bool { 13 | return func(v any) bool { 14 | at := time.Time(a.(zeronull.Timestamp)) 15 | vt := time.Time(v.(zeronull.Timestamp)) 16 | 17 | return at.Equal(vt) 18 | } 19 | } 20 | 21 | func TestTimestampTranscode(t *testing.T) { 22 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ 23 | { 24 | (zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), 25 | new(zeronull.Timestamp), 26 | isExpectedEqTimestamp((zeronull.Timestamp)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC))), 27 | }, 28 | { 29 | nil, 30 | new(zeronull.Timestamp), 31 | isExpectedEqTimestamp((zeronull.Timestamp)(time.Time{})), 32 | }, 33 | { 34 | (zeronull.Timestamp)(time.Time{}), 35 | new(any), 36 | isExpectedEq(nil), 37 | }, 38 | }) 39 | } 40 | -------------------------------------------------------------------------------- /pgtype/zeronull/timestamptz.go: -------------------------------------------------------------------------------- 1 | package zeronull 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/jackc/pgx/v5/pgtype" 9 | ) 10 | 11 | type Timestamptz time.Time 12 | 13 | func (Timestamptz) SkipUnderlyingTypePlan() {} 14 | 15 | func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error { 16 | if !v.Valid { 17 | *ts = Timestamptz{} 18 | return nil 19 | } 20 | 21 | switch v.InfinityModifier { 22 | case pgtype.Finite: 23 | *ts = Timestamptz(v.Time) 24 | return nil 25 | case pgtype.Infinity: 26 | return fmt.Errorf("cannot scan Infinity into *time.Time") 27 | case pgtype.NegativeInfinity: 28 | return fmt.Errorf("cannot scan -Infinity into *time.Time") 29 | default: 30 | return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier) 31 | } 32 | } 33 | 34 | func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) { 35 | if time.Time(ts).IsZero() { 36 | return pgtype.Timestamptz{}, nil 37 | } 38 | 39 | return pgtype.Timestamptz{Time: time.Time(ts), Valid: true}, nil 40 | } 41 | 42 | // Scan implements the database/sql Scanner interface. 43 | func (ts *Timestamptz) Scan(src any) error { 44 | if src == nil { 45 | *ts = Timestamptz{} 46 | return nil 47 | } 48 | 49 | var nullable pgtype.Timestamp 50 | err := nullable.Scan(src) 51 | if err != nil { 52 | return err 53 | } 54 | 55 | *ts = Timestamptz(nullable.Time) 56 | 57 | return nil 58 | } 59 | 60 | // Value implements the database/sql/driver Valuer interface. 61 | func (ts Timestamptz) Value() (driver.Value, error) { 62 | if time.Time(ts).IsZero() { 63 | return nil, nil 64 | } 65 | 66 | return time.Time(ts), nil 67 | } 68 | -------------------------------------------------------------------------------- /pgtype/zeronull/timestamptz_test.go: -------------------------------------------------------------------------------- 1 | package zeronull_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/jackc/pgx/v5/pgtype/zeronull" 9 | "github.com/jackc/pgx/v5/pgxtest" 10 | ) 11 | 12 | func isExpectedEqTimestamptz(a any) func(any) bool { 13 | return func(v any) bool { 14 | at := time.Time(a.(zeronull.Timestamptz)) 15 | vt := time.Time(v.(zeronull.Timestamptz)) 16 | 17 | return at.Equal(vt) 18 | } 19 | } 20 | 21 | func TestTimestamptzTranscode(t *testing.T) { 22 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{ 23 | { 24 | (zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), 25 | new(zeronull.Timestamptz), 26 | isExpectedEqTimestamptz((zeronull.Timestamptz)(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC))), 27 | }, 28 | { 29 | nil, 30 | new(zeronull.Timestamptz), 31 | isExpectedEqTimestamptz((zeronull.Timestamptz)(time.Time{})), 32 | }, 33 | { 34 | (zeronull.Timestamptz)(time.Time{}), 35 | new(any), 36 | isExpectedEq(nil), 37 | }, 38 | }) 39 | } 40 | -------------------------------------------------------------------------------- /pgtype/zeronull/uuid.go: -------------------------------------------------------------------------------- 1 | package zeronull 2 | 3 | import ( 4 | "database/sql/driver" 5 | 6 | "github.com/jackc/pgx/v5/pgtype" 7 | ) 8 | 9 | type UUID [16]byte 10 | 11 | func (UUID) SkipUnderlyingTypePlan() {} 12 | 13 | // ScanUUID implements the UUIDScanner interface. 14 | func (u *UUID) ScanUUID(v pgtype.UUID) error { 15 | if !v.Valid { 16 | *u = UUID{} 17 | return nil 18 | } 19 | 20 | *u = UUID(v.Bytes) 21 | 22 | return nil 23 | } 24 | 25 | func (u UUID) UUIDValue() (pgtype.UUID, error) { 26 | if u == (UUID{}) { 27 | return pgtype.UUID{}, nil 28 | } 29 | return pgtype.UUID{Bytes: u, Valid: true}, nil 30 | } 31 | 32 | // Scan implements the database/sql Scanner interface. 33 | func (u *UUID) Scan(src any) error { 34 | if src == nil { 35 | *u = UUID{} 36 | return nil 37 | } 38 | 39 | var nullable pgtype.UUID 40 | err := nullable.Scan(src) 41 | if err != nil { 42 | return err 43 | } 44 | 45 | *u = UUID(nullable.Bytes) 46 | 47 | return nil 48 | } 49 | 50 | // Value implements the database/sql/driver Valuer interface. 51 | func (u UUID) Value() (driver.Value, error) { 52 | if u == (UUID{}) { 53 | return nil, nil 54 | } 55 | 56 | buf, err := pgtype.UUIDCodec{}.PlanEncode(nil, pgtype.UUIDOID, pgtype.TextFormatCode, u).Encode(u, nil) 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | return string(buf), nil 62 | } 63 | -------------------------------------------------------------------------------- /pgtype/zeronull/uuid_test.go: -------------------------------------------------------------------------------- 1 | package zeronull_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5/pgtype/zeronull" 8 | "github.com/jackc/pgx/v5/pgxtest" 9 | ) 10 | 11 | func TestUUIDTranscode(t *testing.T) { 12 | pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "uuid", []pgxtest.ValueRoundTripTest{ 13 | { 14 | (zeronull.UUID)([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), 15 | new(zeronull.UUID), 16 | isExpectedEq((zeronull.UUID)([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})), 17 | }, 18 | { 19 | nil, 20 | new(zeronull.UUID), 21 | isExpectedEq((zeronull.UUID)([16]byte{})), 22 | }, 23 | { 24 | (zeronull.UUID)([16]byte{}), 25 | new(any), 26 | isExpectedEq(nil), 27 | }, 28 | }) 29 | } 30 | -------------------------------------------------------------------------------- /pgtype/zeronull/zeronull.go: -------------------------------------------------------------------------------- 1 | package zeronull 2 | 3 | import ( 4 | "github.com/jackc/pgx/v5/pgtype" 5 | ) 6 | 7 | // Register registers the zeronull types so they can be used in query exec modes that do not know the server OIDs. 8 | func Register(m *pgtype.Map) { 9 | m.RegisterDefaultPgType(Float8(0), "float8") 10 | m.RegisterDefaultPgType(Int2(0), "int2") 11 | m.RegisterDefaultPgType(Int4(0), "int4") 12 | m.RegisterDefaultPgType(Int8(0), "int8") 13 | m.RegisterDefaultPgType(Text(""), "text") 14 | m.RegisterDefaultPgType(Timestamp{}, "timestamp") 15 | m.RegisterDefaultPgType(Timestamptz{}, "timestamptz") 16 | m.RegisterDefaultPgType(UUID{}, "uuid") 17 | } 18 | -------------------------------------------------------------------------------- /pgtype/zeronull/zeronull_test.go: -------------------------------------------------------------------------------- 1 | package zeronull_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/jackc/pgx/v5" 9 | "github.com/jackc/pgx/v5/pgtype/zeronull" 10 | "github.com/jackc/pgx/v5/pgxtest" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | var defaultConnTestRunner pgxtest.ConnTestRunner 15 | 16 | func init() { 17 | defaultConnTestRunner = pgxtest.DefaultConnTestRunner() 18 | defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { 19 | config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 20 | require.NoError(t, err) 21 | return config 22 | } 23 | defaultConnTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 24 | zeronull.Register(conn.TypeMap()) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /pgxpool/batch_results.go: -------------------------------------------------------------------------------- 1 | package pgxpool 2 | 3 | import ( 4 | "github.com/jackc/pgx/v5" 5 | "github.com/jackc/pgx/v5/pgconn" 6 | ) 7 | 8 | type errBatchResults struct { 9 | err error 10 | } 11 | 12 | func (br errBatchResults) Exec() (pgconn.CommandTag, error) { 13 | return pgconn.CommandTag{}, br.err 14 | } 15 | 16 | func (br errBatchResults) Query() (pgx.Rows, error) { 17 | return errRows{err: br.err}, br.err 18 | } 19 | 20 | func (br errBatchResults) QueryRow() pgx.Row { 21 | return errRow{err: br.err} 22 | } 23 | 24 | func (br errBatchResults) Close() error { 25 | return br.err 26 | } 27 | 28 | type poolBatchResults struct { 29 | br pgx.BatchResults 30 | c *Conn 31 | } 32 | 33 | func (br *poolBatchResults) Exec() (pgconn.CommandTag, error) { 34 | return br.br.Exec() 35 | } 36 | 37 | func (br *poolBatchResults) Query() (pgx.Rows, error) { 38 | return br.br.Query() 39 | } 40 | 41 | func (br *poolBatchResults) QueryRow() pgx.Row { 42 | return br.br.QueryRow() 43 | } 44 | 45 | func (br *poolBatchResults) Close() error { 46 | err := br.br.Close() 47 | if br.c != nil { 48 | br.c.Release() 49 | br.c = nil 50 | } 51 | return err 52 | } 53 | -------------------------------------------------------------------------------- /pgxpool/bench_test.go: -------------------------------------------------------------------------------- 1 | package pgxpool_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/jackc/pgx/v5" 9 | "github.com/jackc/pgx/v5/pgxpool" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func BenchmarkAcquireAndRelease(b *testing.B) { 14 | pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 15 | require.NoError(b, err) 16 | defer pool.Close() 17 | 18 | b.ResetTimer() 19 | for i := 0; i < b.N; i++ { 20 | c, err := pool.Acquire(context.Background()) 21 | if err != nil { 22 | b.Fatal(err) 23 | } 24 | c.Release() 25 | } 26 | } 27 | 28 | func BenchmarkMinimalPreparedSelectBaseline(b *testing.B) { 29 | config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 30 | require.NoError(b, err) 31 | 32 | config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error { 33 | _, err := c.Prepare(ctx, "ps1", "select $1::int8") 34 | return err 35 | } 36 | 37 | db, err := pgxpool.NewWithConfig(context.Background(), config) 38 | require.NoError(b, err) 39 | 40 | conn, err := db.Acquire(context.Background()) 41 | require.NoError(b, err) 42 | defer conn.Release() 43 | 44 | var n int64 45 | 46 | b.ResetTimer() 47 | for i := 0; i < b.N; i++ { 48 | err = conn.QueryRow(context.Background(), "ps1", i).Scan(&n) 49 | if err != nil { 50 | b.Fatal(err) 51 | } 52 | 53 | if n != int64(i) { 54 | b.Fatalf("expected %d, got %d", i, n) 55 | } 56 | } 57 | } 58 | 59 | func BenchmarkMinimalPreparedSelect(b *testing.B) { 60 | config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 61 | require.NoError(b, err) 62 | 63 | config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error { 64 | _, err := c.Prepare(ctx, "ps1", "select $1::int8") 65 | return err 66 | } 67 | 68 | db, err := pgxpool.NewWithConfig(context.Background(), config) 69 | require.NoError(b, err) 70 | 71 | var n int64 72 | 73 | b.ResetTimer() 74 | for i := 0; i < b.N; i++ { 75 | err = db.QueryRow(context.Background(), "ps1", i).Scan(&n) 76 | if err != nil { 77 | b.Fatal(err) 78 | } 79 | 80 | if n != int64(i) { 81 | b.Fatalf("expected %d, got %d", i, n) 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /pgxpool/conn_test.go: -------------------------------------------------------------------------------- 1 | package pgxpool_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/jackc/pgx/v5/pgxpool" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestConnExec(t *testing.T) { 13 | t.Parallel() 14 | 15 | pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 16 | require.NoError(t, err) 17 | defer pool.Close() 18 | 19 | c, err := pool.Acquire(context.Background()) 20 | require.NoError(t, err) 21 | defer c.Release() 22 | 23 | testExec(t, c) 24 | } 25 | 26 | func TestConnQuery(t *testing.T) { 27 | t.Parallel() 28 | 29 | pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 30 | require.NoError(t, err) 31 | defer pool.Close() 32 | 33 | c, err := pool.Acquire(context.Background()) 34 | require.NoError(t, err) 35 | defer c.Release() 36 | 37 | testQuery(t, c) 38 | } 39 | 40 | func TestConnQueryRow(t *testing.T) { 41 | t.Parallel() 42 | 43 | pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 44 | require.NoError(t, err) 45 | defer pool.Close() 46 | 47 | c, err := pool.Acquire(context.Background()) 48 | require.NoError(t, err) 49 | defer c.Release() 50 | 51 | testQueryRow(t, c) 52 | } 53 | 54 | func TestConnSendBatch(t *testing.T) { 55 | t.Parallel() 56 | 57 | pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 58 | require.NoError(t, err) 59 | defer pool.Close() 60 | 61 | c, err := pool.Acquire(context.Background()) 62 | require.NoError(t, err) 63 | defer c.Release() 64 | 65 | testSendBatch(t, c) 66 | } 67 | 68 | func TestConnCopyFrom(t *testing.T) { 69 | t.Parallel() 70 | 71 | pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 72 | require.NoError(t, err) 73 | defer pool.Close() 74 | 75 | c, err := pool.Acquire(context.Background()) 76 | require.NoError(t, err) 77 | defer c.Release() 78 | 79 | testCopyFrom(t, c) 80 | } 81 | -------------------------------------------------------------------------------- /pgxpool/doc.go: -------------------------------------------------------------------------------- 1 | // Package pgxpool is a concurrency-safe connection pool for pgx. 2 | /* 3 | pgxpool implements a nearly identical interface to pgx connections. 4 | 5 | Creating a Pool 6 | 7 | The primary way of creating a pool is with `pgxpool.New`. 8 | 9 | pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) 10 | 11 | The database connection string can be in URL or DSN format. PostgreSQL settings, pgx settings, and pool settings can be 12 | specified here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the 13 | connection with `ConnectConfig`. 14 | 15 | config, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) 16 | if err != nil { 17 | // ... 18 | } 19 | config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { 20 | // do something with every new connection 21 | } 22 | 23 | pool, err := pgxpool.NewWithConfig(context.Background(), config) 24 | 25 | A pool returns without waiting for any connections to be established. Acquire a connection immediately after creating 26 | the pool to check if a connection can successfully be established. 27 | */ 28 | package pgxpool 29 | -------------------------------------------------------------------------------- /pgxpool/rows.go: -------------------------------------------------------------------------------- 1 | package pgxpool 2 | 3 | import ( 4 | "github.com/jackc/pgx/v5" 5 | "github.com/jackc/pgx/v5/pgconn" 6 | ) 7 | 8 | type errRows struct { 9 | err error 10 | } 11 | 12 | func (errRows) Close() {} 13 | func (e errRows) Err() error { return e.err } 14 | func (errRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} } 15 | func (errRows) FieldDescriptions() []pgconn.FieldDescription { return nil } 16 | func (errRows) Next() bool { return false } 17 | func (e errRows) Scan(dest ...any) error { return e.err } 18 | func (e errRows) Values() ([]any, error) { return nil, e.err } 19 | func (e errRows) RawValues() [][]byte { return nil } 20 | func (e errRows) Conn() *pgx.Conn { return nil } 21 | 22 | type errRow struct { 23 | err error 24 | } 25 | 26 | func (e errRow) Scan(dest ...any) error { return e.err } 27 | 28 | type poolRows struct { 29 | r pgx.Rows 30 | c *Conn 31 | err error 32 | } 33 | 34 | func (rows *poolRows) Close() { 35 | rows.r.Close() 36 | if rows.c != nil { 37 | rows.c.Release() 38 | rows.c = nil 39 | } 40 | } 41 | 42 | func (rows *poolRows) Err() error { 43 | if rows.err != nil { 44 | return rows.err 45 | } 46 | return rows.r.Err() 47 | } 48 | 49 | func (rows *poolRows) CommandTag() pgconn.CommandTag { 50 | return rows.r.CommandTag() 51 | } 52 | 53 | func (rows *poolRows) FieldDescriptions() []pgconn.FieldDescription { 54 | return rows.r.FieldDescriptions() 55 | } 56 | 57 | func (rows *poolRows) Next() bool { 58 | if rows.err != nil { 59 | return false 60 | } 61 | 62 | n := rows.r.Next() 63 | if !n { 64 | rows.Close() 65 | } 66 | return n 67 | } 68 | 69 | func (rows *poolRows) Scan(dest ...any) error { 70 | err := rows.r.Scan(dest...) 71 | if err != nil { 72 | rows.Close() 73 | } 74 | return err 75 | } 76 | 77 | func (rows *poolRows) Values() ([]any, error) { 78 | values, err := rows.r.Values() 79 | if err != nil { 80 | rows.Close() 81 | } 82 | return values, err 83 | } 84 | 85 | func (rows *poolRows) RawValues() [][]byte { 86 | return rows.r.RawValues() 87 | } 88 | 89 | func (rows *poolRows) Conn() *pgx.Conn { 90 | return rows.r.Conn() 91 | } 92 | 93 | type poolRow struct { 94 | r pgx.Row 95 | c *Conn 96 | err error 97 | } 98 | 99 | func (row *poolRow) Scan(dest ...any) error { 100 | if row.err != nil { 101 | return row.err 102 | } 103 | 104 | err := row.r.Scan(dest...) 105 | if row.c != nil { 106 | row.c.Release() 107 | } 108 | return err 109 | } 110 | -------------------------------------------------------------------------------- /pgxpool/stat.go: -------------------------------------------------------------------------------- 1 | package pgxpool 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/jackc/puddle/v2" 7 | ) 8 | 9 | // Stat is a snapshot of Pool statistics. 10 | type Stat struct { 11 | s *puddle.Stat 12 | newConnsCount int64 13 | lifetimeDestroyCount int64 14 | idleDestroyCount int64 15 | } 16 | 17 | // AcquireCount returns the cumulative count of successful acquires from the pool. 18 | func (s *Stat) AcquireCount() int64 { 19 | return s.s.AcquireCount() 20 | } 21 | 22 | // AcquireDuration returns the total duration of all successful acquires from 23 | // the pool. 24 | func (s *Stat) AcquireDuration() time.Duration { 25 | return s.s.AcquireDuration() 26 | } 27 | 28 | // AcquiredConns returns the number of currently acquired connections in the pool. 29 | func (s *Stat) AcquiredConns() int32 { 30 | return s.s.AcquiredResources() 31 | } 32 | 33 | // CanceledAcquireCount returns the cumulative count of acquires from the pool 34 | // that were canceled by a context. 35 | func (s *Stat) CanceledAcquireCount() int64 { 36 | return s.s.CanceledAcquireCount() 37 | } 38 | 39 | // ConstructingConns returns the number of conns with construction in progress in 40 | // the pool. 41 | func (s *Stat) ConstructingConns() int32 { 42 | return s.s.ConstructingResources() 43 | } 44 | 45 | // EmptyAcquireCount returns the cumulative count of successful acquires from the pool 46 | // that waited for a resource to be released or constructed because the pool was 47 | // empty. 48 | func (s *Stat) EmptyAcquireCount() int64 { 49 | return s.s.EmptyAcquireCount() 50 | } 51 | 52 | // IdleConns returns the number of currently idle conns in the pool. 53 | func (s *Stat) IdleConns() int32 { 54 | return s.s.IdleResources() 55 | } 56 | 57 | // MaxConns returns the maximum size of the pool. 58 | func (s *Stat) MaxConns() int32 { 59 | return s.s.MaxResources() 60 | } 61 | 62 | // TotalConns returns the total number of resources currently in the pool. 63 | // The value is the sum of ConstructingConns, AcquiredConns, and 64 | // IdleConns. 65 | func (s *Stat) TotalConns() int32 { 66 | return s.s.TotalResources() 67 | } 68 | 69 | // NewConnsCount returns the cumulative count of new connections opened. 70 | func (s *Stat) NewConnsCount() int64 { 71 | return s.newConnsCount 72 | } 73 | 74 | // MaxLifetimeDestroyCount returns the cumulative count of connections destroyed 75 | // because they exceeded MaxConnLifetime. 76 | func (s *Stat) MaxLifetimeDestroyCount() int64 { 77 | return s.lifetimeDestroyCount 78 | } 79 | 80 | // MaxIdleDestroyCount returns the cumulative count of connections destroyed because 81 | // they exceeded MaxConnIdleTime. 82 | func (s *Stat) MaxIdleDestroyCount() int64 { 83 | return s.idleDestroyCount 84 | } 85 | -------------------------------------------------------------------------------- /pgxpool/tx_test.go: -------------------------------------------------------------------------------- 1 | package pgxpool_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/jackc/pgx/v5/pgxpool" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestTxExec(t *testing.T) { 13 | t.Parallel() 14 | 15 | pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 16 | require.NoError(t, err) 17 | defer pool.Close() 18 | 19 | tx, err := pool.Begin(context.Background()) 20 | require.NoError(t, err) 21 | defer tx.Rollback(context.Background()) 22 | 23 | testExec(t, tx) 24 | } 25 | 26 | func TestTxQuery(t *testing.T) { 27 | t.Parallel() 28 | 29 | pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 30 | require.NoError(t, err) 31 | defer pool.Close() 32 | 33 | tx, err := pool.Begin(context.Background()) 34 | require.NoError(t, err) 35 | defer tx.Rollback(context.Background()) 36 | 37 | testQuery(t, tx) 38 | } 39 | 40 | func TestTxQueryRow(t *testing.T) { 41 | t.Parallel() 42 | 43 | pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 44 | require.NoError(t, err) 45 | defer pool.Close() 46 | 47 | tx, err := pool.Begin(context.Background()) 48 | require.NoError(t, err) 49 | defer tx.Rollback(context.Background()) 50 | 51 | testQueryRow(t, tx) 52 | } 53 | 54 | func TestTxSendBatch(t *testing.T) { 55 | t.Parallel() 56 | 57 | pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 58 | require.NoError(t, err) 59 | defer pool.Close() 60 | 61 | tx, err := pool.Begin(context.Background()) 62 | require.NoError(t, err) 63 | defer tx.Rollback(context.Background()) 64 | 65 | testSendBatch(t, tx) 66 | } 67 | 68 | func TestTxCopyFrom(t *testing.T) { 69 | t.Parallel() 70 | 71 | pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 72 | require.NoError(t, err) 73 | defer pool.Close() 74 | 75 | tx, err := pool.Begin(context.Background()) 76 | require.NoError(t, err) 77 | defer tx.Rollback(context.Background()) 78 | 79 | testCopyFrom(t, tx) 80 | } 81 | -------------------------------------------------------------------------------- /pipeline_test.go: -------------------------------------------------------------------------------- 1 | package pgx_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jackc/pgx/v5" 8 | "github.com/jackc/pgx/v5/pgconn" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestPipelineWithoutPreparedOrDescribedStatements(t *testing.T) { 13 | t.Parallel() 14 | 15 | defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 16 | pipeline := conn.PgConn().StartPipeline(ctx) 17 | 18 | eqb := pgx.ExtendedQueryBuilder{} 19 | 20 | err := eqb.Build(conn.TypeMap(), nil, []any{1, 2}) 21 | require.NoError(t, err) 22 | pipeline.SendQueryParams(`select $1::bigint + $2::bigint`, eqb.ParamValues, nil, eqb.ParamFormats, eqb.ResultFormats) 23 | 24 | err = eqb.Build(conn.TypeMap(), nil, []any{3, 4, 5}) 25 | require.NoError(t, err) 26 | pipeline.SendQueryParams(`select $1::bigint + $2::bigint + $3::bigint`, eqb.ParamValues, nil, eqb.ParamFormats, eqb.ResultFormats) 27 | 28 | err = pipeline.Sync() 29 | require.NoError(t, err) 30 | 31 | results, err := pipeline.GetResults() 32 | require.NoError(t, err) 33 | rr, ok := results.(*pgconn.ResultReader) 34 | require.True(t, ok) 35 | rows := pgx.RowsFromResultReader(conn.TypeMap(), rr) 36 | 37 | rowCount := 0 38 | var n int64 39 | for rows.Next() { 40 | err = rows.Scan(&n) 41 | require.NoError(t, err) 42 | rowCount++ 43 | } 44 | require.NoError(t, rows.Err()) 45 | require.Equal(t, 1, rowCount) 46 | require.Equal(t, "SELECT 1", rows.CommandTag().String()) 47 | require.EqualValues(t, 3, n) 48 | 49 | results, err = pipeline.GetResults() 50 | require.NoError(t, err) 51 | rr, ok = results.(*pgconn.ResultReader) 52 | require.True(t, ok) 53 | rows = pgx.RowsFromResultReader(conn.TypeMap(), rr) 54 | 55 | rowCount = 0 56 | n = 0 57 | for rows.Next() { 58 | err = rows.Scan(&n) 59 | require.NoError(t, err) 60 | rowCount++ 61 | } 62 | require.NoError(t, rows.Err()) 63 | require.Equal(t, 1, rowCount) 64 | require.Equal(t, "SELECT 1", rows.CommandTag().String()) 65 | require.EqualValues(t, 12, n) 66 | 67 | results, err = pipeline.GetResults() 68 | require.NoError(t, err) 69 | _, ok = results.(*pgconn.PipelineSync) 70 | require.True(t, ok) 71 | 72 | results, err = pipeline.GetResults() 73 | require.NoError(t, err) 74 | require.Nil(t, results) 75 | 76 | err = pipeline.Close() 77 | require.NoError(t, err) 78 | }) 79 | } 80 | -------------------------------------------------------------------------------- /testsetup/README.md: -------------------------------------------------------------------------------- 1 | # Test Setup 2 | 3 | This directory contains miscellaneous files used to setup a test database. 4 | -------------------------------------------------------------------------------- /testsetup/ca.cnf: -------------------------------------------------------------------------------- 1 | [ req ] 2 | distinguished_name = dn 3 | [ dn ] 4 | commonName = ca 5 | [ ext ] 6 | basicConstraints =CA:TRUE,pathlen:0 7 | -------------------------------------------------------------------------------- /testsetup/localhost.cnf: -------------------------------------------------------------------------------- 1 | [ req ] 2 | default_bits = 2048 3 | distinguished_name = dn 4 | req_extensions = v3_req 5 | prompt = no 6 | [ dn ] 7 | commonName = localhost 8 | [ v3_req ] 9 | subjectAltName = @alt_names 10 | keyUsage = digitalSignature 11 | extendedKeyUsage = serverAuth 12 | [alt_names] 13 | DNS.1 = localhost 14 | -------------------------------------------------------------------------------- /testsetup/pg_hba.conf: -------------------------------------------------------------------------------- 1 | local all postgres trust 2 | local all all trust 3 | host all pgx_md5 127.0.0.1/32 md5 4 | host all pgx_scram 127.0.0.1/32 scram-sha-256 5 | host all pgx_pw 127.0.0.1/32 password 6 | hostssl all pgx_ssl 127.0.0.1/32 scram-sha-256 7 | hostssl all pgx_sslcert 127.0.0.1/32 cert 8 | -------------------------------------------------------------------------------- /testsetup/pgx_sslcert.cnf: -------------------------------------------------------------------------------- 1 | [ req ] 2 | default_bits = 2048 3 | distinguished_name = dn 4 | req_extensions = v3_req 5 | prompt = no 6 | [ dn ] 7 | commonName = pgx_sslcert 8 | [ v3_req ] 9 | keyUsage = digitalSignature 10 | -------------------------------------------------------------------------------- /testsetup/postgresql_setup.sql: -------------------------------------------------------------------------------- 1 | -- Create extensions and types. 2 | create extension hstore; 3 | create domain uint64 as numeric(20,0); 4 | 5 | -- Create users for different types of connections and authentication. 6 | create user pgx_ssl with superuser PASSWORD 'secret'; 7 | create user pgx_sslcert with superuser PASSWORD 'secret'; 8 | set password_encryption = md5; 9 | create user pgx_md5 with superuser PASSWORD 'secret'; 10 | set password_encryption = 'scram-sha-256'; 11 | create user pgx_pw with superuser PASSWORD 'secret'; 12 | create user pgx_scram with superuser PASSWORD 'secret'; 13 | \set whoami `whoami` 14 | create user :whoami with superuser; -- unix domain socket user 15 | 16 | 17 | -- The tricky test user, below, has to actually exist so that it can be used in a test 18 | -- of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. 19 | create user " tricky, ' } "" \\ test user " superuser password 'secret'; 20 | -------------------------------------------------------------------------------- /testsetup/postgresql_ssl.conf: -------------------------------------------------------------------------------- 1 | ssl = on 2 | ssl_cert_file = 'server.crt' 3 | ssl_key_file = 'server.key' 4 | ssl_ca_file = 'root.crt' 5 | -------------------------------------------------------------------------------- /values.go: -------------------------------------------------------------------------------- 1 | package pgx 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/jackc/pgx/v5/internal/anynil" 7 | "github.com/jackc/pgx/v5/internal/pgio" 8 | "github.com/jackc/pgx/v5/pgtype" 9 | ) 10 | 11 | // PostgreSQL format codes 12 | const ( 13 | TextFormatCode = 0 14 | BinaryFormatCode = 1 15 | ) 16 | 17 | func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) { 18 | if anynil.Is(arg) { 19 | return nil, nil 20 | } 21 | 22 | buf, err := m.Encode(0, TextFormatCode, arg, []byte{}) 23 | if err != nil { 24 | return nil, err 25 | } 26 | if buf == nil { 27 | return nil, nil 28 | } 29 | return string(buf), nil 30 | } 31 | 32 | func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { 33 | if anynil.Is(arg) { 34 | return pgio.AppendInt32(buf, -1), nil 35 | } 36 | 37 | sp := len(buf) 38 | buf = pgio.AppendInt32(buf, -1) 39 | argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) 40 | if err != nil { 41 | if argBuf2, err2 := tryScanStringCopyValueThenEncode(m, buf, oid, arg); err2 == nil { 42 | argBuf = argBuf2 43 | } else { 44 | return nil, err 45 | } 46 | } 47 | 48 | if argBuf != nil { 49 | buf = argBuf 50 | pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) 51 | } 52 | return buf, nil 53 | } 54 | 55 | func tryScanStringCopyValueThenEncode(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { 56 | s, ok := arg.(string) 57 | if !ok { 58 | return nil, errors.New("not a string") 59 | } 60 | 61 | var v any 62 | err := m.Scan(oid, TextFormatCode, []byte(s), &v) 63 | if err != nil { 64 | return nil, err 65 | } 66 | 67 | return m.Encode(oid, BinaryFormatCode, v, buf) 68 | } 69 | --------------------------------------------------------------------------------